-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
Merge pull request #115 from alan-turing-institute/add_video_gemini
Add file upload to support multimodal Gemini API
Showing
9 changed files
with
275 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
13 changes: 10 additions & 3 deletions
13
examples/gemini/data/input/gemini-multimodal-example.jsonl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,10 @@ | ||
{"id": 0, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this image", {"type": "image", "media": "pantani_giro.jpg"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} | ||
{"id": 1, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": [{"type": "image", "media": "mortadella.jpg"}, "what is this?"]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} | ||
{"id": 2, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["what is in this image?", {"type": "image", "media": "pantani_giro.jpg"}]}, {"role": "model", "parts": "This is image shows a group of cyclists."}, {"role": "user", "parts": "are there any notable cyclists in this image? what are their names?"}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} | ||
{"id": 1, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this image", {"type": "image", "media": "pantani_giro.jpg"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} | ||
{"id": 2, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": [{"type": "image", "media": "mortadella.jpg"}, "what is this?"]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} | ||
{"id": 3, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["what is in this image?", {"type": "image", "media": "pantani_giro.jpg"}]}, {"role": "model", "parts": "This is image shows a group of cyclists."}, {"role": "user", "parts": "are there any notable cyclists in this image? what are their names?"}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} | ||
{"id": 4, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this video", {"type": "video", "media": "test.mp4"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} | ||
{"id": 5, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this audio", {"type": "audio", "media": "test.wav"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} | ||
{"id": 6, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe what is happening in this audio", {"type": "audio", "media": "test.mp3"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} | ||
{"id": 7, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe this video", {"type": "video", "media": "test.mp4"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} | ||
{"id": 8, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "user", "parts": ["describe this video", {"type": "video", "media": "test2.mp4"}]}], "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 1000}} | ||
{"id": 9, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": "How does technology impact us?", "safety_filter": "none", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}} | ||
{"id": 10, "api": "gemini", "model_name": "gemini-1.5-flash", "prompt": [{"role": "system", "parts": "You are a helpful assistant designed to answer questions briefly."}, {"role": "user", "parts": "Hello, I'm Bob and I'm 6 years old"}, {"role": "model", "parts": "Hi Bob, how may I assist you?"}, {"role": "user", "parts": "How old will I be next year?"}], "safety_filter": "most", "parameters": {"candidate_count": 1, "temperature": 1, "max_output_tokens": 100}} |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Uploading media to Gemini\n", | ||
"\n", | ||
"This notebook processes an experiment file and associate each media element with the id of the file when uploaded using the Files API" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import json\n", | ||
"import time\n", | ||
"import tqdm\n", | ||
"import base64\n", | ||
"import hashlib\n", | ||
"import google.generativeai as genai" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Set the location of the experiment and media\n", | ||
"\n", | ||
"experiment_location = \"data/input\"\n", | ||
"filename = \"gemini-multimodal-example.jsonl\"\n", | ||
"media_location = \"data/media\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load the GEMINI_API_KEY from the environment\n", | ||
"\n", | ||
"GEMINI_API_KEY = os.environ.get(\"GEMINI_API_KEY\")\n", | ||
"if GEMINI_API_KEY is None:\n", | ||
" raise ValueError(\"GEMINI_API_KEY is not set\")\n", | ||
"\n", | ||
"genai.configure(api_key=GEMINI_API_KEY)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def compute_sha256_base64(file_path, chunk_size=8192):\n", | ||
" \"\"\"\n", | ||
" Compute the SHA256 hash of the file at 'file_path' and return it as a base64-encoded string.\n", | ||
" \"\"\"\n", | ||
" hasher = hashlib.sha256()\n", | ||
" with open(file_path, \"rb\") as f:\n", | ||
" for chunk in iter(lambda: f.read(chunk_size), b\"\"):\n", | ||
" hasher.update(chunk)\n", | ||
" return base64.b64encode(hasher.digest()).decode(\"utf-8\")\n", | ||
"\n", | ||
"\n", | ||
"def remote_file_hash_base64(remote_file):\n", | ||
" \"\"\"\n", | ||
" Convert a remote file's SHA256 hash (stored as a hex-encoded UTF-8 bytes object)\n", | ||
" to a base64-encoded string.\n", | ||
" \"\"\"\n", | ||
" hex_str = remote_file.sha256_hash.decode(\"utf-8\")\n", | ||
" raw_bytes = bytes.fromhex(hex_str)\n", | ||
" return base64.b64encode(raw_bytes).decode(\"utf-8\")\n", | ||
"\n", | ||
"\n", | ||
"def wait_for_processing(file_obj, poll_interval=10):\n", | ||
" \"\"\"\n", | ||
" Poll until the file is no longer in the 'PROCESSING' state.\n", | ||
" Returns the updated file object.\n", | ||
" \"\"\"\n", | ||
" while file_obj.state.name == \"PROCESSING\":\n", | ||
" print(\"Waiting for file to be processed...\")\n", | ||
" time.sleep(poll_interval)\n", | ||
" file_obj = genai.get_file(file_obj.name)\n", | ||
" return file_obj\n", | ||
"\n", | ||
"\n", | ||
"def upload(file_path, already_uploaded_files):\n", | ||
" \"\"\"\n", | ||
" Upload the file at 'file_path' if it hasn't been uploaded yet.\n", | ||
" If a file with the same SHA256 (base64-encoded) hash exists, returns its name.\n", | ||
" Otherwise, uploads the file, waits for it to be processed,\n", | ||
" and returns the new file's name. Raises a ValueError if processing fails.\n", | ||
" \"\"\"\n", | ||
" local_hash = compute_sha256_base64(file_path)\n", | ||
"\n", | ||
" if local_hash in already_uploaded_files:\n", | ||
" return already_uploaded_files[local_hash], already_uploaded_files\n", | ||
"\n", | ||
" # Upload the file if it hasn't been found.\n", | ||
" file_obj = genai.upload_file(path=file_path)\n", | ||
" file_obj = wait_for_processing(file_obj)\n", | ||
"\n", | ||
" if file_obj.state.name == \"FAILED\":\n", | ||
" raise ValueError(\"File processing failed\")\n", | ||
" already_uploaded_files[local_hash] = file_obj.name\n", | ||
" return already_uploaded_files[local_hash], already_uploaded_files" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Retrieve already uploaded files\n", | ||
"\n", | ||
"uploaded_files = {\n", | ||
" remote_file_hash_base64(remote_file): remote_file.name\n", | ||
" for remote_file in genai.list_files()\n", | ||
"}\n", | ||
"print(f\"Found {len(uploaded_files)} files already uploaded\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"files_to_upload = set()\n", | ||
"experiment_path = f\"{experiment_location}/{filename}\"\n", | ||
"\n", | ||
"# Read and collect media file paths\n", | ||
"with open(experiment_path, \"r\") as f:\n", | ||
" lines = f.readlines()\n", | ||
"\n", | ||
"data_list = []\n", | ||
"\n", | ||
"for line in lines:\n", | ||
" data = json.loads(line)\n", | ||
" data_list.append(data)\n", | ||
"\n", | ||
" if not isinstance(data.get(\"prompt\"), list):\n", | ||
" continue\n", | ||
"\n", | ||
" files_to_upload.update(\n", | ||
" f'{media_location}/{el[\"media\"]}'\n", | ||
" for prompt in data[\"prompt\"]\n", | ||
" for part in prompt.get(\"parts\", [])\n", | ||
" if isinstance(el := part, dict) and \"media\" in el\n", | ||
" )\n", | ||
"\n", | ||
"# Upload files and store mappings\n", | ||
"genai_files = {}\n", | ||
"for file_path in tqdm.tqdm(files_to_upload):\n", | ||
" uploaded_filename, uploaded_files = upload(file_path, uploaded_files)\n", | ||
" genai_files[file_path] = uploaded_filename\n", | ||
"\n", | ||
"# Modify data to include uploaded filenames\n", | ||
"for data in data_list:\n", | ||
" if isinstance(data.get(\"prompt\"), list):\n", | ||
" for prompt in data[\"prompt\"]:\n", | ||
" for part in prompt.get(\"parts\", []):\n", | ||
" if isinstance(part, dict) and \"media\" in part:\n", | ||
" file_path = f'{media_location}/{part[\"media\"]}'\n", | ||
" if file_path in genai_files:\n", | ||
" part[\"uploaded_filename\"] = genai_files[file_path]\n", | ||
" else:\n", | ||
" print(f\"Failed to find {file_path} in genai_files\")\n", | ||
"\n", | ||
"# Write modified data back to the JSONL file\n", | ||
"with open(experiment_path, \"w\") as f:\n", | ||
" for data in data_list:\n", | ||
" f.write(json.dumps(data) + \"\\n\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import os | ||
|
||
from PIL import Image | ||
|
||
from prompto.apis.gemini.gemini_utils import parse_parts_value | ||
|
||
|
||
def test_parse_parts_value_text(): | ||
part = "text" | ||
media_folder = "media" | ||
result = parse_parts_value(part, media_folder) | ||
assert result == part | ||
|
||
|
||
def test_parse_parts_value_image(): | ||
part = "image" | ||
media_folder = "media" | ||
result = parse_parts_value(part, media_folder) | ||
assert result == part | ||
|
||
|
||
def test_parse_parts_value_image_dict(): | ||
part = {"type": "image", "media": "pantani_giro.jpg"} | ||
media_folder = "media" | ||
|
||
# Create a mock image | ||
if not os.path.exists(media_folder): | ||
os.makedirs(media_folder) | ||
image_path = os.path.join(media_folder, "pantani_giro.jpg") | ||
image = Image.new("RGB", (100, 100), color="red") | ||
image.save(image_path) | ||
|
||
result = parse_parts_value(part, media_folder) | ||
|
||
# Assert the result | ||
assert result.mode == "RGB" | ||
assert result.size == (100, 100) | ||
assert result.filename.endswith("pantani_giro.jpg") | ||
|
||
# Clean up the mock image | ||
os.remove(image_path) | ||
os.rmdir(media_folder) | ||
|
||
|
||
def test_parse_parts_value_video(): | ||
part = {"type": "video", "media": "pantani_giro.mp4"} | ||
media_folder = "media" | ||
|
||
result = parse_parts_value(part, media_folder) | ||
assert result == part |