Skip to content

Commit f6c7236

Browse files
committed
Updated train metrics generation for AutoMLX model
1 parent 789786f commit f6c7236

File tree

1 file changed

+29
-0
lines changed
  • ads/opctl/operator/lowcode/forecast/model

1 file changed

+29
-0
lines changed

ads/opctl/operator/lowcode/forecast/model/automlx.py

+29
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def _build_model(self) -> pd.DataFrame:
159159
self.models[s_id] = {}
160160
self.models[s_id]["model"] = model
161161
self.models[s_id]["le"] = self.le[s_id]
162+
self.models[s_id]["score"] = self.get_validation_score_and_metric(model)
162163

163164
# In case of Naive model, model.forecast function call does not return confidence intervals.
164165
if f"{target}_ci_upper" not in summary_frame:
@@ -511,3 +512,31 @@ def explain_model(self):
511512
f"Failed to generate explanations for series {s_id} with error: {e}."
512513
)
513514
logger.debug(f"Full Traceback: {traceback.format_exc()}")
515+
516+
def get_validation_score_and_metric(self, model):
517+
trials = model.completed_trials_summary_
518+
model_params = model.selected_model_params_
519+
if len(trials) > 0:
520+
score_col = [col for col in trials.columns if "Score" in col][0]
521+
validation_score = trials[trials.Hyperparameters == model_params][score_col].iloc[0]
522+
else:
523+
validation_score = 0
524+
return -1 * validation_score
525+
526+
def generate_train_metrics(self) -> pd.DataFrame:
527+
"""
528+
Generate Training Metrics when fitted data is not available.
529+
"""
530+
total_metrics = pd.DataFrame()
531+
for s_id in self.forecast_output.list_series_ids():
532+
try:
533+
metrics = {self.spec.metric.upper(): self.models[s_id]["score"]}
534+
metrics_df = pd.DataFrame.from_dict(metrics, orient="index", columns=[s_id])
535+
logger.warning("AutoMLX failed to generate training metrics. Recovering validation loss instead")
536+
total_metrics = pd.concat([total_metrics, metrics_df], axis=1)
537+
except Exception as e:
538+
logger.debug(
539+
f"Failed to generate training metrics for target_series: {s_id}"
540+
)
541+
logger.debug(f"Error: {e}")
542+
return total_metrics

0 commit comments

Comments
 (0)