Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 139 additions & 22 deletions src/asleep/collate_outputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
import json
import gzip
import zipfile
from collections import OrderedDict

import pandas as pd
Expand All @@ -10,40 +12,104 @@
def collate_outputs(
results_dir,
collated_results_dir="collated_outputs/",
include_files=None,
):
"""Collate all results files in <results_dir>.
"""Collate selected results files in <results_dir>.
:param str results_dir: Root directory in which to search for result files.
:param str collated_results_dir: Directory to write the collated files to.
:param list include_files: List of file types to collate.
Options: 'info', 'summary', 'predictions', 'sleep_block', 'day_summary'
:return: Collated files written to <collated_results_dir>
:rtype: void
"""

# Default to only JSON files if none specified
if include_files is None:
include_files = ['info', 'summary']

# Convert to set for faster lookups
include_set = set(include_files)

print("Searching files...")

info_files = []
summary_files = []
predictions_files = []
sleep_block_files = []
day_summary_files = []

# Define file patterns for each type (including compressed versions)
info_patterns = ['info.json', 'info.json.gz', 'info.json.zip']
summary_patterns = ['summary.json', 'summary.json.gz', 'summary.json.zip']
predictions_patterns = ['predictions.csv', 'predictions.csv.gz', 'predictions.csv.zip']
sleep_block_patterns = ['sleep_block.csv', 'sleep_block.csv.gz', 'sleep_block.csv.zip']
day_summary_patterns = ['day_summary.csv', 'day_summary.csv.gz', 'day_summary.csv.zip']

# Iterate through the files and append to the appropriate list based on the suffix
for file in Path(results_dir).rglob('*'):
if file.is_file():
if file.name.endswith("info.json"):
if any(file.name == pattern for pattern in info_patterns) and 'info' in include_set:
info_files.append(file)
if file.name.endswith("summary.json"):
elif any(file.name == pattern for pattern in summary_patterns) and \
'summary' in include_set:
summary_files.append(file)
elif any(file.name == pattern for pattern in predictions_patterns) and \
'predictions' in include_set:
predictions_files.append(file)
elif any(file.name == pattern for pattern in sleep_block_patterns) and \
'sleep_block' in include_set:
sleep_block_files.append(file)
elif any(file.name == pattern for pattern in day_summary_patterns) and \
'day_summary' in include_set:
day_summary_files.append(file)

collated_results_dir = Path(collated_results_dir)
collated_results_dir.mkdir(parents=True, exist_ok=True)

# Collate info.json files
print(f"Collating {len(info_files)} info files...")
outfile = collated_results_dir / "info.csv.gz"
collate_jsons(info_files, outfile)
print('Collated info CSV written to', outfile)
# Collate summary.json files
print(f"Collating {len(summary_files)} summary files...")
outfile = collated_results_dir / "summary.csv.gz"
collate_jsons(summary_files, outfile)
print('Collated summary CSV written to', outfile)
# Collate files based on what was requested
if info_files:
print(f"Collating {len(info_files)} info files...")
outfile = collated_results_dir / "info.csv.gz"
collate_jsons(info_files, outfile)
print('Collated info CSV written to', outfile)

if summary_files:
print(f"Collating {len(summary_files)} summary files...")
outfile = collated_results_dir / "summary.csv.gz"
collate_jsons(summary_files, outfile)
print('Collated summary CSV written to', outfile)

if predictions_files:
print(f"Collating {len(predictions_files)} predictions files...")
outfile = collated_results_dir / "predictions.csv.gz"
collate_csvs(predictions_files, outfile)
print('Collated predictions CSV written to', outfile)

if sleep_block_files:
print(f"Collating {len(sleep_block_files)} sleep_block files...")
outfile = collated_results_dir / "sleep_block.csv.gz"
collate_csvs(sleep_block_files, outfile)
print('Collated sleep_block CSV written to', outfile)

if day_summary_files:
print(f"Collating {len(day_summary_files)} day_summary files...")
outfile = collated_results_dir / "day_summary.csv.gz"
collate_csvs(day_summary_files, outfile)
print('Collated day_summary CSV written to', outfile)

# Print summary of what was processed
file_counts = {
'info': len(info_files),
'summary': len(summary_files),
'predictions': len(predictions_files),
'sleep_block': len(sleep_block_files),
'day_summary': len(day_summary_files)
}
processed_types = [k for k, v in file_counts.items() if v > 0 and k in include_set]
if processed_types:
print(f"\nSummary: Processed {processed_types} file types")
else:
print("\nWarning: No matching files found for the requested file types")

return

Expand All @@ -57,19 +123,45 @@ def collate_jsons(file_list, outfile, overwrite=True):

df = []
for file in tqdm(file_list):
with open(file, 'r') as f:
j = json.load(f, object_pairs_hook=OrderedDict)
j['filepath'] = file
df.append(j)
df = pd.DataFrame.from_dict(df) # merge to a dataframe
df = df.applymap(convert_ordereddict) # convert any OrderedDict cell values to regular dict
df.to_csv(outfile, index=False)
file_path = Path(file)

# Handle different compression formats for JSON files
if file_path.name.endswith('.gz'):
with gzip.open(file, 'rt', encoding='utf-8') as f:
j = json.load(f, object_pairs_hook=OrderedDict)
elif file_path.name.endswith('.zip'):
with zipfile.ZipFile(file, 'r') as zf:
# Find the JSON file inside the zip
json_file = None
for name in zf.namelist():
if name.endswith('.json'):
json_file = name
break
if json_file:
with zf.open(json_file) as f:
j = json.load(f, object_pairs_hook=OrderedDict)
else:
print(f"Warning: No JSON file found in {file}")
continue
else:
# Regular uncompressed JSON file
with open(file, 'r', encoding='utf-8') as f:
j = json.load(f, object_pairs_hook=OrderedDict)

j['filepath'] = str(file)
df.append(j)

if df: # Only create DataFrame if we have data
df = pd.DataFrame.from_dict(df) # merge to a dataframe
# convert any OrderedDict cell values to regular dict
df = df.applymap(convert_ordereddict)
df.to_csv(outfile, index=False)

return


def collate_csvs(file_list, outfile, overwrite=True):
""" Collate a list of CSV files into a single CSV file."""
""" Collate a list of CSV files into a single CSV file, adding filepath column."""

if overwrite and outfile.exists():
print(f"Overwriting existing file: {outfile}")
Expand All @@ -78,6 +170,7 @@ def collate_csvs(file_list, outfile, overwrite=True):
header_written = False
for file in tqdm(file_list):
df = pd.read_csv(file)
df['filepath'] = str(file)
df.to_csv(outfile, mode='a', index=False, header=not header_written)
header_written = True

Expand All @@ -92,17 +185,41 @@ def convert_ordereddict(value):


def main():
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(
description="Collate asleep output files from multiple runs into single CSV files",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""Available file types:
info - Device and recording metadata (info.json)
summary - Aggregated sleep statistics (summary.json)
predictions - Epoch-by-epoch sleep predictions (predictions.csv)
sleep_block - Sleep onset/wake times (sleep_block.csv)
day_summary - Daily sleep statistics (day_summary.csv)

Examples:
# Collate info and summary files (default)
collate_sleep results/

# Collate only summary and sleep blocks
collate_sleep results/ --include summary sleep_block

# Collate all file types
collate_sleep results/ --include info summary predictions sleep_block day_summary"""
)
parser.add_argument('results_dir',
help="Root directory in which to search for result files")
parser.add_argument('--output', '-o',
default="collated-outputs/",
help="Directory to write the collated files to")
parser.add_argument('--include', '-i',
nargs='+',
choices=['info', 'summary', 'predictions', 'sleep_block', 'day_summary'],
help="Specify which file types to collate. Default: info summary")
args = parser.parse_args()

return collate_outputs(
results_dir=args.results_dir,
collated_results_dir=args.output,
include_files=args.include,
)


Expand Down
Loading