From a9aeecdd4f4324503013ec3de2d3d9cc7cc0d725 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=BB=8E=E5=90=8C=E5=AD=A6?= <958292128@qq.com>
Date: Thu, 16 May 2024 20:11:23 +0800
Subject: [PATCH] first commit
---
llm/pipelines/fill_mask_demo.py | 27 ++
mindnlp/transformers/pipelines/__init__.py | 15 +-
mindnlp/transformers/pipelines/fill_mask.py | 253 ++++++++++++
.../pipelines/test_pipelines_fill_mask.py | 379 ++++++++++++++++++
4 files changed, 672 insertions(+), 2 deletions(-)
create mode 100644 llm/pipelines/fill_mask_demo.py
create mode 100644 mindnlp/transformers/pipelines/fill_mask.py
create mode 100644 tests/ut/transformers/pipelines/test_pipelines_fill_mask.py
diff --git a/llm/pipelines/fill_mask_demo.py b/llm/pipelines/fill_mask_demo.py
new file mode 100644
index 00000000..86904d63
--- /dev/null
+++ b/llm/pipelines/fill_mask_demo.py
@@ -0,0 +1,27 @@
+from mindnlp.transformers import pipeline
+
+unmasker = pipeline('fill-mask', model='distilbert-base-uncased')
+output=unmasker("Hello I'm a [MASK] model.")
+print(output)
+'''
+[{'sequence': "[CLS] hello i'm a role model. [SEP]",
+ 'score': 0.05292865261435509,
+ 'token': 2535,
+ 'token_str': 'role'},
+ {'sequence': "[CLS] hello i'm a fashion model. [SEP]",
+ 'score': 0.0396859310567379,
+ 'token': 4827,
+ 'token_str': 'fashion'},
+ {'sequence': "[CLS] hello i'm a business model. [SEP]",
+ 'score': 0.034743666648864746,
+ 'token': 2449,
+ 'token_str': 'business'},
+ {'sequence': "[CLS] hello i'm a model model. [SEP]",
+ 'score': 0.034622687846422195,
+ 'token': 2944,
+ 'token_str': 'model'},
+ {'sequence': "[CLS] hello i'm a modeling model. [SEP]",
+ 'score': 0.018145263195037842,
+ 'token': 11643,
+ 'token_str': 'modeling'}]
+ '''
diff --git a/mindnlp/transformers/pipelines/__init__.py b/mindnlp/transformers/pipelines/__init__.py
index 526ffaf9..4c49813a 100644
--- a/mindnlp/transformers/pipelines/__init__.py
+++ b/mindnlp/transformers/pipelines/__init__.py
@@ -53,6 +53,7 @@ from .question_answering import QuestionAnsweringPipeline
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
from .document_question_answering import DocumentQuestionAnsweringPipeline
+from .fill_mask import FillMaskPipeline
from ..models.auto.modeling_auto import (
# AutoModel,
@@ -60,7 +61,7 @@ from ..models.auto.modeling_auto import (
AutoModelForCausalLM,
AutoModelForCTC,
AutoModelForDocumentQuestionAnswering,
- # AutoModelForMaskedLM,
+ AutoModelForMaskedLM,
# AutoModelForMaskGeneration,
# AutoModelForObjectDetection,
AutoModelForQuestionAnswering,
@@ -162,7 +163,16 @@ SUPPORTED_TASKS = {
},
"type": "multimodal",
},
-
+ "fill-mask": {
+ "impl": FillMaskPipeline,
+ "ms": (AutoModelForMaskedLM,),
+ "default": {
+ "model": {
+ "ms": ("distilbert/distilroberta-base", "ec58a5b"),
+ }
+ },
+ "type": "text",
+ },
}
NO_FEATURE_EXTRACTOR_TASKS = set()
@@ -597,6 +607,7 @@ def pipeline(
__all__ = [
'CsvPipelineDataFormat',
+ 'FillMaskPipeline',
'JsonPipelineDataFormat',
'PipedPipelineDataFormat',
'Pipeline',
diff --git a/mindnlp/transformers/pipelines/fill_mask.py b/mindnlp/transformers/pipelines/fill_mask.py
new file mode 100644
index 00000000..7a4c7034
--- /dev/null
+++ b/mindnlp/transformers/pipelines/fill_mask.py
@@ -0,0 +1,253 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"Fill Mask Pipeline"
+from typing import Dict
+
+import numpy as np
+
+from mindnlp.utils import is_mindspore_available,logging
+from .base import GenericTensor, Pipeline, PipelineException
+
+
+
+if is_mindspore_available():
+ import mindspore as ms
+ from mindspore import ops
+
+
+logger = logging.get_logger(__name__)
+
+
+class FillMaskPipeline(Pipeline):
+ """
+ Masked language modeling prediction pipeline using any `ModelWithLMHead`.
+ See the [masked language modeling
+ examples](../task_summary#masked-language-modeling) for more information.
+
+ Example:
+
+ ```python
+ >>> from mindnlp.transformers import pipeline
+
+ >>> fill_masker = pipeline(model="google-bert/bert-base-uncased")
+ >>> fill_masker("This is a simple [MASK].")
+ [{'score': 0.042, 'token': 3291, 'token_str': 'problem', 'sequence': 'this is a simple problem.'}, {'score': 0.031, 'token': 3160, 'token_str': 'question', 'sequence': 'this is a simple question.'}, {'score': 0.03, 'token': 8522, 'token_str': 'equation', 'sequence': 'this is a simple equation.'}, {'score': 0.027, 'token': 2028, 'token_str': 'one', 'sequence': 'this is a simple one.'}, {'score': 0.024, 'token': 3627, 'token_str': 'rule', 'sequence': 'this is a simple rule.'}]
+ ```
+
+ Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
+
+ This mask filling pipeline can currently be loaded from [`pipeline`] using the following task identifier:
+ `"fill-mask"`.
+
+ The models that this pipeline can use are models that have been trained with a masked language modeling objective,
+ which includes the bi-directional models in the library. See the up-to-date list of available models on
+ [huggingface.co/models](https://huggingface.co/models?filter=fill-mask).
+
+
+
+ This pipeline only works for inputs with exactly one token masked. Experimental: We added support for multiple
+ masks. The returned values are raw model output, and correspond to disjoint probabilities where one might expect
+ joint probabilities (See [discussion](https://github.com/huggingface/transformers/pull/10222)).
+
+
+
+
+
+ This pipeline now supports tokenizer_kwargs. For example try:
+
+ ```python
+ >>> from mindnlp.transformers import pipeline
+
+ >>> fill_masker = pipeline(model="google-bert/bert-base-uncased")
+ >>> tokenizer_kwargs = {"truncation": True}
+ >>> fill_masker(
+ ... "This is a simple [MASK]. " + "...with a large amount of repeated text appended. " * 100,
+ ... tokenizer_kwargs=tokenizer_kwargs,
+ ... )
+ ```
+
+
+
+
+
+ """
+ def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
+ masked_index = ops.nonzero(input_ids == self.tokenizer.mask_token_id)
+ return masked_index
+
+ def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray:
+ masked_index = self.get_masked_index(input_ids)
+ numel = np.prod(masked_index.shape)
+ if numel < 1:
+ raise PipelineException(
+ "fill-mask",
+ self.model.base_model_prefix,
+ f"No mask_token ({self.tokenizer.mask_token}) found on the input",
+ )
+
+ def ensure_exactly_one_mask_token(self, model_inputs: GenericTensor):
+ if isinstance(model_inputs, list):
+ for model_input in model_inputs:
+ self._ensure_exactly_one_mask_token(model_input["input_ids"][0])
+ else:
+ for input_ids in model_inputs["input_ids"]:
+ self._ensure_exactly_one_mask_token(input_ids)
+
+ def preprocess(
+ self, inputs, return_tensors=None, tokenizer_kwargs=None, **preprocess_parameters
+ ) -> Dict[str, GenericTensor]:
+ if return_tensors is None:
+ return_tensors = 'ms'
+ if tokenizer_kwargs is None:
+ tokenizer_kwargs = {}
+
+ model_inputs = self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs)
+ self.ensure_exactly_one_mask_token(model_inputs)
+ return model_inputs
+
+ def _forward(self, model_inputs):
+ model_outputs = self.model(**model_inputs)
+ model_outputs["input_ids"] = model_inputs["input_ids"]
+ return model_outputs
+
+ def postprocess(self, model_outputs, top_k=5, target_ids=None):
+ # Cap top_k if there are targets
+ if target_ids is not None and target_ids.shape[0] < top_k:
+ top_k = target_ids.shape[0]
+ input_ids = model_outputs["input_ids"][0]
+ outputs = model_outputs["logits"]
+
+ masked_index = ops.nonzero(input_ids == self.tokenizer.mask_token_id).squeeze(-1)
+ # Fill mask pipeline supports only one ${mask_token} per sample
+
+ logits = outputs[0, masked_index, :]
+ probs=ops.softmax(logits,axis=-1)
+ if target_ids is not None:
+ probs = probs[..., target_ids]
+
+ values, predictions = probs.topk(top_k)
+
+ result = []
+ single_mask = values.shape[0] == 1
+ for i, (_values, _predictions) in enumerate(zip(values.tolist(), predictions.tolist())):
+ row = []
+ for v, p in zip(_values, _predictions):
+ # Copy is important since we're going to modify this array in place
+ tokens = input_ids.numpy().copy()
+ if target_ids is not None:
+ p = target_ids[p].tolist()
+
+ tokens[masked_index[i]] = p
+ # Filter padding out:
+ tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
+ # Originally we skip special tokens to give readable output.
+ # For multi masks though, the other [MASK] would be removed otherwise
+ # making the output look odd, so we add them back
+ sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask)
+ proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode([p]), "sequence": sequence}
+ row.append(proposition)
+ result.append(row)
+ if single_mask:
+ return result[0]
+ return result
+
+ def get_target_ids(self, targets, top_k=None):
+ if isinstance(targets, str):
+ targets = [targets]
+ try:
+ vocab = self.tokenizer.get_vocab()
+ except Exception:
+ vocab = {}
+ target_ids = []
+ for target in targets:
+ id_ = vocab.get(target, None)
+ if id_ is None:
+ input_ids = self.tokenizer(
+ target,
+ add_special_tokens=False,
+ return_attention_mask=False,
+ return_token_type_ids=False,
+ max_length=1,
+ truncation=True,
+ )["input_ids"]
+ if len(input_ids) == 0:
+ logger.warning(
+ f"The specified target token `{target}` does not exist in the model vocabulary. "
+ "We cannot replace it with anything meaningful, ignoring it"
+ )
+ continue
+ id_ = input_ids[0]
+ # XXX: If users encounter this pass
+ # it becomes pretty slow, so let's make sure
+ # The warning enables them to fix the input to
+ # get faster performance.
+ logger.warning(
+ f"The specified target token `{target}` does not exist in the model vocabulary. "
+ f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`."
+ )
+ target_ids.append(id_)
+ target_ids = list(set(target_ids))
+ if len(target_ids) == 0:
+ raise ValueError("At least one target must be provided when passed.")
+ target_ids = np.array(target_ids)
+ return target_ids
+
+ def _sanitize_parameters(self, top_k=None, targets=None, tokenizer_kwargs=None):
+ preprocess_params = {}
+
+ if tokenizer_kwargs is not None:
+ preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
+
+ postprocess_params = {}
+
+ if targets is not None:
+ target_ids = self.get_target_ids(targets, top_k)
+ postprocess_params["target_ids"] = target_ids
+
+ if top_k is not None:
+ postprocess_params["top_k"] = top_k
+
+ if self.tokenizer.mask_token_id is None:
+ raise PipelineException(
+ "fill-mask", self.model.base_model_prefix, "The tokenizer does not define a `mask_token`."
+ )
+ return preprocess_params, {}, postprocess_params
+
+ def __call__(self, inputs, *args, **kwargs):
+ """
+ Fill the masked token in the text(s) given as inputs.
+
+ Args:
+ args (`str` or `List[str]`):
+ One or several texts (or one list of prompts) with masked tokens.
+ targets (`str` or `List[str]`, *optional*):
+ When passed, the model will limit the scores to the passed targets instead of looking up in the whole
+ vocab. If the provided targets are not in the model vocab, they will be tokenized and the first
+ resulting token will be used (with a warning, and that might be slower).
+ top_k (`int`, *optional*):
+ When passed, overrides the number of predictions to return.
+
+ Return:
+ A list or a list of list of `dict`: Each result comes as list of dictionaries with the following keys:
+
+ - **sequence** (`str`) -- The corresponding input with the mask token prediction.
+ - **score** (`float`) -- The corresponding probability.
+ - **token** (`int`) -- The predicted token id (to replace the masked one).
+ - **token_str** (`str`) -- The predicted token (to replace the masked one).
+ """
+ outputs = super().__call__(inputs, **kwargs)
+ if isinstance(inputs, list) and len(inputs) == 1:
+ return outputs[0]
+ return outputs
diff --git a/tests/ut/transformers/pipelines/test_pipelines_fill_mask.py b/tests/ut/transformers/pipelines/test_pipelines_fill_mask.py
new file mode 100644
index 00000000..0ac90ca3
--- /dev/null
+++ b/tests/ut/transformers/pipelines/test_pipelines_fill_mask.py
@@ -0,0 +1,379 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+from mindnlp.transformers import MODEL_FOR_MASKED_LM_MAPPING,FillMaskPipeline, pipeline
+from mindnlp.transformers.pipelines.base import PipelineException
+from mindnlp.utils.testing_utils import (
+ is_pipeline_test,
+ nested_simplify,
+ require_mindspore,
+ slow,
+)
+
+from .test_pipelines_common import ANY
+
+
+@is_pipeline_test
+class FillMaskPipelineTests(unittest.TestCase):
+ model_mapping = MODEL_FOR_MASKED_LM_MAPPING
+
+ def tearDown(self):
+ super().tearDown()
+ # clean-up as much as possible GPU memory occupied
+ gc.collect()
+
+ @require_mindspore
+ def test_small_model_ms(self):
+ unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", top_k=2)
+
+ outputs = unmasker("My name is ")
+ self.assertEqual(
+ nested_simplify(outputs, decimals=6),
+ [
+ {"sequence": "My name is Maul", "score": 2.2e-05, "token": 35676, "token_str": " Maul"},
+ {"sequence": "My name isELS", "score": 2.2e-05, "token": 16416, "token_str": "ELS"},
+ ],
+ )
+
+ outputs = unmasker("The largest city in France is ")
+ self.assertEqual(
+ nested_simplify(outputs, decimals=6),
+ [
+ {
+ "sequence": "The largest city in France is Maul",
+ "score": 2.2e-05,
+ "token": 35676,
+ "token_str": " Maul",
+ },
+ {"sequence": "The largest city in France isELS", "score": 2.2e-05, "token": 16416, "token_str": "ELS"},
+ ],
+ )
+
+ outputs = unmasker("My name is ", top_k=2)
+
+ self.assertEqual(
+ nested_simplify(outputs, decimals=6),
+ [
+ [
+ {
+ "score": 2.2e-05,
+ "token": 35676,
+ "token_str": " Maul",
+ "sequence": "My name is Maul",
+ },
+ {"score": 2.2e-05, "token": 16416, "token_str": "ELS", "sequence": "My name isELS"},
+ ],
+ [
+ {
+ "score": 2.2e-05,
+ "token": 35676,
+ "token_str": " Maul",
+ "sequence": "My name is Maul",
+ },
+ {"score": 2.2e-05, "token": 16416, "token_str": "ELS", "sequence": "My name isELS"},
+ ],
+ ],
+ )
+
+
+ def test_fp16_casting(self):
+ pipe = pipeline(
+ "fill-mask",
+ model="hf-internal-testing/tiny-random-distilbert",
+ )
+
+ # convert model to fp16
+ #pipe.model.half()
+
+ response = pipe("Paris is the [MASK] of France.")
+ # We actually don't care about the result, we just want to make sure
+ # it works, meaning the float16 tensor got casted back to float32
+ # for postprocessing.
+ self.assertIsInstance(response, list)
+
+ @slow
+ @require_mindspore
+ def test_large_model_ms(self):
+ unmasker = pipeline(task="fill-mask", model="distilbert/distilroberta-base", top_k=2)
+ self.run_large_test(unmasker)
+
+ def run_large_test(self, unmasker):
+ outputs = unmasker("My name is ")
+ self.assertEqual(
+ nested_simplify(outputs),
+ [
+ {"sequence": "My name is John", "score": 0.008, "token": 610, "token_str": " John"},
+ {"sequence": "My name is Chris", "score": 0.007, "token": 1573, "token_str": " Chris"},
+ ],
+ )
+ outputs = unmasker("The largest city in France is ")
+ self.assertEqual(
+ nested_simplify(outputs),
+ [
+ {
+ "sequence": "The largest city in France is Paris",
+ "score": 0.251,
+ "token": 2201,
+ "token_str": " Paris",
+ },
+ {
+ "sequence": "The largest city in France is Lyon",
+ "score": 0.214,
+ "token": 12790,
+ "token_str": " Lyon",
+ },
+ ],
+ )
+
+ outputs = unmasker("My name is ", targets=[" Patrick", " Clara", " Teven"], top_k=3)
+ self.assertEqual(
+ nested_simplify(outputs),
+ [
+ {"sequence": "My name is Patrick", "score": 0.005, "token": 3499, "token_str": " Patrick"},
+ {"sequence": "My name is Clara", "score": 0.000, "token": 13606, "token_str": " Clara"},
+ {"sequence": "My name is Te", "score": 0.000, "token": 2941, "token_str": " Te"},
+ ],
+ )
+
+ dummy_str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit," * 100
+ outputs = unmasker(
+ "My name is " + dummy_str,
+ tokenizer_kwargs={"truncation": True},
+ )
+ simplified = nested_simplify(outputs, decimals=4)
+ self.assertEqual(
+ [{"sequence": x["sequence"][:100]} for x in simplified],
+ [
+ {"sequence": f"My name is,{dummy_str}"[:100]},
+ {"sequence": f"My name is:,{dummy_str}"[:100]},
+ ],
+ )
+ self.assertEqual(
+ [{k: x[k] for k in x if k != "sequence"} for x in simplified],
+ [
+ {"score": 0.2819, "token": 6, "token_str": ","},
+ {"score": 0.0954, "token": 46686, "token_str": ":,"},
+ ],
+ )
+
+ @require_mindspore
+ def test_model_no_pad_ms(self):
+ unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base")
+ unmasker.tokenizer.pad_token_id = None
+ unmasker.tokenizer.pad_token = None
+ self.run_pipeline_test(unmasker, [])
+
+ def get_test_pipeline(self, model, tokenizer, processor):
+ if tokenizer is None or tokenizer.mask_token_id is None:
+ self.skipTest("The provided tokenizer has no mask token, (probably reformer or wav2vec2)")
+
+ fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
+ examples = [
+ f"This is another {tokenizer.mask_token} test",
+ ]
+ return fill_masker, examples
+
+ def run_pipeline_test(self, fill_masker, examples):
+ tokenizer = fill_masker.tokenizer
+ model = fill_masker.model
+ outputs = fill_masker(
+ f"This is a {tokenizer.mask_token}",
+ )
+ self.assertEqual(
+ outputs,
+ [
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ ],
+ )
+
+ outputs = fill_masker([f"This is a {tokenizer.mask_token}"])
+ self.assertEqual(
+ outputs,
+ [
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ ],
+ )
+
+ outputs = fill_masker([f"This is a {tokenizer.mask_token}", f"Another {tokenizer.mask_token} great test."])
+ self.assertEqual(
+ outputs,
+ [
+ [
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ ],
+ [
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ ],
+ ],
+ )
+
+ with self.assertRaises(ValueError):
+ fill_masker([None])
+ # No mask_token is not supported
+ with self.assertRaises(PipelineException):
+ fill_masker("This is")
+
+ self.run_test_top_k(model, tokenizer)
+ self.run_test_targets(model, tokenizer)
+ self.run_test_top_k_targets(model, tokenizer)
+ self.fill_mask_with_duplicate_targets_and_top_k(model, tokenizer)
+ self.fill_mask_with_multiple_masks(model, tokenizer)
+
+ def run_test_targets(self, model, tokenizer):
+ vocab = tokenizer.get_vocab()
+ targets = sorted(vocab.keys())[:2]
+ # Pipeline argument
+ fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer, targets=targets)
+ outputs = fill_masker(f"This is a {tokenizer.mask_token}")
+ self.assertEqual(
+ outputs,
+ [
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ ],
+ )
+ target_ids = {vocab[el] for el in targets}
+ self.assertEqual({el["token"] for el in outputs}, target_ids)
+ processed_targets = [tokenizer.decode([x]) for x in target_ids]
+ self.assertEqual({el["token_str"] for el in outputs}, set(processed_targets))
+
+ # Call argument
+ fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
+ outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets=targets)
+ self.assertEqual(
+ outputs,
+ [
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ ],
+ )
+ target_ids = {vocab[el] for el in targets}
+ self.assertEqual({el["token"] for el in outputs}, target_ids)
+ processed_targets = [tokenizer.decode([x]) for x in target_ids]
+ self.assertEqual({el["token_str"] for el in outputs}, set(processed_targets))
+
+ # Score equivalence
+ outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets=targets)
+ tokens = [top_mask["token_str"] for top_mask in outputs]
+ scores = [top_mask["score"] for top_mask in outputs]
+
+ # For some BPE tokenizers, `` is removed during decoding, so `token_str` won't be the same as in `targets`.
+ if set(tokens) == set(targets):
+ unmasked_targets = fill_masker(f"This is a {tokenizer.mask_token}", targets=tokens)
+ target_scores = [top_mask["score"] for top_mask in unmasked_targets]
+ self.assertEqual(nested_simplify(scores), nested_simplify(target_scores))
+
+ # Raises with invalid
+ with self.assertRaises(ValueError):
+ outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets=[])
+ # For some tokenizers, `""` is actually in the vocabulary and the expected error won't raised
+ if "" not in tokenizer.get_vocab():
+ with self.assertRaises(ValueError):
+ outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets=[""])
+ with self.assertRaises(ValueError):
+ outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets="")
+
+ def run_test_top_k(self, model, tokenizer):
+ fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer, top_k=2)
+ outputs = fill_masker(f"This is a {tokenizer.mask_token}")
+ self.assertEqual(
+ outputs,
+ [
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ ],
+ )
+
+ fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
+ outputs2 = fill_masker(f"This is a {tokenizer.mask_token}", top_k=2)
+ self.assertEqual(
+ outputs2,
+ [
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ ],
+ )
+ self.assertEqual(nested_simplify(outputs), nested_simplify(outputs2))
+
+ def run_test_top_k_targets(self, model, tokenizer):
+ vocab = tokenizer.get_vocab()
+ fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
+
+ # top_k=2, ntargets=3
+ targets = sorted(vocab.keys())[:3]
+ outputs = fill_masker(f"This is a {tokenizer.mask_token}", top_k=2, targets=targets)
+
+ # If we use the most probably targets, and filter differently, we should still
+ # have the same results
+ targets2 = [el["token_str"] for el in sorted(outputs, key=lambda x: x["score"], reverse=True)]
+ # For some BPE tokenizers, `` is removed during decoding, so `token_str` won't be the same as in `targets`.
+ if set(targets2).issubset(targets):
+ outputs2 = fill_masker(f"This is a {tokenizer.mask_token}", top_k=3, targets=targets2)
+ # They should yield exactly the same result
+ self.assertEqual(nested_simplify(outputs), nested_simplify(outputs2))
+
+ def fill_mask_with_duplicate_targets_and_top_k(self, model, tokenizer):
+ fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
+ vocab = tokenizer.get_vocab()
+ # String duplicates + id duplicates
+ targets = sorted(vocab.keys())[:3]
+ targets = [targets[0], targets[1], targets[0], targets[2], targets[1]]
+ outputs = fill_masker(f"My name is {tokenizer.mask_token}", targets=targets, top_k=10)
+
+ # The target list contains duplicates, so we can't output more
+ # than them
+ self.assertEqual(len(outputs), 3)
+
+ def fill_mask_with_multiple_masks(self, model, tokenizer):
+ fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
+
+ outputs = fill_masker(
+ f"This is a {tokenizer.mask_token} {tokenizer.mask_token} {tokenizer.mask_token}", top_k=2
+ )
+ self.assertEqual(
+ outputs,
+ [
+ [
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ ],
+ [
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ ],
+ [
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
+ ],
+ ],
+ )
--
Gitee