Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eland/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
# index_field.setter
self._is_source_field = False

self.es_index_field = es_index_field
self.es_index_field = es_index_field # type: ignore

@property
def sort_field(self) -> str:
Expand Down
8 changes: 4 additions & 4 deletions eland/ml/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
FillMaskInferenceOptions,
NerInferenceOptions,
NlpBertTokenizationConfig,
NlpDebertaV2TokenizationConfig,
NlpMPNetTokenizationConfig,
NlpRobertaTokenizationConfig,
NlpTrainedModelConfig,
Expand All @@ -30,11 +31,9 @@
TextSimilarityInferenceOptions,
ZeroShotClassificationInferenceOptions,
)
from eland.ml.pytorch.tokenizers import UnknownModelInputSizeError
from eland.ml.pytorch.traceable_model import TraceableModel # noqa: F401
from eland.ml.pytorch.transformers import (
UnknownModelInputSizeError,
task_type_from_model_config,
)
from eland.ml.pytorch.transformers import task_type_from_model_config

__all__ = [
"PyTorchModel",
Expand All @@ -43,6 +42,7 @@
"NerInferenceOptions",
"NlpTrainedModelConfig",
"NlpBertTokenizationConfig",
"NlpDebertaV2TokenizationConfig",
"NlpRobertaTokenizationConfig",
"NlpXLMRobertaTokenizationConfig",
"NlpMPNetTokenizationConfig",
Expand Down
162 changes: 162 additions & 0 deletions eland/ml/pytorch/tokenizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Optional, Union

import transformers

from eland.ml.pytorch.nlp_ml_model import (
NlpBertJapaneseTokenizationConfig,
NlpBertTokenizationConfig,
NlpDebertaV2TokenizationConfig,
NlpMPNetTokenizationConfig,
NlpRobertaTokenizationConfig,
NlpTokenizationConfig,
NlpXLMRobertaTokenizationConfig,
)

SUPPORTED_TOKENIZERS = (
transformers.BertTokenizer,
transformers.BertTokenizerFast,
transformers.BertJapaneseTokenizer,
transformers.MPNetTokenizer,
transformers.MPNetTokenizerFast,
transformers.DPRContextEncoderTokenizer,
transformers.DPRContextEncoderTokenizerFast,
transformers.DPRQuestionEncoderTokenizer,
transformers.DPRQuestionEncoderTokenizerFast,
transformers.DistilBertTokenizer,
transformers.DistilBertTokenizerFast,
transformers.ElectraTokenizer,
transformers.ElectraTokenizerFast,
transformers.MobileBertTokenizer,
transformers.MobileBertTokenizerFast,
transformers.RetriBertTokenizer,
transformers.RetriBertTokenizerFast,
transformers.RobertaTokenizer,
transformers.RobertaTokenizerFast,
transformers.BartTokenizer,
transformers.BartTokenizerFast,
transformers.SqueezeBertTokenizer,
transformers.SqueezeBertTokenizerFast,
transformers.XLMRobertaTokenizer,
transformers.XLMRobertaTokenizerFast,
transformers.DebertaV2Tokenizer,
transformers.DebertaV2TokenizerFast,
)
SUPPORTED_TOKENIZERS_NAMES = ", ".join(sorted([str(x) for x in SUPPORTED_TOKENIZERS]))


class UnknownModelInputSizeError(Exception):
pass


def find_max_sequence_length(
model_id: str,
tokenizer: Union[
transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast
],
) -> int:
# Sometimes the max_... values are present but contain
# a random or very large value.
REASONABLE_MAX_LENGTH = 8192
max_len = getattr(tokenizer, "model_max_length", None)
if max_len is not None and max_len <= REASONABLE_MAX_LENGTH:
return int(max_len)

max_sizes = getattr(tokenizer, "max_model_input_sizes", dict())
max_len = max_sizes.get(model_id)
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
return int(max_len)

if max_sizes:
# The model id wasn't found in the max sizes dict but
# if all the values correspond then take that value
sizes = {size for size in max_sizes.values()}
if len(sizes) == 1:
max_len = sizes.pop()
if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
return int(max_len)

if isinstance(
tokenizer, (transformers.BertTokenizer, transformers.BertTokenizerFast)
):
return 512

raise UnknownModelInputSizeError("Cannot determine model max input length")


def create_tokenization_config(
model_id: str,
max_model_input_size: Optional[int],
tokenizer: Union[
transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast
],
) -> NlpTokenizationConfig:
if max_model_input_size is not None:
_max_sequence_length = max_model_input_size
else:
_max_sequence_length = find_max_sequence_length(model_id, tokenizer)

if isinstance(
tokenizer, (transformers.MPNetTokenizer, transformers.MPNetTokenizerFast)
):
return NlpMPNetTokenizationConfig(
do_lower_case=getattr(tokenizer, "do_lower_case", None),
max_sequence_length=_max_sequence_length,
)
elif isinstance(
tokenizer,
(
transformers.RobertaTokenizer,
transformers.RobertaTokenizerFast,
transformers.BartTokenizer,
transformers.BartTokenizerFast,
),
):
return NlpRobertaTokenizationConfig(
add_prefix_space=getattr(tokenizer, "add_prefix_space", None),
max_sequence_length=_max_sequence_length,
)
elif isinstance(
tokenizer,
(transformers.XLMRobertaTokenizer, transformers.XLMRobertaTokenizerFast),
):
return NlpXLMRobertaTokenizationConfig(max_sequence_length=_max_sequence_length)
elif isinstance(
tokenizer,
(transformers.DebertaV2Tokenizer, transformers.DebertaV2TokenizerFast),
):
return NlpDebertaV2TokenizationConfig(
max_sequence_length=_max_sequence_length,
do_lower_case=getattr(tokenizer, "do_lower_case", None),
)
else:
japanese_morphological_tokenizers = ["mecab"]
if (
hasattr(tokenizer, "word_tokenizer_type")
and tokenizer.word_tokenizer_type in japanese_morphological_tokenizers
):
return NlpBertJapaneseTokenizationConfig(
do_lower_case=getattr(tokenizer, "do_lower_case", None),
max_sequence_length=_max_sequence_length,
)
else:
return NlpBertTokenizationConfig(
do_lower_case=getattr(tokenizer, "do_lower_case", None),
max_sequence_length=_max_sequence_length,
)
174 changes: 172 additions & 2 deletions eland/ml/pytorch/traceable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@

import os.path
from abc import ABC, abstractmethod
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

import torch # type: ignore
from torch import nn
import transformers
from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast

from eland.ml.pytorch.wrappers import (
_DistilBertWrapper,
_DPREncoderWrapper,
_QuestionAnsweringWrapperModule,
_SentenceTransformerWrapperModule,
)

TracedModelTypes = Union[
torch.nn.Module,
Expand Down Expand Up @@ -68,3 +77,164 @@ def save(self, path: str) -> str:
@property
def model(self) -> nn.Module:
return self._model


class TransformerTraceableModel(TraceableModel):
"""A base class representing a HuggingFace transformer model that can be traced."""

def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
model: Union[
PreTrainedModel,
_SentenceTransformerWrapperModule,
_DPREncoderWrapper,
_DistilBertWrapper,
_QuestionAnsweringWrapperModule,
],
):
super(TransformerTraceableModel, self).__init__(model=model)
self._tokenizer = tokenizer

def _trace(self) -> TracedModelTypes:
inputs = self._compatible_inputs()
return torch.jit.trace(self._model, example_inputs=inputs)

def sample_output(self) -> Tensor:
inputs = self._compatible_inputs()
return self._model(*inputs)

def _compatible_inputs(self) -> Tuple[Tensor, ...]:
inputs = self._prepare_inputs()

# Add params when not provided by the tokenizer (e.g. DistilBERT), to conform to BERT interface
if "token_type_ids" not in inputs:
inputs["token_type_ids"] = torch.zeros(
inputs["input_ids"].size(1), dtype=torch.long
)
if isinstance(
self._tokenizer,
(
transformers.BartTokenizer,
transformers.MPNetTokenizer,
transformers.RobertaTokenizer,
transformers.XLMRobertaTokenizer,
),
):
return (inputs["input_ids"], inputs["attention_mask"])

if isinstance(self._tokenizer, transformers.DebertaV2Tokenizer):
return (
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
)

position_ids = torch.arange(inputs["input_ids"].size(1), dtype=torch.long)
inputs["position_ids"] = position_ids
return (
inputs["input_ids"],
inputs["attention_mask"],
inputs["token_type_ids"],
inputs["position_ids"],
)

@abstractmethod
def _prepare_inputs(self) -> transformers.BatchEncoding: ...


class TraceableClassificationModel(TransformerTraceableModel, ABC):
def classification_labels(self) -> Optional[List[str]]:
id_label_items = self._model.config.id2label.items()
labels = [v for _, v in sorted(id_label_items, key=lambda kv: kv[0])]

# Make classes like I-PER into I_PER which fits Java enumerations
return [label.replace("-", "_") for label in labels]


class TraceableFillMaskModel(TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"Who was Jim Henson?",
"[MASK] Henson was a puppeteer",
padding="max_length",
return_tensors="pt",
)


class TraceableTextExpansionModel(TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"This is an example sentence.",
padding="max_length",
return_tensors="pt",
)


class TraceableNerModel(TraceableClassificationModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
(
"Hugging Face Inc. is a company based in New York City. "
"Its headquarters are in DUMBO, therefore very close to the Manhattan Bridge."
),
padding="max_length",
return_tensors="pt",
)


class TraceablePassThroughModel(TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"This is an example sentence.",
padding="max_length",
return_tensors="pt",
)


class TraceableTextClassificationModel(TraceableClassificationModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"This is an example sentence.",
padding="max_length",
return_tensors="pt",
)


class TraceableTextEmbeddingModel(TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"This is an example sentence.",
padding="longest",
return_tensors="pt",
)


class TraceableZeroShotClassificationModel(TraceableClassificationModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"This is an example sentence.",
"This example is an example.",
padding="max_length",
return_tensors="pt",
)


class TraceableQuestionAnsweringModel(TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"What is the meaning of life?"
"The meaning of life, according to the hitchikers guide, is 42.",
padding="max_length",
return_tensors="pt",
)


class TraceableTextSimilarityModel(TransformerTraceableModel):
def _prepare_inputs(self) -> transformers.BatchEncoding:
return self._tokenizer(
"What is the meaning of life?"
"The meaning of life, according to the hitchikers guide, is 42.",
padding="max_length",
return_tensors="pt",
)
Loading