diff --git a/torchao/prototype/parq/optim/parq.py b/torchao/prototype/parq/optim/parq.py index ade403a87d..43dd5ea149 100644 --- a/torchao/prototype/parq/optim/parq.py +++ b/torchao/prototype/parq/optim/parq.py @@ -47,7 +47,7 @@ def __init__( steepness: float = 10, anneal_center: float = 0.5, ) -> None: - assert anneal_start < anneal_end, "PARQ annealing: start before end." + assert anneal_start <= anneal_end, "PARQ annealing: start before end." assert steepness > 0, "PARQ annealing steepness should be positive." self.anneal_start = anneal_start self.anneal_end = anneal_end