Skip to content

Commit 2b5c9f6

Browse files
committed
convert data indices from torch.tensor to numpy.ndarray
1 parent 4497a62 commit 2b5c9f6

3 files changed

+5
-0
lines changed

evaluate_models_utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def evaluate_model_link_prediction(model_name: str, model: nn.Module, neighbor_s
4747
evaluate_losses, evaluate_metrics = [], []
4848
evaluate_idx_data_loader_tqdm = tqdm(evaluate_idx_data_loader, ncols=120)
4949
for batch_idx, evaluate_data_indices in enumerate(evaluate_idx_data_loader_tqdm):
50+
evaluate_data_indices = evaluate_data_indices.numpy()
5051
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids = \
5152
evaluate_data.src_node_ids[evaluate_data_indices], evaluate_data.dst_node_ids[evaluate_data_indices], \
5253
evaluate_data.node_interact_times[evaluate_data_indices], evaluate_data.edge_ids[evaluate_data_indices]
@@ -178,6 +179,7 @@ def evaluate_model_node_classification(model_name: str, model: nn.Module, neighb
178179
evaluate_total_loss, evaluate_y_trues, evaluate_y_predicts = 0.0, [], []
179180
evaluate_idx_data_loader_tqdm = tqdm(evaluate_idx_data_loader, ncols=120)
180181
for batch_idx, evaluate_data_indices in enumerate(evaluate_idx_data_loader_tqdm):
182+
evaluate_data_indices = evaluate_data_indices.numpy()
181183
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids, batch_labels = \
182184
evaluate_data.src_node_ids[evaluate_data_indices], evaluate_data.dst_node_ids[evaluate_data_indices], \
183185
evaluate_data.node_interact_times[evaluate_data_indices], evaluate_data.edge_ids[evaluate_data_indices], evaluate_data.labels[evaluate_data_indices]
@@ -305,6 +307,7 @@ def evaluate_edge_bank_link_prediction(args: argparse.Namespace, train_data: Dat
305307
test_idx_data_loader_tqdm = tqdm(test_idx_data_loader, ncols=120)
306308

307309
for batch_idx, test_data_indices in enumerate(test_idx_data_loader_tqdm):
310+
test_data_indices = test_data_indices.numpy()
308311
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times = \
309312
test_data.src_node_ids[test_data_indices], test_data.dst_node_ids[test_data_indices], \
310313
test_data.node_interact_times[test_data_indices]

train_link_prediction.py

+1
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@
157157
train_losses, train_metrics = [], []
158158
train_idx_data_loader_tqdm = tqdm(train_idx_data_loader, ncols=120)
159159
for batch_idx, train_data_indices in enumerate(train_idx_data_loader_tqdm):
160+
train_data_indices = train_data_indices.numpy()
160161
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids = \
161162
train_data.src_node_ids[train_data_indices], train_data.dst_node_ids[train_data_indices], \
162163
train_data.node_interact_times[train_data_indices], train_data.edge_ids[train_data_indices]

train_node_classification.py

+1
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@
164164
train_total_loss, train_y_trues, train_y_predicts = 0.0, [], []
165165
train_idx_data_loader_tqdm = tqdm(train_idx_data_loader, ncols=120)
166166
for batch_idx, train_data_indices in enumerate(train_idx_data_loader_tqdm):
167+
train_data_indices = train_data_indices.numpy()
167168
batch_src_node_ids, batch_dst_node_ids, batch_node_interact_times, batch_edge_ids, batch_labels = \
168169
train_data.src_node_ids[train_data_indices], train_data.dst_node_ids[train_data_indices], train_data.node_interact_times[train_data_indices], \
169170
train_data.edge_ids[train_data_indices], train_data.labels[train_data_indices]

0 commit comments

Comments
 (0)