Skip to content

Commit

Permalink
feat(model): Support new zhipuai SDK (#1592)
Browse files Browse the repository at this point in the history
Co-authored-by: yyhhyy <[email protected]>
Co-authored-by: Fangyin Cheng <[email protected]>
  • Loading branch information
3 people authored Jun 4, 2024
1 parent 85bf64e commit c3c0636
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions dbgpt/model/proxy/llms/zhipu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from concurrent.futures import Executor
from typing import Iterator, Optional

Expand Down Expand Up @@ -37,23 +38,37 @@ def __init__(
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_alias: Optional[str] = "zhipu_proxyllm",
context_length: Optional[int] = 8192,
executor: Optional[Executor] = None,
):
try:
import zhipuai
from zhipuai import ZhipuAI

except ImportError as exc:
raise ValueError(
"Could not import python package: zhipuai "
"Please install dashscope by command `pip install zhipuai"
) from exc
if (
"No module named" in str(exc)
or "cannot find module" in str(exc).lower()
):
raise ValueError(
"The python package 'zhipuai' is not installed. "
"Please install it by running `pip install zhipuai`."
) from exc
else:
raise ValueError(
"Could not import python package: zhipuai "
"This may be due to a version that is too low. "
"Please upgrade the zhipuai package by running `pip install --upgrade zhipuai`."
) from exc
if not model:
model = CHATGLM_DEFAULT_MODEL
if api_key:
zhipuai.api_key = api_key
if not api_key:
# Compatible with DB-GPT's config
api_key = os.getenv("ZHIPU_PROXY_API_KEY")

self._model = model
self.client = ZhipuAI(api_key=api_key, base_url=api_base)

super().__init__(
model_names=[model, model_alias],
Expand Down Expand Up @@ -84,26 +99,25 @@ def sync_generate_stream(
request: ModelRequest,
message_converter: Optional[MessageConverter] = None,
) -> Iterator[ModelOutput]:
import zhipuai

request = self.local_covert_message(request, message_converter)

messages = request.to_common_messages(support_system_role=False)

model = request.model or self._model
try:
res = zhipuai.model_api.sse_invoke(
response = self.client.chat.completions.create(
model=model,
prompt=messages,
messages=messages,
temperature=request.temperature,
# top_p=params.get("top_p"),
incremental=False,
stream=True,
)
for r in res.events():
if r.event == "add":
yield ModelOutput(text=r.data, error_code=0)
elif r.event == "error":
yield ModelOutput(text=r.data, error_code=1)
partial_text = ""
for chunk in response:
delta_content = chunk.choices[0].delta.content
partial_text += delta_content
yield ModelOutput(text=partial_text, error_code=0)
except Exception as e:
return ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
Expand Down

0 comments on commit c3c0636

Please sign in to comment.