diff --git a/eland/index.py b/eland/index.py index 2f3cfd33..550c88a9 100644 --- a/eland/index.py +++ b/eland/index.py @@ -50,7 +50,7 @@ def __init__( # index_field.setter self._is_source_field = False - self.es_index_field = es_index_field + self.es_index_field = es_index_field # type: ignore @property def sort_field(self) -> str: diff --git a/eland/ml/pytorch/__init__.py b/eland/ml/pytorch/__init__.py index 5f063af3..9a2c2b1b 100644 --- a/eland/ml/pytorch/__init__.py +++ b/eland/ml/pytorch/__init__.py @@ -20,6 +20,7 @@ FillMaskInferenceOptions, NerInferenceOptions, NlpBertTokenizationConfig, + NlpDebertaV2TokenizationConfig, NlpMPNetTokenizationConfig, NlpRobertaTokenizationConfig, NlpTrainedModelConfig, @@ -30,11 +31,9 @@ TextSimilarityInferenceOptions, ZeroShotClassificationInferenceOptions, ) +from eland.ml.pytorch.tokenizers import UnknownModelInputSizeError from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401 -from eland.ml.pytorch.transformers import ( - UnknownModelInputSizeError, - task_type_from_model_config, -) +from eland.ml.pytorch.transformers import task_type_from_model_config __all__ = [ "PyTorchModel", @@ -43,6 +42,7 @@ "NerInferenceOptions", "NlpTrainedModelConfig", "NlpBertTokenizationConfig", + "NlpDebertaV2TokenizationConfig", "NlpRobertaTokenizationConfig", "NlpXLMRobertaTokenizationConfig", "NlpMPNetTokenizationConfig", diff --git a/eland/ml/pytorch/tokenizers.py b/eland/ml/pytorch/tokenizers.py new file mode 100644 index 00000000..fe6c3556 --- /dev/null +++ b/eland/ml/pytorch/tokenizers.py @@ -0,0 +1,162 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import transformers + +from eland.ml.pytorch.nlp_ml_model import ( + NlpBertJapaneseTokenizationConfig, + NlpBertTokenizationConfig, + NlpDebertaV2TokenizationConfig, + NlpMPNetTokenizationConfig, + NlpRobertaTokenizationConfig, + NlpTokenizationConfig, + NlpXLMRobertaTokenizationConfig, +) + +SUPPORTED_TOKENIZERS = ( + transformers.BertTokenizer, + transformers.BertTokenizerFast, + transformers.BertJapaneseTokenizer, + transformers.MPNetTokenizer, + transformers.MPNetTokenizerFast, + transformers.DPRContextEncoderTokenizer, + transformers.DPRContextEncoderTokenizerFast, + transformers.DPRQuestionEncoderTokenizer, + transformers.DPRQuestionEncoderTokenizerFast, + transformers.DistilBertTokenizer, + transformers.DistilBertTokenizerFast, + transformers.ElectraTokenizer, + transformers.ElectraTokenizerFast, + transformers.MobileBertTokenizer, + transformers.MobileBertTokenizerFast, + transformers.RetriBertTokenizer, + transformers.RetriBertTokenizerFast, + transformers.RobertaTokenizer, + transformers.RobertaTokenizerFast, + transformers.BartTokenizer, + transformers.BartTokenizerFast, + transformers.SqueezeBertTokenizer, + transformers.SqueezeBertTokenizerFast, + transformers.XLMRobertaTokenizer, + transformers.XLMRobertaTokenizerFast, + transformers.DebertaV2Tokenizer, + transformers.DebertaV2TokenizerFast, +) +SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS])) + + +class UnknownModelInputSizeError(Exception): + pass + + +def find_max_sequence_length( + model_id: str, + tokenizer: Union[ + transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast + ], +) -> int: + # Sometimes the max_... values are present but contain + # a random or very large value. + REASONABLE_MAX_LENGTH = 8192 + max_len = getattr(tokenizer, "model_max_length", None) + if max_len is not None and max_len <= REASONABLE_MAX_LENGTH: + return int(max_len) + + max_sizes = getattr(tokenizer, "max_model_input_sizes", dict()) + max_len = max_sizes.get(model_id) + if max_len is not None and max_len < REASONABLE_MAX_LENGTH: + return int(max_len) + + if max_sizes: + # The model id wasn't found in the max sizes dict but + # if all the values correspond then take that value + sizes = {size for size in max_sizes.values()} + if len(sizes) == 1: + max_len = sizes.pop() + if max_len is not None and max_len < REASONABLE_MAX_LENGTH: + return int(max_len) + + if isinstance( + tokenizer, (transformers.BertTokenizer, transformers.BertTokenizerFast) + ): + return 512 + + raise UnknownModelInputSizeError("Cannot determine model max input length") + + +def create_tokenization_config( + model_id: str, + max_model_input_size: Optional[int], + tokenizer: Union[ + transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast + ], +) -> NlpTokenizationConfig: + if max_model_input_size is not None: + _max_sequence_length = max_model_input_size + else: + _max_sequence_length = find_max_sequence_length(model_id, tokenizer) + + if isinstance( + tokenizer, (transformers.MPNetTokenizer, transformers.MPNetTokenizerFast) + ): + return NlpMPNetTokenizationConfig( + do_lower_case=getattr(tokenizer, "do_lower_case", None), + max_sequence_length=_max_sequence_length, + ) + elif isinstance( + tokenizer, + ( + transformers.RobertaTokenizer, + transformers.RobertaTokenizerFast, + transformers.BartTokenizer, + transformers.BartTokenizerFast, + ), + ): + return NlpRobertaTokenizationConfig( + add_prefix_space=getattr(tokenizer, "add_prefix_space", None), + max_sequence_length=_max_sequence_length, + ) + elif isinstance( + tokenizer, + (transformers.XLMRobertaTokenizer, transformers.XLMRobertaTokenizerFast), + ): + return NlpXLMRobertaTokenizationConfig(max_sequence_length=_max_sequence_length) + elif isinstance( + tokenizer, + (transformers.DebertaV2Tokenizer, transformers.DebertaV2TokenizerFast), + ): + return NlpDebertaV2TokenizationConfig( + max_sequence_length=_max_sequence_length, + do_lower_case=getattr(tokenizer, "do_lower_case", None), + ) + else: + japanese_morphological_tokenizers = ["mecab"] + if ( + hasattr(tokenizer, "word_tokenizer_type") + and tokenizer.word_tokenizer_type in japanese_morphological_tokenizers + ): + return NlpBertJapaneseTokenizationConfig( + do_lower_case=getattr(tokenizer, "do_lower_case", None), + max_sequence_length=_max_sequence_length, + ) + else: + return NlpBertTokenizationConfig( + do_lower_case=getattr(tokenizer, "do_lower_case", None), + max_sequence_length=_max_sequence_length, + ) diff --git a/eland/ml/pytorch/traceable_model.py b/eland/ml/pytorch/traceable_model.py index 7b8e13c3..6dd1bcda 100644 --- a/eland/ml/pytorch/traceable_model.py +++ b/eland/ml/pytorch/traceable_model.py @@ -17,10 +17,19 @@ import os.path from abc import ABC, abstractmethod -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import torch # type: ignore -from torch import nn +import transformers +from torch import Tensor, nn +from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast + +from eland.ml.pytorch.wrappers import ( + _DistilBertWrapper, + _DPREncoderWrapper, + _QuestionAnsweringWrapperModule, + _SentenceTransformerWrapperModule, +) TracedModelTypes = Union[ torch.nn.Module, @@ -68,3 +77,164 @@ def save(self, path: str) -> str: @property def model(self) -> nn.Module: return self._model + + +class TransformerTraceableModel(TraceableModel): + """A base class representing a HuggingFace transformer model that can be traced.""" + + def __init__( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + model: Union[ + PreTrainedModel, + _SentenceTransformerWrapperModule, + _DPREncoderWrapper, + _DistilBertWrapper, + _QuestionAnsweringWrapperModule, + ], + ): + super(TransformerTraceableModel, self).__init__(model=model) + self._tokenizer = tokenizer + + def _trace(self) -> TracedModelTypes: + inputs = self._compatible_inputs() + return torch.jit.trace(self._model, example_inputs=inputs) + + def sample_output(self) -> Tensor: + inputs = self._compatible_inputs() + return self._model(*inputs) + + def _compatible_inputs(self) -> Tuple[Tensor, ...]: + inputs = self._prepare_inputs() + + # Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface + if "token_type_ids" not in inputs: + inputs["token_type_ids"] = torch.zeros( + inputs["input_ids"].size(1), dtype=torch.long + ) + if isinstance( + self._tokenizer, + ( + transformers.BartTokenizer, + transformers.MPNetTokenizer, + transformers.RobertaTokenizer, + transformers.XLMRobertaTokenizer, + ), + ): + return (inputs["input_ids"], inputs["attention_mask"]) + + if isinstance(self._tokenizer, transformers.DebertaV2Tokenizer): + return ( + inputs["input_ids"], + inputs["attention_mask"], + inputs["token_type_ids"], + ) + + position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long) + inputs["position_ids"] = position_ids + return ( + inputs["input_ids"], + inputs["attention_mask"], + inputs["token_type_ids"], + inputs["position_ids"], + ) + + @abstractmethod + def _prepare_inputs(self) -> transformers.BatchEncoding: ... + + +class TraceableClassificationModel(TransformerTraceableModel, ABC): + def classification_labels(self) -> Optional[List[str]]: + id_label_items = self._model.config.id2label.items() + labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] + + # Make classes like I-PER into I_PER which fits Java enumerations + return [label.replace("-", "_") for label in labels] + + +class TraceableFillMaskModel(TransformerTraceableModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "Who was Jim Henson?", + "[MASK] Henson was a puppeteer", + padding="max_length", + return_tensors="pt", + ) + + +class TraceableTextExpansionModel(TransformerTraceableModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "This is an example sentence.", + padding="max_length", + return_tensors="pt", + ) + + +class TraceableNerModel(TraceableClassificationModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + ( + "Hugging Face Inc. is a company based in New York City. " + "Its headquarters are in DUMBO, therefore very close to the Manhattan Bridge." + ), + padding="max_length", + return_tensors="pt", + ) + + +class TraceablePassThroughModel(TransformerTraceableModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "This is an example sentence.", + padding="max_length", + return_tensors="pt", + ) + + +class TraceableTextClassificationModel(TraceableClassificationModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "This is an example sentence.", + padding="max_length", + return_tensors="pt", + ) + + +class TraceableTextEmbeddingModel(TransformerTraceableModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "This is an example sentence.", + padding="longest", + return_tensors="pt", + ) + + +class TraceableZeroShotClassificationModel(TraceableClassificationModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "This is an example sentence.", + "This example is an example.", + padding="max_length", + return_tensors="pt", + ) + + +class TraceableQuestionAnsweringModel(TransformerTraceableModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "What is the meaning of life?" + "The meaning of life, according to the hitchikers guide, is 42.", + padding="max_length", + return_tensors="pt", + ) + + +class TraceableTextSimilarityModel(TransformerTraceableModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "What is the meaning of life?" + "The meaning of life, according to the hitchikers guide, is 42.", + padding="max_length", + return_tensors="pt", + ) diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 04d4ba86..9f70af63 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -24,32 +24,18 @@ import os.path import random import re -from abc import ABC, abstractmethod from typing import Dict, List, Optional, Set, Tuple, Union import torch # type: ignore import transformers # type: ignore from torch import Tensor from torch.profiler import profile # type: ignore -from transformers import ( - BertTokenizer, - PretrainedConfig, - PreTrainedModel, - PreTrainedTokenizer, - PreTrainedTokenizerFast, -) +from transformers import PretrainedConfig from eland.ml.pytorch.nlp_ml_model import ( FillMaskInferenceOptions, NerInferenceOptions, - NlpBertJapaneseTokenizationConfig, - NlpBertTokenizationConfig, - NlpDebertaV2TokenizationConfig, - NlpMPNetTokenizationConfig, - NlpRobertaTokenizationConfig, - NlpTokenizationConfig, NlpTrainedModelConfig, - NlpXLMRobertaTokenizationConfig, PassThroughInferenceOptions, PrefixStrings, QuestionAnsweringInferenceOptions, @@ -60,7 +46,23 @@ TrainedModelInput, ZeroShotClassificationInferenceOptions, ) -from eland.ml.pytorch.traceable_model import TraceableModel +from eland.ml.pytorch.tokenizers import ( + SUPPORTED_TOKENIZERS, + SUPPORTED_TOKENIZERS_NAMES, + create_tokenization_config, +) +from eland.ml.pytorch.traceable_model import ( + TraceableFillMaskModel, + TraceableNerModel, + TraceablePassThroughModel, + TraceableQuestionAnsweringModel, + TraceableTextClassificationModel, + TraceableTextEmbeddingModel, + TraceableTextExpansionModel, + TraceableTextSimilarityModel, + TraceableZeroShotClassificationModel, + TransformerTraceableModel, +) from eland.ml.pytorch.wrappers import ( _DistilBertWrapper, _DPREncoderWrapper, @@ -104,23 +106,7 @@ "text_similarity": TextSimilarityInferenceOptions, } SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES)) -SUPPORTED_TOKENIZERS = ( - transformers.BertTokenizer, - transformers.BertJapaneseTokenizer, - transformers.MPNetTokenizer, - transformers.DPRContextEncoderTokenizer, - transformers.DPRQuestionEncoderTokenizer, - transformers.DistilBertTokenizer, - transformers.ElectraTokenizer, - transformers.MobileBertTokenizer, - transformers.RetriBertTokenizer, - transformers.RobertaTokenizer, - transformers.BartTokenizer, - transformers.SqueezeBertTokenizer, - transformers.XLMRobertaTokenizer, - transformers.DebertaV2Tokenizer, -) -SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS])) + TracedModelTypes = Union[ torch.nn.Module, @@ -134,10 +120,6 @@ class TaskTypeError(Exception): pass -class UnknownModelInputSizeError(Exception): - pass - - def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str]: if model_config.architectures is None: if model_config.name_or_path.startswith("sentence-transformers/"): @@ -173,166 +155,6 @@ def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str] return potential_task_types.pop() -class _TransformerTraceableModel(TraceableModel): - """A base class representing a HuggingFace transformer model that can be traced.""" - - def __init__( - self, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - model: Union[ - PreTrainedModel, - _SentenceTransformerWrapperModule, - _DPREncoderWrapper, - _DistilBertWrapper, - ], - ): - super(_TransformerTraceableModel, self).__init__(model=model) - self._tokenizer = tokenizer - - def _trace(self) -> TracedModelTypes: - inputs = self._compatible_inputs() - return torch.jit.trace(self._model, example_inputs=inputs) - - def sample_output(self) -> Tensor: - inputs = self._compatible_inputs() - return self._model(*inputs) - - def _compatible_inputs(self) -> Tuple[Tensor, ...]: - inputs = self._prepare_inputs() - - # Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface - if "token_type_ids" not in inputs: - inputs["token_type_ids"] = torch.zeros( - inputs["input_ids"].size(1), dtype=torch.long - ) - if isinstance( - self._tokenizer, - ( - transformers.BartTokenizer, - transformers.MPNetTokenizer, - transformers.RobertaTokenizer, - transformers.XLMRobertaTokenizer, - ), - ): - return (inputs["input_ids"], inputs["attention_mask"]) - - if isinstance(self._tokenizer, transformers.DebertaV2Tokenizer): - return ( - inputs["input_ids"], - inputs["attention_mask"], - inputs["token_type_ids"], - ) - - position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long) - inputs["position_ids"] = position_ids - return ( - inputs["input_ids"], - inputs["attention_mask"], - inputs["token_type_ids"], - inputs["position_ids"], - ) - - @abstractmethod - def _prepare_inputs(self) -> transformers.BatchEncoding: ... - - -class _TraceableClassificationModel(_TransformerTraceableModel, ABC): - def classification_labels(self) -> Optional[List[str]]: - id_label_items = self._model.config.id2label.items() - labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])] - - # Make classes like I-PER into I_PER which fits Java enumerations - return [label.replace("-", "_") for label in labels] - - -class _TraceableFillMaskModel(_TransformerTraceableModel): - def _prepare_inputs(self) -> transformers.BatchEncoding: - return self._tokenizer( - "Who was Jim Henson?", - "[MASK] Henson was a puppeteer", - padding="max_length", - return_tensors="pt", - ) - - -class _TraceableTextExpansionModel(_TransformerTraceableModel): - def _prepare_inputs(self) -> transformers.BatchEncoding: - return self._tokenizer( - "This is an example sentence.", - padding="max_length", - return_tensors="pt", - ) - - -class _TraceableNerModel(_TraceableClassificationModel): - def _prepare_inputs(self) -> transformers.BatchEncoding: - return self._tokenizer( - ( - "Hugging Face Inc. is a company based in New York City. " - "Its headquarters are in DUMBO, therefore very close to the Manhattan Bridge." - ), - padding="max_length", - return_tensors="pt", - ) - - -class _TraceablePassThroughModel(_TransformerTraceableModel): - def _prepare_inputs(self) -> transformers.BatchEncoding: - return self._tokenizer( - "This is an example sentence.", - padding="max_length", - return_tensors="pt", - ) - - -class _TraceableTextClassificationModel(_TraceableClassificationModel): - def _prepare_inputs(self) -> transformers.BatchEncoding: - return self._tokenizer( - "This is an example sentence.", - padding="max_length", - return_tensors="pt", - ) - - -class _TraceableTextEmbeddingModel(_TransformerTraceableModel): - def _prepare_inputs(self) -> transformers.BatchEncoding: - return self._tokenizer( - "This is an example sentence.", - padding="longest", - return_tensors="pt", - ) - - -class _TraceableZeroShotClassificationModel(_TraceableClassificationModel): - def _prepare_inputs(self) -> transformers.BatchEncoding: - return self._tokenizer( - "This is an example sentence.", - "This example is an example.", - padding="max_length", - return_tensors="pt", - ) - - -class _TraceableQuestionAnsweringModel(_TransformerTraceableModel): - def _prepare_inputs(self) -> transformers.BatchEncoding: - return self._tokenizer( - "What is the meaning of life?" - "The meaning of life, according to the hitchikers guide, is 42.", - padding="max_length", - return_tensors="pt", - ) - - -class _TraceableTextSimilarityModel(_TransformerTraceableModel): - def _prepare_inputs(self) -> transformers.BatchEncoding: - return self._tokenizer( - "What is the meaning of life?" - "The meaning of life, according to the hitchikers guide, is 42.", - padding="max_length", - return_tensors="pt", - ) - - class TransformerModel: def __init__( self, @@ -395,7 +217,9 @@ def __init__( # use padding in the tokenizer to ensure max length sequences are used for tracing (at call time) # - see: https://huggingface.co/transformers/serialization.html#dummy-inputs-and-standard-lengths self._tokenizer = transformers.AutoTokenizer.from_pretrained( - self._model_id, token=self._access_token, use_fast=False + self._model_id, + token=self._access_token, + use_fast=True, # TODO not all tokenizers support fast mode ) # check for a supported tokenizer @@ -443,81 +267,12 @@ def _load_vocab(self) -> Dict[str, List[str]]: vocab_obj["scores"] = scores return vocab_obj - def _create_tokenization_config(self) -> NlpTokenizationConfig: - if self._max_model_input_size: - _max_sequence_length = self._max_model_input_size - else: - _max_sequence_length = self._find_max_sequence_length() - - if isinstance(self._tokenizer, transformers.MPNetTokenizer): - return NlpMPNetTokenizationConfig( - do_lower_case=getattr(self._tokenizer, "do_lower_case", None), - max_sequence_length=_max_sequence_length, - ) - elif isinstance( - self._tokenizer, (transformers.RobertaTokenizer, transformers.BartTokenizer) - ): - return NlpRobertaTokenizationConfig( - add_prefix_space=getattr(self._tokenizer, "add_prefix_space", None), - max_sequence_length=_max_sequence_length, - ) - elif isinstance(self._tokenizer, transformers.XLMRobertaTokenizer): - return NlpXLMRobertaTokenizationConfig( - max_sequence_length=_max_sequence_length - ) - elif isinstance(self._tokenizer, transformers.DebertaV2Tokenizer): - return NlpDebertaV2TokenizationConfig( - max_sequence_length=_max_sequence_length, - do_lower_case=getattr(self._tokenizer, "do_lower_case", None), - ) - else: - japanese_morphological_tokenizers = ["mecab"] - if ( - hasattr(self._tokenizer, "word_tokenizer_type") - and self._tokenizer.word_tokenizer_type - in japanese_morphological_tokenizers - ): - return NlpBertJapaneseTokenizationConfig( - do_lower_case=getattr(self._tokenizer, "do_lower_case", None), - max_sequence_length=_max_sequence_length, - ) - else: - return NlpBertTokenizationConfig( - do_lower_case=getattr(self._tokenizer, "do_lower_case", None), - max_sequence_length=_max_sequence_length, - ) - - def _find_max_sequence_length(self) -> int: - # Sometimes the max_... values are present but contain - # a random or very large value. - REASONABLE_MAX_LENGTH = 8192 - max_len = getattr(self._tokenizer, "model_max_length", None) - if max_len is not None and max_len <= REASONABLE_MAX_LENGTH: - return int(max_len) - - max_sizes = getattr(self._tokenizer, "max_model_input_sizes", dict()) - max_len = max_sizes.get(self._model_id) - if max_len is not None and max_len < REASONABLE_MAX_LENGTH: - return int(max_len) - - if max_sizes: - # The model id wasn't found in the max sizes dict but - # if all the values correspond then take that value - sizes = {size for size in max_sizes.values()} - if len(sizes) == 1: - max_len = sizes.pop() - if max_len is not None and max_len < REASONABLE_MAX_LENGTH: - return int(max_len) - - if isinstance(self._tokenizer, BertTokenizer): - return 512 - - raise UnknownModelInputSizeError("Cannot determine model max input length") - def _create_config( self, es_version: Optional[Tuple[int, int, int]] ) -> NlpTrainedModelConfig: - tokenization_config = self._create_tokenization_config() + tokenization_config = create_tokenization_config( + self._model_id, self._max_model_input_size, self._tokenizer + ) # Set squad well known defaults if self._task_type == "question_answering": @@ -713,7 +468,7 @@ def _make_inputs_compatible( inputs["position_ids"], ) - def _create_traceable_model(self) -> _TransformerTraceableModel: + def _create_traceable_model(self) -> TransformerTraceableModel: if self._task_type == "auto": model = transformers.AutoModel.from_pretrained( self._model_id, token=self._access_token, torchscript=True @@ -731,28 +486,28 @@ def _create_traceable_model(self) -> _TransformerTraceableModel: self._model_id, token=self._access_token, torchscript=True ) model = _DistilBertWrapper.try_wrapping(model) - return _TraceableTextExpansionModel(self._tokenizer, model) + return TraceableTextExpansionModel(self._tokenizer, model) if self._task_type == "fill_mask": model = transformers.AutoModelForMaskedLM.from_pretrained( self._model_id, token=self._access_token, torchscript=True ) model = _DistilBertWrapper.try_wrapping(model) - return _TraceableFillMaskModel(self._tokenizer, model) + return TraceableFillMaskModel(self._tokenizer, model) elif self._task_type == "ner": model = transformers.AutoModelForTokenClassification.from_pretrained( self._model_id, token=self._access_token, torchscript=True ) model = _DistilBertWrapper.try_wrapping(model) - return _TraceableNerModel(self._tokenizer, model) + return TraceableNerModel(self._tokenizer, model) elif self._task_type == "text_classification": model = transformers.AutoModelForSequenceClassification.from_pretrained( self._model_id, token=self._access_token, torchscript=True ) model = _DistilBertWrapper.try_wrapping(model) - return _TraceableTextClassificationModel(self._tokenizer, model) + return TraceableTextClassificationModel(self._tokenizer, model) elif self._task_type == "text_embedding": model = _DPREncoderWrapper.from_pretrained( @@ -762,33 +517,33 @@ def _create_traceable_model(self) -> _TransformerTraceableModel: model = _SentenceTransformerWrapperModule.from_pretrained( self._model_id, self._tokenizer, token=self._access_token ) - return _TraceableTextEmbeddingModel(self._tokenizer, model) + return TraceableTextEmbeddingModel(self._tokenizer, model) elif self._task_type == "zero_shot_classification": model = transformers.AutoModelForSequenceClassification.from_pretrained( self._model_id, token=self._access_token, torchscript=True ) model = _DistilBertWrapper.try_wrapping(model) - return _TraceableZeroShotClassificationModel(self._tokenizer, model) + return TraceableZeroShotClassificationModel(self._tokenizer, model) elif self._task_type == "question_answering": model = _QuestionAnsweringWrapperModule.from_pretrained( self._model_id, token=self._access_token ) - return _TraceableQuestionAnsweringModel(self._tokenizer, model) + return TraceableQuestionAnsweringModel(self._tokenizer, model) elif self._task_type == "text_similarity": model = transformers.AutoModelForSequenceClassification.from_pretrained( self._model_id, token=self._access_token, torchscript=True ) model = _DistilBertWrapper.try_wrapping(model) - return _TraceableTextSimilarityModel(self._tokenizer, model) + return TraceableTextSimilarityModel(self._tokenizer, model) elif self._task_type == "pass_through": model = transformers.AutoModel.from_pretrained( self._model_id, token=self._access_token, torchscript=True ) - return _TraceablePassThroughModel(self._tokenizer, model) + return TraceablePassThroughModel(self._tokenizer, model) else: raise TypeError( diff --git a/tests/ml/pytorch/test_pytorch_model_config_pytest.py b/tests/ml/pytorch/test_pytorch_model_config_pytest.py index aa141eac..60756319 100644 --- a/tests/ml/pytorch/test_pytorch_model_config_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_config_pytest.py @@ -19,13 +19,6 @@ import pytest -try: - import sklearn # noqa: F401 - - HAS_SKLEARN = True -except ImportError: - HAS_SKLEARN = False - try: from eland.ml.pytorch.transformers import TransformerModel @@ -58,9 +51,6 @@ from tests import ES_VERSION pytestmark = [ - pytest.mark.skipif( - not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run" - ), pytest.mark.skipif( not HAS_TRANSFORMERS, reason="This test requires 'transformers' package to run" ), @@ -72,7 +62,7 @@ # If the required imports are missing the test will be skipped. # Only define th test configurations if the referenced classes # have been imported -if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS: +if HAS_PYTORCH and HAS_TRANSFORMERS: MODEL_CONFIGURATIONS = [ ( "sentence-transformers/all-distilroberta-v1", @@ -154,6 +144,14 @@ 512, None, ), + ( + "jinaai/jina-reranker-v2-base-multilingual", + "text_similarity", + TextSimilarityInferenceOptions, + NlpXLMRobertaTokenizationConfig, + 1024, + None, + ), ] else: MODEL_CONFIGURATIONS = []