From df93cfeb2ee1d5e7b7ca09ce3b24be8f81d48f5e Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Thu, 9 Jan 2025 11:44:44 +0800 Subject: [PATCH] Fix/fallback to 1 when weights doesn't exist & create "tensor" if it doesn't exist (#159) --- pyproject.toml | 2 +- src/fsrs_optimizer/fsrs_optimizer.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5876ced..28c86ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "5.6.4" +version = "5.6.5" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 93aff68..a18e703 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -1541,6 +1541,13 @@ def moving_average(data, window_size=365 // 20): def evaluate(self, save_to_file=True): my_collection = Collection(DEFAULT_PARAMETER, self.float_delta_t) + if "tensor" not in self.dataset.columns: + self.dataset["tensor"] = self.dataset.progress_apply( + lambda x: lineToTensor( + list(zip([x["t_history"]], [x["r_history"]]))[0] + ), + axis=1, + ) stabilities, difficulties = my_collection.batch_predict(self.dataset) self.dataset["stability"] = stabilities self.dataset["difficulty"] = difficulties @@ -1551,6 +1558,8 @@ def evaluate(self, save_to_file=True): lambda row: -np.log(row["p"]) if row["y"] == 1 else -np.log(1 - row["p"]), axis=1, ) + if "weights" not in self.dataset.columns: + self.dataset["weights"] = 1 self.dataset["log_loss"] = ( self.dataset["log_loss"] * self.dataset["weights"]