-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
1413 lines (1297 loc) · 73.2 KB
/
train.py
File metadata and controls
1413 lines (1297 loc) · 73.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import os
import time
import random
import argparse
import yaml
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import GradScaler
from torchcfm.conditional_flow_matching import ExactOptimalTransportConditionalFlowMatcher
import torch.backends.cudnn as cudnn
from torch import amp
from contextlib import nullcontext
from tqdm import tqdm
from ot_jepa.models.encoders import (
VisionEncoder,
StateEncoder,
LangEncoder,
update_target_network,
)
from ot_jepa.models.metric import MetricNet
from ot_jepa.models.jepa import TemporalPredictor, GoalDistributionHead
from ot_jepa.models.ot_losses import sliced_w2, sinkhorn_w2, string_prior, cross_modal_ot, goal_w2
from ot_jepa.models.gromov_wasserstein import (
representation_alignment_loss,
batch_ot_coupling,
bilevel_ot_contrastive_loss,
entropic_gromov_wasserstein,
)
from ot_jepa.models.flow_matching import FlowMatchingHead
from ot_jepa.data.buffers import EpisodeDataset, WindowSpec
class OTJEPAModel(nn.Module):
def __init__(self, state_dim: int, cfg: dict, act_dim: int = None):
super().__init__()
d = cfg["model"]["latent_dim"]
# Use dataset-inferred action dim (7D for VJEPA2-AC with ee_delta)
# Falls back to config if not provided (for backward compatibility)
if act_dim is None:
act_dim = int(cfg["model"].get("act_dim", 7))
self.act_dim = act_dim
# Online encoders
v_patch = int(cfg["model"].get("patch_size", 16))
v_depth = int(cfg["model"].get("vision_depth", 4))
v_heads = int(cfg["model"].get("vision_heads", 4))
vision_backbone = str(cfg["model"].get("vision_backbone", "internal")).lower()
if vision_backbone == "vjepa2":
# Lazy import to avoid requiring vjepa2 when not used
from ot_jepa.models.vjepa2_backbone import VJEPA2VisionEncoder
self.E_v = VJEPA2VisionEncoder(
latent_dim=d,
img_size=tuple(cfg.get("data", {}).get("image_size", (256, 256))),
patch_size=v_patch,
depth=v_depth,
heads=v_heads,
)
elif vision_backbone == "vjepa2_hub":
from ot_jepa.models.vjepa2_backbone import VJEPA2HubEncoder
v2 = cfg.get("vjepa2", {})
self.E_v = VJEPA2HubEncoder(
latent_dim=d,
variant=str(v2.get("variant", "vjepa2_ac_vit_giant")),
pretrained=bool(v2.get("pretrained", True)),
freeze=bool(v2.get("freeze_encoder", True)),
img_size=tuple(cfg.get("data", {}).get("image_size", (256, 256))),
patch_size=v_patch,
)
else:
self.E_v = VisionEncoder(latent_dim=d, patch_size=v_patch, depth=v_depth, heads=v_heads)
self.E_s = StateEncoder(in_dim=state_dim, latent_dim=d)
self.E_l = LangEncoder(vocab_size=512, emb_dim=d, latent_dim=d)
# Multi-view fusion: front + wrist + state -> d
self.Fusion = nn.Sequential(nn.Linear(d * 3, d), nn.ReLU(), nn.Linear(d, d))
# Single-view fusion: vision + state -> d (for independent camera training)
self.Fusion_single = nn.Sequential(nn.Linear(d * 2, d), nn.ReLU(), nn.Linear(d, d))
# Prediction and control heads
# V-JEPA2-AC variants use the official hub action-conditioned predictor attached to E_v.
# OT-JEPA and related architectures use a simple temporal predictor over latents.
arch = cfg.get("model", {}).get("architecture", "ot-jepa").lower()
self.use_action_conditioning = arch in ("vjepa2ac-baseline", "vjepa2ac-continued", "vjepa2ac-unfreeze", "vjepa2ac-ot")
# VJEPA2-AC hub predictor operates on patch tokens for MPC planning
self.use_patch_token_mpc = self.use_action_conditioning
# Only non action-conditioned JEPA/OT architectures rely on this latent-level predictor.
# For V-JEPA2-AC variants, self.Pred is not used in training.
if self.use_action_conditioning:
self.Pred = None
else:
self.Pred = TemporalPredictor(latent_dim=d)
self.Metric = MetricNet(d, cfg["model"]["metric_rank"])
self.FM = FlowMatchingHead(d, act_dim)
self.GoalHead = GoalDistributionHead(latent_dim=d)
# Projection head for bilevel batch-OT (R^d -> R^64)
self.OTProj = nn.Sequential(nn.Linear(d, 128), nn.GELU(), nn.Linear(128, 64))
def set_seed(seed: int, rank: int = 0):
seed = int(seed) + int(rank)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def _find_latest_checkpoint(arch: str, directory: str = "checkpoints") -> tuple[str | None, int]:
if not os.path.isdir(directory):
return None, 0
prefix = f"{arch}_"
latest_step = -1
latest_path: str | None = None
for name in os.listdir(directory):
if not name.startswith(prefix) or not name.endswith(".pt"):
continue
step_str = name[len(prefix):-3]
try:
step_val = int(step_str)
except ValueError:
continue
if step_val > latest_step:
latest_step = step_val
latest_path = os.path.join(directory, name)
return latest_path, max(latest_step, 0)
def parse():
ap = argparse.ArgumentParser()
ap.add_argument("--config", type=str, default="configs/default.yaml")
# Optional overrides in dot-notation, e.g.:
# --override train.total_steps=100 device=cpu
ap.add_argument("--override", type=str, nargs="*", default=[])
return ap.parse_args()
def _parse_override_value(raw: str):
"""Best-effort scalar parsing for override values.
Attempts bool -> int -> float -> str, so that common CLI values
like "true", "123", or "1e-4" map to the expected Python types.
"""
text = str(raw).strip()
low = text.lower()
if low in ("true", "false"):
return low == "true"
try:
return int(text)
except ValueError:
pass
try:
return float(text)
except ValueError:
pass
return text
def _apply_overrides(cfg: dict, overrides: list[str]) -> None:
"""Apply --override key=value updates into the nested config dict.
Keys use dot-notation, e.g. "train.total_steps". Missing intermediate
dictionaries are created on demand.
"""
for item in overrides:
if "=" not in item:
continue
key_str, value_str = item.split("=", 1)
key_str = key_str.strip()
if not key_str:
continue
keys = key_str.split(".")
d: dict = cfg
for k in keys[:-1]:
if k not in d or not isinstance(d[k], dict):
d[k] = {}
d = d[k]
leaf = keys[-1]
d[leaf] = _parse_override_value(value_str)
def init_distributed() -> tuple[int, int, int]:
# Prefer torchrun/torch.distributed.run env, fallback to SLURM when needed
has_torchrun = "WORLD_SIZE" in os.environ and "RANK" in os.environ
has_slurm = "SLURM_PROCID" in os.environ and "SLURM_NTASKS" in os.environ
if has_torchrun or has_slurm:
if has_torchrun:
rank = int(os.environ["RANK"]) # global rank
world_size = int(os.environ["WORLD_SIZE"]) # total processes
local_rank = int(os.environ.get("LOCAL_RANK", rank % max(1, torch.cuda.device_count())))
else:
# Single srun per node; use SLURM env as a fallback
rank = int(os.environ.get("SLURM_PROCID", 0))
world_size = int(os.environ.get("SLURM_NTASKS", 1))
local_rank = int(os.environ.get("SLURM_LOCALID", rank % max(1, torch.cuda.device_count())))
# Provide sane defaults if not set
os.environ.setdefault("MASTER_ADDR", os.environ.get("SLURM_LAUNCH_NODE_IPADDR", "127.0.0.1"))
os.environ.setdefault("MASTER_PORT", "29500")
os.environ.setdefault("RANK", str(rank))
os.environ.setdefault("WORLD_SIZE", str(world_size))
os.environ.setdefault("LOCAL_RANK", str(local_rank))
# Set device before creating process group
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(backend=backend, init_method="env://", rank=rank, world_size=world_size)
return rank, world_size, local_rank
return 0, 1, 0
def main():
args = parse()
cfg = yaml.safe_load(open(args.config))
# Apply any command-line overrides after reading the base config
if getattr(args, "override", None):
_apply_overrides(cfg, args.override)
rank, world_size, local_rank = init_distributed()
is_distributed = world_size > 1
set_seed(cfg["seed"], rank)
is_main = (rank == 0)
if cfg["device"] == "cuda" and torch.cuda.is_available():
device = torch.device("cuda", local_rank if torch.cuda.device_count() > 0 else 0)
cudnn.benchmark = True
# Enable TF32 for faster matmuls on Ampere+ GPUs
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if hasattr(torch, "set_float32_matmul_precision"):
torch.set_float32_matmul_precision("high")
except Exception:
pass
else:
device = torch.device("cpu")
# Data - aligned with Meta VJEPA2-AC clip lengths
os.makedirs(cfg["data"]["episode_dir"], exist_ok=True)
window = WindowSpec(k=cfg["data"]["window_k"], H=cfg["data"]["horizon_H"])
# V-JEPA 2 data augmentation parameters from config
random_resize_scale = tuple(cfg["data"].get("random_resize_scale", [0.3, 1.0]))
random_resize_aspect_ratio = tuple(cfg["data"].get("random_resize_aspect_ratio", [0.75, 1.35]))
horizontal_flip = bool(cfg["data"].get("horizontal_flip", False))
augment = bool(cfg["data"].get("augment", True)) # Enable augmentation by default
ds = EpisodeDataset(
cfg["data"]["episode_dir"],
window,
cfg["data"].get("image_size", (256, 256)),
rank=rank,
world_size=world_size,
use_embeddings=bool(cfg["data"].get("use_embeddings", False)),
# V-JEPA 2 data augmentation
random_resize_scale=random_resize_scale,
random_resize_aspect_ratio=random_resize_aspect_ratio,
horizontal_flip=horizontal_flip,
augment=augment,
)
if len(ds)==0:
raise RuntimeError("No episodes found; run data.py first.")
# Detect dataset format for logging
dataset_format = "unknown"
if ds.episodes:
import pandas as pd
first_ep = ds.episodes[0]
parquet_path = os.path.join(first_ep, "episode.parquet")
if os.path.exists(parquet_path):
df = pd.read_parquet(parquet_path)
has_ee_state = any(c.startswith("state_ee_state") for c in df.columns)
has_ee_delta = any(c.startswith("act_ee_delta") for c in df.columns)
if has_ee_state and has_ee_delta:
dataset_format = "7D end-effector (VJEPA2-AC format)"
elif any(c.startswith("state_q") for c in df.columns):
dataset_format = "9D joint-space (legacy format)"
if is_main:
print(f"Dataset: {len(ds)} episodes, state_dim={ds.state_dim}, action_dim={ds.action_dim}")
print(f"Dataset format: {dataset_format}")
if ds.state_dim == 9 or ds.action_dim == 9:
print(f"Using legacy 9D format. Consider regenerating dataset for 7D end-effector format.")
# Log V-JEPA 2 augmentation settings
if augment:
print(f"V-JEPA 2 Augmentation: ENABLED")
print(f"RandomResizedCrop scale: {random_resize_scale}")
print(f"RandomResizedCrop aspect ratio: {random_resize_aspect_ratio}")
print(f"Horizontal flip: {horizontal_flip}")
else:
print(f"V-JEPA 2 Augmentation: DISABLED")
# Synchronize inferred dims across ranks to avoid model shape mismatches
if is_distributed:
sd = torch.tensor([int(ds.state_dim)], device=device)
ad = torch.tensor([int(ds.action_dim)], device=device)
dist.all_reduce(sd, op=dist.ReduceOp.MAX)
dist.all_reduce(ad, op=dist.ReduceOp.MAX)
ds.state_dim = int(sd.item())
ds.action_dim = int(ad.item())
# Select architecture
arch = cfg.get("model", {}).get("architecture", "ot-jepa").lower()
# Optional fast-mode shortcuts for compute-constrained runs. These only
# adjust hyperparameters (not loss structure) and are fully controlled by config.
fast_mode = bool(cfg.get("train", {}).get("fast_mode", False))
if fast_mode:
mconf = cfg.setdefault("model", {})
oconf = cfg.setdefault("ot", {})
# Smaller action-conditioned predictor for V-JEPA2-AC variants
if arch in ("vjepa2ac-baseline", "vjepa2ac-continued", "vjepa2ac-unfreeze", "vjepa2ac-ot"):
mconf.setdefault("pred_hidden_dim", 384)
mconf.setdefault("pred_layers", 6)
mconf.setdefault("pred_heads", 6)
# OT hyperparameters: fewer projections / iterations by default
oconf.setdefault("num_projections_time", 32)
oconf.setdefault("num_projections_xmod", 16)
oconf.setdefault("iters", 10)
oconf.setdefault("batch_ot_iters", 10)
oconf.setdefault("bilevel_iters", 10)
oconf.setdefault("gw_iters", 10)
# Ensure act_dim aligns with dataset
cfg.setdefault("model", {})
cfg["model"]["act_dim"] = int(ds.action_dim)
# Sensible default backbones per arch
vb = str(cfg["model"].get("vision_backbone", "internal")).lower()
if arch in ("vjepa", "ot-vjepa") and vb == "internal":
cfg["model"]["vision_backbone"] = "vjepa2_hub"
vb = "vjepa2_hub"
# For all V-JEPA2-AC variants, force the official robotics-trained hub backbone
# so that continued pretraining always starts from vjepa2_ac_vit_giant weights.
if arch in ("vjepa2ac-baseline", "vjepa2ac-continued", "vjepa2ac-unfreeze", "vjepa2ac-ot"):
cfg["model"]["vision_backbone"] = "vjepa2_hub"
v2 = cfg.setdefault("vjepa2", {})
v2.setdefault("variant", "vjepa2_ac_vit_giant")
v2.setdefault("pretrained", True)
# Model (online)
if is_main:
print(f"Building model (arch='{arch}', backbone='{cfg['model'].get('vision_backbone','internal')}') ...")
# JEPA or OT-JEPA use the same backbone; OT only affects losses
model = OTJEPAModel(ds.state_dim, cfg, act_dim=ds.action_dim).to(device)
if is_main:
print(f"Model created with state_dim={ds.state_dim}, action_dim={ds.action_dim}")
if is_main:
print("Model built; moving to optional compile if enabled ...")
# Variant-specific freezing for VJEPA2-AC
def _freeze(m: nn.Module):
for p in m.parameters():
p.requires_grad = False
def _unfreeze(m: nn.Module):
for p in m.parameters():
p.requires_grad = True
if arch in ("vjepa2ac-baseline", "vjepa2ac-continued", "vjepa2ac-unfreeze", "vjepa2ac-ot"):
ev = getattr(model, "E_v", None)
has_hub_predictor = (ev is not None and hasattr(ev, "predictor") and ev.predictor is not None)
if arch == "vjepa2ac-baseline":
if is_main:
print("Variant vjepa2ac-baseline: training FM head only; freezing encoders and predictors.")
_freeze(model)
_unfreeze(model.FM)
# Unfreeze fusion layers as they are new and random initialized
_unfreeze(model.Fusion)
_unfreeze(model.Fusion_single)
_unfreeze(model.E_s) # State encoder is also new/random
_unfreeze(model.E_l) # Lang encoder is also new/random
# Keep hub predictor frozen (no training)
if has_hub_predictor:
_freeze(ev.predictor)
ev.predictor.eval()
elif arch == "vjepa2ac-continued":
if is_main:
print("Variant vjepa2ac-continued: freezing vision backbone; training hub predictor/FM/fusion/state.")
# Canonical VJEPA2-AC continued training: freeze encoder, train predictor
if ev is not None and hasattr(ev, "backbone"):
_freeze(ev.backbone)
_freeze(ev.norm)
_freeze(ev.proj)
ev.backbone.eval()
# Unfreeze hub predictor for training
if has_hub_predictor:
_unfreeze(ev.predictor)
ev.predictor.train()
_unfreeze(model.E_s)
_unfreeze(model.E_l)
_unfreeze(model.Fusion)
_unfreeze(model.Fusion_single) # Unfreeze single-view fusion layer
_unfreeze(model.FM)
elif arch == "vjepa2ac-unfreeze":
if is_main:
print("Variant vjepa2ac-unfreeze: training all components (encoder + predictor + FM).")
_unfreeze(model)
# Unfreeze hub encoder backbone and predictor
if ev is not None and hasattr(ev, "backbone"):
_unfreeze(ev.backbone)
ev.backbone.train()
if has_hub_predictor:
_unfreeze(ev.predictor)
ev.predictor.train()
elif arch == "vjepa2ac-ot":
if is_main:
print("Variant vjepa2ac-ot: training all + OT losses enabled.")
_unfreeze(model)
# For OT variant: unfreeze encoder (for GW alignment) and predictor
if ev is not None and hasattr(ev, "backbone"):
_unfreeze(ev.backbone)
ev.backbone.train()
if has_hub_predictor:
_unfreeze(ev.predictor)
ev.predictor.train()
# Optional compile for speed (PyTorch 2.x); enabled by default but configurable
# Use 'reduce-overhead' for better iteration time vs 'max-autotune' for peak throughput
if bool(cfg.get("train", {}).get("compile", True)) and hasattr(torch, "compile"):
try:
compile_mode = str(cfg.get("train", {}).get("compile_mode", "reduce-overhead"))
model = torch.compile(model, mode=compile_mode)
if is_main:
print(f"Model compiled with mode='{compile_mode}'")
except Exception as e:
if is_main:
print(f"torch.compile failed: {e}")
# Target encoders/fuser (EMA) for JEPA/OT-JEPA
d = cfg["model"]["latent_dim"]
v_patch = int(cfg["model"].get("patch_size", 16))
v_depth = int(cfg["model"].get("vision_depth", 4))
v_heads = int(cfg["model"].get("vision_heads", 4))
if arch in ("jepa", "ot-jepa", "ot-vjepa", "vjepa2ac-ot"):
if is_main:
print("Building target encoders/fuser (EMA) ...")
vision_backbone = str(cfg["model"].get("vision_backbone", "internal")).lower()
if vision_backbone == "vjepa2":
from ot_jepa.models.vjepa2_backbone import VJEPA2VisionEncoder
E_v_t = VJEPA2VisionEncoder(
latent_dim=d,
img_size=tuple(cfg.get("data", {}).get("image_size", (256, 256))),
patch_size=v_patch,
depth=v_depth,
heads=v_heads,
).to(device)
elif vision_backbone == "vjepa2_hub":
from ot_jepa.models.vjepa2_backbone import VJEPA2HubEncoder
v2 = cfg.get("vjepa2", {})
E_v_t = VJEPA2HubEncoder(
latent_dim=d,
variant=str(v2.get("variant", "vjepa2_ac_vit_giant")),
pretrained=bool(v2.get("pretrained", True)),
freeze=True,
img_size=tuple(cfg.get("data", {}).get("image_size", (256, 256))),
patch_size=v_patch,
).to(device)
else:
E_v_t = VisionEncoder(latent_dim=d, patch_size=v_patch, depth=v_depth, heads=v_heads).to(device)
E_s_t = StateEncoder(in_dim=ds.state_dim, latent_dim=d).to(device)
Fusion_t = nn.Sequential(nn.Linear(d * 3, d), nn.ReLU(), nn.Linear(d, d)).to(device)
# Single-view fusion target (for independent camera trajectory training)
Fusion_single_t = nn.Sequential(nn.Linear(d * 2, d), nn.ReLU(), nn.Linear(d, d)).to(device)
# Initialize targets with online weights
E_v_t.load_state_dict(model.E_v.state_dict())
E_s_t.load_state_dict(model.E_s.state_dict())
Fusion_t.load_state_dict(model.Fusion.state_dict())
Fusion_single_t.load_state_dict(model.Fusion_single.state_dict())
if is_main:
print("Target encoders/fuser ready (including Fusion_single_t for single-view).")
else:
E_v_t = E_s_t = Fusion_t = Fusion_single_t = None
# GW alignment infrastructure
# For OT architectures (including V-JEPA2-AC OT) we keep the
# original representation-based GW alignment that uses a frozen pretrained encoder.
E_v_pretrained = None
gw_ref_weights = None
use_gw_alignment = bool(cfg.get("ot", {}).get("use_gw_alignment", False))
# Ensure vision_backbone is defined even for architectures that skip EMA targets
vision_backbone = str(cfg["model"].get("vision_backbone", "internal")).lower()
# Representation-based GW path: only for non V-JEPA2-AC OT architectures
if (
use_gw_alignment
and vision_backbone in ("vjepa2_hub")
# Explicitly exclude vjepa2ac-ot (using EMA target instead of frozen copy)
and arch not in ("vjepa2ac-ot",)
):
if is_main:
print(f"Creating frozen pretrained encoder for GW alignment (backbone={vision_backbone}) ...")
if vision_backbone == "vjepa2_hub":
from ot_jepa.models.vjepa2_backbone import VJEPA2HubEncoder
v2 = cfg.get("vjepa2", {})
E_v_pretrained = VJEPA2HubEncoder(
latent_dim=d,
variant=str(v2.get("variant", "vjepa2_ac_vit_giant")),
pretrained=bool(v2.get("pretrained", True)),
freeze=True, # Always frozen for alignment reference
img_size=tuple(cfg.get("data", {}).get("image_size", (256, 256))),
patch_size=v_patch,
).to(device).eval()
if is_main:
print("Frozen pretrained encoder ready for GW alignment.")
# Weight-Space GW path removed as per user request.
# BEFORE DDP wrapping: ensure hub predictor is on correct device for this rank
# This ensures DDP replicates the predictor correctly on each GPU
if hasattr(model, 'E_v') and hasattr(model.E_v, 'predictor') and model.E_v.predictor is not None:
predictor = model.E_v.predictor
# Recursively move predictor and ALL its submodules to device
predictor = predictor.to(device)
# Explicitly move all named submodules to ensure they're on device
for name, module in predictor.named_modules():
if module is not predictor: # Don't move the root module twice
if hasattr(module, 'to'):
module.to(device)
# Also check for direct attributes that might be modules (state_encoder, action_encoder, etc.)
for attr_name in dir(predictor):
if not attr_name.startswith('_'):
try:
attr = getattr(predictor, attr_name)
if isinstance(attr, nn.Module) and attr is not predictor:
attr.to(device)
except (AttributeError, RuntimeError):
pass # Skip attributes that can't be accessed
model.E_v.predictor = predictor
if is_main:
print(f"Hub predictor moved to {device} (before DDP wrapping)")
if is_distributed:
# gradient_as_bucket_view reduces memory copies
# static_graph=True enables extra optimizations if model structure is fixed
find_unused = arch not in ("vjepa2ac-baseline", "vjepa2ac-continued", "vjepa2ac-unfreeze", "vjepa2ac-ot")
# Only pass device_ids when model is on GPU; for CPU training, DDP handles it automatically
ddp_kwargs = {
"find_unused_parameters": find_unused,
"gradient_as_bucket_view": True,
}
if device.type == "cuda":
ddp_kwargs["device_ids"] = [local_rank]
model = DDP(model, **ddp_kwargs)
# Ensure all ranks complete DDP wrapping before moving on
dist.barrier()
print(f"Rank {rank}: DDP wrapped and synchronized.")
model_for_losses = model.module if isinstance(model, DDP) else model
lr = float(cfg["train"]["lr"]) # YAML scientific notation can load as str; ensure optimizer receives float
# Prefer fused AdamW when available
try:
opt = optim.AdamW(model.parameters(), lr=lr, fused=True, weight_decay=float(cfg["train"].get("weight_decay", 1e-4)))
except TypeError:
opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=float(cfg["train"].get("weight_decay", 1e-4)))
# Learning rate warmup for faster convergence within 100 iterations
warmup_steps = int(cfg["train"].get("lr_warmup_steps", 0))
total_steps = int(cfg["train"]["total_steps"])
start_lr = float(cfg["train"].get("start_lr", lr * 0.1))
final_lr = float(cfg["train"].get("final_lr", 1.0e-6))
def get_lr_scale(step):
"""Linear warmup then Cosine decay (V-JEPA 2 schedule)"""
# 1. Linear Warmup: start_lr -> lr
if warmup_steps > 0 and step < warmup_steps:
alpha = step / warmup_steps
# scale = current_lr / base_lr
# current_lr = start_lr + alpha * (lr - start_lr)
return (start_lr + alpha * (lr - start_lr)) / lr
# 2. Cosine Decay: lr -> final_lr
if step >= total_steps:
return final_lr / lr
progress = (step - warmup_steps) / (total_steps - warmup_steps)
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
current_lr = final_lr + (lr - final_lr) * cosine_decay
return current_lr / lr
# AMP dtype: default bfloat16 for stability; choose via config train.amp_dtype: [bfloat16|float16]
amp_dtype_name = str(cfg.get("train", {}).get("amp_dtype", "bfloat16")).lower()
amp_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}.get(amp_dtype_name, torch.bfloat16)
scaler = GradScaler(enabled=(device.type == "cuda" and amp_dtype is torch.float16))
# Auto-resume from latest checkpoint if available
latest_ckpt_path, latest_ckpt_step = _find_latest_checkpoint(arch)
start_step = 0
if latest_ckpt_path is not None:
map_location = device
ckpt = torch.load(latest_ckpt_path, map_location=map_location, weights_only=False)
model_state = ckpt.get("model")
if model_state is not None:
model_for_losses.load_state_dict(model_state, strict=False)
if E_v_t is not None and "E_v_t" in ckpt:
E_v_t.load_state_dict(ckpt["E_v_t"], strict=False)
if E_s_t is not None and "E_s_t" in ckpt:
E_s_t.load_state_dict(ckpt["E_s_t"], strict=False)
if Fusion_t is not None and "Fusion_t" in ckpt:
Fusion_t.load_state_dict(ckpt["Fusion_t"], strict=False)
if Fusion_single_t is not None and "Fusion_single_t" in ckpt:
Fusion_single_t.load_state_dict(ckpt["Fusion_single_t"], strict=False)
opt_state = ckpt.get("optimizer")
if opt_state is not None:
opt.load_state_dict(opt_state)
scaler_state = ckpt.get("scaler")
if scaler_state is not None:
try:
scaler.load_state_dict(scaler_state)
except Exception:
pass
start_step = int(ckpt.get("step", latest_ckpt_step))
if is_main:
print(f"Loaded checkpoint '{latest_ckpt_path}' (step={start_step})")
else:
start_step = 0
per_rank_cfg = cfg["data"].get("per_device_batch_size", None)
if per_rank_cfg is not None:
per_rank_batch = int(per_rank_cfg)
global_batch = per_rank_batch * world_size
else:
global_batch = int(cfg["data"]["batch_size"])
if global_batch % world_size != 0 and is_main:
print(f"data.batch_size {global_batch} not divisible by world_size {world_size}; using floor per-rank batch.")
per_rank_batch = max(1, global_batch // world_size)
if is_main:
print(f"world_size={world_size}, per_rank_batch={per_rank_batch}, effective_global_batch={per_rank_batch*world_size}")
K = window.k; H = window.H
w = cfg["loss_weights"]; total = cfg["train"]["total_steps"]
iterator = range(start_step, total)
if is_main:
iterator = tqdm(iterator, desc="train", initial=start_step, total=total)
grad_accum = int(cfg.get("train", {}).get("grad_accum_steps", 1))
teacher_stride = max(1, int(cfg.get("train", {}).get("fm_teacher_stride", 4)))
if is_main:
print(f"FM teacher stride={teacher_stride} (teacher runs every {teacher_stride} steps)")
# Initialize OT-CFM Loss
fm_loss_fn = ExactOptimalTransportConditionalFlowMatcher(sigma=0.0)
# Camera view selection: Meta VJEPA2 trains on single camera view (left exocentric)
# If front_camera_only=True, always use front camera; otherwise randomly select front/wrist
front_camera_only = cfg.get("data", {}).get("front_camera_only", False)
if is_main:
if front_camera_only:
print("Using FRONT camera only (matching Meta's single-view training)")
else:
print("Using random front/wrist view selection (50/50)")
cumulative_step_time = 0.0
optimizer_step = 0
group_start_time = None
for step in iterator:
step_start = time.perf_counter()
if is_main and step == start_step == 0:
print("Sampling first batch ...", flush=True)
# Sample batch with Meta VJEPA2-AC aligned parameters:
# Clips enforced to 5-10 seconds (20-40 frames at 4fps)
# Goal images sampled with fractional conditioning (0.7, 0.2, 0.1)
# Per-scene subgoal images used when available (front_goal_grasp.png,
# front_goal_near.png, front_goal_final.png)
batch = ds.sample_batch(per_rank_batch, sample_goal_images=True,
max_goal_offset=20, use_fractional_goals=True,
use_scene_subgoals=True)
if is_main and step == start_step == 0:
print("First batch ready; entering forward pass ...", flush=True)
# Keep CPU copy for teacher to avoid GPU->CPU sync
cpu_imgs_front = batch["imgs_front"]
imgs_front = batch["imgs_front"].to(device, non_blocking=True).float().contiguous() # (B, K+H+1, C, H, W)
imgs_wrist = batch["imgs_wrist"].to(device, non_blocking=True).float().contiguous()
state = batch["state"].to(device, non_blocking=True).float()
token = batch["token"].to(device, non_blocking=True)
actions = batch.get("actions", torch.zeros(imgs_front.size(0), K+H+1, cfg["model"]["act_dim"], device=device)).to(device, non_blocking=True).float()
# Batch shape: images (B, T, C, H, W) or precomputed embeddings (B, T, D)
if imgs_front.ndim == 5:
B, T, C, H_img, W_img = imgs_front.shape
else:
B, T, _ = imgs_front.shape
C = H_img = W_img = None
# Encode sequence with ONLINE encoders or precomputed embeddings (multi-view fusion)
# ImageNet normalization is now handled internally by VJEPA2HubEncoder
def encode_sequence(Ev, Es, Fusion, imgs_f, imgs_w, s):
if imgs_f.ndim == 3:
# Precomputed embeddings: imgs_f/imgs_w are (B, T, D)
b, t, d = imgs_f.shape
zf = imgs_f.reshape(b * t, d)
zw = imgs_w.reshape(b * t, d)
else:
b, t, c, hh, ww = imgs_f.shape
f = imgs_f.reshape(b * t, c, hh, ww)
wv = imgs_w.reshape(b * t, c, hh, ww)
# Apply channels_last only to 4D image tensors for better throughput
f = f.contiguous(memory_format=torch.channels_last)
wv = wv.contiguous(memory_format=torch.channels_last)
# If encoder is frozen, avoid building autograd graph for its forward
ctx = torch.no_grad() if not next(Ev.parameters()).requires_grad else nullcontext()
with ctx:
zf = Ev(f)
zw = Ev(wv)
zs = Es(s.reshape(b * t, -1))
fused = torch.cat([zf, zw, zs], dim=-1)
z = Fusion(fused)
return z.reshape(b, t, -1)
# Encode single camera view (V-JEPA 2 style: independent trajectories)
# ImageNet normalization is now handled internally by VJEPA2HubEncoder
def encode_single_view(Ev, Es, Fusion_single, imgs, s):
"""Encode a single camera view.
Used for training on independent front/wrist trajectories.
For V-JEPA 2 AC: returns vision only (d), state passed to predictor
For other architectures: fuses vision+state using single-view fusion (2*d -> d)
"""
if imgs.ndim == 3:
# Precomputed embeddings: imgs are (B, T, D)
b, t, d = imgs.shape
zv = imgs.reshape(b * t, d)
else:
b, t, c, hh, ww = imgs.shape
imgs_flat = imgs.reshape(b * t, c, hh, ww)
imgs_flat = imgs_flat.contiguous(memory_format=torch.channels_last)
ctx = torch.no_grad() if not next(Ev.parameters()).requires_grad else nullcontext()
with ctx:
zv = Ev(imgs_flat)
# For single-view: fuse vision + state using a 2-input fusion
# This gives us d-dimensional output matching expected latent_dim
zs = Es(s.reshape(b * t, -1))
z = Fusion_single(torch.cat([zv, zs], dim=-1))
return z.reshape(b, t, -1)
teacher_start = None
with amp.autocast(device_type='cuda', dtype=amp_dtype, enabled=(device.type == "cuda")):
# Initialize logging variables to avoid UnboundLocalError
gw_weight = 0.0
if arch in ("jepa-hf", "vjepa", "vjepa2ac-baseline"):
# FM-only fine-tuning on top of a chosen backbone; teacher supervision
# V-JEPA 2 style: randomly select front OR wrist trajectory (independent training)
# If front_camera_only=True, always use front camera (matching Meta's single-view training)
use_wrist_view = False if front_camera_only else (random.random() < 0.5)
imgs_selected = imgs_wrist if use_wrist_view else imgs_front
# Encode only the final context frame K from selected view
z_seq_online = encode_single_view(
model_for_losses.E_v,
model_for_losses.E_s,
model_for_losses.Fusion_single,
imgs_selected[:, K : K + 1, :, :, :],
state[:, K : K + 1, :],
)
z_t = z_seq_online[:, -1, :]
z_l = model_for_losses.E_l(token)
# FM supervision: OT-CFM
B = imgs_front.size(0)
# Infer action dim from FM head's output layer
fm_net = model_for_losses.FM.net
pred_dim = fm_net[-1].out_features
# Get x1 (target)
# For VJEPA2-AC baseline: use ee_delta (7D end-effector deltas) if available
ee_delta_batch = batch.get("ee_delta")
if ee_delta_batch is not None:
ee_delta_batch = ee_delta_batch.to(device, non_blocking=True)
# Zero rotation component to match MPC behavior
# Meta's MPC implementation zeros rotation (simplified kinematics)
ee_delta_batch = ee_delta_batch.clone()
ee_delta_batch[..., 3:6] = 0.0 # Zero rotation deltas
# Use ee_delta for VJEPA2-AC, fall back to joint actions otherwise
if ee_delta_batch is not None and arch in ("vjepa2ac-baseline",):
# Use K+1 for departure action (target)
# ee_delta[K] is arrival action (s_K - s_{K-1})
# ee_delta[K+1] is departure action (s_{K+1} - s_K)
gt = ee_delta_batch[:, K+1, :] # (B, 7)
else:
gt = actions[:, K, :]
if gt.shape[1] != pred_dim:
if gt.shape[1] > pred_dim:
gt = gt[:, :pred_dim]
else:
pad = torch.zeros(B, pred_dim - gt.shape[1], device=device, dtype=gt.dtype)
gt = torch.cat([gt, pad], dim=1)
x1 = gt
do_teacher = True
if do_teacher:
# 4. Compute Flow State xt and Target Velocity ut using torchcfm
x0 = torch.randn_like(x1)
t, xt, ut = fm_loss_fn.sample_location_and_conditional_flow(x0, x1)
# 5. Predict Velocity
# Use goal image for conditioning (Meta's hindsight relabeling)
# Use goal corresponding to the selected camera view
if use_wrist_view and "goal_imgs_wrist" in batch and batch["goal_imgs_wrist"] is not None:
goal_imgs = batch["goal_imgs_wrist"].to(device, non_blocking=True)
z_goal = model_for_losses.E_v(goal_imgs)
elif not use_wrist_view and "goal_imgs_front" in batch and batch["goal_imgs_front"] is not None:
goal_imgs = batch["goal_imgs_front"].to(device, non_blocking=True)
z_goal = model_for_losses.E_v(goal_imgs)
else:
# Fallback to final frame from selected view
z_goal = model_for_losses.E_v(imgs_selected[:, K+H, :, :, :])
vt = model_for_losses.FM(xt, z_t, z_goal, t)
fm_loss = ((vt - ut) ** 2).mean()
else:
fm_loss = torch.zeros(1, device=device)
if teacher_start is not None:
elapsed = time.perf_counter() - teacher_start
if is_main:
print(f"step={step}: teacher finished in {elapsed:.2f}s", flush=True)
loss = w.get("fm", 0.5) * fm_loss
else:
# JEPA / OT pipelines
# Also used by vjepa2ac-{continued,unfreeze,ot}
# V-JEPA 2 style: randomly select front OR wrist trajectory (independent training)
# If front_camera_only=True, always use front camera (matching Meta's single-view training)
use_wrist_view = False if front_camera_only else (random.random() < 0.5)
imgs_selected = imgs_wrist if use_wrist_view else imgs_front
z_seq_online = encode_single_view(
model_for_losses.E_v,
model_for_losses.E_s,
model_for_losses.Fusion_single,
imgs_selected[:, : K + 1, :, :, :],
state[:, : K + 1, :],
)
z_hist = z_seq_online # (B, K+1, d)
z_t = z_hist[:, -1, :]
rollout_loss = torch.tensor(0.0, device=device)
# For V-JEPA 2-AC variants: use action-conditioned prediction
if model_for_losses.use_action_conditioning:
# Get end-effector data from batch (B, K+H+1, 7)
ee_state_batch = batch.get("ee_state") # (B, K+H+1, 7)
ee_delta_batch = batch.get("ee_delta") # (B, K+H+1, 7)
# Ensure end-effector tensors are on the correct device
if ee_state_batch is not None:
ee_state_batch = ee_state_batch.to(device, non_blocking=True)
if ee_delta_batch is not None:
ee_delta_batch = ee_delta_batch.to(device, non_blocking=True)
# Zero rotation component to match MPC behavior
# Meta's MPC implementation zeros rotation (simplified kinematics)
# Training must match this to avoid distribution shift
ee_delta_batch = ee_delta_batch.clone()
ee_delta_batch[..., 3:6] = 0.0 # Zero rotation deltas
# Validate dimensions
if ee_state_batch is not None and ee_state_batch.shape[-1] != 7:
raise ValueError(
f"Invalid ee_state dimensions: expected (B, T, 7), got {ee_state_batch.shape}. "
f"EE state must be 7D: [pos(3), rot_rpy(3), gripper(1)]"
)
if ee_delta_batch is not None and ee_delta_batch.shape[-1] != 7:
raise ValueError(
f"Invalid ee_delta dimensions: expected (B, T, 7), got {ee_delta_batch.shape}. "
f"EE delta must be 7D: [dpos(3), drot_rpy(3), dgripper(1)]"
)
# Decide whether to use the hub AC predictor on patch tokens
use_hub_predictor = (
hasattr(model_for_losses.E_v, "encode_patches")
and getattr(model_for_losses.E_v, "predictor", None) is not None
and ee_state_batch is not None
and ee_delta_batch is not None
and imgs_selected.ndim == 5
and not bool(cfg["data"].get("use_embeddings", False))
)
if use_hub_predictor:
# Patch-token-level action-conditioned prediction with hub predictor
B_tokens = imgs_selected.size(0)
total_T = K + H + 1
img_seq = imgs_selected[:, : total_T, :, :, :] # (B, K+H+1, C, H, W) from selected view
patch_tokens = model_for_losses.E_v.encode_patches(img_seq) # (B, T, N_p, D_enc)
Bp, T_p, N_p, D_enc = patch_tokens.shape
if Bp != B_tokens or T_p != total_T:
raise RuntimeError(
f"encode_patches returned shape {(Bp, T_p)} but expected {(B_tokens, total_T)}"
)
# Apply layer_norm to encoder output (Meta's normalize_reps=True)
# Meta's WorldModel.encode() applies F.layer_norm(h, (h.size(-1),)) to encoder output
# This must be done BEFORE passing to predictor for consistency with MPC testing
patch_tokens = F.layer_norm(patch_tokens, (patch_tokens.size(-1),))
# Meta's loss is computed on PATCH TOKENS, not pooled representations:
# loss_fn(z, h) = torch.mean(torch.abs(z - _h) ** loss_exp) / loss_exp
# where z and h are both (B, T*N_p, D) patch token tensors.
#
# We implement both:
# 1. Teacher-forcing: predict all future frames given all context
# 2. Autoregressive: predict step-by-step, feeding predictions back
# Get target tokens (from target encoder, already layer_normed)
# We need targets for frames K+1 to K+H
with torch.no_grad():
# Target tokens are the encoder outputs for future frames
# patch_tokens shape: (B, T, N_p, D_enc) - already layer_normed
target_tokens = patch_tokens[:, K+1:K+1+H, :, :] # (B, H, N_p, D_enc)
target_tokens_flat = target_tokens.reshape(B_tokens, H * N_p, D_enc)
# Teacher-Forcing Prediction (Meta's z_tf)
# Meta uses ALL frames except the last for TF input:
# _z = z[:, :-tokens_per_frame] (T-1 frames)
# actions = all T-1 actions
# states = states[:, :-1] (T-1 states)
# The predictor predicts the NEXT token for each position
# Use T-1 frames (all except last) as input
T_tf = total_T - 1 # Number of input frames for TF
context_tokens = patch_tokens[:, :T_tf, :, :] # (B, T-1, N_p, D_enc)
x_in = context_tokens.reshape(B_tokens, T_tf * N_p, D_enc)
# Actions and states for T-1 frames
# Shift actions by +1 (departure action alignment)
# ee_delta_batch[:, 1:T_tf+1, :] gives T_tf departure actions
actions_in = ee_delta_batch[:, 1:T_tf+1, :] # (B, T-1, 7)
states_in = ee_state_batch[:, :T_tf, :] # (B, T-1, 7)
# Run predictor (teacher-forcing: single forward pass)
pred_tf = model_for_losses.E_v.predictor(x_in, actions_in, states_in)
# pred_tf shape: (B, (T-1)*N_p, D_enc)
# Apply layer_norm to predictor output (Meta's normalize_reps=True)
pred_tf = F.layer_norm(pred_tf, (pred_tf.size(-1),))
# Teacher-forcing loss: compare pred_tf with target_tokens
# Meta: _h = h[:, tokens_per_frame : z.size(1) + tokens_per_frame]
# This means: target is frames 1 to T-1 (shifted by 1 frame)
# pred_tf[i] predicts the token at position i+1 in the sequence
tf_target = patch_tokens[:, 1:T_tf, :, :].reshape(B_tokens, (T_tf-1)*N_p, D_enc)
tf_loss = torch.mean(torch.abs(pred_tf[:, N_p:, :] - tf_target))
# Autoregressive Prediction (Meta's z_ar)
# Meta starts with [frame0_GT, frame1_from_TF] and continues from there
# Use the TF prediction for frame 1, not a fresh prediction
auto_steps = min(H, total_T - 1) # Number of AR steps (predict frames 2, 3, ...)
# Start with frame 0 (GT) + frame 1 (from TF prediction)
# pred_tf[:, :N_p] is the TF prediction for frame 1
pred_frame1_tf = pred_tf[:, :N_p, :].reshape(B_tokens, 1, N_p, D_enc)
curr_tokens = torch.cat([
patch_tokens[:, :1, :, :], # Frame 0 (GT)
pred_frame1_tf # Frame 1 (TF prediction)
], dim=1) # (B, 2, N_p, D_enc)
# Continue autoregressive rollout for frames 2, 3, ...
for n in range(1, auto_steps):
T_curr = curr_tokens.shape[1]
x_in_ar = curr_tokens.reshape(B_tokens, T_curr * N_p, D_enc)
# Actions: need T_curr actions for T_curr frames
# ee_delta_batch[:, 1:T_curr+1, :] gives actions [a_1, a_2, ..., a_{T_curr}]
# which are departure actions from states [s_0, s_1, ..., s_{T_curr-1}]
actions_ar = ee_delta_batch[:, 1:T_curr+1, :] # (B, T_curr, 7)
states_ar = ee_state_batch[:, :T_curr, :] # (B, T_curr, 7)
pred_ar = model_for_losses.E_v.predictor(x_in_ar, actions_ar, states_ar)
pred_ar = F.layer_norm(pred_ar, (pred_ar.size(-1),))
pred_next = pred_ar[:, -N_p:, :] # (B, N_p, D_enc)
curr_tokens = torch.cat([curr_tokens, pred_next.unsqueeze(1)], dim=1)
# Autoregressive loss: compare predicted tokens with targets
# curr_tokens[:, 1:] are the predictions (skip frame 0 which is GT)
ar_pred_tokens = curr_tokens[:, 1:, :, :].reshape(B_tokens, -1, D_enc)
ar_tgt_tokens = patch_tokens[:, 1:1+auto_steps, :, :].reshape(B_tokens, -1, D_enc)
ar_loss = torch.mean(torch.abs(ar_pred_tokens - ar_tgt_tokens))
# Combined JEPA loss (Meta: loss = jloss + sloss)
jepa_patch_loss = tf_loss + ar_loss
# For compatibility with existing code, also compute pooled representations
# (used for FM conditioning and other downstream tasks)
z_future_pred_frames = []
z_future_tgt_frames = []
for h in range(1, H + 1):
t_pred = K + h
if t_pred < total_T:
# Predicted: use autoregressive predictions
if h <= auto_steps:
pred_tokens_h = curr_tokens[:, h, :, :] # (B, N_p, D_enc)
else:
pred_tokens_h = curr_tokens[:, -1, :, :] # Use last prediction
z_pred = pred_tokens_h.mean(dim=1)
z_pred = model_for_losses.E_v.norm(z_pred)
z_pred = model_for_losses.E_v.proj(z_pred)
z_future_pred_frames.append(z_pred)
# Target: ground truth
with torch.no_grad():
tgt_tokens_h = patch_tokens[:, t_pred, :, :]
z_tgt = tgt_tokens_h.mean(dim=1)
z_tgt = model_for_losses.E_v.norm(z_tgt)
z_tgt = model_for_losses.E_v.proj(z_tgt)
z_future_tgt_frames.append(z_tgt)
# Stack pooled representations for FM conditioning
if len(z_future_pred_frames) > 0:
z_future_pred = torch.stack(z_future_pred_frames, dim=1) # (B, H, d)
z_future_tgt = torch.stack(z_future_tgt_frames, dim=1) # (B, H, d)
else:
# Fallback if no predictions were made