From 79851433f878216d8887658a36b35d4c07a12c60 Mon Sep 17 00:00:00 2001 From: josc146 Date: Tue, 3 Oct 2023 13:33:55 +0800 Subject: [PATCH] upgrade rwkv pip (0.8.13) --- backend-python/requirements.txt | Bin 672 -> 674 bytes backend-python/requirements_without_cyac.txt | Bin 650 -> 652 bytes backend-python/rwkv_pip/cuda/att_one.cu | 124 ++ backend-python/rwkv_pip/cuda/att_seq.cu | 179 ++ backend-python/rwkv_pip/cuda/element_wise.h | 21 + backend-python/rwkv_pip/cuda/ffn.cu | 165 ++ .../rwkv_pip/cuda/gemm_fp16_cublas.cpp | 86 + backend-python/rwkv_pip/cuda/operators.cu | 246 +++ backend-python/rwkv_pip/cuda/rwkv5.cu | 88 + backend-python/rwkv_pip/cuda/rwkv5_op.cpp | 30 + backend-python/rwkv_pip/cuda/util.h | 7 + backend-python/rwkv_pip/cuda/wrapper.cpp | 141 ++ backend-python/rwkv_pip/model.py | 1827 +++++++++++++++++ backend-python/rwkv_pip/utils.py | 9 +- backend-python/utils/rwkv.py | 2 +- .../wkv_cuda_utils/wkv_cuda10_30.pyd | Bin 433152 -> 0 bytes backend-python/wkv_cuda_utils/wkv_cuda40.pyd | Bin 445440 -> 0 bytes .../wkv_cuda_utils/wkv_cuda_model.py | 734 ------- 18 files changed, 2922 insertions(+), 737 deletions(-) create mode 100644 backend-python/rwkv_pip/cuda/att_one.cu create mode 100644 backend-python/rwkv_pip/cuda/att_seq.cu create mode 100644 backend-python/rwkv_pip/cuda/element_wise.h create mode 100644 backend-python/rwkv_pip/cuda/ffn.cu create mode 100644 backend-python/rwkv_pip/cuda/gemm_fp16_cublas.cpp create mode 100644 backend-python/rwkv_pip/cuda/operators.cu create mode 100644 backend-python/rwkv_pip/cuda/rwkv5.cu create mode 100644 backend-python/rwkv_pip/cuda/rwkv5_op.cpp create mode 100644 backend-python/rwkv_pip/cuda/util.h create mode 100644 backend-python/rwkv_pip/cuda/wrapper.cpp create mode 100644 backend-python/rwkv_pip/model.py delete mode 100644 backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd delete mode 100644 backend-python/wkv_cuda_utils/wkv_cuda40.pyd delete mode 100644 backend-python/wkv_cuda_utils/wkv_cuda_model.py diff --git a/backend-python/requirements.txt b/backend-python/requirements.txt index f570558aca26dfd8c11dd1b732647122654f275c..9075ce7a6c410b6f8eee6596f4affded386a0130 100644 GIT binary patch delta 14 VcmZ3$x`=f`7_%XR@y2jZCIBE$1ML6+ delta 12 TcmZ3)x`1^;7^A_)NKYmJ8hZn| diff --git a/backend-python/requirements_without_cyac.txt b/backend-python/requirements_without_cyac.txt index 5950fac5cea7d10329be665d3450fbe7b03e70a6..b3f8834dc172cbe1ecae4d02327f25222603ebc5 100644 GIT binary patch delta 14 VcmeBT?O~k|#%#!7yfNIE2>>1E1FHZ4 delta 12 TcmeBS?P8q}#%QoH(wGSV7;6J} diff --git a/backend-python/rwkv_pip/cuda/att_one.cu b/backend-python/rwkv_pip/cuda/att_one.cu new file mode 100644 index 0000000..f22858d --- /dev/null +++ b/backend-python/rwkv_pip/cuda/att_one.cu @@ -0,0 +1,124 @@ +#include "ATen/ATen.h" +#include +#include +#include + +#include "element_wise.h" +#include "util.h" + +// Equivalent Python code: +// ww = t_first + k +// p = torch.maximum(pp, ww) +// e1 = torch.exp(pp - p) +// e2 = torch.exp(ww - p) +// wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) +// ww = t_decay + pp +// p = torch.maximum(ww, k) +// e1 = torch.exp(ww - p) +// e2 = torch.exp(k - p) +// t1 = e1 * aa + e2 * v +// t2 = e1 * bb + e2 +// r = r * wkv +// return t1, t2, p, r +struct WkvForwardOne { + const float *t_first; + const float *k; + const float *pp; + const float *aa; + const float *bb; + const float *t_decay; + const float *v; + /* out */ float *t1; + /* out */ float *t2; + /* out */ float *p; + /* in & out */ half *r; + + __device__ void operator()(int i) const { + float ww = t_first[i] + k[i]; + float pp_ = pp[i]; + float p_ = (pp_ > ww) ? pp_ : ww; + float e1 = expf(pp_ - p_); + float e2 = expf(ww - p_); + float aa_ = aa[i]; + float bb_ = bb[i]; + float v_ = v[i]; + r[i] = __hmul(r[i], __float2half(((e1 * aa_ + e2 * v_) / (e1 * bb_ + e2)))); + ww = t_decay[i] + pp_; + float k_ = k[i]; + p_ = (ww > k_) ? ww : k_; + e1 = expf(ww - p_); + e2 = expf(k_ - p_); + t1[i] = e1 * aa_ + e2 * v_; + t2[i] = e1 * bb_ + e2; + p[i] = p_; + } +}; + +/* + Equivalent Python code: + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) +*/ + +struct Mix { + const half *xx; + const half *sx; + const half *k_mix; + const half *v_mix; + const half *r_mix; + /* out */ half *kx; + /* out */ half *vx; + /* out */ half *rx; + + __device__ void operator()(int i) const { + half xx_ = xx[i]; + half sx_ = sx[i]; + half k_mix_ = k_mix[i]; + half v_mix_ = v_mix[i]; + half r_mix_ = r_mix[i]; + kx[i] = __hadd(__hmul(xx_, k_mix_), + __hmul(sx_, __hsub(__float2half(1), k_mix_))); + vx[i] = __hadd(__hmul(xx_, v_mix_), + __hmul(sx_, __hsub(__float2half(1), v_mix_))); + rx[i] = __hadd(__hmul(xx_, r_mix_), + __hmul(sx_, __hsub(__float2half(1), r_mix_))); + } +}; + +using torch::Tensor; + +void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); + +Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix, + Tensor v_mix, Tensor r_mix, Tensor kw, + /* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx, Tensor rw, + /* imm */ Tensor rx, Tensor ow, Tensor t_first, + /* imm */ Tensor k, Tensor pp, Tensor ww, Tensor aa, Tensor bb, + Tensor t_decay, /* imm */ Tensor v, /* in & out */ Tensor r, + /* out */ Tensor x_plus_out, /* out */ Tensor t1, + /* out */ Tensor t2, /* out */ Tensor p) { + Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); + element_wise(Mix{data_ptr(xx), data_ptr(sx), + data_ptr(k_mix), data_ptr(v_mix), + data_ptr(r_mix), data_ptr(kx), + data_ptr(vx), data_ptr(rx)}, + x.numel()); + + gemm_fp16_cublas(kx, kw, k); + gemm_fp16_cublas(vx, vw, v); + gemm_fp16_cublas(rx, rw, r); + at::sigmoid_(r); + + element_wise(WkvForwardOne{data_ptr(t_first), data_ptr(k), + data_ptr(pp), data_ptr(aa), + data_ptr(bb), data_ptr(t_decay), + data_ptr(v), data_ptr(t1), + data_ptr(t2), data_ptr(p), + data_ptr(r)}, + x.numel()); + + gemm_fp16_cublas(r, ow, x_plus_out); + x_plus_out += x; + return xx; +} diff --git a/backend-python/rwkv_pip/cuda/att_seq.cu b/backend-python/rwkv_pip/cuda/att_seq.cu new file mode 100644 index 0000000..4d506a3 --- /dev/null +++ b/backend-python/rwkv_pip/cuda/att_seq.cu @@ -0,0 +1,179 @@ +#include "ATen/ATen.h" +#include +#include +#include + +#include "util.h" +#include "element_wise.h" + +using torch::Tensor; + +void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); +void gemm_fp16_cublas(const void *a, const void *b, void *c, int m, + int n, int k, bool output_fp32); + +// based on `kernel_wkv_forward`, fusing more operations +__global__ void kernel_wkv_forward_new( + const int B, const int T, const int C, const float *__restrict__ const _w, + const float *__restrict__ const _u, const float *__restrict__ const _k, + const float *__restrict__ const _v, const half *__restrict__ const r, + half *__restrict__ const _y, float *__restrict__ const _aa, + float *__restrict__ const _bb, float *__restrict__ const _pp) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + const int _state_offset = _b * C + _c; + + float u = _u[_c]; + float w = _w[_c]; + const float *__restrict__ const k = _k + _offset; + const float *__restrict__ const v = _v + _offset; + half *__restrict__ const y = _y + _offset; + + float aa = _aa[_state_offset]; + float bb = _bb[_state_offset]; + float pp = _pp[_state_offset]; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = k[ii]; + const float vv = v[ii]; + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + y[ii] = __float2half((e1 * aa + e2 * vv) / (e1 * bb + e2)); + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + _aa[_state_offset] = aa; + _bb[_state_offset] = bb; + _pp[_state_offset] = pp; +} + +void cuda_wkv_forward_new(int B, int T, int C, float *w, float *u, float *k, + float *v, half *r, half *y, float *aa, float *bb, + float *pp) { + dim3 threadsPerBlock(min(C, 32)); + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_wkv_forward_new<<>>(B, T, C, w, u, k, v, r, + y, aa, bb, pp); +} + +__global__ void _att_mix(const half *xx, const half *sx, const half *k_mix, + const half *v_mix, const half *r_mix, + const int outer_size, const int inner_size, half *kx, + half *vx, half *rx) { + for (int idx2 = blockIdx.x * blockDim.x + threadIdx.x; idx2 < inner_size; + idx2 += blockDim.x * gridDim.x) { + half k_mix_ = k_mix[idx2]; + half v_mix_ = v_mix[idx2]; + half r_mix_ = r_mix[idx2]; + for (int row = 0; row < outer_size; ++row) { + int idx1 = row * inner_size + idx2; + half xx_ = xx[idx1]; + half sx_ = sx[idx1]; + kx[idx1] = __hadd(__hmul(xx_, k_mix_), + __hmul(sx_, __hsub(__float2half(1), k_mix_))); + vx[idx1] = __hadd(__hmul(xx_, v_mix_), + __hmul(sx_, __hsub(__float2half(1), v_mix_))); + rx[idx1] = __hadd(__hmul(xx_, r_mix_), + __hmul(sx_, __hsub(__float2half(1), r_mix_))); + } + } +} + +void att_mix(const half *xx, const half *sx, const half *k_mix, + const half *v_mix, const half *r_mix, const int outer_size, + const int inner_size, half *kx, half *vx, half *rx) { + // 256 is good enough on most GPUs + const int32_t BLOCK_SIZE = 256; + assert(inner_size % BLOCK_SIZE == 0); + _att_mix<<>>( + xx, sx, k_mix, v_mix, r_mix, outer_size, inner_size, kx, vx, rx); +} + +struct InplaceSigmoid { + __device__ __forceinline__ half operator()(int i) const { + ptr[i] = __float2half(1.0 / (1.0 + exp(-__half2float(ptr[i])))); + } + half *ptr; +}; + +struct InplaceMul { + __device__ __forceinline__ half operator()(int i) const { + y[i] = __hmul(x[i], y[i]); + } + half *y; + half *x; +}; + +/* + Equivalent Python code: + + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(gemm(rx, rw)) + k = gemm(kx, kw, output_dtype=torch.float32) + v = gemm(vx, vw, output_dtype=torch.float32) + + T = x.shape[0] + for t in range(T): + kk = k[t] + vv = v[t] + ww = t_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + aa = e1 * aa + e2 * vv + bb = e1 * bb + e2 + pp = p + out = gemm(r * sx, ow) + return x + out, xx[-1,:], aa, bb, pp +*/ +Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix, + Tensor v_mix, Tensor r_mix, Tensor kw, Tensor vw, Tensor rw, + Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb, + Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out) { + Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); + sx = at::cat({sx.unsqueeze(0), xx.slice(0, 0, -1)}, 0); + char* buf_ptr = (char*)buf.data_ptr(); + half* kx = (half*)buf_ptr; + half* vx = kx + x.numel(); + half* rx = vx + x.numel(); + half* wkv_y = rx + x.numel(); + att_mix(data_ptr(xx), data_ptr(sx), data_ptr(k_mix), + data_ptr(v_mix), data_ptr(r_mix), xx.size(0), xx.size(1), + kx, vx, rx); + float* k = reinterpret_cast(wkv_y + x.numel()); + float* v = k + x.size(0) * kw.size(1); + half* r = reinterpret_cast(v + x.size(0) * vw.size(1)); + + gemm_fp16_cublas(kx, kw.data_ptr(), k, x.size(0), kw.size(1), kw.size(0), true); + gemm_fp16_cublas(vx, vw.data_ptr(), v, x.size(0), vw.size(1), vw.size(0), true); + gemm_fp16_cublas(rx, rw.data_ptr(), r, x.size(0), rw.size(1), rw.size(0), false); + element_wise(InplaceSigmoid{r}, x.size(0) * rw.size(1)); + cuda_wkv_forward_new(1, x.size(0), x.size(1), data_ptr(t_decay), + data_ptr(t_first), k, v, r, + wkv_y, data_ptr(aa), + data_ptr(bb), data_ptr(pp)); + element_wise(InplaceMul{wkv_y, r}, x.numel()); + gemm_fp16_cublas(wkv_y, ow.data_ptr(), x_plus_out.data_ptr(), x.size(0), ow.size(1), ow.size(0), false); + x_plus_out += x; + return xx; +} diff --git a/backend-python/rwkv_pip/cuda/element_wise.h b/backend-python/rwkv_pip/cuda/element_wise.h new file mode 100644 index 0000000..eedc2f9 --- /dev/null +++ b/backend-python/rwkv_pip/cuda/element_wise.h @@ -0,0 +1,21 @@ +#include +#include +#include + +template __global__ void _element_wise(Func func, int n) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; + i += blockDim.x * gridDim.x) { + func(i); + } +} + +// NOTE: packed data type (e.g. float4) is a overkill for current sizes +// (4096 in 7B model and 768 in 0.1B model), +// and is not faster than the plain float version. +template +void element_wise(Func func, int n) { + // 256 is good enough on most GPUs + const int32_t BLOCK_SIZE = 256; + assert(n % BLOCK_SIZE == 0); + _element_wise<<>>(func, n); +} diff --git a/backend-python/rwkv_pip/cuda/ffn.cu b/backend-python/rwkv_pip/cuda/ffn.cu new file mode 100644 index 0000000..c1c2c80 --- /dev/null +++ b/backend-python/rwkv_pip/cuda/ffn.cu @@ -0,0 +1,165 @@ +#include "ATen/ATen.h" +#include +#include +#include + +#include "element_wise.h" +#include "util.h" + +using torch::Tensor; + +void gemm_fp16_cublas(const void *a, const void *b, void *c, int ori_m, + int ori_n, int ori_k, bool output_fp32); + +__global__ void _ffn_seq_mix(const half *xx, const half *sx, const half *k_mix, + const half *r_mix, const int outer_size, + const int inner_size, half *kx, half *rx) { + for (int idx2 = blockIdx.x * blockDim.x + threadIdx.x; idx2 < inner_size; + idx2 += blockDim.x * gridDim.x) { + half k_mix_ = k_mix[idx2]; + half r_mix_ = r_mix[idx2]; + for (int row = 0; row < outer_size; ++row) { + int idx1 = row * inner_size + idx2; + half xx_ = xx[idx1]; + half sx_ = sx[idx1]; + kx[idx1] = __hadd(__hmul(xx_, k_mix_), + __hmul(sx_, __hsub(__float2half(1), k_mix_))); + rx[idx1] = __hadd(__hmul(xx_, r_mix_), + __hmul(sx_, __hsub(__float2half(1), r_mix_))); + } + } +} + +void ffn_seq_mix(const half *xx, const half *sx, const half *k_mix, + const half *r_mix, const int outer_size, const int inner_size, + half *kx, half *rx) { + // 256 is good enough on most GPUs + const int32_t BLOCK_SIZE = 256; + assert(inner_size % BLOCK_SIZE == 0); + _ffn_seq_mix<<>>( + xx, sx, k_mix, r_mix, outer_size, inner_size, kx, rx); +} + +struct InplaceSigmoid { + __device__ __forceinline__ void operator()(int i) const { + ptr[i] = __float2half(1.0 / (1.0 + exp(-__half2float(ptr[i])))); + } + half *ptr; +}; + +struct InplaceReLUAndSquare { + __device__ __forceinline__ void operator()(int i) const { + // __hmax is not defined in old cuda + if (__hgt(ptr[i], __float2half(0))) { + ptr[i] = __hmul(ptr[i], ptr[i]); + } else { + ptr[i] = __float2half(0); + } + } + half *ptr; +}; + +struct InplaceFma { + __device__ __forceinline__ void operator()(int i) const { + a[i] = __hfma(a[i], b[i], c[i]); + } + half *a; + const half *b; + const half *c; +}; + +/* + Equivalent Python code: + + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(gemm(rx, rw)) + vx = torch.square(torch.relu(gemm(kx, kw))) + out = r * gemm(vx, vw) + return x + out, xx[-1,:] +*/ +Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix, + Tensor r_mix, Tensor kw, Tensor vw, Tensor rw, + /* imm */ Tensor buf, + /* out */ Tensor x_plus_out) { + Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); + sx = at::cat({sx.unsqueeze(0), xx.slice(0, 0, -1)}, 0); + char *buf_ptr = (char *)buf.data_ptr(); + half *kx = (half *)buf_ptr; + half *rx = kx + x.numel(); + half *vx = rx + x.numel(); + half *r = vx + x.size(0) * kw.size(1); + ffn_seq_mix(data_ptr(xx), data_ptr(sx), data_ptr(k_mix), + data_ptr(r_mix), xx.size(0), xx.size(1), kx, rx); + + gemm_fp16_cublas(rx, rw.data_ptr(), r, x.size(0), rw.size(1), x.size(1), + false); + element_wise(InplaceSigmoid{r}, x.size(0) * rw.size(1)); + gemm_fp16_cublas(kx, kw.data_ptr(), vx, x.size(0), kw.size(1), x.size(1), + false); + element_wise(InplaceReLUAndSquare{vx}, x.size(0) * kw.size(1)); + gemm_fp16_cublas(vx, vw.data_ptr(), x_plus_out.data_ptr(), x.size(0), + vw.size(1), vw.size(0), false); + element_wise(InplaceFma{data_ptr(x_plus_out), r, data_ptr(x)}, + x_plus_out.numel()); + return xx; +} + +struct FfnOneMix { + __device__ __forceinline__ void operator()(int idx) { + half k_mix_ = k_mix[idx]; + half r_mix_ = r_mix[idx]; + half xx_ = xx[idx]; + half sx_ = sx[idx]; + kx[idx] = __hadd(__hmul(xx_, k_mix_), + __hmul(sx_, __hsub(__float2half(1), k_mix_))); + rx[idx] = __hadd(__hmul(xx_, r_mix_), + __hmul(sx_, __hsub(__float2half(1), r_mix_))); + } + half *k_mix; + half *r_mix; + half *xx; + half *sx; + half *kx; + half *rx; +}; + +/* + Equivalent Python code: + + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(gemm(rx, rw)) + vx = torch.square(torch.relu(gemm(kx, kw))) + out = r * gemm(vx, vw) + return x + out, xx +*/ +Tensor ffn_one(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix, + Tensor r_mix, Tensor kw, Tensor vw, Tensor rw, + /* imm */ Tensor buf, + /* out */ Tensor x_plus_out) { + Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); + char *buf_ptr = (char *)buf.data_ptr(); + half *kx = (half *)buf_ptr; + half *rx = kx + x.numel(); + half *vx = rx + x.numel(); + half *r = vx + x.size(0) * kw.size(1); + element_wise(FfnOneMix{data_ptr(k_mix), data_ptr(r_mix), + data_ptr(xx), data_ptr(sx), kx, rx}, + x.numel()); + // vector * matrix, so m = 1 + gemm_fp16_cublas(rx, rw.data_ptr(), r, 1, rw.size(1), rw.size(0), false); + element_wise(InplaceSigmoid{r}, rw.size(1)); + gemm_fp16_cublas(kx, kw.data_ptr(), vx, 1, kw.size(1), kw.size(0), false); + element_wise(InplaceReLUAndSquare{vx}, kw.size(1)); + gemm_fp16_cublas(vx, vw.data_ptr(), x_plus_out.data_ptr(), 1, vw.size(1), + vw.size(0), false); + element_wise(InplaceFma{data_ptr(x_plus_out), r, data_ptr(x)}, + x_plus_out.numel()); + return xx; +} diff --git a/backend-python/rwkv_pip/cuda/gemm_fp16_cublas.cpp b/backend-python/rwkv_pip/cuda/gemm_fp16_cublas.cpp new file mode 100644 index 0000000..db48fcf --- /dev/null +++ b/backend-python/rwkv_pip/cuda/gemm_fp16_cublas.cpp @@ -0,0 +1,86 @@ +#include +#include +#include +#include +#include + +#define CUBLAS_CHECK(condition) \ + for (cublasStatus_t _cublas_check_status = (condition); \ + _cublas_check_status != CUBLAS_STATUS_SUCCESS;) \ + throw std::runtime_error("cuBLAS error " + \ + std::to_string(_cublas_check_status) + " at " + \ + std::to_string(__LINE__)); + +#define CUDA_CHECK(condition) \ + for (cudaError_t _cuda_check_status = (condition); \ + _cuda_check_status != cudaSuccess;) \ + throw std::runtime_error( \ + "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \ + " at " + std::to_string(__LINE__)); + +cublasHandle_t get_cublas_handle() { + static cublasHandle_t cublas_handle = []() { + cublasHandle_t handle = nullptr; + CUBLAS_CHECK(cublasCreate(&handle)); +#if CUDA_VERSION < 11000 + CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); +#else + CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); +#endif // CUDA_VERSION < 11000 + return handle; + }(); + return cublas_handle; +} + +/* + NOTE: blas gemm is column-major by default, but we need row-major output. + The data of row-major, transposed matrix is exactly the same as the + column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T + */ +void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) { + const auto cuda_data_type = CUDA_R_16F; + const auto cuda_c_data_type = + c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F; + const auto compute_type = CUDA_R_32F; + const float sp_alpha = 1.f; + // swap a and b, and use CUBLAS_OP_N. see the notes above + std::swap(a, b); + const cublasOperation_t cublas_trans_a = CUBLAS_OP_N; + const cublasOperation_t cublas_trans_b = CUBLAS_OP_N; + // m = (B^T).size(0) = B.size(1), and = A.size(1) after swap, + // negative axis is used because of the existence of batch matmul. + const int m = a.size(-1); + const int k = a.size(-2); + const int n = b.size(-2); + const int cublas_lda = m; + const int cublas_ldb = k; + const int cublas_ldc = m; + cublasHandle_t cublas_handle = get_cublas_handle(); + +#if CUDA_VERSION >= 11000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; +#else + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP; +#endif + const float sp_beta = 0.f; + if (a.sizes().size() == 2 && b.sizes().size() == 2) { + CUBLAS_CHECK(cublasGemmEx( + cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha, + a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type, + cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, + compute_type, algo)); + } else { + // batch matmul + assert(a.sizes().size() == 3 && b.sizes().size() == 3); + + const long long int cublas_stride_a = m * k; + const long long int cublas_stride_b = k * n; + const long long int cublas_stride_c = m * n; + CUBLAS_CHECK(cublasGemmStridedBatchedEx( + cublas_handle, cublas_trans_a, cublas_trans_b, m, + n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda, + cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b, + &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c, + a.size(0), compute_type, algo)); + } +} diff --git a/backend-python/rwkv_pip/cuda/operators.cu b/backend-python/rwkv_pip/cuda/operators.cu new file mode 100644 index 0000000..fa5a44f --- /dev/null +++ b/backend-python/rwkv_pip/cuda/operators.cu @@ -0,0 +1,246 @@ +#include +#include +#include "ATen/ATen.h" +#include +#define MIN_VALUE (-1e38) +typedef at::Half fp16; +__half *cast(fp16 *ptr) { + return reinterpret_cast<__half *>(ptr); +} + +template +__global__ void kernel_wkv_forward(const int B, const int T, const int C, + const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + const int _state_offset = _b * C + _c; + + float u = _u[_c]; + float w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + + float aa = _aa[_state_offset]; + float bb = _bb[_state_offset]; + float pp = _pp[_state_offset]; + for (int i = 0; i < T; i++) { + const int ii = i * C; + const float kk = float(k[ii]); + const float vv = float(v[ii]); + float ww = u + kk; + float p = max(pp, ww); + float e1 = exp(pp - p); + float e2 = exp(ww - p); + y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2)); + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; + } + _aa[_state_offset] = aa; + _bb[_state_offset] = bb; + _pp[_state_offset] = pp; +} + +template +void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) { + dim3 threadsPerBlock( min(C, 32) ); + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_wkv_forward<<>>(B, T, C, w, u, k, v, y, aa, bb, pp); +} + +template void cuda_wkv_forward( + int B, int T, int C, + float *w, float *u, fp16 *k, fp16 *v, fp16 *y, + float *aa, float *bb, float *pp); +template void cuda_wkv_forward( + int B, int T, int C, + float *w, float *u, float *k, float *v, float *y, + float *aa, float *bb, float *pp); + +__global__ void kernel_mm_seq_fp32i8( + const int B, const int N, const int M, + const float *__restrict__ const x, const int x_stride, + const uint8_t *__restrict__ const w, const int w_stride, + const float *__restrict__ const mx, + const float *__restrict__ const rx, + const float *__restrict__ const my, + const float *__restrict__ const ry, + float *__restrict__ const y, const int y_stride) { + + const int i = blockIdx.x * blockDim.x + threadIdx.x; + const int k = blockIdx.y * blockDim.y + threadIdx.y; + + if (i < B && k < M) { + float y_local = 0; + for (int j = 0; j < N; ++j) { + y_local += x[i * x_stride + j] * ( + (float(w[j * w_stride + k]) + 0.5f) + * rx[k] * ry[j] + mx[k] + my[j] + ); + } + y[i * y_stride + k] = y_local; + } +} + +template +void cuda_mm8_seq(int B, int N, int M, + F *x, int x_stride, + uint8_t *w, int w_stride, + F *mx, F *rx, + F *my, F *ry, + F *y, int y_stride); + +template <> +void cuda_mm8_seq(int B, int N, int M, + float *x, int x_stride, + uint8_t *w, int w_stride, + float *mx, float *rx, + float *my, float *ry, + float *y, int y_stride) { + dim3 blockSize(1, 128); + dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y); + kernel_mm_seq_fp32i8<<>>( + B, N, M, x, x_stride, w, w_stride, + mx, rx, my, ry, y, y_stride); +} + +__global__ void kernel_mm_seq_fp16i8( + const int B, const int N, const int M, + const __half *__restrict__ const x, const int x_stride, + const uint8_t *__restrict__ const w, const int w_stride, + const __half *__restrict__ const mx, + const __half *__restrict__ const rx, + const __half *__restrict__ const my, + const __half *__restrict__ const ry, + __half *__restrict__ const y, const int y_stride) { + + const int i = blockIdx.x * blockDim.x + threadIdx.x; + const int k = blockIdx.y * blockDim.y + threadIdx.y; + + if (i < B && k < M) { + float y_local = 0; + for (int j = 0; j < N; ++j) { + y_local += __half2float(x[i * x_stride + j]) * ( + (float(w[j * w_stride + k]) + 0.5f) + * __half2float(rx[k]) * __half2float(ry[j]) + + __half2float(mx[k]) + __half2float(my[j]) + ); + } + y[i * y_stride + k] = __float2half(y_local); + } +} + +template <> +void cuda_mm8_seq(int B, int N, int M, + fp16 *x, int x_stride, + uint8_t *w, int w_stride, + fp16 *mx, fp16 *rx, + fp16 *my, fp16 *ry, + fp16 *y, int y_stride) { + dim3 blockSize(1, 128); + dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y); + kernel_mm_seq_fp16i8<<>>( + B, N, M, cast(x), x_stride, w, w_stride, + cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride); +} + +#define MM8_ONE_JSPLIT 24 +#define MM8_ONE_TILE 1024 + +__global__ void kernel_mm_one_fp32i8( + const int N, const int M, + const float *__restrict__ const x, + const uint8_t *__restrict__ const w, const int w_stride, + const float *__restrict__ const mx, + const float *__restrict__ const rx, + const float *__restrict__ const my, + const float *__restrict__ const ry, + float *__restrict__ const y) { + + const int k = blockIdx.y * blockDim.y + threadIdx.y; + const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); + const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); + + if (k < M) { + float y_local = 0; + for (int j = j0; j < j1; ++j) { + y_local += x[j] * ( + (float(w[j * w_stride + k]) + 0.5f) + * rx[k] * ry[j] + mx[k] + my[j] + ); + } + atomicAdd(&y[k], y_local); + } +} + +template +void cuda_mm8_one(int N, int M, + F *x, + uint8_t *w, int w_stride, + F *mx, F *rx, + F *my, F *ry, + float *y); + +template <> +void cuda_mm8_one(int N, int M, + float *x, + uint8_t *w, int w_stride, + float *mx, float *rx, + float *my, float *ry, + float *y) { + dim3 blockSize(1, MM8_ONE_TILE); + dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y); + kernel_mm_one_fp32i8<<>>( + N, M, x, w, w_stride, + mx, rx, my, ry, y); +} + +__global__ void kernel_mm_one_fp16i8( + const int N, const int M, + const __half *__restrict__ const x, + const uint8_t *__restrict__ const w, const int w_stride, + const __half *__restrict__ const mx, + const __half *__restrict__ const rx, + const __half *__restrict__ const my, + const __half *__restrict__ const ry, + float *__restrict__ const y) { + + const int k = blockIdx.y * blockDim.y + threadIdx.y; + const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); + const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); + + if (k < M) { + float y_local = 0; + for (int j = j0; j < j1; ++j) { + y_local += __half2float(x[j]) * ( + (float(w[j * w_stride + k]) + 0.5f) + * __half2float(rx[k]) * __half2float(ry[j]) + + __half2float(mx[k]) + __half2float(my[j]) + ); + } + atomicAdd(&y[k], y_local); + } +} + +template <> +void cuda_mm8_one(int N, int M, + fp16 *x, + uint8_t *w, int w_stride, + fp16 *mx, fp16 *rx, + fp16 *my, fp16 *ry, + float *y) { + dim3 blockSize(1, MM8_ONE_TILE); + dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y); + kernel_mm_one_fp16i8<<>>( + N, M, cast(x), w, w_stride, + cast(mx), cast(rx), cast(my), cast(ry), y); +} diff --git a/backend-python/rwkv_pip/cuda/rwkv5.cu b/backend-python/rwkv_pip/cuda/rwkv5.cu new file mode 100644 index 0000000..a7f13c3 --- /dev/null +++ b/backend-python/rwkv_pip/cuda/rwkv5.cu @@ -0,0 +1,88 @@ +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; +typedef at::Half fp16; +typedef float fp32; + +template +__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, + F *__restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _w += h*_N_; + _u += h*_N_; + _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!! + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + + float state[_N_]; + #pragma unroll + for (int j = 0; j < _N_; j++) + state[j] = _state[j]; + + __syncthreads(); + u[i] = float(_u[i]); + w[i] = _w[i]; + __syncthreads(); + + for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) + { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + + #pragma unroll + for (int j = 0; j < _N_; j+=4) + { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } + #pragma unroll + for (int j = 0; j < _N_; j++) + _state[j] = state[j]; +} + +void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) +{ + assert(H*_N_ == C); + kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); +} +void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y) +{ + assert(H*_N_ == C); + kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); +} +void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y) +{ + assert(H*_N_ == C); + kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); +} diff --git a/backend-python/rwkv_pip/cuda/rwkv5_op.cpp b/backend-python/rwkv_pip/cuda/rwkv5_op.cpp new file mode 100644 index 0000000..5471bcf --- /dev/null +++ b/backend-python/rwkv_pip/cuda/rwkv5_op.cpp @@ -0,0 +1,30 @@ +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; +typedef at::Half fp16; +typedef float fp32; + +void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); +void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y); +void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y); + +void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + cuda_forward_bf16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} +void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + cuda_forward_fp16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} +void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + cuda_forward_fp32(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16"); + m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16"); + m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32"); +} +TORCH_LIBRARY(rwkv5, m) { + m.def("forward_bf16", forward_bf16); + m.def("forward_fp16", forward_fp16); + m.def("forward_fp32", forward_fp32); +} diff --git a/backend-python/rwkv_pip/cuda/util.h b/backend-python/rwkv_pip/cuda/util.h new file mode 100644 index 0000000..f00af22 --- /dev/null +++ b/backend-python/rwkv_pip/cuda/util.h @@ -0,0 +1,7 @@ +#include "ATen/ATen.h" +#include + +template T *data_ptr(torch::Tensor x) { return x.data_ptr(); } +template <> inline half *data_ptr(torch::Tensor x) { + return reinterpret_cast(x.data_ptr()); +} diff --git a/backend-python/rwkv_pip/cuda/wrapper.cpp b/backend-python/rwkv_pip/cuda/wrapper.cpp new file mode 100644 index 0000000..5e91eb5 --- /dev/null +++ b/backend-python/rwkv_pip/cuda/wrapper.cpp @@ -0,0 +1,141 @@ +#include +#include "ATen/ATen.h" +#include +#include + +typedef at::Half fp16; + +template +void cuda_wkv_forward(int B, int T, int C, + float *w, float *u, F *k, F *v, F *y, + float *aa, float *bb, float *pp); +template +void cuda_mm8_seq(int B, int N, int M, + F *x, int x_stride, + uint8_t *w, int w_stride, + F *mx, F *rx, + F *my, F *ry, + F *y, int y_stride); +template +void cuda_mm8_one(int N, int M, + F *x, + uint8_t *w, int w_stride, + F *mx, F *rx, + F *my, F *ry, + float *y); + +void wkv_forward(int64_t B, int64_t T, int64_t C, + torch::Tensor &w, torch::Tensor &u, + torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, + torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(w)); + switch (k.scalar_type()) { + case c10::ScalarType::Half: + cuda_wkv_forward(B, T, C, + w.data_ptr(), u.data_ptr(), + k.data_ptr(), v.data_ptr(), y.data_ptr(), + aa.data_ptr(), bb.data_ptr(), pp.data_ptr()); + break; + case c10::ScalarType::Float: + cuda_wkv_forward(B, T, C, + w.data_ptr(), u.data_ptr(), + k.data_ptr(), v.data_ptr(), y.data_ptr(), + aa.data_ptr(), bb.data_ptr(), pp.data_ptr()); + break; + default: + assert(false && "Only FP16 and FP32 are currently supported"); + } +} + +void mm8_seq(int64_t B, int64_t N, int64_t M, + torch::Tensor &x, torch::Tensor &w, + torch::Tensor &mx, torch::Tensor &rx, + torch::Tensor &my, torch::Tensor &ry, + torch::Tensor &y) { + assert(x.stride(1) == 1); + assert(w.stride(1) == 1); + assert(mx.stride(0) == 1 && rx.stride(0) == 1); + assert(my.stride(0) == 1 && ry.stride(0) == 1); + assert(y.stride(1) == 1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(w)); + switch (x.scalar_type()) { + case c10::ScalarType::Half: + cuda_mm8_seq( + B, N, M, + x.data_ptr(), x.stride(0), + w.data_ptr(), w.stride(0), + mx.data_ptr(), rx.data_ptr(), + my.data_ptr(), ry.data_ptr(), + y.data_ptr(), y.stride(0)); + break; + case c10::ScalarType::Float: + cuda_mm8_seq( + B, N, M, + x.data_ptr(), x.stride(0), + w.data_ptr(), w.stride(0), + mx.data_ptr(), rx.data_ptr(), + my.data_ptr(), ry.data_ptr(), + y.data_ptr(), y.stride(0)); + break; + default: + assert(false && "Only FP16 and FP32 are currently supported"); + } +} +void mm8_one(int64_t N, int64_t M, + torch::Tensor &x, torch::Tensor &w, + torch::Tensor &mx, torch::Tensor &rx, + torch::Tensor &my, torch::Tensor &ry, + torch::Tensor &y) { + assert(x.stride(0) == 1); + assert(w.stride(1) == 1); + assert(mx.stride(0) == 1 && rx.stride(0) == 1); + assert(my.stride(0) == 1 && ry.stride(0) == 1); + assert(y.stride(0) == 1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(w)); + switch (x.scalar_type()) { + case c10::ScalarType::Half: + cuda_mm8_one( + N, M, + x.data_ptr(), + w.data_ptr(), w.stride(0), + mx.data_ptr(), rx.data_ptr(), + my.data_ptr(), ry.data_ptr(), + y.data_ptr()); + break; + case c10::ScalarType::Float: + cuda_mm8_one( + N, M, + x.data_ptr(), + w.data_ptr(), w.stride(0), + mx.data_ptr(), rx.data_ptr(), + my.data_ptr(), ry.data_ptr(), + y.data_ptr()); + break; + default: + assert(false && "Only FP16 and FP32 are currently supported"); + } +} + +using torch::Tensor; + +#ifndef DISABLE_CUBLAS_GEMM +void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); +#endif + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("wkv_forward", &wkv_forward, "wkv forward"); + m.def("mm8_seq", &mm8_seq, "mm8 seq"); + m.def("mm8_one", &mm8_one, "mm8 one"); +#ifndef DISABLE_CUBLAS_GEMM + m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas"); +#endif +} + +TORCH_LIBRARY(rwkv, m) { + m.def("wkv_forward", wkv_forward); + m.def("mm8_seq", mm8_seq); + m.def("mm8_one", mm8_one); +#ifndef DISABLE_CUBLAS_GEMM + m.def("gemm_fp16_cublas", gemm_fp16_cublas); +#endif +} diff --git a/backend-python/rwkv_pip/model.py b/backend-python/rwkv_pip/model.py new file mode 100644 index 0000000..75c543e --- /dev/null +++ b/backend-python/rwkv_pip/model.py @@ -0,0 +1,1827 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +from typing import Optional +import types, gc, os, time, re +import torch +from torch.nn import functional as F + +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True +current_path = os.path.dirname(os.path.abspath(__file__)) + +######################################################################################################## + +if os.environ.get("RWKV_JIT_ON") != "0": + os.environ["RWKV_JIT_ON"] = "1" + MyModule = torch.jit.ScriptModule + MyFunction = torch.jit.script_method + MyStatic = torch.jit.script +else: + MyModule = torch.nn.Module + + def __nop(ob): + return ob + + MyFunction = __nop + MyStatic = __nop + +if os.environ.get("RWKV_CUDA_ON") == "1": + from torch.utils.cpp_extension import load + + try: + load( + name=f"wkv_cuda", + sources=[ + f"{current_path}/cuda/wrapper.cpp", + f"{current_path}/cuda/operators.cu", + f"{current_path}/cuda/gemm_fp16_cublas.cpp", + ], + verbose=True, + extra_cuda_cflags=[ + "--use_fast_math", + "-O3", + "--extra-device-vectorization", + ], + is_python_module=False, + ) + DISABLE_CUBLAS_GEMM = False + except: + print( + "Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow." + ) + load( + name=f"wkv_cuda", + sources=[ + f"{current_path}/cuda/wrapper.cpp", + f"{current_path}/cuda/operators.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "--use_fast_math", + "-O3", + "--extra-device-vectorization", + ], + extra_cflags=["-DDISABLE_CUBLAS_GEMM"], + is_python_module=False, + ) + DISABLE_CUBLAS_GEMM = True + + @MyStatic + def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp): + assert 1 * C % min(C, 32) == 0 + assert ( + k.dtype == v.dtype == torch.float16 or k.dtype == v.dtype == torch.float32 + ) + assert w.dtype == u.dtype == aa.dtype == bb.dtype == pp.dtype == torch.float32 + w = w.contiguous() + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + y = torch.empty( + (T, C), + device=w.device, + memory_format=torch.contiguous_format, + dtype=k.dtype, + ) + torch.ops.rwkv.wkv_forward(1, T, C, w, u, k, v, y, aa, bb, pp) + return y, aa, bb, pp + + @MyStatic + def cuda_mm8_seq(B: int, N: int, M: int, x, w, mx, rx, my, ry): + assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype + assert x.dtype == torch.float32 or x.dtype == torch.float16 + assert w.dtype == torch.uint8 + assert x.shape == (B, N) + assert w.shape == (N, M) + assert rx.shape == mx.shape == (M,) + assert ry.shape == my.shape == (N, 1) + y = torch.empty((B, M), device=w.device, dtype=x.dtype) + torch.ops.rwkv.mm8_seq(B, N, M, x, w, mx, rx, my, ry, y) + return y + + @MyStatic + def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry): + assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype + assert x.dtype == torch.float32 or x.dtype == torch.float16 + assert w.dtype == torch.uint8 + assert x.shape == (N,) + assert w.shape == (N, M) + assert rx.shape == mx.shape == (M,) + assert ry.shape == my.shape == (N, 1) + y = torch.zeros((M,), device=w.device, dtype=torch.float32) + torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y) + return y.to(dtype=x.dtype) + +else: + os.environ["RWKV_CUDA_ON"] = "0" + +if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM: + + @MyStatic + def gemm(a, b, output_dtype: Optional[torch.dtype] = None): + if output_dtype is None: + output_dtype = a.dtype + if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda": + if len(a.shape) == 1: + assert len(b.shape) == 2 + c = torch.empty((b.shape[-1],), dtype=output_dtype, device=a.device) + a = a.unsqueeze(0) + else: + assert len(a.shape) == len(b.shape) + assert len(a.shape) == 2 or len(a.shape) == 3 + # torch.empty((*a.shape[:-1], b.shape[-1])) doesn't work with jit + if len(a.shape) == 2: + c = torch.empty( + (a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device + ) + else: + c = torch.empty( + (a.shape[0], a.shape[1], b.shape[-1]), + dtype=output_dtype, + device=a.device, + ) + torch.ops.rwkv.gemm_fp16_cublas(a, b, c) + return c + else: + return (a @ b).to(output_dtype) + +else: + + def gemm(a, b, output_dtype: Optional[torch.dtype] = None): + if output_dtype is None: + output_dtype = a.dtype + return (a @ b).to(output_dtype) + + +######################################################################################################## + + +class RWKV(MyModule): + def __init__(self, model, strategy, verbose=True, convert_and_save_and_exit=None): + super().__init__() + if verbose: + prxxx = lambda *args, **kwargs: print(*args, **kwargs) + else: + prxxx = lambda *args, **kwargs: None + + STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$" + if not re.match(STRATEGY_REGEX, strategy): + raise ValueError( + "Invalid strategy. Please read https://pypi.org/project/rwkv/" + ) + + strategy = ("->".join([x.strip() for x in strategy.split("->")])).replace( + "->", " -> " + ) + self.args = types.SimpleNamespace() + args = self.args + args.MODEL_NAME = model + args.strategy_string = strategy + + # Rescale for fp16 mode: set x = x/2 every X layer (to avoid fp16 overflow) + try: + self.RESCALE_LAYER = int( + os.environ["RWKV_RESCALE_LAYER"] + ) # !!! NOTE: SEEMS YOU SHOULD SET IT TO 999 (disable) FOR RWKV-MUSIC MODELS !!! + except: + self.RESCALE_LAYER = 6 if "fp16" in strategy else 0 + prxxx( + f'RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]} RWKV_CUDA_ON {os.environ["RWKV_CUDA_ON"]} RESCALE_LAYER {self.RESCALE_LAYER}\n' + ) + + args.MODEL_NAME = args.MODEL_NAME.strip() + if not args.MODEL_NAME.endswith(".pth"): + args.MODEL_NAME += ".pth" + prxxx(f"Loading {args.MODEL_NAME} ...") + with torch.no_grad(): + self.w = torch.load( + args.MODEL_NAME, map_location="cpu" + ) # load model to CPU first + gc.collect() + w = self.w + + ALREADY_CONVERTED = False + if "_strategy" in w: + ALREADY_CONVERTED = True + assert ( + convert_and_save_and_exit == None + ) # you should only convert a raw model + prxxx( + f"Converted model: strategy {w['_strategy']}, version {w['_version']}\n" + ) + assert ( + w["_strategy"] == args.strategy_string + ) # if you are using a new strategy, re-convert the model + assert ( + float(w["_version"]) >= 0.7 + ) # sometimes you should re-convert using latest convert_model.py + assert ( + w["_rescale_layer"] == self.RESCALE_LAYER + ) # must use same RESCALE_LAYER to avoid mistakes + del w["_strategy"] + del w["_version"] + del w["_rescale_layer"] + + args.n_embd = w["emb.weight"].shape[1] + args.n_att = w["blocks.0.att.key.weight"].shape[ + 0 + ] # note: transposed matrix + args.n_ffn = w["blocks.0.ffn.key.weight"].shape[ + 0 + ] # note: transposed matrix + args.n_layer = 0 + keys = list(w.keys()) + self.version = 4 + for x in keys: + layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0 + args.n_layer = max(args.n_layer, layer_id + 1) + if "ln_x" in x: + self.version = max(5, self.version) + if "gate.weight" in x: + self.version = max(5.1, self.version) + if int(self.version) == 5 and "att.time_decay" in x: + args.n_head = w[x].shape[0] + if len(w[x].shape) > 1: + if w[x].shape[1] > 1: + self.version = max(5.2, self.version) + + ####################### Compute strategy + + s = [x.strip().split(" ") for x in strategy.split("->")] + plan = [0] * len(s) + stream_i = -1 + stream_count = 0 + to_allocate = args.n_layer + 1 + allocated = 0 + free_slots = 0 + for i in range(len(s)): + si = s[i] + si1 = si[1] + if si1.startswith("fp32"): + si[1] = [torch.float] + elif si1.startswith("fp16"): + si[1] = [torch.float16] + elif si1.startswith("bf16"): + si[1] = [torch.bfloat16] + if si1.endswith("i8"): + si[1] += [torch.uint8] + else: + si[1] += [si[1][0]] + if len(si) > 2: + ss = si[2] + assert ss.startswith("*") + if ss.endswith("+"): + plan[i] = int(ss[1:-1]) + stream_i = i + else: + plan[i] = int(ss[1:]) + allocated += plan[i] + if allocated >= to_allocate: + plan[i] += to_allocate - allocated + break + else: + free_slots += 1 + if stream_i < 0: + if free_slots > 0 and to_allocate > allocated: + for i in range(len(s)): + if plan[i] == 0: + plan[i] = (to_allocate - allocated) // free_slots + allocated += plan[i] + free_slots -= 1 + if to_allocate > allocated: + plan[len(s) - 1] += to_allocate - allocated + else: + if to_allocate > allocated: + stream_count = to_allocate - allocated + plan[stream_i] += stream_count + prxxx(f"Strategy: (total {args.n_layer}+1={args.n_layer+1} layers)") + for i in range(len(s)): + ss = s[i] + if i != stream_i: + prxxx( + f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]} layers' + ) + else: + prxxx( + f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]-stream_count} layers, stream {stream_count} layers' + ) + plan[i] += 0 if i == 0 else plan[i - 1] + self.strategy = [None] * (args.n_layer + 1) + strategy = self.strategy + for n in range(args.n_layer + 1): + for i in range(len(s)): + if n < plan[i]: + strategy[n] = types.SimpleNamespace() + strategy[n].device = s[i][0] + strategy[n].atype = s[i][1][0] + strategy[n].wtype = s[i][1][1] + strategy[n].stream = False + if i == stream_i and n >= (plan[i] - stream_count): + strategy[n].stream = True + break + prxxx( + f"{n}-{strategy[n].device}-{str(strategy[n].atype).replace('torch.','')}-{str(strategy[n].wtype).replace('torch.','')}{'-stream' if strategy[n].stream else ''}", + end=" ", + ) + prxxx() + + ####################### Load weights to self.w + + if not ALREADY_CONVERTED: + try: # precompute embedding + w["emb.weight"] = F.layer_norm( + w["emb.weight"], + (args.n_embd,), + weight=w["blocks.0.ln0.weight"], + bias=w["blocks.0.ln0.bias"], + ) + except: + w["emb.weight"] = F.layer_norm( + w["emb.weight"].float(), + (args.n_embd,), + weight=w["blocks.0.ln0.weight"].float(), + bias=w["blocks.0.ln0.bias"].float(), + ) + del w["blocks.0.ln0.weight"] + del w["blocks.0.ln0.bias"] + + print_need_newline = False + + REAL_TIME_FIRST = False + for x in list(w.keys()): + if ".time_faaaa" in x: + REAL_TIME_FIRST = True + if REAL_TIME_FIRST: + w = { + k.replace(".time_faaaa", ".time_first") + if ".time_faaaa" in k + else k: v + for k, v in w.items() + } + self.w = w + + keys = list(w.keys()) + for x in keys: + w[x].requires_grad = False + layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0 + if ("ln_out." in x) or ("head." in x): + layer_id = args.n_layer + dd = strategy[layer_id] + DEVICE = dd.device + ATYPE = dd.atype + WTYPE = dd.wtype + + if not ALREADY_CONVERTED: + if self.RESCALE_LAYER > 0: + if "att.output.weight" in x: + w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) + if "ffn.value.weight" in x: + w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) + + if ".time_" in x: + w[x] = w[x].squeeze() + if ( + "key.weight" in x + or "value.weight" in x + or "receptance.weight" in x + or "gate.weight" in x + or "output.weight" in x + or "head.weight" in x + ): + w[x] = w[x].t() + + if ".time_decay" in x: # need fp32 for this + if self.version == 4: + w[x] = -torch.exp(w[x].float()) + elif int(self.version) == 5: + w[x] = torch.exp(-torch.exp(w[x].float())).reshape(-1, 1, 1) + if self.version == 5.2: + w[x] = w[x].reshape(args.n_head, -1, 1) + elif ".time_first" in x: # need fp32 for this + if self.version == 4: + w[x] = w[x].float() + elif int(self.version) == 5: + if REAL_TIME_FIRST: + w[x] = w[x].float().reshape(-1, 1, 1) + else: + w[x] = torch.exp(w[x].float()).reshape(-1, 1, 1) + if self.version == 5.2: + w[x] = w[x].reshape(args.n_head, -1, 1) + elif ".ln_x" in x: # need fp32 for group_norm + w[x] = w[x].float() + else: + if (len(w[x].shape) == 2) and ("emb" not in x): + if WTYPE != torch.uint8: + w[x] = w[x].to(dtype=WTYPE) + else: + w[x] = w[x].float() + + if w[x].shape[0] > w[x].shape[1]: + w[x + "_my"] = torch.amin(w[x], dim=1).unsqueeze(1) + w[x] = w[x] - w[x + "_my"] + w[x + "_mx"] = torch.amin(w[x], dim=0) + w[x] = w[x] - w[x + "_mx"] + w[x + "_rx"] = torch.amax(w[x], dim=0) + w[x] = w[x] / w[x + "_rx"] + w[x + "_ry"] = torch.amax(w[x], dim=1).unsqueeze(1) + w[x] = w[x] / w[x + "_ry"] + else: + w[x + "_mx"] = torch.amin(w[x], dim=0) + w[x] = w[x] - w[x + "_mx"] + w[x + "_my"] = torch.amin(w[x], dim=1).unsqueeze(1) + w[x] = w[x] - w[x + "_my"] + w[x + "_rx"] = torch.amax(w[x], dim=0) + w[x] = w[x] / w[x + "_rx"] + w[x + "_ry"] = torch.amax(w[x], dim=1).unsqueeze(1) + w[x] = w[x] / w[x + "_ry"] + + w[x] = torch.clip( + torch.floor(w[x] * 256), min=0, max=255 + ).to(dtype=torch.uint8) + w[x + "_mx"] = w[x + "_mx"].to(dtype=ATYPE).contiguous() + w[x + "_rx"] = ( + (w[x + "_rx"] / 16).to(dtype=ATYPE).contiguous() + ) + w[x + "_my"] = w[x + "_my"].to(dtype=ATYPE).contiguous() + w[x + "_ry"] = ( + (w[x + "_ry"] / 16).to(dtype=ATYPE).contiguous() + ) + else: + w[x] = w[x].to(dtype=ATYPE) + + if convert_and_save_and_exit == None: + if "emb." in x: + w[x] = w[x].contiguous() + elif (dd.stream) and ( + x.endswith("key.weight") + or x.endswith("value.weight") + or x.endswith("receptance.weight") + or x.endswith("output.weight") + ): + try: + w[x] = ( + w[x].contiguous().pin_memory() + ) # if you see "CUDA error: out of memory" here, that's out of CPU RAM, not VRAM. Get more RAM :) + except: + print( + "Note: You are running out of RAM. Get more CPU RAM. Now this will run much slower." + ) + elif DEVICE != "cpu": + w[x] = w[x].to(device=DEVICE).contiguous() + + if (dd.stream) or (DEVICE != "cpu"): + try: + w[x + "_mx"] = w[x + "_mx"].to(device=DEVICE).contiguous() + w[x + "_rx"] = w[x + "_rx"].to(device=DEVICE).contiguous() + w[x + "_my"] = w[x + "_my"].to(device=DEVICE).contiguous() + w[x + "_ry"] = w[x + "_ry"].to(device=DEVICE).contiguous() + except: + pass + + if "ffn.value.weight" in x: + gc.collect() + if "cuda" in args.strategy_string: + torch.cuda.empty_cache() + + shape = [i for i in w[x].shape if i != 1] + if len(shape) > 1: + shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" + else: + shape = f" {str(shape[0]).rjust(5)} " + if layer_id == 0 or layer_id >= args.n_layer - 1: + if print_need_newline: + prxxx("\n", end="") + print_need_newline = False + dt = str(w[x].dtype).replace("torch.", "") + dt = ( + dt.replace("float32", "f32") + .replace("bfloat16", "bf16") + .replace("float16", "f16") + .replace("uint8", "i8") + ) + prxxx( + x.ljust(32), + dt.rjust(4), + str(w[x].device).rjust(8), + shape, + " (pinned)" if w[x].is_pinned() else "", + ) + else: + print_need_newline = True + prxxx(".", end="", flush=True) + + if convert_and_save_and_exit: + w["_strategy"] = args.strategy_string + w["_rescale_layer"] = self.RESCALE_LAYER + w["_version"] = "0.7" + if not convert_and_save_and_exit.endswith(".pth"): + convert_and_save_and_exit += ".pth" + prxxx(f"Saving to {convert_and_save_and_exit}...") + torch.save(w, convert_and_save_and_exit) + prxxx(f"Converted and saved. Now this will exit.") + exit(0) + + if self.version == 5.2: + assert ( + os.environ["RWKV_CUDA_ON"] == "1" + ) # latest RWKV-5 requires os.environ["RWKV_CUDA_ON"] == '1' (will fix soon) + HEAD_SIZE = args.n_att // args.n_head + rwkv5 = load( + name="rwkv5", + sources=[ + f"{current_path}/cuda/rwkv5_op.cpp", + f"{current_path}/cuda/rwkv5.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={HEAD_SIZE}", + ], + ) + + class RWKV_5(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, H, state, r, k, v, w, u): + with torch.no_grad(): + assert HEAD_SIZE == C // H + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert state.dtype == torch.float32 + assert w.dtype == torch.float32 + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + assert state.is_contiguous() + + y = torch.empty( + (B, T, C), + device=w.device, + dtype=r.dtype, + memory_format=torch.contiguous_format, + ) + if r.dtype == torch.bfloat16: + rwkv5.forward_bf16(B, T, C, H, state, r, k, v, w, u, y) + elif r.dtype == torch.float16: + rwkv5.forward_fp16(B, T, C, H, state, r, k, v, w, u, y) + elif r.dtype == torch.float32: + rwkv5.forward_fp32(B, T, C, H, state, r, k, v, w, u, y) + return y, state + + self.RWKV_5 = RWKV_5 + + gc.collect() + if "cuda" in args.strategy_string: + torch.cuda.empty_cache() + + def RUN_RWKV_5(self, B, T, C, H, state, r, k, v, w, u): + return self.RWKV_5.apply(B, T, C, H, state, r, k, v, w, u) + + @MyFunction + def torch_mm8_seq(self, x, w, mx, rx, my, ry): + return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) + + @MyFunction + def torch_mm8_one(self, x, w, mx, rx, my, ry): + return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) + + if os.environ.get("RWKV_CUDA_ON") == "1": + + @MyFunction + def mm8_seq(self, x, w, mx, rx, my, ry): + if w.device.type == "cuda" and x.dtype == torch.float16: + B, N, M = x.shape[0], w.shape[0], w.shape[1] + return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry) + else: + return self.torch_mm8_seq(x, w, mx, rx, my, ry) + + @MyFunction + def mm8_one(self, x, w, mx, rx, my, ry): + if w.device.type == "cuda": + N, M = w.shape[0], w.shape[1] + return cuda_mm8_one(N, M, x, w, mx, rx, my, ry) + else: + return self.torch_mm8_one(x, w, mx, rx, my, ry) + + else: + + @MyFunction + def mm8_seq(self, x, w, mx, rx, my, ry): + return self.torch_mm8_seq(x, w, mx, rx, my, ry) + + @MyFunction + def mm8_one(self, x, w, mx, rx, my, ry): + return self.torch_mm8_one(x, w, mx, rx, my, ry) + + ######################################################################################################## + + @MyFunction + def ffn_one( + self, + x, + sx, + ln_w, + ln_b, + k_mix, + r_mix, + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(gemm(rx, rw)) + vx = torch.square(torch.relu(gemm(kx, kw))) + out = r * gemm(vx, vw) + return x + out, xx + + @MyFunction + def ffn_one_i8( + self, + x, + sx, + ln_w, + ln_b, + k_mix, + r_mix, + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.square(torch.relu(self.mm8_one(kx, kw, kmx, krx, kmy, kry))) + out = r * (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)) + return x + out, xx + + ######################################################################################################## + + @MyFunction + def ffn_seq( + self, + x, + sx, + ln_w, + ln_b, + k_mix, + r_mix, + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(gemm(rx, rw)) + vx = torch.square(torch.relu(gemm(kx, kw))) + out = r * gemm(vx, vw) + return x + out, xx[-1, :] + + @MyFunction + def ffn_seq_i8( + self, + x, + sx, + ln_w, + ln_b, + k_mix, + r_mix, + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.square(torch.relu(self.mm8_seq(kx, kw, kmx, krx, kmy, kry))) + out = r * (self.mm8_seq(vx, vw, vmx, vrx, vmy, vry)) + return x + out, xx[-1, :] + + ######################################################################################################## + + @MyFunction + def att_one( + self, + x, + sx, + aa, + bb, + pp, + ln_w, + ln_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(gemm(rx, rw)) + k = gemm(kx, kw, output_dtype=torch.float32) + v = gemm(vx, vw, output_dtype=torch.float32) + + ww = t_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + + out = gemm(r * wkv, ow) + return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p + + @MyFunction + def att_one_i8( + self, + x, + sx, + aa, + bb, + pp, + ln_w, + ln_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry)) + k = (self.mm8_one(kx, kw, kmx, krx, kmy, kry)).float() + v = (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)).float() + + ww = t_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + + out = self.mm8_one(r * wkv, ow, omx, orx, omy, ory) + return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p + + ######################################################################################################## + + @MyFunction + def att_seq( + self, + x, + sx, + aa, + bb, + pp, + ln_w, + ln_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(gemm(rx, rw)) + k = gemm(kx, kw, output_dtype=torch.float32) + v = gemm(vx, vw, output_dtype=torch.float32) + + T = x.shape[0] + for t in range(T): + kk = k[t] + vv = v[t] + ww = t_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + aa = e1 * aa + e2 * vv + bb = e1 * bb + e2 + pp = p + out = gemm(r * sx, ow) + return x + out, xx[-1, :], aa, bb, pp + + @MyFunction + def att_seq_i8( + self, + x, + sx, + aa, + bb, + pp, + ln_w, + ln_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) + k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry).float() + v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry).float() + + T = x.shape[0] + for t in range(T): + kk = k[t] + vv = v[t] + ww = t_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + aa = e1 * aa + e2 * vv + bb = e1 * bb + e2 + pp = p + out = self.mm8_seq(r * sx, ow, omx, orx, omy, ory) + return x + out, xx[-1, :], aa, bb, pp + + ######################################################################################################## + + @MyFunction + def att_one_v5( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + H = t_decay.shape[0] + S = x.shape[-1] // H + + r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S) + k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1) + v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S) + + a = gemm(k, v) + out = r @ (t_first * a + s) + s = a + t_decay * s + + out = out.flatten() + out = F.group_norm( + out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b + ).squeeze(0) + out = out.to(dtype=x.dtype) + out = gemm(out, ow) + + return x + out, xx, s + + @MyFunction + def att_seq_v5( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + H = t_decay.shape[0] + S = x.shape[-1] // H + T = x.shape[0] + + w = t_decay.reshape(-1, 1) + u = t_first.reshape(-1, 1) + ws = w.pow(T).reshape(H, 1, 1) + ind = torch.arange(T - 1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1) + w = w.repeat(1, T).pow(ind) + wk = w.reshape(H, 1, T) + wb = wk.transpose(-2, -1).flip(1) + w = torch.cat([w[:, 1:], u], dim=1) + w = F.pad(w, (0, T)) + w = torch.tile(w, [T]) + w = w[:, :-T].reshape(-1, T, 2 * T - 1) + w = w[:, :, T - 1 :].reshape(H, T, T) + + r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) + k = ( + gemm(kx, kw, output_dtype=torch.float32) + .view(T, H, S) + .transpose(0, 1) + .transpose(-2, -1) + ) + v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) + + out = ((r @ k) * w) @ v + (r @ s) * wb + s = ws * s + (k * wk) @ v + + out = out.transpose(0, 1).contiguous().reshape(T, H * S) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) + out = out.to(dtype=x.dtype) + out = gemm(out, ow) + + return x + out, xx[-1, :], s + + ######################################################################################################## + + @MyFunction + def att_one_v5_1( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + g_mix, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + gx = xx * g_mix + sx * (1 - g_mix) + + H = t_decay.shape[0] + S = x.shape[-1] // H + + r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S) + k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1) + v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S) + g = F.silu(gemm(gx, gw)) + + a = gemm(k, v) + out = r @ (t_first * a + s) + s = a + t_decay * s + + out = out.flatten() + out = F.group_norm( + out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b + ).squeeze(0) + out = out.to(dtype=x.dtype) * g + out = gemm(out, ow) + + return x + out, xx, s + + @MyFunction + def att_seq_v5_1( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + g_mix, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + gx = xx * g_mix + sx * (1 - g_mix) + + H = t_decay.shape[0] + S = x.shape[-1] // H + T = x.shape[0] + + w = t_decay.reshape(-1, 1) + u = t_first.reshape(-1, 1) + ws = w.pow(T).reshape(H, 1, 1) + ind = torch.arange(T - 1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1) + w = w.repeat(1, T).pow(ind) + wk = w.reshape(H, 1, T) + wb = wk.transpose(-2, -1).flip(1) + w = torch.cat([w[:, 1:], u], dim=1) + w = F.pad(w, (0, T)) + w = torch.tile(w, [T]) + w = w[:, :-T].reshape(-1, T, 2 * T - 1) + w = w[:, :, T - 1 :].reshape(H, T, T) + + r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) + k = ( + gemm(kx, kw, output_dtype=torch.float32) + .view(T, H, S) + .transpose(0, 1) + .transpose(-2, -1) + ) + v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) + g = F.silu(gemm(gx, gw)) + + out = ((r @ k) * w) @ v + (r @ s) * wb + s = ws * s + (k * wk) @ v + + out = out.transpose(0, 1).contiguous().reshape(T, H * S) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) + out = out.to(dtype=x.dtype) * g + out = gemm(out, ow) + + return x + out, xx[-1, :], s + + ######################################################################################################## + + def att_seq_v5_2( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + g_mix, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + gx = xx * g_mix + sx * (1 - g_mix) + + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + r = gemm(rx, rw, output_dtype=torch.float32) + k = gemm(kx, kw, output_dtype=torch.float32) + v = gemm(vx, vw, output_dtype=torch.float32) + g = F.silu(gemm(gx, gw)) + + out, s = self.RUN_RWKV_5( + 1, + T, + self.args.n_att, + H, + s.transpose(-1, -2).contiguous(), + r, + k, + v, + w=t_decay, + u=t_first, + ) + s = s.transpose(-1, -2) + + out = out.reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) + out = out.to(dtype=x.dtype) * g + out = gemm(out, ow) + + return x + out, xx[-1, :], s + + ######################################################################################################## + + if os.environ["RWKV_CUDA_ON"] == "1": + + @MyFunction + def cuda_att_seq( + self, + x, + sx, + aa, + bb, + pp, + ln_w, + ln_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + T, C = x.shape + xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(gemm(rx, rw)) + k = gemm(kx, kw, output_dtype=torch.float32) + v = gemm(vx, vw, output_dtype=torch.float32) + y, aa, bb, pp = cuda_wkv(T, aa.shape[0], t_decay, t_first, k, v, aa, bb, pp) + + out = gemm(r * y.to(x.dtype), ow) + return x + out, xx[-1, :], aa, bb, pp + + @MyFunction + def cuda_att_seq_i8( + self, + x, + sx, + aa, + bb, + pp, + ln_w, + ln_b, + k_mix, + v_mix, + r_mix, + t_decay, + t_first, + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ): + T, C = x.shape + xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) + k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry) + v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry) + y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp) + + out = self.mm8_seq(r * y, ow, omx, orx, omy, ory) + return x + out, xx[-1, :], aa, bb, pp + + ######################################################################################################## + + def forward(self, tokens, state, full_output=False): + with torch.no_grad(): + w = self.w + args = self.args + + if state == None: + if self.version == 4: + state = [None] * args.n_layer * 5 + for i in range( + args.n_layer + ): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + state[i * 5 + 0] = torch.zeros( + args.n_embd, dtype=atype, requires_grad=False, device=dev + ).contiguous() + state[i * 5 + 1] = torch.zeros( + args.n_att, + dtype=torch.float, + requires_grad=False, + device=dev, + ).contiguous() + state[i * 5 + 2] = torch.zeros( + args.n_att, + dtype=torch.float, + requires_grad=False, + device=dev, + ).contiguous() + state[i * 5 + 3] = ( + torch.zeros( + args.n_att, + dtype=torch.float, + requires_grad=False, + device=dev, + ).contiguous() + - 1e30 + ) + state[i * 5 + 4] = torch.zeros( + args.n_embd, dtype=atype, requires_grad=False, device=dev + ).contiguous() + elif int(self.version) == 5: + state = [None] * args.n_layer * 3 + for i in range(args.n_layer): # state: 0=att_xx 1=att_kv 2=ffn_xx + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + state[i * 3 + 0] = torch.zeros( + args.n_embd, dtype=atype, requires_grad=False, device=dev + ).contiguous() + state[i * 3 + 1] = torch.zeros( + ( + args.n_head, + args.n_att // args.n_head, + args.n_att // args.n_head, + ), + dtype=torch.float, + requires_grad=False, + device=dev, + ).contiguous() + state[i * 3 + 2] = torch.zeros( + args.n_embd, dtype=atype, requires_grad=False, device=dev + ).contiguous() + + seq_mode = len(tokens) > 1 + + x = w["emb.weight"][tokens if seq_mode else tokens[0]] + + for i in range(args.n_layer): + bbb = f"blocks.{i}." + att = f"blocks.{i}.att." + ffn = f"blocks.{i}.ffn." + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + wtype = dd.wtype + if seq_mode: + if "cuda" in str(dev) and os.environ["RWKV_CUDA_ON"] == "1": + ATT = ( + self.cuda_att_seq + if wtype != torch.uint8 + else self.cuda_att_seq_i8 + ) + else: + ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8 + if self.version == 5: + ATT = self.att_seq_v5 + elif self.version == 5.1: + ATT = self.att_seq_v5_1 + elif self.version == 5.2: + ATT = self.att_seq_v5_2 + FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8 + else: + ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8 + if self.version == 5: + ATT = self.att_one_v5 + elif self.version == 5.1: + ATT = self.att_one_v5_1 + elif self.version == 5.2: + ATT = self.att_one_v5_1 # same as v5.1 + FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8 + + x = x.to(dtype=atype, device=dev) + + kw = w[f"{att}key.weight"] + vw = w[f"{att}value.weight"] + rw = w[f"{att}receptance.weight"] + ow = w[f"{att}output.weight"] + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + ow = ow.to(device=dev, non_blocking=True) + kmx = w[f"{att}key.weight_mx"] if wtype == torch.uint8 else x + krx = w[f"{att}key.weight_rx"] if wtype == torch.uint8 else x + kmy = w[f"{att}key.weight_my"] if wtype == torch.uint8 else x + kry = w[f"{att}key.weight_ry"] if wtype == torch.uint8 else x + vmx = w[f"{att}value.weight_mx"] if wtype == torch.uint8 else x + vrx = w[f"{att}value.weight_rx"] if wtype == torch.uint8 else x + vmy = w[f"{att}value.weight_my"] if wtype == torch.uint8 else x + vry = w[f"{att}value.weight_ry"] if wtype == torch.uint8 else x + rmx = w[f"{att}receptance.weight_mx"] if wtype == torch.uint8 else x + rrx = w[f"{att}receptance.weight_rx"] if wtype == torch.uint8 else x + rmy = w[f"{att}receptance.weight_my"] if wtype == torch.uint8 else x + rry = w[f"{att}receptance.weight_ry"] if wtype == torch.uint8 else x + omx = w[f"{att}output.weight_mx"] if wtype == torch.uint8 else x + orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x + omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x + ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x + if self.version == 5.1 or self.version == 5.2: + gw = w[f"{att}gate.weight"] + if dd.stream: + gw = gw.to(device=dev, non_blocking=True) + gmx = w[f"{att}gate.weight_mx"] if wtype == torch.uint8 else x + grx = w[f"{att}gate.weight_rx"] if wtype == torch.uint8 else x + gmy = w[f"{att}gate.weight_my"] if wtype == torch.uint8 else x + gry = w[f"{att}gate.weight_ry"] if wtype == torch.uint8 else x + if self.version == 4: + ( + x, + state[i * 5 + 0], + state[i * 5 + 1], + state[i * 5 + 2], + state[i * 5 + 3], + ) = ATT( + x, + state[i * 5 + 0], + state[i * 5 + 1], + state[i * 5 + 2], + state[i * 5 + 3], + w[f"{bbb}ln1.weight"], + w[f"{bbb}ln1.bias"], + w[f"{att}time_mix_k"], + w[f"{att}time_mix_v"], + w[f"{att}time_mix_r"], + w[f"{att}time_decay"], + w[f"{att}time_first"], + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ) + elif self.version == 5: + x, state[i * 3 + 0], state[i * 3 + 1] = ATT( + x, + state[i * 3 + 0], + state[i * 3 + 1], + w[f"{bbb}ln1.weight"], + w[f"{bbb}ln1.bias"], + w[f"{att}ln_x.weight"], + w[f"{att}ln_x.bias"], + w[f"{att}time_mix_k"], + w[f"{att}time_mix_v"], + w[f"{att}time_mix_r"], + w[f"{att}time_decay"], + w[f"{att}time_first"], + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ) + elif self.version == 5.1 or self.version == 5.2: + x, state[i * 3 + 0], state[i * 3 + 1] = ATT( + x, + state[i * 3 + 0], + state[i * 3 + 1], + w[f"{bbb}ln1.weight"], + w[f"{bbb}ln1.bias"], + w[f"{att}ln_x.weight"], + w[f"{att}ln_x.bias"], + w[f"{att}time_mix_k"], + w[f"{att}time_mix_v"], + w[f"{att}time_mix_r"], + w[f"{att}time_mix_g"], + w[f"{att}time_decay"], + w[f"{att}time_first"], + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ) + if dd.stream: + del kw, vw, rw, ow + + kw = w[f"{ffn}key.weight"] + vw = w[f"{ffn}value.weight"] + rw = w[f"{ffn}receptance.weight"] + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + kmx = w[f"{ffn}key.weight_mx"] if wtype == torch.uint8 else x + krx = w[f"{ffn}key.weight_rx"] if wtype == torch.uint8 else x + kmy = w[f"{ffn}key.weight_my"] if wtype == torch.uint8 else x + kry = w[f"{ffn}key.weight_ry"] if wtype == torch.uint8 else x + vmx = w[f"{ffn}value.weight_mx"] if wtype == torch.uint8 else x + vrx = w[f"{ffn}value.weight_rx"] if wtype == torch.uint8 else x + vmy = w[f"{ffn}value.weight_my"] if wtype == torch.uint8 else x + vry = w[f"{ffn}value.weight_ry"] if wtype == torch.uint8 else x + rmx = w[f"{ffn}receptance.weight_mx"] if wtype == torch.uint8 else x + rrx = w[f"{ffn}receptance.weight_rx"] if wtype == torch.uint8 else x + rmy = w[f"{ffn}receptance.weight_my"] if wtype == torch.uint8 else x + rry = w[f"{ffn}receptance.weight_ry"] if wtype == torch.uint8 else x + if self.version == 4: + offset = i * 5 + 4 + elif int(self.version) == 5: + offset = i * 3 + 2 + x, state[offset] = FFN( + x, + state[offset], + w[f"{bbb}ln2.weight"], + w[f"{bbb}ln2.bias"], + w[f"{ffn}time_mix_k"], + w[f"{ffn}time_mix_r"], + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ) + if dd.stream: + del kw, vw, rw + + if self.RESCALE_LAYER > 0: + if (i + 1) % self.RESCALE_LAYER == 0: + x = x / 2 + + dd = self.strategy[args.n_layer] + x = x[-1, :] if (seq_mode and (not full_output)) else x + x = x.to(dtype=dd.atype, device=dd.device) + + x = F.layer_norm( + x, (args.n_embd,), weight=w["ln_out.weight"], bias=w["ln_out.bias"] + ) + if w["head.weight"].dtype != torch.uint8: + x = x @ w["head.weight"] + else: + if seq_mode and full_output: + x = self.mm8_seq( + x, + w["head.weight"], + w["head.weight_mx"], + w["head.weight_rx"], + w["head.weight_my"], + w["head.weight_ry"], + ) + else: + x = self.mm8_one( + x, + w["head.weight"], + w["head.weight_mx"], + w["head.weight_rx"], + w["head.weight_my"], + w["head.weight_ry"], + ) + + return x.float(), state diff --git a/backend-python/rwkv_pip/utils.py b/backend-python/rwkv_pip/utils.py index 0429fa7..8da5741 100644 --- a/backend-python/rwkv_pip/utils.py +++ b/backend-python/rwkv_pip/utils.py @@ -16,6 +16,7 @@ class PIPELINE_ARGS: top_k=0, alpha_frequency=0.2, alpha_presence=0.2, + alpha_decay=0.996, token_ban=[], token_stop=[], chunk_len=256, @@ -25,6 +26,7 @@ class PIPELINE_ARGS: self.top_k = top_k self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3) self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3) + self.alpha_decay = alpha_decay # gradually decay the penalty self.token_ban = token_ban # ban the generation of some tokens self.token_stop = token_stop # stop generation whenever you see any token here self.chunk_len = ( @@ -84,7 +86,7 @@ class PIPELINE: sorted_ids = np.argsort(probs) sorted_probs = probs[sorted_ids][::-1] cumulative_probs = np.cumsum(sorted_probs) - cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)]) probs[probs < cutoff] = 0 if top_k < len(probs) and top_k > 0: probs[sorted_ids[:-top_k]] = 0 @@ -98,7 +100,7 @@ class PIPELINE: sorted_probs = probs[sorted_ids] sorted_probs = torch.flip(sorted_probs, dims=(0,)) cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() - cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)]) probs[probs < cutoff] = 0 if top_k < len(probs) and top_k > 0: probs[sorted_ids[:-top_k]] = 0 @@ -133,10 +135,13 @@ class PIPELINE: if token in args.token_stop: break all_tokens += [token] + for xxx in occurrence: + occurrence[xxx] *= args.alpha_decay if token not in occurrence: occurrence[token] = 1 else: occurrence[token] += 1 + # print(occurrence) # debug # output tmp = self.decode(all_tokens[out_last:]) diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index ce025ec..7aa3042 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -36,7 +36,7 @@ class AbstractRWKV(ABC): RWKV as Model, ) else: - from rwkv.model import ( + from rwkv_pip.model import ( RWKV as Model, ) from rwkv_pip.utils import PIPELINE diff --git a/backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd b/backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd deleted file mode 100644 index 09927ff1684428b1104421684a7757cb6a5a4687..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 433152 zcmeFa33yc1`3F3K1cL~POBB@VAW;WV9Ka=t)Ip6Y5Vm zBiV5JslPk@sFRb&9CgYmr#2>!J}!Csq*Ib7oRX|QV0`lAQ;$7vc>n(ERx0X|+ibtv zF=e`qJzijl_iIed8?H=y)y(c!HJZ%eG z{uJ%bL3#3id!P8gzI&Z`(jo*Q@HYvP%xzj61BeE#}`W9muMm9Tc7MB>Wq9y3Z0@CHwUnFeLE>G;~KVscu?=@8$A!3dFbH@u-hf`JaizI##Ku@0G{`)HZTI zqGel?-LfPx|L2KB>2FIC6NG>wZZAokbc}xI$7g*K^`M)6rZ};ApxXU91@$R{<%$!T zAfEE#MBREWRP=Lr<8e)msDAQ4Dko*9a|He+6N&8b(~mu>@u)=Nwks(b0OsIxHDCk( zGH|2e7A#S>2_7scu|m$mXW(BZk!TyIU;O7Z zudH)gt1C0kjOt3KY+=QF{fiS$OHHL?{%}fB;_2$j^nQ@rY3;w&7(AF!Q<-W>05nR9Yt3(;@0F-&sjkejF6Ff3%$x5QNe!5sP?=`UGarBeJ}*y$t>&dp06D0H zP`iz4Dzm0HP&+M<0K5}d6)9%T%I7S_de&u#_e4-sz*G%jl4K~w$Bu1s>np$%mdXiapuM0~iiL#@V-$+GhH!rOUV5j~9{5zd(WUkbiQRD)- zRs(4vk#IV5R;RWMeK|8=)^kCvu~$bww$F4{Jl_Y-lTV8Zaq*?tB%NqnmkrWp5~Klq z$|dV5xB&cdh2VppX4Ty?Wj=(58x(`%6sQ#s$+axGsI9RF+7IyiAdpO<65y0hshM4X zfgt@V%9trSb6vC&oplW2H>>1~^Wut@cJmzY&_(ee8K*E;e$~G!f7DBhOTmx16s^k^ z=B!)2Y=}Bf={2x#GOrUJe26o3TKf;jM`uOb z_q`LN+Pp^8b!Wx5CG2^&QPMW5-%%Q!C5>975#l%4r^z}imT8Mg8{1l%wzJ~i0G`1X z&vJ|B>NJV(9l+yAi>Y~dYNd8!0MAwy&wL<&5J!z5o^QVk2~j7Fb_p~ZVjDFB8jTu7 zh^M3x-IqdB_fr66I&ZQb+z|2L2Ew!xX5u7xuocwOdZrkN1K!D6<6LK5*7Z)xVn%7D zcnj1=vCS4*ITy&OQ|TCewbJ!YD%yN;t;M+jkd^^!v2n`4BE`U3^SlDLdBD~uHdo-B z0<+d!LtxOJj!O7ymfwYQ6zW(Ij!}?4F_%qIkUru1cZOi(j8CLbkR5!NF43kBq)U)A z9=JV%9T$ce$d8$o<>cqg+1ex80cwUmEInd1zDIjR!TvxhR!g7ZTW7_P-Y}E;^+?p9 zGn89D8oUGj=-zKaene`W6`%C-8;!D!T5O{`q|y1($dX~t?yPuNTezK}RBHbmz_XLZ z^S;G%t~l{70X)<|XGLQk9;&Ca;!DFEJ!fJH!L zlok6ifC}IhZwlZK-&w#k05^q05ikwFALv{H_}yCo3}n4XpqoIUKa5XSsyZuv|E;wh zWJk|6*WF=@X{NF0yqfu@Ix7ZYx`xaUb@G7P>Ya01mXtl3XkMK@HxpaG+Xa#o;}3{| zvy+6glKV-t1?lE<-_T=M%%B#t9muU)`wyB|lGtZd8-VYUJ_1vizhBmiV`PLvxKWjxHaENu?`SQlr(IhsIIdBLY4h}}QlI@+>SuWMm%S#&2xh4+{wg=$ zM$fFfx+G!t-kln1C&G+bcwR3_zngPK^dAbA$O@2PVO$?M7oZ^is>)G{;rH5l5@{MY#TIr;4aQ%&_+r8M?+>r&|e zx$yik41eGovG6t?c>GU%Gmrm*^gn+q7GFX58<)qx=cV#8pN2QpaweNN(l)4dn%hdL zUyS`Z^NY>dCD1$4X3LEUz%-zF2?S_211cDX5DCl&9kwruZxR85tW48$W<@_2=+nD= z0vY1o;t&}{NMB=ASNvsv>o@oLJKN{)v5GwYPT9yM{Y7E^#(Df5kk8-w2d@c#ZAZlM zcbm`Ok@@_+u%XLe@ydewyLDX`=&n0`0;0dS4~p`K=qE~??mEUHFmi@#Hi+D-N3~(7 z!o2J@R*VQ%lr03rbVQfxvV}F3BN(vjX~V2o-0ZzRxDZciUYa@54Ua+p$3s}9rV=YP z_(Xv~9v5ID#iwaAG-WY?ot)WxUn^2tJ5>z)*V_i*Xa1sF_+>%(hcjXL(0Yh4RAuH! z1E#7oipY4{c{)?t7h19fKe)*HgR53opej=r*B(aeY5Ub%BHyBFg7ne6r|n1#j(t^BiUdV#q3UzK5ABVPgkI6%y4a@ZA-_H!+BB_}bw+GJtPX z5Z~KdyOI~;FB(tO=r@O&!tNYrFLf)+7AB^Tg%YTfko^nd8}BxiwLr!9^H60R!f;o0 zeJW#$LoL-UL!8+0zEJz6SP!)9qmw)iaIC(6wTHwns($=`*^WtYU!t>Vj%4lQkUuXeH~+yzo&cCPE|J7Ya=ub? z&7%p7Pq0*MMiiy*+F#K@$a1p_Gd1ymnO=4C64<}oY$lI9Kk(<9UZnbI;_E1^ic<=X zgz@pUd1V!YG8Wx+{X1iRRcaggqQ}{15!Ov8ogZI(NAl$fwZzCjf#ILi!CQA%BW2Ae zt7s%$Ejfy4VJB!16D#v;EyBKB4cuFc8A0N-s1>*Od4VbU3_>mAXU4Z{7E=SW!6Rzd zEHvV63G{&9uXpqKJwM3r-Gf8?-lRq4_m5i49qKcI?uqrj`1T(p$yd9uX3p%zf+E}Q=RstJ+RqEN-!;@e#7xL5L{*Ndm$qWD zHgQICC0Dbh3$|92tm0sjGh6-4GUfXw0 zi8p;onM=Ov>_nC#Gof;-pw<*0=x0JziAnu2lAU&0al%}?E9dEczbd4U*;pVIWv(hK zPpm3izfi>W{V&F-2e?T5F~Bjqhbr^O03TmS(}<4<1(Jhlch8yiu_}N>T$a;`R4f|eOuS}TtK+?uDEtwnWr1<9;$r~oj6Bw4yezd=P{Enw`x{89VUH=~37 zDexRY!91^yk1T;xrkr^$1RdD_5$7_G1mMdFe{Ij<2NM3Ip2NfBIkQL4;bG#OS+D2t z&>i+$F*sN^vtvr0euDDC8*=832z-7%oi_Cd;Pg3hpR^f=4pg?V4NFfYrvElfm1!AE zF6e(c!KgAhaqV`i@%<4P0xd%RBucC{-PMtyfF)!2}Yg1$$T_H(gb!H2n)1(tUDbBFocTR)&a&X2UT0-tj=;{B z&!5PXHmNFtaWU`FP1(g%69*o{LO^$H@AnevzR~ffDv)mcO}v-k_BcT-=~=s(j&MxG zzSP8Q|Hb!5v22~$l2}}A)S7Z1DKeazmdDE$mYe;#7V0A}kk6&fs9hm6ERr$H1~FBT;|u11`g@kIN7 z{m+B*Z}yuj59S~#n#p52WM6S($MCF@JAAPf)@!*g?>nt^HThdToii9n5*u-tXUW^oFk`ySkXK>aG?n!?HWf`7TmLYpX&af=hP#pgG34pr1 zJ7&cJb~u2mV-}?R9{X$X8$}n$pOAhsRT7=24K^MKAKdB_1Nr&Z)Rl7p<8Aw4e|aw3 zwCd_g^LdgHXz7akimb}Zrrjr#^=TFw+Bhat?6l~qBiWA?_xj>>iod$(f_@H2)B$s% zW4P1OU9(5n%pMX$*JdZDdb@9T9@+sijV9@>+lms{5zdrNdx}*16HiK!rsSYs*+q9y zRD4e*$}XMD@7;2?;cLAU*qx~C9`f^7tvRza2jCH}aGnCq!Qf8dWhZ$? z&uTNo(#MowLS;3(rem_P^nWS^zVvrd;MU#CumF&>djk~+cr532R*bR&u%|J7sWnIc zSrj;d0$^dGrgDP00aXPQ*wM9#dC~(xabIe$D_6j|0H$~XJcw+%z+kD#>pM^DJgEr> z>DbBaR7v_@p#M)@IUmpb4mh)N5z6LlL^n#0(QleP(d2is1$0%vlU+s*E^0gf$nHUp zCFp2iJ%42Pz`GK77+C-xbsc@ZE*$KVvn_Lo7t_i1__H)#Na?r*8ZRc52VMEs8aqUe zf?_RrMSQgiEvx*_{0OMz`J&r{Z=zbd)z*P8`#rd826hA_MGy7hdr*nGJ;2o-_{^Ah zZx#nc=iy0gIAFEeVxH@O{R3iml&sjt>%%>;)qV(P~YOB_K zhAP`#W3<(!J+-@D1tPO-&o0Pfu*d%HMJU$w6C9bGLcehxc!O&=x@+asrRpR{6d+!&LtK<)`QFs`TQ@Zzi0FSZTJ0E{9JjSj%t;-OBAx)3M;|sLLe1qwLKc0L0 z+lSr};4Cno;(V$9laM%=7UVtO0o%@@=h3MmPYPj&j8je+kuK^khBEP5FI~iwLY4NU zkk~&Xj7s$5d^>;Y>7P7;@U>KyV?G?s6ewrTFJ;GWXJy8Gj!6YjV_fXw5K=XKHaxZU z@Y6jN|HYfr5lCtCAcf1ir_eM0Use7eT_?c*i0-j>sUJ}_J)eDN_A0BP-vlZAZToID z6y?7j_N=BoiMYvp`(KN=HxWyRZP6UC=u0~=VXme%odcl8VQ>JBEVXBpn74OekFKs9 z3oB$iuz~y7hLX`aM2bG)=XMMVsFRspQ;qgO8VeY>lecVyNL6j^Ud^TraH)D3RipCM zL7pV#>8#ii0|}OG*TZJaE5k5%Dci!DJ!`h^LwR%Nmz$!-GDK1Xq`IBqDTBlORs%>s znLq*x{S&zeH zzXGT|eVPFOUGbBR@H6~J=+h>t{&+Bz_nXQSiPP7kqo>VQ=u#aUGv&gyC#`$?m7+vP zFPxS2h7`9wf&>lS+_o({j&z|~*$F(e#LQWbA%>^+u19<6`$Q`ju-5CW==&Uxmmg0gfp=XYn)u7f;T0d85?;ERk}4HUgO(;f!qnp*XPug(}@ca zZ5BSwm|HAEV1JrkLC|hrC6vhvu0>(x(v^y(a+b1BRaXucxv+lQpE!9h&@$MU%Lc=i zI7?O+6BMC?e|EF8;mI{j@#Y5E=<^ZSXtl?`4H9b^Yrzwwt!aUrE_kWg{1j5+S?C!> zE5-nlsG)ArOCDMDJK99amR)=oNP_94O(&r6g9re9icK;dxSPi?uJUE563%99g?_ix z5;Se5v4Wfl2Xz&DN3nmjHAl6bKAcC>ef`y?{Gx}s- zz>XuGiB@OjR>IyzR3}yfRaL(tduW`mrCo3uPKHnN0_{liI&-9+c@ceKsi73}HM&D+ zzQ*&iXjK~QA1q^w0#(-ES%%uYa`Vc8Y?BmK79$J7|IKnQGT;ArofVK-${}pvyvsU> z@z7_(6Ydbw3B*TU zKl;1nRYHH2To`QU6URJ)1c9QaEkD4GZ#lEC%nL=95Wi@He&;}$u&HUl$@?(ZzdfQC zx!%ee-5kK9`Mt%pCRDT6l?xy9G7p^TMfAOXfT8N&>qt&r$BQs$v2fi z34R((`Hj9wE^D^~bHpSQv5x9AKP;`w&)CPxjm|)9@rD_@2vHX#Ch%Tj&iuPCFpJ(Z zNX}QxmuDfp&#s&zmE`Fpji{pFbWrtsZEsaCM#1qG9lW3Q@I{!S(yjKb@WckS)KdB? z^PcQ6$rNCT$L#ZmyZZhkVl?z^vowDI_v5iWizMN+sWj0znwkM`xSKlvOC43YzAL^$ zH)!dr=XxQUlyRbKhEuwrsH4}&Hg5n5;YTATa#g}9)^Fvr-CjK%)oP8GB#dJ=QfW|x zS&D$q-ZPHFXXHp;&McSD_Qi1Cgn55v2;WjX(>;BxL3fRBj^dl8__D0hs6eS9Eu{Hh@nF! zo~`d-$Q*cVVew{8a>e@(Q{YHIEU;fzM><)vML@pmp6yn?Q}>LOFZVpEJtJB9X6xNF zmG2@H3(8mDg?w)|y7FyMK)(F-fqGiI!HiM{(kwQaN-e_Obh5jF)?{oEE z+~3haZkErbE|kb0wSR%EuGlWo$8gY`X_O^v__m~7Ia7%lVaK%&ws#wxmiuIh*X`6` zlJoC6>Di(91BGy8nDd#nuN$c212<5lo`_iwQhVIY%Wlw}^!Fw+TZ;$}YL5%GMszP$ z0PjI7IBX6c!m4MbcS`2?5$(ciy_ZL0KS6V=SR|WBH8p7wU zC!lAF>X9Cv;6mGqamED+XT_>*1PQv#x6Qc?UBJ;k5x z0sh=45&hiHsyO{Tm&bok^iv^3{soJc!bqMY zX3fvK;1@pbJ7O%q-T2>foEi-Fhol6Nxc+b@bM*rD&IZ{)9_`*Nuo+rNf8Fp?o5xSO zAU}P&;77V}!Tu1<=ejg%gL!C0z^L`+4lRmB;YEGTHCh8#;uQNKuG8Qfd{V9@;(DU^ zCLW7#?tgJ1TaVCD8Ca!LiMFJ`sY1Zkj`}7w9MUhC(g_rrF`jAsgQrMHlwLHh+}W+v zX1{>|1cnbIo1?C`{O}G z{!i`clAn3lYm`NDOV)MgWUecS`4>6^_pF zd9AlXp7c~TiAN&BoJY1sC4hG048bxB~zMyus;J&>JSB zRTtiH$fJ4Q;CLJ5-PynSC$Rq-#e>cM?R&tf>vZV>3l>@Cf*x>gS)2zHjt6hP6XyXh znx(AI_kcYwaE*f}h#PAY5B@Zv8xJ@e00lfCS6a{ms?n+o4>*pUt~afkT~Gbfx&={ALE zszJy&M7zGE2YD&PApi)3i~vPSQ!r*E(X5M@u|9nyfXi~9mqv6me*E>kpTs}1-TFrW z!uOA|fBF6~q^N8ExG(A-3)T(#2hgg2WDEJnW&a5J2O9YP(a1;n{;{X)ABuY|{bPQ2 z{&7jvKLXMe^pD!E{o`}jKSJEEnSZpN^Aq@op2|$K3uLHd%t`W&PJ%_9MBM(K?-OO6 zppHfb<{OE7qCQc>l*Ee}<(^N-R$yGrIOnyXPoRPC6InjW_la{RYs^sGYv~i!`*#yF zh5#ThW(1@u=o2e{(PhkdizyjnMu_`>PsHwz+pDbJ9IffBfK$w6*N4LE$ZU6$(D?+m zlhB?`I*8q@e#du-@4hSK5(4@>-=GO@YH3IO;S_XeAJ-+oVUl&6g2qld7LnWue|J3s zKDmH3he(Yc(nEV7vGdXc%^rMYeFV#iGb^V6J~&1F?ljkomTqx$QbMJ4IYq77Cz_Iq zCLs^HQmNSl00CbKiy7xCY;E4HF61h6`S|g>t)^X#f%C|1%i#MxLMx5Z(;BdSnWj>v z1D`X*9b7YH6+dsWoO%PyxPeYfu?)L%Y7#AYUg(LBK^|ugU@IE&gOl825V&8-yP%>@ zd)C4&IHmT{(-|I*6OA;j!23kyTOPscC8hKer(<1AT%h_JADMvq)*gq#0@!}vd zA?oF%v#A#p@q;D3I!4;cr<<$@H@xJD(B`z*<4oqk=Uowcb@b}E+S4ZXbaGqPsJs4UI)e}uA;jO++Wea;;8RZ#+=4gJL4mhjfoy5g_BOa5MxyGm=!-|buk zi}1HKj@+&>f9Wpy8@YD;^^4`Nq7Z-i_x;?Vz`gvg%XSx?;C9cF6~%Jce44$EV7E$A zT%NKaOTu$K%X48a9&lI;|L?@x7i;iKHBem5%gf;r@}_~W+VL$@w$R?whKH}q&Ovbb z)gtL8Xi){Iov3qM`GB`gI>)89I7t%2=*df5Dz%nM?aKzOjV1fwFD>qf3 z81487gt{1taHdxTn%pCfcIe4dM75^~1af<4mDWm1IBw6Km`Vyn5-7u7aFq z9t}U(Eyn39YxR{CeUURq^mVK-4ndDqLD}oW`iL7YV)nxvXUTF|0F;Q9G3z44va4fr zc!}IYHh1n1-ka^VR{nejzZb>u?CJ|41CY<|igjRO(d6#C*XgUG-^ zS4a8(UHxtLY`*@!e(L|%U(fyTjb-`z8~6XF{zAPrHP+sgMCe1O!Qf%e;}U0c;80`; zUV0W>6=CCPc1G!qD9f0bL@Ekjl`PQBhqln}AOI2^KEHz^H^<^L3|r=8OXQv+Tbe+;(KMukLK_0 zr2!Srp43zR?C;eKe5$xo!VyAG@CQcS^Wjyz$Xw(sV$O+w7uYn?FCy>A;PU3r^lFvN zwwsZ%6==U_xc#0?+}Zvuy}Fcu;#v5j8~;M+2cZ%Y?nYXD#Gm~dGLUs)JmnTu}> z@l6D;<%%qypWE!eb+x@$RQbMtPTLPG(B7)IbDc2#bzJ(7xb0~sDkq=*@%CRIZSNJ8 z{&{YDoV|^0e~|rGqV2t+wjblRPZwx^dF43%r<0pSUeT5RS(P7wGnT$8OeO>=XOUOb z_K&#jvw7{&^=14udE;+^_oq4xKFAj_a|o?RLM?b{U8?NTRXRH3#ddA(p%{$4OpJTC zQ8|)W3I=X810QA`c7SeO??Cy9oqEZbvHK+R5oF(NUdJM4I<^InVbze!CBwK&6=;S! zlhrmtGQcXnB5{A04Z(e|A(zUrO0{j_hOs=7*ebk_rejl=13@IqIAs1czN^O0mBR6n zvo4HvT%}vQ^^mQmeQ@o6o;7FvHDCkRKVhFR)V(|+8X6jv=Ntw8y{^ESqEz{=KBs{FYTHaW1N( zQ_A!`EMQ8wLA&{ggNtBmebQpXLU@s_`Ihgk_UD=T{N=wT(p?rQ^xntsF2-CKucQOr z?m%B8=xDoka;Cf{GR2&;V^sG|uVm9<`g%M&xztx;K*}cb^mg(tI(~igC|&r8e-YhW z=|+Ww*%oBjmEc8zw@Ma%wc5ou3A*kC&rS1RMAtO~e+DE?jzuATm1J#a#n6j{EB{rJ zbq7UWC3yv^3cgxR)-orZ*u9#!Rq}OQt>I1s7is6qCl;DID=yLJS)!8X2!q+i&@2Cb zw6?Bk_FqKTwYHi7Np0cXm-^mFyf5!tRHDv*@#S_CL|a&`_FjDXJ5&|aR+nnh7tsxX zBCr$QC1~(PbZ=h>PTdXy?i$z^o6KeVvv;!S?PTPyZ|YJ)9S7XnvT3Edq$ihF(B~NU z^V0e5h;cpd)pLC<Ri-f?k!KIbxMw+A<_J8+&=x256t|_s6?*<)yt< z=3F0ZidYv~V#kH6SM)+&LV9(9P$u!(`@I)j_p5hWN=!yur%}{{AbL2 ze3K5jB_825i7!p4*&ZFHDRJKLQEkIeu#4q>dx^~^ zR2pS(UW&nP12JiA&7XV3kN9R-qJWOU^+Rd6db!+PY+tF4ltld8n| zhgfB~*PY##*bia?ku#4V16LnI8o~_dcGlnr{u<;(^i~Z(t#i_>nTTzJjvFCsout+U z@+0v>=!ngL7WWN~bseWHFetz9d;8uW$pJDJ@){2kuB~Jr0e2GcZVZR9i)%*QOgWzyD~^V@(c`4P-t_vy zhO7)>lWdtZr!b76m?95J>aYF`?0)dk%|JuX0;D8~P3A=}wShwsWQu?Y4GH5-=1va; zPQkY7Wyy}~S#BVO)EWCKUf$kEXKFHMdd-IM5xCKTp~?Kldn6eBfuZA0Vzlwz(t;f6 z4!JvzcprFKfl#eF<6NSaj;YKRxu_e9JC=1I=MZ7Fhc}t`vThGA?#Oh=up}OqYTVN6 z$0_sX#ps>Fh&y6g7JNomoCqb`O8BcChs2B0qVKZi8H}e;XxiZCuJd$$M1YA$I*zbl zr2-bz`z`^nF}!vW)V*zlS3SzTg=C|QbZPTAE8M}Q0%EKd?it@CWG*{|FIE*!1(qeJV zxB|C!!K}v)d=$nGY18*LJ`-O~Hwpv5Mdgn48L&nWH&gB^@X-DySP-jb9aZMbX_mb2 zSSv379tv)LNa8 z+(y8&A}i{gANUPL%77WT&b?dhi5ydcpBB~()T+{ss39D7M4foIH!=ve`r~I5A86Zv z$NCk&Q>U@fcM{$PLn~eFG%qi5Mm_EfdmOW~NjtMvcxzDJa%O`@Q!{@-11YtTQuEDUYcvyhVo#2+oUNGpWzw zb;p?NB6pIlSE3lrcfK`y3vtnW#-XgWg4z$&&8vHz`wKom!f3q@7ANubEJyq)3D}8# z`psRG7q$xs3t(^YVN)vHn?T$A42Vlz#HA!t>JG7m^-&F3J>N=pl(Si|#+DJVYtK83 z-!o(wp|jXwqy$Ipy9hjLboh#Y>fE@Rfd;EV3vJ>idakFc*y|SiBNditDIpR&`K~)*M-DYw9iY zo|@~;)^FM;84^gFSCOPEY9!aLj=?mHk$V2BC9z#^ps*jmVEg8v+Pk(>y;i#f#a-q3U5=}$lvc36Q)6qI*h~(M+c4K;Zh754f$ce{9elsF zC)fj+B`Gj};H zoijgU$bl{9l3Q<)%niO`|NNs%unIMB4A<6Rd1#ms3*Vpv{=qlIUw-3OIZ8>E@CU9Q zyaymjU53vx%&eO@F_Wec9eEK3;^}M+WumqI{v!&cebv1jjFdJom$N!$rl8uRg|Vx2 z#Vr)eqfmD762|JSqxku{gh43&E{yF9p9pjaTi0@eRc&dl@ z*rjm;SiQ*lqhls*L`G!vDD#8*8DAMubz@wL zS2y$f76Zp%$~+Qu)Nt(~?XkRN4>*Y@GeHJZIs%?0({ z?EENw>m8V&!pCzXxk-Gr%CT?SPKo})&ix=V<``}F)*Q*#_H}vMcIS{W&Mk{b-Ko2X z>Xz{YGzvMlbh>t0+%Xybeb*#5zw({>3+4vCn!Sh;)i=79PTI16McBgyyy*}O32zbi ze%)tZ>}$;d*0gKESu@t1@l9mQxDlt8_4vbBAGmoWT^yZmwovd;pP+Z1s@Ny-2YzF& zV9w!o2Uv-tlm;ZISHoPE-+)gqJjmiMyFa3pCPlXUr_JPY4igOZXi?8^Sq0P{2WHFp z4E&awruDsx8b}VTKFAyRP)+S}KGa155%VykWD(e zVi1i#bv3O_V9>v?(oAw!V~cgx;v|BQISg;{r*e`^2l0CvLkQ7eb_GKkv^8Gb|FG3! z3a8DnQItYs@XOAP)_C=ljDZ?mvSvRlao~KFlduHMrlK<0pvj-=@>R9&ny>Eh5icyv z5!KsI;`+-bSQVF;sC3e1&rh&2`*=uS1=imohBRfI3m{Eiz!M@8mNZ?-m|F5y z<%3RY6Bq0lfmr>4{`zU1m*r&45XijZuV^M~_p%!8@G){x2HX$uekJly%f&#*R_iSJ zx|h>iu8Gxp%q4y100Z|3E@(NUmz=kaj~%VNGKBQ?2-_`n*B%02ot-P@lEOWL}kIPLx-Er&+zr#c}dnVSy zb#c8eujf3CjVBmovDz}5n6xfIYghb5y%z(V?+|3o!Tt3jY6q&<5jNw8on(jn zHYoOP+ia&1T~f8yY?Ko&(N3&Mb591Nli#-{q2mZY9orKQq;+vCY3guU3Hp#lql`4R z_;8)=GN_!#;=F%d`M2n!_6+V&!b0nL@f1Pc{tPG6v~~&t#$mH*m2D!Gt{iDx2hh_<(|WZf1ar<5e#|+nfe*H7D)kUw=K)r-ZfiLRcF`Wd0N$`33SZ?nBy7uS zT1&p1$de{$&T(MRGkL1mrTPa@wR;Y~^y<#Lo_LVd6GxCrs{ivME?T z-jMuQongK|ACJR7D6w|-{P2%f%?0hba}Vq}7;L({h3$C>Ga&Fk(mZmNCFI-lcGd*! zxfx!)6t(C5yHoBMdv0@Lz@Cpa2JN{7RP*g=-`5_n=k_39(4KpCVbAf`6|(1y%e%4X z>IToA&#&K&J%ihI1L26N8O3s2%sXH+>ZO+A%Ejx#5x)AoV{C+k2@gcyUB8NDEOnB z4QyY_qeBmt>d_(Lr$D^N@|DJyejve)KZKv4i~w%ui{^6%K|JZJPtx|teT!}`yyV76 zvYIyUYY~Bxg1jP)C63@M3x{U=I@9m2BT^BEtCC(Wr&uwG#aRk>c4|L6j;ldLz<`1t zGTScZgG8+%Kb5Z6+Sd(M{*8 z(0g8xk)vd2+hE>i9Dz!{)OaH3asE`R-_dH(6r!O!HFg(PnLO=(FJ~wo+n5K4GGf@y zm>WE{GiKG*mTfuzlrc-o*|E4D+3c+q(clH_T}RpRdoIf8E-wExDHvtIfVLV4Rb}Si zLl!r)X5zTuvWS&n8My0Z^*Iqxag|v$7b{-xaR#jW2gpl7K zUgvI78*W-}%Zgfa#N{-AWDPl!kqai8D?!5I=^ckt7IZrpTCgj~{XgyZ8FQ?rleYNH0IO+scE|3AM*TzfWP@~%?HS!8Zt&@@AEtX@ zh^}&thl7FE6siedRKm1HU+UvJR|y)}~W zdZP8^yG)BJUlzr%e0L&Si`dUjP-(ugQSqL;Gbo;{IL4Uo%S!{|Exy(_-^o{m#CyA6 zo_O8FpGS$;m+vkurt+O^8*r&Xl3s=(TWpx`Wd{V=d=Ddw>t#=A%Bf!Vy9k?Zcj|`U z0bzbSdFTqPMEISIg@XXU5;0@cdV?I=(q3_SfZvV7{C>1_4Ym8RI`oR+r$nRwX1f*tuKOjV`()=tju`_PJ|jz>RtZId~B#Vl~ zuVUB-vo*%(-Z`wPM(y*j+k*BfHw^-Nm$D=~=dTUa4$D=ua@*~ByvqF;g&zH2oP@fH?7-xZl0x1T5iJVEu zwwQ!g6`wCMGzV^IS7n(lb(g-+x*ZRKm`+&TE zh-0t^`aS_j!umeb$JB%NL#2I0U#I&>eDn82U;p^78-J_|XsHMM%>cfzd{_Cb^`QMY zX&;gAavw?e@&%pr>u-aZbxdTi0>yY+3J_mg3UgO3wRlapL?PoYwx0_+mOf@0DmCQ;)BmlLndr z|L&DIen!c`C<0(S0Zs%!*#jkevnoIOBOeLUyDXMo30exh@4f)KbA;Y!_)2YeJc8Vf>qtn`x`73^_N@FQsldq=pF*P<}tJImGqL9-uXn{ z+W#B^>~85zWK}oxj*O)@1}%l&PDHoA&`aSf=`H+7^tmySxAxzV0MFnDY2cr|S=9}_ zPoeXueAnU4qe5>szL+`5AelI z6nd}YE9qT`H{`=#otIzk^mdG; z_as^hy@!czGoiNtUrDdl(z~9>Tl@cs0N3CrYv3;}tm=l|W3X{lzE-ppdeeyRcA#q> z(}b_2x2~mkJdwBdKZXF?S$gAH)eXG^V(I<&Q>C{OU(9Hsw;W$d@7xbWpD&;Yde0Kz z!48q{Ay#!m?;ThrD&Ki%De^TE-LpW4@c~~+uaBj7Fp*IIBmsEyk&b#H!m^ zG1zGdR8VDRRulMuCWNhwnTlc!)=QFdU<>fvlCZmmUzRd|u$y%0BZK<-G4#vB^ot79 zUx@wZpPv3Ku5Se7N3nqXyR7$rL0{HaeELO&>0jKxhxB27N8T8YU*1dbOA?>`=nkK*5A+OQ-|q4Kp*V)W*z;RcSMnUo*hLijLE*Gz zX614e%wI>d*XGR5580)$1_cuGF3@xZDp#N;1%lJV>lD*53e@o+G1)Wb3Wv90`Ea8Z z?g@pPj)w||H*NTE8!FtN6mFIeR}zHl*xgEShQe_lUhzRcKE5XuZhwWF>%-yPz7O{& zg&U%9^L;o>n0>f26mBzx!?vSKmqEqT%l-(a-T|KJ2Z4~GQ0wJ!*(oL4% znz)DqhqR|5SilmQJpjRi!va;qLWTv(&jP}?z2(IMR$cvDa|l}gZMP%~qzro?q;p?_eGQ*r^LoC5+K)g1Sq9I?(d(~ zE0F6q0@Pc9S`-sJRWP|uBS3FgS#rlK5Ij|YT#pf;#R|mNh5DWyYd;07t&5`~-V!C8+Hpf7f{l3q(7 z_(_&#mcS*;s&EqZB6x^;k+{bcAoJXV5_#@HwpkGOAa5*)d(1__cMrM(-J=bk=49(0 zG`}773i!u-0E}Hk*LbiLf4ufPs2KP2J;2sK1jsdg8-Puz0?kq&_=fqwIzDM+G=-+1XF;AG){W4bLLB3x=H8=Tyt~82!8-ULf}N z0?qiLq?I$9Da0QYqTYiTtPo$HVsV|O5Dgy0rzwl$HHA1pA(}jhM-<|Ih1fwLtc^XO z=?XMUfl3u8eHG|Y1)`cgaw`;QumUYppt(Vy zR}|>GlhytTG(QNmrSkNe0$rd$v>(a6j!*NTJk*HslY0)y%mta67Q8(}$D9_NWYQ6* zrEh0>a5%B+^HKk?V!1E*%b60D==+nb_><80=|A9K-E%`#&>IRdULg<@72@4(tT7fV z#0Z7RdJuOh#8nDWrV!$TRH8XgAtour#|SUvLwrym_E(7U3h}H$hz|_|Me6}HFvCcJ_z|(cGXMRF zYJcXwVf(9?Zz#lgg$UTc(yDZ^LX1#|fc+KXDupOhh=BbSVv<69j1eKv{t7W(A)Zx; zfc+I>ghJdZ5YsBBKqa1OmkzPauBSlvDNxw{3iJ`S7#KS*Qy|y=V%lfqp_*0+&*aP=n_<*x_V1{K)&pRdlAB+^y2>c;1&T@Vn&2c3Cy2S$aP0vWJw)R~9)!ey z)8-h`#fb|#v)Sg-L;}x~_Facb+qCEtLP+0cYK5cAonRKs)Cb1D5I=kgqB6jvw24Z9 zgB6E@z*Hd;K+Ge(fvs0rWhsPD_PBWp;#=(Q23gI+Q>Z=pX~VmiRAlGJ@Z0?p6yj~}|0c>_rb7;0z zsIy{^DlDauKwzzmU-S4GqezeA2`$S{;C)~r|8Y#Y0E>%e%rRplcF)A>t1!N{<2?Ik z%-%i2m$!a}^%Dskjyc03TZ|ByFq6N8j@(XMYYvfTFcsge7-q5{z9EXQl=!v@<6FnY zC->?@7z^yb)=T^{FB~fp%kS};7yoOS#KC%o8?cMx@k}9XP|#}wf3i*3K4pgQi}rX7 znQiPd=lYM+rt=z$F~i4G$oLWR7=cGZCpsPCrk)paOaqrR!r{Y?e^I3<{{xi&Byq3b zBmUnw#xrutoU%_o{}20*albXf|8?@X0RQEgIR5kZ&(%N1TSALD6*32Z5%P~+@l4

#I0%5R#Jr(H>M~iS4X^|q8u3Tr{<1^m}UlvgKAXaOO)C8%{ij_M^KlQ?HRP{>g zk4h><{;$OA?E|>aw9n`xNpsB&c;-nEj=zL(jH9T^iE95Ks$^%y{-Z(+738OP5J{gP zl3Mf8QKA;%r?X;25Mx1n_u4j#g{g)A3CVX=2;VRYt=O}~-Y1A1@@IYd&8Ck9?U9Es$nV{@mDRv~ zp9eIMHCMF{k&JZ2m7l0JsZg)W2Yt7Fq<$@N5r;Z?7Bj?GwnkSElDLj=g6 zdbQK?xws3w0V7A2)j*$D-Jnbbs<~}a%bTTWBG;I7xELVWuFt}NN6FH|`fmIC4zu5xTYX>oxqzN*DVw)i_Ow%KBx7AX%>YsqpgmfGTrTEy)*K7Cw^18wnMEhcSoo))n+OW-TC zn6|}>wOGaC({`#07UcVMe!)N6qYyXCIM)`JYEk{=el4oM{8@|YFS}~7&9+TxQT=6*7S&(Y)uQ^# zw;NfmmfN-~wW$8`mKN1tUeKcY%VSzpe_5bK^_Ls9sQz-f7S&&xwW$8ms73LYwf2Zd zM|AHIhbc@~9+8FRIxCVxBObvA^a$}|c*OV5ibv4))+2m>5Z`Vo+3-=%A1G?TAKu+C zG{uEM5zi zJ1d578?iVa(Bd>6V?e)W#NxCrEuNQ8;%yz`!7U}5J?zy_%QqpnhVmW$V7z=S6C(2Eh6d$35dfZi=iL*O?*KmX z<;w_UjC@B{hU6OU@mq%9 z@JFhV#k(7KQpP^4mh$T z4iGe2Gu1yN!HKN?i_6auG;)GN?J?&etBIe$9w%?wjXl0otw-(A>(N5?NWtG}Dc>HG z4_pg-j0xIfqj!V$cnTB49^2#Arb703Owjyjdz3!-liA~~f!)~SB-MJ<9;fRo@uK}> z&bImX`0#+Wum@HnXL%iY8eCaQ(m!^2a}De<=*&X)*iq1QX^%jZ!c7Stbl}p$s~ZC) zm&fz&v$)T{?B@BEn{?fN7bGEF;eU10uAhHd(e?8y3-Cwj(#DyG)*|QdU(s&uw)(|E zP`U{11`m9O^^@QJlV1kfug}-X=C{uq3IazzTEz4CY9iSVco8{11d~{FMxuZ7DavP+gQ+OT9(O7#UJ~aY z^*50)`_T*gQ?>2ZZT_%XjDLjMLq4wG&GB1hOb6HT!E6$)g;mNte{$DnYwkoK@-}e~DTBc$_NyWS|`VQjTA{JkL7x+T<9!RQ5N&AZR zt!5Aj1p3O`Us2Z_cvYD&6qOGzd}42gw`R=GxnCyV8pl~vCs}mNaaz=uvBH@G@3z-h z^T^vF(Yv7w_zl*@J6-{OZ)H4IAF`hBP{Z|Xh2Qt!j+c!wtDTm{%6fPC6m_t()Mec| zNZ)|%A!KvjqVOI~Xy3flaXVR5UZ_VWtcH6J?=Q81l=%$|tlJMU<7<_fLoart zX3G#3XmRPBq5cC`V3e&7!gmbMgJ(QLd&W2Tawy(q^2)dx5PF`@D@#5Uz?aN>J}|Fb zo>wl-E7!)w^Qisivrj?e82(k+lUXA!&igN=ofX$@=+0n~52O;4HnZJ+3vMvCR%|;E zb88sHF@v5!K8&s?7OY~YQ5`PboEXbJuXx1?o3lk@xo}7oouH9}HAa}|fC#WGFqvRr_CwXWxKm+ga`Ku>rgL%Hz&L5y9QI(8z(09c~R^LhHUrE(> z+T4l!0Q!DlW2^79xfPdAC_murv89mzIwVqS<|72S>Zob$-}jDUye_uno?DCM;Bn40 z5Q(UQH^=m~dg#E+Wmh6!>*zJI4fx@`dHJg1cf$B@yJQ;z1QXFS}(TrO4wI=3eNXl(9PqHjhHt}K&dj%Y3*Qj@Hqd@ zVeAO__fR-C?Dy}#xb~a%Xkq=XhmPBw{p!Q`m%%ZB|8%eaq`Sa>#}D1(FAd{A!Qwyq z&o2M%j}+#=rYrmsRS<=#?K7T2@Ut=0Mplil!582L;fJgt{HQgAFNojH5B&HF?lWfR z_X>p6Xgs#?AFLUC#fM$OPhDx4R&MU~&*qzJ#> z&3eC#iug0L5>s2ePhiV1Wnjvfg->W+gFC@z58(3fG=XL-(&fg*ukVIeOkGOGBIA?C z?{PPG!EYOU$M(@Ues8LZ<+ooPzkLGyKGQ#*-Vn@)m;C-`*I0fJ-y%l8 z2S4uXchds={^zbZe&g1A`A+w_v$u1AkJGi?%>I^Ua>2at7lk&dTy&@I)XeFt$cvBPFrN3Uce;59md4`G#`s)^Q$+Ktv`uIQ4+I5wH zzkWEckiWLc#mb-5Uv~&HQNUk^24K>1A63}?Kk?U(OS<>hWta58UuW|aVaQ)EWDWh5 zH(=?SzQ3M1KkTpT^bs!B$X^d>nUssAM8nd-QYEy2Y6yMEs*!6PFCeuHOmF~XNa z3NzqI0K-?3eCC)>u}@I;fUGJc&B*-{iF+jcXUxSXLuIm}5J$f*hW@;3g#IHpll}wW z`ScH0`U6S-QZAzT^iK`azdS&{PnYyx-#kWsOMi^e-!(x0p)0NYP!r@2)Bkj2=pKP9 z{q&PC6}95$ub<2Q7e{t47h6ueeTGM4#a8{tOItR&4UM-#;JXtIFN@@JSo6`~CB8I0&a#U+^Nj$$oq| z{)A|Jkgl;Be9_cNx~)COw%TdEo8K1q#xov#m%S1_-t)+5-U9oDC-5usUujl?%}-%P zVqTRwmS^(7873Yy_`kSCg<{6O8XJr2ZSt;vJm%sseySnexHrG!gbs&B4=s-x?lTIP zj1m8;%N{6lT4z^IA=6k#obsWlf<3j)tU~X^FK)GJ$_*@-Pk8V4DcJ<<0Inf#M^_K$ zw*jMh@njLm>cQMQ87ga;txB0$$&G6B$qEr)cc>Ex6S8evXC~dR59SinRq_BDl}2aX z33jJ0rK;UsCV$06_Je|JE>+GkI;~2Qe+Yqu-P|$^IF5l1SZ4_<-!{kyktun{Y=C&jcV8T&= zQXJ(0PLx68o#zt?xu^ouSkNVKo@Kh`__0DC0fMK&9*;XaXU$hj%;*lkLxFtI+LbgQ7CRN%KmF(q^8am!O zBrSo>bR}Ewk=;hgZtarABTM#!Z$;XRB4qn0*|d^fMWm2+GfQ@x5Lp77^_47QoI2EF z?^`AFS3X5rJhEg@wq!4kkX=W~W|Zu?N|x{6hqSoa9q?lUn``5}&HXhq}b43u{*L?=XbJt7k-rudBox zlz2ZSKFSi`E<~KbW~dUMCFs@Eui%UzCEv~^Zl7A>7g*xh8-<+HDqHaB9($519o#dA z*ynSJ7#78{l6ypKHP2R5x#uWab)9NYtfdk#O~5u`Cj!W7v+|SgMin;!I2Rv^*)tF3KM- zt|v}-`ct3|L1nz~sLFWpu+G3RC@9rUm1<4)6!lNqV#lA|l z&ZGLu+g8@-cyZI0qT(|mWb?<1yNMLS{sZJ8tX3WdNA*Z%3y&;;&8j7oc3R~W zwsa}KfVfm}yclPRH${l&j~8`He2OJ5Qt;_Ta57@!0X=xm6Tv_;5X zFk3?Q90=fr?6`PsHD0p#VlQ52DKiHOe){pcO>CJnui*==-?>DrpQL>)sKo2ldIuuI z`dL6>j{m1seI#Cw*-+GO-O8uu$Ln?_d90G0X-P`Fc1hZJeW6m7cr8?y|HD#^#OtRw z5UK}Kx`32^yk4$UKOs&Cy0N7y@!F%>3qWQ)r7H1Ss9vK~^W$}AMLm8B0D}__H0P1+ zO>zifpDEchmFx{|qF)Kv9$5mLEA(1}1Z*KYO36k8_UCv*Kgf=ckYx-e*@QgFnS_!} zS+WwaJ+cHgTPWGAXW8dow>pjn?6WP|2@$exz%Ey^Gl&#Qz8mBbOC(@>WF=t#g_{-V zZ4$6W+T)aLG+_6(WDkmvbpv)%$!?%zcei9EV0&aGV5gO=1Z*Mu!)sRBXuzI{p9z4H z4~~#^19n=;{)tE-?GqpmX(eELWF=rP=FSPEm4Gc|uTrwnfL&q99ugtz2JDQItyHo# zmaGJ9kE{gjy@aeEu!Zu+n_Gcn6Z|=!h|F~)ftYzHJ~hp)l`6Z1We%;e&n1MpF>@|` z&mI>Jrn#&(bsGj^=7M-YR+}APv@w%mJ|<>L^pC{Mp^TYB^JAt&JjBd%FO!(r1OX&w z#>Gns{v003njamfun5a8=5;fCz6*uI!b$cO3o6m^pWMk1HE>i!8#5idyc`w1=vZZ^ zP-e$ZJsO`Q_Vo80rU`+26M^-rsVl9bO0rYDfkkwL z>CAk}J!?}CZb?8NiBkF5x=-Kxq4T1)vS}Bh#&Mxe6eyg@EW&T6P;p5;lZhYK>wo{u z-nehcok$mavIQMY zIl$<4X4HFJtS)D>SuFRS0UTl*G}FUj82hS-8oZ~?J<(gDdeN%lfsd{E29o|S%58UX zNME`i-aJz>jUaF*tRK&m*MFGD%Pbs}u2*;Ki4V+*hV^Z+GY42Gj4|w=V6X|;hAXxv zVq@$uje^aM&id)Ch|Km5!^|b{Iww$H6{Q{A)6m6+P9%e~Jao5QX*K#f!RKgUH480T z5LwUH0<~mL(}Jk8@n(ylj!zE3lUXzZ4aacT#7{WGa30;6x|DWTuMGHuyz0(dCLaRg z)Imlf;dBH9MjwBnU6=3!?We9(e@dL|#TODT;7>R8E8tHndgM>z>^y7s3%aG;j7!v1vF2#d}4r%J)r zLw|bUc!<(kvB(7c>8{Iz{&di6La0wWS`dF4tp$;FxE87?_+TxFKULmn5!CU?I(QQ6 zPjB;RO`bpLc~$VY`~vs^!cx+_${4{|f{@->(ewfPW~@7T_u;ViR`I&m{>o7Msvfp) z_u3yEYQOsZ!u0d+C&8)F5!}GMxPe`=@dC6LSCe6-@i3=VBMN*=yLYwbOc~b9VXKj> zv9~R%IxBXf48S2#3yvtxg#v?k(|GHtX?w0Pjn<+s*&D32UI?sWr#-5}qX0WS!C@49 zMDM#yUCE(7EYH02yOgI54NRlvEh5h%)*Q_mPo7HZw17NyJ(Fhvv7WSdL>|^d9`YS0 z50&f5!?e-7a$Z=TlW8A`XL$G$}gAe~kNnL+xYSz7^By zyWvu%L|?lbYcfO#y#O{ByqZJDMnq&iM5HL8Y#}dS;8mra0E1ptG9KwYIfXyyJ3Xs# zq9BdaviOU12yb|G!<==a3>&&n<@b{u)^x}(AW1xQsGWKu1;=`h^IVRu*2;Be%m3KW zoVC3dfJytVjMkE!Fw>%#cx02H&WanC_hN#wlYPnXP9Dmk9Jk=7Q$QU9N>fPUz-q}%k zm%JS>?>LSlJ+^NeUH~_ESTd%>q>)I`7vXZ9+57)t?mOV5s?PS8vO1{iu4QyJqJyqF zf@N%Ruw^jB3<;Wv8d;5+RV-`GDj;iz>|(%m3`Xo4F*fWP6%pH_D1&wF*doSua4q<0 zKGp*Bf1dZ8bMLu#$_}hw^7ogpb5D8C+s}LYI!G1`%hqg-(_bUzV3v`+-gi6-hPN6~ zC?DPOO_;j1Y3Q#FytpIcVoW^}0-Un71WGJLC!1CWeyM5o{plpYR~P;u+v5{jX)%x^ zq=}aRQgMn3^tA{*RXc#5adi$o35TAcj0N!4CDQWI-z_Rl&yi*6Y5SrmJ$DqN=a#?r ziJtf#py!5KpPu88;K^s`oOn?by(SC*S`yc4aUnn13)?%VI0*gox4|ZHApC8M!e6%P@ZVk|fd6y+ z4CBA7H2i@63J;QMs=(+Xn1s5TdIf?D&jWi`fz`Em@qQipU+`BO0WLk84x2&}6dWT$ z{spHqdh1O_bze#gzlcKFSzRYZs zKUIQSQfj+2^bb=`@!B>{Yta5Nv^SYuw=gBAB6ph8d$m)DFY~+bS7qK^fomyj6-hdW zRe*A{m%NlEZ$1#sAi0U0YT~aJ{=xvWlLcD{@*Wm^7c+zKVJqIzk-Mb*`+O6*3#hB7 zIvD)z&meNikbFKbbL#X)DK+iiK+98NFqV|EO$i^PMhUfkE+Qr@X)h>&E)xUd?NC2^xxCx z)XA%+&rAOq*5|!fC;!le6zcQN(xKK-GelD_S+f#-KKvJD^ttLLQkqlA+luRR;;jOG zzVk*@0gpra{CcKBeXb!D+35CfIr_Y1yQ|O7PC*MGvWx2fLVFQ?eh5)zIeq@NS}U;8 z=Z!n{=c?%QT?FOobIl7Oecl9pHT3zM?Wxb#y--x27j3Qexk(Un^?82{f93W07z~l1 z@Iif^iEeGWryx(XQM!ftd}wP4eRdB0Kh#R@KhmP5YNkWSZ${N@R#y%rQ^eiJ^U?KQ z(}jW~8Sm8QNgrz+hcL|6$0ij)8m-lNO(dlTTtIZ`4cpuc0~~vE@rx0^O+xr>s`2}F zfRCTr6F!*=8$=J^#-$M{xR$hf@*`a^MMUurr*r>+v=E|cb=5Q%(Q^2m{l_5x41R~W z_<~Ogmc#FzQ;OpEP=AMi*HnsM8*m{{UH)xW48KE5<5y3- zW8^&|i$Uq(k#TD>1+%>a8!#x9O#DoD;QweU!juT~1^0T`{{r+1Jb;qkqR;SfC|J6N z>dbMTzh8tDuWLo8(SvU_bqjekrS_r+r97^amcQaS6X~p8YgaqHsu>3SghYt}#BR)# z@s7`W5zI`gd%x2lSda}Vjn~WugE(=rJ?f0DGW- zwc{)^w|=ZZHd9ZiGadBV{ch~nCG-OtouvEQ-y##^xK}g@lJ*s*Y{b>|u*Bu(v4q(^ zeu4QA63#InLU)V#XeHWX&BsgvIQBvf;0or9LJq>X8Cq-rJBO*`u*I1u&O&**Rr!l= zhGYyE_*b2q2QSd0o-Tj8h|?mbf81a^vw+H>SEJ;Z#i~mQ|49@9bA6|1sy=+KZjXgn zcoF|1EV<+~7%z^mQt6L7L=qK-@ zh*i-D%MCW=N~)+LnxA}ARz;fxtgoUu)_D?9GvzI+Sw=-g?gw)^FB={6l2%bk+aGbG zsHlMZ=4_TQN1Z+~A40-g=0oUy-h8A9JiQ4w2qbi)uemih_R z`~C;ujlW;xuU_ztaS|^^-I=m;@Q9!~$AlPSxWLG+H5Kj-$6~kom+wmKV{y^&J%iU82lbNLTxk+)?#a zu74FKg5_$eglf|&ec|&|X%6hqW!%`NZ%MMft$EL`5UEBEzgk$aS&d=`0e43nnbN9u zlN9DTr{iBGFTsvZjz@92XQI$q{PIAEC-&sd#Ps zMJUYP3%<-8(4u9lJ#RK+pYc0tG9PB=C`G`@yE-(=33a0+6H#*#E7rX`!GXtSia@}9 zjv4rz6V2@WlJqYt53%S?eewKpxI zzjc@#+TVifzS3kp_|M6z7wE*~SHo_a(QxjZp^Tc(cL3Jf3qv-s5hv;usBSC0qg7fsZLxkQQD z48v1uK~ms)u>rDE5CAVk>z&`+h-R5BEC#bfj7@rsMIHw}uj>ZA;ib>mLQ{9-!78H>tat#hrA0lr ztFR@dyIrxf7T1fH(4PSRo7jTOe?V>?4LsnNJ^pX;-{k+kyOc7#Le^KHugi8S&?Rec zkX=%3h}WM5AK?DZ{Zd^lARu3&og_R%4JLSq#u{M3>(8=O!(!(=5l*o)V{O)^F)CL<;+%ib;Yvd!#Vn}}UT zb@c?-4v}5pvVP3=mi{R18Lw2Fq?7f^m+J_FB$c4YGfMIZ4$o|qz<$osOKBpQ-r=D? z3j-~?K3CLV(;pi7mZksNRivNsQ3`4UQ}GwR(6jj-9?Il*4q(2H`AJzT+0{aoYEoI< zzU%IrU`36d8kgW#rS$FfT`wwI=x@`+QF>~SK(*iIARW>|G4x&d(v86Lzn267(W|TwfO`JCe%)--`2<=n%x)| zwZ`M@r-_9V>L!ePyFXyvgc|;hOy0;C=mc>}a+9m}3*tZQ>0<`9Oe?#t)Mf!&ja66~|Y-e0hps z+?wN0ZB%8Ta2j$axs(TkY&F^{Z_RdysJ%GHM^gC2aW#8?iYN zyGXX-YK%=8MsyFB4BSmYXNdR#hjMaiLe^w66WBr?{FOAX;5OmHlnD8THg@vc5CWm! zdVw;5tx3=wi53&AEpz-S{z^G;WdD+96h-NKW41%zr#t%e%|=FX`UaM#@87-CS7+!u zz;DB&4}u0cmC+~H#C!BP@-@#reg-g1AX&>jXWR|B%_%4znTe(=*NU7PV7)SPLqW zk&`R;%N(wA+9IR&ze)#oUK7yI7(qU8Eq#`Om>$S^7QN2^hhfoXml|0^m9f*pa}%F2#9Efi86B$AAc~}B_@y}u1)5YP0AwbP=6gJHRvwI%vy+N z7WYAiQ5%FydAcD*!z8`a(;g@+)DK{uQ#WnmDa2gNf#@qWMOl(+)ZW@UP_Rqwc@=Mg z@-?Icyx;loP6)&Mb}t9sw!OlGT7i{pd+|`CTaLdhY~L?GX;|6vAOil=^{NX}5oin; zN>8XArmy-;q*B-Ez3!$!F@G=I0pO&#YqPqTP&;zPO!ACvv4tY_*?*?kH)%@LmE7GB zu4ml2nZHGD#4$;WJPLX7uHICWP{Y0mFr~14v})QwQ>sA_o8ME5ZLEI3?o$i^fv;Ke(b8{--|Zw^=#zpZ!~fpdew&r!|L}mK_;29Da`3u_DCWA$)ZlW18-vO+^IXX+q4t7 z5mBUeG})R-2Vl(s)7*&nGHMvk6iN};ejR7IFrMgSco9=G!A)JKZN6Pc^8B0jkdW+| zaSa(UN{W6%n&Agb-hTKFID{{uggOQIEI-tKI#q|mZ+`cn{VS=A83gbm1V|-!z6Plh z7+`u~%lp{pNw73O=3Y1~`1e*{{S>$tt!~=~=sL2wF3@(}AD*=QN!E_(gtHylpjYC^ z*an6YnA19_0jKkG>a?eQ?sBGC|L63tY(wAI7||ChvYQF!7qSlFjv1wBkQMRUuS88C z_d%RydbmD9fmtzENsaCBkLH&E z@P|x@)g1rn`Y9)J(?E*om*KpPFDa=16}@DbjJtX$@(1?y3_!ouQ7~e>?!EVMe{Q%2&TGZshKC ziZ|)09}d`oT)>#GX7M_@|C7&ZM3>NH^cll(5%zSXg_hAU`2nU6gk9m@PxF#C6fl*K zR5N->aN-gdenmcaba*ik5rp~|$8Yy6-=U{pUBu?}%j zAFu9*H(ux`{0RFWPW)eSj;_@twuJ_~9{iLCtV790m)}7f%&CpvV4AmY3!^(5!f z4Hbi*6!>Pn;p&H3-EU-w8iYoK4QYb^$vcE6VZ;ciYnPN&H-UYIRe|uYpgl;4fg@IF z%*?>;_xmL5Zb%sOy1_wZmjm~pu(gSj7ym5*mjZ93imFDy?NyPkaYU1iR?-=XXUClpLvUGjDe&UpQeP+G(^)hL}(fU zn&iGAa$9Z6MKs_vOK>tM4`L#OsS|-gPMr_ZWR9n{9rtGlpe@J~Y>Olhj=I(I;G9=9 z4@8!Lu^0tZ>w5$!7)R7u)O`q{i7CPZ;0%g6wJzkEGcHYPp4{+2RrvEFQZF*=BGpZ5 zS1enHm~aHVG|`qOcl{^w(Ml@|iKtr*TlPS!H$F0r9raj=r>#znZw+*SoS6mt7@I5ct-7t z;v`yBO`v$H#zIEbUfdb=h1g3MyXHH2J_Mq-1{U)N*uFqaC!vgbWFi|9-@CcmY&J=R ziW;}Dh05M-;Lb#0R?@hLQ9^u4S$EPB;?^D{K??i_zfzEqx|uS9GcW?;ybKT?qULW* zc!U6JO^iE5npK;=!^eg@J((xf=-W{a0Kf?TLph3r;NFNT!JS6NR~mDwtG`yNG=Zhn zz0`HVLomupmy}b#lKx2$m5@u1MSwQB4Iz$(&S<$=+0$(wjV%MzCM(gDc2_K{Sz|l` z&=`)Q;-fTufNKhKwIoJFe>q6YZHRgoas70bYeDzeEaNYGp8^v7uH~lNwDN(c(HS?< znsHy(i&|PK^~n;=Ow-3Y-NAEiyrOB$M;GBLQE4=RVCD9R7p!nl-3&Qb0p+Cn)3XE$ zrc4{rqM6^;IiBqv^0L8MM|5s$&pqe&I>0~jw8qPcf&=}(=n5o@=4xk;YMq+a?qE`D zYt;n3#KHEKyhJ*?1Q&Hj8smtK4QF%ec{gt{@(iY4;A-Q|Br|+|K5(n$-P_j%c<19c zhBfNgDtfAj{&2$lIvp-E+Khl@dU5~*^dLKe@rM?GN`9&bXi2rUi~hZ0^VIS-5B*B! zlSg4l6(-xm!qRV=J||qZF#af*Pg=NF8TjG)lsb8aI6!zNA7G%JGzgdXhw5g*>FQgw zrh^NNjAgbxr>6L+K5s~*E3dZCU>@Qs{4s#+V}^=(H{(OAtEQ{hWVwgK%i;ZM+pk3$ zIhbKAEzRcC2d5$pM;~P!)MP*#a_SkKF`g&a_udO9qo=(Lal+4Fk7LdANEk6^FLrg= zXxBjQy{)-&Z=NkXk@xpHZ6MOuXFLINYuv)<D)N6UYnj>uuzf;g{G$@$sYT#0^Du-+LFSO{r^FXqV!R8-4z9CmH3EV%O%6C-%UL zm%8q~h)AV?9{FgK?6v;)x7+&bT(c_mANJ~>x=QuW^6I}-tiJEx#d6U-sEtf?8}M7! z0fqd}=Mn@uJ}#7>HJ0eXD$?HZ;WD-@{w-4e+i;t1`B3}Mhs!$oz}oa@;+tw`M_>CJ z^)RQ#l*T9^?>H4Bp;{G{%A4HWM#d!$Ph9+5dbGNsE0hTj_5SanAM`t}l7-u2gvCZV zGKnqK%Iq;04F%~VF=yiP07BHI3gLkXn0R1>kah}%0%7_+seNrC zGO$-XBVtFHk8#>2B_-=0Xgo_nc^&^5RXIFd%Atp<^e%?&zOah4M>);PSw0yQpT=yZ zu7&Fqcb~_|mV8Y`3MsJd$%&4bI-H=Oa~K(MjCIN<)D_vyO%32rtOviZ6{tuw>g$_< zR}+M{X)$=W;++8^SP8sG7X@K7-<&mGh-oW&NJVvQ^|B2WxP zMzF}urbrWqfHJ~uhZknfXqlZgLtqroP~o8+jJNPh&V4kQ;bcOsxRaV_NAS^SuFwwP z6Ilo}6rhDkK!OhCMUyVyZE|yzuaDFSZ+30u!^11PA~|*BFvI;OsCYB?Gv(CD;t)Ce*w)#?gtOQ83RUZwI;^BDle;g`OtX4!9`ywc zrYFzK8MiU@O|z%1DP7HpFv6LSO%V`1Ay+RoX_Gb+*qlm=Vw0f-!wZbKG4@`&NP;SI znGC>e_ixQgB-Kl(m{TV>kP{~DU`dm)DC{^!$SrD!pxVWcnA!y!u|4|-G&X7x;Nf92 zU)H4_t`d=x)=BW;ph56Wv$*id*&H6||FPeaBNQNk_Ir3Ub)&ZhkxmUR8W}ltC?7rm zC>m|jZY%UOI;>yTQ1j7^JLtS=irJcd2Z>VillKilyi`5CwK5;=nXXGROSFSo_|c`d z$98uOqs>H&rQ7=}?y?mLvccEbwp#uD z?$yfl_s4?g#^?FyCQ~g_vNbJuBF8v6k1(Ge@CS0I9rxU3b~KOQnWQvCmHy~uhkj6rm4-<)o1jZ}@_S*^FQuw&T)oydqe zL+EO<(S5!Kw%M9)=~$lP#lC{e0*npBKNfK=VR;*R;vPN2g7o}szaTwbL_4YI-vV@e z@ zKmK)B2AbJTq)dqCFMA|Z4aZ7WBU=!f&jh)+q6C!aYmOP8#YW)n_^Iy?0H^%54k8rB-|t^uZquTqyMF;s6M2?J8KgT#MeL-q%kCocKl$i%K$3p@a}Db2|6T+KkZV0-j7!YO1@F5;cG#meT+f?FaIiHGY|Y)+JS>c`DXsOP#=OGDLU?=2 zl|xZW-)~(IZaMZK{S{d1jXfuCuBDZN7Vm0X77{6sEZWYVTJMB-RHJyfg@8N-I)n;; zq}DV0(rF`G_xHlNfN_Ol{x=(K`jqxU1VYs|^BFgD=p(OuqSjVrKS4J;_z8%f`^R4!Yt?0j4Sg1%ijXzQXM9KEY zIxM)QL}j;qL;%yhmOu8U8Nt8Ol?(XijEdJA?9=}<3Pz9ie9n>j(*9n#_4CCz-%I^1 z;Fh%oXKOAmO+sisp~jKgk;Q72=?m|O5N z*zYH!KeMr}4j>U{m~C@y$*S7+rG$WfNap%tUY>p$Bc|?}N(d3PT0ifM<73jMfx)E( zbdOncqu_ZbBV3WL4PBM&u}QEugR_dBO|OS?SD?Dun!6ixN5Ct&MzBoy+4c;gq^8<|d8Vzm%?BrE)C>6VInbf6 zz+<0x6GI!^c85a2_#)KbbIu||Z`h}v**Y$}482_LY5!cJwIOS6DkYd$(iknAF!adrRsEAn=&xsFLKYDY{nPA9f4#2y=9 zfF{oK4iGfwPu6HA)$DBznkY8EDZC*dv>%}OO1Ev;3oV|mMOq3mAB?te&7q zeZYG>_uQ)S|EK@d%q+%#v`;Cf$|zhu`uUl}$LLmLdq}r3>NNz#qFWi!tr`ezC^`*w z@_d=!(X)TtT@tcKG>jmTbb@?%5)%zMlLXu}hOSrZj}%K`?14C#>{`*v@s-Oar~ehG zIkm=9R*_0W#Yq%|HwA;`a{beKn8vXB`7^x!Zg4M{EmXHnrw_I``c|$?deU}fZxt(m34Jd^_WRPBbuk=_oc31U^_5zR+0X93iU zsR?8-a^u~DxJnntr(mNy0pbP0cGknoUws2vF4sEc)MKE{id`=eiG2N`KPpDhK%EvW zr0br8p{I8EG5IqbkOl1Ox;n2vOi4<6LzDc?sYms6y-8@(XB%dP66-j#bx9}@8k18i zFc%Gcb@iuHNwj8v5BGk9P=WB0>E#k>bTonhwgISz{k~79s#}=?!q8V@ky<7XxLq_K z&GAM=(*RpvI@+9+ubpy>@OMT-vv^-jdti91s=A

`cjTXgRt~cE1>VMUUasUluEu zEH~?V-Lq!2Tu`nNf?1bEAecs$Yht;WDc8huO@!A>{F>FoPFkfcb_S&F!9K}+2tt3r zhsQ2b@EH4a#1l`?)`5airrIfwhv>ng0`8l{+7nxCVf!uo)k=C=LC>Xl|C((sCZ3Bp zPlesglLUzQYqm*)uxTvY;a|TB#NAk*kB*!` z;!Fc*@Gurzu(EJ(TGb zfaWj0bF7oCK8uJVERt_?^>E@b!(OSf(^k5^w7I4EXdWc4M_X_^-sT;wlsQTQg!{0o}(d@~r$I#Ue(OHCd5 zi#k25ldYc3xoObMPf>l=`HE}&C{qRfqDbT~zB!FW3G%Pn^HD?dBv#pbOI=^u+|hgp z(hd0VX%?00N_Q|ad+TcNEXD5yvM=Jrg!$dGnq0jk*skB102*8m1nNIGsLS#oMNF)I z3e?4*1;Q+D!p%kxkqLNHV&t)!D6OnqIqDJxVeQ;8QhvBwetjt~TP+NP{!GcvXo!g` zHoouk(Uf32&dq4>@)zH6A`2e!FHrg4MOq)LP**LvM)O%OX141GHBcuhFlOk#ma=Vu zn@}cIU)aKY5ee0aFOP^D9yWpzhTJ3tfj|=qQg|u;#X#X~c7v1Wv!{`Pf{hn2OC{kc zUH@ptsXQ0Sv5ago;j=67=`vP;WB-zOYwPmfM5+#Fo&JnLRg9$)2!;l~ff_ctk#37y zi0hz!_SCl|vhy%|&AyH4HW&dqC&2ywQitAxMYiVA*FE^ao!w#(zKp+OI5iNLEm)=u zRDhlTMW!jFhN5%CoABD0?6sSM%uIE*X0+d6U9kp*ji{*iW&8Wq-TCfp&1$_t2@gLU z{)XbG*|LM!hDAmi&U_ffno#ZGQU)IpRq!Z2i|@Y{wMPThLhxCk+h-miLw)R>Sd7>t z(1GOi;o4bm!J+|$InW3IMh<>-sl7RuY;saI2S2)0!-@EH#oqH7;a9coF7fU*`07K9ErBZiDy>?i^-g3Bj*zSA+a1u!WqvjzTAV`vm1 z{y3*ogY3i(T)-6^}h6F&y^Qy-#lW5gF{%lKz%75X;3M|UE<8Xc@&PKL|DRK(g?=jK}@ z1sQ*^wIfn{v!aFg3ip&{7;IKCDn)ZCEd4B)G?TdFKa)~tL1{$hpMZ<>wpTSru?q3o zqn}(Pzz(cRE)s}&mGz6o+LK|^QYwN;9=X074Pn~nUVsu9>N;8OlYYeZ3zlE;e00>w zx?IBD$Q$-2-gxi>ImAN~&ept9Z01*Z__Q#ZK-_BoTqluRcfr0gs zjZ0Ss38nF0V)Zgxvl}Otz1}n60V=A$?Bl-Gf3RGAhyJc&^q*O}ve0kqi@S#XU4RV& zoV1yAydP0m^tEoG}+i8 zljcKlHx9FVTI8u&@z&k-oS>D;C~*{aubI{2gihr+9+Md{VuZU7^i;NtkimqpH5$zf zx)i&)wTM2_0DULNjv_*EY=9oKjzJ=N3%sXc;Dx%VlKMG7wxVnMo+ytu3+Ml;VB)9<;ZE4%uw{&bjjwO9g|7?kV-$vA(BX%DC#9x{wFjLJQK<2TN)30~0 zLsH^uT(z!R^kFezC}Rof2u1EO*{*;I9$+b2JR;38sQ$#94aY8>YoH~)O&y%o|CQ9Q zabH`rg#Ub&D6X7-UEX2&#+#pqxF(n7?Ger z3_Q@+u-VZ3duR|q9uvaTN9srCyjnD4)vYIs~(i*4g$Jmss}(I(@WB=ZH;60fI4N zkx6F<<^&yoC;@bAe*xAq4%gs2?{{PKDI$u1yI> zpk6WR55=iodQFP6=D`x&n{W*FYVq*{v9du;CQyZ7@ z!d&r^#x1Y|*%@(FBiMkTLLN>`sb6s~PBwbrVcKV=)Gt;?ezxY&x8T+Hx?a4x1u(xQ zyZ{pRr(iJ%YZT3=Z=b_xWJL$Wps1`ELaz;lKxn^U-U@KzSTFJNY6M9@ZLc zx1lqI`#nWo==68^OcUWakd>-`-KTx4|Cc`2zq{A^kld%7r}Q}F&=2MTY3WVt?4sSMg0J`u_-<6#c5&U8Ks=zxo1aDst-e;xZ zvAN}c0U|#3#q8^?K#?GQ2rBB}nb-Fjrv28Se4SxRusmbm@iiVBAHO)HJq7It%jYjy zx$@^-+FN;UGh1?y;{rnXy4L>wc_CgE`YYTX4t)u`iCG*Y7$yhd7cW0`CCjfY{9H-+ z)b;Pm`CDTTuhXacy;e>=Wu0u0K5IWW+w3C#{VHp!WM`55xqZmTWevDO0%vOB!>Z=S z#H-}g_|F(5Xn_RrMfva1xtj7|1gOlZktNHgJ0`Vq->bR%8R1a$tOub@(`NLObgc@0 z8Sf5g^dr-BO??k^Vp{te7dluSzw*)Fr-KbqIQY-u7b_fo3o5uxgz^W!T0Z|yUNBIu z{*8#K%H%KaCI8)<%GZCC1eb%~Oa5NC!M9xfb9%`?nF21;ezuqVJA280zDz#kuO9$F zl&Sk(34`DIhJBHf@FoL!i=!joO6y;_a`msO220%b4_UeTZ-wh`RJuN`d+rae3I^%3 z5%+35iy|y=@~QvJ$Jz{S9S$IYo--ZOTj^C1eva+tq>1o6)(%<$olPOP(u*yNFp zm}z|DU48t~t;)oo%oWD@fq%Spb@%p=b@u?fWCBsI21_&O$=o;20(Y4M^uA;p9ToCN3s=6@w3)g{n)PTdzULyAWJbJ}chTtc-ncgh1Emm$6jco3V9_j33o3SKwK$nJ%0=VLYD zOKfJNR5A>u{&lDQpopOGSn81FPR^?ToBcFC5k~RzV|2^icrBZ!L(ar{m(G&U+t;$I z@tf-WXE1q5YmwP@)K03#Bpye{TQ#y~t>-g;u}n`3rc2o?tkYrt>E31=aC*uAIyPkK zMS1e?PM#9CyrKX?`TK#tJ~N&Cq#B3i61|_M=)*rlcw^=eOb$9by&wN6bPw&O)TMF( zV_V_=7u!D6isTEMg{3$A5tb@$Mm`%D_~6r7#t@-E*~kZ3#*k+ykl^qZ_tqz|as4?B z9n9HQGsokrZQtW(Bb9Y4mU0>VZ`;Q?>xIa#D_D=GJbUl}3ofkG9A7i&Al&;%@yG^t z4}oRdX(rusibHZrUH?Za3RJ3N(3&}YR5muPU$-gTW(td7Q))a4FOw$Y9fP(o#pB+g z0gv%+`|kT{A7XC^(ZTr{Z*E)bhqM0*ZOMKdavT=Y(xxp0e#NWf@TcuF2$lEm&cO~v zN+>g7b&*5&flSY+E=;|NOm(q&<}+^hAnQ`kw2=*FCkozImdoC}=Oad+&Rnq^*_Iq= z-{MQnZ@U2neLP7U{|l5yC&;~35#o#4f3e@69Ka8rY=d6hbss2}b&k+CSo*m#{H?BA%!=UHf z6owsG)5Zr`D15R16vIcISR9|KAUOL1n~KN0Y2Z3TSonxq&AM6!()V16u&fxr=A2NR zUtKktK?c8h<}<{v@6HSGYuFr%v6bJ#(<7K`U19Bc%q?C_&2A6>A! zpCx+2Ed1zFpZfy`ZSnzn1Mp8M4gZfN;Scu4O~zD9!{2>pjihidS?U79QuWn(G6s`A z)N@{8`*rA1ulNy8Mdg3KlI7v$ii<-X;@#|0=)Qvf%f)x+9|s=wgD-+<*PZAFw}odq zOu67AuLocPHxZdLoG17>F1t$OyY0&!D7QjDevmS^+XA1|c=?^8MT&d<^$_tvwzSj=Gclq~qSQE8J43w`oUI8H>TO1ylFlfQR1_-{Ydv*N4U-YTi zP~~r*3epc6y#37M#)b(1IGHN5B#nxsho>VG)}T9;HX_!svg3~FjwpbiOV5)lKzgA9 zl`OwrcqCSYej@`jgz1M}kmB#P2%re+i(^O>*|SHk7sNIA0{*Z+j!y#W>_`YIr8eTH z0M*xY9uG|QF1Mk&z??_=RiQXGva$AU_`;8d$}kA<1N5=MG!_PRu-+R{)w?1*jby*G~b-T1Cvt0(6$deuwA{viIq$J5UN@cb1pAwYj$ z>l*^JtZ84p(^0bh53uah7yMsYh57;fN7ci8>wBZ1B>rRj$p4kEe=67NLP1Nzx< zsDmH#W3&+D=kQdl6mjz5t0T;YhaQ--^+Wp&vE8uIdNIBAOxwBK6UUMDp2r|;Mw;n& z@C!;c8rRW{#fHM^Jfeti!LnH;+VgoNxbHbP!{i!m+XeP1pwg-ypcvxW;eATooTJ}@ zy>@bJQ!>6c@|^}3-@*SIx0D??6}9Ta{P1@17M9PgWO?|AUf`bycL}zv)d1Ofg~R^I z!Jp+J%!2ZT^1>EDGYyhbUtB>~Y#tGfsk4KI;ju(}g#n@9-A8@a|hsA%1p0G^oP*Ebq&4drd7**{aR_!TS}>2P5Emgw=3P zBO<;Bu=SJbsH@36cC*8aUS;3}Ri1WPee&8rXWydjOTIWQ>^C?JD&Ub|+ZVM?sC^cc zsBcWO*S=fBBfXpv4ZP<*`7^FF`68at(1~w`%W<`(J?$$=`j{Eq1@?-c)gZ^r&Z9SO zr0S-)jaL`LZ0bovIqDLfOR4jq^d zW$ramRk+L6*F=qxuRprKb4*1Fgzzb`K4a|-s~aAtT=<^76saGcUrdm*N%iwH!4bd> zO-Kg(7ntas5B$25oUFY3fO#d#!~gIU1k|2-d(A%f9tEs(G2tf|7>P&k9!e95>6>rp z!Ey5q)F-B{#Bf2E{c>GXmWah!3&ioCHtz&lh`-y}<|aU%|EuMB29JegYc}tsJTDyL z%X2!xd!EJRHyVj}{*UwFa)SV68){GFVNi-wbfNrneyh1el#Au}RzHi;uqC-z{m+Z% zTm137IGB!!{sR1Ae>z%#+w;#kP1w~WXT*P|#G{=dlled+6(ckxj>1yTNZvKYuW2Ko zMgl8U&-cttq61*or6fXniV^kWh7@UH1JeW~9Md;~#-I?WBygb=SbB}=FFzD=?90-N zMu1hyAvh$D?e~cC)*RZSi|@q*D=1cuZ#j6URst{R&;9{Dr}Nu@ z&Pn(k$WP9JYpH^GfG1);AlpxlWyR)7I?+~_bV6BuB@5J#3-o;S7IX?4JnY@kk9v|H z$EXKTF`nTVqr`c{B>PgZUd{k9{CxEvnt4bgW-=gT7v03(?Gvn);x4_I&)!DJ53if* zkp$M0^|mJrw2I0L5cAOsHpO}V$2uW2nb^tXqpdmx4PtWVoHZEfl!>Ip+zQ}N{%jUm z3cbpxrKihjG(S$tsKb9{``4l7b#Qa_d;3tiL6X&m^ZjhZNQ@Jf=siIfO2DN4cx*p( zzHL4I#&5s^oooqKCj14XyD`1zrOM?f=@S={M-h=Kq(y0fU{P5~OF1=aXQvz09Y^=8 z;xmOT3tvV96Iw>fvK1J|J5<4Dc&(gz?@ijgMs9S+sEodMymq$Cp(a$*OyVbt+J=C& zkD=$&q*1UqK#0K=i5}VI6Kx((k)FR&eYx=-wk0~`=?itLnFury**~F#r!+|BeMefY zf*KZM;VF7wUb3vdW(Zp=sA_PbBf~?&)5)DINUb>LU+a%WBq{jnI@ZMx%+3NjeymHA zijC&z!)}ds6wwEzAP!<0hqN_)k}GF0b@{v8hEn{ zMi5<_DcPD&@FyP~vXO)u;)XmP6xwGq6(XU=jiW-$XqZGLmz3)wc$o$zqm-HcrA{)sCG#j;S!M&{RxnIa7fEJVSuuni^s%lX1z6HI+0?C8P1&=r_a& zO6{7;{cmVfA%Fr?Nvd71C4Lfo8Umg)SJ<=?6l@jhz^j|v<~Ii-V#nY5t|zH(x=mb< zb}sqoMeuCI5We_qTrNIgr+oC`!CF`{Xv(NvZUDQp(Y=P~K~qNU@ElAvTQlQ2>|@y) zeutc;M^baw!3SU^h%f%JZR}&(i0NBB^k}SAz5RCq{2^|PC3{-c>(3Gxx|Td9W2Q4L zsgC7geEnonfV$9p1#E4kzDOJK#+m@r)(QF0jp&*C+E0Mrr)waix)Z2siglmBL270S z((nme<7dCc_7gb7)B(8gPSk7s2qG0z*UKg=%nh2gm`&+VVgw+K##I8%QK`U;nBA@O z97d?!w-P|J%G{6qw%~kWSl{)an&t7{Z;RDuv(axQQPps zV(2E*z5aqA)X$m%Mj&7d`U{g<&IrUg`Td2dAx3ca5xxCL7>0b;T3HJ z1W;fE36;2t{4`3%GU}gzZb{REw*AE}&-g?`bG$LN^>=Zq3iN_H0%xC7>jk!rdj7|; z9}FsgfuU_Sde1DF2CYtgv6K?*$&X(L$=7?y6R`&ig5?0jwkIq~RqVf?dX=0g;lGpX zAk&0g<@&!SGAlz&)9S`s$vavHKXu1OeQjYGVZfi<5{>C$IZUye@reO6;H7lOGc4c4 zn(5I4o*0FhnE{+T*IiJja|^TDiSYcu&~UU+zx{Zq^W=hfsAfc=ev6kGKg84TL2_Y9 zv3TeYhX?e#6Dm{aUN09a~iBnO6B^661aB?-1(#W>p zQ>xTA`#Lfb9{$Uuh>RAqOqWR!87*SUVk!jQux#pfpxSHCogBrgrtkp{IB6D{jf{3Wath6(GSh=z5-Fwt2s46HNUnglTB zNeY4Nacd?rG`GkHhg+d|NZ=;bqV0i}39Ezs8CxrpFSw>u59;g4r&?`xH$4)JOYOs{ zA%-^mOsk)d)a2q5H?50hbukCuh%2#8U-t!p_5q&(fdJ8)`#10+_n7*My(sbmZ{CDE zlwRQdge?GtDj#mXt{l*C|6~?Xg_96iubXZP>!YfP2$O2l;b4-Hg!UK#{NwtP&ZIj3 zZY!}sewx%YnvgQVNE81s6kUy*hsY{ZfHQg=#!E1Z5<}F7Sk9+7;^}$suQvCgGj{u1 zdcYIu0rO1SlK$TBopbE?EijA-0y_=^_=tXBs^q;j_3%Mtonou_l~YY)s4L{ZFSSDc zfGF;k+~PMVtXt)B`}+J8eD}zW-;KYTRFx?rhSR9N+FuvJCjt}PeUH~<>j=@K9EX=}&mLzW#0-IiMxyBxQ4@R6so*;cyAs{# zG#FC`80PW|&j(l4W--X%3du0J4VCk+!$(BMjbqh_OvMiHGR);kxS9(0-a&lBc*Xg= zonL+^1`z3GUtyoa6dVN;J|Sn2iP^v_h`*n{ds$$56*{o%Uvfu8x(_eWWYK!8ga)z$ zCmR60ffwKs|2p%1ahifCO+tN8!G1Tl{F8gqznGXEL_oQ*?Eya$X0~208MyrB(`8+tpVAGQ0u_H)Gw~rt4iQ^ z5b%3HfM1I-+=Ti?0Kd&NekZ1SS}-&PP7ZQA_^tWKT@XUf2iya|8k#S`GWdNJz^~rJ z?`TAh0E5`3jlr#+S@%vP7 z6oDd9J!iy!C;CFohzqlvgr*0~a zAG-dK{@o2q3i$V{5Pm-3R=;?0{Iqo~#P$haziQ{!d6?~12q2&z#q0Y5Tn5TQ22iB_ zsDA|8k9qAsSW*2!50t7uvDf-^;O5evoa(}TIzA)nHRh5#R|!rzwWSA9*YXY9ILRXP z9Y4nfTRXSaqwkFyE8)+rq53JW{^b?bUotn?zONu-DyqLQR6pjmzhOo7mpmS9KjPJY zL*B{^bfQK0SGxY?71duDYTu#1lKNP8#08YTUbq6>pa4J-{ErW{@Ar&v{JjGC=cD(8 z+V>6gvH}1_+7I*3@1VvMR1f-pZQtSgDNlYIR#g8G+|cZ@&qx0a+~(Q4e^`Q0rFi|z zE3049zm2J={x*pI-S&O`*>IKW`~LBbnU&CgYlwe-{mUz=e@F;_Uw&gMs()puzTf_a z71h5rRKLm7pEr0r2w1C!4u$ch>0d(iCwcWRuc-dCkp21c8&gsJg--_Me}&inhK2P5 z@p+oWHR>tvtw{`KpfK!_nMWDuChP_e1kN%uIoP1?_Mh(o&P+8~OJg(+VqMNyZ-%JNN+)gtwTg2#k1toZ+W8C&sBL zXZ(?k9z2kXH#s$K`|wnZ$lsR7GFA0376rkWU;r27P^LHIGO(3{4X(Y*P%s zqbFMY-g>1Fzss76;?-#1Na2)}>rUJgIe9Pc?RqgtYh z5!HHx8QRMA+QAUe>PNZy&3&xj*vI-6;tz+UzP%D2+NH2?{t*6+lY{G>gpv?RiYC`Y z3%_?eFqnqPlx8`_S=)A-I8RyfiI;7OG`J_7Nuat>dUlk+92ZX>G{w(jgmuYB|AH(*{h*z3lb=+g#!s(+J>%`To zdt;x#PfThUhcTEu16LYlel6hNK=yY#KwmwUS$z;Dq-6csqJl4(p2u{(zQ!L35q0~E zG)yyt`?uq?LL;Y2Ty;LjO87)2*`g_%h>UOIpOtQJ54XLGj{QGx@7Wg|`d&C@<=YF+ zXHrRXyFm4DNH}?VwL0kl>VY>amZiFR!Vvxx){paM7*l`SVEuP1sXx`J|Cwx{_S)Z$ zvtT9Lhv`tYtU^E$jb5XRhl9nJs13hl(FAg*l}yDhY|_|FE9*D}EubWHT`HR{3?HcL ze;`|%!c?t5h(knyop=C8^HvR?9zH;B@%i~wQ`${{)= zOFx7?d$Q~R@ENn#fO2VDZpR|!?9NN{pIW=sz6$5ZPqZ99|d#2r&|t zQHSiPJ-w`!A=Skz?&D?fk9Y5;@d}iS8AbytSjHe3lqALp#f{z7T9+Ao1@zP4(-_2O zqyIKOJ-oog=g*r^2eHT^bpJ>Ab43uJEC1X0q=Wbz^54csofeSSM*mrS0`UU-^PIW` z*QaBFll?@|{TAYx>G?0>6tcCMCQ&k)3gR6T-ss32U*c8x<#!x$_2mw_yG4hBBW(M6 zVH66({RxPOddn1&&ITs$925!Ni<3}yO$8@)McGOb;YTywM2lS>OqwCDq_*1Bl=ZIV z>HQm*L{^>NJja-VUKx5nHH8Z49Z*2;7u+LWl3vJ?>!Nyy@6ju%71DdxQ+=YhHbCz$ zip%#qhUK6kOz%rQ1?2vrEV=2To)(Hz}9Us2Jtpi~!|2YU+jlC;CSJunOp3!xSo{zq&y3`{P!oLi%H+ zBp=RLdHON6*%$d^EUeL`e=DK?b4{Tz{Z~HJB=KaS{^yEvp{W!X!WjWB^df&uZTE%# zZCv_G^k-|BLSg!g>i?gL$v>>-p5=!#0+eh0x6_|`K4TDENJ3rwXIb!dX1#S;{#$nZ zF6+T~pMFa{UX4xQJFG)5#0iie&PNiX`AX{dG`RIY?LhrUR;GS)p#E*Ff3{!W89%@? znlULZ6-(ERt5%=y0|p1j!#W;9w;?m#IA3M^`-Iu_h3JICj?*^pSM_!6I%cGs#6hOB zYram`4$UDEeN6Xeh@{dlElFb^L3+&`b6)zow(;fK_%phIhH?^z%c#0K3v}{T#89P$ zsI;&)KJV+=Izaup_K2MQh}Z5g@ANe=GY|63?lyDf7m`2}cDEIrZ*|U#%KlxDH%xY` z#q=&cjK(r*YCWvLjK%`v9n(u*!M!?`D^2_*jIAu{Yf?*$xSBCNvfl_QxS4)yQv371 z2EY^U>t6JP;=O3YR@&4mv7(-HtraZ*qxA2jy_XGQw$Ba|5wy%epn53~;b_~}Um2(# z&S`Y{wRSs{79-dgE2Px=UdGT{P(+~m`j5{@x~&y&eQn6Bh%sRj}}%Ge8#EIx5ofW;D~+<2Qv^Y)Ay)-I;cwtQ#gp*;RuT01Y2U)V7Z3pDZ& zCu&0ua0b{7%Rg#l7?5!ZeQtlPaQkRbjMB7^snufm!{{g|=myjR%S~0|H(?66?{4MA zu3~Cq_aJWMWH?1;1f~=M`kW@M8Kr)Qasn9&AUd8=4IAk+yQ{Wl*zs^Hn?@;HDJ%?) z%f!fOGEETc9+sX=%rGr7&g?Uuw|D@rc#fYYPk_hL_|}sN7zt_nEQU{G5qxG~m6JGn z_}tJ8d@fr5KY>r~phEr(4C6DF{8_QykB?7ad}^*Q|2RtPho0ixYrXgOo7-POJXGUr ze}{r4^oeGB=RjDo7Ck?|I1rSgBGgK@5O1%mI~rv5UpfKYzp^$eRC+g_#?R@GM5&>K zIrTf5e8qg;}GOcC}2oSkME&%5+ zsAD0Wt5`uQ-Jt=W6$D&3R1(EcrC(a?2>^p=^dSRrES@GB$lxs(AS=Qw{KMfWu9f^% zx`U<77dBOj1Hj5Jih??w2UGT|@uM5ZZ9FJE?~Ek8CF4M<^}o0v5%eUQKZW$92^Fik zvZx~VkNW_6pvOR_N%HI}iHIvh?%&KU;>b@hzfm({K+f!whw?~Gk~hpI#=Vku%1vMO zffm?@$Wb~8|DksfjT*lhWJbjJ8MiAk8S6GQG5FW&V?7JQHRfc&UQmdoYai{LHn4#l zve&g0+&{qc4CboPZa66v<<55Ul++C$*f8#AmZ@ek8r$Fn#k6-YAeNqk`=qO?xVAv# z)3tY7vSZ!18^j9li<|tJ-nr1;FjE}0ev0#LUfT7NB9~+6sKDeDP)R++21sm6?Q8(& z;@B&TDS}phw5Er*%PE`F8G+^|!xf!*HW_Hn-fZxBK5roCgI@xqyYTKUxSD2VkH)2zs&R*j|qLfol z)WJibD%4Z7+B9+E{-}cFCc1tx?^OsCH;H#0Y?}+G))d=EBT{+(%7U8-No4Ot)-M`n(zf2B;uNPo@C;uK* zTLkHI*04xJVkGf1kQi>F5lwVgFF6TJS+X$^L$J~ejT2EeV~rb0ljqYo0_RhT+so}6 zc=oas7d1n&D{U`VoC|yT)ftYxe2la0D{U|R5O8_6P{LmJTTJGyqPqL*J4ducuk_OkKmj=jwOcopnr$1Quem%;Cn zd8=qIP5&yemn|x=mou!2y7tmQDT=*ZfbFZ||KRG1^e=0se^h%}>jtZNt6(qh-K|l5 za8MC@d3j*SUPjzY5nr6bj%&s8D9wx2Pa zj+w~rv?U+?=Zi1~d>?1NU*|VWE5_pDoP}JJyx-&`1Q6B(n|Zbjij`3Fswpjo4+s+U ztW?UtJA$=_8Bpo&MjwlT7<`Q~rJ4;E!_9Y-X$ze`#^A)qVzHHhCy8i;;}&Eiz>+d& zGpbyw8E_+Wd5=z`-0Bt>Um=W%et#-1s*#}LUQg&U`p4u97CYhihObWT%c&jlx4S=% zdKQdFckT@CPn{ECFmPfg<`p>l?HFihe{dd34~Oeb%nsaQa!_f)ppJnJZMlKd5bsMG zPQjV9{$5%HBGTm|UK$u^((0Jzjukk_(Blh$l=-X8#<}xXof^I1{1q?^TE^Y7C-Miq1mPczGEcWQ8f)b*|VW){LY{&&#T+OOk4<# zFT})=rM362gS3J8#BPs)69(SXuX{~YBon(A>=YD~8F9J%=v=E1pJM9^uq+e1LA{`D zJ?W178#=)!b8p~iIQSgd2Iqt93+46$h0UnOhLGIiC4Y}&_B43dzu84R91g<=F{5FU zfMW2u8C?l%E0gZ{*14t-STr#iG5{)HIaqWVED+&y7W6#6b%fF=_=w#9-~a*6b68qC z67R;tV)SFuxsd_^5qgX-!f%pcwH*ad8zC~og1ium_u9m2#PpoKz1NylK#(gXA_lH} zczY;yD{1Djm3rDX7^}Y$tlFI3B4>u{NoFL?v_;B}fMTIkGxR}avK|o9oXD_qkUlN~ zy_HiFyL&}jSX4Zx)RbJQDTOhI7@1JR!d_i|h|4Lsk5E);kx0o18tSqX-|B-sXm?%I zlOBGzVU$R^G1K^nTr*Nt+m0bI%9slM*VaB7t(ugoeUJju5@yAS>>tPrKprGX$pYe6 zE4|0OiXP}Fod1F;iR$Il^eX)h0eoYMK$B|kjR>CRVf(Cd^)m(aj~LubeF-|y?V6!` zliFlBg-h;%6;-Ei!z{6jCe=@}py6OS?yn+bi5z0;mdEj-HSTO~_i-;|Z3gOT&pj2H z|3dld{f|K>zVm$Ai~k%1RA&+drEZ;WpI=!Y{r%hW?32klWoK%Y*{Y?VuB8+6FQ)iOuzH>={HDZ!VwfAI9M6CPg|=2Q(PSb8!)-$&Hd^AJUsslQRmd9D=<{Bn5UW2W>R7!=ET4S(h(!; z<7#KlthJeCl|VcMrDPCgS3)~wd9>I{sK2O*^-}?xQkz$EJ=Iy{%tvb%@DNe$4%Lsf z^shQ?)30j}tG*wpdz%M-n%=q6V+*X1G)NONSJB-E7j}HVj3Gz&HZT3Bt@#_-n z6!*ISWx3`8iV!VA5&r7yIU)?wG{2FJc&KwT|LzSNzy%iqwp%-^lroo z=Z2Gj+G-InLP#5!U-EYK&AG0R_*RGUJ+=aT|8YubeD{5)9KMmh;yW;m?^c}p z?S;R8|6^%6?A-vCBcvanfhPR!}&l~&}V->)2T!DC=w{rV&4Go9} zse2{7uYNbWZsSrM`v=Vt{p=t9$j|~0UnLwK=CIfi7A#x9`!mnZd4>_wbM?wg{y~6rSzk&OeQnoev zTZ^AA@110Ml7+R;V?$`<;zMydf@;-$cJ^g$|jvk88uIJ}4K$f*UXI{C0T1lmutOvcz8LKCAR-X?_|phrZfQT@9H zce&mxP&uc4^S`Zv?LZJPUm5Lz-tRQmV&3)_57<#yiRj)(N~(kVYgVEn9Dp99;Wy15 z^1y{Ao#tjw(3(@SUjZ$|)TX9hoM2Fa55a2|cqtf+rRz1)DIh%^+?S0?)*g8a7XqZ1 zI%))nL%j7A4vk8}^W!A$a7#l!yz)EkBjvYPV9Q@ytUTvWB@F!w^MAT(pQS5C0un{0s zdm4UTD_RY9`}+&9%8q#O5Urs8)+mB>G?yP%b6tlY#7E{0d5z8)(t1Anz{8Lsz7cbw z`FwN=zllzFok{7iDdziXev3A_Yw1#Xl4Rvn{W5ARlHhK}i%)w^Ou&u|DXc+ zI~Bm6**pA`OTqt!p>_cOeG1@DF9W}k!f8@920r`{$Bb4IGOr8&Ln~rmgmf@ee@Q?C z^WaeC6b}n-ST3>E@DNi->Hw${my075|0+UAsO=8ccH{e3aq?8GDk9+7lLUv3J{fxU znxW;fiuC;L2g}tm^h`5_3h604V3Pe^-{?7+YYpK3D$?`9_Z86dkSSD1Px#^Ng!-IY z*&v5imCv<*?GrsKyzZCsIo>SXh;l-KM3EwDi76DOr^tbr{a2P8&wm+xd;a1+(X)Mk zo|fYDY*T`s(g$8{k!^S@AtmEH3+?Ii>3yPa|8FbMzbj3lu)K@zUklKzR`1T zPX+WGZweLCQ+Pn-&P?Cv+3zoXlFwbsTzX3A->4Gwls<4W#=JUKv6m>Re?V+A$- zIv&#)dUXC<7GfN3x-f1QzxqPW0TWLkR^G;|xtSpkRpc!A7^~u-%sNj_d;iZ zz93a!i|9a}8k&Oz8t;8Av)|I2h8Tk3e0aFz`Ex->0#Gzgq-}WawxNTbGZ6Fl-r9Ye z0+Ud1D*n|OkGG~f%pCOtU!yeEi>3m9yia6em07X|NP)HF>r6{S)Dt{hkIN zjzc*t&`3W7xcV}KC%7Aw2y@|#f2&UawPq>zcXoL&Pdnc;1p|V@L--!!%sV=-R~EAZ zCMWb`yYf@+aA>uY#bJJm=rC!**%O)Q$f>i52%y;&n85si5jx9*+!zOLqoq8maD`Hz z{7t)J4AhtRO8^`STE`kmwO?Kj{B3gu^Dk1EBPg%R#(^q9_{2O!PoB~$LC=ps|0V*z z68gswjrw%Q*$wsB>_XY;%`WwSaKsa$U+8xE_3rWle#xw$oenHmZ?aEsv^R87QrNbf z!(Xod>Xaofa@ijroy<1tkmfkTr;QvluFT*Sk_~ zzzy(5o!OK+aC5W=&`o5E-ulZCwQTK|jk|$VvmD;Y!Z1$Gh+!2wFr^_^qvOJKxnLOu z0o@v)`3uWWN}VD7x~t)lhRDV2Lr|wd6Y?AE_$8PI%oz0VhBvlcW5vCgvpI6WI}ZtN zJT`@_6$?v#_eaeV+Vc|J3{|z{bSR32AdB%r25bfpIMt>ND8hLO`Lxf1_ze*K+VI?_!#<%>>jO-e=GfN;QmaWitK|rd=UTw98W@beEcGTWF?jAxzr>Oxsd3 zZR6RRuq}mYn9y^Xwj<%9O$sc?tC?&>c(u(bnpaHe{(1KluZm?3@DlG5S+W#!hF$~S zCVJ}NhbQZtSp@66ixTh;NrHxZ(i*uAw=3GI>-p&HsE)F@6{9;$v(=P3@p%HK2^cqt z^qiWoI*PZ+dY#ih1>xWOz61a0Qt;O-3V#Zpk|+G|`uBgro+EM1OL~fu_Wa7?pxR0Q z{r{Oge~Hb#KaM?5IZ?CZKVZ*iz87Tdp@X$%{-^ADF7}0zX{&6{hn%2!)hByi-v9~Q z^MAu0dazM~UZ<=*Jo#==Y~y~S#rFSI|ArmMKaM?2X#S7c!wK&M8N2bu|5ZNqnKpc7G|GAi~GZMCr zgn;0MxdlB^E#q!Pilg4jSYird6EC>Qs z4wZB2WbAS8maVKN;jrtWBtCl2jhxEGL@*lT6i+iIf5sh;*k||gUS^-29=+7gZkNrf zQ#xy>ydWjo8y3r79s8v_X2^?fWNAzdxxwtit8Kqt3YpBTIoWvX3XH|jvY;O4FDz7( z$=T@ZNiN%($%T!&z;NVBbciq*HZnR}eF->9ZO4#@3rjOz)P>R{9)YhJVPsVI7iJ4z zlUjN(Xtz*5Fds94ntB}{*ym?9tTqF2GA~0D0Hnrrhk2K7Q!xrG;LGx>B{q_eeslwu zCH2-(?)aO4cfifYQCFM%^~gr`@>Ry8==%{SY!13tY@?T}lWIO=$ZT}VZ8qV&g8)%T zz4CRS)j?l?-5Q*-!dyaJ{k)UhPVvaVy1Lw{QZAwf{SD>A>sMf=;ayB6!2){dIO@uC z!7+8ySyU6uQ`q>Pca{)rOkH5J#Ig8V9jcU`i$wTu1J!olEn;efp;2I=6tvb&p1^)B&X|?~U?2B>DJKsQ_VGc2{ zXf=hJHP1^XUuEMkX)_W4CyInAo6 z16l&)V!}D5{PP&HQuL-p9f>&vA-1+X$`as7i@Ku~@Tu59wUuC6)$y-r{epMr=ck{e z1W2a~>^4`mmEq#n~Eeadcyzpl@$TvrWL6B!!GGTmG_c3)*Q?&N9s{ zMls!--VX)*P{{aKP9lsS#-@-aP?bP!*At(ZmEyOZ&*|SihgOCIUlXF_57Aa9k{^`ot7$ZMkJX5Iufr?V!+ogwo4P}{G`lo@i&!Asoc zSgZ0G_52ysV^aVE#HcEnNn2q~0|b*TxUfGO%1FL^2tDy0qm*U&!YW3BQA)!B!G zf157X0A>=ifsOc>Jf9XHaQMD|)cfH$Y1=3#Yjdjec{JJFsKliAm#>klojuQ{O&uXlPG_eQ zU8JY5&@w2he8?!g0D;~eA?}nwvrVsLVKTAp~M&Ndh4&dV5_p*`%p3Tflsq-#F z!=X}hBSImn=A*k^qdnkECdmmnO*qcJ)KyK8(5XeewISU?bdT$GM>l{pGc38Q%}jQ_|S&Y9Xy!V#WoV^apKt) za4;p#`^8P)A}5J*>U50RC4~Vwq+*i)fcXzf0x725{n#$vf5A$gB%6l!^%@T2ONGtN zrOv<~z&^~9E4nozjBqA|aA;^U23|D7)|-Trzy&qs3xJVS&x6R&h?d$vPbMtnwGedD zzjw7sOHL}~)Pn-BAxz4zK{uf5jVYps2G_T$f|;{NK~$WtP}_^pg(?AziwVXu_8 zdohE_B$IJ(ZskiGg_u8#ynMCxfJUfxgM0HRp(8z5Gng36QL>r&*#`GaDA}w`lZWUt zI4gy{y6XPX)9vM*(s;Vy4~d+`=boqkuk;5QnY>XJG^O!8`;ad3?{}B{!FLj78iabP zZ+noU%OUEnbJJgn!@MxfjLeGXe7{qCWjMHF4h`V8z=8%y?irIi*%sW$^!#}PlY4z6 zch8-|e6eogTwc>G+?`kN);h@yY&4F<2&N%clSNz zEisAFW+GVBkeWU6rd+p1eduz|(GKWVAx**0EPT4?IyCmOguO<5`ic2RJE2h=Pf zpS-HaJKd$EcAr>N)ukVW;&Uj56IJ&!>+^5o=aXMomB^^4^{dgC6}7`BHh}Fh^dkdo zMqNUMZGv6#LKWVd--t*9W#p1a*+bLIZ~qA?0G(AmJPq?< zv-1<@azJ08uNmmr2bYhaYg^HFkL+iQJA4urBM{{d42e0^al4b~k-fKmGxl_Tz)z~e zYuDWjj(>09;$BhFF4ne^;0=b%A(ZG$SeVr?zO*uNvJI|J^zkiyGE(8dXjpkz-6E zewL|DamdH|xJ)-Ah^}xafy_cy9cs)cU%is{y_pEBXT4{O1&7zNYU?S_Cm*?DM55{e z)F^=5@S#Ag{#DcuD$Y3UcQ0)>A|W{t57#AHkAs)3pR0v1HQf9in>{A)*is@9OV?7) z#JfjKocmjKCs%lXgnLbN5HINvG+t1}&7^o1V9^A+J6hlMXXpTkB*VZEIW(3yVs%;Sul;0^5kqMEAvN5wZN4jb?Kl{x%k zJ!}p>6_rW5*WZ#7F}|9m@X|IVvLV5+pIO^PLL;!zhQutt%wv^r+)WmEmS4}aoP349 zsGD0J!BsC!NxR!GRcz4Z!G^3Dl*@1G5=E1d3-{lV15Le~AW}t$Up|~w>lYhU)Eo$@ z7l{UZt|ugR@s>F*5Mp3c7pR>f_g64eT|xWz;)z$f{g1e|6?v7mf8$ioh_d$a+e4Y4 z&J824?M1{_5Yd`LM0A{|tiD7X7NO5D&^Pjv=_t3F6WmVwA%`^5&Kj5yD-(Qcbr`!Jdp1Q9!^F1fQunm$I4%8BSSVfyy#lD-Dcgootp`f0g8(2AYsN;?*2FeN=?<&$^vO%iVaEk5^)p=tNX8mN6P`Q+t3C{Me&I|CP~ zM!f=1yr6=XldcrTYGIrXMmOQ%vQ>$&*VVb8_&7fPY*>4v_*wKW)RV0w3$H<~;pdtj zpea-?tHw0)-2ow|Sve`?>=bfZlvAUe%|cGQa;lY6Z8;HB!cNks_X~L-7D$Xi!u>O3 zidU+x>6?2YWJ)(PsQgl8J{&T8L#9*ats!%E$Xu+ky4_UhNkERT7NJ1lx|jSOhn?RYdizt`ngIR zWUN3QbAgvU`t*c5))TNAXZ!Cfl815j{*Ytxs1yfp2{|SYFXxu)NFenWXs z#REbPY83Qvr;w9Xj#K-ag&dQ~V&zna9FvK0Yx%t%v&qC*wZw9QOyFO?D!v>tr;`v< z`;_^ZGGQvKEkSpi^vhUjjE0p#$LsaW7>&vX9Z!kS4iNE`1pO55g1(B+uEle3Hos#I zs)bQN4=jgYP+{K&5a|H85e zeC6=RZ|cU?yt7|<`bI@HSE@hpo0(CqxH)g1C-JsJP5DZJmUEY1gaY^_;rg&797j4H zhuIUWKfP!nI+MxCXXH1;04c->!m;_fHaC77NB7Dc@qvA04nE$4L>t}q!oA6kZb zuo^bFUc4IS$cdbwf@T>coZ6|Y%BcmF{<97XrNM}QRuT@I# zsIyw`>tVsm8hSNKX0{sbG0iRz5D)n7aBtF=&~rLeW)AL|A^Knf9Yk`z0*(zHAMp%wj{DCV%di3;dM z;@N*L4PXrXjP1P+W(AWGD-YtcWc0Hs@X_}#dP}9^_2U1;5!HKtosw>vnm>t$-om4c-H;O zgRg1e-tM{^eU0d3U1;MEyx|C*0CR@w(R?q|s`m$vGes&mwMf3E_BbePu?=vd%2a|w zU>AvTjqa^MR%u^+uu^C~{@02{_gi>^fLJBK!)9rltGDb``xqO@zXOT zCB3-o5cCU$X<*kIIGl>SgDFqzCdvVhP)sF^VXM^O>iAB@;6g9kNdL?}r+sxWr`!i2 z+~z(fwV}L>-*_3Y_tXMvpKl*Cidu`a{PA-xDM(9iHR^|aZ_&uL_B!MCUOD*xM|`~Q5Sy!|k~y?n$AuS5DJOF{|?O|5JmHCnINt}?i% z>4|$lfinGXbYrK)`VV|{s8Z>EWe|qMCS~DkSo(gwqSC2L`c3nr39SKY1mILMu5J?) zBp0Ft9B;THuEFi|h6(X*9~gVGQYvVEE)&f zq8(m>?Y?vG%!*h#!DpQMkY4~v`LjrLQr__ct`@L*c9iYI{D_2`!}%m}pw14C`EzD^ zdqKl(r*k5#WSH*Z2~?7AY#7*+jX96?5yWRGID*BlAjft66fK|JFMRJqvygmpW2H=F zeiv_tzv}#j`sk$MiFbeSx$OqFJ#Vdu&^|XRIupkM=P)Boid|(>VRI0X z{H_b~>(vuN8r){z^}kGS>M|8+EW4J-48mK?ez)=Q;-$oG<&teyX77mx^?RWtRPWmN z4I*jVXyV)sa}dc{wl_NVVB2TD>`dLu>w*7p5WEY{@6caMFAYnJJ24z%h)AK;dbQVV z?H%jsXfRRzOU!T3;WV_aMcqj&C+#*7LXs(;{AC|)kqjsI;-NyZ!c|V+q+0jVZyQb- z{o;@DhammJ+QuN0Onm#H#?aAcL^2t?O?_%{myM^LcnQ!u=J2;XzGwNgu=Ob_>)lud zvz_6T+$6@^ko_m!SqNy8+xu6(e|4mqUCVNW00G?38}jiQqVVHe{CIB&ya>aUZsT*7 zW~%NUqXJxz3Uq$aK4q|}Z{r~wL1VjZJxbOJEz3h$vJLuJ2g=xJ1cT9so^)`wpT25hPUCNoS@W zZ;@XJh@0H}%RVnTX9VZ!+<1vjIG|NGMHbn6ai~+@8|inu91wLT8c$B+sJmkS(eW%z z16^~KrFt60L9uUpl}N~az@l%aJFIS819QAFn&58dwpiw;^_p?i^MmE*vG@OdL}KFH z2S!pQm%QzD^dmR+m^XA}Uvg0q$NLtHN6DycdWd}EX?%PhsAIJfAr@v?v-~-ih>|p05k2x}qDuph zV;G~jTAjQ=>zYQm2K9fgjQj_$Kq%v!KPD3QtYZ4T^epwLjh;+=?gBUy6T^mBjx+%dq$-p&G?YJ<&<`it8(4A$K4?gj~xb0e>Y; zYNdRXzP=tKS~t{VmDY9i7|CTK>#72CwZ!1QQPshQ34NW*Mzrq0FsRVFNN(&Ef0i;Y zI&B2zuz%uQJS@4%>-aPag`fAuM;D(f;xKG-@hCsV&bsHGFpWzd zX&B_;9#)PWD*B!w59giI`Q#@cMiT{e83CF7%(@HtY4)N!pPbVS%jhYBx1&V7&|^Bh zDCA=#%u*n?Sz(178g&N%p>-nKwa!zsu>-*Co4Y}0uh^c?DnS}dP>ktyJI>+2cOG$v(?=TMv ziE3Zjm(&y|QiTL;uuy77rq0#>B^u+r;Ti%bv3JluCtY7u2{v<(*Bq{y7%TBG~LhLq>VZcD`GuK69v805C@mzM8EH>6p6wMu%`YDik)-ukum&HUkopPR>#2G6o9{o9(~X=kh4tDY3T8A^awwy6QG|d zf*!uVg3reAKU*ny7#i@ORsvtDUz%OSuN?VLFx9h+Tg`XqWFup#NfQ2FX?-RS-5tFe zf&CI7bw|&bsE6jAJ)=?GFZZuuK2u-N*;n5UQ#Hf+SRxUz%AQ5>7gYmZErp6To%5>;UZKIY zn6Z&dDlmSy`r5#$_5=bm;3oV>;^(w#>}bt+~xUE}aVTq^2|jAmz(+qXzG`Ia{={KSW<;kMHU# za$%^9zqR3GUaAf}FW)+wx};tHyu0F{&u|eYz2zrrgOS1Pga6mVT<=Ve*ozIQqXnDUb4$_b+Yy47^M}^fAS` zu!lbOrsa7%w7pPS9f0MN-}z1v!1t3r9;%wgg)I51dEkRee9k+0ivnw*jt>By@1FDN z!=bL$Jw-y)G3luIW}wc(fLz7zmwD=XX$$J)e$VB9Q3rH#qd=X)4tM5_#ncHwKKXdo zQ)d9rts7ByqC$k^^5ccvMu6OSD+5LmS#rx0e12Mr%OBVs-VRFVdyFvPGH!R@hdTPi z;?()>6BZ1>)x{=|0zI|45&-WJgaU@l4e)W<{jN~%vlCo>;WoSW$n zs$Ab&+$SN(C;uh#B7jfUiTWgt({#PN*wZl{y2gvH6MgzP3dg%+16}P9^3Ndpe$bLU zG|;t>By=s}aS{KEE})ZhKJ?9|#p{tJbO}K|c~qdw08WhQ;w~h{(@d|&#J3HrfoJLh zDf;8(`X3vua3KzsdotE5VJ}o62|LiPh&7ywFYFpmT_Oe6TMe=@yzKrr@W%mDce-7| zXpb{G+gW{DCknKp?0h3;%D+HXrY33 z>4pUJ`c;@uk}01&;Vu|MveCc!ioNSggEC8U8763P(gg|!*h0{!uKsLr@`kN0G@ z_+Y<<%G8kn88;Hgz}m}!h~wbM36}UT^8#t70qebx_^}OYQFT8y5tW*+*NW6$KKZTV zM6?l8JfDvJbw2rTLBxs|8jQkNuWGc!%?#}i zwEhi?>pj+%@=K}Eg+3W#ABLsfz0l1C{dmRWAwGM~Isc%!LVldJUj)(ZHV+(^Tx%3rKc zO_`j=oHfI*bQr7XbUt~G2=L5LxpPlZn8cWoBbfXX$fpiU5%!_8u0jmxB3N1+D?#5jAJPaYW< z-7Kwb7IE!5QZY)SvJwb8S~2xVyO@PNhhHk-$R)E&=pySZ?`jcdouh2eA+Ep-hLJhj z`t`6q6XzZ|6zTXF>5U%g>K^Iuc;u7wjug_8n+QBPlSW`hyX+m9!8J1YrUpi_{$*H7 z(eF%drui%P>32WRdpkmKmvu`>W^!D!V$MKs>87Z~&BDMd&+Rb+c8Uv($D_l>PgPe_ zG0OJ(C_uE$42HB{wXV*A8>;4vDr)~PGbJ+CE0x*(f=1QvciC@rbD*eoP+ zl|Px#H?ZTz0+uKO>iY~!zx(_!3TR?!aTD!UX!)9Iuv#?*F?f(!wJoa(7m}vPZLIGb zgsz59DOvw-i89^lKa3*O?VgG!55Dg%13VRHw@l zV(TIsePPg&t~bQ9!jl7Z;=yBkE!VaAhe7}0wTB5;ud-L~^*wC}_k%P5lS=rzs^UB~ zX1pQ9YN+b`h1!#zM1nXHMk@SjC=rlNEBoEky1EUj($wX(iowXK-{8oqf5Y&PtqpoVLGK6j9^$=9V&1>-bKgaw zkRXlWlf+>D1P1>Vn88&BD+V?<>r~vIb2j!&xqa<$nzTcNf5xnC;G9pqNjLlIhtT0@ z9s0TBP#y*I>cK|XDtbK90*lVu!n4SxSt1Qax~qNxi{b%9q!P_K7&XUu($1Xbh1nQ{ zY4tUZ7~Pgt=`Q@xqi(V^Z(Qn06Iq__`QyCZQ9??gc8xV^|5AoI8o9q1yuT)c?~|*< zj(T~V?MFaqCn#Hs;gu4-+U1FARt}W-C%*n^Orb@HGiwFy4pZy2+8zM?6IEpFsQhl( zamqWaoMV@Inx0Sf%!3)0np(on`jp zKp{C0BozbB_!&gMnr|bcV&e9yuE!0hw6oNapA(Hz^TyX1gM#H5`IyLHW}!?z#?B%c zK^X~m+Y-;$UX?(oJ>m0gOXE!S6nCZX$Ne-MZ?&qws1oC6#foAervg&kVH;l!ExH~u@B`s`{qB9)j zHWO`UN%yls_vZ%2`B0sg0ZF@aj`sEXtdK=6dzyAHRQWoZpMtmF^0}*0Piyf(1&lDl zz{W%Nz)ac{(`HKzh^S=g&)!2662uGHE(#z(wkRO8ry=`(KxSZn8w;|10y1+K-F`za z6TW3hH?1^!Gn?ZBLVqhKl7!v`^!7CLV*`2v>j1sxCVh~pe}fih_VLSN5nrHwSR?d9 zMOB@|&%-$Mk&`2bey(@uZ@g{}y_;pNeDdnaa_9;Ug+nT-DW&J%Z7Ilr+GR#F>s}k7luo29#l}< zGWVE&-ntqFEs~QBmA$WlE2((U(o$6Nre?L>_SR1&)@%xWIlOj?ihGfW+nI{KcNbiO zEPTa2l?-wn8zf3q69mjf`W%gbs;+B|M;bAEsW!Z*4!p1@?=Q6Nq^45p^ENMf#jG)q zS+({Af0IjT!#o1}$?8hz&=fh=x$j(UF7RL@{ot7N^T`%{O|I|anDl;;UJvOJFKoV2 zSIo+NAMvcTObnP^4;oKx>I0~8=j8xunS0Fl`(GUYaoZc)cnOT@qdSaqn>{iW?WobTCbK+pbK)N(KRX;c&8BhqB@G#v(4w@=NxqTYc%oNh zVD!Yn+U^4z+d|TXta=8GF`4I{V`6}(2*Efqn`O>ToR#g^=JAz zGS^Jcs=Am&c{;qaD4gbGXwIb!*u>!KA;IrN=A>|0FGMd^p^aPp3ggy(GS>Q)#kf`a zoRc;BtORz|(5INzMH@-3{Vst{`(-+`tBj?i7WfKcp1bwJSTSs3d>N;hiOJ5@P>UFF z^Qf2sHKU6ea2o6w`1?{HBO3!)Q!ru-*BxWU8oTgjKg-vu$S!J--Zln?+Y^RYP21F9 z2uvkRAhJp`p!IZ{hHB=lqs*UO1C>BFfosrWRx9r8%PQA{`ckn@>F8qJPzqF?wTdy9 zYK-Pm4OMM7eF}xzf|(V=bx;(vEE~Gup9L*5RGUauA}yPEf$wxns9{fO+147KS@&ba zg_Au*L*x*=sl|~-nI_XLFd&kRgGh|Eb5;>{`&86)a1MN*_u_rBN6Bx;$?KFnBu@4! zxev(XPj(~p{fae5YdjdR{GzS}ZR6{hNr0P_w&V6Tj+4O}v_N2~2)zQeaeD0|Z zm6onXrIr?7X;G+TPbw`|rCi0SulR>>TT+{=B}|`h}LQ#i?+LpI}LOsQYVpVmi@5BHu@`e!u}eL`si1aPMOz zD087R_9VSjouRUxMCg_36m?ou=Wl@l%R(J{QYV4&h5?N)c?Kw{I*qDxP^eR>Oz_&1 zI@OkD3K@Xw#h>>Yf~B#lJ$%>OcxR1Y8l`%Kwv`lqdxGbTEN{G)I$N4VOP)cErYMcV z-pCBE2lsPNol&Q(&$T|-*|(ZFu{5y9U}v2|0ro?gRh=17Y8yJ2?3v^FVL*A+rsf>g zq(ysn_N}74W#mnMiUaGismZ737gbwyldo12@kx|y9vT#mv$qkIm0=yrQUPvSMc2S6MC(M#X5Bd@0D5AxtspdT|QZ77preJVvwW6 z$$frP)R9AWLH{KOkB585aZu{N;Y+gwrIzcgd*MmC-t-6)YWXlr-TjAxM+zGPusunq zL)t&BCmG9PB~4En%kp&JB8XM}#sF_m(oF`$-FMOuq_4DFRnrm*dnEwW#|k5{ZLo6~ z)XH_g{~!!%YvmEntJ@G4?e1SALrUFtL}1LjuG(a3H2R3r>Hm9$gJAy ztelkro^JJRy(^w>nJlO3Nu?e=>D`>52THr;XeiDUNQN7esV3^xL~cs$q3%n{pf}=)Kw6T=I+;zz1C|#*b@7S2vL-uPRr^*$?jE^;j^T2Ormp z2e+x8@L>NK^0&W>jz?^I%~Sz1G(6{^fhEf`OX@cdX>J7)k$Lq1OkhC3?N6 z9Lo;^78ML;MSAhJsy<2*c3pSgBaHaA?Uc8O2HlNzw$PM$G4=g^wSGw&OJSNKeq^3? z=LJXx@j`+Rotc004W7J`LelQl9R55=eD(m!I6B|2eo9Al;$*YS%Ff6SISV43(RmM+ zN!X2Hzp^sXHu@>%n4lV5jUPNQV56c7Du=*xk^EPHXWCJzmiRM8e>e-pi>F_JU`Ts4 zE7u-dJ~|t6!(DI$hd4%@vOZNtJd;(X!(~TC3g=ler>85KGKtKQBjRHq<>wnFaSWnU z=TXDbzr8Y;p7gnho+0?w92Q@Dm+#j~;PYaTX8q|F*&=@N)^Lr)pV1_BtfC@V6r(E8 zqfY_Q3C?8}4sB4s!9D%60$Wo+jeFp>*o8Y+KdMx}a^sh+% zG<*1Mu_rzdmB?U@+)IZaLu#c5x_aQ3*HY zSe>j&Tye=I)X~wqG#{mdU0B_4nKBaL-KCH~=d~z;HMFKq5S6o7bqfQBN7DVoTohhN) z&C(QqK1!X5n1<75e=@+d2&Tp2HZ?+mB>?GHuHx1?Y7VW_m7E=vGAVytdAa0&^dMJs zy0?jjEruhZP9&@|VLI5oEYtn(2Wu$$;1*3qdJHJ~1QK`w&L{so!O*_6yWhZ5rW^my zuxxAezNkCsqYgen!a^S7In>fCWHhh|3I`sa zD$BG`iL_gfLt>_AtsDB%liBCsQkupub(D5KuME8n$HXLf#efUPH`948t?#zaMfAJX zbSBMv+7i+jCc(sTKlCidY-e-~I$}pDKa1gS{)+JHTsHFV#WXlp z^y$KsYdTrErWBdrS7UoW6dyzvjA_67!(<(TGM2@Cf9f#fk4)pR?<;mLX%Z3!Wia+O zBC$ByfgmrfUp`LI==2S8K%uVn1pQ{XLafeP(Uq(lV8u)n7-0ZcE4$vkahErmrY~8y zPGx-aM+DR0ext=UVN{idjJ{(kb$Wn9iQ)QqYopnne37^K&4{$*g_8dd{Tn@e|1K@< z--OkM>EHIl_wVJC%KJC%!WH(feEem^FcXt8yio!((fPnJ5?G0l8r`JuVp1K^A1^tX zw1|kp6+ZXh8OtG&>2yf2-|YyI!fo7T#G?loBxhK0TF$U zGBh^kBa!C)TR37`ZO$r_i=34mH-qj7#}wLYnL=@gwG(}~A)Nj=brm?7D+D4&{j8q_ z>I2YD>VV91pyEI6qM{IqSfvxFBwav;Y|+643HE1JGaSA0cia^@xe|NZ5MQ7XuvpZp z5&WT|2IHSjs^_^1h});+HkKT!<@!FE4_X!IhMZ9wd^DKzFn^5Snm_r;IB*j$qPZ3U z!bk-O@G0#!xF3GdmJJ6BTNw^g!bWk=^PrcXF8!m@8LZX5z+M*&XPe+w2auXD-}@(? zL+@1aQyH6ft34-PEn0SvT4)vhmlue+{0zp7DG{;J5MhZfTC>d%PF5y9$>_*J#9fL zYZ*dwXepn(aqIAQN81GV?rB3WFWLGr-SV-isLCaOZW|e6|43dw`Acd_96QD07&}%_ zQ)+k0Cyx=@d40z0xUKztChF9%M}O;u8poc9@u8sK?a*&8we5+b@0rj7TQW;yUKCd_ z6GSS>^^6M1=|bh${Rx{Hfra}M9s*+!=7E2Wg}Em`?f!%=iE%b{svdycQ{*arT7Bsd zE|hlpLY^f`JA}W5S#{~FIFqU#h=Pc`CdQ@H@;H7fi$^fkolDv?Idef+u-}=M$EHy4 zbHLqnKKU#bS*&N!`;m}O&fkhG-M6!tANKtsv==!Tq*-`smf2QjLHT68Ah1||x?z>} zvaE1ytaCwoY_W!;w?Xtm2=unrW+P%v#l!najp6+yc^UvY#gcD?6tBSzF><&PEmbf; zCEhdi&h%GVA1YUW;zbZOx^?dZk)gc(SovmbCl@EpM=w^_!UjU(@a`qry8QlFD`<<&LQFTF!uvwJm2svDM%3UW}X|} zT1rAm6U??HA2Z+!KGdgiv{qpB6<%sAXmh`te;hs3^qG%9T0cP+ z)qi51)vp`o{wV0vlEj3`bf`dFE69u-37A;QX58geoe)t~kU$y#ER@k4wnX>cuYUq| zMg#o|eqq$Vq8!Fc%Dwebh&5VkHGooX!8$UbVamH(rdns1l)up`c+V$;Li_iMGTlob ztf$EjZi!8Ve{D;8TV&s-F3~x=*rkHJnY8=aJJDQjJeF(8FErL>`f2~8my|kCEKrsz z11$O!z9N)$*AcAu_~mKC*u>d@;5ZoHFY3CnsIj})wIdv02}vwcfN=db05j~E2}x66 zk|}V#yQjxy>hG7iHy63>+sB1 zW!wc3=C%Vxkj0g|D&$3cSuriE?4>m7s`XY*1as2=D3RRhS7LJL$e92P?wZN5e9g_d!T2{+^INWlg7sg}GV0^S2F z8^# zl0PaY;=fw+C7P&9RCV==58WtrK6&{T5K2vKhNUkhO)^dOLyN0F6qbyE3q4`C8TqdjVzw8aH9>f?jZ=zdY7s_ow5JB9wFhV|!2Yfpanv#)`(+Za4?oEY9! z+bguna(}bJkRu@u9M0ye6vj@8EESR4z7@2q?M&MdA{eMDYH0~unp1M&L`yDt!%e=| zb3|0n6yG03d19hHAz+TBqif$nf6f7s$;YZD>5s&#-efRx%Y<)fyP504x4ptRbe!O| zs@C^cb)J{jR7F)6>Pakl&yzgaF(g@JprUxEM&e0jW{j%MbnAZw1H2CXtO5qN>rd#2 z4-U%wp`yju;~>`_v8-Pw^I&JaBrQ#G_uE-L^WSUv_j|Fhe%RQ+@NYh6-0`=X?tE_x z45V3$qv=v6r%?euMq}_E4Q|vnl54d(*r68weB=;UrhKSwFY4`Z_y5HA(J9J6LE|~^ z#IKc_dy3QGYbe1tOgv6Fiso@I!bWizoXRJU-IyMG=X2^@w{mR4 zQz!o7rCHRq@%1$QMGZBuj$}i}erDW3{Im(rRd%IC?LEw2y;kIt&wouk37Z#;$Y$y3 zm<7dRE1`dQPKc{~^$}edh6n7tRN9oLyJFzJqMqxf7E_QdxYGn8#vZ$&dN^vnqSLMR zA{+g_`M!b2&jE z(2cFe*dI_MpIkrrcq72wXe$a=pClRAfock|zK9}6`_v*pe%Xs9IlOZb{MgXCEN3Rble0XJp z5%v`~p3(hv?t8wNCa6@d#)?$ICN2GE$S%x{xx`Fj-dW!xwQ zU{+*|{E%;*hPAoL>Xl{o9OCVY4O0FDGA2g^HbCI+u)M0P+UkT8eGt)t0B?s*ng0%> zwqim^%096*Cd2h#m7q)o)b9;?s`fPoEUpm-#!V4-o(&L_3}T)j=o?#TS_90DJ=j{2 z;iCsE?J_vXJuu`l59vpZGD_x>Q}v*7i`(jWnx!GUceIrk)&xp7#|)z2N0vf=KKYPn z$R&TjscdNQ5NrQ|sQxGTHfcawoHJcvwEet~K2KdBWwL2-nSD1+tF#EI; zW;G11nMKH;ZV#c9OPQSPYqZ7}Q>_x_yg zwM;mrcDLR8tGZw}fZ+{RNNtJ1rU1qW+#Q9StocYmA=kdlCyNyF0UZXy`|YIeqBs?s z^a#cV@RZu!QJpQ6kxL$_2h3`5Ga2f!>fk&n{UX(gAnNH^K6#1tI@z(YR~==#m;2*B zSE`sLHLdsNX8A+BU^iL*nE@-)-tRViE!h4>cLJYg8_0z7Y>M98A2pdlLFzfzDqXB! z-Um18tLE7s^kk3u7ac_#Btt6+<9TT12ehI#u3^-fOmu@)P`~$i*yL6|;3~4`~OUK}u7pw=N%?(35r7CJ{sR&}^oT!(OrY_Ln zC|W=&@mHpyT4!w%(esh4w6*E%39f>0#X2q?@Iw~)#;atJFQ@=2mmD=sf*pJCPh^oF zSXbgzbYiIV5*GMIXcE*}cHUx`_#ui%CW9p`%cZy{%y3$;gTDVUwDA8Z9ziO@_+t66 zNjT)x7Z7>NIwN%VRbSQp_Np1c+}~Y7>b@nUt_Yt|P+xvOku4`>N|r{kNX4X<$`2d2 z6Kd<3&!s|b7O?cYAMO2lup|@Q?#vm=MY+o*5~Ys$BFd*PHc`Go9`{469f=YNUiqHX zd7(yblRJ?uLrB@lYa?dJuTGX8r}68PW|ikk54$VYxhuBHesbv}es}Yt$p5dlW8VPN zlHk)9^2=SrtQr;KJSXIccqz7(ESXlkfD^*^mumpeg*XlK!%(rs6RnUAUhI8(qk55c zKYu0i=?$)4$1Kbe=d@xuj^n@Hkan#p+AZ_p1H4aXsF*GWGCtHhcOvlTYne~4uj+cq zjF~^gA`6$r9&VX~oBWo4e{rh+-Htc!0QLehn#$JLY0)+$j!LtHt=0Z3nb$XJ<9}h zm2rD5rh7p6C8=nZ;|&3|3roX=Lj%1ab32DuA7)!KrAtA+-YOan3WgOquc%i2vsBFxO9keeEgvF2$ZcoySta>;;NdrqsyLGpc zvh%K@6^18iXKDW>K6mJ@0HI{{_Km|(bHhIdIjd!N+&mi1RZN>HrHS4Fi`64j(P89; z4=phAzkd`4_IwN7f_Akv%4{}QTEm(lY{8jqDImMq7uGfI)^!+?)o}@D4bc zva1{Y37%9Yfk!9v{b5w}B;>0rBh zmQY7%nnB3r1WL@P#9vFBDV+{zKDjBE^eEL4DgsNL5bw{hdztm(l8O_pyy9lqfCx7> zyVXY|Pz5GKWA|q+;+7s6HwjSmnWX^yIKTxAJ}~ebebfE?Fz5@ugI5)w*4v7-2U{=< zGyTgf=_|d$R9e?~71MgvWk-QD+*DO6)^1M?(C1_a=^!qR#>10zouy~_)30<`X*Y3v zc-2+)?R(2Tanmn7V`}prd=oy6p13oy&iLGZ*Kh>&KiE?s*f=|-JJkv2KD(CWEiJ~< zJzYt+#_lB`%L0h0hac2wWnq5OE)7veVvlOB&18v`Oxv!^OVwQiow{HS#nHKMbpLMV z;qe-z)@SEdBj4`6?yo+?0 zTENeI+ZJS=uFNN;iMg?deh3}UFdU#`>`f~8{4Nn&aR9%u<{UsQ_=`)mKZf?24{t2; z7ytcFiNAQ~xgvjY!=lKDRjaFYZjCq8)fM}Ts6CdE6dYW1oxmA9`Q+KA{Istj#^sso z<&)P+N73lj1IWi{n8chCoy?zw;f|+;LCH>?JQV-j1Dc|Cr#1O0JAUrVU$(74qR}UX zumni;xer|pTYh0HdusT`S;bl0{Fs?wht=npT!?Xby1Uh?80b6DI>tyHc%>NT+six* z)Y_fP22(e{VFFCPH?^|ZHy77AH6!6K^s4k4-xEZd(yFs&g3EqwMQCW#>*QWnLOA2I zM+9~7Du4(=S@nls4HM^Vrs&J2kS*?)J9>b$0MY`;h|?K&dV0$&`EU57DG%)P%l4t^`o#=pTZI z52u_fL$=Ws7_zM($Cqu)-k`-1(XzeG7l&KqT4W=&VgYx524VFn2lJW0WSjX6TZnW) zS+WbeWh&ctiBDBw_GHFQhD(${-A^Yt;ip#H$^DyBwi}>R7>ET8$am z_BY8c><2|PdZUpQ&Kz{GZEE36D?9xm@3;@-QXdYxK}#P@fz=-*FiKd5Hbs&NXJ?B$ zS3GFZD=47@Ba1umE$dpu4~`e9Ir7{YA9{HeR4A)ndzLWd@l66dAb?qv9<7s0rEBw;G;!4LMyYu%7ND4n8ulLY0#-Axc;~O(9{V9TXB-QLuMN zbSSYyNc1XEOQJ`}`$%bYj8p_Mb%E5fc>Xus`C3i42x%l*Sf{mw11ylTg?565)T*$Q z-bO}*(ag|R4q_^t$5>OEQl076aGkctkQqbc;1>Dhc0y%LHLzdoJ`@Z)ShQwQHq;=s zWWUf@dEG8*Ih=~s?Z0PMlo0H92l1EPVsdiwXicDH9rn-*?!)!vpJPK`d@SVdFQ+pI zwXUQlG7BDx=^Zr8bG*ivPUD`j<<50=fp3+09M zz>C$y^usQIRsi?QM$)vG{Te5uolE8Z#Z8X%>B&iFwo5f#vwNTO$+r`l zP_%JFkhJNkbPaJ+*%GyHKk!q_B0@l`(Hpcy`U8+zBT6%I4~8wr6VC_`rR5Tu8laOC zn~e!8E^i6te*DygSDR7lngaJ$`2J0XmDm~AhR^}y8k=J{*&+%W-M;XGY#Tyiuw^&7 zgM4>Kj?=?JQtz|1$B3&~bPH93)g8`1XvwACu2rgd>zX*cNs$Ap z&;siKq}_HqAYxO9T-g3NoiCvD__3s=nYH6#g+uIz47*I}l#^w!eTFQ#3NjIfqivO2 zgV4WYjG^Eya!rX%@qF|!hZ^qd1;~krPQ@~qAYLehct<_m^x*B`V8~e&F4{bi)gix2 z*g{Pjw5>PhjpS?VrKPHCQ~6g=oBwarFEpKw**73aTaaSZ+-PqSXk2)@ffL=qR?7HC zJv^@fIk9qJuF-0E{pw>KYj|Oq>F04-xZxlJIqrLkSLtP)(H5Sh>y^&Us@3&h>Z{W3 z@~0%CXbu~mq|;V81Vq0&1BPf05Ty9h68ReI3!EiaDBW*=W0Qlh26XrD*{$o-q$HQz zOXn24d@9{2CV^PYP08Q6IDZKWi@5{&Z^Wh20 zm_G4nLt1iO65F1>DdonOFIP=FGs;Jj&2o-sMJE;s4n_S4PK*Z4^ioSJcaZIvo+-I5 ze_LlRd61pBAWY*H6t{d!&$~v$-Q`cu!Mi5Xgn}PWSFdOJ`zm45Vj&z8>SRM5ds3&_ z>cDq^R7q)ZwojFkTlj`rLM3}rsog47tQs0pQgwP&=f+T{!|H@c`L)yY(FcgSD=zz7 zP2oeI{H& z=v=mqg#u!$4>Gc9N^UqyHEzJ=4_da$;MKA26t1>!v(cF10{r9r*||+s4M<` zpvaVfBCfP!RS-Rnv47BFL}}gIy$uzn=f>Ac{TpEF)9LG=VyLcPbu^aJWCK@yOcwLh$eaRGKKZt4( zy{6ppCHe!$;zdHowy2u6hkOcSp-E!_lz&#n;RA)9iK?!jEb4Ht;iY2L zEPrnnsI2V%tVcB`)U_ueO^QQ0VqY)^m}lcs?!&Y=N#qof+>klJ>oJW(#xu_twwsSZ*Q zajK-WDrHq^pG|!wzx)uEA*kdRtY`Trz@{(nvvXw=daF8%lc<)m)_)YGdWE(~>fpL=7pZ1Cs;2R_f`^f_g9lG^cqVt4lJUOG@m_!) z7fa>BNl(~IHvXo~*oH5VEYiy{6sD}<@Fd+B5fU0{WCvKnX%3C_ghuK@J$q8GJ~WaF zjVSHT35_%aV0!|bCQS#jXL%tiDS0HQ>mJ?26P5{e>`9#_)sb7VeEe^yv{)QHUq|Wz z%7#kzB;6c(xSqWZ!oJWBT}AMG8H9ZLRFfdUq(*dW<%t$J>Uq z^{JLpblOT1=oWD^jX&LCAIeVCwv9cj`xoXo0~O_>KDjzn zshTzb^jyjlnbZ5|gN8x(=&E_6h)hoiCpqZ@na1#5*@w$Ao&Mj1(T*Y|7!aya}@2u)vY5m64hdTD8&SFdB+j+WI8^S`g zRe7z~^Ud0PJ|udD>@gBO@;B*3)wBx6$b%_5?Y7F%AULlR2Z%ECOtfyUjR5^D+6V|R zz$Rdq{}PaX>7N%2O05irYk&zdsn`b{BTa-i9sa!Ipc}3_?vs;_#XPDJ<$zUk(-uvS=?^v(X zD8@+cdX&1wXfZm&lk_sBaTb+ku~2r>?wHW2gk^H?k)9Em?}j}C)o^KzAQh<{clHe^ z`maPuE_qFRF{RDX8=zHP)5<74Lx4~)Cs3L@(o-s~Ykb=iO3kJMS7-Ue8EStHiv{Xd z$};hJdPbD)7bsN?cTKd=pmtpU&x4MN@z;l#pgvSKLH&qKjFZHu{BS2$q)PeZ@k>M= zYaMcoTAd;8&`%a}SbfZm-Q|3I-nd=D!01pr>qI1215Yq{8OT_trJtSNpjU~O2rY(P zn0+Iz79L9W>Ean%raSh=qr5lXV}4cFCnjak7K38pbPf53KZM_~tQw87m9wQq89}$q z&h%A6^QMzLgB!#f`e;vT&$YBCT+!mKHsNL2ElBYxRu9!mI@Gf#>5Qdk`6q&Mq}?+k zy+MK-F_9O_IDb?ix<>r{t!BI1;mex_7u z^L2bfy`hplsWe-aauvO&Xawh!4&4b8QDrc2rYudI1WcPU*OdTira!Icg%auiHWSG}*`(8foN?D?Y`TZ3Wk z`p0%+XQ9Zk(ie1`iG&&ZHz_n*X+=-|nh<=7jc2tPQA9+c?=e>Ug7D$RwBGa$O>A#X z7}XPPzu0h3S1t|d?pF7Q1E7g`HA62kNOHHWliki7L-Skn_DRiD=!MmIDGQ7ol0O=-Kr;f;VcWnu_ug6&|?P6%kp;&7A1`l(uNnS zz1~+Ulfhw6>Qq}AQxO64bZxH$SGHJH*GH>NSl0Mx=@^8SFzps9D3gx+*@JcG3Ud9+ zxt2UM2GUAWRo9L5a4vj_Lp^Fwu2ih21L&lVrzz6RA0HvS{rP0yV(NiaFr`?rjmn#mVayG6V3P_|u?GF( zo2u>IW4RC3Nt_p1#&I`n?}0?PYkB=MFIHp2j@Lv&#sIq7C}kt+pNvo8PRXgW!}ZY& zo>LcS6Dv9h=eLLB{o%Syw<5xCuZPlg7M&weeVt)!!phCMlBm2XJGW5+iCpLBnl`rBF$A)ul*BrVEHK0vE z=T^nlK%5@r18cs(mKe@kJ)VdU*2E=XiI%2SCnhEnRxI*;BEk1YJRiw*bI;gecv5bS zadMqjZPB&SkQy|aOqJ?(WOW2(uveWnB$>DvChoA1eX2cmjA4)dFFj=v^&fVIlTeM$ZO3r2y@xS4aPx=#j?osjN>ud^qK_d-u zeJ`Jm_Vy89e8X!|ds487o-N)8ZtaU(i_W_s6w%8J`^wOdppL=b($`8^y25#2`ezE= z@18NgZ^~qUa3_eUN97zqi_T$2_==BLWdwsY!`tzgCndFgb)dbM#}gB?rlnNSJ~!d%DxjPv6kNMp!XFgmE4y*b%{d>IPGzL zTYQpYq2+{^5g@QWlM4%1+us}35Zd$--XxiToY?y7Dyf0p!d;CouXGiQv_7iuXzPkK z?B!B$x<~)RC#}|`*4-oA;~2^i%5|tcg_xWXA#xPLBZLqQhK0E1NB;@{gm7Zv!rDfr z#Q9_^ng$8;aV|2Fz=iXZJ!)vxU2tDiXqMqq(<&y>@5bLq3EeTUORiX+pEF{Bcg5&O za4*zqnb609E|yAj6>Tk=-?DW(xQ#A-u)MXxi+N2f8RaTgk5UqLlfn8fT*diAdut#% zwm4e7z>uSeEpFX^hR|OJ;ND)_LI?ePmQ&QmmbOjv$t$UWb4?kS>s*R=f^4`BJfkkb zs5&>^PC3G($6q50O4B}n-6;Rwkl!%jX?JKU3I@$%ro)C_TBNfr$Znt9jV0uL#HTr8 z=L0)=cD5I@lLY(Azx3=hu=drU; zEEFJ-(Q7~w8A&P4AeHH>JVlG>lK*wToD(ri>XZ%kx1p+oEfRwT zAkb59J!D&6!)6mtvis=Abj^sJWW^i#;{59BF+Dm(eZJ1bP2^)${}GiLKcom=D}+VE zt}fq}%qJ^gh2BVzjbG6*N>fqj9iyhJx(KUzwPyh>hJvcDKay2oOjXw-#f&^ABIFgS z2mx0&!Dhyw=xwzA@J-yn%JC8OPt+v<(t}@%3#+2Si=#q|sxsXKiteg6;t`l}Cs#27 z^T|yFhyG`TsVR)!rVRSo0TGTniuG!=9b3(~AO9@q@8Ehn_6@Erp$JOY(f!IfBQU=$ z#j%upM43MOz(+@%|8!TeE^lWhN%;qjk<7-PDKsZKx(6P`?rz;0>fPr_?*TY8`q z%7798;kb?QStMc~N^Ek=Mhhi`>ZMVZBXVO`<xMJPcy%eeGaCi`1K5daaLF6P0@s&&kp38i8VWHjP-jviJBSEFzr%PZ) z=YskmsIOeV`|bZ-{p%s;?haA3Rdk>f}e%kX|9-%&onOL-gtX#XZ(uZ5~~r zAgon_2I0N3A&^xNNEGj&N-}VJtwwM0J7dmfZKvN|{0AkMeAMi%+QH4iv!>nJ zCwb&dI;9f|vl6ZAX<3QBqfhQeU5ug5`qvjt$%Q5D4xAOR)Vn>ewif;2Tg7AHKHyx% zX5aRqLW8<$lNR{F+z!vIzxduD|=DW;eW(}V zkaTf8U+3oCZ8t>G5Z+O>#vbEqDM+tVm4f}i{4`-jQ{YNG*uft+vqinvc#%I%uz@1M zx!u0&`&IAtu3i;HNxficM~y*Z#-$ZAl9!IXShIWWt`gS5_uEcgl<2md5bBhrXlvXk^od7VY?*H2xt%!sTw`P6vU;FOcA7>OV~;kjso z)hM=+uqH|tkA$@}LDuTG82RREVI+htSDdL7us~U4f}||JMWk-+uc6kmyXfs;DDXr zdOq>`LBJAEMd(1*UA95kEHJPcQ_+qH_k|>SGKi{I1R(czu>R{Pk9E4RZgaL*&t_q5 z7D$VJ@dx-YR-=*E?)H2|BQIpS%F}-sfcboEpD#ZWq)aY}hrGHxAkrOzufz8J3Z z^+zlZU5M)~|LMdIn_ZWXn&&+lletKcZFpSm;N;N?#Fl0)O-B*nm-=lVs_52(enrKXiC23mk7mIh)a3f)mA9fYv69EB_ zfv(T>;XQ+U3EI+60gd!SH-0!xDHt!7pK)D2$6df$ zgn_Fo?u?_O2y+MQ#-_WO-I&gU_vvYxh3ScN&l0v=^3nz?9J#v+_opx-!8tplhvSim zrp}-1xsp#-Ls`r;PQv*6;4a5_#kpR$Sg68J_q$E?g<3im5c%=9&{e*J;zYf_ptLay z^*Hl2%edR+-!ww>j5B%iGXzbt&L`jFUSkR9I!N^I1>qL;eQbjW{+1)?byZ!ntryA% zzSk1P;upC+2iaG+?GCk>wl3@r9&Y5!wEEmK#=;Fp>rB}F?w)(pSrbf(EO+TfVspTA zwH@h;IZ`KDL`c`|rhd3;{wKCjE zO4=()SRcDpj#mMGxZ37>TcaWX^%?zzt8+|cpLvGM`z{x!tWO7B8peio>fK913jO4~ zoOjM)+)Omcc-)-Nn+QE>(#Fj#7tkb5kB*lFd_LZrx@5XY$%U{WrC%B22p4R%m}|qS zxPNUYx=oi%aV-2K!pbM#!U?Oy??{FCGpwQ>%~$+rS}kb(yy#PGXswuTBe>svXQ{Xp zhRV94r?YtG=A+b-`XTPk{I_Tjq}|a$NlCB#TQ@0w_!XxRxd!}m zwW`6tOFWyRc-RIQ@jOSlKB%fLkmk#u>QfaT-%nwyAVRWF*w2X4n?~u|{l76$`l{N= z`fpT}KEnU2h|)sOze}R@DgNK8QQ8;htDQsNPD8NR&&KBkU=8C?D~ zgc)zMFEJc5hT?nor=B0!setc}VennM()ex&_=-&hHuY|rtHsO+BQKjC?qi)9*WTZ|aq`2xO1|M8Cb2?cm zKQt#XsA@`$Vif$EYui?Fyc;TRAj-XD3#3*K}|Mn!fLLmfA5#F`Pe;z0Fmu zTQ|ZSTQt5QLP57SGj@BN`mhO(X-%(fOU~0iMiwrCtpFLjmk&hia>>O{C<0W*i!HuM zF}ZYsB*GtD7gLupGIp`2Dl2EP6vS_96@*c8l+ebY=5R{*9gwP$7g>WF->3m0N9*fF z_#50eqi=EUHlo#D+)H=sdm6Qao4Z{p(gusfSL*PWg&cPiQ zDD~|D4oPD_o=2|k0OG{`DEbk;eDW^h zG3fwqB+5EJWTBCIowb#Dox&e`LGyqZ*JlJJZ0V#5^|qm2*kJI2@dtj?j&5C`3!-jd zxk5|?P*0+WkUc}Ij5}*JDI6IwzbTv_f=a(Xk~+tqKT3gMaQDnqwoGu;`Bf1Uy9(C@6 zLj)jNpslWO&?>9UtgYyKPO8(&18oc#4*URXhfE@te|&#|8k6xT#^A;}Z&Vld(&`F) zegoh^kVch1QRQs|vm$~*>6Grt&Ebeu5r*FK<~a7KH$7wu4>uDXdeAw)RYr%dE_8Q% zOZ_hBVZWV;??47KCbnd9LB-yB&(y-Urd5_tu0vx|i*eH{g9xoFs)fne(9_VJI zMaSou7Aa=mBd)ccq|_LXt<_vLMT18_guYl!C^+{%2HJn8S<`iK({*l_7>R{+P42CC zqy8J^$5StK-jjBOykaMoUf-k8zHX%tr{nkwUT7>!^Yly8)JmmNO7RCf=_=J9j-w#~ zb6|~bZwu{<0s7rHAjNZ2Ty&@R$Y6>O{yVR$YjRzVQ|C8DUD)_2FGV|cD|%!|Cga)o-X^-A;)4#oKIErp9XTn;b%b;bBUU8RJs z;oz5{eqqSfyMweXQs~EU@cqT`*MB+ub;a8+Lcd zMX009VA+R@r4l}SSNVKY; z{uS<-(Jw67{J_?~X-tG52>|C2H z{kX$vCzsq1wYAQoBVkzUJBof~Sq zZm_~7Y1~f~#$exvJH#Ymw#y_RdJNd3m91g)!7Lh%NU^bbBwfmeMMA0bjaeI21${4^ zaz<}Z%huZF=#;bBEQn^IZ4xXK{rnuKpE!gPr7YmXzJ6i*EoJQs9{hN6Dcw3nrluJt za+YT8LRLOm3!I&*DKl{8`!PlY4KupimVnu^nsr9`~A%I2{-dFLX0eUR9PU%NJ z+4DQ2xs8ExwW9u9a#cnMn_7?BK$-j-4HS*)lvIJZ#@P1_Y7NPJ$#`ls=!;hl#chD$$_O57@`bZ$3dWckQ& znHdp7Q*~}tYcw3#MvXt3qmn}UJ<^I`Q^g|s4mhPBA@VBMPt&)CiFTQ?sx|sS{ls^u zWxUd@;QAJ8HB8nOs+t`!Wh_y*IG!`pm(C~e6wQS=4)~Jhi~V1Gv^YIE*oVTsXUj~m zNBtl6z63C;BKdm)83;#6L*UK+lgO0y`}KbLT>YoSVh~} zn|c7mF^V!t0TSa3BYgdHpt{rZ;K-k{uwatDC%t37E z==e7Q5($;=wLPW6;j8K8xB}S_Hk)<`R@~O144JF(Jc`cYazyT3uoS3>SB!G)auy|) zKix~c@~b28J$oZ5sM%JL1w3Q_$WCUpNSA|FQ;&8g{YOs;bH_GD_Uv|f-KqsvG@aTb z+94SUL`fmb>hPwdkmV#x7|=hm5r)-a)6GgnMfWKcwQUdgR!L8zP}~UUn;nFshG^gb zJ_&v=fudo~(pT_+{DL|K{z+Hw0%8F~SPr_s7)=?(h>t&}-Vf_N3O@TdzWt*E*FsG3)C7(qz3+YoQJKSv?I;O zl8OH5S22fkIiRd3Nduv*n24agcls5~ulHh~ded{%X2_XeO)ekmA(kSe9#kJ3BQlvS zZpvj~SiHaQ7t_W&G!F^~XQjp-(Zrt;vZD)1K=Y7OJ4&g4h2P|r0aoFzPi@UNEIPRt z4EjoYx_(ooRR_}h@Jm0v{$_ej_x(D(o!I$p;_p5Kf3@{Y{N?^Cy;u=NWsp`ZM&LR+ zV}bPgb%;l?NsYvj;^-R^G3OJKrEFl3Ddvu#lna_8DNY2;ebbDLTm#1pv9TZc9U`tS zV?L8>PgA}6`2_q3CO2Sc$1P8oaziawi}o@%o@FOnBbSxgNhJ!=22w8mb+Z)JjmX(u zOB?(vi!ZaI86L;0hAwftxOYWd&(gXMqx4CR-({W;}N_zmT!{b#=V zKV+_AS^tr${MfH&R1QP^SO1Rk5B2F6=Zc=mZSm3-=&Ti_b*u-iWUjwS4t;%z*?Pd# z;1on_3XB19-+WVyH}Z_i!2K7kccB9kceRxAu%kW5#S1$rnBrfI+qL(iNo-LJ*{2BTaf7cZ+b0-dVfYI@L4YS{Lc*e^iIP`@%8I<_Jr74{fT<8k}cvEs+2GTmOn zASc`*kEuS)bI3&qukC}!s4`^{Hre}PAV!q|XSFE?hen@3xA)924(we<{x?w+mt$nO zz$`4t>-6@;O(j)aumLQ|BA|jr7Je&*=_ut?{wd_j)qhna75kG|ZeGS2g8+WQ11n6& zbFL|3;G6my70yZ8(&A)|XyWAoq=2?nPv`-Km}Qg3ka{_;dVa)n+X3I^O$|^7SrOQm z`~XrSf?b!VZ$*ojOKGresm1h_9{&Xu9~_d?$Rd%37ni;$RVD#rVLQ7R8^(-56^?8_ zq1lR%W{iUmdp6y60lFzjy9l}CHJZIh{3aco$q-i;0`**HA6N=3&1rG}egN|`N(M*~ zdjNebZCs{(-WE_=3Q)eUe)+Nbu3XFsEhtnindpbBp;QWU1&wKAL|S zJJm;yGU+Bo>hftslG}mV8=CP#=L6a5%9M=y^`#_6dP%8*Azz$!h@%$31C#%VQMSYF z?|!H|5{(Qj8%;r7?ZKy(;)FVW0aFFz?6tj5SwNrcq!oqcw!vywi>A&o?^iKVmry@S zqx%PKIW0xgxK%VWd~=iF^ix#&Xdf1D{>DI|@YP{Nr?&R#6b7l{T=asMc(07`*G3(|-23;3E1--HSPm?b{g2G5qY_G@14T zQj&v&fsfoBE;47rrobgx#I3glDp1IS0B`0a5K(C%Iouwc5oBd3P~Ucnq(+<}-&_ETc5& zj&HOB8V~pxhlyDwNGwvsL*V2E4xEA zT)U#*g{VN_qRD-PA}O^@7oTDji!u|HFRxrO=0|3eOE5EEoqNI0q_R_~W zs=LV{T*^$m*a|r!6-q=8ys4WcJho+FRIRsl*oh28>mBM>EsxRHFGFSL!~9y0P%VL} zfB2cg1Sry~{r4bh_mxh?W~tJ8tCXsPxN{J9vCO&R&Zf9aA?|q6pQrCIGiRHxxQn7v z-1t6>4YEm>HH$Y$iLMUO*?-W@^{Z|s9|ZPMWw&u=A_iab>X5Ox>@#llqcl73u{EOK z7(*O#*1fiJOomeqg_MJww05|8Ztq&c(C`k-lc-1%53N>LjLF1<5QXz&QC23Z+^Js` zFI>qwbDO+}`TJ{dBVb`IT9rD6fNr=9(Wo&{or>??CUXPn$;Kfb>r4WaZU7Adh$eMq zWS%J!S29lV$^1D-yD7(;kt0P+CAp)iH6I8#dn4(n(p~EC0_`J=6c^Rnx~ETx=Wz%W zk)Tg0?#lIP>HtE3$7#3-)dw0StPqnRyRo@bpkib*4rU8boIFR8@GAx?#+)@| z1aDg0tO%u(hNV*_IgEr}RD_@}q%`2F)K|l3E!L4upalqpc}HzyMdLaO@YrTBxD9S3 zxLNWcsh?g`3hI>t*-~Z1t)3T3#kW0=A0C52eg_1+wwZUUnqfLm@`l&j)d!VZ z+=h<@R83s^9v$@El^221B2W^oH_H~FCE=Vc0^(Lz$m}kc*^S}sez^-Q=yLKMjVicS zs=zYJ>8>hLFqsifRUM(~!IITxtSV4zScz|OdkAx*%bTl^EU7|YhVdhjSV{ES(|KQi zmZ+v`(U$gYpzetn2{)FiDpp)dS>rQCey;wMR6wfqna=x=GP5INIPZgA3c{V^yZn49 zeS;MW7q4vcwH*o<)95?`GT^?Cv`EF`RkNRWj$LZ-?{IT>M_4GSP5aqwS{9oUx zOFvgnKkARAU#O>V@yF7S)6?%DL7cJtzWqeFFM)KA4QM(m&6UU$-Z^`|nm8+p zEQhVty=3;tBWW7PGlf&`)}$E1E1+) zBYjePVDFr{gjU{;;^_=Vi^6l5!RZSngVAEL{WJz?flf@ijm)l{_NYc1+A>Id#i1SW6WUQY z8_Q@lRj~e;l^ZZc>GBqeRjy`ndRLkmmhwjOFrtSCX3@b3CQQ~JtBkN_UsL=uB&~-GqbR4}7{C7Sioflc9{sn{_x*zAZ{)|SXCUKC`QZ43FF&p) zh#X2Z{dR+D>hulxiqr8-4;Fqypg+a)NhCCCz2ynv1k*zKah`R@_Jm%TcE{oXTB5& z@dDjR=J5|NQfU0>kXktPCu%u-Q66br=MmVu+nCy+L#&;oqxkcgF+MMF0Tq)2hrA)% zYAm&PJQTNOgcwP)(a+H;CpmG2nt~|2TTYQbXn6XeSZ$qXPV|dkX zGu=&+?xCNU?n6xXVV&+W$-Om(3Jl@Kk`~#;dvPQW`t(f*80G_ft{>b!+iS8#VB|rK zsFV8CC^a55yL2mcKe1+Oj`cR=W}cdNP?{9-B`b0Elemc?qN!q=`iv$?Bbe_owNAT) zv+#tspMsl-=FHaasE@mv*L(L78K@D>#8-sWp1;8S&E6?0jha8UAIJ_T;(W`&+v{5z zfdu{g+3+rFd&?HTe{3v%OPw`Y{B!W>TdwEQ-=Y`ZEI$}A4UH0-Niz-QwC#Xx6wXs2 zlnN$Ql;!9+ll)Bd(C)2+HD!e%BtLmIwBr+7HL`MFz=o0qWCS53H70-VdDJ&Ie^(2# z8fY1WW?k5D1USUDQ*1I*;*n3xJ5o5>gMDQaq8L>2bZS4-y5mD`R59_3oJqO3jxbbP zN_L$?SQ1ooyf?8%bfVvERhTFlzkEGjYD61DnsapwUTQ==1E!6R!Ap&B7^p$AG9SFu zh^=oSpK_y<6v!_dEt)cWv65tEjLIrviqA!Jh+~*kAt%W0>QA}hKD)6=FgQ`W%W6$wNkfJ4-}V zy`ozimY+jiq^Ew%AsP)Sn!5-RBA>vFb|ug|hoVSIC=PL|BN=Bo`p2As1=--*5flj5 z;E1?^yTX+A z+rO!zj8Z?8O2lg^{1!<3#QQL^p&jps@Aku$dJ0tT`;3!mvnTOgoCg z=294XKE<*fVYMzpKu-$DGz7GzfDwj(a0-Ys1e}D6P#a?isHK4O4FS6-04shf)6Mu? z0L37I>f<$3WhynCba5Ugb)_$ZyJM;DrIl2kx+e0@k`0S3U$kiZ2}zBImg!;&WfY}< zWztt;N%b1MqiQjH!T1F&B5)9{-`S=rj2bSrbZLW7?F&+a@~u~S=Q#_VF@A>@e9Y)WMp_YTnq7}261MideS8qs1iCbhZ55p&>^vFYKK7cFMf z>?YzAidSBsOvA+arLt+N9e)vXHc<*_I|4BM1|A%gpF=!J38d)gv-7*4gwIlGjx<(v ziKbBq%#r50iRfq)3nq;pa@CHvBI6?9YpDj8s{B%D!VOW;-za4VjaXL;2Xvk7Be;n5 z-ov*RB_SPEHti#lIJKS`D+=Jj*IEx`!&DsjgvJLbZK}u_1a9>;GgEud_P2Ku%f1=9 zLrs@)-Sr(wFMK%w_L5xf)0+V;(pm?QjXY;qnx8sLKKQlC?A4+Mu!RBNtePSlsY!GT6_9>gyGUopTZ+= z*j1V3Og&5bj1k@Ukiuh~mU1m4K_P?e+O~R7o-YVjDo~uU2w0$WpBxmnXDdT|CONtTu6;!%PXzEdps!(!oQ3sfbQS&akf6eLH z^@zggxIrs8vlWPiiVx!?9(>>+5js-erZq)Aiic(fees3vpUX?Z zd~F3Pn@18BSGLHJ;V}@MCMMuRD!Iq{`q^^Gjy7(+wz+&Krc!!)BU!@wI3ak&y;GDFvn*-0Xe_^@!6<$^ z4<97n7>IO`5-+4^Jx)?YxEQXD+6h>o3xP%uWV07|`X=GSV` zpIs&29y?tgoI6T@ofUk8OAO|l!|2W!{v92Wuu7ULFdZgP@m!RcP4+NLfPpd4Bm43C za?)>X#czouQvQbddxtKNREqP+LG&OfernN$F!h-~a6d3r^#eCsLBM1C`bz@gUovtf zkbVlX^R5-n=2^&^FFr-lCb7XhAz?Ab+4CJ2lKs#eORlh@8!7h!w9bm(b93#GFm9fgZL`e&S65mQan@h!|25 z+PktAqhjv0AV^%fj%9KbzFMrkpE=jp1@!cgDEVU_<7Q212Dr~j7i}6LIu)sw`rHO* z$tD>i?|1XqJU8h9pL8brMB_l1N<#T-s=!og^2iMKy8IBGr|AQ}kWjf|EH5^q4snzw zEJS}gCIItShG=51RrdgTi*H{(eM)W6S2U%jPNN_F9E3gLdlBZfJ&s1eQ)(-GdHE*P z7GY*l7P^lERmhGNYFV*lbO~a`9vpaMYei+n_+w&nYskVN6!$5r5ni-d1uM(yZk~^Tkf=b> z%VeXq5R~dqvN`$>Ngu6&r#ZGZM|~K_c2p1pzm(@d8bzWsDY_?CEskT?i>w`JAStxW z!4L-mr}D(%GG(Q0c0 zdP`&}k6%G7R*Rx9WOgtL6u-QWOc^bv zP0%ThLx(~i4mbW8T_y}?!i+ReAKD03LTlHis?(Qi5#){2bce%@eN`MoE&*T$F$0!f z9ol2j%8spJ@5%!QIHGFAsBAKRD6CzqYE6Yz4{XE`J@~&<^ajtr1c6X2Jj!1*CHpG6;}%M) zj4#-TSHeWCJwb)E4&d%9%xi01gl~_W{t$;dP}PvKw8%S&y;QZ04}gHGEw}#&^D6RH z^dros8kG8?yH;u&6ZlOLBZw6ipsd+SJ2^6FYpTl_F_sz@S9+zLLJNuUllaXd-;m@v zyOKu4oF!COkVvd(MO)aY?IZT^{fLe8XAyH@cNP9Pc4-lwF5IscuMMUWLIMybe$ORh z?r2}cZKjBKtB6w8EO3nRkfUAfoI#pSocgXGwiE3mXq^;}zP^my$4Vi8-wrF(EyQCJ zfkde|=DmnX6B$H954%}U%@qKfgP1j9fC@!km83IZ$3aP8|H!o5xEsa6XS)3CLDJ$+ zFd8jpU8ZYK9C02kw$cyHBO&yp5&fD?Tk5r0KT9Qnah#p(h2;uhwD=hUu6>FNM%ky5 zTZRMB{-5r^@877~$+K%4(I_AroLU@iwS(PlJ--kwC?!Vs!Ho)I>hHJ-!7@Y0UfW$$ zWN~2x_)t|R{8m(35F)fDUdrJ+>$w-!#1+@88mZ-9Jd-4#0Jy+!MqEnVq+47)DHX&Bsp5CSkczh<=wazz+YAmX zp&#eeLTau2Gh%%55ep&h=%pP}U~P}x9eWs61S(NC6y60i#fx!l8y`Zlh(c}4r$Q}~ zg$gj=zWxHDBvt(w3&xa(U0P4lHwtIX^DeMc;qa}Aju0n{C{R9nOSYb9U>GF9Sh7&r zBy$Or!CB&+!qNWs@iRgIOl#)hXSSuk4E}RmTM$gLUOFF&!arN5Fdluh4yH$ zI4aV3L;X6` z#&!s!yS!j1!>_!!zD1nZ7@is|skK*2)JO#`Gr*eWKwNr3@35)XUg{lvk*zyT_aGda zT1^lc?(~Vc!rGEaSU2?=V#HoC;TinI%CNr z*_e`cU_>*E1^$FPyb(=zc&zmZR9Nc+9QY!PaP;SJZ*)|oqJpv`9TkO9BB)y{pPjOXe2$0ng1|OaXwsZ2Dd=uwMeL7aNaJ!&!7E3e-GZ zWFjXF`^Jz%Ta9>0?)OLquuXas1yvHM_bJ5gewPAPBY-}kRcg3{57eHTeZ@yK-j^a4 zUNw=7VM~%{Y=^=`_n62;_vENVcd?e}-tW6dEy=wnv0_YEV#Vah#EN2+Z}uT-xp5x* zthkC{V!ddX=sp_f-WgXh_>lV`N>j+C`x$SpGDJjq@LEcNC{tY3;3M>!=oz^)(KAel zISsjZH-RVeo|4>K!B%>_djoL{6JT+FK&44VsS$7l0rUyo*LQYOW5LHhg+UYa?<9`J zdkqjs!4e?(QSxYnSITfrY-R~O$oA+ALg4n}sL1Vr>UDL9x+d=GIFdmblCRb9h<6{4 z@4vylKAEI&EG{BxCtf0+tL(k~r|;Zi?siwJxYZXS9@H-c%@b;N7l79`%E_`p>r97S zAfhai3-`VFV3i;RC`8VMBum2Tf)vBi5TF#3sjwuIWcOAUOT4>^f&_x%+#gbq0~mzt zLkNAE#X}g94`Iv#B21;iXd)YWn&f$~Lm5IKBIw8DXy{1{^dv4VZoE%dxIK`)8SU`o zcBp}vV!$>mv= zArf#AEnncoZ1PgF%h^eugk@MzgWq%Hyk+9m~8C?lPI@zRzY7}-U6&5BoVcEA4 zXCp z5V%a@=eg%Dm4R?EWleD{#DX|#LDmaZ162pAf!eOD2dbAS3P{0=h8R>q(SVBpQ48Rr zA_3I`#wrr*fGYtxlhmY&Mups-y6T^^=EfT!TZ|})sK=t)eltw%8RBncaeOZ zht~!2wG6M^`oIyGo?FJj)NoXF?V}f}EW}J;>JRM-yXqjgWqSRr)a%F8z+zB8&xikk29^j385-C=UjwV8CgmAj<7;3$sa@&y zgSLeR70|$H#5W`mce0%tSabMER6nE_s2{3|aT~lHTW7$x>ln^YZ#umn<>!8We!$Nv ze!j}j<@|hzpLg(cJ*O?^cPBru;b$U0_v8LnZ^sY#S;f!iCsN2`{Paws_X2)S;pbw0 z&gJJfOgU!)#i`}@ulXs%kMjF>{QQBRKk@TtexBgxFZ}fKGlc2X=ckpQ4f%QJIHF(9 z&k}x4;^(#e9LCRq{OrlkcKke>pH_Z;GnOcP%FoICd}|Dayu#0?__>&$9)8Z^=M;Ws z^7CVUzQ@n?{Ctt0Pw?|Tepc{v20!!gBwsD5!hzLldz7OhC357|=*{7;#iH*{bVEJ9 z-DIyul^;)xci@OY{!x8GVqLs^c&!Xb6!W1+{ba#L+{;0IM&gBuj|~zuILPBA{Qp?wej7c z-VL=M*P`nk+aU%0?rD9leI&To9PW>3cfe-hg#!y7akPH3_!;lS@K<0-8?(Vte!}Z& zAR5BtY!U6~d7wW%w`@7^9r^+CV`*Re{-#W!ozvkC?HFSr{da`bTJgK}t`@y3k2oTp z5AV~so~7ilx>*vJPX45){fA83P#i?NsXbTRfp5%d%#=)p;um%siV}SA&+1~`2l?R_ ztp|o!EORO&7FL2!9w;#%vk-4iWeFu}E{fzwLy?1*+Czr);u?gA5qPLT;)VPVgTF=+ zf+G$O+`Y@;J|Ha2W=YHP6CovM!HK3>9XQTaUfTm%R9+0_F^EH_HkMqXT+|LO`m~c2 z(+)fxi5MiOwU8Zz3&R)Qrx9Em@P1%9kYY&`~T|FS>-V-AVtZ@g4MzmWh3>pI&j& zTT<%C3~w=nhVQm*_;%EPXLjF|1dF94ZFLiHJh5+FRLQ8Mu5U)Xa#agQ`I}LW`r9I2 ziEH5SG)1r8^c_bEcpeIKBrc$sJmQJCA`HToH9;3-`ddoQDsSSg?ei+E1b!;EaDtEt+bx;cNgg;?O z(S>rMdQ5i5!3@5W=-wtU=&-CNd&QNE1biua7%nd``Eg~K6!I9`z0K$HZ52+vFYJ%C!kox8OjwIEuPU?+p z-X3#cA+iXZTS0%Uafc0qek~aKHragyCLL^HlnWc53va=YS=|?;qal z+Yn1hq+|X1$caC#_j*f0HaOP59eJQ1l;2k;9_x#Zsoi~qz4AZ48RBZZCJKqgEyt+E z=++zZ&a00aLVb^*wD=x}?*mc97vzrFgTCblLtN4HXjy(7Wc;!h$3d&sDmDXPg@@}S z5%Zzc7r}s7l5x@A!u9+ag$h&sqKoD+F#jdx2pJ~1kD_H?|1-7hh&j~4#8Lcc`23xi zX13d32`L}gw;n8F#N4-lsKbBE#J^IaA?!;ZxiTkTn8mXqCI1Y>I z_1WDR`~08X)}vr$Zsn|Snb;2T;fLkD)fLvAK!^{2RFcuk2ev1sOD+|)D4D=2n^&xd z0+e@bjBx5RZUwKsx7+8xCmzAuk2_oMF5e#_zD8{KhT2CZ6SrbwWx)ca=j@llEyd1~ zFy)!r$!$4~Dnb_*(bqN#q-W8(;b92NbGa7_RnxE!+nN#vRfCutW7HZ$wNu-|YugNu z*S!-1y%;FMoEH!ZskuLd$i!0EaX$*@g8itR270P;Z%D}{AbJryNt7%ZILF){p;clN z?Q3#dKWxYuyPD~7Ft3Gd(Zg8Fz4c4%z8rfPa&+M;9_vFgWl1AXY6v{Gm|C2VklcA~ zXZbjISbjGA0J|&<& z>cAafozm074Q6*VF`urC@YrfGp(6uIP)-B>pu}gU>iE%wKQNu1k`L+U6CaV(e?~kC zBb?BCcf5Oh-0UOyE@y7s>~C}Oor&@8&F&-AzYY#bXi^jLXx!|hL!9{~6`>sxnruTk zL)|}Pv>b*3a#KwBmWO}!H9Gx{M=tY5II67g;lom9u?(!lX*XB&z;yI+u66?p3AD1V z#1F5OAJe9FaA1EB{fDZGcx-)fi+H!S9cF*xutOpq^YwVGZHTaNysW0)I^N?r(f`QI zPQ*4CiFcF0Lv=PH*6eRo+&H1hdZgWTDgOC zmI~`Wyv6;vL372HS8G+)xA94yW`(s1Z@a@gy9HxCau!{2XQa2xW?SWeDP>hC*>Q_M+#TNcAQ4Uq)$~Bz5*Xe_4@AP6df?DYrm>t)SlBVc zfFjx926%0~hZ5&69s|+dUTA|{W_NGJ{!YT6tbKK@z>FCUMAQ8mv5Dkg$okOxZBkVY z_v|G?o>k@EdNZFV!^pD^(jIA}9qvQOSc5d?901d6FM9No?)P^^HvKo0xE$rqD3mif z3i<>Z4tD~kLlPprHi0kvL|`v^-yz>Um~7_n*XWz3W|QIh?86K+-g{#Xz4v;@t7r=N z1EmY_(c7ODhCaA=&l*B3pfNfeKh@y#pui_;ejJaD$zVrfDCG`8g_K}qP>hVYp&|}u z1hDY-5o{Hjgo5yk#8=49Q6APod?{b(w7HD?I>)89%-sT}NiWcloo8Sh0NOE8pCpZJ zo?)+lLztugCq)NjA3@!Nv;%tDpU&sBM*xgz3xlLZcd*6Vu^u$n+wmDxs<-0`ey-%_ z3;cYEpR4$}nxAX<`5Hgh@$(ozcky!%KR5AH#($dMzvt&ooYy>lUdYcKhWHwXwC86B zes0yN^m4_-DX(93clMDvEuyRK$a@X)B@o zVsaC#+jI3BQ8k20IU{N*p)NiXDjDaw@JljEnb505>X~8jclRt(ZR^8)*um)a4$c>x*;x57XLEhB~Z`XuxD4Bl+el<8sL(}pWSq$(j* z5qk6xHgh)}Yq}uG(w+CV}q4@mQx9`YQf(emtIeh%elA=SMp2&yT-O z&u^B>Z+~%J`6Z+NLrjG;)ep(H$MN9$!SeqV{P5H8Cix|~v1sESjov5OA&M$aV&68k zfYBNPqQ#w`;V8g%j21`1Ln85A`M_z#F*hI+z4n3h9?|KADth4${vLYU=LPV0^i##( zOBg{Ce{VeWd+6O9NbeDyUY4Rai+2fA&!uY%^|uCtVJh$Cf%HDr>9tVwMn3X;_`55R z-g=$hMhq~iycaz7yXf5-Aiq!dDfzh+y|3htNBq|E2GU!v)9awX?I9j>AtZnIho_M{O+Wnfng`OmQ>XVmCJBh%{GGptUi(0LXX*4F zSM+ZA=yZBU`NxVSbP<~0xEkfZMq|mM-0SHI5IvGHeWBP&KXk66xRIFg*z3v94iY z`!L9#{V%Vrp2OgN82l6JKbHJdNB*)uY#nOK!brai8Y23eSuo&1zpXj3QNJMnba9w| zqbq3nW5wP&(_{Z_-VkHq0_dlRQzwJdcYC5R1QEN&-l1A;^FPFzqZ@#;u!Kr{j4)11LRvx z{7p7@L5B368p5{@jz2>9+Msy&cYDpsW6WQ-uhAj{R~Ps+#F+mguK%F)rGCyZXJJfV z%Jm;S{c6HDF%Kb(>1julXc9cVgYdQV!`)}pZ%;JgHwVXu5q^1ayde3ygX3lW1;vwn z`^LnaQT|Y^bkVks_D4Rzx&w8jm-0UmOg?Hff_rwvHGUcU7m5o_eX1Ehdp+5F%p2SFYy$DzW~;vPu^EWpNp<*@zfTUtv(2sYvk-`0kkx*CtVA)|kCRw{&rWZb-{<4wYhzqP=mcQ3_0EhN{ zm1xYL@TuZ@Q-^HC$9iqg8q)^SC;n2!g~8K{+w0WtLa%Lfo%Hd0hnE_7H|Otx{;o6G zA~W9Yk^QN^Zq0a)XA$Yf3^NIT`8~;{V!Rm%0)9N%``gXL1Mr0+zfOFLXdB$0SOb+O ze?=2_?mT0E?)K0x$%rQAEDZcH9#s@3zVdfVXTV2_sygxLZ)(rbzC50#q}%=dxIp=q z5x#YBJb7f|gxQ4p`SW;aQD5v1j^9V=*9ON|5`J-Td^*ve5gczP{FvZ)jqv@0<0}c@ zIyinQ;ZK-d0%Q5h2){cxejnl22FKgE{K4@W>CcSdcv?yoV}jw~Z?JMbZEE9w{v}=$ zPn*VMfq1(Z>UU-W+7Hophzo<`#nw9gF-mV2^G$Kg`OjWoMmFG9f0RIcCE+Ip$6H9o zorB}K4jKl>mr?p#{oSlNe`>sq{fW);k@MM`5Bu7tx&65r&m!o2Ci))C(ca|mI4Lpy zeEF{lL&kNOFYJnZk7DRcb!`S-y88vaG_{AoUCxOuR7TKXP!ricIbd2srwd?(B#jQnA~ z!i+@zf%A))Ygf~W^){>N#2;`7ou>ar)Gd+EYN`KMy zzbpSDJ^jTh{iA0}Z;+21FRe9;*obHSSsWCP{`)2~67n&m$9R8*$##+ajCklzU3kf# zJ81uh{F~L8p8bJM#rQh$niy0k-Y(kKiN}1h**OTvzfkNAZl7U)9q|8q2lbB`|258z z8l+>C57v{#?Ym6EGv)85`RqFNlYcR3Sg`)biZ^}tJt7_Xn%+6 z7~f()0QomBPjnK zI&Ut<1jomqixnq==f9HhyMyB+3BNWt{s=bciN(S3F@&EH9KV3*j|q-1Bz*ti_!Ppo zt^-f`pV%8z{yd*o8`OS!EF|Ay9|ujJ&Gi>tKhp^x+&+625PxHW^T%WRVZr%ZO7y!1 z$1fs$li>K3l>hGF&13kg5Nj5qlg@&~Iu9y<*LCR~iwKlfaE z=kCU1^X4Al^`i^ zml+s9!<4^RaXz`Om3BwP+)~Qd?I9qB+@;@3(*k4{`N>8Q*~y}FD9zO`E!)i3C>!hO3khJe&YrG#shT^_AJs!qY1`1vcnK`NS$?4RJ<;=c zVyLRh&ZE0SJ=T80pU_lZ#NTL$pQ7SVKj@Ed)DQMAy4KadpE4Uiv;8Cg2|1XHV5DDz zx)QB-2Wy`ge+)M(znQ)#h5Cox<}S)ezmV{2gX0$vesOU8O2W^m15fI#JyHp^2gq&+})K%|kRFeoZ8LUpOJ@;^-Q;>2>_xXzXf5nb-)OdWruOJLM(t&cIfk)4EG2%s2FVXO zlo7s3aC|!9>yS75OJm*-S|2N^zWN8pSCT*K4qhL0EJB>v9+bWy7hue%LGeq8e%Iji z%Lv~jI6j^5wZZwf6Ml1Wyhixt!SNQtyMyERk^CnI$5#@5Sa3YAh4sg+um+dFs2}@C zzDxl7TaWAAvB>3GHlHD6~ZD-D8rk zPk&<^?t>2Z!b->h>d95LuL+{6_YTvgs=oDrQq?2(v#Q1zH*X`q0D71+dtmrUFHXmy z;5bC;2|uwa1i<Pmyv44^*;F}=IZ872u)K4p)PT!%^cS!mU zAN|RSez~H5l}`V;MQ0LV5dKrd*<2Z?AE0TX=wG1d?*m!p|H|J(KUJrnD)~?K@xQ#4 zl+Om-mrK$w(CL43-|yf*UA!pyPnY!5ee};(^fMLxQy|Oy-|>6sXN$9?u3-I#bS2wI zzm}Ra>bJh4KVPTc^Y_p%6eA`7g_3@ukN#vu|1qbOPfMNtd-whh`IL(H!&pB{CH+z# z{j(MQ!HWK)Aj|S8_#O1+){X^@kOl6Id5|Pxr5hATFbA_F;}`3MTN(+Q_0eN{O(wok zC4N>V{(Pb0>La|M2lJT)zkQ^kalU!*XS5d@+aK7AH7?1|2Rc8mNq*!p2OmE&KFkK! zH-2H8ZNpu9qFarLkdOZ7=ZBVtZm-K;(taAVF$a;r{Qb5mjBn3{QR!SGQy+pMjN!drzIlySDpbBg=tiVPGHwhkxqLUGt5Is9K zvp8>`t7%mNTX__|w!K)e37P=^(Luk9O3_op#LW~|5ha!Y67}eDxO54uP}dXE`|a7W z7@<{0+_T=jo{l~1rw`{NQ{v=6`75|FdQ-d=t0y;H?g zaFl|P3cj;I;^!$CuV6g|*WDrUGZl=>@JH2mnSx^#?4e*|1wVI7 zIFo%DoEZ&xr>K|TJ-fcQt$chr+IT~h_?Dfqog z-$%i#70g!fW(6Ntuu{Q23ceGNUsDy{QbA3@3l-G)957w-(^kP#rSf~Df^!s1Qm~DJ zho;GRWeTP%*iXUZQzd?jf)6M-Ucqh(?kbV-o>4Ge@!?S4Efw78lJN=^Y^&ho#WMU7 z1;Z5FRV2gn6}&{jdJ68CBJryyOSnXR=PP)Hf^8K%UMS;jQt&|q9STM(7^dKt1u}lM zf)6TqgMujv?#h?(<|{Z#!Dt0v%ai!q6dbSMkz5%*Q^8IO-k{{A*Kg`f8LzE^pU;rr z&nS3j{$ioTxjJry$jIt8~XxL3jN6^xuO3<2f>$Vbor3ua z&QWlYg3A@jglSGX-}jxKY8!l^m~E@Hz#P6da^rHw9ZM*g(M_mHhW9xK+V* z3a(J_K?Ub2=u*(G;CKa_DR_9Xrs-~$TYqM%E`u?og3*iFF}3R)E0rSwRs zdflYtrMJsz8ONCgkhlJPFORl-vBeTjmnZj#~E3NBIbCIzz;yh6d23Ld^urrW6C z(+bX0Fk8V`1zRY1WTxUn@v}<76^ftX>RXp{5^Z=f{2dO!JCwcLsNywK-!lW!Szf4D zhUry{4v$pV=kLh3%}~#(pTg_kOP`kUERhyTM@J~=SflU(;lKCORZyViW#muSa`Lls zGM&ZR#DXHt<;*VzD43{a<>cW#ryySg;V~9VNJ43$Gt=eF(ux6dCgwP^(h#Ru8!%9d z&dGOO(kIQ;rP!G}F|9b~24`{iK?I_^Xs!Znf>U!9l{mX;6G~j#qym>VF(bFw*+bK! z2N#s&W@)bMoMO#cRD_K4oW;c%lbn}(hbaK6c*-AADNqznUg<>WRzl= zZW(1Mu@qWzElx|CWw0e1Fw-*G5^X8MTQP9C7MG<5Vz>|zZRvvlaBp6}v2lrGhbN_E zbR0-qU{9J~{M8Q^4u zQ{%hd=eLxm5i6?y)5l-JT3Wu*Ei z%^KXbtLBOiSerq1=u&D0x1VKSP_E zk(-l6&k`p$x*}&*Nu~=;4Ncx+@qMLDa=OxR%VZt`@#kxQ+zQ$|}d3s~h{^zBGYu_We8{VvSN%7TH@@(Qw?TJ+yif7za=WMmD`f}*POrcrs* zid{w2-3%TbA4hVbzdy&mCkEK-OjB>r6SjH+Tk9gW`nuJ=Isx^DkX-ynF{tO$EbS}< z@f>a$jot?Lo2ozR9|s4-HTRN%l*l^%wdMP(^G`)PD6mY#f5nij3sN0~e=#hV^c$u9 z|10aC7@2ID2HnZ93|{_65&z^9o6`3LH? zv1C`H?3q6QN4AGFw17BxS3@i#;p>rvB%O)q^>ryxP5zbTON4$&z5+`J|CaeZJ^Zia z_sS8YENN-z8C)4~h|%E{-l_4B=H*%`%JUP-Pi2SW~wEG9aOobc(2 zEy*b(;g9Gpo^l9LwyR5WCuiiN|3f|b{lYv9U$luu1$nx2h`y1#N!V=lhkbqc3k_4B z7Wr){mEQ=)Mk9+|VfEplroJt-t^7qN1nNZ-$)~IL4 z5Hw4J%L9$j$Q95%)#8#u8lgI~x@piOijbF4jL~UPMiz$I&>DJb5%ljDjeQJp3`T^YH6@dYK23eR(({T4Yh3nt}b`T|oEaz2Dda*6T%G{ha9KRLf( zTD~?xsfL213gVqpF^@6|!$ZXSTgo@HS4>Y9uWA+~Qph9AIT1gSkflpRZwC56H;d)M zqhXTE|CgV$r%`@M88=KPx%>^=;~Dj3WMM{8v6DuQ7Ryzm(IduMc%ClMUr_OQ1chIm zvB+;4VIm!p^~M2fOv9*f8sXqQ8+I3v=4WJl^309#1{W8}tQDTVHXg_z3*FQX-O%{nTk>K=wj!Um@tFh4<9!C%4>$FT^=`h zSi^~eKx6_fC85r$Lvqp^ltMdwa`!~`h z9{$ey?=LHxOpHM+7H2L`T#tlGSu%?)mYm{|Ig?;UNzyz~=yM}@l5mQyer!k)-zTqj zImOf-C|$JcFP?Bmx%B!~`l|WM&gH0cg!z9w|DLJ$Uw)oseM}6H|Nojk=Aghv{iE^+ z+GC%-{e9-2gX%*L^ubl2G+23bSdGdnaIt?^EWK-glf+c@1ZPpP?)hL*vZXPRl9wXEWJ3J>5u>y#;u4b*;x%-&?Xg*ec2BA8Q~v!^0UM}z2UWlpDz%5ow>w1bqT&(| z=RLC9_v>}hC8uoQl>gtVKY3@+baP2HODFuNp?-ghO#XfL@n6+Pb1C%(op%P!HOKVJ<8&Q2 z$0VBnsq&!BsQ=0Q<>X`3g&mMLWbojNHk>(Wf%QLCo_Tq(Gz|Dp<}Z!-`(J7AR3Eej z`G2K7qr7QUUbVq^D&`q!(jqUXIFGg|^rH=ym`2p5VXr|Z){Lk*@z_Zv=vXW_-jfTM zgWt3-DlBs1sp*8UEgAdMkr3?>i{SmVj>Xs=06%5&)U?c!EM5njgmD|biR`qDqAZIG zE0%=UzuDgW5r4l3RK`*+V`+$vovA%C3*lx^Fwt~PFE-^vd(KE6tnW%Lo%LB)wG2(^ z=7Llk+fyAY=qZB>in6dxi4J9Cx@c8H-{-518OQ6Z!Q^C& zhZRd+L;U~pvsi{Xi}Ibhn!b|+yZm{4g}poZCAqnUE?%^Spz<;%JFmdn!07x51*Oic z5hYk9&U0R&HlUzi^Btt2B>)V+iZb$X!X?_NpQn)Hq!PzC&(#UVJW7gJ*InDO{ zF7Ex?gp4x&3G%F)8{|Rz2wVjt<>q!J*Kz)HU1&4YDT?WDdd0;t)rC86__q1v^aY=; z{K36=d`jw_*>~KrDsEBn7qOK;Uglc1cHzD?H$HyhvHhX*E)MN}esxz@%9SOjT4cX; zD0+F*Y3)*HK6qr?t%qiXbqHTH{*-fQ?4cX&9gaIjAE{~-Ik|j&dS%4)>a)VG8?@jW zv2@9l_nRzzcv0!+vkTVE>+cP_@z|H{)!*Gw)3Er)55)4etLHWTe&5c_{Pb^|edLXF zw2FCsTFN_G+4Z98isf4xs; zo;B^%S;uEwo!IV;+utrb^6oiPzB>Prh>H03aXkyBmw0jpR)n_h((n6yVXtnPeAGT` z#?=SzIe!0p4Zh9%XzcvBm#4MMp1z@Z->aG}8<*H*!lzy5_O5^K$nPp1Ke^_TE{n^r z+I#H2E}yRaDdf)CZQE=sz4fEo#&a*9t+jkUtGMa3=f%XFw&g3|ygw=A+*^7Jhlo@ya7p0nQh!mST| zeAT3q#Ofc1|MTdvuX}a7qiX;9n1>Er9Xj^)JF}b4{^-UJ<~_1>=CzMaKXzY^JFL+m z%alj29bCHYUiZP1janag_1teW-|qYV{TZ1B-ScmFeb)`2w0wNb(2}Pged3|ZtgX8D zf9Jx_SH)cT#jxfn`xf{7Z0lv0pMQVzR@)m~`GB)|Mf}a<_kMeDvSrE(_pZ2bL96#~ zXx`|O5w6VZu57>liuujN`m>7e|M0WfcRD&;e$_wvomlAXKmUgJuYc*#p=N7dcTyJ-6(b3O@w?9-I~*GDyNQ2(at!MA73?cICT>c!t5Y@3&P zxXX7l?!GT|WT$JphTOMi$l~zsGln1TU9|kIagX(JF6(yUp)H%&wrO|n=FzX6z5dWc z&+I?CJ2Q1*i~Ha2yL@f?ws%yOXY_mj*+~t)eetQGO+L)*`t^*P2aQ?O>d~FUh7K?P z?9RTov~K=MZJt>9V%J^I^*-nR%+>CMONV{8@1SR`XUN?Zlc&62`@ug?G-!V8%iZTa zdBMzA_F;yS^E+Tifz#QJeHZBWx$9)Z5sq->k-cuA6$#=)pf;7J65Mq@LAz8Kd60ufx*k zKOdD<@2iu$Dlbjj_|U95gI-;C{r96I#>QXNvCre-OZsdc_Dskzdo?VAh=1?=Bf|(~PD~Hny@iUDagAls-@X(4l$m<083dWs6;HX5O{wr6ut< zds>|L#=zySp}%aIv^Mgk*D@*|jDGX}*|$CQbNqv`@y&uL2t9sBXt+rl3h;K}*A!O0Ky z#V(!NvHQHa&&72aZ2RKaxDVcW;k<;`b1uH_#s1MHCFg#vy(upE;Na|q8`oyE{G$EQ z_g|k`X4{^>U-7Cww&hKUU|hY z5AAI@dfBj#AA9$fdk1#DXw<#G4FB?raf5GtXxz4&9(uK3vt@bbHcQy>{K6jx#y#9U zsk+UE@QQDKxMWMOJ#Rg__=)K+uRid@^|d$FA9TShPh50^?Y1FL-@Nd6<9C1D^?ApC z-jn#xp^xQ7emS;pvzX`~hp!4=8YtC6AxA_w9ad-xzp{^|@BP)@e_i`t+Z7hu$te=(KoT*mZkX#t(1NVRpgc(mwa* zYtIyR8fo9u?%IwCEf>eM9kYF1>nk=qGIhuDOM6UhvTenRG3aQXTebJn_VrfePukQV z`KkMl47%)^{5NYN-)i<#=T)nB4BWM8&&V%-diJGner$2UAWx^B&p(;;!qxTCo}6>< zoHuVT?NM^aC9SXf@xg0Ou4(Z8g+0F9GwJiI-D|JNdE)T4iGw#UIs1_fTUTViv32au zp(mCseXnJwk=vfwdTjhnPhFV0;8aYHC*R-M|JthkuT-sklY4&V^kr`bY8A z(-KB5ZFt|>PS0I-jVCH?!3b|*kI?P~sqH&Ae&XTc%PcQmId@OXDovf z8S_o+9Z97RKiu`wjI<9sy;rp+?Qo~@XL}>wx%Hm*#Z52mu*;uER_yloUOm6;qkmlf z&(EIxaQU-8g+Dj>?yoXG?m4Z)k#}F|pETsyxar4k{cga+H+OsbyUw|fe>HLR1)I)! ze8**N_O4zu?bgcKYm@GjPn*W=m=IX#(2i@$^W|i{?&cl zzP(h@#a()6@y<-_66Yy>Y+CG$iMq~mT!?*+&Kz9hp;4RGZ6no}F+I;$`YNh3-8XWY zUG>!$)|2xJaN`0I#GgFRlbFss)_SB1v=(R3<5&KY4!R*X;{>8cdC2n!T9zQTT7` z(6vXO==fFr`}^4EU#&m7Z|7gNEC223;pxdMV1C1SVmjC4i%*+X=zxyCr&KGDq?;*1C;3gBRXM(Q=_(fBP8X%$K*U$nzyc|MYsI^YJ(wP*6hK*UwZx*n;Jc z^H+Pm|JU*ftT&&&wxfFc&+z_v5T{XX{VDpw&2(&#`;YSTdrG}kDpcg0gyE&Lh%T7H zdT&s@R<|y|3S>dcqv^gio|5;;lf{V>yIIkldvxZuU+l{2*H2zq)30AX_Qn6TE%-k} zzMZv<+#+X2)^yFhu>4!s1%2{XBmc|-tPDnDCnbwK<}VlVpU|Fv*Ldykvz}S^dT)-p z|4P~ltmbBDG2L-fq9%_d^w6Yw|GE85U2C-}KiG~!pNlXFnoXNPTsW$x-QX-LF!Vnh zagg>`YpCzuMw(2+-=Cnb+&@c?DA9laq9p&9`>i-AuB9s-p&75}F6nR$^{ZoW7R87h zFlvy(^Oc0SgRmeU8+zz=L8G#fk;%Xg0y>?jimC3y#JM5Nw&`~`63%C@+Tl7+%vF+8 z7(8}2Mh6_Fk_1B*|{=e4`dFGR55i}MhjSkh0jtMaBS@Hr^2n!y!x;B;LNj`H^7l{ol9W^k2Gf^$*`W`mfg8e~Wznuzvm9 z_V@p5_5c6Ncjo(;Lqh6>hT;Fvu*i^*kkEP|A#83fI3GpbC+=F4_mCReKc5wpv9Q06 zE}{On?y56zf+~GY|^Jn>|zjZ&Oe$iVouIF`O4>$hN(kT|0;XlCdb4@G2%@P^d60G4~@}i8Q9J)^; zL!;?C88DqMutNylWv7DrV@LqUM&8Cw*B};TW|rWx3f%7H%*-f3i=m6^Ko!$_(`jNH zIZzECu4xZ zhp>q@1}wz|`MBtA6!x*3i2XnI-UTquqH6e`JiAHLP(oVT(m=qa7FhbWVcUjKYN4C7 zDQ)RxY10xave_h?CXgm;Hh~raV?`_q8W9i?wIcFXMZRzm3ol^g1w}!O2wD*|A}U_M zh^XN8|2s4D>^{5QP11tM_y64F%$aBA%$fWB%*>jN8|q~>&+wTOsM}PxHn4KJ4roi4 zY|OW`%WWH6m?t5XEANxrw>EWkFk+L-8AUMO%kby`7dD9~{Z=-qx#~F4%?>_k5({*Y zz9PXpcjelpt2ZvITeJje_=LXjJ}g4);aE>Nd{)hhhRrf^C!xA^G}0SRZ!S5Jh*KTW}=xkJASBJiVUt)awOoo!7^C+}MR&0wQk%=;dl# zY}Tz^+Yqp_*HRaFh#(Mb>*+}(7oBiIS7+z8ZX=gn8EZe@1abGb%)Pvc>E^wL{?D3s z72NxaIPV%y=v&!M4G%nG=m(}{(l^CgSSjl0;)-m!!d>mMbR)Uwy1TjJk>&M3N4#@O zk6fRp%(@%N(M1X73b0+@$?Z>ID7X2DHBV&uCwIrlZ`*f`yU<~o;WoD15Bj*Zx9$UH znsf+kGI;1V<1YL&gLhfHMV7eRo7GP1-c6S>n~tPQ^n!Toyr>nGpXqvNacaK#|DXLa zxlgE<>nJ#x!RAMjZEuOZUB7Mko=NBIauIyg=IdmWzMuQDN%v1%JltTu2j6b+$E|tz zlji%Q_WMeU-*55A`Nm#dXz+Wi`MWIcwfGh5|F%iSf61Pm!0XzdA>1C7U7~0c!citc zURVBdaHZEWw&ja@MRk%4h>HA`KHV-SH@Xr%Z|JB`_vdc-II6zjZAyQN-^|num2)Yysr$j?(8~$S^ldJu^>P4iCEr}R7w>N69tDcC+!)KTIMqZqpoP~6gk1Q$ zz3zP8rZFL+CztzS^%ot4jGAgM6Vi45`}SAKx2%*f zK^%C}y?nTgW;oB)m)AXuF96-f&)fA|a{61YW07YZs061!7`mBMdiX zuxJyCQ7iX(*^#;3sM{wgf_lVjYgi|eXuK}f-A=lR?5R9Bb^M|dHwbeACF{D|w=neW zK$6#v{nFOTn~HDH>9E}8ZRVnrt3u=ILX_?mZ3N#;;XiTS_`)wiXJ;eXvq`#l6A?G- z%+#{gRGT&i;ZiuV%8WY(7MD6Lr~$ENDF$-4m$`P)tw+oaX@2D@cFWTin^d})w%G&% z^=M`X3@!fMMd8Gg#+j6y?JPDmu<&5 zf1i{W>uh^wN?T>Ry=H=5CvDT4qP6C{XZQB?ab&~d=f`!h)|e;yI8&Zp0iww^k)y(C zqLl1}*yxB-n@r2?*2#4LWGd3XbxLcY=hVb8A3Y5%C%=k9Bu62W_?Fn|07Ra^hMbmxK!Hn(g$lR>O z@tg4o8v1b)--{Pt=W8d6SRR06?$qqTbA)iqkRZ3 zGvbUH+MC<-lDN#{VyqgoD$ZCW?q2X!kAq@kR(f_T1+=77VY}_ z=Tx0@ou$i%3{I0t8;CcDT%_bGF<=6-GKst`O=-G#gzUU9bg-F^x$~lA2dRaO1Bk<2 z_H@R-yWOk`jr+m{S)U>7sREs97js7J-`}qFb|z!3@%3?emC6lxVD729Ov;M*Y!+YM z##i6>?L7VvlV9$-xV!#R`~zd3D9yAFUuNdr*X6Mve4nuo?#=C9S@(PL*biF!6+g`F zzR%k4&10{u|7}m@ac}Kk$zwlwsR{qiALVu*w(e&=wL4|hMd^Gf6i z-22$YO!4`n^_a)WiO+O z>sOw2k05W5ns+YZJyQPMaaU>}=_ZuR&)gIz-AY%Jyep?0SuN9V``a#e`pyGcti7ja z@wyWhKjf#MIF%;13*YULG0iI5mFvJb^SaC9e~KL&Lit0mD6nSbhBY@Y+s0x6bOq5p9igt67v66r!o*ShR?Kk>VmJP;9A+u}nnw zNj4&IR~bkpO5OVP>(*;u^YJ(PZA$I7^>o0ta{;?=m3C*sW2L~kb1dD~6P-8kKEVBx zAM)1Uv$?Cal3JY5FNSZ|x49-Uk&Q=@D;m2NE45C)9<2A5W$dZX5=i{b+wFSDV|G6g z0?-!&Q9b3o;iH9!Uk*Cix?redpH;Mksq?!m@VG~)#Mk4RV+lq}ao6MYc_?~dii5V) zL#g;S$)VBYrMA*ynOEK2Rw`%sz?3r?J-;ftHFZ-yM3iLu99=I*i-N}=ol5_xyx+;jhg-^`_Mn$}Jy)q? zD(8nn$;v3Di_P>X1>8OFX)A?ZyJm3T%{fllHE}6zyxc;d4J5| z%^As*t;96HyHO&EyG3kdN!_E9pj}|>XOC66jnw}w<6(J0J;|TxMQ!4O5{!}4uTQu0 zk~DY0Fe;W8r8_iz0$mT|>bSNO4x)3%%&Gh_g`IaybTN|*sb}?74_K|49QN(ynNo8g z5G%K?Nm^#{G5KYAN@FutS@xDscdSgY)!Ckljhg~yUQo}dLhg)@suen&mZO(g%M3y8 z+V8RUb5ub0)KOIrUe5dTHTnf^aaW@@F*1qry)-qOkoU-VZ>vMb`=&-jqiMZ_K|gQM z5{)(Mw%prc>Ly)UcbfJVb5@#!%A%(gZ;mo1k@WUbrfZpqL3XBwHKVyt4Q-^?W~HCg zA&3gFnJM`*iocn?o)GiPs>XaX64l+j+cFrs<7{Up&25t`EF!6gTmogKr82MqIX)R5 z^d`2*E5v~t!`;KF3EAx=42%zcdXQNCuWj87k$JYa4y2a>#-7}LAw!Fk7L%`?WJpMSiXY4-JEsX>I>LF(mO zDPtyH_I(X$U}ea{aF64-5(d_V=ek0#cGMU^*o}*fqpmSy63>?uu^gjfm)P>K$*IT; zMlUS^_uH-O@-4*x5j_2K113RYnb|nfuR!Yus|~4MymmQQBE(9K2ukX~$+y{l&ZIRr zl6nLekPS6q3C<>`Sk;|9`#hvzw8joo8CUvb1=~#~ zeRvtBB_+$+d2$b#S7e4|YQXTJ@vZib5c{3 zXp`jWKthW!hYfB^iSEozO_L1OV!3VFern$If2T`%G6?jF6HjITSgY{r7ev@+L#&&x+Txx%ywl_s<4{j zqx-@RBIuMm#)lN=3q%^vo~D#W!X08JbxWPzKa2D3R6>fsNSE6YR(f{GwkfZ+cgXUW zc9&)&jjvaP!R2p8$k5T$xlhZd1yejjDjhMm%(}6a)`nu0y5K#4MS!1cU&29|E|DOT zRAl93aX0-2cTu;CJVc$W*U4&-Ql^%h1)X?l@kL zmUS%CZ8rDTmrR7e_A1cQ+b!iiHLcZiG48mR6DdeaWyVC0AoUx5B16x6_j)zZ3L0nT zFAispe|^=J*{)){bYS4fR?S3@JOtK&zN;xW#!v;{l=_+ zhpc>C2-b4)S`V zhRXFPBe1R^xTR9AUm;G+wB8)4xL{m+`m8c>lxs`e6a~+Iwv8LduG-?96KXll-RL$J zZ@&1B@A7j-!){b{L{jnFkZGZ-141l}$5VIGxhr_B9kDGod6a1!(t^7!tnPJpr6qaw zb)R}6PQSk)%#i(+Mhm)lY^#i7Tf8~OC(VpNLad1Q>Y2FgaIMoT!tHv8OJ?fs z_(Vo5?V3tV+tr24U){O7L@3oo#ktXy@1xaBJzB}_CVl!PyZmUG@m$46I?Dk4qAs$w zNLe%SKUH2bOM*~P5aj*@?`2Tqlp5D9KLO{Pof#Vs~2- zse{ffh+21bu4l{>t?ca-A<(EKo_*uvjcWPRlD=6dW#~H>j_fo=i!+ypf%9R-xOVTY5dj(ZibsO)4*#3QD{%XEKKPeULgZPH1x z$ZQsO+zyFvGut<*96%o55E>aSkz3m)a~@AS#BQ`D4q0A0UapeWK3RD-Pord7DP0-g z1l^W8pFK=OF*)KJwLtYBojYQ4=~+eZfN$8SU>5Aw`wT{H}=v$c+LQ_ z_#MBXWBXh)xL_@pxgBgpY+aU*Lh>}-kYgtqnw6lTw;T0f)Jg1s_U0hj?_$|tIyIfkl27b#v2sLNmV|X*h0nEb-Gc-mi#h4-ZC08X z1eyfxMO=ZE9Qu!~vBU|CO1+*y;`+n-Y#VRWG6EL?K$eGXejGi@=s`$!+xcgr%FO!` z^9?Su=SJ01^BsKGF&7^%d!1|L)wuL8_%4dX#d88#NM7?b7gLDczxc}=u76K>(jmz9 z{|$_~&XnMf&3eL6naMX-F0%M1xy}3Yn9F*_ylSNcFPJa)#l~I1)5ux1E7vZYH!oUy zR(;)uXw8P`hWhm@*PcepYxn9{L-f*bm+M<=%yGX@b{1_22*B=0#Kgk=AT_o?>qIIKXL^iw)6s*Pn&fAJ5a#+HLEGP&9J-+6C;_ zRIR|*x(!v)$cFjRNb82`DF4>|+NuRncLo}bpB#@bm^c6A=)Cz+_vIMbuZb?K3eAhf z<}av5iiZags*~H1&fQvVgd69D2yZWY0K7oI0BxOSt(AXf{bHneyPC_{ha~C+dUiIK zG?dhjW0FLYsgv(E^c}1!&sFn*5GmZkck1{kX>O#z;+QRfCfj~PM?9WbiYU9@$lS#J z1TjQymTquEb|&(OH60!MP`LIphRwLU?0&Ov=$a3Gt+20vkHiB8hvnm6@CyzY{@JV> zDonWAEY7x9|Dy2h}T^^1Ro6a`8Xr-oL?<| z)HAse%W?jOe`&(=vc-cozlJ>iNgL=*rac3NYUo>j#oc9jeG&TSJ@&77^x{Z(xIZcK zq7MVY*ZcGO+mD+SkN6dLpm|9$^wRD&dhEBM&-Pm!2rrkn*MH&f{kb$^)`**jvFC0z zFG+@;>`=pAd-5e;7{sCQlrYQB8xHZu`}2k)@RdRpfiHJx0hydFh6-=v*4%`Mk$2+i$Vo35YL4@D6;NeDcs^obeOi7eeFy+R*ol z8zAx>(ij-+Cw{mT-tddR-d|`>p^`k5pOY22(f$GXe&Tc9$m$%n{;h*a#w0k^u z4}0`NUwn7&_F-2Ro$IH#X&?+UzUx=q9VwqVe^Ty4(73+{`h0N`8sRBvAU|)s#2@d^ z8;%feb_w4s8ZW)L-s7EL45ibN8iu=KM(>UFWgigFau~6$ZjHS2+`ZyuV%f!uv!f zo$E%&-(xqN#m^hBA$T`H%Uy-8^L?;{G2yY_i9VaBI1*kiZ*M*cfA7zw5wlltvln~r zob!@o=*9mLkA3lzx&0SjE^n{@!r%LIX~ZmW5AlZ|cRP4VGW0tB9{Zhn{1;v>Z?FHt z-}`fE#B2p__F~VSgI_`)N+bA?Cz^ZvZyDf=JB z|M2Gy3ol8AUToSt_Wk1wkMPbF9`Von^M)r#TR4n8cV>7=GW25eipM_i!||n$@Xi$; z@z4A7hG!dPE&+e;=)0gd|$ zqaROs7T&qSBmQ}R-tZ)F(~muOka$Tl^kTEyV?Ts`JmC@ExxypR;St`s!Xy59f8Ov6-of|>d+u2Al4R(`=3$Th2>S7aM|kH7kND^P zdBZb;oAA?q#hor*k_^4rG&ijQPU~0dMd0B(xpU{s^qW6b>N#NV z3rcp8Ny#<7K680Y-o;{z&=&9e+}4`lJH&A>ix4sRtLR^MMzE ztBx#C7r(VYJphD`FHqe;OJ#w21Wx!?Rw1YUAf<)88ARY(00sQrthehSR^ ztD~0v-BCM$OK8Ps_Qi4m#HY5D@jb<-Zk*v$2bKHOa^Tksd}`oipL!XH zEb^&-z<-KQeSNu4owm}au6u`16`bQ!3xJ)#BTYVab+b?XrWJl|KK0+gap(Dzuh*x3 z0GxM`Pc>icQ>S0zQ;+reRM~s5yVR$eKHyWwU+GiZKZIX^|HD3Y0q{J~(eG1#z22vO z^a;XsgHOE(sGEFh_UC*m^m(7U8F=z8`~?2zZo=}vKD7|o0dzm;Q#S$0hkWXHfdA`0 zb@?}a>N`N+BR(|(3_eOYzU@;>EcWiR?v7#Mg3+P~rl5dIrH0rhwM0tSJ=Kj04x1L1$d2N(eYigiO^1h}!l zuWkmu1l$JP2RsNo20RJK9`Nrz>r=OU*{7;lpkz;tgT>Uhhka&0nYe!LWLs;YO`xsp z;T++t7QM-t5VJi#YqqshD2H(yPSI`coD~YG0DKbgF>bf1e#fs)Ownv_L5t_Cn@obH zT`pOjCDw{PH6>>gD?_2?)-CL%@Ga}8l}p#x&uUm&BQ)1#n$?Thu1vFyVzxWeY_pgR zWSFhg{Ptv;Efu=KOf&8F;S96Y+AVuop75>KbbB+*BARYE(@fJnpJ7(7=|(UkiEVuA zHJv(=dGnVr>y~s(d`xS=UNhd!); zZ&QwNP5&Kb{a1J!vjt+-h?%FnRf*Z2^9t2Vi&At*net=btPxac*elsb8!WQlv$~b+ zYXvBj+J08Ij$4>MaE6#QCH1!&DpVcKd?zfEaFGi3^WAxVv^Cbl9RqkUTx7!7dH(WP zPi=eWhWHlVR_(#&K!1JWu;9=Fb?8iRQGqI=0v8-YyTAURL_t}BDs#W*^Ick?N^7y7 zSfD1lSlrQX9_)TM@hx=V$(C+C^}X?V%3wQ67vPjeUZuHIYKl`(=wUO)rIYG-x^ zAOEi`pM%k0U+b(=lMgFUhdrQ_xSu8Vb5RxmPU(R0XQ%TX%G>0_9d&rY5t?qb&}~F{ zyQb?mbQd~*FIB2s!o9=r>%qElyU-A?83k%aEqJPuVg_GD}--7-DqZhX=kgpfFThS!7+mUCr+a~86%nM5J``MyIA#Qbfk}y|Lt_bt+ zy*||dO*vz;Qu`edH&P~MVm{eXlL@E9{b5)yd=lju;I}|g=@8v;FyuIn|L1%=1%1*l z6VL5FXNA9JqH_tX@ZV7-^NS+oWsWKG%(x=sc`>Nw^5Vj>itL!d>$*&tt9=U$$G=A2UlNrDD zE4)>~qtP4(;Qp67&n|JU#azb^T3r@**m%mf#P5QRnLOQ5(=A=UOZRa!P@L^_T0X~VJ_BNZH_BH5>Y~f%ItiE5dHmOHM6F53 zVWIyO%I`G&(9_!A3MVA=QYUXDe=K&>LYL{yS2O(Yf zU6dz)p8-4s3rpu|v=r44)ti+3u@ffFlO>$BCft9+K3MpoUmXNEr6CiZ7AF8*!8D(m zwnJI~^fRPAc#CciblTYZU2wQh9lm3FVj48%j6IHoM!r&MqRkG~W`@TRnupGx9VQHY zcpbbP<-(BWL` z+~#DKw}{Z4g%Z_tje9lSLPy#_UCx@)WyV`hcZ$1@pnOcb8#eqm>2g~BR%z#s)NS~k zSU)cDz7Nf}fuhm`b)d3e6ls-dlz!nf_-`osfbzQgt=HWj`GD?d#}c|>vvHj|V$0Lf zXrw$5N9TC-6egLUOg_}F4lS6hq=rb|EW;|e2BiV;mn!0189_{21FS`sZi=z|BRyA;#DQHC!JavEDKIKb?Vyb!5P;&)A7^OhnD>wKg&_x z0x(`-Dn#80I|N&m>vn31uJa4AURZ~c0Nx8sE#38Jd>a&HSQPaN-A`!%wV^eFTA4uE z)Q{p6mrUA*6UR42M9^e9M{~5;Ur|(1^8BRAl5-~Yl+1IA{Y4f23g;*_t9Vh__QEOm zmxZTZ<14EuhD@8Q$&(y4iF}Yeco=@c@1p!b7?+m4s>@xUuFKwXJR)?jqWE8uyh=b< z*^feTsBkL%=SzB+aNo0E)k_NBuYd+6DTv4 zTJ9|J?=7k-ne0r8lpZ^&qGU%|xa24TU!i9DZ$`Cs=UG!Gu7G{;8kA1}v_~fG8=Vtr z|H_Y&xSBe!2WwpizJ}(zKv8Mg^LU-A2uvOl&M`j!IYl3m@rlk~Tc$=J559`>Pk?m# zt4^n|^9*s+>GXgJ%ZxujcQndzfK%GYs8zS2S4fyWW6QA6twdR;=@Ntkx@t!T=O%y5 z_@qtf-idO7rVHUONq&sqnW<(tF~)j9=>r$&G`|j#4^S6^pF#6KKuKx8B<=9`Q4}ph zWdP;o0nDZ#vE}^9e^t>9#n(C~J0;inNLAfe+4v6QeBlcye*_MBS<bkF-w~PtluZd1CJ4>5&bRzCof-Z! z3Y|qo?{G{W>v7tD>2jLDS#Ueby8-%dT#;UnJ1c1voD?lTM%t#+b=xH4&*vbiC2xZ-p!qvMdx@WwBR@uo4C`EIb;=9< zYm1~Qw{6>7Qf*t{kA(Mllv9D?(!uZHY*-Za1WEvfNFU1xY_iMKy?cyn%MMRfK zyS>5Vw#oSWw7C5_%5MPLGPi6p#Lp1aOYG;OEC3h_n)E)_xl8zX z<~6H@ZX?RuHQk6AGY&g4fiBQ(6>|!qx9M_`(7z9*A7C!xCJFt;UHO|bKZU@+o z(KMcEzlo+BSntXjf_C=-d#I32u5AyAwP^P!UWF6Af=v}fLNgRIf_*at-?tI(Mcief8=3i^m z$&)@A>591?yFiy3i+RQ*D_*hf@fm1hfE|+zKJHWJh(d<$aDqj<$WQOQ`XWe!??d?@ z;FOL~m2#C&d?%z2==zW5X<$<6fK=rnQBswyH#gidUC3O}aziGR!a8?PePLK06Z3EV|)S+Mu&3Ay=r6Jwih>{ScUz7n+hC~@bxw-5o=t81I zP)>0YsMY=QJs{t{A^9GW?=qULQ~V*+YH7h<@QYfVK&k9US;(#%!>-vyvx{$YN+NW< zt;N@uHBD)o+EHBSo8C5~?eLD`TPoU)Y72A}-#By8u{BPxZT6gC#j)$%c$}i+5%>%F zDrtQ@nl(Ua>5#OYO8TJ@beqa(`7a5f+#Er9EP-;AQr%94|KXxj9DiY9@v;67l+5wp zKWUEtS5E01|7K|%O~d~(DP5isPT`)*>xAwzDF3bLcA0t6EqXqn>25dt28HgMDBsa^L&THt+#}04 zYO?nKFkNS#7rNi0{8`gQOgV4R{qAI0dwRS~&r!?%23ORhLqOq9_RCZ5_GlG zSlsDcmC;W>inHL8D8qoibZ{?DccIYyZewgOO zksY)NQ$#i}AscBErpR%eKG{#3Fr^HMlz!P(n=oY|+Fi1{HepH|+5y>Q!x&YS6hM`QyI0=$ShvJqF%CtGopf*F*OS^f(p-OSow#5?!6nqUV% ziSk*1xd+2d+$?r(Q8OJhW<8PiMauOcBn!Wd@+9ywz`7@XnD{=!w;dQHWCPGdw=uvV zpPD(v-@|6x41nL{i9R*4pjtWPq41rd7{nZnQV&cm9inBg91$fTjepq)N*Lvmgy{NF z9veXU)4*@|96~o+sZY@j%qrSjoG7`zcyZC=lQx&mo>WovbEoXWqA&TTv`n2drJ^Y2 zxZ`CR-%+{6@2kb{PonGw{H2W){K`HQ;@>YPLDD>Y(`5xftdP5Cv(BlYM zr*r#_5v7^=MN+3UzEf-DgA5BL9lw2LB5D@~qc}a27ln zWeLETg{b1!Mb3Q57d=NkOxNeLu_lWs|7b1*ic1Hm#nsZUog+QmWf2q#b0agfh@2zhm zLh~xhKLOfzA`GvH9`BfSokLS|p9%-sV^BhXzqDWas`+ z)gwMCmDGi%K=1yh7s1CbRwJ+|Co%Ftd|1b>S1OMrNqG;3t%2Y>Zk+qAoWU!iGt z4@NT+C@xJkxs?%IRVGl#r@Nepe?9%!qN3?~0RujI99e7pTY=Ny87Qrq$AApRhC~?= zrHqD)8rMR@HH&%LO@zVDXQaGeiW9wl`w=wmdfUg*J50z~Z}ZX)xU~0qXdg1P38B?- zJ5ukR3`siawvqJFYw>1(<7LQ$e@F58b=v$&mw}s{x$I|^A0z7!leK%A0_(oqwz0;2 z70wpcpsWNM0QN^{qfDD~2H%7`M3xidN@tJ%9Okyq&+^3uY#MQw{arX1`)wL1KKZD06_ zWqsjapIR|xu5Wo_#pYF;S8rb17HM05ZhhOPa~saxyxp1d?Ae>9Jbun?zQ(quIgJ(P zyko)?S^c9d=>4|Ic75l_0(YLQwP#uW=cgQjUhp?V0|TsrVPYow(^sFfq4lu>3Y zLnxafqD!FM+>a709YAgPj5B=*JlXdd-?x1u=&H+pEe=G9h>}3LM3l!wS>uShwMf+J zew4}qQHD@%9YOiNR}>L}nab@ez-G7D-%K3e#otM2(9hDkDe|5UYSWt14Gr<~=}nmXU%vnx6duPEB(uTX{lnVV;|9eZwY(VU9P zicrP;a;K`Tx^3Z{>WZq0`Ijy#ESi6!v-k|>)N{j|Yx|ZhT0Uo4MQug6;?#=8*R6V| zfAzU*=S1eLtyo>Ls$%`-`bC@O)K{#pSXH|C3-c>f)#3)!oty7fRhy;WPqusA(mw1e zrhSlc`{&WnQ{5xucA39FYub-Nv6r#JH_^~BeSrbOB08q>xw7|ogl;#U$Laz227iyH zV4}2FA?ceUq9jmocd|qKzDjl{D4V*iq|NlKRRwVtoQHA}!1~Z5y3L&LR73B!xza|} zWBoX0dam^jKH%M6>YIEO7(D#vuNx8v!*>&C3ZYC z5zSP9xwj03WPpLQovy3$fb2g_*8D}r;9cUUUhgn|uE1%^&olVI@jS1eS?)#frCj#i znoQiZ+=Bhq`*4_-Ph%hMS+fr(lkJ&>6p}7 zk~*kRAx6Ra51MkIs8p6Vs`Vm=Y$niao#J!ORx5q1ri7^l@R%**jq=ob^fH|3amyMs z4FKaL&sH3=2D40g^rXQy$Yksvd=Hu{w40FLNumEH=|^`)l$BlMhQkPFN z<{HGg-VgsK8cIZ~b9+X)9Txjvp_u*1FvCGR7AV*KubPDYM3lDzC8aXStTbH~{njM` zlv_oiEAMo!^`GbzSCmMRGcp)Ak{&vpkJ0HIfrU=bv(U5v{?c8-ZxDreUFB3Wk|In?2&@Ybf8)ZUW!cZYKJj zPf}K;eb`~j*09k13gvg2F8m!$x5*hK4tk%Y)>8)cyq9`!_wvP`z`?jb9_0kUDc$v` zcK@8yk9*xd5H_J(A#`V;oT=#s%(!~7^Et;&@8dOHLg==myj#{$3|^ zpFz=bLZySmFT;M1*nbV>8vx@b6aQ+@*# zh#TT{n_%|gzpdl-jL$jC#7ozybp6;1S@5SQrhX_>e$s81l-a)tO<@`F2PoU7-d*C% zWsI%+DZ(v%N)YQ+q)Bidnv*nb_z9cu>m{wE-0m=Oj9?u+3ngmkOh36?mK8F}d@f1rt4b3b-!f~YFf)5=|&@?}hNLwjO=9|k_)88J1{LD-vKjTfKLD&bsiDJrGALY!3WuC6P z&tY|;gyjV^lc%!oRjNJj+cyN2iUv?G;IrI8de~5`p423vki4mo`S?^N^KofE7Q;@r zA9ZL*kjtD?XgQ$stmQQdO*={#Kv^|y?BARw=)5u=S75Ej&DWs$gr@Ct+au{1ym7u0 z>)?GTrtedfpA7pa#QvF-{jjus-m?BORyuDBrja**)?BZ?x!ff)N}*vi_k6AvaV9;Y?rVeXWDF$ zyILrDLA?_>;?%dyQ#{{>x9K;lro;Ob6vI0KZ|k?QuNV7wr0j>p-qe%BOxd{6l%I=) z?tLg91ZYc1ACi8R(=BoKl%bo1?n@|lXu3XAhUPfmCEpC+X{-Zs7O8t`Lg&%5hX2#z z?&m0G3>DVpK#$)~&|^=DYop8ywAq4`SzpTR#FSaTnAtgxxa*TN)8A(7yZSIY_T7l) zTmV@SsZPV9C}MU=0HsV%G{5do6kYAiEmtAcfSYr!_Jlub!k%*bk92 z9{L-7RiQd$3g0vv)2i-K(+c+J{4OYFyRX}y@HIN2f`k3);Hru0VC3crx=s%-7X85Y zC6$oF<6KVmkLE!2W4;F%&S5X~jRVG=drowQxVgsg6npuV0kUZ)KQCQb7M<9;cl~+k zhB$|r{gXSXuI!;!$h?4N`qU%I&N zcIR;y%oF}svck0oKi1tTB`x!}(V7(>xo^0HI*9pT9zTTBjN3EQN&IxZ4c@OM;gUNls4tqxLpM8U?k;!+Gl8A>^S7DEARNn9J7)t|12+Iy09OOI0CxhH z0?U9F;B24+xCyuj=m7@4sZ`)wq%WWzq1^?Bfd~srgFxuptoi_Dk16hKq)Y*g-{D+5 z5FWzqPSTbEYXTSs`aj4Sz$**X5N(V=+0_Lq0`vpwV+O+DK44J3r7QqppilIYS3nr( z0|tRzlw&E&;ZI}tB?Dn_A22B2w;Bk8`+z}Um_q>qf!hjHBd`l7`?7%uxE~k-hW8i< z+zlPD3y^Xw5V#LBU>8vKRRa-lKQIJ@z6U%3EgvuJ`&?;95j>t^osK^u^axv%rmf_W_OQ!{9++7ZBb= zzTfPq0boe*=h=7v0`~xX(NTT(II8h~95o08e&nc;pExS?8|;6NIcNHJDQbg{z1OL% zg-+ufI%RDIa2>D{=m&NK13(`j=UW?r2A~bt1#ANnK>tE`oD7ep;L{mUyb*F*Cu z{P{Fz;CEsE8TbRcZp8o3Vt0g6Cs7}}fd2-50vvvPq1p)a15W~zXBVo=fRg`Wj|2EI zuxL)9x)yjDIBG8G0dxbO1|9`o1wxgD>Kx!9;B6g;11v?K;42u)d_qQxF7f>u!=ES0=Noz02l!pPGmU+xCQtgP*sf?unpJ^ zd>{Bb5Im_+odMhlJO%s>ICf#7S_=#SPXd1ejyjq42Dl0M2Jj1@d=Y*FUj}{#9Q-!Y z0*C|O0$u{9FJ@^BXa&9j{02DaltOhP@DAXsz;i(1scbm_H9!=&6u24qHt+&4bxEPB z2F?V|2X+Au0?z@3Vd4kGfNOwn0IvYY)!+y45nvE_0eEXIJb<0RSAbsv)0P&hr9c<3 z8~6z@Wf^8bH}GlTJHX$88Fhtf0k9TG03QbK0EU2s(d5%3k@ zC17rZ@B+62Z#{!D47_DM%gewIfCU>!3*a%pS5H0xU))F@0|##+?|>w5CouiYLbV)7 z0-ppP0)7X~XdpjLp`j% zslLUa;ZgN%^_cpO8dBe7E$(~j3H7AA?~l~eN}el_hoCDH+gsF zgu+YaEm^W;R#aXuTodn!E#vBm@UpNxu_&M6=2&ZMe115*wC1#i&EcvQHKVyZ?Eu}a zIzV@kx|#zGDo$4%xUXvt-04Pn7b4695#jKGGl{M}K)0Kg%xa1y+nb~2k-+e>jZ0?n z%650Or#sf(lO*SpJuTrdWU-EpuI3mINn%F+YvTj=w70XJ*D9ikp6)O{!ggzWbC_6Y zwwScUTVuT)J<*mpuh_I$2puQ6RM?^PMR zSG}&hPtV|e`s>O&ni=ru_(MM4CQI$pxaDi4_NC0%EbU7^vlaqMt z<|U`Iyba2})D(x;zX{3d&A+_S7|U00L$Gfn<+UBFnaJtW0XGbzam%bx|r6p`*qG@kyNoQB+H-Nt45L26^2(N1RAWVD%^B_!+g)4LfVIYUA+YZho@ zXdXT8sB|@Twd0`L$jhf)om^dHTObPkM0WuxEO65lqil1FQtk1{jzqtm?} zanJ9AC+hi$Ow7-VZy#5@^U`I+dQ_^{AL##KgBdPLC4rY>r3s*Ihk;NcWuM zagvb=<8|;lB*=KYN~g_us7*2@d%6dVvAJP#Yg-FMIUXIz# zMwwF>+Z;0vk8O^Ty~j5H$lgPgXH1wuk$*JhA0kvJO9%JRXK84CSJ^N^;9uyYB?4;RUkdXtjb>ktQXSdWcEGd zbK|XPRe6d~elNy$ki|cL0UXVfakhK~@>HSxo{a4v zlYh<%I12l*=rg$Is=fJnk3pTmd5nsm-;?n>%%40@T5WY{ z@>I5zbaCkvDYmIn;T_b)I1u)l|bY?JDlKwH~#ksS| z=6*V{^e%LcoEm3}=5U#jqWiU*Kps$xTl6+UGI#iE8N$`RracEKq%G~q=2&-2G?t{I zu-TuCHFdM|Q0W7^u{_)>!ob2+?wbJ)2nsREM>%k!hZo9 zQ_-v{Yy+pAX@Qz?R^WZyxz(3as`=nwTKwgW_S|k3rt)G#M^~P1O*ahbe#LC=k$Y0V z#;62KYR;@OnXD^gR(fY_US=dlhE4XiW?L*-r+5DHR0q>rLN#h%X{hvC^?nMcomQ)s zN+Z=ilGo~)J8MU($1r?&X#DW zj8i(=FN}M7MFu=uW8LS;q-m7-5OcD0Z$xG8(i{J=4v)ah8{L^^bY5??=NbK@8R(BM z6=l56Sdsipjxj=)Dd?s)vW>C48Dn?HnU^EvG_L70qvcV0Z1Qtc<~Y|#s}efb;mX#v z@pwzV3G#Rth&WLvWhBR#L1eRXh5!rl(l*=7%A89)IS6yD_T(TNW8o)<3(8vVz1AGG zIhTWS(B)nkN)s|3PUdSuGo}r0{_Z=b&#>`UQ>l4Lq)BQzTb9LJdJ`S($iT-(olpJ? zCqbL4Y*8L%7A74>My%iHR_;yH#Ol1;-r3xS%p={`ne9DoOMAN!nCL;CF&^8Ra}ulD z7;(t~2o{h=Z#7Lyrr$9w&QZnyka)&@#F9ha$XWY7B;5&-Sx3TwYLm@t!cB7UGO`Mv(C}9*{aW?hc4KZ zHoNJjLiO+z-Z%D`-Gy1?P`?`Sm>qjFXMQmo_L!}}tPit6k69aLjhHoh(rPDWinPl2 z`v+l_-=xnNJj>$67PnY@k;OYLzSiPfExzC4$1Q%&;=fp2dZXbVw0N<_5sN!4{(!|d zTYQhjk6HY@#s9E)s-!EwN{jUe%GFmI^YblkvADtFu*GvMKFZ=L7N1YP%CE}e7q2zn zw_4n4@lh5JU1RL~Ev~irA6Famdo8}i;$;@^{iw0O(c(6X|8bQuzs}+Yi~n+kF`sF_ zfBX?kXK}s7Wfp&{-`MZ8c$vkH#rJ>M*k5GvLW@T}WXykI{d>^jTdnzx_PfvGvn{T) zxWwY;s5A1r*Wzm}?zVWX#d9pa)29Cu7QbxqRO@fG#oZQb<|r*T|1YxmT8jrPe%Rv2 zEq>bKUsyb1@n0-%u=rByqx?E7UTg8O7B93|6QPu4eXstUaUZtWZ?Q}Nw=DV*`#nSI zEx$$>KJPcUXLr#e)_NRw?SlnpCvBG{&xA?&x zLzl4lD2s1O8uMz4i!FYt+n9G+Ty3#q@e}Vf_VUKM{61yBJ1t&r@lh84C1L2kZt)Ek zudq01ak0g}=rZ&VTYQ7X7g`*#_^D1q_W_IREe=|I?^a`fg~jJsJknvzFS7Vpi!Ze4 zr6YyXK>sPfqbz>zUFQ21i!Zde*5c_Fzj&dc8?^XC7B^X3Yw@EtzxpkI?HS6|^t1B2 z-{QwC{;|a`TU_=5LpRgnDvMWGe740GT6~?ww_7}9@oz0ww!PVH`RhQS41Uu1v)AI6 zEOtI+%n!FXWbtx~H(T6k@ue1DXYp=}@3;7Ii+^qL-z+ZqwDIRCi$fN#wz$#aZi_Fq z_-c!HTRdp-;}$<>@n0;SV*CALEnZ~tT8kSk?y>j-7Von7PK&>8@lKmgb8fKdcfG-b z77ti_oy8xp_#%rtEv~nCxy4l$&$M`o#eR!NK5qPZ!s0=TcUyd&#g|yzV(|)#Ll#fB z*s=Jf>kQwgEWXv^8!Y~i#XS~pwm59@u@+CaxX9u^ea!HE*5W5De$e8ZY&xE6@n(xx zTYRd;b1XjG;v$RxWYd4x;>Rq$-{M;>zQN+17WY^jw)h;2OD%r!6DHhyEq>7A0gJD< z_%e%oEI!-fWfsq|c&f$9;-_pmdfw*i*KK;a_45A9jQ=5v%PfBCea3u`#rIxnzOT1< zo5fGP*O+g%IAn3D#SibW?kqn0J?6W@;=O(5`$mg9Ek4!aGK+uxZbLVpxrh9=+3#wL z|8a>if7s$rS$v7bEfz1gc)G>;-f3~0#S1N-YVpWLHXJrQcUgR^4bNKp z?WXf;=0e`zi&^$7Y^;!%RJ6Yj>Vpv%oVI zj}$CfQmt3Q&SJet)=!NOUF=b?IFj!s2(vgQ4mgvV!I%pyHmg7Rms}9*?hK!a!UWBO5YNHXui`Tbrjn~PJ87p;C zjLi9GwReB{y)NmrNy?Y0nWjAJvXxQx=Y^L_?tPa!b#f-9Y{(D_MCYY?@YvK6y+oEx zlqG*LUxWhB@>%JHY~rO@@{-9CBqtDo;S66mn?lP>^x{6L@At`VpEUM~-rUx_%lq0} z2#LB4oNga}?pWEjuvvq36I7NXT|yms^#_;I$Sv+mo{<5@&2Lj~OtNGWyqzTFzX*VQ&4}va~pKhf~H*p+v-h&YLLSHpHUuj*lj!%?~!xt&+rkLF~Za>m) zCis_hBaatL#-nT*>UJ|3&L7S?*^85OnoI4N6>HJ=Bcv{jdX3Gj$RO=|+mzyVupEBn zljBjxHD%PM z^!^LhaL-;2j-*OHnK8w^e7RkWEh?teUfO5wlhY!y`v_f1uPQB{+&Cu}t>`BGIk!cp zztEDWY2i+@eU{N_JVm>=RJK8oAxLy1sNZeGN9~AE@04uJu-{a##hw{oNLzL(aSPcU zyVYG=w%|TJn3R#%iWPL?v+-qKJ}TI}*S{ob#?X0)yu)oM8Q@s@{kwflLVToz#dNSK zky#ymeIpB!105y??yyaH<~Oqv#i(F^M+RFDb-B(1Ok$3y`4Q_uO~8wnjopPbmOv`4W^u@?R8RZ zvc|e3ltWvahgK?N&Bmo0P4P!ME-ck>#;7h3uSl_%8n8$$;U+;6KCd5oFwC8xOOk0} zfc(@ls5uQ<6SKTba|N6`N7@DvH2#R zE{!#}(eBlxE|v5KT%X8ryMYB^FH7n;EZr)htB9KOTgk2VmUy_HgCAYpwTxlIY@2h< z6tk~YTq3wxb-j;yM{9d0l4-e!7%85qi3kEMiGftpAcTZsP9bqlpta?YU>;6yIESqxh_6t6o|?e7N|kC2d_DElJuji9KCHYj{*9 zx?#G=b(J9@JU+LF&m^r!@xn}HKR78>ZsiTI$^zX9T=dOYz#_J}- zrx&i2mdn;($tBb8ZD7e_6Pw!8-hAWIrJv#`?+LH;9#Xmj1eFM$Wyf8yS0%qsnZh_ay zigEhfV1|wLU#bPp8>*$ePkamMmn-h6)Oos))LEKHY-r!o88cUxre=6jal)H=Xcq07 zuq&IdCQ1#!RVqRC{)N4TlSKcsr1|RhrtTP{Xs_LwRk|9nrbt9yL*c-6>gWfnw(0aU zS!-08Uauq>b~G;Gv`T9@D?l}MX8a@DPP$6MbpOk_ap_G9-A2rol+W_eP4solMrNf} zEm#-XcZ^96t|1*aW>99VhvanQp}_U(4A*se{HcRdXHm2{(W{5Q>-vkSLCCxgq`@X_4ASRng(iKsN%W?mxW z!?S8uq{E@96A3dGzlB)TE}5l^p`34^8Oj(~y4zWqcfnG2GpO_$OGj6%CA{CP(#K!8 z56?4&v8=^z!=#VOL8Y+Fj(wzR9WwS7yk zxmu4`K&)+7VqZ(%2Qjwm^-@W^dwHy&?DC2_s*o9bbF6nu8;5;TSJUdI^DJ|l$H$x{(P*7( z#XQ5LuT%m#F4rI}_s6qoY#MrruF0lZ+LhRzBB1rtoIaXO#YQWQ{XTS8XVb~XDc4jB zVRWr8Urbh0BM9qN}$j+SMBEj&*W8BAen0p=gS=q=#C~l(rNZwe!-#=$f8w^mT~5 zbFXl+y{A{b$C$O45}Yzt?=$9dGE7}=^u3+>D^eI;${EKDE^Cb0K3p~$bCKOwryG4! zH;t6Eb4v9PVlB@fU78hZSu;*zyJ0(f=3Io$!G(ky6Y2HD{vp;oySiG!?On+zJ*`a; zy?~+DCbsyNv`THxK_E@BI<;`g1???8Z8_}qeEpPs_BHMbiR^#TPNuqfWI#B|Ho6#y zy9s)^X({*1t-Iv*S+cCv)H|PW3NI=n`bm0LK+bS%}h!R#uO6fWs$-!G?ne;_sY7)vtE z?BTs1IffQ#x%MN7nf@^R8KdztKE-I0+v$|{%Ol-NO!Bt1V4BesX zIShH8O2Ve%U~(_4gx^yUR)iNW=_w8TS z_0C(Q_k)AvK9Q1&Z+f}qg-f`y$(Al{%`gIFre?QVU(P_44)tuL3)f?Ph&7 zx;~!hR+p#0H}rIKIxGFnw6un|S}@i#(96;`Ohdgk?@ZgQU{TiNrXg*!RBmgO4FQ%n znSBlwMkCv!%i<;?+vQ6KeRWsYdA$jxjusOePy({7y}5@;mg0TL$oAUpJ#oEK8r{J9 zgW~l_uUSpf4e{~TJZb5UN`0qk-m;Gw2XHmpn@G8ho$bwCE%B)A!Ma?}_L$cs>pF!% z_An)pwq)3CtY5y+b=A6$jqA^CtL*A>L0YvUKF*{Ls+j>n|FN%56E@vZ9klnGg`O+HO} zNoHsUyjN+ztOHqIFv}RbT${CQjPQn~rPCf=<9^9pmzOWC-LiOxc;Tkle9V_3?$_4U zqCLBc__#MXdW?DP6S`O;*=u99$n#CGB34Rb|Fe-mjzTkwzj%T<7Bj|Gi?f6U4Yd4(`Ig=QU_yZk}PU{>n8*&^}Qv= zAdsferTdWdw=U#LU6AIxGU*l+-ZYJDr$k2OZ7 z@ZD`q)-^ZF)ex*UeaV`bycw`x%jGsDURpKo*R#ODojSb8YQ9fnWff2D%SLO=O+CER zYHZiK7bvADD0jp<^%*5RTEIzWpGqyswp$wOO;V@B`_%=6``btTqer&(b>~8Og zvutK>_f*rQmg_PZ?MWtir$_fK(=$!YO|R;ROw%NVMU`ioZ4t9S6RI^`EmA3KT3T#z zP=8IEn)R~zO8x%-WA9AhqpHroe`Y3xgoI=WI|7-6pePY0By7cHaKR`pQ(UWJh9w9E zVVnRBaT{9O(o%7>YF+Eps-@yGwJx~SPSslKQm3}I{%ftHwTcCG-0E)L?{m&GlZ9xl z?fdz>OQR3p-#z!7bI*RxJ@?KGp{3g#*BWc%5uVvDspK0L#O(n0b5G0{Z{}T(Ua+>+ zCDvr~ZnGDizGhi0?`)3aDuv#7fE3213I`W_+0 zRc2=U%WH0kHXo_`0roe@oMWZNK&JVA8|PM)ca#38VTY<4kGs@mgu1XkSrt>qvFdiz zT7C?U)tvJ>UR~$$5}ACWLw3J-gH&DF4`UW^j+50*GUr^Tt4r6*c>JG_v zn_Fz<+aE)eRpw8jp{wxdtbnEt=g%2JEMdOudH(lJXL; z^Q=z+Wom0Pa}S<(%J@0_`KC@DmEWFuxTI*?idy(O{&PAQZYIz^4|BF=W8bWJMO#M^GT&O7FybzE!Nq61|MdpBI{EV#1U zvk-ozy$@H??iNrb5_d~mv~|(4CH#uCYbK<4Jnm9HUAfI~r>~%~8gtx%^Nv08m_v^^ zxNyhQ-KC|$mH13UNijn#Tin9K4GM)e$M+H1(&K9hTWL=#4ZrZ-@|6Yl^Js>htwpWN zRxBZROYEJ}H%!~=m@TYKYpkJbu5P-_u&LVrA=B zmhwyWeD|T9CHE(M>(Ul>d?XR_ZZ0KAG1$jntt~vqqSRaX+V!=DPej-!3T_z5pK4vX zb{X%5$-5qWM@kPLOh)ZI##k-QQ0iFUCgwVNW=Z}L7VWFGy>*fNsAi8qV@V6Yc}K?~ zBJPt~7Fl6vv+`_1GRGRYpj%qag6OinPixzoZofAn_l~(CdhZntdpGRuZbEzc;0p+s zc0b)ahNZ0wqMvr2qUO%2hmY=`n*Fc-PnCej<6`okQxq5J(}+9yL&hJQA9!NP?>jrE zT=VSSmrXQg&85zR`1M*@;vDo)qszOQ%r+Nq-eD5&D9fWddccBpEp(d}_wdg><6;Nk zll)7^%6lTRzWf!&*MH*AK1J;Q-0c(M(`MeoXJ6NCvul$c-ug_RJk(E_LYu4zbjM=q%Z&6^?uqPkLU7H@jMsmtI2$$@{LkFD#jpLgo+!JNWP8$IAMOm-oHhYr=M1|KxQ*U%v{~ zSA3uF^DOpd=f3=ZmXAF46^`-z6TV0K6Y9x+--Ax3U8D9Vx0k-Y0NeTQySy0@o>^lCHnHeAirbR$0hpk&u1nNnqBz7uYTr~mm3RnZ$2-=KEg91bZepP z^Vsiukf+4Q{tJ_0ejaVGpUNwAroOLtJRgh9Z$Ri+Bq72-Q}4PPN{TXFscW%-!xr}Y>8Xhrvr9^}jXe?$KN zv)TWQ`TmdOr&k^2mq5GJW-ovL@k;!Jeo)w-3Wf67{lA}nW0qh|<^4<)K7d{Mk0O(S z`wTMXT67d%kKKq)!J;Z-ZbJuQ6uTE4he_-~@dGp1qv$NG8_aiAiGTP-?9b>3`~miY z_=kVPcA|4|)lg&JLdW6y5yk}H;k%<@-EQpgqJ!`*>`HVRzC4h11C*WPG({O&kn5BNcLmGahn$2Nz6(;UUSW-(F4&%GA;VIY*i4$&8oq+eLPQ!OF zNsk$4%ydk22!2I%9Nwuq4NK}Y{qO|U5x7}(3ciDVN8%iBOfx1r4A*0~y-Z!e2UKTZ zCuY;IJ7tTB4#FRZ9@NjIU?*nd+=IMglE)~#QjTAv%;CeBgw4Qp6Sy{E<8aSC4Q+%n zfr~InLsZ;U3v;T4Gxj1Z@q}O+OO*7pd=o8U`C1ZlI96B5cY(u-dahf+_m;ybCSiqF zt4_i!CjNzdjgg7$ARahLbr7DSIs!MTPQd$Ar{Sxrb8ze=7f%2lsX7eL!=xM%@aL-2 z@D0~X4?ZyV-lzEBTTN#cD8B4q)tO{6LzMw8JJZqWK-Mr7vUsK$}D>LW0h2x(f_Lo0_(Q(ZaAnAtt$(Lo!3!}- zXA<78Is^T)H6B=xNjyPu!(?m;!=>$o3m$N|G2su$JG}J>m*zAaeWXi60ItAnpN03T zPQ!|$v^?R-sw40=%>N4U!_{Hxn(M~l(O)EOLm8i8<2?M}9)d4nQdT**>1gJ!Dt<2j zK8{H_WZxclKM)lvAk>I|HG zlJgUU=c`V@*Hq`=Q71b;VfZ9=E^%gI>wKr9@NL!R6vh_pUCJ{BPib*F0{@0d*c@CF zaXJnUTcB+KK7>h{GjPsAy)Jx6bq3B^r0(z`)fqTvvAV;DRA=D2R>m>Ls5o4ID)q?t zkbvheV|_9N4KG+u{mC2+e{dT0OMgkhZL3)4Fi&OR_S2bb@RNfdwvjK~&1&j-Ep3kZ zB?#-3OOFNwQ!=uln57TBM@G8vKFMI=&H1iPC%)>-S;H76X zuQTQ(;Zx^0orNRTI~{<{n6!Z~yj70TX;}MJr-Sf9)k*06nsfKV<1yQA;g3|O;5(|# zxyBrViJvgMR&@%#uG)N^cbYKk2VS5$37=D)gZ;nZ#sNPJVMID%1+*FF1>qdb<`M41WZcWajTh=13ccTTV@m`+p*jm^ zU*z1Qu=Y}84q|)@!Xq%7PI&WW&OHT3UCw$QKLL0JChaE;<5xI$;digpwgki9XFOy* zARL`^IsgY>!+ip4vjF_g58W7;fbU=Hw7Jfh1=l+rg+Koh<2?S;u;Rz8Ip`yPc;XGz zwX}1Xz(SOXa4RP1&%yJyFfY>Y5^y^vV{8uA-Q?07g!3@*AAwudJq>?(v-6XIhu`9K z7|u*N9fJLCWnD`Ae)!~1X}i>27S6bhdAEqU0R9lWi1<@**6rN$kj@ZXbqC{stZCuy zcM>Og3Bo6^@1V0Vv6XbAg(Y_pzvLahf^A@~$-&3&b~*!ZyvOMjJo;YJ##j`AzMrvP z$G;!GiAh*)%}oR3Kxi@=L837dp}!6a-J4tdbI2jC)1+@tVU zn1s#1qF*{cet0A%e!}p(s*~^)O#J8Iq=%fJAZ$||hYzUEz;`i8Gp|&c{V<6$1kb}H zo&>xT6F+JA7AEO8+whNB|L|+76Yw6)`iE~}*8i{Zk6Hh4IVOIh@Md*S!K#Pp3yj+V z_*K;j_=alp2%q-G#7`7Hp*jl>c+|Ov;jOCEu=X+M9)uUFPD1Z*oVy>!RL9}-s&jDO zik6D1}udC1WaKv zUkG=o&cQ)1IQIYysTO_}vwagjp*joq_#6G5c!F@9>NuSGBL1nD5PVy;dC8bVG4UUU zcVUvRG%VZh-2Lz{%;pPTsqRU*Q*{o`ewq6R?(f3z3QXcm!q?T^{GDrI_FC`?)k!#O z2ld4LW(c0Kllg`F);PTH6{pj1$7__A)E)fjb<%-0Z?NWkld=8ET;o6VsJES=dZ^OiY8oq`}ndjiFPAxaM6q9xmg?{64_iJz-CSywk zUZXk%|BA`nl!J?Su}8{33ZKS?P|jI6-Rr`J;P+G~;XA60&tvY!QnF_PU&ADwIapof zbQ-Si=P}K!PvUSBb|roi@Fq;soPx#09&;o8)DMrqLb!+FP1w1zeue}4qjC4c(^N;{ zW2!T7?-J)H2+vcUfUl^|!A+(3=_C#C716|*gGUT-It*`8oq|tblCLcEmO1T*Q&flG zLe){2#3cQ~jA|jT8C!SZ^;o;izwo$$PDfyIh0}g`yPxvGPa3ukb~*}&4s$vHug0W3 zBw_V%rvvc&nB6Ob`;GFL?UZ2%{u~qkX*hed$CSTKIm62_+5btxXEE`Uh3{fgM|>fy zsl}umf^Z=w?J5etH;y(#UrWNn_9UO^Fg$yb$LwVMjKiBJd(3F=ms7B*fjIFKf_Gy> zxDQIhhZ`v`)_WP~YjSDu!%@?mdjQVEq--PbLe)w5gz7B3Wjb+6n}O%dAPuag<8a|j z(u|J6W4=K9Cv6dU&@89JuxTHsL$G*Xj|o%$e)tE~S$O?^&OHSy_IG{94_B&=!r!RQ zz)nocjSpIw-yGmEGst@ee*a*nlW_VR%38)@n8j9*Hevh_r-gxLr-i3t(to4yPSt7H zI+y%VhEe!~BPc^DbGQwYdd$FaM|w;NItag}IsyNxItLFt%J~Vy-(j++k%gyxk@%Sl zBJdf^&aH6mJm(&V#Ya2sho@jNwnX6L>YjmPk8$n+_*G2qLldwQllaZC9&;Hcb5jy} zk8^(f@ElD1#Nlr+i9Z8h!kURE2lqK1je7{5uQ~yrRh@+goZ$R~;bp3m@E4dZD>(H_ z?z$m(o$3@UIZ@LCzl6#B7=bsdPQj9s)DPT%NtzRIySnG#kte${48uQS8Cjpt_n0qW z(!bL1s#82>gzS66vs#>v!^g08+%s@y#AElXL-2|P9uv&rAHIfVnAdV}|Aq7&*=vPg z#-vR~;pM7xaL+~VxLj<0sIu3t;iJuhwt?DfN2owM29FJ+lM2Fz%n9U=+8QZ|Uc?$kX z-LufdH7xuBCjLV(rtWcgkGiMfo9b@XGoEACKRi`+6rP7k8WQj()hYNMCjQM=nU}E; zIt(vRorKS+&cTMSIX@w|UUdR)I+uA?#u@l)OybGFF<*CMQ~=(9$(knxXMMxz5Ih@` zG{oUmn4~iacZv~b4%U3r`3b_SuuHkeNx~=8JqrhKaAQ#bUiU5931d+TzJbYJy7@Nq z7$(<>!e{09AG9mDaU=bQHAw=F{0{RpIsg}A)<3)zvwQN;d!EyNxG!ew7kW2Qe_YEC zAI4;E&A?7f;xXq_r&vbjT=-?o#t*MnorDjo&cG8Bq*M0c;g2!#pMq~=zh}*CF2D^l zxQAc|Chf{x=rOCXX2~yH@Lkf3dlWu^$^KOa7GLDz^usx-!|)rb6Yv2{@|c0;7ps3b zM|Bu}3ll#H_>j71pzjjr?uTt2z$vP@RTv zV%Gnq_{Xe&xK?!>-k~}TUs9ceM_*1@o})$J2iOL*xq^EZY$-Yg-^OIVFjo>kX5)uz zRmb5Ss?+dI%%mE89(16PWS*8!hZ%%_z`m%Itagul}McMWy~*e!kW#l%>?1`m|QCY&%z`P zad;CZb7BfE{;~5Dg-tgw?)g|t!<)CLANcf*E)7{|ZgTMmr(w3-;9Avjc!%mVd=r!F znw#;DS^sdY>NvbZbsD~jS^u|CR#;Tl0PveBx3)>ZKVlM377qJ~(*by#>Il3p5ccR3w^$El9MOH?P}?=YK3_}Ja754gY1z(en$@5nPy_$-z}XCaF*S4VIf zCf5qVr?A5rf3k4c&-A+RIMoq&iRvW$9cHf!JMLqg_i}&qbIJj;f?{Qgsx*he?{v!;H_EO+Q?y zItn+aPQcBoQ*fK=4BW0d2TLB&>k3r|VYBKmT&g+>H>gg)&8kyyo9YbQt~v)x9@Y4T zs)KNa>L|QXbqelOorCp{xv)VvS9KV!QXPdGR43pT)hYOh>I{5OwfPPAYnUxJ_$Ad5 z_$}25c$exld|7o4j(psO4Zs5)p>L}cxIsrGUPQh)eGjO}=94!5v#t$c|4#J|}t2>;iItW*&j>0QdC*eJ+)9`84 zS@@o6^CWeLNx22!EKJHa1W#5Sf!|P_fOn`)!xvTOp#KjpY#9C+ll#FG+@|gsxLtJ) zmi$rkB2*oOVb#K=s-y5S)k*lO>Kq*Vl#3?-k5nCoOI1hV{i-wYHPtyd`hPUMBpmyUyPpcc128+s!Ar1a@dN*&Itv&4#rcoI*HxQm+0($}9xVxL zvQ7u#d6>kRfUl^|!M&by?m>9G>InR%>IA$~bsBn}_t<#+@QbP=@D9~!_>$@zJp8ZJ zjXcMJU%?~|ad;Od`;ci^{esJ503LwZyu+C4ID86|a>&9H|K@ZAUZFY(@5dx;299`9 zY-RV^I(l1onjMliXSFsWMjG$v`z!oYToA4XKC;nbIz*SXgV!L^vg z6Nfiq5;g_@tU3$J|L)xVaHi@IT&6k-@5W@WG7Sgqa5@QJ!9uv_;FO*2S|RvV)d~27 z>MVR0n}C1w3gv@|4#Cx!#1n^?s(TXNqdE6RY&0l)hW15bq2nJNj;j^J!TI~bPyh|IszZTWPP52L*LN6!1<~p@GRAF zc$MlT{F&-Bd`7i-(_?;#N&0he;9D+EKipe&5T1icdgAaQOzI^AC%jEv($0f$Ehg@9 z_>k%hd<~PZIk?w9Tzv)Mk(juLVHA^gB77E;{oyQ}{*H??1h0ITy*T#ylQ8g}?$5wB z)p7U?CV9!iecyLF1lM9xFL8L=2b4QwR~laUA!~lt(n;9*k=yr(!bd;G56=QJ@D0qi zpH9|hn4~!bufpV7NqC^~+BAoS97MFP4b3k|)){J`? zK90%rhzy)w>imQTc+IJpgpI^AmuFs1C!J>Nva}vwMv233boHl5!W; z52vdR!8Mqaa{}%?&}%Mbk0=NqR-J)gu5j*AIH1yLKfFP83LfTn?qRqQlQbmYBdRm7 zY>;#J!?~)%@Q12XaEIy~oLJ@j2jOznQMg5Q3hq#ygA)g*c5zGbq?-6+PMee z64g<-L3ILte~j0@%a?@rU=mLn){gbsz8QoUs!l?0wR88wW=z_77;aFVfZJ4OU`v4h zMP4HCHPt!z)f(rXfDd4{|H6f}^ljpdLNm^lk1(NH7^rjZ!jx)ZXuS48m{BdvVYX~x zc6Y`W85?0{52uAcpWt*F-n*x2c*kC<;T;oI!+Yyh!#nm?4WGjvWlYGy;3QX8!l-It zO11DQO!A(EH%)eRnu0k@%0ZY5y0Q`m8(f@mxEZtc4qsPorr;(<*dSc0Itp*cLb#{l z5sj{_!tfc@S$JBLbC1F+Fj>1M;a1GH4fqZw@tCPz(}al*!3(fvbQ1nubq?M-&5f&R zc#73>d(EghUi)rO0IoX3)m;=GekgMdV`dnhjY-<#@Gi`j8~hNH zIl>%9nlaHKco!z`v83Uqxn9<4To>Mn**wDX!(I9K;SH)&Fo#L}fg{KxcAKnE;5kQn zO_IIVIGlWx*A%e_6@<@Xt8vf5@nIJ>2-m2N!(Xb-z>ic1zUVdgV^T+Ep4aS;Nx6k! ztLi9xN_7^Ve6;fufyW&~`H-&&JQq8YunBmZ>NFg3EbiR11mJft$yWjnJB~4e`U=4J zRGZ_y=2%SpMBweJ({ShsS|)If>Nq_7OBm%7hNDk(IskX7&cS<5a_(t(;mJ-X;fncA zN8vuFI30pjElvmEPSrVhZ^XH$;nx>%EOiRsz$DFPp_ljhNr$8do{dR)#^GILmQ7>I{5awK;=+gh~8Cc&zFO{JQD{{PkI^Be_S*z~Zy@y6_zAO!5+k zcdC0D-hB@J1ot%jZk)Aj5$`9!cQCn@Sc;;NT!-2Eh0kCT zHVbRM=G=qubk%Wq8zz3z@EvtG=dxDBB%Tm#!^BS<7Jq|2&pn7A-uF$`QoO&BhJVAv zPY&+2!L^4VJOvZ?2;8JP0q;|tfjcmp7x)pDVoczxyUq1jiuhA-^0#S=jMYJSi|Q0S zbR%sV_b~K*hdD^zCx_=_;y(fRJkR+F!araVXBMv5#Mq9XC@ekS`SHUmF>z1Ao8_4F zr{H6lq%#AbRrf6XSlvy6HjIg%06Yu}p~LW1Og>+bgPSgJIsqqL=yVW1iOIW&S-A8f z($BkCQTT_8-F;0Kj`$v93->!Acp4`2MHF78Itd?Doq@hfogY8kS9J)ssgA=tF=>-& zIPx;qWb~r|TzEP24LS<%$0RQqSa*ffL3jg};Q3byZn={5i-t#E<;K7W^nKrHKm4ic zG@N|3u2*3allX=Hq|?HQ?WM?H1=J01r_ehG(me!<$v7;M1zJa9~Qy0nSn#g3DD$;T5Wr@VA(xISX(6 ziPI^VQ7tsLYMk&eOvZ#TJWF*PK9AXTEIjw8y5@)XVG@5Drte^0pbUi{U{Vg|PR6CX zm;;LG`|x;7_5&m0hDq2syk2!0KCL4R9I|&>v-U^0pXv}?joJ1Oe}hT+WZ=iD%`;vz z6SMKd8!+*cf`3|x zJj9yh{tMpzqSI;ky_Xn+a8JVKms#`3Ivn1FN%^GU!G9-@yz>x-d+gwziuww|?_!eQ zBz%7-^(y`D71{vyp~M5<$1b9cnOA8`nC!Vl;6<3NBltWfY0kmA*Ib%|aDnP5%wShi z4nlv9xfU&4f{A++R=&>I$QbK~*S*1*MwzGJoOie`X%54cm>uWg*_g~naoF@O`K5h^ z;JKKjCjp<7V;N^)&3i6iLAV3kNq%#1{QIt6g78gjGkG^3FkgSVfOR%|u#N8wK}DYrEAdwu3RxC>9lilp4&o0ycj@%hXzUcN&SZ5PcS>q!}rU%7Vlm72l~u8Sb#Yw4xhusPYxbc z;j?SVFkFa98AjnZRVU!ps*~`Kl|Iu>y=37mze__1zEee5#(Xo_XMTjq+>wIwhLASu zA_{+q*}e&@hHAcGQnm0UO!AV0%ZK^gSOFi#Cg2_#?lX5{lK#jDpIL)R8;HX(BYoyA z_DciM*W@#o7Lj*2K@2}ZxB!!W6or>e^O=iiZ%OzZChj?SoKiD<@2$INu`1mG!{xJO_+CjJxf0d>!a+ZTK$Oc@5@LQMQe z;plzQwABE7et#)D+ATa{HfbX-VYm@1kvQQasxz?c0O#(9b5)1o=mVX503L!#dcyE7 zEI{3*;pGRD9^Pe2!pAV#&&aKhbBPBXfX9T1|8?Bq zhJx_c2HN%xp~D!|(#tN%)-V9Beqo z`3b@GsuS?HsSm!4QPgfm>zfhfl)yFwM0k{N{u_X%sfJvXt!a>KowiJZtVv^1T z+=5A7Qt%I|v#{a>%{z=>k{4kdv;N^9v4F%6(z*FY??EQ5Fc3?6G z#o^#n7^hep2jJyckvzYE4`331242uYUZgL;`!Tsz2EMP_M0}h371#Pcc5n;q2u;vz>7}4A)>1 ze;huB$+Mvh{K5*#jWmSdS(sca4mV@A?%@Am5`Pv3PILJc&ceh`2tJ9)9!M7Mxzb%L z2-m5O!(Xe;!1q;~RX#p+=IUK&PFD>F#+>%U?_jcbln^&L=DKOPQ*{mwZ_~8FyH=A1 znZMvmn4~iYUtU9gnSXO|=vvwy?g6+Klkq1GFU7=v626Gpxdy(qj&h*gnhwULGq@Hy z1LvOUv(Li9@Qky3rin6-!xScEF5K&E>YlM82w#j-cid~_;BMzqr`#`u;Ck6Hilbk%YA6V++>x@z-H+Aa1jc?rYQF$o)o zFJksH{?NO@%^iMNqdEw0#BBM%G2hZT0v@3{49`>@hc~HC!Dm$GVEMP*b^WkjbqF4K z9@k}E5r(riQHGRT2(HJPNoNAiIN#M>2%e3}`x0^Z3rzfE#67{dK_269j|;SY!g;DA z@aLHLPs7m{IzIup1e1Omg-LZ6uDOVH0(BRMTQRwA8orN7JmzAb8HY_k2jQkm81LAx zOTdq@6rTs=M=H!Mm-@^LKE?;Q?lR^wi4(q$Wzhc1eP${q@q}QD>IghbbsS!d$sS7* z{sp_1eTgj0T|r)GlS1EBj5V^BgCR`Xr|_`v>(~V^!Q@&=m{u)}TYn;a!QWvLXAT~94gH+H5QZ0F&2la1`=Lui0ItVmk2e9gVX}tF z!UL}JnMWzBFkJm3pZPO74qw3}o*cY?vuh6-=>M@ByM*^(xNuLvqi$uqz K`>D@7 zh`#Z5#;7}3L*r)zT!)>DJ`--m?nQ5do;zJ%D}pmr3y)A8hV80_=c!J>9JZ5qgn_Lt z9$}Me;aQkmD-O%?DaGSadUswGWoOHMIAB2ZsvYu>)BkrM2N0KiX!o*z| z#pGJTjB4TZd!73XxDFHlXToRIJqtflck?r!ISDH#{*&R;*bsCMR@_J1K>Oh=Os;hT z+>A+_!ux(s{_&HB2c=09TDayHk>}GXQ6(4#5*tN8lN%<8T}HA!Ant?ojs}objl055W^K zi8BJ*Rmb6_s*`Z5>NM>Cm<#KN$E%LOAE{2kcT}6-(4H|#Ll}NfbrQa$ItQmeuIYg% zU^YFlU3DC8!)#vQ@Ql*|c&h3syi;`=K7~mNvbq zbrNn>orX`T&cb(9nNI>>bryc0+WemMV>W&`OLYjIpgIE2P#uR?s7}IrRHxz7 zsTy@@2F0|AFEEmM^$IwcGWpp z`i#a8_f#E(&8ow2x#}p~s5${Rt4_gfsxxrA>KrWji^dQ4RvmIA$+ zbqYSBIss1F~v+6Khp*jlBSDk>js7}GhRcGL9s&laN zd5s^|s}91ss>5)V>L}cxIstD|or2p`XW$OiIau;njUP@_9fZxQ!*HqUDBPer0XM5o z!ELHDaJ%XpEPFxYhZ9r>;Sx;d+$g+WbqcC#Vj>Lsf_2GSyMIL3ILdR-J;6sLsG0s&laHWsM(B zP#uJGRfpkH)ls-nbpmctoq~_3&cGe2bFl318b6$emGZk7#6>(?r?wAA-GC)6mG<%{U_jN)hW15bq4NGor5K>X`10g z)j>E{br>#J9fccJC*T&JU6hbp)QFIu5T;orL$OPQ#~FXW<8` z%|D1Av+=`OszdN3)e(4x>NvbYbrRmAIt`yzorNE$Ht!HWX5)vmREOXRsw1#nbsS!$ zItlMlorZr_orNE$Ht!NYX5)v4st&{Ds-tkD>IB@PIt5GL(>@L-U=mLdHmeT9rK+QF zgX#p_tU3iBQJsN1ROevX`z}sDoTxeo=c*3F<*K9bJE{|Ki|Q17M0EzftU3qFKG68# z1l2*-2H`jx!-ELuN47Fq7Jr@&jf3x-BZtb*H&w%a*pKD;Kw$Nh{Y*lHc| zm#td)IeP5B+*k7z_kU9lZ#nVErpA^-k35E-g`a(ROY6+m*2ab@Gg}&_@EfSDG5OWW zHhw|9Wkyr5VZnkajnh^yIKVIGz4kcYxogxejU6KdGNnEaJ00n zT6C)X5IH}qE(2(f2J|nLuKSI3)2k}~wSL(Bf}bskD*)H{acZAmEpe?O^>qt{4w<&s(Ox83{#&zc6D2j^pp!ZWH+y;BS+;d zQFvQv7x&?~d%C%EX#ZAP|7E-Q-)h~-f^k&cPj@?G9N$3y<0jpi+hruDu1{1sFdH6pP;de zmOq~Mo++#>UmxdpzDhdR$0w9n={(B0;4Ue04mrJEkz}Q)iGN%qMSN~W{v_NwErMgB9_%XVGLs+<%QcH$AAE z#=o?lJ=A}`B^BYnpz&YO_?N04um1BbRs0t;{tJ3~kzMC|rjnikCA?K^OvKY{d?TAp z(QxOh)P(rHiJHQ3iJ8&kxWA#cyPYpKPvv_Y7nu!Ot}+wuajgNSxySM6%X?Se zK40Fu!vC|pC&*)fHr_k$G9pAgVe%d>GxMKJ1^@ zTjC!uxfV_1kiQb8UpYbiLKIv7KuI36eh}Wdeln)-|Hk>&e_aO$%nd#->3O)?)Sc;NH=6X>_$HfjIsYX-BCwDL zdW83%Il!Hl^gK~%jwsv`#QDHRr5WB)#)mrzV>l+gHa~MXl=U&k!#f5H=zhK~Hn5YQ zuW{w^P^Br4nE|qNj~>AF{#0oW`Ia$Bdkq)AKjtlys5+nP>`#0pJ;D=zRjjDfo^K>S z9M{DL+4#$(yq~Z1SGAXWr35?E9850zc0iYc75JCJ$Bc$gK zq~|i0ZMKMAdc5tWcFwf7f)7`k@oSk~Bh0%}o|M=4wJa_c>M`-HUsqXS=hTbGmYdCd z{lGfEzi|9KpJ!|@DU+$XGfR4Gd@}#X3-Q%;RFuhYWb$R_dVjO6!vesig#w8G5ie8u^E z2dpiR5}W@EIR65l`HYp=`YN{dHQ9J$rS@JSU%uO7*((p?ql{?)W`dUhT!j?wInl?aODc=PTZ1 z%g5h7Xi%a4?D_sKJ;fD#xRvX7R8&a)+4_Bu<5<5=xfinK&H3WCO1G;Z_s+GODBq!- z>?XMKef}s@v)bNl?{7OT^&gB4=^TKALddI~E`sMsUJ%7kZPTIxb*KGfm{^&23@cH)7XH(=*wg7B< zr8Xlvp4j$J^IfRNFI1MSUsq&rS#5jRTtWLU>8hXSE9fuYLB<<*;}PxI-(EJz_E)Jd zX^-udgIxJ@EbYBx@Sv{#Q0b>X@TCq?zOU=~?)^p15A(YAZ`-G->!_^k7M}VZ+Ess2 zKOHoOuKvSue?F*d>&up3T?fTm=#S!mBoDyle4#(_;&E5~24hvSrgr81g7lxVPD;of zKX^XjM-O)W$EAP$x^i#Phcux={~0sb^&bkH`dhzlcyZypj2|7N{Dtt0r}UT6h5og$ zl>WDyzYw1Bi2hnx@NeS_6w82UuXl~aSJc&?ZG1yV72>hIh6-iLg9kp912+QX=>@#1-Zp??pP{`_E9{~l?_2g%>)Vw)a`e}d1m ze%-K=LU~M(^CP{5{BRSre%*+Yp2uT~3+bD{=U�?P`zK|86CP^R;{ifc{T?x%NeR zq`%AcZGV>XjP<8s+WMC9l=IwB=KXJDJc;?(*XlaY_%lf6D?7hZUd3%gjP0M&|AQ6Y z;TMV(`MobPAzTW<2{9osp!+xZ&eCqDk% z`MR|wop%1v_FoelIMAgpAHJw_w{GFb#rlnv^5~Zj-(F$ccZFR)R2B`24K9)SqSG(y z3C7cQzdi3SoUdenD4eH$1{TSFK|Vg24?-WvjOykGZ#%oTobQ?s7_FD)*O!%M(>mjo z@aMkdFPHXr#5(aTe{McGgyWo@pouSCX+{*zV{ZrbBpqG~mHB1lI@emw?i^33 z;F#K%Kk20QSU1RJJ-<}WQ=alygDtFJ{$h)VD91{Dv3|Gfb=G(J@uaR}kd44=&rg=~ zgJk_3T@vLxvv*y;*R8GUl=Z&7KcRkluiw3~A)Mn+(ocV8eO$%m-Su?6o|O))^~v~G zRakFH`&vD?Q*w}xzpA~e%8rL=+x{3IM%eKof1dd5dOaFpJx+SYucdV{9}jlzvxM=J z1atj)tj`&5d!*0$=W2`9f5xrfyRToEA9}4vhIX5OqudC>q-)k%7{$Xq(e!Cv&Hh(V_|J7aV zi)-b2`V2y@_d>P3pOEp#mbbk>(DjI{A8hz;^Zi1~zoUF~_w)?v7&y9ceH*^K+>UpW zzooLi7+7o@l*G4C=HFV|mF@gpC-ZY%Emz;aYyDo)$sJ>Ueq{Y#S4a8tH(S^5b#<~| zJ5(Blw1pfrTWxXK5 zyQODLkNJu6EG{a>hver4h8@yJc>Y{^B)&jbeB{4!Kw-T!VI0SSl0tbrKY;Z??MNvA zn?LdsjP2HIz8X?f$JrqgKjY)LSZyuyWg-8p9|siHZ@cqZ_}I8w<~#mede*NSGr-n| zTMu%5>NB7I+A`{++-->4`>P7BKd_Vb;L2ahw|&G237_x3l#eVAT=}qm>z#h?C%e`^ zv`6O0T4`T}`f4xH^`gE1_WEPR!&pBQ;;(BTJzTCg+Kn&tNB?kY(#^-*?+l8K8ZPS# zJ3rgwvBPDC(D|DDRM`D|TmQAPJ`8l#$HV7SK7oM&{O8Ljzh7e4UpBu3cp4|^k^Vec z){oNMTz(&t^=g^iALj38f`w3nJ#MYUw0E5EuQP+Y+mx?BI1az1}Odw<() z{UYm`08>VO{wmy`NPF5W_a|wxeuRwYcD!f&8o#z=gpBuH z$LxRQ=L62yb&MDx^M@UeZTa|n-5(B<`O>vdJN|k{xc*pj1+KUE^d_AJ?l{-A}exy7k|F{PM8;&DQa?pS`~sCgTVBt%>#H23G#uG3|?QIOx`1Sgu98 z?T5&ESMqD;6Rua+!T!Zcj&;4TpWB~tYbM$s`x|^gX4iW78?1-rexzvV{~zn&-s|7N z9js1Rui8J#udp93>s^K-Z~H)F$8)=0*mXTB_qVdXF&72wdU)hISG#t7JgLHz6z+fV zvv3miBc8c`vGpVE*KQ9~+V;0?vKbt+S96|dpZ(gjetB;|){)+rw`={&dV=v>?&sq+ z{KJ!3&-5#{>n(|&_OwULXWI|oPTJo*>W<|Kc74nGf&D+9jHh-#o8!8ex4-0|NWybW z|LZU9D;na19jveY*Lc0e@2abEn}@q8icYxWcTCEem9{j#1Wf3AIV?7wixkV1MT{u?-y_+`Ic&d17nocDIH zn&E!J_7B>BT}N4&t#9ogtnWuk`tALSJx_NazA`=U?O=JzaXvibtKE+qsO8K30AEVU z{k!vDzJA>}w_Y5{__l=K7?b|S_^apHf2?uqeLYWm7;f)(^m^;p>H1&JFO={lcE3G; zy}H;CyM18ONB+mJ<#7e$xyEmwSNQD$)7{Vd$zNeSWxZzir^t@EOwz~aP1_j@B>#GU z!~T>TkLCPA-Xv=u$%Z23n|Ad#vVD~7PnaP?D76P|{f)HsSGfPdFXwG}(SGE4MbO@2 zlAfK^PseT}ZGPqXiR`C!a8l||@=yEq$3_nk?{+;veGbz74ZEHvza!hLN89!$Ybw?w zBil#Y<5DlKocDKBRhvjxcz?%Ww_cY0P}0NttjhK+6}HN6#m`uD@;k{h40|jIjH8GQT`u z?e(|swj0~Ko%#Mtd$Hr4J!U-Yv0i%kD8iQvll5?ZeN!4M9V6Qb?5x`R(;n+To(Gi1 z%1Yez+ z8@BGBU*8AX2aoAGp3M27L%WVEIj$`0IzHc!-cng`YkrD?vGJWu1%>r_N!tLsePVAS zUzF=tm3F^A`(-1f{pQmz`-O$`^p^nZo386sGCvIFb0z%ABAEGx>jzx;e0kP(*5$291@e&y}Z?-`Mi*=?}D*mllp0KTF!G*lXlZ;-`MtpByj_>Ef6A9MD;p z&yTIYx?wWa=F4kbtfoZLQ)%n}D5;-7A$}X*xDttPkesK#m$Z%L^H==I^@1|r>iv}* z=j43P@)#EzG*Hqv*2XXS4Ggg1&1@t6t)y*`ds|g*f~XJng9l1`aOs&${R}K4{rtK5 z3B<+@=qew{?|@Q?KVajx?PCP>!yoCj0i?|@^UHqhPucrhakA}4 zpKpl&sI;Hn?YGA6=hxZ#l=d>N*401dM}Hg_8&@vrA1CK~w?}*ZK&>19q`W8Nf3S>K z`TO${JKrqU`HJx4VgVU1jBP(0+xezd=NsE!MvtI=_~UxEK6Lz$OwpcLPw}>>tgl#4 zJ-EyKL-`M_94*1M{jk4LBkifHKa3qy=ns@P?X!~f$Ghrd^yqHKlvlOxFA)j(9Tyur zt}8ydzMBv2d_{T#%-7xPi}5B<7;mVblD5&trcc|~emvi|&j*NJ@*9X%jja}~?U(xR z(f)_ESNr{IjoXvLe|~sL1`IhVl zWd9)(a@X_jkPi=bKfCvOiT1&IP}-k;-b{P!S})OSX+I;|OJt>HIt%NiQhVIBUZUsO z_Q(CaJm2nMJ7%H8C->8G%=S`5j@dt=J^4EZ>iZh!Nc@~HD=n;t>~U$Sz5kc@H|%<= zyla1e@7A~bbEL=a2e6*1iH+*E9;swKGO}yEN`7e1UF%C*pCw)EDejl@?aS8hHo5<< z=z2c+@KMBHKCEy*O?paW{&LqJ8NWHMi4C&vC-AXSTfZ`2Q(y6X|DwN)a^dk$c?Nji z+N1rpSM`_=Wq*K@&yT+z_Jb*uqYCG3{1p|2_$hDZzpnOe+wb76 z`8CDcx%vG!8z1+pF=@^G$)Ih=n<4%hq?><_qCE~S@9J-T&5wiq-RJ87<>B%_oB8p4 zu9ts4Or+OHf3K;O_$%yqG@0j_<0>V5{&^zT8(etaNqzL_&(!a*_CbTCf7v;S{Ci`A z-8=r0e(Hz*RaNM3e22^UcDFvZ>vhHpZ%q0Uf8yVUAL!z;`RD$F&Jq2nt3MH6v81nP zB>NrPE4>^SN&59zuIH8ZVUKBV|F1nCBy0cHpNG)CK4J)|+pWD$ivR9^GPm{eVox($ z*}r>>9P|J74liw2?$>(%rKJ7y$0HE?Cw^qzrq1$74Ex7+inP>S{zkZ)7(RR7^Hi@x zv!^>9?dgl_+C`Bu2|ba|W3TMkM{Ri7CH%V@{?TsXr3v>Af0awnF5w?^H!bGFZsGlX zg?ELwOZc=~RG6o_g}41qkMsB;86&B?gum1c5+v5TEnFX?NsN#8@MO#)0{0odMm=gf zl=e%L*rRv&lQjRkT)$q!YdJ`Gn#6?O;s1=g)Tx~B@wcx#lC*XI+x6JvX~sjEb&~wu z3q!C!{~ad|`Zqg}SB|S&dLHjJkjDgXj-9l&HP*(b!~@eOO&>U5(&`1VQ(IRD+D>bk zF^$6oix&X7aSAc_vuh$YC{Vzv`*-&oudGy8JUk{+Y=?ZCc{q`$5VBmlii_^w)l<%DJv4&uM;G~E6@yz+NPBPUECz&aurkrHFzcrH^W}a-O zR5zY%Jk?VMdm5{oPBz}kA&9An>M3$^nkalssh%zhA5*Glh{DU1YQlTP%alMvaK33< zw08A^Nvl_Vv7~5`scEQRW~v)&OmWqe8dHrGS8=quVb;I_r>$ChIo~{EzRYLc7f)(u zI+FatYiJ_m;$>@V%G&VNY&+(MKWA>O@Kg;ad*m#NYy=iKv&b-)6o;v>wd^Vk7Mopb> z_Lx>b##GKXzS?ONYxT7G{DR-LX0IuTJ*K-;MZQ}+o`;KKZLO=LlUA&*@-#*p>ci}^ zdUM{;_-I35VQfL5?(mkD0}cqb-0$_BCS}|do6ZVZROtzIzdd!d+6}~}JzF>-)&BVqp zoWu7w_4BqH-&SAQm%M`-H+#x|=9$(gY4k6?iaJ@k;@Xm?#u-%6ZVl7T?lbEjDH_xG zlDBG2b*b+%&!EO;Uz9czqpX8VR(I5z2KvGxWB4YxZyC?Gi$=7MHFXX3C-Z~B#@D}J|9}}WC13_O zZ1o_=HP-KL-s>M+(m3aU8M9|M9zcs~ZJ#n}K{Og$wZaT)syAaBTg~{X^*=V1i;ZX8 zwEFQP{DY=*X4=7~Zbtn?Phd$?m2c^i(-usMt*kY5GwW-L`rT-%XY6S_O;dyQulV}y zX$DX4C?0r$@f_pxooxJ5Bc7tPhp#iLZJ4$@cS&uBcqTUn$9X-wH`FKkd&YU|+E$PA z@~GM~%G=N|zkZGxJ?mpn)ztSqzF&Fn_`A8dxr{K!*1_4gj{9@By!d3+>x=AOno zyx}9y_^FIYCAWI^m{Pw_QRS1qeuJBy^m$s#E1vXu7xtgr)>P&7?9nLaXyAjVMvP}! zxz@qtwy9H#eB-B0nOZcqahiAPG`xBDo;KwP&mPm7s?C_j^E{2yrbtm&z@0pcsOUAZvAQ+JtFdquW1S5{vC>G(As{n%B5tHK*6NBUOQPkREg(S<8osh!@ zl5&nOC&WgOgaDNzCLD(3t`g@66~`pu-pGX9N8Q!spS!EnsA_w9dRjei-s^tv*Y8dD zlr~vmLQGqQw#YtIvXwSO$SRaj+d!F?RVEhn;^nBu9>QE@6c3=1uFF89QFIuTu$ZB; ztfjI7j7tWLI)hEyfbQ3&t2=S4ZOW+IHZFQ!vO9r!_G`G?28uwm@>%AU&r-&t&SBhT zmsg`!M?HgO=LJk!9ca*{%K>^wLO+8Gwmpda^B9J-$@p<#pnaU;M}m+=@Mxt`n*(*0 zd(q{f0~~xb_yGr>*u_v4EF2mJICrYfMZmTMIab8%Vcf1N-WT z?0g&vdRh5BmE|{4cUghHsV(xa90LB8b!f?@vkW$ZkSQ4CZJ_$Qfqj&Ib0Cp;y7MID=cuHgKk!veMeByUeh8t2jz~#WUdF_N?RH>M##d3Fs^}evCQL?J%jI>* z|0q0lv7hn#Ln7S~i5k@ye+51>W4o7y%<)b^At=9e)_egQqP zyc~AaS`rqnX8dwS^QLk+!xT$NQc&$H8L%!%$T}Zwg{t-(24mHP=x;+3d@AfD;wci( zYd{}l5Pm*R&Zqp(0NFLwQy?WtYVsNo2pRZFU_wZOc}^1AW|kp+1n728EG0==HDx`* z-^K-9`$QCgZQ4j87;7SG=`xg9e9J<#Wkk;)nv13MTnwB`;y*^4jh^MG`PxrHVidm^ zqszfuJs(5CM66^L6R{wEGLqSPa~wYyVp$YQ#PDw;qy zHo}*)5#C7RLjf%Jf@8Yb8hRJ;93rju&~dD8z4oRpNWSG*3yT|Tv!8UgQ-l3&&da`o z-rr8`_nU)Jw001-1$(N6pM$cIueb2?P*GOTLsn$Jn=i=vy~vUcH!XbTAkmMzt7TIjM>QdnY|ad=yzfR_~jUmtD4-7qQf%U9T} zt!t?Fqo}0WA4MJff&%C*4RW&f%{=AnAfulF5bL-XjRNIsai?rP1SMs+lGO}FP>DAb z*y1q`=Mzu`)DGc;xDyC_9EvVRo%~!%WMv)dENMgQ@?t?P-q{Q;Il72p%kN`Y*@jHQ z_`AIL+a@qS#|pHh8^;Q)q?^YIVr|hu#ck>>E(=)4jDdwZ$g) zl4T+KIe0(jnP9;vnut9eY`^Z>jmZve)x)22R(75^yPzoO|5YmJ2JBUvgBw6n+ zkc8HU#Nn1Y^9>{b!cC$Uy_`hQb#8BIHj8;sY$+^>CvnNtJ{2OGdX+hbh5K8;fZd0f z5W{IyYhMz}Xa`|Pl^al0@5ccfG$s(8R-vy=aneB=j5KQ2in~b9R;`>y1-saYGTQ5; z&bTh#2lJ7RC1RJMzlzBXl>UHr=%6PktaEEYwZ=zq;CpoqS^~wiL?n1Cu2qtlod`$O zpTzl!g8xEO?dqiQ~t{Oap0VRV9xf(HcDvpnfU~N#` zkXs@@i-)%7v52AT13v#^KK$7ld)aNqc=i0!-i@!@FUjVrme}!Gbbo@>G5Z!&GPD&y+1Gv1(MTXjG0ymLqgy>J zJ1a;4>wfCe=H#&;-p2Ts;jN-;0vK+*2}PQr(~q+-eyf4r1JJnxkoi3bjE%B_0D+9G zl}e-9IAi$QAO-Stbp^p_AxO`->=RF5;l&^kQSb{{@=Hpd4W=~>6%;VIY1=@B@?vIy zkn0fkfdHQjf|gnf9=DmGY_CNarYb%_lfA8t7WxtsC-@ul7@!F=(bYuYpm7Mjl!4;>StsL zKi`d(Aog*@SA(lS`+^k2Kk0^6`1hIEYUzz#bY)3)4nl4HO{jf)#;v^>Nw;GBv#ye* z>m^^kT`ZBsAv6Ta*?bw+6}y&CNF$h=N*WJ7+LYBzP#$r{r!%O*DX|p%(DO`O2k5$| z1EYUNq9zMlL+l(9I7#cQ6SPj3IlmN8!?^XB%wTY4ow$@T0Z4O}Q*zj#80y7O+ZNKuLqAX8^Ux+5plb5x1~OHtq%w z2_P77qPU6eCXSnw>oQcf>oTA%>M~K7k3edzMP1=wTM{uaciprFoHZB1H0?svDm|Ty zB-4Ba&GH*kln8)Z)+6W=$XW4Gdr3J^ExC=g~d3 zYC4+jIQnnulUrxfxYc!aq!2BGK}ezCeSXkdQ)yqi{Go=xyY;O~5nPZspl`l@i;cjF zHllXT`2@hh2Jm}0CMsx91OpN^(1gAvAYHpWrF|zINH)}I_b2(WbTFl@1xfBrrH%8E zczzoD9Z;4kzlKlA7GF%s_7MD=^IR_SLMGw#A$|sfZ%K-3b7?-> zRZkCwW%Wop!4D&xD8D0xF5~!N+G&|FyhH`006O(@PU-cveTF&&*M!%Tp;yAVKM9Y^ z*YgN}E!|FjSr7XDwRALUt@1-bhitr-Zox;BY(IiiKW1+$d)u$*->rL}OVjnkjer1hra+bNVRdzh1?@nu1JrZJ^`KOTkO zrCV;ise9wMYE#-_+H9(#Aa;r5`gn_}zCeAT3v>{otX$IwQyu-oqGUC)d&K||@k`6- zjkuuMR1^b^Evnvtf=#&2+>1eg>tHUw7F;aOI^TUH3IIYbL3IXT*ozR%b-qLb4}qp^s)!ZZSLs7qM9g2$lWf74WV~gYbV^FpHib%*40?VOx!vRBQ z_xvy$?6%BLH3h$!5r$II4v~BjinV{LC$t05k+MUiaqU(_mS)5Eie!K+8`YfqM7*1n zHW@cJMPub_5P+5te-$8v{{5x&RTSJQ;-3(@1-iYyJr}yrGCI3n&xC|%Ek7?8$f>95 z{6e;*st<@FewWa%AueXa?8-!+)!F6ChJ>&9q`$=e-(^cy?|oUZRnz`93#g(M-gMkX zxIY_`^}C3Gy9u~S!~|h?fwsMifcjQ$1MebSqj{&;iXwLr@QpU2z+Hp~vX~EpbzDu) zTF1lC)-e2AFe?d{_<0a650Ef^oYEaG&$*QKwW{_2;XA-A-Wh5iRd=4lX^uuhS#ft3 z|85aa7N)k1`1M`+$Ny1iw2YDr7No_?cypoVB5Za##6AaK3 zto2Gj61u*U1QXVV6yzRr+j*NZ@^;vFICmCD^pgnLhZC063Sb*>jc_O9+Ytj4an@|~ z;)L@`8k`ycRKPQL8Z$r<89))AW&GLSqli8tG+D*R0#YAI*?r)TI-QY9rhPi2^KXDF z+36$v#SCD2ry1bF9gM>vUv^>2B3kG+Tx9S_#s@4R_paq%LPWfs!GV!1M^EMX;Y_4t zs4EF3*N5?MT&A>`WpGEpGy;VHDs;F50(7E?LU44U0+i^44oCn2QlK2k5($I|M{s_- zpyA~p_|=RH!S{NPlr`3V1dvt}%4uUwDNV{o;m7v_=D8WZ`IyM{N)me@6gfAvL@6#s zWCR;G-MAU}+=0Y%Ka3G6C_@|k;!YGeo*4`W!$IW!fgk9EFAn6zfxI}77YFjEFAn6zfxI}77YFj< zKwcckivxLaATJK&#euvykQWE~KaT^Q|3TCTR*N_On-j?Xv2h^V^x{BX9LS3Ud2t{w z4&=pwyf}~-2lC=TUL44a19@>EZ#|H=9>`k{+Fu4Yb|bLRTz*3WI28=l)WSDBld zJ1}=>?#SHHxrMox=H8k+HFtW>KVLJSoNt_On^)$Y`8D%{^Xp+RJv7c1H#Z%rj8(=f z+bY{DyDF8+ROLYBP~}MFXl0@DQsqSDt;(s&=?dE8-&3<^*F++Ak?L4=yt=Kry}GMfsZLc7R1Z~;RF75{sxMVfR0k*5Pi~lm`)z0L=a|s& z#K^?hM6%jgZL2C(r`lg#Q~jT{^j7s$^>h_Y`X_59laq~;ZIjBRGuc16X7cPkh9^fR z(TsnlW(HN*;{K}2tTWp`yJmK9_77Yg2fl7!?s9@nL{&2W{%D* z%>3tCYxWQBM@JYcLcOC?3sWbi-kMr7y6Hz}rK>OaC5a&-y*%d#3g*>^Z$Bxwn7s*xsqVC-&A%I1|GY+kva6CfXic|KPR< z4?TF|!5ZN12ypX6wFaJWeDcU-<3r;QEj*-5jZGbbyZfh|>5=Km^uqM%>BgBgGh;Ic sW=_o1?Cak*vahmlVPDPc`0U{R@%@8y$$4}F*0V0s0*Uwc|B?g$3hP#cU;qFB diff --git a/backend-python/wkv_cuda_utils/wkv_cuda40.pyd b/backend-python/wkv_cuda_utils/wkv_cuda40.pyd deleted file mode 100644 index 442bae8b10cf57e682e1761e75eff67315e97c47..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 445440 zcmeFa34B!5`8PhND8ltGE z)D;yguHc5?8WdruTimURqP2QxaI02v&HMd6=egV5nIs^;-~a!4KQAB6+;i@8p8Y&$ zKleAkpB5{M#bU+yzq2zIn~z`qOUd7^|CATRV(V=A=sK~7`n)h`enHI(gGSY#bY{Ha zj6eVBjN?y@pK$zXr~SDx{>L%#Gsc}3Kk2l1^&yAFPyO@gF*^@i_xja#d+c{$ z^~vu4BlmgpQZ6UymJu-%EzM?++MTjr80tZTg*E z4j6j~()%5Fz{wMSU48QP!$%07>xWm5-5>ea4?j?Te|ORe^`x0CVeJ92*yu}o$3EEo zw9)RnFJfE8`xUM;I93QPy;vyOnHQn8a(SEpHP~mo>LpqJZ!Fo{Di`^Mu{1z!A@?Vm zwmIoNN@BD3iN#8rN@AmgfFk~)B=*n=`plnr?^r$PmVZ0*Uqyi({!E zo_KLA-QR_Z{_fZ~rl}Fx74!_UOxf{Y@LxO@OYeBb=;IrYkHuEb1ajbv-HP8{VZapJ zX-5kdYuX4H+Dk0X+4v3oOT}VsJ2qHAkrA@SMgw4e1TMAXY2%MO@%Y9+Chx@iG_-yf zem8Tqk7xW#k?v{Z4?Mnczmozvl=%9Jyuy3#3|SMlNI#WAO)rqVGxpIH!FQeBxmAM!e_0}7ACizzjgiIx~Z<29`V z9OTrrmK?NhEH>;FK(xP95;L1lDj@Wv&Z3yv@{yvL)4ZtendiNKH7GmKnKG)foVZFs z-f>z-RmMrO*8Kc(VXUU5x-!kYgwv8S=a~Y@0h6ODlgv4J1qk4G=&oR^d65%94k{tk zZlRjWw7Ctaofb#{-ifP<6w_v^QmkiQig^2iq5>vr023!e3BGo0k%RsLOkt@!@j#za zp#=WK!mJ5Uh`)gZj1z0Z*~2Lv*Jnv_CF^KMZ#>Y;T8b+{Xh$`?5gNjDs|QPq<_~ z1?PbulMg=VX;$4W6J|qrxIr;EPJmkR;7rTHi`yD|p#1>9j{?aADgjREl$z;z7zom@ zqKv76v-+Zx=&WO7K3OHFoR?R$w43LFhc1d2@mvaX_0pa9X}|TE1Aop2AnN zSQi=yvx+c(Kd&G*&1sohS%;FH70vLs$*92-&Z|t9G?w$HSJ|Y&c(-B?#8v`MK*FT% zMQoTdsd6%KFtdr7Lo{=mW;QT$L(QC_nf1)RncFo} zwDV8Rd`mM$J1=VHa?PZ@%mU3cnkm|urI}x7rfBDKW`e&qWRs5#Wou;e@CKft*G7~0i|`or}*itSkk{Z zHmuF_$oi)d!6z;n2>FgXHYcdjJbP+bWJL@>3ZAKz36wBbL&MXk0;}oYf*yiK1ENKP4H^vy8m9ts z#R~`;R!Ym-FtUcYF@xxUfVt+o*|rkQ)J%$QKGMwDnn^90mo#&(W>VSaQO%sMnUk4$ zyJof_(=uQU{J?sa0dajWz+7{I0<#KOU@O?g3M{oj#Wf=dj08o1>!2frcq3IzI7A{vur>vgDn@qjLTVVfATp_8B#mdcf-z{oZuR_`TA3#EtNx1cb19yJ zdVX$QE9P?I0;K4<6xVQq6r&dbTWZ|ik+}#3r{$TO6*toRIuNjw~ z?Bxf=w3H|glsLwgI0~SMWQ7h(OqCLhWR%!xMPJ9%-(q_`aG0!KPnywC12AVldGpTUXd7G_~GmvTRU@g}l3^7d3RAA@3y6ZqP7 zeqjL#2{J+&3GsfC^t^v25_l1(2t;4gvb3h94Qfq*2J+%0HRmb>nqBckDAC8B%F$fn zLZ+x>Yt5XmnWB<%&1}<5QOOUtTIn;IDJuC`Gux3FRLOT}d8v}O6j)U9tO76BOi{@! zW`Y~x3R3kT^imO`6;k!ILQ<|)NFk?I7F-bw3um%?YsGY1J@%IB@kx;z77X5+$qOVb zltPh_ut3r}sxB%28iObfCNNenhSS$zoF517QdR(1yw>da4vCd~@{I*d0PqJ;Bp^n- z72o370$%;D0ABWk1xy0)Lj_C%@I4Rk$7KKvjLI>VQnj>P><;j%r>+RIQyxR*a7T5R?HEl!;8)9mlvXXxAO6>pCPk6 zE#CpMfmNZ5Ka#-o1))YGa*`pZW9N=_ZPuZhRoAhFo3)Q-r7yL@0OOubeqO?=C`R5Cy= zJiir&KlGiP@HQTJ{EvM%i~qdzzg(6RUtaiImgj)anre{wK6q0tXWC-N*#fmrb6Y9( zi%wa_9K0E;1ZqdpY`Gx;*gtDt2m#v7Iu#5=N4qF$+H;Dl4lT0IK%!+<4(C2e} z0x9C&;z((=Nnc}BSNx@q^^5!bo#XTO6dL#tf2VEWlKwhBe+PN|9g@x81xKt7e{IL) z;_oh>zvHs`dvSf2zv7j7^>=4q7ii9XJ^|6+dxuB)L-Z3RPIevRNEkWAIb}rdRm0lQ zSYguWE@tc;%qW`&2n_n-)n)T)DtBhUuBQ#tqEWLyW|bLtXsOh3Zg>p(KVHHrHI}>2!2ZIN0t>q&RoaK`exRWC?{k%}%T*pm>uLE_TcUqU)dcCIcu&i}wGBbD%EPI9 zRzBHTaU2$ttgNjAitc51Knz)BF7Pn3aUC%*s*tEYfUhio@40&d_;y?~d`}hl{9P5q z_ukg7Hp{lb;NE6w?HH%z#+tIT96OMgO$@3%>) zY%ch^ApG43cLBXc^)C&BWA**py(E57^@shJ<=6@CLv*&ABLn_}(0^W7ZvHda+8P7; zh0-ySj$f&{?#US1Cs-b3D<)PS7AmtLAr_gnc<^-(Qm{LE^Ni6}R_!K|{rN5NauZQ@&l( z7#f)Mo>04{(KL8VmObG2yFEO9F9`B`|DX`Rw`x-Py-Jf={GP{ba0lIW!|zsk_(i<% zT*a+pDN`wUMR};7`-kdGpnB%5-yg-vSG%xg&MIU=k!`n65Lv$XFe_Mo_fYu|Ga;`K zRXL_!%8JF>#2L*sTy2vo*jh2>XAU$nW~goxdt8k9}3PZoY`?`(F;D9^xYL+W^Px8OqFV1AKfTO(Q-c6i6RTt9!<*`+a98 zI^yh+iEb8d3YO*R+b!_qDlKh{CQ&iz6MI0Lw*J`+eIUXg%fno=Ms>`C&dgRhlmoe|6aj)91u}lI9>(pt5;w zn0hKP1GZsTnU=BSa`s;`;@a(2<0BCm0wqHJBu&%wK;%cMTl>120!>MU`x@=y#8HuPDK%S>Rmo&q6 zhtRM{$}HWOQ46Aa9<=*8IV89Ih4cse$fhQmrsy%?v%Jw<)-GNg{MVmo6p%}rJ?{(} zWexFNJlVtd@7a0rseMhPyf>R>{gIG`l>Rwp(KMJon9g!Rd<ZOahBArp~FXp1jywDr2>y4;(&q92|P&b!@hIes=Wof^f0G~ zaUkdP#1i0Y@eV$bxIbAJs&666&=te_O=n1C;d&@48@vU3Ks5AmHjovNTQJk?n963( zI$Q3DWm(-q&*T02or~U%Ye;Xmx7Uo`lc%qld>g)*kDf18&jvI#obly==DC%D0nNOB zM+P**s+CnV_Bx;`y%qzS=e7$EXgp2Z2cIAKh{mV+>yc(To2$6#h}+PHaJaBkn#q@G z<`jqAf8Gdg6W*X0#g{UrKNi{!S>rY|zhjOB0d$~FbOUtC?0Pyt9s#`*r<8VVWTAEz zHrkm;nb+PUPtiU@l%8W&$<%YkBrE9EZaW}6r?3|m;}-g3TzgC^%i#SXXIK_$C>Q?O zF@USV-%$Pj{Y6^jiL+mpD@48u9E0HX|V7>_~25X7|8E$O>1G3SI-vD2a(mSjCv-0Ke?QT$Z}7xr^N zq7Im29XmQL-4)xK#q1_lRBcvrs<-=c|N3VUx9}!WB)N54K@2Oxsj^9nNp%46Bot{v zHX@f@d=Ev%b0SuD*-Son%h{P97I8~jW%rPW0+MW}X%4_6UUgP1LMHNd0WZ7Y7yX|* z3YoDNyL{5G-^(mVmJBo+`vEN8oGAuCe8pa9gM{sIjolM}?E_1^zHHufy9pKa>pLr| zE&UV_r+~HE47T*8%UR8;>6mm*`fE#pL%|soxOMk3l&cI4R3P9noritI$OPl8Y3x(h znm=N5As$9iftXmRsT^f)MphmLc64oG?h8RsypNQ-as})QU>8q-&t4t?1}9BkZCRq@ zq$cb-WhLW6l>Q6U|A{MSEyW<2QM=P0ZK=CWtx8&gBwhNBy z9`t#F4hPl?hIS8pGl7SY1@KYV(Z{R8!74f3GKUB;8Df~j&enJ#sWEdkUW_Xby7KQd zc8DB##Tp3mAY6r#ReoiD0#vem(XGL^;#K!j-T$6ngS%$nazI9EaN+yH)n4FgFL0MK zAKWGmh|0s0B5=TJv&Age0S5%c?kHJtfLDimV5@Q^Ur%R;8%(!*#%)DDQ(rdF7Back z5DFI>;C`{TnwR<2wQFr14@hdO$*ZnSofWSk6M1Bk?`q70w~Dqsr z`{38oSIKT*qwp{whjjn!0X*7n@BGe#@EAMm)R!RwO`0Cn!2`6${sy}PetYh%Zy$9} zfV04OisPjL&$8VMbPKZH?|^0J(EI54ojZlFLdGE{j7S%CCqvn<>W>X`$f~qEg~a|T zVN~N^ZhUd#qwuvdFzY zLdpXhcz`V^JvxU-lfj0BG>|7FyQUhgfixB{a3yc)&LUN{wR<&-Ho&ForB#i}QwMp( zpE@h{MRSg6+x}RL`9lh0m$EIa+0$n0UNlU`{ALs6ScXVy9m#HIcuJXIrF`Y)(dhqEj9ZC)8lHbeG@F~A5li>Hb&0T)Sl~{gn z+&IMVv68L)p57yVJ8E+9yIWpuCM2FiA8d3Q~X)tBR%tawO zFuMQ{&e)D*$OeDV?%bia4b{v-g~Ve)=`MP zvR;$omPe4Fp_|)&#fl@2Q!6WhXO);4{TX7oYwsqMm%2~1azM0yXT?cx7GX;+>cQ9Y z_akD83#R7f%cZi;N6?NrW7{zG$0;g!Q{0)dfj3d5OB3QXzWwLQnbmB4PFy*exDe5% z;n(bayJZOMPtv1)TJ5WZGI_zVP0U=nQZZG|RC;lB# za8)ru>-nzU*~5;8$JelnH#5jauTQ{6t3459(6N@b7Cb@Pnik0Ef|r`jP9r7mg`QHN zM+QaRqL(bPsCTr9k}bPr4oHIOxb03r;fE6d`V^a_I&gNFVO-@)Q6(JB*bMi?l~9Y( zw5i4lawZ_)DAieUB>HC^k6NnGu!v|EccFv^Pe7JnXF>$TTnYq4Fm_rnF;r@XpXv+P zahx;O>a1LJ;S?26otOz!RsELip>VdAcEfHsX+FtIf^$&-sx!ywo)_dG$z5j?7ttL; z<2CpEbFh8@ODUEpkY)XyX{gOhH?JH>N^wzTF?vDx-=5pf_CFrC2V|OZ2pib%(i<@z z`fM0qVp?Urk-K1S!$k@N!hpTxBIuRRDawRWI{t{8pAP^Ks!~JwnJ1{I)Z`Qg0uypc zYG9eD2nm*jl4YAs#!Tv0L|>o-OEx{KD8qGl^w$S5dT)hgdNz!4GauSqVLRL*q&x)- zMW)RsuUlRv^jFD&=61eu%oFG!kn^PF2e|Q+F$c=HP-F@5i!!Kp4wMO-ng;B=4|9F& zF@@xMD{FLf2)E`JiEE9jg!g95sk^vL7jYMBc59rgzg2HF;M$|XY}||Xm{d6dnuAHc zuLMf)(_qSP@lEpiJ9$hp7W1f1^W##w{LBNa+^7u1mfJ8R<|FEY#3-H$%9wxm0cO#A z3Z3&6v*le#?|PO}csfNRswg-aRQ-P2Th)tE5;QBGGUhoCUxXl7yW%T! zgO)zJuNR_88aukCIHhw7Itqujc@0nqKMJuUS0$Wc{#Jh5?a}yQt=4#P!Z>CFl?F+e zr2zQsJqWq*8QGGTG0Ww*eF2;|X8t)fgl`ev>6$*~pu5I5L-9>hd}-plDvYnuz9X}7 zbH`EiDd8{|e~#iyDLz;N=Q+aoD&?Ixay+-ZxCKCG@N3OCvIw3rpUH3g0y?PqvLzsK zoq0vxp&94J00h9r0N^p}L{iT9$C&X9=gb(>5Z`Wn%T^}X&kc=Uq?5VQxY=L?Rd8Nq zeXQ|N4PwOd>5IsA3!xQ&`GXNdB21~d_D`%|(r99N72|~@*p*6_G{dpDZL8bMCE~%^ z{dX83uMj`q;0v(~-53hgg%F3NJR!yggs2Y*K}MTQBNsY>&0(^^j6qVQ6EU>O#JhDJ zG?@cO=NE74I9I&?unQashPRPTwg|}g(;MB&_vd|c%9nWo*`ATCeABs&%9rnA zB=gEw--UedHoEfRq;*dDvgZftY3&9xOc_Wr*<>m;33t=M?s}S&vM~v~Cio=%LTH|z zi#!BJZN?nk$A-)}Gw}z((q_^pk(TWo7(ZQLH3=oNGZPA$x$Ge|6Tp7lZ!(jva%1L5 zM+3Q8K9jmoB3Ehs0$H80U8t|&pc&IBQ`Ycp>2_rd&tiAn;9z;T!D)F&rg+^-4JJAM zu9KP_ia$^YTZS25So^wxI(Og(3QV3wtCHH|Ha&wV^}WeV*CfJ&+T$Y45#5UwzuT7@AnEi#ZM@Tz)YNNBXilCEoH|*`lVw|;aNX$%3(*Lr6EG2zgX3{v4wF2 zM%(w{ILP1e42uW@P)eyyi24KjUwlhi|KhT#dP;9~jwbUH13<{}`d#5V9xSE%lbX!d z*8@&GA9-*vMYImH;UlmVI`&C21ye{0xnQE{(hM zNnEkE;@bW?`swb!Ec-r_l@`^nqY7(FP$ac56rO~_g~FR((J1Xqq|Dp1Fiw8KrPn1s z88_y@cWDscT|LK_&comCL44=*9G^tA9Q=JEdk$6JgL{szG!K8*2JvkY!spH>pk|8d z(LFrLg|->xX4U0v?<9rwcveJO1?Pp3$gW_%m%F_i#|$JwaaJ0TM+HoY&|p)qYw!## zEl#;P?=njZDWQceUMyp2q05*H|NJi__-|(xN=ERXqO!R7MV4;x*G2IkkPrV2En)tP zyW(HMVU=0+g}S$EflDh};c+^Ld2-*O`^eo-4EP_r#u)qxi?=!~bp;{&-jRlz`}%L==BTPw}UFfIstD zL_ZHRE0=zr&*Hx)`l*lNKfI^-yYwH2QKXix2{+&$o#+##)|oFpr4a3Wg`83|O)41@ zvW-bHh5cIdH{F&9u$UL~Bq7`M*#P?cyo1w~;rP)FKj|!fYV-25VHf;J1Y0Zow66^7 zC!Lp{OW2pTYWyYosUwbb9(&cB6#20&HZ|VT$Ly#%QW4VTkjrg%1Or7(VaTL{bW1+b zpp$Lbw0S|dn!AQgc~WExSQH9ohU|pgGF#&6nlo_YO>23L{yLm^Pbt z!7qH=cf_3hcH@7`v1>3`AL0^3a@B`x*jLYE?{tt2^rPLA1-6G4(qA|H)MoLM%*#)& zF8GluoVPwi`*U3ywZS~LB4E^dbB`v)qVS?V<~q%RD{+YZ80Ts51fP^+i8!7pzKPf3 znVvJ`R|KKSS5~UU$RPO9iYJ0wp zS!UD)*z=u>8K90WNy#yb@In?!rJ#-K<|(dnHHU5({e|?83DR#1(67!#{}4N)V`uhq z_}3c9|4Bgp_5l4wcoC8RB`zdeLp&x6oeEZ3T;>j=wUA0?aLto75MFx#N* z@g#5kw2t8e%WIsqkS7f*DW;GW)Tb4ILqM=$1$zd85yb8 zGq0Ykz$SM2(yaWa)pZj=0nd4z*8VS>*O-qiY+!Mfo-h|)0ObYn7H1Th$R&$P6noqt zPIbzpl(3PK1xis(srdmLzdMcy0=P@83h*{}h6@qPXh;k#*G2$~R#!@Ap%k{x@qMi~ zLmu~3HI7>%!kkBzMkRoD;}4%=f`>Ar1zB@{czIP`cd)P09^a>~^u6KqK3#gletN(z z=nXfmm&+R*nP|)94ex!B%Nq_gi!C8H8cmcOWcIgfG03eSyIE8t=VLVE8 z;SEPVndJ?Rw@}`l{hQZ<{Z}gRXQ{o#}WtLG2TFU#!@4>3R6A1*xKHPxE?LrZt~$M}1=^_pzaIi#ua6Gm0jTJ%ZNi zf&?>jH5YTOTr_0jdbYy1$$0e3k7VFEv4|0Pdm;H{`g@e2P6M>`Upk3(-`A zka46|ed!)#bs>%fKp`+!g6TpzbbS-mk@)7b&1n8~USh1ZeU?gpW=32XNBrRsblYC8OMt^T^Ed>}Ip~;Aawom+dIWrO zE_05Q9Nnac@<3w8rH7im_{#bSrW2=DP5^vxi2CE{t{E-e;^?4+O6hWlT6I9QODdX# zJg7>gW+MOud?hSqE>~e`^Feh!SDDGzPv37f?P?61M^0M?@AnAJG)hft!1861N|_9N z&S1B3O_5c6erGxL2AG2eIxWT0?8>f5l;D1$XFdga>^Xp~XvB|BakoL>d?n9}d&aKCLVM+PmcM6*;T4+Wg(kNw5fi zTVu=ZYV()ulE0yA#$Ugj{8i-RFZ;Zo+Z4Em-*wvV;*;F!S+b&7Hk(hf#}VvONrKZ; zHe^Y7u4j5C%*7233*i5qxP7q(U($l&YF=FqkC2-NzHP@-s%)M;rwtG9%gRA;`Kp>! z6O^a|)GpLHj(ot|#+~a@TM#FSA@t;hE|pqKrS?^WQgYZ>sEk7Kz2IO;ew++%{($r- zq*MKD8V(KA$f#?H4PJ=fKfud(8T@j~-}}k(Q3A(9`e~>kSW3Z?5imgn=arKxP>fc5 z1VWt*ML5$V0!{7~N89vdDWcj_1OmA|vr22lB^vRzX?o!}`e8T;x~}bDV|CWdcwlTFUfAh-FpB;_yN_ zhir}=1>T$OvR3wZ2H!Je1DEC;t3(%Ql+k56EuR%#dXwikj@#o0QHon@!i%ApaqTI7 zdrL<+I`j-bC4++v%o&+%;H&o=O9Vph3bKtQm}X{W z{7fuKmC{QjvEojB%0g^g;VQx)B6FWt#vF|=YkTH3@ox}tDm=>@v|p)IsJ2ta}%78Q+K0ZOpa-_h;{MZ454fB%Q*VWQ@LmEN=e z%k*yV5j|Ldp$t*6LIwv+NG~*FijgrwdU&*wdPCU|e$^)E4&nCC$M>xC)C2e)LPiAN z=iBuZ-?(giqjTYVsmJ)X&c?SvE_`!(jPLcAS*ru`e*G*n+HEab@^lX2*NZ9>9v|{| z)u*;xw**l(0%gE5TjL~SmSt=u>H^=PL3|r<8pyZwn&2x4;(Kl3FXr$5MFAB~AJnk`~b4MKfH~rGZg|-{%C(*p8aCmcbJz6EB?Phe@@|52< zTz+464_JPg9$iXFr;rE&v5k*;@NF5!x221(HGr=uN4U&~uPltO%*D4i@r?zq<%%qu zpS$e8zFOW(s(e2^ujL2kDR0%=**i?Xw@d#Cw>-^6P!1;vfAG~_oq4xKIkuE#ddA(rWmxn>=+ldQ90;nDh4hz z10O~mwt;T#f2e$7r(Sx@SbdW52zuWvQO6`kI<^FmVb+k-B}2GM6)1*0lh!gqGQcVx z(Q*HP1;Kr=Ag9VPOZBV#1#|L9Vqf7tnvP9e4g`@*%m)1`sjxLJZH|nH(&$TKVhFR)Pvk28fqG)riby{LXfx0T!k5iT=1 ztti7Ye?)mSFfwa4PwxAa$8)DKpF2ZOe+nGIvT-Ki-wR94vCKk>V^M9L5~k1Z1Ez!< zw3|=ZxCpk^XDv2NgcsPHA9!!IKhDhNFZ-5AcUr`|l5Qp)u_(iZN#3A0`bvz2aU~t- zb_co#L5JJ9le6TS$OLf3I)-&$^k;k`(HWxG`M1xT74lR|oxWG%jfh)m@9 zS4q|#9=S^LYGmcTT21CM$DQ0gqnBQ8$=7kVhAR!6q@69lm}u&(_(b2QiAvrh45k}H zSN{ECZGC^4e-T~R+IkC+)E4f()O#aw@#RSHh&=z|%k4&qwlG^gs!}h${1dYBYO6~% z=|yzoK?Ez|U4jNLqT7d@x_$!g8dw*b%;h6lJDK!WGP374bt<8b4Q_4Oq*9&IlS3=0 zbBz0XFjP_kmn@sKT#UxIpoUn-wrJ2}80YJ)xeS!G)QC)SrOm^*U*2kD4(G9NduJSylU8#f+w;6yt?L9fb=95G8TWf_v2jXk_-1GGMY&2oiP>5I4UYY~kaCkL!tFn^+v%p<{itC)-PG zHVF?>V(^`q>aA(Xcx@BPa6##e>pDAuDQT}c6YDkbO`S9Bd2{=&aF}MSZMf|fZ-uKQ z^CRXLiwGJYQs&8hp|>5WXu!jfvb<98$z$Un>qNvUxQpkZ`A;%nki_cJF2asP-~NoXyAf%!0J_N1PZb@aq%NwWlGWQ(YO++N7gf7 zpg#hLQkE|F=Q`c=L0HAI>t*GEnH@aaN!P$2rU&vI)JAzaHdB`DLVWX-A(`NdIQ3V5CRRWA>NcREX8}?=iA`o*^cd~sb8skvOb`&EAz{49 zeCGFlEJpRRc*jjlH;_W&%mWoKPw%5LHJL}eVng@}+-SqlWNz?Y!Oi%}Z5TT4vuSXT z+D1O{l%L;^Q)tToJwNy-Hc7%(%vAAPt2l^Z$tk&=*vsK`AszaJ3@vubW zmR>zhn0GHh?F=&}EclGBI2lT|e8FFBJ0xC|6n&R9PhmWTLX&1EgjQukb$(;o4IafE zPjnn(!Ab=zsQ2*!u)TTgf}*&(w}oEC3@Hi58>OX7nx~oJHZBzqW5rS0sL8Tna@=-A zW0f~M`T5n3P_*U45=MZ~x+)*s0ijpt}YtmUByf|A?o)+Wg8hr_- zt>NMpi)+ReIJFCAJ$B%upyf&^+6>A|>B&r^FaTUs?ogirYXosK<*ouR?Ki=KShd$t zWxknY$@`AA;>w?(0Jqif!XxT;^m`bI2Lx0ic%Jndt>19r6h)I&NfE1o4ZYh2x|&Uv zwOXg45?ZTc(zg-tw8)A)=Vv}4Nf|H#*Rgl2-H~Gsz^8?^2WnMmThtH^+oF!W-)k8J zTm6YsiVwAAz-|4CKdIB$=sOO#!O%*7b()tKIK!THhCGeY*|=SqE4(!*ZyB>*KExt2 zaZ`dO^Y6KAz%Ad|ZMcA>4L;)#Q(NTGFX}y)32q`T`{k=z-japPIx_{C@|w!REjpY) za9;k7o%%GcJH}WSeJ5FZC5qv^^R3yNh>P}T9LicNsQpmgysGfLgZTm-M(cgBIEnYS z9Pv{UuoM0C`#F>swhIUgU~lqa6Dr&rK->Nq5SO}$OG(CSPx9-d8nSx674Il#v0#o( zqZ-wo|9d{C$S^`@vBMArj@rp@+I)Ys>q^qk@WI$0+kwJZh2sU(>w1rbYp3HcRtvpy zf>F!De@-QHH1OFxyA!SsQ>Cx)58^JIA{7M+}V7hP^8=ZmRENCmXGCSUH;?19a%RCd&5j4z1b0v`g{aGNpLqN>;&&+cx-k^Vxqe%wQH{t>98CEkL0 zT4(`c>+aHf9|q|m%)!LS3s``q}S#v;Df(7D78kTkI`Vy2rVn z;zTu;J{RumgI80L)2K+@Q<~GCdv33v9*o-SAFoiD)wkFE7g-zl_L}osn-jIyxVP;a zwbwQnE_83Npam2B_cJwqza5V^3RmbVd>`3%$u7TmPiah$HUsfVC%XG`AbV zQ2+Y8O!3_PVXO~D8BTWs4QiZ4BLNO(l2$>T`hyA9Mrleb*x#?QIZZ4ko5pRJYcjXL zW#7Q|Y}Af;*xD0p24b9~(r*2$HR(HP&qZ1OCAi;xC`L zRE|`fDg1$}2hRaWr!K{JDRyR?I53lbpPFqbns zVJ0BkqlLDsbj9ry%cD?s$<>+$t1dV2*daUyE$oQ@)d@S}H+Hww(F@^WB)kvC6di{t z4H4h5=sLU|QH*^tLAum-pX+6a938LMTd1!N`bGM- zUF?*vjHtTNF6CA?`}aB2y)OBtamb*kdsddZ?WR2OjIg>bH{EGFzNXrKL_W%vTiZvo zuhH!7YR;?gX6F~_ThG7*6@C~OlAFX=tMsPrgy=8q+>fHioI~41+DG!WePfok-7#bi z=axyN?$ljCbxV5!8u^@CDqTA*?wFK*-!+cKuWYCOntcOb&0a)_>Kj!{2W{zl5%zEZ zZ!!c!hqr)hzwW!Q547e0Yg)D7tSRfxcoNxCZp5i&KK?M)2TmSI6-T9;E)+b}C+V4| zD%MH-f!~;`+2?Sp1FXbWN&^zqt6?t9C*acq4>GyS>W^rpog!QPlV*H5n+b+`l&I%Z zS^>4jf!T7t1HYxFXusHS$i0PHm{YRYxrXOjmt_nweG?_V0w=d1&M zsXlEThWJ_P+%tzdugR%S2{>z_&Np%Cde7ClqFZ$?@zj}PKd`UQSXiIOX9D^h#cRo7 z*I(kFQcI%>?UO^HhqqgD3C0wj)nEiJ%4j#LJDm!f?JSfu=kimsJ$B6<*6YqS$MOWS zxTe&3f4)d=L4^x5a(r_$n~w8XfdhA2CvXF<56`PY30-9}0v!{6xz66~E*9kIpP&wv z5)k7+{AvQdNk?Z4qVcD$rj;ES)Gy35liXji#5!$396`wR-N3GZ#@QXj_cR6*qQUGA zhBRnvymULoC1^GgmB|84_E48s z)w*-Oy2eMmurOOxZ#{|gFB@T2Tw!_4f{A${eUe}@=SlyNSAG=3p>h)7t{ z?n=r$`eA41ijOECbW$6;aHj~ws*m*Rr*&R0CuIgh<`ws%n9SYFY_!AIDe53@KEnM< zO{QzWB;T93K7&m3Uj9>EDMXZXVNxAC!~l(&7h7l;#ym?PCunFR$# z*neEksV_jqX*6fWittRBZNjqs)6?}=*kd2kw1w|#eo8u;&bY0-wi7IKv{H}5O_)^d z+mjW?uJAkT$fnQ2e7H_dy0+z{*ZBU<5-dEyC=1k<>Daiw1g%|hhk9?C{J4e4nuGJ} z1=J2yuPtoK4?FP=`8FulZrcbGx4NWit=T9iT%w&=$HhfEr!q_D+@<7}5h<=htM{p-rVMIW_i;`1=ddq6bt7RbWMQ3eF^Y0APZ z^4@8!CF+PBy_IQ?jlA}NkWlbOoIF8S0uFAz#2^|SEvXblY|T}6={32uFFHqAtWmJ+ z!?vvWOD+Rs)Bv6irw?_vWNeqM-tWH7qh4_w8Jxw(@$BGj*}_8WS@9G>-hPIY-L!TJ z0>)vpNtJCPmChV}GaWSzZ_vF#7$ZY)Er~Ew1<*XpQz#N&8n;EHOvY#vnR$Yt>>${D zc_yptBpHLR*ojl*ctbK~qw^6^p{Jpy^{>VsmFMo2UtnFrDY@7 z#d`q*c*A@se3eh=uq~}=Eji*W?leJhwgdZ~#a+cN)hj^Nt~vb1%RB!j;z73`UDI+7 zxPP3%rbC2ob6Wkblkol0bKnh_5s$tLw2VS)@XZA=-Q|GS5_VMDqo#ufB&dg?b}pcz zcBV4x$pq&oOy&%-DOf&TpZr*zVZ1*ZufrddSTlQme3eymUVHA^1A7hvn=Wr*dtS;O z5cnTzp19T$^6hyya{~6<6jv`r?fLMYlskt#FBuoG=TnVAdoBdkY^bGiZtS_L!L#QJ>vm($;BwtSIAUT-v78q3HrNb%rKPxXL0>q+w_mnA z-Q1R>n&OiSg83-ZyjNvx++w;i`a*#X}c z!}6E3Xu2GV{@r-lIu!iqW&_LDa_i8^>+04a;U`bL$MluPmwq6@ia(gYpo{=+$BX7m z20^^(qi>RS%ROo>C)CX=Zj2IZ{qaU36>&H#>GkCl zD+aOHOW}@A?Qh4J8bky%D5xRR?PNYk)S7~aFdvY$*iY25uAgXJmul7K=$l*6(R`hp z)tbAmv+k>{+@|88Y~6Ic3cY6q8QDsPvJK`v#u2FG8;vJ|9{W!<`+dz8O(7b(Lt}Sh zm7S-N4|0U!v5j$nC?f~kDRZ;OcFKJ9SIf5Se@dA}<*Znok8JklifHgW_O7Gs_@0Y0 zx|_>CO$tUCFrcjlLRFdB50J%8t(n*^xGZ8OSQ_qnnSG80R9t1gnh6y@ZxQ@jc@;0N zG6R)Y%(!NK9WyV#XhO*E&u?)x3Eu{6{1r25%`sQf0Mcv7n3Nna(Y_KSES}tPG-W}x zPK+97$b@lN`hs>_McN z;{?!Nu*@|R9txe<@uzI(Xu(R6bzdmN?*2k)`Y?L!aeGg2+&nyuh})^bDpN%vU$D>H zUu%`Cu5sf4t6a}Dy7u1-h}M_y&ze;EQb>m7yG%M$5&4QlACPZhRKC0C1m!EW)FSe| zetAH?(tvz5*M#JIqW@~icN5Y2@?EY;l`oBCSibwvTZ`Dw4p3>n5mE7;zb`1B%sA#S z-#3>9#9MHKZ@yEn4vF_(zbx^(i9b&guP@&mO(wFPY#VT?L6TmECR=Qn?PZ4q+58YK zi|b{Jwackq_NNG&ZguL0-*v+Lc5>4dSc&jE9uo%vekEe&Q0vXIX-jLxl>vS?4D67c=$I88?E^TypF`=s;m+u0J}_kn=hkJm1*@_W{r^7|wS z`25b%q}J~VNQUk6qV%F7@v9j2!Dx*!x+sG=)u?^`byv_n`TVggDDPGQc`vylWS?() zuc3Wj{7XpQr!=YZPD3&*ukBw&_G1y&oxVm==NnmA~KH zP5b-3k%LbrE#@`6&76h~jYm*YW;@y!q}YRxb6!`x5m+5lFvaQ!iMW~dcXYS2=B^lo z^5APCzV?kn`0k(V;#b6KgeG-s_^?jC)sR!k^mGTjNo#7+#&EFG!{o}`O{IM>e zr5^A%1^B}9UF);fgYpMS`G|a1`bfH$FX*J-{Xl*q_^`7q8h#@2!VO)W6>lHKnZ9te z>rwrR#)~_sakqZuj5k4lUafc$ivK|^tbaPKoXtM`qpSKIir4pl?{V!Lpa=FO|5^Uf zt^7gGx0Uh{`(5cH@!5&kFVJ3fR@__{G2!aS|J~Y>{|;Z1+4ka$k=>7$?5z0HVLABC z=`TTjMg48}06i=S%lFJ%+FAx&xPD1&jMK8g)^%(`1iLW-*e3BhTi+odDK5|L`k7Ho#=i5y5_wncU;hH>j2Xi;jHScZk zNP2JMYd)y&Zy-te2(X={_YAYTp|@pDdKaUl$afym4HtS1cqF~imfo>M-a6oD0zCR( zk#8ikx}i4*Hje6ZYm^jvn-Se}K-au?IUY$bX6gOhPw4%CA9Jv!*TJlA=W1ENIq5CytMp#OgE?2|W$;LP^YFbNwhu^x-h%}A{wtC1 zHfBZX&Bul?|27u?ej;bsQj?yFlCrw67spGuYGW}PL=K3#=wG<~ zLvz&;Ir~@H?K|h&O~M`hXN}jvF~`+rlHbwDRo>~cJl2PH@bPE$Z;nKJj7c(?!C5t| zZ2@*Y#J;+V8H1dbKn7K2rZs^NXhPUZnTbf&V7??S8@2$?B?-G~_)RJM4|b6*d0bFm zzl45yn0`Ti`tz{;?pn4PydnuJ){rwJ91+ecRVALtprKHcNHqBsYCIrnc(T*-Yb zBj!`+M}^atsg=u-FavgHt<9KSAG1?q4GJXWT_CfY1u9pdCIy1i!|N2&GYZu4C^6YR z<_d?~uza}d748{@n~awVhnqHhxHA>*dWGYfzQUCR;YKRlnF_~sc!h(0e0+lyZluD^ z^x?2?--r8QS1asbg`4feVZiLey{d4VDjb#_UAhb^o?h-&xUUdvsF(R39PjE;okoE6QJ@nQ2%aiHuEz+_#tOviLVZsaIM-tY z?#p48_l*_KdaA&=9wTr|74FMNEWT;bg3@(8M&NEyxD^UVODUY|F#>nC!Yx#|nI4?= z7y%lkKsOKwev+n{C2+~KDx5^U2p*zdBu~J zuIUA6e+3$=K=2O%a!oHlTPVkC4pRG>! z*58NW1!8Y6&>XfcX=Ti&3UQS})O!$v6ym$nEUwcPqQQgsJYjLXsSt-KM3V>cghD*5 z5IYEjwXp{@S%Ic0P^khq0~P2=1)`cgaw`;Q zkOD1LpqW9S*A(c-Q?2l`6liu3XiMekO$EA8foMOHdkeqjQF*BmW377*$;0kJb~0&;)6%E2JlLGr`T3~-Sh3to{xYUSCHmNkFXu)TMkh;yT6}B$V`3OgDnnU|p zL;Ua(L}h?mX=9ZDx9i^6zYM;DB=F2Fy@91ynPtg`Pu94(3*rZ??skDEUwQJ=hP#+l zWapRiLslkC!BpF-f(tMe;w|p~Av&7 zmWIignNVrjypVff{ma6OD<{nhC)%1(0#IO2c=uK6`zIz_5&F{)lIAi_py8AfpRd%S zfBBOZRr`eSpK5_RD`xGF*)!w~J5C)rTRW|%R7yAU=g=&^z8 zZ4s7Fm>my9dAx?qHujnG{MSj-d7Z_W;_C@y{1|zSz$2j(l@4)J_X|0uVK}gb!-sAE zqDoQz*HQlC#Jz5h_&)hns1lvINJUxVxI1Gvw!@8~0ObKT8&=SdKbzl3m%qo~S>YGe>qytCqN z+==2B&oRHK{1gR|^a>)WHJ==#LWj@JF<$27w@9(4iM>}4JLFIM@|#Vb3fdzJ zUy$GXZ7HjPhrSGGAZ@OecginSwA%b0rr0H(9TCJ{Yj&4+Xv5$?_CS93{#X>m*DH+g zljA6Vx%J0$IHjB0ADg@|_ApB@?%_yg!6EQK9C)nrqS!{)+SwS?ECxv%011p+m;hKy zLRDtrQ5FZHANzKIa`?hAO{+RP-7jv|n(aQv_EQH9cbz$Inf>fu9Q-xemNHj;mqW02 zurR%{4az5D(wBYjXbP%{FONk6c$4ANo_NhN<`sPnmkS*R)vIHZX5mdlhUQg|d5%a~ ze2^zpfK3F*AbXY5@};;7ya6pon%O{~mEE9B1+uwpQp=O2C?dz0w7FPEdb>Ug>pV%8 z9@o3=AClvzF2NQZ4cl9$+?um``n&1(#IO3h{}?N89{pXZuk-3};*BW(?OQ<={=APx z`M;>bv9{kpsWLSCbB|NR9;;CLXX0>mn2f`C;=yG)o|sgbLXOv?Z!H-#kRd`xSFI)M zX;K`uwWLUsGr@$@TJrTq7I?Nz{!5c{ZSoCG&bP@wX|l~G>oiGufa)Mkw%g>MntaP9 zhiG!SO>U-1W0M0k`GrmXESHr`s$6B0pKCHk(OXNFYqHcPU)Cf}$MNmcnjC184{9=Q zle095rC9=Bt;wWKUZTk=CYMaAoB;a#$LGoR5v5o?5JjX{hEXVoYOQ)S5pgA9+BxTq znw0Hbdgu}`NwZR6YD1NnRWGHXYieZ!xah37a<`)JfNYYypd-=n9Pd5#1_p2w7tu4^ z`sVwKL!g$D=65`Qp|Ame`LY54zQ25^$r+Tpwd7Sz&a}xznpA&zSd;26f77J;%kG+N zvt<*SRDao6lj<*hHL3pc!v>bC<+kifO{%{v)1>;#i<(q_c}kP&FLO1i{&I^Z)nBgE zr20#dWH7WkG#vbt`XT!Vkh~FzrS00gu=N+^zMw~lAHyT=eN8-qwznSP z`-Av)OUar4@ce;tuQb_a%MR3} z`a>U0sy}>>9Vc{4^@k2kF1Kah)uj5v3z}4acv6$<4|6mbqn27rW@=LX;R;QvKU}0q z@rO0G`13n;Z}Eh}bY<~cu-sX3+0cl^`GOXw@fZUhdqpfx>(b&`{Yl)`As*aPa_*a+ z-9=yxBWC0got^m93R}rmorq3suh%3Hso4$y@|NLi?}y||K?+#LSBQgg`Z&&O2qd~3 z?$?#9Uv^epxHEciHgXc~t=FZ;Ws!7NoX)q>kpoY{8^rqPWf6|%k{zx`@@$SN%q;bv&y?WjjARzmxU(srJ)ROYzt|q7kFGU)EF0L3 zJx)=rN9}QjUWpg2A2WWHZI6!+Sp$1uHgcL*kte~GB_;J^w|7^=9vh#P&mKDonl9}T zh*CHy!Ho`_T6kkapyc#;)_E59{WsmbzhtwUbXEg$azCzz5PLLo|NoK8;w05%h+VvV>HPD?kpl@c^XP`GZ!ysejv zJN7P@R(5AV0g-vQ$&ADcsJspu86+2u$qHH>f#;LsBQc0YWhDA{7gIi~ z91KO-_P9e)@seErQGY85vmU)TlB#XDZga(^Is7A39`bSiZiZhfWjZ*I4`$kklTnUSd0@(F2SfUh`s9=2$t*<>X2BHLgI!Ygr zQ#%n8pd5PKTUhlFlh}y{dSZ%_HJA&=Qy&Bmx~X{ku!F5pV86h7tQCpJXUzKg9-VLX zU_f$>>3i9c#nzEaJ5l}sru36$oqB5-?7_GGfLHjw$3e_xS1c&0m^X*MgZQ?{iLbs3 zd?9-eB-OZd`-=5eGl&ENedP96xwz;vcq#_HkRSeGAi zw69fW2EEvgnk|Ewpv9$fhUyPofl;KfdpvW*#d-dvw6o%;E!+_-`U8pBxJ?-a z9J2)%m|H6@-2r23$yoXv`opM-V!pU+7H!=oJn7N(gp`C&y%n4_MbrBj&B62Xe?2CaAka^pL z@SFzSy1)#IOW)th0l5=qjG%^ijFR)wJ;_3o0vdRS&tE-38_Wx}cKiS>iK?WfgTBXX zY4sgv|0}NgPMZ7BKY+eB+TQ9rX^#3k%99rG*IWF7{UHvC)SB4{0j@e~S_ho8q!`!5 zmaO+|v1~lfm^hJ$DtK~Cuhl~ZW?yzC`fDA9L)(BK-ka55Rs2pEe;jMAt2shp7ClSWK0*Z(J z_w@LmnaQu8Q+j|uGa{s)DHeZU4}Yo){5aT^pZ|L89nn)?{cT}v8uSf!uQH3@<3ca* zMI1%P)xCIMc)9y91~~L7EcWOTLfjvIlUL>FS}(U0#_W}zg7d={b@8}k1NNIGpj4UX zHFuCYcql%r|7Fm21pIpn38oU5E2tRl=;fJjzd|v!^ zeBj4daGx@}ewZhuM&q%C|7i8#D?aWLe&R~QZbeJUxTm~XE_jlOU)h)jjUGbahlHdRA zo|E6Bx5%O2BcAs4d(Obz`n`32E`D>(_wr8nnY0Q*T?0)dQxI3~ZlgjyBP1dzJ>QD%4g(7|bKAYF8LI9tFuy^RwSo1IY}PSiY=f(`FDM0i^933k+N+5)WFp~3?l^?Q zBL{B^hA|>nGWMcF6FBV*+DUVUOxAIdW6Fry6%&BaVm)xK<2>xa9(6zwn@`Sg9Uxqks39s=2}TuQ(8z1nvzXY~EPDsg0tqGvGc+4!sOb}6 zG%06VrNoI-M*Kl^QZ_GjROS1UH+jTSe%4c=EGi0PdAq>iFK*}}1@bds+}v0yZBU90 z(Fp$8k99iDO|SSJEt_A@{tFeP{ZmxL33$o$Wd62lkGXnJ8qwW6LR$}I4g-W7MFzoX zfX6}xLi@K*(&gHzXmOQCEEoI3fY@){xPvEC|6*4rSI@8l_1s!9tPM5WYMRuF+rZ4k zMKeJo>wJ22y(qW8ez;52U$@xQ^Vi$Bu+U0UXYb+r>y0D3@YmGesi>g8Zo#4upeO$N z^gmGAb(Mg>UUP3ge{HL>*w>=J?hs@mkH2mkfJw@ERAKvn#9u!x>E2(LUfKhHoz7i^ zA%DGyIrLYafTe5t{`%(QVSnx2OSo7qe|_wUkiU-N8s#tY*P%nB{(9O#&tJcK90bH) zm+nSU5lj(ZKXtnB*J(>oB;>EBv1kwdb-g#KwrG6)_^y2ZTCUjFqQ9Px2@#KpJpQ^M z0F#vSvHx@ax_43c{<`tSJ@D6$xQGz)*H@TBf8`038Pnk>4XS)CBp%^gqXO?!Y+$ zSNh4PU?^(E&z?V*^)I&UUSDiE@%Gv=8Y}kOz>5`2(3!Ha;zk}I*I2Qz+7%;5Uu-mb zyk4v@q_LlzJ%=sXymiwT`}h@<24Hqy?9spFh!ukUBUy!zLg<67mA=@aK_>FViemyW z$@!Q|v23qS+XgAkhMfiMTxEB&fKIwu{B>G&-^g|op#vnOdlg=fbN6PrFINTkYul?a z%wtI#C?*I__E+v?4g-Z{-%d;0j@Ypedx2s(;LVPf&;3z5(Ua~`KkS)@Lvi9fOj9FH z&;@0QAy@sJzp?9Betf9BTk7Yuqw^KVEZk7@ulLHtx@7?9h>^eitpb8XhnV zhl~;bs>>cJa9XEVP9W2mN1X7nsDd@M&U}U1i7#%oYL^>WFrM)4_9@v2>;SGIYeiQ# z=eGf)dHGZk$m+p7I36l%nXXEiTFHfK^Vtd!UstGO2othwTW7{StS@E~)3x#f8kIt4 z-v@RlE~2X4RVIJNMb?9Yb1qfR2|BEblYa<-4!gO12ymPL9mr=CF$d44qU{*!f%?|_5_t6 zq?IxIs6VFft(oZ-`pro=LB=$<_oTcy?QFVQOUZZlqNWX7B)tI1nN*qKZ*yCm8!+K0 zKq-#$2nWib@%{_3m>g7rX)Nf%(qXTdx1PgmIwvbZ8|}Op(?4bNpm?knE)-?!@#M|7 zER0RYkS~kg2pbAPa=DVMR+6V!l7ooaC5ekQK;w@}6}>&By8K_3>L8aYUdf^8JGKz2 z6C+gFNK()*lRBbqKmuh^n-xD|HN3>rwr^Qr*g>idRB)XG`^>2vr6KQf1vF zwK}Cb!BQO@qDo*hUa8i3RA0k{95p=HrHWUi+F7wDIwMdqZj1+V-=rhvbMQ^FHcVyA z2Sf_;eJoj-r|?uPBlXUufPi7s9@$nU%P(4rVe!h6eFl9Hkew7ETcTv+D(%Ti_DV<% z9d8|ymcVARlCAg1Zlh$kcFE$EC0k?3UK}CYOUWjc>{mnzX*acGw+WFYuvu5hBF3ph zJ@rqkWPasSq{S;s_Vr#O?IjViy_IZA$)2ZVdH+77#mVk~9~0PIhqjGMZt}?1DcN7S zWbw+9J=T)FG(xsi$<`{_5+%EXC7TG5C9oMRWKk7a(iQkKObgSi5|*-U+Fjxwl!(le zBgFd<1sp4_#3vIqR6Q3=K;~^j#0hNvwuCb4l#fe%q7vWMB|ceLvyzXt#HU1v_f_Hz zO1z&EA7+Vf7a~qzv#k=JCg|1FuVIfMCEv~^Zr@tspP>H$-^SV~0J_jO&Sz47RFS z?Txe-GgzdhWa%AVdoe+Bu*-_S(e~m}*@TrbbD(K35+_{?D(%JWXHg1bro7oM~PHg!r?+6(cgUP?CFUTkK`j*F1(rDT)x zCS#r#7f*Z0MBuhuPt%ut85E2WJ~iB3EA5*WN()pvL)Ig zWT$VGkUawecp*DiytW!I+4xp3UZ*KD8w&pV@w!cHnK5tTf!6O_DAteDz7|yC^(s9B zkz)QdpfD%CZB-wM*L$xoHn(o&)AQqXyOJEOB&S-E60cp7HeO$(R3%;u)#d-NR3q_v z$$CQdP)Zk&(vR26mFj222|+isR3%<}R0{!Q`YTn5*FyC=rJ5bDJ1g3M6tdXRoJF>X zzB~&DDCWK?1gr9j0WX0sAOR_OJ+9#$b|-$(xLcDcOW2D*@Xh zOJK8wl1+P-ef}-0<7mMC@CT80RD`S>u*;R~6e5L^?+1Cr5((HISqa#8aIpflO#-$^ zJ4VSy1NK-;_V5T6>H?sog?BK^N@j8+~%sdLe+Rd$%EW3ndj;gWmC4{*#b0&SyZWj)A zbD3@G)(^zYxw!$EZFYLu#!QC!95GX(e$3IL^9aK_xoA&6NyM16xIuF_W>%%U02gj#YLDWp?^Z zF8bKYX|1jttj8OW;L=3#NMJ#>YzV-3Ksr&YIBH0GqXurN?7yDXCH8^I-o^+>gIV#C z1xZ;ICUm3#7tyK0M5Fx{upK50qW%VX`gwzQ zKOoamYR-O_TsB_lk<7?@aQG5FML-U5{z7)?lEd3(g2PL84siHa7knT--Jl%8Jhawn zDa#=oz~>N&jG2D7o6)CcAXjnzQaj* z2r8l{OhK*8$LeO8gn_p3Mt2aW%>%gU4<6xd&IQ8>v-eGaL2jCpb`Hmh9n-;xzh^h( z4|Xb`WbR)*e_%Zu!GWf>{dR}4gq}07DMPm6lwEuapr~y|8nb7|&(foclWv^q264YQ z!02{n*oT~~E@!c6OczZ74zUf2>EXjSnhT2LYyNcsnt+v?(w zzEnTlJX11@AaEzlA5WF*Kg`o*77j|+qr3IQ2WCaXdRy$&bu1Lx7}ifP*u=eu_$UkE znuv|D!!!yuH#+ODHzP9L+YB?4z`akRzAEZt&@&K7-W+rWr+MgZzs73xErQR`#409Q zG$FEHpb2WpoURE`XX9-aK^@;5i8s?|1RBnPS2pn%&hXxqkYnN^+FiXe;16=uoi|NB z7{sZAltjYG2ndWm{zSV@;Rni3T&ez)zP2a+bZfsn{J zc24L2lV!LGV+OelN^TfjV!5`O2zpb&O&faEHnw&w)v-i}Br*t+&V8khIUm0kBb(!rGUi-rW z?Hfw858qFMQo|6eW1ihildPKz{6*DdSQ&F*sHg6eM-THs5AiNU z51H%H!?01cT@aMdD9T5CIkbu%pSDlANRQX9>(in3Vr;C%VuGS`{9@eq+eR1T_N|DH z{y;=lX`bCJC^C2ma{+8Ra8*YP%OjEbg`$N}yuhnUyP_(rs`6lj_rw$km^VG^aiSoB zAsYW89Ks8)*3D@(O0uCXmG38|nA0U+KoWnLB6pey6d0RxoZ~sVYAH9Uf&bC2xlvm$ zD#r9%8JWsmG1DTMDAGYtF4_yPLom^+0aJaSM7n0li~DX(=8bx%4jkY3RL} z&pA@Ed1jmuG z<(q&OKn)&~h^b=i&Ir+K@Gh0wXI*5AhGnX@M#ivNSY_m>_X!fHbkuG<88V44R9l;d z{@TEcCn7Gz(jy_jDO*dR#FDhwlsfneO{?!uCjh>>_y;*2pU_H+i5ww~yBv@TQ{==y zt)Eg~3((WJFX$OvIJv)`8=fTek(zCD-J+EI{l%DW-jv&_#7Sw~Z zx^68P0JJ2o)#73%>C)|+Q=Eie_-*jc1oX*9WWElT% zio*}+ukfH=O%@oP1d~wLQLjL7{(WHYDX`{tFW#@g`0E(~9zB~1n?e>80y9GX1*g;c z=uKL6yr)Ghqm6{1>lo}CMz;5pX=rCEV*;u^h(yi&F-t=|ejUc7=qrwtS#9!{O3+Gb zZIAQK+*yrRw{ls7_K&H(DeSzRIaw9C$K2kloL+LCYT+p*i4Z3v*J5g8H5k}-G6lCE_t{9#YFA`>bl7e27kQ} zL@pVW%jINEo!%&^rhE;wJS7HWNh$l3@G+$pM6^c0@6pkH@{~@t=%Z)6LK9(aaf>Zv zPJMM7mJPrPag%0XmafhtD7_+tn-{5w+G`=#TE^m>-iug(1|{45TD(`4c&``V?#*+& znQ5|~WFrAJV_R?sd`<{r=G1_|-%P{j$*vd=n}K%JTn|c_tpzr*Tn7-1Eg235aOERO zkO_u(yB#Jmlgn@R zESHh9zCQ20Ci#aUBwwF*kpZ=anjxBc`C8@Z^D+CD(C3Ov~PBC-ZzfAZgH)mF-&Ay679$GEsn?^q7@xfIXN! z+o2kemy?te1+rOsLYt|`wCCMeSwcTl?qKC_4?e zFKH95otFbIFruC*e>;iOBEtWu&Uj`4mBpw=$+3u$#!|w68dbnt-zl1=kDjOdVnj_BjW}i|gx1x(g!uJUX`SquLR`sz-2nlbRgwXxGNu)?zhe=Eqd^qe_!Q6%tQBe^a#;-x@cYiEt=4MhIK1p#;mHvH_&K>kU8$sseUGji7X$YdGk9e2!=#iv5=wmC3&OjpjUfWU?EtZ zrb?(*tVm>n0__0VUrrc0RrxhI5MY2?JhYi z%hc1W5>>(NJd?z_JL~<)BxMTD=IR)elO)lTnX{IVW5!k}Ks{%uF@J_Cm!;y>ZI_@j zM=$s?b3u!it;XMG=00O~VxI~1CIFJ6`Bk^OW;C2TS16+<8LyrxBR)6X25~w{oHR*;p%u@jUiuc691&G9 zCkYa&_DH~LweMMZ@2wqAXHrZb;3yK{H}uoYf9xpHLh&a_l0+E5-UK%glnE-RAN)cL zw!IjAfIRUjjq1fcdIe@d!l%IW;itu?`sZcf6MT9+u)b)bHmoH|)Mgl-RP*Zvt`{4i zI0*spLA37q&4Xx`*}`HlJH*&}kFhA@#OF2Lp*Ows8Jq8aob}OS_;0QGvI_Y7M_bUf zU=4LJ!783fl=W5`1#|)GN>v{{B#%dFtE9fW`aj~=1-kQs{Gz7?|7-n`6Gqwudq`S) zkZ2cGpw+%{q)D$x4l$59OoTeHU(01?)9Hx+Bu}u*Xap-B0BmVi3wIN?B=xW>cGm2A z@gn*Y;C~}~aQP3&O(TH^{IbXYE&iAM-*4AqhL_L!GW7MEo%3|b+8bn7uQo(EPX_83 zM;oHiv4E7}*ctjXE%QP#W~9{N$dDoiXt<(F@4Ju5OX)ldv=F>nsjB6!0j^;IN(gTt ze#zZe_GdB$WFdH2RKzm@Cql|G4R5B-5ghpeLc#DRKQrZhm9P-J7U{TCw>qkxPM=gC zYqEtWuVV1l6@KMsr_|$K_Je1mg5`p5z=isY00O?G>0!T|+pz3KKCdabe`vD6is1JY zL-Fq8QV+K(_k2O39lR(S{I7()c;njyW0YbtV#2cR-Apgr#KE%jU(31Oe1^qSs zp^CW4-aMX2L~`;$NZ$Mh3smkN;RsC?%#Fyjj$sB1=b?`Dwn>! zzUxM1^Zjj#ILb&35~%ju9HK*7D29FuUxpEQo~(M|GX&`vSU`CvD>tZ9|4Cb>PP9tE zSpdEh@BZr()P}*Q)c~`-9b>XVRhlGF@Sd_hs7fZ$9LBh*H3oM-O)Q*Hw_@Jg^*-As z)R3=a@rFOoJ6W8PyyU9=g7^>nxR^M)NUG1UTF-6d;|ppzC_H$RnRtwqer2dgs7iT1k0`8Mo&6Q`;2H^!w%j_4JAE zTMII)Xv#%5e3}a;(6{QDo$6QH81|sxzfa+e<`Z~?EjM-}b|+#N$<|$mxhcblu1%!? zZ&T12B7VT7oU9s~F~uwdwvY#Z)tlFA8u4ICgnUCEm-&4NfzWTgL7BkOB}Dm zo{9J9bL4BDlBgfvp8T0+?#wilf_l9bMy#p)S+5?Y32Ey~s3#I!!xB~+l^hO4Y9~8y zB*UV}>G^=bfu6r14^Ar~!K}I)UQ=>;EVzEM*22Uy$&JLS6X-HIB_-$}GiCvJ>hJ8g z%GG2}`RNF!%U|4I+*A-Z8Lb_l2kRfQ+S+aov*ST$x zRtH|I13Rw^=x2-|AHIP;%Ro%^Wnpqb7@IMNlffnO`LuJi zn%C>SbuI{qXoZQgauFYM7}+HzkR+~+=9W#;BI;0oBNjF2F2&56iDw4yL5ET6hevt3 zAVtF@ebUojD9qOnV4hXC{>)Q|IamYHPil&?)TY?p>~+Q>NAr{ zU8B#sn+k>ey>JJB>&0E0-Nl63i92ReW^9W+6ll-!Gs&?@Q=+cn>4tDSJzHvEIm)3@}&vRC0|Z&nn-|ACo)pXKBKXrF@kZ{Wjn@W1!zGVmW_@ZYi!{_w0>wY%VN zC(K4kX;v}OBv$!A9;6aY4@^3{)h8U{bu`v04!`d!ZZT*N|9YwndQQ4P%j3|K3&=z0 z`Qs1lvm`wO%BDwF81E~@iV0j_) z1jjrHmgdFW2Zsg!?%J!L1NVZ>ZT|pWLpIk0`mXu?)0RK=)rWV$*$%JMJ8@)e9YYDM zX&q9B+xb~_#xp*5xzeovbH-Qp!Fwp!0MHLAa+nF0=d%vsj+v!skY(}PZ^YineGsRa z9&V2ym-{M7`J}8jWWpLfe#JdnNKwK0Rgp*6sDy&O*?>Si_jY&)FNlvBKh5}!Pd$|y zzk{(gL-NQXZe3+bGFIx7V5&@yzyrpVdj!Ba{FnQAvZ{>n+3f8k88m`V`md2)()1(@kq)ndn3VBBAy)ATUdiAioIz zQv`?UB{wAdV{N1kMCQ68{6(upVI3!v^jeria{ct?CxIJhNkqLtwJoN^;XqggyKK+u z^=eSRhruuHxqwXcy+<{PX>~o$JhE@QYE=+FltTz^q=A~kHgsC|Q;=B(knN{EVefje^%YuA;nS&V1v`6*)@{Nrj(i)Rreq=Z0C+C;z zQHb6SlPh>goP(fuMg}~en_9Ay$Qx>eB!4~tr(tHjw8~eP6*h8rImKJ_ z)E^GmkzBx>uPVN!9{(g+iRcoVj4|VydoUSCUT7N)iyvTmU)UAi{WM?Fh5{yY(MC4V zKq02%lQ@!1tIkB9tYZfQWaNTJ&}G}}j#vZhKgW+iZA?|B4DbJg27y^UzL+7nsE^k4 z!y7O36MlsK4;TKgK3BJD6x%`v-Uxo$!`7i>qQ}gl4QAEGuQSg(H?qn5T0O~m3`2$B z*9&~J-*DZdZ0>h5NcBS}!iE&VZ}K+bNfe?xF)vaKkVO1df%j*vkV&I5f8nZHR z*F!!D*FK_2IQ%t(gUT!oU>_JCjekyhpeh~^B*?NYvz*6g)%I_!v|rBy_Ml3gf*c3a ze0$aNH4o@r_oF^hPCHE9DhR|HC;s~W&e!aTpJD(4I_8A-f^&3YF*1!L>(wm}3R-P- z%$0>cE;xXCNdR}8f;0&rLJ~KZ=`TaPS~C$wt)yP*qQ;?`P9>RQ89XeP_e`0jnze2M zU?SBgo5D9vQKf;I3HH*I8Z4Pu-%XBN6Zaisy}By-r$V;wu!pxfe{9Z2E3PAYXxPNn z;O6QJu}o0`TY2VTJv7ACL>|~|w~vRKFK>@u%vRWsR`R_C92mRWz)8wG9CoRSbWVRm z*9B_kh+`i)V}aCkG#c6d6kYWm*eoX=ae{5&FtX%$Z;pYiW}l{n&@@QXG)QO~1e)Z% zA#z)7#!WQfG(&JQDEDI~gsC%uK~`M=()3#mZFoOJ0Bu2@;8-Mi@YF+=2j{+`c_6X` zjK!#+THYlrHQt5x!b=`j$T?=NkrXZ*fJEoj;>UFG36u3mN~e0hIWhzlK0>jxG{k_naWQe z1XW1hk0h&t;nVbS&dkWVI&~bDtf;n9JJ-RQv)It0@KlRV=uBbbU$`)cdfJl`V#eN4 z$*T2Wb3u|!O7G-A(6xlkw(95H^#Gu$(ji%N!s{-J%w_IsY_brLPA zCQv+8VhF{hoPiM#*JXh4AT@7e!XpG& zYhv_i(yiL;ZGLRH!;^Wv8hID$0RR}me<(+B5Ih@EA-L1X_?5=2>g=tRDn(!^^#FBU z@DPkLG9+cyuVj1@L?z_Xa}l6TZbOJ8p)*=;R`zt;C*a5cwaH2}sofPTYt|T#05pc6 zs`w~PAK;qG+%1U_(f#{rxeZe9Ag-Uvdd(Ofn`HiF-%~)QKeXI>r&d1jH2M`kSOh2! zyr`v>R39(V%rs-HGaM{*;}uO~E}FzE`_gFw!OF5tykLcs>L$p!0w~w3KR-vHV9JzX z&6@eG9b?$<5ic8@bwua(ws4_8>Hz=9Ga4@^3J#3_qAQRrx~rZ!qGfVQyMqZWEfr(+ z76=iRc!C^MLPfvb%->sjFQ^Wi%z@7}sGz&jtm!`Y&S zy<((_=!6sI*XVGW(PjiJ(~AR`pay8AanXN3Y@S-)>Y-oGdh!TN zslsG?SXlZ^*XM@o=Eom->q!gtDgi&-o>C{z5C;g)BmoB6MT2mCf2wXaoUXfFYdW~V z$XHg}vucu`>+^<0y7FrM6y_nW!k+`kIcBJs_cA`Tx?-yOr)>9ddO4(5b=wWdBL~xr zrKQ`fdjE9f;p(I86Pf~OLsmVj3&!x~`d<&iP46i$L!9t4IOABgJzka>Hft{qb(v@l zzpR$2>V>Cda?$5*rPca|)wpGAp=_!yY@L_P)35OP(LDZ#6mpgM zg^5#cs=9S;Vpxh_4>lL!*R{Gp_;p$u{5nEjjnn*Efeepdi#HH{y#%J_qSxIL;7P>g z$v5i;$^!g)3uPtvCH7Ew{-`?eLXkc9-brec>V}otrFip3pMN|_M!6(8wE6Riq4?sZ zZu#u7bP1^PQFT*vmszXj@l8}8GsAL{@4a9yVySex-od{gzz=x+S9 z73}x$;urREwfgd5fFd$h^eiiHo00k5)Gfg|gtG-a86>r{8gvEZiR>EOyG7 zNo=WBW{zGbROOg(Er%Ye(mR;8`@$;FALTS7Yx!hQd(GgR}5EOI{Gb7HiPWza;BB!~j0sM*m;PSymUMOL0DKRA6_>kR1Mp6htv&WwFRbH1FH=S zR6~_vta6*F(#R>G%y8S_g;_FER%gu=7}e8McxVUXE&P&uAB|=@nNTb5p(ffHd~{nG z#0t9;9F)H-1e65ujyp7_}Ycj>Wv}VKK(|{kmZioq7L|i zbu->HMz_?YOx64KGT)vs9(TFN?5a-1;$Te^KPR1X563kdU1q{oCh3E&e0MC+Wg`G^5=66RL;g+^f*g( zBd8v`u24<<8llj&#PQg{hWpHE;(4a58YK=Dr~^|uv$2m z`hpHq6X#})+nD@@In&mds^mf#;mpIK2$2EumYTc?0|hphlA<_dXuz;2SD zi999)Fx!8vS&6J#2^F*IWCwD>d^|1)(P<8ph56WvAXcdIUF7s|8d@uGZY|!{+IJ*>IUx!qNt)-j5?AZJ^&~h zZM{8K=xKCVzpSC=qGw;JCt+Mf`sin#c=_HC#7ouFM=NvDhU;}r7KwJS2tPX24mj@4 zzI{_F-rs-l!m>z^9gZ#8;TQ5YD&Xe#_%`@%ptLebEv~(q=Ke}~fDtwSTfKb^sixQO`v}e`Mpj|LAIdTUetIqkjmlkqUF8W6T3`#j4U+o!l z(f{D#XOBOJ+EqP?|DMn`;!EV2YrVP#n_F6B-4D70@>u zpVw>QL%R2yK!#F)4r-Oak4|-^De>iQ70#D`A^q(G%f%-eMOzVK09ICg_iB0i`(wd# zFFfe^@{#2K*z_g7pLRb*?c;#5;`{2 zbUa%rbetmtt{ImT(Xn*f?HBlce`BCg7(W;Ogf9Z5h>7{}ue&qQ z#9<<7LOg%lBcZA|SF#$}jL>{K$en^|2nVftxQ-U6ZZ7(gb)d~m z>h}Y{DSxkn2!-+Y$CsV)OrGXWrfLU#Bhf4O>xp$Ff0V090;wiJM4(Cm@Dozm>+HV23 zY=?{2i<98-UlU&CNbROVt%}O~eQ`=ujZD?y>#e%Ixa~LWVf)wp*>yg@eR((Idr|rJ za&7U(rR25Dk@rAyy_{fP!As-3 zpUnQu!MYlNoJxx+41`qLOOb?;) z?vL4!^fZ&aMlcz1AmR|5NB{k_|MTo2ptUaGz0|GqBd|GrzhDDpzTZ4VhG)G@MwkyRhz{Y6cEzz)zr9-D+qH=~i0(6G5@)R+^SEwhBV~5B$CwHBZ)e^y(i^ zmxSyQ9gDL`P(|TMBARGta?L|y7<#q-NVO#99*Bd2^MQI!z=x;x2VRz2Ha7q(BL}SWYcR8u;GTMNAS4Ly5z}AS8-r**x<4zRT}+K7gXvqk z`f-;oj-+6t2La**!FJTb%U^dhSuW2yWz`d)&2(q>3q-#D)E^Ba=%7Xm7SeU^!O&AX z{FwM@1AS3@W%FQfJWNVXTV134&8o-sa=poD)pr|agfeR=_RB+=5K309#9B1))yW_UoSKYxJ5Qec5o7A#+z~iF1XsnG!<0HRx zq`4_yJ?VDg@3gum@xEC0!1P#EbV3Coge3V5Eyu9Q;TLnS=%)Q4sSI{5S#QQodSp#& zxu9MH1oJFc2O*dS)@x+Fn5oyudX0qFMEsi6xDHySt#%fq?ZH0PBm|+~BjK@&{Ck39 zI^v0EX6it}C{y*MCqwjLQvvTyV(*E)Hnaa`{%RpTEuiNBA`1i-6VF9lr^0FG2?E6Y zHOHhu*bpJitQw&Wkgrg~O`EBvjr>I$#wV1Q#3_Gy($FbQF5OzwDSft?gdjQ&37<|e zV6*sp?j8hZL}>J4HqQ#8tLdgIPc^2Sf*DSMYG)=8Vr@|n_h5Z4`YB({v>kxMN3%LZ zyGcAqm}>V`q>(Ir=XtW!zA9|~BF*TCVFCYwy`OUt3am=9%~@R0L!Ax*X#OIdWt&Xp z=3M(iKTNVsjBP&rt8UY2+Q?tDStM!m^l;({!(M5!^H#dO^tq);G!j@X5okuXz=nE={VAq4gUg_pMMH12eK7ttCn1^`K&iH+w=!D(57BsOw)fYW!nKap-h^- zu!Ttx3Dtp=N5svK8o>zqw@O7I(1?l@UXp(?QP?g9OeW6bNFx&k8!yyrpDQLll5r~U zMRFrV4w>+?EAZ(uSAb*R|2!J44rjfS9~jI6;~9gh7;7aE3?1_cTG;4DhAkc;u7Ud5 zTi=q&!NZ(2yJH*O2SXd@2Dm?7>d;%T$W*OZ;=u>*92SG{W&Rbzt%0~4!7_EA0v!DJ zHeDe%6rCg9gxAJmuRRoGR;n{q7yBL76zZ^s%8rKmT(*C1UHI)m*{?e&;pvAPFG%R< z#ZM7LmpWlQj*Q}y>s+LLKrGG!Qk$hvZ} z!RFwH#@k^Am{MN+Efi$c7eIWffkVu3AL4B%2Z(0XV4&mG0fJezk)H!vVL>Uu4(!ky z(`U^#L{~qOJkpXD0cFjYEC?NLK@0`A+gSvSAxMuaKhirNCS_$d-`}DO%_77fXH~Tk z3`-M!yVeeP1KLk5@~nN%Po}*Z&m^)jF0PY{j^UH@c?G|> zts~NUT)|5G3ilt@!eq08Q7O7hV(VxAgz3Z`|CyW`3rZuh9T4f-K~){cCd6lGFL_9S z16YMTBoOl&b1yG8S$i@ZT1rN+$Rp2pqa!T)Y-DLtH_CQjU|g{Lier6-pL?Nf+=IMf zf8zPG_vI1~RXS5OuGD(JbpL8v(6#1Ipl$I#o2J?NY5MN`@dsmuEDg(L=7^VQw6Q}v#Sg{Z zGT5GJkx$Kvx9+Ot0(< zA!tt}#}tVc#bWV24FfMUL>1M~K5`UY+xLdcS6#o|_OJ7E@~dbcZu_fBwYU12iB4Ty zNWYIP-q_V|_2;9ttHm0&xPA^e5I62$V%Y<#N58T6KN-Y}%9(-Q`j-~03pH$H)`est zqpfZNl;f;7~m*)x%v~!#!>Uz_?%jgi-$bRHZ0? z#@yxoEBE6uip6xj-)yLLCMJuVDY=-bPUO`!8R_)byVoTtaW%S9H!b*KF<>Zj2^k1Q z=VNkQ0Si39QnYv&?81_WH5<-dnsHSR+S=l4YYG23dYLs#j356ZiqcO=J&XI_%boVS zmuheHL+_O&)F~#UH64lkx?5jPwMV}?4JZoZz8jUdx_qH9#a~3$Ly!FH*cVt@BU~g# zTE}4jXsU%4G4s8!eIt^Qp0!?53nLN~&@_z_I{k&#-$RoC@|X~wIaWs^`4Fr=$92U~ z_gX1=`X0p3iw9mYw?G&V%>QcD<>|J+u!rsUFV)`n7maUgnFiy5{>2-MzbHcg&EDm* zOx2Tbl}$glxSC$326SqUKkC34mk-C$r{Mgu)oJ&O{B~AFoc0CuF@pDt9QvSo-WJ(1 zJJ*8Ud#9XI&MBv7PsF>wGdpl)SU!5kc)Ox&934rq7SsPgHqp`Ob5ZgM_mq79#)r|X z3x`Az-{zt_v3dz`BY*&_CMWB&Y+Nn~t$q}mU&7VW$F^RERUiGS%-%Ah|MD^<|n!6_9t6bpnXyZrbEGDq`~evJs)}!6Q1)3*Z1tX&%4(0 z5#N46>>^Gxt4_uJv)lfv9=1QcRD0i^r!LY85VYrwi#HC~^Wmw9pH84J&WHs4EA-4r zQn-*=n_0jYy#+t%Oz&81woS&=rS!mTmiMRsox;hgxa_b=lQ`dQFvaUic0>)iTaWpW zmib~%y}II$fkDQ;monGH3{#0GH!@|DJgu($0K&>dbK|(I&V9CF8^|Vkn|`lj;=BoH zAZt~=yt^Ot5b@@sN1a5bigSvql_v(!Vo&Gb^R!c`8EJ24wR0IXbN$;*sDymy-hQzE z#)8b&rwB+$i~95us%u0IxQo0rXp9S^(X76D(xB1o^QB&$$8+|X=*MSjzSOI2u)ZeP zqM95Yu=WV%(#A<|&fj|V9yLu&TaCbBWK%C?V=!qiLulda4f@BZZ+anZUAWFa6ZGxNK%)5>bN06dCAiUU3kIx377X?CV86q&%fBiuUfCFycvHH;Z0-bQPHF@Wd7a?K9fyEow|G^aaLFB}lL%*P2 zW)3=m=fac@vBBvlkQ{N533bS^0Z(mQ#tU=AOB%Pp0c4k4bQ`uqP$3^qOsZe;EKVl+ z;OW|DCe_!~C`Y`3kC&>wZW6C<0nBR-Z-9jTDVR|G$U(56th(t<9jl0!oyLQwLD7nm z6>blM$*PHOIqeTAuRTZ{gXiQV+%dZzIQ%sX#z^Uqaar}oSK3^lYI?%I9Oh(k6MU@e zwq)yA4~*wT`edGA)S(}fU4KJkicWtb0k%vBApV0`;I*&Vm|tp>M=qRz#CywmKe66B zyl(wA?tgC2T~*eB?@trmld;m&g*#vlGRXa49TMTUJER}CJ%YZT3=Z=b@A<%!SZY_N z{LO$Q9QfdHF8af9P##ClPPqt$m$f#v$IzL>^PVCv4Ej5Ms)=wMz$&%B@spml|4R?s z-`8z>NbVD^Q+gb7=mY=M4T$vo!hHFejiNJ+$M;nDpKF+Q4m{yEch<9Nzi|CYrRzKI zPiqs27N}2rkLeCy%6m&A;3OtEFb?3e#w7Z(PEYW!P1L?Eq1ML9eCn}PfiOLV;myPe zKo{PbyK!(LfofKpQ{}Po$xB1plh=Q+e(v&BtAG9#-PPwYvx6EP7ZAeNwf6VU5AiDBU*Yv| z=u6m5%-|fsFgXanaQ(@vSbtUFXN$t8u76j`-x_;(jXu>MwX*6T*2xCxv-WeFEiT~S zueGL1b`~g~(}Qw6)_^A@aHkeNtYS_~yh>Kpea#?23nYjy>VJ>H)zoJQOOAZ?Rk8Z1 z_6aRK_iEmLMmQ9G`dMhxlxe*rU#o&&$9n@B%MSH&IO4|T~-OSgZV1eb!}P5EAU!M9ZV zbGs=YMFE%SKhsV5J>8T)U!olH*B<~tl&QyF34`DKier(J@dg8Vi=!joa@$|CYVEJB z1WVlhk65+#Z-(1%RJ=W{d(IE83I^%35%(Iriy|y=%BlY=j<*HcJY4%A3jzORM7kdP z3Y7;;z=7|GU_aIo3w?}E!U`j3kGej_pRC%ztABTQ^(Ef=Xt{;oZT%SG&n&+_W+8b0 zy8QaWv%|`-Ur+d3l~bSc%7obA%7f)MbW?uM|G0cFgmbR=UHru!{lhh2WJI0&f*KsV#pif+mu z?{56U09A^<8~7em3HcmwW9jjE)8ChbpNl>=G8m?W^czZb(dRA-q4LU(bU`AoAXl`ki0P9uPXkueFqeEqa{gUDfxVx4PADm^HP2>%IO_IZJ2Zo+N`V zGxjwT)B$}6QhRps)O#`-))CqrA<_%kCqTUUSqb{hk<#1fR-A5f(L)$#>VAf~{e<`@ z2XeFOOp00Bqr;P2nTP75c;iy?C{+%7&P5;kby0cK-p~D4$YbR9zX!((PbKrpbJ1fD za=HkW2j`}CLmnx8sKfYH{;~98Se1o8Su2b)fq%Sxb@%mqk!Wk9-rWI|IBgjaC= zAQyetx7c8P%=gxxO>sp6owY9ays%TB=W^|tb2gkiZ0 zU(4q0kh8JhrHdqa*V=YBeluPEEEX?mEwb8<*7a)EF8DY)zEvZ8)_Og2-*5D?V9Gta zMy3Vn!+b^9f2McX2Ap2wC2n~Yg2(a$z+a!4PJi`ki}%fm(y;|U{46RWgt!lk3r>cq|11Fzb#DlxIWdFbQtgU?|xtH1Dp*ZI=DaMt!-=hIOng>mz>8T zV_+f8t=dAsC9jRcpVm(yRKCAE3nvu0(Ucy$rpTe|VCJV)Czjqsrn=f(lZ@Ll$hy|E ztz?5aiGuGd%VTfe=Oad*$x^W#IhGvg-{MQnZoL^5eLO|^ME@N6FbMK&RfPCr^&_%H737X{&O z>dl*ssTPNS?cN%QaIe1D1BAuetMz3LCNAD{UXWnX`mgvAPFeN$S;hMBavskE>{3T? z%!V{vH|yDGQG9p#ao|xu_(d@7x)WXCwqT;els!K8MgV5;5Rti#yRl#5va2|M-}X*p z1e|n$K1i9TcY#lGjQmd0BE`LawYN44GHwJfZO zn&pA|kdtv3XxT3{Br>+&{JwP%d_~uq_<_F|Q*ofmKRy+tA2fL9nMWTVCIH}Msmziz zB9a=Cij3_~cPf2EtYhWC9YZgOt^)YE^gO*1qz71=@2}kY>xXAz1?V?2FjJUbI0Y&G zUWgMAbzYHb>=%^M*KX*kHYwK6<9@T;7S}%zDp1VDLkp8F%mtN;MMC(H)Nig;RM&5oy7) zNhR8HIb^usbH?-A!6&mF{N#IK^boZG#SqU<@004bEd3VjwS$AMJ82EHH}ahV7vIkR z8@H6{I~lEN!~F10@#fdhu3~-oh;HDY0(S|vtX&7$d6mQd$-$rXAS8N_BjH&yuyr*4;jJ7dZyB?0=*ymvZc%J;*ofSw*HceRK^i%4gKj+ET1#8_G z{1^2#Q+Wv1R44$OJ|UDI6H2#*(zk@t^FrxYL+P&qX`@<6^$vCroT2LlZg0??={Y*4 z{{2aP(a}6Ko_HTPCyb<;x)GzVP9{?7U|bUb&s^m%#GruwhTwI$@cQ%|-eMQtgY(P8 z&z^?{RoI{9dpYi?spaXd*KT@1!E!JHu18o6=QS*1y^qxg$VVM_9l6J0c1Xde44k0K z(=NZgs3rCv_AA)G;3rNC=M9dA3V0-h@c++?McNzF>~`$d@JKILL<8~P#80`;1$?i7dR{a)TvDA6r8g0H zTK%?Bd+LCDEo|_?kG6T=W#k$`7LsP4IwGz@0+s&W^1L+4e?%9w&uZv;F5WgoBwXaj4!tYPyC$jvZ`t~qs4?>OM~`=osYrzo zK1KFtti54(!{d|--?P^O?Va^EFvTpX*Q+MqUs7KrkxVk92DYg?K~Ouu|X4~~~_pgu7*L=Gge?icHsx`=PO@#OFw)5pVmEb+k;_@4vL_Gh;^>BGX zfU*tsC-N{T#fed<{M>^ym%O!yPhd-pQqXoD< z|DMy>-Ar~`{AW@;+F7!g4>Xc7FWV_othXbcL0MgkW~fvs0`&XgYtiNG{JQR2ue=@1;^_x5|OvOL1} zW$dNT!SvmtkzBUndS?YcbvJi;@6;(BK$Z)F7&`f*J~bFWMAHhyZ-z>lqNAM`4IXP= ze@GDoH~>(=Sda_|!IPUUatQ^Hmh%!y+Pj2u2;Srr1FMgM3EKAX)wWdoQm+b!EHK!) zPZ)xa`37j!nnQbZ@x7E_1;z65Ed}rNa^MC1+25h(+(Fua&Pn(kD6gM|m52(62Y4dp z1G2s3T2^e1p7N! z^>PP@;peM=*UUp6G3%4XgpfmYBS*K7v0I9_^kP1H8z4V?-Biycu%E1tJ<0l&jsG&0 zlV1dy7}?uB`=68JRb$Q=zCj>D-1fog7OANx#*qS;5~pRIv{je7{Ye|uhuz0 z7>h&a_U|SV7jq5kEAD8|ph%#XX?6dZ+~jQcelFh#D z1;+6WR)`5+59jsXpwH0Ajcgy0*6)^A z&y+RTgo;h44#)^@2zUX|EooE?TtLL+j70Z!h&CTzkpZz>eLLa)e0`%%HW7hFBKu3! z@RSbOd@o6>Wl%#wdRFbL&-K-p)b})DYXw&y+~+_?K|tvR8*gsy-2q%ve=VQPtBL&y9XVjG&aRsz3BPGc|H~s$Q>l zyMg$bnFKpgPn#!fS_leAL!5lNxpiKX4)i?vd;ZpU9rfzgJH>Tq=a7qj1kWTU&KEc0 z6=QeKMb~E*oS{sb(rVY6!R}1-hM(IP7T$dork1HX`azsy*&2R>$aChR=B$eZs7(-G z{A2sL%=8h{uX^a&Sc`h=ZvuF@0owJ4ee!+1 z;rb!%gzB>vz_fNiK6C?m<-Yo3;P**?6jXKqRZX!T6F5mtPeLA&&}AljEw)ME5L3U! zgLk4{V}=o_n7T<0Sz&F^?8R(Ge-I-8=`^kq=#EMSX2k4iorRd8cHKb$O)C8m%0ugg zVSU$=YSzd1ew&TI?QK}B9^Zflbe6Se0$$Zb>6NLIjW55UmEbtBb zdy`wr2*mh(e{Wid5sd$xZv;tl!ub0EBef9({QZuvXd@tiJR?Y`#I@w7Q7YC^{{nPN znijP6_q#sp6AjJr=G4|-#c3|k3sMVbomJ}xj*S}Y3F%jSUrcTB&e~%z4O*T0d?_Wk zJU3=tWKZ<6$KebX1j`AC?Qa$R<3C;{CyMyrt8ppRdiql1S0kB~CZ;KM%N^t$t%FHD zFi~HdSw|S~XSYT}YH*ejm7bp%Kz&|HwLi=Hjcl13+2^SdSefa=wR1g$bU3dttDOli zZ#>G^Z$F;t7?l^#R1M45Z}Bo?26_73PaZ5O6wmzrn1FtFKxOisyvbb!{WdK`zpuwT zs@BQ7`n~Ij0{Z=%m$ZHpKmq+uUq$@fcxEMEX@%*MuDx&wK=FsL7lx zC|@kEzpJfG=2Vlp7=P8eCi7}$X3U%QKkQH3+Ulm8OfmD1O{P{Avy}ESGoj3lG%?Zg zCjLq>4+<$k4gvuFg7N8*m5Brsf*lfZCIl-aVoVU5k?jlklQ@iQ#}vta5&-d0giP;1 zK@L-X$FEq2QH1}3Ip7Eg`*Py$Z(V$8jeag@^S|`dOr#@M`z%dGsYL;r|1lLrBGXC_tm_D65xv! z%$aF&grCBhTTG7Vym00+lOz1hveg8WBmC@O&S@q`_$drIT5^1T5^=S^$s~Cxu#PmS zotbG_m*g*DoiI#T2Sqfj6NZV-f?;5t;noCzF`uLm$R4+*qd;?ud~moG9vg2UVMZe3C5*<$)zENHcY0}KEpM+Na8GvSXL)XkjB&z z$Mp4B5a=K984w5%eYk%EFLIBmFFA^$Ebz^n@PN_{yiM2xP^j?X=4wg-4UbP|6ID0` zfpvT6CcizJnuxGoZ8ii0j|1cC?gNKL6 zDpP?gdYpz!FpC<4)Cbtkr#RxNx$mtpXZBgU|25U;sZ^i2CT~e^@AocQ4*V7vR%Zey z4g>gzeqgEN7aQv7gUC9`Uhyld8V{$gkpI5a^7#XzcxZQv-=MPYmD}wDm=t`6X23P0CWdlfJ^-A%+KO91yP!WdcTbG zZf^aj_oRO@F*}HWa^u(oek9Clz1}i#`OQz4dHuDr{V6bt8r=xT6!IFwB#XB9zpq-t zex3fXJYn!Hum15oRU4=(4#GWO<({A4^J1MzbYRFkaDpBi_ z2oD^Omp3bguPV~Djp&Hl8jyVm)eh`^{p$MIX%YMm0e!4A8?ajy&U`|%9ISM6T&ZA zHhwwO%)@WnqWER~>gC{fQvkoXC*QI67Umx}W`g?nd%r_pzP}3L=L3%W)eGaNt!n{} zPq@}q<6-tOFf;tH00R13xV`+E%V>&i9X+KPSy{!B)>{@#wqIgXY00z~4JU?UP>n{mW{< zWKOVuUqRL^tNnsd`+K9+t^l8rSAksy6p-IZq5l1m@xr|33-jCOqBZLW`}YlWzdQi>?Nyln{s6URUh@L| z=e19I@|%BWIr2XOFEsmRRXxW?|H87`FF~kMxc&a+wJ+-5)-0?2Hi-V+{(b$KZ(cty zt^eWnzJI*1toC<=_~*CZzpVC0gz)#}w`N)GuL-sH`=39f9Qof7YTxMT&xK{RUl3|P z!E3*NS?#BU?9Z3qnq{?L@N`iAS9|@>#~b4z{}+hQQ!H%IOWs>^Fu)2FraiLqC==a; zJ>Y@BSyo2-_R{lTo`%$ZQYy=GjvUv)BpB2A@pZAt8U5w&*5i72t(}Y3t;xd?xHa;X z{QI`)g$=PK<4=hl{D23-3oKOxX1qVn^wV1t<5H9}|HwpV@-<(q7L48@yc8qjhpmro zs>*g&1;JQg02ky^CKu;r#iefxCmY-^#!Szo?O^75(lWU~4J+M6(Ov7YOlicF0ZoL!S ztM0$2DmeCJ-4p-p)V$1@{dZKi5?#=Mz`#n>jUUfmn#1U50bVH^&};u*NoAnR)Z z{|2(Z+Y$QevCR52VM0>Q|5#LzlI3|U*Xw8ekr7dMy-33}E4Y9CqgH5?REVq2W!VTx zG?F8latO%yCjMFF{)W2!U2@|8d4JEn;L!KNA6C7;;Cd#NG`lNQ50`}N&#Y9Z{#NT3 z_|2jW)y)%z@F%~0oG-(e_S*&9zf(^8$xi!EWX+%d1b^=T@ZO7N~p zhYZNSG zoeI>mCv!zXWWbi#=9w)iF=}SM|8VR>WRq62_+SViKq(k-9kG-4^s-xqTo1xXp9Y_XAU+%Yukl&V2TXkaya9C(i!4C* zPr{$8gZNzYzs4sO#OH|rH9qQ$fV?*PY4Hif3mnh0>UKPzjtx$Z69vy(h-aqfzl2N3 z)@B+-$!IEY@?xF*eX*k+ZZV+X3fq2O7zGyu_a`7C>P=Hg1{;{Xb5SJp zE>1$-I~knR4dp0Bgo!4)iDtV!SZ}7hlG|c)lh(VIruVPh5?OtEbDU!edL`)n#8k?s zw@)6upYzywQFxB6LZd>IhYJFf z>!_(4`k(3<{e#P(zrU%JPk&{e6?Z`uCTRcYT8XulrYZ|Apn<`3|UW=@Jn#Zon+E7fPe1cMppSc)ya zDOOF}^^t^bs1JQ1I^neAjLmyhEUR9ZF%GKM<8I9Os%5F_ZL`S4I!f0-uqO4=k`&Gn zq}Ix^p-(_tgUcKYo zGy7v@9^{+TZRW`@WPvE0Zp*vh>f9HV^SikIHCRro#q=pX%*N7caxJXD%*F!q9WzQ^ z&9geI8wEA0B}QCLSROfG7!}+szcs1@`CbDc3-|RXdUD}Wv~DYHYUNnb^7E_}%?G3O z?|S<#8)TRCew){(8dc>|Ai~+U)dw6e1N+r6Tt=5)Ph1S8#SAt^IVRNruVCBTQAMEo z`j2EJ)!Kq@eQn61h_iNvWfdI08>F+bygSe!h&s?r)tBH=$3B2$@e!qi#S*5xcncN9 z>lsyST}Yp8{f@}&9R6EcJvWkD&^{L%H1Z)%w1ynu3~(Bjf7I|WAY&js?tkrY|L9PR z()5p|)k65g=qMklLzW3*U4v5-u@P2*Ws%Y5oblYnefY$4{0#X7cr1mq zUQEDDNZV&2d>RViGYz|(#L>g&6ykHqfS&@N>>>I5=^Mspmve#7%JqMId;;@R^L+V7 zQCdIr6!%tZy>Gu!MIHU58sBgzDi+Zvn(aM(VZ~bXOn!baCT6? zkgEL@Hzb0dMDr(~o)n>CH&-@Q#QyOfz;fs@kZF`M8wH5C(&YYaJR*+r1j`#VBl={` zIe93Lv?O_hZD!o7X_wrbt8`d>gt|Q|;Xi(Gpcyj|G9zM4#_f(w+PV!*4F0wHSj);7 zF|)E^FDS%P)sJ^f>03t*+2`8w-Vfk?2J=*C7n~HTGRPrMN!^gXb)$b_nQ9iJaSUEi zOnDarVyRhpPr8D4nRH3@eU|K4*Ifp&{P*G}eyUF{n9E(})17be(XN*ixg5hl1tuqf zN^&+kAhFHWGXb2NW3Mcx2wIqENeyX}TQ+Ah0!>^ASG4UUGSJ+;+2FHW&Op!?zXV8I z-2?(99Nn7Zd5VpuBVX+_nDm$uYPHViE5*iaI@P)F)8UyqTs$*%rVN8>7{Uw^kRcRk zn!rwFFm+cQY6XG`2_a0EpKrr1DM(P3z#U=yPVl#cXV~MXYP^AOX36!N>b9rIa%3jd zEjW~?i~B!9l(Oon8h8jag?frsn=Vc{5KWNXNZ5^Htp;bfQfdOitaM8WLHR%#mJG)<3u&Celb&{x*<_**Jy|l0AaQBS}VrQhtP+2|wiAkVF%ZJhYktC=%w+bzOvA2n?~O1h5H-M>r@ zgRd9icqjiJS6u+CTZ#C^@ z>W5)_d15OqXee26d%44^sB14{C`CvprryL&11g2YjpAB*~^}{Sh@A? zQ2m$}zQgwNRK%*Dv6s^yB8fB#(Zs)bI-h4R^LN&4ENU+o+-kATvzK<$HqT!EX6)oh z9H%CKcP@lg$QHH&oWmVL*mpV#O?ByaHUlsocS65(sSu6FU+RNHE zTg@A|>h|*PeHztA`W3L3m-~k7W!M8G(X*E`Je|+8m#cQ{#$FD?BPRNWY@WRwW!mQ1 z%NfQ_wlzriE`;=^9@55ME`~J1_QLzQrR?Q$(3x*9^7L6lDkFpV@?AP+V}{2!1u<+l zw1xu=5p>N!jaHy8n*I(p>O%yT#ADPiS_V!FS9g+bgDkCwtaa zjM)Wr<@od(n286$kwQ!ySz7(Ty2$H`BvzmMj_rGYudcPwkWB2FzjIJfX2#|6qhsxS zd7YIyA44Xh-6`cvy&jEIKz(AkDJu36cO$CLmaSwzJ?V!$fA-kQZX{y*9BL zF+JyO?+qpw5adaTh=D5!?+m40Nl4jVYRN`0R`;1;)27rGSuJ-8P=nMcbhFI@&v;R|BE~kH3<% zgju#2+9nN0136RUcbK1IvJyiFKA;77lPA?j*QGwbfh)Q^s2Vfq2zrNXOV|d?{LAeF z5?eC3VFjp0tq(-k!^(K=5^ER`lWE=WN>(r3ER~t2pK-B`UcCa1GDS@R(3ktBz}C}U z5--?88yy_}G_g+7ITsU9f@}_0*GP6k@?m0B<&1ewWuKcl3Q!&s=TU&hVyr)`)!Ne9 zQd@IIsgGsY`aqZqme2D;h5VHoI=8G3bmXsp!IVVxvTACD{tf|rW0FAY)!rKsJk6uk z?2(mfpU!LlyG^@kFF^-}T{CrWR6iR+;gWk`Ma7xhut@BpQT37{XgFAo_p1n5B8Ql| z{YfOW#*OE3AKyt~Vhz+&o_h-06I>JOSKd1TgZM7gN3>0d)yxGE=ED_3FX%NPsnBl5}aE-0SgQ z&^Mh-k(%N@d1Ggng=)id&`+zp>}U1C=nt;X$h>{Z+I3jsYVTWz#Vv2`Pmkx}^=F7W ztFBmyse*V5@IJ+yR+AGOHY)}$kd7ExA6M_UN$+~;IS>y)DcRcYgm%dGXrY}@e^V3t zrvf&qHm~GQp+DfO@chrU90_`%9!A{!rgkA#z-Y%B(jL`_(N>Fq5klI)`jU65Z`O@H#J4hx z?}=sL`}fm|PZ;0*w(TC@ zf0w{_jn{pA^?sYN7rqY*J{W`0=O3U$W+4^(Xa7b+MjafCVoIO{_M}4g*B!jiTzC~- zJ%+cvtj`<#7GoDcjM~KV+8BT=3Z}bY-Z)$(0J-dU7Ak z_qWi$ANvxJY!#j*;Gu7Ju|BK&e$leP$2UCblyPE!oww!Jw~x38kqPvvVlsiJ8zT46 zEBqSv|4_09OC*Hw`^#^?^eebu8)%pg!}YzJEKjqt_Id0Gom_gPj!9Oh;d&au1u}x4y%tk)NNefnX>ZeFiDe ze~NX|;)wXG!KjFLNMT3l5z%Q-UsvHR*9Qm)$(XwBYpY;8vLVb@W_zIbJI%G2bL6Fg zrD=5>b|SidDK*t$y)`S*(3_*-H_aXw`thJihk4l(v}To@SCD=;GwtGn28_dFftQ2H zSgMxA4Jrwwr-E`hsAT2MSGW)$#nf@bKwKs|^ex0&l5sAkFQM8PkG=Xk|5EC=zHRFt zP^doFPbBDULF|RVPGBe0KQIT8UIMYa_5kQSsfPbipl1v=N%e(}!CL&#QYdiz9h*J> zt~LDwNOMx{i4#uR7X-dytqXH>P#ktBH2i~y%52n?OUJY&C)Ef){7JNrL+7zov#hRGX5t0jh*D;4C?Z^VN~P?ZSAKQe{izUDGv zHP?0c4*h5?>xO(r=PUu9i$>=|hDamkLif4ohtD8Qbh>IxPM6Iw>F=3FFrnYCZ38tj+1*Nnk5TtN&(!PrPbUZC?!j zLk!P>YRud?jjPS`;NRap{O7MIg8w0T@ORFGKfQbSql&@*ilKG@|6k_8pIQRGIITui zW#GdPan5KZA#*RVhE~L}2>D>BPDwzMq-^8rG!F|JPLZRAN0>^|20)#-TpXbt7_Zq)wlBtwWPyPj!d(u6l=YYTTNIv&|#$P948AH#`4`L<8;bsWqH0ipFH3v*Qf$7gSUdw@o zJT#GSxtX(U82q^6``!T>*SHrt3-tNP+G<1u@~NO%*r4&guVv0#ddm=lFr5!impp$i z=tuyH&WW@Q&)q(B&~p}I{_cBwa5G>M3QonpI_t+RsdlqQ{qQo>#(vRc;IH?KCg7La z0hF;0v(;RF)>21krqzvGA-2cXYkc%^q!KbQ1s9ydcEWaCypWGpn3AJo-cF1;t9vbn zWL-DO#ZLJ&Sdy(8J$U>5nefbcPlFHVp_~?Iq-O)Je$3!0o(3ht+&JUkiZg$$S<3UB zogU0HF7QmjfS~XYzQ?%oj=}4d#cY7Z3H`BM`6+KWwAjVsFh50fSTy14i7a$v)p#NT zXm$rCus&df&iWuX#<}=NsgEYyq0~2jQ?3{T?dAI=01gF3_(2^#Ag3q(wmayTBX5IW zJ7-W{lZ^vS3e^wLlcThX(DNhE|1*JK4*iD{joMWE__|sgcA@O_VV7F-WlSAY1nCjF zU4Ff@B9C9PDrlDj^Y)wUr#IT_Iw>g}Th8Jy_kVTHk{7w{k3(Kn>tO% zZ*bz5VCt}9(7OxX*m8{x_hQNBC;{(0BzW-HRI*kqEIIT~nkBU7C3qRCV#%3M6bnH% zk^W*p9S$dAo|s9D@0BXV}JF2kg?yzG-I0xs40BTpgELl7~8^Z z{3A@eTDoYHPm$1FYNl6sBQ8&t=+9go{48gc;=3bao=V z+U7LPE9P|V(>2L!VwnTH#Jfb6EXA6k*MWD4o;vu!C|xpxV4d%x1njAvpy8ghMy|`_ zigxLG9tJzAqik-)=nhltHK|T{o`7ir#;qVftH!R0>aDV0=lH82{0H81;Ga+o{#pg$ zPg=vk5AT2f6!sj6YhKb*6t(A976;W%#_#{n?D-2E?)`D>dD1DGB|m{ZpZ#u-u}5yI zHS?#k=Q%hRN~W#0Js)wh=2egEd4K^Dw&(waJuJsT33{E9_VDyOL9vbAM2qeJtMLsd zj(;3`7~Axd*~7_i2N}EZ#y_n+Y>X4nJ+g-p+{%Ia^vE9i8X#eNC^5e6^p;^9hq2)s zX|CmsZ|%sic&&_0Akg#ne1jIi!Xp#xQ+I5&fS^$jsO$t5Xgt1K_N^u^S z5Ps?Xx;XJb4-vKT8xWLN%kD45agR|Qje91?jl>+;E9QPDi}s44X#Zf52*s5D<4BZm zUA+8Og03o5ZuXxveiK?+D#orWsB)^DRikjmy-SX=CUnNpH$h2!^cD`E%*{l^9hj$h zk=VpfdEyc0?DjjtoU_yOo9dZua#(dzNA;u^q()oaV)?6Mzf}7)`Jx+H8dHOAHs|oF z+isFdrZ8h}HlDr`b20QRsK@yW8`WfTCi>b)ZVH-`zA3RsHxYGz&Ss2V6_Q|lX+>H z03bD_+Rb|+nTdY*q|Mk7-@??SUh^8zYNs#2X$>w}VJ#u9_URzEle{vpp00Pg z)QhN-u0j3q{uP*Mco$RkU;#aJ9Bt)u!7+8~IaHI5h!@}U%@Tr*sf%opI2K=Pyi@Zk zbvf#`c0fwz%(-6M$*SLf%qu*YRY8H;$oN!-vm3A&LZ7BjtfP+>^^xKW;^> z7-AGH9H6H$+HZjp(-b^Os7>x!As1|rT|6#{S>G5H zNvn5`6-D4_hxLx51Z0B!m^nl2pHtOeJT$j*+#>}(KX8YjUYKUoy?3>MZj*0#a&A3K z5NNAQvR3=lwEQvOZ9MdD%2)_!RFOsUtyH1{iGfusDb&oxlB4QhW#=$yGZRHkS^Plxbl|L7 zmY|R8wUXW4n9HfdF9jM9IMHk3?btY~RdmTj-~03$!AtWg>p_#bidfzHxGbLVn2Ig@ z;S3<61q*p`SbO%W6*=6E>D~1JPiZGfeKSj%RAW1|1jxmNbIkeIAIM74n`U(^))0i) z>b59rfG5rB?iRqOVguC{f@x7Fy`uFC-kr%$K0^(VP8Zl^o@y}-P($Yv{X}&ekF$!i zHQbtvwE2X5TV0BM0?q`fWPI6*mvvvzPLp$v>25Kq>F)G-DBy=m#=mkAVa#B5g*<_( z1ZtaJ_{6Fd({??lch@Xh87_Q{jgmh^i~SKJ)&4OSQ?|ll%95eNHnOyJ7L^Yd4&o+D z3(Xwa<JARHJ>|Qi2K*DQa%kP0(^GTSDAvBF}`{VJ+rNlWPuM<~ApSbE5KT_54}XV^aYFPphM_I^w8& zgcM4_dej6^z#O&jOPU9v)6T1)PZAn+V$qU7}&i-%ZP^(24uCH?<+`w}=gr>yVlBo&>| zsx3v^D>O<3B?*dHDg;jm%1qGqdTrCT(WINUcXWr2PK_Rdyq3}S*hVmt7DR`TvJB0< zMi@&N#%&cQW6caw-~WH^xlcV+-66jB`+fP*Pu+XZJ@?#m&OPVc<++b}_em5Z8&^P& zk6qAzfw#7zcPU2>l{PYWVxr~4ABX26b8(o==CzL(BWH_KiC$A%p?hoYys4NHmc!6j zRO5*ZJ2^x6k*NKC6_{$j3l#LBcdTz6g|%zwEP^*Wqb1ZUB*h?m=4kBiFYdn?yNtKIZI# zn*q7!b!hg9SpT3TkRzTbXAWVz`2Gua@+4auKEQ8yAp4h_Po051!1&NDSGYA2jBuum z;IgU7QJ62K*?LAe30hE7Qs_IfYWN~DG@>DL)l`w2zj{76$M*-(O2$e~D&@=!E*}F; zO1lI#uLk;YCJxRU(E%EFL(A^)S)}X6O5ZHOqOnw&+Aq&><`Fb`q-o2j92Ce>M`iC!E~vXfK{$jgRlfYnsKo^Acv=v%v+# zTjk@sbSmGvED19kLshVh4IP`a^CNVqnmL}Co?prDlNH}%<4A7x&npM`FJlcJ3E{Kr z>7IgEaAVdD!xgMtDnArKz5mls=9wg2&8nfkvuJj`nw>r4~oAlaAS?^n!I(4~qCU9|+&{4^|f+`bR4_8zDb+ z3Z3h<-8ujho`NGkpxk@&$ z15)wH503At%j~2jy!-%vRjF6h$4)fh`Z#3?_wCBy03MLDghKSon~T$qJ;v8W+LlWl zSNsyAF#`moXz7y-^;VULOA!NCD6p)`O?Pd|2%K_--7~C+Md8vPRX` z_#|W@dKKF4r2T9#lPADpC{gA>P0T?Zvj-BLwD;0)j6I$ou#+l%?YaYlBWLd)0M~li z5SF0sz`6oVE_og>xn=%T7Hx8@(Gfm+e_t$Yn8#^YIRuFuBNDN*Om+%~e58*ncrya& z(mM%g7BcKmx)9xijJ0M$SI@f6BMTX z%KEqQA_{)$6!T@I_zAqov_=^5LiE^KR+n(hM?NV9ox&I{L|ZjKrQ~h}ge*e2ya?t6 zBO7EwxeA@LK}cktj9KmQgI|L*^jT8Pjt6SJqzaX=jwT!YsiW;Gd#z&M$;d4+zis z5{Sj06gD!7>+QBzIrGwU7&+<<q|E!DP>C?U z>Z$P5HYPHfz_6cDTZaTkAXiMhj+u6uM=IaGJ1y``yPhXGg|NM-8@D_HSAv?7GWTD> z*r3gV^;>q>>Q3d^od~OAl-zVd|BTgr9T1CB%2}0yaVk z5wNL?*iOIuD=<@CMf;Dz6R&pr-{RU<$g8yd+oxJal(i4P-Jc2U+yL_0kwp9g5v|=% zM8{dm5+x#9Ln#y1mWpT<`s7*qQYHG7HowD%-h{puib8Y%5^L6kC|83fM7ipbpiD*M zph{FWu`lMXq*bo8pK}LpiRB=+lv^`o4r#GG6|kv`NLIl3V|@2#?JDH^mycE{-)5yn z1ShJrfu*HXz6&q^U(5GU67dVNH>dSW(`uE2Vl|BmYm)eVp ze#E|3F#xd_$#oK|wig4AmsJRLB|={eTDufHu<>|uH>+qx;qGp!Di!X>vbozkx# znl@l?<%sAwVLJYON}q0O%zQfg9OjG2akIfye2>O1PhifE^$}Rf1ufpcT7xQ#*P+ zn37Ip6{0^}3%!FF;|!DIx)_==Pp<{F&qlYwH$JFfDO1n*)~QB<0Vup6HT~pkh_RX& z+kw%Hd8%wxLif5_7X^>w!=DXkZx}z5-d;Yjl_cRcP;>a1x_M{{v}V{CN4_~gbLyEB zW6oI3X<$wbbGFo+X695gr&@9xQ*6rTGhc>KI%6g@rsBDt5P#2}~Uj=DXlqDGesU* zor;auOec>es9{EFW*4givX_|~G84=d)(gXh^}=vreadIOaJz+-9k+!L$890NahnY* zx4Y!0<2JcvPO{wY#&4h7tF*_KzpD3cD;4|}&ei?N@*9+^DjuLYP$SjDv6_=%j$!*- zYL3WcDRZhdM`R-0TKSm8EHV*REt4FT3HX;|#kVxG9SJhEhndeZ6HLWwi|TG2e+esv z(O{+Oc!Ix#(NI~{@fZnh0U~@Qfj^16sITO+X|Noejo&^8*+iwzFXb1M9a}WmMOe6E zFraKXZ{3Gejd{~Cv+UsiYvKO@=peYr8X{dy_`$yj{`D0OeD(1AZ|eHhtg~Nz`i8ig ztJNR>%}iIzZ_e8133%I~rhKJ<%egBrg#y?m;e;*;`;m^tA@)S-w+9xYJuge2QP>y; z2!GQ0X0YQp@b53trODJ1#Yw z&A#GN(yIkX(QFeJ2bBQ`*A@>L1us5z5n1JrBr7kcFA`6t%;yiIyL@+kTN0cy?=jsG zy{DZ~YNgrnJf7`o98T_{iBP`M@mUVAKv2Q6t8AeRf7ap0#e>9uso)P`gE3U?aa(=Q zDZi|R5-kN=!}}#v4!=N-VrFOCw5}%hpxB&nh@NDNnZax+8y%(K zDKp4{rzlAQ-^qbv|7iyS4=Z@a`~~D=RaBKpz(Wo^WA3)#@2_u%Knokc-PUj?>$HuR z@rEPtSlevQW<50Dg$nHbkK2hN6`h)suc$p76l}2!a72}{5DtM|O2(zlCwZ*WzVSpQ z(cp{L_-0e={B7;v)`P=wu5dyW@&~%WT3h6>9Xif}aT5!H>@uD}Lodapq|N6z&8-_o zXgnQLlwOUw^JMZPU^zj3pSuQYmTT_)bN>*UHxEpV;S=tq#mF^Oc6%xX+1gFi-DZ>XY}> z6t&Hlj~Thv{49I?9G4WNsJ9&T{l2#-VWRrPe;P;lef%#rURuwRAH?Zk7X@!LLm?7>=91b|^j4(=el}&}kK{)x%Qu%f92||))%Tw$x(VMsz zPQkKsiA*4@#mt$_P9QHSZYig1GcL1c~lTu*TC+i zn3FP_6Csi@o4lA_xz!PVuq_YeG_Zki@#+!dHjVx!XE_bm(-?JCZ2z<5XaCH z^-d=F``D)jbJZxc<1YbP$E<(L<9jTh7Pmf0Wx_nYy+mMk8+8K4TfhA$-dPCHIP^E+2mu7(eqPOo*C2)8-DSsnQhA}n72Sr=(JGW z8^9cEjK-Mzaa%0rrwPtD68M4T=c;YKsR)hlo;Mg(ve98*LO-%qZy>J_ed-(?St@oPc zIE3O^JOWf@^#d03k?1Uv&rqHCWd74&ZsEB=v{6sqiU<(ZZ@W2A_d9wbBQl_W{&}FK zTb2m1ILn%8&$&2CQrL>J)#m*2AK+3xyM;-Mg3A)4oHs%QnyEjWqpPhT1SajS{{*8}69 z0M!UC^|+Jn;qPOUF zu_z!ziUIsVK%SI-*keGz4ze}85Az4>o8Sf)stS4L%4_?nP};n{fgRc<=OkCLzTXa7 zAeT5EKJ`!3BVoShIt0bb&SCBo0yzMY-#QP{M76DKOKO6NSTO-Mu-`bA8#jr+xiOA6 zTs3ead-KES*b;$CU^8l9yj$c+n18WaoQtK3aEnu3Fp&->>Vk<(F`=6o^S_~01$DDg zREX~1WJQg0s+ZIzLrtR#(JL-MmV3i8ZSMRRY;`ue)->uS1#{?DD9=`HkBH6O@FyU{ zAeVK&6n!srgPOHhBc)f(hNuvlaeD{^BG$L2rY>qnabW#gq@OVF$(f&k?XmVt zUNvGaKkt(>yWN6TtZ^zU*T8WDAE3WIiF;>4?`l-HlIEg&ePr+etAa=WdfCOKIcT-u zV?O-Gs|63)EB?+YPtn59qX58AYc~Nym$L95EL^95VUt`!hES33S<6KBYeIfH$466Hh;Bzxi4a+H+Vt;Voq3uU?uJ`{ zL|h;Prwl_sF{g2J=p@H0dMNv*Ow3g71MdK!6b*CCY=o(loeu3}GupBKAm?c@b*pcZ zk4wiip5&CK4@mK4x*_>TWX@JsvO^M(FF0-7?2G=xt2|KS?cJ1 z&sXXI9oiUkllNcgIPdLN2t8TneylAozy;Ad2R9yT4 zyY=m$RH0J{1KFPWY>i-Od6Bu_rt44`H+LezJ~g0Ee^}dg0(Jzf3V%wXA%Jcq!SNnE zc&{Wm4(O;i9K~R0jD5PfE7&I@C`1oC({fw@{1sUE`_CsK6X2v@&?v z{R80l1E!8tGlkLYXS6nB^=Um9Xu0e{8fMDAq0T!j?juHwG-AYlW+8g~Z$P@A;UVYX z(j2(5-Gez?s6e}LKordT*TQ@v8MDa~9|B{LOxrhKVefikTFeQcte0V%mimvNa4#na8GC*O$Tzn;S&tMn|5*LNeg(=d=W^ zlyAT5`kc~-6EEfc%b9Ze{PQ(36D@Iu;8B2n9iW)H#YCFL`J|om6{6Ri4jN^6Kq0Ul zEN@&VOdVee%tjYpK=Z?4GlZx9EqMs74o0`>` zN{jDY#aCw@ixM9z|Fk~Ud9Z{@+YElCEUjQ_eIfc65@2KCm^uG6hDjJRbOaNx1Nk5g z+M)Z<850K4V1He|_aoUo1HZtkN%Onk5f4WA-564xHt!g$I284xEdG-By-X$Rhfk0A z5kxO?7<%A<7|xy+qW37H>)Gpi64%Tl6+<|dD*+)#D?U72E@r`=!)**Wve6U*TLjM) z6914Ei$$N|7`|EhP72@cDxSh znfk+ok(+~sKeD$S;eVXU2jqGLtj?I8zR{Zzn4(bBj|-NZ`SEZRpo!(dCYq&C@;TMO zYSko&!Gg@JZ5T>ikTilh3WtufFbP?)rm{-GESA9c`g(^@)lrW!2 zBW?sX+{Og;WfM%BWB$lQ&V1Ctjtg8&;1t;zHDqubR{L-<&BRTK0VTj-dS&$uqqJE( zF}-suE;wuCtriJ;n{GlT0$jEqVSOYVxbx&vG13;6JIz_y4u1^e$h_1sq?d-DY~=>z z(LE;Y_VwWQwUGDO^n)Av^KkQ0$FX8=;0BuTRc&=ZE|DmMEw+drC~mgVbxt z!Vg{qy|}fVXs3kx)5D_L^O`r%AmZBng<9-c{x2u`%HdLabz@^$%>1dg26eBpAu?D_J5SnO5e^i@drG zRHabM8YBMVKceMtmbNpEV6+88#d22+j2!z7 z9DIfvhqEUYeb_RR^KYfbdh8@8h2eJl=eqxh6efR9DZFr8sj-G{fiud91UwLIo<9*1 z;1p9xM4=@L*MNLUTQ}iC1qp6DoAfO28kU!WSSiCqBQp8(Z z0Sr6KMdLf;Yyf^Ejcdy$^N$L1Z&dHw7lEEE2cRKbeWycU3MQI;pM7tQH+(V-j5G`h za8OV_yjCgsnPX1BW3vQ%f9Z5z6!SlKu*c&AyV{T zF#LTxf!>eM`vH0!dCwA`_pjeVa%>wEq%eE}8H_*5;P(KNACwRGZ6($z*qlt@R@XP>Kh~s+eYHbqGPDl*Il4cO)Vx}-3Ry+>N1C$ef~_r!WST|Nz({lL zAHX7i0Kwt9SqDbVF_yG*rdeU8RhV{~hOi5xTe2$6CEr-ob&|##mwcj$44&pSS$b=e z1eJpA3TxQ@6&U8w$o^6#FDNR7=mqhA`N2 z06B4h(^A`r4FS8z85LHk{n*0i?YHNO{Wy?F4g`{LpAmkl=-04qcq~TT9%*|{a8f%< z9r+YDO2r$`5(cT|8T>5Cz|4Y~_!v7D$qj2{1CwKHxk3Xw)A%ZkX){)SZNvVRn5r zhS^%*u{G+{w-_K``axvKLEsk%* zmbCQaYNTx*IDqK_+GsEj{MKq|9TgqJk!~~8bS`y2qq^U%jI*ISUIs+kOgY-t>ybh% za>-M)yO-tbaefM}=OiC@Rq`n(HOI$5dLsqC@{? zG95aDL*k%Ha!P6WcV7%-fEg=%9VvaxE)m*je#qLJRL6C1)zf1m^p6Y1?zJdG?0-pRQA3BT#5OE7A?7wk2$OD zkhgv+Va+C?FOzF0vA7G8FuS0l?cK$fLl(AT4@>g6jtwMIRYM4v1N!VnK&0(P;Sop7 zUaSpXR4Xs!iT4+p#kvt9FB9Foir6AU_s=7q+r8xzGbHXy>bun;$V(WI0gSv z&#Mvaq+dAw6U(4_(Kz@e9V!h>)0`C2*o{Z@DhzZ_gk-Y&fQ@Y;X@jgf^TL>nx%W~l z6;Er(vo%?PD^bA>q3f=&in-2gG(2d@&V(HL@CLiMX7nu5hDnsA!#a!NX-*8yZj=Ew zlwVU5_#N_2*2{W9^imeexaBV$xAGHXEypa3Tc)R+%F!n!U{?+L|cZW9g7ZwgQ=F?!ClU44D{Tg;UIk$Qo3KI|+rlEy=nk}@NIBS=A(+!{! zkafV7mzY(GTYEChwV<|?tYg|;tgEGfic707=2VSgxKx8xn?;|rkXtY_eYh4Bsg`B5 z3-(#mGC{RDRmEx9_=|0)8?=Tzsb$-7cxKG+Auc%CLpekag*TNrQYhnO8Vd{%$!01N zVeOnj#4ewTY6s4~v+!QLO?EPQyq}!Kz>wheah8_rF9NM1kT+!Z zqU%#VX0ThEz`j6q4cvw*l0mB4)uHhTy?8`X0KA?Lr8h&e7L-MGW0gab2?g2`xswD^y94(q1&yuQ`?UdiU- zl%C>8u*5v5J3fKVgK#ZK6naq1AK(BVB&CxacTUVjNhfO^u%Z zW}w=VJ8ZQY$0t&rfR5+vPy=<2DWb_BJ35atiYtDebCaQxNcPG2qD>5A-7~lFZs+eRNyuwMX+tN z!e|(LiK%Y62?EpRS6DYxvr!3gnnAdyq7y2#Wj_^atis593d8`+dTu88kDi;q*$7Lm z=jJc<-0~CjN()PQ>JGeTW~6m@{zZ53$vU}MeanhGjuJ=jT^Q)dp}Ru=DF=%O_m0Cs zvHyZEg(WDq9GC9t6X}HL5lpD$gIVnE-x)ljFsZ=uL^`Eu`;?w|l7*EtKG9ePPxEU8 zv8I0oWC#=X*-+K3%4lGFz$l^bO$~ey9 zqoLAgVm0^UBZ}-5!TSDN^ND9&0*nRYRNja9^Di#*xGbBB@t;&y* z=&tK7<8;I~?ZCVxXwcj)XA4D{m!iJiujVgGV>y_{h#xY~mpyaA#>|1!@9Q#Qf;fa$iuPQqu-|s94k9@fE5LhO`ZVdJ-D??4g7GsVH zRQWaV15X&RA?||8e(<TWpj?SRE2l$#(f6076MBFLhZ*z&Ydz=V9s|*b$<`-l`4>?q{KK-#r*ljd z`)2iLaxIXPOFlnP|L%e6r}YSr<4wQ%+FvM1jEs+&6N+e%lNPPlB>q_ za^3Y(tR7?D;n3>!6$zeThTKJ`xzcxll2`|hd;VDHCqE$s_0a6(`+QTBH1m3kJuJEX zHz+ST>Ql$yt0R0_9{V<97BLP;NjPRtRGHu2C7z9gBKVDhEMf|pzO0@YxOX`x(v$*9 zY}7wPwx!H_jP6mAW(ln{1E*#*VII4~mSFvJX6;ffXcPV_7 z4&8>;)ytH@2)rVri5L_30voPhg0a^-p0fHT-?HC=HAs+t^BVRc)qgTN@Um@o6Bul0 zXS><(pPz^Dkep+an|Iq}H++Gib?s&%*#e+@_}5^X+A|XEHNAs?1Y_d=yb4S zh1c<~Cu&gi)m@y5bPACB1QPIqm_l?c6ccE7N!txPW}0`u&}Cbp_YK}b?>hKA5)$$V zAAcimdxq5=IMV|^1|3ruQMpQ-DHOPy6X<|O7so4xE5xNmZ_JCl-U4>#6-~= zH@2mxVxI$-(xiXZQOfze;+GeF8`}|@0A69h>G91}0hiWy$mb$*W(}T6v!1quGz=3p zF}NRkCdO>5I|gl;QO=Jb{GGoLe!c4l#9wya5JCsP{I6X6?rJg(91DE9FlOpbWv(cN zr+(G9_x5R_Ojzwe6=7yi&R4*0%e>#{l`!Jt%QFC!8bhdU6+OYxVD z6F53;gE*iNZ#lI)@X{pWXyV%{>>i_CTaf2 z#WrGOm88e-m{OhQ;ZUMpA1`gxn^P~f7QY@MEqS5j|3m+V4cxygO8YlvjRE?%*VtOO}k{3{VN}T9vLQL5{9QKFcBSZ3#b%x+D}MnGeKWWiUSBQIf=9o5$P2^b075O z5b|0LB$zX!K_sz>{L<#vL#^r}i9?$?v)}rbVxZ6}M8~L*!vE%^(jL~CJ1@5IWiEdp z`HAOVA!b~{lNJ2RLJSsLzp^kXD-Kr_P*3|Df%jEBWM|;V2^}wm&d`p4a)sVjpo@16 zRMQ*$>SlJZmouoi`G&ifqWe-&@(#z`T9-K+I0eVGSVYu97H6 zVMUp9<{X@4K$OBN>+_~={>+$Ra*>S%e<78`hzwRO_HgRo2Iygle0}lKprm0rYWx8* z`35_j(_VNvJ82_`M!AL23W^Hl8IH0rlf= zaaLXWD$az+JQqY1I594rmizHjSv&$$-MXyV%Zdwv1^bg|xo-;jJ_op)Dnv&q>oMr< zNGL={i~*Z?-%biY0RJv8N^w?PIqGsP>4McB} z=s^hRZLG~WVnxMuc_Zz<5@P{zilvZFISo$w$R(b%OBEQP6z|!5XZn|{50$Gw;zb}z zoAnRe`@BUnnr7MmOLUXdV&3 z_m7av_yy^rK5AjoLOYdom9O{^^tJX2uLrmvz_WaVXNiFX5-;XCZPsBDl+=OQvg8Ai zfG*r$%$71tKn{mj(r{UPI-6ts(FTq`I`=c_k2MW+P+uhbu51(0fX?>U0#Qp z2>aR=^)_JNr!M0;JKv=Oc|FqRrqA45ZWJunj7K?0XV_`~Gnbb-P@vDCN*~5{@hd_Z zGYi3br(K>FjCD90ppFCMd-P5gqZB=aT|0yW{tHRC-)AI%G3+oCqNadJqQD9BXs6{6 z6qwq3avmMmPdk2xT{%Lg$dkA+UpwZVV>7N&e~Lk>o!o`Hg?PDzr>>ju9!K=SM`(5E zKCeCqw_)N4Nc8EuT-F81V#HFcKMmXY#V(z=V@0U%Tn?v>nZgl^DRe82F<_Y5%SmE3 z^8svzQ7&X0*AVb1jW7K5IouCLn2vDI65kW{A_o$qctNgSzSnNOjIID(Bq_A8oV8^w zIcMJdj~}}(m8$MU7oBn$?9oTDx3J4CJ5y4aCbwAmTW)yb=uEP^w<(Z2e=B5Kl zki?a{G315#GGtnY*~`(WspeZb4(8;)G7=KvtT8^hEDX@@vG)A0q`1+IGH zs`o(`8H-RGW$rO=1_=zld<_q zg^^Pt%Sq(+pQv^e0|v9(b?Ar<4$Ax?qeb82pevoEAjgw= zU}u7orly#=a#m0L_d4<(==D2T<;U_x@$)L4IpJQ>oyGdiL(dFj0G#Eh06q*p^&Uwx zWF+NU%?`G(h5tFSpDSZNsBRAQcA)!z{QJciWk5mwIq&#QO3girQ^1!j!8brWj&Ky3 zhkFs?ei$4pM4v=Tk;4?fCnnGy-1VMt^W_|z@hz}UpUd$?H7`<&X{~Ry=T@Ry=WI^gbTCIIJaO_DUK)$KGQPG$pP+^uSVv+* zM}B(dVEmK`&$V);MeU=Qzgn#*M0eVjJkiYyjL3TG>3rNZSrl6d{R4A?T*X%(p$h}> z06Q<0GNtJVPkcSlbKcY<3bJ*9-RP)#a%=W*$U;V^+wOxbQDA6iBGhymc2$Xvbd1mX^ z$S^Bs<-?;d2l?OxD&tk6P)G!Hho|)V!2~S_nY$EEv)@wZ6~~wx;ofo5cRUT)c>F*g zs-D9(-}vr-vh&&Km9Gg||HPvZy$_tw!w5w58DAe__^>A<(yaiV1e0>Tzj>h7Uc!(k=!Eav#=Q%tLZe zqd81wqaX5t$_-}Q6VNOgg79A%%y%>TB9^fWz#kz%Yn?cnaR5ZPK;SWmRX5VubT|xKj%xv+SG>($o z?l&5J+Bf21GVQWBiW!n3O9Si zZ1h_`z^n!{14G?c9mWIXb_<^as*Zz5pl5~X2h!{4($Q9Rm}#zhoj8OO^(@Ac8rOTX zGyU_*BFq25fTd~AnT_97+n+Wk;?rz>o}Oo8^k)9d$qW=Efpe|Y#l*@Ua3c{}Ab+4I z;}%|e6xu*iTL}rzRbzgOR@lZ37%1O~7!QT#Gk^)h5kbi76ob zk%o|{5raybRI9`cDLbD_`QYB`(gSUCV~wX+g^f}7>8I9x0ne$6I5+|eNG1MCG?eSC zbtHNrObxZQ?z}N3jBrIdP9ER~E%Nw#Xpt`t11cLmWflcn^$qe0(G6o1FQXGd#h0*v zuY!}H*0S>!1H=z59=UIeU_7#w;FMqo`o4p<@c$?tfm8C@1kMX%QvRT=NChxtAli z&YXlTLllg)+6Xf=;QD&e|6_^?OyMax-zs-Wj*a-uolBhmUt@Gn-yJBjUaJ?s%ng`T z3o~!sOw1SA3ohVj`2JQ6z_S{s!Th8+yudiFkmfJ5K0R%dx4i9qdeS6#%t9=2 zmMfMcaQs(&`r4#aGM_rY`gF)v1gvCy!*}i|?H@K3pPq=cy-7KMh#o#cLEB0RZ+|z#o3{YwBO1>Vq+RkNAJ`h;#0+T&{CB>2%56jgy|cdgEGDzj$N4~iGiJ@B zKpC2ZkI{)E`q9wOx}tbyZ5PP!V1$<9S%^+WSAgn@9jtX_rr8d`Zr`iwm_-UAM3|Sy z7NYLdMU-;Dyad<#aEpJ5e!@85Q{mB_B?iI$G~gXj_N2gsw+N!Sn zJ#}>UChV*=H#JoEPZ{JmJw0RUPwA=XCjxHB(;}!Qf^D!4V?FOdHlOSAU?P}|Kblod zjYycA{$<_0j9^+kV328F&ioUu#RBcmMFzixfDv@<%dp0D&0geo503xX_EW5tnk4%I z5(-JHar3dZD#U0#S>er@3$@j_d5yDj))%}qKe~VDA9K<6cJQ`^G_7ua5RMH7VI13%Q;Y&bjksf33MO28$I*|w-J@e24k$PsYrRW|Y?2?q56V}NYs2QF^&G2*gLg@gWIr#L z#1BNC$buKD6=X{X;aL27t*6w+ex&*70f1SLL5(e~K43lZ_tz!<;% zs(#}v{Q6UVgg!^)f*2Q1^9Wld18oNyhm)*WA6C2}JB2xZg@pmN@;MKIiT6Q7F!|Qh z(qi8ktaG|Y1ikET+PsVJ2|}9udF~A0l3&{*G?eLev`W?wbltvX3~jUwTK>VWIiH4X+3hf7+kzZhwi)&Y zC5~{*_A*}_Xpw7)jMQ-LlfJNem;>_}g2@*1sauHX0<&lrcK0fqcJ)tH!R*nD*+Q(3 z2u9tzspY4PP)LK&hp&vHxIf&!x}Gx&w5@y0F+wGFve+631;-7~3v34|HY~PUM{8;Xo1I0H z`WEoaIbA9+ov%3=2OpvGTIE!j5GJIZI!y>^2Wuij3ij4S3llqQqKk=IBsz(_2Puw@ z#}JWXbWkmW^S^=4*K)drkR}GQz&fra93X*|EVLtNXjupRLT{ZB!Dz z0}Bc|*Qlk6$vlB;mX?|`lp44(2)=YaYQ@w&mi&_C9SHX$EdIb+|LmJ2cE;%V2ez@r z4zwQlqzYs8Vs?)iB$seNv4`N;iIz~IQ4w#FcBHUAGNiI~DJMYv-qYzAQp<|$o&U9X zCqJM5Azr|8i_JF+OPmpP5GyEYT7-*Fy_T3L6@SvJV|Kz$@25=|5P!_;> z9YZR-!DzjsYTsa){h;FhJb1mnnTSIBe%YNyJ(9Y#_|Kui)^zn(4&e$&HbWQGk#y#!0G!$M0Hw00e7N=?uH z+xC^87!F_~c{$y^K^xE?0Esn1spt1#z;bxv=?YO=4(wq6B;c7ozFwHH>hcCHx8tWq zUu}j`*A%(8%J*-;&T!wO8alwZ`sNs%Y#;?`voCmoY#Bmiuw>VngKc)4Bg@aun8`MK zR7rNd!RcWlN!V=eF~U`}KZL4kbqDefTyjaswMyar9gS_ZG2{T1wt#g2q|FXHL&TyG zS>66P6JJ2-v}1`&Gwb@p3ezZtZgQE@JyLx+YP;-WYiv+8|>X6+fY(PykNIwq`&U|jY zG(_64djL%8|I_>hO-IM%8xTlKkRsIFE^iWWTv)mRN4nKk()ha`UQmP_U)k3!v1p9*>jcOTvt=1-Vzz+!HJ z{Ae)$7Mo93H4^u?@gWx*kF3Q;AmkmHJMrk24_6M);lmS>(LVlYL0WQM(p5q?WX)Xv z^5w|1b6h?oSx@J9t~;@a;E>c;;KXRB5-CzkF*lFxm`twW!JW;1!>VWTT2bGv+r5;v_ZfzTC&`R<| zC5#g`^pBfuLrk(x7wg=vby}p3M#}G=S#Td9>IgqPl-1xvH?Opvl6rj)suZQ=zq!;> z;{3BqN%x;|sT9i|aj6VbcOoTxVq)kqL<#F)*QD?x3B~1w$8it~`GL-L7@T#gSOFd4a#3HvZj_4ueNWJ(HO%8_{Z6^zdQokj#-W075hGrGbK=?GtLBEJSRNmKtU)b z2KE?pk(_xCi!fpiZiv(C1&mxDn_Xa(#fF z{4&&9#9H8u=zQP<+{3}?Q+!68LPmHeBg^Zcsh2yhbX2E!L(#)cD2G~##ui8NHW%{nTYxro4tPDsA6_bfPWrj2K%;yK#c;JQiRZZT|4 zLQ;}^;Jj(rDl_+0LG6FZOZ`9H9;5w2nMcrl+il!M4c6^11XR=R$`i!GOV=E7PgXI ze$ZtIRI&@!Gwl<=qAy=!=c?1MIuj>RB_*w2=~C>6X?Ce3RKWb_ZiVb&*>SEc&9e7V z*2g7(f3O!FT-*I5RZK@@n)@OH2SGh}aQ93r#THRAyf1Tt6`;k1rE=m#PipuCe^X*? z^$R3R_;L&iW74obkxn~8w2>Hgds(?&8|lLS6IB=9Vx|>&O#z>R5+vMe^Z)wbD{@^rFpdrHodRCo0t|usod*clcC= zAyr>=LM41F+cVN%2%L6WQ!(xaKcLSpD5ia%?y@LIG3){`m#HjVx zxT>dbXv>F)wqEqMU+6qvi7{&3GZ&JrohzV9Cc0OEoc5F1PA|!mK(|}g;DH#kr0RfW z&6x!!v9I(~^DGsoP4|v*5U}RBZz$*15K8kG)k10=(P`LG;M6zz=??`saeO?@IsmEL zB`_%0%f28`GwKIBM~<=I4HWTouYr+B3cR9dJ%))b3xmS3Nn(^8*Vu0HV*57S-D zhxwH!^lAn2P4&4@dm+HzkRBx)($dPlSnt+SuE%ggdX#KP&mj%9C}jr)(Qe$f2MH;s zJ_t0i{)hDn{OC1HA`CHlu<{@u%1)te8GBOqZs=o^|DSjycS`>KqlGm4C`DY{f4X8I`TxFrILnk$J4x%)P-oN@@|*qo3#0wCc231 zStL5?Z&IPiv@piV6ESq!ymdVq1kPFH0HO>&LycQ;BY-~x8vzOf*bvxdzXZfz{AUFN zrDg_(Y-W6S4B0$rCA{-8CV8(9#t$c2KI`r|xQo1O)#lgI&3TntE0tNCss^Rt2-ZAs z#tPK1L4MARdwG$CRt`q_Axjun@5^>@q5zrt^fGBEVV&>H> z5@L0_dERrrMLCR}%0`z@r;KcW>dm^Ce(J0qmkerw_1=pK2q!R$qZmeVx1-oCj259& zpGdD@8gP|jv7qcon`5+7A<4wOM|?UmKi542)-WlKAQrJ5bKb2`^t(w(HhMn3AW@{W z-n{`DX`5C?=^+FJ1#^_r?2(pIa-HK_o}g50D&XRhk2r(v&%t5=yA_j6_&h!xrTZzR ztYL0&3k__?{8*MNn5F6UK>RWWL^sr!&(O&BUfh-ckm|` zaDNEv8@ps;YfhWyPe_zjj-L!)HntT&LM5)(VqQwf^4m|z*4 zByZ41d7^f=q%Gl$7O#>CFP7a@ii@#&$W~HXPo7A7lAdXw5R^gMyfoMvrnCaf6To#$ zL*!kcglu5i+^QnaD6l*M+$m{Xf-55hV~XTw_Os;IE3iBP+<-LAB%e^$w?zkG@>u7G z^({5cT1TF!)51DD_d~^=DM_kYigoI>PP^8TC+f^%9kkthCj7wxKSqBu`YD z%}UvD>LS|^ldKbBor2bxqjltoIz6m|hPovm(oD>GzYGNU=kI)I^4f z+cnX`#7s?eF>#e93|qY#GaPb+%4UA#c?=|a2$lhuwV_t?*YK}Izt%T9=tGyE=%5py z?qTN@Yo63ylrGmE%XnsR#Y2oZM8>8gUou#d@lgXsGtX=$4eHwU@8!f!FUgV8*Lj=? z2{HB`qfl(66uoc{Lhwm4p4q0O2oVLp$5?F(^uvoOzUiw?>?lnL)#GKqSa7#97Y!{X zM_xVvG$CH~&`TI3y4%KyW)~bovs?4>NzIt{LTbGI4tW^eS68hh?|GYdKrrtS}MkN5L={eJ{hzCEKncyHOK)qUg$D?eO21of$i9sppEVx77nWx z!0UBzzYd)W!LFL2CtKmHP~pfEj0>U17%cO|F~U_*hGYb_;f-pm_m#}V;E*TkR7)DB z!U5xHn%-7dCRI0tR*A5r@nO_45LUvpyHJ6d=(wFdNOyjRT>EmaBoFg}xRMlUyB$3& z3H4BWVkKie&IejiY92ho3F-pV;(j-6!4;wHVZ*uBy*(^BJIKJnc#`R;U2UfE3!w$64(V`Rz_gZTLm z_cS@p{QVKcn=3?P%TNzk2~&y{wvl<$J&d{j4s1e^Dy%_|Ue4OqJ(l}mp2T?+3;X8Q z9W4+EH@$EFV8tqI7=42iG7O-pbtxH9|G@Yp?iigqcDU|=0l{GiZeqEEaCUpx-(Qb= z9gGN%-vCP28aTz(|NMd6&y=EmCvBaHIn+US2Yi2a;Qh4LWf|6JAeLrEc+vk@Nh1f+ z;rP+=N&JQRA!q5BnSl)#P>1?`?saEs*v?c#d=glsLRb>$xo=kSH3OZXqnVt!eM5Tr z3D}6368EvTF{yiZwtRfzFTk)qpsa#|PGIfh;jx zu%i*Lk`qTPCQI1LzxjlrncpRG2dlAcm0dR7mH1C7j7i zWP^&U`aO{T$t6wFgM_?y(pcLNxb&!*IuRQ$6QxP$C17|63Pfbg`}|@XxKYpTVmVgi z&I3T6sP%K_E9}>Jfd8k@SMd1x>YlH_`AH$(jD8Ewy&nex`}OjY1zB0o;0ObtMS>&| z7lw&BT(eJ?Cyz0V%l+&r6W4#RQ%^#rjoAUiMfM(s!;y<6%~$dlMhM&~`2_aaQvt3% zREDhPb~R1xsZpC=W|?m$Iq}S@Dq(Bah^ucBZz!jnwu+r`i55r zkM^iwj-F|st6Tf&*4%j)2!-?}%riG?KU5vn-U4?eEnV?Eu(k)Oyd-{Klu7=;oj^pL zqz4Dk+&N5#FL=DlBMht=-mg;}!d3hc90`=xQMVDV#4500gvX%9HTeiP4#^MBS4ble zf)frJDdPulk1W9*kV8q$8-&$i7(+zkm)|zdQXIEAh&Q=!Nq$G0SFFKa!nq>8{TiRNS{t<< zsW8vMP--aG!uA+q;*1C)MA`_6ER|-% zO>LatV(T_Pl9xV6-nrM?ygDoyVSmb{DC|0cwOzOt=MUwr0e5Wiba0x>)KJ6*vpyDW zC1~c!xlh-X(1AZ+!6|BCOVbvG=%=Uw=NbcbEQ{ftAQ`TGFR@FKH%d-9f=62sJ19-r z{PkV_dYYdc^P)K{=7K@-m?_=ROObSJ3u3pA?nV;yKIGFJvUA>8%g*K?JCVTt@@r5vnn{zoYEr5731L;l3`Y-O)0HL20?uVw=20;)SMYA$mMm!8a0U<5zWz zQdAW54rkMmwvRz*A-X9lYC)v!&&Vn=CerqFkdenYLUu+KBEZ#6U^7oYgO;@P2j9e1 zR*rJec_llQ-^4GBxWd1>LfQeSI|4;>?MMCy^vtOdjKD(le8NHhJvLRRqqok3ezt%J zgFA}(YP1|%^~~>YQvJrUVu$Wm&Z&U;Z3xCv^buwH%%jfaf1UZuLqT2M zK}-_!4<1gLRek*tFyl&f1hinn3!w_wEw238%G+3=#brPt0R0^#k?|QwL_U;QXI2a& zN)XB$c(5Fit@`HFTS5`&2gO>ik&<=-ItI#4v0|xAlvr+L2Ov{ z2It_ic+U(kGmu2H7itt0tTUGgG}=W$A?;ZubO*UBMbRpnM#?iJT2>aWNcVX!MXVhh z&O?` z4t+AW^I{Bk*1o zhms*;ai*rZnkA(NXz7lD(hN&QTGD^~iV1yob)@QKv}|6*DvmJ2n?5x}A_%p(Wj1&e z8mdXqvh9iZJ-xjJxdMT#+LI+qcHIPZw(9niB#VC6L7k9cmY0w=@Lt3Rv^bZtdq3cK zq}gXP6lc|pg33B9BaT4& zsenbw)JMpXvumg<>*6{)E4sSZ>{v)~=Z927dy8XX4=vp?Q2H}t;l#f|TF?H9w8Ah1 z#{ymgh$U4%0}PBW{tI=cwckS6ECoZ#Ak`z?+Sarj5kwAq8TS4#2;GjZg ziuh9e)l5T=0jZsAxK7_r^p79e`!EFMR#$=pf=*rsE8d48w1!d22}+S6ld{--<%qF8 zY9?8)&7q-jV_KJ_(Rz+79?ju~gg}1N7Q{ya6pw&zi0i{U^Lr6m(u)BNBct88a2iu! zGFeiHt_wYsq{!jj%dlfKMnV;(R+jvP>-0J1Vys07IJ;ubKH5c?J44$ppCop};8BRa zfja`RFg?EeTw=>cKRQ?nPduu7>d{ zWhED?;HPtD3x1&%9aBVh{Pmj3&!9N2_ZOEoCZQhAe9aW@w%Iq0AbN~5dh)Xgjbz*= zzrel56woY?Xx|HhTjck#1w#ETj-baQZL_5pYBakf5fnek?L1h%!Y!XRm?7)J=8%C# z&Wu{t$1P)6xWUmn5q8cz`Y1aqf{9UiuGo}pRy^0skv^Xj*6j3}z1LAUAT_F?7Ja+P zBvTJGUABt|*Acoz$d)E5*~Cu!+9pI1*2N3Uz8v(x%i7pb^Pt#-CR1RZIMv^l)F7F# z$=OlMMY70p5#d{pj5t8(t|YjaXq4A`5Qn*2EAbS}I2s|O>H65cbi8yRdbQ1Owu6eG zQjY)?ug(#bed#4!-gh-QC4D;h3LP8Lsn0Lhl#H8etaH|JGhQJ5ar4e!gX89|i_s*U z9>^;R_-wp4by+(}$!b^-<1bG*f(y2K!DcchX8ySY=@wlQ#gXt6VQwM1KN@3+-I3Dx zQ&&-s##j8%G+W^Mxqm;rEV`5FGJTBHV{W~xSQs!tT zK`+JxBIsNpnt6ksY@pQwS#i?HRq=OoRRjMnw`_9punaK7^Ek@&RiwH|nk`@4!zwnu z=l+F5vXAa(xbzk-eZT!T+@-Iroyh-&xbzYBU)ZII9{(

C^1Lp)PHU3lW!A{DU;@ z-`9mvz3~dNCucDC7GcH=zHFvqg`tA%IqCTN8^$HMwwecOI_)hilVF|Qff4`GUH2Mp_e}CmX6y$$|kMGGJ`s2HIwedYP%knSk<2!Qz zd`GP|zT*|&n~Ei7HW>il8mo=3O7UIq<9ivO|FW82)W3OOu7)A=pmrATqr(kM=;)#S7SjNGiGhw=%JB$>yg08F*sbgzzr!hT zZje$(z%`V<3)zl1ih=wA?QJ$Zd0d4!HaEUO1cPpFX8itD)CZfuF|O$~ZBvNO#SR0v zDXi=8dgCtO4cA+ zA5aIsO{stPSHkvtk(m2QF$mtw;HrgGrW#sh4l~@GHN8;aMFSyo5`Ka7w1*heMl00s zTaat#W$Sub-Sb%jGX96a%{a9^O=JnZQx+Gex#|65i`e*`XBP(b1C26&^sB=ZBlo?cf3@V~(qc zvl=<9QdGs-EijhxT#G>$@W`k`bgTOwar4!o1RyP-EnYlmm0@N^R`hM3*lFegtq-UN zegG?nOkC%WdbUW7$ao0G;AT8;6xY49c#+Th0yqz(uJSRgyuC8Z5u~M)IwscZ5vu@( zPqNd^JB8)Nc_pK9tc3Q1aZ0In;PQNkV=L^Et0L{}j`RPX_=5U(cdJ z%}%PVHIE=mRJ~~uGwMvn5)tc$koh(b=*Ad!$LBCDV$43yueAZCl=jEg8m1aW`J?AU zU!*1|_-hXa+P|ks({aD)xY^Z5BH>}3`Q&rgf1!K<>Tyh<^Om$D=oMRG>G?ehM%k4% zoc7}{@PcERn#W(9rdBc)V+wy@Cru^$gX3tBfH<(U*;_*UAVAI>1X3(F$whN!Ck-a~ zV88RaCMVbBIC*}P)RnEr4{!-ei9G|+;jhlt6>S(34DT3NAEKIz*MmLk`=R946V0Ax zv*TsZzqQ%^ttSe7l+9Mt!SjY#w!F3;4OybabJ#!5ACR=V0ZAiMzQs*EAhy;v&zz5K zjuOP2H*Aj5!kiwPQ%c5~rDVAFAw~L+v+t!Y<;>%Ga=qNyfn@)_KyJTcbjuvv9w7w^ zo5F{^p%GI~X4>r?LRIX$nrWiH0ap>I%te8UYA?eD5T}f=A2xz)d>fXXP>B@^*#yG( ze*usmSkm?~kjR|LLVj9Tp& zA@&d{@IciR&uMRpIjaxeOlK)5p*P7v;F)e?EfQK~Q2#dWnc**3vW3dll=(O=Z2dp% zeF=P2Me_du8HhkgL?h{tL}kOg!R5SjnCs{75nH!wln-T$t?-SFYndtKet)z#J2 z)#n5qpE|s@i^)|$kA-!Q<#q7tv4vwI6?64_05E^~mR$GXeQjjvTMtJ*9@`xsN-tIu zZ{?k9G4%Ed`ma*=DCxf>>A!LV(=TEAg@N>IN&3xw^i$7E->2VdQVGBRm+do|CB7q< z#r3fHMP(hMamzkUD`_)GesN><2fR`Fi@yZ zCX@PB-VMvfle#jT6`|WNV3NIKL0zBOI*&fEMMFnKzA^VmbSWDw5|rAToX)MvuRs1> z&d?jErK?@GzRQ_y77Bogw&&ia0`co~dPRp|Qc6EP*w-5KF9^t=;(KjRLSm#8ml#Y9 zEIxJ$(GDiPHF|A#As+9l2{3Sh>oLp(*cCtG&@vIQlPIFaZ<^n8_#J3eh3u>F5)FRX zx7O&zYZLD>&&6mck3d_{23E+1a&H1PlowW0L!nl!kt)EiVeu_Os&fLT5-A=g(k4Z< zUjSnHD;fj0Ve!lJWE`nP$MEL0O^}&WD}734-bZHMBc3wv%$XA@*|&Rb*BSHW8BA2G z?R#To%2(-o>TI>M0Aq;5Wcj!mE4YimfE`u#Y>SB{@D*#7*n@+v?Xe#^ov zWgq-h91`6=UFDWFaxD_0KTYmRHL?cTqDEBOndwUOXv`=l=jU*E)IJEZ z&u?&9^w!~dxZOe~Dt!pA?H7FosXGJTtIPeJZO|wg z$Qt%C1#(aM7<2ExahcFq*U^Q-vGdcIgd&)MFW$mC3wO4`CeGxwYlj;y5Vp5^ZC||s z=D0zCqja1KB~-fCcF~(M9KM=fj>V7zVY6w6V7YX<)Oa36=Wsb9_g0XU!&e&R+7XL} zBp=d0>Xmaw15V#S3Tn1hWP%L#$KOA4w^k+6<)GEnqn%Ix(Nn_Qv5k>EyIo$lYJn9^ zr}l_;NJautlF70Ln5yHv;-*2jQ3@ z8aRMY;5iB< zvmXrlN_(z;Q>9e}(%b!=pI(faUc&zfJr-X5g>7wu!&o3(|`g&8Q60ibV-r zM`tXMUcU~p=rU3R%u6`>)C!;ZCCO4YFvt{hM^MU&#z;y#)Za5BBi6t%Lu~8^-a^FH zWz6Su?P;o)7lknxz>i>Z1BP~d2lHph4Ygb)+Q{5^lt8vdjuO~OB?^%SQZD{=vt-qc z=+l1wy~R=D{bwNpXjHV^;~5Qdq$+g3G_7yo6o6RNjCiQA{1;!a{G0z^`LXB!obt=u z{yF83{|)88)L4ET={9mC>pw!3|3fpC1$CkXL4J-bOkzV1!*1YL35ew zZ<0eFUt+c%Ff}*@(Sib_K-@Rq6y=RLuQG7|CG{a%Jk(Um!;bbK7ccCjV2Z1m;Vg?{ z&XEzwQBy-3&QUYwsL33STiZojf?FGLyQo9tn< z9G_7aN~G=LNOyDZan5-e^r!mXW9blx`BaJ++~|WQ8eTS=G{d7`y|1D~!hhjz);vq3 zTnceB%RyeUq1I@qD?o!p?E*wyqgYpBZq zqzz2rM7m@V0|RD}JTBd>5wmWSEW(B$L*JMTC8;z<8-iLV|DPyPF?TQ$=SxdrUchvE zKyq@65e%yEJ^cQJ>Z)V{111q61&JOVjVmjSnI9_zk8)qF&;9U?4pE8-a0|D{flpgXxLuNiwBKEjktRmV0=+$F9*k>3Z8ip7k9q zOQSbrPvhtXhxte6NtK?(Y*vL%Rc|$2X(u&odsFNepk%0DnGGFV5fKV|@?(jGyP}vl zu~4SlPZ;EcJLECdhj|XU2!Y=lWdkaUu+iQJ12L)$IIC|ZvjluTfo|`a;V7~F2HwO) zz4&z)8CIBu1$mv@zPPERh_>s&k}Lu$SY+Y1QkXNPoXS6iT)FyhilAbD8O_bhVQ^3a z_z4TFFdfgiCX4>>>1$LtCuw_kgsc%wyxETw(AIRV9-v{0PZ~q&<+$qk5yx!@e495l zKpmv0u`hWsQX+y~uRPz17B835VB1oS=_@^cV-+79lGDf{k%ku&c}kV3%CNAVU4#u| z#-MUXwx7^fe~}q$#yI$}XVYyLpqq@e^N>4UBiW0@Z_>e;tmW!RptkFhgA0MBIW6wr z557)cV2YQ)q7OcYPvB^uw*{0I0+b!7Q+BeBE7$i)E0`V0irX#4a0_@p0}fG>mZJ=I z{}OU^bxQTszGQSk!IcypKkj~RvM6J+){af6LsDlsI1{&7NORPxb}AVbWXGys0S^7w zY$Q)QCuuih($+zeVjmEv_UV+f-U+*N&4|zECO*%K(KrC(F0IP+^_hWyByp zm_FpA`IoU%eB>yT&QhcvPLqb1bmT!WwdXfDX6JE_|#JTqK;p{RKYlVZ9ORq**s|fxW1Eu+Xky$HJUod zJhe#1L|syaB8~1Jw4#s#il%X^Xk_^2Cc)`B&d@wZmV{qZw`<6aCqc-A+K!|Op@Gj&{GWD`|CMB`#kM>`bpiZIgRaj z9OW3+6KhCb+X$p22MGfoxjS5B-UXWimt+aY=HFLurjvKw<&N@>s?cS!&=9`G~B%RgUaUeBS~D6Ho+S(nm`RbQe)2KNV5Y+inQ_r!?nje`>F2P-S>f8%{CY21hN6mmr=JDH= zOb}WrX7HwV;v0pA&Z@h~AzaE#yxttCkqRZE2j0|85*}M9|4{V(?e-u8(c)wEtClC} z>(@cD^I?9iN2!*;)boC(UZNl{wVWV`+I^){v019L-YTW)AnqK*T{Ls9xU(tll8HO+ zk9qnIGjp~>ElBCqRpV1Xg0Vq1=@c)W-XJBqZU_;W{U_aAzbg6wu#YOcpEDCt_>xzL zq~fyA*j0<_5bw5}%VY*SPz0~-o2O+#3MhvH%Athx0&{z-2}8p>Fi&}rB%WHOt{9Vv z7a$7f$GlV~sN5?+S-f^5>&*S~9_F7$;YPrMYP2eK3<2G6DWXwhpgI*lZYFaB>B+_+ zp1qO;DBS=W0uW8=%E(AkB(CHE_~86GN4qJ<8OV_$vI?T9y%&#!(>IWgD&3_HFVH^1 zNO4iEcMkF?@oWx(A`u~Oh;kNQBPgcV{EWH+{U2vm%0#=&d> ziZe9p01|%1K*gA|hK%4%iy4YgI%!xsRg%L<=ygR1`a()0>s8e-T8s6YR?q^3!n~ul zv7&K^0z9?`S4i!MByJY$CiT;6NiRiUmiAlTzycv#r^oGplagMcWtNduIvFyk)R}6Y?UoQOZ0?Yi-6cw>vN>d@lzHIy|zqvYqRuSf-?N}un% z4=FP{GKTX$^C%3I1yfc=5v)+SSh>;Hb{5k2WIB(4433uWv|!&xIp^;diuj`9M9Qqb)o#b=;>ejW9i?2 ziUbAzbN^U+kDk8sA4@+%Pk)#SQue#c|HG4t|Cjz)`d{_*xqmFZm5NRB>l`$_Z$A<4 zOCa521Dei>o<|~#!PxWF#Kq0Xa@g)zN_L1mlDEj2^Sef@R6czY$FZlC=)nk7d2RjX zpI?4~{Sh$%;sxTP_wIk1;wB- z8}3n{wk!^R3_qbYg|o4YR8s}(PFmH9O|t0tK8jVQW^sB}ni-bzMshTGnz^P~bZ~+R zlXWNS==Nb1%$c%hdi^w$e3ge_^hW<1;=dtjZDiPta{7(&>-?_xJ5TD-e=B|8FKGTo zeyn;1GQN}#j!*dV<9dR~p+uiOZct5~z5##mTs+f*gT-Z%~3yBuA)X-bQqP zARo!(^|+X~K#Sqdd?66xHM*0`qk2+MPo^=PLu%oqpQ+{WMR}x@R}RJA-TKrH9b)Zx z9mSt7jrMtg6;wuC;8Wj3PyNdAwB@RcCU^Z(W^?AzN$5`g|^i+nC%hruMwThGO0J=89hy>SS`_&CxsX&5ya?3S+LcBx9Z;8s<_ zryk}}H3|`DX7w&T#+@<-n2f_UqSk{8xO$X1h^7@xqv+z1EPSP990>9crb8W7mYo=r z*sjgE@=(Q6B?&L`AO_st*ya3-Upy{vjU0`KHduqU(NZ#8B>Etl2T0-W#gE0HbW!q$ zUC=62UUTB8cr)cMXCZJoC4!eUz7OE^4W#G(l?XQfzV-#(G!M<~A{M78D});^FI82# z=zf_#m)%puZ$5YSV`l_;zwgUI+;a`6TEeDYOy2LPxwPXDOUSQ^eoT(C@(olD!o!tE zmU@aGa2GQcisfy{v|fv5}HXf z4P~_LfNd1cQz4WJCRLQ>=s2JJO!Uz1t8BTTs~Gd_O)g<+ zN`GYf;U%m;s!OvrQ+CPD5)qqL>ehzk=TH~vso!#lL_>M?3AtKvgcd`e zE3hU{9Hje)|LiRTO{SY1;_#XKxro%t5Hp+dWPQjMOO~lsWleZ6!y#wPacLcmK@zV4 zA4}zXQU88waqvD4;zg!GuTN>_sd)Vu{WUhLp0~_$?oAA$pFRdhD)DGa zOk`g>HyzbnL8ALTQqgaA`gr$OubxQZ+5{&tmsVseUJ zD>%cT!hXCCg{CcF4v+40Dx>T6Z>lJz)Gst9<&eG}C1+zg`~Y2H8rk^ItuToS`1$>enE=}9E9t4wy6rEhD$A7+8{KR zie#!+W^mhm)Epul9`Z=0t>7+{j11D)sVy>NcJatz(tD0Q^=VZOB}5$UM1X&;fqsrAfgkq-~P+Ilb>rsCk2G(JFS zQ$)@HaI3GGnc91{ZZJry1>(S+x>Pk78JOZX2?(_;q zHUpB%SO7Ut4@lW!|C>;i>7bIK9w~xx(Fga-Hj;A0b{yp#l}JuHpnM4O86770AYBTN zQ0my!YVEn}5r#`UeF~4fNeNKKNO6~*C4EMTE}xUaW1W`DUrK^P2HDlE^q`5pAY7?H zaZ(Xbp>&@d6t-u}YxzuafISR2T83SG0TwYJzOtVu7L)xXXUKpjet#__&1D6Zt`eGh zGsmV7a&J)wn2AyIF1mls>Dl#YhSBl3tKrO+BNi$?l#_VyfrCWoNPU~u6!~WCcNp}= zcV4g}%I6Ye$eq}(vl!ThmxB4)3RE_aBrL9MahVK{g6K5S7#~u}J=PDdg#I8V+PL-R z4)UFtO704-Xkis2CR-)Uzf|F-Yt>!+tXp=k?E{!ziMq3^S|=`pqutKpYK0rI-tmwy z)MWo{Uz4vpQ{K7OiDy4LW4Qx-M|y4ji4OTBD)QS8k$h!js&GC=xaulIQpX4l>&-rS zly>vU1ExmpN2x$T#5<}HYGOE6gryD|B+@8|bc_-&q-Z@&QY2AByp92&EZ40)wv4t?yvbta zYNL2@ek~7K5}V1QWf$`y60`vt!T+VBUnIJ!F*L~M0_*Xxnowa~zuC}YX`)G~ECsN< zd8)ps9}jD`7kpEFAjaNlcqJ1{7j2LvPr=zZxY{3K9YtV4fCTpMA{rFScb`Rlu__J} zVY5Jy3^SI$sMgQ~szlO@q%54AEtW<(MNY0KW+F}XD477c!tIY_6LBL~6I!z$#9!5q0>8AjZq7b<1`x}~Uqdr|1vSTDM&qCIG@hOT%Q6R=D z;Y0@y{`CFzYW%`>0X2(Ms_$ca$t?~oDEeOH4>Yn>!|W97w^D(SlsOI6H-3lY9x4F3 zNt}nG*FX>3CqMC?086MlY(%V89MZG08lz(FwIE1bxn7$qp-QZMoH^Im1@!cgDEVVw z;ATx}2Dr~j7cDPBbShFc^|>9|%O)8m?|1XqeweEV%!wl55{-kQDhcJUsRC1}$s;q^ z>$1aio~9T0LPF(=vAo!bI>a%Wun;kHOaSJu6w$ymsylX-I*oqQ zNfKSy^}o=zh!CXbDYa$3ynGXC^Dwh03*F0sDrCnBwX9e&x&+bUa~yc1BH5{|7{7c> z-03mN?NQd6ny=TkcozE!z2vM8sFQb=gC2p>a0{2SBz_a6*N5J|`1g^-Jx^1Ai|Hb2 z68^||lkp+wP&bde^z^=_(o8u#kvxA>L8_Nbd}!j>G-_)w37B=ks1CA%R*J1_$ig5L z_bI9oUZhw7E6eI`o{xZ#s6dfRWTUkZlb_ES`f@FT zym6ZDaJaFrietzn0L&n6z8k7qQA&Xsfj5AV*Y<7?n-N4~4ah6)mW+)C!Oy zk`yPI>L5~l0Jq-ZdAJ=@Ib5*e(e9= zt~Yr8B?yFK;ZgpgDUB^e7u-TgmGM_L;*~H_?r>NnxAi#`idVnNrsBn;_JbU0M^!`0 z@5(-dy;Rlp4}yTIEw}#&^SIVXq#t26)u7ZD-L+ENn80tc7)q?L0A`EHpGZ#=@K_bzjIc;I1wvX6DzCmo9KMS9Aw-0|B zyR-;T7w%Vyw+2!PQFI6szvmJ@3tQ+Fp`E6Pb5%qsYZf@hc*xN%_S{LDPMpfCt--6E zXdgxEq;T~0rQ|+V3i>rtysryhIe5T9Kt|Tq~1f!8++O@jiL`VmtlNJQ6}rE}~zvX+^y@>t~51 za4ToW*+j|}z({cl0PEYkSFQtfZvL`2ZRW%i6&0Ivz~ilO)SP2OS77P z@l29H#*HZB3aSg;um;)#`u!iY+Z3V-2XCs%G*Ud9MwMAANIh5wzy+>GxrVsGVe4*` z3Sxv*@jGEi#oG|{uyn8OBMvL3ALrC+Ra^OI#Q5YR7D8InOKYUS+8)UtmW?U`m8csE zZv{>9VjSDjhmb6yP<`&8Ld}zf3NYWk{sN*TRs8}B#*~L$T2In93TIh)E9_J_d~0G> z76eFrPoRACmTVo}#xO{Nv1FmLN#+tLhb>vGy%pz$Jl1b%T}@6UM2Slu2X!c5vhZl6 zuVlG~ktL*&oAol3hfUFJ)E2ZygT+yi#vnSl*N90TxBSo&ag8^^1oF#+7NS4*Nf0Ru zbO!?0>uD06By1wvUZ}tEj<*DsL=`JrA%CJ{(c13)EI}Ccll3E|^jv6V) zWd>N&d>@xy&^v6Zb+vj&Uu65D9?F1lXlfNfWVpvC;&ST)Ou~9Cfbh`W2#VYa1GyIs z9oX+hSZgY)NdG4FRvdg^~HwNYYeHE8B!l_5G?gabU}gQ{3S!`g@)9< z!h)rK#E@FgkHomhka`7egp^_oDEBC&1_CE3!(K%Kii5O^DJ^6Z2^B)`VvE7J04V73 zR&fPirx5lLCLKgQdiR7)Cb7Cx$(P!YoOaNG1ChFobhtM-%2S{~*%6NNf@Y)%1)J~ac`ucN+R_Uh1lI6PyoB|_=MK0;TS%kU21k1ACY)p zh-7%xM1NFyq9?UoL4td9M1p%lvjlgMmf-%zchB3ux_ooP|E1-l_Elr~65XFCl#dQg zD4!6KP+o-kn|_!qXROCQEw+5HSSRWwxR1rU_r#VDJnTM%+A84MJB2q_DI%i2@LEEF zsI%Bj1CP>cf@j#C1kYe0-hg7F`tF9z$$L$7Z->~^&RPmW&lX4k!#vpnoTKG~RarffQWBq#u+#65*9HTod!rEE7FXa({51O8!@uAI)uU44qy;+03q~gRu96Eeh|j2AcUz@ z7)_**W)F6Wkxpa& zDIS?$zB#fssWfR9K{`51FkL~1g7`5S4ky`IL^--zROKLeY01du!13@1>g_!7mF4kO z4$p`kK%}FaE|1^iaK|qp)rQX$zx;Uk#qpIK8QG&}qwgJ}8+Y7EcA8QjN>2=qEM+0xAjJwUVoZHwGdzXqDEB;O@or6QN@$uqeaD2ODF|I z(?Dx*$J2&AJGBgGa{LZ@iN?zwdZE^HfLw-3w8Z{R=ga zNW`26h)yIPni(Gr2m_Bk61@x{XOfawo>Ynkv`X)PD)k1GuJ=C$hzq)xcf^bc>>J61 zpvlED^h4~J;hTLjaS_I3ssgE-JIZY$9SwB{xBd|a(11#~{}E_FtMmqht_S@MM@IIl z|2ZIDsm&G0n@V@l_!t#fN_8}S1F>c5e+>0EC*ptgSAKbh_6O<@H{Sje>+f&W5Zp;z ze~oE&B(VOx?YCxPW~lw!{CtI<&+v0DKWFiCGCy}aDSpcEXZZbDe*T4@ z&-3#IelF$b%lv$WpRe-sb$-6d&z1a4NGJL+{OrQdEBJXaKkM-G?`ah0dwzb#&#nA? zo1f3~^KpK*P)Pa$^ z7gJ1_F~LoHuFD*q&eFV7y2Jfp0?o+ST2i&_s@~1h%9t*S#`tN>Xz6E0@DP(%VB|Ox zvxc14(M0!Y+)UMT9}$jr#O!xnme2`#M)0#!xM_T{6aL=Zw+p z!@#}faDPU-8Mg{ACa=IFt^(RBPSHdbKE4M_+UWI;vR}Nex}qLT(l*hWo(E&-xoz9Q zAJI?HIFHWZeZ%8WS8J!k9nwC^Li+Crt+qO1s$5NaRvvYPzZ!O?Wo=9G5p@$bF3J8` zPy4(~TTdLSkBRcj#2kEM4dyP%R7j7?>UyFWAIs$_KgzG?X3-zbcV=byBbDHjCw7^S zX^1zovX~Mz7KQSop2)#V^|=1cp?AZFuo@VVPXz5a_+qj_d#KiL+z$z zztk$e7?WzW3WbTZO}J-iJe3zq510Z*r!usy(LhGzcjHxWd-h%fctG`73&DzD9(xmK;c9yB6GxedDbpdO$nl@W6 zWpv+Xq7K;XxeGPdVCKH!dX9T-U7K_5p&Ys#+aknr-)F9nnY)S`iW8X5yKgJ@RVo!T z%G^=b|LkS8Em$=9thksXsLI|9MY~%Afb( z{2FUM>n+VDYikywUO?q*B5un{;g5J6$@ALor4Lvt)S?YBN;t647wDP>DVNdVJ@;0x zZ7ip>&6VFcap$$&NI|o_E*qH!Sf~Drkq7Gad4TFxnA*hD6)x)V+*HeJyM*GFTZdB& z#BK-WtN7LHok>Q}6LvY(I{iEoHxx`>Qv+GQ8pAY;cr_HtK;@)(T!`mK5#M2gI4%y+hMgTRiJOl+jJ(5`geVAMm{dA_k3;41aOYy~J z4ZPL8R-+5VZ@c1AeW)zP>-w9Ix;v~C@%8)Vn^d{2FV)3&c?dN)8&IqQ4sZK%5R~$9 zTf0y=#ET#E3j2y?As`b(**KO3wi_W{+wH`lJ8X*|J)S7&ZQtRxehCyL2s)R?;UYed zZ8635gw2sjw9z10L?k_7_an$_t1C%P#s{k0{gKz!l0rOT6AdXo8=?42a>v39Zb@+O z6d06PR*}8p%3~hl-OU{Cy|HEAoJ@3Yjy+)sNp$aYT@s5`wxrDDke~M?*4gREtV(RV zGa+;RAa|1phkIZIx)p%X6LGcbgu5DG^|e-=vI9pG>r^H7L^hj8AEX_d)L$FdTF@VB z++o9@-wcM1C%KQpq=PNYio(VxdIpAwT)xMa7OG>#F~a8y9}XOt6IbF=1bYpbLy}tXmf`?&o!0 zZ*i^lj&++O4)*0a3Z=ZO;J(a_U{a* z3dMI9d>?E^e4+hdwz*H)p<1p;dNeIN4KjY&i_@T0Z53Mqup@=*BLVBT)EB{kSdwrt z@FVN^vl%K(^@}c=_r}7Plp|!A=st#)ecLH&+2J#(g^6Q7d%gGV!LqpBzOq(X|5}(_ zv4qchAE-^%V@+WDDN;XV3tMobn1j|XdO+$T18>m1vFQZD1_Ov?=*wWNI*QLg6VWx` zuafu%gr5!kChL>9rEGETih$FiNr$cTp$quU2WK8iwHd4pgtt?nc?~JLR^n11qyC zr-jMHc8D)3ZN0au+`0q^vHOIQj8@kF5av8;7hfZ)Q8IyDL0+*A3Q*R57s9E}xEH+k z+-1M-L-7pWPV8y1uk4#z;s?ZbudiMtnYgzOn5a;C&VDJ}QtX5cRi3Gx+?La*B6M-# zeQfQ(!!+9GbOgfkT<68cn#njfWfbUBe0C{YeICcC8!XKPMPsxY$^NEiL>OaF5!3f9q+!yEG z6+8WCp39jVJNYz<- z=fHku`iH6te|B9_lQ_5a*v06Eu%k1s-1;3}tLq``3%snN-a5|X_$B7(T^)#RFcRk` zfrsd9gs=JLqN4io4b~y;i%7D`dekea$CZcufYF1VUfQhhyS{i7&IBWNSykvJ>qB@!o3NByEASS3 zV!h^yF00ZuS*PHWJk4_JWW24iSnBMEYr7{dW=}CXpSX5fy950W6o-|Cr2uO$zkCP) zReU71V;4O6vJDYw>)9|sd3sj%vBn_=h8^q6+JAz{7Yp|J%5{O4+S%z>5vCumM0W1a zR$q!ty|$4zvcn(Z-ugku)x@CK0E-9T6EjA!t9C%$L?1lp(`!35gv{1c0?{0vNA^&- zdn>wG>`NjHdg*?SvMk*3=kRzAFqR|XI1n$XBtqLg9zh6H1DvR&qx7Qn-0Zb2W74y# zI4>Hel+U7Uh!GBqT{Vw#rQLVsyiz3i&NWQ-F*z)yz(ge5U}^z}la|2W0!>&Fnyv#q zo&}FT8vZaOg84hj^d`y_EyM1$4ZtgGLPNOEp&myac^(6EBs(B1%yE&-FvsQd(J1vL z%rgAZ+IYcv1pcgAb5V$e`Yd)$LMeU3!**f2!eN15pyHvz zqv4vaM#ahLM=q~5fxg|=MFh&NmmoUDwB!9!1Z}eZ(5SYBvfX6;9&c25^k?dA6}lSw zNl6e4JBTnG0jfng`@`K~2hTvM$`eC0JrK7Vhi%gXkuW_R_l0@MG?wv+P&;NAP$WCt z0NitzNSwcZ5=47?p$&3dpnJR5wv#Z>iaxqlV8x9FqUnB(*hum(V14MhnN(Ht+7=<1 zcT<_S_FbaJUa6V-3>iMf|PpvFOV#k9R^gG3$$6jxuL6lryOr^of~@ z$8<=1gx59?U-(Jm?({xWzI(73z~8y_O-qAG@O%zng#hn8vBuGJontkc0{%d~Ek1hs zv%=5^_XJyC{Y5rLhvVlR_d?M=$nB7K` zP>{yo-y=ImS!g@)oqVC&aWH#;I3A8mZJE0Tng_@7yu;If8vyO_tV5DUHd#EXSRd+$ z`Lgh!>?5dqkaj>%`@BrMFd%JYkhJIywt3sX2+j4jPvz$rex~ztEI+gOIgX#%{Jf2y zx%|9`pG)|8A3yKnr;MM%?|1O?IevDfHq}0wpL;oECO;?fa|b_n@$+MTe#*~1{FL#Z zAr{ZwOC-dIqBtB2dnY>H1-qBVn^_z zh>-~Vg>)w+({&METP%LJ$QQ$iw~Ghp7o1o)x`d7Rbg}FL`0X{~lf@A_4HA@ovZy(I zir9WCIDej`DDgmz^rJ*qI;2oT{*7wDdpxC(wHR#9BA~)%EkqouvrK!4!|ma(I#mw$ zy4dBl)}Rjr4u_PUcF^5@9_!}urn)lMm&Y?3@x?wmrQv5eFu&!d{P^qK{JN?9dQPY* zza-TE$(rSRH~r?CpBa;WvHUN9UvLiIRGvgP7UkR{(EB7g#CxS@aL$Zczz7Wik>ab* zaCC7OMvKGYA(8l=e(;>)m>ZCZUSS};V>-PMMeo+RzlUC%Kzgw{y*Ds|B>uu5`91V@ z;dqg$yvO#a@@6S|A3pkf=sgrjue;7)6GiXYdB2C=*Mamd*6D4)0F%mF^u+I?cXvQ} zyX)n3DSFpD`FrR+6iDx4onAXduQ8oSGHI{TK6;>HL#k`i?-DV@*A3|Q(WJi)_pF7c z_Q+r7w#TRAWc%+-cZZl6GSP!NK=ntYL;a(~Yg1WXm`fOeMuNWQnNO8|#9=Okz(G^u0DQFQdZs2k&I?^`DqKMN zWHF=$`MNzcPbxZ@MPT6HL)JvJFn7@gd=mVl8t8jGm6X2I&!!?6(?^Q^$4y8x|ClNx zf5fb!hV-5aqQBMbq8ahglzx8j^j!WsgX7tM9~~UOhw_gJj$cLi7D4c93V>xlBjc;{ z*_Sox;GXfZ=f%U@WBu%87NMd1sGmk=WPp6Dh`+UmOjTe=@7Y243BmD42_IuNafbA; zUtP`0W6WQ-uaP1HS1tKkkRiQi9@l^H{1;OGTZ8JCr-bW2c=}a@A8j5&81ttcQ(|=R z^t8iQ#01ClAg@Jmd?=;=#cTqM`3oxFzTkLSe?jqN-@Y?du2KF_t#pxBgMN??(CDig z(tB9`gM-z-8javyK=Hg^#{TjV{bgq11n_quyq(6EVvvc3kv`#@;r<9?83OU4;*d#z z=i*FKdC>}YG6l{5g7_#ge-mnPhZDv_EJBafq3pyuJIQ$5YIOEQh&D_h^Kj8@mTQm+~jj2O5;-m4q6vniH z^ohR|aYyjhsDc%ad@HPH9S{?uQ$=KMXLdDOnQ zn#`HOK6sK!#4m?8{sDe$CF1U z1_#H}ilyij9Djh)HwcceBz(2mL>u`}C;D51`p|8?f6SB@kap_?5x&7LxJg;CQZs zQNi)0lz#X5`!6ZqTl|U5@{!}^7ACaOJ~195w2R5-qwm2S?e_g9CLB0rWB7RAYMhcQxQu`38SsV!s0zTRrd<^L^-al)yT_isv9{N)gUh=muX#WKNVN8wb**}RfM={ku zbTnJke;3tM9-qz=&7?G|sW&H#9V01Hgj*l@EY`;r`;xT^O z-qKH*q+?1?M<;@lj$x4ffsRSwi-qK~VnWd`|Ue!O8~-Z8swW_>UCJYQ$^e zfg164kyj%g^T}rA2*|%cboCo<5-$Gy^@lY!`2SZ2^^Y0N%ux*a!~Z^8!}tJu zq=>)KpPEEqNbfm`1z!;poc{{!RTEo-^6xo7_@%+|QRrgDVETzdy>~U%ug9~T@STF= zBM9FhIQ}R$=!xp!<%=Tx*5LRGqQ5jazJT!V;P_<1PpAP;`40|`=lQ%C^AOBfKNgbj z{EvgmFPrNxxct)zAKX5BDu} z7s~YXuF@A{0otJYXx#zn9hj4+Hv0YD@IW`sT8O0dhc4}XB;~y~o&Xvq{>9m*mTUsGsa#+!?h0ME^d-Z2Zjj7yiq33@aQxcXoJjG1h00VC_x;VY1BUpBkxmpl znJN162cGB8)|!VPKKvT&iG_+o368yg-tnPFt`}|HZsORN9{XJ@zrudkm@S6XCjKD5 zt~*SL&GPkV#D4>G6a#&ah49tE%eR~)RD*xQ5?@=x_=0ScJ2?Gx!cPc}w-bJFa6I>K zor2>JQ2GYJ@wA;7yzA z+TiIIQu+qL)0Yyy`h%ePrxSi_aC|A{9}_%(I~iqn@bv6|{IWG@{u-qp96Y^+@STF= z4{-T|<2A1TApE=IJInPn*o)3*(c0OY{vdF`Vyv%|)LxdFV;J!ZiQhYew^wCh{3_r3vvOL?g^UzLZW|XaQdZ$9~~T@ zPWYJMcst=+1jlQH|0Ov87Q*ifjz2*1UmF}>N%;A}@w^u14n7_^K=K_OJiX;2;A4X0 z4{-T|;xXTF#9TS%{sQCuBj&+|5x=Gs+Kv7!QnY^8&?EWv+wGwqPq>50-|Yzn-fMfP z1L*h$bpG@>tV8Xm27jZ9@YO;3vvd3b?!Bj)y%$seDfi1)2e*&hf9^9=F!IaeuV2jk z1o)e?7NU=ZlRTS+CudxyyacDPT;4h7exyQTEF`<|yX4Gd&%ro)N;~Hx@#2>is>9$v z($ZHeb7h1Af)YNqcgp)&0YK0c$OiwR~D zM){s3|77qcQxS~#J%sNR93M^jHwcdB`I+jCLFw-Re_q=YbOrNy^wsTILizt<9v>R{ zWB+MiF#W>*2Cr>aTO;uR`5k5YLFGgHoy?tuDLsxBiB{&qoA7Sw0mqp6_vvqx!+pr% zexy>WCs$Qe9f-;;r#)Q-QdMm@^TDcWe~?u*iY_rT`~}d%^y{JcGhUpI(?A$kuY@h# zTnoUO3qW^;REe?MK!|RwDw~QQ!@TL>|Li>z<}!JQQCO^%@9-gM;5EWBND zE(-zq>$L3R*;-7?E@|262hI2Z%IH?85&TF^&oo)U#ya)1^H4X8$4&CW&`r^INcs*R z{Y(^`=zlau(tieInS1Y?bGZ*nKUrkfX8v(}fMnr6KvPH2FH-cQbox*K9{MRd{S-+* z#mE1%S4jRZQS{#hS?2%R-$OrL)L~_!X?fy5-A8{bH8;rTb=<2?wL3(o|K$U}Lq6Fe zQ_{!!57E!|(f>x(?=VGwFUT_g(|!;A0`aVre}SZ5;G^GG@&9A01C)iwcJO(kXf0HP`zpe-bi(6| zgw6CkwvT1vt5xD(N6S+8(i3+yCN|q242|D056*!1(bw4iz+TkHl@p{0?R7yml!C+o z$8+PuY;b)?DckV1&c|-ahfHL!PX>F9y!M+rg8+W9pXNgR8f=raC+I{UACF%o#qPYS zT6{O3b*SDVdr6C8;>lcokL@rfp-6wlumCEj^mcJXwmU_|SRRLHR6}`a{rAZG&;ah+ zeCbN3YL?=riSDEloX&uD^J%4YasL}s3$h!u52F)&kyJtD*k)LcjSg-SJPt)CAv8X6 zdUR&d!~?E|o8n=N_*`ps8&RuBTSd z&C{bXLaPjac%6G49edW#4c>=LiIan6D{?=_b7N&LCaf`P7?(?DA-TIWCd?i@NNYkQgE4q z+Y}TEo>4I3ZAt$s1?>usRIornz0y8?UdCUg;9>=zP;iEVc?za0I9$P1ugdh36pT{v z=yDmpM8R|g8!Nc$6^Wmzpk2YU%Vc<^g7+xcT)}T&miQMHv@5vxB^myNf>{doQ+#Ub zd+$;i?@u9uIE2QeWxfmR>9nWcyV(jKOGdTtKhbW zWcd9G4p*?9f@FkB7AE01_f?wY+!zU>?K*9P7e(sj|#R^u=Qt{RI5CuCZSX;sE_sMwADVVHa zl!6fo9x0dcb}0Cqf-@9MQSj^8GTuW9rYabv;D&o8{$T~jDR}Z889r0Nt_sdj^3wGn z{ZSdOgMx=2k>779I77if3bs=4#KSUPm4c5dc$ju74}J>vqEOjFf-jLJ5Z|*iymMPs{ML3nctp!8HoDRPf2C6deVV6|Aq| z=TA!fg834TQLwp!-##J3mn&GR;7|oSE9g@Eu6bO>FHz90V0{G-%#-*^1s5wgTfr0s zqZF*G;Niz)x-AO6sNfw6+7)c6;K@g2{56WNlL{VHd{n4!UEU8YlJWKLsDSVLl-;bW zV6uYm1jJjoqIM~!g)KU?e_DsX(*vNM&wZary#8%}MSVwDIGrA@!+D#`@%8ZE3SIg6 z+Qf{!DOyfmR!*j~NE?@5sJWbZMF9EZw5*(oc+bht(?D>v#ZoK2q`;Z!a%O2ofH~uG zoLOmzQ>69luSMqMxqA0Xb9F3o=8j7%%DKZ?)MWsH$c~yTUmNSxT!qEX&f3^wmo`4% zrH#wTEpm3%w8(+^#kpCUD?6u1a~2jNckYDK1GV(NMUUo)arW1_15PEyDv&cpHB~KZiUznNQwWFmon9nFE$el7Q zHzzYEZ#>B+*)qj4+)`{Qu;f~tmNd&iOEzGpWr8KrQi!)A;BqZ4OIO5jAtch$5&v+l zU*FV>oYaCTsiTso^zPNAXZNV?-J`}{jSyFMeqNWJUAsqh?cTj>#>A}Ny;5^?#-1F9+>1pY1iNb$+j;?bU8?qijwA_gigdR0G=d}VHwX7NJV%)%AN^&sfd+_w_M023s}j`SPso% z>Z!nHfEvjx8=QxPT=}Qw-@^6%NA-m35al22EX;G}YU7IYGF{Y+GBfh>p!?A5yyD#4 z0#_m8(BH(23C`;a^CynT8=GI^%oL)O2d6BE_FOJztGtY1%rZjlM>NO zmWBqN+cmpQ-|l^Sozu2qD#ja|vi7fNe;4GRCt^(oW^a!kv@heEjjEp(2{MF(g> z_k+$P!$ksp0{=5}3NwpyGq^d=r|*0w;litBVDbn{Lf)i|+?*_}AR{Xatw5WYpXJmd z@tc2of6Bcf@$v8LCyCd8e*c!7ku@X>-5SN^?Iia8Z_7)12bo2MnW-5CIX)Z8!QF5(@rU%S}y((mONk0(%RA=&moo(mPGv1|EPZ$7!cR&B?nR>@AyBjy#HzP&$o=j zzamK91*s3f|0tH5{Sn!}{HO6d!7>?oM&kw=6KT*Dqqb<4yYNq|aWrM7Kw5Hr@(rI@{$-sDLJeSnmE{r7|30Y+N#zwL|rJ)7IT83I~unfakkt8PRjKe5Gmm1ade_H)Y zKF_rR!NUL3`2SDiCz#yQ(%=!gGT>OI;kP;DUxEHok6%JU&EJ+1`Hf&~jI`JlR-Xpa z)VGDUL0f9!KZWT3?dqLUQW(V{NTe8I?=nI5i2+OTU&DWj$PfsIcrdicxZY_Fydg|? zk^B1IwCs!`ZDMh*D+g0Fn28zZEOf$$F0v#g5A(;@9dOVXS)4l|BM*Kd{1(3-J`vMB z+Bgh_b!Xg?m~qDx&J>S$^4H0M9-HsV*2t5=6i}81*AjeaG@+ta2k; ziZC&w2(v?l8CjU-l8!c~k92k&-?g(gIU8xG^WXAR@^3{E+=I0rLS7z|Jb zo#j+jMy9JcBR6*nW~(ytI0MXO4Ry+>YVt{&oC7W~?9ymfEgJxjLd>ca(WoN}{=ta6 zLT6_F_&iLLVG2ynyYYx4A+E1R<^8`rKFI7I)s2VW*I|ClnMb1ph^8RJm6<)*iRn8P zOm#mZZ$e)Fc29-#h>7Misv*|2ATMDM2J z49pP1XsGo5Z<*frEwi*F7v@aD2xUZ(^G2Qo8Zl(>kQ+x0NxLp~;NbWnahB_*@KlCn zV78pYNX(b>f;}kKaEC;Xfc{r5i$zL7xVHome~sOu_8<>fo&-pfBX8sIR76+|8t=K zf#YFrwLXCgYV`Qyf6^g*La7v)0(FYZ@pa|4EC&NMkN zOUunEa-|jLO)kvv4c-46+Y74CR4@2Er#^jhVFBtRTYc*X&ma7z_VKTq-}%@3izJsn zJO93V{J&_Q*uNa%%E>KaXR-*29n22-b}bp^6*SbZPd;*ai1$qSXC5>^{=YzeG(P+- zcGE?2{tf*jIrSH;k0Krrsa5*()`uyEufF~Z>SH|Q{F~|nv#bB8zf*)&XnpPu%h5c& ziR!Zql>41z2sQv_&;~s2$e{q#DHmm8DV3Vdko@7mIt$~}3}6y=2(-4eh7ZxVFKrU) z+t-y_l$KNE&t`&iN}98Tx6)-rX1MzH)stfKuy0>Ut>eFG$CJ#R%D|hOdi4G+o0}l7 zf9m*y?VFL4KX-rRno{(K$8~>nJoIPpzyHkq(^~N#t+x{FLZ!W^h1j7=`&A>cN7cU) z+nr1k|NVRAQ;5~|36@FNO`FO4YBTUQ89ZlTFGm5mE<`D4Ux#nFfw@6iZ>IeNnxzB& zX_hvY8?oM;i~Sm!<$B1gJNEZ#h?R$O%aERy9vnjRlN#q#2+B^xaH({(+nI8q-M+LJ zI15z&obvfI)hEpZPyUU1I^{RU`!kiFG;AVPxdU5jl&PiadR}&ntKIvv>Bpaq-~ZX)3t3fv|8t?3q}P9@`lt5# zTgpU-+Nf;w_h%{}@%#VVUaJEnldvv9i#!u^iYC%_roMFG1IrHDWb8%D#2$(qTB)G3 zmMTH}B6%>9rbKe`T+YIRLMNV@P6)>%aB2_R+(_(>wv#goF*-nGO0R!ki2>KdiP1DKPcThE+P}u% z1pmSLi9+O;O8lnMA`143{c>`HA10Y30x@uW)ncLZsA;aq3E0%%QR}DCXa2azybK&U?WlFrzpqjO|KCr` z$)g^00;Jn*Ij(Lwh21h-I3$C06DQ`V+w{G<=Wm!t!WPFOxw{SR^`CiOwPx<istWepalG+_^op;R|J^-*OmfPD*>mQsh@DsT*XYU<*SeOhedNHJ zsn1)Y zcBdU9j&8arVnW%v^vdulRTqceGN59VSh(Q!j~Xm|dS1!5)AQe%9peq1dh$E>svqa< zs8=-g6S1__s@e5_Iu z@zaAZ9+-29_Nk}Y(>rd+Jo!@HvUR0b9It3Qt^QBTch7L#eCL(bC1*}FYg&8Ren+o+ zDs#)PE^U^3>5^}bzV&A0bK?shul?m+7f(KW@##B9CbWL{fz72yKe+Vv@2_|!ygaT= zY`6R=#h#r0b8_qKDeD{exv9~T zTN4_L{i@Tfo^>u8_G9^TXV&!YIKS+s{U;ym_|@{CYdtuG{6DJmQSG?~addd*kroMr&Ss_OMrh!^*?OkVp?(+%&9a(?!}qWy0)?~?J;gSY*?a?sdjQ=hw~|MG{L&He6leP`>_ zj~&YYuFu5D9Ty(%IQUsht5P?`NWpnQhMCer2eeh zv7N_mo>;cLM|tX+*S~9aWbDolcJ+Ju!F@w7d%OPi6QS^1+D&NR&7^V&;*KKi)a_v6NvQZthg;u{Af$59@O0kRv?{m%e}N zv%Q>4I{)(2wykShw!UoZh_^0TclfE7zB#rpGv$#ckAKu>>Do4}=4>j<==;&ju+Kd$8h=?mQ7nJG)81ElUix_EDtG)fgMU15 z$g|dS!`$)-w|`Xq$)~^6ZG7^(eU~q8dsov>PF>P_S5eRW)NO6gE?S!K?WTu{7WZ7g z@u?xt56m0b;;lC~4}PiR+?O8NX21K%ufB2p?S_5Yw$+6#(+3Q-ow>dChUR^z)$eu7 zq=!cgJaKKvLv<6oRZYwozU8rY3t#{myMajR;ST>(Rc~ zb72d5Z5{klt%u(~a`@F7r?*Ny^2A$bdxv()?Qke-=G$|N``vwK!v-6g+Z(QE@bT@v z7XQ6nLx=k``|0`El&=yOm)<`xdgsS?kT`n=)Frq{eZX5!oV+n=uO zdGvCpEj}dmi|zM^J<-pT^F!S;pB#u@IH`S?*|T1W?Ksf(*ORw?vgNhQRqB z-J_=7)TH?CcM~t_a?z~Ik8OOv-{tH6{=U0646pN{Hyhu0^eyYlUr)^5FyZSFe=o9r z|Mu6ro&D>4@zm-Mt6pw8*|V$i`d^>gUvI>c!CyT4!99=mzp}^hM}Hmi-Cu7Vc<)oU z?!5b{)qNW+nRr>F`1P+oa-x6i(_Ip)TCNW(|J&btZ|nZ~`-|rPWy+hY4*vbN>Zx@G zv|ag^9(UO8zv0ChkDRXm!HK=!w*U0ugii-OJ2B$B)IN=(B2Nri5wc|5yEjd{^PX*c z-g$q-(5xZfzjWlj&iihaeY%(VTR`?h+w|2@`MnsYhQrELiwq(+#yxj%Q>bj6mdI`(+JXyoMhVGHX$wzk77*N*ZuORE^_E$AB3B|oLjmG%Gf zbkVhz*KeHldDH8|tJXi}efj;QS4NHgTZ@kqOP+qZ(={1syE}ZiX-(RZ4r4CyhHts| z;WkALuW9$r?tP+kR8cwt)ft`ApVQTo~Q*M0i+;@wMM{yFTG33I>C{G!|Bc1J&08IyR!$y=wK zy!Xd`PtWN5;*VG6KKK2&5p6eK`rOCYw%osJ-VJBsN6me{`^%GiT|MXEvn^kH_lnpr z7w>)K-Z{Iw*ZcgYMxT8-XJP)Z$a@d}?S>sIJKpNQ?}NR8|D^R-ntz>`F=ebXjhs=w z$0QBMQTXpleOHk>W}t%qx&9Qd&y7Of@Mitn5I8ISe5p42RzAC+=l`E1KjP_MMJ1?Iuk_ ziD@aNfq+XbW!u|^?SoKiX*TJ*^by*$goVh!I6mQR}-RYUH9Iibg~)czyhTXJ)?LZ@0Ti`p|Ok-%ZY(@0&Su=6&XMW}8W*{}=VAJnI`{>3^jDc>5K3FH~;1bVsXib+#E-q0b!e z;Ci&6ge|{xeg0?p==@0Q-iJ~%tAfkoEu5axOQQO8Eo-H&Za(gMJvLoX(hEv?X3O1l z{-RU9ql+7ZG?mq}t4*}7@9N$m%dYX&iDWN|j%oT1{a$77!{fR&mbYi;tF+WPdMLI8*j`A9K4s>%B)3-SjR5G<_3!_pJQzdyne%~% zw#%cHzZ=)D401m-c`NdJ#*AQw!swVI7vwmKCFGoM4w1)E!n9W!9H%=rbv{SC6WdH8^WOI$ z`Eg9!Q>Oe^#uMf~CNncs>d^9$`!_fg)YU7;SBze?C%&6{o9zyrZ2g!*wHGz&5ux~z zU!L+rB{{?Uc#j_#4ChB~tDq-A^L!sh)T;Eef`E$7$$t;|H^{ zE8^WQ-?%f)x>1FH^uAD~&gV0SVW4A0>v-Jm&YRuvSc{L=MJ2e2dV?Iwr%B z)UrXY1M%kfnIZ>XF0(-yhsHXoRjs>&*$a}#H2-x6c4>k82NQ{NOx@KYcC70|>H4d) zlN(!FLuXf=t^2-IPwS#Z<_gF~ixQlRJy_0jNA*k zSnKZDol#x(2BPl+HESr`tjYaxvd-1KTc7%F-OYTUqlNe+wlh%{MQF{NDNd>@!9|)4 z=s{;OTeo>rgDkTdKC^?%w=7>DT)jd^w6#k%Mq4}Nb{8(>l$gpVZBn~;G-M-wnaQLs$(}5i1yOv# zjy7+qw2Rr{-7V1~AsytFNgFwJ2_?l>MS^+-PWkx@6XV4lm@+Es>}Lpj~}}bM?~b?KYDr8R+|m$*Ea_3w^5`m@S&k#sJ*8rnObzt zIbDgwu5R<)l&Vo-+4pm!D?7!xX;9d@sqpKaBeW;z?1Fw^Aecb34(-78GG247?HVQ{O#N&B6up{I>|)sHiC!L$ZI9Yd*|_J1-Eh8-ij7#w2E9Jg_?*e_ z-1-svg)QGHCV&6%u*t{YTHNiOezwG;uE&hSU``%#O`>!(iTh{(- z7C&O~i#GhJ6HWM1PJ7+L0w4SSTko%Nm?0|r713snCb(<#671h8JUJ%PD*@5&Xl_z0 zQPF?Vw~5G?m%5TY$LVGeogZ%e`0mmC^Uoi{7>lobuvJ7-HpbgFbT_njcU`Va&24)x z*VHWtkLj~ciu7HK7|BYV=x2t1W7OeH|ZMqttqd8_gQ)LJ5z4^{%UY=kFg)H z@~phry!ZXZxDVL({^!km(|=ljNkg~qdh=f2Zr+0rnfC+sz5HFqe#oV_{hf*3-&x*d zuBBB2w%&%VoOSDMZ4axS3^?SjLHY7x_jX2{a;{H?g{{FAy*+Ymp*yc!OSFSRQxl}Uh>q;=-r5Xq){AW0Iaq%QWpmhIxX4>g^Z&B)^koYcjlkFA`c+>Wd zWKwd63#D`}wRIt^FBcm#TVXgVgGrlFjB2@G(~jKj2FW2=5z-@K+p-d9O!0Ne?hf)z zbSCA4lQS2Ux>1-NEZxxEv5lc|Cz_UaY?yIY-dsFBvvg1YVd~qoMdvRFkD&`ux=(Zy zYVpf6J?G9HOL`;;>|msNwn%4h666-CnUc1d=G4yMT`EmZU6Qf0&J$WhtVODYOhe7J zpl-Wj@`4zaPeWK1L&B$mIIOdtmCKzhGe5%H|=Z!ThbsEy?B2QjEiKsqztcgf5 zkCZR1x&`vGE05~2-2}J3%gQ^>fv4olwpQ+$T8Ln`zqEX3!oOT(*~>z@a}KzB_r^H7 zP4S&^9j$fdlfK-2pIK2t#G6FY5H(SHk6T=Hf@y6=(!2BAbb4*U9@jw-Wk5`I)b;3( zPQxu16KlPgQlMKhlv&;K%cjB2om)Ld3GA+;EO-0LU<1=U`Nn~$VsK(#%F(X*NQ%?t z$6`vbHP#agA}`&Ye@M~eFxRKbN*A1hCY#b`+P@+5XvoQ!+>iNQbxyp{veUq&}MU~Q2Zs9?J>Rx#qP4qmrT8KsXg_{BrNrs6zj1q@zs)Q zL?R<3HU(Y=GLfN059;B|_?W5vaBOzM8Dn&`wCg2zng7LDm1c&(m?iFBXkij4YU+~u z2At>_*!p11uRkm%OR!q)G$jrl`a1J;+jG5T+Mf(=Qz;vXmyezkDPRJ$I)%O$qBT=J z!ggMHAZI?6>Lj<&IRJ5Z$7`ucmcyYq+iy(Y83Za7Fn- zNHm=tix$yOQcdK9g{@67mYwKM=}A!xt_r3zW%^?TiN# z!Cw0X%mkH|v0i#{(j@3mx@>}I)UAOGiP@-Ns+x-KToB!n+S!u%BCQ^WWC#|K&|W@# zD%qau)hlFe?49i<$L#HLW@o?0oP=uFv?++@pj#9=C9-4QdH5W)mPa=q+-ZwP?U&1f z4QoB@vz!QtbDgPFP|jg=_mZ6|j`}jMp6?p)!q3guJORi_5R)^elhE`R%hI2f4xArg zTrDTP^brD91QQ74E=u9rFP znFwl+Y{P2Cu46%dz4}`Y(m#swzF{vR>Fua#a^Be!ICW1|#a&gn)@>#S`U9f{n; z>x7U4RFmZHs%E}YX+F!WHCF6!!blb!Ea{TrKt?2G=6nl0RZ*pvRQvFQKjy6IvK(JPv9jjMhl6GI|mn%Ri{WNahbNk-WQKhceHgd{Zc z%yj-o<$J|!w79h_Y7Z#6GU2SMdEs!XDoPV(3qMK)ch5W8MxocnnGAGuj!|}LTuPg8 zw-V^IV5j5D5tbM;*XRzqXx(u0CNk|RDJ|@8ghb+R5gTXH_KaL$OlXhKxI*fGYkw=_ z=T64bsjOa>dd7G&in zy4pzrHL#kC0O%qMZB>lZDS3p7+x8@HbNHC@vphxkj8?bV4^P)ucWg~H*3}>%8-mn7 zGk<8XBao)k?LA_fbn)KGSJ$`64ZE`A?ZyKJM(L5U6J5xhRL^ewCeiNH zUrL=L`gv!k>**ZRrJG>{gFghh3SJOo9s3;3FngXMi4qcE`#weaQsbYCX)h@B+VY5?F7(MM>twHzQZR_$b)c_Si{c|HGQDT|ec)Gl!r=2kq zh9T99-!2DBgxaf7eMx_)_SLrEGkMKDxE{p`8Ce7+z|bs8idErRv$mt9yDc@-?&NQ< z8HB`gQ&d%oTb1@DKF!aYBV)>+oM^k*qz^N~vb1zr2cMusCl{S&nKB?Rn_gvaG7Bzt zEBZO99zI~aq^XAMnmo2<_jMO8vr99yZUSQBWww_bHLygNwq$xs`x5o+SZ_}kcTRV- zlzNApqRWlW8$(PK8<++&Au^*+rZL^IU^_BNKBY-)Q6{m`Z7tQEyJ=}spjvI%O@h7A zI^E(-e*W+2Pt9DlnVwL3zq$NnR-s0o*Tv{Xw$YbHwO@Bd%B0_wlOY-W%??&oRizgf zbSrSreKQhqFYjm)(x&v#(eQE2Y`L>2U75{r)BWNaD)iJ##*>Ee7BFDcDKwBzlFEHkT=PqigdowAIk{bksQaQ3QFxctqC8ahOw(O8SOW)x>U zLol5*w@$n16lp}YOJ8sv#39JfbuZ=MOqXalNiMPyv$&gng}c2wL_cG>tozC8l2S$r zm<6IYyzv6hxqWw>?uJYYKw2m4-6>T3m^)+BqidbYj7;a=_R{h2*Fgnad%LAQq^H4p z?#3POauNkiuI!ZP5vTsfc3%t$v)(~4G4IV5qIBeaUtp44w)!S>@ zZ~O;$7`lP$%zOPd^RBG^U7yvz8@$EXCtdx!Qw%*()xRD&)N6K~WPrD5kxsxre8Wa- za&ngA9IeSIaY*ZN$~Eir#;k90?Db0v#meMgR#dTSt&;0U@WagP<(7-{$EZUfDPy&Y zBVF6#rfPWh&rPxwOD*m376iabcVpc+yqwcNrpx+m`WQ{SN!K>%^lr*_(9H#<8piYK zyWiZE!M4uWHp@;K;gL}|Ze`W5yDOuCtgkiG3w!$e6v7OBWf5x7#bY~UByJQj^bUH^ z<*Pa7chV#|`Ufl<*hBqu<_534(t6C;n!u_|P{x{yY(^o+=_VtSQaSIF887LjPwnvX z9pAJtf(f$%->YZ$vctAPuReF^9X6T$d&f62>gmvw;?kk6U|#I50Z2mAU0a-+LU})v zF9H_!lR5R1Td%bKd%kKVootZ)R99tN)vTR_@BY&DDZ8c!B^4#D&|+3VUj`ND*!Kr) zP?=t8b%Ij?ah+kA6EwN`UAP^a%WAWQ8GPxZa?`GA(X%3+&#w(PFPE3?g(2^&zI#Zxhai-NSj`+WD#%~ zH(u)#@1Dsc@=l)==SYVctLV=@b#`pmlP)}?`w+BCtwb;@cr(>)lG7kfH8zSS3eh=B z+;FuV)He{Non!-9$H}@KmnN6*oawYh^XHCHy+11>NV_TL#j^3WBvaZ_UhKTFf<;}H z4CYoYJ#)>s|K5WpBa2Fz?RQqZ7)+Xs?M2Ok6(#zW9kJv&jBdSwK;nkOIx$Owy^QQd zk&$I;TP|mgFxn83d+a>7{CqRNEI-fSfs+mHUukgY4QIV^rYNYbm#>F0KYjx_@y6M~ z93)?lH&=3q+u!-^_{S#`gWUZ12PR#Ybm&F1UXirr=AM%)c`3jBz5@2L9x}IDDZ%sS z$$gD!#o78x>Q=8`Hg|5c?vjS(o1(RwqMI5vu3o>A@uS^4WDV0xzgwI1iJaf>bB)O zJ32b*FKpX1H(GyTLv+)EX#J*n(fYPc)lvSn*OC0MpU;lWf>p83w&iI~^9`qJIKi!L z)50kKE~n_$d0`um9XsU9z|pqk-0F^n5{c?)v>i_M7q)j0o%DabDC6bxH-Db-XMBXC zo7{hVWYhCSMbse2}0zwHqzr9o0s;a&DN2_p)ce7oq3#xZKUf zGT%vpNOH#v~-h6l~Na@ z|5@AFc?gB;KJ`y#UeWY5yYFfphJRMn=i_6USrp8}gZa zeZk@(Tb>6!@sU2jn~nPq6sZHmhr8_x#zN?y@wmU}c@|H?!~IE>e7!$!xcl(4 z>gNH)9eG}otY_(Wnmq2ic+L%5JP0qBw>Ny@@BO(nV%LP9M{(!wJugYtGsUHbyzajU zs6ufNkHS;pEI)5NBpmP08;{@*ic~#(xy#Kup>cohJns`fK>aVtV_>A8 zgyB+n<1gWQf8m2gO3F}vUb_7`bcVvl-{UscIPD@Fx7>x7#G)PuK9e0Eag(3KX`#oh zk>@_(pvO)VTEolVn~q%m`-w-s^1r75Uy1Xe$E{EthTtnd8UK01R|U#R;xN_Y7Vdqb3XYa zyjHPa|=d`m2EAZc%y7`;&SfhQ|HX^IRxRLL)pS59H@fmxSZ}dE*hrPrvZZ zq4Cm-@BJS4A)ZHzi})Ab`Q(X&=lyx(vyi+D|2m+!Gu2Cy^(<*z<#BJ~Ik!y2gYa^B zd-Fs1dw(vC*!2WnO4s`#SmK!UxbNjTm#26VUM_EMISGI7&!rK&7x8ltckUYWl4Lzg z_`@FelHcYJUwFB^z2OUg@6V+XyWr!bAAa1;;U&p>*6H`S?=29%@N#*3!x#SEpGzZl ztMGFWckZ0@l4L#W^n2V(hKVnb`x9O+Z*Ta*-}~DIo%{3BHRjVxLEO0$#7mO(EG`GV?$3`cKEgX+d?Y;Y&l{ieA2I%iKX;IL zNwS{BrQPGcZ;bH~-udDq;dy`F_=J#!L%4J2iI*hnSzKQ9xCehfw)_#^`Qjttd4Jyc z?4r&k;m;i_UXrY5aoOu}9~fhNgm=F9NO<0#H$D;aRK37hJ%`3iFTSTj|yzv>v zPvnmQ#hpA}lB{QOY4o`F@jRCJ2=9FHk?_1fZ+s%Ofdja6*NB%S>se&}GhX-q&ObgB zUM_EM{~-LmKbJ=A4p64`@Z+qBmn7?1!f*Gu_wk(D_lXDLm&Sn_@TSvCCPf0GHv&`@8x*_$T#*A zkHRxwKP};SfAv4kpI--a=v+U-bJ*inC=OZaBaHA%^#4GipQLY<$Gx5BLcE1XA3XD= zPr~v3yz!|2Ns-!*J32pJlB{Qm#{rM~FwY~!LwM$khlJz(dE+tlr?me+(f)x#KcR2* zxVL+r#gFiCf8Kf~eBIwv9-Ka>?=I|;By2M4&LQlczDcPYZl*qcMyam=p z?v;N*`U$1}^^eStf!-GxV*zVlVvQA;{U^qsz=Z!%>Qm)DRW;G4+JK(`QzrS;4Zyr< zK6QG?r<#ESb9`!2l}`oc`_%RYKK0KMnd+>*sxH zTeYK}2F_URsI@>Ra3Ao-4UVd;cht{!I_iX7j(YXAj=JGGM|~SucDv{5w~qREAU5o%J->6*_TK~l?Wk}358?d9QM>=@s3(Cp|IJbRfv0?a zHO=v>Gm89b|2V&T<9NT?dXis7C;Qd7lX*YIuj)?stB(M`0Ipl;SM^K$>Ly?q2-o=4 zUA2DIcA;NAz0R)|Hv82bz;}R9i(hTu=2r{a{pzjT;kU!DRsvt{@~aQ*_N&XU@T(Kv z>{tFhe)TIL@fN=t06wtSubj90)mPr(R~7FhEZ{ESl$-o&FEHU|zk2Wc{pt<<#0$9L z4!`QZ$FIKnVZS=?F~3@IpI@EzMZbDAaKe}U>R%7w5BN6l@VETx1>g(c_N$uj_*Db& z!|(gmD}R9BgMJnHpGOp9T5DnUnPM-pzn8JKs}E;(D!@4 z8hQzuKfwp+7yLhdwGU9m0X0|>Q03#13&1c?UrL(?%1@x(00)5Z1l)lkAaWx1z%URg z3#jLTKLURNe6I+o6M@OV8NgY9?19&O+^_!pFnP}cC3|ih?55q3PFb%^fEW>ElZ21)yGsob_PJ@A-K{(-6%44R z3D9)NWx3PES+VD)7J{q7;g&W|eBH{s#H(uQ#)j#QOKXLuKijTB?C#07+bnhi*>+pS zZhw~DYRzvj+it1QJ(_K&{T|4&TdV!Dm*t7yT1_{UWmm81p2@b;bi-M84Vv!7Y&%V- z<}h#mGX1TsP<|GHL1VX_>ydR>$YjEW52HAG~%AhVEdNouOm@&+zN%)K~C4!#l-Io#;$W z2eU7zsFCJ-Nh zDZY(wcK6_Nq`yA#=sU@$PMQWT_NihTuQl0N##a`$~6?`1w!R)_m|pBnFC z@pqz6o#?(d^DcDYiI(op?;*3CA{FYu=mMOwenb~VImz)A)6dLv>U{zF!#&ooV2M+m zeV+yvcX=24W_a5$g!hBO`w@(8v7c6D_^aJ>}9p#{(2=omp9Y0)ETq zL>_RjbJnPdr})$<4=W}9r;Ga>jQN05rk>W}>~-EueVcfyqfYgmrs>uS-DZr}Xu2U2 zpDUcdlqppq@!n(j_2AsJTWCnvRG*q!2cGPx$v2iKHQorWa8$)7PEP7{^L(BtDOF>ZbI^e zCj7pV=ZDR+_-%!JqxjvyLrVLN3~Rs5&g-!El@j)o#mOT4>iQ&cuB2WO=P>>op{Y2X zytD81;z#PlH0&ojY9jHJv_A^#h2O*YG4N}kxGYFF911&*6Zkprj;~MpWzxCZ@2m>c zj(4tt72!LobY5|NMY*HOeJ5#NlZ)VWCdN#80E){7=++IdxxB+m`gwR=>YVNbP8{z% zs(DGgh`&z99M%f8=3u%VA*JO~8JlGLrL zehfrOpVJ(;U&7b)Hk^Y(|1*rI0jDfz>g+t7=n5&XJ53q6Vd*@6($tfw#YKu>Rp?9} z&Ia)R2VG`YIk#Z1(+91ti+gN3j(`Ut;`L z(+5qxt#radFKzNp%Ew|iFLZrwHu1>bUL4Rh_i?8CZW4MO&5XA)YI^OfB(uLo`_%85YK%j)$ ztehu!q^io$lAQdB6CuqLC7yLA-hakDR8&%=P5@}rCO)lB5IWz@@)alU3TRKm)-AUTc@HkEL(B-qo#G#L%L)T)wUB@9z zvm&mQPB(F=u>Fh6uL7A|@}#Mw4?=z!G@(a$_?q@R_*?C_$Z@uiPLJQGa1K3(@q10z zWa_{?XBqx|lYMG3`QfYZsS3f8?q-v29nU*WJVWE*F&E=Jz$q*LxejZ+^KmD;zSRre zB^XgnH(=-%IwAvgJ!{U?8E-w^EB@Yr@hp=zb)zhEDv7? zipvI>DpU=LAyOHk^b4mka9ihRGu?07c-$s@KZx-pFtu!u`dKw3hI$4gC`Lq#CNcWN=oe#9j3F`9 zFEN4`l}gPBOf9Y}SyMV|!jiI3d1%s-$?KGv9Ha~$g^{}2wl|>2Gya`sS50r z{4n)qovv@wA=^s%hvx9G1t>35 zP+vA1wPFxTB(9Ia{1&~K3@0*-+k$)9uNLo!B*pIVr z1E1&NTR?GH(;uYZ-G;#AA?cju54^7U1{t5|^0jqp81m3d7=HuEC#KQ5e1A+Dbv`|8 z;xhF_=+4AA8*s`7m^MTA6X$w~vuA7>5xUhF8#LV@@qn({k-@nsA2U8_7rNJDT(0T* z2$w8BE8t91Q=J%Ny^!>Q^L3uz3dwb}h0q;5{3lRaHY8aaeu_lai;={*Tft)z8e7gE z12-1mR&tATzEgU$pIp^_l}+yu-WNWH@dw})Ws+A1NQ)WE`~|zfI9&;ovM|69$>e!; z01;}(fwk#zU>NVuOT5nKVYTKp^k-tD{^D0#23eROHcvWV4NP;U1}-Xc78SqNF=edB zY5$_@X%cUt-575I=)dt#etp|njgT<%;Jmav=qD^Q7Ck>5hx_sN35mmJdH90n5iw=h z>Rd~XR-7fW=|U}=Wc>LQBz2T+=s6z#0w6DmX;tLc82!&_dD-Sv6b057izv6U?UiZS z7JLQqor$poC@Cxd8Q#KT)QgeCAki}zflVo1?k`zE)0|?`EOn&L)TLJ8u>)g|<}oBv zDI!Lb7z&|Hg`GjBH5N}Vsj~IUmS4a4{V>Ku+Harsp+%iM4Pscow|V?Fn{b~Hzdy(L z6@V;r>n20|tU4ShhwelSL)Qmgmiu&ZpMxJUF?ru}Tve(pAY{!rw6X1rQ>KXwlSKl1o<$Jlz#@T4i%f5W}F z?DP0F&G?JF2ga4_EiqY#G%dPcmLo~S@0`F|^1+>N+xd<=@6r70Y(9DNCo5mE*JBsx z(ql2tm}J!?D~~VYAqLnnN%@dpy-o}YbdM7%?x#Gx^Xe-h4ZQ{9t$C)0 z`f`%0I7h~nuh#h{?fHI4AJ*-khbMpuWlEZIP>gh=uIl?Mul*S1Y!?(c>jU-0^(AXd zn_X@A@M=vvgd8Eh*E?&3 zR_DidP1}fb9c?HSH>cd zvpRxN)gs`DxOvHai_GNuD7k^J>|`l+9!9G6#1vL zPi;T7v*e?d?Ps(HJ4^1IwrEDJ6KbC|J5)JiqnnNkbUK2QDObtsnLMlm%F2Qw@9X~y zBf`^`CXB24Fz)Wh_{Jc{8QeePR0bX`zQ73-6_v~gTvs|f@X&=M+VWB%8 zqej#1W6VeBAw4e8^H1Wa@tt7Z?@d$_Cpl`;^JWg16#u(1-lY8pO?~@{<1`SCXAXF) z(A|OYAx$@I+VV&Bd_dEE!tfguy02h-Q_}@WBJp`ZmT}ZX9sVi0%|0V^zr*-%P1kS2 zZPfklL|J?KcDbITmQR7M0^^kc<0C_NqT}pw>xLPFEynqADTC!aYz0cn`f0ya$tN)e zZc#WTj+ZjLqMRc zT-Jz&%{*)uV|x*OS@eGv`l6{W{Q;S`3B6Cxb561K_oOU8VaRnD=ktJE7slU(ukQq( zIw6_O|5Blir}>B7_;}Zc`-J9NjE@0jWqrRP4*g;biZLXHVqRYr6eEJcaI4Ch9tahO zNyxcGYQa=$$@IV#l5b}1uiiWNxQS?oZpZilz}$m^$IoKt zqwKQsVAd0nFH)}uAzAo!jPC(20IYka$M-+x-3|Ih(hchu6fXQV+hWJ$xF_I$uo8)O9##MvzG=%XDMQ{5@#g=7Nkf&Kn zy^n5Sdhx-MWa)cK78iee!q&1`6Do^;?v!6q{4xKe*2%LcRTjq_cf2g)I~uoyeUpTJ zJH|agplsl$@H~J)+V?qYnLg4vd&iw$6#B1YdT!gu)4Bb|FwU2e*3e5l zl$|DRtWT;{zZioUTZS-bV;ku{8;chOR+TgcRuPtGy*7-u(D@iO0Am)CN?2Dq^Q2t# z9Q734o-f6jBBK8Da0O6Orv5-LE&bZ-q=&nvAA`!=#0+&|04ScJmXyrKi&^i+vy6A{ z!s{C0^D!Pi3D7s0HP6c&qL$S*AHzBHJq&4^$fkVlZCGet!uT72d?&%gzg~}b%(~7= z>A6p3CA4Q@gn>ZWko3!nIX30+Yp0LB0x1XG$C$O*wL-rcBMwjxWjY?g*wTbiCWHE( zz;egAWPEY2<4h-=(`3!P!pJ|od_vNBGY=os{+r0#taYu;4h-sww?bytnx_etwMqI| zRx0(_a*wG)2VoKV8OGB9=`?xP#Lf@l>bbU&yMfcK+&z(pX+TL?TI5#sWi)%%2O=q5V7jdxj8L97A z<3+FEzJmvMz3tsR%bo?9?xh`YX+P_seZ3Z+vfaH^wjpUDBi#PilFF+pp z3x@v;oi~5fb>J>%4*OXZXURImMC~6@VEtEE8EgD6z}v!FjMYFR!2Sp_%E;J@c$YKHXIC|fwij_Ox4ByE_MB2gt zB-6%IM?S;DXI{x5uxwD*5i!)?kVZj_s)!g(V)Tj8kMTBV5VLwn-sPOtH9?H3h!{;` z^oh|g#vsOX&Jbob`%ILf7!fg=#OTBLaiAabf}%mp>S2uMio{%3T#hKM4r5f+W2_t3 z^de!_lQUIG-nZ<-xV>Zmv-$u=)vy@levGgf^rR@{i%PuRwx2%%YHkJyV>PasZPi~(wdvfKZ<>RJq zJ+-}ZE93slrp*pko>5s@d1~d<%1KX5-#Q~1+B&O!_GMMGXIIXuWGYZOy>iml@QO)4 z96xW-q$}sGU9h3^*}8V? z`u6(vjh8jFZ@H}TvaP$FNl#w7Wzx4__i=wyd-Lq3%GbSi+$353qb}(EwuyFq=X9UD z*LYJ9G0wPR;T=3QpIM|H0k)PY>S(wuh*=d8qe+ZDj4l20G>CEc5Jsp>{WqJ2cQ{jm z;EDb_{9pG+c&cu~sOl4=UyMPFtHk(*80#D{cNB|RJ%mxED7va3#(fct?@LLB%KF5z zUyMO9hAG z-u0OikTds{o){>ZQaru%k`pG(E`3lxUZc&eW#h+{x0k)mDgAI@QqSaBQ<|tC%cwC> zPrAgfAC~)1omJU+YGrYMpi&hDrfr?xKI5{`qS=*ImEp>H70!b8>h^`Rt1A~&&bxY1 zQSrQUoy8Y9OD>CSt?OI1XvOShm35Vo$|aSHZ(Z~Hz}m~!&#s@nzH)8ln#zq^8y0Pu z-B7u)a!uLdd*@ZE1&bRo_ilYqE!ZmUexlv$7WvRW4f!DB_7C$wPxXL|+hzX#q>&$k z;x1!_ukb*}bT0#hMRZIRb7b%FG%Yut!RcZ6hJMF`?^PnN`lK(D9=K`{1Apf`$oDm} zJ3-ylvJ#o;S*r@+Ei@P7Jb?9~gIZ?JbE=_tWv z<|o4QyEx-Tg=9YcV_pdVZGQS4J&U&CKWXHtnr_EK<9V12F!z?BkPI;Jw%7Gl5tRL> ziJHIY7`#Ek)axB4%vE?zhj|e%c%JJGGsnLuzEsG*TeC@<)?2XedLIt^ij&!gd(!N~ zNj(r6lg}p12fblFV!~YK4O92==cMD@B5TQdE$OWET9P`eNTEi-`VS8kKyjHYZBz$Y z#?YGy^ja7AolDhfKdUJbS^+$?ZbU_TJ$f15^tfdm4~+ohBpLVgJ*)E&U4!YSK6>(C z7i2Q_551X(>$RUg*@v-u7{hM9>%G=Jb{|^u<4(ITsq3d1a}DBM?}vYd2iixQ^NFl_ zJ0$ME#4!7j`m$jedPqd%$0q1t=s>p{urfv-h-7hhI ztLgSVq3O0bgQP+4lhk?Ypq}^A?(JT_gcCd)|1&Yp0i3eoA8P+kIs5RhZ7V z5yr)uPMLLx#m)yEH@|0Ux}?zU#(0yain@sv|arO!yZ~BGqml(g*bi*tskhlGG6qfEz6aSJBbQ3X70hn`|bUseyx8d$F zbYY=8AEQRol^Z$On3ClB{{UXkK{QNB*hJ=>=>K_Vb{A>!G%0!CN z{LR|YpCJhq%|!nJpx(LspH0)5y;Y&p^{_5G&lck?v>c-zV7-Mhs0w3{rh4a$fSwfU zvbfXaU8m6ZV!Qz$Jd?*yXj!k<+aGrG(1d#%WIA0R=3&41AzdFA?CTSqVlv~$e&-UC zF5RXw?c*S1p`T%x_EAr?Y#TDNOX}>Og{Ej0=?AFWrrllT%wdeJ`zhireM$)DHRMTX zE)VBv+I?(%5qy)gQSwUa?H-fHdYnU-U_=d_sUIt3Ss|-l?h?8;WBi+@Oa4IfOE?T; zY&|l0(2ujugOBm>pwN;B{~_4dC;HSxVq(DQ)p=m{=Z(LEuD_r2;O5rs)O@UA^-gNuOu@ z(ucE7-`jY&O=wBq&4PV>q8m>7=t}ozD1B`BTt~2fro3eON+#kXC8i0H- z`S6$)BxVisB&Btf?ld||-g^3?_?0~LG3$l*SIIqZxlF}c=iOO6EC3`w<_PxniQYb5 z&Xsa8V@39<^mxeZn>0f9xQr`eJiHYsFB_0mvjbubi!mUZ$HQWj%X)bjW1UluY^Ha7 zoXzAqcB|^q18FZ~$$V@1TKd})(4U!R^k=+zGzk09S1?RHJ3u|Naha>z?o&8jA#r(* zhly3P#;OB9aA*vw>Lo5o43;~{4;zcM6PhI!QZ|(`AD^sbJ}&ZOG3>PbSk40(a*cBV zk^?%=T3(aTbYOG=)K$1c`DdpYI%WZLZ89- z9KibHpLKLUK3)W|={wCj_(7rj8OGC^P8k*7I@VRHoZS-FvyIFay{m;%7PLFjBTm1| zJjL^Fc$M$mjCE|-#=Sw@Uz>Ih61Ht8rQ?z3f)^U-U=W~X&Gew z8m;^1twVPS-N!IKsp$@wIyBq)7UgEz-^r{4I%;b7`%ou8)^nqsl zevTe{N?My_UZBVH^V4>HX}fdNcKgK6&UwUNpX8Z-H)G$mbM4r7GY^*m=!!^74~tQc zaa9tdTu(H=6i61|WlZMKZ6)dAnS`cPjUZsi|^toyFoibjX0{tn_p91|UzDZi= z-_VDkpOL25{3n;F$*q&rVj9OsW+OqyR_Fs_etO>Xa)Pm zq2o~JuV*R1)oTckw>sgjpO!DNF%B%{Jdt3Z<5M!|_33nwmWFYpqgYLzhpxR&N3rOE zCpDe$tteF$J>ykHYl*6WwgTFUc_$A^TqXV!PcBg><9>4MICU~^C(oNWBwsL+6=D+a zEV-P#-~n~AZ@(^2AG>8G-GQXP$qD;T->K2q5S2}uw1&KPajI>#v@61xj%_Vp?q=u-BZ9{s6hN*%!${2 z!g&AZrKV;7K4|Sq&O0<-LLI_>us|5XX(}tP@^i~e!ud#^aLS)Bd6LW%j--Dt_WA0C ztnA9q7mi;49Q{=66+ga}*hoDEzH+y-0^+ys0SxF5J0SO&BL zmja!@UBHz<4^aON#*RSvo5&BK`~c+$3FD^Tl6&Ja2VN~1o}Q}U=XYx zf)0>+EHDU&(qIG__=169aQKV(1rGcGI0)?zd8b|oNc|8j?bx-;hYvgjpP%DD2h=YJ z6L9UDcpmV)yY^$nP4DB3?nk)ql2Y!qDAFy})DpKTE{fxKj zD@4}>OkG(k4Yab&oPP&63#sb& zu$czzdkg2@_hNS~;Q~WI@H*asen8%bz`?g-2OPM8cVPG(#Pwa!yc=G(`f@P6Pi;3Xiusz|*Kcm#O$YL+m7-vhO4*y;g(3Y>jmk@`6BL*Q>f zcrDX-;4Q!>fu91)*A=M*@J`?%;AvpZdh!#v5qKCF1{yc8oC16l_ztk39y?$ca1Zc( z;4eVvq9Sz>a6j-k@MmDg#v-*I7y!Np{1G@~6Y>VQ3-~hd3!tKbuz~*o{tGy9GkF2T zfv*D315>uJGzPQ*Uj}{!oN#fGIv02?@HyZqps0~82cQ;+0#^ff178Q8114Wmq^g07 zft^4<@I~M$plB=U17g6wzTjcHoo10pJDTq^2S@7gz`M05=1l01g2E16&fL zT>zg18k(^K9&Ra8XSE{Bfe!<|IDUbpmTnT&{cpjK@ z8Sw?~175YAIt;wBljUXL2f+Lt<#hgLDgnioA zs7uvr)$91=Rg;RTX4RrvRa~{HZQP6BfnM2mZoS{ZJ^6BX|LeIqKgGTLz1-`6x!T43 z`B$)OdZl`!x=OuC^{F?rU;3}=YV{U%jrupWS6!>FQ*Tw*tGB5e)Z5iQ^$yNkzf;|$ zZdSLbcd1*|yVZNtd(~~~cJ)5>e$}t;PK^rBb+7t}`l$LCDk1-_ z2Gqydll-K*U;T&rl=?KP0iRJ1sQ*+Cs?Vy2)aTTo`aJvLUr=9EkEk!HFRMq@SJYS4 z*VJR`>!|U4Qyoy>VlD1F3=h7izOR0u4yqrb>-!`1gp$u)$cMJ8lHANuGreKcrf{TY zZcR(mBlK6l zy!J?=iq};~9_qRy_qtiW5E9{2A(6k@9W7Dwam2{7%{9~cYI%3G zr#sfslcMBPJ*|-lWU;1gca+tb?}k96>T;NH0DgKL-9)=GD`$^D@G zhVD)Dhw)Y!0o6-)B@dqdnv#ax7TxA$@tqwl);UW~X1Qd^%&iiW4f>PASuwdtVlsU? zXmfZjJ?@Bfwae?qK((1Kzjh_Ixk;H9iC#D|+szmAooI5PfVYhs<87-GJ@M|gn44%F z>qZ^dSP$QsZKh+5#=E;2Hzap+{Z;GSxsgaFXy2P!?>h9-6!r+vA?U{4)d#@|3VlyqMwYsbbF5ax`+9KxT%ypl}VaK}eUA-S>>m zkGG{QC{Tq82Qj*b9R7tX;7Eat$yZiLdFycDAVwq4;yqIRFC4@eePxAIpa~TYWONVN z{PQ-z5!jDLpT#|2>n+TC6zVL_qcr@&fsEZ_R(J))N8ymh;w3ALyyB-Y|Ix^^c;^>g zg_(~^@8upnFHgr;T^C`QEqY!andmy@>GE8s!qU#AC?I$;&V@yWOOvP40oylHbN?Y{OcIP)g9HIgH+Pij#NvmyEPh1(NNgz zPsN%$|9#gc(l10mJ&;^Yin^*I5FY0Flp=R-Vy8BG-~{G4jjb}>ro`{TwNIv zBsK##v~5Br#d(SN2jjSMNx6M(EW#5ldOApz>eN##A(? z3ERkNZ(5;dj1Bk@e{S=omTEqPmr;NDM+a^<3)5w>sk5s!%Rt=S$+E!R7L1)78DEuk8*uQXJ8t@<#P(^l50n!^4H z>3u{NjO?$F7RYG+M$icvgV#bDB4hAcNMB?OUPsp+83U7ox+J6drKTN)+$*r}KZZd6 z?s_JN7YDCi$N=`JEB^VBI=!` zFuBS%?9a?5I#N;H*&UXr9nrRUPfL5e)vSD5J<5?MQ4EEUQJc*?OM8#T60OmMj8jCj z#?vb@;Mo!D-Y%1-5#~e8$uhkWjk!y2{Kq;y3N!!c&NQp@`bP(z)jyhn{@8L+#_Nm~ zDbLg>BXpU9ZfU337|WY6c6Xe4IZ94rnm#jH9-+skJhxUJ>r#_5i@imiiL+xFpgfBsjM=hhu&@*6<+?? z%nU`ZJ7kCptOVtu%)8)|hpxc7PCm-KOFVf9^R4#eAsc1kCyx*6TK>J(JhXY2gYwYj zUm3~}G9FG95}{es2Df|<9n)vIc$=xTycF6bwTaec@z&mCX9qg)@e${ff5%DCmIbyd zk1z|9jw378|7a=qvJtVm>~<$w+R=HW`?|QJr+sN}HwqIy=rhJ+JMvCqwTuy;9DrZ} zX=JHsayq@FVZ80$7)S3hzB8^R>AH9-#e_MZI7=^Y9K|4agdkUsC`s+g%25QJ(-oMS zDQ)`jGp#p~inYZz#&>k>jE^FtMKT-m=#JR-_+bo<5b=^*TAJi94p|sU`d}PA*SkH+0DPmC8c z2UPo;GIr12!hHfWxCi9wjNP<5IYY_)IeXudvAYtx5Z@=><*~aLyQ$dS>aqLc2e^j` zyCF{)mwu4*l7zAFiVVLOKVGB`2KjzEU&W5={5jW{{}&;y z@>_3lo5fdIe7D69S^QOt4_Z8Iv3k((nQZY4ix*p5Z}DXoUuE%)7T;&_V-`PS@t-X& zm;B>5%i@T|TP#jmtcftb{~<%a-{QM1evid_EKXS5WO0MV`zcrX?X);-@$i7LzsKSx ziziw9?SD7!S6dvh_^Ru(_>VdH*<#q%s4zQ@>yY>Cdoug0*+UUvJ+VEpE2BGly;&GFN^xEH1YAu^Wy3 z^%gf+Jl*0K-eKJLTHJ1NoyCsD-`!{E?zFhg;sq8z`F7)ezs2n~{CfL-hQ*KFVCZ&P zJj3FTzRlP#vAEpgC$Bg5J1t&fajC^myw$khZLxf7U4A?5`+AFKSnOE*jq42ET^83{ z9JaXJ;@@0r=nhzXm&JQ5ZnXHxy@u{ai!Ze}Z1ETW&A8uWaht_2USsU9wz$gTJvP5| zVlbNMKjk;W;$Pov-tV_~kHu>&KEvWa-el+=wfGi`FSB@!#ox8%b*tsC1HldN2+vbfITXWj5^gnsd96Ax#_7C&pT`i$W}+2Yw2FR^&7#Z4A>TYQ7XcUpYE#b2`c z`xZZIvGafp$Kn|lFSfYB;&zL#w0NJzcUt@@iyyQ2NsIqz@kedGRod}D{ikd`{)fTO z-*52a7C&b3LlzHMe2c|n!fJxY6Qe z7B8^4(&FJy7=Ayo_<+TqviLTO_gLI+aihhnEM92wbc?50Tx{`IZ9d*+@eLODS)8!= zQj1qvywKul7MEMBEFQMy@e_-`YVoHm?zcE$@g9rA7N2TyvBl5Za(%+$M=k!8#qYED z8jBMaZ?U+};#n4#*m`rS`)>2yHcO@MF!|rzZ}3KoD=q%h`;Gly-)Ha>7C&HdrNy`1 zZuwYTZ*i%`PuynQ-}hdFUuW^D7C-eKV?SVVpT!$3o^5fDjqd~RHuSqJj#yl3@zAZt zebC~&E#7N!qs3v1i!J{3yNtiDT70j?S6CdexYFVmZ!z=_*m%8Y@vx1@KKt(G_YKU6 zyuWbH`*l`s7F%3z@k2RuXFD~St=RRk9rifsB}NzGC}!+gTfB4Sx&>~M`kJq%rdqFv zT_UQ%tlpRqy4b5=k)+U#5N44~Ja9HOi!m2jY*l{@)Lb6xPDC!o;3TctM!!VsT5R6b z)!W??Un{ps$T`4yc5nYE5@u5Jk{qPfTghe(TPI4Cz0hgO4sx#SKpaJQl5gJ+)M)*% zrS0*S?dww8wAO)p^jL25xJ1v(t-j3n#!fBjBq<)!}Y3j4CTUm8~Ze*#H-nVE|=Vw#O#tgAQgdnUo@#fg923;&hkZaDYWiHFaA^do}c{gDdV2(&F{=rdq*!JBT zu{v%Mvli=Ss4PpmggWu+_b#Q;Up$mNs{o9jU#I<;Y{~T`8MlE{&COCEiW;R!6P0x$ zHza>gfG%8^ZnaA{ehhTp0}&5GzhGp(B3*NiQIw7$6shc{ntgW6VPyP_3)FO@mlsRL zqih{&xtWR-j^}dOlaqX!L+h9xYt{E8q%V#-N--J!}iHX_f%SNGsn zxjY=vdnl--#C>%6eH5%h++Ua9MsIFc)c zbjK9;^X2z3y2=+aM_bm-t{h9HF$g-Q{) zt)dg3MJRI%QNiY412ri#hAu$l9d1L(0LRiF-YskrJH1j=yQIpDqbXuD z%R0KAH@r$usiwIn5ztucP4q;(fyfwYoaEBd&C6;xnR=da*IBj2TI-rn0c~9YT4|8A zo0o1j)gLXnh&02gBl({je=x+!7%h-n7cu1QW>Ry!qhUTSqZI4SwW`x zBF^2T>``Cd9ZMEwpNYP%t&BRT=O#fby1RDNcXiQXJ!J~LCl1J|=z5uIAh)%8zCPyT zAsErCo}HkGbFQ0W_O*t~1h=ZLw=nN$ z>qwxXmXAn|Z>isXuZE*BopK4LXfx!YlUdp(O(G*y)9l3O)N9;XOj#35(=oyv4dnmP-ukg`Q3*hgix2=Rf65oA^8bOwvRh@!V9F$s7o(Cz5N0+`XBj zJ~|o7e8ij}krA0_k#?o)D@(_CY<^FjMqZB)gqZ|?c=R0c)RTKgr_L&+*UzS=ju;0s zGLUsbcl{(eRzF!GJ@+<|dA3Don(ij6Fy}oFYU*-^#&pQkzHb^H)a#6+J)@aErsh_F z%~@1tB>O+ta#l>nGf4Zfo~5EQzx@BV^G;f|SW4WA8ZGeEC%#J`l0uu$9%IdOnw82T z%DmF!-14Mo@z!>B#8S2B6z@)Th|XW4Mb@1+F^t+_hfE@N8!ujvMR*}gN0b}DM;b1f}gn!MQ+sA+H5SMDHDy3gj^12U$R$`SrE%%O;=bmy|%kMhN@v( zWOHrzwq6+sF5{&jACAz;QtYFHHOW|aDsEE0p<8b+c+NW8O@=tSl1zFwZ%E3mNnM6a z3(efPkVD#@k~+E}r<9iyB&XdlFGLSKo-=JxB`1>8`6zVHDZR(;@lY7QZnn8GM3@mr z&q7W*eOX~0di!*~0F6K}W_`6eI*~&+SMO}NmaB|bK6?+#I#Nl|=D#q$dlTYlZiGQ3 z5wo**x5Dde#X58HFw4b;FUw~7V&?MG^xRMy zPGn0DeW6`VcD4J}cxeInN@u9v`LOqRlIi~~dA_!zxjV-A+v|4m0^N*QjXb)_f$#L; z5!QP%`DcpOm@+Y6O*ZUls^Pp#TO=n!wad-;M>ecctf(A1I;8~Hk&m0RD6>{l^7`>m;CpSB@8t!;Sq`NxqG(I9R}Y*w1Ztwz zh6x9A%xGsvV!Mc=-8-7QI-_Xx>rpushWiKfh+2yYJKQt-ck^{NxSS0i&HW=j*$GTW zMQEEDjf@X3sa=(cho(*@%_#n(q@u27x~_(DLIN?AHL!Fy!m@9ZrS4`?>D8Ieu2^g2 zuvuje&~T}MY>~?Parm&=NXL-G=iDRmY04wHyQjS?Ve;i#C55mu-m{iW(d!y~q1Cn1 z?~F;AN-5Ru6G|kDS!du271pPHf2VQR6WfbXL|(p=&V#RD2$gyS-0XF8fpK_IC9k8g`LtNt&P#yA?!q)E4~vF)zzuKH`Trr*|xMR z(blo8*IdrWcSNjfSMpFtzA$23H|j-}c=w7}PhCd>p`-ZfjOKcX%ko64;!7mPb#uZj z&v<>FkmGYptedZxR$AsiObhJFQHqII(Nj+Lw zhuvVLu}V~^rf>ANS47Upt=w&?kLdVwHxpM~xs`KN&(cgA&3;5_QB6zCH0)7Dz8+i20^+Obt@i_1u@;@YXTE=A4n{hag65(rqV-#_~Mnimi6&pr3t zbI-orJ2R|O?Y7YZ87Y~a?)m*C6!!@EJ>IEe^vl=PPtP}--}2g#eeA*9Vovrm{N`iY z4bQA{5!!>m?AZQG+wKah^d5vDrB9-y;k!|8V(Hj+8F2S_xRk^d443029QI>Bb2u8I z{>;Ez7x6*LY^m=-;v{=`H?wWQ($f~tS+epoWwD(enX{<1)9sXSPvar=UY<4%BLv=x z-Q(Fq_WCKu&KMsYxxag0Y_=2yS59eJ+$_terSpt=PLC%>T3Yzo@Z#V~*`;o_Uy>$^ zZDU4f-KNQ}Hy9K31XsHK)#hm};fOgue|+q+h%t-u$1F?7@oO(Q+ZE`!`7N8c^99-Fac#BKON;$GoKYHoXQBmhRoxQ4cr!(36{ERX@S{@9I zYZI7X<&UQ?kyT#HeCz&;yt^D*w>$D~g?{b1$()Hj>+*Om$fW5U(UHu~>XBX`+Y4?RBc44ol?NmI z3a0wAH@ZA~Bx_xMJ5$dc-m*w;kP_l^klySvcQ?&3&r1@aNBZC z)z=G-iB@pwLV06|Cy<@pWs{uW9nD{hU$!my#XweV;|knKc`l4YJG*uB6uNThaSIl<72Mabn(Z8s>1^ll3QEg-xlg$8>ha>vuhs7mABu~Y zuyk9%9LHI0uF^R)m3+;7I$;C(gl__hD;yP5#(c za$`=*hkImOwxaN>sKJ%g$Y%M>!i08XHc_Lc7vf2(Em{x;GS_gz2y(yKc=HTxAz zIUm$JQk^?&CO-DnC|P|PGj_*w87+}E9^sj75=*|LLEH{=KN!V$@sF(Qv2#|mxX7Ax)@|IpZ!KRK z$vT@E+$Gl<50lLMn#52>M){B5F=o7AH%05Kb^E6skESkN(n1wjK93ZHYiJ9$_-8CZ#XMFYnjEkJP6hGPO ztNonye34dYtsPcu)U!?ZTtfBldv#Z^XT``c}d z6Y$gX8fY1 zZ7uG3ojI~-f(#1dX$XAx&!S^n?9*do(#1!&P``F^Ket%Eg~R<8;SMKd7!NO8GG5-L zGee5BSKJ}sNu{)9wftCDn7PZ-xg>_x{LbGe7+;2M-m z18$Ft8x>$4EsO`~w36=?#Pisrd1rAF1C9A-$+TsQs5TIE|<4vgWO7E@SNtlHOgO z-biQ?ZKAw1GLi4dCSKW4_mec~prU@vrqbhDB8zz@unceZGGoFe_7O#pdL_LZ-wSO^ zfSZrnIfNtUn}Z5mUH5^b>0Zp;T!oWQC5>w%0rvArX8$4{NXqNO&a*xRl%cK7Or9`p z>V%_5?dP819B7)uB!FtPh~GSK<~1L_pD)|#^NuA`9DWIXM|a&u=J(OvnPt*D@t}XH z{T@NK?0I{9$1_W0%ECL5i(3}+{YQphOBc16+48|w$(;@nC11J2n|UbPKNp$i@$Ky- zga(+6n0%(OS)6T0>8ey@f8KV(b=q- z5_WT2xMkkLR(^-tRTE-76nDv=d}_1X>5HhW#>^N$ZTi&Xj+`}LA#n~NRyQDl9tmJ@LTtMd!ijB$8vqkf@XGn zBogvgFF8my*vDTj%{<2<*PHp`_7#TDQrIU7t{cgoYFV;kA@7CByB@~u-^mAqQ9F+@ zmPs{~GS>HyxsRS{%|61Sezml>%#)wx>~zr};=-@y(QpWebyD*@D-3m3o^43tm;>i@ zbBmc1Ubye#wtZ>#`*P*J5!XfUyJ+OTBlmTS&>ej66^086PwyDQf|fbq-JK_^Svhs` zQTTb=FSN^nq9Vu4_e-K8i#209_(Mh?|EBNvtxt7y?6={?eXrWbnB`YG58~HjWw~?E zLk%wLW)fRnxLJqcyqhc!+Ni&CRyI@Vo86;3_k_yq+P^f5tS18ND%@P}_$mJEGr+>9 zX1fSaD`_X6UER0M&MZ24>oR`wNRBB;Cy!l)C;4Mv;gg3dcBilF^zQsfvul5gmYz&! z*E4#Xmp>p}ozY)0WcuBp^zM0iLXP(T{ausQ(@Z~kR z?3dH#YlA$d+hu&*bRdxTjnAFJrT<}Q=>dOI}$G94>F~X8Xp%JM0v{oxC;-_`LcasIUMcN21_{M}tRpZbV$7v;`%1#h#sQtbLZ%||Y6n;V!^5TmaA zxh5~);d?SW4+WlgaesD}mb65&LoKuVE-viqK0GmLWp2-uc&F&^+#acaY&_?G-l@|U z5p>b-;J=F>E9)v;-qv<6`PzPc*Ykj`eif>#@OJU@EcR99uKfQyA9d{NI{Ncne2?@l z%E^AugAS(oUVGQ|rK>N%cD}pLZ+ZlILBGqHOjqU4r!%|y4RU*nuKdr*?zeSui!S{0 zS;@o4*)~+YACxj_In=W8SnHx7!#?~e;HmC_|8J)Ib4kw*Wt5!)?NpjO`1{-oaTo0% zx4RSwWRv^fPp>hp*#2KCHDUN0>}qeRNx_3_*>OdO;f>hM=madOGv+q5ABM5J(NP%3 zHi;jY!X87XVN*TdQ6>K2H?Zf>A^1b=W$_RHioJ=>z^%QFc~>;NslPG)+ESB*4>a(0 zIXVSr`iLhw1kZ0Ip6D2S1^Wn{fg=V>TuV&=o`qGRWAJxaPjnhK3^8T^+6TXm4fEh1 zo~SwmSE-J|A7eiJB;cP^r{TXaxt5Jr^JPr@_~A^|A^0)2;4AnUYRom5=r}y&OHK#i zMr@OxG{R4?(Gup~#tg?q`{7BdL+}br?iGiBRGo%>hBS75g6;5OA67&^#>DZB==bpSqvNtveM z@X^McA$0~uRSV;ov=`wMnB-LoexjP!D$H=q<{>;ubqHRhItK4morHhGBn}xk@as5SAL;;n5DU=mQt*H=P6yzRR43rbgJ~PRsgv+rO!7Gf-&4)Q5Hk%E_YnM{>I8fe zlXjYhGsekNFZ{sj3C7$}!uQt0{V<7V0M1n%h6zl<6h4MYI4QUdlf25n@`+A+;Y{oX z+(Ym-)k)a<2d_lm>E z)jb72R?XK~nCX~}54;eQYh&=Q>duGG%ot445`Yh3qlrTrmK{Ud_QoBag!v_1aBGlu z-ix%uY13SsL-1})!cW4gV_m*^;cV3*c&q9J{7AJqj`U*QZiEkePN%H7uMhqM8#bbp zf-~k1m>2gnJnwjS-xwTz0^`?Lsn2jJCUFSEnwidz7e+8?FH!h5CiOo9FFVm)8;8$h zO?BLBmN8SX4``b~c#Y~fd`oo(9x~hc3BU_g$KW3@$KOc+>I_^yhi^oZw8PT5PJ7|ys^f6# zJkmnB2jM5G`Pi5_2a{`Kuvd%IKKMFj>i}$Rb#Vy8S5#+U>&fa4Us0Wbtqas0zM?t< zTNkQ3>~{+N7-i&xjf*Hp`UgK8v4r`_a5NklW=?{80FH}LzOOAx9oo}(jX8lnCji^eb~+00ILGNE+=fXV$iQjmx-diV0o5tE-+6j3 zxIuLSzK_{@yT+I!FwsGHh3YtbRdog)e!lvF7psoLmsDrq$nQBn0k~Rq4E_%${Xha{ zuz0IpUYgHNbV!RiZ~A1|D$ItVXP9fMD(PQmI6)jyo7ItUkF(!Rp*-itI3 z;hUJvLs+-Y`SHQSFmVsTe_<)wwTW@xOUX~_Kny;3nJc>#j9<=a#s%Rkm`x+R;QOx5 z$KXC!=okt&VA8jkE9qx2+t0w4RC}+Y9Iqzd>EBZDEzHIfPQBK-2jQ==oAHx`W3F@Y z2|-`nxeG`AQ0o%RY@k0JLjJ?MZ*V#Ze{myoIOb+a_|=WB5A?$y{MhL@EV@biM%a8a z{XG6d@YY+Hx6nos@L#u5)>6-Z3b6orB0K?;_{ZSB3GP9=^TRcmq&)^V{mjKV1>eEM zzxg@WVv@!XZ25)LVfd!%4E)n=&OHr(bUR@YegZDPgRzXVi^6BI1%v2g;rKi0#|b|O zw_tJNnT9oY(GM_h_riy;n@LLwM(!q_=qU8xL)$|OH)1zP+Tn%wk{-sI7+idx(_wgY z(&-@l2eyg6$lPzt_1H%IC*Y+II6rYX{x^h)pCBBy$+-vMzz3c7!Hcl@#32S>#3ZlM zu<;?+-u*D5ItqV+iJt^~8IwHEz)`<-bus`?!NgA(-m2~i_%SB=Z#J{O#N^rl{5B?I zView{?nzkuuygmqlQFqA3~#{X+5~(Xv*Rfo^@zp+UWkeR7<@wAQ?UL~=f?-NxbIoVyoBR7c@f)fqVLN#`d7A5fiw`~BXz2jB+P z30U_B=k9~wQ5}Qts5VbAN5v#=VfcI18FdTNGY0sQ-6P>#{UwW?$AXR4Dhty(zvH4Ps|R7c@&F_~|tV87R$_Q9tynPa5k zCzQ~H(}zQfUjdVUGU(yoO=LXpgIO$!m26=2ab7L{lJScagW2- z)I9^o{7wD9E!f$tH`B2H-|3@SxBB2o?>HTTYu}~6lCp!>yhoX$<8a*vtdYbW&i|13 z(Laab;s0_v2%EM$?T0`6h;eHu^$L#r#OWY>vcu^#l($@@PBKcF2F%t^I97E4HmeT7 zvs6dnb*kg=KGjM1u4==}HTz={W&rjsE;3=_?1RT);vRzQFv;^6+=AIMg5@Peww`!l z0<-Hi_zot0izzKKhhd_F@Vl6dO)*$p=ECv9)3Ki9a}@p&6aQ&=P&cOo@Cwy&cw#yC z8%*5bI!xjjgZHTpRS;fx;>^0t2Zv)<drwRBL7QjyiPN^(1XEV+Q z;Vq(ZPrx2kPJ7{E)nWL!>J;3!hx6ly7pjiI;nleJBM$IF(Gq9)mg)?gQsdl%FoH?C zqHu%i1pK|~H1v3#A1@4u;a?b5EsU!cPOL35?K1wt_f(rY@&Oa~1f1TB{6U9cXC`Z`b;QWW+OTHpgRYN|*uQwLi{ht6l z3zIq)g;!!yMsfH6Ci#$pp24m@c;P|AC}Y}M0KUF&5tDAhgo8#EnK$V_eQ@gjMP>l& zz>oQ}eY zrg+u4wLZX@M|X$fAS;%H(;_>O~8g(dM~&HlW@ZDJZutvV(>Z9xToQW*-rc6S*oLO ztLhB=+Bcn_0Gxx_yn;_-a^EzZaFWwOxKVWieu&w*Ycus66CH#bF&k$%Eac+shifo# zkHNPw$-@jhZ4UK(?@|+of5xQm4b3ewhs-N7eHkkQaNvBWeQ+_>j(ZsX36uTmG#t}f zME{C^xDHDV#Ae$VNC9uf<+6Q_QPW^$&(=LzK}M>9;p{T zg58Zy!S9_?WRC93z2Nbu;@+S5!@U=AFWmib4kqEpVdY{parVNaFbO{fAHrl0HU&$U zIPHag%-~130F$`@>I6(-;$N6iE%dfIccEXka0M2ntfTNoO!jFL@SJ7zt+>bF-!O?!2F_Sc{XvJ| zEtt3`pjknmg%(c0o{UhPVA30c)o0M}Nq)j(G0C40d;r_by)y8`GmGp#ObD*S z)<_)Sli1;uQ5yF9w#zFY{01g@8-izIlJ+S4nYw4-0cUCc!;w+OP3nIb&RI?TM8jJ# z$%h1d4YTQmU;d8Mez;I|82(sw0zQe^eRB9QCUG#|r7d8heQ+EmeuD5@s-y54O#H;* zA5^Dd?b$`P{`=rmYz=GkAY7&HQ5eVUweZjCo`!Yj=zZZVOxjThUaIbK_(ye5!*1s~ zKVG;mCSm&F6xBg^6(-@t;Y*m5T?URk&*=bMtvUvuRGo&6Yn&fH9DY9IEP3vS3o!{N z4DZ3DUrNHGzNd2=xJ7ju4qB^u0uR9?o&k6P_9oZH;Qg4yGX)R1fI2(`KQMxcdldcx zlfEbkCtSq!^hH5<2_}2#aags^`S-%Juti?3h5N*cOpG~+AKr-_hfczhi`74zhS@!N zxIuLSK95P9)8c*!ZH9X#;3=27u{8{@!z7+@_-8C7V=nA|nT8L?st&+YREOceutl;D ze|eEP5|eO(@N(=a=FD-J#EirN?)80Fui~)g3Rlm)u;@zSPkr;kd6?{9h2f2ugqeV^ zsLsGIUFF>Ua2_UU48xn%Jpo@)oq=Eefz~nDib&ithtFV=t~BiZLk%C!P#uDoVm3bTLDebPYl92N2iIZG5`GN6Dei=s zfd^ght_{E#CifDiRSOs2K)>3P_`m@-(SO#L8Xug81@Iq+o3Xy=6dZN4n`;H&`IuM2 zg!f}c{J{4yxtF5~jxz+jc!l$s&1IY(CHKBgs=|6LEh{8B#^8r4C zN%(2l`{z#k;0)CvcqwN6!)Mez4SWAW{lgimL-10}`iHUGSo6pn0Pg)uH@ES_(=fSj z6yB~n3Exv~ZZ9$iV&W$N+b{_;3h%+v%q5cW>0dFolRW|0`woo{oS`}d&%*4r@G0yD z`T_5q)Fn*955luCDeD-#1rz@X_zEU|GVn{kcG?eHRfpmAnDq}|QTGfSewX@(t*T>i z&fS#J-qZ(}!DPNFoN|vVqaZvBv-tyWSDl3KsW$hLKbXxQ*oN6O!o~L$u|_0(_!_oG zo`J%%lB~_qQ5eU>zwjAM!b!vS`7dtbQMZ|LU- zkZCACa+2vU*O5A!|;DpC*Z59GqCbujgL^ZABI#5W2%Ko z)xu|0r{Px~(fh(xs-y4`)hW0QleA=D<)eCEI2MyQ2VkqZhvD~C$KgY&Q}9jI8QAkN z7mg23Q5}S5s*b{2R43r`s?%`0YV$aAT1?_5JVtd8o~k+w*Q$=en^Y&@3z)<|4XdAU z+6%A3B+NK`P<0A=ey3>{s`kT4s)KNW>M&fRItDkYPQcBoQ*fK=3@lG+_(IiwI7xL7 zE>InYYgEVJM%4+pS#=6lJ*i>Bqg4mt1*&84G1V#fhUyHg`Mq8%RPBe8RR`gzs>ARq z)p7W+>JIB@ZIt90>&cO0NYWPCcemF^W5H3(1hHF&E;Ez=&;G?Qj@P=nymJ+T{ zlimv+tU3T^sSd%@RY&1fs^joZO!6%Wx2Ss>ex%y`nSL0vbpXy#9fIeoj=_ZL3_NlR z=^@M@T!6{7VR2I(gBw*R;AYh+m{Bb(e@@c{zot3>FH#+Y_o`09w^e6g`SW@$oU1wv z*Qt)dyD^(~xJBL5@ZuLZC-X!27$)&g!ELHDu;N8`tryP0B)wty8Yb@-WuPzZuJywc zF&WoF@Xwg66VtHlCH&*=gY=4&fev-&P%kMSpSrz3^hy zak&4>&OHF{SDk|U{?&!)hZm}j!ADf5;QOl0D~xxT#K8}bR~>>sRGol-S8ZOUo?te; z@I2Kq_$DUr>1AM_ZO)Gm&cx(7LI~c6Ng9)ID`wLU>t0j;upJXWQTQ*_=5_Y0Fq=Pc zE+*H8;T`Ipgn!56*}8dyc@t*+!>DTEotVTY3C)|%k1(KGcseG2qA;oMLi3jMAA(O~ z8P{*&hMyo@pgIh%z@!}G@VBZ{@O{b`KeD#^kwV z3ike&bN9hk)nT{}lXS)4BiJN#3cjuG8R*@v=;c>=epZA5}enlRWU@jeir5=xcrr?E(it)p^-U>_8bpYVN5)@Y{6V;V3C(+6+Fq>oL&*D)I= zT-w89_tL}gkC^yLS9?si8h5Q122~3$R2_rAQk{gas?NYVuih6<#Oxj;jHr7QZd9Fs ze^i}@y=q-P`{Co*mFy9v;3;)Zhhg`6r@inesuS?2p3Xf8Uss)h`}A^g@WWG8hv7}C z6Yvey892VT#u=_v9fOamPQl7P&W{&PRvmKJ@fbqZGQrT*b$)j_ycbqqeLIt45HtA99Ibr7yq9fOamPQl6n>K{&4 z9fYrAQfEQ~J?1-@w2>GbG6?_74gByos#9=YgL4nVVxQAqI1!U;gYaC{G5C<`6fA3W ze!N1}et76$>L~qi0M5cBoDh5fleV9N`wekA05_;kz-^eckqjKxJ?q24M`l0e4~S0L?d;!X!-Jf%HR|Eq8da>Nwo2It9xQa(=vUIwseK z;9HpFZ3dn>+UY3l@inKta118%-T*uSvvmVrfk`-V_>}529Qk$aW8wEy$KbRvu3rto zrh{Fd>4&p0`z}3v4wHU34Qm1()0aIxFC07G=>S|h(djU}_6U#35VttohDjd!CV9+x zSdcoFhUG`PIQ!wns*~{7N0A=-#w4sc+LfIbzBz@l2KNjcG}Xn;2WMh7-{94lj1h78 zoa!{3d5p)t#}b0WgPfxd_~G%GO(VP+o5ZyVc+@wX4#F5F;V0qJX&!T%%uis$u^tm= zuhj>Chn27gm4as-=P}E0kHQBrxi$rRO?TP{TU3YPHL8JYq5G;5Y5{OXA=K7M#RCTpK0ylR%yarjS6 z;%{bCKQYlE_)FDE*z22`F8C&Pw9FIW-6uJngcmeB9fPw%PKV%rbDR#qs<}>k;d83f z@aOZKdlFtS-{}~f-Qsiz4sGRrQl{_{OyU`bubk|$&%ZKo&;qT8a3&^Y9fB9CdkiL3 z3*T0qfrA$6wQwe8uZ8De@|-jV@4=+JQt-J&^eOWE6b2SM|H8Ox;mu*^Cjswnb2WH_o`%~oY5Qh{$LxiP_Q3-%dABM6r>hRZ<*K9bPE7nK;gELv zi^1d}yjhI=Ou!Dz)`6AGomSCq+2i%Xp5JDUhxWkRd z7apKG09#at;pM91@Ilon_)pb*akLqV$$kBB;deadGS+Becq1n63E1#m>O5)j!Q(M; z55W`9rZ2@k1V^1in_(>(f>&VTCl05bOFoMqcrGS+6@$-V(*LJnzw@+Q;F*|Q8-@3) zdkWUB(K3a{VKxr%3U!aeKdMf{`tzM1AH4B<9`h_~kOVyG0^03R;sejYB%Bz01e5xZ zf}g0nxsbjN6F+`DzJl!XIPeCjk$>%((~PR7~n^5H7|fo?&>F zx<}!)>K=!8s!qb!u>k%v@S@Az8YBjXf8S|8d_r{!F290uk!K>X{3_zlyI5Yh;s~XalFTC&&)d9Fjbr@c+IsuTKy1135FkG3vI26It)9o{b}RgTRmnRCf5ex|6uR(Za@Nd`ziNBdto%e zdaQ}^g|A`$9@GtZ?a%3NhfwF?J(%1p313j1h8?QSFIZ<`axXuer8)%Hs*b_CR43u9 zsxxrVZSGnhoT@qqPgfm1t>w^k;94(ei8^-u@zOhx;aB zaoUCBg@AT7suS=9)oIxCC9VH(2`2dyhS#c&!#`mXPw!UR9VUK)@O!Fb z@b{SbNyEW^aoP_Ts}92rm^}ASz{6j5bNC?qE+%2dVDYQe+iLO+&cf_ID!f8<9Bxsa zhKFo(Wfy?o#^m$bQMdtsJA$Toj%M>nuo%SW`FF5@Tr$g|d zH>n@E2jI51NQ2D7;gq*sdV}zj#6a%~)5QBiEx5M~^X>F(SE@J>v|f+QSKS#0y! z4>zbzz_&4}+Zp(cD(620e^AT4c<&+s8|sRUk1@yxS7Q=p4E_z1Ib;TU>Wgh2dg0!x z{cx=606eWH`9Qfu;TBBFJq@qu!?pDJad<@EVmo#O;X9a=g|}a^8IMWc2I0>!n=TmG z%W2_xn4~2JEBhC_z5Dp;|DnZW?-@Dg-L(ngZE(XvR|5n*N-kXS2D&W;ATv& zO~InCQ5I+~JaA00xs0?2;A%|VWAJTkwA6Wc>cKAjF#Ikibu0$o!fc+yeGYN)@xve{ z_Y%f1$sgf8n8Y&)KT&scDD?;v|9&_R6aQhj84H#YXXqL0+`aJbacJsl5`O3KV!rEw z@Znn%h?|rn+-G93eOKHMPgNa;H>pm*H&kcf-A6e8N%$%z@yWoMlPE9BE(E`R6!GC* zrT|=w$)0K$7EdlVVe;7vr(zQSAncezJ;A@3T5NuG4B?=Y@ZIAGzY_Q9#bzmX2RaOY zh)LZ@!15U`Z+-A!OyZMTMjpr0y9w@&xB609UJy!6#LxVdG2<2YyR+ z6yC2o1#3@qetd9_>M*=dbrP1%Dz<%{7p}mhEl1(cFx&6Jrr9o@emEJkbpx(Y9fh}I zHto=S)7@9-!>oUJnjA~`FchLL$(Rd6b6hb7VJ^M)HXq;}<*+lX%A9 z`Kw4T<8KW95)0y9UBo-4;P*1*fYO+KLG@YE>xLWkiSn0)pv1N(oc z*z8Z9`(O~0d=Nf@rRghD@Z57KJJuR8c-I=rl=VUyo^(F-oV7^^UW`c`;_yk;X;}L` z=k9~wP#uC7tB%7bRi|O?TJ;aVp*jRFRvm|b#XgYu!`chfKRg$c=K?Xf0h6&K0q<9x zf=6Gd`2+94>=*&xQk{VVE^_WZI7M|3o~b$pZ^k6O3HX@mH2llO+?RPp25z~8JS5-J zu<=q?)_(XbCSyk$4!VqbQbPX2IoKq^48uW}({GSQAN(z5>nD6iwfR23<&25{5WE`` zKS@}2g=?o?7{J6`*y}3h36z~rh{=6J@CTTL6NkUSBwZ;ud_C=l{W?Fq7EADXz&JeS z8rn8{xk1?PTE-6v6aD~8p%d_FOu|XS4%Oy5+AJpSK6n5odn^I?ZS019X|pgEC(hJK z;q^Z(HX(EZrZGv2@bwL@95e91>s^=u7*Z{ygSBIX@HA{bbu|iqjY%Am(ElSHFW})f z7Mn9DuOMv0#7_+V4KuWb4BY>JTwVp>^_ZPg!N!f$E&3=wT!2Zvio&fwrhbrD8QAmY zV)LBrjl&DENfJJsdyA_NVVJ-IxC>|9T5PIh{|$Zzv;8XEq&fu;{3-c{{{S5QGi*OJ z?DuoVLbMNlsM`F3F$;T_{Tvnn<`c~7UlyC8w-YYzemMPCPKRK}9pnq)m^+Kjb-#8x z2@k!SxXHcXz4thsglFAH8VM%~$KLPsM-LR6lYc`$i2nufVeD-5W3YS^br; zdmLP&T6m4>INYpS_`K>g3_MtD-jw*jFec#$&sHsb9Fu#c;H-z7o((7UPa0nE6!An0 ze~L-ml5mUam!a=zmxsdHssqmyo5QdG*M?v#HVGYuYq4RqrIt6MXQPGpW0#?Yf5Kiy z3mg7O-9QT$Veg{D@Iq`E@rl8WsuOUt>J;3jIs?m}b^g6@m})I8gLbqeJ;qzyru=7tU3&Tr8)__z2N+K z;V?|X@xv*qgK&ZBFkGWL1~;lsz|E>taGUB3EPqk+0S;5`hm%wX;R4lRxJGphZd9Fs zn^mXaHq{wemDcd#Fx7rIS#=PesyYnUsE)ynsuOUt>J)rUbq1Edq~XJTRQus%)j_yG zbr`Ny9fKQHC*Y&1Q}7Md8Cbbh!-xB*_QT1lgYZ<h7U)p4#1hJLvW?)D7;d29Nwim3Ad7e-L|P&lW{x= z-&1YgVO@ZUdjPIb9fj*L2|o_+Qk{g)s!qcXRGWVgAI!!Fj!_+eC#Vj=(^W^|eVC08 z+@kJjxLvh*moPCKCLFCg08daIg6*oKaJ}j{yi0WwKC3znKTvJ{N%)uzAC6WXfG4O9 z!FJVAc$Mlnyi;`&KC3znKTvJnqm5uTd^kpR0M1e!g6*oKaJ}j{yhC*oZc&|v+f|$Q z2_LiJ!!fD@aF*&2JY97Zu2&t0cd1UoEvnOSyK3_R;bS&@I7W2<&QcwMr>l;_^{V6W zF4al6MRgi(S8YBde9VRq$EXg#S*ks7}K^|8n)p2ai`Bg6mbs z;f(FhPYAA49fS9&PQopi#4Qa!P;EY<&R{mqaE$5zJVA8`o~}9y*Q<`hyHqFP7S(CE zUA6g`@G%=c9IZM4XQ>Xs^Hs;-ovM>?i|RDouG)OUUJ7Qzgkw|(;4IZ4*seMX*Q<`h zyHqFP7S(CEUA5_;95EX{9HTk_XQ>Xs(^W^|dew1wm+B4#Cq^N8x(aad?;NB;2Ap4Y#W{MTC#p_`@-(18|n=5IkLV6s}hthj*z?!q-%1 zV5P@}>4p2K_QT1lgYZ<Rn>kt zS#=PesyYm>QXPkPsZPQzs?%`0YExEXW@0kFgkU=+8hh}z3Mo;OLY=Hr#cO{tLBU5On=OV55J~50B5NV!FJVAc%|w% zyi0WwmQ<9OGuU_a!t+$e;9pf|;DqkZPZ0h{bpn=F>b0;%br?RVIt7PTIX`}Qj_MeE zL3JAT@8SIT;Al+J6@W`HDT^??UUdR)#>7twZd09s2UR=&0eGhBD7+FA|8aPSYT?VO zh2v{VOoF{8;VkTBv@n9nwZgTkh1aPTK8?M}9#ZDDaRx!yQDI zJI*MbVFG3}y73S*afIBGwa zhc5U0&Btc{eOm5)+J(Exw@e{i*3wKr9 z_B4ntb8b%q$>!mXxv)p!>@q)GE=rf}aw1!4s6D~fPsxaX;Z_E;!o0rZznct^nV!Bh+|12|xu1djmopkzk^y}2;=O*|& z-)py5%ThZo)vvO5>B7& zVUB#v1oxWUaaOG4l^=G+vxhkzP$KT$+;LW%W8G`BhdF*|-TUQ^v*LQTc#2J3_Atkn ztb2X#I4fqBiT}N_hdDN~@xmKVx#O(3+Pe489_DyVx%lsuJI;!jbtker{3We>)ef%w zz6>sP-h%SdEP00fuk#kybad>$Q>xTK)53PKF(r5Kp9`Rb zoXh^n^`G^A;cOsTY98%$zu@O{zjXFKS!&+9`$;Zn`uGfT_U}JQUohJkvw>+D?R3A{ zpU?f0yS<;3lBO^8`P|Q+55}diQ~C<-wFRf^A z^`C7~;yBX}Cvpt&l&uRSU^g$(d-?-=i@}zq?`@_bBih`!NU(l5FaTT_^nZytIH0F*g zOrX8h~lnnZa`%H_}hPWj~Xzw7eJ*N@$oPq|4EXD@Ac z$FxXo3l;tMr6t&X=d{R7C{#3?a+#egmnw5Yr+Qmyf*j}b?el48`Eu&Io#oGWU5@$l zpH_}B(&+0V?NrnM6KR)npG;cqpT|?~?LJ~Kn$Al8%9UPaU!_kuOzBsSwq!_Ey6bsn zsK@SlzCBTo-SvD6r5?NM8TsrneWub~&!^4x*j>-J2=L)Y2I36`mIzg8$E-ojiC$bVoP7tlZQ`3P${wYfNby z-(Mpau--KDnSV$4&NexgE%=wYK2pZMt%=&}9~ohqBBf;#pAHFsbB)n&a>^!c+fp|L|Nsr1AvemluRb z{5_H0vZitA<=A^k@7}riO87tFP{JQY_;Nl{+3CDz6?^3~_4S+=Ti!PRiBG%f zC-D!@CBF8&r?;H1)bpNItcEzY`OW#EE6O~k)tDZdze88BvRa_WTiH2Z*}z-L{Pi1D zWm=2y@8QeObLY!fuk3HPb+hNG54|LQw!ElMt5*)>t#Q46u3TSUX->-DuPM^oZb{ns zk^Z48x>uUnx$w;q=NrC9UIK{nr3&N+>9bQ3sb5=bN>;Dz!&}JwT_y43^Na0v%}9BX z|MvcUdDB?W&*h0rd%u2?{~bwJe*N0_>S*A&cTaL{lP&Llw*02K^7cgVE7wc?q&^*S z6~E_hCqu;NP2#g^P(Pbqsc)O8udCPzlk&6mZ5`z|pttOeb)^4c%d<98QL&dCSFfYK z(>~jM1EfB7$j0EyH6{Jp2iW5Z5Aor=cU7Iwg!0#WSJe+Nb2xsPSP`GqD+kx*(zk){ zR&B5DCC7DI-riN^J~KCe{pyuHd0Ux3$*;lP2)|;0nO7KIV?I9zml9rg8(wv`ezaHF z<8l{1`QMb!Umn)2UO9v}va{h+-tAuNzq|OSy{=yA<*nwz^L6%om7J%&uUd z??YjF9$Uxx>U{jSmYK##W#54kp1e!EwZY?EKq+EZg>V0Y`k zNa9mteC>lP$Nc1uvHT*vJ?4HDbhX**IbT~czYYNlb7^&;iWW%@lL3(R?O4Mz7ZGI2zl;7oT-QE6< zc-qP?)~dzznu5j`Lp>~E#=v> zFuoO$`Wp9?SIVP-^w!m}-;}){`CpaKf69;YughI;%da+HevIEkB9&5qxSgb5^1Hlz zh}=I{KIJ{~=OumJYYOkz7^&$iU74nr`}gE?PW;JuK40SBy?gHda;MFO%N(!JX;@~1ZF{igUt`W)X?&7jCH=frhR@SC zM*7zllsEOWrYd(#dMNMyh2iF>+i>htyUOVwOo>7_nJ^ZC)gf9{y}*Qvf}`?dKm>DwUnV?c%U2lct~Y8vYHH|73P zp7lNCe3P7~K6mO5Ydko)!T|{3U<5e|Jfb zz5ljuaxC>F7aw1wy1J2Lm);Sir~d$ny6w*%YT&q`UXH76{xF|lI+uK3W;EZm^BbA( zq&;%I%y+t)PxRLL#NT(APY?_8uU~t4Z@D}_pQy0M`T0auxf>t3p8Drq#lzOQx$%-? z_Ul7({7>dHGXLmN;rh$7Bz(?SR^;Xj_PCS|ug=dOHV)?4HyR^7 z3&!gjkC*vaZazo((BJ0gUzA6-e%SJ5J&~P{xbj*@c-}s)eo~(wk@34m+7Exy-%5L| zt+B^C9@DU+!g!)|L+8gg}%;~nuKJlcm>`XAR``CcC;!>Nz+^_}x&a=mL$H6D&jBz!%V z>pfC_?)n-tWhDpp&#ecJ;5cI|5bY;fWBTUKuYfAwFp{_kmurT$9$ zqrdZB(z9oW#9!8r^cQtM;ZW9(`SpLN^Pge;-?Q`jpQojJwEiFRDeM2n9oPSTISA42kU>oEa+u?w(I}?yIKE_koA90@xJTz zf5QNJ&VuoJfIZHS*DQxZ`RjLH{||KQ<;z*myY;`k>RPz|Z^(`B)VI|ueQy0M{<~cN z4|4fw&&&F~yA7{k{cn%+>;Iu{d>8+A{omx)|GDes7wOnu+;RP%J6~td=hy$bpJC72 z^}juD_cI?`7-PO_Wu2Mx8C_++@0`2y?F-ZH*q*CQ2? zih;THGvk>*;_L8b{XbMwB8)}O3TD+<@+>{s-))m`IHda82g?R=(Sy=ucNSkEwD=Kel+y~{6uWN^pe z?ER^avLBKQkNH$%#OEtGCVmC;BjzJge#}=~{MlbEZ_{@abMu`Z-5bg6{Q9SUP;UJr z>%&OxAU9zeM|q4OzXt8F|5C8OV&`WC`z550^>BAtKW67s{z!Lu4>*^fjgh9_ZoN}v z%X>(qv0UPlo3FLk_zKHwNMwjzzuWaS{eOAeAVd9d^F6y?Q?Ng<<9g4|2WxxCd~1*` zA6cLKy4&mBe30;T|3KGILn7sNJ)YgKsfhHbB>w!l^6^Cmb8_*s_R!Q+a7=hrwm!P}%Kkx*p}GALTYoD%t*`x&nht`_ zeJAn1;^AkZ*P}A;&1D#?pN6PyR@H*{QirLPt(xC z@;Wq9+adRpFZ<=M!3oAgEV`Qi2Ej`7cWxV|ua>Z7;t zxV){2wK{(i9_@+s_z<5QOZhUtJ~ZN${BiNI^{1D6XIK1(ci8W<^|!HH;%C?Y)-OkjUUh zsZaU#Flb<|J&@nB-$eYQ`SKVrpx~JH(%|+}<^H61NMz8Ee0XwyHy+yY%GQqo0}9K_ z7peE<`Ww5R8sI*qqV?<0NPUN#xAjM^kJJrn;MkU@lt=rZPW8WcdxO`z!ni#a{IegP zA5U$4^XB^t;(KVMu0!IR4WIQB1;YBN7yDg&4`c6kU!8>K!%gy=^Hp{~Mb7g*ha7je z>$UuGwfMEy^Zl8`cc829QvR}^(a46X?gtSc>dz24cKbhWe-NAmF>;LS0rK-1EuagVxZ&XCQRjxf!|J{Cv*B+w%}+&-9z$p`)9 zk4dRv|NTdtzv(AG`N~+&MS5OVFZ{`*uC>}+)m}1I`qR?hMddPHUK}YYE#Y|VJL{}} z`#uCsv)De*Y%on{c)3Wxs{_dR=>!M{+RxJl#wF$-ADFuetpU?q4PT zv(L|XK3^jK?fVqZ5TC2qzq48F;^S$raPNmmeW$$)T|sLN8B-zoNq!AoLF<^S$CM}i zZ+<|?d6pehkM>@!;|;Xqq*>!R#n^g5scm6x={Cs!QI__6qE$>6v{R8{HL!H}Ck@Ve2KeMW$+I*Aq)$442ksoe+ zx|s9)K3!9!+?H3FEidU`DoiQ)Vf$0^gZBaQ<5?5mi?9Am#6Qp5nj#hU`fj?vwt8h{ z-v3hb-yG1q06uKGG0KD^E_{`^{d$4 zuej_Yn?GLm<8$R_&v$e2DXaE~fA6a5YI&c;KF_E9^ZrkVq|fFz=gYF=ftPwS_#*PR zm&}iJJUhlTE|bRrW~>_@8#v#ygZ|x}XT8sS+wNZ*nQv47hD3VV@wJbP2OP8C-^+~$ z9CJTE^KsgnyZ(_H56{cJa^5{pck>~ieHl9cyvHl=e(3uxo_2jdDp&vMT|(NQ*!oZP zrT*Lgh4+% zud1(={B-T1#!GwPM^z<%nUC4?h3!So^LR=2E9LzbDW6p}H3ipG-o5hWC*`w>>X2_g z9Cs_uJzuidlfC)-bDr0ATzkp3C%(KtU%r0kyX26|@0X?hRLZ-`?%1}U0rjr^xcIML zS>-8VgUICC&%k=ue#mg{&-+AWx%1M0tlG<)yPp1(_QEIo_>=ZEw}ST9-j7gumVxU8p4x3oVQ zLHK<-m~UwPl>VctRGyz_*GIA+#tU%n{MMQxZ+mb1fvIeJlJH&sGEV!K@{WF4|J0wR z2#-$*ucGWCn|?p#>C2ucKJ53DN`C46yu`1OLyaH(AM;_E4`tT}jmvCzqUU8lrQGfp zWY7B}dv)wp5WX)`Q5Zkw%T$2%1zk2Ykp7NJXPiY6A=gOUz_e}^t z7hiv5KnLr|g7ZCk5MK5?`CF3T&$a0-wdu{y*Zq+l_ggD?;Z6REY<@PaDDSZ2ht~hb zNDsaVG=E;g=Xqqo{YwYQ^r%Utj|w5Uqc5$ zeRbz4pB^Q$|B(&Pu8+6p*UK^<*!bke0~*hQ?EJFET*UKHx&G`|S-+G1|7*`j|No!= z*Y7a@XReNci%8oIOwcLRvpLu46t~xPD;#Ey0Iz@9XkDUVC}D zOo!@ay@3DKD|?jX&d)8Pe=A52@nii^F8#e-KX89pKLlhvk@URFc+p-`FX^%4$JPes zGu8FE`Gllz1+P=fd{Npv$E1e@<>p6>&z}7F%zTUSyu{6an4jADJHunEgm35jybr_m zcK)OD*;V{lq}|`l&G$>)e1-YGy`CRFwey*?_(9|v<`2AHm7TwH?1`}FA<2>UZu6rq zpC21#yzf<>%O5-5S6AC(8SgiglV7E_{QF4y9+Uhm&FA;la`LCLa)3_ETxRD3Zocm3@9z9|$&VsCpOgH|&EMU6f$>P^@AU=cA>&o&`8)lg-H-36uCe}` z7%!zh^(ni^&hN{3zsBy@$avc{mGQL}q2m|hKl5|De`&{)%`)G%`|ED}q5OEh=Jvy+ zzh*ye2=jGWAN17zTi(Bl?l2!u+65Bz>2;|ep8C%7=hFJz{F(Y?*W)r@l{Q6rdE2|u zU?o4Y^XbZb{o7j8llgY%{E+%0*W2}{%%_W^KH9cqJH!GV*YyS zYpKqs-TIRK3On5zTSg2eelnl8^_TfN^^NtDecxW^_v~+!+Wn1^p6qW}{|x_j{h3|g z)^u~n+4}CSb;q_o5MFN%)~^*>|D}C!T-Qy{d)#qt zt)92-MLsWL>w~tBCgy)qUvxi0+DA#*2RO<5?RGuX&y#B(Tjld6-Q3=sTyN*!jjlaO z`?2%efhA=!-<9i`uk-m48@^pXvwqFq&$chSe%{NIyWfUt&X-yLUTJ?ee4i&5o{dj$ z*PiYD^&-(AXga7_=Z2OPCF8va| zO<(`ATzu_uLs>366ekHm1zU-m>4X7F-{@wbR z_BVKd%-`(%@8x6gKOo=#*!}Xo^8MKc-ize=jci!0J>dy1^xw3!Ta|rXbT>EZZ)}QHhw!f%hJwLcg&fEBsKb*JQf$skDJi|BG z$Jy-sRrV{S!?Epy{Sa?^eZA}d+<3+N{=UNZ2YCLvWB83+?($FCQ+|IUyI!d;T(1m? z^zktt&fR}VWZ=LCjcq{T9H8`gGN+D6p4t4iV9a5 zF6zzu_Nl5)C*j`r)_d!%_3mBwtg3VNK4;JS-~a#p|K9r?U#UDLeLFo^dt83iS6BG= zJBl~HuU7Czd5@oX_DQr)ZOfSSvxAKuz50dzAj)@jX@3Ru4@Hc>>Z?opHNgK%{$U?0 z6HEPg2lNs>07<-#k19KOcu@AP>OLPr}bazxvdF zQv7`4`oH9ro^iMIBB)NWe^2x_jyC)~zc({cvS+Un&(R*@5LAls1N(JQG-XTKe(l+- zfp04MG~|t!0P~*r3VU`;*q_%8tRup{e4_n2yCHtmTMt&|M+VO2EL0U z`%->od-Yhn3jIGmieH3$Twilk)WE}YppQ8APk8eAV2S@_2NNEBSDqt%56@Qihk1Aw z^01~Pe*?g?C@}IL-TyVFd`RD;AIN7f{NWP*9BaQ5{85ZI&%WXD`||oXVefcVw6D*z zuT}QPVgFc5X?<3SpJX>Q9QFL=_JVyaj`rJB#uLaN>=W|1TwcO90(g83`HA*>68Lc= zwC`1AzZfqYp{`c?6aMPqP#Mpa?R6W1N696}+K+97HhJ-w_^@Ys{ZlGWHTEN-pkn;u zu1J441p$wv)PB%!n>=&E(LkyF3$RaMe3$fR$OB=2#dt0GJ^LKk*;1revL9kUSvoXo zZ`u_F9`ILt>&v4zmhIbZUim%y+iOCe3t)po)F1QF6?qvPZNEUvX!;ocSf5^s@7bR` z{LwO+z9)ai*soTJ@M_7PD&+MJkze50`f12#VNdtgFNpj-`CPWgiS)hrW9(ha(f%7F zO=HUw+gMfF&o^4XNNN3?$ZwUv2UR|gphSAhh5f0#-)mGpXT`{#TC!&d`*rl#`Y6wS zoz0f`>-4hyI-3>vsk~l8;J?&pc`%-QSpT5nxhMbIXQBO-)M)#7^(*hc6XgRwDh#$B zZLRd@s|tg?Ohg}3o-ym~Krh~YQUS(dI*pCbuqw2_|6Rc z>A+}vycz=jXiJI)sL1mIxBsnD z9_(K|Hox!LGtr-w@?)oZsw2;r~sx-n^+NUyK=-D5=`i8KfM3hJHmtNp2 zG)kqu*#9E&q~7NBax$o70u}pckxnP(6L4OrI5dx!)J1W&GjokE-M+^vU!X{2}al zX+ghbd&{VOUeL2AA4>DX9({ZA0n6Dwig#Rz{zLu1?-PQbd;H$RFM)*M_s9I0|E>_C@_1BUqkJdIS7C1u%~P6B750V7d?4t@!`q=y={fXI@Rv&a z2Sxh}dK;Y|02QIW)#dto;{p9QhTgpJz9G3}Unz}OdYt{=lTXoO(K7B>J_3hZNdouoo{Q}1r{2Ts{@vri~z`r#o;$IlTj>o^H1=J;b6vl@# zI-Uf5tgkxGUJHEd<05|f{RbXi9e_S5=m+%Y;obhSAMxYZ&nm)u>o?%{=11Z8>>sed ziuFr^KYRGf!$;8)z7_g8o?|?tf`WfzeTgU!)=AE(p#1N8^z5y-DB)j> z_l;OT0egTHy@NBY$pRmIaM%KK$vg9$J73f|0aOnLq1z3)SW zA8sno7nbqva1sVG9A){K9Ill;{vu?$DE|ikad^0TV_lg)di)3VYxe2~{nHCyy%Flr zQc?a2o|cXZeqE_Q-rtZcJuk0s*dQMj`1;s$h+{84>}lXvY$FO%!SiU3*hb)rWe6|u zA^geVdhdM|-hLJDc`{k%&t>}=_j==OOyn zlTU%t{ITHgZwUKyy0qXxJZ}^BYDPrU9mC(H^p33;36|n}_8hUE&sVbFjOJgmZ>0oY zFX887_URI&ege#~cA?0G1!5Bmi?jxdU! zGr~KDpJ$HQvmkE--uLWNqP!j+%Vxl`VS~gdeRVeIp=9s&>~CKA+m7iA+l2jU%zW~4 zQU3tuE^&DE7y3#mzBm8w#m^jz|7r;zm+V!@&l`Vu--J{nA;1-^%+DXzTi# zR(MPGf&3KqMU1dg_$=0Yj$JS5wO>{`_M9WVni_$Rz4|>_Zr@6Nh!6Z%lg4uo|9SCi zO8w=vznJe2mh=_jZ;K(lKw9*lx1SsIbHS zjUb=V^aqpmLSORMXNmcm!BoA6mtO;HLj2?Oufi<&GgvE+L(tQ5VJ~fMgKrEywMImT z9K#2#t)jn5^rN7^aw2*;d@uSZ7ZIZASooIQXn%Mb1T+kXTA8S zrD$*9BM$#jvHlExyq}iA&Yh)6DUvQnsMOJ;tz*H++LHUE_%p?<6OI>47q5-PtrO#KE~S4` z{B&ti)^YL0trO$lU8>(n@!v1Ue|Aj%!cKHz{NI);c2fK&MPM9b;|sZXVtlM-5XVXJ zw}`+v{$9Fx^%rvV#Q13sJf0l?1`!y?*!V(joftn1Is*-!l>g?EkRwM#B4dsUIaaQ3 zAb-J6PCQn5pp%fL>~=qYO@9t#|CQv5Y#`Z+28ndSJA z(pBUSmX_O8zFoQ8U+({t;?FCGD4R<}e9&+Erb@)}^{3_jJ1PE>a)2@MLBAa*#{X@( z|4xc8bYO9ei4Xcass3*xh!~Tid}EKw_r^urjy+EFi+lwOt9++?Y*^twQod2SR>r?b zN5rd)^K!7!%YlU@XXQFJp2)LOzR?ILI4b2F8}I+DoRrQ)+V77mW$+HJ)8GsM6`x2a$U-#p9VJAK}+=)tf-)T+vTyE&oNjkP<3Yog3kUx`{s&*!^ z68f2>*EY{29jb9Nu~bu>EOD~M$w3g+6{jLjy;|WT#8!U6?7v1dy(-jI6}k!;46eK= z-Lnp#kWc53_YD2m}KixVXA76{WjaoE<`P6Hh{x_+!jxOtIJGDoZ_{z8(>%o0=8w?l=(yHaZhLW@rUtDh`swEr-9Qga07_J%=RCb4biRhYX@2dWLNH3WNRDETSZszs7x~7W-!ZGqAmG-!;VFR?I&rg_aV> zD&$*;vXsbieJP1qk8)$#rTs&G@*rnZX{~WSp$SyTP7u3Wn>a*ilo{)+d>h()h=zUB z@YNNWUq_SHbu?~0PcxS@`NOz*Nxm-xzX9lNRIe z%(kq2mE34G6Vo>GSyJovFpkwekCE~HdOeBR_-h7}rnB+pgg>+N9K1ICQ+vChj^=SZ8KhbSx%<@;Q!aW)Zp zolhHBvuqRbokB8hK1ZZkMAh+atD53W(rx$cA@n_&wh?kJQKk`|(Y0x0wlTij{VZX0 zxsP@d={z#k)+WxStj0JC!vPMu&NpRn$k;8&NF~Z#8p!F^T-vLf&mduY6Pc=yH=jmo zPQHUkbBHpZN*Ue#2of4^E}$e+L<5kyL^*}DxV{qzRi)FjIpmTpuEOHUe*H2 z*dg9MWbG#Ow;WYf79cMJY4jQf@~T4!rIm5Ru**3upi<6o7SJBIZ@;K;B_B!p6F9~7 zgpt#IEJx(oM7ftq()C1j<{^{6`nrdl9fVw1#gNH7tQq}zgwGQU=MdpJ(}YvgpFxeR zJX!>8g&O5ddp>1nk?;cq70nSck0__2OzvK$?7`xk1Af*)f3x2ejk7T>e-v??geaV+y zXEzc4Xdr3hZ(i@qSQDm@%*xAfa`MZAdah4h-*HyAPOLjD{Qew)xUs62^O+uKQYX#FBgWKjA^n$!B0QQsbN zS#d&DXmI8FTB2SvKUtU2>|kt#y=2L%Wvi7~PFupslsR*kEKVvjNjAD7smvj|s?8yj zqL&Hxc-7cSCaI=4S>j}ilY{4~D^5k6Ae_!6la)3;S$Qy-=r)`OliiN?V6w-u9!zSM zJ>^-c>V>)%qpM!IPD*$8OC)SQ3X(%?> zsR)ww&PB*X)>egL2*)!KF!7(%`u5c+@f7Jb%y>%bHLQ3lL|r0D^(kutQSCU=G#o^y ztz_z;0l(dAXz^44Rfwk=R5<;49RbLvBSl=7z{s)=nv8AE#vOj(I^SosNr zzC#pz8!oOdCMjbo>9F(7q|49_B$XMFV_3MDI(vZKursg$vvgz}j&Iw7F^0z1r%6-4?4QTz>*tpTesmJumo zv^Fct0n5xI3BFvgEy{PqsE>jp6S|CMhs;w*)parY3Cj199?e{m^6w@58x1k*d7p9# zBHDI>UqZqgsQ>u}dY~qHJ;^$Gjqx5Zo;#>qZx1stb(vnvdxzkbG5W`w&M>1}xHJs7 zh0iPSuY*?IVaDlcq;6<|>_@QE5brEST2~DE(44DrS`Xe6wJgGN+ z*~EH4m(J&jEI%kIHxQ{!hXhe3qokAQIMbcXXpqX2nOdBTlAgwaBPTOPzXYI0swOkv zBgtU4MR^njH72uA&IaGWAVTTn+a&f#QdtS)Faj9yMM=DNL$#)|D*ucWokOM$=+;fh zb)%%*M0^m}UID{XXuW$;3d40O_L+a1HR!jd6qWjX*GkHBOm1@!OqtFkdWcONG>g*_ zU^)ZS!mYS-4cHgFj6Ld)4Ml??PTB7oz+$H}MRlj6d5nN$CzNX?Ij6y|+>V*96+w!O z2S_x78f=oP0^dz|g)<5_5hdYb(Xjt5)?H>Pj zA8$ey(?)w)IEZH9Rj+SR^AV(dU_@{nP|MVgv$LwN;yg;BUZ9HOT+fgYq^ zlF&CKx>=RiHYqEJEY$~;71%v|PfKpdfctbP8v5tMA$DcU)FDIviR7cR()@lx2SE*% zdp~0A_A7N%222KCT0_FuQncqf(rXoeB>CwCF|2waP{DbvpJ6YDfCzGm1m^Lut^}jJ z+cGTR@ES?^a4gkW`hyw6c&7)*V#vSouMkqf*5CK z$W!}u%P&#DB}l&>%AS^`NuGq@?>+9wf54L&-F~1bJoB! zJ<*4nl_?|{-QTRtAf3T~PC=)fL9(hogXqzl#m!dLyaDph!VT3HCr6xIaVoB8;;NIf zYZ{CWLicmLuSrVT^!}!9D3=NHt!nkc9-i&#n>ZVUH=DGFpH2AY)j)(I{W)k-$jo+S zQWKDwhNr6eRosIJZC9-aEskLAsy)Y307baECNW%uD=b_{puMZuH>ZA_X8u0uOJamJ zrZ`ye>&f7naJ%M=X#BN~q}>@{6Kcryz{SQU|LLnP?is2tvNm#)qgk%*xUTW{hQwfp zMcaiiN|O4aP7AtR=+;1|u1BqC!1ABj82n|b-ONuWbQj57e5nKe8ekwdZ8d@1nn;2c z$>|Uy0GOmvVlIsv>0pDF4kqy$2C5`S<(&3a9uNo*`^eXOm&19jA`nnpk zc42%dZ%q7U6sjY%+rnaq##sQS+Hw-UzNto!HwO2F==VZ&j?5njP0^f*f3B)?KiaJH zQmMWuH$<=KlJtWIgUdyoJN9ivkQv+&K}lonbU~!LrSIZhGY7QMrBF9j{ljlSH@Re zrm(cEtTNB-E%rH4O0Sd4=y6G2>ePxI-65zopBHLP^!0=^jdbeTI;(0kA?I;me`z-9 zvh~{;`!ubw?_kl_JxH!!P{xsF`5N`EDf5%0wh)z+pGBeHjg!J|I;qd7iQ&KS(zWY* zmfxfXFK1*h(ABLUlsZgpTW<)s!0S+>_+yo~nffbIc(*FazTmAAyGUl2QB}K5^0C>~ zy{5AUm~}Q&Zj(rdiyj(hSwB{5-OEXfjt}Tk( zofleYm*ear@{@!uBDLCj(yXs1_4upSB!k-P0e zOTFV*`SXeIr@ek2mxg)hO|a89G0@(G*nEjM4Zc+Dba12B*54#P`qy~$YqH}QC?e9B zM*JM@pc$&it^rKe!M*|`m0INf#zOII}atc%n^oC?Z+_4 z_uBO>jyWOle31T1=I5e|-B7^OW;E*)&=eh~g>WELD1 zOPXSrW@r4Jwl?GseQX?HNh>I6Npn*FmqZMI<+~6|&kMNl7e9^LSNo%X@h6}t?!CU zhpL}V+~|)=Zu1V}supgjwm3QB~fmuToU>>s(pi5y5npvJ6sr& zAfXW|Y$V};Q3ez<%aByvNXqoW2s@XsDTMt&_kAotKM|qd;d$MFQm`WerlA-~PB%or z)Db&tI7t!)aCY-QH+*W*Nc!3N5k8r;zzIIyrMs(1;N^Ia;Tp-DsaYaA5awNNh^!*b z6ePEi4A&SvHIb`ol=%soa5i|?CTD)4>RRGkN3*sIW;;Kj8hFg!3X4B9s z?D?3oHpJPrjQzKSc&cH5r`DS@oWL~Fp$nc5 zo?35RXTi^}1Z$W^q7C*xJCT1SwaywA{#OEn!$bjv`w-7%t{SH4APeX%naex-(eosYMzIPDO2O^lt{3g@Rrr$q&cBfElvZs^2Pj>pB6YKR-8uTVh*cO|)^lDvn1Lw6;i+x6(|B1sWkkT(k zEw7T5$CEyOyBvNznKT|xHsap{PMYktw8xY5ju<31d!GAKq+M9_R6OBrmT|TZ;iWJ5 zg3QPERI{}-r#sx|Tb0aeMk5Ke(mHD!h1{G+>a|%UXSwr;e-f!JP68jEL>dZ{NFChD z8i})ce&uSPU?Tm*{};->kVcAxu*Gcx~l3uUo)=gP~N3kcgbi&syaXO`B^NW{LDvKTml9F@AyQzD)K02PskFu zv*|*4aV98EsMbEUA$qqoPS>lQ@GU-eBX4w`#fS7bFd`tYPpZ+cC2Xx7{dFDNNV;|R zRv*2SHO*Ai{76_djp1gf?E3`NR8Rk?D1{i!(^(b$RD|CwT8GgE z%FTrN{z<|`V$BM9f1>iH6@_>_m5FirH#KLQwYKl4m|lxOCmbgRNlt?bh;6-|PKY}O z<1@k-mr2)xp)=fT7jgCES$HT*zbU>@7@iHw_)THeJiA$f+-R1P#@Wp({sFpygOjyP z42Pw!M51RzItqDEyHz~9*;i}d%yNakr2!$X&Tg*GXzrc$Z-=h9WW$nywS!kI8|*)G zma#%Bm@}8H9ytB<4L|tEKLP`DjU{u=nzv-Gy=2Z@bIF{6x$YAD7Bnlf$@u8 zs@ei_YIIn*b=81dHO0viCtI8xadO3}h!f0q%`F@(%OTO`7FFwALwXFYrA4(`T4ZWR zm1l?rb!I7`F{VON0OIWDfc?c~#8F|Dfq*Lrr@0Kaw+w|8g+tZFE&HswToDc@nf4q? zYYR+u2;iv2msGB-;x4_pRe1xJ+}DE4oTIc7?oC=@;>w$}N@_3FXEpuBmJ}waQbt`% zXugZNsX%`ngZL{F$d2@&c16(azj#9{%)V3IChd|IwxpJ4N$9)=`e0Hjl4hq!>WdHc zvxN1LW;aU`+EoNTmTBB*ddiouSnE`=WK`IYk8@P^s7J8Y1t(6mD!OhTkix zOJTd3{{_zC3ysPdEn(Fv%*Al)BY_UPP$eYD%@dpVNP@}Kax zXx~ZKm_LuuHH1#1S>5O*sqW!4W_^;R**GHwG&zAJjjwWYF;o7ShH>mKDc?Ly8M%u| zZ*FjM6{~flyg~nK7&px@>cb%BItR zL)arN^~QMBcFT_gguv_bgktZgB1JwFwek z31O7Kn?xaO?rc!{Np;N0lNS4Ql6CXrDZ8w$Dj8$<)zel<-UBY#MKFm`Pyfi}E>h!c z^i|yigGso7x+&=(#UzdY)*1YcdL_e1yO!VLR}F23ME|-%9h_zoc746{6sUS`65-MhI*6*6CGKO(lGhjI#DQgI%#oE2RZ!%!so}rWMB9P z?6KDPGyNU5_79q~FvA;ofF|^VZN2Wqhg!qRszg|RxHSM}*2qti#4IBJNIYiV5K$&2 zsItE`$oAIyzLN;q40OvTv2I&`o_5%#itha!R&~aq{3*{qNbsRGb_X_(-mp}TGS*TrvEM=^wTYhCh;0O znN(CJu#_v|hSY2JrE?N}N@tpc=J_-g z-{&EoNU--$QGS<#qNp%5Lf(?lif(I&y);{CZ}h2dkHp(uvqx$+O=h{UptEk*JWM-W zTO}>lXDPFwhaaYK+Q@ob2aQz+^GZLFjbfovwBm0`Pq6^P((_%kcc2(y%qkj4NffTA zqj0c+HW-~?cs*z^8x2-n?4t^@t|`uK@>~yW*Cg1aj&&1cd92nI6`b=J-=U~(Rc z`2xv@@&NNaFkV@N2Gm|?QKm3LJF4Vml5&B>W7Zo@^wwNeg5@-4v4r+YFhEbEHm*Nk9rE81J-- zI6sIY3y)n$uy(z$kkF@T&d^uK1ROULLWB1a<>wo)zM+Y<6t2|M=8|Q^_gXA(PIz61 zN%qFLGP|2dKa7X!hr>rcf^njM)f6pkdE`)eFy~ck{Ck&4)Ve)odq2|uOEa6Ks ztlI8N@)S+Cp1^b`m7nIyrfO2uj6xfF2zJn51d}5N@!F#wg;nk02#IUMRnpf*Phm1h z*xyB$S?Dq>n;{_PsT7k7%AK6?SrUJaOW%UFf*u1c!Gj*Neksc_Es>B{5IR4uydq2N zDdM{{t2`lt8}{w&s)D7&C_W+c23dJhX6;5X!@{@GP7@_kE|e*KBPu9`qBBg7o}mzd z>ojtdL>es;!%7}g6e%d|7uI*>AWTYLVc_h-j3!F*pv0udB(-p-#K2;OF+rG+yR<^k z2l02GtfUOszn_tl_>&(H-*2Q&XTowOJwKeW`hrxQh|dsY?7qJdC3+=mcly3hm5I65 zx`8vBSs1r+w+^FT246~27->90&uBn=a6EwMz}jw18x-yN{Xr_PPRBX z;)LZUMYV{N@v+QtOWekt=v?;e!6s3t!p4p@IbYPN)F}X za)>465XJd;5al`LtSmO(ysTPhWtWLlXM_3Ch;a0UKZtc0#g}8_3@oIfzl};k8Gu;g z65rX#R%Jt%DD)}9E{Q?5U97O{+F(adljh>zqaB8p1~)I-ungXV(}cjTIo=fy_#VyxzGnTSl;4?CZl&=W zb9WB%7899|F0#z%*v(LL6^cux971FlstGhiQj? zXa*A?&1b7|HY7obZKvQ9`cqj{yTM>tRr?+l^7yfK`Yd!@$OQ(Z@*V<#S0Case=I{U zAXNuuDpzI#@*z&I%uF7z?RRO;b`J3t^IhuuVV2ji(9M)TnN^*qvmq>3qjA15+gyOm zd@2i8ff~c`=*iCNxY?PF0%KS|p)i&0r4YULPbkmjR7(g>2kuU&UiZ=QR{qM2Vxz*J z&H8T3gnj^tT6~IkJMIrC+nCMS?j{feg$PDlyP{sFwkbFI*iWhSaI0DnqEG)B}j; ziu5mS5T7m#q?+y`|30v-E+b!?!Q4Hgx5lJuCJDXAc)Kl>)bm8yC&NAT!R8(-_b5Ow6d*jlJqWB z^`|Lkr-a~Cwzt*5uF$q9X>D&)@o$p4Vz{|b)!%GJx@WaTaxpe7Cbh?+!IpY`B8diC z_%I1&y{P;JTBC0*uvFL!W^8e#5*y6S+hfarh>o`ZV z7}lrT%52+MFZ{;yWY|E|w}&;n`6+fTi8&#FD^ z_6+aYv}ennt$RlH9NdFA1<8B2QRj}$_mC|T5x0))d}HTZJF9mkcBOZ@yQb}$xoddW zrd^wNZP|6luC2R9cJ1GFVAr8thj+cW>y2G+?RtL~*&Wzjy*shHb9Z64yL;it(vejo z>qdr0HjQi^*)np+$kvgOk^LhFMh=b~8aX`j;>eMaH%8tXd4FWEQ_HN$0W$zt(x9%O;yMOP2y$AOm+Ix8Ki+hjkUAS-Q zzE%6kR!MC6B&8JocjCRV_pQC}?OXP_NDiA?kntb_f6Y3bKjhOr8L*=8{W5R z-{yT=_T8~>>%NhF`}ZM({gStzX3n;S+m>#F|Dc2wqj=xkju&?v+407Xw|2b0gX|3K ztlpW}ncmsCv#`_MIc?|6opW|B+_`k82q-~WA>_dJgWC^nKfL|L?VEOP-nnJx9Xq$~ z9ND>l=YgFEch1~3XBRmlkvzhUY}>!>;I>2CI=4^TJ`>^J*#6e`_qUTBfgPK6Y~HbD z#~oOBNe;{823!JLmu}s%^}yC6TLX6$?pk=)rn^S&I(*mrcO|w>+cvyyWZRK#)!W_e ztF~`J-fwO1+_4ZT96+h7QMPr+{RsMaCVKYZuJrCryASU!pwAAVcLI02cdxs9>)nU% ze*fss;tkK@ZK@m$;8qV}i{q3r_!i{D0X3{}XC9 BeQy8& diff --git a/backend-python/wkv_cuda_utils/wkv_cuda_model.py b/backend-python/wkv_cuda_utils/wkv_cuda_model.py deleted file mode 100644 index 7d17727..0000000 --- a/backend-python/wkv_cuda_utils/wkv_cuda_model.py +++ /dev/null @@ -1,734 +0,0 @@ -######################################################################################################## -# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM -######################################################################################################## - -import types, gc, os, time, re -import torch -from torch.nn import functional as F -torch.backends.cudnn.benchmark = True -torch.backends.cudnn.allow_tf32 = True -torch.backends.cuda.matmul.allow_tf32 = True -current_path = os.path.dirname(os.path.abspath(__file__)) - -# https://zhuanlan.zhihu.com/p/612879065 -def LoadPreCompileLibrary(file): - import importlib - import os - - import torch - - # load the custom_op_library and register the custom ops - lib_dir = os.path.dirname(__file__) - if os.name == "nt": - # Register the main torchvision library location on the default DLL path - import ctypes - import sys - - kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) - with_load_library_flags = hasattr(kernel32, "AddDllDirectory") - prev_error_mode = kernel32.SetErrorMode(0x0001) - - if with_load_library_flags: - kernel32.AddDllDirectory.restype = ctypes.c_void_p - - if sys.version_info >= (3, 8): - os.add_dll_directory(lib_dir) - elif with_load_library_flags: - res = kernel32.AddDllDirectory(lib_dir) - if res is None: - err = ctypes.WinError(ctypes.get_last_error()) - err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' - raise ValueError(err) - - kernel32.SetErrorMode(prev_error_mode) - - loader_details = ( - importlib.machinery.ExtensionFileLoader, - importlib.machinery.EXTENSION_SUFFIXES, - ) - - extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) - ext_specs = extfinder.find_spec(file) - if ext_specs is None: - return False - - try: - torch.ops.load_library(ext_specs.origin) - except OSError as exc: - return False - return True - -######################################################################################################## - -if os.environ.get('RWKV_JIT_ON') != '0': - os.environ["RWKV_JIT_ON"] = '1' - MyModule = torch.jit.ScriptModule - MyFunction = torch.jit.script_method - MyStatic = torch.jit.script -else: - MyModule = torch.nn.Module - def __nop(ob): - return ob - MyFunction = __nop - MyStatic = __nop - -if os.environ.get('RWKV_CUDA_ON') == '1': - if LoadPreCompileLibrary('wkv_cuda') is False: - from torch.utils.cpp_extension import load - load( - name=f"wkv_cuda", - sources=[f"{current_path}/cuda/wrapper.cpp", f"{current_path}/cuda/operators.cu"], - verbose=True, - extra_cuda_cflags=["-t 4", "-std=c++17", "--use_fast_math", "-O3", "--extra-device-vectorization"], - is_python_module=False) - - @MyStatic - def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp): - assert 1 * C % min(C, 32) == 0 - assert k.dtype == v.dtype == torch.float16 or k.dtype == v.dtype == torch.float32 - assert w.dtype == u.dtype == aa.dtype == bb.dtype == pp.dtype == torch.float32 - w = w.contiguous() - u = u.contiguous() - k = k.contiguous() - v = v.contiguous() - y = torch.empty((T, C), device=w.device, memory_format=torch.contiguous_format, dtype=k.dtype) - torch.ops.rwkv.wkv_forward(1, T, C, w, u, k, v, y, aa, bb, pp) - return y, aa, bb, pp - @MyStatic - def cuda_mm8_seq(B: int, N: int, M: int, x, w, mx, rx, my, ry): - assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype - assert x.dtype == torch.float32 or x.dtype == torch.float16 - assert w.dtype == torch.uint8 - assert x.shape == [B, N] - assert w.shape == [N, M] - assert rx.shape == mx.shape == [M] - assert ry.shape == my.shape == [N, 1] - y = torch.empty((B, M), device=w.device, dtype=x.dtype) - torch.ops.rwkv.mm8_seq(B, N, M, x, w, mx, rx, my, ry, y) - return y - @MyStatic - def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry): - assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype - assert x.dtype == torch.float32 or x.dtype == torch.float16 - assert w.dtype == torch.uint8 - assert x.shape == [N] - assert w.shape == [N, M] - assert rx.shape == mx.shape == [M] - assert ry.shape == my.shape == [N, 1] - y = torch.zeros((M,), device=w.device, dtype=torch.float32) - torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y) - return y.to(dtype=x.dtype) -else: - os.environ["RWKV_CUDA_ON"] = '0' - -######################################################################################################## - -class RWKV(MyModule): - def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None): - super().__init__() - if verbose: - prxxx = lambda *args, **kwargs: print(*args, **kwargs) - else: - prxxx = lambda *args, **kwargs: None - - STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$" - if not re.match(STRATEGY_REGEX, strategy): - raise ValueError("Invalid strategy. Please read https://pypi.org/project/rwkv/") - - strategy = ('->'.join([x.strip() for x in strategy.split('->')])).replace('->', ' -> ') - self.args = types.SimpleNamespace() - args = self.args - args.MODEL_NAME = model - args.strategy_string = strategy - - # Rescale for fp16 mode: set x = x/2 every X layer (to avoid fp16 overflow) - self.RESCALE_LAYER = 6 if 'fp16' in strategy else 0 - prxxx(f'RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]} RWKV_CUDA_ON {os.environ["RWKV_CUDA_ON"]} RESCALE_LAYER {self.RESCALE_LAYER}\n') - - args.MODEL_NAME = args.MODEL_NAME.strip() - if not args.MODEL_NAME.endswith('.pth'): - args.MODEL_NAME += '.pth' - prxxx(f'Loading {args.MODEL_NAME} ...') - with torch.no_grad(): - self.w = torch.load(args.MODEL_NAME, map_location='cpu') # load model to CPU first - gc.collect() - w = self.w - - ALREADY_CONVERTED = False - if '_strategy' in w: - ALREADY_CONVERTED = True - assert convert_and_save_and_exit == None # you should only convert a raw model - prxxx(f"Converted model: strategy {w['_strategy']}, version {w['_version']}\n") - assert w['_strategy'] == args.strategy_string # if you are using a new strategy, re-convert the model - assert float(w['_version']) >= 0.7 # sometimes you should re-convert using latest convert_model.py - assert w['_rescale_layer'] == self.RESCALE_LAYER - del w['_strategy'] - del w['_version'] - del w['_rescale_layer'] - - args.n_embd = w['emb.weight'].shape[1] - args.n_layer = 0 - keys = list(w.keys()) - for x in keys: - layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 - args.n_layer = max(args.n_layer, layer_id+1) - - ####################### Compute strategy - - s = [x.strip().split(' ') for x in strategy.split('->')] - plan = [0] * len(s) - stream_i = -1 - stream_count = 0 - to_allocate = args.n_layer + 1 - allocated = 0 - free_slots = 0 - for i in range(len(s)): - si = s[i] - si1 = si[1] - if si1.startswith('fp32'): si[1] = [torch.float] - elif si1.startswith('fp16'): si[1] = [torch.float16] - elif si1.startswith('bf16'): si[1] = [torch.bfloat16] - if si1.endswith('i8'): si[1] += [torch.uint8] - else: si[1] += [si[1][0]] - if len(si) > 2: - ss = si[2] - assert ss.startswith('*') - if ss.endswith('+'): - plan[i] = int(ss[1:-1]) - stream_i = i - else: - plan[i] = int(ss[1:]) - allocated += plan[i] - if allocated >= to_allocate: - plan[i] += to_allocate - allocated - break - else: - free_slots += 1 - if stream_i < 0: - if free_slots > 0 and to_allocate > allocated: - for i in range(len(s)): - if plan[i] == 0: - plan[i] = (to_allocate - allocated) // free_slots - allocated += plan[i] - free_slots -= 1 - if to_allocate > allocated: - plan[len(s)-1] += to_allocate - allocated - else: - if to_allocate > allocated: - stream_count = to_allocate - allocated - plan[stream_i] += stream_count - prxxx(f'Strategy: (total {args.n_layer}+1={args.n_layer+1} layers)') - for i in range(len(s)): - ss = s[i] - if i != stream_i: - prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]} layers') - else: - prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]-stream_count} layers, stream {stream_count} layers') - plan[i] += (0 if i == 0 else plan[i-1]) - self.strategy = [None] * (args.n_layer + 1) - strategy = self.strategy - for n in range(args.n_layer + 1): - for i in range(len(s)): - if n < plan[i]: - strategy[n] = types.SimpleNamespace() - strategy[n].device = s[i][0] - strategy[n].atype = s[i][1][0] - strategy[n].wtype = s[i][1][1] - strategy[n].stream = False - if i == stream_i and n >= (plan[i] - stream_count): - strategy[n].stream = True - break - prxxx(f"{n}-{strategy[n].device}-{str(strategy[n].atype).replace('torch.','')}-{str(strategy[n].wtype).replace('torch.','')}{'-stream' if strategy[n].stream else ''}",end=' ') - prxxx() - - ####################### Load weights to self.w - - if not ALREADY_CONVERTED: - try: # precompute embedding - w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias']) - except: - w['emb.weight'] = F.layer_norm(w['emb.weight'].float(), (args.n_embd,), weight=w['blocks.0.ln0.weight'].float(), bias=w['blocks.0.ln0.bias'].float()) - del w['blocks.0.ln0.weight'] - del w['blocks.0.ln0.bias'] - - print_need_newline = False - keys = list(w.keys()) - for x in keys: - w[x].requires_grad = False - layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 - if ('ln_out.' in x) or ('head.' in x): - layer_id = args.n_layer - dd = strategy[layer_id] - DEVICE = dd.device - ATYPE = dd.atype - WTYPE = dd.wtype - - if not ALREADY_CONVERTED: - if self.RESCALE_LAYER > 0: - if 'att.output.weight' in x: - w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) - if 'ffn.value.weight' in x: - w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) - - if '.time_' in x: - w[x] = w[x].squeeze() - if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'output.weight' in x or 'head.weight' in x: - w[x] = w[x].t() - - if '.time_decay' in x: # need fp32 for this - w[x] = -torch.exp(w[x].float()) - elif '.time_first' in x: # need fp32 for this - w[x] = w[x].float() - else: - if (len(w[x].shape) == 2) and ('emb' not in x): - if WTYPE != torch.uint8: - w[x] = w[x].to(dtype=WTYPE) - else: - w[x] = w[x].float() - - if w[x].shape[0] > w[x].shape[1]: - w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) - w[x] = w[x] - w[x+'_my'] - w[x+'_mx'] = torch.amin(w[x], dim=0) - w[x] = w[x] - w[x+'_mx'] - w[x+'_rx'] = torch.amax(w[x], dim=0) - w[x] = w[x] / w[x+'_rx'] - w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) - w[x] = w[x] / w[x+'_ry'] - else: - w[x+'_mx'] = torch.amin(w[x], dim=0) - w[x] = w[x] - w[x+'_mx'] - w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) - w[x] = w[x] - w[x+'_my'] - w[x+'_rx'] = torch.amax(w[x], dim=0) - w[x] = w[x] / w[x+'_rx'] - w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) - w[x] = w[x] / w[x+'_ry'] - - w[x] = torch.clip(torch.floor(w[x] * 256), min=0, max=255).to(dtype=torch.uint8) - w[x+'_mx'] = w[x+'_mx'].to(dtype=ATYPE).contiguous() - w[x+'_rx'] = (w[x+'_rx'] / 16).to(dtype=ATYPE).contiguous() - w[x+'_my'] = w[x+'_my'].to(dtype=ATYPE).contiguous() - w[x+'_ry'] = (w[x+'_ry'] / 16).to(dtype=ATYPE).contiguous() - else: - w[x] = w[x].to(dtype=ATYPE) - - if convert_and_save_and_exit == None: - if 'emb.' in x: - w[x] = w[x].contiguous() - elif (dd.stream) and (x.endswith('key.weight') or x.endswith('value.weight') or x.endswith('receptance.weight') or x.endswith('output.weight')): - try: - w[x] = w[x].contiguous().pin_memory() # if you see "CUDA error: out of memory" here, that's out of CPU RAM, not VRAM. Get more RAM :) - except: - print('Note: You are running out of RAM. Get more CPU RAM. Now this will run much slower.') - elif DEVICE != 'cpu': - w[x] = w[x].to(device=DEVICE).contiguous() - - if (dd.stream) or (DEVICE != 'cpu'): - try: - w[x+'_mx'] = w[x+'_mx'].to(device=DEVICE).contiguous() - w[x+'_rx'] = w[x+'_rx'].to(device=DEVICE).contiguous() - w[x+'_my'] = w[x+'_my'].to(device=DEVICE).contiguous() - w[x+'_ry'] = w[x+'_ry'].to(device=DEVICE).contiguous() - except: - pass - - if 'ffn.value.weight' in x: - gc.collect() - if 'cuda' in args.strategy_string: - torch.cuda.empty_cache() - - shape = [i for i in w[x].shape if i != 1] - if len(shape) > 1: - shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" - else: - shape = f" {str(shape[0]).rjust(5)} " - if layer_id == 0 or layer_id >= args.n_layer-1: - if print_need_newline: - prxxx('\n', end = '') - print_need_newline = False - dt = str(w[x].dtype).replace('torch.', '') - dt = dt.replace('float32', 'f32').replace('bfloat16', 'bf16').replace('float16', 'f16').replace('uint8', 'i8') - prxxx(x.ljust(32), dt.rjust(4), str(w[x].device).rjust(8), shape, ' (pinned)' if w[x].is_pinned() else '') - else: - print_need_newline = True - prxxx('.', end = '', flush = True) - - if convert_and_save_and_exit: - w['_strategy'] = args.strategy_string - w['_rescale_layer'] = self.RESCALE_LAYER - w['_version'] = '0.7' - if not convert_and_save_and_exit.endswith('.pth'): - convert_and_save_and_exit += '.pth' - prxxx(f'Saving to {convert_and_save_and_exit}...') - torch.save(w, convert_and_save_and_exit) - prxxx(f'Converted and saved. Now this will exit.') - exit(0) - - gc.collect() - if 'cuda' in args.strategy_string: - torch.cuda.empty_cache() - - @MyFunction - def torch_mm8_seq(self, x, w, mx, rx, my, ry): - return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) - - @MyFunction - def torch_mm8_one(self, x, w, mx, rx, my, ry): - return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) - - if os.environ.get('RWKV_CUDA_ON') == '1': - @MyFunction - def mm8_seq(self, x, w, mx, rx, my, ry): - if w.device.type == 'cuda' and x.dtype == torch.float16: - B, N, M = x.shape[0], w.shape[0], w.shape[1] - return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry) - else: - return self.torch_mm8_seq(x, w, mx, rx, my, ry) - @MyFunction - def mm8_one(self, x, w, mx, rx, my, ry): - if w.device.type == 'cuda': - N, M = w.shape[0], w.shape[1] - return cuda_mm8_one(N, M, x, w, mx, rx, my, ry) - else: - return self.torch_mm8_one(x, w, mx, rx, my, ry) - else: - @MyFunction - def mm8_seq(self, x, w, mx, rx, my, ry): - return self.torch_mm8_seq(x, w, mx, rx, my, ry) - @MyFunction - def mm8_one(self, x, w, mx, rx, my, ry): - return self.torch_mm8_one(x, w, mx, rx, my, ry) - - ######################################################################################################## - - @MyFunction - def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - kx = xx * k_mix + sx * (1 - k_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(rx @ rw) - vx = torch.square(torch.relu(kx @ kw)) - out = r * (vx @ vw) - return x + out, xx - - @MyFunction - def ffn_one_i8(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - kx = xx * k_mix + sx * (1 - k_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry)) - vx = torch.square(torch.relu(self.mm8_one(kx, kw, kmx, krx, kmy, kry))) - out = r * (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)) - return x + out, xx - - ######################################################################################################## - - @MyFunction - def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) - kx = xx * k_mix + sx * (1 - k_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(rx @ rw) - vx = torch.square(torch.relu(kx @ kw)) - out = r * (vx @ vw) - return x + out, xx[-1,:] - - @MyFunction - def ffn_seq_i8(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) - kx = xx * k_mix + sx * (1 - k_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) - vx = torch.square(torch.relu(self.mm8_seq(kx, kw, kmx, krx, kmy, kry))) - out = r * (self.mm8_seq(vx, vw, vmx, vrx, vmy, vry)) - return x + out, xx[-1,:] - - ######################################################################################################## - - @MyFunction - def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - kx = xx * k_mix + sx * (1 - k_mix) - vx = xx * v_mix + sx * (1 - v_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(rx @ rw) - k = (kx @ kw).float() - v = (vx @ vw).float() - - ww = t_first + k - p = torch.maximum(pp, ww) - e1 = torch.exp(pp - p) - e2 = torch.exp(ww - p) - wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) - ww = t_decay + pp - p = torch.maximum(ww, k) - e1 = torch.exp(ww - p) - e2 = torch.exp(k - p) - - out = (r * wkv) @ ow - return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p - - @MyFunction - def att_one_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - kx = xx * k_mix + sx * (1 - k_mix) - vx = xx * v_mix + sx * (1 - v_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry)) - k = (self.mm8_one(kx, kw, kmx, krx, kmy, kry)).float() - v = (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)).float() - - ww = t_first + k - p = torch.maximum(pp, ww) - e1 = torch.exp(pp - p) - e2 = torch.exp(ww - p) - wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) - ww = t_decay + pp - p = torch.maximum(ww, k) - e1 = torch.exp(ww - p) - e2 = torch.exp(k - p) - - out = self.mm8_one(r * wkv, ow, omx, orx, omy, ory) - return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p - - ######################################################################################################## - - @MyFunction - def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) - kx = xx * k_mix + sx * (1 - k_mix) - vx = xx * v_mix + sx * (1 - v_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(rx @ rw) - k = (kx @ kw).float() - v = (vx @ vw).float() - - T = x.shape[0] - for t in range(T): - kk = k[t] - vv = v[t] - ww = t_first + kk - p = torch.maximum(pp, ww) - e1 = torch.exp(pp - p) - e2 = torch.exp(ww - p) - sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) - ww = t_decay + pp - p = torch.maximum(ww, kk) - e1 = torch.exp(ww - p) - e2 = torch.exp(kk - p) - aa = e1 * aa + e2 * vv - bb = e1 * bb + e2 - pp = p - out = (r * sx) @ ow - return x + out, xx[-1,:], aa, bb, pp - - @MyFunction - def att_seq_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): - xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) - sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) - kx = xx * k_mix + sx * (1 - k_mix) - vx = xx * v_mix + sx * (1 - v_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) - k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry).float() - v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry).float() - - T = x.shape[0] - for t in range(T): - kk = k[t] - vv = v[t] - ww = t_first + kk - p = torch.maximum(pp, ww) - e1 = torch.exp(pp - p) - e2 = torch.exp(ww - p) - sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) - ww = t_decay + pp - p = torch.maximum(ww, kk) - e1 = torch.exp(ww - p) - e2 = torch.exp(kk - p) - aa = e1 * aa + e2 * vv - bb = e1 * bb + e2 - pp = p - out = self.mm8_seq(r * sx, ow, omx, orx, omy, ory) - return x + out, xx[-1,:], aa, bb, pp - - ######################################################################################################## - - if os.environ["RWKV_CUDA_ON"] == '1': - @MyFunction - def cuda_att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): - T, C = x.size() - xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) - sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) - kx = xx * k_mix + sx * (1 - k_mix) - vx = xx * v_mix + sx * (1 - v_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(rx @ rw) - k = kx @ kw - v = vx @ vw - y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp) - - out = (r * y) @ ow - return x + out, xx[-1,:], aa, bb, pp - - @MyFunction - def cuda_att_seq_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): - T, C = x.size() - xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) - sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) - kx = xx * k_mix + sx * (1 - k_mix) - vx = xx * v_mix + sx * (1 - v_mix) - rx = xx * r_mix + sx * (1 - r_mix) - - r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) - k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry) - v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry) - y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp) - - out = self.mm8_seq(r * y, ow, omx, orx, omy, ory) - return x + out, xx[-1,:], aa, bb, pp - - ######################################################################################################## - - def forward(self, tokens, state, full_output=False): - with torch.no_grad(): - w = self.w - args = self.args - - if state == None: - state = [None] * args.n_layer * 5 - for i in range(args.n_layer): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx - dd = self.strategy[i] - dev = dd.device - atype = dd.atype - state[i*5+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() - state[i*5+1] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() - state[i*5+2] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() - state[i*5+3] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev).contiguous() - 1e30 - state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() - - seq_mode = len(tokens) > 1 - - x = w['emb.weight'][tokens if seq_mode else tokens[0]] - - for i in range(args.n_layer): - bbb = f'blocks.{i}.' - att = f'blocks.{i}.att.' - ffn = f'blocks.{i}.ffn.' - dd = self.strategy[i] - dev = dd.device - atype = dd.atype - wtype = dd.wtype - if seq_mode: - if 'cuda' in str(dev) and os.environ["RWKV_CUDA_ON"] == '1': - ATT = self.cuda_att_seq if wtype != torch.uint8 else self.cuda_att_seq_i8 - else: - ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8 - FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8 - else: - ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8 - FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8 - - x = x.to(dtype=atype, device=dev) - - kw = w[f'{att}key.weight'] - vw = w[f'{att}value.weight'] - rw = w[f'{att}receptance.weight'] - ow = w[f'{att}output.weight'] - if dd.stream: - kw = kw.to(device=dev, non_blocking=True) - vw = vw.to(device=dev, non_blocking=True) - rw = rw.to(device=dev, non_blocking=True) - ow = ow.to(device=dev, non_blocking=True) - kmx = w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x - krx = w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x - kmy = w[f'{att}key.weight_my'] if wtype == torch.uint8 else x - kry = w[f'{att}key.weight_ry'] if wtype == torch.uint8 else x - vmx = w[f'{att}value.weight_mx'] if wtype == torch.uint8 else x - vrx = w[f'{att}value.weight_rx'] if wtype == torch.uint8 else x - vmy = w[f'{att}value.weight_my'] if wtype == torch.uint8 else x - vry = w[f'{att}value.weight_ry'] if wtype == torch.uint8 else x - rmx = w[f'{att}receptance.weight_mx'] if wtype == torch.uint8 else x - rrx = w[f'{att}receptance.weight_rx'] if wtype == torch.uint8 else x - rmy = w[f'{att}receptance.weight_my'] if wtype == torch.uint8 else x - rry = w[f'{att}receptance.weight_ry'] if wtype == torch.uint8 else x - omx = w[f'{att}output.weight_mx'] if wtype == torch.uint8 else x - orx = w[f'{att}output.weight_rx'] if wtype == torch.uint8 else x - omy = w[f'{att}output.weight_my'] if wtype == torch.uint8 else x - ory = w[f'{att}output.weight_ry'] if wtype == torch.uint8 else x - x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT( - x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3], - w[f'{bbb}ln1.weight'], w[f'{bbb}ln1.bias'], - w[f'{att}time_mix_k'], w[f'{att}time_mix_v'], w[f'{att}time_mix_r'], - w[f'{att}time_decay'], w[f'{att}time_first'], - kw, vw, rw, ow, - kmx, krx, kmy, kry, - vmx, vrx, vmy, vry, - rmx, rrx, rmy, rry, - omx, orx, omy, ory, - ) - if dd.stream: - del kw, vw, rw, ow - - kw = w[f'{ffn}key.weight'] - vw = w[f'{ffn}value.weight'] - rw = w[f'{ffn}receptance.weight'] - if dd.stream: - kw = kw.to(device=dev, non_blocking=True) - vw = vw.to(device=dev, non_blocking=True) - rw = rw.to(device=dev, non_blocking=True) - kmx = w[f'{ffn}key.weight_mx'] if wtype == torch.uint8 else x - krx = w[f'{ffn}key.weight_rx'] if wtype == torch.uint8 else x - kmy = w[f'{ffn}key.weight_my'] if wtype == torch.uint8 else x - kry = w[f'{ffn}key.weight_ry'] if wtype == torch.uint8 else x - vmx = w[f'{ffn}value.weight_mx'] if wtype == torch.uint8 else x - vrx = w[f'{ffn}value.weight_rx'] if wtype == torch.uint8 else x - vmy = w[f'{ffn}value.weight_my'] if wtype == torch.uint8 else x - vry = w[f'{ffn}value.weight_ry'] if wtype == torch.uint8 else x - rmx = w[f'{ffn}receptance.weight_mx'] if wtype == torch.uint8 else x - rrx = w[f'{ffn}receptance.weight_rx'] if wtype == torch.uint8 else x - rmy = w[f'{ffn}receptance.weight_my'] if wtype == torch.uint8 else x - rry = w[f'{ffn}receptance.weight_ry'] if wtype == torch.uint8 else x - x, state[i*5+4] = FFN( - x, state[i*5+4], - w[f'{bbb}ln2.weight'], w[f'{bbb}ln2.bias'], - w[f'{ffn}time_mix_k'], w[f'{ffn}time_mix_r'], - kw, vw, rw, - kmx, krx, kmy, kry, - vmx, vrx, vmy, vry, - rmx, rrx, rmy, rry, - ) - if dd.stream: - del kw, vw, rw - - if self.RESCALE_LAYER > 0: - if (i+1) % self.RESCALE_LAYER == 0: - x = x / 2 - - dd = self.strategy[args.n_layer] - x = x[-1,:] if (seq_mode and (not full_output)) else x - x = x.to(dtype=dd.atype, device=dd.device) - - x = F.layer_norm(x, (args.n_embd,), weight=w['ln_out.weight'], bias=w['ln_out.bias']) - if w['head.weight'].dtype != torch.uint8: - x = x @ w['head.weight'] - else: - if seq_mode and full_output: - x = self.mm8_seq(x, w['head.weight'], w['head.weight_mx'], w['head.weight_rx'], w['head.weight_my'], w['head.weight_ry']) - else: - x = self.mm8_one(x, w['head.weight'], w['head.weight_mx'], w['head.weight_rx'], w['head.weight_my'], w['head.weight_ry']) - - return x.float(), state