Skip to content

Commit cf6fd2e

Browse files
authored
fix: Enhance chat response handling for v1 and v2 endpoints (#199)
* fix: Enhance chat response handling for v1 and v2 endpoints * fix: Improve handling of chat message formats for v1 and v2 endpoints * fix: Add extract_message_content utility for handling chat API responses * fix: Addressing the lint issues
1 parent da25c56 commit cf6fd2e

File tree

8 files changed

+393
-33
lines changed

8 files changed

+393
-33
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ Change Log
1313
1414
Unreleased
1515
**********
16+
4.10.8 - 2025-07-31
17+
*******************
18+
* Chat history to support XPert Chat API V2 response with a list of messages.
19+
1620
4.10.7 - 2025-07-15
1721
*******************
1822
* XPert Chat API response will be message object for V1 endpoint and a list of messages for V2 endpoint.

learning_assistant/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
Plugin for a learning assistant backend, intended for use within edx-platform.
33
"""
44

5-
__version__ = '4.10.7'
5+
__version__ = '4.10.8'
66

77
default_app_config = 'learning_assistant.apps.LearningAssistantConfig' # pylint: disable=invalid-name

learning_assistant/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,32 @@ def parse_lms_datetime(datetime_string):
208208
return None
209209

210210
return parsed_datetime
211+
212+
213+
def extract_message_content(message):
214+
"""
215+
Extract content from a message response handling both v1 and v2 endpoint formats.
216+
217+
Args:
218+
message: The message response from the chat API. Can be:
219+
- v2 format: List of message objects
220+
- v1 format: Single message dict with 'content' key
221+
- Error format: Plain string or other format
222+
223+
Returns:
224+
str: The extracted message content, or empty string for empty lists
225+
"""
226+
if v2_endpoint_enabled() and isinstance(message, list):
227+
# For v2 endpoint, message is an array - get the last message content
228+
if len(message) > 0 and isinstance(message[-1], dict):
229+
return message[-1].get('content', '')
230+
elif len(message) > 0:
231+
return str(message[-1])
232+
else:
233+
return '' # Fallback for empty list
234+
elif isinstance(message, dict) and 'content' in message:
235+
# For v1 endpoint, message is a dict with content key
236+
return message['content']
237+
else:
238+
# Fallback for other formats (e.g., error strings)
239+
return str(message)

learning_assistant/views.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from learning_assistant.serializers import MessageSerializer
3939
from learning_assistant.toggles import chat_history_enabled
4040
from learning_assistant.utils import (
41+
extract_message_content,
4142
get_audit_trial_length_days,
4243
get_chat_response,
4344
parse_lms_datetime,
@@ -115,7 +116,8 @@ def _get_next_message(self, request, courserun_key, course_run_id):
115116
status_code, message = get_chat_response(prompt_template, message_list)
116117

117118
if chat_history_enabled(courserun_key):
118-
save_chat_message(courserun_key, user_id, LearningAssistantMessage.ASSISTANT_ROLE, message['content'])
119+
content = extract_message_content(message)
120+
save_chat_message(courserun_key, user_id, LearningAssistantMessage.ASSISTANT_ROLE, content)
119121

120122
return Response(status=status_code, data=message)
121123

test_settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def root(*args):
4343

4444
SECRET_KEY = 'insecure-secret-key'
4545

46+
# Timezone settings
47+
USE_TZ = True
48+
TIME_ZONE = 'UTC' # Use UTC to avoid timezone conversion issues in tests
49+
4650
MIDDLEWARE = (
4751
'django.contrib.sessions.middleware.SessionMiddleware',
4852
'django.contrib.auth.middleware.AuthenticationMiddleware',

tests/test_api.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from django.contrib.auth import get_user_model
1313
from django.core.cache import cache
1414
from django.test import TestCase, override_settings
15+
from django.utils import timezone
1516
from freezegun import freeze_time
1617
from opaque_keys import InvalidKeyError
1718
from opaque_keys.edx.keys import CourseKey, UsageKey
@@ -568,18 +569,18 @@ def setUp(self):
568569
datetime(2024, 1, 29),
569570
)
570571
def test_exists(self, audit_trial_expiration_date):
571-
audit_trial_start_date = datetime.now()
572+
audit_trial_start_date = timezone.now()
572573

573574
LearningAssistantAuditTrial.objects.create(
574575
user=self.user,
575576
start_date=audit_trial_start_date,
576-
expiration_date=audit_trial_expiration_date,
577+
expiration_date=timezone.make_aware(audit_trial_expiration_date)
577578
)
578579

579580
expected_return = LearningAssistantAuditTrialData(
580581
user_id=self.user.id,
581582
start_date=audit_trial_start_date,
582-
expiration_date=audit_trial_expiration_date,
583+
expiration_date=timezone.make_aware(audit_trial_expiration_date),
583584
)
584585
self.assertEqual(expected_return, get_audit_trial(self.user))
585586

@@ -608,19 +609,19 @@ def setUp(self):
608609
)
609610
@patch('learning_assistant.api.get_audit_trial_expiration_date_from_start_date')
610611
def test_exists_get(self, audit_trial_expiration_date, get_audit_trial_expiration_date_mock):
611-
audit_trial_start_date = datetime.now()
612-
get_audit_trial_expiration_date_mock.return_value = audit_trial_expiration_date
612+
audit_trial_start_date = timezone.now()
613+
get_audit_trial_expiration_date_mock.return_value = timezone.make_aware(audit_trial_expiration_date)
613614

614615
LearningAssistantAuditTrial.objects.create(
615616
user=self.user,
616617
start_date=audit_trial_start_date,
617-
expiration_date=audit_trial_expiration_date,
618+
expiration_date=timezone.make_aware(audit_trial_expiration_date),
618619
)
619620

620621
expected_return = LearningAssistantAuditTrialData(
621622
user_id=self.user.id,
622623
start_date=audit_trial_start_date,
623-
expiration_date=audit_trial_expiration_date,
624+
expiration_date=timezone.make_aware(audit_trial_expiration_date),
624625
)
625626
self.assertEqual(expected_return, get_or_create_audit_trial(self.user, 'verified'))
626627

@@ -634,13 +635,13 @@ def test_not_exists_create(self, audit_trial_expiration_date, get_audit_trial_ex
634635
other_user = User(username='other-tester', email='other-tester@test.com')
635636
other_user.save()
636637

637-
start_date = datetime.now()
638-
get_audit_trial_expiration_date_mock.return_value = audit_trial_expiration_date
638+
start_date = timezone.now()
639+
get_audit_trial_expiration_date_mock.return_value = timezone.make_aware(audit_trial_expiration_date)
639640

640641
expected_return = LearningAssistantAuditTrialData(
641642
user_id=self.user.id,
642643
start_date=start_date,
643-
expiration_date=audit_trial_expiration_date
644+
expiration_date=timezone.make_aware(audit_trial_expiration_date)
644645
)
645646

646647
self.assertEqual(expected_return, get_or_create_audit_trial(self.user, 'verified'))
@@ -660,7 +661,7 @@ def setUp(self):
660661
self.user.save()
661662

662663
def test_upgrade_deadline_expired(self):
663-
today = datetime.now()
664+
today = timezone.now()
664665
mock_enrollment = MagicMock()
665666
mock_enrollment.upgrade_deadline = today - timedelta(days=1) # yesterday
666667

@@ -674,7 +675,7 @@ def test_upgrade_deadline_expired(self):
674675
self.assertEqual(audit_trial_is_expired(mock_enrollment, audit_trial_data), True)
675676

676677
def test_upgrade_deadline_none(self):
677-
today = datetime.now()
678+
today = timezone.now()
678679
mock_enrollment = MagicMock()
679680
mock_enrollment.upgrade_deadline = None
680681

@@ -705,10 +706,12 @@ def test_upgrade_deadline_none(self):
705706
datetime(year=2024, month=1, day=1) - timedelta(days=settings.LEARNING_ASSISTANT_AUDIT_TRIAL_LENGTH_DAYS + 1),
706707
)
707708
def test_audit_trial_expired(self, start_date):
708-
today = datetime.now()
709+
today = timezone.now()
709710
mock_enrollment = MagicMock()
710711
mock_enrollment.upgrade_deadline = today + timedelta(days=1) # tomorrow
711712

713+
# Convert naive start_date to timezone-aware
714+
start_date = timezone.make_aware(start_date)
712715
audit_trial_data = LearningAssistantAuditTrialData(
713716
user_id=self.user.id,
714717
start_date=start_date,
@@ -718,7 +721,7 @@ def test_audit_trial_expired(self, start_date):
718721
self.assertEqual(audit_trial_is_expired(mock_enrollment, audit_trial_data), True)
719722

720723
def test_audit_trial_unexpired(self):
721-
today = datetime.now()
724+
today = timezone.now()
722725
mock_enrollment = MagicMock()
723726
mock_enrollment.upgrade_deadline = today + timedelta(days=1) # tomorrow
724727

tests/test_utils.py

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from learning_assistant.constants import LMS_DATETIME_FORMAT
1515
from learning_assistant.utils import (
16+
extract_message_content,
1617
get_audit_trial_length_days,
1718
get_chat_response,
1819
get_optimizely_variation,
@@ -148,6 +149,31 @@ def test_post_request_structure_v2_endpoint(self, mock_requests, mock_v2_enabled
148149
timeout=(connect_timeout, read_timeout)
149150
)
150151

152+
@ddt.data(
153+
# (v2_enabled, response_data, description)
154+
(False, {'role': 'assistant', 'content': 'v1 response'}, 'v1 single dict'),
155+
(True, [{'role': 'assistant', 'content': 'v2 response'}], 'v2 array of dicts'),
156+
(True, {'role': 'assistant', 'content': 'v2 dict'}, 'v2 unexpected dict format'),
157+
)
158+
@ddt.unpack
159+
@responses.activate
160+
@patch('learning_assistant.utils.v2_endpoint_enabled')
161+
def test_endpoint_response_formats(self, v2_enabled, response_data, description, mock_v2_enabled):
162+
"""Test that both v1 and v2 endpoint response formats are handled correctly."""
163+
mock_v2_enabled.return_value = v2_enabled
164+
165+
endpoint = settings.CHAT_COMPLETION_API_V2 if v2_enabled else settings.CHAT_COMPLETION_API
166+
responses.add(
167+
responses.POST,
168+
endpoint,
169+
status=200,
170+
body=json.dumps(response_data),
171+
)
172+
173+
status_code, message = self.get_response()
174+
self.assertEqual(status_code, 200, f"Failed for {description}")
175+
self.assertEqual(message, response_data, f"Response mismatch for {description}")
176+
151177

152178
class GetReducedMessageListTests(TestCase):
153179
"""
@@ -225,7 +251,9 @@ def test_get_audit_trial_length_days_no_setting(self):
225251
)
226252
@ddt.unpack
227253
@patch('learning_assistant.utils.get_optimizely_variation')
228-
def test_get_audit_trial_length_days_experiment(self, variation_key, expected_value, mock_get_optimizely_variation):
254+
def test_get_audit_trial_length_days_experiment(
255+
self, variation_key, expected_value, mock_get_optimizely_variation
256+
):
229257
mock_get_optimizely_variation.return_value = {'enabled': True, 'variation_key': variation_key}
230258
with patch.object(settings, 'OPTIMIZELY_LEARNING_ASSISTANT_TRIAL_VARIATION_KEY_28', 'variation'):
231259
self.assertEqual(get_audit_trial_length_days(1, 'verified'), expected_value)
@@ -273,3 +301,107 @@ def test_wrong_date(self):
273301
response = parse_lms_datetime('when I get my homework done')
274302

275303
self.assertEqual(response, expected_value)
304+
305+
306+
@ddt.ddt
307+
class ExtractMessageContentTests(TestCase):
308+
"""
309+
Tests for the extract_message_content utility function
310+
"""
311+
312+
@patch('learning_assistant.utils.v2_endpoint_enabled')
313+
def test_v2_endpoint_with_list_dict_message(self, mock_v2_enabled):
314+
"""Test v2 endpoint with list containing dict messages"""
315+
mock_v2_enabled.return_value = True
316+
317+
message = [
318+
{'role': 'assistant', 'content': 'First message'},
319+
{'role': 'assistant', 'content': 'Last message'}
320+
]
321+
322+
result = extract_message_content(message)
323+
self.assertEqual(result, 'Last message')
324+
325+
@patch('learning_assistant.utils.v2_endpoint_enabled')
326+
def test_v2_endpoint_with_list_non_dict_message(self, mock_v2_enabled):
327+
"""Test v2 endpoint with list containing non-dict messages"""
328+
mock_v2_enabled.return_value = True
329+
330+
message = ['First message', 'Last message']
331+
332+
result = extract_message_content(message)
333+
self.assertEqual(result, 'Last message')
334+
335+
@patch('learning_assistant.utils.v2_endpoint_enabled')
336+
def test_v2_endpoint_with_empty_list(self, mock_v2_enabled):
337+
"""Test v2 endpoint with empty list"""
338+
mock_v2_enabled.return_value = True
339+
340+
message = []
341+
342+
result = extract_message_content(message)
343+
self.assertEqual(result, '')
344+
345+
@patch('learning_assistant.utils.v2_endpoint_enabled')
346+
def test_v2_endpoint_with_dict_missing_content(self, mock_v2_enabled):
347+
"""Test v2 endpoint with dict message missing content key"""
348+
mock_v2_enabled.return_value = True
349+
350+
message = [{'role': 'assistant'}]
351+
352+
result = extract_message_content(message)
353+
self.assertEqual(result, '')
354+
355+
@patch('learning_assistant.utils.v2_endpoint_enabled')
356+
def test_v1_endpoint_with_dict_message(self, mock_v2_enabled):
357+
"""Test v1 endpoint with dict message containing content"""
358+
mock_v2_enabled.return_value = False
359+
360+
message = {'role': 'assistant', 'content': 'v1 response'}
361+
362+
result = extract_message_content(message)
363+
self.assertEqual(result, 'v1 response')
364+
365+
@patch('learning_assistant.utils.v2_endpoint_enabled')
366+
def test_v1_endpoint_with_dict_missing_content(self, mock_v2_enabled):
367+
"""Test v1 endpoint with dict message missing content key"""
368+
mock_v2_enabled.return_value = False
369+
370+
message = {'role': 'assistant'}
371+
372+
result = extract_message_content(message)
373+
self.assertEqual(result, "{'role': 'assistant'}")
374+
375+
@patch('learning_assistant.utils.v2_endpoint_enabled')
376+
def test_fallback_with_string_message(self, mock_v2_enabled):
377+
"""Test fallback case with string message"""
378+
mock_v2_enabled.return_value = False
379+
380+
message = 'Error: Something went wrong'
381+
382+
result = extract_message_content(message)
383+
self.assertEqual(result, 'Error: Something went wrong')
384+
385+
@patch('learning_assistant.utils.v2_endpoint_enabled')
386+
def test_fallback_with_none_message(self, mock_v2_enabled):
387+
"""Test fallback case with None message"""
388+
mock_v2_enabled.return_value = False
389+
390+
message = None
391+
392+
result = extract_message_content(message)
393+
self.assertEqual(result, 'None')
394+
395+
@patch('learning_assistant.utils.v2_endpoint_enabled')
396+
def test_v2_endpoint_mixed_message_types(self, mock_v2_enabled):
397+
"""Test v2 endpoint with mixed message types in list"""
398+
mock_v2_enabled.return_value = True
399+
400+
message = [
401+
'First string message',
402+
{'role': 'assistant', 'content': 'Dict message'},
403+
'Last string message'
404+
]
405+
406+
result = extract_message_content(message)
407+
self.assertEqual(result, 'Last string message')

0 commit comments

Comments
 (0)