@@ -1076,6 +1076,11 @@ def __init__(
1076
1076
assert (
1077
1077
len ({table .embedding_dim for table in config .embedding_tables }) == 1
1078
1078
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
1079
+ for table in config .embedding_tables :
1080
+ assert table .local_cols % 4 == 0 , (
1081
+ f"table { table .name } has local_cols={ table .local_cols } "
1082
+ "not divisible by 4. "
1083
+ )
1079
1084
1080
1085
ssd_tbe_params = _populate_ssd_tbe_params (config )
1081
1086
compute_kernel = config .embedding_tables [0 ].compute_kernel
@@ -1263,6 +1268,11 @@ def __init__(
1263
1268
assert (
1264
1269
len ({table .embedding_dim for table in config .embedding_tables }) == 1
1265
1270
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
1271
+ for table in config .embedding_tables :
1272
+ assert table .local_cols % 4 == 0 , (
1273
+ f"table { table .name } has local_cols={ table .local_cols } "
1274
+ "not divisible by 4. "
1275
+ )
1266
1276
1267
1277
ssd_tbe_params = _populate_ssd_tbe_params (config )
1268
1278
self ._bucket_spec : List [Tuple [int , int , int ]] = self .get_sharded_local_buckets ()
@@ -1553,10 +1563,7 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
1553
1563
self ._split_weights_res = None
1554
1564
self ._optim .set_sharded_embedding_weight_ids (sharded_embedding_weight_ids = None )
1555
1565
1556
- return self .emb_module (
1557
- indices = features .values ().long (),
1558
- offsets = features .offsets ().long (),
1559
- )
1566
+ return super ().forward (features )
1560
1567
1561
1568
1562
1569
class BatchedFusedEmbedding (BaseBatchedEmbedding [torch .Tensor ], FusedOptimizerModule ):
@@ -1885,6 +1892,324 @@ def named_parameters_by_table(
1885
1892
yield name , param
1886
1893
1887
1894
1895
+ class ZeroCollisionKeyValueEmbeddingBag (
1896
+ BaseBatchedEmbeddingBag [torch .Tensor ], FusedOptimizerModule
1897
+ ):
1898
+ def __init__ (
1899
+ self ,
1900
+ config : GroupedEmbeddingConfig ,
1901
+ pg : Optional [dist .ProcessGroup ] = None ,
1902
+ device : Optional [torch .device ] = None ,
1903
+ sharding_type : Optional [ShardingType ] = None ,
1904
+ backend_type : BackendType = BackendType .SSD ,
1905
+ ) -> None :
1906
+ super ().__init__ (config , pg , device , sharding_type )
1907
+
1908
+ assert (
1909
+ len (config .embedding_tables ) > 0
1910
+ ), "Expected to see at least one table in SSD TBE, but found 0."
1911
+ assert (
1912
+ len ({table .embedding_dim for table in config .embedding_tables }) == 1
1913
+ ), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
1914
+
1915
+ for table in config .embedding_tables :
1916
+ assert table .local_cols % 4 == 0 , (
1917
+ f"table { table .name } has local_cols={ table .local_cols } "
1918
+ "not divisible by 4. "
1919
+ )
1920
+
1921
+ ssd_tbe_params = _populate_ssd_tbe_params (config )
1922
+ self ._bucket_spec : List [Tuple [int , int , int ]] = self .get_sharded_local_buckets ()
1923
+ _populate_zero_collision_tbe_params (ssd_tbe_params , self ._bucket_spec )
1924
+ compute_kernel = config .embedding_tables [0 ].compute_kernel
1925
+ embedding_location = compute_kernel_to_embedding_location (compute_kernel )
1926
+
1927
+ # every split_embeding_weights call is expensive, since it iterates over all the elements in the backend kv db
1928
+ # use split weights result cache so that multiple calls in the same train iteration will only trigger once
1929
+ self ._split_weights_res : Optional [
1930
+ Tuple [
1931
+ List [ShardedTensor ],
1932
+ List [ShardedTensor ],
1933
+ List [ShardedTensor ],
1934
+ ]
1935
+ ] = None
1936
+
1937
+ self ._emb_module : SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags (
1938
+ embedding_specs = list (zip (self ._num_embeddings , self ._local_cols )),
1939
+ feature_table_map = self ._feature_table_map ,
1940
+ ssd_cache_location = embedding_location ,
1941
+ pooling_mode = self ._pooling ,
1942
+ backend_type = backend_type ,
1943
+ ** ssd_tbe_params ,
1944
+ ).to (device )
1945
+
1946
+ logger .info (
1947
+ f"tbe_unique_id:{ self ._emb_module .tbe_unique_id } => table name to count dict:{ self .table_name_to_count } "
1948
+ )
1949
+ self ._table_name_to_weight_count_per_rank : Dict [str , List [int ]] = {}
1950
+ self ._init_sharded_split_embedding_weights () # this will populate self._split_weights_res
1951
+ self ._optim : ZeroCollisionKeyValueEmbeddingFusedOptimizer = (
1952
+ ZeroCollisionKeyValueEmbeddingFusedOptimizer (
1953
+ config ,
1954
+ self ._emb_module ,
1955
+ # pyre-ignore[16]
1956
+ sharded_embedding_weights_by_table = self ._split_weights_res [0 ],
1957
+ table_name_to_weight_count_per_rank = self ._table_name_to_weight_count_per_rank ,
1958
+ sharded_embedding_weight_ids = self ._split_weights_res [1 ],
1959
+ pg = pg ,
1960
+ )
1961
+ )
1962
+ self ._param_per_table : Dict [str , nn .Parameter ] = dict (
1963
+ _gen_named_parameters_by_table_ssd_pmt (
1964
+ emb_module = self ._emb_module ,
1965
+ table_name_to_count = self .table_name_to_count .copy (),
1966
+ config = self ._config ,
1967
+ pg = pg ,
1968
+ )
1969
+ )
1970
+ self .init_parameters ()
1971
+
1972
+ def init_parameters (self ) -> None :
1973
+ """
1974
+ An advantage of KV TBE is that we don't need to init weights. Hence skipping.
1975
+ """
1976
+ pass
1977
+
1978
+ @property
1979
+ def emb_module (
1980
+ self ,
1981
+ ) -> SSDTableBatchedEmbeddingBags :
1982
+ return self ._emb_module
1983
+
1984
+ @property
1985
+ def fused_optimizer (self ) -> FusedOptimizer :
1986
+ """
1987
+ SSD Embedding fuses backward with backward.
1988
+ """
1989
+ return self ._optim
1990
+
1991
+ def get_sharded_local_buckets (self ) -> List [Tuple [int , int , int ]]:
1992
+ """
1993
+ utils to get bucket offset start, bucket offset end, bucket size based on embedding sharding spec
1994
+ """
1995
+ sharded_local_buckets : List [Tuple [int , int , int ]] = []
1996
+ world_size = dist .get_world_size (self ._pg )
1997
+ local_rank = dist .get_rank (self ._pg )
1998
+
1999
+ for table in self ._config .embedding_tables :
2000
+ total_num_buckets = none_throws (table .total_num_buckets )
2001
+ assert (
2002
+ total_num_buckets % world_size == 0
2003
+ ), f"total_num_buckets={ total_num_buckets } must be divisible by world_size={ world_size } "
2004
+ assert (
2005
+ table .total_num_buckets
2006
+ and table .num_embeddings % table .total_num_buckets == 0
2007
+ ), f"Table size '{ table .num_embeddings } ' must be divisible by num_buckets '{ table .total_num_buckets } '"
2008
+ bucket_offset_start = total_num_buckets // world_size * local_rank
2009
+ bucket_offset_end = min (
2010
+ total_num_buckets , total_num_buckets // world_size * (local_rank + 1 )
2011
+ )
2012
+ bucket_size = (
2013
+ table .num_embeddings + total_num_buckets - 1
2014
+ ) // total_num_buckets
2015
+ sharded_local_buckets .append (
2016
+ (bucket_offset_start , bucket_offset_end , bucket_size )
2017
+ )
2018
+ logger .info (
2019
+ f"bucket_offset: { bucket_offset_start } :{ bucket_offset_end } , bucket_size: { bucket_size } for table { table .name } "
2020
+ )
2021
+ return sharded_local_buckets
2022
+
2023
+ def state_dict (
2024
+ self ,
2025
+ destination : Optional [Dict [str , Any ]] = None ,
2026
+ prefix : str = "" ,
2027
+ keep_vars : bool = False ,
2028
+ no_snapshot : bool = True ,
2029
+ ) -> Dict [str , Any ]:
2030
+ """
2031
+ Args:
2032
+ no_snapshot (bool): the tensors in the returned dict are
2033
+ PartiallyMaterializedTensors. this argument controls wether the
2034
+ PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the
2035
+ PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the
2036
+ PartiallyMaterializedTensor has a RocksDB snapshot handle
2037
+ """
2038
+ # in the case no_snapshot=False, a flush is required. we rely on the flush operation in
2039
+ # ShardedEmbeddingBagCollection._pre_state_dict_hook()
2040
+
2041
+ emb_tables , _ , _ = self .split_embedding_weights (no_snapshot = no_snapshot )
2042
+ emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
2043
+ for emb_table in emb_table_config_copy :
2044
+ emb_table .local_metadata .placement ._device = torch .device ("cpu" )
2045
+ ret = get_state_dict (
2046
+ emb_table_config_copy ,
2047
+ emb_tables ,
2048
+ self ._pg ,
2049
+ destination ,
2050
+ prefix ,
2051
+ )
2052
+ return ret
2053
+
2054
+ def named_parameters (
2055
+ self , prefix : str = "" , recurse : bool = True , remove_duplicate : bool = True
2056
+ ) -> Iterator [Tuple [str , nn .Parameter ]]:
2057
+ """
2058
+ Only allowed ways to get state_dict.
2059
+ """
2060
+ for name , tensor in self .named_split_embedding_weights (
2061
+ prefix , recurse , remove_duplicate
2062
+ ):
2063
+ # hack before we support optimizer on sharded parameter level
2064
+ # can delete after PEA deprecation
2065
+ # pyre-ignore [6]
2066
+ param = nn .Parameter (tensor )
2067
+ # pyre-ignore
2068
+ param ._in_backward_optimizers = [EmptyFusedOptimizer ()]
2069
+ yield name , param
2070
+
2071
+ # pyre-ignore [15]
2072
+ def named_split_embedding_weights (
2073
+ self , prefix : str = "" , recurse : bool = True , remove_duplicate : bool = True
2074
+ ) -> Iterator [Tuple [str , Union [PartiallyMaterializedTensor , torch .Tensor ]]]:
2075
+ assert (
2076
+ remove_duplicate
2077
+ ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
2078
+ for config , tensor in zip (
2079
+ self ._config .embedding_tables ,
2080
+ self .split_embedding_weights ()[0 ],
2081
+ ):
2082
+ key = append_prefix (prefix , f"{ config .name } .weight" )
2083
+ yield key , tensor
2084
+
2085
+ # initialize sharded _split_weights_res if it's None
2086
+ # this method is used to generate sharded embedding weights once for all following state_dict
2087
+ # calls in checkpointing and publishing.
2088
+ # When training is resumed, the cached value will be reset to None and the value needs to be
2089
+ # rebuilt for next checkpointing and publishing, as the weight id, weight embedding will be updated
2090
+ # during training in backend k/v store.
2091
+ def _init_sharded_split_embedding_weights (
2092
+ self , prefix : str = "" , force_regenerate : bool = False
2093
+ ) -> None :
2094
+ if not force_regenerate and self ._split_weights_res is not None :
2095
+ return
2096
+
2097
+ pmt_list , weight_ids_list , bucket_cnt_list = self .split_embedding_weights (
2098
+ no_snapshot = False ,
2099
+ )
2100
+ emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
2101
+ for emb_table in emb_table_config_copy :
2102
+ none_throws (
2103
+ none_throws (
2104
+ emb_table .local_metadata ,
2105
+ f"local_metadata is None for emb_table: { emb_table .name } " ,
2106
+ ).placement ,
2107
+ f"placement is None for local_metadata of emb table: { emb_table .name } " ,
2108
+ )._device = torch .device ("cpu" )
2109
+
2110
+ pmt_sharded_t_list = create_virtual_sharded_tensors (
2111
+ emb_table_config_copy ,
2112
+ pmt_list ,
2113
+ self ._pg ,
2114
+ prefix ,
2115
+ self ._table_name_to_weight_count_per_rank ,
2116
+ )
2117
+ weight_id_sharded_t_list = create_virtual_sharded_tensors (
2118
+ emb_table_config_copy ,
2119
+ weight_ids_list , # pyre-ignore [6]
2120
+ self ._pg ,
2121
+ prefix ,
2122
+ self ._table_name_to_weight_count_per_rank ,
2123
+ )
2124
+ bucket_cnt_sharded_t_list = create_virtual_sharded_tensors (
2125
+ emb_table_config_copy ,
2126
+ bucket_cnt_list , # pyre-ignore [6]
2127
+ self ._pg ,
2128
+ prefix ,
2129
+ self ._table_name_to_weight_count_per_rank ,
2130
+ use_param_size_as_rows = True ,
2131
+ )
2132
+ # pyre-ignore
2133
+ assert len (pmt_list ) == len (weight_ids_list ) == len (bucket_cnt_list )
2134
+ assert (
2135
+ len (pmt_sharded_t_list )
2136
+ == len (weight_id_sharded_t_list )
2137
+ == len (bucket_cnt_sharded_t_list )
2138
+ )
2139
+ self ._split_weights_res = (
2140
+ pmt_sharded_t_list ,
2141
+ weight_id_sharded_t_list ,
2142
+ bucket_cnt_sharded_t_list ,
2143
+ )
2144
+
2145
+ def get_named_split_embedding_weights_snapshot (self , prefix : str = "" ) -> Iterator [
2146
+ Tuple [
2147
+ str ,
2148
+ Union [ShardedTensor , PartiallyMaterializedTensor ],
2149
+ Optional [ShardedTensor ],
2150
+ Optional [ShardedTensor ],
2151
+ ]
2152
+ ]:
2153
+ """
2154
+ Return an iterator over embedding tables, for each table yielding
2155
+ table name,
2156
+ PMT for embedding table with a valid RocksDB snapshot to support tensor IO
2157
+ optional ShardedTensor for weight_id
2158
+ optional ShardedTensor for bucket_cnt
2159
+ """
2160
+ self ._init_sharded_split_embedding_weights ()
2161
+ # pyre-ignore[16]
2162
+ self ._optim .set_sharded_embedding_weight_ids (self ._split_weights_res [1 ])
2163
+
2164
+ pmt_sharded_t_list = self ._split_weights_res [0 ]
2165
+ weight_id_sharded_t_list = self ._split_weights_res [1 ]
2166
+ bucket_cnt_sharded_t_list = self ._split_weights_res [2 ]
2167
+ for table_idx , pmt_sharded_t in enumerate (pmt_sharded_t_list ):
2168
+ table_config = self ._config .embedding_tables [table_idx ]
2169
+ key = append_prefix (prefix , f"{ table_config .name } " )
2170
+
2171
+ yield key , pmt_sharded_t , weight_id_sharded_t_list [
2172
+ table_idx
2173
+ ], bucket_cnt_sharded_t_list [table_idx ]
2174
+
2175
+ def flush (self ) -> None :
2176
+ """
2177
+ Flush the embeddings in cache back to SSD. Should be pretty expensive.
2178
+ """
2179
+ self .emb_module .flush ()
2180
+
2181
+ def purge (self ) -> None :
2182
+ """
2183
+ Reset the cache space. This is needed when we load state dict.
2184
+ """
2185
+ # TODO: move the following to SSD TBE.
2186
+ self .emb_module .lxu_cache_weights .zero_ ()
2187
+ self .emb_module .lxu_cache_state .fill_ (- 1 )
2188
+
2189
+ def create_rocksdb_hard_link_snapshot (self ) -> None :
2190
+ """
2191
+ Create a RocksDB checkpoint. This is needed before we call state_dict() for publish.
2192
+ """
2193
+ self .emb_module .create_rocksdb_hard_link_snapshot ()
2194
+
2195
+ # pyre-ignore [15]
2196
+ def split_embedding_weights (
2197
+ self , no_snapshot : bool = True , should_flush : bool = False
2198
+ ) -> Tuple [
2199
+ Union [List [PartiallyMaterializedTensor ], List [torch .Tensor ]],
2200
+ Optional [List [torch .Tensor ]],
2201
+ Optional [List [torch .Tensor ]],
2202
+ ]:
2203
+ return self .emb_module .split_embedding_weights (no_snapshot , should_flush )
2204
+
2205
+ def forward (self , features : KeyedJaggedTensor ) -> torch .Tensor :
2206
+ # reset split weights during training
2207
+ self ._split_weights_res = None
2208
+ self ._optim .set_sharded_embedding_weight_ids (sharded_embedding_weight_ids = None )
2209
+
2210
+ return super ().forward (features )
2211
+
2212
+
1888
2213
class KeyValueEmbeddingBag (BaseBatchedEmbeddingBag [torch .Tensor ], FusedOptimizerModule ):
1889
2214
def __init__ (
1890
2215
self ,
@@ -1901,6 +2226,11 @@ def __init__(
1901
2226
assert (
1902
2227
len ({table .embedding_dim for table in config .embedding_tables }) == 1
1903
2228
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
2229
+ for table in config .embedding_tables :
2230
+ assert table .local_cols % 4 == 0 , (
2231
+ f"table { table .name } has local_cols={ table .local_cols } "
2232
+ "not divisible by 4. "
2233
+ )
1904
2234
1905
2235
ssd_tbe_params = _populate_ssd_tbe_params (config )
1906
2236
compute_kernel = config .embedding_tables [0 ].compute_kernel
0 commit comments