diff --git a/pdap_api_client/AccessManager.py b/pdap_api_client/AccessManager.py new file mode 100644 index 00000000..87877466 --- /dev/null +++ b/pdap_api_client/AccessManager.py @@ -0,0 +1,123 @@ +from http import HTTPStatus +from typing import Optional + +import requests + +from pdap_api_client.DTOs import RequestType, Namespaces, RequestInfo, ResponseInfo + +API_URL = "https://data-sources-v2.pdap.dev/api" +request_methods = { + RequestType.POST: requests.post, + RequestType.PUT: requests.put, + RequestType.GET: requests.get, + RequestType.DELETE: requests.delete, +} + + +class CustomHTTPException(Exception): + pass + + +def build_url( + namespace: Namespaces, + subdomains: Optional[list[str]] = None +): + url = f"{API_URL}/{namespace.value}" + if subdomains is not None: + url = f"{url}/{'/'.join(subdomains)}" + return url + + +class AccessManager: + """ + Manages login, api key, access and refresh tokens + """ + def __init__(self, email: str, password: str, api_key: Optional[str] = None): + self.access_token = None + self.refresh_token = None + self.api_key = api_key + self.login(email=email, password=password) + + # TODO: Add means to refresh if token expired. + + def load_api_key(self): + url = build_url( + namespace=Namespaces.AUTH, + subdomains=["api-key"] + ) + request_info = RequestInfo( + type_ = RequestType.POST, + url=url, + headers=self.jwt_header() + ) + response_info = self.make_request(request_info) + self.api_key = response_info.data["api_key"] + + def refresh_access_token(self): + url = build_url( + namespace=Namespaces.AUTH, + subdomains=["refresh-session"], + ) + raise NotImplementedError("Waiting on https://github.com/Police-Data-Accessibility-Project/data-sources-app/issues/566") + + def make_request(self, ri: RequestInfo) -> ResponseInfo: + try: + response = request_methods[ri.type_]( + ri.url, + json=ri.json, + headers=ri.headers, + params=ri.params, + timeout=ri.timeout + ) + response.raise_for_status() + except requests.RequestException as e: + # TODO: Precise string matching here is brittle. Consider changing later. + if e.response.json().message == "Token is expired. Please request a new token.": + self.refresh_access_token() + return self.make_request(ri) + else: + raise CustomHTTPException(f"Error making {ri.type_} request to {ri.url}: {e}") + return ResponseInfo( + status_code=HTTPStatus(response.status_code), + data=response.json() + ) + + def login(self, email: str, password: str): + url = build_url( + namespace=Namespaces.AUTH, + subdomains=["login"] + ) + request_info = RequestInfo( + type_=RequestType.POST, + url=url, + json={ + "email": email, + "password": password + } + ) + response_info = self.make_request(request_info) + data = response_info.data + self.access_token = data["access_token"] + self.refresh_token = data["refresh_token"] + + + def jwt_header(self) -> dict: + """ + Retrieve JWT header + Returns: Dictionary of Bearer Authorization with JWT key + """ + return { + "Authorization": f"Bearer {self.access_token}" + } + + def api_key_header(self): + """ + Retrieve API key header + Returns: Dictionary of Basic Authorization with API key + + """ + if self.api_key is None: + self.load_api_key() + return { + "Authorization": f"Basic {self.api_key}" + } diff --git a/pdap_api_client/DTOs.py b/pdap_api_client/DTOs.py new file mode 100644 index 00000000..31c8c2cf --- /dev/null +++ b/pdap_api_client/DTOs.py @@ -0,0 +1,54 @@ +from enum import Enum +from http import HTTPStatus +from typing import Optional + +from pydantic import BaseModel + + +class MatchAgencyInfo(BaseModel): + submitted_name: str + id: str + +class ApprovalStatus(Enum): + APPROVED = "approved" + REJECTED = "rejected" + PENDING = "pending" + NEEDS_IDENTIFICATION = "needs identification" + + + +class UniqueURLDuplicateInfo(BaseModel): + original_url: str + approval_status: ApprovalStatus + rejection_note: str + +class UniqueURLResponseInfo(BaseModel): + is_unique: bool + duplicates: list[UniqueURLDuplicateInfo] + + +class Namespaces(Enum): + AUTH = "auth" + MATCH = "match" + CHECK = "check" + + +class RequestType(Enum): + POST = "POST" + PUT = "PUT" + GET = "GET" + DELETE = "DELETE" + + +class RequestInfo(BaseModel): + type_: RequestType + url: str + json: Optional[dict] = None + headers: Optional[dict] = None + params: Optional[dict] = None + timeout: Optional[int] = 10 + + +class ResponseInfo(BaseModel): + status_code: HTTPStatus + data: Optional[dict] diff --git a/pdap_api_client/PDAPClient.py b/pdap_api_client/PDAPClient.py new file mode 100644 index 00000000..6c03ce0f --- /dev/null +++ b/pdap_api_client/PDAPClient.py @@ -0,0 +1,65 @@ +from typing import List + +from pdap_api_client.AccessManager import build_url, AccessManager +from pdap_api_client.DTOs import MatchAgencyInfo, UniqueURLDuplicateInfo, UniqueURLResponseInfo, Namespaces, \ + RequestType, RequestInfo + + +class PDAPClient: + + def __init__(self, access_manager: AccessManager): + self.access_manager = access_manager + + def match_agency( + self, + name: str, + state: str, + county: str, + locality: str + ) -> List[MatchAgencyInfo]: + """ + Returns agencies, if any, that match or partially match the search criteria + """ + url = build_url( + namespace=Namespaces.MATCH, + subdomains=["agency"] + ) + request_info = RequestInfo( + type_=RequestType.POST, + url=url, + json={ + "name": name, + "state": state, + "county": county, + "locality": locality + } + ) + response_info = self.access_manager.make_request(request_info) + return [MatchAgencyInfo(**agency) for agency in response_info.data["agencies"]] + + + def is_url_unique( + self, + url_to_check: str + ) -> UniqueURLResponseInfo: + """ + Check if a URL is unique. Returns duplicate info otherwise + """ + url = build_url( + namespace=Namespaces.CHECK, + subdomains=["unique-url"] + ) + request_info = RequestInfo( + type_=RequestType.GET, + url=url, + params={ + "url": url_to_check + } + ) + response_info = self.access_manager.make_request(request_info) + duplicates = [UniqueURLDuplicateInfo(**entry) for entry in response_info.data["duplicates"]] + is_unique = (len(duplicates) == 0) + return UniqueURLResponseInfo( + is_unique=is_unique, + duplicates=duplicates + ) diff --git a/pdap_api_client/__init__.py b/pdap_api_client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/source_collectors/__init__.py b/source_collectors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/source_collectors/muckrock/README.md b/source_collectors/muckrock/README.md index d74b77f0..43bae80d 100644 --- a/source_collectors/muckrock/README.md +++ b/source_collectors/muckrock/README.md @@ -56,8 +56,6 @@ pip install -r requirements.txt ### 2. Clone Muckrock database & search locally -~~- `download_muckrock_foia.py` `search_local_foia_json.py`~~ (deprecated) - - scripts to clone the MuckRock foia requests collection for fast local querying (total size <2GB at present) - `create_foia_data_db.py` creates and populates a SQLite database (`foia_data.db`) with all MuckRock foia requests. Various errors outside the scope of this script may occur; a counter (`last_page_fetched.txt`) is created to keep track of the most recent page fetched and inserted into the database. If the program exits prematurely, simply run `create_foia_data_db.py` again to continue where you left off. A log file is created to capture errors for later reference. diff --git a/source_collectors/muckrock/__init__.py b/source_collectors/muckrock/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/source_collectors/muckrock/classes/FOIADBSearcher.py b/source_collectors/muckrock/classes/FOIADBSearcher.py new file mode 100644 index 00000000..391f7a8d --- /dev/null +++ b/source_collectors/muckrock/classes/FOIADBSearcher.py @@ -0,0 +1,65 @@ +import os +import sqlite3 + +import pandas as pd + +from source_collectors.muckrock.constants import FOIA_DATA_DB + +check_results_table_query = """ + SELECT name FROM sqlite_master + WHERE (type = 'table') + AND (name = 'results') + """ + +search_foia_query = """ + SELECT * FROM results + WHERE (title LIKE ? OR tags LIKE ?) + AND (status = 'done') + """ + + +class FOIADBSearcher: + + def __init__(self, db_path = FOIA_DATA_DB): + self.db_path = db_path + if not os.path.exists(self.db_path): + raise FileNotFoundError("foia_data.db does not exist.\nRun create_foia_data_db.py first to create and populate it.") + + + def search(self, search_string: str) -> pd.DataFrame | None: + """ + Searches the foia_data.db database for FOIA request entries matching the provided search string. + + Args: + search_string (str): The string to search for in the `title` and `tags` of the `results` table. + + Returns: + Union[pandas.DataFrame, None]: + - pandas.DataFrame: A DataFrame containing the matching entries from the database. + - None: If an error occurs during the database operation. + + Raises: + sqlite3.Error: If any database operation fails, prints error and returns None. + Exception: If any unexpected error occurs, prints error and returns None. + """ + try: + with sqlite3.connect(self.db_path) as conn: + results_table = pd.read_sql_query(check_results_table_query, conn) + if results_table.empty: + print("The `results` table does not exist in the database.") + return None + + df = pd.read_sql_query( + sql=search_foia_query, + con=conn, + params=[f"%{search_string}%", f"%{search_string}%"] + ) + + except sqlite3.Error as e: + print(f"Sqlite error: {e}") + return None + except Exception as e: + print(f"An unexpected error occurred: {e}") + return None + + return df \ No newline at end of file diff --git a/source_collectors/muckrock/classes/FOIASearcher.py b/source_collectors/muckrock/classes/FOIASearcher.py new file mode 100644 index 00000000..f88f8242 --- /dev/null +++ b/source_collectors/muckrock/classes/FOIASearcher.py @@ -0,0 +1,58 @@ +from typing import Optional + +from source_collectors.muckrock.classes.muckrock_fetchers import FOIAFetcher +from tqdm import tqdm + +class FOIASearcher: + """ + Used for searching FOIA data from MuckRock + """ + + def __init__(self, fetcher: FOIAFetcher, search_term: Optional[str] = None): + self.fetcher = fetcher + self.search_term = search_term + + def fetch_page(self) -> dict | None: + """ + Fetches the next page of results using the fetcher. + """ + data = self.fetcher.fetch_next_page() + if data is None or data.get("results") is None: + return None + return data + + def filter_results(self, results: list[dict]) -> list[dict]: + """ + Filters the results based on the search term. + Override or modify as needed for custom filtering logic. + """ + if self.search_term: + return [result for result in results if self.search_term.lower() in result["title"].lower()] + return results + + def update_progress(self, pbar: tqdm, results: list[dict]) -> int: + """ + Updates the progress bar and returns the count of results processed. + """ + num_results = len(results) + pbar.update(num_results) + return num_results + + def search_to_count(self, max_count: int) -> list[dict]: + """ + Fetches and processes results up to a maximum count. + """ + count = max_count + all_results = [] + with tqdm(total=max_count, desc="Fetching results", unit="result") as pbar: + while count > 0: + data = self.fetch_page() + if not data: + break + + results = self.filter_results(data["results"]) + all_results.extend(results) + count -= self.update_progress(pbar, results) + + return all_results + diff --git a/source_collectors/muckrock/classes/SQLiteClient.py b/source_collectors/muckrock/classes/SQLiteClient.py new file mode 100644 index 00000000..96a59d82 --- /dev/null +++ b/source_collectors/muckrock/classes/SQLiteClient.py @@ -0,0 +1,38 @@ +import logging +import sqlite3 + + +class SQLClientError(Exception): + pass + + +class SQLiteClient: + + def __init__(self, db_path: str) -> None: + self.conn = sqlite3.connect(db_path) + + def execute_query(self, query: str, many=None): + + try: + if many is not None: + self.conn.executemany(query, many) + else: + self.conn.execute(query) + self.conn.commit() + except sqlite3.Error as e: + print(f"SQLite error: {e}") + error_msg = f"Failed to execute query due to SQLite error: {e}" + logging.error(error_msg) + self.conn.rollback() + raise SQLClientError(error_msg) + +class SQLiteClientContextManager: + + def __init__(self, db_path: str) -> None: + self.client = SQLiteClient(db_path) + + def __enter__(self): + return self.client + + def __exit__(self, exc_type, exc_value, traceback): + self.client.conn.close() \ No newline at end of file diff --git a/source_collectors/muckrock/classes/__init__.py b/source_collectors/muckrock/classes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/source_collectors/muckrock/classes/muckrock_fetchers/AgencyFetcher.py b/source_collectors/muckrock/classes/muckrock_fetchers/AgencyFetcher.py new file mode 100644 index 00000000..b70c07e0 --- /dev/null +++ b/source_collectors/muckrock/classes/muckrock_fetchers/AgencyFetcher.py @@ -0,0 +1,14 @@ +from source_collectors.muckrock.constants import BASE_MUCKROCK_URL +from source_collectors.muckrock.classes.muckrock_fetchers.MuckrockFetcher import FetchRequest, MuckrockFetcher + + +class AgencyFetchRequest(FetchRequest): + agency_id: int + +class AgencyFetcher(MuckrockFetcher): + + def build_url(self, request: AgencyFetchRequest) -> str: + return f"{BASE_MUCKROCK_URL}/agency/{request.agency_id}/" + + def get_agency(self, agency_id: int): + return self.fetch(AgencyFetchRequest(agency_id=agency_id)) \ No newline at end of file diff --git a/source_collectors/muckrock/classes/muckrock_fetchers/FOIAFetcher.py b/source_collectors/muckrock/classes/muckrock_fetchers/FOIAFetcher.py new file mode 100644 index 00000000..619b92ae --- /dev/null +++ b/source_collectors/muckrock/classes/muckrock_fetchers/FOIAFetcher.py @@ -0,0 +1,36 @@ +from source_collectors.muckrock.classes.muckrock_fetchers.MuckrockFetcher import MuckrockFetcher, FetchRequest +from source_collectors.muckrock.constants import BASE_MUCKROCK_URL + +FOIA_BASE_URL = f"{BASE_MUCKROCK_URL}/foia" + + +class FOIAFetchRequest(FetchRequest): + page: int + page_size: int + + +class FOIAFetcher(MuckrockFetcher): + + def __init__(self, start_page: int = 1, per_page: int = 100): + """ + Constructor for the FOIAFetcher class. + + Args: + start_page (int): The page number to start fetching from (default is 1). + per_page (int): The number of results to fetch per page (default is 100). + """ + self.current_page = start_page + self.per_page = per_page + + def build_url(self, request: FOIAFetchRequest) -> str: + return f"{FOIA_BASE_URL}?page={request.page}&page_size={request.page_size}&format=json" + + def fetch_next_page(self) -> dict | None: + """ + Fetches data from a specific page of the MuckRock FOIA API. + """ + page = self.current_page + self.current_page += 1 + request = FOIAFetchRequest(page=page, page_size=self.per_page) + return self.fetch(request) + diff --git a/source_collectors/muckrock/classes/muckrock_fetchers/FOIALoopFetcher.py b/source_collectors/muckrock/classes/muckrock_fetchers/FOIALoopFetcher.py new file mode 100644 index 00000000..ad78f0b6 --- /dev/null +++ b/source_collectors/muckrock/classes/muckrock_fetchers/FOIALoopFetcher.py @@ -0,0 +1,31 @@ +from datasets import tqdm + +from source_collectors.muckrock.constants import BASE_MUCKROCK_URL +from source_collectors.muckrock.classes.muckrock_fetchers.MuckrockFetcher import FetchRequest +from source_collectors.muckrock.classes.muckrock_fetchers.MuckrockLoopFetcher import MuckrockLoopFetcher + +class FOIALoopFetchRequest(FetchRequest): + jurisdiction: int + +class FOIALoopFetcher(MuckrockLoopFetcher): + + def __init__(self, initial_request: FOIALoopFetchRequest): + super().__init__(initial_request) + self.pbar_records = tqdm( + desc="Fetching FOIA records", + unit="record", + ) + self.num_found = 0 + self.results = [] + + def process_results(self, results: list[dict]): + self.results.extend(results) + + def build_url(self, request: FOIALoopFetchRequest): + return f"{BASE_MUCKROCK_URL}/foia/?status=done&jurisdiction={request.jurisdiction}" + + def report_progress(self): + old_num_found = self.num_found + self.num_found = len(self.results) + difference = self.num_found - old_num_found + self.pbar_records.update(difference) diff --git a/source_collectors/muckrock/classes/muckrock_fetchers/JurisdictionByIDFetcher.py b/source_collectors/muckrock/classes/muckrock_fetchers/JurisdictionByIDFetcher.py new file mode 100644 index 00000000..a038418c --- /dev/null +++ b/source_collectors/muckrock/classes/muckrock_fetchers/JurisdictionByIDFetcher.py @@ -0,0 +1,14 @@ +from source_collectors.muckrock.constants import BASE_MUCKROCK_URL +from source_collectors.muckrock.classes.muckrock_fetchers.MuckrockFetcher import FetchRequest, MuckrockFetcher + + +class JurisdictionByIDFetchRequest(FetchRequest): + jurisdiction_id: int + +class JurisdictionByIDFetcher(MuckrockFetcher): + + def build_url(self, request: JurisdictionByIDFetchRequest) -> str: + return f"{BASE_MUCKROCK_URL}/jurisdiction/{request.jurisdiction_id}/" + + def get_jurisdiction(self, jurisdiction_id: int) -> dict: + return self.fetch(request=JurisdictionByIDFetchRequest(jurisdiction_id=jurisdiction_id)) diff --git a/source_collectors/muckrock/classes/muckrock_fetchers/JurisdictionLoopFetcher.py b/source_collectors/muckrock/classes/muckrock_fetchers/JurisdictionLoopFetcher.py new file mode 100644 index 00000000..46c1bbf6 --- /dev/null +++ b/source_collectors/muckrock/classes/muckrock_fetchers/JurisdictionLoopFetcher.py @@ -0,0 +1,47 @@ +from tqdm import tqdm + +from source_collectors.muckrock.constants import BASE_MUCKROCK_URL +from source_collectors.muckrock.classes.muckrock_fetchers.MuckrockFetcher import FetchRequest +from source_collectors.muckrock.classes.muckrock_fetchers.MuckrockLoopFetcher import MuckrockLoopFetcher + + +class JurisdictionLoopFetchRequest(FetchRequest): + level: str + parent: int + town_names: list + +class JurisdictionLoopFetcher(MuckrockLoopFetcher): + + def __init__(self, initial_request: JurisdictionLoopFetchRequest): + super().__init__(initial_request) + self.town_names = initial_request.town_names + self.pbar_jurisdictions = tqdm( + total=len(self.town_names), + desc="Fetching jurisdictions", + unit="jurisdiction", + position=0, + leave=False + ) + self.pbar_page = tqdm( + desc="Processing pages", + unit="page", + position=1, + leave=False + ) + self.num_found = 0 + self.jurisdictions = {} + + def build_url(self, request: JurisdictionLoopFetchRequest) -> str: + return f"{BASE_MUCKROCK_URL}/jurisdiction/?level={request.level}&parent={request.parent}" + + def process_results(self, results: list[dict]): + for item in results: + if item["name"] in self.town_names: + self.jurisdictions[item["name"]] = item["id"] + + def report_progress(self): + old_num_found = self.num_found + self.num_found = len(self.jurisdictions) + difference = self.num_found - old_num_found + self.pbar_jurisdictions.update(difference) + self.pbar_page.update(1) diff --git a/source_collectors/muckrock/classes/muckrock_fetchers/MuckrockFetcher.py b/source_collectors/muckrock/classes/muckrock_fetchers/MuckrockFetcher.py new file mode 100644 index 00000000..e7a1dff5 --- /dev/null +++ b/source_collectors/muckrock/classes/muckrock_fetchers/MuckrockFetcher.py @@ -0,0 +1,42 @@ +import abc +from abc import ABC +from dataclasses import dataclass + +import requests +from pydantic import BaseModel + +class MuckrockNoMoreDataError(Exception): + pass + +class MuckrockServerError(Exception): + pass + +class FetchRequest(BaseModel): + pass + +class MuckrockFetcher(ABC): + + def fetch(self, request: FetchRequest): + url = self.build_url(request) + response = requests.get(url) + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + print(f"Failed to get records on request `{url}`: {e}") + # If code is 404, raise NoMoreData error + if e.response.status_code == 404: + raise MuckrockNoMoreDataError + if 500 <= e.response.status_code < 600: + raise MuckrockServerError + + + + + return None + + return response.json() + + @abc.abstractmethod + def build_url(self, request: FetchRequest) -> str: + pass + diff --git a/source_collectors/muckrock/classes/muckrock_fetchers/MuckrockLoopFetcher.py b/source_collectors/muckrock/classes/muckrock_fetchers/MuckrockLoopFetcher.py new file mode 100644 index 00000000..2b3d0149 --- /dev/null +++ b/source_collectors/muckrock/classes/muckrock_fetchers/MuckrockLoopFetcher.py @@ -0,0 +1,41 @@ +from abc import ABC, abstractmethod +from time import sleep + +import requests + +from source_collectors.muckrock.classes.muckrock_fetchers.MuckrockFetcher import FetchRequest + + +class MuckrockLoopFetcher(ABC): + + + def __init__(self, initial_request: FetchRequest): + self.initial_request = initial_request + + def loop_fetch(self): + url = self.build_url(self.initial_request) + while url is not None: + response = requests.get(url) + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + print(f"Failed to get records on request `{url}`: {e}") + return None + + data = response.json() + self.process_results(data["results"]) + self.report_progress() + url = data["next"] + sleep(1) + + @abstractmethod + def process_results(self, results: list[dict]): + pass + + @abstractmethod + def build_url(self, request: FetchRequest) -> str: + pass + + @abstractmethod + def report_progress(self): + pass diff --git a/source_collectors/muckrock/classes/muckrock_fetchers/__init__.py b/source_collectors/muckrock/classes/muckrock_fetchers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/source_collectors/muckrock/constants.py b/source_collectors/muckrock/constants.py new file mode 100644 index 00000000..07dca8f4 --- /dev/null +++ b/source_collectors/muckrock/constants.py @@ -0,0 +1,4 @@ + + +BASE_MUCKROCK_URL = "https://www.muckrock.com/api_v1" +FOIA_DATA_DB = "foia_data.db" \ No newline at end of file diff --git a/source_collectors/muckrock/create_foia_data_db.py b/source_collectors/muckrock/create_foia_data_db.py index 4adc5556..f012f5d3 100644 --- a/source_collectors/muckrock/create_foia_data_db.py +++ b/source_collectors/muckrock/create_foia_data_db.py @@ -19,20 +19,24 @@ and/or printed to the console. """ -import requests -import sqlite3 import logging import os import json import time -from typing import List, Tuple, Dict, Any, Union, Literal +from typing import List, Tuple, Dict, Any + +from tqdm import tqdm + +from source_collectors.muckrock.classes.SQLiteClient import SQLiteClientContextManager, SQLClientError +from source_collectors.muckrock.classes.muckrock_fetchers import FOIAFetcher +from source_collectors.muckrock.classes.muckrock_fetchers.MuckrockFetcher import MuckrockNoMoreDataError logging.basicConfig( filename="errors.log", level=logging.ERROR, format="%(levelname)s: %(message)s" ) +# TODO: Why are we pulling every single FOIA request? -base_url = "https://www.muckrock.com/api_v1/foia/" last_page_fetched = "last_page_fetched.txt" NO_MORE_DATA = -1 # flag for program exit @@ -83,69 +87,32 @@ def create_db() -> bool: bool: True, if database is successfully created; False otherwise. Raises: - sqlite3.Error: If the table creation operation fails, prints error and returns False. - """ - - try: - with sqlite3.connect("foia_data.db") as conn: - conn.execute(create_table_query) - conn.commit() - print("Successfully created foia_data.db!") - return True - except sqlite3.Error as e: - print(f"SQLite error: {e}.") - logging.error(f"Failed to create foia_data.db due to SQLite error: {e}") - return False - - -def fetch_page(page: int) -> Union[JSON, Literal[NO_MORE_DATA], None]: + sqlite3.Error: If the table creation operation fails, + prints error and returns False. """ - Fetches a page of 100 results from the MuckRock FOIA API. - - Args: - page (int): The page number to fetch from the API. - - Returns: - Union[JSON, None, Literal[NO_MORE_DATA]]: - - JSON Dict[str, Any]: The response's JSON data, if the request is successful. - - NO_MORE_DATA (int = -1): A constant, if there are no more pages to fetch (indicated by a 404 response). - - None: If there is an error other than 404. - """ - - per_page = 100 - response = requests.get( - base_url, params={"page": page, "page_size": per_page, "format": "json"} - ) - - if response.status_code == 200: - return response.json() - elif response.status_code == 404: - print("No more pages to fetch") - return NO_MORE_DATA # Typically 404 response will mean there are no more pages to fetch - elif 500 <= response.status_code < 600: - logging.error(f"Server error {response.status_code} on page {page}") - page = page + 1 - return fetch_page(page) - else: - print(f"Error fetching page {page}: {response.status_code}") - logging.error( - f"Fetching page {page} failed with response code: { - response.status_code}" - ) - return None - + with SQLiteClientContextManager("foia_data.db") as client: + try: + client.execute_query(create_table_query) + return True + except SQLClientError as e: + print(f"SQLite error: {e}.") + logging.error(f"Failed to create foia_data.db due to SQLite error: {e}") + return False def transform_page_data(data_to_transform: JSON) -> List[Tuple[Any, ...]]: """ - Transforms the data recieved from the MuckRock FOIA API into a structured format for insertion into a database with `populate_db()`. + Transforms the data received from the MuckRock FOIA API + into a structured format for insertion into a database with `populate_db()`. - Transforms JSON input into a list of tuples, as well as serializes the nested `tags` and `communications` fields into JSON strings. + Transforms JSON input into a list of tuples, + as well as serializes the nested `tags` and `communications` fields + into JSON strings. Args: - data_to_transform (JSON: Dict[str, Any]): The JSON data from the API response. - + data_to_transform: The JSON data from the API response. Returns: - transformed_data (List[Tuple[Any, ...]]: A list of tuples, where each tuple contains the fields of a single FOIA request. + A list of tuples, where each tuple contains the fields + of a single FOIA request. """ transformed_data = [] @@ -197,39 +164,40 @@ def populate_db(transformed_data: List[Tuple[Any, ...]], page: int) -> None: sqlite3.Error: If the insertion operation fails, attempts to retry operation (max_retries = 2). If retries are exhausted, logs error and exits. """ - - with sqlite3.connect("foia_data.db") as conn: - + with SQLiteClientContextManager("foia_data.db") as client: retries = 0 max_retries = 2 while retries < max_retries: try: - conn.executemany(foia_insert_query, transformed_data) - conn.commit() + client.execute_query(foia_insert_query, many=transformed_data) print("Successfully inserted data!") return - except sqlite3.Error as e: - print(f"SQLite error: {e}. Retrying...") - conn.rollback() + except SQLClientError as e: + print(f"{e}. Retrying...") retries += 1 time.sleep(1) if retries == max_retries: - print( - f"Failed to insert data from page {page} after { - max_retries} attempts. Skipping to next page." - ) - logging.error( - f"Failed to insert data from page {page} after { - max_retries} attempts." - ) + report_max_retries_error(max_retries, page) + + +def report_max_retries_error(max_retries, page): + print( + f"Failed to insert data from page {page} after { + max_retries} attempts. Skipping to next page." + ) + logging.error( + f"Failed to insert data from page {page} after { + max_retries} attempts." + ) def main() -> None: """ Main entry point for create_foia_data_db.py. - This function orchestrates the process of fetching FOIA requests data from the MuckRock FOIA API, transforming it, + This function orchestrates the process of fetching + FOIA requests data from the MuckRock FOIA API, transforming it, and storing it in a SQLite database. """ @@ -240,33 +208,46 @@ def main() -> None: print("Failed to create foia_data.db") return - if os.path.exists(last_page_fetched): - with open(last_page_fetched, mode="r") as file: - page = int(file.read()) + 1 - else: - page = 1 - - while True: + start_page = get_start_page() + fetcher = FOIAFetcher( + start_page=start_page + ) - print(f"Fetching page {page}...") - page_data = fetch_page(page) + with tqdm(initial=start_page, unit="page") as pbar: + while True: - if page_data == NO_MORE_DATA: - break # Exit program because no more data exixts - if page_data is None: - print(f"Skipping page {page}...") - page += 1 - continue + # TODO: Replace with TQDM + try: + pbar.update() + page_data = fetcher.fetch_next_page() + except MuckrockNoMoreDataError: + # Exit program because no more data exists + break + if page_data is None: + continue + transformed_data = transform_page_data(page_data) + populate_db(transformed_data, fetcher.current_page) + + with open(last_page_fetched, mode="w") as file: + file.write(str(fetcher.current_page)) - transformed_data = transform_page_data(page_data) + print("create_foia_data_db.py run finished") - populate_db(transformed_data, page) - with open(last_page_fetched, mode="w") as file: - file.write(str(page)) - page += 1 +def get_start_page(): + """ + Returns the page number to start fetching from. - print("create_foia_data_db.py run finished") + If the file `last_page_fetched` exists, + reads the page number from the file and returns it + 1. + Otherwise, returns 1. + """ + if os.path.exists(last_page_fetched): + with open(last_page_fetched, mode="r") as file: + page = int(file.read()) + 1 + else: + page = 1 + return page if __name__ == "__main__": diff --git a/source_collectors/muckrock/download_muckrock_foia.py b/source_collectors/muckrock/download_muckrock_foia.py deleted file mode 100644 index 0abd527d..00000000 --- a/source_collectors/muckrock/download_muckrock_foia.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -***DEPRECATED*** - -download_muckrock_foia.py - -This script fetches data from the MuckRock FOIA API and stores the results in a JSON file. - -""" - -import requests -import csv -import time -import json - -# Define the base API endpoint -base_url = "https://www.muckrock.com/api_v1/foia/" - -# Set initial parameters -page = 1 -per_page = 100 -all_data = [] -output_file = "foia_data.json" - - -def fetch_page(page): - """ - Fetches data from a specific page of the MuckRock FOIA API. - """ - response = requests.get( - base_url, params={"page": page, "page_size": per_page, "format": "json"} - ) - if response.status_code == 200: - return response.json() - else: - print(f"Error fetching page {page}: {response.status_code}") - return None - - -# Fetch and store data from all pages -while True: - print(f"Fetching page {page}...") - data = fetch_page(page) - if data is None: - print(f"Skipping page {page}...") - page += 1 - continue - - all_data.extend(data["results"]) - if not data["next"]: - break - - page += 1 - -# Write data to CSV -with open(output_file, mode="w", encoding="utf-8") as json_file: - json.dump(all_data, json_file, indent=4) - -print(f"Data written to {output_file}") diff --git a/source_collectors/muckrock/generate_detailed_muckrock_csv.py b/source_collectors/muckrock/generate_detailed_muckrock_csv.py index a077dbc7..3cb884c0 100644 --- a/source_collectors/muckrock/generate_detailed_muckrock_csv.py +++ b/source_collectors/muckrock/generate_detailed_muckrock_csv.py @@ -1,182 +1,169 @@ -import json +""" +Converts JSON file of MuckRock FOIA requests to CSV for further processing +""" + +# TODO: Look into linking up this logic with other components in pipeline. + import argparse import csv -import requests import time -from utils import format_filename_json_to_csv - -# Load the JSON data -parser = argparse.ArgumentParser(description="Parse JSON from a file.") -parser.add_argument( - "--json_file", type=str, required=True, help="Path to the JSON file" -) - -args = parser.parse_args() - -with open(args.json_file, "r") as f: - json_data = json.load(f) - -# Define the CSV headers -headers = [ - "name", - "agency_described", - "record_type", - "description", - "source_url", - "readme_url", - "scraper_url", - "state", - "county", - "municipality", - "agency_type", - "jurisdiction_type", - "View Archive", - "agency_aggregation", - "agency_supplied", - "supplying_entity", - "agency_originated", - "originating_agency", - "coverage_start", - "source_last_updated", - "coverage_end", - "number_of_records_available", - "size", - "access_type", - "data_portal_type", - "access_notes", - "record_format", - "update_frequency", - "update_method", - "retention_schedule", - "detail_level", -] - - -def get_agency(agency_id): - """ - Function to get agency_described - """ - if agency_id: - agency_url = f"https://www.muckrock.com/api_v1/agency/{agency_id}/" - response = requests.get(agency_url) - - if response.status_code == 200: - agency_data = response.json() - return agency_data - else: - return "" - else: - print("Agency ID not found in item") - - -def get_jurisdiction(jurisdiction_id): - """ - Function to get jurisdiction_described - """ - if jurisdiction_id: - jurisdiction_url = ( - f"https://www.muckrock.com/api_v1/jurisdiction/{jurisdiction_id}/" - ) - response = requests.get(jurisdiction_url) - - if response.status_code == 200: - jurisdiction_data = response.json() - return jurisdiction_data - else: - return "" - else: - print("Jurisdiction ID not found in item") - - -output_csv = format_filename_json_to_csv(args.json_file) -# Open a CSV file for writing -with open(output_csv, "w", newline="") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=headers) - - # Write the header row - writer.writeheader() - +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +from source_collectors.muckrock.classes.muckrock_fetchers import AgencyFetcher +from source_collectors.muckrock.classes.muckrock_fetchers.JurisdictionByIDFetcher import JurisdictionByIDFetcher +from utils import format_filename_json_to_csv, load_json_file + + +class JurisdictionType(Enum): + FEDERAL = "federal" + STATE = "state" + COUNTY = "county" + LOCAL = "local" + + +class AgencyInfo(BaseModel): + name: Optional[str] = "" + agency_described: Optional[str] = "" + record_type: Optional[str] = "" + description: Optional[str] = "" + source_url: Optional[str] = "" + readme_url: Optional[str] = "" + scraper_url: Optional[str] = "" + state: Optional[str] = "" + county: Optional[str] = "" + municipality: Optional[str] = "" + agency_type: Optional[str] = "" + jurisdiction_type: Optional[JurisdictionType] = None + agency_aggregation: Optional[str] = "" + agency_supplied: Optional[bool] = False + supplying_entity: Optional[str] = "MuckRock" + agency_originated: Optional[bool] = True + originating_agency: Optional[str] = "" + coverage_start: Optional[str] = "" + source_last_updated: Optional[str] = "" + coverage_end: Optional[str] = "" + number_of_records_available: Optional[str] = "" + size: Optional[str] = "" + access_type: Optional[str] = "" + data_portal_type: Optional[str] = "MuckRock" + access_notes: Optional[str] = "" + record_format: Optional[str] = "" + update_frequency: Optional[str] = "" + update_method: Optional[str] = "" + retention_schedule: Optional[str] = "" + detail_level: Optional[str] = "" + + + def model_dump(self, *args, **kwargs): + original_dict = super().model_dump(*args, **kwargs) + original_dict['View Archive'] = '' + return {key: (value.value if isinstance(value, Enum) else value) + for key, value in original_dict.items()} + + def keys(self) -> list[str]: + return list(self.model_dump().keys()) + + +def main(): + json_filename = get_json_filename() + json_data = load_json_file(json_filename) + output_csv = format_filename_json_to_csv(json_filename) + agency_infos = get_agency_infos(json_data) + write_to_csv(agency_infos, output_csv) + + +def get_agency_infos(json_data): + a_fetcher = AgencyFetcher() + j_fetcher = JurisdictionByIDFetcher() + agency_infos = [] # Iterate through the JSON data for item in json_data: print(f"Writing data for {item.get('title')}") - agency_data = get_agency(item.get("agency")) + agency_data = a_fetcher.get_agency(agency_id=item.get("agency")) time.sleep(1) - jurisdiction_data = get_jurisdiction(agency_data.get("jurisdiction")) - + jurisdiction_data = j_fetcher.get_jurisdiction( + jurisdiction_id=agency_data.get("jurisdiction") + ) + agency_name = agency_data.get("name", "") + agency_info = AgencyInfo( + name=item.get("title", ""), + originating_agency=agency_name, + agency_described=agency_name + ) jurisdiction_level = jurisdiction_data.get("level") - # federal jurisduction level - if jurisdiction_level == "f": - state = "" - county = "" - municipality = "" - juris_type = "federal" - # state jurisdiction level - if jurisdiction_level == "s": - state = jurisdiction_data.get("name") - county = "" - municipality = "" - juris_type = "state" - # local jurisdiction level - if jurisdiction_level == "l": - parent_juris_data = get_jurisdiction(jurisdiction_data.get("parent")) - state = parent_juris_data.get("abbrev") + add_locational_info(agency_info, j_fetcher, jurisdiction_data, jurisdiction_level) + optionally_add_agency_type(agency_data, agency_info) + optionally_add_access_info(agency_info, item) + + # Extract the relevant fields from the JSON object + # TODO: I question the utility of creating columns that are then left blank until later + # and possibly in a different file entirely. + agency_infos.append(agency_info) + return agency_infos + + +def write_to_csv(agency_infos, output_csv): + # Open a CSV file for writing + with open(output_csv, "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=AgencyInfo().keys()) + + # Write the header row + writer.writeheader() + + for agency_info in agency_infos: + csv_row = agency_info.model_dump() + + # Write the extracted row to the CSV file + writer.writerow(csv_row) + + +def get_json_filename(): + # Load the JSON data + parser = argparse.ArgumentParser(description="Parse JSON from a file.") + parser.add_argument( + "--json_file", type=str, required=True, help="Path to the JSON file" + ) + args = parser.parse_args() + json_filename = args.json_file + return json_filename + + +def add_locational_info(agency_info, j_fetcher, jurisdiction_data, jurisdiction_level): + match jurisdiction_level: + case "f": # federal jurisdiction level + agency_info.jurisdiction_type = JurisdictionType.FEDERAL + case "s": # state jurisdiction level + agency_info.jurisdiction_type = JurisdictionType.STATE + agency_info.state = jurisdiction_data.get("name") + case "l": # local jurisdiction level + parent_juris_data = j_fetcher.get_jurisdiction( + jurisdiction_id=jurisdiction_data.get("parent") + ) + agency_info.state = parent_juris_data.get("abbrev") if "County" in jurisdiction_data.get("name"): - county = jurisdiction_data.get("name") - municipality = "" - juris_type = "county" + agency_info.county = jurisdiction_data.get("name") + agency_info.jurisdiction_type = JurisdictionType.COUNTY else: - county = "" - municipality = jurisdiction_data.get("name") - juris_type = "local" - - if "Police" in agency_data.get("types"): - agency_type = "law enforcement/police" - else: - agency_type = "" - - source_url = "" - absolute_url = item.get("absolute_url") - access_type = "" - for comm in item["communications"]: - if comm["files"]: - source_url = absolute_url + "#files" - access_type = "Web page,Download,API" - break + agency_info.municipality = jurisdiction_data.get("name") + agency_info.jurisdiction_type = JurisdictionType.LOCAL - # Extract the relevant fields from the JSON object - csv_row = { - "name": item.get("title", ""), - "agency_described": agency_data.get("name", "") + " - " + state, - "record_type": "", - "description": "", - "source_url": source_url, - "readme_url": absolute_url, - "scraper_url": "", - "state": state, - "county": county, - "municipality": municipality, - "agency_type": agency_type, - "jurisdiction_type": juris_type, - "View Archive": "", - "agency_aggregation": "", - "agency_supplied": "no", - "supplying_entity": "MuckRock", - "agency_originated": "yes", - "originating_agency": agency_data.get("name", ""), - "coverage_start": "", - "source_last_updated": "", - "coverage_end": "", - "number_of_records_available": "", - "size": "", - "access_type": access_type, - "data_portal_type": "MuckRock", - "access_notes": "", - "record_format": "", - "update_frequency": "", - "update_method": "", - "retention_schedule": "", - "detail_level": "", - } - - # Write the extracted row to the CSV file - writer.writerow(csv_row) + +def optionally_add_access_info(agency_info, item): + absolute_url = item.get("absolute_url") + for comm in item["communications"]: + if comm["files"]: + agency_info.source_url = absolute_url + "#files" + agency_info.access_type = "Web page,Download,API" + break + + +def optionally_add_agency_type(agency_data, agency_info): + if "Police" in agency_data.get("types"): + agency_info.agency_type = "law enforcement/police" + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/source_collectors/muckrock/get_allegheny_foias.py b/source_collectors/muckrock/get_allegheny_foias.py index a559f67f..b269ff18 100644 --- a/source_collectors/muckrock/get_allegheny_foias.py +++ b/source_collectors/muckrock/get_allegheny_foias.py @@ -1,44 +1,30 @@ """ -get_allegheny_foias.py +Get Allegheny County FOIA requests +and save them to a JSON file """ -import requests -import json -import time +from source_collectors.muckrock.classes.muckrock_fetchers.FOIALoopFetcher import FOIALoopFetchRequest, FOIALoopFetcher +from source_collectors.muckrock.classes.muckrock_fetchers import JurisdictionLoopFetchRequest, \ + JurisdictionLoopFetcher +from source_collectors.muckrock.utils import save_json_file -def fetch_jurisdiction_ids(town_file, base_url): + +def fetch_jurisdiction_ids(town_file, level="l", parent=126): """ fetch jurisdiction IDs based on town names from a text file """ with open(town_file, "r") as file: town_names = [line.strip() for line in file] - jurisdiction_ids = {} - url = base_url - - while url: - response = requests.get(url) - if response.status_code == 200: - data = response.json() - for item in data.get("results", []): - if item["name"] in town_names: - jurisdiction_ids[item["name"]] = item["id"] - - url = data.get("next") - print( - f"Processed page, found {len(jurisdiction_ids)}/{len(town_names)} jurisdictions so far..." - ) - time.sleep(1) # To respect the rate limit + request = JurisdictionLoopFetchRequest( + level=level, parent=parent, town_names=town_names + ) - elif response.status_code == 503: - print("Error 503: Skipping page") - break - else: - print(f"Error fetching data: {response.status_code}") - break + fetcher = JurisdictionLoopFetcher(request) + fetcher.loop_fetch() + return fetcher.jurisdictions - return jurisdiction_ids def fetch_foia_data(jurisdiction_ids): @@ -47,28 +33,14 @@ def fetch_foia_data(jurisdiction_ids): """ all_data = [] for name, id_ in jurisdiction_ids.items(): - url = f"https://www.muckrock.com/api_v1/foia/?status=done&jurisdiction={id_}" - while url: - response = requests.get(url) - if response.status_code == 200: - data = response.json() - all_data.extend(data.get("results", [])) - url = data.get("next") - print( - f"Fetching records for {name}, {len(all_data)} total records so far..." - ) - time.sleep(1) # To respect the rate limit - elif response.status_code == 503: - print(f"Error 503: Skipping page for {name}") - break - else: - print(f"Error fetching data: {response.status_code} for {name}") - break + print(f"\nFetching records for {name}...") + request = FOIALoopFetchRequest(jurisdiction=id_) + fetcher = FOIALoopFetcher(request) + fetcher.loop_fetch() + all_data.extend(fetcher.results) # Save the combined data to a JSON file - with open("foia_data_combined.json", "w") as json_file: - json.dump(all_data, json_file, indent=4) - + save_json_file(file_path="foia_data_combined.json", data=all_data) print(f"Saved {len(all_data)} records to foia_data_combined.json") @@ -77,12 +49,12 @@ def main(): Execute the script """ town_file = "allegheny-county-towns.txt" - jurisdiction_url = ( - "https://www.muckrock.com/api_v1/jurisdiction/?level=l&parent=126" - ) - # Fetch jurisdiction IDs based on town names - jurisdiction_ids = fetch_jurisdiction_ids(town_file, jurisdiction_url) + jurisdiction_ids = fetch_jurisdiction_ids( + town_file, + level="l", + parent=126 + ) print(f"Jurisdiction IDs fetched: {jurisdiction_ids}") # Fetch FOIA data for each jurisdiction ID diff --git a/source_collectors/muckrock/muck_get.py b/source_collectors/muckrock/muck_get.py index 20c29338..f51bf9e0 100644 --- a/source_collectors/muckrock/muck_get.py +++ b/source_collectors/muckrock/muck_get.py @@ -1,61 +1,16 @@ """ -muck_get.py - +A straightforward standalone script for downloading data from MuckRock +and searching for it with a specific search string. """ - -import requests -import json - -# Define the base API endpoint -base_url = "https://www.muckrock.com/api_v1/foia/" - -# Define the search string -search_string = "use of force" -per_page = 100 -page = 1 -all_results = [] -max_count = 20 - -while True: - - # Make the GET request with the search string as a query parameter - response = requests.get( - base_url, params={"page": page, "page_size": per_page, "format": "json"} - ) - - # Check if the request was successful - if response.status_code == 200: - # Parse the JSON response - data = response.json() - - if not data["results"]: - break - - filtered_results = [ - item - for item in data["results"] - if search_string.lower() in item["title"].lower() - ] - - all_results.extend(filtered_results) - - if len(filtered_results) > 0: - num_results = len(filtered_results) - print(f"found {num_results} more matching result(s)...") - - if len(all_results) >= max_count: - print("max count reached... exiting") - break - - page += 1 - - else: - print(f"Error: {response.status_code}") - break - -# Dump list into a JSON file -json_out_file = search_string.replace(" ", "_") + ".json" -with open(json_out_file, "w") as json_file: - json.dump(all_results, json_file) - -print(f"List dumped into {json_out_file}") +from source_collectors.muckrock.classes.muckrock_fetchers import FOIAFetcher +from source_collectors.muckrock.classes.FOIASearcher import FOIASearcher +from source_collectors.muckrock.utils import save_json_file + +if __name__ == "__main__": + search_term = "use of force" + fetcher = FOIAFetcher() + searcher = FOIASearcher(fetcher=fetcher, search_term=search_term) + results = searcher.search_to_count(20) + json_out_file = search_term.replace(" ", "_") + ".json" + save_json_file(file_path=json_out_file, data=results) + print(f"List dumped into {json_out_file}") diff --git a/source_collectors/muckrock/muckrock_ml_labeler.py b/source_collectors/muckrock/muckrock_ml_labeler.py index b313c045..e3cb5cc7 100644 --- a/source_collectors/muckrock/muckrock_ml_labeler.py +++ b/source_collectors/muckrock/muckrock_ml_labeler.py @@ -1,6 +1,5 @@ """ -muckrock_ml_labeler.py - +Utilizes a fine-tuned model to label a dataset of URLs. """ from transformers import AutoTokenizer, AutoModelForSequenceClassification @@ -8,45 +7,73 @@ import pandas as pd import argparse -# Load the tokenizer and model -model_name = "PDAP/fine-url-classifier" -tokenizer = AutoTokenizer.from_pretrained(model_name) -model = AutoModelForSequenceClassification.from_pretrained(model_name) -model.eval() - -# Load the dataset from command line argument -parser = argparse.ArgumentParser(description="Load CSV file into a pandas DataFrame.") -parser.add_argument("--csv_file", type=str, required=True, help="Path to the CSV file") -args = parser.parse_args() -df = pd.read_csv(args.csv_file) - -# Combine multiple columns (e.g., 'url', 'html_title', 'h1') into a single text field for each row -columns_to_combine = [ - "url_path", - "html_title", - "h1", -] # Add other columns here as needed -df["combined_text"] = df[columns_to_combine].apply( - lambda row: " ".join(row.values.astype(str)), axis=1 -) - -# Convert the combined text into a list -texts = df["combined_text"].tolist() - -# Tokenize the inputs -inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") - -# Perform inference -with torch.no_grad(): - outputs = model(**inputs) - -# Get the predicted labels -predictions = torch.argmax(outputs.logits, dim=-1) - -# Map predictions to labels -labels = model.config.id2label -predicted_labels = [labels[int(pred)] for pred in predictions] - -# Add the predicted labels to the dataframe and save -df["predicted_label"] = predicted_labels -df.to_csv("labeled_muckrock_dataset.csv", index=False) + +def load_dataset_from_command_line() -> pd.DataFrame: + parser = argparse.ArgumentParser(description="Load CSV file into a pandas DataFrame.") + parser.add_argument("--csv_file", type=str, required=True, help="Path to the CSV file") + args = parser.parse_args() + return pd.read_csv(args.csv_file) + + +def create_combined_text_column(df: pd.DataFrame) -> None: + # Combine multiple columns (e.g., 'url', 'html_title', 'h1') into a single text field for each row + columns_to_combine = [ + "url_path", + "html_title", + "h1", + ] # Add other columns here as needed + df["combined_text"] = df[columns_to_combine].apply( + lambda row: " ".join(row.values.astype(str)), axis=1 + ) + + +def get_list_of_combined_texts(df: pd.DataFrame) -> list[str]: + # Convert the combined text into a list + return df["combined_text"].tolist() + + +def save_labeled_muckrock_dataset_to_csv(): + df.to_csv("labeled_muckrock_dataset.csv", index=False) + + +def create_predicted_labels_column(df: pd.DataFrame, predicted_labels: list[str]) -> None: + df["predicted_label"] = predicted_labels + + +def map_predictions_to_labels(model, predictions) -> list[str]: + labels = model.config.id2label + return [labels[int(pred)] for pred in predictions] + + +def get_predicted_labels(texts: list[str]) -> list[str]: + # Load the tokenizer and model + model_name = "PDAP/fine-url-classifier" + tokenizer = AutoTokenizer.from_pretrained(model_name) + + model = AutoModelForSequenceClassification.from_pretrained(model_name) + model.eval() + # Tokenize the inputs + inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") + # Perform inference + with torch.no_grad(): + outputs = model(**inputs) + # Get the predicted labels + predictions = torch.argmax(outputs.logits, dim=-1) + # Map predictions to labels + predicted_labels = map_predictions_to_labels(model=model, predictions=predictions) + + return predicted_labels + + +if __name__ == "__main__": + df = load_dataset_from_command_line() + # TODO: Check for existence of required columns prior to further processing + create_combined_text_column(df=df) + + texts = get_list_of_combined_texts(df=df) + + predicted_labels = get_predicted_labels(texts=texts) + # Add the predicted labels to the dataframe and save + create_predicted_labels_column(df=df, predicted_labels=predicted_labels) + + save_labeled_muckrock_dataset_to_csv() \ No newline at end of file diff --git a/source_collectors/muckrock/search_foia_data_db.py b/source_collectors/muckrock/search_foia_data_db.py index e7550608..51357663 100644 --- a/source_collectors/muckrock/search_foia_data_db.py +++ b/source_collectors/muckrock/search_foia_data_db.py @@ -18,24 +18,12 @@ Errors encountered during database operations, JSON parsing, or file writing are printed to the console. """ -import sqlite3 import pandas as pd import json import argparse -import os from typing import Union, List, Dict -check_results_table_query = """ - SELECT name FROM sqlite_master - WHERE (type = 'table') - AND (name = 'results') - """ - -search_foia_query = """ - SELECT * FROM results - WHERE (title LIKE ? OR tags LIKE ?) - AND (status = 'done') - """ +from source_collectors.muckrock.classes.FOIADBSearcher import FOIADBSearcher def parser_init() -> argparse.ArgumentParser: @@ -61,45 +49,8 @@ def parser_init() -> argparse.ArgumentParser: def search_foia_db(search_string: str) -> Union[pd.DataFrame, None]: - """ - Searches the foia_data.db database for FOIA request entries matching the provided search string. - - Args: - search_string (str): The string to search for in the `title` and `tags` of the `results` table. - - Returns: - Union[pandas.DataFrame, None]: - - pandas.DataFrame: A DataFrame containing the matching entries from the database. - - None: If an error occurs during the database operation. - - Raises: - sqlite3.Error: If any database operation fails, prints error and returns None. - Exception: If any unexpected error occurs, prints error and returns None. - """ - - print(f'Searching foia_data.db for "{search_string}"...') - - try: - with sqlite3.connect("foia_data.db") as conn: - - results_table = pd.read_sql_query(check_results_table_query, conn) - - if results_table.empty: - print("The `results` table does not exist in the database.") - return None - - params = [f"%{search_string}%", f"%{search_string}%"] - - df = pd.read_sql_query(search_foia_query, conn, params=params) - - except sqlite3.Error as e: - print(f"Sqlite error: {e}") - return None - except Exception as e: - print(f"An unexpected error occurred: {e}") - return None - - return df + searcher = FOIADBSearcher() + return searcher.search(search_string) def parse_communications_column(communications) -> List[Dict]: @@ -164,24 +115,25 @@ def main() -> None: args = parser.parse_args() search_string = args.search_for - if not os.path.exists("foia_data.db"): - print( - "foia_data.db does not exist.\nRun create_foia_data_db.py first to create and populate it." - ) - return - df = search_foia_db(search_string) if df is None: return + update_communications_column(df) - if not df["communications"].empty: - df["communications"] = df["communications"].apply(parse_communications_column) + announce_matching_entries(df, search_string) + generate_json(df, search_string) + + +def announce_matching_entries(df, search_string): print( f'Found {df.shape[0]} matching entries containing "{search_string}" in the title or tags' ) - generate_json(df, search_string) + +def update_communications_column(df): + if not df["communications"].empty: + df["communications"] = df["communications"].apply(parse_communications_column) if __name__ == "__main__": diff --git a/source_collectors/muckrock/search_local_foia_json.py b/source_collectors/muckrock/search_local_foia_json.py deleted file mode 100644 index 562c4bae..00000000 --- a/source_collectors/muckrock/search_local_foia_json.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -***DEPRECATED*** - -search_local_foia_json.py - -""" - -import json - -# Specify the JSON file path -json_file = "foia_data.json" -search_string = "use of force" - -# Load the JSON data -with open(json_file, "r", encoding="utf-8") as file: - data = json.load(file) - -# List to store matching entries -matching_entries = [] - - -def search_entry(entry): - """ - search within an entry - """ - # Check if 'status' is 'done' - if entry.get("status") != "done": - return False - - # Check if 'title' or 'tags' field contains the search string - title_match = "title" in entry and search_string.lower() in entry["title"].lower() - tags_match = "tags" in entry and any( - search_string.lower() in tag.lower() for tag in entry["tags"] - ) - - return title_match or tags_match - - -# Iterate through the data and collect matching entries -for entry in data: - if search_entry(entry): - matching_entries.append(entry) - -# Output the results -print( - f"Found {len(matching_entries)} entries containing '{search_string}' in the title or tags." -) - -# Optionally, write matching entries to a new JSON file -with open("matching_entries.json", "w", encoding="utf-8") as file: - json.dump(matching_entries, file, indent=4) - -print("Matching entries written to 'matching_entries.json'") diff --git a/source_collectors/muckrock/utils.py b/source_collectors/muckrock/utils.py index 3d8b63db..3c7eba28 100644 --- a/source_collectors/muckrock/utils.py +++ b/source_collectors/muckrock/utils.py @@ -8,6 +8,7 @@ """ import re +import json def format_filename_json_to_csv(json_filename: str) -> str: @@ -24,3 +25,12 @@ def format_filename_json_to_csv(json_filename: str) -> str: csv_filename = re.sub(r"_(?=[^.]*$)", "-", json_filename[:-5]) + ".csv" return csv_filename + +def load_json_file(file_path: str) -> dict: + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + return data + +def save_json_file(file_path: str, data: dict | list[dict]): + with open(file_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=4) \ No newline at end of file