Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JaxCNN #35

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open

JaxCNN #35

wants to merge 13 commits into from

Conversation

andyyPark
Copy link
Member

@andyyPark andyyPark commented Sep 1, 2020

This method was inspired by #9 where I exploited the auto differentiable metrics from jax-cosmo library. Flax was used to implement a simple Convolutional Neural Network that assigns bins to galaxies.

My network was

  • optimized for the total 3x2 FOM
  • trained on riz bands

I have created a jupyter notebook, JaxCNN.ipynb, that walks through my code in the notebooks folder, but I still have yet to finish running the notebook on NERSC.

Below is an example of the binning generated for 4 bins:
4_riz

Scores.ipynb in the notebooks folder shows the plots of the metrics for a different number of bins.

FOM_3x2

FOM_3x2

FOM_DETF_3x2

FOM_DETF_3x2

SNR_3x2

SNR_3x2

@andyyPark andyyPark marked this pull request as draft September 1, 2020 07:01
@EiffL
Copy link
Member

EiffL commented Sep 1, 2020

@andyyPark Thanks for your entry! Ahaha, am I glad to see another JAX neural network approach! When you have them, could you add to your description a few metrics, I'm very curious to see how the CNN compares to the Dense network ;-)

@EiffL EiffL added the entry Challenge entry label Sep 1, 2020
@andyyPark andyyPark marked this pull request as ready for review September 8, 2020 03:02
@andyyPark
Copy link
Member Author

andyyPark commented Sep 15, 2020

Since my original submission, I have changed my original jaxCNN to maximize the FOM_DETF score, and have implemented ResNet50 (jaxResNet.py) using jax and jax-cosmo library. Although I don't have any scores to show, below is an example of the binning generated for 5 and 6 bins using the Buzzard dataset:

5 Bins (jaxCNN)

image

6 Bins (jaxCNN)

image

5 Bins (jaxResNet)

image

6 Bins (jaxResNet)

image

I haven't tried this yet but it seems like my original CNN network (jaxCNN.py) performs better than ResNet50 with epochs ~O(100).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
entry Challenge entry
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants