diff --git a/doc/TK_DeltaAlgorithm_README.md b/doc/TK_DeltaAlgorithm_README.md
index d29f76963d334344939fc0c59345284576e4aaf8..c7aff9c850209294920fe588634c35291c3c9cc6 100644
--- a/doc/TK_DeltaAlgorithm_README.md
+++ b/doc/TK_DeltaAlgorithm_README.md
@@ -1462,6 +1462,199 @@ class BertClsModel(BaseModel):
+## 七、P-Tuning v2算法
+
+### 7.1 算法介绍
+
+P-Tuning v2该方法将可训练的连续提示向量独立添加到每个transformer层的输入中,只训练这部分任务相关的向量,保持预训练模型的参数不变。P-Tuning v2会在每个transformer层的key和value向量的前面插入l个用于更新参数的连续提示向量,然后冻结预训练模型的参数, 只更新这些向量的参数,就可以达到近似全参微调的效果。
+
+
+算法原理如下图所示,算法具体实现细节可参考论文[P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/abs/2110.07602)
+
+
![]()
+P-Tuning v2算法原理图: 对于每个下游任务,在网络的每一层添加一份连续提示向量,冻结预训练模型的其他参数,只训练这些向量。
+
+
+
+
+### 7.2 API接口
+
+``` python
+ class PrefixEncoder(pre_seq_len,
+ num_layers,
+ num_heads,
+ kv_channels,
+ prefix_projection,
+ projection_dim,
+ dropout_prob):
+```
+
+定义PrefixEncoder层
+
+
+
+**参数**
+
+- **pre_seq_len**(int) - 网络每层提示向量的长度.
+- **num_layers**(int) - 网络层数,与原模型参数一致。
+- **num_heads**(int) - 多头注意力头数,与原模型参数一致。
+- **kv_channels**(int) - `key`、`value`隐藏维度,与原模型参数一致。
+- **prefix_projection**(bool) - 是否使用MLP表征。
+- **projection_dim**(int) - MLP维度。
+- **dropout_prob**(float) - 丢弃率。
+
+
+
+**异常**
+
+- **TypeError** - `pre_seq_len`不是正整数。
+- **TypeError** - `num_layers`不是正整数。
+- **TypeError** - `num_heads`不是正整数。
+- **TypeError** - `kv_channels`不是正整数。
+- **TypeError** - `projection_dim`不是正整数。
+- **ValueError** - `dropout_prob`不在[0,1)之内。
+
+
+
+### 7.3 使用样例
+
+通过以下步骤将模型结构中`key`、`value`和`attention_mask`修改为新的`key`、`value`和`attention_mask`,冻结网络进行训练:
+
+1)安装mindpet工具包。([安装方法参考《README.md》第二章](../README.md))
+
+2)在模型的初始化时,从工具包中引入`PrefixEncoder`类,创建`prefixEncoder`,在`construct`时构造提示向量传递给网络的每层。
+
+```python
+class ChatModelWithPt2(ChatModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.prefix_encoder = PrefixEncoder(
+ config.pet_config.pre_seq_len,
+ config.pet_config.num_layers,
+ config.pet_config.num_heads,
+ config.pet_config.kv_channels,
+ config.pet_config.prefix_projection,
+ config.pet_config.projection_dim,
+ config.pet_config.dropout_prob
+ )
+ ...
+
+ def construct(self, ...):
+ prefix_key_values = self.prefix_encoder(batch_size)
+ return super().construct(..., prefix_key_values)
+```
+
+3)在模型的Attention结构中,将`prefixlayer`构造的每层`prefix_key_value`矩阵与原`key`、`value`矩阵进行`concat`操作。然后定义全为1的`help`矩阵,将原`attention_mask`矩阵与`help`矩阵进行`concat`(新的`attention_mask`矩阵shape与新的`query`*`key`矩阵的shape相同)。
+
+```python
+#模型的Attention层
+class SelfAttention(nn.Cell):
+ def add_prefix(prefix_key_value, pre_seq_len, key, value, attention_mask):
+ # [bs, num_heads, seq_length, head_dim]
+ seq_len = key.shape[2]
+
+ # [bs, num_heads, pre_seq_len, head_dim]
+ prefix_key = prefix_key_value[0]
+ prefix_value = prefix_key_value[1]
+ cat = P.Concat(2)
+ key = cat([prefix_key, key])
+ value = cat([prefix_value, value])
+
+ batch_size = attention_mask.shape[0]
+ prefix_mask = attention_mask.new_ones((batch_size, 1, seq_len, pre_seq_len))
+ m_cat = P.Concat(3)
+
+ # [bs, 1, seq_len, pre_seq_len + seq_len]
+ attention_mask = m_cat((prefix_mask, attention_mask))
+
+ return key, value, attention_mask
+
+ def construct(self, input_tensor, attention_mask):
+ ...
+ ...
+ key_layer, value_layer, attention_mask = self.add_prefix(
+ prefix_key_value,
+ self.pre_seq_len,
+ key_layer,
+ value_layer,
+ attention_mask
+ )
+ context_layer = self.attention(query_layer, key_layer, value_layer, attention_mask)
+ ...
+```
+
+4)在训练脚本中,从工具包中引入`freeze_delta`方法,定义优化器之前调用`freeze_delta`冻结除`Prefix`矩阵外其它原模型权重。注意,为了适配下游任务引入的额外模型结构无需冻结,可以用`exclude`参数指定无需冻结的结构名称。([冻结方法参考《TK_GraphOperation_README.md》第一章](TK_GraphOperation_README.md))
+
+```Python
+# freeze all cell except ptuning2
+freeze_delta(model=network, mode='ptuning2')
+```
+
+然后从工具包中引入`TrainableParamsCheckPoint`类,将保存ckpt的类改为`TrainableParamsCheckPoint`,仅保存需要更新的参数,可节约存储空间。([详细方法参考《TK_GraphOperation_README.md》第二章](TK_GraphOperation_README.md))
+
+由于微调后只保存了部分参数,推理时具体如何加载ckpt请参考[附录A](###A 分布式微调后模型评估方法)。
+
+```python
+# original callback
+# ckpt_callback = ModelCheckpoint(...)
+
+# replace ModelCheckpoint with TrainableParamsCheckPoint
+ckpt_callback = TrainableParamsCheckPoint(...)
+```
+
+
+
+### 7.4 实验效果
+
+下面实验基于MindFormers开源仓中的[**GLM2-6B**](https://gitee.com/mindspore/mindformers/blob/dev/docs/model_cards/glm2.md)复现。
+
+
+
+
+ 模型 |
+ 下游任务 |
+ 模式 |
+ 训练参数 |
+ 微调参数占比 |
+ 静态内存+动态内存 |
+ rouge-1 |
+
+
+ epoch |
+ 优化器 |
+ 学习率 |
+ pre_seq_num |
+
+
+
+
+ glm2-6m |
+ language modeling |
+ baseline |
+ 1 |
+ Adamw |
+ 1.00E-04 |
+ \ |
+ 100% |
+ 60056MB+92141MB |
+ 30.7 |
+
+
+ p-tuning v2 |
+ 1 |
+ Adamw |
+ 5.00E-03 |
+ 128 |
+ 0.03% |
+ 12992MB+35588MB |
+ 31.5 |
+
+
+
+
+
+
## 附录
diff --git a/doc/image/ptuning2.png b/doc/image/ptuning2.png
new file mode 100644
index 0000000000000000000000000000000000000000..94c9b41a6a045246a51455b0240c4e3329cd09a0
Binary files /dev/null and b/doc/image/ptuning2.png differ
diff --git a/mindpet/delta/__init__.py b/mindpet/delta/__init__.py
index 605be9cd98f1243fee9368a278739f5f95e92cab..7219c79e4b0875863b582faa5be352e7351625ef 100644
--- a/mindpet/delta/__init__.py
+++ b/mindpet/delta/__init__.py
@@ -6,7 +6,8 @@ from mindpet.delta.lora import LoRADense
from mindpet.delta.prefix_layer import PrefixLayer
from mindpet.delta.low_rank_adapter import LowRankAdapterDense, LowRankAdapterLayer
from mindpet.delta.adapter import AdapterDense, AdapterLayer
+from mindpet.delta.ptuning2 import PrefixEncoder
from mindpet.delta.r_drop import RDropLoss, rdrop_repeat
__all__ = ['LoRADense', 'PrefixLayer', 'LowRankAdapterDense', 'LowRankAdapterLayer',
- 'AdapterDense', 'AdapterLayer', 'RDropLoss', 'rdrop_repeat']
+ 'AdapterDense', 'AdapterLayer', 'RDropLoss', 'rdrop_repeat', 'PrefixEncoder']
diff --git a/mindpet/delta/ptuning2.py b/mindpet/delta/ptuning2.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcaa1a2d8ef573b14d9e4b39fe6eac5a64f71925
--- /dev/null
+++ b/mindpet/delta/ptuning2.py
@@ -0,0 +1,114 @@
+# Copyright 2023 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+p-tuning-v2
+Reference: https://arxiv.org/pdf/2110.07602.pdf
+"""
+
+import mindspore as ms
+import mindspore.nn as nn
+import numpy as np
+
+from mindspore import dtype as mstype
+from mindspore.ops import operations as P
+
+from mindpet.utils.version_control import get_dropout
+
+try:
+ from mindspore._checkparam import Validator, Rel
+
+ INC_LEFT = Rel.INC_LEFT
+except:
+ import mindspore._checkparam as Validator
+
+ INC_LEFT = Validator.INC_LEFT
+
+
+class PrefixEncoder(nn.Cell):
+ """
+ The cell to encode the prefix
+ Input : batch_size
+ Output shape: layers * (2, bs, num_heads, pre_len, kv_channels)
+ """
+
+ def __init__(self, pre_seq_len, num_layers, num_heads, kv_channels, prefix_projection,
+ projection_dim, dropout_prob):
+ """
+ Args:
+ pre_seq_len: prefix的序列长度
+ num_layers: 原模型transformer层数
+ num_heads: 原模型transformer多头注意力头数
+ kv_channels: 原模型transformer kv维度
+ prefix_projection 是否使用MLP表征
+ projection_dim: MLP维度
+ dropout_prob: 丢弃率
+ """
+ super().__init__()
+ self.pre_seq_len = Validator.check_positive_int(pre_seq_len, "pre_seq_len")
+ self.num_layers = Validator.check_positive_int(num_layers, "num_layers")
+ self.num_heads = Validator.check_positive_int(num_heads, "num_heads")
+ self.kv_channels = Validator.check_positive_int(kv_channels, "kv_channels")
+
+ dropout_prob = Validator.check_float_range(dropout_prob, 0.0, 1.0, INC_LEFT)
+ self.dropout = get_dropout(dropout_prob)
+
+ self.prefix_projection = prefix_projection
+
+ self.tk_delta_ptuning2_prefix = ms.Parameter(np.arange(self.pre_seq_len),
+ requires_grad=False)
+
+ out_embed_dim = self.num_layers * self.kv_channels * self.num_heads * 2
+ self.tk_delta_ptuning2_embedding = nn.Embedding(self.pre_seq_len, out_embed_dim)
+
+ if self.prefix_projection:
+ self.projection_dim = Validator.check_positive_int(projection_dim, "projection_dim")
+ # two-layer MLP to encode the prefix
+ self.tk_delta_ptuning2_trans = nn.SequentialCell(
+ nn.Dense(out_embed_dim, self.projection_dim),
+ nn.Tanh(),
+ nn.Dense(self.projection_dim, out_embed_dim)
+ )
+
+ self.expand_dims = P.ExpandDims()
+ self.tile = P.Tile()
+ self.cast = P.Cast()
+
+ def construct(self, batch_size, dtype=mstype.half):
+ prefix_tokens = self.expand_dims(self.tk_delta_ptuning2_prefix, 0)
+ prefix_tokens = self.tile(prefix_tokens, (batch_size, 1))
+
+ # (bs, pre_len) -> (bs, pre_len, 2 * layers * num_heads * kv_channels)
+ past_key_values = self.tk_delta_ptuning2_embedding(prefix_tokens)
+
+ if self.prefix_projection:
+ past_key_values = self.tk_delta_ptuning2_trans(past_key_values)
+
+ past_key_values = self.cast(past_key_values, dtype)
+
+ # (bs, pre_len, 2 * layers * num_heads * kv_channels) -> (bs, pre_len, 2 * layers, num_heads, kv_channels)
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.num_layers * 2,
+ self.num_heads,
+ self.kv_channels
+ )
+
+ past_key_values = self.dropout(past_key_values)
+
+ # (bs, pre_len, 2 * layers, num_heads, kv_channels) -> layers * (2, bs, num_heads, pre_len, kv_channels)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+
+ return past_key_values
diff --git a/mindpet/utils/constants.py b/mindpet/utils/constants.py
index 8d0f90d9f153880e688638094b00afb66030d9e6..61a1910eefd3cde70ad8847e645efb4448d916ba 100644
--- a/mindpet/utils/constants.py
+++ b/mindpet/utils/constants.py
@@ -104,5 +104,5 @@ EVAL_INFER_TASK_NAMES = [EVALUATE_TASK_NAME, INFER_TASK_NAME]
TK_SDK_INTERFACE_NAMES = [FINETUNE_TASK_NAME, EVALUATE_TASK_NAME, INFER_TASK_NAME]
# 微调算法清单
-DELTA_LIST = ['lora', 'prefixtuning', 'adapter', 'low_rank_adapter', 'bitfit']
+DELTA_LIST = ['lora', 'prefixtuning', 'adapter', 'low_rank_adapter', 'bitfit', 'ptuning2']
diff --git a/test/unit_test/delta/test_prefix_encoder.py b/test/unit_test/delta/test_prefix_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..17e8a70d70963756213cd4bab51a46485f9d2a6a
--- /dev/null
+++ b/test/unit_test/delta/test_prefix_encoder.py
@@ -0,0 +1,125 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright © Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
+
+import logging
+import os
+import shutil
+import unittest
+import pytest
+import argparse
+import mindspore
+
+from mindpet.delta.ptuning2 import PrefixEncoder
+from mindpet.utils.constants import DEFAULT_MODES, DEFAULT_FLAGS
+
+logging.getLogger().setLevel(logging.INFO)
+LOCAL_PATH = os.path.join('/', 'tmp', 'ut').replace('\\', '/')
+LOCAL_FILE = os.path.join(LOCAL_PATH, 'ut_sample.txt')
+
+
+class TestPrefixEncoder(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ logging.info('-----准备执行单元测试前置动作, 创建本地临时文件-----')
+ if not os.path.exists(LOCAL_PATH):
+ os.makedirs(LOCAL_PATH, exist_ok=True)
+
+ if not os.path.exists(LOCAL_FILE):
+ with os.fdopen(os.open(LOCAL_FILE, DEFAULT_FLAGS, DEFAULT_MODES), 'w+') as file:
+ file.write('ut test sample')
+
+ @classmethod
+ def tearDownClass(cls):
+ logging.info('-----单元测试执行完毕, 清除本地临时文件-----')
+ if os.path.exists(LOCAL_PATH):
+ shutil.rmtree(LOCAL_PATH)
+
+ if os.path.exists(LOCAL_FILE):
+ shutil.rmtree(LOCAL_FILE)
+
+ def test_pre_seq_len_correct(self):
+ logging.info('Start test_pre_seq_len')
+ prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8,
+ kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1)
+ self.assertEqual(10, prefix.pre_seq_len)
+ logging.info("Finish test_pre_seq_len_correct")
+
+ def test_num_layers_correct(self):
+ logging.info('Start test_num_layers_correct')
+ prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8,
+ kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1)
+ self.assertEqual(2, prefix.num_layers)
+ logging.info("Finish test_num_layers_correct")
+
+ def test_num_heads_correct(self):
+ logging.info('Start test_num_heads_correct')
+ prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8,
+ kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1)
+ self.assertEqual(8, prefix.num_heads)
+ logging.info("Finish test_num_heads_correct")
+
+ def test_kv_channels_correct(self):
+ logging.info('Start test_kv_channels_correct')
+ prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8,
+ kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1)
+ self.assertEqual(32, prefix.kv_channels)
+ logging.info("Finish test_kv_channels_correct")
+
+ def test_projection_dim_correct(self):
+ logging.info('Start test_projection_dim_correct')
+ prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8,
+ kv_channels=32, prefix_projection=True, projection_dim=32, dropout_prob=0.1)
+ self.assertEqual(32, prefix.projection_dim)
+ logging.info("Finish test_projection_dim_correct")
+
+ def test_pre_seq_len_is_not_positive_integer(self):
+ logging.info("Start test_pre_seq_len_is_not_positive_integer")
+ self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=-1, num_layers=2, num_heads=8,
+ kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1)
+ logging.info("Finish test_pre_seq_len_is_not_positive_integer")
+
+ def test_num_layers_is_not_positive_integer(self):
+ logging.info("Start test_num_layers_is_not_positive_integer")
+ self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=-1, num_heads=8,
+ kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1)
+ logging.info("Finish test_num_layers_is_not_positive_integer")
+
+ def test_num_heads_is_not_positive_integer(self):
+ logging.info("Start test_num_heads_is_not_positive_integer")
+ self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=-1,
+ kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1)
+ logging.info("Finish test_num_heads_is_not_positive_integer")
+
+ def test_kv_channels_is_not_positive_integer(self):
+ logging.info("Start test_kv_channels_is_not_positive_integer")
+ self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=8,
+ kv_channels=-1, prefix_projection=False, projection_dim=32, dropout_prob=0.1)
+ logging.info("Finish test_kv_channels_is_not_positive_integer")
+
+ def test_projection_dim_is_not_positive_integer(self):
+ logging.info("Start test_projection_dim_is_not_positive_integer")
+ self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=8,
+ kv_channels=32, prefix_projection=True, projection_dim=-1, dropout_prob=0.1)
+ logging.info("Finish test_projection_dim_is_not_positive_integer")
+
+ def test_dropout_prob_negative(self):
+ logging.info("Start test_dropout_prob_is_negative")
+ self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=8,
+ kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=-0.1)
+ logging.info("Finish test_dropout_prob_is_negative")
+
+ def test_dropout_prob_is_one(self):
+ logging.info("Start test_dropout_prob_scope")
+ self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=8,
+ kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=1.0)
+ logging.info("Finish test_dropout_prob_scope")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--device_id', type=int, default=0)
+ args = parser.parse_args()
+ mindspore.set_context(device_id=args.device_id)
+ pytest.main(["-s", os.path.abspath(__file__)])