diff --git a/RELEASENOTES.md b/RELEASENOTES.md
index 5126faec8..49bb3583d 100644
--- a/RELEASENOTES.md
+++ b/RELEASENOTES.md
@@ -7,6 +7,7 @@ __Bug Fixes__:
#1426 Sequential.eval() does not put model into eval mode
`torch.optim.lr_scheduler.LinearLR` `end_factor` default has been corrected, is now 1.0.
+`torch.optim.lr_scheduler.PolynomialLR` `power` type has been corrected, is now double.
# NuGet Version 0.105.0
diff --git a/src/TorchSharp/Optimizers/LRScheduler.cs b/src/TorchSharp/Optimizers/LRScheduler.cs
index dac28c25d..9305467d4 100644
--- a/src/TorchSharp/Optimizers/LRScheduler.cs
+++ b/src/TorchSharp/Optimizers/LRScheduler.cs
@@ -325,7 +325,7 @@ public class PolynomialLR : LRScheduler
/// The index of last epoch. Default: -1.
/// If true, prints a message to stdout for each update. Default: false.
/// A scheduler
- public PolynomialLR(Optimizer optimizer, int total_iters = 5, int power = 1, int last_epoch = -1, bool verbose = false) : base(optimizer, last_epoch, verbose)
+ public PolynomialLR(Optimizer optimizer, int total_iters = 5, double power = 1.0, int last_epoch = -1, bool verbose = false) : base(optimizer, last_epoch, verbose)
{
if (optimizer == null) throw new ArgumentNullException("optimizer");
_power = power;
@@ -359,7 +359,7 @@ protected override IEnumerable get_closed_form_lr()
}
private double _total_iters;
- private int _power;
+ private double _power;
}
///
@@ -1306,7 +1306,7 @@ public static LRScheduler MultiStepLR(Optimizer optimizer, IList milestones
/// The index of last epoch. Default: -1.
/// If true, prints a message to stdout for each update. Default: false.
/// A scheduler
- public static LRScheduler PolynomialLR(Optimizer optimizer, int total_iters = 5, int power = 1, int last_epoch = -1, bool verbose = false)
+ public static LRScheduler PolynomialLR(Optimizer optimizer, int total_iters = 5, double power = 1, int last_epoch = -1, bool verbose = false)
{
return new impl.PolynomialLR(optimizer, total_iters, power, last_epoch, verbose);
}
diff --git a/test/TorchSharpTest/TestTorchTensorBugs.cs b/test/TorchSharpTest/TestTorchTensorBugs.cs
index 0493b604e..4055a5cce 100644
--- a/test/TorchSharpTest/TestTorchTensorBugs.cs
+++ b/test/TorchSharpTest/TestTorchTensorBugs.cs
@@ -873,7 +873,7 @@ public void ValidatePolynomialLR()
double learning_rate = 0.1;
var optimizer = torch.optim.SGD(seq.parameters(), learning_rate);
- var scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, 10, 1);
+ var scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, 10, 1.0);
optimizer.zero_grad();
optimizer.step();
@@ -907,7 +907,7 @@ public void ValidatePolynomialLR()
double learning_rate = 0.1;
var optimizer = torch.optim.SGD(seq.parameters(), learning_rate);
- var scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, 10, 2);
+ var scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, 10, 2.0);
optimizer.zero_grad();
optimizer.step();
diff --git a/test/TorchSharpTest/TestTraining.cs b/test/TorchSharpTest/TestTraining.cs
index 2d3f02bca..3169b0f57 100644
--- a/test/TorchSharpTest/TestTraining.cs
+++ b/test/TorchSharpTest/TestTraining.cs
@@ -1654,7 +1654,7 @@ public void TrainingSGDSequentialLRWithAllClosedFormSchedulers()
var scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer, 2);
var scheduler3 = torch.optim.lr_scheduler.MultiStepLR(optimizer, new[] { 2, 4 });
var scheduler4 = torch.optim.lr_scheduler.ExponentialLR(optimizer);
- var scheduler5 = torch.optim.lr_scheduler.PolynomialLR(optimizer, power: 2);
+ var scheduler5 = torch.optim.lr_scheduler.PolynomialLR(optimizer, power: 2.0);
var scheduler6 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 5, 0.1);
var scheduler7 = torch.optim.lr_scheduler.LinearLR(optimizer, end_factor: 0.75);
var scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, new[] { scheduler0, scheduler1, scheduler2, scheduler3, scheduler4, scheduler5, scheduler6, scheduler7}, new[] { 5, 5, 5, 5, 5, 5, 5 });