Skip to content

Commit abbc3e4

Browse files
faran928facebook-github-bot
authored andcommitted
Support ssd device propagation in Torch Rec for RecSys Inference
Summary: For RecSys Inference when tables are offloaded onto SSD: 1. Specify and propagate the tables to be offloaded to SSD in TorchRec via FUSED_PARAMS 2. Continue using torch.device("cpu") as compute device while using separate input / output dist for SSD (as SSD kernel - EmbeddingDB is different than CPU kernel) by creating a new device group for SSD. Would be renaming device_type_from_sharding_info to storage_device_type_from_sharding_info to clarify it better. Differential Revision: D74378974
1 parent 949278c commit abbc3e4

File tree

8 files changed

+90
-29
lines changed

8 files changed

+90
-29
lines changed

torchrec/distributed/embedding.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
ShardedEmbeddingModule,
4646
ShardingType,
4747
)
48+
from torchrec.distributed.fused_params import (
49+
FUSED_PARAM_IS_SSD_TABLE_PLACEMENT,
50+
FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST,
51+
)
4852
from torchrec.distributed.sharding.cw_sequence_sharding import (
4953
CwSequenceEmbeddingSharding,
5054
)
@@ -184,9 +188,16 @@ def create_sharding_infos_by_sharding_device_group(
184188
assert param_name in parameter_by_name or param_name in state_dict
185189
param = parameter_by_name.get(param_name, state_dict[param_name])
186190

187-
device_group: TypeUnion[str, Tuple[str, ...]] = (
188-
get_device_from_parameter_sharding(parameter_sharding)
189-
)
191+
# if a table name is overridden to be offloaded to ssd storage for inference
192+
# update the device group accordingly
193+
if fused_params and table_name in fused_params.get(
194+
FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST, {}
195+
):
196+
device_group: TypeUnion[str, Tuple[str, ...]] = "ssd"
197+
else:
198+
device_group: TypeUnion[str, Tuple[str, ...]] = (
199+
get_device_from_parameter_sharding(parameter_sharding)
200+
)
190201
if (
191202
parameter_sharding.sharding_type,
192203
device_group,
@@ -214,6 +225,8 @@ def create_sharding_infos_by_sharding_device_group(
214225
per_table_fused_params, parameter_sharding
215226
)
216227
per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params)
228+
if device_group == "ssd":
229+
per_table_fused_params.update({FUSED_PARAM_IS_SSD_TABLE_PLACEMENT: True})
217230

218231
sharding_type_device_group_to_sharding_infos[
219232
(parameter_sharding.sharding_type, device_group)

torchrec/distributed/embedding_sharding.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
ListOfKJTList,
3535
ShardedEmbeddingTable,
3636
)
37+
from torchrec.distributed.fused_params import FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST
3738
from torchrec.distributed.types import (
3839
Awaitable,
3940
EmbeddingEvent,
@@ -420,7 +421,7 @@ def _get_grouping_fused_params(
420421
) -> Optional[Dict[str, Any]]:
421422
"""
422423
Only shallow copy the fused params we need for grouping tables into TBEs. In
423-
particular, we do not copy cache_load_factor.
424+
particular, we do not copy cache_load_factor or ssd embedding table list.
424425
"""
425426
grouping_fused_params: Optional[Dict[str, Any]] = copy.copy(fused_params)
426427

@@ -430,6 +431,9 @@ def _get_grouping_fused_params(
430431
if CACHE_LOAD_FACTOR_STR in grouping_fused_params:
431432
del grouping_fused_params[CACHE_LOAD_FACTOR_STR]
432433

434+
if FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST in grouping_fused_params:
435+
del grouping_fused_params[FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST]
436+
433437
if grouping_fused_params.get(USE_ONE_TBE_PER_TABLE, False):
434438
# Replace with unique value to force it into singleton group.
435439
# Name is used as unique value so we won't group multiple shard belonging

torchrec/distributed/embeddingbag.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
KJTList,
5252
ShardedEmbeddingModule,
5353
)
54+
from torchrec.distributed.fused_params import (
55+
FUSED_PARAM_IS_SSD_TABLE_PLACEMENT,
56+
FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST,
57+
)
5458
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
5559
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
5660
from torchrec.distributed.sharding.dynamic_sharding import (
@@ -227,7 +231,16 @@ def create_sharding_infos_by_sharding_device_group(
227231
assert param_name in parameter_by_name or param_name in state_dict
228232
param = parameter_by_name.get(param_name, state_dict[param_name])
229233

230-
device_group = get_device_from_parameter_sharding(parameter_sharding)
234+
# if a table name is overridden to be offloaded to ssd storage for inference
235+
# update the device group accordingly
236+
if fused_params and table_name in fused_params.get(
237+
FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST, {}
238+
):
239+
device_group: Union[str, Tuple[str, ...]] = "ssd"
240+
else:
241+
device_group: Union[str, Tuple[str, ...]] = (
242+
get_device_from_parameter_sharding(parameter_sharding)
243+
)
231244

232245
if (
233246
parameter_sharding.sharding_type,
@@ -257,6 +270,8 @@ def create_sharding_infos_by_sharding_device_group(
257270
per_table_fused_params, parameter_sharding
258271
)
259272
per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params)
273+
if device_group == "ssd":
274+
per_table_fused_params.update({FUSED_PARAM_IS_SSD_TABLE_PLACEMENT: True})
260275

261276
sharding_type_device_group_to_sharding_infos[
262277
(parameter_sharding.sharding_type, device_group)

torchrec/distributed/fused_params.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
# with certain ways to split models.
2929
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: str = "__register_lengths_to_offsets_lookup"
3030

31+
# List of cpu embedding tables offloaded to ssd to scale the embedding table size
32+
FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST: str = "__register_ssd_table_placement_list"
33+
# Bool param per table to check if the table is offloaded to SSD
34+
FUSED_PARAM_IS_SSD_TABLE_PLACEMENT: str = "__register_is_ssd_table_placement"
35+
3136

3237
class TBEToRegisterMixIn:
3338
def get_tbes_to_register(
@@ -111,5 +116,7 @@ def tbe_fused_params(
111116
fused_params_for_tbe.pop(FUSED_PARAM_BOUNDS_CHECK_MODE)
112117
if FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP in fused_params_for_tbe:
113118
fused_params_for_tbe.pop(FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP)
119+
if FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST in fused_params_for_tbe:
120+
fused_params_for_tbe.pop(FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST)
114121

115122
return fused_params_for_tbe

torchrec/distributed/quant_embedding.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
ShardingType,
4848
)
4949
from torchrec.distributed.fused_params import (
50+
FUSED_PARAM_IS_SSD_TABLE_PLACEMENT,
5051
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
5152
FUSED_PARAM_REGISTER_TBE_BOOL,
5253
get_tbes_to_register_from_iterable,
@@ -173,12 +174,19 @@ def get_device_from_parameter_sharding(
173174
def get_device_from_sharding_infos(
174175
emb_shard_infos: List[EmbeddingShardingInfo],
175176
) -> Union[str, Tuple[str, ...]]:
176-
res = list(
177-
{
178-
get_device_from_parameter_sharding(ps.param_sharding)
179-
for ps in emb_shard_infos
180-
}
181-
)
177+
res_set = set()
178+
for emb_shard_info in emb_shard_infos:
179+
if (
180+
emb_shard_info.fused_params
181+
and FUSED_PARAM_IS_SSD_TABLE_PLACEMENT in emb_shard_info.fused_params
182+
and emb_shard_info.fused_params[FUSED_PARAM_IS_SSD_TABLE_PLACEMENT]
183+
):
184+
res_set.add("ssd")
185+
else:
186+
res_set.add(
187+
get_device_from_parameter_sharding(emb_shard_info.param_sharding)
188+
)
189+
res = list(res_set)
182190
assert len(res) == 1, "All shards should be on the same type of device"
183191
return res[0]
184192

@@ -201,11 +209,11 @@ def create_infer_embedding_sharding(
201209
List[torch.Tensor],
202210
List[torch.Tensor],
203211
]:
204-
device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
212+
storage_device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
205213
get_device_from_sharding_infos(sharding_infos)
206214
)
207215

208-
if device_type_from_sharding_infos in ["cuda", "mtia"]:
216+
if storage_device_type_from_sharding_infos in ["cuda", "mtia"]:
209217
if sharding_type == ShardingType.TABLE_WISE.value:
210218
return InferTwSequenceEmbeddingSharding(sharding_infos, env, device)
211219
elif sharding_type == ShardingType.COLUMN_WISE.value:
@@ -215,31 +223,31 @@ def create_infer_embedding_sharding(
215223
sharding_infos=sharding_infos,
216224
env=env,
217225
device=device,
218-
device_type_from_sharding_infos=device_type_from_sharding_infos,
226+
device_type_from_sharding_infos=storage_device_type_from_sharding_infos,
219227
)
220228
else:
221229
raise ValueError(
222-
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
230+
f"Sharding type not supported {sharding_type} for {storage_device_type_from_sharding_infos} sharding"
223231
)
224-
elif device_type_from_sharding_infos == "cpu" or isinstance(
225-
device_type_from_sharding_infos, tuple
232+
elif storage_device_type_from_sharding_infos in ["cpu", "ssd"] or isinstance(
233+
storage_device_type_from_sharding_infos, tuple
226234
):
227235
if sharding_type == ShardingType.ROW_WISE.value:
228236
return InferRwSequenceEmbeddingSharding(
229237
sharding_infos=sharding_infos,
230238
env=env,
231239
device=device,
232-
device_type_from_sharding_infos=device_type_from_sharding_infos,
240+
device_type_from_sharding_infos=storage_device_type_from_sharding_infos,
233241
)
234242
elif sharding_type == ShardingType.TABLE_WISE.value:
235243
return InferTwSequenceEmbeddingSharding(sharding_infos, env, device)
236244
else:
237245
raise ValueError(
238-
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
246+
f"Sharding type not supported {sharding_type} for {storage_device_type_from_sharding_infos} sharding"
239247
)
240248
else:
241249
raise ValueError(
242-
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
250+
f"Sharding type not supported {sharding_type} for {storage_device_type_from_sharding_infos} sharding"
243251
)
244252

245253

@@ -542,6 +550,10 @@ def __init__(
542550
module, table_name_to_parameter_sharding, fused_params
543551
)
544552

553+
for x, y in self._sharding_type_device_group_to_sharding_infos.items():
554+
print(f"SHARDING INFO: {x}")
555+
print("=========================")
556+
545557
self._sharding_type_device_group_to_sharding: Dict[
546558
Tuple[str, Union[str, Tuple[str, ...]]],
547559
EmbeddingSharding[

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from torchrec.distributed.fused_params import (
3535
fused_param_bounds_check_mode,
3636
fused_param_lengths_to_offsets_lookup,
37+
FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST,
3738
is_fused_param_quant_state_dict_split_scale_bias,
3839
is_fused_param_register_tbe,
3940
tbe_fused_params,

torchrec/distributed/quant_embeddingbag.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
create_sharding_infos_by_sharding_device_group,
3636
)
3737
from torchrec.distributed.fused_params import (
38+
FUSED_PARAM_IS_SSD_TABLE_PLACEMENT,
3839
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
3940
FUSED_PARAM_REGISTER_TBE_BOOL,
4041
get_tbes_to_register_from_iterable,
@@ -97,12 +98,17 @@ def get_device_from_parameter_sharding(
9798
def get_device_from_sharding_infos(
9899
emb_shard_infos: List[EmbeddingShardingInfo],
99100
) -> Union[str, Tuple[str, ...]]:
100-
res = list(
101-
{
102-
get_device_from_parameter_sharding(ps.param_sharding)
103-
for ps in emb_shard_infos
104-
}
105-
)
101+
res_set = set()
102+
for emb_shard_info in emb_shard_infos:
103+
if emb_shard_info.fused_params and emb_shard_info.fused_params.get(
104+
FUSED_PARAM_IS_SSD_TABLE_PLACEMENT, False
105+
):
106+
res_set.add("ssd")
107+
else:
108+
res_set.add(
109+
get_device_from_parameter_sharding(emb_shard_info.param_sharding)
110+
)
111+
res = list(res_set)
106112
assert len(res) == 1, "All shards should be on the same type of device"
107113
return res[0]
108114

@@ -131,7 +137,7 @@ def create_infer_embedding_bag_sharding(
131137
NullShardingContext, InputDistOutputs, List[torch.Tensor], torch.Tensor
132138
]:
133139
propogate_device: bool = get_propogate_device()
134-
device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
140+
storage_device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
135141
get_device_from_sharding_infos(sharding_infos)
136142
)
137143
if sharding_type == ShardingType.TABLE_WISE.value:
@@ -143,7 +149,7 @@ def create_infer_embedding_bag_sharding(
143149
sharding_infos,
144150
env,
145151
device=device if propogate_device else None,
146-
device_type_from_sharding_infos=device_type_from_sharding_infos,
152+
device_type_from_sharding_infos=storage_device_type_from_sharding_infos,
147153
)
148154
elif sharding_type == ShardingType.COLUMN_WISE.value:
149155
return InferCwPooledEmbeddingSharding(

torchrec/distributed/sharding/rw_sequence_sharding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ def forward(
214214
# using _device_type_from_sharding_infos to iterate on local_embs list as
215215
# that's a better practice.
216216
for i, device_type in enumerate(self._device_type_from_sharding_infos):
217+
assert (
218+
device_type != "ssd"
219+
), "Heterogenous sharding across multiple storage device types for a single table not supported for ssd stroage device type"
217220
if device_type != "cpu":
218221
non_cpu_local_embs.append(
219222
_get_batching_hinted_output(
@@ -235,7 +238,7 @@ def forward(
235238
result.append(non_cpu_local_embs_dist[index])
236239
index += 1
237240
return result
238-
elif self._device_type_from_sharding_infos == "cpu":
241+
elif self._device_type_from_sharding_infos in ["cpu", "ssd"]:
239242
# for cpu sharder, output dist should be a no-op
240243
return local_embs
241244
else:

0 commit comments

Comments
 (0)