diff --git a/plotsense/visual_suggestion/__init__.py b/plotsense/visual_suggestion/__init__.py index 9eba1bb..cf571f7 100644 --- a/plotsense/visual_suggestion/__init__.py +++ b/plotsense/visual_suggestion/__init__.py @@ -1 +1,17 @@ from plotsense.visual_suggestion.suggestions import recommender, VisualizationRecommender +from .viz_cache import ( + create_cache, + CacheKeyBuilder, + normalize_prompt, + schema_signature, + weights_signature, +) +__all__ = [ + "VisualizationRecommender", + "recommender", + "create_cache", + "CacheKeyBuilder", + "normalize_prompt", + "schema_signature", + "weights_signature", +] \ No newline at end of file diff --git a/plotsense/visual_suggestion/suggestions.py b/plotsense/visual_suggestion/suggestions.py index 052fd6f..2558310 100644 --- a/plotsense/visual_suggestion/suggestions.py +++ b/plotsense/visual_suggestion/suggestions.py @@ -12,52 +12,61 @@ from pprint import pprint from groq import Groq +# cache utils +from viz_cache import ( + create_cache, + CacheKeyBuilder, + normalize_prompt, + schema_signature, + weights_signature +) load_dotenv() class VisualizationRecommender: DEFAULT_MODELS = { - 'groq': [ - ('llama-3.3-70b-versatile', 0.5), # (model_name, weight) - ('llama-3.1-8b-instant', 0.5) - + "groq": [ + ("llama-3.3-70b-versatile", 0.5), + ("llama-3.1-8b-instant", 0.5), ], - } - def __init__(self, - api_keys: Optional[Dict[str, - str]] = None, - timeout: int = 30, - interactive: bool = True, - debug: bool = False): + def __init__( + self, + api_keys: Optional[Dict[str, str]] = None, + timeout: int = 30, + interactive: bool = True, + debug: bool = False, + ): """ Initialize VisualizationRecommender with API keys and configuration. - - Args: - api_keys: Optional dictionary of API keys. If not provided, - keys will be loaded from environment variables. - timeout: Timeout in seconds for API requests - interactive: Whether to prompt for missing API keys - debug: Enable debug output """ self.interactive = interactive self.debug = debug api_keys = api_keys or {} self.api_keys = { - 'groq': os.getenv('GROQ_API_KEY') - # Add other services here + "groq": os.getenv("GROQ_API_KEY"), } + self.api_keys.update(api_keys) self.timeout = timeout - self.clients = {} - self.available_models = [] - self.df = None - self.model_weights = {} + self.clients: Dict[str, Groq] = {} + self.available_models: List[str] = [] + self.df: Optional[pd.DataFrame] = None + self.model_weights: Dict[str, float] = {} self.n_to_request = 5 - self.api_keys.update(api_keys) + # versions (bump these when prompt/scoring logic changes to bust cache) + self.code_version = "v1.0" + self.prompt_version = "v1" + + # in-session cache (memory only) + self.cache = create_cache( + capacity=256, + default_ttl=30 * 60, # 30 minutes + log_hits=self.debug, + ) self._validate_keys() self._initialize_clients() @@ -68,34 +77,25 @@ def __init__(self, print("\n[DEBUG] Initialization Complete") print(f"Available models: {self.available_models}") print(f"Model weights: {self.model_weights}") - if hasattr(self, 'clients'): - print(f"Clients initialized: {bool(self.clients)}") + print(f"Clients initialized: {bool(self.clients)}") - def _validate_keys(self): - """Validate that required API keys are present""" - service_links = { - 'groq': 'https://console.groq.com/keys ' - } + # ------------------------ setup ------------------------ - for service in ['groq']: + def _validate_keys(self): + """Validate that required API keys are present.""" + service_links = {"groq": "https://console.groq.com/keys"} + for service in ["groq"]: if not self.api_keys.get(service): if self.interactive: try: - link = service_links.get( - service, f"the {service.upper()} website") - message = ( - f"Enter { - service.upper()} API key (get it at {link}): ") - self.api_keys[service] = builtins.input( - message).strip() + link = service_links.get(service, f"the {service.upper()} website") + message = f"Enter {service.upper()} API key (get it at {link}): " + self.api_keys[service] = builtins.input(message).strip() if not self.api_keys[service]: - raise ValueError( - f"{service.upper()} API key is required") + raise ValueError(f"{service.upper()} API key is required") except (EOFError, OSError): - # Handle cases where input is not available raise ValueError( - f"{service.upper()} API key is required " - f"(get it at {service_links.get(service)})" + f"{service.upper()} API key is required (get it at {service_links.get(service)})" ) else: raise ValueError( @@ -105,47 +105,40 @@ def _validate_keys(self): ) def _initialize_clients(self): - """Initialize API clients""" + """Initialize API clients.""" self.clients = {} - if self.api_keys.get('groq'): + if self.api_keys.get("groq"): try: - self.clients['groq'] = Groq(api_key=self.api_keys['groq']) + self.clients["groq"] = Groq(api_key=self.api_keys["groq"]) except ImportError: - warnings.warn( - "Groq Python client not installed. pip install groq") + warnings.warn("Groq Python client not installed. pip install groq") def _detect_available_models(self): self.available_models = [] for provider, client in self.clients.items(): if client and provider in self.DEFAULT_MODELS: - self.available_models.extend( - [m[0] for m in self.DEFAULT_MODELS[provider]]) - + self.available_models.extend([m[0] for m in self.DEFAULT_MODELS[provider]]) if self.debug: - print( - f"[DEBUG] Detected available models: {self.available_models}") + print(f"[DEBUG] Detected available models: {self.available_models}") def _initialize_model_weights(self): - total_weight = 0 + total_weight = 0.0 self.model_weights = {} - - # Only include weights for available models for provider in self.DEFAULT_MODELS: for model, weight in self.DEFAULT_MODELS[provider]: if model in self.available_models: - self.model_weights[model] = weight - total_weight += weight - - # Normalize weights to sum to 1 + self.model_weights[model] = float(weight) + total_weight += float(weight) if total_weight > 0: - for model in self.model_weights: - self.model_weights[model] /= total_weight - + for model in list(self.model_weights.keys()): + self.model_weights[model] = self.model_weights[model] / total_weight if self.debug: print(f"[DEBUG] Model weights: {self.model_weights}") + # ------------------------ public API ------------------------ + def set_dataframe(self, df: pd.DataFrame): - """Set the DataFrame to analyze and provide debug info""" + """Set the DataFrame to analyze and provide debug info.""" self.df = df if self.debug: print("\n[DEBUG] DataFrame Info:") @@ -154,30 +147,18 @@ def set_dataframe(self, df: pd.DataFrame): print("\nSample data:") print(df.head(2)) - def recommend_visualizations(self, - n: int = 5, - custom_weights: Optional[Dict[str, - float]] = None) -> pd.DataFrame: + def recommend_visualizations( + self, + n: int = 5, + custom_weights: Optional[Dict[str, float]] = None, + ) -> pd.DataFrame: """ Generate visualization recommendations using weighted ensemble approach. - - Args: - n: Number of recommendations to return (default: 3) - custom_weights: Optional dictionary to override default model weights - - Returns: - pd.DataFrame: Recommended visualizations with ensemble scores - - Raises: - ValueError: - If no DataFrame is set or no models are available """ - """Generate visualization recommendations using weighted ensemble approach.""" self.n_to_request = max(n, 5) if self.df is None: raise ValueError("No DataFrame set. Call set_dataframe() first.") - if not self.available_models: raise ValueError("No available models detected") @@ -185,19 +166,34 @@ def recommend_visualizations(self, print("\n[DEBUG] Starting recommendation process") print(f"Using models: {self.available_models}") - # Use custom weights if provided, otherwise use defaults + # weights weights = custom_weights if custom_weights else self.model_weights - # Get recommendations from all models in parallel + # ---- Try ensemble cache first (skips model calls entirely) ---- + ens_key = CacheKeyBuilder.ensemble( + df_schema_sig=schema_signature(self.df), + models=self.available_models, + weights_sig=weights_signature(weights), + n=n, + code_version=self.code_version, + prompt_version=self.prompt_version, + ) + cached_df = self.cache.get(ens_key) + if cached_df is not None: + if self.debug: + print("[DEBUG] Ensemble result from cache") + print(self.cache.stats()) + return cached_df.head(n) + + # ---- Get recommendations from all models in parallel ---- all_recommendations = self._get_all_recommendations() if self.debug: print("\n[DEBUG] Raw recommendations from models:") pprint(all_recommendations) - # Apply weighted ensemble scoring - ensemble_results = self._apply_ensemble_scoring( - all_recommendations, weights) + # ---- Apply weighted ensemble scoring ---- + ensemble_results = self._apply_ensemble_scoring(all_recommendations, weights) # Validate and correct variable order if not ensemble_results.empty: @@ -206,21 +202,25 @@ def recommend_visualizations(self, # If we don't have enough results, try to supplement if len(ensemble_results) < n: if self.debug: - print( - f"\n[DEBUG] Only got { - len(ensemble_results)} recommendations, trying to supplement") - return self._supplement_recommendations(ensemble_results, n) + print(f"\n[DEBUG] Only got {len(ensemble_results)} recommendations, trying to supplement") + out = self._supplement_recommendations(ensemble_results, n) + # store what we have in cache + self.cache.set(ens_key, out) + if self.debug: + print("\n[DEBUG] Cache stats:", self.cache.stats()) + return out + # Store full ensemble in cache and return top-n + self.cache.set(ens_key, ensemble_results) if self.debug: print("\n[DEBUG] Ensemble results before filtering:") print(ensemble_results) - + print("\n[DEBUG] Cache stats:", self.cache.stats()) return ensemble_results.head(n) - def _supplement_recommendations( - self, - existing: pd.DataFrame, - target: int) -> pd.DataFrame: + # ------------------------ internals ------------------------ + + def _supplement_recommendations(self, existing: pd.DataFrame, target: int) -> pd.DataFrame: """Generate additional recommendations if we didn't get enough initially.""" if len(existing) >= target: return existing.head(target) @@ -228,12 +228,12 @@ def _supplement_recommendations( needed = target - len(existing) df_description = self._describe_dataframe() - # Try to get more recommendations from the best-performing model - best_model = existing.iloc[0]['source_models'][0] if not existing.empty else self.available_models[0] + # Try to get more recommendations from the first contributing model + best_model = existing.iloc[0]["source_models"][0] if not existing.empty else self.available_models[0] prompt = textwrap.dedent(f""" You already recommended these visualizations: - {existing[['plot_type', 'variables']].to_string()} + {existing[['plot_type', 'variables']].to_string(index=False)} Please recommend {needed} ADDITIONAL different visualizations for: {df_description} @@ -242,28 +242,21 @@ def _supplement_recommendations( """) try: + # model call is cached by _get_model_recommendations path; here we call directly response = self._query_llm(prompt, best_model) - new_recs = self._parse_recommendations( - response, f"{best_model}-supplement") + new_recs = self._parse_recommendations(response, f"{best_model}-supplement") - # Combine with existing - combined = pd.concat( - [existing, pd.DataFrame(new_recs)], ignore_index=True) - combined = combined.drop_duplicates( - subset=['plot_type', 'variables']) + combined = pd.concat([existing, pd.DataFrame(new_recs)], ignore_index=True) + combined = combined.drop_duplicates(subset=["plot_type", "variables"]) if self.debug: - print( - f"\n[DEBUG] Supplemented with { - len(new_recs)} new recommendations") + print(f"\n[DEBUG] Supplemented with {len(new_recs)} new recommendations") return combined.head(target) except Exception as e: if self.debug: - print( - f"\n[WARNING] Couldn't supplement recommendations: { - str(e)}") - return existing.head(target) # Return what we have + print(f"\n[WARNING] Couldn't supplement recommendations: {str(e)}") + return existing.head(target) def _get_all_recommendations(self) -> Dict[str, List[Dict]]: df_description = self._describe_dataframe() @@ -273,23 +266,21 @@ def _get_all_recommendations(self) -> Dict[str, List[Dict]]: print("\n[DEBUG] Prompt being sent to models:") print(prompt) - model_handlers = { - 'llama': self._query_llm - - # Add other model handlers here + model_handlers: Dict[str, Callable[[str, str], str]] = { + "llama": self._query_llm, + # add other model families here and map accordingly } - all_recommendations = {} + all_recommendations: Dict[str, List[Dict]] = {} with ThreadPoolExecutor() as executor: futures = {} for model in self.available_models: - model_type = model.split('-')[0].lower() + model_type = model.split("-")[0].lower() if model_type.startswith(("llama", "mistral")): model_type = "llama" if "llama" in model_type else "mistral" - query_func = model_handlers[model_type] - futures[executor.submit( - self._get_model_recommendations, model, prompt, query_func)] = model + query_func = model_handlers.get(model_type, self._query_llm) + futures[executor.submit(self._get_model_recommendations, model, prompt, query_func)] = model for future in concurrent.futures.as_completed(futures): model = futures[future] @@ -297,93 +288,102 @@ def _get_all_recommendations(self) -> Dict[str, List[Dict]]: result = future.result() all_recommendations[model] = result if self.debug: - print( - f"\n[DEBUG] Got { - len(result)} recommendations from {model}") + print(f"\n[DEBUG] Got {len(result)} recommendations from {model}") except Exception as e: - warnings.warn( - f"Failed to get recommendations from {model}: { - str(e)}") + warnings.warn(f"Failed to get recommendations from {model}: {str(e)}") if self.debug: print(f"\n[ERROR] Failed to process {model}: {str(e)}") return all_recommendations - def _get_model_recommendations(self, - model: str, - prompt: str, - query_func: Callable[[str, - str], - str]) -> List[Dict]: - try: - response = query_func(prompt, model) - + def _get_model_recommendations( + self, + model: str, + prompt: str, + query_func: Callable[[str, str], str], + ) -> List[Dict]: + """ + Call the model with caching at the model-response layer, + then parse into structured recommendations. + """ + # build namespaced cache key for this model response + key = CacheKeyBuilder.model_response( + provider="groq", + model=model, + norm_prompt=normalize_prompt(prompt), + df_schema_sig=schema_signature(self.df), + prompt_version=self.prompt_version, + # include params that can change output + temperature=0.4, + ) + + def compute_fn() -> str: if self.debug: - print(f"\n[DEBUG] Raw response from {model}:") - print(response) + print(f"[DEBUG] Cache MISS → calling model {model}") + return query_func(prompt, model) + + raw_text = self.cache.get_or_compute(key, compute_fn) + if self.debug and raw_text: + print(f"\n[DEBUG] Raw response from {model} (cached={raw_text is not None}):") + # print the first few lines for brevity + print("\n".join(raw_text.splitlines()[:12])) - return self._parse_recommendations(response, model) + try: + return self._parse_recommendations(raw_text, model) except Exception as e: warnings.warn(f"Error processing model {model}: {str(e)}") if self.debug: - print( - f"\n[ERROR] Failed to parse response from {model}: { - str(e)}") + print(f"\n[ERROR] Failed to parse response from {model}: {str(e)}") return [] - def _apply_ensemble_scoring(self, - all_recommendations: Dict[str, - List[Dict]], - weights: Dict[str, - float]) -> pd.DataFrame: - output_columns = ['plot_type', 'variables', - 'ensemble_score', 'model_agreement', 'source_models'] + def _apply_ensemble_scoring( + self, + all_recommendations: Dict[str, List[Dict]], + weights: Dict[str, float], + ) -> pd.DataFrame: + output_columns = ["plot_type", "variables", "ensemble_score", "model_agreement", "source_models"] if self.debug: print("\n[DEBUG] Applying ensemble scoring with weights:") pprint(weights) recommendation_weights = defaultdict(float) - recommendation_details = {} + recommendation_details: Dict[tuple, Dict] = {} for model, recs in all_recommendations.items(): - model_weight = weights.get(model, 0) + model_weight = weights.get(model, 0.0) if model_weight <= 0: continue for rec in recs: - # Create a consistent key for the recommendation - variables = rec['variables'] + variables = rec["variables"] if isinstance(variables, str): - variables = [v.strip() for v in variables.split(',')] + variables = [v.strip() for v in variables.split(",")] - # Filter variables to only those in the DataFrame - valid_vars = [ - var for var in variables if var in self.df.columns] + # keep only variables that exist in dataframe + valid_vars = [var for var in variables if var in self.df.columns] if not valid_vars: if self.debug: - print( - f"\n[DEBUG] Skipping recommendation from {model} with invalid variables: {variables}") + print(f"\n[DEBUG] Skipping recommendation from {model} with invalid variables: {variables}") continue - var_key = ', '.join(sorted(valid_vars)) - rec_key = (rec['plot_type'].lower(), var_key) + var_key = ", ".join(sorted(valid_vars)) + rec_key = (rec["plot_type"].lower(), var_key) - model_score = rec.get('score', 1.0) + model_score = float(rec.get("score", 1.0)) total_weight = model_weight * model_score recommendation_weights[rec_key] += total_weight if rec_key not in recommendation_details: recommendation_details[rec_key] = { - 'plot_type': rec['plot_type'], - 'variables': var_key, - 'source_models': [model], - 'raw_weight': total_weight + "plot_type": rec["plot_type"], + "variables": var_key, + "source_models": [model], + "raw_weight": total_weight, } else: - recommendation_details[rec_key]['source_models'].append( - model) - recommendation_details[rec_key]['raw_weight'] += total_weight + recommendation_details[rec_key]["source_models"].append(model) + recommendation_details[rec_key]["raw_weight"] += total_weight if not recommendation_details: if self.debug: @@ -398,11 +398,13 @@ def _apply_ensemble_scoring(self, if not results.empty: total_possible = sum(weights.values()) - results['ensemble_score'] = results['raw_weight'] / total_possible - results['ensemble_score'] = results['ensemble_score'].round(2) - results['model_agreement'] = results['source_models'].apply(len) - results = results.sort_values(['ensemble_score', 'model_agreement'], ascending=[ - False, False]).reset_index(drop=True) + results["ensemble_score"] = results["raw_weight"] / total_possible if total_possible else 0.0 + results["ensemble_score"] = results["ensemble_score"].round(2) + results["model_agreement"] = results["source_models"].apply(len) + results = results.sort_values( + ["ensemble_score", "model_agreement"], + ascending=[False, False] + ).reset_index(drop=True) return results[output_columns] return pd.DataFrame(columns=output_columns) @@ -424,7 +426,7 @@ def _describe_dataframe(self) -> str: col_type = "datetime" elif pd.api.types.is_numeric_dtype(self.df[col]): col_type = "numerical" - elif self.df[col].nunique() / len(self.df[col]) < 0.05: # Low cardinality + elif self.df[col].nunique() / len(self.df[col]) < 0.05: col_type = "categorical" else: col_type = "text/other" @@ -432,17 +434,14 @@ def _describe_dataframe(self) -> str: # Basic info unique_count = self.df[col].nunique() sample_values = self.df[col].dropna().head(sample_size).tolist() - desc.append( - f"- {col}: {col_type} ({unique_count} unique values), sample: {sample_values}") + desc.append(f"- {col}: {col_type} ({unique_count} unique values), sample: {sample_values}") # Add stats for numerical/datetime if col_type == "numerical": desc.append( - f" Stats: min={ - self.df[col].min()}, max={ - self.df[col].max()}, " f"mean={ - self.df[col].mean():.2f}, missing={ - self.df[col].isna().sum()}") + f" Stats: min={self.df[col].min()}, max={self.df[col].max()}, " + f"mean={self.df[col].mean():.2f}, missing={self.df[col].isna().sum()}" + ) elif col_type == "datetime": desc.append( f" Range: {self.df[col].min()} to {self.df[col].max()}, " @@ -450,18 +449,13 @@ def _describe_dataframe(self) -> str: ) # --- Relationship Analysis --- - numerical_cols = self.df.select_dtypes( - include=np.number).columns.tolist() + numerical_cols = self.df.select_dtypes(include=np.number).columns.tolist() if len(numerical_cols) > 1: desc.append("\nNumerical Variable Correlations (Pearson):") corr = self.df[numerical_cols].corr().round(2) desc.append(str(corr)) - # Categorical-numerical potential groupings - categorical_cols = [ - col for col in self.df.columns - if self.df[col].nunique() / len(self.df[col]) < 0.05 - ] + categorical_cols = [col for col in self.df.columns if self.df[col].nunique() / len(self.df[col]) < 0.05] if categorical_cols and numerical_cols: desc.append("\nPotential Groupings (categorical vs numerical):") desc.append(f" - Could group by: {categorical_cols}") @@ -485,8 +479,8 @@ def _create_prompt(self, df_description: str) -> str: CRITICAL VARIABLE ORDERING RULES: 1. If a suggestion includes both numerical and categorical variables, NUMERICAL VARIABLES MUST COME FIRST. - - Correct: "income, gender" - - Incorrect: "gender, income" + - Correct: "income, gender" + - Incorrect: "gender, income" 2. For plots requiring two numerical variables (e.g., scatter), order by analysis priority (dependent variable first). 3. For single-variable plots, use natural order (e.g., "age" for a histogram). @@ -548,66 +542,40 @@ def _create_prompt(self, df_description: str) -> str: Example INCORRECT suggestions (REJECT THESE): Plot Type: boxplot - Variables: gender, income # WRONG - categorical listed first + Variables: gender, income --- Plot Type: scatter - Variables: price, weight # WRONG - no clear priority order + Variables: price, weight Rationale: Should specify independent/dependent variable order """) def _query_llm(self, prompt: str, model: str) -> str: - if not self.clients.get('groq'): + if not self.clients.get("groq"): raise ValueError("Groq client not initialized") - try: - response = self.clients['groq'].chat.completions.create( + response = self.clients["groq"].chat.completions.create( model=model, messages=[{"role": "user", "content": prompt}], temperature=0.4, max_tokens=1000, - timeout=self.timeout + timeout=self.timeout, ) return response.choices[0].message.content except Exception as e: raise RuntimeError(f"Groq API query failed for {model}: {str(e)}") - def _validate_variable_order( - self, recommendations: pd.DataFrame) -> pd.DataFrame: - """ - Validate and correct the order of variables in recommendations, - ensuring numerical variables come first. - - Args: - recommendations: DataFrame of visualization recommendations - - Returns: - DataFrame with corrected variable order - """ + def _validate_variable_order(self, recommendations: pd.DataFrame) -> pd.DataFrame: + """Ensure datetime/numerical vars come before non-numerical in 'variables'.""" def _reorder_variables(row): - # Split variables - variables = [var.strip() for var in row['variables'].split(',')] + variables = [var.strip() for var in row["variables"].split(",")] - # Identify numerical and non-numerical variables - numerical_vars = [ - var for var in variables - if pd.api.types.is_numeric_dtype(self.df[var]) - ] - - date_vars = [ - var for var in variables - if pd.api.types.is_datetime64_any_dtype(self.df[var]) - ] - - non_numerical_vars = [ - var for var in variables - if var not in numerical_vars and var not in date_vars - ] + numerical_vars = [var for var in variables if pd.api.types.is_numeric_dtype(self.df[var])] + date_vars = [var for var in variables if pd.api.types.is_datetime64_any_dtype(self.df[var])] + non_numerical_vars = [var for var in variables if var not in numerical_vars and var not in date_vars] # Combine with numerical variables first corrected_vars = date_vars + numerical_vars + non_numerical_vars - - # Update the row with corrected variable order - row['variables'] = ', '.join(corrected_vars) + row["variables"] = ", ".join(corrected_vars) return row # Apply reordering @@ -616,58 +584,45 @@ def _reorder_variables(row): if self.debug: print("\n[DEBUG] Variable Order Validation:") - for orig, corrected in zip( - recommendations['variables'], corrected_recommendations['variables']): - if orig != corrected: - print(f" Corrected: {orig} → {corrected}") + for orig, corrected_vars in zip(recommendations["variables"], corrected["variables"]): + if orig != corrected_vars: + print(f" Corrected: {orig} → {corrected_vars}") - return corrected_recommendations + return corrected def _parse_recommendations(self, response: str, model: str) -> List[Dict]: - """Parse the LLM response into structured recommendations""" - recommendations = [] - - # Split response into recommendation blocks - blocks = [b.strip() for b in response.split('---') if b.strip()] + """Parse the LLM response into structured recommendations.""" + recommendations: List[Dict] = [] + blocks = [b.strip() for b in response.split("---") if b.strip()] if self.debug: print(f"\n[DEBUG] Parsing {len(blocks)} blocks from {model}") for block in blocks: - lines = [line.strip() - for line in block.split('\n') if line.strip()] + lines = [line.strip() for line in block.split("\n") if line.strip()] if not lines: continue try: - rec = {'source_model': model} + rec = {"source_model": model} for line in lines: - if line.lower().startswith('plot type:'): - rec['plot_type'] = line.split( - ':', 1)[1].strip().lower() - elif line.lower().startswith('variables:'): - raw_vars = line.split(':', 1)[1].strip() - # Filter variables to only those that exist in - # DataFrame - variables = [v.strip() for v in raw_vars.split( - ',') if v.strip() in self.df.columns] - rec['variables'] = ', '.join( - [var for var in variables if var in self.df.columns]) - # rec['variables'] = self._reorder_variables(', - # '.join(variables)) # Keep original order for now - - if 'plot_type' in rec and 'variables' in rec and rec['variables']: + if line.lower().startswith("plot type:"): + rec["plot_type"] = line.split(":", 1)[1].strip().lower() + elif line.lower().startswith("variables:"): + raw_vars = line.split(":", 1)[1].strip() + vars_list = [v.strip() for v in raw_vars.split(",") if v.strip() in self.df.columns] + rec["variables"] = ", ".join(vars_list) + if "plot_type" in rec and rec.get("variables"): recommendations.append(rec) except Exception as e: - warnings.warn( - f"Failed to parse recommendation from {model}: {str(e)}") + warnings.warn(f"Failed to parse recommendation from {model}: {str(e)}") continue return recommendations # Package-level convenience function -_recommender_instance = None +_recommender_instance: Optional[VisualizationRecommender] = None def recommender( @@ -675,7 +630,7 @@ def recommender( n: int = 5, api_keys: dict = {}, custom_weights: Optional[Dict[str, float]] = None, - debug: bool = False + debug: bool = False, ) -> pd.DataFrame: """ Generate visualization recommendations using weighted ensemble of LLMs. @@ -692,11 +647,9 @@ def recommender( """ global _recommender_instance if _recommender_instance is None: - _recommender_instance = VisualizationRecommender( - api_keys=api_keys, debug=debug) - + _recommender_instance = VisualizationRecommender(api_keys=api_keys, debug=debug) _recommender_instance.set_dataframe(df) return _recommender_instance.recommend_visualizations( n=n, - custom_weights=custom_weights + custom_weights=custom_weights, ) diff --git a/plotsense/visual_suggestion/viz_cache.py b/plotsense/visual_suggestion/viz_cache.py new file mode 100644 index 0000000..454b192 --- /dev/null +++ b/plotsense/visual_suggestion/viz_cache.py @@ -0,0 +1,740 @@ +# viz_cache.py +from __future__ import annotations + +import time +import threading +import hashlib +import json +import logging +from typing import Any, Optional, Dict, Callable, TypeVar, Generic, Protocol +from dataclasses import dataclass, asdict, field +from contextlib import contextmanager +import random +import pandas as pd + +logger = logging.getLogger(__name__) +T = TypeVar("T") + +# ---------- Cache Metrics ---------- +@dataclass +class CacheStats: + """Track cache performance metrics.""" + hits: int = 0 + misses: int = 0 + evictions: int = 0 + expirations: int = 0 + sets: int = 0 + total_size: int = 0 + # Real savings: sum of last_compute_ms attributed on cache HITs + total_compute_time_saved_ms: float = 0.0 + lock_contentions: int = 0 + + @property + def hit_rate(self) -> float: + total = self.hits + self.misses + return self.hits / total if total > 0 else 0.0 + + def to_dict(self) -> dict: + return {**asdict(self), "hit_rate": self.hit_rate} + + +# ---------- Cache Entry ---------- +@dataclass +class CacheEntry(Generic[T]): + """Wrapper for cached values with metadata.""" + value: T + expires_at: float + created_at: float + key: str + metadata: Optional[Dict[str, Any]] = None + access_count: int = field(default=0) # Track popularity + last_accessed: float = field(default_factory=time.time) + + @property + def is_expired(self) -> bool: + return time.time() >= self.expires_at + + @property + def age_seconds(self) -> float: + return time.time() - self.created_at + + @property + def ttl_remaining(self) -> float: + """Remaining TTL in seconds (negative if expired).""" + return self.expires_at - time.time() + + +# ---------- Hash Utilities ---------- +class HashUtils: + """Centralized hashing utilities.""" + + @staticmethod + def sha256(s: str) -> str: + return hashlib.sha256(s.encode("utf-8")).hexdigest() + + @staticmethod + def json_hash(obj: Any) -> str: + """Deterministic hash of JSON-serializable objects.""" + return HashUtils.sha256(json.dumps(obj, sort_keys=True, default=str)) + + @staticmethod + def combine_hashes(*hashes: str) -> str: + """Combine multiple hashes into one.""" + return HashUtils.sha256("".join(hashes)) + + +# ---------- Prompt Normalization ---------- +def normalize_prompt(s: str) -> str: + """Normalize prompt for consistent caching.""" + return "\n".join(line.strip() for line in s.strip().splitlines() if line.strip()) + + +# ---------- Schema Signatures ---------- +def schema_signature(df: pd.DataFrame) -> dict: + """Generate schema-only signature (no data values).""" + return { + "shape": [int(x) for x in df.shape], + "rowcount": int(len(df)), + "schema": [ + { + "name": str(col), + "dtype": str(df[col].dtype), + "nullable": bool(df[col].isnull().any()), + } + for col in df.columns + ], + } + + +def weights_signature(weights: dict) -> dict: + """Normalize weights dictionary for caching.""" + return {"weights": {k: float(weights[k]) for k in sorted(weights)}} + + +# ---------- Cache Key Builders ---------- +class CacheKeyBuilder: + """Factory for building cache keys with consistent structure (namespaced).""" + + @staticmethod + def model_response( + *, + provider: str, + model: str, + norm_prompt: str, + df_schema_sig: dict, + prompt_version: str, + temperature: float = 0.0, + ns: str = "model", + **extra_params, + ) -> str: + """ + Build key for individual model responses. + Returns a namespaced key like 'model:' so prefix invalidation works. + """ + key_data = { + "type": "model_response", + "provider": provider, + "model": model, + "prompt_hash": HashUtils.sha256(norm_prompt), + "df_schema_hash": HashUtils.json_hash(df_schema_sig), + "prompt_version": prompt_version, + "temperature": float(temperature), + } + if extra_params: + key_data["extra"] = HashUtils.json_hash(extra_params) + + digest = HashUtils.json_hash(key_data) + return f"{ns}:{digest}" + + @staticmethod + def ensemble( + *, + df_schema_sig: dict, + models: list, + weights_sig: dict, + n: int, + code_version: str, + prompt_version: str, + ns: str = "ensemble", + ) -> str: + """ + Build key for ensemble results. + Returns a namespaced key like 'ensemble:' so prefix invalidation works. + """ + digest = HashUtils.json_hash( + { + "type": "ensemble", + "df_schema_hash": HashUtils.json_hash(df_schema_sig), + "models": sorted(models), + "weights_hash": HashUtils.json_hash(weights_sig), + "n": int(n), + "code_version": code_version, + "prompt_version": prompt_version, + } + ) + return f"{ns}:{digest}" + + +# ---------- Eviction Policies ---------- +class EvictionPolicy(Protocol): + """Protocol for custom eviction strategies.""" + def select_victim(self, entries: Dict[str, CacheEntry], order: list[str]) -> Optional[str]: + """Select key to evict. Return None if nothing to evict.""" + ... + + +class LRUEviction: + """Least Recently Used eviction.""" + def select_victim(self, entries: Dict[str, CacheEntry], order: list[str]) -> Optional[str]: + return order[0] if order else None + + +class LFUEviction: + """Least Frequently Used eviction.""" + def select_victim(self, entries: Dict[str, CacheEntry], order: list[str]) -> Optional[str]: + if not entries: + return None + return min(entries.keys(), key=lambda k: entries[k].access_count) + + +class TTLEviction: + """Evict soonest-to-expire entry.""" + def select_victim(self, entries: Dict[str, CacheEntry], order: list[str]) -> Optional[str]: + if not entries: + return None + return min(entries.keys(), key=lambda k: entries[k].expires_at) + + +# ---------- Cache Backend ---------- +class MemoryTTLCache(Generic[T]): + """ + Thread-safe in-memory cache with TTL support and dogpile prevention. + - Configurable eviction policy (LRU by default) + - TTL expiry on read and via optional cleanup() + - Per-key compute locks to avoid thundering herds on cache miss + - Optional periodic background cleanup + """ + + def __init__( + self, + capacity: int = 512, + default_ttl: int = 3600, + ttl_jitter_ratio: float = 0.1, + max_value_bytes: Optional[int] = None, + eviction_policy: Optional[EvictionPolicy] = None, + background_cleanup_interval: Optional[int] = None, + maintain_lru: Optional[bool] = None, # auto: only keep order for LRU + ): + """ + Args: + capacity: maximum number of entries + default_ttl: default TTL in seconds for set() without ttl_seconds + ttl_jitter_ratio: +/- jitter applied to TTL to avoid stampedes (0.1 = ±10%) + max_value_bytes: if set, values larger than this will not be cached + eviction_policy: custom eviction strategy (defaults to LRU) + background_cleanup_interval: if set, spawn background thread to cleanup every N seconds + maintain_lru: override LRU maintenance; default True if LRUEviction else False + """ + self.capacity = int(capacity) + self.default_ttl = int(default_ttl) + self.ttl_jitter_ratio = float(ttl_jitter_ratio) + self.max_value_bytes = max_value_bytes + self.eviction_policy = eviction_policy or LRUEviction() + + # Maintain _order list only when using LRU to reduce churn + if maintain_lru is None: + self._maintain_lru = isinstance(self.eviction_policy, LRUEviction) + else: + self._maintain_lru = bool(maintain_lru) + + self._data: Dict[str, CacheEntry[T]] = {} + self._order: list[str] = [] # Used by LRU + self._lock = threading.RLock() + self._stats = CacheStats() + + # Dogpile prevention + self._locks_guard = threading.Lock() + self._key_locks: Dict[str, threading.Lock] = {} + + # Background cleanup + self._cleanup_interval = background_cleanup_interval + self._cleanup_thread: Optional[threading.Thread] = None + self._stop_cleanup = threading.Event() + if self._cleanup_interval: + self._start_background_cleanup() + + # ------------- internal helpers ------------- + + def _now(self) -> float: + return time.time() + + def _apply_ttl_jitter(self, ttl: int) -> int: + if self.ttl_jitter_ratio <= 0: + return ttl + delta = int(ttl * self.ttl_jitter_ratio) + if delta <= 0: + return ttl + return ttl + random.randint(-delta, +delta) + + def _value_size_ok(self, value: Any) -> bool: + if self.max_value_bytes is None: + return True + try: + # Estimate size via JSON serialization + payload = json.dumps(value, default=str) + size = len(payload.encode("utf-8")) + if size > self.max_value_bytes: + logger.debug(f"Value size {size} exceeds limit {self.max_value_bytes}") + return False + return True + except Exception as e: + logger.warning(f"Could not estimate value size: {e}") + return True # Allow if we can't measure + + def _touch(self, key: str) -> None: + """Update access tracking for LRU and popularity counters.""" + entry = self._data.get(key) + if entry: + entry.access_count += 1 + entry.last_accessed = self._now() + if self._maintain_lru: + try: + self._order.remove(key) + except ValueError: + pass + self._order.append(key) + + def _evict(self, key: str, *, reason: str) -> None: + """Remove entry from cache.""" + self._data.pop(key, None) + if self._maintain_lru: + try: + self._order.remove(key) + except ValueError: + pass + + if reason == "expiration": + self._stats.expirations += 1 + else: + self._stats.evictions += 1 + + # also clean up any stale per-key lock + self._cleanup_key_locks() + + def _evict_victim(self) -> None: + """Evict one entry based on policy.""" + victim = self.eviction_policy.select_victim(self._data, self._order) + if victim: + logger.debug(f"Evicting victim: {victim[:16]}...") + self._evict(victim, reason="capacity") + + def _get_key_lock(self, key: str) -> threading.Lock: + """Get or create per-key lock for dogpile prevention.""" + with self._locks_guard: + lock = self._key_locks.get(key) + if lock is None: + lock = threading.Lock() + self._key_locks[key] = lock + return lock + + def _cleanup_key_locks(self) -> None: + """Remove locks for keys no longer in cache.""" + with self._locks_guard: + active_keys = set(self._data.keys()) + stale_keys = [k for k in self._key_locks if k not in active_keys] + for k in stale_keys: + self._key_locks.pop(k, None) + + def _background_cleanup_loop(self) -> None: + """Background thread for periodic cleanup.""" + while not self._stop_cleanup.wait(self._cleanup_interval): + try: + removed = self.cleanup_expired() + if removed > 0: + logger.info(f"Background cleanup removed {removed} expired entries") + self._cleanup_key_locks() + except Exception as e: + logger.error(f"Background cleanup error: {e}", exc_info=True) + + def _start_background_cleanup(self) -> None: + """Start background cleanup thread.""" + if self._cleanup_thread and self._cleanup_thread.is_alive(): + return + + self._stop_cleanup.clear() + self._cleanup_thread = threading.Thread( + target=self._background_cleanup_loop, + daemon=True, + name="CacheCleanup", + ) + self._cleanup_thread.start() + logger.info(f"Started background cleanup (interval: {self._cleanup_interval}s)") + + def _stop_background_cleanup(self) -> None: + """Stop background cleanup thread.""" + if self._cleanup_thread and self._cleanup_thread.is_alive(): + self._stop_cleanup.set() + self._cleanup_thread.join(timeout=5.0) + logger.info("Stopped background cleanup") + + # ------------- public API ------------- + + def peek(self, key: str) -> Optional[T]: + """Return value if present and valid WITHOUT affecting LRU order or stats.""" + with self._lock: + entry = self._data.get(key) + if entry is None or entry.is_expired: + return None + return entry.value + + def get(self, key: str) -> Optional[T]: + """Get value from cache, None if missing or expired.""" + with self._lock: + entry = self._data.get(key) + if entry is None: + self._stats.misses += 1 + return None + + if entry.is_expired: + logger.debug("Cache expired: %s (age=%.1fs)", key[:16], entry.age_seconds) + self._evict(key, reason="expiration") + self._stats.misses += 1 + return None + + self._touch(key) + self._stats.hits += 1 + + # credit "time saved" if we know last compute cost + last_ms = (entry.metadata or {}).get("last_compute_ms") + if last_ms is not None: + try: + self._stats.total_compute_time_saved_ms += float(last_ms) + except Exception: + pass + + logger.debug( + "Cache hit: %s (age=%.1fs, ttl=%.1fs)", + key[:16], entry.age_seconds, entry.ttl_remaining + ) + return entry.value + + def set( + self, + key: str, + value: T, + ttl_seconds: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Store value in cache with TTL.""" + ttl = int(ttl_seconds if ttl_seconds is not None else self.default_ttl) + ttl = max(1, self._apply_ttl_jitter(ttl)) + now = self._now() + + if not self._value_size_ok(value): + logger.debug("Cache skip (value too large): %s", key[:16]) + return + + with self._lock: + entry = CacheEntry[T]( + value=value, + expires_at=now + ttl, + created_at=now, + key=key, + metadata=metadata, + ) + + if key in self._data: + # Update existing + self._data[key] = entry + self._touch(key) + else: + # Insert new + self._data[key] = entry + if self._maintain_lru: + self._order.append(key) + + self._stats.sets += 1 + self._stats.total_size = len(self._data) + + # Evict if over capacity + while self.capacity and len(self._data) > self.capacity: + self._evict_victim() + + logger.debug("Cache set: %s (ttl=%ss)", key[:16], ttl) + + @contextmanager + def timed_compute(self): + """ + Context manager to time compute and expose elapsed ms via yielded dict. + Usage: + with cache.timed_compute() as t: + value = compute() + cache.set(key, value, metadata={"last_compute_ms": t["ms"]}) + """ + start = time.perf_counter() + box: Dict[str, float] = {} + try: + yield box + finally: + elapsed_ms = (time.perf_counter() - start) * 1000.0 + box["ms"] = elapsed_ms + + def get_or_compute( + self, + key: str, + compute_fn: Callable[[], T], + ttl_seconds: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> T: + """ + Dogpile-safe: only one thread computes value for a key on miss. + Others wait and get the fresh value. + """ + # Fast path + cached = self.get(key) + if cached is not None: + return cached + + lock = self._get_key_lock(key) + acquired = lock.acquire(blocking=False) + + if not acquired: + with self._lock: + self._stats.lock_contentions += 1 + with lock: # Wait for the computing thread + cached2 = self.get(key) + if cached2 is not None: + return cached2 + # Should be rare: compute anyway + with self.timed_compute() as t: + value = compute_fn() + meta = dict(metadata or {}) + meta.setdefault("last_compute_ms", t.get("ms", 0.0)) + self.set(key, value, ttl_seconds=ttl_seconds, metadata=meta) + return value + + try: + # We got the lock, double-check + cached2 = self.get(key) + if cached2 is not None: + return cached2 + + # Compute and cache + with self.timed_compute() as t: + value = compute_fn() + meta = dict(metadata or {}) + meta.setdefault("last_compute_ms", t.get("ms", 0.0)) + self.set(key, value, ttl_seconds=ttl_seconds, metadata=meta) + return value + finally: + lock.release() + + def invalidate(self, key: str) -> bool: + """Remove specific key from cache. Returns True if key existed.""" + with self._lock: + existed = key in self._data + if existed: + self._evict(key, reason="manual") + self._cleanup_key_locks() + return existed + + def invalidate_pattern(self, prefix: str) -> int: + """ + Invalidate all keys starting with prefix. Returns count removed. + Works because keys are namespaced (e.g., 'ensemble:'). + """ + removed = 0 + with self._lock: + keys_to_remove = [k for k in self._data if k.startswith(prefix)] + for key in keys_to_remove: + self._evict(key, reason="manual") + removed += 1 + self._cleanup_key_locks() + return removed + + def clear(self) -> None: + """Clear all cache entries.""" + with self._lock: + self._data.clear() + self._order.clear() + self._key_locks.clear() + logger.info("Cache cleared") + + def get_stats(self) -> CacheStats: + """Get current cache statistics (snapshot).""" + with self._lock: + return CacheStats(**asdict(self._stats)) + + def cleanup_expired(self) -> int: + """Remove all expired entries. Returns count of removed entries.""" + now = self._now() + removed = 0 + with self._lock: + for key in list(self._data.keys()): + entry = self._data.get(key) + if entry is None: + continue + if now >= entry.expires_at: + self._evict(key, reason="expiration") + removed += 1 + self._cleanup_key_locks() + return removed + + def get_entry_info(self, key: str) -> Optional[Dict[str, Any]]: + """Get metadata about a cached entry without retrieving the value.""" + with self._lock: + entry = self._data.get(key) + if entry is None: + return None + return { + "key": key, + "created_at": entry.created_at, + "expires_at": entry.expires_at, + "age_seconds": entry.age_seconds, + "ttl_remaining": entry.ttl_remaining, + "access_count": entry.access_count, + "last_accessed": entry.last_accessed, + "metadata": entry.metadata, + } + + def close(self) -> None: + """Explicitly stop background cleanup thread (if any).""" + self._stop_background_cleanup() + + def __del__(self): + """Best-effort cleanup on GC (not guaranteed).""" + try: + self.close() + except Exception: + pass + + +# ---------- Cache Client ---------- +class CacheClient(Generic[T]): + """High-level cache interface with optional backend swapping.""" + + def __init__( + self, + backend: Optional[MemoryTTLCache[T]] = None, + enabled: bool = True, + log_hits: bool = False, + ): + self.backend = backend or MemoryTTLCache() + self.enabled = enabled + self.log_hits = log_hits + + def get(self, key: str) -> Optional[T]: + """Retrieve value from cache (None on miss).""" + if not self.enabled: + return None + value = self.backend.get(key) + if self.log_hits: + logger.info("Cache %s: %s", "HIT" if value is not None else "MISS", key[:16]) + return value + + def peek(self, key: str) -> Optional[T]: + """Peek without affecting LRU or stats.""" + if not self.enabled: + return None + return self.backend.peek(key) + + def set( + self, + key: str, + value: T, + ttl_seconds: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Store value in cache.""" + if not self.enabled: + return + self.backend.set(key, value, ttl_seconds, metadata) + + def get_or_compute( + self, + key: str, + compute_fn: Callable[[], T], + ttl_seconds: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> T: + """Get from cache or compute and store if missing (dogpile-safe).""" + if not self.enabled: + return compute_fn() + return self.backend.get_or_compute(key, compute_fn, ttl_seconds, metadata) + + def invalidate(self, key: str) -> bool: + """Remove specific key from cache.""" + if not self.enabled: + return False + return self.backend.invalidate(key) + + def invalidate_pattern(self, prefix: str) -> int: + """Invalidate all keys with given prefix (e.g., 'ensemble:').""" + if not self.enabled: + return 0 + return self.backend.invalidate_pattern(prefix) + + def clear(self) -> None: + """Clear entire cache.""" + if not self.enabled: + return + self.backend.clear() + + def stats(self) -> dict: + """Current cache statistics as dict.""" + if not self.enabled: + return {} + return self.backend.get_stats().to_dict() + + def cleanup(self) -> int: + """Clean up expired entries.""" + if not self.enabled: + return 0 + return self.backend.cleanup_expired() + + def info(self, key: str) -> Optional[Dict[str, Any]]: + """Get info about a cache entry.""" + if not self.enabled: + return None + return self.backend.get_entry_info(key) + + def close(self) -> None: + """Close underlying backend (stop background cleanup).""" + if not self.enabled: + return + self.backend.close() + + +# ---------- Convenience Factory ---------- +def create_cache( + capacity: int = 512, + default_ttl: int = 3600, + enabled: bool = True, + log_hits: bool = False, + ttl_jitter_ratio: float = 0.1, + max_value_bytes: Optional[int] = None, + eviction_policy: Optional[EvictionPolicy] = None, + background_cleanup_interval: Optional[int] = None, +) -> CacheClient[Any]: + """ + Factory function to create a configured cache client. + + Args: + capacity: max entries + default_ttl: seconds + enabled: master switch + log_hits: log HIT/MISS at INFO + ttl_jitter_ratio: ±ratio jitter to TTLs + max_value_bytes: if set, avoid caching huge values + eviction_policy: LRU (default), LFU, or TTL + background_cleanup_interval: auto-cleanup every N seconds + """ + backend = MemoryTTLCache( + capacity=capacity, + default_ttl=default_ttl, + ttl_jitter_ratio=ttl_jitter_ratio, + max_value_bytes=max_value_bytes, + eviction_policy=eviction_policy, + background_cleanup_interval=background_cleanup_interval, + ) + return CacheClient(backend=backend, enabled=enabled, log_hits=log_hits)