Skip to content

Commit 528554f

Browse files
committed
first commit
0 parents  commit 528554f

37 files changed

+3061
-0
lines changed

.github/workflows/ci.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
name: ci
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
concurrency:
10+
group: ci-${{github.ref}}-${{github.event.pull_request.number || github.run_number}}
11+
cancel-in-progress: true
12+
13+
jobs:
14+
formatting:
15+
uses: ./.github/workflows/formatting.yaml

.github/workflows/formatting.yaml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
name: formatting tests
2+
3+
on:
4+
workflow_dispatch:
5+
workflow_call:
6+
7+
concurrency:
8+
group: unit_tests-${{github.ref}}-${{github.event.pull_request.number || github.run_number}}
9+
cancel-in-progress: true
10+
11+
jobs:
12+
formatting:
13+
runs-on: [ubuntu-latest]
14+
15+
steps:
16+
- uses: actions/checkout@v2
17+
- name: Set up Python 3.9
18+
uses: actions/setup-python@v2
19+
with:
20+
python-version: 3.9
21+
22+
- name: Install dependencies
23+
run: |
24+
python -m pip install --upgrade pip
25+
26+
- name: Update black
27+
run: |
28+
pip install --upgrade black
29+
30+
- name: Lint and Format Check with black
31+
run: |
32+
black --diff --check .

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Training GNNs with AxoNN
2+
3+
## Directory Structure
4+
5+
- **main**: Contains the parallel implementation and the core code for training the model.
6+
- **scripts**: Contains all shell scripts to run experiments and benchmarks. This is where you can find the scripts to set up and execute various experiments.
7+
- **results**: The output files of experiments are stored here, along with plotting scripts to visualize the results.
8+
- **validation**: Contains baselines used for comparison and validation purposes.
9+
- **performance**: Holds the code for performance modeling and benchmarking.

benchmarking/plot.sh

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/bin/bash
2+
3+
# Loop through all directories in the current folder
4+
for gpu_dir in */; do
5+
gpu_dir=${gpu_dir%/} # Remove trailing slash
6+
7+
# Check if it is a directory
8+
if [ -d "$gpu_dir" ]; then
9+
10+
# Loop through subdirectories 0, 1, and 2
11+
for sub_dir in 0 1 2; do
12+
dir="./$gpu_dir/$sub_dir"
13+
14+
# Check if the subdirectory exists
15+
if [ -d "$dir" ]; then
16+
# Run the Python script inside the subdirectory
17+
(cd "$dir" && python ../../process_comm_model.py)
18+
19+
# Copy and rename times.npy to the top-level directory
20+
npy_file="$dir/times.npy"
21+
if [ -f "$npy_file" ]; then
22+
cp "$npy_file" "./times_${gpu_dir}_${sub_dir}.npy"
23+
echo "Copied and renamed $npy_file to ./times_${gpu_dir}_${sub_dir}.npy"
24+
fi
25+
else
26+
echo "Subdirectory $dir not found, skipping."
27+
fi
28+
done
29+
fi
30+
done
31+
32+
# Run the final plotting script
33+
python plot_comm_model_avg.py
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
import os
4+
5+
6+
def aggregate_npy_data(directory):
7+
num_gpus = [4, 8, 16, 32, 64, 128]
8+
9+
mean_comm_times, std_comm_times, mean_epoch_times, std_epoch_times = [], [], [], []
10+
11+
# Read all .npy files in the directory
12+
for i in range(len(num_gpus)):
13+
comm_times_list, epoch_times_list = [], []
14+
for file in os.listdir(directory):
15+
if "_" + str(num_gpus[i]) + "_" in file and file.endswith(".npy"):
16+
data = np.load(os.path.join(directory, file), allow_pickle=True)
17+
if len(data) >= 2:
18+
comm_times, epoch_times = data
19+
comm_times_list.append(comm_times)
20+
epoch_times_list.append(epoch_times)
21+
mean_comm_times.append(np.mean(np.array(comm_times_list), axis=0))
22+
std_comm_times.append(np.std(np.array(comm_times_list), axis=0))
23+
mean_epoch_times.append(np.mean(np.array(epoch_times_list), axis=0))
24+
std_epoch_times.append(np.std(np.array(comm_times_list), axis=0))
25+
26+
print((np.array(mean_epoch_times) - np.array(mean_comm_times)).flatten().tolist())
27+
28+
np.save(
29+
"scaling_perlmutter_reddit.npy",
30+
(mean_comm_times, std_comm_times, mean_epoch_times, std_epoch_times),
31+
)
32+
33+
34+
aggregate_npy_data(os.getcwd())

benchmarking/plot_val_text.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
import re
3+
import matplotlib.pyplot as plt
4+
5+
# Directory containing the files
6+
directory = "./" # Change if needed
7+
8+
# Pattern to extract config name and train loss values
9+
file_pattern = re.compile(r"products_(.+)\.txt")
10+
loss_pattern = re.compile(r"Epoch: \d+, Train Loss: ([\d\.]+)")
11+
12+
# Dictionary to store losses per config
13+
losses_dict = {}
14+
15+
# Iterate over all files in the directory
16+
for filename in os.listdir(directory):
17+
match = file_pattern.match(filename)
18+
if match:
19+
config_name = match.group(1)
20+
losses = []
21+
22+
# Read the file and extract losses
23+
with open(os.path.join(directory, filename), "r") as file:
24+
for line in file:
25+
loss_match = loss_pattern.search(line)
26+
if loss_match:
27+
losses.append(float(loss_match.group(1)))
28+
29+
# Store the extracted losses
30+
if losses:
31+
losses_dict[config_name] = losses
32+
33+
# Plot the losses
34+
plt.figure(figsize=(10, 6))
35+
for config, losses in losses_dict.items():
36+
plt.plot(losses, label=config)
37+
38+
plt.xlabel("Epochs")
39+
plt.ylabel("Train Loss")
40+
plt.title("Training Loss per Configuration")
41+
plt.legend()
42+
plt.grid(True)
43+
44+
plt.savefig("val.png")

benchmarking/process_comm_model.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import re
2+
import os
3+
import numpy as np
4+
import matplotlib.pyplot as plt
5+
from comm_model import compute_config_costs
6+
from comp_model import comp_model
7+
8+
9+
def extract_avg_time(line):
10+
match = re.search(r"Avg Time: ([0-9]*\.?[0-9]+)", line)
11+
return float(match.group(1)) if match else 0
12+
13+
14+
def process_log_file(filename):
15+
comm_times, epoch_times = [], []
16+
comm_time, comp_time, cross_time = None, None, None
17+
18+
with open(filename, "r") as file:
19+
for line in file:
20+
line = line.strip()
21+
22+
if (
23+
"epoch " in line
24+
and comm_time is not None
25+
and comp_time is not None
26+
and cross_time is not None
27+
):
28+
epoch_times.append(comp_time + comm_time + cross_time)
29+
comm_times.append(comm_time)
30+
comm_time = 0
31+
comp_time = 0
32+
cross_time = 0
33+
elif "epoch " in line:
34+
comm_time = 0
35+
comp_time = 0
36+
cross_time = 0
37+
elif comm_time is not None and any(
38+
keyword in line
39+
for keyword in ["gather ", "all-reduce ", "reduce-scatter "]
40+
):
41+
comm_time += extract_avg_time(line)
42+
elif comp_time is not None and any(
43+
keyword in line
44+
for keyword in [
45+
"AGG = A * H ",
46+
# "OUT = AGG * W ",
47+
# "GRAD_W = AGG.T * GRAD_OUT ",
48+
# "GRAD_AGG = GRAD_OUT * W.T ",
49+
"GRAD_H = A.T * GRAD_AGG ",
50+
]
51+
):
52+
comp_time += extract_avg_time(line)
53+
elif cross_time is not None and any(
54+
keyword in line for keyword in ["cross entropy"]
55+
):
56+
cross_time += extract_avg_time(line)
57+
58+
if comm_time is not None and comp_time is not None and cross_time is not None:
59+
# epoch_times.append(comp_time + comm_time + cross_time)
60+
epoch_times.append(comp_time + comm_time)
61+
comm_times.append(comm_time)
62+
63+
return sum(epoch_times[1:]) / (len(epoch_times) - 1), sum(comm_times[1:]) / (
64+
len(comm_times) - 1
65+
)
66+
67+
68+
def parse_config(filename):
69+
match = re.search(r"reddit_X(\d+)Y(\d+)Z(\d+)\.txt", filename)
70+
if match:
71+
x, y, z = map(int, match.groups())
72+
return (x, y, z)
73+
return None
74+
75+
76+
def main():
77+
num_configs = len([f for f in os.listdir() if f.endswith(".txt")])
78+
79+
epoch_times = [0] * num_configs
80+
comm_times = [0] * num_configs
81+
82+
num_gpus = None
83+
for filename in os.listdir():
84+
if filename.startswith("reddit_") and filename.endswith(".txt"):
85+
config = parse_config(filename)
86+
num_gpus = config[0] * config[1] * config[2]
87+
88+
"""
89+
CONFIG_RANKS = compute_config_costs(
90+
num_gpus, 232965, [602, 128, 128, 41], "v3", "perlmutter"
91+
)
92+
"""
93+
94+
CONFIG_RANKS = comp_model(232965, 114848857, num_gpus, [602, 128, 128])
95+
96+
sorted_items = sorted(CONFIG_RANKS.items(), key=lambda x: x[1])
97+
98+
for i in range(len(sorted_items)):
99+
CONFIG_RANKS[sorted_items[i][0]] = i
100+
101+
if config and config in CONFIG_RANKS:
102+
rank = CONFIG_RANKS[config]
103+
avg_epoch_time, avg_comm_time = process_log_file(filename)
104+
105+
if avg_comm_time > 0 and avg_epoch_time > 0:
106+
rank = 0
107+
comm_times[rank] = avg_comm_time
108+
epoch_times[rank] = avg_epoch_time
109+
110+
x_ticks = list(range(len(CONFIG_RANKS)))
111+
x_labels = list(range(len(CONFIG_RANKS)))
112+
113+
np.save("times", (comm_times, epoch_times))
114+
115+
116+
if __name__ == "__main__":
117+
main()

benchmarking/pyg_serial.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import os
2+
import math
3+
import torch
4+
import random
5+
import argparse
6+
import numpy as np
7+
import torch.nn.functional as F
8+
import torch_geometric.transforms as T
9+
from torch_geometric.nn import GCNConv
10+
from torch_geometric.datasets import Reddit
11+
from torch_geometric.data.storage import GlobalStorage
12+
from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr
13+
14+
15+
torch.serialization.add_safe_globals([GlobalStorage, DataEdgeAttr, DataTensorAttr])
16+
17+
18+
def set_seed(seed):
19+
random.seed(seed)
20+
np.random.seed(seed)
21+
torch.manual_seed(seed)
22+
torch.cuda.manual_seed_all(seed)
23+
24+
25+
def create_parser():
26+
parser = argparse.ArgumentParser()
27+
parser.add_argument("--seed", type=int, default=0)
28+
parser.add_argument("--download_path", type=str)
29+
parser.add_argument("--num_epochs", type=int, default=2)
30+
return parser
31+
32+
33+
def get_dataset(download_path=None):
34+
# dataset = Reddit(download_path, transform=T.NormalizeFeatures())
35+
# dataset = PygNodePropPredDataset(name="ogbn-products", root=input_dir, transform=T.NormalizeFeatures())
36+
# gcn_norm = T.GCNNorm()
37+
# return (gcn_norm.forward(dataset[0]), dataset.num_classes)
38+
return torch.load(download_path)
39+
40+
41+
class Net(torch.nn.Module):
42+
def __init__(self, num_input_features, num_classes):
43+
super(Net, self).__init__()
44+
45+
self.conv1 = GCNConv(num_input_features, 128, normalize=False, bias=False)
46+
self.conv2 = GCNConv(128, 128, normalize=False, bias=False)
47+
self.conv3 = GCNConv(128, num_classes, normalize=False, bias=False)
48+
49+
torch.nn.init.kaiming_uniform_(self.conv1.lin.weight, a=math.sqrt(5))
50+
torch.nn.init.kaiming_uniform_(self.conv2.lin.weight, a=math.sqrt(5))
51+
torch.nn.init.kaiming_uniform_(self.conv3.lin.weight, a=math.sqrt(5))
52+
53+
def forward(self, x, edge_index):
54+
x = self.conv1(x, edge_index)
55+
x = F.relu(x)
56+
x = self.conv2(x, edge_index)
57+
x = F.relu(x)
58+
x = self.conv3(x, edge_index)
59+
return x
60+
61+
62+
def train(model, optimizer, input_features, adj, labels):
63+
model.train()
64+
65+
optimizer.zero_grad()
66+
67+
output = model(input_features, adj)
68+
69+
loss = F.cross_entropy(output, labels)
70+
71+
loss.backward()
72+
73+
optimizer.step()
74+
75+
return loss
76+
77+
78+
if __name__ == "__main__":
79+
parser = create_parser()
80+
args = parser.parse_args()
81+
set_seed(args.seed)
82+
83+
data, num_classes = get_dataset(args.download_path)
84+
num_input_features = data.x.shape[1]
85+
86+
data.y = data.y.type(torch.LongTensor)
87+
data.y = data.y.to(torch.device("cuda"))
88+
89+
features_local = data.x.to(torch.device("cuda")).requires_grad_()
90+
91+
model = Net(num_input_features, num_classes).to(torch.device("cuda"))
92+
93+
optimizer = torch.optim.AdamW(
94+
list(model.parameters()) + [features_local], lr=3e-3, weight_decay=0
95+
)
96+
97+
adj = torch.sparse_coo_tensor(
98+
data.edge_index, data.edge_weight, (data.x.shape[0], data.x.shape[0])
99+
)
100+
adj = adj.to_sparse_csr()
101+
adj = adj.to(torch.device("cuda"))
102+
103+
losses = []
104+
for i in range(args.num_epochs):
105+
loss = train(model, optimizer, features_local, adj, data.y)
106+
losses.append(loss.item())
107+
log = "Epoch: {:03d}, Train Loss: {:.4f}"
108+
print(log.format(i, loss))

0 commit comments

Comments
 (0)