1515from typing import Optional
1616
1717import git
18+ import pandas as pd
1819from pathy import Pathy
1920from pydantic import BaseModel , Field , root_validator , validator
2021
22+ # nowcasting_dataset imports
2123from nowcasting_dataset .consts import (
2224 DEFAULT_N_GSP_PER_EXAMPLE ,
2325 DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE ,
2426 NWP_VARIABLE_NAMES ,
2527 SAT_VARIABLE_NAMES ,
2628)
27-
29+ from nowcasting_dataset . dataset . split import split
2830
2931IMAGE_SIZE_PIXELS_FIELD = Field (64 , description = "The number of pixels of the region of interest." )
3032METERS_PER_PIXEL_FIELD = Field (2000 , description = "The number of meters per pixel." )
@@ -102,7 +104,7 @@ class Satellite(DataSourceMixin):
102104 """Satellite configuration model"""
103105
104106 satellite_zarr_path : str = Field (
105- "gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr" ,
107+ "gs://solar-pv-nowcasting-data/satellite/EUMETSAT/SEVIRI_RSS/OSGB36/all_zarr_int16_single_timestep.zarr" , # noqa: E501
106108 description = "The path which holds the satellite zarr." ,
107109 )
108110 satellite_channels : tuple = Field (
@@ -116,7 +118,7 @@ class NWP(DataSourceMixin):
116118 """NWP configuration model"""
117119
118120 nwp_zarr_path : str = Field (
119- "gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr" ,
121+ "gs://solar-pv-nowcasting-data/NWP/UK_Met_Office/UKV__2018-01_to_2019-12__chunks__variable10__init_time1__step1__x548__y704__.zarr" , # noqa: E501
120122 description = "The path which holds the NWP zarr." ,
121123 )
122124 nwp_channels : tuple = Field (NWP_VARIABLE_NAMES , description = "the channels used in the nwp data" )
@@ -213,7 +215,8 @@ def set_forecast_and_history_minutes(cls, values):
213215 Run through the different data sources and if the forecast or history minutes are not set,
214216 then set them to the default values
215217 """
216-
218+ # It would be much better to use nowcasting_dataset.data_sources.ALL_DATA_SOURCE_NAMES,
219+ # but that causes a circular import.
217220 ALL_DATA_SOURCE_NAMES = ("pv" , "satellite" , "nwp" , "gsp" , "topographic" , "sun" )
218221 enabled_data_sources = [
219222 data_source_name
@@ -249,8 +252,8 @@ def set_all_to_defaults(cls):
249252class OutputData (BaseModel ):
250253 """Output data model"""
251254
252- filepath : str = Field (
253- "gs://solar-pv-nowcasting-data/prepared_ML_training_data/v7/" ,
255+ filepath : Pathy = Field (
256+ Pathy ( "gs://solar-pv-nowcasting-data/prepared_ML_training_data/v7/" ) ,
254257 description = (
255258 "Where the data is saved to. If this is running on the cloud then should include"
256259 " 'gs://' or 's3://'"
@@ -262,7 +265,29 @@ class Process(BaseModel):
262265 """Pydantic model of how the data is processed"""
263266
264267 seed : int = Field (1234 , description = "Random seed, so experiments can be repeatable" )
265- batch_size : int = Field (32 , description = "the number of examples per batch" )
268+ batch_size : int = Field (32 , description = "The number of examples per batch" )
269+ t0_datetime_frequency : pd .Timedelta = Field (
270+ pd .Timedelta ("5 minutes" ),
271+ description = (
272+ "The temporal frequency at which t0 datetimes will be sampled."
273+ " Can be any string that `pandas.Timedelta()` understands."
274+ " For example, if this is set to '5 minutes', then, for each example, the t0 datetime"
275+ " could be at 0, 5, ..., 55 minutes past the hour. If there are DataSources with a"
276+ " lower sample rate (e.g. half-hourly) then these lower-sample-rate DataSources will"
277+ " still produce valid examples. For example, if a half-hourly DataSource is asked for"
278+ " an example with t0=12:05, history_minutes=60, forecast_minutes=60, then it will"
279+ " return data at 11:30, 12:00, 12:30, and 13:00."
280+ ),
281+ )
282+ split_method : split .SplitMethod = Field (
283+ split .SplitMethod .DAY ,
284+ description = (
285+ "The method used to split the t0 datetimes into train, validation and test sets."
286+ ),
287+ )
288+ n_train_batches : int = 250
289+ n_validation_batches : int = 10
290+ n_test_batches : int = 10
266291 upload_every_n_batches : int = Field (
267292 16 ,
268293 description = (
0 commit comments