Skip to content

Commit af64ebc

Browse files
committed
Reverted back changes
Converted structures to tuples updated pages
1 parent b370230 commit af64ebc

File tree

7 files changed

+64
-34
lines changed

7 files changed

+64
-34
lines changed

docs/images/functional_diagram.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/index.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ label_enc = LabelBinarizer().fit(df['label'])
3838
# define a batch generator
3939
train_gen = BatchGenerator(
4040
df,
41-
x_structure=[('var1', var1_enc), ('var2', var2_enc)],
41+
x_structure=(('var1', var1_enc), ('var2', var2_enc)),
4242
y_structure=('label', label_enc),
4343
batch_size=4,
4444
train_mode=True
@@ -49,7 +49,7 @@ The generator returns batches of format (x_structure, y_structure) and the shape
4949

5050
```python
5151
>>> train_gen.shape
52-
([(None, ), (None, )], (None, 3))
52+
(((None, ), (None, )), (None, 3))
5353
```
5454

5555
The first element is a x_structure and it is a list if two inputs. Both of them are outputs of LabelEncoders, that

docs/user_guide.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ embark_encoder = LabelEncoder().fit(titanic_data['Embarked'])
5959
cabinclass_encoder = LabelEncoder().fit(titanic_data['Pclass'])
6060

6161
train_generator = BatchGenerator(titanic_data,
62-
x_structure=[
62+
x_structure=(
6363
('Embarked', embark_encoder),
6464
('Pclass', cabinclass_encoder),
65-
('Age', None)
66-
],
65+
('Age', None)
66+
),
6767
y_structure=('Survived', None),
6868
batch_size=32,
6969
shuffle=True

keras_batchflow/base/batch_generators/batch_generator.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@ class BatchGenerator:
1717
**Parameters:**
1818
1919
- **data** - a *Pandas dataframe* containing a dataset with both x and y
20-
- **x_structure** - *tuple* or *list of tuples* - a structure describing mapping of dataframe columns to
21-
pre-fitted encoders and to keras model inputs. When model has multiple inputs, keras expects
22-
a list of numpy arrays as model X's. Each tuple is a mapping of a dataframe column to a relevant encoder.
23-
It has format `('column name', encoder)`. If encoder is None, the column values will be converted to numpy
24-
array and passed unchanged. If `(None, value)` is used, a new constant of value = `value` will be added
25-
to Batch generator's output.
20+
- **x_structure** - *tuple* or *tuple of tuples* - a structure describing mapping of dataframe columns to
21+
pre-fitted encoders and to keras model inputs. When model has a single input x_structure will look like
22+
`x_structure=('column_name', encoder)`. When model has multiple inputs, keras expects a tuple of numpy arrays
23+
as model X's. The structure will look like `x_structure=(('column_name1', encoder1), ('column_name2', encoder2)`
24+
If encoder is None, the column values will be converted to numpy array and passed unchanged. If you want to
25+
add a constant to inputs or outputs, you can add tuples with `column_name = None` and constant value instead
26+
of encoder, like so: `(None, value)`.
27+
**Example:** `x_structure=(('column_name1', encoder1), ('column_name2', None), (None, 1)` - values in
28+
column_name1 are encoded by encoder1, values from column_name2 are passed through unchanged, the third column
29+
in the x structure will be a constant of 1. So the batch could be
30+
`(np.array(...), np.array(...), np.array(1, 1, ...))`
2631
- **y_structure** - (optional) *tuple* or *list of tuples* - a structure describing mapping of dataframe columns to
2732
pre-fitted encoders and to keras model output. When model has multiple output, keras expects
2833
a list of numpy arrays as model Y's. **Default: None**. Same rules and same format applies (see x_structure)

keras_batchflow/base/batch_shapers/batch_shaper.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,11 @@ def _walk_structure(self, data: pd.DataFrame, struc, func, **kwargs):
9797
ret = func(data=data, leaf=struc, **kwargs)
9898
return ret
9999
elif type(struc) in [list, tuple]:
100-
ret = [self._walk_structure(data, s, func, **kwargs) for s in struc]
101-
# we always return lists as tuples as tensorflow wants x and y to be tuples in case of multiple
102-
# components
103-
return tuple(ret)
100+
walked_structure = [self._walk_structure(data, s, func, **kwargs) for s in struc]
101+
if isinstance(struc, tuple):
102+
return tuple(walked_structure)
103+
else:
104+
return walked_structure
104105
else:
105106
raise ValueError('Error: structure definition in {} class only supports lists and tuples, but {}'
106107
'was found'.format(type(self).__name__, type(struc)))

tests/batch_generators/test_batch_generator.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_transform(self):
165165
le = LabelEncoder().fit(self.df['var2'])
166166
bg = BatchGenerator(
167167
self.df,
168-
x_structure=[('var1', self.lb), ('var2', le)],
168+
x_structure=(('var1', self.lb), ('var2', le)),
169169
y_structure=('label', self.le),
170170
shuffle=False,
171171
batch_size=3,
@@ -183,6 +183,19 @@ def test_transform(self):
183183
assert type(batch[0]) == tuple
184184
assert len(batch[0]) == 2
185185

186+
#test the same with x_structure a list
187+
bg = BatchGenerator(
188+
self.df,
189+
x_structure=[('var1', self.lb), ('var2', le)],
190+
y_structure=('label', self.le),
191+
shuffle=False,
192+
batch_size=3,
193+
)
194+
batch = bg.transform(self.df)
195+
assert type(batch) == tuple
196+
assert len(batch) == 2
197+
assert type(batch[0]) == list
198+
186199
def test_inverse_transform(self):
187200
# batch size equals to dataset size
188201
bg = BatchGenerator(
@@ -215,7 +228,7 @@ def test_shapes(self):
215228
le = LabelEncoder().fit(self.df['var2'])
216229
bg = BatchGenerator(
217230
self.df,
218-
x_structure=[('var1', self.lb), ('var2', le)],
231+
x_structure=(('var1', self.lb), ('var2', le)),
219232
y_structure=('label', self.le),
220233
shuffle=False,
221234
)
@@ -227,6 +240,17 @@ def test_shapes(self):
227240
assert sh[0][0] == (3,)
228241
assert sh[0][1] == (1,)
229242
assert sh[1] == (1,)
243+
#test the same with x_structure a list
244+
bg = BatchGenerator(
245+
self.df,
246+
x_structure=[('var1', self.lb), ('var2', le)],
247+
y_structure=('label', self.le),
248+
shuffle=False,
249+
)
250+
sh = bg.shapes
251+
assert type(sh) == tuple
252+
assert len(sh) == 2
253+
assert type(sh[0]) == list
230254

231255
def test_encoder_adaptor(self):
232256
"""

tests/batch_shapers/test_batch_shaper.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_2d_transformer(self, data, one_hot_encoder, label_encoder):
6969

7070
def test_many_x(self, data, label_binarizer, label_encoder):
7171
lb2 = LabelBinarizer().fit(data['var2'])
72-
bs = BatchShaper(x_structure=[('var1', label_binarizer), ('var2', lb2)],
72+
bs = BatchShaper(x_structure=(('var1', label_binarizer), ('var2', lb2)),
7373
y_structure=('label', label_encoder),
7474
data_sample=data)
7575
batch = bs.transform(data)
@@ -87,7 +87,7 @@ def test_many_x(self, data, label_binarizer, label_encoder):
8787
def test_many_y(self, data, label_binarizer, label_encoder):
8888
lb2 = LabelBinarizer().fit(data['var2'])
8989
bs = BatchShaper(x_structure=('var1', label_binarizer),
90-
y_structure=[('label', label_encoder), ('var2', lb2)],
90+
y_structure=(('label', label_encoder), ('var2', lb2)),
9191
data_sample=data)
9292
batch = bs.transform(data)
9393
assert type(batch) == tuple
@@ -111,7 +111,7 @@ def test_predict_batch(self, data, label_binarizer, label_encoder):
111111
so for predict, the generator must return a tuple (x,), where x is a list of inputs
112112
"""
113113
lb2 = LabelBinarizer().fit(data['var2'])
114-
batch_shaper = BatchShaper(x_structure=[('var1', label_binarizer), ('var2', lb2)], data_sample=data)
114+
batch_shaper = BatchShaper(x_structure=(('var1', label_binarizer), ('var2', lb2)), data_sample=data)
115115
batch = batch_shaper.transform(data)
116116
assert isinstance(batch, tuple)
117117
assert len(batch) == 1
@@ -146,7 +146,7 @@ def test_init_with_data_sample(self):
146146
pass
147147

148148
def test_none_transformer(self, data, label_binarizer, label_encoder):
149-
bs = BatchShaper(x_structure=[('var1', label_binarizer), ('var2', None)],
149+
bs = BatchShaper(x_structure=(('var1', label_binarizer), ('var2', None)),
150150
y_structure=('label', label_encoder),
151151
data_sample=data)
152152
batch = bs.transform(data)
@@ -157,7 +157,7 @@ def test_none_transformer(self, data, label_binarizer, label_encoder):
157157
assert np.array_equal(batch[0][1], np.expand_dims(data['var2'].values, axis=-1))
158158

159159
def test_const_component_int(self, data, label_binarizer, label_encoder):
160-
bs = BatchShaper(x_structure=[('var1', label_binarizer), (None, 0)],
160+
bs = BatchShaper(x_structure=(('var1', label_binarizer), (None, 0)),
161161
y_structure=('label', label_encoder),
162162
data_sample=data)
163163
batch = bs.transform(data)
@@ -169,7 +169,7 @@ def test_const_component_int(self, data, label_binarizer, label_encoder):
169169
assert batch[0][1].dtype == int
170170

171171
def test_const_component_float(self, data, label_binarizer, label_encoder):
172-
bs = BatchShaper(x_structure=[('var1', label_binarizer), (None, 0.)],
172+
bs = BatchShaper(x_structure=(('var1', label_binarizer), (None, 0.)),
173173
y_structure=('label', label_encoder),
174174
data_sample=data)
175175
batch = bs.transform(data)
@@ -181,7 +181,7 @@ def test_const_component_float(self, data, label_binarizer, label_encoder):
181181
assert batch[0][1].dtype == float
182182

183183
def test_const_component_str(self, data, label_binarizer, label_encoder):
184-
bs = BatchShaper(x_structure=[('var1', label_binarizer), (None, u'a')],
184+
bs = BatchShaper(x_structure=(('var1', label_binarizer), (None, u'a')),
185185
y_structure=('label', label_encoder),
186186
data_sample=data)
187187
batch = bs.transform(data)
@@ -194,7 +194,7 @@ def test_const_component_str(self, data, label_binarizer, label_encoder):
194194

195195
def test_metadata(self, data, label_binarizer, label_encoder):
196196
VarShaper._dummy_constant_counter = 0
197-
bs = BatchShaper(x_structure=[('var1', label_binarizer), (None, 0.)],
197+
bs = BatchShaper(x_structure=(('var1', label_binarizer), (None, 0.)),
198198
y_structure=('label', label_encoder),
199199
data_sample=data)
200200
md = bs.metadata
@@ -228,7 +228,7 @@ def test_metadata(self, data, label_binarizer, label_encoder):
228228

229229
def test_dummy_var_naming(self, data, label_binarizer, label_encoder):
230230
VarShaper._dummy_constant_counter = 0
231-
bs = BatchShaper(x_structure=[('var1', label_binarizer), (None, 0.), (None, 1.)],
231+
bs = BatchShaper(x_structure=(('var1', label_binarizer), (None, 0.), (None, 1.)),
232232
y_structure=('label', label_encoder),
233233
data_sample=data)
234234
md = bs.metadata
@@ -258,7 +258,7 @@ def inverse_transform(self, data):
258258
return data
259259

260260
a = A()
261-
bs = BatchShaper(x_structure=[('var1', label_binarizer), ('var1', a)],
261+
bs = BatchShaper(x_structure=(('var1', label_binarizer), ('var1', a)),
262262
y_structure=('label', label_encoder),
263263
data_sample=data)
264264
shapes = bs.shape
@@ -283,22 +283,22 @@ def inverse_transform(self, data):
283283
return data
284284

285285
a = A()
286-
bs = BatchShaper(x_structure=[('var1', label_binarizer), ('var1', a)],
286+
bs = BatchShaper(x_structure=(('var1', label_binarizer), ('var1', a)),
287287
y_structure=('label', label_encoder), data_sample=data)
288288
n_classes = bs.n_classes
289289
pass
290290

291291
def test_inverse_transform(self, data, label_binarizer, label_encoder):
292292
le2 = LabelEncoder().fit(data['var2'])
293293
bs = BatchShaper(x_structure=('var1', label_binarizer),
294-
y_structure=[('label', label_encoder), ('var2', le2)],
294+
y_structure=(('label', label_encoder), ('var2', le2)),
295295
data_sample=data)
296296
batch = bs.transform(data)
297297
inverse = bs.inverse_transform(batch[1])
298298
assert inverse.equals(data[['label', 'var2']])
299299
# Check inverse transform when constant field is in the structure
300300
bs = BatchShaper(x_structure=('var1', label_binarizer),
301-
y_structure=[('label', label_encoder), ('var2', le2), (None, 0.)],
301+
y_structure=(('label', label_encoder), ('var2', le2), (None, 0.)),
302302
data_sample=data)
303303
batch = bs.transform(data)
304304
# check that the constant field was added to the y output
@@ -309,7 +309,7 @@ def test_inverse_transform(self, data, label_binarizer, label_encoder):
309309
assert inverse.equals(data[['label', 'var2']])
310310
# Check inverse transform when direct mapping field is in the structure
311311
bs = BatchShaper(x_structure=('var1', label_binarizer),
312-
y_structure=[('label', label_encoder), ('var2', le2), ('var1', None)],
312+
y_structure=(('label', label_encoder), ('var2', le2), ('var1', None)),
313313
data_sample=data)
314314
batch = bs.transform(data)
315315
# check that the constant field was added to the y output
@@ -366,7 +366,7 @@ def test_batch_forking(self, data, label_binarizer, label_encoder):
366366
# check that data is not modified
367367
assert data.equals(data_snapshot)
368368
assert data_xy_fork.columns.nlevels == 2
369-
bs = BatchShaper(x_structure=[('var1', label_binarizer), ('label', label_encoder)],
369+
bs = BatchShaper(x_structure=(('var1', label_binarizer), ('label', label_encoder)),
370370
y_structure=('label', label_encoder),
371371
data_sample=data)
372372
tr = bs.transform(data_xy_fork)
@@ -384,7 +384,7 @@ def test_batch_forking(self, data, label_binarizer, label_encoder):
384384
batch_fork_01 = BatchFork(levels=(0, 1))
385385
data_01_fork = batch_fork_01.transform(data)
386386
assert data_01_fork.columns.nlevels == 2
387-
bs = BatchShaper(x_structure=[('var1', label_binarizer), ('label', label_encoder)],
387+
bs = BatchShaper(x_structure=(('var1', label_binarizer), ('label', label_encoder)),
388388
y_structure=('label', label_encoder),
389389
multiindex_xy_keys=(0, 1),
390390
data_sample=data)

0 commit comments

Comments
 (0)