-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
55 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[project] | ||
name = "vanna" | ||
version = "0.0.9" | ||
version = "0.0.10" | ||
authors = [ | ||
{ name="Zain Hoda", email="[email protected]" }, | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,11 +17,31 @@ | |
import warnings | ||
import traceback | ||
|
||
"""Set the API key for Vanna.AI.""" | ||
api_key: Union[str, None] = None # API key for Vanna.AI | ||
""" | ||
## Example | ||
```python | ||
# Login to Vanna.AI | ||
vn.login('[email protected]') | ||
print(vn.api_key) | ||
vn.api_key='my_api_key' | ||
``` | ||
This is the API key for Vanna.AI. You can set it manually if you have it or use [`vn.login(...)`][vanna.login] to login and set it automatically. | ||
""" | ||
|
||
"""Set the SQL to DataFrame function for Vanna.AI.""" | ||
sql_to_df: Union[Callable[[str], pd.DataFrame], None] = None # Function to convert SQL to a Pandas DataFrame | ||
""" | ||
## Example | ||
```python | ||
vn.sql_to_df = lambda sql: pd.read_sql(sql, engine) | ||
``` | ||
Set the SQL to DataFrame function for Vanna.AI. This is used in the [`vn.ask(...)`][vanna.ask] function. | ||
""" | ||
|
||
__org: Union[str, None] = None # Organization name for Vanna.AI | ||
|
||
|
@@ -45,10 +65,10 @@ def __rpc_call(method, params): | |
global __org | ||
|
||
if api_key is None: | ||
raise Exception("API key not set") | ||
raise Exception("API key not set. Use vn.login(...) to login.") | ||
|
||
if __org is None and method != "list_orgs": | ||
raise Exception("Organization name not set") | ||
raise Exception("Datasets not set. Use vn.use_datasets([...]) to set the datasets to use.") | ||
|
||
if method != "list_orgs": | ||
headers = { | ||
|
@@ -124,17 +144,17 @@ def login(email: str, otp_code: Union[str, None] = None) -> bool: | |
|
||
return True | ||
|
||
def list_orgs() -> List[str]: | ||
def list_datasets() -> List[str]: | ||
""" | ||
## Example | ||
```python | ||
orgs = vn.list_orgs() | ||
datasets = vn.list_datasets() | ||
``` | ||
List the organizations that the user is a member of. | ||
List the datasets that the user is a member of. | ||
Returns: | ||
List[str]: A list of organization names. | ||
List[str]: A list of dataset names. | ||
""" | ||
d = __rpc_call(method="list_orgs", params=[]) | ||
|
||
|
@@ -145,23 +165,23 @@ def list_orgs() -> List[str]: | |
|
||
return orgs.organizations | ||
|
||
def create_org(org: str, db_type: str) -> bool: | ||
def create_dataset(dataset: str, db_type: str) -> bool: | ||
""" | ||
## Example | ||
```python | ||
vn.create_org(org="my-org", db_type="postgres") | ||
vn.create_dataset(dataset="my-dataset", db_type="postgres") | ||
``` | ||
Create a new organization. | ||
Create a new dataset. | ||
Args: | ||
org (str): The name of the organization to create. | ||
db_type (str): The type of database to use for the organization. This can be "Snowflake", "BigQuery", "Postgres", or anything else. | ||
dataset (str): The name of the dataset to create. | ||
db_type (str): The type of database to use for the dataset. This can be "Snowflake", "BigQuery", "Postgres", or anything else. | ||
Returns: | ||
bool: True if the organization was created successfully, False otherwise. | ||
bool: True if the dataset was created successfully, False otherwise. | ||
""" | ||
params = [NewOrganization(org_name=org, db_type=db_type)] | ||
params = [NewOrganization(org_name=dataset, db_type=db_type)] | ||
|
||
d = __rpc_call(method="create_org", params=params) | ||
|
||
|
@@ -172,24 +192,24 @@ def create_org(org: str, db_type: str) -> bool: | |
|
||
return status.success | ||
|
||
def add_user_to_org(org: str, email: str, is_admin: bool) -> bool: | ||
def add_user_to_dataset(dataset: str, email: str, is_admin: bool) -> bool: | ||
""" | ||
## Example | ||
```python | ||
vn.add_user_to_org(org="my-org", email="[email protected]") | ||
vn.add_user_to_dataset(dataset="my-dataset", email="[email protected]") | ||
``` | ||
Add a user to an organization. | ||
Add a user to an dataset. | ||
Args: | ||
org (str): The name of the organization to add the user to. | ||
dataset (str): The name of the dataset to add the user to. | ||
email (str): The email address of the user to add. | ||
Returns: | ||
bool: True if the user was added successfully, False otherwise. | ||
""" | ||
|
||
params = [NewOrganizationMember(org_name=org, email=email, is_admin=is_admin)] | ||
params = [NewOrganizationMember(org_name=dataset, email=email, is_admin=is_admin)] | ||
|
||
d = __rpc_call(method="add_user_to_org", params=params) | ||
|
||
|
@@ -203,21 +223,20 @@ def add_user_to_org(org: str, email: str, is_admin: bool) -> bool: | |
|
||
return status.success | ||
|
||
def set_org_visibility(visibility: bool) -> bool: | ||
def set_dataset_visibility(visibility: bool) -> bool: | ||
""" | ||
## Example | ||
```python | ||
vn.set_org_visibility(org="my-org", visibility=True) | ||
vn.set_dataset_visibility(visibility=True) | ||
``` | ||
Set the visibility of an organization. If an organization is visible, anyone can see it. If it is not visible, only members of the organization can see it. | ||
Set the visibility of the current dataset. If a dataset is visible, anyone can see it. If it is not visible, only members of the dataset can see it. | ||
Args: | ||
org (str): The name of the organization to set the visibility of. | ||
visibility (bool): Whether or not the organization should be visible. | ||
visibility (bool): Whether or not the dataset should be publicly visible. | ||
Returns: | ||
bool: True if the organization visibility was set successfully, False otherwise. | ||
bool: True if the dataset visibility was set successfully, False otherwise. | ||
""" | ||
params = [Visibility(visibility=visibility)] | ||
|
||
|
@@ -230,47 +249,10 @@ def set_org_visibility(visibility: bool) -> bool: | |
|
||
return status.success | ||
|
||
def set_org(org: str) -> None: | ||
""" | ||
DEPRECATED. Use [`use_datasets`][vanna.use_datasets] instead. | ||
Args: | ||
org (str): The organization name. | ||
""" | ||
global __org | ||
print("vn.set_org is deprecated. Please use vn.use_datasets instead.") | ||
warnings.warn("vn.set_org is deprecated. Please use vn.use_datasets instead.", DeprecationWarning) | ||
|
||
my_orgs = list_orgs() | ||
if org not in my_orgs: | ||
# Check if org exists | ||
d = __unauthenticated_rpc_call(method="check_org_exists", params=[Organization(name=org, user=None, connection=None)]) | ||
|
||
if 'result' not in d: | ||
raise Exception("Failed to check if organization exists") | ||
|
||
status = Status(**d['result']) | ||
|
||
if status.success: | ||
raise Exception(f"An organization with the name {org} already exists") | ||
|
||
create = input(f"Would you like to create organization '{org}'? (y/n): ") | ||
|
||
if create.lower() == 'y': | ||
db_type = input("What type of database would you like to use? (Snowflake, BigQuery, Postgres, etc.): ") | ||
__org = 'demo-tpc-h' | ||
if create_org(org=org, db_type=db_type): | ||
__org = org | ||
else: | ||
__org = None | ||
raise Exception("Failed to create organization") | ||
else: | ||
__org = org | ||
|
||
def _set_org(org: str) -> None: | ||
global __org | ||
|
||
my_orgs = list_orgs() | ||
my_orgs = list_datasets() | ||
if org not in my_orgs: | ||
# Check if org exists | ||
d = __unauthenticated_rpc_call(method="check_org_exists", params=[Organization(name=org, user=None, connection=None)]) | ||
|
@@ -288,7 +270,7 @@ def _set_org(org: str) -> None: | |
if create.lower() == 'y': | ||
db_type = input("What type of database would you like to use? (Snowflake, BigQuery, Postgres, etc.): ") | ||
__org = 'demo-tpc-h' | ||
if create_org(org=org, db_type=db_type): | ||
if create_dataset(dataset=org, db_type=db_type): | ||
__org = org | ||
else: | ||
__org = None | ||
|
@@ -379,7 +361,7 @@ def store_documentation(documentation: str) -> bool: | |
## Example | ||
```python | ||
vn.store_documentation( | ||
documentation="This is a documentation string for the employees table." | ||
documentation="Our organization's definition of sales is the discount price of an item multiplied by the quantity sold." | ||
) | ||
``` | ||
|
@@ -423,7 +405,7 @@ def flag_sql_for_review(question: str, sql: Union[str, None] = None, error_msg: | |
```python | ||
vn.flag_sql_for_review(question="What is the average salary of employees?") | ||
``` | ||
Flag a question and its corresponding SQL query for review. You can later retrieve the flagged questions using [`get_flagged_questions()`][vanna.get_flagged_questions]. | ||
Flag a question and its corresponding SQL query for review. You can see the tag show up in [`vn.get_all_questions()`][vanna.get_all_questions] | ||
Args: | ||
question (str): The question to flag. | ||
|
@@ -591,15 +573,15 @@ def ask(question: Union[str, None] = None, print_results: bool = True, auto_trai | |
if print_results: | ||
print(df.head().to_markdown()) | ||
|
||
if len(df) > 0 and auto_train: | ||
store_sql(question=question, sql=sql, tag="SQL Ran") | ||
|
||
try: | ||
plotly_code = generate_plotly_code(question=question, sql=sql, df=df) | ||
fig = get_plotly_figure(plotly_code=plotly_code, df=df) | ||
if print_results: | ||
fig.show() | ||
|
||
if len(df) > 0 and auto_train: | ||
store_sql(question=question, sql=sql, tag="Assumed Correct") | ||
|
||
return sql, df, fig | ||
|
||
except Exception as e: | ||
|
@@ -687,7 +669,7 @@ def get_plotly_figure(plotly_code: str, df: pd.DataFrame, dark_mode: bool = True | |
|
||
def get_results(cs, default_database: str, sql: str) -> pd.DataFrame: | ||
""" | ||
DEPRECATED. Use [`vanna.sql_to_df()`][vanna.sql_to_df] instead. | ||
DEPRECATED. Use `vn.sql_to_df` instead. | ||
Run the SQL query and return the results as a pandas dataframe. This is just a helper function that does not use the Vanna.AI API. | ||
Args: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters