Skip to content

Commit

Permalink
Fix/update sklearn (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Mar 18, 2024
1 parent a6c215f commit 9da7847
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "4.26.3"
version = "4.26.4"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
"numpy>=1.22.4",
"pandas>=1.5.3",
"pytz>=2022.7.1",
"scikit_learn>=1.2.2",
"scikit_learn>=1.4.0",
"torch>=1.13.1",
"tqdm>=4.64.1",
"statsmodels>=0.13.5",
Expand Down
20 changes: 9 additions & 11 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import StratifiedGroupKFold, TimeSeriesSplit
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.metrics import root_mean_squared_error, mean_absolute_error, r2_score
from scipy.optimize import minimize
from itertools import accumulate
from tqdm.auto import tqdm
Expand Down Expand Up @@ -184,7 +184,7 @@ def __init__(
):
if dataframe.empty:
raise ValueError("Training data is inadequate.")
if sort_by_length > 0:
if sort_by_length:
dataframe = dataframe.sort_values(by=["i"])
self.x_train = pad_sequence(
dataframe["tensor"].to_list(), batch_first=True, padding_value=0
Expand Down Expand Up @@ -624,7 +624,7 @@ def create_time_series(
).to_julian_date()
df.drop_duplicates(["card_id", "real_days"], keep="first", inplace=True)
df["delta_t"] = df.real_days.diff()
df["delta_t"].fillna(0, inplace=True)
df.fillna({"delta_t": 0}, inplace=True)
df["i"] = df.groupby("card_id").cumcount() + 1
df.loc[df["i"] == 1, "delta_t"] = 0
if df.empty:
Expand Down Expand Up @@ -856,8 +856,8 @@ def loss(stability):
rating_stability[int(first_rating)] = stability
rating_count[int(first_rating)] = sum(count)
predict_recall = power_forgetting_curve(delta_t, *params)
rmse = mean_squared_error(
recall, predict_recall, sample_weight=count, squared=False
rmse = root_mean_squared_error(
recall, predict_recall, sample_weight=count
)

if verbose:
Expand Down Expand Up @@ -1555,11 +1555,10 @@ def loss(stability):
analysis_group.dropna(inplace=True)
analysis_group.drop_duplicates(subset=[(group_key, "")], inplace=True)
analysis_group.sort_values(by=[group_key], inplace=True)
rmse = mean_squared_error(
rmse = root_mean_squared_error(
analysis_group["true_s"],
analysis_group["predicted_s"],
sample_weight=analysis_group["total_count"],
squared=False,
)
fig = plt.figure()
ax1 = fig.add_subplot(111)
Expand Down Expand Up @@ -1824,11 +1823,10 @@ def get_bin(x, bins=20):
cross_comparison_group = cross_comparison_record.groupby(by=f"{algoA}_bin").agg(
{"y": ["mean"], f"{algoB}_B-W": ["mean"], f"R ({algoB})": ["mean", "count"]}
)
universal_metric = mean_squared_error(
universal_metric = root_mean_squared_error(
y_true=cross_comparison_group["y", "mean"],
y_pred=cross_comparison_group[f"R ({algoB})", "mean"],
sample_weight=cross_comparison_group[f"R ({algoB})", "count"],
squared=False,
)
cross_comparison_group[f"R ({algoB})", "percent"] = (
cross_comparison_group[f"R ({algoB})", "count"]
Expand Down Expand Up @@ -1879,8 +1877,8 @@ def rmse_matrix(df):
.agg({"y": "mean", "p": "mean", "card_id": "count"})
.reset_index()
)
return mean_squared_error(
tmp["y"], tmp["p"], sample_weight=tmp["card_id"], squared=False
return root_mean_squared_error(
tmp["y"], tmp["p"], sample_weight=tmp["card_id"]
)


Expand Down

0 comments on commit 9da7847

Please sign in to comment.