I write a model using Pytorch, and save its state_dict()
to .pth. Now I want to use tensorlayerx to write it, so other people (using tensorflow etc.) can use this model.
My model definition is same in Pytorch and Tensorlayerx, but I can't load pretrained model of .pth in tensorlayerx.
Below is my code. (simple model is used here for clarity, the actual model is more complex than this)
"""
a_torch.py
"""
import torch
from torch import nn
class A(nn.Module):
def __init__(self):
super(A, self).__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=1)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.act(self.bn(self.conv(x)))
if __name__ == '__main__':
a = A()
torch.save(a.state_dict(), 'a.pth')
"""
a_tlx.py
"""
import tensorlayerx as tlx
import torch
from tensorlayerx import nn
class A(nn.Module):
def __init__(self):
super(A, self).__init__()
self.conv = nn.Conv2d(16, kernel_size=1, data_format='channels_first')
self.bn = nn.BatchNorm2d(num_features=16, data_format='channels_first')
self.relu = nn.activation.ReLU()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def pth2npz(pth_path):
temp = torch.load(pth_path) # type(temp) = OrderedDict
tlx.files.save_npz_dict(temp.items(), pth_path.split('.')[0] + '.npz')
if __name__ == '__main__':
a = A()
pth2npz('a.pth')
tlx.files.load_and_assign_npz_dict('a.npz', a)
First run a_torch.py
, then run a_tlx.py
.
The error is below.
Using PyTorch backend.
Traceback (most recent call last):
File "test/test_03.py", line 25, in <module>
tlx.files.load_and_assign_npz_dict('test/a.npz', a)
File "/home/mchen/anaconda3/envs/kpconv/lib/python3.8/site-packages/tensorlayerx/files/utils.py", line 2208, in load_and_assign_npz_dict
raise RuntimeError(
RuntimeError: Weights named 'conv.weight' not found in network. Hint: set argument skip=Ture if you want to skip redundant or mismatch weights
Then I debug and look at the tlx.files.load_and_assign_npz_dict()
source code. I find tensorlayerx parameter name is different from PyTorch. This results in key mismatch when loading pre-trained model.
In the following two figures, the first is the parameter name of PyTorch and the second is the parameter name of TensorLayerx.
Now the solution I can think of is to write a key map table, but it is hard for large model. So can you give me a simple solution ? (same model definition in pytorch and tensorlayerx, load pretrained model in .pth) :grin: