Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer: the model analysis on the AOT compiled JAX program. #1036

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ds-hwang
Copy link
Contributor

@ds-hwang ds-hwang commented Mar 5, 2025

This will help researchers estimate HBM usage and computation costs before launching a job, allowing them to determine whether a model is compute-bound or memory-bound.

Introduced aot_model_analysis(), which returns analysis results as a string, making it reusable (e.g., in Jupyter notebooks).

In addition, change run_aot_compilation to support it on CPU. run_aot_compilation tool prints fuji-1B-v3 model analysis as follows.

XLA_FLAGS=--xla_dump_to=/tmp/aot_xla_dump \
python -m axlearn.experiments.run_aot_compilation \
    --module=axlearn.experiments.text.gpt.c4_trainer \
    --config=fuji-1B-v3 \
    --topology=v4-1024 --cpu 1> /tmp/aot_stdout
======= Memory Analysis ==================================
Input memory: 4465.0 MB / 4.36 GB
Output memory: 4464.8 MB / 4.36 GB
Temp memory: 174977.1 MB / 170.88 GB
Code memory: 0.0 MB / 0.00 GB
Total HBM memory: 183906.9 MB / 179.60 GB
======= Cost Analysis ====================================
FLOPS: 71733280.0 M / 70052.03 G
The number of exp/log/sin/cos ops: 21364.8 M / 20.86 G
The total memory traffic: 1792723.2 MB / 1750.71 GB
  HBM access: 751479.1 MB / 733.87 GB
  L2 cache access: 328740.8 MB / 321.04 GB
  Register usage: 61266.7 MB / 59.83 GB
  Output data transferred: 677251.9 MB / 661.38 GB
Hardware utilization scores
  Tensor Cores / MatMul units: 647.0
  ALU (Arithmetic Logic Unit): 430.0
  Memory Load/Store Units: 144.0
  L1 Cache Operations: 92.0
  L2 Cache Operations: 60.0
  Special Function Units (exp/log/sin/cos): 41.0
  Integer Units (for indexing, loop counters): 16.0
  Branch Divergence (Control Flow Processing): 12.0
  Load Balancing / Dispatch): 10.0
  Texture Units (or Rarely Used Compute Units): 8.0

@ds-hwang ds-hwang requested review from ruomingp, markblee and a team as code owners March 5, 2025 19:53
@ds-hwang
Copy link
Contributor Author

ds-hwang commented Mar 5, 2025

@markblee Could you review it? From 1112

if not hasattr(compiled, "memory_analysis"):
return ""

to_mb_gb = lambda x: f"{x / (1024**2):.1f} MB / {x / (1024**3):.2f} GB"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be confusing since users might think the left number is the usage and right number is the maximum available. (e.g., they might mistake it for meaning "x mb out of y gb used"). Can we eliminate the mb?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But some downstream model (e.g. speech model), GB is too big. Let me make it dynamically decide MB or GB.

@@ -91,6 +101,14 @@ def _compile_and_dump_programs(
logging.info("Wrote serialized %s to %s", program_name, serialized_compiled_output_path)
Copy link
Contributor

@apghml apghml Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR still doesn't deduplicate the the memory printing code in run_aot_compilation.py with the new function you have in trainer.py?

This will help researchers estimate HBM usage and computation costs before
launching a job, allowing them to determine whether a model is compute-bound or
memory-bound.

Introduced aot_model_analysis(), which returns analysis results as a string,
making it reusable (e.g., in Jupyter notebooks).

`run_aot_compilation` tool prints fuji-1B-v3 model analysis as follows.
```
======= Memory Analysis ==================================
Input memory: 4465.0 MB / 4.36 GB
Output memory: 4464.8 MB / 4.36 GB
Temp memory: 174977.1 MB / 170.88 GB
Code memory: 0.0 MB / 0.00 GB
Total HBM memory: 183906.9 MB / 179.60 GB
======= Cost Analysis ====================================
FLOPS: 71733280.0 M / 70052.03 G
The number of exp/log/sin/cos ops: 21364.8 M / 20.86 G
The total memory traffic: 1792723.2 MB / 1750.71 GB
  HBM access: 751479.1 MB / 733.87 GB
  L2 cache access: 328740.8 MB / 321.04 GB
  Register usage: 61266.7 MB / 59.83 GB
  Output data transferred: 677251.9 MB / 661.38 GB
Hardware utilization scores
  Tensor Cores / MatMul units: 647.0
  ALU (Arithmetic Logic Unit): 430.0
  Memory Load/Store Units: 144.0
  L1 Cache Operations: 92.0
  L2 Cache Operations: 60.0
  Special Function Units (exp/log/sin/cos): 41.0
  Integer Units (for indexing, loop counters): 16.0
  Branch Divergence (Control Flow Processing): 12.0
  Load Balancing / Dispatch): 10.0
  Texture Units (or Rarely Used Compute Units): 8.0
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants