-
Notifications
You must be signed in to change notification settings - Fork 13
Change session_metadata_api Session from std lib dataclass to Pydantic #313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,43 +35,49 @@ | |
| # BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF | ||
| # DATA. | ||
| # | ||
| import dataclasses | ||
| import json | ||
| from dataclasses import dataclass, field | ||
| from datetime import datetime | ||
| from typing import List, Any, Optional | ||
|
Comment on lines
+38
to
40
|
||
|
|
||
| import requests | ||
| from pydantic import BaseModel, ConfigDict, Field, alias_generators | ||
|
|
||
| from app.config import settings | ||
| from app.services.utils import raise_for_http_error, body_to_json | ||
|
|
||
|
|
||
| @dataclass | ||
| class SessionQueryConfiguration: | ||
| class SessionQueryConfiguration(BaseModel): | ||
| model_config = ConfigDict( | ||
| alias_generator=alias_generators.to_camel, | ||
| validate_by_name=True, | ||
| revalidate_instances="always", | ||
| ) | ||
|
|
||
| enable_hyde: bool | ||
| enable_summary_filter: bool | ||
| enable_tool_calling: bool = False | ||
| selected_tools: list[str] = field(default_factory=list) | ||
| selected_tools: list[str] = Field(default_factory=list) | ||
| disable_streaming: bool = False | ||
|
|
||
|
|
||
| @dataclass | ||
| class Session: | ||
| class Session(BaseModel): | ||
| model_config = ConfigDict( | ||
| alias_generator=alias_generators.to_camel, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This allows us to output a camelCased dictionary by calling |
||
| validate_by_name=True, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is needed to allow fields to still be set by their snake_case names. |
||
| revalidate_instances="always", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without this, Probably not necessary everywhere we have Pydantic models at this time, but I need it here for a check in #307. |
||
| ) | ||
|
|
||
| id: int | ||
| name: str | ||
| data_source_ids: List[int] | ||
| data_source_ids: list[int] | ||
| project_id: int | ||
| time_created: datetime | ||
| time_updated: datetime | ||
| created_by_id: str | ||
| updated_by_id: str | ||
|
Comment on lines
-64
to
-67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing these because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does removing these raise an error while updating sessions? or it just works fine? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great question! I've tested this locally, and encounter no errors when updating a session, nor creating a new one (which calls We're not actually using these fields when making session update requests anyway (since they're managed at the DB layer), and we're also not using them here in the Python service. |
||
| inference_model: str | ||
| rerank_model: str | ||
| rerank_model: Optional[str] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems |
||
| response_chunks: int | ||
| query_configuration: SessionQueryConfiguration | ||
| associated_data_source_id: Optional[int] = None | ||
|
|
||
| def get_all_data_source_ids(self) -> List[int]: | ||
| def get_all_data_source_ids(self) -> list[int]: | ||
| """ | ||
| Returns all data source IDs associated with the session. | ||
| If the session has an associated data source ID, it is included in the list. | ||
|
|
@@ -81,14 +87,14 @@ def get_all_data_source_ids(self) -> List[int]: | |
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| @dataclasses.dataclass | ||
| class UpdatableSession: | ||
| id: int | ||
| name: str | ||
| dataSourceIds: List[int] | ||
| projectId: int | ||
| inferenceModel: str | ||
| rerankModel: str | ||
| rerankModel: Optional[str] | ||
| responseChunks: int | ||
| queryConfiguration: dict[str, bool | List[str]] | ||
| associatedDataSourceId: Optional[int] | ||
|
|
@@ -114,10 +120,6 @@ def session_from_java_response(data: dict[str, Any]) -> Session: | |
| name=data["name"], | ||
| data_source_ids=data["dataSourceIds"], | ||
| project_id=data["projectId"], | ||
| time_created=datetime.fromtimestamp(data["timeCreated"]), | ||
| time_updated=datetime.fromtimestamp(data["timeUpdated"]), | ||
| created_by_id=data["createdById"], | ||
| updated_by_id=data["updatedById"], | ||
| inference_model=data["inferenceModel"], | ||
| rerank_model=data["rerankModel"], | ||
| response_chunks=data["responseChunks"], | ||
|
|
@@ -127,9 +129,7 @@ def session_from_java_response(data: dict[str, Any]) -> Session: | |
| enable_tool_calling=data["queryConfiguration"].get( | ||
| "enableToolCalling", False | ||
| ), | ||
| disable_streaming=data["queryConfiguration"].get( | ||
| "disableStreaming", False | ||
| ), | ||
| disable_streaming=data["queryConfiguration"].get("disableStreaming", False), | ||
mliu-cloudera marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| selected_tools=data["queryConfiguration"]["selectedTools"] or [], | ||
| ), | ||
| associated_data_source_id=data.get("associatedDataSourceId", None), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The dataclasses import is only used for the UpdatableSession class. Consider removing this import and converting UpdatableSession to Pydantic as well for consistency, or add a comment explaining why UpdatableSession remains a dataclass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the time it would take (to be fair, minutes) to test changing
UpdatableSessiontoo isn't too worth it right now. But maybe I'm being too lazy 😶