diff --git a/.travis.yml b/.travis.yml index a64a792..a76f628 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,5 @@ language: python -virtualenv: - system_site_packages: true + env: matrix: # let's start simple: diff --git a/datacleaner/datacleaner.py b/datacleaner/datacleaner.py index 0fcb87b..e97edf9 100644 --- a/datacleaner/datacleaner.py +++ b/datacleaner/datacleaner.py @@ -30,7 +30,7 @@ update_checked = False def autoclean(input_dataframe, drop_nans=False, copy=False, encoder=None, - encoder_kwargs=None, ignore_update_check=False): + encoder_kwargs=None, ignore_update_check=False,**kwargs): """Performs a series of automated data cleaning transformations on the provided data set Parameters @@ -48,6 +48,9 @@ def autoclean(input_dataframe, drop_nans=False, copy=False, encoder=None, ignore_update_check: bool Do not check for the latest version of datacleaner + fill_func : function or method or string in 'full_func_list' + the function to fill nan + Returns ---------- output_dataframe: pandas.DataFrame @@ -71,10 +74,25 @@ def autoclean(input_dataframe, drop_nans=False, copy=False, encoder=None, if encoder_kwargs is None: encoder_kwargs = {} + fill_func = kwargs.pop('fill_func',"median") + + import inspect + assert inspect.isfunction(fill_func) or inspect.ismethod(fill_func) or type(fill_func) == str + + full_func_list = [ + 'sum', 'max', 'min', 'argmax', 'argmin', 'mean', + 'median','prod' + ] + + if type(fill_func) == str and fill_func in full_func_list: + fill_func = "nan{func}".format(func=fill_func) + mod = __import__("numpy.lib.nanfunctions",fromlist=[fill_func]) + fill_func = getattr(mod,fill_func) + for column in input_dataframe.columns.values: # Replace NaNs with the median or mode of the column depending on the column type try: - input_dataframe[column].fillna(input_dataframe[column].median(), inplace=True) + input_dataframe[column].fillna(fill_func(input_dataframe[column]), inplace=True) except TypeError: most_frequent = input_dataframe[column].mode() # If the mode can't be computed, use the nearest valid value diff --git a/tests.py b/tests.py index fc7b8ff..bcf843a 100644 --- a/tests.py +++ b/tests.py @@ -235,3 +235,24 @@ def test_autoclean_cv_real_data(): assert cleaned_adult_training_data.equals(hand_cleaned_training_adult_data) assert cleaned_adult_testing_data.equals(hand_cleaned_testing_adult_data) + + +def test_autoclean_with_nans_all_numerical_with_fill_func(): + """Test autoclean() with a data set that has all numerical values and some NaNs""" + data = pd.DataFrame({'A': np.random.rand(1000), + 'B': np.random.rand(1000), + 'C': np.random.randint(0, 3, 1000)}) + + data.loc[10:20, 'A'] = np.nan + data.loc[50:70, 'C'] = np.nan + + hand_cleaned_data = data.copy() + hand_cleaned_data['A'].fillna(hand_cleaned_data['A'].sum(), inplace=True) + hand_cleaned_data['C'].fillna(hand_cleaned_data['C'].sum(), inplace=True) + + cleaned_data = autoclean(data,fill_func="sum") + + assert cleaned_data.equals(hand_cleaned_data) + +if __name__ == '__main__': + test_autoclean_with_nans_all_numerical_with_fill_func() \ No newline at end of file