Skip to content

Commit 785253b

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Free mem trigger with all2all for sync trigger eviction (#3490)
Summary: X-link: pytorch/FBGEMM#5062 X-link: facebookresearch/FBGEMM#2070 Before KVZCH is using ID_COUNT and MEM_UTIL eviction trigger mode, both are very tricky and hard for model engineer to decide what num to use for the id count or mem util threshold. Besides that, the eviction start time is out of sync after some time in training, which can cause great qps drop during eviction. This diff is adding support for free memory trigger eviction. It will check how many free memory left every N batch in every rank and if free memory below the threshold, it will trigger eviction in all tbes of all ranks using all reduce. In this way, we can force the start time of eviction in all ranks. Reviewed By: emlin Differential Revision: D85604160
1 parent 4f1f62d commit 785253b

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
242242
)
243243
ssd_tbe_params["cache_sets"] = int(max_cache_sets)
244244

245-
if "kvzch_eviction_trigger_mode" in fused_params and config.is_using_virtual_table:
246-
ssd_tbe_params["kvzch_eviction_trigger_mode"] = fused_params.get(
247-
"kvzch_eviction_trigger_mode"
245+
if "kvzch_eviction_tbe_config" in fused_params and config.is_using_virtual_table:
246+
ssd_tbe_params["kvzch_eviction_tbe_config"] = fused_params.get(
247+
"kvzch_eviction_tbe_config"
248248
)
249249

250250
ssd_tbe_params["table_names"] = [table.name for table in config.embedding_tables]
@@ -337,11 +337,25 @@ def _populate_zero_collision_tbe_params(
337337
eviction_strategy = -1
338338
table_names = [table.name for table in config.embedding_tables]
339339
l2_cache_size = tbe_params["l2_cache_size"]
340-
if "kvzch_eviction_trigger_mode" in tbe_params:
341-
eviction_trigger_mode = tbe_params["kvzch_eviction_trigger_mode"]
342-
tbe_params.pop("kvzch_eviction_trigger_mode")
343-
else:
344-
eviction_trigger_mode = 2 # 2 means mem_util based eviction
340+
341+
assert (
342+
"kvzch_eviction_tbe_config" in tbe_params
343+
), "kvzch_eviction_tbe_config should be in tbe_params"
344+
eviction_tbe_config = tbe_params["kvzch_eviction_tbe_config"]
345+
tbe_params.pop("kvzch_eviction_tbe_config")
346+
eviction_trigger_mode = eviction_tbe_config.kvzch_eviction_trigger_mode
347+
eviction_free_mem_threshold_gb = (
348+
eviction_tbe_config.eviction_free_mem_threshold_gb
349+
)
350+
eviction_free_mem_check_interval_batch = (
351+
eviction_tbe_config.eviction_free_mem_check_interval_batch
352+
)
353+
threshold_calculation_bucket_stride = (
354+
eviction_tbe_config.threshold_calculation_bucket_stride
355+
)
356+
threshold_calculation_bucket_num = (
357+
eviction_tbe_config.threshold_calculation_bucket_num
358+
)
345359
for i, table in enumerate(config.embedding_tables):
346360
policy_t = table.virtual_table_eviction_policy
347361
if policy_t is not None:
@@ -421,6 +435,10 @@ def _populate_zero_collision_tbe_params(
421435
training_id_keep_count=training_id_keep_count,
422436
l2_weight_thresholds=l2_weight_thresholds,
423437
meta_header_lens=meta_header_lens,
438+
eviction_free_mem_threshold_gb=eviction_free_mem_threshold_gb,
439+
eviction_free_mem_check_interval_batch=eviction_free_mem_check_interval_batch,
440+
threshold_calculation_bucket_stride=threshold_calculation_bucket_stride,
441+
threshold_calculation_bucket_num=threshold_calculation_bucket_num,
424442
)
425443
else:
426444
eviction_policy = EvictionPolicy(meta_header_lens=meta_header_lens)
@@ -1768,6 +1786,7 @@ def __init__(
17681786
feature_table_map=self._feature_table_map,
17691787
ssd_cache_location=embedding_location,
17701788
pooling_mode=PoolingMode.NONE,
1789+
pg=pg,
17711790
**ssd_tbe_params,
17721791
).to(device)
17731792

@@ -2000,6 +2019,7 @@ def __init__(
20002019
ssd_cache_location=embedding_location,
20012020
pooling_mode=PoolingMode.NONE,
20022021
backend_type=backend_type,
2022+
pg=pg,
20032023
**ssd_tbe_params,
20042024
).to(device)
20052025

@@ -2680,6 +2700,7 @@ def __init__(
26802700
feature_table_map=self._feature_table_map,
26812701
ssd_cache_location=embedding_location,
26822702
pooling_mode=self._pooling,
2703+
pg=pg,
26832704
**ssd_tbe_params,
26842705
).to(device)
26852706

@@ -2900,6 +2921,7 @@ def __init__(
29002921
ssd_cache_location=embedding_location,
29012922
pooling_mode=self._pooling,
29022923
backend_type=backend_type,
2924+
pg=pg,
29032925
**ssd_tbe_params,
29042926
).to(device)
29052927

torchrec/distributed/types.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
3131
BoundsCheckMode,
3232
CacheAlgorithm,
33+
KVZCHEvictionTBEConfig,
3334
MultiPassPrefetchConfig,
3435
)
3536

@@ -662,7 +663,7 @@ class KeyValueParams:
662663
lazy_bulk_init_enabled: bool: whether to enable lazy(async) bulk init for SSD TBE
663664
enable_raw_embedding_streaming: Optional[bool]: enable raw embedding streaming for SSD TBE
664665
res_store_shards: Optional[int] = None: the number of shards to store the raw embeddings
665-
kvzch_eviction_trigger_mode: Optional[int]: eviction trigger mode for KVZCH
666+
kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig]: KVZCH eviction config for TBE
666667
667668
# Parameter Server (PS) Attributes
668669
ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses
@@ -688,7 +689,7 @@ class KeyValueParams:
688689
None # enable raw embedding streaming for SSD TBE
689690
)
690691
res_store_shards: Optional[int] = None # shards to store the raw embeddings
691-
kvzch_eviction_trigger_mode: Optional[int] = None # eviction trigger mode for KVZCH
692+
kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig] = None
692693

693694
# Parameter Server (PS) Attributes
694695
ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None
@@ -717,7 +718,7 @@ def __hash__(self) -> int:
717718
self.lazy_bulk_init_enabled,
718719
self.enable_raw_embedding_streaming,
719720
self.res_store_shards,
720-
self.kvzch_eviction_trigger_mode,
721+
self.kvzch_eviction_tbe_config,
721722
)
722723
)
723724

0 commit comments

Comments
 (0)