diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 86e9d41..56aed00 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -22,7 +22,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install Python dependencies run: | - uv pip install -e ".[dev]" --system + uv pip install -e ".[dev,docs]" --system - name: Run tests with coverage run: make test - name: Upload coverage to Codecov @@ -31,28 +31,20 @@ jobs: file: ./coverage.xml fail_ci_if_error: false verbose: true - publish-to-pypi: - name: Publish to PyPI - needs: Test - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - fetch-depth: 0 # Fetch all history for all tags and branches - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install package - run: make install - - name: Build package - run: python -m build - - name: Publish a git tag - run: ".github/publish-git-tag.sh || true" - - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 + - name: Test documentation builds + run: make documentation + - name: Check documentation build + run: | + for notebook in $(find docs/_build/jupyter_execute -name "*.ipynb"); do + if grep -q '"output_type": "error"' "$notebook"; then + echo "Error found in $notebook" + cat "$notebook" + exit 1 + fi + done + - name: Deploy documentation + uses: JamesIves/github-pages-deploy-action@releases/v3 with: - user: __token__ - password: ${{ secrets.PYPI }} - skip-existing: true + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + BRANCH: gh-pages # The branch the action should deploy to. + FOLDER: docs/_build/html # The folder the action should deploy. diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index b2e6cc6..7d0023a 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -61,6 +61,17 @@ jobs: python-version: "3.11" - name: Install dependencies run: | - uv pip install -e ".[dev]" --system + uv pip install -e ".[dev,docs]" --system - name: Build package run: make build + - name: Test documentation builds + run: make documentation + - name: Check documentation build + run: | + for notebook in $(find docs/_build/jupyter_execute -name "*.ipynb"); do + if grep -q '"output_type": "error"' "$notebook"; then + echo "Error found in $notebook" + cat "$notebook" + exit 1 + fi + done diff --git a/.github/workflows/versioning.yaml b/.github/workflows/versioning.yaml index acb6361..c16790a 100644 --- a/.github/workflows/versioning.yaml +++ b/.github/workflows/versioning.yaml @@ -35,4 +35,28 @@ jobs: with: add: "." message: Update package version - \ No newline at end of file + publish-to-pypi: + name: Publish to PyPI + if: (github.event.head_commit.message == 'Update package version') + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch all history for all tags and branches + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install package + run: make install + - name: Build package + run: python -m build + - name: Publish a git tag + run: ".github/publish-git-tag.sh || true" + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI }} + skip-existing: true diff --git a/.gitignore b/.gitignore index b7faf40..6c8074f 100644 --- a/.gitignore +++ b/.gitignore @@ -205,3 +205,9 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ + +# Data files +*.csv +*.h5 + +.DS_Store diff --git a/changelog.yaml b/changelog.yaml index f3e6e24..0d02920 100644 --- a/changelog.yaml +++ b/changelog.yaml @@ -1,5 +1,5 @@ - changes: added: - Initialized project. - date: 2025-07-22 + date: 2025-07-22 00:00:00 version: 0.1.0 diff --git a/changelog_entry.yaml b/changelog_entry.yaml index 842a252..3b10786 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -1,6 +1,5 @@ - bump: minor changes: changed: - - Initialized changelogging. - - Added CI workflows and tests. - - Set up basic package structure and coverage. + - Added Single and Multi Year Dataset classes. + - Added data download and upload functionality. diff --git a/docs/_config.yml b/docs/_config.yml new file mode 100644 index 0000000..0a040c7 --- /dev/null +++ b/docs/_config.yml @@ -0,0 +1,26 @@ +title: Policyengine-data documentation +author: PolicyEngine +logo: logo.png + +execute: + execute_notebooks: force + timeout: 180 + +repository: + url: https://github.com/policyengine/policyengine-data + branch: main + path_to_book: docs + +sphinx: + config: + html_js_files: + - https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js + html_theme: furo + pygments_style: default + html_css_files: + - style.css + extra_extensions: + - "sphinx.ext.autodoc" + - "sphinx.ext.viewcode" + - "sphinx.ext.napoleon" + - "sphinx.ext.mathjax" diff --git a/docs/_static/__init__.py b/docs/_static/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/docs/_static/__init__.py @@ -0,0 +1 @@ + diff --git a/docs/_static/style.css b/docs/_static/style.css new file mode 100644 index 0000000..5cfd703 --- /dev/null +++ b/docs/_static/style.css @@ -0,0 +1 @@ +@import url('https://fonts.googleapis.com/css2?family=Roboto+Serif:opsz@8..144&family=Roboto:wght@300&display=swap'); \ No newline at end of file diff --git a/docs/_toc.yml b/docs/_toc.yml new file mode 100644 index 0000000..3439a33 --- /dev/null +++ b/docs/_toc.yml @@ -0,0 +1,4 @@ +format: jb-book +root: intro +chapters: + - file: dataset.ipynb diff --git a/docs/add_plotly_to_book.py b/docs/add_plotly_to_book.py new file mode 100644 index 0000000..822e77a --- /dev/null +++ b/docs/add_plotly_to_book.py @@ -0,0 +1,27 @@ +import argparse +from pathlib import Path + +# This command-line tools enables Plotly charts to show in the HTML files for the Jupyter Book documentation. + +parser = argparse.ArgumentParser() +parser.add_argument("book_path", help="Path to the Jupyter Book.") + +args = parser.parse_args() + +# Find every HTML file in the Jupyter Book. Then, add a script tag to the start of the tag in each file, with the contents: +# + +book_folder = Path(args.book_path) + +for html_file in book_folder.glob("**/*.html"): + with open(html_file, "r") as f: + html = f.read() + + # Add the script tag to the start of the tag. + html = html.replace( + "", + '', + ) + + with open(html_file, "w") as f: + f.write(html) diff --git a/docs/dataset.ipynb b/docs/dataset.ipynb new file mode 100644 index 0000000..5d7e884 --- /dev/null +++ b/docs/dataset.ipynb @@ -0,0 +1,692 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "gq2207gugn", + "metadata": {}, + "source": [ + "# PolicyEngine Dataset classes documentation\n", + "\n", + "This notebook provides documentation for the `SingleYearDataset` and `MultiYearDataset` classes in PolicyEngine Data. These classes are designed to handle structured data for policy analysis and microsimulation.\n", + "\n", + "More information on how to integrate with PolicyEngine Core and country-specific data packages will be added as this develops." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "baee1feb", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from pathlib import Path\n", + "import warnings\n", + "from tables import NaturalNameWarning\n", + "\n", + "from policyengine_data.single_year_dataset import SingleYearDataset\n", + "from policyengine_data.multi_year_dataset import MultiYearDataset" + ] + }, + { + "cell_type": "markdown", + "id": "7832e1ab", + "metadata": {}, + "source": [ + "## SingleYearDataset\n", + "\n", + "The `SingleYearDataset` class is designed to handle data for a single year, organizing it by entities (typically \"person\" and \"household\" in addition to others). Each entity contains a pandas DataFrame with variables relevant to that entity.\n", + "\n", + "### Key features:\n", + "- Stores data for a single time period\n", + "- Organizes data by entities (person, household, etc.)\n", + "- All data in a given entity is combined into a single table\n", + "- Forces data shape validation in the dataset creation process given the table format\n", + "- Supports basic functionality from the legacy `Dataset` like loading and saving but deprecates multiple data format and loading to the cloud complexity\n", + "\n", + "### Creating a SingleYearDataset\n", + "\n", + "There are three main ways to create a `SingleYearDataset`:\n", + "\n", + "1. **From entity DataFrames**: Create directly from a dictionary of entity DataFrames\n", + "2. **From HDF5 file**: Load from an existing HDF5 file\n", + "3. **From simulation**: Create from a PolicyEngine Core microsimulation" + ] + }, + { + "cell_type": "markdown", + "id": "79636fbd", + "metadata": {}, + "source": [ + "#### Method 1: From entity DataFrames" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0cf913d2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset created for year: 2023\n", + "Available entities: ['person', 'household']\n", + "Person data shape: (1000, 4)\n", + "Household data shape: (400, 4)\n" + ] + } + ], + "source": [ + "# Create sample data for demonstration\n", + "np.random.seed(42)\n", + "\n", + "# Person-level data\n", + "person_data = pd.DataFrame({\n", + " 'person_id': range(1000),\n", + " 'age': np.random.randint(18, 80, 1000),\n", + " 'income': np.random.normal(50000, 15000, 1000),\n", + " 'household_id': np.repeat(range(400), [3, 2, 3, 2] * 100) # Varying household sizes\n", + "})\n", + "\n", + "# Household-level data\n", + "household_data = pd.DataFrame({\n", + " 'household_id': range(400),\n", + " 'household_size': np.random.randint(1, 6, 400),\n", + " 'housing_cost': np.random.normal(1200, 300, 400),\n", + " 'state': np.random.choice(['CA', 'TX', 'NY', 'FL'], 400)\n", + "})\n", + "\n", + "# Create entities dictionary\n", + "entities = {\n", + " 'person': person_data,\n", + " 'household': household_data\n", + "}\n", + "\n", + "# Create SingleYearDataset\n", + "dataset_2023 = SingleYearDataset(\n", + " entities=entities,\n", + " time_period=2023\n", + ")\n", + "\n", + "print(f\"Dataset created for year: {dataset_2023.time_period}\")\n", + "print(f\"Available entities: {list(dataset_2023.entities.keys())}\")\n", + "print(f\"Person data shape: {dataset_2023.entities['person'].shape}\")\n", + "print(f\"Household data shape: {dataset_2023.entities['household'].shape}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4cade694", + "metadata": {}, + "source": [ + "#### Method 2: Loading from HDF5 file" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "93f4652e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset saved to sample_dataset_2023.h5\n", + "Dataset loaded from file for year: 2023\n", + "Loaded entities: ['household', 'person']\n", + "Original person data shape: (1000, 4)\n", + "Loaded person data shape: (1000, 4)\n", + "Data integrity check: True\n" + ] + } + ], + "source": [ + "# Save the dataset to an HDF5 file\n", + "file_path = \"sample_dataset_2023.h5\"\n", + "dataset_2023.save(file_path)\n", + "print(f\"Dataset saved to {file_path}\")\n", + "\n", + "# Load the dataset from the HDF5 file\n", + "loaded_dataset = SingleYearDataset(file_path=file_path)\n", + "print(f\"Dataset loaded from file for year: {loaded_dataset.time_period}\")\n", + "print(f\"Loaded entities: {list(loaded_dataset.entities.keys())}\")\n", + "\n", + "# Verify the data is the same\n", + "print(f\"Original person data shape: {dataset_2023.entities['person'].shape}\")\n", + "print(f\"Loaded person data shape: {loaded_dataset.entities['person'].shape}\")\n", + "print(f\"Data integrity check: {dataset_2023.entities['person'].equals(loaded_dataset.entities['person'])}\")" + ] + }, + { + "cell_type": "markdown", + "id": "504f33ea", + "metadata": {}, + "source": [ + "#### Method 3: From a PolicyEngine MicroSimulation" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4f1a90f6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset created from PolicyEngine US microdata stored in hf://policyengine/policyengine-us-data/cps_2023.h5\n", + "Dataset created for time period: 2023\n" + ] + } + ], + "source": [ + "from policyengine_us import Microsimulation\n", + "\n", + "start_year = 2023\n", + "dataset = \"hf://policyengine/policyengine-us-data/cps_2023.h5\"\n", + "\n", + "sim = Microsimulation(dataset=dataset)\n", + "\n", + "single_year_dataset = SingleYearDataset.from_simulation(sim, time_period=start_year)\n", + "single_year_dataset.time_period = start_year\n", + "\n", + "print(f\"Dataset created from PolicyEngine US microdata stored in {dataset}\")\n", + "print(f\"Dataset created for time period: {single_year_dataset.time_period}\")" + ] + }, + { + "cell_type": "markdown", + "id": "9f4c218b", + "metadata": {}, + "source": [ + "### Main functionalities of SingleYearDataset\n", + "\n", + "#### 1. Data access and properties" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4041f3d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Person entity columns: ['person_id', 'age', 'income', 'household_id']\n", + "Household entity columns: ['household_id', 'household_size', 'housing_cost', 'state']\n", + "\n", + "Variables by entity:\n", + "person: ['person_id', 'age', 'income', 'household_id']\n", + "household: ['household_id', 'household_size', 'housing_cost', 'state']\n", + "\n", + "Time period: 2023\n", + "Data format: arrays\n", + "Table names: ('person', 'household')\n", + "Number of tables: 2\n" + ] + } + ], + "source": [ + "# Access entity data\n", + "print(\"Person entity columns:\", dataset_2023.entities['person'].columns.tolist())\n", + "print(\"Household entity columns:\", dataset_2023.entities['household'].columns.tolist())\n", + "\n", + "# Get variables by entity\n", + "print(\"\\nVariables by entity:\")\n", + "variables = dataset_2023.variables\n", + "for entity, vars_list in variables.items():\n", + " print(f\"{entity}: {vars_list}\")\n", + "\n", + "# Access basic properties\n", + "print(f\"\\nTime period: {dataset_2023.time_period}\")\n", + "print(f\"Data format: {dataset_2023.data_format}\")\n", + "print(f\"Table names: {dataset_2023.table_names}\")\n", + "print(f\"Number of tables: {len(dataset_2023.tables)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "75c9d669", + "metadata": {}, + "source": [ + "Note that the data format property will be removed once we fully move away from legacy code that used the old `Dataset` classes as only entity tables as DataFrames will be supported" + ] + }, + { + "cell_type": "markdown", + "id": "09aca442", + "metadata": {}, + "source": [ + "#### 2. Data loading and copying" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2c09e7c8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded data keys (first 10): ['person_id', 'age', 'income', 'household_id', 'household_size', 'housing_cost', 'state']\n", + "Sample variable 'age' shape: (1000,)\n", + "\n", + "Original dataset time period: 2023\n", + "Copied dataset time period: 2023\n", + "Are they the same object? False\n", + "Do they have the same data? True\n" + ] + } + ], + "source": [ + "# Load data as a flat dictionary (useful for PolicyEngine Core)\n", + "loaded_data = dataset_2023.load()\n", + "print(\"Loaded data keys (first 10):\", list(loaded_data.keys())[:10])\n", + "print(\"Sample variable 'age' shape:\", loaded_data['age'].shape)\n", + "\n", + "# Create a copy of the dataset\n", + "dataset_copy = dataset_2023.copy()\n", + "print(f\"\\nOriginal dataset time period: {dataset_2023.time_period}\")\n", + "print(f\"Copied dataset time period: {dataset_copy.time_period}\")\n", + "print(f\"Are they the same object? {dataset_2023 is dataset_copy}\")\n", + "print(f\"Do they have the same data? {dataset_2023.entities['person'].equals(dataset_copy.entities['person'])}\")" + ] + }, + { + "cell_type": "markdown", + "id": "d4edf01f", + "metadata": {}, + "source": [ + "#### 3. Data validation" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "db42fa38", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset validation passed - no NaN values found\n", + "Validation correctly failed: Column 'income' contains NaN values.\n" + ] + } + ], + "source": [ + "# Validate the dataset (checks for NaN values)\n", + "try:\n", + " dataset_2023.validate()\n", + " print(\"Dataset validation passed - no NaN values found\")\n", + "except ValueError as e:\n", + " print(f\"Validation failed: {e}\")\n", + "\n", + "# Create a dataset with NaN values to demonstrate validation\n", + "invalid_person_data = person_data.copy()\n", + "invalid_person_data.loc[0, 'income'] = np.nan\n", + "\n", + "invalid_entities = {\n", + " 'person': invalid_person_data,\n", + " 'household': household_data\n", + "}\n", + "\n", + "invalid_dataset = SingleYearDataset(\n", + " entities=invalid_entities,\n", + " time_period=2023\n", + ")\n", + "\n", + "# Try to validate the invalid dataset\n", + "try:\n", + " invalid_dataset.validate()\n", + " print(\"Invalid dataset validation passed\")\n", + "except ValueError as e:\n", + " print(f\"Validation correctly failed: {e}\")" + ] + }, + { + "cell_type": "markdown", + "id": "3bb8cd5c", + "metadata": {}, + "source": [ + "## MultiYearDataset\n", + "\n", + "The `MultiYearDataset` class is designed to handle data across multiple years, containing a collection of `SingleYearDataset` instances. This is useful for storing all the data necessary for multi-year analysis in a single object, rather than having to load and manage multiple `Dataset` objects one per year.\n", + "\n", + "### Key features:\n", + "- Stores multiple `SingleYearDataset` instances indexed by year\n", + "- Maintains consistency across years for entity structures\n", + "- Supports copying and data extraction across all years\n", + "\n", + "### Creating a MultiYearDataset\n", + "\n", + "There are two main ways to create a `MultiYearDataset`:\n", + "\n", + "1. **From a list of SingleYearDatasets**: Create from existing SingleYearDataset instances\n", + "2. **From HDF5 file**: Load from an existing multi-year HDF5 file" + ] + }, + { + "cell_type": "markdown", + "id": "1dce1010", + "metadata": {}, + "source": [ + "#### Method 1: From SingleYearDataset list" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "32967970", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Multi-year dataset created with years: [2021, 2022, 2023, 2024]\n", + "Earliest time period present: 2021\n", + "Data format: time_period_arrays\n" + ] + } + ], + "source": [ + "# Create datasets for multiple years\n", + "datasets_by_year = []\n", + "\n", + "for year in [2021, 2022, 2023, 2024]:\n", + " # Create slightly different data for each year (e.g., income growth)\n", + " year_person_data = person_data.copy()\n", + " year_person_data['income'] = year_person_data['income'] * (1.03 ** (year - 2023)) # 3% annual growth\n", + " \n", + " year_household_data = household_data.copy()\n", + " year_household_data['housing_cost'] = year_household_data['housing_cost'] * (1.05 ** (year - 2023)) # 5% annual growth\n", + " \n", + " year_entities = {\n", + " 'person': year_person_data,\n", + " 'household': year_household_data\n", + " }\n", + " \n", + " year_dataset = SingleYearDataset(\n", + " entities=year_entities,\n", + " time_period=year\n", + " )\n", + " datasets_by_year.append(year_dataset)\n", + "\n", + "# Create MultiYearDataset\n", + "multi_year_dataset = MultiYearDataset(datasets=datasets_by_year)\n", + "\n", + "print(f\"Multi-year dataset created with years: {sorted(multi_year_dataset.datasets.keys())}\")\n", + "print(f\"Earliest time period present: {multi_year_dataset.time_period}\")\n", + "print(f\"Data format: {multi_year_dataset.data_format}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c82b1b84", + "metadata": {}, + "source": [ + "#### Method 2: Save and load from HDF5 File" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "34167c6a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Multi-year dataset saved to sample_multi_year_dataset.h5\n", + "Multi-year dataset loaded with years: [2021, 2022, 2023, 2024]\n", + "Original 2022 average income: $49201.98\n", + "Loaded 2022 average income: $49201.98\n", + "Data integrity check: True\n" + ] + } + ], + "source": [ + "warnings.filterwarnings(\"ignore\", category=NaturalNameWarning)\n", + "\n", + "# Save the multi-year dataset to an HDF5 file\n", + "multi_year_file_path = \"sample_multi_year_dataset.h5\"\n", + "multi_year_dataset.save(multi_year_file_path)\n", + "print(f\"Multi-year dataset saved to {multi_year_file_path}\")\n", + "\n", + "# Load the multi-year dataset from the HDF5 file\n", + "loaded_multi_year = MultiYearDataset(file_path=multi_year_file_path)\n", + "print(f\"Multi-year dataset loaded with years: {sorted(loaded_multi_year.datasets.keys())}\")\n", + "\n", + "# Verify the data integrity\n", + "original_2022_income = multi_year_dataset[2022].entities['person']['income'].mean()\n", + "loaded_2022_income = loaded_multi_year[2022].entities['person']['income'].mean()\n", + "\n", + "print(f\"Original 2022 average income: ${original_2022_income:.2f}\")\n", + "print(f\"Loaded 2022 average income: ${loaded_2022_income:.2f}\")\n", + "print(f\"Data integrity check: {abs(original_2022_income - loaded_2022_income) < 0.01}\")" + ] + }, + { + "cell_type": "markdown", + "id": "5721fdce", + "metadata": {}, + "source": [ + "### Main functionalities of MultiYearDataset\n", + "\n", + "#### 1. Accessing data by year" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a4c3bbdd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2022 dataset time period: 2022\n", + "2022 person data shape: (1000, 4)\n", + "2024 dataset time period: 2024\n", + "Error accessing non-existent year: No dataset found for year 2025.\n", + "Available years: [2021, 2022, 2023, 2024]\n" + ] + } + ], + "source": [ + "# Access specific years using get_year() method\n", + "dataset_2022 = multi_year_dataset.get_year(2022)\n", + "print(f\"2022 dataset time period: {dataset_2022.time_period}\")\n", + "print(f\"2022 person data shape: {dataset_2022.entities['person'].shape}\")\n", + "\n", + "# Access specific years using indexing operator []\n", + "dataset_2024 = multi_year_dataset[2024]\n", + "print(f\"2024 dataset time period: {dataset_2024.time_period}\")\n", + "\n", + "# Try to access a year that doesn't exist\n", + "try:\n", + " dataset_2025 = multi_year_dataset.get_year(2025)\n", + "except ValueError as e:\n", + " print(f\"Error accessing non-existent year: {e}\")\n", + "\n", + "# List all available years\n", + "print(f\"Available years: {sorted(multi_year_dataset.datasets.keys())}\")" + ] + }, + { + "cell_type": "markdown", + "id": "7e210fa1", + "metadata": {}, + "source": [ + "#### 2. Variables and data structure" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "e744b959", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Variables by year and entity:\n", + "\n", + "Year 2021:\n", + " person: ['person_id', 'age', 'income', 'household_id']\n", + " household: ['household_id', 'household_size', 'housing_cost', 'state']\n", + "\n", + "Year 2022:\n", + " person: ['person_id', 'age', 'income', 'household_id']\n", + " household: ['household_id', 'household_size', 'housing_cost', 'state']\n", + "\n", + "Year 2023:\n", + " person: ['person_id', 'age', 'income', 'household_id']\n", + " household: ['household_id', 'household_size', 'housing_cost', 'state']\n", + "\n", + "Year 2024:\n", + " person: ['person_id', 'age', 'income', 'household_id']\n", + " household: ['household_id', 'household_size', 'housing_cost', 'state']\n" + ] + } + ], + "source": [ + "# Get variables across all years\n", + "variables_by_year = multi_year_dataset.variables\n", + "print(\"Variables by year and entity:\")\n", + "for year, entities in variables_by_year.items():\n", + " print(f\"\\nYear {year}:\")\n", + " for entity, vars_list in entities.items():\n", + " print(f\" {entity}: {vars_list}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c2292511", + "metadata": {}, + "source": [ + "#### 3. Data loading and copying" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "ucthzwr8ppc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sample of loaded data structure:\n", + "\n", + "Variable 'person_id':\n", + " Year 2021: shape (1000,), mean = 499.50\n", + " Year 2022: shape (1000,), mean = 499.50\n", + " Year 2023: shape (1000,), mean = 499.50\n", + " Year 2024: shape (1000,), mean = 499.50\n", + "\n", + "Variable 'age':\n", + " Year 2021: shape (1000,), mean = 49.86\n", + " Year 2022: shape (1000,), mean = 49.86\n", + " Year 2023: shape (1000,), mean = 49.86\n", + " Year 2024: shape (1000,), mean = 49.86\n", + "\n", + "Original dataset years: [2021, 2022, 2023, 2024]\n", + "Copied dataset years: [2021, 2022, 2023, 2024]\n", + "Are they the same object? False\n", + "Original 2023 income mean: $50678.04\n", + "Copy 2023 income mean: $50678.04\n", + "Data integrity check: True\n" + ] + } + ], + "source": [ + "# Load all data as a time-period indexed dictionary\n", + "all_data = multi_year_dataset.load()\n", + "print(\"Sample of loaded data structure:\")\n", + "for var_name, year_data in list(all_data.items())[:2]: # Show first 2 variables\n", + " print(f\"\\nVariable '{var_name}':\")\n", + " for year, data_array in year_data.items():\n", + " print(f\" Year {year}: shape {data_array.shape}, mean = {data_array.mean():.2f}\")\n", + "\n", + "# Create a copy of the multi-year dataset\n", + "multi_year_copy = multi_year_dataset.copy()\n", + "print(f\"\\nOriginal dataset years: {sorted(multi_year_dataset.datasets.keys())}\")\n", + "print(f\"Copied dataset years: {sorted(multi_year_copy.datasets.keys())}\")\n", + "print(f\"Are they the same object? {multi_year_dataset is multi_year_copy}\")\n", + "\n", + "# Verify independence of the copy\n", + "original_2023_income_mean = multi_year_dataset[2023].entities['person']['income'].mean()\n", + "copy_2023_income_mean = multi_year_copy[2023].entities['person']['income'].mean()\n", + "print(f\"Original 2023 income mean: ${original_2023_income_mean:.2f}\")\n", + "print(f\"Copy 2023 income mean: ${copy_2023_income_mean:.2f}\")\n", + "print(f\"Data integrity check: {abs(original_2023_income_mean - copy_2023_income_mean) < 0.01}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "qrue1qeh4ub", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cleaned up sample_dataset_2023.h5\n", + "Cleaned up sample_multi_year_dataset.h5\n", + "Documentation complete! The notebook now contains comprehensive documentation for both SingleYearDataset and MultiYearDataset classes.\n" + ] + } + ], + "source": [ + "# Clean up temporary files\n", + "import os\n", + "\n", + "temp_files = [\"sample_dataset_2023.h5\", \"sample_multi_year_dataset.h5\"]\n", + "for file in temp_files:\n", + " if os.path.exists(file):\n", + " os.remove(file)\n", + " print(f\"Cleaned up {file}\")\n", + "\n", + "print(\"Documentation complete! The notebook now contains comprehensive documentation for both SingleYearDataset and MultiYearDataset classes.\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pe", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/intro.md b/docs/intro.md new file mode 100644 index 0000000..3718747 --- /dev/null +++ b/docs/intro.md @@ -0,0 +1,3 @@ +## PolicyEngine Data + +This is the documentation for PolicyEngine Data, the open-source Python package powering PolicyEngine's data processing and storing functionality. It is used by PolicyEngine UK Data and PolicyEngine US Data, which each define the custom logic specific to processing UK and US data sources. diff --git a/docs/logo.png b/docs/logo.png new file mode 100644 index 0000000..12736e4 Binary files /dev/null and b/docs/logo.png differ diff --git a/pyproject.toml b/pyproject.toml index 75e29b2..c8c45bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,9 +11,10 @@ dependencies = [ "h5py", "numpy", "pandas", - "huggingface_hub", + "huggingface_hub>=0.25.1", + "tables", "policyengine-core>=3.6.4", - "microdf-python==0.4.4", + "microdf-python", ] [project.optional-dependencies] @@ -29,6 +30,27 @@ dev = [ "yaml-changelog>=0.1.7", ] +docs = [ + "sphinx>=5.0.0", + "docutils>=0.17.0", + "jupyter-book>=0.15.0", + "sphinx-book-theme>=1.0.0", + "sphinx-copybutton>=0.5.0", + "sphinx-design>=0.3.0", + "ipywidgets>=7.8.0", + "plotly", + "sphinx-argparse>=0.5.0", + "sphinx-math-dollar>=1.2.1", + "myst-parser>=0.18.1", + "myst-nb>=0.17.2", + "pyyaml", + "furo>=2022.12.7", + "h5py>=3.1.0,<4.0.0", + "policyengine-core", + "policyengine-us", + "policyengine-us-data", +] + [tool.setuptools] packages = ["policyengine_data"] include-package-data = true diff --git a/src/policyengine_data/__init__.py b/src/policyengine_data/__init__.py index 8b13789..9a4f8a7 100644 --- a/src/policyengine_data/__init__.py +++ b/src/policyengine_data/__init__.py @@ -1 +1,3 @@ - +from .dataset_legacy import Dataset +from .multi_year_dataset import MultiYearDataset +from .single_year_dataset import SingleYearDataset diff --git a/src/policyengine_data/data_download_upload.py b/src/policyengine_data/data_download_upload.py new file mode 100644 index 0000000..0892998 --- /dev/null +++ b/src/policyengine_data/data_download_upload.py @@ -0,0 +1,263 @@ +""" +Functionality for uploading and downloading datasets in PolicyEngine. +""" + +import logging +import os +import shutil +import sys +import tempfile +from pathlib import Path +from typing import Dict, Enum, List, Optional, Tuple, Union + +import h5py +import numpy as np +import pandas as pd +import requests + +from .tools.hugging_face import * +from .tools.win_file_manager import WindowsAtomicFileManager + +logger = logging.getLogger(__name__) + + +def atomic_write(file: Path, content: bytes) -> None: + """ + Atomically update the target file with the content. Any existing file will be unlinked rather than overritten. + + Implemented by + 1. Downloading the file to a temporary file with a unique name + 2. renaming (not copying) the file to the target name so that the operation is atomic (either the file is there or it's not, no partial file) + + If a process is reading the original file when a new file is renamed, that should relink the file, not clear and overwrite the old one so + both processes should continue happily. + """ + if sys.platform == "win32": + manager = WindowsAtomicFileManager(file) + manager.write(content) + else: + with tempfile.NamedTemporaryFile( + mode="wb", + dir=file.parent.absolute().as_posix(), + prefix=file.name + ".download.", + delete=False, + ) as f: + try: + f.write(content) + f.close() + os.rename(f.name, file.absolute().as_posix()) + except: + f.delete = True + f.close() + raise + + +class CloudLocation(Enum): + HUGGING_FACE = "HUGGING_FACE" + GOOGLE_CLOUD_STORAGE = "GOOGLE_CLOUD_STORAGE" + + +def download( + local_dir: Path, + url: str, + cloud_location: Optional[str] = None, + version: Optional[str] = None, +) -> None: + """Downloads a file from a cloud location to the local directory. + + Args: + local_dir (Path): The path to save the downloaded file. + url (str): The url to download from. + cloud_location (Optional[str]): The cloud location to download from. Defaults to None. + version (Optional[str]): The version of the file to download. Defaults to None. + """ + if cloud_location is None: + cloud_location = identify_location(url) + else: + if cloud_location not in CloudLocation: + raise ValueError(f"Unsupported cloud location: {cloud_location}") + + if cloud_location == CloudLocation.HUGGING_FACE: + owner, model, filename = parse_hugging_face_url(url) + download_from_hugging_face(local_dir, owner, model, filename, version) + elif cloud_location == CloudLocation.GOOGLE_CLOUD_STORAGE: + download_from_gcs(local_dir, url) + else: + raise ValueError(f"Unsupported cloud location for URL: {url}") + + +def upload( + local_dir: Path, url: str, cloud_location: Optional[str] = None +) -> None: + """Uploads a file from the local directory to a cloud location. + + Args: + local_dir (Path): The path to the directory containing the file to upload. + url (str): The url to upload to. + cloud_location (Optional[str]): The cloud location to upload to. Defaults to None. + """ + if cloud_location is None: + cloud_location = identify_location(url) + else: + if cloud_location not in CloudLocation: + raise ValueError(f"Unsupported cloud location: {cloud_location}") + + if cloud_location == CloudLocation.HUGGING_FACE: + owner, model, filename = parse_hugging_face_url(url) + upload_to_hugging_face(local_dir, owner, model, filename) + elif cloud_location == CloudLocation.GOOGLE_CLOUD_STORAGE: + upload_to_gcs(local_dir, url) + else: + raise ValueError(f"Unsupported cloud location for URL: {url}") + + +def identify_location(url: str) -> CloudLocation: + """Identifies the cloud storage location from a URL. + + Args: + url (str): The URL to analyze. + + Returns: + CloudLocation: The identified cloud storage location. + + Raises: + ValueError: If the URL format is not recognized. + """ + if url.startswith("hf://"): + return CloudLocation.HUGGING_FACE + elif url.startswith("gs://") or url.startswith( + "https://storage.googleapis.com/" + ): + return CloudLocation.GOOGLE_CLOUD_STORAGE + else: + # Default to Hugging Face for backwards compatibility + return CloudLocation.HUGGING_FACE + + +def validate_hugging_face_url(url: str) -> bool: + """Validates a Hugging Face URL format. + + Args: + url (str): The URL to validate. + + Returns: + bool: True if valid, False otherwise. + """ + if not url.startswith("hf://"): + return False + + parts = url[5:].split("/") # Remove "hf://" prefix + return len(parts) == 3 and all(part for part in parts) + + +def parse_hugging_face_url(url: str) -> Tuple[str, str, str]: + """Parses a Hugging Face URL into its components. + + Args: + url (str): The Hugging Face URL to parse. + + Returns: + Tuple[str, str, str]: Owner name, model name, and filename. + + Raises: + ValueError: If the URL is not a valid Hugging Face URL. + """ + if not validate_hugging_face_url(url): + raise ValueError(f"Invalid Hugging Face URL format: {url}") + + parts = url[5:].split("/") # Remove "hf://" prefix + return parts[0], parts[1], parts[2] + + +def download_from_hugging_face( + local_dir: Path, + owner_name: str, + model_name: str, + file_name: str, + version: Optional[str] = None, +) -> None: + """Downloads a file from Hugging Face. + + Args: + local_dir (Path): The path to save the downloaded file. + owner_name (str): The owner name. + model_name (str): The model name. + file_name (str): The file name. + version (Optional[str]): The version of the file to download. + """ + logger.info( + f"Downloading from HuggingFace {owner_name}/{model_name}/{file_name}", + ) + + download_huggingface_dataset( + repo=f"{owner_name}/{model_name}", + repo_filename=file_name, + version=version, + local_dir=local_dir, + ) + + +def upload_to_hugging_face( + local_dir: Path, owner_name: str, model_name: str, file_name: str +) -> None: + """Uploads a file to Hugging Face. + + Args: + local_dir (Path): The path to the directory containing the file to upload. + owner_name (str): The owner name. + model_name (str): The model name. + file_name (str): The file name. + """ + logger.info( + f"Uploading to HuggingFace {owner_name}/{model_name}/{file_name}", + ) + + token = get_or_prompt_hf_token() + api = HfApi() + + api.upload_file( + path_or_fileobj=local_dir, + path_in_repo=file_name, + repo_id=f"{owner_name}/{model_name}", + repo_type="model", + token=token, + ) + + +def download_from_gcs(local_dir: Path, url: str) -> None: + """Downloads a file from Google Cloud Storage. + + Args: + local_dir (Path): The path to save the downloaded file. + url (str): The GCS URL to download from. + """ + # Convert gs:// URLs to https:// if needed + if url.startswith("gs://"): + bucket_and_path = url[5:] + url = f"https://storage.googleapis.com/{bucket_and_path}" + + response = requests.get(url) + + if response.status_code != 200: + raise ValueError( + f"Failed to download from GCS. Status code: {response.status_code}" + ) + + # Extract filename from URL + filename = url.split("/")[-1] + file_path = local_dir / filename + + atomic_write(file_path, response.content) + + +def upload_to_gcs(local_dir: Path, url: str) -> None: + """Uploads a file to Google Cloud Storage. + + Args: + local_dir (Path): The path to the directory containing the file to upload. + url (str): The GCS URL to upload to. + """ + raise NotImplementedError( + "Google Cloud Storage upload is not yet implemented. " + "Please use Hugging Face for now." + ) diff --git a/src/policyengine_data/dataset_legacy.py b/src/policyengine_data/dataset_legacy.py new file mode 100644 index 0000000..ad516f8 --- /dev/null +++ b/src/policyengine_data/dataset_legacy.py @@ -0,0 +1,515 @@ +""" +Legacy dataset class for PolicyEngine, supporting single-year datasets. +""" + +import os +import shutil +import sys +import tempfile +from pathlib import Path +from typing import Dict, List, Union + +import h5py +import numpy as np +import pandas as pd +import requests +from policyengine_core.tools.hugging_face import * +from policyengine_core.tools.win_file_manager import WindowsAtomicFileManager + + +def atomic_write(file: Path, content: bytes) -> None: + """ + Atomically update the target file with the content. Any existing file will be unlinked rather than overritten. + + Implemented by + 1. Downloading the file to a temporary file with a unique name + 2. renaming (not copying) the file to the target name so that the operation is atomic (either the file is there or it's not, no partial file) + + If a process is reading the original file when a new file is renamed, that should relink the file, not clear and overwrite the old one so + both processes should continue happily. + """ + if sys.platform == "win32": + manager = WindowsAtomicFileManager(file) + manager.write(content) + else: + with tempfile.NamedTemporaryFile( + mode="wb", + dir=file.parent.absolute().as_posix(), + prefix=file.name + ".download.", + delete=False, + ) as f: + try: + f.write(content) + f.close() + os.rename(f.name, file.absolute().as_posix()) + except: + f.delete = True + f.close() + raise + + +class Dataset: + """The `Dataset` class is a base class for datasets used directly or indirectly for microsimulation models. + A dataset defines a generation function to create it from other data, and this class provides common features + like storage, metadata and loading.""" + + name: str = None + """The name of the dataset. This is used to generate filenames and is used as the key in the `datasets` dictionary.""" + label: str = None + """The label of the dataset. This is used for logging and is used as the key in the `datasets` dictionary.""" + data_format: str = None + """The format of the dataset. This can be either `Dataset.ARRAYS`, `Dataset.TIME_PERIOD_ARRAYS` or `Dataset.TABLES`. If `Dataset.ARRAYS`, the dataset is stored as a collection of arrays. If `Dataset.TIME_PERIOD_ARRAYS`, the dataset is stored as a collection of arrays, with one array per time period. If `Dataset.TABLES`, the dataset is stored as a collection of tables (DataFrames).""" + file_path: Path = None + """The path to the dataset file. This is used to load the dataset from a file.""" + time_period: str = None + """The time period of the dataset. This is used to automatically enter the values in the correct time period if the data type is `Dataset.ARRAYS`.""" + url: str = None + """The URL to download the dataset from. This is used to download the dataset if it does not exist.""" + + # Data formats + TABLES = "tables" + ARRAYS = "arrays" + TIME_PERIOD_ARRAYS = "time_period_arrays" + FLAT_FILE = "flat_file" + + _table_cache: Dict[str, pd.DataFrame] = None + + def __init__(self, require: bool = False): + # Setup dataset + if self.file_path is None: + raise ValueError( + "Dataset file_path must be specified in the dataset class definition." + ) + elif isinstance(self.file_path, str): + self.file_path = Path(self.file_path) + + self.file_path.parent.mkdir(parents=True, exist_ok=True) + + assert ( + self.name + ), "You tried to instantiate a Dataset object, but no name has been provided." + assert ( + self.label + ), "You tried to instantiate a Dataset object, but no label has been provided." + + assert self.data_format in [ + Dataset.TABLES, + Dataset.ARRAYS, + Dataset.TIME_PERIOD_ARRAYS, + Dataset.FLAT_FILE, + ], f"You tried to instantiate a Dataset object, but your data_format attribute is invalid ({self.data_format})." + + self._table_cache = {} + + if not self.exists and require: + if self.url is not None: + self.download() + else: + self.generate() + + def load( + self, key: str = None, mode: str = "r" + ) -> Union[h5py.File, np.array, pd.DataFrame, pd.HDFStore]: + """Loads the dataset for a given year, returning a H5 file reader. You can then access the + dataset like a dictionary (e.g.e Dataset.load(2022)["variable"]). + + Args: + key (str, optional): The key to load. Defaults to None. + mode (str, optional): The mode to open the file with. Defaults to "r". + + Returns: + Union[h5py.File, np.array, pd.DataFrame, pd.HDFStore]: The dataset. + """ + file = self.file_path + if self.data_format in (Dataset.ARRAYS, Dataset.TIME_PERIOD_ARRAYS): + if key is None: + # If no key provided, return the basic H5 reader. + return h5py.File(file, mode=mode) + else: + # If key provided, return only the values requested. + with h5py.File(file, mode=mode) as f: + values = np.array(f[key]) + return values + elif self.data_format == Dataset.TABLES: + if key is None: + # Non-openfisca datasets are assumed to be of the format (table name: [table], ...). + return pd.HDFStore(file) + else: + if key in self._table_cache: + return self._table_cache[key] + # If a table name is provided, return that table. + with pd.HDFStore(file) as f: + values = f[key] + self._table_cache[key] = values + return values + elif self.data_format == Dataset.FLAT_FILE: + if key is None: + return pd.read_csv(file) + else: + raise ValueError( + "You tried to load a key from a flat file dataset, but flat file datasets do not support keys." + ) + else: + raise ValueError( + f"Invalid data format {self.data_format} for dataset {self.label}." + ) + + def save(self, key: str, values: Union[np.array, pd.DataFrame]): + """Overwrites the values for `key` with `values`. + + Args: + key (str): The key to save. + values (Union[np.array, pd.DataFrame]): The values to save. + """ + file = self.file_path + if self.data_format in (Dataset.ARRAYS, Dataset.TIME_PERIOD_ARRAYS): + with h5py.File(file, "a") as f: + # Overwrite if existing + if key in f: + del f[key] + f.create_dataset(key, data=values) + elif self.data_format == Dataset.TABLES: + with pd.HDFStore(file, "a") as f: + f.put(key, values) + self._table_cache = {} + elif self.data_format == Dataset.FLAT_FILE: + values.to_csv(file, index=False) + else: + raise ValueError( + f"Invalid data format {self.data_format} for dataset {self.label}." + ) + + def save_dataset(self, data, file_path: str = None) -> None: + """Writes a complete dataset to disk. + + Args: + data: The data to save. + + >>> example_data: Dict[str, Dict[str, Sequence]] = { + ... "employment_income": { + ... "2022": np.array([25000, 25000, 30000, 30000]), + ... }, + ... } + >>> example_data["employment_income"]["2022"] = [25000, 25000, 30000, 30000] + """ + if file_path is not None: + file = Path(file_path) + elif not isinstance(self.file_path, Path): + self.file_path = Path(self.file_path) + file = self.file_path + if self.data_format == Dataset.TABLES: + for table_name, dataframe in data.items(): + self.save(table_name, dataframe) + elif self.data_format == Dataset.TIME_PERIOD_ARRAYS: + with h5py.File(file, "w") as f: + for variable, values in data.items(): + for time_period, value in values.items(): + key = f"{variable}/{time_period}" + # Overwrite if existing + if key in f: + del f[key] + try: + f.create_dataset(key, data=value) + except: + raise ValueError( + f"Could not save {key} to {file}. The value is {value}." + ) + elif self.data_format == Dataset.ARRAYS: + with h5py.File(file, "a" if file.exists() else "w") as f: + for variable, value in data.items(): + # Overwrite if existing + if variable in f: + del f[variable] + try: + f.create_dataset(variable, data=value) + except: + raise ValueError( + f"Could not save {variable} to {file}. The value is {value}." + ) + elif self.data_format == Dataset.FLAT_FILE: + data.to_csv(file, index=False) + + def load_dataset( + self, + ): + """Loads a complete dataset from disk. + + Returns: + Dict[str, Dict[str, Sequence]]: The dataset. + """ + file = self.file_path + if self.data_format == Dataset.TABLES: + with pd.HDFStore(file) as f: + data = {table_name: f[table_name] for table_name in f.keys()} + elif self.data_format == Dataset.TIME_PERIOD_ARRAYS: + with h5py.File(file, "r") as f: + data = {} + for variable in f.keys(): + data[variable] = {} + for time_period in f[variable].keys(): + key = f"{variable}/{time_period}" + data[variable][time_period] = np.array(f[key]) + elif self.data_format == Dataset.ARRAYS: + with h5py.File(file, "r") as f: + data = { + variable: np.array(f[variable]) for variable in f.keys() + } + return data + + def generate(self): + """Generates the dataset for a given year (all datasets should implement this method). + + Raises: + NotImplementedError: If the function has not been overriden. + """ + raise NotImplementedError( + f"You tried to generate the dataset for {self.label}, but no dataset generation implementation has been provided for {self.label}." + ) + + @property + def exists(self) -> bool: + """Checks whether the dataset exists. + + Returns: + bool: Whether the dataset exists. + """ + return self.file_path.exists() + + @property + def variables(self) -> List[str]: + """Returns the variables in the dataset. + + Returns: + List[str]: The variables in the dataset. + """ + if self.data_format == Dataset.TABLES: + with pd.HDFStore(self.file_path) as f: + return list(f.keys()) + elif self.data_format in (Dataset.ARRAYS, Dataset.TIME_PERIOD_ARRAYS): + with h5py.File(self.file_path, "r") as f: + return list(f.keys()) + elif self.data_format == Dataset.FLAT_FILE: + return pd.read_csv(self.file_path, nrows=0).columns.tolist() + else: + raise ValueError( + f"Invalid data format {self.data_format} for dataset {self.label}." + ) + + def __getattr__(self, name): + """Allows the dataset to be accessed like a dictionary. + + Args: + name (str): The key to access. + + Returns: + Union[np.array, pd.DataFrame]: The dataset. + """ + return self.load(name) + + def store_file(self, file_path: str): + """Moves a file to the dataset's file path. + + Args: + file_path (str): The file path to move. + """ + + file_path = Path(file_path) + if not file_path.exists(): + raise FileNotFoundError(f"File {file_path} does not exist.") + shutil.move(file_path, self.file_path) + + def download(self, url: str = None, version: str = None) -> None: + """Downloads a file to the dataset's file path. + + Args: + url (str): The url to download. + """ + + if url is None: + url = self.url + + if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os.environ: + auth_headers = {} + else: + auth_headers = { + "Authorization": f"token {os.environ['POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN']}", + } + + # "release://" is a special protocol for downloading from GitHub releases + # e.g. release://policyengine/policyengine-us/cps-2023/cps_2023.h5 + # release://org/repo/release_tag/file_path + # Use the GitHub API to get the download URL for the release asset + + if url.startswith("release://"): + org, repo, release_tag, file_path = url.split("/")[2:] + url = f"https://api.github.com/repos/{org}/{repo}/releases/tags/{release_tag}" + response = requests.get(url, headers=auth_headers) + if response.status_code != 200: + raise ValueError( + f"Invalid response code {response.status_code} for url {url}." + ) + assets = response.json()["assets"] + for asset in assets: + if asset["name"] == file_path: + url = asset["url"] + break + else: + raise ValueError( + f"File {file_path} not found in release {release_tag} of {org}/{repo}." + ) + elif url.startswith("hf://"): + owner_name, model_name, file_name = url.split("/")[2:] + self.download_from_huggingface( + owner_name, model_name, file_name, version + ) + return + else: + url = url + + response = requests.get( + url, + headers={ + "Accept": "application/octet-stream", + **auth_headers, + }, + ) + + if response.status_code != 200: + raise ValueError( + f"Invalid response code {response.status_code} for url {url}." + ) + + atomic_write(self.file_path, response.content) + + def upload(self, url: str = None): + """Uploads the dataset to a URL. + + Args: + url (str): The url to upload. + """ + if url is None: + url = self.url + + if url.startswith("hf://"): + owner_name, model_name, file_name = url.split("/")[2:] + self.upload_to_huggingface(owner_name, model_name, file_name) + + def remove(self): + """Removes the dataset from disk.""" + if self.exists: + self.file_path.unlink() + + @staticmethod + def from_file(file_path: str, time_period: str = None): + """Creates a dataset from a file. + + Args: + file_path (str): The file path to create the dataset from. + + Returns: + Dataset: The dataset. + """ + file_path = Path(file_path) + + # If it's a h5 file, check the first key + + if file_path.suffix == ".h5": + with h5py.File(file_path, "r") as f: + first_key = list(f.keys())[0] + first_value = f[first_key] + if isinstance(first_value, h5py.Dataset): + data_format = Dataset.ARRAYS + else: + data_format = Dataset.TIME_PERIOD_ARRAYS + subkeys = list(first_value.keys()) + if len(subkeys) > 0: + time_period = subkeys[0] + else: + data_format = Dataset.FLAT_FILE + dataset = type( + "Dataset", + (Dataset,), + { + "name": file_path.stem, + "label": file_path.stem, + "data_format": data_format, + "file_path": file_path, + "time_period": time_period, + }, + )() + + return dataset + + @staticmethod + def from_dataframe(dataframe: pd.DataFrame, time_period: str = None): + """Creates a dataset from a DataFrame. + + Returns: + Dataset: The dataset. + """ + dataset = type( + "Dataset", + (Dataset,), + { + "name": "dataframe", + "label": "DataFrame", + "data_format": Dataset.FLAT_FILE, + "file_path": "dataframe", + "time_period": time_period, + "load": lambda self: dataframe, + }, + )() + + return dataset + + def upload_to_huggingface( + self, owner_name: str, model_name: str, file_name: str + ): + """Uploads the dataset to HuggingFace. + + Args: + owner_name (str): The owner name. + model_name (str): The model name. + """ + + print( + f"Uploading to HuggingFace {owner_name}/{model_name}/{file_name}", + file=sys.stderr, + ) + + token = get_or_prompt_hf_token() + api = HfApi() + + api.upload_file( + path_or_fileobj=self.file_path, + path_in_repo=file_name, + repo_id=f"{owner_name}/{model_name}", + repo_type="model", + token=token, + ) + + def download_from_huggingface( + self, + owner_name: str, + model_name: str, + file_name: str, + version: str = None, + ): + """Downloads the dataset from HuggingFace. + + Args: + owner_name (str): The owner name. + model_name (str): The model name. + """ + + print( + f"Downloading from HuggingFace {owner_name}/{model_name}/{file_name}", + file=sys.stderr, + ) + + download_huggingface_dataset( + repo=f"{owner_name}/{model_name}", + repo_filename=file_name, + version=version, + local_dir=self.file_path.parent, + ) diff --git a/src/policyengine_data/multi_year_dataset.py b/src/policyengine_data/multi_year_dataset.py index 8b13789..76e2119 100644 --- a/src/policyengine_data/multi_year_dataset.py +++ b/src/policyengine_data/multi_year_dataset.py @@ -1 +1,167 @@ +""" +Class for handling multi-year datasets in PolicyEngine. +""" +import shutil +from pathlib import Path +from typing import Dict, List, Optional + +import h5py +import numpy as np +import pandas as pd + +from policyengine_data.single_year_dataset import SingleYearDataset + + +class MultiYearDataset: + datasets: Dict[int, SingleYearDataset] + + def __init__( + self, + file_path: Optional[str] = None, + datasets: Optional[List[SingleYearDataset]] = None, + ): + if datasets is not None: + self.datasets = {} + for dataset in datasets: + if not isinstance(dataset, SingleYearDataset): + raise TypeError( + "All items in datasets must be of type SingleYearDataset." + ) + year = dataset.time_period + self.datasets[year] = dataset + + if file_path is not None: + self.validate_file_path(file_path) + with pd.HDFStore(file_path) as f: + self.datasets = {} + + # First, discover all years and entities in the file + years_entities = {} # {year: {entity_name: df}} + + for key in f.keys(): + parts = key.strip("/").split("/") + if len(parts) == 2 and parts[0] != "time_period": + entity_name, year_str = parts + year = int(year_str) + if year not in years_entities: + years_entities[year] = {} + years_entities[year][entity_name] = f[key] + + # Create SingleYearDataset for each year + for year, entities in years_entities.items(): + self.datasets[year] = SingleYearDataset( + entities=entities, + time_period=year, + ) + + self.data_format = "time_period_arrays" # remove once -core does not expect different data formats + self.time_period = ( + list(sorted(self.datasets.keys()))[0] if self.datasets else None + ) + + def get_year(self, time_period: int) -> "SingleYearDataset": + if time_period in self.datasets: + return self.datasets[time_period] + else: + raise ValueError(f"No dataset found for year {time_period}.") + + def __getitem__(self, time_period: int) -> "SingleYearDataset": + return self.get_year(time_period) + + def save(self, file_path: str) -> None: + Path(file_path).unlink( + missing_ok=True + ) # Remove existing file if it exists + with pd.HDFStore(file_path) as f: + for year, dataset in self.datasets.items(): + for entity_name, entity_df in dataset.entities.items(): + f.put( + f"{entity_name}/{year}", + entity_df, + format="table", + data_columns=True, + ) + f.put( + f"time_period/{year}", + pd.Series([year]), + format="table", + data_columns=True, + ) + + def copy(self) -> "MultiYearDataset": + new_datasets = { + year: dataset.copy() for year, dataset in self.datasets.items() + } + return MultiYearDataset(datasets=list(new_datasets.values())) + + @staticmethod + def validate_file_path(file_path: str) -> None: + if not file_path.endswith(".h5"): + raise ValueError( + "File path must end with '.h5' for MultiYearDataset." + ) + if not Path(file_path).exists(): + raise FileNotFoundError(f"File not found: {file_path}") + + # Check if the file contains datasets for multiple years + with h5py.File(file_path, "r") as f: + required_entities = ["person", "household"] + for entity in required_entities: + if entity not in f: + raise ValueError( + f"No data for '{entity}' found in file: {file_path}" + ) + entity_group = f[entity] + if not any(key.isdigit() for key in entity_group.keys()): + raise ValueError( + f"No year data for '{entity}' found in file: {file_path}" + ) + + def load(self) -> Dict[str, Dict[int, np.ndarray]]: + data = {} + for year, dataset in self.datasets.items(): + for entity_name, entity_df in dataset.entities.items(): + for col in entity_df.columns: + if col not in data: + data[col] = {} + data[col][year] = entity_df[col].values + return data + + def remove(self) -> None: + """Removes the dataset from disk.""" + if self.exists(): + self.file_path.unlink() + + def store_file(self, file_path: str): + """Moves a file to the dataset's file path. + + Args: + file_path (str): The file path to move. + """ + + file_path = Path(file_path) + if not file_path.exists(): + raise FileNotFoundError(f"File {file_path} does not exist.") + shutil.move(file_path, self.file_path) + + @property + def variables(self) -> Dict[int, Dict[str, List[str]]]: + """ + Returns a dictionary mapping years to entity variables dictionaries. + """ + variables_by_year = {} + + for year, dataset in self.datasets.items(): + variables_by_year[year] = dataset.variables + + return variables_by_year + + @property + def exists(self) -> bool: + """Checks whether the dataset exists. + + Returns: + bool: Whether the dataset exists. + """ + return self.file_path.exists() diff --git a/src/policyengine_data/single_year_dataset.py b/src/policyengine_data/single_year_dataset.py index 8b13789..7639b35 100644 --- a/src/policyengine_data/single_year_dataset.py +++ b/src/policyengine_data/single_year_dataset.py @@ -1 +1,164 @@ +""" +Class for handling single-year datasets in PolicyEngine. +""" +import shutil +from pathlib import Path +from typing import Dict, List, Optional + +import h5py +import pandas as pd +from policyengine_core.simulations import Microsimulation + + +class SingleYearDataset: + entities: Dict[str, pd.DataFrame] + time_period: int # manually convert to str when -core expects str + + def __init__( + self, + file_path: Optional[str] = None, + entities: Optional[Dict[str, pd.DataFrame]] = None, + time_period: Optional[int] = 2025, + ) -> None: + self.entities: Dict[str, pd.DataFrame] = {} + + if file_path is not None: + self.validate_file_path(file_path) + with pd.HDFStore(file_path) as f: + self.time_period = int(f["time_period"].iloc[0]) + # Load all entities from the file (except time_period) + for key in f.keys(): + if key != "/time_period": + entity_name = key.strip("/") + self.entities[entity_name] = f[entity_name] + else: + if entities is None: + raise ValueError( + "Must provide either a file path or a dictionary of entities' dataframes." + ) + self.entities = entities.copy() + self.time_period = time_period + + self.data_format = "arrays" # remove once -core does not expect different data formats + self.tables = tuple(self.entities.values()) + self.table_names = tuple(self.entities.keys()) + + @staticmethod + def validate_file_path(file_path: str) -> None: + if not file_path.endswith(".h5"): + raise ValueError("File path must end with '.h5' for Dataset.") + if not Path(file_path).exists(): + raise FileNotFoundError(f"File not found: {file_path}") + + with h5py.File(file_path, "r") as f: + required_datasets = [ + "time_period", + "person", + "household", + ] # all datasets will have at least person and household entities + for dataset in required_datasets: + if dataset not in f: + raise ValueError( + f"Dataset '{dataset}' not found in the file: {file_path}" + ) + + def save(self, file_path: str) -> None: + with pd.HDFStore(file_path) as f: + for entity, df in self.entities.items(): + f.put(entity, df, format="table", data_columns=True) + f.put("time_period", pd.Series([self.time_period]), format="table") + + def load(self) -> Dict[str, pd.Series]: + data = {} + for entity_name, entity_df in self.entities.items(): + for col in entity_df.columns: + data[col] = entity_df[col].values + + return data + + def copy(self) -> "SingleYearDataset": + return SingleYearDataset( + entities={name: df.copy() for name, df in self.entities.items()}, + time_period=self.time_period, + ) + + def remove(self) -> None: + """Removes the dataset from disk.""" + if self.exists(): + self.file_path.unlink() + + def store_file(self, file_path: str): + """Moves a file to the dataset's file path. + + Args: + file_path (str): The file path to move. + """ + + file_path = Path(file_path) + if not file_path.exists(): + raise FileNotFoundError(f"File {file_path} does not exist.") + shutil.move(file_path, self.file_path) + + def validate(self) -> None: + # Check for NaNs in the tables + for df in self.tables: + for col in df.columns: + if df[col].isna().any(): + raise ValueError(f"Column '{col}' contains NaN values.") + + @staticmethod + def from_simulation( + simulation: "Microsimulation", + time_period: int = 2025, + entity_names_to_include: Optional[List[str]] = None, + ) -> "SingleYearDataset": + entity_dfs = {} + + # If no entity names specified, use all available entities + if entity_names_to_include is None: + entity_names = list( + set( + simulation.tax_benefit_system.variables[var].entity.key + for var in simulation.input_variables + ) + ) + else: + entity_names = entity_names_to_include + + for entity in entity_names: + input_variables = [ + variable + for variable in simulation.input_variables + if simulation.tax_benefit_system.variables[variable].entity.key + == entity + ] + entity_dfs[entity] = simulation.calculate_dataframe( + input_variables, period=time_period + ) + + return SingleYearDataset( + entities=entity_dfs, + time_period=time_period, + ) + + @property + def variables(self) -> Dict[str, List[str]]: + """ + Returns a dictionary mapping entity names to lists of variables (column names). + """ + variables_by_entity = {} + + for entity_name, entity_df in self.entities.items(): + variables_by_entity[entity_name] = entity_df.columns.tolist() + + return variables_by_entity + + @property + def exists(self) -> bool: + """Checks whether the dataset exists. + + Returns: + bool: Whether the dataset exists. + """ + return self.file_path.exists() diff --git a/src/policyengine_data/tools/__init__.py b/src/policyengine_data/tools/__init__.py new file mode 100644 index 0000000..59fb25b --- /dev/null +++ b/src/policyengine_data/tools/__init__.py @@ -0,0 +1,2 @@ +from .hugging_face import download_huggingface_dataset, get_or_prompt_hf_token +from .win_file_manager import WindowsAtomicFileManager diff --git a/src/policyengine_data/tools/hugging_face.py b/src/policyengine_data/tools/hugging_face.py new file mode 100644 index 0000000..b55a385 --- /dev/null +++ b/src/policyengine_data/tools/hugging_face.py @@ -0,0 +1,78 @@ +import os +import traceback +import warnings +from getpass import getpass + +from huggingface_hub import ModelInfo, hf_hub_download, model_info +from huggingface_hub.errors import RepositoryNotFoundError + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + +def download_huggingface_dataset( + repo: str, + repo_filename: str, + version: str = None, + local_dir: str | None = None, +): + """ + Download a dataset from the Hugging Face Hub. + + Args: + repo (str): The Hugging Face repo name, in format "{org}/{repo}". + repo_filename (str): The filename of the dataset. + version (str, optional): The version of the dataset. Defaults to None. + local_dir (str, optional): The local directory to save the dataset to. Defaults to None. + """ + # Attempt connection to Hugging Face model_info endpoint + # (https://huggingface.co/docs/huggingface_hub/v0.26.5/en/package_reference/hf_api#huggingface_hub.HfApi.model_info) + # Attempt to fetch model info to determine if repo is private + # A RepositoryNotFoundError & 401 likely means the repo is private, + # but this error will also surface for public repos with malformed URL, etc. + try: + fetched_model_info: ModelInfo = model_info(repo) + is_repo_private: bool = fetched_model_info.private + except RepositoryNotFoundError as e: + # If this error type arises, it's likely the repo is private; see docs above + is_repo_private = True + pass + except Exception as e: + # Otherwise, there probably is just a download error + raise Exception( + f"Unable to download dataset {repo_filename} from Hugging Face. This may be because the repo " + + f"is private, the URL is malformed, or the dataset does not exist. The full error is {traceback.format_exc()}" + ) + + authentication_token: str = None + if is_repo_private: + authentication_token: str = get_or_prompt_hf_token() + + return hf_hub_download( + repo_id=repo, + repo_type="model", + filename=repo_filename, + revision=version, + token=authentication_token, + local_dir=local_dir, + ) + + +def get_or_prompt_hf_token() -> str: + """ + Either get the Hugging Face token from the environment, + or prompt the user for it and store it in the environment. + + Returns: + str: The Hugging Face token. + """ + + token = os.environ.get("HUGGING_FACE_TOKEN") + if token is None: + token = getpass( + "Enter your Hugging Face token (or set HUGGING_FACE_TOKEN environment variable): " + ) + # Optionally store in env for subsequent calls in same session + os.environ["HUGGING_FACE_TOKEN"] = token + + return token diff --git a/src/policyengine_data/tools/win_file_manager.py b/src/policyengine_data/tools/win_file_manager.py new file mode 100644 index 0000000..9228dcd --- /dev/null +++ b/src/policyengine_data/tools/win_file_manager.py @@ -0,0 +1,46 @@ +import os +import tempfile +from pathlib import Path +from threading import Lock + + +class WindowsAtomicFileManager: + """ + https://stackoverflow.com/a/2368286 + - Each instance manages a specific logical file by name. + - Files are written atomically using temporary files. + - Thread-safe operations are ensured using a `Lock`. + - Temporary files are created in the system's temporary directory. + - Replace the target file with the temporary file + - For any fallback, Cleanup temporary file if replacement fails + """ + + def __init__(self, file: Path): + self.logical_name = file.name + self.target_path = file + self.lock = Lock() + + def write(self, content: bytes): + with self.lock: + with tempfile.NamedTemporaryFile( + mode="wb", + dir=self.target_path.parent.absolute().as_posix(), + delete=False, + ) as temp_file: + temp_file.write(content) + temp_file.flush() + os.fsync(temp_file.fileno()) + temp_file_path = Path(temp_file.name) + + try: + os.replace(temp_file_path, self.target_path) + except Exception as e: + temp_file_path.unlink(missing_ok=True) + raise e + + def read(self) -> str: + with self.lock: + if not self.target_path.exists(): + raise FileNotFoundError(f"{self.target_path} does not exist.") + with open(self.target_path, "rb") as file: + return file.read() diff --git a/tests/test_data_download_upload_tools.py b/tests/test_data_download_upload_tools.py new file mode 100644 index 0000000..9013fc5 --- /dev/null +++ b/tests/test_data_download_upload_tools.py @@ -0,0 +1,77 @@ +""" +Test data download and upload tools' functionality. +""" + +from pathlib import Path +from tempfile import NamedTemporaryFile +import sys +import threading +from policyengine_core.tools.win_file_manager import WindowsAtomicFileManager +import tempfile +from pathlib import Path +import uuid + + +def test_atomic_write(): + if sys.platform != "win32": + from policyengine_core.data.dataset import atomic_write + + with NamedTemporaryFile(mode="w") as file: + file.write("Hello, world\n") + file.flush() + # Open the file before overwriting + with open(file.name, "r") as file_original: + + atomic_write(Path(file.name), "NOPE\n".encode()) + + # Open file descriptor still points to the old node + assert file_original.readline() == "Hello, world\n" + # But if I open it again it has the new content + with open(file.name, "r") as file_updated: + assert file_updated.readline() == "NOPE\n" + + +def test_atomic_write_windows(): + if sys.platform == "win32": + temp_dir = Path(tempfile.gettempdir()) + temp_files = [ + temp_dir / f"tempfile_{uuid.uuid4().hex}.tmp" for _ in range(5) + ] + + managers = [WindowsAtomicFileManager(path) for path in temp_files] + + contents_list = [ + [f"Content_{i}_{j}".encode() for j in range(5)] for i in range(5) + ] + + check_results = [[] for _ in range(5)] + + threads = [] + for i in range(5): + thread = threading.Thread( + target=file_task, + args=(managers[i], contents_list[i], check_results[i]), + ) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + for i, results in enumerate(check_results): + for expected, actual in results: + assert ( + expected == actual + ), f"Mismatch in file {i}: expected {expected}, got {actual}" + + for temp_file in temp_files: + if temp_file.exists(): + temp_file.unlink() + + +def file_task(manager, contents, check_results): + for content in contents: + manager.write(content) + actual_content = manager.read().decode() + expected_content = content.decode() + check_results.append((expected_content, actual_content)) diff --git a/tests/test_dataset_classes.py b/tests/test_dataset_classes.py new file mode 100644 index 0000000..5738b0b --- /dev/null +++ b/tests/test_dataset_classes.py @@ -0,0 +1,111 @@ +""" +Test cases for SingleYearDataset and MultiYearDataset classes. +""" + + +def test_single_year_dataset() -> None: + from policyengine_data.single_year_dataset import SingleYearDataset + import pandas as pd + + # Create a sample dataset + entities = { + "person": pd.DataFrame({"id": [1, 2], "name": ["Alice", "Bob"]}), + "household": pd.DataFrame({"id": [1], "income": [50000]}), + } + time_period = 2023 + + # Initialize SingleYearDataset + dataset = SingleYearDataset(entities=entities, time_period=time_period) + dataset.validate() + + # Check if entities are correctly set + assert len(dataset.entities) == 2 + assert "person" in dataset.entities + assert "household" in dataset.entities + pd.testing.assert_frame_equal( + dataset.entities["person"], entities["person"] + ) + pd.testing.assert_frame_equal( + dataset.entities["household"], entities["household"] + ) + assert dataset.time_period == time_period + + # Save the dataset to a file + file_path = "test_single_year_dataset.h5" + dataset.save(file_path) + + loaded_dataset = SingleYearDataset(file_path=file_path) + loaded_dataset.validate() + + # Check if loaded entities match original entities + assert len(loaded_dataset.entities) == len(entities) + for entity_name in entities: + pd.testing.assert_frame_equal( + loaded_dataset.entities[entity_name], entities[entity_name] + ) + assert loaded_dataset.time_period == time_period + + variables = dataset.variables + assert variables.keys() == {"person", "household"} + + +def test_multi_year_dataset() -> None: + from policyengine_data.multi_year_dataset import MultiYearDataset + from policyengine_data.single_year_dataset import SingleYearDataset + import pandas as pd + + # Create SingleYearDataset instances for multiple years + entities_2023 = { + "person": pd.DataFrame({"id": [1, 2], "name": ["Alice", "Bob"]}), + "household": pd.DataFrame({"id": [1], "income": [50000]}), + } + entities_2024 = { + "person": pd.DataFrame({"id": [1, 2], "name": ["Alice", "Bob"]}), + "household": pd.DataFrame({"id": [1], "income": [55000]}), + } + + dataset_2023 = SingleYearDataset(entities=entities_2023, time_period=2023) + dataset_2024 = SingleYearDataset(entities=entities_2024, time_period=2024) + + # Initialize MultiYearDataset with list of SingleYearDataset instances + multi_dataset = MultiYearDataset(datasets=[dataset_2023, dataset_2024]) + + # Check if datasets are correctly set + assert len(multi_dataset.datasets) == 2 + assert 2023 in multi_dataset.datasets + assert 2024 in multi_dataset.datasets + + retrieved_2023 = multi_dataset.get_year(2023) + assert isinstance(retrieved_2023, SingleYearDataset) + pd.testing.assert_frame_equal( + retrieved_2023.entities["person"], entities_2023["person"] + ) + retrieved_2024 = multi_dataset[2024] + assert isinstance(retrieved_2024, SingleYearDataset) + pd.testing.assert_frame_equal( + retrieved_2024.entities["household"], entities_2024["household"] + ) + + # Save the dataset to a file + file_path = "test_multi_year_dataset.h5" + multi_dataset.save(file_path) + loaded_multi_dataset = MultiYearDataset(file_path=file_path) + + # Check if loaded datasets match original datasets + assert len(loaded_multi_dataset.datasets) == 2 + assert 2023 in loaded_multi_dataset.datasets + assert 2024 in loaded_multi_dataset.datasets + + for year in [2023, 2024]: + loaded_year_data = loaded_multi_dataset[year] + original_year_data = multi_dataset[year] + + for entity_name in original_year_data.entities: + pd.testing.assert_frame_equal( + loaded_year_data.entities[entity_name], + original_year_data.entities[entity_name], + ) + + variables_by_year = multi_dataset.variables + assert variables_by_year.keys() == {2023, 2024} + assert ["person", "household"] == list(variables_by_year[2023].keys()) diff --git a/tests/test_hugging_face.py b/tests/test_hugging_face.py new file mode 100644 index 0000000..d8c29bf --- /dev/null +++ b/tests/test_hugging_face.py @@ -0,0 +1,162 @@ +""" +Test Hugging Face tools. +""" + +import os +import pytest +from unittest.mock import patch +from huggingface_hub import ModelInfo +from huggingface_hub.errors import RepositoryNotFoundError +from policyengine_core.tools.hugging_face import ( + get_or_prompt_hf_token, + download_huggingface_dataset, +) + + +class TestHuggingFaceDownload: + def test_download_public_repo(self): + """Test downloading from a public repo""" + test_repo = "test_repo" + test_filename = "test_filename" + test_version = "test_version" + test_dir = "test_dir" + + with patch( + "policyengine_core.tools.hugging_face.hf_hub_download" + ) as mock_download: + with patch( + "policyengine_core.tools.hugging_face.model_info" + ) as mock_model_info: + # Create mock ModelInfo object emulating public repo + test_id = 0 + mock_model_info.return_value = ModelInfo( + id=test_id, private=False + ) + + download_huggingface_dataset( + test_repo, test_filename, test_version, test_dir + ) + + mock_download.assert_called_with( + repo_id=test_repo, + repo_type="model", + filename=test_filename, + revision=test_version, + local_dir=test_dir, + token=None, + ) + + def test_download_private_repo(self): + """Test downloading from a private repo""" + test_repo = "test_repo" + test_filename = "test_filename" + test_version = "test_version" + test_dir = "test_dir" + + with patch( + "policyengine_core.tools.hugging_face.hf_hub_download" + ) as mock_download: + with patch( + "policyengine_core.tools.hugging_face.model_info" + ) as mock_model_info: + mock_model_info.side_effect = RepositoryNotFoundError( + "Test error" + ) + with patch( + "policyengine_core.tools.hugging_face.get_or_prompt_hf_token" + ) as mock_token: + mock_token.return_value = "test_token" + + download_huggingface_dataset( + test_repo, test_filename, test_version, test_dir + ) + mock_download.assert_called_with( + repo_id=test_repo, + repo_type="model", + filename=test_filename, + revision=test_version, + token=mock_token.return_value, + local_dir=test_dir, + ) + + def test_download_private_repo_no_token(self): + """Test handling of private repo with no token""" + test_repo = "test_repo" + test_filename = "test_filename" + test_version = "test_version" + test_dir = "test_dir" + + with patch( + "policyengine_core.tools.hugging_face.hf_hub_download" + ) as mock_download: + with patch( + "policyengine_core.tools.hugging_face.model_info" + ) as mock_model_info: + mock_model_info.side_effect = RepositoryNotFoundError( + "Test error" + ) + with patch( + "policyengine_core.tools.hugging_face.get_or_prompt_hf_token" + ) as mock_token: + mock_token.return_value = "" + + with pytest.raises(Exception): + download_huggingface_dataset( + test_repo, test_filename, test_version, test_dir + ) + mock_download.assert_not_called() + + +class TestGetOrPromptHfToken: + def test_get_token_from_environment(self): + """Test retrieving token when it exists in environment variables""" + test_token = "test_token_123" + with patch.dict( + os.environ, {"HUGGING_FACE_TOKEN": test_token}, clear=True + ): + result = get_or_prompt_hf_token() + assert result == test_token + + def test_get_token_from_user_input(self): + """Test retrieving token via user input when not in environment""" + test_token = "user_input_token_456" + + # Mock both empty environment and user input + with patch.dict(os.environ, {}, clear=True): + with patch( + "policyengine_core.tools.hugging_face.getpass", + return_value=test_token, + ): + result = get_or_prompt_hf_token() + assert result == test_token + + # Verify token was stored in environment + assert os.environ.get("HUGGING_FACE_TOKEN") == test_token + + def test_empty_user_input(self): + """Test handling of empty user input""" + with patch.dict(os.environ, {}, clear=True): + with patch( + "policyengine_core.tools.hugging_face.getpass", return_value="" + ): + result = get_or_prompt_hf_token() + assert result == "" + assert os.environ.get("HUGGING_FACE_TOKEN") == "" + + def test_environment_variable_persistence(self): + """Test that environment variable persists across multiple calls""" + test_token = "persistence_test_token" + + # First call with no environment variable + with patch.dict(os.environ, {}, clear=True): + with patch( + "policyengine_core.tools.hugging_face.getpass", + return_value=test_token, + ): + first_result = get_or_prompt_hf_token() + + # Second call should use environment variable + second_result = get_or_prompt_hf_token() + + assert first_result == second_result == test_token + assert os.environ.get("HUGGING_FACE_TOKEN") == test_token