|
53 | 53 | ]
|
54 | 54 |
|
55 | 55 |
|
56 |
| -RAW_DATA_FEATURE_SPEC = dict([(name, tf.io.FixedLenFeature([], tf.string)) |
57 |
| - for name in CATEGORICAL_FEATURE_KEYS] + |
58 |
| - [(name, tf.io.FixedLenFeature([], tf.float32)) |
59 |
| - for name in NUMERIC_FEATURE_KEYS] + |
60 |
| - [(name, tf.io.VarLenFeature(tf.float32)) |
61 |
| - for name in OPTIONAL_NUMERIC_FEATURE_KEYS] + |
62 |
| - [(LABEL_KEY, |
63 |
| - tf.io.FixedLenFeature([], tf.string))]) |
| 56 | +RAW_DATA_FEATURE_SPEC = dict( |
| 57 | + [(name, tf.io.FixedLenFeature([], tf.string)) |
| 58 | + for name in CATEGORICAL_FEATURE_KEYS] + |
| 59 | + [(name, tf.io.FixedLenFeature([], tf.float32)) |
| 60 | + for name in NUMERIC_FEATURE_KEYS] + |
| 61 | + [(name, # pylint: disable=g-complex-comprehension |
| 62 | + tf.io.RaggedFeature( |
| 63 | + tf.float32, value_key=name, partitions=[], row_splits_dtype=tf.int64)) |
| 64 | + for name in OPTIONAL_NUMERIC_FEATURE_KEYS] + |
| 65 | + [(LABEL_KEY, tf.io.FixedLenFeature([], tf.string))]) |
64 | 66 |
|
65 | 67 | _SCHEMA = tft.DatasetMetadata.from_feature_spec(RAW_DATA_FEATURE_SPEC).schema
|
66 | 68 |
|
@@ -121,14 +123,10 @@ def preprocessing_fn(inputs):
|
121 | 123 | outputs[key] = tft.scale_to_0_1(inputs[key])
|
122 | 124 |
|
123 | 125 | for key in OPTIONAL_NUMERIC_FEATURE_KEYS:
|
124 |
| - # This is a SparseTensor because it is optional. Here we fill in a default |
125 |
| - # value when it is missing. |
126 |
| - sparse = tf.sparse.SparseTensor(inputs[key].indices, inputs[key].values, |
127 |
| - [inputs[key].dense_shape[0], 1]) |
128 |
| - dense = tf.sparse.to_dense(sp_input=sparse, default_value=0.) |
129 |
| - # Reshaping from a batch of vectors of size 1 to a batch to scalars. |
130 |
| - dense = tf.squeeze(dense, axis=1) |
131 |
| - outputs[key] = tft.scale_to_0_1(dense) |
| 126 | + # This is a RaggedTensor because it is optional. Here we fill in a default |
| 127 | + # value when it is missing, after scaling it. |
| 128 | + outputs[key] = tft.scale_to_0_1(inputs[key]).to_tensor( |
| 129 | + default_value=0., shape=[None, 1]) |
132 | 130 |
|
133 | 131 | # For all categorical columns except the label column, we generate a
|
134 | 132 | # vocabulary, and convert the string feature to a one-hot encoding.
|
|
0 commit comments