@@ -47,6 +47,7 @@ def evaluate_model_link_prediction(model_name: str, model: nn.Module, neighbor_s
47
47
evaluate_losses , evaluate_metrics = [], []
48
48
evaluate_idx_data_loader_tqdm = tqdm (evaluate_idx_data_loader , ncols = 120 )
49
49
for batch_idx , evaluate_data_indices in enumerate (evaluate_idx_data_loader_tqdm ):
50
+ evaluate_data_indices = evaluate_data_indices .numpy ()
50
51
batch_src_node_ids , batch_dst_node_ids , batch_node_interact_times , batch_edge_ids = \
51
52
evaluate_data .src_node_ids [evaluate_data_indices ], evaluate_data .dst_node_ids [evaluate_data_indices ], \
52
53
evaluate_data .node_interact_times [evaluate_data_indices ], evaluate_data .edge_ids [evaluate_data_indices ]
@@ -178,6 +179,7 @@ def evaluate_model_node_classification(model_name: str, model: nn.Module, neighb
178
179
evaluate_total_loss , evaluate_y_trues , evaluate_y_predicts = 0.0 , [], []
179
180
evaluate_idx_data_loader_tqdm = tqdm (evaluate_idx_data_loader , ncols = 120 )
180
181
for batch_idx , evaluate_data_indices in enumerate (evaluate_idx_data_loader_tqdm ):
182
+ evaluate_data_indices = evaluate_data_indices .numpy ()
181
183
batch_src_node_ids , batch_dst_node_ids , batch_node_interact_times , batch_edge_ids , batch_labels = \
182
184
evaluate_data .src_node_ids [evaluate_data_indices ], evaluate_data .dst_node_ids [evaluate_data_indices ], \
183
185
evaluate_data .node_interact_times [evaluate_data_indices ], evaluate_data .edge_ids [evaluate_data_indices ], evaluate_data .labels [evaluate_data_indices ]
@@ -305,6 +307,7 @@ def evaluate_edge_bank_link_prediction(args: argparse.Namespace, train_data: Dat
305
307
test_idx_data_loader_tqdm = tqdm (test_idx_data_loader , ncols = 120 )
306
308
307
309
for batch_idx , test_data_indices in enumerate (test_idx_data_loader_tqdm ):
310
+ test_data_indices = test_data_indices .numpy ()
308
311
batch_src_node_ids , batch_dst_node_ids , batch_node_interact_times = \
309
312
test_data .src_node_ids [test_data_indices ], test_data .dst_node_ids [test_data_indices ], \
310
313
test_data .node_interact_times [test_data_indices ]
0 commit comments