Hi, thanks for releasing this project,
I get the following error, when running run_a3c.py
.
[2017-02-19 18:18:07,137] Making new env: Breakout-v0
Process Process-1:
Traceback (most recent call last):
File "/home/ajay/anaconda3/envs/pyphi/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
self.run()
File "/home/ajay/anaconda3/envs/pyphi/lib/python3.6/multiprocessing/process.py", line 93, in run
self._target(*self._args, **self._kwargs)
File "run_a3c.py", line 38, in train
p, v = local_policy(Variable(torch.from_numpy(o)).unsqueeze(0))
File "/home/ajay/anaconda3/envs/pyphi/lib/python3.6/site-packages/torch/nn/modules/module.py", line 210, in __call__
result = self.forward(*input, **kwargs)
File "/home/ajay/PythonProjects/pytorch_a3c-master/policy.py", line 61, in forward
x = F.relu(self.head(x))
File "/home/ajay/anaconda3/envs/pyphi/lib/python3.6/site-packages/torch/nn/modules/module.py", line 210, in __call__
result = self.forward(*input, **kwargs)
File "/home/ajay/anaconda3/envs/pyphi/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 54, in forward
return self._backend.Linear()(input, self.weight, self.bias)
File "/home/ajay/anaconda3/envs/pyphi/lib/python3.6/site-packages/torch/nn/_functions/linear.py", line 10, in forward
output.addmm_(0, 1, input, weight.t())
TypeError: addmm_ received an invalid combination of arguments - got (int, int, torch.ByteTensor, torch.FloatTensor), but expected one of:
* (torch.ByteTensor mat1, torch.ByteTensor mat2)
* (torch.SparseByteTensor mat1, torch.ByteTensor mat2)
* (int beta, torch.ByteTensor mat1, torch.ByteTensor mat2)
* (int alpha, torch.ByteTensor mat1, torch.ByteTensor mat2)
* (int beta, torch.SparseByteTensor mat1, torch.ByteTensor mat2)
* (int alpha, torch.SparseByteTensor mat1, torch.ByteTensor mat2)
* (int beta, int alpha, torch.ByteTensor mat1, torch.ByteTensor mat2)
* (int beta, int alpha, torch.SparseByteTensor mat1, torch.ByteTensor mat2)
Also can you tell me where reinforce()
is defined? It's called on line 80 as,
a.reinforce(r - v.data.squeeze())
Thanks for your help