diff --git a/.gitignore b/.gitignore index d646eb5568..1670e78af3 100644 --- a/.gitignore +++ b/.gitignore @@ -16,7 +16,8 @@ templates/examples/graph/* templates/**/guides/**/*.md templates/keras_hub/getting_started.md templates/keras_tuner/getting_started.md +templates/keras_rs/examples/* datasets/* .history .vscode/* -.idea/* +.idea/* \ No newline at end of file diff --git a/scripts/autogen.py b/scripts/autogen.py index a8e7d1dd0d..18f689f2b6 100644 --- a/scripts/autogen.py +++ b/scripts/autogen.py @@ -32,7 +32,7 @@ GUIDES_GH_LOCATION = Path("keras-team") / "keras-io" / "blob" / "master" / "guides" KERAS_TEAM_GH = "https://github.com/keras-team" PROJECT_URL = { - "keras": f"{KERAS_TEAM_GH}/keras/tree/v3.11.3/", + "keras": f"{KERAS_TEAM_GH}/keras/tree/v3.12.0/", "keras_tuner": f"{KERAS_TEAM_GH}/keras-tuner/tree/v1.4.7/", "keras_hub": f"{KERAS_TEAM_GH}/keras-hub/tree/v0.23.0/", "tf_keras": f"{KERAS_TEAM_GH}/tf-keras/tree/v2.19.0/", diff --git a/templates/examples/audio/ctc_asr.md b/templates/examples/audio/ctc_asr.md deleted file mode 100644 index 095e102a45..0000000000 --- a/templates/examples/audio/ctc_asr.md +++ /dev/null @@ -1,659 +0,0 @@ -# Automatic Speech Recognition using CTC - -**Authors:** [Mohamed Reda Bouadjenek](https://rbouadjenek.github.io/) and [Ngoc Dung Huynh](https://www.linkedin.com/in/parkerhuynh/)
-**Date created:** 2021/09/26
-**Last modified:** 2021/09/26
-**Description:** Training a CTC-based model for automatic speech recognition. - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/audio/ipynb/ctc_asr.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/audio/ctc_asr.py) - - - ---- -## Introduction - -Speech recognition is an interdisciplinary subfield of computer science -and computational linguistics that develops methodologies and technologies -that enable the recognition and translation of spoken language into text -by computers. It is also known as automatic speech recognition (ASR), -computer speech recognition or speech to text (STT). It incorporates -knowledge and research in the computer science, linguistics and computer -engineering fields. - -This demonstration shows how to combine a 2D CNN, RNN and a Connectionist -Temporal Classification (CTC) loss to build an ASR. CTC is an algorithm -used to train deep neural networks in speech recognition, handwriting -recognition and other sequence problems. CTC is used when we don’t know -how the input aligns with the output (how the characters in the transcript -align to the audio). The model we create is similar to -[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html). - -We will use the LJSpeech dataset from the -[LibriVox](https://librivox.org/) project. It consists of short -audio clips of a single speaker reading passages from 7 non-fiction books. - -We will evaluate the quality of the model using -[Word Error Rate (WER)](https://en.wikipedia.org/wiki/Word_error_rate). -WER is obtained by adding up -the substitutions, insertions, and deletions that occur in a sequence of -recognized words. Divide that number by the total number of words originally -spoken. The result is the WER. To get the WER score you need to install the -[jiwer](https://pypi.org/project/jiwer/) package. You can use the following command line: - -``` -pip install jiwer -``` - -**References:** - -- [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/) -- [Speech recognition](https://en.wikipedia.org/wiki/Speech_recognition) -- [Sequence Modeling With CTC](https://distill.pub/2017/ctc/) -- [DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html) - ---- -## Setup - - -```python -import pandas as pd -import numpy as np -import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers -import matplotlib.pyplot as plt -from IPython import display -from jiwer import wer - -``` - ---- -## Load the LJSpeech Dataset - -Let's download the [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/). -The dataset contains 13,100 audio files as `wav` files in the `/wavs/` folder. -The label (transcript) for each audio file is a string -given in the `metadata.csv` file. The fields are: - -- **ID**: this is the name of the corresponding .wav file -- **Transcription**: words spoken by the reader (UTF-8) -- **Normalized transcription**: transcription with numbers, -ordinals, and monetary units expanded into full words (UTF-8). - -For this demo we will use on the "Normalized transcription" field. - -Each audio file is a single-channel 16-bit PCM WAV with a sample rate of 22,050 Hz. - - -```python -data_url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2" -data_path = keras.utils.get_file("LJSpeech-1.1", data_url, untar=True) -wavs_path = data_path + "/wavs/" -metadata_path = data_path + "/metadata.csv" - - -# Read metadata file and parse it -metadata_df = pd.read_csv(metadata_path, sep="|", header=None, quoting=3) -metadata_df.columns = ["file_name", "transcription", "normalized_transcription"] -metadata_df = metadata_df[["file_name", "normalized_transcription"]] -metadata_df = metadata_df.sample(frac=1).reset_index(drop=True) -metadata_df.head(3) - -``` - - - - -
- - - - - - - - - - - - - - - - - - - - - - - - - - -
file_namenormalized_transcription
0LJ029-0199On November eighteen the Dallas City Council a...
1LJ028-0237with orders to march into the town by the bed ...
2LJ009-0116On the following day the capital convicts, who...
-
- - - -We now split the data into training and validation set. - - -```python -split = int(len(metadata_df) * 0.90) -df_train = metadata_df[:split] -df_val = metadata_df[split:] - -print(f"Size of the training set: {len(df_train)}") -print(f"Size of the training set: {len(df_val)}") - -``` - -
-``` -Size of the training set: 11790 -Size of the training set: 1310 - -``` -
---- -## Preprocessing - -We first prepare the vocabulary to be used. - - -```python -# The set of characters accepted in the transcription. -characters = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "] -# Mapping characters to integers -char_to_num = keras.layers.StringLookup(vocabulary=characters, oov_token="") -# Mapping integers back to original characters -num_to_char = keras.layers.StringLookup( - vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True -) - -print( - f"The vocabulary is: {char_to_num.get_vocabulary()} " - f"(size ={char_to_num.vocabulary_size()})" -) -``` - -
-``` -The vocabulary is: ['', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', "'", '?', '!', ' '] (size =31) - -``` -
-Next, we create the function that describes the transformation that we apply to each -element of our dataset. - - -```python -# An integer scalar Tensor. The window length in samples. -frame_length = 256 -# An integer scalar Tensor. The number of samples to step. -frame_step = 160 -# An integer scalar Tensor. The size of the FFT to apply. -# If not provided, uses the smallest power of 2 enclosing frame_length. -fft_length = 384 - - -def encode_single_sample(wav_file, label): - ########################################### - ## Process the Audio - ########################################## - # 1. Read wav file - file = tf.io.read_file(wavs_path + wav_file + ".wav") - # 2. Decode the wav file - audio, _ = tf.audio.decode_wav(file) - audio = tf.squeeze(audio, axis=-1) - # 3. Change type to float - audio = tf.cast(audio, tf.float32) - # 4. Get the spectrogram - spectrogram = tf.signal.stft( - audio, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length - ) - # 5. We only need the magnitude, which can be derived by applying tf.abs - spectrogram = tf.abs(spectrogram) - spectrogram = tf.math.pow(spectrogram, 0.5) - # 6. normalisation - means = tf.math.reduce_mean(spectrogram, 1, keepdims=True) - stddevs = tf.math.reduce_std(spectrogram, 1, keepdims=True) - spectrogram = (spectrogram - means) / (stddevs + 1e-10) - ########################################### - ## Process the label - ########################################## - # 7. Convert label to Lower case - label = tf.strings.lower(label) - # 8. Split the label - label = tf.strings.unicode_split(label, input_encoding="UTF-8") - # 9. Map the characters in label to numbers - label = char_to_num(label) - # 10. Return a dict as our model is expecting two inputs - return spectrogram, label - -``` - ---- -## Creating `Dataset` objects - -We create a `tf.data.Dataset` object that yields -the transformed elements, in the same order as they -appeared in the input. - - -```python -batch_size = 32 -# Define the training dataset -train_dataset = tf.data.Dataset.from_tensor_slices( - (list(df_train["file_name"]), list(df_train["normalized_transcription"])) -) -train_dataset = ( - train_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE) - .padded_batch(batch_size) - .prefetch(buffer_size=tf.data.AUTOTUNE) -) - -# Define the validation dataset -validation_dataset = tf.data.Dataset.from_tensor_slices( - (list(df_val["file_name"]), list(df_val["normalized_transcription"])) -) -validation_dataset = ( - validation_dataset.map(encode_single_sample, num_parallel_calls=tf.data.AUTOTUNE) - .padded_batch(batch_size) - .prefetch(buffer_size=tf.data.AUTOTUNE) -) - -``` - ---- -## Visualize the data - -Let's visualize an example in our dataset, including the -audio clip, the spectrogram and the corresponding label. - - -```python -fig = plt.figure(figsize=(8, 5)) -for batch in train_dataset.take(1): - spectrogram = batch[0][0].numpy() - spectrogram = np.array([np.trim_zeros(x) for x in np.transpose(spectrogram)]) - label = batch[1][0] - # Spectrogram - label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8") - ax = plt.subplot(2, 1, 1) - ax.imshow(spectrogram, vmax=1) - ax.set_title(label) - ax.axis("off") - # Wav - file = tf.io.read_file(wavs_path + list(df_train["file_name"])[0] + ".wav") - audio, _ = tf.audio.decode_wav(file) - audio = audio.numpy() - ax = plt.subplot(2, 1, 2) - plt.plot(audio) - ax.set_title("Signal Wave") - ax.set_xlim(0, len(audio)) - display.display(display.Audio(np.transpose(audio), rate=16000)) -plt.show() -``` - - - - - - - - - -![png](/img/examples/audio/ctc_asr/ctc_asr_15_1.png) - - - ---- -## Model - -We first define the CTC Loss function. - - -```python - -def CTCLoss(y_true, y_pred): - # Compute the training-time loss value - batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64") - input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64") - label_length = tf.cast(tf.shape(y_true)[1], dtype="int64") - - input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64") - label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64") - - loss = keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length) - return loss - -``` - -We now define our model. We will define a model similar to -[DeepSpeech2](https://nvidia.github.io/OpenSeq2Seq/html/speech-recognition/deepspeech2.html). - - -```python - -def build_model(input_dim, output_dim, rnn_layers=5, rnn_units=128): - """Model similar to DeepSpeech2.""" - # Model's input - input_spectrogram = layers.Input((None, input_dim), name="input") - # Expand the dimension to use 2D CNN. - x = layers.Reshape((-1, input_dim, 1), name="expand_dim")(input_spectrogram) - # Convolution layer 1 - x = layers.Conv2D( - filters=32, - kernel_size=[11, 41], - strides=[2, 2], - padding="same", - use_bias=False, - name="conv_1", - )(x) - x = layers.BatchNormalization(name="conv_1_bn")(x) - x = layers.ReLU(name="conv_1_relu")(x) - # Convolution layer 2 - x = layers.Conv2D( - filters=32, - kernel_size=[11, 21], - strides=[1, 2], - padding="same", - use_bias=False, - name="conv_2", - )(x) - x = layers.BatchNormalization(name="conv_2_bn")(x) - x = layers.ReLU(name="conv_2_relu")(x) - # Reshape the resulted volume to feed the RNNs layers - x = layers.Reshape((-1, x.shape[-2] * x.shape[-1]))(x) - # RNN layers - for i in range(1, rnn_layers + 1): - recurrent = layers.GRU( - units=rnn_units, - activation="tanh", - recurrent_activation="sigmoid", - use_bias=True, - return_sequences=True, - reset_after=True, - name=f"gru_{i}", - ) - x = layers.Bidirectional( - recurrent, name=f"bidirectional_{i}", merge_mode="concat" - )(x) - if i < rnn_layers: - x = layers.Dropout(rate=0.5)(x) - # Dense layer - x = layers.Dense(units=rnn_units * 2, name="dense_1")(x) - x = layers.ReLU(name="dense_1_relu")(x) - x = layers.Dropout(rate=0.5)(x) - # Classification layer - output = layers.Dense(units=output_dim + 1, activation="softmax")(x) - # Model - model = keras.Model(input_spectrogram, output, name="DeepSpeech_2") - # Optimizer - opt = keras.optimizers.Adam(learning_rate=1e-4) - # Compile the model and return - model.compile(optimizer=opt, loss=CTCLoss) - return model - - -# Get the model -model = build_model( - input_dim=fft_length // 2 + 1, - output_dim=char_to_num.vocabulary_size(), - rnn_units=512, -) -model.summary(line_length=110) -``` - -
-``` -Model: "DeepSpeech_2" -______________________________________________________________________________________________________________ - Layer (type) Output Shape Param # -============================================================================================================== - input (InputLayer) [(None, None, 193)] 0 - - expand_dim (Reshape) (None, None, 193, 1) 0 - - conv_1 (Conv2D) (None, None, 97, 32) 14432 - - conv_1_bn (BatchNormalization) (None, None, 97, 32) 128 - - conv_1_relu (ReLU) (None, None, 97, 32) 0 - - conv_2 (Conv2D) (None, None, 49, 32) 236544 - - conv_2_bn (BatchNormalization) (None, None, 49, 32) 128 - - conv_2_relu (ReLU) (None, None, 49, 32) 0 - - reshape (Reshape) (None, None, 1568) 0 - - bidirectional_1 (Bidirectional) (None, None, 1024) 6395904 - - dropout (Dropout) (None, None, 1024) 0 - - bidirectional_2 (Bidirectional) (None, None, 1024) 4724736 - - dropout_1 (Dropout) (None, None, 1024) 0 - - bidirectional_3 (Bidirectional) (None, None, 1024) 4724736 - - dropout_2 (Dropout) (None, None, 1024) 0 - - bidirectional_4 (Bidirectional) (None, None, 1024) 4724736 - - dropout_3 (Dropout) (None, None, 1024) 0 - - bidirectional_5 (Bidirectional) (None, None, 1024) 4724736 - - dense_1 (Dense) (None, None, 1024) 1049600 - - dense_1_relu (ReLU) (None, None, 1024) 0 - - dropout_4 (Dropout) (None, None, 1024) 0 - - dense (Dense) (None, None, 32) 32800 - -============================================================================================================== -Total params: 26,628,480 -Trainable params: 26,628,352 -Non-trainable params: 128 -______________________________________________________________________________________________________________ - -``` -
---- -## Training and Evaluating - - -```python -# A utility function to decode the output of the network -def decode_batch_predictions(pred): - input_len = np.ones(pred.shape[0]) * pred.shape[1] - # Use greedy search. For complex tasks, you can use beam search - results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0] - # Iterate over the results and get back the text - output_text = [] - for result in results: - result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8") - output_text.append(result) - return output_text - - -# A callback class to output a few transcriptions during training -class CallbackEval(keras.callbacks.Callback): - """Displays a batch of outputs after every epoch.""" - - def __init__(self, dataset): - super().__init__() - self.dataset = dataset - - def on_epoch_end(self, epoch: int, logs=None): - predictions = [] - targets = [] - for batch in self.dataset: - X, y = batch - batch_predictions = model.predict(X) - batch_predictions = decode_batch_predictions(batch_predictions) - predictions.extend(batch_predictions) - for label in y: - label = ( - tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8") - ) - targets.append(label) - wer_score = wer(targets, predictions) - print("-" * 100) - print(f"Word Error Rate: {wer_score:.4f}") - print("-" * 100) - for i in np.random.randint(0, len(predictions), 2): - print(f"Target : {targets[i]}") - print(f"Prediction: {predictions[i]}") - print("-" * 100) - -``` - -Let's start the training process. - - -```python -# Define the number of epochs. -epochs = 1 -# Callback function to check transcription on the val set. -validation_callback = CallbackEval(validation_dataset) -# Train the model -history = model.fit( - train_dataset, - validation_data=validation_dataset, - epochs=epochs, - callbacks=[validation_callback], -) - -``` - -
-``` -369/369 [==============================] - ETA: 0s - loss: 302.4755---------------------------------------------------------------------------------------------------- -Word Error Rate: 1.0000 ----------------------------------------------------------------------------------------------------- -Target : special agent lyndal l shaneyfelt a photography expert with the fbi -Prediction: s ----------------------------------------------------------------------------------------------------- -Target : dissolved in water the sugar is transported down delicate tubes chiefly in the growing bark region of the stem -Prediction: sss ----------------------------------------------------------------------------------------------------- -369/369 [==============================] - 407s 1s/step - loss: 302.4755 - val_loss: 252.1534 - -``` -
---- -## Inference - - -```python -# Let's check results on more validation samples -predictions = [] -targets = [] -for batch in validation_dataset: - X, y = batch - batch_predictions = model.predict(X) - batch_predictions = decode_batch_predictions(batch_predictions) - predictions.extend(batch_predictions) - for label in y: - label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8") - targets.append(label) -wer_score = wer(targets, predictions) -print("-" * 100) -print(f"Word Error Rate: {wer_score:.4f}") -print("-" * 100) -for i in np.random.randint(0, len(predictions), 5): - print(f"Target : {targets[i]}") - print(f"Prediction: {predictions[i]}") - print("-" * 100) - -``` - -
-``` ----------------------------------------------------------------------------------------------------- -Word Error Rate: 1.0000 ----------------------------------------------------------------------------------------------------- -Target : the owners of the latter would then issue a second set of warrants on these goods in total ignorance of the fact that they were already pledged -Prediction: ssnssss ----------------------------------------------------------------------------------------------------- -Target : till the whole body of the slaves were manumitted in eighteen thirtythree -Prediction: sr ----------------------------------------------------------------------------------------------------- -Target : the committee most of all insisted upon the entire individual separation of prisoners except during the hours of labor -Prediction: ssssss ----------------------------------------------------------------------------------------------------- -Target : he made no attempt to help her and there are other indications that he did not want her to learn that language -Prediction: s ----------------------------------------------------------------------------------------------------- -Target : the building of the babylon so famous in history began with nabopolassar -Prediction: sssrs ----------------------------------------------------------------------------------------------------- - -``` -
---- -## Conclusion - -In practice, you should train for around 50 epochs or more. Each epoch -takes approximately 5-6mn using a `GeForce RTX 2080 Ti` GPU. -The model we trained at 50 epochs has a `Word Error Rate (WER) ≈ 16% to 17%`. - -Some of the transcriptions around epoch 50: - -**Audio file: LJ017-0009.wav** -``` -- Target : sir thomas overbury was undoubtedly poisoned by lord rochester in the reign -of james the first -- Prediction: cer thomas overbery was undoubtedly poisoned by lordrochester in the reign -of james the first -``` - -**Audio file: LJ003-0340.wav** -``` -- Target : the committee does not seem to have yet understood that newgate could be -only and properly replaced -- Prediction: the committee does not seem to have yet understood that newgate could be -only and proberly replace -``` - -**Audio file: LJ011-0136.wav** -``` -- Target : still no sentence of death was carried out for the offense and in eighteen -thirtytwo -- Prediction: still no sentence of death was carried out for the offense and in eighteen -thirtytwo -``` - -Example available on HuggingFace. -| Trained Model | Demo | -| :--: | :--: | -| [![Generic badge](https://img.shields.io/badge/🤗%20Model-CTC%20ASR-black.svg)](https://huggingface.co/keras-io/ctc_asr) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-CTC%20ASR-black.svg)](https://huggingface.co/spaces/keras-io/ctc_asr) | - diff --git a/templates/examples/audio/melgan_spectrogram_inversion.md b/templates/examples/audio/melgan_spectrogram_inversion.md deleted file mode 100644 index eef8effd1f..0000000000 --- a/templates/examples/audio/melgan_spectrogram_inversion.md +++ /dev/null @@ -1,953 +0,0 @@ -# MelGAN-based spectrogram inversion using feature matching - -**Author:** [Darshan Deshpande](https://twitter.com/getdarshan)
-**Date created:** 02/09/2021
-**Last modified:** 15/09/2021
- - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/audio/ipynb/melgan_spectrogram_inversion.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/audio/melgan_spectrogram_inversion.py) - - -**Description:** Inversion of audio from mel-spectrograms using the MelGAN architecture and feature matching. - ---- -## Introduction - -Autoregressive vocoders have been ubiquitous for a majority of the history of speech processing, -but for most of their existence they have lacked parallelism. -[MelGAN](https://arxiv.org/pdf/1910.06711v3.pdf) is a -non-autoregressive, fully convolutional vocoder architecture used for purposes ranging -from spectral inversion and speech enhancement to present-day state-of-the-art -speech synthesis when used as a decoder -with models like Tacotron2 or FastSpeech that convert text to mel spectrograms. - -In this tutorial, we will have a look at the MelGAN architecture and how it can achieve -fast spectral inversion, i.e. conversion of spectrograms to audio waves. The MelGAN -implemented in this tutorial is similar to the original implementation with only the -difference of method of padding for convolutions where we will use 'same' instead of -reflect padding. - ---- -## Importing and Defining Hyperparameters - - -```python -!pip install -qqq tensorflow_addons -!pip install -qqq tensorflow-io -``` - -```python -import tensorflow as tf -import tensorflow_io as tfio -from tensorflow import keras -from tensorflow.keras import layers -from tensorflow_addons import layers as addon_layers - -# Setting logger level to avoid input shape warnings -tf.get_logger().setLevel("ERROR") - -# Defining hyperparameters - -DESIRED_SAMPLES = 8192 -LEARNING_RATE_GEN = 1e-5 -LEARNING_RATE_DISC = 1e-6 -BATCH_SIZE = 16 - -mse = keras.losses.MeanSquaredError() -mae = keras.losses.MeanAbsoluteError() -``` -
-``` -|████████████████████████████████| 1.1 MB 5.1 MB/s -|████████████████████████████████| 22.7 MB 1.7 MB/s -|████████████████████████████████| 2.1 MB 36.2 MB/s - -``` -
---- -## Loading the Dataset - -This example uses the [LJSpeech dataset](https://keithito.com/LJ-Speech-Dataset/). - -The LJSpeech dataset is primarily used for text-to-speech and consists of 13,100 discrete -speech samples taken from 7 non-fiction books, having a total length of approximately 24 -hours. The MelGAN training is only concerned with the audio waves so we process only the -WAV files and ignore the audio annotations. - - -```python -!wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 -!tar -xf /content/LJSpeech-1.1.tar.bz2 -``` - -
-``` ---2021-09-16 11:45:24-- https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 -Resolving data.keithito.com (data.keithito.com)... 174.138.79.61 -Connecting to data.keithito.com (data.keithito.com)|174.138.79.61|:443... connected. -HTTP request sent, awaiting response... 200 OK -Length: 2748572632 (2.6G) [application/octet-stream] -Saving to: ‘LJSpeech-1.1.tar.bz2’ -``` -
- -
-``` -LJSpeech-1.1.tar.bz 100%[===================>] 2.56G 68.3MB/s in 36s -``` -
- -
-``` -2021-09-16 11:46:01 (72.2 MB/s) - ‘LJSpeech-1.1.tar.bz2’ saved [2748572632/2748572632] -``` -
- - - -We create a `tf.data.Dataset` to load and process the audio files on the fly. -The `preprocess()` function takes the file path as input and returns two instances of the -wave, one for input and one as the ground truth for comparison. The input wave will be -mapped to a spectrogram using the custom `MelSpec` layer as shown later in this example. - - -```python -# Splitting the dataset into training and testing splits -wavs = tf.io.gfile.glob("LJSpeech-1.1/wavs/*.wav") -print(f"Number of audio files: {len(wavs)}") - -# Mapper function for loading the audio. This function returns two instances of the wave -def preprocess(filename): - audio = tf.audio.decode_wav(tf.io.read_file(filename), 1, DESIRED_SAMPLES).audio - return audio, audio - - -# Create tf.data.Dataset objects and apply preprocessing -train_dataset = tf.data.Dataset.from_tensor_slices((wavs,)) -train_dataset = train_dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE) -``` - -
-``` -Number of audio files: 13100 - -``` -
---- -## Defining custom layers for MelGAN - -The MelGAN architecture consists of 3 main modules: - -1. The residual block -2. Dilated convolutional block -3. Discriminator block - -![MelGAN](https://i.imgur.com/ZdxwzPG.png) - -Since the network takes a mel-spectrogram as input, we will create an additional custom -layer -which can convert the raw audio wave to a spectrogram on-the-fly. We use the raw audio -tensor from `train_dataset` and map it to a mel-spectrogram using the `MelSpec` layer -below. - - -```python -# Custom keras layer for on-the-fly audio to spectrogram conversion - - -class MelSpec(layers.Layer): - def __init__( - self, - frame_length=1024, - frame_step=256, - fft_length=None, - sampling_rate=22050, - num_mel_channels=80, - freq_min=125, - freq_max=7600, - **kwargs, - ): - super().__init__(**kwargs) - self.frame_length = frame_length - self.frame_step = frame_step - self.fft_length = fft_length - self.sampling_rate = sampling_rate - self.num_mel_channels = num_mel_channels - self.freq_min = freq_min - self.freq_max = freq_max - # Defining mel filter. This filter will be multiplied with the STFT output - self.mel_filterbank = tf.signal.linear_to_mel_weight_matrix( - num_mel_bins=self.num_mel_channels, - num_spectrogram_bins=self.frame_length // 2 + 1, - sample_rate=self.sampling_rate, - lower_edge_hertz=self.freq_min, - upper_edge_hertz=self.freq_max, - ) - - def call(self, audio, training=True): - # We will only perform the transformation during training. - if training: - # Taking the Short Time Fourier Transform. Ensure that the audio is padded. - # In the paper, the STFT output is padded using the 'REFLECT' strategy. - stft = tf.signal.stft( - tf.squeeze(audio, -1), - self.frame_length, - self.frame_step, - self.fft_length, - pad_end=True, - ) - - # Taking the magnitude of the STFT output - magnitude = tf.abs(stft) - - # Multiplying the Mel-filterbank with the magnitude and scaling it using the db scale - mel = tf.matmul(tf.square(magnitude), self.mel_filterbank) - log_mel_spec = tfio.audio.dbscale(mel, top_db=80) - return log_mel_spec - else: - return audio - - def get_config(self): - config = super().get_config() - config.update( - { - "frame_length": self.frame_length, - "frame_step": self.frame_step, - "fft_length": self.fft_length, - "sampling_rate": self.sampling_rate, - "num_mel_channels": self.num_mel_channels, - "freq_min": self.freq_min, - "freq_max": self.freq_max, - } - ) - return config - -``` - -The residual convolutional block extensively uses dilations and has a total receptive -field of 27 timesteps per block. The dilations must grow as a power of the `kernel_size` -to ensure reduction of hissing noise in the output. The network proposed by the paper is -as follows: - -![ConvBlock](https://i.imgur.com/sFnnsCll.jpg) - - -```python -# Creating the residual stack block - - -def residual_stack(input, filters): - """Convolutional residual stack with weight normalization. - - Args: - filters: int, determines filter size for the residual stack. - - Returns: - Residual stack output. - """ - c1 = addon_layers.WeightNormalization( - layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False - )(input) - lrelu1 = layers.LeakyReLU()(c1) - c2 = addon_layers.WeightNormalization( - layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False - )(lrelu1) - add1 = layers.Add()([c2, input]) - - lrelu2 = layers.LeakyReLU()(add1) - c3 = addon_layers.WeightNormalization( - layers.Conv1D(filters, 3, dilation_rate=3, padding="same"), data_init=False - )(lrelu2) - lrelu3 = layers.LeakyReLU()(c3) - c4 = addon_layers.WeightNormalization( - layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False - )(lrelu3) - add2 = layers.Add()([add1, c4]) - - lrelu4 = layers.LeakyReLU()(add2) - c5 = addon_layers.WeightNormalization( - layers.Conv1D(filters, 3, dilation_rate=9, padding="same"), data_init=False - )(lrelu4) - lrelu5 = layers.LeakyReLU()(c5) - c6 = addon_layers.WeightNormalization( - layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False - )(lrelu5) - add3 = layers.Add()([c6, add2]) - - return add3 - -``` - -Each convolutional block uses the dilations offered by the residual stack -and upsamples the input data by the `upsampling_factor`. - - -```python -# Dilated convolutional block consisting of the Residual stack - - -def conv_block(input, conv_dim, upsampling_factor): - """Dilated Convolutional Block with weight normalization. - - Args: - conv_dim: int, determines filter size for the block. - upsampling_factor: int, scale for upsampling. - - Returns: - Dilated convolution block. - """ - conv_t = addon_layers.WeightNormalization( - layers.Conv1DTranspose(conv_dim, 16, upsampling_factor, padding="same"), - data_init=False, - )(input) - lrelu1 = layers.LeakyReLU()(conv_t) - res_stack = residual_stack(lrelu1, conv_dim) - lrelu2 = layers.LeakyReLU()(res_stack) - return lrelu2 - -``` - -The discriminator block consists of convolutions and downsampling layers. This block is -essential for the implementation of the feature matching technique. - -Each discriminator outputs a list of feature maps that will be compared during training -to compute the feature matching loss. - - -```python - -def discriminator_block(input): - conv1 = addon_layers.WeightNormalization( - layers.Conv1D(16, 15, 1, "same"), data_init=False - )(input) - lrelu1 = layers.LeakyReLU()(conv1) - conv2 = addon_layers.WeightNormalization( - layers.Conv1D(64, 41, 4, "same", groups=4), data_init=False - )(lrelu1) - lrelu2 = layers.LeakyReLU()(conv2) - conv3 = addon_layers.WeightNormalization( - layers.Conv1D(256, 41, 4, "same", groups=16), data_init=False - )(lrelu2) - lrelu3 = layers.LeakyReLU()(conv3) - conv4 = addon_layers.WeightNormalization( - layers.Conv1D(1024, 41, 4, "same", groups=64), data_init=False - )(lrelu3) - lrelu4 = layers.LeakyReLU()(conv4) - conv5 = addon_layers.WeightNormalization( - layers.Conv1D(1024, 41, 4, "same", groups=256), data_init=False - )(lrelu4) - lrelu5 = layers.LeakyReLU()(conv5) - conv6 = addon_layers.WeightNormalization( - layers.Conv1D(1024, 5, 1, "same"), data_init=False - )(lrelu5) - lrelu6 = layers.LeakyReLU()(conv6) - conv7 = addon_layers.WeightNormalization( - layers.Conv1D(1, 3, 1, "same"), data_init=False - )(lrelu6) - return [lrelu1, lrelu2, lrelu3, lrelu4, lrelu5, lrelu6, conv7] - -``` - -### Create the generator - - -```python - -def create_generator(input_shape): - inp = keras.Input(input_shape) - x = MelSpec()(inp) - x = layers.Conv1D(512, 7, padding="same")(x) - x = layers.LeakyReLU()(x) - x = conv_block(x, 256, 8) - x = conv_block(x, 128, 8) - x = conv_block(x, 64, 2) - x = conv_block(x, 32, 2) - x = addon_layers.WeightNormalization( - layers.Conv1D(1, 7, padding="same", activation="tanh") - )(x) - return keras.Model(inp, x) - - -# We use a dynamic input shape for the generator since the model is fully convolutional -generator = create_generator((None, 1)) -generator.summary() -``` - -
-``` -Model: "model" -__________________________________________________________________________________________________ -Layer (type) Output Shape Param # Connected to -================================================================================================== -input_1 (InputLayer) [(None, None, 1)] 0 -__________________________________________________________________________________________________ -mel_spec (MelSpec) (None, None, 80) 0 input_1[0][0] -__________________________________________________________________________________________________ -conv1d (Conv1D) (None, None, 512) 287232 mel_spec[0][0] -__________________________________________________________________________________________________ -leaky_re_lu (LeakyReLU) (None, None, 512) 0 conv1d[0][0] -__________________________________________________________________________________________________ -weight_normalization (WeightNor (None, None, 256) 2097921 leaky_re_lu[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_1 (LeakyReLU) (None, None, 256) 0 weight_normalization[0][0] -__________________________________________________________________________________________________ -weight_normalization_1 (WeightN (None, None, 256) 197121 leaky_re_lu_1[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_2 (LeakyReLU) (None, None, 256) 0 weight_normalization_1[0][0] -__________________________________________________________________________________________________ -weight_normalization_2 (WeightN (None, None, 256) 197121 leaky_re_lu_2[0][0] -__________________________________________________________________________________________________ -add (Add) (None, None, 256) 0 weight_normalization_2[0][0] - leaky_re_lu_1[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_3 (LeakyReLU) (None, None, 256) 0 add[0][0] -__________________________________________________________________________________________________ -weight_normalization_3 (WeightN (None, None, 256) 197121 leaky_re_lu_3[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_4 (LeakyReLU) (None, None, 256) 0 weight_normalization_3[0][0] -__________________________________________________________________________________________________ -weight_normalization_4 (WeightN (None, None, 256) 197121 leaky_re_lu_4[0][0] -__________________________________________________________________________________________________ -add_1 (Add) (None, None, 256) 0 add[0][0] - weight_normalization_4[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_5 (LeakyReLU) (None, None, 256) 0 add_1[0][0] -__________________________________________________________________________________________________ -weight_normalization_5 (WeightN (None, None, 256) 197121 leaky_re_lu_5[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_6 (LeakyReLU) (None, None, 256) 0 weight_normalization_5[0][0] -__________________________________________________________________________________________________ -weight_normalization_6 (WeightN (None, None, 256) 197121 leaky_re_lu_6[0][0] -__________________________________________________________________________________________________ -add_2 (Add) (None, None, 256) 0 weight_normalization_6[0][0] - add_1[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_7 (LeakyReLU) (None, None, 256) 0 add_2[0][0] -__________________________________________________________________________________________________ -weight_normalization_7 (WeightN (None, None, 128) 524673 leaky_re_lu_7[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_8 (LeakyReLU) (None, None, 128) 0 weight_normalization_7[0][0] -__________________________________________________________________________________________________ -weight_normalization_8 (WeightN (None, None, 128) 49409 leaky_re_lu_8[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_9 (LeakyReLU) (None, None, 128) 0 weight_normalization_8[0][0] -__________________________________________________________________________________________________ -weight_normalization_9 (WeightN (None, None, 128) 49409 leaky_re_lu_9[0][0] -__________________________________________________________________________________________________ -add_3 (Add) (None, None, 128) 0 weight_normalization_9[0][0] - leaky_re_lu_8[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_10 (LeakyReLU) (None, None, 128) 0 add_3[0][0] -__________________________________________________________________________________________________ -weight_normalization_10 (Weight (None, None, 128) 49409 leaky_re_lu_10[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_11 (LeakyReLU) (None, None, 128) 0 weight_normalization_10[0][0] -__________________________________________________________________________________________________ -weight_normalization_11 (Weight (None, None, 128) 49409 leaky_re_lu_11[0][0] -__________________________________________________________________________________________________ -add_4 (Add) (None, None, 128) 0 add_3[0][0] - weight_normalization_11[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_12 (LeakyReLU) (None, None, 128) 0 add_4[0][0] -__________________________________________________________________________________________________ -weight_normalization_12 (Weight (None, None, 128) 49409 leaky_re_lu_12[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_13 (LeakyReLU) (None, None, 128) 0 weight_normalization_12[0][0] -__________________________________________________________________________________________________ -weight_normalization_13 (Weight (None, None, 128) 49409 leaky_re_lu_13[0][0] -__________________________________________________________________________________________________ -add_5 (Add) (None, None, 128) 0 weight_normalization_13[0][0] - add_4[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_14 (LeakyReLU) (None, None, 128) 0 add_5[0][0] -__________________________________________________________________________________________________ -weight_normalization_14 (Weight (None, None, 64) 131265 leaky_re_lu_14[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_15 (LeakyReLU) (None, None, 64) 0 weight_normalization_14[0][0] -__________________________________________________________________________________________________ -weight_normalization_15 (Weight (None, None, 64) 12417 leaky_re_lu_15[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_16 (LeakyReLU) (None, None, 64) 0 weight_normalization_15[0][0] -__________________________________________________________________________________________________ -weight_normalization_16 (Weight (None, None, 64) 12417 leaky_re_lu_16[0][0] -__________________________________________________________________________________________________ -add_6 (Add) (None, None, 64) 0 weight_normalization_16[0][0] - leaky_re_lu_15[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_17 (LeakyReLU) (None, None, 64) 0 add_6[0][0] -__________________________________________________________________________________________________ -weight_normalization_17 (Weight (None, None, 64) 12417 leaky_re_lu_17[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_18 (LeakyReLU) (None, None, 64) 0 weight_normalization_17[0][0] -__________________________________________________________________________________________________ -weight_normalization_18 (Weight (None, None, 64) 12417 leaky_re_lu_18[0][0] -__________________________________________________________________________________________________ -add_7 (Add) (None, None, 64) 0 add_6[0][0] - weight_normalization_18[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_19 (LeakyReLU) (None, None, 64) 0 add_7[0][0] -__________________________________________________________________________________________________ -weight_normalization_19 (Weight (None, None, 64) 12417 leaky_re_lu_19[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_20 (LeakyReLU) (None, None, 64) 0 weight_normalization_19[0][0] -__________________________________________________________________________________________________ -weight_normalization_20 (Weight (None, None, 64) 12417 leaky_re_lu_20[0][0] -__________________________________________________________________________________________________ -add_8 (Add) (None, None, 64) 0 weight_normalization_20[0][0] - add_7[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_21 (LeakyReLU) (None, None, 64) 0 add_8[0][0] -__________________________________________________________________________________________________ -weight_normalization_21 (Weight (None, None, 32) 32865 leaky_re_lu_21[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_22 (LeakyReLU) (None, None, 32) 0 weight_normalization_21[0][0] -__________________________________________________________________________________________________ -weight_normalization_22 (Weight (None, None, 32) 3137 leaky_re_lu_22[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_23 (LeakyReLU) (None, None, 32) 0 weight_normalization_22[0][0] -__________________________________________________________________________________________________ -weight_normalization_23 (Weight (None, None, 32) 3137 leaky_re_lu_23[0][0] -__________________________________________________________________________________________________ -add_9 (Add) (None, None, 32) 0 weight_normalization_23[0][0] - leaky_re_lu_22[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_24 (LeakyReLU) (None, None, 32) 0 add_9[0][0] -__________________________________________________________________________________________________ -weight_normalization_24 (Weight (None, None, 32) 3137 leaky_re_lu_24[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_25 (LeakyReLU) (None, None, 32) 0 weight_normalization_24[0][0] -__________________________________________________________________________________________________ -weight_normalization_25 (Weight (None, None, 32) 3137 leaky_re_lu_25[0][0] -__________________________________________________________________________________________________ -add_10 (Add) (None, None, 32) 0 add_9[0][0] - weight_normalization_25[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_26 (LeakyReLU) (None, None, 32) 0 add_10[0][0] -__________________________________________________________________________________________________ -weight_normalization_26 (Weight (None, None, 32) 3137 leaky_re_lu_26[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_27 (LeakyReLU) (None, None, 32) 0 weight_normalization_26[0][0] -__________________________________________________________________________________________________ -weight_normalization_27 (Weight (None, None, 32) 3137 leaky_re_lu_27[0][0] -__________________________________________________________________________________________________ -add_11 (Add) (None, None, 32) 0 weight_normalization_27[0][0] - add_10[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_28 (LeakyReLU) (None, None, 32) 0 add_11[0][0] -__________________________________________________________________________________________________ -weight_normalization_28 (Weight (None, None, 1) 452 leaky_re_lu_28[0][0] -================================================================================================== -Total params: 4,646,912 -Trainable params: 4,646,658 -Non-trainable params: 254 -__________________________________________________________________________________________________ - -``` -
-### Create the discriminator - - -```python - -def create_discriminator(input_shape): - inp = keras.Input(input_shape) - out_map1 = discriminator_block(inp) - pool1 = layers.AveragePooling1D()(inp) - out_map2 = discriminator_block(pool1) - pool2 = layers.AveragePooling1D()(pool1) - out_map3 = discriminator_block(pool2) - return keras.Model(inp, [out_map1, out_map2, out_map3]) - - -# We use a dynamic input shape for the discriminator -# This is done because the input shape for the generator is unknown -discriminator = create_discriminator((None, 1)) - -discriminator.summary() -``` - -
-``` -Model: "model_1" -__________________________________________________________________________________________________ -Layer (type) Output Shape Param # Connected to -================================================================================================== -input_2 (InputLayer) [(None, None, 1)] 0 -__________________________________________________________________________________________________ -average_pooling1d (AveragePooli (None, None, 1) 0 input_2[0][0] -__________________________________________________________________________________________________ -average_pooling1d_1 (AveragePoo (None, None, 1) 0 average_pooling1d[0][0] -__________________________________________________________________________________________________ -weight_normalization_29 (Weight (None, None, 16) 273 input_2[0][0] -__________________________________________________________________________________________________ -weight_normalization_36 (Weight (None, None, 16) 273 average_pooling1d[0][0] -__________________________________________________________________________________________________ -weight_normalization_43 (Weight (None, None, 16) 273 average_pooling1d_1[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_29 (LeakyReLU) (None, None, 16) 0 weight_normalization_29[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_35 (LeakyReLU) (None, None, 16) 0 weight_normalization_36[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_41 (LeakyReLU) (None, None, 16) 0 weight_normalization_43[0][0] -__________________________________________________________________________________________________ -weight_normalization_30 (Weight (None, None, 64) 10625 leaky_re_lu_29[0][0] -__________________________________________________________________________________________________ -weight_normalization_37 (Weight (None, None, 64) 10625 leaky_re_lu_35[0][0] -__________________________________________________________________________________________________ -weight_normalization_44 (Weight (None, None, 64) 10625 leaky_re_lu_41[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_30 (LeakyReLU) (None, None, 64) 0 weight_normalization_30[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_36 (LeakyReLU) (None, None, 64) 0 weight_normalization_37[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_42 (LeakyReLU) (None, None, 64) 0 weight_normalization_44[0][0] -__________________________________________________________________________________________________ -weight_normalization_31 (Weight (None, None, 256) 42497 leaky_re_lu_30[0][0] -__________________________________________________________________________________________________ -weight_normalization_38 (Weight (None, None, 256) 42497 leaky_re_lu_36[0][0] -__________________________________________________________________________________________________ -weight_normalization_45 (Weight (None, None, 256) 42497 leaky_re_lu_42[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_31 (LeakyReLU) (None, None, 256) 0 weight_normalization_31[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_37 (LeakyReLU) (None, None, 256) 0 weight_normalization_38[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_43 (LeakyReLU) (None, None, 256) 0 weight_normalization_45[0][0] -__________________________________________________________________________________________________ -weight_normalization_32 (Weight (None, None, 1024) 169985 leaky_re_lu_31[0][0] -__________________________________________________________________________________________________ -weight_normalization_39 (Weight (None, None, 1024) 169985 leaky_re_lu_37[0][0] -__________________________________________________________________________________________________ -weight_normalization_46 (Weight (None, None, 1024) 169985 leaky_re_lu_43[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_32 (LeakyReLU) (None, None, 1024) 0 weight_normalization_32[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_38 (LeakyReLU) (None, None, 1024) 0 weight_normalization_39[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_44 (LeakyReLU) (None, None, 1024) 0 weight_normalization_46[0][0] -__________________________________________________________________________________________________ -weight_normalization_33 (Weight (None, None, 1024) 169985 leaky_re_lu_32[0][0] -__________________________________________________________________________________________________ -weight_normalization_40 (Weight (None, None, 1024) 169985 leaky_re_lu_38[0][0] -__________________________________________________________________________________________________ -weight_normalization_47 (Weight (None, None, 1024) 169985 leaky_re_lu_44[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_33 (LeakyReLU) (None, None, 1024) 0 weight_normalization_33[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_39 (LeakyReLU) (None, None, 1024) 0 weight_normalization_40[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_45 (LeakyReLU) (None, None, 1024) 0 weight_normalization_47[0][0] -__________________________________________________________________________________________________ -weight_normalization_34 (Weight (None, None, 1024) 5244929 leaky_re_lu_33[0][0] -__________________________________________________________________________________________________ -weight_normalization_41 (Weight (None, None, 1024) 5244929 leaky_re_lu_39[0][0] -__________________________________________________________________________________________________ -weight_normalization_48 (Weight (None, None, 1024) 5244929 leaky_re_lu_45[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_34 (LeakyReLU) (None, None, 1024) 0 weight_normalization_34[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_40 (LeakyReLU) (None, None, 1024) 0 weight_normalization_41[0][0] -__________________________________________________________________________________________________ -leaky_re_lu_46 (LeakyReLU) (None, None, 1024) 0 weight_normalization_48[0][0] -__________________________________________________________________________________________________ -weight_normalization_35 (Weight (None, None, 1) 3075 leaky_re_lu_34[0][0] -__________________________________________________________________________________________________ -weight_normalization_42 (Weight (None, None, 1) 3075 leaky_re_lu_40[0][0] -__________________________________________________________________________________________________ -weight_normalization_49 (Weight (None, None, 1) 3075 leaky_re_lu_46[0][0] -================================================================================================== -Total params: 16,924,107 -Trainable params: 16,924,086 -Non-trainable params: 21 -__________________________________________________________________________________________________ - -``` -
---- -## Defining the loss functions - -**Generator Loss** - -The generator architecture uses a combination of two losses - -1. Mean Squared Error: - -This is the standard MSE generator loss calculated between ones and the outputs from the -discriminator with _N_ layers. - -

- -

- -2. Feature Matching Loss: - -This loss involves extracting the outputs of every layer from the discriminator for both -the generator and ground truth and compare each layer output _k_ using Mean Absolute Error. - -

- -

- -**Discriminator Loss** - -The discriminator uses the Mean Absolute Error and compares the real data predictions -with ones and generated predictions with zeros. - -

- -

- - -```python -# Generator loss - - -def generator_loss(real_pred, fake_pred): - """Loss function for the generator. - - Args: - real_pred: Tensor, output of the ground truth wave passed through the discriminator. - fake_pred: Tensor, output of the generator prediction passed through the discriminator. - - Returns: - Loss for the generator. - """ - gen_loss = [] - for i in range(len(fake_pred)): - gen_loss.append(mse(tf.ones_like(fake_pred[i][-1]), fake_pred[i][-1])) - - return tf.reduce_mean(gen_loss) - - -def feature_matching_loss(real_pred, fake_pred): - """Implements the feature matching loss. - - Args: - real_pred: Tensor, output of the ground truth wave passed through the discriminator. - fake_pred: Tensor, output of the generator prediction passed through the discriminator. - - Returns: - Feature Matching Loss. - """ - fm_loss = [] - for i in range(len(fake_pred)): - for j in range(len(fake_pred[i]) - 1): - fm_loss.append(mae(real_pred[i][j], fake_pred[i][j])) - - return tf.reduce_mean(fm_loss) - - -def discriminator_loss(real_pred, fake_pred): - """Implements the discriminator loss. - - Args: - real_pred: Tensor, output of the ground truth wave passed through the discriminator. - fake_pred: Tensor, output of the generator prediction passed through the discriminator. - - Returns: - Discriminator Loss. - """ - real_loss, fake_loss = [], [] - for i in range(len(real_pred)): - real_loss.append(mse(tf.ones_like(real_pred[i][-1]), real_pred[i][-1])) - fake_loss.append(mse(tf.zeros_like(fake_pred[i][-1]), fake_pred[i][-1])) - - # Calculating the final discriminator loss after scaling - disc_loss = tf.reduce_mean(real_loss) + tf.reduce_mean(fake_loss) - return disc_loss - -``` - -Defining the MelGAN model for training. -This subclass overrides the `train_step()` method to implement the training logic. - - -```python - -class MelGAN(keras.Model): - def __init__(self, generator, discriminator, **kwargs): - """MelGAN trainer class - - Args: - generator: keras.Model, Generator model - discriminator: keras.Model, Discriminator model - """ - super().__init__(**kwargs) - self.generator = generator - self.discriminator = discriminator - - def compile( - self, - gen_optimizer, - disc_optimizer, - generator_loss, - feature_matching_loss, - discriminator_loss, - ): - """MelGAN compile method. - - Args: - gen_optimizer: keras.optimizer, optimizer to be used for training - disc_optimizer: keras.optimizer, optimizer to be used for training - generator_loss: callable, loss function for generator - feature_matching_loss: callable, loss function for feature matching - discriminator_loss: callable, loss function for discriminator - """ - super().compile() - - # Optimizers - self.gen_optimizer = gen_optimizer - self.disc_optimizer = disc_optimizer - - # Losses - self.generator_loss = generator_loss - self.feature_matching_loss = feature_matching_loss - self.discriminator_loss = discriminator_loss - - # Trackers - self.gen_loss_tracker = keras.metrics.Mean(name="gen_loss") - self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss") - - def train_step(self, batch): - x_batch_train, y_batch_train = batch - - with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: - # Generating the audio wave - gen_audio_wave = generator(x_batch_train, training=True) - - # Generating the features using the discriminator - real_pred = discriminator(y_batch_train) - fake_pred = discriminator(gen_audio_wave) - - # Calculating the generator losses - gen_loss = generator_loss(real_pred, fake_pred) - fm_loss = feature_matching_loss(real_pred, fake_pred) - - # Calculating final generator loss - gen_fm_loss = gen_loss + 10 * fm_loss - - # Calculating the discriminator losses - disc_loss = discriminator_loss(real_pred, fake_pred) - - # Calculating and applying the gradients for generator and discriminator - grads_gen = gen_tape.gradient(gen_fm_loss, generator.trainable_weights) - grads_disc = disc_tape.gradient(disc_loss, discriminator.trainable_weights) - gen_optimizer.apply_gradients(zip(grads_gen, generator.trainable_weights)) - disc_optimizer.apply_gradients(zip(grads_disc, discriminator.trainable_weights)) - - self.gen_loss_tracker.update_state(gen_fm_loss) - self.disc_loss_tracker.update_state(disc_loss) - - return { - "gen_loss": self.gen_loss_tracker.result(), - "disc_loss": self.disc_loss_tracker.result(), - } - -``` - ---- -## Training - -The paper suggests that the training with dynamic shapes takes around 400,000 steps (~500 -epochs). For this example, we will run it only for a single epoch (819 steps). -Longer training time (greater than 300 epochs) will almost certainly provide better results. - - -```python -gen_optimizer = keras.optimizers.Adam( - LEARNING_RATE_GEN, beta_1=0.5, beta_2=0.9, clipnorm=1 -) -disc_optimizer = keras.optimizers.Adam( - LEARNING_RATE_DISC, beta_1=0.5, beta_2=0.9, clipnorm=1 -) - -# Start training -generator = create_generator((None, 1)) -discriminator = create_discriminator((None, 1)) - -mel_gan = MelGAN(generator, discriminator) -mel_gan.compile( - gen_optimizer, - disc_optimizer, - generator_loss, - feature_matching_loss, - discriminator_loss, -) -mel_gan.fit( - train_dataset.shuffle(200).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE), epochs=1 -) -``` - -
-``` -819/819 [==============================] - 641s 696ms/step - gen_loss: 0.9761 - disc_loss: 0.9350 - - - -``` -
---- -## Testing the model - -The trained model can now be used for real time text-to-speech translation tasks. -To test how fast the MelGAN inference can be, let us take a sample audio mel-spectrogram -and convert it. Note that the actual model pipeline will not include the `MelSpec` layer -and hence this layer will be disabled during inference. The inference input will be a -mel-spectrogram processed similar to the `MelSpec` layer configuration. - -For testing this, we will create a randomly uniformly distributed tensor to simulate the -behavior of the inference pipeline. - - -```python -# Sampling a random tensor to mimic a batch of 128 spectrograms of shape [50, 80] -audio_sample = tf.random.uniform([128, 50, 80]) -``` - -Timing the inference speed of a single sample. Running this, you can see that the average -inference time per spectrogram ranges from 8 milliseconds to 10 milliseconds on a K80 GPU which is -pretty fast. - - -```python -pred = generator.predict(audio_sample, batch_size=32, verbose=1) -``` - -
-``` -4/4 [==============================] - 5s 280ms/step - -``` -
---- -## Conclusion - -The MelGAN is a highly effective architecture for spectral inversion that has a Mean -Opinion Score (MOS) of 3.61 that considerably outperforms the Griffin -Lim algorithm having a MOS of just 1.57. In contrast with this, the MelGAN compares with -the state-of-the-art WaveGlow and WaveNet architectures on text-to-speech and speech -enhancement tasks on -the LJSpeech and VCTK datasets [1]. - -This tutorial highlights: - -1. The advantages of using dilated convolutions that grow with the filter size -2. Implementation of a custom layer for on-the-fly conversion of audio waves to -mel-spectrograms -3. Effectiveness of using the feature matching loss function for training GAN generators. - -Further reading - -1. [MelGAN paper](https://arxiv.org/pdf/1910.06711v3.pdf) (Kundan Kumar et al.) to -understand the reasoning behind the architecture and training process -2. For in-depth understanding of the feature matching loss, you can refer to [Improved -Techniques for Training GANs](https://arxiv.org/pdf/1606.03498v1.pdf) (Tim Salimans et -al.). - -Example available on HuggingFace - -| Trained Model | Demo | -| :--: | :--: | -| [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Model-MelGan%20spectrogram%20inversion-black.svg)](https://huggingface.co/keras-io/MelGAN-spectrogram-inversion) | [![Generic badge](https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-MelGan%20spectrogram%20inversion-black.svg)](https://huggingface.co/spaces/keras-io/MelGAN-spectrogram-inversion) | - diff --git a/templates/examples/audio/speaker_recognition_using_cnn.md b/templates/examples/audio/speaker_recognition_using_cnn.md deleted file mode 100644 index c2d6657891..0000000000 --- a/templates/examples/audio/speaker_recognition_using_cnn.md +++ /dev/null @@ -1,852 +0,0 @@ -# Speaker Recognition - -**Author:** [Fadi Badine](https://twitter.com/fadibadine)
-**Date created:** 14/06/2020
-**Last modified:** 19/07/2023
-**Description:** Classify speakers using Fast Fourier Transform (FFT) and a 1D Convnet. - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/audio/ipynb/speaker_recognition_using_cnn.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/audio/speaker_recognition_using_cnn.py) - - - ---- -## Introduction - -This example demonstrates how to create a model to classify speakers from the -frequency domain representation of speech recordings, obtained via Fast Fourier -Transform (FFT). - -It shows the following: - -- How to use `tf.data` to load, preprocess and feed audio streams into a model -- How to create a 1D convolutional network with residual -connections for audio classification. - -Our process: - -- We prepare a dataset of speech samples from different speakers, with the speaker as label. -- We add background noise to these samples to augment our data. -- We take the FFT of these samples. -- We train a 1D convnet to predict the correct speaker given a noisy FFT speech sample. - -Note: - -- This example should be run with TensorFlow 2.3 or higher, or `tf-nightly`. -- The noise samples in the dataset need to be resampled to a sampling rate of 16000 Hz -before using the code in this example. In order to do this, you will need to have -installed `ffmpg`. - ---- -## Setup - - -```python -import os - -os.environ["KERAS_BACKEND"] = "tensorflow" - -import shutil -import numpy as np - -import tensorflow as tf -import keras - -from pathlib import Path -from IPython.display import display, Audio - -# Get the data from https://www.kaggle.com/kongaevans/speaker-recognition-dataset/ -# and save it to ./speaker-recognition-dataset.zip -# then unzip it to ./16000_pcm_speeches -``` - - -```python -!kaggle datasets download -d kongaevans/speaker-recognition-dataset -!unzip -qq speaker-recognition-dataset.zip -``` - -```python -DATASET_ROOT = "16000_pcm_speeches" - -# The folders in which we will put the audio samples and the noise samples -AUDIO_SUBFOLDER = "audio" -NOISE_SUBFOLDER = "noise" - -DATASET_AUDIO_PATH = os.path.join(DATASET_ROOT, AUDIO_SUBFOLDER) -DATASET_NOISE_PATH = os.path.join(DATASET_ROOT, NOISE_SUBFOLDER) - -# Percentage of samples to use for validation -VALID_SPLIT = 0.1 - -# Seed to use when shuffling the dataset and the noise -SHUFFLE_SEED = 43 - -# The sampling rate to use. -# This is the one used in all the audio samples. -# We will resample all the noise to this sampling rate. -# This will also be the output size of the audio wave samples -# (since all samples are of 1 second long) -SAMPLING_RATE = 16000 - -# The factor to multiply the noise with according to: -# noisy_sample = sample + noise * prop * scale -# where prop = sample_amplitude / noise_amplitude -SCALE = 0.5 - -BATCH_SIZE = 128 -EPOCHS = 1 - -``` -
-``` -Warning: Your Kaggle API key is readable by other users on this system! To fix this, you can run 'chmod 600 /home/fchollet/.kaggle/kaggle.json' -Downloading speaker-recognition-dataset.zip to /home/fchollet/keras-io/scripts/tmp_5022915 - 90%|████████████████████████████████████▉ | 208M/231M [00:00<00:00, 217MB/s] -100%|█████████████████████████████████████████| 231M/231M [00:01<00:00, 227MB/s] - -``` -
---- -## Data preparation - -The dataset is composed of 7 folders, divided into 2 groups: - -- Speech samples, with 5 folders for 5 different speakers. Each folder contains -1500 audio files, each 1 second long and sampled at 16000 Hz. -- Background noise samples, with 2 folders and a total of 6 files. These files -are longer than 1 second (and originally not sampled at 16000 Hz, but we will resample them to 16000 Hz). -We will use those 6 files to create 354 1-second-long noise samples to be used for training. - -Let's sort these 2 categories into 2 folders: - -- An `audio` folder which will contain all the per-speaker speech sample folders -- A `noise` folder which will contain all the noise samples - -Before sorting the audio and noise categories into 2 folders, -we have the following directory structure: - -``` -main_directory/ -...speaker_a/ -...speaker_b/ -...speaker_c/ -...speaker_d/ -...speaker_e/ -...other/ -..._background_noise_/ -``` - -After sorting, we end up with the following structure: - -``` -main_directory/ -...audio/ -......speaker_a/ -......speaker_b/ -......speaker_c/ -......speaker_d/ -......speaker_e/ -...noise/ -......other/ -......_background_noise_/ -``` - - -```python -for folder in os.listdir(DATASET_ROOT): - if os.path.isdir(os.path.join(DATASET_ROOT, folder)): - if folder in [AUDIO_SUBFOLDER, NOISE_SUBFOLDER]: - # If folder is `audio` or `noise`, do nothing - continue - elif folder in ["other", "_background_noise_"]: - # If folder is one of the folders that contains noise samples, - # move it to the `noise` folder - shutil.move( - os.path.join(DATASET_ROOT, folder), - os.path.join(DATASET_NOISE_PATH, folder), - ) - else: - # Otherwise, it should be a speaker folder, then move it to - # `audio` folder - shutil.move( - os.path.join(DATASET_ROOT, folder), - os.path.join(DATASET_AUDIO_PATH, folder), - ) -``` - ---- -## Noise preparation - -In this section: - -- We load all noise samples (which should have been resampled to 16000) -- We split those noise samples to chunks of 16000 samples which -correspond to 1 second duration each - - -```python -# Get the list of all noise files -noise_paths = [] -for subdir in os.listdir(DATASET_NOISE_PATH): - subdir_path = Path(DATASET_NOISE_PATH) / subdir - if os.path.isdir(subdir_path): - noise_paths += [ - os.path.join(subdir_path, filepath) - for filepath in os.listdir(subdir_path) - if filepath.endswith(".wav") - ] -if not noise_paths: - raise RuntimeError(f"Could not find any files at {DATASET_NOISE_PATH}") -print( - "Found {} files belonging to {} directories".format( - len(noise_paths), len(os.listdir(DATASET_NOISE_PATH)) - ) -) -``` - -
-``` -Found 6 files belonging to 2 directories - -``` -
-Resample all noise samples to 16000 Hz - - -```python -command = ( - "for dir in `ls -1 " + DATASET_NOISE_PATH + "`; do " - "for file in `ls -1 " + DATASET_NOISE_PATH + "/$dir/*.wav`; do " - "sample_rate=`ffprobe -hide_banner -loglevel panic -show_streams " - "$file | grep sample_rate | cut -f2 -d=`; " - "if [ $sample_rate -ne 16000 ]; then " - "ffmpeg -hide_banner -loglevel panic -y " - "-i $file -ar 16000 temp.wav; " - "mv temp.wav $file; " - "fi; done; done" -) -os.system(command) - - -# Split noise into chunks of 16,000 steps each -def load_noise_sample(path): - sample, sampling_rate = tf.audio.decode_wav( - tf.io.read_file(path), desired_channels=1 - ) - if sampling_rate == SAMPLING_RATE: - # Number of slices of 16000 each that can be generated from the noise sample - slices = int(sample.shape[0] / SAMPLING_RATE) - sample = tf.split(sample[: slices * SAMPLING_RATE], slices) - return sample - else: - print("Sampling rate for {} is incorrect. Ignoring it".format(path)) - return None - - -noises = [] -for path in noise_paths: - sample = load_noise_sample(path) - if sample: - noises.extend(sample) -noises = tf.stack(noises) - -print( - "{} noise files were split into {} noise samples where each is {} sec. long".format( - len(noise_paths), noises.shape[0], noises.shape[1] // SAMPLING_RATE - ) -) -``` - -
-``` -6 noise files were split into 354 noise samples where each is 1 sec. long - -``` -
---- -## Dataset generation - - -```python - -def paths_and_labels_to_dataset(audio_paths, labels): - """Constructs a dataset of audios and labels.""" - path_ds = tf.data.Dataset.from_tensor_slices(audio_paths) - audio_ds = path_ds.map( - lambda x: path_to_audio(x), num_parallel_calls=tf.data.AUTOTUNE - ) - label_ds = tf.data.Dataset.from_tensor_slices(labels) - return tf.data.Dataset.zip((audio_ds, label_ds)) - - -def path_to_audio(path): - """Reads and decodes an audio file.""" - audio = tf.io.read_file(path) - audio, _ = tf.audio.decode_wav(audio, 1, SAMPLING_RATE) - return audio - - -def add_noise(audio, noises=None, scale=0.5): - if noises is not None: - # Create a random tensor of the same size as audio ranging from - # 0 to the number of noise stream samples that we have. - tf_rnd = tf.random.uniform( - (tf.shape(audio)[0],), 0, noises.shape[0], dtype=tf.int32 - ) - noise = tf.gather(noises, tf_rnd, axis=0) - - # Get the amplitude proportion between the audio and the noise - prop = tf.math.reduce_max(audio, axis=1) / tf.math.reduce_max(noise, axis=1) - prop = tf.repeat(tf.expand_dims(prop, axis=1), tf.shape(audio)[1], axis=1) - - # Adding the rescaled noise to audio - audio = audio + noise * prop * scale - - return audio - - -def audio_to_fft(audio): - # Since tf.signal.fft applies FFT on the innermost dimension, - # we need to squeeze the dimensions and then expand them again - # after FFT - audio = tf.squeeze(audio, axis=-1) - fft = tf.signal.fft( - tf.cast(tf.complex(real=audio, imag=tf.zeros_like(audio)), tf.complex64) - ) - fft = tf.expand_dims(fft, axis=-1) - - # Return the absolute value of the first half of the FFT - # which represents the positive frequencies - return tf.math.abs(fft[:, : (audio.shape[1] // 2), :]) - - -# Get the list of audio file paths along with their corresponding labels - -class_names = os.listdir(DATASET_AUDIO_PATH) -print( - "Our class names: {}".format( - class_names, - ) -) - -audio_paths = [] -labels = [] -for label, name in enumerate(class_names): - print( - "Processing speaker {}".format( - name, - ) - ) - dir_path = Path(DATASET_AUDIO_PATH) / name - speaker_sample_paths = [ - os.path.join(dir_path, filepath) - for filepath in os.listdir(dir_path) - if filepath.endswith(".wav") - ] - audio_paths += speaker_sample_paths - labels += [label] * len(speaker_sample_paths) - -print( - "Found {} files belonging to {} classes.".format(len(audio_paths), len(class_names)) -) - -# Shuffle -rng = np.random.RandomState(SHUFFLE_SEED) -rng.shuffle(audio_paths) -rng = np.random.RandomState(SHUFFLE_SEED) -rng.shuffle(labels) - -# Split into training and validation -num_val_samples = int(VALID_SPLIT * len(audio_paths)) -print("Using {} files for training.".format(len(audio_paths) - num_val_samples)) -train_audio_paths = audio_paths[:-num_val_samples] -train_labels = labels[:-num_val_samples] - -print("Using {} files for validation.".format(num_val_samples)) -valid_audio_paths = audio_paths[-num_val_samples:] -valid_labels = labels[-num_val_samples:] - -# Create 2 datasets, one for training and the other for validation -train_ds = paths_and_labels_to_dataset(train_audio_paths, train_labels) -train_ds = train_ds.shuffle(buffer_size=BATCH_SIZE * 8, seed=SHUFFLE_SEED).batch( - BATCH_SIZE -) - -valid_ds = paths_and_labels_to_dataset(valid_audio_paths, valid_labels) -valid_ds = valid_ds.shuffle(buffer_size=32 * 8, seed=SHUFFLE_SEED).batch(32) - - -# Add noise to the training set -train_ds = train_ds.map( - lambda x, y: (add_noise(x, noises, scale=SCALE), y), - num_parallel_calls=tf.data.AUTOTUNE, -) - -# Transform audio wave to the frequency domain using `audio_to_fft` -train_ds = train_ds.map( - lambda x, y: (audio_to_fft(x), y), num_parallel_calls=tf.data.AUTOTUNE -) -train_ds = train_ds.prefetch(tf.data.AUTOTUNE) - -valid_ds = valid_ds.map( - lambda x, y: (audio_to_fft(x), y), num_parallel_calls=tf.data.AUTOTUNE -) -valid_ds = valid_ds.prefetch(tf.data.AUTOTUNE) -``` - -
-``` -Our class names: ['Nelson_Mandela', 'Jens_Stoltenberg', 'Benjamin_Netanyau', 'Julia_Gillard', 'Magaret_Tarcher'] -Processing speaker Nelson_Mandela -Processing speaker Jens_Stoltenberg -Processing speaker Benjamin_Netanyau -Processing speaker Julia_Gillard -Processing speaker Magaret_Tarcher -Found 7501 files belonging to 5 classes. -Using 6751 files for training. -Using 750 files for validation. - -``` -
---- -## Model Definition - - -```python - -def residual_block(x, filters, conv_num=3, activation="relu"): - # Shortcut - s = keras.layers.Conv1D(filters, 1, padding="same")(x) - for i in range(conv_num - 1): - x = keras.layers.Conv1D(filters, 3, padding="same")(x) - x = keras.layers.Activation(activation)(x) - x = keras.layers.Conv1D(filters, 3, padding="same")(x) - x = keras.layers.Add()([x, s]) - x = keras.layers.Activation(activation)(x) - return keras.layers.MaxPool1D(pool_size=2, strides=2)(x) - - -def build_model(input_shape, num_classes): - inputs = keras.layers.Input(shape=input_shape, name="input") - - x = residual_block(inputs, 16, 2) - x = residual_block(x, 32, 2) - x = residual_block(x, 64, 3) - x = residual_block(x, 128, 3) - x = residual_block(x, 128, 3) - - x = keras.layers.AveragePooling1D(pool_size=3, strides=3)(x) - x = keras.layers.Flatten()(x) - x = keras.layers.Dense(256, activation="relu")(x) - x = keras.layers.Dense(128, activation="relu")(x) - - outputs = keras.layers.Dense(num_classes, activation="softmax", name="output")(x) - - return keras.models.Model(inputs=inputs, outputs=outputs) - - -model = build_model((SAMPLING_RATE // 2, 1), len(class_names)) - -model.summary() - -# Compile the model using Adam's default learning rate -model.compile( - optimizer="Adam", - loss="sparse_categorical_crossentropy", - metrics=["accuracy"], -) - -# Add callbacks: -# 'EarlyStopping' to stop training when the model is not enhancing anymore -# 'ModelCheckPoint' to always keep the model that has the best val_accuracy -model_save_filename = "model.keras" - -earlystopping_cb = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True) -mdlcheckpoint_cb = keras.callbacks.ModelCheckpoint( - model_save_filename, monitor="val_accuracy", save_best_only=True -) -``` - - -
Model: "functional_1"
-
- - - - -
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
-┃ Layer (type)         Output Shape       Param #  Connected to         ┃
-┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
-│ input (InputLayer)  │ (None, 8000, 1)   │       0 │ -                    │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_1 (Conv1D)   │ (None, 8000, 16)  │      64 │ input[0][0]          │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ activation          │ (None, 8000, 16)  │       0 │ conv1d_1[0][0]       │
-│ (Activation)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_2 (Conv1D)   │ (None, 8000, 16)  │     784 │ activation[0][0]     │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d (Conv1D)     │ (None, 8000, 16)  │      32 │ input[0][0]          │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ add (Add)           │ (None, 8000, 16)  │       0 │ conv1d_2[0][0],      │
-│                     │                   │         │ conv1d[0][0]         │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ activation_1        │ (None, 8000, 16)  │       0 │ add[0][0]            │
-│ (Activation)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ max_pooling1d       │ (None, 4000, 16)  │       0 │ activation_1[0][0]   │
-│ (MaxPooling1D)      │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_4 (Conv1D)   │ (None, 4000, 32)  │   1,568 │ max_pooling1d[0][0]  │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ activation_2        │ (None, 4000, 32)  │       0 │ conv1d_4[0][0]       │
-│ (Activation)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_5 (Conv1D)   │ (None, 4000, 32)  │   3,104 │ activation_2[0][0]   │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_3 (Conv1D)   │ (None, 4000, 32)  │     544 │ max_pooling1d[0][0]  │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ add_1 (Add)         │ (None, 4000, 32)  │       0 │ conv1d_5[0][0],      │
-│                     │                   │         │ conv1d_3[0][0]       │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ activation_3        │ (None, 4000, 32)  │       0 │ add_1[0][0]          │
-│ (Activation)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ max_pooling1d_1     │ (None, 2000, 32)  │       0 │ activation_3[0][0]   │
-│ (MaxPooling1D)      │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_7 (Conv1D)   │ (None, 2000, 64)  │   6,208 │ max_pooling1d_1[0][ │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ activation_4        │ (None, 2000, 64)  │       0 │ conv1d_7[0][0]       │
-│ (Activation)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_8 (Conv1D)   │ (None, 2000, 64)  │  12,352 │ activation_4[0][0]   │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ activation_5        │ (None, 2000, 64)  │       0 │ conv1d_8[0][0]       │
-│ (Activation)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_9 (Conv1D)   │ (None, 2000, 64)  │  12,352 │ activation_5[0][0]   │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_6 (Conv1D)   │ (None, 2000, 64)  │   2,112 │ max_pooling1d_1[0][ │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ add_2 (Add)         │ (None, 2000, 64)  │       0 │ conv1d_9[0][0],      │
-│                     │                   │         │ conv1d_6[0][0]       │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ activation_6        │ (None, 2000, 64)  │       0 │ add_2[0][0]          │
-│ (Activation)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ max_pooling1d_2     │ (None, 1000, 64)  │       0 │ activation_6[0][0]   │
-│ (MaxPooling1D)      │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_11 (Conv1D)  │ (None, 1000, 128) │  24,704 │ max_pooling1d_2[0][ │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ activation_7        │ (None, 1000, 128) │       0 │ conv1d_11[0][0]      │
-│ (Activation)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_12 (Conv1D)  │ (None, 1000, 128) │  49,280 │ activation_7[0][0]   │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ activation_8        │ (None, 1000, 128) │       0 │ conv1d_12[0][0]      │
-│ (Activation)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_13 (Conv1D)  │ (None, 1000, 128) │  49,280 │ activation_8[0][0]   │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_10 (Conv1D)  │ (None, 1000, 128) │   8,320 │ max_pooling1d_2[0][ │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ add_3 (Add)         │ (None, 1000, 128) │       0 │ conv1d_13[0][0],     │
-│                     │                   │         │ conv1d_10[0][0]      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ activation_9        │ (None, 1000, 128) │       0 │ add_3[0][0]          │
-│ (Activation)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ max_pooling1d_3     │ (None, 500, 128)  │       0 │ activation_9[0][0]   │
-│ (MaxPooling1D)      │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_15 (Conv1D)  │ (None, 500, 128)  │  49,280 │ max_pooling1d_3[0][ │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ activation_10       │ (None, 500, 128)  │       0 │ conv1d_15[0][0]      │
-│ (Activation)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_16 (Conv1D)  │ (None, 500, 128)  │  49,280 │ activation_10[0][0]  │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ activation_11       │ (None, 500, 128)  │       0 │ conv1d_16[0][0]      │
-│ (Activation)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_17 (Conv1D)  │ (None, 500, 128)  │  49,280 │ activation_11[0][0]  │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ conv1d_14 (Conv1D)  │ (None, 500, 128)  │  16,512 │ max_pooling1d_3[0][ │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ add_4 (Add)         │ (None, 500, 128)  │       0 │ conv1d_17[0][0],     │
-│                     │                   │         │ conv1d_14[0][0]      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ activation_12       │ (None, 500, 128)  │       0 │ add_4[0][0]          │
-│ (Activation)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ max_pooling1d_4     │ (None, 250, 128)  │       0 │ activation_12[0][0]  │
-│ (MaxPooling1D)      │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ average_pooling1d   │ (None, 83, 128)   │       0 │ max_pooling1d_4[0][ │
-│ (AveragePooling1D)  │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ flatten (Flatten)   │ (None, 10624)     │       0 │ average_pooling1d[0… │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ dense (Dense)       │ (None, 256)       │ 2,720,… │ flatten[0][0]        │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ dense_1 (Dense)     │ (None, 128)       │  32,896 │ dense[0][0]          │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ output (Dense)      │ (None, 5)         │     645 │ dense_1[0][0]        │
-└─────────────────────┴───────────────────┴─────────┴──────────────────────┘
-
- - - - -
 Total params: 3,088,597 (11.78 MB)
-
- - - - -
 Trainable params: 3,088,597 (11.78 MB)
-
- - - - -
 Non-trainable params: 0 (0.00 B)
-
- - - ---- -## Training - - -```python -history = model.fit( - train_ds, - epochs=EPOCHS, - validation_data=valid_ds, - callbacks=[earlystopping_cb, mdlcheckpoint_cb], -) -``` - -
-``` -WARNING: All log messages before absl::InitializeLog() is called are written to STDERR -I0000 00:00:1699469571.349760 302130 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. -W0000 00:00:1699469571.377393 302130 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update - - 52/53 ━━━━━━━━━━━━━━━━━━━━ 0s 396ms/step - accuracy: 0.4496 - loss: 5.2439 - -W0000 00:00:1699469622.140353 302129 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update - - 53/53 ━━━━━━━━━━━━━━━━━━━━ 0s 974ms/step - accuracy: 0.4531 - loss: 5.1842 - -W0000 00:00:1699469625.456199 302130 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update -W0000 00:00:1699469627.405341 302129 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update - - 53/53 ━━━━━━━━━━━━━━━━━━━━ 101s 1s/step - accuracy: 0.4564 - loss: 5.1267 - val_accuracy: 0.8720 - val_loss: 0.3273 - -``` -
---- -## Evaluation - - -```python -print(model.evaluate(valid_ds)) -``` - -
-``` - 24/24 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8641 - loss: 0.3521 -[0.32171541452407837, 0.871999979019165] - -``` -
-We get ~ 98% validation accuracy. - ---- -## Demonstration - -Let's take some samples and: - -- Predict the speaker -- Compare the prediction with the real speaker -- Listen to the audio to see that despite the samples being noisy, -the model is still pretty accurate - - -```python -SAMPLES_TO_DISPLAY = 10 - -test_ds = paths_and_labels_to_dataset(valid_audio_paths, valid_labels) -test_ds = test_ds.shuffle(buffer_size=BATCH_SIZE * 8, seed=SHUFFLE_SEED).batch( - BATCH_SIZE -) - -test_ds = test_ds.map( - lambda x, y: (add_noise(x, noises, scale=SCALE), y), - num_parallel_calls=tf.data.AUTOTUNE, -) - -for audios, labels in test_ds.take(1): - # Get the signal FFT - ffts = audio_to_fft(audios) - # Predict - y_pred = model.predict(ffts) - # Take random samples - rnd = np.random.randint(0, BATCH_SIZE, SAMPLES_TO_DISPLAY) - audios = audios.numpy()[rnd, :, :] - labels = labels.numpy()[rnd] - y_pred = np.argmax(y_pred, axis=-1)[rnd] - - for index in range(SAMPLES_TO_DISPLAY): - # For every sample, print the true and predicted label - # as well as run the voice with the noise - print( - "Speaker:\33{} {}\33[0m\tPredicted:\33{} {}\33[0m".format( - "[92m" if labels[index] == y_pred[index] else "[91m", - class_names[labels[index]], - "[92m" if labels[index] == y_pred[index] else "[91m", - class_names[y_pred[index]], - ) - ) - display(Audio(audios[index, :, :].squeeze(), rate=SAMPLING_RATE)) -``` - -
-``` - 4/4 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step -Speaker: Magaret_Tarcher Predicted: Benjamin_Netanyau - -W0000 00:00:1699469629.002282 302130 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update - -``` -
- - - - -
-``` -Speaker: Julia_Gillard Predicted: Julia_Gillard - -``` -
- - - - -
-``` -Speaker: Nelson_Mandela Predicted: Nelson_Mandela - -``` -
- - - - -
-``` -Speaker: Magaret_Tarcher Predicted: Magaret_Tarcher - -``` -
- - - - -
-``` -Speaker: Julia_Gillard Predicted: Julia_Gillard - -``` -
- - - - -
-``` -Speaker: Julia_Gillard Predicted: Julia_Gillard - -``` -
- - - - -
-``` -Speaker: Jens_Stoltenberg Predicted: Jens_Stoltenberg - -``` -
- - - - -
-``` -Speaker: Benjamin_Netanyau Predicted: Benjamin_Netanyau - -``` -
- - - - -
-``` -Speaker: Nelson_Mandela Predicted: Nelson_Mandela - -``` -
- - - - -
-``` -Speaker: Nelson_Mandela Predicted: Nelson_Mandela - -``` -
- - - - diff --git a/templates/examples/audio/stft.md b/templates/examples/audio/stft.md deleted file mode 100644 index 331e929f7f..0000000000 --- a/templates/examples/audio/stft.md +++ /dev/null @@ -1,1822 +0,0 @@ -# Audio Classification with the STFTSpectrogram layer - -**Author:** [Mostafa M. Amin](https://mostafa-amin.com)
-**Date created:** 2024/10/04
-**Last modified:** 2024/10/04
-**Description:** Introducing the `STFTSpectrogram` layer to extract spectrograms for audio classification. - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/audio/ipynb/stft.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/audio/stft.py) - - - ---- -## Introduction - -Preprocessing audio as spectrograms is an essential step in the vast majority -of audio-based applications. Spectrograms represent the frequency content of a -signal over time, are widely used for this purpose. In this tutorial, we'll -demonstrate how to use the `STFTSpectrogram` layer in Keras to convert raw -audio waveforms into spectrograms **within the model**. We'll then feed -these spectrograms into an LSTM network followed by Dense layers to perform -audio classification on the Speech Commands dataset. - -We will: - -- Load the ESC-10 dataset. -- Preprocess the raw audio waveforms and generate spectrograms using - `STFTSpectrogram`. -- Build two models, one using spectrograms as 1D signals and the other is using - as images (2D signals) with a pretrained image model. -- Train and evaluate the models. - ---- -## Setup - -### Importing the necessary libraries - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" -``` - - -```python -import keras -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import scipy.io.wavfile -from keras import layers -from scipy.signal import resample - -keras.utils.set_random_seed(41) -``` - -### Define some variables - - -```python -BASE_DATA_DIR = "./datasets/esc-50_extracted/ESC-50-master/" -BATCH_SIZE = 16 -NUM_CLASSES = 10 -EPOCHS = 200 -SAMPLE_RATE = 16000 -``` - ---- -## Download and Preprocess the ESC-10 Dataset - -We'll use the Dataset for Environmental Sound Classification dataset (ESC-10). -This dataset consists of five-second .wav files of environmental sounds. - -### Download and Extract the dataset - - -```python -keras.utils.get_file( - "esc-50.zip", - "https://github.com/karoldvl/ESC-50/archive/master.zip", - cache_dir="./", - cache_subdir="datasets", - extract=True, -) -``` - - - - - './datasets/esc-50_extracted' - - - -### Read the CSV file - - -```python -pd_data = pd.read_csv(os.path.join(BASE_DATA_DIR, "meta", "esc50.csv")) -# filter ESC-50 to ESC-10 and reassign the targets -pd_data = pd_data[pd_data["esc10"]] -targets = sorted(pd_data["target"].unique().tolist()) -assert len(targets) == NUM_CLASSES -old_target_to_new_target = {old: new for new, old in enumerate(targets)} -pd_data["target"] = pd_data["target"].map(lambda t: old_target_to_new_target[t]) -pd_data -``` - - - - - -
-
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
filenamefoldtargetcategoryesc10src_filetake
01-100032-A-0.wav10dogTrue100032A
141-110389-A-0.wav10dogTrue110389A
241-116765-A-41.wav19chainsawTrue116765A
541-17150-A-12.wav14crackling_fireTrue17150A
551-172649-A-40.wav18helicopterTrue172649A
........................
18765-233160-A-1.wav51roosterTrue233160A
18885-234879-A-1.wav51roosterTrue234879A
18895-234879-B-1.wav51roosterTrue234879B
18945-235671-A-38.wav57clock_tickTrue235671A
19995-9032-A-0.wav50dogTrue9032A
-

400 rows × 7 columns

-
-
- -
- - - - - -
- - -
- - - - - -
- -
- - - -
- -
-
- - - - -### Define functions to read and preprocess the WAV files - - -```python -def read_wav_file(path, target_sr=SAMPLE_RATE): - sr, wav = scipy.io.wavfile.read(os.path.join(BASE_DATA_DIR, "audio", path)) - wav = wav.astype(np.float32) / 32768.0 # normalize to [-1, 1] - num_samples = int(len(wav) * target_sr / sr) # resample to 16 kHz - wav = resample(wav, num_samples) - return wav[:, None] # Add a channel dimension (of size 1) -``` - -Create a function that uses the `STFTSpectrogram` to compute a spectrogram, -then plots it. - - -```python -def plot_single_spectrogram(sample_wav_data): - spectrogram = layers.STFTSpectrogram( - mode="log", - frame_length=SAMPLE_RATE * 20 // 1000, - frame_step=SAMPLE_RATE * 5 // 1000, - fft_length=1024, - trainable=False, - )(sample_wav_data[None, ...])[0, ...] - - # Plot the spectrogram - plt.imshow(spectrogram.T, origin="lower") - plt.title("Single Channel Spectrogram") - plt.xlabel("Time") - plt.ylabel("Frequency") - plt.show() -``` - -Create a function that uses the `STFTSpectrogram` to compute three -spectrograms with multiple bandwidths, then aligns them as an image -with different channels, to get a multi-bandwith spectrogram, then plots the spectrogram. - - -```python -def plot_multi_bandwidth_spectrogram(sample_wav_data): - # All spectrograms must use the same `fft_length`, `frame_step`, and - # `padding="same"` in order to produce spectrograms with identical shapes, - # hence aligning them together. `expand_dims` ensures that the shapes are - # compatible with image models. - - spectrograms = np.concatenate( - [ - layers.STFTSpectrogram( - mode="log", - frame_length=SAMPLE_RATE * x // 1000, - frame_step=SAMPLE_RATE * 5 // 1000, - fft_length=1024, - padding="same", - expand_dims=True, - )(sample_wav_data[None, ...])[0, ...] - for x in [5, 10, 20] - ], - axis=-1, - ).transpose([1, 0, 2]) - - # normalize each color channel for better viewing - mn = spectrograms.min(axis=(0, 1), keepdims=True) - mx = spectrograms.max(axis=(0, 1), keepdims=True) - spectrograms = (spectrograms - mn) / (mx - mn) - - plt.imshow(spectrograms, origin="lower") - plt.title("Multi-bandwidth Spectrogram") - plt.xlabel("Time") - plt.ylabel("Frequency") - plt.show() -``` - -Demonstrate a sample wav file. - - -```python -sample_wav_data = read_wav_file(pd_data["filename"].tolist()[52]) -plt.plot(sample_wav_data[:, 0]) -plt.show() -``` - - - -![png](https://github.com/keras-team/keras-io/blob/master/examples/audio/img/stft/raw_audio.png) - - - -Plot a Spectrogram - - -```python -plot_single_spectrogram(sample_wav_data) -``` - - - -![png](https://github.com/keras-team/keras-io/blob/master/examples/audio/img/stft/spectrogram.png) - - - -Plot a multi-bandwidth spectrogram - - -```python -plot_multi_bandwidth_spectrogram(sample_wav_data) -``` - - - -![png](https://github.com/keras-team/keras-io/blob/master/examples/audio/img/stft/multiband_spectrogram.png) - - - -### Define functions to construct a TF Dataset - - -```python -def read_dataset(df, folds): - msk = df["fold"].isin(folds) - filenames = df["filename"][msk] - targets = df["target"][msk].values - waves = np.array([read_wav_file(fil) for fil in filenames], dtype=np.float32) - return waves, targets -``` - -### Create the datasets - - -```python -train_x, train_y = read_dataset(pd_data, [1, 2, 3]) -valid_x, valid_y = read_dataset(pd_data, [4]) -test_x, test_y = read_dataset(pd_data, [5]) -``` - ---- -## Training the Models - -In this tutorial we demonstrate the different usecases of the `STFTSpectrogram` -layer. - -The first model will use a non-trainable `STFTSpectrogram` layer, so it is -intended purely for preprocessing. Additionally, the model will use 1D signals, -hence it make use of Conv1D layers. - -The second model will use a trainable `STFTSpectrogram` layer with the -`expand_dims` option, which expands the shapes to be compatible with image -models. - -### Create the 1D model - -1. Create a non-trainable spectrograms, extracting a 1D time signal. -2. Apply `Conv1D` layers with `LayerNormalization` simialar to the - classic VGG design. -4. Apply global maximum pooling to have fixed set of features. -5. Add `Dense` layers to make the final predictions based on the features. - - -```python -model1d = keras.Sequential( - [ - layers.InputLayer((None, 1)), - layers.STFTSpectrogram( - mode="log", - frame_length=SAMPLE_RATE * 40 // 1000, - frame_step=SAMPLE_RATE * 15 // 1000, - trainable=False, - ), - layers.Conv1D(64, 64, activation="relu"), - layers.Conv1D(128, 16, activation="relu"), - layers.LayerNormalization(), - layers.MaxPooling1D(4), - layers.Conv1D(128, 8, activation="relu"), - layers.Conv1D(256, 8, activation="relu"), - layers.Conv1D(512, 4, activation="relu"), - layers.LayerNormalization(), - layers.Dropout(0.5), - layers.GlobalMaxPooling1D(), - layers.Dense(256, activation="relu"), - layers.Dense(256, activation="relu"), - layers.Dropout(0.5), - layers.Dense(NUM_CLASSES, activation="softmax"), - ], - name="model_1d_non_trainble_stft", -) -model1d.compile( - optimizer=keras.optimizers.Adam(1e-5), - loss="sparse_categorical_crossentropy", - metrics=["accuracy"], -) -model1d.summary() -``` - - -
Model: "model_1d_non_trainble_stft"
-
- - - - -
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
-┃ Layer (type)                          Output Shape                         Param # ┃
-┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
-│ stft_spectrogram_4 (STFTSpectrogram) │ (None, None, 513)           │         656,640 │
-├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
-│ conv1d (Conv1D)                      │ (None, None, 64)            │       2,101,312 │
-├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
-│ conv1d_1 (Conv1D)                    │ (None, None, 128)           │         131,200 │
-├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
-│ layer_normalization                  │ (None, None, 128)           │             256 │
-│ (LayerNormalization)                 │                             │                 │
-├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
-│ max_pooling1d (MaxPooling1D)         │ (None, None, 128)           │               0 │
-├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
-│ conv1d_2 (Conv1D)                    │ (None, None, 128)           │         131,200 │
-├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
-│ conv1d_3 (Conv1D)                    │ (None, None, 256)           │         262,400 │
-├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
-│ conv1d_4 (Conv1D)                    │ (None, None, 512)           │         524,800 │
-├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
-│ layer_normalization_1                │ (None, None, 512)           │           1,024 │
-│ (LayerNormalization)                 │                             │                 │
-├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
-│ dropout (Dropout)                    │ (None, None, 512)           │               0 │
-├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
-│ global_max_pooling1d                 │ (None, 512)                 │               0 │
-│ (GlobalMaxPooling1D)                 │                             │                 │
-├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
-│ dense (Dense)                        │ (None, 256)                 │         131,328 │
-├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
-│ dense_1 (Dense)                      │ (None, 256)                 │          65,792 │
-├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
-│ dropout_1 (Dropout)                  │ (None, 256)                 │               0 │
-├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
-│ dense_2 (Dense)                      │ (None, 10)                  │           2,570 │
-└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
-
- - - - -
 Total params: 4,008,522 (15.29 MB)
-
- - - - -
 Trainable params: 3,351,882 (12.79 MB)
-
- - - - -
 Non-trainable params: 656,640 (2.50 MB)
-
- - - -Train the model and restore the best weights. - - -```python -history_model1d = model1d.fit( - train_x, - train_y, - batch_size=BATCH_SIZE, - validation_data=(valid_x, valid_y), - epochs=EPOCHS, - callbacks=[ - keras.callbacks.EarlyStopping( - monitor="val_loss", - patience=EPOCHS, - restore_best_weights=True, - ) - ], -) -``` - - Epoch 1/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 9s 271ms/step - accuracy: 0.1092 - loss: 3.1307 - val_accuracy: 0.0875 - val_loss: 2.4073 - Epoch 2/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - accuracy: 0.1434 - loss: 2.6563 - val_accuracy: 0.1000 - val_loss: 2.4051 - Epoch 3/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.1324 - loss: 2.5414 - val_accuracy: 0.1000 - val_loss: 2.4050 - Epoch 4/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.1552 - loss: 2.4542 - val_accuracy: 0.1000 - val_loss: 2.3832 - Epoch 5/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.1204 - loss: 2.3896 - val_accuracy: 0.1000 - val_loss: 2.3405 - Epoch 6/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.1210 - loss: 2.3499 - val_accuracy: 0.1000 - val_loss: 2.3108 - Epoch 7/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.1547 - loss: 2.2899 - val_accuracy: 0.1000 - val_loss: 2.2994 - Epoch 8/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.1672 - loss: 2.2049 - val_accuracy: 0.1250 - val_loss: 2.2802 - Epoch 9/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.2025 - loss: 2.1537 - val_accuracy: 0.1000 - val_loss: 2.2709 - Epoch 10/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.1832 - loss: 2.1482 - val_accuracy: 0.1500 - val_loss: 2.2698 - Epoch 11/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.2389 - loss: 2.0647 - val_accuracy: 0.1000 - val_loss: 2.2354 - Epoch 12/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.2253 - loss: 1.9860 - val_accuracy: 0.2125 - val_loss: 2.1661 - Epoch 13/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.2123 - loss: 2.0868 - val_accuracy: 0.1125 - val_loss: 2.1726 - Epoch 14/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.2390 - loss: 2.0544 - val_accuracy: 0.2375 - val_loss: 2.1123 - Epoch 15/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.2656 - loss: 2.0536 - val_accuracy: 0.2625 - val_loss: 2.1235 - Epoch 16/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.3263 - loss: 1.9533 - val_accuracy: 0.1750 - val_loss: 2.1477 - Epoch 17/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.3790 - loss: 1.8721 - val_accuracy: 0.1875 - val_loss: 2.0823 - Epoch 18/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.3292 - loss: 1.8978 - val_accuracy: 0.3125 - val_loss: 2.0181 - Epoch 19/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.3430 - loss: 1.8915 - val_accuracy: 0.3625 - val_loss: 1.9877 - Epoch 20/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.3613 - loss: 1.7638 - val_accuracy: 0.3500 - val_loss: 1.9599 - Epoch 21/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.4141 - loss: 1.6976 - val_accuracy: 0.4125 - val_loss: 1.9317 - Epoch 22/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.4173 - loss: 1.6408 - val_accuracy: 0.3000 - val_loss: 1.9310 - Epoch 23/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.3887 - loss: 1.5914 - val_accuracy: 0.4500 - val_loss: 1.8504 - Epoch 24/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.3943 - loss: 1.5998 - val_accuracy: 0.2875 - val_loss: 1.8993 - Epoch 25/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.5392 - loss: 1.4692 - val_accuracy: 0.4000 - val_loss: 1.8548 - Epoch 26/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.4735 - loss: 1.5004 - val_accuracy: 0.4250 - val_loss: 1.8440 - Epoch 27/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.5132 - loss: 1.4321 - val_accuracy: 0.5000 - val_loss: 1.7961 - Epoch 28/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.5147 - loss: 1.3093 - val_accuracy: 0.4250 - val_loss: 1.8132 - Epoch 29/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.5344 - loss: 1.3614 - val_accuracy: 0.5000 - val_loss: 1.7522 - Epoch 30/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.5545 - loss: 1.2561 - val_accuracy: 0.5375 - val_loss: 1.7180 - Epoch 31/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.5697 - loss: 1.2651 - val_accuracy: 0.5500 - val_loss: 1.6538 - Epoch 32/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.5385 - loss: 1.2571 - val_accuracy: 0.6125 - val_loss: 1.6453 - Epoch 33/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.5734 - loss: 1.3083 - val_accuracy: 0.5125 - val_loss: 1.6801 - Epoch 34/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.5976 - loss: 1.1720 - val_accuracy: 0.4625 - val_loss: 1.6860 - Epoch 35/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.5268 - loss: 1.3844 - val_accuracy: 0.6375 - val_loss: 1.6253 - Epoch 36/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.6021 - loss: 1.1720 - val_accuracy: 0.4625 - val_loss: 1.7012 - Epoch 37/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.5144 - loss: 1.2672 - val_accuracy: 0.6250 - val_loss: 1.5866 - Epoch 38/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.6075 - loss: 1.1400 - val_accuracy: 0.6125 - val_loss: 1.5615 - Epoch 39/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.6272 - loss: 1.1138 - val_accuracy: 0.5000 - val_loss: 1.6364 - Epoch 40/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.5718 - loss: 1.1956 - val_accuracy: 0.6000 - val_loss: 1.6239 - Epoch 41/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.5934 - loss: 1.1302 - val_accuracy: 0.5250 - val_loss: 1.5490 - Epoch 42/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.5930 - loss: 1.0970 - val_accuracy: 0.5625 - val_loss: 1.5530 - Epoch 43/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.6369 - loss: 0.9976 - val_accuracy: 0.6375 - val_loss: 1.5028 - Epoch 44/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.6918 - loss: 0.9205 - val_accuracy: 0.6625 - val_loss: 1.4681 - Epoch 45/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.6543 - loss: 0.9118 - val_accuracy: 0.6000 - val_loss: 1.4737 - Epoch 46/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.6243 - loss: 1.0268 - val_accuracy: 0.5750 - val_loss: 1.5423 - Epoch 47/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.6391 - loss: 1.0181 - val_accuracy: 0.6625 - val_loss: 1.4783 - Epoch 48/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.6863 - loss: 0.9874 - val_accuracy: 0.7000 - val_loss: 1.3977 - Epoch 49/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.7209 - loss: 0.8359 - val_accuracy: 0.6625 - val_loss: 1.3844 - Epoch 50/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.7659 - loss: 0.8241 - val_accuracy: 0.6500 - val_loss: 1.4206 - Epoch 51/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.7143 - loss: 0.8972 - val_accuracy: 0.6750 - val_loss: 1.3756 - Epoch 52/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.7081 - loss: 0.9544 - val_accuracy: 0.6375 - val_loss: 1.3703 - Epoch 53/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.6907 - loss: 0.9446 - val_accuracy: 0.6750 - val_loss: 1.3564 - Epoch 54/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.7460 - loss: 0.7399 - val_accuracy: 0.6000 - val_loss: 1.3840 - Epoch 55/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.7293 - loss: 0.8620 - val_accuracy: 0.6000 - val_loss: 1.3743 - Epoch 56/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.7504 - loss: 0.7715 - val_accuracy: 0.6875 - val_loss: 1.3175 - Epoch 57/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.7643 - loss: 0.7617 - val_accuracy: 0.6625 - val_loss: 1.3407 - Epoch 58/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.7568 - loss: 0.7798 - val_accuracy: 0.6875 - val_loss: 1.2950 - Epoch 59/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.7863 - loss: 0.6884 - val_accuracy: 0.6625 - val_loss: 1.3306 - Epoch 60/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.7550 - loss: 0.7504 - val_accuracy: 0.6500 - val_loss: 1.3260 - Epoch 61/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8069 - loss: 0.6624 - val_accuracy: 0.6375 - val_loss: 1.3168 - Epoch 62/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.7089 - loss: 0.8183 - val_accuracy: 0.7500 - val_loss: 1.2525 - Epoch 63/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7407 - loss: 0.7860 - val_accuracy: 0.7000 - val_loss: 1.2101 - Epoch 64/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.7526 - loss: 0.7691 - val_accuracy: 0.7250 - val_loss: 1.2327 - Epoch 65/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.7827 - loss: 0.7485 - val_accuracy: 0.6750 - val_loss: 1.2848 - Epoch 66/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.7195 - loss: 0.7853 - val_accuracy: 0.7000 - val_loss: 1.2047 - Epoch 67/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.7539 - loss: 0.7530 - val_accuracy: 0.7125 - val_loss: 1.1954 - Epoch 68/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.7912 - loss: 0.6220 - val_accuracy: 0.6750 - val_loss: 1.2297 - Epoch 69/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.7688 - loss: 0.6403 - val_accuracy: 0.6375 - val_loss: 1.2524 - Epoch 70/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.7699 - loss: 0.7181 - val_accuracy: 0.6625 - val_loss: 1.2147 - Epoch 71/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.8300 - loss: 0.5858 - val_accuracy: 0.7000 - val_loss: 1.1705 - Epoch 72/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7518 - loss: 0.6276 - val_accuracy: 0.7625 - val_loss: 1.1478 - Epoch 73/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8192 - loss: 0.5830 - val_accuracy: 0.6750 - val_loss: 1.1484 - Epoch 74/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8044 - loss: 0.6725 - val_accuracy: 0.7500 - val_loss: 1.1518 - Epoch 75/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.7974 - loss: 0.5536 - val_accuracy: 0.6625 - val_loss: 1.2326 - Epoch 76/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.7249 - loss: 0.7748 - val_accuracy: 0.7500 - val_loss: 1.1622 - Epoch 77/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8083 - loss: 0.5952 - val_accuracy: 0.7125 - val_loss: 1.1240 - Epoch 78/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8133 - loss: 0.5249 - val_accuracy: 0.7000 - val_loss: 1.1463 - Epoch 79/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8088 - loss: 0.5889 - val_accuracy: 0.7375 - val_loss: 1.0684 - Epoch 80/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8715 - loss: 0.4484 - val_accuracy: 0.7500 - val_loss: 1.0295 - Epoch 81/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.8099 - loss: 0.5720 - val_accuracy: 0.7125 - val_loss: 1.0846 - Epoch 82/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8377 - loss: 0.5405 - val_accuracy: 0.7250 - val_loss: 1.0810 - Epoch 83/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.7981 - loss: 0.5354 - val_accuracy: 0.7250 - val_loss: 1.0617 - Epoch 84/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.7894 - loss: 0.5246 - val_accuracy: 0.7625 - val_loss: 1.0503 - Epoch 85/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8695 - loss: 0.4168 - val_accuracy: 0.7125 - val_loss: 1.1376 - Epoch 86/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.7566 - loss: 0.6546 - val_accuracy: 0.7250 - val_loss: 1.0920 - Epoch 87/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8146 - loss: 0.5367 - val_accuracy: 0.6750 - val_loss: 1.0721 - Epoch 88/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.8836 - loss: 0.4781 - val_accuracy: 0.7625 - val_loss: 1.0165 - Epoch 89/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.8691 - loss: 0.4114 - val_accuracy: 0.7500 - val_loss: 0.9928 - Epoch 90/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.8794 - loss: 0.4078 - val_accuracy: 0.7750 - val_loss: 0.9922 - Epoch 91/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8698 - loss: 0.4249 - val_accuracy: 0.7375 - val_loss: 1.0113 - Epoch 92/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8553 - loss: 0.4388 - val_accuracy: 0.6875 - val_loss: 1.1355 - Epoch 93/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8322 - loss: 0.5300 - val_accuracy: 0.7375 - val_loss: 1.0236 - Epoch 94/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9123 - loss: 0.4124 - val_accuracy: 0.7625 - val_loss: 0.9826 - Epoch 95/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8403 - loss: 0.4664 - val_accuracy: 0.7750 - val_loss: 0.9689 - Epoch 96/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8281 - loss: 0.4742 - val_accuracy: 0.7250 - val_loss: 1.1120 - Epoch 97/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8416 - loss: 0.4398 - val_accuracy: 0.7375 - val_loss: 1.0888 - Epoch 98/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8671 - loss: 0.4704 - val_accuracy: 0.6625 - val_loss: 1.0802 - Epoch 99/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.8976 - loss: 0.3859 - val_accuracy: 0.8000 - val_loss: 0.9549 - Epoch 100/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8579 - loss: 0.4120 - val_accuracy: 0.7000 - val_loss: 1.0427 - Epoch 101/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8420 - loss: 0.4820 - val_accuracy: 0.7500 - val_loss: 0.9615 - Epoch 102/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8501 - loss: 0.4540 - val_accuracy: 0.7625 - val_loss: 0.9078 - Epoch 103/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8569 - loss: 0.3727 - val_accuracy: 0.6750 - val_loss: 0.9443 - Epoch 104/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9123 - loss: 0.2994 - val_accuracy: 0.6875 - val_loss: 0.9821 - Epoch 105/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8797 - loss: 0.3424 - val_accuracy: 0.7750 - val_loss: 0.9252 - Epoch 106/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8501 - loss: 0.4048 - val_accuracy: 0.7750 - val_loss: 0.9589 - Epoch 107/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8604 - loss: 0.3666 - val_accuracy: 0.7375 - val_loss: 0.9306 - Epoch 108/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9082 - loss: 0.3093 - val_accuracy: 0.7250 - val_loss: 0.9925 - Epoch 109/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.8382 - loss: 0.4424 - val_accuracy: 0.7875 - val_loss: 0.8926 - Epoch 110/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9047 - loss: 0.3130 - val_accuracy: 0.7375 - val_loss: 0.9806 - Epoch 111/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8886 - loss: 0.3073 - val_accuracy: 0.7375 - val_loss: 0.9880 - Epoch 112/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9027 - loss: 0.3040 - val_accuracy: 0.6875 - val_loss: 1.0214 - Epoch 113/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8932 - loss: 0.4064 - val_accuracy: 0.7125 - val_loss: 1.0849 - Epoch 114/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8624 - loss: 0.4336 - val_accuracy: 0.8000 - val_loss: 0.9287 - Epoch 115/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8925 - loss: 0.4030 - val_accuracy: 0.7625 - val_loss: 0.9044 - Epoch 116/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.8922 - loss: 0.3145 - val_accuracy: 0.7750 - val_loss: 0.8441 - Epoch 117/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9369 - loss: 0.2919 - val_accuracy: 0.7625 - val_loss: 0.8530 - Epoch 118/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9051 - loss: 0.2753 - val_accuracy: 0.7250 - val_loss: 0.9205 - Epoch 119/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9144 - loss: 0.2948 - val_accuracy: 0.7000 - val_loss: 0.9843 - Epoch 120/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9043 - loss: 0.3258 - val_accuracy: 0.7125 - val_loss: 0.9686 - Epoch 121/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9383 - loss: 0.2482 - val_accuracy: 0.7125 - val_loss: 0.9158 - Epoch 122/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9314 - loss: 0.3248 - val_accuracy: 0.7000 - val_loss: 1.0416 - Epoch 123/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8713 - loss: 0.3495 - val_accuracy: 0.7125 - val_loss: 0.9176 - Epoch 124/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8660 - loss: 0.3550 - val_accuracy: 0.7750 - val_loss: 0.9248 - Epoch 125/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9375 - loss: 0.2040 - val_accuracy: 0.7875 - val_loss: 0.8526 - Epoch 126/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9521 - loss: 0.2011 - val_accuracy: 0.7750 - val_loss: 0.8185 - Epoch 127/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9070 - loss: 0.2604 - val_accuracy: 0.7875 - val_loss: 0.8706 - Epoch 128/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8554 - loss: 0.3367 - val_accuracy: 0.6750 - val_loss: 1.0503 - Epoch 129/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8305 - loss: 0.5195 - val_accuracy: 0.7500 - val_loss: 0.9261 - Epoch 130/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8939 - loss: 0.3566 - val_accuracy: 0.7875 - val_loss: 0.8478 - Epoch 131/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9220 - loss: 0.2700 - val_accuracy: 0.7625 - val_loss: 0.8353 - Epoch 132/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.8607 - loss: 0.3409 - val_accuracy: 0.7750 - val_loss: 0.8898 - Epoch 133/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8637 - loss: 0.3109 - val_accuracy: 0.7125 - val_loss: 0.9377 - Epoch 134/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.8967 - loss: 0.3634 - val_accuracy: 0.7500 - val_loss: 0.9168 - Epoch 135/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9148 - loss: 0.2964 - val_accuracy: 0.7250 - val_loss: 0.8667 - Epoch 136/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9322 - loss: 0.2350 - val_accuracy: 0.7625 - val_loss: 0.8509 - Epoch 137/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9591 - loss: 0.1990 - val_accuracy: 0.8125 - val_loss: 0.7958 - Epoch 138/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9115 - loss: 0.2270 - val_accuracy: 0.7250 - val_loss: 0.8488 - Epoch 139/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9749 - loss: 0.1524 - val_accuracy: 0.7750 - val_loss: 0.7888 - Epoch 140/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9682 - loss: 0.1539 - val_accuracy: 0.8125 - val_loss: 0.7912 - Epoch 141/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9379 - loss: 0.1751 - val_accuracy: 0.8125 - val_loss: 0.8002 - Epoch 142/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9681 - loss: 0.1103 - val_accuracy: 0.7750 - val_loss: 0.7951 - Epoch 143/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9728 - loss: 0.1513 - val_accuracy: 0.7125 - val_loss: 0.8118 - Epoch 144/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9460 - loss: 0.1630 - val_accuracy: 0.8125 - val_loss: 0.7843 - Epoch 145/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9627 - loss: 0.1494 - val_accuracy: 0.7625 - val_loss: 0.8179 - Epoch 146/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9207 - loss: 0.2203 - val_accuracy: 0.7500 - val_loss: 0.8580 - Epoch 147/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9507 - loss: 0.1636 - val_accuracy: 0.7875 - val_loss: 0.7897 - Epoch 148/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9562 - loss: 0.1523 - val_accuracy: 0.7625 - val_loss: 0.7950 - Epoch 149/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9643 - loss: 0.1464 - val_accuracy: 0.7500 - val_loss: 0.8591 - Epoch 150/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9449 - loss: 0.1604 - val_accuracy: 0.7250 - val_loss: 0.9112 - Epoch 151/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9043 - loss: 0.2253 - val_accuracy: 0.7875 - val_loss: 0.7553 - Epoch 152/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9459 - loss: 0.1466 - val_accuracy: 0.7250 - val_loss: 0.7929 - Epoch 153/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9509 - loss: 0.1329 - val_accuracy: 0.8000 - val_loss: 0.7272 - Epoch 154/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9458 - loss: 0.2293 - val_accuracy: 0.7500 - val_loss: 0.7482 - Epoch 155/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9596 - loss: 0.1434 - val_accuracy: 0.7750 - val_loss: 0.7726 - Epoch 156/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9428 - loss: 0.1471 - val_accuracy: 0.8250 - val_loss: 0.7562 - Epoch 157/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9775 - loss: 0.1568 - val_accuracy: 0.7625 - val_loss: 0.7586 - Epoch 158/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9256 - loss: 0.1936 - val_accuracy: 0.7750 - val_loss: 0.8041 - Epoch 159/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9507 - loss: 0.1620 - val_accuracy: 0.7000 - val_loss: 0.9265 - Epoch 160/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9545 - loss: 0.2093 - val_accuracy: 0.7875 - val_loss: 0.7786 - Epoch 161/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9428 - loss: 0.1747 - val_accuracy: 0.7250 - val_loss: 0.8367 - Epoch 162/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9377 - loss: 0.2172 - val_accuracy: 0.7625 - val_loss: 0.7964 - Epoch 163/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9509 - loss: 0.1753 - val_accuracy: 0.7500 - val_loss: 0.7437 - Epoch 164/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9694 - loss: 0.1197 - val_accuracy: 0.7750 - val_loss: 0.7330 - Epoch 165/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9594 - loss: 0.1065 - val_accuracy: 0.7375 - val_loss: 0.8036 - Epoch 166/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9752 - loss: 0.1265 - val_accuracy: 0.7000 - val_loss: 0.8316 - Epoch 167/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9121 - loss: 0.1863 - val_accuracy: 0.7500 - val_loss: 0.7953 - Epoch 168/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9320 - loss: 0.1759 - val_accuracy: 0.8000 - val_loss: 0.8142 - Epoch 169/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9613 - loss: 0.1785 - val_accuracy: 0.7625 - val_loss: 0.7585 - Epoch 170/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9666 - loss: 0.1096 - val_accuracy: 0.7875 - val_loss: 0.7595 - Epoch 171/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9518 - loss: 0.1422 - val_accuracy: 0.7875 - val_loss: 0.7417 - Epoch 172/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9689 - loss: 0.1236 - val_accuracy: 0.7625 - val_loss: 0.7539 - Epoch 173/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9959 - loss: 0.0662 - val_accuracy: 0.7875 - val_loss: 0.6840 - Epoch 174/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9835 - loss: 0.0803 - val_accuracy: 0.7500 - val_loss: 0.7929 - Epoch 175/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9319 - loss: 0.1924 - val_accuracy: 0.7500 - val_loss: 0.8044 - Epoch 176/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9290 - loss: 0.2342 - val_accuracy: 0.8000 - val_loss: 0.7280 - Epoch 177/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9446 - loss: 0.1692 - val_accuracy: 0.7500 - val_loss: 0.7537 - Epoch 178/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9868 - loss: 0.0925 - val_accuracy: 0.8000 - val_loss: 0.7145 - Epoch 179/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9788 - loss: 0.1382 - val_accuracy: 0.7625 - val_loss: 0.7860 - Epoch 180/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9771 - loss: 0.0829 - val_accuracy: 0.8125 - val_loss: 0.6933 - Epoch 181/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9602 - loss: 0.1095 - val_accuracy: 0.7750 - val_loss: 0.7213 - Epoch 182/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9723 - loss: 0.1172 - val_accuracy: 0.7500 - val_loss: 0.7286 - Epoch 183/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9532 - loss: 0.1564 - val_accuracy: 0.7875 - val_loss: 0.7060 - Epoch 184/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - accuracy: 0.9789 - loss: 0.0840 - val_accuracy: 0.8125 - val_loss: 0.6554 - Epoch 185/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9857 - loss: 0.0764 - val_accuracy: 0.7875 - val_loss: 0.7785 - Epoch 186/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9849 - loss: 0.0791 - val_accuracy: 0.7625 - val_loss: 0.7358 - Epoch 187/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9702 - loss: 0.0919 - val_accuracy: 0.7500 - val_loss: 0.7888 - Epoch 188/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9931 - loss: 0.0779 - val_accuracy: 0.7625 - val_loss: 0.7874 - Epoch 189/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9604 - loss: 0.1247 - val_accuracy: 0.7875 - val_loss: 0.7642 - Epoch 190/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9402 - loss: 0.1906 - val_accuracy: 0.7875 - val_loss: 0.8763 - Epoch 191/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9845 - loss: 0.1111 - val_accuracy: 0.7875 - val_loss: 0.6824 - Epoch 192/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9899 - loss: 0.0591 - val_accuracy: 0.8000 - val_loss: 0.6591 - Epoch 193/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9716 - loss: 0.1055 - val_accuracy: 0.7625 - val_loss: 0.7776 - Epoch 194/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9750 - loss: 0.0953 - val_accuracy: 0.7250 - val_loss: 0.7947 - Epoch 195/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9765 - loss: 0.0889 - val_accuracy: 0.7375 - val_loss: 0.7190 - Epoch 196/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9741 - loss: 0.0896 - val_accuracy: 0.8000 - val_loss: 0.7058 - Epoch 197/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9586 - loss: 0.0916 - val_accuracy: 0.7625 - val_loss: 0.7676 - Epoch 198/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9955 - loss: 0.0655 - val_accuracy: 0.7625 - val_loss: 0.7047 - Epoch 199/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9861 - loss: 0.0663 - val_accuracy: 0.7750 - val_loss: 0.7760 - Epoch 200/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - accuracy: 0.9982 - loss: 0.0558 - val_accuracy: 0.7750 - val_loss: 0.6585 - - -### Create the 2D model - -1. Create three spectrograms with multiple band-widths from the raw input. -2. Concatenate the three spectrograms to have three channels. -3. Load `MobileNet` and set the weights from the weights trained on `ImageNet`. -4. Apply global maximum pooling to have fixed set of features. -5. Add `Dense` layers to make the final predictions based on the features. - - -```python -input = layers.Input((None, 1)) -spectrograms = [ - layers.STFTSpectrogram( - mode="log", - frame_length=SAMPLE_RATE * frame_size // 1000, - frame_step=SAMPLE_RATE * 15 // 1000, - fft_length=2048, - padding="same", - expand_dims=True, - # trainable=True, # trainable by default - )(input) - for frame_size in [30, 40, 50] # frame size in milliseconds -] - -multi_spectrograms = layers.Concatenate(axis=-1)(spectrograms) - -img_model = keras.applications.MobileNet(include_top=False, pooling="max") -output = img_model(multi_spectrograms) - -output = layers.Dropout(0.5)(output) -output = layers.Dense(256, activation="relu")(output) -output = layers.Dense(256, activation="relu")(output) -output = layers.Dense(NUM_CLASSES, activation="softmax")(output) -model2d = keras.Model(input, output, name="model_2d_trainble_stft") - -model2d.compile( - optimizer=keras.optimizers.Adam(1e-4), - loss="sparse_categorical_crossentropy", - metrics=["accuracy"], -) -model2d.summary() -``` - - :17: UserWarning: `input_shape` is undefined or non-square, or `rows` is not in [128, 160, 192, 224]. Weights for input shape (224, 224) will be loaded as the default. - img_model = keras.applications.MobileNet(include_top=False, pooling="max") - - - -
Model: "model_2d_trainble_stft"
-
- - - - -
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
-┃ Layer (type)               Output Shape                   Param #  Connected to           ┃
-┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
-│ input_layer_1             │ (None, None, 1)        │              0 │ -                      │
-│ (InputLayer)              │                        │                │                        │
-├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
-│ stft_spectrogram_5        │ (None, None, 1025, 1)  │        984,000 │ input_layer_1[0][0]    │
-│ (STFTSpectrogram)         │                        │                │                        │
-├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
-│ stft_spectrogram_6        │ (None, None, 1025, 1)  │      1,312,000 │ input_layer_1[0][0]    │
-│ (STFTSpectrogram)         │                        │                │                        │
-├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
-│ stft_spectrogram_7        │ (None, None, 1025, 1)  │      1,640,000 │ input_layer_1[0][0]    │
-│ (STFTSpectrogram)         │                        │                │                        │
-├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
-│ concatenate (Concatenate) │ (None, None, 1025, 3)  │              0 │ stft_spectrogram_5[0]… │
-│                           │                        │                │ stft_spectrogram_6[0]… │
-│                           │                        │                │ stft_spectrogram_7[0]… │
-├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
-│ mobilenet_1.00_224        │ (None, 1024)           │      3,228,864 │ concatenate[0][0]      │
-│ (Functional)              │                        │                │                        │
-├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
-│ dropout_2 (Dropout)       │ (None, 1024)           │              0 │ mobilenet_1.00_224[0]… │
-├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
-│ dense_3 (Dense)           │ (None, 256)            │        262,400 │ dropout_2[0][0]        │
-├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
-│ dense_4 (Dense)           │ (None, 256)            │         65,792 │ dense_3[0][0]          │
-├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
-│ dense_5 (Dense)           │ (None, 10)             │          2,570 │ dense_4[0][0]          │
-└───────────────────────────┴────────────────────────┴────────────────┴────────────────────────┘
-
- - - - -
 Total params: 7,495,626 (28.59 MB)
-
- - - - -
 Trainable params: 7,473,738 (28.51 MB)
-
- - - - -
 Non-trainable params: 21,888 (85.50 KB)
-
- - - -Train the model and restore the best weights. - - -```python -history_model2d = model2d.fit( - train_x, - train_y, - batch_size=BATCH_SIZE, - validation_data=(valid_x, valid_y), - epochs=EPOCHS, - callbacks=[ - keras.callbacks.EarlyStopping( - monitor="val_loss", - patience=EPOCHS, - restore_best_weights=True, - ) - ], -) -``` - - Epoch 1/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 50s 776ms/step - accuracy: 0.0855 - loss: 7.6484 - val_accuracy: 0.0625 - val_loss: 3.7484 - Epoch 2/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 8s 55ms/step - accuracy: 0.1293 - loss: 5.8848 - val_accuracy: 0.0750 - val_loss: 4.0622 - Epoch 3/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.1302 - loss: 4.6363 - val_accuracy: 0.0875 - val_loss: 3.6488 - Epoch 4/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.1656 - loss: 4.6861 - val_accuracy: 0.1250 - val_loss: 3.5224 - Epoch 5/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.2025 - loss: 4.3601 - val_accuracy: 0.0875 - val_loss: 4.0424 - Epoch 6/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.2072 - loss: 3.8723 - val_accuracy: 0.1125 - val_loss: 3.1530 - Epoch 7/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.2562 - loss: 3.2596 - val_accuracy: 0.1125 - val_loss: 2.9712 - Epoch 8/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.2328 - loss: 3.1374 - val_accuracy: 0.1375 - val_loss: 3.0128 - Epoch 9/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.3296 - loss: 2.6887 - val_accuracy: 0.1750 - val_loss: 2.6742 - Epoch 10/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.3123 - loss: 2.4022 - val_accuracy: 0.1750 - val_loss: 2.7165 - Epoch 11/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.3781 - loss: 2.3441 - val_accuracy: 0.1875 - val_loss: 2.1900 - Epoch 12/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.4524 - loss: 2.0044 - val_accuracy: 0.3250 - val_loss: 1.8786 - Epoch 13/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.3609 - loss: 2.0790 - val_accuracy: 0.3750 - val_loss: 1.7390 - Epoch 14/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.5158 - loss: 1.6717 - val_accuracy: 0.3750 - val_loss: 1.5660 - Epoch 15/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.5080 - loss: 1.6551 - val_accuracy: 0.4125 - val_loss: 1.6085 - Epoch 16/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.5921 - loss: 1.4493 - val_accuracy: 0.5250 - val_loss: 1.2603 - Epoch 17/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.5404 - loss: 1.4931 - val_accuracy: 0.6000 - val_loss: 1.0863 - Epoch 18/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.6492 - loss: 1.0411 - val_accuracy: 0.6000 - val_loss: 1.0920 - Epoch 19/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.5987 - loss: 1.3023 - val_accuracy: 0.5625 - val_loss: 1.0882 - Epoch 20/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.5950 - loss: 1.2483 - val_accuracy: 0.5500 - val_loss: 1.0755 - Epoch 21/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.5789 - loss: 1.1988 - val_accuracy: 0.5875 - val_loss: 0.9171 - Epoch 22/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.6694 - loss: 1.0415 - val_accuracy: 0.6875 - val_loss: 0.8319 - Epoch 23/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 53ms/step - accuracy: 0.7705 - loss: 0.8017 - val_accuracy: 0.6750 - val_loss: 0.8824 - Epoch 24/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.6693 - loss: 1.0069 - val_accuracy: 0.7500 - val_loss: 0.6454 - Epoch 25/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.6997 - loss: 0.8689 - val_accuracy: 0.7250 - val_loss: 0.7640 - Epoch 26/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.6816 - loss: 0.8254 - val_accuracy: 0.7500 - val_loss: 0.6418 - Epoch 27/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.6524 - loss: 1.1302 - val_accuracy: 0.7375 - val_loss: 0.7160 - Epoch 28/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.7624 - loss: 0.7522 - val_accuracy: 0.7875 - val_loss: 0.6805 - Epoch 29/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.6926 - loss: 0.8897 - val_accuracy: 0.7500 - val_loss: 0.6289 - Epoch 30/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.7190 - loss: 0.7467 - val_accuracy: 0.7375 - val_loss: 0.5838 - Epoch 31/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.7171 - loss: 0.7727 - val_accuracy: 0.8250 - val_loss: 0.6101 - Epoch 32/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.8120 - loss: 0.5287 - val_accuracy: 0.8625 - val_loss: 0.4229 - Epoch 33/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.7921 - loss: 0.5581 - val_accuracy: 0.8250 - val_loss: 0.4174 - Epoch 34/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.8056 - loss: 0.5415 - val_accuracy: 0.8500 - val_loss: 0.4672 - Epoch 35/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 50ms/step - accuracy: 0.7601 - loss: 0.5661 - val_accuracy: 0.8250 - val_loss: 0.4791 - Epoch 36/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.7866 - loss: 0.5135 - val_accuracy: 0.8750 - val_loss: 0.4217 - Epoch 37/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.8660 - loss: 0.3952 - val_accuracy: 0.8250 - val_loss: 0.4561 - Epoch 38/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.8446 - loss: 0.3751 - val_accuracy: 0.9000 - val_loss: 0.3954 - Epoch 39/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.8546 - loss: 0.3984 - val_accuracy: 0.8375 - val_loss: 0.4534 - Epoch 40/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.8655 - loss: 0.3541 - val_accuracy: 0.8875 - val_loss: 0.3718 - Epoch 41/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.8592 - loss: 0.4164 - val_accuracy: 0.8750 - val_loss: 0.4537 - Epoch 42/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9093 - loss: 0.2404 - val_accuracy: 0.8625 - val_loss: 0.4169 - Epoch 43/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.9329 - loss: 0.1855 - val_accuracy: 0.8750 - val_loss: 0.3354 - Epoch 44/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.8353 - loss: 0.4455 - val_accuracy: 0.8750 - val_loss: 0.3619 - Epoch 45/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.9135 - loss: 0.2196 - val_accuracy: 0.8750 - val_loss: 0.3313 - Epoch 46/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.9129 - loss: 0.2131 - val_accuracy: 0.8875 - val_loss: 0.3199 - Epoch 47/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.9467 - loss: 0.1264 - val_accuracy: 0.8875 - val_loss: 0.3162 - Epoch 48/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.9281 - loss: 0.2276 - val_accuracy: 0.8875 - val_loss: 0.3158 - Epoch 49/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9211 - loss: 0.2044 - val_accuracy: 0.8375 - val_loss: 0.3702 - Epoch 50/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.9247 - loss: 0.1954 - val_accuracy: 0.8750 - val_loss: 0.2875 - Epoch 51/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.9534 - loss: 0.1122 - val_accuracy: 0.9000 - val_loss: 0.2637 - Epoch 52/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.9596 - loss: 0.1261 - val_accuracy: 0.9125 - val_loss: 0.2370 - Epoch 53/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9388 - loss: 0.1679 - val_accuracy: 0.9125 - val_loss: 0.2506 - Epoch 54/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9635 - loss: 0.1075 - val_accuracy: 0.9125 - val_loss: 0.2656 - Epoch 55/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9511 - loss: 0.1666 - val_accuracy: 0.9000 - val_loss: 0.2998 - Epoch 56/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9688 - loss: 0.0860 - val_accuracy: 0.9000 - val_loss: 0.2730 - Epoch 57/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9786 - loss: 0.0796 - val_accuracy: 0.8875 - val_loss: 0.2837 - Epoch 58/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9421 - loss: 0.1239 - val_accuracy: 0.8750 - val_loss: 0.2829 - Epoch 59/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9392 - loss: 0.2626 - val_accuracy: 0.8750 - val_loss: 0.3105 - Epoch 60/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9395 - loss: 0.1321 - val_accuracy: 0.9000 - val_loss: 0.2529 - Epoch 61/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9679 - loss: 0.0968 - val_accuracy: 0.8750 - val_loss: 0.2506 - Epoch 62/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9437 - loss: 0.1074 - val_accuracy: 0.9000 - val_loss: 0.2950 - Epoch 63/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9615 - loss: 0.0958 - val_accuracy: 0.8750 - val_loss: 0.3064 - Epoch 64/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9755 - loss: 0.0601 - val_accuracy: 0.9000 - val_loss: 0.2795 - Epoch 65/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.9723 - loss: 0.0673 - val_accuracy: 0.9125 - val_loss: 0.2123 - Epoch 66/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.9464 - loss: 0.1619 - val_accuracy: 0.9375 - val_loss: 0.1930 - Epoch 67/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.9863 - loss: 0.0445 - val_accuracy: 0.9250 - val_loss: 0.1866 - Epoch 68/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9823 - loss: 0.0678 - val_accuracy: 0.9125 - val_loss: 0.2109 - Epoch 69/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9855 - loss: 0.0579 - val_accuracy: 0.9375 - val_loss: 0.2088 - Epoch 70/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.9800 - loss: 0.0549 - val_accuracy: 0.9625 - val_loss: 0.1693 - Epoch 71/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9861 - loss: 0.0469 - val_accuracy: 0.9500 - val_loss: 0.1738 - Epoch 72/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9876 - loss: 0.0685 - val_accuracy: 0.9375 - val_loss: 0.2090 - Epoch 73/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9605 - loss: 0.0835 - val_accuracy: 0.8875 - val_loss: 0.2828 - Epoch 74/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9783 - loss: 0.0475 - val_accuracy: 0.8875 - val_loss: 0.2500 - Epoch 75/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9871 - loss: 0.0470 - val_accuracy: 0.9000 - val_loss: 0.2094 - Epoch 76/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9881 - loss: 0.0405 - val_accuracy: 0.9500 - val_loss: 0.1971 - Epoch 77/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 45ms/step - accuracy: 0.9736 - loss: 0.0418 - val_accuracy: 0.9375 - val_loss: 0.2014 - Epoch 78/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9582 - loss: 0.1145 - val_accuracy: 0.9125 - val_loss: 0.2082 - Epoch 79/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9831 - loss: 0.0586 - val_accuracy: 0.9125 - val_loss: 0.2109 - Epoch 80/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9574 - loss: 0.0950 - val_accuracy: 0.9000 - val_loss: 0.3043 - Epoch 81/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9964 - loss: 0.0253 - val_accuracy: 0.9250 - val_loss: 0.2476 - Epoch 82/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9838 - loss: 0.0427 - val_accuracy: 0.9125 - val_loss: 0.2480 - Epoch 83/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0094 - val_accuracy: 0.9250 - val_loss: 0.2614 - Epoch 84/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9929 - loss: 0.0256 - val_accuracy: 0.9250 - val_loss: 0.2504 - Epoch 85/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9953 - loss: 0.0215 - val_accuracy: 0.9250 - val_loss: 0.2334 - Epoch 86/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9939 - loss: 0.0200 - val_accuracy: 0.9500 - val_loss: 0.2138 - Epoch 87/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0133 - val_accuracy: 0.9500 - val_loss: 0.2167 - Epoch 88/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9907 - loss: 0.0303 - val_accuracy: 0.9125 - val_loss: 0.2326 - Epoch 89/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9883 - loss: 0.0406 - val_accuracy: 0.9500 - val_loss: 0.2000 - Epoch 90/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9932 - loss: 0.0292 - val_accuracy: 0.9375 - val_loss: 0.1961 - Epoch 91/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9756 - loss: 0.1435 - val_accuracy: 0.9375 - val_loss: 0.2093 - Epoch 92/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9762 - loss: 0.0868 - val_accuracy: 0.9375 - val_loss: 0.2081 - Epoch 93/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9925 - loss: 0.0391 - val_accuracy: 0.9375 - val_loss: 0.1890 - Epoch 94/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9961 - loss: 0.0324 - val_accuracy: 0.9250 - val_loss: 0.2047 - Epoch 95/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9955 - loss: 0.0208 - val_accuracy: 0.8875 - val_loss: 0.2223 - Epoch 96/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9841 - loss: 0.0363 - val_accuracy: 0.9125 - val_loss: 0.1951 - Epoch 97/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9835 - loss: 0.0384 - val_accuracy: 0.9250 - val_loss: 0.1983 - Epoch 98/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9801 - loss: 0.0662 - val_accuracy: 0.9375 - val_loss: 0.2212 - Epoch 99/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9957 - loss: 0.0206 - val_accuracy: 0.9125 - val_loss: 0.2114 - Epoch 100/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9947 - loss: 0.0318 - val_accuracy: 0.9125 - val_loss: 0.1936 - Epoch 101/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0153 - val_accuracy: 0.9250 - val_loss: 0.1731 - Epoch 102/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9946 - loss: 0.0219 - val_accuracy: 0.9250 - val_loss: 0.1804 - Epoch 103/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 1.0000 - loss: 0.0092 - val_accuracy: 0.9125 - val_loss: 0.1641 - Epoch 104/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 45ms/step - accuracy: 0.9811 - loss: 0.0325 - val_accuracy: 0.9250 - val_loss: 0.1796 - Epoch 105/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9850 - loss: 0.0276 - val_accuracy: 0.9375 - val_loss: 0.1738 - Epoch 106/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0074 - val_accuracy: 0.9125 - val_loss: 0.1991 - Epoch 107/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9873 - loss: 0.0487 - val_accuracy: 0.9125 - val_loss: 0.1900 - Epoch 108/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 45ms/step - accuracy: 0.9951 - loss: 0.0224 - val_accuracy: 0.9000 - val_loss: 0.1935 - Epoch 109/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9790 - loss: 0.0544 - val_accuracy: 0.9375 - val_loss: 0.1995 - Epoch 110/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0061 - val_accuracy: 0.9375 - val_loss: 0.1956 - Epoch 111/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9968 - loss: 0.0158 - val_accuracy: 0.9375 - val_loss: 0.1800 - Epoch 112/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9912 - loss: 0.0273 - val_accuracy: 0.9125 - val_loss: 0.1894 - Epoch 113/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9939 - loss: 0.0118 - val_accuracy: 0.9250 - val_loss: 0.1858 - Epoch 114/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9943 - loss: 0.0308 - val_accuracy: 0.9250 - val_loss: 0.1713 - Epoch 115/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9950 - loss: 0.0152 - val_accuracy: 0.9250 - val_loss: 0.1794 - Epoch 116/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0084 - val_accuracy: 0.9375 - val_loss: 0.1895 - Epoch 117/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.9947 - loss: 0.0174 - val_accuracy: 0.9500 - val_loss: 0.1563 - Epoch 118/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 1.0000 - loss: 0.0055 - val_accuracy: 0.9500 - val_loss: 0.1477 - Epoch 119/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9763 - loss: 0.0478 - val_accuracy: 0.9000 - val_loss: 0.1918 - Epoch 120/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9958 - loss: 0.0135 - val_accuracy: 0.8875 - val_loss: 0.2846 - Epoch 121/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9934 - loss: 0.0334 - val_accuracy: 0.9375 - val_loss: 0.1980 - Epoch 122/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9943 - loss: 0.0203 - val_accuracy: 0.9500 - val_loss: 0.1832 - Epoch 123/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9801 - loss: 0.0573 - val_accuracy: 0.9250 - val_loss: 0.2416 - Epoch 124/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9949 - loss: 0.0334 - val_accuracy: 0.9375 - val_loss: 0.1865 - Epoch 125/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.9933 - loss: 0.0120 - val_accuracy: 0.9500 - val_loss: 0.1340 - Epoch 126/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9944 - loss: 0.0126 - val_accuracy: 0.9250 - val_loss: 0.1565 - Epoch 127/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 45ms/step - accuracy: 0.9949 - loss: 0.0143 - val_accuracy: 0.9125 - val_loss: 0.2242 - Epoch 128/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9941 - loss: 0.0138 - val_accuracy: 0.9500 - val_loss: 0.1581 - Epoch 129/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.9992 - loss: 0.0128 - val_accuracy: 0.9500 - val_loss: 0.1274 - Epoch 130/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9966 - loss: 0.0123 - val_accuracy: 0.9625 - val_loss: 0.1514 - Epoch 131/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9873 - loss: 0.0401 - val_accuracy: 0.9375 - val_loss: 0.1517 - Epoch 132/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9784 - loss: 0.0407 - val_accuracy: 0.9375 - val_loss: 0.1771 - Epoch 133/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9982 - loss: 0.0108 - val_accuracy: 0.9250 - val_loss: 0.2291 - Epoch 134/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9957 - loss: 0.0185 - val_accuracy: 0.9000 - val_loss: 0.3030 - Epoch 135/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9771 - loss: 0.0511 - val_accuracy: 0.9250 - val_loss: 0.2313 - Epoch 136/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9965 - loss: 0.0162 - val_accuracy: 0.9375 - val_loss: 0.1983 - Epoch 137/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9829 - loss: 0.0797 - val_accuracy: 0.9500 - val_loss: 0.1685 - Epoch 138/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9910 - loss: 0.0352 - val_accuracy: 0.9625 - val_loss: 0.1578 - Epoch 139/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9818 - loss: 0.0346 - val_accuracy: 0.9375 - val_loss: 0.1616 - Epoch 140/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0079 - val_accuracy: 0.9375 - val_loss: 0.1702 - Epoch 141/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0095 - val_accuracy: 0.9750 - val_loss: 0.1386 - Epoch 142/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.9987 - loss: 0.0081 - val_accuracy: 0.9750 - val_loss: 0.1187 - Epoch 143/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0020 - val_accuracy: 0.9750 - val_loss: 0.1209 - Epoch 144/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.9763 - loss: 0.0806 - val_accuracy: 0.9625 - val_loss: 0.1177 - Epoch 145/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9905 - loss: 0.0263 - val_accuracy: 0.9125 - val_loss: 0.2067 - Epoch 146/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0086 - val_accuracy: 0.9125 - val_loss: 0.2563 - Epoch 147/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9746 - loss: 0.1065 - val_accuracy: 0.9375 - val_loss: 0.2253 - Epoch 148/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9799 - loss: 0.0885 - val_accuracy: 0.9625 - val_loss: 0.1564 - Epoch 149/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9955 - loss: 0.0290 - val_accuracy: 0.9250 - val_loss: 0.2414 - Epoch 150/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9727 - loss: 0.0846 - val_accuracy: 0.9125 - val_loss: 0.2415 - Epoch 151/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9973 - loss: 0.0157 - val_accuracy: 0.9000 - val_loss: 0.3168 - Epoch 152/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9827 - loss: 0.0280 - val_accuracy: 0.9125 - val_loss: 0.2191 - Epoch 153/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9856 - loss: 0.0289 - val_accuracy: 0.9500 - val_loss: 0.1684 - Epoch 154/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9993 - loss: 0.0128 - val_accuracy: 0.9625 - val_loss: 0.1246 - Epoch 155/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.9918 - loss: 0.0194 - val_accuracy: 0.9625 - val_loss: 0.0904 - Epoch 156/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.9992 - loss: 0.0125 - val_accuracy: 0.9625 - val_loss: 0.0854 - Epoch 157/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9986 - loss: 0.0083 - val_accuracy: 0.9500 - val_loss: 0.0979 - Epoch 158/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0062 - val_accuracy: 0.9625 - val_loss: 0.1077 - Epoch 159/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9949 - loss: 0.0305 - val_accuracy: 0.9625 - val_loss: 0.1058 - Epoch 160/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9976 - loss: 0.0084 - val_accuracy: 0.9625 - val_loss: 0.1202 - Epoch 161/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0030 - val_accuracy: 0.9625 - val_loss: 0.1031 - Epoch 162/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9714 - loss: 0.0519 - val_accuracy: 0.9625 - val_loss: 0.1832 - Epoch 163/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0016 - val_accuracy: 0.9250 - val_loss: 0.2786 - Epoch 164/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 45ms/step - accuracy: 0.9733 - loss: 0.0312 - val_accuracy: 0.8750 - val_loss: 0.2878 - Epoch 165/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9897 - loss: 0.0452 - val_accuracy: 0.9375 - val_loss: 0.1482 - Epoch 166/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9956 - loss: 0.0164 - val_accuracy: 0.9500 - val_loss: 0.1278 - Epoch 167/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9934 - loss: 0.0399 - val_accuracy: 0.9375 - val_loss: 0.2300 - Epoch 168/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9900 - loss: 0.0420 - val_accuracy: 0.8875 - val_loss: 0.5143 - Epoch 169/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9869 - loss: 0.0500 - val_accuracy: 0.9125 - val_loss: 0.2374 - Epoch 170/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9849 - loss: 0.0366 - val_accuracy: 0.9125 - val_loss: 0.3109 - Epoch 171/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9918 - loss: 0.0244 - val_accuracy: 0.8875 - val_loss: 0.2994 - Epoch 172/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9979 - loss: 0.0061 - val_accuracy: 0.9375 - val_loss: 0.2885 - Epoch 173/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0073 - val_accuracy: 0.9375 - val_loss: 0.3030 - Epoch 174/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9795 - loss: 0.0277 - val_accuracy: 0.8750 - val_loss: 0.4379 - Epoch 175/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9966 - loss: 0.0176 - val_accuracy: 0.8750 - val_loss: 0.3758 - Epoch 176/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9973 - loss: 0.0046 - val_accuracy: 0.9375 - val_loss: 0.2478 - Epoch 177/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0043 - val_accuracy: 0.9375 - val_loss: 0.2529 - Epoch 178/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0041 - val_accuracy: 0.9250 - val_loss: 0.2604 - Epoch 179/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9973 - loss: 0.0068 - val_accuracy: 0.8875 - val_loss: 0.2902 - Epoch 180/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9866 - loss: 0.0297 - val_accuracy: 0.8625 - val_loss: 0.3225 - Epoch 181/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9935 - loss: 0.0085 - val_accuracy: 0.9000 - val_loss: 0.3310 - Epoch 182/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9930 - loss: 0.0230 - val_accuracy: 0.8875 - val_loss: 0.4211 - Epoch 183/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9981 - loss: 0.0054 - val_accuracy: 0.9125 - val_loss: 0.2929 - Epoch 184/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0136 - val_accuracy: 0.9375 - val_loss: 0.2564 - Epoch 185/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9907 - loss: 0.0160 - val_accuracy: 0.9000 - val_loss: 0.2726 - Epoch 186/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9992 - loss: 0.0036 - val_accuracy: 0.9000 - val_loss: 0.2530 - Epoch 187/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0051 - val_accuracy: 0.9250 - val_loss: 0.2283 - Epoch 188/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0036 - val_accuracy: 0.9250 - val_loss: 0.2084 - Epoch 189/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0012 - val_accuracy: 0.9250 - val_loss: 0.2196 - Epoch 190/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0090 - val_accuracy: 0.9375 - val_loss: 0.2332 - Epoch 191/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9981 - loss: 0.0096 - val_accuracy: 0.9250 - val_loss: 0.2485 - Epoch 192/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9878 - loss: 0.0368 - val_accuracy: 0.9125 - val_loss: 0.3140 - Epoch 193/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0013 - val_accuracy: 0.9125 - val_loss: 0.3289 - Epoch 194/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0091 - val_accuracy: 0.9125 - val_loss: 0.3065 - Epoch 195/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9947 - loss: 0.0131 - val_accuracy: 0.9125 - val_loss: 0.2800 - Epoch 196/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9928 - loss: 0.0078 - val_accuracy: 0.9125 - val_loss: 0.2394 - Epoch 197/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9957 - loss: 0.0133 - val_accuracy: 0.9000 - val_loss: 0.2319 - Epoch 198/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.9993 - loss: 0.0031 - val_accuracy: 0.9125 - val_loss: 0.2119 - Epoch 199/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0014 - val_accuracy: 0.9375 - val_loss: 0.2095 - Epoch 200/200 - 15/15 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 1.0000 - loss: 0.0042 - val_accuracy: 0.9375 - val_loss: 0.1972 - - -### Plot Training History - - -```python -epochs_range = range(EPOCHS) - -plt.figure(figsize=(14, 5)) -plt.subplot(1, 2, 1) -plt.plot( - epochs_range, - history_model1d.history["accuracy"], - label="Training Accuracy,1D model with non-trainable STFT", -) -plt.plot( - epochs_range, - history_model1d.history["val_accuracy"], - label="Validation Accuracy, 1D model with non-trainable STFT", -) -plt.plot( - epochs_range, - history_model2d.history["accuracy"], - label="Training Accuracy, 2D model with trainable STFT", -) -plt.plot( - epochs_range, - history_model2d.history["val_accuracy"], - label="Validation Accuracy, 2D model with trainable STFT", -) -plt.legend(loc="lower right") -plt.title("Training and Validation Accuracy") - -plt.subplot(1, 2, 2) -plt.plot( - epochs_range, - history_model1d.history["loss"], - label="Training Loss,1D model with non-trainable STFT", -) -plt.plot( - epochs_range, - history_model1d.history["val_loss"], - label="Validation Loss, 1D model with non-trainable STFT", -) -plt.plot( - epochs_range, - history_model2d.history["loss"], - label="Training Loss, 2D model with trainable STFT", -) -plt.plot( - epochs_range, - history_model2d.history["val_loss"], - label="Validation Loss, 2D model with trainable STFT", -) -plt.legend(loc="upper right") -plt.title("Training and Validation Loss") -plt.show() -``` - - - -![png](https://github.com/keras-team/keras-io/blob/master/examples/audio/img/stft/training.png) - - - -### Evaluate on Test Data - - -Running the models on the test set. - - -```python -_, test_acc = model1d.evaluate(test_x, test_y) -print(f"1D model wit non-trainable STFT -> Test Accuracy: {test_acc * 100:.2f}%") -``` - - 3/3 ━━━━━━━━━━━━━━━━━━━━ 3s 307ms/step - accuracy: 0.8148 - loss: 0.6244 - 1D model wit non-trainable STFT -> Test Accuracy: 82.50% - - - -```python -_, test_acc = model2d.evaluate(test_x, test_y) -print(f"2D model with trainable STFT -> Test Accuracy: {test_acc * 100:.2f}%") -``` - - 3/3 ━━━━━━━━━━━━━━━━━━━━ 17s 546ms/step - accuracy: 0.9195 - loss: 0.5271 - 2D model with trainable STFT -> Test Accuracy: 92.50% - - - - diff --git a/templates/examples/audio/transformer_asr.md b/templates/examples/audio/transformer_asr.md deleted file mode 100644 index 12581582eb..0000000000 --- a/templates/examples/audio/transformer_asr.md +++ /dev/null @@ -1,618 +0,0 @@ -# Automatic Speech Recognition with Transformer - -**Author:** [Apoorv Nandan](https://twitter.com/NandanApoorv)
-**Date created:** 2021/01/13
-**Last modified:** 2021/01/13
-**Description:** Training a sequence-to-sequence Transformer for automatic speech recognition. - - -
ⓘ This example uses Keras 3
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/audio/ipynb/transformer_asr.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/audio/transformer_asr.py) - - - ---- -## Introduction - -Automatic speech recognition (ASR) consists of transcribing audio speech segments into text. -ASR can be treated as a sequence-to-sequence problem, where the -audio can be represented as a sequence of feature vectors -and the text as a sequence of characters, words, or subword tokens. - -For this demonstration, we will use the LJSpeech dataset from the -[LibriVox](https://librivox.org/) project. It consists of short -audio clips of a single speaker reading passages from 7 non-fiction books. -Our model will be similar to the original Transformer (both encoder and decoder) -as proposed in the paper, "Attention is All You Need". - - -**References:** - -- [Attention is All You Need](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) -- [Very Deep Self-Attention Networks for End-to-End Speech Recognition](https://arxiv.org/abs/1904.13377) -- [Speech Transformers](https://ieeexplore.ieee.org/document/8462506) -- [LJSpeech Dataset](https://keithito.com/LJ-Speech-Dataset/) - - -```python - -import re -import os - -os.environ["KERAS_BACKEND"] = "tensorflow" - -from glob import glob -import tensorflow as tf -import keras -from keras import layers - -pattern_wav_name = re.compile(r'([^/\\\.]+)') - -``` - ---- -## Define the Transformer Input Layer - -When processing past target tokens for the decoder, we compute the sum of -position embeddings and token embeddings. - -When processing audio features, we apply convolutional layers to downsample -them (via convolution strides) and process local relationships. - - -```python - -class TokenEmbedding(layers.Layer): - def __init__(self, num_vocab=1000, maxlen=100, num_hid=64): - super().__init__() - self.emb = keras.layers.Embedding(num_vocab, num_hid) - self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=num_hid) - - def call(self, x): - maxlen = tf.shape(x)[-1] - x = self.emb(x) - positions = tf.range(start=0, limit=maxlen, delta=1) - positions = self.pos_emb(positions) - return x + positions - - -class SpeechFeatureEmbedding(layers.Layer): - def __init__(self, num_hid=64, maxlen=100): - super().__init__() - self.conv1 = keras.layers.Conv1D( - num_hid, 11, strides=2, padding="same", activation="relu" - ) - self.conv2 = keras.layers.Conv1D( - num_hid, 11, strides=2, padding="same", activation="relu" - ) - self.conv3 = keras.layers.Conv1D( - num_hid, 11, strides=2, padding="same", activation="relu" - ) - - def call(self, x): - x = self.conv1(x) - x = self.conv2(x) - return self.conv3(x) - -``` - ---- -## Transformer Encoder Layer - - -```python - -class TransformerEncoder(layers.Layer): - def __init__(self, embed_dim, num_heads, feed_forward_dim, rate=0.1): - super().__init__() - self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) - self.ffn = keras.Sequential( - [ - layers.Dense(feed_forward_dim, activation="relu"), - layers.Dense(embed_dim), - ] - ) - self.layernorm1 = layers.LayerNormalization(epsilon=1e-6) - self.layernorm2 = layers.LayerNormalization(epsilon=1e-6) - self.dropout1 = layers.Dropout(rate) - self.dropout2 = layers.Dropout(rate) - - def call(self, inputs, training=False): - attn_output = self.att(inputs, inputs) - attn_output = self.dropout1(attn_output, training=training) - out1 = self.layernorm1(inputs + attn_output) - ffn_output = self.ffn(out1) - ffn_output = self.dropout2(ffn_output, training=training) - return self.layernorm2(out1 + ffn_output) - -``` - ---- -## Transformer Decoder Layer - - -```python - -class TransformerDecoder(layers.Layer): - def __init__(self, embed_dim, num_heads, feed_forward_dim, dropout_rate=0.1): - super().__init__() - self.layernorm1 = layers.LayerNormalization(epsilon=1e-6) - self.layernorm2 = layers.LayerNormalization(epsilon=1e-6) - self.layernorm3 = layers.LayerNormalization(epsilon=1e-6) - self.self_att = layers.MultiHeadAttention( - num_heads=num_heads, key_dim=embed_dim - ) - self.enc_att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim) - self.self_dropout = layers.Dropout(0.5) - self.enc_dropout = layers.Dropout(0.1) - self.ffn_dropout = layers.Dropout(0.1) - self.ffn = keras.Sequential( - [ - layers.Dense(feed_forward_dim, activation="relu"), - layers.Dense(embed_dim), - ] - ) - - def causal_attention_mask(self, batch_size, n_dest, n_src, dtype): - """Masks the upper half of the dot product matrix in self attention. - - This prevents flow of information from future tokens to current token. - 1's in the lower triangle, counting from the lower right corner. - """ - i = tf.range(n_dest)[:, None] - j = tf.range(n_src) - m = i >= j - n_src + n_dest - mask = tf.cast(m, dtype) - mask = tf.reshape(mask, [1, n_dest, n_src]) - mult = tf.concat( - [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0 - ) - return tf.tile(mask, mult) - - def call(self, enc_out, target): - input_shape = tf.shape(target) - batch_size = input_shape[0] - seq_len = input_shape[1] - causal_mask = self.causal_attention_mask(batch_size, seq_len, seq_len, tf.bool) - target_att = self.self_att(target, target, attention_mask=causal_mask) - target_norm = self.layernorm1(target + self.self_dropout(target_att)) - enc_out = self.enc_att(target_norm, enc_out) - enc_out_norm = self.layernorm2(self.enc_dropout(enc_out) + target_norm) - ffn_out = self.ffn(enc_out_norm) - ffn_out_norm = self.layernorm3(enc_out_norm + self.ffn_dropout(ffn_out)) - return ffn_out_norm - -``` - ---- -## Complete the Transformer model - -Our model takes audio spectrograms as inputs and predicts a sequence of characters. -During training, we give the decoder the target character sequence shifted to the left -as input. During inference, the decoder uses its own past predictions to predict the -next token. - - -```python - -class Transformer(keras.Model): - def __init__( - self, - num_hid=64, - num_head=2, - num_feed_forward=128, - source_maxlen=100, - target_maxlen=100, - num_layers_enc=4, - num_layers_dec=1, - num_classes=10, - ): - super().__init__() - self.loss_metric = keras.metrics.Mean(name="loss") - self.num_layers_enc = num_layers_enc - self.num_layers_dec = num_layers_dec - self.target_maxlen = target_maxlen - self.num_classes = num_classes - - self.enc_input = SpeechFeatureEmbedding(num_hid=num_hid, maxlen=source_maxlen) - self.dec_input = TokenEmbedding( - num_vocab=num_classes, maxlen=target_maxlen, num_hid=num_hid - ) - - self.encoder = keras.Sequential( - [self.enc_input] - + [ - TransformerEncoder(num_hid, num_head, num_feed_forward) - for _ in range(num_layers_enc) - ] - ) - - for i in range(num_layers_dec): - setattr( - self, - f"dec_layer_{i}", - TransformerDecoder(num_hid, num_head, num_feed_forward), - ) - - self.classifier = layers.Dense(num_classes) - - def decode(self, enc_out, target): - y = self.dec_input(target) - for i in range(self.num_layers_dec): - y = getattr(self, f"dec_layer_{i}")(enc_out, y) - return y - - def call(self, inputs): - source = inputs[0] - target = inputs[1] - x = self.encoder(source) - y = self.decode(x, target) - return self.classifier(y) - - @property - def metrics(self): - return [self.loss_metric] - - def train_step(self, batch): - """Processes one batch inside model.fit().""" - source = batch["source"] - target = batch["target"] - dec_input = target[:, :-1] - dec_target = target[:, 1:] - with tf.GradientTape() as tape: - preds = self([source, dec_input]) - one_hot = tf.one_hot(dec_target, depth=self.num_classes) - mask = tf.math.logical_not(tf.math.equal(dec_target, 0)) - loss = model.compute_loss(None, one_hot, preds, sample_weight=mask) - trainable_vars = self.trainable_variables - gradients = tape.gradient(loss, trainable_vars) - self.optimizer.apply_gradients(zip(gradients, trainable_vars)) - self.loss_metric.update_state(loss) - return {"loss": self.loss_metric.result()} - - def test_step(self, batch): - source = batch["source"] - target = batch["target"] - dec_input = target[:, :-1] - dec_target = target[:, 1:] - preds = self([source, dec_input]) - one_hot = tf.one_hot(dec_target, depth=self.num_classes) - mask = tf.math.logical_not(tf.math.equal(dec_target, 0)) - loss = model.compute_loss(None, one_hot, preds, sample_weight=mask) - self.loss_metric.update_state(loss) - return {"loss": self.loss_metric.result()} - - def generate(self, source, target_start_token_idx): - """Performs inference over one batch of inputs using greedy decoding.""" - bs = tf.shape(source)[0] - enc = self.encoder(source) - dec_input = tf.ones((bs, 1), dtype=tf.int32) * target_start_token_idx - dec_logits = [] - for i in range(self.target_maxlen - 1): - dec_out = self.decode(enc, dec_input) - logits = self.classifier(dec_out) - logits = tf.argmax(logits, axis=-1, output_type=tf.int32) - last_logit = tf.expand_dims(logits[:, -1], axis=-1) - dec_logits.append(last_logit) - dec_input = tf.concat([dec_input, last_logit], axis=-1) - return dec_input - -``` - ---- -## Download the dataset - -Note: This requires ~3.6 GB of disk space and -takes ~5 minutes for the extraction of files. - - -```python -keras.utils.get_file( - os.path.join(os.getcwd(), "data.tar.gz"), - "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2", - extract=True, - archive_format="tar", - cache_dir=".", -) - - -saveto = "./datasets/LJSpeech-1.1" -wavs = glob("{}/**/*.wav".format(saveto), recursive=True) - -id_to_text = {} -with open(os.path.join(saveto, "metadata.csv"), encoding="utf-8") as f: - for line in f: - id = line.strip().split("|")[0] - text = line.strip().split("|")[2] - id_to_text[id] = text - - -def get_data(wavs, id_to_text, maxlen=50): - """returns mapping of audio paths and transcription texts""" - data = [] - for w in wavs: - id = pattern_wav_name.split(w)[-4] - if len(id_to_text[id]) < maxlen: - data.append({"audio": w, "text": id_to_text[id]}) - return data - -``` - -
-``` -Downloading data from https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 - 2748572632/2748572632 ━━━━━━━━━━━━━━━━━━━━ 18s 0us/step - -``` -
---- -## Preprocess the dataset - - -```python - -class VectorizeChar: - def __init__(self, max_len=50): - self.vocab = ( - ["-", "#", "<", ">"] - + [chr(i + 96) for i in range(1, 27)] - + [" ", ".", ",", "?"] - ) - self.max_len = max_len - self.char_to_idx = {} - for i, ch in enumerate(self.vocab): - self.char_to_idx[ch] = i - - def __call__(self, text): - text = text.lower() - text = text[: self.max_len - 2] - text = "<" + text + ">" - pad_len = self.max_len - len(text) - return [self.char_to_idx.get(ch, 1) for ch in text] + [0] * pad_len - - def get_vocabulary(self): - return self.vocab - - -max_target_len = 200 # all transcripts in out data are < 200 characters -data = get_data(wavs, id_to_text, max_target_len) -vectorizer = VectorizeChar(max_target_len) -print("vocab size", len(vectorizer.get_vocabulary())) - - -def create_text_ds(data): - texts = [_["text"] for _ in data] - text_ds = [vectorizer(t) for t in texts] - text_ds = tf.data.Dataset.from_tensor_slices(text_ds) - return text_ds - - -def path_to_audio(path): - # spectrogram using stft - audio = tf.io.read_file(path) - audio, _ = tf.audio.decode_wav(audio, 1) - audio = tf.squeeze(audio, axis=-1) - stfts = tf.signal.stft(audio, frame_length=200, frame_step=80, fft_length=256) - x = tf.math.pow(tf.abs(stfts), 0.5) - # normalisation - means = tf.math.reduce_mean(x, 1, keepdims=True) - stddevs = tf.math.reduce_std(x, 1, keepdims=True) - x = (x - means) / stddevs - audio_len = tf.shape(x)[0] - # padding to 10 seconds - pad_len = 2754 - paddings = tf.constant([[0, pad_len], [0, 0]]) - x = tf.pad(x, paddings, "CONSTANT")[:pad_len, :] - return x - - -def create_audio_ds(data): - flist = [_["audio"] for _ in data] - audio_ds = tf.data.Dataset.from_tensor_slices(flist) - audio_ds = audio_ds.map(path_to_audio, num_parallel_calls=tf.data.AUTOTUNE) - return audio_ds - - -def create_tf_dataset(data, bs=4): - audio_ds = create_audio_ds(data) - text_ds = create_text_ds(data) - ds = tf.data.Dataset.zip((audio_ds, text_ds)) - ds = ds.map(lambda x, y: {"source": x, "target": y}) - ds = ds.batch(bs) - ds = ds.prefetch(tf.data.AUTOTUNE) - return ds - - -split = int(len(data) * 0.99) -train_data = data[:split] -test_data = data[split:] -ds = create_tf_dataset(train_data, bs=64) -val_ds = create_tf_dataset(test_data, bs=4) -``` - -
-``` -vocab size 34 - -``` -
---- -## Callbacks to display predictions - - -```python - -class DisplayOutputs(keras.callbacks.Callback): - def __init__( - self, batch, idx_to_token, target_start_token_idx=27, target_end_token_idx=28 - ): - """Displays a batch of outputs after every epoch - - Args: - batch: A test batch containing the keys "source" and "target" - idx_to_token: A List containing the vocabulary tokens corresponding to their indices - target_start_token_idx: A start token index in the target vocabulary - target_end_token_idx: An end token index in the target vocabulary - """ - self.batch = batch - self.target_start_token_idx = target_start_token_idx - self.target_end_token_idx = target_end_token_idx - self.idx_to_char = idx_to_token - - def on_epoch_end(self, epoch, logs=None): - if epoch % 5 != 0: - return - source = self.batch["source"] - target = self.batch["target"].numpy() - bs = tf.shape(source)[0] - preds = self.model.generate(source, self.target_start_token_idx) - preds = preds.numpy() - for i in range(bs): - target_text = "".join([self.idx_to_char[_] for _ in target[i, :]]) - prediction = "" - for idx in preds[i, :]: - prediction += self.idx_to_char[idx] - if idx == self.target_end_token_idx: - break - print(f"target: {target_text.replace('-','')}") - print(f"prediction: {prediction}\n") - -``` - ---- -## Learning rate schedule - - -```python - -class CustomSchedule(keras.optimizers.schedules.LearningRateSchedule): - def __init__( - self, - init_lr=0.00001, - lr_after_warmup=0.001, - final_lr=0.00001, - warmup_epochs=15, - decay_epochs=85, - steps_per_epoch=203, - ): - super().__init__() - self.init_lr = init_lr - self.lr_after_warmup = lr_after_warmup - self.final_lr = final_lr - self.warmup_epochs = warmup_epochs - self.decay_epochs = decay_epochs - self.steps_per_epoch = steps_per_epoch - - def calculate_lr(self, epoch): - """linear warm up - linear decay""" - warmup_lr = ( - self.init_lr - + ((self.lr_after_warmup - self.init_lr) / (self.warmup_epochs - 1)) * epoch - ) - decay_lr = tf.math.maximum( - self.final_lr, - self.lr_after_warmup - - (epoch - self.warmup_epochs) - * (self.lr_after_warmup - self.final_lr) - / self.decay_epochs, - ) - return tf.math.minimum(warmup_lr, decay_lr) - - def __call__(self, step): - epoch = step // self.steps_per_epoch - epoch = tf.cast(epoch, "float32") - return self.calculate_lr(epoch) - -``` - ---- -## Create & train the end-to-end model - - -```python -batch = next(iter(val_ds)) - -# The vocabulary to convert predicted indices into characters -idx_to_char = vectorizer.get_vocabulary() -display_cb = DisplayOutputs( - batch, idx_to_char, target_start_token_idx=2, target_end_token_idx=3 -) # set the arguments as per vocabulary index for '<' and '>' - -model = Transformer( - num_hid=200, - num_head=2, - num_feed_forward=400, - target_maxlen=max_target_len, - num_layers_enc=4, - num_layers_dec=1, - num_classes=34, -) -loss_fn = keras.losses.CategoricalCrossentropy( - from_logits=True, - label_smoothing=0.1, -) - -learning_rate = CustomSchedule( - init_lr=0.00001, - lr_after_warmup=0.001, - final_lr=0.00001, - warmup_epochs=15, - decay_epochs=85, - steps_per_epoch=len(ds), -) -optimizer = keras.optimizers.Adam(learning_rate) -model.compile(optimizer=optimizer, loss=loss_fn) - -history = model.fit(ds, validation_data=val_ds, callbacks=[display_cb], epochs=1) -``` - -
-``` - 1/203 ━━━━━━━━━━━━━━━━━━━━ 9:20:11 166s/step - loss: 2.2387 - -WARNING: All log messages before absl::InitializeLog() is called are written to STDERR -I0000 00:00:1700071380.331418 678094 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. - - 203/203 ━━━━━━━━━━━━━━━━━━━━ 0s 947ms/step - loss: 1.8285target: -prediction: - -
-``` -target: -prediction: - -
-``` -target: -prediction: - -
-``` -target: -prediction: - -
-``` - 203/203 ━━━━━━━━━━━━━━━━━━━━ 428s 1s/step - loss: 1.8276 - val_loss: 1.5233 - -``` -
-In practice, you should train for around 100 epochs or more. - -Some of the predicted text at or around epoch 35 may look as follows: -``` -target: -prediction: - -target: -prediction: -``` - diff --git a/templates/examples/audio/uk_ireland_accent_recognition.md b/templates/examples/audio/uk_ireland_accent_recognition.md deleted file mode 100644 index 4c1b889250..0000000000 --- a/templates/examples/audio/uk_ireland_accent_recognition.md +++ /dev/null @@ -1,1203 +0,0 @@ -# English speaker accent recognition using Transfer Learning - -**Author:** [Fadi Badine](https://twitter.com/fadibadine)
-**Date created:** 2022/04/16
-**Last modified:** 2022/04/16
-**Description:** Training a model to classify UK & Ireland accents using feature extraction from Yamnet. - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/audio/ipynb/uk_ireland_accent_recognition.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/audio/uk_ireland_accent_recognition.py) - - - ---- -## Introduction - -The following example shows how to use feature extraction in order to -train a model to classify the English accent spoken in an audio wave. - -Instead of training a model from scratch, transfer learning enables us to -take advantage of existing state-of-the-art deep learning models and use them as feature extractors. - -Our process: - -* Use a TF Hub pre-trained model (Yamnet) and apply it as part of the tf.data pipeline which transforms -the audio files into feature vectors. -* Train a dense model on the feature vectors. -* Use the trained model for inference on a new audio file. - -Note: - -* We need to install TensorFlow IO in order to resample audio files to 16 kHz as required by Yamnet model. -* In the test section, ffmpeg is used to convert the mp3 file to wav. - -You can install TensorFlow IO with the following command: - - -```python -!pip install -U -q tensorflow_io -``` - ---- -## Configuration - - -```python -SEED = 1337 -EPOCHS = 100 -BATCH_SIZE = 64 -VALIDATION_RATIO = 0.1 -MODEL_NAME = "uk_irish_accent_recognition" - -# Location where the dataset will be downloaded. -# By default (None), keras.utils.get_file will use ~/.keras/ as the CACHE_DIR -CACHE_DIR = None - -# The location of the dataset -URL_PATH = "https://www.openslr.org/resources/83/" - -# List of datasets compressed files that contain the audio files -zip_files = { - 0: "irish_english_male.zip", - 1: "midlands_english_female.zip", - 2: "midlands_english_male.zip", - 3: "northern_english_female.zip", - 4: "northern_english_male.zip", - 5: "scottish_english_female.zip", - 6: "scottish_english_male.zip", - 7: "southern_english_female.zip", - 8: "southern_english_male.zip", - 9: "welsh_english_female.zip", - 10: "welsh_english_male.zip", -} - -# We see that there are 2 compressed files for each accent (except Irish): -# - One for male speakers -# - One for female speakers -# However, we will be using a gender agnostic dataset. - -# List of gender agnostic categories -gender_agnostic_categories = [ - "ir", # Irish - "mi", # Midlands - "no", # Northern - "sc", # Scottish - "so", # Southern - "we", # Welsh -] - -class_names = [ - "Irish", - "Midlands", - "Northern", - "Scottish", - "Southern", - "Welsh", - "Not a speech", -] -``` - ---- -## Imports - - -```python -import os -import io -import csv -import numpy as np -import pandas as pd -import tensorflow as tf -import tensorflow_hub as hub -import tensorflow_io as tfio -from tensorflow import keras -import matplotlib.pyplot as plt -import seaborn as sns -from scipy import stats -from IPython.display import Audio - - -# Set all random seeds in order to get reproducible results -keras.utils.set_random_seed(SEED) - -# Where to download the dataset -DATASET_DESTINATION = os.path.join(CACHE_DIR if CACHE_DIR else "~/.keras/", "datasets") -``` - ---- -## Yamnet Model - -Yamnet is an audio event classifier trained on the AudioSet dataset to predict audio -events from the AudioSet ontology. It is available on TensorFlow Hub. - -Yamnet accepts a 1-D tensor of audio samples with a sample rate of 16 kHz. -As output, the model returns a 3-tuple: - -* Scores of shape `(N, 521)` representing the scores of the 521 classes. -* Embeddings of shape `(N, 1024)`. -* The log-mel spectrogram of the entire audio frame. - -We will use the embeddings, which are the features extracted from the audio samples, as the input to our dense model. - -For more detailed information about Yamnet, please refer to its [TensorFlow Hub](https://tfhub.dev/google/yamnet/1) page. - - -```python -yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1") -``` - ---- -## Dataset - -The dataset used is the -[Crowdsourced high-quality UK and Ireland English Dialect speech data set](https://openslr.org/83/) -which consists of a total of 17,877 high-quality audio wav files. - -This dataset includes over 31 hours of recording from 120 volunteers who self-identify as -native speakers of Southern England, Midlands, Northern England, Wales, Scotland and Ireland. - -For more info, please refer to the above link or to the following paper: -[Open-source Multi-speaker Corpora of the English Accents in the British Isles](https://aclanthology.org/2020.lrec-1.804.pdf) - ---- -## Download the data - - -```python -# CSV file that contains information about the dataset. For each entry, we have: -# - ID -# - wav file name -# - transcript -line_index_file = keras.utils.get_file( - fname="line_index_file", origin=URL_PATH + "line_index_all.csv" -) - -# Download the list of compressed files that contain the audio wav files -for i in zip_files: - fname = zip_files[i].split(".")[0] - url = URL_PATH + zip_files[i] - - zip_file = keras.utils.get_file(fname=fname, origin=url, extract=True) - os.remove(zip_file) -``` - -
-``` -Downloading data from https://www.openslr.org/resources/83/line_index_all.csv -1990656/1986139 [==============================] - 1s 0us/step -1998848/1986139 [==============================] - 1s 0us/step -Downloading data from https://www.openslr.org/resources/83/irish_english_male.zip -164536320/164531638 [==============================] - 9s 0us/step -164544512/164531638 [==============================] - 9s 0us/step -Downloading data from https://www.openslr.org/resources/83/midlands_english_female.zip -103088128/103085118 [==============================] - 6s 0us/step -103096320/103085118 [==============================] - 6s 0us/step -Downloading data from https://www.openslr.org/resources/83/midlands_english_male.zip -166838272/166833961 [==============================] - 9s 0us/step -166846464/166833961 [==============================] - 9s 0us/step -Downloading data from https://www.openslr.org/resources/83/northern_english_female.zip -314990592/314983063 [==============================] - 15s 0us/step -314998784/314983063 [==============================] - 15s 0us/step -Downloading data from https://www.openslr.org/resources/83/northern_english_male.zip -817774592/817772034 [==============================] - 39s 0us/step -817782784/817772034 [==============================] - 39s 0us/step -Downloading data from https://www.openslr.org/resources/83/scottish_english_female.zip -351444992/351443880 [==============================] - 17s 0us/step -351453184/351443880 [==============================] - 17s 0us/step -Downloading data from https://www.openslr.org/resources/83/scottish_english_male.zip -620257280/620254118 [==============================] - 30s 0us/step -620265472/620254118 [==============================] - 30s 0us/step -Downloading data from https://www.openslr.org/resources/83/southern_english_female.zip -1636704256/1636701939 [==============================] - 77s 0us/step -1636712448/1636701939 [==============================] - 77s 0us/step -Downloading data from https://www.openslr.org/resources/83/southern_english_male.zip -1700962304/1700955740 [==============================] - 79s 0us/step -1700970496/1700955740 [==============================] - 79s 0us/step -Downloading data from https://www.openslr.org/resources/83/welsh_english_female.zip -595689472/595683538 [==============================] - 29s 0us/step -595697664/595683538 [==============================] - 29s 0us/step -Downloading data from https://www.openslr.org/resources/83/welsh_english_male.zip -757653504/757645790 [==============================] - 37s 0us/step -757661696/757645790 [==============================] - 37s 0us/step - -``` -
---- -## Load the data in a Dataframe - -Of the 3 columns (ID, filename and transcript), we are only interested in the filename column in order to read the audio file. -We will ignore the other two. - - -```python -dataframe = pd.read_csv( - line_index_file, names=["id", "filename", "transcript"], usecols=["filename"] -) -dataframe.head() -``` - - - - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
filename
0wef_12484_01482829612
1wef_12484_01345932698
2wef_12484_00999757777
3wef_12484_00036278823
4wef_12484_00458512623
-
- - - -Let's now preprocess the dataset by: - -* Adjusting the filename (removing a leading space & adding ".wav" extension to the -filename). -* Creating a label using the first 2 characters of the filename which indicate the -accent. -* Shuffling the samples. - - -```python -# The purpose of this function is to preprocess the dataframe by applying the following: -# - Cleaning the filename from a leading space -# - Generating a label column that is gender agnostic i.e. -# welsh english male and welsh english female for example are both labeled as -# welsh english -# - Add extension .wav to the filename -# - Shuffle samples -def preprocess_dataframe(dataframe): - # Remove leading space in filename column - dataframe["filename"] = dataframe.apply(lambda row: row["filename"].strip(), axis=1) - - # Create gender agnostic labels based on the filename first 2 letters - dataframe["label"] = dataframe.apply( - lambda row: gender_agnostic_categories.index(row["filename"][:2]), axis=1 - ) - - # Add the file path to the name - dataframe["filename"] = dataframe.apply( - lambda row: os.path.join(DATASET_DESTINATION, row["filename"] + ".wav"), axis=1 - ) - - # Shuffle the samples - dataframe = dataframe.sample(frac=1, random_state=SEED).reset_index(drop=True) - - return dataframe - - -dataframe = preprocess_dataframe(dataframe) -dataframe.head() -``` - - - - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
filenamelabel
0/root/.keras/datasets/som_03853_01027933689.wav4
1/root/.keras/datasets/som_04310_01833253760.wav4
2/root/.keras/datasets/sof_06136_01210700905.wav4
3/root/.keras/datasets/som_02484_00261230384.wav4
4/root/.keras/datasets/nom_06136_00616878975.wav2
-
- - - ---- -## Prepare training & validation sets - -Let's split the samples creating training and validation sets. - - -```python -split = int(len(dataframe) * (1 - VALIDATION_RATIO)) -train_df = dataframe[:split] -valid_df = dataframe[split:] - -print( - f"We have {train_df.shape[0]} training samples & {valid_df.shape[0]} validation ones" -) -``` - -
-``` -We have 16089 training samples & 1788 validation ones - -``` -
---- -## Prepare a TensorFlow Dataset - -Next, we need to create a `tf.data.Dataset`. -This is done by creating a `dataframe_to_dataset` function that does the following: - -* Create a dataset using filenames and labels. -* Get the Yamnet embeddings by calling another function `filepath_to_embeddings`. -* Apply caching, reshuffling and setting batch size. - -The `filepath_to_embeddings` does the following: - -* Load audio file. -* Resample audio to 16 kHz. -* Generate scores and embeddings from Yamnet model. -* Since Yamnet generates multiple samples for each audio file, -this function also duplicates the label for all the generated samples -that have `score=0` (speech) whereas sets the label for the others as -'other' indicating that this audio segment is not a speech and we won't label it as one of the accents. - -The below `load_16k_audio_file` is copied from the following tutorial -[Transfer learning with YAMNet for environmental sound classification](https://www.tensorflow.org/tutorials/audio/transfer_learning_audio) - - -```python - -@tf.function -def load_16k_audio_wav(filename): - # Read file content - file_content = tf.io.read_file(filename) - - # Decode audio wave - audio_wav, sample_rate = tf.audio.decode_wav(file_content, desired_channels=1) - audio_wav = tf.squeeze(audio_wav, axis=-1) - sample_rate = tf.cast(sample_rate, dtype=tf.int64) - - # Resample to 16k - audio_wav = tfio.audio.resample(audio_wav, rate_in=sample_rate, rate_out=16000) - - return audio_wav - - -def filepath_to_embeddings(filename, label): - # Load 16k audio wave - audio_wav = load_16k_audio_wav(filename) - - # Get audio embeddings & scores. - # The embeddings are the audio features extracted using transfer learning - # while scores will be used to identify time slots that are not speech - # which will then be gathered into a specific new category 'other' - scores, embeddings, _ = yamnet_model(audio_wav) - - # Number of embeddings in order to know how many times to repeat the label - embeddings_num = tf.shape(embeddings)[0] - labels = tf.repeat(label, embeddings_num) - - # Change labels for time-slots that are not speech into a new category 'other' - labels = tf.where(tf.argmax(scores, axis=1) == 0, label, len(class_names) - 1) - - # Using one-hot in order to use AUC - return (embeddings, tf.one_hot(labels, len(class_names))) - - -def dataframe_to_dataset(dataframe, batch_size=64): - dataset = tf.data.Dataset.from_tensor_slices( - (dataframe["filename"], dataframe["label"]) - ) - - dataset = dataset.map( - lambda x, y: filepath_to_embeddings(x, y), - num_parallel_calls=tf.data.experimental.AUTOTUNE, - ).unbatch() - - return dataset.cache().batch(batch_size).prefetch(tf.data.AUTOTUNE) - - -train_ds = dataframe_to_dataset(train_df) -valid_ds = dataframe_to_dataset(valid_df) -``` - ---- -## Build the model - -The model that we use consists of: - -* An input layer which is the embedding output of the Yamnet classifier. -* 4 dense hidden layers and 4 dropout layers. -* An output dense layer. - -The model's hyperparameters were selected using -[KerasTuner](https://keras.io/keras_tuner/). - - -```python -keras.backend.clear_session() - - -def build_and_compile_model(): - inputs = keras.layers.Input(shape=(1024), name="embedding") - - x = keras.layers.Dense(256, activation="relu", name="dense_1")(inputs) - x = keras.layers.Dropout(0.15, name="dropout_1")(x) - - x = keras.layers.Dense(384, activation="relu", name="dense_2")(x) - x = keras.layers.Dropout(0.2, name="dropout_2")(x) - - x = keras.layers.Dense(192, activation="relu", name="dense_3")(x) - x = keras.layers.Dropout(0.25, name="dropout_3")(x) - - x = keras.layers.Dense(384, activation="relu", name="dense_4")(x) - x = keras.layers.Dropout(0.2, name="dropout_4")(x) - - outputs = keras.layers.Dense(len(class_names), activation="softmax", name="ouput")( - x - ) - - model = keras.Model(inputs=inputs, outputs=outputs, name="accent_recognition") - - model.compile( - optimizer=keras.optimizers.Adam(learning_rate=1.9644e-5), - loss=keras.losses.CategoricalCrossentropy(), - metrics=["accuracy", keras.metrics.AUC(name="auc")], - ) - - return model - - -model = build_and_compile_model() -model.summary() -``` - -
-``` -Model: "accent_recognition" -_________________________________________________________________ - Layer (type) Output Shape Param # -================================================================= - embedding (InputLayer) [(None, 1024)] 0 - - dense_1 (Dense) (None, 256) 262400 - - dropout_1 (Dropout) (None, 256) 0 - - dense_2 (Dense) (None, 384) 98688 - - dropout_2 (Dropout) (None, 384) 0 - - dense_3 (Dense) (None, 192) 73920 - - dropout_3 (Dropout) (None, 192) 0 - - dense_4 (Dense) (None, 384) 74112 - - dropout_4 (Dropout) (None, 384) 0 - - ouput (Dense) (None, 7) 2695 - -================================================================= -Total params: 511,815 -Trainable params: 511,815 -Non-trainable params: 0 -_________________________________________________________________ - -``` -
---- -## Class weights calculation - -Since the dataset is quite unbalanced, we will use `class_weight` argument during training. - -Getting the class weights is a little tricky because even though we know the number of -audio files for each class, it does not represent the number of samples for that class -since Yamnet transforms each audio file into multiple audio samples of 0.96 seconds each. -So every audio file will be split into a number of samples that is proportional to its length. - -Therefore, to get those weights, we have to calculate the number of samples for each class -after preprocessing through Yamnet. - - -```python -class_counts = tf.zeros(shape=(len(class_names),), dtype=tf.int32) - -for x, y in iter(train_ds): - class_counts = class_counts + tf.math.bincount( - tf.cast(tf.math.argmax(y, axis=1), tf.int32), minlength=len(class_names) - ) - -class_weight = { - i: tf.math.reduce_sum(class_counts).numpy() / class_counts[i].numpy() - for i in range(len(class_counts)) -} - -print(class_weight) -``` - -
-``` -{0: 50.430241233524, 1: 30.668481548699333, 2: 7.322956917409988, 3: 8.125175301518611, 4: 2.4034894333226657, 5: 6.4197296356095865, 6: 8.613175890922992} - -``` -
---- -## Callbacks - -We use Keras callbacks in order to: - -* Stop whenever the validation AUC stops improving. -* Save the best model. -* Call TensorBoard in order to later view the training and validation logs. - - -```python -early_stopping_cb = keras.callbacks.EarlyStopping( - monitor="val_auc", patience=10, restore_best_weights=True -) - -model_checkpoint_cb = keras.callbacks.ModelCheckpoint( - MODEL_NAME + ".h5", monitor="val_auc", save_best_only=True -) - -tensorboard_cb = keras.callbacks.TensorBoard( - os.path.join(os.curdir, "logs", model.name) -) - -callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb] -``` - ---- -## Training - - -```python -history = model.fit( - train_ds, - epochs=EPOCHS, - validation_data=valid_ds, - class_weight=class_weight, - callbacks=callbacks, - verbose=2, -) -``` - -
-``` -Epoch 1/100 -3169/3169 - 131s - loss: 10.6145 - accuracy: 0.3426 - auc: 0.7585 - val_loss: 1.3781 - val_accuracy: 0.4084 - val_auc: 0.8118 - 131s/epoch - 41ms/step -Epoch 2/100 -3169/3169 - 12s - loss: 9.3787 - accuracy: 0.3957 - auc: 0.8055 - val_loss: 1.3291 - val_accuracy: 0.4470 - val_auc: 0.8294 - 12s/epoch - 4ms/step -Epoch 3/100 -3169/3169 - 13s - loss: 8.9948 - accuracy: 0.4216 - auc: 0.8212 - val_loss: 1.3144 - val_accuracy: 0.4497 - val_auc: 0.8340 - 13s/epoch - 4ms/step -Epoch 4/100 -3169/3169 - 13s - loss: 8.7682 - accuracy: 0.4327 - auc: 0.8291 - val_loss: 1.3052 - val_accuracy: 0.4515 - val_auc: 0.8368 - 13s/epoch - 4ms/step -Epoch 5/100 -3169/3169 - 12s - loss: 8.6352 - accuracy: 0.4375 - auc: 0.8328 - val_loss: 1.2993 - val_accuracy: 0.4482 - val_auc: 0.8377 - 12s/epoch - 4ms/step -Epoch 6/100 -3169/3169 - 12s - loss: 8.5149 - accuracy: 0.4421 - auc: 0.8367 - val_loss: 1.2930 - val_accuracy: 0.4462 - val_auc: 0.8398 - 12s/epoch - 4ms/step -Epoch 7/100 -3169/3169 - 12s - loss: 8.4321 - accuracy: 0.4438 - auc: 0.8393 - val_loss: 1.2881 - val_accuracy: 0.4460 - val_auc: 0.8412 - 12s/epoch - 4ms/step -Epoch 8/100 -3169/3169 - 12s - loss: 8.3385 - accuracy: 0.4459 - auc: 0.8413 - val_loss: 1.2730 - val_accuracy: 0.4503 - val_auc: 0.8450 - 12s/epoch - 4ms/step -Epoch 9/100 -3169/3169 - 12s - loss: 8.2704 - accuracy: 0.4478 - auc: 0.8434 - val_loss: 1.2718 - val_accuracy: 0.4486 - val_auc: 0.8451 - 12s/epoch - 4ms/step -Epoch 10/100 -3169/3169 - 12s - loss: 8.2023 - accuracy: 0.4489 - auc: 0.8455 - val_loss: 1.2714 - val_accuracy: 0.4450 - val_auc: 0.8450 - 12s/epoch - 4ms/step -Epoch 11/100 -3169/3169 - 12s - loss: 8.1402 - accuracy: 0.4504 - auc: 0.8474 - val_loss: 1.2616 - val_accuracy: 0.4496 - val_auc: 0.8479 - 12s/epoch - 4ms/step -Epoch 12/100 -3169/3169 - 12s - loss: 8.0935 - accuracy: 0.4521 - auc: 0.8488 - val_loss: 1.2569 - val_accuracy: 0.4503 - val_auc: 0.8494 - 12s/epoch - 4ms/step -Epoch 13/100 -3169/3169 - 12s - loss: 8.0281 - accuracy: 0.4541 - auc: 0.8507 - val_loss: 1.2537 - val_accuracy: 0.4516 - val_auc: 0.8505 - 12s/epoch - 4ms/step -Epoch 14/100 -3169/3169 - 12s - loss: 7.9817 - accuracy: 0.4540 - auc: 0.8519 - val_loss: 1.2584 - val_accuracy: 0.4478 - val_auc: 0.8496 - 12s/epoch - 4ms/step -Epoch 15/100 -3169/3169 - 12s - loss: 7.9342 - accuracy: 0.4556 - auc: 0.8534 - val_loss: 1.2469 - val_accuracy: 0.4515 - val_auc: 0.8526 - 12s/epoch - 4ms/step -Epoch 16/100 -3169/3169 - 12s - loss: 7.8945 - accuracy: 0.4560 - auc: 0.8545 - val_loss: 1.2332 - val_accuracy: 0.4574 - val_auc: 0.8564 - 12s/epoch - 4ms/step -Epoch 17/100 -3169/3169 - 12s - loss: 7.8461 - accuracy: 0.4585 - auc: 0.8560 - val_loss: 1.2406 - val_accuracy: 0.4534 - val_auc: 0.8545 - 12s/epoch - 4ms/step -Epoch 18/100 -3169/3169 - 12s - loss: 7.8091 - accuracy: 0.4604 - auc: 0.8570 - val_loss: 1.2313 - val_accuracy: 0.4574 - val_auc: 0.8570 - 12s/epoch - 4ms/step -Epoch 19/100 -3169/3169 - 12s - loss: 7.7604 - accuracy: 0.4605 - auc: 0.8583 - val_loss: 1.2342 - val_accuracy: 0.4563 - val_auc: 0.8565 - 12s/epoch - 4ms/step -Epoch 20/100 -3169/3169 - 13s - loss: 7.7205 - accuracy: 0.4624 - auc: 0.8596 - val_loss: 1.2245 - val_accuracy: 0.4619 - val_auc: 0.8594 - 13s/epoch - 4ms/step -Epoch 21/100 -3169/3169 - 12s - loss: 7.6892 - accuracy: 0.4637 - auc: 0.8605 - val_loss: 1.2264 - val_accuracy: 0.4576 - val_auc: 0.8587 - 12s/epoch - 4ms/step -Epoch 22/100 -3169/3169 - 12s - loss: 7.6396 - accuracy: 0.4636 - auc: 0.8614 - val_loss: 1.2180 - val_accuracy: 0.4632 - val_auc: 0.8614 - 12s/epoch - 4ms/step -Epoch 23/100 -3169/3169 - 12s - loss: 7.5927 - accuracy: 0.4672 - auc: 0.8627 - val_loss: 1.2127 - val_accuracy: 0.4630 - val_auc: 0.8626 - 12s/epoch - 4ms/step -Epoch 24/100 -3169/3169 - 13s - loss: 7.5766 - accuracy: 0.4666 - auc: 0.8632 - val_loss: 1.2112 - val_accuracy: 0.4636 - val_auc: 0.8632 - 13s/epoch - 4ms/step -Epoch 25/100 -3169/3169 - 12s - loss: 7.5511 - accuracy: 0.4678 - auc: 0.8644 - val_loss: 1.2096 - val_accuracy: 0.4664 - val_auc: 0.8641 - 12s/epoch - 4ms/step -Epoch 26/100 -3169/3169 - 12s - loss: 7.5108 - accuracy: 0.4679 - auc: 0.8648 - val_loss: 1.2033 - val_accuracy: 0.4664 - val_auc: 0.8652 - 12s/epoch - 4ms/step -Epoch 27/100 -3169/3169 - 12s - loss: 7.4751 - accuracy: 0.4692 - auc: 0.8659 - val_loss: 1.2050 - val_accuracy: 0.4668 - val_auc: 0.8653 - 12s/epoch - 4ms/step -Epoch 28/100 -3169/3169 - 12s - loss: 7.4332 - accuracy: 0.4704 - auc: 0.8668 - val_loss: 1.2004 - val_accuracy: 0.4688 - val_auc: 0.8665 - 12s/epoch - 4ms/step -Epoch 29/100 -3169/3169 - 12s - loss: 7.4195 - accuracy: 0.4709 - auc: 0.8675 - val_loss: 1.2037 - val_accuracy: 0.4665 - val_auc: 0.8654 - 12s/epoch - 4ms/step -Epoch 30/100 -3169/3169 - 12s - loss: 7.3719 - accuracy: 0.4718 - auc: 0.8683 - val_loss: 1.1979 - val_accuracy: 0.4694 - val_auc: 0.8674 - 12s/epoch - 4ms/step -Epoch 31/100 -3169/3169 - 12s - loss: 7.3513 - accuracy: 0.4728 - auc: 0.8690 - val_loss: 1.2030 - val_accuracy: 0.4662 - val_auc: 0.8661 - 12s/epoch - 4ms/step -Epoch 32/100 -3169/3169 - 12s - loss: 7.3218 - accuracy: 0.4738 - auc: 0.8697 - val_loss: 1.1982 - val_accuracy: 0.4689 - val_auc: 0.8673 - 12s/epoch - 4ms/step -Epoch 33/100 -3169/3169 - 12s - loss: 7.2744 - accuracy: 0.4750 - auc: 0.8708 - val_loss: 1.1921 - val_accuracy: 0.4715 - val_auc: 0.8688 - 12s/epoch - 4ms/step -Epoch 34/100 -3169/3169 - 12s - loss: 7.2520 - accuracy: 0.4765 - auc: 0.8715 - val_loss: 1.1935 - val_accuracy: 0.4717 - val_auc: 0.8685 - 12s/epoch - 4ms/step -Epoch 35/100 -3169/3169 - 12s - loss: 7.2214 - accuracy: 0.4769 - auc: 0.8721 - val_loss: 1.1940 - val_accuracy: 0.4688 - val_auc: 0.8681 - 12s/epoch - 4ms/step -Epoch 36/100 -3169/3169 - 12s - loss: 7.1789 - accuracy: 0.4798 - auc: 0.8732 - val_loss: 1.1796 - val_accuracy: 0.4733 - val_auc: 0.8717 - 12s/epoch - 4ms/step -Epoch 37/100 -3169/3169 - 12s - loss: 7.1520 - accuracy: 0.4813 - auc: 0.8739 - val_loss: 1.1844 - val_accuracy: 0.4738 - val_auc: 0.8709 - 12s/epoch - 4ms/step -Epoch 38/100 -3169/3169 - 12s - loss: 7.1393 - accuracy: 0.4813 - auc: 0.8743 - val_loss: 1.1785 - val_accuracy: 0.4753 - val_auc: 0.8721 - 12s/epoch - 4ms/step -Epoch 39/100 -3169/3169 - 12s - loss: 7.1081 - accuracy: 0.4821 - auc: 0.8749 - val_loss: 1.1792 - val_accuracy: 0.4754 - val_auc: 0.8723 - 12s/epoch - 4ms/step -Epoch 40/100 -3169/3169 - 12s - loss: 7.0664 - accuracy: 0.4831 - auc: 0.8758 - val_loss: 1.1829 - val_accuracy: 0.4719 - val_auc: 0.8716 - 12s/epoch - 4ms/step -Epoch 41/100 -3169/3169 - 12s - loss: 7.0625 - accuracy: 0.4831 - auc: 0.8759 - val_loss: 1.1831 - val_accuracy: 0.4737 - val_auc: 0.8716 - 12s/epoch - 4ms/step -Epoch 42/100 -3169/3169 - 12s - loss: 7.0190 - accuracy: 0.4845 - auc: 0.8767 - val_loss: 1.1886 - val_accuracy: 0.4689 - val_auc: 0.8705 - 12s/epoch - 4ms/step -Epoch 43/100 -3169/3169 - 13s - loss: 7.0000 - accuracy: 0.4839 - auc: 0.8770 - val_loss: 1.1720 - val_accuracy: 0.4776 - val_auc: 0.8744 - 13s/epoch - 4ms/step -Epoch 44/100 -3169/3169 - 12s - loss: 6.9733 - accuracy: 0.4864 - auc: 0.8777 - val_loss: 1.1704 - val_accuracy: 0.4772 - val_auc: 0.8745 - 12s/epoch - 4ms/step -Epoch 45/100 -3169/3169 - 12s - loss: 6.9480 - accuracy: 0.4872 - auc: 0.8784 - val_loss: 1.1695 - val_accuracy: 0.4767 - val_auc: 0.8747 - 12s/epoch - 4ms/step -Epoch 46/100 -3169/3169 - 12s - loss: 6.9208 - accuracy: 0.4880 - auc: 0.8789 - val_loss: 1.1687 - val_accuracy: 0.4792 - val_auc: 0.8753 - 12s/epoch - 4ms/step -Epoch 47/100 -3169/3169 - 12s - loss: 6.8756 - accuracy: 0.4902 - auc: 0.8800 - val_loss: 1.1667 - val_accuracy: 0.4785 - val_auc: 0.8755 - 12s/epoch - 4ms/step -Epoch 48/100 -3169/3169 - 12s - loss: 6.8618 - accuracy: 0.4902 - auc: 0.8801 - val_loss: 1.1714 - val_accuracy: 0.4781 - val_auc: 0.8752 - 12s/epoch - 4ms/step -Epoch 49/100 -3169/3169 - 12s - loss: 6.8411 - accuracy: 0.4916 - auc: 0.8807 - val_loss: 1.1676 - val_accuracy: 0.4793 - val_auc: 0.8756 - 12s/epoch - 4ms/step -Epoch 50/100 -3169/3169 - 12s - loss: 6.8144 - accuracy: 0.4922 - auc: 0.8812 - val_loss: 1.1622 - val_accuracy: 0.4784 - val_auc: 0.8767 - 12s/epoch - 4ms/step -Epoch 51/100 -3169/3169 - 12s - loss: 6.7880 - accuracy: 0.4931 - auc: 0.8819 - val_loss: 1.1591 - val_accuracy: 0.4844 - val_auc: 0.8780 - 12s/epoch - 4ms/step -Epoch 52/100 -3169/3169 - 12s - loss: 6.7653 - accuracy: 0.4932 - auc: 0.8823 - val_loss: 1.1579 - val_accuracy: 0.4808 - val_auc: 0.8776 - 12s/epoch - 4ms/step -Epoch 53/100 -3169/3169 - 12s - loss: 6.7188 - accuracy: 0.4961 - auc: 0.8832 - val_loss: 1.1526 - val_accuracy: 0.4845 - val_auc: 0.8791 - 12s/epoch - 4ms/step -Epoch 54/100 -3169/3169 - 12s - loss: 6.6964 - accuracy: 0.4969 - auc: 0.8836 - val_loss: 1.1571 - val_accuracy: 0.4843 - val_auc: 0.8788 - 12s/epoch - 4ms/step -Epoch 55/100 -3169/3169 - 12s - loss: 6.6855 - accuracy: 0.4981 - auc: 0.8841 - val_loss: 1.1595 - val_accuracy: 0.4825 - val_auc: 0.8781 - 12s/epoch - 4ms/step -Epoch 56/100 -3169/3169 - 12s - loss: 6.6555 - accuracy: 0.4969 - auc: 0.8843 - val_loss: 1.1470 - val_accuracy: 0.4852 - val_auc: 0.8806 - 12s/epoch - 4ms/step -Epoch 57/100 -3169/3169 - 13s - loss: 6.6346 - accuracy: 0.4992 - auc: 0.8852 - val_loss: 1.1487 - val_accuracy: 0.4884 - val_auc: 0.8804 - 13s/epoch - 4ms/step -Epoch 58/100 -3169/3169 - 12s - loss: 6.5984 - accuracy: 0.5002 - auc: 0.8854 - val_loss: 1.1496 - val_accuracy: 0.4879 - val_auc: 0.8806 - 12s/epoch - 4ms/step -Epoch 59/100 -3169/3169 - 12s - loss: 6.5793 - accuracy: 0.5004 - auc: 0.8858 - val_loss: 1.1430 - val_accuracy: 0.4899 - val_auc: 0.8818 - 12s/epoch - 4ms/step -Epoch 60/100 -3169/3169 - 12s - loss: 6.5508 - accuracy: 0.5009 - auc: 0.8862 - val_loss: 1.1375 - val_accuracy: 0.4918 - val_auc: 0.8829 - 12s/epoch - 4ms/step -Epoch 61/100 -3169/3169 - 12s - loss: 6.5200 - accuracy: 0.5026 - auc: 0.8870 - val_loss: 1.1413 - val_accuracy: 0.4919 - val_auc: 0.8824 - 12s/epoch - 4ms/step -Epoch 62/100 -3169/3169 - 12s - loss: 6.5148 - accuracy: 0.5043 - auc: 0.8871 - val_loss: 1.1446 - val_accuracy: 0.4889 - val_auc: 0.8814 - 12s/epoch - 4ms/step -Epoch 63/100 -3169/3169 - 12s - loss: 6.4885 - accuracy: 0.5044 - auc: 0.8881 - val_loss: 1.1382 - val_accuracy: 0.4918 - val_auc: 0.8826 - 12s/epoch - 4ms/step -Epoch 64/100 -3169/3169 - 12s - loss: 6.4309 - accuracy: 0.5053 - auc: 0.8883 - val_loss: 1.1425 - val_accuracy: 0.4885 - val_auc: 0.8822 - 12s/epoch - 4ms/step -Epoch 65/100 -3169/3169 - 12s - loss: 6.4270 - accuracy: 0.5071 - auc: 0.8891 - val_loss: 1.1425 - val_accuracy: 0.4926 - val_auc: 0.8826 - 12s/epoch - 4ms/step -Epoch 66/100 -3169/3169 - 12s - loss: 6.4116 - accuracy: 0.5069 - auc: 0.8892 - val_loss: 1.1418 - val_accuracy: 0.4900 - val_auc: 0.8823 - 12s/epoch - 4ms/step -Epoch 67/100 -3169/3169 - 12s - loss: 6.3855 - accuracy: 0.5069 - auc: 0.8896 - val_loss: 1.1360 - val_accuracy: 0.4942 - val_auc: 0.8838 - 12s/epoch - 4ms/step -Epoch 68/100 -3169/3169 - 12s - loss: 6.3426 - accuracy: 0.5094 - auc: 0.8905 - val_loss: 1.1360 - val_accuracy: 0.4931 - val_auc: 0.8836 - 12s/epoch - 4ms/step -Epoch 69/100 -3169/3169 - 12s - loss: 6.3108 - accuracy: 0.5102 - auc: 0.8910 - val_loss: 1.1364 - val_accuracy: 0.4946 - val_auc: 0.8839 - 12s/epoch - 4ms/step -Epoch 70/100 -3169/3169 - 12s - loss: 6.3049 - accuracy: 0.5105 - auc: 0.8909 - val_loss: 1.1246 - val_accuracy: 0.4984 - val_auc: 0.8862 - 12s/epoch - 4ms/step -Epoch 71/100 -3169/3169 - 12s - loss: 6.2819 - accuracy: 0.5105 - auc: 0.8918 - val_loss: 1.1338 - val_accuracy: 0.4965 - val_auc: 0.8848 - 12s/epoch - 4ms/step -Epoch 72/100 -3169/3169 - 12s - loss: 6.2571 - accuracy: 0.5109 - auc: 0.8918 - val_loss: 1.1305 - val_accuracy: 0.4962 - val_auc: 0.8852 - 12s/epoch - 4ms/step -Epoch 73/100 -3169/3169 - 12s - loss: 6.2476 - accuracy: 0.5126 - auc: 0.8922 - val_loss: 1.1235 - val_accuracy: 0.4981 - val_auc: 0.8865 - 12s/epoch - 4ms/step -Epoch 74/100 -3169/3169 - 13s - loss: 6.2087 - accuracy: 0.5137 - auc: 0.8930 - val_loss: 1.1252 - val_accuracy: 0.5015 - val_auc: 0.8866 - 13s/epoch - 4ms/step -Epoch 75/100 -3169/3169 - 12s - loss: 6.1919 - accuracy: 0.5150 - auc: 0.8932 - val_loss: 1.1210 - val_accuracy: 0.5012 - val_auc: 0.8872 - 12s/epoch - 4ms/step -Epoch 76/100 -3169/3169 - 12s - loss: 6.1675 - accuracy: 0.5167 - auc: 0.8938 - val_loss: 1.1194 - val_accuracy: 0.5038 - val_auc: 0.8879 - 12s/epoch - 4ms/step -Epoch 77/100 -3169/3169 - 12s - loss: 6.1344 - accuracy: 0.5173 - auc: 0.8944 - val_loss: 1.1366 - val_accuracy: 0.4944 - val_auc: 0.8845 - 12s/epoch - 4ms/step -Epoch 78/100 -3169/3169 - 12s - loss: 6.1222 - accuracy: 0.5170 - auc: 0.8946 - val_loss: 1.1273 - val_accuracy: 0.4975 - val_auc: 0.8861 - 12s/epoch - 4ms/step -Epoch 79/100 -3169/3169 - 12s - loss: 6.0835 - accuracy: 0.5197 - auc: 0.8953 - val_loss: 1.1268 - val_accuracy: 0.4994 - val_auc: 0.8866 - 12s/epoch - 4ms/step -Epoch 80/100 -3169/3169 - 13s - loss: 6.0967 - accuracy: 0.5182 - auc: 0.8951 - val_loss: 1.1287 - val_accuracy: 0.5024 - val_auc: 0.8863 - 13s/epoch - 4ms/step -Epoch 81/100 -3169/3169 - 12s - loss: 6.0538 - accuracy: 0.5210 - auc: 0.8958 - val_loss: 1.1287 - val_accuracy: 0.4983 - val_auc: 0.8860 - 12s/epoch - 4ms/step -Epoch 82/100 -3169/3169 - 12s - loss: 6.0255 - accuracy: 0.5209 - auc: 0.8964 - val_loss: 1.1180 - val_accuracy: 0.5054 - val_auc: 0.8885 - 12s/epoch - 4ms/step -Epoch 83/100 -3169/3169 - 12s - loss: 5.9945 - accuracy: 0.5209 - auc: 0.8966 - val_loss: 1.1102 - val_accuracy: 0.5068 - val_auc: 0.8897 - 12s/epoch - 4ms/step -Epoch 84/100 -3169/3169 - 12s - loss: 5.9736 - accuracy: 0.5232 - auc: 0.8972 - val_loss: 1.1121 - val_accuracy: 0.5051 - val_auc: 0.8896 - 12s/epoch - 4ms/step -Epoch 85/100 -3169/3169 - 12s - loss: 5.9699 - accuracy: 0.5228 - auc: 0.8973 - val_loss: 1.1190 - val_accuracy: 0.5038 - val_auc: 0.8887 - 12s/epoch - 4ms/step -Epoch 86/100 -3169/3169 - 12s - loss: 5.9586 - accuracy: 0.5232 - auc: 0.8975 - val_loss: 1.1147 - val_accuracy: 0.5049 - val_auc: 0.8891 - 12s/epoch - 4ms/step -Epoch 87/100 -3169/3169 - 12s - loss: 5.9343 - accuracy: 0.5239 - auc: 0.8978 - val_loss: 1.1220 - val_accuracy: 0.5027 - val_auc: 0.8883 - 12s/epoch - 4ms/step -Epoch 88/100 -3169/3169 - 12s - loss: 5.8928 - accuracy: 0.5256 - auc: 0.8987 - val_loss: 1.1123 - val_accuracy: 0.5111 - val_auc: 0.8902 - 12s/epoch - 4ms/step -Epoch 89/100 -3169/3169 - 12s - loss: 5.8686 - accuracy: 0.5257 - auc: 0.8989 - val_loss: 1.1118 - val_accuracy: 0.5064 - val_auc: 0.8901 - 12s/epoch - 4ms/step -Epoch 90/100 -3169/3169 - 12s - loss: 5.8582 - accuracy: 0.5277 - auc: 0.8995 - val_loss: 1.1055 - val_accuracy: 0.5098 - val_auc: 0.8913 - 12s/epoch - 4ms/step -Epoch 91/100 -3169/3169 - 12s - loss: 5.8352 - accuracy: 0.5280 - auc: 0.8996 - val_loss: 1.1036 - val_accuracy: 0.5088 - val_auc: 0.8916 - 12s/epoch - 4ms/step -Epoch 92/100 -3169/3169 - 12s - loss: 5.8186 - accuracy: 0.5274 - auc: 0.8999 - val_loss: 1.1128 - val_accuracy: 0.5066 - val_auc: 0.8901 - 12s/epoch - 4ms/step -Epoch 93/100 -3169/3169 - 12s - loss: 5.8003 - accuracy: 0.5278 - auc: 0.9002 - val_loss: 1.1047 - val_accuracy: 0.5076 - val_auc: 0.8912 - 12s/epoch - 4ms/step -Epoch 94/100 -3169/3169 - 12s - loss: 5.7763 - accuracy: 0.5297 - auc: 0.9008 - val_loss: 1.1205 - val_accuracy: 0.5042 - val_auc: 0.8891 - 12s/epoch - 4ms/step -Epoch 95/100 -3169/3169 - 12s - loss: 5.7656 - accuracy: 0.5280 - auc: 0.9006 - val_loss: 1.1119 - val_accuracy: 0.5051 - val_auc: 0.8904 - 12s/epoch - 4ms/step -Epoch 96/100 -3169/3169 - 12s - loss: 5.7510 - accuracy: 0.5304 - auc: 0.9012 - val_loss: 1.1095 - val_accuracy: 0.5083 - val_auc: 0.8912 - 12s/epoch - 4ms/step -Epoch 97/100 -3169/3169 - 12s - loss: 5.7480 - accuracy: 0.5302 - auc: 0.9013 - val_loss: 1.1021 - val_accuracy: 0.5091 - val_auc: 0.8922 - 12s/epoch - 4ms/step -Epoch 98/100 -3169/3169 - 12s - loss: 5.7046 - accuracy: 0.5310 - auc: 0.9019 - val_loss: 1.1050 - val_accuracy: 0.5097 - val_auc: 0.8920 - 12s/epoch - 4ms/step -Epoch 99/100 -3169/3169 - 12s - loss: 5.7046 - accuracy: 0.5324 - auc: 0.9022 - val_loss: 1.0983 - val_accuracy: 0.5136 - val_auc: 0.8930 - 12s/epoch - 4ms/step -Epoch 100/100 -3169/3169 - 12s - loss: 5.6727 - accuracy: 0.5335 - auc: 0.9026 - val_loss: 1.1125 - val_accuracy: 0.5039 - val_auc: 0.8907 - 12s/epoch - 4ms/step - -``` -
---- -## Results - -Let's plot the training and validation AUC and accuracy. - - -```python -fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(14, 5)) - -axs[0].plot(range(EPOCHS), history.history["accuracy"], label="Training") -axs[0].plot(range(EPOCHS), history.history["val_accuracy"], label="Validation") -axs[0].set_xlabel("Epochs") -axs[0].set_title("Training & Validation Accuracy") -axs[0].legend() -axs[0].grid(True) - -axs[1].plot(range(EPOCHS), history.history["auc"], label="Training") -axs[1].plot(range(EPOCHS), history.history["val_auc"], label="Validation") -axs[1].set_xlabel("Epochs") -axs[1].set_title("Training & Validation AUC") -axs[1].legend() -axs[1].grid(True) - -plt.show() -``` - - -![png](/img/examples/audio/uk_ireland_accent_recognition/uk_ireland_accent_recognition_29_0.png) - - ---- -## Evaluation - - -```python -train_loss, train_acc, train_auc = model.evaluate(train_ds) -valid_loss, valid_acc, valid_auc = model.evaluate(valid_ds) -``` - -
-``` -3169/3169 [==============================] - 10s 3ms/step - loss: 1.0117 - accuracy: 0.5423 - auc: 0.9079 -349/349 [==============================] - 1s 3ms/step - loss: 1.1125 - accuracy: 0.5039 - auc: 0.8907 - -``` -
-Let's try to compare our model's performance to Yamnet's using one of Yamnet metrics (d-prime) -Yamnet achieved a d-prime value of 2.318. -Let's check our model's performance. - - -```python -# The following function calculates the d-prime score from the AUC -def d_prime(auc): - standard_normal = stats.norm() - d_prime = standard_normal.ppf(auc) * np.sqrt(2.0) - return d_prime - - -print( - "train d-prime: {0:.3f}, validation d-prime: {1:.3f}".format( - d_prime(train_auc), d_prime(valid_auc) - ) -) -``` - -
-``` -train d-prime: 1.878, validation d-prime: 1.740 - -``` -
-We can see that the model achieves the following results: - -Results | Training | Validation ------------|-----------|------------ -Accuracy | 54% | 51% -AUC | 0.91 | 0.89 -d-prime | 1.882 | 1.740 - ---- -## Confusion Matrix - -Let's now plot the confusion matrix for the validation dataset. - -The confusion matrix lets us see, for every class, not only how many samples were correctly classified, -but also which other classes were the samples confused with. - -It allows us to calculate the precision and recall for every class. - - -```python -# Create x and y tensors -x_valid = None -y_valid = None - -for x, y in iter(valid_ds): - if x_valid is None: - x_valid = x.numpy() - y_valid = y.numpy() - else: - x_valid = np.concatenate((x_valid, x.numpy()), axis=0) - y_valid = np.concatenate((y_valid, y.numpy()), axis=0) - -# Generate predictions -y_pred = model.predict(x_valid) - -# Calculate confusion matrix -confusion_mtx = tf.math.confusion_matrix( - np.argmax(y_valid, axis=1), np.argmax(y_pred, axis=1) -) - -# Plot the confusion matrix -plt.figure(figsize=(10, 8)) -sns.heatmap( - confusion_mtx, xticklabels=class_names, yticklabels=class_names, annot=True, fmt="g" -) -plt.xlabel("Prediction") -plt.ylabel("Label") -plt.title("Validation Confusion Matrix") -plt.show() -``` - - -![png](/img/examples/audio/uk_ireland_accent_recognition/uk_ireland_accent_recognition_36_0.png) - - ---- -## Precision & recall - -For every class: - -* Recall is the ratio of correctly classified samples i.e. it shows how many samples -of this specific class, the model is able to detect. -It is the ratio of diagonal elements to the sum of all elements in the row. -* Precision shows the accuracy of the classifier. It is the ratio of correctly predicted -samples among the ones classified as belonging to this class. -It is the ratio of diagonal elements to the sum of all elements in the column. - - -```python -for i, label in enumerate(class_names): - precision = confusion_mtx[i, i] / np.sum(confusion_mtx[:, i]) - recall = confusion_mtx[i, i] / np.sum(confusion_mtx[i, :]) - print( - "{0:15} Precision:{1:.2f}%; Recall:{2:.2f}%".format( - label, precision * 100, recall * 100 - ) - ) -``` - -
-``` -Irish Precision:17.22%; Recall:63.36% -Midlands Precision:13.35%; Recall:51.70% -Northern Precision:30.22%; Recall:50.58% -Scottish Precision:28.85%; Recall:32.57% -Southern Precision:76.34%; Recall:28.14% -Welsh Precision:74.33%; Recall:83.34% -Not a speech Precision:98.83%; Recall:99.93% - -``` -
---- -## Run inference on test data - -Let's now run a test on a single audio file. -Let's check this example from [The Scottish Voice](https://www.thescottishvoice.org.uk/home/) - -We will: - -* Download the mp3 file. -* Convert it to a 16k wav file. -* Run the model on the wav file. -* Plot the results. - - -```python -filename = "audio-sample-Stuart" -url = "https://www.thescottishvoice.org.uk/files/cm/files/" - -if os.path.exists(filename + ".wav") == False: - print(f"Downloading {filename}.mp3 from {url}") - command = f"wget {url}{filename}.mp3" - os.system(command) - - print(f"Converting mp3 to wav and resampling to 16 kHZ") - command = ( - f"ffmpeg -hide_banner -loglevel panic -y -i {filename}.mp3 -acodec " - f"pcm_s16le -ac 1 -ar 16000 {filename}.wav" - ) - os.system(command) - -filename = filename + ".wav" - -``` - -
-``` -Downloading audio-sample-Stuart.mp3 from https://www.thescottishvoice.org.uk/files/cm/files/ -Converting mp3 to wav and resampling to 16 kHZ - -``` -
-The below function `yamnet_class_names_from_csv` was copied and very slightly changed -from this [Yamnet Notebook](https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/yamnet.ipynb). - - -```python - -def yamnet_class_names_from_csv(yamnet_class_map_csv_text): - """Returns list of class names corresponding to score vector.""" - yamnet_class_map_csv = io.StringIO(yamnet_class_map_csv_text) - yamnet_class_names = [ - name for (class_index, mid, name) in csv.reader(yamnet_class_map_csv) - ] - yamnet_class_names = yamnet_class_names[1:] # Skip CSV header - return yamnet_class_names - - -yamnet_class_map_path = yamnet_model.class_map_path().numpy() -yamnet_class_names = yamnet_class_names_from_csv( - tf.io.read_file(yamnet_class_map_path).numpy().decode("utf-8") -) - - -def calculate_number_of_non_speech(scores): - number_of_non_speech = tf.math.reduce_sum( - tf.where(tf.math.argmax(scores, axis=1, output_type=tf.int32) != 0, 1, 0) - ) - - return number_of_non_speech - - -def filename_to_predictions(filename): - # Load 16k audio wave - audio_wav = load_16k_audio_wav(filename) - - # Get audio embeddings & scores. - scores, embeddings, mel_spectrogram = yamnet_model(audio_wav) - - print( - "Out of {} samples, {} are not speech".format( - scores.shape[0], calculate_number_of_non_speech(scores) - ) - ) - - # Predict the output of the accent recognition model with embeddings as input - predictions = model.predict(embeddings) - - return audio_wav, predictions, mel_spectrogram - -``` - -Let's run the model on the audio file: - - -```python -audio_wav, predictions, mel_spectrogram = filename_to_predictions(filename) - -infered_class = class_names[predictions.mean(axis=0).argmax()] -print(f"The main accent is: {infered_class} English") -``` - -
-``` -Out of 66 samples, 0 are not speech -The main accent is: Scottish English - -``` -
-Listen to the audio - - -```python -Audio(audio_wav, rate=16000) -``` - - - - - - - - - - -The below function was copied from this [Yamnet notebook](tinyurl.com/4a8xn7at) and adjusted to our need. - -This function plots the following: - -* Audio waveform -* Mel spectrogram -* Predictions for every time step - - -```python -plt.figure(figsize=(10, 6)) - -# Plot the waveform. -plt.subplot(3, 1, 1) -plt.plot(audio_wav) -plt.xlim([0, len(audio_wav)]) - -# Plot the log-mel spectrogram (returned by the model). -plt.subplot(3, 1, 2) -plt.imshow( - mel_spectrogram.numpy().T, aspect="auto", interpolation="nearest", origin="lower" -) - -# Plot and label the model output scores for the top-scoring classes. -mean_predictions = np.mean(predictions, axis=0) - -top_class_indices = np.argsort(mean_predictions)[::-1] -plt.subplot(3, 1, 3) -plt.imshow( - predictions[:, top_class_indices].T, - aspect="auto", - interpolation="nearest", - cmap="gray_r", -) - -# patch_padding = (PATCH_WINDOW_SECONDS / 2) / PATCH_HOP_SECONDS -# values from the model documentation -patch_padding = (0.025 / 2) / 0.01 -plt.xlim([-patch_padding - 0.5, predictions.shape[0] + patch_padding - 0.5]) -# Label the top_N classes. -yticks = range(0, len(class_names), 1) -plt.yticks(yticks, [class_names[top_class_indices[x]] for x in yticks]) -_ = plt.ylim(-0.5 + np.array([len(class_names), 0])) -``` - - -![png](/img/examples/audio/uk_ireland_accent_recognition/uk_ireland_accent_recognition_48_0.png) - diff --git a/templates/examples/audio/wav2vec2_audiocls.md b/templates/examples/audio/wav2vec2_audiocls.md deleted file mode 100644 index b8817bac69..0000000000 --- a/templates/examples/audio/wav2vec2_audiocls.md +++ /dev/null @@ -1,482 +0,0 @@ -# Audio Classification with Hugging Face Transformers - -**Author:** Sreyan Ghosh
-**Date created:** 2022/07/01
-**Last modified:** 2022/08/27
-**Description:** Training Wav2Vec 2.0 using Hugging Face Transformers for Audio Classification. - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/audio/ipynb/wav2vec2_audiocls.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/audio/wav2vec2_audiocls.py) - - - ---- -## Introduction - -Identification of speech commands, also known as *keyword spotting* (KWS), -is important from an engineering perspective for a wide range of applications, -from indexing audio databases and indexing keywords, to running speech models locally -on microcontrollers. Currently, many human-computer interfaces (HCI) like Google -Assistant, Microsoft Cortana, Amazon Alexa, Apple Siri and others rely on keyword -spotting. There is a significant amount of research in the field by all major companies, -notably Google and Baidu. - -In the past decade, deep learning has led to significant performance -gains on this task. Though low-level audio features extracted from raw audio like MFCC or -mel-filterbanks have been used for decades, the design of these low-level features -are [flawed by biases](https://arxiv.org/abs/2101.08596). Moreover, deep learning models -trained on these low-level features can easily overfit to noise or signals irrelevant to the -task. This makes it is essential for any system to learn speech representations that make -high-level information, such as acoustic and linguistic content, including phonemes, -words, semantic meanings, tone, speaker characteristics from speech signals available to -solve the downstream task. [Wav2Vec 2.0](https://arxiv.org/abs/2006.11477), which solves a -self-supervised contrastive learning task to learn high-level speech representations, -provides a great alternative to traditional low-level features for training deep learning -models for KWS. - -In this notebook, we train the Wav2Vec 2.0 (base) model, built on the -Hugging Face Transformers library, in an end-to-end fashion on the keyword spotting task and -achieve state-of-the-art results on the Google Speech Commands Dataset. - ---- -## Setup - -### Installing the requirements - - -```python -pip install git+https://github.com/huggingface/transformers.git -pip install datasets -pip install huggingface-hub -pip install joblib -pip install librosa -``` - -### Importing the necessary libraries - - -```python -import random -import logging - -import numpy as np -import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers - -# Only log error messages -tf.get_logger().setLevel(logging.ERROR) -# Set random seed -tf.keras.utils.set_random_seed(42) -``` - -### Define certain variables - - -```python -# Maximum duration of the input audio file we feed to our Wav2Vec 2.0 model. -MAX_DURATION = 1 -# Sampling rate is the number of samples of audio recorded every second -SAMPLING_RATE = 16000 -BATCH_SIZE = 32 # Batch-size for training and evaluating our model. -NUM_CLASSES = 10 # Number of classes our dataset will have (11 in our case). -HIDDEN_DIM = 768 # Dimension of our model output (768 in case of Wav2Vec 2.0 - Base). -MAX_SEQ_LENGTH = MAX_DURATION * SAMPLING_RATE # Maximum length of the input audio file. -# Wav2Vec 2.0 results in an output frequency with a stride of about 20ms. -MAX_FRAMES = 49 -MAX_EPOCHS = 2 # Maximum number of training epochs. - -MODEL_CHECKPOINT = "facebook/wav2vec2-base" # Name of pretrained model from Hugging Face Model Hub -``` - ---- -## Load the Google Speech Commands Dataset - -We now download the [Google Speech Commands V1 Dataset](https://arxiv.org/abs/1804.03209), -a popular benchmark for training and evaluating deep learning models built for solving the KWS task. -The dataset consists of a total of 60,973 audio files, each of 1 second duration, -divided into ten classes of keywords ("Yes", "No", "Up", "Down", "Left", "Right", "On", -"Off", "Stop", and "Go"), a class for silence, and an unknown class to include the false -positive. We load the dataset from [Hugging Face Datasets](https://github.com/huggingface/datasets). -This can be easily done with the `load_dataset` function. - - -```python -from datasets import load_dataset - -speech_commands_v1 = load_dataset("superb", "ks") -``` - -The dataset has the following fields: - -- **file**: the path to the raw .wav file of the audio -- **audio**: the audio file sampled at 16kHz -- **label**: label ID of the audio utterance - - -```python -print(speech_commands_v1) -``` - -
-``` -DatasetDict({ - train: Dataset({ - features: ['file', 'audio', 'label'], - num_rows: 51094 - }) - validation: Dataset({ - features: ['file', 'audio', 'label'], - num_rows: 6798 - }) - test: Dataset({ - features: ['file', 'audio', 'label'], - num_rows: 3081 - }) -}) - -``` -
---- -## Data Pre-processing - -For the sake of demonstrating the workflow, in this notebook we only take -small stratified balanced splits (50%) of the train as our training and test sets. -We can easily split the dataset using the `train_test_split` method which expects -the split size and the name of the column relative to which you want to stratify. - -Post splitting the dataset, we remove the `unknown` and `silence` classes and only -focus on the ten main classes. The `filter` method does that easily for you. - -Next we sample our train and test splits to a multiple of the `BATCH_SIZE` to -facilitate smooth training and inference. You can achieve that using the `select` -method which expects the indices of the samples you want to keep. Rest all are -discarded. - - -```python -speech_commands_v1 = speech_commands_v1["train"].train_test_split( - train_size=0.5, test_size=0.5, stratify_by_column="label" -) - -speech_commands_v1 = speech_commands_v1.filter( - lambda x: x["label"] - != ( - speech_commands_v1["train"].features["label"].names.index("_unknown_") - and speech_commands_v1["train"].features["label"].names.index("_silence_") - ) -) - -speech_commands_v1["train"] = speech_commands_v1["train"].select( - [i for i in range((len(speech_commands_v1["train"]) // BATCH_SIZE) * BATCH_SIZE)] -) -speech_commands_v1["test"] = speech_commands_v1["test"].select( - [i for i in range((len(speech_commands_v1["test"]) // BATCH_SIZE) * BATCH_SIZE)] -) - -print(speech_commands_v1) -``` - -
-``` -DatasetDict({ - train: Dataset({ - features: ['file', 'audio', 'label'], - num_rows: 896 - }) - test: Dataset({ - features: ['file', 'audio', 'label'], - num_rows: 896 - }) -}) - -``` -
-Additionally, you can check the actual labels corresponding to each label ID. - - -```python -labels = speech_commands_v1["train"].features["label"].names -label2id, id2label = dict(), dict() -for i, label in enumerate(labels): - label2id[label] = str(i) - id2label[str(i)] = label - -print(id2label) -``` - -
-``` -{'0': 'yes', '1': 'no', '2': 'up', '3': 'down', '4': 'left', '5': 'right', '6': 'on', '7': 'off', '8': 'stop', '9': 'go', '10': '_silence_', '11': '_unknown_'} - -``` -
-Before we can feed the audio utterance samples to our model, we need to -pre-process them. This is done by a Hugging Face Transformers "Feature Extractor" -which will (as the name indicates) re-sample your inputs to the sampling rate -the model expects (in-case they exist with a different sampling rate), as well -as generate the other inputs that model requires. - -To do all of this, we instantiate our `Feature Extractor` with the -`AutoFeatureExtractor.from_pretrained`, which will ensure: - -We get a `Feature Extractor` that corresponds to the model architecture we want to use. -We download the config that was used when pretraining this specific checkpoint. -This will be cached so that it's not downloaded again the next time we run the cell. - -The `from_pretrained()` method expects the name of a model from the Hugging Face Hub. This is -exactly similar to `MODEL_CHECKPOINT` and we just pass that. - -We write a simple function that helps us in the pre-processing that is compatible -with Hugging Face Datasets. To summarize, our pre-processing function should: - -- Call the audio column to load and if necessary resample the audio file. -- Check the sampling rate of the audio file matches the sampling rate of the audio data a -model was pretrained with. You can find this information on the Wav2Vec 2.0 model card. -- Set a maximum input length so longer inputs are batched without being truncated. - - -```python -from transformers import AutoFeatureExtractor - -feature_extractor = AutoFeatureExtractor.from_pretrained( - MODEL_CHECKPOINT, return_attention_mask=True -) - - -def preprocess_function(examples): - audio_arrays = [x["array"] for x in examples["audio"]] - inputs = feature_extractor( - audio_arrays, - sampling_rate=feature_extractor.sampling_rate, - max_length=MAX_SEQ_LENGTH, - truncation=True, - padding=True, - ) - return inputs - - -# This line with pre-process our speech_commands_v1 dataset. We also remove the "audio" -# and "file" columns as they will be of no use to us while training. -processed_speech_commands_v1 = speech_commands_v1.map( - preprocess_function, remove_columns=["audio", "file"], batched=True -) - -# Load the whole dataset splits as a dict of numpy arrays -train = processed_speech_commands_v1["train"].shuffle(seed=42).with_format("numpy")[:] -test = processed_speech_commands_v1["test"].shuffle(seed=42).with_format("numpy")[:] -``` - ---- -## Defining the Wav2Vec 2.0 with Classification-Head - -We now define our model. To be precise, we define a Wav2Vec 2.0 model and add a -Classification-Head on top to output a probability distribution of all classes for each -input audio sample. Since the model might get complex we first define the Wav2Vec -2.0 model with Classification-Head as a Keras layer and then build the model using that. - -We instantiate our main Wav2Vec 2.0 model using the `TFWav2Vec2Model` class. This will -instantiate a model which will output 768 or 1024 dimensional embeddings according to -the config you choose (BASE or LARGE). The `from_pretrained()` additionally helps you -load pre-trained weights from the Hugging Face Model Hub. It will download the pre-trained weights -together with the config corresponding to the name of the model you have mentioned when -calling the method. For our task, we choose the BASE variant of the model that has -just been pre-trained, since we fine-tune over it. - - -```python -from transformers import TFWav2Vec2Model - - -def mean_pool(hidden_states, feature_lengths): - attenion_mask = tf.sequence_mask( - feature_lengths, maxlen=MAX_FRAMES, dtype=tf.dtypes.int64 - ) - padding_mask = tf.cast( - tf.reverse(tf.cumsum(tf.reverse(attenion_mask, [-1]), -1), [-1]), - dtype=tf.dtypes.bool, - ) - hidden_states = tf.where( - tf.broadcast_to( - tf.expand_dims(~padding_mask, -1), (BATCH_SIZE, MAX_FRAMES, HIDDEN_DIM) - ), - 0.0, - hidden_states, - ) - pooled_state = tf.math.reduce_sum(hidden_states, axis=1) / tf.reshape( - tf.math.reduce_sum(tf.cast(padding_mask, dtype=tf.dtypes.float32), axis=1), - [-1, 1], - ) - return pooled_state - - -class TFWav2Vec2ForAudioClassification(layers.Layer): - """Combines the encoder and decoder into an end-to-end model for training.""" - - def __init__(self, model_checkpoint, num_classes): - super().__init__() - # Instantiate the Wav2Vec 2.0 model without the Classification-Head - self.wav2vec2 = TFWav2Vec2Model.from_pretrained( - model_checkpoint, apply_spec_augment=False, from_pt=True - ) - self.pooling = layers.GlobalAveragePooling1D() - # Drop-out layer before the final Classification-Head - self.intermediate_layer_dropout = layers.Dropout(0.5) - # Classification-Head - self.final_layer = layers.Dense(num_classes, activation="softmax") - - def call(self, inputs): - # We take only the first output in the returned dictionary corresponding to the - # output of the last layer of Wav2vec 2.0 - hidden_states = self.wav2vec2(inputs["input_values"])[0] - - # If attention mask does exist then mean-pool only un-masked output frames - if tf.is_tensor(inputs["attention_mask"]): - # Get the length of each audio input by summing up the attention_mask - # (attention_mask = (BATCH_SIZE x MAX_SEQ_LENGTH) ∈ {1,0}) - audio_lengths = tf.cumsum(inputs["attention_mask"], -1)[:, -1] - # Get the number of Wav2Vec 2.0 output frames for each corresponding audio input - # length - feature_lengths = self.wav2vec2.wav2vec2._get_feat_extract_output_lengths( - audio_lengths - ) - pooled_state = mean_pool(hidden_states, feature_lengths) - # If attention mask does not exist then mean-pool only all output frames - else: - pooled_state = self.pooling(hidden_states) - - intermediate_state = self.intermediate_layer_dropout(pooled_state) - final_state = self.final_layer(intermediate_state) - - return final_state - -``` - ---- -## Building and Compiling the model - -We now build and compile our model. We use the `SparseCategoricalCrossentropy` -to train our model since it is a classification task. Following much of literature -we evaluate our model on the `accuracy` metric. - - -```python - -def build_model(): - # Model's input - inputs = { - "input_values": tf.keras.Input(shape=(MAX_SEQ_LENGTH,), dtype="float32"), - "attention_mask": tf.keras.Input(shape=(MAX_SEQ_LENGTH,), dtype="int32"), - } - # Instantiate the Wav2Vec 2.0 model with Classification-Head using the desired - # pre-trained checkpoint - wav2vec2_model = TFWav2Vec2ForAudioClassification(MODEL_CHECKPOINT, NUM_CLASSES)( - inputs - ) - # Model - model = tf.keras.Model(inputs, wav2vec2_model) - # Loss - loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) - # Optimizer - optimizer = keras.optimizers.Adam(learning_rate=1e-5) - # Compile and return - model.compile(loss=loss, optimizer=optimizer, metrics=["accuracy"]) - return model - - -model = build_model() -``` - ---- -## Training the model - -Before we start training our model, we divide the inputs into its -dependent and independent variables. - - -```python -# Remove targets from training dictionaries -train_x = {x: y for x, y in train.items() if x != "label"} -test_x = {x: y for x, y in test.items() if x != "label"} -``` - -And now we can finally start training our model. - - -```python -model.fit( - train_x, - train["label"], - validation_data=(test_x, test["label"]), - batch_size=BATCH_SIZE, - epochs=MAX_EPOCHS, -) -``` - -
-``` -Epoch 1/2 -28/28 [==============================] - 25s 338ms/step - loss: 2.3122 - accuracy: 0.1205 - val_loss: 2.2023 - val_accuracy: 0.2176 -Epoch 2/2 -28/28 [==============================] - 5s 189ms/step - loss: 2.0533 - accuracy: 0.2868 - val_loss: 1.8177 - val_accuracy: 0.5089 - - - -``` -
-Great! Now that we have trained our model, we predict the classes -for audio samples in the test set using the `model.predict()` method! We see -the model predictions are not that great as it has been trained on a very small -number of samples for just 1 epoch. For best results, we recommend training on -the complete dataset for at least 5 epochs! - - -```python -preds = model.predict(test_x) -``` - -
-``` -28/28 [==============================] - 4s 44ms/step - -``` -
-Now we try to infer the model we trained on a randomly sampled audio file. -We hear the audio file and then also see how well our model was able to predict! - - -```python -import IPython.display as ipd - -rand_int = random.randint(0, len(test_x)) - -ipd.Audio(data=np.asarray(test_x["input_values"][rand_int]), autoplay=True, rate=16000) - -print("Original Label is ", id2label[str(test["label"][rand_int])]) -print("Predicted Label is ", id2label[str(np.argmax((preds[rand_int])))]) -``` - -
-``` -Original Label is up -Predicted Label is on - -``` -
-Now you can push this model to Hugging Face Model Hub and also share it with all your friends, -family, favorite pets: they can all load it with the identifier -`"your-username/the-name-you-picked"`, for instance: - -```python -model.push_to_hub("wav2vec2-ks", organization="keras-io") -tokenizer.push_to_hub("wav2vec2-ks", organization="keras-io") -``` -And after you push your model this is how you can load it in the future! - -```python -from transformers import TFWav2Vec2Model - -model = TFWav2Vec2Model.from_pretrained("your-username/my-awesome-model", from_pt=True) -``` - diff --git a/templates/examples/keras_rs/basic_ranking.md b/templates/examples/keras_rs/basic_ranking.md deleted file mode 100644 index 2f11e2d353..0000000000 --- a/templates/examples/keras_rs/basic_ranking.md +++ /dev/null @@ -1,614 +0,0 @@ -# Recommending movies: ranking - -**Author:** [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Rank movies using a two tower model. - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/basic_ranking.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/basic_ranking.py) - - - ---- -## Introduction - -Recommender systems are often composed of two stages: - -1. The retrieval stage is responsible for selecting an initial set of hundreds - of candidates from all possible candidates. The main objective of this model - is to efficiently weed out all candidates that the user is not interested in. - Because the retrieval model may be dealing with millions of candidates, it - has to be computationally efficient. -2. The ranking stage takes the outputs of the retrieval model and fine-tunes - them to select the best possible handful of recommendations. Its task is to - narrow down the set of items the user may be interested in to a shortlist of - likely candidates. - -In this tutorial, we're going to focus on the second stage, ranking. If you are -interested in the retrieval stage, have a look at our -[retrieval](/keras_rs/examples/basic_retrieval/) -tutorial. - -In this tutorial, we're going to: - -1. Get our data and split it into a training and test set. -2. Implement a ranking model. -3. Fit and evaluate it. -4. Test running predictions with the model. - -Let's begin by choosing JAX as the backend we want to run on, and import all -the necessary libraries. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import keras -import tensorflow as tf # Needed for the dataset -import tensorflow_datasets as tfds -``` - ---- -## Preparing the dataset - -We're going to use the same data as the -[retrieval](/keras_rs/examples/basic_retrieval/) -tutorial. The ratings are the objectives we are trying to predict. - - -```python -# Ratings data. -ratings = tfds.load("movielens/100k-ratings", split="train") -# Features of all the available movies. -movies = tfds.load("movielens/100k-movies", split="train") -``` - -
-``` -WARNING:absl:Variant folder /root/tensorflow_datasets/movielens/100k-ratings/0.1.1 has no dataset_info.json - -Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/100k-ratings/0.1.1... - -Dl Completed...: 0 url [00:00, ? url/s] - -Dl Size...: 0 MiB [00:00, ? MiB/s] - -Extraction completed...: 0 file [00:00, ? file/s] - -Generating splits...: 0%| | 0/1 [00:00 -In the Movielens dataset, user IDs are integers (represented as strings) -starting at 1 and with no gap. Normally, you would need to create a lookup table -to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the -user id directly as an index in our model, in particular to lookup the user -embedding from the user embedding table. So we need do know the number of users. - - -```python -users_count = ( - ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) - .reduce(tf.constant(0, tf.int32), tf.maximum) - .numpy() -) -``` - -In the Movielens dataset, movie IDs are integers (represented as strings) -starting at 1 and with no gap. Normally, you would need to create a lookup table -to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the -movie id directly as an index in our model, in particular to lookup the movie -embedding from the movie embedding table. So we need do know the number of -movies. - - -```python -movies_count = movies.cardinality().numpy() -``` - -The inputs to the model are the user IDs and movie IDs and the labels are the -ratings. - - -```python - -def preprocess_rating(x): - return ( - # Inputs are user IDs and movie IDs - { - "user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32), - "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), - }, - # Labels are ratings between 0 and 1. - (x["user_rating"] - 1.0) / 4.0, - ) - -``` - -We'll split the data by putting 80% of the ratings in the train set, and 20% in -the test set. - - -```python -shuffled_ratings = ratings.map(preprocess_rating).shuffle( - 100_000, seed=42, reshuffle_each_iteration=False -) -train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() -test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() -``` - ---- -## Implementing the Model - -### Architecture - -Ranking models do not face the same efficiency constraints as retrieval models -do, and so we have a little bit more freedom in our choice of architectures. - -A model composed of multiple stacked dense layers is a relatively common -architecture for ranking tasks. We can implement it as follows: - - -```python - -class RankingModel(keras.Model): - """Create the ranking model with the provided parameters. - - Args: - num_users: Number of entries in the user embedding table. - num_candidates: Number of entries in the candidate embedding table. - embedding_dimension: Output dimension for user and movie embedding tables. - """ - - def __init__( - self, - num_users, - num_candidates, - embedding_dimension=32, - **kwargs, - ): - super().__init__(**kwargs) - # Embedding table for users. - self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension) - # Embedding table for candidates. - self.candidate_embedding = keras.layers.Embedding( - num_candidates, embedding_dimension - ) - # Predictions. - self.ratings = keras.Sequential( - [ - # Learn multiple dense layers. - keras.layers.Dense(256, activation="relu"), - keras.layers.Dense(64, activation="relu"), - # Make rating predictions in the final layer. - keras.layers.Dense(1), - ] - ) - - def call(self, inputs): - user_id, movie_id = inputs["user_id"], inputs["movie_id"] - user_embeddings = self.user_embedding(user_id) - candidate_embeddings = self.candidate_embedding(movie_id) - return self.ratings( - keras.ops.concatenate([user_embeddings, candidate_embeddings], axis=1) - ) - -``` - -Let's first instantiate the model. Note that we add `+ 1` to the number of users -and movies to account for the fact that id zero is not used for either (IDs -start at 1), but still takes a row in the embedding tables. - - -```python -model = RankingModel(users_count + 1, movies_count + 1) -``` - -### Loss and metrics - -The next component is the loss used to train our model. Keras has several losses -to make this easy. In this instance, we'll make use of the `MeanSquaredError` -loss in order to predict the ratings. We'll also look at the -`RootMeanSquaredError` metric. - - -```python -model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[keras.metrics.RootMeanSquaredError()], - optimizer=keras.optimizers.Adagrad(learning_rate=0.1), -) -``` - ---- -## Fitting and evaluating - -After defining the model, we can use the standard Keras `model.fit()` to train -the model. - - -```python -model.fit(train_ratings, epochs=5) -``` - -
-``` -Epoch 1/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 3:31 3s/step - loss: 0.4544 - root_mean_squared_error: 0.6741 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 36s 465ms/step - loss: 0.3822 - root_mean_squared_error: 0.6155 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.1384 - root_mean_squared_error: 0.3630 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.1370 - root_mean_squared_error: 0.3611 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - loss: 0.1177 - root_mean_squared_error: 0.3360 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - loss: 0.1171 - root_mean_squared_error: 0.3352 - -
-``` - -``` -
- 74/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.1078 - root_mean_squared_error: 0.3227 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 10ms/step - loss: 0.1058 - root_mean_squared_error: 0.3200 - - -
-``` -Epoch 2/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 36s 462ms/step - loss: 0.0780 - root_mean_squared_error: 0.2794 - -
-``` - -``` -
- 25/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0773 - root_mean_squared_error: 0.2781 - -
-``` - -``` -
- 26/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0773 - root_mean_squared_error: 0.2781 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0774 - root_mean_squared_error: 0.2782 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0774 - root_mean_squared_error: 0.2782 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0774 - root_mean_squared_error: 0.2783 - - -
-``` -Epoch 3/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0767 - root_mean_squared_error: 0.2769 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0771 - root_mean_squared_error: 0.2777 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0760 - root_mean_squared_error: 0.2756 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0760 - root_mean_squared_error: 0.2756 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0759 - root_mean_squared_error: 0.2755 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0759 - root_mean_squared_error: 0.2755 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0759 - root_mean_squared_error: 0.2754 - - -
-``` -Epoch 4/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0746 - root_mean_squared_error: 0.2730 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0753 - root_mean_squared_error: 0.2743 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0738 - root_mean_squared_error: 0.2717 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0738 - root_mean_squared_error: 0.2717 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0736 - root_mean_squared_error: 0.2712 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0736 - root_mean_squared_error: 0.2712 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0735 - root_mean_squared_error: 0.2711 - - -
-``` -Epoch 5/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0709 - root_mean_squared_error: 0.2663 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0722 - root_mean_squared_error: 0.2686 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0707 - root_mean_squared_error: 0.2658 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0706 - root_mean_squared_error: 0.2658 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0703 - root_mean_squared_error: 0.2651 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0703 - root_mean_squared_error: 0.2651 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0703 - root_mean_squared_error: 0.2651 - - - - - -
-``` - - -``` -
-As the model trains, the loss is falling and the RMSE metric is improving. - -Finally, we can evaluate our model on the test set. The lower the RMSE metric, -the more accurate our model is at predicting ratings. - - -```python -model.evaluate(test_ratings, return_dict=True) -``` - - - 1/20 ━━━━━━━━━━━━━━━━━━━━ 36s 2s/step - loss: 0.0732 - root_mean_squared_error: 0.2705 - -
-``` - -``` -
- 2/20 ━━━━━━━━━━━━━━━━━━━━ 3s 187ms/step - loss: 0.0724 - root_mean_squared_error: 0.2690 - -
-``` - -``` -
- 3/20 ━━━━━━━━━━━━━━━━━━━━ 1s 95ms/step - loss: 0.0719 - root_mean_squared_error: 0.2681 - -
-``` - -``` -
- 20/20 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 0.0707 - root_mean_squared_error: 0.2658 - - - - - -
-``` -{'loss': 0.0712985172867775, 'root_mean_squared_error': 0.26701781153678894} - -``` -
---- -## Testing the ranking model - -So far, we have only handled movies by id. Now is the time to create a mapping -keyed by movie IDs to be able to surface the titles. - - -```python -movie_id_to_movie_title = { - int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator() -} -movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. -``` - -Now we can test the ranking model by computing predictions for a set of movies -and then rank these movies based on the predictions: - - -```python -user_id = 42 -movie_ids = [204, 141, 131] -predictions = model.predict( - { - "user_id": keras.ops.array([user_id] * len(movie_ids)), - "movie_id": keras.ops.array(movie_ids), - } -) -predictions = keras.ops.convert_to_numpy(keras.ops.squeeze(predictions, axis=1)) - -for movie_id, prediction in zip(movie_ids, predictions): - print(f"{movie_id_to_movie_title[movie_id]}: {5.0 * prediction:,.2f}") -``` - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 271ms/step - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 273ms/step - - -
-``` -b'Back to the Future (1985)': 3.86 -b'20,000 Leagues Under the Sea (1954)': 3.93 -b"Breakfast at Tiffany's (1961)": 3.72 - -``` -
diff --git a/templates/examples/keras_rs/basic_retrieval.md b/templates/examples/keras_rs/basic_retrieval.md deleted file mode 100644 index f8e96c8393..0000000000 --- a/templates/examples/keras_rs/basic_retrieval.md +++ /dev/null @@ -1,2170 +0,0 @@ -# Recommending movies: retrieval - -**Author:** [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Retrieve movies using a two tower model. - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/basic_retrieval.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/basic_retrieval.py) - - - ---- -## Introduction - -Recommender systems are often composed of two stages: - -1. The retrieval stage is responsible for selecting an initial set of hundreds - of candidates from all possible candidates. The main objective of this model - is to efficiently weed out all candidates that the user is not interested in. - Because the retrieval model may be dealing with millions of candidates, it - has to be computationally efficient. -2. The ranking stage takes the outputs of the retrieval model and fine-tunes - them to select the best possible handful of recommendations. Its task is to - narrow down the set of items the user may be interested in to a shortlist of - likely candidates. - -In this tutorial, we're going to focus on the first stage, retrieval. If you are -interested in the ranking stage, have a look at our -[ranking](/keras_rs/examples/basic_ranking/) tutorial. - -Retrieval models are often composed of two sub-models: - -1. A query tower computing the query representation (normally a - fixed-dimensionality embedding vector) using query features. -2. A candidate tower computing the candidate representation (an equally-sized - vector) using the candidate features. The outputs of the two models are then - multiplied together to give a query-candidate affinity score, with higher - scores expressing a better match between the candidate and the query. - -In this tutorial, we're going to build and train such a two-tower model using -the Movielens dataset. - -We're going to: - -1. Get our data and split it into a training and test set. -2. Implement a retrieval model. -3. Fit and evaluate it. -4. Test running predictions with the model. - -### The dataset - -The Movielens dataset is a classic dataset from the -[GroupLens](https://grouplens.org/datasets/movielens/) research group at the -University of Minnesota. It contains a set of ratings given to movies by a set -of users, and is a standard for recommender systems research. - -The data can be treated in two ways: - -1. It can be interpreted as expressesing which movies the users watched (and - rated), and which they did not. This is a form of implicit feedback, where - users' watches tell us which things they prefer to see and which they'd - rather not see. -2. It can also be seen as expressesing how much the users liked the movies they - did watch. This is a form of explicit feedback: given that a user watched a - movie, we can tell how much they liked by looking at the rating they have - given. - -In this tutorial, we are focusing on a retrieval system: a model that predicts a -set of movies from the catalogue that the user is likely to watch. For this, the -model will try to predict the rating users would give to all the movies in the -catalogue. We will therefore use the explicit rating data. - -Let's begin by choosing JAX as the backend we want to run on, and import all -the necessary libraries. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import keras -import tensorflow as tf # Needed for the dataset -import tensorflow_datasets as tfds - -import keras_rs -``` - ---- -## Preparing the dataset - -Let's first have a look at the data. - -We use the MovieLens dataset from -[Tensorflow Datasets](https://www.tensorflow.org/datasets). Loading -`movielens/100k_ratings` yields a `tf.data.Dataset` object containing the -ratings alongside user and movie data. Loading `movielens/100k_movies` yields a -`tf.data.Dataset` object containing only the movies data. - -Note that since the MovieLens dataset does not have predefined splits, all data -are under `train` split. - - -```python -# Ratings data with user and movie data. -ratings = tfds.load("movielens/100k-ratings", split="train") -# Features of all the available movies. -movies = tfds.load("movielens/100k-movies", split="train") -``` - -The ratings dataset returns a dictionary of movie id, user id, the assigned -rating, timestamp, movie information, and user information: - - -```python -for data in ratings.take(1).as_numpy_iterator(): - print(str(data).replace(", '", ",\n '")) -``` - -
-``` -{'bucketized_user_age': np.float32(45.0), - 'movie_genres': array([7]), - 'movie_id': b'357', - 'movie_title': b"One Flew Over the Cuckoo's Nest (1975)", - 'raw_user_age': np.float32(46.0), - 'timestamp': np.int64(879024327), - 'user_gender': np.True_, - 'user_id': b'138', - 'user_occupation_label': np.int64(4), - 'user_occupation_text': b'doctor', - 'user_rating': np.float32(4.0), - 'user_zip_code': b'53211'} - -``` -
-In the Movielens dataset, user IDs are integers (represented as strings) -starting at 1 and with no gap. Normally, you would need to create a lookup table -to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the -user id directly as an index in our model, in particular to lookup the user -embedding from the user embedding table. So we need do know the number of users. - - -```python -users_count = ( - ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) - .reduce(tf.constant(0, tf.int32), tf.maximum) - .numpy() -) -``` - -The movies dataset contains the movie id, movie title, and the genres it belongs -to. Note that the genres are encoded with integer labels. - - -```python -for data in movies.take(1).as_numpy_iterator(): - print(str(data).replace(", '", ",\n '")) -``` - -
-``` -{'movie_genres': array([4]), - 'movie_id': b'1681', - 'movie_title': b'You So Crazy (1994)'} - -``` -
-In the Movielens dataset, movie IDs are integers (represented as strings) -starting at 1 and with no gap. Normally, you would need to create a lookup table -to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the -movie id directly as an index in our model, in particular to lookup the movie -embedding from the movie embedding table. So we need do know the number of -movies. - - -```python -movies_count = movies.cardinality().numpy() -``` - -In this example, we're going to focus on the ratings data. Other tutorials -explore how to use the movie information data as well as the user information to -improve the model quality. - -We keep only the `user_id`, `movie_id` and `rating` fields in the dataset. Our -input is the `user_id`. The labels are the `movie_id` alongside the `rating` for -the given movie and user. - -The `rating` is a number between 1 and 5, we adapt it to be between 0 and 1. - - -```python - -def preprocess_rating(x): - return ( - # Input is the user IDs - tf.strings.to_number(x["user_id"], out_type=tf.int32), - # Labels are movie IDs + ratings between 0 and 1. - { - "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), - "rating": (x["user_rating"] - 1.0) / 4.0, - }, - ) - -``` - -To fit and evaluate the model, we need to split it into a training and -evaluation set. In a real recommender system, this would most likely be done by -time: the data up to time *T* would be used to predict interactions after *T*. - -In this simple example, however, let's use a random split, putting 80% of the -ratings in the train set, and 20% in the test set. - - -```python -shuffled_ratings = ratings.map(preprocess_rating).shuffle( - 100_000, seed=42, reshuffle_each_iteration=False -) -train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() -test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() -``` - ---- -## Implementing the Model - -Choosing the architecture of our model is a key part of modelling. - -We are building a two-tower retrieval model, therefore we need to combine a -query tower for users and a candidate tower for movies. - -The first step is to decide on the dimensionality of the query and candidate -representations. This is the `embedding_dimension` argument in our model -constructor. We'll test with a value of `32`. Higher values will correspond to -models that may be more accurate, but will also be slower to fit and more prone -to overfitting. - -### Query and Candidate Towers - -The second step is to define the model itself. In this simple example, the query -tower and candidate tower are simply embeddings with nothing else. We'll use -Keras' `Embedding` layer. - -We can easily extend the towers to make them arbitrarily complex using standard -Keras components, as long as we return an `embedding_dimension`-wide output at -the end. - -### Retrieval - -The retrieval itself will be performed by `BruteForceRetrieval` layer from Keras -Recommenders. This layer computes the affinity scores for the given users and -all the candidate movies, then returns the top K in order. - -Note that during training, we don't actually need to perform any retrieval since -the only affinity scores we need are the ones for the users and movies in the -batch. As an optimization, we skip the retrieval entirely in the `call` method. - -### Loss - -The next component is the loss used to train our model. In this case, we use a -mean square error loss to measure the difference between the predicted movie -ratings and the actual ratins from users. - -Note that we override `compute_loss` from the `keras.Model` class. This allows -us to compute the query-candidate affinity score, which is obtained by -multiplying the outputs of the two towers together. That affinity score can then -be passed to the loss function. - - -```python - -class RetrievalModel(keras.Model): - """Create the retrieval model with the provided parameters. - - Args: - num_users: Number of entries in the user embedding table. - num_candidates: Number of entries in the candidate embedding table. - embedding_dimension: Output dimension for user and movie embedding tables. - """ - - def __init__( - self, - num_users, - num_candidates, - embedding_dimension=32, - **kwargs, - ): - super().__init__(**kwargs) - # Our query tower, simply an embedding table. - self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension) - # Our candidate tower, simply an embedding table. - self.candidate_embedding = keras.layers.Embedding( - num_candidates, embedding_dimension - ) - # The layer that performs the retrieval. - self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False) - self.loss_fn = keras.losses.MeanSquaredError() - - def build(self, input_shape): - self.user_embedding.build(input_shape) - self.candidate_embedding.build(input_shape) - # In this case, the candidates are directly the movie embeddings. - # We take a shortcut and directly reuse the variable. - self.retrieval.candidate_embeddings = self.candidate_embedding.embeddings - self.retrieval.build(input_shape) - super().build(input_shape) - - def call(self, inputs, training=False): - user_embeddings = self.user_embedding(inputs) - result = { - "user_embeddings": user_embeddings, - } - if not training: - # Skip the retrieval of top movies during training as the - # predictions are not used. - result["predictions"] = self.retrieval(user_embeddings) - return result - - def compute_loss(self, x, y, y_pred, sample_weight, training=True): - candidate_id, rating = y["movie_id"], y["rating"] - user_embeddings = y_pred["user_embeddings"] - candidate_embeddings = self.candidate_embedding(candidate_id) - - labels = keras.ops.expand_dims(rating, -1) - # Compute the affinity score by multiplying the two embeddings. - scores = keras.ops.sum( - keras.ops.multiply(user_embeddings, candidate_embeddings), - axis=1, - keepdims=True, - ) - return self.loss_fn(labels, scores, sample_weight) - -``` - ---- -## Fitting and evaluating - -After defining the model, we can use the standard Keras `model.fit()` to train -and evaluate the model. - -Let's first instantiate the model. Note that we add `+ 1` to the number of users -and movies to account for the fact that id zero is not used for either (IDs -start at 1), but still takes a row in the embedding tables. - - -```python -model = RetrievalModel(users_count + 1, movies_count + 1) -model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.1)) -``` - -Then train the model. Evaluation takes a bit of time, so we only evaluate the -model every 5 epochs. - - -```python -history = model.fit( - train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50 -) -``` - -
-``` -Epoch 1/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 2:37 2s/step - loss: 0.4472 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 17s 221ms/step - loss: 0.4542 - -
-``` - -``` -
- 26/80 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - loss: 0.4760 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4767 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4772 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - loss: 0.4772 - - -
-``` -Epoch 2/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 17s 223ms/step - loss: 0.4471 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4772 - - -
-``` -Epoch 3/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4471 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4772 - - -
-``` -Epoch 4/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4471 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4771 - - -
-``` -Epoch 5/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4470 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 27ms/step - loss: 0.4771 - val_loss: 0.4836 - - -
-``` -Epoch 6/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4470 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4540 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4771 - - -
-``` -Epoch 7/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4470 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770 - - -
-``` -Epoch 8/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4470 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4540 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770 - - -
-``` -Epoch 9/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4469 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770 - - -
-``` -Epoch 10/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4469 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - val_loss: 0.4836 - - -
-``` -Epoch 11/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4469 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - - -
-``` -Epoch 12/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4468 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - - -
-``` -Epoch 13/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4468 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - - -
-``` -Epoch 14/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4468 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - - -
-``` -Epoch 15/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4467 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - val_loss: 0.4835 - - -
-``` -Epoch 16/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4467 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - - -
-``` -Epoch 17/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4466 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4537 - -
-``` - -``` -
- 38/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4758 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4765 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4767 - - -
-``` -Epoch 18/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4466 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - - -
-``` -Epoch 19/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4465 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4756 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4763 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - - -
-``` -Epoch 20/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4465 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4755 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4756 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4763 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - val_loss: 0.4834 - - -
-``` -Epoch 21/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4464 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4755 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4763 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4763 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4764 - - -
-``` -Epoch 22/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4464 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4754 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - - -
-``` -Epoch 23/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4463 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4754 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4761 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4763 - - -
-``` -Epoch 24/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4462 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4753 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4761 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4762 - - -
-``` -Epoch 25/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4462 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4752 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4753 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - val_loss: 0.4832 - - -
-``` -Epoch 26/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4461 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4751 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - - -
-``` -Epoch 27/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4460 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4750 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760 - - -
-``` -Epoch 28/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4459 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4749 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759 - - -
-``` -Epoch 29/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4458 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4748 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4756 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4758 - - -
-``` -Epoch 30/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4457 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4747 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4747 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4754 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4756 - val_loss: 0.4828 - - -
-``` -Epoch 31/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4456 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.4525 - -
-``` - -``` -
- 38/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4746 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4754 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4755 - - -
-``` -Epoch 32/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4454 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4744 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4751 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4753 - - -
-``` -Epoch 33/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4453 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4743 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4743 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4750 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4752 - - -
-``` -Epoch 34/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4451 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4521 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4741 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4748 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4750 - - -
-``` -Epoch 35/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4449 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4519 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4739 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4746 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4748 - val_loss: 0.4821 - - -
-``` -Epoch 36/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4447 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4736 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4743 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4746 - - -
-``` -Epoch 37/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4444 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4734 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4741 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4743 - - -
-``` -Epoch 38/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4442 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4731 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4738 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4740 - - -
-``` -Epoch 39/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4439 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4728 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4735 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4737 - - -
-``` -Epoch 40/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4436 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4725 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4731 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4732 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4734 - val_loss: 0.4807 - - -
-``` -Epoch 41/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4432 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.4502 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4721 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4721 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4728 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4728 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4730 - - -
-``` -Epoch 42/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4428 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4716 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4724 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4724 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4725 - - -
-``` -Epoch 43/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4423 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4712 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4719 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4721 - - -
-``` -Epoch 44/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4418 - -
-``` - -``` -
- 38/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4707 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4714 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4715 - - -
-``` -Epoch 45/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4413 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4701 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4708 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4708 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4709 - val_loss: 0.4783 - - -
-``` -Epoch 46/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4406 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4694 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4701 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4701 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4703 - - -
-``` -Epoch 47/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4399 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4687 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4693 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4695 - - -
-``` -Epoch 48/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4392 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.4461 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4679 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4685 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4686 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4687 - - -
-``` -Epoch 49/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4383 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4670 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4676 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4677 - - -
-``` -Epoch 50/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4373 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4659 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4665 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4667 - val_loss: 0.4739 - - ---- -## Making predictions - -Now that we have a model, we would like to be able to make predictions. - -So far, we have only handled movies by id. Now is the time to create a mapping -keyed by movie IDs to be able to surface the titles. - - -```python -movie_id_to_movie_title = { - int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator() -} -movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. -``` - -We then simply use the Keras `model.predict()` method. Under the hood, it calls -the `BruteForceRetrieval` layer to perform the actual retrieval. - -Note that this model can retrieve movies already watched by the user. We could -easily add logic to remove them if that is desirable. - - -```python -user_id = 42 -predictions = model.predict(keras.ops.convert_to_tensor([user_id])) -predictions = keras.ops.convert_to_numpy(predictions["predictions"]) - -print(f"Recommended movies for user {user_id}:") -for movie_id in predictions[0]: - print(movie_id_to_movie_title[movie_id]) -``` - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 103ms/step - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 105ms/step - - -
-``` -Recommended movies for user 42: -b'Raiders of the Lost Ark (1981)' -b'Godfather, The (1972)' -b'Star Trek: The Wrath of Khan (1982)' -b'Indiana Jones and the Last Crusade (1989)' -b'Birdcage, The (1996)' -b'Silence of the Lambs, The (1991)' -b'Blade Runner (1982)' -b'Aliens (1986)' -b'Contact (1997)' -b'Star Wars (1977)' - -``` -
---- -## Item-to-item recommendation - -In this model, we created a user-movie model. However, for some applications -(for example, product detail pages) it's common to perform item-to-item (for -example, movie-to-movie or product-to-product) recommendations. - -Training models like this would follow the same pattern as shown in this -tutorial, but with different training data. Here, we had a user and a movie -tower, and used (user, movie) pairs to train them. In an item-to-item model, we -would have two item towers (for the query and candidate item), and train the -model using (query item, candidate item) pairs. These could be constructed from -clicks on product detail pages. - diff --git a/templates/examples/keras_rs/data_parallel_retrieval.md b/templates/examples/keras_rs/data_parallel_retrieval.md deleted file mode 100644 index 36ecb2d692..0000000000 --- a/templates/examples/keras_rs/data_parallel_retrieval.md +++ /dev/null @@ -1,4222 +0,0 @@ -# Retrieval with data parallel training - -**Author:** [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Retrieve movies using a two tower model (data parallel training). - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/data_parallel_retrieval.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/data_parallel_retrieval.py) - - - ---- -## Introduction - -In this tutorial, we are going to train the exact same retrieval model as we -did in our -[basic retrieval](/keras_rs/examples/basic_retrieval/) -tutorial, but in a distributed way. - -Distributed training is used to train models on multiple devices or machines -simultaneously, thereby reducing training time. Here, we focus on synchronous -data parallel training. Each accelerator (GPU/TPU) holds a complete replica -of the model, and sees a different mini-batch of the input data. Local gradients -are computed on each device, aggregated and used to compute a global gradient -update. - -Before we begin, let's note down a few things: - -1. The number of accelerators should be greater than 1. -2. The `keras.distribution` API works only with JAX. So, make sure you select - JAX as your backend! - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" - -import random - -import jax -import keras -import tensorflow as tf # Needed only for the dataset -import tensorflow_datasets as tfds - -import keras_rs -``` - ---- -## Data Parallel - -For the synchronous data parallelism strategy in distributed training, -we will use the `DataParallel` class present in the `keras.distribution` -API. - - -```python -devices = jax.devices() # Assume it has >1 local devices. -data_parallel = keras.distribution.DataParallel(devices=devices) -``` - -Alternatively, you can choose to create the `DataParallel` object -using a 1D `DeviceMesh` object, like so: - -``` -mesh_1d = keras.distribution.DeviceMesh( - shape=(len(devices),), axis_names=["data"], devices=devices -) -data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d) -``` - - -```python -# Set the global distribution strategy. -keras.distribution.set_distribution(data_parallel) -``` - ---- -## Preparing the dataset - -Now that we are done defining the global distribution -strategy, the rest of the guide looks exactly the same -as the previous basic retrieval guide. - -Let's load and prepare the dataset. Here too, we use the -MovieLens dataset. - - -```python -# Ratings data with user and movie data. -ratings = tfds.load("movielens/100k-ratings", split="train") -# Features of all the available movies. -movies = tfds.load("movielens/100k-movies", split="train") - -# User, movie counts for defining vocabularies. -users_count = ( - ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) - .reduce(tf.constant(0, tf.int32), tf.maximum) - .numpy() -) -movies_count = movies.cardinality().numpy() - - -# Preprocess dataset, and split it into train-test datasets. -def preprocess_rating(x): - return ( - # Input is the user IDs - tf.strings.to_number(x["user_id"], out_type=tf.int32), - # Labels are movie IDs + ratings between 0 and 1. - { - "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), - "rating": (x["user_rating"] - 1.0) / 4.0, - }, - ) - - -shuffled_ratings = ratings.map(preprocess_rating).shuffle( - 100_000, seed=42, reshuffle_each_iteration=False -) -train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() -test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() -``` - -
-``` -WARNING:absl:Variant folder /root/tensorflow_datasets/movielens/100k-ratings/0.1.1 has no dataset_info.json - -Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/100k-ratings/0.1.1... - -Dl Completed...: 0 url [00:00, ? url/s] - -Dl Size...: 0 MiB [00:00, ? MiB/s] - -Extraction completed...: 0 file [00:00, ? file/s] - -Generating splits...: 0%| | 0/1 [00:00 ---- -## Implementing the Model - -We build a two-tower retrieval model. Therefore, we need to combine a -query tower for users and a candidate tower for movies. Note that we don't -have to change anything here from the previous basic retrieval tutorial. - - -```python - -class RetrievalModel(keras.Model): - """Create the retrieval model with the provided parameters. - - Args: - num_users: Number of entries in the user embedding table. - num_candidates: Number of entries in the candidate embedding table. - embedding_dimension: Output dimension for user and movie embedding tables. - """ - - def __init__( - self, - num_users, - num_candidates, - embedding_dimension=32, - **kwargs, - ): - super().__init__(**kwargs) - # Our query tower, simply an embedding table. - self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension) - # Our candidate tower, simply an embedding table. - self.candidate_embedding = keras.layers.Embedding( - num_candidates, embedding_dimension - ) - # The layer that performs the retrieval. - self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False) - self.loss_fn = keras.losses.MeanSquaredError() - - def build(self, input_shape): - self.user_embedding.build(input_shape) - self.candidate_embedding.build(input_shape) - # In this case, the candidates are directly the movie embeddings. - # We take a shortcut and directly reuse the variable. - self.retrieval.candidate_embeddings = self.candidate_embedding.embeddings - self.retrieval.build(input_shape) - super().build(input_shape) - - def call(self, inputs, training=False): - user_embeddings = self.user_embedding(inputs) - result = { - "user_embeddings": user_embeddings, - } - if not training: - # Skip the retrieval of top movies during training as the - # predictions are not used. - result["predictions"] = self.retrieval(user_embeddings) - return result - - def compute_loss(self, x, y, y_pred, sample_weight, training=True): - candidate_id, rating = y["movie_id"], y["rating"] - user_embeddings = y_pred["user_embeddings"] - candidate_embeddings = self.candidate_embedding(candidate_id) - - labels = keras.ops.expand_dims(rating, -1) - # Compute the affinity score by multiplying the two embeddings. - scores = keras.ops.sum( - keras.ops.multiply(user_embeddings, candidate_embeddings), - axis=1, - keepdims=True, - ) - return self.loss_fn(labels, scores, sample_weight) - -``` - ---- -## Fitting and evaluating - -After defining the model, we can use the standard Keras `model.fit()` to train -and evaluate the model. - - -```python -model = RetrievalModel(users_count + 1, movies_count + 1) -model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.2)) -``` - -Let's train the model. Evaluation takes a bit of time, so we only evaluate the -model every 5 epochs. - - -```python -history = model.fit( - train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50 -) -``` - -
-``` -Epoch 1/50 - -``` -
- - 8/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4707 - 3/80 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - loss: 0.4606 - 1/80 ━━━━━━━━━━━━━━━━━━━━ 2:04 2s/step - loss: 0.4479 - 4/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4637 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 0.4547 - - - 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4718 - - - 6/80 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - loss: 0.4679 - - - 5/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.4663 - 7/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4694 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4727 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.4756 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4762 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4764 - -
-``` - -``` -
- 43/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4766 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4767 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4769 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4770 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4772 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 8ms/step - loss: 0.4773 - - -
-``` -Epoch 2/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.4478 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4717 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4753 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4761 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4763 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4765 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4767 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4768 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4770 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4771 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4772 - - -
-``` -Epoch 3/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.4478 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4717 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4751 - -
-``` - -``` -
- 26/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4760 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4762 - -
-``` - -``` -
- 42/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4764 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4766 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4767 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4769 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4771 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4771 - - -
-``` -Epoch 4/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4477 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4716 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4752 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4760 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4762 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4764 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4766 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4767 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4770 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4770 - - -
-``` -Epoch 5/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4476 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4724 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4753 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4759 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4761 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4764 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4765 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4767 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 27ms/step - loss: 0.4770 - val_loss: 0.4835 - - -
-``` -Epoch 6/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - loss: 0.4476 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4724 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4752 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4759 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4760 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4763 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4765 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4766 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4768 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4769 - - -
-``` -Epoch 7/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.4475 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4723 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4751 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4758 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4760 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4762 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4764 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4765 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4768 - - -
-``` -Epoch 8/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4474 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4722 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4750 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4757 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4759 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4761 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4763 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4764 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4767 - - -
-``` -Epoch 9/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.4473 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4712 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4748 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4755 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4757 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4759 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4761 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4763 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4765 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4766 - - -
-``` -Epoch 10/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4472 - -
-``` - -``` -
- 8/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.4699 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4744 - -
-``` - -``` -
- 26/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4754 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4756 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4758 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4760 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4762 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4763 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4765 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4765 - val_loss: 0.4832 - - -
-``` -Epoch 11/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.4470 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4718 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4746 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4753 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4755 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4757 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4759 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4761 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4762 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4763 - - -
-``` -Epoch 12/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4469 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4716 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4745 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4751 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4753 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4756 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4757 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4759 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4761 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4762 - - -
-``` -Epoch 13/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.4467 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4705 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4741 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4749 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4751 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4753 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4755 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4757 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4758 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4760 - - -
-``` -Epoch 14/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4465 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4712 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4740 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4747 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4749 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4751 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4753 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4754 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4756 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4757 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4757 - - -
-``` -Epoch 15/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4462 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4700 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4736 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4744 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4746 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4748 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4750 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4751 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4753 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4754 - val_loss: 0.4824 - - -
-``` -Epoch 16/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.4459 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4706 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4734 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4741 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4743 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4745 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4747 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4748 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4750 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4751 - - -
-``` -Epoch 17/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4455 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4702 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4730 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4737 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4738 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4741 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4742 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4744 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4745 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4746 - - -
-``` -Epoch 18/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4450 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4697 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4725 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4731 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4733 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4735 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4737 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4738 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4740 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4741 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4741 - - -
-``` -Epoch 19/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4444 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4690 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4718 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4725 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4726 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4728 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4730 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4731 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4733 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4734 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4734 - - -
-``` -Epoch 20/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.4437 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4673 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4707 - -
-``` - -``` -
- 25/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4716 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4718 - -
-``` - -``` -
- 43/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4720 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4722 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4723 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4725 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4726 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4726 - val_loss: 0.4795 - - -
-``` -Epoch 21/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.4427 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4673 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4701 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4707 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4709 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4711 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4712 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4714 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4715 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4716 - - -
-``` -Epoch 22/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4416 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4652 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4685 - -
-``` - -``` -
- 25/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4693 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4696 - -
-``` - -``` -
- 42/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4697 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4699 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4700 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4701 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4703 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4703 - - -
-``` -Epoch 23/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4401 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4636 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4672 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4679 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4681 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4683 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4684 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4685 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4686 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4687 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4687 - - -
-``` -Epoch 24/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4383 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4618 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4653 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4660 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4661 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4663 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4664 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4665 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4666 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4667 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4667 - - -
-``` -Epoch 25/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4361 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4603 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4631 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4637 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4638 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4639 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4640 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4641 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4642 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4642 - val_loss: 0.4701 - - -
-``` -Epoch 26/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.4333 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4574 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4601 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4607 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4608 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4610 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4610 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4611 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4612 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4612 - - -
-``` -Epoch 27/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4299 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4538 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4565 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4571 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4572 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4573 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4573 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4574 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4574 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4574 - - -
-``` -Epoch 28/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4256 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4485 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4517 - -
-``` - -``` -
- 26/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4525 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4526 - -
-``` - -``` -
- 43/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4527 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4527 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4527 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4527 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4527 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4527 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4527 - - -
-``` -Epoch 29/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4204 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4440 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4466 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4471 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4471 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4472 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4471 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4471 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4471 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4470 - - -
-``` -Epoch 30/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.4141 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4374 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4399 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4404 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4404 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4404 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4403 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4402 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4402 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4401 - val_loss: 0.4427 - - -
-``` -Epoch 31/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.4064 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4295 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4319 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4323 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4323 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4322 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4321 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4320 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4319 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4317 - - -
-``` -Epoch 32/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.3973 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4200 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4223 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4227 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4226 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4225 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4224 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4222 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4220 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4218 - - -
-``` -Epoch 33/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.3866 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4089 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4111 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4114 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4113 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4111 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4109 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4107 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4104 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4102 - - -
-``` -Epoch 34/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.3742 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3960 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3981 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3984 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3982 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3979 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3977 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3974 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3971 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3968 - - -
-``` -Epoch 35/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.3601 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3813 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3834 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3836 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3833 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3830 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3827 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3823 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3820 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.3817 - val_loss: 0.3787 - - -
-``` -Epoch 36/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.3443 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3651 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3670 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3671 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3668 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3665 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3661 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3657 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3653 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3649 - - -
-``` -Epoch 37/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.3273 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3475 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3493 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3494 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3490 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3487 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3482 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3478 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3473 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3469 - - -
-``` -Epoch 38/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.3093 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.3282 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.3305 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.3306 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3303 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3299 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3294 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3289 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3285 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3280 - - -
-``` -Epoch 39/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.2907 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3098 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3114 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3114 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3111 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3106 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3101 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3096 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3091 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3087 - - -
-``` -Epoch 40/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.2722 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2907 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2923 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2923 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2919 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2915 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2910 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2905 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2900 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.2896 - val_loss: 0.2856 - - -
-``` -Epoch 41/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.2542 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2722 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2737 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2737 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2734 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2729 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2725 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2720 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2715 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2711 - - -
-``` -Epoch 42/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.2372 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2547 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2561 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2562 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2558 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2554 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2550 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2545 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2540 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.2537 - - -
-``` -Epoch 43/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.2215 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2384 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2399 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2399 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2396 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2392 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2388 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2384 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2380 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.2376 - - -
-``` -Epoch 44/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.2072 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2236 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2250 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2251 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2248 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2244 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2240 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2237 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2233 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2230 - - -
-``` -Epoch 45/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.1944 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2103 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2116 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2117 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2115 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2111 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2108 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2104 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2101 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.2098 - val_loss: 0.2106 - - -
-``` -Epoch 46/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.1831 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1984 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1997 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1998 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1995 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1993 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1990 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1987 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1983 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1981 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1980 - - -
-``` -Epoch 47/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.1730 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1877 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1890 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1891 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1888 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1886 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1884 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1881 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1878 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1875 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1875 - - -
-``` -Epoch 48/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.1641 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1782 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1794 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1795 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1793 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1791 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1788 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1786 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1783 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1781 - - -
-``` -Epoch 49/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.1562 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1693 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1707 - -
-``` - -``` -
- 25/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1709 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1708 - -
-``` - -``` -
- 41/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1706 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1704 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1702 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1700 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1697 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1696 - - -
-``` -Epoch 50/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.1492 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1620 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1631 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1631 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1630 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1628 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1626 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1624 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1622 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.1620 - val_loss: 0.1660 - - ---- -## Making predictions - -Now that we have a model, let's run inference and make predictions. - - -```python -movie_id_to_movie_title = { - int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator() -} -movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. -``` - -We then simply use the Keras `model.predict()` method. Under the hood, it calls -the `BruteForceRetrieval` layer to perform the actual retrieval. - - -```python -user_ids = random.sample(range(1, 1001), len(devices)) -predictions = model.predict(keras.ops.convert_to_tensor(user_ids)) -predictions = keras.ops.convert_to_numpy(predictions["predictions"]) - -for i, user_id in enumerate(user_ids): - print(f"\n==Recommended movies for user {user_id}==") - for movie_id in predictions[i]: - print(movie_id_to_movie_title[movie_id]) -``` - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 204ms/step - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 205ms/step - - - -
-``` -==Recommended movies for user 449== -b'Star Wars (1977)' -b'Fargo (1996)' -b'Silence of the Lambs, The (1991)' -b'Shawshank Redemption, The (1994)' -b'Pulp Fiction (1994)' -b'Raiders of the Lost Ark (1981)' -b"Schindler's List (1993)" -b'Blade Runner (1982)' -b"One Flew Over the Cuckoo's Nest (1975)" -b'Casablanca (1942)' -``` -
- -
-``` -==Recommended movies for user 681== -b'Star Wars (1977)' -b'Fargo (1996)' -b'Godfather, The (1972)' -b'Silence of the Lambs, The (1991)' -b'Raiders of the Lost Ark (1981)' -b'Return of the Jedi (1983)' -b'Pulp Fiction (1994)' -b"Schindler's List (1993)" -b'Empire Strikes Back, The (1980)' -b'Shawshank Redemption, The (1994)' -``` -
- -
-``` -==Recommended movies for user 151== -b'Princess Bride, The (1987)' -b'Pulp Fiction (1994)' -b'English Patient, The (1996)' -b'Alien (1979)' -b'Raiders of the Lost Ark (1981)' -b'Willy Wonka and the Chocolate Factory (1971)' -b'Amadeus (1984)' -b'Liar Liar (1997)' -b'Psycho (1960)' -b"It's a Wonderful Life (1946)" -``` -
- -
-``` -==Recommended movies for user 442== -b'Star Wars (1977)' -b'Fargo (1996)' -b'Godfather, The (1972)' -b'Silence of the Lambs, The (1991)' -b'Raiders of the Lost Ark (1981)' -b'Return of the Jedi (1983)' -b'Pulp Fiction (1994)' -b'Empire Strikes Back, The (1980)' -b"Schindler's List (1993)" -b'Shawshank Redemption, The (1994)' -``` -
- -
-``` -==Recommended movies for user 134== -b'Star Wars (1977)' -b'Fargo (1996)' -b'Godfather, The (1972)' -b'Silence of the Lambs, The (1991)' -b'Raiders of the Lost Ark (1981)' -b'Pulp Fiction (1994)' -b'Return of the Jedi (1983)' -b'Empire Strikes Back, The (1980)' -b'Twelve Monkeys (1995)' -b'Contact (1997)' -``` -
- -
-``` -==Recommended movies for user 853== -b'Star Wars (1977)' -b'Fargo (1996)' -b'Godfather, The (1972)' -b'Raiders of the Lost Ark (1981)' -b'Silence of the Lambs, The (1991)' -b'Return of the Jedi (1983)' -b'Pulp Fiction (1994)' -b"Schindler's List (1993)" -b'Empire Strikes Back, The (1980)' -b'Shawshank Redemption, The (1994)' -``` -
- -
-``` -==Recommended movies for user 707== -b'Star Wars (1977)' -b'Raiders of the Lost Ark (1981)' -b'Toy Story (1995)' -b"Schindler's List (1993)" -b'Empire Strikes Back, The (1980)' -b'Fargo (1996)' -b'Godfather, The (1972)' -b'Return of the Jedi (1983)' -b'Terminator, The (1984)' -b'Princess Bride, The (1987)' -``` -
- -
-``` -==Recommended movies for user 511== -b'Star Wars (1977)' -b'Fargo (1996)' -b'Godfather, The (1972)' -b'Raiders of the Lost Ark (1981)' -b'Silence of the Lambs, The (1991)' -b'Return of the Jedi (1983)' -b"Schindler's List (1993)" -b'Empire Strikes Back, The (1980)' -b'Pulp Fiction (1994)' -b'Shawshank Redemption, The (1994)' - -``` -
-And we're done! For data parallel training, all we had to do was add ~3-5 LoC. -The rest is exactly the same. - diff --git a/templates/examples/keras_rs/dcn.md b/templates/examples/keras_rs/dcn.md deleted file mode 100644 index 6f53f2160e..0000000000 --- a/templates/examples/keras_rs/dcn.md +++ /dev/null @@ -1,678 +0,0 @@ -# Ranking with Deep and Cross Networks - -**Author:** [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Rank movies using Deep and Cross Networks (DCN). - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/dcn.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/dcn.py) - - - ---- -## Introduction - -This tutorial demonstrates how to use Deep & Cross Networks (DCN) to effectively -learn feature crosses. Before diving into the example, let's briefly discuss -feature crosses. - -Imagine that we are building a recommender system for blenders. Individual -features might include a customer's past purchase history (e.g., -`purchased_bananas`, `purchased_cooking_books`) or geographic location. However, -a customer who has purchased both bananas and cooking books is more likely to be -interested in a blender than someone who purchased only one or the other. The -combination of `purchased_bananas` and `purchased_cooking_books` is a feature -cross. Feature crosses capture interaction information between individual -features, providing richer context than the individual features alone. - -![Why are feature crosses important?](https://i.imgur.com/qDK6UZh.gif) - -Learning effective feature crosses presents several challenges. In web-scale -applications, data is often categorical, resulting in high-dimensional and -sparse feature spaces. Identifying impactful feature crosses in such -environments typically relies on manual feature engineering or computationally -expensive exhaustive searches. While traditional feed-forward multilayer -perceptrons (MLPs) are universal function approximators, they often struggle to -efficiently learn even second- or third-order feature interactions. - -The Deep & Cross Network (DCN) architecture is designed for more effective -learning of explicit and bounded-degree feature crosses. It comprises three main -components: an input layer (typically an embedding layer), a cross network for -modeling explicit feature interactions, and a deep network for capturing -implicit interactions. - -The cross network is the core of the DCN. It explicitly performs feature -crossing at each layer, with the highest polynomial degree of feature -interaction increasing with depth. The following figure shows the `(i+1)`-th -cross layer. - -![Feature Cross Layer](https://i.imgur.com/ip5uRsl.png) - -The deep network is a standard feedforward multilayer perceptron -(MLP). These two networks are then combined to form the DCN. Two common -combination strategies exist: a stacked structure, where the deep network is -placed on top of the cross network, and a parallel structure, where they -operate in parallel. - - - - - - -
-
- Parallel layers -
Parallel layers
-
-
-
- Stacked layers -
Stacked layers
-
-
- -Now that we know a little bit about DCN, let's start writing some code. We will -first train a DCN on a toy dataset, and demonstrate that the model has indeed -learnt important feature crosses. - -Let's set the backend to JAX, and get our imports sorted. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import keras -import matplotlib.pyplot as plt -import numpy as np -import tensorflow as tf -import tensorflow_datasets as tfds -from mpl_toolkits.axes_grid1 import make_axes_locatable - -import keras_rs -``` - -Let's also define variables which will be reused throughout the example. - - -```python -TOY_CONFIG = { - "learning_rate": 0.01, - "num_epochs": 100, - "batch_size": 1024, -} - -MOVIELENS_CONFIG = { - # features - "int_features": [ - "movie_id", - "user_id", - "user_gender", - "bucketized_user_age", - ], - "str_features": [ - "user_zip_code", - "user_occupation_text", - ], - # model - "embedding_dim": 32, - "deep_net_num_units": [192, 192, 192], - "projection_dim": 20, - "dcn_num_units": [192, 192], - # training - "learning_rate": 0.01, - "num_epochs": 10, - "batch_size": 1024, -} - -LOOKUP_LAYERS = { - "int": keras.layers.IntegerLookup, - "str": keras.layers.StringLookup, -} -``` - -Here, we define a helper function for visualising weights of the cross layer in -order to better understand its functioning. Also, we define a function for -compiling, training and evaluating a given model. - - -```python - -def visualize_layer(matrix, features): - plt.figure(figsize=(9, 9)) - - im = plt.matshow(np.abs(matrix), cmap=plt.cm.Blues) - - ax = plt.gca() - divider = make_axes_locatable(plt.gca()) - cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im, cax=cax) - cax.tick_params(labelsize=10) - ax.set_xticklabels([""] + features, rotation=45, fontsize=10) - ax.set_yticklabels([""] + features, fontsize=10) - - -def train_and_evaluate( - learning_rate, - epochs, - train_data, - test_data, - model, -): - optimizer = keras.optimizers.AdamW(learning_rate=learning_rate) - loss = keras.losses.MeanSquaredError() - rmse = keras.metrics.RootMeanSquaredError() - - model.compile( - optimizer=optimizer, - loss=loss, - metrics=[rmse], - ) - - model.fit( - train_data, - epochs=epochs, - verbose=0, - ) - - results = model.evaluate(test_data, return_dict=True, verbose=0) - rmse_value = results["root_mean_squared_error"] - - return rmse_value, model.count_params() - - -def print_stats(rmse_list, num_params, model_name): - # Report metrics. - num_trials = len(rmse_list) - avg_rmse = np.mean(rmse_list) - std_rmse = np.std(rmse_list) - - if num_trials == 1: - print(f"{model_name}: RMSE = {avg_rmse}; #params = {num_params}") - else: - print( - f"{model_name}: RMSE = {avg_rmse} ± {std_rmse}; " "#params = {num_params}" - ) - -``` - ---- -## Toy Example - -To illustrate the benefits of DCNs, let's consider a simple example. Suppose we -have a dataset for modeling the likelihood of a customer clicking on a blender -advertisement. The features and label are defined as follows: - -| **Features / Label** | **Description** | **Range**| -|:--------------------:|:------------------------------:|:--------:| -| `x1` = country | Customer's resident country | [0, 199] | -| `x2` = bananas | # bananas purchased | [0, 23] | -| `x3` = cookbooks | # cooking books purchased | [0, 5] | -| `y` | Blender ad click likelihood | - | - -Then, we let the data follow the following underlying distribution: -`y = f(x1, x2, x3) = 0.1x1 + 0.4x2 + 0.7x3 + 0.1x1x2 +` -`3.1x2x3 + 0.1x3^2`. - -This distribution shows that the click likelihood (`y`) depends linearly on -individual features (`xi`) and on multiplicative interactions between them. In -this scenario, the likelihood of purchasing a blender (`y`) is influenced not -only by purchasing bananas (`x2`) or cookbooks (`x3`) individually, but also -significantly by the interaction of purchasing both bananas and cookbooks -(`x2x3`). - -### Preparing the dataset - -Let's create synthetic data based on the above equation, and form the train-test -splits. - - -```python - -def get_mixer_data(data_size=100_000): - country = np.random.randint(200, size=[data_size, 1]) / 200.0 - bananas = np.random.randint(24, size=[data_size, 1]) / 24.0 - cookbooks = np.random.randint(6, size=[data_size, 1]) / 6.0 - - x = np.concatenate([country, bananas, cookbooks], axis=1) - - # Create 1st-order terms. - y = 0.1 * country + 0.4 * bananas + 0.7 * cookbooks - - # Create 2nd-order cross terms. - y += ( - 0.1 * country * bananas - + 3.1 * bananas * cookbooks - + (0.1 * cookbooks * cookbooks) - ) - - return x, y - - -x, y = get_mixer_data(data_size=100_000) -num_train = 90_000 -train_x = x[:num_train] -train_y = y[:num_train] -test_x = x[num_train:] -test_y = y[num_train:] -``` - -### Building the model - -To demonstrate the advantages of a cross network in recommender systems, we'll -compare its performance with a deep network. Since our example data only -contains second-order feature interactions, a single-layered cross network will -suffice. For datasets with higher-order interactions, multiple cross layers can -be stacked to form a multi-layered cross network. We will build two models: - -1. A cross network with a single cross layer. -2. A deep network with wider and deeper feedforward layers. - - -```python -cross_network = keras.Sequential( - [ - keras_rs.layers.FeatureCross(), - keras.layers.Dense(1), - ] -) - -deep_network = keras.Sequential( - [ - keras.layers.Dense(512, activation="relu"), - keras.layers.Dense(256, activation="relu"), - keras.layers.Dense(128, activation="relu"), - ] -) -``` - -### Model training - -Before we train the model, we need to batch our datasets. - - -```python -train_ds = tf.data.Dataset.from_tensor_slices((train_x, train_y)).batch( - TOY_CONFIG["batch_size"] -) -test_ds = tf.data.Dataset.from_tensor_slices((test_x, test_y)).batch( - TOY_CONFIG["batch_size"] -) -``` - -Let's train both models. Remember we have set `verbose=0` for brevity's -sake, so do not be alarmed if you do not see any output for a while. - -After training, we evaluate the models on the unseen dataset. We will report -the Root Mean Squared Error (RMSE) here. - -We observe that the cross network achieved significantly lower RMSE compared to -a ReLU-based DNN, while also using fewer parameters. This points to the -efficiency of the cross network in learning feature interactions. - - -```python -cross_network_rmse, cross_network_num_params = train_and_evaluate( - learning_rate=TOY_CONFIG["learning_rate"], - epochs=TOY_CONFIG["num_epochs"], - train_data=train_ds, - test_data=test_ds, - model=cross_network, -) -print_stats( - rmse_list=[cross_network_rmse], - num_params=cross_network_num_params, - model_name="Cross Network", -) - -deep_network_rmse, deep_network_num_params = train_and_evaluate( - learning_rate=TOY_CONFIG["learning_rate"], - epochs=TOY_CONFIG["num_epochs"], - train_data=train_ds, - test_data=test_ds, - model=deep_network, -) -print_stats( - rmse_list=[deep_network_rmse], - num_params=deep_network_num_params, - model_name="Deep Network", -) -``` - -
-``` -Cross Network: RMSE = 0.0001293081877520308; #params = 16 - -Deep Network: RMSE = 0.13307014107704163; #params = 166272 - -``` -
-### Visualizing feature interactions - -Since we already know which feature crosses are important in our data, it would -be interesting to verify whether our model has indeed learned these key feature -interactions. This can be done by visualizing the learned weight matrix in the -cross network, where the weight `Wij` represents the learned importance of -the interaction between features `xi` and `xj`. - - -```python -visualize_layer( - matrix=cross_network.weights[0].numpy(), - features=["country", "purchased_bananas", "purchased_cookbooks"], -) -``` - -
-``` -:11: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator. - ax.set_xticklabels([""] + features, rotation=45, fontsize=10) -:12: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator. - ax.set_yticklabels([""] + features, fontsize=10) - -
- -``` -
- -![png](/img/examples/keras_rs/dcn/dcn_16_2.png) - - - ---- -## Real-world example - -Let's use the MovieLens 100K dataset. This dataset is used to train models to -predict users' movie ratings, based on user-related features and movie-related -features. - -### Preparing the dataset - -The dataset processing steps here are similar to what's given in the -[basic ranking](/keras_rs/examples/basic_ranking/) -tutorial. Let's load the dataset, and keep only the useful columns. - - -```python -ratings_ds = tfds.load("movielens/100k-ratings", split="train") -ratings_ds = ratings_ds.map( - lambda x: ( - { - "movie_id": int(x["movie_id"]), - "user_id": int(x["user_id"]), - "user_gender": int(x["user_gender"]), - "user_zip_code": x["user_zip_code"], - "user_occupation_text": x["user_occupation_text"], - "bucketized_user_age": int(x["bucketized_user_age"]), - }, - x["user_rating"], # label - ) -) -``` - -
-``` -WARNING:absl:Variant folder /root/tensorflow_datasets/movielens/100k-ratings/0.1.1 has no dataset_info.json - -Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/100k-ratings/0.1.1... - -Dl Completed...: 0 url [00:00, ? url/s] - -Dl Size...: 0 MiB [00:00, ? MiB/s] - -Extraction completed...: 0 file [00:00, ? file/s] - -Generating splits...: 0%| | 0/1 [00:00 -For every feature, let's get the list of unique values, i.e., vocabulary, so -that we can use that for the embedding layer. - - -```python -vocabularies = {} -for feature_name in MOVIELENS_CONFIG["int_features"] + MOVIELENS_CONFIG["str_features"]: - vocabulary = ratings_ds.batch(10_000).map(lambda x, y: x[feature_name]) - vocabularies[feature_name] = np.unique(np.concatenate(list(vocabulary))) -``` - -One thing we need to do is to use `keras.layers.StringLookup` and -`keras.layers.IntegerLookup` to convert all features into indices, which can -then be fed into embedding layers. - - -```python -lookup_layers = {} -lookup_layers.update( - { - feature: keras.layers.IntegerLookup(vocabulary=vocabularies[feature]) - for feature in MOVIELENS_CONFIG["int_features"] - } -) -lookup_layers.update( - { - feature: keras.layers.StringLookup(vocabulary=vocabularies[feature]) - for feature in MOVIELENS_CONFIG["str_features"] - } -) - -ratings_ds = ratings_ds.map( - lambda x, y: ( - { - feature_name: lookup_layers[feature_name](x[feature_name]) - for feature_name in vocabularies - }, - y, - ) -) -``` - -Let's split our data into train and test sets. We also use `cache()` and -`prefetch()` for better performance. - - -```python -ratings_ds = ratings_ds.shuffle(100_000) - -train_ds = ( - ratings_ds.take(80_000) - .batch(MOVIELENS_CONFIG["batch_size"]) - .cache() - .prefetch(tf.data.AUTOTUNE) -) -test_ds = ( - ratings_ds.skip(80_000) - .batch(MOVIELENS_CONFIG["batch_size"]) - .take(20_000) - .cache() - .prefetch(tf.data.AUTOTUNE) -) -``` - -### Building the model - -The model will have embedding layers, followed by cross and/or feedforward -layers. - - -```python - -def get_model( - dense_num_units_lst, - embedding_dim=MOVIELENS_CONFIG["embedding_dim"], - use_cross_layer=False, - projection_dim=None, -): - inputs = {} - embeddings = [] - for feature_name, vocabulary in vocabularies.items(): - inputs[feature_name] = keras.Input(shape=(), dtype="int32", name=feature_name) - embedding_layer = keras.layers.Embedding( - input_dim=len(vocabulary) + 1, - output_dim=embedding_dim, - ) - embedding = embedding_layer(inputs[feature_name]) - embeddings.append(embedding) - - x = keras.ops.concatenate(embeddings, axis=1) - - # Cross layer. - if use_cross_layer: - x = keras_rs.layers.FeatureCross(projection_dim=projection_dim)(x) - - # Dense layer. - for num_units in dense_num_units_lst: - x = keras.layers.Dense(num_units, activation="relu")(x) - - x = keras.layers.Dense(1)(x) - - return keras.Model(inputs=inputs, outputs=x) - -``` - -We have three models - a deep cross network, an optimised deep cross -network with a low-rank matrix (to reduce training and serving costs) and a -normal deep network without cross layers. The deep cross network is a stacked -DCN model, i.e., the inputs are fed to cross layers, followed by feedforward -layers. Let's run each model 10 times, and report the average/standard -deviation of the RMSE. - - -```python -cross_network_rmse_list = [] -opt_cross_network_rmse_list = [] -deep_network_rmse_list = [] - -for _ in range(10): - cross_network = get_model( - dense_num_units_lst=MOVIELENS_CONFIG["dcn_num_units"], - embedding_dim=MOVIELENS_CONFIG["embedding_dim"], - use_cross_layer=True, - ) - rmse, cross_network_num_params = train_and_evaluate( - learning_rate=MOVIELENS_CONFIG["learning_rate"], - epochs=MOVIELENS_CONFIG["num_epochs"], - train_data=train_ds, - test_data=test_ds, - model=cross_network, - ) - cross_network_rmse_list.append(rmse) - - opt_cross_network = get_model( - dense_num_units_lst=MOVIELENS_CONFIG["dcn_num_units"], - embedding_dim=MOVIELENS_CONFIG["embedding_dim"], - use_cross_layer=True, - projection_dim=MOVIELENS_CONFIG["projection_dim"], - ) - rmse, opt_cross_network_num_params = train_and_evaluate( - learning_rate=MOVIELENS_CONFIG["learning_rate"], - epochs=MOVIELENS_CONFIG["num_epochs"], - train_data=train_ds, - test_data=test_ds, - model=opt_cross_network, - ) - opt_cross_network_rmse_list.append(rmse) - - deep_network = get_model(dense_num_units_lst=MOVIELENS_CONFIG["deep_net_num_units"]) - rmse, deep_network_num_params = train_and_evaluate( - learning_rate=MOVIELENS_CONFIG["learning_rate"], - epochs=MOVIELENS_CONFIG["num_epochs"], - train_data=train_ds, - test_data=test_ds, - model=deep_network, - ) - deep_network_rmse_list.append(rmse) - -print_stats( - rmse_list=cross_network_rmse_list, - num_params=cross_network_num_params, - model_name="Cross Network", -) -print_stats( - rmse_list=opt_cross_network_rmse_list, - num_params=opt_cross_network_num_params, - model_name="Optimised Cross Network", -) -print_stats( - rmse_list=deep_network_rmse_list, - num_params=deep_network_num_params, - model_name="Deep Network", -) -``` - -
-``` -Cross Network: RMSE = 0.9427602052688598 ± 0.07614302893494468; #params = {num_params} -Optimised Cross Network: RMSE = 0.9187218248844147 ± 0.031170624868084987; #params = {num_params} -Deep Network: RMSE = 0.8789893209934234 ± 0.025684711934398047; #params = {num_params} - -``` -
-DCN outperforms a similarly sized DNN with ReLU layers, demonstrating -superior performance. Furthermore, the low-rank DCN effectively reduces the -number of parameters without compromising accuracy. - -### Visualizing feature interactions - -Like we did for the toy example, we will plot the weight matrix of the cross -layer to see which feature crosses are important. In the previous example, -the importance of interactions between the `i`-th and `j-th` features is -captured by the `(i, j)`-{th} element of the weight matrix. - -In this case, the feature embeddings are of size 32 rather than 1. Therefore, -the importance of feature interactions is represented by the `(i, j)`-th -block of the weight matrix, which has dimensions `32 x 32`. To quantify the -significance of these interactions, we use the Frobenius norm of each block. A -larger value implies higher importance. - - -```python -features = list(vocabularies.keys()) -mat = cross_network.weights[len(features)].numpy() -embedding_dim = MOVIELENS_CONFIG["embedding_dim"] - -block_norm = np.zeros([len(features), len(features)]) - -# Compute the norms of the blocks. -for i in range(len(features)): - for j in range(len(features)): - block = mat[ - i * embedding_dim : (i + 1) * embedding_dim, - j * embedding_dim : (j + 1) * embedding_dim, - ] - block_norm[i, j] = np.linalg.norm(block, ord="fro") - -visualize_layer( - matrix=block_norm, - features=features, -) -``` - -
-``` -:11: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator. - ax.set_xticklabels([""] + features, rotation=45, fontsize=10) -:12: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator. - ax.set_yticklabels([""] + features, fontsize=10) - -
- -``` -
- -![png](/img/examples/keras_rs/dcn/dcn_31_2.png) - - - -And we are all done! - diff --git a/templates/examples/keras_rs/deep_recommender.md b/templates/examples/keras_rs/deep_recommender.md deleted file mode 100644 index b643e8d25b..0000000000 --- a/templates/examples/keras_rs/deep_recommender.md +++ /dev/null @@ -1,5441 +0,0 @@ -# Deep Recommenders - -**Author:** [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Building a deep retrieval model with multiple stacked layers. - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/deep_recommender.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/deep_recommender.py) - - - ---- -## Introduction - -One of the great advantages of using Keras to build recommender models is the -freedom to build rich, flexible feature representations. - -The first step in doing so is preparing the features, as raw features will -usually not be immediately usable in a model. - -For example: -- User and item IDs may be strings (titles, usernames) or large, non-contiguous - integers (database IDs). -- Item descriptions could be raw text. -- Interaction timestamps could be raw Unix timestamps. - -These need to be appropriately transformed in order to be useful in building -models: -- User and item IDs have to be translated into embedding vectors, - high-dimensional numerical representations that are adjusted during training - to help the model predict its objective better. -- Raw text needs to be tokenized (split into smaller parts such as individual - words) and translated into embeddings. -- Numerical features need to be normalized so that their values lie in a small - interval around 0. - -Fortunately, the Keras -[`FeatureSpace`](/api/utils/feature_space/) utility makes this -preprocessing easy. - -In this tutorial, we are going to incorporate multiple features in our models. -These features will come from preprocessing the MovieLens dataset. - -In the -[basic retrieval](/keras_rs/examples/basic_retrieval/) -tutorial, the models consist of only an embedding layer. In this tutorial, we -add more dense layers to our models to increase their expressive power. - -In general, deeper models are capable of learning more complex patterns than -shallower models. For example, our user model incorporates user IDs and user -features such as age, gender and occupation. A shallow model (say, a single -embedding layer) may only be able to learn the simplest relationships between -those features and movies: a given user generally prefers horror movies to -comedies. To capture more complex relationships, such as user preferences -evolving with their age, we may need a deeper model with multiple stacked dense -layers. - -Of course, complex models also have their disadvantages. The first is -computational cost, as larger models require both more memory and more -computation to train and serve. The second is the requirement for more data. In -general, more training data is needed to take advantage of deeper models. With -more parameters, deep models might overfit or even simply memorize the training -examples instead of learning a function that can generalize. Finally, training -deeper models may be harder, and more care needs to be taken in choosing -settings like regularization and learning rate. - -Finding a good architecture for a real-world recommender system is a complex -art, requiring good intuition and careful hyperparameter tuning. For example, -factors such as the depth and width of the model, activation function, learning -rate, and optimizer can radically change the performance of the model. Modelling -choices are further complicated by the fact that good offline evaluation metrics -may not correspond to good online performance, and that the choice of what to -optimize for is often more critical than the choice of model itself. - -Nevertheless, effort put into building and fine-tuning larger models often pays -off. In this tutorial, we will illustrate how to build a deep retrieval model. -We'll do this by building progressively more complex models to see how this -affects model performance. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import keras -import matplotlib.pyplot as plt -import tensorflow as tf # Needed for the dataset -import tensorflow_datasets as tfds - -import keras_rs -``` - ---- -## The MovieLens dataset - -Let's first have a look at what features we can use from the MovieLens dataset. - - -```python -# Ratings data with user and movie data. -ratings = tfds.load("movielens/100k-ratings", split="train") -# Features of all the available movies. -movies = tfds.load("movielens/100k-movies", split="train") -``` - -The ratings dataset returns a dictionary of movie id, user id, the assigned -rating, timestamp, movie information, and user information: - - -```python -for data in ratings.take(1).as_numpy_iterator(): - print(str(data).replace(", '", ",\n '")) -``` - -
-``` -{'bucketized_user_age': np.float32(45.0), - 'movie_genres': array([7]), - 'movie_id': b'357', - 'movie_title': b"One Flew Over the Cuckoo's Nest (1975)", - 'raw_user_age': np.float32(46.0), - 'timestamp': np.int64(879024327), - 'user_gender': np.True_, - 'user_id': b'138', - 'user_occupation_label': np.int64(4), - 'user_occupation_text': b'doctor', - 'user_rating': np.float32(4.0), - 'user_zip_code': b'53211'} - -``` -
-In the Movielens dataset, user IDs are integers (represented as strings) -starting at 1 and with no gap. Normally, you would need to create a lookup table -to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the -user id directly as an index in our model, in particular to lookup the user -embedding from the user embedding table. So we need do know the number of users. - - -```python -USERS_COUNT = ( - ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) - .reduce(tf.constant(0, tf.int32), tf.maximum) - .numpy() -) -``` - -The movies dataset contains the movie id, movie title, and the genres it belongs -to. Note that the genres are encoded with integer labels. - - -```python -for data in movies.take(1).as_numpy_iterator(): - print(str(data).replace(", '", ",\n '")) -``` - -
-``` -{'movie_genres': array([4]), - 'movie_id': b'1681', - 'movie_title': b'You So Crazy (1994)'} - -``` -
-In the Movielens dataset, movie IDs are integers (represented as strings) -starting at 1 and with no gap. Normally, you would need to create a lookup table -to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the -movie id directly as an index in our model, in particular to lookup the movie -embedding from the movie embedding table. So we need do know the number of -movies. - - -```python -MOVIES_COUNT = movies.cardinality().numpy() -``` - ---- -## Preprocessing the dataset - -### Normalizing continuous features - -Continuous features may need normalization so that they fall within an -acceptable range for the model. We will give two examples of such normalization. - -#### Discretization - -A common transformation is to turn a continuous feature into a number of -categorical features. This makes good sense if we have reasons to suspect that a -feature's effect is non-continuous. - -We need to decide on a number the buckets we will use for discretization. Then, -we will use the Keras `FeatureSpace` utility to automatically find the minimum -and maximum value, and divide that range by the number of buckets to perform the -discretization. - -In this example, we will discretize the user age. - - -```python -AGE_BINS_COUNT = 10 -user_age_feature = keras.utils.FeatureSpace.float_discretized( - num_bins=AGE_BINS_COUNT, output_mode="int" -) -``` - -#### Rescaling - -Often, we want continous features to be between 0 and 1, or between -1 and 1. -To achieve this, we can rescale features that have a different range. - -In this example, we will standardize the rating, which is a integer between 1 -and 5, to be a float between 0 and 1. We need to rescale it and offset it. - - -```python -user_rating_feature = keras.utils.FeatureSpace.float_rescaled( - scale=1.0 / 4.0, offset=-1.0 / 4.0 -) -``` - -### Turning categorical features into embeddings - -A categorical feature is a feature that does not express a continuous quantity, -but rather takes on one of a set of fixed values. - -Most deep learning models express these feature by turning them into -high-dimensional vectors. During model training, the value of that vector is -adjusted to help the model predict its objective better. - -For example, suppose that our goal is to predict which user is going to watch -which movie. To do that, we represent each user and each movie by an embedding -vector. Initially, these embeddings will take on random values. During training, -we adjust them so that embeddings of users and the movies they watch end up -closer together. - -Taking raw categorical features and turning them into embeddings is normally a -two-step process: -1. First, we need to translate the raw values into a range of contiguous - integers, normally by building a mapping (called a "vocabulary") that maps - raw values to integers. -2. Second, we need to take these integers and turn them into embeddings. - -#### Defining categorical features - -We will use the Keras `FeatureSpace` utility for the first step. Its `adapt` -method automatically discovers the vocabulary for categorical features. - - -```python -user_gender_feature = keras.utils.FeatureSpace.integer_categorical( - num_oov_indices=0, output_mode="int" -) -user_occupation_feature = keras.utils.FeatureSpace.integer_categorical( - num_oov_indices=0, output_mode="int" -) -``` - -#### Using feature crosses - -With crosses we can do feature interactions between multiple categorical -features. This can be powerful to express that the combination of features -represents a specific taste for movies. - -Note that the combination of multiple features can result into on a super large -feature space, that is why the crossing_dim parameter is important to limit the -output dimension of the cross feature. - -In this example, we will cross age and gender with the Keras `FeatureSpace` -utility. - - -```python -USER_GENDER_CROSS_COUNT = 20 -user_gender_age_cross = keras.utils.FeatureSpace.cross( - feature_names=("user_gender", "raw_user_age"), - crossing_dim=USER_GENDER_CROSS_COUNT, - output_mode="int", -) -``` - -### Processing text features - -We may also want to add text features to our model. Usually, things like product -descriptions are free form text, and we can hope that our model can learn to use -the information they contain to make better recommendations, especially in a -cold-start or long tail scenario. - -While the MovieLens dataset does not give us rich textual features, we can still -use movie titles. This may help us capture the fact that movies with very -similar titles are likely to belong to the same series. - -The first transformation we need to apply to text is tokenization (splitting -into constituent words or word-pieces), followed by vocabulary learning, -followed by an embedding. - - -The -[`keras.layers.TextVectorization`](/api/layers/preprocessing_layers/text/text_vectorization/) -layer can do the first two steps for us. - - -```python -title_vectorizer = keras.layers.TextVectorization( - max_tokens=10_000, output_sequence_length=16, dtype="int32" -) -title_vectorizer.adapt(movies.map(lambda x: x["movie_title"])) -``` - -Let's try it out: - - -```python -for data in movies.take(1).as_numpy_iterator(): - print(title_vectorizer(data["movie_title"])) -``` - -
-``` -[ 59 187 622 5 0 0 0 0 0 0 0 0 0 0 0 0] - -``` -
-Each title is translated into a sequence of tokens, one for each piece we've -tokenized. - -We can check the learned vocabulary to verify that the layer is using the -correct tokenization: - - -```python -print(title_vectorizer.get_vocabulary()[40:50]) -``` - -
-``` -[np.str_('paris'), np.str_('little'), np.str_('last'), np.str_('ii'), np.str_('1988'), np.str_('king'), np.str_('from'), np.str_('city'), np.str_('boys'), np.str_('murder')] - -``` -
-This looks correct, the layer is tokenizing titles into individual words. Later, -we will see how to embed this tokenized text. For now, we turn this vectorizer -into a Keras `FeatureSpace` feature. - - -```python -title_feature = keras.utils.FeatureSpace.feature( - preprocessor=title_vectorizer, dtype="string", output_mode="float" -) -TITLE_TOKEN_COUNT = title_vectorizer.vocabulary_size() -``` - -### Putting the FeatureSpace features together - -We're now ready to assemble the features with preprocessors in a `FeatureSpace` -object. We're then using `adapt` to go through the dataset and learn what needs -to be learned, such as the vocabulary size for categorical features or the -minimum and maximum values for bucketized features. - - -```python -feature_space = keras.utils.FeatureSpace( - features={ - # Numerical features to discretize. - "raw_user_age": user_age_feature, - # Categorical features encoded as integers. - "user_gender": user_gender_feature, - "user_occupation_label": user_occupation_feature, - # Labels are ratings between 0 and 1. - "user_rating": user_rating_feature, - "movie_title": title_feature, - }, - crosses=[user_gender_age_cross], - output_mode="dict", -) - -feature_space.adapt(ratings) -GENDERS_COUNT = feature_space.preprocessors["user_gender"].vocabulary_size() -OCCUPATIONS_COUNT = feature_space.preprocessors[ - "user_occupation_label" -].vocabulary_size() -``` - ---- -## Pre-building the candidate set - -Our model is going to based on a `Retrieval` layer, which can provides a set of -best candidates among to full set of candidates. To do this, the retrieval layer -needs to know all the candidates and their features. In this section, we -assemble the full set of movies with the associated features. - -### Extract raw candidate features - -First, we gather all the raw features from the dataset in lists. That is the -titles of the movies and the genres. Note that one or more genres are -associated with each movie, and the number of genres varies per movie. - - -```python -movie_titles = [""] * (MOVIES_COUNT + 1) -movie_genres = [[]] * (MOVIES_COUNT + 1) -for x in movies.as_numpy_iterator(): - movie_id = int(x["movie_id"]) - movie_titles[movie_id] = x["movie_title"] - movie_genres[movie_id] = x["movie_genres"].tolist() -``` - -### Preprocess candidate features - -Genres are already in the form of category numbers starting at zero. However, we -do need to figure out two things: -- The maximum number of genres a single movie can have; this will determine the - dimension for this feature. -- The maximum value for the genre, which will give us the total number of genres - and determine the size of our embedding table for genres. - - -```python -MAX_GENRES_PER_MOVIE = 0 -max_genre_id = 0 -for one_movie_genres in movie_genres: - MAX_GENRES_PER_MOVIE = max(MAX_GENRES_PER_MOVIE, len(one_movie_genres)) - if one_movie_genres: - max_genre_id = max(max_genre_id, max(one_movie_genres)) - -GENRES_COUNT = max_genre_id + 1 -``` - -Now we need to pad genres with an Out Of Vocabulary value to be able to -represent genres as a fixed size vector. We'll pad with zeros for simplicity, so -we're adding one to the genres to not conflict with genre zero, which is a valid -genre. - - -```python -movie_genres = [ - [g + 1 for g in genres] + [0] * (MAX_GENRES_PER_MOVIE - len(genres)) - for genres in movie_genres -] -``` - -Then, we vectorize all the movie titles. - - -```python -movie_titles_vectors = title_vectorizer(movie_titles) -``` - -### Convert candidate set to native tensors - -We're now ready to combine these in a dataset. The last step is to make sure -everything is a native tensor that can be consumed by the retrieval layer. -As a remminder, movie id zero does not exist. - - -```python -MOVIES_DATASET = { - "movie_id": keras.ops.arange(0, MOVIES_COUNT + 1, dtype="int32"), - "movie_title_vector": movie_titles_vectors, - "movie_genres": keras.ops.convert_to_tensor(movie_genres, dtype="int32"), -} -``` - ---- -## Preparing the data - -We can now define our preprocessing function. Most features will be handled -by the `FeatureSpace`. User IDs and Movie IDs need to be extracted. Movie genres -need to be padded. Then everything is packaged as a tuple with a dict of input -features and a float for the rating, which is used as a label. - - -```python - -def preprocess_rating(x): - features = feature_space( - { - "raw_user_age": x["raw_user_age"], - "user_gender": x["user_gender"], - "user_occupation_label": x["user_occupation_label"], - "user_rating": x["user_rating"], - "movie_title": x["movie_title"], - } - ) - features = {k: tf.squeeze(v, axis=0) for k, v in features.items()} - movie_genres = x["movie_genres"] - - return ( - { - # User inputs are user ID and user features - "user_id": int(x["user_id"]), - "raw_user_age": features["raw_user_age"], - "user_gender": features["user_gender"], - "user_occupation_label": features["user_occupation_label"], - "user_gender_X_raw_user_age": tf.squeeze( - features["user_gender_X_raw_user_age"], axis=-1 - ), - # Movie inputs are movie ID, vectorized title and genres - "movie_id": int(x["movie_id"]), - "movie_title_vector": features["movie_title"], - "movie_genres": tf.pad( - movie_genres + 1, - [[0, MAX_GENRES_PER_MOVIE - tf.shape(movie_genres)[0]]], - ), - }, - # Label is user rating between 0 and 1 - features["user_rating"], - ) - -``` - -We shuffle and then split the data into a training set and a testing set. - - -```python -shuffled_ratings = ratings.map(preprocess_rating).shuffle( - 100_000, seed=42, reshuffle_each_iteration=False -) - -train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() -test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() -``` - ---- -## Model definition - -### Query model - -The query model is first tasked with converting user features to embeddings. The -embeddings are then concatenated into a single vector. - -Defining deeper models will require us to stack more layers on top of this first -set of embeddings. A progressively narrower stack of layers, separated by an -activation function, is a common pattern: - -``` - +----------------------+ - | 64 x 32 | - +----------------------+ - | relu - +--------------------------+ - | 128 x 64 | - +--------------------------+ - | relu - +------------------------------+ - | ... x 128 | - +------------------------------+ -``` - -Since the expressive power of deep linear models is no greater than that of -shallow linear models, we use ReLU activations for all but the last hidden -layer. The final hidden layer does not use any activation function: using an -activation function would limit the output space of the final embeddings and -might negatively impact the performance of the model. For instance, if ReLUs are -used in the projection layer, all components in the output embedding would be -non-negative. - -We're going to try this here. To make experimentation with different depths -easy, let's define a model whose depth (and width) is defined by a constructor -parameters. The `layer_sizes` parameter gives us the depth and width of the -model. We can vary it to experiment with shallower or deeper models. - - -```python - -class QueryModel(keras.Model): - """Model for encoding user queries.""" - - def __init__(self, layer_sizes, embedding_dimension=32): - """Construct a model for encoding user queries. - - Args: - layer_sizes: A list of integers where the i-th entry represents the - number of units the i-th layer contains. - embedding_dimension: Output dimension for all embedding tables. - """ - super().__init__() - - # We first generate embeddings. - self.user_embedding = keras.layers.Embedding( - # +1 for user ID zero, which does not exist - USERS_COUNT + 1, - embedding_dimension, - ) - self.gender_embedding = keras.layers.Embedding( - GENDERS_COUNT, embedding_dimension - ) - self.age_embedding = keras.layers.Embedding(AGE_BINS_COUNT, embedding_dimension) - self.gender_x_age_embedding = keras.layers.Embedding( - USER_GENDER_CROSS_COUNT, embedding_dimension - ) - self.occupation_embedding = keras.layers.Embedding( - OCCUPATIONS_COUNT, embedding_dimension - ) - - # Then construct the layers. - self.dense_layers = keras.Sequential() - - # Use the ReLU activation for all but the last layer. - for layer_size in layer_sizes[:-1]: - self.dense_layers.add(keras.layers.Dense(layer_size, activation="relu")) - - # No activation for the last layer. - self.dense_layers.add(keras.layers.Dense(layer_sizes[-1])) - - def call(self, inputs): - # Take the inputs, pass each through its embedding layer, concatenate. - feature_embedding = keras.ops.concatenate( - [ - self.user_embedding(inputs["user_id"]), - self.gender_embedding(inputs["user_gender"]), - self.age_embedding(inputs["raw_user_age"]), - self.gender_x_age_embedding(inputs["user_gender_X_raw_user_age"]), - self.occupation_embedding(inputs["user_occupation_label"]), - ], - axis=1, - ) - return self.dense_layers(feature_embedding) - -``` - ---- -## Candidate model - -We can adopt the same approach for the candidate model. Again, we start with -converting movie features to embeddings, concatenate them and then expand it -with hidden layers: - - -```python - -class CandidateModel(keras.Model): - """Model for encoding candidates (movies).""" - - def __init__(self, layer_sizes, embedding_dimension=32): - """Construct a model for encoding candidates (movies). - - Args: - layer_sizes: A list of integers where the i-th entry represents the - number of units the i-th layer contains. - embedding_dimension: Output dimension for all embedding tables. - """ - super().__init__() - - # We first generate embeddings. - self.movie_embedding = keras.layers.Embedding( - # +1 for movie ID zero, which does not exist - MOVIES_COUNT + 1, - embedding_dimension, - ) - # Take all the title tokens for the title of the movie, embed each - # token, and then take the mean of all token embeddings. - self.movie_title_embedding = keras.Sequential( - [ - keras.layers.Embedding( - # +1 for OOV token, which is used for padding - TITLE_TOKEN_COUNT + 1, - embedding_dimension, - mask_zero=True, - ), - keras.layers.GlobalAveragePooling1D(), - ] - ) - # Take all the genres for the movie, embed each genre, and then take the - # mean of all genre embeddings. - self.movie_genres_embedding = keras.Sequential( - [ - keras.layers.Embedding( - # +1 for OOV genre, which is used for padding - GENRES_COUNT + 1, - embedding_dimension, - mask_zero=True, - ), - keras.layers.GlobalAveragePooling1D(), - ] - ) - - # Then construct the layers. - self.dense_layers = keras.Sequential() - - # Use the ReLU activation for all but the last layer. - for layer_size in layer_sizes[:-1]: - self.dense_layers.add(keras.layers.Dense(layer_size, activation="relu")) - - # No activation for the last layer. - self.dense_layers.add(keras.layers.Dense(layer_sizes[-1])) - - def call(self, inputs): - movie_id = inputs["movie_id"] - movie_title_vector = inputs["movie_title_vector"] - movie_genres = inputs["movie_genres"] - feature_embedding = keras.ops.concatenate( - [ - self.movie_embedding(movie_id), - self.movie_title_embedding(movie_title_vector), - self.movie_genres_embedding(movie_genres), - ], - axis=1, - ) - return self.dense_layers(feature_embedding) - -``` - ---- -## Combined model - -With both QueryModel and CandidateModel defined, we can put together a combined -model and implement our loss and metrics logic. To make things simple, we'll -enforce that the model structure is the same across the query and candidate -models. - - -```python - -class RetrievalModel(keras.Model): - """Combined model.""" - - def __init__( - self, - layer_sizes=(32,), - embedding_dimension=32, - retrieval_k=100, - ): - """Construct a combined model. - - Args: - layer_sizes: A list of integers where the i-th entry represents the - number of units the i-th layer contains. - embedding_dimension: Output dimension for all embedding tables. - retrieval_k: How many candidate movies to retrieve. - """ - super().__init__() - self.query_model = QueryModel(layer_sizes, embedding_dimension) - self.candidate_model = CandidateModel(layer_sizes, embedding_dimension) - self.retrieval = keras_rs.layers.BruteForceRetrieval( - k=retrieval_k, return_scores=False - ) - self.update_candidates() # Provide an initial set of candidates - self.loss_fn = keras.losses.MeanSquaredError() - self.top_k_metric = keras.metrics.SparseTopKCategoricalAccuracy( - k=100, from_sorted_ids=True - ) - - def update_candidates(self): - self.retrieval.update_candidates( - self.candidate_model.predict(MOVIES_DATASET, verbose=0) - ) - - def call(self, inputs, training=False): - query_embeddings = self.query_model( - { - "user_id": inputs["user_id"], - "raw_user_age": inputs["raw_user_age"], - "user_gender": inputs["user_gender"], - "user_occupation_label": inputs["user_occupation_label"], - "user_gender_X_raw_user_age": inputs["user_gender_X_raw_user_age"], - } - ) - candidate_embeddings = self.candidate_model( - { - "movie_id": inputs["movie_id"], - "movie_title_vector": inputs["movie_title_vector"], - "movie_genres": inputs["movie_genres"], - } - ) - - result = { - "query_embeddings": query_embeddings, - "candidate_embeddings": candidate_embeddings, - } - if not training: - # No need to spend time extracting top predicted movies during - # training, they are not used. - result["predictions"] = self.retrieval(query_embeddings) - return result - - def evaluate( - self, - x=None, - y=None, - batch_size=None, - verbose="auto", - sample_weight=None, - steps=None, - callbacks=None, - return_dict=False, - **kwargs, - ): - """Overridden to update the candidate set. - - Before evaluating the model, we need to update our retrieval layer by - re-computing the values predicted by the candidate model for all the - candidates. - """ - self.update_candidates() - return super().evaluate( - x, - y, - batch_size=batch_size, - verbose=verbose, - sample_weight=sample_weight, - steps=steps, - callbacks=callbacks, - return_dict=return_dict, - **kwargs, - ) - - def compute_loss(self, x, y, y_pred, sample_weight, training=True): - query_embeddings = y_pred["query_embeddings"] - candidate_embeddings = y_pred["candidate_embeddings"] - - labels = keras.ops.expand_dims(y, -1) - # Compute the affinity score by multiplying the two embeddings. - scores = keras.ops.sum( - keras.ops.multiply(query_embeddings, candidate_embeddings), - axis=1, - keepdims=True, - ) - return self.loss_fn(labels, scores, sample_weight) - - def compute_metrics(self, x, y, y_pred, sample_weight=None): - if "predictions" in y_pred: - # We are evaluating or predicting. Update `top_k_metric`. - movie_ids = x["movie_id"] - predictions = y_pred["predictions"] - # For `top_k_metric`, which is a `SparseTopKCategoricalAccuracy`, we - # only take top rated movies, and we put a weight of 0 for the rest. - rating_weight = keras.ops.cast(keras.ops.greater(y, 0.9), "float32") - sample_weight = ( - rating_weight - if sample_weight is None - else keras.ops.multiply(rating_weight, sample_weight) - ) - self.top_k_metric.update_state( - movie_ids, predictions, sample_weight=sample_weight - ) - return self.get_metrics_result() - else: - # We are training. `top_k_metric` is not updated and is zero, so - # don't report it. - result = self.get_metrics_result() - result.pop(self.top_k_metric.name) - return result - -``` - ---- -## Training the model - -### Shallow model - -We're ready to try out our first, shallow, model! - - -```python -NUM_EPOCHS = 30 - -one_layer_model = RetrievalModel((32,)) -one_layer_model.compile(optimizer=keras.optimizers.Adagrad(0.05)) - -one_layer_history = one_layer_model.fit( - train_ratings, - validation_data=test_ratings, - validation_freq=5, - epochs=NUM_EPOCHS, -) -``` - -
-``` -Epoch 1/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 21:56 17s/step - loss: 0.4487 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 1:03 811ms/step - loss: 0.4548 - -
-``` - -``` -
- 12/80 ━━━━━━━━━━━━━━━━━━━━ 5s 78ms/step - loss: 0.4569 - -
-``` - -``` -
- 22/80 ━━━━━━━━━━━━━━━━━━━━ 2s 44ms/step - loss: 0.4031 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - loss: 0.3579 - -
-``` - -``` -
- 41/80 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step - loss: 0.3203 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - loss: 0.2923 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.2725 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.2548 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.2403 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 18s 18ms/step - loss: 0.2390 - - -
-``` -Epoch 2/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1:08 868ms/step - loss: 0.0760 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0760 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0761 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0762 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0762 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0762 - - -
-``` -Epoch 3/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0738 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0740 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0741 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0742 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0743 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0743 - - -
-``` -Epoch 4/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0722 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0726 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0727 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0728 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0728 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0729 - - -
-``` -Epoch 5/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0708 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0714 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0715 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0716 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0716 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 17s 221ms/step - loss: 0.0717 - val_loss: 0.0727 - val_sparse_top_k_categorical_accuracy: 0.1794 - - -
-``` -Epoch 6/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - loss: 0.0695 - -
-``` - -``` -
- 15/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0703 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0704 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0705 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0706 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0706 - - -
-``` -Epoch 7/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0683 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0693 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0694 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0695 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0695 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0696 - - -
-``` -Epoch 8/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0671 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0683 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0684 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0685 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0686 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0686 - - -
-``` -Epoch 9/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0659 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0674 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0675 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0675 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0676 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0676 - - -
-``` -Epoch 10/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0648 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0665 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0665 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0666 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0666 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0667 - val_loss: 0.0679 - val_sparse_top_k_categorical_accuracy: 0.2392 - - -
-``` -Epoch 11/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0637 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0655 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0656 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0657 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0657 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0657 - - -
-``` -Epoch 12/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0626 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0646 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0647 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0647 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0648 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0648 - - -
-``` -Epoch 13/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0615 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0637 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0638 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - - -
-``` -Epoch 14/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0605 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0629 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0629 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0630 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0630 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0630 - - -
-``` -Epoch 15/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0595 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0621 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0621 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0621 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0622 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0622 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0622 - val_loss: 0.0636 - val_sparse_top_k_categorical_accuracy: 0.2836 - - -
-``` -Epoch 16/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0586 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0613 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0613 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0614 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0614 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0614 - - -
-``` -Epoch 17/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0577 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0606 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0606 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0607 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0607 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0607 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0607 - - -
-``` -Epoch 18/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0569 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0599 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0599 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0600 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0600 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0600 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0600 - - -
-``` -Epoch 19/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0562 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0593 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0593 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0594 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0594 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0594 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0594 - - -
-``` -Epoch 20/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0556 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0587 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0588 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0588 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0589 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0589 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0589 - val_loss: 0.0605 - val_sparse_top_k_categorical_accuracy: 0.3118 - - -
-``` -Epoch 21/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0550 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0582 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0582 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0583 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0584 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0584 - - -
-``` -Epoch 22/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0545 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0577 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0578 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0578 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0579 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0579 - - -
-``` -Epoch 23/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0540 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0573 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0573 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0574 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0575 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0575 - - -
-``` -Epoch 24/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0536 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0569 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0570 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0570 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0571 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0571 - - -
-``` -Epoch 25/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0532 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0567 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0567 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0567 - val_loss: 0.0586 - val_sparse_top_k_categorical_accuracy: 0.3219 - - -
-``` -Epoch 26/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0529 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0563 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0563 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0563 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0564 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0564 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0564 - - -
-``` -Epoch 27/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0526 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0560 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0560 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0560 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0561 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0561 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0561 - - -
-``` -Epoch 28/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0523 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0557 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0557 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0558 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0558 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0559 - - -
-``` -Epoch 29/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0520 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0556 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0556 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0556 - - -
-``` -Epoch 30/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0518 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0553 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0553 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0554 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0554 - val_loss: 0.0574 - val_sparse_top_k_categorical_accuracy: 0.3216 - - -This gives us a top-100 accuracy of around 0.30. We can use this as a reference -point for evaluating deeper models. - -### Deeper model - -What about a deeper model with two layers? - - -```python -two_layer_model = RetrievalModel((64, 32)) -two_layer_model.compile(optimizer=keras.optimizers.Adagrad(0.05)) -two_layer_history = two_layer_model.fit( - train_ratings, - validation_data=test_ratings, - validation_freq=5, - epochs=NUM_EPOCHS, -) -``` - -
-``` -Epoch 1/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1:49 1s/step - loss: 0.4479 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 1:13 940ms/step - loss: 0.4535 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 3s 55ms/step - loss: 0.3700 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - loss: 0.2929 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step - loss: 0.2477 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.2180 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - loss: 0.2010 - - -
-``` -Epoch 2/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1:16 963ms/step - loss: 0.0760 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0757 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0757 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0758 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0758 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0758 - - -
-``` -Epoch 3/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0744 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0742 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0742 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0742 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0743 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0743 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0743 - - -
-``` -Epoch 4/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0729 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0728 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0729 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0729 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0730 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0730 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0730 - - -
-``` -Epoch 5/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0715 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0717 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0717 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0717 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0718 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0718 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 27ms/step - loss: 0.0718 - val_loss: 0.0725 - val_sparse_top_k_categorical_accuracy: 0.1145 - - -
-``` -Epoch 6/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0701 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0706 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0705 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0706 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0706 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0707 - - -
-``` -Epoch 7/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0688 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0695 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0694 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0695 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0695 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0695 - - -
-``` -Epoch 8/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0675 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0684 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0684 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0684 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0684 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0684 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0684 - - -
-``` -Epoch 9/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0662 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0673 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0672 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0673 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0673 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0673 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0673 - - -
-``` -Epoch 10/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0648 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0661 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0661 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0661 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0662 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0662 - val_loss: 0.0670 - val_sparse_top_k_categorical_accuracy: 0.2066 - - -
-``` -Epoch 11/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0635 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0651 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0650 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0650 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0650 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0650 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0650 - - -
-``` -Epoch 12/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0623 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0640 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - - -
-``` -Epoch 13/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0611 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0630 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0629 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0628 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0629 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0629 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0629 - - -
-``` -Epoch 14/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0600 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0620 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0619 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0619 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0619 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0619 - - -
-``` -Epoch 15/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0590 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0612 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0611 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0610 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0610 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0610 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0610 - val_loss: 0.0622 - val_sparse_top_k_categorical_accuracy: 0.2694 - - -
-``` -Epoch 16/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0580 - -
-``` - -``` -
- 15/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0605 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0603 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0602 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0602 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0603 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0603 - - -
-``` -Epoch 17/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0572 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0598 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0596 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0595 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0595 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0596 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0596 - - -
-``` -Epoch 18/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0565 - -
-``` - -``` -
- 15/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0592 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0590 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0589 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0589 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0589 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0589 - - -
-``` -Epoch 19/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0558 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0586 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0584 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0583 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0583 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0584 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0584 - - -
-``` -Epoch 20/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0552 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0580 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0579 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0578 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0578 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0578 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0578 - val_loss: 0.0594 - val_sparse_top_k_categorical_accuracy: 0.2793 - - -
-``` -Epoch 21/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0547 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0576 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0574 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0573 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0573 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0574 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0574 - - -
-``` -Epoch 22/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0542 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0572 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0570 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0569 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0569 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0570 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0570 - - -
-``` -Epoch 23/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0538 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0568 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0565 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - - -
-``` -Epoch 24/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0534 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0565 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0563 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0562 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0562 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0563 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0563 - - -
-``` -Epoch 25/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0530 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0562 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0560 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0559 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0559 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0560 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0560 - val_loss: 0.0579 - val_sparse_top_k_categorical_accuracy: 0.2896 - - -
-``` -Epoch 26/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0527 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0559 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0557 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0556 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0557 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0557 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0557 - - -
-``` -Epoch 27/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0524 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0556 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0554 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0554 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - - -
-``` -Epoch 28/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0521 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0554 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - - -
-``` -Epoch 29/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0519 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0550 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0550 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0550 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0550 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0550 - - -
-``` -Epoch 30/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0517 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0550 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0548 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0548 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0548 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0548 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0548 - val_loss: 0.0570 - val_sparse_top_k_categorical_accuracy: 0.2964 - - -While the deeper model seems to learn a bit better than the shallow model at -first, the difference becomes minimal towards the end of the trainign. We can -plot the validation accuracy curves to illustrate this: - - -```python -METRIC = "val_sparse_top_k_categorical_accuracy" -num_validation_runs = len(one_layer_history.history[METRIC]) -epochs = [(x + 1) * 5 for x in range(num_validation_runs)] - -plt.plot(epochs, one_layer_history.history[METRIC], label="1 layer") -plt.plot(epochs, two_layer_history.history[METRIC], label="2 layers") -plt.title("Accuracy vs epoch") -plt.xlabel("epoch") -plt.ylabel("Top-100 accuracy") -plt.legend() -plt.show() -``` - - - -![png](/img/examples/keras_rs/deep_recommender/deep_recommender_57_0.png) - - - -Deeper models are not necessarily better. The following model extends the depth -to three layers: - - -```python -three_layer_model = RetrievalModel((128, 64, 32)) -three_layer_model.compile(optimizer=keras.optimizers.Adagrad(0.05)) -three_layer_history = three_layer_model.fit( - train_ratings, - validation_data=test_ratings, - validation_freq=5, - epochs=NUM_EPOCHS, -) -``` - -
-``` -Epoch 1/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1:58 1s/step - loss: 0.4474 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 1:22 1s/step - loss: 0.4530 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 4s 69ms/step - loss: 0.3563 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 1s 37ms/step - loss: 0.2745 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 26ms/step - loss: 0.2300 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.2020 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - loss: 0.1843 - - -
-``` -Epoch 2/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1:23 1s/step - loss: 0.0769 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0761 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0760 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0760 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0760 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0760 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0760 - - -
-``` -Epoch 3/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0751 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0745 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0744 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0744 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0745 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0745 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0745 - - -
-``` -Epoch 4/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0737 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0732 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0732 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0732 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0732 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0733 - - -
-``` -Epoch 5/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0724 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0722 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0720 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0720 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0721 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 27ms/step - loss: 0.0721 - val_loss: 0.0726 - val_sparse_top_k_categorical_accuracy: 0.1402 - - -
-``` -Epoch 6/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0712 - -
-``` - -``` -
- 15/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0711 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0709 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0709 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0709 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0710 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0710 - - -
-``` -Epoch 7/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0698 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0700 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0698 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0698 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0698 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0699 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0699 - - -
-``` -Epoch 8/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0684 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0689 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0687 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0687 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0687 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0688 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0688 - - -
-``` -Epoch 9/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0668 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0677 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0675 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0675 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0675 - -
-``` - -``` -
- 74/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0676 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0677 - - -
-``` -Epoch 10/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0654 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0665 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0663 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0663 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0664 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0665 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0665 - val_loss: 0.0667 - val_sparse_top_k_categorical_accuracy: 0.2197 - - -
-``` -Epoch 11/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0640 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0653 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0652 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0652 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0652 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0654 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0654 - - -
-``` -Epoch 12/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0626 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0642 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0641 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0640 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0641 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0643 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0643 - - -
-``` -Epoch 13/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0613 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0631 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0630 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0630 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0630 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0632 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0633 - - -
-``` -Epoch 14/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0601 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0620 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0620 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0619 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0621 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0623 - - -
-``` -Epoch 15/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0590 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0611 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0610 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0610 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0611 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0613 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0613 - val_loss: 0.0618 - val_sparse_top_k_categorical_accuracy: 0.2900 - - -
-``` -Epoch 16/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0580 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0603 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0602 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0602 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0603 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0605 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0605 - - -
-``` -Epoch 17/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0572 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0596 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0595 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0595 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0596 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0597 - - -
-``` -Epoch 18/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0564 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0589 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0588 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0588 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0589 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0590 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0591 - - -
-``` -Epoch 19/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0557 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0583 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0583 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0582 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0583 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0585 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0585 - - -
-``` -Epoch 20/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0551 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0578 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0578 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0577 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0578 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0579 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0580 - val_loss: 0.0591 - val_sparse_top_k_categorical_accuracy: 0.3015 - - -
-``` -Epoch 21/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0546 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0574 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0573 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0573 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0574 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0575 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0575 - - -
-``` -Epoch 22/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0541 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0570 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0569 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0569 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0570 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0571 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0571 - - -
-``` -Epoch 23/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0537 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0565 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0567 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0567 - - -
-``` -Epoch 24/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0533 - -
-``` - -``` -
- 15/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0563 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0562 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0562 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0563 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0564 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0564 - - -
-``` -Epoch 25/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0530 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0560 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0560 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0559 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0560 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0561 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0561 - val_loss: 0.0577 - val_sparse_top_k_categorical_accuracy: 0.3049 - - -
-``` -Epoch 26/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0527 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0558 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0557 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0557 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0558 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0559 - - -
-``` -Epoch 27/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0524 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0554 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0556 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0556 - - -
-``` -Epoch 28/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0522 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0553 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0553 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0553 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0554 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0554 - - -
-``` -Epoch 29/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0520 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0551 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0551 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0551 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0551 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - - -
-``` -Epoch 30/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0517 - -
-``` - -``` -
- 15/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0550 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0549 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0549 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0549 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0550 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0550 - val_loss: 0.0569 - val_sparse_top_k_categorical_accuracy: 0.3072 - - -We don't really see an improvement over the shallow model: - - -```python -plt.plot(epochs, one_layer_history.history[METRIC], label="1 layer") -plt.plot(epochs, two_layer_history.history[METRIC], label="2 layers") -plt.plot(epochs, three_layer_history.history[METRIC], label="3 layers") -plt.title("Accuracy vs epoch") -plt.xlabel("epoch") -plt.ylabel("Top-100 accuracy") -plt.legend() -plt.show() -``` - - - -![png](/img/examples/keras_rs/deep_recommender/deep_recommender_61_0.png) - - - -This is a good illustration of the fact that deeper and larger models, while -capable of superior performance, often require very careful tuning. For example, -throughout this tutorial we used a single, fixed learning rate. Alternative -choices may give very different results and are worth exploring. - -With appropriate tuning and sufficient data, the effort put into building larger -and deeper models is in many cases well worth it: larger models can lead to -substantial improvements in prediction accuracy. - ---- -## Next Steps - -In this tutorial we expanded our retrieval model with dense layers and -activation functions. To see how to create a model that can perform not only -retrieval tasks but also rating tasks, take a look at the multitask tutorial. - diff --git a/templates/examples/keras_rs/dlrm.md b/templates/examples/keras_rs/dlrm.md deleted file mode 100644 index 50ad0f7ef0..0000000000 --- a/templates/examples/keras_rs/dlrm.md +++ /dev/null @@ -1,522 +0,0 @@ -# Ranking with Deep Learning Recommendation Model - -**Author:** [Harshith Kulkarni](https://github.com/kharshith-k)
-**Date created:** 2025/06/02
-**Last modified:** 2025/09/04
-**Description:** Rank movies with DLRM using KerasRS. - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/dlrm.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/dlrm.py) - - - ---- -## Introduction - -This tutorial demonstrates how to use the Deep Learning Recommendation Model (DLRM) to -effectively learn the relationships between items and user preferences using a -dot-product interaction mechanism. For more details, please refer to the -[DLRM](https://arxiv.org/abs/1906.00091) paper. - -DLRM is designed to excel at capturing explicit, bounded-degree feature interactions and -is particularly effective at processing both categorical and continuous (sparse/dense) -input features. The architecture consists of three main components: dedicated input -layers to handle diverse features (typically embedding layers for categorical features), -a dot-product interaction layer to explicitly model feature interactions, and a -Multi-Layer Perceptron (MLP) to capture implicit feature relationships. - -The dot-product interaction layer lies at the heart of DLRM, efficiently computing -pairwise interactions between different feature embeddings. This contrasts with models -like Deep & Cross Network (DCN), which can treat elements within a feature vector as -independent units, potentially leading to a higher-dimensional space and increased -computational cost. The MLP is a standard feedforward network. The DLRM is formed by -combining the interaction layer and MLP. - -The following image illustrates the DLRM architecture: - -![DLRM Architecture](https://raw.githubusercontent.com/kharshith-k/keras-io/refs/heads/keras-rs-examples/examples/keras_rs/img/dlrm/dlrm_architecture.gif) - - -Now that we have a foundational understanding of DLRM's architecture and key -characteristics, let's dive into the code. We will train a DLRM on a real-world dataset -to demonstrate its capability to learn meaningful feature interactions. Let's begin by -setting the backend to JAX and organizing our imports. - - -```python -!pip install -q keras-rs -``` - - - - -```python -import os - -os.environ["KERAS_BACKEND"] = "tensorflow" # `"tensorflow"`/`"torch"` - -import keras -import matplotlib.pyplot as plt -import numpy as np -import tensorflow as tf -import tensorflow_datasets as tfds -from mpl_toolkits.axes_grid1 import make_axes_locatable - -import keras_rs -``` - -Let's also define variables which will be reused throughout the example. - - -```python -MOVIELENS_CONFIG = { - # features - "continuous_features": [ - "raw_user_age", - "hour_of_day_sin", - "hour_of_day_cos", - "hour_of_week_sin", - "hour_of_week_cos", - ], - "categorical_int_features": [ - "user_gender", - ], - "categorical_str_features": [ - "user_zip_code", - "user_occupation_text", - "movie_id", - "user_id", - ], - # model - "embedding_dim": 8, - "mlp_dim": 8, - "deep_net_num_units": [192, 192, 192], - # training - "learning_rate": 1e-4, - "num_epochs": 30, - "batch_size": 8192, -} -``` - -Here, we define a helper function for visualising weights of the cross layer in -order to better understand its functioning. Also, we define a function for -compiling, training and evaluating a given model. - - -```python - -def plot_training_metrics(history): - """Graphs all metrics tracked in the history object.""" - plt.figure(figsize=(12, 6)) - - for metric_name, metric_values in history.history.items(): - plt.plot(metric_values, label=metric_name.replace("_", " ").title()) - - plt.title("Metrics over Epochs") - plt.xlabel("Epoch") - plt.ylabel("Metric Value") - plt.legend() - plt.grid(True) - - -def visualize_layer(matrix, features, cmap=plt.cm.Blues): - - im = plt.matshow( - matrix, cmap=cmap, extent=[-0.5, len(features) - 0.5, len(features) - 0.5, -0.5] - ) - - ax = plt.gca() - divider = make_axes_locatable(plt.gca()) - cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im, cax=cax) - cax.tick_params(labelsize=10) - - # Set tick locations explicitly before setting labels - ax.set_xticks(np.arange(len(features))) - ax.set_yticks(np.arange(len(features))) - - ax.set_xticklabels(features, rotation=45, fontsize=5) - ax.set_yticklabels(features, fontsize=5) - - plt.show() - - -def train_and_evaluate( - learning_rate, - epochs, - train_data, - test_data, - model, - plot_metrics=False, -): - optimizer = keras.optimizers.AdamW(learning_rate=learning_rate, clipnorm=1.0) - loss = keras.losses.MeanSquaredError() - rmse = keras.metrics.RootMeanSquaredError() - - model.compile( - optimizer=optimizer, - loss=loss, - metrics=[rmse], - ) - - history = model.fit( - train_data, - epochs=epochs, - verbose=1, - ) - if plot_metrics: - plot_training_metrics(history) - - results = model.evaluate(test_data, return_dict=True, verbose=1) - rmse_value = results["root_mean_squared_error"] - - return rmse_value, model.count_params() - - -def print_stats(rmse_list, num_params, model_name): - # Report metrics. - num_trials = len(rmse_list) - avg_rmse = np.mean(rmse_list) - std_rmse = np.std(rmse_list) - - if num_trials == 1: - print(f"{model_name}: RMSE = {avg_rmse}; #params = {num_params}") - else: - print(f"{model_name}: RMSE = {avg_rmse} ± {std_rmse}; #params = {num_params}") - -``` - ---- -## Real-world example - -Let's use the MovieLens 100K dataset. This dataset is used to train models to -predict users' movie ratings, based on user-related features and movie-related -features. - -### Preparing the dataset - -The dataset processing steps here are similar to what's given in the -[basic ranking](/keras_rs/examples/basic_ranking/) -tutorial. Let's load the dataset, and keep only the useful columns. - - -```python -ratings_ds = tfds.load("movielens/100k-ratings", split="train") - - -def preprocess_features(x): - """Extracts and cyclically encodes timestamp features.""" - features = { - "movie_id": x["movie_id"], - "user_id": x["user_id"], - "user_gender": tf.cast(x["user_gender"], dtype=tf.int32), - "user_zip_code": x["user_zip_code"], - "user_occupation_text": x["user_occupation_text"], - "raw_user_age": tf.cast(x["raw_user_age"], dtype=tf.float32), - } - label = tf.cast(x["user_rating"], dtype=tf.float32) - - # The timestamp is in seconds since the epoch. - timestamp = tf.cast(x["timestamp"], dtype=tf.float32) - - # Constants for time periods - SECONDS_IN_HOUR = 3600.0 - HOURS_IN_DAY = 24.0 - HOURS_IN_WEEK = 168.0 - - # Calculate hour of day and encode it - hour_of_day = (timestamp / SECONDS_IN_HOUR) % HOURS_IN_DAY - features["hour_of_day_sin"] = tf.sin(2 * np.pi * hour_of_day / HOURS_IN_DAY) - features["hour_of_day_cos"] = tf.cos(2 * np.pi * hour_of_day / HOURS_IN_DAY) - - # Calculate hour of week and encode it - hour_of_week = (timestamp / SECONDS_IN_HOUR) % HOURS_IN_WEEK - features["hour_of_week_sin"] = tf.sin(2 * np.pi * hour_of_week / HOURS_IN_WEEK) - features["hour_of_week_cos"] = tf.cos(2 * np.pi * hour_of_week / HOURS_IN_WEEK) - - return features, label - - -# Apply the new preprocessing function -ratings_ds = ratings_ds.map(preprocess_features) -``` - -For every categorical feature, let's get the list of unique values, i.e., vocabulary, so -that we can use that for the embedding layer. - - -```python -vocabularies = {} -for feature_name in ( - MOVIELENS_CONFIG["categorical_int_features"] - + MOVIELENS_CONFIG["categorical_str_features"] -): - vocabulary = ratings_ds.batch(10_000).map(lambda x, y: x[feature_name]) - vocabularies[feature_name] = np.unique(np.concatenate(list(vocabulary))) -``` - -One thing we need to do is to use `keras.layers.StringLookup` and -`keras.layers.IntegerLookup` to convert all the categorical features into indices, which -can -then be fed into embedding layers. - - -```python -lookup_layers = {} -lookup_layers.update( - { - feature: keras.layers.IntegerLookup(vocabulary=vocabularies[feature]) - for feature in MOVIELENS_CONFIG["categorical_int_features"] - } -) -lookup_layers.update( - { - feature: keras.layers.StringLookup(vocabulary=vocabularies[feature]) - for feature in MOVIELENS_CONFIG["categorical_str_features"] - } -) -``` - -Let's normalize all the continuous features, so that we can use that for the MLP layers. - - -```python -normalization_layers = {} -for feature_name in MOVIELENS_CONFIG["continuous_features"]: - normalization_layers[feature_name] = keras.layers.Normalization(axis=-1) - -training_data_for_adaptation = ratings_ds.take(80_000).map(lambda x, y: x) - -for feature_name in MOVIELENS_CONFIG["continuous_features"]: - feature_ds = training_data_for_adaptation.map( - lambda x: tf.expand_dims(x[feature_name], axis=-1) - ) - normalization_layers[feature_name].adapt(feature_ds) - -ratings_ds = ratings_ds.map( - lambda x, y: ( - { - **{ - feature_name: lookup_layers[feature_name](x[feature_name]) - for feature_name in vocabularies - }, - # Apply the adapted normalization layers to the continuous features. - **{ - feature_name: tf.squeeze( - normalization_layers[feature_name]( - tf.expand_dims(x[feature_name], axis=-1) - ), - axis=-1, - ) - for feature_name in MOVIELENS_CONFIG["continuous_features"] - }, - }, - y, - ) -) -``` - -Let's split our data into train and test sets. We also use `cache()` and -`prefetch()` for better performance. - - -```python -ratings_ds = ratings_ds.shuffle(100_000) - -train_ds = ( - ratings_ds.take(80_000) - .batch(MOVIELENS_CONFIG["batch_size"]) - .cache() - .prefetch(tf.data.AUTOTUNE) -) -test_ds = ( - ratings_ds.skip(80_000) - .batch(MOVIELENS_CONFIG["batch_size"]) - .take(20_000) - .cache() - .prefetch(tf.data.AUTOTUNE) -) -``` - -### Building the model - -The model will have embedding layers, followed by DotInteraction and feedforward -layers. - - -```python - -class DLRM(keras.Model): - def __init__( - self, - dense_num_units_lst, - embedding_dim=MOVIELENS_CONFIG["embedding_dim"], - mlp_dim=MOVIELENS_CONFIG["mlp_dim"], - **kwargs, - ): - super().__init__(**kwargs) - - self.embedding_layers = {} - for feature_name in ( - MOVIELENS_CONFIG["categorical_int_features"] - + MOVIELENS_CONFIG["categorical_str_features"] - ): - vocab_size = len(vocabularies[feature_name]) + 1 # +1 for OOV token - self.embedding_layers[feature_name] = keras.layers.Embedding( - input_dim=vocab_size, - output_dim=embedding_dim, - ) - - self.bottom_mlp = keras.Sequential( - [ - keras.layers.Dense(mlp_dim, activation="relu"), - keras.layers.Dense(embedding_dim), # Output must match embedding_dim - ] - ) - - self.dot_layer = keras_rs.layers.DotInteraction() - - self.top_mlp = [] - for num_units in dense_num_units_lst: - self.top_mlp.append(keras.layers.Dense(num_units, activation="relu")) - - self.output_layer = keras.layers.Dense(1) - - self.dense_num_units_lst = dense_num_units_lst - self.embedding_dim = embedding_dim - - def call(self, inputs): - embeddings = [] - for feature_name in ( - MOVIELENS_CONFIG["categorical_int_features"] - + MOVIELENS_CONFIG["categorical_str_features"] - ): - embedding = self.embedding_layers[feature_name](inputs[feature_name]) - embeddings.append(embedding) - - # Process all continuous features together. - continuous_inputs = [] - for feature_name in MOVIELENS_CONFIG["continuous_features"]: - # Reshape each feature to (batch_size, 1) - feature = keras.ops.reshape( - keras.ops.cast(inputs[feature_name], dtype="float32"), (-1, 1) - ) - continuous_inputs.append(feature) - - # Concatenate into a single tensor: (batch_size, num_continuous_features) - concatenated_continuous = keras.ops.concatenate(continuous_inputs, axis=1) - - # Pass through the Bottom MLP to get one combined vector. - processed_continuous = self.bottom_mlp(concatenated_continuous) - - # Combine with categorical embeddings. Note: we add a list containing the - # single tensor. - combined_features = embeddings + [processed_continuous] - - # Pass the list of features to the DotInteraction layer. - x = self.dot_layer(combined_features) - - for layer in self.top_mlp: - x = layer(x) - - x = self.output_layer(x) - - return x - - -dot_network = DLRM( - dense_num_units_lst=MOVIELENS_CONFIG["deep_net_num_units"], - embedding_dim=MOVIELENS_CONFIG["embedding_dim"], - mlp_dim=MOVIELENS_CONFIG["mlp_dim"], -) - -rmse, dot_network_num_params = train_and_evaluate( - learning_rate=MOVIELENS_CONFIG["learning_rate"], - epochs=MOVIELENS_CONFIG["num_epochs"], - train_data=train_ds, - test_data=test_ds, - model=dot_network, - plot_metrics=True, -) -print_stats( - rmse_list=[rmse], - num_params=dot_network_num_params, - model_name="Dot Network", -) -``` - -![png](/img/examples/keras_rs/dlrm/dlrm_19_158.png) - - -### Visualizing feature interactions - -The DotInteraction layer itself doesn't have a conventional "weight" matrix like a Dense -layer. Instead, its function is to compute the dot product between the embedding vectors -of your features. - -To visualize the strength of these interactions, we can calculate a matrix representing -the pairwise interaction strength between all feature embeddings. A common way to do this -is to take the dot product of the embedding matrices for each pair of features and then -aggregate the result into a single value (like the mean of the absolute values) that -represents the overall interaction strength. - - -```python - -def get_dot_interaction_matrix(model, categorical_features, continuous_features): - # The new feature list for the plot labels - all_feature_names = categorical_features + ["all_continuous_features"] - num_features = len(all_feature_names) - - # Store all feature outputs in the correct order. - all_feature_outputs = [] - - # Get outputs for categorical features from embedding layers (unchanged). - for feature_name in categorical_features: - embedding = model.embedding_layers[feature_name](keras.ops.array([0])) - all_feature_outputs.append(embedding) - - # Get a single output for ALL continuous features from the shared MLP. - num_continuous_features = len(continuous_features) - # Create a dummy input of zeros for the MLP - dummy_continuous_input = keras.ops.zeros((1, num_continuous_features)) - processed_continuous = model.bottom_mlp(dummy_continuous_input) - all_feature_outputs.append(processed_continuous) - - interaction_matrix = np.zeros((num_features, num_features)) - - # Iterate through each pair to calculate interaction strength. - for i in range(num_features): - for j in range(num_features): - interaction = keras.ops.dot( - all_feature_outputs[i], keras.ops.transpose(all_feature_outputs[j]) - ) - interaction_strength = keras.ops.convert_to_numpy(np.abs(interaction))[0][0] - interaction_matrix[i, j] = interaction_strength - - return interaction_matrix, all_feature_names - - -# Get the list of categorical feature names. -categorical_feature_names = ( - MOVIELENS_CONFIG["categorical_int_features"] - + MOVIELENS_CONFIG["categorical_str_features"] -) - -# Calculate the interaction matrix with the corrected function. -interaction_matrix, feature_names = get_dot_interaction_matrix( - model=dot_network, - categorical_features=categorical_feature_names, - continuous_features=MOVIELENS_CONFIG["continuous_features"], -) - -# Visualize the matrix as a heatmap. -print("\nVisualizing the feature interaction strengths:") -visualize_layer(interaction_matrix, feature_names) -``` - -![png](/img/examples/keras_rs/dlrm/dlrm_21_1.png) - - - diff --git a/templates/examples/keras_rs/listwise_ranking.md b/templates/examples/keras_rs/listwise_ranking.md deleted file mode 100644 index 7143859333..0000000000 --- a/templates/examples/keras_rs/listwise_ranking.md +++ /dev/null @@ -1,669 +0,0 @@ -# List-wise ranking - -**Author:** [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Rank movies using pairwise losses instead of pointwise losses. - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/listwise_ranking.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/listwise_ranking.py) - - - ---- -## Introduction - -In our -[basic ranking tutorial](/keras_rs/examples/basic_ranking/), we explored a model -that learned to predict ratings for specific user-movie combinations. This model -took (user, movie) pairs as input and was trained using mean-squared error to -precisely predict the rating a user might give to a movie. - -However, solely optimizing a model's accuracy in predicting individual movie -scores isn't always the most effective strategy for developing ranking systems. -For ranking models, pinpoint accuracy in predicting scores is less critical than -the model's capability to generate an ordered list of items that aligns with a -user's preferences. In essence, the relative order of items matters more than -the exact predicted values. - -Instead of focusing on the model's predictions for individual query-item pairs -(a pointwise approach), we can optimize the model based on its ability to -correctly order items. One common method for this is pairwise ranking. In this -approach, the model learns by comparing pairs of items (e.g., item A and item B) -and determining which one should be ranked higher for a given user or query. The -goal is to minimize the number of incorrectly ordered pairs. - -Let's begin by importing all the necessary libraries. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import collections - -import keras -import numpy as np -import tensorflow as tf # Needed only for the dataset -import tensorflow_datasets as tfds -from keras import ops - -import keras_rs -``` - -Let's define some hyperparameters here. - - -```python -# Data args -TRAIN_NUM_LIST_PER_USER = 50 -TEST_NUM_LIST_PER_USER = 1 -NUM_EXAMPLES_PER_LIST = 5 - -# Model args -EMBEDDING_DIM = 32 - -# Train args -BATCH_SIZE = 1024 -EPOCHS = 5 -LEARNING_RATE = 0.1 -``` - ---- -## Preparing the dataset - -We use the MovieLens dataset. The data loading and processing steps are similar -to previous tutorials, so, we will only discuss the differences here. - - -```python -# Ratings data. -ratings = tfds.load("movielens/100k-ratings", split="train") -# Features of all the available movies. -movies = tfds.load("movielens/100k-movies", split="train") - -users_count = ( - ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) - .reduce(tf.constant(0, tf.int32), tf.maximum) - .numpy() -) -movies_count = movies.cardinality().numpy() - - -def preprocess_rating(x): - return { - "user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32), - "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), - # Normalise ratings between 0 and 1. - "user_rating": (x["user_rating"] - 1.0) / 4.0, - } - - -shuffled_ratings = ratings.map(preprocess_rating).shuffle( - 100_000, seed=42, reshuffle_each_iteration=False -) -train_ratings = shuffled_ratings.take(70_000) -val_ratings = shuffled_ratings.skip(70_000).take(15_000) -test_ratings = shuffled_ratings.skip(85_000).take(15_000) -``` - -So far, we've replicated what we have in the basic ranking tutorial. - -However, this existing dataset is not directly applicable to list-wise -optimization. List-wise optimization requires, for each user, a list of movies -they have rated, allowing the model to learn from the relative orderings within -that list. The MovieLens 100K dataset, in its original form, provides individual -rating instances (one user, one movie, one rating per example), rather than -these aggregated user-specific lists. - -To enable listwise optimization, we need to restructure the dataset. This -involves transforming it so that each data point or example represents a single -user ID accompanied by a list of movies that user has rated. Within these lists, -some movies will naturally be ranked higher by the user (as evidenced by their -ratings) than others. The primary objective for our model will then be to learn -to predict item orderings that correspond to these observed user preferences. - -Let's start by getting the entire list of movies and corresponding ratings for -every user. We remove `user_ids` corresponding to users who have rated less than -`NUM_EXAMPLES_PER_LIST` number of movies. - - -```python - -def get_movie_sequence_per_user(ratings, min_examples_per_list): - """Gets movieID sequences and ratings for every user.""" - sequences = collections.defaultdict(list) - - for sample in ratings: - user_id = sample["user_id"] - movie_id = sample["movie_id"] - user_rating = sample["user_rating"] - - sequences[int(user_id.numpy())].append( - { - "movie_id": int(movie_id.numpy()), - "user_rating": float(user_rating.numpy()), - } - ) - - # Remove lists with < `min_examples_per_list` number of elements. - sequences = { - user_id: sequence - for user_id, sequence in sequences.items() - if len(sequence) >= min_examples_per_list - } - - return sequences - -``` - -We now sample 50 lists for each user for the training data. For each list, we -randomly sample 5 movies from the movies the user rated. - - -```python - -def sample_sublist_from_list( - lst, - num_examples_per_list, -): - """Random selects `num_examples_per_list` number of elements from list.""" - - indices = np.random.choice( - range(len(lst)), - size=num_examples_per_list, - replace=False, - ) - - samples = [lst[i] for i in indices] - return samples - - -def get_examples( - sequences, - num_list_per_user, - num_examples_per_list, -): - inputs = { - "user_id": [], - "movie_id": [], - } - labels = [] - for user_id, user_list in sequences.items(): - sampled_list = sample_sublist_from_list( - user_list, - num_examples_per_list, - ) - - inputs["user_id"].append(user_id) - inputs["movie_id"].append( - tf.convert_to_tensor([f["movie_id"] for f in sampled_list]) - ) - labels.append(tf.convert_to_tensor([f["user_rating"] for f in sampled_list])) - - return ( - {"user_id": inputs["user_id"], "movie_id": inputs["movie_id"]}, - labels, - ) - - -train_sequences = get_movie_sequence_per_user( - ratings=train_ratings, min_examples_per_list=NUM_EXAMPLES_PER_LIST -) -train_examples = get_examples( - train_sequences, - num_list_per_user=TRAIN_NUM_LIST_PER_USER, - num_examples_per_list=NUM_EXAMPLES_PER_LIST, -) -train_ds = tf.data.Dataset.from_tensor_slices(train_examples) - -val_sequences = get_movie_sequence_per_user( - ratings=val_ratings, min_examples_per_list=5 -) -val_examples = get_examples( - val_sequences, - num_list_per_user=TEST_NUM_LIST_PER_USER, - num_examples_per_list=NUM_EXAMPLES_PER_LIST, -) -val_ds = tf.data.Dataset.from_tensor_slices(val_examples) - -test_sequences = get_movie_sequence_per_user( - ratings=test_ratings, min_examples_per_list=5 -) -test_examples = get_examples( - test_sequences, - num_list_per_user=TEST_NUM_LIST_PER_USER, - num_examples_per_list=NUM_EXAMPLES_PER_LIST, -) -test_ds = tf.data.Dataset.from_tensor_slices(test_examples) -``` - -Batch up the dataset, and cache it. - - -```python -train_ds = train_ds.batch(BATCH_SIZE).cache() -val_ds = val_ds.batch(BATCH_SIZE).cache() -test_ds = test_ds.batch(BATCH_SIZE).cache() -``` - ---- -## Building the model - -We build a typical two-tower ranking model, similar to the -[basic ranking tutorial](/keras_rs/examples/basic_ranking/). -We have separate embedding layers for user ID and movie IDs. After obtaining -these embeddings, we concatenate them and pass them through a network of dense -layers. - -The only point of difference is that for movie IDs, we take a list of IDs -rather than just one movie ID. So, when we concatenate user ID embedding and -movie IDs' embeddings, we "repeat" the user ID 'NUM_EXAMPLES_PER_LIST' times so -as to get the same shape as the movie IDs' embeddings. - - -```python - -class RankingModel(keras.Model): - """Create the ranking model with the provided parameters. - - Args: - num_users: Number of entries in the user embedding table. - num_candidates: Number of entries in the candidate embedding table. - embedding_dimension: Output dimension for user and movie embedding tables. - """ - - def __init__( - self, - num_users, - num_candidates, - embedding_dimension=32, - **kwargs, - ): - super().__init__(**kwargs) - # Embedding table for users. - self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension) - # Embedding table for candidates. - self.candidate_embedding = keras.layers.Embedding( - num_candidates, embedding_dimension - ) - # Predictions. - self.ratings = keras.Sequential( - [ - # Learn multiple dense layers. - keras.layers.Dense(256, activation="relu"), - keras.layers.Dense(64, activation="relu"), - # Make rating predictions in the final layer. - keras.layers.Dense(1), - ] - ) - - def build(self, input_shape): - self.user_embedding.build(input_shape["user_id"]) - self.candidate_embedding.build(input_shape["movie_id"]) - - output_shape = self.candidate_embedding.compute_output_shape( - input_shape["movie_id"] - ) - - self.ratings.build(list(output_shape[:-1]) + [2 * output_shape[-1]]) - - def call(self, inputs): - user_id, movie_id = inputs["user_id"], inputs["movie_id"] - user_embeddings = self.user_embedding(user_id) - candidate_embeddings = self.candidate_embedding(movie_id) - - list_length = ops.shape(movie_id)[-1] - user_embeddings_repeated = ops.repeat( - ops.expand_dims(user_embeddings, axis=1), - repeats=list_length, - axis=1, - ) - concatenated_embeddings = ops.concatenate( - [user_embeddings_repeated, candidate_embeddings], axis=-1 - ) - - scores = self.ratings(concatenated_embeddings) - scores = ops.squeeze(scores, axis=-1) - - return scores - - def compute_output_shape(self, input_shape): - return (input_shape[0], input_shape[1]) - -``` - -Let's instantiate, compile and train our model. We will train two models: -one with vanilla mean-squared error, and the other with pairwise hinge loss. -For the latter, we will use `keras_rs.losses.PairwiseHingeLoss`. - -Pairwise losses compare pairs of items within each list, penalizing cases where -an item with a higher true label has a lower predicted score than an item with a -lower true label. This is why they are more suited for ranking tasks than -pointwise losses. - -To quantify these results, we compute nDCG. nDCG is a measure of ranking quality -that evaluates how well a system orders items based on relevance, giving more -importance to highly relevant items appearing at the top of the list and -normalizing the score against an ideal ranking. -To compute it, we just need to pass `keras_rs.metrics.NDCG()` as a metric to -`model.compile`. - - -```python -model_mse = RankingModel( - num_users=users_count + 1, - num_candidates=movies_count + 1, - embedding_dimension=EMBEDDING_DIM, -) -model_mse.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[keras_rs.metrics.NDCG(k=NUM_EXAMPLES_PER_LIST, name="ndcg")], - optimizer=keras.optimizers.Adagrad(learning_rate=LEARNING_RATE), -) -model_mse.fit(train_ds, validation_data=val_ds, epochs=EPOCHS) -``` - -
-``` -Epoch 1/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 2s/step - loss: 0.4960 - ndcg: 0.8892 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 3s 3s/step - loss: 0.4960 - ndcg: 0.8892 - val_loss: 0.1187 - val_ndcg: 0.8846 - - -
-``` -Epoch 2/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - loss: 0.1150 - ndcg: 0.8898 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step - loss: 0.1150 - ndcg: 0.8898 - val_loss: 0.0893 - val_ndcg: 0.8878 - - -
-``` -Epoch 3/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0876 - ndcg: 0.8884 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.0876 - ndcg: 0.8884 - val_loss: 0.0864 - val_ndcg: 0.8857 - - -
-``` -Epoch 4/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0834 - ndcg: 0.8896 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.0834 - ndcg: 0.8896 - val_loss: 0.0815 - val_ndcg: 0.8876 - - -
-``` -Epoch 5/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0794 - ndcg: 0.8887 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.0794 - ndcg: 0.8887 - val_loss: 0.0810 - val_ndcg: 0.8868 - - - - - -
-``` - - -``` -
-And now, the model with pairwise hinge loss. - - -```python -model_hinge = RankingModel( - num_users=users_count + 1, - num_candidates=movies_count + 1, - embedding_dimension=EMBEDDING_DIM, -) -model_hinge.compile( - loss=keras_rs.losses.PairwiseHingeLoss(), - metrics=[keras_rs.metrics.NDCG(k=NUM_EXAMPLES_PER_LIST, name="ndcg")], - optimizer=keras.optimizers.Adagrad(learning_rate=LEARNING_RATE), -) -model_hinge.fit(train_ds, validation_data=val_ds, epochs=EPOCHS) -``` - -
-``` -Epoch 1/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - loss: 1.4067 - ndcg: 0.8933 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step - loss: 1.4067 - ndcg: 0.8933 - val_loss: 1.3927 - val_ndcg: 0.8930 - - -
-``` -Epoch 2/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - loss: 1.4061 - ndcg: 0.8953 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step - loss: 1.4061 - ndcg: 0.8953 - val_loss: 1.3925 - val_ndcg: 0.8936 - - -
-``` -Epoch 3/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 1.4054 - ndcg: 0.8977 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 1.4054 - ndcg: 0.8977 - val_loss: 1.3923 - val_ndcg: 0.8941 - - -
-``` -Epoch 4/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 1.4047 - ndcg: 0.8999 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 1.4047 - ndcg: 0.8999 - val_loss: 1.3921 - val_ndcg: 0.8941 - - -
-``` -Epoch 5/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 1.4041 - ndcg: 0.9004 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 1.4041 - ndcg: 0.9004 - val_loss: 1.3919 - val_ndcg: 0.8940 - - - - - -
-``` - - -``` -
---- -## Evaluation - -Comparing the validation nDCG values, it is clear that the model trained with -the pairwise hinge loss outperforms the other one. Let's make this observation -more concrete by comparing results on the test set. - - -```python -ndcg_mse = model_mse.evaluate(test_ds, return_dict=True)["ndcg"] -ndcg_hinge = model_hinge.evaluate(test_ds, return_dict=True)["ndcg"] -print(ndcg_mse, ndcg_hinge) -``` - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - loss: 0.0805 - ndcg: 0.8886 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step - loss: 0.0805 - ndcg: 0.8886 - - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 910ms/step - loss: 1.3878 - ndcg: 0.8924 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 914ms/step - loss: 1.3878 - ndcg: 0.8924 - - -
-``` -0.8885537385940552 0.8924424052238464 - -``` -
---- -## Prediction - -Now, let's rank some lists! - -Let's create a mapping from movie ID to title so that we can surface the titles -for the ranked list. - - -```python -movie_id_to_movie_title = { - int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator() -} -movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. - -user_id = 42 -movie_ids = [409, 237, 131, 941, 543] -predictions = model_hinge.predict( - { - "user_id": keras.ops.array([user_id]), - "movie_id": keras.ops.array([movie_ids]), - } -) -predictions = keras.ops.convert_to_numpy(keras.ops.squeeze(predictions, axis=0)) -sorted_indices = np.argsort(predictions) -sorted_movies = [movie_ids[i] for i in sorted_indices] - -for i, movie_id in enumerate(sorted_movies): - print(f"{i + 1}. ", movie_id_to_movie_title[movie_id]) -``` - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 261ms/step - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 262ms/step - - -
-``` -1. b'Jack (1996)' -2. b'Mis\xc3\xa9rables, Les (1995)' -3. b'Jerry Maguire (1996)' -4. b"Breakfast at Tiffany's (1961)" -5. b'With Honors (1994)' - -``` -
-And we're all done! - diff --git a/templates/examples/keras_rs/multi_task.md b/templates/examples/keras_rs/multi_task.md deleted file mode 100644 index 6b764b1f22..0000000000 --- a/templates/examples/keras_rs/multi_task.md +++ /dev/null @@ -1,1464 +0,0 @@ -# Multi-task recommenders: retrieval + ranking - -**Author:** [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Using one model for both retrieval and ranking. - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/multi_task.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/multi_task.py) - - - ---- -## Introduction - -In the -[basic retrieval](/keras_rs/examples/basic_retrieval/) -and -[basic ranking](/keras_rs/examples/basic_ranking/) -tutorials, we created separate models for retrieval and ranking tasks, -respectively. However, in many cases, building a single, joint model for -multiple tasks can lead to better performance than creating distinct models for -each task. This is especially true when dealing with data that is unevenly -distributed — such as abundant data (e.g., clicks) versus sparse data -(e.g., purchases, returns, or manual reviews). In such scenarios, a joint model -can leverage representations learned from the abundant data to improve -predictions on the sparse data, a technique known as transfer learning. -For instance, [research](https://openreview.net/forum?id=SJxPVcSonN) shows that -a model trained to predict user ratings from sparse survey data can be -significantly enhanced by incorporating an auxiliary task using abundant click -log data. - -In this example, we develop a multi-objective recommender system using the -MovieLens dataset. We incorporate both implicit feedback (e.g., movie watches) -and explicit feedback (e.g., ratings) to create a more robust and effective -recommendation model. For the former, we predict "movie watches", i.e., whether -a user has watched a movie, and for the latter, we predict the rating given by a -user to a movie. - -Let's start by importing the necessary packages. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import keras -import tensorflow as tf # Needed for the dataset -import tensorflow_datasets as tfds - -import keras_rs -``` - ---- -## Prepare the dataset - -We use the MovieLens dataset. The data loading and processing steps are similar -to previous tutorials, so we will not discuss them in details here. - - -```python -# Ratings data with user and movie data. -ratings = tfds.load("movielens/100k-ratings", split="train") -# Features of all the available movies. -movies = tfds.load("movielens/100k-movies", split="train") -``` - -Get user and movie counts so that we can define embedding layers. - - -```python -users_count = ( - ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) - .reduce(tf.constant(0, tf.int32), tf.maximum) - .numpy() -) - -movies_count = movies.cardinality().numpy() -``` - -Our inputs are `"user_id"` and `"movie_id"`. Our label for the ranking task is -`"user_rating"`. `"user_rating"` is an integer between 0 to 4. We constrain it -to `[0, 1]`. - - -```python - -def preprocess_rating(x): - return ( - { - "user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32), - "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), - }, - (x["user_rating"] - 1.0) / 4.0, - ) - - -shuffled_ratings = ratings.map(preprocess_rating).shuffle( - 100_000, seed=42, reshuffle_each_iteration=False -) - -``` - -Split the dataset into train-test sets. - - -```python -train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() -test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() -``` - ---- -## Building the model - -We build the model in a similar way to the basic retrieval and basic ranking -guides. - -For the retrieval task (i.e., predicting whether a user watched a movie), -we compute the similarity of the corresponding user and movie embeddings, and -use cross entropy loss, where the positive pairs are labelled one, and all other -samples in the batch are considered "negatives". We report top-k accuracy for -this task. - -For the ranking task (i.e., given a user-movie pair, predict rating), we -concatenate user and movie embeddings and pass it to a dense module. We use -MSE loss here, and report the Root Mean Squared Error (RMSE). - -The final loss is a weighted combination of the two losses mentioned above, -where the weights are `"retrieval_loss_wt"` and `"ranking_loss_wt"`. These -weights decide which task the model will focus on. - - -```python - -class MultiTaskModel(keras.Model): - def __init__( - self, - num_users, - num_candidates, - embedding_dimension=32, - layer_sizes=(256, 128), - retrieval_loss_wt=1.0, - ranking_loss_wt=1.0, - **kwargs, - ): - super().__init__(**kwargs) - # Our query tower, simply an embedding table. - self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension) - - # Our candidate tower, simply an embedding table. - self.candidate_embedding = keras.layers.Embedding( - num_candidates, embedding_dimension - ) - - # Rating model. - self.rating_model = tf.keras.Sequential( - [ - keras.layers.Dense(layer_size, activation="relu") - for layer_size in layer_sizes - ] - + [keras.layers.Dense(1)] - ) - - # The layer that performs the retrieval. - self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False) - - self.retrieval_loss_fn = keras.losses.CategoricalCrossentropy( - from_logits=True, - reduction="sum", - ) - self.ranking_loss_fn = keras.losses.MeanSquaredError() - - # Top-k accuracy for retrieval - self.top_k_metric = keras.metrics.SparseTopKCategoricalAccuracy( - k=100, from_sorted_ids=True - ) - # RMSE for ranking - self.rmse_metric = keras.metrics.RootMeanSquaredError() - - # Attributes. - self.num_users = num_users - self.num_candidates = num_candidates - self.embedding_dimension = embedding_dimension - self.layer_sizes = layer_sizes - self.retrieval_loss_wt = retrieval_loss_wt - self.ranking_loss_wt = ranking_loss_wt - - def build(self, input_shape): - self.user_embedding.build(input_shape) - self.candidate_embedding.build(input_shape) - # In this case, the candidates are directly the movie embeddings. - # We take a shortcut and directly reuse the variable. - self.retrieval.candidate_embeddings = self.candidate_embedding.embeddings - self.retrieval.build(input_shape) - - self.rating_model.build((None, 2 * self.embedding_dimension)) - - super().build(input_shape) - - def call(self, inputs, training=False): - # Unpack inputs. Note that we have the if condition throughout this - # `call()` method so that we can do a `.predict()` for the retrieval - # task. - user_id = inputs["user_id"] - if "movie_id" in inputs: - movie_id = inputs["movie_id"] - - result = {} - - # Get user, movie embeddings. - user_embeddings = self.user_embedding(user_id) - result["user_embeddings"] = user_embeddings - - if "movie_id" in inputs: - candidate_embeddings = self.candidate_embedding(movie_id) - result["candidate_embeddings"] = candidate_embeddings - - # Pass both embeddings through the rating block of the model. - rating = self.rating_model( - keras.ops.concatenate([user_embeddings, candidate_embeddings], axis=1) - ) - result["rating"] = rating - - if not training: - # Skip the retrieval of top movies during training as the - # predictions are not used. - result["predictions"] = self.retrieval(user_embeddings) - - return result - - def compute_loss(self, x, y, y_pred, sample_weight, training=True): - user_embeddings = y_pred["user_embeddings"] - candidate_embeddings = y_pred["candidate_embeddings"] - - # 1. Retrieval - - # Compute the affinity score by multiplying the two embeddings. - scores = keras.ops.matmul( - user_embeddings, - keras.ops.transpose(candidate_embeddings), - ) - - # Retrieval labels: One-hot vectors - num_users = keras.ops.shape(user_embeddings)[0] - num_candidates = keras.ops.shape(candidate_embeddings)[0] - retrieval_labels = keras.ops.eye(num_users, num_candidates) - # Retrieval loss - retrieval_loss = self.retrieval_loss_fn(retrieval_labels, scores, sample_weight) - - # 2. Ranking - ratings = y - pred_rating = y_pred["rating"] - - # Ranking labels are just ratings. - ranking_labels = keras.ops.expand_dims(ratings, -1) - # Ranking loss - ranking_loss = self.ranking_loss_fn(ranking_labels, pred_rating, sample_weight) - - # Total loss is a weighted combination of the two losses. - total_loss = ( - self.retrieval_loss_wt * retrieval_loss - + self.ranking_loss_wt * ranking_loss - ) - - return total_loss - - def compute_metrics(self, x, y, y_pred, sample_weight=None): - # RMSE can be computed irrespective of whether we are - # training/evaluating. - self.rmse_metric.update_state( - y, - y_pred["rating"], - sample_weight=sample_weight, - ) - - if "predictions" in y_pred: - # We are evaluating or predicting. Update `top_k_metric`. - movie_ids = x["movie_id"] - predictions = y_pred["predictions"] - # For `top_k_metric`, which is a `SparseTopKCategoricalAccuracy`, we - # only take top rated movies, and we put a weight of 0 for the rest. - rating_weight = keras.ops.cast(keras.ops.greater(y, 0.9), "float32") - sample_weight = ( - rating_weight - if sample_weight is None - else keras.ops.multiply(rating_weight, sample_weight) - ) - self.top_k_metric.update_state( - movie_ids, predictions, sample_weight=sample_weight - ) - - return self.get_metrics_result() - else: - # We are training. `top_k_metric` is not updated and is zero, so - # don't report it. - result = self.get_metrics_result() - result.pop(self.top_k_metric.name) - return result - -``` - ---- -## Training and evaluating - -We will train three different models here. This can be done easily by passing -the correct loss weights: - -1. Rating-specialised model -2. Retrieval-specialised model -3. Multi-task model - - -```python -# Rating-specialised model -model = MultiTaskModel( - num_users=users_count + 1, - num_candidates=movies_count + 1, - ranking_loss_wt=1.0, - retrieval_loss_wt=0.0, -) -model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1)) -model.fit(train_ratings, epochs=5) - -model.evaluate(test_ratings) - -# Retrieval-specialised model -model = MultiTaskModel( - num_users=users_count + 1, - num_candidates=movies_count + 1, - ranking_loss_wt=0.0, - retrieval_loss_wt=1.0, -) -model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1)) -model.fit(train_ratings, epochs=5) - -model.evaluate(test_ratings) - -# Multi-task model -model = MultiTaskModel( - num_users=users_count + 1, - num_candidates=movies_count + 1, - ranking_loss_wt=1.0, - retrieval_loss_wt=1.0, -) -model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1)) -model.fit(train_ratings, epochs=5) - -model.evaluate(test_ratings) -``` - -
-``` -Epoch 1/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 3:45 3s/step - loss: 0.4353 - root_mean_squared_error: 0.6598 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 52s 671ms/step - loss: 0.3644 - root_mean_squared_error: 0.6007 - -
-``` - -``` -
- 26/80 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - loss: 0.1393 - root_mean_squared_error: 0.3644 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 1s 28ms/step - loss: 0.1376 - root_mean_squared_error: 0.3623 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.1170 - root_mean_squared_error: 0.3353 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.1073 - root_mean_squared_error: 0.3223 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.1070 - root_mean_squared_error: 0.3218 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 4s 13ms/step - loss: 0.1042 - root_mean_squared_error: 0.3180 - - -
-``` -Epoch 2/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 52s 668ms/step - loss: 0.0780 - root_mean_squared_error: 0.2792 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0785 - root_mean_squared_error: 0.2801 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0776 - root_mean_squared_error: 0.2786 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0776 - root_mean_squared_error: 0.2786 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0776 - root_mean_squared_error: 0.2786 - -
-``` - -``` -
- 57/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0777 - root_mean_squared_error: 0.2787 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0777 - root_mean_squared_error: 0.2787 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0777 - root_mean_squared_error: 0.2787 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0777 - root_mean_squared_error: 0.2787 - - -
-``` -Epoch 3/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0763 - root_mean_squared_error: 0.2762 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0770 - root_mean_squared_error: 0.2775 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0761 - root_mean_squared_error: 0.2758 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0761 - root_mean_squared_error: 0.2758 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0761 - root_mean_squared_error: 0.2758 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0760 - root_mean_squared_error: 0.2756 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0760 - root_mean_squared_error: 0.2756 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0759 - root_mean_squared_error: 0.2755 - - -
-``` -Epoch 4/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0734 - root_mean_squared_error: 0.2710 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0745 - root_mean_squared_error: 0.2730 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0736 - root_mean_squared_error: 0.2713 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0736 - root_mean_squared_error: 0.2713 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0734 - root_mean_squared_error: 0.2710 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0734 - root_mean_squared_error: 0.2710 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0735 - root_mean_squared_error: 0.2710 - - -
-``` -Epoch 5/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0717 - root_mean_squared_error: 0.2678 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0733 - root_mean_squared_error: 0.2713 - - - 3/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0736 - root_mean_squared_error: 0.2713 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0714 - root_mean_squared_error: 0.2671 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0714 - root_mean_squared_error: 0.2672 - -
-``` - -``` -
- 57/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0713 - root_mean_squared_error: 0.2670 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0713 - root_mean_squared_error: 0.2670 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0713 - root_mean_squared_error: 0.2669 - - - - 1/20 ━━━━━━━━━━━━━━━━━━━━ 42s 2s/step - loss: 0.0685 - root_mean_squared_error: 0.2618 - sparse_top_k_categorical_accuracy: 0.0046 - -
-``` - -``` -
- 3/20 ━━━━━━━━━━━━━━━━━━━━ 5s 349ms/step - loss: 0.0677 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - -
-``` - -``` -
- 5/20 ━━━━━━━━━━━━━━━━━━━━ 2s 174ms/step - loss: 0.0670 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - 9/20 ━━━━━━━━━━━━━━━━━━━━ 0s 87ms/step - loss: 0.0667 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044  - 2/20 ━━━━━━━━━━━━━━━━━━━━ 12s 696ms/step - loss: 0.0681 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - 11/20 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 0.0667 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044  - 6/20 ━━━━━━━━━━━━━━━━━━━━ 1s 140ms/step - loss: 0.0667 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - 10/20 ━━━━━━━━━━━━━━━━━━━━ 0s 78ms/step - loss: 0.0667 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - -
-``` - -``` -
- 13/20 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 0.0671 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044  - 8/20 ━━━━━━━━━━━━━━━━━━━━ 1s 100ms/step - loss: 0.0668 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - -
-``` - -``` -
- 12/20 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 0.0667 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - -
-``` - -``` -
- 7/20 ━━━━━━━━━━━━━━━━━━━━ 1s 116ms/step - loss: 0.0669 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - -
-``` - -``` -
- 4/20 ━━━━━━━━━━━━━━━━━━━━ 3s 233ms/step - loss: 0.0667 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - -
-``` - -``` -
- 20/20 ━━━━━━━━━━━━━━━━━━━━ 3s 38ms/step - loss: 0.0670 - root_mean_squared_error: 0.2589 - sparse_top_k_categorical_accuracy: 0.0046 - - -
-``` -Epoch 1/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 55s 705ms/step - loss: 6907.7500 - root_mean_squared_error: 0.6712 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 53s 681ms/step - loss: 6907.7939 - root_mean_squared_error: 0.6763 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 1s 26ms/step - loss: 6906.6592 - root_mean_squared_error: 0.6932 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 1s 25ms/step - loss: 6906.3804 - root_mean_squared_error: 0.6932 - -
-``` - -``` -
- 57/80 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 6887.2905 - root_mean_squared_error: 0.6935 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 6886.2769 - root_mean_squared_error: 0.6935 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 11ms/step - loss: 6861.2632 - root_mean_squared_error: 0.6933 - - -
-``` -Epoch 2/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 52s 668ms/step - loss: 6595.3521 - root_mean_squared_error: 0.6702 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6568.2349 - root_mean_squared_error: 0.6925 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6567.1797 - root_mean_squared_error: 0.6926 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6566.1387 - root_mean_squared_error: 0.6926 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6544.7070 - root_mean_squared_error: 0.6939 - -
-``` - -``` -
- 56/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6543.9644 - root_mean_squared_error: 0.6939 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 6527.7217 - root_mean_squared_error: 0.6952 - - -
-``` -Epoch 3/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6421.3364 - root_mean_squared_error: 0.6830 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6426.4746 - root_mean_squared_error: 0.6891 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6402.4702 - root_mean_squared_error: 0.7059 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6401.7056 - root_mean_squared_error: 0.7059 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6400.9751 - root_mean_squared_error: 0.7059 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6386.6289 - root_mean_squared_error: 0.7069 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6386.2451 - root_mean_squared_error: 0.7070 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6379.3403 - root_mean_squared_error: 0.7077 - - -
-``` -Epoch 4/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6326.5630 - root_mean_squared_error: 0.6919 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 6333.5112 - root_mean_squared_error: 0.6981 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6309.5977 - root_mean_squared_error: 0.7150 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6308.8608 - root_mean_squared_error: 0.7151 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6296.6309 - root_mean_squared_error: 0.7158 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6296.3599 - root_mean_squared_error: 0.7159 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6296.0918 - root_mean_squared_error: 0.7159 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6291.6152 - root_mean_squared_error: 0.7164 - - -
-``` -Epoch 5/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6259.3281 - root_mean_squared_error: 0.6987 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6267.6138 - root_mean_squared_error: 0.7051 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6242.9312 - root_mean_squared_error: 0.7220 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6242.1875 - root_mean_squared_error: 0.7220 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6241.4839 - root_mean_squared_error: 0.7221 - -
-``` - -``` -
- 57/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6231.3540 - root_mean_squared_error: 0.7226 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6231.1279 - root_mean_squared_error: 0.7226 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6227.6514 - root_mean_squared_error: 0.7231 - - - - 1/20 ━━━━━━━━━━━━━━━━━━━━ 9s 501ms/step - loss: 6525.7983 - root_mean_squared_error: 0.7341 - sparse_top_k_categorical_accuracy: 0.0183 - -
-``` - -``` -
- 2/20 ━━━━━━━━━━━━━━━━━━━━ 12s 708ms/step - loss: 6545.6025 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - 9/20 ━━━━━━━━━━━━━━━━━━━━ 0s 89ms/step - loss: 6557.3950 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156  - 5/20 ━━━━━━━━━━━━━━━━━━━━ 2s 177ms/step - loss: 6556.7119 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - 6/20 ━━━━━━━━━━━━━━━━━━━━ 1s 142ms/step - loss: 6557.6411 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - 4/20 ━━━━━━━━━━━━━━━━━━━━ 3s 237ms/step - loss: 6556.4917 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - 13/20 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 6558.5605 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - -
-``` - -``` -
- 11/20 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - loss: 6557.2266 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - -
-``` - -``` -
- 7/20 ━━━━━━━━━━━━━━━━━━━━ 1s 119ms/step - loss: 6558.2988 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156  - 10/20 ━━━━━━━━━━━━━━━━━━━━ 0s 79ms/step - loss: 6557.6724 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - -
-``` - -``` -
- 8/20 ━━━━━━━━━━━━━━━━━━━━ 1s 102ms/step - loss: 6557.9561 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - -
-``` - -``` -
- 12/20 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 6556.1787 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - -
-``` - -``` -
- 3/20 ━━━━━━━━━━━━━━━━━━━━ 6s 356ms/step - loss: 6558.2368 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - -
-``` - -``` -
- 20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 39ms/step - loss: 6558.5298 - root_mean_squared_error: 0.7323 - sparse_top_k_categorical_accuracy: 0.0156 - - -
-``` -Epoch 1/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 56s 716ms/step - loss: 6907.9180 - root_mean_squared_error: 0.6640 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 51s 656ms/step - loss: 6907.9414 - root_mean_squared_error: 0.6054 - -
-``` - -``` -
- 3/80 ━━━━━━━━━━━━━━━━━━━━ 25s 330ms/step - loss: 6907.9351 - root_mean_squared_error: 0.5618 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 1s 25ms/step - loss: 6906.2886 - root_mean_squared_error: 0.3586 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 1s 24ms/step - loss: 6905.9717 - root_mean_squared_error: 0.3569 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 6884.6377 - root_mean_squared_error: 0.3280 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 6883.6255 - root_mean_squared_error: 0.3274 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 10ms/step - loss: 6861.9297 - root_mean_squared_error: 0.3174 - - -
-``` -Epoch 2/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 52s 660ms/step - loss: 6599.1538 - root_mean_squared_error: 0.2549 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6566.7197 - root_mean_squared_error: 0.2586 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6565.6699 - root_mean_squared_error: 0.2586 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6564.6597 - root_mean_squared_error: 0.2586 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6541.2002 - root_mean_squared_error: 0.2586 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6540.4863 - root_mean_squared_error: 0.2586 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 6526.9360 - root_mean_squared_error: 0.2591 - - -
-``` -Epoch 3/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6427.2715 - root_mean_squared_error: 0.2496 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 6430.3330 - root_mean_squared_error: 0.2527 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6401.6621 - root_mean_squared_error: 0.2532 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6400.9707 - root_mean_squared_error: 0.2532 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6400.2896 - root_mean_squared_error: 0.2531 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6386.1152 - root_mean_squared_error: 0.2531 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6385.7368 - root_mean_squared_error: 0.2532 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6385.3530 - root_mean_squared_error: 0.2533 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6379.2231 - root_mean_squared_error: 0.2537 - - -
-``` -Epoch 4/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6332.7959 - root_mean_squared_error: 0.2469 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6337.2896 - root_mean_squared_error: 0.2503 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6308.8354 - root_mean_squared_error: 0.2503 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6308.1694 - root_mean_squared_error: 0.2503 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6295.6636 - root_mean_squared_error: 0.2502 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6295.3931 - root_mean_squared_error: 0.2502 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6295.1182 - root_mean_squared_error: 0.2502 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6290.9727 - root_mean_squared_error: 0.2506 - - -
-``` -Epoch 5/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6266.3545 - root_mean_squared_error: 0.2446 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6271.7319 - root_mean_squared_error: 0.2483 - - - 3/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 6268.4746 - root_mean_squared_error: 0.2497 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6240.8154 - root_mean_squared_error: 0.2482 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6240.1978 - root_mean_squared_error: 0.2482 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6239.6104 - root_mean_squared_error: 0.2481 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6229.3428 - root_mean_squared_error: 0.2482 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6229.1450 - root_mean_squared_error: 0.2482 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6228.9478 - root_mean_squared_error: 0.2482 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6226.5605 - root_mean_squared_error: 0.2485 - - - - 1/20 ━━━━━━━━━━━━━━━━━━━━ 9s 478ms/step - loss: 6510.3120 - root_mean_squared_error: 0.2476 - sparse_top_k_categorical_accuracy: 0.0183 - -
-``` - -``` -
- 3/20 ━━━━━━━━━━━━━━━━━━━━ 5s 351ms/step - loss: 6552.2383 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - 8/20 ━━━━━━━━━━━━━━━━━━━━ 1s 100ms/step - loss: 6548.0225 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - -
-``` - -``` -
- 11/20 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 6552.4331 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158  - 10/20 ━━━━━━━━━━━━━━━━━━━━ 0s 78ms/step - loss: 6553.4868 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - -
-``` - -``` -
- 5/20 ━━━━━━━━━━━━━━━━━━━━ 2s 175ms/step - loss: 6552.0576 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - 13/20 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 6553.3755 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - 7/20 ━━━━━━━━━━━━━━━━━━━━ 1s 117ms/step - loss: 6552.1162 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - 9/20 ━━━━━━━━━━━━━━━━━━━━ 0s 88ms/step - loss: 6552.2988 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - 4/20 ━━━━━━━━━━━━━━━━━━━━ 3s 233ms/step - loss: 6552.1694 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - -
-``` - -``` -
- 6/20 ━━━━━━━━━━━━━━━━━━━━ 1s 140ms/step - loss: 6551.8081 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - -
-``` - -``` -
- 2/20 ━━━━━━━━━━━━━━━━━━━━ 12s 699ms/step - loss: 6548.6211 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - -
-``` - -``` -
- 12/20 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 6552.3442 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - -
-``` - -``` -
- 20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 38ms/step - loss: 6554.1953 - root_mean_squared_error: 0.2492 - sparse_top_k_categorical_accuracy: 0.0158 - - - - - -
-``` -[6555.712890625, 0.016953036189079285, 0.2508334815502167] - -``` -
-Let's plot a table of the metrics and pen down our observations: - -| Model | Top-K Accuracy (↑) | RMSE (↓) | -|-----------------------|--------------------|----------| -| rating-specialised | 0.005 | 0.26 | -| retrieval-specialised | 0.020 | 0.78 | -| multi-task | 0.022 | 0.25 | - -As expected, the rating-specialised model has good RMSE, but poor top-k -accuracy. For the retrieval-specialised model, it's the opposite. - -For the multi-task model, we notice that the model does well (or even slightly -better than the two specialised models) on both tasks. In general, we can expect -multi-task learning to bring about better results, especially when one task has -a data-abundant source, and the other task is trained on sparse data. - -Now, let's make a prediction! We will first do a retrieval, and then for the -retrieved list of movies, we will predict the rating using the same model. - - -```python -movie_id_to_movie_title = { - int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator() -} -movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. - -user_id = 5 -retrieved_movie_ids = model.predict( - { - "user_id": keras.ops.array([user_id]), - } -) -retrieved_movie_ids = keras.ops.convert_to_numpy(retrieved_movie_ids["predictions"][0]) -retrieved_movies = [movie_id_to_movie_title[x] for x in retrieved_movie_ids] -``` - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 109ms/step - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 110ms/step - - -For these retrieved movies, we can now get the corresponding ratings. - - -```python -pred_ratings = model.predict( - { - "user_id": keras.ops.array([user_id] * len(retrieved_movie_ids)), - "movie_id": keras.ops.array(retrieved_movie_ids), - } -)["rating"] -pred_ratings = keras.ops.convert_to_numpy(keras.ops.squeeze(pred_ratings, axis=1)) - -for movie_id, prediction in zip(retrieved_movie_ids, pred_ratings): - print(f"{movie_id_to_movie_title[movie_id]}: {5.0 * prediction:,.2f}") -``` - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 273ms/step - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 274ms/step - - -
-``` -b'Blob, The (1958)': 2.01 -b'Mighty Morphin Power Rangers: The Movie (1995)': 2.03 -b'Flintstones, The (1994)': 2.18 -b'Beverly Hillbillies, The (1993)': 1.89 -b'Lawnmower Man, The (1992)': 2.57 -b'Hot Shots! Part Deux (1993)': 2.28 -b'Street Fighter (1994)': 1.84 -b'Cabin Boy (1994)': 1.94 -b'Little Rascals, The (1994)': 2.12 -b'Jaws 3-D (1983)': 2.27 - -``` -
diff --git a/templates/examples/keras_rs/sas_rec.md b/templates/examples/keras_rs/sas_rec.md deleted file mode 100644 index 40305906d3..0000000000 --- a/templates/examples/keras_rs/sas_rec.md +++ /dev/null @@ -1,2972 +0,0 @@ -# Sequential retrieval using SASRec - -**Author:** [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Recommend movies using a Transformer-based retrieval model (SASRec). - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/sas_rec.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/sas_rec.py) - - - ---- -## Introduction - -Sequential recommendation is a popular model that looks at a sequence of items -that users have interacted with previously and then predicts the next item. -Here, the order of the items within each sequence matters. Previously, in the -[Recommending movies: retrieval using a sequential model](/keras_rs/examples/sequential_retrieval/) -example, we built a GRU-based sequential retrieval model. In this example, we -will build a popular Transformer decoder-based model named -[Self-Attentive Sequential Recommendation (SASRec)](https://arxiv.org/abs/1808.09781) -for the same sequential recommendation task. - -Let's begin by importing all the necessary libraries. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import collections -import os - -import keras -import keras_hub -import numpy as np -import pandas as pd -import tensorflow as tf # Needed only for the dataset -from keras import ops - -import keras_rs -``` - -Let's also define all important variables/hyperparameters below. - - -```python -DATA_DIR = "./raw/data/" - -# MovieLens-specific variables -MOVIELENS_1M_URL = "https://files.grouplens.org/datasets/movielens/ml-1m.zip" -MOVIELENS_ZIP_HASH = "a6898adb50b9ca05aa231689da44c217cb524e7ebd39d264c56e2832f2c54e20" - -RATINGS_FILE_NAME = "ratings.dat" -MOVIES_FILE_NAME = "movies.dat" - -# Data processing args -MAX_CONTEXT_LENGTH = 200 -MIN_SEQUENCE_LENGTH = 3 -PAD_ITEM_ID = 0 - -RATINGS_DATA_COLUMNS = ["UserID", "MovieID", "Rating", "Timestamp"] -MOVIES_DATA_COLUMNS = ["MovieID", "Title", "Genres"] -MIN_RATING = 2 - -# Training/model args picked from SASRec paper -BATCH_SIZE = 128 -NUM_EPOCHS = 10 -LEARNING_RATE = 0.001 - -NUM_LAYERS = 2 -NUM_HEADS = 1 -HIDDEN_DIM = 50 -DROPOUT = 0.2 -``` - ---- -## Dataset - -Next, we need to prepare our dataset. Like we did in the -[sequential retrieval](/keras_rs/examples/sequential_retrieval/) -example, we are going to use the MovieLens dataset. - -The dataset preparation step is fairly involved. The original ratings dataset -contains `(user, movie ID, rating, timestamp)` tuples (among other columns, -which are not important for this example). Since we are dealing with sequential -retrieval, we need to create movie sequences for every user, where the sequences -are ordered by timestamp. - -Let's start by downloading and reading the dataset. - - -```python -# Download the MovieLens dataset. -if not os.path.exists(DATA_DIR): - os.makedirs(DATA_DIR) - -path_to_zip = keras.utils.get_file( - fname="ml-1m.zip", - origin=MOVIELENS_1M_URL, - file_hash=MOVIELENS_ZIP_HASH, - hash_algorithm="sha256", - extract=True, - cache_dir=DATA_DIR, -) -movielens_extracted_dir = os.path.join( - os.path.dirname(path_to_zip), - "ml-1m_extracted", - "ml-1m", -) - - -# Read the dataset. -def read_data(data_directory, min_rating=None): - """Read movielens ratings.dat and movies.dat file - into dataframe. - """ - - ratings_df = pd.read_csv( - os.path.join(data_directory, RATINGS_FILE_NAME), - sep="::", - names=RATINGS_DATA_COLUMNS, - encoding="unicode_escape", - ) - ratings_df["Timestamp"] = ratings_df["Timestamp"].apply(int) - - # Remove movies with `rating < min_rating`. - if min_rating is not None: - ratings_df = ratings_df[ratings_df["Rating"] >= min_rating] - - movies_df = pd.read_csv( - os.path.join(data_directory, MOVIES_FILE_NAME), - sep="::", - names=MOVIES_DATA_COLUMNS, - encoding="unicode_escape", - ) - return ratings_df, movies_df - - -ratings_df, movies_df = read_data( - data_directory=movielens_extracted_dir, min_rating=MIN_RATING -) - -# Need to know #movies so as to define embedding layers. -movies_count = movies_df["MovieID"].max() -``` - -
-``` -Downloading data from https://files.grouplens.org/datasets/movielens/ml-1m.zip - -``` -
- - 0/5917549 ━━━━━━━━━━━━━━━━━━━━ 0s 0s/step - -
-``` - -``` -
- 8192/5917549 ━━━━━━━━━━━━━━━━━━━━ 2:27 25us/step - -
-``` - -``` -
- 40960/5917549 ━━━━━━━━━━━━━━━━━━━━ 58s 10us/step - -
-``` - -``` -
- 73728/5917549 ━━━━━━━━━━━━━━━━━━━━ 48s 8us/step - -
-``` - -``` -
- 139264/5917549 ━━━━━━━━━━━━━━━━━━━━ 34s 6us/step - -
-``` - -``` -
- 270336/5917549 ━━━━━━━━━━━━━━━━━━━━ 21s 4us/step - -
-``` - -``` -
- 532480/5917549 ━━━━━━━━━━━━━━━━━━━━ 12s 2us/step - -
-``` - -``` -
- 1056768/5917549 ━━━━━━━━━━━━━━━━━━━━ 6s 1us/step - -
-``` - -``` -
- 2121728/5917549 ━━━━━━━━━━━━━━━━━━━━ 2s 1us/step - -
-``` - -``` -
- 4218880/5917549 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step - -
-``` - -``` -
- 5917549/5917549 ━━━━━━━━━━━━━━━━━━━━ 2s 0us/step - - -
-``` -:26: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'. - ratings_df = pd.read_csv( - -:38: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'. - movies_df = pd.read_csv( - -``` -
-Now that we have read the dataset, let's create sequences of movies -for every user. Here is the function for doing just that. - - -```python - -def get_movie_sequence_per_user(ratings_df): - """Get movieID sequences for every user.""" - sequences = collections.defaultdict(list) - - for user_id, movie_id, rating, timestamp in ratings_df.values: - sequences[user_id].append( - { - "movie_id": movie_id, - "timestamp": timestamp, - "rating": rating, - } - ) - - # Sort movie sequences by timestamp for every user. - for user_id, context in sequences.items(): - context.sort(key=lambda x: x["timestamp"]) - sequences[user_id] = context - - return sequences - - -sequences = get_movie_sequence_per_user(ratings_df) -``` - -So far, we have essentially replicated what we did in the sequential retrieval -example. We have a sequence of movies for every user. - -SASRec is trained contrastively, which means the model learns to distinguish -between sequences of movies a user has actually interacted with (positive -examples) and sequences they have not interacted with (negative examples). - -The following function, `format_data`, prepares the data in this specific -format. For each user's movie sequence, it generates a corresponding -"negative sequence". This negative sequence consists of randomly -selected movies that the user has *not* interacted with, but are of the same -length as the original sequence. - - -```python - -def format_data(sequences): - examples = { - "sequence": [], - "negative_sequence": [], - } - - for user_id in sequences: - sequence = [int(d["movie_id"]) for d in sequences[user_id]] - - # Get negative sequence. - def random_negative_item_id(low, high, positive_lst): - sampled = np.random.randint(low=low, high=high) - while sampled in positive_lst: - sampled = np.random.randint(low=low, high=high) - return sampled - - negative_sequence = [ - random_negative_item_id(1, movies_count + 1, sequence) - for _ in range(len(sequence)) - ] - - examples["sequence"].append(np.array(sequence)) - examples["negative_sequence"].append(np.array(negative_sequence)) - - examples["sequence"] = tf.ragged.constant(examples["sequence"]) - examples["negative_sequence"] = tf.ragged.constant(examples["negative_sequence"]) - - return examples - - -examples = format_data(sequences) -ds = tf.data.Dataset.from_tensor_slices(examples).batch(BATCH_SIZE) -``` - -Now that we have the original movie interaction sequences for each user (from -`format_data`, stored in `examples["sequence"]`) and their corresponding -random negative sequences (in `examples["negative_sequence"]`), the next step is -to prepare this data for input to the model. The primary goals of this -preprocessing are: - -1. Creating Input Features and Target Labels: For sequential - recommendation, the model learns to predict the next item in a sequence - given the preceding items. This is achieved by: - - taking the original `example["sequence"]` and creating the model's - input features (`item_ids`) from all items *except the last one* - (`example["sequence"][..., :-1]`); - - creating the target "positive sequence" (what the model tries to predict - as the actual next items) by taking the original `example["sequence"]` - and shifting it, using all items *except the first one* - (`example["sequence"][..., 1:]`); - - shifting `example["negative_sequence"]` (from `format_data`) is - to create the target "negative sequence" for the contrastive loss - (`example["negative_sequence"][..., 1:]`). - -2. Handling Variable Length Sequences: Neural networks typically require - fixed-size inputs. Therefore, both the input feature sequences and the - target sequences are padded (with a special `PAD_ITEM_ID`) or truncated - to a predefined `MAX_CONTEXT_LENGTH`. A `padding_mask` is also generated - from the input features to ensure the model ignores these padded tokens - during attention calculations, i.e, these tokens will be masked. - -3. Differentiating Training and Validation/Testing: - - During training: - - Input features (`item_ids`) and context for negative sequences - are prepared as described above (all but the last item of the - original sequences). - - Target positive and negative sequences are the shifted versions of - the original sequences. - - `sample_weight` is created based on the input features to ensure - that loss is calculated only on actual items, not on padding tokens - in the targets. - - During validation/testing: - - Input features are prepared similarly. - - The model's performance is typically evaluated on its ability to - predict the actual last item of the original sequence. Thus, - `sample_weight` is configured to focus the loss calculation - only on this final prediction in the target sequences. - -Note: SASRec does the same thing we've done above, except that they take the -`item_ids[:-2]` for the validation set and `item_ids[:-1]` for the test set. -We skip that here for brevity. - - -```python - -def _preprocess(example, train=False): - sequence = example["sequence"] - negative_sequence = example["negative_sequence"] - - if train: - sequence = example["sequence"][..., :-1] - negative_sequence = example["negative_sequence"][..., :-1] - - batch_size = tf.shape(sequence)[0] - - if not train: - # Loss computed only on last token. - sample_weight = tf.zeros_like(sequence, dtype="float32")[..., :-1] - sample_weight = tf.concat( - [sample_weight, tf.ones((batch_size, 1), dtype="float32")], axis=1 - ) - - # Truncate/pad sequence. +1 to account for truncation later. - sequence = sequence.to_tensor( - shape=[batch_size, MAX_CONTEXT_LENGTH + 1], default_value=PAD_ITEM_ID - ) - negative_sequence = negative_sequence.to_tensor( - shape=[batch_size, MAX_CONTEXT_LENGTH + 1], default_value=PAD_ITEM_ID - ) - if train: - sample_weight = tf.cast(sequence != PAD_ITEM_ID, dtype="float32") - else: - sample_weight = sample_weight.to_tensor( - shape=[batch_size, MAX_CONTEXT_LENGTH + 1], default_value=0 - ) - - example = ( - { - # last token does not have a next token - "item_ids": sequence[..., :-1], - # padding mask for controlling attention mask - "padding_mask": (sequence != PAD_ITEM_ID)[..., :-1], - }, - { - "positive_sequence": sequence[ - ..., 1: - ], # 0th token's label will be 1st token, and so on - "negative_sequence": negative_sequence[..., 1:], - }, - sample_weight[..., 1:], # loss will not be computed on pad tokens - ) - return example - - -def preprocess_train(examples): - return _preprocess(examples, train=True) - - -def preprocess_val(examples): - return _preprocess(examples, train=False) - - -train_ds = ds.map(preprocess_train) -val_ds = ds.map(preprocess_val) -``` - -We can see a batch for each. - - -```python -for batch in train_ds.take(1): - print(batch) - -for batch in val_ds.take(1): - print(batch) - -``` - -
-``` -({'item_ids': , 'padding_mask': }, {'positive_sequence': , 'negative_sequence': }, ) -({'item_ids': , 'padding_mask': }, {'positive_sequence': , 'negative_sequence': }, ) - -``` -
---- -## Model - -To encode the input sequence, we use a Transformer decoder-based model. This -part of the model is very similar to the GPT-2 architecture. Refer to the -[GPT text generation from scratch with KerasHub](/examples/generative/text_generation_gpt/#build-the-model) -guide for more details on this part. - -One part to note is that when we are "predicting", i.e., `training` is `False`, -we get the embedding corresponding to the last movie in the sequence. This makes -sense, because at inference time, we want to predict the movie the user will -likely watch after watching the last movie. - -Also, it's worth discussing the `compute_loss` method. We embed the positive -and negative sequences using the input embedding matrix. We compute the -similarity of (positive sequence, input sequence) and (negative sequence, -input sequence) pair embeddings by computing the dot product. The goal now is -to maximize the similarity of the former and minimize the similarity of -the latter. Let's see this mathematically. Binary Cross Entropy is written -as follows: - -``` - loss = - (y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)) -``` - -Here, we assign the positive pairs a label of 1 and the negative pairs a label -of 0. So, for a positive pair, the loss reduces to: - -``` -loss = -np.log(positive_logits) -``` - -Minimising the loss means we want to maximize the log term, which in turn, -implies maximising `positive_logits`. Similarly, we want to minimize -`negative_logits`. - - -```python - -class SasRec(keras.Model): - def __init__( - self, - vocabulary_size, - num_layers, - num_heads, - hidden_dim, - dropout=0.0, - max_sequence_length=100, - dtype=None, - **kwargs, - ): - super().__init__(dtype=dtype, **kwargs) - - # ======== Layers ======== - - # === Embeddings === - self.item_embedding = keras_hub.layers.ReversibleEmbedding( - input_dim=vocabulary_size, - output_dim=hidden_dim, - embeddings_initializer="glorot_uniform", - embeddings_regularizer=keras.regularizers.l2(0.001), - dtype=dtype, - name="item_embedding", - ) - self.position_embedding = keras_hub.layers.PositionEmbedding( - initializer="glorot_uniform", - sequence_length=max_sequence_length, - dtype=dtype, - name="position_embedding", - ) - self.embeddings_add = keras.layers.Add( - dtype=dtype, - name="embeddings_add", - ) - self.embeddings_dropout = keras.layers.Dropout( - dropout, - dtype=dtype, - name="embeddings_dropout", - ) - - # === Decoder layers === - self.transformer_layers = [] - for i in range(num_layers): - self.transformer_layers.append( - keras_hub.layers.TransformerDecoder( - intermediate_dim=hidden_dim, - num_heads=num_heads, - dropout=dropout, - layer_norm_epsilon=1e-05, - # SASRec uses ReLU, although GeLU might be a better option - activation="relu", - kernel_initializer="glorot_uniform", - normalize_first=True, - dtype=dtype, - name=f"transformer_layer_{i}", - ) - ) - - # === Final layer norm === - self.layer_norm = keras.layers.LayerNormalization( - axis=-1, - epsilon=1e-8, - dtype=dtype, - name="layer_norm", - ) - - # === Retrieval === - # The layer that performs the retrieval. - self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False) - - # === Loss === - self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True, reduction=None) - - # === Attributes === - self.vocabulary_size = vocabulary_size - self.num_layers = num_layers - self.num_heads = num_heads - self.hidden_dim = hidden_dim - self.dropout = dropout - self.max_sequence_length = max_sequence_length - - def _get_last_non_padding_token(self, tensor, padding_mask): - valid_token_mask = ops.logical_not(padding_mask) - seq_lengths = ops.sum(ops.cast(valid_token_mask, "int32"), axis=1) - last_token_indices = ops.maximum(seq_lengths - 1, 0) - - indices = ops.expand_dims(last_token_indices, axis=(-2, -1)) - gathered_tokens = ops.take_along_axis(tensor, indices, axis=1) - last_token_embedding = ops.squeeze(gathered_tokens, axis=1) - - return last_token_embedding - - def build(self, input_shape): - embedding_shape = list(input_shape) + [self.hidden_dim] - - # Model - self.item_embedding.build(input_shape) - self.position_embedding.build(embedding_shape) - - self.embeddings_add.build((embedding_shape, embedding_shape)) - self.embeddings_dropout.build(embedding_shape) - - for transformer_layer in self.transformer_layers: - transformer_layer.build(decoder_sequence_shape=embedding_shape) - - self.layer_norm.build(embedding_shape) - - # Retrieval - self.retrieval.candidate_embeddings = self.item_embedding.embeddings - self.retrieval.build(input_shape) - - # Chain to super - super().build(input_shape) - - def call(self, inputs, training=False): - item_ids, padding_mask = inputs["item_ids"], inputs["padding_mask"] - - x = self.item_embedding(item_ids) - position_embedding = self.position_embedding(x) - x = self.embeddings_add((x, position_embedding)) - x = self.embeddings_dropout(x) - - for transformer_layer in self.transformer_layers: - x = transformer_layer(x, decoder_padding_mask=padding_mask) - - item_sequence_embedding = self.layer_norm(x) - result = {"item_sequence_embedding": item_sequence_embedding} - - # At inference, perform top-k retrieval. - if not training: - # need to extract last non-padding token. - last_item_embedding = self._get_last_non_padding_token( - item_sequence_embedding, padding_mask - ) - result["predictions"] = self.retrieval(last_item_embedding) - - return result - - def compute_loss(self, x, y, y_pred, sample_weight, training=False): - item_sequence_embedding = y_pred["item_sequence_embedding"] - y_positive_sequence = y["positive_sequence"] - y_negative_sequence = y["negative_sequence"] - - # Embed positive, negative sequences. - positive_sequence_embedding = self.item_embedding(y_positive_sequence) - negative_sequence_embedding = self.item_embedding(y_negative_sequence) - - # Logits - positive_logits = ops.sum( - ops.multiply(positive_sequence_embedding, item_sequence_embedding), - axis=-1, - ) - negative_logits = ops.sum( - ops.multiply(negative_sequence_embedding, item_sequence_embedding), - axis=-1, - ) - logits = ops.concatenate([positive_logits, negative_logits], axis=1) - - # Labels - labels = ops.concatenate( - [ - ops.ones_like(positive_logits), - ops.zeros_like(negative_logits), - ], - axis=1, - ) - - # sample weights - sample_weight = ops.concatenate( - [sample_weight, sample_weight], - axis=1, - ) - - loss = self.loss_fn( - y_true=ops.expand_dims(labels, axis=-1), - y_pred=ops.expand_dims(logits, axis=-1), - sample_weight=sample_weight, - ) - loss = ops.divide_no_nan(ops.sum(loss), ops.sum(sample_weight)) - - return loss - - def compute_output_shape(self, inputs_shape): - return list(inputs_shape) + [self.hidden_dim] - -``` - -Let's instantiate our model and do some sanity checks. - - -```python -model = SasRec( - vocabulary_size=movies_count + 1, - num_layers=NUM_LAYERS, - num_heads=NUM_HEADS, - hidden_dim=HIDDEN_DIM, - dropout=DROPOUT, - max_sequence_length=MAX_CONTEXT_LENGTH, -) - -# Training -output = model( - inputs={ - "item_ids": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="int32"), - "padding_mask": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="bool"), - }, - training=True, -) -print(output["item_sequence_embedding"].shape) - -# Inference -output = model( - inputs={ - "item_ids": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="int32"), - "padding_mask": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="bool"), - }, - training=False, -) -print(output["predictions"].shape) -``` - -
-``` -(2, 200, 50) - -(2, 10) - -``` -
-Now, let's compile and train our model. - - -```python -model.compile( - optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE, beta_2=0.98), -) -model.fit( - x=train_ds, - validation_data=val_ds, - epochs=NUM_EPOCHS, -) -``` - -
-``` -Epoch 1/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 3:07 4s/step - loss: 0.6965 - -
-``` - -``` -
- 2/48 ━━━━━━━━━━━━━━━━━━━━ 2:08 3s/step - loss: 0.6946 - -
-``` - -``` -
- 3/48 ━━━━━━━━━━━━━━━━━━━━ 1:03 1s/step - loss: 0.6926 - -
-``` - -``` -
- 4/48 ━━━━━━━━━━━━━━━━━━━━ 41s 944ms/step - loss: 0.6903 - 5/48 ━━━━━━━━━━━━━━━━━━━━ 30s 713ms/step - loss: 0.6881 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 24s 574ms/step - loss: 0.6859 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 19s 481ms/step - loss: 0.6836 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 16s 415ms/step - loss: 0.6813 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 14s 365ms/step - loss: 0.6790 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 12s 327ms/step - loss: 0.6767 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 9s 270ms/step - loss: 0.6720 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 10s 296ms/step - loss: 0.6744 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 8s 249ms/step - loss: 0.6697 - -
-``` - -``` -
- 14/48 ━━━━━━━━━━━━━━━━━━━━ 7s 234ms/step - loss: 0.6674 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 7s 220ms/step - loss: 0.6651 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 6s 208ms/step - loss: 0.6564 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 6s 195ms/step - loss: 0.6602 - 18/48 ━━━━━━━━━━━━━━━━━━━━ 5s 183ms/step - loss: 0.6580 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 5s 176ms/step - loss: 0.6508 - -
-``` - -``` -
- 24/48 ━━━━━━━━━━━━━━━━━━━━ 3s 139ms/step - loss: 0.6394 - 22/48 ━━━━━━━━━━━━━━━━━━━━ 3s 153ms/step - loss: 0.6426 - 20/48 ━━━━━━━━━━━━━━━━━━━━ 4s 167ms/step - loss: 0.6457 - -
-``` - -``` -
- 25/48 ━━━━━━━━━━━━━━━━━━━━ 3s 134ms/step - loss: 0.6379 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 3s 146ms/step - loss: 0.6410 - -
-``` - -``` -
- 21/48 ━━━━━━━━━━━━━━━━━━━━ 4s 158ms/step - loss: 0.6472 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 2s 130ms/step - loss: 0.6363 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 2s 121ms/step - loss: 0.6242 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 2s 118ms/step - loss: 0.6229 - -
-``` - -``` -
- 27/48 ━━━━━━━━━━━━━━━━━━━━ 2s 128ms/step - loss: 0.6315 - 28/48 ━━━━━━━━━━━━━━━━━━━━ 2s 124ms/step - loss: 0.6329 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 1s 108ms/step - loss: 0.6138 - 32/48 ━━━━━━━━━━━━━━━━━━━━ 1s 111ms/step - loss: 0.6109 - 31/48 ━━━━━━━━━━━━━━━━━━━━ 1s 113ms/step - loss: 0.6254 - 35/48 ━━━━━━━━━━━━━━━━━━━━ 1s 101ms/step - loss: 0.6118 - 38/48 ━━━━━━━━━━━━━━━━━━━━ 0s 93ms/step - loss: 0.6055  - 34/48 ━━━━━━━━━━━━━━━━━━━━ 1s 105ms/step - loss: 0.6092 - 37/48 ━━━━━━━━━━━━━━━━━━━━ 1s 96ms/step - loss: 0.6064 - -
-``` - -``` -
- 36/48 ━━━━━━━━━━━━━━━━━━━━ 1s 98ms/step - loss: 0.6148 - -
-``` - -``` -
- 41/48 ━━━━━━━━━━━━━━━━━━━━ 0s 89ms/step - loss: 0.6028  - 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 91ms/step - loss: 0.6037 - -
-``` - -``` -
- 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 93ms/step - loss: 0.6046 - 45/48 ━━━━━━━━━━━━━━━━━━━━ 0s 81ms/step - loss: 0.5972 - -
-``` - -``` -
- 47/48 ━━━━━━━━━━━━━━━━━━━━ 0s 78ms/step - loss: 0.5923 - 43/48 ━━━━━━━━━━━━━━━━━━━━ 0s 85ms/step - loss: 0.6009 - -
-``` - -``` -
- 42/48 ━━━━━━━━━━━━━━━━━━━━ 0s 87ms/step - loss: 0.5964 - 44/48 ━━━━━━━━━━━━━━━━━━━━ 0s 83ms/step - loss: 0.6000 - -
-``` - -``` -
- 46/48 ━━━━━━━━━━━━━━━━━━━━ 0s 79ms/step - loss: 0.5981 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 0s 138ms/step - loss: 0.5915 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 13s 199ms/step - loss: 0.5908 - val_loss: 0.5149 - - -
-``` -Epoch 2/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 2:12 3s/step - loss: 0.4476 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4472 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4469 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4458 - - - 5/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4456 - 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4456 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4459 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4463 - 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4461 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4465 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4467 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4469 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4469 - -
-``` - -``` -
- 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step - loss: 0.4471 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step - loss: 0.4472 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step - loss: 0.4472 - -
-``` - -``` -
- 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4470 - 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - loss: 0.4470 - -
-``` - -``` -
- 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4470 - 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.4472 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step - loss: 0.4472 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.4466 - 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - loss: 0.4466 - -
-``` - -``` -
- 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - loss: 0.4464 - 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.4463 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step - loss: 0.4453 - 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step - loss: 0.4452 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4454 - 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step - loss: 0.4461 - 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4456 - 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step - loss: 0.4462 - -
-``` - -``` -
- 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.4455 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - loss: 0.4451 - 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.4449 - -
-``` - -``` -
- 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4448 - -
-``` - -``` -
- 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4448 - 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.4448 - -
-``` - -``` -
- 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4446 - 45/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4443 - -
-``` - -``` -
- 38/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4446 - 44/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4442 - -
-``` - -``` -
- 47/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4441 - 42/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4445 - 41/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4445 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4441 - 46/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4442 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - loss: 0.4441 - val_loss: 0.5084 - - -
-``` -Epoch 3/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 1s 42ms/step - loss: 0.4316 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - loss: 0.4313 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4309 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4299 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4298 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4302 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4304 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4307 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4310 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4313 - 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4316 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4317 - -
-``` - -``` -
- 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4319 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4321 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4322 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4323 - -
-``` - -``` -
- 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4323 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4324 - -
-``` - -``` -
- 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4324 - -
-``` - -``` -
- 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4324 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4322 - 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4320 - -
-``` - -``` -
- 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4321 - -
-``` - -``` -
- 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4322 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4317 - -
-``` - -``` -
- 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4317 - -
-``` - -``` -
- 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4316 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4315 - 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4312 - -
-``` - -``` -
- 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4313 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4315 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4313 - 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4311 - 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4311 - 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4312 - 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4311 - 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4312 - -
-``` - -``` -
- 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4311 - -
-``` - -``` -
- 41/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4310 - 43/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.4310 - -
-``` - -``` -
- 42/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4311 - 46/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4309 - 45/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.4309 - -
-``` - -``` -
- 47/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4309 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.4308 - val_loss: 0.4923 - - -
-``` -Epoch 4/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 1s 41ms/step - loss: 0.4203 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4200 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4195 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4183 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4181 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4183 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4185 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4187 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4190 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4195 - 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4193 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.4196 - 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4198 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.4199 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4199 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4200 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4200 - -
-``` - -``` -
- 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4198 - -
-``` - -``` -
- 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4198 - 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4199 - 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4200 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4197 - -
-``` - -``` -
- 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4196 - -
-``` - -``` -
- 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4195 - 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4194 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4193 - -
-``` - -``` -
- 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4192 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4191 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.4189 - -
-``` - -``` -
- 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4188 - 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4188 - -
-``` - -``` -
- 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4187 - 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4190 - -
-``` - -``` -
- 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4186 - 42/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4185 - 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4186 - -
-``` - -``` -
- 41/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4185 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.4182 - val_loss: 0.4797 - - -
-``` -Epoch 5/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 1s 41ms/step - loss: 0.4058 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4057 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4053 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4041 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4036 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4037 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4038 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4039 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4042 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4044 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4045 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.4045 - 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4046 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.4046 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4046 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4046 - -
-``` - -``` -
- 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4046 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4045 - -
-``` - -``` -
- 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4045 - -
-``` - -``` -
- 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4044 - 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4038 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4041 - 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4040 - -
-``` - -``` -
- 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4037 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4036 - 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4034 - -
-``` - -``` -
- 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4033 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4032 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4031 - 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4029 - -
-``` - -``` -
- 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4027 - 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4030 - -
-``` - -``` -
- 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4025 - 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4023 - -
-``` - -``` -
- 38/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4022 - 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4023 - 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4024 - -
-``` - -``` -
- 41/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4020 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4028 - -
-``` - -``` -
- 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4026 - -
-``` - -``` -
- 44/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.4019 - 42/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4020 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.4014 - val_loss: 0.4611 - - -
-``` -Epoch 6/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 1s 41ms/step - loss: 0.3831 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3830 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3827 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3816 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3811 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3811 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3811 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3812 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3813 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3814 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3815 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3814 - -
-``` - -``` -
- 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3814 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3814 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3812 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3811 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3809 - 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3805 - 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3807 - 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3809 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3802 - 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3799 - -
-``` - -``` -
- 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3804 - 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3798 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3795 - -
-``` - -``` -
- 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3793 - -
-``` - -``` -
- 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3792 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3786 - 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3791 - 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3785 - 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.3788 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.3789 - -
-``` - -``` -
- 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3783 - 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3778 - 38/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3775 - 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3778 - -
-``` - -``` -
- 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3776 - 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3777 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.3760 - val_loss: 0.4355 - - -
-``` -Epoch 7/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 1s 41ms/step - loss: 0.3559 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3559 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3555 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3544 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3539 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3538 - 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3539 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3538 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3539 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3540 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3540 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3540 - 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3540 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3539 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3538 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3537 - 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3536 - -
-``` - -``` -
- 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3534 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3533 - 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3531 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3524 - -
-``` - -``` -
- 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3529 - -
-``` - -``` -
- 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3525 - 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3527 - 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3523 - -
-``` - -``` -
- 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3519 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3517 - -
-``` - -``` -
- 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3513 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3512 - 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3513 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3514 - 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3516 - -
-``` - -``` -
- 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3511 - 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3508 - 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3510 - -
-``` - -``` -
- 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3509 - 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3508 - 38/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3508 - -
-``` - -``` -
- 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3507 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.3500 - val_loss: 0.4174 - - -
-``` -Epoch 8/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 1s 41ms/step - loss: 0.3339 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3343 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3341 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3333 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3329 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3329 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3328 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3328 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3329 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3330 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3330 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3330 - -
-``` - -``` -
- 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3330 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3329 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3328 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3327 - -
-``` - -``` -
- 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3326 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3323 - 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3324 - -
-``` - -``` -
- 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3325 - -
-``` - -``` -
- 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3318 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3319 - 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3320 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3317 - -
-``` - -``` -
- 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3318 - -
-``` - -``` -
- 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3316 - -
-``` - -``` -
- 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3315 - -
-``` - -``` -
- 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3314 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3311 - -
-``` - -``` -
- 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3312 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3314 - 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3312 - 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3309 - -
-``` - -``` -
- 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3309 - 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3308 - 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3308 - 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3307 - 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3310 - -
-``` - -``` -
- 38/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3308 - 41/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3307 - -
-``` - -``` -
- 44/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3307 - -
-``` - -``` -
- 42/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3307 - 43/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3307 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.3306 - val_loss: 0.4035 - - -
-``` -Epoch 9/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 2s 44ms/step - loss: 0.3179 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3187 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3186 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3178 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3174 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3173 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3172 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3171 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3172 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3172 - 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3172 - -
-``` - -``` -
- 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3171 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3172 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3170 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3169 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3168 - -
-``` - -``` -
- 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3167 - -
-``` - -``` -
- 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3165 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3164 - 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3163 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3161 - 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3158 - 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3160 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3157 - 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3156 - -
-``` - -``` -
- 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3155 - 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3155 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.3154 - 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3152 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3152 - -
-``` - -``` -
- 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3153 - -
-``` - -``` -
- 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.3151 - 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3150 - 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3150 - 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3150 - -
-``` - -``` -
- 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3150 - 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3150 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.3149 - val_loss: 0.3927 - - -
-``` -Epoch 10/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 1s 42ms/step - loss: 0.3042 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3054 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3054 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3047 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3043 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3042 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3040 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3040 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3040 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3040 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3039 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3037 - -
-``` - -``` -
- 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3038 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3035 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3033 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3032 - -
-``` - -``` -
- 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3026 - 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3026 - -
-``` - -``` -
- 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3024 - 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3025 - 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3031 - -
-``` - -``` -
- 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3022 - 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3020 - 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3022 - -
-``` - -``` -
- 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3020 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3019 - -
-``` - -``` -
- 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3019 - -
-``` - -``` -
- 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3017 - 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3018 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3015 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3016 - 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3015 - -
-``` - -``` -
- 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3016 - -
-``` - -``` -
- 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3015 - 38/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3015 - -
-``` - -``` -
- 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3015 - 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3014 - -
-``` - -``` -
- 41/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3015 - 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3015 - -
-``` - -``` -
- 42/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3015 - 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3015 - 44/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3015 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.3015 - val_loss: 0.3829 - - - - - -
-``` - - -``` -
---- -## Making predictions - -Now that we have a model, we would like to be able to make predictions. - -So far, we have only handled movies by id. Now is the time to create a mapping -keyed by movie IDs to be able to surface the titles. - - -```python -movie_id_to_movie_title = dict(zip(movies_df["MovieID"], movies_df["Title"])) -movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. -``` - -We then simply use the Keras `model.predict()` method. Under the hood, it calls -the `BruteForceRetrieval` layer to perform the actual retrieval. - -Note that this model can retrieve movies already watched by the user. We could -easily add logic to remove them if that is desirable. - - -```python -for ele in val_ds.unbatch().take(1): - test_sample = ele[0] - test_sample["item_ids"] = tf.expand_dims(test_sample["item_ids"], axis=0) - test_sample["padding_mask"] = tf.expand_dims(test_sample["padding_mask"], axis=0) - -movie_sequence = np.array(test_sample["item_ids"])[0] -for movie_id in movie_sequence: - if movie_id == 0: - continue - print(movie_id_to_movie_title[movie_id], end="; ") -print() - -predictions = model.predict(test_sample)["predictions"] -predictions = keras.ops.convert_to_numpy(predictions) - -for movie_id in predictions[0]: - print(movie_id_to_movie_title[movie_id]) -``` - -
-``` -Girl, Interrupted (1999); Back to the Future (1985); Titanic (1997); Cinderella (1950); Meet Joe Black (1998); Last Days of Disco, The (1998); Erin Brockovich (2000); Christmas Story, A (1983); To Kill a Mockingbird (1962); One Flew Over the Cuckoo's Nest (1975); Wallace & Gromit: The Best of Aardman Animation (1996); Star Wars: Episode IV - A New Hope (1977); Wizard of Oz, The (1939); Fargo (1996); Run Lola Run (Lola rennt) (1998); Rain Man (1988); Saving Private Ryan (1998); Awakenings (1990); Gigi (1958); Sound of Music, The (1965); Driving Miss Daisy (1989); Bambi (1942); Apollo 13 (1995); Mary Poppins (1964); E.T. the Extra-Terrestrial (1982); My Fair Lady (1964); Ben-Hur (1959); Big (1988); Sixth Sense, The (1999); Dead Poets Society (1989); James and the Giant Peach (1996); Ferris Bueller's Day Off (1986); Secret Garden, The (1993); Toy Story 2 (1999); Airplane! (1980); Pleasantville (1998); Dumbo (1941); Princess Bride, The (1987); Snow White and the Seven Dwarfs (1937); Miracle on 34th Street (1947); Ponette (1996); Schindler's List (1993); Beauty and the Beast (1991); Tarzan (1999); Close Shave, A (1995); Aladdin (1992); Toy Story (1995); Bug's Life, A (1998); Antz (1998); Hunchback of Notre Dame, The (1996); Hercules (1997); Mulan (1998); Pocahontas (1995); - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 790ms/step - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 791ms/step - - -
-``` -Groundhog Day (1993) -Aladdin (1992) -Toy Story (1995) -Forrest Gump (1994) -Bug's Life, A (1998) -Lion King, The (1994) -Shakespeare in Love (1998) -American Beauty (1999) -Sixth Sense, The (1999) -Ghostbusters (1984) - -``` -
-And that's all! - diff --git a/templates/examples/keras_rs/scann.md b/templates/examples/keras_rs/scann.md deleted file mode 100644 index 473c125c41..0000000000 --- a/templates/examples/keras_rs/scann.md +++ /dev/null @@ -1,2161 +0,0 @@ -# Faster retrieval with Scalable Nearest Neighbours (ScANN) - -**Author:** [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Using ScANN for faster retrieval. - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/scann.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/scann.py) - - - ---- -## Introduction - -Retrieval models are designed to quickly identify a small set of highly relevant -candidates from vast pools of data, often comprising millions or even hundreds -of millions of items. To effectively respond to the user's context and behavior -in real time, these models must perform this task in just milliseconds. - -Approximate nearest neighbor (ANN) search is the key technology that enables -this level of efficiency. In this tutorial, we'll demonstrate how to leverage -ScANN—a cutting-edge nearest neighbor retrieval library—to effortlessly scale -retrieval for millions of items. - -[ScANN](https://research.google/blog/announcing-scann-efficient-vector-similarity-search/), -developed by Google Research, is a high-performance library designed for -dense vector similarity search at scale. It efficiently indexes a database of -candidate embeddings, enabling rapid search during inference. By leveraging -advanced vector compression techniques and finely tuned algorithms, ScaNN -strikes an optimal balance between speed and accuracy. As a result, it can -significantly outperform brute-force search methods, delivering fast retrieval -with minimal loss in accuracy. - -We will start with the same code as the -[basic retrieval example](/keras_rs/examples/basic_retrieval/). -Data processing, model building, and training remain exactly the same. Feel free -to skip this part if you have gone over the basic retrieval example before. - -Note: ScANN does not have its own separate layer in KerasRS because the ScANN -library is TensorFlow-only. Here, in this example, we directly use the ScANN -library and demonstrate its usage with KerasRS. - ---- -## Imports - -Let's install the `scann` library and import all necessary packages. We will -also set the backend to JAX. - - -```python -# ruff: noqa: E402 -``` - - -```python -!pip install -q scann -``` - -
-``` -[?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/11.8 MB ? eta -:--:-- -``` -
- ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - - - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - - - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - - - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/11.8 MB 2.8 MB/s eta 0:00:04 - - - ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/11.8 MB 2.8 MB/s eta 0:00:04 - ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/11.8 MB 2.8 MB/s eta 0:00:04 - ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/11.8 MB 2.8 MB/s eta 0:00:04 - ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/11.8 MB 2.8 MB/s eta 0:00:04 - ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/11.8 MB 2.8 MB/s eta 0:00:04 - ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/11.8 MB 2.8 MB/s eta 0:00:04 - - - ━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━ 4.2/11.8 MB 4.2 MB/s eta 0:00:02 - ━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━ 5.6/11.8 MB 5.3 MB/s eta 0:00:02 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━ 9.4/11.8 MB 8.9 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 10.5/11.8 MB 9.3 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 10.5/11.8 MB 9.3 MB/s eta 0:00:01 - - - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 10.5/11.8 MB 9.3 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 10.5/11.8 MB 9.3 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 10.5/11.8 MB 9.3 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 10.5/11.8 MB 9.3 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 10.5/11.8 MB 9.3 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 11.8/11.8 MB 17.3 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.8/11.8 MB 16.4 MB/s eta 0:00:00 - [?25h - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import time -import uuid - -import keras -import tensorflow as tf # Needed for the dataset -import tensorflow_datasets as tfds -from scann import scann_ops - -import keras_rs -``` - ---- -## Preparing the dataset - - -```python -# Ratings data with user and movie data. -ratings = tfds.load("movielens/100k-ratings", split="train") -# Features of all the available movies. -movies = tfds.load("movielens/100k-movies", split="train") - -# Get user and movie counts so that we can define embedding layers for both. -users_count = ( - ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) - .reduce(tf.constant(0, tf.int32), tf.maximum) - .numpy() -) - -movies_count = movies.cardinality().numpy() - - -# Preprocess the dataset, by selecting only the relevant columns. -def preprocess_rating(x): - return ( - # Input is the user IDs - tf.strings.to_number(x["user_id"], out_type=tf.int32), - # Labels are movie IDs + ratings between 0 and 1. - { - "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), - "rating": (x["user_rating"] - 1.0) / 4.0, - }, - ) - - -shuffled_ratings = ratings.map(preprocess_rating).shuffle( - 100_000, seed=42, reshuffle_each_iteration=False -) -# Train-test split. -train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() -test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() -``` - ---- -## Implementing the Model - - -```python - -class RetrievalModel(keras.Model): - def __init__( - self, - num_users, - num_candidates, - embedding_dimension=32, - **kwargs, - ): - super().__init__(**kwargs) - # Our query tower, simply an embedding table. - self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension) - # Our candidate tower, simply an embedding table. - self.candidate_embedding = keras.layers.Embedding( - num_candidates, embedding_dimension - ) - - self.loss_fn = keras.losses.MeanSquaredError() - - def build(self, input_shape): - self.user_embedding.build(input_shape) - self.candidate_embedding.build(input_shape) - - super().build(input_shape) - - def call(self, inputs, training=False): - user_embeddings = self.user_embedding(inputs) - result = { - "user_embeddings": user_embeddings, - } - return result - - def compute_loss(self, x, y, y_pred, sample_weight, training=True): - candidate_id, rating = y["movie_id"], y["rating"] - user_embeddings = y_pred["user_embeddings"] - candidate_embeddings = self.candidate_embedding(candidate_id) - - labels = keras.ops.expand_dims(rating, -1) - # Compute the affinity score by multiplying the two embeddings. - scores = keras.ops.sum( - keras.ops.multiply(user_embeddings, candidate_embeddings), - axis=1, - keepdims=True, - ) - return self.loss_fn(labels, scores, sample_weight) - -``` - ---- -## Training the model - - -```python -model = RetrievalModel(users_count + 1000, movies_count + 1000) -model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.1)) - -history = model.fit( - train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50 -) -``` - -
-``` -Epoch 1/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 2:34 2s/step - loss: 0.4476 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 17s 223ms/step - loss: 0.4543 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - loss: 0.4760 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4772 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - loss: 0.4772 - - -
-``` -Epoch 2/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 17s 222ms/step - loss: 0.4476 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4771 - - -
-``` -Epoch 3/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4475 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4542 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4771 - - -
-``` -Epoch 4/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.4475 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4761 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4771 - - -
-``` -Epoch 5/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4475 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4541 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 27ms/step - loss: 0.4770 - val_loss: 0.4836 - - -
-``` -Epoch 6/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4474 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770 - - -
-``` -Epoch 7/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4474 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770 - - -
-``` -Epoch 8/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4474 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - - -
-``` -Epoch 9/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4474 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4769 - - -
-``` -Epoch 10/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4473 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - val_loss: 0.4836 - - -
-``` -Epoch 11/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4473 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - - -
-``` -Epoch 12/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4473 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4759 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4759 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - - -
-``` -Epoch 13/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4472 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4539 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - - -
-``` -Epoch 14/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4472 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - - -
-``` -Epoch 15/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4471 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - val_loss: 0.4835 - - -
-``` -Epoch 16/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4471 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4756 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - - -
-``` -Epoch 17/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4471 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4537 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4756 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - - -
-``` -Epoch 18/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4470 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4756 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4763 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - - -
-``` -Epoch 19/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4470 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4755 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4763 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - - -
-``` -Epoch 20/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4469 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4755 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4755 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - val_loss: 0.4835 - - -
-``` -Epoch 21/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4469 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4755 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4755 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - - -
-``` -Epoch 22/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4468 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4535 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4754 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4763 - - -
-``` -Epoch 23/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4468 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4753 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4763 - - -
-``` -Epoch 24/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4467 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4753 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4762 - - -
-``` -Epoch 25/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4466 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4752 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4759 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - val_loss: 0.4833 - - -
-``` -Epoch 26/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4466 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4751 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4761 - - -
-``` -Epoch 27/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4465 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4750 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4758 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - - -
-``` -Epoch 28/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4464 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4530 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4750 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4757 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4758 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4759 - - -
-``` -Epoch 29/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4463 - -
-``` - -``` -
- 38/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4749 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4756 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4757 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4758 - - -
-``` -Epoch 30/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4462 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4747 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4755 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4755 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - val_loss: 0.4830 - - -
-``` -Epoch 31/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4461 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4746 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4754 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4755 - - -
-``` -Epoch 32/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4460 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4744 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4752 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4754 - - -
-``` -Epoch 33/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4458 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4524 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4743 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4744 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4751 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4751 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4753 - - -
-``` -Epoch 34/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4457 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4741 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4741 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4749 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4749 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4751 - - -
-``` -Epoch 35/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4455 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4740 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4747 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4747 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4749 - val_loss: 0.4823 - - -
-``` -Epoch 36/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4453 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4738 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4745 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4745 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4747 - - -
-``` -Epoch 37/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4451 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4736 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4743 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4745 - - -
-``` -Epoch 38/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4449 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4734 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4741 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4741 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4743 - - -
-``` -Epoch 39/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4446 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4731 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4738 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4740 - - -
-``` -Epoch 40/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4443 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4509 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4727 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4734 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4737 - val_loss: 0.4812 - - -
-``` -Epoch 41/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4440 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4725 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4732 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4733 - - -
-``` -Epoch 42/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4437 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4721 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4728 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4728 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4730 - - -
-``` -Epoch 43/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4433 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4717 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4717 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4724 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4725 - - -
-``` -Epoch 44/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4429 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4712 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4719 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4719 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4719 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4721 - - -
-``` -Epoch 45/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4424 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4707 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4714 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4716 - val_loss: 0.4791 - - -
-``` -Epoch 46/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4418 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4701 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4708 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4710 - - -
-``` -Epoch 47/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4412 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.4478 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4695 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4701 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4703 - - -
-``` -Epoch 48/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4406 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4688 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4694 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4696 - - -
-``` -Epoch 49/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4398 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4680 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4680 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4686 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4688 - - -
-``` -Epoch 50/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4390 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4671 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4678 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4679 - val_loss: 0.4753 - - ---- -## Making predictions - -Before we try out ScANN, let's go with the brute force method, i.e., for a given -user, scores are computed for all movies, sorted and then the top-k -movies are picked. This is, of course, not very scalable when we have a huge -number of movies. - - -```python -candidate_embeddings = keras.ops.array(model.candidate_embedding.embeddings.numpy()) -# Artificially duplicate candidate embeddings to simulate a large number of -# movies. -candidate_embeddings = keras.ops.concatenate( - [candidate_embeddings] - + [ - candidate_embeddings - * keras.random.uniform(keras.ops.shape(candidate_embeddings)) - for _ in range(100) - ], - axis=0, -) - -user_embedding = model.user_embedding(keras.ops.array([10, 5, 42, 345])) - -# Define the brute force retrieval layer. -brute_force_layer = keras_rs.layers.BruteForceRetrieval( - candidate_embeddings=candidate_embeddings, - k=10, - return_scores=False, -) -``` - -Now, let's do a forward pass on the layer. Note that in previous tutorials, we -have the above layer as an attribute of the model class, and we then call -`.predict()`. This will obviously be faster (since it's compiled XLA code), but -since we cannot do the same for ScANN, we just do a normal forward pass here -without compilation to ensure a fair comparison. - - -```python -t0 = time.time() -pred_movie_ids = brute_force_layer(user_embedding) -print("Time taken by brute force layer (sec):", time.time() - t0) -``` - -
-``` -Time taken by brute force layer (sec): 0.22817683219909668 - -``` -
-Now, let's retrieve movies using ScANN. We will use the ScANN library from -Google Research to build the layer and then call it. To fully understand all the -arguments, please refer to the -[ScANN README file](https://github.com/google-research/google-research/tree/master/scann#readme). - - -```python - -def build_scann( - candidates, - k=10, - distance_measure="dot_product", - dimensions_per_block=2, - num_reordering_candidates=500, - num_leaves=100, - num_leaves_to_search=30, - training_iterations=12, -): - builder = scann_ops.builder( - db=candidates, - num_neighbors=k, - distance_measure=distance_measure, - ) - - builder = builder.tree( - num_leaves=num_leaves, - num_leaves_to_search=num_leaves_to_search, - training_iterations=training_iterations, - ) - builder = builder.score_ah(dimensions_per_block=dimensions_per_block) - - if num_reordering_candidates is not None: - builder = builder.reorder(num_reordering_candidates) - - # Set a unique name to prevent unintentional sharing between - # ScaNN instances. - searcher = builder.build(shared_name=str(uuid.uuid4())) - return searcher - - -def run_scann(searcher): - pred_movie_ids = searcher.search_batched_parallel( - user_embedding, - final_num_neighbors=10, - ).indices - return pred_movie_ids - - -searcher = build_scann(candidates=candidate_embeddings) - -t0 = time.time() -pred_movie_ids = run_scann(searcher) -print("Time taken by ScANN (sec):", time.time() - t0) -``` - -
-``` -Time taken by ScANN (sec): 0.0032587051391601562 - -``` -
-You can clearly see the performance improvement in terms of latency. ScANN -(0.003 seconds) takes one-fiftieth the time it takes for the brute force layer -(0.15 seconds) to run! - diff --git a/templates/examples/keras_rs/sequential_retrieval.md b/templates/examples/keras_rs/sequential_retrieval.md deleted file mode 100644 index 5341b55d85..0000000000 --- a/templates/examples/keras_rs/sequential_retrieval.md +++ /dev/null @@ -1,2334 +0,0 @@ -# Sequential retrieval [GRU4Rec] - -**Author:** [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Recommend movies using a GRU-based sequential retrieval model. - - -
ⓘ This example uses Keras 2
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/sequential_retrieval.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/sequential_retrieval.py) - - - ---- -## Introduction - -In this example, we are going to build a sequential retrieval model. Sequential -recommendation is a popular model that looks at a sequence of items that users -have interacted with previously and then predicts the next item. Here, the order -of the items within each sequence matters. So, we are going to use a recurrent -neural network to model the sequential relationship. For more details, -please refer to the [GRU4Rec](https://arxiv.org/abs/1511.06939) paper. - -Let's begin by choosing JAX as the backend we want to run on, and import all -the necessary libraries. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import collections -import os -import random - -import keras -import pandas as pd -import tensorflow as tf # Needed only for the dataset - -import keras_rs -``` - -Let's also define all important variables/hyperparameters below. - - -```python -DATA_DIR = "./raw/data/" - -# MovieLens-specific variables -MOVIELENS_1M_URL = "https://files.grouplens.org/datasets/movielens/ml-1m.zip" -MOVIELENS_ZIP_HASH = "a6898adb50b9ca05aa231689da44c217cb524e7ebd39d264c56e2832f2c54e20" - -RATINGS_FILE_NAME = "ratings.dat" -MOVIES_FILE_NAME = "movies.dat" - -# Data processing args -MAX_CONTEXT_LENGTH = 10 -MIN_SEQUENCE_LENGTH = 3 -TRAIN_DATA_FRACTION = 0.9 - -RATINGS_DATA_COLUMNS = ["UserID", "MovieID", "Rating", "Timestamp"] -MOVIES_DATA_COLUMNS = ["MovieID", "Title", "Genres"] -MIN_RATING = 2 - -# Training/model args -BATCH_SIZE = 4096 -TEST_BATCH_SIZE = 2048 -EMBEDDING_DIM = 32 -NUM_EPOCHS = 5 -LEARNING_RATE = 0.05 -``` - ---- -## Dataset - -Next, we need to prepare our dataset. Like we did in the -[basic retrieval](/keras_rs/examples/basic_retrieval/) -example, we are going to use the MovieLens dataset. - -The dataset preparation step is fairly involved. The original ratings dataset -contains `(user, movie ID, rating, timestamp)` tuples (among other columns, -which are not important for this example). Since we are dealing with sequential -retrieval, we need to create movie sequences for every user, where the sequences -are ordered by timestamp. - -Let's start by downloading and reading the dataset. - - -```python -# Download the MovieLens dataset. -if not os.path.exists(DATA_DIR): - os.makedirs(DATA_DIR) - -path_to_zip = keras.utils.get_file( - fname="ml-1m.zip", - origin=MOVIELENS_1M_URL, - file_hash=MOVIELENS_ZIP_HASH, - hash_algorithm="sha256", - extract=True, - cache_dir=DATA_DIR, -) -movielens_extracted_dir = os.path.join( - os.path.dirname(path_to_zip), - "ml-1m_extracted", - "ml-1m", -) - - -# Read the dataset. -def read_data(data_directory, min_rating=None): - """Read movielens ratings.dat and movies.dat file - into dataframe. - """ - - ratings_df = pd.read_csv( - os.path.join(data_directory, RATINGS_FILE_NAME), - sep="::", - names=RATINGS_DATA_COLUMNS, - encoding="unicode_escape", - ) - ratings_df["Timestamp"] = ratings_df["Timestamp"].apply(int) - - # Remove movies with `rating < min_rating`. - if min_rating is not None: - ratings_df = ratings_df[ratings_df["Rating"] >= min_rating] - - movies_df = pd.read_csv( - os.path.join(data_directory, MOVIES_FILE_NAME), - sep="::", - names=MOVIES_DATA_COLUMNS, - encoding="unicode_escape", - ) - return ratings_df, movies_df - - -ratings_df, movies_df = read_data( - data_directory=movielens_extracted_dir, min_rating=MIN_RATING -) - -# Need to know #movies so as to define embedding layers. -movies_count = movies_df["MovieID"].max() -``` - -
-``` -Downloading data from https://files.grouplens.org/datasets/movielens/ml-1m.zip - -``` -
- - 0/5917549 ━━━━━━━━━━━━━━━━━━━━ 0s 0s/step - -
-``` - -``` -
- 40960/5917549 ━━━━━━━━━━━━━━━━━━━━ 10s 2us/step - -
-``` - -``` -
- 155648/5917549 ━━━━━━━━━━━━━━━━━━━━ 5s 1us/step - -
-``` - -``` -
- 647168/5917549 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step - -
-``` - -``` -
- 2629632/5917549 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step - -
-``` - -``` -
- 5917549/5917549 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step - - -
-``` -:26: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'. - ratings_df = pd.read_csv( - -:38: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'. - movies_df = pd.read_csv( - -``` -
-Now that we have read the dataset, let's create sequences of movies -for every user. Here is the function for doing just that. - - -```python - -def get_movie_sequence_per_user(ratings_df): - """Get movieID sequences for every user.""" - sequences = collections.defaultdict(list) - - for user_id, movie_id, rating, timestamp in ratings_df.values: - sequences[user_id].append( - { - "movie_id": movie_id, - "timestamp": timestamp, - "rating": rating, - } - ) - - # Sort movie sequences by timestamp for every user. - for user_id, context in sequences.items(): - context.sort(key=lambda x: x["timestamp"]) - sequences[user_id] = context - - return sequences - -``` - -We need to do some filtering and processing before we proceed -with training the model: - -1. Form sequences of all lengths up to - `min(user_sequence_length, MAX_CONTEXT_LENGTH)`. So, every user - will have multiple sequences corresponding to it. -2. Get labels, i.e., Given a sequence of length `n`, the first - `n-1` tokens will be fed to the model as input, and the label - will be the last token. -3. Remove all user sequences with less than `MIN_SEQUENCE_LENGTH` - movies. -4. Pad all sequences to `MAX_CONTEXT_LENGTH`. - - -```python - -def generate_examples_from_user_sequences(sequences): - """Generates sequences for all users, with padding, truncation, etc.""" - - def generate_examples_from_user_sequence(sequence): - """Generates examples for a single user sequence.""" - - examples = [] - for label_idx in range(1, len(sequence)): - start_idx = max(0, label_idx - MAX_CONTEXT_LENGTH) - context = sequence[start_idx:label_idx] - - # Padding - while len(context) < MAX_CONTEXT_LENGTH: - context.append( - { - "movie_id": 0, - "timestamp": 0, - "rating": 0.0, - } - ) - - label_movie_id = int(sequence[label_idx]["movie_id"]) - context_movie_id = [int(movie["movie_id"]) for movie in context] - - examples.append( - { - "context_movie_id": context_movie_id, - "label_movie_id": label_movie_id, - }, - ) - return examples - - all_examples = [] - for sequence in sequences.values(): - if len(sequence) < MIN_SEQUENCE_LENGTH: - continue - - user_examples = generate_examples_from_user_sequence(sequence) - - all_examples.extend(user_examples) - - return all_examples - -``` - -Let's split the dataset into train and test sets. Also, we need to -change the format of the dataset dictionary so as to enable conversion -to a `tf.data.Dataset` object. - - -```python -sequences = get_movie_sequence_per_user(ratings_df) -examples = generate_examples_from_user_sequences(sequences) - -# Train-test split. -random.shuffle(examples) -split_index = int(TRAIN_DATA_FRACTION * len(examples)) -train_examples = examples[:split_index] -test_examples = examples[split_index:] - - -def list_of_dicts_to_dict_of_lists(list_of_dicts): - """Convert list of dictionaries to dictionary of lists for - `tf.data` conversion. - """ - dict_of_lists = collections.defaultdict(list) - for dictionary in list_of_dicts: - for key, value in dictionary.items(): - dict_of_lists[key].append(value) - return dict_of_lists - - -train_examples = list_of_dicts_to_dict_of_lists(train_examples) -test_examples = list_of_dicts_to_dict_of_lists(test_examples) - -train_ds = tf.data.Dataset.from_tensor_slices(train_examples).map( - lambda x: (x["context_movie_id"], x["label_movie_id"]) -) -test_ds = tf.data.Dataset.from_tensor_slices(test_examples).map( - lambda x: (x["context_movie_id"], x["label_movie_id"]) -) -``` - -We need to batch our datasets. We also user `cache()` and `prefetch()` -for better performance. - - -```python -train_ds = train_ds.batch(BATCH_SIZE).cache().prefetch(tf.data.AUTOTUNE) -test_ds = test_ds.batch(TEST_BATCH_SIZE).cache().prefetch(tf.data.AUTOTUNE) -``` - -Let's print out one batch. - - -```python -for sample in train_ds.take(1): - print(sample) -``` - -
-``` -(, ) - -``` -
---- -## Model and Training - -In the basic retrieval example, we used one query tower for the -user, and the candidate tower for the candidate movie. We are -going to use a two-tower architecture here as well. However, -we use the query tower with a Gated Recurrent Unit (GRU) layer -to encode the sequence of historical movies, and keep the same -candidate tower for the candidate movie. - -Note: Take a look at how the labels are defined. The label tensor -(of shape `(batch_size, batch_size)`) contains one-hot vectors. The idea -is: for every sample, consider movie IDs corresponding to other samples in -the batch as negatives. - - -```python - -class SequentialRetrievalModel(keras.Model): - """Create the sequential retrieval model. - - Args: - movies_count: Total number of unique movies in the dataset. - embedding_dimension: Output dimension for movie embedding tables. - """ - - def __init__( - self, - movies_count, - embedding_dimension=128, - **kwargs, - ): - super().__init__(**kwargs) - # Our query tower, simply an embedding table followed by - # a GRU unit. This encodes sequence of historical movies. - self.query_model = keras.Sequential( - [ - keras.layers.Embedding(movies_count + 1, embedding_dimension), - keras.layers.GRU(embedding_dimension), - ] - ) - - # Our candidate tower, simply an embedding table. - self.candidate_model = keras.layers.Embedding( - movies_count + 1, embedding_dimension - ) - - # The layer that performs the retrieval. - self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False) - self.loss_fn = keras.losses.CategoricalCrossentropy( - from_logits=True, - ) - - def build(self, input_shape): - self.query_model.build(input_shape) - self.candidate_model.build(input_shape) - - # In this case, the candidates are directly the movie embeddings. - # We take a shortcut and directly reuse the variable. - self.retrieval.candidate_embeddings = self.candidate_model.embeddings - self.retrieval.build(input_shape) - super().build(input_shape) - - def call(self, inputs, training=False): - query_embeddings = self.query_model(inputs) - result = { - "query_embeddings": query_embeddings, - } - - if not training: - # Skip the retrieval of top movies during training as the - # predictions are not used. - result["predictions"] = self.retrieval(query_embeddings) - return result - - def compute_loss(self, x, y, y_pred, sample_weight, training=True): - candidate_id = y - query_embeddings = y_pred["query_embeddings"] - candidate_embeddings = self.candidate_model(candidate_id) - - num_queries = keras.ops.shape(query_embeddings)[0] - num_candidates = keras.ops.shape(candidate_embeddings)[0] - - # One-hot vectors for labels. - labels = keras.ops.eye(num_queries, num_candidates) - - # Compute the affinity score by multiplying the two embeddings. - scores = keras.ops.matmul( - query_embeddings, keras.ops.transpose(candidate_embeddings) - ) - - return self.loss_fn(labels, scores, sample_weight) - -``` - -Let's instantiate, compile and train our model. - - -```python -model = SequentialRetrievalModel( - movies_count=movies_count + 1, embedding_dimension=EMBEDDING_DIM -) - -# Compile. -model.compile(optimizer=keras.optimizers.AdamW(learning_rate=LEARNING_RATE)) - -# Train. -model.fit( - train_ds, - validation_data=test_ds, - epochs=NUM_EPOCHS, -) -``` - -
-``` -Epoch 1/5 - -``` -
- - 1/207 ━━━━━━━━━━━━━━━━━━━━ 7:12 2s/step - loss: 8.3177 - -
-``` - -``` -
- 2/207 ━━━━━━━━━━━━━━━━━━━━ 4:07 1s/step - loss: 8.3176 - -
-``` - -``` -
- 3/207 ━━━━━━━━━━━━━━━━━━━━ 2:03 607ms/step - loss: 8.3169 - -
-``` - -``` -
- 4/207 ━━━━━━━━━━━━━━━━━━━━ 1:22 407ms/step - loss: 8.3154 - -
-``` - -``` -
- 11/207 ━━━━━━━━━━━━━━━━━━━━ 24s 128ms/step - loss: 8.2616 - -
-``` - -``` -
- 12/207 ━━━━━━━━━━━━━━━━━━━━ 22s 117ms/step - loss: 8.2514 - -
-``` - -``` -
- 13/207 ━━━━━━━━━━━━━━━━━━━━ 20s 108ms/step - loss: 8.2410 - -
-``` - -``` -
- 14/207 ━━━━━━━━━━━━━━━━━━━━ 19s 100ms/step - loss: 8.2303 - -
-``` - -``` -
- 15/207 ━━━━━━━━━━━━━━━━━━━━ 17s 93ms/step - loss: 8.2196 - -
-``` - -``` -
- 16/207 ━━━━━━━━━━━━━━━━━━━━ 16s 88ms/step - loss: 8.2088 - -
-``` - -``` -
- 23/207 ━━━━━━━━━━━━━━━━━━━━ 11s 62ms/step - loss: 8.1343 - -
-``` - -``` -
- 24/207 ━━━━━━━━━━━━━━━━━━━━ 10s 60ms/step - loss: 8.1240 - -
-``` - -``` -
- 25/207 ━━━━━━━━━━━━━━━━━━━━ 10s 57ms/step - loss: 8.1139 - -
-``` - -``` -
- 26/207 ━━━━━━━━━━━━━━━━━━━━ 10s 55ms/step - loss: 8.1040 - -
-``` - -``` -
- 27/207 ━━━━━━━━━━━━━━━━━━━━ 9s 53ms/step - loss: 8.0943 - -
-``` - -``` -
- 28/207 ━━━━━━━━━━━━━━━━━━━━ 9s 52ms/step - loss: 8.0846 - -
-``` - -``` -
- 29/207 ━━━━━━━━━━━━━━━━━━━━ 8s 50ms/step - loss: 8.0751 - -
-``` - -``` -
- 39/207 ━━━━━━━━━━━━━━━━━━━━ 6s 38ms/step - loss: 7.9869 - -
-``` - -``` -
- 40/207 ━━━━━━━━━━━━━━━━━━━━ 6s 37ms/step - loss: 7.9788 - -
-``` - -``` -
- 41/207 ━━━━━━━━━━━━━━━━━━━━ 6s 37ms/step - loss: 7.9708 - -
-``` - -``` -
- 42/207 ━━━━━━━━━━━━━━━━━━━━ 5s 36ms/step - loss: 7.9629 - -
-``` - -``` -
- 43/207 ━━━━━━━━━━━━━━━━━━━━ 5s 35ms/step - loss: 7.9551 - -
-``` - -``` -
- 53/207 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - loss: 7.8825 - -
-``` - -``` -
- 54/207 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - loss: 7.8757 - -
-``` - -``` -
- 55/207 ━━━━━━━━━━━━━━━━━━━━ 4s 28ms/step - loss: 7.8691 - -
-``` - -``` -
- 56/207 ━━━━━━━━━━━━━━━━━━━━ 4s 28ms/step - loss: 7.8625 - -
-``` - -``` -
- 66/207 ━━━━━━━━━━━━━━━━━━━━ 3s 25ms/step - loss: 7.8011 - -
-``` - -``` -
- 67/207 ━━━━━━━━━━━━━━━━━━━━ 3s 24ms/step - loss: 7.7954 - -
-``` - -``` -
- 75/207 ━━━━━━━━━━━━━━━━━━━━ 2s 22ms/step - loss: 7.7518 - -
-``` - -``` -
- 83/207 ━━━━━━━━━━━━━━━━━━━━ 2s 21ms/step - loss: 7.7120 - -
-``` - -``` -
- 91/207 ━━━━━━━━━━━━━━━━━━━━ 2s 20ms/step - loss: 7.6755 - -
-``` - -``` -
- 99/207 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 7.6419 - -
-``` - -``` -
- 107/207 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 7.6108 - -
-``` - -``` -
- 115/207 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 7.5821 - -
-``` - -``` -
- 123/207 ━━━━━━━━━━━━━━━━━━━━ 1s 16ms/step - loss: 7.5553 - -
-``` - -``` -
- 131/207 ━━━━━━━━━━━━━━━━━━━━ 1s 16ms/step - loss: 7.5303 - -
-``` - -``` -
- 139/207 ━━━━━━━━━━━━━━━━━━━━ 1s 15ms/step - loss: 7.5069 - -
-``` - -``` -
- 140/207 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 7.5041 - -
-``` - -``` -
- 148/207 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 7.4823 - -
-``` - -``` -
- 157/207 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 7.4592 - -
-``` - -``` -
- 165/207 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 7.4400 - -
-``` - -``` -
- 173/207 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 7.4218 - -
-``` - -``` -
- 181/207 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 7.4045 - -
-``` - -``` -
- 189/207 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 7.3881 - -
-``` - -``` -
- 197/207 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 7.3725 - -
-``` - -``` -
- 205/207 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 7.3540 - -
-``` - -``` -
- 206/207 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 7.3558 - -
-``` - -``` -
- 207/207 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 7.3505 - -
-``` - -``` -
- 207/207 ━━━━━━━━━━━━━━━━━━━━ 8s 28ms/step - loss: 7.3487 - val_loss: 5.9852 - - -
-``` -Epoch 2/5 - -``` -
- - 1/207 ━━━━━━━━━━━━━━━━━━━━ 4:08 1s/step - loss: 6.6873 - - - 2/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6892 - - - 3/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6883 - - - 4/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6841 - -
-``` - -``` -
- 9/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6842 - -
-``` - -``` -
- 10/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6834 - -
-``` - -``` -
- 11/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6829 - -
-``` - -``` -
- 12/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6825 - -
-``` - -``` -
- 13/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6822 - -
-``` - -``` -
- 14/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6819 - 15/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6821 - -
-``` - -``` -
- 18/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6813 - 16/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6816 - 17/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6814 - -
-``` - -``` -
- 19/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6811 - -
-``` - -``` -
- 20/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6810 - -
-``` - -``` -
- 22/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6806 - 23/207 ━━━━━━━━━━━━━━━━━━━━ 2s 11ms/step - loss: 6.6805 - -
-``` - -``` -
- 21/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6808 - -
-``` - -``` -
- 25/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6804 - 24/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6803 - -
-``` - -``` -
- 26/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6804 - -
-``` - -``` -
- 27/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6804 - -
-``` - -``` -
- 31/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6803 - -
-``` - -``` -
- 28/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6803 - 29/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6803 - 30/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6803 - -
-``` - -``` -
- 32/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6802 - -
-``` - -``` -
- 38/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6796 - 36/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6795 - -
-``` - -``` -
- 34/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6799 - 33/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6801 - -
-``` - -``` -
- 37/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6797 - -
-``` - -``` -
- 35/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6800 - -
-``` - -``` -
- 39/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6793 - 40/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6792 - -
-``` - -``` -
- 42/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6787 - 43/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6788 - 41/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6791 - -
-``` - -``` -
- 44/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6783 - -
-``` - -``` -
- 45/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6782 - -
-``` - -``` -
- 46/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6780 - 47/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6769 - -
-``` - -``` -
- 50/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6772 - 49/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6765 - 48/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6776 - 51/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.6770 - -
-``` - -``` -
- 52/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6760 - -
-``` - -``` -
- 57/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.6750 - 53/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6746 - 54/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6756 - 56/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6752 - 55/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6748 - -
-``` - -``` -
- 58/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6737 - -
-``` - -``` -
- 63/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6728 - 62/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6730 - 61/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6731 - -
-``` - -``` -
- 60/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6733 - 64/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.6726 - 59/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6735 - -
-``` - -``` -
- 65/207 ━━━━━━━━━━━━━━━━━━━━ 2s 15ms/step - loss: 6.6724 - -
-``` - -``` -
- 70/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6713 - 69/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6707 - 71/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6703 - -
-``` - -``` -
- 68/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6709 - 66/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6721 - 67/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6711 - -
-``` - -``` -
- 77/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6687 - 76/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6693 - 75/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6691 - 72/207 ━━━━━━━━━━━━━━━━━━━━ 1s 15ms/step - loss: 6.6685 - -
-``` - -``` -
- 73/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6699 - -
-``` - -``` -
- 74/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6697 - 82/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6662 - -
-``` - -``` -
- 79/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6672 - 81/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6660 - 83/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6664 - -
-``` - -``` -
- 80/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6659 - 87/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.6641 - -
-``` - -``` -
- 85/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.6640 - -
-``` - -``` -
- 78/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6657 - 84/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.6647 - 86/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.6643 - -
-``` - -``` -
- 207/207 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - loss: 6.6328 - val_loss: 5.9231 - - -
-``` -Epoch 3/5 - -``` -
- - 1/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5509 - - - 2/207 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.5612 - - - 3/207 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.5651 - -
-``` - -``` -
- 14/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5684 - -
-``` - -``` -
- 15/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5687 - -
-``` - -``` -
- 16/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5688 - -
-``` - -``` -
- 17/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5689 - -
-``` - -``` -
- 18/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5691 - -
-``` - -``` -
- 19/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5692 - -
-``` - -``` -
- 21/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5694 - 20/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5694 - -
-``` - -``` -
- 25/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5701 - 23/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5696 - 24/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5699 - -
-``` - -``` -
- 22/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5695 - -
-``` - -``` -
- 29/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5717 - 26/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5704 - 33/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5723 - 28/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5710 - 32/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5722 - 31/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5720 - -
-``` - -``` -
- 27/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5713 - -
-``` - -``` -
- 30/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5719 - -
-``` - -``` -
- 36/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5726 - 34/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5724 - -
-``` - -``` -
- 37/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5727 - -
-``` - -``` -
- 35/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5728 - -
-``` - -``` -
- 38/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5730 - 39/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5731 - -
-``` - -``` -
- 43/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5731  - 40/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5731 - 41/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5731 - 42/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5731 - -
-``` - -``` -
- 44/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5730 - -
-``` - -``` -
- 45/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5730 - -
-``` - -``` -
- 46/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5729 - -
-``` - -``` -
- 47/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5728 - -
-``` - -``` -
- 48/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5727 - -
-``` - -``` -
- 49/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5727 - -
-``` - -``` -
- 50/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5726 - -
-``` - -``` -
- 51/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5725 - -
-``` - -``` -
- 53/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5723 - -
-``` - -``` -
- 56/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5720 - 55/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5718 - -
-``` - -``` -
- 57/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5716 - 52/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5716 - 54/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5719 - -
-``` - -``` -
- 58/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5710 - -
-``` - -``` -
- 59/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5709 - -
-``` - -``` -
- 60/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5709 - 63/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5704 - -
-``` - -``` -
- 62/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5707 - -
-``` - -``` -
- 61/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5706 - -
-``` - -``` -
- 64/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5703 - -
-``` - -``` -
- 65/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5694 - 66/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5699 - 70/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5695 - 67/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5700 - 69/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5690 - 68/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5697 - -
-``` - -``` -
- 71/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5688 - -
-``` - -``` -
- 72/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5687 - -
-``` - -``` -
- 75/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5684 - 73/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5686 - -
-``` - -``` -
- 76/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5683 - -
-``` - -``` -
- 74/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5685 - -
-``` - -``` -
- 83/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5672 - 84/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5673 - 77/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.5683 - 80/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5680 - -
-``` - -``` -
- 78/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.5682 - 81/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5676 - 85/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.5664 - 79/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5681 - -
-``` - -``` -
- 82/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5677 - 88/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5666 - -
-``` - -``` -
- 86/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5665 - -
-``` - -``` -
- 90/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5659 - -
-``` - -``` -
- 87/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5662 - 91/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5658 - 92/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5657 - 94/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5655 - -
-``` - -``` -
- 93/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5656 - 89/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5660 - -
-``` - -``` -
- 207/207 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 6.5498 - val_loss: 5.9322 - - -
-``` -Epoch 4/5 - -``` -
- - 1/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5131 - - - 2/207 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.5257 - - - 3/207 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.5284 - -
-``` - -``` -
- 14/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5314 - -
-``` - -``` -
- 15/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5316 - -
-``` - -``` -
- 16/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5317 - -
-``` - -``` -
- 17/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5317 - -
-``` - -``` -
- 18/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5320 - -
-``` - -``` -
- 19/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5321 - -
-``` - -``` -
- 21/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5325 - -
-``` - -``` -
- 20/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5324 - -
-``` - -``` -
- 23/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5327 - 22/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5325 - 24/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5329 - -
-``` - -``` -
- 25/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5332 - -
-``` - -``` -
- 26/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5335 - -
-``` - -``` -
- 28/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5341 - 36/207 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 6.5354 - 29/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5343 - 35/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5356 - 34/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5355 - 30/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5345 - 27/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5338 - 33/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5350 - 31/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5347 - -
-``` - -``` -
- 32/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5357 - -
-``` - -``` -
- 41/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5363 - 40/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5363 - 38/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5362 - 37/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5361 - 39/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5362 - -
-``` - -``` -
- 44/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5362 - 48/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5358 - -
-``` - -``` -
- 43/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5358 - 47/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5356 - 46/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5356 - 42/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5362 - -
-``` - -``` -
- 45/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5360 - -
-``` - -``` -
- 49/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5354 - 55/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5349 - 54/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5347 - -
-``` - -``` -
- 53/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5352 - 52/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5353  - 56/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5348 - 50/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5354 - -
-``` - -``` -
- 51/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5351 - -
-``` - -``` -
- 57/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5343 - 62/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5340  - 58/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5344 - 61/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5340 - 59/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5344 - -
-``` - -``` -
- 60/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5341 - -
-``` - -``` -
- 63/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5339 - 64/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5338 - 67/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5336 - 68/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5335 - -
-``` - -``` -
- 65/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5337 - -
-``` - -``` -
- 66/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5336 - -
-``` - -``` -
- 70/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5328 - 72/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5326 - 73/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5330 - -
-``` - -``` -
- 69/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5334 - 71/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5332 - 74/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5329 - -
-``` - -``` -
- 80/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5318 - 78/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5320 - 79/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5319 - -
-``` - -``` -
- 77/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5317 - 75/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5322 - -
-``` - -``` -
- 76/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5323 - -
-``` - -``` -
- 84/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5309 - 83/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5306 - 86/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5307 - 87/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5303 - 82/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5312 - -
-``` - -``` -
- 85/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5308 - -
-``` - -``` -
- 90/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5300 - -
-``` - -``` -
- 81/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5312 - -
-``` - -``` -
- 88/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5302 - 91/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5299 - -
-``` - -``` -
- 93/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5294 - 94/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5293 - 89/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5298 - 98/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5287 - 92/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5298 - 95/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5290 - 97/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5291 - -
-``` - -``` -
- 96/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5286 - -
-``` - -``` -
- 207/207 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 6.5158 - val_loss: 5.9527 - - -
-``` -Epoch 5/5 - -``` -
- - 1/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5082 - - - 2/207 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.5182 - - - 3/207 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.5179 - -
-``` - -``` -
- 14/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5126 - -
-``` - -``` -
- 15/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5127 - -
-``` - -``` -
- 16/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5126 - -
-``` - -``` -
- 17/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5126 - -
-``` - -``` -
- 18/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5127 - -
-``` - -``` -
- 19/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5127 - -
-``` - -``` -
- 20/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5128 - 21/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5128 - -
-``` - -``` -
- 24/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5130 - 23/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5128 - 25/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5132 - 22/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5128 - -
-``` - -``` -
- 32/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5151 - 33/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5157 - -
-``` - -``` -
- 27/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5138 - 29/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5153 - 30/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5148 - -
-``` - -``` -
- 28/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5144 - -
-``` - -``` -
- 34/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5158 - 26/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5140 - -
-``` - -``` -
- 31/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5155 - -
-``` - -``` -
- 37/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5160 - -
-``` - -``` -
- 36/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5161 - -
-``` - -``` -
- 35/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5159 - -
-``` - -``` -
- 44/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5160 - -
-``` - -``` -
- 42/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5161 - 43/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5159 - 41/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5161 - 38/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5162 - 45/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5160 - -
-``` - -``` -
- 40/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5162 - -
-``` - -``` -
- 39/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5162 - 46/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5159 - -
-``` - -``` -
- 47/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5154 - 49/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5154 - 50/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5150 - 51/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5151 - 48/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5153 - -
-``` - -``` -
- 55/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5146 - 54/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5145 - 57/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5140 - -
-``` - -``` -
- 53/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5145 - 52/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5149 - -
-``` - -``` -
- 58/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5141 - 56/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5140 - -
-``` - -``` -
- 63/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5135 - 60/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5137 - -
-``` - -``` -
- 59/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5138 - 61/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5134 - 62/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5135 - -
-``` - -``` -
- 67/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5124 - -
-``` - -``` -
- 69/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5117 - 65/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5130 - 66/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5123 - -
-``` - -``` -
- 64/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5131 - 71/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5119 - 70/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5125 - -
-``` - -``` -
- 68/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5118 - -
-``` - -``` -
- 72/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5113 - 73/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5114 - -
-``` - -``` -
- 75/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5110 - 77/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5107 - -
-``` - -``` -
- 76/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5108 - -
-``` - -``` -
- 74/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5110 - -
-``` - -``` -
- 84/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5098 - 81/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5095 - -
-``` - -``` -
- 78/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5107 - 80/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5105 - 83/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5098 - 79/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5097 - 82/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5099 - 85/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5100 - -
-``` - -``` -
- 86/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5091 - -
-``` - -``` -
- 87/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5089 - 89/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5087 - 88/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5090 - 92/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5085 - -
-``` - -``` -
- 95/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5076 - 91/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5079 - -
-``` - -``` -
- 97/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5081 - 90/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5080 - -
-``` - -``` -
- 93/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5078 - 98/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5074 - 94/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5083 - -
-``` - -``` -
- 96/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5082 - -
-``` - -``` -
- 207/207 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 6.4960 - val_loss: 5.9651 - - - - - -
-``` - - -``` -
---- -## Making predictions - -Now that we have a model, we would like to be able to make predictions. - -So far, we have only handled movies by id. Now is the time to create a mapping -keyed by movie IDs to be able to surface the titles. - - -```python -movie_id_to_movie_title = dict(zip(movies_df["MovieID"], movies_df["Title"])) -movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. -``` - -We then simply use the Keras `model.predict()` method. Under the hood, it calls -the `BruteForceRetrieval` layer to perform the actual retrieval. - -Note that this model can retrieve movies already watched by the user. We could -easily add logic to remove them if that is desirable. - - -```python -print("\n==> Movies the user has watched:") -movie_sequence = test_ds.unbatch().take(1) -for element in movie_sequence: - for movie_id in element[0][:-1]: - print(movie_id_to_movie_title[movie_id.numpy()], end=", ") - print(movie_id_to_movie_title[element[0][-1].numpy()]) - -predictions = model.predict(movie_sequence.batch(1)) -predictions = keras.ops.convert_to_numpy(predictions["predictions"]) - -print("\n==> Recommended movies for the above sequence:") -for movie_id in predictions[0]: - print(movie_id_to_movie_title[movie_id]) -``` - - -
-``` -==> Movies the user has watched: -10 Things I Hate About You (1999), American Beauty (1999), Bachelor, The (1999), Austin Powers: The Spy Who Shagged Me (1999), Arachnophobia (1990), Big Daddy (1999), Bone Collector, The (1999), Bug's Life, A (1998), Bowfinger (1999), Dead Calm (1989) - -``` -
- -
-``` - 1/Unknown 0s 300ms/step - - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 302ms/step - - - -
-``` -==> Recommended movies for the above sequence: -Creepshow (1982) -Bringing Out the Dead (1999) -Civil Action, A (1998) -Doors, The (1991) -Cruel Intentions (1999) -Brokedown Palace (1999) -Dead Calm (1989) -Condorman (1981) -Clan of the Cave Bear, The (1986) -Clerks (1994) - -/usr/local/lib/python3.11/dist-packages/keras/src/trainers/epoch_iterator.py:151: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset. - self._interrupted_warning() - -``` -
diff --git a/templates/keras_rs/examples/basic_ranking.md b/templates/keras_rs/examples/basic_ranking.md deleted file mode 100644 index 87c557733b..0000000000 --- a/templates/keras_rs/examples/basic_ranking.md +++ /dev/null @@ -1,613 +0,0 @@ -# Recommending movies: ranking - -**Author:** [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Rank movies using a two tower model. - - - [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/basic_ranking.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/basic_ranking.py) - - - ---- -## Introduction - -Recommender systems are often composed of two stages: - -1. The retrieval stage is responsible for selecting an initial set of hundreds - of candidates from all possible candidates. The main objective of this model - is to efficiently weed out all candidates that the user is not interested in. - Because the retrieval model may be dealing with millions of candidates, it - has to be computationally efficient. -2. The ranking stage takes the outputs of the retrieval model and fine-tunes - them to select the best possible handful of recommendations. Its task is to - narrow down the set of items the user may be interested in to a shortlist of - likely candidates. - -In this tutorial, we're going to focus on the second stage, ranking. If you are -interested in the retrieval stage, have a look at our -[retrieval](/keras_rs/examples/basic_retrieval/) -tutorial. - -In this tutorial, we're going to: - -1. Get our data and split it into a training and test set. -2. Implement a ranking model. -3. Fit and evaluate it. -4. Test running predictions with the model. - -Let's begin by choosing JAX as the backend we want to run on, and import all -the necessary libraries. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import keras -import tensorflow as tf # Needed for the dataset -import tensorflow_datasets as tfds -``` - ---- -## Preparing the dataset - -We're going to use the same data as the -[retrieval](/keras_rs/examples/basic_retrieval/) -tutorial. The ratings are the objectives we are trying to predict. - - -```python -# Ratings data. -ratings = tfds.load("movielens/100k-ratings", split="train") -# Features of all the available movies. -movies = tfds.load("movielens/100k-movies", split="train") -``` - -
-``` -WARNING:absl:Variant folder /root/tensorflow_datasets/movielens/100k-ratings/0.1.1 has no dataset_info.json - -Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/100k-ratings/0.1.1... - -Dl Completed...: 0 url [00:00, ? url/s] - -Dl Size...: 0 MiB [00:00, ? MiB/s] - -Extraction completed...: 0 file [00:00, ? file/s] - -Generating splits...: 0%| | 0/1 [00:00 -In the Movielens dataset, user IDs are integers (represented as strings) -starting at 1 and with no gap. Normally, you would need to create a lookup table -to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the -user id directly as an index in our model, in particular to lookup the user -embedding from the user embedding table. So we need do know the number of users. - - -```python -users_count = ( - ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) - .reduce(tf.constant(0, tf.int32), tf.maximum) - .numpy() -) -``` - -In the Movielens dataset, movie IDs are integers (represented as strings) -starting at 1 and with no gap. Normally, you would need to create a lookup table -to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the -movie id directly as an index in our model, in particular to lookup the movie -embedding from the movie embedding table. So we need do know the number of -movies. - - -```python -movies_count = movies.cardinality().numpy() -``` - -The inputs to the model are the user IDs and movie IDs and the labels are the -ratings. - - -```python - -def preprocess_rating(x): - return ( - # Inputs are user IDs and movie IDs - { - "user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32), - "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), - }, - # Labels are ratings between 0 and 1. - (x["user_rating"] - 1.0) / 4.0, - ) - -``` - -We'll split the data by putting 80% of the ratings in the train set, and 20% in -the test set. - - -```python -shuffled_ratings = ratings.map(preprocess_rating).shuffle( - 100_000, seed=42, reshuffle_each_iteration=False -) -train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() -test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() -``` - ---- -## Implementing the Model - -### Architecture - -Ranking models do not face the same efficiency constraints as retrieval models -do, and so we have a little bit more freedom in our choice of architectures. - -A model composed of multiple stacked dense layers is a relatively common -architecture for ranking tasks. We can implement it as follows: - - -```python - -class RankingModel(keras.Model): - """Create the ranking model with the provided parameters. - - Args: - num_users: Number of entries in the user embedding table. - num_candidates: Number of entries in the candidate embedding table. - embedding_dimension: Output dimension for user and movie embedding tables. - """ - - def __init__( - self, - num_users, - num_candidates, - embedding_dimension=32, - **kwargs, - ): - super().__init__(**kwargs) - # Embedding table for users. - self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension) - # Embedding table for candidates. - self.candidate_embedding = keras.layers.Embedding( - num_candidates, embedding_dimension - ) - # Predictions. - self.ratings = keras.Sequential( - [ - # Learn multiple dense layers. - keras.layers.Dense(256, activation="relu"), - keras.layers.Dense(64, activation="relu"), - # Make rating predictions in the final layer. - keras.layers.Dense(1), - ] - ) - - def call(self, inputs): - user_id, movie_id = inputs["user_id"], inputs["movie_id"] - user_embeddings = self.user_embedding(user_id) - candidate_embeddings = self.candidate_embedding(movie_id) - return self.ratings( - keras.ops.concatenate([user_embeddings, candidate_embeddings], axis=1) - ) - -``` - -Let's first instantiate the model. Note that we add `+ 1` to the number of users -and movies to account for the fact that id zero is not used for either (IDs -start at 1), but still takes a row in the embedding tables. - - -```python -model = RankingModel(users_count + 1, movies_count + 1) -``` - -### Loss and metrics - -The next component is the loss used to train our model. Keras has several losses -to make this easy. In this instance, we'll make use of the `MeanSquaredError` -loss in order to predict the ratings. We'll also look at the -`RootMeanSquaredError` metric. - - -```python -model.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[keras.metrics.RootMeanSquaredError()], - optimizer=keras.optimizers.Adagrad(learning_rate=0.1), -) -``` - ---- -## Fitting and evaluating - -After defining the model, we can use the standard Keras `model.fit()` to train -the model. - - -```python -model.fit(train_ratings, epochs=5) -``` - -
-``` -Epoch 1/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 3:31 3s/step - loss: 0.4544 - root_mean_squared_error: 0.6741 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 36s 465ms/step - loss: 0.3822 - root_mean_squared_error: 0.6155 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.1384 - root_mean_squared_error: 0.3630 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.1370 - root_mean_squared_error: 0.3611 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - loss: 0.1177 - root_mean_squared_error: 0.3360 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - loss: 0.1171 - root_mean_squared_error: 0.3352 - -
-``` - -``` -
- 74/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.1078 - root_mean_squared_error: 0.3227 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 10ms/step - loss: 0.1058 - root_mean_squared_error: 0.3200 - - -
-``` -Epoch 2/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 36s 462ms/step - loss: 0.0780 - root_mean_squared_error: 0.2794 - -
-``` - -``` -
- 25/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0773 - root_mean_squared_error: 0.2781 - -
-``` - -``` -
- 26/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0773 - root_mean_squared_error: 0.2781 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0774 - root_mean_squared_error: 0.2782 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0774 - root_mean_squared_error: 0.2782 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0774 - root_mean_squared_error: 0.2783 - - -
-``` -Epoch 3/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0767 - root_mean_squared_error: 0.2769 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0771 - root_mean_squared_error: 0.2777 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0760 - root_mean_squared_error: 0.2756 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0760 - root_mean_squared_error: 0.2756 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0759 - root_mean_squared_error: 0.2755 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0759 - root_mean_squared_error: 0.2755 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0759 - root_mean_squared_error: 0.2754 - - -
-``` -Epoch 4/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0746 - root_mean_squared_error: 0.2730 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0753 - root_mean_squared_error: 0.2743 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0738 - root_mean_squared_error: 0.2717 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0738 - root_mean_squared_error: 0.2717 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0736 - root_mean_squared_error: 0.2712 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0736 - root_mean_squared_error: 0.2712 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0735 - root_mean_squared_error: 0.2711 - - -
-``` -Epoch 5/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0709 - root_mean_squared_error: 0.2663 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0722 - root_mean_squared_error: 0.2686 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0707 - root_mean_squared_error: 0.2658 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0706 - root_mean_squared_error: 0.2658 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0703 - root_mean_squared_error: 0.2651 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0703 - root_mean_squared_error: 0.2651 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0703 - root_mean_squared_error: 0.2651 - - - - - -
-``` - - -``` -
-As the model trains, the loss is falling and the RMSE metric is improving. - -Finally, we can evaluate our model on the test set. The lower the RMSE metric, -the more accurate our model is at predicting ratings. - - -```python -model.evaluate(test_ratings, return_dict=True) -``` - - - 1/20 ━━━━━━━━━━━━━━━━━━━━ 36s 2s/step - loss: 0.0732 - root_mean_squared_error: 0.2705 - -
-``` - -``` -
- 2/20 ━━━━━━━━━━━━━━━━━━━━ 3s 187ms/step - loss: 0.0724 - root_mean_squared_error: 0.2690 - -
-``` - -``` -
- 3/20 ━━━━━━━━━━━━━━━━━━━━ 1s 95ms/step - loss: 0.0719 - root_mean_squared_error: 0.2681 - -
-``` - -``` -
- 20/20 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 0.0707 - root_mean_squared_error: 0.2658 - - - - - -
-``` -{'loss': 0.0712985172867775, 'root_mean_squared_error': 0.26701781153678894} - -``` -
---- -## Testing the ranking model - -So far, we have only handled movies by id. Now is the time to create a mapping -keyed by movie IDs to be able to surface the titles. - - -```python -movie_id_to_movie_title = { - int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator() -} -movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. -``` - -Now we can test the ranking model by computing predictions for a set of movies -and then rank these movies based on the predictions: - - -```python -user_id = 42 -movie_ids = [204, 141, 131] -predictions = model.predict( - { - "user_id": keras.ops.array([user_id] * len(movie_ids)), - "movie_id": keras.ops.array(movie_ids), - } -) -predictions = keras.ops.convert_to_numpy(keras.ops.squeeze(predictions, axis=1)) - -for movie_id, prediction in zip(movie_ids, predictions): - print(f"{movie_id_to_movie_title[movie_id]}: {5.0 * prediction:,.2f}") -``` - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 271ms/step - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 273ms/step - - -
-``` -b'Back to the Future (1985)': 3.86 -b'20,000 Leagues Under the Sea (1954)': 3.93 -b"Breakfast at Tiffany's (1961)": 3.72 - -``` -
\ No newline at end of file diff --git a/templates/keras_rs/examples/basic_retrieval.md b/templates/keras_rs/examples/basic_retrieval.md deleted file mode 100644 index 06eb46818c..0000000000 --- a/templates/keras_rs/examples/basic_retrieval.md +++ /dev/null @@ -1,2168 +0,0 @@ -# Recommending movies: retrieval - -**Author:** [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Retrieve movies using a two tower model. - - - [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/basic_retrieval.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/basic_retrieval.py) - - - ---- -## Introduction - -Recommender systems are often composed of two stages: - -1. The retrieval stage is responsible for selecting an initial set of hundreds - of candidates from all possible candidates. The main objective of this model - is to efficiently weed out all candidates that the user is not interested in. - Because the retrieval model may be dealing with millions of candidates, it - has to be computationally efficient. -2. The ranking stage takes the outputs of the retrieval model and fine-tunes - them to select the best possible handful of recommendations. Its task is to - narrow down the set of items the user may be interested in to a shortlist of - likely candidates. - -In this tutorial, we're going to focus on the first stage, retrieval. If you are -interested in the ranking stage, have a look at our -[ranking](/keras_rs/examples/basic_ranking/) tutorial. - -Retrieval models are often composed of two sub-models: - -1. A query tower computing the query representation (normally a - fixed-dimensionality embedding vector) using query features. -2. A candidate tower computing the candidate representation (an equally-sized - vector) using the candidate features. The outputs of the two models are then - multiplied together to give a query-candidate affinity score, with higher - scores expressing a better match between the candidate and the query. - -In this tutorial, we're going to build and train such a two-tower model using -the Movielens dataset. - -We're going to: - -1. Get our data and split it into a training and test set. -2. Implement a retrieval model. -3. Fit and evaluate it. -4. Test running predictions with the model. - -### The dataset - -The Movielens dataset is a classic dataset from the -[GroupLens](https://grouplens.org/datasets/movielens/) research group at the -University of Minnesota. It contains a set of ratings given to movies by a set -of users, and is a standard for recommender systems research. - -The data can be treated in two ways: - -1. It can be interpreted as expressesing which movies the users watched (and - rated), and which they did not. This is a form of implicit feedback, where - users' watches tell us which things they prefer to see and which they'd - rather not see. -2. It can also be seen as expressesing how much the users liked the movies they - did watch. This is a form of explicit feedback: given that a user watched a - movie, we can tell how much they liked by looking at the rating they have - given. - -In this tutorial, we are focusing on a retrieval system: a model that predicts a -set of movies from the catalogue that the user is likely to watch. For this, the -model will try to predict the rating users would give to all the movies in the -catalogue. We will therefore use the explicit rating data. - -Let's begin by choosing JAX as the backend we want to run on, and import all -the necessary libraries. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import keras -import tensorflow as tf # Needed for the dataset -import tensorflow_datasets as tfds - -import keras_rs -``` - ---- -## Preparing the dataset - -Let's first have a look at the data. - -We use the MovieLens dataset from -[Tensorflow Datasets](https://www.tensorflow.org/datasets). Loading -`movielens/100k_ratings` yields a `tf.data.Dataset` object containing the -ratings alongside user and movie data. Loading `movielens/100k_movies` yields a -`tf.data.Dataset` object containing only the movies data. - -Note that since the MovieLens dataset does not have predefined splits, all data -are under `train` split. - - -```python -# Ratings data with user and movie data. -ratings = tfds.load("movielens/100k-ratings", split="train") -# Features of all the available movies. -movies = tfds.load("movielens/100k-movies", split="train") -``` - -The ratings dataset returns a dictionary of movie id, user id, the assigned -rating, timestamp, movie information, and user information: - - -```python -for data in ratings.take(1).as_numpy_iterator(): - print(str(data).replace(", '", ",\n '")) -``` - -
-``` -{'bucketized_user_age': np.float32(45.0), - 'movie_genres': array([7]), - 'movie_id': b'357', - 'movie_title': b"One Flew Over the Cuckoo's Nest (1975)", - 'raw_user_age': np.float32(46.0), - 'timestamp': np.int64(879024327), - 'user_gender': np.True_, - 'user_id': b'138', - 'user_occupation_label': np.int64(4), - 'user_occupation_text': b'doctor', - 'user_rating': np.float32(4.0), - 'user_zip_code': b'53211'} - -``` -
-In the Movielens dataset, user IDs are integers (represented as strings) -starting at 1 and with no gap. Normally, you would need to create a lookup table -to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the -user id directly as an index in our model, in particular to lookup the user -embedding from the user embedding table. So we need do know the number of users. - - -```python -users_count = ( - ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) - .reduce(tf.constant(0, tf.int32), tf.maximum) - .numpy() -) -``` - -The movies dataset contains the movie id, movie title, and the genres it belongs -to. Note that the genres are encoded with integer labels. - - -```python -for data in movies.take(1).as_numpy_iterator(): - print(str(data).replace(", '", ",\n '")) -``` - -
-``` -{'movie_genres': array([4]), - 'movie_id': b'1681', - 'movie_title': b'You So Crazy (1994)'} - -``` -
-In the Movielens dataset, movie IDs are integers (represented as strings) -starting at 1 and with no gap. Normally, you would need to create a lookup table -to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the -movie id directly as an index in our model, in particular to lookup the movie -embedding from the movie embedding table. So we need do know the number of -movies. - - -```python -movies_count = movies.cardinality().numpy() -``` - -In this example, we're going to focus on the ratings data. Other tutorials -explore how to use the movie information data as well as the user information to -improve the model quality. - -We keep only the `user_id`, `movie_id` and `rating` fields in the dataset. Our -input is the `user_id`. The labels are the `movie_id` alongside the `rating` for -the given movie and user. - -The `rating` is a number between 1 and 5, we adapt it to be between 0 and 1. - - -```python - -def preprocess_rating(x): - return ( - # Input is the user IDs - tf.strings.to_number(x["user_id"], out_type=tf.int32), - # Labels are movie IDs + ratings between 0 and 1. - { - "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), - "rating": (x["user_rating"] - 1.0) / 4.0, - }, - ) - -``` - -To fit and evaluate the model, we need to split it into a training and -evaluation set. In a real recommender system, this would most likely be done by -time: the data up to time *T* would be used to predict interactions after *T*. - -In this simple example, however, let's use a random split, putting 80% of the -ratings in the train set, and 20% in the test set. - - -```python -shuffled_ratings = ratings.map(preprocess_rating).shuffle( - 100_000, seed=42, reshuffle_each_iteration=False -) -train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() -test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() -``` - ---- -## Implementing the Model - -Choosing the architecture of our model is a key part of modelling. - -We are building a two-tower retrieval model, therefore we need to combine a -query tower for users and a candidate tower for movies. - -The first step is to decide on the dimensionality of the query and candidate -representations. This is the `embedding_dimension` argument in our model -constructor. We'll test with a value of `32`. Higher values will correspond to -models that may be more accurate, but will also be slower to fit and more prone -to overfitting. - -### Query and Candidate Towers - -The second step is to define the model itself. In this simple example, the query -tower and candidate tower are simply embeddings with nothing else. We'll use -Keras' `Embedding` layer. - -We can easily extend the towers to make them arbitrarily complex using standard -Keras components, as long as we return an `embedding_dimension`-wide output at -the end. - -### Retrieval - -The retrieval itself will be performed by `BruteForceRetrieval` layer from Keras -Recommenders. This layer computes the affinity scores for the given users and -all the candidate movies, then returns the top K in order. - -Note that during training, we don't actually need to perform any retrieval since -the only affinity scores we need are the ones for the users and movies in the -batch. As an optimization, we skip the retrieval entirely in the `call` method. - -### Loss - -The next component is the loss used to train our model. In this case, we use a -mean square error loss to measure the difference between the predicted movie -ratings and the actual ratins from users. - -Note that we override `compute_loss` from the `keras.Model` class. This allows -us to compute the query-candidate affinity score, which is obtained by -multiplying the outputs of the two towers together. That affinity score can then -be passed to the loss function. - - -```python - -class RetrievalModel(keras.Model): - """Create the retrieval model with the provided parameters. - - Args: - num_users: Number of entries in the user embedding table. - num_candidates: Number of entries in the candidate embedding table. - embedding_dimension: Output dimension for user and movie embedding tables. - """ - - def __init__( - self, - num_users, - num_candidates, - embedding_dimension=32, - **kwargs, - ): - super().__init__(**kwargs) - # Our query tower, simply an embedding table. - self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension) - # Our candidate tower, simply an embedding table. - self.candidate_embedding = keras.layers.Embedding( - num_candidates, embedding_dimension - ) - # The layer that performs the retrieval. - self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False) - self.loss_fn = keras.losses.MeanSquaredError() - - def build(self, input_shape): - self.user_embedding.build(input_shape) - self.candidate_embedding.build(input_shape) - # In this case, the candidates are directly the movie embeddings. - # We take a shortcut and directly reuse the variable. - self.retrieval.candidate_embeddings = self.candidate_embedding.embeddings - self.retrieval.build(input_shape) - super().build(input_shape) - - def call(self, inputs, training=False): - user_embeddings = self.user_embedding(inputs) - result = { - "user_embeddings": user_embeddings, - } - if not training: - # Skip the retrieval of top movies during training as the - # predictions are not used. - result["predictions"] = self.retrieval(user_embeddings) - return result - - def compute_loss(self, x, y, y_pred, sample_weight, training=True): - candidate_id, rating = y["movie_id"], y["rating"] - user_embeddings = y_pred["user_embeddings"] - candidate_embeddings = self.candidate_embedding(candidate_id) - - labels = keras.ops.expand_dims(rating, -1) - # Compute the affinity score by multiplying the two embeddings. - scores = keras.ops.sum( - keras.ops.multiply(user_embeddings, candidate_embeddings), - axis=1, - keepdims=True, - ) - return self.loss_fn(labels, scores, sample_weight) - -``` - ---- -## Fitting and evaluating - -After defining the model, we can use the standard Keras `model.fit()` to train -and evaluate the model. - -Let's first instantiate the model. Note that we add `+ 1` to the number of users -and movies to account for the fact that id zero is not used for either (IDs -start at 1), but still takes a row in the embedding tables. - - -```python -model = RetrievalModel(users_count + 1, movies_count + 1) -model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.1)) -``` - -Then train the model. Evaluation takes a bit of time, so we only evaluate the -model every 5 epochs. - - -```python -history = model.fit( - train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50 -) -``` - -
-``` -Epoch 1/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 2:37 2s/step - loss: 0.4472 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 17s 221ms/step - loss: 0.4542 - -
-``` - -``` -
- 26/80 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - loss: 0.4760 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4767 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4772 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - loss: 0.4772 - - -
-``` -Epoch 2/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 17s 223ms/step - loss: 0.4471 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4772 - - -
-``` -Epoch 3/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4471 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4772 - - -
-``` -Epoch 4/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4471 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4771 - - -
-``` -Epoch 5/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4470 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 27ms/step - loss: 0.4771 - val_loss: 0.4836 - - -
-``` -Epoch 6/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4470 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4540 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4771 - - -
-``` -Epoch 7/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4470 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770 - - -
-``` -Epoch 8/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4470 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4540 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770 - - -
-``` -Epoch 9/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4469 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770 - - -
-``` -Epoch 10/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4469 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - val_loss: 0.4836 - - -
-``` -Epoch 11/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4469 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - - -
-``` -Epoch 12/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4468 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - - -
-``` -Epoch 13/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4468 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - - -
-``` -Epoch 14/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4468 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - - -
-``` -Epoch 15/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4467 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - val_loss: 0.4835 - - -
-``` -Epoch 16/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4467 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - - -
-``` -Epoch 17/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4466 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4537 - -
-``` - -``` -
- 38/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4758 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4765 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4767 - - -
-``` -Epoch 18/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4466 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - - -
-``` -Epoch 19/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4465 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4756 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4763 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - - -
-``` -Epoch 20/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4465 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4755 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4756 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4763 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - val_loss: 0.4834 - - -
-``` -Epoch 21/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4464 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4755 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4763 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4763 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4764 - - -
-``` -Epoch 22/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4464 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4754 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - - -
-``` -Epoch 23/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4463 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4754 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4761 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4763 - - -
-``` -Epoch 24/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4462 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4753 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4761 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4762 - - -
-``` -Epoch 25/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4462 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4752 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4753 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - val_loss: 0.4832 - - -
-``` -Epoch 26/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4461 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4751 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - - -
-``` -Epoch 27/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4460 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4750 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760 - - -
-``` -Epoch 28/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4459 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4749 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759 - - -
-``` -Epoch 29/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4458 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4748 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4756 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4758 - - -
-``` -Epoch 30/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4457 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4747 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4747 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4754 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4756 - val_loss: 0.4828 - - -
-``` -Epoch 31/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4456 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.4525 - -
-``` - -``` -
- 38/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4746 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4754 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4755 - - -
-``` -Epoch 32/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4454 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4744 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4751 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4753 - - -
-``` -Epoch 33/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4453 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4743 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4743 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4750 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4752 - - -
-``` -Epoch 34/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4451 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4521 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4741 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4748 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4750 - - -
-``` -Epoch 35/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4449 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4519 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4739 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4746 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4748 - val_loss: 0.4821 - - -
-``` -Epoch 36/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4447 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4736 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4743 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4746 - - -
-``` -Epoch 37/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4444 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4734 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4741 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4743 - - -
-``` -Epoch 38/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4442 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4731 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4738 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4740 - - -
-``` -Epoch 39/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4439 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4728 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4735 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4737 - - -
-``` -Epoch 40/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4436 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4725 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4731 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4732 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4734 - val_loss: 0.4807 - - -
-``` -Epoch 41/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4432 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.4502 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4721 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4721 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4728 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4728 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4730 - - -
-``` -Epoch 42/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4428 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4716 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4724 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4724 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4725 - - -
-``` -Epoch 43/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4423 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4712 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4719 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4721 - - -
-``` -Epoch 44/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4418 - -
-``` - -``` -
- 38/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4707 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4714 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4715 - - -
-``` -Epoch 45/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4413 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4701 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4708 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4708 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4709 - val_loss: 0.4783 - - -
-``` -Epoch 46/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4406 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4694 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4701 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4701 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4703 - - -
-``` -Epoch 47/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4399 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4687 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4693 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4695 - - -
-``` -Epoch 48/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4392 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.4461 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4679 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4685 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4686 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4687 - - -
-``` -Epoch 49/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4383 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4670 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4676 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4677 - - -
-``` -Epoch 50/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4373 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4659 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4665 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4667 - val_loss: 0.4739 - - ---- -## Making predictions - -Now that we have a model, we would like to be able to make predictions. - -So far, we have only handled movies by id. Now is the time to create a mapping -keyed by movie IDs to be able to surface the titles. - - -```python -movie_id_to_movie_title = { - int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator() -} -movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. -``` - -We then simply use the Keras `model.predict()` method. Under the hood, it calls -the `BruteForceRetrieval` layer to perform the actual retrieval. - -Note that this model can retrieve movies already watched by the user. We could -easily add logic to remove them if that is desirable. - - -```python -user_id = 42 -predictions = model.predict(keras.ops.convert_to_tensor([user_id])) -predictions = keras.ops.convert_to_numpy(predictions["predictions"]) - -print(f"Recommended movies for user {user_id}:") -for movie_id in predictions[0]: - print(movie_id_to_movie_title[movie_id]) -``` - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 103ms/step - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 105ms/step - - -
-``` -Recommended movies for user 42: -b'Raiders of the Lost Ark (1981)' -b'Godfather, The (1972)' -b'Star Trek: The Wrath of Khan (1982)' -b'Indiana Jones and the Last Crusade (1989)' -b'Birdcage, The (1996)' -b'Silence of the Lambs, The (1991)' -b'Blade Runner (1982)' -b'Aliens (1986)' -b'Contact (1997)' -b'Star Wars (1977)' - -``` -
---- -## Item-to-item recommendation - -In this model, we created a user-movie model. However, for some applications -(for example, product detail pages) it's common to perform item-to-item (for -example, movie-to-movie or product-to-product) recommendations. - -Training models like this would follow the same pattern as shown in this -tutorial, but with different training data. Here, we had a user and a movie -tower, and used (user, movie) pairs to train them. In an item-to-item model, we -would have two item towers (for the query and candidate item), and train the -model using (query item, candidate item) pairs. These could be constructed from -clicks on product detail pages. diff --git a/templates/keras_rs/examples/data_parallel_retrieval.md b/templates/keras_rs/examples/data_parallel_retrieval.md deleted file mode 100644 index f7fe3a7df7..0000000000 --- a/templates/keras_rs/examples/data_parallel_retrieval.md +++ /dev/null @@ -1,4220 +0,0 @@ -# Retrieval with data parallel training - -**Author:** [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Retrieve movies using a two tower model (data parallel training). - - - [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/data_parallel_retrieval.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/data_parallel_retrieval.py) - - - ---- -## Introduction - -In this tutorial, we are going to train the exact same retrieval model as we -did in our -[basic retrieval](/keras_rs/examples/basic_retrieval/) -tutorial, but in a distributed way. - -Distributed training is used to train models on multiple devices or machines -simultaneously, thereby reducing training time. Here, we focus on synchronous -data parallel training. Each accelerator (GPU/TPU) holds a complete replica -of the model, and sees a different mini-batch of the input data. Local gradients -are computed on each device, aggregated and used to compute a global gradient -update. - -Before we begin, let's note down a few things: - -1. The number of accelerators should be greater than 1. -2. The `keras.distribution` API works only with JAX. So, make sure you select - JAX as your backend! - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" - -import random - -import jax -import keras -import tensorflow as tf # Needed only for the dataset -import tensorflow_datasets as tfds - -import keras_rs -``` - ---- -## Data Parallel - -For the synchronous data parallelism strategy in distributed training, -we will use the `DataParallel` class present in the `keras.distribution` -API. - - -```python -devices = jax.devices() # Assume it has >1 local devices. -data_parallel = keras.distribution.DataParallel(devices=devices) -``` - -Alternatively, you can choose to create the `DataParallel` object -using a 1D `DeviceMesh` object, like so: - -``` -mesh_1d = keras.distribution.DeviceMesh( - shape=(len(devices),), axis_names=["data"], devices=devices -) -data_parallel = keras.distribution.DataParallel(device_mesh=mesh_1d) -``` - - -```python -# Set the global distribution strategy. -keras.distribution.set_distribution(data_parallel) -``` - ---- -## Preparing the dataset - -Now that we are done defining the global distribution -strategy, the rest of the guide looks exactly the same -as the previous basic retrieval guide. - -Let's load and prepare the dataset. Here too, we use the -MovieLens dataset. - - -```python -# Ratings data with user and movie data. -ratings = tfds.load("movielens/100k-ratings", split="train") -# Features of all the available movies. -movies = tfds.load("movielens/100k-movies", split="train") - -# User, movie counts for defining vocabularies. -users_count = ( - ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) - .reduce(tf.constant(0, tf.int32), tf.maximum) - .numpy() -) -movies_count = movies.cardinality().numpy() - - -# Preprocess dataset, and split it into train-test datasets. -def preprocess_rating(x): - return ( - # Input is the user IDs - tf.strings.to_number(x["user_id"], out_type=tf.int32), - # Labels are movie IDs + ratings between 0 and 1. - { - "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), - "rating": (x["user_rating"] - 1.0) / 4.0, - }, - ) - - -shuffled_ratings = ratings.map(preprocess_rating).shuffle( - 100_000, seed=42, reshuffle_each_iteration=False -) -train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() -test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() -``` - -
-``` -WARNING:absl:Variant folder /root/tensorflow_datasets/movielens/100k-ratings/0.1.1 has no dataset_info.json - -Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/100k-ratings/0.1.1... - -Dl Completed...: 0 url [00:00, ? url/s] - -Dl Size...: 0 MiB [00:00, ? MiB/s] - -Extraction completed...: 0 file [00:00, ? file/s] - -Generating splits...: 0%| | 0/1 [00:00 ---- -## Implementing the Model - -We build a two-tower retrieval model. Therefore, we need to combine a -query tower for users and a candidate tower for movies. Note that we don't -have to change anything here from the previous basic retrieval tutorial. - - -```python - -class RetrievalModel(keras.Model): - """Create the retrieval model with the provided parameters. - - Args: - num_users: Number of entries in the user embedding table. - num_candidates: Number of entries in the candidate embedding table. - embedding_dimension: Output dimension for user and movie embedding tables. - """ - - def __init__( - self, - num_users, - num_candidates, - embedding_dimension=32, - **kwargs, - ): - super().__init__(**kwargs) - # Our query tower, simply an embedding table. - self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension) - # Our candidate tower, simply an embedding table. - self.candidate_embedding = keras.layers.Embedding( - num_candidates, embedding_dimension - ) - # The layer that performs the retrieval. - self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False) - self.loss_fn = keras.losses.MeanSquaredError() - - def build(self, input_shape): - self.user_embedding.build(input_shape) - self.candidate_embedding.build(input_shape) - # In this case, the candidates are directly the movie embeddings. - # We take a shortcut and directly reuse the variable. - self.retrieval.candidate_embeddings = self.candidate_embedding.embeddings - self.retrieval.build(input_shape) - super().build(input_shape) - - def call(self, inputs, training=False): - user_embeddings = self.user_embedding(inputs) - result = { - "user_embeddings": user_embeddings, - } - if not training: - # Skip the retrieval of top movies during training as the - # predictions are not used. - result["predictions"] = self.retrieval(user_embeddings) - return result - - def compute_loss(self, x, y, y_pred, sample_weight, training=True): - candidate_id, rating = y["movie_id"], y["rating"] - user_embeddings = y_pred["user_embeddings"] - candidate_embeddings = self.candidate_embedding(candidate_id) - - labels = keras.ops.expand_dims(rating, -1) - # Compute the affinity score by multiplying the two embeddings. - scores = keras.ops.sum( - keras.ops.multiply(user_embeddings, candidate_embeddings), - axis=1, - keepdims=True, - ) - return self.loss_fn(labels, scores, sample_weight) - -``` - ---- -## Fitting and evaluating - -After defining the model, we can use the standard Keras `model.fit()` to train -and evaluate the model. - - -```python -model = RetrievalModel(users_count + 1, movies_count + 1) -model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.2)) -``` - -Let's train the model. Evaluation takes a bit of time, so we only evaluate the -model every 5 epochs. - - -```python -history = model.fit( - train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50 -) -``` - -
-``` -Epoch 1/50 - -``` -
- - 8/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4707 - 3/80 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - loss: 0.4606 - 1/80 ━━━━━━━━━━━━━━━━━━━━ 2:04 2s/step - loss: 0.4479 - 4/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4637 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 0.4547 - - - 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4718 - - - 6/80 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - loss: 0.4679 - - - 5/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.4663 - 7/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4694 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.4727 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.4756 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4762 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4764 - -
-``` - -``` -
- 43/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4766 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4767 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4769 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4770 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4772 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 8ms/step - loss: 0.4773 - - -
-``` -Epoch 2/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.4478 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4717 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4753 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4761 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4763 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4765 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4767 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4768 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4770 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4771 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4772 - - -
-``` -Epoch 3/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.4478 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4717 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4751 - -
-``` - -``` -
- 26/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4760 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4762 - -
-``` - -``` -
- 42/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4764 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4766 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4767 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4769 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4771 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4771 - - -
-``` -Epoch 4/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4477 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4716 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4752 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4760 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4762 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4764 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4766 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4767 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4770 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4770 - - -
-``` -Epoch 5/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4476 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4724 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4753 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4759 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4761 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4764 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4765 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4767 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 27ms/step - loss: 0.4770 - val_loss: 0.4835 - - -
-``` -Epoch 6/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - loss: 0.4476 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4724 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4752 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4759 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4760 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4763 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4765 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4766 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4768 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4769 - - -
-``` -Epoch 7/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.4475 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4723 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4751 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4758 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4760 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4762 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4764 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4765 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4768 - - -
-``` -Epoch 8/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4474 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4722 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4750 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4757 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4759 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4761 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4763 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4764 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4767 - - -
-``` -Epoch 9/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.4473 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4712 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4748 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4755 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4757 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4759 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4761 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4763 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4765 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4766 - - -
-``` -Epoch 10/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4472 - -
-``` - -``` -
- 8/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.4699 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4744 - -
-``` - -``` -
- 26/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4754 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4756 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4758 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4760 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4762 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4763 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4765 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4765 - val_loss: 0.4832 - - -
-``` -Epoch 11/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.4470 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4718 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4746 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4753 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4755 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4757 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4759 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4761 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4762 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4763 - - -
-``` -Epoch 12/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4469 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4716 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4745 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4751 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4753 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4756 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4757 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4759 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4761 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4762 - - -
-``` -Epoch 13/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.4467 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4705 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4741 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4749 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4751 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4753 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4755 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4757 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4758 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4760 - - -
-``` -Epoch 14/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4465 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4712 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4740 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4747 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4749 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4751 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4753 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4754 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4756 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4757 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4757 - - -
-``` -Epoch 15/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4462 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4700 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4736 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4744 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4746 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4748 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4750 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4751 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4753 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4754 - val_loss: 0.4824 - - -
-``` -Epoch 16/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.4459 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4706 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4734 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4741 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4743 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4745 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4747 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4748 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4750 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4751 - - -
-``` -Epoch 17/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4455 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4702 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4730 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4737 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4738 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4741 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4742 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4744 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4745 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4746 - - -
-``` -Epoch 18/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4450 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4697 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4725 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4731 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4733 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4735 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4737 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4738 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4740 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4741 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4741 - - -
-``` -Epoch 19/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4444 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4690 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4718 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4725 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4726 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4728 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4730 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4731 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4733 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4734 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4734 - - -
-``` -Epoch 20/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.4437 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4673 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4707 - -
-``` - -``` -
- 25/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4716 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4718 - -
-``` - -``` -
- 43/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4720 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4722 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4723 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4725 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4726 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4726 - val_loss: 0.4795 - - -
-``` -Epoch 21/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.4427 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4673 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4701 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4707 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4709 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4711 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4712 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4714 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4715 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4716 - - -
-``` -Epoch 22/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4416 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4652 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4685 - -
-``` - -``` -
- 25/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4693 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4696 - -
-``` - -``` -
- 42/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4697 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4699 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4700 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4701 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4703 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4703 - - -
-``` -Epoch 23/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4401 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4636 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4672 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4679 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4681 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4683 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4684 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4685 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4686 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4687 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4687 - - -
-``` -Epoch 24/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4383 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4618 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4653 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4660 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4661 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4663 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4664 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4665 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4666 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4667 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4667 - - -
-``` -Epoch 25/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4361 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4603 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4631 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4637 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4638 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4639 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4640 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4641 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4642 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4642 - val_loss: 0.4701 - - -
-``` -Epoch 26/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.4333 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4574 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4601 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4607 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4608 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4610 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4610 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4611 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4612 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4612 - - -
-``` -Epoch 27/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4299 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4538 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4565 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4571 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4572 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4573 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4573 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4574 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4574 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4574 - - -
-``` -Epoch 28/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4256 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4485 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4517 - -
-``` - -``` -
- 26/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4525 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4526 - -
-``` - -``` -
- 43/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4527 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4527 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4527 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.4527 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4527 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4527 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4527 - - -
-``` -Epoch 29/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.4204 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4440 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4466 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4471 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4471 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4472 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4471 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4471 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4471 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.4470 - - -
-``` -Epoch 30/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.4141 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4374 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4399 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4404 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4404 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4404 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4403 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4402 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4402 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.4401 - val_loss: 0.4427 - - -
-``` -Epoch 31/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.4064 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4295 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4319 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4323 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4323 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4322 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4321 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4320 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4319 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4317 - - -
-``` -Epoch 32/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.3973 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4200 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4223 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4227 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4226 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4225 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4224 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4222 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4220 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4218 - - -
-``` -Epoch 33/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.3866 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4089 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4111 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4114 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4113 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4111 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4109 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4107 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4104 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4102 - - -
-``` -Epoch 34/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.3742 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3960 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3981 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3984 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3982 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3979 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3977 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3974 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3971 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3968 - - -
-``` -Epoch 35/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.3601 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3813 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3834 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3836 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3833 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3830 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3827 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3823 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3820 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.3817 - val_loss: 0.3787 - - -
-``` -Epoch 36/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.3443 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3651 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3670 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3671 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3668 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3665 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3661 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3657 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3653 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3649 - - -
-``` -Epoch 37/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.3273 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3475 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3493 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3494 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3490 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3487 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3482 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3478 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3473 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3469 - - -
-``` -Epoch 38/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.3093 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.3282 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.3305 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.3306 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3303 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3299 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3294 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3289 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3285 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.3280 - - -
-``` -Epoch 39/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 19ms/step - loss: 0.2907 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3098 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3114 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3114 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3111 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3106 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3101 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3096 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3091 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.3087 - - -
-``` -Epoch 40/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.2722 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2907 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2923 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2923 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2919 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2915 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2910 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2905 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2900 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.2896 - val_loss: 0.2856 - - -
-``` -Epoch 41/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.2542 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2722 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2737 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2737 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2734 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2729 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2725 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2720 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2715 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2711 - - -
-``` -Epoch 42/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.2372 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2547 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2561 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2562 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2558 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2554 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2550 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2545 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2540 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.2537 - - -
-``` -Epoch 43/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.2215 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2384 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2399 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2399 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2396 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2392 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2388 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2384 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2380 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.2376 - - -
-``` -Epoch 44/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.2072 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2236 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2250 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2251 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2248 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2244 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2240 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2237 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2233 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2230 - - -
-``` -Epoch 45/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.1944 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2103 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2116 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2117 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2115 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2111 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2108 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2104 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.2101 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.2098 - val_loss: 0.2106 - - -
-``` -Epoch 46/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.1831 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1984 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1997 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1998 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1995 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1993 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1990 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1987 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1983 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1981 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1980 - - -
-``` -Epoch 47/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.1730 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1877 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1890 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1891 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1888 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1886 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1884 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1881 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1878 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1875 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1875 - - -
-``` -Epoch 48/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.1641 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1782 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1794 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1795 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1793 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1791 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1788 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1786 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1783 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1781 - - -
-``` -Epoch 49/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 21ms/step - loss: 0.1562 - -
-``` - -``` -
- 9/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1693 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1707 - -
-``` - -``` -
- 25/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1709 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1708 - -
-``` - -``` -
- 41/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1706 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1704 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1702 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1700 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1697 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 0.1696 - - -
-``` -Epoch 50/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 20ms/step - loss: 0.1492 - -
-``` - -``` -
- 10/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1620 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1631 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1631 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1630 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1628 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1626 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1624 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.1622 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 0.1620 - val_loss: 0.1660 - - ---- -## Making predictions - -Now that we have a model, let's run inference and make predictions. - - -```python -movie_id_to_movie_title = { - int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator() -} -movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. -``` - -We then simply use the Keras `model.predict()` method. Under the hood, it calls -the `BruteForceRetrieval` layer to perform the actual retrieval. - - -```python -user_ids = random.sample(range(1, 1001), len(devices)) -predictions = model.predict(keras.ops.convert_to_tensor(user_ids)) -predictions = keras.ops.convert_to_numpy(predictions["predictions"]) - -for i, user_id in enumerate(user_ids): - print(f"\n==Recommended movies for user {user_id}==") - for movie_id in predictions[i]: - print(movie_id_to_movie_title[movie_id]) -``` - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 204ms/step - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 205ms/step - - - -
-``` -==Recommended movies for user 449== -b'Star Wars (1977)' -b'Fargo (1996)' -b'Silence of the Lambs, The (1991)' -b'Shawshank Redemption, The (1994)' -b'Pulp Fiction (1994)' -b'Raiders of the Lost Ark (1981)' -b"Schindler's List (1993)" -b'Blade Runner (1982)' -b"One Flew Over the Cuckoo's Nest (1975)" -b'Casablanca (1942)' -``` -
- -
-``` -==Recommended movies for user 681== -b'Star Wars (1977)' -b'Fargo (1996)' -b'Godfather, The (1972)' -b'Silence of the Lambs, The (1991)' -b'Raiders of the Lost Ark (1981)' -b'Return of the Jedi (1983)' -b'Pulp Fiction (1994)' -b"Schindler's List (1993)" -b'Empire Strikes Back, The (1980)' -b'Shawshank Redemption, The (1994)' -``` -
- -
-``` -==Recommended movies for user 151== -b'Princess Bride, The (1987)' -b'Pulp Fiction (1994)' -b'English Patient, The (1996)' -b'Alien (1979)' -b'Raiders of the Lost Ark (1981)' -b'Willy Wonka and the Chocolate Factory (1971)' -b'Amadeus (1984)' -b'Liar Liar (1997)' -b'Psycho (1960)' -b"It's a Wonderful Life (1946)" -``` -
- -
-``` -==Recommended movies for user 442== -b'Star Wars (1977)' -b'Fargo (1996)' -b'Godfather, The (1972)' -b'Silence of the Lambs, The (1991)' -b'Raiders of the Lost Ark (1981)' -b'Return of the Jedi (1983)' -b'Pulp Fiction (1994)' -b'Empire Strikes Back, The (1980)' -b"Schindler's List (1993)" -b'Shawshank Redemption, The (1994)' -``` -
- -
-``` -==Recommended movies for user 134== -b'Star Wars (1977)' -b'Fargo (1996)' -b'Godfather, The (1972)' -b'Silence of the Lambs, The (1991)' -b'Raiders of the Lost Ark (1981)' -b'Pulp Fiction (1994)' -b'Return of the Jedi (1983)' -b'Empire Strikes Back, The (1980)' -b'Twelve Monkeys (1995)' -b'Contact (1997)' -``` -
- -
-``` -==Recommended movies for user 853== -b'Star Wars (1977)' -b'Fargo (1996)' -b'Godfather, The (1972)' -b'Raiders of the Lost Ark (1981)' -b'Silence of the Lambs, The (1991)' -b'Return of the Jedi (1983)' -b'Pulp Fiction (1994)' -b"Schindler's List (1993)" -b'Empire Strikes Back, The (1980)' -b'Shawshank Redemption, The (1994)' -``` -
- -
-``` -==Recommended movies for user 707== -b'Star Wars (1977)' -b'Raiders of the Lost Ark (1981)' -b'Toy Story (1995)' -b"Schindler's List (1993)" -b'Empire Strikes Back, The (1980)' -b'Fargo (1996)' -b'Godfather, The (1972)' -b'Return of the Jedi (1983)' -b'Terminator, The (1984)' -b'Princess Bride, The (1987)' -``` -
- -
-``` -==Recommended movies for user 511== -b'Star Wars (1977)' -b'Fargo (1996)' -b'Godfather, The (1972)' -b'Raiders of the Lost Ark (1981)' -b'Silence of the Lambs, The (1991)' -b'Return of the Jedi (1983)' -b"Schindler's List (1993)" -b'Empire Strikes Back, The (1980)' -b'Pulp Fiction (1994)' -b'Shawshank Redemption, The (1994)' - -``` -
-And we're done! For data parallel training, all we had to do was add ~3-5 LoC. -The rest is exactly the same. diff --git a/templates/keras_rs/examples/dcn.md b/templates/keras_rs/examples/dcn.md deleted file mode 100644 index 3999ef35db..0000000000 --- a/templates/keras_rs/examples/dcn.md +++ /dev/null @@ -1,676 +0,0 @@ -# Ranking with Deep and Cross Networks - -**Author:** [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Rank movies using Deep and Cross Networks (DCN). - - - [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/dcn.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/dcn.py) - - - ---- -## Introduction - -This tutorial demonstrates how to use Deep & Cross Networks (DCN) to effectively -learn feature crosses. Before diving into the example, let's briefly discuss -feature crosses. - -Imagine that we are building a recommender system for blenders. Individual -features might include a customer's past purchase history (e.g., -`purchased_bananas`, `purchased_cooking_books`) or geographic location. However, -a customer who has purchased both bananas and cooking books is more likely to be -interested in a blender than someone who purchased only one or the other. The -combination of `purchased_bananas` and `purchased_cooking_books` is a feature -cross. Feature crosses capture interaction information between individual -features, providing richer context than the individual features alone. - -![Why are feature crosses important?](https://i.imgur.com/qDK6UZh.gif) - -Learning effective feature crosses presents several challenges. In web-scale -applications, data is often categorical, resulting in high-dimensional and -sparse feature spaces. Identifying impactful feature crosses in such -environments typically relies on manual feature engineering or computationally -expensive exhaustive searches. While traditional feed-forward multilayer -perceptrons (MLPs) are universal function approximators, they often struggle to -efficiently learn even second- or third-order feature interactions. - -The Deep & Cross Network (DCN) architecture is designed for more effective -learning of explicit and bounded-degree feature crosses. It comprises three main -components: an input layer (typically an embedding layer), a cross network for -modeling explicit feature interactions, and a deep network for capturing -implicit interactions. - -The cross network is the core of the DCN. It explicitly performs feature -crossing at each layer, with the highest polynomial degree of feature -interaction increasing with depth. The following figure shows the `(i+1)`-th -cross layer. - -![Feature Cross Layer](https://i.imgur.com/ip5uRsl.png) - -The deep network is a standard feedforward multilayer perceptron -(MLP). These two networks are then combined to form the DCN. Two common -combination strategies exist: a stacked structure, where the deep network is -placed on top of the cross network, and a parallel structure, where they -operate in parallel. - - - - - - -
-
- Parallel layers -
Parallel layers
-
-
-
- Stacked layers -
Stacked layers
-
-
- -Now that we know a little bit about DCN, let's start writing some code. We will -first train a DCN on a toy dataset, and demonstrate that the model has indeed -learnt important feature crosses. - -Let's set the backend to JAX, and get our imports sorted. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import keras -import matplotlib.pyplot as plt -import numpy as np -import tensorflow as tf -import tensorflow_datasets as tfds -from mpl_toolkits.axes_grid1 import make_axes_locatable - -import keras_rs -``` - -Let's also define variables which will be reused throughout the example. - - -```python -TOY_CONFIG = { - "learning_rate": 0.01, - "num_epochs": 100, - "batch_size": 1024, -} - -MOVIELENS_CONFIG = { - # features - "int_features": [ - "movie_id", - "user_id", - "user_gender", - "bucketized_user_age", - ], - "str_features": [ - "user_zip_code", - "user_occupation_text", - ], - # model - "embedding_dim": 32, - "deep_net_num_units": [192, 192, 192], - "projection_dim": 20, - "dcn_num_units": [192, 192], - # training - "learning_rate": 0.01, - "num_epochs": 10, - "batch_size": 1024, -} - -LOOKUP_LAYERS = { - "int": keras.layers.IntegerLookup, - "str": keras.layers.StringLookup, -} -``` - -Here, we define a helper function for visualising weights of the cross layer in -order to better understand its functioning. Also, we define a function for -compiling, training and evaluating a given model. - - -```python - -def visualize_layer(matrix, features): - plt.figure(figsize=(9, 9)) - - im = plt.matshow(np.abs(matrix), cmap=plt.cm.Blues) - - ax = plt.gca() - divider = make_axes_locatable(plt.gca()) - cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im, cax=cax) - cax.tick_params(labelsize=10) - ax.set_xticklabels([""] + features, rotation=45, fontsize=10) - ax.set_yticklabels([""] + features, fontsize=10) - - -def train_and_evaluate( - learning_rate, - epochs, - train_data, - test_data, - model, -): - optimizer = keras.optimizers.AdamW(learning_rate=learning_rate) - loss = keras.losses.MeanSquaredError() - rmse = keras.metrics.RootMeanSquaredError() - - model.compile( - optimizer=optimizer, - loss=loss, - metrics=[rmse], - ) - - model.fit( - train_data, - epochs=epochs, - verbose=0, - ) - - results = model.evaluate(test_data, return_dict=True, verbose=0) - rmse_value = results["root_mean_squared_error"] - - return rmse_value, model.count_params() - - -def print_stats(rmse_list, num_params, model_name): - # Report metrics. - num_trials = len(rmse_list) - avg_rmse = np.mean(rmse_list) - std_rmse = np.std(rmse_list) - - if num_trials == 1: - print(f"{model_name}: RMSE = {avg_rmse}; #params = {num_params}") - else: - print( - f"{model_name}: RMSE = {avg_rmse} ± {std_rmse}; " "#params = {num_params}" - ) - -``` - ---- -## Toy Example - -To illustrate the benefits of DCNs, let's consider a simple example. Suppose we -have a dataset for modeling the likelihood of a customer clicking on a blender -advertisement. The features and label are defined as follows: - -| **Features / Label** | **Description** | **Range**| -|:--------------------:|:------------------------------:|:--------:| -| `x1` = country | Customer's resident country | [0, 199] | -| `x2` = bananas | # bananas purchased | [0, 23] | -| `x3` = cookbooks | # cooking books purchased | [0, 5] | -| `y` | Blender ad click likelihood | - | - -Then, we let the data follow the following underlying distribution: -`y = f(x1, x2, x3) = 0.1x1 + 0.4x2 + 0.7x3 + 0.1x1x2 +` -`3.1x2x3 + 0.1x3^2`. - -This distribution shows that the click likelihood (`y`) depends linearly on -individual features (`xi`) and on multiplicative interactions between them. In -this scenario, the likelihood of purchasing a blender (`y`) is influenced not -only by purchasing bananas (`x2`) or cookbooks (`x3`) individually, but also -significantly by the interaction of purchasing both bananas and cookbooks -(`x2x3`). - -### Preparing the dataset - -Let's create synthetic data based on the above equation, and form the train-test -splits. - - -```python - -def get_mixer_data(data_size=100_000): - country = np.random.randint(200, size=[data_size, 1]) / 200.0 - bananas = np.random.randint(24, size=[data_size, 1]) / 24.0 - cookbooks = np.random.randint(6, size=[data_size, 1]) / 6.0 - - x = np.concatenate([country, bananas, cookbooks], axis=1) - - # Create 1st-order terms. - y = 0.1 * country + 0.4 * bananas + 0.7 * cookbooks - - # Create 2nd-order cross terms. - y += ( - 0.1 * country * bananas - + 3.1 * bananas * cookbooks - + (0.1 * cookbooks * cookbooks) - ) - - return x, y - - -x, y = get_mixer_data(data_size=100_000) -num_train = 90_000 -train_x = x[:num_train] -train_y = y[:num_train] -test_x = x[num_train:] -test_y = y[num_train:] -``` - -### Building the model - -To demonstrate the advantages of a cross network in recommender systems, we'll -compare its performance with a deep network. Since our example data only -contains second-order feature interactions, a single-layered cross network will -suffice. For datasets with higher-order interactions, multiple cross layers can -be stacked to form a multi-layered cross network. We will build two models: - -1. A cross network with a single cross layer. -2. A deep network with wider and deeper feedforward layers. - - -```python -cross_network = keras.Sequential( - [ - keras_rs.layers.FeatureCross(), - keras.layers.Dense(1), - ] -) - -deep_network = keras.Sequential( - [ - keras.layers.Dense(512, activation="relu"), - keras.layers.Dense(256, activation="relu"), - keras.layers.Dense(128, activation="relu"), - ] -) -``` - -### Model training - -Before we train the model, we need to batch our datasets. - - -```python -train_ds = tf.data.Dataset.from_tensor_slices((train_x, train_y)).batch( - TOY_CONFIG["batch_size"] -) -test_ds = tf.data.Dataset.from_tensor_slices((test_x, test_y)).batch( - TOY_CONFIG["batch_size"] -) -``` - -Let's train both models. Remember we have set `verbose=0` for brevity's -sake, so do not be alarmed if you do not see any output for a while. - -After training, we evaluate the models on the unseen dataset. We will report -the Root Mean Squared Error (RMSE) here. - -We observe that the cross network achieved significantly lower RMSE compared to -a ReLU-based DNN, while also using fewer parameters. This points to the -efficiency of the cross network in learning feature interactions. - - -```python -cross_network_rmse, cross_network_num_params = train_and_evaluate( - learning_rate=TOY_CONFIG["learning_rate"], - epochs=TOY_CONFIG["num_epochs"], - train_data=train_ds, - test_data=test_ds, - model=cross_network, -) -print_stats( - rmse_list=[cross_network_rmse], - num_params=cross_network_num_params, - model_name="Cross Network", -) - -deep_network_rmse, deep_network_num_params = train_and_evaluate( - learning_rate=TOY_CONFIG["learning_rate"], - epochs=TOY_CONFIG["num_epochs"], - train_data=train_ds, - test_data=test_ds, - model=deep_network, -) -print_stats( - rmse_list=[deep_network_rmse], - num_params=deep_network_num_params, - model_name="Deep Network", -) -``` - -
-``` -Cross Network: RMSE = 0.0001293081877520308; #params = 16 - -Deep Network: RMSE = 0.13307014107704163; #params = 166272 - -``` -
-### Visualizing feature interactions - -Since we already know which feature crosses are important in our data, it would -be interesting to verify whether our model has indeed learned these key feature -interactions. This can be done by visualizing the learned weight matrix in the -cross network, where the weight `Wij` represents the learned importance of -the interaction between features `xi` and `xj`. - - -```python -visualize_layer( - matrix=cross_network.weights[0].numpy(), - features=["country", "purchased_bananas", "purchased_cookbooks"], -) -``` - -
-``` -:11: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator. - ax.set_xticklabels([""] + features, rotation=45, fontsize=10) -:12: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator. - ax.set_yticklabels([""] + features, fontsize=10) - -
- -``` -
- -![png](/img/examples/keras_rs/dcn/dcn_16_2.png) - - - ---- -## Real-world example - -Let's use the MovieLens 100K dataset. This dataset is used to train models to -predict users' movie ratings, based on user-related features and movie-related -features. - -### Preparing the dataset - -The dataset processing steps here are similar to what's given in the -[basic ranking](/keras_rs/examples/basic_ranking/) -tutorial. Let's load the dataset, and keep only the useful columns. - - -```python -ratings_ds = tfds.load("movielens/100k-ratings", split="train") -ratings_ds = ratings_ds.map( - lambda x: ( - { - "movie_id": int(x["movie_id"]), - "user_id": int(x["user_id"]), - "user_gender": int(x["user_gender"]), - "user_zip_code": x["user_zip_code"], - "user_occupation_text": x["user_occupation_text"], - "bucketized_user_age": int(x["bucketized_user_age"]), - }, - x["user_rating"], # label - ) -) -``` - -
-``` -WARNING:absl:Variant folder /root/tensorflow_datasets/movielens/100k-ratings/0.1.1 has no dataset_info.json - -Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/movielens/100k-ratings/0.1.1... - -Dl Completed...: 0 url [00:00, ? url/s] - -Dl Size...: 0 MiB [00:00, ? MiB/s] - -Extraction completed...: 0 file [00:00, ? file/s] - -Generating splits...: 0%| | 0/1 [00:00 -For every feature, let's get the list of unique values, i.e., vocabulary, so -that we can use that for the embedding layer. - - -```python -vocabularies = {} -for feature_name in MOVIELENS_CONFIG["int_features"] + MOVIELENS_CONFIG["str_features"]: - vocabulary = ratings_ds.batch(10_000).map(lambda x, y: x[feature_name]) - vocabularies[feature_name] = np.unique(np.concatenate(list(vocabulary))) -``` - -One thing we need to do is to use `keras.layers.StringLookup` and -`keras.layers.IntegerLookup` to convert all features into indices, which can -then be fed into embedding layers. - - -```python -lookup_layers = {} -lookup_layers.update( - { - feature: keras.layers.IntegerLookup(vocabulary=vocabularies[feature]) - for feature in MOVIELENS_CONFIG["int_features"] - } -) -lookup_layers.update( - { - feature: keras.layers.StringLookup(vocabulary=vocabularies[feature]) - for feature in MOVIELENS_CONFIG["str_features"] - } -) - -ratings_ds = ratings_ds.map( - lambda x, y: ( - { - feature_name: lookup_layers[feature_name](x[feature_name]) - for feature_name in vocabularies - }, - y, - ) -) -``` - -Let's split our data into train and test sets. We also use `cache()` and -`prefetch()` for better performance. - - -```python -ratings_ds = ratings_ds.shuffle(100_000) - -train_ds = ( - ratings_ds.take(80_000) - .batch(MOVIELENS_CONFIG["batch_size"]) - .cache() - .prefetch(tf.data.AUTOTUNE) -) -test_ds = ( - ratings_ds.skip(80_000) - .batch(MOVIELENS_CONFIG["batch_size"]) - .take(20_000) - .cache() - .prefetch(tf.data.AUTOTUNE) -) -``` - -### Building the model - -The model will have embedding layers, followed by cross and/or feedforward -layers. - - -```python - -def get_model( - dense_num_units_lst, - embedding_dim=MOVIELENS_CONFIG["embedding_dim"], - use_cross_layer=False, - projection_dim=None, -): - inputs = {} - embeddings = [] - for feature_name, vocabulary in vocabularies.items(): - inputs[feature_name] = keras.Input(shape=(), dtype="int32", name=feature_name) - embedding_layer = keras.layers.Embedding( - input_dim=len(vocabulary) + 1, - output_dim=embedding_dim, - ) - embedding = embedding_layer(inputs[feature_name]) - embeddings.append(embedding) - - x = keras.ops.concatenate(embeddings, axis=1) - - # Cross layer. - if use_cross_layer: - x = keras_rs.layers.FeatureCross(projection_dim=projection_dim)(x) - - # Dense layer. - for num_units in dense_num_units_lst: - x = keras.layers.Dense(num_units, activation="relu")(x) - - x = keras.layers.Dense(1)(x) - - return keras.Model(inputs=inputs, outputs=x) - -``` - -We have three models - a deep cross network, an optimised deep cross -network with a low-rank matrix (to reduce training and serving costs) and a -normal deep network without cross layers. The deep cross network is a stacked -DCN model, i.e., the inputs are fed to cross layers, followed by feedforward -layers. Let's run each model 10 times, and report the average/standard -deviation of the RMSE. - - -```python -cross_network_rmse_list = [] -opt_cross_network_rmse_list = [] -deep_network_rmse_list = [] - -for _ in range(10): - cross_network = get_model( - dense_num_units_lst=MOVIELENS_CONFIG["dcn_num_units"], - embedding_dim=MOVIELENS_CONFIG["embedding_dim"], - use_cross_layer=True, - ) - rmse, cross_network_num_params = train_and_evaluate( - learning_rate=MOVIELENS_CONFIG["learning_rate"], - epochs=MOVIELENS_CONFIG["num_epochs"], - train_data=train_ds, - test_data=test_ds, - model=cross_network, - ) - cross_network_rmse_list.append(rmse) - - opt_cross_network = get_model( - dense_num_units_lst=MOVIELENS_CONFIG["dcn_num_units"], - embedding_dim=MOVIELENS_CONFIG["embedding_dim"], - use_cross_layer=True, - projection_dim=MOVIELENS_CONFIG["projection_dim"], - ) - rmse, opt_cross_network_num_params = train_and_evaluate( - learning_rate=MOVIELENS_CONFIG["learning_rate"], - epochs=MOVIELENS_CONFIG["num_epochs"], - train_data=train_ds, - test_data=test_ds, - model=opt_cross_network, - ) - opt_cross_network_rmse_list.append(rmse) - - deep_network = get_model(dense_num_units_lst=MOVIELENS_CONFIG["deep_net_num_units"]) - rmse, deep_network_num_params = train_and_evaluate( - learning_rate=MOVIELENS_CONFIG["learning_rate"], - epochs=MOVIELENS_CONFIG["num_epochs"], - train_data=train_ds, - test_data=test_ds, - model=deep_network, - ) - deep_network_rmse_list.append(rmse) - -print_stats( - rmse_list=cross_network_rmse_list, - num_params=cross_network_num_params, - model_name="Cross Network", -) -print_stats( - rmse_list=opt_cross_network_rmse_list, - num_params=opt_cross_network_num_params, - model_name="Optimised Cross Network", -) -print_stats( - rmse_list=deep_network_rmse_list, - num_params=deep_network_num_params, - model_name="Deep Network", -) -``` - -
-``` -Cross Network: RMSE = 0.9427602052688598 ± 0.07614302893494468; #params = {num_params} -Optimised Cross Network: RMSE = 0.9187218248844147 ± 0.031170624868084987; #params = {num_params} -Deep Network: RMSE = 0.8789893209934234 ± 0.025684711934398047; #params = {num_params} - -``` -
-DCN outperforms a similarly sized DNN with ReLU layers, demonstrating -superior performance. Furthermore, the low-rank DCN effectively reduces the -number of parameters without compromising accuracy. - -### Visualizing feature interactions - -Like we did for the toy example, we will plot the weight matrix of the cross -layer to see which feature crosses are important. In the previous example, -the importance of interactions between the `i`-th and `j-th` features is -captured by the `(i, j)`-{th} element of the weight matrix. - -In this case, the feature embeddings are of size 32 rather than 1. Therefore, -the importance of feature interactions is represented by the `(i, j)`-th -block of the weight matrix, which has dimensions `32 x 32`. To quantify the -significance of these interactions, we use the Frobenius norm of each block. A -larger value implies higher importance. - - -```python -features = list(vocabularies.keys()) -mat = cross_network.weights[len(features)].numpy() -embedding_dim = MOVIELENS_CONFIG["embedding_dim"] - -block_norm = np.zeros([len(features), len(features)]) - -# Compute the norms of the blocks. -for i in range(len(features)): - for j in range(len(features)): - block = mat[ - i * embedding_dim : (i + 1) * embedding_dim, - j * embedding_dim : (j + 1) * embedding_dim, - ] - block_norm[i, j] = np.linalg.norm(block, ord="fro") - -visualize_layer( - matrix=block_norm, - features=features, -) -``` - -
-``` -:11: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator. - ax.set_xticklabels([""] + features, rotation=45, fontsize=10) -:12: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator. - ax.set_yticklabels([""] + features, fontsize=10) - -
- -``` -
- -![png](/img/examples/keras_rs/dcn/dcn_31_2.png) - - - -And we are all done! diff --git a/templates/keras_rs/examples/deep_recommender.md b/templates/keras_rs/examples/deep_recommender.md deleted file mode 100644 index ad154675c7..0000000000 --- a/templates/keras_rs/examples/deep_recommender.md +++ /dev/null @@ -1,5439 +0,0 @@ -# Deep Recommenders - -**Author:** [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Building a deep retrieval model with multiple stacked layers. - - - [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/deep_recommender.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/deep_recommender.py) - - - ---- -## Introduction - -One of the great advantages of using Keras to build recommender models is the -freedom to build rich, flexible feature representations. - -The first step in doing so is preparing the features, as raw features will -usually not be immediately usable in a model. - -For example: -- User and item IDs may be strings (titles, usernames) or large, non-contiguous - integers (database IDs). -- Item descriptions could be raw text. -- Interaction timestamps could be raw Unix timestamps. - -These need to be appropriately transformed in order to be useful in building -models: -- User and item IDs have to be translated into embedding vectors, - high-dimensional numerical representations that are adjusted during training - to help the model predict its objective better. -- Raw text needs to be tokenized (split into smaller parts such as individual - words) and translated into embeddings. -- Numerical features need to be normalized so that their values lie in a small - interval around 0. - -Fortunately, the Keras -[`FeatureSpace`](/api/utils/feature_space/) utility makes this -preprocessing easy. - -In this tutorial, we are going to incorporate multiple features in our models. -These features will come from preprocessing the MovieLens dataset. - -In the -[basic retrieval](/keras_rs/examples/basic_retrieval/) -tutorial, the models consist of only an embedding layer. In this tutorial, we -add more dense layers to our models to increase their expressive power. - -In general, deeper models are capable of learning more complex patterns than -shallower models. For example, our user model incorporates user IDs and user -features such as age, gender and occupation. A shallow model (say, a single -embedding layer) may only be able to learn the simplest relationships between -those features and movies: a given user generally prefers horror movies to -comedies. To capture more complex relationships, such as user preferences -evolving with their age, we may need a deeper model with multiple stacked dense -layers. - -Of course, complex models also have their disadvantages. The first is -computational cost, as larger models require both more memory and more -computation to train and serve. The second is the requirement for more data. In -general, more training data is needed to take advantage of deeper models. With -more parameters, deep models might overfit or even simply memorize the training -examples instead of learning a function that can generalize. Finally, training -deeper models may be harder, and more care needs to be taken in choosing -settings like regularization and learning rate. - -Finding a good architecture for a real-world recommender system is a complex -art, requiring good intuition and careful hyperparameter tuning. For example, -factors such as the depth and width of the model, activation function, learning -rate, and optimizer can radically change the performance of the model. Modelling -choices are further complicated by the fact that good offline evaluation metrics -may not correspond to good online performance, and that the choice of what to -optimize for is often more critical than the choice of model itself. - -Nevertheless, effort put into building and fine-tuning larger models often pays -off. In this tutorial, we will illustrate how to build a deep retrieval model. -We'll do this by building progressively more complex models to see how this -affects model performance. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import keras -import matplotlib.pyplot as plt -import tensorflow as tf # Needed for the dataset -import tensorflow_datasets as tfds - -import keras_rs -``` - ---- -## The MovieLens dataset - -Let's first have a look at what features we can use from the MovieLens dataset. - - -```python -# Ratings data with user and movie data. -ratings = tfds.load("movielens/100k-ratings", split="train") -# Features of all the available movies. -movies = tfds.load("movielens/100k-movies", split="train") -``` - -The ratings dataset returns a dictionary of movie id, user id, the assigned -rating, timestamp, movie information, and user information: - - -```python -for data in ratings.take(1).as_numpy_iterator(): - print(str(data).replace(", '", ",\n '")) -``` - -
-``` -{'bucketized_user_age': np.float32(45.0), - 'movie_genres': array([7]), - 'movie_id': b'357', - 'movie_title': b"One Flew Over the Cuckoo's Nest (1975)", - 'raw_user_age': np.float32(46.0), - 'timestamp': np.int64(879024327), - 'user_gender': np.True_, - 'user_id': b'138', - 'user_occupation_label': np.int64(4), - 'user_occupation_text': b'doctor', - 'user_rating': np.float32(4.0), - 'user_zip_code': b'53211'} - -``` -
-In the Movielens dataset, user IDs are integers (represented as strings) -starting at 1 and with no gap. Normally, you would need to create a lookup table -to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the -user id directly as an index in our model, in particular to lookup the user -embedding from the user embedding table. So we need do know the number of users. - - -```python -USERS_COUNT = ( - ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) - .reduce(tf.constant(0, tf.int32), tf.maximum) - .numpy() -) -``` - -The movies dataset contains the movie id, movie title, and the genres it belongs -to. Note that the genres are encoded with integer labels. - - -```python -for data in movies.take(1).as_numpy_iterator(): - print(str(data).replace(", '", ",\n '")) -``` - -
-``` -{'movie_genres': array([4]), - 'movie_id': b'1681', - 'movie_title': b'You So Crazy (1994)'} - -``` -
-In the Movielens dataset, movie IDs are integers (represented as strings) -starting at 1 and with no gap. Normally, you would need to create a lookup table -to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the -movie id directly as an index in our model, in particular to lookup the movie -embedding from the movie embedding table. So we need do know the number of -movies. - - -```python -MOVIES_COUNT = movies.cardinality().numpy() -``` - ---- -## Preprocessing the dataset - -### Normalizing continuous features - -Continuous features may need normalization so that they fall within an -acceptable range for the model. We will give two examples of such normalization. - -#### Discretization - -A common transformation is to turn a continuous feature into a number of -categorical features. This makes good sense if we have reasons to suspect that a -feature's effect is non-continuous. - -We need to decide on a number the buckets we will use for discretization. Then, -we will use the Keras `FeatureSpace` utility to automatically find the minimum -and maximum value, and divide that range by the number of buckets to perform the -discretization. - -In this example, we will discretize the user age. - - -```python -AGE_BINS_COUNT = 10 -user_age_feature = keras.utils.FeatureSpace.float_discretized( - num_bins=AGE_BINS_COUNT, output_mode="int" -) -``` - -#### Rescaling - -Often, we want continous features to be between 0 and 1, or between -1 and 1. -To achieve this, we can rescale features that have a different range. - -In this example, we will standardize the rating, which is a integer between 1 -and 5, to be a float between 0 and 1. We need to rescale it and offset it. - - -```python -user_rating_feature = keras.utils.FeatureSpace.float_rescaled( - scale=1.0 / 4.0, offset=-1.0 / 4.0 -) -``` - -### Turning categorical features into embeddings - -A categorical feature is a feature that does not express a continuous quantity, -but rather takes on one of a set of fixed values. - -Most deep learning models express these feature by turning them into -high-dimensional vectors. During model training, the value of that vector is -adjusted to help the model predict its objective better. - -For example, suppose that our goal is to predict which user is going to watch -which movie. To do that, we represent each user and each movie by an embedding -vector. Initially, these embeddings will take on random values. During training, -we adjust them so that embeddings of users and the movies they watch end up -closer together. - -Taking raw categorical features and turning them into embeddings is normally a -two-step process: -1. First, we need to translate the raw values into a range of contiguous - integers, normally by building a mapping (called a "vocabulary") that maps - raw values to integers. -2. Second, we need to take these integers and turn them into embeddings. - -#### Defining categorical features - -We will use the Keras `FeatureSpace` utility for the first step. Its `adapt` -method automatically discovers the vocabulary for categorical features. - - -```python -user_gender_feature = keras.utils.FeatureSpace.integer_categorical( - num_oov_indices=0, output_mode="int" -) -user_occupation_feature = keras.utils.FeatureSpace.integer_categorical( - num_oov_indices=0, output_mode="int" -) -``` - -#### Using feature crosses - -With crosses we can do feature interactions between multiple categorical -features. This can be powerful to express that the combination of features -represents a specific taste for movies. - -Note that the combination of multiple features can result into on a super large -feature space, that is why the crossing_dim parameter is important to limit the -output dimension of the cross feature. - -In this example, we will cross age and gender with the Keras `FeatureSpace` -utility. - - -```python -USER_GENDER_CROSS_COUNT = 20 -user_gender_age_cross = keras.utils.FeatureSpace.cross( - feature_names=("user_gender", "raw_user_age"), - crossing_dim=USER_GENDER_CROSS_COUNT, - output_mode="int", -) -``` - -### Processing text features - -We may also want to add text features to our model. Usually, things like product -descriptions are free form text, and we can hope that our model can learn to use -the information they contain to make better recommendations, especially in a -cold-start or long tail scenario. - -While the MovieLens dataset does not give us rich textual features, we can still -use movie titles. This may help us capture the fact that movies with very -similar titles are likely to belong to the same series. - -The first transformation we need to apply to text is tokenization (splitting -into constituent words or word-pieces), followed by vocabulary learning, -followed by an embedding. - - -The -[`keras.layers.TextVectorization`](/api/layers/preprocessing_layers/text/text_vectorization/) -layer can do the first two steps for us. - - -```python -title_vectorizer = keras.layers.TextVectorization( - max_tokens=10_000, output_sequence_length=16, dtype="int32" -) -title_vectorizer.adapt(movies.map(lambda x: x["movie_title"])) -``` - -Let's try it out: - - -```python -for data in movies.take(1).as_numpy_iterator(): - print(title_vectorizer(data["movie_title"])) -``` - -
-``` -[ 59 187 622 5 0 0 0 0 0 0 0 0 0 0 0 0] - -``` -
-Each title is translated into a sequence of tokens, one for each piece we've -tokenized. - -We can check the learned vocabulary to verify that the layer is using the -correct tokenization: - - -```python -print(title_vectorizer.get_vocabulary()[40:50]) -``` - -
-``` -[np.str_('paris'), np.str_('little'), np.str_('last'), np.str_('ii'), np.str_('1988'), np.str_('king'), np.str_('from'), np.str_('city'), np.str_('boys'), np.str_('murder')] - -``` -
-This looks correct, the layer is tokenizing titles into individual words. Later, -we will see how to embed this tokenized text. For now, we turn this vectorizer -into a Keras `FeatureSpace` feature. - - -```python -title_feature = keras.utils.FeatureSpace.feature( - preprocessor=title_vectorizer, dtype="string", output_mode="float" -) -TITLE_TOKEN_COUNT = title_vectorizer.vocabulary_size() -``` - -### Putting the FeatureSpace features together - -We're now ready to assemble the features with preprocessors in a `FeatureSpace` -object. We're then using `adapt` to go through the dataset and learn what needs -to be learned, such as the vocabulary size for categorical features or the -minimum and maximum values for bucketized features. - - -```python -feature_space = keras.utils.FeatureSpace( - features={ - # Numerical features to discretize. - "raw_user_age": user_age_feature, - # Categorical features encoded as integers. - "user_gender": user_gender_feature, - "user_occupation_label": user_occupation_feature, - # Labels are ratings between 0 and 1. - "user_rating": user_rating_feature, - "movie_title": title_feature, - }, - crosses=[user_gender_age_cross], - output_mode="dict", -) - -feature_space.adapt(ratings) -GENDERS_COUNT = feature_space.preprocessors["user_gender"].vocabulary_size() -OCCUPATIONS_COUNT = feature_space.preprocessors[ - "user_occupation_label" -].vocabulary_size() -``` - ---- -## Pre-building the candidate set - -Our model is going to based on a `Retrieval` layer, which can provides a set of -best candidates among to full set of candidates. To do this, the retrieval layer -needs to know all the candidates and their features. In this section, we -assemble the full set of movies with the associated features. - -### Extract raw candidate features - -First, we gather all the raw features from the dataset in lists. That is the -titles of the movies and the genres. Note that one or more genres are -associated with each movie, and the number of genres varies per movie. - - -```python -movie_titles = [""] * (MOVIES_COUNT + 1) -movie_genres = [[]] * (MOVIES_COUNT + 1) -for x in movies.as_numpy_iterator(): - movie_id = int(x["movie_id"]) - movie_titles[movie_id] = x["movie_title"] - movie_genres[movie_id] = x["movie_genres"].tolist() -``` - -### Preprocess candidate features - -Genres are already in the form of category numbers starting at zero. However, we -do need to figure out two things: -- The maximum number of genres a single movie can have; this will determine the - dimension for this feature. -- The maximum value for the genre, which will give us the total number of genres - and determine the size of our embedding table for genres. - - -```python -MAX_GENRES_PER_MOVIE = 0 -max_genre_id = 0 -for one_movie_genres in movie_genres: - MAX_GENRES_PER_MOVIE = max(MAX_GENRES_PER_MOVIE, len(one_movie_genres)) - if one_movie_genres: - max_genre_id = max(max_genre_id, max(one_movie_genres)) - -GENRES_COUNT = max_genre_id + 1 -``` - -Now we need to pad genres with an Out Of Vocabulary value to be able to -represent genres as a fixed size vector. We'll pad with zeros for simplicity, so -we're adding one to the genres to not conflict with genre zero, which is a valid -genre. - - -```python -movie_genres = [ - [g + 1 for g in genres] + [0] * (MAX_GENRES_PER_MOVIE - len(genres)) - for genres in movie_genres -] -``` - -Then, we vectorize all the movie titles. - - -```python -movie_titles_vectors = title_vectorizer(movie_titles) -``` - -### Convert candidate set to native tensors - -We're now ready to combine these in a dataset. The last step is to make sure -everything is a native tensor that can be consumed by the retrieval layer. -As a remminder, movie id zero does not exist. - - -```python -MOVIES_DATASET = { - "movie_id": keras.ops.arange(0, MOVIES_COUNT + 1, dtype="int32"), - "movie_title_vector": movie_titles_vectors, - "movie_genres": keras.ops.convert_to_tensor(movie_genres, dtype="int32"), -} -``` - ---- -## Preparing the data - -We can now define our preprocessing function. Most features will be handled -by the `FeatureSpace`. User IDs and Movie IDs need to be extracted. Movie genres -need to be padded. Then everything is packaged as a tuple with a dict of input -features and a float for the rating, which is used as a label. - - -```python - -def preprocess_rating(x): - features = feature_space( - { - "raw_user_age": x["raw_user_age"], - "user_gender": x["user_gender"], - "user_occupation_label": x["user_occupation_label"], - "user_rating": x["user_rating"], - "movie_title": x["movie_title"], - } - ) - features = {k: tf.squeeze(v, axis=0) for k, v in features.items()} - movie_genres = x["movie_genres"] - - return ( - { - # User inputs are user ID and user features - "user_id": int(x["user_id"]), - "raw_user_age": features["raw_user_age"], - "user_gender": features["user_gender"], - "user_occupation_label": features["user_occupation_label"], - "user_gender_X_raw_user_age": tf.squeeze( - features["user_gender_X_raw_user_age"], axis=-1 - ), - # Movie inputs are movie ID, vectorized title and genres - "movie_id": int(x["movie_id"]), - "movie_title_vector": features["movie_title"], - "movie_genres": tf.pad( - movie_genres + 1, - [[0, MAX_GENRES_PER_MOVIE - tf.shape(movie_genres)[0]]], - ), - }, - # Label is user rating between 0 and 1 - features["user_rating"], - ) - -``` - -We shuffle and then split the data into a training set and a testing set. - - -```python -shuffled_ratings = ratings.map(preprocess_rating).shuffle( - 100_000, seed=42, reshuffle_each_iteration=False -) - -train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() -test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() -``` - ---- -## Model definition - -### Query model - -The query model is first tasked with converting user features to embeddings. The -embeddings are then concatenated into a single vector. - -Defining deeper models will require us to stack more layers on top of this first -set of embeddings. A progressively narrower stack of layers, separated by an -activation function, is a common pattern: - -``` - +----------------------+ - | 64 x 32 | - +----------------------+ - | relu - +--------------------------+ - | 128 x 64 | - +--------------------------+ - | relu - +------------------------------+ - | ... x 128 | - +------------------------------+ -``` - -Since the expressive power of deep linear models is no greater than that of -shallow linear models, we use ReLU activations for all but the last hidden -layer. The final hidden layer does not use any activation function: using an -activation function would limit the output space of the final embeddings and -might negatively impact the performance of the model. For instance, if ReLUs are -used in the projection layer, all components in the output embedding would be -non-negative. - -We're going to try this here. To make experimentation with different depths -easy, let's define a model whose depth (and width) is defined by a constructor -parameters. The `layer_sizes` parameter gives us the depth and width of the -model. We can vary it to experiment with shallower or deeper models. - - -```python - -class QueryModel(keras.Model): - """Model for encoding user queries.""" - - def __init__(self, layer_sizes, embedding_dimension=32): - """Construct a model for encoding user queries. - - Args: - layer_sizes: A list of integers where the i-th entry represents the - number of units the i-th layer contains. - embedding_dimension: Output dimension for all embedding tables. - """ - super().__init__() - - # We first generate embeddings. - self.user_embedding = keras.layers.Embedding( - # +1 for user ID zero, which does not exist - USERS_COUNT + 1, - embedding_dimension, - ) - self.gender_embedding = keras.layers.Embedding( - GENDERS_COUNT, embedding_dimension - ) - self.age_embedding = keras.layers.Embedding(AGE_BINS_COUNT, embedding_dimension) - self.gender_x_age_embedding = keras.layers.Embedding( - USER_GENDER_CROSS_COUNT, embedding_dimension - ) - self.occupation_embedding = keras.layers.Embedding( - OCCUPATIONS_COUNT, embedding_dimension - ) - - # Then construct the layers. - self.dense_layers = keras.Sequential() - - # Use the ReLU activation for all but the last layer. - for layer_size in layer_sizes[:-1]: - self.dense_layers.add(keras.layers.Dense(layer_size, activation="relu")) - - # No activation for the last layer. - self.dense_layers.add(keras.layers.Dense(layer_sizes[-1])) - - def call(self, inputs): - # Take the inputs, pass each through its embedding layer, concatenate. - feature_embedding = keras.ops.concatenate( - [ - self.user_embedding(inputs["user_id"]), - self.gender_embedding(inputs["user_gender"]), - self.age_embedding(inputs["raw_user_age"]), - self.gender_x_age_embedding(inputs["user_gender_X_raw_user_age"]), - self.occupation_embedding(inputs["user_occupation_label"]), - ], - axis=1, - ) - return self.dense_layers(feature_embedding) - -``` - ---- -## Candidate model - -We can adopt the same approach for the candidate model. Again, we start with -converting movie features to embeddings, concatenate them and then expand it -with hidden layers: - - -```python - -class CandidateModel(keras.Model): - """Model for encoding candidates (movies).""" - - def __init__(self, layer_sizes, embedding_dimension=32): - """Construct a model for encoding candidates (movies). - - Args: - layer_sizes: A list of integers where the i-th entry represents the - number of units the i-th layer contains. - embedding_dimension: Output dimension for all embedding tables. - """ - super().__init__() - - # We first generate embeddings. - self.movie_embedding = keras.layers.Embedding( - # +1 for movie ID zero, which does not exist - MOVIES_COUNT + 1, - embedding_dimension, - ) - # Take all the title tokens for the title of the movie, embed each - # token, and then take the mean of all token embeddings. - self.movie_title_embedding = keras.Sequential( - [ - keras.layers.Embedding( - # +1 for OOV token, which is used for padding - TITLE_TOKEN_COUNT + 1, - embedding_dimension, - mask_zero=True, - ), - keras.layers.GlobalAveragePooling1D(), - ] - ) - # Take all the genres for the movie, embed each genre, and then take the - # mean of all genre embeddings. - self.movie_genres_embedding = keras.Sequential( - [ - keras.layers.Embedding( - # +1 for OOV genre, which is used for padding - GENRES_COUNT + 1, - embedding_dimension, - mask_zero=True, - ), - keras.layers.GlobalAveragePooling1D(), - ] - ) - - # Then construct the layers. - self.dense_layers = keras.Sequential() - - # Use the ReLU activation for all but the last layer. - for layer_size in layer_sizes[:-1]: - self.dense_layers.add(keras.layers.Dense(layer_size, activation="relu")) - - # No activation for the last layer. - self.dense_layers.add(keras.layers.Dense(layer_sizes[-1])) - - def call(self, inputs): - movie_id = inputs["movie_id"] - movie_title_vector = inputs["movie_title_vector"] - movie_genres = inputs["movie_genres"] - feature_embedding = keras.ops.concatenate( - [ - self.movie_embedding(movie_id), - self.movie_title_embedding(movie_title_vector), - self.movie_genres_embedding(movie_genres), - ], - axis=1, - ) - return self.dense_layers(feature_embedding) - -``` - ---- -## Combined model - -With both QueryModel and CandidateModel defined, we can put together a combined -model and implement our loss and metrics logic. To make things simple, we'll -enforce that the model structure is the same across the query and candidate -models. - - -```python - -class RetrievalModel(keras.Model): - """Combined model.""" - - def __init__( - self, - layer_sizes=(32,), - embedding_dimension=32, - retrieval_k=100, - ): - """Construct a combined model. - - Args: - layer_sizes: A list of integers where the i-th entry represents the - number of units the i-th layer contains. - embedding_dimension: Output dimension for all embedding tables. - retrieval_k: How many candidate movies to retrieve. - """ - super().__init__() - self.query_model = QueryModel(layer_sizes, embedding_dimension) - self.candidate_model = CandidateModel(layer_sizes, embedding_dimension) - self.retrieval = keras_rs.layers.BruteForceRetrieval( - k=retrieval_k, return_scores=False - ) - self.update_candidates() # Provide an initial set of candidates - self.loss_fn = keras.losses.MeanSquaredError() - self.top_k_metric = keras.metrics.SparseTopKCategoricalAccuracy( - k=100, from_sorted_ids=True - ) - - def update_candidates(self): - self.retrieval.update_candidates( - self.candidate_model.predict(MOVIES_DATASET, verbose=0) - ) - - def call(self, inputs, training=False): - query_embeddings = self.query_model( - { - "user_id": inputs["user_id"], - "raw_user_age": inputs["raw_user_age"], - "user_gender": inputs["user_gender"], - "user_occupation_label": inputs["user_occupation_label"], - "user_gender_X_raw_user_age": inputs["user_gender_X_raw_user_age"], - } - ) - candidate_embeddings = self.candidate_model( - { - "movie_id": inputs["movie_id"], - "movie_title_vector": inputs["movie_title_vector"], - "movie_genres": inputs["movie_genres"], - } - ) - - result = { - "query_embeddings": query_embeddings, - "candidate_embeddings": candidate_embeddings, - } - if not training: - # No need to spend time extracting top predicted movies during - # training, they are not used. - result["predictions"] = self.retrieval(query_embeddings) - return result - - def evaluate( - self, - x=None, - y=None, - batch_size=None, - verbose="auto", - sample_weight=None, - steps=None, - callbacks=None, - return_dict=False, - **kwargs, - ): - """Overridden to update the candidate set. - - Before evaluating the model, we need to update our retrieval layer by - re-computing the values predicted by the candidate model for all the - candidates. - """ - self.update_candidates() - return super().evaluate( - x, - y, - batch_size=batch_size, - verbose=verbose, - sample_weight=sample_weight, - steps=steps, - callbacks=callbacks, - return_dict=return_dict, - **kwargs, - ) - - def compute_loss(self, x, y, y_pred, sample_weight, training=True): - query_embeddings = y_pred["query_embeddings"] - candidate_embeddings = y_pred["candidate_embeddings"] - - labels = keras.ops.expand_dims(y, -1) - # Compute the affinity score by multiplying the two embeddings. - scores = keras.ops.sum( - keras.ops.multiply(query_embeddings, candidate_embeddings), - axis=1, - keepdims=True, - ) - return self.loss_fn(labels, scores, sample_weight) - - def compute_metrics(self, x, y, y_pred, sample_weight=None): - if "predictions" in y_pred: - # We are evaluating or predicting. Update `top_k_metric`. - movie_ids = x["movie_id"] - predictions = y_pred["predictions"] - # For `top_k_metric`, which is a `SparseTopKCategoricalAccuracy`, we - # only take top rated movies, and we put a weight of 0 for the rest. - rating_weight = keras.ops.cast(keras.ops.greater(y, 0.9), "float32") - sample_weight = ( - rating_weight - if sample_weight is None - else keras.ops.multiply(rating_weight, sample_weight) - ) - self.top_k_metric.update_state( - movie_ids, predictions, sample_weight=sample_weight - ) - return self.get_metrics_result() - else: - # We are training. `top_k_metric` is not updated and is zero, so - # don't report it. - result = self.get_metrics_result() - result.pop(self.top_k_metric.name) - return result - -``` - ---- -## Training the model - -### Shallow model - -We're ready to try out our first, shallow, model! - - -```python -NUM_EPOCHS = 30 - -one_layer_model = RetrievalModel((32,)) -one_layer_model.compile(optimizer=keras.optimizers.Adagrad(0.05)) - -one_layer_history = one_layer_model.fit( - train_ratings, - validation_data=test_ratings, - validation_freq=5, - epochs=NUM_EPOCHS, -) -``` - -
-``` -Epoch 1/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 21:56 17s/step - loss: 0.4487 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 1:03 811ms/step - loss: 0.4548 - -
-``` - -``` -
- 12/80 ━━━━━━━━━━━━━━━━━━━━ 5s 78ms/step - loss: 0.4569 - -
-``` - -``` -
- 22/80 ━━━━━━━━━━━━━━━━━━━━ 2s 44ms/step - loss: 0.4031 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - loss: 0.3579 - -
-``` - -``` -
- 41/80 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step - loss: 0.3203 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - loss: 0.2923 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.2725 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.2548 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.2403 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 18s 18ms/step - loss: 0.2390 - - -
-``` -Epoch 2/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1:08 868ms/step - loss: 0.0760 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0760 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0761 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0762 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0762 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0762 - - -
-``` -Epoch 3/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0738 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0740 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0741 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0742 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0743 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0743 - - -
-``` -Epoch 4/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0722 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0726 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0727 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0728 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0728 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0729 - - -
-``` -Epoch 5/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0708 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0714 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0715 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0716 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0716 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 17s 221ms/step - loss: 0.0717 - val_loss: 0.0727 - val_sparse_top_k_categorical_accuracy: 0.1794 - - -
-``` -Epoch 6/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1s 22ms/step - loss: 0.0695 - -
-``` - -``` -
- 15/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0703 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0704 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0705 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0706 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0706 - - -
-``` -Epoch 7/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0683 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0693 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0694 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0695 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0695 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0696 - - -
-``` -Epoch 8/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0671 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0683 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0684 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0685 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0686 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0686 - - -
-``` -Epoch 9/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0659 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0674 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0675 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0675 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0676 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0676 - - -
-``` -Epoch 10/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0648 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0665 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0665 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0666 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0666 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0667 - val_loss: 0.0679 - val_sparse_top_k_categorical_accuracy: 0.2392 - - -
-``` -Epoch 11/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0637 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0655 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0656 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0657 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0657 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0657 - - -
-``` -Epoch 12/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0626 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0646 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0647 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0647 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0648 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0648 - - -
-``` -Epoch 13/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0615 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0637 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0638 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - - -
-``` -Epoch 14/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0605 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0629 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0629 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0630 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0630 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0630 - - -
-``` -Epoch 15/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0595 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0621 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0621 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0621 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0622 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0622 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0622 - val_loss: 0.0636 - val_sparse_top_k_categorical_accuracy: 0.2836 - - -
-``` -Epoch 16/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0586 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0613 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0613 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0614 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0614 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0614 - - -
-``` -Epoch 17/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0577 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0606 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0606 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0607 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0607 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0607 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0607 - - -
-``` -Epoch 18/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0569 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0599 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0599 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0600 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0600 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0600 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0600 - - -
-``` -Epoch 19/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0562 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0593 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0593 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0594 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0594 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0594 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0594 - - -
-``` -Epoch 20/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0556 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0587 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0588 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0588 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0589 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0589 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0589 - val_loss: 0.0605 - val_sparse_top_k_categorical_accuracy: 0.3118 - - -
-``` -Epoch 21/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0550 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0582 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0582 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0583 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0584 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0584 - - -
-``` -Epoch 22/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0545 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0577 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0578 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0578 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0579 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0579 - - -
-``` -Epoch 23/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0540 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0573 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0573 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0574 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0575 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0575 - - -
-``` -Epoch 24/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0536 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0569 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0570 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0570 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0571 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0571 - - -
-``` -Epoch 25/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0532 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 53/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0567 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0567 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0567 - val_loss: 0.0586 - val_sparse_top_k_categorical_accuracy: 0.3219 - - -
-``` -Epoch 26/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0529 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0563 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0563 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0563 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0564 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0564 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0564 - - -
-``` -Epoch 27/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0526 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0560 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0560 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0560 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0561 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0561 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0561 - - -
-``` -Epoch 28/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0523 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0557 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0557 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0558 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0558 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0559 - - -
-``` -Epoch 29/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0520 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0556 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0556 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0556 - - -
-``` -Epoch 30/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0518 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0553 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0553 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0554 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0554 - val_loss: 0.0574 - val_sparse_top_k_categorical_accuracy: 0.3216 - - -This gives us a top-100 accuracy of around 0.30. We can use this as a reference -point for evaluating deeper models. - -### Deeper model - -What about a deeper model with two layers? - - -```python -two_layer_model = RetrievalModel((64, 32)) -two_layer_model.compile(optimizer=keras.optimizers.Adagrad(0.05)) -two_layer_history = two_layer_model.fit( - train_ratings, - validation_data=test_ratings, - validation_freq=5, - epochs=NUM_EPOCHS, -) -``` - -
-``` -Epoch 1/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1:49 1s/step - loss: 0.4479 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 1:13 940ms/step - loss: 0.4535 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 3s 55ms/step - loss: 0.3700 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 1s 32ms/step - loss: 0.2929 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step - loss: 0.2477 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.2180 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - loss: 0.2010 - - -
-``` -Epoch 2/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1:16 963ms/step - loss: 0.0760 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0757 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0757 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0758 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0758 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0758 - - -
-``` -Epoch 3/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0744 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0742 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0742 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0742 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0743 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0743 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0743 - - -
-``` -Epoch 4/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0729 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0728 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0729 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0729 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0730 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0730 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0730 - - -
-``` -Epoch 5/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0715 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0717 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0717 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0717 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0718 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0718 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 27ms/step - loss: 0.0718 - val_loss: 0.0725 - val_sparse_top_k_categorical_accuracy: 0.1145 - - -
-``` -Epoch 6/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0701 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0706 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0705 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0706 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0706 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0707 - - -
-``` -Epoch 7/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0688 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0695 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0694 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0695 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0695 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0695 - - -
-``` -Epoch 8/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0675 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0684 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0684 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0684 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0684 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0684 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0684 - - -
-``` -Epoch 9/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0662 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0673 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0672 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0673 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0673 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0673 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0673 - - -
-``` -Epoch 10/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0648 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0661 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0661 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0661 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0662 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0662 - val_loss: 0.0670 - val_sparse_top_k_categorical_accuracy: 0.2066 - - -
-``` -Epoch 11/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0635 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0651 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0650 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0650 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0650 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0650 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0650 - - -
-``` -Epoch 12/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0623 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0640 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0639 - - -
-``` -Epoch 13/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0611 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0630 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0629 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0628 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0629 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0629 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0629 - - -
-``` -Epoch 14/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0600 - -
-``` - -``` -
- 19/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0620 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0619 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0619 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0619 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0619 - - -
-``` -Epoch 15/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0590 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0612 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0611 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0610 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0610 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0610 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0610 - val_loss: 0.0622 - val_sparse_top_k_categorical_accuracy: 0.2694 - - -
-``` -Epoch 16/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0580 - -
-``` - -``` -
- 15/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0605 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0603 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0602 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0602 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0603 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0603 - - -
-``` -Epoch 17/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0572 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0598 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0596 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0595 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0595 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0596 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0596 - - -
-``` -Epoch 18/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0565 - -
-``` - -``` -
- 15/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0592 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0590 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0589 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0589 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0589 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0589 - - -
-``` -Epoch 19/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0558 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0586 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0584 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0583 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0583 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0584 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0584 - - -
-``` -Epoch 20/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0552 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0580 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0579 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0578 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0578 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0578 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0578 - val_loss: 0.0594 - val_sparse_top_k_categorical_accuracy: 0.2793 - - -
-``` -Epoch 21/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0547 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0576 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0574 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0573 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0573 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0574 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0574 - - -
-``` -Epoch 22/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0542 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0572 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0570 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0569 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0569 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0570 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0570 - - -
-``` -Epoch 23/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0538 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0568 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0565 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - - -
-``` -Epoch 24/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0534 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0565 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0563 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0562 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0562 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0563 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0563 - - -
-``` -Epoch 25/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0530 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0562 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0560 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0559 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0559 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0560 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0560 - val_loss: 0.0579 - val_sparse_top_k_categorical_accuracy: 0.2896 - - -
-``` -Epoch 26/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0527 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0559 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0557 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0556 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0557 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0557 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0557 - - -
-``` -Epoch 27/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0524 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0556 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0554 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0554 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - - -
-``` -Epoch 28/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0521 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0554 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - - -
-``` -Epoch 29/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0519 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0550 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0550 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0550 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0550 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0550 - - -
-``` -Epoch 30/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0517 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0550 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0548 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0548 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0548 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0548 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0548 - val_loss: 0.0570 - val_sparse_top_k_categorical_accuracy: 0.2964 - - -While the deeper model seems to learn a bit better than the shallow model at -first, the difference becomes minimal towards the end of the trainign. We can -plot the validation accuracy curves to illustrate this: - - -```python -METRIC = "val_sparse_top_k_categorical_accuracy" -num_validation_runs = len(one_layer_history.history[METRIC]) -epochs = [(x + 1) * 5 for x in range(num_validation_runs)] - -plt.plot(epochs, one_layer_history.history[METRIC], label="1 layer") -plt.plot(epochs, two_layer_history.history[METRIC], label="2 layers") -plt.title("Accuracy vs epoch") -plt.xlabel("epoch") -plt.ylabel("Top-100 accuracy") -plt.legend() -plt.show() -``` - - - -![png](/img/examples/keras_rs/deep_recommender/deep_recommender_57_0.png) - - - -Deeper models are not necessarily better. The following model extends the depth -to three layers: - - -```python -three_layer_model = RetrievalModel((128, 64, 32)) -three_layer_model.compile(optimizer=keras.optimizers.Adagrad(0.05)) -three_layer_history = three_layer_model.fit( - train_ratings, - validation_data=test_ratings, - validation_freq=5, - epochs=NUM_EPOCHS, -) -``` - -
-``` -Epoch 1/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1:58 1s/step - loss: 0.4474 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 1:22 1s/step - loss: 0.4530 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 4s 69ms/step - loss: 0.3563 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 1s 37ms/step - loss: 0.2745 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 26ms/step - loss: 0.2300 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.2020 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - loss: 0.1843 - - -
-``` -Epoch 2/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 1:23 1s/step - loss: 0.0769 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0761 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0760 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0760 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0760 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0760 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0760 - - -
-``` -Epoch 3/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0751 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0745 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0744 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0744 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0745 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0745 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0745 - - -
-``` -Epoch 4/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0737 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0732 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0732 - -
-``` - -``` -
- 51/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0732 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0732 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0733 - - -
-``` -Epoch 5/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0724 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0722 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0720 - -
-``` - -``` -
- 50/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0720 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0721 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 27ms/step - loss: 0.0721 - val_loss: 0.0726 - val_sparse_top_k_categorical_accuracy: 0.1402 - - -
-``` -Epoch 6/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0712 - -
-``` - -``` -
- 15/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0711 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0709 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0709 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0709 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0710 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0710 - - -
-``` -Epoch 7/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0698 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0700 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0698 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0698 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0698 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0699 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0699 - - -
-``` -Epoch 8/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0684 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0689 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0687 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0687 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0687 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0688 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0688 - - -
-``` -Epoch 9/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0668 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0677 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0675 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0675 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0675 - -
-``` - -``` -
- 74/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0676 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0677 - - -
-``` -Epoch 10/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0654 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0665 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0663 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0663 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0664 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0665 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0665 - val_loss: 0.0667 - val_sparse_top_k_categorical_accuracy: 0.2197 - - -
-``` -Epoch 11/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0640 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0653 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0652 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0652 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0652 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0654 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0654 - - -
-``` -Epoch 12/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0626 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0642 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0641 - -
-``` - -``` -
- 49/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0640 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0641 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0643 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0643 - - -
-``` -Epoch 13/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0613 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0631 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0630 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0630 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0630 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0632 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0633 - - -
-``` -Epoch 14/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0601 - -
-``` - -``` -
- 18/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0620 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0620 - -
-``` - -``` -
- 52/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0619 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0621 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0623 - - -
-``` -Epoch 15/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0590 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0611 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0610 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0610 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0611 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0613 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0613 - val_loss: 0.0618 - val_sparse_top_k_categorical_accuracy: 0.2900 - - -
-``` -Epoch 16/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0580 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0603 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0602 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0602 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0603 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0605 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0605 - - -
-``` -Epoch 17/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0572 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0596 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0595 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0595 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0596 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0597 - - -
-``` -Epoch 18/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0564 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0589 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0588 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0588 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0589 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0590 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0591 - - -
-``` -Epoch 19/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 0.0557 - -
-``` - -``` -
- 17/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0583 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0583 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0582 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0583 - -
-``` - -``` -
- 79/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0585 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0585 - - -
-``` -Epoch 20/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0551 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0578 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0578 - -
-``` - -``` -
- 44/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0577 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0578 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0579 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0580 - val_loss: 0.0591 - val_sparse_top_k_categorical_accuracy: 0.3015 - - -
-``` -Epoch 21/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0546 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0574 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0573 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0573 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0574 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0575 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0575 - - -
-``` -Epoch 22/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0541 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0570 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0569 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0569 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0570 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0571 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0571 - - -
-``` -Epoch 23/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0537 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0565 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0566 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0567 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0567 - - -
-``` -Epoch 24/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0533 - -
-``` - -``` -
- 15/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0563 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0562 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0562 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0563 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0564 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0564 - - -
-``` -Epoch 25/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0530 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0560 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0560 - -
-``` - -``` -
- 45/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0559 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0560 - -
-``` - -``` -
- 75/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0561 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0561 - val_loss: 0.0577 - val_sparse_top_k_categorical_accuracy: 0.3049 - - -
-``` -Epoch 26/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0527 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0558 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0557 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0557 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0558 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0559 - - -
-``` -Epoch 27/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0524 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0554 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0555 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0556 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0556 - - -
-``` -Epoch 28/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0522 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0553 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0553 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0553 - -
-``` - -``` -
- 78/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0554 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0554 - - -
-``` -Epoch 29/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0520 - -
-``` - -``` -
- 16/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0551 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0551 - -
-``` - -``` -
- 47/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0551 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0551 - -
-``` - -``` -
- 77/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0552 - - -
-``` -Epoch 30/30 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step - loss: 0.0517 - -
-``` - -``` -
- 15/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0550 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0549 - -
-``` - -``` -
- 46/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0549 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0549 - -
-``` - -``` -
- 76/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0550 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0550 - val_loss: 0.0569 - val_sparse_top_k_categorical_accuracy: 0.3072 - - -We don't really see an improvement over the shallow model: - - -```python -plt.plot(epochs, one_layer_history.history[METRIC], label="1 layer") -plt.plot(epochs, two_layer_history.history[METRIC], label="2 layers") -plt.plot(epochs, three_layer_history.history[METRIC], label="3 layers") -plt.title("Accuracy vs epoch") -plt.xlabel("epoch") -plt.ylabel("Top-100 accuracy") -plt.legend() -plt.show() -``` - - - -![png](/img/examples/keras_rs/deep_recommender/deep_recommender_61_0.png) - - - -This is a good illustration of the fact that deeper and larger models, while -capable of superior performance, often require very careful tuning. For example, -throughout this tutorial we used a single, fixed learning rate. Alternative -choices may give very different results and are worth exploring. - -With appropriate tuning and sufficient data, the effort put into building larger -and deeper models is in many cases well worth it: larger models can lead to -substantial improvements in prediction accuracy. - ---- -## Next Steps - -In this tutorial we expanded our retrieval model with dense layers and -activation functions. To see how to create a model that can perform not only -retrieval tasks but also rating tasks, take a look at the multitask tutorial. diff --git a/templates/keras_rs/examples/dlrm.md b/templates/keras_rs/examples/dlrm.md deleted file mode 100644 index 46131c0bad..0000000000 --- a/templates/keras_rs/examples/dlrm.md +++ /dev/null @@ -1,520 +0,0 @@ -# Ranking with Deep Learning Recommendation Model - -**Author:** [Harshith Kulkarni](https://github.com/kharshith-k)
-**Date created:** 2025/06/02
-**Last modified:** 2025/09/04
-**Description:** Rank movies with DLRM using KerasRS. - - - [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/dlrm.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/dlrm.py) - - - ---- -## Introduction - -This tutorial demonstrates how to use the Deep Learning Recommendation Model (DLRM) to -effectively learn the relationships between items and user preferences using a -dot-product interaction mechanism. For more details, please refer to the -[DLRM](https://arxiv.org/abs/1906.00091) paper. - -DLRM is designed to excel at capturing explicit, bounded-degree feature interactions and -is particularly effective at processing both categorical and continuous (sparse/dense) -input features. The architecture consists of three main components: dedicated input -layers to handle diverse features (typically embedding layers for categorical features), -a dot-product interaction layer to explicitly model feature interactions, and a -Multi-Layer Perceptron (MLP) to capture implicit feature relationships. - -The dot-product interaction layer lies at the heart of DLRM, efficiently computing -pairwise interactions between different feature embeddings. This contrasts with models -like Deep & Cross Network (DCN), which can treat elements within a feature vector as -independent units, potentially leading to a higher-dimensional space and increased -computational cost. The MLP is a standard feedforward network. The DLRM is formed by -combining the interaction layer and MLP. - -The following image illustrates the DLRM architecture: - -![DLRM Architecture](https://raw.githubusercontent.com/kharshith-k/keras-io/refs/heads/keras-rs-examples/examples/keras_rs/img/dlrm/dlrm_architecture.gif) - - -Now that we have a foundational understanding of DLRM's architecture and key -characteristics, let's dive into the code. We will train a DLRM on a real-world dataset -to demonstrate its capability to learn meaningful feature interactions. Let's begin by -setting the backend to JAX and organizing our imports. - - -```python -!pip install -q keras-rs -``` - - - - -```python -import os - -os.environ["KERAS_BACKEND"] = "tensorflow" # `"tensorflow"`/`"torch"` - -import keras -import matplotlib.pyplot as plt -import numpy as np -import tensorflow as tf -import tensorflow_datasets as tfds -from mpl_toolkits.axes_grid1 import make_axes_locatable - -import keras_rs -``` - -Let's also define variables which will be reused throughout the example. - - -```python -MOVIELENS_CONFIG = { - # features - "continuous_features": [ - "raw_user_age", - "hour_of_day_sin", - "hour_of_day_cos", - "hour_of_week_sin", - "hour_of_week_cos", - ], - "categorical_int_features": [ - "user_gender", - ], - "categorical_str_features": [ - "user_zip_code", - "user_occupation_text", - "movie_id", - "user_id", - ], - # model - "embedding_dim": 8, - "mlp_dim": 8, - "deep_net_num_units": [192, 192, 192], - # training - "learning_rate": 1e-4, - "num_epochs": 30, - "batch_size": 8192, -} -``` - -Here, we define a helper function for visualising weights of the cross layer in -order to better understand its functioning. Also, we define a function for -compiling, training and evaluating a given model. - - -```python - -def plot_training_metrics(history): - """Graphs all metrics tracked in the history object.""" - plt.figure(figsize=(12, 6)) - - for metric_name, metric_values in history.history.items(): - plt.plot(metric_values, label=metric_name.replace("_", " ").title()) - - plt.title("Metrics over Epochs") - plt.xlabel("Epoch") - plt.ylabel("Metric Value") - plt.legend() - plt.grid(True) - - -def visualize_layer(matrix, features, cmap=plt.cm.Blues): - - im = plt.matshow( - matrix, cmap=cmap, extent=[-0.5, len(features) - 0.5, len(features) - 0.5, -0.5] - ) - - ax = plt.gca() - divider = make_axes_locatable(plt.gca()) - cax = divider.append_axes("right", size="5%", pad=0.05) - plt.colorbar(im, cax=cax) - cax.tick_params(labelsize=10) - - # Set tick locations explicitly before setting labels - ax.set_xticks(np.arange(len(features))) - ax.set_yticks(np.arange(len(features))) - - ax.set_xticklabels(features, rotation=45, fontsize=5) - ax.set_yticklabels(features, fontsize=5) - - plt.show() - - -def train_and_evaluate( - learning_rate, - epochs, - train_data, - test_data, - model, - plot_metrics=False, -): - optimizer = keras.optimizers.AdamW(learning_rate=learning_rate, clipnorm=1.0) - loss = keras.losses.MeanSquaredError() - rmse = keras.metrics.RootMeanSquaredError() - - model.compile( - optimizer=optimizer, - loss=loss, - metrics=[rmse], - ) - - history = model.fit( - train_data, - epochs=epochs, - verbose=1, - ) - if plot_metrics: - plot_training_metrics(history) - - results = model.evaluate(test_data, return_dict=True, verbose=1) - rmse_value = results["root_mean_squared_error"] - - return rmse_value, model.count_params() - - -def print_stats(rmse_list, num_params, model_name): - # Report metrics. - num_trials = len(rmse_list) - avg_rmse = np.mean(rmse_list) - std_rmse = np.std(rmse_list) - - if num_trials == 1: - print(f"{model_name}: RMSE = {avg_rmse}; #params = {num_params}") - else: - print(f"{model_name}: RMSE = {avg_rmse} ± {std_rmse}; #params = {num_params}") - -``` - ---- -## Real-world example - -Let's use the MovieLens 100K dataset. This dataset is used to train models to -predict users' movie ratings, based on user-related features and movie-related -features. - -### Preparing the dataset - -The dataset processing steps here are similar to what's given in the -[basic ranking](/keras_rs/examples/basic_ranking/) -tutorial. Let's load the dataset, and keep only the useful columns. - - -```python -ratings_ds = tfds.load("movielens/100k-ratings", split="train") - - -def preprocess_features(x): - """Extracts and cyclically encodes timestamp features.""" - features = { - "movie_id": x["movie_id"], - "user_id": x["user_id"], - "user_gender": tf.cast(x["user_gender"], dtype=tf.int32), - "user_zip_code": x["user_zip_code"], - "user_occupation_text": x["user_occupation_text"], - "raw_user_age": tf.cast(x["raw_user_age"], dtype=tf.float32), - } - label = tf.cast(x["user_rating"], dtype=tf.float32) - - # The timestamp is in seconds since the epoch. - timestamp = tf.cast(x["timestamp"], dtype=tf.float32) - - # Constants for time periods - SECONDS_IN_HOUR = 3600.0 - HOURS_IN_DAY = 24.0 - HOURS_IN_WEEK = 168.0 - - # Calculate hour of day and encode it - hour_of_day = (timestamp / SECONDS_IN_HOUR) % HOURS_IN_DAY - features["hour_of_day_sin"] = tf.sin(2 * np.pi * hour_of_day / HOURS_IN_DAY) - features["hour_of_day_cos"] = tf.cos(2 * np.pi * hour_of_day / HOURS_IN_DAY) - - # Calculate hour of week and encode it - hour_of_week = (timestamp / SECONDS_IN_HOUR) % HOURS_IN_WEEK - features["hour_of_week_sin"] = tf.sin(2 * np.pi * hour_of_week / HOURS_IN_WEEK) - features["hour_of_week_cos"] = tf.cos(2 * np.pi * hour_of_week / HOURS_IN_WEEK) - - return features, label - - -# Apply the new preprocessing function -ratings_ds = ratings_ds.map(preprocess_features) -``` - -For every categorical feature, let's get the list of unique values, i.e., vocabulary, so -that we can use that for the embedding layer. - - -```python -vocabularies = {} -for feature_name in ( - MOVIELENS_CONFIG["categorical_int_features"] - + MOVIELENS_CONFIG["categorical_str_features"] -): - vocabulary = ratings_ds.batch(10_000).map(lambda x, y: x[feature_name]) - vocabularies[feature_name] = np.unique(np.concatenate(list(vocabulary))) -``` - -One thing we need to do is to use `keras.layers.StringLookup` and -`keras.layers.IntegerLookup` to convert all the categorical features into indices, which -can -then be fed into embedding layers. - - -```python -lookup_layers = {} -lookup_layers.update( - { - feature: keras.layers.IntegerLookup(vocabulary=vocabularies[feature]) - for feature in MOVIELENS_CONFIG["categorical_int_features"] - } -) -lookup_layers.update( - { - feature: keras.layers.StringLookup(vocabulary=vocabularies[feature]) - for feature in MOVIELENS_CONFIG["categorical_str_features"] - } -) -``` - -Let's normalize all the continuous features, so that we can use that for the MLP layers. - - -```python -normalization_layers = {} -for feature_name in MOVIELENS_CONFIG["continuous_features"]: - normalization_layers[feature_name] = keras.layers.Normalization(axis=-1) - -training_data_for_adaptation = ratings_ds.take(80_000).map(lambda x, y: x) - -for feature_name in MOVIELENS_CONFIG["continuous_features"]: - feature_ds = training_data_for_adaptation.map( - lambda x: tf.expand_dims(x[feature_name], axis=-1) - ) - normalization_layers[feature_name].adapt(feature_ds) - -ratings_ds = ratings_ds.map( - lambda x, y: ( - { - **{ - feature_name: lookup_layers[feature_name](x[feature_name]) - for feature_name in vocabularies - }, - # Apply the adapted normalization layers to the continuous features. - **{ - feature_name: tf.squeeze( - normalization_layers[feature_name]( - tf.expand_dims(x[feature_name], axis=-1) - ), - axis=-1, - ) - for feature_name in MOVIELENS_CONFIG["continuous_features"] - }, - }, - y, - ) -) -``` - -Let's split our data into train and test sets. We also use `cache()` and -`prefetch()` for better performance. - - -```python -ratings_ds = ratings_ds.shuffle(100_000) - -train_ds = ( - ratings_ds.take(80_000) - .batch(MOVIELENS_CONFIG["batch_size"]) - .cache() - .prefetch(tf.data.AUTOTUNE) -) -test_ds = ( - ratings_ds.skip(80_000) - .batch(MOVIELENS_CONFIG["batch_size"]) - .take(20_000) - .cache() - .prefetch(tf.data.AUTOTUNE) -) -``` - -### Building the model - -The model will have embedding layers, followed by DotInteraction and feedforward -layers. - - -```python - -class DLRM(keras.Model): - def __init__( - self, - dense_num_units_lst, - embedding_dim=MOVIELENS_CONFIG["embedding_dim"], - mlp_dim=MOVIELENS_CONFIG["mlp_dim"], - **kwargs, - ): - super().__init__(**kwargs) - - self.embedding_layers = {} - for feature_name in ( - MOVIELENS_CONFIG["categorical_int_features"] - + MOVIELENS_CONFIG["categorical_str_features"] - ): - vocab_size = len(vocabularies[feature_name]) + 1 # +1 for OOV token - self.embedding_layers[feature_name] = keras.layers.Embedding( - input_dim=vocab_size, - output_dim=embedding_dim, - ) - - self.bottom_mlp = keras.Sequential( - [ - keras.layers.Dense(mlp_dim, activation="relu"), - keras.layers.Dense(embedding_dim), # Output must match embedding_dim - ] - ) - - self.dot_layer = keras_rs.layers.DotInteraction() - - self.top_mlp = [] - for num_units in dense_num_units_lst: - self.top_mlp.append(keras.layers.Dense(num_units, activation="relu")) - - self.output_layer = keras.layers.Dense(1) - - self.dense_num_units_lst = dense_num_units_lst - self.embedding_dim = embedding_dim - - def call(self, inputs): - embeddings = [] - for feature_name in ( - MOVIELENS_CONFIG["categorical_int_features"] - + MOVIELENS_CONFIG["categorical_str_features"] - ): - embedding = self.embedding_layers[feature_name](inputs[feature_name]) - embeddings.append(embedding) - - # Process all continuous features together. - continuous_inputs = [] - for feature_name in MOVIELENS_CONFIG["continuous_features"]: - # Reshape each feature to (batch_size, 1) - feature = keras.ops.reshape( - keras.ops.cast(inputs[feature_name], dtype="float32"), (-1, 1) - ) - continuous_inputs.append(feature) - - # Concatenate into a single tensor: (batch_size, num_continuous_features) - concatenated_continuous = keras.ops.concatenate(continuous_inputs, axis=1) - - # Pass through the Bottom MLP to get one combined vector. - processed_continuous = self.bottom_mlp(concatenated_continuous) - - # Combine with categorical embeddings. Note: we add a list containing the - # single tensor. - combined_features = embeddings + [processed_continuous] - - # Pass the list of features to the DotInteraction layer. - x = self.dot_layer(combined_features) - - for layer in self.top_mlp: - x = layer(x) - - x = self.output_layer(x) - - return x - - -dot_network = DLRM( - dense_num_units_lst=MOVIELENS_CONFIG["deep_net_num_units"], - embedding_dim=MOVIELENS_CONFIG["embedding_dim"], - mlp_dim=MOVIELENS_CONFIG["mlp_dim"], -) - -rmse, dot_network_num_params = train_and_evaluate( - learning_rate=MOVIELENS_CONFIG["learning_rate"], - epochs=MOVIELENS_CONFIG["num_epochs"], - train_data=train_ds, - test_data=test_ds, - model=dot_network, - plot_metrics=True, -) -print_stats( - rmse_list=[rmse], - num_params=dot_network_num_params, - model_name="Dot Network", -) -``` - -![png](/img/examples/keras_rs/dlrm/dlrm_19_158.png) - - -### Visualizing feature interactions - -The DotInteraction layer itself doesn't have a conventional "weight" matrix like a Dense -layer. Instead, its function is to compute the dot product between the embedding vectors -of your features. - -To visualize the strength of these interactions, we can calculate a matrix representing -the pairwise interaction strength between all feature embeddings. A common way to do this -is to take the dot product of the embedding matrices for each pair of features and then -aggregate the result into a single value (like the mean of the absolute values) that -represents the overall interaction strength. - - -```python - -def get_dot_interaction_matrix(model, categorical_features, continuous_features): - # The new feature list for the plot labels - all_feature_names = categorical_features + ["all_continuous_features"] - num_features = len(all_feature_names) - - # Store all feature outputs in the correct order. - all_feature_outputs = [] - - # Get outputs for categorical features from embedding layers (unchanged). - for feature_name in categorical_features: - embedding = model.embedding_layers[feature_name](keras.ops.array([0])) - all_feature_outputs.append(embedding) - - # Get a single output for ALL continuous features from the shared MLP. - num_continuous_features = len(continuous_features) - # Create a dummy input of zeros for the MLP - dummy_continuous_input = keras.ops.zeros((1, num_continuous_features)) - processed_continuous = model.bottom_mlp(dummy_continuous_input) - all_feature_outputs.append(processed_continuous) - - interaction_matrix = np.zeros((num_features, num_features)) - - # Iterate through each pair to calculate interaction strength. - for i in range(num_features): - for j in range(num_features): - interaction = keras.ops.dot( - all_feature_outputs[i], keras.ops.transpose(all_feature_outputs[j]) - ) - interaction_strength = keras.ops.convert_to_numpy(np.abs(interaction))[0][0] - interaction_matrix[i, j] = interaction_strength - - return interaction_matrix, all_feature_names - - -# Get the list of categorical feature names. -categorical_feature_names = ( - MOVIELENS_CONFIG["categorical_int_features"] - + MOVIELENS_CONFIG["categorical_str_features"] -) - -# Calculate the interaction matrix with the corrected function. -interaction_matrix, feature_names = get_dot_interaction_matrix( - model=dot_network, - categorical_features=categorical_feature_names, - continuous_features=MOVIELENS_CONFIG["continuous_features"], -) - -# Visualize the matrix as a heatmap. -print("\nVisualizing the feature interaction strengths:") -visualize_layer(interaction_matrix, feature_names) -``` - -![png](/img/examples/keras_rs/dlrm/dlrm_21_1.png) - - diff --git a/templates/keras_rs/examples/listwise_ranking.md b/templates/keras_rs/examples/listwise_ranking.md deleted file mode 100644 index 89f302b3fd..0000000000 --- a/templates/keras_rs/examples/listwise_ranking.md +++ /dev/null @@ -1,667 +0,0 @@ -# List-wise ranking - -**Author:** [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Rank movies using pairwise losses instead of pointwise losses. - - - [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/listwise_ranking.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/listwise_ranking.py) - - - ---- -## Introduction - -In our -[basic ranking tutorial](/keras_rs/examples/basic_ranking/), we explored a model -that learned to predict ratings for specific user-movie combinations. This model -took (user, movie) pairs as input and was trained using mean-squared error to -precisely predict the rating a user might give to a movie. - -However, solely optimizing a model's accuracy in predicting individual movie -scores isn't always the most effective strategy for developing ranking systems. -For ranking models, pinpoint accuracy in predicting scores is less critical than -the model's capability to generate an ordered list of items that aligns with a -user's preferences. In essence, the relative order of items matters more than -the exact predicted values. - -Instead of focusing on the model's predictions for individual query-item pairs -(a pointwise approach), we can optimize the model based on its ability to -correctly order items. One common method for this is pairwise ranking. In this -approach, the model learns by comparing pairs of items (e.g., item A and item B) -and determining which one should be ranked higher for a given user or query. The -goal is to minimize the number of incorrectly ordered pairs. - -Let's begin by importing all the necessary libraries. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import collections - -import keras -import numpy as np -import tensorflow as tf # Needed only for the dataset -import tensorflow_datasets as tfds -from keras import ops - -import keras_rs -``` - -Let's define some hyperparameters here. - - -```python -# Data args -TRAIN_NUM_LIST_PER_USER = 50 -TEST_NUM_LIST_PER_USER = 1 -NUM_EXAMPLES_PER_LIST = 5 - -# Model args -EMBEDDING_DIM = 32 - -# Train args -BATCH_SIZE = 1024 -EPOCHS = 5 -LEARNING_RATE = 0.1 -``` - ---- -## Preparing the dataset - -We use the MovieLens dataset. The data loading and processing steps are similar -to previous tutorials, so, we will only discuss the differences here. - - -```python -# Ratings data. -ratings = tfds.load("movielens/100k-ratings", split="train") -# Features of all the available movies. -movies = tfds.load("movielens/100k-movies", split="train") - -users_count = ( - ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) - .reduce(tf.constant(0, tf.int32), tf.maximum) - .numpy() -) -movies_count = movies.cardinality().numpy() - - -def preprocess_rating(x): - return { - "user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32), - "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), - # Normalise ratings between 0 and 1. - "user_rating": (x["user_rating"] - 1.0) / 4.0, - } - - -shuffled_ratings = ratings.map(preprocess_rating).shuffle( - 100_000, seed=42, reshuffle_each_iteration=False -) -train_ratings = shuffled_ratings.take(70_000) -val_ratings = shuffled_ratings.skip(70_000).take(15_000) -test_ratings = shuffled_ratings.skip(85_000).take(15_000) -``` - -So far, we've replicated what we have in the basic ranking tutorial. - -However, this existing dataset is not directly applicable to list-wise -optimization. List-wise optimization requires, for each user, a list of movies -they have rated, allowing the model to learn from the relative orderings within -that list. The MovieLens 100K dataset, in its original form, provides individual -rating instances (one user, one movie, one rating per example), rather than -these aggregated user-specific lists. - -To enable listwise optimization, we need to restructure the dataset. This -involves transforming it so that each data point or example represents a single -user ID accompanied by a list of movies that user has rated. Within these lists, -some movies will naturally be ranked higher by the user (as evidenced by their -ratings) than others. The primary objective for our model will then be to learn -to predict item orderings that correspond to these observed user preferences. - -Let's start by getting the entire list of movies and corresponding ratings for -every user. We remove `user_ids` corresponding to users who have rated less than -`NUM_EXAMPLES_PER_LIST` number of movies. - - -```python - -def get_movie_sequence_per_user(ratings, min_examples_per_list): - """Gets movieID sequences and ratings for every user.""" - sequences = collections.defaultdict(list) - - for sample in ratings: - user_id = sample["user_id"] - movie_id = sample["movie_id"] - user_rating = sample["user_rating"] - - sequences[int(user_id.numpy())].append( - { - "movie_id": int(movie_id.numpy()), - "user_rating": float(user_rating.numpy()), - } - ) - - # Remove lists with < `min_examples_per_list` number of elements. - sequences = { - user_id: sequence - for user_id, sequence in sequences.items() - if len(sequence) >= min_examples_per_list - } - - return sequences - -``` - -We now sample 50 lists for each user for the training data. For each list, we -randomly sample 5 movies from the movies the user rated. - - -```python - -def sample_sublist_from_list( - lst, - num_examples_per_list, -): - """Random selects `num_examples_per_list` number of elements from list.""" - - indices = np.random.choice( - range(len(lst)), - size=num_examples_per_list, - replace=False, - ) - - samples = [lst[i] for i in indices] - return samples - - -def get_examples( - sequences, - num_list_per_user, - num_examples_per_list, -): - inputs = { - "user_id": [], - "movie_id": [], - } - labels = [] - for user_id, user_list in sequences.items(): - sampled_list = sample_sublist_from_list( - user_list, - num_examples_per_list, - ) - - inputs["user_id"].append(user_id) - inputs["movie_id"].append( - tf.convert_to_tensor([f["movie_id"] for f in sampled_list]) - ) - labels.append(tf.convert_to_tensor([f["user_rating"] for f in sampled_list])) - - return ( - {"user_id": inputs["user_id"], "movie_id": inputs["movie_id"]}, - labels, - ) - - -train_sequences = get_movie_sequence_per_user( - ratings=train_ratings, min_examples_per_list=NUM_EXAMPLES_PER_LIST -) -train_examples = get_examples( - train_sequences, - num_list_per_user=TRAIN_NUM_LIST_PER_USER, - num_examples_per_list=NUM_EXAMPLES_PER_LIST, -) -train_ds = tf.data.Dataset.from_tensor_slices(train_examples) - -val_sequences = get_movie_sequence_per_user( - ratings=val_ratings, min_examples_per_list=5 -) -val_examples = get_examples( - val_sequences, - num_list_per_user=TEST_NUM_LIST_PER_USER, - num_examples_per_list=NUM_EXAMPLES_PER_LIST, -) -val_ds = tf.data.Dataset.from_tensor_slices(val_examples) - -test_sequences = get_movie_sequence_per_user( - ratings=test_ratings, min_examples_per_list=5 -) -test_examples = get_examples( - test_sequences, - num_list_per_user=TEST_NUM_LIST_PER_USER, - num_examples_per_list=NUM_EXAMPLES_PER_LIST, -) -test_ds = tf.data.Dataset.from_tensor_slices(test_examples) -``` - -Batch up the dataset, and cache it. - - -```python -train_ds = train_ds.batch(BATCH_SIZE).cache() -val_ds = val_ds.batch(BATCH_SIZE).cache() -test_ds = test_ds.batch(BATCH_SIZE).cache() -``` - ---- -## Building the model - -We build a typical two-tower ranking model, similar to the -[basic ranking tutorial](/keras_rs/examples/basic_ranking/). -We have separate embedding layers for user ID and movie IDs. After obtaining -these embeddings, we concatenate them and pass them through a network of dense -layers. - -The only point of difference is that for movie IDs, we take a list of IDs -rather than just one movie ID. So, when we concatenate user ID embedding and -movie IDs' embeddings, we "repeat" the user ID 'NUM_EXAMPLES_PER_LIST' times so -as to get the same shape as the movie IDs' embeddings. - - -```python - -class RankingModel(keras.Model): - """Create the ranking model with the provided parameters. - - Args: - num_users: Number of entries in the user embedding table. - num_candidates: Number of entries in the candidate embedding table. - embedding_dimension: Output dimension for user and movie embedding tables. - """ - - def __init__( - self, - num_users, - num_candidates, - embedding_dimension=32, - **kwargs, - ): - super().__init__(**kwargs) - # Embedding table for users. - self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension) - # Embedding table for candidates. - self.candidate_embedding = keras.layers.Embedding( - num_candidates, embedding_dimension - ) - # Predictions. - self.ratings = keras.Sequential( - [ - # Learn multiple dense layers. - keras.layers.Dense(256, activation="relu"), - keras.layers.Dense(64, activation="relu"), - # Make rating predictions in the final layer. - keras.layers.Dense(1), - ] - ) - - def build(self, input_shape): - self.user_embedding.build(input_shape["user_id"]) - self.candidate_embedding.build(input_shape["movie_id"]) - - output_shape = self.candidate_embedding.compute_output_shape( - input_shape["movie_id"] - ) - - self.ratings.build(list(output_shape[:-1]) + [2 * output_shape[-1]]) - - def call(self, inputs): - user_id, movie_id = inputs["user_id"], inputs["movie_id"] - user_embeddings = self.user_embedding(user_id) - candidate_embeddings = self.candidate_embedding(movie_id) - - list_length = ops.shape(movie_id)[-1] - user_embeddings_repeated = ops.repeat( - ops.expand_dims(user_embeddings, axis=1), - repeats=list_length, - axis=1, - ) - concatenated_embeddings = ops.concatenate( - [user_embeddings_repeated, candidate_embeddings], axis=-1 - ) - - scores = self.ratings(concatenated_embeddings) - scores = ops.squeeze(scores, axis=-1) - - return scores - - def compute_output_shape(self, input_shape): - return (input_shape[0], input_shape[1]) - -``` - -Let's instantiate, compile and train our model. We will train two models: -one with vanilla mean-squared error, and the other with pairwise hinge loss. -For the latter, we will use `keras_rs.losses.PairwiseHingeLoss`. - -Pairwise losses compare pairs of items within each list, penalizing cases where -an item with a higher true label has a lower predicted score than an item with a -lower true label. This is why they are more suited for ranking tasks than -pointwise losses. - -To quantify these results, we compute nDCG. nDCG is a measure of ranking quality -that evaluates how well a system orders items based on relevance, giving more -importance to highly relevant items appearing at the top of the list and -normalizing the score against an ideal ranking. -To compute it, we just need to pass `keras_rs.metrics.NDCG()` as a metric to -`model.compile`. - - -```python -model_mse = RankingModel( - num_users=users_count + 1, - num_candidates=movies_count + 1, - embedding_dimension=EMBEDDING_DIM, -) -model_mse.compile( - loss=keras.losses.MeanSquaredError(), - metrics=[keras_rs.metrics.NDCG(k=NUM_EXAMPLES_PER_LIST, name="ndcg")], - optimizer=keras.optimizers.Adagrad(learning_rate=LEARNING_RATE), -) -model_mse.fit(train_ds, validation_data=val_ds, epochs=EPOCHS) -``` - -
-``` -Epoch 1/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 2s/step - loss: 0.4960 - ndcg: 0.8892 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 3s 3s/step - loss: 0.4960 - ndcg: 0.8892 - val_loss: 0.1187 - val_ndcg: 0.8846 - - -
-``` -Epoch 2/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - loss: 0.1150 - ndcg: 0.8898 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step - loss: 0.1150 - ndcg: 0.8898 - val_loss: 0.0893 - val_ndcg: 0.8878 - - -
-``` -Epoch 3/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0876 - ndcg: 0.8884 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.0876 - ndcg: 0.8884 - val_loss: 0.0864 - val_ndcg: 0.8857 - - -
-``` -Epoch 4/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0834 - ndcg: 0.8896 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.0834 - ndcg: 0.8896 - val_loss: 0.0815 - val_ndcg: 0.8876 - - -
-``` -Epoch 5/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0794 - ndcg: 0.8887 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.0794 - ndcg: 0.8887 - val_loss: 0.0810 - val_ndcg: 0.8868 - - - - - -
-``` - - -``` -
-And now, the model with pairwise hinge loss. - - -```python -model_hinge = RankingModel( - num_users=users_count + 1, - num_candidates=movies_count + 1, - embedding_dimension=EMBEDDING_DIM, -) -model_hinge.compile( - loss=keras_rs.losses.PairwiseHingeLoss(), - metrics=[keras_rs.metrics.NDCG(k=NUM_EXAMPLES_PER_LIST, name="ndcg")], - optimizer=keras.optimizers.Adagrad(learning_rate=LEARNING_RATE), -) -model_hinge.fit(train_ds, validation_data=val_ds, epochs=EPOCHS) -``` - -
-``` -Epoch 1/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - loss: 1.4067 - ndcg: 0.8933 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step - loss: 1.4067 - ndcg: 0.8933 - val_loss: 1.3927 - val_ndcg: 0.8930 - - -
-``` -Epoch 2/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - loss: 1.4061 - ndcg: 0.8953 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step - loss: 1.4061 - ndcg: 0.8953 - val_loss: 1.3925 - val_ndcg: 0.8936 - - -
-``` -Epoch 3/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 1.4054 - ndcg: 0.8977 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 1.4054 - ndcg: 0.8977 - val_loss: 1.3923 - val_ndcg: 0.8941 - - -
-``` -Epoch 4/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 1.4047 - ndcg: 0.8999 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 1.4047 - ndcg: 0.8999 - val_loss: 1.3921 - val_ndcg: 0.8941 - - -
-``` -Epoch 5/5 - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 1.4041 - ndcg: 0.9004 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 1.4041 - ndcg: 0.9004 - val_loss: 1.3919 - val_ndcg: 0.8940 - - - - - -
-``` - - -``` -
---- -## Evaluation - -Comparing the validation nDCG values, it is clear that the model trained with -the pairwise hinge loss outperforms the other one. Let's make this observation -more concrete by comparing results on the test set. - - -```python -ndcg_mse = model_mse.evaluate(test_ds, return_dict=True)["ndcg"] -ndcg_hinge = model_hinge.evaluate(test_ds, return_dict=True)["ndcg"] -print(ndcg_mse, ndcg_hinge) -``` - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - loss: 0.0805 - ndcg: 0.8886 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step - loss: 0.0805 - ndcg: 0.8886 - - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 910ms/step - loss: 1.3878 - ndcg: 0.8924 - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 914ms/step - loss: 1.3878 - ndcg: 0.8924 - - -
-``` -0.8885537385940552 0.8924424052238464 - -``` -
---- -## Prediction - -Now, let's rank some lists! - -Let's create a mapping from movie ID to title so that we can surface the titles -for the ranked list. - - -```python -movie_id_to_movie_title = { - int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator() -} -movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. - -user_id = 42 -movie_ids = [409, 237, 131, 941, 543] -predictions = model_hinge.predict( - { - "user_id": keras.ops.array([user_id]), - "movie_id": keras.ops.array([movie_ids]), - } -) -predictions = keras.ops.convert_to_numpy(keras.ops.squeeze(predictions, axis=0)) -sorted_indices = np.argsort(predictions) -sorted_movies = [movie_ids[i] for i in sorted_indices] - -for i, movie_id in enumerate(sorted_movies): - print(f"{i + 1}. ", movie_id_to_movie_title[movie_id]) -``` - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 261ms/step - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 262ms/step - - -
-``` -1. b'Jack (1996)' -2. b'Mis\xc3\xa9rables, Les (1995)' -3. b'Jerry Maguire (1996)' -4. b"Breakfast at Tiffany's (1961)" -5. b'With Honors (1994)' - -``` -
-And we're all done! diff --git a/templates/keras_rs/examples/multi_task.md b/templates/keras_rs/examples/multi_task.md deleted file mode 100644 index 68d00a60e6..0000000000 --- a/templates/keras_rs/examples/multi_task.md +++ /dev/null @@ -1,1463 +0,0 @@ -# Multi-task recommenders: retrieval + ranking - -**Author:** [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Using one model for both retrieval and ranking. - - - [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/multi_task.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/multi_task.py) - - - ---- -## Introduction - -In the -[basic retrieval](/keras_rs/examples/basic_retrieval/) -and -[basic ranking](/keras_rs/examples/basic_ranking/) -tutorials, we created separate models for retrieval and ranking tasks, -respectively. However, in many cases, building a single, joint model for -multiple tasks can lead to better performance than creating distinct models for -each task. This is especially true when dealing with data that is unevenly -distributed — such as abundant data (e.g., clicks) versus sparse data -(e.g., purchases, returns, or manual reviews). In such scenarios, a joint model -can leverage representations learned from the abundant data to improve -predictions on the sparse data, a technique known as transfer learning. -For instance, [research](https://openreview.net/forum?id=SJxPVcSonN) shows that -a model trained to predict user ratings from sparse survey data can be -significantly enhanced by incorporating an auxiliary task using abundant click -log data. - -In this example, we develop a multi-objective recommender system using the -MovieLens dataset. We incorporate both implicit feedback (e.g., movie watches) -and explicit feedback (e.g., ratings) to create a more robust and effective -recommendation model. For the former, we predict "movie watches", i.e., whether -a user has watched a movie, and for the latter, we predict the rating given by a -user to a movie. - -Let's start by importing the necessary packages. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import keras -import tensorflow as tf # Needed for the dataset -import tensorflow_datasets as tfds - -import keras_rs -``` - ---- -## Prepare the dataset - -We use the MovieLens dataset. The data loading and processing steps are similar -to previous tutorials, so we will not discuss them in details here. - - -```python -# Ratings data with user and movie data. -ratings = tfds.load("movielens/100k-ratings", split="train") -# Features of all the available movies. -movies = tfds.load("movielens/100k-movies", split="train") -``` - -Get user and movie counts so that we can define embedding layers. - - -```python -users_count = ( - ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) - .reduce(tf.constant(0, tf.int32), tf.maximum) - .numpy() -) - -movies_count = movies.cardinality().numpy() -``` - -Our inputs are `"user_id"` and `"movie_id"`. Our label for the ranking task is -`"user_rating"`. `"user_rating"` is an integer between 0 to 4. We constrain it -to `[0, 1]`. - - -```python - -def preprocess_rating(x): - return ( - { - "user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32), - "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), - }, - (x["user_rating"] - 1.0) / 4.0, - ) - - -shuffled_ratings = ratings.map(preprocess_rating).shuffle( - 100_000, seed=42, reshuffle_each_iteration=False -) - -``` - -Split the dataset into train-test sets. - - -```python -train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() -test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() -``` - ---- -## Building the model - -We build the model in a similar way to the basic retrieval and basic ranking -guides. - -For the retrieval task (i.e., predicting whether a user watched a movie), -we compute the similarity of the corresponding user and movie embeddings, and -use cross entropy loss, where the positive pairs are labelled one, and all other -samples in the batch are considered "negatives". We report top-k accuracy for -this task. - -For the ranking task (i.e., given a user-movie pair, predict rating), we -concatenate user and movie embeddings and pass it to a dense module. We use -MSE loss here, and report the Root Mean Squared Error (RMSE). - -The final loss is a weighted combination of the two losses mentioned above, -where the weights are `"retrieval_loss_wt"` and `"ranking_loss_wt"`. These -weights decide which task the model will focus on. - - -```python - -class MultiTaskModel(keras.Model): - def __init__( - self, - num_users, - num_candidates, - embedding_dimension=32, - layer_sizes=(256, 128), - retrieval_loss_wt=1.0, - ranking_loss_wt=1.0, - **kwargs, - ): - super().__init__(**kwargs) - # Our query tower, simply an embedding table. - self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension) - - # Our candidate tower, simply an embedding table. - self.candidate_embedding = keras.layers.Embedding( - num_candidates, embedding_dimension - ) - - # Rating model. - self.rating_model = tf.keras.Sequential( - [ - keras.layers.Dense(layer_size, activation="relu") - for layer_size in layer_sizes - ] - + [keras.layers.Dense(1)] - ) - - # The layer that performs the retrieval. - self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False) - - self.retrieval_loss_fn = keras.losses.CategoricalCrossentropy( - from_logits=True, - reduction="sum", - ) - self.ranking_loss_fn = keras.losses.MeanSquaredError() - - # Top-k accuracy for retrieval - self.top_k_metric = keras.metrics.SparseTopKCategoricalAccuracy( - k=100, from_sorted_ids=True - ) - # RMSE for ranking - self.rmse_metric = keras.metrics.RootMeanSquaredError() - - # Attributes. - self.num_users = num_users - self.num_candidates = num_candidates - self.embedding_dimension = embedding_dimension - self.layer_sizes = layer_sizes - self.retrieval_loss_wt = retrieval_loss_wt - self.ranking_loss_wt = ranking_loss_wt - - def build(self, input_shape): - self.user_embedding.build(input_shape) - self.candidate_embedding.build(input_shape) - # In this case, the candidates are directly the movie embeddings. - # We take a shortcut and directly reuse the variable. - self.retrieval.candidate_embeddings = self.candidate_embedding.embeddings - self.retrieval.build(input_shape) - - self.rating_model.build((None, 2 * self.embedding_dimension)) - - super().build(input_shape) - - def call(self, inputs, training=False): - # Unpack inputs. Note that we have the if condition throughout this - # `call()` method so that we can do a `.predict()` for the retrieval - # task. - user_id = inputs["user_id"] - if "movie_id" in inputs: - movie_id = inputs["movie_id"] - - result = {} - - # Get user, movie embeddings. - user_embeddings = self.user_embedding(user_id) - result["user_embeddings"] = user_embeddings - - if "movie_id" in inputs: - candidate_embeddings = self.candidate_embedding(movie_id) - result["candidate_embeddings"] = candidate_embeddings - - # Pass both embeddings through the rating block of the model. - rating = self.rating_model( - keras.ops.concatenate([user_embeddings, candidate_embeddings], axis=1) - ) - result["rating"] = rating - - if not training: - # Skip the retrieval of top movies during training as the - # predictions are not used. - result["predictions"] = self.retrieval(user_embeddings) - - return result - - def compute_loss(self, x, y, y_pred, sample_weight, training=True): - user_embeddings = y_pred["user_embeddings"] - candidate_embeddings = y_pred["candidate_embeddings"] - - # 1. Retrieval - - # Compute the affinity score by multiplying the two embeddings. - scores = keras.ops.matmul( - user_embeddings, - keras.ops.transpose(candidate_embeddings), - ) - - # Retrieval labels: One-hot vectors - num_users = keras.ops.shape(user_embeddings)[0] - num_candidates = keras.ops.shape(candidate_embeddings)[0] - retrieval_labels = keras.ops.eye(num_users, num_candidates) - # Retrieval loss - retrieval_loss = self.retrieval_loss_fn(retrieval_labels, scores, sample_weight) - - # 2. Ranking - ratings = y - pred_rating = y_pred["rating"] - - # Ranking labels are just ratings. - ranking_labels = keras.ops.expand_dims(ratings, -1) - # Ranking loss - ranking_loss = self.ranking_loss_fn(ranking_labels, pred_rating, sample_weight) - - # Total loss is a weighted combination of the two losses. - total_loss = ( - self.retrieval_loss_wt * retrieval_loss - + self.ranking_loss_wt * ranking_loss - ) - - return total_loss - - def compute_metrics(self, x, y, y_pred, sample_weight=None): - # RMSE can be computed irrespective of whether we are - # training/evaluating. - self.rmse_metric.update_state( - y, - y_pred["rating"], - sample_weight=sample_weight, - ) - - if "predictions" in y_pred: - # We are evaluating or predicting. Update `top_k_metric`. - movie_ids = x["movie_id"] - predictions = y_pred["predictions"] - # For `top_k_metric`, which is a `SparseTopKCategoricalAccuracy`, we - # only take top rated movies, and we put a weight of 0 for the rest. - rating_weight = keras.ops.cast(keras.ops.greater(y, 0.9), "float32") - sample_weight = ( - rating_weight - if sample_weight is None - else keras.ops.multiply(rating_weight, sample_weight) - ) - self.top_k_metric.update_state( - movie_ids, predictions, sample_weight=sample_weight - ) - - return self.get_metrics_result() - else: - # We are training. `top_k_metric` is not updated and is zero, so - # don't report it. - result = self.get_metrics_result() - result.pop(self.top_k_metric.name) - return result - -``` - ---- -## Training and evaluating - -We will train three different models here. This can be done easily by passing -the correct loss weights: - -1. Rating-specialised model -2. Retrieval-specialised model -3. Multi-task model - - -```python -# Rating-specialised model -model = MultiTaskModel( - num_users=users_count + 1, - num_candidates=movies_count + 1, - ranking_loss_wt=1.0, - retrieval_loss_wt=0.0, -) -model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1)) -model.fit(train_ratings, epochs=5) - -model.evaluate(test_ratings) - -# Retrieval-specialised model -model = MultiTaskModel( - num_users=users_count + 1, - num_candidates=movies_count + 1, - ranking_loss_wt=0.0, - retrieval_loss_wt=1.0, -) -model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1)) -model.fit(train_ratings, epochs=5) - -model.evaluate(test_ratings) - -# Multi-task model -model = MultiTaskModel( - num_users=users_count + 1, - num_candidates=movies_count + 1, - ranking_loss_wt=1.0, - retrieval_loss_wt=1.0, -) -model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1)) -model.fit(train_ratings, epochs=5) - -model.evaluate(test_ratings) -``` - -
-``` -Epoch 1/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 3:45 3s/step - loss: 0.4353 - root_mean_squared_error: 0.6598 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 52s 671ms/step - loss: 0.3644 - root_mean_squared_error: 0.6007 - -
-``` - -``` -
- 26/80 ━━━━━━━━━━━━━━━━━━━━ 1s 29ms/step - loss: 0.1393 - root_mean_squared_error: 0.3644 - -
-``` - -``` -
- 27/80 ━━━━━━━━━━━━━━━━━━━━ 1s 28ms/step - loss: 0.1376 - root_mean_squared_error: 0.3623 - -
-``` - -``` -
- 48/80 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.1170 - root_mean_squared_error: 0.3353 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.1073 - root_mean_squared_error: 0.3223 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.1070 - root_mean_squared_error: 0.3218 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 4s 13ms/step - loss: 0.1042 - root_mean_squared_error: 0.3180 - - -
-``` -Epoch 2/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 52s 668ms/step - loss: 0.0780 - root_mean_squared_error: 0.2792 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.0785 - root_mean_squared_error: 0.2801 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0776 - root_mean_squared_error: 0.2786 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0776 - root_mean_squared_error: 0.2786 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0776 - root_mean_squared_error: 0.2786 - -
-``` - -``` -
- 57/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0777 - root_mean_squared_error: 0.2787 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0777 - root_mean_squared_error: 0.2787 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0777 - root_mean_squared_error: 0.2787 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0777 - root_mean_squared_error: 0.2787 - - -
-``` -Epoch 3/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0763 - root_mean_squared_error: 0.2762 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0770 - root_mean_squared_error: 0.2775 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0761 - root_mean_squared_error: 0.2758 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0761 - root_mean_squared_error: 0.2758 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0761 - root_mean_squared_error: 0.2758 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0760 - root_mean_squared_error: 0.2756 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0760 - root_mean_squared_error: 0.2756 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0759 - root_mean_squared_error: 0.2755 - - -
-``` -Epoch 4/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0734 - root_mean_squared_error: 0.2710 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0745 - root_mean_squared_error: 0.2730 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0736 - root_mean_squared_error: 0.2713 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0736 - root_mean_squared_error: 0.2713 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0734 - root_mean_squared_error: 0.2710 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0734 - root_mean_squared_error: 0.2710 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0735 - root_mean_squared_error: 0.2710 - - -
-``` -Epoch 5/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0717 - root_mean_squared_error: 0.2678 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0733 - root_mean_squared_error: 0.2713 - - - 3/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0736 - root_mean_squared_error: 0.2713 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0714 - root_mean_squared_error: 0.2671 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0714 - root_mean_squared_error: 0.2672 - -
-``` - -``` -
- 57/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0713 - root_mean_squared_error: 0.2670 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0713 - root_mean_squared_error: 0.2670 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.0713 - root_mean_squared_error: 0.2669 - - - - 1/20 ━━━━━━━━━━━━━━━━━━━━ 42s 2s/step - loss: 0.0685 - root_mean_squared_error: 0.2618 - sparse_top_k_categorical_accuracy: 0.0046 - -
-``` - -``` -
- 3/20 ━━━━━━━━━━━━━━━━━━━━ 5s 349ms/step - loss: 0.0677 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - -
-``` - -``` -
- 5/20 ━━━━━━━━━━━━━━━━━━━━ 2s 174ms/step - loss: 0.0670 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - 9/20 ━━━━━━━━━━━━━━━━━━━━ 0s 87ms/step - loss: 0.0667 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044  - 2/20 ━━━━━━━━━━━━━━━━━━━━ 12s 696ms/step - loss: 0.0681 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - 11/20 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 0.0667 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044  - 6/20 ━━━━━━━━━━━━━━━━━━━━ 1s 140ms/step - loss: 0.0667 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - 10/20 ━━━━━━━━━━━━━━━━━━━━ 0s 78ms/step - loss: 0.0667 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - -
-``` - -``` -
- 13/20 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 0.0671 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044  - 8/20 ━━━━━━━━━━━━━━━━━━━━ 1s 100ms/step - loss: 0.0668 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - -
-``` - -``` -
- 12/20 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 0.0667 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - -
-``` - -``` -
- 7/20 ━━━━━━━━━━━━━━━━━━━━ 1s 116ms/step - loss: 0.0669 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - -
-``` - -``` -
- 4/20 ━━━━━━━━━━━━━━━━━━━━ 3s 233ms/step - loss: 0.0667 - root_mean_squared_error: 0.2582 - sparse_top_k_categorical_accuracy: 0.0044 - -
-``` - -``` -
- 20/20 ━━━━━━━━━━━━━━━━━━━━ 3s 38ms/step - loss: 0.0670 - root_mean_squared_error: 0.2589 - sparse_top_k_categorical_accuracy: 0.0046 - - -
-``` -Epoch 1/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 55s 705ms/step - loss: 6907.7500 - root_mean_squared_error: 0.6712 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 53s 681ms/step - loss: 6907.7939 - root_mean_squared_error: 0.6763 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 1s 26ms/step - loss: 6906.6592 - root_mean_squared_error: 0.6932 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 1s 25ms/step - loss: 6906.3804 - root_mean_squared_error: 0.6932 - -
-``` - -``` -
- 57/80 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 6887.2905 - root_mean_squared_error: 0.6935 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 6886.2769 - root_mean_squared_error: 0.6935 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 11ms/step - loss: 6861.2632 - root_mean_squared_error: 0.6933 - - -
-``` -Epoch 2/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 52s 668ms/step - loss: 6595.3521 - root_mean_squared_error: 0.6702 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6568.2349 - root_mean_squared_error: 0.6925 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6567.1797 - root_mean_squared_error: 0.6926 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6566.1387 - root_mean_squared_error: 0.6926 - -
-``` - -``` -
- 55/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6544.7070 - root_mean_squared_error: 0.6939 - -
-``` - -``` -
- 56/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6543.9644 - root_mean_squared_error: 0.6939 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 6527.7217 - root_mean_squared_error: 0.6952 - - -
-``` -Epoch 3/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6421.3364 - root_mean_squared_error: 0.6830 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6426.4746 - root_mean_squared_error: 0.6891 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6402.4702 - root_mean_squared_error: 0.7059 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6401.7056 - root_mean_squared_error: 0.7059 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6400.9751 - root_mean_squared_error: 0.7059 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6386.6289 - root_mean_squared_error: 0.7069 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6386.2451 - root_mean_squared_error: 0.7070 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6379.3403 - root_mean_squared_error: 0.7077 - - -
-``` -Epoch 4/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6326.5630 - root_mean_squared_error: 0.6919 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 6333.5112 - root_mean_squared_error: 0.6981 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6309.5977 - root_mean_squared_error: 0.7150 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6308.8608 - root_mean_squared_error: 0.7151 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6296.6309 - root_mean_squared_error: 0.7158 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6296.3599 - root_mean_squared_error: 0.7159 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6296.0918 - root_mean_squared_error: 0.7159 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6291.6152 - root_mean_squared_error: 0.7164 - - -
-``` -Epoch 5/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6259.3281 - root_mean_squared_error: 0.6987 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6267.6138 - root_mean_squared_error: 0.7051 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6242.9312 - root_mean_squared_error: 0.7220 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6242.1875 - root_mean_squared_error: 0.7220 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6241.4839 - root_mean_squared_error: 0.7221 - -
-``` - -``` -
- 57/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6231.3540 - root_mean_squared_error: 0.7226 - -
-``` - -``` -
- 58/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6231.1279 - root_mean_squared_error: 0.7226 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6227.6514 - root_mean_squared_error: 0.7231 - - - - 1/20 ━━━━━━━━━━━━━━━━━━━━ 9s 501ms/step - loss: 6525.7983 - root_mean_squared_error: 0.7341 - sparse_top_k_categorical_accuracy: 0.0183 - -
-``` - -``` -
- 2/20 ━━━━━━━━━━━━━━━━━━━━ 12s 708ms/step - loss: 6545.6025 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - 9/20 ━━━━━━━━━━━━━━━━━━━━ 0s 89ms/step - loss: 6557.3950 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156  - 5/20 ━━━━━━━━━━━━━━━━━━━━ 2s 177ms/step - loss: 6556.7119 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - 6/20 ━━━━━━━━━━━━━━━━━━━━ 1s 142ms/step - loss: 6557.6411 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - 4/20 ━━━━━━━━━━━━━━━━━━━━ 3s 237ms/step - loss: 6556.4917 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - 13/20 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 6558.5605 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - -
-``` - -``` -
- 11/20 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - loss: 6557.2266 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - -
-``` - -``` -
- 7/20 ━━━━━━━━━━━━━━━━━━━━ 1s 119ms/step - loss: 6558.2988 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156  - 10/20 ━━━━━━━━━━━━━━━━━━━━ 0s 79ms/step - loss: 6557.6724 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - -
-``` - -``` -
- 8/20 ━━━━━━━━━━━━━━━━━━━━ 1s 102ms/step - loss: 6557.9561 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - -
-``` - -``` -
- 12/20 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 6556.1787 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - -
-``` - -``` -
- 3/20 ━━━━━━━━━━━━━━━━━━━━ 6s 356ms/step - loss: 6558.2368 - root_mean_squared_error: 0.7329 - sparse_top_k_categorical_accuracy: 0.0156 - -
-``` - -``` -
- 20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 39ms/step - loss: 6558.5298 - root_mean_squared_error: 0.7323 - sparse_top_k_categorical_accuracy: 0.0156 - - -
-``` -Epoch 1/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 56s 716ms/step - loss: 6907.9180 - root_mean_squared_error: 0.6640 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 51s 656ms/step - loss: 6907.9414 - root_mean_squared_error: 0.6054 - -
-``` - -``` -
- 3/80 ━━━━━━━━━━━━━━━━━━━━ 25s 330ms/step - loss: 6907.9351 - root_mean_squared_error: 0.5618 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 1s 25ms/step - loss: 6906.2886 - root_mean_squared_error: 0.3586 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 1s 24ms/step - loss: 6905.9717 - root_mean_squared_error: 0.3569 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 6884.6377 - root_mean_squared_error: 0.3280 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 6883.6255 - root_mean_squared_error: 0.3274 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 10ms/step - loss: 6861.9297 - root_mean_squared_error: 0.3174 - - -
-``` -Epoch 2/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 52s 660ms/step - loss: 6599.1538 - root_mean_squared_error: 0.2549 - -
-``` - -``` -
- 29/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6566.7197 - root_mean_squared_error: 0.2586 - -
-``` - -``` -
- 30/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6565.6699 - root_mean_squared_error: 0.2586 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6564.6597 - root_mean_squared_error: 0.2586 - -
-``` - -``` -
- 59/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6541.2002 - root_mean_squared_error: 0.2586 - -
-``` - -``` -
- 60/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6540.4863 - root_mean_squared_error: 0.2586 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 6526.9360 - root_mean_squared_error: 0.2591 - - -
-``` -Epoch 3/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6427.2715 - root_mean_squared_error: 0.2496 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 6430.3330 - root_mean_squared_error: 0.2527 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6401.6621 - root_mean_squared_error: 0.2532 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6400.9707 - root_mean_squared_error: 0.2532 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6400.2896 - root_mean_squared_error: 0.2531 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6386.1152 - root_mean_squared_error: 0.2531 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6385.7368 - root_mean_squared_error: 0.2532 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6385.3530 - root_mean_squared_error: 0.2533 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6379.2231 - root_mean_squared_error: 0.2537 - - -
-``` -Epoch 4/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6332.7959 - root_mean_squared_error: 0.2469 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6337.2896 - root_mean_squared_error: 0.2503 - -
-``` - -``` -
- 31/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6308.8354 - root_mean_squared_error: 0.2503 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6308.1694 - root_mean_squared_error: 0.2503 - -
-``` - -``` -
- 61/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6295.6636 - root_mean_squared_error: 0.2502 - -
-``` - -``` -
- 62/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6295.3931 - root_mean_squared_error: 0.2502 - -
-``` - -``` -
- 63/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6295.1182 - root_mean_squared_error: 0.2502 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6290.9727 - root_mean_squared_error: 0.2506 - - -
-``` -Epoch 5/5 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6266.3545 - root_mean_squared_error: 0.2446 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6271.7319 - root_mean_squared_error: 0.2483 - - - 3/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 6268.4746 - root_mean_squared_error: 0.2497 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6240.8154 - root_mean_squared_error: 0.2482 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6240.1978 - root_mean_squared_error: 0.2482 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6239.6104 - root_mean_squared_error: 0.2481 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6229.3428 - root_mean_squared_error: 0.2482 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6229.1450 - root_mean_squared_error: 0.2482 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6228.9478 - root_mean_squared_error: 0.2482 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6226.5605 - root_mean_squared_error: 0.2485 - - - - 1/20 ━━━━━━━━━━━━━━━━━━━━ 9s 478ms/step - loss: 6510.3120 - root_mean_squared_error: 0.2476 - sparse_top_k_categorical_accuracy: 0.0183 - -
-``` - -``` -
- 3/20 ━━━━━━━━━━━━━━━━━━━━ 5s 351ms/step - loss: 6552.2383 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - 8/20 ━━━━━━━━━━━━━━━━━━━━ 1s 100ms/step - loss: 6548.0225 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - -
-``` - -``` -
- 11/20 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 6552.4331 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158  - 10/20 ━━━━━━━━━━━━━━━━━━━━ 0s 78ms/step - loss: 6553.4868 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - -
-``` - -``` -
- 5/20 ━━━━━━━━━━━━━━━━━━━━ 2s 175ms/step - loss: 6552.0576 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - 13/20 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 6553.3755 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - 7/20 ━━━━━━━━━━━━━━━━━━━━ 1s 117ms/step - loss: 6552.1162 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - 9/20 ━━━━━━━━━━━━━━━━━━━━ 0s 88ms/step - loss: 6552.2988 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - 4/20 ━━━━━━━━━━━━━━━━━━━━ 3s 233ms/step - loss: 6552.1694 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - -
-``` - -``` -
- 6/20 ━━━━━━━━━━━━━━━━━━━━ 1s 140ms/step - loss: 6551.8081 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - -
-``` - -``` -
- 2/20 ━━━━━━━━━━━━━━━━━━━━ 12s 699ms/step - loss: 6548.6211 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - -
-``` - -``` -
- 12/20 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 6552.3442 - root_mean_squared_error: 0.2488 - sparse_top_k_categorical_accuracy: 0.0158 - -
-``` - -``` -
- 20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 38ms/step - loss: 6554.1953 - root_mean_squared_error: 0.2492 - sparse_top_k_categorical_accuracy: 0.0158 - - - - - -
-``` -[6555.712890625, 0.016953036189079285, 0.2508334815502167] - -``` -
-Let's plot a table of the metrics and pen down our observations: - -| Model | Top-K Accuracy (↑) | RMSE (↓) | -|-----------------------|--------------------|----------| -| rating-specialised | 0.005 | 0.26 | -| retrieval-specialised | 0.020 | 0.78 | -| multi-task | 0.022 | 0.25 | - -As expected, the rating-specialised model has good RMSE, but poor top-k -accuracy. For the retrieval-specialised model, it's the opposite. - -For the multi-task model, we notice that the model does well (or even slightly -better than the two specialised models) on both tasks. In general, we can expect -multi-task learning to bring about better results, especially when one task has -a data-abundant source, and the other task is trained on sparse data. - -Now, let's make a prediction! We will first do a retrieval, and then for the -retrieved list of movies, we will predict the rating using the same model. - - -```python -movie_id_to_movie_title = { - int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator() -} -movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. - -user_id = 5 -retrieved_movie_ids = model.predict( - { - "user_id": keras.ops.array([user_id]), - } -) -retrieved_movie_ids = keras.ops.convert_to_numpy(retrieved_movie_ids["predictions"][0]) -retrieved_movies = [movie_id_to_movie_title[x] for x in retrieved_movie_ids] -``` - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 109ms/step - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 110ms/step - - -For these retrieved movies, we can now get the corresponding ratings. - - -```python -pred_ratings = model.predict( - { - "user_id": keras.ops.array([user_id] * len(retrieved_movie_ids)), - "movie_id": keras.ops.array(retrieved_movie_ids), - } -)["rating"] -pred_ratings = keras.ops.convert_to_numpy(keras.ops.squeeze(pred_ratings, axis=1)) - -for movie_id, prediction in zip(retrieved_movie_ids, pred_ratings): - print(f"{movie_id_to_movie_title[movie_id]}: {5.0 * prediction:,.2f}") -``` - - - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 273ms/step - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 274ms/step - - -
-``` -b'Blob, The (1958)': 2.01 -b'Mighty Morphin Power Rangers: The Movie (1995)': 2.03 -b'Flintstones, The (1994)': 2.18 -b'Beverly Hillbillies, The (1993)': 1.89 -b'Lawnmower Man, The (1992)': 2.57 -b'Hot Shots! Part Deux (1993)': 2.28 -b'Street Fighter (1994)': 1.84 -b'Cabin Boy (1994)': 1.94 -b'Little Rascals, The (1994)': 2.12 -b'Jaws 3-D (1983)': 2.27 - -``` -
\ No newline at end of file diff --git a/templates/keras_rs/examples/sas_rec.md b/templates/keras_rs/examples/sas_rec.md deleted file mode 100644 index 54399cd2e5..0000000000 --- a/templates/keras_rs/examples/sas_rec.md +++ /dev/null @@ -1,2970 +0,0 @@ -# Sequential retrieval using SASRec - -**Author:** [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Recommend movies using a Transformer-based retrieval model (SASRec). - - - [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/sas_rec.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/sas_rec.py) - - - ---- -## Introduction - -Sequential recommendation is a popular model that looks at a sequence of items -that users have interacted with previously and then predicts the next item. -Here, the order of the items within each sequence matters. Previously, in the -[Recommending movies: retrieval using a sequential model](/keras_rs/examples/sequential_retrieval/) -example, we built a GRU-based sequential retrieval model. In this example, we -will build a popular Transformer decoder-based model named -[Self-Attentive Sequential Recommendation (SASRec)](https://arxiv.org/abs/1808.09781) -for the same sequential recommendation task. - -Let's begin by importing all the necessary libraries. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import collections -import os - -import keras -import keras_hub -import numpy as np -import pandas as pd -import tensorflow as tf # Needed only for the dataset -from keras import ops - -import keras_rs -``` - -Let's also define all important variables/hyperparameters below. - - -```python -DATA_DIR = "./raw/data/" - -# MovieLens-specific variables -MOVIELENS_1M_URL = "https://files.grouplens.org/datasets/movielens/ml-1m.zip" -MOVIELENS_ZIP_HASH = "a6898adb50b9ca05aa231689da44c217cb524e7ebd39d264c56e2832f2c54e20" - -RATINGS_FILE_NAME = "ratings.dat" -MOVIES_FILE_NAME = "movies.dat" - -# Data processing args -MAX_CONTEXT_LENGTH = 200 -MIN_SEQUENCE_LENGTH = 3 -PAD_ITEM_ID = 0 - -RATINGS_DATA_COLUMNS = ["UserID", "MovieID", "Rating", "Timestamp"] -MOVIES_DATA_COLUMNS = ["MovieID", "Title", "Genres"] -MIN_RATING = 2 - -# Training/model args picked from SASRec paper -BATCH_SIZE = 128 -NUM_EPOCHS = 10 -LEARNING_RATE = 0.001 - -NUM_LAYERS = 2 -NUM_HEADS = 1 -HIDDEN_DIM = 50 -DROPOUT = 0.2 -``` - ---- -## Dataset - -Next, we need to prepare our dataset. Like we did in the -[sequential retrieval](/keras_rs/examples/sequential_retrieval/) -example, we are going to use the MovieLens dataset. - -The dataset preparation step is fairly involved. The original ratings dataset -contains `(user, movie ID, rating, timestamp)` tuples (among other columns, -which are not important for this example). Since we are dealing with sequential -retrieval, we need to create movie sequences for every user, where the sequences -are ordered by timestamp. - -Let's start by downloading and reading the dataset. - - -```python -# Download the MovieLens dataset. -if not os.path.exists(DATA_DIR): - os.makedirs(DATA_DIR) - -path_to_zip = keras.utils.get_file( - fname="ml-1m.zip", - origin=MOVIELENS_1M_URL, - file_hash=MOVIELENS_ZIP_HASH, - hash_algorithm="sha256", - extract=True, - cache_dir=DATA_DIR, -) -movielens_extracted_dir = os.path.join( - os.path.dirname(path_to_zip), - "ml-1m_extracted", - "ml-1m", -) - - -# Read the dataset. -def read_data(data_directory, min_rating=None): - """Read movielens ratings.dat and movies.dat file - into dataframe. - """ - - ratings_df = pd.read_csv( - os.path.join(data_directory, RATINGS_FILE_NAME), - sep="::", - names=RATINGS_DATA_COLUMNS, - encoding="unicode_escape", - ) - ratings_df["Timestamp"] = ratings_df["Timestamp"].apply(int) - - # Remove movies with `rating < min_rating`. - if min_rating is not None: - ratings_df = ratings_df[ratings_df["Rating"] >= min_rating] - - movies_df = pd.read_csv( - os.path.join(data_directory, MOVIES_FILE_NAME), - sep="::", - names=MOVIES_DATA_COLUMNS, - encoding="unicode_escape", - ) - return ratings_df, movies_df - - -ratings_df, movies_df = read_data( - data_directory=movielens_extracted_dir, min_rating=MIN_RATING -) - -# Need to know #movies so as to define embedding layers. -movies_count = movies_df["MovieID"].max() -``` - -
-``` -Downloading data from https://files.grouplens.org/datasets/movielens/ml-1m.zip - -``` -
- - 0/5917549 ━━━━━━━━━━━━━━━━━━━━ 0s 0s/step - -
-``` - -``` -
- 8192/5917549 ━━━━━━━━━━━━━━━━━━━━ 2:27 25us/step - -
-``` - -``` -
- 40960/5917549 ━━━━━━━━━━━━━━━━━━━━ 58s 10us/step - -
-``` - -``` -
- 73728/5917549 ━━━━━━━━━━━━━━━━━━━━ 48s 8us/step - -
-``` - -``` -
- 139264/5917549 ━━━━━━━━━━━━━━━━━━━━ 34s 6us/step - -
-``` - -``` -
- 270336/5917549 ━━━━━━━━━━━━━━━━━━━━ 21s 4us/step - -
-``` - -``` -
- 532480/5917549 ━━━━━━━━━━━━━━━━━━━━ 12s 2us/step - -
-``` - -``` -
- 1056768/5917549 ━━━━━━━━━━━━━━━━━━━━ 6s 1us/step - -
-``` - -``` -
- 2121728/5917549 ━━━━━━━━━━━━━━━━━━━━ 2s 1us/step - -
-``` - -``` -
- 4218880/5917549 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step - -
-``` - -``` -
- 5917549/5917549 ━━━━━━━━━━━━━━━━━━━━ 2s 0us/step - - -
-``` -:26: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'. - ratings_df = pd.read_csv( - -:38: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'. - movies_df = pd.read_csv( - -``` -
-Now that we have read the dataset, let's create sequences of movies -for every user. Here is the function for doing just that. - - -```python - -def get_movie_sequence_per_user(ratings_df): - """Get movieID sequences for every user.""" - sequences = collections.defaultdict(list) - - for user_id, movie_id, rating, timestamp in ratings_df.values: - sequences[user_id].append( - { - "movie_id": movie_id, - "timestamp": timestamp, - "rating": rating, - } - ) - - # Sort movie sequences by timestamp for every user. - for user_id, context in sequences.items(): - context.sort(key=lambda x: x["timestamp"]) - sequences[user_id] = context - - return sequences - - -sequences = get_movie_sequence_per_user(ratings_df) -``` - -So far, we have essentially replicated what we did in the sequential retrieval -example. We have a sequence of movies for every user. - -SASRec is trained contrastively, which means the model learns to distinguish -between sequences of movies a user has actually interacted with (positive -examples) and sequences they have not interacted with (negative examples). - -The following function, `format_data`, prepares the data in this specific -format. For each user's movie sequence, it generates a corresponding -"negative sequence". This negative sequence consists of randomly -selected movies that the user has *not* interacted with, but are of the same -length as the original sequence. - - -```python - -def format_data(sequences): - examples = { - "sequence": [], - "negative_sequence": [], - } - - for user_id in sequences: - sequence = [int(d["movie_id"]) for d in sequences[user_id]] - - # Get negative sequence. - def random_negative_item_id(low, high, positive_lst): - sampled = np.random.randint(low=low, high=high) - while sampled in positive_lst: - sampled = np.random.randint(low=low, high=high) - return sampled - - negative_sequence = [ - random_negative_item_id(1, movies_count + 1, sequence) - for _ in range(len(sequence)) - ] - - examples["sequence"].append(np.array(sequence)) - examples["negative_sequence"].append(np.array(negative_sequence)) - - examples["sequence"] = tf.ragged.constant(examples["sequence"]) - examples["negative_sequence"] = tf.ragged.constant(examples["negative_sequence"]) - - return examples - - -examples = format_data(sequences) -ds = tf.data.Dataset.from_tensor_slices(examples).batch(BATCH_SIZE) -``` - -Now that we have the original movie interaction sequences for each user (from -`format_data`, stored in `examples["sequence"]`) and their corresponding -random negative sequences (in `examples["negative_sequence"]`), the next step is -to prepare this data for input to the model. The primary goals of this -preprocessing are: - -1. Creating Input Features and Target Labels: For sequential - recommendation, the model learns to predict the next item in a sequence - given the preceding items. This is achieved by: - - taking the original `example["sequence"]` and creating the model's - input features (`item_ids`) from all items *except the last one* - (`example["sequence"][..., :-1]`); - - creating the target "positive sequence" (what the model tries to predict - as the actual next items) by taking the original `example["sequence"]` - and shifting it, using all items *except the first one* - (`example["sequence"][..., 1:]`); - - shifting `example["negative_sequence"]` (from `format_data`) is - to create the target "negative sequence" for the contrastive loss - (`example["negative_sequence"][..., 1:]`). - -2. Handling Variable Length Sequences: Neural networks typically require - fixed-size inputs. Therefore, both the input feature sequences and the - target sequences are padded (with a special `PAD_ITEM_ID`) or truncated - to a predefined `MAX_CONTEXT_LENGTH`. A `padding_mask` is also generated - from the input features to ensure the model ignores these padded tokens - during attention calculations, i.e, these tokens will be masked. - -3. Differentiating Training and Validation/Testing: - - During training: - - Input features (`item_ids`) and context for negative sequences - are prepared as described above (all but the last item of the - original sequences). - - Target positive and negative sequences are the shifted versions of - the original sequences. - - `sample_weight` is created based on the input features to ensure - that loss is calculated only on actual items, not on padding tokens - in the targets. - - During validation/testing: - - Input features are prepared similarly. - - The model's performance is typically evaluated on its ability to - predict the actual last item of the original sequence. Thus, - `sample_weight` is configured to focus the loss calculation - only on this final prediction in the target sequences. - -Note: SASRec does the same thing we've done above, except that they take the -`item_ids[:-2]` for the validation set and `item_ids[:-1]` for the test set. -We skip that here for brevity. - - -```python - -def _preprocess(example, train=False): - sequence = example["sequence"] - negative_sequence = example["negative_sequence"] - - if train: - sequence = example["sequence"][..., :-1] - negative_sequence = example["negative_sequence"][..., :-1] - - batch_size = tf.shape(sequence)[0] - - if not train: - # Loss computed only on last token. - sample_weight = tf.zeros_like(sequence, dtype="float32")[..., :-1] - sample_weight = tf.concat( - [sample_weight, tf.ones((batch_size, 1), dtype="float32")], axis=1 - ) - - # Truncate/pad sequence. +1 to account for truncation later. - sequence = sequence.to_tensor( - shape=[batch_size, MAX_CONTEXT_LENGTH + 1], default_value=PAD_ITEM_ID - ) - negative_sequence = negative_sequence.to_tensor( - shape=[batch_size, MAX_CONTEXT_LENGTH + 1], default_value=PAD_ITEM_ID - ) - if train: - sample_weight = tf.cast(sequence != PAD_ITEM_ID, dtype="float32") - else: - sample_weight = sample_weight.to_tensor( - shape=[batch_size, MAX_CONTEXT_LENGTH + 1], default_value=0 - ) - - example = ( - { - # last token does not have a next token - "item_ids": sequence[..., :-1], - # padding mask for controlling attention mask - "padding_mask": (sequence != PAD_ITEM_ID)[..., :-1], - }, - { - "positive_sequence": sequence[ - ..., 1: - ], # 0th token's label will be 1st token, and so on - "negative_sequence": negative_sequence[..., 1:], - }, - sample_weight[..., 1:], # loss will not be computed on pad tokens - ) - return example - - -def preprocess_train(examples): - return _preprocess(examples, train=True) - - -def preprocess_val(examples): - return _preprocess(examples, train=False) - - -train_ds = ds.map(preprocess_train) -val_ds = ds.map(preprocess_val) -``` - -We can see a batch for each. - - -```python -for batch in train_ds.take(1): - print(batch) - -for batch in val_ds.take(1): - print(batch) - -``` - -
-``` -({'item_ids': , 'padding_mask': }, {'positive_sequence': , 'negative_sequence': }, ) -({'item_ids': , 'padding_mask': }, {'positive_sequence': , 'negative_sequence': }, ) - -``` -
---- -## Model - -To encode the input sequence, we use a Transformer decoder-based model. This -part of the model is very similar to the GPT-2 architecture. Refer to the -[GPT text generation from scratch with KerasHub](/examples/generative/text_generation_gpt/#build-the-model) -guide for more details on this part. - -One part to note is that when we are "predicting", i.e., `training` is `False`, -we get the embedding corresponding to the last movie in the sequence. This makes -sense, because at inference time, we want to predict the movie the user will -likely watch after watching the last movie. - -Also, it's worth discussing the `compute_loss` method. We embed the positive -and negative sequences using the input embedding matrix. We compute the -similarity of (positive sequence, input sequence) and (negative sequence, -input sequence) pair embeddings by computing the dot product. The goal now is -to maximize the similarity of the former and minimize the similarity of -the latter. Let's see this mathematically. Binary Cross Entropy is written -as follows: - -``` - loss = - (y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)) -``` - -Here, we assign the positive pairs a label of 1 and the negative pairs a label -of 0. So, for a positive pair, the loss reduces to: - -``` -loss = -np.log(positive_logits) -``` - -Minimising the loss means we want to maximize the log term, which in turn, -implies maximising `positive_logits`. Similarly, we want to minimize -`negative_logits`. - - -```python - -class SasRec(keras.Model): - def __init__( - self, - vocabulary_size, - num_layers, - num_heads, - hidden_dim, - dropout=0.0, - max_sequence_length=100, - dtype=None, - **kwargs, - ): - super().__init__(dtype=dtype, **kwargs) - - # ======== Layers ======== - - # === Embeddings === - self.item_embedding = keras_hub.layers.ReversibleEmbedding( - input_dim=vocabulary_size, - output_dim=hidden_dim, - embeddings_initializer="glorot_uniform", - embeddings_regularizer=keras.regularizers.l2(0.001), - dtype=dtype, - name="item_embedding", - ) - self.position_embedding = keras_hub.layers.PositionEmbedding( - initializer="glorot_uniform", - sequence_length=max_sequence_length, - dtype=dtype, - name="position_embedding", - ) - self.embeddings_add = keras.layers.Add( - dtype=dtype, - name="embeddings_add", - ) - self.embeddings_dropout = keras.layers.Dropout( - dropout, - dtype=dtype, - name="embeddings_dropout", - ) - - # === Decoder layers === - self.transformer_layers = [] - for i in range(num_layers): - self.transformer_layers.append( - keras_hub.layers.TransformerDecoder( - intermediate_dim=hidden_dim, - num_heads=num_heads, - dropout=dropout, - layer_norm_epsilon=1e-05, - # SASRec uses ReLU, although GeLU might be a better option - activation="relu", - kernel_initializer="glorot_uniform", - normalize_first=True, - dtype=dtype, - name=f"transformer_layer_{i}", - ) - ) - - # === Final layer norm === - self.layer_norm = keras.layers.LayerNormalization( - axis=-1, - epsilon=1e-8, - dtype=dtype, - name="layer_norm", - ) - - # === Retrieval === - # The layer that performs the retrieval. - self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False) - - # === Loss === - self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True, reduction=None) - - # === Attributes === - self.vocabulary_size = vocabulary_size - self.num_layers = num_layers - self.num_heads = num_heads - self.hidden_dim = hidden_dim - self.dropout = dropout - self.max_sequence_length = max_sequence_length - - def _get_last_non_padding_token(self, tensor, padding_mask): - valid_token_mask = ops.logical_not(padding_mask) - seq_lengths = ops.sum(ops.cast(valid_token_mask, "int32"), axis=1) - last_token_indices = ops.maximum(seq_lengths - 1, 0) - - indices = ops.expand_dims(last_token_indices, axis=(-2, -1)) - gathered_tokens = ops.take_along_axis(tensor, indices, axis=1) - last_token_embedding = ops.squeeze(gathered_tokens, axis=1) - - return last_token_embedding - - def build(self, input_shape): - embedding_shape = list(input_shape) + [self.hidden_dim] - - # Model - self.item_embedding.build(input_shape) - self.position_embedding.build(embedding_shape) - - self.embeddings_add.build((embedding_shape, embedding_shape)) - self.embeddings_dropout.build(embedding_shape) - - for transformer_layer in self.transformer_layers: - transformer_layer.build(decoder_sequence_shape=embedding_shape) - - self.layer_norm.build(embedding_shape) - - # Retrieval - self.retrieval.candidate_embeddings = self.item_embedding.embeddings - self.retrieval.build(input_shape) - - # Chain to super - super().build(input_shape) - - def call(self, inputs, training=False): - item_ids, padding_mask = inputs["item_ids"], inputs["padding_mask"] - - x = self.item_embedding(item_ids) - position_embedding = self.position_embedding(x) - x = self.embeddings_add((x, position_embedding)) - x = self.embeddings_dropout(x) - - for transformer_layer in self.transformer_layers: - x = transformer_layer(x, decoder_padding_mask=padding_mask) - - item_sequence_embedding = self.layer_norm(x) - result = {"item_sequence_embedding": item_sequence_embedding} - - # At inference, perform top-k retrieval. - if not training: - # need to extract last non-padding token. - last_item_embedding = self._get_last_non_padding_token( - item_sequence_embedding, padding_mask - ) - result["predictions"] = self.retrieval(last_item_embedding) - - return result - - def compute_loss(self, x, y, y_pred, sample_weight, training=False): - item_sequence_embedding = y_pred["item_sequence_embedding"] - y_positive_sequence = y["positive_sequence"] - y_negative_sequence = y["negative_sequence"] - - # Embed positive, negative sequences. - positive_sequence_embedding = self.item_embedding(y_positive_sequence) - negative_sequence_embedding = self.item_embedding(y_negative_sequence) - - # Logits - positive_logits = ops.sum( - ops.multiply(positive_sequence_embedding, item_sequence_embedding), - axis=-1, - ) - negative_logits = ops.sum( - ops.multiply(negative_sequence_embedding, item_sequence_embedding), - axis=-1, - ) - logits = ops.concatenate([positive_logits, negative_logits], axis=1) - - # Labels - labels = ops.concatenate( - [ - ops.ones_like(positive_logits), - ops.zeros_like(negative_logits), - ], - axis=1, - ) - - # sample weights - sample_weight = ops.concatenate( - [sample_weight, sample_weight], - axis=1, - ) - - loss = self.loss_fn( - y_true=ops.expand_dims(labels, axis=-1), - y_pred=ops.expand_dims(logits, axis=-1), - sample_weight=sample_weight, - ) - loss = ops.divide_no_nan(ops.sum(loss), ops.sum(sample_weight)) - - return loss - - def compute_output_shape(self, inputs_shape): - return list(inputs_shape) + [self.hidden_dim] - -``` - -Let's instantiate our model and do some sanity checks. - - -```python -model = SasRec( - vocabulary_size=movies_count + 1, - num_layers=NUM_LAYERS, - num_heads=NUM_HEADS, - hidden_dim=HIDDEN_DIM, - dropout=DROPOUT, - max_sequence_length=MAX_CONTEXT_LENGTH, -) - -# Training -output = model( - inputs={ - "item_ids": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="int32"), - "padding_mask": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="bool"), - }, - training=True, -) -print(output["item_sequence_embedding"].shape) - -# Inference -output = model( - inputs={ - "item_ids": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="int32"), - "padding_mask": ops.ones((2, MAX_CONTEXT_LENGTH), dtype="bool"), - }, - training=False, -) -print(output["predictions"].shape) -``` - -
-``` -(2, 200, 50) - -(2, 10) - -``` -
-Now, let's compile and train our model. - - -```python -model.compile( - optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE, beta_2=0.98), -) -model.fit( - x=train_ds, - validation_data=val_ds, - epochs=NUM_EPOCHS, -) -``` - -
-``` -Epoch 1/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 3:07 4s/step - loss: 0.6965 - -
-``` - -``` -
- 2/48 ━━━━━━━━━━━━━━━━━━━━ 2:08 3s/step - loss: 0.6946 - -
-``` - -``` -
- 3/48 ━━━━━━━━━━━━━━━━━━━━ 1:03 1s/step - loss: 0.6926 - -
-``` - -``` -
- 4/48 ━━━━━━━━━━━━━━━━━━━━ 41s 944ms/step - loss: 0.6903 - 5/48 ━━━━━━━━━━━━━━━━━━━━ 30s 713ms/step - loss: 0.6881 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 24s 574ms/step - loss: 0.6859 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 19s 481ms/step - loss: 0.6836 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 16s 415ms/step - loss: 0.6813 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 14s 365ms/step - loss: 0.6790 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 12s 327ms/step - loss: 0.6767 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 9s 270ms/step - loss: 0.6720 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 10s 296ms/step - loss: 0.6744 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 8s 249ms/step - loss: 0.6697 - -
-``` - -``` -
- 14/48 ━━━━━━━━━━━━━━━━━━━━ 7s 234ms/step - loss: 0.6674 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 7s 220ms/step - loss: 0.6651 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 6s 208ms/step - loss: 0.6564 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 6s 195ms/step - loss: 0.6602 - 18/48 ━━━━━━━━━━━━━━━━━━━━ 5s 183ms/step - loss: 0.6580 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 5s 176ms/step - loss: 0.6508 - -
-``` - -``` -
- 24/48 ━━━━━━━━━━━━━━━━━━━━ 3s 139ms/step - loss: 0.6394 - 22/48 ━━━━━━━━━━━━━━━━━━━━ 3s 153ms/step - loss: 0.6426 - 20/48 ━━━━━━━━━━━━━━━━━━━━ 4s 167ms/step - loss: 0.6457 - -
-``` - -``` -
- 25/48 ━━━━━━━━━━━━━━━━━━━━ 3s 134ms/step - loss: 0.6379 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 3s 146ms/step - loss: 0.6410 - -
-``` - -``` -
- 21/48 ━━━━━━━━━━━━━━━━━━━━ 4s 158ms/step - loss: 0.6472 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 2s 130ms/step - loss: 0.6363 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 2s 121ms/step - loss: 0.6242 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 2s 118ms/step - loss: 0.6229 - -
-``` - -``` -
- 27/48 ━━━━━━━━━━━━━━━━━━━━ 2s 128ms/step - loss: 0.6315 - 28/48 ━━━━━━━━━━━━━━━━━━━━ 2s 124ms/step - loss: 0.6329 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 1s 108ms/step - loss: 0.6138 - 32/48 ━━━━━━━━━━━━━━━━━━━━ 1s 111ms/step - loss: 0.6109 - 31/48 ━━━━━━━━━━━━━━━━━━━━ 1s 113ms/step - loss: 0.6254 - 35/48 ━━━━━━━━━━━━━━━━━━━━ 1s 101ms/step - loss: 0.6118 - 38/48 ━━━━━━━━━━━━━━━━━━━━ 0s 93ms/step - loss: 0.6055  - 34/48 ━━━━━━━━━━━━━━━━━━━━ 1s 105ms/step - loss: 0.6092 - 37/48 ━━━━━━━━━━━━━━━━━━━━ 1s 96ms/step - loss: 0.6064 - -
-``` - -``` -
- 36/48 ━━━━━━━━━━━━━━━━━━━━ 1s 98ms/step - loss: 0.6148 - -
-``` - -``` -
- 41/48 ━━━━━━━━━━━━━━━━━━━━ 0s 89ms/step - loss: 0.6028  - 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 91ms/step - loss: 0.6037 - -
-``` - -``` -
- 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 93ms/step - loss: 0.6046 - 45/48 ━━━━━━━━━━━━━━━━━━━━ 0s 81ms/step - loss: 0.5972 - -
-``` - -``` -
- 47/48 ━━━━━━━━━━━━━━━━━━━━ 0s 78ms/step - loss: 0.5923 - 43/48 ━━━━━━━━━━━━━━━━━━━━ 0s 85ms/step - loss: 0.6009 - -
-``` - -``` -
- 42/48 ━━━━━━━━━━━━━━━━━━━━ 0s 87ms/step - loss: 0.5964 - 44/48 ━━━━━━━━━━━━━━━━━━━━ 0s 83ms/step - loss: 0.6000 - -
-``` - -``` -
- 46/48 ━━━━━━━━━━━━━━━━━━━━ 0s 79ms/step - loss: 0.5981 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 0s 138ms/step - loss: 0.5915 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 13s 199ms/step - loss: 0.5908 - val_loss: 0.5149 - - -
-``` -Epoch 2/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 2:12 3s/step - loss: 0.4476 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4472 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4469 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4458 - - - 5/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4456 - 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4456 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4459 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4463 - 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4461 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4465 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4467 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4469 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4469 - -
-``` - -``` -
- 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step - loss: 0.4471 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step - loss: 0.4472 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step - loss: 0.4472 - -
-``` - -``` -
- 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4470 - 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - loss: 0.4470 - -
-``` - -``` -
- 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4470 - 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.4472 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step - loss: 0.4472 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.4466 - 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - loss: 0.4466 - -
-``` - -``` -
- 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - loss: 0.4464 - 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.4463 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step - loss: 0.4453 - 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step - loss: 0.4452 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4454 - 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step - loss: 0.4461 - 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4456 - 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step - loss: 0.4462 - -
-``` - -``` -
- 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.4455 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step - loss: 0.4451 - 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.4449 - -
-``` - -``` -
- 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4448 - -
-``` - -``` -
- 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4448 - 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.4448 - -
-``` - -``` -
- 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4446 - 45/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4443 - -
-``` - -``` -
- 38/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4446 - 44/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4442 - -
-``` - -``` -
- 47/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4441 - 42/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4445 - 41/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4445 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4441 - 46/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4442 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - loss: 0.4441 - val_loss: 0.5084 - - -
-``` -Epoch 3/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 1s 42ms/step - loss: 0.4316 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - loss: 0.4313 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4309 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4299 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4298 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4302 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4304 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4307 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4310 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4313 - 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4316 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4317 - -
-``` - -``` -
- 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4319 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4321 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4322 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4323 - -
-``` - -``` -
- 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4323 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4324 - -
-``` - -``` -
- 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4324 - -
-``` - -``` -
- 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4324 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4322 - 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4320 - -
-``` - -``` -
- 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4321 - -
-``` - -``` -
- 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4322 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4317 - -
-``` - -``` -
- 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4317 - -
-``` - -``` -
- 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4316 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4315 - 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4312 - -
-``` - -``` -
- 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4313 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4315 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4313 - 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4311 - 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4311 - 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4312 - 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4311 - 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4312 - -
-``` - -``` -
- 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4311 - -
-``` - -``` -
- 41/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4310 - 43/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.4310 - -
-``` - -``` -
- 42/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4311 - 46/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4309 - 45/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.4309 - -
-``` - -``` -
- 47/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4309 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.4308 - val_loss: 0.4923 - - -
-``` -Epoch 4/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 1s 41ms/step - loss: 0.4203 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4200 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4195 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4183 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4181 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4183 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4185 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4187 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4190 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4195 - 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4193 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.4196 - 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4198 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.4199 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4199 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4200 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4200 - -
-``` - -``` -
- 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4198 - -
-``` - -``` -
- 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4198 - 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4199 - 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4200 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4197 - -
-``` - -``` -
- 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4196 - -
-``` - -``` -
- 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4195 - 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4194 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4193 - -
-``` - -``` -
- 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4192 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4191 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.4189 - -
-``` - -``` -
- 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4188 - 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4188 - -
-``` - -``` -
- 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4187 - 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4190 - -
-``` - -``` -
- 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4186 - 42/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4185 - 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4186 - -
-``` - -``` -
- 41/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4185 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.4182 - val_loss: 0.4797 - - -
-``` -Epoch 5/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 1s 41ms/step - loss: 0.4058 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4057 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4053 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4041 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4036 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4037 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.4038 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4039 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4042 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4044 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4045 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.4045 - 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.4046 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.4046 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4046 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4046 - -
-``` - -``` -
- 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4046 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4045 - -
-``` - -``` -
- 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4045 - -
-``` - -``` -
- 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4044 - 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4038 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4041 - 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4040 - -
-``` - -``` -
- 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4037 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4036 - 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4034 - -
-``` - -``` -
- 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4033 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4032 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4031 - 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4029 - -
-``` - -``` -
- 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4027 - 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4030 - -
-``` - -``` -
- 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4025 - 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.4023 - -
-``` - -``` -
- 38/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4022 - 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4023 - 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.4024 - -
-``` - -``` -
- 41/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4020 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.4028 - -
-``` - -``` -
- 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.4026 - -
-``` - -``` -
- 44/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.4019 - 42/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.4020 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.4014 - val_loss: 0.4611 - - -
-``` -Epoch 6/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 1s 41ms/step - loss: 0.3831 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3830 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3827 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3816 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3811 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3811 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3811 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3812 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3813 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3814 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3815 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3814 - -
-``` - -``` -
- 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3814 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3814 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3812 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3811 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3809 - 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3805 - 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3807 - 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3809 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3802 - 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3799 - -
-``` - -``` -
- 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3804 - 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3798 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3795 - -
-``` - -``` -
- 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3793 - -
-``` - -``` -
- 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3792 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3786 - 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3791 - 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3785 - 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.3788 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.3789 - -
-``` - -``` -
- 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3783 - 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3778 - 38/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3775 - 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3778 - -
-``` - -``` -
- 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3776 - 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3777 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.3760 - val_loss: 0.4355 - - -
-``` -Epoch 7/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 1s 41ms/step - loss: 0.3559 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3559 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3555 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3544 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3539 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3538 - 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3539 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3538 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3539 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3540 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3540 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3540 - 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3540 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3539 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3538 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3537 - 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3536 - -
-``` - -``` -
- 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3534 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3533 - 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3531 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3524 - -
-``` - -``` -
- 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3529 - -
-``` - -``` -
- 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3525 - 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3527 - 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3523 - -
-``` - -``` -
- 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3519 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3517 - -
-``` - -``` -
- 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3513 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3512 - 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3513 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3514 - 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3516 - -
-``` - -``` -
- 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3511 - 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3508 - 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3510 - -
-``` - -``` -
- 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3509 - 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3508 - 38/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3508 - -
-``` - -``` -
- 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3507 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.3500 - val_loss: 0.4174 - - -
-``` -Epoch 8/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 1s 41ms/step - loss: 0.3339 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3343 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3341 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3333 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3329 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3329 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3328 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3328 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3329 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3330 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3330 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3330 - -
-``` - -``` -
- 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3330 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3329 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3328 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3327 - -
-``` - -``` -
- 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3326 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3323 - 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3324 - -
-``` - -``` -
- 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3325 - -
-``` - -``` -
- 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3318 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3319 - 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3320 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3317 - -
-``` - -``` -
- 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3318 - -
-``` - -``` -
- 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3316 - -
-``` - -``` -
- 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3315 - -
-``` - -``` -
- 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3314 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3311 - -
-``` - -``` -
- 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3312 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3314 - 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3312 - 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3309 - -
-``` - -``` -
- 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3309 - 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3308 - 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3308 - 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3307 - 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3310 - -
-``` - -``` -
- 38/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3308 - 41/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3307 - -
-``` - -``` -
- 44/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3307 - -
-``` - -``` -
- 42/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3307 - 43/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3307 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.3306 - val_loss: 0.4035 - - -
-``` -Epoch 9/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 2s 44ms/step - loss: 0.3179 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3187 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3186 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3178 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3174 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3173 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3172 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3171 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3172 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3172 - 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3172 - -
-``` - -``` -
- 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3171 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3172 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3170 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3169 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3168 - -
-``` - -``` -
- 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3167 - -
-``` - -``` -
- 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3165 - -
-``` - -``` -
- 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3164 - 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3163 - -
-``` - -``` -
- 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3161 - 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3158 - 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3160 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3157 - 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3156 - -
-``` - -``` -
- 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3155 - 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3155 - -
-``` - -``` -
- 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.3154 - 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3152 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3152 - -
-``` - -``` -
- 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3153 - -
-``` - -``` -
- 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 0.3151 - 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3150 - 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3150 - 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3150 - -
-``` - -``` -
- 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3150 - 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3150 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.3149 - val_loss: 0.3927 - - -
-``` -Epoch 10/10 - -``` -
- - 1/48 ━━━━━━━━━━━━━━━━━━━━ 1s 42ms/step - loss: 0.3042 - - - 2/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3054 - - - 3/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3054 - - - 4/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3047 - -
-``` - -``` -
- 6/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3043 - -
-``` - -``` -
- 7/48 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 0.3042 - -
-``` - -``` -
- 8/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3040 - -
-``` - -``` -
- 9/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3040 - -
-``` - -``` -
- 10/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3040 - -
-``` - -``` -
- 11/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3040 - -
-``` - -``` -
- 12/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3039 - -
-``` - -``` -
- 13/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3037 - -
-``` - -``` -
- 14/48 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 0.3038 - -
-``` - -``` -
- 15/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3035 - -
-``` - -``` -
- 16/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3033 - -
-``` - -``` -
- 17/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3032 - -
-``` - -``` -
- 20/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3026 - 22/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3026 - -
-``` - -``` -
- 21/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3024 - 19/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3025 - 18/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3031 - -
-``` - -``` -
- 24/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3022 - 25/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3020 - 23/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3022 - -
-``` - -``` -
- 27/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3020 - -
-``` - -``` -
- 26/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3019 - -
-``` - -``` -
- 28/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3019 - -
-``` - -``` -
- 34/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3017 - 29/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3018 - 33/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3015 - 30/48 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step - loss: 0.3016 - 32/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3015 - -
-``` - -``` -
- 31/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3016 - -
-``` - -``` -
- 37/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3015 - 38/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3015 - -
-``` - -``` -
- 36/48 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - loss: 0.3015 - 35/48 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step - loss: 0.3014 - -
-``` - -``` -
- 41/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3015 - 39/48 ━━━━━━━━━━━━━━━━━━━━ 0s 16ms/step - loss: 0.3015 - -
-``` - -``` -
- 42/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3015 - 40/48 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.3015 - 44/48 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 0.3015 - -
-``` - -``` -
- 48/48 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 0.3015 - val_loss: 0.3829 - - - - - -
-``` - - -``` -
---- -## Making predictions - -Now that we have a model, we would like to be able to make predictions. - -So far, we have only handled movies by id. Now is the time to create a mapping -keyed by movie IDs to be able to surface the titles. - - -```python -movie_id_to_movie_title = dict(zip(movies_df["MovieID"], movies_df["Title"])) -movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. -``` - -We then simply use the Keras `model.predict()` method. Under the hood, it calls -the `BruteForceRetrieval` layer to perform the actual retrieval. - -Note that this model can retrieve movies already watched by the user. We could -easily add logic to remove them if that is desirable. - - -```python -for ele in val_ds.unbatch().take(1): - test_sample = ele[0] - test_sample["item_ids"] = tf.expand_dims(test_sample["item_ids"], axis=0) - test_sample["padding_mask"] = tf.expand_dims(test_sample["padding_mask"], axis=0) - -movie_sequence = np.array(test_sample["item_ids"])[0] -for movie_id in movie_sequence: - if movie_id == 0: - continue - print(movie_id_to_movie_title[movie_id], end="; ") -print() - -predictions = model.predict(test_sample)["predictions"] -predictions = keras.ops.convert_to_numpy(predictions) - -for movie_id in predictions[0]: - print(movie_id_to_movie_title[movie_id]) -``` - -
-``` -Girl, Interrupted (1999); Back to the Future (1985); Titanic (1997); Cinderella (1950); Meet Joe Black (1998); Last Days of Disco, The (1998); Erin Brockovich (2000); Christmas Story, A (1983); To Kill a Mockingbird (1962); One Flew Over the Cuckoo's Nest (1975); Wallace & Gromit: The Best of Aardman Animation (1996); Star Wars: Episode IV - A New Hope (1977); Wizard of Oz, The (1939); Fargo (1996); Run Lola Run (Lola rennt) (1998); Rain Man (1988); Saving Private Ryan (1998); Awakenings (1990); Gigi (1958); Sound of Music, The (1965); Driving Miss Daisy (1989); Bambi (1942); Apollo 13 (1995); Mary Poppins (1964); E.T. the Extra-Terrestrial (1982); My Fair Lady (1964); Ben-Hur (1959); Big (1988); Sixth Sense, The (1999); Dead Poets Society (1989); James and the Giant Peach (1996); Ferris Bueller's Day Off (1986); Secret Garden, The (1993); Toy Story 2 (1999); Airplane! (1980); Pleasantville (1998); Dumbo (1941); Princess Bride, The (1987); Snow White and the Seven Dwarfs (1937); Miracle on 34th Street (1947); Ponette (1996); Schindler's List (1993); Beauty and the Beast (1991); Tarzan (1999); Close Shave, A (1995); Aladdin (1992); Toy Story (1995); Bug's Life, A (1998); Antz (1998); Hunchback of Notre Dame, The (1996); Hercules (1997); Mulan (1998); Pocahontas (1995); - -``` -
- - 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 790ms/step - -
-``` - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 791ms/step - - -
-``` -Groundhog Day (1993) -Aladdin (1992) -Toy Story (1995) -Forrest Gump (1994) -Bug's Life, A (1998) -Lion King, The (1994) -Shakespeare in Love (1998) -American Beauty (1999) -Sixth Sense, The (1999) -Ghostbusters (1984) - -``` -
-And that's all! diff --git a/templates/keras_rs/examples/scann.md b/templates/keras_rs/examples/scann.md deleted file mode 100644 index 17c5c70bb3..0000000000 --- a/templates/keras_rs/examples/scann.md +++ /dev/null @@ -1,2159 +0,0 @@ -# Faster retrieval with Scalable Nearest Neighbours (ScANN) - -**Author:** [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Using ScANN for faster retrieval. - - - [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/scann.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/scann.py) - - - ---- -## Introduction - -Retrieval models are designed to quickly identify a small set of highly relevant -candidates from vast pools of data, often comprising millions or even hundreds -of millions of items. To effectively respond to the user's context and behavior -in real time, these models must perform this task in just milliseconds. - -Approximate nearest neighbor (ANN) search is the key technology that enables -this level of efficiency. In this tutorial, we'll demonstrate how to leverage -ScANN—a cutting-edge nearest neighbor retrieval library—to effortlessly scale -retrieval for millions of items. - -[ScANN](https://research.google/blog/announcing-scann-efficient-vector-similarity-search/), -developed by Google Research, is a high-performance library designed for -dense vector similarity search at scale. It efficiently indexes a database of -candidate embeddings, enabling rapid search during inference. By leveraging -advanced vector compression techniques and finely tuned algorithms, ScaNN -strikes an optimal balance between speed and accuracy. As a result, it can -significantly outperform brute-force search methods, delivering fast retrieval -with minimal loss in accuracy. - -We will start with the same code as the -[basic retrieval example](/keras_rs/examples/basic_retrieval/). -Data processing, model building, and training remain exactly the same. Feel free -to skip this part if you have gone over the basic retrieval example before. - -Note: ScANN does not have its own separate layer in KerasRS because the ScANN -library is TensorFlow-only. Here, in this example, we directly use the ScANN -library and demonstrate its usage with KerasRS. - ---- -## Imports - -Let's install the `scann` library and import all necessary packages. We will -also set the backend to JAX. - - -```python -# ruff: noqa: E402 -``` - - -```python -!pip install -q scann -``` - -
-``` -[?25l ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/11.8 MB ? eta -:--:-- -``` -
- ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - - - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - - - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - - - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/11.8 MB 126.7 MB/s eta 0:00:01 - ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/11.8 MB 2.8 MB/s eta 0:00:04 - - - ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/11.8 MB 2.8 MB/s eta 0:00:04 - ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/11.8 MB 2.8 MB/s eta 0:00:04 - ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/11.8 MB 2.8 MB/s eta 0:00:04 - ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/11.8 MB 2.8 MB/s eta 0:00:04 - ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/11.8 MB 2.8 MB/s eta 0:00:04 - ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/11.8 MB 2.8 MB/s eta 0:00:04 - - - ━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━ 4.2/11.8 MB 4.2 MB/s eta 0:00:02 - ━━━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━━━━━ 5.6/11.8 MB 5.3 MB/s eta 0:00:02 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━━━━━ 9.4/11.8 MB 8.9 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 10.5/11.8 MB 9.3 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 10.5/11.8 MB 9.3 MB/s eta 0:00:01 - - - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 10.5/11.8 MB 9.3 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 10.5/11.8 MB 9.3 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 10.5/11.8 MB 9.3 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 10.5/11.8 MB 9.3 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 10.5/11.8 MB 9.3 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 11.8/11.8 MB 17.3 MB/s eta 0:00:01 - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 11.8/11.8 MB 16.4 MB/s eta 0:00:00 - [?25h - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import time -import uuid - -import keras -import tensorflow as tf # Needed for the dataset -import tensorflow_datasets as tfds -from scann import scann_ops - -import keras_rs -``` - ---- -## Preparing the dataset - - -```python -# Ratings data with user and movie data. -ratings = tfds.load("movielens/100k-ratings", split="train") -# Features of all the available movies. -movies = tfds.load("movielens/100k-movies", split="train") - -# Get user and movie counts so that we can define embedding layers for both. -users_count = ( - ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) - .reduce(tf.constant(0, tf.int32), tf.maximum) - .numpy() -) - -movies_count = movies.cardinality().numpy() - - -# Preprocess the dataset, by selecting only the relevant columns. -def preprocess_rating(x): - return ( - # Input is the user IDs - tf.strings.to_number(x["user_id"], out_type=tf.int32), - # Labels are movie IDs + ratings between 0 and 1. - { - "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), - "rating": (x["user_rating"] - 1.0) / 4.0, - }, - ) - - -shuffled_ratings = ratings.map(preprocess_rating).shuffle( - 100_000, seed=42, reshuffle_each_iteration=False -) -# Train-test split. -train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() -test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() -``` - ---- -## Implementing the Model - - -```python - -class RetrievalModel(keras.Model): - def __init__( - self, - num_users, - num_candidates, - embedding_dimension=32, - **kwargs, - ): - super().__init__(**kwargs) - # Our query tower, simply an embedding table. - self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension) - # Our candidate tower, simply an embedding table. - self.candidate_embedding = keras.layers.Embedding( - num_candidates, embedding_dimension - ) - - self.loss_fn = keras.losses.MeanSquaredError() - - def build(self, input_shape): - self.user_embedding.build(input_shape) - self.candidate_embedding.build(input_shape) - - super().build(input_shape) - - def call(self, inputs, training=False): - user_embeddings = self.user_embedding(inputs) - result = { - "user_embeddings": user_embeddings, - } - return result - - def compute_loss(self, x, y, y_pred, sample_weight, training=True): - candidate_id, rating = y["movie_id"], y["rating"] - user_embeddings = y_pred["user_embeddings"] - candidate_embeddings = self.candidate_embedding(candidate_id) - - labels = keras.ops.expand_dims(rating, -1) - # Compute the affinity score by multiplying the two embeddings. - scores = keras.ops.sum( - keras.ops.multiply(user_embeddings, candidate_embeddings), - axis=1, - keepdims=True, - ) - return self.loss_fn(labels, scores, sample_weight) - -``` - ---- -## Training the model - - -```python -model = RetrievalModel(users_count + 1000, movies_count + 1000) -model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.1)) - -history = model.fit( - train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50 -) -``` - -
-``` -Epoch 1/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 2:34 2s/step - loss: 0.4476 - -
-``` - -``` -
- 2/80 ━━━━━━━━━━━━━━━━━━━━ 17s 223ms/step - loss: 0.4543 - -
-``` - -``` -
- 28/80 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - loss: 0.4760 - -
-``` - -``` -
- 54/80 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4772 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - loss: 0.4772 - - -
-``` -Epoch 2/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 17s 222ms/step - loss: 0.4476 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 64/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4771 - - -
-``` -Epoch 3/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4475 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4542 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4771 - - -
-``` -Epoch 4/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.4475 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4761 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4769 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4771 - - -
-``` -Epoch 5/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4475 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4541 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 27ms/step - loss: 0.4770 - val_loss: 0.4836 - - -
-``` -Epoch 6/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4474 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770 - - -
-``` -Epoch 7/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4474 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770 - - -
-``` -Epoch 8/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4474 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - - -
-``` -Epoch 9/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4474 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4767 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4769 - - -
-``` -Epoch 10/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4473 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - val_loss: 0.4836 - - -
-``` -Epoch 11/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4473 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - - -
-``` -Epoch 12/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4473 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4759 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4759 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - - -
-``` -Epoch 13/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4472 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4539 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768 - - -
-``` -Epoch 14/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4472 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - - -
-``` -Epoch 15/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4471 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - -
-``` - -``` -
- 65/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - val_loss: 0.4835 - - -
-``` -Epoch 16/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4471 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4756 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767 - - -
-``` -Epoch 17/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4471 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4537 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4756 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - - -
-``` -Epoch 18/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4470 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4756 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4763 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - - -
-``` -Epoch 19/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4470 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4755 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4763 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - - -
-``` -Epoch 20/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4469 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4755 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4755 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765 - val_loss: 0.4835 - - -
-``` -Epoch 21/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4469 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4755 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4755 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764 - - -
-``` -Epoch 22/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4468 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4535 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4754 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4763 - - -
-``` -Epoch 23/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4468 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4753 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4763 - - -
-``` -Epoch 24/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4467 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4753 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4762 - - -
-``` -Epoch 25/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4466 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4752 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4759 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761 - val_loss: 0.4833 - - -
-``` -Epoch 26/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4466 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4751 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4761 - - -
-``` -Epoch 27/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4465 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4750 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4758 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4760 - - -
-``` -Epoch 28/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4464 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4530 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4750 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4757 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4758 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4759 - - -
-``` -Epoch 29/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4463 - -
-``` - -``` -
- 38/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4749 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4756 - -
-``` - -``` -
- 73/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4757 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4758 - - -
-``` -Epoch 30/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4462 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4747 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4755 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4755 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757 - val_loss: 0.4830 - - -
-``` -Epoch 31/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4461 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4746 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4754 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4755 - - -
-``` -Epoch 32/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4460 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4744 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4752 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4754 - - -
-``` -Epoch 33/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4458 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4524 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4743 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4744 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4751 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4751 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4753 - - -
-``` -Epoch 34/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4457 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4741 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4741 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4749 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4749 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4751 - - -
-``` -Epoch 35/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4455 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4740 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4747 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4747 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4749 - val_loss: 0.4823 - - -
-``` -Epoch 36/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4453 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4738 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4745 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4745 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4747 - - -
-``` -Epoch 37/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4451 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4736 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4743 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4745 - - -
-``` -Epoch 38/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4449 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4734 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4741 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4741 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4743 - - -
-``` -Epoch 39/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4446 - -
-``` - -``` -
- 37/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4731 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4738 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4740 - - -
-``` -Epoch 40/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4443 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4509 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4727 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4734 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4737 - val_loss: 0.4812 - - -
-``` -Epoch 41/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4440 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4725 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4732 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4733 - - -
-``` -Epoch 42/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4437 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4721 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4728 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4728 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4730 - - -
-``` -Epoch 43/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4433 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4717 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4717 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4724 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4725 - - -
-``` -Epoch 44/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4429 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4712 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4719 - -
-``` - -``` -
- 70/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4719 - -
-``` - -``` -
- 71/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4719 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4721 - - -
-``` -Epoch 45/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4424 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4707 - -
-``` - -``` -
- 68/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4714 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4716 - val_loss: 0.4791 - - -
-``` -Epoch 46/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4418 - -
-``` - -``` -
- 32/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4701 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4708 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4710 - - -
-``` -Epoch 47/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4412 - - - 2/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.4478 - -
-``` - -``` -
- 36/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4695 - -
-``` - -``` -
- 67/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4701 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4703 - - -
-``` -Epoch 48/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4406 - -
-``` - -``` -
- 35/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4688 - -
-``` - -``` -
- 69/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4694 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4696 - - -
-``` -Epoch 49/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4398 - -
-``` - -``` -
- 33/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4680 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4680 - -
-``` - -``` -
- 66/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4686 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4688 - - -
-``` -Epoch 50/50 - -``` -
- - 1/80 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.4390 - -
-``` - -``` -
- 34/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4671 - -
-``` - -``` -
- 72/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4678 - -
-``` - -``` -
- 80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4679 - val_loss: 0.4753 - - ---- -## Making predictions - -Before we try out ScANN, let's go with the brute force method, i.e., for a given -user, scores are computed for all movies, sorted and then the top-k -movies are picked. This is, of course, not very scalable when we have a huge -number of movies. - - -```python -candidate_embeddings = keras.ops.array(model.candidate_embedding.embeddings.numpy()) -# Artificially duplicate candidate embeddings to simulate a large number of -# movies. -candidate_embeddings = keras.ops.concatenate( - [candidate_embeddings] - + [ - candidate_embeddings - * keras.random.uniform(keras.ops.shape(candidate_embeddings)) - for _ in range(100) - ], - axis=0, -) - -user_embedding = model.user_embedding(keras.ops.array([10, 5, 42, 345])) - -# Define the brute force retrieval layer. -brute_force_layer = keras_rs.layers.BruteForceRetrieval( - candidate_embeddings=candidate_embeddings, - k=10, - return_scores=False, -) -``` - -Now, let's do a forward pass on the layer. Note that in previous tutorials, we -have the above layer as an attribute of the model class, and we then call -`.predict()`. This will obviously be faster (since it's compiled XLA code), but -since we cannot do the same for ScANN, we just do a normal forward pass here -without compilation to ensure a fair comparison. - - -```python -t0 = time.time() -pred_movie_ids = brute_force_layer(user_embedding) -print("Time taken by brute force layer (sec):", time.time() - t0) -``` - -
-``` -Time taken by brute force layer (sec): 0.22817683219909668 - -``` -
-Now, let's retrieve movies using ScANN. We will use the ScANN library from -Google Research to build the layer and then call it. To fully understand all the -arguments, please refer to the -[ScANN README file](https://github.com/google-research/google-research/tree/master/scann#readme). - - -```python - -def build_scann( - candidates, - k=10, - distance_measure="dot_product", - dimensions_per_block=2, - num_reordering_candidates=500, - num_leaves=100, - num_leaves_to_search=30, - training_iterations=12, -): - builder = scann_ops.builder( - db=candidates, - num_neighbors=k, - distance_measure=distance_measure, - ) - - builder = builder.tree( - num_leaves=num_leaves, - num_leaves_to_search=num_leaves_to_search, - training_iterations=training_iterations, - ) - builder = builder.score_ah(dimensions_per_block=dimensions_per_block) - - if num_reordering_candidates is not None: - builder = builder.reorder(num_reordering_candidates) - - # Set a unique name to prevent unintentional sharing between - # ScaNN instances. - searcher = builder.build(shared_name=str(uuid.uuid4())) - return searcher - - -def run_scann(searcher): - pred_movie_ids = searcher.search_batched_parallel( - user_embedding, - final_num_neighbors=10, - ).indices - return pred_movie_ids - - -searcher = build_scann(candidates=candidate_embeddings) - -t0 = time.time() -pred_movie_ids = run_scann(searcher) -print("Time taken by ScANN (sec):", time.time() - t0) -``` - -
-``` -Time taken by ScANN (sec): 0.0032587051391601562 - -``` -
-You can clearly see the performance improvement in terms of latency. ScANN -(0.003 seconds) takes one-fiftieth the time it takes for the brute force layer -(0.15 seconds) to run! diff --git a/templates/keras_rs/examples/sequential_retrieval.md b/templates/keras_rs/examples/sequential_retrieval.md deleted file mode 100644 index 54fd7fbe3a..0000000000 --- a/templates/keras_rs/examples/sequential_retrieval.md +++ /dev/null @@ -1,2333 +0,0 @@ -# Sequential retrieval [GRU4Rec] - -**Author:** [Abheesht Sharma](https://github.com/abheesht17/), [Fabien Hertschuh](https://github.com/hertschuh/)
-**Date created:** 2025/04/28
-**Last modified:** 2025/04/28
-**Description:** Recommend movies using a GRU-based sequential retrieval model. - - - [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/sequential_retrieval.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/sequential_retrieval.py) - - - ---- -## Introduction - -In this example, we are going to build a sequential retrieval model. Sequential -recommendation is a popular model that looks at a sequence of items that users -have interacted with previously and then predicts the next item. Here, the order -of the items within each sequence matters. So, we are going to use a recurrent -neural network to model the sequential relationship. For more details, -please refer to the [GRU4Rec](https://arxiv.org/abs/1511.06939) paper. - -Let's begin by choosing JAX as the backend we want to run on, and import all -the necessary libraries. - - -```python -import os - -os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` - -import collections -import os -import random - -import keras -import pandas as pd -import tensorflow as tf # Needed only for the dataset - -import keras_rs -``` - -Let's also define all important variables/hyperparameters below. - - -```python -DATA_DIR = "./raw/data/" - -# MovieLens-specific variables -MOVIELENS_1M_URL = "https://files.grouplens.org/datasets/movielens/ml-1m.zip" -MOVIELENS_ZIP_HASH = "a6898adb50b9ca05aa231689da44c217cb524e7ebd39d264c56e2832f2c54e20" - -RATINGS_FILE_NAME = "ratings.dat" -MOVIES_FILE_NAME = "movies.dat" - -# Data processing args -MAX_CONTEXT_LENGTH = 10 -MIN_SEQUENCE_LENGTH = 3 -TRAIN_DATA_FRACTION = 0.9 - -RATINGS_DATA_COLUMNS = ["UserID", "MovieID", "Rating", "Timestamp"] -MOVIES_DATA_COLUMNS = ["MovieID", "Title", "Genres"] -MIN_RATING = 2 - -# Training/model args -BATCH_SIZE = 4096 -TEST_BATCH_SIZE = 2048 -EMBEDDING_DIM = 32 -NUM_EPOCHS = 5 -LEARNING_RATE = 0.05 -``` - ---- -## Dataset - -Next, we need to prepare our dataset. Like we did in the -[basic retrieval](/keras_rs/examples/basic_retrieval/) -example, we are going to use the MovieLens dataset. - -The dataset preparation step is fairly involved. The original ratings dataset -contains `(user, movie ID, rating, timestamp)` tuples (among other columns, -which are not important for this example). Since we are dealing with sequential -retrieval, we need to create movie sequences for every user, where the sequences -are ordered by timestamp. - -Let's start by downloading and reading the dataset. - - -```python -# Download the MovieLens dataset. -if not os.path.exists(DATA_DIR): - os.makedirs(DATA_DIR) - -path_to_zip = keras.utils.get_file( - fname="ml-1m.zip", - origin=MOVIELENS_1M_URL, - file_hash=MOVIELENS_ZIP_HASH, - hash_algorithm="sha256", - extract=True, - cache_dir=DATA_DIR, -) -movielens_extracted_dir = os.path.join( - os.path.dirname(path_to_zip), - "ml-1m_extracted", - "ml-1m", -) - - -# Read the dataset. -def read_data(data_directory, min_rating=None): - """Read movielens ratings.dat and movies.dat file - into dataframe. - """ - - ratings_df = pd.read_csv( - os.path.join(data_directory, RATINGS_FILE_NAME), - sep="::", - names=RATINGS_DATA_COLUMNS, - encoding="unicode_escape", - ) - ratings_df["Timestamp"] = ratings_df["Timestamp"].apply(int) - - # Remove movies with `rating < min_rating`. - if min_rating is not None: - ratings_df = ratings_df[ratings_df["Rating"] >= min_rating] - - movies_df = pd.read_csv( - os.path.join(data_directory, MOVIES_FILE_NAME), - sep="::", - names=MOVIES_DATA_COLUMNS, - encoding="unicode_escape", - ) - return ratings_df, movies_df - - -ratings_df, movies_df = read_data( - data_directory=movielens_extracted_dir, min_rating=MIN_RATING -) - -# Need to know #movies so as to define embedding layers. -movies_count = movies_df["MovieID"].max() -``` - -
-``` -Downloading data from https://files.grouplens.org/datasets/movielens/ml-1m.zip - -``` -
- - 0/5917549 ━━━━━━━━━━━━━━━━━━━━ 0s 0s/step - -
-``` - -``` -
- 40960/5917549 ━━━━━━━━━━━━━━━━━━━━ 10s 2us/step - -
-``` - -``` -
- 155648/5917549 ━━━━━━━━━━━━━━━━━━━━ 5s 1us/step - -
-``` - -``` -
- 647168/5917549 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step - -
-``` - -``` -
- 2629632/5917549 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step - -
-``` - -``` -
- 5917549/5917549 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step - - -
-``` -:26: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'. - ratings_df = pd.read_csv( - -:38: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'. - movies_df = pd.read_csv( - -``` -
-Now that we have read the dataset, let's create sequences of movies -for every user. Here is the function for doing just that. - - -```python - -def get_movie_sequence_per_user(ratings_df): - """Get movieID sequences for every user.""" - sequences = collections.defaultdict(list) - - for user_id, movie_id, rating, timestamp in ratings_df.values: - sequences[user_id].append( - { - "movie_id": movie_id, - "timestamp": timestamp, - "rating": rating, - } - ) - - # Sort movie sequences by timestamp for every user. - for user_id, context in sequences.items(): - context.sort(key=lambda x: x["timestamp"]) - sequences[user_id] = context - - return sequences - -``` - -We need to do some filtering and processing before we proceed -with training the model: - -1. Form sequences of all lengths up to - `min(user_sequence_length, MAX_CONTEXT_LENGTH)`. So, every user - will have multiple sequences corresponding to it. -2. Get labels, i.e., Given a sequence of length `n`, the first - `n-1` tokens will be fed to the model as input, and the label - will be the last token. -3. Remove all user sequences with less than `MIN_SEQUENCE_LENGTH` - movies. -4. Pad all sequences to `MAX_CONTEXT_LENGTH`. - - -```python - -def generate_examples_from_user_sequences(sequences): - """Generates sequences for all users, with padding, truncation, etc.""" - - def generate_examples_from_user_sequence(sequence): - """Generates examples for a single user sequence.""" - - examples = [] - for label_idx in range(1, len(sequence)): - start_idx = max(0, label_idx - MAX_CONTEXT_LENGTH) - context = sequence[start_idx:label_idx] - - # Padding - while len(context) < MAX_CONTEXT_LENGTH: - context.append( - { - "movie_id": 0, - "timestamp": 0, - "rating": 0.0, - } - ) - - label_movie_id = int(sequence[label_idx]["movie_id"]) - context_movie_id = [int(movie["movie_id"]) for movie in context] - - examples.append( - { - "context_movie_id": context_movie_id, - "label_movie_id": label_movie_id, - }, - ) - return examples - - all_examples = [] - for sequence in sequences.values(): - if len(sequence) < MIN_SEQUENCE_LENGTH: - continue - - user_examples = generate_examples_from_user_sequence(sequence) - - all_examples.extend(user_examples) - - return all_examples - -``` - -Let's split the dataset into train and test sets. Also, we need to -change the format of the dataset dictionary so as to enable conversion -to a `tf.data.Dataset` object. - - -```python -sequences = get_movie_sequence_per_user(ratings_df) -examples = generate_examples_from_user_sequences(sequences) - -# Train-test split. -random.shuffle(examples) -split_index = int(TRAIN_DATA_FRACTION * len(examples)) -train_examples = examples[:split_index] -test_examples = examples[split_index:] - - -def list_of_dicts_to_dict_of_lists(list_of_dicts): - """Convert list of dictionaries to dictionary of lists for - `tf.data` conversion. - """ - dict_of_lists = collections.defaultdict(list) - for dictionary in list_of_dicts: - for key, value in dictionary.items(): - dict_of_lists[key].append(value) - return dict_of_lists - - -train_examples = list_of_dicts_to_dict_of_lists(train_examples) -test_examples = list_of_dicts_to_dict_of_lists(test_examples) - -train_ds = tf.data.Dataset.from_tensor_slices(train_examples).map( - lambda x: (x["context_movie_id"], x["label_movie_id"]) -) -test_ds = tf.data.Dataset.from_tensor_slices(test_examples).map( - lambda x: (x["context_movie_id"], x["label_movie_id"]) -) -``` - -We need to batch our datasets. We also user `cache()` and `prefetch()` -for better performance. - - -```python -train_ds = train_ds.batch(BATCH_SIZE).cache().prefetch(tf.data.AUTOTUNE) -test_ds = test_ds.batch(TEST_BATCH_SIZE).cache().prefetch(tf.data.AUTOTUNE) -``` - -Let's print out one batch. - - -```python -for sample in train_ds.take(1): - print(sample) -``` - -
-``` -(, ) - -``` -
---- -## Model and Training - -In the basic retrieval example, we used one query tower for the -user, and the candidate tower for the candidate movie. We are -going to use a two-tower architecture here as well. However, -we use the query tower with a Gated Recurrent Unit (GRU) layer -to encode the sequence of historical movies, and keep the same -candidate tower for the candidate movie. - -Note: Take a look at how the labels are defined. The label tensor -(of shape `(batch_size, batch_size)`) contains one-hot vectors. The idea -is: for every sample, consider movie IDs corresponding to other samples in -the batch as negatives. - - -```python - -class SequentialRetrievalModel(keras.Model): - """Create the sequential retrieval model. - - Args: - movies_count: Total number of unique movies in the dataset. - embedding_dimension: Output dimension for movie embedding tables. - """ - - def __init__( - self, - movies_count, - embedding_dimension=128, - **kwargs, - ): - super().__init__(**kwargs) - # Our query tower, simply an embedding table followed by - # a GRU unit. This encodes sequence of historical movies. - self.query_model = keras.Sequential( - [ - keras.layers.Embedding(movies_count + 1, embedding_dimension), - keras.layers.GRU(embedding_dimension), - ] - ) - - # Our candidate tower, simply an embedding table. - self.candidate_model = keras.layers.Embedding( - movies_count + 1, embedding_dimension - ) - - # The layer that performs the retrieval. - self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False) - self.loss_fn = keras.losses.CategoricalCrossentropy( - from_logits=True, - ) - - def build(self, input_shape): - self.query_model.build(input_shape) - self.candidate_model.build(input_shape) - - # In this case, the candidates are directly the movie embeddings. - # We take a shortcut and directly reuse the variable. - self.retrieval.candidate_embeddings = self.candidate_model.embeddings - self.retrieval.build(input_shape) - super().build(input_shape) - - def call(self, inputs, training=False): - query_embeddings = self.query_model(inputs) - result = { - "query_embeddings": query_embeddings, - } - - if not training: - # Skip the retrieval of top movies during training as the - # predictions are not used. - result["predictions"] = self.retrieval(query_embeddings) - return result - - def compute_loss(self, x, y, y_pred, sample_weight, training=True): - candidate_id = y - query_embeddings = y_pred["query_embeddings"] - candidate_embeddings = self.candidate_model(candidate_id) - - num_queries = keras.ops.shape(query_embeddings)[0] - num_candidates = keras.ops.shape(candidate_embeddings)[0] - - # One-hot vectors for labels. - labels = keras.ops.eye(num_queries, num_candidates) - - # Compute the affinity score by multiplying the two embeddings. - scores = keras.ops.matmul( - query_embeddings, keras.ops.transpose(candidate_embeddings) - ) - - return self.loss_fn(labels, scores, sample_weight) - -``` - -Let's instantiate, compile and train our model. - - -```python -model = SequentialRetrievalModel( - movies_count=movies_count + 1, embedding_dimension=EMBEDDING_DIM -) - -# Compile. -model.compile(optimizer=keras.optimizers.AdamW(learning_rate=LEARNING_RATE)) - -# Train. -model.fit( - train_ds, - validation_data=test_ds, - epochs=NUM_EPOCHS, -) -``` - -
-``` -Epoch 1/5 - -``` -
- - 1/207 ━━━━━━━━━━━━━━━━━━━━ 7:12 2s/step - loss: 8.3177 - -
-``` - -``` -
- 2/207 ━━━━━━━━━━━━━━━━━━━━ 4:07 1s/step - loss: 8.3176 - -
-``` - -``` -
- 3/207 ━━━━━━━━━━━━━━━━━━━━ 2:03 607ms/step - loss: 8.3169 - -
-``` - -``` -
- 4/207 ━━━━━━━━━━━━━━━━━━━━ 1:22 407ms/step - loss: 8.3154 - -
-``` - -``` -
- 11/207 ━━━━━━━━━━━━━━━━━━━━ 24s 128ms/step - loss: 8.2616 - -
-``` - -``` -
- 12/207 ━━━━━━━━━━━━━━━━━━━━ 22s 117ms/step - loss: 8.2514 - -
-``` - -``` -
- 13/207 ━━━━━━━━━━━━━━━━━━━━ 20s 108ms/step - loss: 8.2410 - -
-``` - -``` -
- 14/207 ━━━━━━━━━━━━━━━━━━━━ 19s 100ms/step - loss: 8.2303 - -
-``` - -``` -
- 15/207 ━━━━━━━━━━━━━━━━━━━━ 17s 93ms/step - loss: 8.2196 - -
-``` - -``` -
- 16/207 ━━━━━━━━━━━━━━━━━━━━ 16s 88ms/step - loss: 8.2088 - -
-``` - -``` -
- 23/207 ━━━━━━━━━━━━━━━━━━━━ 11s 62ms/step - loss: 8.1343 - -
-``` - -``` -
- 24/207 ━━━━━━━━━━━━━━━━━━━━ 10s 60ms/step - loss: 8.1240 - -
-``` - -``` -
- 25/207 ━━━━━━━━━━━━━━━━━━━━ 10s 57ms/step - loss: 8.1139 - -
-``` - -``` -
- 26/207 ━━━━━━━━━━━━━━━━━━━━ 10s 55ms/step - loss: 8.1040 - -
-``` - -``` -
- 27/207 ━━━━━━━━━━━━━━━━━━━━ 9s 53ms/step - loss: 8.0943 - -
-``` - -``` -
- 28/207 ━━━━━━━━━━━━━━━━━━━━ 9s 52ms/step - loss: 8.0846 - -
-``` - -``` -
- 29/207 ━━━━━━━━━━━━━━━━━━━━ 8s 50ms/step - loss: 8.0751 - -
-``` - -``` -
- 39/207 ━━━━━━━━━━━━━━━━━━━━ 6s 38ms/step - loss: 7.9869 - -
-``` - -``` -
- 40/207 ━━━━━━━━━━━━━━━━━━━━ 6s 37ms/step - loss: 7.9788 - -
-``` - -``` -
- 41/207 ━━━━━━━━━━━━━━━━━━━━ 6s 37ms/step - loss: 7.9708 - -
-``` - -``` -
- 42/207 ━━━━━━━━━━━━━━━━━━━━ 5s 36ms/step - loss: 7.9629 - -
-``` - -``` -
- 43/207 ━━━━━━━━━━━━━━━━━━━━ 5s 35ms/step - loss: 7.9551 - -
-``` - -``` -
- 53/207 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - loss: 7.8825 - -
-``` - -``` -
- 54/207 ━━━━━━━━━━━━━━━━━━━━ 4s 29ms/step - loss: 7.8757 - -
-``` - -``` -
- 55/207 ━━━━━━━━━━━━━━━━━━━━ 4s 28ms/step - loss: 7.8691 - -
-``` - -``` -
- 56/207 ━━━━━━━━━━━━━━━━━━━━ 4s 28ms/step - loss: 7.8625 - -
-``` - -``` -
- 66/207 ━━━━━━━━━━━━━━━━━━━━ 3s 25ms/step - loss: 7.8011 - -
-``` - -``` -
- 67/207 ━━━━━━━━━━━━━━━━━━━━ 3s 24ms/step - loss: 7.7954 - -
-``` - -``` -
- 75/207 ━━━━━━━━━━━━━━━━━━━━ 2s 22ms/step - loss: 7.7518 - -
-``` - -``` -
- 83/207 ━━━━━━━━━━━━━━━━━━━━ 2s 21ms/step - loss: 7.7120 - -
-``` - -``` -
- 91/207 ━━━━━━━━━━━━━━━━━━━━ 2s 20ms/step - loss: 7.6755 - -
-``` - -``` -
- 99/207 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 7.6419 - -
-``` - -``` -
- 107/207 ━━━━━━━━━━━━━━━━━━━━ 1s 18ms/step - loss: 7.6108 - -
-``` - -``` -
- 115/207 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 7.5821 - -
-``` - -``` -
- 123/207 ━━━━━━━━━━━━━━━━━━━━ 1s 16ms/step - loss: 7.5553 - -
-``` - -``` -
- 131/207 ━━━━━━━━━━━━━━━━━━━━ 1s 16ms/step - loss: 7.5303 - -
-``` - -``` -
- 139/207 ━━━━━━━━━━━━━━━━━━━━ 1s 15ms/step - loss: 7.5069 - -
-``` - -``` -
- 140/207 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 7.5041 - -
-``` - -``` -
- 148/207 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 7.4823 - -
-``` - -``` -
- 157/207 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 7.4592 - -
-``` - -``` -
- 165/207 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - loss: 7.4400 - -
-``` - -``` -
- 173/207 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 7.4218 - -
-``` - -``` -
- 181/207 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 7.4045 - -
-``` - -``` -
- 189/207 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - loss: 7.3881 - -
-``` - -``` -
- 197/207 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 7.3725 - -
-``` - -``` -
- 205/207 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 7.3540 - -
-``` - -``` -
- 206/207 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - loss: 7.3558 - -
-``` - -``` -
- 207/207 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - loss: 7.3505 - -
-``` - -``` -
- 207/207 ━━━━━━━━━━━━━━━━━━━━ 8s 28ms/step - loss: 7.3487 - val_loss: 5.9852 - - -
-``` -Epoch 2/5 - -``` -
- - 1/207 ━━━━━━━━━━━━━━━━━━━━ 4:08 1s/step - loss: 6.6873 - - - 2/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6892 - - - 3/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6883 - - - 4/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6841 - -
-``` - -``` -
- 9/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6842 - -
-``` - -``` -
- 10/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6834 - -
-``` - -``` -
- 11/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6829 - -
-``` - -``` -
- 12/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6825 - -
-``` - -``` -
- 13/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6822 - -
-``` - -``` -
- 14/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6819 - 15/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6821 - -
-``` - -``` -
- 18/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6813 - 16/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6816 - 17/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.6814 - -
-``` - -``` -
- 19/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6811 - -
-``` - -``` -
- 20/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6810 - -
-``` - -``` -
- 22/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6806 - 23/207 ━━━━━━━━━━━━━━━━━━━━ 2s 11ms/step - loss: 6.6805 - -
-``` - -``` -
- 21/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6808 - -
-``` - -``` -
- 25/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6804 - 24/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6803 - -
-``` - -``` -
- 26/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6804 - -
-``` - -``` -
- 27/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6804 - -
-``` - -``` -
- 31/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6803 - -
-``` - -``` -
- 28/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6803 - 29/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6803 - 30/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6803 - -
-``` - -``` -
- 32/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6802 - -
-``` - -``` -
- 38/207 ━━━━━━━━━━━━━━━━━━━━ 2s 12ms/step - loss: 6.6796 - 36/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6795 - -
-``` - -``` -
- 34/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6799 - 33/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6801 - -
-``` - -``` -
- 37/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6797 - -
-``` - -``` -
- 35/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6800 - -
-``` - -``` -
- 39/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6793 - 40/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6792 - -
-``` - -``` -
- 42/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6787 - 43/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6788 - 41/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6791 - -
-``` - -``` -
- 44/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6783 - -
-``` - -``` -
- 45/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6782 - -
-``` - -``` -
- 46/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6780 - 47/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6769 - -
-``` - -``` -
- 50/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6772 - 49/207 ━━━━━━━━━━━━━━━━━━━━ 2s 13ms/step - loss: 6.6765 - 48/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6776 - 51/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.6770 - -
-``` - -``` -
- 52/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6760 - -
-``` - -``` -
- 57/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.6750 - 53/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6746 - 54/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6756 - 56/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6752 - 55/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6748 - -
-``` - -``` -
- 58/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6737 - -
-``` - -``` -
- 63/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6728 - 62/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6730 - 61/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6731 - -
-``` - -``` -
- 60/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6733 - 64/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.6726 - 59/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6735 - -
-``` - -``` -
- 65/207 ━━━━━━━━━━━━━━━━━━━━ 2s 15ms/step - loss: 6.6724 - -
-``` - -``` -
- 70/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6713 - 69/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6707 - 71/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6703 - -
-``` - -``` -
- 68/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6709 - 66/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6721 - 67/207 ━━━━━━━━━━━━━━━━━━━━ 2s 14ms/step - loss: 6.6711 - -
-``` - -``` -
- 77/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6687 - 76/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6693 - 75/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6691 - 72/207 ━━━━━━━━━━━━━━━━━━━━ 1s 15ms/step - loss: 6.6685 - -
-``` - -``` -
- 73/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6699 - -
-``` - -``` -
- 74/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6697 - 82/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6662 - -
-``` - -``` -
- 79/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6672 - 81/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6660 - 83/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6664 - -
-``` - -``` -
- 80/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6659 - 87/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.6641 - -
-``` - -``` -
- 85/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.6640 - -
-``` - -``` -
- 78/207 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - loss: 6.6657 - 84/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.6647 - 86/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.6643 - -
-``` - -``` -
- 207/207 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - loss: 6.6328 - val_loss: 5.9231 - - -
-``` -Epoch 3/5 - -``` -
- - 1/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5509 - - - 2/207 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.5612 - - - 3/207 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.5651 - -
-``` - -``` -
- 14/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5684 - -
-``` - -``` -
- 15/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5687 - -
-``` - -``` -
- 16/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5688 - -
-``` - -``` -
- 17/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5689 - -
-``` - -``` -
- 18/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5691 - -
-``` - -``` -
- 19/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5692 - -
-``` - -``` -
- 21/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5694 - 20/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5694 - -
-``` - -``` -
- 25/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5701 - 23/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5696 - 24/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5699 - -
-``` - -``` -
- 22/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5695 - -
-``` - -``` -
- 29/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5717 - 26/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5704 - 33/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5723 - 28/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5710 - 32/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5722 - 31/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5720 - -
-``` - -``` -
- 27/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5713 - -
-``` - -``` -
- 30/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5719 - -
-``` - -``` -
- 36/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5726 - 34/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5724 - -
-``` - -``` -
- 37/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5727 - -
-``` - -``` -
- 35/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5728 - -
-``` - -``` -
- 38/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5730 - 39/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5731 - -
-``` - -``` -
- 43/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5731  - 40/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5731 - 41/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5731 - 42/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5731 - -
-``` - -``` -
- 44/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5730 - -
-``` - -``` -
- 45/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5730 - -
-``` - -``` -
- 46/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5729 - -
-``` - -``` -
- 47/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5728 - -
-``` - -``` -
- 48/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5727 - -
-``` - -``` -
- 49/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5727 - -
-``` - -``` -
- 50/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5726 - -
-``` - -``` -
- 51/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5725 - -
-``` - -``` -
- 53/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5723 - -
-``` - -``` -
- 56/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5720 - 55/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5718 - -
-``` - -``` -
- 57/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5716 - 52/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5716 - 54/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5719 - -
-``` - -``` -
- 58/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5710 - -
-``` - -``` -
- 59/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5709 - -
-``` - -``` -
- 60/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5709 - 63/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5704 - -
-``` - -``` -
- 62/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5707 - -
-``` - -``` -
- 61/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5706 - -
-``` - -``` -
- 64/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5703 - -
-``` - -``` -
- 65/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5694 - 66/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5699 - 70/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5695 - 67/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5700 - 69/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5690 - 68/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5697 - -
-``` - -``` -
- 71/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5688 - -
-``` - -``` -
- 72/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5687 - -
-``` - -``` -
- 75/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5684 - 73/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5686 - -
-``` - -``` -
- 76/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5683 - -
-``` - -``` -
- 74/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5685 - -
-``` - -``` -
- 83/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5672 - 84/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5673 - 77/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.5683 - 80/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5680 - -
-``` - -``` -
- 78/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.5682 - 81/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5676 - 85/207 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 6.5664 - 79/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5681 - -
-``` - -``` -
- 82/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5677 - 88/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5666 - -
-``` - -``` -
- 86/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5665 - -
-``` - -``` -
- 90/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5659 - -
-``` - -``` -
- 87/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5662 - 91/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5658 - 92/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5657 - 94/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5655 - -
-``` - -``` -
- 93/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5656 - 89/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5660 - -
-``` - -``` -
- 207/207 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 6.5498 - val_loss: 5.9322 - - -
-``` -Epoch 4/5 - -``` -
- - 1/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5131 - - - 2/207 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.5257 - - - 3/207 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.5284 - -
-``` - -``` -
- 14/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5314 - -
-``` - -``` -
- 15/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5316 - -
-``` - -``` -
- 16/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5317 - -
-``` - -``` -
- 17/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5317 - -
-``` - -``` -
- 18/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5320 - -
-``` - -``` -
- 19/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5321 - -
-``` - -``` -
- 21/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5325 - -
-``` - -``` -
- 20/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5324 - -
-``` - -``` -
- 23/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5327 - 22/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5325 - 24/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5329 - -
-``` - -``` -
- 25/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5332 - -
-``` - -``` -
- 26/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5335 - -
-``` - -``` -
- 28/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5341 - 36/207 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 6.5354 - 29/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5343 - 35/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5356 - 34/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5355 - 30/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5345 - 27/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5338 - 33/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5350 - 31/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5347 - -
-``` - -``` -
- 32/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5357 - -
-``` - -``` -
- 41/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5363 - 40/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5363 - 38/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5362 - 37/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5361 - 39/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5362 - -
-``` - -``` -
- 44/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5362 - 48/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5358 - -
-``` - -``` -
- 43/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5358 - 47/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5356 - 46/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5356 - 42/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5362 - -
-``` - -``` -
- 45/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5360 - -
-``` - -``` -
- 49/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5354 - 55/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5349 - 54/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5347 - -
-``` - -``` -
- 53/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5352 - 52/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5353  - 56/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5348 - 50/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5354 - -
-``` - -``` -
- 51/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5351 - -
-``` - -``` -
- 57/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5343 - 62/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5340  - 58/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5344 - 61/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5340 - 59/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5344 - -
-``` - -``` -
- 60/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5341 - -
-``` - -``` -
- 63/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5339 - 64/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5338 - 67/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5336 - 68/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5335 - -
-``` - -``` -
- 65/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5337 - -
-``` - -``` -
- 66/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5336 - -
-``` - -``` -
- 70/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5328 - 72/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5326 - 73/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5330 - -
-``` - -``` -
- 69/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5334 - 71/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5332 - 74/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5329 - -
-``` - -``` -
- 80/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5318 - 78/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5320 - 79/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5319 - -
-``` - -``` -
- 77/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5317 - 75/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5322 - -
-``` - -``` -
- 76/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5323 - -
-``` - -``` -
- 84/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5309 - 83/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5306 - 86/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5307 - 87/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5303 - 82/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5312 - -
-``` - -``` -
- 85/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5308 - -
-``` - -``` -
- 90/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5300 - -
-``` - -``` -
- 81/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5312 - -
-``` - -``` -
- 88/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5302 - 91/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5299 - -
-``` - -``` -
- 93/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5294 - 94/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5293 - 89/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5298 - 98/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5287 - 92/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5298 - 95/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5290 - 97/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5291 - -
-``` - -``` -
- 96/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5286 - -
-``` - -``` -
- 207/207 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 6.5158 - val_loss: 5.9527 - - -
-``` -Epoch 5/5 - -``` -
- - 1/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5082 - - - 2/207 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.5182 - - - 3/207 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.5179 - -
-``` - -``` -
- 14/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5126 - -
-``` - -``` -
- 15/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5127 - -
-``` - -``` -
- 16/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5126 - -
-``` - -``` -
- 17/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5126 - -
-``` - -``` -
- 18/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5127 - -
-``` - -``` -
- 19/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5127 - -
-``` - -``` -
- 20/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5128 - 21/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5128 - -
-``` - -``` -
- 24/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5130 - 23/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5128 - 25/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5132 - 22/207 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 6.5128 - -
-``` - -``` -
- 32/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5151 - 33/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5157 - -
-``` - -``` -
- 27/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5138 - 29/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5153 - 30/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5148 - -
-``` - -``` -
- 28/207 ━━━━━━━━━━━━━━━━━━━━ 1s 8ms/step - loss: 6.5144 - -
-``` - -``` -
- 34/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5158 - 26/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5140 - -
-``` - -``` -
- 31/207 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - loss: 6.5155 - -
-``` - -``` -
- 37/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5160 - -
-``` - -``` -
- 36/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5161 - -
-``` - -``` -
- 35/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5159 - -
-``` - -``` -
- 44/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5160 - -
-``` - -``` -
- 42/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5161 - 43/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5159 - 41/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5161 - 38/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5162 - 45/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5160 - -
-``` - -``` -
- 40/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5162 - -
-``` - -``` -
- 39/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5162 - 46/207 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 6.5159 - -
-``` - -``` -
- 47/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5154 - 49/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5154 - 50/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5150 - 51/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5151 - 48/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5153 - -
-``` - -``` -
- 55/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5146 - 54/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5145 - 57/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5140 - -
-``` - -``` -
- 53/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5145 - 52/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5149 - -
-``` - -``` -
- 58/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5141 - 56/207 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 6.5140 - -
-``` - -``` -
- 63/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5135 - 60/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5137 - -
-``` - -``` -
- 59/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5138 - 61/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5134 - 62/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5135 - -
-``` - -``` -
- 67/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5124 - -
-``` - -``` -
- 69/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5117 - 65/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5130 - 66/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5123 - -
-``` - -``` -
- 64/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5131 - 71/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5119 - 70/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5125 - -
-``` - -``` -
- 68/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5118 - -
-``` - -``` -
- 72/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5113 - 73/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5114 - -
-``` - -``` -
- 75/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5110 - 77/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5107 - -
-``` - -``` -
- 76/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5108 - -
-``` - -``` -
- 74/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5110 - -
-``` - -``` -
- 84/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5098 - 81/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5095 - -
-``` - -``` -
- 78/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5107 - 80/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5105 - 83/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5098 - 79/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5097 - 82/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5099 - 85/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5100 - -
-``` - -``` -
- 86/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5091 - -
-``` - -``` -
- 87/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5089 - 89/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5087 - 88/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5090 - 92/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5085 - -
-``` - -``` -
- 95/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5076 - 91/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5079 - -
-``` - -``` -
- 97/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5081 - 90/207 ━━━━━━━━━━━━━━━━━━━━ 1s 12ms/step - loss: 6.5080 - -
-``` - -``` -
- 93/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5078 - 98/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5074 - 94/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5083 - -
-``` - -``` -
- 96/207 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 6.5082 - -
-``` - -``` -
- 207/207 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - loss: 6.4960 - val_loss: 5.9651 - - - - - -
-``` - - -``` -
---- -## Making predictions - -Now that we have a model, we would like to be able to make predictions. - -So far, we have only handled movies by id. Now is the time to create a mapping -keyed by movie IDs to be able to surface the titles. - - -```python -movie_id_to_movie_title = dict(zip(movies_df["MovieID"], movies_df["Title"])) -movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. -``` - -We then simply use the Keras `model.predict()` method. Under the hood, it calls -the `BruteForceRetrieval` layer to perform the actual retrieval. - -Note that this model can retrieve movies already watched by the user. We could -easily add logic to remove them if that is desirable. - - -```python -print("\n==> Movies the user has watched:") -movie_sequence = test_ds.unbatch().take(1) -for element in movie_sequence: - for movie_id in element[0][:-1]: - print(movie_id_to_movie_title[movie_id.numpy()], end=", ") - print(movie_id_to_movie_title[element[0][-1].numpy()]) - -predictions = model.predict(movie_sequence.batch(1)) -predictions = keras.ops.convert_to_numpy(predictions["predictions"]) - -print("\n==> Recommended movies for the above sequence:") -for movie_id in predictions[0]: - print(movie_id_to_movie_title[movie_id]) -``` - - -
-``` -==> Movies the user has watched: -10 Things I Hate About You (1999), American Beauty (1999), Bachelor, The (1999), Austin Powers: The Spy Who Shagged Me (1999), Arachnophobia (1990), Big Daddy (1999), Bone Collector, The (1999), Bug's Life, A (1998), Bowfinger (1999), Dead Calm (1989) - -``` -
- -
-``` - 1/Unknown 0s 300ms/step - - -``` -
- 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 302ms/step - - - -
-``` -==> Recommended movies for the above sequence: -Creepshow (1982) -Bringing Out the Dead (1999) -Civil Action, A (1998) -Doors, The (1991) -Cruel Intentions (1999) -Brokedown Palace (1999) -Dead Calm (1989) -Condorman (1981) -Clan of the Cave Bear, The (1986) -Clerks (1994) - -/usr/local/lib/python3.11/dist-packages/keras/src/trainers/epoch_iterator.py:151: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset. - self._interrupted_warning() - -``` -
\ No newline at end of file