Skip to content

Commit 8447505

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Support ebc for kv zch
Differential Revision: D76636927
1 parent dc035b2 commit 8447505

File tree

4 files changed

+900
-23
lines changed

4 files changed

+900
-23
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 334 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,11 @@ def __init__(
10761076
assert (
10771077
len({table.embedding_dim for table in config.embedding_tables}) == 1
10781078
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
1079+
for table in config.embedding_tables:
1080+
assert table.local_cols % 4 == 0, (
1081+
f"table {table.name} has local_cols={table.local_cols} "
1082+
"not divisible by 4. "
1083+
)
10791084

10801085
ssd_tbe_params = _populate_ssd_tbe_params(config)
10811086
compute_kernel = config.embedding_tables[0].compute_kernel
@@ -1263,6 +1268,11 @@ def __init__(
12631268
assert (
12641269
len({table.embedding_dim for table in config.embedding_tables}) == 1
12651270
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
1271+
for table in config.embedding_tables:
1272+
assert table.local_cols % 4 == 0, (
1273+
f"table {table.name} has local_cols={table.local_cols} "
1274+
"not divisible by 4. "
1275+
)
12661276

12671277
ssd_tbe_params = _populate_ssd_tbe_params(config)
12681278
self._bucket_spec: List[Tuple[int, int, int]] = self.get_sharded_local_buckets()
@@ -1553,10 +1563,7 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
15531563
self._split_weights_res = None
15541564
self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None)
15551565

1556-
return self.emb_module(
1557-
indices=features.values().long(),
1558-
offsets=features.offsets().long(),
1559-
)
1566+
return super().forward(features)
15601567

15611568

15621569
class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule):
@@ -1885,6 +1892,324 @@ def named_parameters_by_table(
18851892
yield name, param
18861893

18871894

1895+
class ZeroCollisionKeyValueEmbeddingBag(
1896+
BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule
1897+
):
1898+
def __init__(
1899+
self,
1900+
config: GroupedEmbeddingConfig,
1901+
pg: Optional[dist.ProcessGroup] = None,
1902+
device: Optional[torch.device] = None,
1903+
sharding_type: Optional[ShardingType] = None,
1904+
backend_type: BackendType = BackendType.SSD,
1905+
) -> None:
1906+
super().__init__(config, pg, device, sharding_type)
1907+
1908+
assert (
1909+
len(config.embedding_tables) > 0
1910+
), "Expected to see at least one table in SSD TBE, but found 0."
1911+
assert (
1912+
len({table.embedding_dim for table in config.embedding_tables}) == 1
1913+
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
1914+
1915+
for table in config.embedding_tables:
1916+
assert table.local_cols % 4 == 0, (
1917+
f"table {table.name} has local_cols={table.local_cols} "
1918+
"not divisible by 4. "
1919+
)
1920+
1921+
ssd_tbe_params = _populate_ssd_tbe_params(config)
1922+
self._bucket_spec: List[Tuple[int, int, int]] = self.get_sharded_local_buckets()
1923+
_populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec)
1924+
compute_kernel = config.embedding_tables[0].compute_kernel
1925+
embedding_location = compute_kernel_to_embedding_location(compute_kernel)
1926+
1927+
# every split_embeding_weights call is expensive, since it iterates over all the elements in the backend kv db
1928+
# use split weights result cache so that multiple calls in the same train iteration will only trigger once
1929+
self._split_weights_res: Optional[
1930+
Tuple[
1931+
List[ShardedTensor],
1932+
List[ShardedTensor],
1933+
List[ShardedTensor],
1934+
]
1935+
] = None
1936+
1937+
self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags(
1938+
embedding_specs=list(zip(self._num_embeddings, self._local_cols)),
1939+
feature_table_map=self._feature_table_map,
1940+
ssd_cache_location=embedding_location,
1941+
pooling_mode=self._pooling,
1942+
backend_type=backend_type,
1943+
**ssd_tbe_params,
1944+
).to(device)
1945+
1946+
logger.info(
1947+
f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}"
1948+
)
1949+
self._table_name_to_weight_count_per_rank: Dict[str, List[int]] = {}
1950+
self._init_sharded_split_embedding_weights() # this will populate self._split_weights_res
1951+
self._optim: ZeroCollisionKeyValueEmbeddingFusedOptimizer = (
1952+
ZeroCollisionKeyValueEmbeddingFusedOptimizer(
1953+
config,
1954+
self._emb_module,
1955+
# pyre-ignore[16]
1956+
sharded_embedding_weights_by_table=self._split_weights_res[0],
1957+
table_name_to_weight_count_per_rank=self._table_name_to_weight_count_per_rank,
1958+
sharded_embedding_weight_ids=self._split_weights_res[1],
1959+
pg=pg,
1960+
)
1961+
)
1962+
self._param_per_table: Dict[str, nn.Parameter] = dict(
1963+
_gen_named_parameters_by_table_ssd_pmt(
1964+
emb_module=self._emb_module,
1965+
table_name_to_count=self.table_name_to_count.copy(),
1966+
config=self._config,
1967+
pg=pg,
1968+
)
1969+
)
1970+
self.init_parameters()
1971+
1972+
def init_parameters(self) -> None:
1973+
"""
1974+
An advantage of KV TBE is that we don't need to init weights. Hence skipping.
1975+
"""
1976+
pass
1977+
1978+
@property
1979+
def emb_module(
1980+
self,
1981+
) -> SSDTableBatchedEmbeddingBags:
1982+
return self._emb_module
1983+
1984+
@property
1985+
def fused_optimizer(self) -> FusedOptimizer:
1986+
"""
1987+
SSD Embedding fuses backward with backward.
1988+
"""
1989+
return self._optim
1990+
1991+
def get_sharded_local_buckets(self) -> List[Tuple[int, int, int]]:
1992+
"""
1993+
utils to get bucket offset start, bucket offset end, bucket size based on embedding sharding spec
1994+
"""
1995+
sharded_local_buckets: List[Tuple[int, int, int]] = []
1996+
world_size = dist.get_world_size(self._pg)
1997+
local_rank = dist.get_rank(self._pg)
1998+
1999+
for table in self._config.embedding_tables:
2000+
total_num_buckets = none_throws(table.total_num_buckets)
2001+
assert (
2002+
total_num_buckets % world_size == 0
2003+
), f"total_num_buckets={total_num_buckets} must be divisible by world_size={world_size}"
2004+
assert (
2005+
table.total_num_buckets
2006+
and table.num_embeddings % table.total_num_buckets == 0
2007+
), f"Table size '{table.num_embeddings}' must be divisible by num_buckets '{table.total_num_buckets}'"
2008+
bucket_offset_start = total_num_buckets // world_size * local_rank
2009+
bucket_offset_end = min(
2010+
total_num_buckets, total_num_buckets // world_size * (local_rank + 1)
2011+
)
2012+
bucket_size = (
2013+
table.num_embeddings + total_num_buckets - 1
2014+
) // total_num_buckets
2015+
sharded_local_buckets.append(
2016+
(bucket_offset_start, bucket_offset_end, bucket_size)
2017+
)
2018+
logger.info(
2019+
f"bucket_offset: {bucket_offset_start}:{bucket_offset_end}, bucket_size: {bucket_size} for table {table.name}"
2020+
)
2021+
return sharded_local_buckets
2022+
2023+
def state_dict(
2024+
self,
2025+
destination: Optional[Dict[str, Any]] = None,
2026+
prefix: str = "",
2027+
keep_vars: bool = False,
2028+
no_snapshot: bool = True,
2029+
) -> Dict[str, Any]:
2030+
"""
2031+
Args:
2032+
no_snapshot (bool): the tensors in the returned dict are
2033+
PartiallyMaterializedTensors. this argument controls wether the
2034+
PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the
2035+
PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the
2036+
PartiallyMaterializedTensor has a RocksDB snapshot handle
2037+
"""
2038+
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
2039+
# ShardedEmbeddingBagCollection._pre_state_dict_hook()
2040+
2041+
emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
2042+
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
2043+
for emb_table in emb_table_config_copy:
2044+
emb_table.local_metadata.placement._device = torch.device("cpu")
2045+
ret = get_state_dict(
2046+
emb_table_config_copy,
2047+
emb_tables,
2048+
self._pg,
2049+
destination,
2050+
prefix,
2051+
)
2052+
return ret
2053+
2054+
def named_parameters(
2055+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
2056+
) -> Iterator[Tuple[str, nn.Parameter]]:
2057+
"""
2058+
Only allowed ways to get state_dict.
2059+
"""
2060+
for name, tensor in self.named_split_embedding_weights(
2061+
prefix, recurse, remove_duplicate
2062+
):
2063+
# hack before we support optimizer on sharded parameter level
2064+
# can delete after PEA deprecation
2065+
# pyre-ignore [6]
2066+
param = nn.Parameter(tensor)
2067+
# pyre-ignore
2068+
param._in_backward_optimizers = [EmptyFusedOptimizer()]
2069+
yield name, param
2070+
2071+
# pyre-ignore [15]
2072+
def named_split_embedding_weights(
2073+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
2074+
) -> Iterator[Tuple[str, Union[PartiallyMaterializedTensor, torch.Tensor]]]:
2075+
assert (
2076+
remove_duplicate
2077+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
2078+
for config, tensor in zip(
2079+
self._config.embedding_tables,
2080+
self.split_embedding_weights()[0],
2081+
):
2082+
key = append_prefix(prefix, f"{config.name}.weight")
2083+
yield key, tensor
2084+
2085+
# initialize sharded _split_weights_res if it's None
2086+
# this method is used to generate sharded embedding weights once for all following state_dict
2087+
# calls in checkpointing and publishing.
2088+
# When training is resumed, the cached value will be reset to None and the value needs to be
2089+
# rebuilt for next checkpointing and publishing, as the weight id, weight embedding will be updated
2090+
# during training in backend k/v store.
2091+
def _init_sharded_split_embedding_weights(
2092+
self, prefix: str = "", force_regenerate: bool = False
2093+
) -> None:
2094+
if not force_regenerate and self._split_weights_res is not None:
2095+
return
2096+
2097+
pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights(
2098+
no_snapshot=False,
2099+
)
2100+
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
2101+
for emb_table in emb_table_config_copy:
2102+
none_throws(
2103+
none_throws(
2104+
emb_table.local_metadata,
2105+
f"local_metadata is None for emb_table: {emb_table.name}",
2106+
).placement,
2107+
f"placement is None for local_metadata of emb table: {emb_table.name}",
2108+
)._device = torch.device("cpu")
2109+
2110+
pmt_sharded_t_list = create_virtual_sharded_tensors(
2111+
emb_table_config_copy,
2112+
pmt_list,
2113+
self._pg,
2114+
prefix,
2115+
self._table_name_to_weight_count_per_rank,
2116+
)
2117+
weight_id_sharded_t_list = create_virtual_sharded_tensors(
2118+
emb_table_config_copy,
2119+
weight_ids_list, # pyre-ignore [6]
2120+
self._pg,
2121+
prefix,
2122+
self._table_name_to_weight_count_per_rank,
2123+
)
2124+
bucket_cnt_sharded_t_list = create_virtual_sharded_tensors(
2125+
emb_table_config_copy,
2126+
bucket_cnt_list, # pyre-ignore [6]
2127+
self._pg,
2128+
prefix,
2129+
self._table_name_to_weight_count_per_rank,
2130+
use_param_size_as_rows=True,
2131+
)
2132+
# pyre-ignore
2133+
assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list)
2134+
assert (
2135+
len(pmt_sharded_t_list)
2136+
== len(weight_id_sharded_t_list)
2137+
== len(bucket_cnt_sharded_t_list)
2138+
)
2139+
self._split_weights_res = (
2140+
pmt_sharded_t_list,
2141+
weight_id_sharded_t_list,
2142+
bucket_cnt_sharded_t_list,
2143+
)
2144+
2145+
def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[
2146+
Tuple[
2147+
str,
2148+
Union[ShardedTensor, PartiallyMaterializedTensor],
2149+
Optional[ShardedTensor],
2150+
Optional[ShardedTensor],
2151+
]
2152+
]:
2153+
"""
2154+
Return an iterator over embedding tables, for each table yielding
2155+
table name,
2156+
PMT for embedding table with a valid RocksDB snapshot to support tensor IO
2157+
optional ShardedTensor for weight_id
2158+
optional ShardedTensor for bucket_cnt
2159+
"""
2160+
self._init_sharded_split_embedding_weights()
2161+
# pyre-ignore[16]
2162+
self._optim.set_sharded_embedding_weight_ids(self._split_weights_res[1])
2163+
2164+
pmt_sharded_t_list = self._split_weights_res[0]
2165+
weight_id_sharded_t_list = self._split_weights_res[1]
2166+
bucket_cnt_sharded_t_list = self._split_weights_res[2]
2167+
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
2168+
table_config = self._config.embedding_tables[table_idx]
2169+
key = append_prefix(prefix, f"{table_config.name}")
2170+
2171+
yield key, pmt_sharded_t, weight_id_sharded_t_list[
2172+
table_idx
2173+
], bucket_cnt_sharded_t_list[table_idx]
2174+
2175+
def flush(self) -> None:
2176+
"""
2177+
Flush the embeddings in cache back to SSD. Should be pretty expensive.
2178+
"""
2179+
self.emb_module.flush()
2180+
2181+
def purge(self) -> None:
2182+
"""
2183+
Reset the cache space. This is needed when we load state dict.
2184+
"""
2185+
# TODO: move the following to SSD TBE.
2186+
self.emb_module.lxu_cache_weights.zero_()
2187+
self.emb_module.lxu_cache_state.fill_(-1)
2188+
2189+
def create_rocksdb_hard_link_snapshot(self) -> None:
2190+
"""
2191+
Create a RocksDB checkpoint. This is needed before we call state_dict() for publish.
2192+
"""
2193+
self.emb_module.create_rocksdb_hard_link_snapshot()
2194+
2195+
# pyre-ignore [15]
2196+
def split_embedding_weights(
2197+
self, no_snapshot: bool = True, should_flush: bool = False
2198+
) -> Tuple[
2199+
Union[List[PartiallyMaterializedTensor], List[torch.Tensor]],
2200+
Optional[List[torch.Tensor]],
2201+
Optional[List[torch.Tensor]],
2202+
]:
2203+
return self.emb_module.split_embedding_weights(no_snapshot, should_flush)
2204+
2205+
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
2206+
# reset split weights during training
2207+
self._split_weights_res = None
2208+
self._optim.set_sharded_embedding_weight_ids(sharded_embedding_weight_ids=None)
2209+
2210+
return super().forward(features)
2211+
2212+
18882213
class KeyValueEmbeddingBag(BaseBatchedEmbeddingBag[torch.Tensor], FusedOptimizerModule):
18892214
def __init__(
18902215
self,
@@ -1901,6 +2226,11 @@ def __init__(
19012226
assert (
19022227
len({table.embedding_dim for table in config.embedding_tables}) == 1
19032228
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
2229+
for table in config.embedding_tables:
2230+
assert table.local_cols % 4 == 0, (
2231+
f"table {table.name} has local_cols={table.local_cols} "
2232+
"not divisible by 4. "
2233+
)
19042234

19052235
ssd_tbe_params = _populate_ssd_tbe_params(config)
19062236
compute_kernel = config.embedding_tables[0].compute_kernel

0 commit comments

Comments
 (0)