From b6cfaa31e958e69cb482475ff6122486a4969837 Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Sat, 9 Aug 2025 22:27:11 -0500 Subject: [PATCH] [jax_llm_examples/cli] New CLI implementation ; [pyproject.toml] Prepare for meta package ; [README.md] Document new CLI --- README.md | 40 ++++++++ jax_llm_examples/__init__.py | 4 + jax_llm_examples/cli/__init__.py | 0 jax_llm_examples/cli/__main__.py | 156 +++++++++++++++++++++++++++++++ pyproject.toml | 128 +++++++++++++++++++++++++ 5 files changed, 328 insertions(+) create mode 100644 jax_llm_examples/__init__.py create mode 100644 jax_llm_examples/cli/__init__.py create mode 100644 jax_llm_examples/cli/__main__.py diff --git a/README.md b/README.md index f843e64..074614e 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,46 @@ Current contents include: * [Kimi K2](kimi_k2/) * [OpenAI GPT OSS](gpt_oss/) +## CLI + +### `python3 -m jax_llm_examples.cli --help` +``` +usage: python3 -m jax_llm_examples [-h] [--version] [-s SEARCH] {ls,run} ... + +A collection of JAX implementations for various Large Language Models. + +positional arguments: + {ls,run} + ls List installed models + run Run specified model. Explicitly calls the main.py as + `if __name__ == "__main__"` + +options: + -h, --help show this help message and exit + --version show program's version number and exit + -s SEARCH, --search SEARCH + Alternative filepath(s) or fully-qualified name (FQN) + to use models from. +``` + +### `python3 -m jax_llm_examples.cli ls --help` +``` +usage: python3 -m jax_llm_examples ls [-h] + +options: + -h, --help show this help message and exit +``` + +### `python3 -m jax_llm_examples.cli run --help` +``` +usage: python3 -m jax_llm_examples run [-h] -n MODEL_NAME + +options: + -h, --help show this help message and exit + -n MODEL_NAME, --model-name MODEL_NAME + Model name +``` + --- For multi-host cluster setup and distributed training, see [multi_host_README.md](./multi_host_README.md) and the [tpu_toolkit.sh script](./misc/tpu_toolkit.sh). diff --git a/jax_llm_examples/__init__.py b/jax_llm_examples/__init__.py new file mode 100644 index 0000000..b28f429 --- /dev/null +++ b/jax_llm_examples/__init__.py @@ -0,0 +1,4 @@ +__version__ = "2025.08.09" +__description__ = "A collection of JAX implementations for various Large Language Models." + +__all__ = ["__description__", "__version__"] diff --git a/jax_llm_examples/cli/__init__.py b/jax_llm_examples/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jax_llm_examples/cli/__main__.py b/jax_llm_examples/cli/__main__.py new file mode 100644 index 0000000..0c35db9 --- /dev/null +++ b/jax_llm_examples/cli/__main__.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 + +import argparse +import importlib +import importlib.util +import os +import runpy +import sys + +from jax_llm_examples import __description__, __version__ + + +def get_module(name, package=None): + absolute_name = importlib.util.resolve_name(name, package) + if absolute_name in sys.modules: + return sys.modules[absolute_name] + + path, parent_module, child_name, path, spec = None, None, None, None, None + if "." in absolute_name: + parent_name, _, child_name = absolute_name.rpartition(".") + parent_module = get_module(parent_name) + path = parent_module.__spec__.submodule_search_locations + for finder in sys.meta_path: + spec = finder.find_spec(absolute_name, path) + if spec is not None: + break + if spec is None: + raise ModuleNotFoundError(f"No module named {absolute_name!r}", name=absolute_name) + module = importlib.util.module_from_spec(spec) + if path is not None: + setattr(parent_module, child_name, module) + return module + + +def filepath_from_module(module): + return os.path.dirname(get_module(module).__file__) + + +def _build_parser(): + """ + Parser builder + + :return: instanceof argparse.ArgumentParser + :rtype: ```argparse.ArgumentParser``` + """ + parser = argparse.ArgumentParser( + prog="python3 -m jax_llm_examples", + description=__description__, + ) + parser.add_argument( + "--version", + action="version", + version="%(prog)s {__version__}".format(__version__=__version__), + ) + parser.add_argument( + "-s", + "--search", + action="append", + default=[os.path.dirname(os.path.dirname(os.path.dirname(__file__)))], + help="Alternative filepath(s) or fully-qualified name (FQN) to use models from.", + ) + + subparsers: argparse._SubParsersAction[argparse.ArgumentParser] = parser.add_subparsers() + subparsers.required = True + subparsers.dest = "command" + + ###### + # ls # + ###### + ls_parser: argparse.ArgumentParser = subparsers.add_parser( + "ls", + help="List installed models", + ) + + ####### + # run # + ####### + run_parser: argparse.ArgumentParser = subparsers.add_parser( + "run", + help='Run specified model. Explicitly calls the main.py as `if __name__ == "__main__"`', + ) + run_parser.add_argument("-n", "--model-name", help="Model name", required=True) + + return parser + + +def main(cli_argv=None, return_args=False): + """ + Run the CLI parser + + :param cli_argv: CLI arguments. If None uses `sys.argv`. + :type cli_argv: ```None | list[str]``` + + :param return_args: Primarily use is for tests. Returns the args rather than executing anything. + :type return_args: ```bool``` + + :return: the args if `return_args`, else None + :rtype: ```None | Namespace``` + """ + _parser: argparse.ArgumentParser = _build_parser() + args: argparse.Namespace = _parser.parse_args(args=cli_argv) + if return_args: + return args + + if args.command == "ls": + print( + "\n".join( + sorted( + f"- {d}" + for search_path in frozenset(args.search) + for d in ( + os.listdir(search_path) if os.path.isdir(search_path) else filepath_from_module(search_path) + ) + if os.path.isdir(os.path.join(search_path, d)) + and d + not in frozenset( + ( + ".git", + ".github", + ".idea", + ".venv", + ".vscode", + "__pycache__", + "build", + "jax_llm_examples", + "jax_llm_examples.egg-info", + "misc", + ) + ) + and os.path.isfile(os.path.join(search_path, d, "main.py")) + ) + ) + ) + return None + elif args.command == "run": + search_paths = tuple( + sorted( + (search_path if os.path.isdir(search_path) else filepath_from_module(search_path)) + for search_path in frozenset(args.search) + ) + ) + for search_path in search_paths: + candidate = os.path.join(search_path, args.model_name, "main.py") + if os.path.isdir(search_path) and os.path.isfile(candidate): + runpy.run_path(str(candidate), run_name="__main__") + else: + run_mod = search_path # TODO: back-and-forth with `filepath_from_module` to find the right one + runpy.run_module(str(run_mod), run_name="__main__") + return None + raise ImportError(f"Could not find `run_model` function in {search_paths!r}") + else: + raise NotImplementedError + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index e6c0762..4fce337 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,131 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "jax-llm-examples" +version = "0.1.0" +description = "A collection of JAX implementations for various Large Language Models." +requires-python = ">=3.10" +readme = "README.md" +license = { file = "LICENSE" } +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] + +dependencies = [ + "jax", + "jaxlib", + "numpy", + "transformers", +] + +[project.optional-dependencies] +deepseek_r1 = [ + "etils", + "ipykernel", + "orbax-checkpoint", + "torch", + "torchvision>=0.21.0", + "tpu-info", + "tqdm", + "transformers>=4.49.0", +] +gpt_oss = [ + "absl-py", + "datasets", + "etils", + "flatbuffers", + "gcsfs", + "tensorstore", + "torch", + "tqdm", +] +kimi_k2 = [ + "etils", + "ipykernel", + "orbax-checkpoint", + "torch", + "torchvision>=0.21.0", + "tpu-info", + "tqdm", + "transformers>=4.49.0", +] +llama3 = [ + "datasets", + "etils", + "gcsfs", + "orbax-checkpoint", + "torch", + "tqdm", +] +llama4 = [ + "datasets", + "etils", + "gcsfs", + "orbax-checkpoint", + "torch", + "tqdm", +] +# qwen3 has the same dependencies as llama3 +qwen3 = [ + "datasets", + "etils", + "gcsfs", + "orbax-checkpoint", + "torch", + "tqdm", +] + +# A convenience extra to install dependencies for all models. +all = [ + "jax-llm-examples[cli]", + "jax-llm-examples[deepseek_r1]", + "jax-llm-examples[gpt_oss]", + "jax-llm-examples[kimi_k2]", + "jax-llm-examples[llama3]", + "jax-llm-examples[llama4]", + "jax-llm-examples[qwen3]", +] + +# Setuptools configuration to handle the monorepo structure +[tool.setuptools] +# Explicitly list all packages and their subpackages to be included in the distribution. +# Assumes source code is in the directories specified in `package_dir`. +packages = [ + "deepseek_r1_jax", + "deepseek_r1_jax.third_party", + "deepseek_r1_jax.third_party.tokenizer", + "gpt_oss_jax", + "jax_llm_examples.cli", + "kimi_k2_jax", + "kimi_k2_jax.third_party", + "llama3_jax", + "llama4_jax", + "qwen3_jax", +] + +[tool.setuptools.package-dir] +# Map the Python package names to their actual directories on disk. +deepseek_r1_jax = "deepseek_r1_jax/deepseek_r1_jax" +gpt_oss_jax = "gpt_oss/gpt_oss_jax" +kimi_k2_jax = "kimi_k2/kimi_k2_jax" +llama3_jax = "llama3/llama3_jax" +llama4_jax = "llama4/llama4_jax" +qwen3_jax = "qwen3/qwen3_jax" + +[tool.setuptools.package-data] +# Include necessary non-python files (configs, tokenizers, etc.) for specific packages +"deepseek_r1_jax.third_party" = ["*.json"] +"deepseek_r1_jax.third_party.tokenizer" = ["*.gz", "*.json"] +"kimi_k2_jax.third_party" = ["*.json"] + [tool.black] line-length = 120