diff --git a/README.md b/README.md
index bf7cbc495d3b9a0938b08693cecd19fddd12db9a..c87eff422ba537185b16d5c10e9808a595ce3e71 100644
--- a/README.md
+++ b/README.md
@@ -51,6 +51,7 @@ pip uninstall mindpet
| LowRankAdapter | Compacter: Efficient low-rank hypercom plex adapter layers | [TK_DeltaAlgorithm_README](doc/TK_DeltaAlgorithm_README.md) 第四章 |
| BitFit | BitFit: Simple Parameter-efficient Fine-tuning for Transformer-based Masked Language-models | [TK_DeltaAlgorithm_README](doc/TK_DeltaAlgorithm_README.md) 第五章 |
| R_Drop | R-Drop: Regularized Dropout for Neural Networks | [TK_DeltaAlgorithm_README](doc/TK_DeltaAlgorithm_README.md) 第六章 |
+| PromptTuning | The Power of Scale for Parameter-Efficient Prompt Tuning | [TK_DeltaAlgorithm_README](doc/TK_DeltaAlgorithm_README.md) 第七章 |
diff --git a/doc/TK_DeltaAlgorithm_README.md b/doc/TK_DeltaAlgorithm_README.md
index d29f76963d334344939fc0c59345284576e4aaf8..8795db6908c1046e67d660707bcbf2673e16ebc1 100644
--- a/doc/TK_DeltaAlgorithm_README.md
+++ b/doc/TK_DeltaAlgorithm_README.md
@@ -1463,6 +1463,161 @@ class BertClsModel(BaseModel):
+## 七、Prompt-Tuning算法
+
+### 7.1 算法介绍
+
+Prompt-Tuning算法通过为文本任务的输入提娜佳前缀提示信息Prompt,无需修改上游预训练模型参数,即能实现下游任务的微调。前缀Prompt并非来自人工标注,而是由深度神经网络表示,对每一个下游任务,在微调时只需要冻结预训练权重,额外训练Prompt部分的参数即可。
+
+算法原理如下图所示,算法具体细节可参考相关论文[The Power of Scale for Parameter-Efficient Prompt Tuning](https://arxiv.org/abs/2104.08691)。
+
+
Prompt-Tuning算法原理图: 对于每个下游任务,给输入添加前缀信息,冻结预训练模型参数,只训练这些前缀。
+
+
+
+### 7.2 API接口
+
+#### PromptTuning
+
+```python
+class tk.delta.prompt_tuning.PromptTuning(num_virtual_tokens,
+ token_dim,
+ num_transformer_submodules)
+```
+
+**参数**
+
+- **num_virtual_tokens**(int) - 提示词缀标记的长度。
+- **token_dim**(int)- embedding后每个标记对应向量的维度,与原模型hidden_size一致。
+- **num_transformer_submodules**(int) - 原模型中transformer子模块的个数,默认为1。
+
+
+**输入**
+
+shape为 `(batch_size, num_virtual_tokens * num_transformer_submodules)` 的Tensor。参数中的 `batch_size` 应等于模型参数中的 `batch_size` 。
+
+**输出**
+
+shape为 `(batch_size, num_virtual_tokens * num_transformer_submodules, hidden_size)` 的Tensor 。参数中的 `batch_size`、`hidden_size`应等于模型参数中的 `batch_size`、`hidden_size`。
+
+
+
+**异常**
+
+- **TypeError** - `num_virtual_tokens`不是正整数。
+- **TypeError** - `token_dim`不是正整数。
+- **TypeError** - `num_transformer_submodules`不是正整数。
+
+### 7.3 使用样例
+
+通过以下步骤将模型结构中经过Embedding层的输入与带有PromptTuning结构的模块进行拼接,并冻结网络进行训练:
+
+1)安装mindpet工具包。([安装方法参考《README.md》第二章](../README.md))
+
+2)在模型的主体结构中,从工具包引入`PromptTuning`类,初始化PromptTuning模块和用于输入模块的prompt_ids,并在模型construct中,将prompt_output拼接至原模型的embedding_output之前,裁剪至原Tensor长度后替换原模型embedding_output。此外,根据不同模型的输入,对例如input_ids等Tensor进行类似拼接替换操作。PromptTuning相关参数可参考API接口自行指定。如果进行分布式训练,可调用`shard`方法指定分布式策略。
+
+```python
+from tk.delta import PromptTuning
+# original ModelClass
+def __init__(self, **kwargs):
+ # add promttuning initialization
+ self.prompt_cell = PromptTuning(num_virtual_tokens=20, token_dim=4096, num_transformer_submodules=1)
+ self.prompt_ids = Tensor(list(range(0, num_virtual_tokens * num_transformer_submodules)), dtype=mstype.int32)
+ self.prompt_ids = self.expand_dims(self.prompt_ids, 0)
+ # if distributed training is required, invoke shard method
+ self.tile = P.Tile().shard(((1, 1),))
+ self.concat = P.Concat(axis=1).shard(((config.parallel_config.data_parallel, 1, 1), (config.parallel_config.data_parallel, 1, 1)))
+ self.slice = P.StridedSlice().shard(((config.parallel_config.data_parallel, 1, 1),))
+
+def construct(self, input_ids, embedding_output, **kwargs):
+ # get prompt embedding output
+ prompt_ids = self.tile(self.prompt_ids, (batch_size, 1))
+ prompt_output = self.cast(self.prompt_cell(prompt_ids), mstype.float16)
+ # concat prompt_output and embedding_output
+ embedding_output = self.concat((prompt_output, embedding_output))
+ embedding_output = self.slice(embedding_output,
+ (0, 0, 0),
+ (batch_size, seq_length, self.hidden_size),
+ (1, 1, 1))
+```
+
+3)在训练脚本中,从工具包中引入`freeze_delta`方法,定义优化器之前调用`freeze_delta`冻结除`PromptTuning`模块外其它原模型权重。([冻结方法参考《TK_GraphOperation_README.md》第一章](TK_GraphOperation_README.md))
+
+```Python
+from tk.graph import freeze_delta
+
+# freeze all cells except LoRA and head
+freeze_delta(model=network, mode='prompttuning')
+```
+
+然后从工具包中引入`TrainableParamsCheckPoint`类,将保存ckpt的类改为`TrainableParamsCheckPoint`,仅保存需要更新的参数,可节约存储空间。([详细方法参考《TK_GraphOperation_README.md》第二章](TK_GraphOperation_README.md))
+
+由于微调后只保存了部分参数,推理时具体如何加载ckpt请参考[附录A](###A 分布式微调后模型评估方法)。
+
+```python
+from tk.graph import TrainableParamsCheckPoint
+
+# original callback
+# ckpt_callback = ModelCheckpoint(...)
+
+# replace ModelCheckpoint with TrainableParamsCheckPoint
+ckpt_callback = TrainableParamsCheckPoint(...)
+```
+
+### 7.4 实验效果
+
+下面实验基于Mindspore/Mindformers开源仓中的[llama](https://gitee.com/mindspore/mindformers/blob/dev/docs/model_cards/llama.md)复现。
+
+
+
+
+ 模型 |
+ 下游任务 |
+ 模式 |
+ 训练参数 |
+ 微调参数占比 |
+ 静态内存+动态内存 |
+ Em/F1 |
+
+
+ epoch |
+ 优化器 |
+ 学习率 |
+ seq_length |
+ num_virtual_tokens |
+
+
+
+
+ llama_7b |
+ SQuAD |
+ baseline |
+ 2 |
+ Adam |
+ 3.00E-05 |
+ 2048 |
+ \ |
+ 100% |
+ 21976MB+8744MB |
+ 82.57/65.84 |
+
+
+ PromptTuning |
+ 2 |
+ Adam |
+ 1.00E-02 |
+ 2048 |
+ 20 |
+ 0.0095% |
+ 14184MB+5265MB |
+ 84.35/67.88 |
+
+
+
+
+
+
+
## 附录
### A 微调后模型评估方法
diff --git a/doc/image/prompt_tuning.png b/doc/image/prompt_tuning.png
new file mode 100644
index 0000000000000000000000000000000000000000..316652df5ad558a4a62cb825fbce05a074ecb39f
Binary files /dev/null and b/doc/image/prompt_tuning.png differ
diff --git a/test/unit_test/delta/test_prompt_tuning.py b/test/unit_test/delta/test_prompt_tuning.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2763980cc94d26596f3c78f1a25715d0178af88
--- /dev/null
+++ b/test/unit_test/delta/test_prompt_tuning.py
@@ -0,0 +1,135 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright © Huawei Technologies Co., Ltd. 2010-2022. All rights reserved.
+
+import os
+import logging
+import unittest
+
+import mindspore
+import pytest
+from mindspore.common.tensor import Tensor
+
+from tk.delta.prompt_tuning import PromptTuning
+
+logging.getLogger().setLevel(logging.INFO)
+mindspore.set_context(device_id=1)
+
+class TestPromptTuning(unittest.TestCase):
+ # _check_num
+ def test_check_num_with_zero_num_virtual_tokens(self):
+ logging.info('Start test_check_num_with_zero_num_virtual_tokens.')
+ with self.assertRaises(ValueError) as ex:
+ PromptTuning(num_virtual_tokens=0, token_dim=1, num_transformer_submodules=1)
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_zero_num_virtual_tokens.')
+
+ def test_check_num_with_float_num_virtual_tokens(self):
+ logging.info('Start test_check_num_with_float_num_virtual_tokens.')
+ with self.assertRaises(TypeError) as ex:
+ PromptTuning(num_virtual_tokens=1.5, token_dim=1, num_transformer_submodules=1)
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_float_num_virtual_tokens.')
+
+ def test_check_num_with_negative_num_virtual_tokens(self):
+ logging.info('Start test_check_num_with_negative_num_virtual_tokens.')
+ with self.assertRaises(ValueError) as ex:
+ PromptTuning(num_virtual_tokens=-1, token_dim=1, num_transformer_submodules=1)
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_negative_num_virtual_tokens.')
+
+ def test_check_num_with_str_num_virtual_tokens(self):
+ logging.info('Start test_check_num_with_str_num_virtual_tokens.')
+ with self.assertRaises(TypeError) as ex:
+ PromptTuning(num_virtual_tokens='a', token_dim=1, num_transformer_submodules=1)
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_str_num_virtual_tokens.')
+
+ def test_check_num_with_bool_num_virtual_tokens(self):
+ logging.info('Start test_check_num_with_bool_num_virtual_tokens.')
+ with self.assertRaises(TypeError) as ex:
+ PromptTuning(num_virtual_tokens=True, token_dim=1, num_transformer_submodules=1)
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_bool_num_virtual_tokens.')
+
+ def test_check_num_with_zero_token_dim(self):
+ logging.info('Start test_check_num_with_zero_token_dim.')
+ with self.assertRaises(ValueError) as ex:
+ PromptTuning(num_virtual_tokens=1, token_dim=0, num_transformer_submodules=1)
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_zero_token_dim.')
+
+ def test_check_num_with_float_token_dim(self):
+ logging.info('Start test_check_num_with_float_token_dim.')
+ with self.assertRaises(TypeError) as ex:
+ PromptTuning(num_virtual_tokens=1, token_dim=1.5, num_transformer_submodules=1)
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_float_token_dim.')
+
+ def test_check_num_with_negative_token_dim(self):
+ logging.info('Start test_check_num_with_negative_token_dim.')
+ with self.assertRaises(ValueError) as ex:
+ PromptTuning(num_virtual_tokens=1, token_dim=-1, num_transformer_submodules=1)
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_negative_token_dim.')
+
+ def test_check_num_with_str_token_dim(self):
+ logging.info('Start test_check_num_with_str_token_dim.')
+ with self.assertRaises(TypeError) as ex:
+ PromptTuning(num_virtual_tokens=1, token_dim='a', num_transformer_submodules=1)
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_str_token_dim.')
+
+ def test_check_num_with_bool_token_dim(self):
+ logging.info('Start test_check_num_with_bool_token_dim.')
+ with self.assertRaises(TypeError) as ex:
+ PromptTuning(num_virtual_tokens=1, token_dim=False, num_transformer_submodules=1)
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_bool_token_dim.')
+
+ def test_check_num_with_zero_num_transformer_submodules(self):
+ logging.info('Start test_check_num_with_zero_num_transformer_submodules.')
+ with self.assertRaises(ValueError) as ex:
+ PromptTuning(num_virtual_tokens=1, token_dim=1, num_transformer_submodules=0)
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_zero_num_transformer_submodules.')
+
+ def test_check_num_with_float_num_transformer_submodules(self):
+ logging.info('Start test_check_num_with_float_num_transformer_submodules.')
+ with self.assertRaises(TypeError) as ex:
+ PromptTuning(num_virtual_tokens=1, token_dim=1, num_transformer_submodules=1.5)
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_float_num_transformer_submodules.')
+
+ def test_check_num_with_negative_num_transformer_submodules(self):
+ logging.info('Start test_check_num_with_negative_num_transformer_submodules.')
+ with self.assertRaises(ValueError) as ex:
+ PromptTuning(num_virtual_tokens=1, token_dim=1, num_transformer_submodules=-1)
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_negative_num_transformer_submodules.')
+
+ def test_check_num_with_str_num_transformer_submodules(self):
+ logging.info('Start test_check_num_with_str_num_transformer_submodules.')
+ with self.assertRaises(TypeError) as ex:
+ PromptTuning(num_virtual_tokens=1, token_dim=1, num_transformer_submodules='a')
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_str_num_transformer_submodules.')
+
+ def test_check_num_with_bool_num_transformer_submodules(self):
+ logging.info('Start test_check_num_with_bool_num_transformer_submodules.')
+ with self.assertRaises(TypeError) as ex:
+ PromptTuning(num_virtual_tokens=1, token_dim=1, num_transformer_submodules=True)
+ logging.error(ex.exception)
+ logging.info('Finish test_check_num_with_bool_num_transformer_submodules.')
+
+ def test_params_with_legal_tk_delta_prompttuning_embedding(self):
+ logging.info('Start test_params_with_legal_tk_delta_prompttuning_embedding')
+ prompttuning = PromptTuning(num_virtual_tokens=1, token_dim=1, num_transformer_submodules=1)
+ target = Tensor([[1], [1]]).asnumpy() == prompttuning.tk_delta_prompttuning_embedding.asnumpy()
+ for result in target:
+ self.assertTrue(result)
+ logging.info("Finish test_params_with_legal_tk_delta_prompttuning_embedding")
+
+
+if __name__ == '__main__':
+ pytest.main(["-s", os.path.abspath(__file__)])
diff --git a/tk/delta/__init__.py b/tk/delta/__init__.py
index 6fcb9286f325c1c215c46f38a2136a852e26ad5a..ad821eef79f512bf5f891704d56c4639503bb660 100644
--- a/tk/delta/__init__.py
+++ b/tk/delta/__init__.py
@@ -7,6 +7,7 @@ from tk.delta.prefix_layer import PrefixLayer
from tk.delta.low_rank_adapter import LowRankAdapterDense, LowRankAdapterLayer
from tk.delta.adapter import AdapterDense, AdapterLayer
from tk.delta.r_drop import RDropLoss, rdrop_repeat
+from tk.delta.prompt_tuning import PromptTuning
__all__ = ['LoRADense', 'PrefixLayer', 'LowRankAdapterDense', 'LowRankAdapterLayer',
- 'AdapterDense', 'AdapterLayer', 'RDropLoss', 'rdrop_repeat']
+ 'AdapterDense', 'AdapterLayer', 'RDropLoss', 'rdrop_repeat', 'PromptTuning']
diff --git a/tk/delta/prompt_tuning.py b/tk/delta/prompt_tuning.py
new file mode 100644
index 0000000000000000000000000000000000000000..edd426be87550d605aed1d2a5356edc8b5151b23
--- /dev/null
+++ b/tk/delta/prompt_tuning.py
@@ -0,0 +1,42 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright © Huawei Technologies Co., Ltd. 2022-2023. All rights reserved.
+import mindspore as ms
+import mindspore.nn as nn
+
+from tk.utils.version_utils import is_version_ge
+
+if is_version_ge(ms.__version__, '2.0.0'):
+ import mindspore._checkparam as Validator
+
+ INC_LEFT = Validator.INC_LEFT
+else:
+ from mindspore._checkparam import Validator, Rel
+
+ INC_LEFT = Rel.INC_LEFT
+
+
+class PromptTuning(nn.Cell):
+ """Define a cell with PromptTuning structure.
+
+ Attributes:
+ num_virtual_tokens (int): The number of virtual tokens to use.
+ token_dim (int): The hidden embedding dimension of the base model.
+ num_transformer_submodules (int): The number of transformer submodules in the base model.
+ """
+
+ def __init__(self,
+ num_virtual_tokens: int,
+ token_dim: int,
+ num_transformer_submodules: int = 1):
+ super().__init__()
+ self.num_virtual_tokens = Validator.check_positive_int(num_virtual_tokens, int, "num_virtual_tokens")
+ self.token_dim = Validator.check_positive_int(token_dim, int, "token_dim")
+ self.num_transformer_submodules = Validator.check_positive_int(num_transformer_submodules, int,
+ "num_transformer_submodules")
+ self.total_virtual_tokens = self.num_virtual_tokens * self.num_transformer_submodules
+ self.tk_delta_prompttuning_embedding = nn.Embedding(self.total_virtual_tokens, self.token_dim)
+
+ def construct(self, indices):
+ prompt_embeddings = self.tk_delta_prompttuning_embedding(indices)
+ return prompt_embeddings
diff --git a/tk/utils/constants.py b/tk/utils/constants.py
index 8d0f90d9f153880e688638094b00afb66030d9e6..c7a9be0101cb8c32c2fe153382dcdfafb9ec9b40 100644
--- a/tk/utils/constants.py
+++ b/tk/utils/constants.py
@@ -104,5 +104,5 @@ EVAL_INFER_TASK_NAMES = [EVALUATE_TASK_NAME, INFER_TASK_NAME]
TK_SDK_INTERFACE_NAMES = [FINETUNE_TASK_NAME, EVALUATE_TASK_NAME, INFER_TASK_NAME]
# 微调算法清单
-DELTA_LIST = ['lora', 'prefixtuning', 'adapter', 'low_rank_adapter', 'bitfit']
+DELTA_LIST = ['lora', 'prefixtuning', 'adapter', 'low_rank_adapter', 'bitfit', 'prompttuning']