-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlayout.py
More file actions
104 lines (83 loc) · 4.01 KB
/
layout.py
File metadata and controls
104 lines (83 loc) · 4.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from __future__ import annotations
from typing import Dict, List, Tuple
import torch
class ContinuousSnakeLayout:
"""Sequence <-> volume layout: Faced Continuous Snake Layout.
Each depth slice (G) is an HxW face traversed by a 2-D snake. Odd
faces reuse the same face path in reverse, so token continuity is
preserved at face boundaries as well as inside each face.
Example on a 2x3x3 grid:
face g=0 face g=1
0 1 2 11 10 9
5 4 3 12 13 14
6 7 8 17 16 15
The public representation remains [B, L, D]; only the placement
inside [B, D, G, H, W] follows this layout.
"""
def __init__(self, grid: Tuple[int, int, int]):
self._cache: Dict[Tuple[str, torch.device], torch.Tensor] = {}
self.set_grid(grid)
@staticmethod
def _build_indices(grid: Tuple[int, int, int]) -> Tuple[torch.Tensor, torch.Tensor]:
g, h, w = (int(grid[0]), int(grid[1]), int(grid[2]))
if g < 1 or h < 1 or w < 1:
raise ValueError(f'grid dimensions must be positive, got {grid}')
face: List[Tuple[int, int]] = []
for y in range(h):
xs = range(w) if (y % 2 == 0) else range(w - 1, -1, -1)
for x in xs:
face.append((y, x))
token_to_flat: List[int] = []
for z in range(g):
order = face if (z % 2 == 0) else reversed(face)
for y, x in order:
token_to_flat.append(z * h * w + y * w + x)
L = g * h * w
if len(token_to_flat) != L:
raise RuntimeError(f'internal layout error: built {len(token_to_flat)} indices for L={L}')
flat_to_token = [0] * L
for tok, flat in enumerate(token_to_flat):
flat_to_token[flat] = tok
return (
torch.tensor(token_to_flat, dtype=torch.long),
torch.tensor(flat_to_token, dtype=torch.long),
)
def set_grid(self, grid: Tuple[int, int, int]) -> None:
self.grid = (int(grid[0]), int(grid[1]), int(grid[2]))
self.seq_len = self.grid[0] * self.grid[1] * self.grid[2]
self._token_to_flat_cpu, self._flat_to_token_cpu = self._build_indices(self.grid)
self._cache.clear()
def _idx(self, name: str, device: torch.device) -> torch.Tensor:
key = (name, device)
cached = self._cache.get(key)
if cached is not None:
return cached
base = self._token_to_flat_cpu if name == 'token_to_flat' else self._flat_to_token_cpu
idx = base.to(device=device, non_blocking=True)
self._cache[key] = idx
return idx
def token_to_flat(self, device: torch.device) -> torch.Tensor:
return self._idx('token_to_flat', device)
def flat_to_token(self, device: torch.device) -> torch.Tensor:
return self._idx('flat_to_token', device)
def mask_to_vol(self, mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
b, l = mask.shape
if l != self.seq_len:
raise ValueError(f'mask length mismatch: got {l}, expected {self.seq_len}')
g, h, w = self.grid
flat_to_token = self.flat_to_token(mask.device)
return mask.to(dtype).index_select(1, flat_to_token).reshape(b, 1, g, h, w)
def seq_to_vol(self, seq: torch.Tensor) -> torch.Tensor:
b, l, d = seq.shape
if l != self.seq_len:
raise ValueError(f'seq length mismatch: got {l}, expected {self.seq_len}')
g, h, w = self.grid
flat_to_token = self.flat_to_token(seq.device)
return seq.index_select(1, flat_to_token).transpose(1, 2).reshape(b, d, g, h, w)
def vol_to_seq(self, volume: torch.Tensor) -> torch.Tensor:
b, d, g, h, w = volume.shape
if (g, h, w) != self.grid:
raise ValueError(f'volume grid mismatch: got {(g, h, w)}, expected {self.grid}')
token_to_flat = self.token_to_flat(volume.device)
flat = volume.reshape(b, d, self.seq_len).transpose(1, 2)
return flat.index_select(1, token_to_flat)