Skip to content

Commit 75dfb7f

Browse files
EddyLXJmeta-codesync[bot]
authored andcommitted
Free mem trigger with all2all for sync trigger eviction (#3490)
Summary: X-link: pytorch/FBGEMM#5062 Pull Request resolved: #3490 X-link: https://github.com/facebookresearch/FBGEMM/pull/2070 Before KVZCH is using ID_COUNT and MEM_UTIL eviction trigger mode, both are very tricky and hard for model engineer to decide what num to use for the id count or mem util threshold. Besides that, the eviction start time is out of sync after some time in training, which can cause great qps drop during eviction. This diff is adding support for free memory trigger eviction. It will check how many free memory left every N batch in every rank and if free memory below the threshold, it will trigger eviction in all tbes of all ranks using all reduce. In this way, we can force the start time of eviction in all ranks. Reviewed By: emlin Differential Revision: D85604160 fbshipit-source-id: 177ec779960a4ac9bfc3d41f38beeb7e56665db8
1 parent fb8932a commit 75dfb7f

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
242242
)
243243
ssd_tbe_params["cache_sets"] = int(max_cache_sets)
244244

245-
if "kvzch_eviction_trigger_mode" in fused_params and config.is_using_virtual_table:
246-
ssd_tbe_params["kvzch_eviction_trigger_mode"] = fused_params.get(
247-
"kvzch_eviction_trigger_mode"
245+
if "kvzch_eviction_tbe_config" in fused_params and config.is_using_virtual_table:
246+
ssd_tbe_params["kvzch_eviction_tbe_config"] = fused_params.get(
247+
"kvzch_eviction_tbe_config"
248248
)
249249

250250
ssd_tbe_params["table_names"] = [table.name for table in config.embedding_tables]
@@ -337,11 +337,25 @@ def _populate_zero_collision_tbe_params(
337337
eviction_strategy = -1
338338
table_names = [table.name for table in config.embedding_tables]
339339
l2_cache_size = tbe_params["l2_cache_size"]
340-
if "kvzch_eviction_trigger_mode" in tbe_params:
341-
eviction_trigger_mode = tbe_params["kvzch_eviction_trigger_mode"]
342-
tbe_params.pop("kvzch_eviction_trigger_mode")
343-
else:
344-
eviction_trigger_mode = 2 # 2 means mem_util based eviction
340+
341+
assert (
342+
"kvzch_eviction_tbe_config" in tbe_params
343+
), "kvzch_eviction_tbe_config should be in tbe_params"
344+
eviction_tbe_config = tbe_params["kvzch_eviction_tbe_config"]
345+
tbe_params.pop("kvzch_eviction_tbe_config")
346+
eviction_trigger_mode = eviction_tbe_config.kvzch_eviction_trigger_mode
347+
eviction_free_mem_threshold_gb = (
348+
eviction_tbe_config.eviction_free_mem_threshold_gb
349+
)
350+
eviction_free_mem_check_interval_batch = (
351+
eviction_tbe_config.eviction_free_mem_check_interval_batch
352+
)
353+
threshold_calculation_bucket_stride = (
354+
eviction_tbe_config.threshold_calculation_bucket_stride
355+
)
356+
threshold_calculation_bucket_num = (
357+
eviction_tbe_config.threshold_calculation_bucket_num
358+
)
345359
for i, table in enumerate(config.embedding_tables):
346360
policy_t = table.virtual_table_eviction_policy
347361
if policy_t is not None:
@@ -421,6 +435,10 @@ def _populate_zero_collision_tbe_params(
421435
training_id_keep_count=training_id_keep_count,
422436
l2_weight_thresholds=l2_weight_thresholds,
423437
meta_header_lens=meta_header_lens,
438+
eviction_free_mem_threshold_gb=eviction_free_mem_threshold_gb,
439+
eviction_free_mem_check_interval_batch=eviction_free_mem_check_interval_batch,
440+
threshold_calculation_bucket_stride=threshold_calculation_bucket_stride,
441+
threshold_calculation_bucket_num=threshold_calculation_bucket_num,
424442
)
425443
else:
426444
eviction_policy = EvictionPolicy(meta_header_lens=meta_header_lens)
@@ -1768,6 +1786,7 @@ def __init__(
17681786
feature_table_map=self._feature_table_map,
17691787
ssd_cache_location=embedding_location,
17701788
pooling_mode=PoolingMode.NONE,
1789+
pg=pg,
17711790
**ssd_tbe_params,
17721791
).to(device)
17731792

@@ -2000,6 +2019,7 @@ def __init__(
20002019
ssd_cache_location=embedding_location,
20012020
pooling_mode=PoolingMode.NONE,
20022021
backend_type=backend_type,
2022+
pg=pg,
20032023
**ssd_tbe_params,
20042024
).to(device)
20052025

@@ -2680,6 +2700,7 @@ def __init__(
26802700
feature_table_map=self._feature_table_map,
26812701
ssd_cache_location=embedding_location,
26822702
pooling_mode=self._pooling,
2703+
pg=pg,
26832704
**ssd_tbe_params,
26842705
).to(device)
26852706

@@ -2900,6 +2921,7 @@ def __init__(
29002921
ssd_cache_location=embedding_location,
29012922
pooling_mode=self._pooling,
29022923
backend_type=backend_type,
2924+
pg=pg,
29032925
**ssd_tbe_params,
29042926
).to(device)
29052927

torchrec/distributed/types.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
3434
BoundsCheckMode,
3535
CacheAlgorithm,
36+
KVZCHEvictionTBEConfig,
3637
MultiPassPrefetchConfig,
3738
)
3839

@@ -667,7 +668,7 @@ class KeyValueParams:
667668
lazy_bulk_init_enabled: bool: whether to enable lazy(async) bulk init for SSD TBE
668669
enable_raw_embedding_streaming: Optional[bool]: enable raw embedding streaming for SSD TBE
669670
res_store_shards: Optional[int] = None: the number of shards to store the raw embeddings
670-
kvzch_eviction_trigger_mode: Optional[int]: eviction trigger mode for KVZCH
671+
kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig]: KVZCH eviction config for TBE
671672
672673
# Parameter Server (PS) Attributes
673674
ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses
@@ -693,7 +694,7 @@ class KeyValueParams:
693694
None # enable raw embedding streaming for SSD TBE
694695
)
695696
res_store_shards: Optional[int] = None # shards to store the raw embeddings
696-
kvzch_eviction_trigger_mode: Optional[int] = None # eviction trigger mode for KVZCH
697+
kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig] = None
697698

698699
# Parameter Server (PS) Attributes
699700
ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None
@@ -722,7 +723,7 @@ def __hash__(self) -> int:
722723
self.lazy_bulk_init_enabled,
723724
self.enable_raw_embedding_streaming,
724725
self.res_store_shards,
725-
self.kvzch_eviction_trigger_mode,
726+
self.kvzch_eviction_tbe_config,
726727
)
727728
)
728729

0 commit comments

Comments
 (0)