Skip to content

Commit

Permalink
Merge pull request #115 from alan-turing-institute/add_video_gemini
Browse files Browse the repository at this point in the history
Add file upload to support multimodal Gemini API
fedenanni authored Mar 3, 2025

Verified

This commit was signed with the committer’s verified signature.
2 parents 20474de + e748e87 commit 3d1dc94
Showing 9 changed files with 275 additions and 12 deletions.
3 changes: 2 additions & 1 deletion examples/gemini/README.md
Original file line number Diff line number Diff line change
@@ -11,7 +11,8 @@ prompto_run_experiment --file data/input/gemini-example.jsonl --max-queries 30

## Multimodal prompting

Multimodal prompting is available with the Gemini API. We provide an example notebook in the [Multimodal prompting with Vertex AI notebook](./gemini-multimodal.ipynb) and example experiment file in [data/input/gemini-multimodal-example.jsonl](https://github.com/alan-turing-institute/prompto/blob/main/examples/gemini/data/input/gemini-multimodal-example.jsonl). You can run it with the following command:
Multimodal prompting is available with the Gemini API. To use it, you first need to upload your files to a dedicated cloud storage using the [File API](https://ai.google.dev/api/files#v1beta.files). To support you with this step, we provide a [notebook](./gemini-upload.ipynb) which takes your multimedia prompts as input and will add to each of your `media` elements a corresponding `uploaded_filename` key/value. You can test this with the example experiment file in [data/input/gemini-multimodal-example.jsonl](https://github.com/alan-turing-institute/prompto/blob/main/examples/gemini/data/input/gemini-multimodal-example.jsonl).
Then, we provide an example notebook in the [Multimodal prompting with Vertex AI notebook](./gemini-multimodal.ipynb). You can run it with the following command:
```bash
prompto_run_experiment --file data/input/gemini-multimodal-example.jsonl --max-queries 30
```
13 changes: 10 additions & 3 deletions examples/gemini/data/input/gemini-multimodal-example.jsonl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
{"id": 0, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this image", {"type": "image", "media": "pantani_giro.jpg"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 1, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": [{"type": "image", "media": "mortadella.jpg"}, "what is this?"]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 2, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["what is in this image?", {"type": "image", "media": "pantani_giro.jpg"}]}, {"role": "model", "parts": "This is image shows a group of cyclists."}, {"role": "user", "parts": "are there any notable cyclists in this image? what are their names?"}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 1, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this image", {"type": "image", "media": "pantani_giro.jpg"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 2, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": [{"type": "image", "media": "mortadella.jpg"}, "what is this?"]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 3, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["what is in this image?", {"type": "image", "media": "pantani_giro.jpg"}]}, {"role": "model", "parts": "This is image shows a group of cyclists."}, {"role": "user", "parts": "are there any notable cyclists in this image? what are their names?"}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 4, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this video", {"type": "video", "media": "test.mp4"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 5, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this audio", {"type": "audio", "media": "test.wav"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 6, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this audio", {"type": "audio", "media": "test.mp3"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 7, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe this video", {"type": "video", "media": "test.mp4"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 8, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe this video", {"type": "video", "media": "test2.mp4"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}}
{"id": 9, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": "How does technology impact us?", "safety_filter": "none", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}}
{"id": 10, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "system", "parts": "You are a helpful assistant designed to answer questions briefly."}, {"role": "user", "parts": "Hello, I'm Bob and I'm 6 years old"}, {"role": "model", "parts": "Hi Bob, how may I assist you?"}, {"role": "user", "parts": "How old will I be next year?"}], "safety_filter": "most", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}}
Binary file added examples/gemini/data/media/test.mp3
Binary file not shown.
Binary file added examples/gemini/data/media/test.mp4
Binary file not shown.
Binary file added examples/gemini/data/media/test.wav
Binary file not shown.
Binary file added examples/gemini/data/media/test2.mp4
Binary file not shown.
206 changes: 206 additions & 0 deletions examples/gemini/gemini-upload.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Uploading media to Gemini\n",
"\n",
"This notebook processes an experiment file and associate each media element with the id of the file when uploaded using the Files API"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import time\n",
"import tqdm\n",
"import base64\n",
"import hashlib\n",
"import google.generativeai as genai"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set the location of the experiment and media\n",
"\n",
"experiment_location = \"data/input\"\n",
"filename = \"gemini-multimodal-example.jsonl\"\n",
"media_location = \"data/media\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the GEMINI_API_KEY from the environment\n",
"\n",
"GEMINI_API_KEY = os.environ.get(\"GEMINI_API_KEY\")\n",
"if GEMINI_API_KEY is None:\n",
" raise ValueError(\"GEMINI_API_KEY is not set\")\n",
"\n",
"genai.configure(api_key=GEMINI_API_KEY)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def compute_sha256_base64(file_path, chunk_size=8192):\n",
" \"\"\"\n",
" Compute the SHA256 hash of the file at 'file_path' and return it as a base64-encoded string.\n",
" \"\"\"\n",
" hasher = hashlib.sha256()\n",
" with open(file_path, \"rb\") as f:\n",
" for chunk in iter(lambda: f.read(chunk_size), b\"\"):\n",
" hasher.update(chunk)\n",
" return base64.b64encode(hasher.digest()).decode(\"utf-8\")\n",
"\n",
"\n",
"def remote_file_hash_base64(remote_file):\n",
" \"\"\"\n",
" Convert a remote file's SHA256 hash (stored as a hex-encoded UTF-8 bytes object)\n",
" to a base64-encoded string.\n",
" \"\"\"\n",
" hex_str = remote_file.sha256_hash.decode(\"utf-8\")\n",
" raw_bytes = bytes.fromhex(hex_str)\n",
" return base64.b64encode(raw_bytes).decode(\"utf-8\")\n",
"\n",
"\n",
"def wait_for_processing(file_obj, poll_interval=10):\n",
" \"\"\"\n",
" Poll until the file is no longer in the 'PROCESSING' state.\n",
" Returns the updated file object.\n",
" \"\"\"\n",
" while file_obj.state.name == \"PROCESSING\":\n",
" print(\"Waiting for file to be processed...\")\n",
" time.sleep(poll_interval)\n",
" file_obj = genai.get_file(file_obj.name)\n",
" return file_obj\n",
"\n",
"\n",
"def upload(file_path, already_uploaded_files):\n",
" \"\"\"\n",
" Upload the file at 'file_path' if it hasn't been uploaded yet.\n",
" If a file with the same SHA256 (base64-encoded) hash exists, returns its name.\n",
" Otherwise, uploads the file, waits for it to be processed,\n",
" and returns the new file's name. Raises a ValueError if processing fails.\n",
" \"\"\"\n",
" local_hash = compute_sha256_base64(file_path)\n",
"\n",
" if local_hash in already_uploaded_files:\n",
" return already_uploaded_files[local_hash], already_uploaded_files\n",
"\n",
" # Upload the file if it hasn't been found.\n",
" file_obj = genai.upload_file(path=file_path)\n",
" file_obj = wait_for_processing(file_obj)\n",
"\n",
" if file_obj.state.name == \"FAILED\":\n",
" raise ValueError(\"File processing failed\")\n",
" already_uploaded_files[local_hash] = file_obj.name\n",
" return already_uploaded_files[local_hash], already_uploaded_files"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Retrieve already uploaded files\n",
"\n",
"uploaded_files = {\n",
" remote_file_hash_base64(remote_file): remote_file.name\n",
" for remote_file in genai.list_files()\n",
"}\n",
"print(f\"Found {len(uploaded_files)} files already uploaded\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"files_to_upload = set()\n",
"experiment_path = f\"{experiment_location}/{filename}\"\n",
"\n",
"# Read and collect media file paths\n",
"with open(experiment_path, \"r\") as f:\n",
" lines = f.readlines()\n",
"\n",
"data_list = []\n",
"\n",
"for line in lines:\n",
" data = json.loads(line)\n",
" data_list.append(data)\n",
"\n",
" if not isinstance(data.get(\"prompt\"), list):\n",
" continue\n",
"\n",
" files_to_upload.update(\n",
" f'{media_location}/{el[\"media\"]}'\n",
" for prompt in data[\"prompt\"]\n",
" for part in prompt.get(\"parts\", [])\n",
" if isinstance(el := part, dict) and \"media\" in el\n",
" )\n",
"\n",
"# Upload files and store mappings\n",
"genai_files = {}\n",
"for file_path in tqdm.tqdm(files_to_upload):\n",
" uploaded_filename, uploaded_files = upload(file_path, uploaded_files)\n",
" genai_files[file_path] = uploaded_filename\n",
"\n",
"# Modify data to include uploaded filenames\n",
"for data in data_list:\n",
" if isinstance(data.get(\"prompt\"), list):\n",
" for prompt in data[\"prompt\"]:\n",
" for part in prompt.get(\"parts\", []):\n",
" if isinstance(part, dict) and \"media\" in part:\n",
" file_path = f'{media_location}/{part[\"media\"]}'\n",
" if file_path in genai_files:\n",
" part[\"uploaded_filename\"] = genai_files[file_path]\n",
" else:\n",
" print(f\"Failed to find {file_path} in genai_files\")\n",
"\n",
"# Write modified data back to the JSONL file\n",
"with open(experiment_path, \"w\") as f:\n",
" for data in data_list:\n",
" f.write(json.dumps(data) + \"\\n\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
15 changes: 7 additions & 8 deletions src/prompto/apis/gemini/gemini_utils.py
Original file line number Diff line number Diff line change
@@ -31,29 +31,28 @@ def parse_parts_value(part: dict | str, media_folder: str) -> any:

# read multimedia type
type = part.get("type")
uploaded_filename = part.get("uploaded_filename")
if type is None:
raise ValueError("Multimedia type is not specified")
# read file location
media = part.get("media")
if media is None:
raise ValueError("File location is not specified")

# create Part object based on multimedia type
if type == "text":
return media
else:
if type == "image":
media_file_path = os.path.join(media_folder, media)
return PIL.Image.open(media_file_path)
elif type == "file":
if uploaded_filename is None:
raise ValueError(
f"File {media} not uploaded. Please upload the file first."
)
else:
try:
return get_file(name=media)
return get_file(name=uploaded_filename)
except Exception as err:
raise ValueError(
f"Failed to get file: {media} due to error: {type(err).__name__} - {err}"
)
else:
raise ValueError(f"Unsupported multimedia type: {type}")


def parse_parts(parts: list[dict | str] | dict | str, media_folder: str) -> list[any]:
50 changes: 50 additions & 0 deletions tests/apis/gemini/test_gemini_image_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os

from PIL import Image

from prompto.apis.gemini.gemini_utils import parse_parts_value


def test_parse_parts_value_text():
part = "text"
media_folder = "media"
result = parse_parts_value(part, media_folder)
assert result == part


def test_parse_parts_value_image():
part = "image"
media_folder = "media"
result = parse_parts_value(part, media_folder)
assert result == part


def test_parse_parts_value_image_dict():
part = {"type": "image", "media": "pantani_giro.jpg"}
media_folder = "media"

# Create a mock image
if not os.path.exists(media_folder):
os.makedirs(media_folder)
image_path = os.path.join(media_folder, "pantani_giro.jpg")
image = Image.new("RGB", (100, 100), color="red")
image.save(image_path)

result = parse_parts_value(part, media_folder)

# Assert the result
assert result.mode == "RGB"
assert result.size == (100, 100)
assert result.filename.endswith("pantani_giro.jpg")

# Clean up the mock image
os.remove(image_path)
os.rmdir(media_folder)


def test_parse_parts_value_video():
part = {"type": "video", "media": "pantani_giro.mp4"}
media_folder = "media"

result = parse_parts_value(part, media_folder)
assert result == part

0 comments on commit 3d1dc94

Please sign in to comment.