Skip to content

Commit

Permalink
Apply suggestions from Ruff (#556)
Browse files Browse the repository at this point in the history
* Add ruff to dev dependencies

* Use safer defusedxml

* Fix lints in parsers

* Fix all other lints

* Use latest version of poetry on CI

* Use black format for now
  • Loading branch information
KapJI authored Sep 8, 2024
1 parent c2a03ea commit d29be9e
Show file tree
Hide file tree
Showing 19 changed files with 516 additions and 63 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ jobs:

- name: Install Poetry
uses: abatilo/[email protected]
with:
poetry-version: 1.3.2

- name: Install dependencies
run: poetry install
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ jobs:

- name: Install Poetry
uses: abatilo/[email protected]
with:
poetry-version: 1.3.2

- name: Install dependencies
run: poetry install
Expand Down
13 changes: 8 additions & 5 deletions cgt_calc/currency_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@
import datetime
from decimal import Decimal
from pathlib import Path
from typing import Final
from xml.etree import ElementTree
from typing import TYPE_CHECKING, Final

from defusedxml import ElementTree as ET
import requests

from .dates import is_date
from .exceptions import ExchangeRateMissingError, ParsingError
from .model import BrokerTransaction

if TYPE_CHECKING:
from .model import BrokerTransaction

EXCHANGE_RATES_HEADER: Final = ["month", "currency", "rate"]
NEW_ENDPOINT_FROM_YEAR: Final = 2021


class CurrencyConverter:
Expand Down Expand Up @@ -78,7 +81,7 @@ def _write_exchange_rates_file(

def _query_hmrc_api(self, date: datetime.date) -> None:
# Pre 2021 we need to use the old HMRC endpoint
if date.year < 2021:
if date.year < NEW_ENDPOINT_FROM_YEAR:
month_str = date.strftime("%m%y")
url = (
"http://www.hmrc.gov.uk/softwaredevelopers/rates/"
Expand All @@ -105,7 +108,7 @@ def _query_hmrc_api(self, date: datetime.date) -> None:
url, f"HMRC API returned a {response.status_code} response"
)

tree = ElementTree.fromstring(response.text)
tree = ET.fromstring(response.text)
rates = {
str(getattr(row.find("currencyCode"), "text", None)).upper(): Decimal(
str(getattr(row.find("rateNew"), "text", None))
Expand Down
6 changes: 4 additions & 2 deletions cgt_calc/current_price_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from contextlib import suppress
import datetime
from decimal import Decimal
from typing import TYPE_CHECKING

import yfinance as yf # type: ignore
import yfinance as yf # type: ignore[import-untyped]

from .currency_converter import CurrencyConverter
if TYPE_CHECKING:
from .currency_converter import CurrencyConverter


class CurrentPriceFetcher:
Expand Down
2 changes: 1 addition & 1 deletion cgt_calc/dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def is_date(date: datetime.date) -> bool:
"""Check if date has only date but not time."""
if not isinstance(date, datetime.date) or isinstance(date, datetime.datetime):
raise ValueError(f'should be datetime.date: {type(date)} "{date}"')
raise TypeError(f'should be datetime.date: {type(date)} "{date}"')
return True


Expand Down
9 changes: 6 additions & 3 deletions cgt_calc/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

from __future__ import annotations

import datetime
from decimal import Decimal
from typing import TYPE_CHECKING

from .model import BrokerTransaction
if TYPE_CHECKING:
import datetime
from decimal import Decimal

from .model import BrokerTransaction


class ParsingError(Exception):
Expand Down
7 changes: 5 additions & 2 deletions cgt_calc/initial_prices.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
from __future__ import annotations

from dataclasses import dataclass
import datetime
from decimal import Decimal
from typing import TYPE_CHECKING

from .dates import is_date
from .exceptions import ExchangeRateMissingError

if TYPE_CHECKING:
import datetime
from decimal import Decimal


@dataclass
class InitialPrices:
Expand Down
8 changes: 4 additions & 4 deletions cgt_calc/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
"""Capital Gain Calculator main module."""

from __future__ import annotations

from collections import defaultdict
Expand Down Expand Up @@ -149,8 +150,7 @@ def handle_spin_off(
self,
transaction: BrokerTransaction,
) -> tuple[Decimal, Decimal]:
"""
Handle spin off transaction.
"""Handle spin off transaction.
Doc basing on SOLV spin off out of MMM.
Expand Down Expand Up @@ -377,7 +377,7 @@ def convert_to_hmrc_transactions(
print(f"Dividend taxes: £{round_decimal(-dividends_tax, 2)}")
print(f"Interest: £{round_decimal(interest, 2)}")
print(f"Disposal proceeds: £{round_decimal(total_sells, 2)}")
print("")
print()

def process_acquisition(
self,
Expand Down Expand Up @@ -698,7 +698,7 @@ def calculate_capital_gain(

for date_index in (
begin_index + datetime.timedelta(days=x)
for x in range(0, (end_index - begin_index).days + 1)
for x in range((end_index - begin_index).days + 1)
):
if date_index in self.acquisition_list:
for symbol in self.acquisition_list[date_index]:
Expand Down
8 changes: 4 additions & 4 deletions cgt_calc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(

def __repr__(self) -> str:
"""Return print representation."""
return f"<CalculationEntry {str(self)}>"
return f"<CalculationEntry {self!s}>"

def __str__(self) -> str:
"""Return string representation."""
Expand Down Expand Up @@ -196,7 +196,7 @@ def unrealized_gains_str(self) -> str:

def __repr__(self) -> str:
"""Return print representation."""
return f"<PortfolioEntry {str(self)}>"
return f"<PortfolioEntry {self!s}>"

def __str__(self) -> str:
"""Return string representation."""
Expand Down Expand Up @@ -243,7 +243,7 @@ def taxable_gain(self) -> Decimal:

def __repr__(self) -> str:
"""Return string representation."""
return f"<CalculationEntry: {str(self)}>"
return f"<CalculationEntry: {self!s}>"

def __str__(self) -> str:
"""Return string representation."""
Expand All @@ -253,7 +253,7 @@ def __str__(self) -> str:
unrealized_gains_str = (
entry.unrealized_gains_str() if self.show_unrealized_gains else ""
)
out += f"{str(entry)}{unrealized_gains_str}\n"
out += f"{entry!s}{unrealized_gains_str}\n"
out += f"For tax year {self.tax_year}/{self.tax_year + 1}:\n"
out += f"Number of disposals: {self.disposal_count}\n"
out += f"Disposal proceeds: £{self.disposal_proceeds}\n"
Expand Down
19 changes: 13 additions & 6 deletions cgt_calc/parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from decimal import Decimal
from importlib import resources
from pathlib import Path
from typing import TYPE_CHECKING, Final

from cgt_calc.const import DEFAULT_INITIAL_PRICES_FILE
from cgt_calc.exceptions import UnexpectedColumnCountError
from cgt_calc.model import BrokerTransaction
from cgt_calc.resources import RESOURCES_PACKAGE

from .mssb import read_mssb_transactions
Expand All @@ -20,14 +20,19 @@
from .sharesight import read_sharesight_transactions
from .trading212 import read_trading212_transactions

if TYPE_CHECKING:
from cgt_calc.model import BrokerTransaction

INITIAL_PRICES_COLUMNS_NUM: Final = 3


class InitialPricesEntry:
"""Entry from initial stock prices file."""

def __init__(self, row: list[str], file: str):
"""Create entry from CSV row."""
if len(row) != 3:
raise UnexpectedColumnCountError(row, 3, file)
if len(row) != INITIAL_PRICES_COLUMNS_NUM:
raise UnexpectedColumnCountError(row, INITIAL_PRICES_COLUMNS_NUM, file)
# date,symbol,price
self.date = self._parse_date(row[0])
self.symbol = row[1]
Expand Down Expand Up @@ -98,9 +103,11 @@ def read_initial_prices(
"""Read initial stock prices from CSV file."""
initial_prices: dict[datetime.date, dict[str, Decimal]] = {}
if initial_prices_file is None:
with resources.files(RESOURCES_PACKAGE).joinpath(
DEFAULT_INITIAL_PRICES_FILE
).open(encoding="utf-8") as csv_file:
with (
resources.files(RESOURCES_PACKAGE)
.joinpath(DEFAULT_INITIAL_PRICES_FILE)
.open(encoding="utf-8") as csv_file
):
lines = list(csv.reader(csv_file))
else:
with Path(initial_prices_file).open(encoding="utf-8") as csv_file:
Expand Down
10 changes: 6 additions & 4 deletions cgt_calc/parsers/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
import datetime
from decimal import Decimal
from pathlib import Path
from typing import Final

from cgt_calc.const import TICKER_RENAMES
from cgt_calc.exceptions import ParsingError, UnexpectedColumnCountError
from cgt_calc.model import ActionType, BrokerTransaction

CSV_COLUMNS_NUM: Final = 7


def action_from_str(label: str) -> ActionType:
"""Convert string label to ActionType."""
Expand All @@ -21,8 +24,7 @@ def action_from_str(label: str) -> ActionType:


class RawTransaction(BrokerTransaction):
"""
Represents a single raw transaction.
"""Represents a single raw transaction.
Example format:
2023-02-09,DIVIDEND,OPRA,4200,0.80,0.0,USD
Expand All @@ -40,8 +42,8 @@ def __init__(
file: str,
):
"""Create transaction from CSV row."""
if len(row) != 7:
raise UnexpectedColumnCountError(row, 7, file)
if len(row) != CSV_COLUMNS_NUM:
raise UnexpectedColumnCountError(row, CSV_COLUMNS_NUM, file)

date_str = row[0]
date = datetime.datetime.strptime(date_str, "%Y-%m-%d").date()
Expand Down
14 changes: 10 additions & 4 deletions cgt_calc/parsers/schwab.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from decimal import Decimal
from enum import Enum
from pathlib import Path
from typing import Final

from cgt_calc.const import TICKER_RENAMES
from cgt_calc.exceptions import (
Expand All @@ -19,6 +20,9 @@
)
from cgt_calc.model import ActionType, BrokerTransaction

OLD_COLUMNS_NUM: Final = 9
NEW_COLUMNS_NUM: Final = 8


class SchwabTransactionsFileRequiredHeaders(str, Enum):
"""Enum to list the headers in Schwab transactions file that we will use."""
Expand Down Expand Up @@ -144,11 +148,13 @@ def __init__(
file: str,
):
"""Create transaction from CSV row."""
if len(row_dict) < 8 or len(row_dict) > 9:
if len(row_dict) < NEW_COLUMNS_NUM or len(row_dict) > OLD_COLUMNS_NUM:
# Old transactions had empty 9th column.
raise UnexpectedColumnCountError(list(row_dict.values()), 8, file)
if len(row_dict) == 9 and list(row_dict.values())[8] != "":
raise ParsingError(file, "Column 9 should be empty")
raise UnexpectedColumnCountError(
list(row_dict.values()), NEW_COLUMNS_NUM, file
)
if len(row_dict) == OLD_COLUMNS_NUM and list(row_dict.values())[-1] != "":
raise ParsingError(file, f"Column {OLD_COLUMNS_NUM} should be empty")
as_of_str = " as of "
date_header = SchwabTransactionsFileRequiredHeaders.DATE.value
if as_of_str in row_dict[date_header]:
Expand Down
27 changes: 16 additions & 11 deletions cgt_calc/parsers/schwab_equity_award_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from decimal import Decimal
import json
from pathlib import Path
from typing import Any
from typing import Any, Final

from pandas.tseries.holiday import USFederalHolidayCalendar
from pandas.tseries.offsets import CustomBusinessDay
Expand All @@ -30,11 +30,9 @@
from cgt_calc.util import round_decimal

# Delay between a (sale) trade, and when it is settled.
SETTLEMENT_DELAY = 2 * CustomBusinessDay(calendar=USFederalHolidayCalendar())

OPTIONAL_DETAILS_NAME = "Details"

field2schema = {"transactions": 1, "Transactions": 2}
SETTLEMENT_DELAY: Final = 2 * CustomBusinessDay(calendar=USFederalHolidayCalendar())
OPTIONAL_DETAILS_NAME: Final = "Details"
FIELD_TO_SCHEMA: Final = {"transactions": 1, "Transactions": 2}


@dataclass
Expand Down Expand Up @@ -196,6 +194,12 @@ def __init__(self, row: JsonRowType, file: str, field_names: FieldNames) -> None
action = action_from_str(self.raw_action)
symbol = row.get(names.symbol)
symbol = TICKER_RENAMES.get(symbol, symbol)
if symbol != "GOOG":
# Stock split hardcoded for GOOG
raise ParsingError(
file,
f"Schwab Equity Award JSON only supports GOOG stock but found {symbol}",
)
quantity = _decimal_from_number_or_str(row, names.quantity)
amount = _decimal_from_number_or_str(row, names.amount)
fees = _decimal_from_number_or_str(row, names.fees)
Expand Down Expand Up @@ -312,20 +316,21 @@ def __init__(self, row: JsonRowType, file: str, field_names: FieldNames) -> None
def _normalize_split(self) -> None:
"""Ensure past transactions are normalized to split values.
This is in the context of the 20:1 stock split which happened at close
on 2022-07-15 20:1.
This is in the context of the 20:1 GOOG stock split which happened at
close on 2022-07-15 20:1.
As of 2022-08-07, Schwab's data exports have some past transactions
corrected for the 20:1 split on 2022-07-15, whereas others are not.
"""
split_factor = 20
threshold_price = 175

# The share price has never been above $175*20=$3500 before 2022-07-15
# so this price is expressed in pre-split amounts: normalize to post-split
if (
self.date <= datetime.date(2022, 7, 15)
and self.price
and self.price > 175
and self.price > threshold_price
and self.quantity
):
self.price = round_decimal(self.price / split_factor, ROUND_DIGITS)
Expand All @@ -346,14 +351,14 @@ def read_schwab_equity_award_json_transactions(
"Cloud not parse content as JSON",
) from exception

for field_name, schema_version in field2schema.items():
for field_name, schema_version in FIELD_TO_SCHEMA.items():
if field_name in data:
fields = FieldNames(schema_version)
break
if not fields:
raise ParsingError(
transactions_file,
f"Expected top level field ({', '.join(field2schema.keys())}) "
f"Expected top level field ({', '.join(FIELD_TO_SCHEMA.keys())}) "
"not found: the JSON data is not in the expected format",
)

Expand Down
Loading

0 comments on commit d29be9e

Please sign in to comment.