# -*- coding: utf-8 -*-
import argparse
import glob
import io
import os
import pathlib
import threading
import cv2 as cv
import lmdb
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from tqdm import tqdm
# plt.rcParams['font.sans-serif'] = ['SimHei'] # 正常显示中文
# plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
root_path = pathlib.Path('/root/autodl-tmp/hwdb')
output_path = os.path.join(root_path, pathlib.Path('lmdb'))
train_path = os.path.join(root_path, pathlib.Path('train_3755'))
val_path = os.path.join(root_path, pathlib.Path('test'))
characters = []
with open('../character-3755.txt', 'r', encoding='utf-8') as f:
while True:
line = f.readline()
if not line:
break
char = line.strip()
characters.append(char)
def write_cache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.items():
if isinstance(v, bytes):
# 图片类型为bytes
txn.put(k.encode(), v)
else:
# 标签类型为str, 转为bytes
txn.put(k.encode(), v.encode()) # 编码
def create_dataset(env, image_path, label, index):
n_samples = len(image_path)
# map_size=1073741824 定义最大空间是1GB
cache = {}
cnt = index + 1
for idx in range(n_samples):
# 读取图片路径和对应的标签
image = image_path[idx]
if not os.path.exists(image):
print('%s does not exist' % image)
continue
with open(image, 'rb') as fs:
image_bin = fs.read()
# .mdb数据库文件保存了两种数据,一种是图片数据,一种是标签数据,它们各有其key
image_key = 'image-%09d' % cnt
label_key = 'label-%09d' % cnt
cache[image_key] = image_bin
cache[label_key] = label
cnt += 1
if len(cache) != 0:
write_cache(env, cache)
return n_samples
def show_image(samples):
plt.figure(figsize=(20, 10))
for pos, sample in enumerate(samples):
plt.subplot(4, 5, pos + 1)
plt.imshow(sample[0])
# plt.title(sample[1])
plt.xticks([])
plt.yticks([])
plt.axis("off")
plt.show()
def lmdb_test(root):
env = lmdb.open(
root,
max_readers=1,
readonly=True,
lock=False,
readahead=False,
meminit=False)
if not env:
print('cannot open lmdb from %s' % root)
return
with env.begin(write=False) as txn:
n_samples = int(txn.get('num-samples'.encode()))
with env.begin(write=False) as txn:
samples = []
for index in range(1, n_samples + 1):
img_key = 'image-%09d' % index
img_buf = txn.get(img_key.encode())
buf = io.BytesIO()
buf.write(img_buf)
buf.seek(0)
try:
img = Image.open(buf)
except IOError:
print('Corrupted image for %d' % index)
return
label_key = 'label-%09d' % index
label = str(txn.get(label_key.encode()).decode('utf-8'))
print(n_samples, len(img.split()), label)
samples.append([img, label])
if index == 5:
# show_image(samples)
# samples = []
break
def lmdb_init(directory, out, left, right):
entries = characters[left:right]
pbar = tqdm(entries)
n_samples = 0
# 计算所需内存空间
character_count = len(entries)
image_path = glob.glob(os.path.join(directory, entries[0], '*.png'))
image_cnt = len(image_path)
data_size_per_img = cv.imdecode(np.fromfile(image_path[0], dtype=np.uint8), cv.IMREAD_UNCHANGED).nbytes
# 一个类中所有图片的字节数
data_size = data_size_per_img * image_cnt
# 所有类的图片字节数
total_byte = 2 * data_size * character_count
# 创建lmdb文件
if not os.path.exists(out):
os.makedirs(out)
env = lmdb.open(out, map_size=total_byte)
for dir_name in pbar:
image_path = glob.glob(os.path.join(directory, dir_name, '*.png'))
label = dir_name
n_samples += create_dataset(env, image_path, label, n_samples)
pbar.set_description(
f'character[{left + 1}:{right}]: {label} | nSamples: {n_samples} | total_byte: {total_byte}byte | progressing')
write_cache(env, {'num-samples': str(n_samples)})
env.close()
def begin(mode, left, right, valid=False):
if mode == 'train':
path = os.path.join(output_path, pathlib.Path(mode + '_' + str(right)))
if not valid:
lmdb_init(train_path, path, left=left, right=right)
else:
print(f"show:{valid},path:{path}")
lmdb_test(path)
elif mode == 'test':
path = os.path.join(output_path, pathlib.Path(mode + '_' + str(right - left)))
if not valid:
lmdb_init(val_path, path, left=left, right=right)
else:
print(f"show:{valid},path:{path}")
lmdb_test(path)
class MyThread(threading.Thread):
def __init__(self, mode, left, right, valid):
threading.Thread.__init__(self)
self.mode = mode
self.left = left
self.right = right
self.valid = valid
def run(self):
begin(mode=self.mode, left=self.left, right=self.right, valid=self.valid)
if __name__ == '__main__':
"""
train_500: 3755类前500个类[1,500] = [0, 500)
train_1000: 3755类第501到1000类[501,1000] = [500, 1000)
train_1500: 3755类第1001到1500类[1001,1500] = [1000, 1500)
train_2000: 3755类第1501到2000类[1501,2000] = [1500, 2000)
train_2755: 3755类第2001到2755类[2001,2755] = [2000, 2755)
train_3755: 3755类第2756到3755类[2756,3755] = [2755, 3755)
test_1000: 3755类后1000类[2756,3755] = [2755, 3755)
"""
parser = argparse.ArgumentParser()
parser.add_argument("--train", action="store_true", help="generate train lmdb")
parser.add_argument("--test", action="store_true", help="generate test lmdb")
parser.add_argument("--all", action="store_true", help="generate all lmdb")
parser.add_argument("--show", action="store_true", help="show result")
parser.add_argument("--start", type=int, default=0, help="class start from where,default 0")
parser.add_argument("--end", type=int, default=3755, help="class end from where,default 3755")
args = parser.parse_args()
train = args.train
test = args.test
build_all = args.all
start = args.start
end = args.end
show = args.show
if train:
print(f"args: mode=train, [start:end)=[{start}:{end})")
begin(mode='train', left=start, right=end, valid=show)
if test:
print(f"args: mode=test, [start:end)=[{start}:{end})")
begin(mode='test', left=start, right=end, valid=show)
if build_all:
s = [0, 500, 1000, 1500, 2000, 2755]
step = [500, 500, 500, 500, 755, 1000]
m = ['5*train', '1*test']
threads = []
threadLock = threading.Lock()
mode_index = 0
for i in range(len(m)):
tmp = m[i].strip().split("*")
for j in range(int(tmp[0])):
if show:
begin(mode=tmp[1], left=s[mode_index], right=s[mode_index] + step[mode_index], valid=show)
else:
thread = MyThread(mode=tmp[1], left=s[mode_index],
right=s[mode_index] + step[mode_index], valid=show)
threads.append(thread)
thread.start()
mode_index += 1
for t in threads:
t.join()