Skip to content

Commit

Permalink
🎨 ore1
Browse files Browse the repository at this point in the history
  • Loading branch information
ferzcam committed Oct 24, 2024
1 parent cec2a80 commit d6bce1e
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,15 @@ def graph_path(self):
return self._graph_path

if "foodon" in self.use_case:
# graph_name = f"{self.use_case}-merged.train.cat.s1_filtered.edgelist"
graph_name = f"{self.use_case}-merged.train.cat.transitive_filtered.edgelist"
graph_name = f"{self.use_case}-merged.train.cat.s1_filtered.edgelist"
# graph_name = f"{self.use_case}-merged.train.cat.transitive_filtered.edgelist"
elif "go" in self.use_case:
graph_name = f"{self.use_case}.train.cat.s1_filtered.edgelist"
# graph_name = f"{self.use_case}.train.cat.s1_filtered.edgelist"
graph_name = f"{self.use_case}.train.cat.transitive_filtered.edgelist"
elif "ore1" in self.use_case:
# graph_name = f"ORE1.cat.edgelist"
graph_name = f"_train_ORE1_wrapped.cat.s1.edgelist"
# graph_name = f"_train_ORE1_wrapped.cat.transitive_filtered.edgelist"
graph_name = f"_train_ORE1_wrapped.cat.s1.transitive_filtered.edgelist"
else:
raise ValueError(f"Unknown use case {self.use_case}")

Expand Down Expand Up @@ -439,12 +441,12 @@ def train(self, wandb_logger):

tolerance = 0
best_loss = float("inf")
best_mr = float("inf")
best_mr = 10000000
best_mrr = 0
ont_classes_idxs = th.tensor(list(self.ontology_classes_idxs), dtype=th.long,
device=self.device)

for epoch in trange(self.epochs, desc=f"Training..."):
for epoch in trange(self.epochs, desc=f"Training. Best MRR: {best_mrr:.6f}. Best MR: {int(best_mr)}"):
# logging.info(f"Epoch: {epoch+1}")
self.model.train()

Expand Down

0 comments on commit d6bce1e

Please sign in to comment.