-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JaxCNN #35
base: master
Are you sure you want to change the base?
JaxCNN #35
Conversation
@andyyPark Thanks for your entry! Ahaha, am I glad to see another JAX neural network approach! When you have them, could you add to your description a few metrics, I'm very curious to see how the CNN compares to the Dense network ;-) |
Since my original submission, I have changed my original jaxCNN to maximize the FOM_DETF score, and have implemented ResNet50 (jaxResNet.py) using jax and jax-cosmo library. Although I don't have any scores to show, below is an example of the binning generated for 5 and 6 bins using the Buzzard dataset: 5 Bins (jaxCNN)6 Bins (jaxCNN)5 Bins (jaxResNet)6 Bins (jaxResNet)I haven't tried this yet but it seems like my original CNN network (jaxCNN.py) performs better than ResNet50 with epochs ~O(100). |
This method was inspired by #9 where I exploited the auto differentiable metrics from
jax-cosmo
library. Flax was used to implement a simple Convolutional Neural Network that assigns bins to galaxies.My network was
riz
bandsI have created a jupyter notebook,
JaxCNN.ipynb
, that walks through my code in thenotebooks
folder, but I still have yet to finish running the notebook on NERSC.Below is an example of the binning generated for 4 bins:
Scores.ipynb
in thenotebooks
folder shows the plots of the metrics for a different number of bins.FOM_3x2
FOM_DETF_3x2
SNR_3x2