diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 968d7aa15..6e5daaaef 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -242,9 +242,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: ) ssd_tbe_params["cache_sets"] = int(max_cache_sets) - if "kvzch_eviction_trigger_mode" in fused_params and config.is_using_virtual_table: - ssd_tbe_params["kvzch_eviction_trigger_mode"] = fused_params.get( - "kvzch_eviction_trigger_mode" + if "kvzch_eviction_tbe_config" in fused_params and config.is_using_virtual_table: + ssd_tbe_params["kvzch_eviction_tbe_config"] = fused_params.get( + "kvzch_eviction_tbe_config" ) ssd_tbe_params["table_names"] = [table.name for table in config.embedding_tables] @@ -337,11 +337,25 @@ def _populate_zero_collision_tbe_params( eviction_strategy = -1 table_names = [table.name for table in config.embedding_tables] l2_cache_size = tbe_params["l2_cache_size"] - if "kvzch_eviction_trigger_mode" in tbe_params: - eviction_trigger_mode = tbe_params["kvzch_eviction_trigger_mode"] - tbe_params.pop("kvzch_eviction_trigger_mode") - else: - eviction_trigger_mode = 2 # 2 means mem_util based eviction + + assert ( + "kvzch_eviction_tbe_config" in tbe_params + ), "kvzch_eviction_tbe_config should be in tbe_params" + eviction_tbe_config = tbe_params["kvzch_eviction_tbe_config"] + tbe_params.pop("kvzch_eviction_tbe_config") + eviction_trigger_mode = eviction_tbe_config.kvzch_eviction_trigger_mode + eviction_free_mem_threshold_gb = ( + eviction_tbe_config.eviction_free_mem_threshold_gb + ) + eviction_free_mem_check_interval_batch = ( + eviction_tbe_config.eviction_free_mem_check_interval_batch + ) + threshold_calculation_bucket_stride = ( + eviction_tbe_config.threshold_calculation_bucket_stride + ) + threshold_calculation_bucket_num = ( + eviction_tbe_config.threshold_calculation_bucket_num + ) for i, table in enumerate(config.embedding_tables): policy_t = table.virtual_table_eviction_policy if policy_t is not None: @@ -421,6 +435,10 @@ def _populate_zero_collision_tbe_params( training_id_keep_count=training_id_keep_count, l2_weight_thresholds=l2_weight_thresholds, meta_header_lens=meta_header_lens, + eviction_free_mem_threshold_gb=eviction_free_mem_threshold_gb, + eviction_free_mem_check_interval_batch=eviction_free_mem_check_interval_batch, + threshold_calculation_bucket_stride=threshold_calculation_bucket_stride, + threshold_calculation_bucket_num=threshold_calculation_bucket_num, ) else: eviction_policy = EvictionPolicy(meta_header_lens=meta_header_lens) @@ -1768,6 +1786,7 @@ def __init__( feature_table_map=self._feature_table_map, ssd_cache_location=embedding_location, pooling_mode=PoolingMode.NONE, + pg=pg, **ssd_tbe_params, ).to(device) @@ -2000,6 +2019,7 @@ def __init__( ssd_cache_location=embedding_location, pooling_mode=PoolingMode.NONE, backend_type=backend_type, + pg=pg, **ssd_tbe_params, ).to(device) @@ -2680,6 +2700,7 @@ def __init__( feature_table_map=self._feature_table_map, ssd_cache_location=embedding_location, pooling_mode=self._pooling, + pg=pg, **ssd_tbe_params, ).to(device) @@ -2900,6 +2921,7 @@ def __init__( ssd_cache_location=embedding_location, pooling_mode=self._pooling, backend_type=backend_type, + pg=pg, **ssd_tbe_params, ).to(device) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 01dfe3acf..0fb74665e 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -33,6 +33,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( BoundsCheckMode, CacheAlgorithm, + KVZCHEvictionTBEConfig, MultiPassPrefetchConfig, ) @@ -667,7 +668,7 @@ class KeyValueParams: lazy_bulk_init_enabled: bool: whether to enable lazy(async) bulk init for SSD TBE enable_raw_embedding_streaming: Optional[bool]: enable raw embedding streaming for SSD TBE res_store_shards: Optional[int] = None: the number of shards to store the raw embeddings - kvzch_eviction_trigger_mode: Optional[int]: eviction trigger mode for KVZCH + kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig]: KVZCH eviction config for TBE # Parameter Server (PS) Attributes ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses @@ -693,7 +694,7 @@ class KeyValueParams: None # enable raw embedding streaming for SSD TBE ) res_store_shards: Optional[int] = None # shards to store the raw embeddings - kvzch_eviction_trigger_mode: Optional[int] = None # eviction trigger mode for KVZCH + kvzch_eviction_tbe_config: Optional[KVZCHEvictionTBEConfig] = None # Parameter Server (PS) Attributes ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None @@ -722,7 +723,7 @@ def __hash__(self) -> int: self.lazy_bulk_init_enabled, self.enable_raw_embedding_streaming, self.res_store_shards, - self.kvzch_eviction_trigger_mode, + self.kvzch_eviction_tbe_config, ) )