Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 54 additions & 4 deletions .github/workflows/wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ jobs:

test-modal-recompute:
# These tests verify that recomputation options do not change the results at all
name: Test recompute - ${{ matrix.recompute.name }} - ${{ matrix.dtype.name }}
name: Recompute - ${{ matrix.recompute.name }} - ${{ matrix.dtype.name }}
needs: deploy-modal
runs-on: ubuntu-latest
if: false
strategy:
fail-fast: false
max-parallel: 3
Expand All @@ -129,7 +130,7 @@ jobs:
args: "--offload-opt-m --offload-opt-v --offload-master"
# While not strictly a recomputation, chunked attention should be bitwise identical, too
- name: "Chunked attention"
args: "--recompute-att --attn-bwd-chunks 4"
args: "--recompute-att --attn-bwd-chunks=2"
dtype:
- name: "BF16"
args: "--matmul-dtype=bf16"
Expand Down Expand Up @@ -158,6 +159,7 @@ jobs:
name: Test fixed - ${{ matrix.config.name }}
needs: deploy-modal
runs-on: ubuntu-latest
if: false
strategy:
fail-fast: false
max-parallel: 3
Expand Down Expand Up @@ -192,11 +194,59 @@ jobs:
- name: Run test on Modal
run: python3 scripts/modal_test_ci.py ${{ matrix.config.args }}

test-modal-multi-gpu:
name: Test Multi-GPU - ${{ matrix.config.name }}
#needs:
# - test-modal-fixed
# - test-modal-recompute
needs: deploy-modal

runs-on: ubuntu-latest
strategy:
fail-fast: false
max-parallel: 3
matrix:
config:
- name: "BF16 weight sharding"
func: "recompute"
args: "--matmul-dtype bf16 --shard-weights"
- name: "FP8 + memcpy"
func: "recompute"
args: "--matmul-dtype e4m3 --shard-weights --memcpy-all-gather"
- name: "FP8 + persistent quants"
func: "recompute"
args: "--matmul-dtype e4m3 --shard-weights --persistent-quants --offload-quants"
- name: "Fixed BF16"
func: "fixed"
args: "bf16"
- name: "Fixed FP8"
func: "fixed"
args: "e4m3"
- name: "Fixed BF16 Shard gradient"
func: "fixed"
args: "bf16 --shard-gradients"
- name: "Fixed FP8 Shard gradient"
func: "fixed"
args: "e4m3 --shard-gradients"
steps:
- name: Checkout code
uses: actions/checkout@v4

# Note: No need to download wheel again, it's already in the deployed image

- name: Install Modal
run: pip install modal

- name: Set Modal token
run: modal token set --token-id ${{ secrets.MODAL_TOKEN_ID }} --token-secret ${{ secrets.MODAL_TOKEN_SECRET }}

- name: Run test on Modal
run: python3 scripts/modal_test_ci.py ${{ matrix.config.func }} ${{ matrix.config.args }} --gpus 2

release:
if: github.event_name == 'workflow_dispatch' || startsWith(github.ref, 'refs/heads/release-') || startsWith(github.ref, 'refs/tags/')
needs:
- test-modal-recompute
- test-modal-fixed
- test-modal-multi-gpu

runs-on: ubuntu-latest
permissions:
Expand Down
103 changes: 90 additions & 13 deletions scripts/modal_test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Usage:
modal run modal_test_app.py [-- test args...]
"""
import argparse
import io
import sys
from pathlib import Path
Expand Down Expand Up @@ -68,13 +69,7 @@ def compare_and_create_report(result, expected):
}


@app.function(
gpu="L4",
memory=8192,
timeout=300,
image=image,
)
def run_recompute_test(test_args: list[str]):
def _run_recompute_test(test_args: list[str]):
"""Run recomputation tests on Modal."""
from pyllmq.tests.run import parse_args, run_training
from pyllmq.tests.recompute import disable_recompute
Expand All @@ -89,6 +84,26 @@ def run_recompute_test(test_args: list[str]):
return compare_and_create_report(test_run, ref_run)


@app.function(
gpu="L4",
memory=8192,
timeout=300,
image=image,
)
def run_recompute_test(test_args: list[str]):
return _run_recompute_test(test_args)


@app.function(
gpu="L4:2",
memory=8192,
timeout=300,
image=image,
)
def run_recompute_test_x2(test_args: list[str]):
return _run_recompute_test(test_args)


def run_with_config(test_args: list[str]):
from pyllmq.tests.run import parse_args, run_training
config = parse_args(test_args)
Expand Down Expand Up @@ -144,6 +159,56 @@ def run_fixed_result_test(dtype: str = "bf16"):
return report


@app.function(
gpu="L4:2",
memory=8192,
timeout=300,
image=image,
)
def run_fixed_result_test_x2(dtype: str = "bf16", shard_gradients: bool = False):
from pyllmq.tests.run import RunResult

print(f"Launching Modal fixed_result test with dtype: {dtype}")

if dtype == "e5m2":
args = [f"--matmul-dtype=e4m3", "--gradient-dtype=e5m2"]
else:
args = [f"--matmul-dtype={dtype}"]

if shard_gradients:
args += ["--shard-gradients"]

args += ["--gpus=2"]

"""Run tests on Modal."""
result = run_with_config(args)
if dtype == "bf16":
expected = {
"losses": [3.4119365215301514, 3.394049882888794, 3.4545254707336426, 3.0694894790649414, 3.007321834564209, 3.3855042457580566, 3.368359088897705, 3.421376943588257, 3.1316380500793457, 3.2092301845550537, 3.01995849609375],
"norms": [5.42860746383667, 5.231578826904297, 5.656546115875244, 4.69525146484375, 4.644282341003418, 5.210570812225342, 5.396310806274414, 4.417316913604736, 4.4374165534973145, 4.28884220123291],
}
elif dtype == "e4m3":
expected = {
"losses": [3.4303817749023438, 3.43670392036438, 3.483766555786133, 3.0972299575805664, 3.0326924324035645, 3.409470558166504, 3.3872318267822266, 3.4421865940093994, 3.152552843093872, 3.229149341583252, 3.0453014373779297],
"norms": [5.8067474365234375, 8.371203422546387, 5.1532464027404785, 4.662567615509033, 4.763641834259033, 4.693724632263184, 5.259921073913574, 4.645272731781006, 4.207671165466309, 4.346331596374512]
}
elif dtype == "e5m2":
expected = {
"losses": [3.4303817749023438, 3.4341166019439697, 3.4837355613708496, 3.09706711769104, 3.0316996574401855, 3.410259962081909, 3.3873462677001953, 3.441790819168091, 3.1511523723602295, 3.2284598350524902, 3.0418832302093506],
"norms": [5.7736382484436035, 8.317730903625488, 5.149673938751221, 4.641636371612549, 4.685691833496094, 4.650301933288574, 5.228470325469971, 4.605687618255615, 4.183129787445068, 4.276437759399414],
}
else:
raise ValueError(f"Unknown dtype: {dtype}")

report = compare_and_create_report(result, RunResult(**expected))
if not report["passed"]:
import json
import dataclasses
# this helps with debugging/updating in case of failure
print(json.dumps(dataclasses.asdict(result)))
return report


@app.function(
gpu="L4",
memory=8192,
Expand All @@ -166,10 +231,21 @@ def run_torch_compare_step(test_args: list):
}


def _get_gpu_arg(args: tuple[str, ...]) -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default="1")
parsed_args, _ = parser.parse_known_args(args)
return parsed_args.gpus


@app.local_entrypoint()
def test_recompute(*test_args: str):
print(f"Launching Modal recomputation test with args: {test_args}")
result = run_recompute_test.remote(list(test_args))
gpus = _get_gpu_arg(test_args)
if gpus == 2:
result = run_recompute_test_x2.remote(list(test_args))
else:
result = run_recompute_test.remote(list(test_args))

# Print the comparison report
print("\n" + result["report"])
Expand All @@ -188,15 +264,16 @@ def test_torch_step(*test_args: str):


@app.local_entrypoint()
def test_fixed(dtype: str = "bf16"):
def test_fixed(dtype: str = "bf16", gpus: int = 1, shard_gradients: bool = False):
print(f"Launching Modal test with dtype: {dtype}")
result = run_fixed_result_test.remote(dtype)
if gpus == 2:
result = run_fixed_result_test_x2.remote(dtype, shard_gradients)
else:
assert shard_gradients == False, "shard_gradient only supported for 2 gpus"
result = run_fixed_result_test.remote(dtype)

# Print the comparison report
print("\n" + result["report"])

if not result["passed"]:
sys.exit(1)



37 changes: 30 additions & 7 deletions scripts/modal_test_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,56 @@
Usage:
python run_modal_tests.py [test args...]
"""
import argparse
import sys
import modal


def _get_gpu_arg(args: list[str]) -> tuple[int, list[str]]:
parser = argparse.ArgumentParser()
parser.add_argument("--gpus", type=int, default="1")
parsed_args, rest = parser.parse_known_args(args)
return parsed_args.gpus, rest


if __name__ == "__main__":
# Reference the already-deployed app
app = modal.App.lookup("llmq-test", create_if_missing=False)

test_name = sys.argv[1]
gpus, rest = _get_gpu_arg(sys.argv[2:])
test_args_pos = []
test_args_kw = {}

if test_name == "recompute":
# Get the run_recompute_test function from the deployed app
test_fn = modal.Function.from_name("llmq-test", "run_recompute_test")
test_args = sys.argv[2:]
if gpus == 2:
test_fn = modal.Function.from_name("llmq-test", "run_recompute_test_x2")
else:
test_fn = modal.Function.from_name("llmq-test", "run_recompute_test")
test_args_pos = [sys.argv[2:]]
elif test_name == "fixed":
test_fn = modal.Function.from_name("llmq-test", "run_fixed_result_test")
test_args = sys.argv[2]
parser = argparse.ArgumentParser()
parser.add_argument("dtype", type=str)
parser.add_argument("--shard-gradient", action="store_true")
parsed_args, rest = parser.parse_known_args(rest)
if gpus == 2:
test_fn = modal.Function.from_name("llmq-test", "run_fixed_result_test_x2")
test_args_kw = {"dtype": parsed_args.dtype, "shard_gradients": parsed_args.shard_gradient}
else:
assert not parsed_args.shard_gradient, "shard_gradient only supported for 2 gpus"
test_fn = modal.Function.from_name("llmq-test", "run_fixed_result_test")
test_args_kw = {"dtype": parsed_args.dtype}
elif test_name == "torch-step":
test_fn = modal.Function.from_name("llmq-test", "run_torch_compare_step")
test_args = sys.argv[2:]
test_args_pos = [sys.argv[2:]]
else:
raise RuntimeError(f"Unknown test type {test_name}")

# Get test arguments from command line

print(f"Launching Modal test with args: {test_args}")
result = test_fn.remote(test_args)
print(f"Launching Modal test with args: {test_args_pos}, {test_args_kw}")
result = test_fn.remote(*test_args_pos, **test_args_kw)

# Print the comparison report
print("\n" + result["report"])
Expand Down
Loading