I am currently trying to run the naive gflownet training code in molecular docking setting by running
python gflownet.py
under the mols directory. I have unzipped the datasets and have all requirements installed. And I have successfully run the model in the toy grid environment.
And when I further look up, it seems like the problem occurs around the line 70 in model_block.py. I tried to print out the stem_block_batch_idx
but it doesn't seems like could be transfered to long type directly, which is required by an index:
tensor([[-8.4156e-02, -4.2767e-02, -7.2483e-02, -3.3011e-02, -1.1865e-02,
2.0981e-03, 1.3293e-02, -7.3515e-03, -4.1853e-02, 2.1048e-02,
3.8597e-02, -1.5558e-02, 2.1581e-02, 4.9257e-03, 9.5167e-02,
4.0965e-02, 2.0146e-02, -5.5610e-02, -3.5318e-02, -3.1394e-02,
7.2078e-02, 1.8894e-02, -3.0249e-02, 2.9740e-02, 5.6950e-02,
-3.8425e-02, 2.8620e-02, 9.2052e-02, -8.5357e-03, 1.6788e-02,
7.7801e-02, -4.2119e-02, 1.3606e-02, 7.5316e-02, 4.7131e-02,
-4.3429e-03, 1.4157e-04, 2.0939e-02, -2.3499e-02, -6.5888e-02,
-2.8960e-02, 3.1548e-02, -9.2680e-03, 5.4192e-02, -9.6579e-03,
2.0602e-02, 1.8935e-02, 4.1228e-03, -6.3467e-02, 3.6747e-02,
1.4168e-02, -6.1473e-03, -1.9472e-02, -3.3970e-02, -5.7308e-03,
-4.6021e-02, -3.8956e-02, 4.7375e-02, -8.4562e-02, -1.0087e-02,
2.0478e-02, -6.8286e-02, 5.4663e-02, -5.1468e-02, 1.2617e-02,
2.4625e-02, 5.2167e-02, 5.7779e-02, -5.7788e-02, -1.3323e-02,
1.3913e-02, -7.4439e-02, -4.0981e-02, 5.0797e-02, -5.6230e-02,
-5.0963e-02, -5.5488e-02, -2.7339e-02, 1.0469e-02, 3.4695e-02,
-3.2623e-02, 7.6694e-03, -5.8748e-03, 7.0495e-02, -2.2805e-02,
-5.4334e-03, -2.1636e-02, 1.9597e-02, 6.2370e-02, -2.4995e-02,
1.6165e-02, -4.6878e-03, 2.9743e-02, 1.2653e-02, -5.4271e-02,
1.1247e-02, -3.8340e-03, -4.7489e-02, 1.5719e-02, 3.2552e-02,
6.0665e-02, -1.2330e-02, 2.6115e-02, -2.7376e-02, 3.4152e-02,
-1.0086e-02, -2.4257e-02, 3.2202e-02, -3.2659e-02, 8.6094e-02,
-3.1996e-02, 7.8751e-02, 4.5367e-02, -3.8693e-02, -3.6531e-02,
6.7311e-03, 3.2884e-02, -3.2774e-02, -3.8855e-02, 2.8814e-02,
4.3942e-02, -1.3374e-02, 3.0905e-02, -7.0064e-02, -5.7230e-03,
4.5093e-02, 3.8167e-02, -3.0602e-02, -4.0387e-02, -1.5985e-02,
-9.5962e-02, -1.1354e-02, 2.0879e-02, 1.4092e-02, -3.8405e-02,
1.4337e-02, -6.0682e-02, -9.0190e-03, -5.0898e-02, -4.7344e-02,
4.1045e-02, -6.7031e-02, 8.8112e-02, 3.2149e-02, 3.7748e-02,
-4.0757e-02, 1.4378e-02, -1.0749e-01, 6.1679e-02, -6.7268e-03,
-2.7889e-02, -5.9315e-02, -5.5883e-02, -2.6489e-02, 7.3640e-02,
1.8273e-02, -5.2330e-02, -7.7003e-05, 6.8413e-04, -1.4364e-01,
-1.9389e-02, 4.5649e-02, -4.0468e-02, -4.2819e-02, 4.5874e-02,
-1.6481e-02, 1.2627e-02, -8.4941e-02, -3.7458e-02, 2.1359e-02,
-9.2863e-02, -3.4932e-03, 7.1990e-02, 6.2144e-02, 8.1462e-02,
-2.0569e-02, 5.9194e-02, 1.6996e-03, 8.0618e-03, 6.1753e-02,
4.1602e-02, 1.0910e-02, 2.0523e-02, -9.9781e-04, 1.9131e-02,
-1.0267e-02, -9.4474e-02, -3.5725e-02, 9.9953e-03, -4.3195e-02,
-7.9051e-02, -3.1881e-02, 9.2158e-03, -9.6167e-04, -2.7508e-02,
7.1478e-02, -5.4107e-02, 8.0026e-02, -1.8887e-02, 4.6941e-02,
6.5166e-02, 1.2000e-02, 3.9906e-02, -2.8206e-02, 3.7483e-02,
3.5408e-02, -2.5863e-02, 2.3528e-02, 7.1814e-03, 8.0863e-02,
-1.3736e-02, -8.5978e-02, -4.1238e-02, -1.2545e-02, 5.5479e-02,
7.3487e-03, 8.9125e-02, -3.4814e-02, -4.5358e-02, 4.9893e-02,
3.5286e-02, 3.2084e-02, 5.0868e-02, 2.3549e-02, -9.2907e-02,
-6.9315e-03, -1.3088e-02, 8.7066e-02, 1.1554e-02, 1.3771e-02,
-1.7489e-02, -5.2921e-02, 9.2110e-03, 1.6766e-02, 4.8030e-02,
1.4481e-02, 2.9254e-03, 3.5795e-02, 1.0397e-01, -2.0675e-03,
-2.9916e-02, -5.3299e-02, -2.1396e-02, -5.3189e-02, 3.2805e-02,
-2.6538e-03, -2.6352e-02, -1.2823e-02, 6.1972e-02, 5.4822e-02,
4.5579e-02, -3.6638e-02, 8.1013e-03, -5.6014e-02, 1.5187e-02,
-6.5561e-02]], device='cuda:0', dtype=torch.float64,
grad_fn=)
I wonder if I am running the code in the correct way. Is this index correct and if so, do you know what's happening?