I've selected your work as part of a reproducibility study that I am conducting, I'm having difficulties running your code.
There appears to be a missing dependency for tensorflow (included in this file: https://github.com/zhanglu-cst/ClassKG/blob/e66ad7f9aa89bff57c75adc7bfd5e5063b2958ea/compent/checkpoint.py).
Aside from that, I'm getting this error:
Using backend: pytorch
Using backend: pytorch
Using backend: pytorch
Setting up a new session...
/workspace/task/../keyword_sentence/keywords.py:125: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
label_one_hot = torch.tensor(label_one_hot).float()
/workspace/task/../keyword_sentence/keywords.py:126: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
index_one_hot = torch.tensor(index_one_hot).float()
/workspace/task/../keyword_sentence/keywords.py:125: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
label_one_hot = torch.tensor(label_one_hot).float()
/workspace/task/../keyword_sentence/keywords.py:126: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
index_one_hot = torch.tensor(index_one_hot).float()
Traceback (most recent call last):
File "pipeline.py", line 102, in <module>
spawn(main, args = (), nprocs = world_size, join = True)
File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
while not context.join():
File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/workspace/task/pipeline.py", line 76, in main
res_dict = GCN_trainer.train_model(sentences = voted_sentences,
File "/workspace/task/../Models/Graph_SSL/trainer_gcn.py", line 90, in train_model
self.pretrain_model(dataloader_train.dataset.Large_G, self.model)
File "/workspace/task/../Models/Graph_SSL/trainer_gcn.py", line 46, in pretrain_model
trainer.do_train()
File "/workspace/task/../Models/SSL/trainer_SSL.py", line 60, in do_train
output = self.model(batch['batch_graphs'])
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 930, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/workspace/task/../Models/Graph_SSL/GIN_model.py", line 217, in forward
h = self.ginlayers[i](graphs, h)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/dgl/nn/pytorch/conv/ginconv.py", line 133, in forward
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cpu!
STDOUT:
loading unlabeled sentence:11527
test not exist
class:0, GT count:956.0, Pred count:1175.0
class:1, GT count:858.0, Pred count:931.0
class:2, GT count:81.0, Pred count:395.0
class:3, GT count:5872.0, Pred count:5441.0
class:4, GT count:535.0, Pred count:360.0
rank:1,analyse of keywords:
precision recall f1-score support
0 0.76 0.94 0.84 956
1 0.77 0.84 0.80 858
2 0.18 0.89 0.30 81
3 1.00 0.92 0.96 5872
4 0.80 0.54 0.64 535
accuracy 0.89 8302
macro avg 0.70 0.82 0.71 8302
weighted avg 0.92 0.89 0.90 8302
rank:1,analyse of keywords:, keywords count:25
rank:1,cover:8302, all:11527
rank:1, cover%:0.7202220872733582
2022-06-23 15:52:18,324 NYT_5 INFO: rank:1, rank:1, keywords f1_micro:0.8900264996386414
2022-06-23 15:52:18,324 NYT_5 INFO: rank:1, rank:1, keywords f1_macro:0.7091776614199704
loading unlabeled sentence:11527
test not exist
class:0, GT count:956.0, Pred count:1175.0
class:1, GT count:858.0, Pred count:931.0
class:2, GT count:81.0, Pred count:395.0
class:3, GT count:5872.0, Pred count:5441.0
class:4, GT count:535.0, Pred count:360.0
rank:0,analyse of keywords:
precision recall f1-score support
0 0.76 0.94 0.84 956
1 0.77 0.84 0.80 858
2 0.18 0.89 0.30 81
3 1.00 0.92 0.96 5872
4 0.80 0.54 0.64 535
accuracy 0.89 8302
macro avg 0.70 0.82 0.71 8302
weighted avg 0.92 0.89 0.90 8302
rank:0,analyse of keywords:, keywords count:25
rank:0,cover:8302, all:11527
rank:0, cover%:0.7202220872733582
2022-06-23 15:52:18,695 NYT_5 INFO: rank:0, rank:0, keywords f1_micro:0.8900264996386414
2022-06-23 15:52:18,695 NYT_5 INFO: rank:0, rank:0, keywords f1_macro:0.7091776614199704
2022-06-23 15:52:19,188 NYT_5 INFO: rank:1, iteration:0, start
2022-06-23 15:52:19,542 NYT_5 INFO: rank:0, iteration:0, start
2022-06-23 15:52:22,563 NYT_5 INFO: rank:1, vote generate sentences:8302. total count:11527, cover:0.7202220872733582
2022-06-23 15:52:22,565 NYT_5 INFO: rank:1, labels:0, count:1175
2022-06-23 15:52:22,565 NYT_5 INFO: rank:1, labels:1, count:931
2022-06-23 15:52:22,565 NYT_5 INFO: rank:1, labels:2, count:395
2022-06-23 15:52:22,565 NYT_5 INFO: rank:1, labels:3, count:5441
2022-06-23 15:52:22,565 NYT_5 INFO: rank:1, labels:4, count:360
2022-06-23 15:52:22,565 NYT_5 INFO: rank:1, build graphs, total number keywords:25
2022-06-23 15:52:22,841 NYT_5 INFO: rank:0, vote generate sentences:8302. total count:11527, cover:0.7202220872733582
2022-06-23 15:52:22,851 NYT_5 INFO: rank:0, labels:0, count:1175
2022-06-23 15:52:22,851 NYT_5 INFO: rank:0, labels:1, count:931
2022-06-23 15:52:22,851 NYT_5 INFO: rank:0, labels:2, count:395
2022-06-23 15:52:22,852 NYT_5 INFO: rank:0, labels:3, count:5441
2022-06-23 15:52:22,852 NYT_5 INFO: rank:0, labels:4, count:360
2022-06-23 15:52:22,864 NYT_5 INFO: rank:0, build graphs, total number keywords:25 edge_number:527
2022-06-23 15:52:27,405 NYT_5 INFO: rank:1, berfor balance, sample number each class:[1175 931 395 5441 360]
2022-06-23 15:52:27,420 NYT_5 INFO: rank:1, after balance, sample number each class:[5441 5441 5441 5441 5441]
2022-06-23 15:52:27,420 NYT_5 INFO: rank:1, build graphs, total number keywords:25
edge_number:527
2022-06-23 15:52:27,545 NYT_5 INFO: rank:0, berfor balance, sample number each class:[1175 931 395 5441 360]
2022-06-23 15:52:27,570 NYT_5 INFO: rank:0, after balance, sample number each class:[5441 5441 5441 5441 5441]
2022-06-23 15:52:27,577 NYT_5 INFO: rank:0, build graphs, total number keywords:25
edge_number:527
2022-06-23 15:52:44,330 NYT_5 INFO: rank:1, start SSL
edge_number:527
2022-06-23 15:52:44,330 NYT_5 INFO: rank:0, start SSL
I'm using nvcr.io/nvidia/pytorch:22.02-py3
as the base docker image. I'm trying to test it on NYT5. Let me know if you need more info.