diff --git a/enterprise/filters/__init__.py b/enterprise/filters/__init__.py new file mode 100644 index 000000000..87ba36052 --- /dev/null +++ b/enterprise/filters/__init__.py @@ -0,0 +1,3 @@ +""" +Filter pipeline step implementations for edx-enterprise openedx-filters integrations. +""" diff --git a/enterprise/filters/course_modes.py b/enterprise/filters/course_modes.py new file mode 100644 index 000000000..d242b5797 --- /dev/null +++ b/enterprise/filters/course_modes.py @@ -0,0 +1,37 @@ +""" +Pipeline steps for the course mode checkout filter. +""" +import logging + +from openedx_filters.filters import PipelineStep + +log = logging.getLogger(__name__) + + +class CheckoutEnterpriseContextInjector(PipelineStep): + """ + Inject enterprise customer data into the course mode checkout context. + + If the current request is associated with an enterprise customer, this step adds the + enterprise customer dict to the checkout context under the key 'enterprise_customer'. + This allows downstream checkout logic to apply enterprise-specific pricing. + """ + + def run_filter(self, context, request, course_mode): # pylint: disable=arguments-differ + """ + Inject enterprise customer into the checkout context. + """ + # Deferred import — will be replaced with internal path in epic 17. + from openedx.features.enterprise_support.api import \ + enterprise_customer_for_request # pylint: disable=import-outside-toplevel + + try: + enterprise_customer = enterprise_customer_for_request(request) + except Exception: # pylint: disable=broad-except + log.warning('Failed to retrieve enterprise customer for checkout context.', exc_info=True) + enterprise_customer = None + + if enterprise_customer: + context['enterprise_customer'] = enterprise_customer + + return {'context': context, 'request': request, 'course_mode': course_mode} diff --git a/tests/filters/__init__.py b/tests/filters/__init__.py new file mode 100644 index 000000000..5a5f3f9ac --- /dev/null +++ b/tests/filters/__init__.py @@ -0,0 +1 @@ +"""Tests for enterprise filter pipeline steps.""" diff --git a/tests/filters/test_course_modes.py b/tests/filters/test_course_modes.py new file mode 100644 index 000000000..8551d7b96 --- /dev/null +++ b/tests/filters/test_course_modes.py @@ -0,0 +1,100 @@ +""" +Tests for enterprise.filters.course_modes pipeline step. +""" +import sys +import unittest +from unittest.mock import MagicMock, patch + + +class TestCheckoutEnterpriseContextInjector(unittest.TestCase): + """ + Tests for CheckoutEnterpriseContextInjector pipeline step. + """ + + def _make_step(self): + from enterprise.filters.course_modes import CheckoutEnterpriseContextInjector + return CheckoutEnterpriseContextInjector( + "org.openedx.learning.course_mode.checkout.started.v1", + [], + ) + + def _run_with_patched_imports(self, enterprise_customer, context=None, request=None, course_mode=None): + """ + Helper: run step.run_filter with deferred imports patched via sys.modules. + """ + if context is None: + context = {} + if request is None: + request = MagicMock() + if course_mode is None: + course_mode = MagicMock() + + step = self._make_step() + + api_mod = MagicMock() + api_mod.enterprise_customer_for_request = MagicMock(return_value=enterprise_customer) + + with patch.dict(sys.modules, { + 'openedx': MagicMock(), + 'openedx.features': MagicMock(), + 'openedx.features.enterprise_support': MagicMock(), + 'openedx.features.enterprise_support.api': api_mod, + }): + return step.run_filter(context=context, request=request, course_mode=course_mode) + + def test_injects_enterprise_customer_when_found(self): + """ + When an enterprise customer is found for the request, it is injected into the context. + """ + enterprise_customer = {"uuid": "test-uuid", "name": "Test Enterprise"} + context = {"course_id": "course-v1:org+course+run"} + request = MagicMock() + course_mode = MagicMock() + + result = self._run_with_patched_imports( + enterprise_customer=enterprise_customer, + context=context, + request=request, + course_mode=course_mode, + ) + + self.assertEqual(result["context"]["enterprise_customer"], enterprise_customer) + self.assertIs(result["request"], request) + self.assertIs(result["course_mode"], course_mode) + + def test_does_not_inject_when_no_enterprise_customer(self): + """ + When no enterprise customer is found, the context is returned unchanged. + """ + context = {"course_id": "course-v1:org+course+run"} + + result = self._run_with_patched_imports( + enterprise_customer=None, + context=context, + ) + + self.assertNotIn("enterprise_customer", result["context"]) + self.assertEqual(result["context"], context) + + def test_handles_exception_gracefully(self): + """ + When enterprise_customer_for_request raises an exception, the context is returned unchanged. + """ + step = self._make_step() + context = {"course_id": "course-v1:org+course+run"} + request = MagicMock() + course_mode = MagicMock() + + api_mod = MagicMock() + api_mod.enterprise_customer_for_request = MagicMock(side_effect=Exception("Connection error")) + + with patch.dict(sys.modules, { + 'openedx': MagicMock(), + 'openedx.features': MagicMock(), + 'openedx.features.enterprise_support': MagicMock(), + 'openedx.features.enterprise_support.api': api_mod, + }): + result = step.run_filter(context=context, request=request, course_mode=course_mode) + + self.assertNotIn("enterprise_customer", result["context"]) + self.assertEqual(result["context"], context)