From 0063c171f3efb5107e936ee739dbffa697d2c755 Mon Sep 17 00:00:00 2001 From: josc146 Date: Fri, 24 Nov 2023 17:55:16 +0800 Subject: [PATCH] upgrade to rwkv 0.8.22 (rwkv6 support) --- .github/workflows/release.yml | 2 + .gitignore | 2 +- backend-python/rwkv_pip/cuda/rwkv6.cu | 87 +++ backend-python/rwkv_pip/cuda/rwkv6_op.cpp | 34 + backend-python/rwkv_pip/model.py | 890 +++++++++++++++++++--- backend-python/rwkv_pip/rwkv5.pyd | Bin 419328 -> 419328 bytes backend-python/rwkv_pip/rwkv6.pyd | Bin 0 -> 419328 bytes backend-python/rwkv_pip/wkv_cuda.pyd | Bin 474112 -> 474112 bytes 8 files changed, 909 insertions(+), 106 deletions(-) create mode 100644 backend-python/rwkv_pip/cuda/rwkv6.cu create mode 100644 backend-python/rwkv_pip/cuda/rwkv6_op.cpp create mode 100644 backend-python/rwkv_pip/rwkv6.pyd diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5d7d7e2..55485bf 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -109,6 +109,7 @@ jobs: go install github.com/wailsapp/wails/v2/cmd/wails@latest rm ./backend-python/rwkv_pip/wkv_cuda.pyd rm ./backend-python/rwkv_pip/rwkv5.pyd + rm ./backend-python/rwkv_pip/rwkv6.pyd rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd rm ./backend-python/get-pip.py make @@ -141,6 +142,7 @@ jobs: go install github.com/wailsapp/wails/v2/cmd/wails@latest rm ./backend-python/rwkv_pip/wkv_cuda.pyd rm ./backend-python/rwkv_pip/rwkv5.pyd + rm ./backend-python/rwkv_pip/rwkv6.pyd rm ./backend-python/rwkv_pip/beta/wkv_cuda.pyd rm ./backend-python/get-pip.py make diff --git a/.gitignore b/.gitignore index b6ebf31..246fbdf 100644 --- a/.gitignore +++ b/.gitignore @@ -18,7 +18,7 @@ __pycache__ /cmd-helper.bat /install-py-dep.bat /backend-python/wkv_cuda -/backend-python/rwkv5 +/backend-python/rwkv* *.exe *.old .DS_Store diff --git a/backend-python/rwkv_pip/cuda/rwkv6.cu b/backend-python/rwkv_pip/cuda/rwkv6.cu new file mode 100644 index 0000000..91d9cd3 --- /dev/null +++ b/backend-python/rwkv_pip/cuda/rwkv6.cu @@ -0,0 +1,87 @@ +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; +typedef at::Half fp16; +typedef float fp32; + +template +__global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state, + const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, + F *__restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h*_N_; + _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!! + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + + float state[_N_]; + #pragma unroll + for (int j = 0; j < _N_; j++) + state[j] = _state[j]; + + __syncthreads(); + u[i] = float(_u[i]); + __syncthreads(); + + for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) + { + __syncthreads(); + w[i] = _w[t]; + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + + #pragma unroll + for (int j = 0; j < _N_; j+=4) + { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } + #pragma unroll + for (int j = 0; j < _N_; j++) + _state[j] = state[j]; +} + +void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) +{ + assert(H*_N_ == C); + kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); +} +void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y) +{ + assert(H*_N_ == C); + kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); +} +void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y) +{ + assert(H*_N_ == C); + kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); +} diff --git a/backend-python/rwkv_pip/cuda/rwkv6_op.cpp b/backend-python/rwkv_pip/cuda/rwkv6_op.cpp new file mode 100644 index 0000000..3701ea0 --- /dev/null +++ b/backend-python/rwkv_pip/cuda/rwkv6_op.cpp @@ -0,0 +1,34 @@ +#include +#include "ATen/ATen.h" +#include +typedef at::BFloat16 bf16; +typedef at::Half fp16; +typedef float fp32; + +void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); +void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y); +void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y); + +void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); + cuda_forward_bf16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} +void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); + cuda_forward_fp16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} +void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); + cuda_forward_fp32(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16"); + m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16"); + m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32"); +} +TORCH_LIBRARY(rwkv6, m) { + m.def("forward_bf16", forward_bf16); + m.def("forward_fp16", forward_fp16); + m.def("forward_fp32", forward_fp32); +} diff --git a/backend-python/rwkv_pip/model.py b/backend-python/rwkv_pip/model.py index 4173c4f..cdfec6c 100644 --- a/backend-python/rwkv_pip/model.py +++ b/backend-python/rwkv_pip/model.py @@ -5,6 +5,7 @@ from typing import Optional import types, gc, os, time, re import torch +import torch.nn as nn from torch.nn import functional as F torch.backends.cudnn.benchmark = True @@ -80,7 +81,7 @@ else: if os.environ.get("RWKV_CUDA_ON") == "1": DISABLE_CUBLAS_GEMM = False - from torch.utils.cpp_extension import load # L581 + from torch.utils.cpp_extension import load if LoadPreCompileLibrary("wkv_cuda") is False: try: @@ -374,6 +375,11 @@ class RWKV(MyModule): if len(w[x].shape) > 1: if w[x].shape[1] > 1: self.version = max(5.2, self.version) + if "time_maa" in x: + self.version = max(6, self.version) + if int(self.version) == 6 and "time_faaaa" in x: + args.n_head = w[x].shape[0] + prxxx(f"Model detected: v{self.version:.1f}") ####################### Compute strategy @@ -524,22 +530,24 @@ class RWKV(MyModule): ): w[x] = w[x].t() - if ".time_decay" in x: # need fp32 for this + if ".time_decay" in x and "_w" not in x: # need fp32 for this if self.version == 4: w[x] = -torch.exp(w[x].float()) elif int(self.version) == 5: w[x] = torch.exp(-torch.exp(w[x].float())).reshape(-1, 1, 1) if self.version == 5.2: w[x] = w[x].reshape(args.n_head, -1, 1) + elif self.version == 6.0: + w[x] = w[x].float().reshape(args.n_head, -1, 1) elif ".time_first" in x: # need fp32 for this if self.version == 4: w[x] = w[x].float() - elif int(self.version) == 5: + elif int(self.version) in [5, 6]: if REAL_TIME_FIRST: w[x] = w[x].float().reshape(-1, 1, 1) else: w[x] = torch.exp(w[x].float()).reshape(-1, 1, 1) - if self.version == 5.2: + if self.version in [5.2, 6.0]: w[x] = w[x].reshape(args.n_head, -1, 1) elif ".ln_x" in x: # need fp32 for group_norm w[x] = w[x].float() @@ -695,22 +703,85 @@ class RWKV(MyModule): assert u.is_contiguous() assert state.is_contiguous() - y = torch.empty( - (B, T, C), - device=w.device, - dtype=r.dtype, - memory_format=torch.contiguous_format, - ) - if r.dtype == torch.bfloat16: - rwkv5.forward_bf16(B, T, C, H, state, r, k, v, w, u, y) - elif r.dtype == torch.float16: - rwkv5.forward_fp16(B, T, C, H, state, r, k, v, w, u, y) - elif r.dtype == torch.float32: - rwkv5.forward_fp32(B, T, C, H, state, r, k, v, w, u, y) - return y, state + y = torch.empty( + (B, T, C), + device=w.device, + dtype=r.dtype, + memory_format=torch.contiguous_format, + ) + if r.dtype == torch.bfloat16: + rwkv5.forward_bf16(B, T, C, H, state, r, k, v, w, u, y) + elif r.dtype == torch.float16: + rwkv5.forward_fp16(B, T, C, H, state, r, k, v, w, u, y) + elif r.dtype == torch.float32: + rwkv5.forward_fp32(B, T, C, H, state, r, k, v, w, u, y) + return y, state self.RWKV_5 = RWKV_5 + if self.version == 6.0 and os.environ["RWKV_CUDA_ON"] == "1": + HEAD_SIZE = args.n_att // args.n_head + if LoadPreCompileLibrary("rwkv6") is True: + rwkv6 = torch.ops.rwkv6 + else: + rwkv6 = load( + name="rwkv6", + sources=[ + f"{current_path}/cuda/rwkv6_op.cpp", + f"{current_path}/cuda/rwkv6.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3" if os.name != "nt" else "", + "--extra-device-vectorization", + f"-D_N_={HEAD_SIZE}", + f"-D_T_={4096}", + ], + ) + + class RWKV_6(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, H, state, r, k, v, w, u): + with torch.no_grad(): + assert HEAD_SIZE == C // H + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert state.dtype == torch.float32 + assert w.dtype == torch.float32 + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + eew = torch.exp(-torch.exp(w.float())).contiguous() + + y = torch.empty( + (B, T, C), + device=w.device, + dtype=r.dtype, + memory_format=torch.contiguous_format, + ) + if r.dtype == torch.bfloat16: + rwkv6.forward_bf16( + B, T, C, H, state, r, k, v, eew, u, y + ) + elif r.dtype == torch.float16: + rwkv6.forward_fp16( + B, T, C, H, state, r, k, v, eew, u, y + ) + elif r.dtype == torch.float32: + rwkv6.forward_fp32( + B, T, C, H, state, r, k, v, eew, u, y + ) + return y, state + + self.RWKV_6 = RWKV_6 + gc.collect() if "cuda" in args.strategy_string: torch.cuda.empty_cache() @@ -718,6 +789,9 @@ class RWKV(MyModule): def RUN_RWKV_5(self, B, T, C, H, state, r, k, v, w, u): return self.RWKV_5.apply(B, T, C, H, state, r, k, v, w, u) + def RUN_RWKV_6(self, B, T, C, H, state, r, k, v, w, u): + return self.RWKV_6.apply(B, T, C, H, state, r, k, v, w, u) + ######################################################################################################## @MyFunction @@ -750,12 +824,10 @@ class RWKV(MyModule): rx = xx * r_mix + sx * (1 - r_mix) r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) - vx = torch.square(torch.relu(matmul(kx, kw, kmx, krx, kmy, kry))) + vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2 out = r * matmul(vx, vw, vmx, vrx, vmy, vry) return x + out, xx - ######################################################################################################## - @MyFunction def ffn_seq( self, @@ -787,7 +859,78 @@ class RWKV(MyModule): rx = xx * r_mix + sx * (1 - r_mix) r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) - vx = torch.square(torch.relu(matmul(kx, kw, kmx, krx, kmy, kry))) + vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2 + out = r * matmul(vx, vw, vmx, vrx, vmy, vry) + return x + out, xx[-1, :] + + @MyFunction + def ffn_one_v6( + self, + x, + sx, + ln_w, + ln_b, + k_maa, + r_maa, + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = sx - xx + kx = xx + sx * k_maa + rx = xx + sx * r_maa + + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2 + out = r * matmul(vx, vw, vmx, vrx, vmy, vry) + return x + out, xx + + @MyFunction + def ffn_seq_v6( + self, + x, + sx, + ln_w, + ln_b, + k_maa, + r_maa, + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) + sx = sx - xx + kx = xx + sx * k_maa + rx = xx + sx * r_maa + + r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry)) + vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2 out = r * matmul(vx, vw, vmx, vrx, vmy, vry) return x + out, xx[-1, :] @@ -851,8 +994,6 @@ class RWKV(MyModule): out = matmul(r * wkv, ow, omx, orx, omy, ory) return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p - ######################################################################################################## - @MyFunction def att_seq( self, @@ -962,11 +1103,11 @@ class RWKV(MyModule): rx = xx * r_mix + sx * (1 - r_mix) H = t_decay.shape[0] - S = x.shape[-1] // H + N = x.shape[-1] // H - r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, S) - k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, S, 1) - v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, S) + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N) a = matmul(k, v) out = r @ (t_first * a + s) @@ -974,7 +1115,7 @@ class RWKV(MyModule): out = out.flatten() out = F.group_norm( - out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b + out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5 ).squeeze(0) out = out.to(dtype=x.dtype) out = matmul(out, ow, omx, orx, omy, ory) @@ -1024,7 +1165,7 @@ class RWKV(MyModule): rx = xx * r_mix + sx * (1 - r_mix) H = t_decay.shape[0] - S = x.shape[-1] // H + N = x.shape[-1] // H T = x.shape[0] w = t_decay.reshape(-1, 1) @@ -1042,26 +1183,25 @@ class RWKV(MyModule): r = ( matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) - .view(T, H, S) + .view(T, H, N) .transpose(0, 1) ) k = ( matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) - .view(T, H, S) - .transpose(0, 1) - .transpose(-2, -1) + .view(T, H, N) + .permute(1, 2, 0) ) v = ( matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) - .view(T, H, S) + .view(T, H, N) .transpose(0, 1) ) out = ((r @ k) * w) @ v + (r @ s) * wb s = ws * s + (k * wk) @ v - out = out.transpose(0, 1).contiguous().reshape(T, H * S) - out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) + out = out.transpose(0, 1).contiguous().reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) out = out.to(dtype=x.dtype) out = matmul(out, ow, omx, orx, omy, ory) @@ -1118,11 +1258,11 @@ class RWKV(MyModule): gx = xx * g_mix + sx * (1 - g_mix) H = t_decay.shape[0] - S = x.shape[-1] // H + N = x.shape[-1] // H - r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, S) - k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, S, 1) - v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, S) + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N) g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) a = matmul(k, v) @@ -1131,7 +1271,7 @@ class RWKV(MyModule): out = out.flatten() out = F.group_norm( - out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b + out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5 ).squeeze(0) out = out.to(dtype=x.dtype) * g out = matmul(out, ow, omx, orx, omy, ory) @@ -1188,7 +1328,7 @@ class RWKV(MyModule): gx = xx * g_mix + sx * (1 - g_mix) H = t_decay.shape[0] - S = x.shape[-1] // H + N = x.shape[-1] // H T = x.shape[0] w = t_decay.reshape(-1, 1) @@ -1206,18 +1346,17 @@ class RWKV(MyModule): r = ( matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) - .view(T, H, S) + .view(T, H, N) .transpose(0, 1) ) k = ( matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) - .view(T, H, S) - .transpose(0, 1) - .transpose(-2, -1) + .view(T, H, N) + .permute(1, 2, 0) ) v = ( matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) - .view(T, H, S) + .view(T, H, N) .transpose(0, 1) ) g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) @@ -1225,8 +1364,8 @@ class RWKV(MyModule): out = ((r @ k) * w) @ v + (r @ s) * wb s = ws * s + (k * wk) @ v - out = out.transpose(0, 1).contiguous().reshape(T, H * S) - out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) + out = out.transpose(0, 1).contiguous().reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) out = out.to(dtype=x.dtype) * g out = matmul(out, ow, omx, orx, omy, ory) @@ -1284,28 +1423,27 @@ class RWKV(MyModule): gx = xx * g_mix + sx * (1 - g_mix) H = t_decay.shape[0] - S = x.shape[-1] // H + N = x.shape[-1] // H T = x.shape[0] r = ( matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) - .view(T, H, S) + .view(T, H, N) .transpose(0, 1) ) k = ( matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) - .view(T, H, S) - .transpose(0, 1) - .transpose(-2, -1) + .view(T, H, N) + .permute(1, 2, 0) ) v = ( matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) - .view(T, H, S) + .view(T, H, N) .transpose(0, 1) ) g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) - out = torch.empty((T, H, S), dtype=r.dtype, device=r.device) + out = torch.empty((T, H, N), dtype=r.dtype, device=r.device) for t in range(T): rt = r[:, t : t + 1, :] kt = k[:, :, t : t + 1] @@ -1314,8 +1452,198 @@ class RWKV(MyModule): out[t] = (rt @ (t_first * at + s)).squeeze(1) s = at + t_decay * s - out = out.reshape(T, H * S) - out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) + out = out.reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) + out = out.to(dtype=x.dtype) * g + out = matmul(out, ow, omx, orx, omy, ory) + + return x + out, xx[-1, :], s + + ######################################################################################################## + + @MyFunction + def att_one_v6_0( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + + sx = sx - xx + xxx = xx + sx * x_maa + xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1) + xxx = torch.bmm(xxx, tm_w2).view(5, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=0) + + wx = xx + sx * (w_maa + mw) + kx = xx + sx * (k_maa + mk) + vx = xx + sx * (v_maa + mv) + rx = xx + sx * (r_maa + mr) + gx = xx + sx * (g_maa + mg) + + H = t_decay.shape[0] + N = x.shape[-1] // H + + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) + + w = t_decay + (torch.tanh(wx @ td_w1) @ td_w2).float().view(H, N, 1) + w = torch.exp(-torch.exp(w.float())) + + a = matmul(k, v) + out = r @ (t_first * a + s) + s = a + w * s + + out = out.flatten() + out = F.group_norm( + out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5 + ).squeeze(0) + out = out.to(dtype=x.dtype) * g + out = matmul(out, ow, omx, orx, omy, ory) + + return x + out, xx, s + + @MyFunction + def att_seq_v6_0( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) - xx + xxx = xx + sx * x_maa + xxx = torch.tanh(xxx @ tm_w1).view(T, 5, -1).transpose(0, 1) + xxx = torch.bmm(xxx, tm_w2).view(5, T, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=0) + + wx = xx + sx * (w_maa + mw) + kx = xx + sx * (k_maa + mk) + vx = xx + sx * (v_maa + mv) + rx = xx + sx * (r_maa + mr) + gx = xx + sx * (g_maa + mg) + + r = ( + matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) + .view(T, H, N) + .transpose(0, 1) + ) + k = ( + matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + .view(T, H, N) + .permute(1, 2, 0) + ) + v = ( + matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + .view(T, H, N) + .transpose(0, 1) + ) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) + + w = t_decay.view(1, H, N, 1) + (torch.tanh(wx @ td_w1) @ td_w2).float().view( + T, H, N, 1 + ) + w = torch.exp(-torch.exp(w.float())) + out = torch.empty((T, H, N), dtype=r.dtype, device=r.device) + for t in range(T): + rt = r[:, t : t + 1, :] + kt = k[:, :, t : t + 1] + vt = v[:, t : t + 1, :] + at = matmul(kt, vt) + out[t] = (rt @ (t_first * at + s)).squeeze(1) + s = at + w[t] * s + + out = out.reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) out = out.to(dtype=x.dtype) * g out = matmul(out, ow, omx, orx, omy, ory) @@ -1376,8 +1704,8 @@ class RWKV(MyModule): out = matmul(r * y.to(x.dtype), ow, omx, orx, omy, ory) return x + out, xx[-1, :], aa, bb, pp - # NOTE: decorate with @MyFunction causes JIT error - def cuda_att_seq_v5_2( + @MyFunction + def v5_2_before( self, x, sx, @@ -1425,35 +1753,303 @@ class RWKV(MyModule): rx = xx * r_mix + sx * (1 - r_mix) gx = xx * g_mix + sx * (1 - g_mix) + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) + k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) + v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) + g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) + + return r, k, v, g, xx[-1, :], s.transpose(-1, -2).contiguous() + + @MyFunction + def v5_2_after( + self, t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory + ): H = t_decay.shape[0] N = x.shape[-1] // H T = x.shape[0] + s = s.transpose(-1, -2) + out = out.reshape(T, H * N) + out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5) + out = out.to(dtype=x.dtype) * g + out = matmul(out, ow, omx, orx, omy, ory) + + return x + out, xxx, s + + def cuda_att_seq_v5_2( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + g_mix, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + r, k, v, g, xxx, ss = self.v5_2_before( + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + k_mix, + v_mix, + r_mix, + g_mix, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ) + + out, s = self.RUN_RWKV_5( + 1, T, self.args.n_att, H, ss, r, k, v, w=t_decay, u=t_first + ) + + return self.v5_2_after( + t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory + ) + + @MyFunction + def v6_0_before( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) - xx + xxx = xx + sx * x_maa + xxx = torch.tanh(xxx @ tm_w1).view(T, 5, -1).transpose(0, 1) + xxx = torch.bmm(xxx, tm_w2).view(5, T, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=0) + + wx = xx + sx * (w_maa + mw) + kx = xx + sx * (k_maa + mk) + vx = xx + sx * (v_maa + mv) + rx = xx + sx * (r_maa + mr) + gx = xx + sx * (g_maa + mg) + r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32) k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32) v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32) g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry)) - out, s = self.RUN_RWKV_5( - 1, - T, - self.args.n_att, - H, - s.transpose(-1, -2).contiguous(), - r, - k, - v, - w=t_decay, - u=t_first, + w = t_decay.view(1, H, N, 1) + ( + torch.tanh(wx @ td_w1) @ td_w2 + ).float().view(T, H, N, 1) + + return r, k, v, g, w, xx[-1, :], s.transpose(-1, -2).contiguous() + + def cuda_att_seq_v6_0( + self, + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ): + H = t_decay.shape[0] + N = x.shape[-1] // H + T = x.shape[0] + + r, k, v, g, w, xxx, ss = self.v6_0_before( + x, + sx, + s, + ln_w, + ln_b, + lx_w, + lx_b, + x_maa, + w_maa, + k_maa, + v_maa, + r_maa, + g_maa, + tm_w1, + tm_w2, + td_w1, + td_w2, + t_decay, + t_first, + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, ) - s = s.transpose(-1, -2) - out = out.reshape(T, H * N) - out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) - out = out.to(dtype=x.dtype) * g - out = matmul(out, ow, omx, orx, omy, ory) - - return x + out, xx[-1, :], s + out, s = self.RUN_RWKV_6( + 1, T, self.args.n_att, H, ss, r, k, v, w=w, u=t_first + ) + return self.v5_2_after( + t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory + ) ######################################################################################################## @@ -1498,7 +2094,7 @@ class RWKV(MyModule): state[i * 5 + 4] = torch.zeros( args.n_embd, dtype=atype, requires_grad=False, device=dev ).contiguous() - elif int(self.version) == 5: + elif int(self.version) in [5, 6]: state = [None] * args.n_layer * 3 for i in range(args.n_layer): # state: 0=att_xx 1=att_kv 2=ffn_xx dd = self.strategy[i] @@ -1549,7 +2145,13 @@ class RWKV(MyModule): ATT = self.att_seq_v5_2 if cuda_applicable: ATT = self.cuda_att_seq_v5_2 + elif self.version == 6.0: + ATT = self.att_seq_v6_0 + if cuda_applicable: + ATT = self.cuda_att_seq_v6_0 FFN = self.ffn_seq + if self.version >= 6.0: + FFN = self.ffn_seq_v6 else: ATT = self.att_one if self.version == 5: @@ -1558,7 +2160,11 @@ class RWKV(MyModule): ATT = self.att_one_v5_1 elif self.version == 5.2: ATT = self.att_one_v5_1 # same as v5.1 + elif self.version == 6.0: + ATT = self.att_one_v6_0 FFN = self.ffn_one + if self.version >= 6.0: + FFN = self.ffn_one_v6 x = x.to(dtype=atype, device=dev) @@ -1587,7 +2193,7 @@ class RWKV(MyModule): orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x - if self.version == 5.1 or self.version == 5.2: + if self.version in [5.1, 5.2, 6.0]: gw = w[f"{att}gate.weight"] if dd.stream: gw = gw.to(device=dev, non_blocking=True) @@ -1671,7 +2277,7 @@ class RWKV(MyModule): omy, ory, ) - elif self.version == 5.1 or self.version == 5.2: + elif self.version in [5.1, 5.2]: x, state[i * 3 + 0], state[i * 3 + 1] = ATT( x, state[i * 3 + 0], @@ -1712,8 +2318,57 @@ class RWKV(MyModule): omy, ory, ) + elif self.version == 6.0: + x, state[i * 3 + 0], state[i * 3 + 1] = ATT( + x, + state[i * 3 + 0], + state[i * 3 + 1], + w[f"{bbb}ln1.weight"], + w[f"{bbb}ln1.bias"], + w[f"{att}ln_x.weight"], + w[f"{att}ln_x.bias"], + w[f"{att}time_maa_x"], + w[f"{att}time_maa_w"], + w[f"{att}time_maa_k"], + w[f"{att}time_maa_v"], + w[f"{att}time_maa_r"], + w[f"{att}time_maa_g"], + w[f"{att}time_maa_w1"], + w[f"{att}time_maa_w2"], + w[f"{att}time_decay_w1"], + w[f"{att}time_decay_w2"], + w[f"{att}time_decay"], + w[f"{att}time_first"], + kw, + vw, + rw, + gw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + gmx, + grx, + gmy, + gry, + omx, + orx, + omy, + ory, + ) if dd.stream: del kw, vw, rw, ow + if self.version in [5.1, 5.2, 6.0]: + del gw kw = w[f"{ffn}key.weight"] vw = w[f"{ffn}value.weight"] @@ -1736,31 +2391,56 @@ class RWKV(MyModule): rry = w[f"{ffn}receptance.weight_ry"] if wtype == torch.uint8 else x if self.version == 4: offset = i * 5 + 4 - elif int(self.version) == 5: + elif int(self.version) in [5, 6]: offset = i * 3 + 2 - x, state[offset] = FFN( - x, - state[offset], - w[f"{bbb}ln2.weight"], - w[f"{bbb}ln2.bias"], - w[f"{ffn}time_mix_k"], - w[f"{ffn}time_mix_r"], - kw, - vw, - rw, - kmx, - krx, - kmy, - kry, - vmx, - vrx, - vmy, - vry, - rmx, - rrx, - rmy, - rry, - ) + if self.version < 6.0: + x, state[offset] = FFN( + x, + state[offset], + w[f"{bbb}ln2.weight"], + w[f"{bbb}ln2.bias"], + w[f"{ffn}time_mix_k"], + w[f"{ffn}time_mix_r"], + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ) + else: + x, state[offset] = FFN( + x, + state[offset], + w[f"{bbb}ln2.weight"], + w[f"{bbb}ln2.bias"], + w[f"{ffn}time_maa_k"], + w[f"{ffn}time_maa_r"], + kw, + vw, + rw, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + ) if dd.stream: del kw, vw, rw diff --git a/backend-python/rwkv_pip/rwkv5.pyd b/backend-python/rwkv_pip/rwkv5.pyd index 24d7c1ff78bfbf99a68f3737df1ed5c68a21ecb3..5b59dde9ff671096526b87216f197825bded1e40 100644 GIT binary patch delta 48 ycmZoTBiV39a)SgT^SZ)>W@*NDX+{uc0%GRv(u^#UnIMUF^DLI_=2@&Ke**wxO%BWe delta 48 ycmZoTBiV39a)SgT^K5?CW@*NDX+{uc0%GRv(u^#UnIMUF^DLI_=2@&Ke**wV%?+#o diff --git a/backend-python/rwkv_pip/rwkv6.pyd b/backend-python/rwkv_pip/rwkv6.pyd new file mode 100644 index 0000000000000000000000000000000000000000..480a03ec940b8088a2e72aad8cb834400b679c3a GIT binary patch literal 419328 zcmeGF34GmC`Uj5Rv(>{#_ngmjp8Y)M zIm_qVLyw;o=^2Scit&F(MY$P>L@8MvUR`jvsBYEM5e zT6gZm)6N}pc698R2@@vPM^8C5dTzsn=;;%pRfiuHJ$vG~Q-}5KyI!TDUh&cnTikW; zeuumNKlFav;SZvG{#*MVb`L%yZy#{jGxB-hVXTk5ojmMbe5PM_pYMCvA5lJLjV+Hr zyZcbS>)w42A9v4whtFMpfZ&{Uecam#-Un{+#rw z?z68VgQC5A^s9_4fTkYnA=Qb?@x5HWPJtTicRX50s{GGLD{81*{Ob`(1JpKhKXS{q zW<^_ZQDvl(B zc-neKQhi*g+|Ob4r#9B3x@5k}N!e8n7dje=q=%h5ZcP1{NMwWQKn|RdEAjaZJ_G-f zaHnAwEV6tfJV=+oaqxKnJ_G-fkx1*XItwT=Le|JB02q}6mmD@>(&+JH>Q9N!!Sk>1 z{C<4S%L{kFnEHKB?~lik*{}u8mv*hhO7ZpNuyap6YvS05=#n}WSx?D<+t;@Ef6rg1 zX;qEWQdOCBW>i%=W%DXVzS=Y5G*?$TrfBb?$kM9H_y9=jwDkSs-|=8Zb!DtM0?=r6 zOW$#S=^3eRDOuPn5*hwFAW|olM9d+F7ZJLCUeAcRb^@9-EvZ?&%==f3w*8$MqbkdZ zs}$rNr)5-Slq6H;o6csR54VWkyit({yo7{I3FomV^ z!~=awg%VAhL3t-eQlZ?;c%RJ`(%ctV&K72@KHu}M5qzaAUJO#`HX z zU?508K^c=p=dXuWqOAbeGxy>vC9=a$VM4^Bjg}L&p{#E(IUtd@Xe#}*9 zT{bUc-RccP)LA&HVc%$0Cp_j@XX>g9+B?nHRYu?%PD^!Vomq4S{2brXPV)~=^H&QQ z)W5ll9Rb~QW+k;~embxyGRtY6Ss6z&^XWK-h^op)*2GxzrqoQUoC080((%kft(>lv z{aJaxR?gB&CKu-ST6v>Zma=k+R?gN+a%@gzC5+if1*J%83Lm3@UOTf=NTHBqK}etg z9nP#2)F{A=Gb?9no0yoV3N@=yQ`J1J@+*8oJBA-KD_7vd3{3Zg?5-xxy9!}r>XHne zj6T_S+n=HDBTnyH-z^6O^?lLUY<-_0HCg&TNGn77PHJUH-$S)Br0f4P+7{yCw{}(}yze0RN?caR9FI0N-5-jDaiw2}L6)y!~g|rK&Q~ zQE~2b_G^ro6Xwu*Ta0s9N2k|JAl6aQ>uDJ&Q6~>LDECAtT~zj5q-k~hB7}}S^}Amp zNdo?W7&x1W9N1|txs^nlG5kL9Y|qH>)|E5RcTDPW6iul5naol!dbIDu1r7Eft%3^%GGK)O@<3MmKVyZoi*L*$dGKYW9+LlqW{R12>rS>o%4upXrTjIlefjH%?HX~@@j3j9 z@xf_Y1O>I30UMwKjvX`O-m-l~e1`}yxXZL6V`gA}Vu40q>Jvy3_ZLGXypz7h5P!U$ z?keh7Fqz~W36Wv-@YaP$HzMEM&x)Ob6=m}PF&*ApRW`4>awi6aTJn$&5Y@@^?5 zN4wn*`gO_hA*8OZ#LNhvC=kfw0!%<;=EesNwwS<9#{9TxSjl`lB-S(bnW&@;f#Ky})?Te+o_;kk1-z02gZ?=`+UYS9j zYTC>p$czJe$mbt72I1TP6@t&zmvmgm>ie+urGHZO-z>IHmNEwsogF45p+5-I%|+#A zXQhJq`XU)vWZ*0{n?1q!0ZZi_QF7_K_8;g8S#FNuNGBe!%&Tr%1pAkpJ>`+-2mfdB zqbC;YCIE9f&b`bXL<>71PGMeVuF)dw%h_d<7L$U+X;v$q#cBW`mLSdHHwg77ekOgp zrjc5iYaUa(rlFA*BeE1lvewv=e&2GW=2Zxvm zd4;ITF>|D?Sgci?(Ol2f5$S?06&GS9nIqRJn_H&B-O?Kly-_zB9Axb<45z|$p}Mlh zoc$w_WI8HVJ?^DfF>~+gjt-2%l1Y|R1u0XxjX%;SA`N|!hK@P7IAWF|&6DV${i={Y z-rSOe;FsT&l}ElQTfb1m_5Cl8QLc88_+x-$ejBRH9s_)QAx$GbA{0oDq1`=W*8fjO z2j&p6JVV_SZ3dP_OS_STtF*L7W@i2Dk{2tY(G>I@Xu=h_KoX=h^sB&uOO{1T*Go_<|3P{(`Vl$jD!cOOmBe^cz$J z)&j;3 zpU4w8HGK)j#f@P;{BC>d8yP;|>Q7vfZnc>O--o)C zVzwXo`F%z+A*q|MH0EH=u`v`7X*qPc)Oobf~?KSD6YK_5vMco=yV%IAFbP*|MtX?syZzq?}9!CPvBY z<&*TWs4)YR9vi`yCXo#j!W~_@!Z%Fu9pK}G8r|W`jzIT^0%Tyk((&w!^<46t1q7bc z`TJ_`cWXwe&=uXG8B;l*O5AP6ui|)fCgl5dDUP11&;<;+Q6lV~;@uD4DQUL|$ zv(oYn36KBWuH%~n;|2CdnX5)El4k5;VvUH)A6O$2!-1<~diq47{^C!lKZGblc81-D z9qIYpQA1~dcT`o@X=vi=5<3XDsMAaYWwX1+wp(X$RvPBIV_R%n*n@wR!D?X2ipi%KQT zzea3HP66NNdSG~iGZ$ad#4LQ%gB)#I3TK>!KBcsIZ#&Cw$OG9O)a{EX7}O)S(0 z>w9jJG~>9mEp)r0hkRBVX3Xq%wsGT_@SMb6I`U}RE2<;D%Cfj0^c%)a73Ip~n9MS$ z#O2*FKWt?+1Xn;--bUt=rvAtr)riJJvl2!y(RtkPSPsYsm+rIRr>>d~7;hOLo1wCm zg)y<#Ou`W*86IoR16LMVWtUC6L8gt-MYctcYW6Fq?v}F1n<*1LQTJ z5osUhGHbt~*X?dbyTPqk!);ofESvT`sr6<1m?DeG(S@?h?#2`D z*T*7dSKY+#U2-(?ww@8J|5kc3HurMI9t4q9Q$69hFLzW70vK0_E6e7g0zGRda0K4t zU?PssUCi?Hdawf7=cvkE*b5S-qTY4k-gKR&OfdkWEB8Z!Cj2&>-8Fyr-AcmO*&S~Z zUI0#jW-nme3!h0d+7hlJVXSypS5}!@P$kKDeUR|4rLult2`U@hh04CX#j0$utFkp{ z2aahsn!VcGeXgi1Mb4|qdCGiO0a zS(xccKac#OTRn0Xz(PIpcCY~<6A&PQ&%Mk6b>d#=6Qe5kqGgk&302m;`Ud1|FIidb z^-08~t~jg!X$}T3cA#Tkjaq710-G$*)P%a9hNxexpQ~@`R`;s|81>_Ize~U$*sB8L zRrqwwHWNiZ=y^Pk2K`iLmgvg;CXLetc01Md4F6wK|{=Q-wA%k0f)}d(W|)CYUsoug^68L z$i-g^dsbQfa~tN&7cAnQM2v0%jp=^PIefzI?A3hvchn4J4MvXCc_n7=H4N@ml_Oz= zqzBemVH--ga)=@x%*b@ekeNzjHG1e2a6@AO19yUM!l^!pUS;iG#in&|sajd^%9W=E z@28Bp~Ujr z@=1u_lcieuef0Y7@Y_2dzq=OXH#?q*?K0+*V)7%wFk>2L`2inK5(Gh)xWb zk}!w<{W~`qSdUgK6y!8^@S^}fM+f;?0IX2F%vB1&4-UdF=hB|ScgFu$J+t{=1IfD) zq`T(-_dxCG)48U5{G@aEx#}0_(_KI2d2TM9Pvv>*@<`;|K6LcB8H6s?zA+Ol9K+VS zz3(oHwD-VSHE+mp#}fnz8oH@5 z)BK@}oJFgP35w9cKf6z>oDHIrs+rvJo5#lZ`Cy|(6aEX(EjZSKCrDe}3^`r!QnSTG zHse{u8AU6x2rXoxZqZ8?S@b*FM9G$2{s>5d>8MR7pzxy!0DX#0lI?g`fMHzaOHw5_ zl1r)3?|=Zc2u+);ukf=}1W?3ZB!=xzS*p;mh-fzlp@ce5KqB=7#M%c0L@;)mG5swy zdrt5LY#;5MX?0d^CG1&5bz)^pRdpEIqwHB)x_&*|bD=`vSDM$D=k?5l=nD%3rS2@u z@37XN(W*4qKUl_+II8^T(1KTPTGgMFqN2)TWI_1fzNeAxe@&}u3F)rd$$0k$VmypA z1;R`!F`s=0U!X!v1fD>7kwKRn4b}2{k^vR`>x4XN_S_OcP0vW{^7GD8f$b3P%y1w% z;g=@+MAFDO-MC#3`o%yzwp@_B9(H90KafD@m4HhOc2Uy{N3m9Z+3=169;XN$H)r$kE54z~ z#St@tM^CK=vPA?G`m6&(u;DIfm?P-bLqyO)qTjeFC-R;uG8lpXXG}A8x2>E5;Y6^a z%{yEV>4rW3^jgTCoBbOCgQq$ZKi`T@dtP=!0ehagIeh~mdaEy0x=Deh% zHrIQ_^XMvAjpMOEzl;+{Gn~@7MeRLyZuJJZ5PmdbqPSd~Fk{-}v(3J0bJSR?t|(y~ z^MpJ@cae}#1bp_@oqYJFE50d;4_zT*dWP|h!!zBZ#+r2J_~d*xOchsrY2thQln{R# z*=J-nYOaXW=Y+$2{2AHg$(R-Lsq&5u<69!n;29J-FMoKeNKbgA%=L=#D#b_#HtPp5 z)|hkU8PWv4YeWFNYXlyHFVmn?=JrSNqqJGfy>bRjbhiH2+0F+0ag0W-thh@Ul|B8a z`Ibi>k(AU%>W^SoLX_sb;q%*2?KZ+F1M@Qg*C~jE;!?9)E&GHF%GfSMW*Y~)5=G)> z^Fx6_;^u25;=0-M11PMZ5Npo$g*Xan5?Sv=h#h?)&I|}q8xn$yHkystKY`6*vCfpB zDAI`-Iwa%S`VNNV1>*{fcNX@dt$6=pik&_E3GDBNJWGu_&@J4gj>2AzC#A#NJS4^TVl*qN(zd+W$PG=VBV>oEW)XPy7iGMP1SDlQR z!z~WBL(zAhkoWK0P7NkGzps&=?Q>=0IfsNBu{}dk;*=QJ^9OFAzzi_AE2%y9`_9@! z`g^0rXf8&y4JU3HE| z^AiI=$nl1r<~klMrRQxL&ERtY2kxS*1B;5NAFH8as9n75A18ph8@u*0H=D~@U%(_d zA(P-TO9(oOo9SF%fW|?W#iO_95m76yoey9h;Qq_9@8Ya7ow~iBDy$WgI8h_D4u!{| zaG`MiH#ABc3rX|$YcLOg#iiFNKAD%Q#{}_h8^m{i_wmV4nFrtB{uSWwx827l5iJkC zD}wkA?mj*wY&x# z_@kZKQ~f>jMUH-+WK}->WaCeF1AlEU{-X=we=v*x?&zm97yrL;peZl^+4wv4ABQfU zVvvFx@Q*Geic)LLEFmKO05zp%mULnaa&oLm(hOlgW$w}0G{7QW%#(r~Bj*6JyXX8%zkqGvy@Kd*nM)38MF38X2dZ_;w>ZgWQN*AzKty$4-{X8z7 zfjrc_sx{IP(&q3}Z90O9;v&MJTGEoG8g#M^n>MfL(P@|QBss`u*d&G?&KyJWuM2+T zIs6U?^LzWa{QO>Vi49`8{BGF^zwmM25%cofh5t?YENt(n1gU)e;d#C%%ct;+sP+E@bPSbW{e`2q}v7q`;{{z}AlXCUfABe}XBUK%p7^ zI=GRd?$UfvxwCtz&3qkmu~C&X^PP(op!Tgu$uZ;bKsI6}mWoc3=6UX-HR;Nl&A@*l z{h$6Bpx+vx|1loq&_CR+=-8FLUxog~LHcb0`s4D^{|mRTe_i=M4am>+K3D%A<3W!6 zqwQksuPJ|!emT}CUHaql(f`t(UilU2|MW>f{$zmu$9Rw<|JB$D`E~Vwagcs0K!035 z`aAZ`p`RO1{Gi)qCzOFLZB~B_CAcA`#;jz~Mi4}fHhxW+D)P1~zo$(`;Uxyun1u?@ z7j?5#mNGMO&mUCgxKrkAezW#xf5K{mj>psNxnb}n5fa;XoHiq@V1>)>&n8(yRfOvU zvU6!y%oeTgsp812OV^M9COfq9EX^jfxC;IBLhjxm4Zz!Gm%PK+$&|ouQPgo20Dly; zb*5!5#cKTF4)@HDY+cj_JgUuUsmjPqwU%|YWCb>{+oOQDx^851G~&8WOP|+FJ?l}0 z12V4CV`g*%lm{Gu6THP4MJ95|CL+ad7|5kgS(Fksl6Vb<&7hi6bMsMHPBM0uN_%|&G~f4zbKbTx6!3<9n=NZWZ@BG4TL(o1)ab}UTRv}?{6Ri% zIKqr)eYQ6Y%=QN7=X=B1M|R;2)c`2q4Ni}O-mo!Rb>a=j@@}+%Mv5J8Bfcy9A7>T% zi|qfi;=uz?`W`UxZF`pY0E$>s?c=y*~`?YJ6C4t;UBpJC6uQ z&#;<=T3kkw4tDFvjiM?|&#bOoSlP-e|GaW4>?f_lA|IfN$J&nPoKS z8?$*%4*n8#&cxy~L|{hHQdJaVWi7i)R`PK8JFat_39X?455~!BMP?U10)kNf@CXl& z1#sJaDIj-0_n49~W`wv8 z_(b0Q@u{n1_UO*h8V|QRmY^}4T^|asA+udgLgx_JPC~mk=^%D9{Q=)4eq09HdBcQ& zzB1XemZBZ;hf~lYf3_Nsm)Z3hlQz_HY%Go-@Lmv&kHSn7UXfb0cxfZi}!PnLE!yNxj~!q*|QdnR#+TX zw2zi1c|0y6g_Q^EneZ%vRf|f+5bf(({FMxJLg`9+H%{#Ic*9z@&ObKWX107)-%6`L zl*~D{$!YtpC36pf&6OA8O*ir?X-$g9W*O zy+`}bwsPrVR)l?C_e5xQny;&z0=(wS6uFg$!cb`FBeuQW(EK?}Zv zybE=XS4!Y*4Hvr97DmCFGnAgZ$fc69R8nskw31i*g~}*2-wO^F=Z_uW$%!bBLOI#n zmSuR0p`xZavh^Z-o(M1BW$V|PU+yj2M=`v9(py6f!IBCV&jBN&-1(Ng2a3^-k3gvN z8H6)AQH#lnd&JQWJz0vV@)S`8F(nGnT2ToHZqTyp(=oG0e|rM2Rvg|{kkiDY;m5ee zeELdTeWgWTb8*(s-A=n5N*HeFa3Y|>%RY8_@`|Bt$gwShx!ZU6S0x@O-ytPbQ%mE=3*{!HU$nvhTx@V z!Byei_*n+k0m^p9ic}O%jLb`VpS%w`K)Ztg#Mu4NH+U4F1UvohZEjEuED`@7(Zg%k z|5JJ|{*US1-7R{s07Ds~Vuef&Sdc6h0#pO5cF$t=e~oc2N<;Rod^ z6Mi4^ch#qkTqA=hHv-ClWd`F_$1Ka(PSgp$ua^W^+-Mw!EYH$wgYS+YzWPPKn7<=} z_^unu#y9R>^F; z2`O8F_WOj}@598M?cclA&YqHz6cRxow()BLd|QR_ZRO%?3E=CQCtMc7R~E)s=HeSc zd}o5!az&QS&;9n_dfMJAs(e4)qV4+^Xm8cqv2K|Dx-R`^-S)`OL^;{?$J>AFXnU`y z^e=YXV+NA9{n7SciMIEO+J1!FK3<^xHI)bPKb_nx@QSYdH>>;zoO$W1!el}aa~60- zZU3y>KAqJbU0=pugkU_ui%x@+QrJQ`bI}*Tzl7wYYXI1 z48~p#%{^PGoVbU98_mFnS%)2@LDN3IzPnZH>BEKZ<3X2Y^-w(-MQ`m|p!OtYTHw$^_h97*ye3dVJsX3cYq z%62ZQqf^YROa)8{H)u2e#z>=mc-KwuVAG+tsSN|VbVz|} z%>6}rmtBo{e>YwA$$yufa9>cEok7OGXtMBKZI`=u*>#TYqeJ|=>@1{ryQCELcF83G z3*BY+Mz!44=HD*)+krW^OU_4C!Mobz=9Y%BUF-EAkmj-skJWKCcDj7RtgU!JMI(sl zQ$#zbK6F3eFV@fEgZ;bgI?>Nv01Mq^_c)jXZ~i5i$45m!!Amf2KvhBgbgCG=%WhW? z!B#}4puxNBHU^g-Ft|%u*L;5ZNz&%Y5$<~E?nKJ)cwitB8M&ZXL>N_x zRWnYl*Yxmz5Lk8;O5FhM-#X}K7rbvg`+awe@G^ShZG=#2kz_sc;?-73&`BaP`B-KqhobtaNF&Da)bzXP(?NDfpW|G}0F zomrXztz(M!bRD)VFz8M=$KD$xV|R$|?%IHeQc0EZ7IEUbRP{A+v@3=ZoINqI7$VpH#j_)-zxr-vC4@ z8LrxgM0Xt(4j|V4m+a8_GxEMhAvQbBp%~ZpKIMz|j7%ErWaXHsZs# zxP3Ps@jh@Ffly7HcTpe!3e}M=>iXjLKeg9*Od*(1+(CiI$#xl*#KU6sTY3FBW-j9K zmM~-TdcM!-$}^y3@11=+4v81VMc-vjGZ;^y(6~7qp;g&Xo!`2v8$62JpKU+Mg2e7XXtg7Jp!h{h_fVL){Y zl`jJ)SWTO8SK9$Fz4CiuAD%<(@rL8rN);tx4A!Kx90dwtMR{6`nqBoFn6`$CS}d*^ zSGFh-Wd}aWaI6xFHiPm~dNNZl3;-9E`}&z41J(%QX3E`@cxZo$R=# zSSzpj2?}sW4KF-~*Vnj6R76(6+>-th^#bwJ`@Nyeqk>8ze1ZKM?f-D`B=OaXv-4XW z69IROMA-(xvKH%9R8mV-L~d5Xs#sOJj-|X7%w1cvk9|SYpVkKMF?Rov$sz(#apsaS!0Hv zQXW%TxJ8Ex2+nIi@W$CRu2aUoD{?E@dL@cIc`00zw-A?`?>Lk-C8+&S-n6>MMF;Z% z5=!fxusDf7_=sqMKP>@M(vSbZ8YDF7fUp4e79TdF!o3N!&EJ5y)J0rMGNtYiTv#8R zUwOV2Z7*lDV2v%Ko24#3p5K#X7~?>(!w?3JB!s2S-*<9dN%9aMjQQFt4|_(&1E|;a z9tr18`;}G;_v8hmmWBUZC5!gbX7Z)>BwQ_)N?*U*PIlmC$^Poh-aTC}`scnG-Co)j zuuyq|g2|GW4U`Ep^C9~r>OBFNmB~cFCHX}WK7#w*?q^vXJ7YG)3_=SWv-N-wn6Zb4 z`iu6=vkx&>hg(yJ@LthjotXq3a-}U^elBu+xUdRaN*=UNlHLm9BBVvxD0bW`+0} za}xpr^)2fpfupH1DCO-R-M8_uD5BP%ZD< z_a^C9=p$<1)9-cd`-o)X>>f~ZLHi!Z_5u6W=dtf^goZuN$`Bf(Ptw>dAmZn<_7Q5C zTx+lU4mGBHiM3@I1e-sP!k62-cwF1xig;1b_H=*(_P2dg&zwOQqG$ZE{abqAY)ETH zTx{)+8{6%HH1zKov&n^fK8;sGN3tIk@85hW7H5 zIf^LUGs@`bwdG}(PNRWsOX0MRZR$)+sb%;W_7}UgYC$|>ffmFwF#feV&Gz)D)1pVB zzul2aV@&r#fnaYmpDnjfgh4oMY5NJv!_3NGLh(|ih-Ci7*}^AEly>^E-ZqZDNcf6ynm5`YElp}e_Tl1>OkLm*3NT&~*_Su2! z(b34xq)pu;fAms#APMh{IZ^vjN<+kVEP6cE!-eeoINRm!C$zDlV;b;#SNs;mJ_b2Q zGvs!ETW-42Hhv^=?X*3A!@}Brn%TOm?R@J?wbm*hZTXAzt?%dr6@C;qx*Nq; z6OO%UJSMIbcJ9rbI}a}eLk`2FV&|;j9i7Do=2fi-fqdo z7*qHz2_vwVW_6`gi6U)m6gL-c{Bv} zSSd$VV2-;MGA@iFx|*y03X7w+M)iG~tqDqDDhRbOJF z(utefKgY`Q3qbAbtHAyh#L$G;y#^ZHyJ>aw5}sg@u%ziq()@G;U)2HeVj(jyqPK^U^7oJBwMaC*wDA>}a_ z^(6*O%%XP}+?+iZUA&Et9j(0Mz3b>AqnIPvUYP|2MkIga(j0h*(rC`gl`J-SD|L1V z%l1!C*N^i~yP&2md}r%tl4v>&gXkWj3$V-uisdg9@n@UX<34FuJEooCcbLeg&&PVc z%)B^;S%=}G+@wCvQf!>TC=1n==}3c|m`<&{SG_k)e%wZ6&B6QiMbr*duOn>I4?EFz zSq#SBbF1w%qD!jQ(xsemi8f-5nnC}$HPXYjmHe>h@nMzOrj<{c0k{(5t9VHWLXJt4-Sq>*C9%%MyKpj(6zjzJTyqxG;oCnK6}=VEs{^)2O z+oyQ+=`;SER+IbMc7b@0+|>nSU^jKqcR4JQ4JbW~o-cuhpU{^d5I=?5?FUdFl+?S{37HLe{T?>iT}Jz7dhg@7~L~ zO@-e^tlyRuDRa{AXaLC?GG_K->q{vt!hH?7 zuZc_&?$Z|R-?xMNzuWIC@e3RjIckf0d}6t0XLsy=Z61EJ=~i}+?z5PwxpdD7(|rMR z7@EM+731Ph=k`5$7Bu3rUTY55A|edEH;SU}@kve8yn2(>V3Yuw1z7YHJ3*!SM&^n)b6QY5+2hD#zVBKB;%$11Z@yP}V>%d&TWb!; z7OxBce2#d1`5w_?EZfOGZSsO7y$nOP*f8764iB=qFGd#E%eD!#dG7~VR&aY$SN*A9 znBR_>mi-)lcMbC^5i^fkZ)bbASNtoW)=8Qot6H~RTYgVg1|@W%|GWHZ|3@+GKd=6t z4CVBX68zSzbeTFkcguC7{Vxr=QM7=49%&B9drCmw%TEcp(cj)*OZ!Y90pFiSYca<@ z*G~=kQ@Q}Z9ajeU-5|j4Up3`b`^;ZkexD-&pWjEcsQr5iiedWWQ{rU zb{4On#=4eGrx+;;#p}VO?Do9XfAA&!adb%1L*L7iw7^iG4x|{ngM$Jo27Za1NyrtL zgtfnKzc(vE%`Wy*g6=3Enm<5KO!(x|7R-AXc472{e_q`!`{&hBgP(NThSxkcaalEl z4=rW2qitalvk)`;O9*QtE9QWvXu8!ANJP!Yv+*o@aS>s&0KP`zONH@0dWDa#qvGbR zFpT`Xc3tyoH_zBAMHpqolDMzkdd9vT92{yfroVo^B z)MM@Jd7RMCvF9K2p)2GcLH+g;rWkzf!KzL)J=xL-yw%>hz8~hHY**gjunj-Am5CuQ)P!2dAa)ID9c%J=-JFG@=GyI~w|%eer7p z<7bo{fFc0)BfxFX;AP|aCBs=2>>1W5+=ch(U6PmHzvn2uRrq2a0lKCUEAW-{mf>DE z(0dg{(EBq1hFN+~v#JYvWqIkng_a`UGNRi@=(XZ2>D_4QJxt^+eIFpeb5DzWH?yh> zdbi)7PoGDlrO>M;y1xKj(}*O#lHL}U-Y_C>={uAFhgo`?v#JYvJLaYL`MpZ-V|+2A zh2C;}CB4V+>j2Q_ODKZgVgmf*DUt7SR&_z|dDuADUw)64BHv7+`x5AyMoht1(i>&z zO(61?zGo6(f~9u?tGb|fcwTxtprz2OB)Ut4UVnTgz0VNUA>TSg-qN=x0oLFLbKsxf z+|xC^_o4G#`C8CY=uN{HQzrBp@s;$BvGm3hc}w511h~r5JBn3Z$Tu-Bz4z``dT-*3 zxgF@5Mr81n^yWP(^36jL^d2X`088&qR&_zI1Nz7#A6i=Zh;B!r*NCsAce157p2#5| z0Ull;`aFtNUC_HaFTD!16ncG#?s=eV8WF))(lfY$9`^gR4i|c}@s;$ZS$bCyc}w4B0=$Z!)PcWT$f_>rJpvo&>hof>6ngbU_dd`y zji|*}(%ak8JBG+x`W{7qlP$gdSk(o+ee%-#@MuZ`BJK*;;$GGIU&XdLA0Y`T74$AvW=%>Q;M;4~PPDwZELwrYW`uK%U$m3%ss!xw6mWN#;j#G<6C6N4ky7 zhq4c0TL0qN6gkoafBH?m)#48N`3Y>nt_qp-oY9?IP)QlnTOk%JM6CyLjzX+F&*Hj6 zA?iGceH3D;LR_K{jUGf_g}7TGP7(;~ogUEg%`Gca6eyuUbWIQFP6awafi_eiQNAnG z1O?hvfd(j$7{vwJU4hnDpdT2V;Ni1_>iO)=7W7lUy1t^Hmh^Qz;$pA&F=F0=t%c`64wzv9NGj_Js z!WT_6+%+fX16ytHs{_{3c-dI-pY1ccC6IEalUU98wsOXl*kBN050 z+Pi+Dw&@YC5JJvIvy0;B%a{Cpq40T#5I?*GNEzUO75{;~UDi!;{snvmN#Hw2tH9Eo zjM9bh$r>pY^k1wFc7mr+d-5|ymQX}w=a=$BR^lfkby{6C8AGwRIDBF*d&zb=T;Z6A zPd{mT?w_NNWL|yP`$XYDQ~_)nHT{1>!{B@dR9ZGK%-*t| zpM8=Qp^pW}26zYG3TM+Jwdr~Xf*PBueM0zmus|IZ<4Td8c~E}Bl|}ET36`JBj3O;; zFUPbAu-GeUJ~|}F?)mjKif8{-X+eK<%#zx0aZ4=A?7l&u_f06$< zZr;BjoBxyKF#=~k{%0MWkN@oTSM`q#m31an=K426{_*{p-mMlhr$CN6na;l;^h$`AqabZsh94I9#>)Fc|KF#{E&IO|IM&oJxdv2FrUN5H2Q5@Km%!Wv^-OO3GlP|-B7Vh zJi9)KJ!QVHr7Rf37}B!+zd(OlENwJC9uUTNuRJSP`8g5tdQ`|C8yTl0j5V+Y;~tLm z@3am)5C`UKyePI?9p_|>kH|D6X@ZFfj2HLqD+o-jrk-MPAo?-41H3~H9MibEqr?40 zsg(Kct2n;EfxE_xe%Ai1OceYz+9;T>zspM$%s-vx>ndC2)6QYcZR4moEj)S}3cwo= zzb$~rY-6s`$8fpO%LICL8t&V?m&h=@>T3dv4z~CpPpAQh2#`VbYNz?@h3X9$Int~K z`mE|YWhzk3Ws4NwOh6Mk2CKuxFv;D07G7OSmY$KHCvuvfkk=Png3b0wR`NXu!Bx<}2s-2i2EhP~x$|bffC13ZoOyboMKK)RO{cZ6LEk^c$sanf*l~ec(aV2KH+B?5@>7jkV#7&=lEKDlYiP`m%8oFjy zBDzI7Do*)tkMM+Snml0lpy4^+>%bWf^f!8j-`~RVdFm>WxVhwzDW2bm!W50^jV_0) zhS^J6^sbZIz4SF3s>6K7Fv_Q7|LhjI3|j&Jq7gP~AGQi4LXP``kTWZ(-;RpI{{ucH zUJH^N@kZk;ms&@~NIsQ74xWTJh*oJuBlE#@H2--EUH<|9G?p=cgnB?OL97XWOWGm8 zyEVaGC=xK7bHb>=w-v-?F^@};Ioj2Hj$im7f*K;g?K851#7l7iFq;6?_gY#=h_AD& zPk1tiPma9XKB?mq+Zj}C{h?o?YNRb!5ccvh)bY$p=@oXl!#!2wfVn!&b{j>ppwXJC z{vw<|eat!!z$HZvy?;dq#u$n@X=r?pxzE4r;`#MIbl!f~|8)NRuRnJF{5zb$>WshZ z3-Cwj;-=x43pn;& znHE_Jll2rGr*ArtO0a{)N{PNMfwB#$C}s4ICtG$>E`52Ca|`qgY6YEekw%%c`F>9g zKpc+D?!XdwZ&T`C#Qm0ZaJ-3Q-`lJLc9pZ$kIE|O`M(2ccr-*0Wlns>=Fh;SubAPn zUg9A%ing=dvp`-oyKFR>9jr!bppk*{LZ_@S)ljenLBreOJ-1U-_ z>B&X{Cpf-6k(2|moC|cfdmB!H{GCfVutX^=QNjKmT7P=>K!_6LIhQ^nr>tSRBWcDo z9!pzh7_g>L_i)-Id8sp#r;{s!2i;V>efaj)D6!CUtVM{&j22O`yKr@6ke)}3A;?!eAcmv-yu z30OW64~dfRxY_GU@ZGe;aXVQSJ@`A+BdElv*#9}2&elp6l44pjV2J_=m40{+o07&yRn@?&1&pK8r3nV#C}HF_?^}FUy{c-=7Y2yn>VQsv|89 z=$L~GJWH&OKbo`+$Fc~$Ik3f%GXHqao0~6w$Ocm8H!!eH&$IFink``x`jhYb1B1b#eCc~ZYx&K1{}Jg zSP;XFjF;4T)lS4%PIEG5n_YGao70WuzU6E{V23oDoCO!Mu`_G@=*uU*_5AZIHU7kx zffv%n%-jz0&_=;x=A$uSU4#ach#U+q6OniaeGK2zpj#7|VYu}DtuiPgW=<8<5RXyP zAU(-KlLQ)A-{()lS)Dm)n5$Q-DjDgZ?@vCn^-<A*QiKv1nbM&?48n%s_RhVG4 z_t?1=_~G4I@3$y^gAf+Lzp!5Le~1r@f5mNu@z3lA{^?=-$65T3o#*18(h2_UyMaFz z#{UBlLq9K{zAmn|Z^XP=0!qUCh#z|54ORoy!9)3t`d=-^j)4E|4#$T5{!{1L zFMUg4{q{k}?aF?oVf=5yF@XP4um40k!T%ukle@w{=ZFyh7g+q~PIUR7eRE;{t2@IV zqj3<0sqJr;BlzX!pCNq);Xf(_Z}WFQ{z=GdXfF`*Y`mPhN>(nKOFlW|*f-{~b&0{!$z&aUK@MmV_SHNK286KuAu=hNuc>&HW9&DgJ3@?Oc zE7IlS*oRl6e(Dl37Krcp)}whR(rmUuB3wrlb5St+Zep!cOC_VHP=dq2%q)JbGk1ao zTe;HR#C){K{Ir96r1BIBi5WpFMA2Q5IC=mjY$OetC|<>;P%vunq-Z#MTqXM!2CTqo zLlBIc$IvGUi@|PWYULELX!oM8yu$K{(;64y$~v@nvBwp=W3iJV>9aNV61*Cqh}a)= zKC@L(LsYDCjVi*VfY%gdF|m``GROpBcFu;`+3X&GiOZQkVSD_Hktd2y%H}1Ht9)PT zrjI<%uX;XIMMYt*W*5l(#W`JMLwyF^n^#N`V@#5NM)J?T^tm)Q{qZ?YHvfm5K?Uht z5{@zj4_TH>vPCnZ;=|~h?&3LZ{V%Mq$q7g~3JHR<(2l3nT1wWv5aTlw4Rw`&S}yj3 z0dd}>eg{vcKE_uUU#ElzmIJeS_z8+Gvn!@|We~A!zGOA{ zeX`vmk;jC+q`Ce`s7x}NeDqC^T>ABve#-@^O|&==R|KEEsdjYnW8wBl#4m&i^GJ5GV#RAiTs7v26~)S_#0o;k0e(6y zx-P%Rc<2Cy>8uQ|2fAzD?!(o={ZxD8(7eMfL^qm@5hHphvmix=^Cp=(-L5GF9%CDEst_nXCV%U@|Xe!*(mit zn+nPmCeVG*81y?<;z4B?5Q=T=XO0JxKXi1hXi0rk@XtV zPnjn<-344-mIvwPc`Iex>0}Ye>cK2N0V-?0&PvxjvyuztW&(}~A)n9{^10YclI)_! z48B_*h?WjS{D5k-(M`9B3sBYWDwi&<;X~G=f|trAoUz)Kp>@3$Drzp{$=$IS1LT@q ztO0PSXgw5_qCW}!f;7v&{De-|45MZbpaK@VwUaa#hLO1QY2 zM`?q~iTl`R>?hBQ$UzdA#)2*?9sas$d>W7GoHSW0?R>x&>_}Km$7xLQmPY_>Lg2bP>8CGs576kvZg$$e_v&(4sxmDkx*TPv<0$a^9+*sx*fTI zfq_)nH!0(4;>2gfQr$X4mB40WrHbCDveqlrtzD{kBvcQvRGV_BGB+eu2{)j2sZzZX zGDE|XMtd5TL~=Ivt!dY^N41|)<%Pv!S3Dxsj*5?wfu*|?Iujz}T(Eg%nRw+WFYuz75OmA1|! zn^dyA#95@pBTM!>oRWvMm*)0dtxFb< zEZIve*(-9$mMYnllD&sWA??E;4{2i|vII6uWNwC@$d;~CO;pM;OWBU>F7Y0g`1Bm& z^6nBGE3L$bDDj;v@gX7N1U3gMahsB;NB2)q>u$8UseoqjP)#iqe2xzl4d$%Ef zP>nR=j4U^@siXjMhhrQ&L=b%_mzjp*-tOX^TifxAp>vC;r0ip#R!7_Dw+?1)R}`m zv>-{VO8L1wX@PN7QmI(Iy$jRoaiA>Md3sz1JqGr*1LL`SP@$yl1aDv!9l_pnper+L zZxC*dAQVJO<=Upv{W>}>Yb~30DK+Urogq+IGe74|mG@A9O#5`uz0yFcvNMvr&$~%G_m=WZ&V(RP6;7ZMm zKIF3g5|3m?o`b_<-?ALW$RT!7$l=pxg2PDz103Gug3rZGwaOvPLu>U@4&eaq`&g0? zWz0s(VcL9lEjg4m$7YR4@ZWDQ0tNj|C_fv_TF}pUrWPDTL&>fz%xdxBn5y1_*n0W<#f%zAl4_?WRbS4?`_f$guU}I%T>v5hxG@`$aWKTq| z#aG=yVy848uGU#h*=2X|5e!_z`lmLu?iwZCC>;XC{pJ8;|1@{L&Sbuv%^17Pe$#-% zb!%OGM8rx6Ri>sJNF0<22mb&xsRv;yuD~xH5mA5AKX?~Q-ywbJemFT>GKwH@rxf8< zli&+jDuU5}2{5+9fq4bPM_VP@ve77AUH{S)ww!1hPEm_!!db)~Rgu({)te!7cT zm{!Q(0MfHb`oMA2SHtusRfaBm=+Q+oEfME@uw$lwg_tYWL-ST>red}bNxy8E5M&&XhR&0 znt2;Q3-+3yWJiVh3tHuOr|k7T>;$kUxq)$E9s5VD4(-L%64vXFcUrVZ!?xPxmT2yh zWX+9i^%!gHZgirfqO#565Gy0FfFZhmGXrnpaO+BOyDlF`YvEVcE>bqC0ITTfmVLre zn4ON`?(c?XpGUF3_7Ik*W@4xEv;v>`R`UjtX8~(Il!Fv)WEt{oM4h@`nG?#VNk zSeNacBMpaAxajA)UDosmMV*b*t!SULPp4qEzNfx9$EE_i`}EfvSh;M{jH zJ~}G?{SsWkwK?Q-2S%>weU14gr9r24HAM;F8BD0p+r-N1_qQp)Z_dL0E5PGcsJPQY zOsCBPru-6R_dg7BuCtO(%G|7JLClh!$s(hGduE=A>bgIL_dLOXI}~<_f6p-f?epUQ z@7m*kM-vw1e;JF}{GVSKe_;HUqj+4spDs8gvc%0$`W!V5SCR8l-gy`ARfZ{td*gjM z`akq*odz#%z%UV1!3T622z#gO(j@kuk+^%_E<;d1<-P_t2}l~+=Co))Kac&*=8qXV z!q$B`CV^HArAYQv39STb%e~-#n0m&_nmDbYe<8KkhbiXpm>AYRpyw#cE+fBu-iA-p z)=Q2(YZIyZ>25rS47DDnuLyblfi~i$si`yp>i7)_W-9^fh;l9g52X|>6<<7SM{cPb ze4&oqg6ifg0vz0SIfPs}5c~A9rbus3nzy$ETb>g`v6PfFW%zQ;pETt_zhXx^_-(7@ zM!Pi6E$~u8OSfyJauQ|JkUc^b(hW5*2&u_LWoLxgImJc3hekNpYQR3mU$6uRNwzPc zBz<4OM@Pl`t+J1SUOq?g+3AqEaR9UfJ;#MHccYi046yjQ@*U-3{B&uy0h}_!A76=f zw1Uy-$|@dIz+!olF)Yp@rr~Ib;lVJFG!e|iW_i}V<}!*jP2ky%uE`wuc#gJHm&8y+ zZ#e zITw`W&vVzw@#h`b1T1Iy^Gj9T^yiByMN7qirvv_cG?GFwgZT5GInQ%@)T>Q-{JGcW zo%{0-J6N@bIh6)|#8P zr&1)`$2uySmRO*0CdG9bi3r=5d&bbMp|r`Ag@-QzycT`gM#HAyfzCjG3-T8wf2l0~ z23Y=n80zz9ZdupT53|SN$%ZsC1^SL@0#mF@dO1bfY?aHXW3~%18se`%1&ERhjIO<$ z-Am8w%^Iv$6jvgtlZW`~G~<7eDrF7MAqUU+=i&n@iFUsj9}amjq&1gB94}riM~!QA z2EElNn+5&l(vq^P<_r8ZspZiH@C>tJenhB$6`k)Rh^aeLz+X-8@v+hVH=->#w=HTg;lW zU|iw^l5Dh6?hDXYA;Na^Y|N9oG3RN_TTO_+@rnn4TYj`4B0Q^w8ge>U3ymz?sfF1@ zu(>xag5JrS~^4Z45%US;t z=25deR5wb9_AmPSO85Z?lQol=7&EM|KjogxA;H$Kj#11?YqT)?)Mj~^wanExF_eDcf888e~T zvK%)bOEq<>C#|CGdk`J^mCjcM-9s+AhJGA0eQ1`D2ij9wv6V++hQ+0%IB{S=*Go?v zQ!N>-j2t4~%qFrUOHg_$&;D+B8m8cTJ6Zd)c%8_h>Y;xww)A5VlW$AX_y)zrS*$Tn zY7v^15uc03MRgiTUT%QLNf>}HTy?E)p0{_EElFH5N3jV{ zv3N$sPO=2^C^k3#bc6q;Rxdw;e}8KhvI^I52h9XVI*r zkN)rU|LG2eI$o~oyWy`lY@g#xZhu28OPKk1L7gXCY>s0)qWW0K$_Q{wdgpB6rDCjb z%nhiJM@-Og1eejbN$VV2=OWFAuSP(v-7UfcT!b{U1Y z767gm{0^Lto&yZH$8oWOCm1zyy3DUh@2g*ho<^noBAtyBTadD_ zNCo8aIwf_4fTtUcNi)+T42E^LD~e&d#)t7N8|f8=kru$mc&PU6RQ_AnQvRBJ;{hV8 zAK~Fl{@?)R+nk@0HBw!5bg8syvGjxCK8;?)Ows=Ph4PKSFa303-|V5zA&)XrLj=10 z>66*VIKnmAyLMz4LF7Tc{ZAz67+4^A8-Z)gUJ(miqhon)JV(ap<+}qm0yH{2-Y3)T z%qWXBW`8Y$1x>Z%<_Ou*0weO%#R>as)?k#gfc-IhZvV)BfBSP;ykYg4GbAX$;6`phe~y6H*OBy18-gY1G2h-C$GK0b1BtR zt_S|5fCR(A7kv~(soX2YQ0gxwk1_R!KsiXau%ls9k`uFwW#62IxhX<)tbI#)fOk(A3}HGPs?V6gIG^ri z8Q`|cP7V?}2&M5tR4Mfa8Xw;4(GUUQ-*$s7j;($8j%15*5|K6PB)=sCcx3lAtA zy8X$@Sb8-tR zB~dmn1*f22Z-WtQ7QZIUe>Y@@YqRxV4CWe^s8U_B3XJrh?ydpx_6qxWDKK!Lr`BPQ zvU%dt8MFOPTFYZ?iRso0<1ZyQl@JgI$>a=2*g<8~Mc}!AV82j~9?R4bfpE4E%a|<( z2;*AO)CzerX1T1kHtT_wKh$~W-ndvy7fH!E;3s4)H*i`wW!7A66FbIQAp!r466MiP zQY(tl{6v}TJ~=Kz0?js4k9RWJY7R-zSXP0rY=CwsLQB;1fOD*yC+yxZ7X-w#qC{D_ zh}9oMB}x}alGn7JSx&l)x~ze#eXyv(a4Fraj(m=gtJ%V=?fz|1q#agNk?ct+_mFq4 ze}MChd2tKRA&$oyh+QEOXGxd~hTGl&z&7*lqq4^-X3sjW;Q0pdHTd|}%f>e(h;Lb! z_!6Q@(oPSh9X!$!(Q@VS;UidL>jWN{5I=)ab!oB`=9D(28aRE|CpewD+@94`fZXv( zI|w_Mglq4frE~hyr^4)Gews60aQsXL#)zA+c9hdj!+G%j9luMu%mc{FckQ8&xcM`N z+yJ&*`p}=!cjy`aEwS)P3A;8?%b1eQ#9s&~olx?(jQQfLHMlQ8+AmGU*Hv96ou$bH zHA`(L&8Aq{{-&(FcGf;r}=UCG!8%UWNES7J)}(4e`JJTJe8jabErR&&&UWFT0`to360@k0{9h z4d3VE|A_yGxi5jUb2|T@Niy6-#*F3aptzC2C6-H+iKV^K;6{V%SFN#6OKT$NR1$`Y zPC8LZC$~~UZE2&npw!k#(qcj?*AgaZ8(XENdSVh%lu&X1-|zFB^KN&UiS*asKOg44 z?^&Mn?B_Y_d#~SS|M`!d3jql`M%R47hZW%eJ@51b|Cg|7hy0&%a0&jy`ts(_AF@sH z%L&y+C@IYR(BvC>jvJjZ+|%+##uoQz5?WX7-`#`xC(n zF9iR->YHx@_mb6J`v6_fw(CP}*PrfvU!|$;gl;(7F^$+w0k2KQHZqh*khQ~T5r_TL z=PqZO^u9B83 zoxNcdtkL6F(xZhGl^kFBk&ElJMQ&6=!QL)Fq|7}vQBx_2Pq_cabqSUF?{I9*kbK^9 zE?wnGGG^-2;ExQCzyrpV2Zg`|_%G*6WL6pDtC+JQ#7zy<9!xkSz!hOjMmlZR+g|r; zS9Zqe1}p)-e3Qm|Y4hkH#T>Tid`h%EvIe@8FkAIfeaJGjc$?K%_WWvgxp=1swzU|n z)wZ(o$&d#aesaNb2uQ)2CI;*=K9MM7Hfq#uPjg{@^=;kD0F&f|z@V2&IvflgA6DU> z()n!QdX&XA8^HC?y!MZ!&F=>)3*1CqN&P|d#Y*fu4MIu;mDr=!#b*K_5WfKx-@oF; z^JAXhaUe?QA6X?-Yy1h3m~F5#kyv1?wiuZ9nD~*Nf#`k04@v^do^+0vg#f=#_z-^D z%on>LeD=B+x(f;%?7*>{=(pvM^Bn%g{`+jiV75Mr7w3OcoUgqMy%`2i6~1CA8V`hr znOh(5H9NE=!H@Uh6zr&$R`Js&}vbQw3DshKWD zh$-j6S7ES?B9XaHOm2W_XTnBsZ->399->PZ;yqW|(K&1kw^^jIWfQHF5*I+i`(CAwG?i0H($ErsFXWXz!~4N|dvdrv>ukFc4>4 z`1Spjuh~;quMoN08|P;HfHb*_$0XKjGE4prbBQQcd^pDJ0MI1&jIcdUfzg^9pt-XS&=K zAolwN2iV>sX)eKyHU2g*BF2OFZbFLGp2QM@EudUW@B^{G<|)FQi??xIkI~@{!V)8C zNU~@%JA+j}@kSD0eg}&Bl%<5SVM^IBp==l^lbb=vcJs%J6s>uzW@Tr(nTQun1QL1k z8mQB)>eH@6QF3NDIfMO=WW&4HIX1lbsj@+Y>HzNZ5&;U*G4s6L9N5AXVS^IQ!=fG3;3+nHlrt=g9z(Tn|%P%%^)4@j)+0v7Eo;1>}Rj>b#_;-3Oy$KNxLKtTDgE zh-e&iki{0&2@}}MzjGT9hPq&)xn?WU21ppS3J&1>Y~8(?8`s!q-p33TDQUF=T_8X# z78kTLQ}^ZCPa1#qGX>oitki8=ZZi*lM*hu|9jvy-RD2pzxtEjmn!NcX%v4jH5_oTz z#pYZom+;3E6cC?ffoIG|JwZ0wlacq^3$C3;=56t;$)h+y!CvgASUinnA*+$+v%at* zIYeBLwe_h>R-ljznoIeQ&eg3XwFAY>xC_~q`0AZa-Cs#DXt2I|xS8^@bt7|8n3aSm z&#Vz2Lc}URZd8EiASV976E=m_9P@Ju5e|Y{ge=a}0Ow(5|LqBniU>!Ls_VqT2cfym z9JHJtIOz7YB5hXOf;<2LKln#)&VWZDvu9qQ_3$bNbIPY$HpA?Ud5vZvxJdgDBr@{m zPj;0`fN7u8WPbLslQ*iSUAzaRDG5lI`l>Q@2XPqe_wGi6a5)@*2zOXmd1jpoTJ1(r z@pMW!;M&3&;FTa=j|@ zX>-CGWD=R}jRRa46O4b9)IvPlMF3=Vu6Eoz6ktqWm!?}GrYa(C-e01OPt6!DCWXMN zdlH*H+nS9=T7ItufUzeOk3bY0>d%3INEz_!=8tWko@tkaj`sGd$+`r>1_nz5Jt7zL z%=s8cY{z0v%vJC;ptioW77uQoI|)xX73>S`(hZ$HkFaKesO$JG#_*a zhEzrM92$;(m9=>-TDBkXcmHw)_|f{nAG``q(klbqq(Zd3Ka{r%Qa8Mynoa!zB$WeFlfUDihdAMLu*a)*4iZK$IJ6pdbMX#dqMoZA z-^jC5|3>%_Yy%?&J;Ml)2gX8q^})RDkvB^#b}mU5=~r}rX%YV;3hC=(Klyb72JW-W zZ>x!6C4P-;E5ontG(-6Hmto-7sO3Jt&fPh{uk0$}*C_C?5Py0`h$pciPd42&lo#Sx z?1N`*LM*LIG zTB4UbA`X* z`C_mB*R5wMDzEp&SXDVSS}MR#+TLDKilg{aPI*}=2%z``W-FCXxGwCd`3z{O-04^` z1+F~=_nbAJLQv4TBwriNG5?^hL~e&T@vo{tA;0IO8iBW07~W@Zm4G+nZ~ce&8Ydfl z{O&0Q@1hOBFBgV)S}AyUYyf_ThT&~q3f}H-RHQecPwgCz?F(N30ooxrgHntx_N67a z%?2Pa12g4g8b?KnVfpU(#!)Ob))s4KvC*MoC^DKwcC|%XI3%2oB4!Z0Fmukq3N{lv zkVuGGDj>9j@fMyC6UqKgm^-HN%{?nv+D;qg;tA>F^Jdf4noP(ZFgTQVWsSlX@06Qa zLi$~*|2_FR!eWn|i+4{~`@UV^XfN(Bm5ECFPiO=G7gZhKe z>BR_z@DV53WQ?efL~JQ@!<)1o&&~dNl_KNZtS=uZX2RUFtA zI?&j=05B`?{*@Q)G`(xRxId#4&#VlgeBBvp|9&OsJn;nOD>`w`e|!zTE|WTo6}Z@e~+E4f~i9E#*Ml6ZR$5&FL5XE`@rE> z(_oYC#C8u$Mor45T_b5S)`eXs2)WJtRaCu)FA4J^Hky0>0C=)dhX4;R+xfX3GvQzS zVlwKqnhx|4^4Mm+W7hz|r@-cP=%23-*zd^^5RgFo>v&;qv$qA3eGa++WZ?U(e^ip_ zRhrzkMo*)Ccsf7g3>V+c!8JeHuLt5~G^1NP3-RPsEy*m=4rbw@$2^M@ z6$+!%Mx^E2dkl9$3ZFR*woAnZAMLP3i(7mK8}Ka@X|lZjp6}YD{{D82>TgJ-t!8sW z?duVd7UCUfHXu{sgPcKx${z1fqGp5kyg45lAeh(`JEuM3RVVL4{6GR!194I=##fwd zst|wdq5ywtEWVIKlWF~j+K4ZaYj9h!{mtl<1r<|5Eyky>I*k1X$wh@5L|nPk;YY@PhGw zAwG^bN&+7$WjAj0_U2|Sk9>N-A5s2RlfG|QfFg-m)w#m@Hh3NSCGv(=c zdTB8o4+$M#P&%Iav(RysbdY=XMMp>;GD?Ddyy@%%syX6?kV}e`7yKK7<-Z1I!Ei`d zRf&n%X#JplbUej&sXM+wMavWJyf*^)eJxxmT0RIrQDQ#9kC%cg4Xx}Z(l+GtS5s2v zT!ubQBik^Jm|*GYh@4P z-Xju}W>0M6`%>_!ZS|HN-~bmu0_?CUfjxen@j80gNDPJ!1Cj>uJk)y*Pe_O41GY;* zjg>sxd|Euu7!GFSkGG8`Y4;jX5jyOBu?f9+eNsIOCs>oT7qYWv*pU9yzp~`+h?p#G@J|qb-EwDbOKQ4sx~{q|j+@7yct) zZoY3)dkD;Ll&gQ>M%Mr6tqSeC{x=uDmv;yY^0#?=ICHg&%2yWs58`jXlj9FZ@7~hq z!i;B2U{RKTH*ewBJhoI#wL&+(P1>e&@%>m>s7M7GA6f=P+4eCgwVg?>_M>GbFpsp4 z2w)|PNz2$1=2c!)zH$-&0%Pf|cJS%Hf`W1Gy1)7psDJF|s6QNTSuHqM`;YP@c>Gtw zYXefduT-tF{KgW#*WRXYZaBTc`b+)MaWEF3#QN3#<(2CD_R@h1mW%D>sPdI7wf}xV zUPu0I!|4y!-*hAEFa7I=w||D+@DgEofZyInX|o@B9gj{Y*jxDOu^C53e|FbdNcEqv3D9 zjc^4W6#Op#;3>u43@$tVBE25UU5o1GADn~Kj*Pt-o^U{+W?On?xH|M7yhQ5Tworz^ z(4Vt6MQ)b4Ev6Bj9>AXZ9nL;6VV~g)qNL`$eVAw4y1~A1a>o1=FFpsl^a?)qc~>*E zDf*Z$^efWebIt;2u)*`?NE~jt!MCB3-Q04OVT(H(vuVzQa4pi z`U2@`BY7=gav{C~aR~Oie~~bf=G*M>a`ElYR0@;k=$&nth}uiCb|v4BmwnA3k}Z^K z1Tz~xVD8aQ_IGK}OU&8}@lk?g58smJ`1?OAxV7whCbgO6uiGJb8?na*8=%=_)()OI1r!CKxZ=S-!IPKUW_zN|_yG^*wL)s3j;0JlwrS?-T1M z|Nr!w_P?e0kM;>AT|war@f{`*AFEsIeiPBHteK1kW)b16=vFO)ydS~un;o!J?b2oU zO2i(~usE9(RTQ2irbGuO33w6=U9ako6iah+9?Pf9-sf?A6=YMO|79|!IDWFk#UACM z;?xVm8ylj8IPP1eqUy$vW-Cl%Xw$jcD^_~?Q8YbiK3PN`>~Jg#?{UlqJ9uBI1n*9m zaS)#B54G9 z0w(b)*k~uf<{(&{$;{P!4YyL#?+I@=|D9|Xm0Vpi68ZXrQdEqgfqE4zq#Hb($B|%w zA5&L14`iz~>9t;en3j}jjV?M{(SI1Y@>1Lflw$3ajqK^Ec4&xI_M&gvOJZR~ZD#z`n^`*@*_C*jq5HCKR zc3wR}f*;0`h$lj;7Z-8G6HU#vO^I$&HCk-Oyg-KenyioN#Fi*XaKC>czT~?k&exGI zvN%J#%eYaIWIGjUAxjTkMwYr;jqNYq867b!;4j$wQ%*vGRcY24l~;Vb1)!ZuK%G46 z``qE;P4P*^GRoq#`~eSbl#gFIzcmshydUo86cYS8RRX4u?^K z)8yOYJ^~*OU~^TdYwiHDETfB=_80F>EDOry*Woq}xCv!a^{2z_ z8<8*_c=L$(I#w^Fppf((DF_5wP>{k)^A`h!bJz_|UBaG5CL&$D0H-bCDP8|K#;H7a z$+3)VV&P{?;L{}{hhyLE49Y;mSvQf2zKs9;{)|CYf~8UjhR*#xYPje|x-AYl>Y;w_ z)VU-w2C{qER>idqMu346=wbLN#UfW*>%#}`>=wiDW&D-EVT7ct;@C1!0e1eE+NO{i ziOz9`vLC zOm24uXV{&5lN4*pv`0%>d_+{?s{mO1n4#1j1*$^uSrG|cy*s%W0fPUi?nIe+85RvF z#p{%N+qj%LX;%jiJ*IIpZJ9Y4mxG5MGsMg3N@q1^#%MgqymGqPrQo568PL!E#ot0f z-t6;AV9^%3jQ$Xl-?J=aVw^W4fR2{}1oLKFKSfMwMn`Sm)x;SFd% z4M=nLxip>iN<5Kn7DTbDCVgV#XF|jD(I*N9lC-`!>EFdk9~CFPTb%TIanhs3Nk4Q+ z_!2yI4;QkX;y<2dnUikuXn-_YDXj?)sWd2_HYS|pB`i(12a2l~a`t5XrO)oP+dhhi z!PRE#<6PGcu_QBH)2q<872;D)rdOkb)rl0i985)QUWlJ^TB(fFi8pKG5Cm$UUa}Cs zV*B@3G1$EKPSu|@mVTCY%q8ylM-WJFT(4l42g%a)1?%6w448YpSJ1t}AM5xNt6}ku`If-Y z{KwhlD_6vSmD5XnPmdGJ<=QWYf1v)68(DwTO7#Qu|F#tUPd!~0dFQ_ZT?wXsutC6c zsL-&yX~zL0n^MyOtz_CDzf?`PO`8=+tCC5Gb3c0ox)$sD+2KIkc{8I&RF6*QohKQ@i^|z3&iR*i%nK#|V&{cqB4b)(2jofs z>o9MES~fu~e1bE$KFRa9017)sKnqJF|rUym!O(>#9KgBP-Rn`k`wjDKo)_ zwAKzpetpd^>HZ zsb}in@iMx_1fB6J+M!vt}!>I1XCQ-}+ z6ynt^UIE--Ai$#O>AIJV(*=>)k5c1HxLW$yuVSeHlr4fa4v-#w%PxLulIsx6Zq#gd z%S@to8%-Z-(>Y=jQh;EBSY$J~L3u&PA4&inXJ23~<8U?qLVV3Jjunzjf*DZiQi$Bh zoR(OA8Oi}E4k_+Ozh+MTg&ML)Ye;@##LzEU;h_h?SoM|Jq|-C~Xn9Yc2YSmm@Xw2h zT`W+|sZ)7-2-ZJxBkONkslIQ|bMI6I2;1{h-DQyv+4Bkbc6|qZaZV%{Tym$RAY907 znw`KGy#+t%Oket=T{fB2Xgg**y`R0E!pR%6Kvat$lDr$jW-piIm{~O!#h6EHvm*P_ zRiRGCUeWTKZW}YqBAy(X$|8BzT*H$=x%jbTbdfx3rsDnpSgT!kn7V|7$y}8$@4h=d zM7)Lg2x_KVS^dS8A+)%nx4q&dogWb5%4!QdsCoPb8!BaChJL3QZoj1@vt<(^25j7= z@#}}kyO>#mJ=Pu?Ux=d7W@bNX-BX*-mnL&5ci`vZ2Y%IU9vF}LHSuC7CYuMGJ%aDF z2}tW@Z@j=R&&M}1e*L<4CMZ|?g?VLF%&#x|8Cwx1R=?wwzhfD9kL!SO>*fa~lcnZ8!#dB}m$`}I;YQX!uu^?rU)pXdr7xc>Dyyua{-Q)^ zEnNhL`XBw80$_(*w5>iezd$G4=1d+lx)|JXMi<78%_d^;*7e_VKm4F{5;=Rq`AVjs z5qK^P=`hcC+*p#+BMqAmW`;bqbs6U^vK>eVRfHYLNZcFAdI&1y!;NWk)L+DB4*atE z%(VH^r|=qxHy(voUvss1b?a9;m$XF}K%)K>OsIS4I9O2L9QPbuxb?Dc))RryRGlbU z>H093yxF%WQ2&@)SECb)4CA!yRrl}Wazyio1b^pz&a(=z72XomY?CZ(kD1Wj42%f~3#D?eZ!7ap| zBm?!q;X?eGBcWVgZc+PkSZfEj4PA6Yk3&E>F?0G~m58KuhQ;_-uK(QY8(06)jjaEA zpYhEXaVX3`^Dw?`;AfVgpDX2ggG=YEY1uiet2eRr0`Sxbr2G-l@`qP0zfH9K zA8U%?L*8aeU-| zym9QDh~TH|KM%ZJBk;cK!CPG(9-G@?AqerYFVcU|{t1TxUXd_;2rBB~nYZ>CrhaR1 zq24kjT%Im{Qmx0v*MAbxo}%`{<#)<$SovS#`!PQ1bT7sJ1?OfD54eB`zTDldT|&&< zG%v^qffj{_+smOZQ8$t27{PKd48L^wc^g=ML*bjU@Tu!ZRr0si9(MJPWmtQTJ0}~a z&)Lt1UXK;YXM4`HOUMV8ExZELR{y&wAy%F@^Zv?CUKtvOhw@!$%9dvcN{ZUcmd|u` zv~%C8y;~XsnD}OMpbax-S4+Bz(l6uPF%9=WLu=|i)2SKLhD;(Psj*jwZTT4a6~|`U zfl49%!GQr~3+p58v33?w76ronGoJvUXL+h{VR{kvymABag_g4V-j}gDo(l2Do5Bqh z>nkffJU3z?0r374o4a`hfsMiK1e<9JJlPoElo&9QA;7|QK^;DNp zY#-JNKdIyo(f*ww69~W$$mGW?BNqC5ItjEU!i8sFKmB+U#DwmgdY1jUjPmpIH zu5Y57j?5QvbPr)!)(UKzM{?UEppV~+MnvMtuBJk%-t(<@Kk1FS{?<4>2&dT-AomM%&ul{CPAN@t{ z$Ccq9ily3${Qt1BeD0oIUOy?nHO&>_7vhKP7Y-W&@}Yjiwh!X16`qPj@@q02Euz$! z-ic-J(xokE%hEA;cQc*lm#VTAC-|k-vhxG11%vp@=Dq8pZuKkXZK(m?0%OerO)A8n zU=U>cqBep$njL{n5YKrCqfd>Pc8`!QJwk*&fcVKf73jBHG;gC@2{coPPZCTZfxCe7^j=mlM_x$WPJF5Zfoj8EOKQ zCXf5W9s1ToyeSc%mJk9T2$`XkPH4QPh5dL@7!bklVGK9~6T3}!Q`{4{50@mJP6f&X z9;fx+4tWrlpD_bQfK+_GBrCfC4#C>DIrdT{fzB)^PY=7Of&R-wy>oV5M}`)RzU zs0B$gAKh{oFNf8{VPr5KP&X&t2D9(tb?S( zKhkNoj6ZW*Mjpccx4gauvRD=DgB`t)ALV;DPfa^s5n?2NFYwoAW+1=Gd;!Y`x<94l zSAK@^2FhWWYzyq1{^f4aJ+zxPS8Kce{tc)R$yH}V77oX@Fm)@!(LoG!kQ-PA5FsQ) z`nn7tpNT+%6jWYfy?u9V{C>tt5&L7686F=4^s2r!Q+*E*d)5SDXpQ4xV7(0p*B7m4 zQ+_MzDB>RFFb+pJ+a-U4_j%$Y7}(4M7VV`>NnINtIc=tnr(!^bx(08S$7h!KIHS7P z79MB|iy+hHcohCfnoM>L9&U>#y-oMt#+_jM!SAOXKabMn0!mODs1JetOK3~hC&?IC z#L9NH2>6k}d~%>Xgv$GDn|=3_0Y)+7Gp&WC0n_W;O2FfOFRAdxO?3Gd&Y!?-P5xp8P6!{%zjaqf_=iuDlYec{ zvjG3TWnUxwyB>=f0scKLm%qsuySS+eJDA3w7iOXG#r-LTPiq-`R&vdZHmdSmZeJt# zERru0dGb8C2%neE-9UU+a0^yYzs4(nvSEBiJA4A%nDR49*S}B2?n_7n?o!6x#kd-T zg^*gWkPo<0t7Rd5&y@(vO7W{Hil3F=y|p2J-DF=Q{CfE!vcu!oW`A=SQwJ!&JI?7Z zznuNT9t!d2ccC*Di)RpnG@M_CvBz1r^K*ji z4g1?5e^)N`Km{v==!2BG`xW@4C&_b~7Afxa_z7wjNWu0}ny<#V8Mn#*jQ#%CN_BpyGLmw<~mAt+zIW7Lih#gdEf((-beozEWcHB|&tq0H@8z;t&Z`uelr#Of{n3gKecv z6899KI5vW@_I7x~Lu3Ci2=Nc}2|UcGYouG5_aq@yKgiEb{s8^%ga5)$aE``>hMg_F z{O4+HQtluBg7IC0swd|l`qWFy{xJT)$J5V&@cjK?LWus2t#1hcFThFKSKoA$ZT}4{ zyKD^pFMN*rA^gWSz>yK!d#KjtT>sL0RbJFrqD zkPl}jdJGReDDUb=_UmE0UbFLJx^+wKT<%%pFz}M^zzrkK_B(ijQjNoPZg2bqS77jn zAKnGaRx@i_p@4+o_nKQ_a?P&oBKs6j8FMW_F~oDj`Lx*;-k;;h&Te+PW#fA*-x+Z6 zUHrd_s9~e^py{a95aoxr3%0m?egn(HNAv-I!kES+*n&;41FOBlQU4UcpXDLUqVmP^ z!WKK(HBHa>?3L_K?InK+bH#4@YKJ^kGFLqY-LcQZ67W3vcLin~DcP(S_GZl1okg+@ zf_3f-{)_sUt2zE#pA{$o_kL#NeNyCoTIBtf$oqq#cli_?88Bsfpw9z8P9y`M13CuB z(0&N1E8AN~`U~#^<%E$obL$uq_3$NQMn2gu1_bO6NcyZLyq+MuH?S=qJPsO5DTY>p zw>Aim_rv!K9_1rjl@a2ayLq2Xuvb)hdHH)Gj4qLnhTvL+)o?zeW4?vE`c3A&S>zt* z#wQ>96hlORy!La4lx$z}#VO(bgdal%JQ7^{=&(rNHvF!*m+9%F=L-7ln-w1E7P$12m#^qByWzCXaeM*|`TsSsaHbA$SvG*}6K41&o z0NrBt=f-8=-2`;yAtet2ap-06#Xv1ZF9ZJ!l&XOeENNIKXLo9^yH^0-S5JU8ZRNB4 zs<4Pe>dlnx`wp5`&<;IkcRYO$4-Gy*x0MY)Ec@;rwk#1# z2Evq1?}S!_zlXZ!CPSVdsyv^(Rpfa-PTLgXH*V?6a~8pexa0aWCl6p8llgAB<%zP5 zv?uaJX>aZzl0W5Z%GGlDebt|{9i41>`0GpOJN)sfZy`bFse|}K>KYA%&H+vSl1twPXXrv)5`unc2h{4A2LNC6}7!5TNSm_3STihlF05o-x2q6u9O zkU?moS8mKqh|5yAn3u;=DjH`ieyAe`ru;^Z$O(<$H;t7GC$NY-_BwEbpR5KSu)pILq=^?_*bTXs`zh9JNJ zfC|TYWIzO-ocxeO8o*#3KO&X&4rv^ZJ1gOz6Q+q1*=l?Es#&IMtyhFy78o2jZx?}& z$O5#g=FlF|iWi;!j9>-D{^MH--uL?hFYM3mhuFD7v@4yH@H>>>v;Z>=RS*yGM9c?d ztL4~IVv(d1ZF3es-{~t^sD4tQ7vhhiQ&8~b&f_()uLNu53=G4|L3b!~kj9o}K*uh) zg}v9NSbXz7Y%k(&hJ5h#P#pp1u0*jb2Z&J0S)xraCPsgtNY7badi(NqY%HZ zR$N`!!IhLJFe=0^uETlwN4g-u?Zx~c&sxbd zVaoF;;V}ityw6727Up1r9Q?zNEQ2bwpCxP+fSDA8j*O5{1myV+q;}-?Z%7{+sP0LO zVxdJCpK{Dt%->cT9lqCb0sCWn!MLO>|FVlcwq3*#43TIzy@($Pwtdw~Nt@Z&fnW1r zyc%&x=ep`i+hh=0F^vvKng$NNn91}~z50dtvs*yVfgyO$J7Om9 zSBSq~uf`xjan|gQBejU-*3@WFoHaWwLHoJd)t6)6!JkhhF7}fCw+G_|uu{Yqe_R{; zLHn7W&|3{O+HNA-xcl^EK8Ux5ru2yHJV5z z%wi5E0KLV|rSC|4OAvrGnlvqFPNE$oyWJJ5(fRk@K>+xk;jfWjG~c7+%_fopIC@qX z5135{Kk=|7MVc7L(WKkD6f={niHp{N6=_saFEdSpP16xjIO~Fae}~fG3whK^-i-Vw zAan+oGQ>gy`h!6vK5`EwsJ{}DL>oAs^mXfNKuxM zGG_mq$ty~auj-h{&o-731h)*9XwHntV+zm4M;6dP2>B_NZ(+^MxPgz2#mI0V=jMuU zNi~oCN{|drmKFE^xD$fIEy^ovy1y?*V?VYeNprmM&K}n1*n!-n!c4TG!0R9{if`;ToQR6JlDi6SAq5oRBDsIHoikvNgmB{rY&{38ksOa>Z2wS~CUGvzj$0!5gq_Z-i*DbNpG+siUPvMKFMSx4;JB*i&p`;n59ZQ^uc z_Ds2**&JaLK_Fy}Yl>NYUz>QXOYBk`!Ny_0Cnw)5PE1xpt z6BJTR6de?wAV2l|d`a;|u)|A|FM<_b5_};x;}{*gE4vY*xhuQj4`LKYH_8++eu^5j|vL+5@ z3&eeV9H6r}+~vZ%Bs(IqL?EG&cxD1>VoxCze59}|)%zrg<;(@zasPt+!sivM>T;I> zA#}4NJ@zy75v|G-b2OoFraY==EGcO{sjb+Xl{VDW$)iY4=AVWgiKf&In z&XogTk0;>d@_iDAK@qSEechrK7`{l@EW8C9E~I2NO75&e8J}J!CG6e?uyGYkrN_*@ z)#Iy)_3o|^P5@*dLR|p%c)z%Y)@ATJ9{6p6$xCTm@N2V%n=;#m@Y_l8n~+}DhR_{2 zQNt~A1O3)~pO0Td5q@X*4f^;!6~WI3yr*B> z8<$kT@B93m3oT0c=Yx6aSN+BBaZq9z;^Vg+zy^iu1D@#@?+m10LrWFU>z~{|;RE9we~d z!)&u+03rP-UEder)1WNEpc3_WUl?vb;kDn@ullQgSFZlo`mE1x%U+6*H$Awoz-NTS zoi=xo5uEboW_`oPkZ$z^sD|1L@z#B zb&Grj8PTu$<0JJGUi-arX`t)+%hA6rZek9$AM@(pwUPBF^sD~Ie+ai9pua!$3-Rfn z1r0PPz_;!7pneokLVnjp+HdiC#=DBDhrZ{H@C%(!A>JKn-#5^+ivW~pKgxfc14>pI?7Mzv|D3;P1VF)m-{IAt(69OnBlhRZZ$!W9uhN}Xh;$jh)-Af$YrpsUBKkw|d4|ks)+xu`Xw1!a ztdwOM#&PGD;wJTzvh7W#b7Z+zM)^Eg8#{l9%-QUA;eg(uh4_o_ zajFx$fa7m{5N@)H(TNC zyc=u;1{kh9rcrCES(IE7V&u>bCr4$*<;_bOjB@(-;+OG+Qcd0zQ{hGqa%B3!}sXn`;g~vEx^%-Yz4C42#Ps-yr#>4OBmx}Sbrlll)`;^3Q z+tT>$+$VmEKQ52oEmIu-4l04)uBGw2;*3)KyU3e~Kqa0{Jq}CW^(w#igov5LSTcEp~t!Gq0S;&Dqu|Af;Xbo%vJ3wd-5r^Y8*8 z@Y5MN-UDoBM^S@QhW1uo@2#GD4BLoUf(4Rh|HHA4xV=q{;&1|219f3BB5#m=(gBjm z#X0tT#GT@Fvsqn6B4rj*$>Ecj4A)i$~u4f+uA`2y&{#RTuKCgjZA&-3!mg ze~$JJ4z~AGIqCYp+TPPI2IzbC>J6tayuXP`+MI&w;SgBU1vTcJ|55z{znPVzx&iQ4c@7EK^` zTFLYml!dUNyy@Z=bCiUxOJ&>R<3(6dWD3(&fiPz)3LN6c=fr27d?(5-BY*9r4Z|n% zpU0;)h|kl_ed99${z?2u(%-R@W~in%E{1IXg!l!XQ_0%*#{YIEWOdpiC^4aqVK2M{8b9RNOS*BVeRNRssfj{y zb8n5ZlEFiW-xiRZ-rSCzvsmC{KT&f31`fs1=v&NBIfd+OrbU=YQ-PDW;f>B2<4wE@ z-~5gvY2KCd8?qLQBMe7)VH66({RxPedD#|{&L+V~1CC;ZY!r^%dyq1ZU?Zrl$UqAt zO+aZSu5E645Zh*Cc~DZ@U258S*UI$%0&5o=L2rR$OreWMHp4$!wZ3{)^U3`xTzIQ#f6y*{cm7+YF5u#iJratKBw}3vs zzVCE$KjpE{p0 z7%n7bzSAk|N&cdd7L~sh=L=*#IByz;>3B0Sncu0qxE(vhdwzr@M)UotKRpqwe-P`> z+K~F|b_&-2FykCYvcnhMhD=yvKeqD+DZA+l z(Fung19q&gdcST9Mms3lgo7lP)V`mo+ar%e%%k+~gh(p0d{qY1^qI}_%vt_^UGwvG z$)~h{LOF-SWmMgi1-kjws7R$JQR&ILH+ory3_LRC0V!6;tPggX1)=FdfnaD zhzHr-#={h{>#Xef4m5y9kcUtT6ee_+9!6tXV|IiU*wI*MykmRGYq?j)aixV%!h-%q zeJ$o9EACcIj~qRk3T~(0TFf!rVh?!2eeFfhF5QbZep5}YKUNg`fwQ8ebP zSw z`@NJOA5QCsp5mNLgZFJ5s%Stzsm60~T)&(?(QF^Cf)%Ui`TEY+K`G!utz--F-da7N zK@Prj0%>#f=BUu$yN?vVKz}4k4JFK*Z+yXHG4VKpufPMRHE=gF5#q58{H8Nq#~}gt z@9M{M0+^|RhV)60%$Nm*5b#xUp{u}@Z_tYsc!0ML11|1AMDf#^0)X7hE0{?0I>3lqM%OmV9I_qzVza_g9nA@qmYDe z0|X$|`EOj12zwIcPcc0iLdDrUSyU1Gx8G*~GA)wlR!Kx$S#tk2ZV^X5$RDE~180@< zcAtELiKt2P^wbhcCcTnwN=c7;qa4?GViGaFCf$lm*0~KO27jGC zHn4C4v$nEeFDN83bq{yX7}Q7(x$Af^{u|TS5K@wcjJ*o37fuRAxwD--C3T|)HBS7B zW2&8u#x{6CG2B2(h$Nn`OfZal7a+fDV!n1bBfOro?#CHRpaCP z?ZL#lX^)ZRNKBbq9swefJ>+4clsDf_13IcUQ^WAbHZkECqGDCMMXc503|Fd_%;*RW z7iU^)&RRY}!OhIyxKollu?{K%MQgdDyL?%-KP1U)q)6k~0>MSiui)z~MD?Ms(48w9 zVgx|gJb%z<_CmBHwDe6ROxPur-@~;ndJ1W48zOp(r5}As+2|vHPxeUNB=Ku9f4axj z9UiUwJ6qSIr1xZP{&{j3$o>S|JNY}H6w*6*NJC;IaWzPcHqneGdTUml1E#Fnj))F;C}<>}BbG%Eq$xa@$=F>mqyU zvUQ8>x}GhPa&{Yy8l`2Rfi zvfwO+-ro<{%lx-KhrR5(OW*dg!)s*T=d_oVTNl~O@P63KbxuWt_R>fxioN_~zRGg| zTwRI&WwXqG)L!B6~U6)-AG^^R1oiVUhk7?3oUEF8TqlQ+#-&L)zNQOh_YYFFc=H$zHApoyGPd zSKl;ea?**f!?&q1!sEB8(QUVNrQ7se6(Vf+n!x7M3)!8v72@yU0~@?a}+;N@+2CK#-ubQfUkCJk}avL1lWIeJloH@HN(! zYPVR7vhQuSEp++}z=4m$;s6Uz64409EyzZIC1oyQRJm3ea0_$W5kNP~tv;d98DUI( zvlDSch6EMMJ)z6!ACog!>_o@%(g74?-Ymy&?*JV2ESldK*&UvLTo426am^0q6*&6s z8suhwo_WU3-!|Bo9k?ZAAMoVCU4t4s^Mhs}-j_6-g0pD@ytEiZq~&5>8W?EP>Y0Yf z-y17%kfGxXfRy>GohAn7ueudI9v=1QuYfuDA>%YmzH|?=WgG==ydpB`_xzQ%-yK_9 z@NdG?Ay5<^_R&I+;5+_wRm}iT4r#2KG!k?bcr_20ISv^h^!Ucu+PeF;K-wU@Vz}l|@f4hr#6dZ;RVs>MPfMW2u6I}^xE1T*1IB-oB zuxMd2WB^qD8Ni}Ps@fZ=TFeh$g65($3jV%v0rLWPT}{&zbtK-6SEcC3anEK71Vrc< zUxeQx!)iAQo-tZvh6Q;cHg6t4WWY)=y};h+IW`p#~TQkG+cDUZeVoi41BJD;%iAbs)`k*pd4+t4fWVks<9~Xh%&Z&umyrOL^DxOnn zO0LwD!k9yhO|E5OuP$H2ikk|sK+xv3_)K4*z znME~hGGD)p1UMrmNy~AaChGY&9@|ECmy+T>d17beGv>ha&`zh^{N?n)>JJ>O%yZ2c z+K4Hxu0f5M+=}$)(fMbH8u!2M4efN8XP7eFrX)r$NPr8ZBSF?D%|Onqb=qZ>P&@=S zFo<$1q201PT52WKU)03{kl_Xp1~6m z{(v9&rG>zZ|2lA62^LWI%-dO5hnBOCEh_fG!Z8!mtg@uQgxYcZ|2LDbqZ6&p`b3p=rgcKO6U4(Y-`{E>08C`OEST>*>` z(uU@jyj^_@ZrwrO!y4l~@Jv99Jlwr+@I{8XAxdIzc$Tm%9vZ!Mc=y z{lk)C`dKIZ$j|~8poi>VEIDKHKtg7!CtphQ%a{W(@uXANSP5goChq|=uvYjCuFD{} zfs#W({jYyc= z{^b;GAA*2d$!M>~oVtfnk?^*^c))NfvCw;%lr-Pm&lRHs&@mc-(Cl%b9~Z84%L2J{ zWO>snXN)+PzN4*|Bp6h{iqecs15xl&EK`Vd8c5Fq<+4%9@t?PFAwWu)!!jT)7oTts z;w?#-^F4_>f~BFqUimG@O8EzV@|xF1MHOf#%g^6&Y_XAojp&R z#`6UF!xo#gsZ;3ferPEa1pJ-E=dU}{KbkB_n?ta}Nqs@+tE8G)po3BmD;&dCE*Etb zvI({1v>A)Y1zb}%8)#457WN&VOMa&_U3y3n4>5DVFv^qdpA)>LFKrGLD+0Uy{RLQM z&zoDnLFYm+p>0yW<|4z>+@Qk`<0JEiJf(h-v|fn+0Ux{KL?L1>U{i=+&38J!UYMh# z+~88|`(1oTRwA^PE=6ZaR^GI)x75MiOcE^}$BWH+%nz|~YQZR14mp2?jGpw|UGKf8_O#WM}TifW*?)KB>K1V=QD%Pbf>8El-N3baN-L z6Qt9BJHaPjHEs4P2me8a=Rh@KPn;&r&PDL=?Hm5G8_M8+d=dQpir~-f8~(ZF;D5kS zJB0sXMet`;fZt5vw3u28A3oQH)^Iem9oLy+`Xdlp^9>G2%!5Oj^E@oLVZY2)!xL>G zSQ6E+;sKOH|3~e{_pj^@LpDSuqrXS+9ndFB&!L=~2Y#QEo?q4|ltuH7hlKJ&U?%4fV?xS@&! z4m_>0g`)J7I1qE%hLU5)uij8^+0QqMp1ni#w3Vi3_cHX9Kk)LpY9?Y0Who zMc*;I_e1}#vxTDaDz%}13U}X#+xvX>w8O^Hb7`_4^o+NKis>mnpz`pzjicx2t2atM z_v{v=r;PrMEkjTF11D!4y@6Df)xUFRZxnsc?i!@8O#igdGND*rn@FJU2MH;&vGSU8 z)ke|t*)IK{XH;2wiucuZsH==?a1C3Mdkx@2w^zu`*d z09>}j=y|idS92#z9%2yrmXBSwmzquHpJNr*q!&61gr(_*Iz$BWsg4C$pz*%1V)t9x z$??1yhT(j4xa9eBK}P~m8vwcRJlF;)NSccf^Y`7_13LnfNN_6tH86kNp6Rl4)Q7)| z(pWE=4*Ye!r~^-CZ&1cEOg4+(R?Y!ITWg=45!++wH97Iyq!Kc*8P{CMdNMM=!3+6l zfXz82;jP5T$Te@SXP2yH)12&-Po<<;tHp!&O*uzp_j_7=I5Oq1K#^VuxO$nvW84i& zgt>6W->M6~rn2V#&K?iul`}n4uplTrgzpK?yu;{TSjh?)UFydcdW^Xq)CP!T$Q-H zhRpPvD=}XKZRQBdjysyn#i1%;_$BDUrpI6deU{%q`oAdf!Gb>NKY?g8WV$YCY`|t0 z%1$@CG@Kp~PlSGsUSRnKwR zA1~d^w(I4YKmuk|WcXrBK=$6+Ka?{^17ENoky6MCVwznH6~XGRl&^Agn5r|IGGE^r z?E!QP*`iy2F~D`p;n_&4T@G(%GV(YhhE?p)l!jc5jtkS}f@Ks0bZ3C}7na{7b%s^< z*1#hTlcQ_HP^VD|ITkyz2&NG;1_OHGjUCrm@lfXMj2!UJLjrsLT_I=1!jj7-D@$n4 ztImP2S6u)_aS&uNUdVvW00O7lRDoigmylQYSq8rWqF*;{y~ruR*tfR~Gxn2&GPad~ zTBG|6+9IizvF%L8A7R?H(!@0R6a?L+GHtjrZMZOPIGDDJGHpBT2qR&;2-7g37i8ML zgo`#QurROYvJv6ciIbF9OzHhfqKE$UK4#e3v3* z?@a^^_oNzm-xii}oLE_c&W`FRi(3g!>SWkz+Dur+CTRk6dzzj%lQ%{2PFb%D^iN^< z#|7b^Sq}bYCE;JjPX!Zxbp6|X%EU4Gf5V>5-pXZqin8`Ru|BMJ(trPdX3y{9QuzOj zJ^vP;=Aq*K7wmbvx-er;+Ceq*zh%#hu#sn@?0H6;@@k{(d5i@Twdel}dss1|ANFuY zZCGp*w^yl-ue_Ilt2k=dzjl*YH z5TC8f;1l&n_Ddtc3etHW=LmY=Qu%iMN%8AwZ?Bq+1wo+7p>p1w`ztIG%T`vK zaLP?k5+A*1#!QD&MTIbRHFrE>pWTOzcAuS&Uh3vglg+A^b=O_?tdyA6xKe&K_RDn5 zmhZKZr3tg|FYP|Ox@k8_A)A@ACmZkMSPUW-)RTO|LN%G3i*K_F7lrKy6E?4vTI<7Nv=Ocn(d;{JV%8+;jzIKF>H9x>75M>Kri-?LeH&TT2rFq~=VQ{p8#uViZ^lm*rPSEWRnfMsD)ft)txWclZ&&PB)GQN2hK- z8s_uO->$J9#eQ#dXVCrSZFDib$=tFJagE2XatX&~0Z~XDWs!hsFUDZ9n{x?C^Y!18 z+i4yd*iy@FA?0Fb+dqPT(e*1Z)AA1AM+OV%p_8a9pA$`(Tc=V@7_j$Hm{?E87>r;O z<|i(T?kzCh_IAyi<||OPvl~*fd(L%kCvP5I!_zPrRY8HKk@1-vM>mvwCO2F6r=yK_ zbHuxJK0R!(*}Swy<)5Hv;Q&32QNIOBY*iqdGB>=wUJlqGxt{g2X7~S5JxjAbya@FJ zt#z<{RLPjk`&LD=rnXfSfz2LVAGUN+Fbm?`U1|z4CuX&yR_4!zKuDprsWRc3FH#OIi|exE%H$G zrp-*n9D)#AHw|ahMSI)KuP+6BDmGAUCzy6KnB?LW-kq;+uR#frP8Zl~uWGRsP(ss* zeyVO7x3h|~wcMJHcl!zFX^k1S2{;{6$ojH51GFt@r%m~RZEhusX>+n zm^6Y-Ax)qffvN*4&Pws!&F2j0T|g_tfv?GN@`q@#Jz`|Ge*=>#yJ9kB)xktiOrmoE zl@A9Fk~T?&X1DC;tlYjelpfb^^*=a2dAp3u`RH|pglc(m65^cJD6l~PI%%Df^L3nQ z=AW-v9oLU$nyIgdn%j93)SShd5OYng9xzV*Xy`seHYZm@OYR^(c)04Q)W>Evj#w0sif{UIEM{X8WIp zm#Ir=@gavFiZ9XAaniO~jsxdS%X)Flk$598A7aZ-x8+m)EI-eduj*&{ye;4JY5(w9 zVaw0$XZd+Nsg3@$t)Jx=SorCFmXBHZ_5Ccr?jXfy<-hw!Z>z=U-hP%}XUlh#EYEl{ zay0;Nf;xkH1J&rMBOgF5`I=^Qkgk%O+U~^{y3J^3yUXs8>@R((sHP@b@RujI{vzTNWN}9F00YB>7&KTX?8l*gO1;>t-1ZC;sclu zM_*l!UnJOjuYp(uZ|scPske9)hnz7dVttRV&BflB zu^CXm>$q6`nWk_dxeU*MNz0EAzk!~UW^bC#M+RW9A1uct!M?G6;&E)V=0n2lqn{oxn z!=cqKhthB-e`UB3h_`fqt@H?a^M_~ADBxl-jR*63_?$A263@=ifTqlOKX@8}BqxdT z<^qh_C4~Vwq+*j#=LssX3n{kUquDOL|ALh~Nwy6i;Wb>!|E-u$oq<1qeb^;ebZbNy z;Ykq~{}6ySnh>+mLblq zDChHr@s+^99dqC&PFwpLK)Gi`?nsv4hOXzuu_E`Ble<#yj~A1p7vPwte|KKW>^vS^ zK)e;TZ!VGszv9*eH=96Js4g4&_eVND0zvgm$(X_KeHK2)MoaGW_A3YC5wSJ&NC=-D zrH2S$fsIi&99OV1hzOv9EXY@i!g8f(efa!^@Ezj8{c6CH*x_ZcxCl&YP|bGeW!FvT zW62J{FFG`qFq;dp-K>%D=)LGIh2azFm>tHq1iVK+v>F$Lx9acxg$Md*r1KE+Q>Q>& zkCJs82F4+>Jd8@wr@dT47Pqabb)t@Q^h=?W* zSQyjPT(~XzbmzO9F$r}I6LRs}<%efNALa>slxSf+^fU-9#Ltgja3ZnB=gYQZjoLMD zSA(--{$A*Y?Q#Er4yx=}tP940`>kXwJyVD3 zn1Xt2BuJRLggg(7NFx2IOxom$T35~BEj+QXVQ!~kW9!^$GIt9F(w~8&;q2U==+XgS=y<3vOWa)8C8GT>%=mozaIXVCQ z2<|mG6VNT|FTz0-JX$yt@WxNz)m?^xW1$e=Y?{*qy@DA!1sg8JN80qXMfVj@!Yt&= zL9k3PvOyx0t1xRG5UG1eY}OKd2)IY;h2xOg(ER{r>Jo}XPO60LSM86zhHzxA9xMrs zA0zU0K_c{+Mk@|=EFP-pe}EMr6|&9FnMVCzq`mId&M6N^T&fPZ5`I}u+)11LE5EiR ze2$lZEdHcrzA-!>pEplE4$HBI(g5=MmEzln2>ARy9y$@K^lS3+Lr5s3{D(LLj%B3fOj zs)BZpa<$nPQLa|JP^K+Rgep(^F9y|jKcSX@MKqDr@YCR_~5 zcQNPUHcq}jH5(@1&swAM9jcgx<-423tpD;Yqkr%G!{SH5ZK%JvsR-HAu{o_s887BM zCD>R3S&?T+dr1PH5>DgG#Ft)-icwm)CtIw1;Xe6k3E{purG#+H<5Pxz_aGGC5c~b? ztf+kVROG_)J<=jqf@Njww+z2FCw`Ur^~#xwlFzT}{#}A!4_;E5U+kdGZg3l3f9myb zbmiE)1V_NW!}f42Ux=TkpG!c`@2@{=SaJw+0u`e9H+nT8Ec|o%i0HVuH|I@K!*1$@ zRqVX;=6BB`#f=8be2+$_Coty6{0OwJ3A6NYTMQH7@)F~6x~MfSXIIS-iOrP~9DV-# zdB8_8abT!Gb*@mM)jY)tMA6>){RzKu=}nSF{afeic4r4Pq7JqTs^UUjPX?wi3rU6e z2S3y>FUh&pq@B!5o7c96*mCh}CTUQaHhIQpje_Yaj8UGRA9{yq}M8ZRWrj{9e_S6kBW;RlD&8{}3mCYxZvb9ZVV@d;4 z2H2E#rqnZK?LP>|F=eK%guG%CNd!w)vE&mrk-W-4I$^Qo0jSw6M#DZ?jBM2V<%wTv$aeU-(0fdo{7e1Rl%ea431 z3BJ!oTjO{JRIkb%*M*4(krG*IYPl~_!)@&+2`6UbO?g61Of+Ivaq+r=D5CkDs~cD* zk(q%=vy1-^#s7h#17I=@gxJsW12?Gr@^!9w|KNKk_q=Li`$7BfgPuXIX20n1PHMSY zUURDNW3c-~johS-bLpB|W)%4Bl1!RK~?H0g4za3p@` zWan)@Tny%A$`bciX6TYRyzlg{G>H0-sjGFUUDI3bjgG0)&Z)2yjlo3VB(D3AuEDOM zpZS{FCu=3Ksa#SdsF88P$heO&9G zG=$tm6Z(;y1Ri96N&7#|egM*KY$Oyky;8!gI1QR)U;P$XG9~9qtPg*J4Clz%&)ME15ydb$Vs!lNp%nsqQe6# zH{%(rPJ7zi`46{iz2-Zvb-VDCi$CyBlt+6B^DS-K&MtJLb#vO$c4te%T*a1h@tL-G z+U)0wr!!sl&bxhF@w7Qpi+{%!Z!rU1aiD68SGnRX=1W@q%Tio@wOOv)R@_b-Tj^<> zcmjiLne_*fe)ZC%zLuY+3LHQ1^heYleSM{yC;5k$+oj&$~?=GJU-+i7d!uJ_kqL1%hW$}Gu+Q#5J`JK}8TMWS|=H^1-Ha1xTDUd(BRenI+8SOtUp(WyW)RCy-t$P|>5 zfgA9#Q|V2nVOw1M+oQ5|@)ZA^*6M(qFmJp~&qW_Iifu40qZ6?&#TlSl(3&nEHz`We z?Wnk_Ax*ud)w}Ocl~{V1=(YFc3IC1u272wC7-?_pHva+oV2SLpr(8wX&JAf>Yr>7S zdIeVATse^{r>SpY(*unysrUQotS0Lp8m)gy$@;gItG^z1lZM*=qF4WZ?3Huvf48W< zwQo@;5iFjdL!w>qBzOabk|c5IU>>&{(0hpJiMgGjIOxxs_22OHALZ+9m6*9&N(c~! zLYz{Li9)JAPF-0Fmq{ z-^Wp7sV7@oc3F%4(JAa#AP17Nkwo{Ma5po}ZU2c5DiXxD&=WAHkZoFZ2qx+k>cbG9 zfWUR6SS7&rJzWqf#J~D|h8XT`2fm5ObXI>~%Xvvm(?O)GCT*jM?&b&D`kFh|%j&@7 z=cE-P+HD58Ju=wvs12An*%cgDxS++!PXkmz&yN41=nd#%KNG5-A^nwoQAmU5bFifbO@($OXE zK?7w)7f@b*iFzqe+RXEq08~~gDx%G3ua)inc8X>;iKyNs^P9L?1zLx(8%Og(l2SuTA5_L+ihq=C(d&+ z>AizAwVB%wKs(-)fMdPCQl`$9(7VVGc9cC!nY@S`A^d<4u zYSylY5P(37nQzn4YY@ZB`HCk@ZwN1}xq@u;Icqa@_YGkN%$ynA&m>oy@1$)Uh03FK z^Oe(reSx7|cb+La-w1j}FRX!5MX4G|8|e(~1tLR2@ogZWm0?kL17ux@C)=<{Cr)|~ zq6kqi!yf0~?3D0q7Ya1@9@Yy5a3z6tk~%Z>G>QCdK-_Hp`up|80r5DBjxfaAx%d}U zF)@XfTM25{tRu$Rlc61T)g1nTjn1uSgoSZhLNT_Ay)D-3XPNhhu!7z3e+`(U&S<3h zWZ`E8j89YCvy#FCrlV^Aw7j}%^n!VF~X&LL_f@_ym%4$7zE0cn7h5&;`wDRk0KlDXWe_2%gY=I z{u?L)xjlhZCTG_&3lG6cQM_3>il-JnwJ?oj9gD)F7d)(Fy|+xrx*2c^@hPpcJZj;b zuqUtrP?<703NN#K^cRpaz?1P$oB92|DnE$HUy42u`>$b-lu=~(7Sum)e?Z53AtCys ztl7Hm?J#M>OyDf`XT*z!hh^D8=UT)C*I)4DAv4SeIFq0{j4jbjcpoCUe}L%sS{yL9 zF|^YMEaQVw!LeG(EVjoORlGiRCvF{}PW2IWW4MM9#P-AAoPiJ1QY_H!BJ(U!KeV4G zV@AA)Eh%ujS-Y7eZ5HC<*dX#ykW3-oMIw=m&Mi2IF2*kc|7mQ$?+^T|O2Z!m)d)`Y zxSg*xH>~x9+yU`H$c^}^f>}^P5J_VswZv7p8=D$;?F=_F0eH#V60Zvdzh=* z+cdJfb6xeMePISib6zA@d(RRo^NR0O!yV>FFF+5A)Oh|q$vh5a3h}#;3123A66PV_ z9630+&j6d_;zudGopXb~=;1n!XXbhCbCa;>>-YNi1$pkwoH=vmY;$IL=E1x= z^<>h#e0}hQ;qGJQDd|^c6%f!W9#v!0;4^+~Vi;T;R4D5<`tdLoYIZkn=}o)r+|);` zU(bRxB`Jad?Yj#*$eSom)o0H_?`nf?HM%pkkQsZrkAkOv%hW~A zyM8M8bPPZCQ^6zqfd3DF_#}E@H#`?Dca^~^ej15>^f$w&gBe^x6Bm@KeGL5Q82+G- zhZjFHwbusnBTmX9lHyC$qXn(5Q8`mK{p);7F<7=4DG?aAq_SIpk zX7G>48gDS;eln!WaDBtl zgMDS)VK7`Fw_0hgey~$Mtt9`=2w5kR6HjfL)kxOOuKm3zCV!ICTekya>H=8 zvJ)gW1S+t}DfBaUTJQGxX2dIvQ2yk6+O5_`5kMI^=Cs>>?`XwiC!f)Apqyx?N>jJi z>GrYNbciRl((xil`BK%G!Fz1Z)>0M%C*&)Cd}nuJ2tyYBE$-J(p#jhHp<8D&RUP*IJCdPS-Bd_b(f)1M;M2|UlunWmEZnJV zT&dq-@{+Hdtz(xYHKHu_k>c_Pc89kGGR0ma47jXY!(k+#>gf}UQ{Q%re7c3gdbfataJ12q<3j!y$AM0Dyyi(xt51VO040v&i0IRsA}>G3pA5L1_mlhqS8<%C8{B@L zj)~AUQFLwV(gL=%?txQPEn7tpB{uli;)qU+0L zbO}K*bxfek0RA;TqN^3U6w&0Ur}%dDXTURKi4^_uq52=M3=BEnYA;wI2|M7bh&7y& zE$tf5SSbaToyg1Z^6SfL4l_{_bMBBZI^vA34pyHwi~_AFyV#7I@^7r_o%Y+TYJ>vc z|KV34J;s?OXRthnq3p$Ej+RzvmjOtE`Q{V2Pm*b$JUqOA)$HGd!{7C%L$M}+wtmJ- zfr6(JKF$8n==s)!r9{K&BnKaj=?ZOT)JVp?_e2}8bQLlo5Ev(sbKE}6cV}-5R z)whpVOZ9@v=~oUI{C~mi`!4%GPnO#0$W^rO9w&PgG-PX;pB8r)VkUd zt*b>$qzrLhSeIoj@L?oIiA-w@2C8~(KG<)gfD#E1{ggxpHdq5h90x~(VA=R(qd>;V zVErg0e(XRa)eqtmQK`jx%|PuHQ(rztL>n>5K_~X-#niikh!rn17?-OqQfAbUr0ph` zo{C~~>yR2A9@r|PhfT`Pds+R0`-7|{Zf1CYp!H=uu3!AvR^w7SbfI64*oR>mxBvJ2 z(3jx_u8pc6Ol|))6^a1|>jW4SWN2mR@EYb5fxg;LLtlKp#%7{rgCRUh&>swxX=ECK zR3u&(^>e;rYNz0iF$9fa{te6qSFhV6@~Kx2^HZq?x8XSTC2WjK8H%ay7P8P6)AF=S zU5V10Tb{mcF?F@@MpLsExBMT5?|T7Xi+g-M;Zx<0*Qb^&p0Lq3!>^2frNz|NVrrTQ z@XSxU*REHXgw1Fmxb{ah2P8@n_MvmGQVi(gR6v41@g%O5aY0oZ-4>4t4;H=~7T)Xg zzOCJ$SU-luU-tfot5W^SbzY7mh-Ks`^x)LUI^xzwQe{YLRj*t1t3yXB#>gtGfUu(# zGrngRHL&M!g946xs`n*o<;U&$C#|z=c%vt6BZRc<&16r`)KR$6ZaWh{xL%gAsZp*%jj@!X-`V^e8?W41gwx)S5ZvY5 z35u@sT)Q9`duuOEEp8eGUU_ca^jyVf z%^#BhK1-P}+A@2d(J7d!?;0bwfQ6q|Z+j}wHxG$iD-pz`{()V_i#mxSpgzE`6x@fp z|DPsSC7Y=r z%>4ONy&!X(=#}GFkv1TY?RS;Yz8=GVDAfHH!{9#i$#4rYC#Yg!-r@K z1Y%}Kp(?uDuB6ZlqFHK;!)+wL#EEzYMvSEVpdT*HQXZTp0^$$ua>J-Zy&I$OdJ+lZ z_v{fX{Aw%{kWDN3p0wVMgsKd+yfe}-{S#XKR_QFI2}Wyhrxopo+5Al~?fVPLmonqv zqJxm~t9}A^OIQ0A8r5%b=vmP|Ry`Rtz?PAne=B#^&-Fx780)ux{`z~9!kiZ*h0~hK zoi$?%gHgdGkcD8k+YgX{rkF+|g`OzvrZ(1=>064)2EIE}Ff=72P064s;a-YkrHcX( zpb6=h2$gdF+7N+@>WW_rmN^Eg;^_%o;BILwts&+*PB&V7|d@h z$b8lGpt_3?VYn7~UlW1+9Ofm|D5%Tr4PKbFfiE>YT_v+8F~c%LbUkWUGr#@R*Zt_L zamxVglw3GOO8*VR@7Mr_Vfzbuf1oFe_bQ2bf5i17$G0It8pEfG!TbpfJ`tF~wfI>5 z*5;j({h2o2Gv!xY>D?Q%eRF1;1Jm9U+k@ZqCUg{9hkoudoJVC~9J(EWMVIa5S!6Rc zkp?5(U*^H0*no&s!Zb{$O^@@W6&HG8HV0u^ec2J#ZCO?B!nZx@7EAL6mYy_`gUm-R|uDm;W$&ec_V!r!1;PGq- z!{&?BLF8yK^*m-UV>+#wxCm|$ingSup9 z4{VB=t@j_bCY^c>1A!rjh#`l-FAytx-_IEvzohIhv0%N6XNU193Sy=Q{G-e8A1(ZM zydzqMakaX80UThuKpSnY=O(XDEmCyMQNw1%^mC>AxuE+C1LJ(C&dWih-CqfMrQQ8j z7(Gq9m#O^Z?|LZ|LS=hei%%+GgmDJeR12YVXj4p^XH4^+{=CVekRV>jF4gcwW6J|F zdm6F}12O|EZVWOm-85u2T=e)2z3AI^-TnDJI zx2g*EvELn(Cd9y1MqV(Naf`qLQ!R7y;tTjzp9iSud$;a=$ii3bM@H^_cCjH*s+!&?BCX5U!|MX# zO-hpUqdBvTbNP|BDyDz1L4o#6`U3i6ql8}C=k$z}Q z`g>%Hz82T4#y#l;k)8<~@bP@5KFP|*2q(l!%fx}%b%8NxLjX1I9P_9Lwam$j{r(rH zn%rH+Hs0@I`smjGCDSdriY@UVc0kgwvAFDMGN7iPh*}QjgbAnNe_=^*g1s6S_T(N0 z@vzl2j&UiNN}JMhC&gxVLBJ8wbf9J=T>bOts9_SLs+UI zO%<#Ud(4j&qgIBtHm8`0$B zzIx{VmM?}sakpo9^JpOHhCN|;^~?zdL(nK;!jGKNXrZ3&?oiDJ>nQW`A3`ONE#Mlo znAM8A`g6+lpuRM^*D|_jGn4{VZ>{3Yr5Z2l@*FZdOrJubwqWMQaGewdEz5;2_-9GW z4AnkTl}O8;xyyIDE!425v}}8?g)?wL46=vGL=Iy#wK&o!(`1?j21LT1hKa;jyNn5X zuWm(UQUlI`Z}6V1PxdN#Vw}82$;0AgpOX8M)F9d|ol0CMeQ&t0k?YD^oKh3q|P$cp+Vme_tu6=tJbDc#Vx+l@=(d1R9c}*`O0gb^bIL#JyD&P zsAJBM5$h( z9Yc!2Z<6PX!t|hR)!p-?Nwnk{)NG2qJdd^{G=yP$Y7yyE;`Q;o8$5$n~)( z=NnDxDq!qTv%=A!=Bw^17Cp}_dB$YA1QMZ?tF~_JSg<<}4T%*DMe_KaUSLccL$9W= z>MP$j3E#Jg5H2%>USQVJPPB9{*AE6dR_&#s_d{do9Zz_M8r<5V2;X*1B^1GzbkMD6 zf}pf}tV{mNw4n-dx+8hhs}~j8icSo1SmUrfqM-`F{K&0@{%PdaZ#AM)KXU6=7`g2y z^jZtK`^ks*%B=3%Q-A3$pRlrvmbV%;$Wh|d*H{dUd$;W()6UU98dSVACLgPIl$YK zbc+FT7nU>x=_~EtQDjMBp9BEeRv3wG9hP~`sFmx!v_^~YQ7eBM8_GnxtM)MmdkJ*g z{gV@o45kCuYU4Fpn5phQThk1LJue8m%?Ktp7gsQnqZ33ZiaTg$*h+Q zqs%$*fH9l?3)byzD=Pse$}BK-_T+U>3y2$pI4W=5+8fTU#MN2}-BocAcl}?VAoJgQ z80il^?(_P@loe5+V;m#k(G%)Bc{Zsx*2vttYpk4Dk*9mV$q&V|ER)DXPb&55N$=(d zJy6>H0}VyEAs2cZB=Ra~F_8<_9U`%+FKWiU#2n93M&Z+zVyJ2i{aq-StY$PcIZsVu z9QGvbl%}y|fdr-9S;0u=ht61GM9Iawht`zq`ri+fSSv@WZPYC%byEe&aBnizM1z{h zH`N{PemxP7o_c`G&$YwL@$S&co9_dXi4FUpdr`(V+>{>b;7R%+OaBCFcI?&;R0ylv z3o6trfel=1@(w7QW0Vi#JZxX~0JRIw!cd%Z%W%H-p2sP$fddSu!tiPqQ-j`{ug#}^ zv>1HQ)h0@EH;NxP!;Wl^_rjKa+iso*N_ZYDK^qI9vl!OxZ&Q22bG>+xA*0q z2g?%l8q%qOLKO_$9Jpg8$Hhqc8$+)TS4;GI`A{rt-u3-d!C-EfAGd}56_Vp_lY_-i z|GcPOPJ?b|MJ5WhiQuU3_p9|w(pW_;Mf}J-=hhC881XWK51lExPOrT+_frX5vki9b#H!(D&8c=|C2hP7Aw>DptatxwN=TXD%h`m2f2ct8np9F8C{YSym z3CUIX^L9Gqzxp1lXR|-R!(oPSp>3o{@p+%4>L^q%OGnlC{$XEyk!(YJST>{Wqy()( zF;f4StnD^Z{XQ*`eq8;~Uz8**#;4ssZivYsr`fb7%lBJ6o)Of!oStF$emo++!WEwy zpBIBP>+yN=Mf~Ee;aZ75t4ZpI8Xf3jOj?OQx*D8La4wU}w$oUZJ`bxyF?-%xVoQ4X z{mKdX(8N@|NQF-^7U>Va>XmKB+pNO^2--EY!5?ipBY-|f3DSw)MZeHw)8r@}lIb*2chv}ak^=VP? z)l<2}Bk@t6jD=(VM3uYvZW~c$6lF^c#1t?zn-3=j_SlB@U<`XC;nOg9)(#LkZoRRR zW|hcGD{!)sn}6aD0L=cT1GHe2ADR%w5gF!&?ym4qU0-= z_Nh6vPFK#Ulu7wtm6uO_?nWNuicYu3cf|j&zey)5tTSOc*t0s@^UsHBfqUx~O+|Y7 zC1N6=PJoN4=e}TQkKWF2;3?Csd^IfFW|TVVU!HS0C5|THgFdx?nGM@>{Ej(;=`nw# zi5W|zTxEkPI=GNXqnOa>lD%TC*fpK4UPRA0_*T7W9hR*>W~>@!55B%t?pf#K$GKn* z<^C&CpP`cL;Ip-#)c#b&Q(M-FMzx=H_sx`NTBbzCZOd!C{sh>jzVuaA+_{vd`4b09 zJD*o}eY<#+!R1WliNS^Ao0%e)SohfV@ddXr&4rP+j5N$7xESuID-bPpMaQ5cc9ioY z3IF~75`OfPBjO)&-VmX~ulUyci?0yV;8@~QW@(o_L%F6D*)Xog{(d+8!%Zx)Hi(GCQ8uwQ=^C&)T|gPdciZ~vNpW3CXZ z^Hy{v=U!x$O%xbm0M{tH!QJ}`45GT}3->8)5a0X}!8E!bYk^J}Ri#mPeuz2lpliY~ z9olGh4_t1cLzaz{0xApoX^lyui`*%9$xx#hX#=}gwkuCzBQtQMkhA(iO2BDzaS;2^QRA zwr3pV*X;JM@v4hBIkZ`DlXmwMgCbNBIz@#N{!A(-d)VUMn(5((4EZC;&qnUm=Eg0K zSgl*v)41hu>v~gO*?0yR5Kj9Xa2nK+onssqbbUE=PS8*upf@(?(p>{Ja{2q+pbjp} z=5;3dRqb9HE3KmCeVw^=SmwAl5Z7kV*5*IsE8(Nh(}l*yz9Z87_!f?s)=>Q|j+jP+ zl^r*O?g&Q{?X@*g+-dDZ7rBSiAA5ctPUcI2h*97D<3N1?s#6E#o&yzsyHHe=A`z=} zIF+Oe$dD~Mm>|JkZZ-X}CI0bK|N`r!v{nK@Y4 ziaAILo5elPgFb$`w%+m?tku3wr(Fme1R-v9>d6>CjW6KPeT6iM;e054znou2UxbMD z{YIUFnjvIpP)?ZpNVv_b$cb&#$oAHlPaV-6^=JWy6;%|>nqTN{AhFBeMM4JWVe4#R zMohRx*H?=lId`pHWaGh~$TVV+$*QFue(%{4dRQjkSiEvM&3KO5*g%$W>4ML#T%}I> zNT9{dSvriz_3moDPymV{wy#0?w4nh;Ec!F=E@v$!G@q7=sT)~LFgu(+$?ZLT_~j+X z-m6M7gE7*~0pW+yjLQ{vb)7RR`8f@-SUwU|0iXczSxv*Wf7@R_Jn!gk~+|+keu#Nj@_TIl@VCFKjASj znglcMh1X+YKE+SFKcQPgrCqd)XPMFt<8NtJ?q~Ao7>g!^ zTh$AqAfl*=acRjRj-B*lNIZh8=EmV{-gFfYcE4$P>@28B)9kH^#`}N3_&|JxaWPM?3zj^`W8aPrL}C zW;a)RF@|z49TekNmWz`XBbdbr_t7FEy65x!~@}(L+t2`6!(A6J$~Squ;Xn^&{LL1$|nQ zm~=25@;+C@odPl=M*=36vKjYKs^0u(psHjdL-=Q*jGt@UfEPxAA6yM~91;BsIb+_8 zaYZ7V5}hkT^?V=>vP->B~(3?xq0zOxQrYmCh^jUzmtjgr);%sRMwppVV-`F2(q|x z4~4w&J13^)l)Z{ZUHUDFHo&~KLW$H-SBc4?BbNX$xJNe6{*`FArjFkPIB>NJS8ELV zTv)9Up!+GfZ0O9o3pspeyNf#}5pNq>UM4bTzaA;Lb|E)Q!mR--8Diek~;bppHuH1QI1V`>cwA1X%=wrA6Jn%wN4$6jQ(5NIZ$WN)y8Yditx!lVU5Qe`HRGt9%6xT^NA} z?7UP-N$L4OU!v!_sl^mz8?Iu3h;hfK)Wb206`gLkFIlF*XlFyk^wZc`+3?$xlb6=l zT~w33jcVm_M$sN(*Jsa^2NM}a3%@c7H8WEd+FKa;)ZX7xjd52{qnKJvlg6;Dd*DSe z%*#3O;jzb(e25}avsa5k3ufPy8%DjtuLO%E>SrX?@AmiAGz(xq%3zk=R*I1o!q7Ne45a1iy*TnmF8nuNSuds`fJmtk_T(2wf7~c_~0lHHbxmpl_`b6MkS^-dd64qX#VQVjSdN7IK+~6i}mC zO6F6G^`LT_yFlNtq9I1_Xe%$Q36yV+8ANFuSq1&Y)Wf17pPIa$d}tUUp8XY3J+Idh z^T>(8Inx!a?Sy{%JY$KJ$)>?W{JUvdl|@LExBV#Cj}N7A`xzR9~>)E<=P8w2m2-(HJaU<&#x&CWKla2q3F&1NmCmXq=9p-(#3`~{cxk9dXfD> zPbM$E=os1{8Ct0@o`+U0rWLhuY2+SsgH=$!^R6|{PSzR^=5|yaZwQB0f0{tiW*xbg0P&e^b*pHB{Gi00#ey{Wg4n= z))o;x1<6WVTh5#0j?N8V=Un_9dF21A1S+4Zn=ZkQJG)CBc|osfIFA4=E${qwpC84a0*bER>IzCzeLTP#uL%mE)pdYyu_|HT$b&T*L#*NL%6o9 z*GAlsU!5#JPUF`n%_}dI9`;mjcunk;`z*&t{O;!E(fGgiF8u=>3khTTGJd(A=#@74 zxxFDr1}Me0l5fmRF5ral{h&#J=R=%^`@vMg-Zb}Xea8>n&FV$QUH@V8@j#-!| zMvjecXO91dIX6OvGs-Z>>5eeL++I(p^V($NdO zjvnw0(Kc=$RJEAOt}6`Q-P8)z3kq_gg`>UZe4L(&bqcs+*HG`JuB(+~G%!SZ-uuxVex6&YGQ(5YHv2qe}donX|ynKjcj{}TTX{u#BO zQLRjqO{OdqvR3Qe-$JWKOsHooyahKqv|8`pw5heK32&FCf+M%LZQLHWpLqBx+a*ml z`N{;ZNh375QQYK(!<&qz0`@HPSMgLdY21~WqHgpCTblI?hdZS>q-K5(L0(26SvO%^?2+zM?@d%hWW6MkN6)gf;AN?Q40Dk(fkGdnj5V3kJylaFc|9QDa zelV_+g(I&sWGmKUqhZ6YQ|i(m%Kh{M;3g@k@uj0E-9V*lRI1Lo=bsdE>g$E#XMm4+ zLHK99&{GjqF6PA)?Mp+Y#gGtgPtwb+dN@8w1FJ~8T~Cmo=cKz8h9_xfY5yfY_wtzl zp=1sAjl)oLFM{-55vrEmLk@WQ2j$SH0G`f zICEO~<<`gv!kOb>uS)&F-khsIgN3-m-}%@9MaK6I^{hf2p=pee_mPyCUB$YYw3*Ub zfEH65ukBNvp(3!<3Gx1HyO&unE~&Vpombop8xY~f<`vvaiYhP}8mGMx*9NajK+$KG z0`Mcc#5Bg>0|UR&H_Oit&~EetWx8B69<9G&|{_jkT-tv1!=*4%6*vcmW!kcpdu{2%~_nG$E7~Y&1F9tRevUt3BA|WgTQaSgbt6|G8ZDr30 z-#Dv0cTOCk&o{XcxdxLU{ zUD2|=%@;>n5Q2p9(1gcF?4BU*oV7uL>c44cuLU8j8 zo1aCB`c8Nj4Z2cb4ZiNL-1U#j`?1-*z1VBLwc|pv;Tjo->w6EtB)VkDLhds|yAy=kHDq^?{Z8)D zxHLJwCf{YSA*9iDoFVM+!M||U;9lQf85Um-9Y7*~W*V&7EK{;&qQ2f8ae+)z)9i5M zvf!SMhsSWkWTl;6P-LDay5G`Qto@}A*0(HYwGCSz)7~rRN9}cm`+Q%C`~@4nWpo@E z(P8g-?% zwyDJ)yj9|pDvH%h+5K)UyMzP99>cMdP?AMcsi;Gd8P-R}RCTS;1ZdcMIz40RR8d`t zzZM^z2Lu(>qi?bF3nl6BEtb;Y1<4V;^$&?@ZoFIvwK23R*T4I4H8FU*09pauSHC|b zfXnub8I>JP>pkQ4OP2W$DjuMN*X#Z2;MrrUVicPt!Bh&0suGD2eX%~)$Pd6D%vmq7 zsJvdnBK26Uq-ihvH6P4es#C+ixXF<|J?rh`Dw?j@z0bwe>+jnr!VN*vrllE`e!f>!H3CRfkJH{a5EpknX zP4Rs6pvc%wC_zp^bk-&g6T}Lbvy6^3grVJ5EPuDQr|B(TP1t zXDl58qN_N)Xah}<;xBDIT{b=&&O*0Px_uwA$w62HdWK;|?_Ow(JzM7#ynL$M5@v*W z%n9V5n#{k$=WD3if&1I^u#1g%sAD4#c~?Jp@Ufk*tQlM2!xNS z%N56Y-Zcl_*%RKikR}w|&U!aDz6^7P5dQF3Undvp*poV~RtLTVB%N1Uob6YoTXhZ) zptexSo}@dhQswSERYytH=~JDZLqnZbCq&9mol}fHhtpHJ?>lOWG4$#GBTH11n6y5a2i}8L&CR(|^9u?3;!7Ae zfu2*eGZoSEM=Wt*iE6b?iobn2E+5{8e3(EPd%nVxLU4h+b3f_!9kr z;~8B-$F`}SIhmAfp+&X;$_@gpf=MF?w}^16u5yYhmCn)5gX^wu z4?YBLm;74tp1=D+ZMpD)LeGlo?t4s7(K5R^z7-U>lc}YfsV(l?J#U&Cp<~v^!(aYe;_pwkN==NYlU}IA+ypRR5y}4hRhtLLGZjXPMFf z_k{+Oc2%K)QG3lRB$a2Pud+RZ?1&a;o&( zBYh>m{1BEQsN@%{=lUnWrZ0b48MND~*4-OY*7`_?4<#&o8w-=byUxczQ`aWXoUxFw|ORanvxaVt0#E@dR#1(3nx9P)f4_^ z%-M#|namsT-A-G>-TveXz|9dMp^?tBeIu=*kNSK$(xDNh-TtAG#sF+j zfHS1&K*EsYBP z^DSyOCl?dEe~7F42R_sJ3enb+*&P>0BVU;_>iS}#!nSsvMU~uD+RccV)N#7nu^w2U zJ5ReTJs>7HtVwv*g4^dL^;P540#C(hv$MxV5P0)~`{J|F8bW#gQ3`;1QmMZSL6v)8 zTir6FkEi*O*y=VI4CUdnF9hnY-$nHVM8Df4;@R0Y1t1xGoTo*4-xnvk`VH6p&%`ly z?dS+Flod4O^`vI3^t7B|v{ZKwn#)Mg(`RrrlBwapt`%u0{;cl(a%!7@{wiU0jNqqOU=IyI8# ze5I>)Mt(Ff%YegE1dxW66a9&ziK0ze0$ruv>=W(Nm0n1nCw#3Jdwkc#^|d)|^_3OX z7e3)px=;DIUwcB=YVzHR|3mGKfOtcCqHRc9D<^64;Y+n1qYdeawjphOno6lUXo!|J z!6Zn?MCwDJ$@(Ab75rG{mUA^@|7nD76vM*e&%Rk5eJAe^sH##S{niSnb-&jFu+#8F8?JU z{nEeKWR-!*=B6x6ldXeR#=C6JW%WvtPwB&np3h}Vg{S32uE*T63e)Qf>fD8iH$ck~2+a0fN zVJ${yc#>YNG;mesu~2r>?vT)_3d`i)BRwNBKMZ>Ys^N0c^j+<^{pO+Qw~CT{>WS7Q zrLEB$pw-0Ybt2K1gezKnA?TZ~+mya7Xz^0agnIU*Ue?lc{S!es((a8$@0gkc zuss28Q5uo=L5bL)w95yP=K`=j0q(Ulm*ASEU{fN$_fSuMYXG(U643zVdlfY*xgqDzZ_X)d%@HvqBwvQs)}gp`D00rPWYcmF_#x zSLzFu>`A40s+6zn|B5V#kK=?o6{_>y(9rx)$DY*bR~;H!X!%U1LZub*yN85^3Zarc zskBU$@;FoCKlA*dQokyl8Y(RhmF!8S6_&>H^K>~PQL`$ma^9`ps;tpdL&C^z2?^`> zmqH?=4($^XIVE-uiB2WzLZVNJjYGnz)!Uijh>KD-_ZywZAkiuD4dS!29a-2=FWsedRp_zVvsUkQh>@dUYWk9TWah`=S`^Jy*IR>O?RtGX zv9nC%Sm_Zu&P2kT{V5b0udL`b0{p=yvuCv#QA9+c?=jZ+g7D$RjNbGOP24qEvNEct zKy|=f6@qKs-Riz{1T+z^X6YpkNyBaX!R`TQ@muruNzHWVh1GacXIb;9E6s=zt*|`c z^L~)byOBIr;C3{3XCxC6h0hwZg5p0D$(1ZR#8%y{x3b(mQM1+E9PO!~C|hxHczxB{ z8LxJHOwg_29uLQ=1$eVP+OMOB75G(m=}BHVtAlXt3F9iD6=PXjWMV8zs?G{&!%5nM z1*S@w7>7Mc*H{`?5drgbr?NzD;<2hfy(y%u@oPyXU84~`mkP?H<9_yFxfhe`U(U7U zzA=zilB&CJq=#jp9<>jyQmm&1%PvaY#G@lYUEnIB-_s{(MQBHSxYoLN#B(Pa`2+8V zxaT=C=H;6#`h~qZ;)%3Xccbb>Rk&ru3Rh4xFwssUg%M4o(2!%JdIR0QJDvMuV7EBf z#eJim0-Lg=5x@R%Pm=^9h#ALf2rCp*{i+96!ALUmfYYQ$YHoS6v)J750c>hK&T0+% z*_Tzz5068~V4cKSi-k-#{tF&RO}MsA!T~)a=dr#>$T&cEf0Qz7eFytAYBZ>`!}Z{7 zo->wc6Dv9h7Y=Df`ClKwqNQooQNN9tY*?}2v!rN=Cq+CT&2@9XvcvGC z+?=$APOHbL4hf>F(aoF|!mYS&SpJAG1zzYedX`J@V z1PgAbfVFK9*TyFV%Vrc2zvu+Y6go#Dtv7f(?B!)DM5khP)UD#jrfj*SN{*m!>QTns zY*AssO(ur0k6n~2;apuJ`ziI0*MIba1wCj8mmakxE6!&5_UrBHO|f#d zomArfAq$dBT$qV_C1k&APaR{JT=>LOCQ<)kXE+Jf>{>sM8rt5&SO#+Qq=QxZg?i%U z#V6S3PX$EvQ5lKD_KR?_cm7Ez@Iz5K_9)W#4m2LE=j6&o9QOQp5~KQFKAjyMqdxwI z*9e~Wq+k&}+dz+v4j1f9Htw4;*&l`) zMAR#KIDi(N!;J7Hk5^>{(=6AvQNUqb<&WWLptME0O}vsG#exx!wc6L!qq2R>;ON0o z)(Ap~ghL}k_yO*Rby!Oz8Aa8$IhnTa#7V5>w>#*4#i?cY<#LISIjPF)a`8!uWtJ0O zMgVg|HXjzQws!=qvaul)6 zb#G1w75eMs+#cgCbkNVAKN4|xLNRqUH5gpe#^rjKVMWBub%593L*CD)2oF4Z`~|Y0 zD&tY!66N0)@*5{T>3++lgFdFR5i=7u^fDryZ9#VXG~8H1qmRTsU+jE+AJ5K?Bs)p4 zzsze&Fgu^Ho2WZm3SgN4iHxoXlE_F( zX%1mU~2Z55qph@jG9SN4iITDdg($>JMV2q(Bd2D}s&SQJ-vt|HIiI(~3 zYoL2=gEezu@`Jhj-rAnNT@3N+zOj02?N+n&1m)|{?ro^*V4K8X0SNRoSP$8jm)YE9 z)vDa3Pf*>6oodCK`{M8VV!kJ?K3`+vCi1bW)1tDyR3=66S|KblyZYkvR53LUR_Kj{ zv6$;e9iucAh2F7hy1M&saH*Jj3ap`^y8Aa|l^9dq{b-VrZ|oB^`W31O!3|cb)6dY- zZ2jSzxPg^FkD#v$WM>99iwkdv3NMWcEvm})==|DRMDM`@xvyjsEU=i`R&eNlR!y}8 z>ur&Kc0z>X4OHDbV#ijq?jDv0On(PA+NFPR0|`Y?!jA4&%^!vPZA->d8Y712vk!c9 z#QAZ~tCV#4^X8J2|Hr*W=D2QMhRT)d1hg>WiHcG9Ev@`XczmzG106F2r~;tBLXxP= zArbpfVvCz)&=9H{Hd&6ykK3IUVRz?c#_ySTs}sn#K61w`?lbJ(+Wa*w6U1ng3j%-D{sbGs+&!B0Sf?{*lD`9kyyUG-8)bt(fSz%f>BwVrX z>t2fVEI2%fb)LsXxV!X6gJdbblB;0kvN5}~&~9;WT56D$pfaxK1Nhsuq#+3EzplU8 zJ0Dx$+PL73h<&`YrqJ%d1N`(VJD`=MfxI{lM>2{Gy85dguC$DTyX!v;h_MA$Okdwu zUpuNk_!U*FmmgO{dWC~Cul6bo(Vsi#yqE{j^{z|IVY>+aSYAuwB#tF3`IzwxGb=<=>vw%%k8=*h}cl z;%78#HOWe!2&MZZN^`c>Bxzaykz3m6W5KW%0;bWjd-`Zm0W-3*XH%pE{AVsw%|o_MOSr69Rp zRZ8}Y`M2Q3(%`CisFOc#S(|z<31~d4U;{&fbFqsAaH>uMCU zk(Y@_kydxdW*?TlSYm+xnd_-)g(}|3Oa6=OsZm;ICJVlj#=$?y(}* z*PpR+wGmnQ>QQmj?v#_Syif8n!t>Dvr%`N{a0Mf8Yzd<_NLa%C8w5*Wh13+!?~0Zv z3rw>kUtg25Y_OTylMt5w*4sieC$@~GXm80D9uB3uBuYQQ7S7?2oqOy^(K^HyIE0^W z4DuPlF#hp%>dc662})KVUN7y$JygisSZmx1_LNx#=N8x$Alp^l zS(q@stM&;L7|I3LtSLTR zkqJF^EtMAeZa!t<+u~9_wHPLX*3eAJVIsXUa&$U6D{f#~GL zwW41-IrTSQCx5={NIJP~DE)Dw^b>S)M_P7YXciSibaJ?VzS{d|+dPDCX?VHBXKA;_ z5D>S}>)?DZ5bcj`EZT$Q;FFy14=7N7T3tJ5PsF{14oV~=Ip6J>dPL6mY$1^R{dSQs zJ*>CRj4D$lYAe`cGy?RWth4AxDov zhE7%xhFFF8_>aGj!cacm3kQ5DyDy3RD2&!yDV4317}Aty_Y>Ah@oAt#WF3Bw98F}? zbL%=)PGNzkdramQL0NinqxdT2@G z`O4WW(M+;XB~WWiZbqZV9Cy~fp|-l>{>~nZpBU_ny8dE&^BeU$$Z!1XaV^ahrYt;H z*z&2ZeBsSj_)7{S65Q$r*(d{#L^NaZLiM_5SuxdXHXoDG>hELQ)O0n;N-Y#IP8Zw_ z#Muc%-u{-k)-Avv^}eUPF$?uLb5*C11d)h;5l#ZpPJm1+I%A>@sJC(82yPaijpY_8qU^V7nfvoqiHMek}=(AHyxDlbtjck*B zm73UTKi`BY!Ul&+OtEwrj|%Ry>j;3oQPBufa8H_Q?;&bPHuz*Mw=+x4#Ujr|!nc02 zI3U9Q9b9Z^x7TZk!=hUw!-;RwUP;3G)B+8>T3HKM%N(;S!UU+_=r3JeV=DXbtK6S{ zxj1Eg8uSBjSm4ZG9xTP|Xs4)b1&G`$mWu7>y(f}(bN{6@$)HEa%L2YAUYjLS@*ylp z>sQt|!Ub1{o0>DIxR)D6x9O58j)k9$X2sO0G^P^2BNXD#aLvK7@4-%ggTB_UW8?um|Bi8hczMT zm4E9c%@4ogLo=mS=_?t@7pPMk#t}6Vhm!m@Vnsh9wK-8M;CwL@rCf)B8x0xnB566}4vAaV)Ff&T|EayChL>6%a5-?Ib0qhoxJeBdRM4$3@YDA9qJZzJ7~ikoAKu@~J`KLeCtMOds??i)`EJJ^ zWVB8Kq`7J(1|R%P_AIi}-ZZCTRJG+!F$#V$m}Ym5;(EWSlXd58Gwk|^><|ktLuJM+ zKPlsDb`={&UA5#C!h_&vc^{l%V5)Zib&V*+-nDZ}I37M3@^W4(v0FGierHh7+#nMM z0!@{^b*|d6fnucbK>jvgc|RdU$^V8AH0GFE6jM(wqCRYbV_MUjzHKq}8Y}wRCSWOE z&ThLe<5lvh72&&6^=^aEq?la3Koa>Q-#<-V#)?1jNHwIK#ZnNzek7|Pc&)43lAz|1 z%J?0Ss*)F3qr32UIXn%tzKIBbqkC4Dd&B9>o}=3B#kow6zGqQ4xV2kAkv2G_3FQNx zq#+=l`hs|4JU-C0cs*`gKR8T0F;ME;100gZo%1lcx`T%cpZtz=?VvSHq1Ng%wcsls z-{Y0ANfp98Oo~bH4h2`8u(K!A)@5^;v2Nw;VsYSJh4O#l7o<05i7|b&BAwivuKECf zgKR$Rxh%)|4qU8YonT1#imA`vPm0>O=_u+YQe7c^~)xJsPWF0;x= z73Xd(?J2m;6?B2etd8h5>^s!EX@?6yv_M;Z>7Y|inK@h0_ncIxl?U1wFdVo6)(&|@ zEWhC15;Z2{G1%aEo%g8^ds+1*KA!~eAV{Oi-&W-v1G6H6LTOXa!L8wlQvySJXmgtQ z@}`9*PIcy@Ll3$Zx6A3!)n#r+Zi0w7CK?NNCcP6G%%0Mg%?A~Gdn;25+Zt9`G4=g> zyjpBIt174f_SNxtQol{i&BfGwK*e(}=8E?7MH2aa%GbQ&{8QY&FEwrF>s4~7`KG!$ zxASE|)u%TpgZunF5$T1pao(nZZfrO@zQ(jjG5aGR^80ohNlMMJZLRHkDH=TXVf4jn zLcu%pF|=>_rfmSXrt95wG)PIm+h*f|RqL z&k3|Br@`lxlX0B_iV_)-JdvWtkH7y^y2SUgCWz~aaWNt}zOPT)pBb+N;RvTGLbFB@ z8A1KDlf7d2y+c&Rzbgqq!F}dm2-z&ibjQR!x}b0SnYjS*+&2EhLdYieVfm?5tWek{ zP{G{?z~ICz6J3nBZ2uM0vR&9M^2zYaQP*Ii4Q^e%-tzJI z=Oy9)_$m8)UpZW)PbK>8lxC~si&}Zd?w-B~b(9+%^5I~qgwI}`_-ZhQ=rhrU#+E|Zhoe;m^#}KRkAx>%9N3g`N1bSF`pEfdwwRit zP=OJPeUJHmiL!bdTfZ#uVEytAzwcq^%H-+KWE)?oW3@PKcE1B zurBHCSze^#NdJW-0-JOTaa{YE8Wc2E_>ED0GwHATI0oHpao5(`npkSf$v&siU8WrS zW>UWzFrNS@TkSu(%N7`=<@cUz)r8@v{M5Of>DLTax?_*97dMzFjFrpO*lMRy&3S4l z<6RBiK7mXA?r~s`R<=ga2lHqRJ~r0K?7tMN(Sq+M<;+S-`cpb(jNYJ@?RBluDP!|l zbby7nr|wjT?B{Lzi6<^o$`U^G@nhR>8`8eu!Ouu8rP$D@^iI7M7I2njgHl#8bq8>E zs>OD-BWoOSKIf6V#Y@9NkWsSHeX8}nY$|L+qD%;#CpI;nYBVqqxJ66*VSk^FpV0& zW8{?5r(vs#VABFJF9gm30hLDTq55h1);Ps3GgdVwR_b5nN$P~5;>9L?3$+F&>k3nw zDlIAddT%M7vqm$;)I>E}isLlJEGY-S68ae*?d{(+_=nQHV_Ly4re1(~*&cs5cdse6 z)!nPeR)#NLNHOzJK!FL_*!UX9d=0!LLY`4rH_W?z%psb+ta#t-eIok>|x-L+WFQVrpwK7IYYMd%)Xh8B(Mo6MY9)#u) zQe@qdLu}M+p?*t6`>Al_?y>4$xU=uP^=-UMy*J5cnueXIOdBb$HtI_^9;>*RIy+*D zL0@vFp$G`iC1powXP#v<$Zp%a&h`)jalEvc8g(ngnL$8Nj<+&Zlq;sT5BbDuQ9quB z4=Qcd3BlPu=nC_Bx^DU1b)BN9w~xQo3zB@@etlK;pT*Uv{4UW0`YQ!38yd|ICHzlm zd^Dd~a5+C|EgR8B(tjW6nUB$`7cpKK+3}(|*&}u+Ct?1Bylk<#8AoiMaL9QDbHaEX z^S7o{>Se4{vhBQ$R?mrPW&}!Gfbh=*7!(kb6jMhFtZPY2-0>mpePmEjT68V;iUbP| z%wKDL)mTNl?(gekQcSNZ?QB-TeSQT4Qz_a{jt30XX%HbAvt@Hr1=sXodE8qj4+atA zs;N!dCzc8aNUXl1>o*297RnmM(!~7JTz0`-zN^LRL5gMghhK#}-ui>GUMNdKSy_nC z*}Lcp~7}LEvPK5eR~Cl=#J!w zSVmvHq}bpJ+tUNBL!Q5l&v>10;`)qA_x{@ardbptZ`N1@nTd*qExE$e;OGd11}N1ZLDNxHZ2L@^BgQNTI!!dKKfRnGEm| zO2dzKeC9^gk6S~^1$X5})|$-{TL}ijq(0zs{e<@?TjlVlRJB zQT5E#dNa8v~MwawB79-H(;x!K9C<_ zJShzw-dla^6m#+30e=ekhuDmcExycuHTwCPhESnqfAqBS(PP;ADvYPm*<&9Nxa-u5 z$=s-eiB@;bl4j(itM3EIJmMD{X3JNJLP+6`H9q)DF2Wv-h62T9NbTKh&%fea zV2hBt;(uXMH1< z(Vwf0!;g5qzuqf(IO=n$Z3c24OZ7|6KB337{o7kF;F zjz`cCHKW&7TyH3?3Z)Fefr#@%$v>*7M@!t{&Qc%i9QHPRp_n@H1_{%5;R$wZRhUz- z+LSd8#4$nS?~ULFchbIgh)YqyU%nxN%a{efhTE5!N{hSsak^q&hNk!)%pWwyf+r|v zh)?0GiT?=e{+nl;d1Vs;$`E?S4HZ~aZml-EM<0uJn|Mz#HDMPo5mptlQoz)MS4J80 zS!MIiP@^0j3IC>=?Lc?=p$62!vzqmjA5as3OvKN4N6WTL;U_~e%U2QnE&(126*L-4 z`~b{CUQlIY%%0iSX&e|!X3fsFmJp2!h|)Rvb zUXQifHyr4rRBKu`z7Mdz8+bsC#(`T%Ql@j>6=3S)(Y0c9*Xq&j--wgi!Q5V6JX0~9 zu-wiw@Dj7MoIYeqTgZ)rErYwo)pS9?b^_<)>cguwB&%(iU5X5427Yqpj$2on6Lz&v zkYk}aQ~h2jEM0TaeEzVi6OvW8{D`WYgEU0kg{Nml`=K$}>aCHRAzF%a)fwXfof+2T z9b*&JvK9rk(Hj6b<s9O!O>mk2RvHC(G)MR=gN|;BBN#zxjSbHX>wa&s7MD zVQvH_Y=94~em+Inp3?B6;mEzaPKjjfJ9Y(T5g%a>v0A#b!s{2fDx7mM^<_1IK4m1P zwsPArwHw6L4RVtJi7X{h>O}E@eGjrtU}=tfv`+i1;qqxS39J9lD}77#aVft_#-Lnc zbwZo-H;Z+ONvd1N`pfBc|AMkbvvZUo|J15Xg zF?AalZK*Zm79C>?Si*20KW)#aTO@X>4#V-I8J@`({1W^strz)KAB3!r+}LUq`#H5X zHkZY(n7@{SE$CS$Y8u>ApEU|(pJs3fwAEd0;m|UZuDhJGqbjRc_(&T;EpzS_X0g_f z26yQR7QCqLuTKG-6`wGMCfc8|vXvUdQMvK4e6^bx%)PyMQYI_{oz-*$swQ0zDp5G zx10vt;Ge^#RjZX6$7Rnky_KsZC5YPiTQN~=d{8%>zXxe%D70}k<8}?yWkp?9)UhcE zvOrxbP}eN#nnj&0-?1{6+96O^t6m%mbt=ty$?^*>>FSP%{PpbR6;cL%*wH5KIo6oV zkn&e|oWNz9Qx@LjcU^CL$9I1jT`8tUGYG1MHnq^EUmem5*7jcfE+q<7AzrK!Cq7*H z0YM?n`O1p`v81tb*-@c&P;wvbai%kK{5{BXPvS9$yJD(~Xb{DrPBC?fBEv4*@QCJz^`Z}c8(MDG zs50MDIn(rnAm#*^Yy)VR2S$%isDEfwfAjyX{@zjjm&~~iJFYc!{bDHn&*#M`=)eAd ztN){@{t5qE{Zpg*oBwb1Kb#Zje@Tq!`mg&xE^7b!|E>P>QTr!+Wc~O&ksa!{YqD(o zh3PL-iFtuVzG?UMt>rmVn|4KO3{6w4Pf=)ODi!OKJMS`0axGRba883Q;gR}3bUYFf zd4=PDeJ3&uk~#OBlClSj&W%b4{4DH=2&{MgTN#1h`i>{C-u)~yj6iL*x!3t}FbiEe zqgtP^Wp%*Hms`r)*IAKfWo3pGBO9hKpSB4#%W_RtAbgL_@;!Nrb$A`vKwmx3aVA4~0@cMB?r2&+sC%ir- zOgm7J(vet&GI&?2a5e%!xBSak4!VRTaOZxq6< z_NtqfdCXhhRvTsS{0aT2IfJTtLGPU#hz5OnHuXz=P?#wTe$Nady;`}emV7O_iV0IaL_yq{D9zNP_vuZjt zwVc%TB8cle1Vm2d8qdlB?h>GCa0iXr!A!F5`O{(Zf{N2XXdf&`9Yd3kSV>deR;SIb z3MBQL;~8iIBD%wstyP z-Rt&Wt|D{3DiJ09r-A2;CDQZKPK3Gt*uKzBjUv$$STlj0d5f~wt;@M^i7Myr*(6%a zp5#)`4PX8E=ie3m{?bq4{^n>hKkll3DEhsiOD7K5eY?l~oCw-C6;D)-CBpTuEKPEc z{?73F+la%e${FC8vhWFycaF!qx8a=-;hh&yHmJj}q%oUU56*$bDrv;Kgi$ZC9uS9q zgniqr@)p6&Lq@cFNT&Zk_O1mWs-pW}lvR8X)F*!QzP>RPQOpoui$X36NPB|^Ck+DOS)vkH=Jk^5iL+Y6`u*j zX#`{Eh#F=OW>NT~?|$S^lW+dKrFpF}iM*<52a!QFVoMYejb|>_e>3eAg+|Sv4t_)0 zOsiwglx8A$o%pNPyz|oDv?-_`YqQ@|O&ZUBFMI};>sjtES2M5YAB32Uc88itgUVVm zJ76+uz0w`7lrdgK86Qo}@lQt&?cT7ugOI{tQjk0sFo?m!L?jjNbQqK@CP@b-UDdq? z@hN9Mi}Gg8-_wbt23iK8SzSo}0XXvb z;H5!{-x3T1&hmKh(xBWyVQRHXQfPVuhBDn=BVK8}vMQqU&$4DxqP|8wh7x4-45ipi z$`$BTsF|_IjY<`L0)fxBsNS#G1^?iKi5!iTuAV$Hv>qLO@Q;R{%P?uflq?S-3?$IS zs02Q7G+uzYdKUilrY6%t`plHfr1sw7?8mRV6h(`YhkdYxUeSznA~HxXtf?4Fb5kx% z`N2)DkIdV%jMX2}r5WOdTmmzw>b_65QQ#lTzfoE79%rYxOCbT(JVBsIsbG3(1~HVm zB1?iPCS`6@Hc*r%<@DwkqF5U zDME{(&z>ShMZw`e)0TmzUV;?jnEBgT7O9mXD`gZXs}E($vd6@#vaI}y84f;UjzjBc zR4cUgbEUOAVU9bzi2v@B9iq=@3xuh#U{o>TpC971swrn&nd0%x{~0z;fd)Y7>O zLZx4vOXbQE_~I&R4zVUbX(Z!Uu`ZN^48jq z4XhEElxHabXB`Qzo8AH?d{%LBB*S;CK+|#n#F6ZoOLSy%g-Bx$zG}x7rEn1M`BZ{S zMSQt5;RdhhZ$xDWjTBEO6Y4sXkH~dO)Aj7sqAY}?!X`T+juY#dDGCmvY-%*@Dubxl z^$Cp+klO;KT_#eiu9@lDdzmUQh^tPe2upyx8PZ`~19uYlqQ?Q?{GckfpqjR#S0uT{ z-S9C(u4Tv+ASsVd@Bv+614WynT=tzPOcUh{d7Es^$cx8D+KqQCALw4)}DMWlJ5>|X65;4TTke_X~p}sMn z8a$=<^rxuivI@#oB{cPT)4Is+U>M3Ek`9fC4L1l^-w?4*T#Gi`UGkiP6S0QRC1!}p{tLb)?Q}ky2x;uYWOpgmiL#K-36(l*P!1^0H`zva$x}gs`^d=}A$` z%I9qSsLszoAK2hgCpqTdK!xe&NotC_ix6zZeP%poGtX(6lEc3fkkIbR7<>>*lfcq8 zg1z9P;UVH8u^Q|((E=b}xAw=K2=$E8lB-;?M$2C8Ut5GIIjV1peiI4W0F6M-?h<5) zI+SBj!=K++kB8I@=%;jj3o?PFi9t@D3t*Y?RC{G%2Bg`C#1Zb}X9r$M#L__<#L0G_ z9%|b4e9*(ERTy^=x{NlRb|2rZ^8q6w^5PJf0f3k%L9mdv!> z*=?0bglU}210YstzGQmFk~~ku2jR`WtMN^XgdhMp7~E8h+$EA%?rR)-U9XJeJ2Xt9 z(Q4W|Yvr)4U?;f1uEQ9G8H@Jp9vkJEm0VOI(NPKto@FIztuaap5ilBhBtPExgz7i8 z-*-k3&VSAPJqm$(DzWr**KYJ6V03}&Ec64dzhwP@OY{Sm89=}v=YO9-^yLhh1Z0o` z8Ui?g^UM^QY@;l+Z2}i`N4?{_5irX`bG5|q>c1Au!HS7TYlvfob31ycAdv0%J za(yF4#jMu?BXQ(9mdTNM+-GzPOS!r(poTA?phe2m>(vEY*5?!{x3zl&!jDEqlvsbX|TyJb^qj|JiiOsGY{ zl!wc{K_(O-Q>;+RB1u}6AVukW67mgx7%8tt#s3a5UG=$wWu?IcX-&*DPM&f9@Ehf0<1RIPpsBkwNkkiT^Vy#Ec?U=FF~0EDa)$6em(+B zq70>c#2alaL0XaH7w#OQoU3GAhUy2RNklJ6dGtcH!Lyqn;EIJu z8jkUK3FT3oLP>=&kqPm7NTl*Tlu5(W$P`=sx~>v0s`tbFsc5LOT)*rPY^fUC?_$ST zva;|6w=f?-mjzL)Apcrs2kNV;S_v~Iq&HXDLP=o^`Nr&ZUOM3ccP) zu1-q$#W6B6ft1?Kdk!U+dP+H+N=Ghuj zAL|ul^U8A0AvnmHYhKm#f!M=L6R6~q9Cjho0P9H zOr*~u3S?3q8_XjqqJ$-slgb*N25_-2U?nV$Iw5$xe?+}jQ<5xJy%36FeE;X_n$2bW@ zd9?!OoDmdFgUYIik_PuJK|!AF&Y9g z?+N9P+L8@(_6vDQ(*O{bjHyf{vm6M8Ems*j4f}=s29gb;-6ko$-T-yf$Xw<2K2%>x za;^THQr3X{^hSP|C|ZHig2p4Gj}xBk@Q&2Wlt?yixug^PYEOg-O)nd?5dG_J1CiQL zbs%)To+jam!$x8~9;|5?M%swOvBaHF3P;ArIBimj!9kQL=H94<5~um(66#`r^OO@{ zSd1u1(TWkR;SDqyUgl~=6^+DHSJ*v@&hmnwjD9nzwWL_uLsLU*GYp_%YaI?VfTiGl zIAVd`AyW-Ei+A)zacl|=?n#hOghyT58enmqp)(_4_ymV3DqTQVCAUaJ?lKLzy`(vE zavK>rm6q~(yI|)Nbh|#wq0L6feFgl1K0R;hJA2*b=Yu`p- zWu=nh^F{Zglux90fAqIBwdE@$jbu%(dDo36ND|2x@Nc5YU(tOzrZxGzrQz{!L|x7R z?-wj*)K7;x1g>T*h08K>smoY`*QNY*FfJvgZRRq7#0#6_CYXmi?BH|ukiGXDS5--R-U%&!_%CBun zlzGRZ7UL9_>>c}KV$n$>p8C3s^(Ozs1|Z6)5Hhx!d>PAFb;ir`{V^wI)HBbB9`&5Q z8EJKCiA4*GWczD*li?>QxH#WXFHwL)o(* zd!EUjC$eX2_WS{vZ|d{zE(i7Jjb37vgfVr>1WSN+4Dm7EN0J%>^X)#hp=Zi_KaiCJ0}yh2=@GL z61{)Hp6|2g>+Jagdp^pZce3Zn%zp>={FS+U!=9hA=iBW08hbv+o)6-=$>dv;RcDg3 z>Tn8fU2bgNq{(TMpk1eA`@YTeL8aNh*0>&B;?J2ej~_I({>fBV6=m|RHZ?(76+&Hu zhJUiD&tcMHi%h;ZGkt3@DqTXwvZe8THq;~iJ#CtK+7zzaaIwo~wGdMp)8tuP`?7pL z-U*ro`~Lup`drKIGqv)SN^Dm zy@!Wwt9*&EMB~Yd58s##4WjPl?(B~uEvFJ6XL_I)HwN1nqMDtLp{>TW^Y+K z=D^?ldM^b!NwM>fwu%)mje9iVl{~m8Id}+1N(D%Y~(B#{t9747s zOFGy765%~D%eR(hhTbsw)-}a#HdB5v9e{x#255yeZmK}(;IL35anZc%e$rxzPtP$T z9w_V4`@V&^L)v#e0Hkt31G^l+pxnp2(i$1h=Y!&V*HXOja*KJ^@Gt@|c>M+NK#ECU zeWl@6d@J{3M59#0o8UkvrGkI71HUYIZQQ*OWT{w@UrwuLcoBk1z7M=JYJZBDJyVI} zDNIy+vV^u3FHdlJFr6aa9@>ZFOqjT|so->+2 z>kadUg<=XEmGz8Y_8{Z`C#X>CmYjx|~SjvGvMow{cZ zeajQ?8{08!=qgux->9h8tFngHyD;DXYC_{DD8B_;SWW3vdr^1k5~?%pQ7_v7wqn&Z zel_RO!Vi=tur_%Mnnj2C+q`W}SHyKiR(L*CKK>gn(!visk@L5BJ1YAzj$F7=X#`L+ z9N@XBcKCjbB{7llfw!&l7`#w3d>h>T2+qniaG~)>G?$ z+zOj=R`XfCB?%#~TuREgKl(i)9=9)-eSCWeVkYii=ivjA3Z)naAKiZ?$HU}ezD7!8X9;EbA{%^$+khHf7m18s9-vao2K^<@i96``&JfTShMaXp1I03~tEuy`^9rBCQ~K#R`w5A;TAM@_m?A`}x5v-|J}yq{uAaTb_TVVdg8lI5+ag zty!(#G8M1S>h@M<@#?X@PO&E6=vZ_s08t0hBU;CJI$%aLqIK<#eOay7XC)(=*Qe~d z22oIdZK^APG%rJ4b)wI%l39lhfqp&`v_IRo4M9{;Y#zX~`HjSh!{Ypz?&YX|r z!0*uL3qMBp7i(IzDz@~;RZUIah}EW5ugC66Ma5TTF@|H`{X~R6z)c7TUXAdyUm1_U z%EE)xT2#pmUiYJsbW@LQ2RqQh!g(dn#4~+vcf5v$T{$=^TMXR>|M6EhzKD938 z2PnS_w_=*wXnZcBcKA=b5jF$h}-s@}@Spb76X@DFjkkMLIlzs_(yh@u~mj02k>XeZ!0JLCK&??|OD5rkB$ zg%JhDg}C35CDVfrW-VK<&!n+p>nvY%XAnzkdaKFsEl4pQHMXKVYgq^Gp$%n|4F!lYrglvGG!@bco21y7G=YP5J%NEMR*5;=s>0*QLt>9 zELv~?wMtow{37Ci^N2kssD7H_u7Z+{-o;Sc)Y@wd#EAPD12KAVl#k)#13by?`~?xv z){+{rzk;W(DJ~Nvcp4hpByS@uS`3Pk>;{dbT2reG)@RmXW?~uxy4|;EBwmB^BMLXm zw+p#EtWu0j>Mu8C$VWaSn*a>5_?a3+aMa_}ObB=ZFRXO8cdq2ebJepfX; zc{)&JY-*);Pb~GFF?T}%XC!Y+_iag=x6kgeR;10_Z?ju7G3VjiM}6z)h>Q*mF?XlU z`*w`g?yZaLp3&hgz61`FT5c`-~2&5cU=XS!bx(%BiQdRPYJZjy=%gYuCh@SFa$ppz7z09~g41uI=&K zmvDyN4WEa8q`Q$$DW12d9?|(ec>doAg8R(IkR0gGyldjGI4WU(EI32qxnw(of-mT80thBfLT;bcE_0 zFn*7I^$SSg2Pu0f@D!+}52D~G)hRr`h9UIrGh9oc&d?g($)^$T zZQ-=e@bS(FiK1O+cpq<67uFd*6>lrh)zB4BAQ=4t++dO$K`{s2ebLqL65*^!S#`uI zgVBYmBLc?mM?GjF(OAoNXc#fOfGioI1~kRFGbrUx9R$(jCe#M8YQGu#gb0JY4pD0b zJC7S{`oS8pmQAd(`jGrORaF`HHX~Tmy4t2z7e7IQkoL_@alxuTI~MhcB?>D}bs4C*3HZVlg9GTj6My$-$Mf$)>A+=R zNeZ=b)ehN|jQ3=$W+ksOJ&&e`mq6@XVEA6)v0V; zOaeO=!>En>z!)z^1y3O&)==RGvjNbgZpIosnuLJRi4kQybHA4` zztomlw_x(0{;v}doqu>o0J6AfO`JwFJDJA6I?6Qk6W1=@M^N{`<$xOYt7sN>2Y@#0 zvM^!M9c*gqF_!9ik3Cd6J@&HaKK9(tp5L?Q0rot|o`WY13Q*@iuf*mF92mau0jdzP^$_urSjpTVA-ek;6h+N3dwN3MeU zZ_^ISW#~_X@3D}E41u3TpyWh{5AV-TQtoJxev&ex zbNKM!k8(x}ct1*T+7a|a3XHGJ!@jH77-{nP7gOQ1UXP2*(w0Z8M6Zdu0Hfh^(&#K5 zzhM@)dO=-O*&SK|j*i|}MozEc0bSV>(I9{Q*+wO?WBB|N|E3vQ#@l}bKlcc{jz3vG z7!mj;mZ3G6lviIm1iNmk;SS{Cff0_oYOW&3kL-HG;(wi^ny~(VINi+qj zLksk-4W*Z+(iHq5C8os}D02g1UM&G#d769+YFHaq+c*00PMOTZ2gjO5dy#j-;pMZ;G;+4ocH8B!FKjGtMI4Uwpwd zSYC_)JTyNijEF;H(Bl`ih|g7~(XNQ*;|G1Caz~5!6lHFU_ypx}`22x>in9GT@Oys) zpZFW`{}%ldqFTy7`qvie=PDy;cSSh<`In;vo8qRA)~FY{{wGOU)T*WMS0cP}E7|o2 zqtKMEzlQKH;6|h7@znq81Bdi~Mamh!wp4#olxZt_o_x`x^+~Av}H`;kSjwLw-%u zR~~Krs(ejQ7T{=uKn>Bxzl4>4So;2@6#p+_>zBWZm4Eo~D+upvt*e6C_{oGz`9!%jOJX%7VZw9xRh9;WnaK73MB+-ulgt7EV> zJeCbs#bZ5csK$>LUxx5>Ua+j|fno799@wmtH6iue&&t*pEL%O^?_Wag*QX~DOus+7O0nya zA^a&)67|GG@J{7~7V){t%i-tCRvM8yh%s4tcWbl#wa-sC$tVx&nvb@9GdzkYO1UEF zCzwAT9{5;gMvM4(WyXPk7zl><`JN5^@M*oJidtN`@-Xmgf9<|mkB>EJieareZu3H623!td=23n^(uib|AgNd9&cp% z5096r{`kVU7J zA3h!~*P(UopYq07d`@%p{g{JYv|n2`y7s#=ibB)8>_j?4LG5F-`NMp2OZp@qFLuxm z`GetUzWIdk@ngJ2=?_0%rldU<>{u!B_-W}UD2rOerzj`r`mWIMm@nEFmOpCx#(k^v zqK!XANz@||fo?t#^TU`67t?*OZWPmf-kp3pt(9G?I8oy83m<;}1`+;j5#IIt!r!gN z{{uF$({$cvzc2hiHT*Lo`~(~)7N|Eo^}6z5;}xZYp2Zse@$zr{CMlYjKUUcP$Il(gJqV-@?T}Ba`2&PHG zA*3`ArU9v=f_MpMHLHx&q@T^VDI52O;}36t4~NCmdeB)x%9MINf6zW##OEq|d$%0E zrSY&>kL?gPeSE&9k&fR|rH?*+Wy)qfG9>+w4=v(lWl@WGqhfCn5BmDfJ0yOL|5}o7 z*(Qvi`iIqTu}0Pum)#N`>y=;Xn4^o|hpN)Ka=viV^6^!feENq>1*??qF;q2DjkBlYzC3Z?H>oSM^o{qP^ef~&GI*!P8` z&%d}0@KeI3-@k+KL&M{f&_ycaa0*hG_?HvDPk4MR;X8!K?`sQuWBB|f5q@KM{9>a2 zaCp3v@V@Z)T*8;PfT#G!g~zk`wvA!sr(YueZd1bYuZ)#n`1)By`0(=CznIeZ@UF1* z+1R{K`1CC$`W?dKmk_>DF9J06*T0bWwVKz2!x z?33|FyIvw5O-k3RXj@37@~U;h=RRfFTAWnscsBnPZhW$morXe- z65h4Vj_>?@J*Puh`q8;q&{1C5pb4$X4>VZl*E>V9tE*o!rN4vTi#Gid;Tyxo&-LF^ z^&HZMXB;2b!uWz@lP^5|BEq-8UnAkig%8jAw?5(VJ1BgI@OUz3RvPI_z-IMdzeMJN5 zH5C7x@bMc-DEq>PXZpv`@Zn_&KQ4TDiST{G<9D$9hsVpT{KLfW%jnD3yC5(69)>hv z+VoVtw9~am|3Q){L-jti?PV#Yw*`LH5PnLS^pHFvd<*>I`qK8W^~HvCeeKPNo>8p2Nrk1rzp(C~O8;ZF^ZmkIw%`1DJJ-xeOf zgZTeKcziwKZw-%UYgwm;ACK%H{!R%WUTO#Y(D3*jEdOEgm=Bz*FC2Y;iy4u*`oV@4 zzp@6k+eDLs_r0R&k%INx=cgV|@rC2R&mRSRQ{3HWf=)o-2E${s47Hyr;q{I6gr5_p zKRZGnV7+&=o?0{QbxwHs$okK6Jq2xg+4#$+Pfv)xxnL>!SSZO?GUo9amtD_5DXbHg z$2TnI7>Rh-G@s85cFV`sPf~&7@e{KAdlZ?Nq@s2*&U8FQfQ<`te~9j~;{aS01i@!5+RT?uOH}#6$SEkI~n& zSVJGceR%! zNjSkkO#`ai)wnGOn>Jv=OOgn8#>did??S}PTFaw1ADM)Z_(guey4L*|EhDF84BV(Y z1v>&UGF7_amJXucN~J!&8S0wxxQ<^~`VsU^oW3bQznh@H8t26l{cAv$rT&{GN2We3 z{aj@=Pe082i5I~QEUP>7^cM^IvP%Ez-$TDZrC-457X;E@Ea~hklV# z%qkO2%TxM`0`xaha|8eHDdqebq0)czrr*JzGG#TVkM$p-UlyQ$kDxzV(BA;EEd7q( zL*J>i=JcJMzB53-o1p)>m8XA_N^_1#pzcC=x?MGQo)~ELI1NGMfxAd z3*=Nt18lzj^Q9AGUqs=Ld)`kJ4K6|WEJ4_-5}u?btf%LXyO#$(TLk{Pgy*`u8u(;w zV2FjF{Gy?8j(%_ke1N{ppO#x^g5M0^uqke^lD^N6g&MEO4X1#xSU#*DfIcl z{1&?^1ysJfNKa*#+Ag;9b|59({Wt}lk&)X zV^<}19qK3aM-$y^O5O{ilvd0Jci>h7cfP@gXePh zhXgDYu)Ba?W^nu?0vZKuFJMDD$CnBy2{>J(t9~@cFBI@x0VM%n9L4cf0uB^#Uz&(l zz?lN}6mX}JO>TI;ujxpPt0yUnXFMfU5tJknaNVT{D@7^9VRoz(N6U6aHQnutC7x0(KPs z5(G>aaG=1e@q5JgTmi2X@Y)c6kDSHRvq(U@fP)15zBk9eDBx8B77Dm8f#Vwld`Q5V z0uB?fy?~p0asRgoI99;Z1Z)uTuIS1AT`u4t0eAG^?u!KMB;ftsxqFp>V+Gvajl173 zpi97X0TWK=_!GqU_OATWg$mbw3W0mlmX z^O@Xzjetu8>?`1*GdO;;fJFlC5$V5Cz}xW6p|J}KZZ0o8a9rE-5ghjKVrz;gvG6tGIbg#z9t;Bo=i2>6+Rl8`H> z2$&*ZzJO%{#;X2Bf{f#NIxY}!rhu0TxJ1Cm1bkJ%_XOM{pp?t$#0uC`z*GSz2v{PZ zOTfzoTq5Ap0CE#=cFA{K(fOiY{tbpqU{7At40!E4c z{xks-1so+{N0IN`ES?UdfC&P|3n&S=e;oI>Q^5BGd{My11iVeag#unAV1}%2%j4T5;Bo<%2sl^3nF3xQV4{GX1&kE%Kqja6k$|rXxJ?~lUfcvs}d}{^# zNWf(RUN7J*0VfG)6tJ&=@dCCIaIYxuj|5yJ;3EQFFW|)jmI*jY@L81;?P;ii-^J%} zSSVmm0V4&B8qD!up3UJJ0ecE~{~$p}zzYQIB;c3H9KUQJhh+kG7jW+Y?!HpMg#s1| zI9H_e=OpfLjewU4I99+;0`Bh5{cR9%xq#OSI9O`kjTTmBjD2l&KIyqz%?R0 z?T2vxgGD-T5$-B~pB8xads4{vav`Vt3OGx^ogx0D*{y0YEh?#SxxY30ZU}*DJU?H> z@#?qi5Z|#9r=ubS?0Bi(e+c@hK6^`2glu!mc85o{$g?aJwi0@Jt#YZuCA+L8-eQl< zVV7b-Eu+e5E%sPT?s3jtIGva zl~%XgGQ&ziY?V%zW0tjKq%3Ea%5Fzr*;!$=fTGLlbhwZXyKJ?WS?tBuN~_%i*9*N? zw};}(t8lXB1w_Wgm2qd{%i4bs?@&2pU2PA3REQU)mdRHws@@IQ7?by%iSR z{3%JU+2ymapHPx}Nm8-H?)F&h9<4i^rA#}ouTD!ziOoJku5^@G6KEYLh#fLe=a-=rJ0fgFi-N}-3(YLO^_x_i=+Hfkh*ebj(Yoa6_YkC=W%k&CTET%g=W%B5W>1j|+ zOJq}Ku57c*)bpwupxf&t&DC1cSGKsBhf0eZTC2-aVs#5enT1cVCeBFgE6*-N*lLGY zc3Y|$;l8rN4*i=*PxsoOIXg;O1bw**Z7KG6Efp2j(4sAN76EkZ30Ce^D8=$@8&c`8 zTV-h5o-zPDTylxcja*nuz@IT*<~U_E7E0mJ{%blU7Tj}0T4i!3yBVHLz$g1AK- zvoK^J_ez_)(&8yDOC`n8iv-O^UWy%+P8-<8hCd=e4>up+5L0b9L)+zaS@D!rLKx#< zIDvqO4;f39d$_@Ep~DiP*%nucd3xypD72J+y-TRSQYZNvojWn$kGmY{FPRQR$jsIF zBNzU20<^3@O~-)9YA@;MWSlP~h1=}3ISY%uB^LHI*z9m77CVU&er?$^GKzt~w^YTy z;oK^a4xZCsKGpQA`7U*;=`;#zH4*$p7T*$?H%jD4F6y6Ayu(zAxse|ke`JrwAZ!u| z#^R4)jetaqz6rq~9t?6Psp(P^-r)8->i47!jelz4qJ9^k+*+#NBn4PxcpNMFoF&&9EHL8BEtG4D`5V@MjCrFG;(c&Rk2Nfev6Arv=nh&64!ta2kn= z(f2}EA>U3sUFgGm;~(RX?54pxB!X;l&G53}yu~9s9T<{R6YM5sf?qOCH3U3bEBLTq zGhQpP+G)N4Pk2; zbA`?AF?;Q^T^80r3$oIv{s!3?hTqw_d13h-U~P$YmJLPbsdidrxIrWubZETQoq`MQxYrLWHK5q457J{1g1i zccim^I3L|GYmcG)`(NfqmgS;qsuQSVB+{kxCZZ=zfqch%p3~xTTWJIZcR15%qcrHw zH1OiUH;teq397%ZUDh;@kh7bSv6Y;oRJSxpMSMsyVD@yEk zj--AFLD=s+FBUUMNQiL3ZywH#RVK9!VY!J7-znf|i=CF~xC<8EEX6cALBoD18ZRLu_DKo zP(4wM(K z3R14iHVZRSSc@EQr^qMfjLR86ImbLIZS=T|oOI|LB@V0G-U~%n1+85gU1o7n3Ipxe z=8?M$hqhQ^tuiNKOpaNt!3)$(|*2&F?0 zg1=FWBF(}`5_u*_zd(6)!Gdz78%0HdkAGpQ^(}97T2|VG@cQkaU7k!P52|8wQmULr zv2j^Fuxb!vDjr&WQkA!a3-J4IG+z8Y?IGMaMawTY7^><8d2E=-)CI=c|8@TT9rKSa zoa@`c!e$diB)7!OSsqji$%A^)OSV+Ftd^2$OxI`!x;A&7Z3f1@R8yGYPhk0i%|u+N zswZs1;b`iY=(&m=Ub`m&nnZ7T_;5Mtxc6ayw!90}w=Om}>7B&i1kdsB%m3(lN-w_r zfx)fS2D2@h2^Tz$1%JcU=NX`{t)~=cfWLpE{vd4r(e46s<_YfJiE_HN7?Xf(NeiFk%6(O@6yq#8sB$pr5|j^rqxu?7 z9>b1T@;C%1go^z~`TPH6{r8XN^FM9RuoACyc&x;2H@9vHs-`f)o(?+>b;cOh;w3b; zueN%I%B8j$FbI+zut}d?hQ5xb#vsz6zrdIrh8}L&UB=AQr^9mHQCf-_7ZTVWiialZ zF`L8Y4FdAyf1rMo9Muszm^k&E9DW7$B`G@RDf@nbq1z&tuESd0P)hI1#e6CG68whg^J*> zF}s)pgF@mh{f-%b3&ZTUvD^D|7P2V~DO|l4yBlCHS%O-gK>LoEW*y)kV;^QbN)mb{ z_F|I}Zcj;SD&Go{nresr(Q%)gI1c;_t*8HAuj-7!CR&{?NpG&}^CP%;rso#I8@do9i(o#L$Y9 zv6h!?{5%;^fN82(pc;n?|L*OBWHp)Bl7dc>K4@QCY=NpHth80*U8+#b=ilDG_U|s= zpi?=yHVaL5239i2=-z{&o_vwj<ot?Dax%Hy`CzTg1{&e{dzFVi|7F<(y_0`X&EpdO5Qh#8i zXW0uk>{xlpy=NW#Jo3u3Bm14XzKiJFO-x@ESGilfLKi~XjoBhQf6fR7AW_Fjd z>eVL=Iq!sJQ!_hE|FqA7%ro3kQLzT)R0KR!7B!d1Usc=v-peYE32S5%+dUhlN{ z!l;gIY9mJO>@voC^|J$hy#A-*hdU{j#|_s!dByD?pEtvsx&A=Td*2@XdO+W+*L}V! z>GoZdA`4%-rmW+<4=#E4%Da|ceEvPv2XC_ZqT20|F1-8v(N%BV=-Yj$-Ko2tKY4%g z>qFkY*;4H2XTRvB4=?(p^Sx8XdLO#`zS~C{PU$yv!&y6@O*-p~aVO>OxV8Uhn@5g1 z^X8LI+0thGE!LCj(l48~egE!q>B1*(eC({nr@Vd9N$mzt@DyJ#zU!)U7oMQ3I?;9W z`=8Cb#?*b(c@L!ia)Wi~!i(OX`ShMWC#-z(o+mfYn4P-zPrFRr;|;OiH(`|Rv*ANu+E%iOld9{>67gx!Z?d)=~O zX2HPAJGEcXDy{GI*DGt653DO(`P7&3d#Atk#+G4sUbAh&$uG4(w-Ob8L-+C?o#ynK z(klNc^PM}YJ6!PEwuUclSwpXVru9RgrRRP3{HnY!ZhfP4^@UIP?u?z3a$Rrv^H|#_ z(f53sJ9K7z$2P6!t{=U~vLHG6*%h~bzx#~J;=R4Un{(Yw1$jNs?-Oy;$}zV__nVWm zH`(>@Yg6wTWL?(xm)kdOeBrb%CvTki;=fkyx&4vPzui_`a6_k?-yZVt3ti8+dR?t0 z_3cMzwAugEgJV0qU)<;GIhT!?vgDMzw~iZ|Q~TL9LoPq{q)!?vmE}+M`S9`Nf8SiZ z!k2N*xbJrC_P^jCb6s8eg>N^$``#~YPCEGIw$tzLc5&x-fBM(pE$(DT;ihhf?|wLQ z=eld%_b0Djdwb3UqwS+lees#s$34>fx<_uEdH{8^H=@UEiOIm$(=)?MR%xiAHbpD9vU!M8> z#F)bLfjtJ@8-2&1jpH7Pxc;@hd!86S?~KB|x4d|Ga8#d)p1VutzjU2<*xWfCJFGdy z*zws8?_4GW5IKkOO%>!ukm#6JC^rS7(b zS8two)q_8!PsnWb$w%Lo_5C=>w|mc!`0pP`EBG|){+g>sr@Zyf#rA#qL!KJDud^KMx4f~E5pUBAtKIq7H9GZz*07?A0{WWkfMzmChE zHa+|1l1C#4y+8Sq^E!FwzLM3hU%LgTf4lazVW+S9;Wgi22wv;8&zv-V-;0JvKdUTT zQ~ue+AKZqoUi$2P>+see-~RlY>mThr+rOp$++T0s-ge@$aUb9F#^pB-?=>+0#$R*3 z{9@|pD{i0q*4*2lPd#B-<;f>xtbXE#1H;qq?3cCvwAImd-~2Fm(}0g&yZhGrs-Ic0 z>xY?*m$V+y?Ya8~UKDrLn1?RA;pg^m9Qbf&kN2+6d~fVMm9bwI4mlwy;Xux_k;_V7 zIq%Xrmv7qo@@o?(l;nK%$lis0x1Be)|3`!OJz$JUoSi-G((SLOp7F}?%MFj8GT>$T zzQdoscU|Nn<=vjQPK~-?`||XhPTl7@_ErtL(JnvY?3rg=+vWTo8J%xUI%CR~mrp%+ z^o}|L~(fMn7JD-B-mQ_n+N;-y6>j z%^GuXYW2Y@z8iMuWqlv|u2;pqUzJYmw)WrmzBBT)?JJgyIg~N^x(5b4I(yLBSMR## zv?pIVGwqZ6KfK|JtKT2c_M`Jo_~4VPmpbwiuGsU`6C3L4HX=i@wa1xCY@@N)DOcmSvAOaz=f|KkaEFHS7b?|0j5oyC&Jw@ zjL-*7gg5r%dLFE8U@b`V=ClYQMzYgk0#o8<%VUUBjB-V=;P|ym0*``O$@J`rJ-U_L zVg~PSYelJ@^<{KNRG|?Us^uSE5RRj^6B(hnpfp(Jd|lL6~DlI9_QAG z;c`EQ1edrApTj=xQ+e9@z!EMV|LFhxWRm}Me&RnJ_y4Ey{eKxxVSfOL!@&7Rn&!l2 zKGZtMmo5Lab)06KcGuKFLsI_4^_ylBAsE6Ji7}CuY0su33)mYUQ!^hPheNKk*sDpp zU=Yv6xac;}8yq3f1Mnh;MKB<1P&b{Wk6%j~kx;F-i8CYRk>A*%;e$;Eu156Z#r zt*CH%*k+Mq$$w_wHpYrm^5|HVz-;Qn!~+sCYa zhT%_Wxlw`9#vM^gfAa1qmdNQ2ni1^JYtXss31_s$i*1=2GU#YTZ$w0^$SC}ejEaqj zh=^ZwjcF(n;1<&uuWaTFJO zVQhqrM%H4B7h@PY=m%7>W3QUo898Kfbvt0I!3e1hAGWui>fli`+DO?+jGiP+x#Pp2RkvI#-$;LNEI%vjd@feNgraKoL;n@ zBkS`>)?gJC*ws*hb9m$+j>?QSJ2+v{g+oVBe6p=nn=6LlJHM*gI#HR`hNDM#-oUBz zQBsh;GHdA+mKCkOP>Cm0jKppt#S)E!t0-zzBDWQX1W`cjtmixkqz6&j<+IF`@5Bpn%b8K09=Ad5CeZ2{v>Iibwsak^9c_jlOsRW9EC5-qmgJc-NZ@^ib2K8dIy z5fNL)&!1K8M6rL=>SU+I^7brW%Kb0huJ=FPQi7c+6%HKzNk_{{ONGB3&G~b=a6}e5 z9J#`3pW(soa@a93>l?7cff@`IW@&Bn@yPR=ci$)6i$dZf+dukME#Tqg9r|>0EY2pn zT;N6crKfT}EIoz8hFSdFU*j?kmkKz8dR$wvbfpNd@@4T19u9ml&$e31&7#ZiszxYa zKtKAM*&}*!TnA33z)8Yx>@T6*bCsmv3eM-ydXZ2g%JoBDKDLWE|B3}HI-P$TWe%qb z_j(Keo-Dr46L6e>&Y>J%GlIhm;XYEpAp#y0@t3{J)*frUMhJGiKq; zZ=ND?$`l)Wkc+Cnp)(^@i0L~eV$zTwz{vL{@X(s^ij%rd=sU;hta6lF2qSn3m)G%f zTRe@Ilk_C-FZR96p)`uSFBWpkxt6={c!|T{8I*JsU5LawMUlb3R{GB#a5U+sq+>?L?U_WqJI{z(bbM%8 z%9~9YlyIsXA$3O8fhi>81zLcmn5qGt(a5*WsO<=Y5r|(p%2v=6)h56ntlvRyM!-1Z%~?86b#ko;fr)<|%Mi z17TQ3(VDpAR@-E-f8q`4xGFVN&#WGnbf(s!6Hx0LPc@tT2-D4QpSo`LxHqr?0|p3T;k4DHpgB6r$me|64C4-|R;EZIV5XBXSh6e+VIaTCJEw(p`L>4OE^ z9Vw@7w@AcaV_zNR3pv05Cew|5ZV;ux>X9VIg>uMv-dzE1!R7@3hmQ&4pyMw$2PGZl z22GluLE0(n>=$Ui&1UTQ%QdgvZ7H=*u+l9BYQp7yBVT)qVem$QPq?c2@HGOzA`IR*j?*{!n-4Du{MIn|noJ(P zY*Co-BL0J6@J12-%B!0XUoY_Si(AGMzv=hawa=k@nmjmd--dJW`{Q042k!W(WVbb` zcQaA?yReZhE3q^Z5@;Q0$;Wh)LW^A=r%NGti%W?^@t!wDm0WDfCtlt3lo2U92SJaryy0%&dkeB%O7o`bO+OnyI6jU06WiN{DjOg znRIHYnyv{MxmjtWad7FV^T};kc1B)a+L#P^Vous5oM@UhDl0=ypO_)%kC)SO#^;+d zCde5Dqcd{zGsowQXx=}ZrC}M_x%ubIg4zhFefu;IO^3*cq&JlK_ASH{psyi=@F;SnbD~s7s22EkM>&#*`A12>|}SYRRhv( zUD|d`_JzOQiX~}0yTypLHPS(8%P9NtyD^#S!&gCJ48WH(`2`DI!uz8o45&QfnjPUW za84v9dbGPBdlg%p(3-4j#KC=&=!sD+7?P<})+#Egmh2`gA+h-F9;%drJZ6We=<-1) z$wS!C?M+?S9WIoWC^bJ8?~F2by{A{41a3zqg@6i79`X#0KEj{ggW0UsIHQNYmC zZJssATrD$kx$n`ky%-|aA8L&Zle_I6!}>i@;Ve^!Z~klx6Ib#aQmtZJ<(f5f*Ij4%!s(@ssg-Gzp;n83$m6ig%7WTJjbN-f3IR8^i)iP7*6 zOfGe#^1(1!Uj%NLQiJqU&oV&mq4iNlSVe}V=M7vuMTbLh@)Q$WHtI_ELWm2l7`y0O zy%<*|%_38K&l%`4?Q)T(T-1E1bW&3(M=a}1K9@NvaN|NM!q89kr7?yOU2Q6}cnIK$ z1$GBYhTo4Cnt7HHZ~)cBS)Di?ryjFwKaqa)*y?Tx)+SlBa91Z9xMT>*1zxZu8!%I$ zJGPF@LwM+5J!IyKwqi3ZOK?LQok+-kWXBbXW5t?)UXbfCpwG-dXyAwP=FP0daN?Rb z6cMsS!)fybI@&a$cb5SF&2%3c7DCioW)=sss4}5P;=rUNccK}!3v26URB(1w4pxLj zqhNE{Q1)H0Go@7&)u+gp%72y@xUux&XjKbFrshg?$*kcvo5=u{avDBd3zM?i?r|-T~phBv#F+K=Lsg%M@YZM#j!wT-U)5?g2_|YsAYCI=s9DO%gW0 zRhPB(Q`MbOZCSB5M+Q-{A0Hp&L%SfI!z`H?m(ImI<(ftcR*Vun0^b*$q8{6fySZ|u zb8(Mv8oNQ0!lmJc*$iODV9XTkugzwhpH@NoKW`CcbMT!Kz^)G{FvI6#jQ6c^z#f#0*O^DxGpTSl}W|4L%2IAL3_<)K-y*6h4v8xJ}16bUoz%7K+ z_j)3xkYb?!iU@nJr0_^I+FU`hnLugKl+{XLzam zM6s&1rI#DW5f!@=VScHrU^*h|i z5E>o2jDBuUHNrY!P08*?>qqu*y#n>LAMe7s+5nLb!7a(dWz0wSXBpFvBPf$Z40mL( zm32~0-8c{0!dBB02J{XYcl?I`fxV@!e4Wn{P$JY0l1=AQ9pn5R&F$Jrz>>-WwGYQ; zQ5rB$+)p)v722Z=<2T$rXjYw}lW4xEiZ!Ess-6W69xKLQkiKP(5?TFL+d6$yHGoMT z`>6?|Bw^Z_2V%EsvI{`DS2mfgQV^a3#}Eb%(jF@MInHaGXQnY^piW$-2ougiN#SnE zv(74qRqPCRZ?UpIUN8t^NzGAWd*A{a5Ls&f1O9|0TV)eF91Aqtw@2u0Yy!iJcJ z2v8r2x``{J@9*}&x_m^@NL-tsapkL}RlxLZAM~qcz(PU#YQ4ie%kE%~7`~!? zk;$^f>v7;tPFrz%t%^yu)38BPj zwOrl*3e7@2#JGsO76nv6QSk)O%aznXW##IldV>v2(oyfAAEG+XChI)0`<8@+s$#mC zqzTSLGi{mWGAOfgZ*BW_pwFU`OS~?s?}1@0 z8;jwNTlJD&T~3$=$?oD5C#&3r|guPpluyxzd$(Kgryh5cTOiQ>^nx}*v>0wsyGJq$U*Nb9e{#pI)0zS0WI43!^TXao&>aW^{5wSyMF?3ye#LC5apCStG;CSeADR!b#m*xXQp zt#r#OT}H7PRA7CPhr+NQ>*LhM`wfsUzE}hZy7EX8wVqLNxmIzUEQfgxX}?TR|6x;? z4DnadSBi0GbP{IFy=-hwD`4ZB&({TDhW4=KTu!vHGzDfBB0)gszDCs zzd*$}X7U~DM7-05l{ZPJ!Z;zabt}-{Q66`ZzitVsZfDU9%^xDZt)J)7f zj=N+-w$h+4O*xwVB$NfJ7wYf-QU1<3n*4R)_1qn>i1~9n>5y8N$7DQHd6>GBd^A~5 zi-y7fR(sH;3^8G7oF2HDl}z5KbHb!EEf)l=kYs%G(Ld)hzJ*6_<3)CDRWt_D$RVs& z3(rY{7R#6V2$4_1-c2KvxRhdz@h{MANwcsDh_Pf~pJreZmgW{K`;)1Th(Ion2Agk1G|J+r?N9WXY_ZrjzJ|P$&?kZ2AzB za_P8%lkND{4E7-3*e+~rdEhz*JBGNzR?cP-5E!-wvY7fr0(H^zp=lBAmM9u2lE{-_ z?l(e>m&v^(I*KBIq6J)*CV~4?n_P1nExV>B4>&q=cs>7-XI{R;Xu)Yju#|)6XrV3> zV{zI#8(}uLUhIK|ASQRjf-z`f<_XFlXQ&tE44BzL|50gi_J=mAjRaoQ`3lU}mf>t9u)gZ_jTao$%HvQF6m8%Upbn?@DWGC98`%* z`pV4%xbjjucl*D(r7cbIn510hJw}kg%un%1HSMm_#Y`r zk!{;|kaqMJ&>p$na{&T+T|IYcYqyRWoEubxX ziBxUcB>pW`aXg$`;myduq?!o+B8zWs3rCLRqW&4hJB1`SqD_!UF$e{PU@ZO!)(A)* z>6;J?;=v$ylA10x;Z2glotwpwc>Iz1x#^fAE0)SJLq_&3{lKSc%$$|68MQ)Kg3JbP z7St^#bX=!Y$SlcA5GIHlj9E9bFPC5sau$5%;7^u@|`T zrX;g8W0noS+|t}}`R2T2`ky;gOk1|J0tp8lNV6tg;lG7l_C9x#1~WuQ65?3!IZYcQ4`%qfp=Sd1}HY!@e%9qT&c2Dn9Buu4gw~Cf3CSBG+N4iNO0l0gd$WU-%^y@pwW> z=YAyC^v$?{?_h=AdAeMm!X)hj_NF+DO%+wHgnvT91*U zVr%VctwpOgrM9(J+tv29+FJYl?wNU#XTxq1!1ukr>-xUk>z@BSGxywcKF>3c@W1P^ z|HPvgM)V zmq*OXaI+nI?hf;^r0B^G)#bHs4X8|E5QidD!Yn^;IK&_C&l`?#b(V@CmpiGvEGc@C zS>v&9MZd-K>@tfiSGL!Gk?;Mr!{`3Ie2HFsE|183*JIb)Z?WG5i7!cH7hFI-dH69- z`HAn-;c1G_;|vHbvIw{@VRLtJc)AZKjGr; zKY2|1ll1BG*nfgPUHc=lT-o0EhW(oP*^)X8?Pj?OX1}%KG*p!C}E6y>^GtBEmItcELXNSpG3a*=kkcz z*xD?$1AFdN^RlGq#s8$oKIhW({);SEw%31=@BO(vViv|tDe}0xz{`@N*YWq*Z_41m z$Z}DpuQ&-?R+N3F?Ho!E0{f|n&lFE%^8_LmPVJR&g z^;me^Ul{#B%CpE$7asA?`}2mU4L5DrbH|34B}FebTRip&^aBZx$W9j?@z4A7hNlxZ z1y|AE_TusKi|gs|xW6*=0|}4FP8S~W&-+UWPbK3Y?72h4%aWoOo9!O^B>Dk`2ifVu zBmQ}R-tZ)G6K)GA?kw@Lr0B(_)MMX@ejwoy+3CU~{&|1i@Pw%Y3GBI>!ON1Omp1=h zul?2O$A==zmF?{xM85au@`zcIJdGfav;AI{6utOg#&WJKSGL!Gk?;MvJYv?4{}JSI2E)seqL)0a@z`%d z-vOl?`-wx5nXaD}f4skVUHbIe(TmS@BQld7yG&t7i64H1uBHEnGW~>qzQ?}QqZc>r z$Z&t&bQQVoZ!8+8)oHyBvk(ckjdj#Q%=Wufsi&aQdzJbTRQj?~4?$brQtGzfFu#Y+ z>r!goNAN+;zgdHIeCn$KpSmZ@r{0C8XZzHdP)UwYJq}$n#HW4`^r?9xed=q_m@z(e z#8{uYWq+S~4SGG_rz$FZYQIXK`gM&@?X$+G{@mbGQ_uD(|3;ts8#M89pSliuX@%M>Z>yy_07eOdV7hZ#+~G->mrT{pWQAf=$ zbJR`HbI`Trj=E-zqyBKFqiWYW>IcvX^^R(5aMbp=qq^2PYRlJg`>vyg-QlPO z-^cB}jym!_M^!`nb~x%8Xy^To`sstvBaSM10)L)z)ay{m(~kNrG`7=GbGABaBedaJ zp0_*dtrr~i*e@LQ^sA0q{Y##I<*3bXIclG`9W@Vn?-NILeD0{fLy`XwE-2vm)x3TD z>iJ=QHFlg|T>w1@tzUJl^{WZ>e)X-d`PF^U*Ejgp-@f5jpFky>{A$XjepLk>eHr%9HfX_> zesvWTZo_}5>}upd>KZ@w+OOK7Bow|Dd#K<#zluYhQ1E*6Ti}P1Q0Ok6@Aj)i2X^=4 z9%>VOfG|SIC*XY=*_}K?T~PQL{D9O}{DwNAf*)fKC7{qY%%LtQ^ekxu{Q>#_`Vjgz z6nM_Ba-q?X>{&ioLpaWbp9MemfOxT!`Sca!3H7N9JnUFB@`E=H6gBf6*Z8VhF^Xdr zI^S>2*4B$zJ8r_D;agijJ`~b&lDy<)%+{#{ejn4#v%U%~o)h<>@YKqsm*d4+u?M8& z-LRrisIr*RJm_;<-j#M+v*Og+nT=Tt~YNq*;n6=$T8Z=5evF5|R zi?n^_dAyFl?C-ofs@tql;=uluij%#T@9c)U=CzXWvIlihkNLUho?BHRFX6<$O*;-2 zH8xj{Ur?;Wvo6JKl@3p<*KE1r+mvEv_}H&9@|x>7{Y)P5OmYB!b>|^P*T>ZSt zxY>eP;beH$nYbo?8c;#H2lG5$Z})d40_t{NnJkJKo{RQn%;n?Rl%px)*ESrH9In~X zVzvV_PkEavW`zf4srP1d^KlxIUs;wa1B#4zCHp>I5(ar5->76CrZv0!S;wuDK5Q(B zS5(oYpH*v?nu(C2xD}G5pWMQ8{n^p#@@C%g#)I~3CH)#^^U9k`YU`K9R`X5RW{mdq z*D4Nu^jRavf%IwFRAAq}wEIipaJ~h}yaW~AThPlrbJPRK< z-13c}z6ZZUKINet3^}=NPtwdNHNx>_J1TpM6Y`R=5YWRNi+gk}K z?q3!Aw@`iuIk^$z&nD;h)gMuQh5^YjYi&h ztji7;9^%D#VQdLVA3y4vU|i!;FwaqW4~&lMc;n8O>!{quA#t4tn|U7LQzLMfCuxnn zgj3}D7?zcx*tM*e(R7JhWV~)2Cr8$roaFwd6V?z|Dd zmS)Fwp26>$50`pmpToCOh8CbKhXT2sw~#+clpLCKry}r>_D|A#vrTVFo1G?qTXBLI zRd^|ys~{&gc#lr&4bBY`PVom`U4}N=aLTiUcPl7mXB=hv07o6*J4}Z!ftbSAQGR3O z{XomBb~Yey_!vixv3#Gne7^7MI(C{VR%NdJt;UOqrf}GqW!y9U1wR4k%OX@2A z>$+WH(q+EzpNw*f=8u^2mhXh%*Xd{ZbeWgBwHb4ZZo2AnSz*#iw@o|Y5x+L}qRc(!uO7Z?E{p*>9Dn#dS2%|;}8cA-_UmWMxz+MIDC|M%0Odw zT&HJr&E9sxq``0rXNd`SGxmk&qg*JAxg91v zRZanXzR`^5HcCxKKUV7X!Ma}Sw6XQnx4&QQzi~`_G(33+2GoJ@$WwQkXtR;p%*Z%E z%h36?(S+d>34`ypfEoljx$Rf!u;x25z{s=xgexzPwl!y`DYH{yUjt9!acE|1x5+!T z+bqXfNjyDnPs6$}hH|FnYcpkHiZciIzEM6kiuCa1`Ba{u#Jj@8TZePA31_>=xC`Yz zEu-va?bi~AX--P{+b(=Bqx?$qDXJi8Ki!eGOqa8Yo-*Stw`w!)2BYi?Ik~My{z_AB z50&=nK;0Hk!ul17_c3VZLfN?;)QO^mDAKyp$o#=64cwf4nacCHtMa;A%L9CeIF`{3 zn~m$`SnKliEi_V|h@*2GdJ2=wGlq{0sFAdJQbQzfp1`W`S(F46$ZaR=MV%FEnMiJw=a+z2u5V0uITg&l&e@^rg3OV{}htfxPT@-Fl*XjE?3;1v}9N4tE8a~ufLv7y^zMXI{KsgCwtigDRaJTBZ>@CMN!gm(R zc~Bs?1HPgJ3dJGcDGHn|>0!#vVqM;DhHVx3Uw8+am!M#7QrrZeLJ5fyLD?WmQCyTZ zdF((L$D)WcBd{ZT>X6~i&`9otq4`5L2E#)RBJlZYT;NVrTX%MvG7-3gunj{Q3DMu1 zv@dgxq5aD{NaAYhKnQDH2aZLv0Lspd&|ejGR|F;japy39;Ee1mWt^e&*OsXo*bC1> zIR_%0p0@eRGJ~ndTTNKngzsA@-_?BWOn*twx1Gx+%$_mgqr&$L%JZ79!}Oce92sPr z{28swOPBEd3FWVvFOI(?`C$QPoEqz}?vht1ed5tN&4YIa)MeC#!aOubLql>ClC-Vg zL1`DI6Xnh%W+Rx?yMR94G+C&NQNP(R|qLpTk&ndFy%jVzMEx0`dWg3)|+o4@Z z1D6n862-c`+2gjt_`6)(o{mxp(U!Solc9P_89q<=F7D=w!Wzf#oR#nZMytj`2UTPbC)^+#8Zz+T`aNBg3|ww>-FsNfy4FKB(HnC(n=5tFGaZua&l#vigLNd z`H;Hu1V&pSckbES?Ty;Wzr0R96QjGSdXrG5Ld&2`?th})k6{(i^^$PA^IY3E={dnp zldr#kKRb6a{H7YuM0Q|MP;bS^dZMY(?Xs+hzz?0kB9A;fZ*k`@I{iv)I(gD3C0#Ms zV;A_k$6}tbMJN92_V@)fuS0fBQg)+XogoSty3r}jZYMvz^W(q6TKI1i{|_Y1<5UE~ z+ok&~Q(i`@yrX1X`4yc{QlBTmx>eVIG|Qo(xgn{_5mCA;by57kq%5HXm3kq|SrUk3 zM{*VosTlg0ghA@&W)qinh$y@OHcInGI?AnLRvN>T6*y}Vh&uxn!W#GLz`ArZRX>YQ%}D@N3e zs>{jpkEt13vwvOAJ^3{U)fCj_Y#ujbLa|d=GjUR3{)DA&JdV@xxQU>@B58danrETh z+@Q4g?cbI_NuwV@*$_v$vkm2?4wQqGYIO1g+p~{z0$Eu(69ShFnG|?x=%m1(oZLx) zRnj<`zTBkiXNcA5n)O4*77$|+ma&V1G#)hPcRC9LQ}vjHdCTCPX6zq=twQob=Bvl@ zK&k0)CZ&wyA{ZB*ic$s9wn(bCp&+x+nHi{dX4%9iY`b{=$ke=Z&Ude=3GvA?AnyHAxWI*aDR$M*ZL(&hl8~&u@BydeICle5PiIf z*YOU?YvX0cRnvv9807@bmo(v>?p*D_XXnf!uT1zPJ!{e6-%X^XH@#b3`%8`e6WH4} zGCq8eqXzj#>U7*9eD|X~0kO_Pz7(}n2q;&VIOTyrg;S`!{5xF!-=fh>9~dU$Vd}xb zx~>H}0_r_U&tYiBK{-Zt!OeIOXZ2)5D-9{F*<;RK0TOhPQOPQxfPzv~#=OVgLa-z@y+ck>sx{60O;+0T~WktuGrh#T?ae#i=b6}i5B*r$xA z%70aO-$F6rX?5j$*Lgn?9^d@|br_VJ+d=zPln^B;N{}giQAm`CC~*{qTSdJByvUP?87`kK! z&Gf)Y*+_!{FUlR+Q1i9QmKvpCETv?8;9N;JDWglXBYIr4)wD4~A0SOeqZ|M^xr#w1 zZe}|7sBsP&vyQdd@J@$i`dpO7P#u(=8{Fo`cLmROU=Wv$JQLlrGPKmp5rJlh^nmyc zAM96yebbaf9!h;_!)f8SP+o*a31h8 zbt~P#`0O1y@gd*NnVJ2{&{eq;hvsMh-U*(Y{X_qVs!@|hDtP-LmQC0?NKB!|$={`%dX9VWwlm_PGmuEe83(gAfM|lLIT_URZb)GXt@4 z_tW+HRjkP(%0HTqpq$(!wRoEJYiCFgcS#!xg}ICw>hu7VJwY9xGYKbV-5SR--pPL` zpcaap!_XW7(Kng3)HM!4OR1Z4ur6GTB6X8>m~{26MtIIbIR|1+K&%OWM2~mOy3WY% zxzBaNdke}ID3F_wemRLkKD_U=vgR-Op!*oJR=Z93e}?iFL^+ftsSwJ_2uiLD>YD>| z9p}`+*)5JUo_LOvwevjF{{0v8I*I2nGzUV==}Fs^b*&W+0!p%1!DiN;#|ct|Sp$(3 zJ3Y4CXv$Co5rwCsL?Pz+Ce6y&!NFfW*Ea3$d04+E{$7mc+fYt!cavMxhO43u6!Pg| zClXjnKQ<$Kj9$RNjUGpq82>tPTKEFW+ge6QTCa#GaZ%b(sBu*^T;rLi-A)+nd`8Oq zXE@R8ZG#`~S#KMTp7pj6>Fwn$#8~ql^t0?WFbc)i2a^O%!%A7y60Cidz zrKk+$UZw7FKFA8J%9d8~QzvKAkeZ=2xod)V=jOB8#!{g(ZN#V9qiV)X8kIj{?x3-& z_OHob#khaXxJd>12j%DI@1H+5f5e;PS51f)u9{dgX-(0jN%<4=nF{2O&mXZWG;hS~ zgQv_Gao&_gQ)fbaQw9JyCtpRSLd{b-#b2k#F74a@%gJ3u3EHe zNlm0?>6+y=E7z2+S+(97vGcT*BVIY;uPtw>asjb*AIVftyca=7PFrJfz^MATH4k(qb> zXdl{{dxs1Q-NvjZV%+EtiZTc<7+0YDV#AWzbHSHKPAtZS~IO?`lM<3Q}d@>G$Si} z$}!H&lbz$&gjbcc&Y3Z9(wzK~{BZv9`7>`?cvfK1nkADWla}N!%3qkjbk*`1D<>__ zUz)!#cjnzw^3~LtrKp=$J)x$ql6pVf?paIwF!piU2N|~?f`*>zF&VeZ{C%fsKc~*ZK8CT4((Wn_I;u3PEa;= zTS=SgS*vQtS>atM_d%==U8UR1Db6(b-8NU+s28z*1v5R@`Xvv9sW#i0&J^(jM|clp zoI_xQ_Yt&LP4C6pXT1-HdERLD;dYvRI4K9h zW7669IrT}8pU3flpNqYI>OTG`ohIV8Qr41>lC`A6y4RA{;f&)D#2?MoPdbIMd^nXVD}e#z`{n>0k;>kk(6(bb0io zK@v6@`xkyH{7;FS_>;J?=P8nIy6e5xjdmYe(qpsTm(=CcjJc-cT<=e^W>yO%Hi+$8RAL;1dTAHG_1@}oA7lC-wP-| z*L)$f4l&cY!*SF5aLxCw@O_B#55s51rPa=hBG2P*uoJ#K6n)+&w~Tfn#XcnV$D+)F zoZJG_ZrtE(5;@-Z3tu_PYRwlW|4G|+Itt6T*@VAU_%21cO7kfb&o?Oi*58eWZ;SBV zkMfA-Yc=g)8OK^|dhRrQ39P5Th4Opo?+|0e>)d(i2Rt)Il(Vdqiv#_hb2a8{h#6;& zdIrA^L^%{<2ua-}?5mtZ#4pcyb|%((OuPV1sqnI2eX<}Orpa&$)9a)lanX5f){dHC zDLfzLLgCHr`j(adUN@gPBP4vf9G0Y}*+b&~NtEpn>n)VQqE-~*6mbp;=t-f@i_IoY z-Vy!}Q2qwtUz_&tO=k}N>Gk%l#LJdJDO01i+ISs;CIm^m4j1&b%GqEN<1N2)s)?7b zQ$6(~0$bs!D5ie26D(VYOxq>$n}z3mlnWszH++*8xxqP-F}Cif2)DHP?N~1)jD>fh zxli+U8QwBysic*Z+l?lU+p#Wu8Rb`+PnkH*lVydJa+wsqPf`A(`MR#ue%Ct;V^Y#! z?2m049EN5pBx!J@AZHOw8VnCOEjkUnX%}(bosPy$yO?n&XRN)v=efL>dU&rkymAg( z=V673-xiDu??-t=>~obV|MN_}^VHvj@V$=m8_ic>_#)0kgU9BVDR9F3cHwc^ItD(A|* zY0&EOUV-Lj;UOH?8(ukkPB?z+chlJu{zqMR&!TzNy8D@NC+E{~SL}DTq~!hkSQq{` z%D;p!x6_0_;#5f)@ucCXXW^4J?m#pMy2iQ8*l#syH&yJ9LzxYc7e>yD`hcEU!yKvf z8KuqUjFPvU#>K6q;d(U4{t2n)&6hTeb=uv6<{n7W;|@XQwkBVWk$f>@MfR!mc*yLV zBw%|(#ucxl@joYLCe(Lm8wx}Ti=t?1io&9liPDO)*eSz?-ti4KlS|mGif|63CYL4i zRl!BH5BqR_W}G=c<4vRKh%YQgG36{wIkRDztn2P+Se+|jiJ`d^%FWfD7wj5>qINOs zL}9su^su2=G_*oOA$gN8^YKwi=Ht?SJcKyiemseWG`z$)j+O&H&syFN;rT7fA0f)B zX~RErD&X^;=?FY;$IZjgjD+aZ-1cal$n=zpNmv&igJSx=cJe*Nev#O(=(cYadv95< z!%FAvdNi9L$=eGAeXa8D1yT8>+pe1RB{Np>hM@!Y!XKfSaX|sjYHaEK`V7Zg zXnflx-CMiMK)dKVMd=bn(J2?nsuStn**o_o?w$KfzC8!h= zzvlw+?CYF~ImZU{GN`T7ta%lXZN4MA&!dF#)3;xa+OKMq+Hd(twcpecYQK+y@w|d8 zHDc-{{F|t<8CXepH180G)KiD5A=hLRKhE!0jOrQ@A6^y=DK(M$gE;lQbN%XNXv?!i zlHXts`H!s%sTtfp?CcGw5@CN(O(vQ8{W<6|{Vi8Z|Xrjhe#QaUI@lIoBN5d?Gh*h{|gotn%Q^ zgEtS}yeXr*B&-tt;iGfZXzWK<4N{}A8$D%sm-tQAayO1&?RydLVd1y@8quHcc^#Z3U7N3eels<3YELH`%%Yxemc=$(su?B&_dvui(i|7^8CX}5pln66Rrk!5Z= z91G9Em+-C?as>7Hilw|quy3~7m$2_k*!Pup7W6avsaa~@5j>C1B3_TF(Y{A@-uiMV zN{xZIzsw2wh6U8HsYBE-_HQJ9k{TDkSF0Q-6m(quSONQ`&(kon#63CI#;P?9FS+h-P zIRHYLLtRkhIm!?edY--s3T~$zfRa$z3$!;-@)Vy+oa$5IRX)`Lm3`Hxx=!<{ zcG?)BE~tQZwG2woE(it5{|M9uNm&pI^4@j?N|0AV!8)FyHYiRxma-h=o$kbWhJxo~ z54AxFOTi78L2Xb+tDz(qx&Rqa*F}a3XqU^NcBo^Mp(Gf(6h27Gu}~5UU4a?Yb)}(# zt8fdoLy^tU_sEso;l0CEMJ$B8~Vfa2|Kw%E?dnwMdkZYDeUvPiVvui$3T=y03 zdYm6nozrQbPNb|u;f1szVBtvtRRFf~+zz#(F9W-vPAGH|dCt4-9Z*8Bp8e+r-fNFL zs`cxRDr5#piYBsgkFJ${hIgiperH&n;esb{sVpWEtWu_k#7goS< zK6EzpL+IB~Nf&nlLF0bQJHyaZ(9oUu1>Fu6{Vt$ZLgzt0gx-OicgQp7G-xaIXDIl4 zmOY><=y~V^Xx~3D6@=D8zl78unL0ob=zi#z(0@aRy~{aV=v?Ry=p|_KpLkUaS`U2> z`YCked&C)Pg1!Si10C|`fSLnULN`FqLsQt{>Qvh2AvCC z4|PC4g+73W{)4=M7DH!2ZP2~YPoNK>5udPC1ucfwL9apohC-hP)XUJif0Aea;*1RR zTWG>(0kr}85IXeV96N$a|3lhBS^STHPKGXpo`F7vruwo}19Ug^SLg_bqgv2-e-=|} z>Kn8?z?m&5H!Dk>2(>~xp~68~YAJLX^azxMCg)_SYoQ-QNofDUS*ir8gRY03fc^sQ zKLm5=Oz1`EjD51yrlDDCEA%N;Hw?d_qTDQXI`nPmdFXTKVBRI454AutQ*e~bd$O3s z58^o8U?%DNu-P(<<1+g)J7xV@jpRM^(JTz@r^c#0wZA%mx0w#)m@Vg5lzgG^P&Hm1 zrY5K_sX}$Qnn+4bQb)3=a1<{UP3Cs*qt#SyMwq6)%sl^CHA8(x&EyXDi7B z%~o^NTs2Rfz&-5?Sm!-aEmDis5_J-H(4NdYxXaXXwL+~_r|?egsq9yNRh_0zS7&gW zdzmU%6{=EIshFx(tGRc)mi@#txmA2Ew~IG$uI?;uCU4^2_!jPpKZm>Y*R${SHFcgk zUu{rdSFP#-wNYKDE>ah(OVl^iCUW;O^-XoTxf7o&>SlF|x>bEwwX56IX7xRFyShW&slKnasJqnN>Ido`^+S$E+^0I!{p`&>s2)=P zr5;v~utokO^_Y5GJ)xddPpPL>CnagC`mx%^AmKUnyxOi_P(M*WRWGWash8BxRYJYY zUhXf{tLm5PHT5gCL%q(K*I%nQmA>D-D89Z*%^JUa*|Jc0*5p~UW{r=^r_UD0n#zql`VC+wlSg?s2OGPih7gNoDnd+zJvJ$Jf7 zuKEvi=YKf7=R%@O_R#IhS>r3pn`$eg<_lxtIV)z3=Tp6nQ9cY_+uTIX^GW+~7`F1d zx`s;mP&j7fzc$`;Ph0A1`CdRY-rN|*N5rnJtqc#+GGP4`G8 zLnnot(7cxBmd04P_U!VymYC^-i{};>OLw=-{rd7rjmzi{V^uN&ib!`Q8c%;sPQz}q zZp9q^HCtny(oUw>q_mlDs&}4etKKY90mJEPv1)rO2s9)_y zWlA`DV&7s{oWe&|Nr4Q`mM)D|FQ{*hHCC5b+BDIjF4bWzZ|38c6?Cl8SYspOhWL8! za;lmq%_yr_F zUtaX2Kwn&L^lR&yYwM%2b(OJrb8SO?KgFS1pRWGK-8{{H#80`?fYU) z@wz|8FChK<@}egP`r@+Dudj-&8&JukHO|DD8u`BYzKFPI_Q4bN%uFWcXU5hKDBcwqdqX2z5R=$q{eh?FSU=J~&gN-4EXv5br*G=!tkAOgiS1)68yG$ehC1rkQbgY}1VFJ+_%g z_8z7TW5N`U%%dp}TiV9WW8c@Xz{8fI=k#)DptDbw5p#hcjbs_3)%E*dF`z z`QO7ZHBAo7#ETiOo+|d5TK30I73i5^PR(2btQXSdr1m`nbK|XPQ!^By%wF{GpqKp2 z1+Xtq2IMNmr?h1_vlsobr^xQB{Ac!JfUZ(}%20(ed(yvyRQYKuU?1Z9<4=*EuJ&e@ z-4Az)rw95`whNlWX);`xXBD$L z+oTtE%_ijR9*3+ZH;30iY<5=Guq~_i_J9`4#UOap|s@9~}-K9qzS8e6(F7{=X=I>DHuEgSmk8FK1qyKAUXr_av6yg-(-G z15D90E>lu;w{{cA1B!8#-bQFjAN~@Ca3y=S=OBf&spA}NzwNo*EbPvUWpxc1x;5P}^z9y+JJe+nyiJFz!U8Zv% z;i-MQ%XA8)KX-kc2^oOXOeaJJ;55^DkpVdE|MbWJgk(HR(w|#O+FrDK8TS1L;OYPP zIg{NBgZEs?-kmxbfWLcu24&!~cmG7nz-90L`ILdn-jmZRR(ex3wl?0(j;KCJ5%tbe zm|UeB_V>&tYMY|Ev)e64Yopb%=E|B_m09_==P3K0L@^Yijap&mS=xEDyuK=0FXNQD z+H+%`UXcOM+VaLTWzy8ge26(&Pj5tJZqOV5<#nC|GymJ2X-enyzwLQS|7ZsK14~62 zuQOI8Kb!g)q01C>WewTJSl*1W8)MANIpj2;=`*9{K6-5Ob7kr{*PK=*bgsh%)k|Wr zs!S8)fe4V}MD>)BG-C!io0T>MSdf>t+2)w?nPG3688$;Ff|z;egy86jJ%v?9^zhql zqs+@cE103^b%!423@brtIMXipq~Xi3u9J>4?GjHK#&oMaY1sN%_(|h}vX*|YH4Sgt z<)Act=~sq&7#R;YWoklGrVVcX?mDJVvGG<@sd-JDCMmA3niH#PiPzO~20qs3eDZ&A z60~xvEy{h&!ldI!iS_?>EBAj*6RY!XeSKvOXCCRkPN{9KncdRJfr)0$Gseo-rk%v< zHbz|X0t5?4eYct>rMs6jjI;GiW1M@8ogLFH>Ec*Z6BFii%~|*I#$F6^`|xr>pDn3f zS=kG(dvyh-WJ;ah{X}c2Zz`{jEsd>hI6Kx)8!aufVUMmYKQp!)LnE4aNiE%)q|Xjn z7<5|#vyfz%ce@ar!@Dz?ywHIV2AY)R6wPz`W~?c1?lV258I(0Sq83@dfR*d0exr*w z@)Y-dxWB1K@SfnSyf-_h$82d4@B6;ySFH#1n9V;nOKo|BZyMzHm~F;v?3=jtm?bdF z!A$W@Soind4E(^XbiW?6jWc;q5kESH_m~~DfOm-#e)X=0uXHKz9si2&Z1UY%_xHeZ z^zgNh>@gd%goPv*{kQ1ZHKJ?eLf- zF;k>fx?eTnlwY&O%Prn+@kNX8SWH^{+~SyJ#{C3~Gb}E$SYh#Oi)|KfxA>^VS1o>I zF?YF-XixG>}7SFSIgGD!vpB-=9zishFi=7s?SiH{SB^J-M`1X9`{$`6SEsnAH z(h0`?5{uIrxQN9%}zg+wL_BiwWu*EAauCX}H;!ul^9%A@!vbfITN{iDi4z>7R zf#L78xX|K14>so8EZ$^slEu#sGWIW7EVKB&4gU=m-?QoDmIFoKDZlMU8obita*O#E z-h{ro@Q~K#mg+-Y_Y@Q zHj6tgD%-B?Z*jWCv~3Z=p)nQ*MKxW?i}i#J)k*Wy--uUUNGqKX(fL5mYC z&ak-DVwJ^4iKv6w_AL};``3}t1OPP_=ls7{gx>PYb=hp_?OAX{4tA{Sv<|+ z1{=$fo$>YFudnqK@mRZao3P-A>(U|WLqkW809nDsDmz-y=}g1NI|mHId^>zwk& z`tT_zyfkZe%um(FDpo9OXlbmBEs~2P*DxQ zs+W^xvtcj%lXQD6FiRf=n_UyDJaci=YJI@Keal!b?l@IX!R^tQUk;pGPL8Noz5Mga zYwO&77MoLS%9P9rEih|(dl{2o49qVbu3uSRcO9v?sVUajoH>9Y?dtX2b+y8ryo)Wb zT^pM#n`A7`Nij07JgXhMEAIeR8RoU#W)C^#)I z)q{^L&C;u4y_tF)xtJp708jPX0}H*GXJg41JZ5oNf>STNwioWrp^r(l;JztZ*;2)w zJ)SqwyEhY?jD5T%y*1yX-TM|oqHaCc?ZfXKdz>w7mR;Qhl@&;rQAb|=&E+&l6?bJ% z$pGW#{cblVS$aE?ggc2?O-?y06gB4~O;8p~yCLbj4}77_d<$H@!2{s)zWeY5{8Rgu zE3Ip8*u!kNQ4SGM0IosOh-j4!X! zS0Z%Z5w(}itQ4w*=nJNp>7?^3)6pGuSUT=ZS3+^GSo(N)&VP9K5%=!w1^1Zs z>l;4&4_t+$LK%GC|Di>@MdlrB{`fPcE60yH#|P{*K$C(_iSrrak95&5l`0r||~u-fY?V;MhRCk+b=Y z=ESHS5$es66)E;Bm+O?Yl!hSjTa-!qt$k=OkBbm9~7WpXAi#9SDd z)nvxd8JN7oZ8#a=SpMC+a7{vdpoGPAu-&I*b@cTmEF2c7Gcj=A1T2*p4uD?|aq!@j z-QhRm9DUN{OzD9J=nu)EUncZ8S4a70$tw$H-}QvZisgC(wkJNzWRRVjmEkFRo32tq z+)6J=Ql48AQ3vTkLkHvLwTb&p^BL3`_$>?f1m>iW*(Pyr%B^TFw3f$3=;+{5K zf;)ipu+u9=u`8;otK&w@S!nejF3O3YN{%_pDJwZBZ)Q7Xb_Ie^?s6moSvscV1US`U9kG)Q+mDX67ggkglGVn@;EM75tg(?0VhzmS@ml`lb&EoPv5^4u~_k@9VD!N3b8noN`B))#5$HUO4Q$0ta zkHV!r`>1JAU4caIaf>Bx3JXMOWb)!8kU}^hfZ`U zRlm-mgxi<0w4vJu*ooX5Y3<#&s3nZ=97H>_iv*E9?{M12RGSmrH?h0#_|mh*6Z+yb z&86nM=r62En23eb30`UKBuyj)5#gts=*3K8EyVhtY2zmDjy;4pk&B*d!+H`2hH1k| zwM_Qj#8LN*f-~K_TRrmhbz?FUkimeNJ>NY5SE=3`cXu?*DsK|sTycH)tY@NLUOhB8 z|C(7f4Ruvb?x9^dsN1Iy-RPX>x=K0lJ21EVk0Y)7@WM<}@9IBAJmq9d|J*6Lbk-zF zYM)^+eYdRSxhov;fx1cY>4PhzRkHP0a>=xR%k1ej@5`B{Ecfu3E|b#tb0dRt-D7EA znCyOM%q;*bQn<`m^Iwbo$ z-ziP=YAJ<9PS#4FaKlbHQd?YCTi#U60pImawQ>NjzEW0(mod;;Ydb(`c1us0n!a!22>%8Oi2GE|$xLq<8g&nAnd+dF5)%Nv_wCiW*a>J0$Ti^p~wAjZK( zrX?#*ipv#74TeoOn!Uq93Tbn4YVQVxl3z{%Ded}sB8ReLy(ShaXBa7+jzjmH(tE5Q z42N;+CYu{V8Z*M^S!(HyUy5J5UIm@bL(?D_H8Rb-$Iy#!vfg`e4dzC0SO%5~6BVcxZ?G|{QtQhya2AE=F{g-Ni^U`UOeuEDt{nEv~J9Xa7B6ViR zwLx`}=gGmY`xD-x`E>^giT2N#o$%TqX4l+3GZC~uxT%|i|6-*LvzRDW(ntH|8JfgZdMOvC_X4CG?Zm6%WUEN}C(c^0$*0v$Ot0mt7F}6$f z5=pFaUU_p#Z9NT=%F1Ljx4ug8y%6JWMZH;E@w(inm&=vqjeKjQmp`^ACnbn=WAn7B z(aCmMMIDsIOue$aWpxej^>p7ttDDtR&BYxzIxjcxKr#=a|M(^GF+iy{c?nEiYVnlx8 z-?iMWsn6>0bBPjnQn`h*U&m5SOU-V=kgQqFwY%k1?+bL#OZq#JX7=vB!TE(MX}NYI zh^hWC{OPChGd`8mCfCv_?UqNnl^ANqq^7;=HwoQV){bP2tNs#0cc^-PL!NghVgFIH z#{w_$mXHDY<(85;l+Tlzs@HWV%Ya<#VnWFr%C6irH$eyNQsxiLVD*wxjr5d^PIupb zkzzK0KjuUfy-)1J0ZN7l)nk|*jEGX6LMAn7XxaXgISauGAr7VgpJ*=G0a;Gs#kX2Qsq0WPNi?ucAhmvHqa=N~G7U zxXBIi;nuv#@*ULu+{1J5E@r%ItHs_<%B`rct!$`@MP={Rm3o@Tytrv@y)ek$r6kho z6uTA6=S_EARqtZs`g1>dxQH#YXmOK(dZg!hMZL`OVpZDw!5(vY)@G0Pn0a}%-HF)g zbtnYSX=vfbKF$*nisjr)5Q{E}ouejpKbdqqtoxZ1U(g&|s}AorA=vxH3(zEx@Yv`5#wPcql#3zybzPOMJ6a8qmw=Ce8GSJP0XJ-deZ zxR*70jQOf3eC6?`78|P>o@as;X{eR+KI~4Wn#%$@UkB}R3opwSeAXTmKA-gz2WtRoia%^pNmBz_vReg^sVs!yh-{>)O3zZs%nMty!^{tH%EYMe!7=wZy z4qdwU?RnOPT&Z(<U$0*^^2b*N}sno3AcC*V{nyAy^UCiB5e7_q?eS4%*Z`WB#-V}#-y57>h(i)K&ojIg+=A1nynVIe-Nt04OLPpi>s<^ zaZvy2F*WODQND7j2Q%u39co#-$A=9(@mKRadutGZmX;>xpH zY8zc+wGv;tYh#Hd-mwxw=@~`gYs2&xtftuVs?Gje_eV<_>SGjvmP%squ{CL4;2x>c z&AhhD7YW_HkQShJX#RQSb!^gE5!uqmEpKe;HaJ;ACVK-WAe`8w{c$&Etl72jxW(7D zbJ~Lv;kE6Y@$p(pe5A2~tv;>`(?*53rTj$Mp^j7A?E34cxD-?Ane|83OcSmCZOu#c z=f;fTzE%Q}<|-RwHo`TN`e@ybwq|UFS(CBWgz-s3qcuFpnw@ku-+f~=XIKxjrb{_I zCU0`c>K7+Rttsna^a954C~GE>GbUfQCOwDnSP8af<{BvbZjmzU>_^)r-7D4Xx`w8h z*~_Dz&6Q4JHWh(Odg_+(l~pNv>NWqO80Bjh^Yi_3T}G|G+N_tQbmQ^b`q^?4P95O0 ztzs|1Gqb46W=URk1!yO_pZwXfT0gI$(VRgrMUsZ8Fxw_( z6o5)<7@OMz8S4=MVbVgAU8D5 z<1%Z)ak8C?Wlx|PHtkHQs({|2h)}Mib>jkOodbA&#I6z+RH>sp zrl$G8Q2G4$W(u6VQ&ikU1oYcQYLbu5N;&*%nC%XsbZt%Q#JNkC%w2T!6tjmjO-18G zfMQg~m&v0X0pcP(x6$W>Btsm&SKiy)E+g|5b7RvitrLd;8uW#KEZKARySGif3|TnQ zxHh(yi;5KAG#A}kM$0p`PJB8}baFQi=jzbu_SvVRv)8R7B2+*(VDdI%RE%{)X^K?i zRK@AaXmsiFIqUJOwsN*CVIP2r?gi^Tdm(&{J&h@0Hv^~~ksWP{$0}>9`CzmuCWQDf z%q4w#V(Tua*HKuNS~+{^iY3bzES}pveS-~ z!WiX=x_pv~o>Q*hLsP75RmN)Ts)=2-o+(|TnwlGR${nuc%oewG@db7?>`y4Z!j7(9 ztXC^Uo07+O>iKk5U0r?boEpAJ&ozg-mz;(BSWT1_9|?q<$t49z2EF|ii?WYJsz)6I`Kqk{+wpSPppAc%QVv{GvnWmfh!Byx;_yY-!>4yQz59U~Fi3 z8NMy*mVmVI}UtgZ>%b$8?igMqs+Z(%SEN;^IQ`NAxTt0~FO~7<~ zZnsE(qfZw0CwJHF`1JlWp3vj>)DV~$rVB^f>q>2cq~m8c-q+VDwaaww7zDKJ`kBJa z;gWLgo=1z8SIP*qyCCf%Uk(Sl_v7}I2HAkj7#=gw3HNZrxrAU zqAApUT(;cftpsuxIngEBmtD+7esp0d6zX-}YO?-ij_>&K#RGAa;Dd9J3Ps?Ydm+6-^YQ40eI0*`1{!t+%?2~$&;-u5(Kr6J2v_WbAR&t zkergfeaY!}b242yCSP#pImcphF}*+01r1FV;rX5rrQds z-L8o7d>V)8{9&)8no&K2;a!5Cw*0Z=-17ydlnVsX)JxJO&G*F8$>j{UIr^*Hz0Cb& zx(_#{lb7KOD*6C(I{tyYqL|Jd?Oleeo-^c=zMsK%Ha)Fc=JJy6j%vf5KD}uXJ~2e)&bwfDpvi*PD(-a6Q zRSn(qR6xbSA48+=3aA7)bGTCZ=;Pp>(0KG6U|@t&lhB92I5ZV~E7%Ur6do`E%|YJ< zjvuMiBH;&5hL)o*18;*)6MpdbP!;+lc*R((5SmH z2S3dtoaoj5tYJaT=*z(GLg%6H0RIYIBs>SOmjhjnJ_23^U5CC6yc^o_ZTP{TSbYL~ z&+3z4UOsVyrvRJ*34a)@wdQef*nvuQZ4Ia(xDNXJ_JC>!UxMC+KLLJj_39v{iXbr$ zfvX{r9|ym0^&Q|RR<90L>I7)!cL+atlhwC_FGE!~ljh*a0;4YgPl7~d1pK($u*T}+U>77Z1rI#Z<`=jel6))!?}a82h7OR+ZnR#o0Fv@7xCGK^2yU?YR`6b{ z?*QMn`Yv$aqii~W$60+CY=PSDqile`haRO~B*C*nv>jp&J_kv>HF4xeJ`O92=^MyZExp=^L3O(!kTbJ#$QgCxyE;8LqE1KT0-OYjLu>Rl(eLkN8r z_=(lC*sm5q%cZV?w_AM&nEe%F9t5kbJ`O%-^$GCinUqV)VLLeIIPQ<5JV(I4TRocu z>PSe36TBD_zuLejt$8Q-fz>C$xwC9|!1Ez-+Y0{DnkT^{!X_>u@L{O^R?-&yxP%2g z%0_55`3rU7HUf^CYrxDc4!lE z?*OY7kWW&7!PyH5&-3trhn@(Jga`Z;Bx%(JZdgPdpCw-4(~!im6O1l4JZ0ckNa{-h zJbsBO!(s3pNZcmDBTpg^q@4$Ef#Uep4t`+uN$|LcF%N@RSbZD#meqHGhn{TH6TAwN zIJSX*wfZD@(o!3D@H$As+z$T7nyY1mAJS(x!t8WE=WA$C&tQCeQ z3_c4@mAHc|RvLX7_3dj{o$HYX0ARYn^}-_Z(wy@Iku zd1wVMsG{G;ybZj%nsEkgMLYOt4fSUm;Q@E7p}&Bq3w*zhcwwFdM>S9es9y!(sI!#n zpv)J5KY~W#Zzp(aqfIODQ>#}^l*tytMwy9#S3o*{!JT4`+ay?Ww$VqxP3I^Qm~evI z*BN~R%sJQSgWz;X%0L+042jGRaKzV)z5v{4^=;rsR*NWPlrV!ASbZD#vekEi2elfW z0`Ls0F9UD2`gZVbtM392y1?>-XIOn1c&pX7gO5Tx@V67py-?}24T3WvorYiqBs_8O z8f)GGF1$#o`PAzO`1m&%AJEQpg8#kA8BmyhL2%4vHay@Hkd*mO@T=c6ZAcmT z3M6ey=yLK75`70a{0gHFgCDiwp8BP(X6y;gfu{gG4$@%;Z@tFww1Y#gWjMrmAP9aF zl5*7n#;-Hxf_GnU_&dOF+(ymd4E9d)N2eD8aVtEKLM1-FwXghy~bB;jub54?kRT7^od11e9syTuG5Rv_O{?z$U%1PdC&24}fJ~|5VC;v)hkT5K+o7(9$uIEe zdrjJg!Jk4A$w%-%_i@KMV^q~aI|Gd;j-6n|{j?kC<6yx9jA76V-UQ_n-*)i)2PtEe zlUDHbhm5`qocUiy9|nI1&7@sQg4aHbEcn~O3m-8&ZQ#+5A`_l4c<7Iec?cZ#n9&Ep z4bYqT+Y0_1lC(;I!ymWp09ax5aq!!a@U(-kLXzfP;Gs{L@)-hGLBdl8-fGR;!4DuQ zV@bg$joT1d56O5h4&H9fJHWqLbM+K+E=b&#f!9IeZ#(!lByPLFL!Y)`0MCbnzZHDS znsCbvgEzlK-n>9wgTGD~eHVD_%O+pK z;B(MBkCXo3h*xa61y@;p8MxW%JHTlb*Y)yKhoeocGz0PR0`;+y0h zVTgcNzGd`n;N!nxu0&pRg4cD?&OSxHfQRg4UV}aao&@PK4>nkR9K6cv+ray*z61QZ z)hEEeT744S_jg845S(N65%40bZv(4;Pk({GaqxCXmpAY?kmO4jc*Gw}e8XTVB>7tg zsz2I32rPl5ZbrastiB!m1tk4u7r5+QBc}}fDRepMlK{W;C*w8*e#7e9z&~1j61@98 z)*l|HE`Yy*B%EE~zJIo12CLqu$)~T6gBzel_}dEJ0!f(L!H=QSs4waR`f*6gZ5X@- z`iy?C9sD;m6}Rdy^c|4s%fKhCz7ss~ug1IpJm2bD!MCiw3*3-I=KX{LdI^t&qek0e)ol>Lbz^l5h&XCVJLky1ZLP7ZyUo_F zxV6@)y`W<2*w$*R-OgIuxV4wr))se7x6!s@Ya0jag{=tx_xE|ehY1l~+ui@`_51x^ zOZ)ZxIk)FL_jAr?&Sa{bE`<+I@tMcvz6I{B_L&TI!XFJaRfN=M7;YsVMJM6UR1d+{ zX>NTK-ljSQr`J-SJzf@jQd`^4by2zlnp!Lw$%ed6#ZggqCWHp`_GhMh!?d$}mQm5}!Bhrd!i z1m`!n^%3|<)hYNp)j2q!QR@nRjIhsT@O9N@w$H2}WPKdoN67r2h6hyhM`O(j!tMi~ zB_z*7uxqYcABCSKmeT)Iu%U^5n4z!2zagZbmmcjieYcN2H8eh5zJbm_<7W1`7(2L73l^rKxq zb0uN@J@~Nd3>+fT>?5pN;Kpeh{)Uh|=Y+@j-2E85oUmnq8N!wq&RwW|0DeSuKb&{0 zYnKcRALou0UamR`Q-tJGm{Bb(Jf1w!A4*{dAs%29mM>!Lus#fLAf!(w;fq9%y;XZF*dqp*3Y&paabVOck22y+^u>D zzD>wJ=1j&ZAvz2%Cgi*c_z;mpXQ1yar%U0Ps^jo;s?%_G3*w&-_!;cV>PCxvq>LECOwX54w_*3F4_8Eeo?seDA6#NMx`OL!NHS_`YDTT)n zvOWqgQ=NdHR-J-#;^Y&*Hv+FFWd9`m9D~V8w;HKU_)3K5=-vu1~?A>G~m9 zagpv1&r=ZlPf;C%H>vK2-z98*;OnZ*B|dX9 zA?d{6cZeA0%E0PN-JBMN8wfcz0W*ZOi?H-EJukeAh;q-Lg-fqwoa`b^cpH&p{}kLu zJcTyfnJb7#B~AEEA|q+Sy@d1`a~0QpLe3R|rxTKg82k`nuY<7dYPU}m2Cw1Rhv{$d z`fGI`xa&HXhcwI*wjN;Nhcy526xA_!lj?r>UBd1U3$NGx;VG(P@Fvy$@VkWFA7*Y~ zPUe~-T%L6IGjX_+kaDHri>h;Q)`#8t2)syj0)C#5G}G{p#0u^oaM|@au%!4`kpk3E6)LR(#Cu6Ncxhj>Fqkr{K>B zyFaYBN%x1ns^jo>)fxEreYDYc=npV_v)-e^4-q-Z4}6fY^#fm2orAMJ?$$@(MXD3< z^MoxUyz6h!^urXKd<$cT@e+m)5{>9I%o4J{u<%xwW+}Xv*w0uwo7@ns(4sTbTfSq&f>f^-1PF?rBr-F+%pwz~@vC!M9YK+o%&l_6ft~gnaHChdWfK;BM7Ju<}!G zpD4y)i&cIhyn@>}A!sZ!vsgA-Gs^jn~)k%1V>J)rbb;i=A zGlbr!>&=~{PuTQfm+B~7p*jw)Qk{f%s7}F0RcGLS)#fvD4RL8k;dQE$@Nv}{_?+q? zcu=*ui~JBaKX9SyC|spF4nMBCAAVDH2L4WU4o?28rUOq^9fMb^PQp7>r{JTiGeXrv zaGz>(H+ds$-r!=@F?f^ee)xUWSr|&Wb}xmEgd7`zi&e+qM%4*;v+90$uj({>N_7^# zq&f$SKIhUbg^j8saIxwb+^9MMZ&uw8?^T_KPpQtrmsICq(LI_zoUJ+nm#B`x8K2j3 z!6m9=@K)9R@Nv}{_?+q?_?Bw(1=@?Sbqf~|GG3x^m9CG&>r^MIA%5 zbw7MTbq~EAQtX zQl4#LKOx7a;13A7wr1gJyWBl_3|>jdGiDOrPuS}(ocDlxK8nDz2w5M8Un2IQ({NaI z4qpCcw|^4)zT$K#{DSHXr}>?7n{{qSE1@ei|b{?}X?Bk(N3mK}aXbw7NGkb20$ z4}RV01l*-M4S!0=u|u%&VNDaRCFHZ$1WfCCVNSJh_BUL9A}~S7xrAxe!j}oja}Gux z(ezB|eq|*-{C*;@+{DbNotp283ABLx@j=`%{C*cnW@hY>h?psdx z!(S3PDKD&lO!EV`s7}HsRA=E!#Ix+5gGG-!9f9i!Nhbks*7f~xm+CZpN_7^#P00Hd zrQfEk|HQKl_fZM>MMBa^!z>}^66RD3+rQ)1$Kh3~Q}9vM8Mud#cFe)bjMHJ*tvUuD zBjg@C0}K9H%K}eV9fOysPQbra-4DO4It_oMItL$qLhBR0O-TODckxdM(P4NQA$d!{ zZxGTh8CdlsUYU5I@FK!`%kUejGw@f096JQ3e9yI47#>f^`Y3#mm_fPH@I}J57wrDN zOEU(ugzO`nKHzi&#t7+S!W*A*o_Rl9_b-eGo;MS){RhsMh{C&w9Oq5JCkWeK@TjL< zeoEoj2su{Ov$n#MQUPefo3HSyf-Vr}{YK|fHu|5jFLx^{sfek-(`$T@m{Wc-T#$ige@RzEG zVCWgQPbutF9fj*uC*bXb%+)FQgs#uReX7mRIWHmSiogpAslz0!`32AI(jVZrRA=D% zyWRRY{IhEFES?V``={VB|K@ZQUQ5{P1$<0(2EL`*{5$t@gzOWAx2aCS=Tr~Dsn5B6 z!f=)9IJ{GJ3O=WL2u>Z;{oyLrarhb4Dfpb~Az1xO-5;(}9fxLEDwSGqr3r8*ApRGos)sUCt;hjf3qN_8CGsX7IZdEWJzH2f(cZ8rqB zzTmFuNjT}(PKV*wRA=DXztQVHyq%C^Q}7q6hhW8T-TE**Pjwt#r8)`!?nUNSo}>HW zE<(~t!-;#`*bKvusP2casx~iiPeDlEi^8i^C*h;2Gw_Vxx$z!@zfwH}xBT9%Pr|Pe zl1>Jm^|G61;xI?ZxDzJ-;Iy#xk4_6ys)dnZw_cc0EgT|j-NNiETnpuS4Q6sq3m@F; zbQ(VJs%rS9KdFXadQCNa;C0pTOK+%#|4LlPoG=8#Z)zRFxN2caweTrI%ASRH?Q?CK zfR|JkLPfOinK-Qn}9bFk_yZFjgrbsTEz&y15QWahlzcx?}z`cdI;Wg(9Nr9xXt+OoSB5X2)T}@ zVUy2q=ivxk6ZD&Y{IodyM9ArW_?=>>GjMpU-^97^&B3m5ep`nLxSM#C^^x&@vx1O3 zC*kv|OONuK;u7}fno-M4D<~t;B-cB5S*71p$$cZNt@4}e z@PESaC&VA{pt7)hvO6{mFI1g?53A0=KdLUB;x}I<1b-C*XaAJr@3)kk2cJ;P`p&*f4Br_M4koAA$XZ>@Uoz7Vd9z`s>|t8m>^_- z;Tx*W0>8P5u<5{=$7tV%ONeLLKL$UjIsrdM$U8s%@QbR`@O!GWaLht?Y$@D&Ec-u3 ze}Mam9>$0{&Tq~jq&{PC`SE^}#P^QF^(Xqx*U<_14}{ch3Vu&@7QU|9e1JR?vQGq_ zt2z!prn(HHwZbG>1Lb~O_U`8cM`Hb z4ewv!;hZoH=kiFr{Nw#_Q}C}PNSY>ANUJG>Ujt@E%Dp=KLRf%Y`ef85prx6 zPFU*Jhv8b)3HSxV<^k@}^*PvZy4ybjFCb)}1bkze-^7?(%o%?3l^Dl=oBY6^6SB_` zoN}h?4`H~Bko7Tmz3L?Vis}q}j<98cet7_M96HyhCj_yi$& z%fcCPrz7xwqKI*nhP&31Pd-aZ!?V}9IV}zY7dTxCzot3^=dbr$e<>m)Bz)xr!Ro|SOxM&?h(MH244*o`mqLCQ+Vbv6RCgkgW7f0J8Z3co^p zT|PsHKOv+Yv+%>0_)YZ}X$!dKQs!|!zfQo<5t3#ahAyML99s%E5GSD%a45lbivBM= z^>Xe>NGAplY@tn=yZC_^BuDzj=-GhT#Ro zK8{Vm&k`~=Q}D;Cv#{iIIv(Nis-tilVf#P)F(LJng$GoddziZjn?AgskbTnd_o{QS z<@0WzDEx@(e)uERS?K=)_hB4c3co4EMA#_+>)sCk;=2kb4*LDq%yKx%~mgJG_mM@}}UQAL9Q0aoX{# z^Z{ZU9z_iPnHWbOGhd@G5jV0v4nIZMHi9n@lII+p@pYHy2wbZ=0ked7IKtA0x#vI& zFCt`p0+xJ({4mFs!Y@8T-}@5zho?NwdC7AOUP{<`9$rt#^(6`0zD>WRA4cJggyg3m z{+!rPxrSifKe=*6;LF5%*5}~N@3?k}z&{aT%4jlv)B6Nv+3@jbs9-Vp<_7K-ey}>^bQs?F=`p>^G7tpWF512y;+ZJ&0({A1l z!%q-W&;8K^n?Qbzo>2E3cJR$3IaM^ypNz2%QPZN@U7XF%$`pLo5-_|;Zw-Az_ zewZfYT*90LT3B|AA&ndNC%yQ&y)pBiudt`;2Grs6Zrz^zz-AB2m0Yl z#6GmC2$*U@@{@tVN|$~qyq=I}mL&XVU7v*$s@y(dxK?!nKBPJWeUmjEc&6$&{G94E z+^gD5378WIsplBHTXhP)NZ2_Io;NjM$7USueIAwak zTqk1zet?kt#NdP(l!ar%@I*q+6@%BPPQrUsr{TY=9)iU)-TtNU7}ZhuLDdQPN!2Mh zVOGF=fpHXuYYA!V1pFQ$<0T7cHaH!Dvl|0uka;o!{_Y=~wePALq@&*9keViO~NE`5Yt$ zzeQA&hYWn1kaL;lfN3ObTf>!vq#uVVLdq+AjF5dY@Z6Sw^)BM@K0?lwhX1MBvxPg#- zCg8=#P+sPQ1pEiWkWLD|LC7`QETsMj**^>~BBV_daNF?#b1ymxmqY{h`xG(wG$G&V z$ilZy2$%;RC!G^nM@V^v_Y*6aE7I_ylW9A20^YeeVD|G|kb%daN?1{Hf|8 zxad5$e-wU(kbP3{MMB1D4u+Sz^}^RzpvjL}$vGuBZxmib$a$0S^MsTu4UbyIK76OF z6kbm>au1V)^H(#+at(^W|0I%hpwOPZQK!`UIhP_1f z_ZhP=5oavWCxstfOFHO&m?5N$!XK=2?U;oX7r1kUVN|tnh>-g?;qvv&Y4p`N{4^nX zOTp3$^?CtkTof>`(a$6BY(n-)z@HFTaShGFvJI}T!tkSnjI|W>ZDhWqPnN>Pg!Gv> zeBxrR3)EE>zP5=uA8j@V%q7G=NgtkgiQY57ej>+u;lfJ;=0^69!gYk5SK${_r(wlq z?wS~ey$R|^zDEWRUQRzln=M>Jw>q7Ii>`1w4!2%OTI`dAw{3Si1;2Kc(}`;W<~d@J zeFoui*K&RR0r`Q~5yzr$h2JA&tUU$ys}@eU&aE$l$Ep^dsrnrFbz(E=2!EsNh5J+s zPxz2K*Gcdvsu;o9Sf7N!kI^5{!uFde@6$ok1+O5sp|65>5KSB_d{lJ?ZoZlR!#+2|pM9J>pr3(7 zf8)+8{C7h7+aNss7TTSCVsJl^Mhk0ib$N@xC92PX{e;w^@Q13?e;Y7gB{CeFg})%4 zLJz^Wh}3;S%Zk5ajG%?h#JKx|mZze7(85m;>(RnNVl#ROzD>-ayyg>J`v}ot*rhrO zSE!D|t5hf99ja6CQPmmvtm+}SPqpbMeZr;>yHrQv&BO!DUH$M`)kAO}Ve|I)981_^ zVVCMCT%kG+cM_6L8osI8d@^9V30WV5J5;CO9@RP6d7IlO3j0*|!`-Tf;6B3U=Tqc| zu=#--2}wTzA5oowp?`4eOW|_Wad;IW$0p%js#EY$)fxD#>LIvKwYi=1685~XOLY{k zP#uR?sZPQ>RHxvhsx$Ce)kAQfYJSYbR1r3P*rhrOdsN5aRjQNlPSq*+nCc9CR`n3v zr`p^>`h-m%E>InXD^$nfwW^cwPSq*+sOk)SPW2Gnr`r5uz)U4<`f#D@C|spF4zE?6 zgmNvbsbrRmGIt3q7oq^A)9)fSHHg}SL z!ln-ws*b`{s^jol)k%1#>J)rTbp}4CdI-Ly+I)ue37bA#s5%N)sE)&{R43sbs#EYW z)fxCJ)kAQfYI7IuOW5?`v8to6M|B)tr#cBgqdEm2SDk^ws&jDsXI+}5uu*jcE>RtW z8&xOZ&8qw1y{gmjDb-o{JJmT@bhoAt8&yZ(sj6dev+4xARdqj1s}}x5br!y&ItNQq znm(MbIs(s79fOytPQZ_=?uTDioreFSIt%}(ItRynPSc08RY%}b)iJnHbpn1|bw9jM zbs9dUItyP`or7cV(ez=X>Iht{ItDkYPQaU0_rou%PQ#~FXW=WVbFk#|nm(MbIs%ud zj=@c;6Yv(*{cxA+G<;fh7QU=H2giOv(}#0aN8l3GF}O)}0^Xv!AMR3}hEJ=`!k1O& z;Mgx}`f#r52wbW<1~;otz*|-K!v|ES;Zv%!@DHkUu;fdcKAf*Q0?$w#gO{mJz)z^| zhYzSu!>3he;VY_huxO{I506$Ifs0kg;3m}xc(dw$xJz{!KBYPfUs9ceMfYm@uu*jc zE>RtW8&xOZmk4_uhcBwm!CCjY^$~cH>IA%nkU6Ix?oyqGPpQtrmsICq(fyhqI9GK9 zE><0bn^Y&@gM`fwd|KCM;mfLXuxOX22^&>M;9}J=xKVWi-lDo6?oyqGPpQtrmsICq z(F2-3Y*ZbAi&e+qM%4-Uan=3sKGkXXl@iRu{Kq&fj_QQZ%BsZPVERcGPLs&jDcgPJ~^t2zRg zsE)x+suS=Q)%|dn>NI>>br!y?ItRz5HGMc&bp$R^9fO-xC*Uor`{6FtY5272EDSy5 z`c*0HRvm-4sP2cSeAVp}gEy#7!Ut8S;nRfl|15k-bq*GNP4f)rs*b?Ls$+1I>IA$+ zbwAvtIt`yzorN!}&cULuYx=NJbp$R^9fMb^PQv?Cr{UA8v+!lrIXM1dO%u*l9f3BG6IBXEi87~G^f0dG;=4|l0f!>3he;mfLX zaO@+RKAfvM0+*aEa;|+@v}IZ&BS3 zcd1Upe^WgK-%@S9$^9W=@894;)ls-gbsS!+ItlMooq~_4&cNqX55c!on{SamVbg~T zRY&0})p2;O>Lk2VbqYSFIs>0mJp|uUZ5|_i!ln-ws*b`{s^jol)k%1V>J)rbbq4-Y z^$>hZwRs$0jj-v%g{q@)mFhVBxaxkmOLZDPtvU-|R-J>Vep|;l+(=0J33#*Wet56y zG<-^R7QUo92gm-CJGKVCLObsGLqbr!y?ItRyoN7IM1RY%|w z)iJnHbpqb3x*zURorbST;Dwq@z-%W(C*jXk55e|-cI%_?A5^E{>#EHYeC|NVv2pmY z>I^LZu3I05SEx?HAFIy7u}`{vN?{`*bsmAegk0<6@HW*c_$VR!WZ<)^hhW3^-2M@G ziRuKrnUMYa;k~Mb&!`r*exJ_?eh{==f?kgnt|R1F;kBxTpHMCQ9cU+jhC*!%FMg; z%6ameqh_A`X#Bi2=Qp=)x^T;-7j9m|PttB*yJh{U7oB|N#r%t16kBr6nzLKw|FI<( z$$!soUUTU9zwXGb8~Ksg^?%*58#Z6IZtZ590e-4<>$bI*Zj1aCDZN{ak@}4#Ez7u-SI_Wym{`hxu$A7b4{(^M= zi@Ny>Qu?p!=PyWX%Y~cgt<@jh{)@-kUp~^_tyTUHAE0Z_U)I{S=G0~9@PpGQEM3#G zzO{MXg0(Gcy4GyjzUj&iewKMn*V>l0mW~S-bZpzmFOF^NaK8lFvHrqsYcJm19NBu| z$o%@|>bA*ofKQl4>Uo~02OD}kzHDP;GM{CQwHR$_JhHll-+V4^MPx&4Fli)?L5eFvvIEn9*UN9YM7H2X^`>vVrFMcMz5X8%K){UtYE z`d-dt|3jMn59!XJeZKNVXtz1!qqxYBbrbb1Hi5~Dc_-MFx!6ScZaQXRHz`(kdwq4V*rF1F{FeTI%OVwo)4<-pW~CRoEe zHY^6Kd7pyz2)2dIy;3OtNq>x&@ijR@{&~y6Y2<*m3?89Aq&NJ(QXh9{eWXX!N1D30 zcSL>2@_W^XjGF&f>LYuE`WT}8&yJ`MS@!B10u)ELB!1!^%cw zMA>CY|5Li-nLBma9nW7u&}DZ#fB8X|-SKo;yUh4dy5spay)L`s`DTMIyW{z$kzFSL zWlj8tnK?hXvHZXP{GXnl@bB8aoswnU#H^CEo_J(F1$Yg7)0)EbagIN`XH^V|$ zBHWO@osrOn^-m0^24-{iE%^VX$3K4=+U^>!$8z&u-P zWV?~${hfU0EYrg`{q+1#>^D!=l0Nz6D|22x{hH6$L*WA^&340f$`_LT0-KjqlK*62 z{i?m8)B*E#()UNW1DLRsUrdrFt}fC2o)4uC%K6!5*Yjf#1t!$Jl($vA^B;(E{(^uz zzh8I*^%Q6@claZ#UM~snH_z0T78Jh8rEPXsFtCAi7I(#kfhy{+{s7yJKj^lHf~-%L z_Qd0EzrCT%ena~e_$hx;X^%g&r@(x8_mV*9P0Cw2Np7@3!=|FO2uu+fc(-M1ywB20zb!ssmz(3Ev9e!FyxY>-K0n@P zmm~23%cJ9=aU;(^FWzmJC&l|Lo8kk9ZJ!VijhFpPDaPSRNG*m5e<9_;|Nn9vkno91|a~JSgM;hcdoHCOP%qoRD8Pe1%^{ z>*N6{w>LEOwt2kPj4KR<`8BIoOo?ny$2|T%G;~nLueUrpowqC9@k|!{+m}1EpQiH? z^P_)%{&;`T^=~g7JAXY}%e|@QE4asw>ncC}ELpnr)sUa@HO)*myY0BH@Ncl|r&qCm z>X~8LZrEOFw%hGxe)=iL&8TqOUkHWYkom3L?6BKsmbvv||4vOm^@h|>oq5D=9~XAp z`6aP7S3k@@Y@c2j;(@p4l^KO>mwGaly4}^!%)$^8tE@k~o+pIDFPf8T_?|J_OP9S; zeUN&FWd}(IULFd+YVKGRtU5^gEYCO?njoxi9ANv#p1qZ{-;;Nny7~cMa$ag>^=!93 zTvq8%-jG@;^Lr8V+w|(){IS>6%GFiW(}dK@^72l;n4Mb5`Q`di8a^1RlXQD0?{9nDb!Cf$4zD)Cm++sF+Y*N zw`Kd2i%fY%pD)Sw^2%<@ifJK#vMIIlb=t>&Qnjfb$Mu^18yh#+?UcV+&O`e9WquBi z+hDg(sph)LdDo5GZkP9l;xhjm>ZM-x@0mJKkQ|p-=$G>2nTvK#@`A%Cev=8T9_R5@b zebmFo9@{_3U)8uBcDZ~`w?8>2wQ_mI_5!KjX1RXexX4hyEpl9073CyfbE}v8lI2o< zu7BjedK~#9y|NnWhxO&v8~n-HQoqRqQvTWF9W6+Rq%2{zH9D zlJqCfrd~Kd^)X$Re_P4+C#%hQbsGu{^-^0)J}DRVfL*uj6=|Ono~bQ5H@v*SaNfD& zy6t%y$B{ny+1n64VESuK&0NaWm|96Y*!2~Sltc1o`yb`6+U-k{Zh4q|H<12h_Tzli z{|hEtYtF9k^y;rw&O2w^cAHKv6n#<3S2cz6u|MS~lkIhLCsanuY>tPw5}gC+{|AOna04Q!Bl4O_TD>O09gs z^&emL$Z|>ktM>Sk?C-U&ZI7Axa-B21yC6yW6GBPXKM#i5BtNsYTwZxv?ERl@pK{q= zBkjX^DbGw@@X zXU3@ZA-{XOQ?`HE^3bns`%KT5<9XYEs!gnx^K!l=^^XK){L|0dB_Ds1?LVwGUVC}v zqMa&S{xiRq_OtE9`6<^dSzo8^^db6}@5wcWauv$@Q>%B_a(U%>p6$$^MW@y5wB_oa zw*9bn+S?Gd^Q$diJ^32hPL$*Mk>#VkY&+HE%jfmKrPBw3(q4?i4#|htPDki}m9pKo z6XzSz{~GN2d^?3?+3SBNhnF8#u2bqZ*z&x__zu!Pz4o%@2-%mHGZh{=N1(v8KV6YtghfnLko1e@*^^+@Bv-zG?Y(`qRkrQ66tx z*>VRnwsx-(b|x0 zC)@wHKUi4Li?D87Es%V8?UW0}N6&v#WV>xA&Of66&9UqA^Ix$nd;M=g^^U{Jbxh4p zTOM!z>zuaTw%1|pH0539KPktN=D*4L@*O*UATVnFJ4W&|x}8SMe^X_)ZLl`jO_pj=BwpwbNnqUsZlw+3`evk2wET=F8#rzpmOoUcyYRY_0F~ z+G(NW<2~lTYM1{J^B?D@Turjx8&4zp-#A&{TwQrsJ9+b8OHG3<*ZgU3a(|KUhlkC7 z75Q;x%SSnmIRA~wm(QF3+NO8g_G11!R`TJs(-G#s8rg2!iSv!(>k|U(@sfzl@7Uy-wx2x&L8)@WvDM zf4rn$F=M$u8It3fUv8B9J2$_TH*WAJefjo#UGo2IwV5gR4brap`=nZz|M2f6{h3p# zKX+d=U)E2kuC(`&Qh&Z})uyPn&q)6+4(~K_eWCtJWWP5m!~28W|9SKKths&i11~G9 zy!$`yujKx@TJA?DrdH0X+TlwUrdF2Jy@~y$_jlyKx@Lo4+ME8rNYb6G<=6Ss1cMFu zOQlQ4dF7nE)0cGhC(jSY^E=A?p+dv+nVfsWJx2QH_?f%;WFxh5?5syj64L$$&1XmHqGpV{vGp}}G23k?oC zA1J@w`9Fih&i5G{c0P}8=lbCMoxx$}>kJMXGthD9`I7XVzcV=O{GCDdck<~we`j#m z`8$Kd&fm#zcmB?x`a6Te&fk&kaa%s;@5p%@?0KEflRw`1J%hu}=NTM!K2LtT^Lqw| zozF8k?0la5cIWpD4m+P`aM<}ggTwjy?c;UeT+aDEgTv16$sg~0pF#C~28W&Bli%)q zpTS}0`wR{{zel%Izs~m=9Cp5slygu1`JMW8{?DNLKZC>0|H-HC{GY*L=l=|<|C8VD z{GY*L=l={2JO4+vyZUwhkDRyLp4a(6`Qx1*G&t;hph5M4^4pysG&t;hpuu711Le0n zKWK2+`9Ooi&Igk1b*}$gKWK2+`9Ooi&Igk1_!Iaq&JP+KcKzP`y2SJylyNW=GW$(3 z@6U|(Z>;~${5#rzFI4}1HU7FUvxw1DY5nfeKK!dPzcD}Y?-cRlD@OV7drO`;5XFD@ z{CLmb_I!BGe<1T@ZI$Q8Tc3S|5C2-|`_{iNP(R*oS06rwkG{83{d=>)`tJo${CoZ~ zD$M)u{8c8C$_9W&SwwqDM%)FFR}iZ93OnMAn7Za60kn|Plt;t zDy`2h{(X|?&vLt5dPbq7gO43I5P$y#cmG^fRI%Jw6^yL%eEq_r-N(|`R|K8|6R9x{`)DT{de7d(kTD^M9+WM?VkT09qqsC z_TxwS@5g%nyKZ;$@iWZ3qx^T>-Zjd9@9_M0-R}ACZKM5n-QF_Fe~)<-h{LG9ie*Wur{#eAZA>xi#JRe@>Ip(9DSH$1H z`Q7LDocjBP1(h7Xn3ud+{KD)iH zg8Za@I{bpYznx~v13T>Ys)7OgIbFT=;eU#M@BI16lXu(gVfFQ=6z)9CpPw3d#BT4M z;?ni}`3a%a9`V=e3g!B9Z%=(;gJoSv>Z6C}JLW(9DbKGDhn$aJA98-Z{d|PykJ07! z{Q746%`pBe_tRmy@9izz!#sX(5B?a_(6U!}{@d^4{+j0s>)%iB?X$l1tlk}#o_{}+ zACnUQzJAtj>z`Mb@qA$Ydvlb`)vC;z^|qLxc-W-pR~R{<>UUB z=Zrcz|Jzc&4ChzB-t*zBdLOarohhIHaQ*On{X;%}743=7{#$$hyvEe_cH8|A`S-P< zL;ih`{7vg!Znq!Nzh{1*GL8CgB4525?0)A@rTy&vci4WuFng3gUwz2GPub_2#@_Ar zIG!)+rCil>s4wpKDc^qEzKxyMSFfAPc__d5_u@y}=U>XVqP)-g`Lq-1EPKWC=k4>| zjdz>5y*&5echBkF;rwm!hw;(hmggs_|2fO8uU=8T!TQ3~ul4f}`S-m$?Q!DYo3ZLc zI{)7KbF+INvD?q9++cla&%bY$_H+K6_3x|XIgjHgr}gj2|FD#=wwC;j@a@m3>$ZNa zy+62d%^`nXe0+J%nbS)??B^5Cf2{9aZkKE1^9k|sdpFo+d487uRUyyctoM9;>dE@_ z*5Bg!Z?vyZd)VX2-)R5J`S*>K10UV`2T#7cf4^7W58(Mr{Cjymjn#zg=Tn|vFYP1# zoL3(DnLTg4lmj1KpPzlzTAqq&l!NtERSo|4^6$@@-r)K7)?cMPoqw<8@}9Ft_*Z%V z9$%gI$@}+wzB9tVI^^F=d8~g=|MT*{v~D0M?dbXJe-eL~`PXZoX;Lon@wGgM{8MQk z)_cCA=i`s|N0!y@vHmRWXZ`!%NdLk=_1eoT*E{@sX)pYA%5^0FUgobu{=JmTE02Bt z<@tSaOI5e$Frc{m62iRM&0W$@5P=zx}Xw!bg8qp6_m~^?B`O#}Vtjb~^O= zk$f&D{feuUrxE?n`S;RJq(eL9{d@Y~2>}*Tu1Wn zrT-oB??<+i*Z+>K8<0O*wbJw1y>{~YpB-2D>4&wG`1|ttk?ntb@}Ikm@bT^SgX`Dv zwR@~TOFLQr-fO2L_@~almv$nZ5&iFo{=L`##J?X|uCAIszV0CP8@KbH=ieXJPWb3v z|MS|*&VQ`;+UZdL6aQ|6e}ByMot}U1&413nAJtAH`d{9^Cmq@;@82Ij{~hx0kJSI# z>bh+^c|N<>|Gf6H?Ib?_=>8}3AM3qwW#>Qq@6r9QqjrPWPS(Ho`k&WcBj&%u{d;LI z(xY5Q^6#bpS^qv#6SD0xSFf+``g6DOeEUP~Bli#Z=eB?1KZ)-z^>J7`vHjI<#{cN~ z?+E^V%k+lB{QEiYKL0uYUOqoJhxNZB`uAS{YpNUI#~$#y(z3xt{qg?pI z^J@24Ki!s(>zC*E&+z8=FztnpZT)Ya-@SHY{8-=Kn-A>$x7=UMpnr=0uk)kzyYc^d zKW_B=Zhd{v?>+L)o)@W>-``17yghDjOR>==NrdGbe z=j-Bs*J!z?)o`DM58SBb&6R{-koK9zk3ow6P5-d|z1KhO_MdA^vg`Jc^mg!uQ>n*Y5GVeMZ#_|Yt$FL)0qN%?BD zonEugKfFKVjr++B_H%^t8p?}5Z}ZRn;p7=_@?*c8H~;yA_4$7*{yW#Nf`SI??@yh& z!@kE5D(JI*OS7gwbqwX>xu~izWc~dyGj~}3p83tb-#Kws1NU{Qm8A_Kdw)GS;C%kQ z6XgAYS~FFi`@;V|d{o&U-p|FyFW+r_emnkY-?0_=`P}yxSMIc*OO%+RiQ-?rZT;6p z#`E(>`}ns1lD_kc28TT#-}ys%f7SEzoj;WK@x9}nKQuV(`S{Ks8tHF4Ki~85ollfc z|4jS7rug>Ek9=<7`S@P?&Mz7q_I!Nj6Y25HFP@+8e4@OM?;Y>?`OY`W`}p2==Nk5M}xzjkMI1WeEOcB@BE{@kMAAt{G-8P z&&PNE(a8FBe!l19J0B^ZzUSvVA1UwSd+9qrX>i!{@tu#9uV2s4cRo_y$M=r+{Cwvt z;iFqWDSv+FD~Z1^?{y9iJ3lGE-T6v`!_HR{|6bnbblb!9PvkKmL#B z=R031@8e@bSpS~>Z~gp3{*b9@6WsU==1))(f8+z%}$dN|GxO0{(X}7=STbZa($QgcRc?-P|iHad*vbZ z=l>wrAFf}XpYQqj#l_CY|D()*%ugk{e{r$%@r%{R7ymsf@3(4x3S>X>Vb90s2rtTZ z{CE5;`#$|x_3y`3R;vFVKaga5g{%cmS~{r$15PnOE}8V>XKt-nV3 zJb$0hU#%}Uu5yE~sxY#O>y!9HRb%CNeD&Vm4R~RTyT*>&X_5tb|Gt=I>)-e8vGvp2 z+iiWk7xeld+wUldtXiYK{n+vbn~vw-uj=Kz`1Rx1PtJ3GZ};K;eKFTV-iKe=+h^CW z=pC^1{Q2d*A)g$7o?Hjvx#H(QZ~HmD19tiB-jIDi{w(pcMr=P*d^wh5y#tnK^m3m! zV*9dQnU|j#KE1cka%u0tVcVDVhV1+Cr}cKb_vCx~EWPtD?j5kp-Mu0Ee*7s?KO@pV zxwp?QpVT{G`GH=plOwjD*xPNFPw4HljP?#3wtZ1Ah9>(T-`j0@TyLMHcm8902ki2~ zUf%Cwzhip4-TU+EPcP{0v&&t*&cE;M<#}tw`8#^M?Q(l>pJiL`z+u~4dqej9_?F&o z_nv%jpL>5^(~tBH*yZ`Xyx&8*^Lo48`}4Z}=-xiN+|)Z@IjNWX?Gfh>^mg0jF};13 zq27VRw)=ZS_Wk&x-fqiaugtUg_Mg~0V3$jJdB2181-;#tUi}pI%6$9Gu<0GJoY2et z)rjLud%NxOQN4YZ<9i3**ZcGGev7<6Use`kTrBP?8#~&+&-|DEy^rHPzuoiiy?OkV zka^p1|L84yEGyn)$D#GF?fels?EU_#V)_~1huItY2>uZMY_aqArxv?;V{+MU>!)-7 z{f2x#P%&kL_1mk|f3KXnUAI5+2FDkbO`-qt-bRJ``sGu1*zFJT`KY8%`R%+iWy(&w z-_)s(*zFJN=i^gG`Sw*~oe!ULpC20D*SGcO`Sdf!k{dW$dt7x5$8*2h zBIV=z2jr_x&R;fbyH9-b@@b^UcKiq4uU%9$yO#61=LYVRm%XyOr_cJJe7?c;WsUep z_Vb0to*j1moYf)w{`>Rp^Iz(Jr+qJJP9x>u{qyqjo%|R$&p-H&a{Vgn*=difDr>O5 z_ne+b?6T+Y9{Sus-=D9TLpkur&y>$MxPFlTddb(S>PHIXIqCeGodrogr?B_ui>uAb zY4iuyuc+1$E`E%xW>&w&b?mhV#loP+(dp?_~<()N)`r~s6+Rys$ z&R?i8%d58+*zZ@&q}&`wd3m35*(>(*HT;kBYBpG3e^$eQPd=Yn-P3LBPd*>;^L*6U zv)nG1ukQ27`|;=0c3Z#O`hQQ}E$wGNPpInIVD}4$*`IXm=im5Gvv|G~KfSzar%!zE zIX&C$ey>XZA^mXO2I@1la`K!6fcJjfY%T9xE$0+zCwy`lKlG2H zGHEZCr_3P#)b9)}=WE^W{HIUb9h7=6V|UhTx9IH!Y{Yu3VP$|dHR=pf5;nGCr?`* zlzR8Z6X$pH$H9>G-6zYq!vFThlbt`PpAqBA`taS=+ig3s{rC3$LK#;aM>*|#^{3Qq zuS{ytS^7Uv`X7fY`5=!%KP9tuB<=*f$C1% zPVW8FHAV7Xy^bsU9=|u9Wc;}KkAB7V$Brv}ayy>Hf7khM^td{1+HU@2N`C$uF|Mrd zUM}r~|DB)z=zpZ|jVsD$eRyyFTU7mK;Lx}dA6@-;Z~i;JW{>q<^X=4U)636)JM6L@ zSNP(`)^h(${~R&@+0P5SapjFCJO9!D?6~4MJFf7zz46p0&likO#+CKmy>_$XiT&jB zTbchxk1O(T{djNwTQCj(o$(~^fonV2_v>x>`TTNW^&{4QmG|nko$TiT%zp>XlXsiZ z6)lZ~n963jf<1SKfGPmhngZP=4#j zcUEuDw-fuHDch+Z%4>c3j+za&o!V-r8OYf89K2`kpy?XyA z<0(n|G{|+6e1+?_+xrEnUzy+Q^uB3Qb*Fu=BChvCw*N2X{d@K812vu2x8;2i>%-gk zvnhXh-3Fh*uP=*lxBmB-X*=}&@aT(je;caZZtVP_KD{>{>3=VqXBQPspU(A`{!cyF z_wTtLlRx=>EA`L(UB7z3mlR*SXNO(yuWB&T{(OH1fBtS$G`Y`6|D&C)uU;nivG((` zuzkPK`rhQPre~);J~*YD#}(WE?E5EWbu{<3wI_WKUcbN!7{ z{}uYZ*Kp4Ud)(~Vv>)%0jJUqq&mF3GFUC+0Jw4m)etYHbBTzq8bNXz%P3HMR{QbFe zci8*bz4HC%ET2E28W%$l=okquf%)#)>p#cw|-K7 zyYrRs_pPtQd-&E*((U-~&Q}^7cD|DM!-v1W@BAhFed{mb?^}N|RO>Iv_P8A%&R@#={?2F0AMgAo{C(>)4Gue>DZkzMO}v+HeJ1>U z>oeuIJHKgg*!fJnhi`qRk^Z~${hjYLIPCnU{PE6r;=O$9JK^tJzbU`n`A+!z)_3AP zeCs#qcIwypPJ_eFcam}*{{Fu6pLj3d`cL@#)_=;U@BAnHed|9Bs{fSV?)<00Vdp;$ z4mqGH5fc2qd`;dKq-}zB_-~au-zt5QcZ~bGC@*4L$|3!6GV$??2 zf8;;;+i|0Iv76JnGR<%w`ET^`li0`y6YuyJ*FWW$)B|K6`5TAcA32l12Ppq$J0u+P z{~gl5e%7z6ygc$(1`qppv0Hk_@d^EdilusfIbM<+ef$UA_IDhgJjx|J@_0#d^zk2Y zg@4EK{mv*fBafFPM<2hIj8RqZI6hVC&OhXE=)Ye5dHg+FqzI$iAZb75Fy>}0za(&^ z^1sIRBOPBP4a>ig#~-QuW#s=z$M=`L^Y|l`U;d7#9M8kEJ3#*%u5cn*=_u{T`7a2& z#Y6v{t@$5)eClYoOJ>CJoPXWu%U>on<|?f( zVU*<%|2ypWCBm#6@$c)44ll`^@x^{~Fg)+d3%6{&__9mGo%1@!7thJAS9)J^z?9zuv-+Rur0petxr|kuoLLZdrRt zc-|GQEvK8N32RO>i`GP%&Nb26=5x&nwJrRos`XqmzoYG3(^|`aPN;2{|8&TII^{oI z@}C9#=e2Rok@3Z|#}{8qvRf|P5T2LdH!Z?56UVk4zi^VdeO+@$c;=R^$1a#;9=WKs zB|P(@W1AxjS{iv@$8>Kwwyk}Vxu>IzUDhAl+Hw54V1&Oiz19S?!ATwe8Z7yNnez$$ z-1DlI_Kwx-%&E1T)|q9sSMZ;T?dwea=JlugeRl?Ln7O4n(iC2GhQIab)eCz~?Sj>& zy|#tl2ASX7+-usWwDubRxBT;)`Mrwv+O{>uFaK#TU9iTqU2@r#mA=<~EluH>7x3ry z*U!6Z;rQb97f&`bTAON3WJ_~%o7r;3=6Rbho40kEnbp=b!e+slVcQnR$f`2dM-yVLM!QQy_XhY_ckHu}=lX4ZnHrRHm0 zoallpx2>JG?Xv1I1xtLslxbbx-1?09!1jyRUwERKwElDEqyOuIIE+5x+!_U6jZdqFO&Fk zUv2Ywfr42bEv$Hu4b683EF*R~IIDeu6Xoqa?7E6w+dJKwiuPV#$?8CXWv_CLGOk>! zT&KJMwsG8g<%PGMBoblf~c6Lk0@us1*X_hf}o3k(6mYBC;?ATd}=4r;9?|Ul{otbD3uiLUVJY(sa zH7A}JS<_y7l>f)R`K|3c$F(o&T)l7&*U-Z89kpFD@a+`PE~gVn>}RYoF-zf3V=gZB60Umd;4qP#`!QnBQ_t zd*RrLo%4bvXO);4Ev=^p$DUQbfEunPfSI z|3>D&ZNAphE3PyfFS=yyye*eb@R?svEWF<5|1ce9QpwSLtN0_nK#k8VG_8?~BD}EK z%#G|vCl!9nH@%DN1$zXmnz#7^3(fql1x+T<(9*Qm=bz)Fw;XK>#(ZD4l>C}M6273V zsgYNZCizx!QqrB&UgfW3qMpB?^;f=v`USN2#7^?G)i-~^f?xSsYPbK&=l9PjFDx>_ zk8*YGnrDIyktUYzXKBGaa#rFGviPrNerH5BcCoRwWx;}reI@gJ!MfJS8NQa*NIMnK zK?QWq^R-Oc^k8}MwBUr+lPZJPn38DF*V)?K`VUfdfy%bDW#Ff#f~!D@&*%4T9CzYH ztkC_Jx3k(md$QEsO>$|I z>gkZW`!Y+dveYh1&#=@cOKnnht!=HFeI?DlLjPI5y0!%&Q{Nt;*v&Tw+x)+qIJQ1$ z=2a9>Z6^i&4egy;YhC$TYj(BfzirIK)~krsV`^KAi`%HODJ54kk$%_sk1v|o`iIb> z)=rzKt+CS9)l&+oF&fxk5(x6Uds7OhyE?J`sY~jsV*y2Pk@{L->#MeNd_lnfjWI#$ zt24s3X0~aCb$*pvUoj>~jdeD&iGCQE)H>O9wytJsCU8R_@by68cY*m`ofNL*m%gBX zc8yk6dxum=i&j>9rz|yVWwjsbgd|u{(iaH%A1$8P`Xl0!;zg|s+NG*~TTxsetZ$3l zN!M;UDOk|i7HQVXYROlY+*SB|zKLx+#!PIxj#xXUwrzr~t!Z>H|7}GR+fE=}2`y@C zw&}Y5)wZEBc%v_HBP|!XKM?O(O6Uk|eQ$;#kkzQ8Ydk)u{dZ&j=7UoFk_ zFYd@*^#vMbY(5+eOl-S(OwjhR`CNq9+Ce94Ebj^u$--b?lS#8U&LIwQl5_d_y z-vnFQTD@N3Dy~i|E^=rPN-v8vOT~4%imQ!m9uxFGT*z?caJJO8Hv0-DwKkf-YUBST z{dj@Y+|p^%$<8VBI$2x0)>}uu-kA4fX!_q-ZS|w8t=%T-I$7J#>gZ%;?TpHH%H~F; z46hbHU2aQ#A)Tz<8(&qOu9J0qj3RW&m~8K2O!fphPOGf8LweaN+soR$&N->s^|IQI z^ZbSWb06o0UwbgcL+(~VUJ3>eTgv`voZvg|;B2bizmqUwO zRLD#+Gc*EHMZrrwl~%2^$6Bh^5(5Y!25;q{Rg0HnsaA^+y!2wxR*u!9A8oC-&+2*B zyJsfBL(lpC{m!rae*4Yd@80{}dp+-3&suw}H-sZL(@uKLT+xIJ8@V6YBuZmN1#AvN z5>`x7xtc+hepV&Bb}Gdo!w_V&4Mw17!Y}!)Grx3Y~ul&{0*Vi@=7Ks-vc)MY`82<=O_Dd(l4o?W-qG5af*=pNkp>` zV(X0Pb_+*tJSiIiJDzEEWqw$UVv21hK1U&P&=^D2?bRv0cR191i%fkTgMZh zVS)#H0CJG20S90et*qbxbo=+^l4d{18O=J;1K_TL2e677TIm<~%q9o$cQBh&P&M-e z`ZrIYf3u{995LIp1OAuaP~=lo2DqxSy_2G270HUrR|xEj;C|P=HSn0;eSyb3EL0Qs zOEAkVz-a@vi-46Dv4*eiYeC>8NKXj}wB$dPCH&O46brb=>LKp)qm*wTj_^_vLnJ0W zH%b}Dncprj|D!lrG_D%|Yuq3DLsY5KnH-Vd6UX{R>M#E}X(j>(C>XeVqBPbYp#Bko zemTk1tD>Sg=T~Zug6)1ycP=LddlepULW&gZdT60r;(xi14oeO|K0=2j_0yaSDsk)^ z!2hQa&FROsHqz>(>>CwZ`m8?EfM@-+ia?vCh?+BQm|TF(7=Vqu0Ks-gSaYty))sDc z5MN>dNd#_^^n-HEITPDo%UT_ij~qbf93tqXe-?xbO>Ujv66FB9B#$JFW?*L)ZLG!& zu(Or{<)4HVKxbNDw1;qbb6EK~2iCeDYUU7jW>Gl@T!3|~Ox2vP0Jd^$d$m@_PT@5f z5S1p_q-Yceoe@){SHl#P_2KJzzyln?QAr6kXL%*0Gy$AaasakZLQglM4T|(&H~>wO zlt>XXX8<$u0EF3(fY)`{+JTu%FQb0Jj03aV0M{&R?*lTI4_kojMS`r#O%+@~5i#8& z-8(!=0b*{^z-nBM7#m%W$o>(Im`grjIWFB?K#Ywahm#kF6+q?|)wD7-f=3#!oRS`D&I$nLxCXS>eJY~4uVA}8(&{En z0&Al~Pb^z6zVVs8Vx>V0%n!bxoSa0J^CK=#nKv1azzwe_dc2KLzAc175(*R6@oS z+2~;4EN)`GQ%T#J)S|XsrWUo$8UT$x1wIcNS5~#}^v{Fl@L>y}E#{y(7C__Q0Y!CV z_Yr2l9_6~WiGu9+TX~pf2kLgGj&p&!IYlSHA!EDha3;_-r>LZrexPd!*CoU*23kY0 z{S=6KG=BkN(+h~%B?}~Fvj^M&AhSz~2(|aQ#`Jwq2E=Hu4ItKiDha>{D9E>hHURy{hKDjjb-Ed#7+uIVyRnw&%`O^S!6C2_BXBOzH@oN%O7zw+ z5+?(dvx^>~(k-Z*W1k3{?rkdh(%A)#vr2|Y8fTs053G_la>zW3{6VE#atHx)zj30% zNk++up4=gn6fV%$z=_JYmCLir<@XVM3~~qlG42lh0e2vc<3WVu5;-~KK>ebn4G-Kw z!cPXD1QdI~9q1-F4Y3CWcYZQZG?{{r!QOy^Tbx}ashSTH^`T}y)XeA98>rb?fts~? zUdawg%{;H<5h(cZyaAWsL(M!7E`bx94>j}Q5&VF*d(R?|;1}&w)chiDlJw5!J%S}k z@&r*{Cd#)^9uQ`(_MKhOG#{Gg!%Tf^kUQ|dR7Fu6zG?@%13~*jHMz)W6D{& zK*Ilgf2-DQAr_{Ez3z56%(3Dh^#-FNmMaXRj7I-TrcR--E&2?3dN3uixLQ3#CwLJ zCm0s>j)7q>mQFUTj%tTmg}O&+C*i;`EA>?PK%_KQ2r0c3Kw(Fv&%&{ItMlcAV^YT+ zM5edP{Aa3gi~>kV*+Xemx^l6cu)SgriVLc`1U-CIx`vO6LHMZ4ai%otOUg+{wfLwj z<%Crn)#9V3(Ni4NI$_ifoGFd^OI1~~==l80QYD0+F^j`eB_!x64r`t;EJiU)!_F%^ z>B}00QN>&|wR+{+z2*9bxAL{i_+Mx~mQTfJ3+(lLh%FGN=&1$5RNX)me=(1aX5FF= zn;M52ml}^69}RmruMRh_%1bxJ6H}ar^J4kydFkPNqnWA{6dX1R?3P>^ds=9S@6EHN z!el+WRG6b*iiY*%V{Eb%W;^oJ<6wll@^<_y0=qMRif$|xQo8j~US^Ah7`rw9S(W$` zp>;*dcxDNGVBeaLf;DRG?H5SHd>rZlf z!o_oA$W+sLKez?2P(}}B^iYQ27vuNwz!*Dw@E|;#(Zl1yMPozg8;g@kx@wOgL3ima zp3YF+(ycvsx=rKI(;2-y?{`oR<3l8Dyg)o*;|2H$mu|4=_JLQRfqc4w?jH2?XKVk< z(@lCh>TDzf(tN?o4Uz+|j{rZAqws!(_b~Z;n0Sv0X+t>SOujH3O{4&WD6S5p$cD*u zmxRlz{lBa$pFKjpCqK%YGeX>7r+QS8{cNQoezZ^y8^SXQ} zVKrZuFKaN9@+#?8l+dg~35{KshXEz8%R`fv4HMYYIe~qr!qD9(a-@DK-F_51_r?u9 zc>yY8Xi~bfSU_dT$Jq(ql8~%BA-S35^5QT@mCM?}eeFUs%j8Ac5p|D_R^$coho?&I z0()PP$074v2R%h;sCl%6b1^+dje%ykSzb3Zb#^}b$q?N*f2bRKS_u!!R~hFj6*uQ= z%$xHgjX$Z3JdGy8&Ky#1a(PHx-y{|-$|TqIL`0K#qXBcocw6%{1#@pt7MTVOVqFgl&5r+_dOsw*c=a9mb zYur`&gy+1Mk5+m^r1H+3?w^G(cBMr#%RSB9>$aLn9{lxnDxMjmN)uA<(uJshH%bZSYSXXV3nnu(h2_$`ACzM{Ha)F z9_7H)rk3XQBzncA6NI5bGQ&Ds3XQQ)m!(=FuQJUKC0I*7j4~9G7uOQ?QYwQ=a1=PC zc%;g+{|9lzzm}-}q?u=Uz`&b?xb+?&au&9WfQoHhrX-~eT2xb-Y+qAaYy=kL9Yhs6 zltrk;5yc2xb=ak9NIHi4Qn2axnPd(5h;%uMOJlGdqyiHk@KPZ*Oj7QEMv9@~Ird?v zEPZ2!GJ-=;2W;-h4No2t4|AacG}ud6pljIQ)Bv z3Qd*f2V|!N$}K1*?ZP&j%1Iu9(q0|W6wgYU8pbag3$0OsmjfzA{Q)Yf@*f@2J17yn zgA&0z0Hrz~)$BAce?n0bDLx*@M7=Q=Da+y8`KDZ?f!BDPhP65lH0h9yBaY2)PQ)g}jY$5XARR`9&Eew(&|H~|evsAO4Y{z)H3c5v z%3N4_KtyF&zQc-75~Rg56mCbTO5;?cqT>wIc5?(Qgad%OPXMk}J(%}H(+ErBXu?Umz zquK*ivFD=Va}1-sl>^;EOKhWPE+?T|ljxT0#|}0fKa=cOj>uLThYA@K#kR@9i>QQW z#9qp>ePEUH0;|9hxNC{$F?IVoBTcp!D2nY|09ahW^1u{mP^|;LwKsDMYDhIhPcZT= zfu7jQs)<<~(v#1_@+x2{GOk%N5@KBQ1jaQ>j9ZDROqfKt^79JjO!ibpko+nas_`q8 z7MF~~J2}y=1JS+7nLtBmmACFBgm&%>y@OuPKeb=%G{i^QIxL{WM^6o zi$b)V^s_A0WE^8IBu^d|3Cp5h8#SHpp23ObaRhOlfx`IBV1kK7+v|4F$ z+-h)7iEhbB574bkbh{pU^FH1y*ieO`7f~-Z7?!q4YsE@W$YSYH!sCN-WYp$d^lp&u zB#`bX(gToNGf4Nx;8j7MJBvuSZIr@QhkYD8-fO&HWOb8}8H2LHK}~ZvqI6)km}erRLA;nIM_~p3J!ihOo?FpXoVdTGJR7&kBrMl z4UnnGMX%55ehdZ`EQ<0fAio2}X`WB=*LZ>6T(8S&MNWDwhxt4M+R+(SIOnh(V`cT4 zto&F`Pf|a$#(iQ4kM0PYvB2g9U~?6fRcJyuHe;p{HcEgE4|{=4uy%-ty}%||JH*3Y zcN`T}yw79flVM0(rx7;3U#(2?!N&4O9qLKrUMLOn{-~$~8>%YwHtL}y$eL8qiVk(IxeB{O+*qO5+B<&-?~T=uuD8rT)Snc=%?%|0-)H5i`r zz{q+|?#p4Z7@Yt<-WpChMgboW_;|p_<0Q!O(as=z-plguB_03+pJ0g)4}f35s>$!E z(jUut(FE{#NlwU69S=t3orj{yhoJZMxsdp@KnCTR{{tD%>5$Zr=ADLu$@}nB>v9z! z8>tA24I3-?D)m3Ik#rUbws*1t2eaB{8J*IdIZWZ)&WN^a5Gj-95h7NB zgrxuwB>XvJUOkKS#UmP~bZhXjNG^PUhW$Y_Tv8XVZR*C^M8lfDt3$rGL&Ks$eGm=* zCL8@H(C}}wVQCa!>vLJ7to$b1=y|ZFPZkpmuNx%?9^Cgf$Uausd^@M7s`8PV6gMW+ zH3Bt4w}I}0wh*QJDe~V7vYHQd@vqW!pTJql{UCIuKhPDj7xLd;u+oQIJNa+V8{l9+ z6>u;o8Xw^JpkO~B%g(t(!9Q;&jX^o&zo*~{9r8Ct`7wN+Z~|U+R>6`GVrf1s$+rd( z@Gr8FGqcF2J_n`Yb?TJG?PVkAw;4Jk))xHJmTFn4&dM8em{6r;NDH#;C#T42Wmejl zYqXq2%Mm5Xf4yJxpD827u}N_d%li{5EB&h^1PDc@d3_#P(o+FCd|e_eDOl4}WZ_^< z53z8SKLvkhNL2QpDHy_-lw%0u6??@MqVnXK@$dh*+Nz%w|9ywlb&<4nZ^-(RxKWTp z!X(3grelLJ_;;& zg~6eRmf@xFNTe>&uzwQO`c#UqprAfH7s*R%u=*clFeDqN6u*NfGkP*(OU0176p4Jg zn_d_>(?T~%8{CRFM9P%6#cF{?HxGI)hN0`}-jhC*lKAvcJWs_Nvmx0KCH77J(x2<& z#RZrGHoH#vAP;q*n#; z-`a;QKGq(JM#dV+Xe8oAT8z}&D_L1A#vW?_wQfArj*Rf)ibm9f_&ANaJsUfx$D8!z zH0&(IBT!e6sHD)C@RnSU4Cw|of^BF|vSmV;tttF$Z=Z-nD&+W9du)u=yi8*4?Z}#< z8E)J?S~uv_{X?2!IfX52_dwap+J~Z@^AZ*UT_9BG{+yRU@m-nt58HL~0wIcTkvrSj z-!f{mlWY~^o7!u!NIy1bjI!#8)NX}-mWj_9!`_i9^(8L|YK{M=_Q=d};|+IaKbt>I zQ2&xS#b?>_F+r|sSi8uxoR)^L=*)?+morQg>)FpTDa6(LACZDL-@teB-D7n2AKI+w~N61^G?0(3PYv{yE{>FW^y~t(tMk z#L!b2_5vRFKxP=eSqsbLi9+N%1om2{j4y1)PMj~S55JaSPcDM}ZM#KIhNMf!celP9Svma%m(o(TvM}}pE`uO)U6|0U! zhaJ)VCo+-Y-l`+I@j;iYKAsVOB!=G2uqQH69viXeOj!JSOa)(>r-3i$JDNFTV(5+x`z7A=hD?og8A4KCL+2qG{?qXw zFG(R8dkAupcL_g~VOTOPjf0$CbLhJnHeYFwYKKPOl7$(O&|4seWOt@a_q#J-Dm2w- zNKWbQaBLI5TCLFHSwS+(V;h6JvqJ0xQFV?)f7`Fyw`K4>YdA;P`L3+|AR|H8r!%^B z8>WcXoWN`U0B|Hv5#@yw*lm~wi+a5jdZSNg&=chou6vMCK2NS#%bx)m6)7=-Vg*PU zEJdflihT!iQB1{mXm-Qj2?RVX66*+F@3&Qvw`6(hVFkkZ#nd~yg68oGlJ5$VPgIt- zRV)wWJ2;HwOO?mYA1vRe{gLhcY&9)ApE*P=6Qxi4b@NI^RxjZ*dS^yYt}g1KVP2An zMDEP64T?P8GIsMk>P_?M+9)pc=)xdN4Jn0x8Or8n1eypg<_Dt-Vj?KSmlt_9)}Dz; z^C&}zy*SS_YgV}_wUex}+$2P21g}~#5)%mz`<9HXpmz>UQlU6 zFr5d=W2Xdfcu)fp}9Ru|MTMLAE(38_jxwKrkPxW zbmghoemPXP*H<>MC(|{R#_y0Dz(f|=C&olhI*J(^EnrbxcJuuKM;n@B(Tc9+A_DziE1lcUF*qdexl?FK{ z#6A|)(f!eXJFMG>(%~l07~PLBvKe;$LR#t)L!YK~>rgu2Aolxtvid?g`e~YNPb=m3 zkabF<``HWWGU_{&j`GPz&@=@xdw)9Cn`Y<0I1j?N)U=S&o|Qp`u9y~KKTj7ph;}#! z{XEU;p$yX_vAfdjT%}PubzJa#y8(_skKgXwgR9Y(QW7pC*Yw4Ou{_krvu5DuL$rrGT^uvLqT$M*4F zy3(U7Ir&&qFq7-G*me8~5Kobg(o%2~NDE7rIrns>V@j-2g4T9HKK_M<1fSAO|TrOXSrC>Hb_FOEFT#= z@38J(i&=D&eTR|wza||X;q;~1c~H(pBh|h%`NpMVWaXN)^s*T0OY1hSmIFcCkI#w1seUO;SjsE#^}q}Er`5&Uuz zoFIPD3%{`XjkIo{D+qhc)R!^8=t@x+g*_~Tnf!8EWKsy-N$L)2SR~`MTFe@dv6Gfs zvz3QqQkF6eYyrgLUrj$NX0kX1EeMo#W+XNtc*8&VG?BK2iL{)_B;sTe5luCkqKI{Y zh%cWOh#0FLB%XTu;piKObl)J>IsGuA{uIPmnr23kb*e|pN{UY(dR0tpU0B86ej^%2 zrs}~%xL9`#kj_OswcwI;U)N$41JZeX$unj!Siehp$;LxM9`4W%4|m3wgp-2z!@_3> z=X(&&Pgas}l_VURYBUf|RNh}v@nIX8t-AH}f^hHhV<|Ag(BbGkhjcrVE>}NWNv~Yl zD9hn=%z@5DFqstRFkD0(COUYC!SEi`V8vhX-uOa)*txr+yVt!vMgUNu)iLn z*a$+2kcEdhw8QB|ED7~1y0JlCI~_tv-xf$MJ+nkFZ?}P7-fn|rZ@2Mj>r`V0`5mFm zjDloO2O|WMJ%1>2=^?f(A*-*q1={w_39|ZHn+OJtL)%_&E0D~ywbUJnr}5is)JfG+ z_98ydjq;wNf4KD!d#EjhU%PGflc}59qP6C+HhDQ> zbhwB|+G4lxr|=ov%u`x+8Vf8+am{^@ic<^iaB8`a#FM0%NZ>4vguWHLq%xc6<@SR# zkG^5tPD0*JLZYcgLqf_U#fq!>>*+p^$@Z)-^FYY z!|18+v_+$jwdn@+@TDVP8z!shL|=;fIKy09O?|{L=uT2M|G2X4m2DPgf^jJq&BZLL zi`&>)hZJe1h^eE4Z?>r=iV4y*G`%>Zn_upes;vvY7@xYtFx(|S!g>#J7JeTZIT+@& zjh%RhGiHz07|7cQE~|7V@!>eVOy4?F)Gr2B^q|04ni zWr6;;{rSC%dYAWRdOLfs?%mwGrT5-mbRNZjNsJ#acW3Svl8QiGb=^(f$?o~xi@KM0 zU)Ftf_onWB-TS)_bRX<~x%+7Mo88B|KkP2wHe_4Pwx(^#ZQizN+h%Q>zirXB%(l*L zmu=g;earT(+jnlicl*BW`?nw1esKHC+mCL4bNlh_A8r?RlQ6*leMymQM=uV-4%te*Kji+Yy#WP1MJ#wp)DWOvQ(y4_8? zle@j$({|6=J%9J2-OG1p3QyRyd-LwPJxzO(dxYhRVBuNIyE0v!U6*xT-L$|@_?{2<;Nnh2n1Ua|*7B`Gw(i@yf9rv*i@KI~ z9q2vSd$jk>-dXo7zbA9g@q5bqhV<3+)%7*?UEQ~-Z*$+4K6()T90PuoZ(Y81GwwaQ zwY)3YwFuAMiRXOSRdd(0yDqzH=Uqqd8q)1`XSz3Ic8+&9VFosB+qdoLwjtXWZSTYk z9Nj)`$JQOkcg(+g%iS;IPr)?pTm-cC1L>Mw({^QcZP|5jm(Y_0W}AWY@t)-Ft-CvW WxAqEKLtH1yx8f!K@$dh81pX6`)$IrX literal 0 HcmV?d00001 diff --git a/backend-python/rwkv_pip/wkv_cuda.pyd b/backend-python/rwkv_pip/wkv_cuda.pyd index 8e668ed22c20899a9f82a4972ebda4146be5d208..9d7a2812127c77bb445cfa5b5e1f7da24b0f8035 100644 GIT binary patch delta 52 zcmZqpA=B_fW`hPJ^VGtGW^KlHZAK7g0%B$$X4$UI$eQT{l5C&u%({KLGh5~l0I@$1 AN&o-= delta 52 zcmZqpA=B_fW`hPJ^AUd6W^KlHZAK7g0%B$$X4$UI$eQT{l5C&u%({KLGh5~l0HLi9 At^fc4