diff --git a/doc/TK_DeltaAlgorithm_README.md b/doc/TK_DeltaAlgorithm_README.md index d29f76963d334344939fc0c59345284576e4aaf8..c7aff9c850209294920fe588634c35291c3c9cc6 100644 --- a/doc/TK_DeltaAlgorithm_README.md +++ b/doc/TK_DeltaAlgorithm_README.md @@ -1462,6 +1462,199 @@ class BertClsModel(BaseModel): +## 七、P-Tuning v2算法 + +### 7.1 算法介绍 + +P-Tuning v2该方法将可训练的连续提示向量独立添加到每个transformer层的输入中,只训练这部分任务相关的向量,保持预训练模型的参数不变。P-Tuning v2会在每个transformer层的key和value向量的前面插入l个用于更新参数的连续提示向量,然后冻结预训练模型的参数, 只更新这些向量的参数,就可以达到近似全参微调的效果。 + + +算法原理如下图所示,算法具体实现细节可参考论文[P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/abs/2110.07602) + +

+P-Tuning v2算法原理图: 对于每个下游任务,在网络的每一层添加一份连续提示向量,冻结预训练模型的其他参数,只训练这些向量。 +
+ + + +### 7.2 API接口 + +``` python + class PrefixEncoder(pre_seq_len, + num_layers, + num_heads, + kv_channels, + prefix_projection, + projection_dim, + dropout_prob): +``` + +定义PrefixEncoder层 + + + +**参数** + +- **pre_seq_len**(int) - 网络每层提示向量的长度. +- **num_layers**(int) - 网络层数,与原模型参数一致。 +- **num_heads**(int) - 多头注意力头数,与原模型参数一致。 +- **kv_channels**(int) - `key`、`value`隐藏维度,与原模型参数一致。 +- **prefix_projection**(bool) - 是否使用MLP表征。 +- **projection_dim**(int) - MLP维度。 +- **dropout_prob**(float) - 丢弃率。 + + + +**异常** + +- **TypeError** - `pre_seq_len`不是正整数。 +- **TypeError** - `num_layers`不是正整数。 +- **TypeError** - `num_heads`不是正整数。 +- **TypeError** - `kv_channels`不是正整数。 +- **TypeError** - `projection_dim`不是正整数。 +- **ValueError** - `dropout_prob`不在[0,1)之内。 + + + +### 7.3 使用样例 + +通过以下步骤将模型结构中`key`、`value`和`attention_mask`修改为新的`key`、`value`和`attention_mask`,冻结网络进行训练: + +1)安装mindpet工具包。([安装方法参考《README.md》第二章](../README.md)) + +2)在模型的初始化时,从工具包中引入`PrefixEncoder`类,创建`prefixEncoder`,在`construct`时构造提示向量传递给网络的每层。 + +```python +class ChatModelWithPt2(ChatModel): + def __init__(self, config): + super().__init__(config) + self.prefix_encoder = PrefixEncoder( + config.pet_config.pre_seq_len, + config.pet_config.num_layers, + config.pet_config.num_heads, + config.pet_config.kv_channels, + config.pet_config.prefix_projection, + config.pet_config.projection_dim, + config.pet_config.dropout_prob + ) + ... + + def construct(self, ...): + prefix_key_values = self.prefix_encoder(batch_size) + return super().construct(..., prefix_key_values) +``` + +3)在模型的Attention结构中,将`prefixlayer`构造的每层`prefix_key_value`矩阵与原`key`、`value`矩阵进行`concat`操作。然后定义全为1的`help`矩阵,将原`attention_mask`矩阵与`help`矩阵进行`concat`(新的`attention_mask`矩阵shape与新的`query`*`key`矩阵的shape相同)。 + +```python +#模型的Attention层 +class SelfAttention(nn.Cell): + def add_prefix(prefix_key_value, pre_seq_len, key, value, attention_mask): + # [bs, num_heads, seq_length, head_dim] + seq_len = key.shape[2] + + # [bs, num_heads, pre_seq_len, head_dim] + prefix_key = prefix_key_value[0] + prefix_value = prefix_key_value[1] + cat = P.Concat(2) + key = cat([prefix_key, key]) + value = cat([prefix_value, value]) + + batch_size = attention_mask.shape[0] + prefix_mask = attention_mask.new_ones((batch_size, 1, seq_len, pre_seq_len)) + m_cat = P.Concat(3) + + # [bs, 1, seq_len, pre_seq_len + seq_len] + attention_mask = m_cat((prefix_mask, attention_mask)) + + return key, value, attention_mask + + def construct(self, input_tensor, attention_mask): + ... + ... + key_layer, value_layer, attention_mask = self.add_prefix( + prefix_key_value, + self.pre_seq_len, + key_layer, + value_layer, + attention_mask + ) + context_layer = self.attention(query_layer, key_layer, value_layer, attention_mask) + ... +``` + +4)在训练脚本中,从工具包中引入`freeze_delta`方法,定义优化器之前调用`freeze_delta`冻结除`Prefix`矩阵外其它原模型权重。注意,为了适配下游任务引入的额外模型结构无需冻结,可以用`exclude`参数指定无需冻结的结构名称。([冻结方法参考《TK_GraphOperation_README.md》第一章](TK_GraphOperation_README.md)) + +```Python +# freeze all cell except ptuning2 +freeze_delta(model=network, mode='ptuning2') +``` + +然后从工具包中引入`TrainableParamsCheckPoint`类,将保存ckpt的类改为`TrainableParamsCheckPoint`,仅保存需要更新的参数,可节约存储空间。([详细方法参考《TK_GraphOperation_README.md》第二章](TK_GraphOperation_README.md)) + +由于微调后只保存了部分参数,推理时具体如何加载ckpt请参考[附录A](###A 分布式微调后模型评估方法)。 + +```python +# original callback +# ckpt_callback = ModelCheckpoint(...) + +# replace ModelCheckpoint with TrainableParamsCheckPoint +ckpt_callback = TrainableParamsCheckPoint(...) +``` + + + +### 7.4 实验效果 + +下面实验基于MindFormers开源仓中的[**GLM2-6B**](https://gitee.com/mindspore/mindformers/blob/dev/docs/model_cards/glm2.md)复现。 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
模型下游任务模式训练参数微调参数占比静态内存+动态内存rouge-1
epoch优化器学习率pre_seq_num
glm2-6mlanguage modelingbaseline1Adamw1.00E-04\100%60056MB+92141MB30.7
p-tuning v21Adamw5.00E-031280.03%12992MB+35588MB31.5
+ + + ## 附录 diff --git a/doc/image/ptuning2.png b/doc/image/ptuning2.png new file mode 100644 index 0000000000000000000000000000000000000000..94c9b41a6a045246a51455b0240c4e3329cd09a0 Binary files /dev/null and b/doc/image/ptuning2.png differ diff --git a/mindpet/delta/__init__.py b/mindpet/delta/__init__.py index 605be9cd98f1243fee9368a278739f5f95e92cab..7219c79e4b0875863b582faa5be352e7351625ef 100644 --- a/mindpet/delta/__init__.py +++ b/mindpet/delta/__init__.py @@ -6,7 +6,8 @@ from mindpet.delta.lora import LoRADense from mindpet.delta.prefix_layer import PrefixLayer from mindpet.delta.low_rank_adapter import LowRankAdapterDense, LowRankAdapterLayer from mindpet.delta.adapter import AdapterDense, AdapterLayer +from mindpet.delta.ptuning2 import PrefixEncoder from mindpet.delta.r_drop import RDropLoss, rdrop_repeat __all__ = ['LoRADense', 'PrefixLayer', 'LowRankAdapterDense', 'LowRankAdapterLayer', - 'AdapterDense', 'AdapterLayer', 'RDropLoss', 'rdrop_repeat'] + 'AdapterDense', 'AdapterLayer', 'RDropLoss', 'rdrop_repeat', 'PrefixEncoder'] diff --git a/mindpet/delta/ptuning2.py b/mindpet/delta/ptuning2.py new file mode 100644 index 0000000000000000000000000000000000000000..dcaa1a2d8ef573b14d9e4b39fe6eac5a64f71925 --- /dev/null +++ b/mindpet/delta/ptuning2.py @@ -0,0 +1,114 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +p-tuning-v2 +Reference: https://arxiv.org/pdf/2110.07602.pdf +""" + +import mindspore as ms +import mindspore.nn as nn +import numpy as np + +from mindspore import dtype as mstype +from mindspore.ops import operations as P + +from mindpet.utils.version_control import get_dropout + +try: + from mindspore._checkparam import Validator, Rel + + INC_LEFT = Rel.INC_LEFT +except: + import mindspore._checkparam as Validator + + INC_LEFT = Validator.INC_LEFT + + +class PrefixEncoder(nn.Cell): + """ + The cell to encode the prefix + Input : batch_size + Output shape: layers * (2, bs, num_heads, pre_len, kv_channels) + """ + + def __init__(self, pre_seq_len, num_layers, num_heads, kv_channels, prefix_projection, + projection_dim, dropout_prob): + """ + Args: + pre_seq_len: prefix的序列长度 + num_layers: 原模型transformer层数 + num_heads: 原模型transformer多头注意力头数 + kv_channels: 原模型transformer kv维度 + prefix_projection 是否使用MLP表征 + projection_dim: MLP维度 + dropout_prob: 丢弃率 + """ + super().__init__() + self.pre_seq_len = Validator.check_positive_int(pre_seq_len, "pre_seq_len") + self.num_layers = Validator.check_positive_int(num_layers, "num_layers") + self.num_heads = Validator.check_positive_int(num_heads, "num_heads") + self.kv_channels = Validator.check_positive_int(kv_channels, "kv_channels") + + dropout_prob = Validator.check_float_range(dropout_prob, 0.0, 1.0, INC_LEFT) + self.dropout = get_dropout(dropout_prob) + + self.prefix_projection = prefix_projection + + self.tk_delta_ptuning2_prefix = ms.Parameter(np.arange(self.pre_seq_len), + requires_grad=False) + + out_embed_dim = self.num_layers * self.kv_channels * self.num_heads * 2 + self.tk_delta_ptuning2_embedding = nn.Embedding(self.pre_seq_len, out_embed_dim) + + if self.prefix_projection: + self.projection_dim = Validator.check_positive_int(projection_dim, "projection_dim") + # two-layer MLP to encode the prefix + self.tk_delta_ptuning2_trans = nn.SequentialCell( + nn.Dense(out_embed_dim, self.projection_dim), + nn.Tanh(), + nn.Dense(self.projection_dim, out_embed_dim) + ) + + self.expand_dims = P.ExpandDims() + self.tile = P.Tile() + self.cast = P.Cast() + + def construct(self, batch_size, dtype=mstype.half): + prefix_tokens = self.expand_dims(self.tk_delta_ptuning2_prefix, 0) + prefix_tokens = self.tile(prefix_tokens, (batch_size, 1)) + + # (bs, pre_len) -> (bs, pre_len, 2 * layers * num_heads * kv_channels) + past_key_values = self.tk_delta_ptuning2_embedding(prefix_tokens) + + if self.prefix_projection: + past_key_values = self.tk_delta_ptuning2_trans(past_key_values) + + past_key_values = self.cast(past_key_values, dtype) + + # (bs, pre_len, 2 * layers * num_heads * kv_channels) -> (bs, pre_len, 2 * layers, num_heads, kv_channels) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.num_heads, + self.kv_channels + ) + + past_key_values = self.dropout(past_key_values) + + # (bs, pre_len, 2 * layers, num_heads, kv_channels) -> layers * (2, bs, num_heads, pre_len, kv_channels) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + + return past_key_values diff --git a/mindpet/utils/constants.py b/mindpet/utils/constants.py index 8d0f90d9f153880e688638094b00afb66030d9e6..61a1910eefd3cde70ad8847e645efb4448d916ba 100644 --- a/mindpet/utils/constants.py +++ b/mindpet/utils/constants.py @@ -104,5 +104,5 @@ EVAL_INFER_TASK_NAMES = [EVALUATE_TASK_NAME, INFER_TASK_NAME] TK_SDK_INTERFACE_NAMES = [FINETUNE_TASK_NAME, EVALUATE_TASK_NAME, INFER_TASK_NAME] # 微调算法清单 -DELTA_LIST = ['lora', 'prefixtuning', 'adapter', 'low_rank_adapter', 'bitfit'] +DELTA_LIST = ['lora', 'prefixtuning', 'adapter', 'low_rank_adapter', 'bitfit', 'ptuning2'] diff --git a/test/unit_test/delta/test_prefix_encoder.py b/test/unit_test/delta/test_prefix_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..17e8a70d70963756213cd4bab51a46485f9d2a6a --- /dev/null +++ b/test/unit_test/delta/test_prefix_encoder.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright © Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + +import logging +import os +import shutil +import unittest +import pytest +import argparse +import mindspore + +from mindpet.delta.ptuning2 import PrefixEncoder +from mindpet.utils.constants import DEFAULT_MODES, DEFAULT_FLAGS + +logging.getLogger().setLevel(logging.INFO) +LOCAL_PATH = os.path.join('/', 'tmp', 'ut').replace('\\', '/') +LOCAL_FILE = os.path.join(LOCAL_PATH, 'ut_sample.txt') + + +class TestPrefixEncoder(unittest.TestCase): + + @classmethod + def setUpClass(cls): + logging.info('-----准备执行单元测试前置动作, 创建本地临时文件-----') + if not os.path.exists(LOCAL_PATH): + os.makedirs(LOCAL_PATH, exist_ok=True) + + if not os.path.exists(LOCAL_FILE): + with os.fdopen(os.open(LOCAL_FILE, DEFAULT_FLAGS, DEFAULT_MODES), 'w+') as file: + file.write('ut test sample') + + @classmethod + def tearDownClass(cls): + logging.info('-----单元测试执行完毕, 清除本地临时文件-----') + if os.path.exists(LOCAL_PATH): + shutil.rmtree(LOCAL_PATH) + + if os.path.exists(LOCAL_FILE): + shutil.rmtree(LOCAL_FILE) + + def test_pre_seq_len_correct(self): + logging.info('Start test_pre_seq_len') + prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + self.assertEqual(10, prefix.pre_seq_len) + logging.info("Finish test_pre_seq_len_correct") + + def test_num_layers_correct(self): + logging.info('Start test_num_layers_correct') + prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + self.assertEqual(2, prefix.num_layers) + logging.info("Finish test_num_layers_correct") + + def test_num_heads_correct(self): + logging.info('Start test_num_heads_correct') + prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + self.assertEqual(8, prefix.num_heads) + logging.info("Finish test_num_heads_correct") + + def test_kv_channels_correct(self): + logging.info('Start test_kv_channels_correct') + prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + self.assertEqual(32, prefix.kv_channels) + logging.info("Finish test_kv_channels_correct") + + def test_projection_dim_correct(self): + logging.info('Start test_projection_dim_correct') + prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=True, projection_dim=32, dropout_prob=0.1) + self.assertEqual(32, prefix.projection_dim) + logging.info("Finish test_projection_dim_correct") + + def test_pre_seq_len_is_not_positive_integer(self): + logging.info("Start test_pre_seq_len_is_not_positive_integer") + self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=-1, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + logging.info("Finish test_pre_seq_len_is_not_positive_integer") + + def test_num_layers_is_not_positive_integer(self): + logging.info("Start test_num_layers_is_not_positive_integer") + self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=-1, num_heads=8, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + logging.info("Finish test_num_layers_is_not_positive_integer") + + def test_num_heads_is_not_positive_integer(self): + logging.info("Start test_num_heads_is_not_positive_integer") + self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=-1, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + logging.info("Finish test_num_heads_is_not_positive_integer") + + def test_kv_channels_is_not_positive_integer(self): + logging.info("Start test_kv_channels_is_not_positive_integer") + self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=-1, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + logging.info("Finish test_kv_channels_is_not_positive_integer") + + def test_projection_dim_is_not_positive_integer(self): + logging.info("Start test_projection_dim_is_not_positive_integer") + self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=True, projection_dim=-1, dropout_prob=0.1) + logging.info("Finish test_projection_dim_is_not_positive_integer") + + def test_dropout_prob_negative(self): + logging.info("Start test_dropout_prob_is_negative") + self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=-0.1) + logging.info("Finish test_dropout_prob_is_negative") + + def test_dropout_prob_is_one(self): + logging.info("Start test_dropout_prob_scope") + self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=1.0) + logging.info("Finish test_dropout_prob_scope") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--device_id', type=int, default=0) + args = parser.parse_args() + mindspore.set_context(device_id=args.device_id) + pytest.main(["-s", os.path.abspath(__file__)])