diff --git a/docs/source/data_loaders_on_cpu_with_jax.ipynb b/docs/source/data_loaders_on_cpu_with_jax.ipynb index 6f15829..d367e63 100644 --- a/docs/source/data_loaders_on_cpu_with_jax.ipynb +++ b/docs/source/data_loaders_on_cpu_with_jax.ipynb @@ -6,7 +6,7 @@ "id": "PUFGZggH49zp" }, "source": [ - "# Introduction to Data Loaders on CPU with JAX" + "# Data loading on a CPU with JAX" ] }, { @@ -17,20 +17,16 @@ "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/data_loaders_on_cpu_with_jax.ipynb)\n", "\n", - "This tutorial explores different data loading strategies for using **JAX** on a single [**CPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-CPU). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including:\n", + "This tutorial shows how to efficiently load data on a [**single CPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-CPU) for image classification on the MNIST datset using various dataset libraries, such as:\n", "\n", - "- [**PyTorch DataLoader**](https://github.com/pytorch/data)\n", - "- [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets)\n", - "- [**Grain**](https://github.com/google/grain)\n", - "- [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)\n", + "- [Grain](https://github.com/google/grain)\n", + "- [PyTorch DataLoader](https://github.com/pytorch/data)\n", + "- [Hugging Face](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)\n", + "- [TensorFlow Datasets (TFDS)](https://github.com/tensorflow/datasets)\n", "\n", - "In this tutorial, you'll learn how to efficiently load data using these libraries for a simple image classification task based on the MNIST dataset.\n", + "Compared with [GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) or [multi-device setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html), data loading on a CPU can avoid such challenges, as GPU memory management and data synchronization across devices. This can be helpful for smaller-scale tasks or scenarios where data resides exclusively on a CPU.\n", "\n", - "Compared to GPU or multi-device setups, CPU-based data loading is straightforward as it avoids challenges like GPU memory management and data synchronization across devices. This makes it ideal for smaller-scale tasks or scenarios where data resides exclusively on the CPU.\n", - "\n", - "If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html).\n", - "\n", - "If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html)." + "**Note:** To learn about **GPU-based data loading** with JAX, go to [Data loading on a GPU with JAX](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html). For a **multi-device data loading strategy** with JAX, check out [Data loading on multiple devices with JAX](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html)." ] }, { @@ -39,9 +35,9 @@ "id": "pEsb135zE-Jo" }, "source": [ - "## Setting JAX to Use CPU Only\n", + "## Set JAX to use only the CPU\n", "\n", - "First, you'll restrict JAX to use only the CPU, even if a GPU is available. This ensures consistency and allows you to focus on CPU-based data loading." + "First, let's set JAX to use only the CPU, even if a GPU is available. This ensures consistency and allows us to focus on CPU-based data loading:" ] }, { @@ -62,7 +58,7 @@ "id": "-rsMgVtO6asW" }, "source": [ - "Import JAX API" + "Next, let's import certain JAX modules, including JAX NumPy, `jax.random`, and three JAX transformations, such as `jax.grad`, `jax.jit` and `jax.vmap`:" ] }, { @@ -75,7 +71,11 @@ "source": [ "import jax\n", "import jax.numpy as jnp\n", - "from jax import random, grad, jit, vmap" + "from jax import random, grad, jit, vmap\n", + "\n", + "from jax.scipy.special import logsumexp\n", + "\n", + "import time" ] }, { @@ -84,7 +84,7 @@ "id": "TsFdlkSZKp9S" }, "source": [ - "### CPU Setup Verification" + "Let's verify the CPU setup with `jax.devices()`:" ] }, { @@ -119,9 +119,9 @@ "id": "qyJ_WTghDnIc" }, "source": [ - "## Setting Hyperparameters and Initializing Parameters\n", + "## Setting hyperparameters and initializing parameters\n", "\n", - "You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network." + "We'll define certain hyperparameters for the model and data loading, including layer sizes, the learning rate, the batch size, and the data directory. We'll also initialize the weights and biases for the fully-connected neural network in our example." ] }, { @@ -132,24 +132,25 @@ }, "outputs": [], "source": [ - "# A helper function to randomly initialize weights and biases\n", - "# for a dense neural network layer\n", + "# Define a helper function to initialize model weights and biases\n", + "# using a random normal distribution using `jax.random.normal()`.\n", "def random_layer_params(m, n, key, scale=1e-2):\n", - " w_key, b_key = random.split(key)\n", - " return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n", + " w_key, b_key = random.split(key) # Split the JAX PRNG key.\n", + " return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n", "\n", - "# Function to initialize network parameters for all layers based on defined sizes\n", + "# Define a function to initialize network parameters for all layers based on defined sizes\n", + "# using the previously created `random_layer_params()` function.\n", "def init_network_params(sizes, key):\n", - " keys = random.split(key, len(sizes))\n", - " return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]\n", + " keys = random.split(key, len(sizes)) # Split the JAX PRNG key.\n", + " return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]\n", "\n", - "layer_sizes = [784, 512, 512, 10] # Layers of the network\n", - "step_size = 0.01 # Learning rate for optimization\n", - "num_epochs = 8 # Number of training epochs\n", - "batch_size = 128 # Batch size for training\n", - "n_targets = 10 # Number of classes (digits 0-9)\n", - "num_pixels = 28 * 28 # Input size (MNIST images are 28x28 pixels)\n", - "data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset\n", + "layer_sizes = [784, 512, 512, 10] # Layers of the network.\n", + "step_size = 0.01 # Learning rate for optimization.\n", + "num_epochs = 8 # Number of training epochs.\n", + "batch_size = 128 # Batch size for training.\n", + "n_targets = 10 # Number of classes (digits 0-9).\n", + "num_pixels = 28 * 28 # Input size (MNIST images are 28x28 pixels).\n", + "data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset.\n", "\n", "# Initialize network parameters using the defined layer sizes and a random seed\n", "params = init_network_params(layer_sizes, random.PRNGKey(0))" @@ -161,11 +162,11 @@ "id": "6Ci_CqW7q6XM" }, "source": [ - "## Model Prediction with Auto-Batching\n", + "## Model prediction with auto-batching with `jax.vmap`\n", "\n", - "In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n", + "Next, we'll define the `predict()` function that computes the output of the network for a single input image.\n", "\n", - "To efficiently process multiple images simultaneously, you'll use [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap), which allows you to vectorize the `predict` function and apply it across a batch of inputs. This technique, called auto-batching, improves computational efficiency by leveraging hardware acceleration." + "To efficiently process multiple images simultaneously, we'll use the [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) transformation, which will vectorize the `predict` function and apply it across a batch of inputs. This is also known as auto-batching, which improves computational efficiency by leveraging hardware acceleration." ] }, { @@ -176,23 +177,23 @@ }, "outputs": [], "source": [ - "from jax.scipy.special import logsumexp\n", - "\n", + "# Create the ReLU activation function.\n", "def relu(x):\n", - " return jnp.maximum(0, x)\n", + " return jnp.maximum(0, x)\n", "\n", + "# Define the prediction function.\n", "def predict(params, image):\n", - " # per-example prediction\n", - " activations = image\n", - " for w, b in params[:-1]:\n", - " outputs = jnp.dot(w, activations) + b\n", - " activations = relu(outputs)\n", + " # Per-example prediction.\n", + " activations = image\n", + " for w, b in params[:-1]:\n", + " outputs = jnp.dot(w, activations) + b\n", + " activations = relu(outputs)\n", "\n", - " final_w, final_b = params[-1]\n", - " logits = jnp.dot(final_w, activations) + final_b\n", - " return logits - logsumexp(logits)\n", + " final_w, final_b = params[-1]\n", + " logits = jnp.dot(final_w, activations) + final_b\n", + " return logits - logsumexp(logits)\n", "\n", - "# Make a batched version of the `predict` function\n", + "# Using `jax.vmap`, make a batched version of the `predict()` function.\n", "batched_predict = vmap(predict, in_axes=(None, 0))" ] }, @@ -202,18 +203,18 @@ "id": "niTSr34_sDZi" }, "source": [ - "## Utility and Loss Functions\n", + "## Set up one-hot encoding, accuracy calculation, and the loss function with `jax.grad` and `jax.jit`\n", "\n", - "You'll now define utility functions for:\n", + "Next, we'll define some utility functions for:\n", "\n", - "- One-hot encoding: Converts class indices to binary vectors.\n", - "- Accuracy calculation: Measures the performance of the model on the dataset.\n", - "- Loss computation: Calculates the difference between predictions and targets.\n", + "- One-hot encoding to convert class indices to binary vectors.\n", + "- Accuracy calculation for measuring the performance of the model on the dataset.\n", + "- The loss function for calculating the difference between predictions and targets.\n", "\n", - "To optimize performance:\n", + "To optimize performance, we'll use the following JAX automatic differentiation and compilation transformations:\n", "\n", - "- [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters.\n", - "- [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation." + "- [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters.\n", + "- [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla)." ] }, { @@ -224,47 +225,34 @@ }, "outputs": [], "source": [ - "import time\n", - "\n", "def one_hot(x, k, dtype=jnp.float32):\n", - " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", - " return jnp.array(x[:, None] == jnp.arange(k), dtype)\n", + " \"\"\"Creates a one-hot encoding function of x of size k.\"\"\"\n", + " return jnp.array(x[:, None] == jnp.arange(k), dtype)\n", "\n", "def accuracy(params, images, targets):\n", - " \"\"\"Calculate the accuracy of predictions.\"\"\"\n", - " target_class = jnp.argmax(targets, axis=1)\n", - " predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n", - " return jnp.mean(predicted_class == target_class)\n", + " \"\"\"Calculates the accuracy of predictions.\"\"\"\n", + " target_class = jnp.argmax(targets, axis=1)\n", + " predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n", + " return jnp.mean(predicted_class == target_class)\n", "\n", "def loss(params, images, targets):\n", - " \"\"\"Calculate the loss between predictions and targets.\"\"\"\n", - " preds = batched_predict(params, images)\n", - " return -jnp.mean(preds * targets)\n", + " \"\"\"Calculates the loss between predictions and targets.\"\"\"\n", + " preds = batched_predict(params, images)\n", + " return -jnp.mean(preds * targets)\n", "\n", + "# Apply the `@jax.jit` decorator for faster execution.\n", "@jit\n", "def update(params, x, y):\n", - " \"\"\"Update the network parameters using gradient descent.\"\"\"\n", - " grads = grad(loss)(params, x, y)\n", - " return [(w - step_size * dw, b - step_size * db)\n", - " for (w, b), (dw, db) in zip(params, grads)]\n", + " \"\"\"Updates the network parameters using gradient descent.\"\"\"\n", + " grads = grad(loss)(params, x, y)\n", + " return [(w - step_size * dw, b - step_size * db)\n", + " for (w, b), (dw, db) in zip(params, grads)]\n", "\n", "def reshape_and_one_hot(x, y):\n", - " \"\"\"Reshape and one-hot encode the inputs.\"\"\"\n", + " \"\"\"Reshapes and one-hot encode the inputs.\"\"\"\n", " x = jnp.reshape(x, (len(x), num_pixels))\n", " y = one_hot(y, n_targets)\n", - " return x, y\n", - "\n", - "def train_model(num_epochs, params, training_generator, data_loader_type='streamed'):\n", - " \"\"\"Train the model for a given number of epochs.\"\"\"\n", - " for epoch in range(num_epochs):\n", - " start_time = time.time()\n", - " for x, y in training_generator() if data_loader_type == 'streamed' else training_generator:\n", - " x, y = reshape_and_one_hot(x, y)\n", - " params = update(params, x, y)\n", - "\n", - " print(f\"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: \"\n", - " f\"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, \"\n", - " f\"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}\")" + " return x, y" ] }, { @@ -273,9 +261,11 @@ "id": "Hsionp5IYsQ9" }, "source": [ - "## Loading Data with PyTorch DataLoader\n", + "## PyTorch `DataLoader`\n", "\n", - "This section shows how to load the MNIST dataset using PyTorch's DataLoader, convert the data to NumPy arrays, and apply transformations to flatten and cast images." + "This section shows how to load the MNIST dataset using PyTorch's `DataLoader`, convert the data into NumPy arrays, and apply transformations to flatten and cast images.\n", + "\n", + "**Note:** We'll be using PyTorch with `torchvision`, so the next step shows how to install the packages." ] }, { @@ -312,6 +302,16 @@ "!pip install torch torchvision" ] }, + { + "cell_type": "markdown", + "id": "423c5f49", + "metadata": {}, + "source": [ + "### Load the `torchvision` dataset and standardize it\n", + "\n", + "First, we'lll create some helper functions and classes for convert `torchvision` MNIST dataset into NumPy arrays, and further transforming and preprocessing the data." + ] + }, { "cell_type": "code", "execution_count": 8, @@ -335,18 +335,24 @@ "outputs": [], "source": [ "def numpy_collate(batch):\n", - " \"\"\"Convert a batch of PyTorch data to NumPy arrays.\"\"\"\n", - " return tree_map(np.asarray, data.default_collate(batch))\n", + " \"\"\"Converts a batch of PyTorch data into NumPy arrays.\"\"\"\n", + " return tree_map(np.asarray, data.default_collate(batch))\n", "\n", "class NumpyLoader(data.DataLoader):\n", - " \"\"\"Custom DataLoader to return NumPy arrays from a PyTorch Dataset.\"\"\"\n", + " \"\"\"A custom NumPy `DataLoader` for the PyTorch generator,\n", + " subclasses `torch.utils.data.DataLoader`.\n", + "\n", + " Returns:\n", + " NumPy arrays from the PyTorch dataset.\n", + "\n", + " \"\"\"\n", " def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):\n", " super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=numpy_collate, **kwargs)\n", "\n", "class FlattenAndCast(object):\n", - " \"\"\"Transform class to flatten and cast images to float32.\"\"\"\n", - " def __call__(self, pic):\n", - " return np.ravel(np.array(pic, dtype=jnp.float32))" + " \"\"\"A transformation class for flattening and casting images to `float32`.\"\"\"\n", + " def __call__(self, pic):\n", + " return np.ravel(np.array(pic, dtype=jnp.float32))" ] }, { @@ -355,9 +361,7 @@ "id": "mfSnfJND6I8G" }, "source": [ - "### Load Dataset with Transformations\n", - "\n", - "Standardize the data by flattening the images, casting them to `float32`, and ensuring consistent data types." + "Now we can load the data, standardize it by flattening the images, casting them to `float32`, and ensuring consistent data types." ] }, { @@ -485,9 +489,9 @@ "id": "kbdsqvPZGrsa" }, "source": [ - "### Full Training Dataset for Accuracy Checks\n", + "### Convert the training and test sets\n", "\n", - "Convert the entire training dataset to JAX arrays." + "Next, we'll convert the whole training dataset to JAX arrays (`jax.Array`s) and one-hot encode the labels." ] }, { @@ -508,9 +512,7 @@ "id": "WXUh0BwvG8Ko" }, "source": [ - "### Get Full Test Dataset\n", - "\n", - "Load and process the full test dataset." + "After this, we can load and process the test dataset:" ] }, { @@ -521,7 +523,9 @@ }, "outputs": [], "source": [ + "# Load the test set.\n", "mnist_dataset_test = MNIST(data_dir, download=True, train=False)\n", + "# Convert to JAX arrays, one-hot encode the labels.\n", "test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32)\n", "test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets)" ] @@ -557,11 +561,11 @@ "id": "m3zfxqnMiCbm" }, "source": [ - "### Training Data Generator\n", + "### Create the PyTorch training generator\n", "\n", - "Define a generator function using PyTorch's DataLoader for batch training. Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload.\n", + "We'll define a generator function using PyTorch's `DataLoader` for batch training. By setting `num_workers > 0`, this enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. You can experiment with different values to find the optimal setting for your hardware and workload.\n", "\n", - "Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` This warning can be safely ignored since data loaders do not use JAX within the forked processes." + "**Note:** When setting `num_workers > 0`, you may get the following Warning: `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` This warning can be safely ignored since `DataLoader`s do not use JAX within the forked processes." ] }, { @@ -582,7 +586,7 @@ "id": "Xzt2x9S1HC3T" }, "source": [ - "### Training Loop (PyTorch DataLoader)\n", + "### Define the training loop using the PyTorch `DataLoader`\n", "\n", "The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters." ] @@ -614,6 +618,18 @@ } ], "source": [ + "def train_model(num_epochs, params, training_generator, data_loader_type='streamed'):\n", + " \"\"\"Trains the model for a given number of epochs.\"\"\"\n", + " for epoch in range(num_epochs):\n", + " start_time = time.time()\n", + " for x, y in training_generator() if data_loader_type == 'streamed' else training_generator:\n", + " x, y = reshape_and_one_hot(x, y)\n", + " params = update(params, x, y)\n", + "\n", + " print(f\"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: \"\n", + " f\"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, \"\n", + " f\"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}\")\n", + "\n", "train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable')" ] }, @@ -623,9 +639,9 @@ "id": "Nm45ZTo6yrf5" }, "source": [ - "## Loading Data with TensorFlow Datasets (TFDS)\n", + "## Load data with TensorFlow Datasets (TFDS)\n", "\n", - "This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow." + "This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlowm as shown below:" ] }, { @@ -639,7 +655,7 @@ "import tensorflow_datasets as tfds\n", "import tensorflow as tf\n", "\n", - "# Ensuring CPU-Only Execution, disable any GPU usage(if applicable) for TF\n", + "# To ensure CPU-only execution, disable any GPU usage (if applicable) for TF.\n", "tf.config.set_visible_devices([], device_type='GPU')" ] }, @@ -649,9 +665,9 @@ "id": "3xdQY7H6wr3n" }, "source": [ - "### Fetch Full Dataset for Evaluation\n", + "### Load the entire TF dataset for evaluation\n", "\n", - "Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation." + "First, load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation." ] }, { @@ -709,7 +725,7 @@ } ], "source": [ - "# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)\n", + "# `tfds.load` returns `tf.Tensor`s (or `tf.data.Dataset`s if `batch_size` is not `-1`).\n", "mnist_data, info = tfds.load(name=\"mnist\", batch_size=-1, data_dir=data_dir, with_info=True)\n", "mnist_data = tfds.as_numpy(mnist_data)\n", "train_data, test_data = mnist_data['train'], mnist_data['test']\n", @@ -1426,9 +1442,9 @@ "id": "kk_4zJlz7T1E" }, "source": [ - "### Define Training Generator\n", + "### Define the training HF generator\n", "\n", - "Set up a generator to yield batches of images and labels for training." + "Set up a HF generator to yield batches of images and labels for training:" ] }, { @@ -1440,7 +1456,11 @@ "outputs": [], "source": [ "def hf_training_generator():\n", - " \"\"\"Yield batches for training.\"\"\"\n", + " \"\"\"Yield batches for training.\n", + "\n", + " Yields:\n", + " x, y: A tuple containing a batch of images (x) and labels (y).\n", + " \"\"\"\n", " for batch in mnist_dataset[\"train\"].iter(batch_size):\n", " x, y = batch[\"image\"], batch[\"label\"]\n", " yield x, y" @@ -1452,7 +1472,7 @@ "id": "HIsGfkLI7dvZ" }, "source": [ - "### Training Loop (Hugging Face Datasets)\n", + "### Train the model with HF Datasets\n", "\n", "Run the training loop using the Hugging Face training generator." ] @@ -1493,9 +1513,9 @@ "id": "qXylIOwidWI3" }, "source": [ - "## Summary\n", + "## Conclusion\n", "\n", - "This notebook has introduced efficient strategies for data loading on a CPU with JAX, demonstrating how to integrate popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages, enabling you to streamline the data loading process for machine learning tasks. By understanding the strengths of these methods, you can select the approach that best suits your project's specific requirements." + "This document demonstrates efficient data loading techniques for JAX on a CPU with PyTorch, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages." ] } ], diff --git a/docs/source/data_loaders_on_cpu_with_jax.md b/docs/source/data_loaders_on_cpu_with_jax.md index b109474..e67dd40 100644 --- a/docs/source/data_loaders_on_cpu_with_jax.md +++ b/docs/source/data_loaders_on_cpu_with_jax.md @@ -13,32 +13,28 @@ kernelspec: +++ {"id": "PUFGZggH49zp"} -# Introduction to Data Loaders on CPU with JAX +# Data loading on a CPU with JAX +++ {"id": "3ia4PKEV5Dr8"} [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/data_loaders_on_cpu_with_jax.ipynb) -This tutorial explores different data loading strategies for using **JAX** on a single [**CPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-CPU). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including: +This tutorial shows how to efficiently load data on a [**single CPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-CPU) for image classification on the MNIST datset using various dataset libraries, such as: -- [**PyTorch DataLoader**](https://github.com/pytorch/data) -- [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets) -- [**Grain**](https://github.com/google/grain) -- [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading) +- [Grain](https://github.com/google/grain) +- [PyTorch DataLoader](https://github.com/pytorch/data) +- [Hugging Face](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading) +- [TensorFlow Datasets (TFDS)](https://github.com/tensorflow/datasets) -In this tutorial, you'll learn how to efficiently load data using these libraries for a simple image classification task based on the MNIST dataset. +Compared with [GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) or [multi-device setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html), data loading on a CPU can avoid such challenges, as GPU memory management and data synchronization across devices. This can be helpful for smaller-scale tasks or scenarios where data resides exclusively on a CPU. -Compared to GPU or multi-device setups, CPU-based data loading is straightforward as it avoids challenges like GPU memory management and data synchronization across devices. This makes it ideal for smaller-scale tasks or scenarios where data resides exclusively on the CPU. - -If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html). - -If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html). +**Note:** To learn about **GPU-based data loading** with JAX, go to [Data loading on a GPU with JAX](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html). For a **multi-device data loading strategy** with JAX, check out [Data loading on multiple devices with JAX](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html). +++ {"id": "pEsb135zE-Jo"} -## Setting JAX to Use CPU Only +## Set JAX to use only the CPU -First, you'll restrict JAX to use only the CPU, even if a GPU is available. This ensures consistency and allows you to focus on CPU-based data loading. +First, let's set JAX to use only the CPU, even if a GPU is available. This ensures consistency and allows us to focus on CPU-based data loading: ```{code-cell} :id: vqP6xyObC0_9 @@ -49,7 +45,7 @@ os.environ['JAX_PLATFORM_NAME'] = 'cpu' +++ {"id": "-rsMgVtO6asW"} -Import JAX API +Next, let's import certain JAX modules, including JAX NumPy, `jax.random`, and three JAX transformations, such as `jax.grad`, `jax.jit` and `jax.vmap`: ```{code-cell} :id: tDJNQ6V-Dg5g @@ -57,11 +53,15 @@ Import JAX API import jax import jax.numpy as jnp from jax import random, grad, jit, vmap + +from jax.scipy.special import logsumexp + +import time ``` +++ {"id": "TsFdlkSZKp9S"} -### CPU Setup Verification +Let's verify the CPU setup with `jax.devices()`: ```{code-cell} --- @@ -75,31 +75,32 @@ jax.devices() +++ {"id": "qyJ_WTghDnIc"} -## Setting Hyperparameters and Initializing Parameters +## Setting hyperparameters and initializing parameters -You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network. +We'll define certain hyperparameters for the model and data loading, including layer sizes, the learning rate, the batch size, and the data directory. We'll also initialize the weights and biases for the fully-connected neural network in our example. ```{code-cell} :id: qLNOSloFDka_ -# A helper function to randomly initialize weights and biases -# for a dense neural network layer +# Define a helper function to initialize model weights and biases +# using a random normal distribution using `jax.random.normal()`. def random_layer_params(m, n, key, scale=1e-2): - w_key, b_key = random.split(key) - return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,)) + w_key, b_key = random.split(key) # Split the JAX PRNG key. + return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,)) -# Function to initialize network parameters for all layers based on defined sizes +# Define a function to initialize network parameters for all layers based on defined sizes +# using the previously created `random_layer_params()` function. def init_network_params(sizes, key): - keys = random.split(key, len(sizes)) - return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)] + keys = random.split(key, len(sizes)) # Split the JAX PRNG key. + return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)] -layer_sizes = [784, 512, 512, 10] # Layers of the network -step_size = 0.01 # Learning rate for optimization -num_epochs = 8 # Number of training epochs -batch_size = 128 # Batch size for training -n_targets = 10 # Number of classes (digits 0-9) -num_pixels = 28 * 28 # Input size (MNIST images are 28x28 pixels) -data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset +layer_sizes = [784, 512, 512, 10] # Layers of the network. +step_size = 0.01 # Learning rate for optimization. +num_epochs = 8 # Number of training epochs. +batch_size = 128 # Batch size for training. +n_targets = 10 # Number of classes (digits 0-9). +num_pixels = 28 * 28 # Input size (MNIST images are 28x28 pixels). +data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset. # Initialize network parameters using the defined layer sizes and a random seed params = init_network_params(layer_sizes, random.PRNGKey(0)) @@ -107,101 +108,90 @@ params = init_network_params(layer_sizes, random.PRNGKey(0)) +++ {"id": "6Ci_CqW7q6XM"} -## Model Prediction with Auto-Batching +## Model prediction with auto-batching with `jax.vmap` -In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image. +Next, we'll define the `predict()` function that computes the output of the network for a single input image. -To efficiently process multiple images simultaneously, you'll use [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap), which allows you to vectorize the `predict` function and apply it across a batch of inputs. This technique, called auto-batching, improves computational efficiency by leveraging hardware acceleration. +To efficiently process multiple images simultaneously, we'll use the [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) transformation, which will vectorize the `predict` function and apply it across a batch of inputs. This is also known as auto-batching, which improves computational efficiency by leveraging hardware acceleration. ```{code-cell} :id: bKIYPSkvD1QV -from jax.scipy.special import logsumexp - +# Create the ReLU activation function. def relu(x): - return jnp.maximum(0, x) + return jnp.maximum(0, x) +# Define the prediction function. def predict(params, image): - # per-example prediction - activations = image - for w, b in params[:-1]: - outputs = jnp.dot(w, activations) + b - activations = relu(outputs) + # Per-example prediction. + activations = image + for w, b in params[:-1]: + outputs = jnp.dot(w, activations) + b + activations = relu(outputs) - final_w, final_b = params[-1] - logits = jnp.dot(final_w, activations) + final_b - return logits - logsumexp(logits) + final_w, final_b = params[-1] + logits = jnp.dot(final_w, activations) + final_b + return logits - logsumexp(logits) -# Make a batched version of the `predict` function +# Using `jax.vmap`, make a batched version of the `predict()` function. batched_predict = vmap(predict, in_axes=(None, 0)) ``` +++ {"id": "niTSr34_sDZi"} -## Utility and Loss Functions +## Set up one-hot encoding, accuracy calculation, and the loss function with `jax.grad` and `jax.jit` -You'll now define utility functions for: +Next, we'll define some utility functions for: -- One-hot encoding: Converts class indices to binary vectors. -- Accuracy calculation: Measures the performance of the model on the dataset. -- Loss computation: Calculates the difference between predictions and targets. +- One-hot encoding to convert class indices to binary vectors. +- Accuracy calculation for measuring the performance of the model on the dataset. +- The loss function for calculating the difference between predictions and targets. -To optimize performance: +To optimize performance, we'll use the following JAX automatic differentiation and compilation transformations: -- [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters. -- [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation. +- [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters. +- [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla). ```{code-cell} :id: sA0a06raEQfS -import time - def one_hot(x, k, dtype=jnp.float32): - """Create a one-hot encoding of x of size k.""" - return jnp.array(x[:, None] == jnp.arange(k), dtype) + """Creates a one-hot encoding function of x of size k.""" + return jnp.array(x[:, None] == jnp.arange(k), dtype) def accuracy(params, images, targets): - """Calculate the accuracy of predictions.""" - target_class = jnp.argmax(targets, axis=1) - predicted_class = jnp.argmax(batched_predict(params, images), axis=1) - return jnp.mean(predicted_class == target_class) + """Calculates the accuracy of predictions.""" + target_class = jnp.argmax(targets, axis=1) + predicted_class = jnp.argmax(batched_predict(params, images), axis=1) + return jnp.mean(predicted_class == target_class) def loss(params, images, targets): - """Calculate the loss between predictions and targets.""" - preds = batched_predict(params, images) - return -jnp.mean(preds * targets) + """Calculates the loss between predictions and targets.""" + preds = batched_predict(params, images) + return -jnp.mean(preds * targets) +# Apply the `@jax.jit` decorator for faster execution. @jit def update(params, x, y): - """Update the network parameters using gradient descent.""" - grads = grad(loss)(params, x, y) - return [(w - step_size * dw, b - step_size * db) - for (w, b), (dw, db) in zip(params, grads)] + """Updates the network parameters using gradient descent.""" + grads = grad(loss)(params, x, y) + return [(w - step_size * dw, b - step_size * db) + for (w, b), (dw, db) in zip(params, grads)] def reshape_and_one_hot(x, y): - """Reshape and one-hot encode the inputs.""" + """Reshapes and one-hot encode the inputs.""" x = jnp.reshape(x, (len(x), num_pixels)) y = one_hot(y, n_targets) return x, y - -def train_model(num_epochs, params, training_generator, data_loader_type='streamed'): - """Train the model for a given number of epochs.""" - for epoch in range(num_epochs): - start_time = time.time() - for x, y in training_generator() if data_loader_type == 'streamed' else training_generator: - x, y = reshape_and_one_hot(x, y) - params = update(params, x, y) - - print(f"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: " - f"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, " - f"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}") ``` +++ {"id": "Hsionp5IYsQ9"} -## Loading Data with PyTorch DataLoader +## PyTorch `DataLoader` + +This section shows how to load the MNIST dataset using PyTorch's `DataLoader`, convert the data into NumPy arrays, and apply transformations to flatten and cast images. -This section shows how to load the MNIST dataset using PyTorch's DataLoader, convert the data to NumPy arrays, and apply transformations to flatten and cast images. +**Note:** We'll be using PyTorch with `torchvision`, so the next step shows how to install the packages. ```{code-cell} --- @@ -213,6 +203,10 @@ outputId: 33dfeada-a763-4d26-f778-a27966e34d55 !pip install torch torchvision ``` +### Load the `torchvision` dataset and standardize it + +First, we'lll create some helper functions and classes for convert `torchvision` MNIST dataset into NumPy arrays, and further transforming and preprocessing the data. + ```{code-cell} :id: kO5_WzwY59gE @@ -226,25 +220,29 @@ from torchvision.datasets import MNIST :id: 6f6qU8PCc143 def numpy_collate(batch): - """Convert a batch of PyTorch data to NumPy arrays.""" - return tree_map(np.asarray, data.default_collate(batch)) + """Converts a batch of PyTorch data into NumPy arrays.""" + return tree_map(np.asarray, data.default_collate(batch)) class NumpyLoader(data.DataLoader): - """Custom DataLoader to return NumPy arrays from a PyTorch Dataset.""" + """A custom NumPy `DataLoader` for the PyTorch generator, + subclasses `torch.utils.data.DataLoader`. + + Returns: + NumPy arrays from the PyTorch dataset. + + """ def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=numpy_collate, **kwargs) class FlattenAndCast(object): - """Transform class to flatten and cast images to float32.""" - def __call__(self, pic): - return np.ravel(np.array(pic, dtype=jnp.float32)) + """A transformation class for flattening and casting images to `float32`.""" + def __call__(self, pic): + return np.ravel(np.array(pic, dtype=jnp.float32)) ``` +++ {"id": "mfSnfJND6I8G"} -### Load Dataset with Transformations - -Standardize the data by flattening the images, casting them to `float32`, and ensuring consistent data types. +Now we can load the data, standardize it by flattening the images, casting them to `float32`, and ensuring consistent data types. ```{code-cell} --- @@ -258,9 +256,9 @@ mnist_dataset = MNIST(data_dir, download=True, transform=FlattenAndCast()) +++ {"id": "kbdsqvPZGrsa"} -### Full Training Dataset for Accuracy Checks +### Convert the training and test sets -Convert the entire training dataset to JAX arrays. +Next, we'll convert the whole training dataset to JAX arrays (`jax.Array`s) and one-hot encode the labels. ```{code-cell} :id: c9ZCJq_rzPck @@ -271,14 +269,14 @@ train_labels = one_hot(np.array(mnist_dataset.targets), n_targets) +++ {"id": "WXUh0BwvG8Ko"} -### Get Full Test Dataset - -Load and process the full test dataset. +After this, we can load and process the test dataset: ```{code-cell} :id: brlLG4SqGphm +# Load the test set. mnist_dataset_test = MNIST(data_dir, download=True, train=False) +# Convert to JAX arrays, one-hot encode the labels. test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32) test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets) ``` @@ -296,11 +294,11 @@ print('Test:', test_images.shape, test_labels.shape) +++ {"id": "m3zfxqnMiCbm"} -### Training Data Generator +### Create the PyTorch training generator -Define a generator function using PyTorch's DataLoader for batch training. Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload. +We'll define a generator function using PyTorch's `DataLoader` for batch training. By setting `num_workers > 0`, this enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. You can experiment with different values to find the optimal setting for your hardware and workload. -Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` This warning can be safely ignored since data loaders do not use JAX within the forked processes. +**Note:** When setting `num_workers > 0`, you may get the following Warning: `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` This warning can be safely ignored since `DataLoader`s do not use JAX within the forked processes. ```{code-cell} :id: B-fES82EiL6Z @@ -311,7 +309,7 @@ def pytorch_training_generator(mnist_dataset): +++ {"id": "Xzt2x9S1HC3T"} -### Training Loop (PyTorch DataLoader) +### Define the training loop using the PyTorch `DataLoader` The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters. @@ -322,14 +320,26 @@ colab: id: vtUjHsh-rJs8 outputId: 4766333e-4366-493b-995a-102778d1345a --- +def train_model(num_epochs, params, training_generator, data_loader_type='streamed'): + """Trains the model for a given number of epochs.""" + for epoch in range(num_epochs): + start_time = time.time() + for x, y in training_generator() if data_loader_type == 'streamed' else training_generator: + x, y = reshape_and_one_hot(x, y) + params = update(params, x, y) + + print(f"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: " + f"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, " + f"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}") + train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable') ``` +++ {"id": "Nm45ZTo6yrf5"} -## Loading Data with TensorFlow Datasets (TFDS) +## Load data with TensorFlow Datasets (TFDS) -This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow. +This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlowm as shown below: ```{code-cell} :id: sGaQAk1DHMUx @@ -337,15 +347,15 @@ This section demonstrates how to load the MNIST dataset using TFDS, fetch the fu import tensorflow_datasets as tfds import tensorflow as tf -# Ensuring CPU-Only Execution, disable any GPU usage(if applicable) for TF +# To ensure CPU-only execution, disable any GPU usage (if applicable) for TF. tf.config.set_visible_devices([], device_type='GPU') ``` +++ {"id": "3xdQY7H6wr3n"} -### Fetch Full Dataset for Evaluation +### Load the entire TF dataset for evaluation -Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation. +First, load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation. ```{code-cell} --- @@ -359,7 +369,7 @@ colab: id: 1hOamw_7C8Pb outputId: ca166490-22db-4732-b29f-866b7593e489 --- -# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1) +# `tfds.load` returns `tf.Tensor`s (or `tf.data.Dataset`s if `batch_size` is not `-1`). mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True) mnist_data = tfds.as_numpy(mnist_data) train_data, test_data = mnist_data['train'], mnist_data['test'] @@ -654,15 +664,19 @@ print('Test:', test_images.shape, test_labels.shape) +++ {"id": "kk_4zJlz7T1E"} -### Define Training Generator +### Define the training HF generator -Set up a generator to yield batches of images and labels for training. +Set up a HF generator to yield batches of images and labels for training: ```{code-cell} :id: -zLJhogj7RL- def hf_training_generator(): - """Yield batches for training.""" + """Yield batches for training. + + Yields: + x, y: A tuple containing a batch of images (x) and labels (y). + """ for batch in mnist_dataset["train"].iter(batch_size): x, y = batch["image"], batch["label"] yield x, y @@ -670,7 +684,7 @@ def hf_training_generator(): +++ {"id": "HIsGfkLI7dvZ"} -### Training Loop (Hugging Face Datasets) +### Train the model with HF Datasets Run the training loop using the Hugging Face training generator. @@ -686,6 +700,6 @@ train_model(num_epochs, params, hf_training_generator) +++ {"id": "qXylIOwidWI3"} -## Summary +## Conclusion -This notebook has introduced efficient strategies for data loading on a CPU with JAX, demonstrating how to integrate popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages, enabling you to streamline the data loading process for machine learning tasks. By understanding the strengths of these methods, you can select the approach that best suits your project's specific requirements. +This document demonstrates efficient data loading techniques for JAX on a CPU with PyTorch, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages.