From be81a5a378e06f192df50c47102246cf7ce78c4d Mon Sep 17 00:00:00 2001 From: pk Date: Sun, 24 Sep 2023 12:32:08 +0800 Subject: [PATCH 1/5] add p-tuning-v2 --- mindpet/delta/__init__.py | 3 +- mindpet/delta/ptuning2.py | 114 +++++++++++++++++++++++++++++++++++++ mindpet/utils/constants.py | 2 +- 3 files changed, 117 insertions(+), 2 deletions(-) create mode 100644 mindpet/delta/ptuning2.py diff --git a/mindpet/delta/__init__.py b/mindpet/delta/__init__.py index 605be9c..7219c79 100644 --- a/mindpet/delta/__init__.py +++ b/mindpet/delta/__init__.py @@ -6,7 +6,8 @@ from mindpet.delta.lora import LoRADense from mindpet.delta.prefix_layer import PrefixLayer from mindpet.delta.low_rank_adapter import LowRankAdapterDense, LowRankAdapterLayer from mindpet.delta.adapter import AdapterDense, AdapterLayer +from mindpet.delta.ptuning2 import PrefixEncoder from mindpet.delta.r_drop import RDropLoss, rdrop_repeat __all__ = ['LoRADense', 'PrefixLayer', 'LowRankAdapterDense', 'LowRankAdapterLayer', - 'AdapterDense', 'AdapterLayer', 'RDropLoss', 'rdrop_repeat'] + 'AdapterDense', 'AdapterLayer', 'RDropLoss', 'rdrop_repeat', 'PrefixEncoder'] diff --git a/mindpet/delta/ptuning2.py b/mindpet/delta/ptuning2.py new file mode 100644 index 0000000..d091380 --- /dev/null +++ b/mindpet/delta/ptuning2.py @@ -0,0 +1,114 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +p-tuning-v2 +Reference: https://arxiv.org/pdf/2110.07602.pdf +""" + +import mindspore as ms +import mindspore.nn as nn +import numpy as np + +from mindspore import dtype as mstype +from mindspore.ops import operations as P + +from mindpet.utils.version_control import get_dropout + +try: + from mindspore._checkparam import Validator, Rel + + INC_LEFT = Rel.INC_LEFT +except: + import mindspore._checkparam as Validator + + INC_LEFT = Validator.INC_LEFT + + +class PrefixEncoder(nn.Cell): + """ + The cell to encode the prefix + Input : batch_size + Output shape: layers * (2, bs, num_heads, pre_len, kv_channels) + """ + + def __init__(self, pre_seq_len, num_layers, num_heads, kv_channels, prefix_projection, + projection_dim, dropout_prob): + """ + Args: + pre_seq_len: prefix的序列长度 + num_layers: 原模型transformer层数 + num_heads: 原模型transformer多头注意力头数 + kv_channels: 原模型transformer kv维度 + projection_dim: MLP编码维度 + dropout_prob: 丢弃率 + prefix_projection 是否使用MLP编码 + """ + super().__init__() + self.pre_seq_len = Validator.check_positive_int(pre_seq_len, "pre_seq_len") + self.num_layers = Validator.check_positive_int(num_layers, "num_layers") + self.num_heads = Validator.check_positive_int(num_heads, "num_heads") + self.kv_channels = Validator.check_positive_int(kv_channels, "kv_channels") + + dropout_prob = Validator.check_float_range(dropout_prob, 0.0, 1.0, INC_LEFT) + self.dropout = get_dropout(dropout_prob) + + self.prefix_projection = prefix_projection + + self.tk_delta_ptuning2_prefix = ms.Parameter(np.arange(self.pre_seq_len), + requires_grad=False) + + out_embed_dim = self.num_layers * self.kv_channels * self.num_heads * 2 + self.tk_delta_ptuning2_embedding = nn.Embedding(self.pre_seq_len, out_embed_dim) + + if self.prefix_projection: + self.projection_dim = Validator.check_positive_int(projection_dim, "projection_dim") + # two-layer MLP to encode the prefix + self.tk_delta_ptuning2_trans = nn.SequentialCell( + nn.Dense(out_embed_dim, self.projection_dim), + nn.Tanh(), + nn.Dense(self.projection_dim, out_embed_dim) + ) + + self.expand_dims = P.ExpandDims() + self.tile = P.Tile() + self.cast = P.Cast() + + def construct(self, batch_size, dtype=mstype.half): + prefix_tokens = self.expand_dims(self.tk_delta_ptuning2_prefix, 0) + prefix_tokens = self.tile(prefix_tokens, (batch_size, 1)) + + # (bs, pre_len) -> (bs, pre_len, 2 * layers * num_heads * kv_channels) + past_key_values = self.tk_delta_ptuning2_embedding(prefix_tokens) + + if self.prefix_projection: + past_key_values = self.tk_delta_ptuning2_trans(past_key_values) + + past_key_values = self.cast(past_key_values, dtype) + + # (bs, pre_len, 2 * layers * num_heads * kv_channels) -> (bs, pre_len, 2 * layers, num_heads, kv_channels) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.num_heads, + self.kv_channels + ) + + past_key_values = self.dropout(past_key_values) + + # (bs, pre_len, 2 * layers, num_heads, kv_channels) -> layers * (2, bs, num_heads, pre_len, kv_channels) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + + return past_key_values diff --git a/mindpet/utils/constants.py b/mindpet/utils/constants.py index 8d0f90d..61a1910 100644 --- a/mindpet/utils/constants.py +++ b/mindpet/utils/constants.py @@ -104,5 +104,5 @@ EVAL_INFER_TASK_NAMES = [EVALUATE_TASK_NAME, INFER_TASK_NAME] TK_SDK_INTERFACE_NAMES = [FINETUNE_TASK_NAME, EVALUATE_TASK_NAME, INFER_TASK_NAME] # 微调算法清单 -DELTA_LIST = ['lora', 'prefixtuning', 'adapter', 'low_rank_adapter', 'bitfit'] +DELTA_LIST = ['lora', 'prefixtuning', 'adapter', 'low_rank_adapter', 'bitfit', 'ptuning2'] -- Gitee From a8508eb5696aea68444ccd1dd3e1999d49b09c6f Mon Sep 17 00:00:00 2001 From: pk Date: Sun, 24 Sep 2023 13:36:15 +0800 Subject: [PATCH 2/5] p-tuning-v2 test py --- test/unit_test/delta/test_prefix_encoder.py | 125 ++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 test/unit_test/delta/test_prefix_encoder.py diff --git a/test/unit_test/delta/test_prefix_encoder.py b/test/unit_test/delta/test_prefix_encoder.py new file mode 100644 index 0000000..9bc95a7 --- /dev/null +++ b/test/unit_test/delta/test_prefix_encoder.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright © Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + +import logging +import os +import shutil +import unittest +import pytest +import argparse +import mindspore + +from mindpet.delta.ptuning2 import PrefixEncoder +from mindpet.utils.constants import DEFAULT_MODES, DEFAULT_FLAGS + +logging.getLogger().setLevel(logging.INFO) +LOCAL_PATH = os.path.join('/', 'tmp', 'ut').replace('\\', '/') +LOCAL_FILE = os.path.join(LOCAL_PATH, 'ut_sample.txt') + + +class TestPrefixEncoder(unittest.TestCase): + + @classmethod + def setUpClass(cls): + logging.info('-----准备执行单元测试前置动作, 创建本地临时文件-----') + if not os.path.exists(LOCAL_PATH): + os.makedirs(LOCAL_PATH, exist_ok=True) + + if not os.path.exists(LOCAL_FILE): + with os.fdopen(os.open(LOCAL_FILE, DEFAULT_FLAGS, DEFAULT_MODES), 'w+') as file: + file.write('ut test sample') + + @classmethod + def tearDownClass(cls): + logging.info('-----单元测试执行完毕, 清除本地临时文件-----') + if os.path.exists(LOCAL_PATH): + shutil.rmtree(LOCAL_PATH) + + if os.path.exists(LOCAL_FILE): + shutil.rmtree(LOCAL_FILE) + + def test_pre_seq_len_correct(self): + logging.info('Start test_pre_seq_len') + prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + self.assertEqual(10, prefix.pre_seq_len) + logging.info("Finish test_pre_seq_len_correct") + + def test_num_layers_correct(self): + logging.info('Start test_num_layers_correct') + prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + self.assertEqual(2, prefix.num_layers) + logging.info("Finish test_num_layers_correct") + + def test_num_heads_correct(self): + logging.info('Start test_num_heads_correct') + prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + self.assertEqual(8, prefix.num_heads) + logging.info("Finish test_num_heads_correct") + + def test_kv_channels_correct(self): + logging.info('Start test_kv_channels_correct') + prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + self.assertEqual(32, prefix.kv_channels) + logging.info("Finish test_kv_channels_correct") + + def test_projection_dim_correct(self): + logging.info('Start test_projection_dim_correct') + prefix = PrefixEncoder(pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=True, projection_dim=32, dropout_prob=0.1) + self.assertEqual(32, prefix.projection_dim) + logging.info("Finish test_projection_dim_correct") + + def test_pre_seq_len_is_not_positive_integer(self): + logging.info("Start test_pre_seq_len_is_not_positive_integer") + self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=-1, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + logging.info("Finish test_pre_seq_len_is_not_positive_integer") + + def test_num_layers_is_not_positive_integer(self): + logging.info("Start test_num_layers_is_not_positive_integer") + self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=-1, num_heads=8, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + logging.info("Finish test_num_layers_is_not_positive_integer") + + def test_num_heads_is_not_positive_integer(self): + logging.info("Start test_num_heads_is_not_positive_integer") + self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=-1, + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + logging.info("Finish test_num_heads_is_not_positive_integer") + + def test_kv_channels_is_not_positive_integer(self): + logging.info("Start test_kv_channels_is_not_positive_integer") + self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=-1, prefix_projection=False, projection_dim=32, dropout_prob=0.1) + logging.info("Finish test_kv_channels_is_not_positive_integer") + + def test_projection_dim_is_not_positive_integer(self): + logging.info("Start test_projection_dim_is_not_positive_integer") + self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=32, prefix_projection=True, projection_dim=-1, dropout_prob=0.1) + logging.info("Finish test_projection_dim_is_not_positive_integer") + + def test_dropout_prob_negative(self): + logging.info("Start test_dropout_prob_is_negative") + self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=-1, prefix_projection=False, projection_dim=32, dropout_prob=-0.1) + logging.info("Finish test_dropout_prob_is_negative") + + def test_dropout_prob_is_one(self): + logging.info("Start test_dropout_prob_scope") + self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=8, + kv_channels=-1, prefix_projection=False, projection_dim=32, dropout_prob=1.0) + logging.info("Finish test_dropout_prob_scope") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--device_id', type=int, default=0) + args = parser.parse_args() + mindspore.set_context(device_id=args.device_id) + pytest.main(["-s", os.path.abspath(__file__)]) -- Gitee From 9b19d2491f5fac49360b5f1adcd552d007d64daf Mon Sep 17 00:00:00 2001 From: pk Date: Sun, 24 Sep 2023 13:41:49 +0800 Subject: [PATCH 3/5] p-tuning-v2 test py --- test/unit_test/delta/test_prefix_encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/unit_test/delta/test_prefix_encoder.py b/test/unit_test/delta/test_prefix_encoder.py index 9bc95a7..17e8a70 100644 --- a/test/unit_test/delta/test_prefix_encoder.py +++ b/test/unit_test/delta/test_prefix_encoder.py @@ -107,13 +107,13 @@ class TestPrefixEncoder(unittest.TestCase): def test_dropout_prob_negative(self): logging.info("Start test_dropout_prob_is_negative") self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=8, - kv_channels=-1, prefix_projection=False, projection_dim=32, dropout_prob=-0.1) + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=-0.1) logging.info("Finish test_dropout_prob_is_negative") def test_dropout_prob_is_one(self): logging.info("Start test_dropout_prob_scope") self.assertRaises(ValueError, PrefixEncoder, pre_seq_len=10, num_layers=2, num_heads=8, - kv_channels=-1, prefix_projection=False, projection_dim=32, dropout_prob=1.0) + kv_channels=32, prefix_projection=False, projection_dim=32, dropout_prob=1.0) logging.info("Finish test_dropout_prob_scope") -- Gitee From b94117e96a4cce8b31b7febb22b6fe98dd0ddd28 Mon Sep 17 00:00:00 2001 From: pk Date: Sun, 24 Sep 2023 20:36:20 +0800 Subject: [PATCH 4/5] add p-tuning-v2 doc --- doc/TK_DeltaAlgorithm_README.md | 192 ++++++++++++++++++++++++++++++++ doc/image/ptuning2.png | Bin 0 -> 60070 bytes mindpet/delta/ptuning2.py | 4 +- 3 files changed, 194 insertions(+), 2 deletions(-) create mode 100644 doc/image/ptuning2.png diff --git a/doc/TK_DeltaAlgorithm_README.md b/doc/TK_DeltaAlgorithm_README.md index d29f769..b8cb27f 100644 --- a/doc/TK_DeltaAlgorithm_README.md +++ b/doc/TK_DeltaAlgorithm_README.md @@ -1462,6 +1462,198 @@ class BertClsModel(BaseModel): +## 七、P-Tuning v2算法 + +### 7.1 算法介绍 + +P-Tuning v2该方法将可训练的连续提示向量独立添加到每个transformer层的输入中,只训练这部分任务相关的向量,保持预训练模型的参数不变。P-Tuning v2会在每个transformer层的key和value向量的前面插入l个用于更新参数的连续提示向量,然后冻结预训练模型的参数, 只更新这些向量的参数,就可以达到近似全参微调的效果。 + + +算法原理如下图所示,算法具体实现细节可参考论文[P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/abs/2110.07602) + +

+P-Tuning v2算法原理图: 对于每个下游任务,在网络的每一层添加一份连续提示向量,冻结预训练模型的其他参数,只训练这些向量。 +
+ + + +### 7.2 API接口 + +``` python + class PrefixEncoder(pre_seq_len, + num_layers, + num_heads, + kv_channels, + prefix_projection, + projection_dim, + dropout_prob): +``` + +定义PrefixEncoder层 + + + +**参数** + +- **pre_seq_len**(int) - 网络每层提示向量的长度. +- **num_layers**(int) - 网络层数,与原模型参数一致。 +- **num_heads**(int) - 多头注意力头数,与原模型参数一致。 +- **kv_channels**(int) - `key`、`value`隐藏维度,与原模型参数一致。 +- **prefix_projection**(bool) - 是否使用MLP表征。 +- **projection_dim**(int) - MLP维度。 +- **dropout_prob**(float) - 丢弃率。 + + + +**异常** + +- **TypeError** - `pre_seq_len`不是正整数。 +- **TypeError** - `num_layers`不是正整数。 +- **TypeError** - `num_heads`不是正整数。 +- **TypeError** - `kv_channels`不是正整数。 +- **TypeError** - `projection_dim`不是正整数。 +- **ValueError** - `dropout_prob`不在[0,1)之内。 + + + +### 7.3 使用样例 + +通过以下步骤将模型结构中`key`、`value`和`attention_mask`修改为新的`key`、`value`和`attention_mask`,冻结网络进行训练: + +1)安装mindpet工具包。([安装方法参考《README.md》第二章](../README.md)) + +2)在模型的初始化时,从工具包中引入`PrefixEncoder`类,创建`prefixEncoder`,在`construct`时构造提示向量传递给网络的每层。 + +```python +class ChatModelWithPt2(ChatModel): + def __init__(self, config): + super().__init__(config) + self.prefix_encoder = PrefixEncoder( + config.pet_config.pre_seq_len, + config.pet_config.num_layers, + config.pet_config.num_heads, + config.pet_config.kv_channels, + config.pet_config.prefix_projection, + config.pet_config.projection_dim, + config.pet_config.dropout_prob + ) + ... + + def construct(self, ...): + prefix_key_values = self.prefix_encoder(batch_size) + return super().construct(..., prefix_key_values) +``` + +3)在模型的Attention结构中,将`prefixlayer`构造的每层`prefix_key_value`矩阵与原`key`、`value`矩阵进行`concat`操作。然后定义全为1的`help`矩阵,将原`attention_mask`矩阵与`help`矩阵进行`concat`(新的`attention_mask`矩阵shape与新的`query`*`key`矩阵的shape相同)。 + +```python +#模型的Attention层 +class SelfAttention(nn.Cell): + def add_prefix(prefix_key_value, pre_seq_len, key, value, attention_mask): + # [bs, num_heads, seq_length, head_dim] + seq_len = key.shape[2] + + # [bs, num_heads, pre_seq_len, head_dim] + prefix_key = prefix_key_value[0] + prefix_value = prefix_key_value[1] + cat = P.Concat(2) + key = cat([prefix_key, key]) + value = cat([prefix_value, value]) + + batch_size = attention_mask.shape[0] + prefix_mask = attention_mask.new_ones((batch_size, 1, seq_len, pre_seq_len)) + m_cat = P.Concat(3) + + # [bs, 1, seq_len, pre_seq_len + seq_len] + attention_mask = m_cat((prefix_mask, attention_mask)) + + return key, value, attention_mask + def construct(self, input_tensor, attention_mask): + ... + ... + key_layer, value_layer, attention_mask = self.add_prefix( + prefix_key_value, + self.pre_seq_len, + key_layer, + value_layer, + attention_mask + ) + context_layer = self.attention(query_layer, key_layer, value_layer, attention_mask) + ... +``` + +4)在训练脚本中,从工具包中引入`freeze_delta`方法,定义优化器之前调用`freeze_delta`冻结除`Prefix`矩阵外其它原模型权重。注意,为了适配下游任务引入的额外模型结构无需冻结,可以用`exclude`参数指定无需冻结的结构名称。([冻结方法参考《TK_GraphOperation_README.md》第一章](TK_GraphOperation_README.md)) + +```Python +# freeze all cell except ptuning2 +freeze_delta(model=network, mode='ptuning2') +``` + +然后从工具包中引入`TrainableParamsCheckPoint`类,将保存ckpt的类改为`TrainableParamsCheckPoint`,仅保存需要更新的参数,可节约存储空间。([详细方法参考《TK_GraphOperation_README.md》第二章](TK_GraphOperation_README.md)) + +由于微调后只保存了部分参数,推理时具体如何加载ckpt请参考[附录A](###A 分布式微调后模型评估方法)。 + +```python +# original callback +# ckpt_callback = ModelCheckpoint(...) + +# replace ModelCheckpoint with TrainableParamsCheckPoint +ckpt_callback = TrainableParamsCheckPoint(...) +``` + + + +### 7.4 实验效果 + +下面实验基于MindFormers开源仓中的[**GLM2-6B**](https://gitee.com/mindspore/mindformers/blob/dev/docs/model_cards/glm2.md)复现。 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
模型下游任务模式训练参数微调参数占比静态内存+动态内存rouge-1
epoch优化器学习率pref_seq_num
glm2-6mlanguage modelingbaseline1Adamw1.00E-04\100%60056MB+92141MB30.7
p-tuning v21Adamw5.00E-031280.03%12992MB+35588MB31.5
+ + + ## 附录 diff --git a/doc/image/ptuning2.png b/doc/image/ptuning2.png new file mode 100644 index 0000000000000000000000000000000000000000..94c9b41a6a045246a51455b0240c4e3329cd09a0 GIT binary patch literal 60070 zcmZ^~WmFu|wtyMjA-KD{ySoMt?gV#t2<{#v!QCym2M_KN+-V@VHk6y+e_glTv^84ifzC9Rw{rEO2E-`=}lG2jQkJEAg&& zis%^l0&OM!LHymjhD4+nQyAbog0q~i+q-ud?tlLvHZ4iL-@Q8#l$R3M^fvmP180yr zvNtR*@dL#+rwy*QknTL*PgVAm#AC7BZn498QP}w+I;MOZM^#K33)cgpFbG#qys4Jq z<#WpI(^j_+lzf|5R5`-t!;Oz#`@`nA-$T#n^Lop>$blVaJ|fh}fl*KAqUrznyz@!2 zZT9Ft*LR%1ldpr^{@nxo{Oi5rE&t#9fzR3Y@Vwyv`AnqD6Gj;DpX)@Zw1oZfBL8hV z6=g>z?|;Qa{C|rXHA~ss+6m~AcsX?|5}PRLUTHd_C=?x#WiVqgy(krKd*sKMhL=WW zIok*m#x0U(;r|-`UwPRoGQrEd_hwqXJl2QnB!+(X;?pcOs=ZnP?nDicR)rWt-g&|M zr1A*h)IkeTLosZ5mKxB|o zD!&LSMkn^CeIECI(Zx`aexe}SX|QpPC})IddubOH$^!I5yOx@r9jD*wXnML-gD2w8 z*WqyrtF5iQQ29}@PQc?N%r8PnX(?P^z=%X;Rs59-Ig z;dqz5VHEZjo<^%_%iJQZJduEW&677i$l2Lh&`bf;4G{w^E$#2MR+YG#$z)1d)lN@` zLm$l=y=Dnb%`bR@*x?evrUMb|(`u)!TY*61nCR%ihm%>H(*>2>W^v1*nLua%2}fJnYB=3_)I?4C6r4+@ht0 z3XP$a)%A8axXEwa=DCv67su^xM}PZa*qIIpHSBOl?@8^LxnOzCZJbiG1MUKv12xas zndL|>HrZ%|g@;EsG(WvQY<)c{=XTx}V=v`V$rF+8^gPQvQk}in=?jGeB+U5bYzjNf zm&RWT=8yOOd5Z-ub55(-fp3rhw#WHlJlF#j`&Ahl1_o%(#m!9-v(;g(C6mJD-QO>|65OzK|*y&M;6^tBn4w5b)onDh`C_XfuA4#ocM z(9a}tphN2$8|0~@zzZBqe8ntN$-^hKi+QaVFdrT%DcYPh_$*P23!ej{u^8tAKL8 za$2Es_UOSN;b7(B8(JUK_&;3(Hy5U%;q>(M)VH))=1N-F*i^KB=QsKB?vUR$ja~z` z&Zry1ybJ>qGn7I)h7EYnThJ`T)` zdR|t>V(_8@QOm!%r^N3YNvWsGJR>FIDN(c1Q!Wh^)%NT2gZJZQ!OS;NpX*N(d7`2l zt{nVd8diS&!cj-USBSrq?NCYQOW2e!i}yvM;G#)qxY{v=uN6^v|EA#Y0(bQK;_|d3 zcs5$L-}q9t5FWpPfvp3)B2IFj_N7jm5@IwvuggyQN^Jh)=kl2wKEDI>O&c6HGvxDZ z(a;|=Rwm0IvAhZK;}|=l36&^}Hc^8ARA656dqJs)+HRR!REw4_GzTwyhnxp7Zpu(R zj+xhs<`fhAzy$Za!K22`V!}lQd7|K+cd4gIsx$S6zPOxwDuExE2!H(D2hmALSn&K; zua%7n8X6jQ(f=JI-;=!jW)@{6Cj3yFwyyP%7i5LHe2-|;AQ7qwn1=?6>7c8#;&Sg4 zDH10TC$%E>u4AU?vQmp>H%&2gA^762hBf-mdZjPtHl01`VC|xzea5(@I`7YRILvU* zMG$|ijtZ-eu+SE3P3hcZ>;2*j^I;t>NByX#$|9L8K61t5&l|DC0xiOL!SH&;Hm{FOrinApWj&k9yKaoa^r?;DxJANVj zozEDn*d(NR%A^c`PwELsu)av(i+8tVX%Efla%sW0Y&;noHDr-wH9KDR5sk_pDR{Ov zI}m*7Npbu8&<0sr#I8vEmCa}2o2eVHEy?OoQ5RQ75mAyYaok)=ggb5WZMIPKINSazFG0dNx$vd(5(24?q@yLrJ*ZTqq9$V=-F#y+6&nXZWi@EKz*A@Ejgbu~gIh+Ks{>k5es! zr8fAXR6~%S60H0bK_&)OhwqSN8~+i_aFEbc)gA+hi3=2e^fI=e`U!&lUtx8@jcK*B zLGe)Cu|K|p$cUD0=P8#vN}#TjyMDdZGgD7T_eh%3rtWr530^W$n^U1_RQ0E5dUu>n z_n?ISuL5j9VGhl|mU;VL?ULMWFn%-J#0Ih>i+oj!4djrpnpyHgl0&*5K_~CPVC8K= zS90fUEh~)1dD;%|2`{csS{{IG7}Bkq9bvS$HP2ZB8Rv#D41(qh1ixvQRp$Rx^_<|Z zVL!Rb-ALA^h%7q&P-7ux>|hIJ$6Qn=sSd5x%p`W6j6%TLRdA;})}}MqRZ>BJ#j=a1 z6?cpr8A4^e#Iln79kn8`NJWImp}IM~=la02Z9h%SC`LEtBe5M>m;)+6i1>%~+xselXkU@Hkc;+XL+5geS?#9w`cR ztnJCsdp$W6o(FtpsNWvgmbciAaNtw8wu6v;{q`DgSxSm~zbgnN0A(>lO z;tV%gwX;s+Y+&tMyw@LUbg9FS7)3$#wD!@nbs{$r3A0XF3A2eKi@!XQWhtZ!Pzhb~ z*#8K!%n{ZKXU~5BBSZv`GBig+W&D*5?BEd>xch6Cj*;?UOc!d zV~_>TGcYa6Xib9dnTODr<2p?kFmBCTW)=w)(wXR$^kQQ72yK9uQTB{XSS;&(>h^oU zk?i4!fqzk3M#MnEAB$gfBvuLKfIZdMO>Ijr=QxP5UYKVG4|X|~9&UVRC^cd?Ec(OZ z3|XnGXrlWaO^91WABKH?WVlePKmiZ*B;Lr$xQA)8AmyXFeUR6>9mRA^jh+aSO;zq#C^K^kFx?ons_9zJqwnGm6OsP(bIk^>T_Tx@4ICpoKC_|# zDZfctFX_b?zppw`%@Wz9YU~}W&+*su$;5YO_V4J!+GvVyUbFF(xZ=Vi$vQMIS9IgD zTh(0oQl`sk)pWLkPHfOo_>g3LHByV@o$jcCadeszcREi@mZ4L4?L;oX3S#}-7TEgF zoq5{Ln-Ys&&fqZ_&JJixvxKsRG1^**)qP#f|LW%Fe)bDaxU)&(Rqt-rWe_nt8tjke zdwqB8oCFNk6K|#zd-8X0e@pH_2V)$5s&6upnFINYcz;e|4Fa+Z&!PiCcmcMzcMfy+ zL-HN*7LyL-X*hj}=d-@Z=8(0W0_0^}qTg6MYbepQx0i(J_---JA=)0nohZ2ga`H-dLj39WB&!p2~c78yzAtJZV= zYR;w51&^lUQ(;MerPhYwjGrj_K*xE26c#yB>SHq`YbgHXV5Gle$rx9Cn0BP-_49^z z_1ZzLfAmdIljNv>+Kks4b1~Pw+U~{PHXiNDs(|Ul4jnV58S)tJ4!nw6(H^=7^}7tw z=MQ5lI#IsclQ0$qPVM}1y~pp-8SrW zgGFCg1cbI;uXhXwEMYX9Q0{@o3o^!tSxcvvbO*FT_{omUKW{eiVx3Bs1XAnX5^M+g zj>IsYHE6fHkr{r399gMM*r5pEub% z2!TLPo~*edk)HN^@2Q*9nQ)clBhvWK=B56){8K*neF6G{PX@2u1WeTYs#Y>=xl+0P z9XaF@dSkSVvkPO8*RZk8G@HZW6CoNT^!sTteHZZ^s-vB*(<*e1%h-Y>WirZqU`|AN zkDIC~+_Mip5^rOB;+ZN`Cjiw<9&NANKG93iMKHzZ>qCGhD)_CoxEki zm>Zyf@ls9Ak5wzGyri@C>qQRLV<<7zqx0@VWbtL1wOVjj7&md%4>$w&8aTqQDBqZo ztc-iF@z;Zx5WhAH2l6wV-`ZgP$5Wf!mtoASv1DSElkY!!LQ&P5hL)ScrD4G^YV0WV ztx&LL{{)*}vlVtXZe?aI2&Y-TNbHs%Z1P*6iU&LxCa@qaoZ+I=o)e1!80>pjOk^9fv}L8W}(N8hQ)=b_%qA3;#1&5ZQ+Nz{z% z$a~tvKNeS(+jg1>F!ixQCnO-h8Y(Ke3r5Xov%~-DgqA$T#p1b5$vBU)A%wbm+|n`F zY=wU%0tw25WPkKE>v@)b!EQnRa6LEOC{xkl9?EzZp?fx4+K#Dk&@9mnldMW-zRleR zS$%w{>0VTya>n*&9g_%ORjS!V=eTwSuk`ySLPqHTJMRWloPS`n^2+-+(LtllNaMLhyu=i?R2#ySdYPO71Ka zhpi~3)px`3+Gyx)4iMnLx-sqClh+1oCwL>gKSgHW{|$7mNB~3AA+ODx6!H^K{wcY3 z{ik@SkK3F5;L7H&H||45HE$XTJ9L+y^l7mgm)8e}CRBy`T9k)O@+@C!qmzh&4<_x^ z95uxV1~9!$F)HZy{dS7u+@nCN~z@-;HR<#;p8-8Wx-1jOs+7pLtty;DCT zLUi37VddKcH};SWe|atOW25zO7O}=ZvhxjY$Au3bO{Bb5=LyWC?DI^yn`4Rk#hVoW z9c$gY^~s5(ra`>*>xbdMX-atb09(lJ^imjL8*AMAnM@fno5y97f_LqZzXf7MZ~l(& zqxQbYX`VP29C(0rarLN2=I;;bB77pbiW_zfMG|2!L}^N9MPangk0$zO-1|{EF+k2g zx|95k&!Zpi$WdrWcIm>k-z^Mw3xZfT3Cj=e=tc5Uu93n_j$mKU%+crvb3=vn6**$T z(e&U#(JH?*+O?7GCnjOVUXlf6 zcM=yZ-sO=vq25We`^Y1U7TrMXuX=g5gfB=$s8hbotEgl!term~<`{pIN9drWiKeh2gV>pA% zuqY9YZ=sWHza+9RL;HI2{fDBu&J)zgo8y1;lSJ4z>}0WOZ>dJ#>}PAR5nvi5IEbpI zQb`w(KCpF=q($*_;%=;Y4S;mr*qfakuu_g>6K#g>ZCaSMtCyET?hg9f7*@*P% zZIUT@as8=)Iy4n9ph}si38H8|Zo*_u!!Hq!LJED@RRgJ}W#Z|gBMZFY>~uyJdWtJH ztaae&vGm07Ni#tnwMf!;q;sQ8+ay9f7d7egrm^#JwhFF#5ytDfl@&UjNhBTB1&a&z z!!+LWc+>j|3g%}vL#GHG%+wqDHF@C&8}5EqwV1%4INK!IkeTjZuLw}7iNDXEIRtPjn4 zhGG?IR-Z1HRAwOAM0wm6w-Vl&2wCsxinh=&j!-crB^>;)++ZVz#C6PtQHx}0t|aTX zf>WNbStUV+?Dw%pvV24c*+#nmAe~N-<~=)ZMxkQ(deWSgR0?iFZ*Jdd91)_AGH=Ty zluho`N{t}I!v=|abHPCx)LHANCK2DikqB*%XI{OSXR07r(cb@;-+r6~g&VixfvW~# z$KEbWymIh`1Lf}j&t>0Lpp^Ndac#~T=b%Gy?;vY_E}@O=#S%=^!p-)QbodZ25_F6o z;4)8Y6Zq1``|Jo#4VdKu3sjvu^Xb0(XZHNy0uS!z=51j}f8yBpB8V)Cg8;i^7|!M& z!ao)(t;7zL(7VIn+4+l7<6Vu8$|<)+*eo7*3@pyHXgeu+qRFhB-K>aoB-GnyFSNjD zZE0j}Iyg6^^_OPBEfkT75)pqkQ&-f)(oeybmzRs#K+W)X_v$w*7yN`2B00VD;}GSt z>z-f~wpg4IYgC@Kp$f@xvhZDYq|%S_Uw-egaH3D)hKqP=`1U=9e~!MgnIzD%$Vly0 zj>z#HQXJN<>trk1h{O+AEW?)b_sfEGRbdSKu{sdr?;bKSnC{056p{eb}h6xgYBBZ&wzSy-o!6+{j(55 zWX;v(iwu(nzAA0Kxa<`v7T>V>K_|Fb0yxs-88zO8p6O#cM5O8@t~TO)b#0P!YAIAH z=-tuEuo(q9HzXiI6CR3I78c*;MLJZs-5dE48=^NZo1y@5&|aHZxr=sbjKfOg=h@2!xZ)fOYO@4Pq<>} zGoyK;xttGwF-%T_L}qm6s4GR@s+kfsYFlnhRP*Bw*q?a zp)C{e^M4ds#FMus4cJWOxJNuX^9}a_$r^K-^7x)6G2&YSVZ`3sCy2Z8L&=_#?uMB? z1|v`=EyMw*y{+~;k3&0T7cwdY#oXkk7}PVZsTgEFHcyU72frDC{b6FzdTebZ9WKv?oeAK;hcm@3NW|9mWd}AWvZg5 ztLvM6=gO$@OAHYGB}Ty<;^@akQ$?R=9?xm$Q;41(J+4ic=O7`dZh1r5r0}i}^DJ8T zOCjjTUL{%StoH7&QcpTnE0HD@v|%Yfv+QVgVmd6c$#dr~RF+k$d6IICW>oUDz9<{E zhGo)!PPFW?y&SOsvw_95y(7Dvom#FMn*MTT62CqAiZn>1r0Eg}lc?*CN7wqPD=J

R3s*&-Y*abyBlYrcvp?vWAJTtM}fimJ6X^S`$&A_<))qBUOcJewll8V;5 z;QHVBR94#7>r2S-J-eF{3$)nOiYe)B!6&@pnwnUv&32K!Z(v1acFx^?7uR`6WG7?7 z1a2a=sspl#4Q~weKE3A1{Cg)$?bd1HF^KblURv;xBWYuC~}3tq2X<2 z2VeCb$-xF}y!l_^Z8&xr#1+gj(ZP5h^Z>Qh)r!m? z6i?4FsBtsy=7=RA&xD9#rzumSfC_45)Qgf|wqapTEhx@qKdj?vmbXIkQNWix`@B#!nHf?dLjq6`_d_nFU7S2SzWK3osDP= zhbQOiL2TF@8r;RR1~YF@}vh@m^Y# z>XROGic?3wnqz!RmU}De0y%L}v3ET*%Ft53=I;CBvlx;o=GmZ4h_DD2X-o{-_*vGX zX^;l8T*G1iXYlfB7fBa)XzCX{Ct+svT>f4R6e|G-Wa){Vs0=B8t zjfskX)+MVArrb!5+DX&qdL8q|$d9DI;%N#{^=Qrgbb^d-$6{bAww%6FU#8A122T1% zHtFAANA3BCqxf zL+6R+DTs^0v@4e+1X}dO@Z-(R?Fv*&f-n7x6+T{h`y9sg6xe(Up6%s1^ZUts#?H>S zfk7t!J?@QL>DdTX@j3eE<*f=k_yUr>;2SqpI`!-b=b4MeEqh$eNBA^YUx|+fwtFx1 zGRA4fNbOLi#>+K2RC`lI0zxtg1e{Lr3^S9BH_G!dYL zNh_`kyyFgJ@LxATO7T4&S#gX_zoLMo8)C#Okij&Q)7#zMt@nFyS)N0 z6PX$SPyy!YOgi=Nc%3$}{(yN*l!vJ;cj7PTR`lh`M%Ug2te3bm=jBHqK&8Q!}^bP`JOy$jStxv2JO2}GDZOiLJG zR=U(Ih1xN{*XCw}Lc<^SUVBQ|x>y9yY76y$A09MNyxxV2f6k}CFW{__ zWMK~L*{tLy6rBX63hr&-!W0iyx9Ck9;4|22_?hn7fBRVCuKIuDj4n(cXmX9gzdr#J zg@jqseCPQtROTs`D$>l-gTFW&Eb(`vKb{11Mu{!oN9yaa)nZy1>_kjqMp(H1$ zMXkd@_+NiwV)fqkp;(kE#3!nu+jA_pQx5f3V>)$o@>r-dF;R6@~R5VDu1PCOjXWpW&U1}lEJKb_7|5dym zG?H2#OV-zU2%xQVjtZ&Y(P|)e?M1VAsh~JI`P$YA>tdRU9~B$ z6a&exv)^b^``N%$@oeEXr9z5T8M2I>upo5A zY#HV!ZJueA=!JAmB@a*fSGfsaam6n$2byDn2odS2wh~>?Ey6+?4|=NWxHNluI$ocL z9VTZ$RS_V0lSOa*L4MG$Tmb1nxPX9zgJYNC=CT4v13gO-;fA=kUyW7`&N~v_0s-JA zJdLcuK`GPPcc(;^)Qah#9Ilt`J}4mXeFj*OI+|QoS^1~u=UDdUf(Q3k@ZE%4c{+<>49I$sh}rP%;?G;+-{g9R{FZRZ z!3dC*KV7biZ9oYn3`qDT-V5JE9r>m@&pO}QANjgr%GmJvw)2gzQ1B=J8^Phyv?6jh zk%jMXu}?9yTDyzU1?6H&pTh?cJ>y|bR#Vb+4h4uD&Z zgwTr9Qt!}(*BUkk!=9-~&>xd2nFZEx-hFdh7u<^|IIAjTBi>ZCgAJ)A?N2`4hqOj& z&p}pSbG)3f&Y9387-YyH)WG3Trw0Ydan74nRZpNw25?X63Eg{99+=o98ZS1JA24R) z=|Q-vdtxvOauKkhC~C2Uo?f=p*Kxcv+NDyUG9J2E1YtyZuu(oDP^i=lOeJ}{f)S8w z$Z+r9)ql4OgP2b@IpEMncCt%kfa=N@$!RH_U(~Yi^s&B`D_7ec9V;CKjn0}}y?(j= zfaoO|`=YKcGo{1AG+W*miIh2Kh3ZJNIg%2g)qhzH)==JlyYb`OcHF%pgnXlQARu1Rqh-lwf=>k0OkKJbiWi}S*#_Ve`hPp7d zz+K-5zB@BAwO8uL{NK>C`k3L5?W2s8xT;=lO6*CZa!yJ-6$%y8jj8uLQP-mi zT+|VInNbC}rVY&A(Tr}i z2K!Q+haq5w^;WoyW%1UfCr`_D$mI3tSJ%jgXQFG@guo!hG&KL}4HD(-DZO>LUMc-R zXzBkU{4Wip5gi?EI=rZPw%Yuc{4sLT^8cWgI$mft6f^N?J8P7VM!H}(Tq+T%)39q- zDVsh78i@F)eqP}JhV&Ld^zlG_}BXrmW5wfiQ*AssV`6*Y>_iNiq@{fen zX4%9iOQGX!(mJiC_%tnbafnl#+f}pcRdHAS+lcOQT=KD+BnmBAiEb(hCSvH|Us`k< z9+b6jK3I;J*=likULBz}?Wh{{JN73uViCpYIDg>Z2ulkL<;{ASKhdq*3uzIz#jGsz zZa4c<=wVStMC~uu9!RwjBM((&zW;5=fpvFWNnPBp(Xq#6kL+I!{t@QrtwlJ`x0c=x(T8g!HtyC3TBn3kJ)5s1$IN5(`KQ=OjNbgAS7 zkb9ln`f*LX!oYX}5_6yUxv&-k)MAlhLTa*}cNk%lS1OXx=A$NYGzRv*` z+~Pc(A78M0og3YD)QE$b6PxrRPu6G`N^jTgvbQJ+?Wq5`XRGQut;{!wX_c@aFTF@6 zxHFcJEV2v8jLudpSYzNK)ae?kb!$>L;c-BXFFEyt4>;MIa-bL%P7?S>NQ%EsWPJtJ zm$Qukk=y+B7crwkOrEemXbl+;rfQ;a`-Ry3*&k=zdIVWYv&(+W?`iDVXF#)ACOtT}X*N7*k43?P>JIi(qrsn<;3+G+VX^5MfKe<*kj4cGtKS>#q> zI)Co=HvBm2wy6betI1?NB%K+VQUrl>d~5m1s`6tMjXx=u{%6s^m#7;3*8a^+qdVbA zv1**xwN@wdi_LCdk3Vk#08fV)4mpxYYLzQ_aC{sc7WM@PkJ;Le8^D!9LPLi#ZQ9+B z0sg7^wwZ5!=|VCRm3_LPP;h2ms3BkYJ%Fsu@m8mQJLj~TVw)~lZFS0q}~BBNg(+9L3Nmc#sLQ^PtHy*2}%1gxWs9{@iKf=q^Atyq(_5d+$dgE(V%! zaueZ*y(T^flZK7#kO`>SlprE9_(_Qnnt;9%po}d>h6fCmXR5pKeoQSbU zA``9XwfsM|pcaHli~-!NLf0<|ZO8TtXF@Pw^0Y&bKLHWx;|C7uR-IPJ55QXdEya z@YllytMvR<0*JL6kLdV#X}KudG&afq^3DnJ*B8Bi%+A-_=`>2?>M4M`m?bmtOUA?J zJ_aEFFTYF!XDHaCRioGM?Mp~VC?zc&Ez6hfl$sX^mVXqfU)o-V;{|8iBN)b;j(wpI z0dd&sD=F#m4TySAw8Fue=)W&7=WT)YO1+7~6c%Owd8ie@evAp&KP)E%JB$ zIq5qi3jbSgV^Kz!z)#l1)0#x=L)bUReC z^lt{~O;qRnDgQeApli!*zoH7j&J`MEWEtZC@!ITjV+jD#0nhgq!`h5NZ!i7OB!<>V zx&;#9nE<0)XDkZ*n#JQ_Hmn`nGWY}r14cxF#SdVUQ7=(oH~tOS()~0xz^3($+b$hX zpU}5X_i7P{so=WB`qCEF@W!iN{*kidK33~4SJh1@8 z%DGo1otM_>brAy?AlRR8dDIGRe|;|jp0bsi?<{LlGkrMU04$XSR_g!;oq9628_`ut z`NW1Pwb+C;i-JvYgFKMa6kw|WHcxrv1W02j274%TbtXEb<8<-Vwl4$*co-nfsfe6b zn<{4LQjG@S0d@al_dgRLgTEKAebjIhGo}Hs=S4fIqkBqgofzM~W1$P=5K;RMy@Leo`W9Ij7Ci~E- zklKkao5E*{Rk%jozSIDTZZvYaGw^qm0rUeHN3ZjB@#lL0bXJ}9Q($Ce0#bB;DP;h0 zCl?Vp>)>dTZU;zkz&Dy}EhgL0jM~&#;@1SgcARvoPweKy$iNVN$Pq}hxwRQ8R?J`l zoFKg>>w;ES@!sj<4GlfiCKz2@Tarjr;xQpvW?o)iKxC`q21Z0O;C&VLGJ;$Fl$B-y z19;mSy7(+J83de|K<9unl74uJhKx*%P7;vU^gB<;hhlgsLMuTR0Ms2;xekR(j%TgH z0o=TsBp?k59H3ZN7CZ2?^lu~aR|B!pfr}(j`2JIMF%4`7M5srx68hf9i3zc>v3ULe z>NbwtX1e{nc5?jwp3qe0|0+0SZUA$1M>Y~wcl%#C$IVl|n@Qm7o%Ufpf52ZJ$vN(D zC>A$lwrtPwKV==6IRBWwq5{TV)uOFpc}(`Of3+J46~Ms&Y&!%Uu4vrgyOTvs$^X=Z zWFr3016ePsMJmhdht7Q8=q`WS5u*w(^ggBg6#l?`=&5fVja6r{L!Zj%;MK)aS)X0r zg!NHuH(~Bo+hY2k=P^U{nr*|c_C{jKvsvX`SCo~Agj83c8KpxC4!lL{^uel^KsMN@Dsa+;$nywA@I5ZR|6jel!J{yP|v=A~3X{69x2 zM5qR5yL$DraaRnS6TYL)6Rb9GVu;(Z&`<3i>WC>rw1ypEwHP8Ieq z&ErURvfAs#J!CA1ydlb_N19exY0-jN3cZrAJB$r~7Jvxw0;sn)f_vk)sogSik@k6{ z%4gOtFQ)#VJlM&sBVKg+wQTm;+B>>q7PRCvztt5=cKd^8os;zc5sqKwp`If5LVtEQ ze6HSqpADvib>$mudn_>3%E!4%qP2B5f-Vpz8-LBe#HWwIigyZxoRpDy z!uIXvFNN8!p90roAIvg&F17``h+Hixi|X1s;Za<5`opM( zjN8UB*&Q6KLFp6{9ryIz@ry zM7h&yfdku89Vd}COi35*FG)fYdYYI(mrDzG{focgD8r%*d7WwuZN8B?WxZLC zT8Yyd9G?TVny?x*UN^Bu#Ax2S2S*6*ZVY=J+_vY&&;47Ob~E&JGcD@!bZv3PI+a01 zSwwnqgT(I$ffSAXtu~IsJKKBi^|Y;mzu*BAz=Gd6xJKoTI|hFr(M2l9PLb2ZA9+U- zW;kM=4h;Lp^^ZAT59AxNy{Ihs`t!}w)^JDan516 z+!*}3GeL*6f*`acia&TKZ9nN}Own|ocExmmWX@w>3kR4<6u#FDSDi)-h+Uew^|~g_ zE9h;4GyRxJ#xfgMx_)B6+V11j%5O?J+-~h$pATgUHbEfDjo#4IMSrDzFqNCUE?;aR z^nKe7D|Ed^mUEajz0t@yQVpi{d(O}Ha9kkrQ|Atndcw*0h(B6q@#MmXe+#FlU7L_DdMB|ATN0*F9p3c5poC$X?viTj79 zvn9-&(=^qO&NamB8mEWvzp(SM*V5Q4w*5Fe(f{Z(xHa%*+#|~pO&!x>n4mhzI?&1Z zSoGC)Dsr%t@xGWvB~9WbTAc39P`|oUWh|;V$n%Dx?Y=5_{T96cVGi$guybu@qX0F+ zZJF)PJ`sW@W0)!LtYcrd$HJa&%Yd`VG$%Ee=vVZDjIR8~f-VC`oJBuY# zNDub%8K>OF2V$_O)G{FxH;^4VMxk{Y&XS4rfiP*Kw~v*7N}|AZ<3tmMpT33M^|6|i zJ8%*fL8t;SYt6X&E*B%C~b}q}m2ehesu+_TN7%KRS$j=)$CA zFc+FB!mEHz8q;fEHDq%=%K37wEwX`_0jcg9G;S6*{QEd;4=Q{|?^@s+W?8C>xh2Za zQ7bGGZ=nP5r7xrJ!v@@jz|Tk|jcqwb>$wo?MjrJ&Bl@d_8vN>?-AdeG3ex+0J4at- z2u*uo#Lj^QM{e|GtE6cfnz_oagoRv4JChK%3{2rv#<;(ix}x^8GCxYP2(dO>Zt*vBt61B-INzOcEAAT~I!UtE6s-1b=7~*I>*>h*x^o%wd)&EIe zuoC`TloUDew5wUIgPxwAzH1DaBVd_7!k%7J)JF=CwrRJ=O+k#v&uS^TMY&vGFdJD_&<8D1VXCW}>jH2kHRWog^uy;+bWU zi;yj%bmv);(N@;D(b^MT`6l2Oc`WHui)(hiQ+iaoxvn5K(bdju^C8_( zySlr`X5+;?h+hiflyDa-Je^QQ{2Gtm`O;VWY-F9VTj;-X3~A?zHjnojR`l5ZzQS8a zYiNOSm52RmUFS>Dczim3xL7onrIY5w zr4G-r$Zo@Mh3a)v*rr}H1tHwAEPtz?cV+82gC;MwP6+a~s4HW5iVtJmjP~+Aa<1G1 zF>nOoRZ&0s!Ala<`D0(kS7}~~&?3^kFd`Ifu1skSs^kf_mfb&n-Pefk5#uPFW|6GBa3KHKdX9-n+gy_nXeW`hnazL026l3;_fo_9{aw+@ zf)eJO#r881*L3=QrgQMe2KoMQ1ovGQ;tt6X5Rm-A42=;}hztC%tS`X4=CZO8O~A?6 zq(fM=A&?z{aY3ol@1R7xK%b_9_m+GvtSK_klS@sODo55z8j%`zwR&}K81_=k4W?yl zqhkkOf3s1ty3XM*mp(6Ox)dXtw>JC0S9%hVghP;zox2J}F7Bw#&VK%NoFIioNlaRC zafVnx#yB+$GgLL~K6k-AymmTaZ9~5Ny>7EkEM5bi1_MfoEcxdpTemOlo8BqlBy|DrHZ)nt2 zi(>Hv_KWsBP*%FjR;W}krv5N`?JW|VyaSO;-=N%40=_Cm{H<#x$N!J|sX&Us-&%k; zfj07vKmDB3Y*6?I2hAEcoa+xqP10t>Ir;v=os4~%W^`HfnxlAoRKVD4 zCaJkO4BPuGt<&yWlCCvz65={YHz(gd6e!ZKW66a8;5&B&uc@K2C+y%W=Zk6%@RVI{ z89Ns#C7QWsqQGO%l^BzNKU}h%UZq~<-R8ij^{p9_Ed=Hg0|C7lV}YfxM5isyfDRkMO#^Q%~%*%iUz%^&EwF6gvcHI{KZ z-01Y`?`oU)&wfH>yIjkh+5GtU`!3WLwb4^5woJG z!#juoGkS+)X{?3ZjaJW_&Jj?Q7pLEtlho=f$~43N#K1h_6T4AtSaSKGdt`jxPGw6;<`OAKkIfq=T;#0|NnZ<*N7222P z^#0NvCWH5#o}0${&_|HzfFZ|T8+4GDUd;(cl$xkR;YUct!G}K7R(HV-S|F$S=#4WK z8*Gc=a?C4mAp>~|Mhbtz9c?|zR}MG5{D2=@0?;0@DDS?j;~hKIez?cK66dy`{Vp`j zZl|&J&#TN#^liN*glSMA~e zMO1Y9^0#;{w+m$IR2flzGv1aw4MVlUEo=9$L9#*Al23>v!iY%1Xz6LRdmIBknJs`$ z(`3D{D%c2ukDG-U^Gpn@x&}>>{nR`vl|>roK)rpnQ|{Nb+^ddD*=H4|w}$TL4x2Y)RvNG=fsP#X~uqK*u^IS)t?W&)2ho-@lxcC?02kp`~x zf%-5a&*faX`Z1OUYW;(Wb3p1@>R4w2$cE!Q1Fiv<={C36ZLQ@Z)jwZTaljDz&5rqU z5^SaS*(f&~d(8N)hISIURh;WjRZ)LU=2RcwyA$e6k#u`|`=q2xgaIi&*wp0NGG_f{ zgB(q}fbp23|A(=+0IF;2wsm(xAUFgIP6+Pq?(XjH65J)h-QC?axD(uA;ZAT5ws3os z{qJ-4z3-m$>ZNK`s;H%7${3@!*1uL9+zFEo)E8!tkt$kqE`r7qK(BvwNf6bjFo%cw z^eW^YPCLPSl)X!2(|$i5C_0-`4^mt0{YlDN!f-4a#6w!`;`)hHY1$L?TC-<(kQQ%0 zSJ16B){Vxi17RB`mErZ#xBvO7A**_FDBEEDAQU+s0#s@HtT zb_WZ?T;fVplpALFE{iP3(@ypURMPc#4HcJ>ZL`j^JPb0_K(jnV*>#i}fqVDLB3@jd z?S-%Q2`4`HN{%Yz^lGj?r(De3vrJ_~ly@*94`0H74!Bwc$wos~Dp?=OmDv21M0=mb zG;0&TsKuLDWSQ;OHJHaL1@-Wq8r;|_@Md5N{`vgek=WCCCEIZ^6p0e)V}y_|hrjgE z)^0+mXP9NDlmR#Emf42~g43vrt$Sl06^@!BBZdTrM>68=?pvc)Z4CnAec|aRA37vJ(aL=-D&z^REZD z!i9|LulG0B2D@06H5FFU0!P{lZn)2>X?zs02<)O29i#-COTDU?mJ0Sgkf)O0fv!W% z>Bl?{W~xtSV#zAwABS|%pH<=o^U<3*BHWm#30qOzc&czMn0G$sOHU?L*joO2oyUnMPo zZkR;U;?s>7*&46~Vq|X2D`GnR29qdk>+diwX6ULuu>05&!uL>@mH1<5?c#(`d7YX{ z_Pa|G3uO&wheg<68>q4X7v<2^ypB?O+o!9431%*)yiabi-_k-Y-U(8ve0qHHg#MuU z#j@%0kdpcHJZK>^f*{(;ml-AL0!_7Z#NuOdYp!u2o{zFd zdTcMtRcN8A@2|l+{shdexuKUNeoYg7zX$pQGfe}}!i0mK=6w5_m%?jIJ&@RU@(*Jk zjkZH{XOgylk0ZG`IwN7o9{G(pp;CHO=;J+c8(~GQ$_Rt+D^vK2&?VtWHQtXJ%FtW? zvDNmeBo>{<>ljgw9eMXm?NA|4^yESfYcGH*|6?LJMm;a&%IiayPlU{AAD&uaBl&4g z-B_LwdID8KI5%74fpZME0fAtx??tt_o@!#Fuqv8N@MQ0V6extf5m%{ARW*{23_a=W z2*L08GQT+;Idt2$j-lY)8gj%Q*Jn!S)CiLG&9Zfpv^Ztxo{GoW692I|tlS*WOvB7?(U>38RQlsD8F0knxKZ69VL`;VTuAx5KCMGV32 z)|q@zm6GAR#v1T4LF4^@MyvY+hMRULWZ^5X@OF|ma#e3bQ`+ok2V5)tk7x!Az6pya zVnhC+`{vi4NGf2Jn}^2x#p%`XX}Pn4UO1T7N+ca|ust*RoUzbslf^IDC9q@aCtVWL z+E~>uVcbwgU)KyLcUa4p8kRLt7)N8kdql3?Y*)0`Xa9E@*l;+G1e?w<@c^*}Hc*14 z#ncL(<>5UHlD(JDw^@oLArsyL4*8k{o-z()fA3JgyNOPOfiq~CrP#zgFL#(fRC|Wf zD-J#K2#HuRU)44*OX3V+%eNm#(Djq~f@^p?g>}`e_9uk9K9wkpLauSK3K#Z3%C{Q0 zMm%`aV$fs-4?}3>wA&D}-)Q+$9a^`1yNtJ&LEcTL&sc4EhvS;vleu8Iee)@nm#cQ$ zj@(eDe-uP0()IlP#~+L2DP`hLv6Jl;_9iR%BaD(Cso6h;D))Qzrul-*jMVK)YYRu9b7;*cAEl1SD|;vMx1e+mth$*g_) z)OC*wArU2fP9n5Nocj5#M3_M*2?#q45if5l3Q-2#4+YT;_)Sqh&%B7`kVaOb&E(0h7t35eGSRq7@l}H`?7HsHmvmv6;o;hc%l)lbp7z%|AXO z;fMHo+iZOTXd97X0c7 z>@jFGiJ#6_+g#<|wjDs%t8OQx4?a6GL@88sFY8*SCodDA;h<;zxH(-3C(Itv`aq49 zH-ZHEZtJWdJP_%lDqo{7z-&1y{xm8q?iNKA%x17ze}T#2L@pYU?Z;7mvw70dl+C>~ zn~}>?;J815PHBeQ>;CfeDd^_pdjk#GYBp|W~xeizAb#>l--j~Yrxke6N7Z zC^RYxfI7_|&J#z-*9h)#jf+6Wqa;uP@Ky(#Ec~)2beT5DBo6 zBtkxa`b7AohcMn7ib@9Xf@}auQoSL#ANZWW?TV|CC|r&_SC~n_zz4vEjseV%p|nk0 ze7ybd?stHA2hQ@j>t0XUtL}O+1U?%e`+xyG_W~0TDYM>GD`uOY>j#d!zvm(nIf6g3zlo&*V{KtR7t2=|D5uhRdsQBER7``OHT>lW1$;aDg88-0ABR` zc-93x?Vw1bqB-qop&GC*;$TmNRa49tY1jY!8e7nG7!?>pfWiD)ahI02JLH$dMb%k^ z&*wuuAS<$_nVfZSWr%2(|Bd&b2kz3+0S*M9u7AT=_ULrjmF&nX52*sqAu^E>BMW{B|ICy~qLcB_Qw*06i9OiX(R zATVGs?=xI9<{gxUxH#A_U~gT1x05L1}4tPL~t%ltWM1T|J$wZ)X~SW1E?RwE!f$z|WLYfYOa({S29FbtZsJ>Jw+ zj|bd$oLv&Px(bQ3IuGgiesQ0U3dx*VTu#&(g6@Ub+VZci8UJ(CTQws(O|Us|qYQ9) zz0{8jCDqEeHmGKdE_a4CZ$raqDZ3eX4*Zbn(KT;<#SgZ&L2H(Jq@P zK#IRvFjycFUq504Tzy-yC+~pUTGWrXHdg1c;e#2SqQl!#@z}3`g$6D6QJ@3?d960p zsCBCL&SW%6rbW*wUsQHrYHf1Pgfvm|`p<%4)!XP}zH$GQ1~Vc$J!GB%7jYONuF?)* z4WJ+mnarp6W_D#MPJgh67t%3xT9kEpTd5B6r-<^gGJBG@^^VkWP(rASzQW(zu&G2k zul;pK^(}Q!S|T+tF^HqwV&%^Pi#6v}O3R44xe$-NTzyH(MRvk@L^I7RwsETS%G5~kJJ1PyhIl)9HiCs;18KTOf7gMz38nQccg}QT7fg9eYAO-+&IF6~BHLiel&Jp=S zk-cVqx(MaZy=}jD^AWcl+5cICe@ieqaYJE*5Ef|J6+P<$9*o` zmO>3Ii7(l)nz`)g53pGj3o=fd=utr+BT2VR*xq%mqbtQk4X4S;_Js4B(7qSDd4{Tr znA)0YM#)d#YM+&vy)*`S6{K7zMzjYXrQpDmIlyBP5TCx=86U1~vwIH-cP`tI@X#?`wbKGp56sD~4UP7x!uX6vA zrb-7}T7lB)UZQ#NLoiY+b{~BT5UZH8*nNntt zXd_1|h0_d4)zsh&S;~hnb`dwFo7AgIJ6mg$!A`3cM(M^o zD-24-!=yUtZM-N=oU03qIx#+d`lQ}qnRu6M0Gzh>q`6kUY=Aq+^N=P9i%HDHRo$Ic z%!D7folTXi(OuA?SA3Q&LVJNm%^dAD*s?JkQZb;+{{|Vnk!^`(D1*oK6Klfcig5Ba zpJDgh%0}oDc2sFcMvT^1fhw1e%wrQ;(%)Y3h%r!xtZ08lMZkCFxK-fm5VVB9gy?yr zDD^y7t@rg2Q1*`Wz_OSEy_Q92i`QC2L(CDKIa8wt3sYD;ja`x)WOW5TqNas?FLVC# zMa#m}h!WKztcY-)^CIPQXPAT1I>YL0WA7N@q!Mt^{FKJnH3~wK)p@Ek<^=ZUR=zOR%pVQdL0^dS&yMRLui-QVH4W+a zv}cXfi45tlf)#6Q+p%;hHv2wAmB>oxpaH^Kp)&P!{KitG-3Wx zMl*1zEM8_bWj~jGMzp26f+IZrzXB{+jkbMT@J4~bi}guM~mjkdZ&I$ zFvy;iXqXKA)Vy?ARq>_;Ma+>V1I#e3UExsRB+rYo;)fK?9(yC9;j&OOnMRHLskyX{ zVthXjH>e-DF(_Z0SHi|8I!az_1@0x#cM{SPCvGg^iZ-cH+WlPf?i!n}eh=3#ks+f2uN%k;s?x#4vTC4r`BO)n{vI9Lvj0`Vr;^ zxgcwzU~ty}e_7c*K{2jxWAF~n-5Z|w-aWX>7;ddI@?BJ+g4xY;!bnbSfdm)oW!p|< zl1szXqy7VWi#3VpQn?9rnuT#FW=E3QSJc{yZyMsmx6+!nu&P!nW7E~$2k{(#+vb8y z2F4VY@ZN4`qTL>w=7*xptSyfQzy4?~Rx^E?yIs^)u0IVMCS2Y;#Cncz8;mmooDk2S zl<2Goq{`V+7)-I|3*kY-?I6B}bX~an_SJ@qnWPkE)1KQtMQ}!@Mm`d*&K!V&RHMsI*>5oOL|CGp!(Mi6*Dp6dN+g!MchDliBOv!93Oxg=EQO~Mk==WTDKizbXwkNCQpduh5(_bhi1QqA%vxq4J&tW zwc#8t^-db(M#CtaeS2Vr-Yw>yH~Js~sXj`8b1vi0T17ZnI0i+WUEMXk(y9T5I+qXGOhp z?(=AraIv>bZ_l}7!`YNM$m-NiK*K4z3a zbgfG<(#xVdo~PiovZC{dE}kv5riE`A7_2!B`+X4JOg-)_#UwwHJ;zy^#2X#co006z zc{W=!mcl{#_iR4`txgBvOlcN5Za?lMOU9HBlb#d~9F4rYWK0r_Tfpj-oBFKAt*rx_ zD<92dG4@?IpWp?#Yc>IOhDP8I;>SpJg+AzfIj6A|1ZUWp$-*`*NqwC>8V~BW$lk*= z5IFKBb>+p5g4CZ_&-KKY4x{LYc`c*A!&G1G?CCZgG9`1iWv6fvmfP zyXavG0-i{FcL9coy!|XqM(b7}#)gCe;c#>ioY73ytr@1aFJ2($E}S4KjUOw{y5VSE zZi%^7DJ9vcq21FpWmMpx;c$uzxiN`Cpw>$M^(A@b$ikmFBZY=$o{WdF{v3r~g?w?u zohoKrCI}^jiPGm8iDegZopO?lw0bc-1(uW37_KG+hh&0qk%hb(;{jvl=XFF?lxt!O z5=2k;-}~8`zi6v*726opQJg+?)mMDSMe&e9*{OJL&gN!yRw+O&Pi$wH`vf26EPoXSveNk)07pl8dNt z#z>;DdLi!Uu1t5xhESw)LTsr^!~?kS!42o<=luHi{C9?c^0%z<*{lhpMO1wFtSOZ10N)RFUdx$cHI%*+M=l<|02X_AP zu4o{r$r)aRQ}p$Q#ED8sMshF0qb!2yLAi7PnV}?+hatER8YXB1dt>eyeMa^m-7e(^ z2&pSrfJNMWh*IB6K2AL?C^juhutZ*C$2QvDtaQmE91E`#b&qZ~1v^%=iVZPeS;c1BANUdwo)L(u4NA9J2%R4O@P!||W7>+EkbL*ToV)Gb>-NUB&&sK+i=?{ zOFtg8z9auqD}0enR{)R0+2Z;B2oYwUk=y5&*i_L63ZYS0^QKJ_chw=5Kob4kZpoci zB{iW-IHFB~^xp~bUM{(N=0}WwJtsjLjSeGAAMg)K?DNm*5h;yQJ@JaN)C7*%e!z&U z4n%f2ru0zB%QlH70C*RJ{n4D0HeWGrDIqCJF!P z!Q5ZyF2UUsEACXap=k9#I(W(GEu!H>iE#D`{oKL_?OP9KL=9zcK%zp-Cc=>*T4VYd zCMoVzw!Kg9Z|*Fpko%J!fWioHMCtm-Hf|+o;P?eQ*&0lbm2nTh`RAiP;41pcZPj=v zLnNFX#&e23%PrT5j_AGcc%00g+E3AXIycvv7gx$*U8^e|o;4J3Hn@Hb8>Loo7(%}} zXXgQKUJsh*2zBhZ5P+ATgl_-^e6pkjPx$Yxj1nRYohk^hsfPNNieL4>@PzE$7Yq?$ z8XC?ILK&^)9;>^aqqV|`ZoL0_1UD>o0u!__5jld&44*{)Ec(jmk&rLCajxQ6Z_*A@ zqrm548xz@fVC}0Zf$t*zhF4k0sFe5v6zvsx5m_OD-7@9wx))&}iJdH=5f7VKG4EiY zNcTKexoruz?Nh3tm1S78(8FLx1LjIw-!qUbW#l)H&o>(gIEhcp8hH%Hgj*b~*`<}6PF<(A(HBgGe+sP*2pd=_tO74UA~%7iS;uWSY&Bv z-mH+#k_}2g7WyJHsyR!EaLII!R?eI zG@0eg7o84+JB#1MG|t&>5q;jl6605SV>MPmIL{%N5mj>IwI3Wfs}ezyNu9iN2jQ*H zXyLx6K`1TCLVN}N!3OUrcu(E1u%6kaEmHk6zCrG@fOcTmaGkLSS3O|RDMvZ4YT{Ij z5OsP9{7bP42Ws%4G8M59yS^vd4W&M?DJ?|h%77^EhK2IyvQs)?WJ0lscU78$35@Re zQB;vi1tXyGF!DHbxjspi2*f6d6tAUqKwqS+v(mM&ih+45s@Tcxso9~{XLCf`d*`pI zrb+`LF5|q0)6c|pT;!!B7C9?PLWxAGG~sFm%!c}YaoBv3GXg`L3)!dAJDHV=mN!{e zeYmHNG*Opxmll=4g0ZTH`Pffi=I^*1=X6coAi2YmjGx?{ri>h6GqWIjMP%wSR9&V# zlD@g$0mj!GT$YJf5zW?yt0bRcbjn^v<0zx+Ze-@6%l3#)U_TFrbC{k?E8${Dz)fF4 zRArtlcv9@4v++aiJ;;WE6Cv00Ruw(PR;G|Ivv z2&us6t7o!8>X6Pdse`GfmLa$-Ke}nJq0eB&QjNKC1%+l1zninV@Ggs6=po|e0-{bK zRQV_Iz0&R1sf3Z3W9(=CMo@7F-fV)UD()*-=p%}nVhx*VQN>9@1cDP1lm?^|h;g8l zv#}ZOUOVwuB^^gV67 z1O6^_YCFw<&NMfhw?$Hi#~h&%NJdNT2-c;QCyCjXvm?pxfE$m?)5+fKnYW~`tovYQ z!SC-%Rx?PBVD~uvT(b+e{5J5n2&bbD? z26qTuISj)NbjYo_5*)i|4^^S^3GQ{>$i1xL@8Fui*P7L!Pp>EY;;7Q3Tt5xp7U`oR zl22Feho2^t>?DlPr1Qd?;w*H#Sd8-=$Ri0Y+BH^lYyn?>R?AU_-!93d)O(HXZF#~o zm4Ikpdm6MsN7CypilbPR*5NU%E|r{_u(Unc@o-g6HU3lWw&(3)T$ZfVFWxYc*giUI zKtsci%En7K42gcpVpRop_?>@SkDz)YJF)&#dxU3rFzY!ply|05E;4F69F^CF=-_67 z7wuyj({Pi#KAXxn*QvmV!d>tI6`FxdKWXrku{y*^28>k^6j@b_gdvKQ$^)KNN+RSn zlqQP0YI;5s)YD94=Rz58P2!%&Gi0^m)*;J7fy2&nA$5zdD*4i&+Nj;*y5e>Bl&^0q zs5prEuZ|T&;f>+C3#GHz14e6#q7%<7S-KdMtUAqKaok$0nbAcvLovLX7I9%acLN+M z8${u_CA>75D)9%dr8N-w3wyu|-?x#mk>1CBA!{K>6G|=d>M+f#AQ`R;ceQlm6wNN` zYW{pmI23pkADvrm1P4FE*eZFFy%e{~C$)IZmF&{2E)2g%%axB;1T649jj5L{jwu`b zGJT>+5?^$s64wn4kMUE+o@^9CRepeqtKR^aZ{_K41OIg!d@==dY?Kk z5Um@mw7Q$mxl^d>%s@v=cN46g3^_~0^tzKd%G?u)HRZW4wNe;`l!EB2EF!mP+iGQs z8)k387_)TzQO;Zgi@z8FCs}bay@iJA5;_}#UZO02X3*Jipkd$$>~4@nQZ9<;!f6k| zuhPBjDr2RR2QO<0rl3az)7Aw4Z72H*1|gs1RZxg_d0+)MLtt9R`)tdvDx%_+2H{0i zios&{4^sfqWR-ah>MX$#-A->29rD0~K0>k`$8s5;fH9RdMUbr04;$1Y7^_OSjeNrFTaXC5hoh zKp=KG&S99&e|I!{BY5XLGg@4*dKKp>)jH4$-SMU@byHN2lQsgM&2=xto?Lk6_>IlM zrC)Uo5|8ba>9Rfo0CESk7krhY#vr3pPI9FQm}_5-5`(eC#k0*&`)wc5du}HbycT>7 zpGA(Ms$t!>Sf(-+ieT>9@8TCFcqn!EIbm0t^6Qb}09ME$olKGl? zNY-W**_pARTOM`<6Kcntk-oN6)|S;LJ;K0~yr@ zq5zVo82Ega*|}TdY^fxL%RjVr$8mrqFFi;iwr!KsKs^2^#RJDzi&?Kw;~jFar}*iV zPk3W!#ylwySu( z&)T!G$BjVg%o#5-WM0B|e8iQbtvxq1h?9EDuuXaz?AXlp>_zMuTIHRIb{Nz;>Rqd^b3-HsvlVkL@VAk2Q9bLV#kAQVwi|B&yD~ z2}owW=`g4L+|FuMtuaHOKcBr{Wpr?S#nW zja_JKjNp&=LA<|OPlFk3F7Tzba6)80Gi!-ECuV`v8fHh}ZiB7F_snc83q>fE`yIm) zKU1dCj5w`_ja!WH!%XgW$QNDHs!&leDA_ua<+;Fc60(1mtj`EDC_gK}qTF=7=PZ)J zg!f#|qw^BQPq5~IG-~{4lzYrPOIx-eL5ygBRrO@OOxVyOc;KG0%R zC`=SNZ0h(IGGOje*Qfv7wqi_|PggCg%IWf%yepFCh9|jsLWmA=Pekk$DR5W;EI> zIYIneT3l>#B;)4c>F-hoO3bMk44L7<^+MZ+t^xBe(pcW{D!>fbSfxfCSR|!_^Thh zR-Ui^Yr%Z+J^A$cu6*#9v#k0)O?*96=ZF|`r(uT94Qv9JJS~4ER-Q~bTwCS;Da9V2 zoKak)vR?8=n@RTEM2Ze(&nwpUf>;C9Km0XTfQ~V6AvZUiqov;N!BsNJ2h`$+F7Meew-LAmx*Ef^=gOadJ;BL|J3BNnpW2NXAl0d#n|~L${71zmOSmUOb~_5-Z#k`*1V{U#hGgSd6V+i70%%f+QCh%4(Fnm@KqrA>ukVo_Q z)hzc*>a+X4X$gy6Hmsbovfo5fP6$=sw3*gJ<4`9LnpFGxg0BBbTDvN}L;&|5TwI-B zsW-fIFXElB%*aT=o~*SZLMT-f=QLF&OB>0u|Lc$3Rsj9 zlHq03(FMe@aOE5#`A;^Ahl6ns# zx->=xg^4872Mxyf3-TWSW4*p=jH0e)`uB(=Nw@bc2Ke*^HN$dFr2h;jPxPwufb;M# z)_{UE$u221dLxN1t5`;Zks@q(w2jkLo(+J~KA)^C`gb+0Z*XuMuorLX1^stol%0tz zp!!24TfB+!_E==VSGz7W7XN$E4p?D6s~USgKQv&*G5qZ#LW6~DbvpC|JQv}FxKS1i zo;TJ&GM(;IIAA0BNkENigZN+7I|JL4>r}wSGdi0bcjbJ)@%j=R5(0fuUH#Fjms)8b zur~EP-k$Y$y*04MC;shd7)oY-0C+S3+O&i5WO9ATQ51Cmk4?N#@oyj6J!eWfs2HbZ zXEK`rh`_7E>|qcATtyLY-7>!R>`T!=lg#IZYQ4{|Bf9_V?Sz8hKk*&|iE4nkA_0rb zrNXZ3Jp@F5OtpUt)^ig8DczlKwl`u3FbYGX)5WFHYD9b8O5yxWtNr(M_tMYsp>^rY zGYzNbO$BD0_kTGPI39kBZm>6l*pQ4Fu_#r_h=C`XWbygZf@uL`mm>ToW)A>e4BYtY z>gvis{P@;-V)SO*0|Ry-dthOW|7Lx9tB>Ew0{$Nv;Q8;_9+rVGeNPG)dh(HM!I6Mr zjqvvF06s(GUS#Ci0zL?wfVXjW$$IGQEdL?1N5=jW${% z&p(@*KL7cQH(1fB_r^sq^5Gc!bN|Gy<>! zo%g`Q-qh_nK21X#n~T&NNo?Yb^96 z^Rp}9u4FJB!(<>GoCQYHuD!iIU_Uh}(cj-+-u7)&+$XwRkWLqXzRd|XDJQ4;noRKi z7X}7}(Q|Xl)na z7r?dXZs$`5Ko;eg;ebo&OgbEcW(o)?oGn+AYZe0n>`-2cHmjv9pQ zrj5(EpTG#PXNQ9F#h84Z{`QCcPC@Bx@mzFdle%&v- zd3qMN@L4MO_c7qTJf86xn4=%-DxAKBKMsQm0prjZqZJT$h`GJ;){wf-B*1nu_57<; zAwLK(P+0=0iG?kC=v1m`y6qWY1v%uIS-^?dpPx~UfY+&%!~cW%wgiCG7{W(U)S83x zn1NxZZMbjo$gP90ny)ZO!oi!rqIqi^yGqT0ss(D~ntbfGwzgXeQ51?m>SX1ki0kmK z@fhGqtRJBPyj;Nbi5l%7{qEXQF)eQ(#K>dn8sO`rt~~+(7kUP$V8}If=_c3MiT->c zi!PJ5f*LOofaH+5|CL*=)9!Vz$PHYC0NOUqx%gdHF zfkR`-Yb3->9hKm{_c}Hq;co?9W!|v;~CP)u(4T zXN=Z6o&Odv$h2?8#W}u9Rs`rw3T4U^dkjkGnby2-6cUCi;1-F5Uy$-;*|R^a*Mt+B z-A;exInU~-Zp*Gv1wbYRFs5_6GhQ|4<{3T8ROZDvi|XCKS3|Vv-va4uu2z)R_*m5? ztufecdYr(!g%~rWapZY za>}_i=#x+HIR0tXy;rM78mkrcPZS;(z2z)9VDl3@OFPv5=by4ei?(`Zm~22%M|=Ix zCzjy{jOV7TgZoT03=}J7O^1yN1m2jC8viSXsy0;EQ4%0sjCJZDoIKk|XEY;5UA@N# zjMYw3Ov%UFmrA0eu&(*u{X$jU?jjR&=}DmzTFkA&T}}7k{KrhO9Gua%vl>0QYDLPq zua=e;5W4!iF8q((2<$$iOlr?_p46d+mKIL+rXUPIeaG zu5GfC6_22Q0O+WU5xg{{YO@LPICQllY|%)DCHHx5KM^v+mW=$cu)QDr*)66E5-_}4 ziPnKT^^7#)4gl^tdYd8b?j*U2AUg=&y?b{6AZx=m34j1xUavdG^~rz;pZ_yp(sUdd zx&Lvn8RGyVyZbogp^C|`!aq?a2l;N)fo^V331m`{yzABiTvi^vmA~5` zuoSDtVvZ%eLWih+oL?h}sz-F{#tZkYwtv^6Q>LKe6($5fX=uBHX|L;`Lm6D{UZlSW zW-}!1DEV~v#+17edkOz=e&J8MChP-ocr9TyAM~qzl$m&F}_dJcAcsWW)#ev3kiy#!~ zr6^8KT=SM6!m@U5M*lF}q`fJw1uhJ{xB^8(=}vlTZ!mwLNptEEUw};EDIz+tfy?nY zQ-M~HJBj=!Y^34K&d;xyPERyFu1F%Meo4eMaB;zg<}G)WTWGb=87qi|4kenkBpjy< z`o0PB>tQ8-Riv+?UQ6x$7%Rh44bmG3n=k=jCsd>;L$K@y5lu(yzisG~Zd`(d7_7vJ z-Q}wcxaz;1`F(2p6YmcLy)p~ojD0<@a}P&v6$Gd`GoZQ)gUTk)w(6}AHFKe{^@PP*TY>Ca}9S-(a{rJhFV zf6gDE3U&S`LQ)*-*X1TT6k2@G&#eyy%*ej&-zDZMzQiVW)AMCIp~tE(=1>P^$de%l zk5(bbI|f8gF;>C022db6h*xwCCj53IkBf^zYr~_xCyv>n3~Ci>{K5{#MYLZ3YtpK^ zL9?T7^?+)l+f`UZU6MFvCAxPhnivjCH~Rf8o3k)TkaqDw*c(Bz51+vA(t!e&v$`8^ z6Cy4)1B}wv8pD`~@%<_W29w1T$*6x+@Wh|DoRGF~osRABxCF zyoC|21kmC%Apx93QCcHEN2PpU|(4*ZbYNf2XZnPoCN^JYdX$3VNyUHtyrN?+% zRQ*PLESnJnM>C#Y$}<9Ssu^=v$=aFXFN)H3=|8F%cz-!r?$YFw^imhEl0&Io)O)}n0%N*jLW$Mlj z4nhhFlwZLO!PHF8vvhRJErE&ceq&MeD~3Tf0EoO5%`$FUWf?CKXK_(%g}2CqX=>e9e@B z8Ya0|nNTzvaSepqn;r5!w(v`QVn1f?p{@)!cN-v9IR+fd)S@N6MMf|@sJ>-JwD|&) zlpb1~Cco~tADBI?4`GAlps!6I6j!hh>h6f~p`Ecwc-uX*HHYtdrskglmg*qAGg#n4 z*VpAc-3alYKd76I;U=yb>@MLZ!T?6Ie>Sd+G+1y{G_>2sfMgk}p1DK#-vYYFtnisJK@xiA{tyl;vowE1&U3BKbVs$dTGJ zr?G3e`S1w3#aYYRg9Pf(AN!b5Z{VOfHDXlo?3#G%K-?S02{Eyoz5+xU40^q zBSJ+uTOV?hV~Fa)Q0IFj+YMQUhcwCM2Dq@t+l%blmGip)Wvin^1xixdsY<2>g}+QR z88K8fwFlnz(W(Amfql5?XLn$KC>Z+60{L_3fGQ@1A1@mP^lv7oNf3H;ZH5IHrUw~= z!niQFXP>C_%N>GPLsWV!BYXztv(-#z@ai-^sy|cBys0r5l1L_^vlo0!!SmTN?riJ_ zD#Q9EHdBAhq!swWl4)5JHDAxU3^Ql)pj_MoKLEihhML$R;h+^V7Ao!emDeX= zYg33kn=cq(Vg`0ul(*ML0IOyT@SLP}8^-WX_{~PwYnU14|DP*@oag3biKzYvh+|iU zzi%}T1*wU)8C@g=cc%*-CMIy3h*5T)1@be52<^Yf$bK$_8Bo$K~?x9$QpP!xZ{X^tdSoos;SuQU-p>lX9$CV9|EX4sbj z_gAqsP828$6#`pHG~y1fV$(k$?rs+uo;)9zcIz4wzg3835Oep&hll_Kr?5BYMJV11G_xa~DA1F?w4ix;;NBv{R)KDy#r$k;%j zlkYQ-5W25{nNLUcz3R`FV$l(7WCg(2FG$B^M~%O+l530P;)U)vQ8n)zskzp+IOd!Q z@fCRX1d}J!mbC-8d+fKYh?h#pe{Y*$epmZWw~M^LEC2uf@Q7CEw5ifRIQiMC#$Dr> zlHC3N%!Nfw>^@~tY%0~3cH7+1z(%>?44HXb+T8U4p)%yh|2y3>o_VK%p6GxR5w8VN zrVX~e7xa!2LI>zjH)L{_@9YW9MF)ae`3+5iMInj|hDS{7@YXSIRG_-4RGCLYeK19< zR4i8z@GU$Lncag=@hXR}+4V)uO@kua2eYc~zvNOq!*f=}e}p>JmBSS@1us01p{YweqBrWU= zm@e}hM$2>X2f?~dD*KB{m)&{uw@!=0T>c;LOEeqG z{=W{@RHnn#-gvmz{UzZ&Y7qHx9I(U0@<*PGOnG&bSrk{nq0-oE66;N}=~wx-oEzy# zG)tYFjcFH&Ve62!@Ef)UgU#qC|eH87{LHT6K4o-5kCkyug(YW@z!TMkrPZ$Mldm6JhE6g#cB@D!9hr zm;P*HVw{^Q`D^c)-HN>`cNG#Oe)dYFa}{4MBVMvBub!n6D&K4EC;VTNPP{^a0t@+W zVcSZX2Fbk`bDpxnD^P4>2P^L<{8_&$cLaXl`t-ymEz}n0#+Q6aXGHU4*8lrHY*e_smA9>-cmI7*~*W}_Te?Q zy^PYFFE)Uo%d=GHcn4ki)(&r94sDZ`}gf-H*>}fjk%oOfhZaqQXoa96I-<%%uf0 z`>U`G`+>DoABJ7&7ly;Qdx^6ua0iMiIvzvcYW}0g1QAW*cY3$vbMVKnljBUh%?WOd*bUB4;$ZW$%^;LNM z7JdD1e>T1Ei`TCNVC5kMfA1vSdtjC(to>BUs?Ex?fayCqtjHYM`KOFaFXnZ2B#RK{|vjpn`-TAkr-$DV>|{+|;Iii*w%J zJ?Gr}zW09rTb~W9=A3KHF~&2-^Yn!%D#pac;f-&P7KLYh{(IN~^m7)<|93RO28%c1 zG1D#q2Gp4z-?YCTwqzuZn`r6_WlllZ?34d>`j1c6Z>UHB_V9WkYieOHsbRWnFq4b+ozfV<;macn|JSy^ zQ)J!|)>+Dy^Pk*33|dZ3&cVdim!Be}Q@?Mz0U3u^(R14&#w~^#< zwLGEFhnuzs{D<*3e>Da9zbAi%h=ddu@(IAr8hOixI+el)>mJ?a#)LFrB6*QXJGZ!D zy#SEVx358dM-1>KtsT@-{PlqU(Jx7%{R0b$o!g!uJt#!{e?J3N5Qg+;@5q04axnmO z0;^8}075kW>~6|_bKVIV$oqFou&`^p&Ho&2y3z#yffwU(d;kjN|N9OQ50X28Il5zD z0FBGE8ApaU>QU_`l@KJ!!-ETSbaw&`w%s^?{|KUKWunZ1-se>_*8dcWL>&t!v}V5s z*d+qV3@o76Q3n57MdnUfJ=dqzrlEUrhi00c01Zy5p5%_Bd&*&pXULKuJOKmqxu)<32L7~DJ;xW@}%^#xv2fBvQtJsko|U(Cj!&Z;)tEN{~$3hF35{YjHeO!;0C?l>Yx&ckG705=s zqb9W?Aq~fnXQPVJ-`wC7+CVI@3kEO{LrE3fu=Ecf7;XWU_*JvN-?IbAUeIeR@LROF z0(U&;syIRGGc4RXM)5?}w>4q5TK=D&bRGHgX!oM-zC zdHyoQMV&eFAVwa!r(W%@x(h)aaXW8}Hov&^GwH0g)@fuMs{8NH0Ssu6&UvQTUpt0ZPttlC99rT<9+t`z&FZ#?w=WYmK*`6 zxK|uMtQI=YXGo7z^lhHYo;l|?a$6g8Tx$xL+OS#YTG z*(+%SKBp;+;V<*JU9-577vz!Ij!7%S4@gv*%h9^SGFa>QF?}!V6aT%5&$H@sYIl*Fbyl>BM zBc22A!t&2|q2hqKh>x9#vamBu+10JL`jG{O3PlWgN-t6LXeJHyJ-Eu3 zQGHtQDop@a^tU0!;hKoM?dQJbG*Z;mt3z*r5;xUxmL)WF&Q^T-T*(1q2ud>luYvf| zG4<06Q0M+ED~DkOu=TzV>_w*Oc5j=7ARmc7l!(8I8DlD)xjBul zgKSpBB+qyOTtO6}51A=b+hy?2#W~adxfyO+K+|M+YMUsC#yd%T*`+Y}B-2L(iX2;L zj~j`ZqEg6l>H&JA!C1GzHqS7>#Gd$>DG_SI$lTbE?$V+qS!T9BkD>ndeE zHKQLR=wUxRA>@GHqO{NE<=-<17>}k8&;+}M?Nwzv3=Oq!FZb>nc5e^vAIGb222-fU zHr$k2T*w_LI;A(525D+Lz{h=-7K?-_j)JWtGw>6AtA>WR^--d@x9c1Gr zNXk^%(p&JR%(X3t{Wgd_BEx{79=WyCqQ*8dW5a@0j11W&-onD|gIJ-}59&nnco{+d z!8qfGUzf-4H>8q|gl+_}mUSH{)Y;nRsoiLpbJCU5KN}C?lPq&Z%V?|pfkTWF8bbFl5O&Xaiz`g?vDu8`IS|*m^-F!}Z zC)?ai{`5hMX~p9ZsUUbLhn6PpGTDt|6j>s48uvUj_0=CT@f~dL%q$mJ!`IO(WAb>F z8r+$xU3P~yj?9pA=e zV;V;nUE`RuyH2_TLK+mfodHO?hK}kf!$pJ>fZ+GDDUYnLvs!Yo-vl{NQc$m zU$*h>PSwz#ZuN(@0UmUU=P}L1!t-ygwD;oP7z1VMC!6Pv{TuEWRATzgB9wo*tUd36 zDAR7LBKyyxK`q|V?VI~+q{I8Rq3c}%_H;dCTa6}UqVEaP9h;KRe%bF{){bn=ovGBjaWHFMqbADEb zRNr))R9wdQgqBmgt&^U3g=;Gum|#jZR`IO@i!&aP3$u3@8RnLC zcSmVgyJ8Vm`628VSCYN{Me?D$v=@uoom*s|0#k-Qow3HGzq|2r7$OuRyJh&6Ap1px zEE3-GnUu1R{Vj8avo))Bq)nxZFn`GW3yQ&Vj9 zt4<8clloSeg9I=*_o~?bmagu{k%%4}f#tg;0$@^d3wOPR{Y(R8;&*fto`vrS?lV*E z$1>`X8TkKCc$TGO1}~EBrnmf!GnmKmP#gNi1m*OWpW9Ae2rp(bh`b5=EZ9C#>3SpO z9YtLJO;iy`Vv-~qq5Hw^E#t2Y*jUJ5|wE`yE2KEAb%J6HlX*1qSL zyq6pGUM@|IWT)v#ZJ%dEixZsoQ41RFZfuL2#j&Xv8s%rA2xDD@@T4}mtHYv?CZlgZ z*ucd$lNfTg#tF%>Ki#UA`9y&SflFDh>w9yjDecEr{IuHHUHvy#HZpJdD}B;nQZ|uH zIW^QTSqe04D-k^Fkz6x(4ITP(^Q2fk6||}AQs1KvMq69ZLHou0eFzawyP+b9{#K{A zh*|v2{H!zjv_3Z+zN#bqZ`R|9XOV8%w%1NKu38lFZx>49eZBthYzCOqeZ4J_ngl*Z zHc>LJw8M*kGY2^8%*M%8^BusGqfU)qj+6;};B%wu-lSgvgkFzH3+4wO6Bo)+!>0zthpwY zVNnW&gv%2);DX4h+#?m*)ri@RM9;;f#mTH7U)gUYe^5PB zt2{`VYR1sX{)llWzTXI|!U@hvf5?vU-N(%|F!6#)2;})E!55tfxnC1ruJEy#t3X`M*DY)S`;U-Xc)1&*%qmOXZXTl{%QHqrz(2f;4Y;Xr4z+f35Lm2 zh=%6BJiP4jo^Jyc176(f;HoJgbCFu-NDQRD?XeFYU6gdq=Dz#G^=x`cMxwV1cC|b? zca%yP?fv7tIveAk1vYm`A2I@MIrA#C-)X>%3#OSj)uaAczEuuTxoZnj&sUdR?^eJsbRqgvZ! z>s0#5!+O>-dERfaQrgKSB)%n+*CIFAYCBn5^tO}YaH%taNfgsc@^D-sHRGPEh)MQ& z%7Y<;5Z8RP)Q`svEYmH!=^&ftCGMa?oI~d+s_kz48O*8pjT3qK9&v8 zH&oE{TyfLUSS7{aVhOL)pr+upJ|W*g-#;`&*oda+a||~mnfzbSQm?NrP19$vC;r4W zvSU(x-vaRd6I|**R(UydnT8o_gs^zg{UM62&?>Q(kmo$HdOubF)qwtMJJnA110InV zYnJKiHIpSzcsw&caznyM`UPMBeXoYJDK7npU0$ItTwC(IbK_W>XRLj_DhQp3=6XT+ zR=y$Q)#GAGWj0%mn1*fWw{RhFjgwAYKn@jOn?j=UZ3|yKc-|q13ID)%W~*P(wk8^} zcti$!+6fZwcgjb!H?>Jj!F|7RwGMY4aL9v9+l~6JuA{^>v-OUcDSnfNt@&(Zf5z)j zF+x3hqU-WzAn;{v^MMD<(9uf%@Xd>NNIY|^ z=c7c|ErR@&f~(@*6q+S1807z^`6+#0NgV2(k}nRlCO=zXYFBEbV%>RvqGLp|)V3$JekR|({6;xwPY`dC z0@O9lov98;c5qGG5?(l)hDyHrVL2T3ZpIrbQtTzEb#%^r;{F5uRB8QV+jgr%zY5xi zr*+>c-IW;nC!K3aG9m`GKI4GZqYv#2Kgl8_up?$yu=k>+)-pk4)LQ*eZ2b!u`2j1^%!t(BwN9ewavek(6@dtq`>Q@{oxIsBEm>(e{TaTLrek zftF{!-lP2^yrP{a3sN%vAZQ#Q1h#qb6T%^jD^XP-YXXpSGlM09?+jBPpFboG%rKdB zx+MLo*$SoexqQ0xKIkwyN9hp(|}JInbo1vsJA0i(?v%3t+- zpZ)?S8;6w6`gAWA^;~5nqSO{j7((0aSZEFv4=?z@$NNdM%nb&5y{muQj%_x%UAE8= zBeYLsxaj3i%9{?96!d5Pk3q~O4MN1j-yQ+~AHX&XK^x@bl-u){QK4yt3x5Z;qX0>Y z>FTR!fh_Nf#14)z*000c=%5^upy-|HiDvEp26XG-zXgU&L!=b^L zi9m!WtmL^ErU3t|l~=@&`zB0UdVKfP#>wP+@t1Nv#4l9bPr4kYd)Xu<8Jrgf_)h%4 znZwg^Kmi6DZBxpzPPBJSgLEUBhzuIyg%z_=jvQym;v#ww!F z(xM?gm`cXfIdaOk<5M~VCP&DI4K{ByP%H=(!-pDk%4=(<*F6h*Z3*R17POrD{Nu#} zmgkLZ=mO;w?RR(y4lq~NKrNU}ip582kI-Uf3C!Di7N6E?g@PIVj3Y&L6Qu#tS0wn< zG@cPH;10|2!6j)-y)pEiiIL?t6@D7}M7J);EHS`?tV@wm=J3M*oFSEf7jE*MJHh@T zms3s2m`m6-c7i{YezUmdWt14Dpn{PKAm@X?u$}lkuB`_3?(_N^6;l-P+eE9VSL`r1 z=<@#3FF`0Q2blA&PpuwIes*fs)*rIu0p8Z0kxwhxJOSFe(R!}}2#|m9zywXZ^jrgr zXVd9^Y^Pg*%*x1PX9#bxGnk@3S!?>=D=*IHy-L29$(t?WhXse~9alko2hzo)t)79Y zqc)h%gV(~z?(|+Jx4Nn@32g|lpeMAdZT`EVbaQhP@;S_d{uCSMfo1X(MVXIp$c=@x zv0PFlYqeeW?2|jV&BW${F)B^!>1JV{T;kK*DD&pcF>NRHy}B*y`Z>bC2C8aT%}|Ge z3N$Gt;4$-TQf#WCibDO<)FL0`j>y9}wKyFmflpKiqCa2GkL09$=O6&NKuX95=O#Fy z5n|nIsPMtw_2z{jgSpaoh6DNV`+oSeH!;hs-8sk$uKTcuQeW(`c1D5B46rAvDJ#lb*=wQOY#9Kq>gcHcEj}G$5&%bS(%9e93!yzHbk>>WC)AG4#??N8;Inwz&{Z1P= z6I&@N8`;y5h8UM{Z@c=ZLe0$6Y^nP7KNPmY*~^loD(v}2^p4HO9`--eEAC8@8IAm9 ztLY;UWJJ&-Hcd-6n_TW1s87mlUAf&P^FUHk;3LO_A!KI|HOH`hyHK87DR=PTgq3fj?aVj=)*P~I`%xd@qavv@)BtMg|EfNtyaQ+qo(m3rv zE18+b?#U}(y+q#;2zafO19}*dJOQ|dBj_aGy);4{NtgwH6=~+32Md8)|C>g`zMTI?gM z#K?+tXPd*|ruQI37xS}KVtFM!S$Rrzopo!MH9szyQffy^#A=`fBreMA<&m-r-2_Xfujq z1dB#R53{cewv@_mgVzP(PcZe<>9g@l=I)a=`siLJw2m@=)UA@3!z_$_Hn?f7y%)!P z#6yLgxWvkkM;e_YY+zucL3H%S#)7qD8ylopbYGk492%OoX?3z-Hd+r99+>*V31DSq zaehj;PiIkBXd>J-YUrNyaeDecrvAaAmY{Gc@11H|;{9;I@mh}|F~*v+i8ax~kC{N> z66C8Y$A>PEn*#O?=ToC>KTN$f^DKeF^y8e%>SsDjL)&}BhubYI=Cst8cZui3=fs=i zyH#xJJybrgy0y-E_b`756ZFOsWL%LDU%(W*eD1M%{FCXNs5(z5Y%!B_^*yHghxSN= zoY?~uSs?sJu1v$FBtbXCIvAtDi2pt)#TFNv-z@W+jy>eYK)FG&mX@J;B|iMP@F7>efg10$G|(u zjxhgTl2nim3DW1&G2(-aWg3WbAAo^rx*!&pFPcneY}l)PwoX>LklfWZs)e!uiRbtc zjsxqq-eZ;-6NHgTxkTpv-2BU+4ygEIQ^Yf5NiRNCTE$7{V{-j!YR!UmKkP~m&MFQM zvzQh@LRf3qqM9m92z42SXGd?0#_aQCtBta<8P$3%c8gsf$Ee4}c{|)zv-6OUOm7p) ze6)y7>uY!5{Yqk%PA!1SeDL8JDamw@bTGQX(B%pOc#UAdjg$A9#3!#g`T^qyR1VLo z5>j=*C8w%@{k?Yr83S^*gd_a`m;J6>Oqpdm?oblIi zI6eq|FZe-TU3c<4I`tn8ONAqp0~9&YHYN3gF!|->SDbpOAsr@wnnP)s!t(?qcty}lXN$FkC>33T*6a=3HPhO@x zI8}&{B_8;7@ok+ZlGJlWUw)S8iX}W}d zDc&B1Uo%63KX{wkqRsBLm^*g7&th** zT9qr!>}YFG$3p;PhLwT)+Gdmay3fjs;|GYL9m%2O#F2?Vy3PHytSq%Db?x0l?sZCw4)5VnGf~Y}Nct84!`sia4-)Jn%Z17k>$!NsC!bk$c;12Je=u-sB1! z`7I_G=|-l&VlR7ZhCb^``^MliFzZ3IR5x9CxUM@PGr^l?SNU-r8MO~xrq9pRGaFov zSvP49m$3Q(M(fDaZh8$e=ky(~HOvcBoLy1jgC<^8=NU0D_w)eI|Eiy z*wadmPibs84bZu`-g-T{AQvuUgvSPnXv+7-AJ|Af|3-Z;j zr!|%^nz~9Ye13;1P<*Rawb*e~Y|ZsWC`UpoXHqc0ef;4cK9$Vj!go!expN0!3a38G zcIChJK?uD74B^CirgK};_xQz7*ZM#s!bvEEBTEhkL23rs8mJ^z>fgqvh6j$qX!ePD zVF(t_tf8^VaKMoi&ILLkOH+-<3+wKC0vyqx+(afwNa>@J2CZ7X0giAY?xaD5G!+pE z5+^(_fFqR7;GZbqHp?C~?Mk2k$t;&zG&MoO!UuSOzMdXhqx#p|W8mnz$g7!Av4!Z9 zg2E#276Fp)2Uu5ZDt0&lmI6%zZu0qDNanxzwL*;7-2CK1_ljGY8f!%MUu}(fJ6>UYaJ+ZSUW8E!Stt8b*rR7_^x=3w-1e4XVv_33oSzzlCz*h9w1-cl;6c)L zS)rm4>b%sYzAh7eLMwmD7}Gu$FzdY~%9+ROKk;|c65mIESvwrw1sjBr$p$}{8W zvpNjdj0A@(OVJv;VDU2j&i2n=rcpraA@E`9*B6m9xb>vdgmwDQockTj@$XFBY$Dcs zdPf?w`=|weyX6Vdl$q4T;4Hy8lF!znRB0J*TS=2HYhP4kEY(WCLlyW))tNzUm&Pb%^&DpaM*(!9?-8~U@!b=g#w$(AugmQ`h!HELz5V@l-fJ*nEEjDvp zAOM6PfGCeIZ`$-502OlF`zWfKG1yCX%v@losAMkvji;K=a|5WE-7OT zFaH0(KCfgLv}S1*Q^;0ENF@c9AT=&-C;)};r1||hail?sAo@7;zxyk;?g|+{U;6r` z@MZ$rJn>JF$4@jMMc6uG>gN*WRL@%> zKJ`MSPvcPu!9<~d#KycZQhH^`hSY%ZKjNmxue?HbdAn`4#fT3Ed&yy%AlR2Ab+(SW z)EdNgV1q(>$Nuzjlf?q&36#+}7@q$9AFg@mqXmftxL7e%Y@p{QhK)EbNfe=*Csyzi z>LAJZ_Qx0zp7bsu_21jg)Kl=HRTvT*T z)eo5V@BQi!IKB?9tZYXWMHfnO@Vzx%6i=8f*jQoLD&8iZaLZY)td`_c#~~u(NI$;o z;9P2mjrC<)pvkcv>E^i{%5p@os%#A+->v2zIvj(C-}~Ls78(kuru6Zj?>H=pT|xI_ z&nL>(I9$g5zF6&+n!jJ==u$7i|L1k?$ju8=M9J2mR6-U;JD-h5urF^ZZWTmMe1|11 zHwc-tX5?(kXyvA3v3}|>O%biU$@f|td9>}LG|-8_eIZ?3r@3-YnowCBp;G$d(>uoE_iHAOZ6epG~?mNvz{jv-(RgF$%t`QZ6|LY z9@@j6f#%}J?>-y#KR){)Y0%_^Yo7R5>KiPIzKwo8 z0kXrQ-uRe!A@7wZNhbbpGV@gMsUm5F_&`@Jva-=seW|sV$4amy(!0{h;EogJkWTt4 zV{X#OvEa!A?N{9bAnF*SNwndiAt66=h|FRKMI!t+?296|oM1v^qs^O-i)FV%z7A!5_A6 zkb(-trO>g#n$M8|RL z9J@E|p?ZNlQD^U4>r+Ki!bJLLZVlr9w&O7*lCRkaJ=GfyZ&1fMEIUEth6LjTHg8!> z-se`^lV<@7AU*q-O6b2flz&8mUxD0ZT>Ee#S#i3NGDbcxylJ?YhFC6}1I8TOLA5U? z6?Hb&2rJ^sTPAE{v?lP@eo$k{px6n~h3Xv=p~)H8F!Nf7@Qc@%oR3{uee?4{~@-lvCOrO_-0m*nN?Li4|BCAZVm%)?sH1%NU}Umi?2+}pql-LC*71uKZ;5|#c*O!LP~z#VQo9Rd zRpVuj!|tX9Crk?3uGI40)u?C6-3t2#>7q+gnTwx~mp-@BQ03ko*6E|U`bHKT8yjOf zkL8E(JO`{G4mBf8NLib&cAwQXPMAD=X9=AxBNh^_uNM;5Gbm-MmmJ@CO%AS#6^)^m z^omg7UUKdVKUT_Trm*6hsvo3(Js3)N4ztp6MM-ioinTG~3ZpEBU7*!*7~>r1E~_7E zPRk;{4-tO0el5<;rJWn{@+DnCyFMIDQ(?@zF*0_}o#b#+!M!X#hgw zBX<_r@HSb-%Z?_F4am9kV^>251MOP3HBbG<#1>*)lsmCUC_JVncZivDfi;%7J8 zDO-(!uU#Hm&2_3T@Do*+*NMrVY;dzHv?1##NH&muP8NA~SAc}AS0~-tG~p_BsTJQm&=yFDS2RODj1Bn zQ2bNdktCT;iQRz9##a*WO4X;;8t*R6c9`D)6hN^j(C>eJH*sB)<$7*67f%0+@7nd< zV@npi{KZJne(;A?IAU!DNfS6$-t+f!3w21_h>Q7Av&B_Yj`=oa-k9?ux)WLKG*eXk zEB>}EB(bsV@ek!$+Etv*mD|SN6oDe^VOFVzcC$_VhD>|td=;9wgsj>!$>+&E>7T7+ zG?a>7#8M}_CiA}=BOaAK>cMtaOZoY=X!p3wav@;PA39h(ANEUM_y}9&G$&!_Y}$OJ zqz}|r)LBDY@?ImoJy{;3ZA{E0-ut_ahCmhzR_vHOehc20uY)v$V`yw8MJlmdz3<;@ z_`Kduy4SFq9P40a`7M=g(T5p)b3QfP&!JUZI!$=8fbxBH=d=m|d=yc-l-kK8Zh9zt z8Rjm;yk`YdJ?L;vY!!Kmp1T#Hau`Rp_=@&8sa_83cTK^C!8@cM7y)|vW3)^q`{^&q zhvUQ+hI%FE@kn&va^8vufMp16`i+WlVA$Ux5Fr8%LjpiyELFlD$VdJAScKiMP65h# zwl$Jj-z-YqyZwCCrE|F)UCa>L{o9Ea<@(HnHu*P*dUrl+o&ni<*o%il+ffpKCOGKETv-gi~yF6f3hK*H8K-^&< zvmoSJa@NRYwq#tUZ*6^m{OIT7#BH+)LNtNt_V3JO7ni*!eVf{O2Po+1_*_u$gSl9@ z3oX*ypJ}O$Ht{p_E)JTghNOc3a{M{6-u>G!=aG{q2;}NoUAM95R|PWrqBMH3D!CM2 z=s+wRDPiQM?it}AJH&4RMYL@2>S1z6RKgT`qZiXxB&9RzZ8MPN!eG*Ur1Q?zlRR?8 zGqCO($V{yFeEEHrG?HHdQr7kv`EIKVn2XH5V_R>n`$!$F$J(dEm(9_<)8-(P)BzgF zI606pES=4|)^Ik*3#f_G#U{JD&vNr|aq!JiSZfenc7M*@)Glz+D4*ozJW~q8Y6M}pD|3~1d;m=+85@7ZLie8)jZ7A%={Fe&H zrp0=D4a!~hD5EJCB|_UCI8^jAqb0AP#7h*B?y7fh3QCPWm81@ zUol1V2T?L>UfTQAcGftPD1~D$Fb=X><}?sWhqwuw9r!LVcfwd+a``lY9;Q_I zDo>xQ<({s6cFdX_TsSagVgjs#N z`I1xM6BvBJ@y)*XJMtRKYQQ~mUlKizf6yJ3v+EPQPKczez8!9^I)!bN5X2Lc#p;n*+6#4^5 zrGb}JIM0pi%@_y3llo@r52R8v;8xLiu7*4ln$NO!O$8Leh60(|e#?`Dc+)L-{4A>AMu?uNFvrEDa)A-a z#rs*B$i`9XQBp_oqLCW^IONtxQJz1md{G03jG1_!p(brLfa@$-nBl(;v=3uta%HZV zx=$%K?krMNRwnZ7IVg2K+v(KTc1x!1(U4ml1n1*Sl*&{AIY>? z4u6E*ME;fu%`iFcOQz1O&%*w%w1gPAF#F{4lKAKoDn6rv(O0DAXt95$Y)X3G zvFUZUMFRU%^=Eosj-`%mM;tG2whZ;XNJCC{9{g-8(?)yDu8MYaFHN=ZxuNN4`c*yHrl@|$zc`C!n zC+%>a7qwwDvKv{*w{CMR#CH|wJ_j!+nkGtY87@1!2Q!?_)OA7VUM zzBjWbjz=XIfs{CE?#4NfWR-!a82@rdgPwOq(3fBkX?l0rBJYydH~A=OlPai2Ot#_U z8?wtJ_O7EZN-^lxENJ{aF^Kc7v?4^AKAhbjH!HBLH3%F-HJ9db@$e&Sb8Q_5hkHjd z9paZWsbwMzZK07Msc-%tmz{-PaT-JkgAKu!e%J|;c2Hef2uAPZ5DRDPGc6+dXd4TS zPf3=piip;;%HW=M#jG}8&R}~ZN`4WpZp2(p9|tOSF~%PP1DRZ0l`m{jcOLV2;Xv2P z{0B=b)js?npxeZUI)8=?r!*!ZVg0ZRdJK_a(Q(jbC4Bt3ld?1>HWmV8E4PLzYdC$Y zCCz<~mNpLipt2z)7W6Y&a!V2-Et?#cj3-woR)-QQfb<~aL-kjlwB{gBrM*0ffQy75 z$!ad+Bm%va{7Vul54UH@>8$+2!^1!|4#1Vjx}>F5ruMELM|SgXX?K;5f+Ijnm*9cO zd8S%w&OsEEOvNEVem{v$SU?35hfTfjI6J1FIcz{@I%@u%Bd5wqFAQgGWZ2bXRC}!hYPV_$jFB?7{YYrg8V|4MjqM%+%m@xTmo3c zL`nd_-8d%GToer3VX}|+zX;Gx@x8r4b?h5EHvv{FZ zbgb^~4lK$uX<_yP{D!U*dyYT>4{+Ly55Sa`CSemS;Jm<0bvBg!yo3Y5DyWG~O-)yq zrHBYer&_KtrcirDlG!_uK+9;_WhajL-uguiA{?n>+WXK-rS35 zwKHA1UI@w;o;$N2VS`zN{D3h?(OT3 zi9Wl78)T?(hW51!oSy#{mlaO=m9%=UZHTks@7 zfXUrYm<#fD+j@dS8reD~7}sRA#ohJOf5t$)s^k@b)XY~0Ub~NN*CL@kUIO9ywVIk* z@z`=pKtmnB$8m5%a3_n(Yi+WgMY8uMWf>ngmZ(KP2|kE}ds~DPi###0pUd0dzQc8?NY}$SqHw^_2wNMB;f+y$m?&*Vsm; zju3JFW`AJkRChG3YmX;}g9$Rf<8Rpmq97g{KN{x`xQYE#W7$Ek1ZHM#>KMcE5G;%~ zhx6H%0$!tFOX0}sMq>uq$T-X{w+uO)l2X3{GRoFCuT?yO?%r;4u8{j%{bpYt)7IN3 zw4o@g^x>+VHkXvS>B=|=sX7+;mH@e4BAr}ZOvstU9aIaDPzJ#StSmfw*(o2?fr>@< zYSbEbIt)jP+FWjc@X}Jd0N}4IAw^21LD{pg;GezAFbMbt01*|Dh_&W$&x7yI^}_r8 z*}3HnA@7}j0TSTspfM{K{2>4-OSSnPbW=~l5lI@fsQNSN;&+u9bvQD1Z+9#{z-1Zo z*mKN)w?nxB#KcQh`#J(g0eG>pUx`%bH1@3E}u26X7ILN1E zK>>ZzOq#D$ueY&sX|8?=F2A?_`SzzETdXWyssh7He@lzz*@ZSlbfOrG7D<9fcmM;4 zyxp#jEqHV#;O834KhjD-XW5MB~ezt|cn%&Z4a z$EgPpK`CGjNNq2*{J45z>1goBuK=9#`0J8_ib{mx7n7Z4U}#t2w_)ddtG17qtFK1w zWIqml=J+)AkLy+x1E#3>yY1V-jny`wU*2+P2k)5@4KMFYCB?_5L;78OZ#vzjT;V6EkbPafAvw%`ueFJwEH>m;g5`Cl&FuV@*5S-Hm&=zfo5L z9_+8vJT`~2wLu~4{i)*(zzw>i|9l?{en4nj?&sTQm$LG+T(EoE4udu z`YYT=BNp*lr+l2j7#M1d|9=Rxc&v~>2_&>Zk5B&^^A8+$HN3qy;m`j|2>)YHA0OXZ z>d&_ZE+k9z=jng`?0o1ju%+# z|Hh<%)2iP5ozM|WJQ9>^{37zBvM5i8HWMFnHt{cFvEMhG_;}Q@Y-QoN6hjmC5wgwc zcymvFDKDU&cBE4q7R@pBj0RwrOBq>}!}Qj6X3wuYmP{9Pm2+Q5kzERUN9QU#r@w9> zu}~b}saaoJ4aIUKXDA_%CE#|ZGFffX+ESC=WpB%^gq zJ*qFAK?^bnsr=?hAplf2&UE&;zNt}j8qxKIyWs(PU8(Ox@Z2^#pnVZvg6K=<&d%i@ zMW(6y{W>2o;x@cNs#R=namQD7Vs)}0lkof=KUY>(5Rz3~pwug;KktsOf5Ar*jyrrg zeQL*TCsDtT-6@v~G;1!*2qRDQ)yc8Q7dfvZjty>tmOUXO;Yu&YV_86GG*KOeNg~F*X z)Z&Y}t<;>rbz3QyiG@J|Po-y%4DjjC|S| z^(XDWwN36&QTFu^+{sCU#EcTBhN)}D)KO&H=-frQVGa?hU*MnhZ9g}qponqzBe&?? zNN6fVSzUXCEUn!F#n4efN>K}D>>$}D6YDS@=s^M=<#d`Xnb)t=#0;sJe zS<~;e!nl`;+lU?5#pKSh;1E>!QYoFn73XBU{xTWK4H>C8P{1c?Ob@d1#!_QuV7m<} z^&@$9yBy2%1y7_2$zATmUM?R@W4R-5q?X<`e5WxZJ6;OM`>_V2f1Kg{Q*ud854pZ` zp7DB51)&5;3vov^ygql%!Ov|EMYOKpmljZt$`Hcgb_-k%^-fW)$6$=jfarR2$(QBwlmlp^I{4b4+o(1dML=4 zarL5v&G%J`7$*1=08bQ}jn)JcYFC=V6UGz22z~i&%=yYN;sY5MxO35DZM4Sunf7Nx zN*y-cumBHrUjgSZGc`sfdy>sk<#3-z$QJ#igNt$xthBNFQL{z}jet3F?jXwQi(3&! z!o6d6^tFZO4H;1R_CZFFE;SUc=R|O$5`6I5n3;I zwuF0+5GXTy^w5UURclJUqFphT)x~J=_Iw#O3}tsdMCg7d0NFv6V)#q$8;iS}xcQzV zq=fGDd_e~Mf;XdI(zuOpTe*<@4)7X27jt$j-`INE>yusQ48L2Vc(zjiK;lbalg5d7 z#_>D(MzEzv>_&Z#{ueQTbmoxcna)tF4sOXL_E6lxOcTM6(g7T;gnT)l(6I8xSD#f^-^L87O=?Qpe#J%=CC3a5Ez4#HOb*ZH0 zccXrhhhkan{o9U~YE~8xrObs~(XJf2(z$MMD|%(u)5V(xD1pH{F)%=4!9OvM+4TJr zrAg;#h@*@OMzXT(deQ9mbg_uU7Y#&0Q}oC34>|TEK^9eKZAD+f7_($UvWrWbx*l2nl|YFeX48_1;lE zfH$Iy3<_s`!J_rvuHgU)1nEgF)cB-;KGb|=oGU+$Pf@rZ`gFjZq|$cfmue>CkvVj2 zL{&P77!j?`BrbPOn$}^NMq6VbqG>mZ6~#i!+L~D20DL;aGX#}VRi4eM3$c~OV*7&U&Wm^rH8z6R8m`9>yOsk7uBA&j1zR~|z*I5Nb^@VF+5ClOzHS#LSs*MqNBH_alEI* zA32NmD&A>_5o7Q+6#78XX^{KXyV(dJU9d5|8ej?|eL4kvFXhcFpM(%svDgJCAP4s*kpVT%eX#Ti+4k19p2MJaGc zOG&e_rqp9cT{D2CeL<_$$$aNmi1o+U>}&ysN=Ub(YE60YbCpjZY z7D6RIWT|MgoGp;7g_RafbPWbkL7^|A`Q=qF?_bbPFn@35&LNI(SLtFr;`vsc)DW{n z5Yt|X5Ah(-mY>9ob-a5uTAOTt_*Kb=le|>SSDoVKZpn7gsK^<)phO@c#G)B?82U;u z*6pj`ykePvb(`m~{XI_6<4KWkap}U1dAz@LS{$YkjL-2D(x~D;dA3)?%hibeq?%{L zjp+Q%3VrfSE*0Z0va`kH6oig%w?YpD5%+B*63+hcSpUD=R(pZ z`W+WNVqZRpqP5u#DU`>7lw5S9mdnsY_~GW=GI8v~`%qU%;q zBVkA(Vt=y^y_h2STb%sKwAs@V@8$e-4e4kTSB{ImmK;jx8M~QqQGZZ#?`GoTrcrc- zhqc4?wp`-)qQj513auN55~m>ebn3E z(OD46PLJBp?$h^aZbaI4l_Be6_~C`eq== zN{p7Pbh)EJVmnzDwFDVjj+o(BnhN8I%R-A>NHW7SR)P4660t8WyalUB{$>YBvin`& zTd=C=YtEu>ra{B48GgO9W9)n(C0P`zRj8tzGHC~XSQOoP=J@vf0Ty@klvXr}1vdx+ z^T^ALiXbkY4~cYpx9EI{;O)bAi8!jt6(c(dbcw}7Q>eW&DSqjEsZJNTB@Z`_ovoK`~kt^bLa@>>fjUz{7D_-95UUjv)DGcudcc>xInf798dfHGZPqwI| zhKcZF-Y+w^`1S_FsMMxU$1w;=tS%>vw-XIS}cpdMR()Vfqjwr@C1v zndY)_#Ibr0hB&4B?J2~W7_H$a4d(lQODIU$$tm?E*Ea=WT1G43P3VxRM4hV<^?20` zG}K{72|c9h)JY|9Ph40~u&dSfu^lyWVPqt6KUQosFj=7s-kX+@Ex^C<1KrUzWRU_8 zhlip&A{4Dk9yUqS2&YVxJ`Aa0NUNN&8rZ_fP)}sejz zgP9sj#UmIMVGEV$nXO7*K88eI(n&tW!wGfXM(=v zdAAdilSUr4FkyFSnetmUb@l|*F8#Pn{A=*0Y^sE|8bmt{D5BXXp`r88r@i{Z2M<_- zhazh5M9R6Opz4xx?IC<(XW?kmN3KT|@Y|YWXf~KoEIXbL7bBBQ0bUOMV+Ocm42#3; z=?n}1ov?12j7M3#F(gj9CbLbYkAH zuMuiCB+akeFvaW=2bKq2cLb>(sy9EAXNX>Uz*aLTa@HCI9`pbD&fNt|m8%uMkka7; zx=!U1mEW3b6!`oR9-?A(IA}|go2qqrdxKx zx`pvzwXitYue3N39~eBBxKEC2;%lZ3R6I z-KXEqMbD@#Ueh8WN9vD@al3=g*owJii)woKQQu_lr+ws~M82Xs7V&kqJY%9ig!m$xHWNarJ4HoteW7G09ZhW#O{czl7Z>YZxAJ6xS3 z=OrOZbvv2`-nS{$wZdoj)|{mEWn`6dxGLb1Mzt3%%Eo+2q*%K4Y~RNt7&E7%=#O=u zb#LuU9Qd6|Sid(Kg_ovLZzd@8hQ^e$>1mvokzzk)j}nk3G+GEY)QpqN#z*X-F#+IC zYU3#WBE`TYu=C{6kU}`-cdMqAk28O(I^NZUL7T-R&T2CEL1G_*j;$Nrq}?(C+x=Zd zg5vR4Z52y|S3k~0rFS-N-C7>>)M#IvjyiZ7H;T6}ECOs>Ko%taNdBQ|)siA0YU?Fy z1of2JsKw)z12UR2q7T(xbJD!=W%{dd+O+9&4d}G2JLu1NN){|tJqK8_AQL9=5XwH=b zWR?NXfX)4aVU`N7-dSbQ=eO#glnFcWx^GG`JVc5g)%*<4gp_2MUOE?wlsV|xJ&0XS z42oDz;a*-lkXoP9`f&c97gokC`XFPOXHr3kz5Er~{^!1CI2_(isoGPA8)m=yoY+9q z`o2_3ObR><=Vl z)&tlf=Of0_J&q_wY>8VJ?F9Bf5zu*UQl{H3jWh4M>!DXG=YgWmlt|lAW+u8bR`ngi zyuF!_TZu=PCQh&w*r*n<$Kj2zIa+I{W@y%V##EHtS}G((=Aojpp<1VolHP+I;oE=l zv2_6|uAcV7^Kf^%uy!9It3Ir7kVfY{xo@a!-3sj80MStFxJU zVe2Aw61KbNQ1xS{=<&~tT0i@vXFBi^Rtt4;XZl1I4NR(Wijq0*{$;eM>k_KZmLCy@ zcHDum@sBqNm)2RmVHmKD^3oZyGa`OkPN1*Jl3|cLk;vBXZxJEaYGP&AwCJVbR{15$ zl2e>mE&LuLf$n66Bd;6%Y7qVNelQb74O#r?Tz)J^Wm3-;iN1Kb)&q&GCeqdtes2PY zsr|7et1&Jk4dfr;K_p>J!-e#F`fk@J*rKEcbPDmocK6akJw^NOIPJ04Po?Ua4z@?q z;c4I*dS3=5mOm<#yrIr%xq74Q`*JapbX}q3t9#IAb9+zN^BCZ-ln_d3MRLUDW81EV z(FKL}QED`F3C1*X>wi{XT-wAt&68S+c~Q`%zWQ!a7clT79u{DtQPUF%&o$6_BB}@3 zFtS|Pu>UI=GR>qNebAadp3vc6DND|2v$`ibTtktbVM|&zF;01TojrlZLY4N9>tZt*s6~XN2ve1dT|U1F zpD=8+38WmzCh=y0>ws+W6t=!Hvbz1?w^RcICw?lOY8^JakpbX749kJ?`C#!F$07=! z#@=Q5HBOG5ZsYUjrBCXJ;fx6y`r2SAUF4q10j-xGA8pWfvWeU;EzSNUk$Q2fLcyLT zMBGJ!5#2H>x}^YS(4M0$6+@3_A%}5o##xmG@A8yLj{_`HeZah{F-YqH<8@e!)j(xF z>N7DNF6&!GL#Qr9H+m$q^n2smsL>Y2Z@if{(G@akzZ2DLFam?nAB^rRE{(U|Z1|{; zc=^=wZi~4?0&EX%lNi3Fv9r$4NtNB7#omWTok`q}-jZ=#-I%m*n0luZ-j#nFxn{49 zyip2^aqHxZly-|yG!FG#-zVA5J*Si9t+&?e3Om3pTbv^DIgMglcDUBYdED#hHciHB znm~w_3=+p8-k+h6tQdX$<~X_YnZsU?9(d+P`D`oy{>R;<=bnDEzVW{d6rZ$d%MN3A z(DGsY!{{31Ap6bbRtGD)UB$)fZX4%nY)rID$}lZ`s_b(F_N&ZH zOx24P@J7dmtzEE6Ze$xnPbA(|6U@Z?cTMB1V+^d>RDM4rW@!x)BwJneE3RC?+nMQE z=$iV)@~oF#`7?BaPa&<7jFa@ISlV`%aq%~jB>EA|d=X%k5n@-CzldnLFt2m}W5Z`w zPrl_(ul5ZQFc5U>7=>I8S_f)d$y!#f02_*VCy04c&^&?UT$wD}h>L#jajXRk3L@8KHXcCcMTL`X zYQ0%Mu_xn6W4PIFt&1ERRo*in>TkZzxzkqrSN6S99UFdp{@SGY8LqAQ&QsB{L5$AA4bsoNql~q={-M zp(MjglD~R-WUd#=>XkwDRzcfg=SvJhYd2Ry3wv@^E#C#q-y{%s{4S>$qJjhp^RJ1D z{}nLslDahsEFdE@PivXS>zznIpK%?sfWIV(xKwsYBaaeR_={o7$;~BWVBr6MiW3t~ z-07)}Qcd`ykiVHE8#ts*1#|N8>45FPEb{+u-{^sZtjAJ-Qv8z@C;!Qc3o_e<2uAsB z7L_Y;i?~uG`l9NK-&PtZM)FVS+Qkb-*TyE6mX`igDa8Qn)B832dLb%)`<7( zL)pFav71ecVTeZ8sVgp_h0?Y+0WdA1z*@B}=vYfD7gwefln?#6Q#1w)Gy${G7O;;y zR-(JQgAfhF2F!Ilx72$2t@|}~%Q?Eiu8(|2>S4viqZV&=5L`rLBsf7f3@bV0BW>SL zi9O%#xI#4;q>&&@5;}1Zg{FqaSK$8EaxPTo({GZ$-pLB9+YLMyz5vFJFB22fjY_`` zf5^{g0?nQLp5-hSU;A^oksJeRJ42E{!V_=-i>00D#m%6&7~fBvRQp^Hscc5Y4T z=|Y>78l09rPr>KEpP=G2@yUG?p464KC&s{Z_3hi4H$b)Z9@~1I>BzW^^G&-Z<%vM= zm*CU)g6VbArCev6@q;qp0s#_7J$WNuWpUtWb-QOy(R@{V_LZG&~b3S)xTx) z`1ohKB^Xksq!+8im-dj*P;kkCn`}-ttP~gk`si*0`%rLLw-V0*(9dPSF~fOsyv2Iq zVgl20o!T|VAQutzUuo9(3iSWcVJI2QxF3%BUSnIxVQ!K29<&msq!@&t4XB!utaO+q zAZJpH=!`(BGYG|deXx(jFA&&Dx(9`l(A-l2v%!Bc^eSGe&A})CJF9;JKA(V?So8yc zO1O(zD$`%lJO%_*C3TM2-2m`+9=cjiGjAgH&m>|0D16zljrrQ)-4;?(b18Nt} z8iu)p$Lt9mr~=Q?l;BUVSTrhkte}7yFkR{i^ujb7xnD?@e-}|~4T)dTeE?88I{2x5 zHLlLVg-?4Y);1pgI=RLt4iTX9k_ELo>@*z&40k){fcAj+#Du@frym)7=s*Z^q1EFF zPj3ey1~v#ALiuh6_`1`r$73KQ6~Q)r7Jm_(Vw%mLtO=US5YRa>@oeGQ7VAKANg} KD%DC>5&r{w6TMXc literal 0 HcmV?d00001 diff --git a/mindpet/delta/ptuning2.py b/mindpet/delta/ptuning2.py index d091380..dcaa1a2 100644 --- a/mindpet/delta/ptuning2.py +++ b/mindpet/delta/ptuning2.py @@ -51,9 +51,9 @@ class PrefixEncoder(nn.Cell): num_layers: 原模型transformer层数 num_heads: 原模型transformer多头注意力头数 kv_channels: 原模型transformer kv维度 - projection_dim: MLP编码维度 + prefix_projection 是否使用MLP表征 + projection_dim: MLP维度 dropout_prob: 丢弃率 - prefix_projection 是否使用MLP编码 """ super().__init__() self.pre_seq_len = Validator.check_positive_int(pre_seq_len, "pre_seq_len") -- Gitee From f8701fad75db9ebecf190b7a464ee2c5c1b2c0f5 Mon Sep 17 00:00:00 2001 From: pk Date: Sun, 24 Sep 2023 20:44:23 +0800 Subject: [PATCH 5/5] add p-tuning-v2 doc --- doc/TK_DeltaAlgorithm_README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/doc/TK_DeltaAlgorithm_README.md b/doc/TK_DeltaAlgorithm_README.md index b8cb27f..c7aff9c 100644 --- a/doc/TK_DeltaAlgorithm_README.md +++ b/doc/TK_DeltaAlgorithm_README.md @@ -1568,7 +1568,8 @@ class SelfAttention(nn.Cell): # [bs, 1, seq_len, pre_seq_len + seq_len] attention_mask = m_cat((prefix_mask, attention_mask)) - return key, value, attention_mask + return key, value, attention_mask + def construct(self, input_tensor, attention_mask): ... ... @@ -1623,7 +1624,7 @@ ckpt_callback = TrainableParamsCheckPoint(...) epoch 优化器 学习率 - pref_seq_num + pre_seq_num -- Gitee