Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions src/stepcount/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def read(
- resample_hz (str, optional): The resampling frequency for the data. If 'uniform', it will use `sample_rate`
and resample to ensure it is evenly spaced. Default is 'uniform'.
- sample_rate (float, optional): The sample rate of the data. If None, it will be inferred. Default is None.
- csv_start_row (int, optional): The row number to start reading from (0-indexed, excluding header).
Only applies to CSV files. Default is None (read from the beginning).
- csv_end_row (int, optional): The row number to stop reading at, inclusive (0-indexed, excluding header).
- csv_start_row (int, optional): The file row number where the header is located (0-indexed).
Rows before this are skipped. Only applies to CSV files. Default is None (header at row 0).
- csv_end_row (int, optional): The file row number to stop reading at, inclusive (0-indexed).
Only applies to CSV files. Default is None (read to the end).
- csv_time_format (str, optional): Format string for parsing the time column (e.g., '%Y-%m-%d %H:%M:%S.%f').
Only applies to CSV files. Default is None (auto-detect).
Expand Down Expand Up @@ -82,7 +82,8 @@ def read(
if any(i < 0 for i in [tidx, xidx, yidx, zidx]):
raise ValueError(f"csv_txyz_idxs must be non-negative integers, got: '{csv_txyz_idxs}'")
# Read header to get column names at those indices
header = pd.read_csv(filepath, nrows=0).columns.tolist()
# Skip csv_start_row rows to reach the actual header row
header = pd.read_csv(filepath, nrows=0, skiprows=csv_start_row).columns.tolist()
max_idx = max(tidx, xidx, yidx, zidx)
if max_idx >= len(header):
raise ValueError(f"Column index {max_idx} out of range. CSV has {len(header)} columns.")
Expand All @@ -95,18 +96,20 @@ def read(
if csv_end_row < csv_start_row:
raise ValueError(f"csv_end_row ({csv_end_row}) must be >= csv_start_row ({csv_start_row})")

# skiprows: skip rows after header if csv_start_row is specified
# csv_start_row is 0-indexed (excluding header), so skiprows skips file rows 1 to csv_start_row
skiprows = None if csv_start_row is None else range(1, csv_start_row + 1)
# skiprows: skip rows before the header if csv_start_row is specified
# csv_start_row is 0-indexed file row where the header is located
# skiprows=N skips rows 0 to N-1, making row N the header
skiprows = csv_start_row

# nrows: number of data rows to read
# csv_end_row is inclusive and 0-indexed (excluding header)
# csv_end_row is file row to stop at (inclusive, 0-indexed)
# Data rows are from (csv_start_row + 1) to csv_end_row
if csv_end_row is None:
nrows = None
elif csv_start_row is None:
nrows = csv_end_row + 1
nrows = csv_end_row # rows 1 to csv_end_row = csv_end_row rows
else:
nrows = csv_end_row - csv_start_row + 1
nrows = csv_end_row - csv_start_row # rows (csv_start_row+1) to csv_end_row

# Common read_csv kwargs
read_kwargs = dict(
Expand Down
Loading