Skip to content

Commit 4a2b415

Browse files
committed
more readme
1 parent 52afbee commit 4a2b415

File tree

4 files changed

+112
-73
lines changed

4 files changed

+112
-73
lines changed

README.md

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
# Training GNNs with AxoNN
1+
# Plexus
22

33
## Directory Structure
44

5-
- **main**: Contains the parallel implementation and the core code for training the model.
6-
- **scripts**: Contains all shell scripts to run experiments and benchmarks. This is where you can find the scripts to set up and execute various experiments.
7-
- **results**: The output files of experiments are stored here, along with plotting scripts to visualize the results.
8-
- **validation**: Contains baselines used for comparison and validation purposes.
9-
- **performance**: Holds the code for performance modeling and benchmarking.
5+
- **benchmarking**: Contains a serial implementation using PyTorch Geometric (PyG) for validation and testing. Additionally, it includes utilities for benchmarking Sparse Matrix-Matrix Multiplication (SpMM) operations, a key component in GNN computations.
6+
- **examples**: Offers a practical demonstration of how to leverage Plexus to parallelize a GNN model. This directory includes example scripts for running the parallelized training, as well as utilities for parsing the resulting performance data.
7+
- **performance**: Houses files dedicated to modeling the performance characteristics of parallel GNN training. This includes models for communication overhead, computation costs (specifically SpMM), and memory utilization.
8+
- **plexus**: Contains the core logic of the Plexus framework. This includes the parallel implementation of a Graph Convolutional Network (GCN) layer, along with utility functions for dataset preprocessing, efficient data loading, and other essential components for distributed GNN training.

benchmarking/README.md

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# benchmarking
2+
3+
This directory contains files used for validating the parallel implementation and benchmarking key operations.
4+
5+
## Files
6+
7+
- **pyg_serial.py**: This Python script provides a serial implementation of a GNN model using PyTorch Geometric (PyG). It is primarily used for validation purposes, allowing for comparison against the parallelized version. The script is configured to train a model with 3 Graph Convolutional Network (GCN) layers and a hidden dimension size of 128 on the ogbn-products dataset by default.
8+
9+
The script offers several command-line arguments to customize the training process:
10+
- `--download_path`: Specifies the path to the directory where the dataset is stored.
11+
- `--num_epochs` (optional): Determines the number of training epochs (default is 2).
12+
- `--seed` (optional): Allows setting a specific random seed for reproducible experiments.
13+
- Other aspects like the number of GCN layers and the hidden dimension size can be modified by adjusting the model definition within the script or by altering the dataset loading within the `get_dataset` function.
14+
15+
**Example Usage:**
16+
```bash
17+
python pyg_serial.py --download_path <path/to/dataset> --num_epochs 10
18+
```
19+
20+
- **spmm.py**: This script is designed to test the performance of Sparse Matrix-Matrix Multiplication (SpMM), a fundamental operation in GNN computations. It provides flexibility in configuring the SpMM operation to analyze performance under various conditions.
21+
22+
It accepts the following command-line arguments:
23+
- `--pt_file`: Specifies the path to a `.pt` file. This file is expected to be the output of preprocessing a dataset using Plexus, containing a tuple `(data, num_classes)` where `data` is a processed PyG `Data` object. The dimensions of the sparse matrix and the dense feature matrix used in the SpMM benchmark are derived from this data.
24+
- `--shard_row` (optional): Optionally specifies how to shard the row dimension (M) of the sparse matrix (A, sized M x K). This allows for investigating the impact of different row sharding strategies on SpMM performance (default is 1).
25+
- `--shard_col` (optional): Optionally specifies how to shard the column dimension (K) of the sparse matrix (A, sized M x K), which corresponds to the row dimension of the dense features matrix (F, sized K x N). This allows for investigating the impact of different sharding strategies along the shared dimension on SpMM performance (default is 1).
26+
- `--shard_col_x` (optional): Optionally specifies how to shard the column dimension (N) of the dense feature matrix (F, sized K x N). This allows for investigating the impact of different column sharding strategies on SpMM performance (default is 1).
27+
- `--iterations` (optional): Sets the total number of SpMM iterations to run for the benchmark (default is 25).
28+
- `--warmup` (optional): Specifies the number of initial iterations to perform as a warmup. The timing results of these warmup iterations will be ignored to get more stable performance measurements (default is 5)).
29+
- Note that for the arguments related to sharding the matrices, the matrices are padded so their sizes are divisible by these arguments.
30+
31+
**Example Usage:**
32+
```bash
33+
python spmm.py --pt_file <path/to/data/processed_data.pt>
34+
```

benchmarking/pyg_serial.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,9 @@ def create_parser():
3131

3232

3333
def get_dataset(download_path=None):
34-
# dataset = Reddit(download_path, transform=T.NormalizeFeatures())
35-
# dataset = PygNodePropPredDataset(name="ogbn-products", root=input_dir, transform=T.NormalizeFeatures())
36-
# gcn_norm = T.GCNNorm()
37-
# return (gcn_norm.forward(dataset[0]), dataset.num_classes)
38-
return torch.load(download_path)
34+
dataset = PygNodePropPredDataset(name="ogbn-products", root=input_dir, transform=T.NormalizeFeatures())
35+
gcn_norm = T.GCNNorm()
36+
return (gcn_norm.forward(dataset[0]), dataset.num_classes)
3937

4038

4139
class Net(torch.nn.Module):

benchmarking/spmm.py

+70-62
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,34 @@
1-
import torch
2-
import torch.sparse
31
import time
2+
import torch
43
import argparse
4+
import torch.sparse
55

66

7-
def multiply_partitioned_matrices_padded(
8-
partition_row,
9-
partition_col_x,
10-
partition_x_col,
11-
iterations=25,
12-
warmup_iterations=5,
13-
pt_file="/pscratch/sd/a/aranjan/gnn-env/gnn-datasets/products/processed_products.pt",
7+
def multiply_sharded_matrices_padded(
8+
pt_file
9+
shard_row,
10+
shard_col,
11+
shard_x_col,
12+
iterations,
13+
warmup_iterations,
1414
):
1515
"""
16-
Reads a .pt file, pads and partitions the edge_index and x matrices,
17-
multiplies the partitions using sparse matrix multiplication, and measures the time.
16+
Reads a .pt file, pads and shards the edge_index and x matrices,
17+
multiplies the shards using sparse matrix multiplication, and measures the time.
1818
1919
Args:
20-
pt_file (str): Path to the .pt file containing the data dictionary.
21-
partition_row (int): The number of partitions for the first dimension of edge_index.
22-
partition_col_x (int): The number of partitions for the second dimension of edge_index
20+
pt_file (str): Path to the .pt file containing the data.
21+
shard_row (int): The number of shards for the first dimension of edge_index.
22+
shard_col (int): The number of shards for the second dimension of edge_index
2323
and the first dimension of x.
24-
partition_x_col (int): The number of partitions for the second dimension of x.
24+
shard_x_col (int): The number of shards for the second dimension of x.
2525
iterations (int): The total number of multiplication iterations to perform.
2626
warmup_iterations (int): The number of initial iterations to exclude from timing.
2727
"""
28+
2829
try:
2930
data, _ = torch.load(pt_file, weights_only=False)
3031

31-
print(data.edge_weight[0:100])
32-
3332
edge_index = data.edge_index
3433
x = data.x
3534

@@ -38,74 +37,74 @@ def multiply_partitioned_matrices_padded(
3837

3938
# Calculate padded dimensions for edge_index (implied adjacency matrix)
4039
padded_rows = (
41-
(original_num_nodes + partition_row - 1) // partition_row * partition_row
40+
(original_num_nodes + shard_row - 1) // shard_row * shard_row
4241
)
4342
padded_cols_x = (
44-
(original_num_nodes + partition_col_x - 1)
45-
// partition_col_x
46-
* partition_col_x
43+
(original_num_nodes + shard_col - 1)
44+
// shard_col
45+
* shard_col
4746
)
4847

4948
# Calculate padded dimensions for x
5049
padded_x_rows = (
5150
padded_cols_x # Align with the padded columns of the adjacency matrix
5251
)
5352

54-
if partition_x_col == 1:
55-
padded_x_cols = 128
53+
if shard_x_col == 1:
54+
padded_x_cols = original_x_cols
5655
else:
5756
padded_x_cols = (
58-
(original_x_cols + partition_x_col - 1)
59-
// partition_x_col
60-
* partition_x_col
57+
(original_x_cols + shard_x_col - 1)
58+
// shard_x_col
59+
* shard_x_col
6160
)
6261

63-
# Calculate partition sizes for padded dimensions
64-
row_partition_size = padded_rows // partition_row
65-
col_partition_x_size = padded_cols_x // partition_col_x
66-
x_col_partition_size = padded_x_cols // partition_x_col
62+
# Calculate shard sizes for padded dimensions
63+
row_shard_size = padded_rows // shard_row
64+
col_shard_x_size = padded_cols_x // shard_col
65+
x_col_shard_size = padded_x_cols // shard_x_col
6766

68-
# Extract the first partition of padded_adj_indices
67+
# Extract the first shard of padded_adj_indices
6968
start_row = 0
70-
end_row = row_partition_size
69+
end_row = row_shard_size
7170
start_col = 0
72-
end_col = col_partition_x_size
71+
end_col = col_shard_x_size
7372

7473
relevant_edges_mask = (
7574
(edge_index[0] >= start_row)
7675
& (edge_index[0] < end_row)
7776
& (edge_index[1] >= start_col)
7877
& (edge_index[1] < end_col)
7978
)
80-
partitioned_edge_index = edge_index[:, relevant_edges_mask]
79+
sharded_edge_index = edge_index[:, relevant_edges_mask]
8180

82-
# Adjust the indices in the partitioned_edge_index to be relative to the partition
83-
partitioned_edge_index[0] = partitioned_edge_index[0] - start_row
84-
partitioned_edge_index[1] = partitioned_edge_index[1] - start_col
81+
# Adjust the indices in the sharded_edge_index to be relative to the shard
82+
sharded_edge_index[0] = sharded_edge_index[0] - start_row
83+
sharded_edge_index[1] = sharded_edge_index[1] - start_col
8584

86-
# Create the sparse adjacency matrix from the partitioned edge_index
87-
partitioned_adj_t = torch.sparse_coo_tensor(
88-
partitioned_edge_index,
85+
# Create the sparse adjacency matrix from the sharded edge_index
86+
sharded_adj_t = torch.sparse_coo_tensor(
87+
sharded_edge_index,
8988
data.edge_weight[relevant_edges_mask],
90-
(row_partition_size, col_partition_x_size),
89+
(row_shard_size, col_shard_x_size),
9190
).to_sparse_csr()
9291

9392
padded_x = torch.zeros((padded_x_rows, padded_x_cols), dtype=x.dtype)
9493
padded_x[:original_num_nodes, :original_x_cols] = x
9594

96-
# Extract the first partition of padded_x
95+
# Extract the first shard of padded_x
9796
x_start_row = 0
98-
x_end_row = col_partition_x_size
97+
x_end_row = col_shard_x_size
9998
x_start_col = 0
100-
x_end_col = x_col_partition_size
101-
partitioned_x = padded_x[x_start_row:x_end_row, x_start_col:x_end_col]
99+
x_end_col = x_col_shard_size
100+
sharded_x = padded_x[x_start_row:x_end_row, x_start_col:x_end_col]
102101

103-
print("Workload: " + str(partitioned_adj_t._nnz() * partitioned_x.shape[1]))
102+
print("Theoretical # of FLOPs (2 * NNZ * D): " + str(2 * sharded_adj_t._nnz() * sharded_x.shape[1]))
104103

105104
# Move tensors to CUDA if available
106105
if torch.cuda.is_available():
107-
partitioned_adj_t = partitioned_adj_t.cuda()
108-
partitioned_x = partitioned_x.cuda()
106+
sharded_adj_t = sharded_adj_t.cuda()
107+
sharded_x = sharded_x.cuda()
109108

110109
# Perform sparse matrix multiplication and measure time
111110
times = []
@@ -114,7 +113,7 @@ def multiply_partitioned_matrices_padded(
114113
torch.cuda.synchronize()
115114

116115
start_time = time.time()
117-
result = torch.sparse.mm(partitioned_adj_t, partitioned_x)
116+
result = torch.sparse.mm(sharded_adj_t, sharded_x)
118117

119118
if torch.cuda.is_available():
120119
torch.cuda.synchronize()
@@ -123,7 +122,7 @@ def multiply_partitioned_matrices_padded(
123122
else:
124123
if i >= warmup_iterations:
125124
start_time = time.time()
126-
result = torch.sparse.mm(partitioned_adj_t, partitioned_x)
125+
result = torch.sparse.mm(sharded_adj_t, sharded_x)
127126
end_time = time.time()
128127
times.append(end_time - start_time)
129128

@@ -147,22 +146,30 @@ def multiply_partitioned_matrices_padded(
147146

148147
if __name__ == "__main__":
149148
parser = argparse.ArgumentParser(
150-
description="Multiply partitioned sparse and dense tensors with padding."
149+
description="Multiply sharded sparse and dense tensors with padding."
150+
)
151+
parser.add_argument(
152+
"pt_file",
153+
type=int,
154+
help="Path to plexus processed .pt file containing the data"
151155
)
152156
parser.add_argument(
153-
"partition_row",
157+
"shard_row",
154158
type=int,
155-
help="Number of partitions for the first dimension of edge_index.",
159+
default=1,
160+
help="Number of shards for the first dimension of edge_index.",
156161
)
157162
parser.add_argument(
158-
"partition_col_x",
163+
"shard_col",
159164
type=int,
160-
help="Number of partitions for the second dimension of edge_index and the first dimension of x.",
165+
default=1,
166+
help="Number of shards for the second dimension of edge_index and the first dimension of x.",
161167
)
162168
parser.add_argument(
163-
"partition_x_col",
169+
"shard_x_col",
164170
type=int,
165-
help="Number of partitions for the second dimension of x.",
171+
default=1,
172+
help="Number of shards for the second dimension of x.",
166173
)
167174
parser.add_argument(
168175
"--iterations",
@@ -176,10 +183,11 @@ def multiply_partitioned_matrices_padded(
176183

177184
args = parser.parse_args()
178185

179-
multiply_partitioned_matrices_padded(
180-
args.partition_row,
181-
args.partition_col_x,
182-
args.partition_x_col,
186+
multiply_sharded_matrices_padded(
187+
args.pt_file,
188+
args.shard_row,
189+
args.shard_col,
190+
args.shard_x_col,
183191
args.iterations,
184-
args.warmup,
192+
args.warmup
185193
)

0 commit comments

Comments
 (0)