diff --git a/acestep/training/trainer.py b/acestep/training/trainer.py index fa5b7ea8..1b625881 100644 --- a/acestep/training/trainer.py +++ b/acestep/training/trainer.py @@ -135,12 +135,27 @@ def __init__( # Inject LoRA into the decoder only if check_peft_available(): + # Fix: Force tensors out of inference mode before injection + for param in model.parameters(): + param.data = param.data.clone() + if param.is_inference(): + with torch.no_grad(): + param.data = param.data.clone() + self.model, self.lora_info = inject_lora_into_dit(model, lora_config) logger.info(f"LoRA injected: {self.lora_info['trainable_params']:,} trainable params") else: self.model = model self.lora_info = {} logger.warning("PEFT not available, training without LoRA adapters") + + # Added Torch Compile Logic + if hasattr(torch, "compile") and self.device_type == "cuda": + logger.info("Compiling DiT decoder...") + self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA + logger.info("torch.compile successful") + else: + logger.warning("torch.compile is not available on this PyTorch version.") # Model config for flow matching self.config = model.config