Skip to content

Commit

Permalink
add option to test against the Gemini API
Browse files Browse the repository at this point in the history
  • Loading branch information
pleary committed Sep 10, 2024
1 parent dee8b6e commit 8c915ef
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 12 deletions.
1 change: 1 addition & 0 deletions config.yml.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
app_secret: "somesecret"
gemini_api_key: "gemini_api_key"
models:
- name: "ModelGenerationName"
vision_model_path: "models/.../vision_model.h5"
Expand Down
16 changes: 11 additions & 5 deletions lib/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@

class TestObservation:

def __init__(self, row):
row["taxon_ancestry"] = row["taxon_ancestry"].split("/")
row["taxon_ancestry"] = list(map(int, row["taxon_ancestry"]))
# remove life
row["taxon_ancestry"].pop(0)
def __init__(self, row, gemini_attributes=False):
if not gemini_attributes:
row["taxon_ancestry"] = row["taxon_ancestry"].split("/")
row["taxon_ancestry"] = list(map(int, row["taxon_ancestry"]))
# remove life
row["taxon_ancestry"].pop(0)
for key in row:
setattr(self, key, row[key])
if gemini_attributes:
self.gemini_response_text = None
self.gemini_error = None
return

self.inferrer_results = None
self.summarized_results = {}
85 changes: 78 additions & 7 deletions lib/vision_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import aiofiles.os
import re
import traceback
import google.generativeai as gemini
from datetime import datetime
from PIL import Image
from lib.test_observation import TestObservation
Expand All @@ -27,6 +28,11 @@ def __init__(self, config, **args):
self.start_timestamp = currentDatetime.strftime("%Y%m%d")
self.set_run_hash(config)

if self.cmd_args["gemini"]:
gemini.configure(api_key=config["gemini_api_key"])
self.gemini_model = gemini.GenerativeModel(model_name="gemini-1.5-pro")
return

print("Models:")
for inferrer_index, model_config in enumerate(config["models"]):
print(json.dumps(model_config, indent=4))
Expand Down Expand Up @@ -64,14 +70,20 @@ async def run_async(self):
path = os.path.join(self.cmd_args["data_dir"], file)
print(f"\nProcessing {file}")
await self.test_observations_at_path(path, label)
self.display_and_save_results(label)
if self.cmd_args["gemini"]:
self.display_and_save_results_gemini(label)
else:
self.display_and_save_results(label)
else:
print(f"\nProcessing {self.cmd_args['path']}")
await self.test_observations_at_path(self.cmd_args["path"], self.cmd_args["label"])
self.display_and_save_results(self.cmd_args["label"])
if self.cmd_args["gemini"]:
self.display_and_save_results_gemini(self.cmd_args["label"])
else:
self.display_and_save_results(self.cmd_args["label"])

async def test_observations_at_path(self, path, label):
N_WORKERS = 5
N_WORKERS = 1 if self.cmd_args["gemini"] else 5
self.limit = self.cmd_args["limit"] or 100
target_observation_id = self.cmd_args["observation_id"]
self.start_time = time.time()
Expand All @@ -96,7 +108,7 @@ async def test_observations_at_path(self, path, label):
for index, observation in df.iterrows():
if target_observation_id and observation.observation_id != target_observation_id:
continue
obs = TestObservation(observation.to_dict())
obs = TestObservation(observation.to_dict(), gemini_attributes=self.cmd_args["gemini"])
self.test_observations[obs.observation_id] = obs
self.queue.put_nowait(obs.observation_id)

Expand All @@ -113,9 +125,14 @@ async def worker_task(self):
if self.processed_counter >= self.limit:
continue
observation = self.test_observations[observation_id]
await self.test_observation_async(observation)
if observation.inferrer_results is None:
continue
if self.cmd_args["gemini"]:
await self.test_observation_with_gemini_async(observation)
if observation.gemini_response_text is None:
continue
else:
await self.test_observation_async(observation)
if observation.inferrer_results is None:
continue
self.processed_counter += 1
self.report_progress()

Expand All @@ -127,6 +144,60 @@ async def worker_task(self):
finally:
self.queue.task_done()

async def test_observation_with_gemini_async(self, observation):
cache_path = await self.download_photo_async(observation.photo_url)

# due to asynchronous processing, the requested limit of observations to test
# has been reached, so do not test this observation. The rest of this method
# will be processed synchronously, so no need to check this again this method
if self.processed_counter >= self.limit:
return

if cache_path is None \
or not os.path.exists(cache_path) \
or observation.lat == "" \
or observation.lng == "":
return

try:
print(f"Uploading {cache_path}, for {observation.observation_id}")
sample_file = gemini.upload_file(path=cache_path)
# Prompt the model with text and the previously uploaded image.
response = self.gemini_model.generate_content([
sample_file,
"Return only the binomial species name of the organism in the photo. "
f"The photo was taken on {observation.observed_on} "
f"at latitude {observation.lat} and longitude {observation.lng}"
])
observation.gemini_response_text = response.text.strip()
print({observation.observation_id: observation.gemini_response_text})
except Exception as e:
observation.gemini_error = True
print(f"Error scoring observation {observation.observation_id}")
print(response)
print(e)
print(traceback.format_exc())
return

def display_and_save_results_gemini(self, label):
scored_observations = list(filter(
lambda observation: observation.gemini_response_text is not None or observation.gemini_error is not None,
self.test_observations.values()
))
if len(scored_observations) == 0:
return
all_obs_responses = []
for obs in scored_observations:
all_obs_responses.append({
"observation_id": obs.observation_id,
"taxon_id": obs.taxon_id,
"taxon_ancestry": obs.taxon_ancestry,
"gemini_response": obs.gemini_response_text,
"gemini_error": obs.gemini_error
})
export_path = self.export_path("gemini", label=label)
pd.DataFrame(all_obs_responses).to_csv(export_path)

def display_and_save_results(self, label):
scored_observations = list(filter(
lambda observation: len(observation.summarized_results) > 0,
Expand Down
6 changes: 6 additions & 0 deletions test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@
@click.option("--observation_id", type=str, help="Single observation UUID to test.")
@click.option("--filter-iconic/--no-filter-iconic", show_default=True, default=True,
help="Use iconic taxon for filtering.")
@click.option("--gemini", is_flag=True, show_default=True, default=False,
help="Output debug messages.")
@click.option("--debug", is_flag=True, show_default=True, default=False,
help="Output debug messages.")
def test(**args):
if not args["path"] and not args["data_dir"]:
print("\nYou must specify either a `--path` or a `--data_dir` option\n")
exit()

if args["gemini"] and ("gemini_api_key" not in CONFIG or not CONFIG["gemini_api_key"]):
print("\nconfig.yml does not configure a `gemini_api_key`\n")
exit()

# some libraries are slow to import, so wait until command is validated and properly invoked
from lib.vision_testing import VisionTesting
print("\nArguments:")
Expand Down

0 comments on commit 8c915ef

Please sign in to comment.