Skip to content

Commit

Permalink
Add geo_scores_for_taxa endpoint (#28)
Browse files Browse the repository at this point in the history
* add endpoint for fetching obs geo scores in bulk

* additional validation for geo_scores_for_taxa endpoint
  • Loading branch information
pleary authored Sep 19, 2024
1 parent dee8b6e commit 9c4f210
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
26 changes: 26 additions & 0 deletions lib/inat_inferrer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def setup_elevation_dataframe(self):
self.geo_elevation_cells = InatInferrer.add_lat_lng_to_h3_geo_dataframe(
self.geo_elevation_cells
)
self.geo_elevation_cell_indices = {
index: idx for idx, index in enumerate(self.geo_elevation_cells.index)
}

def setup_elevation_dataframe_from_worldclim(self, resolution):
# preventing from processing at too high a resolution
Expand Down Expand Up @@ -422,6 +425,29 @@ def aggregate_results(self, leaf_scores, debug=False,
# InatInferrer.print_aggregated_scores(all_node_scores)
return all_node_scores

def h3_04_geo_results_for_taxon_and_cell(self, taxon_id, lat, lng):
if lat is None or lng is None:
return None
try:
lat_float = float(lat)
lng_float = float(lng)
except ValueError:
return None

try:
taxon = self.taxonomy.df.loc[taxon_id]
except KeyError:
return None

if pd.isna(taxon["leaf_class_id"]) or pd.isna(taxon["geo_threshold"]):
return None

h3_cell = h3.geo_to_h3(lat_float, lng_float, 4)
return float(self.geo_elevation_model.eval_one_class_elevation_from_features(
[self.geo_model_features[self.geo_elevation_cell_indices[h3_cell]]],
int(taxon["leaf_class_id"])
)[0][0]) / taxon["geo_threshold"]

def h3_04_geo_results_for_taxon(self, taxon_id, bounds=[],
thresholded=False, raw_results=False):
if (self.geo_elevation_cells is None) or (self.geo_elevation_model is None):
Expand Down
10 changes: 10 additions & 0 deletions lib/inat_vision_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(self, config):
self.h3_04_taxon_range_comparison_route, methods=["GET"])
self.app.add_url_rule("/h3_04_bounds", "h3_04_bounds",
self.h3_04_bounds_route, methods=["GET"])
self.app.add_url_rule("/geo_scores_for_taxa", "geo_scores_for_taxa",
self.geo_scores_for_taxa_route, methods=["POST"])
self.app.add_url_rule("/build_info", "build_info", self.build_info_route, methods=["GET"])

def setup_inferrer(self, config):
Expand Down Expand Up @@ -86,6 +88,14 @@ def build_info_route(self):
"build_date": os.getenv("BUILD_DATE", "")
}

def geo_scores_for_taxa_route(self):
return {
obs["id"]: self.inferrer.h3_04_geo_results_for_taxon_and_cell(
obs["taxon_id"], obs["lat"], obs["lng"]
)
for obs in request.json["observations"]
}

def index_route(self):
form = ImageForm()
if "observation_id" in request.args:
Expand Down

0 comments on commit 9c4f210

Please sign in to comment.