Skip to content

Commit fcb9721

Browse files
committed
fix minor typos
1 parent 9c011a8 commit fcb9721

3 files changed

+20
-10
lines changed

evaluate_link_prediction.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@
169169
evaluate_neg_edge_sampler=val_neg_edge_sampler,
170170
evaluate_data=val_data,
171171
loss_func=loss_func,
172-
num_neighbors=args.num_neighbors)
172+
num_neighbors=args.num_neighbors,
173+
time_gap=args.time_gap)
173174

174175
new_node_val_losses, new_node_val_metrics = evaluate_model_link_prediction(model_name=args.model_name,
175176
model=model,
@@ -178,7 +179,8 @@
178179
evaluate_neg_edge_sampler=new_node_val_neg_edge_sampler,
179180
evaluate_data=new_node_val_data,
180181
loss_func=loss_func,
181-
num_neighbors=args.num_neighbors)
182+
num_neighbors=args.num_neighbors,
183+
time_gap=args.time_gap)
182184

183185
if args.model_name in ['JODIE', 'DyRep', 'TGN']:
184186
# the memory in the best model has seen the validation edges, we need to backup the memory for new testing nodes
@@ -191,7 +193,8 @@
191193
evaluate_neg_edge_sampler=test_neg_edge_sampler,
192194
evaluate_data=test_data,
193195
loss_func=loss_func,
194-
num_neighbors=args.num_neighbors)
196+
num_neighbors=args.num_neighbors,
197+
time_gap=args.time_gap)
195198

196199
if args.model_name in ['JODIE', 'DyRep', 'TGN']:
197200
# reload validation memory bank for new testing nodes
@@ -204,7 +207,8 @@
204207
evaluate_neg_edge_sampler=new_node_test_neg_edge_sampler,
205208
evaluate_data=new_node_test_data,
206209
loss_func=loss_func,
207-
num_neighbors=args.num_neighbors)
210+
num_neighbors=args.num_neighbors,
211+
time_gap=args.time_gap)
208212
# store the evaluation metrics at the current run
209213
val_metric_dict, new_node_val_metric_dict, test_metric_dict, new_node_test_metric_dict = {}, {}, {}, {}
210214

evaluate_node_classification.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,17 @@
139139
evaluate_idx_data_loader=val_idx_data_loader,
140140
evaluate_data=val_data,
141141
loss_func=loss_func,
142-
num_neighbors=args.num_neighbors)
142+
num_neighbors=args.num_neighbors,
143+
time_gap=args.time_gap)
143144

144145
test_total_loss, test_metrics = evaluate_model_node_classification(model_name=args.model_name,
145146
model=model,
146147
neighbor_sampler=full_neighbor_sampler,
147148
evaluate_idx_data_loader=test_idx_data_loader,
148149
evaluate_data=test_data,
149150
loss_func=loss_func,
150-
num_neighbors=args.num_neighbors)
151+
num_neighbors=args.num_neighbors,
152+
time_gap=args.time_gap)
151153

152154
# store the evaluation metrics at the current run
153155
val_metric_dict, test_metric_dict = {}, {}

train_node_classification.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@
234234
evaluate_idx_data_loader=val_idx_data_loader,
235235
evaluate_data=val_data,
236236
loss_func=loss_func,
237-
num_neighbors=args.num_neighbors)
237+
num_neighbors=args.num_neighbors,
238+
time_gap=args.time_gap)
238239

239240
logger.info(f'Epoch: {epoch + 1}, learning rate: {optimizer.param_groups[0]["lr"]}, train loss: {train_total_loss:.4f}')
240241
for metric_name in train_metrics.keys():
@@ -255,7 +256,8 @@
255256
evaluate_idx_data_loader=test_idx_data_loader,
256257
evaluate_data=test_data,
257258
loss_func=loss_func,
258-
num_neighbors=args.num_neighbors)
259+
num_neighbors=args.num_neighbors,
260+
time_gap=args.time_gap)
259261

260262
if args.model_name in ['JODIE', 'DyRep', 'TGN']:
261263
# reload validation memory bank for saving models
@@ -289,15 +291,17 @@
289291
evaluate_idx_data_loader=val_idx_data_loader,
290292
evaluate_data=val_data,
291293
loss_func=loss_func,
292-
num_neighbors=args.num_neighbors)
294+
num_neighbors=args.num_neighbors,
295+
time_gap=args.time_gap)
293296

294297
test_total_loss, test_metrics = evaluate_model_node_classification(model_name=args.model_name,
295298
model=model,
296299
neighbor_sampler=full_neighbor_sampler,
297300
evaluate_idx_data_loader=test_idx_data_loader,
298301
evaluate_data=test_data,
299302
loss_func=loss_func,
300-
num_neighbors=args.num_neighbors)
303+
num_neighbors=args.num_neighbors,
304+
time_gap=args.time_gap)
301305

302306
# store the evaluation metrics at the current run
303307
val_metric_dict, test_metric_dict = {}, {}

0 commit comments

Comments
 (0)