-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
186 lines (162 loc) · 6.46 KB
/
Copy pathapp.py
File metadata and controls
186 lines (162 loc) · 6.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from prophet import Prophet
import pandas as pd
from typing import List, Dict, Optional
app = FastAPI(title="Onboarders Example Time Series Prediction Service")
class TimeSeriesData(BaseModel):
dates: List[str]
values: List[float]
class ProphetParameters(BaseModel):
# Core Parameters
changepoint_prior_scale: float = Field(
default=0.05,
gt=0,
description="Flexibility of the trend changes. Higher values allow more flexibility."
)
seasonality_prior_scale: float = Field(
default=10.0,
gt=0,
description="Strength of the seasonality model. Higher values allow stronger seasonal patterns."
)
holidays_prior_scale: float = Field(
default=10.0,
gt=0,
description="Strength of the holiday effects. Higher values allow stronger holiday effects."
)
seasonality_mode: str = Field(
default="additive",
pattern="^(additive|multiplicative)$",
description="Type of seasonality, either 'additive' or 'multiplicative'"
)
# Seasonality Parameters
yearly_seasonality: Optional[bool] = Field(
default=True,
description="Whether to include yearly seasonality"
)
weekly_seasonality: Optional[bool] = Field(
default=True,
description="Whether to include weekly seasonality"
)
daily_seasonality: Optional[bool] = Field(
default=False,
description="Whether to include daily seasonality"
)
# Growth Parameters
growth: str = Field(
default="linear",
pattern="^(linear|logistic|flat)$",
description="Type of growth trend: 'linear', 'logistic', or 'flat'"
)
cap: Optional[float] = Field(
default=None,
description="Growth cap for logistic growth"
)
floor: Optional[float] = Field(
default=None,
description="Growth floor for logistic growth"
)
# Changepoint Parameters
n_changepoints: int = Field(
default=25,
ge=0,
description="Number of potential changepoints"
)
changepoint_range: float = Field(
default=0.8,
gt=0,
le=1,
description="Proportion of history where changepoints are considered"
)
class ForecastRequest(BaseModel):
data: TimeSeriesData
periods: int = Field(
default=30,
gt=0,
description="Number of periods to forecast"
)
model_parameters: Optional[ProphetParameters] = Field(
default=None,
description="Custom Prophet model parameters"
)
return_components: bool = Field(
default=False,
description="Whether to return trend and seasonal components"
)
class ForecastResponse(BaseModel):
forecast_dates: List[str]
forecast_values: List[float]
forecast_lower_bound: List[float]
forecast_upper_bound: List[float]
components: Optional[Dict[str, List[float]]] = None
def prepare_data(data: TimeSeriesData) -> pd.DataFrame:
"""Convert input data to Prophet's required format."""
return pd.DataFrame({
'ds': pd.to_datetime(data.dates),
'y': data.values
})
def configure_prophet_model(params: ProphetParameters) -> Prophet:
"""Configure Prophet model with custom parameters."""
model_args = {
'changepoint_prior_scale': params.changepoint_prior_scale,
'seasonality_prior_scale': params.seasonality_prior_scale,
'holidays_prior_scale': params.holidays_prior_scale,
'seasonality_mode': params.seasonality_mode,
'yearly_seasonality': params.yearly_seasonality,
'weekly_seasonality': params.weekly_seasonality,
'daily_seasonality': params.daily_seasonality,
'growth': params.growth,
'n_changepoints': params.n_changepoints,
'changepoint_range': params.changepoint_range
}
# Add capacity parameters for logistic growth
if params.growth == 'logistic':
if params.cap is None or params.floor is None:
raise ValueError("Cap and floor must be specified for logistic growth")
model_args['growth'] = 'logistic'
model_args['cap'] = params.cap
model_args['floor'] = params.floor
return Prophet(**model_args)
@app.post("/forecast/", response_model=ForecastResponse)
async def create_forecast(request: ForecastRequest):
try:
# Prepare the input data
df = prepare_data(request.data)
# Configure and train the model
model_params = request.model_parameters or ProphetParameters()
model = configure_prophet_model(model_params)
model.fit(df)
# Create future dates for forecasting
future = model.make_future_dataframe(periods=request.periods)
# If using logistic growth, set cap and floor for future dates
if model_params.growth == 'logistic':
future['cap'] = model_params.cap
future['floor'] = model_params.floor
# Make predictions
forecast = model.predict(future)
# Prepare the response
response_data = {
'forecast_dates': forecast.ds[-request.periods:].dt.strftime('%Y-%m-%d').tolist(),
'forecast_values': forecast.yhat[-request.periods:].tolist(),
'forecast_lower_bound': forecast.yhat_lower[-request.periods:].tolist(),
'forecast_upper_bound': forecast.yhat_upper[-request.periods:].tolist()
}
# Add components if requested
if request.return_components:
components = {
'trend': forecast.trend[-request.periods:].tolist(),
'yearly': forecast.yearly[-request.periods:].tolist() if 'yearly' in forecast else None,
'weekly': forecast.weekly[-request.periods:].tolist() if 'weekly' in forecast else None,
'daily': forecast.daily[-request.periods:].tolist() if 'daily' in forecast else None
}
response_data['components'] = {k: v for k, v in components.items() if v is not None}
return ForecastResponse(**response_data)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/parameters/default")
async def get_default_parameters():
"""Return the default model parameters and their descriptions."""
return ProphetParameters.model_json_schema()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)