Skip to content

Commit 03611a2

Browse files
cpsievertCopilot
andauthored
Add deepcopy() support to Chat instances (#96)
* Allow Chat instances to be deepcopied * Update changelog * Update chatlas/_chat.py Co-authored-by: Copilot <[email protected]> * Update types --------- Co-authored-by: Copilot <[email protected]>
1 parent c2a7bfe commit 03611a2

File tree

5 files changed

+43
-1
lines changed

5 files changed

+43
-1
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
### Changes
1313

14-
* `ChatDatabricks()`'s `model` now defaults to `databricks-claude-3-7-sonnet` instead of `databricks-dbrx-instruct`
14+
* `ChatDatabricks()`'s `model` now defaults to `databricks-claude-3-7-sonnet` instead of `databricks-dbrx-instruct`. (#95)
15+
16+
### Improvements
17+
18+
* `Chat` instances can now be deep copied, which is useful for forking the chat session. (#96)
1519

1620
### Bug fixes
1721

chatlas/_chat.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import copy
34
import inspect
45
import os
56
import sys
@@ -1545,6 +1546,21 @@ def __repr__(self):
15451546
res += "\n" + turn.__repr__(indent=2)
15461547
return res + "\n"
15471548

1549+
def __deepcopy__(self, memo):
1550+
result = self.__class__.__new__(self.__class__)
1551+
1552+
# Avoid recursive references
1553+
memo[id(self)] = result
1554+
1555+
# Copy all attributes except the problematic provider attribute
1556+
for key, value in self.__dict__.items():
1557+
if key != "provider":
1558+
setattr(result, key, copy.deepcopy(value, memo))
1559+
else:
1560+
setattr(result, key, value)
1561+
1562+
return result
1563+
15481564

15491565
class ChatResponse:
15501566
"""

chatlas/types/openai/_submit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class SubmitInputArgs(TypedDict, total=False):
7070
"gpt-4o-search-preview-2025-03-11",
7171
"gpt-4o-mini-search-preview-2025-03-11",
7272
"chatgpt-4o-latest",
73+
"codex-mini-latest",
7374
"gpt-4o-mini",
7475
"gpt-4o-mini-2024-07-18",
7576
"gpt-4-turbo",

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ build-backend = "hatchling.build"
103103
[tool.hatch.version]
104104
source = "vcs"
105105

106+
# Need this to have github refs in dependencies
107+
# [tool.hatch.metadata]
108+
# allow-direct-references = true
109+
106110
[tool.hatch.build]
107111
skip-excluded-dirs = true
108112

tests/test_chat.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,20 @@ def test_json_serialize():
175175
# Completion objects, at least of right now, aren't included in the JSON
176176
turns[1].completion = None
177177
assert turns == turns_restored
178+
179+
180+
# Chat can be deepcopied/forked
181+
def test_deepcopy_chat():
182+
import copy
183+
184+
chat = ChatOpenAI()
185+
chat.chat("Hi", echo="none")
186+
chat_fork = copy.deepcopy(chat)
187+
188+
assert len(chat.get_turns()) == 2
189+
assert len(chat_fork.get_turns()) == 2
190+
191+
chat_fork.chat("Bye", echo="none")
192+
193+
assert len(chat.get_turns()) == 2
194+
assert len(chat_fork.get_turns()) == 4

0 commit comments

Comments
 (0)