Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions backend/lcfs/tests/credit_ledger/test_credit_ledger_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async def test_get_rows_default_sort(
):
fake_row = MagicMock()
execute_result = MagicMock()
execute_result.scalars.return_value.all.return_value = [fake_row]
execute_result.all.return_value = [fake_row]

mock_session.execute.return_value = execute_result
mock_session.scalar.return_value = 1
Expand All @@ -51,7 +51,7 @@ async def test_get_rows_with_sort_and_paging(
):
fake_rows = [MagicMock(), MagicMock()]
execute_result = MagicMock()
execute_result.scalars.return_value.all.return_value = fake_rows
execute_result.all.return_value = fake_rows

mock_session.execute.return_value = execute_result
mock_session.scalar.return_value = 2
Expand All @@ -73,7 +73,9 @@ async def test_get_rows_with_sort_and_paging(


@pytest.mark.anyio
async def test_get_distinct_years(repo: CreditLedgerRepository, mock_session: MagicMock):
async def test_get_distinct_years(
repo: CreditLedgerRepository, mock_session: MagicMock
):
"""Test getting distinct years for an organization."""
fake_years = ["2024", "2023", "2022"]
execute_result = MagicMock()
Expand All @@ -89,11 +91,13 @@ async def test_get_distinct_years(repo: CreditLedgerRepository, mock_session: Ma


@pytest.mark.anyio
async def test_get_distinct_years_filters_nulls(repo: CreditLedgerRepository, mock_session: MagicMock):
async def test_get_distinct_years_filters_nulls(
repo: CreditLedgerRepository, mock_session: MagicMock
):
"""Test that get_distinct_years filters out null years."""
fake_years_with_nulls = ["2024", None, "2023", "", "2022"]
expected_years = ["2024", "2023", "2022"]

execute_result = MagicMock()
execute_result.scalars.return_value.all.return_value = fake_years_with_nulls

Expand Down
65 changes: 43 additions & 22 deletions backend/lcfs/tests/credit_ledger/test_credit_ledger_services.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from math import ceil
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch

Expand Down Expand Up @@ -29,17 +30,15 @@ async def test_get_ledger_paginated_success(credit_ledger_service, mock_repo):
page=2, size=5, filters=[], sort_orders=[]
)

mock_rows = [
SimpleNamespace(
transaction_type="Credit",
compliance_period="2023",
organization_id=1,
compliance_units=10,
available_balance=10,
update_date="2024-01-01",
)
for _ in range(3)
]
ledger_view = SimpleNamespace(
transaction_type="ComplianceReport",
compliance_period="2023",
organization_id=1,
compliance_units=10,
available_balance=10,
update_date="2024-01-01",
)
mock_rows = [(ledger_view, 2)]
mock_repo.get_rows_paginated.return_value = (mock_rows, 12)

data = await credit_ledger_service.get_ledger_paginated(
Expand All @@ -48,17 +47,28 @@ async def test_get_ledger_paginated_success(credit_ledger_service, mock_repo):

assert data.pagination.total == 12
assert data.pagination.total_pages == ceil(12 / 5)
assert len(data.ledger) == 3
assert len(data.ledger) == 1
assert isinstance(data.ledger[0], CreditLedgerTxnSchema)
assert data.ledger[0].description == "Supplemental 2"


@pytest.mark.anyio
async def test_export_transactions_generates_stream(credit_ledger_service, mock_repo):
with patch(
"lcfs.web.api.credit_ledger.services.SpreadsheetBuilder.build_spreadsheet",
return_value=b"dummy-bytes",
):
mock_repo.get_rows_paginated.return_value = ([], 0)
), patch(
"lcfs.web.api.credit_ledger.services.SpreadsheetBuilder.add_sheet"
) as mock_add_sheet:
ledger_view = SimpleNamespace(
transaction_type="ComplianceReport",
compliance_period="2023",
organization_id=1,
compliance_units=10,
available_balance=10,
update_date=datetime(2024, 1, 1),
)
mock_repo.get_rows_paginated.return_value = ([(ledger_view, 1)], 1)

resp = await credit_ledger_service.export_transactions(
organization_id=1, compliance_year=None, export_format="csv"
Expand All @@ -67,28 +77,39 @@ async def test_export_transactions_generates_stream(credit_ledger_service, mock_
assert isinstance(resp, StreamingResponse)
assert resp.media_type == "text/csv"
assert resp.headers["Content-Disposition"].startswith("attachment;")
assert mock_add_sheet.called
_, kwargs = mock_add_sheet.call_args
assert kwargs["rows"][0][3] == "Compliance Report – Supplemental 1"


@pytest.mark.anyio
async def test_get_organization_years_success(credit_ledger_service, mock_repo):
"""Test getting organization years returns years from repo."""
expected_years = ["2024", "2023", "2022"]
mock_repo.get_distinct_years.return_value = expected_years

organization_id = 123
years = await credit_ledger_service.get_organization_years(organization_id=organization_id)

years = await credit_ledger_service.get_organization_years(
organization_id=organization_id
)

assert years == expected_years
mock_repo.get_distinct_years.assert_called_once_with(organization_id=organization_id)
mock_repo.get_distinct_years.assert_called_once_with(
organization_id=organization_id
)


@pytest.mark.anyio
async def test_get_organization_years_empty_list(credit_ledger_service, mock_repo):
"""Test getting organization years returns empty list when no data."""
mock_repo.get_distinct_years.return_value = []

organization_id = 456
years = await credit_ledger_service.get_organization_years(organization_id=organization_id)

years = await credit_ledger_service.get_organization_years(
organization_id=organization_id
)

assert years == []
mock_repo.get_distinct_years.assert_called_once_with(organization_id=organization_id)
mock_repo.get_distinct_years.assert_called_once_with(
organization_id=organization_id
)
41 changes: 28 additions & 13 deletions backend/lcfs/web/api/credit_ledger/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from typing import Optional, List

from fastapi import Depends
from sqlalchemy import func, select, and_, desc, asc, distinct
from sqlalchemy import func, select, and_, desc, distinct
from sqlalchemy.ext.asyncio import AsyncSession

from lcfs.db.dependencies import get_async_db_session
from lcfs.web.core.decorators import repo_handler
from lcfs.db.models.transaction.CreditLedgerView import CreditLedgerView
from lcfs.db.models.compliance.ComplianceReport import ComplianceReport

log = structlog.get_logger(__name__)

Expand All @@ -28,24 +29,38 @@ async def get_rows_paginated(
limit: Optional[int],
conditions: List[any],
sort_orders: List[any],
) -> tuple[List[CreditLedgerView], int]:
# Base query
stmt = select(CreditLedgerView).where(and_(*conditions))
) -> tuple[List[tuple], int]:
# Base query - join with compliance_report to get version for ComplianceReport transactions
stmt = (
select(
CreditLedgerView,
ComplianceReport.version.label("compliance_report_version"),
)
.outerjoin(
ComplianceReport,
and_(
CreditLedgerView.transaction_id
== ComplianceReport.compliance_report_id,
CreditLedgerView.transaction_type == "ComplianceReport",
),
)
.where(and_(*conditions))
)

# Sort and order
for order in sort_orders:
direction = asc if order.direction == "asc" else desc
stmt = stmt.order_by(direction(getattr(CreditLedgerView, order.field)))
if not sort_orders:
stmt = stmt.order_by(CreditLedgerView.update_date.desc())
# Always sort by update_date DESC - sorting is not allowed on credit ledger
stmt = stmt.order_by(CreditLedgerView.update_date.desc())

# Count before pagination
total = await self.db.scalar(select(func.count()).select_from(stmt.subquery()))
count_stmt = select(func.count()).select_from(
select(CreditLedgerView).where(and_(*conditions)).subquery()
)
total = await self.db.scalar(count_stmt)

# Pagination
stmt = stmt.offset(offset).limit(limit)

rows = (await self.db.execute(stmt)).scalars().all()
result = await self.db.execute(stmt)
rows = result.all()
return rows, total or 0

@repo_handler
Expand All @@ -64,7 +79,7 @@ async def get_distinct_years(
.where(CreditLedgerView.compliance_period.isnot(None))
.order_by(desc(CreditLedgerView.compliance_period))
)

result = await self.db.execute(stmt)
years = result.scalars().all()
return [str(year) for year in years if year]
5 changes: 3 additions & 2 deletions backend/lcfs/web/api/credit_ledger/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@

class CreditLedgerTxnSchema(BaseSchema):
transaction_type: str
description: Optional[str] = None
compliance_period: str
organization_id: int
compliance_units: int
available_balance: Optional[int]
update_date: datetime

model_config = ConfigDict(from_attributes=True)
@field_validator('available_balance')

@field_validator("available_balance")
@classmethod
def validate_available_balance(cls, v: Optional[int]) -> int:
"""Ensure available balance is never negative - display 0 instead"""
Expand Down
61 changes: 49 additions & 12 deletions backend/lcfs/web/api/credit_ledger/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ class CreditLedgerService:
def __init__(self, repo: CreditLedgerRepository = Depends()) -> None:
self.repo = repo

def _apply_filters(self, pagination: PaginationRequestSchema, conditions: List[any]) -> None:
def _apply_filters(
self, pagination: PaginationRequestSchema, conditions: List[any]
) -> None:
for f in pagination.filters:
field = get_field_for_filter(CreditLedgerView, f.field)
filter_val = f.filter
Expand Down Expand Up @@ -62,8 +64,26 @@ async def get_ledger_paginated(
sort_orders=pagination.sort_orders,
)

# Transform rows with compliance report version (e.g., "Original", "Supplemental 1")
ledger_items = []
for row in rows:
ledger_view, version = row
# Create schema from the ledger view
item = CreditLedgerTxnSchema.model_validate(ledger_view)

# Add formatted description for compliance reports
if (
ledger_view.transaction_type == "ComplianceReport"
and version is not None
):
item.description = (
"Original" if version == 0 else f"Supplemental {version}"
)

ledger_items.append(item)

return CreditLedgerListSchema(
ledger=[CreditLedgerTxnSchema.model_validate(r) for r in rows],
ledger=ledger_items,
pagination=PaginationResponseSchema(
total=total,
page=pagination.page,
Expand Down Expand Up @@ -113,16 +133,33 @@ async def export_transactions(
sort_orders=sort_orders,
)

sheet_rows = [
[
int(r.compliance_period),
int(r.available_balance or 0),
int(r.compliance_units or 0),
r.transaction_type,
r.update_date.strftime("%Y-%m-%d"),
]
for r in rows
]
sheet_rows = []
for row in rows:
ledger_view, version = row

# Format transaction type with version for compliance reports
transaction_type = ledger_view.transaction_type
if transaction_type == "ComplianceReport" and version is not None:
# Format as "Original", "Supplemental 1", etc.
description = "Original" if version == 0 else f"Supplemental {version}"
transaction_type = f"Compliance Report – {description}"
elif transaction_type == "StandaloneTransaction":
transaction_type = "Legacy Transaction"
else:
# Add spaces to camelCase
transaction_type = "".join(
[" " + c if c.isupper() else c for c in transaction_type]
).strip()

sheet_rows.append(
[
int(ledger_view.compliance_period),
int(ledger_view.available_balance or 0),
int(ledger_view.compliance_units or 0),
transaction_type,
ledger_view.update_date.strftime("%Y-%m-%d"),
]
)

builder = SpreadsheetBuilder(file_format=export_format)
builder.add_sheet(
Expand Down
Loading