diff --git a/eland/ml/pytorch/__init__.py b/eland/ml/pytorch/__init__.py index 5f063af3..109d4a3e 100644 --- a/eland/ml/pytorch/__init__.py +++ b/eland/ml/pytorch/__init__.py @@ -20,6 +20,7 @@ FillMaskInferenceOptions, NerInferenceOptions, NlpBertTokenizationConfig, + NlpDebertaV2TokenizationConfig, NlpMPNetTokenizationConfig, NlpRobertaTokenizationConfig, NlpTrainedModelConfig, @@ -27,6 +28,7 @@ QuestionAnsweringInferenceOptions, TextClassificationInferenceOptions, TextEmbeddingInferenceOptions, + TextExpansionInferenceOptions, TextSimilarityInferenceOptions, ZeroShotClassificationInferenceOptions, ) @@ -43,12 +45,14 @@ "NerInferenceOptions", "NlpTrainedModelConfig", "NlpBertTokenizationConfig", + "NlpDebertaV2TokenizationConfig", "NlpRobertaTokenizationConfig", "NlpXLMRobertaTokenizationConfig", "NlpMPNetTokenizationConfig", "QuestionAnsweringInferenceOptions", "TextClassificationInferenceOptions", "TextEmbeddingInferenceOptions", + "TextExpansionInferenceOptions", "TextSimilarityInferenceOptions", "ZeroShotClassificationInferenceOptions", "task_type_from_model_config", diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index eddd39b7..d848afe7 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -315,10 +315,12 @@ def __init__( *, tokenization: NlpTokenizationConfig, results_field: t.Optional[str] = None, + expansion_type: t.Optional[str] = "elser", ): super().__init__(configuration_type="text_expansion") self.tokenization = tokenization self.results_field = results_field + self.expansion_type = expansion_type class TrainedModelInput: diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index 04d4ba86..101e0f22 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -32,7 +32,6 @@ from torch import Tensor from torch.profiler import profile # type: ignore from transformers import ( - BertTokenizer, PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, @@ -509,7 +508,10 @@ def _find_max_sequence_length(self) -> int: if max_len is not None and max_len < REASONABLE_MAX_LENGTH: return int(max_len) - if isinstance(self._tokenizer, BertTokenizer): + # Known max input sizes for some tokenizers + if isinstance(self._tokenizer, transformers.BertTokenizer): + return 512 + if isinstance(self._tokenizer, transformers.DebertaV2Tokenizer): return 512 raise UnknownModelInputSizeError("Cannot determine model max input length") @@ -552,6 +554,22 @@ def _create_config( tokenization=tokenization_config, embedding_size=embedding_size, ) + elif self._task_type == "text_expansion" and es_version >= (9, 2, 0): + sample_embedding = self._traceable_model.sample_output() + if type(sample_embedding) is tuple: + text_embedding = sample_embedding[0] + else: + text_embedding = sample_embedding + shape = text_embedding.shape + token_window = shape[1] + if token_window > 1: + expansion_type = "splade" + else: + expansion_type = "elser" + inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( + tokenization=tokenization_config, + expansion_type=expansion_type, + ) else: inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type]( tokenization=tokenization_config diff --git a/tests/ml/pytorch/test_pytorch_model_config_pytest.py b/tests/ml/pytorch/test_pytorch_model_config_pytest.py index aa141eac..7a45ee71 100644 --- a/tests/ml/pytorch/test_pytorch_model_config_pytest.py +++ b/tests/ml/pytorch/test_pytorch_model_config_pytest.py @@ -19,13 +19,6 @@ import pytest -try: - import sklearn # noqa: F401 - - HAS_SKLEARN = True -except ImportError: - HAS_SKLEARN = False - try: from eland.ml.pytorch.transformers import TransformerModel @@ -46,6 +39,7 @@ QuestionAnsweringInferenceOptions, TextClassificationInferenceOptions, TextEmbeddingInferenceOptions, + TextExpansionInferenceOptions, TextSimilarityInferenceOptions, ZeroShotClassificationInferenceOptions, ) @@ -54,13 +48,9 @@ except ImportError: HAS_PYTORCH = False - from tests import ES_VERSION pytestmark = [ - pytest.mark.skipif( - not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run" - ), pytest.mark.skipif( not HAS_TRANSFORMERS, reason="This test requires 'transformers' package to run" ), @@ -70,9 +60,9 @@ ] # If the required imports are missing the test will be skipped. -# Only define th test configurations if the referenced classes +# Only define the test configurations if the referenced classes # have been imported -if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS: +if HAS_PYTORCH and HAS_TRANSFORMERS: MODEL_CONFIGURATIONS = [ ( "sentence-transformers/all-distilroberta-v1", @@ -154,6 +144,14 @@ 512, None, ), + ( + "naver/splade-v3-distilbert", + "text_expansion", + TextExpansionInferenceOptions, + NlpBertTokenizationConfig, + 512, + None, + ), ] else: MODEL_CONFIGURATIONS = [] @@ -216,6 +214,9 @@ def test_model_config( if task_type == "text_similarity": assert tokenization.truncate == "second" + if task_type == "text_expansion": + assert config.inference_config.expansion_type == "splade" + del tm def test_model_config_with_prefix_string(self):