Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions eland/ml/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
FillMaskInferenceOptions,
NerInferenceOptions,
NlpBertTokenizationConfig,
NlpDebertaV2TokenizationConfig,
NlpMPNetTokenizationConfig,
NlpRobertaTokenizationConfig,
NlpTrainedModelConfig,
NlpXLMRobertaTokenizationConfig,
QuestionAnsweringInferenceOptions,
TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions,
TextExpansionInferenceOptions,
TextSimilarityInferenceOptions,
ZeroShotClassificationInferenceOptions,
)
Expand All @@ -43,12 +45,14 @@
"NerInferenceOptions",
"NlpTrainedModelConfig",
"NlpBertTokenizationConfig",
"NlpDebertaV2TokenizationConfig",
"NlpRobertaTokenizationConfig",
"NlpXLMRobertaTokenizationConfig",
"NlpMPNetTokenizationConfig",
"QuestionAnsweringInferenceOptions",
"TextClassificationInferenceOptions",
"TextEmbeddingInferenceOptions",
"TextExpansionInferenceOptions",
"TextSimilarityInferenceOptions",
"ZeroShotClassificationInferenceOptions",
"task_type_from_model_config",
Expand Down
2 changes: 2 additions & 0 deletions eland/ml/pytorch/nlp_ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,12 @@ def __init__(
*,
tokenization: NlpTokenizationConfig,
results_field: t.Optional[str] = None,
expansion_type: t.Optional[str] = "elser",
):
super().__init__(configuration_type="text_expansion")
self.tokenization = tokenization
self.results_field = results_field
self.expansion_type = expansion_type


class TrainedModelInput:
Expand Down
22 changes: 20 additions & 2 deletions eland/ml/pytorch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from torch import Tensor
from torch.profiler import profile # type: ignore
from transformers import (
BertTokenizer,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizer,
Expand Down Expand Up @@ -509,7 +508,10 @@ def _find_max_sequence_length(self) -> int:
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
return int(max_len)

if isinstance(self._tokenizer, BertTokenizer):
# Known max input sizes for some tokenizers
if isinstance(self._tokenizer, transformers.BertTokenizer):
return 512
if isinstance(self._tokenizer, transformers.DebertaV2Tokenizer):
return 512

raise UnknownModelInputSizeError("Cannot determine model max input length")
Expand Down Expand Up @@ -552,6 +554,22 @@ def _create_config(
tokenization=tokenization_config,
embedding_size=embedding_size,
)
elif self._task_type == "text_expansion" and es_version >= (9, 2, 0):
sample_embedding = self._traceable_model.sample_output()
if type(sample_embedding) is tuple:
text_embedding = sample_embedding[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename text_embedding to sparse_embedding

Copy link
Contributor Author

@daixque daixque Jul 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current codebase, text_expansion is used in anywhere. (No sparse_embedding)
https://github.com/search?q=repo%3Aelastic%2Feland%20text_expansion&type=code

% grep -inR "text_expansion" eland tests | grep -v "Binary file"
eland/ml/pytorch/transformers.py:76:    "text_expansion",
eland/ml/pytorch/transformers.py:97:    "text_expansion": TextExpansionInferenceOptions,
eland/ml/pytorch/transformers.py:557:        elif self._task_type == "text_expansion":
eland/ml/pytorch/transformers.py:747:        if self._task_type == "text_expansion":
eland/ml/pytorch/nlp_ml_model.py:320:        super().__init__(configuration_type="text_expansion")
tests/ml/pytorch/test_pytorch_model_config_pytest.py:149:            "text_expansion",
tests/ml/pytorch/test_pytorch_model_config_pytest.py:217:            if task_type == "text_expansion":

Should I rename everything? It will cause CLI interface change. Should we keep --task-type=text_expansion for the compatibility? (I feel that renaming should be another PR)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok thanks. Yes the rename is not necessary in this PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks

else:
text_embedding = sample_embedding
shape = text_embedding.shape
token_window = shape[1]
if token_window > 1:
expansion_type = "splade"
else:
expansion_type = "elser"
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
tokenization=tokenization_config,
expansion_type=expansion_type,
)
else:
inference_config = TASK_TYPE_TO_INFERENCE_CONFIG[self._task_type](
tokenization=tokenization_config
Expand Down
27 changes: 14 additions & 13 deletions tests/ml/pytorch/test_pytorch_model_config_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@

import pytest

try:
import sklearn # noqa: F401

HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False

try:
from eland.ml.pytorch.transformers import TransformerModel

Expand All @@ -46,6 +39,7 @@
QuestionAnsweringInferenceOptions,
TextClassificationInferenceOptions,
TextEmbeddingInferenceOptions,
TextExpansionInferenceOptions,
TextSimilarityInferenceOptions,
ZeroShotClassificationInferenceOptions,
)
Expand All @@ -54,13 +48,9 @@
except ImportError:
HAS_PYTORCH = False


from tests import ES_VERSION

pytestmark = [
pytest.mark.skipif(
not HAS_SKLEARN, reason="This test requires 'scikit-learn' package to run"
),
pytest.mark.skipif(
not HAS_TRANSFORMERS, reason="This test requires 'transformers' package to run"
),
Expand All @@ -70,9 +60,9 @@
]

# If the required imports are missing the test will be skipped.
# Only define th test configurations if the referenced classes
# Only define the test configurations if the referenced classes
# have been imported
if HAS_PYTORCH and HAS_SKLEARN and HAS_TRANSFORMERS:
if HAS_PYTORCH and HAS_TRANSFORMERS:
MODEL_CONFIGURATIONS = [
(
"sentence-transformers/all-distilroberta-v1",
Expand Down Expand Up @@ -154,6 +144,14 @@
512,
None,
),
(
"naver/splade-v3-distilbert",
"text_expansion",
TextExpansionInferenceOptions,
NlpBertTokenizationConfig,
512,
None,
),
]
else:
MODEL_CONFIGURATIONS = []
Expand Down Expand Up @@ -216,6 +214,9 @@ def test_model_config(
if task_type == "text_similarity":
assert tokenization.truncate == "second"

if task_type == "text_expansion":
assert config.inference_config.expansion_type == "splade"

del tm

def test_model_config_with_prefix_string(self):
Expand Down