Skip to content

Training loop from scratch with JAX doesn't converge #21489

@heydaari

Description

@heydaari

Hello, i have tried using training loop from scratch with jax in keras-3 but the model doesn't converge in training. i tested the model.fit() API to check the issue and it perfectly trains the model and updates the weights.

This is the gist of the issue:
https://gist.github.com/heydaari/0645956a773fac764cd974f5f1850ea3

Image

The first logs are the last epoch logs of training loop from scratch, and the logs from model.fit()

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions