Skip to content

Commit eda28f3

Browse files
committed
feat(recommender): implement VisualizationRecommender for orchestrating model querying, parsing, and ensemble scoring
1 parent 1eda4eb commit eda28f3

1 file changed

Lines changed: 177 additions & 0 deletions

File tree

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import pandas as pd
2+
from pprint import pprint
3+
from typing import Dict, List, Optional, Tuple
4+
5+
from plotsense.core.ai_interface import AIModelInterface
6+
from plotsense.core.enums.strategy import StrategyName
7+
from plotsense.core.providers.provider_manager import ProviderManager
8+
from plotsense.visual_suggestion.recommender.dataframe_analyzer import DataFrameAnalyzer
9+
from plotsense.visual_suggestion.recommender.ensemble_scorer import EnsembleScorer
10+
from plotsense.visual_suggestion.recommender.prompt_builder import PromptBuilder
11+
from plotsense.visual_suggestion.recommender.response_parser import ResponseParser
12+
13+
14+
class VisualizationRecommender:
15+
16+
def __init__(
17+
self,
18+
api_keys: Optional[Dict[str, str]],
19+
strategy: StrategyName,
20+
selected_models: Optional[List[Tuple[str, str]]],
21+
timeout: int,
22+
interactive: bool,
23+
debug: bool,
24+
):
25+
"""
26+
Initialize VisualizationRecommender with API keys and configuration.
27+
28+
Args:
29+
api_keys: Optional dictionary of API keys. If not provided,
30+
keys will be loaded from environment variables.
31+
timeout: Timeout in seconds for API requests
32+
interactive: Whether to prompt for missing API keys
33+
debug: Enable debug output
34+
"""
35+
self.timeout = timeout
36+
self.interactive = interactive
37+
self.debug = debug
38+
self.strategy_name = strategy
39+
40+
selected_providers = {p for p, _ in (selected_models or [])}
41+
42+
self.manager = ProviderManager(
43+
api_keys=api_keys or {},
44+
interactive=interactive,
45+
restrict_to=list(selected_providers) if selected_providers else None
46+
)
47+
self.ai_interface = AIModelInterface(self.manager, timeout=self.timeout)
48+
49+
all_models = self.manager.list_all_models()
50+
self.available_models = [
51+
(provider, model)
52+
for provider, models in all_models.items()
53+
for model in models
54+
]
55+
56+
if not self.available_models:
57+
raise ValueError(
58+
"No available models detected — check API keys or selection input."
59+
)
60+
61+
# initialize strategy instance
62+
self.strategy = self.ai_interface._init_strategy(
63+
self.strategy_name, self.available_models
64+
)
65+
66+
self.df = None
67+
# model_weights will be lazily obtained from AIModelInterface if not provided
68+
self.model_weights = {}
69+
70+
if self.debug:
71+
print("\n[DEBUG] Initialization Complete")
72+
print(f"Available models: {self.available_models}")
73+
print(f"Model weights: {self.model_weights}")
74+
75+
def set_dataframe(self, df: pd.DataFrame):
76+
"""Set the DataFrame to analyze and provide debug info"""
77+
self.df = df
78+
if self.debug:
79+
print("\n[DEBUG] DataFrame Info:")
80+
print(f"Shape: {df.shape}")
81+
print("Columns:", df.columns.tolist())
82+
print("\nSample data:")
83+
print(df.head(2))
84+
85+
def recommend_visualizations(
86+
self, n: int = 5, custom_weights: Optional[Dict[str, float]] = None
87+
) -> pd.DataFrame:
88+
"""
89+
Generate visualization recommendations using weighted ensemble approach.
90+
91+
Args:
92+
n: Number of recommendations to return (default: 3)
93+
custom_weights: Optional dictionary to override default model weights
94+
95+
Returns:
96+
pd.DataFrame: Recommended visualizations with ensemble scores
97+
98+
Raises:
99+
ValueError: If no DataFrame is set or no models are available
100+
"""
101+
"""Generate visualization recommendations using weighted ensemble approach."""
102+
self.n_to_request = max(n, 5)
103+
104+
if self.df is None:
105+
raise ValueError("No DataFrame set. Call set_dataframe() first.")
106+
107+
if not self.available_models:
108+
raise ValueError("No available models detected")
109+
110+
if self.debug:
111+
print("\n[DEBUG] Starting recommendation process")
112+
print(f"Using models: {self.available_models}")
113+
114+
# Use custom weights if provided, otherwise try self.model_weights then ai_interface weights
115+
if custom_weights:
116+
weights = custom_weights
117+
elif self.model_weights:
118+
weights = self.model_weights
119+
else:
120+
# Defer to AIModelInterface for default weights (keeps compatibility with provider-manager)
121+
weights = self.ai_interface.get_model_weights()
122+
123+
# Get recommendations from all models in parallel via AIModelInterface
124+
analyzer = DataFrameAnalyzer(self.df)
125+
df_description = analyzer.describe_dataframe()
126+
prompt = PromptBuilder(self.n_to_request).build_prompt(df_description)
127+
128+
if self.debug:
129+
print("\n[DEBUG] Prompt being sent to models:")
130+
print(prompt)
131+
132+
# Expecting ai_interface.query_all_models to return dict { "provider:model": "raw text" }
133+
all_recommendations = self.ai_interface.query_all_models(
134+
prompt, self.debug
135+
)
136+
137+
if self.debug:
138+
print("\n[DEBUG] Raw recommendations from models:")
139+
pprint(all_recommendations)
140+
141+
# Parse model responses into structured recommendation lists
142+
parser = ResponseParser(self.df, debug=self.debug)
143+
parsed_recs = {
144+
model: parser.parse_recommendations(response, model)
145+
for model, response in all_recommendations.items()
146+
}
147+
148+
if self.debug:
149+
print("\n[DEBUG] Applying ensemble scoring")
150+
151+
scorer = EnsembleScorer(
152+
self.df, debug=self.debug,
153+
available_models=self.available_models
154+
)
155+
# Use weights determined above (which respects custom_weights)
156+
ensemble_df = scorer.apply_ensemble_scoring(parsed_recs, weights)
157+
158+
final_df = pd.DataFrame()
159+
# Validate and correct variable order
160+
if not ensemble_df.empty:
161+
final_df = parser.validate_variable_order(ensemble_df)
162+
163+
# If we don't have enough results, try to supplement (mirror original behavior)
164+
if len(final_df) < n:
165+
if self.debug:
166+
print(f"\n[DEBUG] Only got {len(final_df)} recommendations, trying to supplement")
167+
# Use the same ensemble_df context when supplementing, so the scorer/parser can access source_models
168+
supplemented = scorer.supplement_recommendations(ensemble_df, n)
169+
return supplemented
170+
171+
if self.debug:
172+
print("\n[DEBUG] Ensemble results before filtering:")
173+
print(ensemble_df)
174+
175+
# Return the validated & ordered results (top-n)
176+
return ensemble_df.head(n)
177+

0 commit comments

Comments
 (0)