From a9aeecdd4f4324503013ec3de2d3d9cc7cc0d725 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=8E=E5=90=8C=E5=AD=A6?= <958292128@qq.com> Date: Thu, 16 May 2024 20:11:23 +0800 Subject: [PATCH] first commit --- llm/pipelines/fill_mask_demo.py | 27 ++ mindnlp/transformers/pipelines/__init__.py | 15 +- mindnlp/transformers/pipelines/fill_mask.py | 253 ++++++++++++ .../pipelines/test_pipelines_fill_mask.py | 379 ++++++++++++++++++ 4 files changed, 672 insertions(+), 2 deletions(-) create mode 100644 llm/pipelines/fill_mask_demo.py create mode 100644 mindnlp/transformers/pipelines/fill_mask.py create mode 100644 tests/ut/transformers/pipelines/test_pipelines_fill_mask.py diff --git a/llm/pipelines/fill_mask_demo.py b/llm/pipelines/fill_mask_demo.py new file mode 100644 index 00000000..86904d63 --- /dev/null +++ b/llm/pipelines/fill_mask_demo.py @@ -0,0 +1,27 @@ +from mindnlp.transformers import pipeline + +unmasker = pipeline('fill-mask', model='distilbert-base-uncased') +output=unmasker("Hello I'm a [MASK] model.") +print(output) +''' +[{'sequence': "[CLS] hello i'm a role model. [SEP]", + 'score': 0.05292865261435509, + 'token': 2535, + 'token_str': 'role'}, + {'sequence': "[CLS] hello i'm a fashion model. [SEP]", + 'score': 0.0396859310567379, + 'token': 4827, + 'token_str': 'fashion'}, + {'sequence': "[CLS] hello i'm a business model. [SEP]", + 'score': 0.034743666648864746, + 'token': 2449, + 'token_str': 'business'}, + {'sequence': "[CLS] hello i'm a model model. [SEP]", + 'score': 0.034622687846422195, + 'token': 2944, + 'token_str': 'model'}, + {'sequence': "[CLS] hello i'm a modeling model. [SEP]", + 'score': 0.018145263195037842, + 'token': 11643, + 'token_str': 'modeling'}] + ''' diff --git a/mindnlp/transformers/pipelines/__init__.py b/mindnlp/transformers/pipelines/__init__.py index 526ffaf9..4c49813a 100644 --- a/mindnlp/transformers/pipelines/__init__.py +++ b/mindnlp/transformers/pipelines/__init__.py @@ -53,6 +53,7 @@ from .question_answering import QuestionAnsweringPipeline from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline from .document_question_answering import DocumentQuestionAnsweringPipeline +from .fill_mask import FillMaskPipeline from ..models.auto.modeling_auto import ( # AutoModel, @@ -60,7 +61,7 @@ from ..models.auto.modeling_auto import ( AutoModelForCausalLM, AutoModelForCTC, AutoModelForDocumentQuestionAnswering, - # AutoModelForMaskedLM, + AutoModelForMaskedLM, # AutoModelForMaskGeneration, # AutoModelForObjectDetection, AutoModelForQuestionAnswering, @@ -162,7 +163,16 @@ SUPPORTED_TASKS = { }, "type": "multimodal", }, - + "fill-mask": { + "impl": FillMaskPipeline, + "ms": (AutoModelForMaskedLM,), + "default": { + "model": { + "ms": ("distilbert/distilroberta-base", "ec58a5b"), + } + }, + "type": "text", + }, } NO_FEATURE_EXTRACTOR_TASKS = set() @@ -597,6 +607,7 @@ def pipeline( __all__ = [ 'CsvPipelineDataFormat', + 'FillMaskPipeline', 'JsonPipelineDataFormat', 'PipedPipelineDataFormat', 'Pipeline', diff --git a/mindnlp/transformers/pipelines/fill_mask.py b/mindnlp/transformers/pipelines/fill_mask.py new file mode 100644 index 00000000..7a4c7034 --- /dev/null +++ b/mindnlp/transformers/pipelines/fill_mask.py @@ -0,0 +1,253 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"Fill Mask Pipeline" +from typing import Dict + +import numpy as np + +from mindnlp.utils import is_mindspore_available,logging +from .base import GenericTensor, Pipeline, PipelineException + + + +if is_mindspore_available(): + import mindspore as ms + from mindspore import ops + + +logger = logging.get_logger(__name__) + + +class FillMaskPipeline(Pipeline): + """ + Masked language modeling prediction pipeline using any `ModelWithLMHead`. + See the [masked language modeling + examples](../task_summary#masked-language-modeling) for more information. + + Example: + + ```python + >>> from mindnlp.transformers import pipeline + + >>> fill_masker = pipeline(model="google-bert/bert-base-uncased") + >>> fill_masker("This is a simple [MASK].") + [{'score': 0.042, 'token': 3291, 'token_str': 'problem', 'sequence': 'this is a simple problem.'}, {'score': 0.031, 'token': 3160, 'token_str': 'question', 'sequence': 'this is a simple question.'}, {'score': 0.03, 'token': 8522, 'token_str': 'equation', 'sequence': 'this is a simple equation.'}, {'score': 0.027, 'token': 2028, 'token_str': 'one', 'sequence': 'this is a simple one.'}, {'score': 0.024, 'token': 3627, 'token_str': 'rule', 'sequence': 'this is a simple rule.'}] + ``` + + Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) + + This mask filling pipeline can currently be loaded from [`pipeline`] using the following task identifier: + `"fill-mask"`. + + The models that this pipeline can use are models that have been trained with a masked language modeling objective, + which includes the bi-directional models in the library. See the up-to-date list of available models on + [huggingface.co/models](https://huggingface.co/models?filter=fill-mask). + + + + This pipeline only works for inputs with exactly one token masked. Experimental: We added support for multiple + masks. The returned values are raw model output, and correspond to disjoint probabilities where one might expect + joint probabilities (See [discussion](https://github.com/huggingface/transformers/pull/10222)). + + + + + + This pipeline now supports tokenizer_kwargs. For example try: + + ```python + >>> from mindnlp.transformers import pipeline + + >>> fill_masker = pipeline(model="google-bert/bert-base-uncased") + >>> tokenizer_kwargs = {"truncation": True} + >>> fill_masker( + ... "This is a simple [MASK]. " + "...with a large amount of repeated text appended. " * 100, + ... tokenizer_kwargs=tokenizer_kwargs, + ... ) + ``` + + + + + + """ + def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray: + masked_index = ops.nonzero(input_ids == self.tokenizer.mask_token_id) + return masked_index + + def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray: + masked_index = self.get_masked_index(input_ids) + numel = np.prod(masked_index.shape) + if numel < 1: + raise PipelineException( + "fill-mask", + self.model.base_model_prefix, + f"No mask_token ({self.tokenizer.mask_token}) found on the input", + ) + + def ensure_exactly_one_mask_token(self, model_inputs: GenericTensor): + if isinstance(model_inputs, list): + for model_input in model_inputs: + self._ensure_exactly_one_mask_token(model_input["input_ids"][0]) + else: + for input_ids in model_inputs["input_ids"]: + self._ensure_exactly_one_mask_token(input_ids) + + def preprocess( + self, inputs, return_tensors=None, tokenizer_kwargs=None, **preprocess_parameters + ) -> Dict[str, GenericTensor]: + if return_tensors is None: + return_tensors = 'ms' + if tokenizer_kwargs is None: + tokenizer_kwargs = {} + + model_inputs = self.tokenizer(inputs, return_tensors=return_tensors, **tokenizer_kwargs) + self.ensure_exactly_one_mask_token(model_inputs) + return model_inputs + + def _forward(self, model_inputs): + model_outputs = self.model(**model_inputs) + model_outputs["input_ids"] = model_inputs["input_ids"] + return model_outputs + + def postprocess(self, model_outputs, top_k=5, target_ids=None): + # Cap top_k if there are targets + if target_ids is not None and target_ids.shape[0] < top_k: + top_k = target_ids.shape[0] + input_ids = model_outputs["input_ids"][0] + outputs = model_outputs["logits"] + + masked_index = ops.nonzero(input_ids == self.tokenizer.mask_token_id).squeeze(-1) + # Fill mask pipeline supports only one ${mask_token} per sample + + logits = outputs[0, masked_index, :] + probs=ops.softmax(logits,axis=-1) + if target_ids is not None: + probs = probs[..., target_ids] + + values, predictions = probs.topk(top_k) + + result = [] + single_mask = values.shape[0] == 1 + for i, (_values, _predictions) in enumerate(zip(values.tolist(), predictions.tolist())): + row = [] + for v, p in zip(_values, _predictions): + # Copy is important since we're going to modify this array in place + tokens = input_ids.numpy().copy() + if target_ids is not None: + p = target_ids[p].tolist() + + tokens[masked_index[i]] = p + # Filter padding out: + tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)] + # Originally we skip special tokens to give readable output. + # For multi masks though, the other [MASK] would be removed otherwise + # making the output look odd, so we add them back + sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask) + proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode([p]), "sequence": sequence} + row.append(proposition) + result.append(row) + if single_mask: + return result[0] + return result + + def get_target_ids(self, targets, top_k=None): + if isinstance(targets, str): + targets = [targets] + try: + vocab = self.tokenizer.get_vocab() + except Exception: + vocab = {} + target_ids = [] + for target in targets: + id_ = vocab.get(target, None) + if id_ is None: + input_ids = self.tokenizer( + target, + add_special_tokens=False, + return_attention_mask=False, + return_token_type_ids=False, + max_length=1, + truncation=True, + )["input_ids"] + if len(input_ids) == 0: + logger.warning( + f"The specified target token `{target}` does not exist in the model vocabulary. " + "We cannot replace it with anything meaningful, ignoring it" + ) + continue + id_ = input_ids[0] + # XXX: If users encounter this pass + # it becomes pretty slow, so let's make sure + # The warning enables them to fix the input to + # get faster performance. + logger.warning( + f"The specified target token `{target}` does not exist in the model vocabulary. " + f"Replacing with `{self.tokenizer.convert_ids_to_tokens(id_)}`." + ) + target_ids.append(id_) + target_ids = list(set(target_ids)) + if len(target_ids) == 0: + raise ValueError("At least one target must be provided when passed.") + target_ids = np.array(target_ids) + return target_ids + + def _sanitize_parameters(self, top_k=None, targets=None, tokenizer_kwargs=None): + preprocess_params = {} + + if tokenizer_kwargs is not None: + preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs + + postprocess_params = {} + + if targets is not None: + target_ids = self.get_target_ids(targets, top_k) + postprocess_params["target_ids"] = target_ids + + if top_k is not None: + postprocess_params["top_k"] = top_k + + if self.tokenizer.mask_token_id is None: + raise PipelineException( + "fill-mask", self.model.base_model_prefix, "The tokenizer does not define a `mask_token`." + ) + return preprocess_params, {}, postprocess_params + + def __call__(self, inputs, *args, **kwargs): + """ + Fill the masked token in the text(s) given as inputs. + + Args: + args (`str` or `List[str]`): + One or several texts (or one list of prompts) with masked tokens. + targets (`str` or `List[str]`, *optional*): + When passed, the model will limit the scores to the passed targets instead of looking up in the whole + vocab. If the provided targets are not in the model vocab, they will be tokenized and the first + resulting token will be used (with a warning, and that might be slower). + top_k (`int`, *optional*): + When passed, overrides the number of predictions to return. + + Return: + A list or a list of list of `dict`: Each result comes as list of dictionaries with the following keys: + + - **sequence** (`str`) -- The corresponding input with the mask token prediction. + - **score** (`float`) -- The corresponding probability. + - **token** (`int`) -- The predicted token id (to replace the masked one). + - **token_str** (`str`) -- The predicted token (to replace the masked one). + """ + outputs = super().__call__(inputs, **kwargs) + if isinstance(inputs, list) and len(inputs) == 1: + return outputs[0] + return outputs diff --git a/tests/ut/transformers/pipelines/test_pipelines_fill_mask.py b/tests/ut/transformers/pipelines/test_pipelines_fill_mask.py new file mode 100644 index 00000000..0ac90ca3 --- /dev/null +++ b/tests/ut/transformers/pipelines/test_pipelines_fill_mask.py @@ -0,0 +1,379 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +from mindnlp.transformers import MODEL_FOR_MASKED_LM_MAPPING,FillMaskPipeline, pipeline +from mindnlp.transformers.pipelines.base import PipelineException +from mindnlp.utils.testing_utils import ( + is_pipeline_test, + nested_simplify, + require_mindspore, + slow, +) + +from .test_pipelines_common import ANY + + +@is_pipeline_test +class FillMaskPipelineTests(unittest.TestCase): + model_mapping = MODEL_FOR_MASKED_LM_MAPPING + + def tearDown(self): + super().tearDown() + # clean-up as much as possible GPU memory occupied + gc.collect() + + @require_mindspore + def test_small_model_ms(self): + unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", top_k=2) + + outputs = unmasker("My name is ") + self.assertEqual( + nested_simplify(outputs, decimals=6), + [ + {"sequence": "My name is Maul", "score": 2.2e-05, "token": 35676, "token_str": " Maul"}, + {"sequence": "My name isELS", "score": 2.2e-05, "token": 16416, "token_str": "ELS"}, + ], + ) + + outputs = unmasker("The largest city in France is ") + self.assertEqual( + nested_simplify(outputs, decimals=6), + [ + { + "sequence": "The largest city in France is Maul", + "score": 2.2e-05, + "token": 35676, + "token_str": " Maul", + }, + {"sequence": "The largest city in France isELS", "score": 2.2e-05, "token": 16416, "token_str": "ELS"}, + ], + ) + + outputs = unmasker("My name is ", top_k=2) + + self.assertEqual( + nested_simplify(outputs, decimals=6), + [ + [ + { + "score": 2.2e-05, + "token": 35676, + "token_str": " Maul", + "sequence": "My name is Maul", + }, + {"score": 2.2e-05, "token": 16416, "token_str": "ELS", "sequence": "My name isELS"}, + ], + [ + { + "score": 2.2e-05, + "token": 35676, + "token_str": " Maul", + "sequence": "My name is Maul", + }, + {"score": 2.2e-05, "token": 16416, "token_str": "ELS", "sequence": "My name isELS"}, + ], + ], + ) + + + def test_fp16_casting(self): + pipe = pipeline( + "fill-mask", + model="hf-internal-testing/tiny-random-distilbert", + ) + + # convert model to fp16 + #pipe.model.half() + + response = pipe("Paris is the [MASK] of France.") + # We actually don't care about the result, we just want to make sure + # it works, meaning the float16 tensor got casted back to float32 + # for postprocessing. + self.assertIsInstance(response, list) + + @slow + @require_mindspore + def test_large_model_ms(self): + unmasker = pipeline(task="fill-mask", model="distilbert/distilroberta-base", top_k=2) + self.run_large_test(unmasker) + + def run_large_test(self, unmasker): + outputs = unmasker("My name is ") + self.assertEqual( + nested_simplify(outputs), + [ + {"sequence": "My name is John", "score": 0.008, "token": 610, "token_str": " John"}, + {"sequence": "My name is Chris", "score": 0.007, "token": 1573, "token_str": " Chris"}, + ], + ) + outputs = unmasker("The largest city in France is ") + self.assertEqual( + nested_simplify(outputs), + [ + { + "sequence": "The largest city in France is Paris", + "score": 0.251, + "token": 2201, + "token_str": " Paris", + }, + { + "sequence": "The largest city in France is Lyon", + "score": 0.214, + "token": 12790, + "token_str": " Lyon", + }, + ], + ) + + outputs = unmasker("My name is ", targets=[" Patrick", " Clara", " Teven"], top_k=3) + self.assertEqual( + nested_simplify(outputs), + [ + {"sequence": "My name is Patrick", "score": 0.005, "token": 3499, "token_str": " Patrick"}, + {"sequence": "My name is Clara", "score": 0.000, "token": 13606, "token_str": " Clara"}, + {"sequence": "My name is Te", "score": 0.000, "token": 2941, "token_str": " Te"}, + ], + ) + + dummy_str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit," * 100 + outputs = unmasker( + "My name is " + dummy_str, + tokenizer_kwargs={"truncation": True}, + ) + simplified = nested_simplify(outputs, decimals=4) + self.assertEqual( + [{"sequence": x["sequence"][:100]} for x in simplified], + [ + {"sequence": f"My name is,{dummy_str}"[:100]}, + {"sequence": f"My name is:,{dummy_str}"[:100]}, + ], + ) + self.assertEqual( + [{k: x[k] for k in x if k != "sequence"} for x in simplified], + [ + {"score": 0.2819, "token": 6, "token_str": ","}, + {"score": 0.0954, "token": 46686, "token_str": ":,"}, + ], + ) + + @require_mindspore + def test_model_no_pad_ms(self): + unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base") + unmasker.tokenizer.pad_token_id = None + unmasker.tokenizer.pad_token = None + self.run_pipeline_test(unmasker, []) + + def get_test_pipeline(self, model, tokenizer, processor): + if tokenizer is None or tokenizer.mask_token_id is None: + self.skipTest("The provided tokenizer has no mask token, (probably reformer or wav2vec2)") + + fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer) + examples = [ + f"This is another {tokenizer.mask_token} test", + ] + return fill_masker, examples + + def run_pipeline_test(self, fill_masker, examples): + tokenizer = fill_masker.tokenizer + model = fill_masker.model + outputs = fill_masker( + f"This is a {tokenizer.mask_token}", + ) + self.assertEqual( + outputs, + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + ) + + outputs = fill_masker([f"This is a {tokenizer.mask_token}"]) + self.assertEqual( + outputs, + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + ) + + outputs = fill_masker([f"This is a {tokenizer.mask_token}", f"Another {tokenizer.mask_token} great test."]) + self.assertEqual( + outputs, + [ + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + ], + ) + + with self.assertRaises(ValueError): + fill_masker([None]) + # No mask_token is not supported + with self.assertRaises(PipelineException): + fill_masker("This is") + + self.run_test_top_k(model, tokenizer) + self.run_test_targets(model, tokenizer) + self.run_test_top_k_targets(model, tokenizer) + self.fill_mask_with_duplicate_targets_and_top_k(model, tokenizer) + self.fill_mask_with_multiple_masks(model, tokenizer) + + def run_test_targets(self, model, tokenizer): + vocab = tokenizer.get_vocab() + targets = sorted(vocab.keys())[:2] + # Pipeline argument + fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer, targets=targets) + outputs = fill_masker(f"This is a {tokenizer.mask_token}") + self.assertEqual( + outputs, + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + ) + target_ids = {vocab[el] for el in targets} + self.assertEqual({el["token"] for el in outputs}, target_ids) + processed_targets = [tokenizer.decode([x]) for x in target_ids] + self.assertEqual({el["token_str"] for el in outputs}, set(processed_targets)) + + # Call argument + fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer) + outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets=targets) + self.assertEqual( + outputs, + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + ) + target_ids = {vocab[el] for el in targets} + self.assertEqual({el["token"] for el in outputs}, target_ids) + processed_targets = [tokenizer.decode([x]) for x in target_ids] + self.assertEqual({el["token_str"] for el in outputs}, set(processed_targets)) + + # Score equivalence + outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets=targets) + tokens = [top_mask["token_str"] for top_mask in outputs] + scores = [top_mask["score"] for top_mask in outputs] + + # For some BPE tokenizers, `` is removed during decoding, so `token_str` won't be the same as in `targets`. + if set(tokens) == set(targets): + unmasked_targets = fill_masker(f"This is a {tokenizer.mask_token}", targets=tokens) + target_scores = [top_mask["score"] for top_mask in unmasked_targets] + self.assertEqual(nested_simplify(scores), nested_simplify(target_scores)) + + # Raises with invalid + with self.assertRaises(ValueError): + outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets=[]) + # For some tokenizers, `""` is actually in the vocabulary and the expected error won't raised + if "" not in tokenizer.get_vocab(): + with self.assertRaises(ValueError): + outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets=[""]) + with self.assertRaises(ValueError): + outputs = fill_masker(f"This is a {tokenizer.mask_token}", targets="") + + def run_test_top_k(self, model, tokenizer): + fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer, top_k=2) + outputs = fill_masker(f"This is a {tokenizer.mask_token}") + self.assertEqual( + outputs, + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + ) + + fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer) + outputs2 = fill_masker(f"This is a {tokenizer.mask_token}", top_k=2) + self.assertEqual( + outputs2, + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + ) + self.assertEqual(nested_simplify(outputs), nested_simplify(outputs2)) + + def run_test_top_k_targets(self, model, tokenizer): + vocab = tokenizer.get_vocab() + fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer) + + # top_k=2, ntargets=3 + targets = sorted(vocab.keys())[:3] + outputs = fill_masker(f"This is a {tokenizer.mask_token}", top_k=2, targets=targets) + + # If we use the most probably targets, and filter differently, we should still + # have the same results + targets2 = [el["token_str"] for el in sorted(outputs, key=lambda x: x["score"], reverse=True)] + # For some BPE tokenizers, `` is removed during decoding, so `token_str` won't be the same as in `targets`. + if set(targets2).issubset(targets): + outputs2 = fill_masker(f"This is a {tokenizer.mask_token}", top_k=3, targets=targets2) + # They should yield exactly the same result + self.assertEqual(nested_simplify(outputs), nested_simplify(outputs2)) + + def fill_mask_with_duplicate_targets_and_top_k(self, model, tokenizer): + fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer) + vocab = tokenizer.get_vocab() + # String duplicates + id duplicates + targets = sorted(vocab.keys())[:3] + targets = [targets[0], targets[1], targets[0], targets[2], targets[1]] + outputs = fill_masker(f"My name is {tokenizer.mask_token}", targets=targets, top_k=10) + + # The target list contains duplicates, so we can't output more + # than them + self.assertEqual(len(outputs), 3) + + def fill_mask_with_multiple_masks(self, model, tokenizer): + fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer) + + outputs = fill_masker( + f"This is a {tokenizer.mask_token} {tokenizer.mask_token} {tokenizer.mask_token}", top_k=2 + ) + self.assertEqual( + outputs, + [ + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + [ + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + {"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)}, + ], + ], + ) -- Gitee