44
55"""
66
7- import os
87from abc import ABC , abstractmethod
98
9+ import tensorflow as tf
1010import keras
1111
12+ try :
13+ HAS_NL_EXPR = True
14+ except ImportError :
15+ HAS_NL_EXPR = False
16+
1217
1318def layers_as_string (layers ):
1419 if isinstance (layers , str ):
@@ -32,25 +37,11 @@ def __init__(
3237 self ,
3338 dataset ,
3439 ):
35- self .basedir = os .path .join (os .path .dirname (__file__ ), ".." , "predictors" )
3640 self .dataset = dataset
3741
3842 # Filled with get data if needed
3943 self ._data = None
4044
41- keras_version_file = f"{ dataset } _keras_version"
42-
43- try :
44- with open (os .path .join (self .basedir , keras_version_file )) as file_in :
45- version = file_in .read ().strip ()
46- except FileNotFoundError :
47- version = None
48- if version != keras .__version__ :
49- print (f"Keras version changed. Regenerate predictors for { dataset } " )
50- self .build_all_predictors ()
51- with open (os .path .join (self .basedir , keras_version_file ), "w" ) as file_out :
52- print (keras .__version__ , file = file_out )
53-
5445 def __iter__ (self ):
5546 return self .all_tested_layers .__iter__ ()
5647
@@ -69,39 +60,28 @@ def data(self):
6960 self .load_data ()
7061 return self ._data
7162
72- def predictor_file (self , predictor ):
73- return f"{ self .dataset } _{ layers_as_string (predictor )} .keras"
74-
75- def build_predictor (self , layers ):
63+ def get_case (self , layers ):
7664 """Build model for one predictor"""
7765 X , y = self .data
7866 predictor = self .compile (layers )
7967 predictor .fit (X , y )
8068
81- predictor .save (self .predictor_file (layers ))
8269 return predictor
8370
84- def build_all_predictors (self ):
85- """Build all the predictor for this case.
86- (Done when we have a new sklearn version)"""
87- for predictor in self :
88- self .build_predictor (predictor )
89-
90- def get_case (self , predictor ):
91- filename = self .predictor_file (predictor )
92- try :
93- return keras .saving .load_model (os .path .join (self .basedir , filename ))
94- except ValueError :
95- return self .build_predictor (predictor )
96-
9771
9872class HousingCases (Cases ):
9973 """Base class to have cases for testing regression models on diabetes set
10074
10175 This is appropriate for testing a regression with a single output."""
10276
10377 def __init__ (self ):
104- self .all_tested_layers = [[keras .layers .Dense (16 , activation = "relu" )]]
78+ self .all_tested_layers = [
79+ [keras .layers .Dense (16 , activation = "relu" )],
80+ ]
81+ if HAS_NL_EXPR :
82+ self .all_tested_layers .append (
83+ [keras .layers .Dense (16 , activation = "sigmoid" )],
84+ )
10585 super ().__init__ ("housing" )
10686 self .load_data ()
10787
@@ -117,3 +97,37 @@ def compile(self, layers):
11797 )
11898 nn .compile (loss = "mean_squared_error" , optimizer = "adam" )
11999 return nn
100+
101+
102+ class MNISTCases (Cases ):
103+ """Base class to have cases for testing regression models on diabetes set
104+
105+ This is appropriate for testing a regression with a single output."""
106+
107+ def __init__ (self ):
108+ self .all_tested_layers = [
109+ [keras .layers .Dense (20 , activation = "relu" )],
110+ ]
111+
112+ if HAS_NL_EXPR :
113+ self .all_tested_layers += [
114+ [keras .layers .Dense (20 , activation = "sigmoid" )],
115+ ]
116+ super ().__init__ ("housing" )
117+ self .load_data ()
118+
119+ def load_data (self ):
120+ (X_train , y_train ), (_ , _ ) = keras .datasets .fashion_mnist .load_data ()
121+ X_train = tf .reshape (tf .cast (X_train , tf .float32 ) / 255.0 , [- 1 , 28 * 28 ])
122+ self ._data = (X_train , y_train )
123+
124+ def compile (self , layers ):
125+ nn = keras .models .Sequential (
126+ [keras .layers .InputLayer ((28 * 28 ,))] + layers + [keras .layers .Dense (10 )]
127+ )
128+ nn .compile (
129+ optimizer = "adam" ,
130+ loss = tf .keras .losses .SparseCategoricalCrossentropy (),
131+ metrics = [tf .keras .metrics .SparseCategoricalAccuracy ()],
132+ )
133+ return nn
0 commit comments