Skip to content

Commit

Permalink
Merge pull request #260 from briefercloud/fix-ai-big-schema
Browse files Browse the repository at this point in the history
give table info direct to ai api via payload
  • Loading branch information
vieiralucas authored Nov 26, 2024
2 parents be5d442 + fbfaa8d commit a77c6a7
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 414 deletions.
87 changes: 10 additions & 77 deletions ai/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,16 @@
# from langchain.globals import set_verbose
# set_verbose(True)

import tempfile
import json
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi import FastAPI, Depends, HTTPException, status
from pydantic import BaseModel
from typing import List, Optional
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from decouple import config
from api.chains.sql import create_sql_query_chain
from api.llms import initialize_llm
from api.chains.sql_edit import create_sql_edit_query_chain
from api.chains.python_edit import create_python_edit_query_chain
from api.chains.stream.python_edit import create_python_edit_stream_query_chain
from api.chains.stream.sql_edit import create_sql_edit_stream_query_chain
from api.chains.vega import create_vega_chain
import secrets
from sqlalchemy.engine import create_engine
from typing import Any
Expand Down Expand Up @@ -52,54 +45,25 @@ def get_current_username(credentials: HTTPBasicCredentials = Depends(security)):
)
return credentials.username

class SQLInputData(BaseModel):
credentialsInfo: Any
databaseURL: str
question: str
modelId: Optional[str] = None

@app.post("/v1/sql")
async def v1_sql(data: SQLInputData, _ = Depends(get_current_username)):
engine = get_database_engine(data.databaseURL, data.credentialsInfo)
db = SQLDatabase(engine=engine)

llm = initialize_llm(model_id=data.modelId)
chain = create_sql_query_chain(llm, db)
res = chain.invoke({"question": data.question})

return res["text"]


class VegaInputData(BaseModel):
sql: str
model_id: Optional[str] = None

@app.post("/v1/vega")
async def v1_vega(data: VegaInputData, _ = Depends(get_current_username)):
llm = initialize_llm(model_id=data.modelId)
chain = create_vega_chain(llm)
res = chain.invoke({"sql": data.sql})

return res["text"]

class SQLEditInputData(BaseModel):
databaseURL: str
credentialsInfo: Any
query: str
instructions: str
dialect: str
tableInfo: Optional[str] = None
modelId: Optional[str] = None
openaiApiKey: Optional[str] = None

@app.post("/v1/sql/edit")
async def v1_sql_edit(data: SQLEditInputData, _ = Depends(get_current_username)):
engine = get_database_engine(data.databaseURL, data.credentialsInfo)
db = SQLDatabase(engine=engine)

@app.post("/v1/stream/sql/edit")
async def v1_steam_sql_edit(data: SQLEditInputData, _ = Depends(get_current_username)):
llm = initialize_llm(model_id=data.modelId, openai_api_key=data.openaiApiKey)
chain = create_sql_edit_query_chain(llm, db)
res = chain.invoke({"query": data.query, "instructions": data.instructions})
chain = create_sql_edit_stream_query_chain(llm, data.dialect, data.tableInfo)

return res["text"]
async def generate():
async for result in chain.astream({"query": data.query, "instructions": data.instructions}):
yield json.dumps(result) + "\n"

return StreamingResponse(generate(), media_type="text/plain")

class PythonEditInputData(BaseModel):
source: str
Expand All @@ -109,37 +73,6 @@ class PythonEditInputData(BaseModel):
modelId: Optional[str] = None
openaiApiKey: Optional[str] = None

@app.post("/v1/python/edit")
async def v1_python_edit(data: PythonEditInputData, _ = Depends(get_current_username)):
llm = initialize_llm(model_id=data.modelId, openai_api_key=data.openaiApiKey)
chain = create_python_edit_query_chain(llm)
res = chain.invoke({"source": data.source, "instructions": data.instructions, "allowed_libraries": data.allowedLibraries})

return res["text"]

@app.post("/v1/stream/sql/edit")
async def v1_steam_sql_edit(data: SQLEditInputData, _ = Depends(get_current_username)):
with tempfile.NamedTemporaryFile(delete=True) as temp_cert_file:
credentialsInfo = data.credentialsInfo
cert_temp_file_path = None
if data.credentialsInfo and "sslrootcert" in data.credentialsInfo:
certdata = bytes.fromhex(data.credentialsInfo["sslrootcert"])
temp_cert_file.write(certdata)
temp_cert_file.flush()
cert_temp_file_path = temp_cert_file.name
credentialsInfo = None

engine = None
if data.databaseURL != "duckdb":
engine = get_database_engine(data.databaseURL, credentialsInfo, cert_temp_file_path)
llm = initialize_llm(model_id=data.modelId, openai_api_key=data.openaiApiKey)
chain = create_sql_edit_stream_query_chain(llm, engine)

async def generate():
async for result in chain.astream({"query": data.query, "instructions": data.instructions}):
yield json.dumps(result) + "\n"

return StreamingResponse(generate(), media_type="text/plain")

@app.post("/v1/stream/python/edit")
async def v1_stream_python_edit(data: PythonEditInputData, _ = Depends(get_current_username)):
Expand Down
47 changes: 0 additions & 47 deletions ai/api/chains/python_edit.py

This file was deleted.

54 changes: 0 additions & 54 deletions ai/api/chains/sql.py

This file was deleted.

56 changes: 0 additions & 56 deletions ai/api/chains/sql_edit.py

This file was deleted.

7 changes: 2 additions & 5 deletions ai/api/chains/stream/sql_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def get_table_info(engine):
system_catalogs = ["system", "information_schema", "current", "jmx", "memory"]
all_catalogs = [row[0] for row in engine.execute("SHOW CATALOGS").fetchall()]
user_catalogs = list(filter(lambda c: c not in system_catalogs, all_catalogs))
print(user_catalogs)
with ThreadPoolExecutor() as executor:
catalog_engines = [create_engine(engine.url.set(database=catalog)) for catalog in user_catalogs]
results = executor.map(lambda ce: get_catalog_table_info(ce), catalog_engines)
Expand All @@ -83,15 +82,13 @@ def get_table_info(engine):
return table_info[:100000]


def create_sql_edit_stream_query_chain(llm, engine):
table_info = get_table_info(engine)

def create_sql_edit_stream_query_chain(llm, dialect, table_info):
prompt = PromptTemplate(
template=template,
input_variables=["query", "instructions"],
partial_variables={
"table_info": table_info,
"dialect": engine.dialect.name if engine else "DuckDB",
"dialect": dialect,
},
)

Expand Down
44 changes: 0 additions & 44 deletions ai/api/chains/vega.py

This file was deleted.

Loading

0 comments on commit a77c6a7

Please sign in to comment.