From 304b632299d6d5cd14d862c73677f77d0f81f8c1 Mon Sep 17 00:00:00 2001 From: Bireflection <1412947499@qq.com> Date: Tue, 25 Mar 2025 16:18:11 +0800 Subject: [PATCH] audio_spectrogram_transformer finetune --- .../audio_spectrogram_transformer/README.md | 39 ++++++ .../audio_spectrogram_transformer/finetune.py | 121 ++++++++++++++++++ 2 files changed, 160 insertions(+) create mode 100644 llm/finetune/audio_spectrogram_transformer/README.md create mode 100644 llm/finetune/audio_spectrogram_transformer/finetune.py diff --git a/llm/finetune/audio_spectrogram_transformer/README.md b/llm/finetune/audio_spectrogram_transformer/README.md new file mode 100644 index 000000000..5ff0badbb --- /dev/null +++ b/llm/finetune/audio_spectrogram_transformer/README.md @@ -0,0 +1,39 @@ +# audio_spectrogram_transformer 模型微调报告 + +## 任务 +- 模型:MIT/ast-finetuned-audioset-10-10-0.4593 +- 数据集: ashraq/esc50 + +## 结果对比: + +**Mindnlp+D910B** + +| Epoch | Training Loss | Eval Loss | Accuracy | +|------:|-------------:|----------------:|----------:| +| 1 | 3.0928 | 2.2305 | 0.8150 | +| 2 | 1.4845 | 0.9815 | 0.8950 | +| 3 | 0.5733 | 0.4876 | 0.9250 | +| 4 | 0.2061 | 0.3770 | 0.9125 | +| 5 | 0.0742 | 0.3207 | 0.9300 | +| 6 | 0.0443 | 0.1821 | **0.9600**| +| 7 | 0.0178 | 0.2144 | 0.9575 | +| 8 | 0.0111 | 0.2155 | 0.9575 | +| 9 | 0.0094 | 0.2167 | 0.9575 | +| 10 | 0.0087 | 0.2174 | 0.9575 | + +--- + +**Pytorch+3090** + +| Epoch | Training Loss | Validation Loss | Accuracy | +|------:|-------------:|----------------:|---------:| +| 1 | 1.2550 | 0.3934 | 0.8825 | +| 2 | 0.4477 | 0.3656 | 0.8925 | +| 3 | 0.3289 | 0.2777 | 0.9200 | +| 4 | 0.2200 | 0.3645 | 0.9175 | +| 5 | 0.1679 | 0.2345 | 0.9350 | +| 6 | 0.1140 | 0.1877 | 0.9575 | +| 7 | 0.0925 | 0.1641 | 0.9575 | +| 8 | 0.0648 | 0.1810 | 0.9475 | +| 9 | 0.0593 | 0.1285 | 0.9550 | +| 10 | 0.0269 | 0.1222 | **0.9575** | \ No newline at end of file diff --git a/llm/finetune/audio_spectrogram_transformer/finetune.py b/llm/finetune/audio_spectrogram_transformer/finetune.py new file mode 100644 index 000000000..eeed8b6bc --- /dev/null +++ b/llm/finetune/audio_spectrogram_transformer/finetune.py @@ -0,0 +1,121 @@ +import mindspore as ms +import numpy as np +from datasets import Audio, ClassLabel, load_dataset +from mindspore.dataset import GeneratorDataset +from sklearn.metrics import accuracy_score +from mindnlp.engine import Trainer, TrainingArguments +from mindnlp.transformers import (ASTConfig, ASTFeatureExtractor, + ASTForAudioClassification) + +ms.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend") + +# 加载esc50数据集 +esc50 = load_dataset("ashraq/esc50", split="train") + +df = esc50.select_columns(["target", "category"]).to_pandas() +class_names = df.iloc[np.unique(df["target"], return_index=True)[ + 1]]["category"].to_list() + +esc50 = esc50.cast_column("target", ClassLabel(names=class_names)) +esc50 = esc50.cast_column("audio", Audio(sampling_rate=16000)) +esc50 = esc50.rename_column("target", "labels") +num_labels = len(np.unique(esc50["labels"])) + +# 初始化AST +pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593" +feature_extractor = ASTFeatureExtractor.from_pretrained(pretrained_model) +model_input_name = feature_extractor.model_input_names[0] +SAMPLING_RATE = feature_extractor.sampling_rate + + +# 预处理音频 +def preprocess_audio(batch): + wavs = [audio["array"] for audio in batch["input_values"]] + inputs = feature_extractor( + wavs, sampling_rate=SAMPLING_RATE, return_tensors="ms") + return {model_input_name: inputs.get(model_input_name), "labels": list(batch["labels"])} + + +dataset = esc50 +label2id = dataset.features["labels"]._str2int + +# 构造训练集和测试集 +if "test" not in dataset: + dataset = dataset.train_test_split( + test_size=0.2, shuffle=True, seed=0, stratify_by_column="labels") + + +dataset = dataset.cast_column("audio", Audio( + sampling_rate=feature_extractor.sampling_rate)) +dataset = dataset.rename_column("audio", "input_values") + +dataset["train"].set_transform( + preprocess_audio, output_all_columns=False) +dataset["test"].set_transform(preprocess_audio, output_all_columns=False) + +# 加载config +config = ASTConfig.from_pretrained(pretrained_model) +config.num_labels = num_labels +config.label2id = label2id +config.id2label = {v: k for k, v in label2id.items()} + +model = ASTForAudioClassification.from_pretrained( + pretrained_model, config=config, ignore_mismatched_sizes=True) + + +def convert_mindspore_datatset(hf_dataset, batch_size): + data_list = list(hf_dataset) + + def generator(): + for item in data_list: + yield item[model_input_name], item["labels"] + # 构造MindSpore的GeneratorDataset + ds = GeneratorDataset( + source=generator, + column_names=[model_input_name, "labels"], + shuffle=False + ) + ds = ds.batch(batch_size, drop_remainder=True) + return ds + + +# 初始化训练参数 +training_args = TrainingArguments( + output_dir="./checkpoint", + logging_dir="./logs", + learning_rate=5e-5, + num_train_epochs=10, + per_device_train_batch_size=8, + evaluation_strategy="epoch", + save_strategy="epoch", + eval_steps=1, + save_steps=1, + load_best_model_at_end=True, + metric_for_best_model="accuracy", + logging_strategy="epoch", + logging_steps=20, +) + +train_ms_dataset = convert_mindspore_datatset( + dataset["train"], training_args.per_device_train_batch_size) +eval_ms_dataset = convert_mindspore_datatset( + dataset["test"], training_args.per_device_train_batch_size) + + +def compute_metrics(eval_pred): + logits = eval_pred.predictions + labels = eval_pred.label_ids + predictions = np.argmax(logits, axis=1) + return {"accuracy": accuracy_score(predictions, labels)} + + +# 初始化trainer +trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_ms_dataset, + eval_dataset=eval_ms_dataset, + compute_metrics=compute_metrics, +) + +trainer.train()