Skip to content

Commit 66fd5d1

Browse files
authored
Merge pull request #12 from qedsoftware/csv-reader
Add csv reader
2 parents 1abc89e + f6bfd96 commit 66fd5d1

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

timm/data/readers/reader_factory.py

+8
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from .reader_image_folder import ReaderImageFolder
55
from .reader_image_in_tar import ReaderImageInTar
6+
from .reader_paths_csv import ReaderPathsCsv
67

78

89
def create_reader(
@@ -34,6 +35,13 @@ def create_reader(
3435
from .reader_wds import ReaderWds
3536
kwargs.pop('download', False)
3637
reader = ReaderWds(root=root, name=name, split=split, **kwargs)
38+
elif "samples_csv_path" in kwargs:
39+
assert "class_map" in kwargs
40+
reader = ReaderPathsCsv(
41+
images_dir=root,
42+
samples_csv_path=kwargs["samples_csv_path"],
43+
class_map=kwargs["class_map"],
44+
)
3745
else:
3846
assert os.path.exists(root)
3947
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder

timm/data/readers/reader_paths_csv.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
A dataset reader that extracts images from a single folder
3+
based on a csv with labels and filenames relative to that folder.
4+
"""
5+
import os
6+
import pandas as pd
7+
8+
from .reader import Reader
9+
10+
11+
class ReaderPathsCsv(Reader):
12+
def __init__(
13+
self,
14+
images_dir,
15+
samples_csv_path,
16+
class_map: dict[str, int],
17+
):
18+
super().__init__()
19+
assert isinstance(class_map, dict)
20+
21+
self.images_dir = images_dir
22+
samples_df = pd.read_csv(samples_csv_path).astype(str)
23+
24+
if not samples_df["label"].isin(class_map).all():
25+
unrecognized_ids = ~samples_df["label"].isin(class_map)
26+
unrecognized_labels = set(samples_df.loc[unrecognized_ids, "label"])
27+
raise ValueError(f"Unrecognized labels found in samples_df: {unrecognized_labels}")
28+
29+
samples_df["label"] = samples_df["label"].map(class_map)
30+
31+
self.samples_df = samples_df
32+
33+
def __getitem__(self, index):
34+
filename, target = self.samples_df.iloc[index]
35+
path = os.path.join(self.images_dir, filename)
36+
return open(path, 'rb'), target
37+
38+
def __len__(self):
39+
return len(self.samples_df.index)
40+
41+
def _filename(self, index, basename=False, absolute=False):
42+
filename = self.samples_df.iloc[index, "filename"]
43+
if basename:
44+
filename = os.path.basename(filename)
45+
elif not absolute:
46+
filename = os.path.relpath(filename, self.images_dir)
47+
return filename

timm/train.py

+6
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@
9393
help='path to dataset (root dir)')
9494
parser.add_argument('--dataset', metavar='NAME', default='',
9595
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
96+
parser.add_argument('--train-samples-csv-path', metavar='PATH',
97+
help='path to csv with train filenames and labels')
98+
parser.add_argument('--val-samples-csv-path', metavar='PATH',
99+
help='path to csv with train filenames and labels')
96100
group.add_argument('--train-split', metavar='NAME', default='train',
97101
help='dataset train split (default: train)')
98102
group.add_argument('--val-split', metavar='NAME', default='validation',
@@ -669,6 +673,7 @@ def train(config: dict[str, t.Any]):
669673
input_key=args.input_key,
670674
target_key=args.target_key,
671675
num_samples=args.train_num_samples,
676+
samples_csv_path=args.train_samples_csv_path,
672677
)
673678

674679
if args.val_split:
@@ -684,6 +689,7 @@ def train(config: dict[str, t.Any]):
684689
input_key=args.input_key,
685690
target_key=args.target_key,
686691
num_samples=args.val_num_samples,
692+
samples_csv_path=args.val_samples_csv_path,
687693
)
688694

689695
# setup mixup / cutmix

0 commit comments

Comments
 (0)