diff --git a/examples/run_music_generation.py b/examples/run_music_generation.py index 84924e7..ed0a3b7 100644 --- a/examples/run_music_generation.py +++ b/examples/run_music_generation.py @@ -48,6 +48,8 @@ def parse_args(): parser.add_argument("--mula_dtype", type=str2dtype, default="bfloat16") parser.add_argument("--codec_dtype", type=str2dtype, default="float32") parser.add_argument("--lazy_load", type=str2bool, default=False) + parser.add_argument("--seed", type=int, default=None, + help="Random seed for reproducibility (default: random)") return parser.parse_args() @@ -67,7 +69,7 @@ def parse_args(): lazy_load=args.lazy_load, ) with torch.no_grad(): - pipe( + result = pipe( { "lyrics": args.lyrics, "tags": args.tags, @@ -77,5 +79,7 @@ def parse_args(): topk=args.topk, temperature=args.temperature, cfg_scale=args.cfg_scale, + seed=args.seed, ) print(f"Generated music saved to {args.save_path}") + print(f"Seed used: {result['seed']}") diff --git a/src/heartlib/pipelines/music_generation.py b/src/heartlib/pipelines/music_generation.py index c9111ff..5e9a57f 100644 --- a/src/heartlib/pipelines/music_generation.py +++ b/src/heartlib/pipelines/music_generation.py @@ -187,6 +187,7 @@ def _sanitize_parameters(self, **kwargs): "temperature": kwargs.get("temperature", 1.0), "topk": kwargs.get("topk", 50), "cfg_scale": kwargs.get("cfg_scale", 1.5), + "seed": kwargs.get("seed", None), } postprocess_kwargs = { "save_path": kwargs.get("save_path", "output.mp3"), @@ -271,7 +272,15 @@ def _forward( temperature: float, topk: int, cfg_scale: float, + seed: Optional[int] = None, ): + # Set seed for reproducibility + if seed is None: + seed = torch.randint(0, 2**32, (1,)).item() + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + prompt_tokens = model_inputs["tokens"].to(self.mula_device) prompt_tokens_mask = model_inputs["tokens_mask"].to(self.mula_device) continuous_segment = model_inputs["muq_embed"].to(self.mula_device) @@ -333,13 +342,15 @@ def _pad_audio_token(token: torch.Tensor): frames.append(curr_token[0:1,]) frames = torch.stack(frames).permute(1, 2, 0).squeeze(0) self._unload() - return {"frames": frames} + return {"frames": frames, "seed": seed} def postprocess(self, model_outputs: Dict[str, Any], save_path: str): frames = model_outputs["frames"].to(self.codec_device) + seed = model_outputs.get("seed") wav = self.codec.detokenize(frames) self._unload() torchaudio.save(save_path, wav.to(torch.float32).cpu(), 48000) + return {"save_path": save_path, "seed": seed} def __call__(self, inputs: Dict[str, Any], **kwargs): preprocess_kwargs, forward_kwargs, postprocess_kwargs = ( @@ -347,7 +358,7 @@ def __call__(self, inputs: Dict[str, Any], **kwargs): ) model_inputs = self.preprocess(inputs, **preprocess_kwargs) model_outputs = self._forward(model_inputs, **forward_kwargs) - self.postprocess(model_outputs, **postprocess_kwargs) + return self.postprocess(model_outputs, **postprocess_kwargs) @classmethod def from_pretrained(