diff --git a/README.md b/README.md index 590f9e6..42ff779 100644 --- a/README.md +++ b/README.md @@ -242,6 +242,10 @@ python utils/api_server.py \ --enable-lora \ --max-lora-rank 64 +# to use vllm's openai-compatible api server +export MODEL_NAME="Qwen/Qwen3-4B" +vllm serve "$MODEL_NAME" --port 8000 + # to run sql-eval using the api runner - depending on how much your GPUs can take, can increase p and b to higher values python main.py \ -db postgres \ @@ -304,6 +308,16 @@ python -W ignore main.py \ While you can do the same for the other runners, the time savings are most significant when loading a large model locally, vs calling an always-on API. +#### Thinking Models + +If you'd like to use a model that outputs thinking tokens, you can pass in the flag `--enable_thinking` to the runner so that we will remove the thinking tokens from the LLM output before running the generated query. + +You can check out `run_qwen.sh` for an example of how to run a thinking model. + +```bash +./run_qwen.sh --thinking # add --thinking to generate thinking tokens +``` + ### Bedrock ```bash diff --git a/main.py b/main.py index b6500fc..d0cb6cf 100644 --- a/main.py +++ b/main.py @@ -30,6 +30,7 @@ parser.add_argument( "--cot_table_alias", type=str, choices=["instruct", "pregen", ""], default="" ) + parser.add_argument("--thinking", action="store_true") # execution-related parameters parser.add_argument("-o", "--output_file", nargs="+", type=str, required=True) parser.add_argument("-p", "--parallel_threads", type=int, default=5) @@ -83,7 +84,7 @@ from runners.openai_runner import run_openai_eval if args.model is None: - args.model = "gpt-3.5-turbo-0613" + args.model = "gpt-4o" run_openai_eval(args) elif args.model_type == "anthropic": from runners.anthropic_runner import run_anthropic_eval @@ -108,7 +109,11 @@ elif args.model_type == "api": assert args.api_url is not None, "api_url must be provided for api model" assert args.api_type is not None, "api_type must be provided for api model" - assert args.api_type in ["vllm", "tgi"], "api_type must be one of 'vllm', 'tgi'" + assert args.api_type in [ + "openai", + "vllm", + "tgi", + ], "api_type must be one of 'openai', 'vllm', 'tgi'" from runners.api_runner import run_api_eval diff --git a/prompts/prompt_qwen.json b/prompts/prompt_qwen.json new file mode 100644 index 0000000..1a56f8e --- /dev/null +++ b/prompts/prompt_qwen.json @@ -0,0 +1,11 @@ +[ + { + "role": "system", + "content": "Your task is to convert a user question to a {db_type} query, given a database schema." + }, + { + "role": "user", + "content": "Generate a SQL query that answers the question `{user_question}`.\n{instructions}\nThis query will run on a database whose schema is represented in this string:\n{table_metadata_string}\n{join_str}\n{table_aliases}\nAfter reasoning, return only the SQL query, and nothing else." + } +] + \ No newline at end of file diff --git a/run_model_cot.sh b/run_model_cot.sh index 3c09d67..656ae09 100755 --- a/run_model_cot.sh +++ b/run_model_cot.sh @@ -24,7 +24,7 @@ for model_name in "${model_names[@]}"; do echo "Running model ${model_name}" # first, get the API up - python3 utils/api_server.py --model "${model_name}" --tensor-parallel-size 1 --dtype float16 --max-model-len 8192 --gpu-memory-utilization 0.40 --block-size 16 --disable-log-requests --port "${PORT}" & + python3 utils/api_server.py --model "${model_name}" --tensor-parallel-size 1 --dtype float16 --max-model-len 16384 --gpu-memory-utilization 0.90 --block-size 16 --disable-log-requests --port "${PORT}" & # run a loop to check if the http://localhost:8080/health endpoint returns a valid 200 result while true; do diff --git a/run_qwen.sh b/run_qwen.sh new file mode 100755 index 0000000..93aca68 --- /dev/null +++ b/run_qwen.sh @@ -0,0 +1,36 @@ +export db_type="postgres" +export prompt_file="prompts/prompt_qwen.json" +export model_name="qwen" +export PORT=8000 + +# assume you already have the vllm server running +# vllm serve "$model_name" --port 8000 + +if [[ "$1" == "--thinking" ]]; then + echo "Running sql-eval on $model_name with thinking tokens" + python3 main.py -db "${db_type}" \ + -f "${prompt_file}" \ + -q "data/questions_gen_${db_type}.csv" "data/instruct_basic_${db_type}.csv" "data/instruct_advanced_${db_type}.csv" \ + -o "results/${model_name}/openai_api_v1.csv" "results/${model_name}/openai_api_basic.csv" "results/${model_name}/openai_api_advanced.csv" \ + -g api \ + -m "Qwen/Qwen3-4B" \ + -b 1 \ + -c 0 \ + --thinking \ + --api_url "http://localhost:${PORT}/v1/chat/completions" \ + --api_type "openai" \ + -p 10 +else + echo "Running sql-eval on $model_name without generating thinking tokens" + python3 main.py -db "${db_type}" \ + -f "${prompt_file}" \ + -q "data/questions_gen_${db_type}.csv" "data/instruct_basic_${db_type}.csv" "data/instruct_advanced_${db_type}.csv" \ + -o "results/${model_name}/openai_api_v1.csv" "results/${model_name}/openai_api_basic.csv" "results/${model_name}/openai_api_advanced.csv" \ + -g api \ + -m "Qwen/Qwen3-4B" \ + -b 1 \ + -c 0 \ + --api_url "http://localhost:${PORT}/v1/chat/completions" \ + --api_type "openai" \ + -p 10 +fi \ No newline at end of file diff --git a/runners/anthropic_runner.py b/runners/anthropic_runner.py index aa9f571..0779bc8 100644 --- a/runners/anthropic_runner.py +++ b/runners/anthropic_runner.py @@ -15,6 +15,7 @@ from utils.llm import chat_anthropic import json + def generate_prompt( prompt_file, question, diff --git a/runners/api_runner.py b/runners/api_runner.py index 6b2afdf..50f4267 100644 --- a/runners/api_runner.py +++ b/runners/api_runner.py @@ -1,7 +1,7 @@ import json import os from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Optional +from typing import Dict, List, Optional from eval.eval import compare_query_results import pandas as pd @@ -75,6 +75,25 @@ def mk_tgi_json(prompt, num_beams): } +def mk_openai_json( + prompt: List[Dict[str, str]], + model_name: Optional[str] = None, + thinking: bool = False, +): + # See https://docs.vllm.ai/en/v0.7.1/serving/openai_compatible_server.html + # for the full list of routes and their respective parameters + data = { + "model": model_name, + "messages": prompt, + "temperature": 0, + "max_tokens": 32768, + } + # See https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes + if not thinking: + data["chat_template_kwargs"] = {"enable_thinking": False} + return data + + def process_row( row, api_url: str, @@ -84,6 +103,8 @@ def process_row( logprobs: bool = False, sql_lora_path: Optional[str] = None, sql_lora_name: Optional[str] = None, + model_name: Optional[str] = None, + thinking: bool = False, ): start_time = time() if api_type == "tgi": @@ -92,6 +113,8 @@ def process_row( json_data = mk_vllm_json( row["prompt"], num_beams, logprobs, sql_lora_path, sql_lora_name ) + elif api_type == "openai": + json_data = mk_openai_json(row["prompt"], model_name, thinking) else: # add any custom JSON data here, e.g. for a custom API json_data = { @@ -129,7 +152,7 @@ def process_row( except KeyError: print(r.json()) generated_query = "" - elif "[SQL]" not in row["prompt"]: + elif "[SQL]" not in row["prompt"] and api_type == "vllm": generated_query = ( r.json()["text"][0] .split("```sql")[-1] @@ -139,9 +162,18 @@ def process_row( + ";" ) else: - generated_query = r.json()["text"][0] + response_json = r.json() + if "choices" in response_json: + generated_query = response_json["choices"][0]["message"]["content"] + elif "text" in response_json: + generated_query = response_json["text"][0] + else: + print(f"choice/text not found as a key in response:\n{response_json}") + generated_query = "" if "[SQL]" in generated_query: generated_query = generated_query.split("[SQL]", 1)[1].strip() + elif "" in generated_query: + generated_query = generated_query.split("", 1)[-1].strip() else: generated_query = generated_query.strip() @@ -224,6 +256,9 @@ def run_api_eval(args): sql_lora_path = args.adapter if args.adapter else None sql_lora_name = args.adapter_name if args.adapter_name else None run_name = args.run_name if args.run_name else None + model_name = args.model if args.model else None + thinking = True if args.thinking else False + if sql_lora_path: print("Using LoRA adapter at:", sql_lora_path) if logprobs: @@ -250,6 +285,7 @@ def run_api_eval(args): questions_file, db_type, num_questions, k_shot, cot_table_alias ) # create a prompt for each question + # note that prompt will be a list of dicts for json prompt templates df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -294,6 +330,8 @@ def run_api_eval(args): logprobs, sql_lora_path, sql_lora_name, + model_name, + thinking, ) ) diff --git a/utils/gen_prompt.py b/utils/gen_prompt.py index bc089b8..ede9c07 100644 --- a/utils/gen_prompt.py +++ b/utils/gen_prompt.py @@ -1,6 +1,6 @@ from copy import deepcopy import json -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import numpy as np from utils.dialects import ( ddl_to_bigquery, @@ -123,7 +123,7 @@ def generate_prompt( columns_to_keep=40, shuffle_metadata=False, table_aliases="", -): +) -> Union[List[Dict[str, str]], str]: """ Generates the prompt for the given question. If a json file is passed in as the prompt_file, please ensure that it is a list