diff --git a/README.md b/README.md index bf7cbc495d3b9a0938b08693cecd19fddd12db9a..c87eff422ba537185b16d5c10e9808a595ce3e71 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,7 @@ pip uninstall mindpet | LowRankAdapter | Compacter: Efficient low-rank hypercom plex adapter layers | [TK_DeltaAlgorithm_README](doc/TK_DeltaAlgorithm_README.md) 第四章 | | BitFit | BitFit: Simple Parameter-efficient Fine-tuning for Transformer-based Masked Language-models | [TK_DeltaAlgorithm_README](doc/TK_DeltaAlgorithm_README.md) 第五章 | | R_Drop | R-Drop: Regularized Dropout for Neural Networks | [TK_DeltaAlgorithm_README](doc/TK_DeltaAlgorithm_README.md) 第六章 | +| PromptTuning | The Power of Scale for Parameter-Efficient Prompt Tuning | [TK_DeltaAlgorithm_README](doc/TK_DeltaAlgorithm_README.md) 第七章 | diff --git a/doc/TK_DeltaAlgorithm_README.md b/doc/TK_DeltaAlgorithm_README.md index d29f76963d334344939fc0c59345284576e4aaf8..8795db6908c1046e67d660707bcbf2673e16ebc1 100644 --- a/doc/TK_DeltaAlgorithm_README.md +++ b/doc/TK_DeltaAlgorithm_README.md @@ -1463,6 +1463,161 @@ class BertClsModel(BaseModel): +## 七、Prompt-Tuning算法 + +### 7.1 算法介绍 + +Prompt-Tuning算法通过为文本任务的输入提娜佳前缀提示信息Prompt,无需修改上游预训练模型参数,即能实现下游任务的微调。前缀Prompt并非来自人工标注,而是由深度神经网络表示,对每一个下游任务,在微调时只需要冻结预训练权重,额外训练Prompt部分的参数即可。 + +算法原理如下图所示,算法具体细节可参考相关论文[The Power of Scale for Parameter-Efficient Prompt Tuning](https://arxiv.org/abs/2104.08691)。 + +

Prompt-Tuning算法原理图: 对于每个下游任务,给输入添加前缀信息,冻结预训练模型参数,只训练这些前缀。
+ + + +### 7.2 API接口 + +#### PromptTuning + +```python +class tk.delta.prompt_tuning.PromptTuning(num_virtual_tokens, + token_dim, + num_transformer_submodules) +``` + +**参数** + +- **num_virtual_tokens**(int) - 提示词缀标记的长度。 +- **token_dim**(int)- embedding后每个标记对应向量的维度,与原模型hidden_size一致。 +- **num_transformer_submodules**(int) - 原模型中transformer子模块的个数,默认为1。 + + +**输入** + +shape为 `(batch_size, num_virtual_tokens * num_transformer_submodules)` 的Tensor。参数中的 `batch_size` 应等于模型参数中的 `batch_size` 。 + +**输出** + +shape为 `(batch_size, num_virtual_tokens * num_transformer_submodules, hidden_size)` 的Tensor 。参数中的 `batch_size`、`hidden_size`应等于模型参数中的 `batch_size`、`hidden_size`。 + + + +**异常** + +- **TypeError** - `num_virtual_tokens`不是正整数。 +- **TypeError** - `token_dim`不是正整数。 +- **TypeError** - `num_transformer_submodules`不是正整数。 + +### 7.3 使用样例 + +通过以下步骤将模型结构中经过Embedding层的输入与带有PromptTuning结构的模块进行拼接,并冻结网络进行训练: + +1)安装mindpet工具包。([安装方法参考《README.md》第二章](../README.md)) + +2)在模型的主体结构中,从工具包引入`PromptTuning`类,初始化PromptTuning模块和用于输入模块的prompt_ids,并在模型construct中,将prompt_output拼接至原模型的embedding_output之前,裁剪至原Tensor长度后替换原模型embedding_output。此外,根据不同模型的输入,对例如input_ids等Tensor进行类似拼接替换操作。PromptTuning相关参数可参考API接口自行指定。如果进行分布式训练,可调用`shard`方法指定分布式策略。 + +```python +from tk.delta import PromptTuning +# original ModelClass +def __init__(self, **kwargs): + # add promttuning initialization + self.prompt_cell = PromptTuning(num_virtual_tokens=20, token_dim=4096, num_transformer_submodules=1) + self.prompt_ids = Tensor(list(range(0, num_virtual_tokens * num_transformer_submodules)), dtype=mstype.int32) + self.prompt_ids = self.expand_dims(self.prompt_ids, 0) + # if distributed training is required, invoke shard method + self.tile = P.Tile().shard(((1, 1),)) + self.concat = P.Concat(axis=1).shard(((config.parallel_config.data_parallel, 1, 1), (config.parallel_config.data_parallel, 1, 1))) + self.slice = P.StridedSlice().shard(((config.parallel_config.data_parallel, 1, 1),)) + +def construct(self, input_ids, embedding_output, **kwargs): + # get prompt embedding output + prompt_ids = self.tile(self.prompt_ids, (batch_size, 1)) + prompt_output = self.cast(self.prompt_cell(prompt_ids), mstype.float16) + # concat prompt_output and embedding_output + embedding_output = self.concat((prompt_output, embedding_output)) + embedding_output = self.slice(embedding_output, + (0, 0, 0), + (batch_size, seq_length, self.hidden_size), + (1, 1, 1)) +``` + +3)在训练脚本中,从工具包中引入`freeze_delta`方法,定义优化器之前调用`freeze_delta`冻结除`PromptTuning`模块外其它原模型权重。([冻结方法参考《TK_GraphOperation_README.md》第一章](TK_GraphOperation_README.md)) + +```Python +from tk.graph import freeze_delta + +# freeze all cells except LoRA and head +freeze_delta(model=network, mode='prompttuning') +``` + +然后从工具包中引入`TrainableParamsCheckPoint`类,将保存ckpt的类改为`TrainableParamsCheckPoint`,仅保存需要更新的参数,可节约存储空间。([详细方法参考《TK_GraphOperation_README.md》第二章](TK_GraphOperation_README.md)) + +由于微调后只保存了部分参数,推理时具体如何加载ckpt请参考[附录A](###A 分布式微调后模型评估方法)。 + +```python +from tk.graph import TrainableParamsCheckPoint + +# original callback +# ckpt_callback = ModelCheckpoint(...) + +# replace ModelCheckpoint with TrainableParamsCheckPoint +ckpt_callback = TrainableParamsCheckPoint(...) +``` + +### 7.4 实验效果 + +下面实验基于Mindspore/Mindformers开源仓中的[llama](https://gitee.com/mindspore/mindformers/blob/dev/docs/model_cards/llama.md)复现。 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
模型下游任务模式训练参数微调参数占比静态内存+动态内存Em/F1
epoch优化器学习率seq_lengthnum_virtual_tokens
llama_7bSQuADbaseline2Adam3.00E-052048\100%21976MB+8744MB82.57/65.84
PromptTuning2Adam1.00E-022048200.0095%14184MB+5265MB84.35/67.88
+ + + + ## 附录 ### A 微调后模型评估方法 diff --git a/doc/image/prompt_tuning.png b/doc/image/prompt_tuning.png new file mode 100644 index 0000000000000000000000000000000000000000..316652df5ad558a4a62cb825fbce05a074ecb39f Binary files /dev/null and b/doc/image/prompt_tuning.png differ diff --git a/test/unit_test/delta/test_prompt_tuning.py b/test/unit_test/delta/test_prompt_tuning.py new file mode 100644 index 0000000000000000000000000000000000000000..e2763980cc94d26596f3c78f1a25715d0178af88 --- /dev/null +++ b/test/unit_test/delta/test_prompt_tuning.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright © Huawei Technologies Co., Ltd. 2010-2022. All rights reserved. + +import os +import logging +import unittest + +import mindspore +import pytest +from mindspore.common.tensor import Tensor + +from tk.delta.prompt_tuning import PromptTuning + +logging.getLogger().setLevel(logging.INFO) +mindspore.set_context(device_id=1) + +class TestPromptTuning(unittest.TestCase): + # _check_num + def test_check_num_with_zero_num_virtual_tokens(self): + logging.info('Start test_check_num_with_zero_num_virtual_tokens.') + with self.assertRaises(ValueError) as ex: + PromptTuning(num_virtual_tokens=0, token_dim=1, num_transformer_submodules=1) + logging.error(ex.exception) + logging.info('Finish test_check_num_with_zero_num_virtual_tokens.') + + def test_check_num_with_float_num_virtual_tokens(self): + logging.info('Start test_check_num_with_float_num_virtual_tokens.') + with self.assertRaises(TypeError) as ex: + PromptTuning(num_virtual_tokens=1.5, token_dim=1, num_transformer_submodules=1) + logging.error(ex.exception) + logging.info('Finish test_check_num_with_float_num_virtual_tokens.') + + def test_check_num_with_negative_num_virtual_tokens(self): + logging.info('Start test_check_num_with_negative_num_virtual_tokens.') + with self.assertRaises(ValueError) as ex: + PromptTuning(num_virtual_tokens=-1, token_dim=1, num_transformer_submodules=1) + logging.error(ex.exception) + logging.info('Finish test_check_num_with_negative_num_virtual_tokens.') + + def test_check_num_with_str_num_virtual_tokens(self): + logging.info('Start test_check_num_with_str_num_virtual_tokens.') + with self.assertRaises(TypeError) as ex: + PromptTuning(num_virtual_tokens='a', token_dim=1, num_transformer_submodules=1) + logging.error(ex.exception) + logging.info('Finish test_check_num_with_str_num_virtual_tokens.') + + def test_check_num_with_bool_num_virtual_tokens(self): + logging.info('Start test_check_num_with_bool_num_virtual_tokens.') + with self.assertRaises(TypeError) as ex: + PromptTuning(num_virtual_tokens=True, token_dim=1, num_transformer_submodules=1) + logging.error(ex.exception) + logging.info('Finish test_check_num_with_bool_num_virtual_tokens.') + + def test_check_num_with_zero_token_dim(self): + logging.info('Start test_check_num_with_zero_token_dim.') + with self.assertRaises(ValueError) as ex: + PromptTuning(num_virtual_tokens=1, token_dim=0, num_transformer_submodules=1) + logging.error(ex.exception) + logging.info('Finish test_check_num_with_zero_token_dim.') + + def test_check_num_with_float_token_dim(self): + logging.info('Start test_check_num_with_float_token_dim.') + with self.assertRaises(TypeError) as ex: + PromptTuning(num_virtual_tokens=1, token_dim=1.5, num_transformer_submodules=1) + logging.error(ex.exception) + logging.info('Finish test_check_num_with_float_token_dim.') + + def test_check_num_with_negative_token_dim(self): + logging.info('Start test_check_num_with_negative_token_dim.') + with self.assertRaises(ValueError) as ex: + PromptTuning(num_virtual_tokens=1, token_dim=-1, num_transformer_submodules=1) + logging.error(ex.exception) + logging.info('Finish test_check_num_with_negative_token_dim.') + + def test_check_num_with_str_token_dim(self): + logging.info('Start test_check_num_with_str_token_dim.') + with self.assertRaises(TypeError) as ex: + PromptTuning(num_virtual_tokens=1, token_dim='a', num_transformer_submodules=1) + logging.error(ex.exception) + logging.info('Finish test_check_num_with_str_token_dim.') + + def test_check_num_with_bool_token_dim(self): + logging.info('Start test_check_num_with_bool_token_dim.') + with self.assertRaises(TypeError) as ex: + PromptTuning(num_virtual_tokens=1, token_dim=False, num_transformer_submodules=1) + logging.error(ex.exception) + logging.info('Finish test_check_num_with_bool_token_dim.') + + def test_check_num_with_zero_num_transformer_submodules(self): + logging.info('Start test_check_num_with_zero_num_transformer_submodules.') + with self.assertRaises(ValueError) as ex: + PromptTuning(num_virtual_tokens=1, token_dim=1, num_transformer_submodules=0) + logging.error(ex.exception) + logging.info('Finish test_check_num_with_zero_num_transformer_submodules.') + + def test_check_num_with_float_num_transformer_submodules(self): + logging.info('Start test_check_num_with_float_num_transformer_submodules.') + with self.assertRaises(TypeError) as ex: + PromptTuning(num_virtual_tokens=1, token_dim=1, num_transformer_submodules=1.5) + logging.error(ex.exception) + logging.info('Finish test_check_num_with_float_num_transformer_submodules.') + + def test_check_num_with_negative_num_transformer_submodules(self): + logging.info('Start test_check_num_with_negative_num_transformer_submodules.') + with self.assertRaises(ValueError) as ex: + PromptTuning(num_virtual_tokens=1, token_dim=1, num_transformer_submodules=-1) + logging.error(ex.exception) + logging.info('Finish test_check_num_with_negative_num_transformer_submodules.') + + def test_check_num_with_str_num_transformer_submodules(self): + logging.info('Start test_check_num_with_str_num_transformer_submodules.') + with self.assertRaises(TypeError) as ex: + PromptTuning(num_virtual_tokens=1, token_dim=1, num_transformer_submodules='a') + logging.error(ex.exception) + logging.info('Finish test_check_num_with_str_num_transformer_submodules.') + + def test_check_num_with_bool_num_transformer_submodules(self): + logging.info('Start test_check_num_with_bool_num_transformer_submodules.') + with self.assertRaises(TypeError) as ex: + PromptTuning(num_virtual_tokens=1, token_dim=1, num_transformer_submodules=True) + logging.error(ex.exception) + logging.info('Finish test_check_num_with_bool_num_transformer_submodules.') + + def test_params_with_legal_tk_delta_prompttuning_embedding(self): + logging.info('Start test_params_with_legal_tk_delta_prompttuning_embedding') + prompttuning = PromptTuning(num_virtual_tokens=1, token_dim=1, num_transformer_submodules=1) + target = Tensor([[1], [1]]).asnumpy() == prompttuning.tk_delta_prompttuning_embedding.asnumpy() + for result in target: + self.assertTrue(result) + logging.info("Finish test_params_with_legal_tk_delta_prompttuning_embedding") + + +if __name__ == '__main__': + pytest.main(["-s", os.path.abspath(__file__)]) diff --git a/tk/delta/__init__.py b/tk/delta/__init__.py index 6fcb9286f325c1c215c46f38a2136a852e26ad5a..ad821eef79f512bf5f891704d56c4639503bb660 100644 --- a/tk/delta/__init__.py +++ b/tk/delta/__init__.py @@ -7,6 +7,7 @@ from tk.delta.prefix_layer import PrefixLayer from tk.delta.low_rank_adapter import LowRankAdapterDense, LowRankAdapterLayer from tk.delta.adapter import AdapterDense, AdapterLayer from tk.delta.r_drop import RDropLoss, rdrop_repeat +from tk.delta.prompt_tuning import PromptTuning __all__ = ['LoRADense', 'PrefixLayer', 'LowRankAdapterDense', 'LowRankAdapterLayer', - 'AdapterDense', 'AdapterLayer', 'RDropLoss', 'rdrop_repeat'] + 'AdapterDense', 'AdapterLayer', 'RDropLoss', 'rdrop_repeat', 'PromptTuning'] diff --git a/tk/delta/prompt_tuning.py b/tk/delta/prompt_tuning.py new file mode 100644 index 0000000000000000000000000000000000000000..edd426be87550d605aed1d2a5356edc8b5151b23 --- /dev/null +++ b/tk/delta/prompt_tuning.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright © Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. +import mindspore as ms +import mindspore.nn as nn + +from tk.utils.version_utils import is_version_ge + +if is_version_ge(ms.__version__, '2.0.0'): + import mindspore._checkparam as Validator + + INC_LEFT = Validator.INC_LEFT +else: + from mindspore._checkparam import Validator, Rel + + INC_LEFT = Rel.INC_LEFT + + +class PromptTuning(nn.Cell): + """Define a cell with PromptTuning structure. + + Attributes: + num_virtual_tokens (int): The number of virtual tokens to use. + token_dim (int): The hidden embedding dimension of the base model. + num_transformer_submodules (int): The number of transformer submodules in the base model. + """ + + def __init__(self, + num_virtual_tokens: int, + token_dim: int, + num_transformer_submodules: int = 1): + super().__init__() + self.num_virtual_tokens = Validator.check_positive_int(num_virtual_tokens, int, "num_virtual_tokens") + self.token_dim = Validator.check_positive_int(token_dim, int, "token_dim") + self.num_transformer_submodules = Validator.check_positive_int(num_transformer_submodules, int, + "num_transformer_submodules") + self.total_virtual_tokens = self.num_virtual_tokens * self.num_transformer_submodules + self.tk_delta_prompttuning_embedding = nn.Embedding(self.total_virtual_tokens, self.token_dim) + + def construct(self, indices): + prompt_embeddings = self.tk_delta_prompttuning_embedding(indices) + return prompt_embeddings diff --git a/tk/utils/constants.py b/tk/utils/constants.py index 8d0f90d9f153880e688638094b00afb66030d9e6..c7a9be0101cb8c32c2fe153382dcdfafb9ec9b40 100644 --- a/tk/utils/constants.py +++ b/tk/utils/constants.py @@ -104,5 +104,5 @@ EVAL_INFER_TASK_NAMES = [EVALUATE_TASK_NAME, INFER_TASK_NAME] TK_SDK_INTERFACE_NAMES = [FINETUNE_TASK_NAME, EVALUATE_TASK_NAME, INFER_TASK_NAME] # 微调算法清单 -DELTA_LIST = ['lora', 'prefixtuning', 'adapter', 'low_rank_adapter', 'bitfit'] +DELTA_LIST = ['lora', 'prefixtuning', 'adapter', 'low_rank_adapter', 'bitfit', 'prompttuning']