-
Notifications
You must be signed in to change notification settings - Fork 24
Minimal edit to enable Deepseek prefilling #96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
302867e
e9843f7
0b6541a
7d80a38
71fe0d4
a30b651
cd3190f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -117,11 +117,29 @@ async def _make_api_call(self, prompt: Prompt, model_id, start_time, **kwargs) - | |
| ) | ||
| else: | ||
| api_func = self.aclient.chat.completions.create | ||
| api_response: openai.types.chat.ChatCompletion = await api_func( | ||
| messages=prompt.openai_format(), | ||
| model=model_id, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| original_base_url = self.aclient.base_url | ||
| try: | ||
| if model_id in {"deepseek-chat", "deepseek-reasoner"}: | ||
| if prompt.is_last_message_assistant(): | ||
| # Use the beta endpoint for assistant prefilled prompts with DeepSeek | ||
| self.aclient.base_url = "https://api.deepseek.com/beta" | ||
| else: | ||
| # Use the standard v1 endpoint otherwise | ||
| self.aclient.base_url = "https://api.deepseek.com/v1" | ||
| messages = prompt.deepseek_format() | ||
| else: | ||
| messages = prompt.openai_format() | ||
|
|
||
| api_response: openai.types.chat.ChatCompletion = await api_func( | ||
| messages=messages, | ||
| model=model_id, | ||
| **kwargs, | ||
| ) | ||
| finally: | ||
| # Always revert the base_url after the call | ||
| self.aclient.base_url = original_base_url | ||
|
Comment on lines
+139
to
+141
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one thought - could this have strange async race conditions :/
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, maybe the base url should be passed directly to the api_func on a call-wise basis (rather than setting it as an attribute of the entire class). Since the class itself could be handling many requests with different models (and even different providers if it was set up differently).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah that's true, should we have an asyncio lock maybe?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The api_func doesn't accept base url unfortunately. And i guess locking would harm concurrency.. Would you be against locking?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think locking would mess things up in terms of throughput since it would lock until the async call is complete which would be bad. Perhaps you can override the URL via "extra_headers"?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm still a little worried about this. Can we just use "https://api.deepseek.com/beta" always? Then we can set in api.py and remove all this logic internally of needing to swap between
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what ended up happening here? |
||
|
|
||
| if hasattr(api_response, "error") and ( | ||
| "Rate limit exceeded" in api_response.error["message"] or api_response.error["code"] == 429 | ||
| ): # OpenRouter routes through the error messages from the different providers, so we catch them here | ||
|
|
@@ -160,6 +178,7 @@ async def _make_api_call(self, prompt: Prompt, model_id, start_time, **kwargs) - | |
| duration=duration, | ||
| cost=context_cost + count_tokens(choice.message.content, model_id) * completion_token_cost, | ||
| logprobs=(self.convert_top_logprobs(choice.logprobs) if choice.logprobs is not None else None), | ||
| reasoning_content=getattr(choice.message, "reasoning_content", None), | ||
| ) | ||
| ) | ||
| self.add_response_to_prompt_file(prompt_file, responses) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a DEEPSEEK_MODELS dict somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(or list)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have one in
safety-tooling/safetytooling/apis/inference/api.py, but that would result in a circular import. We can create a constants file somewhere maybe?