diff --git a/datasets.py b/datasets.py index 679c736e..d83df63b 100644 --- a/datasets.py +++ b/datasets.py @@ -811,7 +811,7 @@ def load_labels(self) -> pd.DataFrame: filename="ceo-2019-Zambia-Cropland-(RCMRD-Set-1)-sample-data-2021-12-12.csv", class_prob=lambda df: (df["Crop/non-crop"] == "Crop"), start_year=2019, - train_val_test=(0.2, 0.4, 0.4), + train_val_test=(0.6, 0.2, 0.2), latitude_col="lat", longitude_col="lon", filter_df=clean_ceo_data, @@ -820,7 +820,7 @@ def load_labels(self) -> pd.DataFrame: filename="ceo-2019-Zambia-Cropland-(RCMRD-Set-2)-sample-data-2021-12-12.csv", class_prob=lambda df: (df["Crop/non-crop"] == "Crop"), start_year=2019, - train_val_test=(0.2, 0.4, 0.4), + train_val_test=(0.6, 0.2, 0.2), latitude_col="lat", longitude_col="lon", filter_df=clean_ceo_data,