Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,46 @@ Current contents include:
* [Kimi K2](kimi_k2/)
* [OpenAI GPT OSS](gpt_oss/)

## CLI

### `python3 -m jax_llm_examples.cli --help`
```
usage: python3 -m jax_llm_examples [-h] [--version] [-s SEARCH] {ls,run} ...

A collection of JAX implementations for various Large Language Models.

positional arguments:
{ls,run}
ls List installed models
run Run specified model. Explicitly calls the main.py as
`if __name__ == "__main__"`

options:
-h, --help show this help message and exit
--version show program's version number and exit
-s SEARCH, --search SEARCH
Alternative filepath(s) or fully-qualified name (FQN)
to use models from.
```

### `python3 -m jax_llm_examples.cli ls --help`
```
usage: python3 -m jax_llm_examples ls [-h]

options:
-h, --help show this help message and exit
```

### `python3 -m jax_llm_examples.cli run --help`
```
usage: python3 -m jax_llm_examples run [-h] -n MODEL_NAME

options:
-h, --help show this help message and exit
-n MODEL_NAME, --model-name MODEL_NAME
Model name
```

---

For multi-host cluster setup and distributed training, see [multi_host_README.md](./multi_host_README.md) and the [tpu_toolkit.sh script](./misc/tpu_toolkit.sh).
4 changes: 4 additions & 0 deletions jax_llm_examples/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__version__ = "2025.08.09"
__description__ = "A collection of JAX implementations for various Large Language Models."

__all__ = ["__description__", "__version__"]
Empty file.
156 changes: 156 additions & 0 deletions jax_llm_examples/cli/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#!/usr/bin/env python3

import argparse
import importlib
import importlib.util
import os
import runpy
import sys

from jax_llm_examples import __description__, __version__


def get_module(name, package=None):
absolute_name = importlib.util.resolve_name(name, package)
if absolute_name in sys.modules:
return sys.modules[absolute_name]

path, parent_module, child_name, path, spec = None, None, None, None, None
if "." in absolute_name:
parent_name, _, child_name = absolute_name.rpartition(".")
parent_module = get_module(parent_name)
path = parent_module.__spec__.submodule_search_locations
for finder in sys.meta_path:
spec = finder.find_spec(absolute_name, path)
if spec is not None:
break
if spec is None:
raise ModuleNotFoundError(f"No module named {absolute_name!r}", name=absolute_name)
module = importlib.util.module_from_spec(spec)
if path is not None:
setattr(parent_module, child_name, module)
return module


def filepath_from_module(module):
return os.path.dirname(get_module(module).__file__)


def _build_parser():
"""
Parser builder

:return: instanceof argparse.ArgumentParser
:rtype: ```argparse.ArgumentParser```
"""
parser = argparse.ArgumentParser(
prog="python3 -m jax_llm_examples",
description=__description__,
)
parser.add_argument(
"--version",
action="version",
version="%(prog)s {__version__}".format(__version__=__version__),
)
parser.add_argument(
"-s",
"--search",
action="append",
default=[os.path.dirname(os.path.dirname(os.path.dirname(__file__)))],
help="Alternative filepath(s) or fully-qualified name (FQN) to use models from.",
)

subparsers: argparse._SubParsersAction[argparse.ArgumentParser] = parser.add_subparsers()
subparsers.required = True
subparsers.dest = "command"

######
# ls #
######
ls_parser: argparse.ArgumentParser = subparsers.add_parser(
"ls",
help="List installed models",
)

#######
# run #
#######
run_parser: argparse.ArgumentParser = subparsers.add_parser(
"run",
help='Run specified model. Explicitly calls the main.py as `if __name__ == "__main__"`',
)
run_parser.add_argument("-n", "--model-name", help="Model name", required=True)

return parser


def main(cli_argv=None, return_args=False):
"""
Run the CLI parser

:param cli_argv: CLI arguments. If None uses `sys.argv`.
:type cli_argv: ```None | list[str]```

:param return_args: Primarily use is for tests. Returns the args rather than executing anything.
:type return_args: ```bool```

:return: the args if `return_args`, else None
:rtype: ```None | Namespace```
"""
_parser: argparse.ArgumentParser = _build_parser()
args: argparse.Namespace = _parser.parse_args(args=cli_argv)
if return_args:
return args

if args.command == "ls":
print(
"\n".join(
sorted(
f"- {d}"
for search_path in frozenset(args.search)
for d in (
os.listdir(search_path) if os.path.isdir(search_path) else filepath_from_module(search_path)
)
if os.path.isdir(os.path.join(search_path, d))
and d
not in frozenset(
(
".git",
".github",
".idea",
".venv",
".vscode",
"__pycache__",
"build",
"jax_llm_examples",
"jax_llm_examples.egg-info",
"misc",
)
)
and os.path.isfile(os.path.join(search_path, d, "main.py"))
)
)
)
return None
elif args.command == "run":
search_paths = tuple(
sorted(
(search_path if os.path.isdir(search_path) else filepath_from_module(search_path))
for search_path in frozenset(args.search)
)
)
for search_path in search_paths:
candidate = os.path.join(search_path, args.model_name, "main.py")
if os.path.isdir(search_path) and os.path.isfile(candidate):
runpy.run_path(str(candidate), run_name="__main__")
else:
run_mod = search_path # TODO: back-and-forth with `filepath_from_module` to find the right one
runpy.run_module(str(run_mod), run_name="__main__")
return None
raise ImportError(f"Could not find `run_model` function in {search_paths!r}")
else:
raise NotImplementedError


if __name__ == "__main__":
main()
128 changes: 128 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,131 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "jax-llm-examples"
version = "0.1.0"
description = "A collection of JAX implementations for various Large Language Models."
requires-python = ">=3.10"
readme = "README.md"
license = { file = "LICENSE" }
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]

dependencies = [
"jax",
"jaxlib",
"numpy",
"transformers",
]

[project.optional-dependencies]
deepseek_r1 = [
"etils",
"ipykernel",
"orbax-checkpoint",
"torch",
"torchvision>=0.21.0",
"tpu-info",
"tqdm",
"transformers>=4.49.0",
]
gpt_oss = [
"absl-py",
"datasets",
"etils",
"flatbuffers",
"gcsfs",
"tensorstore",
"torch",
"tqdm",
]
kimi_k2 = [
"etils",
"ipykernel",
"orbax-checkpoint",
"torch",
"torchvision>=0.21.0",
"tpu-info",
"tqdm",
"transformers>=4.49.0",
]
llama3 = [
"datasets",
"etils",
"gcsfs",
"orbax-checkpoint",
"torch",
"tqdm",
]
llama4 = [
"datasets",
"etils",
"gcsfs",
"orbax-checkpoint",
"torch",
"tqdm",
]
# qwen3 has the same dependencies as llama3
qwen3 = [
"datasets",
"etils",
"gcsfs",
"orbax-checkpoint",
"torch",
"tqdm",
]

# A convenience extra to install dependencies for all models.
all = [
"jax-llm-examples[cli]",
"jax-llm-examples[deepseek_r1]",
"jax-llm-examples[gpt_oss]",
"jax-llm-examples[kimi_k2]",
"jax-llm-examples[llama3]",
"jax-llm-examples[llama4]",
"jax-llm-examples[qwen3]",
]

# Setuptools configuration to handle the monorepo structure
[tool.setuptools]
# Explicitly list all packages and their subpackages to be included in the distribution.
# Assumes source code is in the directories specified in `package_dir`.
packages = [
"deepseek_r1_jax",
"deepseek_r1_jax.third_party",
"deepseek_r1_jax.third_party.tokenizer",
"gpt_oss_jax",
"jax_llm_examples.cli",
"kimi_k2_jax",
"kimi_k2_jax.third_party",
"llama3_jax",
"llama4_jax",
"qwen3_jax",
]

[tool.setuptools.package-dir]
# Map the Python package names to their actual directories on disk.
deepseek_r1_jax = "deepseek_r1_jax/deepseek_r1_jax"
gpt_oss_jax = "gpt_oss/gpt_oss_jax"
kimi_k2_jax = "kimi_k2/kimi_k2_jax"
llama3_jax = "llama3/llama3_jax"
llama4_jax = "llama4/llama4_jax"
qwen3_jax = "qwen3/qwen3_jax"

[tool.setuptools.package-data]
# Include necessary non-python files (configs, tokenizers, etc.) for specific packages
"deepseek_r1_jax.third_party" = ["*.json"]
"deepseek_r1_jax.third_party.tokenizer" = ["*.gz", "*.json"]
"kimi_k2_jax.third_party" = ["*.json"]

[tool.black]
line-length = 120

Expand Down