Skip to content

Commit

Permalink
pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Hakimovich99 committed Nov 17, 2023
1 parent 463fc80 commit fb0bee9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions src/components/load_from_csv/src/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import typing as t

import dask
import dask.dataframe as dd
import pandas as pd
from fondant.component import DaskLoadComponent
Expand All @@ -20,16 +19,19 @@ def __init__(
column_name_mapping: t.Optional[dict],
n_rows_to_load: t.Optional[int],
index_column: t.Optional[str],
) -> None:
) -> None:
"""
Args:
spec: the component spec
dataset_uri: The remote path to the parquet file/folder containing the dataset
column_name_mapping: Mapping of the consumed dataset to fondant column names
n_rows_to_load: optional argument that defines the number of rows to load. Useful for
testing pipeline runs on a small scale.
index_column: Column to set index to in the load component, if not specified a default
globally unique index will be set.
dataset_uri: The remote path to the parquet file/folder
containing the dataset column_name_mapping: Mapping of
the consumed dataset to fondant column names
n_rows_to_load: optional argument that defines the
number of rows to load. Useful for testing pipeline
runs on a small scale.
index_column: Column to set index to in the load component,
if not specified a default globally unique index will
be set.
"""
self.dataset_uri = dataset_uri
self.column_separator = column_separator
Expand Down Expand Up @@ -72,7 +74,10 @@ def set_df_index(self, dask_df: dd.DataFrame) -> dd.DataFrame:
)

def _set_unique_index(dataframe: pd.DataFrame, partition_info=None):
"""Function that sets a unique index based on the partition and row number."""
"""
Function that sets a unique index
based on the partition and row number.
"""
dataframe["id"] = 1
dataframe["id"] = (
str(partition_info["number"])
Expand All @@ -98,7 +103,7 @@ def _get_meta_df() -> pd.DataFrame:
dask_df = dask_df.set_index(self.index_column, drop=True)

return dask_df

def return_subset_of_df(self, dask_df: dd.DataFrame) -> dd.DataFrame:
if self.n_rows_to_load is not None:
partitions_length = 0
Expand All @@ -122,7 +127,9 @@ def load(self) -> dd.DataFrame:
columns = self.get_columns_to_keep()

logger.debug(f"Columns to keep: {columns}")
dask_df = dd.read_csv(self.dataset_uri, sep=self.column_separator, usecols=columns)
dask_df = dd.read_csv(
self.dataset_uri, sep=self.column_separator, usecols=columns
)

# 2) Rename columns
if self.column_name_mapping:
Expand Down
Empty file.

0 comments on commit fb0bee9

Please sign in to comment.