-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdata.py
47 lines (40 loc) · 1.7 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import itertools
from typing import Optional
import hivemind
import numpy as np
from datasets import load_dataset
logger = hivemind.get_logger(__name__)
def preprocess_batch(batch, tokenizer, max_sequence_length: int):
mask = [
(
caption is not None and len(caption) >= 3 and
nsfw == 'UNLIKELY' and
orig_width > 0 and orig_height > 0 and
max(orig_height / orig_width, orig_width / orig_height) <= 2
) for caption, nsfw, orig_width, orig_height in
zip(batch['caption'], batch['NSFW'], batch['original_width'], batch['original_height'])
]
logger.debug(f'{np.mean(mask) * 100:.1f}% of examples left after filtering')
if any(mask):
result = tokenizer(list(itertools.compress(batch['caption'], mask)),
add_special_tokens=False, max_length=max_sequence_length, truncation=True)
else:
# This branch is necessary because tokenizer([]) raises IndexError
result = {'input_ids': [], 'attention_mask': []}
result['image'] = [np.frombuffer(encoded, np.int16).astype(np.int64)
for encoded in itertools.compress(batch['code'], mask)]
return result
def make_dataset(
tokenizer,
*,
shuffle_buffer_size: int = 8192,
shuffle_seed: Optional[int],
preprocessing_batch_size: int = 256,
max_sequence_length: int,
):
ds = load_dataset('laion/laion_100m_vqgan_f8', split='train', streaming=True)
ds = ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)
ds = ds.map(lambda batch: preprocess_batch(batch, tokenizer, max_sequence_length),
batched=True, batch_size=preprocessing_batch_size)
ds = ds.with_format('torch')
return ds