From 4b843422a347e0e02f74adb166426648b8d10144 Mon Sep 17 00:00:00 2001 From: odashi Date: Mon, 30 Jun 2025 12:43:25 +0900 Subject: [PATCH 1/4] add --source-weights option --- .../v4-corpus-ratio-abci/merge/merge.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py b/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py index a63c6d00..35985858 100644 --- a/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py +++ b/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py @@ -31,6 +31,16 @@ def parse_args(): " All models should be in the same format and have compatible parameters." ), ) + p.add_argument( + "--source-weights", + type=float, + nargs="+", + default=None, + help=( + "Weights for each source model. If not provided, " + "all models will be treated equally (weight = 1)." + ), + ) p.add_argument( "--output-model", type=pathlib.Path, @@ -71,8 +81,23 @@ def main(): logging.info(f"Source models: {args.source_models}") + # Check weights + if args.source_weights is None: + logging.info("No source weights provided, treating all models equally.") + args.source_weights = [1.0] * model_count + else + if len(args.source_weights) != model_count: + raise ValueError( + f"Number of source weights ({len(args.source_weights)}) " + f"does not match number of source models ({model_count})." + ) + if any(weight <= 0 for weight in args.source_weights): + raise ValueError("All source weights must be positive."); + + logging.info(f"Source weights: {args.source_weights}") + # Iterate through each model and accumulate the parameters - for model_path in args.source_models: + for model_path, weight in zip(args.source_models, args.source_weights): if not model_path.exists(): raise FileNotFoundError(f"Model path {model_path} does not exist.") if not model_path.is_dir(): @@ -89,8 +114,9 @@ def main(): param_sums[key] += tensor # Average the parameters + total_weight = sum(args.source_weights) for key in param_sums: - param_sums[key] /= model_count + param_sums[key] /= total_weight logging.info("Merging completed. Saving the merged model...") args.output_model.mkdir(parents=True, exist_ok=True) From 280a6664c5a556bffba2a71358693c6bea214a7a Mon Sep 17 00:00:00 2001 From: odashi Date: Fri, 4 Jul 2025 14:02:38 +0900 Subject: [PATCH 2/4] fix bug --- .../v4-corpus-ratio-abci/merge/merge.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py b/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py index 35985858..118d5cd3 100644 --- a/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py +++ b/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py @@ -1,6 +1,6 @@ # Script to apply average merging to Hugging Face models. # This script can work with the pretrain Python environment. -# +# # Usage: # python merge.py \ # --source-models /path/to/model1 /path/to/model2 \ @@ -54,7 +54,7 @@ def parse_args(): def iter_params(model_path: pathlib.Path) -> tuple[str, torch.Tensor]: """ Iterate through the parameters of a model stored in a .safetensors file. - + Args: model_path (pathlib.Path): Path to the .safetensors file. @@ -75,7 +75,7 @@ def main(): # Initialize a dictionary to hold the sum of parameters param_sums = {} model_count = len(args.source_models) - + if model_count == 0: raise ValueError("No input models provided for merging.") @@ -85,7 +85,7 @@ def main(): if args.source_weights is None: logging.info("No source weights provided, treating all models equally.") args.source_weights = [1.0] * model_count - else + else: if len(args.source_weights) != model_count: raise ValueError( f"Number of source weights ({len(args.source_weights)}) " @@ -112,7 +112,7 @@ def main(): raise ValueError(f"Shape mismatch for key '{key}': " f"{param_sums[key].shape} vs {tensor.shape}") param_sums[key] += tensor - + # Average the parameters total_weight = sum(args.source_weights) for key in param_sums: @@ -120,12 +120,12 @@ def main(): logging.info("Merging completed. Saving the merged model...") args.output_model.mkdir(parents=True, exist_ok=True) - + # Copy original files other than .safetensors for file in args.source_models[0].iterdir(): if file.suffix != ".safetensors": shutil.copy(file, args.output_model / file.name) - + # There should be `model.safetensors.index.json` file # containing the mapping of parameter names to their destination file names. index_file = args.output_model / "model.safetensors.index.json" @@ -133,16 +133,16 @@ def main(): raise FileNotFoundError(f"Index file {index_file} does not exist.") with index_file.open("r") as f: weight_map = json.load(f)["weight_map"] - + # Check if the weight map is consistent with the parameters if set(weight_map.keys()) != set(param_sums.keys()): raise ValueError("Weight map keys do not match the parameter keys.") - + # Make inverse mapping for saving output_map = {k: [] for k in set(weight_map.values())} for k, v in weight_map.items(): output_map[v].append(k) - + metadata = {"format": "pt"} # Save all parameters @@ -151,7 +151,7 @@ def main(): output_path = args.output_model / file_name logging.info(f" Saving parameters to {output_path}") safetensors.torch.save_file(tensors, output_path, metadata=metadata) - + logging.info("Merged model saved successfully.") From 34fb5c7d2c19d25253477dce355493cc66432c87 Mon Sep 17 00:00:00 2001 From: odashi Date: Fri, 11 Jul 2025 09:52:19 +0900 Subject: [PATCH 3/4] add aggregation-method --- .../v4-corpus-ratio-abci/merge/merge.py | 54 +++++++++++-------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py b/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py index 118d5cd3..6713e7f8 100644 --- a/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py +++ b/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py @@ -4,6 +4,8 @@ # Usage: # python merge.py \ # --source-models /path/to/model1 /path/to/model2 \ +# --source-weights 1.0 2.0 \ # Optional +# --aggregation-method average \ # Optional, # --output-model /path/to/output_model import argparse @@ -41,6 +43,15 @@ def parse_args(): "all models will be treated equally (weight = 1)." ), ) + p.add_argument( + "--aggregation-method", + choices=["average", "sum"], + default="average", + help=( + "Method to aggregate parameters. " + "'average' will average the parameters, while 'sum' will sum them up." + ), + ) p.add_argument( "--output-model", type=pathlib.Path, @@ -72,30 +83,32 @@ def iter_params(model_path: pathlib.Path) -> tuple[str, torch.Tensor]: def main(): args = parse_args() - # Initialize a dictionary to hold the sum of parameters - param_sums = {} - model_count = len(args.source_models) - - if model_count == 0: - raise ValueError("No input models provided for merging.") - logging.info(f"Source models: {args.source_models}") - # Check weights - if args.source_weights is None: - logging.info("No source weights provided, treating all models equally.") - args.source_weights = [1.0] * model_count + if args.source_weights is not None: + if len(args.source_weights) != len(args.source_models): + raise ValueError("Number of source weights must match number of source models.") else: - if len(args.source_weights) != model_count: - raise ValueError( - f"Number of source weights ({len(args.source_weights)}) " - f"does not match number of source models ({model_count})." - ) - if any(weight <= 0 for weight in args.source_weights): - raise ValueError("All source weights must be positive."); + args.source_weights = [1.0] * len(args.source_models) logging.info(f"Source weights: {args.source_weights}") + match args.aggregation_method: + case "average": + if any(x <= 0 for x in args.source_weights): + raise ValueError("All source weights must be positive for --aggregation-method=average.") + denominator = sum(args.source_weights) + case "sum": + denominator = 1.0 + case _: + raise ValueError(f"Unknown aggregation method: {args.aggregation_method}") + + logging.info(f"Aggregation method: {args.aggregation_method}, denominator: {denominator}") + + # Initialize a dictionary to hold the sum of parameters + param_sums = {} + model_count = len(args.source_models) + # Iterate through each model and accumulate the parameters for model_path, weight in zip(args.source_models, args.source_weights): if not model_path.exists(): @@ -113,10 +126,9 @@ def main(): f"{param_sums[key].shape} vs {tensor.shape}") param_sums[key] += tensor - # Average the parameters - total_weight = sum(args.source_weights) + # Normalize the parameters for key in param_sums: - param_sums[key] /= total_weight + param_sums[key] /= denominator logging.info("Merging completed. Saving the merged model...") args.output_model.mkdir(parents=True, exist_ok=True) From f9c20a0eded0f6bbaf341477d31761e0cb56ac49 Mon Sep 17 00:00:00 2001 From: odashi Date: Tue, 19 Aug 2025 14:08:33 +0900 Subject: [PATCH 4/4] fix bug --- pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py b/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py index 6713e7f8..7f8e0852 100644 --- a/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py +++ b/pretrain/scripts/v4-corpus-ratio-abci/merge/merge.py @@ -119,12 +119,12 @@ def main(): for key, tensor in iter_params(model_path): if key not in param_sums: - param_sums[key] = tensor + param_sums[key] = tensor * weight else: if param_sums[key].shape != tensor.shape: raise ValueError(f"Shape mismatch for key '{key}': " f"{param_sums[key].shape} vs {tensor.shape}") - param_sums[key] += tensor + param_sums[key] += tensor * weight # Normalize the parameters for key in param_sums: