diff --git a/README.md b/README.md index 9375564..c28744f 100644 --- a/README.md +++ b/README.md @@ -117,19 +117,11 @@ Some things to take note of: - If you do not populate your database with data (ie only create the tables without inserting data), you would return empty dataframes most of the time (regardless of whether the query generated was what you want), and it would result in results matching all the time and generate a lot of false positives. Hence, you might want to consider populating your database with some meaningful data that would return different results if the queries should be different from what you want. - If testing out on your private data, you would also need to change the questions file to point to your own questions file (tailored to your database schema). -### Query Generator - -To test your own query generator with our framework, you would need to extend [Query Generator](query_generators/query_generator.py) and implement the [generate_query](query_generators/query_generator.py#L18) method to return the query of interest. We create a new class for each question/query pair to isolate each pair's runtime state against the others when running concurrently. You can also reference [OpenAIQueryGenerator](query_generators/openai.py) which implements `Query Generator` and uses a simple prompt to send a message over to OpenAI's API. Feel free to extend it for your own use. - -If there are functions that are generally useful for all query generators, they can be placed in the `utils` folder. If you need to incorporate specific verbose templates (e.g. for prompt testing), you can store them in the `prompts` folder, and later import them. Being able to version control the prompts in a central place has been a productivity win for our team. - ### Runner -Having implemented the query generator, the next piece of abstraction would be the runner. The runner calls the query generator, and is responsible for handling the configuration of work (e.g. parallelization / batching / model selected etc.) to the query generator for each question/query pair. +The runner calls is responsible for handling the configuration of work (e.g. parallelization / batching / model selected etc.) for each question/query pair. -We have provided a few common runners: `eval/openai_runner.py` for calling OpenAI's API (with parallelization support), `eval/anthropic_runner` for calling Anthropic's API, `eval/hf_runner.py` for calling a local Hugging Face model and finally, `eval/api_runner.py` makes it possible to use a custom API for evaluation. - -When testing your own query generator with an existing runner, you can replace the `qg_class` in the runner's code with your own query generator class. +We have provided a few common runners: `runners/openai_runner.py` for calling OpenAI's API (with parallelization support), `runners/anthropic_runner` for calling Anthropic's API, `runners/hf_runner.py` for calling a local Hugging Face model and finally, `runners/api_runner.py` makes it possible to use a custom API for evaluation. ## Running the Test @@ -500,8 +492,6 @@ You can use the following flags in the command line to change the configurations ## Checking the Results -To better understand your query generator's performance, you can explore the results generated and aggregated for the various metrics that you care about. - ### Upload URL If you would like to start a google cloud function to receive the results, you can use the `--upload_url` flag to specify the URL that you want to report the results to. Before running the evaluation code with this flag, you would need to create a server that serves at the provided URL. We have provided 2 sample cloud function endpoints for writing either to bigquery or postgres, in the `results_fn_bigquery` and `results_fn_postgres` folders. You may also implement your own server to take in similar arguments. Before deploying either cloud functions, you would need to set up the environment variables by making a copy of .env.yaml.template and renaming it to .env.yaml, and then filling in the relevant fields. For the bigquery cloud function, you would also need to put your service account's key.json file in the same folder, and put the file name in the `CREDENTIALS_PATH` field in the .env.yaml file. @@ -572,7 +562,6 @@ We welcome contributions to our project, specifically: - Dataset - Adding new database schema/data - Framework code - - New query generators/runners (in the [query_generators](query_generators) and [eval](eval) folders respectively) - Improving existing generators/runners (e.g. adding new metrics) Please see [CONTRIBUTING.md](https://github.com/defog-ai/sql-generation-evaluation/blob/main/CONTRIBUTING.md) for more information. diff --git a/eval/anthropic_runner.py b/eval/anthropic_runner.py deleted file mode 100644 index e58f903..0000000 --- a/eval/anthropic_runner.py +++ /dev/null @@ -1,159 +0,0 @@ -import json -from concurrent.futures import ThreadPoolExecutor, as_completed -import copy -import os -from eval.eval import compare_query_results -import pandas as pd -from psycopg2.extensions import QueryCanceledError -from query_generators.anthropic import AnthropicQueryGenerator -from tqdm import tqdm -from utils.questions import prepare_questions_df -from utils.creds import db_creds_all -from utils.reporting import upload_results - - -def run_anthropic_eval(args): - # get params from args - questions_file_list = args.questions_file - prompt_file_list = args.prompt_file - output_file_list = args.output_file - num_questions = args.num_questions - k_shot = args.k_shot - db_type = args.db_type - cot_table_alias = args.cot_table_alias - - for questions_file, prompt_file, output_file in zip( - questions_file_list, prompt_file_list, output_file_list - ): - print(f"Using prompt file {prompt_file}") - # get questions - print("Preparing questions...") - print( - f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" - ) - question_query_df = prepare_questions_df( - questions_file, db_type, num_questions, k_shot, cot_table_alias - ) - input_rows = question_query_df.to_dict("records") - output_rows = [] - with ThreadPoolExecutor(args.parallel_threads) as executor: - # for each query in the csv, generate a query using the generator asynchronously - futures = [] - for row in input_rows: - # get db creds for each row's db_name - db_name = row["db_name"] - db_creds = db_creds_all[row["db_type"]] - - qg = AnthropicQueryGenerator( - db_creds=copy.deepcopy(db_creds), - db_name=db_name, - db_type=db_type, - model=args.model, - prompt_file=prompt_file, - timeout=args.timeout_gen, - use_public_data=not args.use_private_data, - verbose=args.verbose, - ) - - generated_query_fut = executor.submit( - qg.generate_query, - question=row["question"], - instructions=row["instructions"], - k_shot_prompt=row["k_shot_prompt"], - glossary=row["glossary"], - table_metadata_string=row["table_metadata_string"], - prev_invalid_sql=row["prev_invalid_sql"], - prev_error_msg=row["prev_error_msg"], - cot_instructions=row["cot_instructions"], - columns_to_keep=args.num_columns, - shuffle=args.shuffle_metadata, - ) - futures.append(generated_query_fut) - - total_tried = 0 - total_correct = 0 - for f in (pbar := tqdm(as_completed(futures), total=len(futures))): - total_tried += 1 - i = futures.index(f) - row = input_rows[i] - result_dict = f.result() - query_gen = result_dict["query"] - reason = result_dict["reason"] - err = result_dict["err"] - # save custom metrics - if "latency_seconds" in result_dict: - row["latency_seconds"] = result_dict["latency_seconds"] - if "tokens_used" in result_dict: - row["tokens_used"] = result_dict["tokens_used"] - row["generated_query"] = query_gen - row["reason"] = reason - row["error_msg"] = err - # save failures into relevant columns in the dataframe - if "GENERATION ERROR" in err: - row["error_query_gen"] = 1 - elif "EXECUTION ERROR" in err: - row["error_db_exec"] = 1 - elif "TIMEOUT" in err: - row["timeout"] = 1 - else: - expected_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - db_creds = db_creds_all[row["db_type"]] - # try executing the queries and compare the results if they succeed - try: - exact_match, correct = compare_query_results( - query_gold=expected_query, - query_gen=query_gen, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[db_type], - timeout=args.timeout_exec, - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - if correct: - total_correct += 1 - except QueryCanceledError as e: - row["timeout"] = 1 - row["error_msg"] = f"QUERY EXECUTION TIMEOUT: {e}" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - output_rows.append(row) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) - output_df = pd.DataFrame(output_rows) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - output_df.to_csv(output_file, index=False, float_format="%.2f") - - # get average rate of correct results - avg_subset = output_df["correct"].sum() / len(output_df) - print(f"Average correct rate: {avg_subset:.2f}") - - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="anthropic", - prompt=prompt, - args=args, - ) diff --git a/eval/openai_runner.py b/eval/openai_runner.py deleted file mode 100644 index 6a879ac..0000000 --- a/eval/openai_runner.py +++ /dev/null @@ -1,176 +0,0 @@ -import json -from concurrent.futures import ThreadPoolExecutor, as_completed -import copy -import os - -from eval.eval import compare_query_results -import pandas as pd -from psycopg2.extensions import QueryCanceledError -from query_generators.openai import OpenAIQueryGenerator -from tqdm import tqdm -from utils.questions import prepare_questions_df -from utils.creds import db_creds_all -from utils.reporting import upload_results - - -def run_openai_eval(args): - # get params from args - questions_file_list = args.questions_file - prompt_file_list = args.prompt_file - output_file_list = args.output_file - num_questions = args.num_questions - k_shot = args.k_shot - db_type = args.db_type - cot_table_alias = args.cot_table_alias - - for questions_file, prompt_file, output_file in zip( - questions_file_list, prompt_file_list, output_file_list - ): - print(f"Using prompt file {prompt_file}") - # get questions - print("Preparing questions...") - print( - f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" - ) - question_query_df = prepare_questions_df( - questions_file, db_type, num_questions, k_shot, cot_table_alias - ) - input_rows = question_query_df.to_dict("records") - output_rows = [] - with ThreadPoolExecutor(args.parallel_threads) as executor: - # for each query in the csv, generate a query using the generator asynchronously - futures = [] - for row in input_rows: - # get db creds for each row's db_name - db_name = row["db_name"] - db_creds = db_creds_all[row["db_type"]] - - qg = OpenAIQueryGenerator( - db_creds=copy.deepcopy(db_creds), - db_name=db_name, - db_type=db_type, - model=args.model, - prompt_file=prompt_file, - timeout=args.timeout_gen, - use_public_data=not args.use_private_data, - verbose=args.verbose, - ) - - generated_query_fut = executor.submit( - qg.generate_query, - question=row["question"], - instructions=row["instructions"], - k_shot_prompt=row["k_shot_prompt"], - glossary=row["glossary"], - table_metadata_string=row["table_metadata_string"], - prev_invalid_sql=row["prev_invalid_sql"], - prev_error_msg=row["prev_error_msg"], - table_aliases=row["table_aliases"], - columns_to_keep=args.num_columns, - shuffle=args.shuffle_metadata, - ) - futures.append(generated_query_fut) - - total_tried = 0 - total_correct = 0 - for f in (pbar := tqdm(as_completed(futures), total=len(futures))): - total_tried += 1 - i = futures.index(f) - row = input_rows[i] - result_dict = f.result() - query_gen = result_dict["query"] - reason = result_dict["reason"] - err = result_dict["err"] - table_metadata_string = result_dict["table_metadata_string"] - # save custom metrics - if "latency_seconds" in result_dict: - row["latency_seconds"] = result_dict["latency_seconds"] - if "tokens_used" in result_dict: - row["tokens_used"] = result_dict["tokens_used"] - row["generated_query"] = query_gen - row["reason"] = reason - row["error_msg"] = err - row["table_metadata_string"] = table_metadata_string - # save failures into relevant columns in the dataframe - if "GENERATION ERROR" in err: - row["error_query_gen"] = 1 - elif "EXECUTION ERROR" in err: - row["error_db_exec"] = 1 - elif "TIMEOUT" in err: - row["timeout"] = 1 - else: - expected_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - db_creds = db_creds_all[db_type] - # try executing the queries and compare the results if they succeed - try: - exact_match, correct = compare_query_results( - query_gold=expected_query, - query_gen=query_gen, - db_name=db_name, - db_type=db_type, - db_creds=db_creds, - timeout=args.timeout_exec, - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - if correct: - total_correct += 1 - except QueryCanceledError as e: - row["timeout"] = 1 - row["error_msg"] = f"QUERY EXECUTION TIMEOUT: {e}" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - output_rows.append(row) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) - output_df = pd.DataFrame(output_rows) - output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - if "prompt" in output_df.columns: - del output_df["prompt"] - # get num rows, mean correct, mean error_db_exec for each query_category - agg_stats = ( - output_df.groupby("query_category") - .agg( - num_rows=("db_name", "count"), - mean_correct=("correct", "mean"), - mean_error_db_exec=("error_db_exec", "mean"), - ) - .reset_index() - ) - print(agg_stats) - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - output_df.to_csv(output_file, index=False, float_format="%.2f") - - # get average rate of correct results - avg_subset = output_df["correct"].sum() / len(output_df) - print(f"Average correct rate: {avg_subset:.2f}") - - results = output_df.to_dict("records") - - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="openai", - prompt=prompt, - args=args, - ) diff --git a/main.py b/main.py index 28d4301..b6500fc 100644 --- a/main.py +++ b/main.py @@ -80,13 +80,13 @@ ) if args.model_type == "oa": - from eval.openai_runner import run_openai_eval + from runners.openai_runner import run_openai_eval if args.model is None: args.model = "gpt-3.5-turbo-0613" run_openai_eval(args) elif args.model_type == "anthropic": - from eval.anthropic_runner import run_anthropic_eval + from runners.anthropic_runner import run_anthropic_eval if args.model is None: args.model = "claude-2" @@ -98,11 +98,11 @@ raise ValueError( "vLLM is not supported on macOS. Please run on another OS supporting CUDA." ) - from eval.vllm_runner import run_vllm_eval + from runners.vllm_runner import run_vllm_eval run_vllm_eval(args) elif args.model_type == "hf": - from eval.hf_runner import run_hf_eval + from runners.hf_runner import run_hf_eval run_hf_eval(args) elif args.model_type == "api": @@ -110,35 +110,35 @@ assert args.api_type is not None, "api_type must be provided for api model" assert args.api_type in ["vllm", "tgi"], "api_type must be one of 'vllm', 'tgi'" - from eval.api_runner import run_api_eval + from runners.api_runner import run_api_eval run_api_eval(args) elif args.model_type == "llama_cpp": - from eval.llama_cpp_runner import run_llama_cpp_eval + from runners.llama_cpp_runner import run_llama_cpp_eval run_llama_cpp_eval(args) elif args.model_type == "mlx": - from eval.mlx_runner import run_mlx_eval + from runners.mlx_runner import run_mlx_eval run_mlx_eval(args) elif args.model_type == "gemini": - from eval.gemini_runner import run_gemini_eval + from runners.gemini_runner import run_gemini_eval run_gemini_eval(args) elif args.model_type == "mistral": - from eval.mistral_runner import run_mistral_eval + from runners.mistral_runner import run_mistral_eval run_mistral_eval(args) elif args.model_type == "bedrock": - from eval.bedrock_runner import run_bedrock_eval + from runners.bedrock_runner import run_bedrock_eval run_bedrock_eval(args) elif args.model_type == "together": - from eval.together_runner import run_together_eval + from runners.together_runner import run_together_eval run_together_eval(args) elif args.model_type == "deepseek": - from eval.deepseek_runner import run_deepseek_eval + from runners.deepseek_runner import run_deepseek_eval run_deepseek_eval(args) else: diff --git a/query_generators/anthropic.py b/query_generators/anthropic.py deleted file mode 100644 index 15c46ec..0000000 --- a/query_generators/anthropic.py +++ /dev/null @@ -1,181 +0,0 @@ -from typing import Dict -from func_timeout import FunctionTimedOut, func_timeout -import os -import time - -from query_generators.query_generator import QueryGenerator -from utils.gen_prompt import to_prompt_schema -from utils.dialects import convert_postgres_ddl_to_dialect -from utils.llm import chat_anthropic - - -class AnthropicQueryGenerator(QueryGenerator): - """ - Query generator that uses Anthropic's models - Models available: claude-2, claude-instant-1 - """ - - def __init__( - self, - db_type: str, - db_creds: Dict[str, str], - db_name: str, - model: str, - prompt_file: str, - timeout: int, - use_public_data: bool, - verbose: bool, - **kwargs, - ): - self.db_creds = db_creds - self.db_type = db_type - self.db_name = db_name - self.model = model - self.prompt_file = prompt_file - self.use_public_data = use_public_data - self.timeout = timeout - self.verbose = verbose - - def get_completion( - self, - model, - prompt, - max_tokens=600, - temperature=0, - stop=["```", ";"], - logit_bias={}, - ): - """Get Anthropic chat completion using the new utility function""" - messages = [{"role": "user", "content": prompt}] - try: - response = chat_anthropic( - messages=messages, - model=model, - max_completion_tokens=max_tokens, - temperature=temperature, - stop=stop, - ) - return response.content - except Exception as e: - print(str(e)) - if self.verbose: - print(type(e), e) - return "" - - def generate_query( - self, - question: str, - instructions: str, - k_shot_prompt: str, - glossary: str, - table_metadata_string: str, - prev_invalid_sql: str, - prev_error_msg: str, - cot_instructions: str, - columns_to_keep: int, - shuffle: bool, - ) -> dict: - start_time = time.time() - self.err = "" - self.query = "" - self.reason = "" - tokens_used = 0 - - if self.use_public_data: - from defog_data.metadata import dbs - import defog_data.supplementary as sup - else: - # raise Exception("Replace this with your private data import") - from defog_data_private.metadata import dbs - import defog_data_private.supplementary as sup - - with open(self.prompt_file) as file: - model_prompt = file.read() - question_instructions = question + " " + instructions - - if table_metadata_string == "": - md = dbs[self.db_name]["table_metadata"] - pruned_metadata_str = to_prompt_schema(md, shuffle) - pruned_metadata_str = convert_postgres_ddl_to_dialect( - postgres_ddl=pruned_metadata_str, - to_dialect=self.db_type, - db_name=self.db_name, - ) - column_join = sup.columns_join.get(self.db_name, {}) - # get join_str from column_join - join_list = [] - for values in column_join.values(): - if isinstance(values[0], tuple): - for col_pair in values: - col_1, col_2 = col_pair - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - else: - col_1, col_2 = values[0] - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - if len(join_list) > 0: - join_str = "\nHere is a list of joinable columns:\n" + "\n".join( - join_list - ) - else: - join_str = "" - pruned_metadata_str = pruned_metadata_str + join_str - else: - pruned_metadata_str = table_metadata_string - - prompt = model_prompt.format( - user_question=question, - db_type=self.db_type, - table_metadata_string=pruned_metadata_str, - instructions=instructions, - k_shot_prompt=k_shot_prompt, - glossary=glossary, - prev_invalid_sql=prev_invalid_sql, - prev_error_msg=prev_error_msg, - cot_instructions=cot_instructions, - ) - function_to_run = self.get_completion - package = prompt - - try: - self.completion = func_timeout( - self.timeout, - function_to_run, - args=( - self.model, - package, - 600, - 0, - ["```", ";"], - ), - ) - results = self.completion - self.query = results.split("```sql")[-1].split(";")[0].split("```")[0] - self.reason = "-" - except FunctionTimedOut: - if self.verbose: - print("generating query timed out") - self.err = "QUERY GENERATION TIMEOUT" - except Exception as e: - if self.verbose: - print(f"Error while generating query: {type(e)}, {e})") - self.query = "" - self.reason = "" - if isinstance(e, KeyError): - self.err = f"QUERY GENERATION ERROR: {type(e)}, {e}, Completion: {self.completion}" - else: - self.err = f"QUERY GENERATION ERROR: {type(e)}, {e}" - - return { - "query": self.query, - "reason": self.reason, - "err": self.err, - "latency_seconds": time.time() - start_time, - "tokens_used": tokens_used, - "table_metadata_string": pruned_metadata_str, - } diff --git a/query_generators/openai.py b/query_generators/openai.py deleted file mode 100644 index cdfeb49..0000000 --- a/query_generators/openai.py +++ /dev/null @@ -1,201 +0,0 @@ -from typing import Dict, List -import time -from func_timeout import FunctionTimedOut, func_timeout -import json - -from query_generators.query_generator import QueryGenerator -from utils.gen_prompt import to_prompt_schema -from utils.dialects import convert_postgres_ddl_to_dialect -from utils.llm import chat_openai, LLMResponse - - -class OpenAIQueryGenerator(QueryGenerator): - """ - Query generator that uses OpenAI's models - """ - - def __init__( - self, - db_creds: Dict[str, str], - db_name: str, - db_type: str, - model: str, - prompt_file: str, - timeout: int, - use_public_data: bool, - verbose: bool, - **kwargs, - ): - self.db_creds = db_creds - self.db_type = db_type - self.db_name = db_name - self.model = model - self.o1 = self.model.startswith("o1-") - self.prompt_file = prompt_file - self.use_public_data = use_public_data - self.timeout = timeout - self.verbose = verbose - - def get_chat_completion( - self, - model, - messages, - max_tokens=600, - temperature=0, - stop=[], - logit_bias={}, - seed=100, - ) -> str: - """Get OpenAI chat completion using the new utility function""" - try: - response: LLMResponse = chat_openai( - messages=messages, - model=model, - max_completion_tokens=max_tokens, - temperature=temperature, - stop=stop, - seed=seed, - ) - return response.content - except Exception as e: - print(str(e)) - if self.verbose: - print(type(e), e) - return "" - - def generate_query( - self, - question: str, - instructions: str, - k_shot_prompt: str, - glossary: str, - table_metadata_string: str, - prev_invalid_sql: str, - prev_error_msg: str, - table_aliases: str, - columns_to_keep: int, - shuffle: bool, - ) -> dict: - start_time = time.time() - self.err = "" - self.query = "" - self.reason = "" - tokens_used = 0 - - if self.use_public_data: - from defog_data.metadata import dbs - import defog_data.supplementary as sup - else: - # raise Exception("Replace this with your private data import") - from defog_data_private.metadata import dbs - import defog_data_private.supplementary as sup - - with open(self.prompt_file) as file: - chat_prompt = json.load(file) - question_instructions = question + " " + instructions - - if table_metadata_string == "": - md = dbs[self.db_name]["table_metadata"] - table_metadata_ddl = to_prompt_schema(md, shuffle) - table_metadata_ddl = convert_postgres_ddl_to_dialect( - postgres_ddl=table_metadata_ddl, - to_dialect=self.db_type, - db_name=self.db_name, - ) - column_join = sup.columns_join.get(self.db_name, {}) - # get join_str from column_join - join_list = [] - for values in column_join.values(): - if isinstance(values[0], tuple): - for col_pair in values: - col_1, col_2 = col_pair - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - else: - col_1, col_2 = values[0] - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - if len(join_list) > 0: - join_str = "\nHere is a list of joinable columns:\n" + "\n".join( - join_list - ) - else: - join_str = "" - table_metadata_string = table_metadata_ddl + join_str - - if glossary == "": - glossary = dbs[self.db_name]["glossary"] - try: - if self.o1: - sys_prompt = "" - user_prompt = chat_prompt[0]["content"] - else: - sys_prompt = chat_prompt[0]["content"] - sys_prompt = sys_prompt.format( - db_type=self.db_type, - ) - user_prompt = chat_prompt[1]["content"] - if len(chat_prompt) == 3: - assistant_prompt = chat_prompt[2]["content"] - except: - raise ValueError("Invalid prompt file. Please use prompt_openai.md") - user_prompt = user_prompt.format( - db_type=self.db_type, - user_question=question, - table_metadata_string=table_metadata_string, - instructions=instructions, - k_shot_prompt=k_shot_prompt, - glossary=glossary, - prev_invalid_sql=prev_invalid_sql, - prev_error_msg=prev_error_msg, - table_aliases=table_aliases, - ) - - if self.o1: - messages = [{"role": "user", "content": user_prompt}] - else: - messages = [] - messages.append({"role": "system", "content": sys_prompt}) - messages.append({"role": "user", "content": user_prompt}) - if len(chat_prompt) == 3: - messages.append({"role": "assistant", "content": assistant_prompt}) - - function_to_run = self.get_chat_completion - package = messages - - try: - self.completion = func_timeout( - self.timeout, - function_to_run, - args=(self.model, package, 1200, 0), - ) - results = self.completion - self.query = results.split("```sql")[-1].split("```")[0] - self.reason = "-" - except FunctionTimedOut: - if self.verbose: - print("generating query timed out") - self.err = "QUERY GENERATION TIMEOUT" - except Exception as e: - if self.verbose: - print(f"Error while generating query: {type(e)}, {e})") - self.query = "" - self.reason = "" - print(e) - if isinstance(e, KeyError): - self.err = f"QUERY GENERATION ERROR: {type(e)}, {e}, Completion: {self.completion}" - else: - self.err = f"QUERY GENERATION ERROR: {type(e)}, {e}" - - return { - "table_metadata_string": table_metadata_string, - "query": self.query, - "reason": self.reason, - "err": self.err, - "latency_seconds": time.time() - start_time, - "tokens_used": tokens_used, - } diff --git a/query_generators/query_generator.py b/query_generators/query_generator.py deleted file mode 100644 index 5cb48a0..0000000 --- a/query_generators/query_generator.py +++ /dev/null @@ -1,62 +0,0 @@ -import psycopg2 - - -class QueryGenerator: - """ - To customize a query generator, you would implement/override the following functions: - __init__: for initializing the question-specific parameters (eg credentials for the database). - generate_query: implement your query generation logic given a question. add your secret sauce here! - - The following function(s) are implemented, as these are common across all query generators: - exec_query: executes the query generated by generate_query; only postgres for now. It has - an implicit dependency on self.db_creds and self.verbose from __init__. - """ - - def __init__(self, **kwargs): - pass - - def generate_query( - self, - question: str, - instructions: str, - k_shot_prompt: str, - glossary: str, - table_metadata_string: str, - prev_invalid_sql: str, - prev_error_msg: str, - ) -> dict: - # generate a query given a question, instructions and k-shot prompt - # any hard-coded logic, prompt-engineering, table-pruning, api calls etc - # should be completely contained within this function - # do add try-except blocks to catch any errors and return an empty string - # these are the keys that you should store in the returned dict: - # query: the generated query - # reason: the reason for the query - # err: the error message if any - # any other fields you might want to track (eg tokens used in query, latency etc) - pass - - def exec_query(self, query: str) -> str: - """ - Tries to execute a query and returns an error message if unsuccessful - This function implicitly relies on self.db_creds from init - """ - if self.db_type != "postgres": - raise ValueError("Only postgres is supported for now") - try: - self.conn = psycopg2.connect(**self.db_creds) - self.cur = self.conn.cursor() - self.cur.execute(query) - _ = self.cur.fetchall() - self.cur.close() - self.conn.close() - return "" - except Exception as e: - if self.verbose: - print(f"Error while executing query:\n{type(e)}, {e}") - # cleanup connections - if self.cur: - self.cur.close() - if self.conn: - self.conn.close() - return str(e) diff --git a/runners/anthropic_runner.py b/runners/anthropic_runner.py new file mode 100644 index 0000000..7461bf3 --- /dev/null +++ b/runners/anthropic_runner.py @@ -0,0 +1,244 @@ +import os +from time import time +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pandas as pd +import sqlparse +from tqdm import tqdm + +from eval.eval import compare_query_results +from utils.creds import db_creds_all +from utils.dialects import convert_postgres_ddl_to_dialect +from utils.gen_prompt import to_prompt_schema +from utils.questions import prepare_questions_df +from utils.reporting import upload_results +from utils.llm import chat_anthropic + + +def generate_prompt( + prompt_file, + question, + db_name, + db_type, + instructions="", + k_shot_prompt="", + glossary="", + table_metadata_string="", + prev_invalid_sql="", + prev_error_msg="", + public_data=True, + shuffle=True, +): + if "anthropic" not in prompt_file: + raise ValueError("Invalid prompt file. Please use prompt_anthropic.md") + + if public_data: + from defog_data.metadata import dbs + import defog_data.supplementary as sup + else: + from defog_data_private.metadata import dbs + import defog_data_private.supplementary as sup + + with open(prompt_file, "r") as f: + prompt = f.read() + + if table_metadata_string == "": + md = dbs[db_name]["table_metadata"] + pruned_metadata_ddl = to_prompt_schema(md, shuffle) + pruned_metadata_ddl = convert_postgres_ddl_to_dialect( + postgres_ddl=pruned_metadata_ddl, + to_dialect=db_type, + db_name=db_name, + ) + column_join = sup.columns_join.get(db_name, {}) + join_list = [] + for values in column_join.values(): + if isinstance(values[0], tuple): + for col_pair in values: + col_1, col_2 = col_pair + join_str = f"{col_1} can be joined with {col_2}" + if join_str not in join_list: + join_list.append(join_str) + else: + col_1, col_2 = values[0] + join_str = f"{col_1} can be joined with {col_2}" + if join_str not in join_list: + join_list.append(join_str) + if len(join_list) > 0: + join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) + else: + join_str = "" + pruned_metadata_str = pruned_metadata_ddl + join_str + else: + pruned_metadata_str = table_metadata_string + + prompt = prompt.format( + user_question=question, + db_type=db_type, + instructions=instructions, + table_metadata_string=pruned_metadata_str, + k_shot_prompt=k_shot_prompt, + glossary=glossary, + prev_invalid_sql=prev_invalid_sql, + prev_error_msg=prev_error_msg, + ) + return prompt + + +def process_row(row, model_name, args): + start_time = time() + prompt = generate_prompt( + prompt_file=args.prompt_file[0], + question=row["question"], + db_name=row["db_name"], + db_type=args.db_type, + instructions=row["instructions"], + k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], + prev_invalid_sql=row["prev_invalid_sql"], + prev_error_msg=row["prev_error_msg"], + public_data=not args.use_private_data, + shuffle=args.shuffle_metadata, + ) + messages = [{"role": "user", "content": prompt}] + try: + response = chat_anthropic(messages=messages, model=model_name, temperature=0.0) + generated_query = ( + response.content.split("```sql", 1)[-1].split("```", 1)[0].strip() + ) + try: + generated_query = sqlparse.format( + generated_query, reindent=True, keyword_case="upper" + ) + except: + pass + return { + "query": generated_query, + "reason": "", + "err": "", + "latency_seconds": time() - start_time, + "tokens_used": response.input_tokens + response.output_tokens, + } + except Exception as e: + return { + "query": "", + "reason": "", + "err": f"GENERATION ERROR: {str(e)}", + "latency_seconds": time() - start_time, + "tokens_used": 0, + } + + +def run_anthropic_eval(args): + # get params from args + questions_file_list = args.questions_file + prompt_file_list = args.prompt_file + output_file_list = args.output_file + num_questions = args.num_questions + k_shot = args.k_shot + db_type = args.db_type + cot_table_alias = args.cot_table_alias + + for questions_file, prompt_file, output_file in zip( + questions_file_list, prompt_file_list, output_file_list + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") + print( + f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" + ) + question_query_df = prepare_questions_df( + questions_file, db_type, num_questions, k_shot, cot_table_alias + ) + input_rows = question_query_df.to_dict("records") + output_rows = [] + with ThreadPoolExecutor(args.parallel_threads) as executor: + futures = [] + for row in input_rows: + generated_query_fut = executor.submit( + process_row, + row=row, + model_name=args.model, + args=args, + ) + futures.append(generated_query_fut) + + total_tried = 0 + total_correct = 0 + for f in (pbar := tqdm(as_completed(futures), total=len(futures))): + total_tried += 1 + i = futures.index(f) + row = input_rows[i] + result_dict = f.result() + query_gen = result_dict["query"] + reason = result_dict["reason"] + err = result_dict["err"] + # save custom metrics + if "latency_seconds" in result_dict: + row["latency_seconds"] = result_dict["latency_seconds"] + if "tokens_used" in result_dict: + row["tokens_used"] = result_dict["tokens_used"] + row["generated_query"] = query_gen + row["reason"] = reason + row["error_msg"] = err + # save failures into relevant columns in the dataframe + if "GENERATION ERROR" in err: + row["error_query_gen"] = 1 + elif "TIMEOUT" in err: + row["timeout"] = 1 + else: + expected_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + try: + is_correct = compare_query_results( + query_gold=expected_query, + query_gen=query_gen, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=row["question"], + query_category=row["query_category"], + decimal_points=args.decimal_points, + ) + if is_correct: + total_correct += 1 + row["is_correct"] = 1 + row["error_msg"] = "" + else: + row["is_correct"] = 0 + row["error_msg"] = "INCORRECT RESULTS" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"EXECUTION ERROR: {str(e)}" + output_rows.append(row) + pbar.set_description( + f"Accuracy: {round(total_correct/total_tried * 100, 2)}% ({total_correct}/{total_tried})" + ) + + # save results to csv + output_df = pd.DataFrame(output_rows) + output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + # get directory of output_file and create if not exist + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + output_df.to_csv(output_file, index=False, float_format="%.2f") + + # get average rate of correct results + avg_subset = output_df["is_correct"].sum() / len(output_df) + print(f"Average correct rate: {avg_subset:.2f}") + + results = output_df.to_dict("records") + # upload results + with open(prompt_file, "r") as f: + prompt = f.read() + if args.upload_url is not None: + upload_results( + results=results, + url=args.upload_url, + runner_type="anthropic", + prompt=prompt, + args=args, + ) diff --git a/eval/api_runner.py b/runners/api_runner.py similarity index 100% rename from eval/api_runner.py rename to runners/api_runner.py diff --git a/eval/bedrock_runner.py b/runners/bedrock_runner.py similarity index 100% rename from eval/bedrock_runner.py rename to runners/bedrock_runner.py diff --git a/eval/deepseek_runner.py b/runners/deepseek_runner.py similarity index 100% rename from eval/deepseek_runner.py rename to runners/deepseek_runner.py diff --git a/eval/gemini_runner.py b/runners/gemini_runner.py similarity index 98% rename from eval/gemini_runner.py rename to runners/gemini_runner.py index dc805db..3300ad3 100644 --- a/eval/gemini_runner.py +++ b/runners/gemini_runner.py @@ -42,7 +42,6 @@ def generate_prompt( with open(prompt_file, "r") as f: prompt = f.read() - question_instructions = question + " " + instructions if table_metadata_string == "": md = dbs[db_name]["table_metadata"] @@ -128,7 +127,7 @@ def process_row(row, model_name, args): query_gen=generated_query, db_name=db_name, db_type=db_type, - db_creds=db_creds_all[row["db_type"]], + db_creds=db_creds_all[db_type], question=question, query_category=query_category, decimal_points=args.decimal_points, diff --git a/eval/hf_runner.py b/runners/hf_runner.py similarity index 99% rename from eval/hf_runner.py rename to runners/hf_runner.py index 2f34540..9046a65 100644 --- a/eval/hf_runner.py +++ b/runners/hf_runner.py @@ -1,11 +1,9 @@ -import json import os from typing import Optional from eval.eval import compare_query_results import pandas as pd import torch -import traceback from transformers import ( AutoTokenizer, AutoModelForCausalLM, diff --git a/eval/llama_cpp_runner.py b/runners/llama_cpp_runner.py similarity index 96% rename from eval/llama_cpp_runner.py rename to runners/llama_cpp_runner.py index f084c26..0297ca0 100644 --- a/eval/llama_cpp_runner.py +++ b/runners/llama_cpp_runner.py @@ -1,5 +1,4 @@ import os -from concurrent.futures import ThreadPoolExecutor, as_completed from eval.eval import compare_query_results import pandas as pd @@ -70,11 +69,10 @@ def run_llama_cpp_eval(args): model_path = args.model output_file_list = args.output_file k_shot = args.k_shot - max_workers = args.parallel_threads db_type = args.db_type cot_table_alias = args.cot_table_alias - llm = Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=2048) + llm = Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=4096) for questions_file, prompt_file, output_file in zip( questions_file_list, prompt_file_list, output_file_list diff --git a/eval/mistral_runner.py b/runners/mistral_runner.py similarity index 99% rename from eval/mistral_runner.py rename to runners/mistral_runner.py index bd49b37..4abdf81 100644 --- a/eval/mistral_runner.py +++ b/runners/mistral_runner.py @@ -41,8 +41,6 @@ def generate_prompt( sys_prompt = prompt.split("System:")[1].split("User:")[0].strip() user_prompt = prompt.split("User:")[1].strip() - question_instructions = question + " " + instructions - if table_metadata_string == "": if public_data: from defog_data.metadata import dbs diff --git a/eval/mlx_runner.py b/runners/mlx_runner.py similarity index 98% rename from eval/mlx_runner.py rename to runners/mlx_runner.py index 254d4fe..e773008 100644 --- a/eval/mlx_runner.py +++ b/runners/mlx_runner.py @@ -1,5 +1,4 @@ import os -from concurrent.futures import ThreadPoolExecutor, as_completed from eval.eval import compare_query_results import pandas as pd diff --git a/runners/openai_runner.py b/runners/openai_runner.py new file mode 100644 index 0000000..5d207ef --- /dev/null +++ b/runners/openai_runner.py @@ -0,0 +1,260 @@ +import os +from time import time +from concurrent.futures import ThreadPoolExecutor, as_completed +import json + +import pandas as pd +import sqlparse +from tqdm import tqdm + +from eval.eval import compare_query_results +from utils.creds import db_creds_all +from utils.dialects import convert_postgres_ddl_to_dialect +from utils.gen_prompt import to_prompt_schema +from utils.questions import prepare_questions_df +from utils.reporting import upload_results +from utils.llm import chat_openai + + +def generate_prompt( + prompt_file, + question, + db_name, + db_type, + instructions="", + k_shot_prompt="", + glossary="", + table_metadata_string="", + prev_invalid_sql="", + prev_error_msg="", + public_data=True, + shuffle=True, +): + if public_data: + from defog_data.metadata import dbs + import defog_data.supplementary as sup + else: + from defog_data_private.metadata import dbs + import defog_data_private.supplementary as sup + + with open(prompt_file, "r") as f: + prompt = json.load(f) + + if table_metadata_string == "": + md = dbs[db_name]["table_metadata"] + pruned_metadata_ddl = to_prompt_schema(md, shuffle) + pruned_metadata_ddl = convert_postgres_ddl_to_dialect( + postgres_ddl=pruned_metadata_ddl, + to_dialect=db_type, + db_name=db_name, + ) + column_join = sup.columns_join.get(db_name, {}) + join_list = [] + for values in column_join.values(): + if isinstance(values[0], tuple): + for col_pair in values: + col_1, col_2 = col_pair + join_str = f"{col_1} can be joined with {col_2}" + if join_str not in join_list: + join_list.append(join_str) + else: + col_1, col_2 = values[0] + join_str = f"{col_1} can be joined with {col_2}" + if join_str not in join_list: + join_list.append(join_str) + if len(join_list) > 0: + join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) + else: + join_str = "" + pruned_metadata_str = pruned_metadata_ddl + join_str + else: + pruned_metadata_str = table_metadata_string + + if prompt[0]["role"] == "system": + prompt[0]["content"] = prompt[0]["content"].format( + db_type=db_type, + ) + prompt[1]["content"] = prompt[1]["content"].format( + user_question=question, + instructions=instructions, + table_metadata_string=pruned_metadata_str, + k_shot_prompt=k_shot_prompt, + ) + else: + prompt[0]["content"] = prompt[1]["content"].format( + db_type=db_type, + user_question=question, + instructions=instructions, + table_metadata_string=pruned_metadata_str, + k_shot_prompt=k_shot_prompt, + ) + return prompt + + +def process_row(row, model_name, args): + start_time = time() + messages = generate_prompt( + prompt_file=args.prompt_file[0], + question=row["question"], + db_name=row["db_name"], + db_type=args.db_type, + instructions=row["instructions"], + k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], + prev_invalid_sql=row["prev_invalid_sql"], + prev_error_msg=row["prev_error_msg"], + public_data=not args.use_private_data, + shuffle=args.shuffle_metadata, + ) + try: + response = chat_openai(messages=messages, model=model_name, temperature=0.0) + generated_query = ( + response.content.split("```sql", 1)[-1].split("```", 1)[0].strip() + ) + try: + generated_query = sqlparse.format( + generated_query, reindent=True, keyword_case="upper" + ) + except: + pass + return { + "query": generated_query, + "reason": "", + "err": "", + "latency_seconds": time() - start_time, + "tokens_used": response.input_tokens + response.output_tokens, + } + except Exception as e: + return { + "query": "", + "reason": "", + "err": f"GENERATION ERROR: {str(e)}", + "latency_seconds": time() - start_time, + "tokens_used": 0, + } + + +def run_openai_eval(args): + # get params from args + questions_file_list = args.questions_file + prompt_file_list = args.prompt_file + output_file_list = args.output_file + num_questions = args.num_questions + k_shot = args.k_shot + db_type = args.db_type + cot_table_alias = args.cot_table_alias + + for questions_file, prompt_file, output_file in zip( + questions_file_list, prompt_file_list, output_file_list + ): + print(f"Using prompt file {prompt_file}") + print("Preparing questions...") + print( + f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" + ) + question_query_df = prepare_questions_df( + questions_file, db_type, num_questions, k_shot, cot_table_alias + ) + input_rows = question_query_df.to_dict("records") + output_rows = [] + with ThreadPoolExecutor(args.parallel_threads) as executor: + futures = [] + for row in input_rows: + generated_query_fut = executor.submit( + process_row, + row=row, + model_name=args.model, + args=args, + ) + futures.append(generated_query_fut) + + total_tried = 0 + total_correct = 0 + for f in (pbar := tqdm(as_completed(futures), total=len(futures))): + total_tried += 1 + i = futures.index(f) + row = input_rows[i] + result_dict = f.result() + query_gen = result_dict["query"] + reason = result_dict["reason"] + err = result_dict["err"] + # save custom metrics + if "latency_seconds" in result_dict: + row["latency_seconds"] = result_dict["latency_seconds"] + if "tokens_used" in result_dict: + row["tokens_used"] = result_dict["tokens_used"] + row["generated_query"] = query_gen + row["reason"] = reason + row["error_msg"] = err + # save failures into relevant columns in the dataframe + if "GENERATION ERROR" in err: + row["error_query_gen"] = 1 + else: + expected_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + try: + is_correct = compare_query_results( + query_gold=expected_query, + query_gen=query_gen, + db_name=db_name, + db_type=db_type, + question=row["question"], + query_category=row["query_category"], + db_creds=db_creds_all[db_type], + ) + if is_correct: + total_correct += 1 + row["is_correct"] = 1 + row["error_msg"] = "" + else: + row["is_correct"] = 0 + row["error_msg"] = "INCORRECT RESULTS" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"EXECUTION ERROR: {str(e)}" + output_rows.append(row) + pbar.set_description( + f"Accuracy: {round(total_correct/total_tried * 100, 2)}% ({total_correct}/{total_tried})" + ) + + # save results to csv + output_df = pd.DataFrame(output_rows) + output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + if "prompt" in output_df.columns: + del output_df["prompt"] + # get num rows, mean correct, mean error_db_exec for each query_category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("is_correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + # get directory of output_file and create if not exist + output_dir = os.path.dirname(output_file) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + output_df.to_csv(output_file, index=False, float_format="%.2f") + + # get average rate of correct results + avg_subset = output_df["correct"].sum() / len(output_df) + print(f"Average correct rate: {avg_subset:.2f}") + + results = output_df.to_dict("records") + + # upload results + with open(prompt_file, "r") as f: + prompt = f.read() + if args.upload_url is not None: + upload_results( + results=results, + url=args.upload_url, + runner_type="openai", + prompt=prompt, + args=args, + ) diff --git a/eval/together_runner.py b/runners/together_runner.py similarity index 100% rename from eval/together_runner.py rename to runners/together_runner.py diff --git a/eval/vllm_runner.py b/runners/vllm_runner.py similarity index 100% rename from eval/vllm_runner.py rename to runners/vllm_runner.py