Skip to content

update data pipeline change from qinyiyan@ and other changes from zhaoyuec@ #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# RankML
RankML Library for TPU in Keras and Flax
RankML Library in Jax and Keras

## Setup

To set up the environment for this project, follow these steps:

## Install dependencies:
```bash
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -r requirements.txt
```

Expand Down
1 change: 1 addition & 0 deletions tpu/flax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ To set up the environment for this project, follow these steps:

## Install dependencies:
```bash
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -r requirements.txt
```

Expand Down
52 changes: 50 additions & 2 deletions tpu/flax/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,16 @@
limitations under the License.
"""

import enum
import ml_collections

class DatasetFormat(enum.Enum):
"""Defines the dataset format."""
TSV = "tsv"
TFRECORD = "tfrecord"
SYNTHETIC = "synthetic"


def get_config():
"""Get the default hyperparameter configuration."""
config = ml_collections.ConfigDict()
Expand All @@ -37,14 +45,54 @@ def get_config():
config.train_data.sharding = True
config.train_data.num_shards_per_host = 8
config.train_data.cycle_length = 8
config.train_data.use_synthetic_data = True
config.train_data.dataset_format = DatasetFormat.SYNTHETIC

config.validation_data = ml_collections.ConfigDict()
config.validation_data.input_path = 'path/to/validation/data/*.tsv'
config.validation_data.global_batch_size = 1024
config.validation_data.is_training = False
config.validation_data.sharding = False
config.validation_data.use_synthetic_data = True
config.validation_data.dataset_format = DatasetFormat.SYNTHETIC

# Global configuration
config.num_epochs = 10 # Make sure this is defined
config.steps_per_epoch = 100 # Adjust this value based on your dataset size and batch size

return config

def get_criteo_config():
"""Get the configuration for the Criteo dataset."""
config = ml_collections.ConfigDict()

# Model configuration
config.model = ml_collections.ConfigDict()
config.model.vocab_sizes = [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36] # Example vocab sizes
config.model.num_dense_features = 13 # Example number of dense features
config.model.embedding_dim = 16 # TODO(qinyiyan): Use larger embedding vector when ready.
config.model.bottom_mlp_dims = [512, 256, 128] # Add this line
config.model.top_mlp_dims = [1024, 1024, 512, 256, 1] # Add this line
config.model.learning_rate = 0.025 # Add this line

# Data configuration
config.train_data = ml_collections.ConfigDict()
config.train_data.input_path = 'gs://rankml-datasets/criteo/train/day_0/*00100*'
config.train_data.global_batch_size = 1024
config.train_data.is_training = True
config.train_data.sharding = True
config.train_data.num_shards_per_host = 8
config.train_data.cycle_length = 8
config.train_data.multi_hot_sizes = [3, 2, 1, 2, 6, 1, 1, 1, 1, 7, 3, 8, 1, 6, 9, 5, 1, 1, 1, 12, 100, 27, 10, 3, 1, 1]
config.train_data.dataset_format = DatasetFormat.TFRECORD

config.validation_data = ml_collections.ConfigDict()
config.validation_data.input_path = 'gs://rankml-datasets/criteo/eval/day_23/*00000*'
config.validation_data.global_batch_size = 1024
config.validation_data.is_training = False
config.validation_data.sharding = False
config.validation_data.cycle_length = 8
config.validation_data.multi_hot_sizes = [3, 2, 1, 2, 6, 1, 1, 1, 1, 7, 3, 8, 1, 6, 9, 5, 1, 1, 1, 12, 100, 27, 10, 3, 1, 1]
config.validation_data.dataset_format = DatasetFormat.TFRECORD


# Global configuration
config.num_epochs = 10 # Make sure this is defined
Expand Down
148 changes: 132 additions & 16 deletions tpu/flax/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,102 @@

from typing import List
import tensorflow as tf
from configs import get_config
from configs import DatasetFormat
import ml_collections


class CriteoTFRecordReader(object):
"""Input reader fn for TFRecords that have been serialized in batched form."""

def __init__(
self,
config: ml_collections.ConfigDict,
is_training: bool,
use_cached_data: bool = False,
):
self._params = config.train_data if is_training else config.validation_data
self._num_dense_features = config.model.num_dense_features
self._vocab_sizes = config.model.vocab_sizes
self._use_cached_data = use_cached_data

self.label_features = "label"
self.dense_features = ["dense-feature-%d" % x for x in range(1, 14)]
self.sparse_features = ["sparse-feature-%d" % x for x in range(14, 40)]

def __call__(self) -> tf.data.Dataset:
params = self._params
# Per replica batch size.
batch_size = params.global_batch_size

def _get_feature_spec():
feature_spec = {}
feature_spec[self.label_features] = tf.io.FixedLenFeature(
[], dtype=tf.int64
)
for dense_feat in self.dense_features:
feature_spec[dense_feat] = tf.io.FixedLenFeature(
[],
dtype=tf.float32,
)
for i, sparse_feat in enumerate(self.sparse_features):
feature_spec[sparse_feat] = tf.io.FixedLenFeature(
[params.multi_hot_sizes[i]], dtype=tf.int64
)
return feature_spec

def _parse_fn(serialized_example):
feature_spec = _get_feature_spec()
parsed_features = tf.io.parse_single_example(
serialized_example, feature_spec
)
label = parsed_features[self.label_features]
features = {}
int_features = []
for dense_ft in self.dense_features:
int_features.append(parsed_features[dense_ft])
features["dense_features"] = tf.stack(int_features)

features["sparse_features"] = {}
for i, sparse_ft in enumerate(self.sparse_features):
features['sparse_features'][str(i)] = parsed_features[sparse_ft]
return features, label

# TODO(qinyiyan): Enable sharding.
filenames = tf.data.Dataset.list_files(self._params.input_path, shuffle=False)

num_shards_per_host = 1
if params.sharding:
num_shards_per_host = params.num_shards_per_host

def make_dataset(shard_index):
filenames_for_shard = filenames.shard(num_shards_per_host, shard_index)
dataset = tf.data.TFRecordDataset(filenames_for_shard)
if params.is_training:
dataset = dataset.repeat()
dataset = dataset.map(
_parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
return dataset

indices = tf.data.Dataset.range(num_shards_per_host)
dataset = indices.interleave(
map_func=make_dataset,
cycle_length=params.cycle_length,
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)

dataset = dataset.batch(
batch_size,
drop_remainder=True,
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
if self._use_cached_data:
dataset = dataset.take(1).cache().repeat()

return dataset


class CriteoTsvReader:
"""Input reader callable for pre-processed Criteo data."""

Expand All @@ -27,24 +120,28 @@ def __init__(self, config: ml_collections.ConfigDict, is_training: bool):
self._model_config = config.model

def __call__(self) -> tf.data.Dataset:
if self._params.use_synthetic_data:
if self._params.dataset_format == DatasetFormat.SYNTHETIC:
return self._generate_synthetic_data()

@tf.function
def _parse_fn(example: tf.Tensor):
"""Parser function for pre-processed Criteo TSV records."""
label_defaults = [[0.0]]
dense_defaults = [[0.0] for _ in range(self._model_config.num_dense_features)]
dense_defaults = [
[0.0] for _ in range(self._model_config.num_dense_features)
]
num_sparse_features = len(self._model_config.vocab_sizes)
categorical_defaults = [[0] for _ in range(num_sparse_features)]
record_defaults = label_defaults + dense_defaults + categorical_defaults
fields = tf.io.decode_csv(example, record_defaults, field_delim='\t', na_value='-1')
fields = tf.io.decode_csv(
example, record_defaults, field_delim="\t", na_value="-1"
)

label = tf.reshape(fields[0], [1])
features = {}
dense_features = fields[1:self._model_config.num_dense_features + 1]
features['dense_features'] = tf.stack(dense_features, axis=0)
features['sparse_features'] = {
dense_features = fields[1 : self._model_config.num_dense_features + 1]
features["dense_features"] = tf.stack(dense_features, axis=0)
features["sparse_features"] = {
str(i): fields[i + self._model_config.num_dense_features + 1]
for i in range(num_sparse_features)
}
Expand All @@ -65,19 +162,32 @@ def _generate_synthetic_data(self) -> tf.data.Dataset:
num_dense = self._model_config.num_dense_features
dataset_size = 100 * self._params.global_batch_size

dense_tensor = tf.random.uniform(shape=(dataset_size, num_dense), maxval=1.0, dtype=tf.float32)
dense_tensor = tf.random.uniform(
shape=(dataset_size, num_dense), maxval=1.0, dtype=tf.float32
)
sparse_tensors = [
tf.random.uniform(shape=(dataset_size,), maxval=int(size), dtype=tf.int32)
for size in self._model_config.vocab_sizes
]
sparse_tensor_elements = {str(i): sparse_tensors[i] for i in range(len(sparse_tensors))}
sparse_tensor_elements = {
str(i): sparse_tensors[i] for i in range(len(sparse_tensors))
}

dense_tensor_mean = tf.math.reduce_mean(dense_tensor, axis=1)
sparse_tensors_mean = tf.math.reduce_sum(tf.stack(sparse_tensors, axis=-1), axis=1)
sparse_tensors_mean = tf.cast(sparse_tensors_mean, dtype=tf.float32) / sum(self._model_config.vocab_sizes)
label_tensor = tf.cast((dense_tensor_mean + sparse_tensors_mean) / 2.0 + 0.5, tf.int32)

input_elem = {'dense_features': dense_tensor, 'sparse_features': sparse_tensor_elements}, label_tensor
sparse_tensors_mean = tf.math.reduce_sum(
tf.stack(sparse_tensors, axis=-1), axis=1
)
sparse_tensors_mean = tf.cast(sparse_tensors_mean, dtype=tf.float32) / sum(
self._model_config.vocab_sizes
)
label_tensor = tf.cast(
(dense_tensor_mean + sparse_tensors_mean) / 2.0 + 0.5, tf.int32
)

input_elem = {
"dense_features": dense_tensor,
"sparse_features": sparse_tensor_elements,
}, label_tensor
dataset = tf.data.Dataset.from_tensor_slices(input_elem)
dataset = dataset.cache()
if self._params.is_training:
Expand All @@ -86,10 +196,16 @@ def _generate_synthetic_data(self) -> tf.data.Dataset:

return dataset


def train_input_fn(config: ml_collections.ConfigDict) -> tf.data.Dataset:
"""Returns dataset of batched training examples."""
return CriteoTsvReader(config, is_training=True)()
if config.train_data.dataset_format in {DatasetFormat.SYNTHETIC, DatasetFormat.TSV}:
return CriteoTsvReader(config, is_training=True)()
return CriteoTFRecordReader(config, is_training=True)()


def eval_input_fn(config: ml_collections.ConfigDict) -> tf.data.Dataset:
"""Returns dataset of batched eval examples."""
return CriteoTsvReader(config, is_training=False)()
if config.validation_data.dataset_format in {DatasetFormat.SYNTHETIC, DatasetFormat.TSV}:
return CriteoTsvReader(config, is_training=True)()
return CriteoTFRecordReader(config, is_training=False)()
Loading