diff --git a/.github/workflows/e2e-nvidia-l40s-x4-sdk.yml b/.github/workflows/e2e-nvidia-l40s-x4-sdk.yml
index f7d79ab7..780ffcb6 100644
--- a/.github/workflows/e2e-nvidia-l40s-x4-sdk.yml
+++ b/.github/workflows/e2e-nvidia-l40s-x4-sdk.yml
@@ -3,15 +3,317 @@
name: E2E (NVIDIA L40S x4) SDK Test
on:
+ pull_request:
+ branches:
+ - "main"
+ schedule:
+ - cron: '0 16 * * *' # Runs at 4PM UTC every day
workflow_dispatch:
inputs:
pr_or_branch:
description: 'pull request number or branch name'
required: true
default: 'main'
+concurrency:
+ group: ${{ github.workflow }}-${{ github.event.number || github.ref }}
+ cancel-in-progress: true
+
+env:
+ TMPDIR: /home/tmp
+
jobs:
- noop:
+ start-large-ec2-runner:
+ runs-on: ubuntu-latest
+ outputs:
+ label: ${{ steps.launch-ec2-instance-with-fallback.outputs.label }}
+ ec2-instance-id: ${{ steps.launch-ec2-instance-with-fallback.outputs.ec2-instance-id }}
+ ec2-instance-region: ${{ steps.launch-ec2-instance-with-fallback.outputs.ec2-instance-region }}
+ steps:
+ - name: Checkout "launch-ec2-runner-with-fallback" in-house CI action
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ repository: instructlab/ci-actions
+ # clone the "ci-actions" repo to a local directory called "ci-actions", instead of overwriting the current WORKDIR contents
+ path: ci-actions
+ ref: release-v0.1
+ sparse-checkout: |
+ actions/launch-ec2-runner-with-fallback
+
+ - name: Launch EC2 Runner with Fallback
+ id: launch-ec2-instance-with-fallback
+ uses: ./ci-actions/actions/launch-ec2-runner-with-fallback
+ env:
+ TMPDIR: "/tmp"
+ with:
+ aws_access_key_id: ${{ secrets.AWS_ACCESS_KEY_ID }}
+ aws_secret_access_key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
+ github_token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
+ regions_config: >
+ [
+ {
+ "region": "us-east-2",
+ "subnets": {
+ "us-east-2a": "${{ vars.SUBNET_US_EAST_2A }}",
+ "us-east-2b": "${{ vars.SUBNET_US_EAST_2B }}",
+ "us-east-2c": "${{ vars.SUBNET_US_EAST_2C }}"
+ },
+ "ec2-ami": "${{ vars.AWS_EC2_AMI_US_EAST_2 }}",
+ "security-group-id": "${{ vars.SECURITY_GROUP_ID_US_EAST_2 }}"
+ },
+ {
+ "region": "us-east-1",
+ "subnets": {
+ "us-east-1a": "${{ vars.SUBNET_US_EAST_1A }}",
+ "us-east-1b": "${{ vars.SUBNET_US_EAST_1B }}",
+ "us-east-1c": "${{ vars.SUBNET_US_EAST_1C }}",
+ "us-east-1d": "${{ vars.SUBNET_US_EAST_1D }}",
+ "us-east-1e": "${{ vars.SUBNET_US_EAST_1E }}",
+ "us-east-1f": "${{ vars.SUBNET_US_EAST_1F }}"
+ },
+ "ec2-ami": "${{ vars.AWS_EC2_AMI_US_EAST_1 }}",
+ "security-group-id": "${{ vars.SECURITY_GROUP_ID_US_EAST_1 }}"
+ }
+ ]
+ try_spot_instance_first: false
+ ec2_instance_type: g6e.12xlarge
+ aws_resource_tags: >
+ [
+ {"Key": "Name", "Value": "instructlab-ci-github-large-runner"},
+ {"Key": "GitHubRepository", "Value": "${{ github.repository }}"},
+ {"Key": "GitHubRef", "Value": "${{ github.ref }}"},
+ {"Key": "GitHubPR", "Value": "${{ github.event.number }}"}
+ ]
+
+ e2e-large-test:
+ needs:
+ - start-large-ec2-runner
+ runs-on: ${{ needs.start-large-ec2-runner.outputs.label }}
+
+ permissions:
+ pull-requests: write
+
+ steps:
+ - name: "Harden Runner"
+ # v2.10.1
+ uses: step-security/harden-runner@c6295a65d1254861815972266d5933fd6e532bdf
+ with:
+ egress-policy: audit
+ - name: Install Packages
+ run: |
+ cat /etc/os-release
+ mkdir -p "${TMPDIR}"
+ sudo dnf install -y gcc gcc-c++ make git python3.11 python3.11-devel
+
+ - name: Checkout
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ # https://github.com/actions/checkout/issues/249
+ fetch-depth: 0
+
+ - name: Install dependent PRs if needed
+ uses: depends-on/depends-on-action@61cb3f4a0e2c8ae4b90c9448dc57c7ba9ca24c35 # main
+ with:
+ token: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Fetch and checkout PR
+ if: ${{ github.event_name == 'pull_request_target' }}
+ run: |
+ git fetch origin pull/${{ github.event.number }}/head:pr-${{ github.event.number }}
+ git checkout pr-${{ github.event.number }}
+
+ - name: Update instructlab-training library
+ run: |
+ export CUDA_HOME="/usr/local/cuda"
+ export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64"
+ export PATH="$PATH:$CUDA_HOME/bin"
+ nvidia-smi
+ python3.11 -m venv --upgrade-deps venv
+ . venv/bin/activate
+ pip install instructlab
+ pip install instructlab[cuda]
+ pip install vllm
+ python3.11 -m pip install packaging wheel setuptools-scm
+ pip install .
+ pip install .[cuda]
+ python3.11 -m pip uninstall -y flash-attn
+ python3.11 -m pip cache purge
+ python3.11 -m pip install ninja
+ MAX_JOBS=8 python3.11 -m pip install flash-attn --no-build-isolation
+
+ - name: Check disk before tests
+ run: |
+ df -h
+
+ # TODO: switch to downloading a ds rather than generating one
+ # - name: Download SDG Dataset
+ # working-directory: ./training
+ # uses: actions/download-artifact@v4
+ # with:
+ # name: sdg-dataset.jsonl
+ # path: dataset
+
+ - name: Run e2e test
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
+ run: |
+ . venv/bin/activate
+ ls scripts
+ ls ./
+ ./scripts/test-sdk.sh
+
+ # we know that the file will be named something like f"/training_params_and_metrics_global{os.environ['RANK']}.jsonl" in python
+ # and we know that it will be written into a directory created by `mktemp -d`.
+ # Given this information, we can use the following command to find the file:
+ log_files=$(find /tmp/ -name "training_params_and_metrics_global0.jsonl")
+ phase_num=1;
+ for log_file in $log_files; do
+ mv "${log_file}" phase-${phase_num}-training-log.jsonl
+ ((phase_num++))
+ done
+
+ - name: Check disk after tests
+ run: |
+ df -h
+
+ - name: Upload training logs Phase 1
+ uses: actions/upload-artifact@v4
+ with:
+ name: phase-1-training-log.jsonl
+ path: ./phase-1-training-log.jsonl
+ retention-days: 1
+ overwrite: true
+
+ - name: Upload training logs Phase 2
+ uses: actions/upload-artifact@v4
+ with:
+ name: phase-2-training-log.jsonl
+ path: ./phase-2-training-log.jsonl
+ retention-days: 1
+ overwrite: true
+
+ stop-large-ec2-runner:
+ needs:
+ - start-large-ec2-runner
+ - e2e-large-test
runs-on: ubuntu-latest
+ if: ${{ always() }}
steps:
- - name: No-op
- run: 'true'
+ - name: "Harden Runner"
+ # v2.10.1
+ uses: step-security/harden-runner@c6295a65d1254861815972266d5933fd6e532bdf
+ with:
+ egress-policy: audit
+
+ - name: Configure AWS credentials
+ uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
+ with:
+ aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
+ aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
+ aws-region: ${{ vars.AWS_REGION }}
+
+ - name: Stop EC2 runner
+ uses: machulav/ec2-github-runner@a8c20fc0876503410b2b966c124abc2311984ce2 # v2.3.9
+ with:
+ mode: stop
+ github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
+ label: ${{ needs.start-large-ec2-runner.outputs.label }}
+ ec2-instance-id: ${{ needs.start-large-ec2-runner.outputs.ec2-instance-id }}
+
+ loss-graphs:
+ needs:
+ - stop-large-ec2-runner
+ runs-on: ubuntu-latest
+ if: ${{ always() }}
+ steps:
+ - name: "Harden Runner"
+ # v2.10.1
+ uses: step-security/harden-runner@c6295a65d1254861815972266d5933fd6e532bdf
+ with:
+ egress-policy: audit
+
+ - name: Configure AWS credentials
+ uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
+ with:
+ aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
+ aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
+ aws-region: ${{ vars.AWS_REGION }}
+
+ - name: Checkout
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ with:
+ # https://github.com/actions/checkout/issues/249
+ fetch-depth: 0
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r requirements-dev.txt
+
+ - name: Download loss data Phase 1
+ id: phase-1-download-logs
+ uses: actions/download-artifact@v4
+ with:
+ name: phase-1-training-log.jsonl
+ path: downloaded-data
+
+ - name: Download loss data Phase 2
+ id: phase-2-download-logs
+ uses: actions/download-artifact@v4
+ with:
+ name: phase-2-training-log.jsonl
+ path: downloaded-data
+
+ - name: Try to upload Phase 1 to s3
+ id: phase-1-upload-s3
+ continue-on-error: true
+ run: |
+ python ./scripts/create-loss-graph.py \
+ --log-file "${{ steps.phase-1-download-logs.outputs.download-path }}/phase-1-training-log.jsonl" \
+ --output-file "./phase-1-test.md" \
+ --phase "1" \
+ --aws-region "${{ vars.AWS_REGION }}" \
+ --bucket-name "${{ vars.AWS_S3_LOSS_GRAPHS_BUCKET_NAME }}" \
+ --base-branch "${GITHUB_REF##*/}" \
+ --head-sha "${{ github.sha }}" \
+ --pr-number "${{ github.event.number }}" \
+ --origin-repository "${{ github.repository }}"
+
+ - name: Try to upload Phase 2 to s3
+ id: phase-2-upload-s3
+ continue-on-error: true
+ run: |
+ python ./scripts/create-loss-graph.py \
+ --log-file "${{ steps.phase-2-download-logs.outputs.download-path }}/phase-2-training-log.jsonl" \
+ --output-file "./phase-2-test.md" \
+ --phase "2" \
+ --aws-region "${{ vars.AWS_REGION }}" \
+ --bucket-name "${{ vars.AWS_S3_LOSS_GRAPHS_BUCKET_NAME }}" \
+ --base-branch "${GITHUB_REF##*/}" \
+ --head-sha "${{ github.sha }}" \
+ --pr-number "${{ github.event.number }}" \
+ --origin-repository "${{ github.repository }}"
+
+ - name: Check Phase 1 S3 upload status for success
+ if: steps.phase-1-upload-s3.outcome == 'success'
+ run: |
+ echo "Uploaded Phase 1 loss graph to S3."
+ cat ./phase-1-test.md >> "${GITHUB_STEP_SUMMARY}"
+
+ - name: Check Phase 2 S3 upload status for success
+ if: steps.phase-2-upload-s3.outcome == 'success'
+ run: |
+ echo "Uploaded Phase 2 loss graph to S3."
+ cat ./phase-2-test.md >> "${GITHUB_STEP_SUMMARY}"
+
+ - name: Check Phase 1 S3 upload status for failure
+ if: steps.phase-1-upload-s3.outcome == 'failure'
+ run: |
+ echo "::warning::Failed to upload Phase 1 loss graph to S3. This won't block the workflow, but you may want to investigate."
+ echo "Loss graph upload failed" >> "${GITHUB_STEP_SUMMARY}"
+
+ - name: Check Phase 2 S3 upload status for failure
+ if: steps.phase-2-upload-s3.outcome == 'failure'
+ run: |
+ echo "::warning::Failed to upload Phase 2 loss graph to S3. This won't block the workflow, but you may want to investigate."
+ echo "Loss graph upload failed" >> "${GITHUB_STEP_SUMMARY}"
diff --git a/scripts/ibm_legacy_tmpl.py b/scripts/ibm_legacy_tmpl.py
new file mode 100644
index 00000000..0f09468f
--- /dev/null
+++ b/scripts/ibm_legacy_tmpl.py
@@ -0,0 +1,30 @@
+# SPDX-License-Identifier: Apache-2.0
+
+# First Party
+from instructlab.training.chat_templates.utils import SpecialTokens, TokenInfo
+
+SPECIAL_TOKENS = SpecialTokens(
+ system=TokenInfo("<|system|>", add_to_tokenizer=True),
+ user=TokenInfo("<|user|>", add_to_tokenizer=True),
+ assistant=TokenInfo("<|assistant|>", add_to_tokenizer=True),
+ eos=TokenInfo("<|endoftext|>", add_to_tokenizer=True),
+ pad=TokenInfo("<|pad|>", add_to_tokenizer=True),
+ bos=TokenInfo("<|begginingoftext|>", add_to_tokenizer=True),
+)
+
+CHAT_TEMPLATE = (
+ "{% for message in messages %}"
+ "{% if message['role'] == 'pretraining' %}"
+ "{{'<|pretrain|>' + message['content'] + '<|endoftext|>' + '<|/pretrain|>' }}"
+ "{% elif message['role'] == 'system' %}"
+ "{{'<|system|>'+ '\n' + message['content'] + '\n'}}"
+ "{% elif message['role'] == 'user' %}"
+ "{{'<|user|>' + '\n' + message['content'] + '\n'}}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{'<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n')}}"
+ "{% endif %}"
+ "{% if loop.last and add_generation_prompt %}"
+ "{{ '<|assistant|>' + '\n' }}"
+ "{% endif %}"
+ "{% endfor %}"
+)
diff --git a/scripts/test-data/e2e-qna-grounded-employee-skill.yaml b/scripts/test-data/e2e-qna-grounded-employee-skill.yaml
new file mode 100644
index 00000000..5522dd5a
--- /dev/null
+++ b/scripts/test-data/e2e-qna-grounded-employee-skill.yaml
@@ -0,0 +1,228 @@
+version: 2
+created_by: ae2015
+task_description: >
+ Given a context document and a user's question, determine whether you have
+ enough information to provide a yes-or-no answer to the user's question.
+seed_examples:
+ - context: >
+ Document:
+
+ All regular employees except those on a leave of absence (LOA) must file Form XYZ.
+ If you are a regular employee on a LOA, you do not need to file the form. If you
+ are a supplemental employee (type B1, type B2, or type B3), your employee type
+ determines which form you must file, as follows:
+
+ * If you are a type B1 employee, you must follow the same rules that apply to the
+ regular employees.
+
+ * If you are a type B2 employee, the rules that apply to you are different from
+ those that apply to the regular employees. See Pub. 31416 to find out which forms
+ you should file.
+
+ * If you are a supplemental employee of neither type B1 nor type B2, you are a
+ type B3 employee. Different rules apply for each part of the year. For information
+ on type B3 employees, see Pub. 31416.
+ question: >
+ User asks the following question:
+
+ "I am a supplemental employee of type B1 on a LOA. Should I file Form XYZ?"
+
+ As the responding agent, do you have enough information to provide a yes-or-no answer
+ to the user's question? Answer Yes or No, then explain your answer.
+ answer: >
+ Yes, I have enough information to provide a yes-or-no answer to the user's question.
+ Specifically, the answer to the user's question is "no". Indeed, the user is a
+ supplemental employee of type B1 and is on a LOA. According to the document,
+ supplemental employees of type B1 must follow the same rules that apply to the
+ regular employees, and a regular employee on a LOA does not need to file Form XYZ.
+ Hence, the user does not need to file Form XYZ. I reached this conclusion using
+ only the information available in the document and in the user's question.
+
+ - context: >
+ Document:
+
+ All regular employees except those on a leave of absence (LOA) must file Form XYZ.
+ If you are a regular employee on a LOA, you do not need to file the form. If you
+ are a supplemental employee (type B1, type B2, or type B3), your employee type
+ determines which form you must file, as follows:
+
+ * If you are a type B1 employee, you must follow the same rules that apply to the
+ regular employees.
+
+ * If you are a type B2 employee, the rules that apply to you are different from
+ those that apply to the regular employees. See Pub. 31416 to find out which forms
+ you should file.
+
+ * If you are a supplemental employee of neither type B1 nor type B2, you are a
+ type B3 employee. Different rules apply for each part of the year. For information
+ on type B3 employees, see Pub. 31416.
+ question: >
+ User asks the following question:
+
+ "I am a supplemental employee of type B1, not on a LOA. Should I file Form XYZ?"
+
+ As the responding agent, do you have enough information to provide a yes-or-no answer
+ to the user's question? Answer Yes or No, then explain your answer.
+ answer: >
+ Yes, I have enough information to provide a yes-or-no answer to the user's question.
+ Specifically, the answer to the user's question is "yes". Indeed, the user is a
+ supplemental employee of type B1 and is not on a LOA. According to the document,
+ supplemental employees of type B1 must follow the same rules that apply to the
+ regular employees, and all regular employees except those on a LOA must file
+ Form XYZ. Hence, the user must file Form XYZ. I reached this conclusion using only
+ the information available in the document and in the user's question.
+
+ - context: >
+ Document:
+
+ The Helping Hands Program enables the employees of XYZ Technologies to support
+ nonprofit organizations within the communities where the employees live and work.
+ All active, regular employees of XYZ Technologies are eligible to participate in
+ the program. The following categories are ineligible to participate in the program:
+
+ * Retirees who no longer work at XYZ Technologies
+
+ * XYZ employees on a leave of absence
+
+ * Supplemental employees and contractors who are not regular XYZ employees
+
+ Please note that XYZ employees on a leave of absence can participate in the Helping
+ Hands Program once they return to an active status.
+ question: >
+ User asks the following question:
+
+ "I am an active, regular employee of XYZ Technologies. I am not on a leave of absence.
+ Can I participate in the Helping Hands Program?"
+
+ As the responding agent, do you have enough information to provide a yes-or-no answer
+ to the user's question? Answer Yes or No, then explain your answer.
+ answer: >
+ Yes, I have enough information to provide a yes-or-no answer to the user's question.
+ Specifically, the answer to the user's question is "yes". Indeed, the user is an
+ active, regular employee of XYZ Technologies. According to the document, all active,
+ regular employees of XYZ Technologies are eligible to participate in the Helping
+ Hands Program. The document excludes retirees, XYZ employees on a leave of absence,
+ and those who are not regular XYZ employees; however, the user is an active and
+ regular employee who is not on a leave of absence, so the exclusions do not apply.
+ Therefore, the user can participate in the Helping Hands Program. I reached this
+ conclusion using only the information available in the document and in the user's
+ question.
+
+ - context: >
+ Document:
+
+ Applying for a driver's license is as easy as 1-2-3! You can apply for a driver's license
+ if you bring documents from each of the three sections below:
+
+ Section 1. Proof of Identity - one document is required from the list below:
+
+ * Unexpired passport
+
+ * Birth certificate
+
+ * Certificate of naturalization or citizenship
+
+ * Valid, unexpired Permanent Resident Card
+
+ Section 2. Proof of Social Security Number - one document is required from the list below:
+
+ * Social Security card
+
+ * W-2 form
+
+ * SSA-1099 form
+
+ * Paystub
+
+ Section 3. Proof of Residency - two different documents are required from the list below:
+
+ * Home utility bill
+
+ * Medical document
+
+ * Proof of car registration
+
+ * Proof of employment
+
+ * Proof of insurance
+
+ * Mortgage bill
+
+ * Lease agreement
+
+ * Tax return
+ question: >
+ User asks the following question:
+
+ "Can I apply for a driver's license if I bring my birth certificate, my W-2 form, and
+ my proof of car registration?"
+
+ As the responding agent, do you have enough information to provide a yes-or-no answer
+ to the user's question? Answer Yes or No, then explain your answer.
+ answer: >
+ Yes, I have enough information to provide a yes-or-no answer to the user's question.
+ Specifically, the answer to the user's question is "no". Indeed, the user wants to
+ bring a birth certificate, a W-2 form, and a proof of car registration. According to
+ the context document, to apply for a driver's license, the user must bring one proof
+ of Identity, one proof of Social Security Number, and two different proof of Residency
+ documents. A birth certificate is one proof of Identity, a W-2 form is one proof of
+ Social Security Number, and a proof of car registration is one proof of Residency
+ document. The user cannot apply for a driver's license without bringing a second proof
+ of Residency document. I reached this conclusion using only the information available
+ in the context document and in the user's question.
+
+ - context: >
+ Document:
+
+ XYZ Corporation released three models of phones in 2024: Model MX, Model MY, and
+ Model MZ. Model MX is the budget-conscious choice, it has 6-inch screen, a single
+ rear camera, and 32 GB of storage. Model MY is the most portable choice, with a
+ 5-inch screen, a dual rear camera, and 32 GB of storage. Model MZ is the luxury
+ choice, sporting a 7-inch screen, a dual rear camera, and 64 GB of storage.
+ These three models are the only ones available for sale in our store.
+ question: >
+ User asks the following question:
+
+ "Does your store sell a phone made by XYZ Corporation that has a 6-inch screen and
+ a dual rear camera?"
+
+ As the responding agent, do you have enough information to provide a yes-or-no answer
+ to the user's question? Answer Yes or No, then explain your answer.
+ answer: >
+ Yes, I have enough information to provide a yes-or-no answer to the user's question.
+ Specifically, the answer to the user's question is "no". Indeed, the user is looking
+ for a phone made by XYZ Corporation that has a 6-inch screen and a dual rear camera.
+ According to the document, our store sells only three models of phones made by
+ XYZ Corporation: Model MX, Model MY, and Model MZ. Only Model MX has a 6-inch screen,
+ but it has a single rear camera; models MY and MZ do not have a 6-inch screen. Hence,
+ our store does not sell a phone the user is looking for. I reached this conclusion
+ using only the information available in the document and in the user's question.
+
+ - context: >
+ Document:
+
+ Solar System Overview
+
+ The solar system has one star, eight planets, five dwarf planets, at least 290 moons,
+ more than 1.3 million asteroids, and about 3,900 comets. It is located in an outer
+ spiral arm of the Milky Way galaxy called the Orion Arm, or Orion Spur. Our solar
+ system orbits the center of the galaxy at about 515,000 mph (828,000 kph). It takes
+ about 230 million years to complete one orbit around the galactic center.
+
+ We call it the solar system because it is made up of our star, the Sun, and everything
+ bound to it by gravity - the planets Mercury, Venus, Earth, Mars, Jupiter, Saturn,
+ Uranus, and Neptune; dwarf planets Pluto, Ceres, Makemake, Haumea, and Eris - along
+ with hundreds of moons; and millions of asteroids, comets, and meteoroids.
+ question: >
+ User asks the following question:
+
+ "Does the solar system have two stars?"
+
+ As the responding agent, do you have enough information to provide a yes-or-no answer
+ to the user's question? Answer Yes or No, then explain your answer.
+ answer: >
+ Yes, I have enough information to provide a yes-or-no answer to the user's question.
+ Specifically, the answer to the user's question is "no". According to the document,
+ the solar system has only one star - the Sun, not two stars. I reached this
+ conclusion using only the information available in the document and in the user's
+ question.
diff --git a/scripts/test-data/e2e-qna-knowledge-phoenix.yaml b/scripts/test-data/e2e-qna-knowledge-phoenix.yaml
new file mode 100644
index 00000000..be4c3ef9
--- /dev/null
+++ b/scripts/test-data/e2e-qna-knowledge-phoenix.yaml
@@ -0,0 +1,198 @@
+version: 3
+domain: astronomy
+created_by: juliadenham
+seed_examples:
+ - context: |
+ **Phoenix** is a minor [constellation](constellation "wikilink") in the
+ [southern sky](southern_sky "wikilink"). Named after the mythical
+ [phoenix](Phoenix_(mythology) "wikilink"), it was first depicted on a
+ celestial atlas by [Johann Bayer](Johann_Bayer "wikilink") in his 1603
+ *[Uranometria](Uranometria "wikilink")*. The French explorer and
+ astronomer [Nicolas Louis de
+ Lacaille](Nicolas_Louis_de_Lacaille "wikilink") charted the brighter
+ stars and gave their [Bayer designations](Bayer_designation "wikilink")
+ in 1756. The constellation stretches from roughly −39 degrees to −57 degrees
+ [declination](declination "wikilink"), and from 23.5h to 2.5h of [right
+ ascension](right_ascension "wikilink"). The constellations Phoenix,
+ [Grus](Grus_(constellation) "wikilink"),
+ [Pavo](Pavo_(constellation) "wikilink") and [Tucana](Tucana "wikilink"),
+ are known as the Southern Birds.
+ questions_and_answers:
+ - question: |
+ What is the Phoenix constellation?
+ answer: |
+ Phoenix is a minor constellation in the southern sky.
+ - question: |
+ Who charted the Phoenix constellation?
+ answer: |
+ The Phoenix constellation was charted by french explorer and
+ astronomer Nicolas Louis de Lacaille.
+ - question: |
+ How far does the Phoenix constellation stretch?
+ answer: |
+ The phoenix constellation stretches from roughly −39° to −57°
+ declination, and from 23.5h to 2.5h of right ascension.
+ - context: |
+ Phoenix was the largest of the 12 constellations established by [Petrus
+ Plancius](Petrus_Plancius "wikilink") from the observations of [Pieter
+ Dirkszoon Keyser](Pieter_Dirkszoon_Keyser "wikilink") and [Frederick de
+ Houtman](Frederick_de_Houtman "wikilink"). It first appeared on a 35cm
+ diameter celestial globe published in 1597 (or 1598) in Amsterdam by
+ Plancius with [Jodocus Hondius](Jodocus_Hondius "wikilink"). The first
+ depiction of this constellation in a celestial atlas was in [Johann
+ Bayer](Johann_Bayer "wikilink")'s
+ *[Uranometria](Uranometria "wikilink")* of 1603. De Houtman included
+ it in his southern star catalog the same year under the Dutch name *Den
+ voghel Fenicx*, "The Bird Phoenix", symbolising the
+ [phoenix](Phoenix_(mythology) "wikilink") of classical mythology. One
+ name of the brightest star [Alpha
+ Phoenicis](Alpha_Phoenicis "wikilink")—Ankaa—is derived from the Arabic:
+ العنقاء, romanized: al-‘anqā’, lit. 'the phoenix', and
+ was coined sometime after 1800 in relation to the constellation.
+ questions_and_answers:
+ - question: |
+ What is the brightest star in the Phoenix constellation
+ called?
+ answer: |
+ Alpha Phoenicis or Ankaa is the brightest star in the Phoenix
+ Constellation.
+ - question: Where did the Phoenix constellation first appear?
+ answer: |
+ The Phoenix constellation first appeared on a 35-cm diameter
+ celestial globe published in 1597 (or 1598) in Amsterdam by
+ Plancius with Jodocus Hondius.
+ - question: |
+ What does "The Bird Phoenix" symbolize?
+ answer: |
+ "The Bird Phoenix" symbolizes the phoenix of classical mythology.
+ - context: |
+ Phoenix is a small constellation bordered by [Fornax](Fornax "wikilink")
+ and Sculptor to the north, Grus to the west, Tucana to the south,
+ touching on the corner of [Hydrus](Hydrus "wikilink") to the south, and
+ [Eridanus](Eridanus_(constellation) "wikilink") to the east and
+ southeast. The bright star [Achernar](Achernar "wikilink") is
+ nearby. The three-letter abbreviation for the constellation, as
+ adopted by the [International Astronomical
+ Union](International_Astronomical_Union "wikilink") in 1922, is
+ "Phe". The official constellation boundaries, as set by Belgian
+ astronomer [Eugène Delporte](Eugène_Joseph_Delporte "wikilink") in 1930,
+ are defined by a polygon of 10 segments. In the [equatorial coordinate
+ system](equatorial_coordinate_system "wikilink"), the [right
+ ascension](right_ascension "wikilink") coordinates of these borders lie
+ between 23h 26.5m and 02h 25.0m,
+ while the [declination](declination "wikilink")
+ coordinates are between −39.31° and −57.84°. This means it remains
+ below the horizon to anyone living north of the [40th
+ parallel](40th_parallel_north "wikilink") in the [Northern
+ Hemisphere](Northern_Hemisphere "wikilink"), and remains low in the sky
+ for anyone living north of the [equator](equator "wikilink"). It is most
+ visible from locations such as Australia and South Africa during late
+ [Southern Hemisphere](Southern_Hemisphere "wikilink") spring. Most
+ of the constellation lies within, and can be located by, forming a
+ triangle of the bright stars Achernar, [Fomalhaut](Fomalhaut "wikilink")
+ and [Beta Ceti](Beta_Ceti "wikilink")—Ankaa lies roughly in the centre
+ of this.
+ questions_and_answers:
+ - question: What are the characteristics of the Phoenix constellation?
+ answer: |
+ Phoenix is a small constellation bordered by Fornax and Sculptor to
+ the north, Grus to the west, Tucana to the south, touching on the
+ corner of Hydrus to the south, and Eridanus to the east and southeast.
+ The bright star Achernar is nearby.
+ - question: |
+ When is the phoenix constellation most visible?
+ answer: |
+ Phoenix is most visible from locations such as Australia and
+ South Africa during late Southern Hemisphere spring.
+ - question: |
+ What are the Phoenix Constellation boundaries?
+ answer: |
+ The official constellation boundaries for Phoenix, as set by Belgian
+ astronomer Eugène Delporte in 1930, are defined by a polygon of 10
+ segments.
+ - context: |
+ Ten stars have been found to have planets to date, and four planetary
+ systems have been discovered with the [SuperWASP](SuperWASP "wikilink")
+ project. [HD 142](HD_142 "wikilink") is a yellow giant that has an
+ apparent magnitude of 5.7, and has a planet ([HD 142b](HD_142_b
+ "wikilink")) 1.36 times the mass of Jupiter which orbits every 328 days.
+ [HD 2039](HD_2039 "wikilink") is a yellow subgiant with an apparent
+ magnitude of 9.0 around 330 light years away which has a planet ([HD 2039
+ b](HD_2039_b "wikilink")) six times the mass of Jupiter. [WASP-18](WASP-18
+ "wikilink") is a star of magnitude 9.29 which was discovered to have a hot
+ Jupiter-like planet ([WASP-18b](WASP-18b "wikilink")) taking less than a
+ day to orbit the star. The planet is suspected to be causing WASP-18 to
+ appear older than it really is. [WASP-4](WASP-4 "wikilink") and
+ [WASP-5](WASP-5 "wikilink") are solar-type yellow stars around 1000
+ light years distant and of 13th magnitude, each with a single planet
+ larger than Jupiter. [WASP-29](WASP-29 "wikilink") is an orange
+ dwarf of spectral type K4V and visual magnitude 11.3, which has a
+ planetary companion of similar size and mass to Saturn. The planet
+ completes an orbit every 3.9 days.
+ questions_and_answers:
+ - question: In the Phoenix constellation, how many stars have planets?
+ answer: |
+ In the Phoenix constellation, ten stars have been found to have
+ planets to date, and four planetary systems have been discovered
+ with the SuperWASP project.
+ - question: What is HD 142?
+ answer: |
+ HD 142 is a yellow giant that has an apparent magnitude of 5.7, and
+ has a planet (HD 142 b) 1.36 times the mass of Jupiter which
+ orbits every 328 days.
+ - question: |
+ Are WASP-4 and WASP-5 solar-type yellow stars?
+ answer: |
+ Yes, WASP-4 and WASP-5 are solar-type yellow stars around 1000 light
+ years distant and of 13th magnitude, each with a single planet
+ larger than Jupiter.
+ - context: |
+ The constellation does not lie on the
+ [galactic plane](galactic_plane "wikilink") of the Milky Way, and there
+ are no prominent star clusters. [NGC 625](NGC_625 "wikilink") is a dwarf
+ [irregular galaxy](irregular_galaxy "wikilink") of apparent magnitude 11.0
+ and lying some 12.7 million light years distant. Only 24000 light years in
+ diameter, it is an outlying member of the [Sculptor Group](Sculptor_Group
+ "wikilink"). NGC 625 is thought to have been involved in a collision and
+ is experiencing a burst of [active star formation](Active_galactic_nucleus
+ "wikilink"). [NGC 37](NGC_37 "wikilink") is a
+ [lenticular galaxy](lenticular_galaxy "wikilink") of apparent magnitude
+ 14.66. It is approximately 42 [kiloparsecs](kiloparsecs "wikilink")
+ (137,000 [light-years](light-years "wikilink")) in diameter and about
+ 12.9 billion years old. [Robert's Quartet](Robert's_Quartet "wikilink")
+ (composed of the irregular galaxy [NGC 87](NGC_87 "wikilink"), and three
+ spiral galaxies [NGC 88](NGC_88 "wikilink"), [NGC 89](NGC_89 "wikilink")
+ and [NGC 92](NGC_92 "wikilink")) is a group of four galaxies located
+ around 160 million light-years away which are in the process of colliding
+ and merging. They are within a circle of radius of 1.6 arcmin,
+ corresponding to about 75,000 light-years. Located in the galaxy ESO
+ 243-49 is [HLX-1](HLX-1 "wikilink"), an
+ [intermediate-mass black hole](intermediate-mass_black_hole
+ "wikilink")—the first one of its kind identified. It is thought to be a
+ remnant of a dwarf galaxy that was absorbed in a
+ [collision](Interacting_galaxy "wikilink") with ESO 243-49. Before its
+ discovery, this class of black hole was only hypothesized.
+ questions_and_answers:
+ - question: |
+ Is the Phoenix Constellation part of the Milky Way?
+ answer: |
+ The Phoenix constellation does not lie on the galactic plane of
+ the Milky Way, and there are no prominent star clusters.
+ - question: |
+ How many light years away is NGC 625?
+ answer: |
+ NGC 625 is 24000 light years in diameter and is an outlying
+ member of the Sculptor Group.
+ - question: |
+ What is Robert's Quartet composed of?
+ answer: |
+ Robert's Quartet is composed of the irregular galaxy NGC 87,
+ and three spiral galaxies NGC 88, NGC 89 and NGC 92.
+document_outline: |
+ Information about the Phoenix Constellation including the
+ history, characteristics, and features of the stars in the constellation.
+document:
+ repo: https://github.com/RedHatOfficial/rhelai-taxonomy-data
+ commit: c87a82eb15567f28c0a8d30025e0cd77a2150646
+ patterns:
+ - phoenix.md
diff --git a/scripts/test-data/profile-l40s-x4.yaml b/scripts/test-data/profile-l40s-x4.yaml
new file mode 100644
index 00000000..5afd7f94
--- /dev/null
+++ b/scripts/test-data/profile-l40s-x4.yaml
@@ -0,0 +1,157 @@
+chat:
+ context: default
+ # Directory where chat logs are stored
+ logs_dir: ~/.local/share/instructlab/chatlogs
+ # The maximum number of tokens that can be generated in the chat completion
+ max_tokens: null
+ # Directory where model to be used for chatting with is stored
+ model: ~/.cache/instructlab/models/instructlab/granite-7b-lab
+ session: null
+ # visual mode
+ vi_mode: false
+ # renders vertical overflow if enabled, displays ellipses otherwise
+ visible_overflow: true
+evaluate:
+ # Base taxonomy branch
+ base_branch: null
+ # Directory where the model to be evaluated is stored
+ base_model: ~/.cache/instructlab/models/instructlab/granite-7b-lab
+ # Taxonomy branch containing custom skills/knowledge that should be used for evaluation runs
+ branch: null
+ # Number of GPUs to use for running evaluation
+ dk_bench:
+ # File with questions and reference answers used for evaluation during DK-Bench.
+ input_questions: null
+ # Judge model for DK-Bench.
+ judge_model: gpt-4o
+ # Directory where DK-Bench evaluation results are stored.
+ output_dir: ~/.local/share/instructlab/internal/eval_data/dk_bench
+ # Comma-separated list of file formats for results of the DK-Bench evaluation.
+ output_file_formats: jsonl
+ gpus: 4
+ # MMLU benchmarking settings
+ mmlu:
+ # batch size for evaluation.
+ # Valid values are a positive integer or 'auto' to select the largest batch size that will fit in memory
+ batch_size: auto
+ # number of question-answer pairs provided in the context preceding the question used for evaluation
+ few_shots: 5
+ # Settings to run MMLU against a branch of taxonomy containing
+ # custom skills/knowledge used for training
+ mmlu_branch:
+ # Directory where custom MMLU tasks are stored
+ tasks_dir: ~/.local/share/instructlab/datasets
+ model: null
+ # multi-turn benchmarking settings for skills
+ mt_bench:
+ # Directory where model to be used as judge is stored
+ judge_model: ~/.cache/instructlab/models/prometheus-eval/prometheus-8x7b-v2.0
+ max_workers: auto
+ # Directory where evaluation results are stored
+ output_dir: ~/.local/share/instructlab/internal/eval_data/mt_bench
+ # Settings to run MT-Bench against a branch of taxonomy containing
+ # custom skills/knowledge used for training
+ mt_bench_branch:
+ # Directory where model to be used as judge is stored
+ judge_model: ~/.cache/instructlab/models/prometheus-eval/prometheus-8x7b-v2.0
+ # Directory where evaluation results are stored
+ output_dir: ~/.local/share/instructlab/internal/eval_data/mt_bench_branch
+ # Path to where base taxonomy is stored
+ taxonomy_path: ~/.local/share/instructlab/taxonomy
+ # System prompt for model getting responses during DK-Bench.
+ system_prompt: You are an advanced AI assistant designed to provide precise and
+ accurate information. Your primary goal is to answer queries with the most up-to-date
+ and factual information available. Focus on delivering clear, concise, and correct
+ responses. If you're uncertain about any aspect of the query, state your level
+ of confidence and provide the most accurate information you can. Your responses
+ should prioritize accuracy over all other considerations.
+ # Temperature for model getting responses during DK-Bench.
+ temperature: 0.0
+general:
+ debug_level: 0
+ log_level: INFO
+generate:
+ # Teacher model that will be used to synthetically generate training data
+ model: ~/.cache/instructlab/models/mistralai/Mixtral-8x7B-Instruct-v0.1
+ # Number of CPU cores to use for generation
+ num_cpus: 10
+ # Directory where generated datasets are stored
+ output_dir: ~/.local/share/instructlab/datasets
+ # Directory where pipeline config files are stored
+ pipeline: full
+ # The total number of instructions to be generated
+ sdg_scale_factor: 30
+ # Branch of taxonomy used to calculate diff against
+ taxonomy_base: empty
+ # Directory where taxonomy is stored and accessed from
+ taxonomy_path: ~/.local/share/instructlab/taxonomy
+ # Teacher model specific settings
+ teacher:
+ # Serving backend to use to host the teacher model
+ backend: vllm
+ # Path to teacher model that will be used to synthetically generate training data
+ model_path: ~/.cache/instructlab/models/mistralai/Mixtral-8x7B-Instruct-v0.1
+ # vLLM serving settings
+ vllm:
+ # number of GPUs to allocate to vLLM
+ gpus: 4
+ # the family of model being served - used to determine the appropriate chat template
+ llm_family: 'mixtral'
+serve:
+ # Serving backend to use to host the model
+ backend: vllm
+ # Chat template to supply to the served model. Possible values:
+ # - Custom chat template string
+ # - Auto: Uses default for serving backend
+ chat_template: auto
+ # Llamacpp serving settings
+ llama_cpp:
+ # number of model layers to offload to GPU
+ # -1 means all
+ gpu_layers: -1
+ # the family of model being served - used to determine the appropriate chat template
+ llm_family: ''
+ # maximum number of tokens that can be processed by the model
+ max_ctx_size: 4096
+ # Path to model that will be served for inference
+ model_path: ~/.cache/instructlab/models/instructlab/granite-7b-lab
+ # vLLM serving settings
+ vllm:
+ gpus: 4
+ # the family of model being served - used to determine the appropriate chat template
+ llm_family: ''
+ # additional arguments to be supplied directly to vLLM
+ vllm_args: ["--tensor-parallel-size", "4"]
+train:
+ additional_args:
+ warmup_steps: 10
+ learning_rate: 2e-6
+ lora_dropout: 0.1
+ lora_alpha: 32
+ deepspeed_cpu_offload_optimizer_pin_memory: false
+ deepspeed_cpu_oddload_optimizer_ratio: 1
+ ckpt_output_dir: checkpoints
+ data_output_dir: train-output
+ data_path: ./taxonomy_data
+ deepspeed_cpu_offload_optimizer: true
+ effective_batch_size: 32
+ lora_quantize_dtype: null
+ lora_rank: 0
+ max_batch_len: 10000
+ max_seq_len: 4096
+ model_path: ~/.cache/instructlab/models/instructlab/granite-7b-lab
+ num_epochs: 1
+ save_samples: 0
+ is_padding_free: true
+ nproc_per_node: 4
+ phased_phase1_effective_batch_size: 32
+ phased_phase1_num_epochs: 2
+ phased_phase1_samples_per_save: 0
+ phased_phase2_effective_batch_size: 32
+ phased_phase2_num_epochs: 2
+ phased_phase2_samples_per_save: 0
+ distributed_backend: fsdp
+ pipeline: accelerated
+ device: cuda
+version: 1.0.0
+
diff --git a/scripts/test-sdk.sh b/scripts/test-sdk.sh
new file mode 100755
index 00000000..6a534a96
--- /dev/null
+++ b/scripts/test-sdk.sh
@@ -0,0 +1,91 @@
+#!/usr/bin/env bash
+set -xeuf
+
+# generic globals
+BOLD='\033[1m'
+NC='\033[0m' # No Color
+PRESERVE=0
+
+# path and token globals
+SCRIPTDIR=$(dirname "$0")
+E2E_TEST_DIR=""
+CONFIG_HOME=""
+DATA_HOME=""
+CACHE_HOME=""
+CONFIG_HOME=""
+HF_TOKEN=${HF_TOKEN:-}
+
+GRANITE_7B_MODEL="instructlab/granite-7b-lab"
+MIXTRAL_8X7B_MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1"
+
+init_e2e_tests() {
+ E2E_TEST_DIR=$(mktemp -d)
+ export HOME="${E2E_TEST_DIR}" # update the HOME directory used to resolve paths
+
+ CONFIG_HOME=$(python -c 'import platformdirs; print(platformdirs.user_config_dir())')
+ DATA_HOME=$(python -c 'import platformdirs; print(platformdirs.user_data_dir())')
+ CACHE_HOME=$(python -c 'import platformdirs; print(platformdirs.user_cache_dir())')
+ # ensure that our mock e2e dirs exist
+ for dir in "${CONFIG_HOME}" "${DATA_HOME}" "${CACHE_HOME}"; do
+ mkdir -p "${dir}"
+ done
+
+ E2E_LOG_DIR="${HOME}/log"
+ mkdir -p "${E2E_LOG_DIR}"
+}
+
+test_train() {
+ task initialize ilab
+ # TODO: get profiles
+ ilab config init --non-interactive --profile="${SCRIPTDIR}/test-data/profile-l40s-x4.yaml"
+ mkdir -p "$DATA_HOME"/instructlab/taxonomy/knowledge/phoenix/overview/e2e-phoenix
+ cp "$SCRIPTDIR"/test-data/e2e-qna-knowledge-phoenix.yaml "$DATA_HOME"/instructlab/taxonomy/knowledge/phoenix/overview/e2e-phoenix/qna.yaml
+
+ mkdir -p "$DATA_HOME"/instructlab/taxonomy/compositional_skills/extraction/answerability/e2e-yes_or_no
+ cp "$SCRIPTDIR"/test-data/e2e-qna-grounded-employee-skill.yaml "$DATA_HOME"/instructlab/taxonomy/compositional_skills/extraction/answerability/e2e-yes_or_no/qna.yaml
+ task ilab initialization complete
+
+ task download models
+
+ step Downloading the mixtral-8x7b instruct model as the teacher model for SDG
+ ilab model download --repository ${MIXTRAL_8X7B_MODEL} --hf-token "${HF_TOKEN}"
+ step Downloading granite-7b-lab model to train
+ ilab model download --repository ${GRANITE_7B_MODEL}
+
+ task model downloading complete
+
+ task generate ilab data
+ ilab data generate --enable-serving-output
+ task generation complete
+
+ task Train the model with instructlab/training SDK
+
+ local knowledge_data_path
+ local skills_data_path
+ knowledge_data_path=$(find "${DATA_HOME}"/instructlab/datasets -name 'knowledge_train_msgs*' | head -n 1)
+ skills_data_path=$(find "${DATA_HOME}"/instructlab/datasets -name 'skills_train_msgs*' | head -n 1)
+
+
+ export INSTRUCTLAB_EVAL_FIRST_N_QUESTIONS=10
+ export HF_DATASETS_TRUST_REMOTE_CODE=true
+
+ python -c "import sys; sys.path.insert(0, '${SCRIPTDIR}'); from test_sdk import run_test; run_test('${knowledge_data_path}', '${skills_data_path}', nnodes=1, node_rank=0, nproc_per_node=4)"
+ task Training complete
+}
+
+step() {
+ echo -e "$BOLD$* - $(date)$NC"
+}
+
+task() {
+ echo -e "$BOLD------------------------------------------------------$NC"
+ step "$@"
+}
+
+check_disk() {
+ task Check disk
+ df -h
+}
+
+init_e2e_tests
+test_train
\ No newline at end of file
diff --git a/scripts/test_sdk.py b/scripts/test_sdk.py
new file mode 100644
index 00000000..90a258fc
--- /dev/null
+++ b/scripts/test_sdk.py
@@ -0,0 +1,372 @@
+# Standard
+from pathlib import Path
+import argparse
+import datetime
+import os
+import subprocess
+
+# Third Party
+from transformers import AutoConfig
+import torch
+
+# First Party
+from instructlab.training.async_logger import AsyncStructuredLogger
+from instructlab.training.config import DistributedBackend, ModelTypes
+from instructlab.training.model import Accelerator, Checkpointer, Model, setup_optimizer
+from instructlab.training.multipack_sampler import (
+ find_packing_max_batch_len_and_grad_accum,
+)
+
+# to SDK-ify below:
+from instructlab.training.token_dataset import setup_dataloader, setup_dataset
+from instructlab.training.tokenizer_utils import setup_tokenizer
+from instructlab.training.train import train
+from instructlab.training.utils import StreamablePopen, set_random_seed, setup_logger
+import instructlab.training.data_process as dp
+
+
+def main(args):
+ """
+ This script uses the classes defined in src/instructlab/training and follows a similar flow as main in main_ds.py
+ This is separate to ensure our testing uses a consistent script that will catch breakages to the SDK classes
+ main and run_training expect a set of arguments specific to ilab, this script only requires the following arguments:
+
+ model_path
+ knowledge_data_path
+ skills_data_path
+ effective_batch_size
+ ckpt_dir
+
+ The other arguments are set per phase by the script and are not configurable
+
+ PHASE 1
+
+ EBS = 128
+ model = granite-7b-lab
+ MBL = 10000
+ model_type = Liger
+ seed = 42
+ nproc_per_node = 4
+ nnodes = 1
+ use_dolomite = False
+ is_padding_free = True
+ lr_scheduler = cosine
+ num_epochs = 1
+ data_path = KNOWLEDGE DATA
+
+ PHASE 2
+
+ EBS = 3840
+ model = last CKPT of PHASE 1
+ MBL = 10000
+ model_type = Liger
+ seed = 42
+ nproc_per_node = 4
+ nnodes = 1
+ use_dolomite = False
+ is_padding_free = True
+ lr_scheduler = cosine
+ num_epochs = 1
+ data_path = SKILLS DATA
+
+ """
+
+ # Third Party
+ import yaml
+
+ # granite teacher model
+ # model_path = os.path.abspath(
+ # os.path.expanduser("~/.cache/instructlab/models/instructlab/granite-7b-lab")
+ # )
+ # data path to put processed data into
+ data_output_path = os.path.abspath(
+ os.path.expanduser("~/.local/share/instructlab/internal")
+ )
+ if not os.path.exists(args.ckpt_dir):
+ os.makedirs(args.ckpt_dir, exist_ok=True)
+ # # checkpoint dir to put ilab checkpoints into.
+ # ckpt_dir = os.path.abspath(
+ # os.path.expanduser("~/.local/share/instructlab/checkpoints")
+ # )
+
+ data_path = os.path.join(data_output_path, "data.jsonl")
+ metric_logger = AsyncStructuredLogger(
+ args.ckpt_dir + f"/training_params_and_metrics_global{os.environ['RANK']}.jsonl"
+ )
+ if os.environ["LOCAL_RANK"] == "0":
+ print(f"\033[38;5;120m{yaml.dump(vars(args), sort_keys=False)}\033[0m")
+ metric_logger.log_sync({"script_params": vars(args)})
+
+ setup_logger("INFO")
+ tokenizer = setup_tokenizer(args.model_path)
+
+ model_conf = AutoConfig.from_pretrained(
+ pretrained_model_name_or_path=args.model_path
+ )
+ args.model_type = model_conf.model_type
+
+ #### distributed init #####
+ torch.cuda.set_device(int(os.getenv("LOCAL_RANK", "0")))
+ args.local_rank = int(os.getenv("LOCAL_RANK", "0"))
+ torch.distributed.init_process_group("nccl")
+ args.global_rank = torch.distributed.get_rank()
+ tensor = torch.ByteTensor([False]).cuda()
+ torch.distributed.all_reduce(tensor)
+ torch.distributed.barrier()
+
+ flash_enabled = True
+
+ dataset = setup_dataset(
+ data_path,
+ mock=False,
+ mock_len=None,
+ )
+
+ try:
+ packing_max_batch_len, grad_accum = find_packing_max_batch_len_and_grad_accum(
+ num_gpus=torch.distributed.get_world_size(),
+ avg_sample_len=dataset.get_lengths().mean(),
+ effective_batch_size=args.effective_batch_size,
+ max_batch_len_per_gpu=10000,
+ is_padding=False,
+ dataset=dataset,
+ seed=42,
+ )
+ args.sampler = "multipack"
+ except RuntimeError as e:
+ if os.environ["LOCAL_RANK"] == "0":
+ print(f"\033[38;5;120m{e}\033[0m")
+
+ # fallback to grad accum = 1
+ # NOTE: packing max batch len will not be used
+ packing_max_batch_len = None
+ grad_accum = 1
+ args.sampler = "distributed"
+
+ # This model class wraps the various AutoModel classes we support
+ # based on model_type, and model_path -> choose auto_model
+ lora_config = None
+ m = Model(
+ model_path=args.model_path,
+ output_dir=args.ckpt_dir,
+ lora_config=lora_config,
+ distributed_framework=DistributedBackend.FSDP,
+ tokenizer=tokenizer,
+ model_type=ModelTypes("Causallm"),
+ flash_enabled=flash_enabled,
+ noise_alpha=None,
+ )
+
+ args.samples_per_gpu = (
+ args.effective_batch_size // grad_accum // torch.distributed.get_world_size()
+ )
+
+ train_loader = setup_dataloader(
+ dataset,
+ tokenizer.pad_token_id,
+ num_workers=8,
+ use_dolomite=False,
+ flash_enabled=flash_enabled,
+ max_batch_len=10000,
+ packing_max_batch_len=packing_max_batch_len,
+ samples_per_gpu=args.samples_per_gpu,
+ sampler=args.sampler,
+ seed=42,
+ )
+ if len(train_loader) == 0:
+ # this happens sometimes when we have more GPUs than data to process. In this case
+ # we should either alert the user to switch samplers, or do it automatically and
+ # warn them about it happening
+ print(
+ "\033[93mThe dataset is too small for multipack to distribute all of the samples across GPUs. Falling back to the distributed sampler!\033[0m"
+ )
+ train_loader = setup_dataloader(
+ dataset,
+ tokenizer.pad_token_id,
+ num_workers=8,
+ use_dolomite=False,
+ flash_enabled=flash_enabled,
+ max_batch_len=10000,
+ packing_max_batch_len=packing_max_batch_len,
+ samples_per_gpu=args.samples_per_gpu,
+ sampler=args.sampler,
+ seed=42,
+ )
+
+ if args.local_rank == 0:
+ metric_logger.log_sync(
+ {
+ "num_gpus": torch.distributed.get_world_size(),
+ "avg_sample_len": dataset.get_lengths().mean(),
+ "effective_batch_size": args.effective_batch_size,
+ "max_batch_len_per_gpu": 10000,
+ "packing_max_batch_len": packing_max_batch_len,
+ "grad_accum": grad_accum,
+ "num_batches": len(train_loader),
+ "avg_samples_per_batch": len(dataset) / len(train_loader),
+ "samples_per_gpu": args.samples_per_gpu,
+ "total_samples": len(dataset), # emit the total number of samples
+ }
+ )
+ # accelerator does not need optimizer to init, in fact, the optimizer needs to be initialized AFTER the Accelerator
+ accelerator = Accelerator(
+ model=m,
+ samples_per_gpu=args.samples_per_gpu,
+ grad_accum=grad_accum,
+ train_loader=train_loader,
+ distributed_framework=DistributedBackend.FSDP,
+ fsdp_sharding_strategy="SHARD_GRAD_OP",
+ fsdp_cpu_offload_params=False,
+ save_samples=0,
+ )
+ # optimizer needs model that has been prepared by accelerator
+ # and then accelerator needs to be prepared AGAIN once optimizer is initialized
+ optimizer = setup_optimizer(
+ model=m,
+ cpu_offload=False,
+ name=None, # choose based on backend
+ learning_rate=2e-6,
+ )
+ accelerator.prepare_with_optimizer(
+ optimizer=optimizer,
+ lr_scheduler="cosine",
+ num_epochs=2,
+ num_warmup_steps=10,
+ )
+ # TODO: make this work more seamlessly
+ optimizer = accelerator.optimizer
+ m = accelerator.model
+
+ checkpointer = Checkpointer(
+ strategy="all", model=m, optimizer=optimizer, accelerator=accelerator
+ )
+ checkpointer.load_latest_full_state(output_dir=Path(args.ckpt_dir))
+ train(
+ model=m,
+ optimizer=optimizer,
+ accelerator=accelerator,
+ checkpointer=checkpointer,
+ sampler=args.sampler,
+ use_dolomite=False,
+ metric_logger=metric_logger,
+ output_dir=args.ckpt_dir,
+ checkpoint_at_epoch=True,
+ effective_batch_size=args.effective_batch_size,
+ last_step=0,
+ num_epochs=2,
+ save_last=True,
+ )
+ torch.distributed.barrier()
+ torch.distributed.destroy_process_group()
+
+
+def run_test(knowledge_data_path, skills_data_path, nnodes, node_rank, nproc_per_node):
+ phase1_model_path = os.path.abspath(
+ os.path.expanduser("~/.cache/instructlab/models/instructlab/granite-7b-lab")
+ )
+ data_output_path = os.path.abspath(
+ os.path.expanduser("~/.local/share/instructlab/internal")
+ )
+
+ phase1_checkpoint_dir = os.path.abspath(
+ os.path.expanduser("~/.local/share/instructlab/phased/phase1/checkpoints")
+ )
+ phase2_checkpoint_dir = os.path.abspath(
+ os.path.expanduser("~/.local/share/instructlab/phased/phase2/checkpoints")
+ )
+ num_phases = 2
+ effective_batch_size = 32
+ data_path = knowledge_data_path
+ ckpt_dir = phase1_checkpoint_dir
+ model_path = phase1_model_path
+ for phase in range(1, num_phases + 1):
+ # override model
+ # override checkpoints dir
+ # override EBS
+ if phase == 2:
+ model_path = os.path.join(phase1_checkpoint_dir, "hf_format", "last_epoch")
+ effective_batch_size = 32
+ data_path = skills_data_path
+ ckpt_dir = phase2_checkpoint_dir
+ print(f"RUNNING PHASE {phase} of {num_phases}")
+
+ dp.process_data(
+ data_output_path=data_output_path,
+ model_path=model_path,
+ data_path=data_path,
+ max_seq_len=4096,
+ num_cpu_procs=16,
+ )
+
+ command = [
+ "torchrun",
+ f"--nnodes={nnodes}",
+ f"--node_rank={node_rank}",
+ f"--nproc_per_node={nproc_per_node}",
+ f"--rdzv_id=123",
+ f"--rdzv_endpoint=127.0.0.1:12222",
+ __file__,
+ f"--effective-batch-size={effective_batch_size}",
+ f"--model-path={model_path}",
+ f"--ckpt-dir={ckpt_dir}",
+ ]
+ process = None
+ interrupt: KeyboardInterrupt | Exception | None = None
+ failure = False
+ try:
+ log_path = os.path.abspath(
+ os.path.expanduser(
+ f"~/.local/share/instructlab/checkpoints/full_logs_global{node_rank}.log"
+ )
+ )
+ if not os.path.exists(os.path.dirname(log_path)):
+ os.makedirs(os.path.dirname(log_path), exist_ok=True)
+ process = StreamablePopen(
+ log_path,
+ command,
+ )
+ process.listen()
+ except KeyboardInterrupt as e:
+ print("Training subprocess interrupted by user.")
+ interrupt = e
+ except Exception as e:
+ print("Unexpected exception received during distributed training")
+ interrupt = e
+ finally:
+ if "process" not in locals() or process is None:
+ return
+
+ failure = process.poll() != 0
+ if not failure:
+ print("\033[92mOperation completed successfully! 🎉\033[0m")
+ else:
+ print(
+ "\033[91mTraining subprocess has not exited yet. Sending SIGTERM.\033[0m"
+ )
+
+ process.terminate()
+ try:
+ print("Waiting for process to exit, 60s...")
+ process.wait(timeout=60)
+ except subprocess.TimeoutExpired:
+ print(
+ "\033[91mTraining subprocess did not terminate before timeout, sending SIGKILL.\033[0m"
+ )
+ process.kill()
+
+ if interrupt:
+ raise interrupt
+ if failure:
+ raise RuntimeError(
+ "Suffered a failure during distributed training. Please see the training logs for more context."
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--effective-batch-size", type=int)
+ parser.add_argument("--model-path", type=str)
+ parser.add_argument("--ckpt-dir", type=str)
+ args = parser.parse_args()
+ set_random_seed(42)
+ main(args)
diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py
index 92248fe7..d205b5d1 100644
--- a/src/instructlab/training/config.py
+++ b/src/instructlab/training/config.py
@@ -27,6 +27,20 @@ class DeepSpeedOffloadStrategy(Enum):
NONE = None
+# public API
+class Optimizers(Enum):
+ ADAMW = "Adamw"
+ CPUAdam = "CPUAdam"
+ FusedAdam = "FusedAdam"
+
+
+# public API
+class ModelTypes(Enum):
+ LIGER = "Liger"
+ CAUSALLM = "Causallm"
+ DOLOMITE = "Dolomite"
+
+
# public API
class DistributedBackend(Enum):
FSDP = "fsdp"
diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py
index d88c81db..be733dda 100644
--- a/src/instructlab/training/main_ds.py
+++ b/src/instructlab/training/main_ds.py
@@ -1,18 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
-from copy import deepcopy
from pathlib import Path
import argparse
import datetime
-import math
import os
import re
import subprocess
-import time
-
-# Third Party
-from accelerate import Accelerator
try:
# Third Party
@@ -37,15 +31,7 @@
print("DeepSpeed is not available. Some features may be unavailable.")
# Third Party
-from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
-from torch.utils.data import DataLoader
-from tqdm import tqdm
-from transformers import (
- AutoConfig,
- AutoModelForCausalLM,
- PreTrainedTokenizer,
- get_scheduler,
-)
+from transformers import AutoConfig
import torch
import torch.distributed
@@ -54,230 +40,28 @@
from instructlab.training.async_logger import AsyncStructuredLogger
# pylint: disable=no-name-in-module
-from instructlab.training.config import DistributedBackend, TorchrunArgs, TrainingArgs
+from instructlab.training.config import (
+ DistributedBackend,
+ ModelTypes,
+ TorchrunArgs,
+ TrainingArgs,
+)
+from instructlab.training.model import Accelerator, Checkpointer, Model, setup_optimizer
from instructlab.training.multipack_sampler import (
find_packing_max_batch_len_and_grad_accum,
)
-from instructlab.training.setup_accelerator import setup_accelerator
from instructlab.training.token_dataset import setup_dataloader, setup_dataset
from instructlab.training.tokenizer_utils import setup_tokenizer
from instructlab.training.utils import (
StreamablePopen,
- add_noisy_embeddings,
- apply_gradient_checkpointing,
- check_flash_attn_enabled,
check_valid_train_args,
- convert_loss_to_reduce_sum,
- create_lora_config,
- ensure_loadable_dolomite_checkpoint,
- load_latest_full_state,
- prepare_peft_model,
prepare_universal_checkpoint_from_latest,
- save_checkpoint,
- save_hf_format_accelerate,
set_random_seed,
setup_logger,
)
import instructlab.training.data_process as dp
-def setup_optimizer(args, model):
- if args.distributed_training_framework == DistributedBackend.FSDP.value:
- optimizer = torch.optim.AdamW(
- model.parameters(),
- lr=args.learning_rate,
- betas=(0.9, 0.95),
- weight_decay=0.0,
- )
- elif args.distributed_training_framework == DistributedBackend.DEEPSPEED.value:
- # need to use this only when the CPU offload optimizer is enabled
- if args.cpu_offload_optimizer:
- print(
- "\033[33m!!! CPU offload optimizer enabled, using DeepSpeedCPUAdam !!!\033[0m"
- )
- optimizer = DeepSpeedCPUAdam(
- model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95)
- )
- else:
- optimizer = FusedAdam(
- model.parameters(), lr=args.learning_rate, betas=(0.9, 0.95)
- )
- else:
- raise ValueError(
- f"Sharding framework {args.distributed_training_framework} is not supported."
- )
- return optimizer
-
-
-def setup_model(
- args, tokenizer: PreTrainedTokenizer, train_loader, grad_accum, flash_enabled
-):
- bnb_config = None
- if args.lora_r > 0 and args.lora_quant_bits == 4:
- # Third Party
- from transformers import BitsAndBytesConfig
-
- bnb_config = BitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_quant_type="nf4",
- bnb_4bit_use_double_quant=True,
- bnb_4bit_compute_dtype=torch.float16, # if not set will throw a warning about slow speeds when training
- )
-
- base_model_args = {
- "pretrained_model_name_or_path": args.model_name_or_path,
- "torch_dtype": torch.bfloat16,
- "quantization_config": bnb_config,
- }
- if flash_enabled:
- base_model_args["attn_implementation"] = "flash_attention_2"
-
- if args.use_dolomite:
- with ensure_loadable_dolomite_checkpoint(
- args.model_name_or_path, args.output_dir
- ) as path:
- base_model_args["pretrained_model_name_or_path"] = path
- base_model_args["use_padding_free_transformer"] = True
- model = GPTDolomiteForCausalLM.from_pretrained(
- **base_model_args,
- )
- elif args.use_liger:
- # TODO(osilkin): we duplicate some checks here because someone may run this script through
- # torchrun directly and not `run_training`. To fix this, we should eventually move everything
- # to using `torch.multiprocessing` and simplify the CLI.
- if args.lora_r > 0:
- raise ValueError(
- "Using LoRA and Liger kernels is not supported. Please use either LoRA or Liger kernels, but not both."
- )
- try:
- # Third Party
- from liger_kernel.transformers import AutoLigerKernelForCausalLM
- except ImportError as e:
- raise ValueError(
- "Liger kernels are not installed. Please install Liger kernels using the following command: pip install liger-kernel"
- ) from e
-
- # NOTE: (jkunstle) we disable fused_linear_cross_entropy, even though it's a default for most of the models with LK support,
- # because reduce_sum_loss requires the logits, and fused_linear_cross_entropy explicitly skips materializing them for
- # performance.
- model = AutoLigerKernelForCausalLM.from_pretrained(
- **base_model_args, cross_entropy=True, fused_linear_cross_entropy=False
- )
- else:
- model = AutoModelForCausalLM.from_pretrained(**base_model_args)
-
- # store the base model args so we can recall them later if saving a LoRA model
- args.base_model_args = base_model_args
-
- if len(tokenizer) > model.config.vocab_size:
- print(
- f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size"
- )
- model.resize_token_embeddings(
- int(8 * math.ceil(len(tokenizer) / 8.0))
- ) # make the vocab size multiple of 8 for sharding the embedding layer.
-
- # Fix any discrepancy between model and tokenizer
- if (
- model.config.pad_token_id is not None
- and tokenizer.pad_token_id is not None
- and model.config.pad_token_id != tokenizer.pad_token_id
- ):
- print(
- f"WARNING: There is a mismatch between pad token id of model ({model.config.pad_token_id}) and tokenizer({tokenizer.pad_token_id}). Fixing model pad token id to be same as tokenizer's pad token id"
- )
- model.config.pad_token_id = tokenizer.pad_token_id
- if (
- model.config.bos_token_id is not None
- and tokenizer.bos_token_id is not None
- and model.config.bos_token_id != tokenizer.bos_token_id
- ):
- print(
- f"WARNING: There is a mismatch between bos token id of model({model.config.bos_token_id}) and tokenizer({tokenizer.bos_token_id}). Fixing model bos token id to be same as tokenizer's bos token id"
- )
- model.config.bos_token_id = tokenizer.bos_token_id
- if (
- model.config.eos_token_id is not None
- and tokenizer.eos_token_id
- and model.config.eos_token_id != tokenizer.eos_token_id
- ):
- print(
- f"WARNING: There is a mismatch between eos token id of model({model.config.eos_token_id}) and tokenizer({tokenizer.eos_token_id}). Fixing model eos token id to be same as tokenizer's eos token id"
- )
- model.config.eos_token_id = tokenizer.eos_token_id
-
- if "ForCausalLM" not in model.__class__.__name__:
- raise ValueError(
- f"Model class name: {model.__class__.__name__} is not supported."
- )
-
- # ensure the model has any tokens which were added to the tokenizer
- if tokenizer.pad_token_id is not None and model.config.pad_token_id is None:
- model.config.pad_token_id = tokenizer.pad_token_id
- if tokenizer.bos_token_id is not None and model.config.bos_token_id is None:
- model.config.bos_token_id = tokenizer.bos_token_id
- if tokenizer.eos_token_id is not None and model.config.eos_token_id is None:
- model.config.eos_token_id = tokenizer.eos_token_id
-
- model = convert_loss_to_reduce_sum(model, use_dolomite=args.use_dolomite)
- model = add_noisy_embeddings(model, noise_alpha=args.NEFTune_alpha)
-
- # handling of gradient checkpointing
- # it is handled differently for lora and full
- # - with the exception of granite, which handles it
- # in the later stanza
- if args.lora_r > 0:
- lora_config = create_lora_config(model, args)
- model = prepare_peft_model(
- model,
- lora_config,
- args.distributed_training_framework,
- gradient_checkpointing=not args.use_dolomite,
- )
- args.lora_config = lora_config
- elif not args.use_dolomite:
- model.gradient_checkpointing_enable()
-
- # granite gradient checkpointing is handled uniformly
- # for both lora and full here
- if args.use_dolomite:
- block_name = model._no_split_modules[0]
- apply_gradient_checkpointing(
- model,
- block_name=block_name,
- use_reentrant=True, # this should be the HF default mode
- )
-
- if args.lora_r > 0:
-
- def make_inputs_require_grad(module, input, output): # pylint: disable=unused-argument
- output.requires_grad_(True)
-
- model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
-
- accelerator = setup_accelerator(args, model, grad_accum)
- if args.distributed_training_framework == DistributedBackend.FSDP.value:
- model = accelerator.prepare(model)
- optimizer = setup_optimizer(args, model)
-
- lr_scheduler = get_scheduler(
- name=args.lr_scheduler,
- optimizer=optimizer,
- num_warmup_steps=args.num_warmup_steps,
- num_training_steps=args.num_epochs * len(train_loader) // grad_accum,
- )
- model, optimizer, _, lr_scheduler = accelerator.prepare(
- model,
- optimizer,
- deepcopy(train_loader),
- lr_scheduler,
- )
- # Necessary so that Accelerate does not step once per GPU
- # see https://github.com/huggingface/accelerate/blob/127818fc27ebe5cb236357fff59ff1748326d643/src/accelerate/scheduler.py#L69
- lr_scheduler.split_batches = True
- return model, lr_scheduler, optimizer, accelerator
-
-
# this function is to check if the checkpoint provided can be resumed
def maybe_resume_training(args, model):
local_rank = int(os.environ["LOCAL_RANK"])
@@ -338,205 +122,13 @@ def maybe_resume_training(args, model):
return model
-def train(
- args,
- model,
- optimizer,
- lr_scheduler,
- accelerator: Accelerator,
- tokenizer: PreTrainedTokenizer,
- train_loader: DataLoader,
- grad_accum,
- metric_logger,
-):
- model.train()
-
- global_step = 1
- local_rank = int(os.environ["LOCAL_RANK"])
- world_size = int(os.environ["WORLD_SIZE"])
-
- batch_size = args.effective_batch_size // grad_accum
- samples_seen = 0
-
- if hasattr(args, "samples_seen"):
- print(f"\033[93mUpdating 'samples_seen' {args.samples_seen}\033[0m")
- samples_seen = args.samples_seen
-
- if args.save_samples > 0:
- args.save_samples = (args.save_samples // batch_size) * batch_size
- (
- print(f"\033[93mNumber of samples per save: {args.save_samples}\033[0m")
- if local_rank == 0
- else None
- )
-
- if args.save_samples_ds is not None:
- args.save_samples_ds = (args.save_samples_ds // batch_size) * batch_size
- (
- print(
- f"\033[93mNumber of samples per DS save: {args.save_samples_ds}\033[0m"
- )
- if local_rank == 0
- else None
- )
-
- global_grad_norm = None
- for epoch in range(args.current_epoch, args.num_epochs):
- if args.sampler in ("multipack"):
- train_loader.batch_sampler.set_epoch(epoch)
- elif args.sampler in ("distributed"):
- train_loader.sampler.set_epoch(epoch)
- else:
- raise NotADirectoryError
-
- if local_rank == 0:
- inner_pb = tqdm(range(len(train_loader)), desc=f"Epoch {epoch}")
-
- # blast through the batches in the train loader up to the last step within the epoch.
- for batch in train_loader:
- if global_step <= args.last_step:
- # in the case of resuming, last_step > 0
- global_step += 1
- if local_rank == 0:
- inner_pb.update(1)
- continue
- start = time.time()
- num_loss_counted_tokens = float(
- torch.tensor([batch.pop("num_loss_counted_tokens")])
- )
- micro_batch_size = float(torch.tensor([batch.pop("num_samples")]))
- if not args.use_dolomite:
- for k in batch:
- batch[k] = batch[k].to(local_rank)
- output = model(
- **batch,
- use_cache=False,
- )
- loss = output.loss
- log_loss = loss.detach().item()
-
- num_loss_counted_tokens, micro_batch_size, log_loss = map(
- float,
- accelerator.reduce(
- torch.tensor(
- [num_loss_counted_tokens, micro_batch_size, log_loss],
- dtype=torch.float32,
- device=accelerator.device,
- ),
- reduction="sum",
- ),
- )
- samples_seen += int(micro_batch_size)
-
- # num_loss_counted_tokens = aggregated_values[0]
- loss = (
- loss / num_loss_counted_tokens * world_size
- ) # dividing by the total number of non-padding tokens and multiplying by the number of GPUs so when accelerate averages by world_size, it will be the correct loss.
- print(
- f"Epoch: {epoch}, Step: {global_step}, Rank: {torch.distributed.get_rank()}, loss = {loss}"
- )
- accelerator.backward(loss)
-
- if global_step % grad_accum == 0:
- global_grad_norm = accelerator.clip_grad_norm_(model.parameters(), 1.0)
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad()
-
- if local_rank == 0:
- elapsed_time = time.time() - start
- overall_throughput = args.samples_per_gpu * world_size / elapsed_time
- current_lr = lr_scheduler.get_last_lr()[0]
- cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3)
- cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
- global_grad_norm = (
- model.get_global_grad_norm()
- if hasattr(model, "get_global_grad_norm")
- else global_grad_norm
- )
- global_grad_norm = (
- float(global_grad_norm) if global_grad_norm is not None else None
- )
- # TODO - Bring back weight_norm gather
- # weight_norm = float(
- # model.optimizer.single_partition_of_fp32_groups[0].norm()
- # )
-
- # TODO - Bring back consistent gradnorm and weight_norm logging
- metric_logger.log_sync(
- {
- "epoch": epoch,
- "step": global_step,
- "rank": torch.distributed.get_rank(),
- "overall_throughput": overall_throughput,
- "lr": current_lr,
- "cuda_mem_allocated": cuda_mem_allocated,
- "cuda_malloc_retries": cuda_malloc_retries,
- "num_loss_counted_tokens": int(num_loss_counted_tokens),
- "batch_size": int(micro_batch_size),
- "total_loss": float(log_loss / num_loss_counted_tokens),
- "samples_seen": samples_seen,
- "gradnorm": global_grad_norm,
- "total_samples": len(train_loader.dataset),
- # "weight_norm": weight_norm,
- }
- )
-
- if args.save_samples > 0 and (
- global_step * batch_size % args.save_samples == 0
- ):
- save_checkpoint(
- args=args,
- accelerator=accelerator,
- model=model,
- tokenizer=tokenizer,
- samples_seen=samples_seen,
- is_lora=bool(args.lora_r),
- hf_format=True,
- )
-
- # if (
- # args.save_samples_ds is not None
- # and global_step * batch_size % args.save_samples_ds == 0
- # ):
- # save_model_ds_native(
- # args,
- # model,
- # tokenizer,
- # global_step * args.samples_per_gpu * world_size,
- # )
- global_step += 1
- if local_rank == 0:
- inner_pb.update(1)
- torch.cuda.empty_cache()
- if args.checkpoint_at_epoch:
- save_checkpoint(
- args=args,
- accelerator=accelerator,
- model=model,
- tokenizer=tokenizer,
- samples_seen=samples_seen,
- is_lora=bool(args.lora_r),
- full_state=args.accelerate_full_state_at_epoch,
- hf_format=True,
- epoch=epoch,
- )
-
- if args.save_last:
- save_hf_format_accelerate(
- args,
- model,
- tokenizer,
- accelerator,
- samples_seen,
- is_lora=bool(args.lora_r),
- )
-
-
def main(args):
# Third Party
import yaml
+ # First Party
+ from instructlab.training.train import train
+
if args.distributed_training_framework == "deepspeed" and not FusedAdam:
raise ImportError(
"DeepSpeed was selected but we cannot import the `FusedAdam` optimizer"
@@ -561,7 +153,6 @@ def main(args):
setup_logger(args.log_level)
tokenizer = setup_tokenizer(args.model_name_or_path, args.chat_tmpl_path)
- # device = torch.device("cuda", args.local_rank)
model_conf = AutoConfig.from_pretrained(args.model_name_or_path)
args.model_type = model_conf.model_type
@@ -580,8 +171,14 @@ def main(args):
torch.distributed.all_reduce(tensor)
torch.distributed.barrier()
- flash_enabled = check_flash_attn_enabled(args.disable_flash_attn, args.use_dolomite)
+ flash_enabled = Model.check_flash_attn_enabled(
+ args.disable_flash_attn, args.use_dolomite
+ )
+ # TODO: would like to replace this with either
+ # a) a dataset class
+ # b) a dataloader class
+ # c) a bit of both
dataset = setup_dataset(
args.data_path,
mock=args.mock_data,
@@ -609,6 +206,28 @@ def main(args):
grad_accum = 1
args.sampler = "distributed"
+ # This model class wraps the various AutoModel classes we support
+ # based on model_type, and model_path -> choose auto_model
+ lora_config = None
+
+ if args.lora_r > 0:
+ lora_config = Model.create_lora_config(
+ lora_target_modules=args.lora_target_modules,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ lora_r=args.lora_r,
+ )
+ m = Model(
+ model_path=args.model_name_or_path,
+ output_dir=args.output_dir,
+ lora_config=lora_config,
+ distributed_framework=DistributedBackend(args.distributed_training_framework),
+ tokenizer=tokenizer,
+ model_type=ModelTypes(args.model_class),
+ flash_enabled=flash_enabled,
+ noise_alpha=args.NEFTune_alpha,
+ )
+
args.samples_per_gpu = (
args.effective_batch_size // grad_accum // torch.distributed.get_world_size()
)
@@ -661,25 +280,60 @@ def main(args):
"total_samples": len(dataset), # emit the total number of samples
}
)
-
- model, lr_scheduler, optimizer, accelerator = setup_model(
- args, tokenizer, train_loader, grad_accum, flash_enabled
+ # accelerator does not need optimizer to init, in fact, the optimizer needs to be initialized AFTER the Accelerator
+ accelerator = Accelerator(
+ model=m,
+ samples_per_gpu=args.samples_per_gpu,
+ grad_accum=grad_accum,
+ train_loader=train_loader,
+ distributed_framework=DistributedBackend(args.distributed_training_framework),
+ fsdp_sharding_strategy=args.fsdp_sharding_strategy,
+ deepspeed_cpu_offload_optimizer=args.cpu_offload_optimizer,
+ deepspeed_cpu_offload_optimizer_pin_memory=args.cpu_offload_optimizer_pin_memory,
+ deepspeed_cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio,
+ fsdp_cpu_offload_params=args.cpu_offload_params_fsdp,
+ save_samples=args.save_samples,
)
-
- load_latest_full_state(args=args, accelerator=accelerator)
-
+ # optimizer needs model that has been prepared by accelerator
+ # and then accelerator needs to be prepared AGAIN once optimizer is initialized
+ optimizer = setup_optimizer(
+ model=m,
+ cpu_offload=args.cpu_offload_optimizer,
+ name=None, # choose based on backend
+ learning_rate=args.learning_rate,
+ )
+ accelerator.prepare_with_optimizer(
+ optimizer=optimizer,
+ lr_scheduler=args.lr_scheduler,
+ num_epochs=args.num_epochs,
+ num_warmup_steps=args.num_warmup_steps,
+ )
+ # TODO: make this work more seamlessly
+ optimizer = accelerator.optimizer
+ m = accelerator.model
+
+ strategy = "all"
+ if not args.accelerate_full_state_at_epoch:
+ strategy = "hf_format"
+ checkpointer = Checkpointer(
+ strategy=strategy, model=m, optimizer=optimizer, accelerator=accelerator
+ )
+ checkpointer.load_latest_full_state(Path(args.output_dir))
train(
- args,
- model,
- optimizer,
- lr_scheduler,
- accelerator,
- tokenizer,
- train_loader,
- grad_accum,
- metric_logger,
+ model=m,
+ optimizer=optimizer,
+ accelerator=accelerator,
+ checkpointer=checkpointer,
+ sampler=args.sampler,
+ use_dolomite=args.use_dolomite,
+ metric_logger=metric_logger,
+ output_dir=args.output_dir,
+ checkpoint_at_epoch=args.checkpoint_at_epoch,
+ effective_batch_size=args.effective_batch_size,
+ last_step=args.last_step,
+ num_epochs=args.num_epochs,
+ save_last=args.save_last,
)
-
torch.distributed.barrier()
torch.distributed.destroy_process_group()
@@ -876,6 +530,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
# Maybe switch out from argparse to something smarter
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str)
+ parser.add_argument(
+ "--model-class",
+ type=str,
+ default=ModelTypes.LIGER.value,
+ help=f"valid model classes are {ModelTypes.LIGER.value}, {ModelTypes.DOLOMITE.value}, and {ModelTypes.CAUSALLM.value}.",
+ )
parser.add_argument("--data_path", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--num_epochs", type=int, default=1)
diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py
new file mode 100644
index 00000000..fb396c74
--- /dev/null
+++ b/src/instructlab/training/model.py
@@ -0,0 +1,1081 @@
+# Standard
+from copy import deepcopy
+from pathlib import Path
+from typing import List, Optional, Tuple
+import math
+import os
+import shutil
+import time
+import warnings
+
+# Third Party
+from accelerate import Accelerator as TransformersAccel
+
+try:
+ # Third Party
+ from deepspeed.ops.adam import DeepSpeedCPUAdam
+except ImportError:
+ DeepSpeedCPUAdam = None
+ local_rank = int(os.getenv("LOCAL_RANK", "0"))
+ if __name__ == "__main__" and (not local_rank or local_rank == 0):
+ print(
+ "DeepSpeed CPU Optimizer is not available. Some features may be unavailable."
+ )
+
+try:
+ # Third Party
+ from deepspeed.ops.adam import FusedAdam
+except ImportError:
+ FusedAdam = None
+ local_rank = int(os.getenv("LOCAL_RANK", "0"))
+ if __name__ == "__main__" and (not local_rank or local_rank == 0):
+ print("DeepSpeed is not available. Some features may be unavailable.")
+
+# Third Party
+from instructlab.dolomite.hf_models import export_to_huggingface
+from peft import LoraConfig
+from torch import distributed as dist
+from torch.distributed.fsdp import FullStateDictConfig
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp import StateDictType
+from torch.optim import AdamW
+from torch.utils.data import DataLoader
+from transformers import PreTrainedTokenizer, get_scheduler
+import torch
+
+# First Party
+from instructlab.training.config import ( # Adjust this import if needed
+ DeepSpeedOptions,
+ DistributedBackend,
+ ModelTypes,
+ Optimizers,
+)
+
+# Local
+from .utils import log_rank_0, wraps
+
+# mypy: disable_error_code="has-type"
+
+
+class Model:
+ def __init__(
+ self,
+ model_path: str,
+ output_dir: str,
+ distributed_framework: DistributedBackend,
+ model_type: ModelTypes,
+ noise_alpha: Optional[float],
+ tokenizer: PreTrainedTokenizer,
+ flash_enabled: bool = False,
+ lora_config: Optional[LoraConfig] = None,
+ ):
+ self.lora_config = lora_config
+ self.noise_alpha = noise_alpha
+ self.model_type = model_type
+ self.tokenizer = tokenizer
+ self.distributed_framework = distributed_framework
+ self.base_model_args = {
+ "pretrained_model_name_or_path": model_path,
+ "torch_dtype": torch.bfloat16,
+ }
+
+ if flash_enabled:
+ self.base_model_args["attn_implementation"] = "flash_attention_2"
+
+ # Pick model loader based on type
+ if model_type == ModelTypes.LIGER:
+ try:
+ # Third Party
+ # pylint: disable-next=W0611
+ from liger_kernel.transformers import AutoLigerKernelForCausalLM
+ except ImportError as e:
+ raise ValueError(
+ "Liger kernels are not installed. Please install Liger kernels using the following command: pip install liger-kernel"
+ ) from e
+ self.model = AutoLigerKernelForCausalLM.from_pretrained(
+ **self.base_model_args,
+ cross_entropy=True,
+ fused_linear_cross_entropy=False,
+ )
+ self.model.gradient_checkpointing_enable()
+ elif model_type == ModelTypes.DOLOMITE:
+ # Third Party
+ from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
+
+ # First Party
+ from instructlab.training.utils import (
+ apply_gradient_checkpointing,
+ ensure_loadable_dolomite_checkpoint,
+ )
+
+ with ensure_loadable_dolomite_checkpoint(model_path, output_dir) as path:
+ self.base_model_args["pretrainedmodel_name_or_path"] = path
+ self.base_model_args["use_padding_free_transformer"] = True
+ self.model = GPTDolomiteForCausalLM.from_pretrained(
+ **self.base_model_args
+ )
+ apply_gradient_checkpointing(
+ model=self.model,
+ block_name=self.model._no_split_modules[0],
+ use_reentrant=True,
+ )
+ elif model_type == ModelTypes.CAUSALLM:
+ # Third Party
+ from transformers import AutoModelForCausalLM
+
+ self.model = AutoModelForCausalLM.from_pretrained(**self.base_model_args)
+ self.model.gradient_checkpointing_enable()
+ else:
+ raise AttributeError(
+ f"Invalid Model Type {model_type} valid types are {ModelTypes.LIGER.value}, {ModelTypes.DOLOMITE.value}, and {ModelTypes.CAUSALLM.value}."
+ )
+
+ self.reconcile_tokenizer()
+ if self.lora_config:
+ # First Party
+ self.model = self.prepare_peft_model(
+ gradient_checkpointing=not (model_type == "dolomite"),
+ )
+ if model_type == "dolomite":
+ # pylint: disable=unused-argument
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ self.model.get_input_embeddings().register_forward_hook(
+ make_inputs_require_grad
+ )
+
+ def train(self, mode=True):
+ """Set the model in training mode.
+
+ Args:
+ mode (bool): Whether to set training mode (True) or evaluation mode (False).
+ """
+ return self.model.train(mode)
+
+ @property
+ def module(self):
+ return getattr(self.model, "module", self.model)
+
+ def parameters(self):
+ return self.model.parameters()
+
+ def update_model(self, new_model):
+ if isinstance(new_model, Model):
+ raise AttributeError("This will cause recursion")
+ self.model = new_model
+
+ def __getattr__(self, name):
+ if name == "model":
+ return super().__getattribute__("model")
+ return getattr(self.model, name)
+
+ def __call__(self, *args, **kwargs):
+ return self.model(*args, **kwargs)
+
+ def get_projection_layer_names(self) -> List[str]:
+ """
+ Given a pretrained model, returns all of the projection layers (matching '_proj')
+ """
+ proj_layers = set(
+ name.split(".")[-1]
+ for name, _ in self.model.named_modules()
+ if name.endswith("_proj")
+ )
+ return list(proj_layers)
+
+ def prepare_peft_model(
+ self,
+ gradient_checkpointing=True,
+ gradient_checkpointing_kwargs={"use_reentrant": True},
+ mixed_precision="bf16",
+ ):
+ # Third Party
+ from peft import (
+ LoraModel,
+ PeftModel,
+ get_peft_model,
+ prepare_model_for_kbit_training,
+ )
+ from trl.trainer.utils import peft_module_casting_to_bf16
+
+ proj_layers = self.get_projection_layer_names()
+ if not self.lora_config.target_modules:
+ print(
+ "WARNING: lora_target_modules not specified. Using all projection layers."
+ )
+ if not proj_layers:
+ raise RuntimeError("No projection layers found in the model.")
+ self.lora_config.target_modules = proj_layers
+ else:
+ requested = set(self.lora_config.target_modules)
+ available = set(proj_layers)
+ missing = requested - available
+ valid = requested & available
+
+ if not valid:
+ raise ValueError(
+ f"None of the requested LoRA target modules exist in the model.\n"
+ f"Requested: {self.lora_config.target_modules}\nAvailable: {proj_layers}"
+ )
+ if missing:
+ print(
+ f"\033[33mWARNING: The following modules were not found in the model: {list(missing)}. "
+ f"Applying LoRA only to: {list(valid)}.\033[0m"
+ )
+ self.lora_config.target_modules = list(valid)
+
+ # if not isinstance(peft_config, PeftConfig):
+ # raise ValueError(
+ # "If you want to use the PeftModel, you need to pass a PeftConfig object, "
+ # f"and you passed a {type(peft_config)}."
+ # )
+
+ if not isinstance(self.model, PeftModel):
+ if getattr(self.model, "is_loaded_in_8bit", False) or getattr(
+ self.model, "is_loaded_in_4bit", False
+ ):
+ preprare_model_kwargs = {
+ "use_gradient_checkpointing": gradient_checkpointing
+ }
+
+ # if _support_gc_kwargs:
+ preprare_model_kwargs["gradient_checkpointing_kwargs"] = (
+ gradient_checkpointing_kwargs
+ )
+
+ self.model = prepare_model_for_kbit_training(
+ self.model, **preprare_model_kwargs
+ )
+
+ elif gradient_checkpointing:
+ # For backward compatibility with older versions of transformers
+ if hasattr(self.model, "enable_input_require_grads"):
+ self.model.enable_input_require_grads()
+ else:
+
+ def make_inputs_require_grad(module, input, output): # pylint: disable=unused-argument
+ output.requires_grad_(True)
+
+ self.model.get_input_embeddings().register_forward_hook(
+ make_inputs_require_grad
+ )
+
+ if self.distributed_framework == DistributedBackend.FSDP.value:
+ # FSDP doesn't like `get_peft_model` as it leads to dtype mismatches
+ self.model = LoraModel(self.model, self.lora_config, "default")
+ else:
+ self.model = get_peft_model(self.model, self.lora_config)
+ if mixed_precision == "bf16" and getattr(
+ self.model, "is_loaded_in_4bit", False
+ ):
+ peft_module_casting_to_bf16(self.model)
+
+ return self.model
+
+ @staticmethod
+ def create_lora_config(
+ lora_target_modules: List[str],
+ lora_alpha: Optional[int],
+ lora_dropout: Optional[float],
+ lora_r: int,
+ ):
+ # Local
+ return LoraConfig(
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ r=lora_r,
+ bias="none",
+ task_type="CAUSAL_LM",
+ target_modules=lora_target_modules,
+ )
+
+ @classmethod
+ def setup_liger(
+ cls,
+ model_path: str,
+ output_dir: str,
+ lora_config: LoraConfig,
+ distributed_framework: DistributedBackend,
+ noise_alpha: Optional[float],
+ tokenizer: PreTrainedTokenizer,
+ flash_enabled: bool = False,
+ ):
+ return cls(
+ model_path=model_path,
+ output_dir=output_dir,
+ lora_config=lora_config,
+ distributed_framework=distributed_framework,
+ model_type=ModelTypes.LIGER,
+ noise_alpha=noise_alpha,
+ tokenizer=tokenizer,
+ flash_enabled=flash_enabled,
+ )
+
+ @classmethod
+ def setup_dolomite(
+ cls,
+ model_path: str,
+ output_dir: str,
+ lora_config: LoraConfig,
+ distributed_framework: DistributedBackend,
+ noise_alpha: Optional[float],
+ tokenizer: PreTrainedTokenizer,
+ flash_enabled: bool = False,
+ ):
+ return cls(
+ model_path=model_path,
+ output_dir=output_dir,
+ lora_config=lora_config,
+ distributed_framework=distributed_framework,
+ model_type=ModelTypes.DOLOMITE,
+ noise_alpha=noise_alpha,
+ tokenizer=tokenizer,
+ flash_enabled=flash_enabled,
+ )
+
+ def reconcile_tokenizer(self):
+ if len(self.tokenizer) > self.model.config.vocab_size:
+ print(
+ f"WARNING: tokenizer has {len(self.tokenizer)} tokens but model has {self.model.config.vocab_size} vocab size"
+ )
+ self.model.resize_token_embeddings(
+ int(8 * math.ceil(len(self.tokenizer) / 8.0))
+ ) # make the vocab size multiple of 8 for sharding the embedding layer.
+
+ # Fix any discrepancy between model and tokenizer
+ if (
+ self.model.config.pad_token_id is not None
+ and self.tokenizer.pad_token_id is not None
+ and self.model.config.pad_token_id != self.tokenizer.pad_token_id
+ ):
+ print(
+ f"WARNING: There is a mismatch between pad token id of model ({self.model.config.pad_token_id}) and tokenizer({self.tokenizer.pad_token_id}). Fixing model pad token id to be same as tokenizer's pad token id"
+ )
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
+ if (
+ self.model.config.bos_token_id is not None
+ and self.tokenizer.bos_token_id is not None
+ and self.model.config.bos_token_id != self.tokenizer.bos_token_id
+ ):
+ print(
+ f"WARNING: There is a mismatch between bos token id of model({self.model.config.bos_token_id}) and tokenizer({self.tokenizer.bos_token_id}). Fixing model bos token id to be same as tokenizer's bos token id"
+ )
+ self.model.config.bos_token_id = self.tokenizer.bos_token_id
+ if (
+ self.model.config.eos_token_id is not None
+ and self.tokenizer.eos_token_id
+ and self.model.config.eos_token_id != self.tokenizer.eos_token_id
+ ):
+ print(
+ f"WARNING: There is a mismatch between eos token id of model({self.model.config.eos_token_id}) and tokenizer({self.tokenizer.eos_token_id}). Fixing model eos token id to be same as tokenizer's eos token id"
+ )
+ self.model.config.eos_token_id = self.tokenizer.eos_token_id
+
+ if "ForCausalLM" not in self.model.__class__.__name__:
+ raise ValueError(
+ f"Model class name: {self.model.__class__.__name__} is not supported."
+ )
+
+ # ensure the model has any tokens which were added to the tokenizer
+ if (
+ self.tokenizer.pad_token_id is not None
+ and self.model.config.pad_token_id is None
+ ):
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
+ if (
+ self.tokenizer.bos_token_id is not None
+ and self.model.config.bos_token_id is None
+ ):
+ self.model.config.bos_token_id = self.tokenizer.bos_token_id
+ if (
+ self.tokenizer.eos_token_id is not None
+ and self.model.config.eos_token_id is None
+ ):
+ self.model.config.eos_token_id = self.tokenizer.eos_token_id
+
+ # Local
+ from .utils import add_noisy_embeddings, convert_loss_to_reduce_sum
+
+ self.model = convert_loss_to_reduce_sum(
+ self.model, use_dolomite=(self.model_type == "dolomite")
+ )
+ self.model = add_noisy_embeddings(self.model, noise_alpha=self.noise_alpha)
+
+ @staticmethod
+ def supports_flash_attention(device_id=0):
+ """Check if a GPU supports FlashAttention."""
+ major, minor = torch.cuda.get_device_capability(device_id)
+ # Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
+ is_sm8x = major == 8 and minor >= 0
+ is_sm90 = major == 9 and minor == 0
+ dev_name = torch.cuda.get_device_properties(device_id).gcnArchName.split(":")[0]
+ is_compat_amd = dev_name in ("gfx90a", "gfx940", "gfx941", "gfx942")
+ return is_sm8x or is_sm90 or is_compat_amd
+
+ @staticmethod
+ def check_flash_attn_enabled(disable_flash_attn: bool, use_dolomite: bool) -> bool:
+ """Check if flash attention should be enabled based on configuration.
+
+ Args:
+ disable_flash_attn: Whether flash attention is explicitly disabled
+ use_dolomite: Whether dolomite padding-free transformer is being used
+
+ Returns:
+ bool: Whether flash attention should be enabled
+
+ Raises:
+ RuntimeError: If trying to use flash attention on unsupported hardware
+ or trying to use dolomite without flash attention
+ """
+ if not disable_flash_attn:
+ if Model.supports_flash_attention():
+ return True
+ else:
+ raise RuntimeError(
+ "ERROR: Trying to use Flash Attention on unsupported hardware. Please set disable_flash_attn to True."
+ )
+ elif use_dolomite:
+ raise RuntimeError(
+ "ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported"
+ )
+ return False
+
+
+def setup_optimizer(
+ model: Model,
+ cpu_offload: bool,
+ name: Optimizers | None,
+ learning_rate: int,
+ betas: Tuple[float, float] = (0.9, 0.95),
+) -> torch.optim.Optimizer:
+ """Setup and return an optimizer based on the given parameters.
+
+ Args:
+ model: The model to optimize
+ cpu_offload: Whether to offload optimizer to CPU (for DeepSpeed)
+ name: Optional optimizer name to use
+ learning_rate: Learning rate for the optimizer
+ betas: Beta parameters for Adam optimizers
+
+ Returns:
+ A PyTorch optimizer instance
+ """
+ if name is not None:
+ if name == Optimizers.ADAMW:
+ return AdamW(
+ model.parameters(),
+ lr=learning_rate,
+ betas=betas,
+ weight_decay=0.0,
+ )
+ elif name == Optimizers.CPUAdam:
+ return DeepSpeedCPUAdam(model.parameters(), lr=learning_rate, betas=betas)
+ elif name == Optimizers.FusedAdam:
+ return FusedAdam(model.parameters(), lr=learning_rate, betas=betas)
+ else:
+ raise ValueError(f"Unknown optimizer type: {name}")
+ else:
+ if model.distributed_framework == DistributedBackend.FSDP:
+ return AdamW(model.parameters(), lr=learning_rate, betas=betas)
+ elif model.distributed_framework == DistributedBackend.DEEPSPEED:
+ if cpu_offload:
+ return DeepSpeedCPUAdam(
+ model.parameters(), lr=learning_rate, betas=betas
+ )
+ else:
+ return FusedAdam(model.parameters(), lr=learning_rate, betas=betas)
+
+
+class Accelerator:
+ def __init__(
+ self,
+ model: Model,
+ samples_per_gpu: int,
+ grad_accum: int,
+ train_loader: DataLoader,
+ save_samples: int,
+ distributed_framework: DistributedBackend, # dist framework is assoc with Accelerator primarily.
+ fsdp_sharding_strategy: Optional[str] = None,
+ deepspeed_cpu_offload_optimizer: Optional[bool] = False,
+ deepspeed_cpu_offload_optimizer_pin_memory: Optional[bool] = False,
+ deepspeed_cpu_offload_optimizer_ratio: Optional[float] = None,
+ fsdp_cpu_offload_params: Optional[bool] = False,
+ ):
+ self.samples_per_gpu = samples_per_gpu
+ self.save_samples = save_samples
+ self.grad_accum = grad_accum
+ self.model = model
+ self.distributed_framework = distributed_framework
+ self.fsdp_sharding_strategy = fsdp_sharding_strategy
+ self.deepspeed_cpu_offload_optimizer = deepspeed_cpu_offload_optimizer
+ self.deepspeed_cpu_offload_optimizer_pin_memory = (
+ deepspeed_cpu_offload_optimizer_pin_memory
+ )
+ self.train_loader = train_loader
+ self.deepspeed_cpu_offload_optimizer_ratio = (
+ deepspeed_cpu_offload_optimizer_ratio
+ )
+ self.fsdp_cpu_offload_params = fsdp_cpu_offload_params
+
+ if self.distributed_framework == DistributedBackend.DEEPSPEED:
+ # Standard
+ accel_args = {
+ "deepspeed_plugin": self.get_ds_plugin(
+ world_size=torch.distributed.get_world_size(),
+ samples_per_gpu=samples_per_gpu,
+ grad_accum=grad_accum,
+ opts=DeepSpeedOptions(
+ cpu_offload_optimizer=deepspeed_cpu_offload_optimizer,
+ cpu_offload_optimizer_ratio=self.deepspeed_cpu_offload_optimizer_ratio,
+ cpu_offload_optimizer_pin_memory=self.deepspeed_cpu_offload_optimizer_pin_memory,
+ save_samples=save_samples,
+ ),
+ ),
+ }
+ elif self.distributed_framework == DistributedBackend.FSDP:
+ accel_args = {
+ "fsdp_plugin": self.get_fsdp_config(),
+ "mixed_precision": "bf16",
+ }
+ self.accelerator = TransformersAccel(
+ **accel_args,
+ )
+ self.accelerator.even_batches = False
+
+ new_m = self.accelerator.prepare(model.model)
+ self.model.update_model(new_m)
+
+ def prepare_with_optimizer(
+ self,
+ optimizer: torch.optim.Optimizer,
+ lr_scheduler: str,
+ num_epochs: int,
+ num_warmup_steps: int,
+ ):
+ self.setup_lr_scheduler(
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ num_epochs=num_epochs,
+ num_warmup_steps=num_warmup_steps,
+ )
+ new_m, new_opt, _, self.lr_scheduler = self.accelerator.prepare(
+ self.model.model,
+ optimizer,
+ deepcopy(self.train_loader),
+ self.lr_scheduler,
+ )
+ self.lr_scheduler.split_batches = True
+ self.model.update_model(new_m)
+ self.optimizer = new_opt
+
+ def setup_lr_scheduler(
+ self,
+ optimizer: torch.optim.Optimizer,
+ lr_scheduler: str,
+ num_epochs: int,
+ num_warmup_steps: int,
+ ):
+ self.lr_scheduler = get_scheduler(
+ name=lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_epochs * len(self.train_loader) // self.grad_accum,
+ )
+
+ def __getattr__(self, name):
+ # Forward anything not found to the underlying optimizer
+ return getattr(self.accelerator, name)
+
+ def get_fsdp_config(self):
+ # Standard
+ from functools import partial
+
+ # Third Party
+ from accelerate.utils import FullyShardedDataParallelPlugin
+ from peft.utils.other import fsdp_auto_wrap_policy
+ from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
+ from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
+ from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
+
+ # First Party
+ from instructlab.training.utils import get_module_class_from_name
+
+ is_lora = self.model.lora_config is not None
+ block_name = self.model._no_split_modules[0]
+
+ wrap_policy = None
+ if is_lora > 0:
+ wrap_policy = fsdp_auto_wrap_policy(self.model)
+ else:
+ wrap_policy = partial(
+ transformer_auto_wrap_policy,
+ transformer_layer_cls={
+ get_module_class_from_name(self.model, block_name),
+ },
+ )
+
+ # TODO(osilkin): BACKWARD_POST trades memory utilization for processing time, which is important for systems utilizing LoRA
+ # We should have this be configurable in the future.
+ prefetch_policy = (
+ BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE
+ )
+ fsdp_plugin = FullyShardedDataParallelPlugin(
+ auto_wrap_policy=wrap_policy,
+ limit_all_gathers=True,
+ backward_prefetch=prefetch_policy,
+ sharding_strategy=ShardingStrategy[self.fsdp_sharding_strategy],
+ cpu_offload=CPUOffload(self.fsdp_cpu_offload_params),
+ )
+
+ # `use_orig_params` must be disabled when using LoRA and FSDP together
+ # Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts
+ if self.model.lora_config is not None:
+ fsdp_plugin.use_orig_params = False
+
+ return fsdp_plugin
+
+ def get_ds_plugin(
+ self, world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions
+ ):
+ # Third Party
+ from accelerate.utils import DeepSpeedPlugin
+
+ ds_config = {
+ "train_batch_size": samples_per_gpu * world_size * grad_accum,
+ "gradient_accumulation_steps": grad_accum,
+ "train_micro_batch_size_per_gpu": samples_per_gpu,
+ "steps_per_print": 1,
+ "zero_optimization": {
+ "stage": 2,
+ # this option is only supported with DeepSpeed ZeRO stage 3
+ "offload_param": {"device": "none"},
+ "offload_optimizer": {"device": "none"},
+ },
+ "bf16": {"enabled": True},
+ "gradient_clipping": 1.0,
+ "prescale_gradients": False,
+ "wall_clock_breakdown": False,
+ }
+
+ if opts.cpu_offload_optimizer:
+ # this only works when the cpu offload optimizer is enabled
+ ds_config["zero_optimization"]["offload_optimizer"] = {
+ # CPU offloading is the only option available in ZeRO stage 2
+ "device": "cpu",
+ "pin_memory": opts.cpu_offload_optimizer_pin_memory,
+ "ratio": opts.cpu_offload_optimizer_ratio,
+ }
+ ds_plugin = DeepSpeedPlugin(
+ hf_ds_config=ds_config,
+ )
+ return ds_plugin
+
+ @classmethod
+ def setup_deepspeed(
+ cls,
+ model: Model,
+ samples_per_gpu: int,
+ grad_accum: int,
+ train_loader: DataLoader,
+ deepspeed_cpu_offload_optimizer: Optional[bool],
+ deepspeed_cpu_offload_optimizer_pin_memory: Optional[bool],
+ deepspeed_cpu_offload_optimizer_ratio: float,
+ save_samples: int,
+ ):
+ return cls(
+ model=model,
+ grad_accum=grad_accum,
+ train_loader=train_loader,
+ distributed_framework=DistributedBackend.DEEPSPEED,
+ samples_per_gpu=samples_per_gpu,
+ deepspeed_cpu_offload_optimizer=deepspeed_cpu_offload_optimizer,
+ deepspeed_cpu_offload_optimizer_pin_memory=deepspeed_cpu_offload_optimizer_pin_memory,
+ deepspeed_cpu_offload_optimizer_ratio=deepspeed_cpu_offload_optimizer_ratio,
+ save_samples=save_samples,
+ )
+
+ @classmethod
+ def setup_fsdp(
+ cls,
+ model: Model,
+ samples_per_gpu: int,
+ grad_accum: int,
+ train_loader: DataLoader,
+ fsdp_sharding_strategy: Optional[str],
+ fsdp_cpu_offload_params: bool,
+ save_samples: int,
+ ):
+ return cls(
+ model=model,
+ grad_accum=grad_accum,
+ train_loader=train_loader,
+ distributed_framework=DistributedBackend.FSDP,
+ samples_per_gpu=samples_per_gpu,
+ fsdp_sharding_strategy=fsdp_sharding_strategy,
+ fsdp_cpu_offload_params=fsdp_cpu_offload_params,
+ save_samples=save_samples,
+ )
+
+
+class Checkpointer:
+ def __init__(
+ self,
+ model: Model,
+ optimizer: torch.optim.Optimizer,
+ accelerator: Accelerator,
+ strategy="all",
+ ):
+ self.strategy = strategy.lower()
+ self.model = model
+ self.optimizer = optimizer
+ self.accelerator = accelerator
+
+ # Map strategies to internal methods
+ self._checkpoint_fn = {
+ "full_state": self.save_full_state,
+ "hf_format": self.save_hf_format_accelerate,
+ "all": self.save_all_checkpoints,
+ }.get(self.strategy, self._no_checkpoint)
+
+ def checkpoint(self, *args, **kwargs):
+ # Calls the method chosen at init
+ return self._checkpoint_fn(*args, **kwargs)
+
+ # pylint: disable=unused-argument
+ def _no_checkpoint(self, *args, **kwargs):
+ print("[None] Skipping checkpointing.")
+
+ # pylint: disable=unused-argument
+ def save_fsdp_lora_model(
+ self,
+ output_dir: Path,
+ **kwargs,
+ ):
+ """Given a LoRA model wrapped by FSDP and Accelerate, save a full copy of the original
+ model with the trained LoRA adapters merged into the copy.
+
+ This function creates a full copy of the model being trained and stores it in CPU memory.
+ If encountering OOM errors on CPU, this is likely a culprit.
+
+ Args:
+ args (Namespace): Args received by the ArgumentParser.
+ model (FSDP): FSDP model as prepared by `accelerate.Accelerator`
+ accelerator (Accelerator): The given accelerator object.
+ """
+ # Third Party
+ from peft import LoraModel
+
+ if self.accelerator.distributed_type != DistributedBackend.FSDP:
+ raise RuntimeError(
+ "`save_fsdp_lora_model` was called when FSDP was not being used."
+ )
+ if not wraps(self.model, FSDP):
+ raise RuntimeError(
+ "`save_fsdp_lora_model` was called but provided model is not an FSDP model."
+ )
+ if not wraps(self.model, LoraModel):
+ raise RuntimeError(
+ "`save_fsdp_lora_model` was called but provided model is not a LoRA model."
+ )
+
+ # okay now that validation is out of the way, we are free to implement saving
+ sd_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
+ with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, sd_config):
+ state = self.model.state_dict()
+
+ # When training a LoRA with FSDP and Accelerate, you cannot directly merge the adapters into
+ # the model wrapped by FSDP. To get around this limitation, we get a copy of the state dict
+ # create an identical model on CPU, load the state dict into the CPU model, merge the adapters
+ # and save the model to disk.
+ if self.accelerator.is_main_process:
+ # Third Party
+ from transformers import AutoModelForCausalLM
+
+ # remove device_map from args list so we can load the model on CPU
+ old_device_map = self.model.base_model_args.pop("device_map", None)
+ model_copy = AutoModelForCausalLM.from_pretrained(
+ **self.model.base_model_args, device_map="cpu"
+ )
+ model_copy = LoraModel(model_copy, self.model.lora_config, "default")
+ model_copy.load_state_dict(state)
+ model_copy.merge_and_unload(progressbar=True)
+ model_copy.save_pretrained(output_dir, safe_serialization=True)
+ self.model.config.to_json_file(f"{output_dir}/config.json")
+ self.model.tokenizer.save_pretrained(output_dir)
+ del model_copy
+ if old_device_map:
+ # return the previous device_map so it can be used later on if needed
+ self.model.base_model_args["device_map"] = old_device_map
+
+ dist.barrier()
+
+ # pylint: disable=unused-argument
+ def save_full_state(
+ self,
+ output_dir,
+ epoch: int,
+ samples_seen: int,
+ **kwargs,
+ ):
+ """
+ Saves model, optimizer, and lr_scheduler state.
+ TODO: save model config - decided not to do this.
+ TODO: save tokenizer - decided not to do this.
+ TODO: handle LoRA
+ TODO: handle granite
+ """
+ if self.model.lora_config is not None:
+ raise NotImplementedError("Can't save full state for LoRA at the moment.")
+
+ # if args.is_granite:
+ # raise NotImplementedError("Can't save full state for Granite models yet.")
+
+ output_dir = Path(output_dir) / "full_state" / f"epoch_{epoch}"
+ log_rank_0(
+ f"\033[93mSaving full model state in {output_dir}\033[0m", to_print=True
+ )
+
+ # patch FSDP state dict method so it works correctly.
+ def _get_state_dict_patched(model, unwrap=False):
+ return get_state_dict_unpatched(model, unwrap=unwrap)
+
+ if self.accelerator.distributed_framework == "fsdp":
+ get_state_dict_unpatched = self.accelerator.get_state_dict
+ self.accelerator.get_state_dict = _get_state_dict_patched
+
+ self.accelerator.save_state(
+ output_dir=output_dir,
+ # max_shard_size="5GB",
+ # safe_serialization=True,
+ )
+
+ # save metadata file for current training status
+ if self.accelerator.is_main_process:
+ # TODO: should we set the global_step here rather than calculating global_step
+ # based on samples_seen?
+ metadata = {"current_epoch": epoch, "samples_seen": samples_seen}
+ torch.save(metadata, output_dir / "training_metadata.json")
+ log_rank_0(
+ f"\033[93mSaving training state: {metadata}\033[0m", to_print=True
+ )
+
+ log_rank_0(f"\033[93mModel state saved in: {output_dir}\033[0m", to_print=True)
+
+ # cleanup
+ if self.accelerator.distributed_framework == "fsdp":
+ self.accelerator.get_state_dict = get_state_dict_unpatched
+
+ # pylint: disable=unused-argument
+ def save_hf_format_accelerate(
+ self,
+ output_dir,
+ epoch: int,
+ samples_seen: int,
+ last_epoch: bool = False,
+ **kwargs,
+ ):
+ # Standard
+ from tempfile import TemporaryDirectory
+
+ # Build the subdirectory name
+ subdir = "last_epoch" if last_epoch else f"samples_{samples_seen}"
+
+ log_rank_0(
+ f"\033[93mSaving model in huggingface format at: {subdir}\033[0m",
+ to_print=True,
+ )
+ start = time.time()
+
+ if self.model.model_type in ("gpt_megatron", "gpt_dolomite"):
+ convert_dolomite = False
+ else:
+ convert_dolomite = True
+
+ # Build the final output directory path
+ final_output_dir = Path(output_dir) / "hf_format" / subdir
+
+ if self.model.model_type == "dolomite" and convert_dolomite:
+ tmpdir = TemporaryDirectory("w") # pylint: disable=consider-using-with
+ output_dir = Path(tmpdir.name)
+ else:
+ output_dir = final_output_dir
+
+ CONFIG_NAME = "config.json"
+ output_config_file = output_dir / CONFIG_NAME
+
+ # XXX(osilkin): LoRA + FSDP requires a different saving path than the others
+ # so we set this variable and use it to avoid those paths further down.
+ is_fsdp_lora = (
+ self.model.lora_config is not None
+ and self.accelerator.distributed_type == DistributedBackend.FSDP
+ )
+ if is_fsdp_lora:
+ self.save_fsdp_lora_model(
+ model=self.model,
+ accelerator=self.accelerator,
+ output_dir=output_dir,
+ )
+
+ get_state_dict_unpatched = self.accelerator.get_state_dict
+
+ def _get_state_dict_patched(model, unwrap=False):
+ return get_state_dict_unpatched(model, unwrap=unwrap)
+
+ self.accelerator.get_state_dict = _get_state_dict_patched
+
+ if not is_fsdp_lora and self.accelerator.is_main_process:
+ if self.model.lora_config is not None:
+ self.model.module.merge_adapter()
+ model_state = self.model.module.state_dict()
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ if not self.model.module.config.architectures and convert_dolomite:
+ arch_added = False
+ if self.model.model_type == "llama":
+ self.model.module.config.architectures = ["LlamaForCausalLM"]
+ arch_added = True
+ elif self.model.model_type == "granite":
+ self.model.module.config.architectures = ["GraniteForCausalLM"]
+ arch_added = True
+ if arch_added:
+ warnings.warn(
+ f"Adding architectures to ckpt: {self.model.module.config.architectures}",
+ )
+ else:
+ warnings.warn(
+ f"Converting from dolomite, but no architecture field added to config.json",
+ )
+ self.model.module.config.to_json_file(output_config_file)
+ self.model.tokenizer.save_pretrained(output_dir)
+
+ if self.model.lora_config is not None:
+ self.save_dict_accelerate(
+ self.accelerator,
+ model_state,
+ save_directory=output_dir,
+ max_shard_size="5GB",
+ safe_serialization=True,
+ )
+ self.model.module.unmerge_adapter()
+
+ if self.model.lora_config is None:
+ self.accelerator.save_model(
+ self.model,
+ save_directory=output_dir,
+ max_shard_size="5GB",
+ safe_serialization=True,
+ )
+
+ if (
+ self.model.model_type == "dolomite"
+ and convert_dolomite
+ and self.accelerator.is_main_process
+ ):
+ # export doesnt like the directory to exist
+ if final_output_dir.exists():
+ shutil.rmtree(final_output_dir)
+ export_to_huggingface(
+ pretrained_model_name_or_path=tmpdir.name,
+ save_path=final_output_dir,
+ model_type=self.model.model_type,
+ )
+ tmpdir.cleanup()
+
+ log_rank_0(f"\033[93mModel saved in {final_output_dir}\033[0m", to_print=True)
+ log_rank_0(f"saving took {time.time() - start} seconds")
+ dist.barrier()
+
+ self.accelerator.get_state_dict = get_state_dict_unpatched
+
+ def save_dict_accelerate(
+ self,
+ accelerator: Accelerator,
+ state_to_save,
+ save_directory,
+ max_shard_size="5GB",
+ safe_serialization=True,
+ ):
+ old_get_state = accelerator.get_state_dict
+ accelerator.get_state_dict = self._copy_no_lora_dict
+
+ def skip_precheck_loops():
+ return []
+
+ # The save model does a loop over modules and params in order to determine how to get state dict. Since we already have the state dict directly, we want to bypass those checks.
+ state_to_save.modules = skip_precheck_loops
+ state_to_save.parameters = skip_precheck_loops
+
+ accelerator.save_model(
+ state_to_save,
+ save_directory=save_directory,
+ max_shard_size=max_shard_size,
+ safe_serialization=safe_serialization,
+ )
+
+ accelerator.get_state_dict = old_get_state
+
+ def _copy_no_lora_dict(self, state_dict):
+ # Standard
+ from collections import OrderedDict
+
+ cleaned_state_dict = OrderedDict()
+ for param_tensor in state_dict:
+ if not "lora" in param_tensor:
+ cleaned_state_dict[
+ param_tensor.replace(".base_layer", "").replace(
+ "basemodel.model.", ""
+ )
+ ] = deepcopy(state_dict[param_tensor]).cpu()
+ return cleaned_state_dict
+
+ def load_latest_full_state(self, output_dir: Path) -> None:
+ """Loads accelerator state from most recently saved checkpoint
+ in `output_dir/full_state`.
+
+ Args:
+ output_dir: Base output directory containing the full_state subdirectory
+ """
+ full_state_dir = output_dir / "full_state"
+
+ if not full_state_dir.is_dir():
+ return
+
+ # picks checkpoint with the largest number of samples by splitting the "samples_NNNN" string on _
+ # and comparing the number at the end of the string
+ checkpoint_list = sorted(
+ list(full_state_dir.iterdir()),
+ reverse=True,
+ key=lambda x: int(str(x).rsplit("_", maxsplit=1)[-1]),
+ )
+
+ if len(checkpoint_list) == 0:
+ log_rank_0(
+ f"\033[93mNo checkpoints to load from: {full_state_dir}\033[0m",
+ to_print=True,
+ )
+ return
+
+ latest_checkpoint = checkpoint_list[0]
+ log_rank_0(
+ f"\033[93mLoading checkpoint from: {latest_checkpoint}\033[0m",
+ to_print=True,
+ )
+ self.accelerator.load_state(latest_checkpoint)
+
+ def save_all_checkpoints(
+ self,
+ output_dir,
+ epoch: int,
+ samples_seen: int,
+ last_epoch: bool = False,
+ ):
+ self.save_hf_format_accelerate(
+ output_dir=output_dir,
+ epoch=epoch,
+ samples_seen=samples_seen,
+ last_epoch=last_epoch,
+ )
+ self.save_full_state(
+ output_dir=output_dir, epoch=epoch, samples_seen=samples_seen
+ )
diff --git a/src/instructlab/training/train.py b/src/instructlab/training/train.py
new file mode 100644
index 00000000..064da890
--- /dev/null
+++ b/src/instructlab/training/train.py
@@ -0,0 +1,238 @@
+# Standard
+from typing import List
+import os
+import time
+
+# Third Party
+from pydantic import BaseModel
+from tqdm import tqdm
+import torch
+
+# First Party
+from instructlab.training.async_logger import AsyncStructuredLogger
+from instructlab.training.model import Accelerator, Checkpointer, Model
+
+
+class Metrics(BaseModel):
+ samples_seen: int
+ total_loss: float
+ batch_size: int
+ num_loss_counted_tokens: int
+ global_grad_norm: float | None = None
+ total_samples: int
+ overall_throughput: float
+ current_lr: float
+
+
+def train(
+ model: Model,
+ optimizer: torch.optim.Optimizer,
+ accelerator: Accelerator,
+ metric_logger: AsyncStructuredLogger,
+ checkpointer: Checkpointer,
+ effective_batch_size: int,
+ num_epochs: int,
+ last_step: int,
+ checkpoint_at_epoch: bool,
+ output_dir,
+ use_dolomite: bool,
+ save_last: bool,
+ sampler: str,
+):
+ model.train()
+ global_step = 1
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+
+ batch_size = effective_batch_size // accelerator.grad_accum
+ samples_seen = 0
+
+ if accelerator.save_samples > 0:
+ accelerator.save_samples = (accelerator.save_samples // batch_size) * batch_size
+ (
+ print(
+ f"\033[93mNumber of samples per save: {accelerator.save_samples}\033[0m"
+ )
+ if local_rank == 0
+ else None
+ )
+ all_metrics: List[Metrics] = []
+ for epoch in range(num_epochs):
+ # TODO, implement better metrics gathering with returned `List[Metrics]` type
+ samples_seen, new_metrics = train_epoch(
+ epoch_number=epoch,
+ samples_seen=samples_seen,
+ local_rank=local_rank,
+ global_step=global_step,
+ last_step=last_step,
+ world_size=world_size,
+ batch_size=batch_size,
+ samples_per_gpu=accelerator.samples_per_gpu,
+ checkpoint_at_epoch=checkpoint_at_epoch,
+ output_dir=output_dir,
+ use_dolomite=use_dolomite,
+ checkpointer=checkpointer,
+ accelerator=accelerator,
+ optimizer=optimizer,
+ model=model,
+ sampler=sampler,
+ metric_logger=metric_logger,
+ )
+ all_metrics = all_metrics + new_metrics
+
+ if save_last:
+ checkpointer.save_hf_format_accelerate(
+ output_dir=output_dir,
+ epoch=num_epochs,
+ samples_seen=samples_seen,
+ last_epoch=True,
+ )
+
+
+def train_epoch(
+ epoch_number: int,
+ samples_seen: int,
+ local_rank: int,
+ global_step: int,
+ last_step: int,
+ world_size: int,
+ batch_size: int,
+ samples_per_gpu: int,
+ checkpoint_at_epoch: bool,
+ output_dir: str,
+ sampler: str,
+ checkpointer: Checkpointer,
+ model: Model,
+ optimizer: torch.optim.Optimizer,
+ accelerator: Accelerator,
+ use_dolomite: bool,
+ metric_logger: AsyncStructuredLogger,
+) -> tuple[int, List[Metrics]]:
+ all_metrics: List[Metrics] = []
+ global_grad_norm = None
+ if sampler in ("multipack"):
+ accelerator.train_loader.batch_sampler.set_epoch(epoch_number)
+ elif sampler in ("distributed"):
+ accelerator.train_loader.sampler.set_epoch(epoch_number)
+ else:
+ raise AttributeError(
+ f"Sampler {sampler} is invalid. Valid samplers are multipack and distributed."
+ )
+ if local_rank == 0:
+ inner_pb = tqdm(
+ range(len(accelerator.train_loader)), desc=f"Epoch {epoch_number}"
+ )
+
+ # blast through the batches in the train loader up to the last step within the epoch.
+ for batch in accelerator.train_loader:
+ if global_step <= last_step:
+ # in the case of resuming, last_step > 0
+ global_step += 1
+ if local_rank == 0:
+ inner_pb.update(1)
+ continue
+ start = time.time()
+ num_loss_counted_tokens = float(
+ torch.tensor([batch.pop("num_loss_counted_tokens")])
+ )
+ micro_batch_size = float(torch.tensor([batch.pop("num_samples")]))
+ if not use_dolomite:
+ for k in batch:
+ batch[k] = batch[k].to(local_rank)
+ output = model(
+ **batch,
+ use_cache=False,
+ )
+ loss = output.loss
+ log_loss = loss.detach().item()
+ num_loss_counted_tokens, micro_batch_size, log_loss = map(
+ float,
+ accelerator.reduce(
+ torch.tensor(
+ [num_loss_counted_tokens, micro_batch_size, log_loss],
+ dtype=torch.float32,
+ device=accelerator.device,
+ ),
+ reduction="sum",
+ ),
+ )
+ samples_seen += int(micro_batch_size)
+ loss = (
+ loss / num_loss_counted_tokens * world_size
+ ) # dividing by the total number of non-padding tokens and multiplying by the number of GPUs so when accelerate averages by world_size, it will be the correct loss.
+ accelerator.backward(loss)
+
+ if global_step % accelerator.grad_accum == 0:
+ global_grad_norm = accelerator.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+ accelerator.lr_scheduler.step()
+ optimizer.zero_grad()
+
+ if local_rank == 0:
+ elapsed_time = time.time() - start
+ overall_throughput = samples_per_gpu * world_size / elapsed_time
+ current_lr = accelerator.lr_scheduler.get_last_lr()[0]
+ global_grad_norm = (
+ model.get_global_grad_norm()
+ if hasattr(model, "get_global_grad_norm")
+ else global_grad_norm
+ )
+ global_grad_norm = (
+ float(global_grad_norm) if global_grad_norm is not None else None
+ )
+ # TODO - Bring back weight_norm gather
+ # weight_norm = float(
+ # model.optimizer.single_partition_of_fp32_groups[0].norm()
+ # )
+ metrics = Metrics(
+ samples_seen=samples_seen,
+ total_loss=float(log_loss / num_loss_counted_tokens),
+ batch_size=int(micro_batch_size),
+ num_loss_counted_tokens=int(num_loss_counted_tokens),
+ global_grad_norm=global_grad_norm,
+ total_samples=len(accelerator.train_loader.dataset),
+ overall_throughput=overall_throughput,
+ current_lr=current_lr,
+ )
+ all_metrics.append(metrics)
+ metric_logger.log_sync(
+ {
+ "epoch": epoch_number,
+ "step": global_step,
+ "rank": torch.distributed.get_rank(),
+ "overall_throughput": metrics.overall_throughput,
+ "lr": metrics.current_lr,
+ "cuda_mem_allocated": torch.cuda.memory_allocated() / (1024**3),
+ "cuda_malloc_retries": torch.cuda.memory_stats()[
+ "num_alloc_retries"
+ ],
+ "num_loss_counted_tokens": metrics.num_loss_counted_tokens,
+ "batch_size": metrics.batch_size,
+ "total_loss": metrics.total_loss,
+ "samples_seen": metrics.samples_seen,
+ "gradnorm": metrics.global_grad_norm,
+ "total_samples": len(accelerator.train_loader.dataset),
+ # "weight_norm": weight_norm,
+ }
+ )
+
+ if accelerator.save_samples > 0 and (
+ global_step * batch_size % accelerator.save_samples == 0
+ ):
+ checkpointer.checkpoint(
+ output_dir=output_dir,
+ epoch=epoch_number,
+ samples_seen=samples_seen,
+ )
+
+ global_step += 1
+ if local_rank == 0:
+ inner_pb.update(1)
+ torch.cuda.empty_cache()
+ if checkpoint_at_epoch:
+ checkpointer.checkpoint(
+ output_dir=output_dir,
+ epoch=epoch_number,
+ samples_seen=samples_seen,
+ )
+ return samples_seen, all_metrics
diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py
index cd8034b9..36f6c6db 100644
--- a/src/instructlab/training/utils.py
+++ b/src/instructlab/training/utils.py
@@ -1,13 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
-from argparse import Namespace
-from collections import OrderedDict
from contextlib import contextmanager
-from copy import deepcopy
from functools import partial
from pathlib import Path
-from tempfile import TemporaryDirectory
from typing import Any, List, Optional, Tuple
import importlib
import inspect
@@ -22,12 +18,7 @@
# Third Party
# pylint: disable=no-name-in-module
-from accelerate import Accelerator, DistributedType
-from instructlab.dolomite.hf_models import (
- GPTDolomiteConfig,
- export_to_huggingface,
- import_from_huggingface,
-)
+from instructlab.dolomite.hf_models import GPTDolomiteConfig, import_from_huggingface
from rich.logging import RichHandler
from torch import distributed as dist
from torch import nn
@@ -37,10 +28,6 @@
apply_activation_checkpointing,
checkpoint_wrapper,
)
-from torch.distributed.fsdp import FullStateDictConfig
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-from torch.distributed.fsdp import StateDictType
-from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer
import numpy as np
import torch
import torch.nn.functional as F
@@ -99,20 +86,6 @@ def check_valid_train_args(train_args: TrainingArgs):
"Quantization is not supported when training LoRA models with FSDP. For quantized LoRA training, please switch to DeepSpeed."
)
- if check_flash_attn_enabled(train_args.disable_flash_attn, train_args.use_dolomite):
- # verify that the flash_attn package is actually installed
- try:
- # pylint: disable=unused-import
- # Third Party
- import flash_attn
- except ImportError as exc:
- raise ImportError(
- "Flash attention is enabled but flash_attn is not installed. You can resolve this in the following ways:\n"
- "1. Ensure the CUDA/ROCM version of the training library is installed via: `pip install instructlab-training[cuda]` or `pip install instructlab-training[rocm]`\n"
- "2. Install flash_attn manually via: `pip install flash-attn --no-build-isolation`\n"
- "3. Disable flash attention by setting `disable_flash_attn=True` in your training arguments\n"
- ) from exc
-
# liger checks
if train_args.lora and train_args.lora.rank > 0 and train_args.use_liger:
raise ValueError(
@@ -206,34 +179,6 @@ def listen(self):
break
-def supports_flash_attention(device_id=0):
- """Check if a GPU supports FlashAttention."""
- major, minor = torch.cuda.get_device_capability(device_id)
- # Check if the GPU architecture is Ampere (SM 8.x) or newer (SM 9.0)
- is_sm8x = major == 8 and minor >= 0
- is_sm90 = major == 9 and minor == 0
- dev_name = torch.cuda.get_device_properties(device_id).gcnArchName.split(":")[0]
- is_compat_amd = dev_name in ("gfx90a", "gfx940", "gfx941", "gfx942")
- return is_sm8x or is_sm90 or is_compat_amd
-
-
-def check_flash_attn_enabled(disable_flash_attn: bool, use_dolomite: bool) -> bool:
- if not disable_flash_attn:
- if supports_flash_attention():
- flash_enabled = True
- else:
- raise RuntimeError(
- "ERROR: Trying to use Flash Attention on unsupported hardware. Please set disable_flash_attn to True."
- )
- elif use_dolomite:
- raise RuntimeError(
- "ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported"
- )
- else:
- flash_enabled = False
- return flash_enabled
-
-
def make_collate_fn(
pad_token_id, use_dolomite=False, flash_enabled=True, max_batch_len=60000
):
@@ -487,173 +432,6 @@ def wraps(module: nn.Module, wrapped_classes: Tuple[Any]) -> bool:
return False
-def create_lora_config(model: PreTrainedModel, args: Namespace) -> "peft.LoraConfig":
- # if lora
- # Third Party
- from peft import LoraConfig
-
- # ensure we select only the modules that exist in the model
- proj_layers = get_projection_layer_names(model)
- if not args.lora_target_modules:
- print(
- f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules"
- )
- if not proj_layers:
- raise RuntimeError("could not find any projection layers in the model")
- args.__dict__["lora_target_modules"] = proj_layers
- else:
- # when the user specifies the module, we should verify that they align with what's in the model
- lora_target_modules_set = set(args.lora_target_modules)
- diff = lora_target_modules_set - set(proj_layers)
- layers_to_target = lora_target_modules_set - diff
- if len(diff) == len(args.lora_target_modules):
- raise ValueError(
- f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically."
- )
- if diff:
- print(
- f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m"
- )
- args.__dict__["lora_target_modules"] = list(layers_to_target)
-
- return LoraConfig(
- lora_alpha=args.lora_alpha,
- lora_dropout=args.lora_dropout,
- r=args.lora_r,
- bias="none",
- task_type="CAUSAL_LM",
- target_modules=args.lora_target_modules,
- )
-
-
-def save_fsdp_lora_model(
- args: Namespace,
- model: FSDP,
- tokenizer: PreTrainedTokenizer,
- accelerator: Accelerator,
- output_dir: Path,
-):
- """Given a LoRA model wrapped by FSDP and Accelerate, save a full copy of the original
- model with the trained LoRA adapters merged into the copy.
-
- This function creates a full copy of the model being trained and stores it in CPU memory.
- If encountering OOM errors on CPU, this is likely a culprit.
-
- Args:
- args (Namespace): Args received by the ArgumentParser.
- model (FSDP): FSDP model as prepared by `accelerate.Accelerator`
- accelerator (Accelerator): The given accelerator object.
- """
- # Third Party
- from peft import LoraConfig, LoraModel
-
- if accelerator.distributed_type != DistributedType.FSDP:
- raise RuntimeError(
- "`save_fsdp_lora_model` was called when FSDP was not being used."
- )
- if not wraps(model, FSDP):
- raise RuntimeError(
- "`save_fsdp_lora_model` was called but provided model is not an FSDP model."
- )
- if not wraps(model, LoraModel):
- raise RuntimeError(
- "`save_fsdp_lora_model` was called but provided model is not a LoRA model."
- )
-
- # okay now that validation is out of the way, we are free to implement saving
- lora_conf: LoraConfig = args.lora_config
- sd_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
- with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, sd_config):
- state = model.state_dict()
-
- # When training a LoRA with FSDP and Accelerate, you cannot directly merge the adapters into
- # the model wrapped by FSDP. To get around this limitation, we get a copy of the state dict
- # create an identical model on CPU, load the state dict into the CPU model, merge the adapters
- # and save the model to disk.
- if accelerator.is_main_process:
- # remove device_map from args list so we can load the model on CPU
- old_device_map = args.base_model_args.pop("device_map", None)
- model_copy = AutoModelForCausalLM.from_pretrained(
- **args.base_model_args, device_map="cpu"
- )
- model_copy = LoraModel(model_copy, lora_conf, "default")
- model_copy.load_state_dict(state)
- model_copy.merge_and_unload(progressbar=True)
- model_copy.save_pretrained(output_dir, safe_serialization=True)
- model.config.to_json_file(f"{output_dir}/config.json")
- tokenizer.save_pretrained(output_dir)
- del model_copy
- if old_device_map:
- # return the previous device_map so it can be used later on if needed
- args.base_model_args["device_map"] = old_device_map
-
- dist.barrier()
-
-
-def prepare_peft_model(
- model: PreTrainedModel,
- peft_config,
- distributed_backend: str,
- gradient_checkpointing=True,
- gradient_checkpointing_kwargs={"use_reentrant": True},
- mixed_precision="bf16",
-):
- # will guard this
- # Third Party
- from peft import (
- LoraModel,
- PeftConfig,
- PeftModel,
- get_peft_model,
- prepare_model_for_kbit_training,
- )
- from trl.trainer.utils import peft_module_casting_to_bf16
-
- if not isinstance(peft_config, PeftConfig):
- raise ValueError(
- "If you want to use the PeftModel, you need to pass a PeftConfig object, "
- f"and you passed a {type(peft_config)}."
- )
-
- if not isinstance(model, PeftModel):
- if getattr(model, "is_loaded_in_8bit", False) or getattr(
- model, "is_loaded_in_4bit", False
- ):
- preprare_model_kwargs = {
- "use_gradient_checkpointing": gradient_checkpointing
- }
-
- # if _support_gc_kwargs:
- preprare_model_kwargs["gradient_checkpointing_kwargs"] = (
- gradient_checkpointing_kwargs
- )
-
- model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
-
- elif gradient_checkpointing:
- # For backward compatibility with older versions of transformers
- if hasattr(model, "enable_input_require_grads"):
- model.enable_input_require_grads()
- else:
-
- def make_inputs_require_grad(module, input, output): # pylint: disable=unused-argument
- output.requires_grad_(True)
-
- model.get_input_embeddings().register_forward_hook(
- make_inputs_require_grad
- )
-
- if distributed_backend == DistributedBackend.FSDP.value:
- # FSDP doesn't like `get_peft_model` as it leads to dtype mismatches
- model = LoraModel(model, peft_config, "default")
- else:
- model = get_peft_model(model, peft_config)
- if mixed_precision == "bf16" and getattr(model, "is_loaded_in_4bit", False):
- peft_module_casting_to_bf16(model)
-
- return model
-
-
def prepare_universal_checkpoint_from_latest(output_dir):
"""Populate the universal checkpoint in output_dir/step_folder
- 1. read output_dir/latest to get step_folder
@@ -916,328 +694,9 @@ def log_rank_0(msg, include_caller=False, rank=None, to_print=False):
# print(msg)
-def _copy_no_lora_dict(state_dict):
- cleaned_state_dict = OrderedDict()
- for param_tensor in state_dict:
- if not "lora" in param_tensor:
- cleaned_state_dict[
- param_tensor.replace(".base_layer", "").replace("base_model.model.", "")
- ] = deepcopy(state_dict[param_tensor]).cpu()
- return cleaned_state_dict
-
-
-def save_dict_accelerate(
- accelerator: Accelerator,
- state_to_save,
- save_directory,
- max_shard_size="5GB",
- safe_serialization=True,
-):
- old_get_state = accelerator.get_state_dict
- accelerator.get_state_dict = _copy_no_lora_dict
-
- def skip_precheck_loops():
- return []
-
- # The save model does a loop over modules and params in order to determine how to get state dict. Since we already have the state dict directly, we want to bypass those checks.
- state_to_save.modules = skip_precheck_loops
- state_to_save.parameters = skip_precheck_loops
-
- accelerator.save_model(
- state_to_save,
- save_directory=save_directory,
- max_shard_size=max_shard_size,
- safe_serialization=safe_serialization,
- )
-
- accelerator.get_state_dict = old_get_state
-
-
-def save_hf_format_accelerate(
- args,
- model,
- tokenizer,
- accelerator: Accelerator,
- samples_seen,
- is_lora=False,
-):
- # Build the subdirectory name
- subdir = (
- "last_epoch" if args.keep_last_checkpoint_only else f"samples_{samples_seen}"
- )
-
- log_rank_0(
- f"\033[93mSaving model in huggingface format at: {subdir}\033[0m",
- to_print=True,
- )
- start = time.time()
-
- if args.model_type in ("gpt_megatron", "gpt_dolomite"):
- convert_dolomite = False
- else:
- convert_dolomite = True
-
- # Build the final output directory path
- final_output_dir = Path(args.output_dir) / "hf_format" / subdir
-
- if args.use_dolomite and convert_dolomite:
- tmpdir = TemporaryDirectory("w") # pylint: disable=consider-using-with
- output_dir = Path(tmpdir.name)
- else:
- output_dir = final_output_dir
-
- CONFIG_NAME = "config.json"
- output_config_file = output_dir / CONFIG_NAME
-
- # XXX(osilkin): LoRA + FSDP requires a different saving path than the others
- # so we set this variable and use it to avoid those paths further down.
- is_fsdp_lora = is_lora and accelerator.distributed_type == DistributedType.FSDP
- if is_fsdp_lora:
- save_fsdp_lora_model(
- args=args,
- model=model,
- tokenizer=tokenizer,
- accelerator=accelerator,
- output_dir=output_dir,
- )
-
- get_state_dict_unpatched = accelerator.get_state_dict
-
- def _get_state_dict_patched(model, unwrap=False):
- return get_state_dict_unpatched(model, unwrap=unwrap)
-
- accelerator.get_state_dict = _get_state_dict_patched
-
- if not is_fsdp_lora and accelerator.is_main_process:
- if is_lora:
- model.module.merge_adapter()
- model_state = model.module.state_dict()
-
- output_dir.mkdir(parents=True, exist_ok=True)
- if not model.module.config.architectures and convert_dolomite:
- arch_added = False
- if args.model_type == "llama":
- model.module.config.architectures = ["LlamaForCausalLM"]
- arch_added = True
- elif args.model_type == "granite":
- model.module.config.architectures = ["GraniteForCausalLM"]
- arch_added = True
- if arch_added:
- warnings.warn(
- f"Adding architectures to ckpt: {model.module.config.architectures}",
- )
- else:
- warnings.warn(
- f"Converting from dolomite, but no architecture field added to config.json",
- )
- model.module.config.to_json_file(output_config_file)
- tokenizer.save_pretrained(output_dir)
-
- if is_lora:
- save_dict_accelerate(
- accelerator,
- model_state,
- save_directory=output_dir,
- max_shard_size="5GB",
- safe_serialization=True,
- )
- model.module.unmerge_adapter()
-
- if not is_lora:
- accelerator.save_model(
- model,
- save_directory=output_dir,
- max_shard_size="5GB",
- safe_serialization=True,
- )
-
- if args.use_dolomite and convert_dolomite and accelerator.is_main_process:
- # export doesnt like the directory to exist
- if final_output_dir.exists():
- shutil.rmtree(final_output_dir)
- export_to_huggingface(
- pretrained_model_name_or_path=tmpdir.name,
- save_path=final_output_dir,
- model_type=args.model_type,
- )
- tmpdir.cleanup()
-
- log_rank_0(f"\033[93mModel saved in {final_output_dir}\033[0m", to_print=True)
- log_rank_0(f"saving took {time.time() - start} seconds")
- dist.barrier()
-
- accelerator.get_state_dict = get_state_dict_unpatched
-
-
-# this is native deepspeed saving with optimizer, scheduler
-def save_model_ds_native(
- args,
- model,
- tokenizer, # pylint: disable=unused-argument
- samples_seen,
-):
- # to get a statedict from a zero checkpoint, all you need to do is
- # - from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
- # - sd = get_fp32_state_dict_from_zero_checkpoint('ckpt')
- # - sum([math.prod(x.shape) for x in sd.values()]) # check the size (should be correct)
-
- log_rank_0(
- f"\033[93mSaving model+optimizer+scheduler in format at samples_seen: {samples_seen}\033[0m",
- to_print=True,
- )
- start = time.time()
- # used to save huggingface format, so we can use it for hf.from_pretrained
- output_dir = Path(args.output_dir) / "ds_native"
- tag = f"samples_{samples_seen}"
- use_lora = args.lora_r > 0
-
- # NOTE: this is a distributed save
- # if its lora, we only save the adapters
- # - so we exclude frozen if use_lora==True
- model.save_checkpoint(
- output_dir,
- exclude_frozen_parameters=use_lora,
- tag=tag, # this will create the subdirectory with the correct name
- )
-
- # for now we are not saving tokenizer, config, eg..
- # so it is not totally "HF compatible"
-
- log_rank_0(f"\033[93mModel saved in {output_dir}\033[0m", to_print=True)
- log_rank_0(f"saving took {time.time() - start} seconds")
-
-
def set_random_seed(seed):
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
-
-
-def save_checkpoint(
- args,
- accelerator: Accelerator,
- model,
- tokenizer,
- samples_seen,
- is_lora: bool,
- epoch: int = None,
- hf_format: bool = True,
- full_state: bool = False,
-) -> None:
- if hf_format:
- save_hf_format_accelerate(
- args=args,
- model=model,
- accelerator=accelerator,
- tokenizer=tokenizer,
- samples_seen=samples_seen,
- is_lora=is_lora,
- )
-
- if full_state:
- save_full_state(
- args=args,
- accelerator=accelerator,
- is_lora=is_lora,
- epoch=epoch,
- samples_seen=samples_seen,
- )
-
-
-def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int):
- """
- Saves model, optimizer, and lr_scheduler state.
- TODO: save model config - decided not to do this.
- TODO: save tokenizer - decided not to do this.
- TODO: handle LoRA
- TODO: handle granite
- """
- if is_lora:
- raise NotImplementedError("Can't save full state for LoRA at the moment.")
-
- # if args.is_granite:
- # raise NotImplementedError("Can't save full state for Granite models yet.")
-
- output_dir = Path(args.output_dir) / "full_state" / f"epoch_{epoch}"
- log_rank_0(f"\033[93mSaving full model state in {output_dir}\033[0m", to_print=True)
-
- # patch FSDP state dict method so it works correctly.
- def _get_state_dict_patched(model, unwrap=False):
- return get_state_dict_unpatched(model, unwrap=unwrap)
-
- if args.distributed_training_framework == "fsdp":
- get_state_dict_unpatched = accelerator.get_state_dict
- accelerator.get_state_dict = _get_state_dict_patched
-
- accelerator.save_state(
- output_dir=output_dir,
- # max_shard_size="5GB",
- # safe_serialization=True,
- )
-
- # save metadata file for current training status
- if accelerator.is_main_process:
- # TODO: should we set the global_step here rather than calculating global_step
- # based on samples_seen?
- metadata = {"current_epoch": epoch, "samples_seen": samples_seen}
- torch.save(metadata, output_dir / "training_metadata.json")
- log_rank_0(f"\033[93mSaving training state: {metadata}\033[0m", to_print=True)
-
- log_rank_0(f"\033[93mModel state saved in: {output_dir}\033[0m", to_print=True)
-
- # cleanup
- if args.distributed_training_framework == "fsdp":
- accelerator.get_state_dict = get_state_dict_unpatched
-
-
-def load_latest_full_state(args, accelerator) -> None:
- """
- Loads accelerator state from most recently saved checkpoint
- in `output_dir/full_state`.
- """
- output_dir = Path(args.output_dir) / "full_state"
-
- if not output_dir.is_dir():
- return
-
- # picks checkpoint with the largest number of samples by splitting the "samples_NNNN" string on _
- # and comparing the number at the end of the string
- checkpoint_list = sorted(
- list(output_dir.iterdir()),
- reverse=True,
- key=lambda x: int(str(x).rsplit("_", maxsplit=1)[-1]),
- )
-
- if len(checkpoint_list) == 0:
- log_rank_0(
- f"\033[93mNo checkpoints to load from: {output_dir}\033[0m", to_print=True
- )
- return
-
- latest = checkpoint_list[0]
-
- log_rank_0(f"\033[93mLoading state from: {latest}\033[0m", to_print=True)
- accelerator.load_state(latest)
-
- training_metadata = torch.load(latest / "training_metadata.json")
- log_rank_0(
- f"\033[93mTraining metadata loaded: {training_metadata}\033[0m", to_print=True
- )
-
- # previous epoch is basis for current epoch.
- args.__dict__["current_epoch"] = training_metadata["current_epoch"] + 1
- args.__dict__["samples_seen"] = training_metadata["samples_seen"]
-
-
-def get_projection_layer_names(model: PreTrainedModel) -> List[str]:
- """
- Given a pretrained model, returns all of the projection layers (matching '_proj')
- """
- proj_layers = set(
- name.split(".")[-1]
- for name, _ in model.named_modules()
- if name.endswith("_proj")
- )
- return list(proj_layers)
diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py
new file mode 100644
index 00000000..8fb9b7be
--- /dev/null
+++ b/tests/unit/test_model.py
@@ -0,0 +1,681 @@
+# Standard
+from pathlib import Path
+from unittest.mock import MagicMock, PropertyMock, patch
+import os
+import sys
+
+# Third Party
+from torch.distributed.fsdp import ShardingStrategy
+from torch.optim import AdamW
+from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
+from transformers.models.llama.modeling_llama import LlamaForCausalLM
+import pytest
+import torch
+import torch.nn as nn
+
+# First Party
+from instructlab.training.config import DistributedBackend, ModelTypes, Optimizers
+from instructlab.training.model import Accelerator, Checkpointer, Model, setup_optimizer
+
+
+# Define base model class at module level
+class MockBaseModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.embed_tokens = MagicMock()
+ # Add layers to base model
+ layer0 = MagicMock()
+ layer1 = MagicMock()
+ self.layers = nn.ModuleList([layer0, layer1])
+
+
+# Test fixtures
+@pytest.fixture
+def mock_tokenizer():
+ tokenizer = MagicMock(spec="PreTrainedTokenizer")
+ tokenizer.__len__.return_value = 1000
+ tokenizer.pad_token_id = 0
+ tokenizer.bos_token_id = 1
+ tokenizer.eos_token_id = 2
+ return tokenizer
+
+
+@pytest.fixture
+def mock_config():
+ config = MagicMock(spec=PretrainedConfig)
+ config.vocab_size = 1000
+ config.pad_token_id = 0
+ config.bos_token_id = 1
+ config.eos_token_id = 2
+ config.architectures = ["LlamaForCausalLM"]
+ return config
+
+
+@pytest.fixture
+def mock_model(mock_config, mock_tokenizer):
+ # Create a mock model that matches the expected structure
+ class MockBaseModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.embed_tokens = nn.Embedding(1000, 768)
+ self.forward = MagicMock()
+ # Add the projection layers directly to the base model
+ self.q_proj = nn.Linear(768, 768)
+ self.v_proj = nn.Linear(768, 768)
+
+ def prepare_inputs_for_generation(*args, **kwargs):
+ return {"input_ids": torch.tensor([[1, 2, 3]])}
+
+ self.prepare_inputs_for_generation = prepare_inputs_for_generation
+
+ class MockLlamaForCausalLM(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.config = mock_config
+ self.lora_config = (
+ Model.create_lora_config(
+ lora_target_modules=["q_proj", "v_proj"],
+ lora_alpha=16,
+ lora_dropout=0.1,
+ lora_r=8,
+ ),
+ )
+ self._no_split_modules = ["transformer"]
+ self.__class__.__name__ = "LlamaForCausalLM"
+ self.gradient_checkpointing_enable = MagicMock()
+ self.gradient_checkpointing = False
+ self.parameters = MagicMock(return_value=[])
+ self.base_model_args = {}
+ self.model_type = ModelTypes.CAUSALLM
+ self.model = self
+ self.module = self
+ self.update_model = MagicMock()
+ self.tokenizer = mock_tokenizer
+ self.base_model = MockBaseModel()
+
+ def named_modules_mock(*args, **kwargs):
+ return [
+ ("base_model.q_proj", self.base_model.q_proj),
+ ("base_model.v_proj", self.base_model.v_proj),
+ ]
+
+ self.named_modules = named_modules_mock
+
+ def get_submodule_mock(name):
+ if name == "base_model":
+ return self.base_model
+ elif name == "base_model.q_proj":
+ return self.base_model.q_proj
+ elif name == "base_model.v_proj":
+ return self.base_model.v_proj
+ return None
+
+ self.get_submodule = get_submodule_mock
+
+ def get_input_embeddings():
+ return self.base_model.embed_tokens
+
+ self.get_input_embeddings = get_input_embeddings
+
+ # Override _apply to prevent recursion
+ def _apply_mock(fn):
+ return self
+
+ self._apply = _apply_mock
+
+ # Add prepare_inputs_for_generation to match base model
+ def prepare_inputs_for_generation(*args, **kwargs):
+ return self.base_model.prepare_inputs_for_generation(*args, **kwargs)
+
+ self.prepare_inputs_for_generation = prepare_inputs_for_generation
+
+ model = MockLlamaForCausalLM()
+ return model
+
+
+@pytest.fixture
+def mock_peft_model(mock_model):
+ # Create a mock PEFT model that wraps the base model
+ class MockPEFTModel(nn.Module):
+ def __init__(self, base_model):
+ super().__init__()
+ self.base_model = base_model
+ self.lora_config = MagicMock()
+ self.lora_config = Model.create_lora_config(
+ lora_target_modules=["q_proj", "v_proj"],
+ lora_alpha=16,
+ lora_dropout=0.1,
+ lora_r=8,
+ )
+ self.lora_config.lora_alpha = 16
+ self.lora_config.lora_dropout = 0.1
+ self.lora_config.target_modules = ["q_proj", "v_proj"]
+ self.lora_config.task_type = "CAUSAL_LM"
+
+ def __getattr__(self, name):
+ if name == "base_model":
+ return self.base_model
+ return getattr(self.base_model, name)
+
+ # Override _apply to prevent recursion
+ def _apply_mock(fn):
+ return self
+
+ self._apply = _apply_mock
+
+ # Add prepare_inputs_for_generation to match base model
+ def prepare_inputs_for_generation(*args, **kwargs):
+ return self.base_model.prepare_inputs_for_generation(*args, **kwargs)
+
+ self.prepare_inputs_for_generation = prepare_inputs_for_generation
+
+ return MockPEFTModel(mock_model)
+
+
+@pytest.fixture
+def mock_dataloader():
+ dataloader = MagicMock()
+ dataloader.dataset = MagicMock()
+ dataloader.dataset.__len__.return_value = 1000
+ return dataloader
+
+
+@pytest.fixture
+def mock_distributed():
+ with (
+ patch("torch.distributed.is_initialized", return_value=True),
+ patch("torch.distributed.get_rank", return_value=0),
+ patch("torch.distributed.get_world_size", return_value=2),
+ patch("torch.distributed.barrier"),
+ ):
+ yield
+
+
+@pytest.fixture
+def mock_model_path(tmp_path):
+ return str(tmp_path / "model")
+
+
+@pytest.fixture
+def mock_output_dir(tmp_path):
+ return str(tmp_path / "output")
+
+
+# Model class tests
+class TestModel:
+ def test_model_initialization(self, mock_tokenizer, mock_model, mock_peft_model):
+ with (
+ patch(
+ "transformers.AutoModelForCausalLM.from_pretrained",
+ return_value=mock_model,
+ ) as mock_from_pretrained,
+ patch("peft.LoraModel", return_value=mock_peft_model),
+ ):
+ model = Model(
+ model_path="test_path",
+ output_dir="test_output",
+ distributed_framework=DistributedBackend.FSDP,
+ model_type=ModelTypes.CAUSALLM,
+ noise_alpha=0.1,
+ tokenizer=mock_tokenizer,
+ flash_enabled=True,
+ lora_config=Model.create_lora_config(
+ lora_target_modules=["q_proj", "v_proj"],
+ lora_alpha=16,
+ lora_dropout=0.1,
+ lora_r=8,
+ ),
+ )
+ mock_from_pretrained.assert_called_once()
+ assert model.model_type == ModelTypes.CAUSALLM
+ assert model.lora_config.r == 8
+ assert model.lora_config.lora_alpha == 16
+ assert model.lora_config.lora_dropout == 0.1
+ assert sorted(model.lora_config.target_modules) == sorted(
+ ["q_proj", "v_proj"]
+ )
+ mock_model.gradient_checkpointing_enable.assert_called_once()
+
+ def test_get_projection_layer_names(self, mock_model, mock_tokenizer):
+ with patch(
+ "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model
+ ):
+ model = Model(
+ model_path="test_path",
+ output_dir="test_output",
+ distributed_framework=DistributedBackend.FSDP,
+ model_type=ModelTypes.CAUSALLM,
+ noise_alpha=0.1,
+ tokenizer=mock_tokenizer,
+ )
+ proj_layers = model.get_projection_layer_names()
+ assert set(proj_layers) == {"q_proj", "v_proj"}
+
+ def test_prepare_peft_model_fsdp(self, mock_model, mock_tokenizer):
+ with patch(
+ "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model
+ ):
+ model = Model(
+ model_path="test_path",
+ output_dir="test_output",
+ distributed_framework=DistributedBackend.FSDP.value,
+ model_type=ModelTypes.CAUSALLM,
+ noise_alpha=0.1,
+ tokenizer=mock_tokenizer,
+ lora_config=Model.create_lora_config(
+ lora_target_modules=["q_proj", "v_proj"],
+ lora_alpha=16,
+ lora_dropout=0.1,
+ lora_r=8,
+ ),
+ )
+ with patch("peft.LoraModel") as mock_lora_model:
+ model.prepare_peft_model()
+ mock_lora_model.assert_called_once()
+
+ def test_prepare_peft_model_deepspeed(self, mock_model, mock_tokenizer):
+ with patch(
+ "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model
+ ):
+ # Mock the PeftModel check
+ mock_model.is_loaded_in_8bit = False
+ mock_model.is_loaded_in_4bit = False
+ mock_model.__class__.__name__ = "LlamaForCausalLM" # Not a PeftModel
+
+ # Create a mock PeftModel class
+ class MockPeftModel(nn.Module):
+ pass
+
+ model = Model(
+ model_path="test_path",
+ output_dir="test_output",
+ distributed_framework=DistributedBackend.DEEPSPEED.value,
+ model_type=ModelTypes.CAUSALLM,
+ noise_alpha=0.1,
+ tokenizer=mock_tokenizer,
+ lora_config=Model.create_lora_config(
+ lora_target_modules=["q_proj", "v_proj"],
+ lora_alpha=16,
+ lora_dropout=0.1,
+ lora_r=8,
+ ),
+ )
+ with (
+ patch("peft.get_peft_model") as mock_get_peft_model,
+ patch("peft.PeftModel", MockPeftModel),
+ ):
+ model.prepare_peft_model()
+ mock_get_peft_model.assert_called_once()
+
+ def test_create_lora_config(self, mock_tokenizer, mock_model):
+ with patch(
+ "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model
+ ):
+ lora_config = Model.create_lora_config(
+ lora_target_modules=["q_proj", "v_proj"],
+ lora_alpha=16,
+ lora_dropout=0.1,
+ lora_r=8,
+ )
+ assert lora_config.r == 8
+ assert lora_config.lora_alpha == 16
+ assert lora_config.lora_dropout == 0.1
+ assert sorted(lora_config.target_modules) == sorted(["q_proj", "v_proj"])
+
+ def test_reconcile_tokenizer(self, mock_tokenizer, mock_model):
+ with patch(
+ "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model
+ ):
+ model = Model(
+ model_path="test_path",
+ output_dir="test_output",
+ distributed_framework=DistributedBackend.FSDP,
+ model_type=ModelTypes.CAUSALLM,
+ noise_alpha=0.1,
+ tokenizer=mock_tokenizer,
+ )
+ model.reconcile_tokenizer()
+ assert model.model.config.pad_token_id == mock_tokenizer.pad_token_id
+ assert model.model.config.bos_token_id == mock_tokenizer.bos_token_id
+ assert model.model.config.eos_token_id == mock_tokenizer.eos_token_id
+
+ def test_supports_flash_attention(self):
+ with (
+ patch(
+ "torch.cuda.get_device_capability", return_value=(8, 0)
+ ) as mock_capability,
+ patch(
+ "torch.cuda.get_device_properties",
+ return_value=MagicMock(gcnArchName="gfx90a:0"),
+ ) as mock_props,
+ ):
+ assert Model.supports_flash_attention() is True
+ mock_capability.assert_called_once()
+ mock_props.assert_called_once()
+
+ def test_check_flash_attn_enabled(self):
+ # Test when flash attention is enabled and supported
+ with patch.object(Model, "supports_flash_attention", return_value=True):
+ assert Model.check_flash_attn_enabled(False, False) is True
+
+ # Test when flash attention is enabled but not supported
+ with patch.object(Model, "supports_flash_attention", return_value=False):
+ with pytest.raises(
+ RuntimeError,
+ match="Trying to use Flash Attention on unsupported hardware",
+ ):
+ Model.check_flash_attn_enabled(False, False)
+
+ # Test when flash attention is disabled but dolomite is enabled
+ with pytest.raises(
+ RuntimeError,
+ match="Trying to use dolomite padding-free transformer without flash attention",
+ ):
+ Model.check_flash_attn_enabled(True, True)
+
+ # Test when flash attention is disabled and dolomite is disabled
+ assert Model.check_flash_attn_enabled(True, False) is False
+
+ def test_setup_optimizer(self, mock_model, mock_tokenizer):
+ with patch(
+ "transformers.AutoModelForCausalLM.from_pretrained"
+ ) as mock_from_pretrained:
+ mock_model = MagicMock()
+ mock_model.parameters.return_value = [torch.nn.Parameter(torch.randn(2, 2))]
+ mock_model.config = MagicMock()
+ mock_model.config.vocab_size = 1000
+ mock_model.__class__.__name__ = "LlamaForCausalLM" # Set correct class name
+ mock_from_pretrained.return_value = mock_model
+ mock_tokenizer.__len__.return_value = 1000
+
+ model = Model(
+ model_path="instructlab/granite-7b-lab",
+ output_dir="test_output",
+ distributed_framework=DistributedBackend.FSDP,
+ model_type=ModelTypes.CAUSALLM,
+ noise_alpha=None,
+ tokenizer=mock_tokenizer,
+ )
+ model.model = mock_model
+
+ # Test FSDP with AdamW
+ optimizer = setup_optimizer(
+ model=model,
+ cpu_offload=False,
+ name=None,
+ learning_rate=1e-4,
+ )
+ assert isinstance(optimizer, torch.optim.AdamW)
+
+ # Test DeepSpeed with FusedAdam
+ model.distributed_framework = DistributedBackend.DEEPSPEED
+ with patch("instructlab.training.model.FusedAdam") as mock_fused_adam:
+ optimizer = setup_optimizer(
+ model=model,
+ cpu_offload=False,
+ name=None,
+ learning_rate=1e-4,
+ )
+ mock_fused_adam.assert_called_once()
+
+ # Test DeepSpeed with CPUAdam
+ with patch("instructlab.training.model.DeepSpeedCPUAdam") as mock_cpu_adam:
+ optimizer = setup_optimizer(
+ model=model,
+ cpu_offload=True,
+ name=None,
+ learning_rate=1e-4,
+ )
+ mock_cpu_adam.assert_called_once()
+
+ # Test explicit optimizer selection
+ with patch("instructlab.training.model.AdamW") as mock_adamw:
+ optimizer = setup_optimizer(
+ model=model,
+ cpu_offload=False,
+ name=Optimizers.ADAMW,
+ learning_rate=1e-4,
+ )
+ mock_adamw.assert_called_once()
+
+ def test_model_lora_initialization(
+ self, mock_model_path, mock_output_dir, mock_tokenizer
+ ):
+ with patch(
+ "transformers.AutoModelForCausalLM.from_pretrained"
+ ) as mock_from_pretrained:
+ # Create a simpler mock model
+ mock_model = MagicMock()
+ mock_model._no_split_modules = ["transformer"]
+ mock_model.config = MagicMock()
+ mock_model.config.vocab_size = 1000
+ mock_model.config.pad_token_id = 0
+ mock_model.config.bos_token_id = 1
+ mock_model.config.eos_token_id = 2
+ mock_model.config.architectures = ["LlamaForCausalLM"]
+ mock_model.gradient_checkpointing = False
+ mock_model.__class__.__name__ = "LlamaForCausalLM"
+ mock_model.parameters.return_value = [torch.nn.Parameter(torch.randn(2, 2))]
+ mock_model.base_model_args = {}
+ mock_model.model_type = ModelTypes.CAUSALLM
+ mock_model.model = mock_model
+ mock_model.module = mock_model
+ mock_model.tokenizer = mock_tokenizer
+ mock_model.base_model = MagicMock()
+ mock_model.base_model.embed_tokens = MagicMock()
+ mock_model.get_input_embeddings = MagicMock(
+ return_value=mock_model.base_model.embed_tokens
+ )
+ mock_model.base_model.q_proj = nn.Linear(768, 768)
+ mock_model.base_model.v_proj = nn.Linear(768, 768)
+
+ # Add named_modules to support LoRA
+ def named_modules_mock(*args, **kwargs):
+ return [
+ ("base_model.v_proj", mock_model.base_model.v_proj),
+ ("base_model.q_proj", mock_model.base_model.q_proj),
+ ]
+
+ mock_model.named_modules = named_modules_mock
+ mock_from_pretrained.return_value = mock_model
+
+ with (
+ patch("peft.LoraModel") as mock_peft_model,
+ patch("peft.get_peft_model") as mock_get_peft_model,
+ patch("peft.prepare_model_for_kbit_training") as mock_prepare_model,
+ ):
+ # Mock the PEFT model initialization
+ mock_peft_model.return_value = mock_model
+ mock_get_peft_model.return_value = mock_model
+ mock_prepare_model.return_value = mock_model
+
+ model = Model(
+ model_path=mock_model_path,
+ output_dir=mock_output_dir,
+ distributed_framework=DistributedBackend.FSDP,
+ model_type=ModelTypes.CAUSALLM,
+ noise_alpha=None,
+ tokenizer=mock_tokenizer,
+ lora_config=Model.create_lora_config(
+ lora_target_modules=["v_proj", "q_proj"],
+ lora_alpha=32,
+ lora_dropout=0.1,
+ lora_r=8,
+ ),
+ )
+
+ assert model.lora_config is not None
+ assert model.lora_config.r == 8
+ assert model.lora_config.lora_alpha == 32
+ assert model.lora_config.lora_dropout == 0.1
+ assert set(model.lora_config.target_modules) == {"v_proj", "q_proj"}
+
+ def test_model_reconcile_tokenizer(
+ self, mock_model_path, mock_output_dir, mock_tokenizer
+ ):
+ with patch(
+ "transformers.AutoModelForCausalLM.from_pretrained"
+ ) as mock_from_pretrained:
+ mock_model = MagicMock()
+ mock_model.config.vocab_size = 1000
+ mock_model.config.pad_token_id = None
+ mock_model.config.bos_token_id = None
+ mock_model.config.eos_token_id = None
+ mock_model.__class__.__name__ = "LlamaForCausalLM" # Set a valid class name
+ mock_model.gradient_checkpointing = False
+ mock_from_pretrained.return_value = mock_model
+
+ model = Model(
+ model_path=mock_model_path,
+ output_dir=mock_output_dir,
+ distributed_framework=DistributedBackend.FSDP,
+ model_type=ModelTypes.CAUSALLM,
+ noise_alpha=None,
+ tokenizer=mock_tokenizer,
+ )
+
+ model.reconcile_tokenizer()
+
+ assert model.model.config.pad_token_id == mock_tokenizer.pad_token_id
+ assert model.model.config.bos_token_id == mock_tokenizer.bos_token_id
+ assert model.model.config.eos_token_id == mock_tokenizer.eos_token_id
+
+
+# Accelerator class tests
+class TestAccelerator:
+ def test_accelerator_initialization(self, mock_model, mock_dataloader):
+ mock_model.lora_config = None
+ with (
+ patch(
+ "instructlab.training.utils.get_module_class_from_name",
+ return_value=MockBaseModel,
+ ),
+ patch("torch.cuda.is_available", return_value=True),
+ patch("torch.cuda.get_device_capability", return_value=(8, 0)),
+ patch(
+ "torch.cuda.get_device_properties",
+ return_value=MagicMock(gcnArchName="gfx90a:0"),
+ ),
+ patch("accelerate.utils.is_bf16_available", return_value=True),
+ patch("torch.cuda.is_bf16_supported", return_value=True),
+ patch("torch.cuda.current_device", return_value=0),
+ patch("torch.cuda._initialized", True),
+ patch("torch.cuda.is_initialized", return_value=True),
+ ):
+ accelerator = Accelerator(
+ model=mock_model,
+ samples_per_gpu=4,
+ grad_accum=2,
+ train_loader=mock_dataloader,
+ save_samples=1000,
+ distributed_framework=DistributedBackend.FSDP,
+ fsdp_sharding_strategy="FULL_SHARD",
+ fsdp_cpu_offload_params=True,
+ )
+ assert accelerator.samples_per_gpu == 4
+ assert accelerator.grad_accum == 2
+ assert accelerator.save_samples == 1000
+ mock_model.update_model.assert_called_once()
+
+ def test_setup_lr_scheduler(self, mock_model, mock_dataloader):
+ mock_model.lora_config = None
+ with (
+ patch(
+ "instructlab.training.utils.get_module_class_from_name",
+ return_value=MockBaseModel,
+ ),
+ patch("torch.cuda.is_available", return_value=True),
+ patch("torch.cuda.get_device_capability", return_value=(8, 0)),
+ patch(
+ "torch.cuda.get_device_properties",
+ return_value=MagicMock(gcnArchName="gfx90a:0"),
+ ),
+ patch("accelerate.utils.is_bf16_available", return_value=True),
+ patch("torch.cuda.is_bf16_supported", return_value=True),
+ patch("torch.cuda.current_device", return_value=0),
+ patch("torch.cuda._initialized", True),
+ patch("torch.cuda.is_initialized", return_value=True),
+ ):
+ accelerator = Accelerator(
+ model=mock_model,
+ samples_per_gpu=4,
+ grad_accum=2,
+ train_loader=mock_dataloader,
+ save_samples=1000,
+ distributed_framework=DistributedBackend.FSDP,
+ fsdp_sharding_strategy="FULL_SHARD",
+ )
+
+ # Create a real AdamW optimizer
+ params = [torch.nn.Parameter(torch.randn(2, 2))]
+ optimizer = AdamW(params, lr=0.001)
+ optimizer.param_groups = [{"lr": 0.001}]
+ optimizer.state_dict = MagicMock(
+ return_value={"param_groups": [{"lr": 0.001}]}
+ )
+ optimizer.get_lr = MagicMock(return_value=0.001)
+
+ accelerator.setup_lr_scheduler(
+ optimizer=optimizer,
+ lr_scheduler="cosine",
+ num_epochs=10,
+ num_warmup_steps=100,
+ )
+ assert hasattr(accelerator, "lr_scheduler")
+
+
+# Checkpointer class tests
+class TestCheckpointer:
+ def test_checkpointer_initialization(self, mock_model):
+ optimizer = MagicMock()
+ accelerator = MagicMock()
+
+ checkpointer = Checkpointer(
+ model=mock_model,
+ optimizer=optimizer,
+ accelerator=accelerator,
+ strategy="full_state",
+ )
+ assert checkpointer.strategy == "full_state"
+
+ def test_save_full_state(self, mock_model, tmp_path, mock_distributed):
+ optimizer = MagicMock()
+ accelerator = MagicMock()
+ accelerator.is_main_process = True
+ accelerator.save_state = MagicMock()
+ mock_model.lora_config = None
+ checkpointer = Checkpointer(
+ model=mock_model,
+ optimizer=optimizer,
+ accelerator=accelerator,
+ strategy="full_state",
+ )
+
+ output_dir = tmp_path / "test_output"
+ os.makedirs(output_dir / "full_state" / "epoch_1", exist_ok=True)
+ checkpointer.save_full_state(output_dir=output_dir, epoch=1, samples_seen=1000)
+ accelerator.save_state.assert_called_once()
+
+ def test_save_hf_format_accelerate(self, mock_model, tmp_path, mock_distributed):
+ optimizer = MagicMock()
+ accelerator = MagicMock()
+ accelerator.is_main_process = True
+ accelerator.save_model = MagicMock()
+ mock_model.lora_config = None
+ mock_model.model_type = ModelTypes.CAUSALLM
+ mock_model.module = mock_model # Ensure module is set
+ mock_model.tokenizer = MagicMock() # Ensure tokenizer is set
+
+ checkpointer = Checkpointer(
+ model=mock_model,
+ optimizer=optimizer,
+ accelerator=accelerator,
+ strategy="hf_format",
+ )
+
+ output_dir = tmp_path / "test_output"
+ os.makedirs(output_dir / "hf_format" / "samples_1000", exist_ok=True)
+ checkpointer.save_hf_format_accelerate(
+ output_dir=output_dir, epoch=1, samples_seen=1000
+ )
+ accelerator.save_model.assert_called_once()
diff --git a/tests/unit/test_train.py b/tests/unit/test_train.py
new file mode 100644
index 00000000..2156a5f1
--- /dev/null
+++ b/tests/unit/test_train.py
@@ -0,0 +1,409 @@
+# Standard
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+import os
+
+# Third Party
+from torch.optim import Optimizer
+import pytest
+import torch
+
+# First Party
+from instructlab.training.async_logger import AsyncStructuredLogger
+from instructlab.training.model import Model
+from instructlab.training.train import Metrics, train, train_epoch
+
+
+# Test fixtures
+@pytest.fixture
+def mock_model():
+ model = MagicMock(spec=Model)
+ model.train.return_value = None
+ model.get_global_grad_norm = MagicMock(return_value=1.0)
+ model.parameters.return_value = [torch.tensor([1.0])]
+ return model
+
+
+@pytest.fixture
+def mock_optimizer():
+ optimizer = MagicMock(spec=Optimizer)
+ optimizer.step.return_value = None
+ optimizer.zero_grad.return_value = None
+ return optimizer
+
+
+@pytest.fixture
+def mock_accelerator():
+ accelerator = MagicMock()
+ accelerator.device = "cpu"
+ accelerator.samples_per_gpu = 4
+ accelerator.grad_accum = 2
+ accelerator.save_samples = 1000
+ accelerator.train_loader = MagicMock()
+ accelerator.train_loader.dataset = MagicMock()
+ accelerator.train_loader.dataset.__len__.return_value = 1000
+ accelerator.train_loader.batch_sampler = MagicMock()
+ accelerator.train_loader.sampler = MagicMock()
+ accelerator.reduce = MagicMock(return_value=torch.tensor([3.0, 1.0, 1.5]))
+ accelerator.backward = MagicMock()
+ accelerator.clip_grad_norm_ = MagicMock(return_value=1.0)
+ accelerator.lr_scheduler = MagicMock()
+ accelerator.lr_scheduler.get_last_lr.return_value = [0.001]
+ return accelerator
+
+
+@pytest.fixture
+def mock_checkpointer():
+ checkpointer = MagicMock()
+ checkpointer.checkpoint.return_value = None
+ return checkpointer
+
+
+@pytest.fixture
+def mock_logger():
+ return MagicMock(spec=AsyncStructuredLogger)
+
+
+@pytest.fixture
+def mock_environment():
+ with patch.dict(os.environ, {"LOCAL_RANK": "0", "WORLD_SIZE": "2"}):
+ yield
+
+
+@pytest.fixture
+def mock_distributed():
+ with (
+ patch("torch.distributed.is_initialized", return_value=True),
+ patch("torch.distributed.get_rank", return_value=0),
+ patch("torch.distributed.get_world_size", return_value=2),
+ ):
+ yield
+
+
+@pytest.fixture
+def mock_cuda():
+ with (
+ patch("torch.cuda.memory_allocated", return_value=0),
+ patch("torch.cuda.memory_stats", return_value={"num_alloc_retries": 0}),
+ ):
+ yield
+
+
+class TestTrainEpoch:
+ def test_train_epoch_basic(
+ self,
+ mock_model,
+ mock_optimizer,
+ mock_accelerator,
+ mock_checkpointer,
+ mock_logger,
+ mock_environment,
+ mock_distributed,
+ mock_cuda,
+ ):
+ # Setup mock batch
+ mock_batch = {
+ "input_ids": torch.tensor([[1, 2, 3]]),
+ "attention_mask": torch.tensor([[1, 1, 1]]),
+ "labels": torch.tensor([[1, 2, 3]]),
+ "num_loss_counted_tokens": 3,
+ "num_samples": 1,
+ }
+ mock_accelerator.train_loader.__iter__.return_value = [mock_batch]
+
+ # Setup mock model output
+ mock_output = MagicMock()
+ mock_output.loss = torch.tensor(0.5)
+ mock_model.return_value = mock_output
+
+ # Run train_epoch
+ _, metrics = train_epoch(
+ epoch_number=0,
+ samples_seen=0,
+ local_rank=0,
+ global_step=2,
+ last_step=0,
+ world_size=2,
+ batch_size=4,
+ samples_per_gpu=4,
+ checkpoint_at_epoch=True,
+ output_dir="test_output",
+ sampler="distributed",
+ checkpointer=mock_checkpointer,
+ model=mock_model,
+ optimizer=mock_optimizer,
+ accelerator=mock_accelerator,
+ use_dolomite=True,
+ metric_logger=mock_logger,
+ )
+
+ # Verify metrics
+ assert metrics is not None
+ assert metrics[-1].samples_seen == 1
+ assert metrics[-1].total_loss == 0.5
+ assert metrics[-1].batch_size == 1
+ assert metrics[-1].num_loss_counted_tokens == 3
+ assert metrics[-1].global_grad_norm == 1.0
+ assert metrics[-1].total_samples == 1000
+ assert metrics[-1].overall_throughput > 0
+ assert metrics[-1].current_lr == 0.001
+
+ # Verify calls
+ mock_model.assert_called_once()
+ mock_optimizer.step.assert_called_once()
+ mock_optimizer.zero_grad.assert_called_once()
+ mock_accelerator.backward.assert_called_once()
+ mock_accelerator.clip_grad_norm_.assert_called_once()
+ mock_accelerator.lr_scheduler.step.assert_called_once()
+
+ def test_train_epoch_with_multipack_sampler(
+ self,
+ mock_model,
+ mock_optimizer,
+ mock_accelerator,
+ mock_checkpointer,
+ mock_logger,
+ mock_environment,
+ mock_distributed,
+ mock_cuda,
+ ):
+ # Setup mock batch
+ mock_batch = {
+ "input_ids": torch.tensor([[1, 2, 3]]),
+ "attention_mask": torch.tensor([[1, 1, 1]]),
+ "labels": torch.tensor([[1, 2, 3]]),
+ "num_loss_counted_tokens": 3,
+ "num_samples": 1,
+ }
+ mock_accelerator.train_loader.__iter__.return_value = [mock_batch]
+
+ # Setup mock model output
+ mock_output = MagicMock()
+ mock_output.loss = torch.tensor(0.5)
+ mock_model.return_value = mock_output
+
+ # Run train_epoch with multipack sampler
+ _, metrics = train_epoch(
+ epoch_number=0,
+ samples_seen=0,
+ local_rank=0,
+ global_step=2,
+ last_step=0,
+ world_size=2,
+ batch_size=4,
+ samples_per_gpu=4,
+ checkpoint_at_epoch=True,
+ output_dir="test_output",
+ sampler="multipack",
+ checkpointer=mock_checkpointer,
+ model=mock_model,
+ optimizer=mock_optimizer,
+ accelerator=mock_accelerator,
+ use_dolomite=True,
+ metric_logger=mock_logger,
+ )
+
+ # Verify metrics
+ assert metrics is not None
+ assert metrics[-1].samples_seen == 1
+ assert metrics[-1].total_loss == 0.5
+ assert metrics[-1].batch_size == 1
+ assert metrics[-1].num_loss_counted_tokens == 3
+ assert metrics[-1].global_grad_norm == 1.0
+ assert metrics[-1].total_samples == 1000
+ assert metrics[-1].overall_throughput > 0
+ assert metrics[-1].current_lr == 0.001
+
+ # Verify calls
+ mock_model.assert_called_once()
+ mock_optimizer.step.assert_called_once()
+ mock_optimizer.zero_grad.assert_called_once()
+ mock_accelerator.backward.assert_called_once()
+ mock_accelerator.clip_grad_norm_.assert_called_once()
+ mock_accelerator.lr_scheduler.step.assert_called_once()
+ mock_accelerator.train_loader.batch_sampler.set_epoch.assert_called_once_with(0)
+
+ def test_train_epoch_invalid_sampler(
+ self,
+ mock_model,
+ mock_optimizer,
+ mock_accelerator,
+ mock_checkpointer,
+ mock_logger,
+ ):
+ with pytest.raises(AttributeError) as exc_info:
+ train_epoch(
+ epoch_number=0,
+ samples_seen=0,
+ local_rank=0,
+ global_step=1,
+ last_step=0,
+ world_size=2,
+ batch_size=8,
+ samples_per_gpu=4,
+ checkpoint_at_epoch=True,
+ output_dir="test_output",
+ sampler="invalid",
+ checkpointer=mock_checkpointer,
+ model=mock_model,
+ optimizer=mock_optimizer,
+ accelerator=mock_accelerator,
+ use_dolomite=False,
+ metric_logger=mock_logger,
+ )
+ assert "Sampler invalid is invalid" in str(exc_info.value)
+
+
+class TestTrain:
+ def test_train_basic(
+ self,
+ mock_model,
+ mock_optimizer,
+ mock_accelerator,
+ mock_checkpointer,
+ mock_logger,
+ mock_environment,
+ mock_distributed,
+ mock_cuda,
+ ):
+ # Setup mock metrics
+ mock_metrics = Metrics(
+ samples_seen=4,
+ total_loss=0.5,
+ batch_size=8,
+ num_loss_counted_tokens=100,
+ global_grad_norm=1.0,
+ total_samples=1000,
+ overall_throughput=100.0,
+ current_lr=0.001,
+ )
+ mock_list_metrics = [mock_metrics]
+
+ # Run train with mock train_epoch
+ with patch(
+ "instructlab.training.train.train_epoch",
+ return_value=(0, mock_list_metrics),
+ ):
+ train(
+ model=mock_model,
+ optimizer=mock_optimizer,
+ accelerator=mock_accelerator,
+ metric_logger=mock_logger,
+ checkpointer=mock_checkpointer,
+ effective_batch_size=8,
+ num_epochs=1,
+ last_step=0,
+ checkpoint_at_epoch=True,
+ output_dir="test_output",
+ use_dolomite=False,
+ save_last=True,
+ sampler="distributed",
+ )
+
+ # Verify calls
+ mock_model.train.assert_called_once()
+ mock_checkpointer.save_hf_format_accelerate.assert_called_with(
+ output_dir="test_output",
+ epoch=1,
+ samples_seen=0,
+ last_epoch=True,
+ )
+
+ def test_train_with_save_samples(
+ self,
+ mock_model,
+ mock_optimizer,
+ mock_accelerator,
+ mock_checkpointer,
+ mock_logger,
+ mock_environment,
+ mock_distributed,
+ mock_cuda,
+ ):
+ # Setup mock metrics
+ mock_metrics = Metrics(
+ samples_seen=4,
+ total_loss=0.5,
+ batch_size=8,
+ num_loss_counted_tokens=100,
+ global_grad_norm=1.0,
+ total_samples=1000,
+ overall_throughput=100.0,
+ current_lr=0.001,
+ )
+ mock_list_metrics = [mock_metrics]
+
+ # Set save_samples to match batch size
+ mock_accelerator.save_samples = 8
+
+ # Run train with mock train_epoch
+ with patch(
+ "instructlab.training.train.train_epoch",
+ return_value=(0, mock_list_metrics),
+ ):
+ train(
+ model=mock_model,
+ optimizer=mock_optimizer,
+ accelerator=mock_accelerator,
+ metric_logger=mock_logger,
+ checkpointer=mock_checkpointer,
+ effective_batch_size=8,
+ num_epochs=1,
+ last_step=0,
+ checkpoint_at_epoch=True,
+ output_dir="test_output",
+ use_dolomite=False,
+ save_last=True,
+ sampler="distributed",
+ )
+
+ # Verify save_samples was adjusted
+ assert mock_accelerator.save_samples == 8
+
+ def test_train_with_resume(
+ self,
+ mock_model,
+ mock_optimizer,
+ mock_accelerator,
+ mock_checkpointer,
+ mock_logger,
+ mock_environment,
+ mock_distributed,
+ mock_cuda,
+ ):
+ # Setup mock metrics
+ mock_metrics = Metrics(
+ samples_seen=4,
+ total_loss=0.5,
+ batch_size=8,
+ num_loss_counted_tokens=100,
+ global_grad_norm=1.0,
+ total_samples=1000,
+ overall_throughput=100.0,
+ current_lr=0.001,
+ )
+ mock_list_metrics = [mock_metrics]
+
+ # Run train with mock train_epoch and resume from step 5
+ with patch(
+ "instructlab.training.train.train_epoch",
+ return_value=(0, mock_list_metrics),
+ ):
+ train(
+ model=mock_model,
+ optimizer=mock_optimizer,
+ accelerator=mock_accelerator,
+ metric_logger=mock_logger,
+ checkpointer=mock_checkpointer,
+ effective_batch_size=8,
+ num_epochs=1,
+ last_step=5,
+ checkpoint_at_epoch=True,
+ output_dir="test_output",
+ use_dolomite=False,
+ save_last=True,
+ sampler="distributed",
+ )
+
+ # Verify training started from the correct step
+ mock_model.train.assert_called_once()