From fd119ea15658a7b6bc47bf7b6c84870c5b3d644c Mon Sep 17 00:00:00 2001 From: John Lambert Date: Fri, 21 Feb 2025 12:51:38 -0500 Subject: [PATCH] Fixes --- .../zip_paratext_project_terms_parser.py | 2 +- .../zip_paratext_project_text_updater.py | 2 +- .../huggingface/hugging_face_nmt_engine.py | 20 +++++++++++-------- .../huggingface/hugging_face_nmt_model.py | 2 ++ .../hugging_face_nmt_model_trainer.py | 6 +++++- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 7 files changed, 26 insertions(+), 16 deletions(-) diff --git a/machine/corpora/zip_paratext_project_terms_parser.py b/machine/corpora/zip_paratext_project_terms_parser.py index 3f781b21..ebc208a0 100644 --- a/machine/corpora/zip_paratext_project_terms_parser.py +++ b/machine/corpora/zip_paratext_project_terms_parser.py @@ -19,5 +19,5 @@ def _exists(self, file_name: StrPath) -> bool: def _open(self, file_name: StrPath) -> Optional[BinaryIO]: if file_name in self._archive.namelist(): - return BytesIO(self._archive.read(file_name)) + return BytesIO(self._archive.read(str(file_name))) return None diff --git a/machine/corpora/zip_paratext_project_text_updater.py b/machine/corpora/zip_paratext_project_text_updater.py index 75e8ff02..b4dbd8bd 100644 --- a/machine/corpora/zip_paratext_project_text_updater.py +++ b/machine/corpora/zip_paratext_project_text_updater.py @@ -18,5 +18,5 @@ def _exists(self, file_name: StrPath) -> bool: def _open(self, file_name: StrPath) -> Optional[BinaryIO]: if file_name in self._archive.namelist(): - return BytesIO(self._archive.read(file_name)) + return BytesIO(self._archive.read(str(file_name))) return None diff --git a/machine/translation/huggingface/hugging_face_nmt_engine.py b/machine/translation/huggingface/hugging_face_nmt_engine.py index 04086afd..c5a196ed 100644 --- a/machine/translation/huggingface/hugging_face_nmt_engine.py +++ b/machine/translation/huggingface/hugging_face_nmt_engine.py @@ -80,17 +80,21 @@ def __init__( else: src_lang_token = src_lang tgt_lang_token = tgt_lang - if ( - src_lang is not None - and src_lang_token not in self._tokenizer.added_tokens_encoder - and src_lang_token not in additional_special_tokens + if src_lang is not None and ( + src_lang_token is None + or ( + src_lang_token not in self._tokenizer.added_tokens_encoder + and src_lang_token not in additional_special_tokens # type: ignore - we already check for None + ) ): raise ValueError(f"The specified model does not support the language code '{src_lang}'") - if ( - tgt_lang is not None - and tgt_lang_token not in self._tokenizer.added_tokens_encoder - and tgt_lang_token not in additional_special_tokens + if tgt_lang is not None and ( + tgt_lang_token is None + or ( + tgt_lang_token not in self._tokenizer.added_tokens_encoder + and tgt_lang_token not in additional_special_tokens # type: ignore - we already check for None + ) ): raise ValueError(f"The specified model does not support the language code '{tgt_lang}'") diff --git a/machine/translation/huggingface/hugging_face_nmt_model.py b/machine/translation/huggingface/hugging_face_nmt_model.py index 253eb9be..e0a80f27 100644 --- a/machine/translation/huggingface/hugging_face_nmt_model.py +++ b/machine/translation/huggingface/hugging_face_nmt_model.py @@ -89,6 +89,8 @@ def __init__(self, model: HuggingFaceNmtModel, corpus: Union[ParallelTextCorpus, def save(self) -> None: super().save() + if self._model.training_args.output_dir is None: + raise ValueError("Output directory must not be None.") output_dir = Path(self._model.training_args.output_dir) if output_dir != self._model._model_path: shutil.copytree(output_dir, self._model._model_path) diff --git a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py index 1192243e..448050a6 100644 --- a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py +++ b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py @@ -115,6 +115,8 @@ def train( check_canceled: Optional[Callable[[], None]] = None, ) -> None: last_checkpoint = None + if self._training_args.output_dir is None: + raise ValueError("Output directory is not set") if os.path.isdir(self._training_args.output_dir) and not self._training_args.overwrite_output_dir: last_checkpoint = get_last_checkpoint(self._training_args.output_dir) if last_checkpoint is None and any(os.path.isfile(p) for p in os.listdir(self._training_args.output_dir)): @@ -176,6 +178,8 @@ def find_missing_characters(tokenizer: Any, train_dataset: Dataset, lang_codes: return missing_characters def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any: + if self._training_args.output_dir is None: + raise ValueError("Output directory is not set") tokenizer_dir = Path(self._training_args.output_dir) tokenizer.save_pretrained(str(tokenizer_dir)) with open(tokenizer_dir / "tokenizer.json", "r+", encoding="utf-8") as file: @@ -317,7 +321,7 @@ def preprocess_function(examples): model=model, args=self._training_args, train_dataset=cast(Any, train_dataset), - tokenizer=tokenizer, + processing_class=tokenizer, data_collator=data_collator, callbacks=[ _ProgressCallback( diff --git a/poetry.lock b/poetry.lock index 96e89cdc..966bc455 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3098,13 +3098,13 @@ diagrams = ["jinja2", "railroad-diagrams"] [[package]] name = "pyright" -version = "1.1.386" +version = "1.1.394" description = "Command line wrapper for pyright" optional = false python-versions = ">=3.7" files = [ - {file = "pyright-1.1.386-py3-none-any.whl", hash = "sha256:7071ac495593b2258ccdbbf495f1a5c0e5f27951f6b429bed4e8b296eb5cd21d"}, - {file = "pyright-1.1.386.tar.gz", hash = "sha256:8e9975e34948ba5f8e07792a9c9d2bdceb2c6c0b61742b068d2229ca2bc4a9d9"}, + {file = "pyright-1.1.394-py3-none-any.whl", hash = "sha256:5f74cce0a795a295fb768759bbeeec62561215dea657edcaab48a932b031ddbb"}, + {file = "pyright-1.1.394.tar.gz", hash = "sha256:56f2a3ab88c5214a451eb71d8f2792b7700434f841ea219119ade7f42ca93608"}, ] [package.dependencies] @@ -4739,4 +4739,4 @@ thot = ["sil-thot"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "111b1e7d775714231ffcd7562afc60f4cc9dc56a43e14ffb08be52880db30284" +content-hash = "1b6eb89d3dbc2052746df113ffc526a34430b34adea49d6224b9f2679192f74a" diff --git a/pyproject.toml b/pyproject.toml index 83bfe753..8d0ca2e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ pytest-cov = "^4.1.0" ipykernel = "^6.7.0" jupyter = "^1.0.0" pandas = "^2.0.3" -pyright = { extras = ["nodejs"], version = "^1.1.362" } +pyright = { extras = ["nodejs"], version = "^1.1.394" } decoy = "^2.1.0" pep8-naming = "^0.14.1"