Skip to content

Commit 58ee896

Browse files
ADD: self-attention model (#281)
1 parent e10c08a commit 58ee896

File tree

17 files changed

+1655
-72
lines changed

17 files changed

+1655
-72
lines changed

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ Here are two tutorials published on Medium that can help you
122122
| Shopper | [![alt text](docs/illustrations/logos/jupyter_logo.png)](notebooks/basket_models/shopper.ipynb)   [![Open In Colab](https://img.shields.io/badge/-grey?logo=googlecolab)](https://colab.research.google.com/github/artefactory/choice-learn/blob/main/notebooks/basket_models/shopper.ipynb) | Ruiz et al. [[16]](#trident-references)           | *Shopper* | [#](https://artefactory.github.io/choice-learn/references/basket_models/references_shopper/) |
123123
| Alea Carta | [![alt text](docs/illustrations/logos/jupyter_logo.png)](notebooks/basket_models/alea_carta.ipynb)   [![Open In Colab](https://img.shields.io/badge/-grey?logo=googlecolab)](https://colab.research.google.com/github/artefactory/choice-learn/blob/main/notebooks/basket_models/alea_carta.ipynb) | Désir et al. [[17]](#trident-references) | *AleaCarta* | [#](https://artefactory.github.io/choice-learn/references/basket_models/references_alea_carta/) |
124124
| Base Attention | [![alt text](docs/illustrations/logos/jupyter_logo.png)](notebooks/basket_models/basic_attention.ipynb)   [![Open In Colab](https://img.shields.io/badge/-grey?logo=googlecolab)](https://colab.research.google.com/github/artefactory/choice-learn/blob/main/notebooks/basket_models/basic_attention.ipynb) | Wang et al. [[18]](#trident-references) | *AttentionBasedContextEmbedding* | [#]() |
125+
| Self Attention | [![alt text](docs/illustrations/logos/jupyter_logo.png)](notebooks/basket_models/self_attention.ipynb)   [![Open In Colab](https://img.shields.io/badge/-grey?logo=googlecolab)](https://colab.research.google.com/github/artefactory/choice-learn/blob/main/notebooks/basket_models/self_attention.ipynb) | Zang et al. [[20]](#trident-references) | *SelfAttentionModel* | [#]() |
125126

126127

127128
### Data
@@ -377,8 +378,8 @@ The use of this software is under the MIT license, with no limitation of usage,
377378
[16] [SHOPPER: A Probabilistic Model of Consumer Choice with Substitutes and Complements](https://arxiv.org/abs/1711.03560), Ruiz, F. J. R.; Athey, S.; Blei, D. M. (2019)\
378379
[17] [Better Capturing Interactions between Products in Retail: Revisited Negative Sampling for Basket Choice Modeling](https://ojs.aaai.org/index.php/AAAI/article/view/11851), Désir, J.; Auriau, V.; Možina, M.; Malherbe, E. (2025), ECML PKDDD\
379380
[18] [Attention-based Transactional Context Embedding for Next-Item Recommendation](https://ojs.aaai.org/index.php/AAAI/article/view/11851), Wans, S.; Liang, H.; Longbing,C.; Xiaoshui, H.; Defu, L.; Wei, L. (2018)\
380-
[19] [A Discrete Choice Model for Subset Selection.](https://www.cs.cornell.edu/~arb/papers/higher-order-choice-wsdm-2018.pdf), Benson, A.; Kumar, R.; Tomkins, A. (2018)
381-
381+
[19] [A Discrete Choice Model for Subset Selection.](https://www.cs.cornell.edu/~arb/papers/higher-order-choice-wsdm-2018.pdf), Benson, A.; Kumar, R.; Tomkins, A. (2018)\
382+
[20] [Next Item Recommendation with Self-Attention.](https://recnlp2019.github.io/papers/RecNLP2019_paper_21.pdf), Zhang, S.; Yao, L.; Tay, Y.; Sun, A. (2018)\
382383
### Code and Repositories
383384

384385
*Official models implementations:*
@@ -388,4 +389,5 @@ The use of this software is under the MIT license, with no limitation of usage,
388389
[12] [ResLogit](https://github.com/LiTrans/reslogit)\
389390
[13] [Learning-MNL](https://github.com/BSifringer/EnhancedDCM)\
390391
[16] [Shopper](https://github.com/franrruiz/shopper-src)\
391-
[17] [AleaCarta](https://github.com/artefactory/alea-carta-est)
392+
[17] [AleaCarta](https://github.com/artefactory/alea-carta-est)\
393+
[20] [SelfAttention](https://github.com/artefactory/rd-self-attentive)

choice_learn/basket_models/alea_carta.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
weight_decay: Union[float, None] = None,
3535
momentum: float = 0.0,
3636
epsilon_price: float = 1e-5,
37+
l2_regularization: float = 0.0,
3738
**kwargs,
3839
) -> None:
3940
"""Initialize the AleaCarta model.
@@ -79,6 +80,7 @@ def __init__(
7980
self.item_intercept = item_intercept
8081
self.price_effects = price_effects
8182
self.seasonal_effects = seasonal_effects
83+
self.l2_regularization = l2_regularization
8284

8385
if "preferences" not in latent_sizes.keys():
8486
logging.warning(
@@ -265,6 +267,7 @@ def compute_batch_utility(
265267
week_batch: np.ndarray,
266268
price_batch: np.ndarray,
267269
available_item_batch: np.ndarray,
270+
user_batch: Union[np.ndarray, tf.Tensor],
268271
) -> tf.Tensor:
269272
"""Compute the utility of all the items in item_batch given the items in basket_batch.
270273
@@ -297,6 +300,7 @@ def compute_batch_utility(
297300
Utility of all the items in item_batch
298301
Shape must be (batch_size,)
299302
"""
303+
_ = user_batch
300304
_ = available_item_batch
301305
item_batch = tf.cast(item_batch, dtype=tf.int32)
302306
if len(tf.shape(item_batch)) == 1:
@@ -448,7 +452,6 @@ def compute_basket_utility(
448452
[np.delete(basket, i) for i in range(len_basket)]
449453
) # Shape: (len_basket, len(basket) - 1)
450454

451-
# Basket utility = sum of the utilities of the items in the basket
452455
return tf.reduce_sum(
453456
self.compute_batch_utility(
454457
item_batch=basket,
@@ -457,6 +460,7 @@ def compute_basket_utility(
457460
week_batch=np.array([week] * len_basket),
458461
price_batch=prices,
459462
available_item_batch=available_item_batch,
463+
user_batch=None,
460464
)
461465
).numpy()
462466

@@ -536,6 +540,7 @@ def compute_batch_loss(
536540
week_batch: np.ndarray,
537541
price_batch: np.ndarray,
538542
available_item_batch: np.ndarray,
543+
user_batch: np.ndarray,
539544
) -> tuple[tf.Variable]:
540545
"""Compute log-likelihood and loss for one batch of items.
541546
@@ -576,6 +581,7 @@ def compute_batch_loss(
576581
Approximated by difference of utilities between positive and negative samples
577582
Shape must be (1,)
578583
"""
584+
_ = user_batch
579585
_ = future_batch
580586
batch_size = len(item_batch)
581587
item_batch = tf.cast(item_batch, dtype=tf.int32)
@@ -593,7 +599,6 @@ def compute_batch_loss(
593599
],
594600
axis=0,
595601
)
596-
597602
augmented_item_batch = tf.cast(
598603
tf.concat([tf.expand_dims(item_batch, axis=-1), negative_samples], axis=1),
599604
dtype=tf.int32,
@@ -612,7 +617,8 @@ def compute_batch_loss(
612617
week_batch=week_batch,
613618
price_batch=augmented_price_batch,
614619
available_item_batch=available_item_batch,
615-
)
620+
user_batch=None,
621+
) # Shape: (batch_size * (n_negative_samples + 1),)
616622

617623
positive_samples_utilities = tf.gather(params=all_utilities, indices=[0], axis=1)
618624
negative_samples_utilities = tf.gather(
@@ -645,6 +651,12 @@ def compute_batch_loss(
645651
),
646652
output=tf.nn.sigmoid(all_utilities),
647653
) # Shape: (batch_size * (n_negative_samples + 1),)
648-
654+
ridge_regularization = self.l2_regularization * tf.add_n(
655+
[tf.nn.l2_loss(weight) for weight in self.trainable_weights]
656+
)
649657
# Normalize by the batch size and the number of negative samples
650-
return tf.reduce_sum(bce) / (batch_size * self.n_negative_samples), loglikelihood
658+
return (
659+
tf.reduce_sum(bce + ridge_regularization)
660+
/ (batch_size * (self.n_negative_samples + 1)),
661+
loglikelihood,
662+
)

0 commit comments

Comments
 (0)