Skip to content

Commit 8006f96

Browse files
zoyahavtfx-copybara
authored andcommitted
Update census example to use RaggedFeature (and resolve related TODO).
PiperOrigin-RevId: 489961032
1 parent ab665ad commit 8006f96

File tree

3 files changed

+23
-29
lines changed

3 files changed

+23
-29
lines changed

examples/census_example_common.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,16 @@
5353
]
5454

5555

56-
RAW_DATA_FEATURE_SPEC = dict([(name, tf.io.FixedLenFeature([], tf.string))
57-
for name in CATEGORICAL_FEATURE_KEYS] +
58-
[(name, tf.io.FixedLenFeature([], tf.float32))
59-
for name in NUMERIC_FEATURE_KEYS] +
60-
[(name, tf.io.VarLenFeature(tf.float32))
61-
for name in OPTIONAL_NUMERIC_FEATURE_KEYS] +
62-
[(LABEL_KEY,
63-
tf.io.FixedLenFeature([], tf.string))])
56+
RAW_DATA_FEATURE_SPEC = dict(
57+
[(name, tf.io.FixedLenFeature([], tf.string))
58+
for name in CATEGORICAL_FEATURE_KEYS] +
59+
[(name, tf.io.FixedLenFeature([], tf.float32))
60+
for name in NUMERIC_FEATURE_KEYS] +
61+
[(name, # pylint: disable=g-complex-comprehension
62+
tf.io.RaggedFeature(
63+
tf.float32, value_key=name, partitions=[], row_splits_dtype=tf.int64))
64+
for name in OPTIONAL_NUMERIC_FEATURE_KEYS] +
65+
[(LABEL_KEY, tf.io.FixedLenFeature([], tf.string))])
6466

6567
_SCHEMA = tft.DatasetMetadata.from_feature_spec(RAW_DATA_FEATURE_SPEC).schema
6668

@@ -121,14 +123,10 @@ def preprocessing_fn(inputs):
121123
outputs[key] = tft.scale_to_0_1(inputs[key])
122124

123125
for key in OPTIONAL_NUMERIC_FEATURE_KEYS:
124-
# This is a SparseTensor because it is optional. Here we fill in a default
125-
# value when it is missing.
126-
sparse = tf.sparse.SparseTensor(inputs[key].indices, inputs[key].values,
127-
[inputs[key].dense_shape[0], 1])
128-
dense = tf.sparse.to_dense(sp_input=sparse, default_value=0.)
129-
# Reshaping from a batch of vectors of size 1 to a batch to scalars.
130-
dense = tf.squeeze(dense, axis=1)
131-
outputs[key] = tft.scale_to_0_1(dense)
126+
# This is a RaggedTensor because it is optional. Here we fill in a default
127+
# value when it is missing, after scaling it.
128+
outputs[key] = tft.scale_to_0_1(inputs[key]).to_tensor(
129+
default_value=0., shape=[None, 1])
132130

133131
# For all categorical columns except the label column, we generate a
134132
# vocabulary, and convert the string feature to a one-hot encoding.

examples/census_example_test.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515

1616
import os
1717

18-
import tensorflow as tf
1918
import census_example
2019
import census_example_common
20+
from tensorflow_transform import test_case
2121
import local_model_server
2222

2323

24-
class CensusExampleTest(tf.test.TestCase):
24+
class CensusExampleTest(test_case.TransformTestCase):
2525

2626
def testCensusExampleAccuracy(self):
2727
raw_data_dir = os.path.join(os.path.dirname(__file__), 'testdata/census')
@@ -106,13 +106,13 @@ def testCensusExampleAccuracy(self):
106106
}"""
107107
results = local_model_server.make_classification_request(
108108
address, ascii_classification_request)
109-
self.assertEqual(len(results), 1)
110-
self.assertEqual(len(results[0].classes), 2)
109+
self.assertLen(results, 1)
110+
self.assertLen(results[0].classes, 2)
111111
self.assertEqual(results[0].classes[0].label, '0')
112112
self.assertLess(results[0].classes[0].score, 0.01)
113113
self.assertEqual(results[0].classes[1].label, '1')
114114
self.assertGreater(results[0].classes[1].score, 0.99)
115115

116116

117117
if __name__ == '__main__':
118-
tf.test.main()
118+
test_case.main()

examples/census_example_v2.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,10 @@ def transform_dataset(data):
8787
for key, val in data.items():
8888
if key not in common.RAW_DATA_FEATURE_SPEC:
8989
continue
90-
if isinstance(common.RAW_DATA_FEATURE_SPEC[key], tf.io.VarLenFeature):
91-
# TODO(b/169666856): Remove conversion to sparse once ragged tensors are
92-
# natively supported.
90+
if isinstance(common.RAW_DATA_FEATURE_SPEC[key], tf.io.RaggedFeature):
91+
# make_csv_dataset will set the value to 0 when it's missing.
9392
raw_features[key] = tf.RaggedTensor.from_tensor(
94-
tf.expand_dims(val, -1)).to_sparse()
93+
tf.expand_dims(val, axis=-1), padding=0)
9594
continue
9695
raw_features[key] = val
9796
transformed_features = tft_layer(raw_features)
@@ -189,10 +188,7 @@ def train_and_evaluate(raw_train_eval_data_path_pattern,
189188

190189
inputs = {}
191190
for key, spec in feature_spec.items():
192-
if isinstance(spec, tf.io.VarLenFeature):
193-
inputs[key] = tf.keras.layers.Input(
194-
shape=[None], name=key, dtype=spec.dtype, sparse=True)
195-
elif isinstance(spec, tf.io.FixedLenFeature):
191+
if isinstance(spec, tf.io.FixedLenFeature):
196192
# TODO(b/208879020): Move into schema such that spec.shape is [1] and not
197193
# [] for scalars.
198194
inputs[key] = tf.keras.layers.Input(

0 commit comments

Comments
 (0)