Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable to assign different parameters dtype during training #1037

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jialingt
Copy link

@jialingt jialingt commented Mar 5, 2025

This PR enables train_dtype to accept both jnp.dype and also PerParamFn[jnp.dtype]:

  1. jnp.dtype, where both float inputs and model parameters will be cast to this dtype.

  2. ConfigOr[PerParamFn[jnp.dtype]], allowing different dtypes to be applied to different parameters during training.

@jialingt jialingt requested review from ruomingp, markblee and a team as code owners March 5, 2025 22:07
@jialingt jialingt changed the title Enable assign different train_dtype for parameters during training Enable assign different parameters dtype during training Mar 5, 2025
@jialingt jialingt changed the title Enable assign different parameters dtype during training Enable to assign different parameters dtype during training Mar 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant