When run python ssl_graphmodels/pyg_models/train_docs.py --name R8
building gnn_note_attn_gumbel model...
train and validate: 0/200 | |
Traceback (most recent call last):
File "D:\Code\TextSSL\ssl_graphmodels\pyg_models\train_docs.py", line 126, in
best_model, best_epoch, test_results, all_results, best_preds, labels = train_main(train_loader, val_loader, test_loader, params['patience'])
File "D:\Code\TextSSL\ssl_graphmodels\pyg_models\train_docs.py", line 70, in train_main
train_loss = train(train_loader, training=True)
File "D:\Code\TextSSL\ssl_graphmodels\pyg_models\train_docs.py", line 22, in train
for index, data in enumerate(loader):
File "D:\Anaconda\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 521, in next
data = self._next_data()
File "D:\Anaconda\envs\pytorch\lib\site-packages\torch\utils\data\dataloader.py", line 561, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "D:\Anaconda\envs\pytorch\lib\site-packages\torch\utils\data_utils\fetch.py", line 52, in fetch
return self.collate_fn(data)
File "D:\Anaconda\envs\pytorch\lib\site-packages\torch_geometric\loader\dataloader.py", line 19, in call
return Batch.from_data_list(batch, self.follow_batch,
File "D:\Anaconda\envs\pytorch\lib\site-packages\torch_geometric\data\batch.py", line 68, in from_data_list
batch, slice_dict, inc_dict = collate(
File "D:\Anaconda\envs\pytorch\lib\site-packages\torch_geometric\data\collate.py", line 84, in collate
value, slices, incs = _collate(attr, values, data_list, stores,
File "D:\Anaconda\envs\pytorch\lib\site-packages\torch_geometric\data\collate.py", line 133, in _collate
incs = get_incs(key, values, data_list, stores)
File "D:\Anaconda\envs\pytorch\lib\site-packages\torch_geometric\data\collate.py", line 223, in get_incs
repeats = [
File "D:\Anaconda\envs\pytorch\lib\site-packages\torch_geometric\data\collate.py", line 224, in
data.inc(key, value, store)
TypeError: inc() takes 3 positional arguments but 4 were given
My env shows below:
PyG:2.0.4
Pytorch:1.10
CUDA:11.3
OS:Win10