Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for huggingface ASR models in hg recipes #751

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
feff2cd
refactor the hg interface to support multiple models through presets
Ahmedsaed Aug 15, 2024
d9f753e
Refactor ASR evaluation code for improved extensibility
Ahmedsaed Aug 15, 2024
a3329b7
correctly manage and free cuda memory
Ahmedsaed Aug 15, 2024
da5d8d3
fix type hints and lint
Ahmedsaed Aug 15, 2024
107aa5e
Use custom `EvalSeqbatch` instead of Seq2Seq2Batch
Ahmedsaed Aug 15, 2024
721baa4
Introduce AsrDatasetConfig to handle different ASR datasets
Ahmedsaed Aug 15, 2024
1f6772c
lint and fix type hints
Ahmedsaed Aug 15, 2024
45e6571
move split to AsrDatasetConfig and move tokenizer back to AsrEvalConfig
Ahmedsaed Aug 15, 2024
7fb74ed
update default split
Ahmedsaed Aug 15, 2024
657f40f
Merge branch 'main' into hg/AsrDatasetConfig
Ahmedsaed Aug 15, 2024
d80d154
Add whisper integration
Ahmedsaed Aug 16, 2024
9f035ce
Intorduce ModelConfig for dynamically loading huggingface models
Ahmedsaed Aug 16, 2024
58dd05c
Lint and update type hints
Ahmedsaed Aug 16, 2024
f92c551
lint and fix type hints
Ahmedsaed Aug 16, 2024
dc684f0
Refactor AsrDatasetConfig
Ahmedsaed Aug 20, 2024
e081e77
Merge branch 'main' into hg/AsrDatasetConfig
Ahmedsaed Aug 20, 2024
775f191
Merge branch 'main' into hg/whisper
Ahmedsaed Aug 20, 2024
433f463
refactor dynamically imported libraries
Ahmedsaed Aug 22, 2024
0f08e00
Merge branch 'hg/whisper' of https://github.com/Ahmedsaed/fairseq2 in…
Ahmedsaed Aug 22, 2024
460c777
Merge branch 'main' into hg/AsrDatasetConfig
Ahmedsaed Aug 26, 2024
58d4e03
Merge branch 'main' into hg/whisper
Ahmedsaed Aug 26, 2024
a0a266b
refactor code and lint
Ahmedsaed Sep 4, 2024
058bdf0
Merge branch 'hg/AsrDatasetConfig' of https://github.com/Ahmedsaed/fa…
Ahmedsaed Sep 4, 2024
201127c
lint
Ahmedsaed Sep 4, 2024
667194f
refactor code and lint
Ahmedsaed Sep 4, 2024
5fc5fed
lint
Ahmedsaed Sep 4, 2024
c645950
Merge branch 'hg/AsrDatasetConfig' of https://github.com/Ahmedsaed/fa…
Ahmedsaed Sep 4, 2024
0aaa6bb
Merge branch 'hg/AsrDatasetConfig' of https://github.com/Ahmedsaed/fa…
Ahmedsaed Sep 4, 2024
9d860f5
Merge branch 'main' into hg/whisper
Ahmedsaed Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 19 additions & 23 deletions src/fairseq2/recipes/hg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,40 @@

from __future__ import annotations

try:
import datasets # type: ignore[attr-defined,import-untyped,import-not-found]
from typing import Dict, List

_has_hg_datasets = True
except ImportError:
_has_hg_datasets = False


try:
import evaluate # type: ignore[attr-defined,import-untyped,import-not-found]

_has_hg_evaluate = True
except ImportError:
_has_hg_evaluate = False
from fairseq2.recipes.cli import Cli, RecipeCommandHandler


from fairseq2.recipes.cli import Cli, RecipeCommandHandler
def check_libraries(libraries: List[str]) -> Dict[str, bool]:
"""Check if the given libraries are available."""
availability = {}
for lib in libraries:
try:
__import__(lib)
availability[lib] = True
except ImportError:
availability[lib] = False
return availability


def _setup_hg_cli(cli: Cli) -> None:
if not _has_hg_datasets or not _has_hg_evaluate:
required_libraries = ["transformers", "datasets", "evaluate", "hydra"]
if not all(check_libraries(required_libraries).values()):
return

group = cli.add_group("hg", help="Hugging Face recipes")

from fairseq2.recipes.hg.asr_eval import (
asr_eval_presets,
load_wav2vec2_asr_evaluator,
)
from fairseq2.recipes.hg.asr_eval import asr_eval_presets, load_asr_evaluator

handler = RecipeCommandHandler(
load_wav2vec2_asr_evaluator,
load_asr_evaluator,
preset_configs=asr_eval_presets,
default_preset="librispeech_asr",
default_preset="default_asr",
)

group.add_command(
"wav2vec2_asr",
"asr",
handler,
help="evaluate a wav2vec 2.0 ASR model on a downstream benchmark (default: librispeech_asr)",
help="evaluate an ASR model (default: wav2vec2) on a downstream benchmark (default: librispeech_asr)",
)
Loading
Loading