Skip to content

Commit 215bef1

Browse files
B2B Provisioning: Add user/contract FKs, user provisioning on purchase (#2645)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 936d598 commit 215bef1

File tree

18 files changed

+418
-45
lines changed

18 files changed

+418
-45
lines changed

b2b/admin.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""B2B model admin. Only for convenience; you should use the Wagtail interface instead."""
2+
3+
from django.contrib import admin
4+
5+
from b2b.models import ContractPage, OrganizationPage
6+
7+
admin.site.register(OrganizationPage)
8+
admin.site.register(ContractPage)

b2b/api.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from b2b.models import ContractPage, OrganizationIndexPage, OrganizationPage
1313
from cms.api import get_home_page
1414
from courses.models import Course, CourseRun
15+
from ecommerce.api import establish_basket
1516
from ecommerce.models import Product
1617
from main.utils import date_to_datetime
1718

@@ -133,3 +134,79 @@ def create_contract_run(
133134
)
134135

135136
return course_run, course_run_product
137+
138+
139+
def validate_basket_for_b2b_purchase(request) -> bool:
140+
"""
141+
Validate the basket for a B2B purchase.
142+
143+
This function checks if the basket is valid for a B2B purchase. It ensures
144+
that the basket contains only products that are part of a contract and that
145+
the contract is active.
146+
147+
When SSO integrations are implemented, this will need to be revised since
148+
it'll fail those baskets (unless we create orders for those completely
149+
differently).
150+
151+
Args:
152+
request: The HTTP request object containing the basket data.
153+
154+
Returns: bool
155+
"""
156+
157+
basket = establish_basket(request)
158+
if not basket:
159+
return False
160+
161+
basket_contracts = []
162+
course_run_content_type = ContentType.objects.get_for_model(CourseRun)
163+
164+
# This system only supports one item per basket, but this is done so it can
165+
# be lifted out and into UE later (which supports >1 item).
166+
167+
for item in (
168+
basket.basket_items.filter(product__content_type=course_run_content_type)
169+
.prefetch_related(
170+
"product",
171+
"product__purchasable_object",
172+
"product__purchasable_object__b2b_contract",
173+
)
174+
.all()
175+
):
176+
contract = item.product.purchasable_object.b2b_contract
177+
178+
if contract and contract.is_active:
179+
basket_contracts.append(contract)
180+
181+
if len(basket_contracts) == 0:
182+
# No contracts in the basket, so we don't need to check further.
183+
# The other validity checks that run before will make sure the discount
184+
# applies to the basket products.
185+
return True
186+
187+
discounts_with_contracts = (
188+
basket.discounts.filter(
189+
redeemed_discount__products__product__content_type=course_run_content_type
190+
)
191+
.prefetch_related(
192+
"redeemed_discount",
193+
"redeemed_discount__products",
194+
"redeemed_discount__products__product__purchasable_object",
195+
"redeemed_discount__products__product__purchasable_object__b2b_contract",
196+
)
197+
.distinct()
198+
.all()
199+
)
200+
201+
if len(discounts_with_contracts) != len(basket_contracts):
202+
# We should have a code for each contract in the basket.
203+
return False
204+
205+
for discount_item in discounts_with_contracts:
206+
for discount_product in discount_item.redeemed_discount.products.all():
207+
contract = discount_product.product.purchasable_object.b2b_contract
208+
209+
if contract and contract.is_active and contract in basket_contracts:
210+
return True
211+
212+
return False

b2b/api_test.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
from mitol.common.utils import now_in_utc
88

99
from b2b import factories
10-
from b2b.api import create_contract_run
10+
from b2b.api import create_contract_run, validate_basket_for_b2b_purchase
1111
from b2b.constants import B2B_RUN_TAG_FORMAT
12+
from b2b.factories import ContractPageFactory
1213
from courses.factories import CourseFactory
14+
from ecommerce.api_test import create_basket
15+
from ecommerce.factories import ProductFactory, UnlimitedUseDiscountFactory
16+
from ecommerce.models import BasketDiscount, DiscountProduct
1317
from main.utils import date_to_datetime
1418

1519
FAKE = faker.Factory.create()
@@ -73,3 +77,71 @@ def test_create_single_course_run(mocker, has_start, has_end):
7377
assert run.enrollment_end is None
7478

7579
assert product.purchasable_object == run
80+
81+
82+
@pytest.mark.parametrize(
83+
(
84+
"run_contract",
85+
"apply_code",
86+
),
87+
[
88+
(False, False),
89+
(False, True),
90+
(True, False),
91+
(True, True),
92+
],
93+
)
94+
def test_b2b_basket_validation(user, run_contract, apply_code):
95+
"""
96+
Test that a basket is validated correctly for B2B contracts.
97+
98+
Basically, if the user is adding a product that links to a course run that
99+
is also linked to a contract, we need to have also applied the discount code
100+
that matches the product, or we shouldn't be allowed to buy it.
101+
102+
The truth table for this should be:
103+
104+
| run_contract | apply_code | result |
105+
|--------------|------------|--------|
106+
| False | False | True |
107+
| False | True | True |
108+
| True | False | False |
109+
| True | True | True |
110+
"""
111+
112+
product = ProductFactory.create()
113+
discount = UnlimitedUseDiscountFactory.create()
114+
discount_product = DiscountProduct.objects.create(
115+
discount=discount, product=product
116+
)
117+
discount_product.save()
118+
discount.products.add(discount_product)
119+
120+
if run_contract:
121+
contract = ContractPageFactory.create()
122+
123+
product.purchasable_object.b2b_contract = contract
124+
product.purchasable_object.save()
125+
product.refresh_from_db()
126+
127+
basket = create_basket(user, [product])
128+
129+
if apply_code:
130+
redemption = BasketDiscount(
131+
redemption_date=now_in_utc(),
132+
redeemed_by=user,
133+
redeemed_discount=discount,
134+
redeemed_basket=basket,
135+
)
136+
137+
redemption.save()
138+
basket.refresh_from_db()
139+
140+
check_result = validate_basket_for_b2b_purchase(basket)
141+
142+
if run_contract and not apply_code:
143+
# User is trying to buy something that's linked to a contract but hasn't
144+
# applied the code, so this should be false.
145+
assert check_result is False
146+
else:
147+
assert check_result is True

b2b/serializers/v0/__init__.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,55 +5,61 @@
55
from b2b.models import ContractPage, OrganizationPage
66

77

8-
class OrganizationPageSerializer(serializers.ModelSerializer):
8+
class ContractPageSerializer(serializers.ModelSerializer):
99
"""
10-
Serializer for the OrganizationPage model.
10+
Serializer for the ContractPage model.
1111
"""
1212

1313
class Meta:
14-
model = OrganizationPage
14+
model = ContractPage
1515
fields = [
1616
"id",
1717
"name",
1818
"description",
19-
"logo",
19+
"integration_type",
20+
"organization",
21+
"contract_start",
22+
"contract_end",
23+
"active",
2024
"slug",
25+
"organization",
2126
]
2227
read_only_fields = [
2328
"id",
2429
"name",
2530
"description",
26-
"logo",
31+
"integration_type",
32+
"organization",
33+
"contract_start",
34+
"contract_end",
35+
"active",
2736
"slug",
37+
"organization",
2838
]
2939

3040

31-
class ContractPageSerializer(serializers.ModelSerializer):
41+
class OrganizationPageSerializer(serializers.ModelSerializer):
3242
"""
33-
Serializer for the ContractPage model.
43+
Serializer for the OrganizationPage model.
3444
"""
3545

46+
contracts = ContractPageSerializer(many=True, read_only=True)
47+
3648
class Meta:
37-
model = ContractPage
49+
model = OrganizationPage
3850
fields = [
3951
"id",
4052
"name",
4153
"description",
42-
"integration_type",
43-
"organization",
44-
"contract_start",
45-
"contract_end",
46-
"active",
54+
"logo",
4755
"slug",
56+
"contracts",
4857
]
4958
read_only_fields = [
5059
"id",
5160
"name",
5261
"description",
53-
"integration_type",
54-
"organization",
55-
"contract_start",
56-
"contract_end",
57-
"active",
62+
"logo",
5863
"slug",
64+
"contracts",
5965
]

courses/api.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ def _enroll_learner_into_associated_programs():
196196

197197
_enroll_learner_into_associated_programs()
198198

199+
# If the run is associated with a B2B contract, add the contract
200+
# to the user's contract list
201+
if run.b2b_contract:
202+
user.b2b_contracts.add(run.b2b_contract)
203+
user.save()
204+
199205
if not created:
200206
enrollment_mode_changed = mode != enrollment.enrollment_mode
201207
enrollment.edx_enrolled = edx_request_success

courses/management/commands/create_courseware.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,13 @@ def _create_departments(self, departments: List[str]) -> models.QuerySet: # noq
7777
models.QuerySet: Query set containing all of the departments specified
7878
in the list of department names.
7979
"""
80-
add_depts = Department.objects.filter(name__in=departments.split()).all()
81-
for dept in departments.split():
80+
add_depts = Department.objects.filter(name__in=departments).all()
81+
for dept in departments:
8282
found = len([db_dept for db_dept in add_depts if db_dept.name == dept]) > 0
8383
if not found:
8484
Department.objects.create(name=dept)
8585

86-
return Department.objects.filter(name__in=departments.split()).all()
86+
return Department.objects.filter(name__in=departments).all()
8787

8888
def _department_must_be_defined_error(self):
8989
"""

ecommerce/api.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import logging
44
import uuid
55
from decimal import Decimal
6+
from urllib.parse import urljoin
67

8+
from django.conf import settings
79
from django.core.exceptions import ObjectDoesNotExist, PermissionDenied, ValidationError
810
from django.db import transaction
911
from django.db.models import Q
@@ -57,7 +59,11 @@
5759
log = logging.getLogger(__name__)
5860

5961

60-
def generate_checkout_payload(request):
62+
def generate_checkout_payload(request): # noqa: PLR0911
63+
"""Generate the checkout payload for the current basket."""
64+
65+
from b2b.api import validate_basket_for_b2b_purchase
66+
6167
basket = establish_basket(request)
6268

6369
if basket.has_user_blocked_products(request.user):
@@ -99,6 +105,15 @@ def generate_checkout_payload(request):
99105
),
100106
}
101107

108+
if not validate_basket_for_b2b_purchase(request):
109+
return {
110+
"invalid_discounts": True,
111+
"response": redirect_with_user_message(
112+
reverse("cart"),
113+
{"type": USER_MSG_TYPE_DISCOUNT_INVALID},
114+
),
115+
}
116+
102117
order = PendingOrder.create_from_basket(basket)
103118
total_price = 0
104119

@@ -141,7 +156,7 @@ def generate_checkout_payload(request):
141156
),
142157
}
143158

144-
callback_uri = request.build_absolute_uri(reverse("checkout-result-callback"))
159+
callback_uri = urljoin(settings.SITE_BASE_URL, reverse("checkout-result-callback"))
145160
payload = PaymentGateway.start_payment(
146161
ECOMMERCE_DEFAULT_PAYMENT_GATEWAY,
147162
gateway_order,

0 commit comments

Comments
 (0)