Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,17 @@
Detailed installation instructions can be found [here](https://calcil.readthedocs.io/en/latest/installation.html).
```
# Create a virtual environment
conda create -n calcil python=3.9
conda create -n calcil python=3.10
conda activate calcil

# (optional, if needed) Install CUDA in conda virtual env
conda install -c conda-forge cudatoolkit~=11.8.0 cudnn~=8.8.0
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc

# Install jaxlib for GPU
pip install jaxlib==0.3.18+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Install this library
# Install this library (defaults to JAX CPU version)
pip install git+https://github.com/rmcao/CalCIL.git
```
**GPU Support:** To enable GPU support, you must install the CUDA-enabled version of JAX. For example:
```bash
pip install -U "jax[cuda12]"
```
Please refer to the official [JAX installation guide](https://jax.readthedocs.io/en/latest/installation.html) for more details on installing JAX with CUDA or TPU support.

## Tutorials
A step-by-step tutorial on how to use CalCIL for image reconstruction can be found [here](https://calcil.readthedocs.io/en/latest/getting_started.html).
Expand Down
6 changes: 3 additions & 3 deletions calcil/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ def run_reconstruction(state: train_state.TrainState,

batch_info = []
for i_batch, input_dict in enumerate(input_batches):
cur_rngs = jax.tree_map(lambda rng: jax.random.split(rng)[0], rngs)
rngs = jax.tree_map(lambda rng: jax.random.split(rng)[1], rngs)
cur_rngs = jax.tree.map(lambda rng: jax.random.split(rng)[0], rngs)
rngs = jax.tree.map(lambda rng: jax.random.split(rng)[1], rngs)

state, info = update_fn(state, input_dict, cur_rngs)

Expand Down Expand Up @@ -356,7 +356,7 @@ def run_reconstruction(state: train_state.TrainState,

# save checkpoints
if ((s+1) % recon_param.checkpoint_every == 0) or (s + 1 == recon_param.n_epoch):
checkpoints.save_checkpoint(recon_param.save_dir, state, s+1,
checkpoints.save_checkpoint(os.path.abspath(recon_param.save_dir), state, s+1,
keep=recon_param.keep_checkpoints, overwrite=True)

print('Total elapsed time in sec: {:#.5g}'.format(time.time() - recon_start_time), end='')
Expand Down
29 changes: 14 additions & 15 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,28 @@ Step-by-step Installation

.. code-block:: bash

$ conda create -n calcil python=3.9
$ conda create -n calcil python=3.10
$ conda activate calcil


2. Install CUDA and cuDNN in conda virtual env (you may opt to skip this step if you have CUDA installed in your system and you know what you are doing)
2. Install CalCIL. You may use -e flag to install in editable mode.

.. code-block:: bash

$ conda install -c conda-forge cudatoolkit~=11.8.0 cudnn~=8.8.0
$ conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
$ pip install git+https://github.com/rmcao/CalCIL.git

.. note::

3. Install jaxlib. Note that the following command is for CUDA 11.x and cuDNN 8.2+. If you have different versions of CUDA, please refer to `JAX installation guide <https://jax.readthedocs.io/en/latest/installation.html>`__ and make sure to match the version numbers of jaxlib and jax (as specified in requirements.txt).
This will install the standard CPU-only version of JAX by default.

3. (Optional) Enable GPU support.

To enable GPU support, you must install the CUDA-enabled version of JAX. For example:

.. code-block:: bash

$ pip install jaxlib==0.3.18+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
$ pip install -U "jax[cuda12]"

Please refer to the official `JAX installation guide <https://jax.readthedocs.io/en/latest/installation.html>`__ for more details on installing JAX with CUDA or TPU support.

.. note::

Expand All @@ -41,15 +47,8 @@ Step-by-step Installation

$ python -c "import jax.numpy as jnp; print(jnp.ones(5)+jnp.zeros(5))"

4. pip install CalCIL. You may use -e flag to install in editable mode.

.. code-block:: bash

$ pip install git+https://github.com/rmcao/CalCIL.git

5. Install optional dependencies for interactive visualization via Jupyter lab
4. Install optional dependencies for interactive visualization via Jupyter lab

.. code-block:: bash

$ conda install -c conda-forge jupyterlab nodejs ipympl

23 changes: 11 additions & 12 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
chex==0.1.6
flax==0.6.0
jax==0.3.18
optax==0.1.3
protobuf~=3.20.0
tensorflow==2.8.1
scikit-image~=0.19.3
scikit-learn~=1.2.2
scipy==1.10.1
pandas~=1.5.2
matplotlib~=3.7.0
https://storage.googleapis.com/jax-releases/nocuda/jaxlib-0.3.18-cp39-cp39-manylinux2014_x86_64.whl
chex==0.1.91
flax==0.12.2
jax==0.8.2
optax==0.2.6
protobuf==5.29.3
tensorflow==2.20.0
scikit-image==0.26.0
scikit-learn==1.6.1
scipy==1.17.0
pandas==2.2.3
matplotlib==3.10.0
.
sphinx-rtd-theme
17 changes: 8 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
chex==0.1.6
flax==0.6.0
jax==0.3.18
optax==0.1.3
protobuf~=3.20.0
tensorflow==2.8.1
scikit-image~=0.19.3
scipy==1.10.1
numpy<2.0.00
chex==0.1.91
flax==0.12.2
jax==0.8.2
numpy==2.4.1
optax==0.2.6
scikit-image==0.26.0
scipy==1.17.0
tensorflow==2.20.0