diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index e4953ac28..cbc419146 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -84,6 +84,7 @@ CountTimestampMixedEvictionPolicy, data_type_to_sparse_type, FeatureL2NormBasedEvictionPolicy, + FeatureScoreBasedEvictionPolicy, NoEvictionPolicy, pooling_type_to_pooling_mode, TimestampBasedEvictionPolicy, @@ -235,6 +236,9 @@ def _populate_zero_collision_tbe_params( counter_thresholds = [0] * len(config.embedding_tables) ttls_in_mins = [0] * len(config.embedding_tables) counter_decay_rates = [0.0] * len(config.embedding_tables) + feature_score_counter_decay_rates = [0.0] * len(config.embedding_tables) + max_training_id_num_per_table = [0] * len(config.embedding_tables) + target_eviction_percent_per_table = [0.0] * len(config.embedding_tables) l2_weight_thresholds = [0.0] * len(config.embedding_tables) eviction_strategy = -1 table_names = [table.name for table in config.embedding_tables] @@ -251,6 +255,20 @@ def _populate_zero_collision_tbe_params( raise ValueError( f"Do not support multiple eviction strategy in one tbe {eviction_strategy} and 1 for tables {table_names}" ) + elif isinstance(policy_t, FeatureScoreBasedEvictionPolicy): + feature_score_counter_decay_rates[i] = policy_t.decay_rate + max_training_id_num_per_table[i] = ( + policy_t.max_training_id_num_per_rank + ) + target_eviction_percent_per_table[i] = ( + policy_t.target_eviction_percent + ) + if eviction_strategy == -1 or eviction_strategy == 5: + eviction_strategy = 5 + else: + raise ValueError( + f"Do not support multiple eviction strategy in one tbe {eviction_strategy} and 5 for tables {table_names}" + ) elif isinstance(policy_t, TimestampBasedEvictionPolicy): ttls_in_mins[i] = policy_t.eviction_ttl_mins if eviction_strategy == -1 or eviction_strategy == 0: @@ -288,6 +306,9 @@ def _populate_zero_collision_tbe_params( counter_thresholds=counter_thresholds, ttls_in_mins=ttls_in_mins, counter_decay_rates=counter_decay_rates, + feature_score_counter_decay_rates=feature_score_counter_decay_rates, + max_training_id_num_per_table=max_training_id_num_per_table, + target_eviction_percent_per_table=target_eviction_percent_per_table, l2_weight_thresholds=l2_weight_thresholds, meta_header_lens=meta_header_lens, ) diff --git a/torchrec/modules/embedding_configs.py b/torchrec/modules/embedding_configs.py index 393b1025f..7505ade9b 100644 --- a/torchrec/modules/embedding_configs.py +++ b/torchrec/modules/embedding_configs.py @@ -203,6 +203,17 @@ def __post_init__(self) -> None: self.inference_eviction_threshold = self.eviction_threshold +@dataclass +class FeatureScoreBasedEvictionPolicy(VirtualTableEvictionPolicy): + """ + Feature score based eviction policy for virtual table. + """ + + decay_rate: float = 0.99 # default decay by default #TODO: Change to real value + max_training_id_num_per_rank: int = 0 # max number of training ids per rank + target_eviction_percent: float = 0.0 # target eviction percent + + @dataclass class TimestampBasedEvictionPolicy(VirtualTableEvictionPolicy): """