diff --git a/plotsense.egg-info/PKG-INFO b/plotsense.egg-info/PKG-INFO index d9dde76..23b086a 100644 --- a/plotsense.egg-info/PKG-INFO +++ b/plotsense.egg-info/PKG-INFO @@ -1,10 +1,11 @@ Metadata-Version: 2.4 Name: plotsense -Version: 0.1.0 +Version: 0.1.3 Summary: An intelligent plotting package with suggestions and explanations -Author-email: Christian Chimezie -License: MIT -Project-URL: Homepage, https://github.com/christianchimezie/PlotSenseAI +Home-page: https://github.com/christianchimezie/PlotSenseAI +Author: Christian Chimezie, Toluwaleke Ogidan, Grace Farayola, Amaka Iduwe, Nelson Ogbeide, Onyekachukwu Ojumah, Olamilekan Ajao +Author-email: chimeziechristiancc@gmail.com, gbemilekeogidan@gmail.com, gracefarayola@gmail.com, nwaamaka_iduwe@yahoo.com, Ogbeide331@gmail.com, Onyekaojumah22@gmail.com, olamilekan011@gmail.com +License: Apache License 2.0 Project-URL: Documentation, https://github.com/christianchimezie/PlotSenseAI/blob/main/README.md Classifier: Development Status :: 3 - Alpha Classifier: Intended Audience :: Science/Research @@ -21,11 +22,28 @@ Requires-Python: >=3.8 Description-Content-Type: text/markdown License-File: LICENCE License-File: NOTICE -Requires-Dist: matplotlib>=3.0 +Requires-Dist: matplotlib>=3.8.0 Requires-Dist: seaborn>=0.11 Requires-Dist: pandas>=1.0 Requires-Dist: numpy>=1.18 +Requires-Dist: python-dotenv +Requires-Dist: groq +Requires-Dist: anthropic +Requires-Dist: openai +Requires-Dist: google-genai +Requires-Dist: requests +Dynamic: author +Dynamic: author-email +Dynamic: classifier +Dynamic: description +Dynamic: description-content-type +Dynamic: home-page +Dynamic: license Dynamic: license-file +Dynamic: project-url +Dynamic: requires-dist +Dynamic: requires-python +Dynamic: summary # 🌟 PlotSense: AI-Powered Data Visualization Assistant @@ -51,7 +69,7 @@ pip install plotsense ```bash import plotsense as ps -from plotsense import recommender, generate_plot, explainer, +from plotsense import recommender, plotgen, explainer ``` ### πŸ” Authenticate with Groq API: Get your free API key from Groq Cloud https://console.groq.com/home @@ -82,7 +100,7 @@ print(suggestions) ``` ### πŸ“Š Sample Output: -![alt text](suggestions_table.png) +![alt text](image.png) πŸŽ›οΈ Want more suggestions? @@ -90,7 +108,30 @@ print(suggestions) suggestions = ps.recommender(df, n=10) ``` -### 🧾 2. AI-Powered Plot Explanation +### πŸ“ˆ 2. One-Click Plot Generation +Generate recommended charts instantly: + +```bash +plot1 = ps.plotgen(df, suggestions.iloc[0]) # This will plot a bar chart with variables 'survived', 'pclass' +plot2 = ps.plotgen(df, suggestions.iloc[1]) # This will plot a bar chart with variables 'survived', 'sex' +plot3 = ps.plotgen(df, suggestions.iloc[2]) # This will plot a histogram with variable 'age' +``` +πŸŽ›οΈ Want more control? + +``` bash +plot1 = ps.plotgen(df, suggestions.iloc[0], x='pclass', y='survived') +``` +Supported Plots +- scatter +- bar +- barh +- histogram +- boxplot +- violinplot +- pie +- hexbin + +### 🧾 3. AI-Powered Plot Explanation Turn your visualizations into stories with natural language insights: ``` bash @@ -103,7 +144,7 @@ print(explanation) - Custom Prompts: You can provide your own prompt to guide the explanation ``` bash -explanation = refine_plot_explanation( +explanation = explainer( fig, prompt="Explain the key trends in this sales data visualization" ) @@ -111,7 +152,14 @@ explanation = refine_plot_explanation( - Multiple Refinement Iterations: Increase the number of refinement cycles for more polished explanations: ```bash -explanation = refine_plot_explanation(fig, iterations=3) # Default is 2 +explanation = explainer(fig, max_iterations=3) # Default is 2 +``` + +## πŸ”„ Combined Workflow: Suggest β†’ Plot β†’ Explain +``` bash +suggestions = ps.recommender(df) +plot = ps.plotgen(df, suggestions.iloc[0]) +insight = ps.explainer(plot) ``` ## 🀝 Contributing @@ -131,13 +179,15 @@ We welcome contributions! - More model integrations - Automated insight highlighting - Jupyter widget support +- Features/target analysis +- More supported plots ### πŸ“₯ Install or Update ``` bash pip install --upgrade plotsense # Get the latest features! ``` ## πŸ›‘ License -MIT License (Open Source) +Apache License 2.0 ## πŸ” API & Privacy Notes - Your API key is securely held in memory for your current Python session. @@ -146,3 +196,12 @@ MIT License (Open Source) Let your data speakβ€”with clarity, power, and PlotSense. πŸ“Šβœ¨ + +## Your Feedback +[Feedback Form](https://forms.gle/QEjipzHiMagpAQU99) + + + + + + diff --git a/plotsense.egg-info/SOURCES.txt b/plotsense.egg-info/SOURCES.txt index a251ad9..0edb61d 100644 --- a/plotsense.egg-info/SOURCES.txt +++ b/plotsense.egg-info/SOURCES.txt @@ -1,3 +1,5 @@ +LICENCE +NOTICE README.md pyproject.toml setup.py @@ -10,6 +12,24 @@ plotsense.egg-info/top_level.txt plotsense/explanations/__init__.py plotsense/explanations/explanations.py plotsense/plot_generator/__init__.py +plotsense/plot_generator/base_generator.py +plotsense/plot_generator/basic_generator.py plotsense/plot_generator/generator.py +plotsense/plot_generator/helpers.py +plotsense/plot_generator/registry.py +plotsense/plot_generator/smart_generator.py +plotsense/plot_generator/plots/__init__.py +plotsense/visual_suggestion/__init__.py plotsense/visual_suggestion/suggestions.py -plotsense/visual_suggestion/__init__.py \ No newline at end of file +plotsense/visual_suggestion/recommender/__init__.py +plotsense/visual_suggestion/recommender/dataframe_analyzer.py +plotsense/visual_suggestion/recommender/ensemble_scorer.py +plotsense/visual_suggestion/recommender/prompt_builder.py +plotsense/visual_suggestion/recommender/response_parser.py +plotsense/visual_suggestion/recommender/visualization_recommender.py +test/__init__.py +test/my_ptce_test.py +test/my_test.py +test/test_explanations.py +test/test_plotgen.py +test/test_suggestions.py \ No newline at end of file diff --git a/plotsense.egg-info/requires.txt b/plotsense.egg-info/requires.txt index 6fbc6f2..a329864 100644 --- a/plotsense.egg-info/requires.txt +++ b/plotsense.egg-info/requires.txt @@ -1,4 +1,10 @@ -matplotlib>=3.0 +matplotlib>=3.8.0 seaborn>=0.11 pandas>=1.0 numpy>=1.18 +python-dotenv +groq +anthropic +openai +google-genai +requests diff --git a/plotsense.egg-info/top_level.txt b/plotsense.egg-info/top_level.txt index 100c7e8..f5522c3 100644 --- a/plotsense.egg-info/top_level.txt +++ b/plotsense.egg-info/top_level.txt @@ -1 +1,2 @@ plotsense +test diff --git a/plotsense/__init__.py b/plotsense/__init__.py index 301ab30..25b7b41 100644 --- a/plotsense/__init__.py +++ b/plotsense/__init__.py @@ -1,3 +1,3 @@ from plotsense.visual_suggestion.suggestions import recommender, VisualizationRecommender -from plotsense.explanations.explanations import explainer,PlotExplainer -from plotsense.plot_generator.generator import plotgen, PlotGenerator \ No newline at end of file +from plotsense.explanations.explanations import explainer, PlotExplainer +from plotsense.plot_generator.generator import plotgen, BasicPlotGenerator, SmartPlotGenerator diff --git a/plotsense/core/ai_interface.py b/plotsense/core/ai_interface.py new file mode 100644 index 0000000..eb17f22 --- /dev/null +++ b/plotsense/core/ai_interface.py @@ -0,0 +1,423 @@ +from concurrent.futures import ThreadPoolExecutor, as_completed +import warnings +from typing import Dict, List, Optional, Tuple + +from plotsense.core.enums.strategy import StrategyName +from plotsense.core.strategies.round_robin import RoundRobinStrategy +from plotsense.core.strategies.cost_optimized import CostOptimizedStrategy +from plotsense.core.strategies.performance_optimized import PerformanceOptimizedStrategy +from plotsense.core.strategies.fallback_chain import FallbackChainStrategy + + +class AIModelInterface: + """ + Handles all low-level interactions with LLM providers. + Acts as a bridge between PlotExplainer (or any client) + and ProviderManager. + """ + + def __init__(self, provider_manager, timeout: int = 30): + self.manager = provider_manager + self.timeout = timeout + + def _init_strategy( + self, strategy_name: StrategyName, + available_models: List[Tuple[str, str]] + ): + try: + strategy_name = StrategyName(strategy_name) + except ValueError: + raise ValueError(f"Invalid strategy name: {strategy_name}") + + if strategy_name == StrategyName.ROUND_ROBIN: + return RoundRobinStrategy(available_models) + elif strategy_name == StrategyName.COST_OPTIMIZED: + return CostOptimizedStrategy(available_models, self.manager) + elif strategy_name == StrategyName.PERFORMANCE: + return PerformanceOptimizedStrategy(available_models, self.manager) + elif strategy_name == StrategyName.FALLBACK_CHAIN: + return FallbackChainStrategy(available_models) + + def query_all_models( + self, + prompt: str, + debug: bool = False, + base64_image: Optional[str] = None, + custom_parameters: Optional[Dict] = None, + strategy: StrategyName = StrategyName.ROUND_ROBIN, + max_workers: int = 6, + ) -> Dict[str, str]: + """ + Query all available models (across all providers) in parallel. + Uses ThreadPoolExecutor for concurrency. + Returns a mapping of "provider:model" -> response_text. + + Notes: + - Keeps strategy initialization for compatibility (strategy instance + can be used later to reorder or filter models). + - Each model is queried independently; failures don't stop the rest. + """ + # Get available models as list of tuples: [(provider, model_name), ...] + all_models = self.manager.list_all_models() + self.available_models = [ + (provider, model) + for provider, models in all_models.items() + for model in models + ] + if not self.available_models: + raise ValueError("No available models found from provider manager.") + + # Initialize strategy instance (keeps previous behavior) + strategy_instance = self._init_strategy( + strategy, self.available_models + ) + + results: Dict[str, str] = {} + + # --- 1️⃣ Let strategy select or order models --- + # Most strategies (RoundRobin, CostOptimized, etc.) will implement a method + # like `.select_models(n: int)` or `.get_next_batch()`. + # If not, we simply use all available models. + try: + # Example interface: select_models returns a prioritized list + models_to_query = strategy_instance.select_model(len(self.available_models)) + except AttributeError: + # Fallback: strategy not yet implementing selection + models_to_query = self.available_models + + if not models_to_query: + raise ValueError("Strategy did not return any models to query.") + + if debug: + print(f"\n[DEBUG] Strategy '{strategy_instance.__class__.__name__}' selected models:") + for prov, mod in models_to_query: + print(f" - {prov}:{mod}") + + def _query_one(provider: str, model_name: str): + model_key = f"{provider}:{model_name}" + try: + resp = self.query_model( + provider=provider, + model=model_name, + prompt=prompt, + base64_image=base64_image, + custom_parameters=custom_parameters, + ) + return model_key, resp + except Exception as e: + warnings.warn(f"[AIModelInterface] Query failed for {model_key} -> {e}") + return model_key, f"Error: {e}" + + # FallbackChainStrategy -> sequential queries until one succeeds + if isinstance(strategy_instance, FallbackChainStrategy): + for provider, model_name in models_to_query: + key, resp = _query_one(provider, model_name) + results[key] = resp + if not resp.lower().startswith("error"): + # Stop at first success (fallback semantics) + break + else: + # Run queries concurrently + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_key = { + executor.submit(_query_one, provider, model_name): (provider, model_name) + for provider, model_name in self.available_models + } + + for future in as_completed(future_to_key): + key, resp = future.result() + results[key] = resp + + return results + + def query_model( + self, + provider: str, + model: str, + prompt: str, + base64_image: Optional[str] = None, + custom_parameters: Optional[Dict] = None + ) -> str: + """ + Query a model via the provider manager. + Handles provider-specific formatting and error management. + """ + if provider not in self.manager.providers: + raise ValueError(f"Unknown provider: {provider}") + + try: + # Build messages depending on provider/model + messages = self._build_messages( + provider, model, prompt, base64_image + ) + generation_params = {"temperature": 0.7, "max_tokens": 1000, **(custom_parameters or {})} + + provider_lower = provider.lower() + # model_lower = model.lower() + + # -------------------- OPENAI (Chat + Response) -------------------- + if "openai" in provider_lower: + # if "gpt" in model_lower or "chat" in model_lower: + if "chat" in provider_lower: + # Chat-based models (GPT-4, GPT-3.5, etc.) + return self.manager.query( + provider, + model=model, + messages=messages, + prompt=prompt, + **generation_params, + ) + elif "response" in provider_lower: + # Response-based models (completion endpoints) + return self.manager.query( + provider, + model=model, + prompt=prompt, + **generation_params, + ) + + # -------------------- AZURE OPENAI -------------------- + elif "azure" in provider_lower: + # Azure follows OpenAI's API style but requires deployment-specific name + return self.manager.query( + provider, + model=model, + messages=messages, + prompt=prompt, + **generation_params, + ) + + # -------------------- GROQ -------------------- + elif "groq" in provider_lower: + # Typically text-only Llama-style models + return self.manager.query( + provider, + model=model, + messages=messages, + prompt=prompt, + **generation_params, + ) + + # -------------------- ANTHROPIC -------------------- + elif "anthropic" in provider_lower: + # Claude models (text + multimodal optional) + return self.manager.query( + provider, + model=model, + messages=messages, + prompt=prompt, + **generation_params, + ) + + # -------------------- GEMINI -------------------- + elif "gemini" in provider_lower: + # Supports text + images + return self.manager.query( + provider, + model=model, + messages=messages, + prompt=prompt, + image=base64_image, + **generation_params, + ) + + # -------------------- OLLAMA -------------------- + elif "ollama" in provider_lower: + # Local models; prompt only, may support images if model allows + return self.manager.query( + provider, + model=model, + prompt=prompt, + image=base64_image, + **generation_params, + ) + + # -------------------- DEFAULT / UNKNOWN -------------------- + else: + print(f"[AIModelInterface] Warning: Using default query handling for {provider}:{model}") + # Fallback for new or custom providers + return self.manager.query( + provider, + model=model, + messages=messages, + prompt=prompt, + **generation_params, + ) + + except Exception as e: + warnings.warn(f"[AIModelInterface] Querying error for {provider}:{model} -> {str(e)}") + return f"Error: {e}" + finally: + return f"Error: No valid query handler found for provider '{provider}'." + + def get_model_weights(self) -> Dict[str, float]: + """ + Return model weighting for ensemble scoring. + + Weighting strategy (default heuristics): + - OpenAI GPT-4 variants -> higher weight (2.0) + - Anthropic Claude family -> high weight (1.8) + - Google Gemini -> high weight (1.6) + - Azure (OpenAI in Azure) -> treated similar to openai (1.8 for gpt-4) + - Groq (Llama variants) -> moderate weight (1.2) + - Ollama / local models -> lower/moderate weight (1.0) + - Other / unknown -> base weight (1.0) + + Returns: + dict of "provider:model" -> normalized_weight + """ + all_models = self.manager.list_all_models() + self.available_models = [ + (provider, model) + for provider, models in all_models.items() + for model in models + ] + + raw_weights: Dict[str, float] = {} + + for provider, model_name in self.available_models: + key = f"{provider}:{model_name}" + lname = model_name.lower() + lprov = provider.lower() + + # Base preference by model name + if "gpt-4" in lname or "gpt4" in lname or "gpt-4o" in lname: + base = 2.0 + elif "claude" in lname: + base = 1.8 + elif "gemini" in lname or "gemini-pro" in lname: + base = 1.6 + elif "llama" in lname or "groq" in lprov: + # groq's Llama-based models - decent but not highest + base = 1.2 + elif "azure" in lprov: + # Azure OpenAI often runs OpenAI models; favor if contains gpt-4 + base = 1.8 if "gpt-4" in lname or "gpt4" in lname else 1.1 + elif "ollama" in lprov: + base = 1.0 + else: + base = 1.0 + + # Provider-level adjustments (optional) + if lprov == "anthropic": + base *= 1.0 # already accounted by 'claude' checks + if lprov == "openai": + base *= 1.0 + if lprov == "groq": + base *= 1.0 + if lprov == "azure": + base *= 1.0 + + raw_weights[key] = base + + # Normalize to sum to 1 + total = sum(raw_weights.values()) or 1.0 + normalized = {k: (v / total) for k, v in raw_weights.items()} + return normalized + + def _build_messages( + self, provider: str, model: str, prompt: str, + base64_image: Optional[str] = None + ): + """ + Build messages dynamically based on provider capabilities. + Supports multimodal input where possible (OpenAI GPT-4o, Gemini, Anthropic, etc.). + Falls back to text-only prompt for providers without image support. + """ + provider_lower = provider.lower() + model_lower = model.lower() + + # --- 1️⃣ OpenAI / Azure (GPT-4, GPT-4o, GPT-3.5 etc.) --- + if provider_lower in {"openai", "azure"}: + if base64_image and any(tag in model_lower for tag in ["gpt-4o", "gpt-4-turbo", "gpt-4-vision"]): + # Chat message with multimodal support + return [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}, + ], + } + ] + else: + # Standard chat completion format + return [ + {"role": "system", "content": "You are a helpful data visualization assistant."}, + {"role": "user", "content": prompt}, + ] + + # --- 2️⃣ Anthropic (Claude) --- + elif provider_lower == "anthropic": + # Claude supports multimodal via text + image blocks in messages + if base64_image: + return [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": base64_image}}, + ], + } + ] + else: + return [ + {"role": "user", "content": prompt} + ] + + # --- 3️⃣ Gemini (Google) --- + elif provider_lower == "gemini": + # Gemini API supports multimodal via a combined structure + if base64_image: + return [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image", "data": base64_image, "mime_type": "image/jpeg"}, + ], + } + ] + else: + return [ + {"role": "user", "content": prompt} + ] + + # --- 4️⃣ Groq (LLaMA / Mistral etc. – text-only) --- + elif provider_lower == "groq": + return [ + {"role": "user", "content": prompt} + ] + + # --- 5️⃣ Ollama (local models; may support image, but prompt-based) --- + elif provider_lower == "ollama": + if base64_image: + # Send inline text prompt mentioning image context + return [ + { + "role": "user", + "content": f"{prompt}\n\n[Image attached as base64 input]" + } + ] + else: + return [ + {"role": "user", "content": prompt} + ] + + # --- 6️⃣ Default / Unknown provider fallback --- + else: + if base64_image: + return [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}, + ], + } + ] + else: + return [ + {"role": "user", "content": prompt} + ] + diff --git a/plotsense/core/enums/strategy.py b/plotsense/core/enums/strategy.py new file mode 100644 index 0000000..61dbecc --- /dev/null +++ b/plotsense/core/enums/strategy.py @@ -0,0 +1,8 @@ +from enum import Enum + +class StrategyName(str, Enum): + ROUND_ROBIN = "round_robin" + COST_OPTIMIZED = "cost_optimized" + PERFORMANCE = "performance" + FALLBACK_CHAIN = "fallback" + diff --git a/plotsense/core/providers/anthropic.py b/plotsense/core/providers/anthropic.py new file mode 100644 index 0000000..a24e290 --- /dev/null +++ b/plotsense/core/providers/anthropic.py @@ -0,0 +1,75 @@ +from typing import List +from anthropic import Anthropic +from .base import LLMProvider + + +class AnthropicProvider(LLMProvider): + """Provider integration for Anthropic's Claude models.""" + + LINK = "πŸ‘‰ https://console.anthropic.com/account/keys πŸ‘ˆ" + + def __init__(self, api_key: str): + self.api_key = api_key + self.client = None + + def _init_client(self): + """Initialize Anthropic client if not already created.""" + if not self.client: + self.client = Anthropic(api_key=self.api_key) + + def query(self, prompt: str, model: str, **kwargs) -> str: + """Send a message to Anthropic Claude model and return its response text.""" + if not self.client: + raise ValueError( + "Anthropic client not initialized. Call validate_key() first." + ) + + messages = kwargs.pop("messages", None) + if not messages and prompt: + # Default to a single user message + messages = [{"role": "user", "content": prompt}] + elif not messages and not prompt: + raise ValueError("Either 'prompt' or 'messages' must be provided.") + + try: + response = self.client.messages.create( + model=model, + messages=[{"role": "user", "content": prompt}], + **kwargs, + ) + return response.content[0].text if response and response.content else "" + except Exception as e: + raise RuntimeError(f"Anthropic query failed: {e}") + + def list_models(self) -> List[str]: + """ + Return a list of supported Anthropic models. + This list can be expanded as new Claude versions are released. + """ + return [ + "claude-3-5-sonnet-20241022", + "claude-3-opus-20240229", + "claude-3-haiku-20240307", + ] + + def validate_key(self) -> bool: + """ + Validate the provided API key by performing a lightweight test. + Returns True if successful, False otherwise. + """ + try: + self._init_client() + if not self.client: + raise ValueError( + "Anthropic client not initialized. Call validate_key() first." + ) + # Perform a trivial, cheap call to verify authentication + self.client.messages.create( + model="claude-3-haiku-20240307", + messages=[{"role": "user", "content": "ping"}], + max_tokens=5, + ) + return True + except Exception: + return False + diff --git a/plotsense/core/providers/azure_openai.py b/plotsense/core/providers/azure_openai.py new file mode 100644 index 0000000..4a81200 --- /dev/null +++ b/plotsense/core/providers/azure_openai.py @@ -0,0 +1,98 @@ +from typing import List +from openai import OpenAI +# AzureOpenAI, +from openai.types.chat import ChatCompletionUserMessageParam +from .base import LLMProvider + + +class AzureOpenAIProvider(LLMProvider): + """Provider integration for Azure-hosted OpenAI models.""" + + LINK = "πŸ‘‰ https://portal.azure.com/#view/Microsoft_Azure_ProjectOxford/CognitiveServicesHub/feature/OpenAI πŸ‘ˆ" + + def __init__( + self, api_key: str, + endpoint: str = "https://models.github.ai/inference", + api_version: str = "2024-02-15-preview" + ): + """ + Args: + api_key: Azure OpenAI API key + endpoint: Full Azure endpoint (e.g. https://.openai.azure.com/) + api_version: Azure OpenAI API version + """ + self.api_key = api_key + self.endpoint = endpoint + # self.api_version = api_version + self.client = None + + def _init_client(self): + """Initialize the Azure OpenAI client.""" + if not self.endpoint: + raise ValueError("Azure OpenAI endpoint not provided.") + if not self.client: + self.client = OpenAI( + api_key=self.api_key, + # api_version=self.api_version, + # azure_endpoint=self.endpoint, + base_url=self.endpoint, + ) + + def query(self, prompt: str, model: str, **kwargs) -> str: + """ + Send a prompt to Azure OpenAI Chat Completion API. + """ + self._init_client() + if not self.client: + raise ValueError("AzureOpenAI client not initialized. Call validate_key() first.") + + # Ensure messages format exists in kwargs + messages: list[ChatCompletionUserMessageParam] = kwargs.pop("messages", None) + if not messages and prompt: + messages = [{"role": "user", "content": prompt}] + elif not messages and not prompt: + raise ValueError("Either 'prompt' or 'messages' must be provided.") + + try: + if "max_tokens" in kwargs: + kwargs["max_output_tokens"] = kwargs.pop("max_tokens") + response = self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs + ) + return response.choices[0].message.content + except Exception as e: + raise RuntimeError(f"Azure OpenAI query failed: {e}") + + def list_models(self) -> List[str]: + """ + Return a suggested list of Azure OpenAI deployable model names. + (These must match your deployment names in Azure.) + """ + return [ + "openai/gpt-5", + # "gpt-4o", + # "gpt-4-turbo", + # "gpt-35-turbo", + # "gpt-4", + ] + + def validate_key(self) -> bool: + """ + Attempt a lightweight ping to validate Azure OpenAI credentials. + """ + try: + self._init_client() + if not self.client: + raise ValueError("AzureOpenAI client not initialized. Call validate_key() first.") + response = self.client.chat.completions.create( + model="openai/gpt-5", + messages=[{"role": "user", "content": "ping"}], + max_completion_tokens=5 + ) + return bool(response) + except Exception as e: + print(f"⚠️ Azure OpenAI API key validation failed: {e}") + return False + diff --git a/plotsense/core/providers/base.py b/plotsense/core/providers/base.py new file mode 100644 index 0000000..e1f8d94 --- /dev/null +++ b/plotsense/core/providers/base.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod +from typing import List + +class LLMProvider(ABC): + """Abstract base class for LLM providers.""" + + LINK: str + + @abstractmethod + def __init__(self, api_key: str): + """Initialize the provider with an API key.""" + pass + + @abstractmethod + def query(self, prompt: str, model: str, **kwargs) -> str: + pass + + @abstractmethod + def list_models(self) -> List[str]: + pass + + @abstractmethod + def validate_key(self) -> bool: + pass + diff --git a/plotsense/core/providers/gemini.py b/plotsense/core/providers/gemini.py new file mode 100644 index 0000000..5d3d791 --- /dev/null +++ b/plotsense/core/providers/gemini.py @@ -0,0 +1,92 @@ +from typing import List, Optional +from google import genai +from google.genai.types import GenerateContentConfig +from .base import LLMProvider + + +class GeminiProvider(LLMProvider): + """Provider integration for Google's Gemini models (v2 SDK).""" + + LINK = "πŸ‘‰ https://aistudio.google.com/app/apikey πŸ‘ˆ" + + def __init__(self, api_key: str): + self.api_key = api_key + self.client = None + + def _init_client(self): + """Initialize Anthropic client if not already created.""" + if not self.client: + self.client = genai.Client(api_key=self.api_key) + + def query( + self, + prompt: str, + model: str, + base64_image: Optional[str] = None, + **kwargs, + ) -> str: + """ + Send a text (or multimodal) prompt to Gemini and return the response text. + Supports text-only and text+image queries. + + Uses the new google-genai v2 API. + """ + try: + # Build input depending on image presence + if base64_image: + # Multimodal: send both text and image + contents = [ + {"text": prompt}, + { + "inline_data": { + "mime_type": "image/jpeg", + "data": base64_image, + } + }, + ] + else: + # Text-only + contents = prompt + + self._init_client() + if not self.client: + raise ValueError("Gemini client initialization failed.") + + response = self.client.models.generate_content( + model=model, + contents=contents, + **kwargs, + ) + + # Return clean text or empty string if missing + return getattr(response, "text", "") or "" + + except Exception as e: + raise RuntimeError(f"Gemini query failed: {e}") + + def list_models(self) -> List[str]: + """ + Return a curated list of Gemini models. + """ + return [ + "gemini-2.5-flash", + "gemini-2.0-pro", + "gemini-1.5-flash", + ] + + def validate_key(self) -> bool: + """ + Validate the provided Gemini API key by attempting a trivial generation. + """ + try: + self._init_client() + if not self.client: + raise ValueError("Gemini client initialization failed.") + response = self.client.models.generate_content( + model="gemini-2.5-flash", + contents="ping", + config=GenerateContentConfig(max_output_tokens=5), + ) + return bool(response.text) + except Exception: + return False diff --git a/plotsense/core/providers/groq.py b/plotsense/core/providers/groq.py new file mode 100644 index 0000000..abbb412 --- /dev/null +++ b/plotsense/core/providers/groq.py @@ -0,0 +1,74 @@ +from typing import List +from groq import Groq +from groq.types.chat import ChatCompletionUserMessageParam +from .base import LLMProvider + + +class GroqProvider(LLMProvider): + """Provider integration for Groq's fast inference API.""" + + LINK = "πŸ‘‰ https://console.groq.com/keys πŸ‘ˆ" + + def __init__(self, api_key: str): + self.api_key = api_key + self.client = None + + def _init_client(self): + """Initialize Groq client if not already created.""" + if not self.client: + self.client = Groq(api_key=self.api_key) + + def query( + self, + prompt: str, + model: str, + **kwargs, + ) -> str: + """ + Send a text prompt to Groq (Llama models) and return the response text. + Supports OpenAI-style chat completions. + """ + self._init_client() + if not self.client: + raise ValueError("Groq client not initialized. Call validate_key() first.") + + # Build messages dynamically (fallback if only prompt is given) + messages: list[ChatCompletionUserMessageParam] = kwargs.pop("messages", None) + if not messages and prompt: + messages = [{"role": "user", "content": prompt}] + elif not messages and not prompt: + raise ValueError("Either 'prompt' or 'messages' must be provided.") + + try: + response = self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs, + ) + return response.choices[0].message.content + except Exception as e: + raise RuntimeError(f"Groq query failed: {e}") + + def list_models(self) -> List[str]: + """Return a curated list of supported Groq models.""" + return [ + "llama-3.1-8b-instant", + "llama-3.3-70b-versatile", + ] + + def validate_key(self) -> bool: + """ + Validate the provided Groq API key by making a lightweight request. + """ + try: + self._init_client() + if not self.client: + raise ValueError("Groq client not initialized.") + response = self.client.chat.completions.create( + model="llama-3.1-8b-instant", + messages=[{"role": "user", "content": "ping"}], + max_tokens=5, + ) + return bool(response.choices[0].message.content) + except Exception: + return False diff --git a/plotsense/core/providers/groq_openai.py b/plotsense/core/providers/groq_openai.py new file mode 100644 index 0000000..ca0f4f7 --- /dev/null +++ b/plotsense/core/providers/groq_openai.py @@ -0,0 +1,73 @@ +from typing import List +from openai.types.chat import ChatCompletionUserMessageParam +from openai import OpenAI +from .base import LLMProvider + + +class GroqProvider(LLMProvider): + """Provider for Groq models using the unified OpenAI SDK interface.""" + + LINK = "πŸ‘‰ https://console.groq.com/keys πŸ‘ˆ" + + def __init__(self, api_key: str): + self.api_key = api_key + self.client = None + + def _init_client(self): + """Initialize Groq client via OpenAI-compatible endpoint.""" + if not self.client: + self.client = OpenAI( + api_key=self.api_key, + base_url="https://api.groq.com/openai/v1" # Key difference + ) + + def query(self, prompt: str, model: str, **kwargs) -> str: + """ + Send a chat completion query to Groq via OpenAI SDK. + """ + self._init_client() + if not self.client: + raise ValueError("Groq client not initialized. Call validate_key() first.") + + # Ensure messages are present + messages: list[ChatCompletionUserMessageParam] = kwargs.pop( + "messages", None + ) + if not messages and prompt: + messages = [{"role": "user", "content": prompt}] + elif not messages and not prompt: + raise ValueError("Either 'prompt' or 'messages' must be provided.") + + try: + response = self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs + ) + return response.choices[0].message.content + except Exception as e: + raise RuntimeError(f"Groq query failed: {e}") + + def list_models(self) -> List[str]: + """ + Available Groq Llama models (you can update this dynamically later). + """ + return ["llama-3.1-8b-instant", "llama-3.3-70b-versatile"] + + def validate_key(self) -> bool: + """ + Simple ping to check API validity. + """ + try: + self._init_client() + if not self.client: + raise ValueError("Groq OpenAI client not initialized.") + response = self.client.chat.completions.create( + model="llama-3.1-8b-instant", + messages=[{"role": "user", "content": "ping"}], + max_tokens=5 + ) + return bool(response) + except Exception: + return False + diff --git a/plotsense/core/providers/ollama_openai.py b/plotsense/core/providers/ollama_openai.py new file mode 100644 index 0000000..185a045 --- /dev/null +++ b/plotsense/core/providers/ollama_openai.py @@ -0,0 +1,76 @@ +from typing import List +from openai import OpenAI +from openai.types.chat import ChatCompletionUserMessageParam +from .base import LLMProvider + + +class OllamaProvider(LLMProvider): + """ + Provider for Ollama models using the OpenAI-compatible API. + This allows querying a locally running Ollama instance. + """ + + LINK = "πŸ‘‰ https://ollama.ai/library πŸ‘ˆ" + + def __init__(self, api_key: str = ""): + # Ollama typically doesn't require an API key (local service) + self.api_key = api_key + self.client = None + + def _init_client(self): + """Initialize the OpenAI-compatible client for local Ollama.""" + if not self.client: + # Default local Ollama endpoint + self.client = OpenAI( + base_url="http://localhost:11434/v1", # Ollama’s OpenAI-compatible API + api_key=self.api_key or "ollama", # Dummy key for OpenAI client compatibility + ) + + def query(self, prompt: str, model: str, **kwargs) -> str: + """ + Query the Ollama model using OpenAI-compatible endpoint. + """ + self._init_client() + if not self.client: + raise ValueError("Ollama client not initialized. Call validate_key() first.") + + messages: list[ChatCompletionUserMessageParam] = kwargs.pop( + "messages", None + ) + if not messages and prompt: + messages = [{"role": "user", "content": prompt}] + elif not messages and not prompt: + raise ValueError("Either 'prompt' or 'messages' must be provided.") + + try: + response = self.client.chat.completions.create( + model=model, + **kwargs + ) + return response.choices[0].message.content + except Exception as e: + raise RuntimeError(f"Ollama query failed: {e}") + + def list_models(self) -> List[str]: + """ + List of example models. In a real setup, this could query `ollama list`. + """ + return ["llama3", "mistral", "codellama", "phi3", "neural-chat"] + + def validate_key(self) -> bool: + """ + Validate connection to local Ollama instance. + """ + try: + self._init_client() + if not self.client: + raise ValueError("Ollama OpenAI client not initialized.") + response = self.client.chat.completions.create( + model="llama3", + messages=[{"role": "user", "content": "ping"}], + max_tokens=5 + ) + return bool(response) + except Exception: + return False + diff --git a/plotsense/core/providers/openai_chat.py b/plotsense/core/providers/openai_chat.py new file mode 100644 index 0000000..e4bcd39 --- /dev/null +++ b/plotsense/core/providers/openai_chat.py @@ -0,0 +1,76 @@ +from typing import List, Optional +from openai import OpenAI +from openai.types.chat import ChatCompletionUserMessageParam +from .base import LLMProvider + + +class OpenAIChatProvider(LLMProvider): + """Provider integration for OpenAI Chat models.""" + + LINK = "πŸ‘‰ https://platform.openai.com/api-keys πŸ‘ˆ" + + def __init__(self, api_key: str): + self.api_key = api_key + self.client = None + + def _init_client(self): + """Initialize OpenAI client if not already created.""" + if not self.client: + self.client = OpenAI(api_key=self.api_key) + + def query( + self, + prompt: Optional[str], + model: str, + **kwargs, + ) -> str: + """ + Send a prompt or messages to OpenAI Chat Completion API. + Supports both chat-style input and single text prompts. + """ + self._init_client() + if not self.client: + raise ValueError("OpenAI client not initialized. Call validate_key() first.") + + # Handle either prompt or messages + messages: list[ChatCompletionUserMessageParam] = kwargs.pop("messages", None) + if not messages and prompt: + messages = [{"role": "user", "content": prompt}] + elif not messages and not prompt: + raise ValueError("Either 'prompt' or 'messages' must be provided.") + + try: + response = self.client.chat.completions.create( + model=model, + messages=messages, + **kwargs, + ) + return response.choices[0].message.content + except Exception as e: + raise RuntimeError(f"OpenAI chat query failed: {e}") + + def list_models(self) -> List[str]: + """Return a curated list of supported OpenAI chat models.""" + return [ + "gpt-4o-mini", + "gpt-4.1", + "gpt-4-turbo", + "gpt-4o", + ] + + def validate_key(self) -> bool: + """ + Validate the provided OpenAI API key by performing a lightweight test query. + """ + try: + self._init_client() + if not self.client: + raise ValueError("OpenAI client not initialized.") + response = self.client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "ping"}], + max_tokens=5, + ) + return bool(response.choices[0].message.content) + except Exception: + return False diff --git a/plotsense/core/providers/openai_response.py b/plotsense/core/providers/openai_response.py new file mode 100644 index 0000000..bd4ec25 --- /dev/null +++ b/plotsense/core/providers/openai_response.py @@ -0,0 +1,77 @@ +from typing import List, Optional +from openai import OpenAI +from .base import LLMProvider + + +class OpenAIResponseProvider(LLMProvider): + """Provider integration for OpenAI's Responses API.""" + + LINK = "πŸ‘‰ https://platform.openai.com/api-keys πŸ‘ˆ" + + def __init__(self, api_key: str): + self.api_key = api_key + self.client = None + + def _init_client(self): + """Initialize OpenAI client if not already created.""" + if not self.client: + self.client = OpenAI(api_key=self.api_key) + + def query( + self, + prompt: Optional[str], + model: str, + **kwargs, + ) -> str: + """ + Send a prompt to OpenAI's Responses API and return the generated text. + The Responses endpoint supports text, image, and JSON outputs. + """ + self._init_client() + if not self.client: + raise ValueError("OpenAI client not initialized. Call validate_key() first.") + + if not prompt: + raise ValueError("'prompt' must be provided for Responses API queries.") + + try: + if "max_tokens" in kwargs: + kwargs["max_output_tokens"] = kwargs.pop("max_tokens") + # The Responses API expects `input`, not `messages` + response = self.client.responses.create( + model=model, + input=prompt, + **kwargs, + ) + + # The unified field for plain text output is `output_text` + return getattr(response, "output_text", "") or "" + except Exception as e: + raise RuntimeError(f"OpenAI response query failed: {e}") + + def list_models(self) -> List[str]: + """Return a curated list of supported OpenAI Response models.""" + return [ + "gpt-4.1", + "gpt-4.1-mini", + "gpt-4o", + ] + + def validate_key(self) -> bool: + """ + Validate the provided OpenAI API key by performing a lightweight test query. + """ + try: + self._init_client() + if not self.client: + raise ValueError("OpenAI client not initialized.") + response = self.client.responses.create( + model="gpt-4.1-mini", + input="ping", + max_output_tokens=16, + ) + return bool(getattr(response, "output_text", "")) + except Exception as e: + print(f"OpenAI Responses API key validation failed: {e}") + return False + diff --git a/plotsense/core/providers/provider_manager.py b/plotsense/core/providers/provider_manager.py new file mode 100644 index 0000000..fc50e4b --- /dev/null +++ b/plotsense/core/providers/provider_manager.py @@ -0,0 +1,241 @@ +from typing import Dict, List, Optional, Type + +from plotsense.core.providers.anthropic import AnthropicProvider +from plotsense.core.providers.azure_openai import AzureOpenAIProvider +from plotsense.core.providers.base import LLMProvider +from plotsense.core.providers.gemini import GeminiProvider +from plotsense.core.providers.ollama_openai import OllamaProvider +from plotsense.core.providers.openai_chat import OpenAIChatProvider +from plotsense.core.utils import prompt_for_api_key +from .groq import GroqProvider +from .openai_response import OpenAIResponseProvider + + +class ProviderManager: + """Manages multiple LLM providers, their API keys, and interactions.""" + + SUPPORTED_PROVIDERS: Dict[str, Dict[str, Type[LLMProvider]]] = { + "groq": { + "default": GroqProvider, + }, + "openai": { + "chat": OpenAIChatProvider, + "response": OpenAIResponseProvider, + }, + "anthropic": { + "default": AnthropicProvider, + }, + "gemini": { + "default": GeminiProvider, + }, + "azure": { + "default": AzureOpenAIProvider, + }, + "ollama": { + "default": OllamaProvider, + }, + } + + def __init__( + self, api_keys: Dict[str, str], interactive: bool = True, + restrict_to: Optional[List[str]] = None + ): + self.api_keys = api_keys or {} + self.interactive = interactive + self.providers = {} + self.restrict_to = set(restrict_to) if restrict_to else None + + # Normalize restrict_to list + if restrict_to: + invalid = [p for p in restrict_to if p not in self.SUPPORTED_PROVIDERS] + if invalid: + raise ValueError( + f"Unsupported provider(s): {invalid}. " + f"Supported providers: {list(self.SUPPORTED_PROVIDERS.keys())}" + ) + self.restrict_to = set(restrict_to) + else: + self.restrict_to = None + + self._init_providers() + + def _init_providers(self): + """Initialize all registered providers and validate their API keys.""" + for vendor_name, variants in self.SUPPORTED_PROVIDERS.items(): + # Skip if restrict_to is provided and this vendor isn’t included + if self.restrict_to and vendor_name not in self.restrict_to: + continue + + for variant_name, provider_cls in variants.items(): + full_name = f"{vendor_name}_{variant_name}" + link = getattr(provider_cls, "LINK", f"https://{vendor_name}.com") + + api_key: Optional[str] = self.api_keys.get(vendor_name) + if not api_key: + # Try to prompt only if interactive and not restricted + api_key = prompt_for_api_key( + vendor_name, + link, + self.interactive, + skip_if_missing=bool(self.restrict_to), + ) + if not api_key: + # Skip this provider if key is still missing + print(f"⏩ Skipping {full_name.upper()} (no API key provided).") + continue + + self.api_keys[vendor_name] = api_key + + if not isinstance(api_key, str) or not api_key.strip(): + print(f"⚠️ Skipping {full_name.upper()} due to invalid API key format.") + continue + + provider = provider_cls(api_key=api_key) + + try: + if provider.validate_key(): + print(f"βœ… {full_name.upper()} API key validated successfully.") + self.providers[full_name] = provider + else: + print(f"❌ {full_name.upper()} API key invalid or unverified.") + except Exception as e: + print(f"⚠️ Error validating {full_name.upper()} API key: {e}") + + def get_provider(self, vendor_name: str, variant_name: str = ""): + """ + Get or initialize a provider (with optional variant) on demand. + + Args: + vendor_name: Name of the AI provider (e.g., "openai", "groq") + variant_name: Optional variant name (e.g., "chat", "completion") + + Returns: + Initialized provider instance + """ + # Compose a unique key for storage + full_name = f"{vendor_name}_{variant_name}" if variant_name else vendor_name + + if vendor_name not in self.SUPPORTED_PROVIDERS: + raise ValueError(f"Unknown provider: {vendor_name}") + + if full_name not in self.providers: + variants = self.SUPPORTED_PROVIDERS[vendor_name] + + # Determine class safely + provider_cls = None + if variant_name: + provider_cls = variants.get(variant_name) + if not provider_cls: + raise ValueError( + f"Unknown variant '{variant_name}' for provider '{vendor_name}'" + ) + else: + variant_name, provider_cls = next(iter(variants.items())) + full_name = f"{vendor_name}_{variant_name}" + + link = getattr(provider_cls, "LINK", f"https://{vendor_name}.com") + + api_key: Optional[str] = self.api_keys.get(vendor_name) + if not api_key: + api_key = prompt_for_api_key(vendor_name, link, self.interactive) + if not api_key: + raise ValueError(f"Missing API key for {vendor_name}") + self.api_keys[vendor_name] = api_key + + # if not isinstance(api_key, str): + # raise TypeError(f"API key for {vendor_name} must be a string") + + provider = provider_cls(api_key=api_key) + + try: + if provider.validate_key(): + print(f"βœ… {full_name.upper()} API key validated successfully.") + else: + print(f"❌ {full_name.upper()} API key invalid or unverified.") + except Exception as e: + print(f"⚠️ Error validating {full_name.upper()} API key: {e}") + + self.providers[full_name] = provider + + return self.providers[full_name] + + def list_all_models(self): + all_models = {} + for name, provider in self.providers.items(): + try: + all_models[name] = provider.list_models() + except Exception as e: + print(f"⚠️ Failed to list models for {name}: {str(e)}") + return all_models + + def query(self, provider_name: str, model: str, prompt: str, **kwargs): + """Query a specific provider with a prompt and model.""" + provider = self.providers.get(provider_name) + if not provider: + raise ValueError(f"Provider {provider_name} not initialized.") + return provider.query(prompt, model, **kwargs) + + def get_model_costs(self) -> Dict[str, float]: + """ + Return a global map of model names to approximate per-request cost multipliers. + This helps CostOptimizedStrategy prioritize cheaper models. + """ + # In a real system, this could come from provider-specific metadata + return { + # OpenAI + "gpt-4o-mini": 0.01, + "gpt-4o": 0.03, + "gpt-4-turbo": 0.025, + "gpt-3.5-turbo": 0.008, + # Groq (Llama) + "llama-3.1-8b-instant": 0.005, + "llama-3.3-70b-versatile": 0.02, + # Anthropic + "claude-3-haiku": 0.009, + "claude-3-sonnet": 0.02, + "claude-3-opus": 0.05, + # Gemini + "gemini-1.5-flash": 0.006, + "gemini-1.5-pro": 0.02, + # Azure (proxy to GPT costs) + "azure-gpt-4o-mini": 0.011, + "azure-gpt-4o": 0.031, + # Ollama (local = near-zero cost) + "llama3": 0.001, + "mistral": 0.002, + } + + def get_model_performance(self) -> Dict[str, float]: + """ + Return approximate relative performance scores for each model. + Higher means better performance (accuracy, reasoning ability, etc.). + """ + return { + # OpenAI + "gpt-4o": 10.0, + "gpt-4o-mini": 8.5, + "gpt-4-turbo": 9.5, + "gpt-3.5-turbo": 7.5, + + # Anthropic + "claude-3-opus": 9.8, + "claude-3-sonnet": 9.0, + "claude-3-haiku": 7.0, + + # Groq + "llama-3.3-70b-versatile": 8.8, + "llama-3.1-8b-instant": 6.5, + + # Gemini + "gemini-1.5-pro": 9.3, + "gemini-1.5-flash": 7.8, + + # Azure (maps to OpenAI) + "azure-gpt-4o": 9.8, + "azure-gpt-4o-mini": 8.3, + + # Ollama (local models) + "mistral": 6.0, + "llama3": 6.8, + } + diff --git a/plotsense/core/strategies/cost_optimized.py b/plotsense/core/strategies/cost_optimized.py new file mode 100644 index 0000000..dd3b7a5 --- /dev/null +++ b/plotsense/core/strategies/cost_optimized.py @@ -0,0 +1,33 @@ +from typing import Dict, List, Optional, Tuple +from plotsense.core.strategies.strategy import Strategy + + +class CostOptimizedStrategy(Strategy): + """Prioritize cheaper models first, fallback to pricier if needed.""" + + def __init__(self, provider_models: List[Tuple[str, str]], provider_manager): + super().__init__(provider_models) + + self.cost_map: Dict[str, float] = provider_manager.get_model_costs() + + # Sort models by ascending cost (lowest first) + self.model_list = sorted( + provider_models, + key=lambda p_m: self.cost_map.get(p_m[1], float("inf")) + ) + + def select_models(self, n: int) -> List[Tuple[str, str]]: + """ + Return the top `n` cheapest models. + """ + return self.model_list[:n] + + def select_model( + self, iteration: int, current_explanation: Optional[str] = None + ) -> Tuple[str, str]: + """Use the cheapest available model, escalate if iteration increases.""" + if not self.model_list: + raise ValueError("No models available in strategy.") + index = min(iteration, len(self.model_list) - 1) + return self.model_list[index] + diff --git a/plotsense/core/strategies/fallback_chain.py b/plotsense/core/strategies/fallback_chain.py new file mode 100644 index 0000000..2110e8d --- /dev/null +++ b/plotsense/core/strategies/fallback_chain.py @@ -0,0 +1,24 @@ +from typing import List, Optional, Tuple +from plotsense.core.strategies.strategy import Strategy + +class FallbackChainStrategy(Strategy): + """Try providers/models in fixed order until one succeeds.""" + + def __init__(self, provider_models: List[Tuple[str, str]]): + super().__init__(provider_models) + # Deterministic order; could later be made configurable + self.model_list = provider_models + self._last_success_index = 0 + + def select_model( + self, iteration: int, current_explanation: Optional[str] = None + ) -> Tuple[str, str]: + """If previous success exists, keep using it; otherwise go to next.""" + if not self.model_list: + raise ValueError("No models available in strategy.") + index = min(iteration, len(self.model_list) - 1) + return self.model_list[index] + + def report_success(self, index: int): + """Optionally record which model last succeeded.""" + self._last_success_index = index diff --git a/plotsense/core/strategies/performance_optimized.py b/plotsense/core/strategies/performance_optimized.py new file mode 100644 index 0000000..bff04bb --- /dev/null +++ b/plotsense/core/strategies/performance_optimized.py @@ -0,0 +1,40 @@ +from typing import Dict, List, Optional, Tuple +from plotsense.core.strategies.strategy import Strategy + +MODEL_PERFORMANCE_MAP = { + "gpt-4o": 10, + "gpt-4o-mini": 8, + "llama-3.3-70b-versatile": 9, + "llama-3.1-8b-instant": 6, +} + +class PerformanceOptimizedStrategy(Strategy): + """Prefer highest-performance models first.""" + + def __init__( + self, provider_models: List[Tuple[str, str]], provider_manager + ): + super().__init__(provider_models) + + # Get dynamic performance scores from ProviderManager + self.performance_map: Dict[str, float] = provider_manager.get_model_performance() + + # Sort models descending by performance score + self.model_list = sorted( + provider_models, + key=lambda p_m: self.performance_map.get(p_m[1], 0), + reverse=True, + ) + + def select_models(self, n: int) -> List[Tuple[str, str]]: + """Return the top `n` highest-performing models.""" + return self.model_list[:n] + + def select_model( + self, iteration: int, current_explanation: Optional[str] = None + ) -> Tuple[str, str]: + """Start from best model; fallback to lower-tier ones if needed.""" + if not self.model_list: + raise ValueError("No models available in strategy.") + index = min(iteration, len(self.model_list) - 1) + return self.model_list[index] diff --git a/plotsense/core/strategies/round_robin.py b/plotsense/core/strategies/round_robin.py new file mode 100644 index 0000000..45a3c89 --- /dev/null +++ b/plotsense/core/strategies/round_robin.py @@ -0,0 +1,19 @@ +from typing import List, Optional, Tuple +from plotsense.core.strategies.strategy import Strategy + +class RoundRobinStrategy(Strategy): + """Cycle through all models evenly.""" + + def __init__(self, provider_models: List[Tuple[str, str]]): + super().__init__(provider_models) + self.model_list = provider_models + self._last_index = -1 + + def select_model( + self, iteration: int, current_explanation: Optional[str] = None + ) -> Tuple[str, str]: + if not self.model_list: + raise ValueError("No models available in strategy.") + # Pick model based directly on iteration count + index = iteration % len(self.model_list) + return self.model_list[index] diff --git a/plotsense/core/strategies/strategy.py b/plotsense/core/strategies/strategy.py new file mode 100644 index 0000000..52940b6 --- /dev/null +++ b/plotsense/core/strategies/strategy.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from typing import List, Tuple, Optional + +class Strategy(ABC): + """ + Base Strategy interface for selecting provider/model pairs. + Each strategy returns a tuple: (provider_name, model_name) + """ + + def __init__(self, provider_models: List[Tuple[str, str]]): + """ + Args: + provider_models: dict mapping provider_name -> list of models + """ + self.provider_models = provider_models + + @abstractmethod + def select_model(self, iteration: int, current_explanation: Optional[str] = None) -> Tuple[str, str]: + """ + Return a (provider, model) tuple for the given iteration. + + Args: + iteration: current iteration index (0-based) + current_explanation: optionally, the current explanation for refinement + + Returns: + (provider_name, model_name) + """ + pass + diff --git a/plotsense/core/utils.py b/plotsense/core/utils.py new file mode 100644 index 0000000..89c9447 --- /dev/null +++ b/plotsense/core/utils.py @@ -0,0 +1,55 @@ +import builtins +import base64 +from matplotlib.figure import Figure +from matplotlib.axes import Axes +from typing import Optional, Union, cast + + +def prompt_for_api_key( + service_name: str, service_link: str, interactive: bool = True, + skip_if_missing: bool = False +) -> Optional[str]: + """Prompt user for API key or raise if unavailable.""" + if not interactive: + if skip_if_missing: + return None + raise ValueError( + f"{service_name.upper()} API key is required. " + f"Set it in the environment or pass it as an argument. " + f"You can get it at {service_link}" + ) + + try: + print(f"βš™οΈ {service_name.upper()} API key not found.") + print(f"πŸ”— Get it at {service_link}") + key = builtins.input(f"Enter {service_name.upper()} API key (or press Enter to skip): ").strip() + if not key and skip_if_missing: + return None + if not key: + raise ValueError(f"{service_name.upper()} API key is required.") + return key + except (EOFError, OSError): + if skip_if_missing: + return None + raise ValueError(f"{service_name.upper()} API key is required (get it at {service_link})") + +def save_plot_to_image( + plot_object: Union[Figure, Axes], + output_path: str = "temp_plot.jpg" +) -> str: + """Save a matplotlib Figure or Axes object to a JPEG image file.""" + if isinstance(plot_object, Axes): + fig = plot_object.figure + else: + fig = plot_object + cast(Figure, fig).savefig( + output_path, format='jpeg', dpi=100, bbox_inches='tight' + ) + return output_path + + +def encode_image(image_path: str) -> str: + """Encode image file to base64 string.""" + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + diff --git a/plotsense/explanations/explanations.py b/plotsense/explanations/explanations.py index e30533e..46413a1 100644 --- a/plotsense/explanations/explanations.py +++ b/plotsense/explanations/explanations.py @@ -1,186 +1,67 @@ -import base64 import os -import matplotlib.pyplot as plt -from typing import Union, Optional, Dict, List +# import matplotlib.pyplot as plt +from matplotlib.figure import Figure +from matplotlib.axes import Axes +from typing import List, Tuple, Union, Optional, Dict from dotenv import load_dotenv -from groq import Groq -import warnings -import builtins +from plotsense.core.ai_interface import AIModelInterface +from plotsense.core.enums.strategy import StrategyName +from plotsense.core.providers.provider_manager import ProviderManager +from plotsense.core.utils import encode_image, save_plot_to_image load_dotenv() + class PlotExplainer: """ - A class to generate and refine explanations for plots using LLMs.""" - DEFAULT_MODELS = { - 'groq': ['meta-llama/llama-4-scout-17b-16e-instruct', - 'meta-llama/llama-4-maverick-17b-128e-instruct'], - } - + A class to generate and refine explanations for plots using LLMs. + """ + def __init__( - self, - api_keys: Optional[Dict[str, str]] = None, - max_iterations: int = 3, - interactive: bool = True, - timeout: int = 30 + self, + api_keys: Optional[Dict[str, str]], + strategy: StrategyName = StrategyName.ROUND_ROBIN, + selected_models: Optional[List[Tuple[str, str]]] = None, + max_iterations: int = 3, + interactive: bool = True, + timeout: int = 30, ): - # Default to empty dict if None - api_keys = api_keys or {} - - ## Initialize API keys with environment variable or provided keys - self.api_keys = { - 'groq': os.getenv('GROQ_API_KEY') - } - # Update with provided API keys - self.api_keys.update(api_keys) - # Set interactive mode and timeout for API calls - self.interactive = interactive - # Set timeout for API calls - self.timeout = timeout - # Initialize empty dict for clients - self.clients = {} - # Initialize empty list for available models - self.available_models = [] - # Set max iterations for refinement - self.max_iterations = max_iterations - - # Validate API keys and initialize clients - self._validate_keys() - # Initialize clients - self._initialize_clients() - # Detect available models - self._detect_available_models() - - def _validate_keys(self): - """Validate that required API keys are present""" - service_links = { - 'groq': 'πŸ‘‰ https://console.groq.com/keys πŸ‘ˆ' - } - - for service in ['groq']: - if not self.api_keys.get(service): - if self.interactive: - try: - link = service_links.get(service, f"the {service.upper()} website") - message = ( - f"Enter {service.upper()} API key (get it at {link}): " - ) - self.api_keys[service] = builtins.input(message).strip() - if not self.api_keys[service]: - raise ValueError(f"{service.upper()} API key is required") - except (EOFError, OSError): - # Handle cases where input is not available - raise ValueError(f"{service.upper()} API key is required (get it at {service_links.get(service)})") - else: - raise ValueError( - f"{service.upper()} API key is required. " - f"Set it in the environment or pass it as an argument. " - f"You can get it at {service_links.get(service)}" - ) + self.timeout = timeout # timeout for API calls + self.max_iterations = max_iterations # max iterations for refinement + self.strategy_name = strategy # strategy for provider selection - def _initialize_clients(self): - """Initialize API clients based on provided API keys""" - self.clients = {} - if self.api_keys.get('groq'): - try: - self.clients['groq'] = Groq(api_key=self.api_keys['groq']) - except Exception as e: - warnings.warn(f"Could not initialize Groq client: {e}", ImportWarning) - - def _detect_available_models(self): - """Detect available models based on initialized clients""" - self.available_models = [] - for provider, client in self.clients.items(): - if client and provider in self.DEFAULT_MODELS: - self.available_models.extend(self.DEFAULT_MODELS[provider]) - - def save_plot_to_image( - self, - plot_object: Union[plt.Figure, plt.Axes], - output_path: str = "temp_plot.jpg" - ): - """Save plot to an image file""" - if isinstance(plot_object, plt.Axes): - fig = plot_object.figure - else: - fig = plot_object - - fig.savefig(output_path, format='jpeg', dpi=100, bbox_inches='tight') - return output_path - - def encode_image( - self, - image_path: str - ) -> str: - """Encode image file to base64 string""" - with open(image_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode('utf-8') + selected_providers = {p for p, _ in (selected_models or [])} - def _query_model( - self, - model: str, - prompt: str, - image_path: str, - custom_parameters: Optional[Dict] = None - ) -> str: - - """Generic model querying method with provider-specific logic""" - - base64_image = self.encode_image(image_path) - - # Determine provider based on model name - provider = next( - (p for p, models in self.DEFAULT_MODELS.items() if model in models), - None + self.manager = ProviderManager( + api_keys=api_keys or {}, + interactive=interactive, + restrict_to=list(selected_providers) if selected_providers else None + ) + self.ai_interface = AIModelInterface(self.manager, timeout=self.timeout) + + # if selected_models: + # self.available_models = self.manager.list_all_models + # else: + all_models = self.manager.list_all_models() + self.available_models = [ + (provider, model) + for provider, models in all_models.items() + for model in models + ] + + if not self.available_models: + raise ValueError( + "No available models detected β€” check API keys or selection input." + ) + + self.strategy = self.ai_interface._init_strategy( + self.strategy_name, self.available_models ) - - if not provider: - raise ValueError(f"No provider found for model {model}") - - try: - if provider == 'groq': - client = self.clients['groq'] - - # Merge default and custom parameters - default_params = { - 'max_tokens': 1000, - 'temperature': 0.7 - } - generation_params = {**default_params, **(custom_parameters or {})} - - response = client.chat.completions.create( - model=model, - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}" - } - } - ] - } - ], - **generation_params - ) - - return response.choices[0].message.content - - except Exception as e: - if "503" in str(e): - print(f"Groq service temporarily unavailable, retrying... Error: {e}") - raise # This will trigger retry - error_message = f"Model querying error for {model}: {str(e)}" - warnings.warn(error_message) - return error_message def refine_plot_explanation( self, - plot_object: Union[plt.Figure, plt.Axes], + plot_object: Union[Figure, Axes], prompt: str = "Explain this data visualization", temp_image_path: str = "temp_plot.jpg", custom_parameters: Optional[Dict] = None @@ -190,40 +71,50 @@ def refine_plot_explanation( raise ValueError("No available models detected") # Save plot to temporary image file - image_path = self.save_plot_to_image(plot_object, temp_image_path) - + image_path = save_plot_to_image(plot_object, temp_image_path) + try: # Iterative refinement process current_explanation = None - + for iteration in range(self.max_iterations): - current_model = self.available_models[iteration % len(self.available_models)] - + provider, current_model = self.strategy.select_model( + iteration, current_explanation + ) + if current_explanation is None: current_explanation = self._generate_initial_explanation( - current_model, image_path, prompt, custom_parameters + provider, current_model, image_path, prompt, custom_parameters ) else: critique = self._generate_critique( - image_path, current_explanation, prompt, current_model, custom_parameters + provider, current_model, image_path, current_explanation, prompt, custom_parameters ) - + current_explanation = self._generate_refinement( - image_path, current_explanation, critique, prompt, current_model, custom_parameters + provider, current_model, image_path, + current_explanation, critique, prompt, + custom_parameters ) + if current_explanation is None: + raise RuntimeError( + "Failed to generate an explanation β€” no models available or initial step failed." + ) + return current_explanation - + finally: # Clean up temporary image file if os.path.exists(image_path): os.remove(image_path) def _generate_initial_explanation( - self, - model: str, + self, + provider: str, + model: str, image_path: str, - original_prompt: str, + original_prompt: str, custom_parameters: Optional[Dict] = None ) -> str: """Generate initial plot explanation with structured format""" @@ -237,7 +128,7 @@ def _generate_initial_explanation( 4. Conclusion - Be specific and data-driven - Highlight key statistical and visual elements - + Specific Prompt: {original_prompt} Formatting Instructions: @@ -246,20 +137,22 @@ def _generate_initial_explanation( - Provide quantitative insights - Explain the significance of visual elements """ - + return self._query_model( - model=model, + provider=provider, + model=model, prompt=base_prompt, - image_path=image_path, + image_path=image_path, custom_parameters=custom_parameters ) def _generate_critique( - self, - image_path: str, - current_explanation: str, - original_prompt: str, + self, + provider: str, model: str, + image_path: str, + current_explanation: str, + original_prompt: str, custom_parameters: Optional[Dict] = None ) -> str: """Generate critique of current explanation""" @@ -293,21 +186,23 @@ def _generate_critique( Provide a constructive critique that will help refine the explanation. """ - + return self._query_model( - model=model, - prompt=critique_prompt, - image_path=image_path, + provider=provider, + model=model, + prompt=critique_prompt, + image_path=image_path, custom_parameters=custom_parameters ) def _generate_refinement( - self, - image_path: str, - current_explanation: str, - critique: str, - original_prompt: str, + self, + provider: str, model: str, + image_path: str, + current_explanation: str, + critique: str, + original_prompt: str, custom_parameters: Optional[Dict] = None ) -> str: """Generate refined explanation based on critique""" @@ -342,50 +237,84 @@ def _generate_refinement( - Use markdown-style headers for clarity - Include bullet points for clarity - Provide quantitative insights - - Ensure the explanation is comprehensive and insightful - + - Ensure the explanation is comprehensive and insightful """ - + return self._query_model( + provider=provider, model=model, - prompt= refinement_prompt, + prompt= refinement_prompt, image_path=image_path, custom_parameters= custom_parameters ) + def _query_model( + self, provider: str, model: str, prompt: str, image_path: str, + custom_parameters: Optional[Dict] = None + ) -> str: + base64_image = encode_image(image_path) + return self.ai_interface.query_model( + provider=provider, + model=model, + prompt=prompt, + base64_image=base64_image, + custom_parameters=custom_parameters + ) + # Package-level convenience function _explainer_instance = None def explainer( - plot_object: Union[plt.Figure, plt.Axes], + plot_object: Union[Figure, Axes], prompt: str = "Explain this data visualization", + *, # force keyword args after this + + custom_parameters: Optional[Dict] = None, + strategy: StrategyName = StrategyName.ROUND_ROBIN, + selected_models: Optional[List[Tuple[str, str]]] = None, + api_keys: Optional[Dict[str, str]] = None, max_iterations: int = 3, - custom_parameters: Optional[Dict] = None, - temp_image_path: str = "temp_plot.jpg" + interactive: bool = True, + timeout: int = 30, + temp_image_path: str = "temp_plot.jpg", ) -> str: """ - Convenience function for iterative plot explanation - + Convenience function to generate and refine plot explanations + Uses a singleton PlotExplainer instance for efficiency. + Args: - data: Original data used to create the plot (DataFrame or numpy array) - plot_object: Matplotlib Figure or Axes - prompt: Explanation prompt - api_keys: API keys for different providers - max_iterations: Maximum refinement iterations - custom_parameters: Additional generation parameters - + - plot_object: Matplotlib Figure or Axes object to explain + - prompt: Initial prompt for explanation generation + - custom_parameters: Optional dict of custom parameters for the model + - strategy: StrategyName enum for model selection strategy + - selected_models: Optional list of (provider, model) tuples to restrict models + - api_keys: Optional dict of API keys for providers + - max_iterations: Max refinement iterations + - interactive: Whether to prompt user for input when needed + - timeout: Timeout in seconds for API calls + - temp_image_path: Path to save temporary plot image + Returns: - Comprehensive explanation with refinement details + A comprehensive, refined explanation generated by the chosen AI models. """ global _explainer_instance + if _explainer_instance is None: - _explainer_instance = PlotExplainer(api_keys=api_keys, - max_iterations=max_iterations) + _explainer_instance = PlotExplainer( + api_keys=api_keys, + strategy=strategy, + selected_models=selected_models, + max_iterations=max_iterations, + interactive=interactive, + timeout=timeout, + ) + return _explainer_instance.refine_plot_explanation( plot_object=plot_object, prompt=prompt, custom_parameters=custom_parameters, temp_image_path=temp_image_path ) + diff --git a/plotsense/plot_chat/__init__.py b/plotsense/plot_chat/__init__.py new file mode 100644 index 0000000..32528d3 --- /dev/null +++ b/plotsense/plot_chat/__init__.py @@ -0,0 +1 @@ +from plotsense.plot_chat.client import PlotChatClient diff --git a/plotsense/plot_chat/action.py b/plotsense/plot_chat/action.py new file mode 100644 index 0000000..b62df69 --- /dev/null +++ b/plotsense/plot_chat/action.py @@ -0,0 +1,118 @@ +import re, json, base64 +from io import BytesIO +from typing import Optional, Dict, Any +import pandas as pd +import matplotlib.pyplot as plt +import io + +from plotsense.plot_generator.generator import plotgen + + +class ActionClient: + """ + Handles AI-powered PlotSense actions (plotgen, explainer, etc.) + Takes the user's prompt, analyzes it, calls the right PlotSense function, + and streams back human-like text + generated image. + """ + + def __init__(self, client): + self.client = client + + @staticmethod + def _fig_to_base64(fig) -> str: + """Convert matplotlib Figure to base64 string.""" + buffer = BytesIO() + fig.savefig(buffer, format="png", bbox_inches="tight", dpi=100) + buffer.seek(0) + img_str = base64.b64encode(buffer.read()).decode("utf-8") + plt.close(fig) + return f"data:image/png;base64,{img_str}" + + def handle_plotgen_extension( + self, + model: str, + message: str, + # df: pd.DataFrame, + previous_response_id: Optional[str] = None, + upload_fn=None, + ) -> Dict[str, Any]: + """ + Handles the plotgen extension: analyzes prompt, generates plot, + and streams AI text + image inline. + """ + + # Based on the dataframe provided (columns: {list(df.columns)}), + extraction_instructions = f""" + You are a PlotSense assistant. + The user says: "{message}" + identify a suitable plot type and columns to use. + Respond *only* in JSON like: + {{ + "df": , (In a format I will later cast to a DataFrame using + `pd.DataFrame()`) + "plot_type": "scatter", + "variables": ["a", "b"] + }} + If unsure, respond with: + {{ "error": "Could not extract columns." }} + """ + + extraction_response = self.client.responses.create( + model=model, + instructions=extraction_instructions, + input=[{"role": "user", "content": [{"type": "input_text", "text": message}]}], + previous_response_id=previous_response_id, + stream=False, + ) + + extraction_output = extraction_response.output_text.strip() + extraction_output = re.sub(r"^```(json)?", "", extraction_output) + extraction_output = re.sub(r"```$", "", extraction_output).strip() + + print("Extraction JSON:", extraction_output) + + try: + plot_request = json.loads(extraction_output) + if "error" in plot_request: + return {"error": plot_request["error"]} + df = pd.DataFrame(plot_request["df"]) + plot_type = plot_request["plot_type"] + variables = plot_request["variables"] + except json.JSONDecodeError: + return {"error": "Could not parse the request for plotting."} + + try: + suggestion_row = pd.Series({ + "plot_type": plot_type, + "variables": ",".join(variables) + }) + file_obj = io.BytesIO() + fig = plotgen(df, suggestion_row, generator="basic") + fig.savefig(file_obj, format="png", bbox_inches="tight", dpi=150) + file_obj.seek(0) # rewind to start + img_base64 = self._fig_to_base64(fig) + image_url = img_base64 + if upload_fn: + image_url = upload_fn(file_obj=file_obj, + key="plotgen_image.png", content_type="image/png") + print("Uploaded image URL:", image_url) + except Exception as e: + return {"error": f"Plot generation failed: {str(e)}"} + + followup_prompt = f""" + The plot has been generated successfully using: + - Plot Type: {plot_type} + - Variables: {variables} + + The image below shows the resulting visualization. + Include this url in your response: {image_url} + Please explain the resut of this plot + in a friendly, human-like conversational tone. + Showing that the image is provided below + """ + + return { + "text": followup_prompt.strip(), + "image": image_url, + } + diff --git a/plotsense/plot_chat/audio.py b/plotsense/plot_chat/audio.py new file mode 100644 index 0000000..6a64317 --- /dev/null +++ b/plotsense/plot_chat/audio.py @@ -0,0 +1,29 @@ +from io import BytesIO +from typing import BinaryIO + +from plotsense.plot_chat.streaming import ChatStreamWrapper + +class AudioClient: + def __init__(self, client): + self.client = client + + def create_audio_transcription(self, file_obj: BinaryIO, model: str): + stream = self.client.audio.transcriptions.create( + file=file_obj, + model=model, + language="en", + stream=True + ) + return ChatStreamWrapper(stream) + + def create_audio_speech(self, text: str, voice: str, model: str) -> BytesIO: + with self.client.audio.speech.with_streaming_response.create( + model=model, + voice=voice, + response_format="mp3", + input=text + ) as response: + audio_bytes = response.read() + buffer = BytesIO(audio_bytes) + buffer.seek(0) + return buffer diff --git a/plotsense/plot_chat/chat.py b/plotsense/plot_chat/chat.py new file mode 100644 index 0000000..5a46ce3 --- /dev/null +++ b/plotsense/plot_chat/chat.py @@ -0,0 +1,130 @@ +from typing import List, Optional + +from plotsense.plot_chat.action import ActionClient +from plotsense.plot_chat.streaming import ChatStreamWrapper +from .prompts import get_instructions + + +class ChatClient: + def __init__(self, client): + self.client = client + self.action_client = ActionClient(client) + + def chat_stream( + self, + model: str, + message: str, + previous_response_id: Optional[str] = None, + fileIds: List[str] = [], + imageUrls: List[str] = [], + instructions: List[str] = [], + extension: Optional[str] = None, + upload_fn=None, + # df=None, + ): + """ + Main streaming entrypoint for chat. + Handles PlotSense extensions (plotgen, explainer, etc.) + before falling back to normal chat streaming. + """ + + content_blocks = [] + prompt = message + + print("ChatClient.chat_stream: extension =", extension) + if extension and extension.lower() == "plotgen": + print("ChatClient.chat_stream: extension =", extension) + # Delegate to specialized handler + action_result = self.action_client.handle_plotgen_extension( + previous_response_id=previous_response_id if model.lower().startswith("gpt") else None, + model=model, + message=message, + # df=df, + upload_fn=upload_fn, + ) + + if "error" in action_result: + content_blocks.append({ + "type": "input_text", + "text": f"⚠️ {action_result['error']}", + }) + else: + # Include AI follow-up text + plot image + content_blocks.append({ + "type": "input_text", + "text": action_result["text"], + }) + # "type": "input_image" is only supported for gpt models that can render images inline + if model.lower().startswith("gpt"): + content_blocks.append({ + "type": "input_image", + # "text": f"imageUrl: {action_result["image"]}", + "image_url": action_result["image"], + "detail": "high", + }) + else: + content_blocks.append({ + "type": "input_text", + "text": f"imageUrl: {action_result["image"]}", + }) + + # content_blocks = [] + if fileIds: + for fileId in fileIds: + content_blocks.append({ + "type": "input_file", + "file_id": fileId + }) + + if imageUrls and model.lower().startswith("gpt"): + for imageUrl in imageUrls: + content_blocks.append({ + "type": f"input_image", + # "text": f"{imageUrl}", + "image_url": imageUrl, + "detail": "high", + }) + + content_blocks.append({ + "type": "input_text", + "text": prompt + }) + print("Content blocks:", content_blocks) + + stream = self.client.responses.create( + model=model, + instructions=get_instructions(instructions), + input=[ + { + "role": "user", + "content": content_blocks + }, + ], + previous_response_id=previous_response_id if model.lower().startswith("gpt") else None, + stream=True, + ) + + return ChatStreamWrapper(stream) + + def prompt( + self, model: str, prompt: str, previous_response_id: Optional[str] = None + ) -> str: + response = self.client.responses.create( + model=model, + instructions=get_instructions([]), + input=prompt, + previous_response_id=previous_response_id if model.lower().startswith("gpt") else None, + ) + return response.output_text + + def generate_chat_title( + self, model: str, assessment_title: str, initial_prompt: str + ) -> str: + from .prompts import generate_chat_title_prompt + prompt = generate_chat_title_prompt(assessment_title, initial_prompt) + response = self.client.responses.create( + model=model, + instructions=get_instructions([]), + input=prompt + ) + return response.output_text diff --git a/plotsense/plot_chat/client.py b/plotsense/plot_chat/client.py new file mode 100644 index 0000000..d9fe385 --- /dev/null +++ b/plotsense/plot_chat/client.py @@ -0,0 +1,19 @@ +from openai import OpenAI +from .chat import ChatClient +from .audio import AudioClient +from .file import FileClient +from .realtime import RealtimeClient + + +class PlotChatClient: + def __init__(self, api_key: str = ""): + self.client = OpenAI( + api_key=api_key, + # base_url="https://api.groq.com/openai/v1" + ) + + self.chat = ChatClient(self.client) + self.audio = AudioClient(self.client) + self.files = FileClient(self.client) + self.realtime = RealtimeClient(self.client) + diff --git a/plotsense/plot_chat/file.py b/plotsense/plot_chat/file.py new file mode 100644 index 0000000..07cc8e7 --- /dev/null +++ b/plotsense/plot_chat/file.py @@ -0,0 +1,15 @@ +from io import BytesIO +from openai.types import FilePurpose + + +class FileClient: + def __init__(self, client): + self.client = client + + def upload_file(self, file_obj: BytesIO, purpose: FilePurpose) -> str: + response = self.client.files.create(file=file_obj, purpose=purpose) + return response.id + + def delete_file(self, file_id: str): + self.client.files.delete(file_id) + diff --git a/plotsense/plot_chat/function_calls.py b/plotsense/plot_chat/function_calls.py new file mode 100644 index 0000000..3a59e6a --- /dev/null +++ b/plotsense/plot_chat/function_calls.py @@ -0,0 +1,187 @@ +import json +from openai.types.responses import ResponseInputParam, ToolParam, FunctionToolParam +from openai.types.responses.response_input_param import Message +import pandas as pd +from typing import List, Dict, Callable, Optional +from io import StringIO +import matplotlib.pyplot as plt +from matplotlib.figure import Figure +import io +import base64 + +from plotsense.explanations.explanations import explainer +from plotsense.plot_generator.generator import plotgen +from plotsense.visual_suggestion.suggestions import recommender + + +class FunctionCallClient: + """Orchestrates OpenAI function calls for multiple tools automatically.""" + + # Predefined internal tools mapping + TOOL_DEFINITIONS = { + "plotgen": { + "type": "function", + "name": "generate_plot", + "description": "Generate a plot based on suggestions", + "parameters": { + "type": "object", + "properties": { + "df": {"type": "string", "description": "JSON-serialized DataFrame"}, + "suggestion": {"type": "integer", "description": "Index or row identifier of suggestion"} + }, + "required": ["df", "suggestion"], + "additionalProperties": False + } + }, + "explainer": { + "type": "function", + "name": "explain_plot", + "description": "Generate an explanation for a plot", + "parameters": { + "type": "object", + "properties": { + "plot_object": {"type": "string", "description": "Reference to internal Figure object"} + }, + "required": ["plot_object"], + "additionalProperties": False + } + }, + "recommender": { + "type": "function", + "name": "recommend", + "description": "Generate top-N recommended plots or actions", + "parameters": { + "type": "object", + "properties": { + "df": {"type": "string", "description": "JSON-serialized DataFrame"}, + "n": {"type": "integer", "description": "Number of recommendations to return"} + }, + "required": ["df", "n"], + "additionalProperties": False + } + } + } + + # Map tool identifiers to actual Python functions + TOOL_FUNCTION_MAPPING = { + "plotgen": lambda df, suggestion, api_key="", **kwargs: plotgen( + pd.read_json(StringIO(df)), + api_keys={"openai": api_key}, + suggestion=suggestion, suggestions_df=pd.read_json(StringIO(df)), + selected_models = [("openai", "gpt-5"), ("openai", "gpt-4-turbo")], + **kwargs + ), + "explainer": lambda plot_object, api_key="", **kwargs: explainer( + plot_object, + api_keys={"openai": api_key}, + selected_models = [("openai", "gpt-5"), ("openai", "gpt-4-turbo")], + **kwargs + ), + "recommender": lambda df, n=5, api_key="", **kwargs: recommender( + pd.read_json(StringIO(df)), n=n, + api_keys={"openai": api_key}, + selected_models = [("openai", "gpt-5"), ("openai", "gpt-4-turbo")], + **kwargs + ) + } + + def __init__(self, client): + self.client = client + self.tools: list[ToolParam] = [] + self.function_mapping: Dict[str, Callable] = {} + + def register_tools(self, tool_names: List[str]): + """Register tools by their identifier name (plotgen, explainer, etc.)""" + for name in tool_names: + if name not in self.TOOL_DEFINITIONS: + raise ValueError(f"No predefined tool for identifier '{name}'") + tool_def = self.TOOL_DEFINITIONS[name] + if tool_def['name'] in self.function_mapping: + continue # Already registered + + func_tool: FunctionToolParam = { + "name": tool_def["name"], + "description": tool_def.get("description"), + "parameters": tool_def["parameters"], + "type": "function", + "strict": True + } + self.tools.append(func_tool) + self.function_mapping[tool_def['name']] = self.TOOL_FUNCTION_MAPPING[name] + + def handle_user_input( + self, + user_input: str, + instructions: Optional[str] = None + ) -> str: + """Main orchestrator for multi-tool function calls""" + input_list: ResponseInputParam = [ + Message( + role="user", + type="message", + content=[ + {"type": "input_text", "text": user_input} + ] + ) + ] + + calls_remaining = True + final_response_text = "" + conversation_history = input_list.copy() + + while calls_remaining: + # Ask model with all tools + response = self.client.responses.create( + model="gpt-5", + tools=self.tools, + input=conversation_history, + instructions=instructions, + # stream=True, + ) + + calls_remaining = False # will be True if model requests any function call + new_inputs = [] + + # Process function calls + for item in response.output: + if item.type == "function_call": + calls_remaining = True + func_name = item.name + args = json.loads(item.arguments) + + if func_name in self.function_mapping: + print("Args:", args) + # Fix suggestion parameter type + if func_name == "generate_plot" and "suggestion" in args: + if isinstance(args["suggestion"], str) and args["suggestion"].isdigit(): + args["suggestion"] = int(args["suggestion"]) + + result = self.function_mapping[func_name]( + **args, + api_key=self.client.api_key + ) + + # Convert result to JSON-serializable format + if isinstance(result, pd.DataFrame): + result_serializable = result.to_json(orient='split') + elif isinstance(result, Figure): + buf = io.BytesIO() + result.savefig(buf, format='png') + buf.seek(0) + result_serializable = {"image_base64": base64.b64encode(buf.read()).decode('utf-8')} + else: + result_serializable = result + + input_list.append({ + "type": "function_call_output", + "call_id": item.call_id, + "output": json.dumps(result_serializable) + }) + # Append outputs to conversation history so next request has context + conversation_history.extend(new_inputs) + + # If no more function calls, capture final text output + if not calls_remaining: + final_response_text = response.output_text + + return final_response_text diff --git a/plotsense/plot_chat/prompts.py b/plotsense/plot_chat/prompts.py new file mode 100644 index 0000000..0f3d196 --- /dev/null +++ b/plotsense/plot_chat/prompts.py @@ -0,0 +1,51 @@ +from typing import List + +PLOTCHAT_SYSTEM_PROMPT = ''' +Your name is Plotly. You are an intelligent and analytical AI assistant integrated into the PlotChat platform β€” an AI-powered data visualization and analysis assistant built on top of PlotSense. + +Your purpose is to help users explore, visualize, and understand their data easily through natural conversation. You can analyze user input, suggest visualization types, generate plots automatically, and explain their meanings clearly. + +PlotSense supports three main modes of operation: + +1. **Recommender** β€” Analyze a user's dataset or question and recommend suitable visualization types (e.g., scatter plot, bar chart, histogram, box plot, etc.). Explain why those charts fit the data or analysis goal. + +2. **PlotGen** β€” Automatically generate plots from user data and instructions. You extract relevant variables, select the right visualization, and use PlotSense’s `plotgen()` function to render the chart. Be precise about variable selection and chart intent. + +3. **Explainer** β€” Interpret and explain plots that have already been generated. Describe what the visualization shows, highlight patterns, correlations, or outliers, and help the user understand data insights. Provide interpretations that are clear and human-like β€” avoid generic commentary. + +You can handle all types of data (numeric, categorical, text-based, or time series) and work across use cases such as analytics, research, and reporting. + +When unsure about user intent or missing data, ask for clarification instead of guessing. Always explain your reasoning in simple, intuitive language with examples if needed. + +Be professional yet friendly, concise but clear. You aim to make data analysis feel effortless and interactive. + +Additionally, you are a conversational AI and have access to the ongoing chat history within this session. Use this context to make your responses relevant, connected, and aware of prior discussions. +''' + +def get_instructions(user_instructions: List[str]) -> str: + if not user_instructions: + return PLOTCHAT_SYSTEM_PROMPT + + pre_text = "\n\n---\nHere are additional instructions provided by the user:\n" + formatted_instructions = "\n".join( + f"- {instruction.strip()}" for instruction in user_instructions if instruction.strip() + ) + return f"{PLOTCHAT_SYSTEM_PROMPT}{pre_text}{formatted_instructions}" + + +def generate_chat_title_prompt(project_title: str, initial_prompt: str) -> str: + # Default project title may be "Untitled Project" + return f""" +You are a helpful assistant tasked with naming PlotChat conversations. + +Based on the project title and the first user message, generate a short, clear, and descriptive title for the chat session. +Keep it between 3 to 6 words. Do not include punctuation at the end. + +Project Title: +\"\"\"{project_title}\"\"\" + +First message: +\"\"\"{initial_prompt}\"\"\" + +Suggested Chat Title: +""" diff --git a/plotsense/plot_chat/realtime.py b/plotsense/plot_chat/realtime.py new file mode 100644 index 0000000..475e21c --- /dev/null +++ b/plotsense/plot_chat/realtime.py @@ -0,0 +1,29 @@ +import requests + + +class RealtimeClient: + def __init__(self, client): + self.client = client + + def generate_ephemeral_key(self) -> str: + url = "https://api.openai.com/v1/realtime/client_secrets" + headers = { + "Authorization": f"Bearer {self.client.api_key}", + "Content-Type": "application/json", + } + payload = { + "session": { + "type": "realtime", + "model": "gpt-realtime" + } + } + + try: + response = requests.post(url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() + print("Ephemeral key generated:", data) + return data.get("value") + except requests.RequestException as e: + raise RuntimeError(f"Failed to generate ephemeral key: {e}") from e + diff --git a/plotsense/plot_chat/streaming.py b/plotsense/plot_chat/streaming.py new file mode 100644 index 0000000..f85226c --- /dev/null +++ b/plotsense/plot_chat/streaming.py @@ -0,0 +1,37 @@ +from openai.types.audio import TranscriptionTextDeltaEvent +from openai.types.responses import ResponseTextDeltaEvent, ResponseTextDoneEvent, ResponseCompletedEvent + + +class ChatStreamWrapper: + def __init__(self, stream): + self._stream = stream + self.item_id = None + self.response_id = None + + def __iter__(self): + for event in self._stream: + if ( + event.type == "response.output_text.delta" + and isinstance(event, ResponseTextDeltaEvent) + ): + if not self.item_id: + self.item_id = event.item_id + yield event.delta + elif ( + event.type == "response.output_text.done" + and isinstance(event, ResponseTextDoneEvent) + ): + self.item_id = event.item_id + elif ( + event.type == "response.completed" + and isinstance(event, ResponseCompletedEvent) + ): + if not self.response_id: + self.response_id = event.response.id + elif ( + event.type == "transcript.text.delta" + and isinstance(event, TranscriptionTextDeltaEvent) + ): + # TranscriptionTextDeltaEvent(delta='To', type='transcript.text.delta', logprobs=None) + yield event.delta + diff --git a/plotsense/plot_generator/__init__.py b/plotsense/plot_generator/__init__.py index 3b22971..ab45af9 100644 --- a/plotsense/plot_generator/__init__.py +++ b/plotsense/plot_generator/__init__.py @@ -1 +1 @@ -from plotsense.plot_generator.generator import plotgen, PlotGenerator \ No newline at end of file +from plotsense.plot_generator.generator import plotgen diff --git a/plotsense/plot_generator/base_generator.py b/plotsense/plot_generator/base_generator.py new file mode 100644 index 0000000..21618b1 --- /dev/null +++ b/plotsense/plot_generator/base_generator.py @@ -0,0 +1,76 @@ +import pandas as pd +from matplotlib.figure import Figure +from typing import Callable, Dict, Optional + +from plotsense.plot_generator.registry import PlotRequirements, PlotTypeRegistry + + +class PlotGenerator: + """ + A class to generate various types of plots based on suggestions. + It uses matplotlib for plotting and can handle both univariate and bivariate cases. + """ + def __init__(self, data, suggestions: Optional[pd.DataFrame] = None): + """ + Initialize with data and plot suggestions. + + Args: + data: DataFrame containing the actual data + suggestions: DataFrame with plot suggestions + """ + if not isinstance(data, pd.DataFrame): + raise TypeError("Data must be a pandas DataFrame") + if data.empty: + raise ValueError("DataFrame is empty") + if not isinstance(suggestions, pd.DataFrame): + raise TypeError("Suggestions must be a pandas DataFrame") + if suggestions.empty: + raise ValueError("Suggestions DataFrame is empty") + if 'plot_type' not in suggestions.columns or 'variables' not in suggestions.columns: + raise ValueError("Suggestions DataFrame must contain 'plot_type' and 'variables' columns") + + self.data = data.copy() + self.suggestions = suggestions + self.registry = PlotTypeRegistry() + self._register_default_plots(self._default_plots) + + @property + def _default_plots(self) -> Dict[str, Callable[..., Figure]]: + """Subclasses override this to define plot type β†’ function mapping.""" + return {} + + def _register_default_plots( + self, plots_to_register: Dict[str, Callable[..., Figure]] + ): + for name, func in plots_to_register.items(): + self.registry.register( + name, + PlotRequirements( + min_variables=1, max_variables=2, numeric_only=True + ), + lambda variables, f=func: f(self.data, variables) + ) + + def generate_plot(self, suggestion_index: int, **kwargs) -> Figure: + """ + Generate a plot based on the suggestion at given index. + + Args: + suggestion_index: Index of the suggestion in dataframe + **kwargs: Additional arguments for the plot + + Returns: + matplotlib Figure object + """ + suggestion = self.suggestions.iloc[suggestion_index] + plot_type = suggestion['plot_type'].lower() + variables = [v.strip() for v in suggestion['variables'].split(',')] + + plot_func = self.registry.get_generator(plot_type) + if not plot_func: + raise ValueError(f"Plot type '{plot_type}' not supported") + + if not self.registry.validate(plot_type, variables, self.data): + raise ValueError(f"Invalid variables for plot '{plot_type}'") + + return plot_func(variables, **kwargs) diff --git a/plotsense/plot_generator/basic_generator.py b/plotsense/plot_generator/basic_generator.py new file mode 100644 index 0000000..7acd6a6 --- /dev/null +++ b/plotsense/plot_generator/basic_generator.py @@ -0,0 +1,29 @@ +from plotsense.plot_generator.base_generator import PlotGenerator +from plotsense.plot_generator.plots.basic.kde import create_kde_plot +from plotsense.plot_generator.plots.basic.barh import create_barh_plot +from plotsense.plot_generator.plots.basic.box import create_box_plot +from plotsense.plot_generator.plots.basic.ecdf import create_ecdf_plot +from plotsense.plot_generator.plots.basic.hist import create_hist_plot +from plotsense.plot_generator.plots.basic.violin import create_violin_plot +from plotsense.plot_generator.plots.basic.bar import create_bar_plot +from plotsense.plot_generator.plots.basic.hexbin import create_hexbin_plot +from plotsense.plot_generator.plots.basic.pie import create_pie_plot +from plotsense.plot_generator.plots.basic.scatter import create_scatter_plot + + +class BasicPlotGenerator(PlotGenerator): + @property + def _default_plots(self): + return { + 'bar': create_bar_plot, + 'barh': create_barh_plot, + 'box': create_box_plot, + 'ecdf': create_ecdf_plot, + 'hexbin': create_hexbin_plot, + 'hist': create_hist_plot, + 'kde': create_kde_plot, + 'pie': create_pie_plot, + 'scatter': create_scatter_plot, + 'violin': create_violin_plot, + } + diff --git a/plotsense/plot_generator/generator.py b/plotsense/plot_generator/generator.py index 54d1c64..721d228 100644 --- a/plotsense/plot_generator/generator.py +++ b/plotsense/plot_generator/generator.py @@ -1,564 +1,132 @@ import pandas as pd -import matplotlib.pyplot as plt -import numpy as np -from typing import List, Dict, Any, Optional, Union - - -class PlotGenerator: - """ - A class to generate various types of plots based on suggestions. - It uses matplotlib for plotting and can handle both univariate and bivariate cases. - """ - def __init__(self, data: pd.DataFrame, suggestions: Optional[pd.DataFrame] = None): - """ - Initialize with data and plot suggestions. - - Args: - data: DataFrame containing the actual data - suggestions: DataFrame with plot suggestions - """ - if not isinstance(data, pd.DataFrame): - raise TypeError("Data must be a pandas DataFrame") - if data.empty: - raise ValueError("DataFrame is empty") - if not isinstance(suggestions, pd.DataFrame): - raise TypeError("Suggestions must be a pandas DataFrame") - if suggestions.empty: - raise ValueError("Suggestions DataFrame is empty") - if 'plot_type' not in suggestions.columns or 'variables' not in suggestions.columns: - raise ValueError("Suggestions DataFrame must contain 'plot_type' and 'variables' columns") - - self.data = data.copy() - self.suggestions = suggestions - self.plot_functions = self._initialize_plot_functions() - - def generate_plot(self, suggestion_index: int, **kwargs) -> plt.Figure: - """ - Generate a plot based on the suggestion at given index. - - Args: - suggestion_index: Index of the suggestion in dataframe - **kwargs: Additional arguments for the plot - - Returns: - matplotlib Figure object - """ - # if suggestion_index < 0 or suggestion_index >= len(self.suggestions): - # raise IndexError("Suggestion index out of range") - if not isinstance(suggestion_index, int): - raise TypeError("Suggestion index must be an integer") - if not isinstance(kwargs, dict): - raise TypeError("Additional arguments must be provided as a dictionary") - if self.suggestions.empty: - raise ValueError("No suggestions available to generate a plot") - if self.data.empty: - raise ValueError("No data available to generate a plot") - if not isinstance(self.suggestions, pd.DataFrame): - raise TypeError("Suggestions must be a pandas DataFrame") - if not isinstance(self.data, pd.DataFrame): - raise TypeError("Data must be a pandas DataFrame") - if self.suggestions.empty: - raise ValueError("Suggestions DataFrame is empty") - if self.data.empty: - raise ValueError("DataFrame is empty") - - suggestion = self.suggestions.iloc[suggestion_index] - plot_type = suggestion['plot_type'].lower() - variables = [v.strip() for v in suggestion['variables'].split(',')] - - if plot_type not in self.plot_functions: - print(f"This version of PlotSense does not support plot type: {plot_type}") - return None - - plot_func = self.plot_functions[plot_type] - return plot_func(variables, **kwargs) - - def _initialize_plot_functions(self) -> Dict[str, callable]: - """Initialize all matplotlib plot functions with their requirements.""" - return { - # Basic plots - 'scatter': self._create_scatter, - 'bar': self._create_bar, - 'barh': self._create_barh, - - # Statistical plots - 'hist': self._create_hist, - 'boxplot': self._create_box, - 'violinplot': self._create_violin, - - # Specialized plots - 'pie': self._create_pie, - 'hexbin': self._create_hexbin - - } - - - # ========== Basic Plot Functions ========== - def _create_scatter(self, variables: List[str], **kwargs) -> plt.Figure: - if len(variables) < 2: - raise ValueError("scatter requires at least 2 variables (x, y)") - fig, ax = plt.subplots() - ax.scatter(self.data[variables[0]], self.data[variables[1]], **kwargs) - self._set_labels(ax, variables) - ax.set_title(f"Scatter: {variables[0]} vs {variables[1]}") - return fig - - def _create_bar(self, variables: List[str], **kwargs) -> plt.Figure: - fig, ax = plt.subplots(figsize=(10, 6)) - - # Extract label-related kwargs if provided - x_label = kwargs.pop('x_label', None) - y_label = kwargs.pop('y_label', None) - title = kwargs.pop('title', None) - - # Define font sizes - tick_fontsize = kwargs.pop('tick_fontsize', 12) - label_fontsize = kwargs.pop('label_fontsize', 14) - title_fontsize = kwargs.pop('title_fontsize', 16) - - if len(variables) == 1: - # Single variable - show value counts - value_counts = self.data[variables[0]].value_counts().sort_values(ascending=False) - ax.bar(value_counts.index.astype(str), value_counts.values, **kwargs) - ax.set_xlabel(variables[0] if x_label is None else x_label, fontsize=label_fontsize) - ax.set_ylabel('Count' if y_label is None else y_label, fontsize=label_fontsize) - ax.set_title(f"Bar plot of {variables[0]}" if title is None else title, fontsize=title_fontsize) - ax.tick_params(axis='x', labelsize=tick_fontsize) - ax.tick_params(axis='y', labelsize=tick_fontsize) - - - if len(value_counts) > 10: - fig.set_size_inches(max(12, len(value_counts)), 8) - plt.setp(ax.get_xticklabels(), rotation=90, ha='center') - - else: - # First variable is numeric, second is categorical - grouped = self.data.groupby(variables[1])[variables[0]].mean().sort_values(ascending=False) - ax.bar(grouped.index.astype(str), grouped.values, **kwargs) - ax.set_xlabel(variables[1] if x_label is None else x_label, fontsize=label_fontsize) - ax.set_ylabel(f"{variables[0]}" if y_label is None else y_label, fontsize=label_fontsize) - ax.set_title(f"{variables[0]} by {variables[1]}" if title is None else title, fontsize=title_fontsize) - ax.tick_params(axis='x', labelsize=tick_fontsize) - ax.tick_params(axis='y', labelsize=tick_fontsize) - - if len(grouped) > 10: - fig.set_size_inches(max(12, len(grouped)), 8) - plt.setp(ax.get_xticklabels(), rotation=90, ha='center') - - return fig - - def _create_barh(self, variables: List[str], **kwargs) -> plt.Figure: - fig, ax = plt.subplots(figsize=(10, 6)) - - - # Extract label-related kwargs if provided - x_label = kwargs.pop('x_label', None) - y_label = kwargs.pop('y_label', None) - title = kwargs.pop('title', None) - - # Define font sizes - tick_fontsize = kwargs.pop('tick_fontsize', 12) - label_fontsize = kwargs.pop('label_fontsize', 14) - title_fontsize = kwargs.pop('title_fontsize', 16) - - if len(variables) == 1: - # Single variable - show value counts - value_counts = self.data[variables[0]].value_counts() - ax.barh(value_counts.index.astype(str), value_counts.values, **kwargs) - ax.set_xlabel(variables[0] if x_label is None else x_label, fontsize=label_fontsize) - ax.set_ylabel('Count' if y_label is None else y_label, fontsize=label_fontsize) - ax.set_title(f"Bar plot of {variables[0]}" if title is None else title, fontsize=title_fontsize) - ax.tick_params(axis='x', labelsize=tick_fontsize) - ax.tick_params(axis='y', labelsize=tick_fontsize) - - - if len(value_counts) > 10: - fig.set_size_inches(max(12, len(value_counts)), 8) - plt.setp(ax.get_xticklabels(), rotation=90, ha='center') - - else: - # First variable is numeric, second is categorical - grouped = self.data.groupby(variables[1])[variables[0]].mean() - ax.barh(grouped.index.astype(str), grouped.values, **kwargs) - ax.set_xlabel(variables[1] if x_label is None else x_label, fontsize=label_fontsize) - ax.set_ylabel(f"{variables[0]}" if y_label is None else y_label, fontsize=label_fontsize) - ax.set_title(f"{variables[0]} by {variables[1]}" if title is None else title, fontsize=title_fontsize) - ax.tick_params(axis='x', labelsize=tick_fontsize) - ax.tick_params(axis='y', labelsize=tick_fontsize) - - if len(grouped) > 10: - fig.set_size_inches(max(12, len(grouped)), 8) - plt.setp(ax.get_xticklabels(), rotation=90, ha='center') - - return fig - - - # ========== Statistical Plot Functions ========== - def _create_hist(self, variables: List[str], **kwargs) -> plt.Figure: - fig, ax = plt.subplots() - ax.hist(self.data[variables[0]], **kwargs) - ax.set_xlabel(variables[0]) - ax.set_ylabel('Frequency') - ax.set_title(f"Histogram of {variables[0]}") - return fig - - def _create_box(self, variables: List[str], **kwargs) -> plt.Figure: - fig, ax = plt.subplots(figsize=(10,6)) - plt.setp(ax.get_xticklabels(), rotation=90, ha='center') - - ax.boxplot(self.data[variables[0]], **kwargs) - ax.set_ylabel(variables[0]) - ax.set_title(f"Box plot of {variables[0]}") - - return fig - - def _create_violin(self, variables: List[str], **kwargs) -> plt.Figure: - fig, ax = plt.subplots(figsize=(10,6)) - plt.setp(ax.get_xticklabels(), rotation=90, ha='center') - - ax.violinplot(self.data[variables[0]], **kwargs) - ax.set_ylabel(variables[0]) - ax.set_title(f"Violin plot of {variables[0]}") - return fig - - - # ========== Specialized Plot Functions ========== - def _create_pie(self, variables: List[str], **kwargs) -> plt.Figure: - value_counts = self.data[variables[0]].value_counts() - fig, ax = plt.subplots() - ax.pie(value_counts, labels=value_counts.index, autopct='%1.1f%%', **kwargs) - ax.set_title(f"Pie chart of {variables[0]}") - return fig - - def _create_hexbin(self, variables: List[str], **kwargs) -> plt.Figure: - fig, ax = plt.subplots() - ax.hexbin(self.data[variables[0]], self.data[variables[1]], **kwargs) - self._set_labels(ax, variables) - ax.set_title(f"Hexbin: {variables[0]} vs {variables[1]}") - return fig - - - - # ========== Helper Methods ========== - def _set_labels(self, ax, variables: List[str]): - """Set labels for x and y axes based on variables.""" - if len(variables) > 0: - ax.set_xlabel(variables[0]) - if len(variables) > 1: - ax.set_ylabel(variables[1]) - -class SmartPlotGenerator(PlotGenerator): - def _create_box(self, variables: List[str], **kwargs) -> plt.Figure: - """Enhanced boxplot that handles both univariate and bivariate cases with NaN handling.""" - fig, ax = plt.subplots(figsize=(10, 6)) - plt.setp(ax.get_xticklabels(), rotation=90, ha='center') - - if len(variables) == 1: - # Univariate case - single numerical variable - data = self.data[variables[0]].dropna() # Remove NaN values - if len(data) == 0: - raise ValueError(f"No valid data remaining after dropping NaN values for {variables[0]}") - ax.boxplot(data, **kwargs) - ax.set_ylabel(variables[0]) - ax.set_title(f"Box plot of {variables[0]}") - elif len(variables) >= 2: - # Bivariate case - numerical vs categorical - numerical_var = variables[0] - categorical_var = variables[1] - - # Clean data - remove rows where either variable is NaN - clean_data = self.data[[numerical_var, categorical_var]].dropna() - if len(clean_data) == 0: - raise ValueError(f"No valid data remaining after cleaning {numerical_var} and {categorical_var}") - - # Group data by categorical variable - grouped_data = [clean_data[clean_data[categorical_var] == cat][numerical_var] - for cat in clean_data[categorical_var].unique()] - - # Filter out empty groups - grouped_data = [group for group in grouped_data if len(group) > 0] - if not grouped_data: - raise ValueError("No valid groups remaining after filtering") - - ax.boxplot(grouped_data, **kwargs) - ax.set_xticklabels(clean_data[categorical_var].unique()) - ax.set_xlabel(categorical_var) - ax.set_ylabel(numerical_var) - ax.set_title(f"Box plot of {numerical_var} by {categorical_var}") - else: - raise ValueError("Box plot requires at least 1 variable") - - return fig - - def _create_violin(self, variables: List[str], **kwargs) -> plt.Figure: - """Enhanced violin plot that handles both univariate and bivariate cases with NaN handling.""" - fig, ax = plt.subplots(figsize=(10,6)) - plt.setp(ax.get_xticklabels(), rotation=90, ha='center') - - if len(variables) == 1: - # Univariate case - single numerical variable - data = self.data[variables[0]].dropna() # Remove NaN values - if len(data) == 0: - raise ValueError(f"No valid data remaining after dropping NaN values for {variables[0]}") - ax.violinplot(data, **kwargs) - ax.set_ylabel(variables[0]) - ax.set_title(f"Violin plot of {variables[0]}") - elif len(variables) >= 2: - # Bivariate case - numerical vs categorical - numerical_var = variables[0] - categorical_var = variables[1] - - # Clean data - remove rows where either variable is NaN - clean_data = self.data[[numerical_var, categorical_var]].dropna() - if len(clean_data) == 0: - raise ValueError(f"No valid data remaining after cleaning {numerical_var} and {categorical_var}") - - # Group data by categorical variable - grouped_data = [clean_data[clean_data[categorical_var] == cat][numerical_var] - for cat in clean_data[categorical_var].unique()] - - # Filter out empty groups - grouped_data = [group for group in grouped_data if len(group) > 0] - if not grouped_data: - raise ValueError("No valid groups remaining after filtering") - - ax.violinplot(grouped_data, **kwargs) - ax.set_xticks(np.arange(1, len(grouped_data)+1)) - ax.set_xticklabels(clean_data[categorical_var].unique()) - ax.set_xlabel(categorical_var) - ax.set_ylabel(numerical_var) - ax.set_title(f"Violin plot of {numerical_var} by {categorical_var}") - else: - raise ValueError("Violin plot requires at least 1 variable") - - return fig - - def _create_hist(self, variables: List[str], **kwargs) -> plt.Figure: - """Enhanced histogram that can handle grouping by a second variable.""" - fig, ax = plt.subplots(figsize=(12, 8)) - - if len(variables) == 1: - # Simple histogram - data = self.data[variables[0]].dropna() - if len(data) == 0: - raise ValueError(f"No valid data remaining for {variables[0]}") - - ax.hist(data, **kwargs) - ax.set_xlabel(variables[0]) - ax.set_ylabel('Frequency') - ax.set_title(f"Histogram of {variables[0]}") - elif len(variables) >= 2: - # Grouped histogram - numerical_var = variables[0] - categorical_var = variables[1] - - # Clean data - clean_data = self.data[[numerical_var, categorical_var]].dropna() - if len(clean_data) == 0: - raise ValueError(f"No valid data remaining after cleaning {numerical_var} and {categorical_var}") - - # Get unique categories - categories = clean_data[categorical_var].unique() - - # Set default colors if not provided - if 'color' not in kwargs and 'colors' not in kwargs: - colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] - else: - colors = [kwargs.pop('color')] * len(categories) if 'color' in kwargs else kwargs.pop('colors') - - # Plot each group - for i, cat in enumerate(categories): - ax.hist(clean_data[clean_data[categorical_var] == cat][numerical_var], - alpha=0.5, - label=str(cat), - color=colors[i % len(colors)], - **kwargs) - - ax.set_xlabel(numerical_var) - ax.set_ylabel('Frequency') - ax.set_title(f"Histogram of {numerical_var} by {categorical_var}") - ax.legend() - else: - raise ValueError("Histogram requires at least 1 variable") - - return fig - - def _create_scatter(self, variables: List[str], - size_scale: float = 100.0, - **kwargs) -> plt.Figure: - """ - Create a scatter plot with optional color and size dimensions. - - Parameters: - ----------- - variables : List[str] - - 2 variables: x, y - - 3 variables: x, y, color - - 4 variables: x, y, color, size - size_scale : float - Scaling factor for bubble sizes (default: 100) - - Returns: - -------- - matplotlib.figure.Figure - """ - if len(variables) < 2: - raise ValueError("Scatter requires at least 2 variables (x, y)") - if len(variables) > 4: - raise ValueError("Scatter supports maximum 4 variables (x, y, color, size)") - - # Check data types - for var in variables[:2]: # x and y must be numeric - if not np.issubdtype(self.data[var].dtype, np.number): - raise ValueError(f"Variable '{var}' must be numeric") - - fig, ax = plt.subplots() - scatter_params = { - 'x': self.data[variables[0]], - 'y': self.data[variables[1]], - } - - # Handle color (3rd variable) - if len(variables) >= 3: - color_data = self.data[variables[2]] - if pd.api.types.is_numeric_dtype(color_data): - # For numeric color data, use continuous colormap - scatter_params['c'] = color_data - kwargs.setdefault('cmap', 'viridis') - else: - # For categorical data, convert to numeric codes - scatter_params['c'] = pd.factorize(color_data)[0] - kwargs.setdefault('cmap', 'tab10') - - # Handle size (4th variable) - if len(variables) == 4: - size_data = self.data[variables[3]] - if not pd.api.types.is_numeric_dtype(size_data): - raise ValueError(f"Size variable '{variables[3]}' must be numeric") - - # Normalize and scale sizes - sizes = np.abs(size_data) # Ensure positive - sizes = (sizes - sizes.min()) / (sizes.max() - sizes.min() + 1e-8) * size_scale - scatter_params['s'] = sizes - - # Apply any additional kwargs - scatter_params.update(kwargs) - - scatter = ax.scatter(**scatter_params) - - # Set labels and title - self._set_labels(ax, variables[:2]) # Assuming this sets x and y labels - title = f"Scatter: {variables[0]} vs {variables[1]}" - if len(variables) >= 3: - title += f" (colored by {variables[2]})" - # Add colorbar for continuous data - if pd.api.types.is_numeric_dtype(self.data[variables[2]]): - fig.colorbar(scatter, ax=ax, label=variables[2]) - if len(variables) == 4: - title += f" (sized by {variables[3]})" - ax.set_title(title) - - return fig - - +from matplotlib.figure import Figure +from typing import Optional, Union, Callable +from plotsense.plot_generator.basic_generator import BasicPlotGenerator +from plotsense.plot_generator.smart_generator import SmartPlotGenerator +from plotsense.plot_generator.registry import PlotRequirements # Global instance of the plot generator _plot_generator_instance = None +_GENERATOR_MAP = { + "basic": BasicPlotGenerator, + "smart": SmartPlotGenerator +} + def plotgen( df: pd.DataFrame, suggestion: Union[int, pd.Series], suggestions_df: Optional[pd.DataFrame] = None, + generator: str = "basic", + plot_function: Optional[Callable] = None, + plot_type: Optional[str] = None, + plot_requirements: Optional[PlotRequirements] = None, **plot_kwargs -) -> plt.Figure: +) -> Figure: """ Generate a plot based on visualization suggestions. - + + Users can also register a custom plot function temporarily by providing: + plot_function: callable(df, variables, **kwargs) -> Figure + plot_type: string name for the custom plot + plot_requirements: optional PlotRequirements object + Args: df: Input DataFrame containing the data to plot suggestion: Either an integer index or a pandas Series containing the suggestion row suggestions_df: DataFrame containing visualization suggestions (required if suggestion is an index) + generator: String identifier for generator to use ("basic" or "smart") + plot_function: Optional custom plot function + plot_type: Name of the custom plot + plot_requirements: Optional PlotRequirements for the custom plot **plot_kwargs: Additional arguments to pass to the plot function - + Returns: matplotlib.Figure: The generated figure - - Example: - # Using index (requires suggestions_df) - fig = plotgen(df, 7, suggestions_df=recommendations) - - # Using direct row access with additional plot arguments - fig = plotgen(df, recommendations.iloc[7], bins=30, color='red') - - # Using specific variable names - fig = plotgen(df, recommendations.iloc[7], x='age', y='fare') """ global _plot_generator_instance - - # Handle case where suggestion is a row from recommendations - if isinstance(suggestion, pd.Series): - # Create a temporary single-row suggestions DataFrame - temp_df = pd.DataFrame([suggestion]) - # Initialize the plot generator with this single suggestion - _plot_generator_instance = SmartPlotGenerator(df, temp_df) - - - # Get the variables from the suggestion - variables = [v.strip() for v in suggestion['variables'].split(',')] - plot_type = suggestion['plot_type'].lower() - - # Handle x, y, z arguments if provided - if 'x' in plot_kwargs: - variables[0] = plot_kwargs.pop('x') - if 'y' in plot_kwargs and len(variables) > 1: - variables[1] = plot_kwargs.pop('y') - if 'z' in plot_kwargs and len(variables) > 2: - variables[2] = plot_kwargs.pop('z') - # Create a new suggestion with updated variables - updated_suggestion = suggestion.copy() - updated_suggestion['variables'] = ','.join(variables) - temp_df = pd.DataFrame([updated_suggestion]) - _plot_generator_instance.suggestions = temp_df - - # Generate the plot - return _plot_generator_instance.generate_plot(0, **plot_kwargs) - - # Handle case where suggestion is an index - elif isinstance(suggestion, int): - if suggestions_df is None: + # Determine generator class from string + generator_class = _GENERATOR_MAP.get(generator.lower(), BasicPlotGenerator) + + # Initialize generator instance if needed + if _plot_generator_instance is None or not isinstance( + _plot_generator_instance, generator_class + ): + # Handle case where suggestion is a row from recommendations + if isinstance(suggestion, pd.Series): + temp_df = pd.DataFrame([suggestion]) + _plot_generator_instance = generator_class(df, temp_df) + # Handle case where suggestion is an index + elif isinstance(suggestion, int): + if suggestions_df is None: + raise ValueError("suggestions_df must be provided when using an index") + _plot_generator_instance = generator_class(df, suggestions_df) + else: + # Update data if it changed + if not _plot_generator_instance.data.equals(df): + _plot_generator_instance.data = df + + # If user provides a custom plot function, register it temporarily + if plot_function is not None: + if not plot_type: + raise ValueError("plot_type name must be provided when registering a custom plot") + if plot_requirements is None: + plot_requirements = PlotRequirements(min_variables=1, max_variables=2, numeric_only=True) + + pg = _plot_generator_instance + if pg is None: + raise RuntimeError("Plot generator instance is not initialized") + + pg.registry.register( + plot_type, + plot_requirements, + lambda variables, + f=plot_function: f(pg.data, variables, **plot_kwargs) + ) + + # Extract suggestion row + if isinstance(suggestion, pd.Series): + suggestion_row = suggestion.copy() + else: + s_df = suggestions_df + if s_df is None: + raise ValueError("suggestions_df must be provided when using an index") + suggestion_row = s_df.iloc[suggestion].copy() + + # Override variables if x/y/z provided + variables = [v.strip() for v in str(suggestion_row['variables']).split(',')] + if 'x' in plot_kwargs: + variables[0] = plot_kwargs.pop('x') + if 'y' in plot_kwargs and len(variables) > 1: + variables[1] = plot_kwargs.pop('y') + if 'z' in plot_kwargs and len(variables) > 2: + variables[2] = plot_kwargs.pop('z') + + suggestion_row['variables'] = ','.join(variables) + + # Update the generator's suggestion DataFrame if using index + if isinstance(suggestion, int): + s_df = suggestions_df + if s_df is None: raise ValueError("suggestions_df must be provided when using an index") - - # Initialize the plot generator if it doesn't exist - if _plot_generator_instance is None: - _plot_generator_instance = SmartPlotGenerator(df, suggestions_df) - else: - # Update the data if the generator exists but the data changed - if not _plot_generator_instance.data.equals(df): - _plot_generator_instance.data = df - - # Get the variables from the suggestion - suggestion_row = suggestions_df.iloc[suggestion] - variables = [v.strip() for v in suggestion_row['variables'].split(',')] - plot_type = suggestion_row['plot_type'].lower() - - # Handle x, y, z arguments if provided - if 'x' in plot_kwargs: - variables[0] = plot_kwargs.pop('x') - if 'y' in plot_kwargs and len(variables) > 1: - variables[1] = plot_kwargs.pop('y') - if 'z' in plot_kwargs and len(variables) > 2: - variables[2] = plot_kwargs.pop('z') - - # Create a new suggestion with updated variables - updated_suggestion = suggestion_row.copy() - updated_suggestion['variables'] = ','.join(variables) - suggestions_df.iloc[suggestion] = updated_suggestion + s_df.iloc[suggestion] = suggestion_row _plot_generator_instance.suggestions = suggestions_df - - # Generate the plot - return _plot_generator_instance.generate_plot(suggestion, **plot_kwargs) - # else: - # raise TypeError("suggestion must be either an integer index or a pandas Series") + else: + _plot_generator_instance.suggestions = pd.DataFrame([suggestion_row]) + + # Determine plot_type to use + active_plot_type = plot_type or str(suggestion_row['plot_type']).lower() + + # Generate the plot + plot_func = _plot_generator_instance.registry.get_generator(active_plot_type) + if not plot_func: + raise ValueError(f"Plot type '{active_plot_type}' not supported") + + if not _plot_generator_instance.registry.validate(active_plot_type, variables, _plot_generator_instance.data): + raise ValueError(f"Invalid variables for plot '{active_plot_type}'") + + return plot_func(variables, **plot_kwargs) +# fig = plotgen(df, 0, suggestions_df, generator="smart") diff --git a/plotsense/plot_generator/helpers.py b/plotsense/plot_generator/helpers.py new file mode 100644 index 0000000..722dd50 --- /dev/null +++ b/plotsense/plot_generator/helpers.py @@ -0,0 +1,10 @@ +from typing import List + + +def set_labels(ax, variables: List[str]): + """Set labels for x and y axes based on variables.""" + if len(variables) > 0: + ax.set_xlabel(variables[0]) + if len(variables) > 1: + ax.set_ylabel(variables[1]) + diff --git a/plotsense/plot_generator/plots/__init__.py b/plotsense/plot_generator/plots/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plotsense/plot_generator/plots/basic/bar.py b/plotsense/plot_generator/plots/basic/bar.py new file mode 100644 index 0000000..66a9f74 --- /dev/null +++ b/plotsense/plot_generator/plots/basic/bar.py @@ -0,0 +1,46 @@ +from typing import List +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np + +def create_bar_plot(df: pd.DataFrame, variables: List[str], **kwargs): + fig, ax = plt.subplots(figsize=(10, 6)) + + # Extract label-related kwargs if provided + x_label = kwargs.pop('x_label', None) + y_label = kwargs.pop('y_label', None) + title = kwargs.pop('title', None) + + # Define font sizes + tick_fontsize = kwargs.pop('tick_fontsize', 12) + label_fontsize = kwargs.pop('label_fontsize', 14) + title_fontsize = kwargs.pop('title_fontsize', 16) + + if len(variables) == 1: + value_counts = df[variables[0]].value_counts().sort_values(ascending=False) + ax.bar( + value_counts.index.astype(str), + np.asarray(value_counts.values, **kwargs) + ) + ax.set_xlabel(variables[0] if x_label is None else x_label, fontsize=label_fontsize) + ax.set_ylabel('Count' if y_label is None else y_label, fontsize=label_fontsize) + ax.set_title(f"Bar plot of {variables[0]}" if title is None else title, fontsize=title_fontsize) + ax.tick_params(axis='x', labelsize=tick_fontsize) + ax.tick_params(axis='y', labelsize=tick_fontsize) + if len(value_counts) > 10: + fig.set_size_inches(max(12, len(value_counts)), 8) + plt.setp(ax.get_xticklabels(), rotation=90, ha='center') + else: + grouped = df.groupby(variables[1])[variables[0]].mean() + grouped = pd.Series(grouped).sort_values(ascending=False) + ax.bar(grouped.index.astype(str), np.asarray(grouped.values), **kwargs) + ax.set_xlabel(variables[1] if x_label is None else x_label, fontsize=label_fontsize) + ax.set_ylabel(f"{variables[0]}" if y_label is None else y_label, fontsize=label_fontsize) + ax.set_title(f"{variables[0]} by {variables[1]}" if title is None else title, fontsize=title_fontsize) + ax.tick_params(axis='x', labelsize=tick_fontsize) + ax.tick_params(axis='y', labelsize=tick_fontsize) + if len(grouped) > 10: + fig.set_size_inches(max(12, len(grouped)), 8) + plt.setp(ax.get_xticklabels(), rotation=90, ha='center') + return fig + diff --git a/plotsense/plot_generator/plots/basic/barh.py b/plotsense/plot_generator/plots/basic/barh.py new file mode 100644 index 0000000..cbf5100 --- /dev/null +++ b/plotsense/plot_generator/plots/basic/barh.py @@ -0,0 +1,48 @@ +from typing import List +from matplotlib.figure import Figure +import matplotlib.pyplot as plt + +def create_barh_plot(df, variables: List[str], **kwargs) -> Figure: + fig, ax = plt.subplots(figsize=(10, 6)) + + + # Extract label-related kwargs if provided + x_label = kwargs.pop('x_label', None) + y_label = kwargs.pop('y_label', None) + title = kwargs.pop('title', None) + + # Define font sizes + tick_fontsize = kwargs.pop('tick_fontsize', 12) + label_fontsize = kwargs.pop('label_fontsize', 14) + title_fontsize = kwargs.pop('title_fontsize', 16) + + if len(variables) == 1: + # Single variable - show value counts + value_counts = df[variables[0]].value_counts() + ax.barh(value_counts.index.astype(str), value_counts.values, **kwargs) + ax.set_xlabel(variables[0] if x_label is None else x_label, fontsize=label_fontsize) + ax.set_ylabel('Count' if y_label is None else y_label, fontsize=label_fontsize) + ax.set_title(f"Bar plot of {variables[0]}" if title is None else title, fontsize=title_fontsize) + ax.tick_params(axis='x', labelsize=tick_fontsize) + ax.tick_params(axis='y', labelsize=tick_fontsize) + + if len(value_counts) > 10: + fig.set_size_inches(max(12, len(value_counts)), 8) + plt.setp(ax.get_xticklabels(), rotation=90, ha='center') + + else: + # First variable is numeric, second is categorical + grouped = df.groupby(variables[1])[variables[0]].mean() + ax.barh(grouped.index.astype(str), grouped.values, **kwargs) + ax.set_xlabel(variables[1] if x_label is None else x_label, fontsize=label_fontsize) + ax.set_ylabel(f"{variables[0]}" if y_label is None else y_label, fontsize=label_fontsize) + ax.set_title(f"{variables[0]} by {variables[1]}" if title is None else title, fontsize=title_fontsize) + ax.tick_params(axis='x', labelsize=tick_fontsize) + ax.tick_params(axis='y', labelsize=tick_fontsize) + + if len(grouped) > 10: + fig.set_size_inches(max(12, len(grouped)), 8) + plt.setp(ax.get_xticklabels(), rotation=90, ha='center') + + return fig + diff --git a/plotsense/plot_generator/plots/basic/box.py b/plotsense/plot_generator/plots/basic/box.py new file mode 100644 index 0000000..e9d1733 --- /dev/null +++ b/plotsense/plot_generator/plots/basic/box.py @@ -0,0 +1,15 @@ +from typing import List +from matplotlib.figure import Figure +import matplotlib.pyplot as plt + + +def create_box_plot(df, variables: List[str], **kwargs) -> Figure: + fig, ax = plt.subplots(figsize=(10,6)) + plt.setp(ax.get_xticklabels(), rotation=90, ha='center') + + ax.boxplot(df[variables[0]], **kwargs) + ax.set_ylabel(variables[0]) + ax.set_title(f"Box plot of {variables[0]}") + + return fig + diff --git a/plotsense/plot_generator/plots/basic/ecdf.py b/plotsense/plot_generator/plots/basic/ecdf.py new file mode 100644 index 0000000..aa20f1f --- /dev/null +++ b/plotsense/plot_generator/plots/basic/ecdf.py @@ -0,0 +1,21 @@ +import matplotlib.pyplot as plt +import numpy as np + +def create_ecdf_plot(df, variables, **kwargs): + """Empirical Cumulative Distribution Function (ECDF) plot.""" + var = variables[0] + data = df[var].dropna() + if data.empty: + raise ValueError(f"No valid data for {var}") + + sorted_data = np.sort(data) + n = len(sorted_data) + y = np.arange(1, n + 1) / n + + fig, ax = plt.subplots(figsize=(8, 5)) + ax.plot(sorted_data, y, marker='.', linestyle='none', **kwargs) + ax.set_title(f"ECDF of {var}") + ax.set_xlabel(var) + ax.set_ylabel("Cumulative Probability") + return fig + diff --git a/plotsense/plot_generator/plots/basic/hexbin.py b/plotsense/plot_generator/plots/basic/hexbin.py new file mode 100644 index 0000000..f574d71 --- /dev/null +++ b/plotsense/plot_generator/plots/basic/hexbin.py @@ -0,0 +1,11 @@ +from typing import List +import matplotlib.pyplot as plt +from matplotlib.figure import Figure + +def create_hexbin_plot(df, variables: List[str], **kwargs) -> Figure: + fig, ax = plt.subplots() + ax.hexbin(df[variables[0]], df[variables[1]], **kwargs) + df._set_labels(ax, variables) + ax.set_title(f"Hexbin: {variables[0]} vs {variables[1]}") + return fig + diff --git a/plotsense/plot_generator/plots/basic/hist.py b/plotsense/plot_generator/plots/basic/hist.py new file mode 100644 index 0000000..c5ea05a --- /dev/null +++ b/plotsense/plot_generator/plots/basic/hist.py @@ -0,0 +1,12 @@ +from typing import List +from matplotlib.figure import Figure +import matplotlib.pyplot as plt + +def create_hist_plot(df, variables: List[str], **kwargs) -> Figure: + fig, ax = plt.subplots() + ax.hist(df[variables[0]], **kwargs) + ax.set_xlabel(variables[0]) + ax.set_ylabel('Frequency') + ax.set_title(f"Histogram of {variables[0]}") + return fig + diff --git a/plotsense/plot_generator/plots/basic/kde.py b/plotsense/plot_generator/plots/basic/kde.py new file mode 100644 index 0000000..88d3c16 --- /dev/null +++ b/plotsense/plot_generator/plots/basic/kde.py @@ -0,0 +1,16 @@ +import matplotlib.pyplot as plt + +def create_kde_plot(df, variables, **kwargs): + """Kernel Density Estimation plot for a numeric variable.""" + var = variables[0] + data = df[var].dropna() + if data.empty: + raise ValueError(f"No valid data for {var}") + + fig, ax = plt.subplots(figsize=(8, 5)) + data.plot(kind='kde', ax=ax, **kwargs) + ax.set_title(f"KDE Plot of {var}") + ax.set_xlabel(var) + ax.set_ylabel("Density") + return fig + diff --git a/plotsense/plot_generator/plots/basic/pie.py b/plotsense/plot_generator/plots/basic/pie.py new file mode 100644 index 0000000..b89e64c --- /dev/null +++ b/plotsense/plot_generator/plots/basic/pie.py @@ -0,0 +1,13 @@ +from typing import List +import matplotlib.pyplot as plt +from matplotlib.figure import Figure + + +def create_pie_plot(df, variables: List[str], **kwargs) -> Figure: + value_counts = df[variables[0]].value_counts() + fig, ax = plt.subplots() + ax.pie(value_counts, labels=value_counts.index, autopct='%1.1f%%', **kwargs) + ax.set_title(f"Pie chart of {variables[0]}") + return fig + + diff --git a/plotsense/plot_generator/plots/basic/scatter.py b/plotsense/plot_generator/plots/basic/scatter.py new file mode 100644 index 0000000..6799dd8 --- /dev/null +++ b/plotsense/plot_generator/plots/basic/scatter.py @@ -0,0 +1,14 @@ +from typing import List +import matplotlib.pyplot as plt + +def create_scatter_plot(df, variables: List[str], **kwargs): + """Scatter plot: requires at least 2 variables (x, y).""" + if len(variables) < 2: + raise ValueError("scatter requires at least 2 variables (x, y)") + fig, ax = plt.subplots() + ax.scatter(df[variables[0]], df[variables[1]], **kwargs) + ax.set_xlabel(variables[0]) + ax.set_ylabel(variables[1]) + ax.set_title(f"Scatter: {variables[0]} vs {variables[1]}") + return fig + diff --git a/plotsense/plot_generator/plots/basic/violin.py b/plotsense/plot_generator/plots/basic/violin.py new file mode 100644 index 0000000..08cfb1a --- /dev/null +++ b/plotsense/plot_generator/plots/basic/violin.py @@ -0,0 +1,14 @@ +from typing import List +from matplotlib.figure import Figure +import matplotlib.pyplot as plt + + +def create_violin_plot(df, variables: List[str], **kwargs) -> Figure: + fig, ax = plt.subplots(figsize=(10,6)) + plt.setp(ax.get_xticklabels(), rotation=90, ha='center') + + ax.violinplot(df[variables[0]], **kwargs) + ax.set_ylabel(variables[0]) + ax.set_title(f"Violin plot of {variables[0]}") + return fig + diff --git a/plotsense/plot_generator/plots/smart/box.py b/plotsense/plot_generator/plots/smart/box.py new file mode 100644 index 0000000..e45e72c --- /dev/null +++ b/plotsense/plot_generator/plots/smart/box.py @@ -0,0 +1,49 @@ +from typing import List +import matplotlib.pyplot as plt +import pandas as pd +from matplotlib.figure import Figure + +def create_box_plot(df: pd.DataFrame, variables: List[str], **kwargs) -> Figure: + """Enhanced boxplot that handles both univariate and bivariate cases with NaN handling.""" + fig, ax = plt.subplots(figsize=(10, 6)) + plt.setp(ax.get_xticklabels(), rotation=90, ha='center') + + if len(variables) == 1: + # Univariate case - single numerical variable + data = df[variables[0]].dropna() # Remove NaN values + if data.empty: + raise ValueError(f"No valid data remaining after dropping NaN values for {variables[0]}") + ax.boxplot(data, **kwargs) + ax.set_ylabel(variables[0]) + ax.set_title(f"Box plot of {variables[0]}") + elif len(variables) >= 2: + # Bivariate case - numerical vs categorical + numerical_var = variables[0] + categorical_var = variables[1] + + # Clean data - remove rows where either variable is NaN + clean_data = df[[numerical_var, categorical_var]].dropna() + if clean_data.empty: + raise ValueError(f"No valid data remaining after cleaning {numerical_var} and {categorical_var}") + + # Group data by categorical variable + grouped_data = [ + clean_data[clean_data[categorical_var] == cat][numerical_var] + for cat in clean_data[categorical_var].unique() + ] + + # Filter out empty groups + grouped_data = [group for group in grouped_data if len(group) > 0] + if not grouped_data: + raise ValueError("No valid groups remaining after filtering") + + ax.boxplot(grouped_data, **kwargs) + ax.set_xticklabels(clean_data[categorical_var].unique()) + ax.set_xlabel(categorical_var) + ax.set_ylabel(numerical_var) + ax.set_title(f"Box plot of {numerical_var} by {categorical_var}") + else: + raise ValueError("Box plot requires at least 1 variable") + + return fig + diff --git a/plotsense/plot_generator/plots/smart/ecdf.py b/plotsense/plot_generator/plots/smart/ecdf.py new file mode 100644 index 0000000..09cfe5e --- /dev/null +++ b/plotsense/plot_generator/plots/smart/ecdf.py @@ -0,0 +1,43 @@ +import matplotlib.pyplot as plt +import numpy as np +from typing import List +import pandas as pd +from matplotlib.figure import Figure + +from plotsense.plot_generator.helpers import set_labels + +def create_ecdf_plot(df: pd.DataFrame, variables: List[str], **kwargs) -> Figure: + """ + Enhanced ECDF plot that handles univariate and grouped data with NaN handling. + """ + if len(variables) == 0: + raise ValueError("ECDF plot requires at least 1 variable") + + var = variables[0] + fig, ax = plt.subplots(figsize=(8, 5)) + + if len(variables) == 1: + data = df[var].dropna() + if data.empty: + raise ValueError(f"No valid data for {var}") + sorted_data = np.sort(data) + n = len(sorted_data) + y = np.arange(1, n + 1) / n + ax.plot(sorted_data, y, marker='.', linestyle='none', **kwargs) + else: + # Grouped ECDF + group_var = variables[1] + clean_data = df[[var, group_var]].dropna() + if clean_data.empty: + raise ValueError(f"No valid data after cleaning {var} and {group_var}") + for cat, group in clean_data.groupby(group_var): + sorted_data = np.sort(group[var]) + n = len(sorted_data) + y = np.arange(1, n + 1) / n + ax.plot(sorted_data, y, marker='.', linestyle='none', label=str(cat), **kwargs) + ax.legend(title=group_var) + + ax.set_title(f"ECDF of {var}") + set_labels(ax, variables[:2]) + ax.set_ylabel("Cumulative Probability") + return fig diff --git a/plotsense/plot_generator/plots/smart/histogram.py b/plotsense/plot_generator/plots/smart/histogram.py new file mode 100644 index 0000000..d134320 --- /dev/null +++ b/plotsense/plot_generator/plots/smart/histogram.py @@ -0,0 +1,56 @@ +from typing import List +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np + +def create_histogram_plot(df: pd.DataFrame, variables: List[str], **kwargs) -> plt.Figure: + """Enhanced histogram that can handle grouping by a second variable.""" + fig, ax = plt.subplots(figsize=(12, 8)) + + if len(variables) == 1: + # Simple histogram + data = df[variables[0]].dropna() + if data.empty: + raise ValueError(f"No valid data remaining for {variables[0]}") + ax.hist(data, **kwargs) + ax.set_xlabel(variables[0]) + ax.set_ylabel("Frequency") + ax.set_title(f"Histogram of {variables[0]}") + elif len(variables) >= 2: + # Grouped histogram + num, cat = variables[0], variables[1] + + # Clean data - remove rows where either variable is NaN + clean_data = df[[num, cat]].dropna() + if clean_data.empty: + raise ValueError(f"No valid data remaining after cleaning {num} and {cat}") + + # Get unique categories + categories = clean_data[cat].unique() + + # Set default colors if not provided + if 'color' in kwargs: + colors = [kwargs['color']] * len(categories) + elif 'colors' in kwargs: + colors = kwargs['colors'] + else: + colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] + + # Plot each group + for i, c in enumerate(categories): + ax.hist( + clean_data[clean_data[cat] == c][num], + alpha=0.5, + label=str(c), + color=colors[i % len(colors)], + **kwargs + ) + ax.set_xlabel(num) + ax.set_ylabel("Frequency") + ax.set_title(f"Histogram of {num} by {cat}") + ax.legend() + else: + raise ValueError("Histogram requires at least one variable") + + return fig + diff --git a/plotsense/plot_generator/plots/smart/kde.py b/plotsense/plot_generator/plots/smart/kde.py new file mode 100644 index 0000000..742c361 --- /dev/null +++ b/plotsense/plot_generator/plots/smart/kde.py @@ -0,0 +1,37 @@ +from matplotlib.figure import Figure +import matplotlib.pyplot as plt +import seaborn as sns # optional, for nicer KDE plots +from typing import List +import pandas as pd + +from plotsense.plot_generator.helpers import set_labels + +def create_kde_plot(df: pd.DataFrame, variables: List[str], **kwargs) -> Figure: + """ + Enhanced KDE plot that handles univariate and grouped data with NaN handling. + """ + if len(variables) == 0: + raise ValueError("KDE plot requires at least 1 variable") + + var = variables[0] + data: pd.DataFrame = pd.DataFrame(df[[var]].dropna()) + if data.empty: + raise ValueError(f"No valid data for {var}") + + fig, ax = plt.subplots(figsize=(8, 5)) + + if len(variables) == 1: + # Univariate + sns.kdeplot(data=data, ax=ax, **kwargs) + else: + # Bivariate / group-by + group_var = variables[1] + clean_data: pd.DataFrame = pd.DataFrame(df[[var, group_var]].dropna()) + if clean_data.empty: + raise ValueError(f"No valid data after cleaning {var} and {group_var}") + sns.kdeplot(data=clean_data, x=var, hue=group_var, ax=ax, **kwargs) + + ax.set_title(f"KDE Plot of {var}") + set_labels(ax, variables[:2]) # x + y labels + return fig + diff --git a/plotsense/plot_generator/plots/smart/scatter.py b/plotsense/plot_generator/plots/smart/scatter.py new file mode 100644 index 0000000..6b081cf --- /dev/null +++ b/plotsense/plot_generator/plots/smart/scatter.py @@ -0,0 +1,79 @@ +from typing import List +from matplotlib.figure import Figure +from matplotlib import pyplot as plt +import pandas as pd +import numpy as np + +def create_scatter_plot( + df, variables: List[str], + size_scale: float = 100.0, **kwargs +) -> Figure: + """ + Create a scatter plot with optional color and size dimensions. + + Parameters: + ----------- + variables : List[str] + - 2 variables: x, y + - 3 variables: x, y, color + - 4 variables: x, y, color, size + size_scale : float + Scaling factor for bubble sizes (default: 100) + + Returns: + -------- + matplotlib.figure.Figure + """ + if len(variables) < 2: + raise ValueError("Scatter requires at least 2 variables (x, y)") + if len(variables) > 4: + raise ValueError("Scatter supports up to 4 variables (x, y, color, size)") + + # Check data types + for var in variables[:2]: + if not np.issubdtype(df[var].dtype, np.number): + raise ValueError(f"Variable '{var}' must be numeric") + + fig, ax = plt.subplots() + scatter_params = {"x": df[variables[0]], "y": df[variables[1]]} + + # Handle color (3rd variable) + if len(variables) >= 3: + color_data = df[variables[2]] + if pd.api.types.is_numeric_dtype(color_data): + # For numeric color data, use continuous colormap + scatter_params["c"] = color_data + kwargs.setdefault("cmap", "viridis") + else: + # For categorical data, convert to numeric codes + scatter_params["c"] = pd.factorize(color_data)[0] + kwargs.setdefault("cmap", "tab10") + + # Handle size (4th variable) + if len(variables) == 4: + size_data = df[variables[3]] + if not pd.api.types.is_numeric_dtype(size_data): + raise ValueError(f"Size variable '{variables[3]}' must be numeric") + + # Normalize and scale sizes + sizes = np.abs(size_data) + sizes = (sizes - sizes.min()) / (sizes.max() - sizes.min() + 1e-8) * size_scale + scatter_params["s"] = sizes + + # Apply any additional kwargs + scatter_params.update(kwargs) + scatter = ax.scatter(**scatter_params) + + # Set labels and title + ax.set_xlabel(variables[0]) + ax.set_ylabel(variables[1]) + title = f"Scatter: {variables[0]} vs {variables[1]}" + if len(variables) >= 3: + title += f" (colored by {variables[2]})" + if pd.api.types.is_numeric_dtype(df[variables[2]]): + fig.colorbar(scatter, ax=ax, label=variables[2]) + if len(variables) == 4: + title += f" (sized by {variables[3]})" + ax.set_title(title) + return fig + diff --git a/plotsense/plot_generator/plots/smart/violin.py b/plotsense/plot_generator/plots/smart/violin.py new file mode 100644 index 0000000..1b7b41f --- /dev/null +++ b/plotsense/plot_generator/plots/smart/violin.py @@ -0,0 +1,48 @@ +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +def create_violin_plot(df: pd.DataFrame, variables, **kwargs): + """Enhanced violin plot that handles both univariate and bivariate cases with NaN handling.""" + fig, ax = plt.subplots(figsize=(10, 6)) + plt.setp(ax.get_xticklabels(), rotation=90, ha='center') + + if len(variables) == 1: + # Univariate case - single numerical variable + data = df[variables[0]].dropna() + if data.empty: + raise ValueError(f"No valid data remaining after dropping NaN values for {variables[0]}") + ax.violinplot(data, **kwargs) + ax.set_ylabel(variables[0]) + ax.set_title(f"Violin plot of {variables[0]}") + elif len(variables) >= 2: + # Bivariate case - numerical vs categorical + num, cat = variables[0], variables[1] + + # Clean data - remove rows where either variable is NaN + clean_data = df[[num, cat]].dropna() + if clean_data.empty: + raise ValueError(f"No valid data remaining after cleaning {num} and {cat}") + + # Group data by categorical variable + grouped_data = [ + clean_data[clean_data[cat] == c][num] + for c in clean_data[cat].unique() + ] + + # Filter out empty groups + grouped_data = [g for g in grouped_data if len(g) > 0] + if not grouped_data: + raise ValueError("No valid groups remaining after filtering") + + ax.violinplot(grouped_data, **kwargs) + ax.set_xticks(np.arange(1, len(grouped_data) + 1)) + ax.set_xticklabels(clean_data[cat].unique()) + ax.set_xlabel(cat) + ax.set_ylabel(num) + ax.set_title(f"Violin plot of {num} by {cat}") + else: + raise ValueError("Violin plot requires at least one variable") + + return fig + diff --git a/plotsense/plot_generator/registry.py b/plotsense/plot_generator/registry.py new file mode 100644 index 0000000..f84bf69 --- /dev/null +++ b/plotsense/plot_generator/registry.py @@ -0,0 +1,49 @@ +from typing import Callable, Dict, List, Optional, Any +import pandas as pd +from dataclasses import dataclass + +@dataclass +class PlotRequirements(): + """Define constraints for a plot type.""" + min_variables: int = 1 # minimum required variables + max_variables: int = 2 # maximum supported variables + numeric_only: bool = True # whether data must be numeric + +class PlotTypeRegistry: + """Central registry for all supported plot types.""" + + def __init__(self): + self._registry: Dict[str, Dict[str, Any]] = {} + + def register(self, name: str, requirements: PlotRequirements, generator: Callable): + """Register a plot type and its generation function.""" + self._registry[name.lower()] = { + "requirements": requirements, + "generator": generator + } + + def get_generator(self, name: str) -> Optional[Callable]: + """Retrieve generator function by plot type name.""" + entry = self._registry.get(name.lower()) + return entry["generator"] if entry else None + + def validate(self, name: str, variables: List[str], df: pd.DataFrame) -> bool: + """Check if given data fits the plot type requirements.""" + entry = self._registry.get(name.lower()) + if not entry: + return False + + req = entry["requirements"] + if not (req.min_variables <= len(variables) <= req.max_variables): + return False + + if req.numeric_only: + for var in variables: + if not pd.api.types.is_numeric_dtype(df[var]): + return False + return True + + def list_plot_types(self) -> List[str]: + """List all registered plot types.""" + return list(self._registry.keys()) + diff --git a/plotsense/plot_generator/smart_generator.py b/plotsense/plot_generator/smart_generator.py new file mode 100644 index 0000000..eb83d19 --- /dev/null +++ b/plotsense/plot_generator/smart_generator.py @@ -0,0 +1,25 @@ +from plotsense.plot_generator.base_generator import PlotGenerator +from plotsense.plot_generator.plots.smart.box import create_box_plot +from plotsense.plot_generator.plots.smart.ecdf import create_ecdf_plot +from plotsense.plot_generator.plots.smart.histogram import create_histogram_plot +from plotsense.plot_generator.plots.smart.kde import create_kde_plot +from plotsense.plot_generator.plots.smart.scatter import create_scatter_plot +from plotsense.plot_generator.plots.smart.violin import create_violin_plot + + +class SmartPlotGenerator(PlotGenerator): + """ + An enhanced PlotGenerator with advanced plotting capabilities. + """ + + @property + def _default_plots(self): + return { + 'box': create_box_plot, + 'ecdf': create_ecdf_plot, + 'histogram': create_histogram_plot, + 'kde': create_kde_plot, + 'scatter': create_scatter_plot, + 'violin': create_violin_plot, + } + diff --git a/plotsense/visual_suggestion/__init__.py b/plotsense/visual_suggestion/__init__.py index 9eba1bb..9c9fe01 100644 --- a/plotsense/visual_suggestion/__init__.py +++ b/plotsense/visual_suggestion/__init__.py @@ -1 +1,2 @@ -from plotsense.visual_suggestion.suggestions import recommender, VisualizationRecommender +from plotsense.visual_suggestion.suggestions import recommender +from plotsense.visual_suggestion.recommender import VisualizationRecommender diff --git a/plotsense/visual_suggestion/recommender/__init__.py b/plotsense/visual_suggestion/recommender/__init__.py new file mode 100644 index 0000000..853b52a --- /dev/null +++ b/plotsense/visual_suggestion/recommender/__init__.py @@ -0,0 +1 @@ +from .visualization_recommender import VisualizationRecommender diff --git a/plotsense/visual_suggestion/recommender/dataframe_analyzer.py b/plotsense/visual_suggestion/recommender/dataframe_analyzer.py new file mode 100644 index 0000000..c7a893f --- /dev/null +++ b/plotsense/visual_suggestion/recommender/dataframe_analyzer.py @@ -0,0 +1,69 @@ +from typing import List +import pandas as pd +import numpy as np + + +class DataFrameAnalyzer: + def __init__(self, df: pd.DataFrame) -> None: + self.df = df + + def describe_dataframe(self) -> str: + num_cols = len(self.df.columns) + sample_size = min(3, len(self.df)) + desc: List[str] = [] + + # --- Basic Metadata --- + desc.append(f"DataFrame Shape: {self.df.shape}") + desc.append(f"Columns ({num_cols}): {', '.join(self.df.columns)}") + desc.append("\nColumn Details:") + + # --- Column-Level Analysis --- + for col in self.df.columns: + # Determine semantic type (more granular than dtype) + if pd.api.types.is_datetime64_dtype(self.df[col]): + col_type = "datetime" + elif pd.api.types.is_numeric_dtype(self.df[col]): + col_type = "numerical" + elif self.df[col].nunique() / len(self.df[col]) < 0.05: # Low cardinality + col_type = "categorical" + else: + col_type = "text/other" + + # Basic info + unique_count = self.df[col].nunique() + sample_values = self.df[col].dropna().head(sample_size).tolist() + desc.append( + f"- {col}: {col_type} ({unique_count} unique values), sample: {sample_values}" + ) + + # Add stats for numerical/datetime + if col_type == "numerical": + desc.append( + f" Stats: min={self.df[col].min()}, max={self.df[col].max()}, " + f"mean={self.df[col].mean():.2f}, missing={self.df[col].isna().sum()}" + ) + elif col_type == "datetime": + desc.append( + f" Range: {self.df[col].min()} to {self.df[col].max()}, " + f"missing={self.df[col].isna().sum()}" + ) + + # --- Relationship Analysis --- + numerical_cols = self.df.select_dtypes(include=np.number).columns.tolist() + if len(numerical_cols) > 1: + desc.append("\nNumerical Variable Correlations (Pearson):") + corr = self.df[numerical_cols].corr().round(2) + desc.append(str(corr)) + + # Categorical-numerical potential groupings + categorical_cols = [ + col for col in self.df.columns + if self.df[col].nunique() / len(self.df[col]) < 0.05 + ] + if categorical_cols and numerical_cols: + desc.append("\nPotential Groupings (categorical vs numerical):") + desc.append(f" - Could group by: {categorical_cols}") + desc.append(f" - To analyze: {numerical_cols}") + + return "\n".join(desc) + diff --git a/plotsense/visual_suggestion/recommender/ensemble_scorer.py b/plotsense/visual_suggestion/recommender/ensemble_scorer.py new file mode 100644 index 0000000..efa7002 --- /dev/null +++ b/plotsense/visual_suggestion/recommender/ensemble_scorer.py @@ -0,0 +1,127 @@ +from typing import Dict, List, Tuple +import pandas as pd +from collections import defaultdict +from pprint import pprint +import textwrap + +from plotsense.visual_suggestion.recommender.dataframe_analyzer import DataFrameAnalyzer + + +class EnsembleScorer: + def __init__( + self, df: pd.DataFrame, available_models: List[Tuple[str, str]], + debug: bool = False + ): + self.df = df + self.debug = debug + self.available_models = available_models + + def apply_ensemble_scoring( + self, all_recommendations: Dict[str, List[Dict]], + weights: Dict[str, float] + ) -> pd.DataFrame: + output_columns = ['plot_type', 'variables', 'ensemble_score', 'model_agreement', 'source_models'] + + if self.debug: + print("\n[DEBUG] Applying ensemble scoring with weights:") + pprint(weights) + + recommendation_weights = defaultdict(float) + recommendation_details = {} + + for model, recs in all_recommendations.items(): + model_weight = weights.get(model, 0) + if model_weight <= 0: + continue + + for rec in recs: + # Create a consistent key for the recommendation + variables = rec['variables'] + if isinstance(variables, str): + variables = [v.strip() for v in variables.split(',')] + + # Filter variables to only those in the DataFrame + valid_vars = [var for var in variables if var in self.df.columns] + if not valid_vars: + if self.debug: + print(f"\n[DEBUG] Skipping recommendation from {model} with invalid variables: {variables}") + continue + + var_key = ', '.join(sorted(valid_vars)) + rec_key = (rec['plot_type'].lower(), var_key) + + model_score = rec.get('score', 1.0) + total_weight = model_weight * model_score + recommendation_weights[rec_key] += total_weight + + if rec_key not in recommendation_details: + recommendation_details[rec_key] = { + 'plot_type': rec['plot_type'], + 'variables': var_key, + 'source_models': [model], + 'raw_weight': total_weight + } + else: + recommendation_details[rec_key]['source_models'].append(model) + recommendation_details[rec_key]['raw_weight'] += total_weight + + if not recommendation_details: + if self.debug: + print("\n[DEBUG] No valid recommendations after filtering") + return pd.DataFrame(columns=output_columns) + + results = pd.DataFrame(list(recommendation_details.values())) + + if self.debug: + print("\n[DEBUG] Recommendations before scoring:") + print(results) + + if not results.empty: + total_possible = sum(weights.values()) + results['ensemble_score'] = results['raw_weight'] / total_possible + results['ensemble_score'] = results['ensemble_score'].round(2) + results['model_agreement'] = results['source_models'].apply(len) + results = results.sort_values(['ensemble_score', 'model_agreement'], ascending=[False, False]).reset_index(drop=True) + return results[output_columns] + + return pd.DataFrame(columns=output_columns) + + def supplement_recommendations(self, existing: pd.DataFrame, target: int) -> pd.DataFrame: + """Generate additional recommendations if we didn't get enough initially.""" + if len(existing) >= target: + return existing.head(target) + + needed = target - len(existing) + analyzer = DataFrameAnalyzer(self.df) + df_description = analyzer.describe_dataframe() + + # Try to get more recommendations from the best-performing model + best_model = existing.iloc[0]['source_models'][0] if not existing.empty else self.available_models[0] + + prompt = textwrap.dedent(f""" + You already recommended these visualizations: + {existing[['plot_type', 'variables']].to_string()} + + Please recommend {needed} ADDITIONAL different visualizations for: + {df_description} + + Use the same format but ensure they're distinct from the above. + """) + + try: + response = self._query_llm(prompt, best_model) + new_recs = self._parse_recommendations(response, f"{best_model}-supplement") + + # Combine with existing + combined = pd.concat([existing, pd.DataFrame(new_recs)], ignore_index=True) + combined = combined.drop_duplicates(subset=['plot_type', 'variables']) + + if self.debug: + print(f"\n[DEBUG] Supplemented with {len(new_recs)} new recommendations") + + return combined.head(target) + except Exception as e: + if self.debug: + print(f"\n[WARNING] Couldn't supplement recommendations: {str(e)}") + return existing.head(target) # Return what we have + diff --git a/plotsense/visual_suggestion/recommender/prompt_builder.py b/plotsense/visual_suggestion/recommender/prompt_builder.py new file mode 100644 index 0000000..c1eb973 --- /dev/null +++ b/plotsense/visual_suggestion/recommender/prompt_builder.py @@ -0,0 +1,93 @@ +import textwrap + + +class PromptBuilder: + def __init__(self, n_to_request: int): + self.n_to_request = n_to_request + + def build_prompt(self, df_description: str) -> str: + return textwrap.dedent(f""" + You are a data visualization expert analyzing this dataset: + + {df_description} + + Recommend {self.n_to_request} insightful visualizations using matplotlib's plotting functions. + For each suggestion, follow this exact format: + + Plot Type: + Variables: + Rationale: <1-2 sentences explaining why this visualization is useful> + --- + + CRITICAL VARIABLE ORDERING RULES: + 1. If a suggestion includes both numerical and categorical variables, NUMERICAL VARIABLES MUST COME FIRST. + - Correct: "income, gender" + - Incorrect: "gender, income" + 2. For plots requiring two numerical variables (e.g., scatter), order by analysis priority (dependent variable first). + 3. For single-variable plots, use natural order (e.g., "age" for a histogram). + + GENERAL RULES FOR ALL PLOT TYPES: + 1. Ensure the plot type is a valid matplotlib function + 2. The plot type must be appropriate for the variables' data types + 3. The number of variables must match what the plot type requires + 4. Variables must exist in the dataset + 5. Never combine incompatible variables + 6. Always specify complete variable sets + 7. Ensure plot type names are in lowercase and match matplotlib's naming conventions eg hist for histogram, bar for barplot + 8. Ensure the common plot types requirements are met including the data types + + COMMON PLOT TYPE REQUIREMENTS (non-exhaustive): + 1. bar: 1 categorical (x) + 1 numerical (y) β†’ Variables: [numerical], [categorical] + 2. scatter: Exactly 2 numerical β†’ Variables: [independent], [dependent] + 3. hist: Exactly 1 numerical β†’ Variables: [numerical] + 4. boxplot: 1 numerical OR 1 numerical + 1 categorical β†’ Variables: [numerical], [categorical] (if grouped) + 5. pie: Exactly 1 categorical β†’ Variables: [categorical] + 6. line: 1 numerical (y) OR 1 numerical (y) + 1 datetime (x) β†’ Variables: [y], [x] (if applicable) + 7. heatmap: 2 categorical + 1 numerical OR correlation matrix β†’ Variables: [numerical], [categorical], [categorical] + 8. violinplot: Same as boxplot + 9. hexbin: Exactly 2 numerical variables + 10. pairplot: 2+ numerical variables + 11. jointplot: Exactly 2 numerical variables + 12. contour: 2 numerical variables for grid + 1 for values + 13. quiver: 2 numerical variables for grid + 2 for vectors + 14. imshow: 2D array of numerical values + 15. errorbar: 1 numerical (x) + 1 numerical (y) + error values + 16. stackplot: 1 numerical (x) + multiple numerical (y) + 17. stem: 1 numerical (x) + 1 numerical (y) + 18. fill_between: 1 numerical (x) + 2 numerical (y) + 19. pcolormesh: 2D grid of numerical values + 20. polar: Angular and radial coordinates + + If suggesting a plot not listed above, ensure: + - The function exists in matplotlib + - Variable types and counts are explicitly compatible + - The rationale clearly explains the insight provided + + Additional Requirements: + 1. For specialized plots (like quiver, contour), ensure all required components are specified + 2. Consider the statistical properties and relationships of the variables + 3. Suggest plots that would reveal meaningful insights about the data + 4. Include both common and advanced plots when appropriate + + Example CORRECT suggestions (NUMERICAL FIRST): + Plot Type: boxplot + Variables: income, gender + Rationale: Compares income distribution across genders + --- + Plot Type: scatter + Variables: age, income + Rationale: Shows relationship between age and income + --- + Plot Type: bar + Variables: revenue, product_category + Rationale: Compares revenue across product categories + + Example INCORRECT suggestions (REJECT THESE): + Plot Type: boxplot + Variables: gender, income # WRONG - categorical listed first + --- + Plot Type: scatter + Variables: price, weight # WRONG - no clear priority order + Rationale: Should specify independent/dependent variable order + """) + diff --git a/plotsense/visual_suggestion/recommender/response_parser.py b/plotsense/visual_suggestion/recommender/response_parser.py new file mode 100644 index 0000000..e2629d5 --- /dev/null +++ b/plotsense/visual_suggestion/recommender/response_parser.py @@ -0,0 +1,98 @@ +from typing import Dict, List +import pandas as pd +import warnings + + +class ResponseParser: + def __init__(self, df: pd.DataFrame, debug: bool = False): + self.df = df + self.debug = debug + + def parse_recommendations(self, response: str, model: str) -> List[Dict]: + """Parse the LLM response into structured recommendations""" + recommendations = [] + + # Split response into recommendation blocks + blocks = [b.strip() for b in response.split('---') if b.strip()] + + if self.debug: + print(f"\n[DEBUG] Parsing {len(blocks)} blocks from {model}") + + for block in blocks: + lines = [line.strip() for line in block.split('\n') if line.strip()] + if not lines: + continue + + try: + rec = {'source_model': model} + for line in lines: + if line.lower().startswith('plot type:'): + rec['plot_type'] = line.split(':', 1)[1].strip().lower() + elif line.lower().startswith('variables:'): + raw_vars = line.split(':', 1)[1].strip() + # Filter variables to only those that exist in DataFrame + variables = [ + v.strip() for v in raw_vars.split(',') if v.strip() in self.df.columns + ] + rec['variables'] = ', '.join([ + var for var in variables if var in self.df.columns + ]) + #rec['variables'] = self._reorder_variables(', '.join(variables)) # Keep original order for now + + if 'plot_type' in rec and 'variables' in rec and rec['variables']: + recommendations.append(rec) + except Exception as e: + warnings.warn(f"Failed to parse recommendation from {model}: {str(e)}") + continue + + return recommendations + + def validate_variable_order(self, recommendations: pd.DataFrame) -> pd.DataFrame: + """ + Validate and correct the order of variables in recommendations, + ensuring numerical variables come first. + + Args: + recommendations: DataFrame of visualization recommendations + + Returns: + DataFrame with corrected variable order + """ + def _reorder_variables(row): + # Split variables + variables = [var.strip() for var in row['variables'].split(',')] + + # Identify numerical and non-numerical variables + numerical_vars = [ + var for var in variables + if pd.api.types.is_numeric_dtype(self.df[var]) + ] + + date_vars = [ + var for var in variables + if pd.api.types.is_datetime64_any_dtype(self.df[var]) + ] + + non_numerical_vars = [ + var for var in variables + if var not in numerical_vars and var not in date_vars + ] + + # Combine with numerical variables first + corrected_vars = date_vars + numerical_vars + non_numerical_vars + + # Update the row with corrected variable order + row['variables'] = ', '.join(corrected_vars) + return row + + # Apply reordering + corrected_recommendations = recommendations.apply(_reorder_variables, axis=1) + + if self.debug: + print("\n[DEBUG] Variable Order Validation:") + for orig, corrected in zip(recommendations['variables'], corrected_recommendations['variables']): + if orig != corrected: + print(f" Corrected: {orig} β†’ {corrected}") + + return corrected_recommendations + diff --git a/plotsense/visual_suggestion/recommender/visualization_recommender.py b/plotsense/visual_suggestion/recommender/visualization_recommender.py new file mode 100644 index 0000000..ec1f09f --- /dev/null +++ b/plotsense/visual_suggestion/recommender/visualization_recommender.py @@ -0,0 +1,177 @@ +import pandas as pd +from pprint import pprint +from typing import Dict, List, Optional, Tuple + +from plotsense.core.ai_interface import AIModelInterface +from plotsense.core.enums.strategy import StrategyName +from plotsense.core.providers.provider_manager import ProviderManager +from plotsense.visual_suggestion.recommender.dataframe_analyzer import DataFrameAnalyzer +from plotsense.visual_suggestion.recommender.ensemble_scorer import EnsembleScorer +from plotsense.visual_suggestion.recommender.prompt_builder import PromptBuilder +from plotsense.visual_suggestion.recommender.response_parser import ResponseParser + + +class VisualizationRecommender: + + def __init__( + self, + api_keys: Optional[Dict[str, str]], + strategy: StrategyName, + selected_models: Optional[List[Tuple[str, str]]], + timeout: int, + interactive: bool, + debug: bool, + ): + """ + Initialize VisualizationRecommender with API keys and configuration. + + Args: + api_keys: Optional dictionary of API keys. If not provided, + keys will be loaded from environment variables. + timeout: Timeout in seconds for API requests + interactive: Whether to prompt for missing API keys + debug: Enable debug output + """ + self.timeout = timeout + self.interactive = interactive + self.debug = debug + self.strategy_name = strategy + + selected_providers = {p for p, _ in (selected_models or [])} + + self.manager = ProviderManager( + api_keys=api_keys or {}, + interactive=interactive, + restrict_to=list(selected_providers) if selected_providers else None + ) + self.ai_interface = AIModelInterface(self.manager, timeout=self.timeout) + + all_models = self.manager.list_all_models() + self.available_models = [ + (provider, model) + for provider, models in all_models.items() + for model in models + ] + + if not self.available_models: + raise ValueError( + "No available models detected β€” check API keys or selection input." + ) + + # initialize strategy instance + self.strategy = self.ai_interface._init_strategy( + self.strategy_name, self.available_models + ) + + self.df = None + # model_weights will be lazily obtained from AIModelInterface if not provided + self.model_weights = {} + + if self.debug: + print("\n[DEBUG] Initialization Complete") + print(f"Available models: {self.available_models}") + print(f"Model weights: {self.model_weights}") + + def set_dataframe(self, df: pd.DataFrame): + """Set the DataFrame to analyze and provide debug info""" + self.df = df + if self.debug: + print("\n[DEBUG] DataFrame Info:") + print(f"Shape: {df.shape}") + print("Columns:", df.columns.tolist()) + print("\nSample data:") + print(df.head(2)) + + def recommend_visualizations( + self, n: int = 5, custom_weights: Optional[Dict[str, float]] = None + ) -> pd.DataFrame: + """ + Generate visualization recommendations using weighted ensemble approach. + + Args: + n: Number of recommendations to return (default: 3) + custom_weights: Optional dictionary to override default model weights + + Returns: + pd.DataFrame: Recommended visualizations with ensemble scores + + Raises: + ValueError: If no DataFrame is set or no models are available + """ + """Generate visualization recommendations using weighted ensemble approach.""" + self.n_to_request = max(n, 5) + + if self.df is None: + raise ValueError("No DataFrame set. Call set_dataframe() first.") + + if not self.available_models: + raise ValueError("No available models detected") + + if self.debug: + print("\n[DEBUG] Starting recommendation process") + print(f"Using models: {self.available_models}") + + # Use custom weights if provided, otherwise try self.model_weights then ai_interface weights + if custom_weights: + weights = custom_weights + elif self.model_weights: + weights = self.model_weights + else: + # Defer to AIModelInterface for default weights (keeps compatibility with provider-manager) + weights = self.ai_interface.get_model_weights() + + # Get recommendations from all models in parallel via AIModelInterface + analyzer = DataFrameAnalyzer(self.df) + df_description = analyzer.describe_dataframe() + prompt = PromptBuilder(self.n_to_request).build_prompt(df_description) + + if self.debug: + print("\n[DEBUG] Prompt being sent to models:") + print(prompt) + + # Expecting ai_interface.query_all_models to return dict { "provider:model": "raw text" } + all_recommendations = self.ai_interface.query_all_models( + prompt, self.debug + ) + + if self.debug: + print("\n[DEBUG] Raw recommendations from models:") + pprint(all_recommendations) + + # Parse model responses into structured recommendation lists + parser = ResponseParser(self.df, debug=self.debug) + parsed_recs = { + model: parser.parse_recommendations(response, model) + for model, response in all_recommendations.items() + } + + if self.debug: + print("\n[DEBUG] Applying ensemble scoring") + + scorer = EnsembleScorer( + self.df, debug=self.debug, + available_models=self.available_models + ) + # Use weights determined above (which respects custom_weights) + ensemble_df = scorer.apply_ensemble_scoring(parsed_recs, weights) + + final_df = pd.DataFrame() + # Validate and correct variable order + if not ensemble_df.empty: + final_df = parser.validate_variable_order(ensemble_df) + + # If we don't have enough results, try to supplement (mirror original behavior) + if len(final_df) < n: + if self.debug: + print(f"\n[DEBUG] Only got {len(final_df)} recommendations, trying to supplement") + # Use the same ensemble_df context when supplementing, so the scorer/parser can access source_models + supplemented = scorer.supplement_recommendations(ensemble_df, n) + return supplemented + + if self.debug: + print("\n[DEBUG] Ensemble results before filtering:") + print(ensemble_df) + + # Return the validated & ordered results (top-n) + return ensemble_df.head(n) + diff --git a/plotsense/visual_suggestion/suggestions.py b/plotsense/visual_suggestion/suggestions.py index 7225fac..62863fa 100644 --- a/plotsense/visual_suggestion/suggestions.py +++ b/plotsense/visual_suggestion/suggestions.py @@ -1,609 +1,12 @@ -import os -from typing import Dict, List, Optional, Tuple, Callable -from collections import defaultdict +from typing import Dict, List, Optional, Tuple from dotenv import load_dotenv import pandas as pd -import numpy as np -import warnings -import concurrent.futures -from concurrent.futures import ThreadPoolExecutor -import textwrap -import builtins -from pprint import pprint -from groq import Groq +from plotsense.core.enums.strategy import StrategyName +from plotsense.visual_suggestion.recommender.visualization_recommender import VisualizationRecommender -load_dotenv() - -class VisualizationRecommender: - DEFAULT_MODELS = { - 'groq': [ - ('llama-3.3-70b-versatile', 0.5), # (model_name, weight) - ('llama-3.1-8b-instant', 0.5), - ('llama-3.3-70b-versatile', 0.5) - ], - # Add other providers here - } - - def __init__(self, api_keys: Optional[Dict[str, str]] = None, timeout: int = 30, interactive: bool = True, debug: bool = False): - """ - Initialize VisualizationRecommender with API keys and configuration. - - Args: - api_keys: Optional dictionary of API keys. If not provided, - keys will be loaded from environment variables. - timeout: Timeout in seconds for API requests - interactive: Whether to prompt for missing API keys - debug: Enable debug output - """ - self.interactive = interactive - self.debug = debug - api_keys = api_keys or {} - self.api_keys = { - 'groq': os.getenv('GROQ_API_KEY') - # Add other services here - } - - self.timeout = timeout - self.clients = {} - self.available_models = [] - self.df = None - self.model_weights = {} - self.n_to_request = 5 - - self.api_keys.update(api_keys) - - self._validate_keys() - self._initialize_clients() - self._detect_available_models() - self._initialize_model_weights() - - - if self.debug: - print("\n[DEBUG] Initialization Complete") - print(f"Available models: {self.available_models}") - print(f"Model weights: {self.model_weights}") - if hasattr(self, 'clients'): - print(f"Clients initialized: {bool(self.clients)}") - - def _validate_keys(self): - """Validate that required API keys are present""" - service_links = { - 'groq': 'πŸ‘‰ https://console.groq.com/keys πŸ‘ˆ' - } - - for service in ['groq']: - if not self.api_keys.get(service): - if self.interactive: - try: - link = service_links.get(service, f"the {service.upper()} website") - message = ( - f"Enter {service.upper()} API key (get it at {link}): " - ) - self.api_keys[service] = builtins.input(message).strip() - if not self.api_keys[service]: - raise ValueError(f"{service.upper()} API key is required") - except (EOFError, OSError): - # Handle cases where input is not available - raise ValueError(f"{service.upper()} API key is required (get it at {service_links.get(service)})") - else: - raise ValueError( - f"{service.upper()} API key is required. " - f"Set it in the environment or pass it as an argument. " - f"You can get it at {service_links.get(service)}" - ) - - def _initialize_clients(self): - """Initialize API clients""" - self.clients = {} - if self.api_keys.get('groq'): - try: - self.clients['groq'] = Groq(api_key=self.api_keys['groq']) - except ImportError: - warnings.warn("Groq Python client not installed. pip install groq") - - def _detect_available_models(self): - self.available_models = [] - for provider, client in self.clients.items(): - if client and provider in self.DEFAULT_MODELS: - # For now we'll assume all DEFAULT_MODELS are available - # In a real implementation, you might want to check which models are actually available - self.available_models.extend([m[0] for m in self.DEFAULT_MODELS[provider]]) - - if self.debug: - print(f"[DEBUG] Detected available models: {self.available_models}") - - def _initialize_model_weights(self): - total_weight = 0 - self.model_weights = {} - - # Only include weights for available models - for provider in self.DEFAULT_MODELS: - for model, weight in self.DEFAULT_MODELS[provider]: - if model in self.available_models: - self.model_weights[model] = weight - total_weight += weight - - # Normalize weights to sum to 1 - if total_weight > 0: - for model in self.model_weights: - self.model_weights[model] /= total_weight - - if self.debug: - print(f"[DEBUG] Model weights: {self.model_weights}") - - def set_dataframe(self, df: pd.DataFrame): - """Set the DataFrame to analyze and provide debug info""" - self.df = df - if self.debug: - print("\n[DEBUG] DataFrame Info:") - print(f"Shape: {df.shape}") - print("Columns:", df.columns.tolist()) - print("\nSample data:") - print(df.head(2)) - - def recommend_visualizations(self, n: int = 5, custom_weights: Optional[Dict[str, float]] = None) -> pd.DataFrame: - """ - Generate visualization recommendations using weighted ensemble approach. - - Args: - n: Number of recommendations to return (default: 3) - custom_weights: Optional dictionary to override default model weights - - Returns: - pd.DataFrame: Recommended visualizations with ensemble scores - - Raises: - ValueError: If no DataFrame is set or no models are available - """ - """Generate visualization recommendations using weighted ensemble approach.""" - self.n_to_request = max(n, 5) - - if self.df is None: - raise ValueError("No DataFrame set. Call set_dataframe() first.") - - if not self.available_models: - raise ValueError("No available models detected") - - if self.debug: - print("\n[DEBUG] Starting recommendation process") - print(f"Using models: {self.available_models}") - - # Use custom weights if provided, otherwise use defaults - weights = custom_weights if custom_weights else self.model_weights - - # Get recommendations from all models in parallel - all_recommendations = self._get_all_recommendations() - - if self.debug: - print("\n[DEBUG] Raw recommendations from models:") - pprint(all_recommendations) - - # Apply weighted ensemble scoring - ensemble_results = self._apply_ensemble_scoring(all_recommendations, weights) - - # Validate and correct variable order - if not ensemble_results.empty: - ensemble_results = self._validate_variable_order(ensemble_results) - - # If we don't have enough results, try to supplement - if len(ensemble_results) < n: - if self.debug: - print(f"\n[DEBUG] Only got {len(ensemble_results)} recommendations, trying to supplement") - return self._supplement_recommendations(ensemble_results, n) - - if self.debug: - print("\n[DEBUG] Ensemble results before filtering:") - print(ensemble_results) - - return ensemble_results.head(n) - - - def _supplement_recommendations(self, existing: pd.DataFrame, target: int) -> pd.DataFrame: - """Generate additional recommendations if we didn't get enough initially.""" - if len(existing) >= target: - return existing.head(target) - - needed = target - len(existing) - df_description = self._describe_dataframe() - - # Try to get more recommendations from the best-performing model - best_model = existing.iloc[0]['source_models'][0] if not existing.empty else self.available_models[0] - - prompt = textwrap.dedent(f""" - You already recommended these visualizations: - {existing[['plot_type', 'variables']].to_string()} - - Please recommend {needed} ADDITIONAL different visualizations for: - {df_description} - - Use the same format but ensure they're distinct from the above. - """) - - try: - response = self._query_llm(prompt, best_model) - new_recs = self._parse_recommendations(response, f"{best_model}-supplement") - - # Combine with existing - combined = pd.concat([existing, pd.DataFrame(new_recs)], ignore_index=True) - combined = combined.drop_duplicates(subset=['plot_type', 'variables']) - - if self.debug: - print(f"\n[DEBUG] Supplemented with {len(new_recs)} new recommendations") - - return combined.head(target) - except Exception as e: - if self.debug: - print(f"\n[WARNING] Couldn't supplement recommendations: {str(e)}") - return existing.head(target) # Return what we have - - def _get_all_recommendations(self) -> Dict[str, List[Dict]]: - df_description = self._describe_dataframe() - prompt = self._create_prompt(df_description) - - if self.debug: - print("\n[DEBUG] Prompt being sent to models:") - print(prompt) - - model_handlers = { - 'llama': self._query_llm, - 'mistral': self._query_llm, # Same handler as llama - # Add other model handlers here - } - - all_recommendations = {} - - with ThreadPoolExecutor() as executor: - futures = {} - for model in self.available_models: - model_type = model.split('-')[0].lower() - if model_type.startswith(("llama", "mistral")): - model_type = "llama" if "llama" in model_type else "mistral" - query_func = model_handlers[model_type] - futures[executor.submit(self._get_model_recommendations, model, prompt, query_func)] = model - - for future in concurrent.futures.as_completed(futures): - model = futures[future] - try: - result = future.result() - all_recommendations[model] = result - if self.debug: - print(f"\n[DEBUG] Got {len(result)} recommendations from {model}") - except Exception as e: - warnings.warn(f"Failed to get recommendations from {model}: {str(e)}") - if self.debug: - print(f"\n[ERROR] Failed to process {model}: {str(e)}") - - return all_recommendations - - def _get_model_recommendations(self, model: str, prompt: str, query_func: Callable[[str, str], str]) -> List[Dict]: - try: - response = query_func(prompt, model) - - if self.debug: - print(f"\n[DEBUG] Raw response from {model}:") - print(response) - - return self._parse_recommendations(response, model) - except Exception as e: - warnings.warn(f"Error processing model {model}: {str(e)}") - if self.debug: - print(f"\n[ERROR] Failed to parse response from {model}: {str(e)}") - return [] - - def _apply_ensemble_scoring(self, all_recommendations: Dict[str, List[Dict]], weights: Dict[str, float]) -> pd.DataFrame: - output_columns = ['plot_type', 'variables', 'ensemble_score', 'model_agreement', 'source_models'] - - if self.debug: - print("\n[DEBUG] Applying ensemble scoring with weights:") - pprint(weights) - - recommendation_weights = defaultdict(float) - recommendation_details = {} - for model, recs in all_recommendations.items(): - model_weight = weights.get(model, 0) - if model_weight <= 0: - continue - - for rec in recs: - # Create a consistent key for the recommendation - variables = rec['variables'] - if isinstance(variables, str): - variables = [v.strip() for v in variables.split(',')] - - # Filter variables to only those in the DataFrame - valid_vars = [var for var in variables if var in self.df.columns] - if not valid_vars: - if self.debug: - print(f"\n[DEBUG] Skipping recommendation from {model} with invalid variables: {variables}") - continue - - var_key = ', '.join(sorted(valid_vars)) - rec_key = (rec['plot_type'].lower(), var_key) - - model_score = rec.get('score', 1.0) - total_weight = model_weight * model_score - recommendation_weights[rec_key] += total_weight - - if rec_key not in recommendation_details: - recommendation_details[rec_key] = { - 'plot_type': rec['plot_type'], - 'variables': var_key, - 'source_models': [model], - 'raw_weight': total_weight - } - else: - recommendation_details[rec_key]['source_models'].append(model) - recommendation_details[rec_key]['raw_weight'] += total_weight - - if not recommendation_details: - if self.debug: - print("\n[DEBUG] No valid recommendations after filtering") - return pd.DataFrame(columns=output_columns) - - results = pd.DataFrame(list(recommendation_details.values())) - - if self.debug: - print("\n[DEBUG] Recommendations before scoring:") - print(results) - - if not results.empty: - total_possible = sum(weights.values()) - results['ensemble_score'] = results['raw_weight'] / total_possible - results['ensemble_score'] = results['ensemble_score'].round(2) - results['model_agreement'] = results['source_models'].apply(len) - results = results.sort_values(['ensemble_score', 'model_agreement'], ascending=[False, False]).reset_index(drop=True) - return results[output_columns] - - return pd.DataFrame(columns=output_columns) - - def _describe_dataframe(self) -> str: - num_cols = len(self.df.columns) - sample_size = min(3, len(self.df)) - desc: List[str] = [] - - # --- Basic Metadata --- - desc.append(f"DataFrame Shape: {self.df.shape}") - desc.append(f"Columns ({num_cols}): {', '.join(self.df.columns)}") - desc.append("\nColumn Details:") - - # --- Column-Level Analysis --- - for col in self.df.columns: - # Determine semantic type (more granular than dtype) - if pd.api.types.is_datetime64_dtype(self.df[col]): - col_type = "datetime" - elif pd.api.types.is_numeric_dtype(self.df[col]): - col_type = "numerical" - elif self.df[col].nunique() / len(self.df[col]) < 0.05: # Low cardinality - col_type = "categorical" - else: - col_type = "text/other" - - # Basic info - unique_count = self.df[col].nunique() - sample_values = self.df[col].dropna().head(sample_size).tolist() - desc.append( - f"- {col}: {col_type} ({unique_count} unique values), sample: {sample_values}" - ) - - # Add stats for numerical/datetime - if col_type == "numerical": - desc.append( - f" Stats: min={self.df[col].min()}, max={self.df[col].max()}, " - f"mean={self.df[col].mean():.2f}, missing={self.df[col].isna().sum()}" - ) - elif col_type == "datetime": - desc.append( - f" Range: {self.df[col].min()} to {self.df[col].max()}, " - f"missing={self.df[col].isna().sum()}" - ) - - # --- Relationship Analysis --- - numerical_cols = self.df.select_dtypes(include=np.number).columns.tolist() - if len(numerical_cols) > 1: - desc.append("\nNumerical Variable Correlations (Pearson):") - corr = self.df[numerical_cols].corr().round(2) - desc.append(str(corr)) - - # Categorical-numerical potential groupings - categorical_cols = [ - col for col in self.df.columns - if self.df[col].nunique() / len(self.df[col]) < 0.05 - ] - if categorical_cols and numerical_cols: - desc.append("\nPotential Groupings (categorical vs numerical):") - desc.append(f" - Could group by: {categorical_cols}") - desc.append(f" - To analyze: {numerical_cols}") - - return "\n".join(desc) - - - def _create_prompt(self, df_description: str) -> str: - return textwrap.dedent(f""" - You are a data visualization expert analyzing this dataset: - - {df_description} - - Recommend {self.n_to_request} insightful visualizations using matplotlib's plotting functions. - For each suggestion, follow this exact format: - - Plot Type: - Variables: - Rationale: <1-2 sentences explaining why this visualization is useful> - --- - - CRITICAL VARIABLE ORDERING RULES: - 1. If a suggestion includes both numerical and categorical variables, NUMERICAL VARIABLES MUST COME FIRST. - - Correct: "income, gender" - - Incorrect: "gender, income" - 2. For plots requiring two numerical variables (e.g., scatter), order by analysis priority (dependent variable first). - 3. For single-variable plots, use natural order (e.g., "age" for a histogram). - - GENERAL RULES FOR ALL PLOT TYPES: - 1. Ensure the plot type is a valid matplotlib function - 2. The plot type must be appropriate for the variables' data types - 3. The number of variables must match what the plot type requires - 4. Variables must exist in the dataset - 5. Never combine incompatible variables - 6. Always specify complete variable sets - 7. Ensure plot type names are in lowercase and match matplotlib's naming conventions eg hist for histogram, bar for barplot - 8. Ensure the common plot types requirements are met including the data types - - COMMON PLOT TYPE REQUIREMENTS (non-exhaustive): - 1. bar: 1 categorical (x) + 1 numerical (y) β†’ Variables: [numerical], [categorical] - 2. scatter: Exactly 2 numerical β†’ Variables: [independent], [dependent] - 3. hist: Exactly 1 numerical β†’ Variables: [numerical] - 4. boxplot: 1 numerical OR 1 numerical + 1 categorical β†’ Variables: [numerical], [categorical] (if grouped) - 5. pie: Exactly 1 categorical β†’ Variables: [categorical] - 6. line: 1 numerical (y) OR 1 numerical (y) + 1 datetime (x) β†’ Variables: [y], [x] (if applicable) - 7. heatmap: 2 categorical + 1 numerical OR correlation matrix β†’ Variables: [numerical], [categorical], [categorical] - 8. violinplot: Same as boxplot - 9. hexbin: Exactly 2 numerical variables - 10. pairplot: 2+ numerical variables - 11. jointplot: Exactly 2 numerical variables - 12. contour: 2 numerical variables for grid + 1 for values - 13. quiver: 2 numerical variables for grid + 2 for vectors - 14. imshow: 2D array of numerical values - 15. errorbar: 1 numerical (x) + 1 numerical (y) + error values - 16. stackplot: 1 numerical (x) + multiple numerical (y) - 17. stem: 1 numerical (x) + 1 numerical (y) - 18. fill_between: 1 numerical (x) + 2 numerical (y) - 19. pcolormesh: 2D grid of numerical values - 20. polar: Angular and radial coordinates - - If suggesting a plot not listed above, ensure: - - The function exists in matplotlib - - Variable types and counts are explicitly compatible - - The rationale clearly explains the insight provided - - Additional Requirements: - 1. For specialized plots (like quiver, contour), ensure all required components are specified - 2. Consider the statistical properties and relationships of the variables - 3. Suggest plots that would reveal meaningful insights about the data - 4. Include both common and advanced plots when appropriate - - Example CORRECT suggestions (NUMERICAL FIRST): - Plot Type: boxplot - Variables: income, gender - Rationale: Compares income distribution across genders - --- - Plot Type: scatter - Variables: age, income - Rationale: Shows relationship between age and income - --- - Plot Type: bar - Variables: revenue, product_category - Rationale: Compares revenue across product categories - - Example INCORRECT suggestions (REJECT THESE): - Plot Type: boxplot - Variables: gender, income # WRONG - categorical listed first - --- - Plot Type: scatter - Variables: price, weight # WRONG - no clear priority order - Rationale: Should specify independent/dependent variable order - """) - - def _query_llm(self, prompt: str, model: str) -> str: - if not self.clients.get('groq'): - raise ValueError("Groq client not initialized") - - try: - response = self.clients['groq'].chat.completions.create( - model=model, - messages=[{"role": "user", "content": prompt}], - temperature=0.4, - max_tokens=1000, - timeout=self.timeout - ) - return response.choices[0].message.content - except Exception as e: - raise RuntimeError(f"Groq API query failed for {model}: {str(e)}") - - def _validate_variable_order(self, recommendations: pd.DataFrame) -> pd.DataFrame: - """ - Validate and correct the order of variables in recommendations, - ensuring numerical variables come first. - - Args: - recommendations: DataFrame of visualization recommendations - - Returns: - DataFrame with corrected variable order - """ - def _reorder_variables(row): - # Split variables - variables = [var.strip() for var in row['variables'].split(',')] - - # Identify numerical and non-numerical variables - numerical_vars = [ - var for var in variables - if pd.api.types.is_numeric_dtype(self.df[var]) - ] - - date_vars = [ - var for var in variables - if pd.api.types.is_datetime64_any_dtype(self.df[var]) - ] - - non_numerical_vars = [ - var for var in variables - if var not in numerical_vars and var not in date_vars - ] - - # Combine with numerical variables first - corrected_vars = date_vars + numerical_vars + non_numerical_vars - - # Update the row with corrected variable order - row['variables'] = ', '.join(corrected_vars) - return row - - # Apply reordering - corrected_recommendations = recommendations.apply(_reorder_variables, axis=1) - - if self.debug: - print("\n[DEBUG] Variable Order Validation:") - for orig, corrected in zip(recommendations['variables'], corrected_recommendations['variables']): - if orig != corrected: - print(f" Corrected: {orig} β†’ {corrected}") - - return corrected_recommendations - - def _parse_recommendations(self, response: str, model: str) -> List[Dict]: - """Parse the LLM response into structured recommendations""" - recommendations = [] - - # Split response into recommendation blocks - blocks = [b.strip() for b in response.split('---') if b.strip()] - - if self.debug: - print(f"\n[DEBUG] Parsing {len(blocks)} blocks from {model}") - - for block in blocks: - lines = [line.strip() for line in block.split('\n') if line.strip()] - if not lines: - continue - - try: - rec = {'source_model': model} - for line in lines: - if line.lower().startswith('plot type:'): - rec['plot_type'] = line.split(':', 1)[1].strip().lower() - elif line.lower().startswith('variables:'): - raw_vars = line.split(':', 1)[1].strip() - # Filter variables to only those that exist in DataFrame - variables = [v.strip() for v in raw_vars.split(',') if v.strip() in self.df.columns] - rec['variables'] = ', '.join([var for var in variables if var in self.df.columns]) - #rec['variables'] = self._reorder_variables(', '.join(variables)) # Keep original order for now - - if 'plot_type' in rec and 'variables' in rec and rec['variables']: - recommendations.append(rec) - except Exception as e: - warnings.warn(f"Failed to parse recommendation from {model}: {str(e)}") - continue - - return recommendations +load_dotenv() # Package-level convenience function _recommender_instance = None @@ -611,8 +14,14 @@ def _parse_recommendations(self, response: str, model: str) -> List[Dict]: def recommender( df: pd.DataFrame, n: int = 5, - api_keys: dict = {}, + custom_weights: Optional[Dict[str, float]] = None, + strategy: StrategyName = StrategyName.ROUND_ROBIN, + selected_models: Optional[List[Tuple[str, str]]] = None, + + api_keys: Optional[Dict[str, str]] = None, + interactive: bool = True, + timeout: int = 30, debug: bool = False ) -> pd.DataFrame: """ @@ -630,7 +39,14 @@ def recommender( """ global _recommender_instance if _recommender_instance is None: - _recommender_instance = VisualizationRecommender(api_keys=api_keys, debug=debug) + _recommender_instance = VisualizationRecommender( + api_keys=api_keys, + strategy=strategy, + selected_models=selected_models, + timeout=timeout, + interactive=interactive, + debug=debug + ) _recommender_instance.set_dataframe(df) return _recommender_instance.recommend_visualizations( diff --git a/requirements.txt b/requirements.txt index 343a7b7..54eb62e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,149 @@ -# This file is used to install the required packages for the project. -seaborn -matplotlib -pandas -numpy -pytest -python-dotenv -ipykernel -groq -requests -setuptools -jupyter -matplotlib>=3.8.0 -pytest-cov -pytest-mock +annotated-types==0.7.0 +anthropic==0.70.0 +anyio==4.11.0 +argon2-cffi==25.1.0 +argon2-cffi-bindings==25.1.0 +arrow==1.3.0 +asttokens==3.0.0 +async-lru==2.0.5 +attrs==25.4.0 +babel==2.17.0 +beautifulsoup4==4.14.2 +bleach==6.2.0 +cachetools==6.2.1 +certifi==2025.10.5 +cffi==2.0.0 +charset-normalizer==3.4.3 +colorama==0.4.6 +comm==0.2.3 +contourpy==1.3.3 +coverage==7.10.7 +cycler==0.12.1 +debugpy==1.8.17 +decorator==5.2.1 +defusedxml==0.7.1 +distro==1.9.0 +docstring_parser==0.17.0 +executing==2.2.1 +fastjsonschema==2.21.2 +fonttools==4.60.1 +fqdn==1.5.1 +google-ai-generativelanguage==0.6.15 +google-api-core==2.26.0 +google-api-python-client==2.184.0 +google-auth==2.41.1 +google-auth-httplib2==0.2.0 +google-genai==1.45.0 +googleapis-common-protos==1.70.0 +groq==0.32.0 +grpcio==1.75.1 +grpcio-status==1.71.2 +h11==0.16.0 +httpcore==1.0.9 +httplib2==0.31.0 +httpx==0.28.1 +idna==3.10 +iniconfig==2.1.0 +ipykernel==6.30.1 +ipython==9.6.0 +ipython_pygments_lexers==1.1.1 +ipywidgets==8.1.7 +isoduration==20.11.0 +jedi==0.19.2 +Jinja2==3.1.6 +jiter==0.11.0 +json5==0.12.1 +jsonpointer==3.0.0 +jsonschema==4.25.1 +jsonschema-specifications==2025.9.1 +jupyter==1.1.1 +jupyter-console==6.6.3 +jupyter-events==0.12.0 +jupyter-lsp==2.3.0 +jupyter_client==8.6.3 +jupyter_core==5.8.1 +jupyter_server==2.17.0 +jupyter_server_terminals==0.5.3 +jupyterlab==4.4.9 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.27.3 +jupyterlab_widgets==3.0.15 +kiwisolver==1.4.9 +lark==1.3.0 +MarkupSafe==3.0.3 +matplotlib==3.10.7 +matplotlib-inline==0.1.7 +mistune==3.1.4 +nbclient==0.10.2 +nbconvert==7.16.6 +nbformat==5.10.4 +nest-asyncio==1.6.0 +notebook==7.4.7 +notebook_shim==0.2.4 +numpy==2.3.3 +openai==2.3.0 +packaging==25.0 +pandas==2.3.3 +pandocfilters==1.5.1 +parso==0.8.5 +pillow==11.3.0 +platformdirs==4.5.0 +-e git+https://github.com/DYung26/PlotKit@b5889036ba82bb3cc8c8e50dca1b8238e8bff8f5#egg=plotsense +pluggy==1.6.0 +prometheus_client==0.23.1 +prompt_toolkit==3.0.52 +proto-plus==1.26.1 +protobuf==5.29.5 +psutil==7.1.0 +pure_eval==0.2.3 +pyasn1==0.6.1 +pyasn1_modules==0.4.2 +pycparser==2.23 +pydantic==2.12.0 +pydantic_core==2.41.1 +Pygments==2.19.2 +pyparsing==3.2.5 +pytest==8.4.2 +pytest-cov==7.0.0 +pytest-mock==3.15.1 +python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 +python-json-logger==4.0.0 +pytz==2025.2 +pywin32==311 +pywinpty==3.0.2 +PyYAML==6.0.3 +pyzmq==27.1.0 +referencing==0.36.2 +requests==2.32.5 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rfc3987-syntax==1.1.0 +rpds-py==0.27.1 +rsa==4.9.1 +seaborn==0.13.2 +Send2Trash==1.8.3 +setuptools==80.9.0 +six==1.17.0 +sniffio==1.3.1 +soupsieve==2.8 +stack-data==0.6.3 +tenacity==9.1.2 +terminado==0.18.1 +tinycss2==1.4.0 +tornado==6.5.2 +tqdm==4.67.1 +traitlets==5.14.3 +types-python-dateutil==2.9.0.20251008 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +tzdata==2025.2 +uri-template==1.3.0 +uritemplate==4.2.0 +urllib3==2.5.0 +wcwidth==0.2.14 +webcolors==24.11.1 +webencodings==0.5.1 +websocket-client==1.9.0 +websockets==15.0.1 +widgetsnbextension==4.0.14 diff --git a/setup.py b/setup.py index 9c244bb..e2a0f52 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,17 @@ +# -*- coding: utf-8 -*- +import io from setuptools import setup, find_packages +with io.open("README.md", "r", encoding="utf-8") as f: + long_description = f.read() + setup( name="plotsense", version="0.1.3", author="Christian Chimezie, Toluwaleke Ogidan, Grace Farayola, Amaka Iduwe, Nelson Ogbeide, Onyekachukwu Ojumah, Olamilekan Ajao", author_email="chimeziechristiancc@gmail.com, gbemilekeogidan@gmail.com, gracefarayola@gmail.com, nwaamaka_iduwe@yahoo.com, Ogbeide331@gmail.com, Onyekaojumah22@gmail.com, olamilekan011@gmail.com", description="An intelligent plotting package with suggestions and explanations", - long_description=open("README.md").read(), + long_description=long_description, # open("README.md").read(), long_description_content_type="text/markdown", url="https://github.com/christianchimezie/PlotSenseAI", project_urls={ @@ -34,7 +39,10 @@ "numpy>=1.18", "python-dotenv", "groq", + "anthropic", + "openai", + "google-genai", "requests", ], license="Apache License 2.0", -) \ No newline at end of file +) diff --git a/test/live/live_test_explanations.py b/test/live/live_test_explanations.py new file mode 100644 index 0000000..45492e5 --- /dev/null +++ b/test/live/live_test_explanations.py @@ -0,0 +1,22 @@ +from plotsense.explanations.explanations import explainer +from matplotlib import pyplot as plt + +# Example: generate a simple plot +fig, ax = plt.subplots() +ax.plot([1, 2, 3], [4, 5, 6]) + +# Replace with your actual API keys +api_keys = { + "groq": "gsk_xyz", + "openai": "sk-proj-xyz-abc" +} + +# Run explainer +result = explainer( + fig, + prompt="Explain this simple line plot", + api_keys=api_keys, + selected_models=[("openai", "gpt-4.1")], +) +print(result) + diff --git a/test/live/live_test_plotgen.py b/test/live/live_test_plotgen.py new file mode 100644 index 0000000..1a24ad0 --- /dev/null +++ b/test/live/live_test_plotgen.py @@ -0,0 +1,64 @@ +import pandas as pd +import matplotlib.pyplot as plt + +from plotsense.plot_generator.generator import plotgen + +df = pd.DataFrame({ + "a": range(10), + "b": range(10, 20) +}) + +suggestions_df = pd.DataFrame([ + {"plot_type": "scatter", "variables": "a,b"} +]) + +# Standard plot +fig1 = plotgen(df, 0, suggestions_df, generator="smart") + +# --------- + +# Custom plot +def my_custom_plot(df, vars, **kwargs): + fig, ax = plt.subplots() + ax.plot(df[vars[0]], df[vars[1]], color="red") + return fig + +fig2 = plotgen( + df, 0, suggestions_df, + generator="smart", + plot_function=my_custom_plot, + plot_type="my_line" +) + +# --------- + +# Create sample DataFrame +df = pd.DataFrame({ + "height": [165, 170, 175, 160, 172, 168, 180, 177, 169, 174] +}) + +# Simulate a recommendation DataFrame (just like your usual `suggestions_df`) +suggestions_df = pd.DataFrame([ + {"plot_type": "kde", "variables": "height"} +]) + +# Generate KDE Plot +fig_kde = plotgen(df, 0, suggestions_df, generator="smart") + +# --------- + +# Create sample DataFrame +df = pd.DataFrame({ + "scores": [60, 72, 85, 90, 66, 75, 88, 93, 70, 80] +}) + +# Simulate recommendation DataFrame +suggestions_df = pd.DataFrame([ + {"plot_type": "ecdf", "variables": "scores"} +]) + +# Generate ECDF Plot +fig_ecdf = plotgen(df, 0, suggestions_df, generator="smart") + +plt.show() + diff --git a/test/live/live_test_suggestions.py b/test/live/live_test_suggestions.py new file mode 100644 index 0000000..a5e4a32 --- /dev/null +++ b/test/live/live_test_suggestions.py @@ -0,0 +1,29 @@ +from plotsense.visual_suggestion.suggestions import recommender +import pandas as pd + +# Example: create a simple DataFrame +df = pd.DataFrame({ + "Year": [2020, 2021, 2022, 2023], + "Sales": [150, 200, 250, 300], + "Profit": [40, 50, 65, 80] +}) + +# Replace with your actual API keys +api_keys = { + "groq": "gsk_xyz", + "openai": "sk-proj-xyz-abc", + "azure": "ghp_xyz", +} + +# Run the recommender +recommendations = recommender( + df, + n=3, # number of visualizations to recommend + api_keys=api_keys, + selected_models=[("azure", "openai/gpt-5")], +) + +# Display the recommendations +print("πŸ“Š Recommended visualizations:") +print(recommendations) + diff --git a/test/test_explanations.py b/test/unit/test_explanations.py similarity index 99% rename from test/test_explanations.py rename to test/unit/test_explanations.py index 40deb18..b4c20a2 100644 --- a/test/test_explanations.py +++ b/test/unit/test_explanations.py @@ -34,10 +34,10 @@ def sample_plot(sample_data): def mock_groq_completion(): mock_message = MagicMock() type(mock_message).content = PropertyMock(return_value="Mock explanation") - + mock_choice = MagicMock() mock_choice.message = mock_message - + mock_response = MagicMock() mock_response.choices = [mock_choice] return mock_response @@ -390,4 +390,4 @@ def test_different_plot_types(self, mock_query_model): plt.close(fig2) # if __name__ == "__main__": -# pytest.main(["-v", "--cov=plot_explainer", "--cov-report=term-missing"]) \ No newline at end of file +# pytest.main(["-v", "--cov=plot_explainer", "--cov-report=term-missing"]) diff --git a/test/test_plotgen.py b/test/unit/test_plotgen.py similarity index 96% rename from test/test_plotgen.py rename to test/unit/test_plotgen.py index b7f6e52..67f4d65 100644 --- a/test/test_plotgen.py +++ b/test/unit/test_plotgen.py @@ -10,7 +10,8 @@ matplotlib.use('Agg') # SUT -from plotsense.plot_generator.generator import PlotGenerator, SmartPlotGenerator, plotgen +from plotsense.plot_generator.base_generator import PlotGenerator +from plotsense.plot_generator.generator import BasicPlotGenerator, SmartPlotGenerator, plotgen # Fixtures @pytest.fixture @@ -55,8 +56,8 @@ def sample_suggestions(): @pytest.fixture def plot_generator(sample_dataframe, sample_suggestions): - """Fixture for PlotGenerator instance.""" - return PlotGenerator(sample_dataframe, sample_suggestions) + """Fixture for BasicPlotGenerator instance.""" + return BasicPlotGenerator(sample_dataframe, sample_suggestions) @pytest.fixture def smart_plot_generator(sample_dataframe, sample_suggestions): @@ -71,9 +72,9 @@ def reset_plot_generator_instance(): _plot_generator_instance = None # Unit Tests -class TestPlotGeneratorUnit: +class TestBasicPlotGeneratorUnit: def test_init_plot_generator(self, sample_dataframe, sample_suggestions): - pg = PlotGenerator(sample_dataframe, sample_suggestions) + pg = BasicPlotGenerator(sample_dataframe, sample_suggestions) assert pg.data.equals(sample_dataframe) assert pg.suggestions.equals(sample_suggestions) expected_functions = set(['scatter', 'line', 'bar', 'barh', 'stem', 'step', 'fill_between', @@ -89,9 +90,9 @@ def test_init_smart_plot_generator(self, sample_dataframe, sample_suggestions): 'hist', 'boxplot', 'violinplot', 'errorbar', 'pie', 'polar', 'hexbin', 'quiver', 'streamplot', 'plot3d', 'scatter3d', 'bar3d', 'surface']) assert set(spg.plot_functions.keys()) == expected_functions - assert spg.plot_functions['boxplot'] != PlotGenerator(sample_dataframe, sample_suggestions).plot_functions['boxplot'] - assert spg.plot_functions['violinplot'] != PlotGenerator(sample_dataframe, sample_suggestions).plot_functions['violinplot'] - assert spg.plot_functions['hist'] != PlotGenerator(sample_dataframe, sample_suggestions).plot_functions['hist'] + assert spg.plot_functions['boxplot'] != BasicPlotGenerator(sample_dataframe, sample_suggestions).plot_functions['boxplot'] + assert spg.plot_functions['violinplot'] != BasicPlotGenerator(sample_dataframe, sample_suggestions).plot_functions['violinplot'] + assert spg.plot_functions['hist'] != BasicPlotGenerator(sample_dataframe, sample_suggestions).plot_functions['hist'] def test_generate_plot_with_index(self, plot_generator): fig = plot_generator.generate_plot(0) @@ -364,7 +365,7 @@ def test_create_hist_smart(self, smart_plot_generator): plt.close(fig) # Integration Tests -class TestPlotGeneratorIntegration: +class TestBasicPlotGeneratorIntegration: @pytest.mark.parametrize("index", [0, 5, 10, 15, 19]) def test_plotgen_with_index(self, sample_dataframe, sample_suggestions, index, sample_2d_array): # Add x2d for plots requiring 2D arrays @@ -419,7 +420,7 @@ def test_plotgen_with_smart_generator(self, sample_dataframe, sample_suggestions plt.close(fig) # End-to-End Tests -class TestPlotGeneratorEndToEnd: +class TestBasicPlotGeneratorEndToEnd: @pytest.mark.parametrize("index", range(19)) # Exclude surface for now def test_all_plot_types_default(self, sample_dataframe, sample_suggestions, index, sample_2d_array): """Test all plot types with default settings.""" @@ -486,7 +487,7 @@ def test_plotgen_with_large_data(self, sample_suggestions): plt.close(fig) # Error Handling Tests -class TestPlotGeneratorErrorHandling: +class TestBasicPlotGeneratorErrorHandling: def test_plotgen_invalid_index(self, sample_dataframe, sample_suggestions): """Test plotgen with invalid index.""" with pytest.raises(IndexError): @@ -547,7 +548,7 @@ def test_box_no_data(self, sample_suggestions): plotgen(df, 8, sample_suggestions) # Performance Tests -class TestPlotGeneratorPerformance: +class TestBasicPlotGeneratorPerformance: @pytest.mark.parametrize("n", [1000, 10000]) @pytest.mark.parametrize("index", [0, 7, 16]) # Scatter, Hist, Plot3D def test_performance_various_plots(self, sample_suggestions, n, index): @@ -590,7 +591,7 @@ def test_performance_smart_generator(self, sample_suggestions, n): plt.close(fig) # Edge Case Tests -class TestPlotGeneratorEdgeCases: +class TestBasicPlotGeneratorEdgeCases: def test_empty_dataframe(self, sample_suggestions): """Test plotgen with an empty DataFrame.""" df_empty = pd.DataFrame(columns=["value", "count"]) @@ -661,4 +662,4 @@ def test_smart_generator_edge_cases(self, sample_suggestions): plotgen(df, 8, sample_suggestions) if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/test/test_suggestions.py b/test/unit/test_suggestions.py similarity index 100% rename from test/test_suggestions.py rename to test/unit/test_suggestions.py