Skip to content

Commit 995715c

Browse files
committed
chore(preprocess): convenient preprocess; better statistics displaying
i) preprocess all datasets with one line. ii) display the data statistic in a more concise table way
1 parent 591044e commit 995715c

File tree

3 files changed

+29
-8
lines changed

3 files changed

+29
-8
lines changed

preprocess_data/data_statistics.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
import numpy as np
2+
import pandas as pd
3+
from tabulate import tabulate
4+
5+
6+
def pprint_df(df, tablefmt='psql'):
7+
print(tabulate(df, headers='keys', tablefmt=tablefmt))
8+
29

310
if __name__ == "__main__":
4-
for dataset_name in ['wikipedia', 'reddit', 'mooc', 'lastfm', 'myket', 'enron', 'SocialEvo', 'uci',
5-
'Flights', 'CanParl', 'USLegis', 'UNtrade', 'UNvote', 'Contacts']:
11+
all_datasets = ['wikipedia', 'reddit', 'mooc', 'lastfm', 'myket', 'enron', 'SocialEvo', 'uci',
12+
'Flights', 'CanParl', 'USLegis', 'UNtrade', 'UNvote', 'Contacts']
13+
records = []
14+
for dataset_name in sorted(all_datasets, key=lambda v: v.upper()):
615
edge_raw_features = np.load('../processed_data/{}/ml_{}.npy'.format(dataset_name, dataset_name))
716
node_raw_features = np.load('../processed_data/{}/ml_{}_node.npy'.format(dataset_name, dataset_name))
17+
info = {'name': dataset_name,
18+
'num_nodes': node_raw_features.shape[0] - 1,
19+
'node_fea_dim': node_raw_features.shape[-1],
20+
'num_edges': edge_raw_features.shape[0] - 1,
21+
'edge_fea_dim': edge_raw_features.shape[-1]}
22+
records.append(info)
823

9-
print('Statistics of dataset ', dataset_name)
10-
print('number of nodes ', node_raw_features.shape[0] - 1)
11-
print('number of node features ', node_raw_features.shape[1])
12-
print('number of edges ', edge_raw_features.shape[0] - 1)
13-
print('number of edge features ', edge_raw_features.shape[1])
14-
print('====================================')
24+
info_df = pd.DataFrame.from_records(records)
25+
pprint_df(info_df)
+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import os
2+
3+
for name in ['wikipedia', 'reddit', 'mooc', 'lastfm', 'enron', 'SocialEvo', 'myket',
4+
'uci', 'Flights', 'CanParl', 'USLegis', 'UNtrade', 'UNvote', 'Contacts']:
5+
os.system(f'python preprocess_data.py --dataset_name {name}')

requirements.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
torch>=1.8.1
2+
numpy
3+
pandas
4+
tqdm
5+
tabulate

0 commit comments

Comments
 (0)