diff --git a/torch-neuronx/inference/hf_pretrained_perceiver_multimodal_inference.ipynb b/torch-neuronx/inference/hf_pretrained_perceiver_multimodal_inference.ipynb
index a99e04a..574113c 100644
--- a/torch-neuronx/inference/hf_pretrained_perceiver_multimodal_inference.ipynb
+++ b/torch-neuronx/inference/hf_pretrained_perceiver_multimodal_inference.ipynb
@@ -1,789 +1,798 @@
{
- "cells": [
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## HuggingFace Multimodal Perceiver Inference on Trn1 / Inf2"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**Introduction**\n",
- "\n",
- "This notebook demonstrates how to compile and run the HuggingFace Multimodal Perceiver model to classify and autoencode video inputs on Neuron. The script is loosely based on HuggingFace's official tutorial for running inference on the multimodal perceiver at https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Perceiver/Perceiver_for_Multimodal_Autoencoding.ipynb\n",
- "\n",
- "This notebook can be run on the smallest Inf2 instance `inf2.xlarge`"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the [PyTorch Installation Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/torch-neuronx.html#setup-torch-neuronx). You can select the kernel from the 'Kernel -> Change Kernel' option on the top of this Jupyter notebook page."
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**Install Dependencies**\n",
- "\n",
- "This tutorial requires the following pip packages to be installed:\n",
- "- `torch-neuronx`\n",
- "- `neuronx-cc`\n",
- "- `transformers==4.30.2`\n",
- "- `opencv-python-headless`\n",
- "- `imageio`\n",
- "- `scipy`\n",
- "- `accelerate`\n",
- "Furthermore, it requires the `ffmpeg` video-audio converter which is used to extract audio from the input videos.\n",
- "\n",
- "`torch-neuronx` and `neuronx-cc` should be installed when you configure your environment following the Inf2 setup guide. The remaining dependencies can be installed below:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "%env TOKENIZERS_PARALLELISM=True #Supresses tokenizer warnings making errors easier to detect\n",
- "!pip install transformers==4.30.2 opencv-python-headless==4.8.0.74 imageio scipy accelerate opencv-python==4.8.0.74\n",
- "\n",
- "!wget https://johnvansickle.com/ffmpeg/builds/ffmpeg-git-amd64-static.tar.xz\n",
- "!tar xvf ffmpeg-git-amd64-static.tar.xz\n",
- "!mv ffmpeg-git-*-amd64-static/ffmpeg .\n",
- "!rm -rf ffmpeg-git-*-amd64-static ffmpeg-git-amd64-static.tar.xz"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**Imports**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import base64\n",
- "import os\n",
- "import ssl\n",
- "import re\n",
- "from urllib import request\n",
- "import cv2\n",
- "import imageio\n",
- "import time\n",
- "import random\n",
- "from tqdm import tqdm\n",
- "import numpy as np\n",
- "import scipy.io.wavfile\n",
- "from IPython.display import HTML\n",
- "\n",
- "from typing import Optional, Tuple, Union\n",
- "from transformers import PerceiverForMultimodalAutoencoding\n",
- "from transformers.modeling_outputs import BaseModelOutputWithCrossAttentions\n",
- "from transformers.models.perceiver.modeling_perceiver import PerceiverBasicDecoder, PerceiverClassifierOutput\n",
- "from transformers.models.perceiver.modeling_perceiver import restructure\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "import torch_neuronx"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**Video Preprocessing Utilities**\n",
- "\n",
- "The following code cell defines some useful functions for fetching, preprocessing and visualizing the input video. Most of these are taken directly from HuggingFace's official multimodal perceiver tutorial at https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Perceiver/Perceiver_for_Multimodal_Autoencoding.ipynb. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Utilities to fetch videos from UCF101 dataset\n",
- "UCF_ROOT = 'https://www.crcv.ucf.edu/THUMOS14/UCF101/UCF101/'\n",
- "_VIDEO_LIST = None\n",
- "_CACHE_DIR_NAME = \"video_cache\"\n",
- "\n",
- "os.makedirs(\"video_cache\", exist_ok=True)\n",
- "# As of July 2020, crcv.ucf.edu doesn't use a certificate accepted by the\n",
- "# default Colab environment anymore.\n",
- "unverified_context = ssl._create_unverified_context()\n",
- "\n",
- "def list_ucf_videos():\n",
- " \"\"\"Lists videos available in UCF101 dataset.\"\"\"\n",
- " global _VIDEO_LIST\n",
- " if not _VIDEO_LIST:\n",
- " index = request.urlopen(UCF_ROOT, context=unverified_context).read().decode('utf-8')\n",
- " videos = re.findall('(v_[\\w_]+\\.avi)', index)\n",
- " _VIDEO_LIST = sorted(set(videos))\n",
- " return list(_VIDEO_LIST)\n",
- "\n",
- "def fetch_ucf_video(video):\n",
- " \"\"\"Fetchs a video and cache into local filesystem.\"\"\"\n",
- " cache_path = os.path.join(_CACHE_DIR_NAME, video)\n",
- " if not os.path.exists(cache_path):\n",
- " urlpath = request.urljoin(UCF_ROOT, video)\n",
- " print('Fetching %s => %s' % (urlpath, cache_path))\n",
- " data = request.urlopen(urlpath, context=unverified_context).read()\n",
- " open(cache_path, \"wb\").write(data)\n",
- " return cache_path\n",
- "\n",
- "# Utilities to open video files using CV2\n",
- "def crop_center_square(frame):\n",
- " y, x = frame.shape[0:2]\n",
- " min_dim = min(y, x)\n",
- " start_x = (x // 2) - (min_dim // 2)\n",
- " start_y = (y // 2) - (min_dim // 2)\n",
- " return frame[start_y:start_y+min_dim,start_x:start_x+min_dim]\n",
- "\n",
- "def load_video(path, max_frames=0, resize=(224, 224)):\n",
- " cap = cv2.VideoCapture(path)\n",
- " frames = []\n",
- " try:\n",
- " while True:\n",
- " ret, frame = cap.read()\n",
- " if not ret:\n",
- " break\n",
- " frame = crop_center_square(frame)\n",
- " frame = cv2.resize(frame, resize)\n",
- " frame = frame[:, :, [2, 1, 0]]\n",
- " frames.append(frame)\n",
- "\n",
- " if len(frames) == max_frames:\n",
- " break\n",
- " finally:\n",
- " cap.release()\n",
- " return np.array(frames) / 255.0\n",
- "\n",
- "def to_gif(images):\n",
- " converted_images = np.clip(images * 255, 0, 255).astype(np.uint8)\n",
- " imageio.mimsave('./animation.gif', converted_images, duration=40, loop=100)\n",
- " with open('./animation.gif', 'rb') as f:\n",
- " gif_64 = base64.b64encode(f.read()).decode('utf-8')\n",
- " return HTML('
' % gif_64)\n",
- "\n",
- "def play_audio(data, sample_rate=48000):\n",
- " scipy.io.wavfile.write('tmp_audio.wav', sample_rate, data)\n",
- "\n",
- " with open('./tmp_audio.wav', 'rb') as f:\n",
- " audio_64 = base64.b64encode(f.read()).decode('utf-8')\n",
- " return HTML('