您好,我和前面的朋友遇见的问题一样,修改crnn_recognizer.py文件的第100行def init(self, model_path='/root/zjut/ocr.pytorch/checkpoints/CRNN.pth')。当我执行'python demo.py'命令出错,显示如下:
Traceback (most recent call last):
File "/root/.vscode-server/extensions/ms-python.python-2019.11.50794/pythonFiles/ptvsd_launcher.py", line 43, in
main(ptvsdArgs)
File "/root/.vscode-server/extensions/ms-python.python-2019.11.50794/pythonFiles/lib/python/old_ptvsd/ptvsd/main.py", line 432, in main
run()
File "/root/.vscode-server/extensions/ms-python.python-2019.11.50794/pythonFiles/lib/python/old_ptvsd/ptvsd/main.py", line 316, in run_file
runpy.run_path(target, run_name='main')
File "/root/anaconda3/lib/python3.6/runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "/root/anaconda3/lib/python3.6/runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "/root/anaconda3/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/root/zjut/ocr.pytorch/demo.py", line 10, in
from ocr import ocr
File "/root/zjut/ocr.pytorch/ocr.py", line 6, in
recognizer = PytorchOcr()
File "/root/zjut/ocr.pytorch/recognize/crnn_recognizer.py", line 111, in init
self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path).items()})
File "/root/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 845, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for CRNN:
Missing key(s) in state_dict: "conv1.weight", "conv1.bias", "conv2.weight", "conv2.bias", "conv3_1.weight", "conv3_1.bias", "bn3.weight", "bn3.bias", "bn3.running_mean", "bn3.running_var", "conv3_2.weight", "conv3_2.bias", "conv4_1.weight", "conv4_1.bias", "bn4.weight", "bn4.bias", "bn4.running_mean", "bn4.running_var", "conv4_2.weight", "conv4_2.bias", "conv5.weight", "conv5.bias", "bn5.weight", "bn5.bias", "bn5.running_mean", "bn5.running_var".
Unexpected key(s) in state_dict: "cnn.conv0.weight", "cnn.conv0.bias", "cnn.conv1.weight", "cnn.conv1.bias", "cnn.conv2.weight", "cnn.conv2.bias", "cnn.batchnorm2.weight", "cnn.batchnorm2.bias", "cnn.batchnorm2.running_mean", "cnn.batchnorm2.running_var", "cnn.batchnorm2.num_batches_tracked", "cnn.conv3.weight", "cnn.conv3.bias", "cnn.conv4.weight", "cnn.conv4.bias", "cnn.batchnorm4.weight", "cnn.batchnorm4.bias", "cnn.batchnorm4.running_mean", "cnn.batchnorm4.running_var", "cnn.batchnorm4.num_batches_tracked", "cnn.conv5.weight", "cnn.conv5.bias", "cnn.conv6.weight", "cnn.conv6.bias", "cnn.batchnorm6.weight", "cnn.batchnorm6.bias", "cnn.batchnorm6.running_mean", "cnn.batchnorm6.running_var", "cnn.batchnorm6.num_batches_tracked".
size mismatch for rnn.1.embedding.weight: copying a param with shape torch.Size([5997, 512]) from checkpoint, the shape in current model is torch.Size([5835, 512]).
size mismatch for rnn.1.embedding.bias: copying a param with shape torch.Size([5997]) from checkpoint, the shape in current model is torch.Size([5835]).
其中CRNN.pth是您度盘所提供的。