diff --git a/pyproject.toml b/pyproject.toml index 1605f046..c8d9a152 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vanna" -version = "0.0.9" +version = "0.0.10" authors = [ { name="Zain Hoda", email="zain@vanna.ai" }, ] diff --git a/src/vanna/__init__.py b/src/vanna/__init__.py index aab97465..eb78c23c 100644 --- a/src/vanna/__init__.py +++ b/src/vanna/__init__.py @@ -17,11 +17,31 @@ import warnings import traceback -"""Set the API key for Vanna.AI.""" api_key: Union[str, None] = None # API key for Vanna.AI +""" +## Example +```python +# Login to Vanna.AI +vn.login('user@example.com') +print(vn.api_key) + +vn.api_key='my_api_key' +``` + +This is the API key for Vanna.AI. You can set it manually if you have it or use [`vn.login(...)`][vanna.login] to login and set it automatically. + +""" -"""Set the SQL to DataFrame function for Vanna.AI.""" sql_to_df: Union[Callable[[str], pd.DataFrame], None] = None # Function to convert SQL to a Pandas DataFrame +""" +## Example +```python +vn.sql_to_df = lambda sql: pd.read_sql(sql, engine) +``` + +Set the SQL to DataFrame function for Vanna.AI. This is used in the [`vn.ask(...)`][vanna.ask] function. + +""" __org: Union[str, None] = None # Organization name for Vanna.AI @@ -45,10 +65,10 @@ def __rpc_call(method, params): global __org if api_key is None: - raise Exception("API key not set") + raise Exception("API key not set. Use vn.login(...) to login.") if __org is None and method != "list_orgs": - raise Exception("Organization name not set") + raise Exception("Datasets not set. Use vn.use_datasets([...]) to set the datasets to use.") if method != "list_orgs": headers = { @@ -124,17 +144,17 @@ def login(email: str, otp_code: Union[str, None] = None) -> bool: return True -def list_orgs() -> List[str]: +def list_datasets() -> List[str]: """ ## Example ```python - orgs = vn.list_orgs() + datasets = vn.list_datasets() ``` - List the organizations that the user is a member of. + List the datasets that the user is a member of. Returns: - List[str]: A list of organization names. + List[str]: A list of dataset names. """ d = __rpc_call(method="list_orgs", params=[]) @@ -145,23 +165,23 @@ def list_orgs() -> List[str]: return orgs.organizations -def create_org(org: str, db_type: str) -> bool: +def create_dataset(dataset: str, db_type: str) -> bool: """ ## Example ```python - vn.create_org(org="my-org", db_type="postgres") + vn.create_dataset(dataset="my-dataset", db_type="postgres") ``` - Create a new organization. + Create a new dataset. Args: - org (str): The name of the organization to create. - db_type (str): The type of database to use for the organization. This can be "Snowflake", "BigQuery", "Postgres", or anything else. + dataset (str): The name of the dataset to create. + db_type (str): The type of database to use for the dataset. This can be "Snowflake", "BigQuery", "Postgres", or anything else. Returns: - bool: True if the organization was created successfully, False otherwise. + bool: True if the dataset was created successfully, False otherwise. """ - params = [NewOrganization(org_name=org, db_type=db_type)] + params = [NewOrganization(org_name=dataset, db_type=db_type)] d = __rpc_call(method="create_org", params=params) @@ -172,24 +192,24 @@ def create_org(org: str, db_type: str) -> bool: return status.success -def add_user_to_org(org: str, email: str, is_admin: bool) -> bool: +def add_user_to_dataset(dataset: str, email: str, is_admin: bool) -> bool: """ ## Example ```python - vn.add_user_to_org(org="my-org", email="user@example.com") + vn.add_user_to_dataset(dataset="my-dataset", email="user@example.com") ``` - Add a user to an organization. + Add a user to an dataset. Args: - org (str): The name of the organization to add the user to. + dataset (str): The name of the dataset to add the user to. email (str): The email address of the user to add. Returns: bool: True if the user was added successfully, False otherwise. """ - params = [NewOrganizationMember(org_name=org, email=email, is_admin=is_admin)] + params = [NewOrganizationMember(org_name=dataset, email=email, is_admin=is_admin)] d = __rpc_call(method="add_user_to_org", params=params) @@ -203,21 +223,20 @@ def add_user_to_org(org: str, email: str, is_admin: bool) -> bool: return status.success -def set_org_visibility(visibility: bool) -> bool: +def set_dataset_visibility(visibility: bool) -> bool: """ ## Example ```python - vn.set_org_visibility(org="my-org", visibility=True) + vn.set_dataset_visibility(visibility=True) ``` - Set the visibility of an organization. If an organization is visible, anyone can see it. If it is not visible, only members of the organization can see it. + Set the visibility of the current dataset. If a dataset is visible, anyone can see it. If it is not visible, only members of the dataset can see it. Args: - org (str): The name of the organization to set the visibility of. - visibility (bool): Whether or not the organization should be visible. + visibility (bool): Whether or not the dataset should be publicly visible. Returns: - bool: True if the organization visibility was set successfully, False otherwise. + bool: True if the dataset visibility was set successfully, False otherwise. """ params = [Visibility(visibility=visibility)] @@ -230,47 +249,10 @@ def set_org_visibility(visibility: bool) -> bool: return status.success -def set_org(org: str) -> None: - """ - DEPRECATED. Use [`use_datasets`][vanna.use_datasets] instead. - - Args: - org (str): The organization name. - """ - global __org - print("vn.set_org is deprecated. Please use vn.use_datasets instead.") - warnings.warn("vn.set_org is deprecated. Please use vn.use_datasets instead.", DeprecationWarning) - - my_orgs = list_orgs() - if org not in my_orgs: - # Check if org exists - d = __unauthenticated_rpc_call(method="check_org_exists", params=[Organization(name=org, user=None, connection=None)]) - - if 'result' not in d: - raise Exception("Failed to check if organization exists") - - status = Status(**d['result']) - - if status.success: - raise Exception(f"An organization with the name {org} already exists") - - create = input(f"Would you like to create organization '{org}'? (y/n): ") - - if create.lower() == 'y': - db_type = input("What type of database would you like to use? (Snowflake, BigQuery, Postgres, etc.): ") - __org = 'demo-tpc-h' - if create_org(org=org, db_type=db_type): - __org = org - else: - __org = None - raise Exception("Failed to create organization") - else: - __org = org - def _set_org(org: str) -> None: global __org - my_orgs = list_orgs() + my_orgs = list_datasets() if org not in my_orgs: # Check if org exists d = __unauthenticated_rpc_call(method="check_org_exists", params=[Organization(name=org, user=None, connection=None)]) @@ -288,7 +270,7 @@ def _set_org(org: str) -> None: if create.lower() == 'y': db_type = input("What type of database would you like to use? (Snowflake, BigQuery, Postgres, etc.): ") __org = 'demo-tpc-h' - if create_org(org=org, db_type=db_type): + if create_dataset(dataset=org, db_type=db_type): __org = org else: __org = None @@ -379,7 +361,7 @@ def store_documentation(documentation: str) -> bool: ## Example ```python vn.store_documentation( - documentation="This is a documentation string for the employees table." + documentation="Our organization's definition of sales is the discount price of an item multiplied by the quantity sold." ) ``` @@ -423,7 +405,7 @@ def flag_sql_for_review(question: str, sql: Union[str, None] = None, error_msg: ```python vn.flag_sql_for_review(question="What is the average salary of employees?") ``` - Flag a question and its corresponding SQL query for review. You can later retrieve the flagged questions using [`get_flagged_questions()`][vanna.get_flagged_questions]. + Flag a question and its corresponding SQL query for review. You can see the tag show up in [`vn.get_all_questions()`][vanna.get_all_questions] Args: question (str): The question to flag. @@ -591,15 +573,15 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai if print_results: print(df.head().to_markdown()) + if len(df) > 0 and auto_train: + store_sql(question=question, sql=sql, tag="SQL Ran") + try: plotly_code = generate_plotly_code(question=question, sql=sql, df=df) fig = get_plotly_figure(plotly_code=plotly_code, df=df) if print_results: fig.show() - if len(df) > 0 and auto_train: - store_sql(question=question, sql=sql, tag="Assumed Correct") - return sql, df, fig except Exception as e: @@ -687,7 +669,7 @@ def get_plotly_figure(plotly_code: str, df: pd.DataFrame, dark_mode: bool = True def get_results(cs, default_database: str, sql: str) -> pd.DataFrame: """ - DEPRECATED. Use [`vanna.sql_to_df()`][vanna.sql_to_df] instead. + DEPRECATED. Use `vn.sql_to_df` instead. Run the SQL query and return the results as a pandas dataframe. This is just a helper function that does not use the Vanna.AI API. Args: diff --git a/src/vanna/types.py b/src/vanna/types.py index 1c6f0d34..ab09f4a2 100644 --- a/src/vanna/types.py +++ b/src/vanna/types.py @@ -83,7 +83,7 @@ class QuestionCategory: NO_SQL_GENERATED = "No SQL Generated" SQL_UNABLE_TO_RUN = "SQL Unable to Run" BOOTSTRAP_TRAINING_QUERY = "Bootstrap Training Query" - ASSUMED_CORRECT = "Assumed Correct" + SQL_RAN = "SQL Ran Successfully" FLAGGED_FOR_REVIEW = "Flagged for Review" REVIEWED_AND_APPROVED = "Reviewed and Approved" REVIEWED_AND_REJECTED = "Reviewed and Rejected"