From 74776a4c23b5f53050195f01d3f69a260c96aab4 Mon Sep 17 00:00:00 2001 From: Joey Cheng Date: Fri, 18 Oct 2024 15:47:34 +0800 Subject: [PATCH] update data pipeline change to load real data from qinyiyan@ and other changes from zhaoyuec@ --- README.md | 3 +- tpu/flax/README.md | 1 + tpu/flax/configs.py | 52 +++++++++++- tpu/flax/data_pipeline.py | 148 ++++++++++++++++++++++++++++---- tpu/flax/layers.py | 56 ++++++++---- tpu/flax/losses.py | 5 +- tpu/flax/metrics.py | 26 +++++- tpu/flax/models.py | 24 ++++-- tpu/flax/requirements.txt | 17 ---- tpu/flax/train.py | 173 +++++++++++++++++++++++--------------- tpu/keras/configs.py | 16 ++-- tpu/keras/models.py | 41 ++++----- tpu/keras/train.py | 33 ++------ 13 files changed, 409 insertions(+), 186 deletions(-) diff --git a/README.md b/README.md index 127a1da..0dd7577 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # RankML -RankML Library for TPU in Keras and Flax +RankML Library in Jax and Keras ## Setup @@ -7,6 +7,7 @@ To set up the environment for this project, follow these steps: ## Install dependencies: ```bash +pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip install -r requirements.txt ``` diff --git a/tpu/flax/README.md b/tpu/flax/README.md index f1a2fc8..8dc5e30 100644 --- a/tpu/flax/README.md +++ b/tpu/flax/README.md @@ -7,6 +7,7 @@ To set up the environment for this project, follow these steps: ## Install dependencies: ```bash +pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip install -r requirements.txt ``` diff --git a/tpu/flax/configs.py b/tpu/flax/configs.py index 45b2160..37373bd 100644 --- a/tpu/flax/configs.py +++ b/tpu/flax/configs.py @@ -14,8 +14,16 @@ limitations under the License. """ +import enum import ml_collections +class DatasetFormat(enum.Enum): + """Defines the dataset format.""" + TSV = "tsv" + TFRECORD = "tfrecord" + SYNTHETIC = "synthetic" + + def get_config(): """Get the default hyperparameter configuration.""" config = ml_collections.ConfigDict() @@ -37,14 +45,54 @@ def get_config(): config.train_data.sharding = True config.train_data.num_shards_per_host = 8 config.train_data.cycle_length = 8 - config.train_data.use_synthetic_data = True + config.train_data.dataset_format = DatasetFormat.SYNTHETIC config.validation_data = ml_collections.ConfigDict() config.validation_data.input_path = 'path/to/validation/data/*.tsv' config.validation_data.global_batch_size = 1024 config.validation_data.is_training = False config.validation_data.sharding = False - config.validation_data.use_synthetic_data = True + config.validation_data.dataset_format = DatasetFormat.SYNTHETIC + + # Global configuration + config.num_epochs = 10 # Make sure this is defined + config.steps_per_epoch = 100 # Adjust this value based on your dataset size and batch size + + return config + +def get_criteo_config(): + """Get the configuration for the Criteo dataset.""" + config = ml_collections.ConfigDict() + + # Model configuration + config.model = ml_collections.ConfigDict() + config.model.vocab_sizes = [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36] # Example vocab sizes + config.model.num_dense_features = 13 # Example number of dense features + config.model.embedding_dim = 16 # TODO(qinyiyan): Use larger embedding vector when ready. + config.model.bottom_mlp_dims = [512, 256, 128] # Add this line + config.model.top_mlp_dims = [1024, 1024, 512, 256, 1] # Add this line + config.model.learning_rate = 0.025 # Add this line + + # Data configuration + config.train_data = ml_collections.ConfigDict() + config.train_data.input_path = 'gs://rankml-datasets/criteo/train/day_0/*00100*' + config.train_data.global_batch_size = 1024 + config.train_data.is_training = True + config.train_data.sharding = True + config.train_data.num_shards_per_host = 8 + config.train_data.cycle_length = 8 + config.train_data.multi_hot_sizes = [3, 2, 1, 2, 6, 1, 1, 1, 1, 7, 3, 8, 1, 6, 9, 5, 1, 1, 1, 12, 100, 27, 10, 3, 1, 1] + config.train_data.dataset_format = DatasetFormat.TFRECORD + + config.validation_data = ml_collections.ConfigDict() + config.validation_data.input_path = 'gs://rankml-datasets/criteo/eval/day_23/*00000*' + config.validation_data.global_batch_size = 1024 + config.validation_data.is_training = False + config.validation_data.sharding = False + config.validation_data.cycle_length = 8 + config.validation_data.multi_hot_sizes = [3, 2, 1, 2, 6, 1, 1, 1, 1, 7, 3, 8, 1, 6, 9, 5, 1, 1, 1, 12, 100, 27, 10, 3, 1, 1] + config.validation_data.dataset_format = DatasetFormat.TFRECORD + # Global configuration config.num_epochs = 10 # Make sure this is defined diff --git a/tpu/flax/data_pipeline.py b/tpu/flax/data_pipeline.py index 74a94e0..6778dc9 100644 --- a/tpu/flax/data_pipeline.py +++ b/tpu/flax/data_pipeline.py @@ -16,9 +16,102 @@ from typing import List import tensorflow as tf -from configs import get_config +from configs import DatasetFormat import ml_collections + +class CriteoTFRecordReader(object): + """Input reader fn for TFRecords that have been serialized in batched form.""" + + def __init__( + self, + config: ml_collections.ConfigDict, + is_training: bool, + use_cached_data: bool = False, + ): + self._params = config.train_data if is_training else config.validation_data + self._num_dense_features = config.model.num_dense_features + self._vocab_sizes = config.model.vocab_sizes + self._use_cached_data = use_cached_data + + self.label_features = "label" + self.dense_features = ["dense-feature-%d" % x for x in range(1, 14)] + self.sparse_features = ["sparse-feature-%d" % x for x in range(14, 40)] + + def __call__(self) -> tf.data.Dataset: + params = self._params + # Per replica batch size. + batch_size = params.global_batch_size + + def _get_feature_spec(): + feature_spec = {} + feature_spec[self.label_features] = tf.io.FixedLenFeature( + [], dtype=tf.int64 + ) + for dense_feat in self.dense_features: + feature_spec[dense_feat] = tf.io.FixedLenFeature( + [], + dtype=tf.float32, + ) + for i, sparse_feat in enumerate(self.sparse_features): + feature_spec[sparse_feat] = tf.io.FixedLenFeature( + [params.multi_hot_sizes[i]], dtype=tf.int64 + ) + return feature_spec + + def _parse_fn(serialized_example): + feature_spec = _get_feature_spec() + parsed_features = tf.io.parse_single_example( + serialized_example, feature_spec + ) + label = parsed_features[self.label_features] + features = {} + int_features = [] + for dense_ft in self.dense_features: + int_features.append(parsed_features[dense_ft]) + features["dense_features"] = tf.stack(int_features) + + features["sparse_features"] = {} + for i, sparse_ft in enumerate(self.sparse_features): + features['sparse_features'][str(i)] = parsed_features[sparse_ft] + return features, label + + # TODO(qinyiyan): Enable sharding. + filenames = tf.data.Dataset.list_files(self._params.input_path, shuffle=False) + + num_shards_per_host = 1 + if params.sharding: + num_shards_per_host = params.num_shards_per_host + + def make_dataset(shard_index): + filenames_for_shard = filenames.shard(num_shards_per_host, shard_index) + dataset = tf.data.TFRecordDataset(filenames_for_shard) + if params.is_training: + dataset = dataset.repeat() + dataset = dataset.map( + _parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) + return dataset + + indices = tf.data.Dataset.range(num_shards_per_host) + dataset = indices.interleave( + map_func=make_dataset, + cycle_length=params.cycle_length, + num_parallel_calls=tf.data.experimental.AUTOTUNE, + ) + + dataset = dataset.batch( + batch_size, + drop_remainder=True, + num_parallel_calls=tf.data.experimental.AUTOTUNE, + ) + dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) + if self._use_cached_data: + dataset = dataset.take(1).cache().repeat() + + return dataset + + class CriteoTsvReader: """Input reader callable for pre-processed Criteo data.""" @@ -27,24 +120,28 @@ def __init__(self, config: ml_collections.ConfigDict, is_training: bool): self._model_config = config.model def __call__(self) -> tf.data.Dataset: - if self._params.use_synthetic_data: + if self._params.dataset_format == DatasetFormat.SYNTHETIC: return self._generate_synthetic_data() @tf.function def _parse_fn(example: tf.Tensor): """Parser function for pre-processed Criteo TSV records.""" label_defaults = [[0.0]] - dense_defaults = [[0.0] for _ in range(self._model_config.num_dense_features)] + dense_defaults = [ + [0.0] for _ in range(self._model_config.num_dense_features) + ] num_sparse_features = len(self._model_config.vocab_sizes) categorical_defaults = [[0] for _ in range(num_sparse_features)] record_defaults = label_defaults + dense_defaults + categorical_defaults - fields = tf.io.decode_csv(example, record_defaults, field_delim='\t', na_value='-1') + fields = tf.io.decode_csv( + example, record_defaults, field_delim="\t", na_value="-1" + ) label = tf.reshape(fields[0], [1]) features = {} - dense_features = fields[1:self._model_config.num_dense_features + 1] - features['dense_features'] = tf.stack(dense_features, axis=0) - features['sparse_features'] = { + dense_features = fields[1 : self._model_config.num_dense_features + 1] + features["dense_features"] = tf.stack(dense_features, axis=0) + features["sparse_features"] = { str(i): fields[i + self._model_config.num_dense_features + 1] for i in range(num_sparse_features) } @@ -65,19 +162,32 @@ def _generate_synthetic_data(self) -> tf.data.Dataset: num_dense = self._model_config.num_dense_features dataset_size = 100 * self._params.global_batch_size - dense_tensor = tf.random.uniform(shape=(dataset_size, num_dense), maxval=1.0, dtype=tf.float32) + dense_tensor = tf.random.uniform( + shape=(dataset_size, num_dense), maxval=1.0, dtype=tf.float32 + ) sparse_tensors = [ tf.random.uniform(shape=(dataset_size,), maxval=int(size), dtype=tf.int32) for size in self._model_config.vocab_sizes ] - sparse_tensor_elements = {str(i): sparse_tensors[i] for i in range(len(sparse_tensors))} + sparse_tensor_elements = { + str(i): sparse_tensors[i] for i in range(len(sparse_tensors)) + } dense_tensor_mean = tf.math.reduce_mean(dense_tensor, axis=1) - sparse_tensors_mean = tf.math.reduce_sum(tf.stack(sparse_tensors, axis=-1), axis=1) - sparse_tensors_mean = tf.cast(sparse_tensors_mean, dtype=tf.float32) / sum(self._model_config.vocab_sizes) - label_tensor = tf.cast((dense_tensor_mean + sparse_tensors_mean) / 2.0 + 0.5, tf.int32) - - input_elem = {'dense_features': dense_tensor, 'sparse_features': sparse_tensor_elements}, label_tensor + sparse_tensors_mean = tf.math.reduce_sum( + tf.stack(sparse_tensors, axis=-1), axis=1 + ) + sparse_tensors_mean = tf.cast(sparse_tensors_mean, dtype=tf.float32) / sum( + self._model_config.vocab_sizes + ) + label_tensor = tf.cast( + (dense_tensor_mean + sparse_tensors_mean) / 2.0 + 0.5, tf.int32 + ) + + input_elem = { + "dense_features": dense_tensor, + "sparse_features": sparse_tensor_elements, + }, label_tensor dataset = tf.data.Dataset.from_tensor_slices(input_elem) dataset = dataset.cache() if self._params.is_training: @@ -86,10 +196,16 @@ def _generate_synthetic_data(self) -> tf.data.Dataset: return dataset + def train_input_fn(config: ml_collections.ConfigDict) -> tf.data.Dataset: """Returns dataset of batched training examples.""" - return CriteoTsvReader(config, is_training=True)() + if config.train_data.dataset_format in {DatasetFormat.SYNTHETIC, DatasetFormat.TSV}: + return CriteoTsvReader(config, is_training=True)() + return CriteoTFRecordReader(config, is_training=True)() + def eval_input_fn(config: ml_collections.ConfigDict) -> tf.data.Dataset: """Returns dataset of batched eval examples.""" - return CriteoTsvReader(config, is_training=False)() + if config.validation_data.dataset_format in {DatasetFormat.SYNTHETIC, DatasetFormat.TSV}: + return CriteoTsvReader(config, is_training=True)() + return CriteoTFRecordReader(config, is_training=False)() diff --git a/tpu/flax/layers.py b/tpu/flax/layers.py index cb3a1ba..aa443b8 100644 --- a/tpu/flax/layers.py +++ b/tpu/flax/layers.py @@ -25,8 +25,10 @@ import numpy as np import optax + class MLP(nn.Module): """Multi-layer perceptron.""" + layer_sizes: Sequence[int] @nn.compact @@ -36,16 +38,20 @@ def __call__(self, x): x = nn.relu(x) return x + class DenseArch(nn.Module): """Dense features architecture.""" + layer_sizes: Sequence[int] @nn.compact def __call__(self, dense_features): return MLP(self.layer_sizes)(dense_features) + class EmbeddingArch(nn.Module): """Embedding architecture""" + vocab_sizes: List[int] embedding_dim: int @@ -53,11 +59,16 @@ class EmbeddingArch(nn.Module): def __call__(self, embedding_ids): embeddings = [] for i, vocab_size in enumerate(self.vocab_sizes): - embedding_table = self.param(f'embedding_{i}', nn.initializers.uniform(), (vocab_size, self.embedding_dim)) + embedding_table = self.param( + f"embedding_{i}", + nn.initializers.uniform(), + (vocab_size, self.embedding_dim), + ) embedding = jnp.take(embedding_table, embedding_ids[:, i], axis=0) embeddings.append(embedding) return embeddings + class InteractionArch(nn.Module): """Base interaction architecture.""" @@ -65,45 +76,54 @@ class InteractionArch(nn.Module): def __call__(self, dense_output, embedding_outputs): return jnp.concatenate([dense_output] + embedding_outputs, axis=1) + class DotInteractionArch(InteractionArch): """Dot product interaction architecture.""" @nn.compact def __call__(self, dense_output, embedding_outputs): - base_output = jnp.concatenate([dense_output] + embedding_outputs, axis=1) - - # Combine dense and embedding outputs - combined_values = jnp.concatenate([dense_output.reshape(dense_output.shape[0], 1, -1)] + [e.reshape(e.shape[0], 1, -1) for e in embedding_outputs], axis=1) - - # Compute pairwise interactions + combined_values = jnp.concatenate( + [dense_output.reshape(dense_output.shape[0], 1, -1)] + + [e.reshape(e.shape[0], 1, -1) for e in embedding_outputs], + axis=1, + ) + interactions = jnp.matmul(combined_values, combined_values.transpose((0, 2, 1))) - - # Get upper triangular indices + num_features = combined_values.shape[1] triu_indices = jnp.triu_indices(num_features, num_features, k=1) - - # Extract upper triangular elements + interactions_flat = interactions[:, triu_indices[0], triu_indices[1]] - + return jnp.concatenate([dense_output, interactions_flat], axis=1) + class LowRankCrossNetInteractionArch(InteractionArch): """Low Rank Cross Network interaction architecture.""" + num_layers: int low_rank: int @nn.compact def __call__(self, dense_output, embedding_outputs): base_output = jnp.concatenate([dense_output] + embedding_outputs, axis=1) - + x_0 = base_output x_l = x_0 in_features = x_0.shape[-1] for layer in range(self.num_layers): - W = self.param(f'W_{layer}', nn.initializers.glorot_uniform(), (in_features, self.low_rank)) - V = self.param(f'V_{layer}', nn.initializers.glorot_uniform(), (self.low_rank, in_features)) - b = self.param(f'b_{layer}', nn.initializers.zeros, (in_features,)) + W = self.param( + f"W_{layer}", + nn.initializers.glorot_uniform(), + (in_features, self.low_rank), + ) + V = self.param( + f"V_{layer}", + nn.initializers.glorot_uniform(), + (self.low_rank, in_features), + ) + b = self.param(f"b_{layer}", nn.initializers.zeros, (in_features,)) x_l_v = jnp.matmul(x_l, V.T) x_l_w = jnp.matmul(x_l_v, W.T) @@ -111,11 +131,13 @@ def __call__(self, dense_output, embedding_outputs): return x_l + class OverArch(nn.Module): """Over-architecture (top MLP).""" + layer_sizes: Sequence[int] @nn.compact def __call__(self, x): x = MLP(self.layer_sizes)(x) - return nn.Dense(features=1)(x) \ No newline at end of file + return nn.Dense(features=1)(x) diff --git a/tpu/flax/losses.py b/tpu/flax/losses.py index 94bdc55..27a437f 100644 --- a/tpu/flax/losses.py +++ b/tpu/flax/losses.py @@ -25,6 +25,9 @@ import numpy as np import optax + def bce_with_logits_loss(logits, labels): """Binary Cross Entropy with Logits Loss.""" - return -jnp.mean(labels * jax.nn.log_sigmoid(logits) + (1 - labels) * jax.nn.log_sigmoid(-logits)) + return -jnp.mean( + labels * jax.nn.log_sigmoid(logits) + (1 - labels) * jax.nn.log_sigmoid(-logits) + ) diff --git a/tpu/flax/metrics.py b/tpu/flax/metrics.py index d3096d6..0161b8e 100644 --- a/tpu/flax/metrics.py +++ b/tpu/flax/metrics.py @@ -24,16 +24,38 @@ import ml_collections import numpy as np import optax +from jax.scipy.integrate import trapezoid def accuracy(logits, labels): """Calculates the accuracy of predictions.""" predictions = jax.nn.sigmoid(logits) > 0.5 return jnp.mean(predictions == labels) + +def auc(logits, labels): + """Calculates the Area Under the Receiver Operating Characteristic Curve (AUC).""" + # Sort the data by predicted probabilities in descending order + sorted_indices = jnp.argsort(logits)[::-1] + sorted_logits = logits[sorted_indices] + sorted_labels = labels[sorted_indices] + + # Calculate the cumulative sum of positive and negative labels + cumsum_pos = jnp.cumsum(sorted_labels) + cumsum_neg = jnp.cumsum(1 - sorted_labels) + + # Calculate the total number of positive and negative labels + total_pos = jnp.sum(sorted_labels) + total_neg = len(sorted_labels) - total_pos + + # Calculate the AUC using the trapezoidal rule + auc_score = trapezoid(cumsum_pos / total_pos, cumsum_neg / total_neg) + + return auc_score + @jax.jit def compute_metrics(logits, labels): """Computes all metrics at once.""" return { - 'accuracy': accuracy(logits, labels), + "accuracy": accuracy(logits, labels), + "AUC": auc(logits, labels), } - diff --git a/tpu/flax/models.py b/tpu/flax/models.py index af5be5a..a2682ab 100644 --- a/tpu/flax/models.py +++ b/tpu/flax/models.py @@ -26,8 +26,10 @@ import optax from layers import MLP, DenseArch, EmbeddingArch, InteractionArch, OverArch + class DLRMV2(nn.Module): """DLRM V2 model.""" + vocab_sizes: List[int] embedding_dim: int bottom_mlp_dims: List[int] @@ -35,22 +37,28 @@ class DLRMV2(nn.Module): @nn.compact def __call__(self, dense_features, embedding_ids): - # Bottom MLP x = self.bottom_mlp(dense_features) - # Embedding layer embeddings = [] for i, vocab_size in enumerate(self.vocab_sizes): embedding = nn.Embed(vocab_size, self.embedding_dim)(embedding_ids[str(i)]) embeddings.append(embedding) - - # Flatten and concatenate embeddings - embedding_output = jnp.concatenate([e.reshape(-1, self.embedding_dim) for e in embeddings], axis=1) + if len(embeddings[0].shape) == 3: + # Multihot embedding has 3 dimensions: (batch_size, multihot_size, embedding_dim) + # TODO(qinyiyan): Find a proper way to do pooling. + pooled_embeddings = [ + jnp.mean(embedding, axis=1) for embedding in embeddings + ] + embedding_output = jnp.concatenate( + [e.reshape(-1, self.embedding_dim) for e in pooled_embeddings], axis=1 + ) + else: + embedding_output = jnp.concatenate( + [e.reshape(-1, self.embedding_dim) for e in embeddings], axis=1 + ) - # Concatenate bottom MLP output and embedding output concatenated = jnp.concatenate([x, embedding_output], axis=1) - # Top MLP y = self.top_mlp(concatenated) return y.squeeze(-1) @@ -66,4 +74,4 @@ def top_mlp(self, x): x = nn.Dense(dim)(x) x = nn.relu(x) x = nn.Dense(self.top_mlp_dims[-1])(x) - return x \ No newline at end of file + return x diff --git a/tpu/flax/requirements.txt b/tpu/flax/requirements.txt index 885c473..61a7e2b 100644 --- a/tpu/flax/requirements.txt +++ b/tpu/flax/requirements.txt @@ -1,22 +1,5 @@ -""" -Copyright 2024 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - absl-py flax -# jax[tpu] @ https://storage.googleapis.com/jax-releases/libtpu_releases.html jaxlib ml_collections numpy diff --git a/tpu/flax/train.py b/tpu/flax/train.py index 3102f82..c26684d 100644 --- a/tpu/flax/train.py +++ b/tpu/flax/train.py @@ -22,92 +22,129 @@ import jax.numpy as jnp import optax from models import DLRMV2 -from configs import get_config +from configs import get_config, get_criteo_config from losses import bce_with_logits_loss -from metrics import accuracy +from metrics import accuracy, compute_metrics import ml_collections from data_pipeline import train_input_fn, eval_input_fn import tensorflow as tf -import metrics +from jax.experimental import mesh_utils +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P +import numpy as np +import time + + +def create_train_state(rng, config, mesh): + """Creates initial `TrainState` with sharding.""" + dlrm = DLRMV2( + vocab_sizes=config.model.vocab_sizes, + embedding_dim=config.model.embedding_dim, + bottom_mlp_dims=config.model.bottom_mlp_dims, + top_mlp_dims=config.model.top_mlp_dims, + ) + + dummy_dense = jnp.ones([1, config.model.num_dense_features]) + dummy_sparse = { + str(i): jnp.ones([1], dtype=jnp.int32) + for i in range(len(config.model.vocab_sizes)) + } + + params = dlrm.init(rng, dummy_dense, dummy_sparse)["params"] + tx = optax.adam(config.model.learning_rate) + + return train_state.TrainState.create( + apply_fn=dlrm.apply, + params=jax.tree.map( + lambda x: jax.device_put(x, NamedSharding(mesh, P())), params + ), + tx=tx, + ) + @jax.jit -def apply_model(state, dense_features, sparse_features, labels): - """Computes gradients, loss and accuracy for a single batch.""" +def train_step(state, batch): def loss_fn(params): - logits = state.apply_fn({'params': params}, dense_features, sparse_features) - loss = bce_with_logits_loss(logits, labels) + logits = state.apply_fn( + {"params": params}, batch["dense_features"], batch["sparse_features"] + ) + loss = bce_with_logits_loss(logits, batch["labels"]) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) (loss, logits), grads = grad_fn(state.params) - return grads, loss, logits + state = state.apply_gradients(grads=grads) + metrics = compute_metrics(logits, batch["labels"]) + metrics["loss"] = loss + return state, metrics -@jax.jit -def update_model(state, grads): - return state.apply_gradients(grads=grads) -def create_train_state(rng, config): - """Creates initial `TrainState`.""" - dlrm = DLRMV2( - vocab_sizes=config.model.vocab_sizes, - embedding_dim=config.model.embedding_dim, - bottom_mlp_dims=config.model.bottom_mlp_dims, - top_mlp_dims=config.model.top_mlp_dims +@jax.jit +def eval_step(state, batch): + logits = state.apply_fn( + {"params": state.params}, batch["dense_features"], batch["sparse_features"] ) - - # Create dummy inputs for initialization - dummy_dense = jnp.ones([1, config.model.num_dense_features]) - dummy_sparse = {str(i): jnp.ones([1], dtype=jnp.int32) for i in range(len(config.model.vocab_sizes))} - - params = dlrm.init(rng, dummy_dense, dummy_sparse)['params'] - tx = optax.adam(config.model.learning_rate) - return train_state.TrainState.create(apply_fn=dlrm.apply, params=params, tx=tx) + loss = bce_with_logits_loss(logits, batch["labels"]) + metrics = compute_metrics(logits, batch["labels"]) + metrics["loss"] = loss + return metrics + -def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train_state.TrainState: - """Execute model training and evaluation loop.""" - train_ds = train_input_fn(config) - # test_ds = eval_input_fn(config) +def train_and_evaluate( + config: ml_collections.ConfigDict, workdir: str +) -> train_state.TrainState: + devices = jax.devices() + num_devices = len(devices) + mesh = Mesh(mesh_utils.create_device_mesh((num_devices,)), axis_names=("batch",)) rng = jax.random.PRNGKey(0) rng, init_rng = jax.random.split(rng) - state = create_train_state(init_rng, config) - print('start training') - for epoch in range(1, config.num_epochs + 1): - rng, input_rng = jax.random.split(rng) - - # Train loop - epoch_metrics = [] - for features, labels in train_ds.take(config.steps_per_epoch): - dense_features = jnp.array(features['dense_features']) - sparse_features = {k: jnp.array(v) for k, v in features['sparse_features'].items()} - labels = jnp.array(labels) - grads, loss, logits = apply_model(state, dense_features, sparse_features, labels) - state = update_model(state, grads) - batch_metrics = metrics.compute_metrics(logits, labels) - batch_metrics['loss'] = loss - epoch_metrics.append(batch_metrics) - - # Compute average metrics for the epoch - train_metrics = jax.tree_map(lambda *args: jnp.mean(jnp.array(args)), *epoch_metrics) - - print('epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f' % (epoch, train_metrics['loss'], train_metrics['accuracy'] * 100)) - - # # Evaluation loop - # test_loss = [] - # test_accuracy = [] - # for features, labels in test_ds: - # dense_features = jnp.array(features['dense_features']) - # sparse_features = {k: jnp.array(v) for k, v in features['sparse_features'].items()} - # labels = jnp.array(labels) - # _, loss, accuracy = apply_model(state, dense_features, sparse_features, labels) - # test_loss.append(loss) - # test_accuracy.append(accuracy) - - # test_loss = jnp.mean(jnp.array(test_loss)) - # test_accuracy = jnp.mean(jnp.array(test_accuracy)) - + + with mesh: + state = create_train_state(init_rng, config, mesh) + + train_ds = train_input_fn(config) + eval_ds = eval_input_fn(config) + + batch_sharding = NamedSharding( + mesh, + P( + "batch", + ), + ) + + print("Start training") + for epoch in range(1, config.num_epochs + 1): + train_metrics = [] + for features, labels in train_ds.take(config.steps_per_epoch): + batch = { + "dense_features": jax.device_put( + np.array(features["dense_features"]), batch_sharding + ), + "sparse_features": jax.tree.map( + lambda x: jax.device_put(np.array(x), batch_sharding), + features["sparse_features"], + ), + "labels": jax.device_put(np.array(labels), batch_sharding), + } + + # jax.debug.visualize_array_sharding(batch['labels']) + + state, metrics = train_step(state, batch) + train_metrics.append(metrics) + + train_metrics = jax.tree.map( + lambda *args: jnp.mean(jnp.array(args)), *train_metrics + ) + + print(f"Epoch {epoch}:") + print( + f' Train loss: {train_metrics["loss"]:.4f}, accuracy: {train_metrics["accuracy"]:.4f}, auc: {train_metrics["auc"]: .4f}' + ) + return state + if __name__ == "__main__": - config = get_config() - train_and_evaluate(config, '/tmp/dlrm_v2') + # config = get_config() + config = get_criteo_config() + train_and_evaluate(config, "/tmp/dlrm_v2") diff --git a/tpu/keras/configs.py b/tpu/keras/configs.py index 45b2160..766c1c0 100644 --- a/tpu/keras/configs.py +++ b/tpu/keras/configs.py @@ -22,12 +22,12 @@ def get_config(): # Model configuration config.model = ml_collections.ConfigDict() - config.model.vocab_sizes = [1000, 1000, 1000] # Example vocab sizes - config.model.num_dense_features = 13 # Example number of dense features - config.model.embedding_dim = 32 # Add this line - config.model.bottom_mlp_dims = [64, 32, 16] # Add this line - config.model.top_mlp_dims = [64, 32, 1] # Add this line - config.model.learning_rate = 0.001 # Add this line + config.model.vocab_sizes = [1000, 1000, 1000] + config.model.num_dense_features = 13 + config.model.embedding_dim = 32 + config.model.bottom_mlp_dims = [64, 32, 16] + config.model.top_mlp_dims = [64, 32, 1] + config.model.learning_rate = 0.001 # Data configuration config.train_data = ml_collections.ConfigDict() @@ -47,7 +47,7 @@ def get_config(): config.validation_data.use_synthetic_data = True # Global configuration - config.num_epochs = 10 # Make sure this is defined - config.steps_per_epoch = 100 # Adjust this value based on your dataset size and batch size + config.num_epochs = 10 + config.steps_per_epoch = 100 return config \ No newline at end of file diff --git a/tpu/keras/models.py b/tpu/keras/models.py index 5dad0d2..0856f6b 100644 --- a/tpu/keras/models.py +++ b/tpu/keras/models.py @@ -23,7 +23,6 @@ def DLRM(config): - # Extract model parameters from the configuration num_dense_features = config.model.num_dense_features vocab_sizes = config.model.vocab_sizes embedding_dim = config.model.embedding_dim @@ -32,32 +31,34 @@ def DLRM(config): num_sparse_features = len(vocab_sizes) - # Inputs - dense_input = keras.Input(shape=(num_dense_features,), name='dense_features') - sparse_inputs = [keras.Input(shape=(), dtype='int32', name=f'sparse_feature_{i}') for i in range(num_sparse_features)] + dense_input = keras.Input(shape=(num_dense_features,), name='dense_input') + + sparse_inputs = [ + keras.Input(shape=(), dtype='int32', name=f'sparse_input_{i}') + for i in range(num_sparse_features) + ] - # Bottom MLP for dense features x = dense_input for units in bottom_mlp_units: x = keras.layers.Dense(units, activation='relu')(x) - dense_embedding = x # Shape: (batch_size, embedding_dim) - # Embedding layers for sparse features - sparse_embeddings = [] - for i in range(num_sparse_features): - embedding_layer = keras.layers.Embedding(input_dim=vocab_sizes[i], output_dim=embedding_dim, name=f'embedding_{i}') - sparse_embedding = embedding_layer(sparse_inputs[i]) # Shape: (batch_size, embedding_dim) - sparse_embeddings.append(sparse_embedding) + embeddings = [] + for i, vocab_size in enumerate(vocab_sizes): + embedding = keras.layers.Embedding( + input_dim=vocab_size, + output_dim=embedding_dim, + name=f'embedding_{i}' + )(sparse_inputs[i]) + embeddings.append(embedding) + + concatenated_embeddings = keras.layers.Concatenate()(embeddings) - # Interactions between dense and sparse embeddings - all_embeddings = [dense_embedding] + sparse_embeddings # List of tensors with shape (batch_size, embedding_dim) - # Concatenate interactions - x = keras.layers.Concatenate()(all_embeddings) # Shape: (batch_size, num_interactions) + interaction = keras.layers.Concatenate()([x, concatenated_embeddings]) - # Top MLP + y = interaction for units in top_mlp_units[:-1]: - x = keras.layers.Dense(units, activation='relu')(x) - output = keras.layers.Dense(top_mlp_units[-1], activation='sigmoid')(x) + y = keras.layers.Dense(units, activation='relu')(y) + outputs = keras.layers.Dense(top_mlp_units[-1], activation='sigmoid')(y) - model = keras.Model(inputs=[dense_input] + sparse_inputs, outputs=output) + model = keras.Model(inputs=[dense_input] + sparse_inputs, outputs=outputs) return model \ No newline at end of file diff --git a/tpu/keras/train.py b/tpu/keras/train.py index 220e6c2..bf50ad7 100644 --- a/tpu/keras/train.py +++ b/tpu/keras/train.py @@ -28,22 +28,15 @@ from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P -# Import the configuration and data pipeline modules from configs import get_config import data_pipeline from models import DLRM config = get_config() -# Get training and validation datasets train_data = data_pipeline.train_input_fn(config) eval_data = data_pipeline.eval_input_fn(config) -""" -## Multi-Device Synchronous Training - -Now, we will set up the training loop to perform synchronous training across multiple devices using JAX sharding APIs. -""" # Configurations num_epochs = config.num_epochs @@ -54,13 +47,10 @@ optimizer = keras.optimizers.Adam(learning_rate) loss_fn = keras.losses.BinaryCrossentropy(from_logits=False) -# Initialize all state with .build() -# Need to generate one batch of data to build the model (one_batch_inputs, one_batch_labels) = next(iter(train_data)) model.build(one_batch_inputs) optimizer.build(model.trainable_variables) -# This is the loss function that will be differentiated. def compute_loss(trainable_variables, non_trainable_variables, inputs, y_true): y_pred, updated_non_trainable_variables = model.stateless_call( trainable_variables, non_trainable_variables, inputs @@ -68,10 +58,8 @@ def compute_loss(trainable_variables, non_trainable_variables, inputs, y_true): loss_value = loss_fn(y_true, y_pred) return loss_value, updated_non_trainable_variables -# Function to compute gradients compute_gradients = jax.value_and_grad(compute_loss, has_aux=True) -# Training step @jax.jit def train_step(train_state, inputs, y_true): trainable_variables, non_trainable_variables, optimizer_variables = train_state @@ -89,21 +77,16 @@ def train_step(train_state, inputs, y_true): optimizer_variables, ) -# Replicate the model and optimizer variable on all devices def get_replicated_train_state(devices): - # All variables will be replicated on all devices var_mesh = Mesh(devices, axis_names=("_")) - # In NamedSharding, axes not mentioned are replicated (all axes here) var_replication = NamedSharding(var_mesh, P()) - # Apply the distribution settings to the model variables trainable_variables = jax.device_put(model.trainable_variables, var_replication) non_trainable_variables = jax.device_put( model.non_trainable_variables, var_replication ) optimizer_variables = jax.device_put(optimizer.variables, var_replication) - # Combine all state in a tuple return (trainable_variables, non_trainable_variables, optimizer_variables) @@ -111,14 +94,13 @@ def get_replicated_train_state(devices): print(f"Running on {num_devices} devices: {jax.local_devices()}") devices = mesh_utils.create_device_mesh((num_devices,)) -# Data will be split along the batch axis -data_mesh = Mesh(devices, axis_names=("batch",)) # naming axes of the mesh +data_mesh = Mesh(devices, axis_names=("batch",)) data_sharding = NamedSharding( data_mesh, P( "batch", ), -) # naming axes of the sharded partition +) train_state = get_replicated_train_state(devices) @@ -128,18 +110,17 @@ def get_replicated_train_state(devices): for step in range(config.steps_per_epoch): batch = next(data_iter) inputs, y_true = batch - # Convert inputs to the expected format - # inputs is a dict with 'dense_features' and 'sparse_features' + dense_features = inputs['dense_features'].numpy() sparse_features = inputs['sparse_features'] sparse_features = [sparse_features[str(i)].numpy() for i in range(len(config.model.vocab_sizes))] - # Prepare the input list + input_list = [dense_features] + sparse_features y_true = y_true.numpy() - # Shard inputs sharded_inputs = [jax.device_put(x, data_sharding) for x in input_list] - sharded_y_true = jax.device_put(y_true, data_sharding) - loss_value, train_state = train_step(train_state, sharded_inputs, sharded_y_true) + # sharded_y_true = jax.device_put(y_true, data_sharding) + # loss_value, train_state = train_step(train_state, sharded_inputs, sharded_y_true) + loss_value, train_state = train_step(train_state, sharded_inputs, y_true) print(f"Epoch {epoch+1}, loss: {loss_value}") trainable_variables, non_trainable_variables, optimizer_variables = train_state