-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathruntime_data.py
More file actions
136 lines (115 loc) · 4.99 KB
/
runtime_data.py
File metadata and controls
136 lines (115 loc) · 4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from __future__ import annotations
from typing import Tuple
import torch
NAT_WARMUP_FRAC = 0.15
NAT_PEAK_PROB = 0.5
NAT_MIN_BLIND_FRAC = 0.1
NAT_MAX_BLIND_FRAC = 0.7
FIM_WARMUP_FRAC = 0.15
FIM_MIN_SPAN_FRAC = 0.05
FIM_MAX_SPAN_FRAC = 0.08
def apply_nat(
input_ids: torch.Tensor,
pad_id: int,
step: int,
max_steps: int,
warmup_frac: float = NAT_WARMUP_FRAC,
peak_prob: float = NAT_PEAK_PROB,
max_blind_frac: float = NAT_MAX_BLIND_FRAC,
min_blind_frac: float = NAT_MIN_BLIND_FRAC,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Blind-tail (non-autoregressive) prefix-completion task.
Closes the train/inference gap of one-shot `generate()`: at inference
the model sees PAD on emission positions and must produce tokens from
the ray-carried causal state alone. Trains that regime by randomly
replacing a per-sample suffix of `input_ids` with PAD while keeping
the original tokens as `target_ids` for the LM loss.
Schedule:
* steps [0 .. warmup_frac*max_steps): pure teacher forcing (no-op).
* [warmup_frac .. end]: per-sample Bernoulli (p = peak_prob *
progress) selects NAT-masked samples; blind length uniform in
[1, L * (min_blind + (max_blind - min_blind) * progress)].
Masks/loss in the caller are built from the UNMUTATED target so
lengths/EOS stay correct even when the tail is all-PAD.
"""
B, L = input_ids.shape
device = input_ids.device
warmup_steps = int(max_steps * warmup_frac)
if step < warmup_steps:
return input_ids, input_ids
progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
progress = min(1.0, max(0.0, progress))
p_nat = peak_prob * progress
apply_mask = torch.rand(B, device=device) < p_nat
if not apply_mask.any():
return input_ids, input_ids
max_blind = max(1, int(L * (min_blind_frac + (max_blind_frac - min_blind_frac) * progress)))
blind_len = torch.randint(1, max_blind + 1, (B,), device=device)
blind_len = torch.where(apply_mask, blind_len, torch.zeros_like(blind_len))
P = L - blind_len
idx = torch.arange(L, device=device).unsqueeze(0).expand(B, -1)
blind_positions = idx >= P.unsqueeze(1)
masked_inputs = torch.where(
blind_positions,
torch.full_like(input_ids, pad_id),
input_ids,
)
return masked_inputs, input_ids
def apply_fim(
input_ids: torch.Tensor,
pad_id: int,
vocab_size: int,
step: int,
max_steps: int,
warmup_frac: float = FIM_WARMUP_FRAC,
min_span_frac: float = FIM_MIN_SPAN_FRAC,
max_span_frac: float = FIM_MAX_SPAN_FRAC,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Noise-tolerant continuation task.
Corrupt a random interior span (5-8% of L) in the input — either
PAD-zeroing or replacing with random tokens (50/50) — and mask those
same positions in the TARGET so the LM loss gets NO signal on them.
The model thus sees broken text but is trained to "continue as if not
broken" on every position outside the span: loss on the clean prefix
remains (teacher forcing), loss on the continuation after the span is
the core objective (emit real tokens despite having read garbage).
The model is explicitly NOT pushed to guess the original corrupted
tokens — only to recover and continue gracefully.
The span is constrained so it never overlaps with the EOS position
(keeping `lengths_from_tokens(target)` correct) and always leaves at
least 2 tokens of continuation after it.
"""
B, L = input_ids.shape
device = input_ids.device
warmup_steps = int(max_steps * warmup_frac)
if step < warmup_steps:
return input_ids, input_ids
progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
progress = min(1.0, max(0.0, progress))
rate = min_span_frac + (max_span_frac - min_span_frac) * progress
span_len = max(1, int(L * rate))
# Per-sample real lengths (non-pad count == EOS_pos + 1 in tokenizer-padded batches).
lengths = (input_ids != pad_id).sum(dim=1) # [B]
corrupted = input_ids.clone()
target = input_ids.clone()
for i in range(B):
L_i = int(lengths[i].item())
# Need room for BOS (pos 0), span of span_len, and at least 2
# continuation tokens (inc. EOS). If the sample is too short,
# fall through with no corruption for this sample.
max_start = L_i - span_len - 2
if max_start < 1:
continue
start = int(torch.randint(1, max_start + 1, (1,)).item())
end = start + span_len
if torch.rand(1).item() < 0.5:
corrupted[i, start:end] = pad_id
else:
corrupted[i, start:end] = torch.randint(
3, vocab_size, (end - start,), device=device,
)
# Mask labels inside the corrupted span so CE (ignore_index=pad)
# produces zero gradient at those positions. Predictions outside
# the span are still supervised against the clean tokens.
target[i, start:end] = pad_id
return corrupted, target