From 4912916c11f68037510fce44bdbdcf1292550cb8 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 13 Jul 2022 15:52:21 -0700 Subject: [PATCH] [FRONTEND] Added support for element-wise function defined in external LLVM bitcode (e.g., libdevice) (#562) --- include/triton/codegen/extern_lib.h | 89 + include/triton/codegen/pass.h | 11 +- include/triton/codegen/selection/generator.h | 14 +- include/triton/ir/builder.h | 6 + include/triton/ir/enums.h | 2 + include/triton/ir/instructions.h | 21 + include/triton/ir/visitor.h | 4 + lib/codegen/extern_lib.cc | 63 + lib/codegen/pass.cc | 89 +- lib/codegen/selection/generator.cc | 39 +- lib/driver/llvm.cc | 7 +- lib/ir/builder.cc | 13 + lib/ir/instructions.cc | 22 + python/setup.py | 2 +- python/src/triton.cc | 61 +- python/test/unit/language/test_core.py | 46 + python/triton/code_gen.py | 19 +- python/triton/language/__init__.py | 2 +- python/triton/language/core.py | 6 +- python/triton/language/extern.py | 107 ++ python/triton/language/libdevice.10.bc | Bin 0 -> 469572 bytes python/triton/language/libdevice.py | 1661 ++++++++++++++++++ python/triton/tools/build_extern.py | 340 ++++ python/tutorials/07-libdevice-function.py | 74 + 24 files changed, 2634 insertions(+), 64 deletions(-) create mode 100644 include/triton/codegen/extern_lib.h create mode 100644 lib/codegen/extern_lib.cc create mode 100644 python/triton/language/extern.py create mode 100644 python/triton/language/libdevice.10.bc create mode 100644 python/triton/language/libdevice.py create mode 100644 python/triton/tools/build_extern.py create mode 100644 python/tutorials/07-libdevice-function.py diff --git a/include/triton/codegen/extern_lib.h b/include/triton/codegen/extern_lib.h new file mode 100644 index 000000000..c161ff142 --- /dev/null +++ b/include/triton/codegen/extern_lib.h @@ -0,0 +1,89 @@ +#ifndef _TRITON_CODE_GEN_EXTERN_LIB_H_ +#define _TRITON_CODE_GEN_EXTERN_LIB_H_ + +#include +#include + +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/SourceMgr.h" + +namespace triton { +namespace codegen { + +/// +/// \brief ExternLib is a class that represents a library of external functions. +/// +class ExternLib { + public: + ExternLib(const std::string &name, const std::string &path) + : name_(name), path_(path) {} + + virtual ~ExternLib() = default; + + virtual const std::string &name() const { return name_; } + + virtual const std::string &path() const { return path_; } + + /// + /// \brief Load the library and return the module. + /// + std::unique_ptr load(llvm::LLVMContext &ctx); + + /// + /// \brief Link the module into the given module. + /// + void link(std::unique_ptr &llvm, + std::unique_ptr &mod); + + /// + /// \brief Run load, link, and opt on the module. + /// + virtual void install(llvm::LLVMContext &ctx, + std::unique_ptr &llvm) { + auto mod = load(ctx); + link(llvm, mod); + opt(ctx, llvm); + } + + /// + /// \brief Run opt on the module. + /// + virtual void opt(llvm::LLVMContext &ctx, + std::unique_ptr &llvm) = 0; + + private: + std::string name_; + std::string path_; +}; + +/// +/// \brief ExternLibMap is a map of ExternLibs from their names to their paths. +/// +typedef std::map> ExternLibMap; + +/// +/// \brief Concrete class for NVIDIA's libdevice library. +/// +class LibDevice final : public ExternLib { + public: + LibDevice(const std::string &name, const std::string &path) + : ExternLib(name, path) {} + + virtual ~LibDevice() = default; + + virtual void opt(llvm::LLVMContext &ctx, + std::unique_ptr &llvm) override; +}; + +/// +/// \brief Create an ExternLib instance based on the name and path. +/// +std::unique_ptr create_extern_lib(const std::string &lib_name, + const std::string &lib_path); + +} // namespace codegen +} // namespace triton + +#endif diff --git a/include/triton/codegen/pass.h b/include/triton/codegen/pass.h index 0c8f11315..95b00b807 100644 --- a/include/triton/codegen/pass.h +++ b/include/triton/codegen/pass.h @@ -3,6 +3,7 @@ #include +#include "extern_lib.h" namespace llvm{ class Module; @@ -30,12 +31,10 @@ namespace codegen{ // TODO: // There should be a proper pass manager there! -std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, - codegen::target* target, - int sm, int num_warps, - int num_stages, int &shared_static); - - +std::unique_ptr add_passes_to_emit_bin( + ir::module &ir, llvm::LLVMContext &ctx, codegen::target *target, + int num_warps, int num_stages, int &shared_static, + const ExternLibMap &extern_libs); } } diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index b408a46ca..7867c356b 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -6,6 +6,7 @@ #include "triton/ir/visitor.h" #include "triton/ir/instructions.h" #include "triton/codegen/analysis/layout.h" +#include "triton/codegen/extern_lib.h" #include // forward @@ -199,6 +200,7 @@ private: void visit_make_range(ir::make_range*); void visit_clock_inst(ir::clock_inst*); void visit_globaltimer_inst(ir::globaltimer_inst*); + void visit_extern_elementwise_inst(ir::extern_elementwise_inst*); // void visit_make_range_sta(ir::make_range_sta*); void visit_undef_value(ir::undef_value*); void visit_constant_int(ir::constant_int*); @@ -209,18 +211,26 @@ private: void visit_argument(ir::argument*); void visit(ir::module &, llvm::Module &); - // layouts void visit_layout_mma(analysis::mma_layout*); void visit_layout_scanline(analysis::scanline_layout*); void visit_layout_shared(analysis::shared_layout*); + // Add a new external library based on given name and path if it doesn't exist + void add_extern_lib(const std::string &lib_name, const std::string &lib_path); -private: + // Get all external libraries + const ExternLibMap &get_extern_lib_map() { + return extern_lib_map_; + } + + private: LLVMContext *ctx_; Builder* builder_; Module *mod_; + std::map> extern_lib_map_; + analysis::axes *a_axes_; analysis::swizzle *swizzle_; std::map axes_; diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 74028f822..8eb1c2ce3 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -169,6 +169,12 @@ public: // Utilities value *create_clock(); value *create_globaltimer(); + // Extern instruction + value *create_extern_elementwise(const std::string &lib_name, + const std::string &lib_path, + const std::string &symbol_name, + const std::vector &args, + type *ret_ty); // Built-in instruction value *create_get_program_id(unsigned axis); value *create_get_num_programs(unsigned axis); diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 3fa008606..4e60d3444 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -154,6 +154,8 @@ enum value_id_t: unsigned { INST_COS, INST_SIN, INST_LOG, + // extern + INST_EXTERN_ELEMENTWISE, // array arithmetic INST_TRANS, INST_REDUCE, diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 402208a8b..1bad86c33 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -1097,7 +1097,28 @@ public: static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr); }; +class extern_elementwise_inst : public instruction { + extern_elementwise_inst(context &ctx, const std::vector &args, + type *dst_ty, const std::string &lib_name, + const std::string &extern_lib_path, + const std::string &symbol_name, instruction *next); + std::string repr_impl() const { return "extern_elementwise"; } + _TRITON_DEFINE_CLONE(extern_elementwise_inst) + _TRITON_DEFINE_ACCEPT(extern_elementwise_inst) + public: + static extern_elementwise_inst *create( + context &ctx, const std::vector &args, type *dst_ty, + const std::string &lib_name = "", const std::string &lib_path = "", + const std::string &symbol_name = "", instruction *next = nullptr); + + const std::string &get_lib_name() const { return lib_name_; } + const std::string &get_lib_path() const { return lib_path_; } + + private: + std::string lib_name_ = ""; + std::string lib_path_ = ""; +}; } } diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index 774f2e172..5f84f414f 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -84,6 +84,8 @@ class prefetch_s_inst; class clock_inst; class globaltimer_inst; +class extern_elementwise_inst; + class make_range_sta; class undef_value; class constant_int; @@ -177,6 +179,8 @@ public: virtual void visit_constant_int(constant_int*) = 0; virtual void visit_constant_fp(constant_fp*) = 0; virtual void visit_alloc_const(alloc_const*) = 0; + + virtual void visit_extern_elementwise_inst(extern_elementwise_inst*) = 0; }; } diff --git a/lib/codegen/extern_lib.cc b/lib/codegen/extern_lib.cc new file mode 100644 index 000000000..0a1f165ea --- /dev/null +++ b/lib/codegen/extern_lib.cc @@ -0,0 +1,63 @@ +#include "triton/codegen/extern_lib.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Type.h" +#include "llvm/Linker/Linker.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" +#include "triton/codegen/pass.h" + +namespace triton { + +namespace codegen { + +std::unique_ptr ExternLib::load(llvm::LLVMContext& ctx) { + llvm::SMDiagnostic err; + auto mod = llvm::parseIRFile(this->path_, err, ctx); + if (!mod) { + throw std::runtime_error("Failed to load extern lib " + this->name_ + + " at " + this->path_); + } + return mod; +} + +void ExternLib::link(std::unique_ptr& llvm, + std::unique_ptr& mod) { + // Set triple and data layout to match the target module + mod->setTargetTriple(llvm->getTargetTriple()); + mod->setDataLayout(llvm->getDataLayout()); + if (llvm::Linker::linkModules(*llvm, std::move(mod))) { + throw std::runtime_error("Failed to link extern lib " + this->name_ + + " at " + this->path_); + } +} + +void LibDevice::opt(llvm::LLVMContext& ctx, std::unique_ptr& llvm) { + // Add nvvm reflect flags to llvm module + // https://llvm.org/docs/LangRef.html#module-flags-metadata + // i32 4: Override the other module. + // i32 1: Emit an error + // If both modules specify Override, but the values differ, an error + // will be emitted. + llvm::Type* I32 = llvm::Type::getInt32Ty(ctx); + llvm::Metadata* md_four = + llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 4)); + llvm::Metadata* md_name = llvm::MDString::get(ctx, "nvvm-reflect-ftz"); + llvm::Metadata* md_one = + llvm::ConstantAsMetadata::get(llvm::ConstantInt::getSigned(I32, 1)); + llvm::MDNode* reflect = llvm::MDNode::get(ctx, {md_four, md_name, md_one}); + llvm->addModuleFlag(reflect); +} + +std::unique_ptr create_extern_lib(const std::string& lib_name, + const std::string& lib_path) { + if (lib_name == "libdevice") { + return std::make_unique(lib_name, lib_path); + } else { + throw std::runtime_error("Unknown external library: " + lib_name); + } +} + +} // namespace codegen +} // namespace triton diff --git a/lib/codegen/pass.cc b/lib/codegen/pass.cc index 412e2f4c8..645f10978 100644 --- a/lib/codegen/pass.cc +++ b/lib/codegen/pass.cc @@ -1,4 +1,14 @@ #include "triton/codegen/pass.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Linker/Linker.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/axes.h" @@ -9,24 +19,66 @@ #include "triton/codegen/transform/cts.h" #include "triton/codegen/transform/dce.h" #include "triton/codegen/transform/disassociate.h" +#include "triton/codegen/transform/inline.h" #include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/pipeline.h" #include "triton/codegen/transform/prefetch.h" -#include "triton/codegen/transform/inline.h" #include "triton/ir/function.h" #include "triton/ir/module.h" #include "triton/ir/print.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/LegacyPassManager.h" -#include "llvm/IR/Verifier.h" + namespace triton { namespace codegen { +static void link_extern_libs(const ExternLibMap& user_extern_lib_map, + const ExternLibMap& target_extern_lib_map, + ir::module& ir, llvm::LLVMContext& ctx, + std::unique_ptr& llvm) { + for (const auto& iter : target_extern_lib_map) { + auto &lib_name = iter.first; + if (user_extern_lib_map.count(lib_name) != 0 && + user_extern_lib_map.at(lib_name)->path() != "") { + // If the user specified a path for this library, use it. + user_extern_lib_map.at(lib_name)->install(ctx, llvm); + } else { + // Otherwise, use the default path. + iter.second->install(ctx, llvm); + } + } + + std::set function_names; + for (auto& func : ir.get_function_list()) { + function_names.insert(func->get_name()); + } + llvm::legacy::PassManager pass; + pass.add(llvm::createInternalizePass([&](const llvm::GlobalValue& v) -> bool { + if (function_names.count(v.getName()) != 0) { + // Preserve global functions + return true; + } + // Internalize all device functions + return false; + })); + + llvm::legacy::PassManager pm; + pm.add(llvm::createVerifierPass()); + pm.run(*llvm); + + llvm::PassManagerBuilder builder; + builder.OptLevel = 3; + builder.SizeLevel = 0; + builder.populateModulePassManager(pass); + + pass.run(*llvm); +} + // TODO: // There should be a proper pass manager there! -std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMContext& ctx, codegen::target* target, - int cc, int num_warps, int num_stages, int& shared_static) { +std::unique_ptr add_passes_to_emit_bin( + ir::module& ir, llvm::LLVMContext& ctx, codegen::target* target, + int num_warps, int num_stages, int& shared_static, + const ExternLibMap& extern_lib_map) { // generate llvm code std::string name = ir.get_function_list()[0]->get_name(); std::unique_ptr llvm(new llvm::Module(name, ctx)); @@ -47,8 +99,10 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC codegen::transform::peephole peephole(target, &layouts); codegen::transform::coalesce coalesce(&align, &layouts, has_sm80); codegen::transform::prefetch prefetch_s(target); - codegen::transform::membar barriers(&liveness, &layouts, &allocation, &prefetch_s, target); - codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target, num_warps); + codegen::transform::membar barriers(&liveness, &layouts, &allocation, + &prefetch_s, target); + codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, + target, num_warps); // run passes inliner.run(ir); dce.run(ir); @@ -56,7 +110,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC peephole.run(ir); dce.run(ir); pipeline.run(ir); - dce.run(ir); + dce.run(ir); disassociate.run(ir); dce.run(ir); align.run(ir); @@ -64,8 +118,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC layouts.run(ir); peephole.run(ir); dce.run(ir); - if (target->is_gpu()) - cts.run(ir); + if (target->is_gpu()) cts.run(ir); align.run(ir); axes.run(ir); layouts.run(ir); @@ -73,8 +126,7 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC dce.run(ir); align.run(ir); dce.run(ir); - if (target->is_gpu()) - cts.run(ir); + if (target->is_gpu()) cts.run(ir); dce.run(ir); align.run(ir); axes.run(ir); @@ -97,8 +149,15 @@ std::unique_ptr add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC // ir.print(std::cout); isel.visit(ir, *llvm); shared_static = allocation.allocated_size(); + + if (isel.get_extern_lib_map().size() > 0) { + // If there's any extern lib calls, + // we need to link them in. + link_extern_libs(extern_lib_map, isel.get_extern_lib_map(), ir, ctx, llvm); + } + return llvm; } -} // namespace codegen -} // namespace triton +} // namespace codegen +} // namespace triton diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index e2303b990..b30283ced 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1195,7 +1195,7 @@ void generator::visit_cos_inst(ir::cos_inst* x){ for(auto idx: idxs_.at(x)){ vals_[x][idx] = call(cos, std::vector{vals_[x->get_operand(0)][idx]}); } - } +} /** * \brief Code Generation for `umulhi` @@ -3154,6 +3154,30 @@ void generator::visit_async_wait_inst(ir::async_wait_inst* i) { call(iasm); } +/** + * \brief Code Generation for `extern_elementwise` + */ +void generator::visit_extern_elementwise_inst(ir::extern_elementwise_inst *i) { + std::vector operand_types; + for (size_t j = 0; j < i->get_num_operands(); j++) { + operand_types.push_back( + cvt(i->get_operand(j)->get_type()->get_scalar_ty())); + } + Type *ret_type = cvt(i->get_type()->get_scalar_ty()); + FunctionType *FT = + FunctionType::get(ret_type, std::move(operand_types), false); + Function *F = llvm::cast( + mod_->getOrInsertFunction(i->get_name(), FT).getCallee()); + for (auto idx : idxs_.at(i)) { + std::vector args; + for (size_t j = 0; j < i->get_num_operands(); j++) { + args.emplace_back(vals_[i->get_operand(j)][idx]); + } + vals_[i][idx] = call(F, std::move(args)); + } + add_extern_lib(i->get_lib_name(), i->get_lib_path()); +} + //void generator::visit_make_range_dyn(ir::make_range_dyn* x) { // for(indices_t idx: idxs_.at(x)){ // assert(idx.size() == 1); @@ -3741,6 +3765,15 @@ void generator::visit(ir::module &src, llvm::Module &dst) { visit_function(fn); } +void generator::add_extern_lib(const std::string &lib_name, + const std::string &lib_path) { + if (extern_lib_map_.count(lib_name) == 0) { + extern_lib_map_[lib_name] = create_extern_lib(lib_name, lib_path); + } else if (extern_lib_map_.at(lib_name)->path() != lib_path) { + throw std::runtime_error("A library has multiple paths (1) " + lib_path + + " (2) " + extern_lib_map_.at(lib_name)->path()); + } +} -} -} +} // namespace codegen +} // namespace triton diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index 92a6b75de..c4a13b806 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -358,8 +358,5 @@ hipModule_t amdgpu_to_hipmodule(const std::string& path) { return ret; } - - -} -} - +} // namespace driver +} // namespace triton diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 510994fd8..120b575cf 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -379,6 +379,19 @@ value *builder::create_globaltimer() { return insert(globaltimer_inst::create(ctx_)); } +//===----------------------------------------------------------------------===// +// externs +//===----------------------------------------------------------------------===// + +value *builder::create_extern_elementwise(const std::string &lib_name, + const std::string &lib_path, + const std::string &symbol_name, + const std::vector &args, + type *ret_ty) { + return insert(extern_elementwise_inst::create(ctx_, args, ret_ty, lib_name, + lib_path, symbol_name)); +} + //===----------------------------------------------------------------------===// // built-in instructions //===----------------------------------------------------------------------===// diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index dbee5e0ee..7831e1650 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -988,6 +988,28 @@ globaltimer_inst* globaltimer_inst::create(context &ctx, const std::string &name return new globaltimer_inst(ctx, name, next); } +// extern elementwise +extern_elementwise_inst::extern_elementwise_inst( + context &ctx, const std::vector &args, type *ret_ty, + const std::string &lib_name, const std::string &lib_path, + const std::string &symbol_name, instruction *next) + : instruction(ret_ty, INST_EXTERN_ELEMENTWISE, args.size(), symbol_name, + next), + lib_name_(lib_name), + lib_path_(lib_path) { + for (size_t i = 0; i < args.size(); i++) { + set_operand(i, args[i]); + } +} + +extern_elementwise_inst *extern_elementwise_inst::create( + context &ctx, const std::vector &args, type *ret_ty, + const std::string &lib_name, const std::string &lib_path, + const std::string &symbol_name, instruction *next) { + return new extern_elementwise_inst(ctx, args, ret_ty, lib_name, lib_path, + symbol_name, next); +} + // clock clock_inst::clock_inst(context &ctx, const std::string &name, instruction *next) : instruction(type::get_int64_ty(ctx), INST_CLOCK, 0, name, next) { } diff --git a/python/setup.py b/python/setup.py index 7ed6ab444..6c136b6c7 100644 --- a/python/setup.py +++ b/python/setup.py @@ -98,7 +98,7 @@ class CMakeBuild(build_ext): if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) # python directories - python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include'] + python_include_dirs = [distutils.sysconfig.get_python_inc()] cmake_args = [ "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DBUILD_TUTORIALS=OFF", diff --git a/python/src/triton.cc b/python/src/triton.cc index 4e1849733..fcebeeb5f 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1,5 +1,6 @@ #include "triton/codegen/pass.h" #include "triton/codegen/target.h" +#include "triton/codegen/extern_lib.h" #include "triton/driver/error.h" #include "triton/driver/llvm.h" #include "triton/ir/builder.h" @@ -19,7 +20,6 @@ #include #include #include "llvm/IR/Module.h" -#include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" namespace py = pybind11; @@ -140,7 +140,7 @@ size_t get_pointer_range_size(uint64_t addr){ // Launch void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, std::string& cache_key, std::string& params, size_t& params_size, py::dict constants, - int num_warps, int num_stages) { + int num_warps, int num_stages, py::dict& extern_libs) { size_t len = PyList_Size(args.ptr()); params.reserve(8*len); // 8 max bytes by argument char* params_ptr = ¶ms[0]; @@ -256,6 +256,11 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f throw std::runtime_error(err_msg); } params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]); + + for (auto item : extern_libs) { + cache_key += "-" + item.first.cast(); + cache_key += "_" + item.second.cast(); + } } // @@ -288,7 +293,7 @@ void init_triton_runtime(py::module &&m) { // cache key m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages, - py::function add_to_cache, py::object grid){ + py::dict extern_libs, py::function add_to_cache, py::object grid){ // parse arguments to compute cache key, compile-time constants and packed kernel arguments long _num_warps = PyLong_AsLong(num_warps.ptr()); long _num_stages = PyLong_AsLong(num_stages.ptr()); @@ -296,13 +301,14 @@ void init_triton_runtime(py::module &&m) { std::string params; size_t params_size; py::dict constants; - parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, params_size, constants, _num_warps, _num_stages); + parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, + params_size, constants, _num_warps, _num_stages, extern_libs); // get cached binary py::str key(cache_key); py::bool_ noop = false; if(!bin_cache.contains(key)) { - noop = add_to_cache(key, args, device, num_warps, num_stages); + noop = add_to_cache(key, args, device, num_warps, num_stages, extern_libs); } if (noop) return (py::object)py::none(); @@ -467,11 +473,10 @@ std::tuple hip_load_binary(const std::st // --------------------------------------- // CUDA -std::tuple cu_compile_ttir(const std::string& name, ir::module &ir, - uint64_t device, int num_warps, int num_stages, - asm_map_t &asm_map){ - - int n_shared_bytes; +std::tuple cu_compile_ttir( + const std::string &name, ir::module &ir, uint64_t device, int num_warps, + int num_stages, asm_map_t &asm_map, + const triton::codegen::ExternLibMap &extern_lib_map) { py::gil_scoped_release allow_threads; llvm::LLVMContext ctx; // device properties @@ -483,7 +488,9 @@ std::tuple cu_compile_ttir(const std::string& name, std::string ptxas_path = drv::path_to_ptxas(version); // Triton-IR -> NVPTX LLVM-IR triton::codegen::nvidia_cu_target target(cc); - auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, cc, num_warps, num_stages, n_shared_bytes); + int n_shared_bytes; + auto llvm = triton::codegen::add_passes_to_emit_bin( + ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map); std::string tmp; llvm::raw_string_ostream llir(tmp); llir << *llvm; @@ -502,14 +509,16 @@ std::tuple cu_compile_ttir(const std::string& name, } // HIP -std::tuple hip_compile_ttir(const std::string& name, ir::module &ir, - uint64_t device, int num_warps, int num_stages, - asm_map_t &asm_map){ +std::tuple hip_compile_ttir( + const std::string &name, ir::module &ir, uint64_t device, int num_warps, + int num_stages, asm_map_t &asm_map, + const triton::codegen::ExternLibMap &extern_lib_map) { llvm::LLVMContext ctx; // Triton-IR -> NVPTX LLVM-IR triton::codegen::amd_cl_target target; int n_shared_bytes; - auto llvm = triton::codegen::add_passes_to_emit_bin(ir, ctx, &target, 70, num_warps, num_stages, n_shared_bytes); + auto llvm = triton::codegen::add_passes_to_emit_bin( + ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map); std::string tmp; llvm::raw_string_ostream llir(tmp); llir << *llvm; @@ -523,7 +532,9 @@ std::tuple hip_compile_ttir(const std::string& name void init_triton_codegen(py::module &&m) { m.def( - "compile_ttir", [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages) { + "compile_ttir", + [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, + int num_stages, py::dict& extern_libs) { std::string name = ir.get_function_list()[0]->get_name(); // record asm as we generate asm_map_t asm_map; @@ -531,11 +542,20 @@ void init_triton_codegen(py::module &&m) { ir.print(ttir); asm_map["ttir"] = py::cast(ttir.str()); llvm::LLVMContext ctx; + // construct extern lib map + triton::codegen::ExternLibMap extern_lib_map; + for (auto item : extern_libs) { + auto name = item.first.cast(); + auto path = item.second.cast(); + extern_lib_map.emplace( + name, triton::codegen::create_extern_lib(name, path)); + } if(backend == CUDA) - return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); + return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map); if(backend == ROCM) - return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map); - }, py::return_value_policy::take_ownership); + return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map); + }, + py::return_value_policy::take_ownership); m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){ py::gil_scoped_release allow_threads; if(backend == CUDA) @@ -931,7 +951,8 @@ void init_triton_ir(py::module &&m) { // Utilities .def("create_clock", &ir::builder::create_clock, ret::reference) .def("create_globaltimer", &ir::builder::create_globaltimer, ret::reference) - + // Extern instruction + .def("create_extern_elementwise", &ir::builder::create_extern_elementwise, ret::reference) // Built-in instruction .def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference) .def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d032d1e39..cb2cb9c33 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1300,3 +1300,49 @@ def test_num_warps_pow2(): _kernel[(1,)](dst=dst, num_warps=1) _kernel[(1,)](dst=dst, num_warps=2) _kernel[(1,)](dst=dst, num_warps=4) + +# ------------- +# test extern +# ------------- + + +@pytest.mark.parametrize("dtype_str, expr, lib_path", + [('int32', 'libdevice.ffs', ''), + ('float32', 'libdevice.pow', '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'), + ('float64', 'libdevice.norm4d', '')]) +def test_libdevice(dtype_str, expr, lib_path): + + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + y = GENERATE_TEST_HERE + tl.store(Y + tl.arange(0, BLOCK), y) + + shape = (128, ) + rs = RandomState(17) + # limit the range of integers so that the sum does not overflow + x = numpy_random(shape, dtype_str=dtype_str, rs=rs) + + if expr == 'libdevice.ffs': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.ffs(x)'}) + y_ref = np.zeros(shape, dtype=x.dtype) + for i in range(shape[0]): + y_ref[i] = (int(x[i]) & int(-x[i])).bit_length() + elif expr == 'libdevice.pow': + # numpy does not allow negative factors in power, so we use abs() + x = np.abs(x) + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.pow(x, x)'}) + y_ref = np.power(x, x) + elif expr == 'libdevice.norm4d': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': 'tl.libdevice.norm4d(x, x, x, x)'}) + y_ref = np.sqrt(4 * np.power(x, 2)) + + x_tri = to_triton(x) + # triton result + y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device='cuda') + kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}) + # compare + if expr == 'libdevice.ffs': + np.testing.assert_equal(y_ref, to_numpy(y_tri)) + else: + np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 30a79bcc9..3951d8b6b 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -689,7 +689,7 @@ class CodeGenerator(ast.NodeVisitor): ret = triton.language.tensor(ret, self.prototypes[fn_name].ret_type) return ret # built-in function - if sys.modules[fn.__module__] is triton.language.core: + if sys.modules[fn.__module__] is triton.language.core or isinstance(fn, triton.language.extern.ExternalFunction): ret = fn(*args, _builder=self.builder, **kws) if fn in self.value_constructor.builtins.values(): args = [arg.value if isinstance(arg, triton.language.constexpr) else arg @@ -933,7 +933,7 @@ class Kernel: self.fn = fn self.cache_key = {} - def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages): + def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages, extern_libs): tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')] # attributes @@ -953,9 +953,10 @@ class Kernel: constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) arg_types = [Kernel._to_python_ir(arg) for i, arg in enumerate(wargs) if i not in constants] - return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, is_manual_warmup=False) + return self.fn._warmup(key, arg_types=arg_types, device=device_idx, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, + extern_libs=extern_libs, is_manual_warmup=False) - def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): + def __call__(self, *wargs, grid, num_warps=4, num_stages=2, extern_libs={}, **kwargs): assert num_warps != 0 and (num_warps & (num_warps - 1)) == 0, f"num_warps={num_warps} must be a power of 2." # handle arguments passed by name kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} @@ -985,7 +986,7 @@ class Kernel: cache_key = self.cache_key[device] stream = current_cuda_stream(device) return _triton.runtime.launch(wargs, self.fn.do_not_specialize, cache_key, self.fn.arg_names, - device, stream, self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, + device, stream, self.fn.bin_cache, num_warps, num_stages, extern_libs, self.add_to_cache, grid) @@ -1242,7 +1243,7 @@ class JITFunction: def warmup(self, compile): return self._warmup(**compile, is_manual_warmup=True) - def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, is_manual_warmup): + def _warmup(self, key, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs, is_manual_warmup): hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() # create cache directory @@ -1264,7 +1265,7 @@ class JITFunction: with open(bin_cache_path, 'rb') as f: binary = pickle.load(f)["binary"] - compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages) + compile = dict(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs) if JITFunction.cache_hook is not None: name = self.__name__ info = key.split('-')[-3:] @@ -1293,7 +1294,7 @@ class JITFunction: self.bin_cache[key] = LoadedBinary(device, binary) return False - def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages): + def _compile(self, arg_types, device, attributes, constants, num_warps, num_stages, extern_libs): # create IR module context = _triton.ir.context() # get just-in-time proto-type of kernel @@ -1316,7 +1317,7 @@ class JITFunction: backend = _triton.runtime.backend.CUDA else: backend = _triton.runtime.backend.ROCM - name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages) + name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, extern_libs) max_shared_memory = _triton.runtime.max_shared_memory(backend, device) if shared_mem > max_shared_memory: raise OutOfResources(shared_mem, max_shared_memory, "shared memory") diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 0b04465eb..6b0058dd5 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa: F401 -from . import core, random +from . import core, extern, libdevice, random from .core import * from .random import * diff --git a/python/triton/language/core.py b/python/triton/language/core.py index cc0db5566..4197a3333 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -248,8 +248,10 @@ class block_type(dtype): # while tensor's shape is a list of constexpr self.shape = shape self.numel = 1 - for s in self.shape: - self.numel *= s + for i, s in enumerate(self.shape): + if isinstance(s, constexpr): + self.shape[i] = s.value + self.numel *= self.shape[i] self.name = self.__str__() diff --git a/python/triton/language/extern.py b/python/triton/language/extern.py new file mode 100644 index 000000000..a306a2e9a --- /dev/null +++ b/python/triton/language/extern.py @@ -0,0 +1,107 @@ +from __future__ import annotations # remove after python 3.11 + +from . import core, semantic + + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, _builder=None): + ''' + Dispatch a function to a library + + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _builder: the builder + + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, core.tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_type = arg_type_symbol_dict[arg_types][1] + ret_type = core.block_type(ret_type, ret_shape) if ret_shape is not None else ret_type + return core.tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder)), ret_type) + + +def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, _builder=None): + ''' + Dispatch an elementwise function to a library + + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param _builder: the builder + + :return: the return value of the function + ''' + dispatch_args = args.copy() + if len(args) == 1: + dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) + ret_shape = dispatch_args[0].shape + elif len(args) == 2: + dispatch_args[0] = core._to_tensor(dispatch_args[0], _builder) + dispatch_args[1] = core._to_tensor(dispatch_args[1], _builder) + dispatch_args[0], dispatch_args[1] = semantic.binary_op_type_checking_impl( + dispatch_args[0], dispatch_args[1], _builder) + ret_shape = dispatch_args[0].shape + else: + for i in range(len(dispatch_args)): + dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for i in range(len(dispatch_args)): + _, broadcast_arg = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder) + # Change the shape of each argument based on the broadcast shape + for i in range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder) + ret_shape = broadcast_arg.shape + func = getattr(_builder, "create_extern_elementwise") + return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder) + + +class ExternalFunction: + ''' + A wrapper for external functions + ''' + + def __init__(self, fn): + self.fn = fn + + def __call__(self, *args, **kwargs): + if '_builder' not in kwargs or \ + kwargs['_builder'] is None: + raise ValueError("Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)") + return self.fn(*args, **kwargs) + + +def extern(fn): + ''' + A decorator for external functions + ''' + return ExternalFunction(fn) diff --git a/python/triton/language/libdevice.10.bc b/python/triton/language/libdevice.10.bc new file mode 100644 index 0000000000000000000000000000000000000000..ef3ae8d81946e5401289886e62e415d6e10783e8 GIT binary patch literal 469572 zcmeEv30xCb*Y{)(5HLW{009S48x?C%5Zo{UR4gh{(Yl);Qe4oexYepz5MqlOms+ga z1|QqnR$E-Kw2!t4pi+w#0e4hTs1*V{Ewhh5rIW z=xS9NcOYNfwZ|G(mZ%+1=pW2|D@Bj@W5=%Xh~;{Bl|&1riSAvE<^+obkLz##0Q2t} zTc6-%@ornHRt+`e)|q~IldI~_A1cK#F$=>2&xbZXgo`D{mg6~@c{w_)8s6~%0pzQw84j2_^Eu^Wo_^oEVw>jb}%I9a=z@0O;(#9RBVu8b0BeE z&TLFLe?d-ng}Z^%uC3bhh0?#|CBG&+k?*MdzPz$i}QmH z=X=)H1y$OEuI9_?^F0q6f-3Wa&e(#^+k?*NgARwt%2{|#UCT(B_RG+UFCnJ&;n*FaB6QFGa3^&zq| zP*g1Vd`M7@A^0%V*D5;|5_Bft6B5+qdqRg6+htfW>I-Q9a_C&tp`K?@je?kcs>{(q zv!rw8=ySTK=jg10%%*{H*$`Z;m7O38e;x#BWmoHhpesxBgUjtQBv%t$scLsR@zQJfUCA3qe(- zIdh4^BQ28(US9x%fir8HU=Fk}C|zDoW(5?i0mFZ?2VDkLf~I&}7n@&46*nzhAeduC z!+zcZ)Wwwg=smM;p+@E48F)*~)?GXZ(_q*@up>fbpNyJVe2_5PjqN>1{HCT@BP+|o zH;CoF#>ufZ_jS#mwHxP!&kQWFN0Du8>4L&_5s`plqvH23OWFb9Qz%e^>|h z0M1THoz*bz%?h1T(J#6d-_S-U6A$(e@bW7AQ0Okp8qQtDuk66iVc(SGs|;^k7Fye2 zv(#$uXnlG*S0C*@v{aiKKcHW2qjhT!qgC3kb!&G)VrjCw7rJJ5s+gyeIn^O{S7J_a3eQ>Eo-_<}KNucVTymi!=Exe|t+lT!ym4 z6}pCgX_M2V)A2oBBO=~dJthR|vE95w=-IlyyI8@q2n21T%sOuW=xBNTIlG;8Iw*9o zce_C`%OcjU$TFFxZSAln>)S3Z1$ook>U60F_K%)zZVmelzFE|HjZ;(WKGt3NdRv|p;*rU{weHX;JL6Y~w5Um#du z6}(Zv=?NpdASV-M5n)dEXu*PQVCx05@^U;iYOf57*F2@*b&Vjif`1og7%Lb;L58Nq z99;0aLNL36e}niC-9j4cuS%hR*I}vLQ+6(=LpzwJ;9svit^9+w7zP7LBN#6Fa2*%J z;GcNK@PJDJ+aeO8hE36xaafob6$>t#MZjh&(L1>GA*;BUl6VV$;)30TC`;nX?hQ{+ zH=>0&j&da~>8%7p1aYINU#)lMttpG63(8>#;?V_p&;_-m3kstP($EFj>4HpjK>`k2 zNtA}D(f{j$FukL$*wgwbo!)fE=uP+JK5D^eI(K>>J$W!ZSx^l3AlR4LokfP33VXuVT4OD7EZiJgMR=3NTK6S zQx?isI`0(qAsmbVmFR2fEel{jwj-rzVHu>Ov?059Dn=j7T`{)Okx(Df!XIPdjyqIH zjJnZ>Fbese7d^u?Muy!mwh|*RD8E0wK@Kq*B$Q9@Wp55$cg#>+s>Gs&rHt~YFvNvp zY$YiT3__BCRyAfJgXj&?lipUD^k)2#-oj7r=w}sF0ZQ#C9Srz%jO)CM3y#3pN{qaq zk)>{bTKTLvs1Hs zGj*0;qdbOvC}#}&=QKx?7{fmNKVgo*dedsnoDTjA*6(@6J>xiJ&OT3Tt{2=p{{?Yz zjKL7X7!0m+D}BWN-*1#G49#*y4qT^pB9lwAOHAazG`S^=#?&$zvw&gYPcba~bc`;p zJH}RGF78(q3hZy26F={u9QM)2W zW7aUt)C(4;W7IC1Ub_}DdSeQ_=yp3)K%b;mGiq1H=(~DG-_6FjddzXf1^2+%O2-W} zCPUl+7jg8dYBi(pT*cA*ZZ@NKuHsm@LpWd$FFMo-6)?ror1GyQM>0< zgciZ{+6h|F8`H$-J3C{j)G|h91EVn)FdDNrqcMXq7LJh@pyC$vq2elzK2&NMedj8U zt9N7xeb<{|GhM_nYe&~RJEJjOYe%1e7GO+qY$Zlspm&VM9KdMIe=+**7(?7uM(u7h z8uN8VWA?$=N>UoKnJ-pci`Mj^lEbK738Qv(jM_CYYS+l9-5ZSB#n5Z_g2k<3)b0?Y zc9$5nyUD0sBcpcr8MRx&s9j%r?Ow3BBzaR&Y?hZKbVcxWr<-VhaHP3AhM4qLE}Oon zzv?mpVM<3r!e(N)b~M zO!kt_ij1P$AJ=6veWh8(Xm*CI_ZY>1X$nRoafueTrOSHq0>s1^+l=lAWqiei3QQ1N z&nO`iDlkB7Qb&dj@zY}pM}Z%q?`WO!eFyqeE~FW45We@U{l#3!Ke?gk(-VVTWKs}} zubxsoMrZx)GB5VS??7z~MtfuB{3dAx{CA2Z`v^r2h87}YRbaYtW0LIcvDdY*cP zdeC)YHJBhM7I8^$F^BFGdxX&%Pf^w9DPA0-c=?Rt)i4@xHlqPk84bu#gzsOb_ZO2M zWQzNdA&#kcOmjzXLnc~dKE>2lMuU_x6gdUs>WN-}KXJi*=mUVsgdAe8&tP1V-&1X-uriiWV^Xj;VJy z=;B?QbE`8g5{p5!)iPZh=GJd!(9(f_oc!~wdgiXM z#0%g5s&0Mv?TLT>e&|ljh*jNBJiK-Mbl;VWFV5(<@%?)>?~gs#cx*HOZ157ovQFcj+(dJ$d!-&-fX;|EC20#a|(Zj{CI^qR}>BFV4Q<@3((h{pGj&ckcW$cmMG(j|~=kjXi^>pLNf_)LwX~bB1uATd@J2 z41K4)Fz4mL;+5my#bx^zLdXf1a=OYn-J&@W37p7LobKtIs0EzpQ6XilC{=VNq~8YVM?rc!WcaVmkYO*# za6V+%02y9{1Z9vQ4-%-7Ai;axBjzXtb2Wnbdci`gV3AF*xK^-Kp^Cl+DQh654N|U# zlyUDvU1FeWOCYO1A*=h4;65b4K6Wei9tH`vLV^THum%#eD`dIzhDv$EAY7L@J8j~vi6ax=pj>$I%_qzE&(3ThsP&<&|WA=grcs8$6tm76Ci;N5||-D z$nd~`(RsWv+jwJ(c;m`=<4^G>T;)xKthyD3U^!5i;eqT~a2r_|g2SyAZs)*lRAGn` zZVhldcU$qnw3C3rzt5XnC_l zm$4St#a_4dUfo2o!NvBVqxrHl1&>H(3M}XJDU0iq%jr|TAQz(N76}Cl6CgNI@VbtJ z!q)H{|8j4yYZk9MMT;4V7MB{l8q{9h#HA=s6M`R$wZUnM<-PWqdR4^r*u&|)&(u4M z)2qBeGEXB|lrD%uk(GideU1*IEFnS@qE;JuknMbZ&O~9(EGZE!m}dyP(cpEX*6WVV z>zaa-QicLZ$uLp~){{ZVYN-LjIr%3naXrf7dgKv*i%O!`#YLMrz1Ed<0ipMey{b9A z3Q9zWrsU{G<;>QZAbxWJL~%k8qBY0==UKs3#(bP(2N(o})Up_ikNUlt=^LAuUURpZ>2LpKdIa*R zafrrErdo_ea$hxA`iWj;t?tfkQ-;gXfP2z-g5~C8pVy=34Uyb6Fod7r zL9e;jXnk;*-Uk}g2dM>#{lgO#xl$=6G_w4!Rch(oKp%T8Y5jou2Ku4mV_qCq?`Yw_ zSq?p;9JPP19M9j+L7i#+yypr1JnpZ?Bx7F7Gjwuu=<<_n2{S$GFVi!}hN`2>IAUM~ zW?FlPXHTnoYT7n11|C1=oO;iGRPS*ZbLTpcZEOHTm-&%v8zXFl=Kd(#{l2sM%&YTl z)zv@UP+J8b6E?sx_R2nXi^C`-(BdvwDAW4iuvc>0*x2G)ZmMtpYHU3J*gV&oCV%Pk z$`6Hhe0;z4p4&jP6<6rCf&g{twF7d7f_Vt#3&5i5`Af272gB?NwcXo*quGz#Pepl|3;Y#m}Oe@KcP%|W|E>;GkoV*Z04~ z2I~HFY_RUwKZOl&Mmfy?SUDQv=1Wz|5!OJjB?cfy>dMNyg1&IHZMj# z@4Wc@@;8fU^1Hrx`R6=n^1Hlv`9XB~VJ}|(cz0U;JG^-Lo84&g+rN1E=Y%x*p)X#3 z5M4h0yz-fAAj1QS|DbU7Q@OxARx{HZ{%mIcf0>@S&Qjf{xH<<9@w7GBAJ3jv)r_?E z-CIe>Ef)AHcm^W z-7oIuOnWHi2&7dPha0<1TErl_NXGHuJ zu?BimCDdnH`OZC+LaM6`tnmLx6nci@tnfW&}z2SN@^KWK)!?kAS z|Ci~ZyPg+Q`Pg&HXJEx5)fpGuL($sUBe#qkve5-dK+ z*ZKT*W44AaU-RPSE9mkSFJ3-Imyf-8`4wGh{a?}bh4;UeF5mj%tY2wD$eY$hR43?faUMZ!^-`|NXRl8SC0p zhJSOJu^f830Gn~~POtr__~d)fy3`!4y0A^VOlE$(9D z8&u6mYk#j9`8Ff1eRDJNZAM!Are@^(>}ebA8=pDfMpZM?+TU$PzRgH$-_VSFn~~O@ z*Nl9hn)dnT=eFR3$XB!{=Lg3PF+Z@}<8dJhTP4?le~k~E_9-7i=yKVsT_@@Z+>rmvs#}(j{atJrEJC&rx5YSY+2Y#OE*9_!YErMLt`(qKEzW%e6U| zR<5p3ESH7WUf)D#k8dnpHpoFlKCQjl(faWbzn)}Pu6M1X%TYK~4{vp@X}lwU4N31$0}hYx;uQA$LVj!siyqsXYzreDUsU{r`(1gR4Pz_HR)n}rIa2q>4T}!{056_jVNdqtL@)28UO4^UAQ@ z+=EKq0C6nkCv6z=oi45dDqLcK3iCm%l0PF$w43{_*?_XN7y{dshN-$jlxU6~^@lH% zedHKLOa(upo{WPA+}V1JSzlS1Hr*k-)!nAU@eZpuL>YLIN&jD=8StK+KdtR zN{p|~D`idNtEZ6~U&RHi)m#qcTeb~3sEgA>J8*_l9e-2UEm7lW#Y3S-pcSDJvecqL z@n8jPP=}Q)+8>2%&1=qD*Cr6P=1|R&m@u5E@@Qqk_mb0TtdQQrQ$K?7u`%6VTLmY# z(*4SqID6B(3a&m}E7YUiFeH9@%ox^xpEkzG3d+SQ4ygy=*r6doy?EG*9>agB_t)CK z^3gw}_fgPxj~^?}m`P%+F!obC)fR;q$brC;7F31Th}&uv-zg=q!H<}OK8q;#kvakr zzVr!C)QFOG+z<)5kp*!|kF{TUX)5t3Gz{+e_kMUGnF^LaJ%tRGM|uH64C#gVQ>Ho) zxSHfzRBnb7xQ_RcC6epdujjxZk8mwidOMYVHHjyllpgJMguPjg{Pj=FpJsm_*-t{{ zLCSQ47f!{;_^Fs+ zI{XzAe{sPg&22-E0^(o!9oC87)u!koL+>I}U|tFQK#0`Bk`=&T74g>~yJk2hxNPI2 ze@%O!u0}viN%Zq8tp6COfAxQ`{teGjf7?G; z|JR}Z5;W62P5C)M)yT_9+3r)Z@Z7leoChxWK<0t5C7n>^2e{$M478 zd-$*~@eX~)7Kjf?SwbT&bSDxwM)k*xc-a!-rU4dEV84ldjyJ18$CE(EWhx3AD3ZR{ z#fg{E69BJ| z`1+Jhx&psZFo%pvReYz{+5H4#?H~A!Tf{oFURO$PrqrtjF0Kb<16K$I*qZfU!lTdq7{$BO~CVqk`XVnY15!!%;ZfW(1`A43K@P(h zT2(##4k!{Ea-}kemSY91TdNiC=6(-oFv2)I4ePgAC802Sq@SR~(=J13NWz2dLNHnC z`PcS{rYI8e6^_RXeB-v+!weBnZ;N;Z0gqYItLuWG2^NR}fcpATf=bt}0X%t}Dn;+k z?%|1Q&tzR`TQw!HZy^X&wFZHvxlVz1+X~`jSP zLFm7b&1t2uOeL9cdY4uC8={)v0+s@tGbB+zHy&?s0 z1FbQg{SKNLJHS}$K30k8lUmzA2 ztxX!Mvn%f1>5~An`BFP}pmijquUg_$i4d?td2_%D0K}yem)%CWiPZOpUV^=GKjwrG zh?yf4XC9|dR$Q+NQ(!f7R=vhLCb(te-=fxm_O+m0Xa=>2lXdP6?Bd`yh&d6>ljweI zw(?zgAucVB5oa(^WkBC=h5gC`o)l{*8X9jU#}j@$SsgEC?jt6iM}M?9ZLc8#-|l&% zb&W$dN=L=D3(;Fi>pPeGNofRf(x|W z%6%Lf3O8|Z!|@Y@3mqz-xhplUgK zJJ-j5H@D1zUw_Li_rZ}?2@g5!2P;8a;%R6F8nxVqR5OFjdIYSG^TsorFh%ig zDA59Ow`qiVK?`kGaH%M;+$VwK=gx9F1`n$96)W+{Xp{~e~z8d>3Nm2bxn0MtMVj<>n8fi_Km*TXfauBIuF|&P+Qy+eJi%`ncL!eD3)5fAz z?Wgee!d`eAv4nSuN|(^>&!Lw=Tko0uTR}K5ny(P&20B;pDKATm%_y`=vq&x}0dz%~8yUqgdiED9{K>AhJ6e&R841+8?`(dhvk_#EZq>Mw2&pHppw z%Kz9lg3qaO-xNbmaZQ_}VceHE#{ESY_vD!c#r{P5lxCVb$sK#WAKyOtPHyLM?epI* zpZ#x^ubybX>0%G)2;^LEYCp7=BF0!em@IazbF|F@l7WC!1Wx@zcCE9ZmqUpgh8zZn z0h%ADu$?Ij87S6;3GjI$an z-z@HdK|K*lk{R=1D@++YE0a+KRYlS%2;>Lw6}N>VW_}nBbGZY2y_0l4%#PTp<PE&6#Yz1lJtbC|L1n-?$VBrFwSm-1M z>8WHH*!npLlvMlGNKv{5NEbJmH(wwa%L+U8RQI=x4c(-V5D^b+C*w*lFT&rqX1FPI8lM|F<%P-0HI-}!s z+{6z5uGRpxx47ryXbUs=i!x=g{&=cB7yJ@19*CJt z_v$Wa-HKRf-AxA6TqT{jyG7|1zRGX66;)Lg063T(pm+@`!v}0^;ji>B_t|b&dVJ(z z2}{_bcAJU|Mj`!2+!B9+64xgME@T<7B^g(FTz2E$ovp;e+6#Mue9X)!LT0-AL@?7c zis9`q2rTWH+>)!G(voZMqeO&Cn@A-=zIBJM?sP_|%=uB_j63kq&j%r^!i?KI)5I$S z6x94NanJ@tTndjr3_TJb?^$sFz`A#;Mabt>@t>4}<%x$(bRxRRMhkLxMl$wj9E%Bpa&j<;q^plIlSXowx<2*_&j$tFf znEG;GC-f5erf!ELOf#~9#S2urhZ1Uvg<@OP50KG>t97Q@$T}cZ7qbq?)RL`gJX1Zp z>D#`U>KR+r;&#)HZt8*l#|b{@@7`}HhcLPDv>NO$x*8v{GEtTso={ju*y0k;jV)oo z+WonutSk?&)S{4Dn`XT#+N9wK-;Y*f;+LZ>b=kEX(`bH*(2NON$pEH^`5GX@Qtf>4 z%FMiWa(>4cLxRX1E9MEh)`+bv7!J|_1uSlHP@3DDquHs}*jdX${XR3ixvIy=x`1ub znz~VN6W4O(u%dKK2&f|fnG_?h#4|g<42rdjyPNwRtaXClTvNd!3t;s;s#$NC#UADm zvQHI&G+ODYEUHz8nK|N7NQ0tRAPKx_$itiL7Ks~kyra;|B7&>ed8-M>@IYe#oVdu3gNORk6 zaPMbL$cOSP4Q3=RIdm%>-C~Ew=#`2*1HrF9n@Ga)|O(g9)!*vIu@+RZf=@Fo!`hw<4=Y% zG;sRUUK7=@K)X@l1q(_&KT}-pv(>I#78{tZ=Fe8BZyE#vn+!6r-+K-Gwdl3Tz}Mme zqUi8D!!U$AFXL7=SxjLM=+?Nv?Z`x}%7xV`vm?b{7vqi|ZmxAiKybdAy(d-_1?JOqf5c~8C38pV7!)ivIoLw zGzHZ_sdVu5o3Ec@4TF|2p^rJi_3*I;J)K+2nkDe7F^$xsyRr#ats@wnIUm z*N;{X(%(Eai*+&iz{Sv*j9+=?;eKzeYTR4b3Li2eEau{cgX8qkBX%~LCvMzACR#FM zEuT#Em0tjv=u-9s3qH6|6m#)p_BhA8s`lxQcb7h-zCD{peRgFp_3r1pQyl4j?LX1+ z?!sQ`vtKSxA;o{Qahu8*alC7z`NG7NcONEn^B&3xTgr*}CN|!4Ct9ASje|8Q3ovlT zEuL}d*lms)%{;*4p^7c%8&?G`GSr@JIZ=&6wOXd5mrL^|pqG>DiI=klK`ms0y}^(&k*}}iE0~T_vf4wgyAkf9W;zL3#C>ek_9{=UKhf0>aB))?rHr!Gy z#LgY*=b`U#&9r7uT(rd$SWwcd$R3?bt%CJVdf7nG%XZ=~E9jbmNM90N#Ok?~mKdd{ zh-gGg6Y2FVAf?ybE~!N8=*4)8>D0fhD2`I_!*DJyxp$<}ofq#piQD6=EnCk#(EfCR zLlR_0s{+wo=A}@fqVxB=_pf*m^6$IFQv3Geo=2dTi*}sY71dw)wN~`{BrZ%5zjVp< zwib?g3)%xc_7K{zoA)iJf~y8z#MemduNEzJ;Ph69OXyhQ>J3AVgG`+OWLm>0nX2BAnTm2B1w}C3tgi;P z&ITwYt4$aMS^Itf`p!bsRNN&BfL^a|UGCGHOi^)NB8MB0c1(Cw7gr!C4&x;A9j+n0 zYP10y%D6ko6%VBRMcNVWySY`4MTmAXx`A588#ApogkZ&kqV{k*0(O$LC)<%F-Q(90 zTKW^z(r*K=CPhnR|4r&nmf|pbWQqngL#i(JGc{y&60%A*MkXrx0uP`TJEXgPr1#ro z9I103dZ5EN-u8k>E4yV!Zg`@_Gg&snF!_w~0ISugj{ui{qKE9ywA^IH1GRadj}XeE z!Ksx4q^PFFc5&N->9JUDd!c>0sguw)9o*ypa$_mp#kq&%W$OcD^IICVoHjV$wy6-a zhpFbRGEV=}Akl?LyW{LB*iXB7oe#IHW%EwY=5%a`RodLFs-Ws*Xk7Bh>(NS6OCYI**mEp~jI-4eNTeNV<*;f(- z+YLdVWzV(wu(w?d`YdaKtt(qPqlmLs2pHuTkDuCR<+W>sS=rw0tGetq?O-ZmiHdP~ zYH*SG_1d2}YidUu71C|dTt#ipOzTCZ4f_KUb43NeaF=bvTi@U>lZTx02-wbEV)2Nc zQKb5Exgu|}*c2{2+!`{L!jv;a%n2nY)r;1(Z=p%mcxm}tQqyQ-SIAe`6<_t5)fif4 zi+Io5#Ac}!7{~kP<$$4HLKC{^C@u|Gh{aq^*V+%XdF_-YZw?0%;hYI=%%j;6(rvZE zZsMa>uiF)M+4?YeB*e2CX8uu-eI$ID-qsV3l!nVu4hr?oUD6rCqivuWPgy;U-X@fj z($vs3yCAQf&g8uBDcp(f4el#7oP5up3@w6Y2v4D>&TG8jDWZ@j>7=&+b#zAU<#iDX z8I43lC_mqjo|>{=>D6E4#pT*D&x*I&Z`wf4y({&c5U(`1s{k^*_#LBFUKOccKNX|Q zueK3=7&v`Pdn$8I(5H}j?aPpPzSlKp=H@OX0cmbGG^MplrC3yff4%apymm#XoZC2d zub(*HkX0kfYC+0N&aA?JMp<=1@MS-C2&T?NX@#tCWBU*wJo zFFHCX0^nFBjFjv?wI(XD?VD^*=+NJr?u9QC_t1@Zm9?f6&R5VGcqE9lw}=m+T3@%j z8d`rOj0OJ(n8exoKAkGJM5MXJY8|z*mWm1$D@GgD(0k^-Vx@MyQUa|NBC*Ce&br(H zbJRN;1}9v&sceobJhLIi=$d7f)G=^s?l5* zAVV(`H}0^9--CH|&&PqX4e&wI*jd0@-||8`?FApS$Xn6a@6!>%p zL3IB+AKo4b{6y6TWoJ%0yG_L=-%b41|GtBP>)X$Kb3m!m&G%vxaEM9ymFMnAI`%7I zxBE}C#K!<;xL7qZL)n-)c{&A>TBXXy~94h5>=-T{_9)TdH{Xh4*#>uG8|$K!~gd4{7Vj6FwvsK?db z3U4ZECCVi+(cbbCNKljLy{pD0IWdwQGl8J2Kg_G7OIgMUAMe5QmZ)5E8GZvq_H_BM z3d)i44R26t`z-!tHAr{AW2F`hVB|{Y=Qesm*)?kX2Mrl?IU@x|}Xu)1^Gdvih#A3x0(R#QN zt%nQ1_n<#MM#C*uVoM7(ET4LXdy3ddgtf>aW4N?9hNb*3CIf%v|6;QL#bp1B$^KDH z=5vrkp@IJ8F98}TK)A6R5MHG=ghI-!?+8=YW1js*B7h8f(d!?})7&ZmCN|d<%f1#?)O%A*a&sq@^kA zZICmGq)LCr52AYj14hrR?(&ynGK)ge(NhadJ>V%4vZQYBoE>QV0u?ga?1U`%{`4&y zmIDhy;k1O$-_@lfoCdkEy|SY7+F4Db**yVsaC@lN_KJ#2c| zrii!)Pkrzl_NdBhCtOAq z#kbpuiv&)Cu6>}ymW^f)h1SYa>wm{hI{d$j$#y&oCJW3&Knw*3tTHxW3r7* zOqNDqvW)~L^CvLbWjZDs`#hLzllnIYUhRtulevRRyO7cE=bjC#{T)m;PVEBJ)-izE z|8z{|2e=6^SsB7)r3jM^A~2bl!erdOV>M!v7+bj4hH=|S8>RykyI%<082)ujaV)m= zBs`))`0Ud{n=s3OJ(vnEh%JyQTv{B%mJ@hvE5VTsg3li1$okB^jbg3--{;7FJ~DF> z_j>DH(|`fHRygLn^HMwTVSgYdP-ujpF~z3^Lw+rmvhHPvUtK|RWKM2O60pw6jcu18 zZY+(-joqhnWA`7&jj0AX1Gk#EF;-8{X#^KWMKksEo{m4ljcJC@1a7R9$&FPEDGQ`= zW2#>5bdPXj_94fB8`ChkF+=ajIdpF9R(-{z+?eKV;KuC13~uZuL5DfGv75QvMT*=_Z zS^+m^0B-C>9z1Z=T$6*FH#&+Vc_~0VT!rGrqtDnY=VNKkaaWTn_?WDM|PaeaK zS)R^~xwHKpEEc}=J(&HMAU>+RR^*)h&vjgGXp5kigGbSPF~}$f$U{+N)O?R0CMz@4az^Jl=j?=c`-8`=Ahq0%bCfJ!qSLsh3yX|>;E^8pe& zskGW}gf9{IPAaX{vAS%Gq|zP$uh|F`otM{r+g0>>lhfm>$!PjFqCQa)mLxW0+n0FC zdK@ukEXlgQ@}33W+w@%kfSM7&6E7PobnjvTUUnJcWqBPW*`DUJ3clu)zF>t0Q^!;0nGAqHWqU=6TGPDfcx%BK|VGI8u@B>Pr4B|!exC6c;12=NCeUsE%9K9F}g1I3_~ z!Jx%lZlM>CC-+Q9LV-8v*3rp3Q+0BfQv&mpgB7vQw1S z@GR-lc+omV$dF$<#aD~pB8ad7Yn^cp3R^9zf$&D?dXSnq^Ag{f1e}`>$>#+wHW-xz zS5$SEfPR66z{5pmY~{NForc(_Qhrv_$-N&0v|R%X6qKL`!iksCW~dW@)QhsBsvWq9 ztz8jOMMT0vSJ#&hKnXlv*bY5=A$yulDN)2LyWKNSyzmGr?Bv3n5e(kT zE*~PMF6B?2ureV&;`CRnRtoQsq2Cb&&M~^y+`MX2larGq@KXOU-v~YEVB0cUii^AT zBLmUU{;hK>Tu@~49|<=?yHB*r5C;b>Fgd_M%`M6JgreqVUML{~ejVY%&8tUzv8;?m+67D2?SOd+r$+K&TR5nO~qBb%`w%@G6J5Y^xEnY_c)N!+jc z9&Rfog7q(65ZWLT|FfuXR(S7OTK3s~1qG_?6Fx1xw!+&oQ)+60UFkbcTb%`5vqT2iRC}D~kJ3Zbfw>w>z<~iwu7+OiqPx4RQ-0B`$YeRE=0qAuiGlSJA5k zf&wTX-EA)w<%&!g_@fnq-2G8N0Sam<4`V^3*>WShGgXLAllkCkB;sF?tC5s@F;z5) z7PW7#QdyiaYU^6I=vgD)g9@NbrlQ>N;TeDf_n{Bthf$TW44F~^LBBzJXISpV_-u@>P ziF#^FKv~r+@7>%}(6*zBM@$v_K-q4qq=TZ$R5FL>6ys0&0l8D<2r_=?jP(CKbwNp( zy}P4q)}%uN%6+DTM|HK^8OmooW(D|B)d(xv=2GJAeLVxE^ zz?R%e@IxtXDcC{>vi-I?OcYfSD~1R;R)Dj1^V@kq?K-2{L9m>{;h7c3ebE5u^R=e? zBFqEbL*q}zA?$^PBIAnvm6IHC@cm$nw};ra%*CPSDQoUidEXfNr@MD#=|!;G8)gCr z(AaJe;lFX3(@`@JN#KvXH5-u+XWv42d+zKTcFVmzE8y+svuNMKT~9S-fE*r%4<|tI zTqGpeKNdDtZ&ERKD4Z+3Lj{76HZu8-GVM5*9QLlNIp~rs^N+HHuE~m4Z7Op~1}pI$ ztt8--y4`1#OQtciUm`a>IJ`Yv)$1X#QOIZ z+Tvgv*I8kPqnh`Qv4~1pT;tIC>Z}m#s0LQPlN7M6VG_LIKwjw5;uw}d*flaX7KM*A z#l|-6GrBlg!6qUighp5b!e*lTE4;UMdgaYCQp1f^UETQ9A`vsvmyzWiw^!{`^e;F$ zwezA$-1@+S)~Nmsu&8_hd5rnAu+bw^%_XfU3bde%!5u18miWCjQi zupz)Nt7#QnjaI=^(AsQ}FQsf|-*A3On4z0P-SHtjf7pn06f{i{JA80&M=|uOL4{y( z8c>)>gOaL<{daL`m@2|ik%E|h;fZ3jl_*p>xdxf7OeCsOQtqSP&8=4QKJr)-LG4Yd zPeyx_=2PoLQt@qT!>hpzMmoxZB?+{!uqIJ4EdHPvYKnH8oO(<0@`p`1P zX^8rqRt<;esO?7fKIaUR&^^>C)d(SeM<7PmZ&SgSoL5f_ggljCf6iG13&8#myB(xk zN1!$AN2qM)TT$(#U#Iol&DBs%jtx81uiQtEEYgDFSI;M+vY~t$1WTOxl#b{G`FxJ@ z(c)nA91$_E_@76?y;N_>=e|)IW|yMgk_uC(c|+3_ziUzXJLjLyk<&GbM2+*|&i#;) z)j%MfOvR7&uMV&pq20ekB42BoE>shL06SaYf!dBg#c5|Nl6Sg)iNb%WqC`wCFuq*w zvkleCYOsKvordh}RBQtZ!$;>NkY_7U`xCxEu=$IN6q4%#M|(t9el8vw7lu6Er_qM2 ziKV3ZDkL{p-7^nG_ApQfmak|vygd@Q2;T0`K;|xR4!r%+4v%GAV7nNm)PoS~AGm}7 zy*GIPbqm4!YU>eCcbQjZxLUZvRW+^_4h-Q2*JNM_54*@)_eYt*HQC-(qn4;v-!ms& zSmu(tK6F000Vb|9xvd47ityh$xI)e6U>_}r_*;O8zdeOG0q8Kz4jw;T$kL7R5Giq^ zeGlRX^(P?N-KhqakQGi?$rH#*9tCr`9yXrKns$L`aSVHxu#W*Tz>%@i;gsGPx-5)~LRgTrHI0=67Qne0Fw$jgYNMBJ*kmxSn~-spCsLZO7{Y-g zJhABMO^#cmh#47$S}#GL1CLSL37806QW@4~?cKmhN`| zWvW~@cV3T3pl|Nn;~W05b93k0lfR`4Pd0uVxUOS3Mg0taL60KBXfqiq$ortH&i zLiV*NdkcRmWKTv1-mZwQ1_l&1q&V)yJx-x}G#CPeQx%Cevy*e1{{l$szMJik{l{XDiW~nH!fH zYO6j>N36{;G)bU6C^moQBoXdv)?tz;ak1yAf0PYOr|>V07f9WQ2%rzi^JC`1WHdLU z4#~N@(3v_h!O8EWUWU{^L--i8y!XnE5G7)XO2|#g8Llx#dso#T90v=IrJWdvJGBPM zLdU5Pb4vicdI*)m2WU#UjADJxj?r=T3h^=reIP;;3mjOogEtV9WGSSAiD(Nl?2&v+ z(bhZ*Q>lpQe^I_el`><|5&WB}hI((qX zhh^yDIjsRq8WrDd=;2MqdxNoSI*m&CI3HF4(}!j6`4}IT>Ty0SmWvOo;c-5!n8*0A zeq2Y3@BXnG-5`sH;=5gZSggnSur&RU+lrjOGBrf7H`Eg1yPZC)0v8`v!Q*^bd18hS zOB1_;>BIWX{rP=Zmgn|i(c`yB|FI8C{CFQ$jf)Qpe}WIIw)C++tT$YISZ_Su zhoyPE59_9j59{XReOM`v^UvTx zsPr!EPl?zbw`F={j~gNA}ZVwkM&;9nN870l=BMP zy&s%czq>fEOn=1k!)eYd(;u&xiF=ng@WUb$2j1koBI35bQuKl#`xsJ!>E-eY&3Y?} zR7z>VjR9GQ+vH7dF!dFtxjjGuU%bA$->qk?{lFEKfLu}7GPThCWK}}Ka0w9_K57rKu6>73 zrn%FH#0OmBILdovU%Da|ScrZOPZTbWs%8d3-#zAQiX1WH)-Bc7LTZ`^Y_Km@)h>@IQ)zQq~T z-Fb>c%I`crax$T#C~I8T@<5UP>BMJAN+&+c?0M1!Z{7%PHOB?5or{?*B3s$vn*?m= zPcprZLuJ;k$}2L2-%Ge&)qR?dh@$r-{6nzC4bDq(#BLn-7kmGYrrMKF1MIP_`1!Q~lr| zmRR-acDuR7u>9a=Vz4&SiHlgeV*5yaKVIGhIF&*;i*+d2uBDbp?FxY(7!M0T5&SBl zWBXYm=Of}|7K#eZ{4Pv|{6r>Yb=?8SD%Spr^F$|&j9NJxq7C+jBR)8_ioGb|b77(H zaumy;4ao6=m<5ZRdJ=Zal|f1;|9U09Z7Oj#6yjg0vPKr7g+6g+L6wuvyA}}vTV87^ zPNKlK4%|WpbmT#91;s{Qm7uc=W}{_n3kj97vf5c~I!ZuvA$Zt1_s^pYA;7?K^e-dZCbCHinX4!!rNT;5}`lYNtLPwdZrZsb9>Ht@BNw$-9tU6v!lo*u}1Og$57AhhzM9Pk*cS4 zSV_Qr@a0IU7TYGV^R2|m3x2XXw~4$NxZge}Gk(<%{72kfsEuh%t!mJ>stypZHHpep zJbaGEokrHG4MDcm#dcsBK32e^sd~QsWA>r*9d*|nwPH>A5o%>aY=D}d<|G@gt?MX; zWxPLCwrkZLC6F!bIoHMe3xOokBQ`+Cua=6Uir77XVVGj&+mEmhod|<)C@cEX!WRah zgJVErB3Rq4}Fyxo)LLHG>7@g-0PuV zFh4209(slO$(`$=t(XlHc_VZH^OLzZLJ2$V@a0z^?zSFrx3y@Y{UbWi>lU)dGm$<1 z0&K8)EZAT@a)E6}-0h%gc0a6GEzVhuzg4MN<*QnoJTxCY_0<@ihzDmcjDrXD+rgJ_ zMaS9@mllV8>F~~%t0vC3LVmSO_1$d;{t)nF`#zb9PKns#Mq2PCIZda{<|GBhG_kwj z`hSpJf~@NOK+^q%%PUl0LSA3!u633z@;G#&pn_7k-W`Bm@`YiGb_^5J=x-NKzBkU6wGuB8lVWTf%7) z5m1i5GseymV{8mz!ofJ{1C_lEoI#&_zxVHpVgU$NIu7k?4CFgbBZCug z9=xd}&+WtI2(OlXqjh-tIYH|@9H%H;AmJln=(bFwj)iDCZ|pV~(BUvJRZgFMQ178l zKKq7br_X+FKjs_$forA7t|9oCNbWxL4v5wwIDzaPaajFtNhNS_J83e{^@YeBhwG_E{6;YfL7ERIuW772K6%8{ zcL0IqBZs9Zosc!rc}#;Xg2G2Y+r@Qnf{%H_Rq0OGeY%BuR5kbjz&QnSk|WnWY1~=8 zI&i>SPkCMns$$`12}t~ytdb@yo`Lxlu_O{34-USP)ez@O5m65bL4HEYU8Weh6o6M2 z0$vDUlHm;akhu`44rhsy&2%wUP#VUH1Vp<=LHYize#pQ6A!;6UegZii+#eo{116$; z{|!m6TOba>aehLLTChk_-c>VNw+J1;bsAbnbMS3A7_>8tH>OYeb5Mwrd=#RYaTwWN z$7%4CV>Tw`JbZek)lN!3lEx+7q{99(C6AtupEd!l7p?j8ixhyh$0bUR~)&k1z;SRie4WJ4A)XoYl1m{!wQ^T-%LbmrY-V`W+gy#U>V6`f(m zTfY}U9M10#t7@c%A#b968QbCP5^zpa%5g6ll;c&pUXJgBq5vHb;?7p5Uh3j3p1Vb* zEyo?)3Fo;@mM!u>w9*lyky{PHepzU3Bk4~Ky^3CDINq01*V{wj$17xjMTQC~bn*kN zpBy0`$VVQx)c7;Rg46rIu*nFoBAYQcP#qiIz0hYjEvo%L)AWkE*( zlZMQ?HW*el#~sl~te)3lFNwoL?B}kc9^H2wopQ1RY3T{4mgZHrZd($@qJjf8!>m@& zQZmRevPLJ!2Q5t;0;K(U%2o~D!Olnb&>y;3YT1X6`5M|x9|vs)$LdjLq`q3er#MXIT=vyVRu`dr z1T0_DQ+U)qWkm@DYM}$P&~baz!;RGyrNv<^=en#>QdfrV;o&T)aFUwUmKwu5-Ui32 zop7uYp@n3RDM!Zj+CyL1S^Vm}Zon!TxoC&-jSvlB1>fow{<}Ipc_40o&{u~2$_#Oy zbm#Da5<%x@=ev4Oug_Ny#>Q#bX8W9_Cic=h5@+apf859}P8TDW{e*uRvM9CZZV2yv%xeAR* zLALeh6BE%gGH?kjH}@YSe~kp;5W>eP{ui!SN-s}SL8wfFW39!3$LkDoe zrNuGq61@6y#}XI@O!cGo^y>B=hb|CQme3JoI3ReAfo1u&BpMEI&*Jp9nR@Tz#1t8N zRG84;a)V6I^0%=9;TS1&^73UH|1t${qI10~esDjiQ40KoaH#`jGudiGN`(e6v#EfhYY6>iz=YUbEMuu1o;Y?)BYloQDLCO=3q8fU6uZ43$N1?tUR;6h!H|t(cZX9%lXdef!qE_12jeUZ{?Avt=*MxU zb?|ht4pocAE}!_$fWtWJtD;=B`-klDDB5B#DBA^e(q`u3mM%}O%_u>d*q+$wQU-6Y zyyFUZM$Z8&FpmIN@Owj0UKZ#mgaiQ3-TxN65V*z%USWM|LF>~!CyZHguTU(@j>u}w zNcdd(DetgY%$|)cTRj|t*~)+i(F_JQ`&|VWgO~$-8ffJqC%}Qcv^a(hATUk_!DYkI zC^VfU;|(BC?b-wRT1t(E0$tMB0_YOugV}*byVMQ=!GEss-nMky31F)gYc79fERoum zznkjeHHm9o^l3M=8RF-&ziSlz^=IneGuery#e)?n(_c|s9W^3vQ}l;Z?}B;{ZqmDk zc&FYu#)T?oC`91+!9aBW2Gpxbr(x>?sdwZ1OmVIvYBU;%j$aL@FUGbk?}`u9okPrG zP{5$5>EmFMe-EkGW`}y&_`l79Lu@5L4Z8b|LkHTLeBr%^@3;UNd~V^VE6yolW`hIB zukZo31T>vp>ri-_b047<61;1MZHPkD2#FK&*4+i@0>d19?NoJ|Jo{KbDuBq4xDYl) zz;6+dj}=u90Bxy+F|LeFHmInkYdX5EBpwz@ZEnCL@x{Q+Q`SQ-S=|UWp*zUj zU>#AAz+0W{>rUB3@vGax7!;xPjcz<;4D`b|5p-Kf7_i!Rl4DLXQ(t0#hapV1`gW?^??IA zeJoDvGDsCq+(S#74GJ)oTb&03e~v>~B*HDR{?b%k9Xf_y6%D?H5`U{o{fz>aWaK6X zJPk(H`5rK`E|U0cWM>NM6ZEGkPVdF*|a0UeCR-8|T5>#t|-CC!~mW{x^yb-f&EGmaqZ$aY_YXZ9xut#75qT4=+5chhP2ifky0!5HLKp-Lh>8Ll6cnc# zR1_yn01-toD&njGQE>pHb!x4e04lavu})aAjg~r8w6ubv#Wq1y+G0y9Di*DAwspda ziq_$K);T#R1NOf6zW4hN-}k%yy|3h)ob2iBz4qGATI*Slc1@8FEgW96Lm2$aw<($8 zvdF&=KWi>&a_{`+WPGuXf-NWuRD;k7m=9OlnpF?4%PnK_P{6$%CfqaqDQOhaOpKv0 zM93YC%^;CH7&J^+)ggd+OcQ9gR{ixc!I=XQkQ6r zO5#!)`ui-choDq`yUH#?kfvqeylgV^k-pLoqEklUU>JD(mmn1g&f#wp(L_dXA>ODr zo&zs?YoDK*_7ZlvwlfcOg7%TSXYe5-z?Gwi`;gcx{&s(3C^dVA-R}*V$cR6d?LGqb zLA!7i%ge$Cg_xwnusC>Gc&Sf+vn;$L;$ERVn$r$6DGlIFD`XU!2Mus>o zt)6VfT~{12x|u|hq%TY!8q&c1eYutWA_FJL@u=Ax0wIgBUv%1Rf;$+QRLwDwb5&); z;mG(N)}$Fg5x(VCZULYeJ=d#w0L9RzY2?V$j$EA`Mw2c=Sn_du-n3=xA%`x)qbQmt z9&5-dhXuV33JR->6mJA=u4oQPuwQvyWaAm-CTJ}j+M$fimPmArE#ee4x|?j zGeE15l2&11iR;720p1dKN92J z6K&(1V0Is2``F=i7t?Ik|8>@OW2uAeA)`z*tv!&SISIZa+s&jdW5I@|(MiRM1I?sD zspO=x@T^^gw??Df=P*#CNi4iJ^9s`2^5~TQt(Ls_J9mEHh|FFK8NF?MRIVBfy z`9}i3T=%e7l;n$7;}CN6$FSXd5ptA%WdJT_cKq=Os9qn<>cZwel43=^Y0+`B@UQEb zOG?W%x4rLR^P0QggIjaqgB2US_)~V+Xv|$wLdr~uC=r)LPWfD160>bB7!X(CX|h*W zK3oSwBe!y1H0gO!vH+7%6JR<)FG$>Ie)i?MZZSI7jq0I0yhUvSj=QJ}`+w__6Ing%+3ffuc#S=Ksz))}zg zn*%9jv610eN7g#M#2kWY*6B5|5pUy3PNx_4@o;aT(`$!)kJTr)*69^}-G0TBNgOP5 zFLLM{gqTA0C<|0IjpMLU{ir3-9d6S>B6$M0{KEX^X@hgaY^{+&NPF2j zjGqmg-^>OeMbP*b+sELABRR(@6ee6^ov$M=h@}2zK}0rj+JBkY#7u$Yz1c#clOcty zwTEyB(UVVfC>>5R&)GQ$?GmIe*KY4axSPOk1k7p38DQLxjkOj6K(mb!2?SxQ^+uWW zUxHL1xPiY-M2Umum1KfJ$D_TKO+TPsDBX zZBswjVDlpXObJU0Epp{TDEADaY3J=d#f%-01>)adU$Bc{r?0cx>Bo_tUi1eA-m}t3q9YD6m}+Tqd1t<_)kkoU5-)qVA8ft;_^0mVEtGwM`x@sW;r@a zMfUu1Y~}?@`J^J9lSUyxh9`ljH(5eb!+{6t5U!v?A^0yrDiBO#_Gs4UEfDwq+2^fj zR?&aL=WRyxJ|jHZs=sWs3g*T;KRtW^&AI4{=SEN|Y+2L=j&O$e@ArPT)+6)T>raWV z{^wcyO<4@_GbY)#DOnUp{Z77aTUcMWoIsPW+kEnMOO+HbzHaJ)4qM*f>lVPp!kK;D z3Sy4$vha2LCQ@zob^C?0{%MD3Nrfcu3A3-8B6L}LuNV!7duKH9mSF#-#67Fj+@@b* zHjz-;G1f|#b4CNS&Ct!XQ>$VlyI~iFwQ^8E8n^YyN6o_4S3p~zY#6uo6~smm+V&HN zg!S?G321mI9h;p_^WefUG*x0_sO;zV5%$|iZquxd9E$uq654Ou$~lo3HV}SiT0a5` zd+^Hn;~o=Jk@JY#G^?6ukCDoaPMS@TS)*HFyIo4#t=SZL`;7j3bdpofb;Bcj0kkuj zB2Dm#CVn|OsUd7YHID0K?eGZiY8;q3^GNOR4DX9{jXHC(xh=l%EM-Xioa*%9w|G!g ziUDC%a;?qwnkQb0YwBk^b<;dvfvOlyW*H_b&K#V)r%Eq|79!vyc{>g)R3Hq{6raqQ zOf+K-pwW72G=XZR_J^^s@GGGsgY1k6S@_sL99egx#t5|BjE_P0~wJ=p<5e5|j;fNbUi{ z?Pf+>xlh{4D%c|GQoqB;M^SWl?lqtYHFFql>k0E1H*6rAs-MY;?NEp2;K4b6Zi*6K zy=Ra;1Aor<3b*oHJGyS7qm}2{zDe&+DHiILWKir4uLw7zp>9jhaeUIY1COc+HqVH6iRA2-cI zvPzSWo0NRqY_bDv~?3_eL zaDLK|hQ!L|WDD8+ANwQ!eHC8F0d{0 zmTi5NZ;l-jy3m2oL!9QF8$`&z#$_r6HYZCaLPEnCne&&(Z1xtmjia>bOjGfYvLbee zU2uhGF4sgw?C0U2NJ2g)hd~>)Dl|5kf}Ur{*BaFJoiZcf^^TzC+|KGb--ly^m_=XC z)aVqa!x=Bh|GY2q6Gl!SBsqPKk<&=~HvHNyMovE*SUj~jX8`L;Eq*1lw-M!cj_X&XCM!QIJv<-LNMOr-ifqcjg5^n)E+>zby1u+~Gf z`#~DiWaYZvEukhK+k56J)a3dXA8!CI9h4@ec#)e-;4HD2Azno2{^;rpQ@lu=sT$b& zG*mY2E8LoSo&-dAEc6<}qE&1Lp1!|@5Yrgvytj&j#*hndg0@-HQ}gRLk_P-92?iw(Pgk7*)hV2fOx;%wk8BDJ zm^JBfrg97KzoDQzV+Tq%JR%BDa_PmjVIJcV`Vogt20K@o)O_K1KGefZUIzBCRtO4F zal`XE@;$QfLH)I`=%s9eWHb|-Zsv@zu;n1FY>tT3foWGrWDo_~L9CsR_^5lajy^Ki6-+A5 zcm^QM5J!N;Az2o=JY_XL){x3wxGcab+j`!VOg4cQ=&tS-+MGKzW4;b1FOyl9!&Dr6+?@@ne`LY6yQKX8kV0ENq4t?x2tgdQtA zt?xcEWpd$C!#`f$)cS#^?Z~n*mpM6r-;GWVtI)q5?u12>UnUgjT9V&&B)@Nx0KQN1 zn~CjP0zF$#dUiJH**em*OGr+~HPo}$(*5VNfN2URcZvQ`>m=AAn@?V$Gl@v88=RZu z-X;zQeNP)NC$bu<0B&Z}19(!8W-i$ygD;P&;K@{Q-}m*m*a7x^(HJ|xJ2TJI4v>HQ6WRfCUcY{Z z9iTp+s1x3>1Mm@pb-`7^5`%yI>aon;@N4UT*n;X`{d5q;mrI?V{_ZVY^8BrO6#qEq zvn#!#D!26^_(uElmkfN<#%U+^CDreMZx}E1OeeuK(8Dw*ok}H61b0yZQe?_XDh0z5I!)XXs>yb3Qi@_P`-C9IC%+QU2Z7qDF{;NTJiaqVHy}&i}1i_do7+^=3 zt+QaXycbzu=oX{G!V~#>Ej03On6>8cKXdF%MH63FNSbVJU000qm8fw|>)M~+hZ6Wb z0siHj*s1R!eH*(6_e*1w^_5ugv?sh@!0ib+oA^CpEpMc$J5vcP1Id|LNnbfUBOQ|d zXUNJPX{L3<$n-+mJvjHNXpbh4&&Ve4i)`}To-mQ$6D+ul2D_cdlRTGE z`++J%@F@ur34nt+>*cYK-xDmjj5>#?KeO-;De(GbJ^VNchS{vY4x%`ns z7r#di_nc&YbC2Jv#0Eu20xSTew`DdGyOzo%*+sp}LrF!qJ%*N|_U^tDuELSWK)Xy@ zOX*DIO^LpNoepxYyCRrx?+HlhNPrnLViIO3c?UBbaNnzff!BS&^bwM z(a`&rp{1;@CmRycwZgNU`>`R&<6Af+5u=gNG2EA{S=W33YJNpEzpZlQe1Wy^8nAT= zJZW`0F2vSJhB(EoktaGUTA3~m$DZUbgxaDWUes@cS&M!yWnS<|D4qH4Q*Vr!ROFZR)&Wxb6?c>L>{B#LLRMkfgrBF-EkI4IQi-nB0%9 z(QSiEcqgvgnOi5J+e@h1UkV!W-CpkY`|t|SGQQi(Jz9^YYjnFLP0)zz_KZW#(d{YJ z?amc^w~bSdduN^KEbPLyyuP_YoD>`*Z-mPI8a_w34bpg6}jOB@_iNgepAkzLn`9cgb> zwgg^bY;;p{k-5|~pAMh7M%N_1v($G#7xNQ5^WzlsvY2{ll(_Rl{NV|LFEaUv)epZX z_#%@Cr8Vgi_k*Hqgiu`7kj(6&8H5$?enWz(hX)*cUYBPf;G;JNB)s>MWALDHN6Y5s zYgO*7Z@wNn(CRoV6%$d^x2*3he9*SN&7J#gZ_XqX`b#*x*KEeex|B4>~ z<(ZRg_MgK{yfs7oet189acwh^Y;4~~`GAU2Ux^1I7&S%Xh(F6ywDjlOl$!*Ylnf(E zjt8qKtXPzX3=3?ZTu#^qrLD$=c+&jN`wACnu z+-p0?`VIoIn_J&O*!2vx?XMm$8_37fu0^Ng?tpxQiCo~qz~gGdMqWb@0g=n`n)to8 z3zvIZKXxr_-#tWqh-lx*a zP+Gwhl(vTsOnN#noudO2!*W*(!-Rkg6UIZ$#u@}p86%{tKazcLM!wE##0H|ilaiZ< za(tr?cOY&=y5C`lzo<;*>r^|2d?r^rpO7LYISxQ_cXIO(9i4;r zGUp)b`kpvUQF&eJ^? z0j8%Qn7z_QOy!wcwL^G574j`(bwTa(90N8UPBYr`Gci(|g7IQhE-S@uoDQ$|qBxIK zs?^MAoqI->5AcdRsHkH_2C_V(YFRjUP_c;B>&M5ys2RRHCisI+FHH9@wQ5Rm7V97v zn#7z&*h4_y5aK$abW4msJKd~&HlkaMl5L#gaKsPf%w8It?FoPO(R|2_nP)n46Zv}S zsP*OuyXzagT@x+NF4xa^OW<8^gM`fpAsYd)sWM0y>r5ZeBNLpWJxw~XJiH9)o=h+p zAWuV3q{C;rEvD1yG?z%6QMuQ?zyU`REtwoF4m6!Mx!7ACXgYoF5=qyj{WwqZ7y2Gg zXj;HW8rC-D4Gdsw<#dMbH(Und5R}2~*oL9U(ymxIy>kuNahA?RUD1Ga?=1cRq8XL$ z32Oc%PKnWsDr>`BL-!+u%%uBiXRd(;FQ+q+CCf@(RV{I@^3i;CUcw7>O&SoftgtUw zRWSY*;JG}CBh=w!#I-yzL}|s1oj5S=D2tc*YQ^C|)L4~Q->fg*;eC3C1f7{wT*ElG zaX63^ny)=RX09;_^R;^ZG*i?Rrx`;(>b6e_-)((-dUNP5YOdS*__FrQHM*TD#Sw{_ zjT&w01$0}}7~PgQ&gb>Eva>Sr4dQd<_g$3sDH&E(%-mL-$@tu@;IV zPD7@1$P|ycwOE4h(7|!B)W~<>-K{tYY0M`fR%g$M_suB-M8S7pN1$VSsCE2dXbW?e zy?7ke)+suO4zoww^cEG}>a{_l-Hm;h{%xMxkDJnwW+-}jytRhcPW9uzqND8&eu&3^ zO-I`u)WQ`~O(_mLq!$LMmdTLuUg}ZNEz1sx&O~{5KwBOp*$Es1$ADb9ODiB0IZ%lm zTfMJznc?Z%_oA_}BKSwO%O1%^b%A3u*ct|N!8dzI8{I9~8qPRvy^FmCVV#yb3bc*$ z)_0&^TE$f)k?X?wctDMXVA_F~b^L7E(6E$a;^WuPpJU8tl%wdHcs7~OuHxgJL!iT5 zu%fA@f2IjI8Q(j#PY+XE#fB;{j<$!#?Bok=y*p^hwA zR-YhB_@*!=F+cWT1s%dGEQ3!5KceFyqqN9#L} z!sdJ1#I225=58CZN>1euZ8?XwL|Bb=$-0Ri1S_M<0~m$dh0zU3;U;}(fx-&jz23Ij zY*l1f;@^d*TY9fwSG_!?P#kr=?V-7!4zg0q&cgd5+WHRpS_q*>L9Gx>5ckU_N)~@y zlMe)`I0vEPvFTMQzwoL~+?YxkxgQ8$_5miF2uA%tK6jj3dZJcN9flC;Dh}!-peducI|GM5|@`yvs9iZOK zS8U9klRE=9*HJiZCIQcxP9_|hb*rlfoJ-cf-@Jw2e)s4u;M{IbrN@94fAVc3z`2V4 zH1E(Zw_O>?yz7PPsg zB^G4#qRbkLXh$;p?Uju3afKl^1@Z4;ORTr{ z;8PEuuPn3FtA@X=DyM7o*vwY9aWpx%eMYsL+3IS~=#bZgXRFf>FG~a`<@0@FY*4m2B8HUo-| zHC3H)%u#f*mO5#ss#A_S?wtpCnztiXc`3wS@Kkk|P20U8`6c=(AuUHDPUN@y+f^Dd zi0jSyMEVTO(?aFSqU}CID-}1*=Y0qB;(Mdrx2WAoN?utJj&(~lVzNOaGU0vkWj-Z% zv!p4BS~yMPG$kYG=Hb{$no_RNtSL=v>UbE2%sg(pSB(2!_!c(PnY0#WNOcB;9>9=! zREBD2NgWOQnagj;kO>lM!(+QspLFF7nTpi2P-@hq)ab=2d@uE*JHwEv1@XPfl)b4l z44GQ8Q;K|?Fq3QH29s+P=O9L;s+7D5)e=WcIbS@mOXyezb)*~v(!NZ)MLf`*ACRo` zgb4c+tn{{=hjPCQsqm~aDJcuSLz_IIi1MCWwBiwbqQVe$~H#d};J z)kUhIS{v`k;zv^B~dtr)^aG*n*Xy3lv6M92k z>YsF{cr@*9RbeSxXm=}`a@l4tD+Haa)&pFpHICML3J{+2Bdqq<5Rj=_-zmD^cCq!H zkOyrU-W*$>%3NowEfj2k-6|1Z@?B`uOU7~Xh>hcnL;QD}cH9SKtYxzEy+a#w32n?18s)15qSB!3 zh`~80X6gpLzh@0$561Dt%Rh3qqML7&z{$a5e4~%RQ;`;$br2qraz74xFzvDlQh}hS zg>R%Oj|&)6OIDcoS`)AS#*(l@Y|1`5z5{*0Hr6Z3`0Io3W;0NP8i6C&AFsli`s~h~ z-@wrNb{ERuSI;E#tNwEQ2l$3^D4U#WSd*%(G2-9}$mojk2pMeJ2mC&OO3xm4p9Wy> zi?iJOF7=ya#swZ_Sl|LG4~q4KV>%8DK9Zl5@e;0-*BrXwJ&2Pvsu+6&nZq{5V$6pk zlaSxYw9%dLGhqiR_Fkif5D-;zCzjr84gPeEbst*Ffd>DO%kTI|Jb8vQ zZuIfUf?Ll}^xebxS59=B1Igs0xjMCd-IVaJ0Cy_e_4|>L$OEE+rBM_JlP5dw{TA{d z!%E(<{ZWrNNW5%ri>hdna0H@`wx}(d(Hbs2jCSl}bhMVd>2)bwrHU}Onw~Lw?x)wA z80H-mT|_UFOM9Bm!|M2*80&1>qZUjyiT+7L79JFtHlhU`Zk+af5B>Yad9s6!f43whUl}=4#z7vxv*q9yFF*fb8!1D*iIPl#JIPvqbSN@)jm(F zHSz3(V_x^1C=0L&ZaZvCw{gMImN(;41160tGi;x^vBmbzAE#{Z++$8>DR&&~Z=6Bm z6Xo)hT70a53yks=JfKLv9`3B%V`9^t`~_ts*^WW7X+T#g63ptN7WjK- zmdf1&6Pi`xj%IR~P&>JIX4(PD>L)-xQYQ2!n0!ZwLMIh|Cl_#d$j9MfC=L%A93BcF z-vnFD&Lt$@bQu38B5dOL>xTbZ%Qe3^sZ<(Gq|l5bhut#Qg|3n6yH#I}Dhsxgvr4+z z*YEBE?Ae1YWszT>V6!Z`T~*=fO3TlQvkfZx1!HOw)hW-xaXz-v%)vp-Ijk!mSZscu zSHO}}sGe<<03zG%d$)ySUlvo_9P4aw@LdYRt?Zc& zU>7ub;azIff~1xS5E}Y9t$0qfv6$sT1nSCJaZ>qxcDERT+eUT8PFY6Ve$xy>v_3@C z?V?1ihVdTxNn@lCKkrFADo;?rc(wv}hsLSjrWJ_T^!7tDHzHn?Jo?*BQbuOL;fwim z-Ws9E$7DNyjzK6&@$s@FGp{+bmb1-S{Y(ovQ74#U%g@rB-8D`kS~Q*^qoi{(&6!5; zE1U$sU_#2LbO}lmO6n=jU`dh2NuWJINz?Si__Tm-_72&B6EDp}TPLRksX(v-Wi=5E zXXYU#()|zTA+b{uO}UeK?%hek+>5rHdzX55G0(kQWwvuq*slvfHmkXJD)wi7?)mjq zym9V5ZZP-a`MIZ==`c1l_$YQwCUX-)OCYC8MTz67Ci5n#jzNQvy;8*NX;@ zB5sUJdWJT`lz(|YF@0^DhbY5C4v=8WrZhIk}pWMwqO4r=yj^dB{fMpyt13E2XPw?-Z0Yv$%j{UH`4)Cj&iW212PO=35g>k ziDRY%5)N>~1n07X>W7PZFk(d%5XosZo%z(t7F0m5GL&R#F*SEUAiBm<0ZoFNcD^>g zKs>_GQeLExG*{0F)>Z{Z2c}oaxPS{o>{&EKP0Ehhr&!ys=?N&4oU&8CApxgS6IY>0 z*%8FQ3t?SL!a8X@C#==O2Y0h@D1Mq3RK`&RafYXjW?G_)h~Z;}buJUT&tM-EyU)C$ z!$UyYh?6~Yu{!D`eN=y4?tc+Ut%w2!;c7AkDX5~c%Mmpd}4>4 zNo1=Xry|!|HvEY>k6}UMsp>B%=zlt;q@y5@uk1j#ND1uPZxn{Jicl!#`=QCqx_rX? zIp1E$oZ8ZRFc|1%%jJ1yG$Wc#tTtDs6TFiMGl&vO;Y6Ev5C;%xosorZf|!r65c=yv_d~23YfgaWynkHEf5}R<^^N%m!#{Q1<7iuAeZtLi+6n4 z%)3k?2&V+a{sPgV@hXvCDw#o#Z6^gNmJG_f8I&}U*aK6b@l->$d;!@K4cW2;JB=P1 zj}#h@CKwMuy0ExuOGeI2XXFfVcK?TRhTUr_N4hcF4KHKvHPCe?^$9WWH-B8W-EXuF zb)8$WzwrAF!xm!2TC7qJ{{>6fO2(WXWtK6$hM#+*jHxG1kYqYs@G{0=C1Yf7kTGq~ z;utvvi_I)!lvXmPs-cV#4q1bxEw`32NI7XCV=CVuW4bJ)APnLTw2(2cEoF?sM#dCq z2_dIZX*l-ASABL{6NH z83h!~6wW(o7&c{7tb#C*ldW^6;T$Gp60MVWDJw^kDSRV=1FFX(x|#PY-EId>&q*;t z*s*p&#ITZZG?fVurcebo1o-=O@39H${UlmukZ2+3u(N`|!9EJ{NpEK3AAtFNj+HB7 zgkfb||1ZduMI=|2`Pf)>&ZH8x`Q=p%i~bROyw#4*FR54X@xe`eEL?8ORBEG2ZebJ; zE(fL0VjPYZ!Bqzk7-f?zu@p4^bbOlPMCPEj;CEw0VJi|#tp~G=1!sa;f-mVXMu|Dj za(}6b%z$V{7%YkS&LMM*`yCvlqKG+Nf^d{l%@%rqv$;!lv8tAAWJC!0 zgs^2~c_o~-35PNa{R;BUNb-$T7cHZ%DF*w48mOnK4CTo0d)j({gf^vsFG?J2;Q`Vm z1O)IkTp2MQ;F=~_Xe-Patcp^%k&Hu2IGMsWs>JKB)8x0#cC70ra4(s1@RI6v7pj-z z2tzCRW3yOXV~~C>87boZhEuGCj)D516Ao^1lA@uwLvO%BUmSXEn3%N_dFzcO6La3i zoiNYyjvv`0e*~Es@6680y2ZTPN`wUqvvK5y-7R8TwK4r0gBF}HC!%2HCzRQSlU%?F zN{x2PX&70uDNc+{-^=K2?PHky*w_$G@waofd zmr{l*7{rQB|A)p&oKB9DxZY%w3sEAzmw}cz!3FkIX%g-Wd=|mypeW@fnJ^B2Pg>&{ z`uNyH%gA_9S`u!U#im4a|t=J_-`Nt%C2GcS$A)kuK37B+e}m33WOIc`ng6mQ%vre^bUUWn}#PH|Ugakn#VZ zQ{Hc&Q*Q6m+pJUmS!b(L8XD@9Ph)@Ob;?V9!GLS6Q)(LOl)r2+>y(3rpXYQ+ZaLO% zfP7b3$@gb@X8C^O%sj8{^LDHvu`P zX7LwLpozo4BUzfKaQK*;i&@ROLZ*L9GM(}1Q_yw>7b=+ZQ%pa-GzDG(3Nk;JwRD=J zKE{&m#Tgs*yQFV+LxP{*q@SnOLRB`<}gM}#I(@KTHfK6SY>_1cHb z;hgfgMX@q{(SJ{-Pk6VXOy3Mex<_g~r0nIJj7&d3GQB4$(#+}off$Hwj}^|Ahw&%D zmNNI!3`>PN&VnuHV+Vz5n}b%B>tIk-F*5iWDb$z0w+`EVywW-o`+6%Y?xV(Nk35GMg7TYjR<4t@@p@fRB@Jucd!i&FQb=Jju?z%bYfz;_m>U>WrDON45(a6%DOz2QPwBVR(J@~;Y1EY?h{th8(GCYXP<*OSfnw= zH_8%#?-CHEP(P}n5}!g!JY`hZQ$}@-E7W%h`I^__L%}duNVY+bCmOSv*bRQhw$Ko) zM7tBD0)ZnV^OG5w5A%&6PX?PvTLEHVJ)HE0QI5kIN2N3`sZ{!kGHa55w}3UFSi+Wq zZ4pXOm(Y`@wPolV2>sC$la%tmu~=2?ao(HhQXi}hilY&wA*>J+xU@FdzNzB`*Vf%H zN|ed)wdrAScKNz@Mn|1yDz|WM6Q}4{2nk{uV^__S9!Q~SQqnIpX~5+BdLsaV2~e2H ztjmgP$?%Bo@qH<7cI?FERnFeMGjqE00grThp+Y9Nf2W)4%*{rWl{utfsEaey zJreix_8NL9=SiBk9C}c+bzFcjZJgbVGA84(UX%YqRqQWViTuUY#i}x+TrP2~RDF7A z*tWpiQLgT2p1|$Nnrnx~rKRf9+V<4?1P4hZB70f%z7c555PvkLM25W0)DeN}ITx>0 zK53(z(#Nf8?8TIW^;sm$x{yEcR>D`<@Y#O?#c!yV-YO}%hCuvS# zXgc0fV1=MVx@)NxbGbYIe{#G`8kTaxAp%L1CdmSpTl0@bO_WKOr;t}D>6Bz2Bb`!nc|7tr z$s}JEkxOPFLt>a$WE5nmf--->SCKNOodh!yA+dZ?m(UEAz$I}%S+h(~0P!2jNoL%J z%uqvS$UlZSp|gJyWQJAD%AcsSiD)W+_Ww`G2op0QehS$TEiuamZC`&FOs!{`$3^mIJk)>?7Zb|gb%7(`lvf*q)+3+d@vVmjyW@UrGN;c@NWkX|LHq;i{ zb?sQI(_6}hSd(lR!^#E~D;rc6vLW13HrQ`Xgly1T%Z7;-vf=IjoNTDa6F>AMs*6yl?YZo3D zSi9tlaQ-IwA*Af3wHw8(T?W!yY`u1mOl#(9U+Qbes12-%a1@fL4~&csOF2u1K#|2a zO_uwy!F*g^?rSLgj-#EWwQ=h##n*!Mis%i)w@p{bwyK#$iv%XB$c)IpA9u0V7HI97 za&yu z`T>)PU@Vre21~Ll6`Y&3|5YhW z*Ssy^&jz^O_TC-L=f0bT4bFXi8{m4U+5&$zSttG0L1i~lAPwjC__HIG3VdEEKs3}q z#lRt)YH*#y@6=M>8K|>II#Ry!E}bJ4$f#(_ipXl~D&z3qK%4>a&Vs+zili)_eUb@5SM|FgEav1@AXPptI4W%V~s%#+Q?dYfn;JUFFTp{Rm;vGu?7$m33 zR&C|IumDk!ZBIu=x)r=I<=;yAoaXNgIkU@1hatRj5n;ev2{$?Pu$k+I;WcyNf>U&Z z(pj{&x1q^RD#X0Cx9~k2%=;diY}j1JDkRC#22~TUi(*`tP?toB24<&!vMpK8AF{c> zasI#qo=8;EK%<3VKV-9^Np;yt@LzY;=jYI!2H3&B^nYmRdt)(&CRxlOB{zpI&>RX3 z9APls5QuqtFSA#JB<2Lp-MmN5S4k4HooW+zZWTt)!$#H|N-nuIu&%=wJ%eR%X5b>V zx-^aK&Oyh=-KOf{ytuZBm$wr?fu@P4V*(l2354lv2NxVi4JTamF0L8KhvJgGefxUL+7^rFq-= z|MUErBj?qn1K*&mqrLyW zk0+iVc<-mR9+kZ(Z@A=baNl?AGNLhd|9)maLNxA9ZFvB-s!xS&zlA4%d(y)d$w=`rKMOCve$PnAoaqc*RKz=!6a~&Rk&Q8iUm^GVd!Q& zA*ng|KAwP^?00x#PjJ>5qH#C zxzGpyRxtlo8d|Z>XQlh-BKlZ7OwX4w&saR~t*7S)bY3QHJg3s}FTL3GKlzgGNF*LM zP$@q#%05piufluytkaM$?zP_asz1-Cv_L6Flug0kpH(2B`I`Yh)X~5g-+ff$t@wNO z-QUo|gD!V^&>31&HY6GSult9*`sXqahfG&Po-x^B9{FLVt+rr&OONUXUt{9$_6Ek9FL7d4kYJ&RIix z4;+g4Qj#HX#p5um-66!o(;JvoL;&yZ8x}IL1OvpmvDABMAET2!(8={u4w)=y)(4&Z zlR6p4C+o~QtJh+?=6I}S@sSi(dJczrWjF^;N+#UHz?M)<$HcK*CpGaA-r3;p;CQSx z@fqG5=^Axz0cgqElB&%_KELqWfuoOgYyO63@rk&_ z_j7(ozOZkFCwNd|&p7JFdC*%8zIL#+_FtEsUiRg^mwTDe$ zfD{mQj;@IsL&|J%TRe*xO^$tG?b?BxM>vj(_bum#t#U%=U=RZtz_7%=hPwC|35bDY z6f5lEn8q_Nu#ksa5iU%CeKFTswi@ znV*rbTs=90h=JbW-l2MOMg((>F|!CVR;)b%Vjx*_>YXFeEwL>vI4=b;kgzp>-8>k? zKtf_h%4p@ZlzU;isf+N!j6@7nr)`aJ2zF{dENZePw!-rT*K);s4@HMi`3ibr%cS)t z^h;w{cMp|UQsn}d9em~0=k?Dl`IM@g|MHTqQMo9{X$M!ie7X*mXGWoNiazEl|2rf@ z3nC$&Psy;{BT?4upy+d|{4;4b@AmxhCh+>GiPuNvvYW45Q>_-ifLk2TY9x))DKPtV z#OsrjrsnyWQg?yZM@YOr5gqq%bP$u)t;7olpAv%@o}bz#IHrL3kk-4Tb&NAM8);~_ zvScMEdaJk2!D}z1*PdSqJ}BncU#{%~4E}(rlk;`4=y@mCIA2jGr~PW-YEHV;wn)Wq zkJ4`~NRu^?62an1+Jd@kBP{bqjAT zZxG&^G5jsOHO4_bWxZFG3mLQW+)Ey3&VY#*IgIDZBk+M616j4Rzu~!;$#b7>i{X3;WOv zRKY?ivN5RwpM#>S=Igl<<{AYWr5WgG--*7k4jtSPSRsy=`HVNriEMgsUQ^@Ab8?@t zg&tTez46w9dbp3Hza4wHPom($Lt4DB8!C4Khdow(`psHRS~~V;3Wa^5@lhmZwG-)+ z(r~Ym zr##`Fn0|#p4>gUvl74jiTH8wK;P_d=6fJXk$Uv7B&$x?Z1e+=y~We9n|d`MQv@!mSC6w>LTz0vZ9EB&WiFPv zO$xd)u+HwomL}|1(>K0Fs?)?kV82%ttpM{f*mL>@pQM1O>CQJ3OenJPdS;9D>&K-$nAT>~Ub3=n-3bSia0v40Nz zO6I(YhvHM`P6ibJE;Vw4d=1x2{k{8nUu_tdwvn?RsQdPtcRy*fG!Ug-4G?Y22^_m- zT-qj@9~(98HwFDT7ll`&@YKo;{hYZG(65_<5m-Up*__N5ZoKb0J10=RhMN%I^nELF zvz|&$3(VLY^z#A~<-HX}2{sMji>h3ggrYW3JGW@eMLlqxn-iGH7xi7=u*A*McGSo~ zQ4R|u+RKu^+lYA(PiB&&#CH5V_|>PUEGJOM{n-B3w#Cdf=0Ub|Lc3XR3kI(CvCE+< zKQ2d=CQHaG)%0TjY_xA$!J2LS=*NGVylLUqKs{Gl`q}OiS5BJltj61GJ14vtqm_%% z66I==cMffvJuYo~PGI#Kt&4lfg7ES+acq)Kb$a&3LBSOvXhq$nMISg(+7Y9j=1HZ8MQ< z9MTf|#0m949<3lt~4Ue}|EM~MEf=B_gGoko|-ftT6bJ&d~v4elPn-IwDS zPj2OptwZ1lxnZ3#Eb$~-R8!am{BkMh#o(7KOqo!G7ClM%jlL)z+y-CVdO~09zs%Zp zJ%9CL0^a_h2TIn{FCXk~5chW8_RjydNWG`Fk$Ml?+orjH+%TGNjJL)n5D)LLx((ji zqPdhn<{^4>x$qKtv+VrIxpH{RPcUK}o~EQb;fvxHt?P{8T^tNwND&>(9+w?9PaZ_J6}q#{ObAmWjbHF zx32!ljbWC~_8Zpa)Ek0nwiT{#?YjYwOu4+tY93!e3AYZ3xhXK*Bk_etu(C8G^mNxQ zA3heGA?TCl`U^C4Ny;5G^je)^yMe<@M`d`gh;-d3^N)ShDY8oP0X)pEw|(-CPHLKS zy=PK$z6#YRwLM+O@T_2(&QD}HUiciLVe@0GoNrH$6Oi89{wYKWLC(G)$|TdTor6s+vALB9#{XN)%Ni87`9k$^NuI2 zv4MBv&a6(fdFJfn{SDpUx;J!xTWed?wJlR^i+bj4+w+hW{cMZce|ep4Q4hu>*%tM3 zCNmREadvoxz8~5AD(Ou8d%DgaGV`WzT0hjfRPGJ)1v+~^}kP^JJ0OSS&y=D@Pr0$|!xRd5%df+?UH2&VKvpmf=O5qwVPX)SJPt=09 ziCS=03aZ;j41}ZbuzU|5)>Xg}Lol5jF^DOrnwWADiCVBT5$ZI-W138Q>^nTBA!9wfb`h-X*;D zgZIDq1gPD$>gRnzqNIO5$Q(0Ov}4&v-*(CJc=7bu`NOiP$_<}wpE_0aXh>-)l5`xs z*7y$K%ujE=SB#&1KKoW1Q0iQJ{$l=AQNJ&5r2_e@dwu_hV*pPN9%uw;RQ)>rElN?j zYuiUhO#!<99yGi7tt~;AB6%xb{|sV+nh!b=0eq*|udktgumS;&OnmU`3)EjYt;3gK z?^*ijOX6^U`TF&pUQxP}vRFLP_Rf~SZ~^p^GKL)8nMYm!{K7{sIs?Uguj0{Ea+m?1 zkdxDhw{8qjYU~_FxU5vVl8t=v+rQYUT>jPt+>^LpH;V7px~IzAuNlRU8Ky=ROV)^g zrH*s(uZ;Ou?zohFUgn;vasR{OVR~M|JY(^^NBIo&2xi7=VAPuTiv;ay=DL`MZ^~nD2;wmE@>l)&D1epVN}nCm_#b;vFZ=$ate@%xHsv>0Xwj79gMe~SG!;=9-6JYa*< z?%6~*?H((fc1ApSXk@0t{tO1R<#1Xp;k1!0MS*+5<8(YuD>Y}BGvl<`BlBhX7C3G8 z$o+B#x5ZJ@jMH8`+Zpv2S>v?DRhkkDoYpup(yjqcs~=Mg2=g{!!)BaTHY%dK1x{-i z>fzo1r&T0B#@V};oW0FBtuWT@S3Jq%w4nos7B|3YcLf5c724pm{X-^xV}aAwd1)Uu zz-jZ((INQ{Yn(Roto|iB$>Frp_(;bFIIUp9U39Y88mCnQrv-;GkJGAw(=xa%2G)$z zR>c$H`Xj2-jMEkXr@du?(-r`yCEV5;r#(6i6oEowTerYz=Y`}O`dHw!^9B|$xUCgV zYnXcvIBlgZPFoF}_JJi%yJ%pU@J%@FbqCarV;P=rOgY2ySgBZVExJ1b&2FstHx42le9=??HX`4EQK(rLEZ-++**aLLeKcTWxzVH^paJwB`u2qsGnHn=$4oEW=) zODu}-N5!YYYm`HlQ?{3>K;h1FK`JK%DRP0m5V28)kJDTZF_I+?0AiFuKqbWta9H8d z5NSukWBKcW0iE7}&+5DhpPf#XTi~;Xxo&GM@ma%MkDr)p8+^9QX5h01B5t?DXXTUA z11#}b`Q)-5Z@_1@Sw{eo3T*M&ye-2fh6G!XK>o3?Nsc8WK8J4g>PhtJlW2;zuIIsF z<7E3S3?6GT+ZZQD?tTM4D_nO4mFsNr*^2WLaev~s=Z5Xkz)UT;+%pMg)TBvxh&16G zDmM^!`|+S`o?KEh<#>xK%+d}VxNK~$6+dAaBYlHi>~Np8pyX^N?)E$y_FZZYA6*c# z3sV$)_l(J4(%O5NpTt+6(ZOlAQZYW-VVl>^W*k$Sarzf{?F9r8=cRQt0fkb>tMIJK zHvxs^+EBQ1?vugKaIbVvse>;^^dg^8fZvwWZ!JmhJM}Cd*u#SKzEf1iO1j4Kwj{m3 zK8_gK4^!JM@L6G&|04K=^Q8B}tiZ*f?reb1F5d-wR%naQDyOcs!e^CJf13OTeAck% zC4g;-8J}ga=%+K_)~+_O(pSD52fXzny|1b?6&%48$w;hRTnXxU287lqzSsneS50MP zH#!Q8mnmb+tmz=ddeTmX@YEFqb7K-=WJZtFA}ED^%m(ddkR@2!Iu>Ko`{}87qY2tmV$m@0=`lKrBsf?-wzqF1X zzzxp<`D;>Thz6Rv2mt1bKbn`T0JVmGQty!HH{JO+YHPK&Un&8F)rKCaJ48)?X%sq( zLjvzb0|}IE0v4rSy_mTTx4V><5q*5*_maGo9+oKwmWGcjX zFCK*?Qz4m-@0fr=x%)|AbU7%2#akb?A;cDqxvI!s-6(X9>9@N3iTG^}{Z_a7UA_Th zQ*)Y3+pKTJL6KuW)NP<^g7{ixyvyNom3=1wms8u|v*|ez_mgtW_-uMkOf8UZ9-n<* z=))Bt=sy_+yhjDpB>>>HCxlsQ@7ET1@N(~98Mz>LtNdzbPN0#{-Ac!V)<4L6{8F7x z(odcnTBnUmTago3PhakOuK3H;U4ea7+bwXnoQQcK3C$mw3Qs@WJwIWyDv%rh-|rV^vS@AQMoHYxe&111Q&3dG@H)OI!GPr`>ID@wcBi?h zZ@QNx0G;KEx_`eOgS&+qxmhxvgQRsY!U2jBXgh6*a}GZb>Qk%rD&Vo)kLy1)VsLnT zc2CHjrg#}j2AC}+V3sDE!2U_Z)+?*z(`X#qCG(^2eavoj1psZnw5Xq>H@&?za1T(_ z4|ANg?WQ}=q0yRU_x_r0YvF>g#t^p0l&d-FY`~ZE35oUZAyFq5CrtlsWzMw9HQSwL zjlt9+95%Mnxo!JZK5>EZnd8Kt++Jf$y{YPy*}VT%RVQ83{r=UBLKkh;QEA%tItHVS z07l!qAx697jTmhaVYHu)vck3qx9Uh3?S2NM?SzkqF5u$|9|C5JzsFBfYL?;S;q7e* zuGJA-+k!CKnw|Jqf^TEh8-Uauk8IYq#yA z?^BNe3Os-G83WSxqy!zOYuXXIJGYt#X=|v6%Z!&PL z&lP<6G=+g;i`@WQFT1tH7Ot&+g=Y?@UL&~no4u{^%p)-a*M2(fAK==}X1MkN57(Z1 zkB4h7*uu3>OmOWy6I|Q6AzXX-U&FOyU!tL*%Ti|4n%>}RFaMLRJ?sOAW}bQB;t^7I zm4S69wF+`#C*ZP}--xoMv%Bmqwx9bg%>!qnDwx2avyg;u)VwCUfZxlqr=e1=qdu7>&s6kG=z?yyrYt?jf_0w zr?dfx$234?>vq_pvJ02VY??bN<#Lt{C>nQpMgvr~{zikMmWSIGb-Hx}RQAprP}!HZ zMLk|Vh9zUy#*s+|2<*9~JSyAUjLNc{;w&zk4O}*T9dKFyVb-{;`Z$lv>VeBf2(nZK z$TuOa{m$WgsaK@H_v9?+Ge_K`MEC7?5iH1f@tX&A)8dMj-l&u-NJ@Qi1=T zBMuo7K^J=9UEI`v(`Bqi)EjQ585+v{qNdz0*#yuQ5wvt`B9X~CNRJT9vzT(zhyZ0KdEa*Rc=bfI4Y+>U|cfD`#9O zC&KX+KDd2D$*pwUzW5(o6G@^e%QDfQvo77a{`=G5gQ@!NEpITL=l4BD9;;Pf{F#m- zC;!Hvf1Qrry51{Fdh_)Pf&|0%6+K5DW)w}Xo0P)F#B0XVtdqC6$<_7$cP_yH+kJoY zq+vHPPTNY82zCCd|Gr`tI;DKoEsux3{14rI|4(!FwL+5bz6v^n_y6xA$;QJ!zXDu; ztaLBNDC_o@A0L5bV9Q^>zl$k2d4myCFz@D{pTcQ6_E9g&>D|5mDK>aypfu}p8k!WE&bsAC%vMK*X9(%rlF1f z@wqb=#h)#ITGlhFcK6eve`2aS?|yGDp1fYZ@Ag|+9(r~D=8N8jl%GRRBhzfbAID*_ zFy0(+`zdOjboje_Ydy*yxSj(Yvg@{fR6*$FGcuSlirY^-3b%8?s&_t_F;%p@|LU)L zM(KZg(m8diNH;D$0PV{EeCA1j?2=v68!w(JD!TCF!d_A8Z7D-AQxrcvJ^Clgx{>?S zAHAYv-}U*4mf4GQ2}pNdwe$65B)gRyc>6=xDf-2(zJao`-pjv=CcJw6`T@1P>tX*; z*Z}~_gG9*&ldo((6u?3^0v34RX~!u&fZ&+`3&$w{$#~ys;V@wG8L1*O8EL_kUch*+ zlf$VQkC#s96nB89!;`$_I1M?Xn(=tW61R`w?a1Tt zR|XE98f@z~-SrTqjwBCJ3%}``Arq%#a5zuooaWl?Ac$uDrrmm}lNA-+l%th?;{o!5 zXwccKHAtIt^OXhNwgZ&O8yJ_}?A;TFcqup$mWp_{#23|*F)GVwRLnV>s?XOpLv%9d zF0Gm(n%J3R8XP@nVi|`yZCwf6N=6QuMCo84lt?FBar%T5nHp>O>t7s#=^7|eQW}TV z>89_(nCb|tQ%~aD7bS`IiFopN4lB;|_^p9=W&L<__~>9lh&td?1s3KG+=<6)khry#Mcj=9EP7r#`6J zlDt7*tzL7|A}f>LTf3I7$*yFPmB~2cC^}qA9k$HM6w*_@D;QbJIbb#5PRZr;J|dvtm~*ehuklV4L&@W zKA>lg766`}dv%6&9y4lDo|>F6k(fzgk>vNlu2ujzm(q;1aF|w3xA!RkBLwd-y|KkF zVsLpkaG0+6jS$6$)By`XUi&S;&>lluF8+VC9 z^0sgq2&eL8t>H%6x=x|KVfbovD?qz-^kuE^YsYD=__PYSHJNCAAO#RFujSJ!Tx$kg zTW1>MOU2&<*CvnWjMh#o0BhzC3jR^z@`q$&L8}plcey2F_+~#cr&PAVB$B#(#=GL? z0C2sFWT-qqzw2hJ$ffWYC2Tb_O~wQCb>g;|!CCYgI^^&GeZeeyyVeIqySZU3m=)>C zT;n}{F3m9k=NiSgXHmJjEh^Wz?By%h|5jOfsl< zl+_;a7|UFv+r??Me7A!JwL`bJQnx#o^WCnUZSUCTpy(ji?b_Lq&6#U-`zvWV-|Zhi z|6UzY&<@>}H!9-0on0dl2OvW^&~erY{(4lVVh!B&Y0N9rlqS?%AntYxl?!uG`SC7B zzVhtU^cFgZ+P65PNS@C;affa(*_013#sQv3;2Y#+Hc%u>(ST} zuJOiHkxyZHXJJ>GLT@>J!wtLs?IjrY+r$rEmqxDBY-1nw)C;p1FrF2=h7yl?lC}p$ z-_hDVCgr@UHHzxT=zMK^bY9?oh3~wvM(Nn2Ta475E3ZDPTCo-uAuiP{XO5cQz9Txn zmpUJs4QFds&{WP%Z)HU1dvl$y^ZClLH*y71%GqFdJKSeGpMPmDI=`1Xe^brpxz(1t ztKC47%HZ_nwL|9xj+Kd#5diQ6P|KtJ9XJ-*^z+L!Bb z&L_p_@pbC)Luq49v?_jfE8m8kb$n*upFNy&=^BkIOP4n0hI`-Qx3KXaMdkMWtN6;L ziz2$WKPY;_57(lM2<95^x=vbk+%zV_yI&z>*!UL4MBTWaf7iK}q%V=WF^p?@u3xbB zM_gm9%lbRja}!Ou^dTl%MGPj|?c-tmM4NM}{rOLjtd48Bu;#JzNmTxVCfal9U9NKN z{R^mE(-oChHh#rdUUo_U^83z)0bJ!}e!pmcz%?q59N_edtGssU->AHfD%V6t^OY~^ zm0{R}Sy^UM3AC3y66GBairSEBC-u+Z#H#R_wn7~t?S{%_o&iWY#RkRPy`&cZ4LhYL zH_Mj}(ota2#xe|hR)!q+*>Ipg#&p9w|DN9Y;M_=4qx|k`VVE{Wd@EKj^#|LFC*F2~ zXX)U$+Kv<&qW^$p*iq_t7JK}lTvzYJu9gm}J=o)P=$f)Go{$h1^$i|;&g~sU7 zc6NozmPUVxoi&)#iDgUId#{7u6i% zUriYj-$LTB*usI_%h#s~f5QOGq5YfKh1A_VMxfb{ z!OfTe5%DC~xPQ`bpOtWqA^3gn{r>)ap8LA@+MKi3+G|fI;q3ihA0hti(Gw#!zlz7i)V2M| zB1M=g7jfoH<|94r#Mx=!Mz|rvdlEM?jdQ4lStdm3lDK zK8h~NT0PCBJsMf7vfLlq+zIHb&1&4~QO*+#kIJQ!;;w? zleFX2#ypJ7;5eba6Md2oJM@sZtR2E|2De>ex7MU^QV^Q0(`{zwt)KHzY=SxI3rSiR zufm2V#gp1f>g1%&W_FvF2%-=}HpQS?euRx-yo+e=koET2%Na~zlpytbIH(iaMym78 z!gbwll`>Hf2m1=L&ux{@+DExamQ2OAC1cMUE%@ms#^Wj3KlW8Gk2uA`4FRqVh+n*d zy0-62#IN7EtQ)An&gJsPZB6BkFMbR+IYv8g1(q+jt(bcWTa;w7weQ>R*u-2Ot~idH z)y{RT&h%w^Ce+pYS9o5DVqAA$J7(w7^0ZpCBq~GV&g{$;x_GB{+hA|G663@X+Pa_p zp%-@MsvR(s0&cmpJ9Fy=)o82ySF}~U@)FzDOD+|g1Lk8hb5i4Grnn35#_ag18ESWo zm%=^yQl`CRWw^;@#AjNjK3eAR#TGyw+3*unzxf8+5Am+K6m1%fqeloEZTe|-c1pP7 z3KM<$X)Wz%UV*pN%a$R(=BHCv!0JfSl2g5ytx?fAZ5q&);Fh3Iu12k7&=yHI-bw37 zi?K|ol5WadcJvv;66S{t0*#xHIfB4x*jxV%Oah1UcHXEC~C+}P8!p#;$BMbKU=~T^~@j@oskE=pPC^a0Q*`m zQ(V_~g!g&bCUis`4a+|w)?VF9xz%cGxXDw-zW$6nHXtfPPWHY7lK)}sW#>7YQU5%u z|EV);FLP8%UAKlSo-y{KbMnmLQ5jmY-!{M$*Y)+Mo>#$MF`w#x9?05r`Y3C@Mg1=r zd&xPuIuuJ}AnZ*BCO>EFv(HO@LH&!U{ueUVUe!mLGzay+WbC8P$?L{PWvIwLVxTFm z>+7#LuZ6v03Dy5<0c$UDQrZe8M`dt)n2B)PRa_`b!9$QHLbcz5G-hVj3mPzl^*bh_uVxlr+J_w<^kLkZm2o4v-+Y!QQ3gLAY8-l#Aa-wi< zREC-jp}|#LAzOxL9)(cnm&JzgLJ%}P*H*F@AcmJPwzO{7(${8J_bu{*DO>%^!G z@gRgy>tp(s31Ptn*>4CTgF<*$!G@rAS8i=xh*2NJ^g@uE_@QhC06PldxnBhv!fSyH zA*53Xb)S7;LkM*fJy)ardnN>#2QT8Y5=;aNVcQ_n4_T_uCTFN=jX12t8nL{GHR7;x zo5KA^W{tcT#KCn1xwc$wg{PAqu#0jF`&GI!(jstj<5N!V$@A=f718>`wXc7&6$33= zUR3YE9Lc6R(#B;qyGFc}+}Xc|D*!io)6o?7qVw{=8wic;xA~j0(4+T7*&g)HpVT|; zzGCgwUdpYje+@Tj&)Ctb^><+G zMHl3mk1OVsDi&nGts)5R+XQ;muV;^-vUiTgolm-t|f z;9;o_V20$j8`cOOmNO?~c0~DSt!CDUv#PE8A7cO*2wT_ z8R8)bp*FzOh6%x;Oq_!dE>Z|xD%lX^gRn+Ire(#yDA_O)<*f?Oy1bJj2uw4ZSNj9lB)y);4BxUj7YaTi&JZ?G4m936y0aw%cz^7cus{0m3NQXOq1{TQ}Bzyt}CLm<{&cjJ>j>kbf8UC1iir-9xNQ=v4U+N#o*z-mRm%yI$2kgVzdb0L$?jolN8|z#{}A@q$o{IojImdKCp-mv2eOY{<5GY6O=tL5DL#Sn4h=k!{8kltQTU&sx*+5@ttHBy;vIRQYe`82NT(vvm;nXJ)plzfp4U?Fd%{Gi$lNOdj|avp0km(l@5Ku3yWta?W3v?d@T2 zHlh=2FAGs_t^6b0BN+SCauw_q9jX3rU08e0NZy*CzMh_; z4Pm-E*u#70neCAl!>6FDzaH%TGt&i&g_CFI+N!7vBtM6+T`<~1v~MkDE6V9si!xNm z5Be?AB82RN2M03t$^>By?2E`gtbHhJKiNZ+n}ONN*z>}LsU@%L0uve44bW}WRzo49R=8uUPa@PKaN8t12 z_@jsHOWRLm?H73#!CrO&_FG3SVeM7rs(Y~4vi5Zzv2S7DK=yU*HLQK1XFcpWmtoHx zdjj?}Ih7?-dH5}FxlTV*>XvvSBdy8oPWTHe*6IWfwu)UQW=tGg-&maIEUA)T#$5i6 z=5pNU>zIM~Q&=p^P>*dK2)imR{vIf%fhZd6ux_#2sxC2>3fU3vov=Pg2C;6Q71JqN zPuW3q$`uM~+$Gea|FbXnqFNXbogo>=lzqkZr?G)mIZ7F7A;ziYL!FVbE5 zNKa*(bSi5(*>P9;{@vPBhf+G7+^}C9X;CpnRh#`lHttw_r!4RK{$cBOk4v#m?(Byd z6hC_%B_HN1-hZDzUbj3k=_~zheVciXqmt$Jrokz5dYR!cc*@Id2aug(1XnTA6ogY> zUi)SBPMJsXuBP;8YSPgV$<(kDo5xwjr;y1n<1Ch>Ofhw?jy3O1`Skwj-dDPJ6W)+L zbq$`)|0hC?mzfvi=zU_}=WW{6P90mcW%dir?tdJ)1#noHA57)M&lG3QN&cEx+$yl> zoJGqK+@06kw;bV{-O)X)#g3jGW2<9%f6jrA;Jz>0v@1dmM;U_J6x`RW0B8AOsJWYT zBs~cZV+U+?O2$PkE=sFgElfR!*Id=hr=b-WKMQys9nNj>Fr!+slu! zZ|ja(3;s-*binitvyZJk%5|C156SdVpUvvNiev5CNLuvi#k}rxvdMyoj5;2nsr$Zb z6i7_6NuBPDd@PN^!bM40X4r{4j@FtU=f8{l_`0$*mNTE;uOW@)$>VZ1lUVjKjcgX) zGZ!&$U-!(CR9CqUwRA-c|MWG%b^V^e?bu-0Ypzoo%%T;nJ?A)g+bry%onh?Nv(g`0 za=_b&~*-vt12nFdeGgC;TxH94e|8{PN9r82uuK5hV z=XU`XX8n`u^f%t`8c2K`QqPkc=)gtt&*46c+mbrRPvfH|8Cyy^VI`oLP8BOAhW%Ol7Y^!4h-V??QzXf6;hRkXPN9 zW5_e6e|bIMrWleD!HZF^#LWZs>R*26#>&EwE$#fr&#_}Mpy$yWrsw${-;Mu_yw#C6 zdG&GD+fti{@K#FRV%@x1Zxyk!1wepG7;hD^u}knqv;^J?&-1ivQlGjFVxvi&9h|f2$@w;N=KMvH$UNoTyBm(&3UYXXKgeyUr6)#~<7s(e?KET- zbH=OQN$A-1iDFOR3hb<&m$cY1D8BoO^N!L=*yI%2nN$|H$;9PlqdwTKJ@2TY%ZoCs z6_t>6?)DmaG^5sdYSvprgK#@a^-`^z?WxO^nF088YB!k?pkDdvZ*$|bknxFcxkl$I z*A(1bDmOZ#Ir&wS6t0si`NjI22uO4A&d%AKaw1S#1he((#Yk|++jXo!F3jA~YXF=6eU)`zH)Cpq_;C{h% z*B-0&$F!|l8nRGog!3xQ}s9g(j zHOWR%kQT-?@*C@1U-C9454JC2s#<>kP5In}_==|Ic)#2Hrf6_$5sFsid})Sp>)IyX+-}@4$>r-lH=ai|d31_9y8LwW zJ8!;kHZ}`pJ*LLECuVbvT0_Y(_V(&Q>0gj5#Mf!k#?^3*bI2}T&aM8$+3bcQ599g|{YJG_4e2CrGO&~Z zBQ5MWjp%vn0*t63WrBB2Tnob8}9$N(5w#%C^q3%DN%P-euT$xS2p8S3uz7jURQ#7ViDO$b%ST(-q55a>E8z7;w;bcp3(H>OPR{U+(=Zx~qW4-PfAiqb=RFNT#vjy^!wb8C89`s}@XXCn;uRW;C!m{W8k-=i09yqenV>2rlvXT&YO1tpKV;<@y_ zNky}VPEj-L;%XkP^qW@;u!4qfoQHZACxjcV@Uq8tcQ)gGqV-GMKN@`>aXiUei8u;w zrEI~CV^z=TMvS5*ZWO$r8_#;B7%#z`wCZLZ{mds0GNRyCoSBX;_lejjI9p54aP&xT z<)7aAp8}`*$eKKh=v%5^UrV3!W${^<;uEZYMF{!Ps}aJp!?tweR^(3H*i!X7-H1`q zjiW#J#*Jsc@4}6nVtRU1u#eY+6j4Pf(GFa{ZcKIYq%~baDMCN*9 z?xuN<>`v{0`SCG16v?cs=*m|G3w!E*noY2tdP4Wx?YV`S%-HaW-Q5vQN9~c_d~f|g z`t@0QS+@)oZ&JH_Vd-=I!Mc9wMstR~QFna7Ar}k+*2(E)QUibKo9E%?3z+I!t$P0D zU|XxQ2|9lfM+A7Hy?s>v@_0J2M>x)jaoCWlq|rgZOH}xGv`fd0q-?45rp@K;j#$ zoF<9_Ytri9bTRJugK5>}Mn7vW<^XFb+Sy9~^=EQ2lIoFCjT1PQ6 zl*xitYqw-+CHFk~e@iKSN6)>&t!g^g=qGipUC&i|Om20vzp&;d+{O4cnVwMQTx9%R z6LPn_tMRw;l{X&7llX9&&5~xOuMVxxMrZ5(21{_Gv~BF;x$q59eNA%^qUpVyoOQjy z{_u!nsVNmwc>*-^9lpZ_RtL1W3S6JBP^AcoSoK=VRDH$ z<%D@E#(JAV&&mTV!< z>M|L(1)u98Y7ysPall*nJn-Tsx*l9cPNwK;3?e5FqsWOhIkEheoP@bkct7aEyQmAV z6NNWO*`l-9v{2TrXlsw+TK7`B!IvkGtnk>PF@d7$KLVI*HMr-3MTGN)6iB{_iwkDK=V4+RD}gCW9)LgOynu=G zA{47Xiek5~qgY)Gip`>82gj(VNpyo1xXm%>)_~w1!$PKv3k!{JC9;>c{#7iOhzhB& zD+&-4l;5W2&w|gQ!|+*J1)n#Fki1WXWE_fda{FW}xR(51C9UlzcG!{B_494N zkHdBOKA5Yobg=Q0a*|4?!-E4cmU3b&UlKeSe@JE_N1nZtf$833WCL5b8?naS^849h zAztYznf^?Il66kx&&^?+ikyAxJ=H z*&8@z{TdXjJA+~mNp5To74s(X@r@`Z8HQr9JyA?d#geJm0xH&zidmDl6)46DK(W&9 zC>FdL#Woe8SRfVq;DTcBGvO@}#ne71=G+a%4TRip>Pz*j=AJ4pdh9A*WXWHD}z4 z$^`b8xU2Svb3kxe-;KzCxBKX-xhV1Q@5=HUBIzc48G_qFPBlWuvJdY51?X7)wr#se zvBdk(vEeJNQeS^rdqpIT+OmBXN`6Rq@&bAHp5KnTTOrcq|CGIPrIqS&!%0+7x}#q= zNTNvZUCS$4X{FBfIbJQ|UU-mK1+#a@kH(?a;yn*eKyjn`gQ^wax1;xd-|kdODvz~; z(oc_-aV9@?Ze6c6tKocBFaGR_{IeKY>*p%^|8vCVGKpENnljQYmOM}Dmz~s)GRbW| zXWz2*B-490KRg3k_p#A0PJvJem6Mhb_eKUYNcY@FeD9 zJNYi__|elYF)Kju;cPzGvWL(3z!W4J)32OZ@trV=;7Yay!2V&M4lNn+6d1 zx1w7#)2g3DzrQZx9{+LC4b<9u+@Wk-e49AtJqSke*_@Uj80YWOrF|^=aCA2&y>qS4 zOaz@h;n@t&0GV={04XH7!sQ=wVcz}injV-vw;tQf!Bjr?Bo@B5I|^?B$h5pRBMg}e z-!4>Udj({l>iY*ZYr)CiNa`n{cL}YkhjY&yxdsmTS-(9it#WO{6%eAyzxlWbj-wu3 z=mTGMjtM8Oi=;=NEva5>|cd-H<8WCqMK9c;`Di-S#4QOOy);9E6x}}VotWc;g9ag88gd7k z_V8!6A`G&^`Lh}?w>`#2xC{+un9TH!Nh|--;g?oxBqOzFn0>ny1*p`L(z40MF~{5c zVkgtILSJHT-=6dVO$<8$$=R&IF6Yw!l?E$D`?r{G#RTwWJd(3X#}lM*hz;QBUmZ@4 zE^ZUcg#PqO2h%b5j7D#@sA3wO`48&XQvIb1*+wS_i%^o4r+GEFN36@xdNkU<*8--_ zDIK7hp@~ABJW(>6jZQtm&kclR1k>^A3CVuC&lp^p7y{bTB%q-3qz&sqOv;JXbjD+w z**Z$D5MM_pbz(n@Tx#euwp$8C{K@PdU*v8~s~jRGGoe<-G_a@V*u?-bSYongfNTkU zMuR7soq$vQCSz$QlIr-AX8AM~n|@I9qp*5{fiR8cN2@MT^cg+BJ&7<({kGG>P?Fn- zre<4HenUM|!I^~gl&2h%AV8zvVpEH3&9*~WjVjle`8EpEa77L4N$Vq@Dxbw+Rjql1dKE2aK^Be=NlIi)P4TO5uQ{xw_KQ?{?QdL5a zBY|mrdL&X+qVc3qQ`nWi1gR>mM@VA~QdK66VO-XCPK3+Xy8%`2E2k5rXg8#}Gb zKBTIY5`S6PjjxKWVNzA5w@ign;!N9?paE?$Dd*n0*s{GT^B9rEbO&~nPE4brB27YM zSHFe8gcvMQx@EKZ#IDU1X*5SR1E@?Vn$u8`CiR+p`)MROWO^#K^Xe@6jNmBgAv4UJ zS&pbYmnzriI@F^!J3QSVkq8Yl;dcahKQ<7hVfy6qm=)*~H{#56S~MD3aw_g_;kQeQ z+vGBlb9Y;<(S61kXDf)()Q|D&NN}XwRT$%wU#MB8(Q?L%*%wbcN*$TV3pXUlUbqak zB&PA7mL|BaUtMpjR-y8NRQaFIADPCzLYqWPP^R*vy;kFNpHX>{!kGmsdzcf=nTg8% z`c$x$XV35x)t@%dqj}ib*6bZ7ELA?Dl>$07`pUPDU60DAQsuenY~|cs;ZzBL+#}2c z~16F-mXc5U2rbbB@dl@Cn_sH7e#eKE!d>@Ow>b>rN3l(Mpb*40sQ zy>APL?(JuO8~7um!kBda446s@F_qd@$FaaDt|XhRt2lMi$R3K#cMUFXBR`4b91V*4 zRvlv-AM|qWg&NqZ$kwuTExS02t1RsR`jxV;hwY!ARh~3|tU(3X98J~dI-%{-9 ze*M^IVnfR1BIlv9k1%IO1@s$5!fmjH_A?9{?+<8>Qv(Q_+nKRk>+1qw3Y~yDu;zKtCKvvt z(S(-v9PPVN)*^aa`<$;rF7_Oiv}vX=M)7w~5a1OUwWQH6(y!y+t)pL*UPt0zPA|wn z6I8iLCAOZ}aNOI%!CPc!whXecvd`k$^f=X+lyW9-GhaQ^Pi)#qeVwLsl*N<2p>&jD zGxesUT#aMWz-nT@0zVr>?3bjo{78fU^4hk!lyI`N>$WyHQPuK|weF*3x8LbwiuF$< zfs(SMbeSq}@DKenzVKDIF>yGV*Omt|OqrPfHBhiDj;6OsrhO@rNS-`n?ySbY$|7@N zx33qp-y;p(mlCFFW0JKkbYJoqjjM?$CRe$hTqbCjsiFdO*$%V3lM|euPt4t#OT*%{dXUS9+qlp~m$4EbXYrX5Tv?#a-t; zCbL(!@{Y-s+HQo82J>|Ok7JP^k3;$K8q802i?r|s9Tw}cNX$zIld?e3m7k9Db3Y(f zzQt1+rBUQtl*yL0xit7AGzb z`q}>VhN%&s_RZ^LC}d;yy{z$=FGt9`gy4e+%o}flM29>&aVqAmMxsguWNP&1{w(v5 zs4|sExdj&7o@ZliL12ZF$z+Jugqv(2tr`r()ycja*wDNMVN=n!6k3+WM>U$^&dw>> z>+N&dzHKXL?K?BlB8Hka-8cJ(xS4?v8mdHxm~GOs+BVwO%1XpJohpsuu@%(>54jU;kY=6YS`#8pKW(q54kiz4*-5hwO4TG8w5Qwl!5m*98nsL061 z!Q$|O1q%>uJw;pp!xpw9W1O9lu9D-(93qWz_J$%52cV<(iC)QD94@RNMW8JtF0_7A zsBVh*+ZCugYo;J`bE>DPXTEU#iZ z5n*X{pSQJ7iuI-M=cR_H9F{LvJ*PFF3lGVMovR~qZAWTTaFo4O9)I7s)`v_gF50y| zq{zG2Sy3RLG}YoZytS=0Q7o@9Thz%mB4eUcUH7a-dDiRgRm(A-GF!d;3IDdfaYym^ zc5Sk9uXxrrMgGMv;KGHybAFPW5LK@1oargQer@(gyrlazT+Y))}s7)7QE^T zv_?HOLrWoS^Dr&X5(}B9xx|q-2vbS&Ii`}#&Xbu+m>y=HGOV^rrWr-2ZB(tz2liDwG%HCfsLiWVt8I`U~2@yNjNAU$23mcLmA=#uhrzW z$})CRhJ-*#Td%BZq8Qott}P=w#VfrON9|Vi>D)3RLrym*dRbo23g>UXUpF@^S=@V- z$>)baYl=alg`NhjX?C+TO%--}={A!O%yQvf6f6Pi$05|OL0HfBb(||_YYk}46HGXu zFD<)Ea1j50Pg`<38PS%3f&Ue4xwW4WZ5bQ^b+>|{vyVn zXC+*M`Z>h;Iasx2>yL92ZTAAn#@H*X1*sORzy?wMU;8^V_R6-xQ?Pd+d&%l{to;%< z(Y}6QgBW|>O#y!e>_f;t*guf5=XDU4!akVn!>l^6_NUxLxnF_}V(gXw2vS?8fej-2 zP=6U?uM`Qd!F~bR$FAa|HFfG4=~C%GI&cGPGpBEyxtt_4U`4i4UXx>r}t{ zSFF7}K$%p9`aKwX)kS&Tf@v9&VX!wDmh8^hTV7Ga-lU%D_n64qi-+Lxm58hZ!>Eey1_%ywYi;62Pdc5mKxB)l#1uC>|G(u8TM`ilG*%Dka6U*@q*)eciG?0g$>eZw^MhP$d)#4nK+_0-gd z{<$n6H(vM@v$cY7%A&S}Q|bZVI#9KXH`rI8oc^z&{=(GTuzyMRPyI_7d)|LUEk&Jg zZm_Rs?3DvR)gC7F67~)Lb&S1ojxYd(o&#}$4Q+$j`ro=Y^g07n)0g1{)!*>aH@yF^ zP|KP&JDZ}G>VWD01!}2Ga@!fM_>P%vk}LAM;e=k&Yz;BRb^UD3E|;9b5OgLyF61i~ z)l`jCN(F(_GbAIKE*3t3AFwToI1Rni^&`c}n9CNc9w+2ka}z zKI%8tUKuB}w8E#i8OOT|d&O9)e~eWuW6zn&>ks>fWIuNb+0%QHx+qT$epZjy_mOf{q!QdC3t6JK zI<$@Kv}o}7#^M~_7U6TuMJ>^y^*>L*T+|IjS*03(8><`Z>c3f7BFo^6r3F7Kc|xAc zt%@2mTmJ9XLBlUA!c$^D@7O5NtX9wDjX+R=eGpXPl(Fpl4@+US_7rgFDTQjCHZq>; zUh4hfCUMj$v#m}|$qy|U#`Pjs6EQKypk!o(mEG+{CVDSu0>36`by{QS=vdY$6ctU$Hv2&(~9k)VSM>T zQ`^jFz8e?+FInb}Ea7dFwVm-s?&Qp^;m$K9riGjezj$$`z4JOHWn(&YFjAAvwesf$ zkJ>P8#5A*x^m1_SgYyu)379z3T=#Lwj1tnniS;`XgIC1qPWRh{Z2Ek}>A1oCFkj9T z@A(lbygKym>(dJFAKQz=26GOW-X~)5o?A)=6sn`>F3Ybsp8Ny7PrLxdbET zwE7F>%r7~Kj)DXXI4U1nu_0!|Hr-fc(7xF=5=3g6tQPyLl z*WyV;E)vbrYC-u6|Wb4;R({`u+Vf6yE5{7kby*2-oq;tpBa_>TYi ztn@i}Q(QvOTz=_L=AFLes3v`ei|5ebzRd1OjC^O)K*mu+Jg;`k}|!*Y$U-#dcmg(o(-^lk~5d z#TjT?y(jy(-JcJ2dBJYRdDs*J)pVYZk*n`31^_3*OLA_f{23B-vaMT$jCKqs*YA^# zz?*H_BY(S`<2vQxmeyXT3#irI-CXf<{S1IN+&q0JsVLBnVf9K+N(rJ|OvQ8%W#bmo zJF`>TxC5rW7*Xf^c;jE=dg9z#FC}@)CU3o_e)QJGd_TMu zlefYZ#q8iX+K6r-fi3IJ(Z;45jzf?)D>p?kvk6rLdFL;)RHx-hO*VE! zK2n#3v65J9ipa=b>gId?a_8jDh^Bx#-HQN8-G&oRyqOmunzs~L6l z6MvxF(0a1zQ<D<`LZ#EwFC3iYR z?nR6aXPTQZ-1v|;pvOMVARfX=Ug)7m3mwaIWPc;{xZ#t~W4#_vv!3kxlkUv^H#dJb z9`NLwcbgN9|Asf~`2`8vjXSCB4eMjnl$8@Ea*Rj;5@eBm8Ygry#tB~xG`{lkE8{DN zZHWaeA7!kmUgOpI(|qG?ZZp2I8S+?`cm|~m&W(pW@QJvIQLcLVn@$u)!}&|E+^r)D zA{UWdQ5DIpdq!sy9hX0mK*(2D0l?-Y-24w@<0tsV!+UjP<0pW`ynDElP_E~l%Dv*7 zOd>k$9`}V2b%2xjcl$P-Y=_cp_81TI+9xjwWsJ7|2jRbpK0}~h?iH%6EPf%vUpb}$ zH|8(h$g5_K?JMlw4IdGR9PFfoKA0H92k{B;QLxcL*@!S!ojS%qiO&ZcpSVBx*sfzU zS#8O>V!xXkkpAYQ|64wpZqjLMBW^=fIVy^KVUK6wB597Yt{-6U5_2%Xj;r3^diGZ@kZ<#511}LZ3JVGcs zkqA;l^jSh2bKbcEalCtqO2t&U0z_&^oT-~*99DP7CnA@M9+6`5B;4(}5=BeKxbT2J z#{48Wt0&ASSbM^D!P0e~q#Q-HwCv7z+ix6s&Zk($L-UCDE8Fo2%5g8BhTgvzbQ72t94nDu^icxZdp;9&yGP zRzxbpF(z7EW|oa!d}_@w1W-_xvIz@j{Dfj71S{uFDglxOH}B-*e$1Gqh+}cgCB*S0 z*8=cg!IKQSu}_B!E}pfFh{gPieR~AjSzQdd`SzYM!iEl-|Gm-I+LipYO)-+DFw4QT z%sMn@Ui6qA46<0(;VdHq^&-4g6Ap_&s)#~Rc+YYM6OmZ1LltM~P{p^=$LL0G65Y5P zN~q!=8wmk?8*_jd#{3$(@#^ITpoh_NP>AJIn?0!sPy(2bn>RZ)vgGW#2X2=pnh5kU zPY4b%ZOrE;2lR-golo>KSUfm&hQSsin_!DAbl75)uMS&$*9a>P-`Iz>PfQS9q=!DU zAmiScH2D&^ZI>VYtKbLy+>S0FiZG|r3FKe-kU2(O{Ett>3A%qsM+V*sSe_k7O~U_1 zU!?0tT!-Y)Mv4x$En*65QUfr5!WY*^H z=H6hYd6RF4w9JDT`ge&hKEWiu7}iS9rZZorAogzDJZ-chLf#ZzA8WM!3V;vlSDheO zuM9_Pen7$%ZWG#T`Bs zy$(wB)csSwOr)Z3SYs%0b#o{&yk|S)9yzp7_s^@@#_30V;+FB6Bl!HK-P#y+N61Z| zIIRf{Sz2J=cLeW(FfM9NsORFU&}djzzA2hVEnM0|dQq=i7q{!vEMk1Bo<-c;>@S~r z3HQ0pSi}j7jlVlX?wU9IL+Z+#pN%Kw;k4>cu!v?WSQasQAA!AZ9+#P&!?OMJNn*mK zrt=kaVjtuA3Oez0A7eW4rE%D6#%-Bx^hD1~Ghh>sLa)znFW5xxSg?uOIbai2Cy7m5 z%CL!#!65+l3e&NPw}?&TtkbcH8b_8*4AZlT>fU-b(f4QK-qz~bM2?P4^zCX6Cs{f+ zQT|%TCKeK#D7jAkyLn-j!hgvg-dr3pI5KL3Uv%(Y%?}- z#0IL7qhk}l)YTY3P6~BwVh3^(yN#UqlM~5OvsMlgeyYYH-S4Z+zow~%{8+Dp4i}OW zhYJ)*fsR@9A}2+W)Yz*;3d%Q-lK@@e2~-$7A_~v#^5uw#M%7eQ0Ov)DPG!r?Z)3PL~z?ahyk*(&Kb{8#jLF7ob z-s(q=C^=lv!ICS^jzD0XcKC0(<@1{+zI^#>A_bp-2-Lpp;3!ep1i*uYP{js973&C9 zyifGu3_=wT5PevI!koVms&gEQ)f0Vqo9M$jq7R>t@ZvlwmPf^^Q7oHe9J3F$=xk%x24C_YdA0qFRvRXcXmi`SrPz``3RdbYSg9|-Dt#)fR0moNQM=BqE1M#3&*Wvxr12ArkR66{|-v z&MXv@IH6eZF#N{`^n)89#YmK)2+fEf(({L_za(wQ1zieFxExgz{_x9r58KVwfnfiS zKm7Xt@rT`he?JgtCKrmH{FZZr&A${Xsj<*?p%!kYjCqUD{QVv}j$_3|w5`bM`$hKJ4D=M{vklPoHiEIr;YS zu?kVry_=7Ypv2CN2cAGhbI+5rcflcVzU5K5(&}OMwKsq^WA6{V4z^O`^LtH8aGLWU zmMx@`(chc|2wHk`q*Nf6M92F@t3=mqNE8ICIV&6J}_{3m53X4 zReT1HXH47R0$<5HD^CJtjeZ$-eWg`(_c};`|X$mP__7vRR_>U&hdE_ zsO?~#u1Av}Y`O)$HLYwSwb8B{iPvU>G=?N8cf*7Hg@C-l1&c++oYHTF^qpkHwQldc z5X8wHg2Y&^b=35_k6i1h>Guq|)Se{cXAGy}tN_p<-k6f!Z(*G%#vZh?Cdd4f%d7J9|$hA68 z^|OPVAH%h3Hd&Q`>1xWgN^}1K^r!8H&9{$SEB6=SPYDKG>w14Y=0b0o_lg8gQ+K|w>VweFg@0GfC#*IL)v zMMIyl8E43~%Ht-1TUtvDmm$}B+;gm(s{z+~JU|A4zGhr&j?D(ltm(vX8E~!KJYhA| z`}QzsC^yfl%c@Ust*+%D?=%~!*#=x|dW@gwVUYpXy4~8=Y=CQXuC;m-Ng2+i$_=}I}&`BMU6M$T1#j7b%C(oN3Qj`^+a}eV%#QD# z%_L=*MU6M$S|Me)s2;(xT&p;)u&Wx4|H!pk_8KAG^hvJua+f`D`6szF;96(TH{e>w zRX5>U7oOOEQs`?L&zuPN8M_>uwFT0j2?db*z2I76=0^KLJ{Ee2d`()o85HPG_= zjY!8|#p#Xsqnvy19mEwaNl4I;YkgQPPlhZo%e6kVS)EFsF~JSF*7QZUQF$h5`5AJp zGB5YvTn)HZ*%0sF=`;4ESGQEash;l49ghPkrVpr347gTlrSRKb23)JO(rS(Fvk}*N z;4btNx%n7lL#|a7zc+3_Qd6^BYu4u7vb{)6ZOpZ5Z{(x$7gV_c*Q%)!+UzmlS~XQx z<8+^mxYmMYeZjRJL*<5Ct7N{Ps1%Zi;8*q3cxe|~vxAth#$0RagyX3E166Lowc?CM zQ?50A{-?Rt*u?Ho@=GTrKLf6{a=anedT2)zu2rk!T5}n$Rr6*iR20>&7u-P(&Ub-peh6x|l5Q2Yk+8Bo?$lye@ktrV2_vuD7B^~Zj)H(I9_ zJ8}oIpR24_&}WKk0RNfuP|G$Q&CSP1&qP_+E8DMK5gP0|V(^)G=$tpVM7?o)JYWFy^bBQi0d zTQ@RvYgEAUZJiqF*8XjYZarz+NVodXN6@WR4BfgQpgG;TY1sdXZuPd_BaPmdlB1(r zqxU5rZlqgZFm!9h5FC;uTD6gGjV16-N4Nf9?c^T=x^?dVK)1$xq+1L0bZZUy(9x~4 zn$xZIgZ@Xl)py88x>Y$>5C|N~nTR$+x>e#L3Y|?qJ@BVl+Q5{3X?ab$I!rEJ@ZIneEs z=~NGUMnb|+^#dqZ-1K)&x0u$GIXV8Vs~DnydZojjGg;1)b6PEW3c3vC^yjqIF5dii z#JqwO%d*h)JiCx246hQc7^@(bl@7vkc$tE$XloSVYeg2D*}f`q3EX0q049MMFRcyl z#q3KU_#FxBTHSiaQ2fu&=&4iyA*_Ib6b;J>UrR=rl6OGh3wlgMT7Tl@f3i(|t(P!#W|7oHy@X*g*%$dbEFaD5 zfM_@*VOZmMm6?mqQRg_2FfKh1U zN}^p&56X+Tk%Xb*kWLx2<*IhPPcW`48X4EgfkupLaH;|0>VAnB*Y=k_!MIl1GmPs( zTO-C*pv)7WmN*Ufu4ozI&^y((pjP&UWv$7ofeKdXCd|!y2b69$qw!mv9>OJIAA=pw6lHamlsL zR)LOsP90V6OIo42Nvf<8+{VbMi2JK(b&MUkYFAM@R66UJk>ec(Rs}5UHQB%Pg;Z$c zB)=he4ErLouU~$gwZG#k+TgOmp2NVA$~<6Iu$K_l_13q6u~!;M4Dy8+S^Iid(RN^6 zBF3JV53CAURsh-i`dPA!=(CT|o1sqVYgf@eU|r6Pz492aD%b~+eW0HMW3M#O7xWfh zX6yF{Rt2mpkg*3Jnc5my*0*Fo+)u*T^BSO|ghftRTJXv%to?Aps!%?R;U=~2z^dxN zJrbkp>yZqJQ9a6#b%FCc=K6YCc~@guOjhmzfiEcLX<1A*jIb(@TH`5%(SETjA@IdK zxn8Wybu zr1B#si|HX=3&{N%^Ek;pFKg}p17WY2jhVtE&*AM_$Acx=>5@4mTX>ycvfcaT@E<1) zX@qCLbr$l;{zUFWHws(5OMh-|-=;6(qNqTs~ zcl;mc&HTJ~HoIIZE~-f_kktMR6-B|DjqMx*%w+%*&TY3Mdy{p9y`rxZ{WvUz=q!AR z;jSZu%`)vDcDSMTmkzv0jB+}1NBKsOnhg6QvM=-#G4@IWPR-%7XRLjLn`i^{=)4(w z-g^Px9MDq<*`M`uX6$($!mmLNc@yU*Nqf%P`?`y^L%A)Ou~%9OQ-6T{WwI~z3uNq- zUcwOA%g8?LvzM%Wkh^H#98iReJ+G~hFW3x7i|nuZ$ryWH|Brfrv1zYZ`%rgL?z|25 za>iZ>J-}74H&zq{k;Li&8Y+s)pogbLc~()B6M*HS1*d1MC|Ym8sTnAWssgZF?tsf< z6-B2*4^Iww>LE43Kv7g_APH!wC_2|&bXAM;tfHu72q?sQ%nC+PG>a%i4j}4=AzL8l z`7U?0>h+K<+e{Kz3Nchzg5C`zu1lNtp6#7e?%Ev_!BxFw9tN@?Wn%XRL`$v}gOcAf z4};Vza;W{8klJ6$2hK+dahQ_(FC_3RSuP3^c#`QEVv_T$9b}5@ja<|fG0-%H4qvnC z#M-NdDYte4m(_-`k1CVLf|HUP%ZXN8Q4`wsmg+Z<6O{~CO8bM$YRlNCm&r4O!DSiC ziE_$i$5H=#*c-};YKJRpMx%aKPBgbnt_}mIWh^J^P|kq_potZTDnmI@`8Ud>#i*Z^ z6V;Z<>tsN)jO9dw%T**H*p`Sa136J~2q?rz=-sh$p1vN)IRF(q8KH4w{THx|>U?1c za3mR_aj90F@MBB={TbweoCBnK$AF~j5KxG*P{lJ=6jhaLNkvdZAs8r%%0>dBj-PIz zC>rV^u8{9AP!!cc)QOR~TURdGg`^d$y)DkJT++LD{Z3|f)Rl9-m$sJj;=hE zm)$uLq2MRJKH4Y`{ud^D2zy2A;`~dQ*6~&g7ol|x)H?0zu59b#JVfdH0S=;^{_!gQw*I+zeC=y4C(7F> ztc86h*&E7AHrTu_J(qzw>|LqLiuaV?2wGdmnQkH^lp^IKoyAvYKA`sK&zJ!_O67O?BxyiYPOfw2t$_#?+rH@ zKoH%RR*#vS8>$#y7=oKj!9DdStx?^q!swfAKS(%h;~%Q>DZK8p_l7G55?MIF>N#UC znWY>J`xvrs*v{Hx^izH>F;o(b%S>vi(HX{#!5cVf^eN`>i2CtUX3wyJ&`ilBn8K+#kfQH`zP1*Rb|uy#iq` zyAFHq$ZuJDhboz+ay2F@qXeq_Q<&-^#*=j^U(!b>VnYkM>6xn z-%EnWXlNsxIee}XvZPT|d06|Tk-cReJ7sHpM+4Q?_d>i^HlpC%q=_(cJfhHj$@E)r zQI=^hZxkI(hjcIgdJ2R1BY%|D_`g4{6*{P((I7QoblkEiCZ7z?&OE#~ro%(Kk z>6JWH1gfralF11~UG3EQ$Xe$)VaoI4gdWSFoC+PSToRKu10# zF7`o07L+`Uy~E;G1uXA3(?$@xEam`k*2bK(>Qh{+ z`S*x;#XgBkUy}k-^4!;pc~6{gQ%(AaOA(r^?<-cxQ#}((o?1r9lUocWPwp|aGP!{N z%0S68bx!G0ytMuRh|%v@9ZT7!pVr70W5)u3tbWJh_H0S@4eZrW45TK?y!75QN375w z6Q)9nW*8E#i8kURO$gU+=4T!^Azb&0)C~IgrMZI$@9%rlD|Wr0s7wd7#vb7!8)I^D zYElw%a;9tUnGHrNQbtX5t;=%4I{T89PDT_$p$ANR5bc;KcD=-^jb5Guy#(w`C@IFZ zE6%W8y~jr6_wV^On;0yW`-Qpr#6;p=@p|6gE;Fd#QPT5_)fuV`IM=r})iUe=+?H8l zwtfC9)JHWGAl{*yA1p7u0j`zlx8gY2;^B`7eZ?f$zCxv=Tmidc>SVjl?3NuAev7K8 zM-@R{+8d9M)me|y+INvpL=~wMsS2lOY!%$T<|!V}~dQvC4+g%RY%YAC1pkAG3mvx?5I^)c<$;F@Kf>)Kpx=>%z$kT5)CRH&rF?+70$8TOk zwSahSG@ib=uHUw*n-{zs0~@=X#wTA`%?_+f7lH8p`8I`&z^XLJJb4X5m`x$fvQsl5 zXeO8+p%C(Q>#(|*4MA3>5$}f}>vblCin)@YyqIpZLVA2XgPR-diAoI*`lAwAOBa^-h51wWyjcuRyLjy`b^uBD$SQ>ID9ey3+9~4YcUJ0 z@0sWPR;DU$Q_ogD+T2uRcFvJ|6Ei3I27mE)%N)&I8r`U>qyynbGmzYrdL^=M3g<~u zG|nlQ3OAWx;*HbTHTu5$c22pzg$%S+O?W~K3^PJshz{Og|g=V~Y@=3Um-h+$6I{tRX zMfwfyS8$Q^6}gc0-@&?&$2$kkf(sSnf*WOCXu1sBqW7sw?(ugpF0?nKUsHE{O)k>= z?_yo7jdvb92QKz7E~L}UO#|RULoWR9^LH^WoNp!#rM?Iv7wZ0dSr=LH&aw01BA0QI zG~L{P8C)bipuTv(-^;j=-_!(AUkoD`_5HJ17y0qd$(pWsNbP5!`ub~JTeBtbT3CVy zY5pqCm&|&9*F`%S>%DF=t@l6OzGR;D+-P%Ur?&w6B`B(2q&ZzIr*g3BxsRwj9`X-- zC+*Cvbxo?dPbaJrsMlbvKdSgBSx}`>o6W?NgLQMHi)C~vf;&vJ{GeUMN8tkVN#x@H zW9wYtV%pySKbM-BY8s-T@X^x9_c*`M`Z&wAF{d$0A(v*(}$yRnXz z0JjnC4TtM}Ju!)X5qGI}EfaFH`$At+`T8~P=q74_mpX6OLli{-yFue2M?!Rz;J=#$DeZ+)L`mWhd32f0?x0khW4L9^EO9YTgtAJ<|? z%<;)IG{{n)s32dn1_VZ=ZIlcFzUiDx9>YH61qK#6bH?kP=?*A;rVquojzoU>kmodi z!*aNA>A=sKpGgO7SZ-LuyhH86hWf;NPCQ2Qj5!ajFb`EULtfA-$~cKwfqqLu@R(^_ ze>?Qs&*-iG=Jw+~f)9w_nh2zwllW3CVNS ze2`!+KnddzBQ22=Ckco{CnXZ9wcO%#kdU>2MDZ7%kw_??>ldIo7NUfB=sQ}XAWkwJ z{k)V&=>N;LxDOJ{uV{-F-a8^u`kWb!wunZFz@Z;#iPAWULJJa=M1r}V>+gtHJCyMM z%KJbh#4luv&=!kOLOJvkEm0RINoIBjAOB1`pymm5zi2JqPZcCmq^_;^1qAUfF=8Z-g^GIoEOoF;E~~lW7pu4zAkO03aF~WQnVAr3OG`)&Np$Wz~&|}j=S=CgTM^l#0=t>%uU50p+Jf8!;)-?gfbx) zB}yp?L!f|`IPJ=}9Sjl!iG<=Mv!)Crv?!q(<|iN$=?RSeAR*q5b!iZ`p(P4j`F3d_ zF_=i`UNSvXBEjiV;@B`&8zPaD5Qq{AN}{>cEK%dyG)wUYN@WZu=BROPo~3+G?E;mW zVTH6xL4p=lbd*Z+CMO!?nJZru03fpnh`f<`CuKG`{cw!KOo zo|bs&$`6xnhViaFd0Eb(mC`iXNKlRF2ba3pozm}++uB>)7-zWD#X7^KuBd(1cyedV zc&%4(zJQ0*buqdMBXtndHV-A{cG40F)otzylvs%pYUjfsVS3lVb*>mvU%~mFpd~x!pc#5A9}|tAu|+btHdbb@DcKfa>6hV(wWN{t1gkt!}(T zTLD*FBiMp_XLwSF8M-jiSr@qNh=R_#HwuR@h5ZOJYUoS2>`cxb50!{mV_kwYd7k*D zUCABjX1Yzn5kz9KTP{lEQxZ%4X$j^vE^jO3v68tE32lNVeIBH;Dnp4Io20ZvtXnZk z+@K_4V`vH4b*}9;Ooio2B=Qmrdr`uG66Ko|w8Toca+EMo5-awCgbW{gW0t!qVW5Na z@p(eF{3F=Ot^rT3+oT-eMGW+(n`Sle5<^R#r-*6y~!qWYeipt+-&(v<9`4FMJA6F@H9ftH=NJyk%wco^``{%Jvd$78*s7(SU!{*S{>_-r zdtiuSJYbCGX|0#o9zwRUk%t-vnHE5 zB;Pm5+FJHNX*2SH@?e{{{S7b`?VRIdF{w6nB5Z||3h@+ah}E=xKiOxqmfOpXG(cctM0V>q=7CfldjMf|B@{8I9eUN^ZFD6E3e9dQ**3Ac~Bn$bGxu2Iiu1M|(( zoLi^Q%ik+AJLe-qFoB#d9fmaAcbX*s%!_`nitQmyDf2`AqcritgZE)W?w zDeANed4>*)NZOIS>2@^Am_ecW(JP1A1+ulfEE2YLE`>^uE9siE>smk7 z*3O0l5L62tvTkO1R~f`Wd`2s~x3g>hb9QG<9}L);v@l;56WhNVtkwPJP?Gb`8MfI; zGi=G*Go-F5?F7(69R)aQ?e+VG^w5A`vnH!mlc58$hclehriA`e%0T_9t1SoNS-gX*w8gGeD_Phfi!Ab`S-tTtfh=jP4r7Kdnp@jRsU|K@=YnY}KBtnP;xJ6p` z89?y|Ql#4j6A4|X6bVvzqJ(PSR9d3`*RcF@kWdf_^-O!-A&_8gL5d8!sYF8F*|INE z^h1f7eR5htI3mno0Esjr!Cq{i)@vqEWFo~DJ2{bHbV=!t6kaGH+aEzo_>2hSvE!i& zM@A7p7Fqi7Ux2FiPv{Gp?qT~$Uy2G+K7qcZc|r}B*Y=aX#E!M+^o1k~83y>ODW0Kx zy!e(l6hO8k$TqtL1Vq|3Wgvq1A_()q5*j3JL|6zr5kO)Hh}hLK$@VjVlp=_(yZZqG zk{DEqAO;E~cI^QgWTidFhw%{l41CoD(lB0JE1nG?yAWijT^s>XdZY|RkU#{{A4sG@ z%14Cdur~vUj)2JBEDHsrAS#b>5I~IGYiN+WLE@(XB6*4yDI9f(2GQ7a21@}%eu#ib zgTx;ofiS#n20?P|ivLO)qJd@L4$;D_KWrg7>A3Y_?eNGs>KyqTHR;)!(OG{$&o=yl zJzH|{1l_ZPf+8Qv*dm`uNXEMF015dH zq|h8ZLrcU5IrZ8F67@tv+trfw5+sx;;ddnY43P*Kt49gOPL#;$agLTK4s!Cz1&Kx? zA@#7#wt-8T0VT4JSezpg%CTWPK|)JO=nfXp61Re!3YB=7I#0$opV2BItoTC|Zxmbw zjH5~plB>WYd&2;@3e*k2>;C!LgXAir?3tqHmx8a=!&gl~7^m#h`UU`q_9%daAGNqZ zKuX7E?E(qh@>5*JM$Z|D~I~wdFVR->P$Vd*kc%#+KXmIAcq=3@VvS{|trd{CQ zckBY@lRLJ-DVwmPZKgU}-Z)yNd32bx9%?E*Os>Hf?bG&x1gi|X2=AWVI3l4PY?+QA zf1rf>$zWPSJT@%S2ofPgf_>MX*L4wq{f)h>%r2NnNQb0kAjKAxP@SAgO9YG!(=>yG zf=H;#?9=`P3DyIoxNkR=NEn7#Zbgc1C{c4#PD?1phUH7WVHhS7><9Kdj~JjRM~Vk_ zaw4G|nz9`!cA$jpR0J)NIX27?01{dvp?+web_66?kC5V_T?CO3`&;fpirpxYb!q`E zp&JV$c^Ncu^!(P5oZ$^4c@U`PABJ%&r>FiD89>Jb^IpQZ<@6HAEk%#hWZY6$!$>ZN zy94QIGOWZ;5Z}@+1CS~NskB>iFsUDmTfrS#Z$9MXYYu@EiVvd`8V*;36H;*&!^XSA zy<|cw_CR}wVu+qR4IZ3yI=Omarho)Yd=^v>M5{EkR;o50Ki327cu1oVDSOEgdLc`WVo7zW<+q%O@9q>tqTsa>G*^LdN+ zL?v)S7OEI175UK*w2EJ_Q?K_>mr4QxzQURY5;_D)I-mT3NW@K$odgNxNd$@O-Aqf2 z33l@N2okkKLg#Opy%HqE7eJ!+yhSsS&`zj93GpeE&>a0lOH2xOD#o1r^+ZDct7SMJ z(()@&!tX-zCn8ZYL3Iiwbd*F+?=K)>a-9i!IKPp&PCtj45V(pbXJ4@R@|W}o8-t#n z4b+gaG|(Z=7sk>AUwjN}xpuVm8KX|?ni7^Fy=YRE3SNwdr>Ca*+FG%3bjeY3T7N^Z zyu*IogcACT z-)IS^3CvF~A^3`MjEr%f4tAvih`$ncfbR6&?i-O1PqO@g6dzH-z1ZJ_47Sn>d?n4C>S%OSCXaepXec{w{zl{|+lXNtS=*^px+2CTxRfm!<-m>BWY| z(R$lRUl@Lmkw?^v_SM=`P>dr2w)<)7f(-c;O!S{rA;?Rv5ai+({<7Cr2#&C-1*bTb zf?S0!4<2hTL%M%g*O?ywJqB|isi}@H!=ZAeuNV*y6+Y&&DIO}SM?)-5t3WqnTIhAI zalSmeS*8U4@>BZL?ElI5?}RfmuH2_C7DXI?DPK9D)>R4J`f}rVn0RSetgM!`Mn%o3 zR<)Z`&76)Y{F|RB&*Nm#_{_6hDljT6o!u-jrT*Xi3}+i}IIXvQTCMs&u~Cx*Z+&^@ z?En3|0vG%A<#77IR@is{)ex&sE^^xzm+7`&T;x{F;5U!;e+X-t8EtinBA2EWI7ik$ zVwUXF3URTAYCX(m`K@o?&#N*}8IOm&*UXZ+cXU;_cXCy{hn1Z{qm{ED`bsgYMGjl; zl3AH}O3C^#uh=BNU`?&QO56Gx09~zWx^8ii?b?G}nAy|J8rUxVzg+meZ9TlzylY*i z3Uc)O0%F>GL-=ehXo-HycgljDgOJ;IxN8?Ln1uC>aPBgxa2@wwZvtpIEPn|zu*#Qc zlcvm#^n`Ms9(*s?vN*|T<18|>-KU4p+qIn91qIxhm9)_TqNQS7$ug{}-uk>N%nDz0 zg;`;CX3CV-S;-~x-G!3T;s`Ri`=L8b`Gq5xyi4*!OFCRYdA2a03(F>Fu!@q1+3m>s zP0JyPaRr16q(3CFK)0>^cXO|M$hP9jOQzLq5^+#^mUR{+RA*5_9x<1e&>U7O<6#w7 zOC*>Ntw-i?A#iIONF;Bynmg%6C6`R2O3o#n0~uxkW$cJ3kTLy2IPXt3%xF`!qQbi@1nIi08%vm?_dE z?Pw{?-VZarrhL2oPP!de^k`U7TE9XvI&dy&#}z${D+~?PE@(&3xvZTK2J+BDzY{?H zcGkR;&z-Z6D&;DaQe}89+s!43^!1;uySPCUcXPug-t`okSb@=Gk>aySVYVkM;=Sn4 zg^M8~Egs!2u1npN7028l6szrB`@H`=cX7D$;G7WXzCkSqujX(bc^L2k;lRG z`)*>2oFj^aV)%?LzRN0sw7BwO;SF?}n|Quq?rQq{f+Laz5a`5+BGo?HLpU9#&y|nB zl+9a>b5C+ExB`>*9hxk8>YuHP4-C%t9`y;(JjMiUo@}oux_Mq%|9Xh~8k$4%h=hu~ zvR-qDYovC8Y}z(S6wES1rLFdIR8zMO7oV^=;5Awi1$K^av%gFFBM;_ZP=V$>(+dB* zX}1?px{lU%);{u3;qKV}ll=aIcsAIbI`?a*oJR`wBkm_x2$?!`KT7{kyzbZui^C{> zDB=(ttKdfA@__na&0P2iyU7mF20usrGtxgjO z#^u6lY~;r%p_zY{mZ&@`=@$(W4Mal!)jIhdNa#_*|A^IDA`y6*RSy!XdX&)3KTk^- zk4gfTfJ8HqV1BnQ?)My)(vE_J#ZjyCL_&Mn?;kYBKPX|GZ=fZTb0yPbK_a=L)h%Q^ z8;e+wkmaK-j#?SKDs}r3r8oMdRpeX#{eC65#S?7D+RNM;*uMEFj@AxIuCU@x(_iF{ zhrjqs`bpwqOVLV;RI>C}e=P6|w5;L_wyb8_8M=pidhk2Dbvt44jPyNW5i9JizRzfR zG`O8C+V9NYs*f__Aj;#pC{s$w?C5Zomhtw;Ns#PKfaSkv(qp_m4kbuoFYE%P#nG&@ z1Sb7B;~S`mzadOcg;`~g2S4CCcvb^}(G{^a<%5I{C9!)FS9%!f`oVx zNH7=q6%dKi<3^Mae#eU3=y0BvnD4=#d*{@MzTwa;?uJA2#rM<5(5$^uDCf*^0}m(esdBX`ZlAdd z0*??xzsT?6L=A^@dd)4Cl{It}D{J)fMPi6@HH9fFk4tiy9l*;gTfLmw&Bnh&9(Z{P zdig~wv3Kj3Zn))V4H9~5ln^iJNK2$27jqZ6!P0Ig@7C)?;I%@LZWSzvYASL3+#y-g z)tc)Ii>a$6F!*-fvP6GpweHyD@60P3T1U9b6Jm8Cj8|$uXtB)7tF*0-cxYLkaD^v$ zsvb6f2Ex_4n>QG$OV=Hh9TugiT*tluxLi(mAz-DkaFzZmttk+LYZ@qBs9BT6h47kK}Q&alsgc#Pzdtp%?UGpm8j!Nz4 zdEK~VxnsKUoh%ssNC(Wx3tR*IiZ#fun|_1lukpB%AVEJDFC`tY#-lVr3VUG}`t4G` z8^q}f^0cU;qf~B)Zqh2xJosto`u&N8<+oUyD$#FI!gDF>CXp!3Gopkr6+y}?%@Qv? z`0LU2eTYQvEmlnf`YlR?EiI%a>hh$iAR(tD>UP|sDU2Tc&FK2RM1paf<@psPa#3Q_ z(xh92qB&2461kLwar$jq;=Kod3;OvGA`y6-rCN_pixT>!ez%Flx|0T!V64T8w1~=S ziO(MV9S=bwj7Vs2v+nE!31KWqyjjXBClcC|lC>ZaNJ*qsnI*nM2ZSrC2lFB6fGsDa zY5GT?IvX21yz9Cu5?3+JCI>y`FZ7gQ54X-B9Z-4~I$$((z#Ubj=Zu67SO6dj1Q{Qj z1nQ>icge|81Tj(|+8K}Nx^(Z!pD+FhSEWY;WN)4{7cR{wETRx3vg=*~qMu=N96|CB zr0CHx8pLq7kim@wkURpS8g1h*ssxZy1UVMVsvsc7lZyWWh~~d&)cmRn8l-Pe{xU6q zFe?ejjXY_&z8XLj2ol}3fPly%Z3+iGv1bjwM9W*d_?c5i zh^>l8>A0?SRS&aut;xdb_XTeBgD}1|`f7-pNMS(aOHH=a#=(LNsk(~7!7v`mhTwRp zTfT}O4{w~5e9?=)Kph^FI@~xZ{lfI5cH!;EW2;qUR25c3q7fLFhf)LcaxJZ5I4KdI z0%Imu6Ze{FgM{oQn&XAlAtIrD>^Bz8F%~5Z@keQi;!~1j)n{;|TrzS9D_HKF zm!PVvhU--RE@M22kNMuSelI>P?8T>rV_jF$PYVleirP06dFUHRUCK{MwalU5^e9oh zOT3aubn4ZD63P2eB5yKCn4*8)^;E(v!Yflxka<7$3xMB(;K`-C0#`aT?GZaH2ud4j zO}klb;S{v%RIH9-Wj<5~yO2r*S8j;juQsWcz-1erqZ(nNq4PEV!{nL}*rMv$LivnT|O*L!r=A?e3bt|E>X00r2B%QS~U%Uk*q+3v;Y1eC7 zqOqs=m+o-kzepqu)qck}vmmSsCGMn(g5M_BbDwrgw)xv+|-T*0S(*T`Y8wNx~am zd-sVMCu||lt7(U54@{{!$(X%T=uywQ5`h5myZx-Cn>}+R)}8oWedf^w76)foax}6Y@LQXD7>hN9I~bPte%- zMCZ${j)$}e?Tx*^CtVvWgDKz^m?X7m>D#y{L;4*M2X%S0k^gp)HtETKK6HG$C_|1r z{l1s%X~C)Y?_>6~j53aunKP8Je|IblhN<(*!Dy3jFE5g1NbQ3Z_dzM~_VC;7MXjqz z+6R_?gDh84yI>8m`M3IueykkBP)qXj+sGBnu_>^+F*XHO1e%8m`K{B&8g*$q2&WSA z$wa*0G?<9@lR)COwC_JjCHYD^!BWIb#e}eYcjW}ErN6{UFDA3#fy={u#>=*L7V8P3 zw91n8Ys#DqrG*-1poMy=6CBdwyJB^W=P3gmV77T6T0P5x{S%^4i>d16jv5CCp!o@C zOc4UrGXg?_0cacm8Pe?d7C-UVzM0GAA zpJCqGo}X;bzb2j{gBYSfqzJQQx8o;4LJauv@l-6SIWt5r4X2TH4Jm<$Jf3nCOSPjJ zIni;Qet60hmJ^Nqd4^9~2M5T7fkHe-ezGI~S{PMUFrMztQ9BdkGt{RX11CZg2ji?i zr#(MHLT5P%m?;^b)qQ+ehLWQuD;%30#|10K1j9wiHG_T5F&vAFz~a_);493VR|qf#?>;jWFof0M0AW8Nx9iEa#rY#@Z)k<+Bi`& zeTIZ*D7okrrFytMBwitJvJ3?ivlM8fEH2xPHZgmr;~HCNm_%qxQ{I5m6;$k%fl{g0 zOO1N%eeqa?yGGej!Y{#!nPSX>L%U4|IG@P~S11t{(O%12nPnI+`z>Wkk|)hqI8zcI zTs|QzL(WG;UXfH-1c3_pVVqMD^F{|k?0#mk4O@nXOgLX0!x1Nf!WqKQpH8hH+ph?MhQa#l>>d zE)Ys;!_K0G)L$JJppXKE)(quKJ?U;g8-ALugR_zXX`$Jn;#&sOE)c41!)DskL_$$4 z3dK`IDd3#siyIut6Vh=TZI%tCRYP6i#E zl{A4CiVGFj8ce%DDBGT`r-k62vP^g?Dn#d0^3*88HYSO`3Nt0-y8=>M42GcGO3U^?u9MsmIrLfSAeK4yT z``~2$edvS9wa^FQkF=>s^bCu6LGgz0!_CE1hi3j3ruhw4soN3fP=9o+pCEpQPJape zpxjd!?n(DSJ#!KE!9|qQUsATCy-fKng?S8ww6qZWGxWhBP=%?K)3;-*@Q}&NJVl?8 z=o>{ejUR0$dK>EG{1u#D0cDzo!jjc0#RqhH3TJ_*Fq+2kDU!utAB>@#{*u&@#?gIG zk-Y^Rxg)6;`#bc(p@1`ka(YL$3=f%byhP(B15VIn?1MA;@n)Ri3a4Ie;Yyc4<2dP6 zy5VqrGEz8kp2Bz*L}tVeJ%7EPqrV-r5aO!D-g;*IIEktp8#$Hz?j%exog#mH zST0!5;6htaB2om4qoFC{uArkn@)TCmO`*8tz6zQ`z6zT{c}46>H-(m;%De+YQdc6R z?juT@Jp!t*0-J*6%67*?R8y#@h$@lDKoPCr*PDsT6i&`=;2AGy&j>71DG~uEhr)U0 zDXgb)beF;s0Y{sNIJs90W*mcUs>}{9#*inK+#-65(&mi;oCL%vbY<6=ab}8)h|@&j zB=DQeIQ8JDeF5hKjiW76>23i|Ife7qQ`kh~2(M(V1{}j`#4%jay3y7Vw@Fn^034ke zXOJi@dMx0qMw}!!b`~BoS!b!pKLl_BLeRff^QC?0F}?*HRR?D!X=t}-0Y^2Mc7af> z8=KjQ2(e>D<4`CFg*cn|!_7j&7dzdMj)USi(?Y`+i*LzIyFkdV6FaPvn#y%m*O!b_ znk!i;SiuympyaBd6I778H6OTKR5%Z^7PppMU%6Na;pXX-r8}`}I;l^Ds%E8jNET;C zm(IyrlL9fqL)7Vf5*lWl{O`~(!MJeO9{Lx?@XIf-YzxJffESr5z>c@W0{v^ z!InPb?$$aJKRz%{ax2BPYtjC&EBa3$YiArccbA3Pd>+QA8hMBu%8e8nzyYwCj;5_o~56 zWDup+B!Lj4D^0Xi#PgU0L_4VV=)$fs6KxY2kf@F#+QE-86O}J^@|g+-ilq%CzoM#- z0-SORXGtGn42@HIHERPp#su_~#VXZp2CI%Fe<= zCIjsjH6czjg|nNlLmVo0Z0S{u9V=`tP<^X&>K4crb{$(Pv5!#KCx?-0iQ!{QLg?_Z z6C(d9pcOC$o&OmBGSQL>VndxEr5%AEP`*l)*;x!)wc_GF!pnXB_2BwwiF6J;0tC*% zM}TX6VCjjto%m`PZE_dSkt+Sk$Ver8F|8Cc%ur%HJ;41BvI;kBz1x{`U zk0{KZ>OA@pMV&6v^nu4{)=9b_1}zD=5du%3Wqt4|Fpr;{2Zu~gB#lcnh@+)&?)QCBH-wrykr6?&>MnbNu9UKt8H(z7OPm{v1+j_c=s| zTlh=+<#R~Bwf~fI=avxN+ga`qw09mnQPAY^wol0iINhi0M9Fn+1!jL{%wa;5*@r%c z>tY|7%&U;GXIM8Im`9pbY!JeXM1w3n% z!-6RoKUQRWbu`Q`UN7CE3gyV|9M}50Tz7DXkg@cd%Kxa~SBSFclu+taOWvfmIuH03 zM1tMh!gMmukw9qFs`ily?bD6dS@*@a3jY?~Vr2-W4o^L%TRr!Xw`cObEA=5LZ`{I+ zBo(6xbVEya1U0MX_^Rn6-*?0oi zZX1)cO)tHqN{&#fA`}pJX1rBM$DN69CsA={)-$j)$u>ygAiHw{OPdEL?~rU$xo?A+ zb^MTeu3cPv4YgnNn_-Sr19PNuRF~9q8|Twoc+u#GOJG>^_HUkt*gbt{p?U1 zn8**s5)|8*Q5-Ub&P)`rhuVDcxM8<9f-I&fceC%@W(~DDVeu#R>YsTWRum`2B2YEX zhRlO2_ZLNipllurYPUhQh}H?+5@v4!{2X^a$24&(<`($Cs_!At5hX`)3ZpoVit^UQ zJk@r=pFLsf_h3(4%$v_S(weY>Z43A*Gmr;SOp^Ff1hrlOO#FRFh9ey9<;YVp7t4}VV~vdwj9ql7m-4ToIBb^{*e<2pRg36`)^8Mq_3S6_2|rP$8``)> z9bz*s1RDAxK?fX868S{)XY9LFh)<2m};M^+heZQ$TWKBSJ zM9_q-@PTrS6`W^g$50G(s)i-qU8FskI|Nroz=9akwtd>3_l!1i^9#*mFmdiV2 zShl%Nye7!5lb>&_V0m7b@7Ay4=vAjMob4l?o3x5Ed=&^Scj0@Nyk6vTL+9N6ednrTGwCwA=i-%=8&n$?E5d8!(qgX+A_=HO)2bmIe0f z1@EKFy{A_cNuOYR#%|Rsj;w7m>8;H#HBYczo?yEuBZb-Kb~#wI-bTG1yzyse5)f1N zLivQX_(`FMDh;i}%vK4riJRz41VgYTR(7Z`SybN z)=2b}j?TPJaL6=%Xd5KYfGF@85@kjz%tXe6MWxM9S0$9GS5X=AjtY|B0SpzXwa^gB5+tLR?-meIY6TkFXz7p%m3uARv&yNFjXCwtc7zom>20JfU+)Jo{s+Lhz?d#0zM^`od72^c}ieDgLd87gXO?^~% zv`Bdagp{-p^P<4t?>Sh017-O=Dh0t;svZgUA<+TK@*6k<%|!0L{F^($^1ig?lN+Qy zae&i6;S6x*4Ww}jUTG0WN8uPFAt4IEksmJ-ze3B?II63H;$py2ZKN!}M^#GWFslSd z5$70Xpp6{48K;LY|G{p+38QfiHAv-Jz)7BmemT^cC#P}hUP)g9j{GHBCvBc&KaImU zQKVac7P7(YCuU~e5crRH0oK`yIPUvZ5;!D7msP4B3(g?UIm$X)IT2=@{=WP$1)Qbd zN)LtoeTCsl(=HGizh5P%h4l4;3n*kjA@(-TJhRY1Uw%X`oTb`E3k~!Y&efTAfsksy zN<#}VUkIL{&{N7l+c*o%LPLG|^NnzpY&$J9)K|EG@sipFLdW*2bhMD{rJx>#UZ9YA zJ7)hz{r_a=aD@)$NmL= zxj;EpoHK8Q^Yk+|t%-?-y|C%j8{+cjbZD3-hM|ImSx*}MdIZz)c>2$RF1?|YZoRRg z;gvw!`znwiQ6D5y=WzCzi4OVl#}+`B z+Dmt-+(xO-TfosoVVBBw=IxR$)G102OW#7U2mn{kBaio|~dj)KOKH3*7lzlMhY z6LH4>rIOP)`tO2)h%<;X&|jQC%{Zrg>&_`^KuAjqsa^>__mY1Cq8y4S$C>vhO{6gz zkx2Lfi4H~S%tX28iga&4sFW7cHwye80Fmq<5*_t)@^1ZQ3Pe6gPfygA_H`C z##?avL$u!+ppy$tyFjSsFO`uNQrWbbheGpFXxJgnakJ1a@0+L)gQ)faiF4r7gE2UYMr zOOM`|!szbXByt^VFv>rImXgw{j8f8CON@UNL68X1D$ySl6nn^yfEIW0L6!2L-u%NC z;brz?yuEpBgSR&a^+!0w?U>&L!*1H#@qny8+0oQs&IU;7vu44mlR5J?84aw=)E@?g)9zt}!q@3e87C^ct zR+SB8I$|7^6>#*n*2mj&HoLx9ETVeaGKhOHTW#<?FcvFpbK`7bVkTfF+cWKfHtC^bu0@68k+esS7}-;@s%BLAD&}>bJikT|P07dG zz5sbH7g8R*pcLY5*R6nTG}bO6oamaVUTK@;7awZyO9(Z*V1IOs%P`DTL;9cG2g<)5 z!1SNAp}&qgk-ySbnmtnuxsQ_GI5t|RbaOAaaX!R)=Sbx~l9p1XCD^`&(!N?kX>Xx4 zgGGYvx|Q%7+a@e}!JWtlNeBW4Iz|F?6PBB)0!O(p|EO3DK_^Pe(sO&7@V#k zv!Lu07 zCiseQ7S!L%=^;G}TJk(BSC8R456LX3csI9}lM3NGvvHbF_)zKuhfLu+hC|%(BA7u8 z!q&)|&Hf#e?@-}8?A5M97ZHT-{BBFacUo2i#42DGbetEMHf^)5#qYLj=Ut`4cOJz< z3U`>8y@W`CUIBnxo!gM$T|um&H$=!pfwPL*jjSeMJzyLzT6~HhGKf zwTDz$)9`fn2=z8|x}giM(Tky7mMq48vDG#L+r<>P6cL}J&&tw607nFAX!Tat1qwx! z0wz9-n zvA8wcZ58I?;^L|Bow=ke@E9t5hbk!AOQ`Vbwx}l;B4lpCx&Lt7MDYrG=6C8!=@o2+ zE7%G#Ugc)T5Wl3tcXrXnDT$}Tcjls3ls=Tg3p~i&_^H%bFmV#BAc0rRW$&ga zh3`OHn%1lXUaAc6Ed!-e|4?dtsJ+i z#yld_SCSk0L69SaAv%VL(%@Z*ws*aF*V^DnbQiyEk!QQ^1dK#MDFSYIhHRnwq;1mt zE$YL{cHF`6KEnJW*?isy)`p#(Ln@$se=74*!fDakEx)7KDceV0yuW6CP~dA6VGQwE z*{|HIbo8u+2r948oxksuZ;FLM8-zDTwaCcKmLm9sV)a$6gp-QpUt1R6hT3HsV; z+iG*;&s^y=d@ z&^VemSv>*A&=Ya$D|BWY@%hx6vw&kT<21SQ_>O>c5pl8?vGr!04sLabQ%~WP+WL9Z z<3r_2r~Da!m9V2)LZBWhfLNnys7C0IO5)j<5Z=Lp{*mokm}A) zgtNL;5rfGuaQ*l?(=HGi?kxn3?aKT0dJ zwZKEB<*@PupPSIb18(B*6XG3NMfY&0H!{Ja#c+L_P8-STO_-o^F<29+?Pd#nD3IGU$;y z(=HHtsP#UnLG=RrHF?qk?+C`LORi^+McGZ$m{o0}(Ea{?^bNLw)O zy?akCt0xSOU$7~%6WRJi(iG}MH$@CIMNAC#$`;$V=B9{T<)k?YLQS-g_Pr{g9C%Bq zrZ9R7-_jKqe##nw794>@<&TxC>57XVz@_6lL}RQbj;fyS#^d$^oNtH|zM8GVL#9i| z?{3A2Q$pce1Zc{MfiHFU)pb=>=+HGhKzBWVjVK4taiCiRCg!x@{xZw)(a4QUGYX}6$d z&=ixF;fU%iSYU37QHf5P2Otzp3w8UX3XlVDKGhUaKEefb#WTO?C!z%>A(5^|oJtc( zpQdJQy$SJ0sl+qbf53?91voBQUT zRGA9Esi$yaeS~o|PW>0=WWbS5MjT^}!HlDPmRj=&a2V@ooNX|odIOFGaSBt}HD;VU zZt;kdNa08X#b%r(iB9>C0p~J}Q}RhwuLB&XC)5D%BP^zIe7?%21CDY!;)tJU*U{F| zJqJg<3};E65J%NL2S+V2?E<$b@ey95h2&prP)Ixjg?yec*3&}r25{6{a8>{foz$yy&t+g!U|SX73?o4#|oOlnZAa>ONk$~maM+@o(u1XjHIe| z-$z*P<1@{LBvL!?@?sezQi~d>cp`qNF1L#800GcQA4YUCmuTz;#|n}LQS-g@N1ZV>PUYe8cccqMz;G# znyA@r5)w_Nhz1Lq%|zo?I~}nDq0h9C&sUXhObmEFMfAZ(*h~}Ud}l@gku(B{j86<^ zBIWy3)#y9m`RT;c?E7wM<9-F4K+5wsvTMvZU)`b*Cz`?u6!;IIhv@LtPQ~)Q@P73H z#6W3ZRZi^9@Cx>)==pvFg#K{IG(s#flTq3AKl=sz&d2X7OAgX z9RLA$6A&aRot>4gwl*%)SbKHv6v`z(^YW4TEkLaWsO9sN@dweWliyAS5Bj53OQRm_ z26wH&pe3kHXEXm$n^Uf7Ev@og;Ku%>3p}OE{!mX8&@()5-PUMavZt^n3CK-c$!f8G ztOi0k2jKnqL_zXIIApr9pRI|kfE)X`3cRtOIzaK1zOfJZ&g}jW;@sk%k`68Ynp!*? z9;yw0;Nx)QQ>g+DwN5y3tj-|LS;S#xs6)&+AKkX4yFl386dI>;wIVnc9uH^=laDwh6i&ndE#hDUs0Car_!U(SDR7{u6BMM)K#eU#K$Ub6scJIV(~*j5r34M+VN06cS}tS}s)ZP~CCHODO<*xyaBBg#H5HpICH1aL z61)ymkT5_1|_T!J~mxhd-q&QP!n;mp_CXoGl1j z{#0ss3eOrl)%a>k?OblpthwAnu!aAp=48)JJt$p}dRRY~>#^Gf@{TuyC)P?=7Hv_t z*#YyX!(Q2j3VJCZ51d7=ODR}-hgu;>>VU7QGzKKU`gdBk;#l71aFUiywSfCSY1y11 zEt`6^V7_l_TDHNFgCH%NVT)RgQ^f8A^Ir+!9hS0wjAh4OX@5QTis9I;B%{Nq7)^z{ z)*!d%Twzm*+5+zAqf;@ff2F42H7;A!iqSl7Xcd#a1hl*N7l|OXI^07v7RtuJ)B%W&2KVn+02D`uMX&e<;qfP088T8BxAY zln(M%Z6I}8I;8-w=zAjKd)s#zgsXQ0xX(ovDSI1rDLIY0zqx09jUM^|Tv)ORyTzP% z3nuMVnM;*_;E7JpL`wew+^*gHEsy$EI*_+n9``M=6nR%FOR3W8MSG#NR{|()FPx;e zsuwvu7r|N^D&*JPxz21)?SJFno9jOiesBH~E0a8sTQ~kXu7>0+l{=plrEAWqXX?h2 zWfk|4MkB0OG*cxsi1L=wt2~CEJ&ldh6Rhp`hu9mbib)9C3?jzfCwD2~VK~}I-tXrxu z|8_jT)27Qs5IJ*s7%ZukES21**HQyF6pIVheMU=flSM7r0J~yd8+h59GjI&Ruwgw!#WXxg^#MZ6Cp z``>S~$;w<^C)pK1f^{Q6GAz{~l4%u3@|_LKG4)U{@_87faf8nw(=6hB5~aVnt%kQ> zu0)EWA>0-bygTA6v|AH?Hsw@8FMjv8+8*jRUxyU0PSo8Jx(G>)v9!yv26;*=1b}-Ow!^+ZGOE z7eXxKNV=iF2c#Jd$wmmgxQh*)79(C418bRgA((6=*m;NaF4M}jvrGDWsk!LKK!GF!x{(uf1nGjS`M6aPNl)q#^GtxxL@1zV83_4JH|D(=XKpZ*>7BZ${UQH z8;oCP{l>kF=VJ8Sjd#M^Y-fmOh+*89DP|8+GgP@~RZB z(dHPj+BEjnrR)+! z)*N6D3V12ad8rgM3{2% zv0u_QmJf=4?>f4!(?0Nfl>8 zrxjeW`xMaF(P!qDJQuG{)OVWmI+Ops&uX*ahnm2d{TI4*EBf#~pl{ppY!84@|2DlO z{lSpUuUoR5XaAe;_g5W<1ii~Y_FJd4jmFa(oAb_)|Gls3^P}+D#*PDahhBRBzq-j^ zO15%heqhL_>{PO;O$e68vw@LZ!LCnHJ+%zpL z6vHNO)hp}D@7{lhO3L1jSP%86L!%GbT=8Y6tKx@J&uwfj8T%BhAHI)$@5m2q z8UF+;!n^&Yq^g25X0iW}f(-}u$s>%YFkJKXH7_<)<6 z+j|ATrfPfdOKm@u?D}%@YmjKu#)dPCI(V9PoOy5B@!Rfq2x#0)A8nj0C zz4M2A+Yg!X;MCtqxOsoc`A6Hp*AiZLOkMw@ZK}>}to#(6j~4H@>H@3$2NNCnlJ+;H z+`RCYA0?q@1dMy|;5TCn^f0R#*o%gQZ?j}7cJhC)^iS1a0xrDXe-SJ_<9WLaxVg9G zk(lZS!hjTL)x8gzRX@d0t{Hfp>N>w~G+@<&9-q@w0H?u2(Q-mNe1yZ89fj(Quc_weE?HKp!-m?-Ycj3Jc6vfL^) zXWT4&9P+>ZL-~7%2OhgQtru>prb$c=kU#QY)ol8Yo5?1v!wGg*u*s5&<{9mN?47Z% zIxdTLJ1u>Wo+=ePpo3#C4Cw|y@5gqS4aK~lvAc=&BTC(X=q>-dm4CwpSpLvK&G`;dxH%kpOy<JZXr zVF_;H7T&~7{;Dm^${#$I8hO$-d4n4~Y5u$8k7oMxs^jS+yZ5L}Ryf((p@Add=Dh=d z>3-;4ufW^`4!g}N{Gl`MeK%yu>(hFwkaJaDP}bv?@s9Ez>apT=$Kh4rSnD5m-TA+l zA?P|KRgX_qz;5f`J=petu1LJ2UrKI1xlbK`{%p$l=BBX~ zlYYa^<1H6{^x?nHZLEu5_@1)cA|)C#FFFM`P0<%93orVN#|A7+!uCJ2s*GL!@8PTr zTB+%*r!Q3UdA`*^+?@W3HGlqo{QvR7Wv@Ffo4Nfc_{6d!9gg7ULTE1I|5`8ns~eQH zV$nwA=)Z8I009;i;bukjR^0i1RW_&U|Ko^dEvwe@r~~FJ0(H5m^p0MOC&sL9w5tAJ zr%oNGG<+4O*3b{2?Y)DLInk1K$TOu*;*>%b4M#W4*c3omxMWs@NP`jm~@IU9B zd1p1V-_P&={p@$2ecyS`dCob{bIxbCJ0Q$?9yv%H}~`qmgdfyhfU_INNln` z5l`2CM&=tbRuN66kZ7{@5>56QqRF{NG}mgjV$bUjh^O4eD7%I{A<^XTz!A!`HqFN- zJDzBAW)V$>DgLI&`7j5YYn8;y{2|zs=4BF3*OwDbW-5WDa|hVLAN&oKf!J5xp4x5N zk6SN2G=QzSh?kXmtgCOiV!jAF4&*+bOIYxiwwc1?9^gIThDYiL9Rzq%(Y~OYraBRl z@Bfw3roE+?WJHNoD@na|L-(S-T8_oWkG7v@ly}0Ue463!R2YrUln?kjXHZ~`t)H|2 z#(uDMziDv|Ezc5*8`vC6$;BqBJP(^aDb?6Ktm3r%Xbmg=;dC*6e_exdw)#t9QM4^C zuR)L({x`9?l+uCJ+_2b+zfCXy6`TCyt2r8c9zY8gm%u!SzsLKhVUt?*zN@=HxT1S% z&#j^ukOpF=Sx~2$qi;cs+k=tA_k5x#RQ?fIq0&D7nhIQ*3;v7nji&#i5P<(ud#4?Y z2-vQfY_&^kyr*UCBx^H8fNDaW_Z%OoIr-pDtgFpm-~keD&f*hI_9UXoOeLCYw~6M< z=K-2K?I+lU)|Ix_3{twf>irr7QPCN7O0QBw`5u6fN;U}dbf5ibR^}=|CC6?3%Q%R1 z2Pd%TbM>Yb-gNU;ZXCzzvpSHs=4^Mh*VA=xiF>;cRHHWYE74>32o<>5Whv8(sX_1M1kwmmi2X3dw69-(pTW8vi{tknlt=1FQG0^P*334 z`aAUz2Q()RRc`kY)+PyI?9|5CH`p}SYrEms5=tUN6%1o=VP4v9`cT??3-+w)y%m~^rB7Ev zBP@3zur#<_7fynPPpACorgABL! z!;kCk8~)d}Q6c1?Q$ghjXR`MZZh}vvVqP2$UhDcyy~I=jZc>!y0LfPgI z^#&t&8ZvD!T41GVJhP{#de_=7q z8_d;`Di=-L5Z(wIL*9-cP@{(Y2lMRrg0L<&dhoM*{RATgHRJ~Pt0v_JUxceEQm~&1 z&2_;!Y{v3au^AKEW(iF${{VK`#IL|*AXg=qL$*o=*OVDMii zq~ep{-Vdw>q>MuHANqeoaSN(CG`WE7ZC^w7V$Xp-wq}nHD8B5q*)1n6Q?%{Z8MAx} zTly%ycxN^~zDd{|SvKV&6TY!=ib zbK?Z>vDqSciA^W}Z=w;j62EPg8)-5wgsHCkBorHTv7LqRw(P4=(lrnv;?C;{%t+X} zTrZbyy`G&hs|!VaI*DwnZ)S@9Dm&Jom}TcfEE z&u}TtZqBuKfn>pZh_ydm(MWdik*qvHGQ;qR&qSwJPNDaVKu=+l&Ac~5ISb`ky@r80 z)p@*BcqfC5D-}sZQ~H`{%3TiQ%c@wSsU191D@_BAF**nFrGs8z&-$lC&xS!q@Tsv6 zpYuaZ2`1v9|3-B6s>)Y^Fg894aFulRYY^9q#&f@ETFbii{I`ZvUlY;ROU2VrrBtGt z`@}Wj0JmRK_QyLZRB!apy{GjZD>GfYX>p3q zColcFvn|Q^4=i(O@Na)R>>(rhMc`cvm9U0?ix)~7m z*<-dMivO**aw^^U**21yrH3qi3Hg%QG~i^#MM8Fa^(M_x?KE{R!t2!1E?a#mbezvm z)DQ$%^|n4*Ar~ez+=c$N@@hr{^WmM zs+z13Bh3X}>i@Rje+1QF6e5Ii5c~c|pkB1rMx-`$%1oC+l{cvBG}8)*L82R%RNYVo z#HtJzgp80UEj3t~-eta(fq6&N4oxpdkq=%@>53tm+gD9yLVvbqCNiVkV_$RUnb9&e zm2+a4TPK9Yt2a*8_R90Yo#F4cRBFgrVS}bg zRqO%(yuMFP(|?+bx{bng$ZU4WNX}`59G0R}uH|w&qsy(x+a_>340b%A!>6*=L->UH z^hYhjP2uPut)-9Isr3(1K5iSW+y=_rK9Wlu+j=@>BR?L~r)mQOUoogv++ zIcU(2ws!R>-U7byTfb2?Gli)gaWxX9m2bX2WeuB1D3^WusPMihXL%GS#^2IWZF4Q#wAMK2 zWA=>Wd7OS^ifY27xW6YCatwRCbh_y|(e#RHy{-5|bL)O}Eun6l(B@g!dBQL|2tI2| znqrqF9W84|{12_$SJnA=5N6#c*jP`WMXv01&0h?ME+gYX)ggpObX(hStr(lox<%n~ z4GVFg(PMUF3*Si7M7DphZF|nqu(hd88iK{-?abrfGz0h3FykdoTVK#MHZw(KM^S3q z_B_Hks!vn$0m-q_R^z`puqg1KXoZI`gs}Y-F_}()_vZqt|3e4p2^~eRh86waR^X-W zwFd#N!2b_xKwBe()I^KHGe>L-lzCFzH9V? z+8qkZD0?h^a41y^4D4E6vig7$Vf2U-TseHA#^c%-5_ zu)w_rotRdBhas`~D<(T=;(|o@qA41vP^1;wehduIkheGz!v^||W)ESs{`0Qtqxfj+ ze9YqPc8Kc|s>e2#X1zTz5g2HH*pU@af_)QSf_y<&>-fQTT-0d;4Vy{Ce2Vt*T1g^M z`p;3v&uQ`>18$Z{sTStm_Ai$K4V&SlEw!U1Q0uS*H=P-Zb6UGV4V*1~5pD;4`9Yw{ zO+#39=mDWGKZN1nZZ>RB_U+|5yzb*Mzq; zRzoe|*?6p414$)fU)R=oPs3}EhgmEexAhV1=|8YQtO=hNRp+!5Ectr=Mf{eRNXj2P66k%YziA-xRppI!d-ip@$5I{_z*?ZVTBflbKe zF5PJ^Po}9(NFYP@(bhHK_#O@XghShBY@)^#Kl3Q4->BtLxfH*5Ml?DIM6Qp30)L`9@?#~`R~=ugkz#>RvuJU59(wAe z57PsT>XXLOKf0=%&%^RQ)2r19(;w2Yml?CX9LU(IB%6QEAZBK1=C%P0kHrd_y(gR* zmN!r#HO$YzA97yd5B0A%bK1i%mN(%Kf+GAO@6~qrp>6OkP1!&)xFRCz6Py@#v2FVk znlclPy2~{-E#+1)yQ=fHmKTgiMAPdx-qaibHi){&sB?58CMs6u5M2ewf7hIS+_Z*i zT3Fn$-WuS%{ksujq*%mnXgdZWkZOPHnqZDrd~zSz^;!Yo*;k)E#fss{F2`6P z%AmH-)3v-SBBRj%14m=n@DBRrl+Q4s8=|C|j{Sa!c2etUTRH6+`Q@LWPnJ+@sNpc} zU(!*Oujm~Ged4RedujPDV}$E@M9RzI0?>K?@_8r zbZOR1i)QS9v1%+;9XYG=xrD5cBiQFkpSW*r~t9m zW;xr!m<`^Rp6yOI%2y+<(AyxD+H09!Jwb~y_}D0NPY_MMC7f#TOIz&>x&~|)QHAU@ zH|Llt*qqzf&qBNS!#Q78_gU>3i>n4Pv8i<-o=RQ%;!{H`(KLD=0XiR+2Z6*kq*jdp z-Gm+Lhsy=+KK$lG+F3fV!yrbV_^Q|a!D7zv4*_;I@$Uhj_1Vt`F+5Vsy^n&$JZ2FJ zKe!8S;SagP2g45ycPs$ng~fXOVQY$}vxkX{b&5}1n=Pa%A|ofld1_1bUjw9OO}Em@ znHLh<8vl1QSTh@h`I;)-<5-8!#o9rg$b4IyhP6RIeqU|pd9AY}<+-ykRE}6(1libR z@$;~0<^M@Eg5%gYhZbYwvh}TzwsG|lG(7Eqf%y$d{Nyy(Ara)!$3@`$d_n^}GW2pO~WH=1&{|M95T=Ry5OEgz-EqoDDewj&EZ z&}U(_Rt|M>htIL*r)faZi}Fjzu^`JRW@f{6`A~+(#m3-f#OtU-hV!GwDR!i71d?(W&OG-ZB3Bk z`pkeqI{(ASwB=hZ_fr3+ab~%sbkKRQa5F5<0YACI8nDM*TmguVCTm^uB7EMHit87F z@BI4SzXF1&l=FaKY}DJXrvP&?+3Iq0V9kPRI*}bHeY)Td_1VykOM>CEO&;D zWnAb%lterK7&dN#-Nw)@|1)-L$a*9rGw1BKYvv8c~tx#*^;;!rSqedUXao+RxZSK<)7oV z^1XAwU}3OVn-T8~VR#*w|0w=nIHemxpCWVD1m)O#j!L0wo;ZPp1<;@;AfvX7Nh25@ zm8HSgz*vWtzXsubV#U{k-e&!{wBB_=BJbf5^+nAKL2A%QP$(MxXBds!f`1 zC@j}?iN;l3RX};q&nK#Wqjqco_Z6@6A0zdf1X4JOFOTGB1wVs9tN#-hKF z4|j9%U48aJG%xiz`9zN_K0Y<%{7pPvyGQ8HSL;<8d)Frrc~EvA6GBy7Zggb4S_0Wq zKwQ@Tv2=1swFafc!Sw95mbLxT|0{JiY+?yxm+CrGb5UCN*lZDg(cCdHHo4TWp&i8Y zp#1eu*wohZHEnODM^JOtVWhsfjXy{$9n)yr_ORRm&edIxr;Nbnven6GjqPMv2mM9; z*=aaHb8pCNXFO<|_uJp(4f*NiQNn2In*Y-ko-#mla9BSZ9+mUqyw+{8MJP7>DxnVE z<60VyCyX^2{D1mxn_25DD2Ai#+Hk=ge991f><3LN|5I->wdEJWorT?KB66c*G>yXk zC@>18cP_&u2pW4|tWl-kHZ1OFiPusV$YQO(6yTugSD(g%tn3AlFy^t6OX z>Z0ZW4`w6wDRTJ)NS2srvKPZuK7QM$P2Icr@-W{%RaI}aTns|}!IsYOzY1Ei;c}Io zmZMe+GF+dZl7YV;x4I!BympPe672?XzNfKNEAsmUkzcMM8m{Kt69j>|t7bYSsyuqQ zlm~Yq`E6eh4kK=V{d=_^^rmiuAf?Y^GI?oQtTI;f7fdbA5a}E-8=2(Y>Vqvise2$wJo)##LMQmwEbuAniKfj_EQa%&`V~iI>Kx90=e&v@ zw4HO;3VT5c9zn%D9;&9cBy*;JdkmR+UH_qU<2_A~8puLz7JNo(KJyP?vyoqm8tNka zjjH|wkek&$H|O*;cl@fle;f1`yJYd7B2OCy4cNFuJ=C~3h_{!<3tEV-f`opLRN^Sh z_@Ia-=BhmS!ViY%2LBt5?=FZ#bk_vMjG0Q;2^Iv*{5_cYW6gIi;sR(}akk4|Mx4C0 z$1NAYY*uAC6U!^x45zsEv!zz98@AMq6J-CfTRZ2cxPva{q*{e(I^{8Nh}Uq30*u-eQ7%;-{K9@6dy2J z6ZtZ@ty4y2 zVU^b23dzS84YXWXl2Hw>?_<;EeG}JdX=^BDQq08RN(UJ*499h~=lD|1bp?NvGP^e4 ztgG>B$<(?}1F4O(ewRKk&$mHS2!GuPuq`8j)$hW1xxCmbE4 zxtKxc*EO_V9Q}W}@B7-me%AUkCse?FRV`Gl<4&+X*B)AFouLtbHn#CWz~9kb)f_KY z98+X3)HK{cafy3qh9!fL0F0IccN==qG&KUr2s&n?7u%j8myA>Gu+gR2Gp%^HG!{D` zg4-V_6`wy_bWLjGad%s*nHGz<+wCiDzXrQ~S2e`bp_7V?1+hO27?856h`hc9+1IRm z$vC=_^ytz&$SEsRmZa(X3$y(X8nyA#_Q5(TxtepxJ|;VGWTUTw~68(DBK#o95*v5NziM+xj$32dD3#wXWPyJA9^B!CxrWA39pBu^|COJxXszbv64M4w%a^};XkK_yp3B0<(`wdr|0lMqx@`V~Zc{)ve=qgm zKUV!0-0Ek}P~Ni*D$kx-#ilA3MHf+&xC-ZL9v0@qJRj|#Xv{vYtrDlGRMTQiW&VYm z(dp9GwN_!cddVYU<>pz$FW){)(_X)Yo7(R*Q&e@fo%~YzL(=PpgyPa@(&Tu&EO*_k zDSZ5cpQ7n)PZO%ia2PI2uh&|9sJydFzaG}|YSw?Wy*9(bQI`x2JcR#f`txX*F4h}p zUATll6*}tZ)pax8H1qgP5#jNWW)RQ9Chzzki@ZqKU@*KIw^ohgfue^UMU%8*V=#v6 zXIRzAG`46$%Zq*gVNr>but|siUs;2$mzw(vsT$)IDm>QR z)(8~vR+LcKT_e8^&gOZp5`6Ylr=Z`uGUS>$RN+xSY0W zYp$m4-MAZ_;Uic z;~iPOTpW}UejFS>DS!14#HxPH9=;R)#Grc1$A(xHisJ#BG)-8B0xnOseX!BoRbVx{ zKD4^i<A1xa_VTp#+<&r6 znnwZ#`S)!*%10ge$lBx|yLiCBkUx13Zx!V!k}E9;=mU181&MCy-lczM+&2DC^WGXAfeEO zQveCeYKMLdgYwbGffYkxPsV{R`oMciuO0f6xh9^J-8NTvR2?zXsFt&D%)s7Ge?$+w z!^N>i`8e3eSW|)${n&c4>)nA)-?l&if9V0>($eifTg8`jplyK_3|;*0Rd_uBDF-v1te1eD)G-l^H!k@qX(k;pUKk$0Q%$a`@;(o-`I z-ir=xSMpx*Pr91-xrl^v7pZz00bGdsb~ip2yWjwqqH~hnZJ=RoAlHmiK)Hh*7GejA z5z@@8LA2k-0$SEqq`9~eXco0u0ygOl?EJ?@?A+cBI+s-jiesr_He=gb){fuV@y(pQ zPO}cKfA%q5+BfbcySE)xhkCdY5vCzRUN|Cb|9}YdEPFdp*>wH2OheD?TRMAI4KUbu za{27E(?fLvDeU2+yA|4u$g%9r5HafBqtK2Z0kYqb0G|y=z|C&P`)=9Cap{yu%ihdF zmYyuPo3ROp6)j(kqnvI?m(oRCR%syBjN*A2sV#aONTt%nySf?Ettr#O0SRX_zUMp& zl+B7$x|z24>>Y;h?QbIzDI(!@M*L2}k2Nat|ivHAvfF3Lt8TYE!i(F^Q+ z#~*nN#m+^1?3_V#&cx1XS=jkW@7`hk7EHEba9Gl-2ogqs{E7HO79f5HW8)+WGbY+?w_f1#K=GvMv-QS5 z6f>usa54e1ucZB+E~3Qn_TZp7R}jN!#87>Uz~JYX*xwtiTuyqrj777i%=Yu+QFP2q z>~Uyb%r^Z=6qD%hM#gaz#>e?EoUFpgn9iXW&-3HZsgx{4GGiXfu*k@`H;dct6d;M1 zkHo(l3K*n|a9q)EIIe#$7?<%Ig0#;?RH;Tdt#hz5A3IZ~W9QxHv9ndL-aaA*Lo$oC zeII+5f-`dub{5UT&cm>?Y$A4!NyETw(-BGL5FF73kz6>7NEC>KafBdmHg;Z2kT;beZ-_%N#hOL?(MtbTFe0NT z+LP+j04_R=;OYAi+?xRRFjCnQ(Wah+JrRvHBJ9ZtOB_^sim)fl$!Je>M&ej9F4_~W zF-X{&@{=#nKLYI~Al1!Px$k^qLZC>-}mcO17(4&%P}K;A15B>giFcyz?h zqF31Y=oaL0FLo|&#Li9Qk;gpj9M*=N-wwde71+5mGF_w)ahP!hZiy2YB;VO2^*NIE z$My7Sr4s$!`&ip&!D=+g#{ShM*grWEkxxM6iS>y5^*BH-N~A~3!7VnLqQfmDKsF-Cmu`TcQib@Rj`d5|ld*ah z1`dqdWMJ&DdY{dkwcQGRjLrli+1v6Y`g$zAi(#C|H*(PhN(7!5e|cuNO+9UdcWj2}lBXH3*G%V6i?P4ave`PWjrEb+`;z| zZhq+$L&CCgjgWX&4Py9i0poJE;<&OHlZtU)A;3G#lE>lppG%j`Yrhwq!tA@WWq zq$CxQ#3>NT8bl(ShDc&op=<|WXDFQsotI%}&Munq1s+AO;9ng~Gn|C|kFLZ)znb*! z@5QBvrN^oI`|NVh0;Ds>DX-S>%gXn>9kxE7?g2H zPB?YLCFs~J?+b*n`V$Kvi;0d+*lk3zo(bmc0)!fcP)YvedZzy|ygI`q_rlJF z%rjGt+AwIWBLizW19WVz>91nbsT?4Fpy^fnJw3UP!WZ<~wl*bc{W4!#a2RcAlbFkl z6SsvLcVjVPql>fk4H$GUZe4k_F6>aOCcYe&I;Uq_RyFLu_Vh^qFfjW)3$G7%1G8^b zP};jcc+2`VtBv>IZo;{$znx8?&Psk3ZA$_0KSJc--T5+>WmVHDfvla5{$X@4O3#?& z;chyhaW?mR_oo8a)L`Rofpk%H`xX7k6xcM|%Me^_A4!07dX7_pTl$CTiNMX99-iNXE7rYJh>GZhooB_XxsfqV)xyd9JF}mYq zl5u&u9ywz7VhR5HzPP`L+E+%kNaXf^Pkn6gT1I=#>AzHO^w(@V=5DjNsLcg8f&B09 zpKV^Mx8eKuY#0(4eObli8ogZlE#53+$h0Kr!6?7F6MB6A{{1HONORn+sA9^CzMScd zH+%EVz8QLmF4><43c0JY22KF-Sys)R%-xh?uD*_rj^}}?pS$;za<~UPt$XSYf!CFb zGAPLf9`;kl#^}1|P{-y;EOy^=%YovPNH9}LGyJ=VY61VgPyD-BK0w)jx5Pr~HRgxT z*uRkI^Fx284Es-T5VE?56dU{_qpAMkM2x0{`+7|B<=md5d)k*%0`qWc*iXNHusZu~ zJmW5tN{?pIjlz9tg^u3o9jESZU8)Ig{J5{rz*`@TyUR6HZX(!arXy zsh?vV-^81Bf_zXGPgf8B2Fn33eS@AMAATJS!p=)edv^_lz31ZPIZO3Ue%lk_?9F&S z_-#4N>(;8xLg+c({kgk~H~aFy_rXi`(#C#EAF-R@hOmX?HVV~jU zuB|?QA_peMrKg9&G9iml{`yz9t9pg1ij=>;!Uy<#D7^r_Ihx_+_@{3M*Vl9GMEws8@`I#Vy}Cm+ zOX8!!Yt^5#<$vlX8GGLxKFzHBye(h9dm%oO>oRa?I*X7?V2zzG|$iP>;FM*{Kyi4^4g!wW3!|1$;qz zG<=_Ucd`0IxqW=#`@1$~@a@#@J;|O;f3e=eXMvr9RG-Y9Uzh=3cYO?OkJE8d)e@C1 zv2(0o9GwW?*Y%l0eq5Ka@YXBnS)!_~%Ux+!HsQXxFED9ss34i#e7+FMU7G`wz6;b$ zt};F9bKiUs$>iBd;kJ%b!%TrmJ3o@imQUuud>gP#ua(sp{ z#fHyRV{)S&!digbJiK0|a|ReV+lxC5W>kY>dCVwx>vO))OUvluy-hl?Lj#GI= zHCvh|yjOk%IfJTytFMTChEta*E z51i*s!p}R)|1HS-*11zW^S*gHDMoIjnY0P;h2V#;SFv8wHy6ikVcx@HrFh zH&H*x%vo>L)Z!lGM0zCk0S#FKd`T?wfG=f<-aXf#TrhkCocJQ~xuuukF}}RQ&RAOn*B08FXH8a2jx&<_#Q|i1D+>N?jf3!U}`MsQ!%bK<|NT z@fnDlWe{yi4qmii4Gex62aiZp56<0bYPuM{K~56G-Dx(1c!s&I7AFEyWli0UgDl`f z$9#~pI;ijf@eH!w%xWSH7xj4MoiMLN3s`8?23yq&?LISL z(PmhcMkFVrpY~w}z(>jpVDNTvlX~zIcVKXFC=BjpALmI9CRz&BS+DAy`uBOFK|g$k z!B4Q_3RJ6He16;AFfaB17@T8tMm=~-8H08TR;7CI@!oONRrk%4aqz@wUwE(N{I}o% zzoE3bl?|U#pL&;Sh*H=vIgu2T{AojIJI`Ce(%4N?ZLF^w;?7yZPkyd*v z4K(?JYH3v--v-J$3}<^N{Wo9v8Z7m!V=+pl;-WRpFzqFiVcIjIhrzTXT`U6!XQ=F} zVzP)j)$FR?W%#&1<031k;NJ?*)oVVZc|cg1r-ORQtaqvQcACmECI!FVyDk==fp)5~ zrvpj*?v$-SJIfqszsfNtmx%b7b)MsW^O6FUVGo9AC+8q+` zU~f^4P<+N8syV;MVe6u$SfY6o{i;bGqdU3ma zRHsH9Ez#A+rBdnPkD3(vG(|XN1x#!hR2G&vz`&h`~W>qv-$pG?e_1~YB>uE#OEmD{0Y2+ttO4_e(ogAnW= zW0}6*jcYxWy+4<0{S&=4n)VxOI#F!d_YcHTQDYgsum@CAeGv%dQ7hytw=d9%fHKRx zHT{dK0fE~wrol@n;`0(|trm~l0Nre>MXY$x5WDn~5Iuzj;^MzG-<#8Ssi-tMWB#+pY~uVl5o0VL^&B+X#6>I?0h zZVN-8h7n-EN-l|#-nLNV2#I;eP2c`+ft>2L7J#irSlRY1Y69L1ih3=FWki<7dB3Ol zAU*?QI#w}T$Z|idn_K$|=$?Z>Agy96vb2sQ-z#!C)OJx*pFefWHhjjyE(T(&owF=S z1im--0fHHbptA-L=wkI{mhNoZxCJ_i>S8gOC!svu4uEblsImG(WC`VU{Sd^{b0CyT z-U~Ij;Ky!xo}T7A$bwIw*BpjAddp!QhzI^;Lsl;G8NZZsXo1dt^>e<&hj@m1Kl>Oz zo$O!n7zz#rV;P)*^8iX&)Tn5}5g;-INA$I^Be~QqyLIsOykNt*L50;=!KRqfy6(38 z-~ye4SUuv?J3Nrxsxz5*&C2hI<-Bat)Y(iKp9=eQvhrb-Ii zp)MlSX0ir~3IRVT5ssu{>2LsoR<%V4NF%o~Yc45G4*11@V zFSHmL3^4f1KvJBPTf9&Z=M4}iEuFrLtWPk&l|WT(&Y#`YsSz z=*|VPS2*#~7~u}1Ry$H8#cXCORQC%9Cd3YZsF`TL8P&Iad6Ep!DGOwOdlV2Cb=Z!* z3z%RTpHK}$F?r7B06b?s3c;aR4bPJEI2;hda1!2!&DxiE*v<#x-HfU5URiajj=d7# zs}Y_%{HYq=N6uU66PzSc!_Q{3j_x~b$3u8u<3Kh1M|L8>i#7oKiQyG$_%Jyy(Kk4$ zS`EL5%{qJFuw4YguQZmc;oI3c0MFV8@YE46)$qx3UJgtmX9P)qFq@SQlbDF`B4cZK zucY78u^QlI2p>4&jT&Ai=T-B9lcZ|+FgELY@?pC)gikgWtKp|P(vQO;+5+(M5g*m? zZ{$3>?vcZGvR$Nh3^1WPrX60OlZUF8%lH_iGOxT|y3L@v#e+e0|M48DR$WIE+GyF~ z1=0EBs%G30xk-*5wktytij27R)c8lkyd+e}XlbVFlCUloK zK*CKV;RfRiNrJSuZaa`rha`N{aZ;1uVnTDHq%4RoCQ0Dk62(c<58Kh>frNHrsuM{< zxMLcSP#py%aGktu$^LP-cxk%7L#6w8QVMP+^s)9S3v`|!36B}xOI1>+?xXt^NN7hA ze!BH931dxY<2gWrf<&2qOH?MwIBe&JB-oqEos4>x3>xa2l4hM~#NiHQ>4VJot1PS} z+{9;Pjy3xPD87&BUn<3y4!c|36G|hihQ7Ye#21Q$I)eh+IP3kD^SP$1hGtDaYZ0_# z8x#v08)zC>C}n{0y#8g?HX|Lf@>}v<**2(>#xw$(_%tk@w27QvlDxi6uyCEpg*Vr+ z*uo|>*l@{GQ00YQVIT_J?I^liOoD;9G$c+GU#8yHmN+sQH(;TvMK<3g?lizJLwKGS zWc`%zk|>iffG#e#Nj*NqF6bCiKBzN>(F$u$PpC*WGIp3GnR*pB`VX zhF|B%JZ*NwPDa8@7MaBDhzT|fL--J{91>o#-=q-WS*HNL$h1-ozs1of)p&id;fAHE zJxA6@_b>VE6-hi-6tSybC20qF_o{b;9drT_(q6n#i;&k}B3_!hK&OhFa^_vpn}sWq zWC$Pa8uNmLkGm&|hbeHyDX5|(kW)ahlvtbaRwS{;kyB9Y$SP!Q2sSK07NWc;3FH){ zn7jlQl92^Q!ry8Zo;xxhiUtH57LhDSEp>HF-$5bAaX^CSxR}3563Xt$_5cavkc5BO zK!Oyv*)a!A{(+L^Fg{VYa?U-DRj}bh9Q+S2c7jp3&(N~diI)4wiM(Rc2~<$d02Q2s z7N7#5INb!H$#%XfQWGWR;2V`JMW@osC*fb;h&5byc0U1q+7Qs2r;_n%)lD>sb=(QX zGD!|719hW!pI#8%3VIj!iBU%Coou*dF2K)2_=;LJykwtA;^=-*;+}+;4wUG{gTvE~ zdfUc$LBjJ!=*|at9>Tx$@=%X|(Bz0Sz`rKpId;0+0Y3UOj^ECxBjG6{B?|z4HNw9+ zh44gCzc-^Ok_9>)N-<0JjF3RZ!1M&rcxm?&9w2ac^WeT8C6iR$DY`dJ)~(91E<+K@ z5;|2Plq&v|K`p=*)mUTlO&&tIz6mgeIul|#-R8bjDf<1PN23?Ser@j#ijLDKlQ~Jk zfT%l2RQst3YL|~~QXxtRHtbwlC?o0tQcNVSZ%!@H`Hn*Q#`v-{i8pCPw{BU{6P}oV zqYrDhy;z-aa_5BWMn0?7Q0yFhje?RW6TxwBGWlj}ycZn$agy(EJeTEIGwht8c(Iw8 z;SArzcy{_q_D@uO0}vgWo%iAc(S^+EbkSoy>tsILu1oblCRYD5&792+;pi1olD!tP zzzf~G8B71U>rIuXQ_XyIeGd&=HS3ZdQTH!0E(t;>@4yTC#A5;d5JpmAUcb+Vppb1d z(EPygA+zqTn?(2Xq%3#a7$jx+3w%$78QHrrd$If`QLgF@9Ob*#%~pqI11-b$2)07_q*LFPUE3?Z*zquXn{uMKSnl=utk1wrHB zXdA9|L*cgc9`mi5dOR|h7IAJ{AEi%7hE{x(9t`~r*4ZTQI4WiIDpXWT8FqS`xssnb zGuKh?n!4JDA&wy@ZMe_&L(DjI6x8$^5+p5hkne=TeXskq8}v;p;^^P)@r|2SSt%tnfhk%t-3wcvURhvQji35-LleUYXaf93L_i=Ny+;&&+k2+GIoK z1zP$tz86ltsy7nGGaGJ`<6}*oE1N?mCOBv10j92cCy+F@xveVfE~1X0K^}LvGH@36 zuMUDzc}2h_R~E;x?Gsi;=4Fo1&NxGgb{^y=?$hJdxe4ojCd9>@0+)b1414yZn7J)XmFA zuE#yGg)(;u_s!ek#}${U2|nXV;S`^v3P3p5vD*ZlL1glT8c(~!o_|#@aFi;0%{e9K z7XRhIO~$E8pz~%wOnw{Y)r2Im1_c92%Jf6x>L!24rbNJ$-AYcJLQYxwDY>6>kn>dE zLFAN;?U#6P^gJk@hU_l2NxbJEKYZW(=shU?cEIpgTJY?rM*x4=l0bGPZ{G4=B>s%m zq9w2-A_2edk<)7YoU`IUsFk-`P3>9vT9D;xM84AI^ps7%oQ1s0=50{Z&;R1jWNOy; zH|iGP{U79gO~Jpsr_f?u@0(}bBQx45w7Zk>83cP_hYVOyCXTN6D=7sM13(1TgF-UU>bBytK3W1i5Q$qJ;}Uf%Q1cC` zSvP91dM1w1CM~Ia>Of5;2dG}fd(%tsz`Pg{7VTgUzN)t#7&Q#nkxwBi>}?!1HEb~C z$X?Jk*p%FJpdN!cvkC84kZR(3G&$7ls>!C4gSl6dVC_UA+@_%FdtqL`N_#o}nLc=n zdBsl45dD{DCif4~UvtJ;Z!`7oL-Qt_m-mLJVJLU!3ieXRSRgX$pTTl@FZi{tB+M&* z@1efaTu*>}RT+AUNsDvZm73VIB(#=A^&MbNwT@B zJ-xYkS#r`S~b{`4*c{UKHtlvy8 zt==iD1WF(4MI;FyT@rx=(NiEHa8Z$( zgs^8kk*-;ao&6!zD$+e<8t2BmvLJdLlCajSh$KN2D)9#r)*%V^U;RtMVbj<#Y*31D zl7!@EqI5}*6gv)*!0uC{Cc)FS8c2|p0twcOOVlLDp7Dxxd!^XbAcH8sN?8SXbK#(Bu^ZMNQmNW?h3O1%pBJlC zZEPi_7vwNe)t*lZJM^2gs+@D+$BN;}%Zn`|%%CGxPcF^PoR_99@Gg07TFuX~PRs_E zFs{!g$UiCbAZ0FPSGae8OAH5_dl7w%sk2rtk?B4c^#4@T?Q-!Ez{davLLJCy#^5+K@seCKTJ+R^?FH-$r~sezno-Y zi7TsE3K5$xk}%ENO-;f#(=Z^R7)kKydr*~ylrF7eK}0q$SQW0dnzld&U^xtKvs~}s zc~RBr=;=mpT}dkJYbftQkXMk3q*Xu#5hc{IW=aB+WI%#4T>Hlf>Yy(Qh6v5id*$JO zBQ)MN-K(H_#UnuVme7Jp?brR=Gnp2gkOC}4X@j3dL#+p1th#k zsCyUSqf>DFBW58aJZHP)I>3KK_^?4MfPNxE3vt5;EmbKa$-`KT(25U((UA4EhGenM zkseS^ohDFqr|7b~wZcTa$%ToiK75Rvh}s&A(6*2hF`yenXtEVB5z#ARj=5*nsvAE&uTcC5Lud0YU<5{(7lpu){0yuR- zVjy=Fk%HygS?@Dq(r+0R%os+`O|yA_N#|;Lzrt=)-$S8xR*HD1%T*0s;HZ3aNj@fghcni673)RA6U$h%V9g$W9aE_X6LdTy?f~JoVfI0i{crSU<(~< zRmkS%4QE86(9k|mg`Mh1d7T_ciC&=d`?bmNIz4n9z7|7qc6n-V{f%Mb=#XnQRqsbg zp)&jHSn47SznJd%#t?qpGM&*heE_4ebWbDtc}h?T<`cq~#^JiStV zh8$@DN3k15E5}I#7<*VVIfkt~pO$Q>+?2EwW2=|HSuyiwvxy3ler1BoH+ z4+riGlCs^uaE|iQp}I71KDi=Zjix5pQfks#hXjJF0oisftoeaa#bfS42+TRvEZ#&4VP#RVA zs5#yfXi$i|j)0Q(7y~ieoBQT*4>|hhw(Gp{IiV5&uSp>ufC9HdfwRY){#W3?@JyT` z)Hv{v6!>@pgFuPgf%_K2wm0tI3^#pyE|)Zz3XOXgF!Wi^Hjng&^73H9nRunOg4JR1(o z-!Fz}_6m>}E-|JJXR6i)g|^YMC&G0ss4GH9QnCzeNMKGf1|fgt;6vo%#Vp8S|cgb0__z<$~h$z{Hq&M7IM(E3@ZSle$};T*Fr8MEzE#>Xx5K9%}Ff zF?JFdKPLmmFicde7{bfdmu&XL9Ut2mJGuQfZ`q6sW-N#Nfg{jgMobk#UMm+>r1*C1?9QCxD!wm zt5HZT8QpvU#Ro*uV4F^&5N|dZf+&U|3Q0$v8int3IV*k{g!>anD$ZCkdQ1TnO^D*7 zZ61k2zQte!qF^Hmug)Sh3h8rsSSp}kt5N)F$zc8jD4G$)-?l|03Tl+WXhh*m$l_~> z8b#A{x$GjK;Hgnuwq#gN1r)7_qQ$m^M8Vx^;D#v15VH7MtwzBqmosQ1;g(6=BQo#7 zH;7@%0Bhm|nC!I^biR>>IxeDkKWK5&e$?VOkG`q3_?o3jFnd2N@bBaTuiL;mN}Udb zbRr>NY~PcF$afe7AtB2MaeQxA6H;9+Pp7%!br+Hlaj1cwMKuu8goHev5DnOs>8A9E zup}TPF$oC~`O?+*8EY*~!qR~dF+EHrOvRof z;HoNdU&-~gco^K$QV|lqbU+A&)mG z9C!q#xI3VT>1H^VMB)368iC@7KorGZ9%>X#5xz&{lK@2CofJn?1n&rS9`Opr(cRF4 z#3Fsi-Gt)Ugjguvfod!$%0UKrE3UnQU!@bFkR(`~Q zxMN7|NAhxE6Guvg8_QBek-RrdjpEvXVnk6!pjdSpP;kX?S1dejfT++x4)DZk-2RRW zmzXag@aer$?_SlxtX?6reFQZ1GW$r~igMkG?8vdKQNKaO!qaS&wdJodjj&oJ3{KGwZKb)~Ep{HGL;u)muwSqq?4IeN5~$l`$FA z>~C+K{v?X@&0&Zlkw8)O_+J#K2C!cB1#dN+M8ST?4LSx6D+f_nPcBfSDBLVV6x9TZ zlD(E{D$)ixy+sroNfe@Y;!}ts3{lLQ9BoNbQMQ?J3a!p5WTbl506>AM>Ee5vDKMvz zs_9S7GYEuj5`o2&JqJLJxv4MQK->%mQ7RKiy=mau4=j&+Kip%jKFw0?{gyWj;1Xi> zWEI+)&0&Bx3eDD*fh~^@uasT6W$MY8=9;*s=@gg!XVc_U14+xAxTOh6W1J-{^K5${ zP3iL9u*?xh0h*Ies$&bV_ZIJ0u&1(q0s^T&dD=fP&wFxPE}vI?T%1&VoQf^{pF8*Q zarw_MaZjs{QOG=GSxE1lI(3j6Dj!?VLA(IGM1 zPZEad57!bUaWwAi-GlUGE2A0rT@qKaC|6Cz4AqN8ry#JCKHyD~C~qdFPhO_Wz;U>E zVX@p~krO@EG>+9zbluOX_@tkc$QV!ba{CsoWWZHJqdCtg#Qnt1VBRS%hM@wV+0JxG z^UU!)BVFm_t9eCHbWEN=gA>J5I(om+c}bLma0Xd(fIX*QnwxZvx$c}_t0(I^aHmOz z*-X=QuojK_!zsu7p_6Qs4pWdLSz?MQzk_p!*aw)#z=`x3IJfk%(>)k#VHAcx1s?Pp zVn1-Y?ud!gbuW1oD?}>{{Z>I1WZsV1&{cPV-839-=|PU?!5O}BMlYOlhV?ud!yCCW z-Kfqf(b!FpLtR_Sg5<)&dk3RWDmakBnBOn0I%eJ|zg0=q2>W;Mi91Xdt&&F98qTs= zv?_^`1TSA6Eb>Z#p#J5-Xgn19R*EEJOlR@IrDYN&vya#A>@#ujinb(~T_x4~o6S#LEGFdp@ekUZcVGiP! zm~T11yTTMM&!rJivLSF%uCo?0)aS%TGj8aI?JH0TK*sy7n_>!ipQFT@I9fD>JnBgN zDr3HMPPVf@M<%JCb?@f61n%dp!y(q17o9)aghRX`hPXc^Rh=_3>lx_+^BCoVYx@_^ zbL2d-+9YA7{lZ|0NUELJN*!U(PZ) zYfvH`Zy}sH>?~*4PyoOePMWvoOIn!-A&+DN1R#|pSt5c$GHE;-kU&0 zadmB@-Ox11WP=I_C`~kus4W4(iO`4!g9s{)s0_{ng9gV0>;^QLNWf?ei5OdhvnD8@ zsHok5iV;m4P%vn~IANSHPN*3FXIFI18WDnz3~RIE04(W00i_G`UQ0!L?-No53K>y<4WDOT@eK58UXXQB+ERZKP z8Dd3kS6`))f2WeSWrv_-*3hmW<~p4|@ul?s6?--qK5@;{x0=lNC+#s9JYV=IU^3>M zqP1;TV+YGA={gnWx+Bz3SaRRSWOUCC_C;!*zVZDIC8%w`>fiHrcIrmk(7%$5-Je() z$wDK}C?~|7(Ot03slZnG;CVLQ+Dd(6fm%_Q^3f6_`kxb)J;7KO&c+4t(ZOT zX>lAqYLQ~rgdV1@_*DP#f>n6UciZrG$r_lNdqJ{zFO z&JbbpoDWCY`5$lJ1EaHh`Vv|Jb3)Z>`}d{d3)t)Rwb-@6FSsC9+x;)6hP5NtZB&ix zm|`q;mx@GpI|(@nktO!l+cp+Mh+NG;@zt?j%u}O9-!U!YNYnDKyC^sw=IV*wLI$hL zh3OUEiQ6TDBue8hyh)9|KcVB8fc7^l?=Xkz5vduYz*PnMuG!pQl{@Dn4?(n3gum;U zb-`hKkXLC*bevxvq-Xw&-=D`^l;QWIR!&IM^;AmF2RR+02V?&-N7j3f zgfCk)1|+uAUV8Ow)%L^b6<)DvGL|!RNjP#S`zaWXxrzj2S6+ak4>+ClyHW0$P`!_; zj?S<^4S6x|n2Hh)7zt7^P?8%Yev7WB-}9nX$aYluEan?}$RqX)nq7T34b6_O*%Xbu zC)1pcq0j*_7a(u-vsrDC5zCNHq-B!xiq|oJ(=EsH>DjwWWAW^XtqZ~a+k!ta<&D`7 zVGq&TB|%$f3Q67KL9f@4+QFwilA+~4?~6m_;ST-iaO4S<`t*v05sPe?ju18`{)79tyjoP5KicVi zWDr@NSc6ZFG6wSPAsW84nk#LTloQJT~+N zX2LrB{?%r8G-E}~C3^5!G*W_XdA#N;{P0IuFJwS_p0>F+9kx`v(c!@6ZgdE8I!=e^ z@SZ}JJCOfIRUc2IVxE59deTJ#>*X=ukf=*z3+V7N=1&?JUa?v_TNb;OIYeKkgLh0e zb4&CUx@CDnuCZXU#*`PZ<^C14Hg9x|hp;K&{fw$Pp&n+5EIbhDtR+6M{34&1lP#1CJE&B70ZR^^LZCn_w6$)MZcT(xM$;rFo1^Xpsh zi=6-a%yUb=s`f_TdH%HZ1M&aF?JrJLKQ=~7V~<%;AH<%g9w~^{)1e{yEL~g?T}@}p zW6n}Hy#moG{T7pE7)eJuet>WGX5Jnz9PS<;$s7tym_y!b=5TBtI^aY(rCvjdrUW!^ z$sgfo1;Zu}Ij!in8n2P9)}tfoRZFAGs4+*Qe`(9n<5zt8xZHHGmS^7lcm99MYxu6- zKwFB$_lXI5$-;l*<7}4DuF5O60I%ZltH@T9+@9IKKwlcsu(n{SJ$AV009LB#N!avAEX@)FOAh0L)? zDHgMN6}mNly3>8U%}-;Gv_gg5D(N`xlf0i-5$`P ze8gXLxW5!x>^y>RU_3I@K){3qxnOZc`jcI583IPqR{_y?m-)21 zni}+Y6Y{#n$8LVB<1;UrApP0(n=cx0eXG-L^zPf}8l)(D*|Yuc=a<@+!sNc|3oz>V!o!nI+pOgm|y2x!mFzHQrfXp z^A$rtHDAMI8}cbamf-^7yN6hZ>*7KS^r-wf_^!C+uOe28_G@bOiU;{WH63<%UE% zKrj7Aws;#+KUcN(AKxYF@V~qRUh#Vp&sT_?+(nMyILeg z-%@=yu+`PnguBa<1q0%8LAvv8h~7M72_Pgm$KQ@@1;r>kNa8VxkWN0Trr{FC6X8(N zX=WU-VEsk(*6A&1vhiZG?ueMmxD^Fyg(_A=Z?{PfuCbe^Kq}4OWy#1`W8EKlC^t)! zgI~_Z)n_rUaJ6!PRsRN$*fg7@;P5@dX(p6;TNgKb_EYdk5DQr`qN-osYc&%hpJZ>npdXnmQbpA-OaAqf$edNO zUOo!^_Lm|BqX~U&XLiGo*75ntH_u+R8gPZ_wQnPdq@M3y_LH|C$SQnX-VOG(HTGd< zou=$hU)St>SL;hZf*E_&o4ka+NONH zvV_a6@JAXFwAFNB5~eiGeusm?yd<`-QiIBUq7Rw#iIxh2t8 z*Qn!?rac?p3dAuPvMUB9J`w7;<*bf-GJLAzzd>z3f2)I(8UioUG(|A}U^=SGEplgC zzvWHd@ZN50haC-f`3$`-{Jgi{{nk;45H9;9RljV#>Kse2C`kB1EKp9^FZ^t4uzW1N zYCz1nwrvgaR}4W)sQnyoe*hAS^y%eoFkFqT@k!Vz%wM)4y*B^!0E4|J{&BheBwvB3 zlLserCb73jTv)AT%vWQ}PlRutlWffNKzc>|!}1Zo^VObubNF<#Y7-M0C6v%$E0fNB z=H2|D0bHAYdz_e1XCxFh2;=nb#DE+@OD^Z7S3OVE3GoPQB?A-kjsB~wT#X6C1Xu-T zxzfm1vn^(&S1}jNgA5q6+9#n#xOkLf@rZVLC;Y6S%uKtvtCw(&Vix~HFP9-78P*Fv z_dAoXr($5j5rehB_sK_TWMN=zQwcR!j%sGo?2+{5*+_GSxL%;OAF3IU8Yib)qF0+S zueNedtKoSfJ?piwnV7SKo>saEzrf2#w0`X92VB~jvE_$@T^9HXzyPu(Cc5ktwHVnX z-={I_WahMfy(|dc+$AN6)|a7=Dv!hoU!Bk5B%#7R4yhi~YF@S;)A+WQ#^_8fxoOnp4XHNiBH=qZW`G9Ho zTHdwNP#$+d()CtD3CY2}+PpBC87lb?D_{ImZ%DioaG~5Mp+bnaq32+n<^SN)ONbX{ zmIE&|UTUR=GkRr}ckH{izZ$g3iXjgx&jz)UF*U8)w2$xwUd89|Ccl;C?QnWk{3;uP zK8OXxbh4c3IyQ^f%)2ktZGEYG#!~l;JMgqyY7X8OiQp5hGfy)h;7Xr_e=v=;!0fnX z<*L_=>fAIc@IBQPrVa35)B6N$edEF8X^?^AFl@~?l&9f)WLZ#;Mj8f1F!D5{_D2SV z`YNPYc+sDJ{Yr`y3-a;^q*!?6j1&tu&KTo{)D?ML3}c2bB)C)Z@> zXh04d_YrYk-|yiDhF@pms^0h`G4{q0!*#|N#Bl95hUkDhnOm5{Gu^wrUDxCgnK%wS zL~aTY=&L+HcOV$fM;i zbV)=YE?VgWVr8|j0V%QfN=G4}p9G}Ef`oU&uP zKmPu`-Ey!(uQMLGarL~cU1yR_%6uJ8ub2DEd$l3Yd#IQwcFETPNrb2Rlq3fdt z@@&yMXH)b{+)XNBo-^uIt)qJD;;5e;8J6l|rYK^IRy9g-^V}M`xsD-#{)}#Jh@qRW z5$&=sLoGd>x%n@;xySX0f6JqwjzJy;j)u67VU%9Z6u6M-0Z;l{G@A<75WBP}g9=yu z%=Lh1I@<#|Ke0U!!q7?Ic4YXnwVu=i8T+W&^F4^oS+kw#0UIhLgJE_q{ebTQc+sJq z5e&ccE~0Gq^`LrdXOU1xE%8Y0CZ;HNB4r*$H%~f8J2&k& z{psdGbhE?Rf9-)+QO$aw>U0F!N8*e|8jmrX8mH)yi=cf9W#Fi-OzI|*>b%Y0Q`+4=Y3G3? z4dBjhbEwDDd`%mb(tBM}K*n&_pH1ddegexgW5#@$)RpDDzLGz0(u);SO{I}?>5K>| zD5}9vS&(0AWE9yWwHQxDI^IPyw(tFYieAd{Wh1f5#bS@&B*U)^A3tel^whfv5unsm z(6gYoo|_?^pCdi=FC|n>i4&v2;DRsRZb%t2enFoCeDdahDr{$tG4R*-%D0<}iYP(I zp&kZe?j=(!`jX(toiMrm&PAPjeBEuY?E#Z$-}WSqay(DreP#8m-_I=hgB8bDIxB+2 znWvF1kWI0OQqRkfF#SJla@OZozl*GX(;g$neDm{pY4;1;S{C_Qo>5u;qP4u}Yjs6y z^@qy(FOl`1zSeiN){nGp{}HwQ+qdltRomC1_I}FtBh>9b*0mok?l4x_VS>8D6kUhu z;yk+&qjGs0CJV*qhrH7r_lq5~)k~B}j?!V%Jg3SPE)FSbhjrqP8`T}N#7=o?r$TY3 z3bAvE+PO^K>6+T5QQY~Ky7L3E%QJD0l%yW(jCyWN>X~KKD?6#zexu%CL@X={@3TCa z8A9#vPR<iD4>YS{E;E4H zp(x&Ga;b}@O4}C(CpMk=AR8S9mI}_%E!rRp|(oB>{yxt9?FlB{; zzgycK{M~hJupr}Kq#XY@CjRdGGUD(47}uKky9X_)Gw^qxAqwOFb<*tF{WjhA8lnGl z@zIx{Y+n0q{|9L{sjJeD{o$be?u_@tG@I=H>vC%t+Gp@$KRD?47&O}F?$3CWCIX?d z+$>3>7$GVu^ai7J-?#r9A231$rs>ahEKf8@!B3t13H;yQmF{4YUU!OUrPXiu<~O9- zG=Q3!DOb_f8r0fmNewJbcM_PvW!v`-_M0Jj@zYPsJx0irqXNc)uX^Z@+dW74{`g*- zs2P%DgT7iZV1zdJ*1Z_qHuwi0@INQXQ&%zt)$v57J;#$nG2><~Hw4n?zX^3Lv$ZsM zdVaH!PT+!KV^#z&R|Gy8A2iqP5TCX7vJ``@c|UkES>jkn8RsADew$p3oyoPOiLm)z ziL@9+@r2FyDx?ypR)oz#7vL%y0x#9(gw2DkN^vJo*!;-9YaIOD5#Z8{uzAR8Q}_XQ zMxD(Fo1c1keA0}txvbp{J$#}m{Gf@jIc2;H9C22WD{~WJb8Wf&Iy}L7x7*;76w_P* zVROWU?eJjE@Por`6JfJHxHuno@`TMF+Bmof2%F=*Z+{~Tcq?`YFjI@Xhl6&Oy^y9^jIXscDcEYgJXn8c(a_xka%gi~tV}U$JX_CqLKR>uK6t%yi+GFbZ z+QC9zJphEvT+(-d6JObU?XS|-$*Dz_B%)VitZCqEa5&p6Q{!rjq6J+-e<+ykPqPT@Hv@PJeH z;D~LsV|$~0+o*l3lor&*Z9Au3dc^KK6E~ zR%O12uXba34&MA;uB4FBl71LgRIMf{!yG|qA#e3QwrNsX8 zw(YP$L@e+W#mi4fZh&$5Cu+iec^=mU@w$U(LM=6+-aH3QU|(ZarK+UFO$i=k!@2$X zQaVSUm-IHv;rcve#v)W+MwOQqHIXYn@bvIRSOOS^*~UB=?md9cF{oTEigLHl$C}P1cbIdCT_C^3i~XG%99VKkh@E3m$cz1;QoH)Jo3bA#_5=G%Yw4WC z?rT=aiJdWf2gJ@I8uEssete(H2c#$`gS@-g&~kCAO;iME&1uh9OZO3SI`(f(#yDIQFv)!!w(kjjud0xe4 zjIUfQu)|?f<*5lOIg)SR9=652^^WZ~;;@A|650d~fUsB5pm5zAgrG0y*pEwu9+4JCY6cm<^`<6x@BxlJZ(YPd{^5oA;kt{T!1Nfp6)AXI_(sHs`wOn5xC|Fu*|>GG z;7WFlu;iIr#|iGKaLS`*tr{Z`cwCUnH-cM3E0@g@^z!8w^CkzbGd`$pDDExW4VmY8}=$6 zx4oJYM;g@dPx~P6mdGelUb=uvQ%pw{wFgkewJWG%<^mLZAU^_z0Mesm6;ogq2}j|Y zpHaAWBDH59ZmzqGn;*=_%^yMa>G5oR3>#Bql-;qbg?={7aN6Z zt#q;M8-0rY_jIwmS8M8_1Qp!teH`jn$L;+dBD>AJguiofasQy+pCh47&Kbo;B(#CL zRPQw+=j{AK=w0vsoc_?u2x=HKuO!OQ^Yl-sY6Z{UJpP?g!<>;^iS#S45fT5qKDKm* z7uKejWq82%Em~xATP@)`UOG0H#f#Ob~a{QAZVa~IYKX16? zpqzWYSHCnH)vda(E+aKk@r>oDY{}%ms~ptt-re7yW>e9){QGMTx}bge*N`^~MRle= zcUH%?@zdMl&U=(5ShRC5MtF zYo=s%g_0#BCF_Z~(Z4BKhZN0~tRN^=`esVjSIjw7jnX8=zbIM5`cySjvXU#%@*`Z! zBf?8=G3U^8^5s>Wo+EBjvQ!pFc|GSrWgL_&bze?Td0?-GlBMoTV)r#W%8A`f$$DVa zT*-=?)jv?ri!C1!Kby`W_CX6HXl|^MQ}Eq<^!a}3^ON;WVz0G|%wRM!gSr?JmEy1) ziKe_!c}hRCQ(WbWIX^?gDLq|;` z741%p0t_MMnLEZZQmY!uoTGRd6q-UMjnkInjx4LW!#}8?nNzYpH-|Jb` zJDcfQ<4K)rsb`&Ure}Q~B2+Yv2(MhdE4*^uevII$8y^X;+?G-d-G!P-0%5YZZV+4v zEw^u_JUt&HRG#t^CJGgsgsq!|Mcu4yRn$?$hG`cG73lnQ=W~LG#+Y9VE#Z@%%0UFm};C$kFW_$^e5 zCCS=wi!46v@%yKs2`R$ath94~HaQnb8;g<^OlfCGe<~!EsG2mVg`_#<+qq$nmG&r; zrdU|29E{X?D4~{gsXEf7>PeS+LAumRdhqy^IA*VQY9DVj&%EdWZmu6^@wKn_IwOZv z(zH|qp+ZHFrX?axtBzEtI#QuxNQKIyG_;eTX{mbSHeb@T;z-ky(`_kqTO(zzt-v4K zam5=yAwiK?MP;HJnSYUTkw^hsPCLmQM$hVktF82`hVIr#HRBrX=TI3z&bSY#zy~O5wMO(O;^rZ*FmB(W`meBr? z99&=5hfK?r6dN`%o~B%A+Ej|ODz`eRF~L5R9j1Q?&mi7dKH!=x9#*#y2aX=*&9e)( z&TJj}hwVT{=8Sb@j2X;cW259!_rMHH#3iosVykmR4>!y~rBqzY*{F0F<&`3CB8uxZ zi}t68;UxJGzgFA`bw=nl$3i9>O>Bz_l)laEd|Yw~_+_p~NM~VDWY?>bPzv9UC8K^C z&+4WbcLVDL$5kaB)UhqLOVA?!F5bmfSFs-!U+lP>OV1dg8W&F+G{lPH3z`+y9$G(o zIH=$QIFsVznUjDph-M{2X62=#{jGRoXFkPaME$r=QN}YWV-#Y@4K~$a6VTm-a%H@j z84pviC(791Yvq-t>_I)#Nzb>V**?~!iCh`_;L{F}rOL2yDnINv6AXtD=#R)Ym2t?# z>xVwAQ+7u8vAZPE-vXTlG->=jxV-aF1H{Hw^|vpU1kcD1eTacEVpg8PEqkpHPbI)Lo$Ny zgVO_#RtiI-&~AQ1sebG)L=QuWKvUHUympO-o`92s;Z-W{j`Q2&Cj%mUTcn$yH8&ta zHSIw8bqtu9q4btvTzlgGE!OmS+slI4F`edUBO)GJC96QKidtlf(whm2sY8ef%2=9 z_gLM0wGq;!js_R1`kKxYSuQRxVyhivd%#{X-!Ncrm|TNFaC|Gtp4y< zPqE?tY4uUMxXkhULR`Fc;KUV|CbNuo@aUUIkWBfPw>^+y`Piv1HOR1>x!q;+?+)rI z#eEBr@cP)6q(W?~Us)A|49mKXMfWi=7S@0N1?JIvH9Ou?qi0Ry_SQ7Pm*a-T5oM;7R`h@avNliqyWPiMKm~GoB|E|RQgdr_? zuPgV(_F0Y3K~1w@aS>nPEe{W{>*+bHpV98dM(l&rj`^2bZtg2>n)@#Db6b|Z zxgGP2(Ip%4(dzNtc^7ePrE%jHW46-MY~??e>#5q&#iwy+IS0`P+f2e{|0oD)(`M{d z^@#T+i*>r`E?}%&tTE&u&~kHtZqh3VFEXbux+(Rs&sAl|3NWg9a@`iZEJC^zyL&{Z z)|CBtt|w+JWlx%dnzA3i^VqHI+${U?JOBM^I!AMwWj_wvhN(wG)i%q1JklywkF9?` z`|(JdxLP_#7c^x*mQOr@E-0j3`KIj08Dot*rC`^a&wiZUevIjMFi8rsA5S}nYJF+` zf~yuk;}uk7nC<{)B)<#ZcX*w;C$t`6_T$f%U{hb&mjrH>{a7(B*(JNES@z?y_D@Xr zf-eBU0O|`ISb!|DL zJOyozB^@dHJ5W~`Gl*MZh+>F$6l|X$aHVn)*TiszwieNiSA7->BWtnC;w9W=x#6?8 zo3OM;UA`6GcF$+Ao8UIYXhJYwoZs?Yq^W?`h4E}$39FZ~`yN{IKweCn6|n@l=iv`MR;2H-BK5$EAz)aXO+04R_v6a zbbcVdWO_#{z2nQTTv*mplZu(G1~c;+lKyKi&c@!vm&@-jn<3fvD1A3ptOqOhodMqN z`;X!gma+UVlM1~?_^w&<>l+~Q*cSsHrP*kEuiE;DgE;#2&<5PGC+P7j0|;kb$;_s_ih8f!y`1;5|nrh=2hMej>0U%pAhvgVOzObhg9W8g%mKK@`qplH<)ca z1`7=9SaE9{;Mo}i3AXr^1K@3ctS%`bp(-kZG=&J!xP_Q13TT~h=5sy8jh?Si-#~(m5o}ka0dy8^`y+g!&_Oay#Gw)^nr2T=f7eF1L1RgZ+!Jh+IaUql`G3bXj)D`OrnPTd9fIu73NZ?E2lg@hc!vhlHVuE7igQ-%gI)8dT zcdo~J>4cA|#L;XaC%ioD5ro&$dRd@z;(Ko1bBmw!$kH9ht+yEyYU|A(0W6RY*bQls z6Y4iIOH$hNS;fUMm#}n3v_(J2Oc&#GX1%3K$GDf-mT`RaYUiFMZaVTm^LI8|9CBC2 zWi3<3-ow*cdfL}?37%$qvErM)mcKT0B&)dGSEk1~Va!fuOYkPRp>xZ`7Tp+tZY=Ip z%XeeW#Nr=8zRSHs-EP~xUr#3I1MPSAP+LVX(FVe zcc!!thcVZMKS76(*P6eJSt4FVV=_p0&9}*}Ez0SjO0f%I^I@gH&#SvC>AC5WcA@Np z!0?lT@#a4yWA=6hzkwi*u^BcX&=)Rd>^D#_UiQ12xF;$|>0$O`$k6qtxvz9>8hpG$ z$j2+hd?(-I4V&BE1TacsGVLi3KKM(*;rdb?R1kM)Vh~Gt=;l$ zPFWGB-#`;J_Gn^``F_5MW#78@$dzCxgnbycxP!=+MorwGq&5G>RBh>1?>`9~c0bjd z1EAB+$?iYR{d-$_rH+^kC&C<=g;KW~^&J8UEVc-;L>4357@K+~*5#eizN7MkYqU=( z;@1#k|J)AFN}KZlJPz*xY)aUwW3*q(e3shbRjDUyvoyr+Z(ua)n}?@DxsJI640eAu zROx4lA6#x89!9b8ITW9NnDp$}N;0%B2L|(uzD-Cbe-9y<>X(d9lEIbgR+n9@dk`0} zes)+3Zy}$tVTpv|68#0}3C>Nk7ld>LLmf2v=3#;P=~lI9_(^a73fY&0mI=oW z$F~p=opnUnLcn9`Q_WiR(i5M>ZOWVX65h<%7LS1J#dKVY zE1-VB@M3&{;KCvblMes;u;u%%EATgARs4^zYDT`vX_1+3FOaE<0*2z|SXla+NsBJp zwTBgX!C|dP`kt(aJb{W7Gbmh1VjmVF(aJSst{y>NLH`iEZ6tV`fI@0dpb!mt`Sc*m zi;7V2(goyc6e9b4z5L7H6zT9Zno0OKmGJNWD@h3H+m^>vry=XisVOqCiH%t?1B6o- zJ-ck=CrV;Gep(usOOt_oEg=PqcNaj&b`sW4XC$t*!n>MYoqs`$0fSn_E6PlyzTMj1 z6t#_!S=S?crrx@t7T~t+Yy4MXx>zKgDKcH?D_xXh>hF?i?38KDKua^SK$|7W&Dqce zd5?8AC0z$?Mx;aEKM&NT+3Y{|;dU6EbM78+gp19whT;N%;EKNHe*@xnyfXALAnwVl z)dhgKKQ|^IY9a65%Uax#J>bM$wqP^XXGhLD<7{Ux&a_}JzQdgW?Z2^il3nQ26Uk4! z0Ci`jPbpj#qSfu&D{T&rmpL#G6fqApI5BPhn-KF+pk?_##yl?WF`SGhE#oEDw{I3N zv7QN@z_P={OQ>M)G(KpmY%^EpnJvcpKr8dN00T-jd}US2QTYI!^<=yu8@HM+X6RnG z$12p{u#V;UNI~~atB|WhhXN{s5`H>dt8m5=vxLA!0H1HNpq;bj1vfR9`m%QPB46)Z z$uF3G<$hxq$8qGN%HZ3mHHV;Xj)Gf&#Y3&yVPBS8-u3O2h&iByHf}~7)Rj-5Wl%+M z%$v{AkP%!_UuNseEo!eqZS=4LVYT?Go@;B?Wx}~TsIBXHzWIg?E|v*UAvo~ZQ0^j2 z#5uvvuCRZzR^^4OVWUx#RoQqv4?DYf%D2Hj!9BUI{rXhkF*0y(4B=pB$#^U(lAozg zdpk}jO@)3Vkk$ah@}wG@G6R+igid@k)o6eLTtrMf1R?b#WK-CLer)na+{uMaC;$nMKIAvLE7TT`%wu|op#ba@mCBj<;Zu;z4dohegt3ia-^C+(>B=A6ij(g+T5@ z>SqM#@nTQ5TA)W`7jTV_m>_X#Kqd)C8pw`j^Er^~8&;3rE+vqgqih1X3p}+-Z+OFV z<>{u_ERqBC#t;lyD9_`CkGTxt;|QD9OXGOqwc`gC77am$AMSba)bEVl^oT1Ud|TpV z2Jp-(t3!bv*^zcCzb)XbX|w?g+KuM$h))zUST46q6z8Y_tTXj<3FTqFF5vq8dQ%PGl?8SImhYgM@u>Bu4z?0RRS@j`s8KEA zqbeF#Q&h!h>c1ZAcsQvDGj>N5141$kh%n6Gk2yyVY?rU&;G{D3IV!KG%9|tJ+AI33 z5;t8>?)$B^yP=_T8bm$Kc5;I#A}R{{Kn6uMX#9Hdn$ii+fQbHq8ouQxhL4J*bBL}s z_v^Lsw(VFyq(s?=xA8>X*=(=kn$}LC#6T<>$d8JChwU&b!sxyBE`U_*sE8Q-{uoH{ zbJKaZNah?;9gr9J?HSIaY)?C&Y~56PDQ+cvw`iyLK4^A(o{I*Tp* z+^&s|d+<gre9@&leUFenAg!o_wy4L&B|ZFz0AZt{g5y z>{_GCd<)1fqbP<5iz9f!jK)(5jYkaE9>WQYB;?nV(#Pyb$V^NXGsRrp^JgQ7K|G&& zqHi5vcTV`~DNq-Vb9Lv0pP7PBnK7gL{4ag*5U*| z%CaXp%td$baU9v14hvxzqBssoKeH3ui;kO00Vf}$d-j2vH_yZ*ktS|1x=q|;;SR)pa?h;lj z?teKkI?1#o^vw&a=2X~M+Q9Mjscl+l1J8Y;Zj&M>WuiU8m}rlfFGzVQ zJO>?R+RMX?7g8AGg_PK%rJe&_Xu5z(7M63L#eG}e1_jBh2(G@jmP&F^my1XE#9q6B>(TC{CW7p}2$Fo5-o|ooCMBqzf7}cuKgjB1S z_E$L7s;!w3dh$uT%5l`06lbv$VJ6RlwvK*j2rBu|rbE&AMvj8DBL#pgOOUnbPbgZG zg^JeMmR)-+J*$F??R;*|nOH zf06!G%Z6c)ey5$n&{34LyW_jJxB2?}&ljnNkIX3UVHRmNVu9L7vPR+TuSU$htEhIq z*j1z)9R#(JJcDH98I(aTKZWG-6H1;zyC}F~EX82-z)rcgEpF48;x-5L2&=QrWcDcb zF?$q)aqa(cuUrmoCfnnBC6uk8uiw4H*1`Yw{x_l-5T<~A_Qt0T9ibl0t9u!XTI%Dz zJaz+GTyF02H97~~wciS@@nhHDU%KJPjrEh?z-!3&)R0E_6)8^~ng4eE*{KntnSU#+ z>uaSqL~ZNkZ6B-J#za^@)6joEk();NO5?P^v8JE;K2s0OQ4fTy`HQ%H_xj-9VT$DT z-C0a_D%XnqiVV8b<2V2WT7A@V7?qS>W$v;$GE1*@%86y}?dKR)J+VV(C?yh{L9q3V z4QCKl1xI2dx|G|SXmCI(A{DKMuAIx!)sZ3dWZYV&4BQ4aN<+$9Mh_0Mi3fhC3qdL8 z4cpn`fm_5>>UqGWu1Q=FsC-;BU~eI98YqIewKHS1@z=PM%TlQvWV04}r8gEFo!JF$ z5f3_tI(5tBI~^})v9;?CcBzapD0t+v;7Q@(>HrE0mot`4w@Y`(p&4zP)8 zy}d4)xDR#aQJqcO>v3a~JEk<>Ue9X32m~3z?RDie9jcWJx7Twja-=P`*WHHawAfx> zVo|`uy-#OR+`Mp$T|TY^&QqEE`x@23A90@C{yN19$0hg6OM(Z7bRJS{WCM7eZ9BKt zeh!aOrksFSjU*PkcTVqAf8x9%P-{@+^d~G}?xrHQNsAoY59DY7rcJ}kW=Fstemi5B zQvt7|h$b9SR6#StwgGF{S9RzsJ3>3^DC?3miMA6Mi$g)vc1E0xS5cIcty24z*_Lv$ zDJR>UW6m)Yv}&0lNF6!IZj+ihgrKXlp08XpVc1QmiA%WAp_!2KlsU%;{6wzj)I4#_ z1*mx;q~=AKF5^*S;br++pjR08*_Ib}=$}UC^vQNIgVx3FChEf)(0O&FV^a7!_m#wZ zPQ3#aidU}UhezC{bBMj(G|8}IAa+Qd_#uY>I~x9<>6^r!Zk08NQLA)hv(YH0+gxYP zA$Gl5&x?KceduK3ktDW@ibu`06Nfnzuu<-8%NshlN*|M!h033p8X@|f?em-_wXoSn zf^M$WY&jd~ZqTZxlyoeEQF^ih`AIx6P6&v#1EZ~-~AD|54ZI5RR!O<1x zZa4ykO`>qxEDEQcN?@%gfjiqUxmc}`ftl{(BDc}TL*J3vyKzbgLTI&w*RBv=yMv3W zdNQlWl3D#QV4a+-=5qnr`48@%mOEEo59A){WAZAWb{kxTqOO6A!i-3Fy@lq^U<=MxAjjL7H-im6dZd zkoLsi8_-Jt(mvclzAG**LD~zCfcR!0t*TvrKraDE%eYJ+ZCgu__MklOMl+C>6D--( z3Z!LB$cB)aEkT-Stnt}qAWa-#GYuq0Eqx%>VLO19a$15k-{AWZ|1meCxf?FL8#?h+9WxzVP^FeJp&FQD2GAxJ9(NZXA8#+$qQ z4bM@+M9=#`?yyMUhiBq90BJfQNDFBR(v(51L0WMStq`QCY|CXP{|iVf{Pts1?nae218M4cRSddonuOO|ql0cfL0HmqiT7b0D{|cmq$Za@h9^I;YFtlsQ(9Zg8 zu&zLjzsX_0hTjG|S=R@az{Jn{ZIIJ!7UuYG{aaw2{Pb&D#xAH;7@RFqS#ONJWH=-i znS|>7__M*4Nd?A*jcrJ;T8yhlmAK(R*b@AC{2O>~T-dx6f1bGKXa<8Ad0FJGXDoc1 z2>L~DCOhNv1a|>40oG%(@ojOr*b3{Frp0r3r`;H}$K`k47I)wd`>OQdB!&efp!}5g z8YZ05XR(j4TsTt13xLrRpQxXtr$H%2gK#xo3JtWH!uK{RH~mx9I^ zf*lW36QcB7Wi2o&jw7R@jo>0g0}>Nl1QPREY%jP7rwE_Lj)IHh=wc_q#SL^(F1T2Z zi&5*z=$JQ|jEtFlRRLs;lJ{sz@%0UK9mk-_%}wS(rz-l4U&;}TmBL`fkWN;~@n~Z>{l=(&E9j$#AxB{4 zEGSJy=pF)TdU&tuP1hdG zm7IpBp6nJts6q1$B^QcM2s1U=01JPu1-#>IfVwxbSl$L`VyfgGeNfQ}kFKCc*Ye&M zrQ`fevG>DyU&w<$?u4zdmWID(-oj0WzqsIwov?|<8UA!bhc1V?s3oRUPF@Y%qae6N zhCfzISGF051&Yg_9^!{^LjjZwgty279lpUs;}OTQsciX0AFvI|B|NPno)9UgU1S) zIlIMiW97%^z;;S+68I%=<{UKK>ept`HFcrn#vMUc;veccGpELGAS%2=6-LT=M+RN+ zy`l_Qq(DT;np*0|9rvvR5_5o&$=|^&iZ?P%xd;5RpYThDvWcK_rlc~!Jn9IdF_OfnanxFo-g0e=~&{}8-Q&Zf^E&?YRWFh?Z9B>t%7A8)O+Zh#C9@^ z;l!3re+#i2No=i>_$gV}*5#gY58=Vd+t!v3kB5h{oNQ}iQ$J_!Gs^syU|Uy$HJ|u4 z+1A3XB4t2MO@MP^ibE*|I{6ltIGe?DmFGGm4E!b|>PBC_VhP`Zf5wo_v96KA+f8Zt0xm<9<*L@oV;k`-tRU+O=Sh;X@ zO7DUah5?g(zGPECVG*byQ<^+3R*PW$J?GE$gD7eU`k{gPA+@N9N%Po7Oe*laU@@I0 z&>Y*>WMFhPFX|Ecv4q=;Xy!cunri+Po&W>v2Mi%5`$|vbTSz=`jsnJ*C-uae;6C#& zw~;Gw6DcO!>~)p(a0TZ5h!xW%dCf>Mm)PBub;wB03W|GhRdV(R_VATqV!l>26+N?r zdS;)NH%TgobV8VSZ_b2J(r$4xh{}Hc`r8w^KGDxufj)UaJ|P7~M6=0EeYdNiP#VZm z^tT<5T$>y&ZkQL*6c_0VlGP+P$ufi*CT{bzV;7A>7&UDUM(F%rnRMVLnRyf73!$5b zH{aY3Lljk#aly8Nqi7MQc1Rfz7(x~Fw=n9{e%?%e7i;!Dv@MU^7^;f;@IBsbrzIw1 zo@OPp%CCK{q<{wPVfi&blSujVb0(p3`;ZWkJz5yQ2*fq)D~KPwG*ljN;Y^_aC8-nl zT69| zk+}@sY57YtH;o5Xer~yvlhnjQ`7^)UN(9~yQ92g6757sOLr=q5BI|sk(G5(=Gt`bH zZo|A&hlLz+MSYmYh$gdpO0p~T-B0O9cn z!yw0G7ow65V*6alU24K@dFHTtBfkAr4bs0^}nVYRlu;lQb9kC-E z;513JnUXrS_I8L#pfJ!~_hhutH(&+1#G zo%QbwJ*<||!$#uj1t;>5$WLg`GJdcP^ND&$Hy8*t*N_TyVKeJ)&iaRc!)R~==!X+d zi%{%Kn`w-Liy@=)xGm_Rs?SI%O8mHn3#V`45q_(Sz{1{*wXn0k5?Q3S99~Uo%gw{& zD{){m<#S&ebBlgBqGOyWCJ-Ozh3cQ^hgdkDN9gclO+&!9I04Jd@XEsl7$c#}ShgV^`ar=$iWORAvcKfBC{YVi~j=onN| zZYVKontO|(RDzCjhK}28C?P*rL+Ps)**?c1>3sUNOD%6ZhTHluc6hG*@@b*X-vZo> zO2^OQbY*)>1&Kzj+9TxC?6y2&6h?mD3Mc2a@J$L1_iv&9cdB4c`Rs>g_-@GF&-!}! zAVW^^v;Z(2lFeU<-=KEI9fjJJb1lh8mb8s>b!cKO%ze$)tbo?$dz2KkN=D(zg>yal zV#Y?3^F_`a#<_kl<6N&L9qQE&bact%f7g2+dA z%>rXi){j4ZS>vD!zd83+noVi$q#sehhOEeQ?+ED zr!5=9?s(9RqBv+#rTqDABQ>1^Q9BfI3%+8?Dd-X@q)T-DnA2J`Jye(oa)H-+aCPc}tuEqDZ%!roaQKi~pm18{yLc2^g?c00tZn#nOvtjbr*;@9scXf66+fCQnD8pV=bu+0MF1uFDFBL)B7~sx_=DuU^J_ zYLeZXA{h!hthYYzrlS~$_43QxctZZ#u-A~Up5zM|&&d}u@quDL)CJCyCS>A6-yAx} z(@oI~>nh0!!IgkaQ#3=}WmO@xA}*RiGtu7&T29Mo2G#Vg08yEw8Z<>S)K47t0q*3Y z8Fb+`*3g&OBoJIQgYwfIP>ym)Ickb#5YMyv4R`X<4B5j+HDD70-PbId;qTkz^1!G? z&7v7HrWEI+PUy?5nv^j`a)vpFoXw&cnt7o&k7g*iqJ0Zd5p}>S7X@~g)X5d zY$k_Vwc*~7wvMFD)HNyLxri5rJzK*MjfFN*zOZ-CCDb@up2aDI&AiZ&{)mUnwO1mf z-TMdT@Jf6E;)Ophv)xY$VYXT$?U1eJ5NB4^B)q6u^r9B(c`JfC)@q>^W+kZ223SwsTF%C)++ zMLKIQiAR4_yEwvBcbQeG44;QG?KBiL1;uJAOyha6a_oB{6NHk_1d%hbS9|m+uVe>C zKx54cF(oTmj!_tLIV@9|>Rw6CZVHB{x{7{`GfnMO5iA*dSJZ^sXD2{{Mv{Pawqc$0 z-uOco{gl&1bAnCKpQI3zr%kAy|3uJH!vk6 zlJK9&pD8@YjHTn~N8I-jvEXlXp`LWYYes!}%{h96)m%kj z247C}1eV=%btp52P8mV=m z8PDP*-^dDNU&sZG8(;X62~LveGTTOKuF#z3?Sxl-jJFK$@1f0T`$}76rBSzKIIl3j%x(J1`cJR{8Ut(Y= zY%wtIX?QJJx!0<4IVQPcZlLV7aacj;5IZ}Wv}acAmjk*$>@*U4a2_vq?R?{9opU9R zII(Nz`+v@yL+l6gJYMWmHz2mA3&gIsuII%r`&}lz4_kIyPLkXasp3yD0aE3K$s$~5 zD|f9#<+^lKo;;+CuUxetxntp3H}O!e+iD#bN$X+#Y75C-+ZGMsB+tysgyb_x^5+qt zOE937pRE=sTyiBZxyI`kNSyR%-E>3ggC1)3B1G*}>ZZgvDqE9)ar8>nPI=EPnjc5! z_ngBx>IRuKZQt-RZ467aY=DaL)d%2R{ z{rPkMO6DA5_jKLO^|`ou+b`WW3_krKpryStZGoQ}jg zT+gY?Y!>NJo)_`uc)C{kn|y-swF0gb*~$|rrIAXxl`@(y#oNAHFQaa`lEGXnRG&m0 zdN0yFFo&dA-!5r1Cq?ED1(4zvNipP7lN7&N_4C8?KHT&APbBL$o^w+N)8Gj<8_Lz6 zv+^$L&kjWWjgm=x{ZU?ev+rTOK5cma>M9%0uTfqkN%5hpiX*)cKR9?EYeZkXLb=@x z{uO3j&*-wyJylxkVxd-Zb0hWEVdo4O7j|AAiT1)YT zC!vePJh4-@d4;hZi;ig%6UqD@9^)Dwy$HONnto%{C+j+3M^gr~uZ=Sott?vQ_5qlZ zMxoQS!>96{zROFs&OK0W&vn|ak5bC$972oj{ias9x^r(WVwF1z&{F|rCe!!=4t(TZ zTv}#Z!M!!*{ckLl?%0!|xAuvvo_lM$f3BAy==lQ(9bvzW7rOKhRaTjscskeXrJdJn z=$r=rG_z$~<)S3@=TXYH|U%spKJCb zSGjMp_!=r-OqK75Oc8I7Vta{6D>nrFAe-pZM1$s8}|u%)!YK|dx%%k#%9$C)?=XZ4HP z6#>WH!izoI`=z}03_-)txWpeV{IIm&vc8X{{kYBvTH4RtBvBu!8VvRS5lu3&4I>Qj zt5aubAxOfU*pZ>X5vd=XD>+5&I^o*D^_H@S^aSK;B)L={Er(n*hV^Ts5S*QQ*07JW z-2oE)Nl(_grkssSwt>=hQ`uR!!}{z++lF7_jQ{=snyVrMd9oNU*j;j4c!V&pk!e?_3nhZwr5&iDACyFXHqLE+|| zxmja@gY0%LHto)0=xaXD=)5E|nCpqUsE7}t<^Mw~=E@x4b9T9weJ*}&dG~`rJzvHm zhwQd$pIpg9lIp%|jzLtFuAlS`jlN(SJ!e1V-}6!j%hJ~vJ!iS)qLb9W2+!Bk^N(B$ zx#z_LrE(99z7RZL^7&`{^Pf6cUU&mqEADw+XK9M@hq;oc^!yXoQto->AnE&fK8v2$ zf0V#KuW_*atr02+_q@1=w8S6Jzo6%zxmIw`s|QPa;rSeT{+ekb|NIUI%gVPHRowIc zgS|J8i?VG0hOgNN7+_c(HoX8BM9fi9al;uF#eEdqw*e7Xa5OZ_)LBqbQP9Ynl58|n zEGG_@H(rBZJcR7}co$t5)@F*VKJ?|EHw4NKiUzvp@1&*%B$z4??gb6w}Y9OrQ! z%lA-s77N!P|8tW6X|RpUZyg}+hWwQz|J~C~`9D#}>R*8?!jK+l0v(c-fS?H+6%6%zSq-Q1!B|SlkoFE! zn7b)ThKS^I5qT`ExDc} zq%TvHlGozEND`r8LAEZWle45htca^I0J{=42;0O4F7HukN4%oM-Tg|=15to9%pB6i z(J=E&y>lhJrW~EStJoC-faqdUYEgDE&fl)+rc*@Bjw+p?i+POlTU$ss&qB9F3+@qe z(~Z1*%~+D~3;A_oZ1+kRdr!VE6txM|7&besjJ19ZXn}1AS|AVY;b=kqOl9_Fs6Wqe zLt`B%u2jHp$(n}%74b2cif~M9#YS5KF2Z0M`{L#KnYvK>)20WG`PXbQ zbkDjF!OP_aK0j3_-x)dF5{mV+BeZ-|QhZ`REL_s0#)JYii9L*$_*GuOVLTk&wS^9H zbhkLw5_8u1tf!vCZR8CXcfC~%jF30|9TS)oJj;W6FYg4roDy7#%#8dG&|P#cLGr)kIjgkpj0#Ga$R}1j^sgu^LTg8ChGJYBUj1qu$uU zs-fH&l@1>ZYayFRJ!Bi843x+j=+Vhom{jcr#=Xj``?so2OFZ-G(YV`@cW;U_L$+HB zQI+rm`gMRXk5`T!V;ykM<1buj-gX%N_YrpT!FViP+=&|x-W+jkbXOSOaN|L}&)YE? z+|YJ1o*zix7<)R}!ALF_cj*PAh$IZnt##WS^lZ0ZV~u4rWQSZC+V>5EJpPCrUQC*} zxs!1_*YWDvR3bV)0v#88*ocm^M-t{rudj_Te8|n|zBE<);9kRRB)_#&44$(CcWbr= z`K=?7zv4r^Ljgf;vsUNTMyQLp{FXFTREJ)}<>AQxOD8^_v-zLRHX}c8H1eB@avb?J z+pP4C+6dlWE`M#BYDDK=!&M}IaCo^Rf6g2rOS-Bjl8M_};K*;@X4M_2jj%fM-%V3# zyY?EcBl*XK8}Xd2-#kY(9{rz0`v1u;NB)LwR*Mz=|AedmWtu9pd#~YUlHU*>gXe7i zk~w+EZ=Fo~|4FeUziPWxE{^OqygHma4{rI%Q#6TS2noIUG^ChQ!x(bJ8JI!~iTY6X z&8LnjyyquN(<6N`H-&lId5gq-dkv?$pb6K*%Rf!Z6LVWj{}FE*meNwPln#V+;g*th zxmXZTjIFef>&CI!suDH$)H4L$s&YNug(E}*{ci!By?6pPld4OpV>4CHrX?u|tx)%k zowMMr5WlX3`J-V!S>gMorHgmyiZKC`HevwwKNzMx`u-RhkCZn_%$Tva|S_#qOQWv>~!bW&U(+QR<AjDbionaT85#(&+xBf~5po$8A(9|{3p0eHxs7GHf$drueP;m!Bcx^KJIE{saCKw) z31`#jn_E96SRnWC_3CZ%G{Xdc$;WFI`c~%i*N^2YzA)Nqe)|O0AS0Q$t4q>LknJf6 zs*{;MWhLxGjuPO&X?G5Xx9;fm=#)>KVQW&hIUETdF9pO!<_gBm@=3-$>CjuDZ^Q5< z{XH0VTt5DBM{l@{mKX%mm6k!6KBj1t}cxs(2z6sQmu+geUY&6p}lb{SG?r4n3g1f zgnUG-R7!hwf%(eSHN<>n#S{Yd(L3WXd2#e>V!m?c6ccl3+aB^Fy#|AD^~-e(21uth zt%mptCU9=+M(nyDy||KhrTL0L)fZ}_QwR>@m;9EKdpSDgrtd7QH~&S#$g3wvu9c@-6$?Dg$A^}o*yBLJ0sKhDZ|LB) zL-B24g7bLmn|Sg4w(EFtc;yCO>#-m8qUvjLYI9ti#E~zgxi6}wiM6d-)%Q{8p01h} z%Q|}T&9wn}EgqFV3_bR9=xx**IqKRfRZDA`FLMhQqGQeNY|3&Qj^)P|ArmS1=*=|u6 zK9`S)jUbC?*5BCioMS0eh%H!w&gVRdKlWKb>qG-%boZVZRnT& z8_Ky~`U@)9f7Le|`(Wj^9jKsR#u_PJ9O)W@7x%^M@Z!Xj4&=qcZRExA_yBryZatt= z$EWM#q8Go&o^dVhFY6x^Y*wo^_0IkmRU<(cTCV9g{9Dtkd&hDMKex18Ac#_ZLm&mQ@r)F%`vukFsk`py0D-z|*0o@akgUfw^Ii{_m^W7z6Cd0$xX{1Jf&< zvPM6BGRaQ3?Y4sN*<{ z@lbmFc|*Yk8{)*R99_?5WgMb*Vt1b^BK>4$JDYxQ)+3+Y$z zB6Dd;JC9>;&&X-9{O)0L+qCNNlAlB0eiYhgQHvIWOhqrg!K)^-6S2w5hmdB(sc(kK zmbeinmQ#~kQ<*Hx(KR-o-Dx-c85EZJMr}x?6}cWGRen`Z6C#K z#;`>hFW%pJWNWZMzHhxi`aBAq-A>c8%Rh&%easyj zb}+GSn^ zzxtsOU0NamllUKY-*pb z$ZtJJiJ!fw`mxpW``|@OyH$cw_L~EvY$WXGU^8mFsD;Vtm;Mp3<{WQ={b<>|V%Ur( zxx(sH`H2-asV^fZ!iIIb7>Ug_B=JI$ST)il8n2}I$=+2v+Ml@*SsfZ*o3xvxuOjI! zo5U`gb|89*R%%@7M&zD zey$*W_aPH7D`6SpJ*lJcO>Q0LXy7W8WP-Ch=&gjBjaNj1nZx{TG#v(HR_j>K{$?NWj`R< z*o6j>l-FG$S{&wdFv_(t!K2q^@*G^7oO%P>eCav9NNg{H#d`dU-CL=Rxs5 zM;6;9WYH36yNRX9QuH0N#E~p3h)pvIiK#CnT{ug+kcMOm;{N)K;EFF^VDOQ_c9c zX1_U`7Q3#1Qsk-j!U(c|GUUmdd^lpAVP zquY;6ng!{XYs~KVF7r2qBJ#)P>5EzMqC@_SRe*5wGKS%$ED(a?fZ_Qn@SeCE606i6uydgP+< zY!q>C<`2zZJ-1M@9)r2h<=gM(SMJJZ zuX(dj2pGdpioX4(xOB2y9EX-@0~2;euq~N-XUxUf(8WlP;Ou zAml5(PpQc?jhsZ*>KSXTg{A>El{0_JpwC$&>9PbRzeIv3=?Eh>z}}K|wFpc) zaSqhZk?JMB@3qOY3Ecv^hsotaQ%0fD)1=yDEA}w)d1~d|Go6cq^Ct24d}IE|C2B#X zS*4Q8`YmyO5D=%GE6*}t(&jl!_nPxUxpzmGInR@(ZngLo*_4Clq&aOjIG=O%w!YUv zFAk3ysLJYO($O-E+o7rSvqOjite@*EG<0txeU)AWcc9BHLR6-Hj@Xae78`7rDaPl* z8TnkK2%SqD;a3qy_%A2t1rAnICi&f3qfnZaLc4|TAybNk_)DfKDg&SYov;6Q^F`Kr z^ZBycVX=f`35ddTVEkhGz7ZK`v=frMLpQ;qf#kx%Okcn}h{byS!~DrjW2m*Q=8wXr z!Cx{>QKRtbB^=J|H#dB5z269|X9$A}u*|PoZOtG4++BWJN(Ev6xib3@?jJ9k?ZHR8 zum_2?h+_|a-)~^k9=tDf1iggr3uY~G9|xQDA>2bfsGef_<$Gqet{upvqAz zsm_2Up-7(?(A!CJ;qA$KS=`OIoK_)X_L2}W0SgW4skt|6E={JNl|1y7u5@L`RD}F+ zIZ_x8P1aCWnh9Yl#xwGlOjFc!e0m8-GUF$~_m<-qMDPt>#ftcRMBZ@Fsk@OFBVqyV z7^9%R!A)ayE8jt8gS%7&jqRhR?6GEQY2Y^!8+bH z$CY(SJ;PaGM302Lnh2x`egTo@n9J_tI< z#sOJc%)7oqpu1)rUp#jOTk`OYR)qyP*J4!bm&aoqEehdifzW9h>pEU=1QS2e9pk`_ z`C5wHCbF|4ozYhpT9wc30C0xLsyHGmeP^SuG}@J+MPw`qEy|o{ zsE8QmNtaWBammc+=iwwF6pkH#AFMqu;XDRIqYa+59Dg>_Nck)~hK4#Yl=@r0BDgGJ zof{ed89Koz2*Nu5^UpzK_F~lJ-xRWAxO3+5EzBQLtUKoK`~D-F=I>FQV~zi1BI24R z;xbk}HxbDQc?UIxJZ8)UxX7+mA+8-1EE)0(Io@JPv!d{+*yg1AY@YhYnh->vpie`OdfYfCL`7cxV)3HP@rI z1rwo+!h^sLI>d~|gUqGWxUgdmY>wiKol)CfQd_ftbD_QDZveZ=;lmLc@4bxt#%Lhm zoX%=(8v$~&kBcz7%lZfAAzNTZsWksV>A$rid?+N<0V1!_pEhS~5#3MnXR($lS>tzO%FiY9H&kynMe{X`0h1 zn|xFF-Znq;Hyxj5tD711317Wlddp}UxGbv9xlTP_`Wt`K2OSbj2G3Cu^2UU?V3{zo ziVl>LV_H1jtjg~h9M|*t;d^alz#eK59qR}oeN$Gd`H@?kBC#jQ_B=#YAc?zl%j zARtZcDj8vkKJt%~`lKoqqtx(I5upcMCyU~vWX-d$5F-b4bEZec2whv4lkF?lReXdkHj)5mhyW--iNkt<)~HErqYA$$&MC?;>|HEq zofr$60MDoiv`gk2F=q+fG;>aUKNQMvvTH+CZ369Fw7wbaB0ZAJwPNg9PmMj-lXXym zBUe`Y5s=IUsilKu&O)J;l4T5YTNA+VCpAr4^Nln6(7dg4##WoZox$FHbdqtt*%%sb z_0N$0L>J7eF1G!wpSd<*4f+0swk-IHK#6ZhwPic11j$A~vlHf%ZtU@oken7m5f6&DgH)BEL9^Q|C_anNJps(;$>6r^@3qC&-hiq!yoME& zp^{*Q1moL6r9}Nf1>(FXzk| z3Y)NHmTM*Id$GY-+x?|5G}fWHCYu{t)%LDhS+6w{e?j7`_O2xsu0mFoq|+_lYDsMn z1I%4f?hd!tc3Cl+U@Lr8RPI2xAmMfD>DpMBI!Vl+@sZt#R~1lP)as{l0?yIl0Iakl zcww?kU3ojRp`BTcy|}AfCpd}hg3j47n1d(Cu5>O!_M|=}`z}3ks3qG=&s+69oGBDKA||ScQZ_ugB?7(?l{{N1G1NwfrNgns$}4t>r#-auN3=7*&hwlxu@Yx~ zIE%v4X_qnZuJyf>*716;w<@e#HqDY(& z8*`@+K!;xGR|ud)=@Wlkh@ObD7G6F)lt4(_c_WpRP}>!o&$EyF%T^x0y3^2olAEn& z&0Ot>t%7Hzi~JSu{79WhH;LD`r(C+fG|ZT?(-7ihiIJ_cUA;8NoWL)ZSG0TfG@h{u z;59)2?w-NDYy7&kUT_Et{Xvg}B55h%axWg?F5c(c$X8Cgx<@GCREa|XOng=qcd}jI zq@FL|dE7@R_qfpL<#V1uDx8|HeAav6Z=$SKm!EAW=j)$6dflD?ee`|p{SZJ}#bW!Y7y0sqzLD13#R-TaG-8=k{Ue2c~edRIK zJ_3t;f!y$~m(s!FkK~&O4zho*zW*+d)U+Q5F_Cx-J!I#^HiIX)6quE6DuuNDm>=fx zf`$4~U6p*jB8x#T2#4H8<%Z)@MNp5;_#%hE#{!X#1+8Iu^9VlS%?~i=XCR*8{L}~< zyVrV1E!ITI=<}>3hq+GX6I{F4UnJhIR36?jegyLkSa$G+XV`L;1h{vr2q3^+J0L(d zUu*6R^ljYQik=A9QWP*QJ*vD+xP-{-?guC?KpIP5p%S}F`yFej`)DgSh!B18I;mZkZHcVQoz z&h06~v1^xKefU$F(`N?fO+~`_Wx)+5?N)V#Q6SnBl%X)<`?tpHYj$~C@a;an#dv2; z5@)!qmsL?|PGtt?Q9hB$+d9xDRz9Y`^@(4GG=b-u=69@H+1d*Gfxls|xt7kGWbAHl z(%sK?U0*}_RBN~LWaV_2vrpA0w6K$|&d!cR`6LAj|9YUR7k#d3<482fXJUjbM5&a< zXx>sLh_aG2RSrLz4I|FemW=apn%vr~~2!I~;g#U#%wRYBVO6#Rz8a zdbUIuHa8#?DDGH+-XU<}Q-zR3VErhJHHH zeR+jAYp7753j^*!t?ANGHo&n3pv=T6JVPll<0=DKHqD}HY0%rt$*_poPpkcRaAw0p z^sU$pSHPeKwI-fd;d)XD3weZTe&U|R?az89=0Hz!>NSscH-ASQB2wSwZ2k=MKh8qf z0UXGpAvQyvXG*G`7_&Mg4wn*{%@om+WxuJ!5ZWjknqgOJXF6@>8!pY1Z%)cz&!0}AgqDh6Rbg4c8sR#t-J80^d)q} zd8a)8LR-zSp^(K)QIC2+4oliu5VY3Qxoz08RXAxqcGwKYrx?`FmG#uaN?iY(_sVN0 z#=*3<1lYbh=a0!{@HG#(=<>sbm?tP!nv=*ZO#?0fuqpYeP(Bj^uN5YY{E-oG+Jc{EQ_BJsZ()G7hwl+2Y!5;fx#Wb z>e^{S#&);Oavs?Jl5*AQ>&3>1ea48Qa+PEV_cg+RV9bGS!{z%itZx>nO3In=%zaPPgPA-UWMbSqx085`?+}|3CsTM`pZ_uXjkGh?0!2z)CzCJRMJ0bz|&H5c!~u zU~I5}QA%Gz^2~+RkZfOF1or+mh| zyE-2EzfAX)$$u;;W1id@oWZ0%`W}<|RYsBcRBB)S=&Sn~zU1 zAc;=@m}5;AZu^L1O)d{nGreeHF{Wj4ck&H0>qys<9_@*3#RC!a-GVf=eBviuIoFoL zh;YfYZGgJE_GHEA^agtbuvKwN`H$YKZfVXoj6Qcwu8K>j8J!e&6Rb>-Q*((ku!qQf z9O7n%-}r5TS|Bv3yw@7qZ`d*$k-6mAJE9JXgtU;sTJnchk=odm&Y~tB5M{}!-w0~b z?*v+RdDnz&YnVP%uA9&6G_qC@KjLQ6B_FC>6}Wmc{f%+fcQd-`jO9^zLb*6*lH&Me zjVy+$y~q#a8BdMz9ZVE|^23vB-PKz2%76;0NwS|2|Z=kpSeKiC2VOVE&S6it31OUcxv=ri(5&yY~-R(%{{6 zXTReOGC|;4h#Zi$Cbtb)I|jT1X9~SCEAUi=J=_#)s7rKOoNV;aB)PPYyoI$zM~I&z zk%Sp%(kSR088HPb@nrd`sd(-s9L=PIXdEy$e}w zch#+Yi!&%fAgC9u;4T*G{y|_Qh?A=Tf=WA@#;8SSJiWPNW6w++Q^e6y#%E>rF+9=7 z)qD^ucML=zgpjlviZ2N2heFg3gh!ZPccO=F;cDPT4+SwLVv7-8Tn$y5+MgvrJnX^r zMNi053D%A>rsr@uk>PuykG_+Yao~S}&i1}Yt=%h%y>(eUSqb~sgP;ai@35oDMOm3Z zGlRRjT-p zD5tUI;pD_19G6nq%Nv+LNMh6rXB{I0=!^lB!*#DZ0t2W~sC!`mp?Ey4iVyW(zggWK zj_JbdD%3$LLg$-}$1XA+zfB_Ju^p5T3y17eMt6a3F#Sv=yYBE2&onGjBXyHzH%<0o5?k9(e~d9~>HqP%d#+mzL8$(y?lX zFlWh+;5>K?SlCVUi$xfT80ReL%;?4Kzh^nek>6;`dFeT>BTXSYsCbUKQ;%?!LKJ~O z-)zqA@$L5lAs7-tVf#34*jBB^4aq#lY&VZ{5o!XK^^W~o z1N8Ip-7p1Tkvr5S1gvAJ+_P%sSx+rjMs%F!ps)@@H?e%&`#$5`4EbHPc4dv$GWkcM zNFj7TbuvfymE$WAv1S3++L`fsH4u!@+PIE$xYi!{n1nd%PbT@YqqXzf!|}?{Cq>KX zV~PjQ1hV%s43E`m+D3yQUaU@{)FPp`hssBtq~wsZ)e9Y=G#U&LoUxg8Ll)-&z&W|; z_v>cnH+uge<+vT(XbNPlfSS~v*)rsfY9vl=`n}A?{6^*!I!Ep2TI&%fa!<0UM4}`Q zl^PRc0L^e= z&!!l6wZpg=cUpN$Kjdy8xl0U9QN+vMjf2yh1H#(2H3aD4+?9wsung<|fUHeUxO_^o z$N|^Vxb<~}m2?OxtEadRCIM8iWSU2{>oOSd4Tec%fi?@6LQn5;0GS>5)%5h9h{Q}T zPA#L_<$MfWu3}O%8`X|osg{~Z2E5upM1gs{0b}5R$c41iXqzG_Z#+)IITKVb7QLFy zU6K0lLJoOD+&VFz%Rr1-NpdG~o5G%{rgj?NZiQ=H}BS$;NR>4=12roUHjbZlX)n_I* zebx`GfaIbHB$oezy|`|1F6o;*VtYeW$`ONpz$b~TPbW`N& zi&w@N=r#e2(ZIQgiY{k3J1A{l6EKz~Cc;6Vb>l6dQ#3HDSSqVatAr_rBg!}g?WEUL z{9B;yf2iByfazKX*|f!f2Ol$As&yeVYhB35hP5t&phAs{tgG^%vR)bT6K=Bg{BBSx zeZeEJReF~r@}zb7QndM?@CXqslxbrQE9-S3j}TM+Zux7%4*9r$J}6h648CjJ?7|4h zqW~>e^){pu#~WPiGn%>YR7izuSI?+}vJ-f?SAtBy+LCm7#=qygbdaVgd721+P=txj-9Up%(Q(8NEQONM@fLMJlu@-hdjltZ8@M!k%S$-r2+62 ztJ`|A2JX_g?Zr$x{v#AEeaYhKEB#S37|`1f?nzEhCo&r-#SY5u6a(pqmD*dYeu09( zM!>A}ZuhvQ_;le+U@F+*va0_Gcn|V-0{LV%<+*hVVWu)(6jEN}`e1BFkdglTOx!lL zES?47K_Nk7Dvs@M=OwCl)B`-G%#Qbz8FsoOLD>RfA*fpWr4U9+)nlyd8!}^$mkni5 zwdmTyDkoH0$sNaMj{_n~-U-c=yZ_2Hv(xKKFhX_^s_L%`Kfu}`&8#%h?h`_pW?C~% zU)7^bmHR82+ab^DKfn;aNrsR?&|-(M-yLA)06vScO!(m39{}Ltu(PoL9<0Lvxj#ps z(_0m)lhz^6ETf;|S`j`k2J56ETG&DF^l#4xcolrtv>B4F-M@z7NFCG&=kjVGcy zi{OAeSw8AvTr>6s?L{+olVdP%QS8#XX4?p=T$L zT2-e3!*!mJ0ckxGhSKc;6nYLYSA>dW4bQ4X>rhupwL5+9|wlgq}#hFFBR>o&+eufZKb9C52K)i(N*~`A@|UMAos5W2 z*2cCpDqHgKvY3(FTp}8tBPZe7IPKaY_~2!ZoI+e1d##FPeTf{{*>C_DOx|w{*p2r~ z7MdkhSYOHAC9JgPNquD21Ycrrz2@;b}coMwYnpr zKGSRN3~v7PVrg!{O4N2}MPh3wqMH*Ln>=WxMqhn0u(jtQTGA12lO-#z68!Vt8Ec6H z=aKHjf%E=t;@Nq8+kWz5Arr*?_y_)ULa4|2{HGIG-4b_`C8*nPvVIGk+Y^b-#9fOo zi=`J5;vyfGk|6LaYSKuc_h~ci2;+7?!7FIWZ8W9vgKed#u}{W%;%fP{?-EW$g-G#^ zVDQhgJ%|hDeTlD_dp&5=xR*w+UU7{s`nUw*|MHM&4~I=pVD3}uL|HQ~AN{y3lwX$* zkw&D?*x171hBRX7DiZ4Trud9kUsLY0UfW-kBnL3PF1hwbNQ*`Ddkbrmi)|Fl^`5gS zd!D_wt^CBat4pVu<(F?h!}@JJ&>Fuydvlwx$|q5%^~8$y4v&)!GhbE`A=mS&p_91o z+5K0eQ{I&@vAnK`oEHt;gpD%e+Spe9D+m{xTc2W4yc~34L|iTP?VnNOe+eB)Y}sD& zX^$;fRH--Yy^GxRq6ET3GN+lZ*T8!yHz6UB7}CHubZyHiK7uK}&a|H!Qi8w!9vc9F zwYRw#o$^dP1Z(?7%^RJXja2@ND!-A<>~fxGH;a+ArmrJL7O`1$$yQPGGN2uV{H`nU zWLO1ho$5qaB^9iYzCA@Fw9qA2i5rQw9u_lt;(L%?N}$b+dtrY=*W#gNwu>x9&bOTN z4TQ;&N#`oaY$to$YdWy?>o>W+e@2cX-E-$t2e7$>C@)(H!@=|L;vgZ~KA$I@8 z4<46g$SMK44!eJw>`Y6$|68=%U9H;#GXl!HHPY@XqWGL!YLTIvq@1`(Fu|ug@9g+# z`DY`pmYsWHp=Fg?=%V)f%K7r_i$<9=Yc8E*>_<}WUfecp%RZNVzM8Q&+kf*q?Go!P zmdp6w_iFjW-trX>%fGAgU#ncNri?^IY8eJ0BgKHbyR~`TtDG=Y-o;aDV zGxb*)Q9n_=A}!=C1R@=>oh__oWH$TK?94Q-JSD1HZ8vR&_Ncx$1flbw#K_zRk)jOd zY+yUC8!W{5!vcf)IFEeG$p|%Yikz~#v%-va-d_{gqWDEPoIeE9O{<-f9KN$4aS?c$Giq>YO{m`VXwxmEyu{rY26)r}s$wGePqouUJ0HfI zcm*S;C-NO#GIzJI4s`t&qQs_h4ExhK?d!1JfS>_Jkz6Ndl|r$rW}?I%f}nwA?O#Uf zmO`Y%b+W&C{ks8s(D?&-_KM*1G)q^f1*rv znyJlpFwN6SffwK4j8V{t)4+{#*H{t}4dMaz=C2*S6mpY7a{bOXXd#JvTgRC;tXTizk8ICGkO|At<2M5G3qSB=GhC4k)0vS$l;l-+p-g6KN;G*tV}4_2iFD z1~m0e?xvVG{wLiow3giGD}$uVxQl{r%0o2d7?D`rg`$;P5gENA9$-Z{f^cQTD-o2- zLK*}?pd-M2!`1<)ypdGC!se*lI@MJ^04T>?W2{p>#xTDz1&s!qBM8>@CE(E2L>Dax z%W|km^b=y^f$qMOYq@?x);Q)jmZO=@;&7o2Z~g+{ew#ZcU+GLD&aqp)@uK-Ux{3R{ z{kPA7AIWdh?J(&suG_`agCRc1VbI>on*vwu?Dh6b#O!li&v!yrHwd_4pIU$<_fRL)Ol%^$P=zuhpL^)T|IT@m@QI$m{A*rQ(O+Z8uy3nEml^_3*?mP6H`^u=XF?(&5^a0FJX zAwrg^i-}McTaH0{8$p(CgxSyqYmwaYedp3=Vwej>4$h;0z1~Wj_&huPsHk$u^O9eD z<(GfF+;OUE>c1AwaABO;H{6V4nBRva!C_Z_B)B<#4ibEL^%Gn>j(+be``74eUvnzU zCNtu`D>%MuwAi9l{n90_R!+F_+7IQen61ng2^p_M6+d2v1HR&*cT0m%pmcGF2CqCu~iw-DO<>k=YO=m z797g89nNai#(UA_S2L{i^O7rVWLstg?7-JUJ^`{VtJ=A@8k3|KRq3b1rdv1Hwl3qx z=r5)~`x^+$?S!5lV+0YatdxmpiQ)I$x`T&v15dX~21e4SwQG1;?%hr|xTX&BqbjJA z96E^?)(g=$t-&W)Jlg@Vc=g1|G?h4+b|pTl@Guow=&)~ZMvgGTqf9ksUDt_x;z&s4 z2?=jlJslat2ydt-yx~&9rX3<4XmvziJfE)tOAO)e5mEfYjR7T1(uu*JZF(r2M3)q+^@C1xCE;web9I4sU@h&Ec=1iwwiXmmE87* z8V;~+>&&j1JKH|6y8d8mYopy7)NZ%wc8vzNpS0~B*xVXu_oc>8jav7&jPAc%++SHb zse-+qUGEf^?ESm7oh2lq@lGWoqh%pVvOpSs}TSrJT(KfI;fTDxo1S;$TZQf5x5N9IqGT3EM z>dV8_plYNpqIb9J+m*6a`2?vT?JxRv_#82@dZ9xlf5|jO8Sv>POk(uDWUV*p+50WI z>uvV~=z0-OaWtIaV!#r|coIjs{1YlQqhS&MiI>fKvK!9`(TI)AaB3Toi5F4BH^T|BctY};o! zr8yN|Un6iW*KRe;<=@YfPc_TO3E$WdQ|n=iQ1{gAkjhGANnz4h=Winn_*r0#i4^#( z|5y%yV-;~TYEY?NQz`h!s#b>#$nIIM5--uCADg|C{G|{lyC#Ljq^byGye`H?7@;W9 zOUIIf#vlFPJyjQ%MWSUB8OC`q&I!>suH9xIo>(Cu@uYJTT&o%Hu-y8}r*?Yw=NEBq z7iA6#^iGF13_ej;=A(2d+*YOqEHJ?Emn|vUmCz^R3&>n4tKrLMl0Uczr#xfnWHYag zuT;PyE?lttQYV*MXfas{G2Nv+rF{h_(-cMgK)i&p%vvCiac`q~NxRppWU#!a!TaGT zoh=mRk+^`wc7(I_xIGEd4!chwSz6>O3@Zle7VJsLUouTmBk;-j61MbxcC53QQc11q zNd77n;{=WL9M6A06Wj(71vmMm_3hf^KXRwiVbLlgu&ByZ$Y~LOiQ#xvDIF$yWogR{ z*)Yo7p)@XDz)N?I(ehr)(r&dVilVY&g43J|hJ}>SFe$6>G=P`eA{_UGe~##w*Xk8O zB92`Ln}#ohs||w5ZL%!datzq(%t*X7k7#FjAswEjlzrCjrq8|l=-;Y(pYW}85y#+Gm*!o8bT0CD(tQh^avC5WCR?kWBhL~3Sj^eiWvb% z`8?~uV|>6i@LdkjE?-Af8!spsEQE{3C1ZTjTnY!?CV3i@LUftGaX#0TIt{}GEGizUS zY`>z4e#HVgh`T-}ti%+S!|#kJvW5Kcve@urf&rDL0mbEDue+3$52%hEsK3ZBQ@PMx zYeQ6S+Y;6rl=EUHx(uG=e<-_9GS3Al3lF}?NU7;jXp8>5#!0CVbeEb6X;TlSATr7n zsWJ7^ntJO@ee{9}*@C_Wg8oHzL70iCUtq!y!K7`hPN$}FUlm#i`E{fliaJvE*oHi6-72NaW`)x` z^tvrkPWcm^-c49{P`hrcZr#81>)6BpLm(*Ft=P&b*U?>-*_`|cGs@BhLJPcv_RX4Y zm@H(xY$=ND1+q@Kzhs)lOmL*VYyI?poc4cFF4 zdx$SEmyf(>tZV*sZ;_W=RI;oh$z~AEzbj)PQ$_+)H%KVAzgGtbvUKs5>K9+{H+h}O z@k{<*uA+iihXPQ^SA(iAz+eh{obh2}_P7*uQbr6DDB@si8->^)d3q?<7!w92Swkhr zDKW^BA=(F{rb@{hRo&H~V4ikhvQ?g*Esidzkc{h(Z@T;|uPyZWx}~X*1@I zFnK}V|AV>fl|$utpnSL^_ecKfh%JrKBOhqIUD; zK5swn^W~S%QF}p;pwD{W``6*K-DV1VSS3BlEP9xLSJMh4VocSx*m-cRxJ#yFKE#l)X5{nThlm!??2&2{-|&(J$R{eP zte_kh|9wBu*&ZeVU{5rr#m+xiilP)Ll{SKSQfn*ZN6F{8tygr90!oUYSHd3JEsJ#l z(V0poh1Us!GNiX{@Holupy;k7X%;EnXUo!@wi?zC*5qe>W7D#-L4#jo9$0@O8A8

3)N^|vg5@QZ*Rv7?Px>Oi>Ka!c%dajpQahS-j3%x zFK_f4@;bW|tjG{8GD2#{R{kJS>3G2#B*GBbUo<9a94`!mM5pYsB!mCUM0q?nayuZ| zCYJz1K;9;ev26o;!Kvf@Bvp^o?!Xifvo}EvIL@=G_?z9&_0XKBWr~P*|1f|AF~)7@ z93|}I0_bJf`o%l#vanYtt>lYP%pft6u!%o^ovW_OZ7e^WtIqI{aB@90=fm7CVRo}j zy(+}Q{!*cGfofl@ic+|d%})HY$Inqh?fYD0vemr7K|Y+cLAmC)&`Lee?#poGD~VXc z!yjUet7*L80g0*S!M+edMFD;eOw4pi<1Etu^bS7N&O^{)M0O3MjO;-iNZG2E{m|GkjAJ!o)Sz!dTaN>hm! z7faZ;!!Ecz(z<8RZcA16C-z*$?rdE*V0u#Rj;d3YlI3if|4@;VsfcM0TTy1Simt2N zGSFa~_@>e=L*xEB&D2P{>(C^&g^ms~<&Yl9-EJW<+R!o=J5ct)p_xi4@l7NvmY_$d z=%!7KF=?s=EvMZVH?IhZ@e|J}iM=A57^9S-Bzf6M!P{09vBQp{7PYFVc`mLZHnDxm zP!!Eo#P!Hg3#o{mZnjn2{NlNetLSfsZV}U?))mQ{RrHquM>EK}YE=>1jEX$nZZSjG zykTUBX!qBdv0@v^wu(H*fB$(gq%uqT1v9@G#y_}~fL#A^)v3g!N5$@>k*KuV-M-|Q zivB2WAwfDQ7}zMIgNqNH4|h8hsE#_L9)t$u zS^u$E6DbzWV2iE($6`r44P3G6Ahh;O?7(bOKu$$enR*~&KW}d^I~r`?(bES!McCH^ zYE4l&e@$!{P{|9>tE0%jhNwzwTRmGW%7jgv$v(Vn2dkKYVuulpX-40Uo~!bD6yS*u=EK#9t)G*+3arMD3)U<5fTdm z#u1_aL{?d%MRKuZ6;uJM`yckif2PFPMo%nun%ZK^{*|V3o6nvFwKpSjWFNB~W!25Va2`R^YnA~_yG5kXDS>t+=3AF})>B0B7X`%pqY2_0qrtF}#j!Qc07 zbl5uVslJ5*g`Gqg$S%{fyWTLMiuulNRq~vr1L`R992AVaB?O942IgNTJ=xCxMEwMW z`c<;j&+w1bFP2mUg3MCCe=e3#Y&3%{mPe@Hf1)Q3Or%)QbGF!8(%PszLEpa;G>ik= zYqpOThD>D%`sV#D^^0W=6y{%0Y(o8*V*i$?vIEEt7Y+*$ReP~52{j02RRvQl^DjAr zNU{InWZ@C_C_?=Ng!*A9{$UW{Ma2?Tv;kExM0HRkA(GXES~mTslAQcg(p+1!#y2f#E1qNmjHDVUs3sOJ&=qCv`qP6U6%2|>gy%dv{{94=z(w} zG7KtITWGuU$U)b+G+5>wmB`XJK6se4c)fPEQ=Y+D)NjB{b4;*X z59Kggbj(JrpSZfq4e#PC?YFwsU4nfYjDn(2{q;2({c6AoCFEOD)EFk1DlL+(GvL;H zOD?nMV!OM*cmcQ$*}M?AYh+h;Y2bIGzn&9A<#d!<7gW3^g#!)EInbPRzHd1wJfm72 z6Om$u#-%Dm{Eg~lazykF`{m3#1@;gOcvy9m)ya9xU|k2dIB08D3n#|Gto|JIFOj}I zQ?ycCQrmmD8KUfid8YR<6LHUovAbtLL{M+?JJVx;Mwy}~i9-x-z%Fu3#nTp3;eE;< z=26h%RL1s&1nZY-L*g&;N*G3Nk4aLxPpB9$HRun3F9^mrxNCR`S|465bat-&3F=d% z&Qq!(@~&{)uR6}r(;MUByv|)U^Nf+69wYmv9}pWy2V~lR!H=NQ@T?@rUoN?Ex#L__ zVD`eRMZ87ytMp$)l{cQ11pCW3+`Q~NjVa}wQ2`kS4-e1M5rgLuzsFM|!L6E0GnA^8 z>Fb~quNgm8H(+J@RX9Y7#l$G~${x0sKI!j^q;b4hVdOgmrDYOpX`A_Y>0GvPqXZNW zvN)nIoU&$w(bToVZ^WYFSw*6I?uTR#Xlwb_9c|msVD6wq7U{BHE!77SPv=ct!=N!5 zJYNlsk#4+;68OO7VR^Ab3qH2P_Lk4%c~CK{r=TABS-V~=bG#D-dXwP)t5a6(~p;ZR@3#`%|#GoW%bHn zO8*IRuc)x~-G{C3?)<7k!DEGI`%nED!=j*d=WB%MFDH8lJvdH|tzw1_$)VTPihsR1 zmuS=CLxn3TPBu|A!Y&cZ$m>at%Z+Klds%bj}pjlk%EluHMN7~`mrWc33<&GEn;p$*V zjw9X*#|!gtD|tb5)LL;^etXvH$_zoRm@|hn_&1KFVd-Wi?I<#I6PfIGa*za#RXJYp zqC~45FR~J~O|qjXk=$-(n4#+JkmOyFb|~QG+9kC6Hp-KFd&1G%tY)Rz!&Q*u{9uy` zd|Z3z4Vn-9A$_+H^frxxE5UHp?$&A(FG|d(Bm-Ef*HTR%PRy7=_u*uS>hI|e)sE_J zZG_tf{8TrN)mEzFW4#sJ4CqE%+ZEgGy0XLp8 zM|dC)k*hO$U&xA<5m}Knqbp4XoA--`5J6GNC4OaCHn#AE$vUY(E*H&(sfnt@H;`c0 z8z!It{3X*A)yT-PXEOG>iFno|$7Xd93)P$3c;BM_NeA)&Lpioao*GqqeZa{biEBq~ zj&br<`+3K7_7?PNMp@ujZI^X>EebpoNzCsU4a{pF!;I754-vFUPNbB~L*+?4v@#EA zShlkMT|Ef;5n+)cV!PYrnUpwEB*DW#@=)z~*t-KBs+ot$_J_2%+bl^NNt0qHkE1nt z8|@~LhJc*2ZId(-Y@%8sR}X}bzN!)5JZXW?rxv$0I!37ewdy3#hYydeEZ_qH9h{RD z>ta931MJJUpk!fTo0H5EJN#~jEgEP~R*^teZ!I%hKXEs5!td-oX7`-mO~Hd@Zh$;g zx!&&FAlMPJ!y!9=%lVbc4t&(7Y}?1ir0VNyXzz2xdc)`;Ao6v&l5L!#HQ2-AvQqUywxSag|$#iUf`q<|LKlg?gAstN9?LM$&*g{VxD#pTUWr`4sQTC*?tq&Dy}#5>cmN1K|7ALI1oF2^1ngsRCPUJ*vzfU zFB!wvfn)*paxB&3?=se~h*eo2#H!R{RodEZ)lDO+o-0_DeFKTkW<&_L?gYo3N0f)IzawMp`}^IplbZG6!c*YZKw zlk5`uEmbkAfFK~v^Z;;MBpl!}9?5d@w#CkpuC--%7 zMhJ8ENL8o1jwlWgcIpr|2O2422X7Xa^%{V}cD~!Jum_iD(U)rTpbnWCckALqNx={-$g1F>Q!f zK2mA*Zv*)g->T!^crwgbDi^#`j6|||`nIL`dU&Hhf$uHF*XQvBMewVvNFx)K=U<%L6r-1Up8cLb4$DfB2mv5d zu855@h$3d<4s|dBZOvHg$9XjkiqHOKNTtfj>qg{PF>$U3n407Ly^Zh_RQr2x$$j=@ z_1m1i>+pofE}I&O74Nr@?a{*sQ7Jxm9gf0ivSd= z&N4chv_n8-F$1hnFVCyv^X0}@lXbxH5$acTRUu0jr6YEtC6|q{2yzovH=j5eE$bSl zS4-cTPYg6wli@=oUvqwPF=bUrR~y<7Ay$dBj94X-FWcbgMA%Q}Tz~T4bN&B7KmFfx zU7YscbKOoK|9h^B|9h^tI@c%4+tNj<0kyA!(H>gUzRtDW=;P)aT~=>+h4+G=_jtIo6GP zq39M8x?z%Wp?RmrfL&db4TP6@Ns_=p)W{jJJoe8he=W_aT#Hl|>Uo+|CBE%gmmoFC z$_@V?ob7dgd$w=;-#pvD`@eLyKeM0h%l~hk?Qbw=dm>}YF%;+h-`&}hXqvu_3n}!@ zJ2Vhg*~M%|x(9K*Wd&i7Y!%uzp6Wa5{jTpr5EgZG=dSNwC5Ia;6}%(H7cK`IR()szizAR6la~TF;@=r@A zlmG4^IVWa&h2(aLTh*TsvrWI&aQ51RpKrfV4%U8?S@zwo89(&?D#xR0!1nu@-#%Sd zkH#c#yh=*n_d8^MZtGKWwxr}3A;eSJOaJ97yNj;3TAg=Yv-Yv~+EGFM;+;EagtIi# zk2SL97(|F_+Xb zCQ?+)0)hq$DNv~FsNEOC5MXl%@N_+6&c^yr6#-I=xO5tVp2_U7Q4ETP!8!wfnvVKW z9Lq7`&&Jxs!$cVlHs0303wO;~8EtHpWytg$pAi?#)o|`0^0f(37+I6I zR+aPy5lmDOnF9n9ti51{rB5Yyy}jn$p}`@Z%h_aiEinY>&v9m*%6W=C~{IIeM!Zsp)TDyFz6CD6}yn2fOx$llgnwoNQw22sl~h6q}Rf-&BM7&R)1y zHZev|PXKp=(F5EKrdDK54sjpU`H}~CB$k<|$e+v|VTc6{v2-LuhPqZ5!8Qk;~d83ySo?aWp<7S4t^9D*-;mncG)+Xmv^21 zm4)}~d}QHGwrA`*lw0g9*^UZl26DPT5EcJ<#2^CKRAkH9y&J6LsnUj_)V-n!tnGIq zht~pS?`d-Dco*;lVLmZX9HZMWbf5&Gbowp3C(2a4zfQ@}pmv0eJvDp{^A}6;cr*X4 zPew2yTXm|K8C7fe6MZZi?*S}ckpe|Dsaa%9Hm$1wjts~okG0LvJp#MFgIKbrG zdZ;)|t2b1LV(W5Fa7AMO;MCwxszqLQ3F5>*M6;BOz+gDt+{R$IG+qG)1NBg8wLT*D z`Z&%^%u2PlS|7UoqVMjw8lv2+)<@z|&&_bMYwc`K_IP|6oJ{?P&B-$NZiOC2tQWr` zz>-%wCKRqF9dZ#w&m6Jf^^$Ydv1={JSMM;+$>lW-LpGox{^YZt#-}-?r~9Rd0zCPu z&4cX5ftleDo)Ro}Ks6&s*VlpOA=4oU{tvqT5I6)AI3>@ie8hHs0~N|Mx}H_?IjqWm z8~z{C^>yJpIb9D`-i!~g1?YO7EC%#DZv@isB`DlR&q3if$w4;UR)NE9zp~-Bg%cs% z#sj?%;WlCHU{LNU;H5LdxzS%N#p5+(^gHx5gMase|Ne^M{|oC!($>A&bXEKN%GrR} zi+-uT5dY=Y-Uoo#8_&D&wcC3t@Do&c)T;U9nw&f9^Vi_6+KVqwCGP)pBm>z4nJ*Kw znrxZ;2O#!24L-!zym@^JhU-J_Hcj7HwWqKeqVzSE<_{j_ea?a5m0#B%PZ*Z_{3Q&3 zzd6_E#p^LHD2^U-dsOoao!hMgfM5LR%l@Nf_1~An@N<*=&-gbvu^o{$+dJ_HREF4y$FK7R6)RQ!cCiNr*ZnpI# zmVZ@GGUeNP660U$NkVMvNwWXBo`m{KJxR3DV{jocI4YAI%%p;&D2lp=;hMgqu#gBY zML{$VYpYJe(?n1br6B@+C*fHlxEG}(0)A)Vc_MfmWh4TqE4fGn6;UQ4fVz^)MDRAs za*r0Sz`Nm>Cm_2!;o1f0_E46&$+NG8MlF6QVk|Z;L|KPeMf@vzcn`x2-4Q z@Ii0o8*A)f8bmPOHBC3yG`buX=9={-DoMT|%w@I0rXw~KtoxNI0?)r?@Z}}Z2kXd{3Wkp|m@+#;r!(LoGG* zpH8=`8Z1o+pFNUa`#M)A{+opuJVzmNyR(>k|*wy)i@eN`|#n^gk8 zA)P@S8Z?`G(+6q?kDzk6%9G8$RtJ`!`3N8W;uCVUQ|C81D{9%+7oysBHk@e@`Q5;F zsxh_gJ!%$E)Jw#)UjO~m8MrS{TXnrywgJ#R86xnKgXn^Pkk z@Gd4_LXAyspV&IjNx%gcJMg@G6~cy)9jJk*Zy4Vla@G21!W<_`T4(dQpsBnlT0My( z%OFH(2c0?D8J-q}@(ft>?TJ(pR$m`2Q3^@KF{5^U!8AP0jlJ{=QZe>i>OusjUi=nB z@Iopj0!uHwRq#scJx$%mBiup1x?vE`kWuvJ`<2v=sJ9GX(4DAC@fwGEIJo_`XBwQD zwVLN2ahRi45i|L-w1VF+liU{FMVLTEYs3A>a@DQ=(}AwX2jkqq$#w45;w*fljn_Dk zFt?`t94+crCCHFK%jIs`oyA%GoLM6IWje)Kl-hVrot(xNe4>0DAncUPk2eYJAHo7u zC_2Y~vF6~N6bh$Oo#THIvISmuCe!+$<N4Ub16#R9fkoA(&daX&XPtvumPtc z-j}=_wq8rZr!L(>EHVs^mm_BNf-F}Hvs{6f#&)?{$d_owCU;M4W3yZ(uw29czFa3sTIZAHdMR2B zJs)v6zn#!PmTS7ra((vq<+^WGeL}E4-qcI4wpo=fn0bA?r5C>s5xkO0iGbHzZx#G0 z^_~ICb+m(C>Q9!d`mf8?0?ReWx?I2X%p%Lxc*9|y&2s%Jear8+huoIAi*UiW%Qa=S zyA!)yzlY_z-MU=Avstb>yj*v(%Qdq&YsNn;*Jz^?_MTPtH3VQj08~Hu-B7ipGi2Hh zK`%GnLnV4Pa{YwYm8OF0=VM)I8oYdEUfwlng&Go;%8~PkUl*tzAWI-dYlcO=f@xTT z)-}qv1J_N}LDW~}We(MjiPl^83sc;A!WQBah_UH+#O$e!L1pga(PGcefnYy)ejwPN zaVyD($I+qud!3II9C>kj{N%ky_MmmX^y&V6hhLT7{r5Wm%eMUQbC-`&0kGK=(lv}foC;a8vhRH zH>mz+#UC}ZuR_3$KmWwY7q2GoUI$=Vojj-W$&n3K97K3myS#i?kDbBDz|J`TyxoK= z?E(QT>qFQy_*|Pbc)z1qC0a_^C;=<@XYm*n<3F`KzO<=*W2cRR9| zmpyy_irx6~R$RfChx6!jYn1<6Sibo*5zymrMOaVFZ*fuLV&1`YDb*w1UOGqbsXRt2 zbAaO%8vIPL#lU4Ny$~@fYwBf1>WfKGD_`nD1jUii!xRI|lTs-WR7UEp0*ll;Ud{K# z<4UFek-~YvA=c|1-<#`mlc)ccI10x#0w6OL0hus&sgJ{msS*bQWCnU+yTcTpJjgHt z><)1p@K(9ld4lngs7ux_N%XPA$|uQRB5>m3C!a*U$0ik+~LO%0$Q8VFu*EE8wppHXdNWT%6XJwwo!=E($j&o=hmBgwE@+Cmi} zi^z&r5CI)23?qWPXf+X-BZXawz!a^`P=$Aswiom*%fg_F-&43+-J2TS3V)+E%G^jA zG|HL>%MZdK+Aqt4id2ros;G>noTJYtlH2MmD>kUjvNh0mhn*sl{>e=RqoLnR@x7+4 zDjuo|mvgq(>V6Hx*i&TwSY`%>Kj{0QVk)KHsS0m(5wt+vr&<*hKnFSCw+QvKgJv^u zpp}ofCzCgcaKig>z^MkcLBfZ?ANLTwbdJBz865I?v>8E^T{Isu5H~U}!-t2yspXX^pYl)_NfHevjlV4Dn{h}g-&Hpv>!*(PrhT=|G>6Vu5$V4D;p zj)|wpHqp;RK2GI9;dxY2AU{1G{ZqXkk`yz*-znBX&LScNS27L-xb9er$9n;o!BMV5{qsxIB` z40JYMPF{F3z3Hdi%b;J%d#~J6^C_U`dvHfCTSe7-mayV1xE3U@KBazo;$1Kw&c6Ed z`ON`P1yl~hCqG#0=Zva=KpmxTAB7VLB)5ztJ)tO$E#N`^#Xwd8r+njIgmR@|%-uL) z#y(m7)vZSc`PxBNMVbn9C;BU>aR5DPDEIZ1B6V8b5bhbT{jC|JxlX-#^m$@~Z!e1f zjf->A)q>x!4bVimp)EJPG~Er*(wVj1H?6EQ3ZE}iD*hD@4-kg7;f5Q#^M7JVto5Gz z`l`wBeNs6Y9w__)54UvZuf)Ui`=*iMSyF#8JV>~-4I7`Gq;82PH2<>LZ;DNby8{k! zTSOuD4xoA2Fo=%dGXbg*P)}*5>IoGy6q{{Y#yLb4VJgdea%}}Po`v4=(2Giks&$Gb zP(jlm9n?3}?F5;pj495#=>P^kP)scgk{WPqU)yE@*!0UePU6%xh^JbM@(3b4SdR$( zAqrX{1+zEtdv9n%1q~a2pj1;QqwdryC6qQm9tdg@4UOU6iKZ&CLuf{A(D5vH%=AX0 z)|KWp5*)-a-fH@a%7;O<>~Jw?{VzTJ1^s~BomtqWBIvdPyfS$Olmf-8>%e?K ze!};|6r)ZiIB9Sq^a5zYzCwKIjlKSD^Hr#&!c_=y+@3j9F{*R?OJV5K+fRp5z#^DL z+=N1^XBS_^)mnn~as!+V$TmUo1F%ixD0TQT@+w46pnY8Uiowq50@hAZSB{4T zJ`*i)-GMrku8%~ocw7P|$!Z`mb%BpFf3Xyg_W|BvPAlXOwDANU=w`EagGUM$h_6gr zadqAc!0^_!3%<9{0atJ2l-+Fui!rlHw>{O*_qO^KR#)p@hEbb7uw+sB&s&zSo5oO4 zsb79ScMny+U=Iw>Pl#C?bT4`43m86P`AS@4M-NP7Sfo$|pJY3mTJ$Fj-^XbW8{VvQ zHTLzT7jSRK%$r;Lg9F7>lOneD{h=a6XnyD*6i-M`Q40{zSpc^RYnJgqiQaD?jo*H# z4QqovT2&a2C&l-*SGNrlZnCaV6G@K9EE3+euFnFz;EdhPz(mh3=Y4&ZTr)kRXrEL{ zh6f7I;Nh0;=0kXReqS9Ko+WKeh6f2R;NeuHxd6LG?vN}bbxT$FeVg#(RUCfIqpuX9 zZHZkN8g*1Gq&TIb8ahI-IkX2fal?(4T7(mF$Tcjm0>(F3Y~#^ltA#V0ywY47!oi6M z$Dm9e?_DCOlhwk;maP__Vf8`^-AiAPLl-ignrJUza|0IhV)V^aH}y^p^krmuw(8?w zTAS!(qf}xZLe~Hz>Kb4S;xy9qxYoeJ$qp~xDr2@*ZIW-$A_t~#6SS;`Z(ypHTJ=)T z%yPhqJC?TNaNVszeTN z)M}E6Y4y0Qxr<>H;zLDiysdGHI{ki7T{3W@*yKlh$EZMnW!;SanP6R`hZbp~iZj2hb{nLi#BL!^+@0~mO zx92X9K4^Zmw`1?y+a}+IDqUlBpW> z*v;_Afd+XWKj<)Gr-R`s?YcAZKjiw4b=`v&{|2yT=v3C4IRh#o%O;l*XB*Z4H=aLk zC~~d;4>8y1#k>3j%y}=iAx8 zb9*p3 zR^BPfW-F(AErwO6L>n`3C9uE8^wzNIpc6E+uiunHlR_Ze52M6mMK-H$A6nNPA!$;L zk|rnV+89C2WFx`^V33T|-t%k)O(q}@?~clcYXwV*I)ytu1GNE6l;OfjCI%8PU0|W- z@@_+bi7`XO;_0|dtk^Yhi0Ke}9l(IH21+2!4PZY9x?b+$4&8*d4;&mg+rKD?4-KN- zmt+lTU;{cR)PbWognq(rSj;_deOU&EJakzm?Yc|rC*V>tXX@IAX%=XbjD~(8UJH5f zMa=%-u&>aGmgz5fgaAq@&f~Xa>hn%2#Jw0y{~ys0fBwJehZ#se{KV*o`A9$P|C@f$ zA^nh!(EBh(Kj=U|DA{g6{7*=DmM4w|{V-jfeC$w) zSSCJqd;#)VRJ-0kd~bW}`!+cFIJgZOeFGNG!*+f!@>I=|<5N)TF98eA^;eI4`m7q& z$=-?IePl+h>kLul{@&%!#~rPTe*{cojoYEIZMS!6UE%Z0tUgCwhSA)I~uZ z9Up21*qoWc=UvVyU2x?!h zg@PAA6RQK$on@UI>}wwjr(;h(v5EYQ2J=s@wHrYM+9v))K&`VINd(iIgb;ydo!uxR znAIeL2#o9OMiarjCIinH?YiQxG&PCa@GyC&3&XQ*H|+usa&6B*BSUF}KKgn?gkvzg za+t8)fqp@I4%tDILz8Ye(C)A`B)(5yzdVOL$|@HFF+cv zgCnfButL{KYA1-J)t%s+sDw{+V{j$U-d%b|-}pj7D-|lu zH^Qns*N3N9NSIawN@g#Lr7#D0S0+yg+0)b2)xfY{>&Zv&rHl{pAJ=0GGKUZRtEG7H zf50X!_*Y~ghY^PlbWT8jH4Vx7qZ0e+{=l2^dFVgQD&OyGi&=$l5jCz2DAWrGU=WyP zeKqRtg*!1m&H#MO9@TaaKf8>DeD&V;=DJ7eJu~{60DiKu@;!~UIOD|}FvX1VwyC?G zp!w{+3Ifj$NF^Aan*)WX@o-@1AZWg*ua*oyETzcsAmMpDTp6i9jCb^szFAt)6ygp1 zCPBb6f3H%<*2HJoXZb+ojvXqGw_ocN;F5%EeJXb-{fS~3?qtA9fw#$y)Z~kI>0D9= z8A{I?V;6Z?qh~;GwiURA7qM#Y`|yvt#Ru z=W})_;&w;*DH`G-fQ8tSHKpm-4A#NvxvFIL4Dg)7UKjkabG>Z{_51WNt=KV+1$|Z z=H8>0!#`K49G;|qJXQlnBwX9fnDU9gC%#?!8aATmt3J*7t?7J(c#EG=zwF!wyfpA- zNiROGn|YE&yib2MH026r_?hMJ?{SE?_{0mAj~Rw%&uK^y?|ydQLWtK)h-MUIWmh5p zy-Mn2lYoP`a=1HU%51uPdNaTLqxT{o+GaIzE_S3~@Tv(470J@{tTRRLWlDu+=BzPv zsRso#S7W#^J4MRRgMOwp#1np4DjtKEl_Jt?6=X{#M4*nOW639SrSdUqvkzWqc~XDm zGr>D(%ZnHoH)}CYN1QR8YY$8lQFv;X?O5b`i*r<8(j?op#)7IHO4n_;luqkWLir2y z50f>n=QL|$DFZXB`y`qr;6EQCr((HX(zxl4av)wQ_S2k_$>~CN>8yh3PD_X&VwcG( znBlaX2nO!r4I-zS>9mpv;&v&mf*+mMw$9Z3gfDcK(}pp5TgTZ0;g)}@X;T+)x&uM= z6!F3zqANuM_<>drR28kxpd1DqBu**)gCvzYOoD&R&)HIFs+;Tp>}II#pdCZzu}mvJ zqfPwL2@XP83zP6tMyn|N8J`>neS_5!wJYDH?HztYXtm;8+q)qkbO6lP#7nobiEqsX zcK8{kwSR~KkEuOJev@Y|cWIqu=ucI!2GSjUg@mO7;E}y_xEk>sf5*Js?mv{-OaH|N zS`UHT>?eSChi{z0?xnMB_fp0t27U9GDUH{}I5*Hx2Db5K*G!1agNIuwgJe_x#Q^>UpeWMzo;>;I^4ESSyEmcgw&u>&83R%3-TC_=8y3y4J^AYo2e!WEGtlu#8T=k*eSDv6B z=(uNT(>aTN`2@rFR4igVcnEZJ99*|*?BKV>I;Tl(5p;aVIw$)Q*w3uaF(I8J7zs@> z1wW`qQs;!;1NimMai(Xi$B<&(Xc0k+m-=c%XRsCCbUm#mItECtMq{l6Bv)1ekX#V~ zkX-d4p8%38A^?&rtKfu`A_7w+-B%>@55A{?rj6k2Dj=hJ{=S z;-g|SgMza{pY^^WbR9TB67~n5WTrlox~M+V@F`?aQ=)dV)eXfs(3I6nFClq66;kMx zVb&u~ExLnH&}W-9?X|-il^X2dv1*604B-{h9@p|h@mVdRX-9y@hV^RCY=kZ7lFRDQ z264gy-1VT+6F&G50v!3=CctcRg2imJdDVD2m|B!bE)B@sZ5-f1GJiqa4P zWCNWg0$sF@2@y{=kmyL?kuG81IfnZrj(uX4vK4)-7<{rvmuGg ztnH2>V+oMmqx)sBPOR_4D`J?{DL57L7fbPYLPmG6G5$z6iu}v0HLTkn*nQ87hC6*D zxKh<`{ZDdz#kQwk&Q6|`NVVvS$9{rkRQ37PlGdF*jYErc_oO3pKOdfZ92TkGYstLH zOAe;qfJLf)G%8@__K9y!!0_yD8W&8)1>A ztsGarA|#VLNTT!>nnFn%Rd@F@JYG8k_}`im~DQ}ou2h`uaA#^ET`D54kSlu5d3A*xj9 z+8f@3H;Oxzre!>wKa20#qiqJ?lXD?iMD1$B`ePSPN3w7#E~yK>Y!*%mS~#t6u)t@{ zIQ+Kpj(C|knkMV1w6OzCGW}Acs3N?c`0a$DM6fzaNd)?K!f+zk6r~{oetY2$M351s zBLaPUVK*Y!7iA;@eg|PB5oAZ1h(O;#*pmqIqbx+g?sy{PnC)r)w7CG-oWP)O)L6_#>n2&+ z=HjO!A@G7<=F+;@240x|kkDWC4NbqRn217)izaCOi2w_=bW8wFN&gw)+q8-UcsxfyF`q zh3~fxDA>X+uWABry)E07a0-0v$1lHpaq;CJD5yQj|Mf3TP6qcv)RxOu?T%P}`0^{1 zFJJoR;_dQJzgb{-RrTCG$KsX-euUxAtByWDwfFVJY8c*q=k5umd;#hJb7;Y>#`$ZH?1$mzy=(WsSXJqI5QZ0jYV-5ju`b&yV7O)W z?w4(1&UE?=!}DjIo$OM+<31X`zI;LN*O0uA&R*{Kbvq>QSHkdf-?=c!`zXS`Zy;-g zq5d+YyA82M@b>p6!?mqo1Gt-Uz$(Su(J{vrZ;0yd`eJM!r}wocw(pWkie`iFSSJ7n zumW5}h0r|2!2rur0xnsFT4_yq&+aQ1XfQOebl39;0yrpjAp%|`pHBpbq*5ZFBK20m z5vey3C?olHgWby6pt&BL$^XyJD`f?V1`9B`UcT!aYV0ex~wj&XQS8h zT$wgVBAB~q-JFS0_Q>fSGnwDw=KT1Z!lLNUZJ0fkvHkzDPQHvs>!f%9Dwuo85X26J zUM4zmQtSaZ8*BSFXZFq>%=+jD$G93eXq3rvPfmN9QNst=nlZWFqL+6!P-xOk9(01b zxY_qAOgjDDPbcFJ>{)Rb#eoWA1E0@(ap5scx_eKqoIT`J-3+aOef2*yYqtEW9PL^6 zZoG9{cWTLU7_M8g>uP!a=cA9i!0->LXSN3I{1FZJSx`M{@zIX+(FVF0+qNx}0Eai4wv2B&Ux<3Z6(@Qn*>{9SZTa?{^YrbC6Vd;f z77GmdH8cP)H;^PK-6ksCrmOA$l{@S;xWi3_CLGV}r=FRd7TqAp&J2O_NXLNaaKTK;0@hDy4`3fVzr&;+WK*2mq*C z1t+A6QEIvo2Cu(J@967a@(9BJJ~>J>2wWRd#Lz8YO)!LlFwE-jh-U!*m6d!qtxcGE zkRwcm3Vi|i&Ag_FN;W9 z8+2VJ@~)h5-{j3j#KfY$E^VWyK~%vX_&7D!AQQ&e7}@N_*W`=zc{E<)n%;YLyGa6F zC*gVml`T<5BA`19HxfZvl!*uc!2e7Hl~EQV006&*2&$rZqCNPoDtpuNO#F@f2;m4L z;mq`rrRoEsm7>1eEZ9lJ?aSrIA1Sg&}C}l8y*QPkR1&+p#^JFCZP`wrKypH#vc;5R6{eq2bwp zsuM^hWzD>IA}{3qNEG!7Fpb1DwupI1jcAMGE?#Ps%+4n8%ldTH)m_)%s}c=oEv{c#fz znNZxBt=*TFExix^lFGtHTi#FmZFvIRWlHRYvpbu;o>vA}Y~1j4ZpfMTvys1K@`t$L zPdol`7;sSK%CX{tKbGA;goek*Eq(OI>nfB_G;#jZ7n6>j_#K=V{NCsndAey8{3vS9 zEBvC%w+Vfa5{-X{Z6B=`i*E!xUu+BjI}lc_jdpNvEuxkv{xAWpXxhC*ftLc^^bnmz z)K{)lJOM8$ZKT;M$dgKlAT5$UL_Tp`Dkp-hNV8RtFQtf}IFde0K5?&D;W z6|oy-MX=i4Z?sbkq3=zPHm$*W5D0&V5`i(=LIgngJDdnIp!6 zLQV8&Q&}oaK9mOvm3*kOFI~_2p&r69=tHbsdUzUQuWFtF$ae}{F#BkXjh%LZyF|K` z@W`s8Xz67{?Yt~j44sYL-5C4JydPyRGp@IK>DS>|`@?`V>YTPPV6{D@YXm{ajGYRS zNJC&r2;@Ms|5QXCgwfCgd07_Yz6XDky$_S}$U}Wyk-p9y%INEDU}MdXPZfcUl^gB} zHr4~^fzTs+g}qfz2d9}pk0L$&9efsSt1OhuP;?s^RzJyFE51^S(%2S9$9qgy*VH=Q#*QhcsI{V!~YNd_zu+Lu(XK;sE7;L-{lf z-%cW4p#Mn3%Qh0L^lrp!6?|LH+tl$=XvcH*o$}}qlYPf?g&i!3Ug`Fo%HaP@voW>_ z7IbZ?;9qliY!W?II_wy{zkGrE>iQ!FRHXg?XN?PQp=d0zb6Oyc9H#({%qI)|(j}~% zG%bs?iAGla(e=htJ4NFpTr5>-1^kwcEFC$WreuiOBVNN;RtICQlDL8m8X@bY5BlDswQEsJGotm5_OT4 zG$77mNpF1`(Il^}&SG;%a2EeS$pmVvBkF_Hi4PrZSjVmu~2e#qx=MzQjZ=aEK);D8SeVbgPNMV;2+4? zPwO3CD-8cLv!8oMd=(R9E?wWL2-J948*XHm8=GAh6#7PR6unmp3W5^f2xgSnOFh*3 zFb}9{inOV(NgME5QdEe4m|;RyNBm6{{S-vNA1(|dg17zDM4%ro>`DY*`e}(1 zN&SsMvVn?7um(OB&jgGGD+UMB8*J;=m}Gr93m^ zB;q*`i?ugwhJ~dwjs+qGs_$&5%+5mVq)?;^lq@Ag3O)y_9zEWAX2!Ng|EHwS{ovT?wCly9r5yOcV5d1_o5kM z{$<)0w}R+4K%}rnciAHjG*-_LRhxUobx-j@^?KF4$maB5;wo5*7L?aL2|hcKy; zJZO^v#YSnA1MtL`-pFN`{d5kYk8jkOr3jlJe@f(dI0SLX$6vv9P)t?SFVJ5p=0~O! zYU4Jh6zVSXGJNAHx&;(QJy2OIqUufKDx&ya-bpJU9m^}*OjkjPR8DKIMC$GlGZa-p z&D5V>-cTD3Jn?6&AOCo6Tt~J>{M{qS>>4x%*yK-{q55h?Rcl*Ya|ppO!EjH`hzBB` ztF@@DPaUJT|L z!eE}UaLk&gQ??XAn*Td!WB${=y>APDXAFA%`EvE^D^SYDZ*u?1oa&lx*TLE~T|ZQ7 zRrQt+^T2$y9PMzcx_a###H=+s^Vy>gPyeh!)^4q5ne#7<3U~z8u5v^EW@OBkB5QXf z>y`Xm3!V)&41%|VLCD$&t&fQO_CdIvjj=yYh4yvgiP~>a(Y;ix%#AGzwTTuE70D)1 zPet)oRJ56(qWKsVrTq;R)dy4*>xNO$2saKDJ!Mc)`(mduf{N}isHpw7sOSzsMFTf( zeW|HsgNhylD(dq$RAkpM%r$|aqLH>Cbaih%kD#LWHmIo6-%-)$FekGO#-Rum6kwb} zBB+W|5&=ZDPZNPIT0;a7)jo?*(N8Msi4_&C{wpf-2UN7fii!@_h6-1P1N9ESVxJ+v z1{FmL_bO%02r2@5Fz{Pc6g)bgLq%p);jkbpD(Y#2iYy!|;x}bb(I!AeYyJThokpk# zaur!ya$i!dDr~&+`hQHT^&Dmj>aI5;3utO-l#!4}eNi*2|M&7UN6{p8gh zN%?GAe^naZR5SZ2UHa7`wtd!3k9!!k&@sZZeX5jXFOHF!TXL>t-T<>k+C`v_QoeKM zW5id|8%hH2cVLtIkGNc}?cEXH+q**yTU~_CO8HGKxy21Lx(Hd~w_8-s_SAO%I zp<0YxlFb{#8XT#IHK>+KThtjb(x(o|^-LicMwbkg-*Hg+O_kMy%5N)LO4N^fQg4#H z7=fS0fl&EfYzLL!-vwscHAR)*5zhOmrZaIRcPgs>;f5xSA|<%I5dN`^4R(979$Js`!J>tHETPwd)2J$?kbCStzHLjDA{7vPzc7$jg zQ~8}XbX=>7^<7Ej_r8JF%5OUG54Q3#Q%NG_!Lf_i+RK_E=*)zb)6*_LFR)b6-4|=p2vN40L|~i5Nox zWffw~B)}=zq44nymJ;U%q_HDroURigd3Pa8httlJp~LCKyyVClNLQ_>dfU{76^G_n zaqb>sQB(F%7Df$^naP5r;hM%5d{3Myg;U^QqPp1%m(%nY*sSM$#sR%#u47i@C z%s+A}{|I0O<&qzdZLV3c5;5v!O|K}-sCmXP>Q%j1o^$FH!>C8sEB{!vd^W?VS6SF@ z&lb$6_uWCls8`Ac<$tS&(%Ju%$EtJZf=O4CXI(&0X7SQ@Lx0Pdi=a&9%XW^Z+6OXF z=FXnD=h++AF;FJE``tsAA2Cpd-uUc8I)*Zz%9uz6f-=Qi?apWTDfECWCEz3rS&F=4 z{~`aB%@14E1!Ezg?p`QbDAG$k#eGytytG;&!}=&!wuW+!&Z75>VXY9%DB^=*{R{n9 z!@4qZs>GTg6CzZRJeZ^?1rg{&gkeOG8l@%zNc8DS1goR8MYPqh-W0_e)|L_>CfJyi zWsOeW6&Jb|)JKMOA3^v%{hj)EXku7f2x3a#7`xUe2gBM*5>iMRW5KW%C&-h%M4b`T z3I$h`xHB@WLml^67|OFx%D^&0dB&@c4C^n5RV1x3HmrH|v0-hQoNOoTYhzd|>ia5A ze-DN=D&ZB_)d#~`uld0-U()(54*kjR^;72)ZZL52*g+(a^*?1sTMg^LzZ=&1C4qti zcvt2{=`94#wUD_G0Tso6NCYiqQX)`B>8*m+GVcRGdG^jtFYQ1K>$Ir97}k~1mEwNR z#IR27t#-zS_2X&Ixi*w%v9f!1{oa$?G7l6U`F9QS~BFoI*R(1f;;@F!zXXBKqenl%tt<%Wjmv z3ndX>2hFUEgBoWPJ3xQ26pyzQUTX`&89UVv-ZMIcE%Rs0)JD7vrh*EJs9x3Z|C7$_ zZ$0&xyXiHA6mf5r_n^pd%(S`|`clwj=&!PdPwbRukFnFqxN4fd;J%T2Ub<_mB}0m{ z#`;&(bK+v+@*a+6XgxiGO?%`8)53NAwc*0L1M4Iax}MQl#DDzh2sliANddnSDqgFs^n}vuQ*+CGQb!Kc@Rv+dZbd9lhEYLy)Xrb!_ z2^hCrLmAh0Qc(+ri2}kiN_L)!;Tc2q37TKY%**Yy%*$Pz913z0n-p~?^Ns1ZoXo4- zi{@Lwu)of?B*d~}LR~^E%#@3P0lfV%A+`PCgzy(j@pvoY^?#gxvYmHuee6>GcFy5; zEor@ax!OhJz6R;$=CUG=5JBI(-f8YF9U{bU`5}%R|3{wrlt30en3l<+H3Pt(#5DBH zvR{43pt9Q?p4MHoJC^i7i}p;$#fsFLWS@`&v0N0*#nKI79zL`Wb~%pgzmy9M#(WBE zs9}r6`M9oX;A@z%CZyXgnHD5rHK$ygWb*fuY`ktQ;BTKOqFwbaZD&F1DFr#CdT3@U z(KRGKjrN5<*8`6VyXW$+LUwv6iYeN&38t?8E=$r769|uobQ7LNxq~F)ova>v|2I9B zY}I2f@EK59Vn&Y{kRH?gRgao0g4qsc&LSin|TWRFOG(CQU@F7=L_2%TR=Z#w{splJr(xV?n}sI^03=(_d~MR z;oj+^B?imM>l}Wo>}AGp6}{|&x%3oGFFt;&?xn(&1xf%; z>-(pL3+vsn0w?3>Y-5^oynDtby1P$?SVgA6VhT1YmSACkZwgjOytQIq^tibhk~r4` zMlt(lVhhksY%@rezj%oYnOs?Zz7(zrO^oqP)BukHJFa^KK~_^)WL0iOR+2W87-W?w z*FvlpHZ?&3IF5KbNmx2**^d*g$ZF)jLRM(Mv(-sw21#7%u>0Q_cK@^4?~;3w{mvae z2YN}#;wwZ`pv;CTXF>R?4B@LvgsBSy%m^J0G!flO*Cj(E@wo8Hzx+!M=tT7^G$HXdy1b<1|TRTeQHon8{oDF!_K> z(jbwK_<)ON?nyvlNkBQmPqGvv^p!`YkCvxs*6P5+Ffmq<#PkFGr}@PK_Z#dq`$o_8 z{OkNuF!QSbrP==T{L&-(NuUt$;oU=(1OM2>m zN*BoJTL=Kd&V7@SIzeaJS;x{HA%wsNW*$ksj|dX$|H3m20&)iUe9aOmladG_2G`d7<2TcAQ0} z!hr^=mrAqRBPC%rkD$pMbD}X>ne3I{5w>E_@38ACVq@Eo_MoyEmJDrT4!FOPHQ?Ox z{65`J3!C+{IQe(5S;L$`I`7P21=HdowDiWA7zH&yey{MnU6{zHXXp6y!qC#&V4wDQ z%IeSD7)@AG*yn7Yx^j0P$Z}jSEYvCB6ju8oIV7!!%WdTJX$p(;CLo-bgS2GoEtGF~ zali4r*qDA5;bwDFua)EIKa<(#`q z+r|+LQy}&OTDlizr5|>q()}Xf`Pp4t4qV8}tdVP#EzZ_Q2|jp=Ug|5Qu|4xn8m&D4 zGh!f2qruUp=*f0)%Qnd)8N(1gpE1m{KMIfeMw=mu#`EKIh4!Vl`wvbFQjU^be|!Ty zdk)k0=bpXeU!T2C3my+W4}{Dek$HHI=pfsy+go-Nd#=O<~6{S=_X1aiEF2`)1PcG{|vhf&kj zA{#&WUlaHluR5F`yJSATCLt#0XnqFtR(cH8_fUO8DN$lLMY~F%dy0hh*|5uSCKdcc zAG$Fp&adHi@aQ>Qt8@-7CX1VATae1V-3%K zWdYDzHHWoGTH&-afeYKhE!}n|uMzN%%#ak7BGoxe&9AqqHFA41@X2lBQ^#hOJoZ#& zuN(UZi#gvz`&mxS;!I9<^*HClo%L8hD#s=rUewmccP;kv#XDo!4vU7(%v8Dz3GiLn~m^$9DJDIfYzU1g+HxgtU^_C`u~!dq5;?~D|`EhS!k-NvoSor;~3aOPW{7nm6lOJoE2tnLjFG`8xm zF)VSSy7UOCy#^3(O{V555AgwFPWVUt@Ux zzcq&u(FRnGR|DfR@urG>Nt!ewx#6!k3soMoI$p7<=Xlq}w$vU-+(f1#|2*d%8o11((oJa2?`?+X zeB$czWG_PtXoZ+uz&&=Q+b77T3E?K6+oxps2joR=lejo1Dh09(^Pt+q(W@D|*%O6i>%erKx>k`*wJnO}~ zPKz(N5@SWas}beomWHMN--kPPcO%k}@BN6UrMR1$Aq0O%mX8@Xx>tYiGYe1Qdt?ue z^D^hQ5bGDKBnGF5hpA%XK-RPcO319}`bIn463XJ*fb+aiqA@p}ZaZ+Skn>qq$YR9V z2&Lnf3xZ4xnT^;n3F`a++sqJO;CwXt2X`krG zPuO+yCtf$Vn00fU?f%3tzx}BB3Q?5PrMy+kTHR>a8pXQZt(=S~a}F$`bNhu*TW}CX zI84l5Z@gtt$#uqkbtuFqK{4fI$ghb0VksVPFcT-9C2aa8LAn*q!yIExIBeAh_3SQG{?JI= z%K}YUCn_MqV1I~Ajqgq~ZQHHRUWXv5Cs6)opn6tl8Xl>Gt!JvX*wEr$+I|Fa8+TW$ zyerU@nIc}S#}4GAWnS<#SG9>BP|BMrgiC2kF?aqIxX9}Zt^|cX5+G($VCtXVO7*D4 z{a}XQGu0Quaf@~q>!T7(4mfvC(}or%7$NPvp$@J!%mJTF-)mvPuN?WYxjsvN`~_WS zL?4_{!|T6OI}iD_hOY5L%U`XP(q zpoBTt4k0aV~I9$dk| zIxcM`Iq(irS55~M7x`&`o;66{MmblCk-5Q75D1x9Lm>5cI{Zv= z+F6HO+wM~}hxm2k2cci8{Z)^7aNJVwcWq(2bI>9IDt*9aNHRm6LmiEB9-ad^joiQ; z3n6nfRrvC(I?bV=BOhJVXqXP{YS-x}e1zL9c7`8BEgHqmE56)RCjvV5X@$ zqYrqQXmweo8#KG$7Yg$spYJW`R{_6{$*er@lh1!7>fxop*3V2ymumQtijb*j()Wat zFTFQJ+~-L|DsMDDl;I<5t)v`|2&RMoT0G5E8)kL)JE1U59A>~bW+p|zVwtIhsy%ak z_*`YIL)ZFj-#g?{UT23iT1rx@(}=>d-W0RqV|LHFE~|!n&49PL37#&i7I(-kLPrAz2rJylESbe zIn9v)LNgyk_VwXozF}goTP;o&KXl>MXLf-wXhdLf#Qrhe2o;>>K8 z`u_R|t<{#dLNzU#b9eti7IisHX*V~>U-)@0lD5-TO!xV(rNLpBVd(BRS z*QOd|h);(>+cEW=( z`iO&H86Q;M9HcOyuVBM6*dy2-Y7?nvHMUVXF3{C^OM}+~JY#J4Pqm#Bm?)`AyCB{| zNeRzsFy9~3;L*R!%|JfuG(vGvfjLd`=-&w>nclL*e_w%>svk9(Xws6;QWg4$B9Q6B zuiag_Fw;lI*)T=A{f53mAMB_8(nK;amW#_;wipOUB5${jDi>!McmZG)S&(z6M34)2KD;lW)-m1~~2!8&jzOxh6NZft@DThcaA)YKoo` z!?r};9ydd#M(5;|sd2U5Z}m6#>;F@PumO9|9h6&@jj_$wH>GjcZAk=8_XwmGZ;Mmm zTd*&2>Y>WyE$&O8_rRw1euiMX%tk3rqV#y(VqSW@I7QRShu0O^2CVl_mwamk_tuhp zXPbZz#`StRV`wZ%bNty)0Y=4EbqWr?HVnYgu*6j=gK8@&C&X<3XvVacE6dxuP; zw@yiL)(~W~w2nBau?pTfr4YfugGQ_1y;CX?#2vI)1s|N!h+x)1>R0j|RZh}V_&ZV# zYCbOIEcC?o(3PIGz{?{1$F;VpIe4E!jPpwW6S|QXlUfdZ_@EEfEF@)6ayR#6uQR$<@NS{w7{?7UOOy8pi#2PF z^)lmzjB!uLE`zA{iOwN2*{5kzmOLJj=To^17(%t1xu@N}i`N)O&cjglB{khoCkhBD zaNoS$d3pUw61!4EV*2Qm##?_gHN(qEv-TuD`yjhS&+Cu6i@88R1+^~4^wcD#g)gQg zH?TS;=A7?j2(Vl4mj+%opAq(PYc)``7wzMo!Jdbg7*l)M!(?SEVE#65gDJx}brOd( zf`c07;2v^tXiGS#V-6nY4od6?cUaz<$pA4pyNy_@WUX5Jwh_BVbJLjjbGlP>=prw> zn>6glytdk1Y0MnITl2!=LpBI)p=#cl_+cH zqF}`ju*ZKlQAaHc3#+|+i|8%%&1&RvU7Cnm@O%|P3!e9r6jj*?f_u0<_tKGHidui| z9VSKwUP1!^J6R&}^@y-mPmb?qGN|9uDgUZT&#Dyz_wfy_JluP>n|x2`fGXK;d^bnk zpoOBN&FO?|sx;)n?fnGyLg*^;xm6?BDE@VT;pk!+01ZNL> zagssqHZTVA6!s!7rzQt=O#KnY=tAHiZzDRWL`m7Je}!}Qp9tst4rOyIix_5TpIlxG zL>g^8S*N3tatVv|j<~K^su~W4boxQ$Fs($PMe{EZge+!~ukWB#LH#W7Kq*nSbt@=a zW4x&>w$p$OU0cA~(jCFB23V)AemVE)vycysP|o%7KX$fl`o0~aO>VaF`qw#|-z|O$ zOlcv@n*`q0PAmZG57WBQ=dV9qv1I}DBHOp`(%#M?2fywGY~nxpq@Q!?(T~7NO~1N2 z(JObkr)FoT1FpI^=Rx7Yxw$_gc9YV57c~b9j$A<0tg|<^zqE5|bt$kuR$gpT zzxs9FQN;3=R&i(f)H~0gA)qvJ&|1weuVP;zhM0&|A`Nx|t zK$LUxPr(CA^Z&SovMx82?|u2Y;ff}3tC{bAJ~A`@M1T|A>gcLRtr{F}_ZvJ_e!VSw z6RK-|KM7AY?aY+@xg9&LOMu6;>DkXG;ue1n+*fO)U~0GW*o4os?jWx5Zd;DU)#Oh) z1y5D?d{_RGz4zCChT&<7faAM{l#W18wQ6yBC(j!%W}~O-^FB8jeku6#JI+^~d?l85FD|0%$DtA$0q~yfAKl9 zYSH9Ic7U(YUWm6DnQ?pq3sT9z1=2?0roIy1G5@XHSN&JP316 zeeyNDtEu@6XRq3}e0j6#rP?4kd*g#k=JwdWztfT;xxU%!CzlIO-8t*ivTIg#n+2!# zW!~-rXMg+X%keh>uVy>LaQ~CnALT!oQUU&5dhz}Ciw%z_UWBt-R!pCJeEj~kh*2r) z*>U!p;9?Hl+m_I4yZ>IqUlehn<>W`c2}8=4}zBV%CV;nXXQ zUR`8r9*p0s>Y?9=qa-Y^2?nc5dB9R#5cFkBIXrUXxT?>V-Y1w&CC7d~YWZVVb&vP3 z80MUudle|$Kf~Y& zH$R`wk=MsAyW5E&uZQXCwuF$^D=~TfDb9~ArRlVX>P2s)i{k6&hCBJG1*6sC9cmXF zv-Qg<(BUoJ&&Eb>SJa;3o!iQZ`^N|-gLs4{<}0wBz72&gsbbpkt3>uT6i3XA{kjOg z;4oWSf4zy+x{a5)5J6UdzL^Ln%A`b4++S}MOp8G5U4*6Krd!Np0p55kQ5Y^^SJP{D}Z61g(NyG6fOn z2GiC%+ARwq0^?w_Rj^m4{Gz5C9}&z{+)ERERU`gd_4wj%?r);){T;HoKPtxN{*Kt( zpEAbg{&H;YPZMKve@AWZPZwizf5&X@&lqEKeGf4|t=pC!iT{!ZchgMRhQ zJk^*+et~&f368hm#y{YLmXXzH@-hb_@Il@f3S7Tc%)##5!4YXAw`*${m#PH(($I1h zM+dxqCeZ<}ttw%otUqls@wOA=g%WiiU1mG-&tl@g22hNoEB6A^v;^F+Q zL{Kf05<%r~y;VTf@g@SvhyZ;Cxt0`RH*zhh1C2z$Uo4Cyg4F{}M4(?R>`4Th23m-K zzeL!Z2r>pr)^Kjp6!V7lp~65PnWC>C5?7=^Yu-YUpo73(S)o-li_7x0T9u=Mbc-`C zjUXnjr}GGvf3`oi5>039)=RJvraiyg)O32Dk8v?nt3(Hk^68`6sFo(;{7T($1Kt=+!XdLRWh{l9tA7MtnKB|e8T*eh0<4D2^gGo!v z_=T(0Ubq9bWPnN$DG|YU*op46g?w|3F5!q^DxSi|nE{Y_l_`erTW!YQ#}iFF>kDdV z323O$dTE$HPh*+c3fJ>^vqr33_1L8+b!(i zw#Vv3&_Cp-MjD8o!KLk0b{#`);yw)dV$mWyPgnz;P#)O$EC{JE`DN+mjSCa!U-H<8 z;u4ws|Ha;$hc$I|{oe^g3>Y*}nXw)m86_erpkm@wwGL>l1+^M*LX=TOoFD@N0t!W| z2r2<7qXLRD3kg~WK%iokip7Xl5D+L@WDI^NW*T;py*_n+*^aYP4Y&g(8d0-NSk0Q_J?XAf z0F#jtV@S`Oa*^TQJF9l8HKNW-DTjp4N_N0yx=myR*{0^wx1Z&W*{RkHg~lm%%-7&x z8UyVWnX1^wOvx4k)OnPa00!+~ur`66EtIjGF7UhYiN81GX@P)1jT(iwE~ohuCJP|U z8+3EllAms*O=LKHZ)je|?7eYwBIAdxap+_2E%+E1Hvkz6dv2y;=g2ofdVz}>XIRFo z-2*z^3#oLUosp#Hnl);-j8sei<>@yzU$dAqUW#U zj-bd>8e}7wgc7X6wVrcPe07Dic|zA3Oh#Hh@c*~6*!EBqfBtg>z8tMb1A5eY+<)LU6lwE?<}7j7=u;o&E-}`W<#8?=Rq-04 zCQ;zzxrKkbsJf)#0Ll9t(Z6q=F89MQFQbT}4J;?`ad!w6o;nRF*LHbC``_wMBYi|e=z zRCSp&1BjvLu`26=AR`YKgoBWDzU)62fC22Fda_+sx_*L!b1)d$bVN1Hm5_rJdk!DDK!?WNco# zGatb>aC}f>Y-Syznp459D$hBfx(qFhGW7m90fy=kd*|w(Mot^^z25IXQw~u}faRKl& zdE*@@liLTIAb*2vM3zJTCFAGF-w^tu@L@xv8}c`JmlYSbOf3x=0Q2Y5sxk*G>t+6i z4?|5^fw`Sw$luU+H*$)96d5|9_EySFx{u<`Xp zF5}w)eg(m7=chYEHbqC(qRrTq}H!{ykTkI{V(k$VBB@ z7xm|#xRedo-j;GdcE_z3^RukBsbqh7E>*qjr%&yutBI_>!aUn{m?qZU^h5PWc86KQ z#kB1$evw;zx#c04zl{&y9gsKi6jBwK+vLl*zDNHqJZH0N(+V!n_jsQG=213>KifZ2B~z)p&9U?wH8`Tx z!axe%vPJ6FIAwFSF-lOiH?4HGQ7aeBL;KQ7jx9U)el(cpwG|DiH*{%$ay)0KMdX=l z`{F6I&N?+C9b7l@f5K2bfAbD?Iq=ii{~gzgJkia}|JOKH&~o?}jujaGSx1Y8;u2lw zLICo=@v*Qr^}%Cm^N5e7kEz{;eJp2wbx`Ree@b&V@I~0ou}hnQ3wRr-3AkXl_BdQ1 z+Q63=?9n#D1>z0VL|nr|donJNZs5xcytFNF0p&aDWn9BY+Xfd9-|^)I`?TlvIa0y# za1fqe-a(I8OwPCcianAl_9*H*g&rMN=u!4}3Ozci(4(5~6nYe((4)@p6nZ35=n-?H zLXQ#^dc@kO(4!=U9yxAQ=utA+BXFg7)B1;9DQ?+g*Ge472WW}=yG8PexuQ+SV^gzH zA}`3)X5xa*jU4&JWNYi;0_G-(yx_VvfeTohIP!_Pp>2W-95+ei1-aTRT;RQlBcGUj zZChLrwMlY$tTx;FYsV3!fkK;|JQg@C2lTMuJMvIKq1oPrgA&<60X}H|1`bjeAj?Gj z_%M^E{&cMT@u7x%(h+HbNc9L9vDS1wQwtu9f>hu^1=OOSvyLPlROKGPesn z-A+-O2^qi_gHa`)Bxm({oXpQG&tFJBI*Pv|V#v|w994y<8*daF;!bR(aug)Hj{UTl zhaDq-Y12)}r#$XfiM-&WHWL@{wsQEmMz6LWE)Z>%$P4<9Yu#T?UQ6e9Q5W#F1ge(f zf}ZcPae*XIwGtQboN90ZCsefx7lb->;sQyiY7H)kaw1NWt`z?H9~P;a8f)9E3tB-w z|c?)QaTNg*F^NWE2SqE z`&ZFNE2Z*tR|<1W zU6|aJLdC8W>J}K}#mJRnCv&BQ$z3Vk7Shp|cVbtH&|!4n!qjJYkfqg1rFf7Rx~FlG zD~0m8D~0c6bRD}=tmUp0mfV%%?VhL^M7mO_>8cB{VD@BrVaw&lwBiTCiMDmtl?vcm~uIn&y0Sgjc;Tkt|^!AXlc8&_J zl;_x$A_7-RGIpgnfGg!UHRW{mqr)f?X+5->y}{t`zD<>`LJ{V^@mX z@iT1fN)d1Px2_aNKU46IB>ml$vMUckcO3S22CWv-MJa#spbAX*&J zD0ii-oX*lrIEq~<%4bHe#;z2{t3N3rR|@fWSBi=8z4;2R6alzWB+qyvrQEfxq$_0r zyHe=jN|}UQDcQ)C67_>Na;1bSct6;~u9Sre-Vfm=nJXn+(Ul?|cBPmoxKi|nT`6e_ zjfvZ2u9Tk@8@nNQrLd7JrF}nK4fpS^lxL^zFa6AwB1Ntg?%!M~!hAz_l^y?qD@Afm z?n){CXRZ_ua;3bao#*Sjo~4fMB2t|+UC$C@!Sg~L8!=bUv+jop#yj;taK{R+>qr0K zESL`-8K;2^ZT$brmP+LVy?)1G14~di+Smbd@+3<$QQ%) z8l8A=cg&LIQMYBT9AAKJ*lbT^aL0oR@mTqm{sxzcUB2g6*9IRtinp)x4*$6Nwl2HqR<-eg^hckjzW52I&Zb+l)v zc6^M1=RX^RQ}Yb&nXYaD2gtGat#O%Lqto!*F0Hx|&~UJ~Vj?^)7XP&8)#1U0Th(Ch z=kTgM^wG7*kKkTe^{Ozws{ei&JczTBkE+K{J0J2I%Dt))1MsNeY$-)$}&{cGljK##5}dFynL2XoHsaFI(Il< zztQ`uDc%Zj5%c%yW2^QIwZOHO$9X*v^mq({W2>frW__^bZ~g6He&x^I_1uDR&r&eI zIk(7XOXRGFD0Th!Sq1di_d`vPy576gw>;dUy3h?Ks?E|NJTS6i#_wRB*YL5k_0qfX z;7j7$wp!O^bY4OEme>u+V{+0il;(r^fbYtX(s}KcJzze0Zti8H?!E%#OL96p#i`>+ zT@ef%gwKdlA)2l}0OZus}(YOAd&1UGT07Y*FohKQu}V@+4Sug?sFu5$DTZBK@H zp2~$;W#4Ubj~@!m^MYwg^4XkWxu$vtn$@O5P183v4yuEPkiR+lu+s~NFS}t@y&vtG z{%k>3eFvCJ`%JSM)ARl)0rSZbVeO{p^6Ow$sR4syX9L$UQ2XFKQ6BE#^u0H$jeDL0qOE9ZZci)DEuozR>txhED>bs!z zo=Lx7%G{VS1?8P=D`+1+6KsSQd2aEwCgm;YVn#A9{(p%v(j2FvKjwh#8lTi(6E-9{ zDlKuO9d}g!Sz$-MFU#TD|7PP|tMLLzUo9G(@~rFdiPH%CiWrDI<)5DNS0oJD<-j73 z)q+>*a4&JzWYQ8EHKB0y#`GaGwTuZHsDi!sUz*&fgUioHazVXn{y4(eViBZ zR;}qd$F9Q|OG9!EUQB9suiFj0w>L!>)9`WrrMM676=(*# zJLZ?!8@8YsY)*>lZdsX=i_ERpH2At_*jOO*BX_+k>V1mJ!MwBH+_~djU_3Gx-#njn zz)YBo%q_;(4!+IIk+uFX($1+cEmdYdNZ1eUj=P7VyvXGWa^~?ewy##Vzn2f;~9Jy?SO*)x-s`L)yQs_*9tL zUk>I8$D-#oT)V{DA%x@eF6;5p(Mm#5l6=wQJ=KG`Xjy+EURwci^6^mZEbT#^S zw-gi}g|0qs%_u2665W>ptxt{5J@%r|ssX+`s2cI+YL|1u5zWd=?0MFa>6wkA)58C zE#M3c8aH8!(6 z{V+l_-@YsT^IfekLNxgv4V|~jhY`)}j=ZAM5k#|?RUz9l`=P7X!(-o1C@s(sFmRx< z(m(Z`e+DZK|6>>=fOJgV6Toj72vDEIL}dzO3mD0cS71=I_w_XF2=O;{@lT&#fu!5*I;Q7ZG~t%2Z5Vj@xt zU~ra9oH%u&F;W6U0TABG$FxX}sc-j4?l4CoP8-IcszV&kQeFZmV%ll~dP9IdKi^i< zCT538lql>A$wZ}8ioD6G-sh;Vr{3AZeI`>rVLqBA&d8B>J zM5S+PR3WVS9w>(CInHqOVbV67p-fu4C}JcPK`7_G>Xb>?^;)@+sKGfAuF74nj(ez_ z$C}Gj6;dW}GG8w9#(BL8b?9@*smR_Vkrx!}FmXZ69*z&LQKF-V3p)2md?sm6S!=7c z?8F1?p?#&}xNtuI%rE2x@e5QV9p^{?iP;*}`PQyn%&7}izgw?Gx^k(-`rEY5XhO1h-|dgGznco0y2Hze?RAJ4aePZ@em{-Uh1?N~ zcvEHm@RX? z3$lGB@&XH87A~mq<=n`zPv|sbQ!pSw~s2na68>`xInazFE5y@YlI8L`>65@pRYR^7f3-W8DIDU zT?<@5@uSKw+(p+07Z84YdBHZ_dANY>N4+x@;lfPy+ZwxTrv8j_migil(V;f;`V^^aT`?^p+Kgw~eL z0^V^Y81-;;Mr!MQm|$_H;tsxli?hEj74++M{|37X)qX#0C5w)wOnH_oKMnG1|Pa_2Olkvth#=Cg_@x+cz29zUS{E zf^LNeCg_|>#p|P8<1j(@%*s|f;dwG9=su2iTtX6bp)*sIyut0u{uT&2PW?#WCv%iv znCrH)TnLdC?H{>>z#$P1ws z!+V@|3 zILwPdDhJX^A|m!2ZHxg3R4g;#Q4QA5wuas0j3cbPD^zlkNohk%$rTa0d9C0R*(Cn~ zep@8DkKYze4LKuPK!Ud-UwB(2MyI(>i$SNkW59I>Ejawze3E|KL+?e|bI|%7<%vo0 zGuvzCz)SKMX!A5G#_X+kwB!Y>>to8WuDzL9@rAmKcEj8Fu^bKGPh!W8`|`<;?0p4} zlICXbMdC$pGGf#|Q0ZVgsDRMQJ{H6unCXrZ7|C9tgP?~G+ncBa6&+$J$R9b5@`LLd z&f6^FGRJt&{F*xodgM4Ni8D(3Mieo*#pa2ln>KJy%NweSn4Eg&o&Z@ydn%P zD04cwo@fG&56^2BXFIV+)8G(q#&ogy*M554gE325Cq zWd}^_Zn4(e7;6?{erqqSJ3_f)2qU%}yxE01Ghp$~s75kBM z!<(jDujamM@<9)9ERX0T&c>M9)l~fD+o7SiLexk<&|zS1Sz=E6KlE%SK?L> zp}3OBQpirnXUby;(r~2)YkDfK#H%1uaV289P%N)hzz}5knTfZbWlsn;+$0W?;}g|@5+Xhk*GedoS1MsB$p^2`ntlpb z>Z~BxxRTULNXW;bhM^=MyoEFAXK^LwLxPJdaczVmd8J1TCHdg3m`T5cE3qCDJX}d? zBP8VG@R*?_AG~!l>B+c~<3mD(D{*HDMe<5d7)lv<#5c~QXFf}#oOTu*)$?3tSZ}&h z-_clGu+Fglg!npf2$FCmznSzixRUfC;e;!R<_LN6N-r3KWLzm=CjA1gM5!kH-BXm7 z?NsX=`PLSTz7lRU@T9w|!FK1w(Z(tU6Er+`!yYgN_5kkjZ%2e7ARUCw{hwY2OV5H1 z0$+P0naFQO%`fczo+)yH1Bwpm7ieIyV090B3+6CZrwF#dt|-n#Ittobj3J+2r#fS& z+7=g(IT)~ZQ_MlONPD7W#yc-d4yaPhPy!cQSCXom+%${+1~M_M9d&VAUQ_1F)))(|48m7J=KH> zu0&ZZ6w50yH3YYDrNuVh6O!=1mfFyxyU8YqUkJx%Q7uUoU-}>jHggRU1qDQFDUJL%XZ8Fl zO$PajI3@T7+9*0cBJd$5HVZGnBJ_l%pzoewgcqTYhM#-Bah)lw0I}*Oya>x@HK)N> zYYHqPs;(!w2wApIr70jdc?lY)ZN5)7euuJI_(xd=hFTo0)#FGMU%lx|&tD!gVn~ zmQu{9nJMPFpB3xE*It6|Yo4x|o7Kq7%;~8h^kx%&CVG@-2|^!^@Y_&T5=|n~IW^xb zJy{2Og2N$TNd_toQ`Vo*CJnghoX&5^I8N#c+SjTFLCBNibP+-{t$2$@V(1!KhfnE0 z)Z*QSpZEiU#?e7bM{T;=jk@^^suuk0&=Yy&d;(XRxv%zfm3p==Ox=^5lL z;iRE7@d|lMY@bCB!pdTm4tM_E&2+>Vjf#2bXy)bEgOOE(DFOtNj| z%^Cs@ZrgVjoqmnHpqgI^6I_Y=olwlfm9}UIObkYplR6(jVAQaP61{?Ll5p35xT|d< z3uWWf^cV2+#W4qFhd#;&^sDZ`PmVv!B-9Z25u!GVMF}adL%)OOjD}U{Lxha1!sirA zW!I7Qz_k=622FT1;${@)G33famcnf#Mp8@Vq@Wqtr~r)4@{C~ecNnpJ;D=GDkN6HG zsrz#HMx-$6W)oSk!iv$d9bbu;t#`e+7?y1pTDJRb$*U1lJe5DdJy8jZP`}1_{K!4Q z6yDHqd-?c^d^y*V%UNutyqO%7<_bdBjB5qIV^?TE{l_klshHOI7|zTcPcz%VKPcSK zeEbrGbEa@%P?Q&rav2Yo@GTk;KX15%;dMi-7JA>m>a27S{V?<+H2hMB zj%Sj>rKDaMb5~VkL%I&Ri%808{6qRbyJHOe#$M)ybB>uCut5cldRUEoX1b@@v9NK>}H(>bv``4 zmCYl%K=K=^Xsn77vd8|a15)?TZ(B%Ep{r4y&~a>zN)K$pQVlcuM@UkLEPG-EJUHBk z4l_;Zy&yr04qKcMh=7E)-?i4Gew~%qfba>!C)SZBcB31epjn}J?W+_GVR7SYrwbS z%x?p}b;}y?Ekg_VnENw)42(k`dn9T9K2R()(#aC+3cj@(37@-hocwh2>U4&~F~d*5 zZ1r4o`dhr5gKGu@2(C{8X{}#52rPPDRSO@!(#Bfa5tJmhc4`c}XaB0W*L)=z5*GR4 zO*|a%Fzn}3d+)!ieGDs~ z#7x7C;qY*RMz@0GlFTgEGapxY+I_^41t;%5>@CYa=uv>L>uH_@%O9=LbRE!~;aC=A z`!I{Sc98{=Z-Ta5o8DYt^?V4}{utsGt&C4fm&qojC-g^iro;D58)BfRUguS*+_R(9 zhOf|7pdPgKY2No=q9nMX=Oy`gL3WPfd~m;L2sn1~JrzF7DkM%L=oO9-nx-zwo7p_Bc22A}L7;~gD*HzNUX zQdU#eb)!>;10xH2!IIEFzq-R#`=ZgNbm3x+YmAY>jI8T-Sq7b7FwuC`jNu%Tr$>z* z9?!-}UkU!!ap;ZF?|NjTR>lwEhEbd}U+&({p=C2Q*D6~KjWfPGK~s6k%Fx;Y#lTDl z-*4Rhy^WGF=vM8A`*MWQ+h{u+FG(3a3K4VNbDSJ0LgsU?Y#t>@>?c$OQRbV57J%t_ z|IU@HMdG0;h8FAyfnE@$r=AP4Q~`z-vZ%>KFCh}1@-?(D;c@8CXFFw6By4zUpZwKB z8U43xwyfn3d$Nhc7lKh-q^uF{Q!d0#N(yB*k7hZsRe0W+lHsUKN*6=NNELGCur5g; zuNUPvhR!%Kp^~YOuE(u{3kV@~6I5i=%dJPEKG)Dj`nA{E2{J>hqKzKoLQzT-J%F>Q zuvgGVt83S4kn9*BKC1`d(xco^k2Sn34>m6!)ZN_@~^@|e$mSNeK+ z)MbL<9zC(yna9?<5G*M`DB|shvit{oJvk_1be}9@G^#;<6x!NOgfQQlz3i%q&{Sm_-v!q>CLk%q$XdeMVgcW11uisz26CxmjuY(Wzds zpWtPCwoO~&q}~+ktMEF!q@|?4BGcTV1HJk0-*zD^Bkne^QBe)sS3Q|{^z2@Edw%8R zmiI$N@w4B8d3sy2>DGn=sr>*SIpwy#?e1zr7=kF;%OTx1{4&BnqK@|G-r87gk7&J| zt6f)%lBY8OzvEmj^&j28x2+0@3-9ROnDX2Oi8m2fV7&4=*RH~<7+UX``_%J(d6qte zL-NNb_coNIn-u~=A?};rU*3PQ;}0+wFNhEBx*zQK49xAu$8Ik_cgz9IsVCa)+(0b$ z6fkMLAKrv6#%?%N^*RdWtbVtkV&26!E?`b_ zi-<*OalZb+pv7ksV$#AoN;03eI3pH?ThUfmn`=SLyTHoIEPT*mTngqNPvn>fR^-ZQ=!daG)ey*{Z=wsM87b#?b+3sIG{zi^@9ejcEy8;7D>8V+3w((I1O<${ifGKD%^L;3yqn@W+e|PU=JfcM$vp&8}PG6Oa zXpy#E4gDpfb3}+12|O7m^$%&1(IP9Vr(X)+2DAu!6e$QM*)+=Q|E?g|ox|-F3!a5H z8^KToyp4CSye`@k41C=A_nDqgc9z$}s^DF};+}5id~X_pQ;r?0&;QhQ6)mEKhLpqp zU4aOf;`fAe!-4#`hiH>ndj$3h2B0&|uRm3=>cigI0zM-BR5E2ei0p*}{}fT~wdwNY zf^!ogCN|vOHr8M5DGH%&o3W`SOXHrs9+)QvqYbX3K`mPmytO0^sMGp?SvkD}G1%=1fGqlaYZl$*hzLr-mcIO1qM zn=JkN(iWVCnCz%a9Su3e__6@FunR4g*Ug^?`@Mq;yEd$891F{pj2uQ%eWIsxj=rB$ zRMZO#D#m(z=IpoU@=zGGb!3zM?_Kl11YVir^vVy(jb)(wUYMvyDkhWJ(g&G-WSzEh^vX#ZRjWr~M9t*m3D}XGV^1HC!wIYFO^-qC;ut z5ayL}V|@L*FC(+S{Klbx$h4Ko)@YAPimbBRQ)Oq4MnW!VjU(%cWi_0CuVT%-N2fra zB3^$V-fi;aD%ylER}Nh*d@*x0+Q^3DF7JJGv@IO&Zc2umcI-i$rAD{EF8qD{H1%I?G zn_&;yZ`AVBH%7O1F1rLQ7^UJ$;H4?4Li7c~voi>0y=%964(5BRx;~T!yRShgd}!0) z{C1lMZ@@hJsrg`;&Fc5Z!2Hgu;;Px>?`5KBeLKTfv#RO?{b4hZoNuw~&&)cU1;f{Q z>e{B{m($N>gLzxa!=e1Wt>*8*+%iA>5C7`KSoHN27B;)>-YdHyv`tpGd4=?Qb{C`8 z2X`m$+FCIc?H#XMN?aI&xdXt0Nzw%+Cc$YN8{zKEF3e8PnERUmNkrxw_b2Wed?C{@ zu$X@~@m(PM=T&KsGV>dQVPbTCEQ((+daR_00)yY`+Mte>@T0K0njP z32;MxW@K_ixrs+Qd?SeSj`Te-59>l&0L^`U9eLxQZqI^iPc_MTbD}oQsU6JcSNM9I z>Ma+dUG>dmgJ9x9^&pt1-powCaw1by0p>Rf8e9VE8$AIxn|@3tg>yTlZ?3MMokU$Y$p**1)TSr_dLOd(O_P`l~OH%pxzZ&$1*vEfI}O@nFT! z?1HD+NcJGWJTkAj+B_82A-}5ZWzmx!P?Q9iylCIWra!|fknBOt)#tJN~51+Xi+_qbD9Y%&*&2q2)WYCJbdcku+(a}TaW@lM#Q#rcg(#sX$2}t&U zs-JEe|LD)27U-(u(`}5~d96Yudk{2q^zN-ojn~lq-eY=isPav~NhEvF9`Cc|eeta_ zSTjU$fZD5;V=s{GK~Tz??#`^+g>S(;i!R80o6{GHWDib8Up34qY(?Uk?6{$ogCAq> zy#@0-mM!n?4rLl6*@NQD?twK`r7K`pjO8m*qy2JBmqnDC#)=l>hV@LDi1hSUHi(tjTpK)3mt(7LUm z8~|9?A~TgGOwwb*-r1`uC7DCcc~zwV(ti=-%)$kd^QtPiK;-O*3plZ=G+dDC>@EFD zHG(8SdNx9wbn{m;D%j&lX9qv-ZAjDv(s|;g!{{YFpp;$IA8>(>&Ny5^?BdG{_URbm z0`@LyG_G+#XEH9}?&8Y}4(eFo0^Tm_X#-QGZPS2B?1-3 zXf$|tH6JHUSj!I}9bKH%w_nM7l&C|;by@Bbc|np66BjtTbL2fr*3rWS-tH24L8=ac z3!>aP@*bt@nBanJcZs~5Io%bNNsDk2eB}i;>;$wOh zQYoWotT$p#ng>SV0P6_HFDCQWNPfGz@JG5iUPL9T3r~=v-@J(l_|dR?v0q0~Z1IC( zn94Nr#+7>Uy&;@-*ltLNfeR?UR3BVmtUC@D5WakQ!6aQHT)_67CWs*kRn7y|6Zl$U zHj8k9PcLXx>?EeqJJGg^f4g4O7+pbRSQA__O-TS5t6G8!9$4@}mXChX$Xok4%*Zf)7PJut z`N9$wWxtMWke%=#^F9dX&sebm4>Ic(Z%t6*k9d$Bt(w0e2bulzS@r!(V2~FeCTuJN z(=_XRU%Q^AFc3}i&AddpUVzvOOD|H6h_nVs5RucST86R0`&(GJK+>jq2N%?Aal{3j zcdB=BLERSb(25c0jQRlU1yI|Mub`Nb_ep^+T}0kMq9YP{L6I&K7l@B=sJO;0UA;`Q zhV+PBdXeW$zB6U*T&R{*MoKU4nyT{cr%3Sg_n~Nb7Xs2I@%6xYzeUfM3$Ayk6k~ZsVPhSa8M<&+ zM6-ae?o>p*iZ})%XD4WT^lHkku9T;n$|iLosi`l?eNhts$3j)B%W#iWprvYX1)+@3 z5j6{via(q2oDJyiGL$hJ{z*;XX9LP8sX`0ZFdw6gtb;GTt#VX6(Yp!{w!vOU>U}IE zQmzUb7e_(k;!l~zMfrIw7rK!}Uv@Y3qCUA*NL7WR?LjT)@v?Pg(N``d-=D?Y{7G;P zAX3T>fJk3zNU;o~w`s_NooZ4Z+=vWg&lFOI(ec(RlSPs6 zE;2O^v;%nO&mipp9z!m}$Yls6bUz_4RN?nm}gocn0$D}a|$9Nh-nG$1y z!ZE3akfZJg=#s)Q_84ITY;MDXi*H!;!c1Eyc+t&0P{uuLJmxh{ACrBXcYpD&vuq@u zXgr2#Yr?A0hxdnDrosEeH$1r*<9tT&dFME1C42Nz(2h;&Eyfbc{4wd4kU}FLxUjiU zc=X|zAdp}bBMC+?=1Y)Z?1EvVPFN%c0Y}7|IfgR3ATng-o8V#2Ox-UiW{MPHqIJrP z#s(YuChdaUAr75FQPuXK?PtY^y-OI&a3CjXS%D4tdYx-GOW%`OM$_1-KS&uTa6O%^ zIodEA6O#N2{I+q%oJuVo|i37~#lIUzrG0%H#2?KW8T z36AO@5<0^H#lEQj#9Ysx?X;M3AnD|WsM&MLxQu-$e{AYW{YvWyB~yGWI!q&P#ShON z_bSi%jt^fPvJH3C_ODcuJHLRp(#znjwDD4^7NBuLJ=K-tR`|!KQUgX_M&A)l6^h<` zg~{7}3?-KUau#c5C_pabh5wNLNHkt}OYOr>@eqC4M}Iv>95$6ZNL5kx1ff*3{Gc_N1&Dr& z*gwKXj9EPP=;9Q47NEfy^ra!`I$GrINK6)r_6MnFrcNfi3`Ao@ty(!GjG<_^2OD&I zr@DMKvWgw8?<`(!Pw=H8T#~n7}ss9_T~*iI=n^h(6?KZ!k`UPYU!Q4Zhs7N3B zDbk112MY8dAd3q0As~x#fDVvFIekb`qz?gEWXo|bKo;cy9UzNx`jAkh4*^*e%PRr0 zC@%Iepk! zp+Fx>fpe4(9w3YI!2@JbP9FlYD5nnrSyZ480a=s}9w3Wy`Vf#sk$Z|@1;DXD4SDQM z0TJBe-_cii?44$ojk7m7Jrw&1Fsa9VNHFDi-z*_l4iY|PD9J&>Ei>t1xDxLn!NQe@ z*#PEXu=F>Ek{l!iWHB075Q{aV7CX!W<WKo)SQJpwH zr88tCGm^5Qh7|BY-ZWrvqEtzCkHYY_iAn)Qo=XU_(s3p4VOGsAyi0(E9YP3u28j@) zdDUDUm5nm;Yp;u_k++?bdK|v~lHlu)KLNvIGdh2f9tex)V#rfEldcj#Qc0oJ;M*s; z<6ef6AeF$f$$gB5X>um8kD(+73EO7UX&5BTuO4>RdwaHYCx z@v$h(1WK!kWta&R0l+9{0*7Fez)axaOnMBiMEQkq!<7gC80D2TH3TWR5+I8rTnUiH zsPYjw?v8k+^7}Mte3D>k7{!$Q#&rY&Yc^3gO^@OlF9X_EW1y`jk;7iM{M};=v{@Ob zOqsHN4F=knmpR;Z_&1>iZMEt)0NSW40BBR5{(2z^v`Kw``BbU(8osP2Nd>;6&ny^q zS@7ed(FWP-)Um1$xLE0Ye?)wj~8Qe72q*~$b$t&w+> zLkAR|xCtxM#ppg<*KqzodNsv7!&%7}{oeA4Quav|U&RoX6@swp61Wt%5E+D(0utxL z`J-dTx7Z+_M*SZDEHVZN>&{ttGVwzY)*sw47{YpCx6|fp z?idtvF-F~vX)vD5MJViLsNj5z>xD(?Z6xeR`BuozCb>P~dJTajll;xF9$vN-k++o- zz<#(A?^~f*UTK4dU{c=5ZwE&AOvf|i@0ag1I?eLQwv_==e2Lq3okb7EZAbk|5bfjz zaW@D>@=6;ulrG;OFUWlsJ<*K>`ldomgltJYXHQd32)fls#u>*mf<$IVYGHOtV0K;_ z)fmH%3Yf#Yb4T4n#NTTpvDnB=*%E#*x0u61dBh=sJ#S$)1AfvB1JKtQVOlG~zZ%9x z$JAIeN0zF$@H*oZX>@}TrL_(l)q#<-(gB$KGyBCg=qKYN>Y@;5!LgC8kY!RYErn$} zgjU!=F2)C?dxcCxPdb|cSMxl9yc!DNgG=}38`EdNJ>ea45JS~>W{Zx0aZSFQqsZk9 zaRc7u)4X`PaC55o4m{AXP<>vAKONH4=_mUN72)1!TChDBbSMWAZ8;xtG^aHzz~pNBExLz&S5H$-pi8$Lx`%_B#aR_ ztly+!{dEl>g9rx9_7#sUbEwqn0R~ezCzJR9mJ7gAdRtwv03M|7z}~zlofx=n!C3}} z(}UENmQ`wnqQfa1BIu5{^22M+Afu}_O_MPPLa0@97?am(NeVXJev+;_<7=+g!3~ga zoT~*yB1dcC@=C1>sA^2mBrW|QWsp#(w8Ee7{yBqqLHjcXu^Kao^_W2nRbUV=0&Td) z&EE?<#zZjK63=a$Elj1_Oc-bmI`KSojb=mKweV-Z0OOgb;6NPD6O$ zhadhCKuiJv(Qp6 V(7TPB!4so;-E42P?mOm9nTZ?2j)E&|dRLRJBX-#7I`qxpP* zB`?O!M~IG-%EX23&PxV5@CEg=YPfyYiC~=!4(_SgVsDPP&G4(imWT%8gK9y zAbDRen`9e30x0s-0S}zrQQPewHPUClLdZAi5j9E`@rysuhm!{!Xs<=tw4cPXdk%OQ zT0rlMKUtZaj(m)IaqP%5GC3tAFj&2I*`3!2`FmG}vMHd12Ol1+{S41*ob<-uTXs6v zlOy2cP0!HL_YksIJF4-2qAKB(uBG&cB;e=ZaZPQ>>q!B}mOTV`u+%mfrOwPoha5c0 z)JFLd(CF+q$HfkfFLuiQdUaOzSJPMD;4^)_>^re?-AO9Gb#@e?h{|zL$>n^_Pe?gr> z_KD;FgAGNdA&Byek8pe|NBBgxnv73u9_ACpGComyK*lF#%J@WEl1~(oe4<#!CmtN( z6U#hheBz#AKCw#9Cx%qW_`bebGCr~HZ+xPOj8CK%$ZYP&_(Tfui4&QW9)`~35E&6! zq7OtQ1;huNA0i^M82Cic6MV>}0>tM*gPfe%@R>^Iz-M3;H_UY6=^?7IGg=~$<*e9j z?F6V`O{4^6oFFF1I+d@7Ajfs zf^iT1!(oG0vaIo`KQ^A+h z-pI5RS$RlH(O(8I3M*s)<3<_4DE37FWAKj%gp@&xbqHFlL(rllf))c{nq~jF5EM${ z6r{CiKMrw=@$d%-n*VZ(|K%3{%Po@J;xC)N`Y*TmUv80!xka=f{>v@?ms|WVxASs5OCNlS1L0cY_mRP03&MRRWfCCV_q-F8FMPDS7k&cs$L(!*56ypKgUH_lx86*2s~U<&3C}*2n_69;*P4hDp5^se49=f^Z$A*I z&kw)2opmJlAR>PouT*fe%5ut)(CDE!`^seBL39UoR?HYb*p;6r1@ox3x>LS&txbsh z-R69jd#l5x0}<%UPP-T6%w2<6d~wOK-MJqk!x0mGrYo$%V?wPrQX`%d8=PyGk-Y`T zUrB1t)riPhn@fQFWw!U(HKh!O1UNu&2e&_RL50u+g^uhp5@!D{Om2qI5zgrM+<+&q z3{mJvcFyEgC2eX&17M!ibfi7MrJIX{MthDuY3z+FJc(HJ{Eu%RKD!!+LPw$^r><)G z>qxyX@U747yjnzUk(W{E$cG2_5&{ytFCwAQHRlfW#V0R7q_+cteqN*0dfX(J+>vIy=y|GDRecp`(lot$FQdAj-FRG zUl!UugIT5O$7dWH9QwTkW;N=r&(+*{^UocFX}ZvKvB>R2zxkE1+f;m)x40bZ_pd`~ z?3TwDxD6G1qBM5A)|BF~jE1xQU>?OXD@@tkC`)74bE0LwLG8OkXjTnh>@}*o7LIss z>D}a8!7gS^-Y`vV?ftclhJvGLR_&d0cNtvoJq)wTF-S50veDij%_{4KbWE9LYc$dj zZceMbzRNaiKblqZfwDh)ORbinS@n3IS#qu;>k*pOq87ba&n=ht-v^Cp;%e-%QZx~~<QC_l59$&n9$x+%r8=lCtH=OIS1a_pVIbU3GU5(c;GfORi*8bw5C0 z{Df!jM$W$Cs#e$ME6@rzY+dya)6A#*?QHe0i@4nIB^G zn^lJ$u|=(Wb`_e2r;BfM0qv$MVJF~^iL^Am7tjz6I{_u^U0r0`p7<7I z-uXH~FuL%Xxm< zHUsUE%exBAOaACVcT-f^o6UuRhQd-XfBq_U@aon24%yuVfYz!vrdMDQNYpbD3zM6h zQBug72IDcYS31&BbVP(nOJU>uMwyU**7yCh&L@9#M@b{84m-&C$%s5mnO%QFKIDIlArjb@zc{5wU+_k;6 zG&L(H9A$5?N_SsX(HPf?_Qm!y>Eb_>nNhJv!x}P!?x{^5E!_D%tQqRELTk06qgT z6|#*Z_t5Q+W~XHvhot01`{2mN(OG7JP6``GC8cJa#(x?M@gIgv{O_6#X znyjaC3RgPAWzNHuxB?F4#A?+aDz%y=-a;D~LZOe!tdpzPv+T)Nh1mj0vkI;t7H|lB z7S3pzLNBN5sRZG(6myy5a3zY6!52=}7FQAnb2>v;t4397Nl7xEV^2#2zBm*FP8Gs4vQJUp0*rha!N`?bB)NQJ z8`mDdk++Ckt{s6d3ECUI9?6fZqD_@eeP66+bhFYxp0^}&K z!6%h!z>sS;PqLM1vD$zZWke^O_U0F06BNpq)|WMN_l!J@_>!k+lVpRcTTOZS*Cd%OU&Lv600DEUle5 zz{5vB6RC>r%`=R$ZtUyN$H zZ9*(OeA3zr3es@jPIZ$l!a1-AW3n~osLUM)i}0c@sfJqP!u!2i9=XB8Qm9wBOc?E5 zAn1GXXEfSoA$YXC9}tA`Wj{B}*v2pfT3YELc(kJ)5G-SIw0YfU)$cBY(H>|2dGLnr zzJDIP!42aR0BqKYdG zY-j4>N|du4?<>EIgmCc9X<_6&MV#fUtD2H+v(9peB(iO)GffhO+uotCvKY4=vx7Mf zSMol~;XcQeHqzvQ9I0DJq9HtM0v4)S-XS6GKc}gD`xC;YE~82L7l9n-8Gwy=ulQRa zN6iE>ki$E0iP{|8WdF304NOYwIY?40Z4=g&jUn>^NQVB0Ko03H=GPz$XMxtNLLi4I zUSa|3=Q_ykDh6^;18Fy5*;Z=7;!_Od_)EX0?3Z=&?Gp!b@YQKp&^JoJnZBO)H7s5! zs-G$`KrwE%aApZ84POy#rtKzDC9=pP5Fj2EL@P6Zee^XNEe|QglJJqdRSahV?ImS} zVk`&8>;V}=XGxP}Q#B_@D2_*sB1RirxnoK$TOnv^Zmp?3W|7msWEO2F=VT!)1yl_>#>@( zPJLL@H{lG%mSREEm!Sm=m^dD3`kvcv3L3mWtWeP9?H<Rfxo23@_mRRe0Ow?y zzV8)|DfbU+`i}crp}m1&P2X7x$CL*ZG<}<^k*4qdm2iv3B28c0D6HvAS*$#9Q<}4q z9ij$p*;0|1djYsDaT;)2G0OJ9eGLbMEVI$`KXD0QEELi5hVlcTTqv-^O~6pGAJ+7R zxQ%wC>02%S55qZnP7H^0n5FzJoa4>~GMppm=06GNi20dVMKZm6XE>gIN#k6s0oZOV z+t_6G+G#TmtVkKzo%x1RSJ-HV1J&AGS^q;s2XB`Y!dZoz&qExM$(e<)&1P#GR9c5F zG*$Wb&KEkG?D1Vm+V^Wu4a-cA7j5IgS}vOh=C@6vfJT*tLwWK}@yUjXP$D6L4H7G-Tfo|pKQ@DH-jh#;}W4c7Ga?!yX! zZM-SM8ynNK@E6RoUkR>h`rfbcjrgsG(m5@&{KRM+h_!hZi)&^wiRY}e-nR4qF!mjA zOaR*B+Dr&%mS}+cV3o}@Cky^F%=Zd3>78EO5 z96$oJ;zmJ1L5aAi6BPwT;eT$D=eeQp@BR1fhqn1X&t2!9bH{VeJ?Ac|vN@6BY6?aP zM>ad)IE#HPjD8)gU0kDI$B+{#*>C@MzmERz4SpR5vd#QDJ{_g}IzAoE)E}XoO+?|U zecW|a=Sjm=q6WJ2R={c1BUI;IhSRG1s2&$o^4#f6`f!yNXIcf3VChV^h7u&5Nm9uR zrZY)rt3+?<0^rIaMcivcE!k*5=yUP>NnKbiN zJlqdI*|BivP4Oih=}e-#JSm-tw?L)AnReK*aOX|;C3?`Aq<498I+JFBiii8*XFC?| zyg9x^UpkZgE>E-W812`Qc#QJvNc8Y%K1u~#&Ah8W!lLu!+~tvUMf!y*Dei~ecB~?r zODlW{E}g0DE>A#bk}gu|ai+a?tP(oY8egIlovGn2PeNzXFH%WyKjhf4{-QG#`4Zje zOk_1rMrV?KuhPFg+TO3D!R{R#A1)mN%ZRyON6E0NQ*h*CCn=i<<=HH~MvmhEjeV=u0F`>u#EHs;bTWItuLC`E?YUA2e4LnjbV54Kq7vPTV&;Xf9f1cFzYfjq_I@2Z@@T&f*?Doky}dU&qVCupiC(hhK+$lNg1IK8qZ9^E5SvaMAKRC|tDi zv?gixPdHpO`AAv|enIPj|0%P}XqOE6W>q4bZU2Zu+r~xEXWK1?*tbUK@tcQ%a5h-Z zq5VC|gv42J<$VmosdHLG`+Gy}WTLyW0(xM=+p zi%ix}Xe)>cWx5ohaM9uR5H7k1g^La!*=)&lF?cLN9AF`DE?ESQU+S~aE@#dPXqPcu z^c`PqO|&s$aD>Zak&8#(Fp3N0-S*+4yIiHYASn{(EE{{NhUP*)7xtIF4%cWdsEfp9 z%P21BjzPHS;c!-5R|s4P^vj^bMN79S=8ZkYeu}v8M5DIYe^WeL!PiC#<+dBs##?* zhT3M7KIA=Ssj0Hd9`Z<0jHzsCs!GZ@OKop@*t1xVsRpeP5?ne{!$Y1RoXVt6Rq-}b z`^loACTY;xl@28*@cbwjRZ*pgIL7)@1hY&&ihjIOlnM@~5F@$U8yAWLT`s z^m+oD=Z5pZz+evLNTjv9GZc0;MklnIq3#o8bg616m3s4FsF3>kc&KQCnog7hi56N{ zfEPpva_ImM)nJaT>q6^~CV|uH_gGYbm`ge!7Y+q2=Bfny&}Di2;8D;TS!x_D=1={k z`7j56Kw|FtfD-e;(r0qVm=8EKzsMAiT5A{Jt4*eyPSkcQ_A4tvqV*_0vqv40=oqNk zx(-RS)<1#6>|v1T)kvb_$vTkes(ffd24@W|(e=ZK)bqRgXn>fp8Jt_{RIX?$Mf`M_ zH0{z>aINH*|3+cxcdVtu(C4in)q6wE52IX@m|h(5-`2O+j-iwZr7Oscz0~eobV@f( z*9@Q5@-+~KUee(sd=(fjh!I{`sMnTzDNs1%fKl#^qIatVUYQpe>@HRFG-u*X5u+Ps zoBYcb<`hB1=#%LR_>BMHWL0jL^Ya>A49P7!sDWz`jlLG@l>#uWk>pLYXS-Jldc!lv zTV~HPR|*91%%$AyS?)@~NO;z(!t9xLr63rdx!*Q>c6Fs7lnQ~w9K{h=c5(U%waL>N z0w0)rcgSuqfBZozm?%fCE3yu0qa2*;_d>!!C}E&~J?#dPb3-Rs`(oHc4N9o)K!5|s zOK{)_?EIWpA9#8U5BYY0iwC@X!}-E{S?r}iJ{<{@a{%&GxCDr>ER*9DB zE>Joo?JAOZ?`zT_k>z5+)oPOtNim7T9K4LuAq{b5|4C;mTSYq4nIz?68O}7B!}MOZ z!~HNu;JuI*rw8vE4bwxltg1Ic7(~58wx=DP};lg0Q+i zN{@FPHW3>8ujSmvWzSsc&G_%!4r}rRD;o4tn z6D@iP{v2unDxEywR$wl>1RQHPH#<2`+|Jl2t{^uuO&`t9yW>IqWkp{<_c)d$4mnfZ zULQ-cDS_0@;V}PsF{W-V$IBVSlS(g-BMFR>;z(PJ(&9)Gqs?)o14imN(ix-eaioi< zbLjSn7hdxOAB2J%aSoSwEjDEm*&yIHIh$2tE97U$L&wr(VEzf)(z&gel@4H*__sf zT60>{%xUHKFrzgs!HiZxojI*PnA6&)hZ(II31+m)>da}~XijVY9%i&=CQ!6y4z5vK zC4P6+yM{f5{i{`?Q@F+@pLAA0_jd?Qa=*s;7nUNOw7I!J=l}!ObrR7kj_;0NZf!k zrE{1S@hIM%tebI)WLTZ=QMLdix@-Xh6uTpS>`WU~PH^J7g2FepE12err_=|T7H?_8 z3+5b{>=m=Ez=V?1c_dh#4Qf1R{Tcx`{%0{(JAWym3??HuN2s-5d!6YKn)5SS9@nCx z1E)5_!Xw(e5b9_94~BqO%^nR&Wu1}bL^l=Uxs**I{2_$f?u8^Y@K>gzsex&_JWh^8 z(-q&qS+SZVG;sN%aow+4Z8in&fhjN)2C>nxTD+I@hf~!9YHrxeQApQP<8*Bgw(1o% zP8(K}P=%&&8Sws$N|tM>mZw z;!qFvBs$afHKY%nN%~x@$C-|Dbl=dKGJCM`g`Amd;X=-|g}g3;X(dr`V}?=ZslO+N zC}0<6!L@oZhA3bUY-v?r+6SKS8n1t?Ev<#m_`nlg6S+6o(&nVh54_-O^oo`DxAv%Z zqy;GArwuk`9OI)n5o8Yj12QLbTwszprgPf*oP9c3 zpk^NvOc=Y9XP68y46FG)8&MT64OJ{sP*tRfZ*Z=>T~vneI$TQGa(I$_gwu_N^zEIxJsdh%Th2R0q80QoH4?mfvI8(0(32gA(ZHVUAdCj}WRJl} z(33qDqrpAdoOEhsx7yId{W7&!UgflOWq+o}YuP%Iz^Gy!X^T^PT)DscM1!?+DZMFvE>j&kaGYT+YWR?bMNd5dK~1m zGLNE-cBWw*IK|Q9Aa{fG7>onaF*FWpEa#ZU0i2}hSyr@TiJ82wr5cq}%N$5qVCv)5 zHeC)a(ZCQphKg1@tpRUu@+7KXF$b)HtjRVB{KtPHYhVsEn6t`BIR_eex*L4ASj=sN z21-%HxrO@IVqwW*_%wk;TeMDxPC9j|Urg`Jf9=68wqk#_v=G(D3kF8z{O2~0GBa3I zbYP9P+kEV{lzwu%gL1T8(z)9j`pKOREQuX`&!! z7zIlRD~u*c2m+%?5`v4-6bWIAkwik+V>C@dkQjwZh)x(ymk!65~4RoQ4*prMoT1wJ4WA22oH>6B!m}6G6^wo zhb^+*@Pgn*(1m?ekvsPf4R_1@dB!E4N7yCW7Pm$MKyrza2rVaBgpr<;EX9ann=HqO zXPd0SNMxI=#Yk$Ktj9=ho6HEIYtYyx^DxrdCW|oA+a^mfV%R0iG2+=JYcLYoC2KK~ z+9m5TlG`OSrs4YRl6e?u?UF?p>FtuG7%}XVbEIsToQ}FyKZZ}14WlNyhB16P&X;PggSkCO zC37hCCjRoKl9J>2Bu38T_|6#d#_@R=xsT(!VK;8A+*X6#iW17pw7mWCH35b2L}{ z#VX)enS^?x7|m6Dv7*cq%zDlg%Rw#>@K#i)gE&2pQ({cpU!Hx7YmgMSzx98?_V;m$ zvU~5HI$fqd(IQ{7f1|dRF`=PWghPO7d@Os~gEgmzJqVa$SR5r_0^wxlW)&D}QY}DN zt+)F~tIMZ-1J)(D#Pbo7;lzWKh%i*Q9ZH2D!+!~p4A(Z$GOQVDIRvU;d7~;G4%1Rq z3>Hj|IO}bXHSE%nl+HVDbOQBr)o8g#@FTRm`k?&$ktJ@RVXs4#u->BJNi$uvfkZkA zo+KR!8pQfx?6=U0=R&k7+Z)8=IrgxHVvf4y zmeD6T9m0eKc6bI%?ZzIZ)sOU;XAlD!N+KS0hZ3Vcv?WT$MyYsiw1!qFRlpM{M z(B`Bi648XrM{-JKyg~wz!jvCCE*w5+Z^1^3UtuvZeFm%5RH{H760rw{f^-iW3JK|W zC`5%uJ%gbjeTIg@gW^%tbZ*hT%L6|U!h3#bC^X4I2o2MDiGGabOkupa$oo_6E0xej zCDmYu75iDM?O|W6HWXbK=nqQmfPGMZA2oE#w}Kl?|D`O+7L#{dq%7TSwpa zXFqdsesV)cKY8DuUGG9z-V)P+!#?z9|LsChYX5k5uPzy@;A=)VoC5|XY0oD3P%;{r zO4sTiM+qj%dcyaKwC9BNWm~P_N>Y~&B8hk_XGwtz+ZaBq;R+qBTEDYs`Syp8kitsn znxtD*Jd*R$1uhbG|I1H<%}l#94Q3|w3VZh`_}(!F>GJTf9iP)Z>B(*jvp`^*xP~BN ztXvXbu?XeFrX~<&zq}(P{k|`98#6{`~-v= z^IJZ}dj7{16FBiSUY4BK> z+l`e50mRVbiAVWKrLy7;oMEnTPl;PVFI~Q@3`|MAy}@Gb7MTX`(l{GjkrwcW&67;Z zp2)J_iqhb;pTd^&r;>(?$8*It1_}wMs+~L;Q$^+j3WU8jL0JaB8+%P0wfp- zjs!?C5*-PUVI(;cATO0kez4_)Knu*wN}t_nv(i_0Fe?R2;y_esSugiap$_GWut^+& zD}D4$-_w$^EQxw2N9<&bfvMwhP=z)K)`dqwyv|+0vea2IX+Bt%rj$(PbnY7CykZ2h zEIBOx5w=AcuaIRa^w1ckhitgE7c0~4h($Kr9_*q74ka*WJ+OO(=C=S58Oxp~W?&x){clAxIf|Kd`a%!3 z5+l6_TZPd>4|Wnpbsp?wj2?NgQ!sk$!A_mY>Cz`;sVjThdsi>YGe~|sKxF-aJ#8%b zDi{U&g$Fxba@=)3rElT=96(l@4lGyc!nrxFrLmF#nz#lJ_DWaDZBWWk1-$cM$GQ?- z!b8T>k3M)f$Gfr@umuwtZt->E;vO40b|Xv{9QBES9t5Qv(#_$nkufjAY_Z;zead#= zyI#7}31vCI)*(~sASCz$i*if7n44j&IBX&YO>1@q=MsamHd6FQ(E!k@l9#-6M{S3N z8b}`!;3l5&57Mu4&l@(iK4ZA@LxX?*g|}1xuEcT@i=EzZ9SIvH@tX&<5RV zt^y-w4OfSexQ5Hr(J2)*Trozv8gPZ7-!boVbr^~7bD2f-JH>sj7$e<%t^yC0E4AZh?NYx6EuY6x(K&7-JAu}-4%N?0M?Mgdj31@iRVk{d6+NpR4G5Ng!3Y!=%i;o*#Q`x^JE8N zbl#I4gwX|0_85%vJ=tS1D)3~F!>G`+Vx_!nCADZW3Ify^kp%%7jCchBIT#5F0<;*3 z3IfV7k`x5!F_IPpG+-nv2q1S98@ZY$=Gr`!i1Rqpm~Laz*3zorR~|?eY{$_tNIMQP zO#QaL18ta+&*cl6ct;aE(1t1XT)wbLl)0e;ZJ5%`<%^nV!&J^Ux}%!MNV`xm3K^z) zGJm2BQ~Vyu^{McQCg=)&6g@-MB^3$bN}$}e8YiF7jvlTe>&70z)}~_>Oy`#;MejkR z1n-eZrH#)pi4?<&dckM_h$CfLzE+pIE#D3WYl10aV3_cd%_1@ zgCYNJ>;HhOL&;sJWTLz;ayxsj`h{4uLI-CV4qbvPSa$p%?+Y8@duV^C>AMtKVb}lE5hEBx#FL?n#owDDNccfYF7Mq%%g^lcdX9DAU*_&BrzF z)D~ccSN%Hj@j7bI2rdTjFcMu15MU&^7$Cw(dNDwPk?dlCboY9y6#2yfat!s6^4Y`0 zc)itE528=wIxThnh3`p-zdh>?JQlQ|&3fLYU9e?a278Bun7Z6cqA{vGYv_(!fz_j;i}&cq z!;t|ks-KtRFfuAl?f{1|)lk1aBCt8<1^p&F|1ale3zQ%|k{DIJiL~ zALU;4*uA0Z76r8&s81lT%K@%%_s{@WgTLD2(byupJ6Y3YajD);3w}e>oak zWcMWNo3KSTI{8;tYX|srCffAyRo;APDs7Q%ba2@XAH$#L!^d!Q(B00r{?<#w^xIg3 z2khXtd{mk4AzCt;{(deE6K?g6ng-S`tz|M?MHJj7`l138)0o#BGZRyyqm^5g9qacN zGZPczP_p4RqQ*os6O;4e2q{>tgoQoLtX7j|I`4%dLz>Nt-0o-vUUICVt!72ae@(_0 z63v{6E+jhb0EF@a`9SEG$3UnGx#vkNG91&s=OoWijz+_KwD|KOt+W-*17#AV6q8u* zptMTz!^w)T7M&axS_dqEP7zst`e_C$1ad|7nvx_rdWM+4Qy#j|Iv8Fm#F3%wkYD<^ z?*O$-%#tA9jv1_PAf~@?HZ+f4h8Z zy>U*($TTkeODULg94kf=w?@)x;iY^MBj=@jXN-7D`8%^k*d&<)eFBXb5vraDsyDH;deJ3l^CgR zIcefms7^!YKvijh6<&H{>ROODWzW7~*qY@)kg(IsUes|kZiIM8IBg-AXyW^bz)q({FxDTFag}JW4PHt z<>xr%2EOBUNV~sK1W}=9F&*Apz5qW}MaY7FsV^0IJT&F`6AP9ruCh zctMrZEK>tTwiV$Cdxno^#!(h z2Wn2_m4K$5MxqkXw9`mh0-AOj$xA@fP9seTXxeF{ElJj6q%W}y>7bT$pjsdaC$?cU zEu7enQD``^6Qk+jL)#wa|T*o)Dea6*HTG@Ll}&cx%v>Xo{69+L55 zkElawI~-_j@L{A4E_)B!;9Ktv+Tbxev9`uK2HMyYn}!IW;OfL4wHaswKd6w4q&+|Z~(KA_`3?YJVsuvDPjC6YKn)4P737XE3qK zKm)0n(^0!1kFS}`hgOs<2sdp{1gejYf;?D>|3WKq$r$^hUzHbV?#OcycY>RMI}@4= z-08CmbLYrF_WJ(LV|HQg9399`>`yG%g}L+FK=uaAodE;9x-=XH?uerhcSHl9(A?1+ z*O#poOq2`+?uZ6n&Qs@eIP*+?NBTlbP<0APl5H_GBUbE_%<}! zIw>)-nB~Okh2PC`QetEg?!!t%Wgc4>G zY>d{-Caf@8H=7_ZN}Nq_G1AT^Y^&%}ie}?mY@}JVkAgPKazDG=A;C!YOHsBZKXbx- z#N*~~&=zW?-_O#W`fyM$-w5Bm0(L>~NpCb!_p&$)GL*d?^vrR*R zwD`j*>fy{#Gm%fAyS)1*qfZ7T8mo4?mG(0?C={*;GcFE} zPH?GBmu21F%BA&l+6!bxrZhC2OW)K~dV$=X$*V}`(*93n7f4Me9r2** zOirleuo6CKJUn*83A+3WxUgD^PSDkn#mYxbIxlfPe4?=O-QNdjdb-%!yq?YwTPApB zemFv==~TsB?=BT zvsxX>H#$Of?qX)O`hJk@q@3<-&KOHvbrb8yACG}4ZV(J_>7d|z2j_-NG!JBPktu7?CzuC} zrmV(%K}BZS;UX__0qT?jUK3o{4!P_mAKIiIblXt@_Npsh%tClsH|gaj4O}hz6+V{r zlYb+=cu&-$l+9Ti%iwVvdMr6cw0ciuc*A3_JsTSa72ElHRvi^x^-kJc%6_&pp!j;H zw5wn=OWMrzKE?L)so|Xo_{0Szj>y}~fjaEoP{=z|1XlGkEbxx-?%R2L6&E7FWt7Q^ zA@$%ZGVkop4r1!Sx$cF{JWqC{CqgC@3J5!ZCE)bnj3eM1ql-D*ojC8MUn1kte%r`b-jKptj6c{Pq z*yu3Qy|H0><9y9FVvNMiHVTXs%{DrWbj>!*AvoVB8!<-WPc{mS6rXH#80kLQFn!Jh zAboN3HJK_15vlR^>|6oALV;w8*qlNDm^^WdjRGS@OFPnA+L8XH9qC`%k^ZF}>0jEB z{0>{cHU12Y4Q;@o)lt~t0za!^9*HN@?zE+hZoa1vkuq>`mZxA(0^$+U!m2o zS;FnvpnFnhBi)nWy*H(XRe0Jj@&w(}0#8s7SzIllt;w>Ds-C;=a`E-+q9K%a?83k_ z>YR-3xXnR{;vSkqGYv(14;cbS{`r4E{hh8}g(Lq}!#fU9UNs!~_Zb35{snGVugU_Y ze*{1S>nC5MmAJnPvjXz7wOq;JxNL+2}zB1%7sEUJ46ghS0juPmS&SB zbW$~B#7L78D8MKuB~XNsHYHGkQCUi$6eE2~pbVpil)xG;Q3Embj8)2nY-@~`3)x(Z zRtVWP7_AhtZ82IUWZPjBD`eYa6enbNz$jkGCNWA7n#Gx`uY)*qrP;hLKGRBSxRilL z-xL52D_I92)LiyK89F5%Ih#-GY9ubo^If*Bm>oes32|Xv=#r$Iy`O$E%|!`MsFnLC zr`_D|zk~gObXhW-63_W%7i>L{_rAy_XjaWgzrc#0)AiSdM9*aKS*hhBpOv?N zu(nIiDYuhalXZ$Y(f$b%u*XQi0hjU)X=+)Q&v+lshTe+anq_aOs!Zq3EqiFZ-H z>$16SF6K^x3+sKCq<3>McUHS7;R(f^b51+C!C`3kwjOcEx%=-lcl5Jlm*L`P-Y(#d zbN4ad*Lx+a1{*g?3uP%PI!xWmM;fqb6;SINbF{qzBk^c^9Y%`L_RJsYl)BOOVvLx< z_6m%|!S*^`QISb+tm|%12dfioEB?$?e5Qv`Z$GvYqdtCY6-Isi*hv`m^J6DtRVR^E?Iaos-__pDT;?G86e4s!>^mavN#} ze?I7OfaPj0FoUQKMdw|s25VT5JJdh10lMWkQTmu_B4#DBhH0;XHSFY;245Jf{cT{2 zNe){94K$g4>PY(tudAI(fhxBWyT#hcmP@=SPc5;lJ$Pz~_XY?UL3@h~;V-u*(Qy1LPW3gJZ?zFv1wqPPZ985Vh#$+^d=% zsgK+bUO-=ji8Bb*-&Ug>pN0`}DoN%^Gkw48T7&@f{k-qJ%=G<{E1kE2zJF798*D_h6byzUoZ9x&?2F0XX{9G1BN z4pA+GPI(wZp8us!DdbA&TQ*j^Ysn@N#M@Z7Q;J@m^=cejx@!ij0KS78QWhlT|I()( zQd+GG9jk}Oh45HHEJ!MV(^fytafDEIh0Hyg#u;SfPR1%!_A=YTt}&%QeuG|>b=;%5 zekj}WmpZiv5Mkfy)5`AY62VKQZYYI9JYW7=e5SBu2M$+^sQsb^cj)`e3Ojp ztyT+ZRn$76?J4#l7(l#(7E=zuHp2O!#Shm-e@1M!?>!t27z`O6IV0`5tXNR4gu}Yq zr#wf}JkbWunMsEq?HtO#!a}j`PI*OHp7MHWoSdq}9`?b48eZA&QW!!(#yZt@Zc6)pf5$}S_)lYU$;t7-k*OYIa7SZ(xd|z z%(%_Z-0PS`aZY`cPp+9sp^TgOyuC#?S2%C)rslHDF4uN(y*z^`ok689h1?jG-Q+uC zRCSZj!$^OV?~YOZO}+r5hMRn0>5NSGA3PTE)k~H}P@(TE2D|>6&5Ej`)J=Y_?{MUp=gZ4fHo~|(e%>N}#sbNn+t#mhlT!wOg(9O*7!oUq zYv0&~)?te@2lh+Gh3%P;EP9@6xq5#mW#&0Y*3nhg(w&;5lR1HRqZ&S4x@xD#fVZWj6~B@8J}@|(^GjE$)~4^ zFw#s`2N@9tP82uFs@ssF$rLn|Tj4sC# zYK*SN65BAk9!qS;s4SM)iP5cCA_Jp}SmHkz-H9c3V^kGO?8WF_ETO@uCYCsW(Sumx z5Jvh~;xI;au|(zuM~}?;JEb6!x2%vzisKgcb3gqLk$l-+B=vL5MUpYsTqJpO%|%i) z*IXo}bInCkKG$3%HFM2HQajgNB=vL6MUpY!TqJq(%|%i)-&`c6^UXz4J|BtXepWE} zH<~UN*YZ|zT%AC<>FNZ^MO|D}KT6NYYr?T|lXag`Zf0E5(F@J|l@LF#C>$+|S5eX0JvB(V>4WO$#UthB z_D$6tWzUduvkayai{D-O@u1xFb%NEFnG{DgaB%Sbhq0HhT2@i>q-u{VOLvu0kz4F_ z#c@mP@A7$RIAm$_U04kmN9d$Et}G^vR5`9nj4V#NvRL@tNmnIC&@z?@jh7zJ1J!!VLm^Q9PtRr8}Tl2-F&7)4d{6EKoh^W_*NRP)uB z7nzKOyg1JLqr7XDM>Rk8stBj`zr$TY|63QIEM?TkCmZy?nKrckSO1F4JF5;oYD4RP z>tB%<&hjdb+R*ynhF7HaEUo`-ex?sCdye+;D+VI{@0-Rcl>S$p8hEGxRGn{-s`HKe z`M{T+E}OZ}!^>>C9ZNgosA2m%9ZF(Q0yZ{59}5Cj6^jHc@nSq>NK$p`v2HdPk`iVk zLsISPeUzJLPP%2BkRDZT6Z5R>Epf+jH>cik?Yv+ z3`W(fi@>N_GR}P;jHCz2&sVP(8<)o@id-bV^kkcSYh-$ zo**!K8BcIAdL2*LV$={%Z2b|+GzK5M8YkIKnY_NbmdqxKW>dSYs#-pYk-nDij8T0p zpNCOHE#DoZ=32hs0bL5?AwTVaX`^cqLPf8WsiwImb-FxTZ`x6*LS2(I#{@3GmiErN7!s%_^`YJwl?BB9NUT`-mypwnhi$oN&r!MpW{hJmhQ9_gx+;T9Np?A2Y4#yUL2V*z zx`PvwK6eF!-W;&jrL^dvfu>KpxKx>THGKGNKp!WG!NrLl#9A@+6;c%k(fEjPMbTvEi>k$wy_#x78G z`yY%d&(kB>v%lLt(jExSDMYHTmSP}OhqMlv%z+((d|QjWm~Ak;i&}Y~zmeuN88m8x zkhZ0*8Rh|_@k)E8Sk%ON0iDu{I_0aX;f9wE^2ApE^eIz&K>zM&X zZQv1A1_w8X{qB}j=A|R-LVffWBb3CoPHK4U7I@rkbD5Xe8Xg}j8P2jsmLzz5M0|

v0KHuLiBq5$3cXpPI-H3^r05EDP0s^?# zI%S~(GS%~-6=1@WZd4KW@EKVaFwh}98zof-7TX;Wekl-ue!a}bGX4d!dp#~Q>0fkV z9gYe?6?s$$j$4C!2a12K%h_mgtz|Lv+%l@A%z^J^?w96{t z9p^le$_!g0cMtVa#UNoCEK8hJ7>jv)7w&PZ)RX~TUEl+%bpH&cOSk&PUxt!hM}+bH zqWC^({D&!}jMN%^vNyqDF&GJOkNIDDY96P~&A+bv=Mkh;Te?r@jzhKYY_>Lz7{l-Q z)70?N$8P@0^5$0@jjqk@T2PQ(&X-QNz9`+kayt$!@op|jp5za?r&XRCI{T zc(jk9!ixO#=0&ca4SU|)xvvQFU-3TwME7M@KG6Hk!FyUt*Vp<-|K2%i>Z8-H zZbG$zfobVApJRUPoHV!k5B^Nje@pG6r*zj$?HKu`<6ucqlJMFIG9=K;tQ%6hKq+VY zPlp_jd}oWBePX(=vV2&c7o&277w4bMg@(WHI|*Jq=x%CwLG&q?s|Fqno7$4Q+SK+p zzNUuG?LM6SurfaSITY|LH#&~=-&1LIA0#xUXT7NDe=wrXE&fLJ5YEGLznEJtkmHX5 z@sQ(1%3a8@cUCX-VnZQQ)?gB`HV6J+Jv&6O+^bn2C#vn2JWEVB{CN{ApH)4v*kLitP z)2p8vX*!#{Qe>R+K$_3JE>8g-p59ygyJ3#^|O90 zR(frP;W4MB@#u`U+E5@(>1LRT&~R6Fqig1K@xF)Xt7OoXqOsne!b_xfwdr{Gtp zt9)m-hBmfQg|-G78W%Wg`H z-t0I55=?uMziigQCoccQX>!w5s?wl?yk)0v1`DzhT!vD-f&oO^GpY>{L^$Q%s{} z#rxCueQo{sR1Yw|_K!ISuPYBf5H#n;ytP$Lt$li30PE7hPV~ax@~Zsl177p#+on3I zch>TobGuS~Qo5(o*Hr1$>YO5d`T4*5+4|q#r`NuRnRp(WiRZsK_t4YU|5|liz!20U zu&Aa@o%O6pXP5?NxcN|vYIE9x>A}>T)EvC!y?^wpZni6UNap9ZG%n`k{@Wy6F*JI+ zy^SG<->Y(8+qaYar?nVH{Y(GoKM8xos674Jt!YcOsh>ZMw+aSsx79ftb|MwI(a*cy zG9*KSy`nOI>f9~=b<0)f_Vl=px9Jf~K}~0-b{}e-9Ecm;pI92GOj{3w=8(nzX@OhO zl+W}v?gtHha&3yBQ~|##M!hZ{JHlAYw5GFD=N{31XuqgOK03cYo^-A&Uo=|%Pq8BA zRG&Gg6ZuEJhrjbvyT7oq_>^|7CGFaq#OQ5gIz^=K>++}8Z%yOaKk`vxbTQGv_;OoI z>EuV_l&IiNw}`E#Jt;KICG~&q^6K1D8_}kB(>%itdqop-8MX5LuN#SJm7o5!nGgLx zx)*=)U}|R``Qqaf?yCR!6gH9{FQMh8ed?dcbZT*odtLs~`mJe#f#iwNMNVZDp>g>w zxeMH2ML>NFi{6s|xYu1QR542k+W7p4Z68&3NV#@o&7t{u()Rmv1 zepPG-7ORFyo{#RB}|AX-_{3Bm>x^C)|$LFV3Ypi-3dD~nX{mef)y&GNlAA4F(-=2cH zCHVdem`AHw?EbZ&0X&b_^LeKwf$FB@Z_eyvJG3HVN+U8kEe!T@by`FRGvq7=mck0FGYK#(kY~j8+U1|bg8Q#qHTVLn5-z)C%8~o`N});Lkq=Bt zaouVh=_}rxnOl84$jw;G%L`MhPhv3Fg27<$2QwJ&m<)zZ3M1i7gpOp?kMa_2Ryrp! z(q2U9LP!9!{!S@pv0$kT`%*^)5%lu;QzqMCe#-3%-NZHkZ{SQOTl5hLXv@ zij`?IwzK^VwMXsdq@-FVk&FR`G+_@Tm>mi2b$gp_#qMchLcZPAkvEp%P{Px^%#jlE zpIlKD>X$ikf+3!dYEOdSj3fs9X7kt|hnt8hD>o6fdbpwLazoXk!-yzlM<_&2rq!d2 z($OZO%AT2sl8-{gW<4_!#r@!dYW(ulKvdN=Ow<8CRC3xCL&=(*#mf5S-4=#)dMrRa z!|Mk`G1x*wnw(2O6oWO{-e%;rW10PxhXynH45ssK~2l3^H>(8S!QL052YcFIj4$V%`7~T~&*bC=MUs zB}`%xX}yj>EhJ;$055mBm-5;2uvpV;+!9A%+>*#O)1d^*!30xqqDW~d z^YOQ+5b}7DlI5JF`OZM~IuAs3)i)7LCc)gpwb_b>TNJ=Z)Lo)qxG)%b-$dy6wjz0x(1sWV^kgG z{0EVcbjJaP3>v4I2Xgd=K@{qMn6BS%V7ev{wMS;!q*l z6Dmtf0il=!O%f3+h9j!w?K?cZMI%S}?i`x~7Z+!Y8Lu`@T*K@gNfw#69Ho(OF_5st z#f!&cmQ_o|D-6OTKLa%qa*&Esc2J`fU;rUMY3E2Xc}dYQ-7#|Dx19X)JsL`yg(}N5 z!$f_yo0kwxGK}w3XuKThW*9GHha#eQhYdst3>oXAOajo=&Lk39QK;B1c8)GZ40h0u z2T+ZSHd{vrEf1p1dxgoRyrd%0Hn)+6l5@HmN}e_(j2tzk*lU8)$R_;)AWE(uXh;*= z6^IgLT#+h{%;++7j5>FLixs;_nr!q%aMPBhR?ihnVCv&9+{FdBj?H* zpG8h8%P|Z>HPFgtBZEu{hc+;Iwt3E8?ktgW%-}EyTP~=I%+U~xh|E8w;*N|W(I%Y| zL@2`1n9<9aG0d1T$Cz=BF=L)FW0WzYcB3(47h^^bW5!Tp#!O?zDaMQ$#*FT!8ci7; z4D1{2j40*lq%fHWd>#P5MSmd4{NoHPM!GwQYZdp<95Qp32pYfCCXGSPk4A~JP-6LM zlsF3|j<}5yXQISOp(wFr5lXyniLP95u@{>$*g1I{Z}@2qOc$j`tzj`LOX4$v78a3= z-rf#w12!CYu;r{0utawA*iYUREK1zEay6&ED*y{t;Z>~6tUtN#cT{&;^sV{=D zlFFq6l~2F4`HkM(vF60?cfGEh8h8(quibQ`uJ+N2Nlj3CcGIXY=N7F^Z552=T|a00 z(PQLqS&&@)$I;^NJ)RAE4ao~#${!9&Uu)BP#hH0&>nr!v{KVsJg0X^~3#{+7^!WH1 zlFNOL?i_LceeYIC{<8Y#g%jKI5`a{3{}XvN9|zn045Z2qE{@uDiu+;#w0-mN>dVbx z#ZPv1n8$v2eBa8%=Nsyu6+7!}&$Kn48ktt~*?X+)ufat1sHlwv&@;J$4o?@2DtLPy zl8=~p>+zMnIal5kJB!y8JlpzPOouaHyvNFG7Y*23c(htr2&J#x_aN*3Ilp_5{8QAe zLvsrM7|;Zz?|FaN{&K5uW2^UA(Y|#q&-*Q@JqpR2SIs!mw)pdq*N{9a{3pTJwx(+> zkbLfQx7k~lWc=3V4aac3vu^y>^6nWbeg7a;pX<9~zCv=%(Sqj#+umjXsfs1FeoHIQ z{B{6Hm86WG@@W6AgVE4-_1m-ep0(A!+BIk%d$CL9k}Ksa?;@_QesFfv&b-1R;41Iu zLp|>Nk)E0lJ+r&w^c@UhXsS3~*JxFZ02{!T0wL zgMPp8_S#<$UCs*;S7W}vx77LY3VX!W*9D^o6!1$CS2aH^SbLbi>%O=fvuZl zUQl}P6C2k~I+#6bBQR&+yC>I*3ac*ihm6f>_1#dIm039&N^gAh<=lIZH>+(RxqHE_ zM;B_Bzib1+`r*OwAC!;3zk*nwzH05cCD}V)Lg_ir&KYZh#+{ZXht&m)<-}0B=l>ZF)r2|Ncr|S33*WKP+v2cZ6()0`Hiu!2l zardKVHe@_+^JWRgijM2Si2r5ZbC4R%;^r3p)boQbL)No<{?Pw@q;=9q7~E?L*4|LA zB3ix*#>$TUQF>g~{od(9=%ppuchKKwNN(!&v^L=DoZ5zKwJ%>kZ4r!>US5@V@6g-_ zzL30hbgCbk4N2RMku z>W7ll_YN1QxQ-aZ?IU3?kPyF0Y%2tqy->7zUCWvLnKypbOBFHzO)zN-v-2`nuRz;d=PfgMwoo z=U`-OKM!oV+_LA*6G*Ndy)WfmX-XU#*?zA7nSc4`_5P5Y`NM0!sP_pYx5CJlA3n6~ zz~X{aXCZl7mc5}{(8z9C_IW{7%9>fIF@8pQ{2kb0F{l0z-g^VdcMDD&KX>U3zYvl){cy^2!<92t@1X5p+djS6 zH{e1YjO?6(lS@;F{qQUbI`XBkE6jP%eyf4xLlO&@T~MuY{|u$CZ*9yvE?j*CMz;1* z!GC`Gf9$<|JXG8JKRnL{gQ<{^wn_gf7SoulyYwQ4{~X3LW^>HELAP7>92TfXtU*u4Rs z%&{|7?|kAihM$JkJ4c+0ocsQlyKg{XPG7gnl5JR82J*$ruus1HZe&U)6i+@IUH@g_ zzVKU6e4N$$qLJ@2*R_H4cDxMV_v}p4QIKBcoeS3TmU)RQe3ScBt%`mZ66~1Wy5?$2Nd7DW4}6n>&sA> znxs?jIv?5|aqt0jmHXzxsGg`%89&22-hO;@mdh zE9!Nppe*I+l;l2-#cxY;7p3wJnyeDiDWm&b51V$B*KdXOc{hXqbC4tfA-U#4F zP*+*uIq4)o5oZ<$)VKb9c5ShmL5(ALZv`pOCy_n!GH0W@c@ZiS@HXcnpgKZ8_Q=QF z0s*ZNDiRQxV+i0z3dkOXo7*FR7pWqF#pXN&2qOh#k0Q-I5Fn0JkwBC=zi0-YOYs*T z1ab{_%e5)xTA}`X~IUSO3=Z z>Ywx8UgBMFQmJ! z6G_o>nN!fcNJ13|5EL_~As{JKi2zkG^G5`vg{l!CxWEiVKzb-i$V$DrcLMXJEu;x# ztRo^Esq2{)EdOx~I{&REsF(H|3U90*DM~%&zZgYV#~VJGVBdEH5g{Y0r+r0%jjwN> z;!gx8Ur8#&@=7uY23G2(!5e^m1)Nx^K{ofd%4euHG&7B6GB)z{J8bP3gxn~8@dVOF zH5=sShalpl@|SUNvPp9?LKmY-wP;^}^YD(Rf%~SpnKm$#Y9Xn|Px{~|XQ7ix3)A7G zFMjf&?xfVB<>^Y%eNw`n=p>99GoA=6&1KHy4+9s`G<~=m(Od|mRu}>kz+wm~RZQI@ zSs(^7?oS7LNW7oKH`zQ=sDY~`idbaH}8gYmUIX4kfkZ`kQssSkiA#n zaqTOAM-O?Ob9JAiYVsi;HeBOx?};8V)P2oJgSH+aa+7m0_g+Ik>A z-L4fL@{wSVsoPPgv~U^9@Rzp6w{JVaNg{p1l%*2FeR=^80n}P%0s;ya2oWHyWhNrv z;sUYcCPAy2JA%mSC#aK7X&pkm87up2! zUj(K*w7EsCnr7<+B2Kx;Pd}oZpi0GtF=jqb0Z)CmelWs=EFeBv!~>qx!;RpLTB*lc zxFjkAn4|CD?zeq>YWb6q)jJ_{b9wso9Dro(K9{iyrKwX3A znJh{h?IRLF?Bfpt`)PdtRS=IX>@px@lqEeH>eIf?WGu71#|n>czE#{S=GmX8$;>zX zLa%dEc^cl+d%e#Ic^1c@Zb(V5a=5PplsHYf1L<5})q^I)! zoFq<6q`n&OSv_TioF?A@{l_*!P?oK66ZBt4Tc#4F_5Wk9U78$X7n0}D||^-hFMy z`+6LYJ5+}iAK`pCp1!1=E@a>2e6e$QyoZn-?5b!H4Lu+Gs~8tOAEl{WL=Imthei&c z(1mS)Dg|7{c&HNJOs*y?&E(L?;hW>a?u#l#UB!f`lG02rB8P7_hei&c$c1f*DkWXT z#HbSAT&}*i#{k=F7r#{(p(EeJsRI!e4+0l|i!DN6$J9nohPz527EHYnVrarl{Rw(H zUiuCGCJ5VUr$V&M42)pPyCSip`xdt^B`Fevi7ya)+zjGRO6iPD#U}WSWa1aIX9lw& zp3<8JFIO<0!RyNq<`Gj3KO!7@C}wbffXKJzqfj(TFGP1=d`_I}oGteU%_6>b@>Lv~ zU^5|kH7;x!!GItdD3@Rk?8UklOEHkXPd_qJccdIB=cn%_Ub9~~vam>^4Ydn9X)p0^ zCSAwuQ6;{OTuD}1%aKK(O3g0m1N_m^YZx`uMBa5wxR|)E+D0z=jkqp-1B@D~^btl4 zRjR&@iBTodAi0{Xw21@qN!)CQ3;UQ@_k3O2Z}6=hS_k3LliiaEjb#!KB9)SU*Jnwl zzrd3_t)75FC({>mlZ8Q^ZiF&Uoywx>>laQyQ6ia(xx3QP+`n!*Gwc{>^!f*3K6l6j zo)nRWG*2A+z4|oG>%p zN0K7{i_-*w;F2LQ(jw+UB!SVR6-S>ZN>V@qpPS7?5?Fj&nF3+9ULZG#v&Do2DzUM2 zZ)m&)Z+s!u9DRZR;_NOxO6EsQhhSQ5yj$CP;t^GuWQkHiEFa@{ki?vei{)lqEa&24 zxq%}RByG_omNUl_%b79e;v}ICaGj_la2+~6@9^jgI6T*Vc9~mXLQpV0)3+Er2_V{U zz=wcc3PaGM8G@Np(GXNG;a>S|#z`~;kl0;nt^dbVGz97;+^XM*Ay90d#>|Bw5DbJN zkdLfGLm>QVdUI9ZI(!Ho%YqEuybx8VWFTEYe77j@8a!E(Zh8ijtV__ z1|z6}2(aXFP+}>io;*TV2;Mg|gFFpnxmzD(Y!lU*kY)Kmqp0WnEHdc# zTA1#3D8_UY0U2L`^>@HVfNm6E0QJiS))4edD}*cats!b%Fp)0imOsa%)-AR|)OrU~ zU=6qM5Pu*m?Y9!)k0K|zIO^nF*-GdPV{ez@no&QViY8urbPz0RppS#mtmkv4+7 zhBTcu0T#JhgGKUE?y3bV6enQ$vv~^`K)jAkAfgvVC#b(Y{nJU{1p$jRFYKjILZI6TJjtYNvACdN^1oX>a%U}ki+ zF>PE`--{J^jN@jk7)p9WN%aH0ayWejKkc_3Q&x!O+m4%o0n?<78?sd2Wg5=}9xrAx z$(u8_DZrsmKIfCrL~y}k$_*#Mc>LzZTeR^Y^<0IK%a?FG+nPx`*V@s&#YwtZM+aq* zw+xGz9APv|ugsCm~uL>=(F0 zca#*(<&1YHR;pP7{-+uuB=i(Za4(p6Ajd5klTfAfXs$J?B+e2@$Vyo(!F^QelqF*- zs+1YcbwZW$vII)9QZ`HQ0986;$(W8R6-RTuQKi-_0p)WD&V;GFDwa0|Uv;AUewrZs zJ#Zq^$Bekq0?YQ0YSdVvC4=ni?%mhi-hI8)yRYrN z``Vh_yRXH)`}(|hUz3(WUuROmV7lC;PRR0^S>@VZ;S1E4K4xa^A;j<BR~gAmcj%aknFFUnZsMpO9xr+a7QG8(kTRET3RFq zL~~m6!O;Q9eoPPiqmTmi=~H(3gKewT5^P($t^mtrxy^+ zB@XD)l;*>Aw;;7>j?FXpGvYL&V0W@&_@xprQ$Y-uK}1wbZ?UL?5H4YKqE@kfF?;-{ z?%fhpv81Begvh#DG2{z~HmYK{+%?1t$`{mw&812CXIvjbx+y-Motu zT}9Pm{L#cLs~%53Neg=LJduAJtSMBxpuvPSMOcBAXU~=G!Y9jJ!s)V4uSFr%kvNs* z*?B_R!6{^ltLo*vWkAw*8pgk3Y9=8+B2k&(j1gpC18Ww?@oYcBa7FEA6@hTpfrz*k zGo~To?67Kj0*0$)Ae`#HBav{nF3*F;1v$8IPP<4n9xg0Xtt@#+np|CxaV&yZ)wu7ZWJ`aQaH=54dT68ruan57HOFeN{1j zg!|Gm5By^(Iyq(m?dU+ z8WmDNp?rpy3)p8prO&T`eYS~_*e{>4Oeb(6?oH}r5OWd86~{+~#9U>s7d{qpS`CD| zY4;D*U5LUuZk_fNzQJp7G4JP#K7j2j_=aQGkoH;3oVcBQOEuc2I?cKuRQ&}KEyn9p z%)@YreuGPN7%tIyHeCr6bB#pL!6iD79R?D8aS_}?GIK86f>xqWSS8S_fGFZBEUl86 zOOqxtJ%A_xiC-Z_8u~WlOmMdQ-WR)UP%1RIR zrs>2whWAKxO7tEUF6>g+RBNGa1t-&U3eRyC<_E$lG6SKD z@aP31Qm1$3Si?ox*6W#zwK=MKos$O{Cpr+TPTt>Gb4kNB7a-D7 z>*1NHJq!z=e_BqrAWzKo>3%!@c33uh7+!HQ;Jla)H}3*BzuR_X4ZPY~b8Xb$BGI_} z`|X5y!2mAh#w%W3D})ZWje0>jnKh_aSh*jD$ya@q^wGC|M6Bse?9g z>ws#~8BUwt*LyxP(dr3LG(LQ{t#i|0`%nGCck^+`)){<7<0|{C84@|*h0IyDtwP-b z$**uI%$ak6u(U+I5*b4(aK5oTeD@LOmcjM{W5~N?jPfqE{|yJ2F_CwPE;gpwq9eK3 zSVkVvR~XX``K7V9?@4B)9grWm9dOGqOUfl z4I{DD#xnAVzQ&j~oW#}`%g7`8FUB-`V>xd6odZe6r{Z0I?%X=qe%kk4Pk;G+*Vld5 z^>yEMef@V`U;ka#H+w4OEUH=Vr zT?k!wORIovW|fAQ*i(&~cu5BEk1(A4ZVcA$1|A``{8t#4JPLS)S&&D_)PdWCtAN8(o)pFE0xg$0mDl2@3JJW6_n%^{D{UST5gDE$={MII?$VPf(q z^A#3P9_78lBn`Yw4AMh=(uqGW{Vj|!F8*|Qotw8`B29MV6cwUk+w=+wSROiLqmYTO z*0_ZSR28qESu;+u^*y-d^hZlPvdDP{_ zK0qFId$A9ZNBnW@jFrSkjr(3{?ZCducMu(8&)^Wo9_23F*dw_F#-1-4V^4_%BDzX@ zA`op81YNX1L|1uF1mbN{;bjX%bXEOCAlW7*ESRn3(qO753sjHHaGdDssK64Le6m>g z-oZ?J4WDTf2V2&x*E=ek#lf>dm7CG+rRMZr*`O>*>Xy~7M-gNtmyjt59&S9w=n2a#1?$3-@X?oNm--Nxh#eu{9CWs^h8Yx-IL3pyoE=~PEMLQrH6!$h7 zhDLG-ADZ?u;6?1w42}H&G&JI80)DD=hQk0fG-=NS!c?iyX#g6Uyk`PYDjJ&V>#88> z5BSiiBk-Zg>tjU>4gE4J#Q+qiv!L0r^Nw6(ofV0B1QTAp7Z(`iaU-};-p2%XFX$0}Ds#VLNnpFzAJv%F!S1cpy% z#Fku@^T{K{Rk?^fQeKry$RqVtxq>{RT$3w3ZV;_t*W}nX;v*F4sS^~XQH&>dYV&o; zWGw-p@hVTQEUU>4v>N4j{40~T<1d^@3&!cg;)-A4o5cFfh?lhEroSr`;rrl46g!@v z*4BA~u88lnDUFn#_ESA>poS}JMIHSI`gs&+ygsGi*nFCN_#ryjaa| zFpDu*kPDk733BnU=ko01dCXcZb$NnGtx{w=^~cjpOHf14E2p5dN=PVVl>$v{nWU;?|X7%JCXG<>O4_N(PDuQL_yRf7?EMQa@rOu`GWTtcC+_77K zV#DZtVAE5(V#8QRZ6Jk5+UIDlL?O#FI!8Ne6t^4mftRS;Onua-9e;TjQGMP?eBl(= zT-|5_ryBU$-x;b1s?S4-Z=w~Wzz8f40-uf{#Ya>Q@yE*D_{2Jo@`8SZ?W``=z91C& zLz646@UqpoRvfA{;kk5b{&^12oSuT0ft7A$-vgTD2t z>e~cE5~1F@a?o*U9Ec~i48-%6Vo+xv(}k==)4ob4)h4hE$TkIhbD5?b^yRN56xm!^ zdnT_v;~J)b=V=kOw9;Ei|IB%<@|Ld`Z~e;c7$O&A6rAU8xTPkron&VN=7!kX*@Ief ziV-F^U<*zxbBGJ%HN$oQc}?tAeIT#dorIIu+~*+jn(=QRnCs67Fl~IaJBQc4YWY6y z7Jrw^u_`@zU|S*^=dl^a-W2;|-X4j2q9JE6_Mh9#udH za#+JpVK)YlN7sN!ui)VRZn(`cac?*}60!ThfN{zMrvYS+)le7KLaEr=VrckFoEj$$ z?6PD1Z*c7_f20pLQQ`uZTJU~Si&9Ig_@*Qlo{AGeQ571hCYo)yG;0kvK2!wbO4 zcku99<-g=N32G}{V$Qhj{;%j6+!rkp^S`IQ`CNZ^>%V8SXio|o+>2P}jFr8H{-tKs zrrH0V=B{DL+T|}zCds@D14FyzzhQxIAAMnm4Xq6BJWn0V-4y$5*n0}W8W;5PeBl0W z`Z>xEJc>_j8n=?hQ&HLrXZm4=N(NnB%L1y!&kI6J1q6)+aG_Y>6f)mYC9aFDFaoN> z)xcfWXDo0!)vSK|XjYxXAIK^Tm+xOe@B1xa!K3cjDqW!qcDh21iMeVW)ltC4&VnPp zoTR`-zi8MhzH8q-)vBjL(*F6>oc~IJ5#GwuGQ|LS#ZQ*4|BfjJ7-cpt8i-Ro4KIW@ zag}q-zoQKC34dkcFKhqw$G>6aY%5_pIB7_r)J1`z9yG+XwsCYF!RGlLN9I{89i>me zjztNG|LcFCMXQ|4`nNO|Kp|5zCbrZBz9Jf0zM7(U4^9sFZ#)|Tv5$s_;~#|Y$TzMI zCdlprpFR0s@_KZ{USe<8k>5OGZOHf?-E#TkM$ zAB5oBZRh_YWTY~s6KFq&>X?v@5HCf?rl0;T^NUYLZDbt5RrLAN{t_h60gImP`fmd5 zv$O3)w+OyYodp9O-E?$vzzKNO%35>v#|3^mB9OY`MT25Bq+uF~3aRag4Fe(lj#)6{ z!K7~dWq4n-DD)ftGF*CgIkwRR-k8t8a5%z8M+y^g@V!s$3*G6W-&}cGrd79h4H$4PK#xevRv~HD)ACoa{}EAb8I(FbLjw_)b{_a2B;3;*G+ZBl zYseXf4AQa_NS+aV1)p&022s77xs^wu&Zu0EQzEGAG_(n-Vm(81%z}M{m)0cwjJM1` zq;1*L#GK!Yb07)lf{Y)`9XP$V1Af~Ge7Wd>6jlDLt$X9xT#@&5CP8v!TgY^ zVRgW^YB?o1|AqO1t6_eSi1Di{H*2T^x~!a9-Lf-Iuy>wMu8{C8lFB{O{O6`$Tbol% zFf(*5({eRXHUo8oAEjWKhEgNokq9Rb$ztKO71K2g50osGmU85yn~to5dlDz>P`hc# zIS3eGsGzxDv3~zIZh>U4yY0+ zwn-J5o43~7{C%F5B}QVw((%Tu;(?g)jai1qVH8@15_lD0^bJ|KT*0P@xf7qq^$Wvp zP&eb<*|Z#IC6cT+#7>5p+X=KpF2dzhTZ-z{NFYwJzl5_Z94%4;F%eJ$O)*dKd=V=E z)(Gx$L*SQS#fJ)kli^~m3@8^=Pc>Wr>U|^Q{{lR zN}ahipU3t;`;(zDX2Hnd=QH)V9NLVU$1)!ZEetbgctU zxl6k@k1rn`Aa}*+qs8(u@p5+|eM}PFBaQBrF89ilk5kIMtL5XBxy+3ef;8hJzt01Y zKUPoaSGZ!}?xkTED?+_dBxXV2!eSvSvYHjw%8Jfo#ZdHCV0u6E^;V1Yev#;{Rp_l3 zFEZ=^M=z;;GB_;!EggJfg~OYj!=f9#oF0_(IoFh(N;Rh%GpbZTIy9p?wNZ`8=rP5( zfp7dmV*FZZ{ElMsk#F)vdD*;*#%-7KxSNF*diL0_eGP8Wclcd#%TR~R8ZrOAzd_Bl zgMM3=oKS3$S{y64h#%x2Q4LV29F9>3?2Ty4RO>5J$4ZrTwaT$Z zHSjUj=?cxJNovzV9rQ`Caiql-?X7NEY8{=WX#TArrqqc;k&ohd-7&nQ=F^@-;@=ULQDw*L?^?-kG z!$=Q_h?~?WC5WaMQtCsRO@YyOlCp}@0hj#g0_a=z)$g=!E6Y+KPh zF~|vbOXx_EDac`hM>6f+DdvqC7C^ScpZ=o-P2n?)L(Q}4mq<};2WpO*iUeCqD@F4Y zhEXU3VV%{F5#+>$Gne7tDq;J}JQ=>2d&kUHBFX|h#<#lh@`#R;+n#kn1aU&d z^6sB5mG$^S&~9sYc27Wx!`sv<2=qSWurKlLwuT`tlPGUu?&N-0bmASv_*Q4XyOw|B zW`+k8FCEa9o!tE)7aoG*X}OBOUb<~N_Zo`Vt=#^_b4SMMe@xt5A6ZZicvnsEuB4AI7I-M> z-jx;O)N-qI`uJw}Up}y!YLP~@&V&C2)-BX~dQD1JQ?Ex9g@-}wSRESRGU{jQk+;Ur zYy1Z;qYIr^EQ3zyqL7>SK^Syr$*0cTOskg=Y2Uh4?ftdA>X8q`!gpQ1*|utA$NAgP ziPwQUvwAule*6M~`8$5c2E5#udL079x$irYJN~vV-vq7i>b!BtAz*9XYbd__*={?- zq9Y4{p+lN}bTPqKz3aFb2Z21O@7beXqG_%Qiw z61#>+;~Ri5XnuMI@%5POtnEoc2Apc&jCVe=Gka-tXVKg;2&ON+e}3D;AxodMLa2T& z&$*|$K0Cu3hGd1K?2N~uab9q}7tPmRT9)$uxevWrbmHWu-6cj3Kf|*s_kLDp|M5^tN-TdU-^fUaOA0;mDZ+XzSqQ6oSaz`ToqsQ7F|I&x;H_2*sAka&Y3#+A%1N-Yq&m-?MXl;P6bzsXt!YGBW}fd7G4*MIgfjGK@9sA3T9!I6}mB(E@R3OaB&tN0n`}g z76e?GB|?BShPf31RkI`rpvE$PML^9g1p=h8%!w@>T=78o^W^Nni~uQO zOgaPD&bl5!$xq8bhn0h72P*KjWmMo2iytYJ@Ug^l4{rTFSf%0Ip7>7~Mgi(PQunwB z=T-?NlK&0e1N*)C_1!N{sb4_%NDs5m$e;G2Aiy<(q3+4u5O?^J!A{_$Y5x7=<87re zFNTbsBx%DqMcvOz-b3-Gp6DKIVRwTE6wfor{-?IP>&NZDf27~tSb1mZxZei@*@26QOX<)qXZYB zh*}!T9=@Dc`vzxT$#%$Z_ig;J7iV7a-QzqcwY~Es6qmgI`PrS^r7v2c__6=!9{3FW ze^2)iN+l8r7pPlWU;e%9diIn1BS6zo9v^Y^`Kz4umN2=dMcDb08#w9I zQqVLimv|j=PQBjP48?61X!xS`fmSBy(rO^NJ^#@;AUD_l$2teG!P?ZPZ-`^El$ckF zL#hoJ`z$apjRkA(iI4pzB>BBG5EXwcHGik)MtS})wV6ArDT}|o(3H?WcK5!n_2k5#If=8hhP68M01ixT8=mS?lAK=hd z6!+2xQf)b76Qs_d8R+!|mE)wLQD_8L=_GI!-fO;U3p2U)hT4Kzq@tr*|7(SK0^b|guHIRE92%@!K!jnLFTtiKU)uCl)u#I=@gZ8 z)p#bcy-zj#O&8};tTIGVXU=XlLX&yYJY8KqcYoMEzNh14koun5GLMjvN^>8^S7MlW zp;9r@7Yocd2oQ!+S0LcLnJEIqp;8hkHnT=RS}1iTs!?J#6ajgmQWCgm=7fOiQ0mX9 z#$_`%1hj@qNkDDp4S*&rH&Pu$4GBNnpCI($E!<#>cN2oaSgDjHTw?~V$3({dzo5rx zm6q@H7_Dmauk{!UxR3t_^%y=Uj7`Z?mvC z9L!8C59REq7 z^Zrpgu60B4U8iCjey-oM80R+}QhV3Ms`jl7&JcGo#q&-|erzL5HKLNw^BPy2D)$CW zqRx9t+qiMJoU2AfFiy4PkMbybzXR+As?yRI>D!8u%KrhL#%-@p7hTf7z7l4p_OK^K z?>){J;{1k7s&Y1j?zij4Suv+xc1b?^K7J?8b-HJ4YT}x`FL3ihbJ@4tUv56}YJt|f z&b(>!IiIy0Cjd3bYF~KxZAmt;3i7ru9grWktgQ)dM*TH(tNWqL=Q^SF^&{?H9JsG6 z9ZVyFlG>#fSw-W@!8Ae;USdhy>lP}T><`^Dpgm*^0qe%cVe4+Xrq_>w%Qu~lJJ}VO z<@Oh}H@vH}>}K0|8=xzdI$v|2fAVJy*JEqen)kWb!1638-V#5r+p#Wj&08pb?Cy4# zsoPu(zyKso$fOHh|M+9^QsS!Z+U*cHr2fYg&Y69ET;MF^mJF=r#-Sg<7Q4<=gaPnVmUT5)E?VpsI7XG^VgJW8vM z^qpTVvC~5V>{Y>BJ%T`sI{vDFqDMUM&!%iK@#z!=sz~6gDHj2fKmiri_-1N>07amR z1UgMI1SkUqG(9tW*E&VN75J1VMFxbf-<1JoZoFwcVvrL|pL^spApCU#gUGacF!0Cz zZlb3G?h*rkqS%2D9=pQmmo&W!C;ANdL_(-XgbBli2wXaS$AkE@IE7F)s=X!{ZB&@C z$7w&dYxv{*)9|&S0en-=FA{m*k$o$;*WYv(WQ;VLLuC6HedRL`--2JfmkA=OS;LvY zZ^J_qm-c08^A_0W!eXxFp)9VV;8s@Vlx>RM+}|+;=0kx3o5i*M*&XcK3Roz2%M;15DsiLY)DN3YdnRZpQlYGq zpAn7LfivMMzOM!=tW&k6PTL!f*TUN>xLF_n=7~=({JO9(|`1;y6U!2Ibp<@uq&6~6hr)9pjcG}(Sg(q7s0Y$0j z!JD4nzMe=L6~XZRR@-{%aQw3dXs&AiirAc-`q!{t5Vd(3{k}NnB~Dl^8SB_^Y(&YC zRJi)OgC8&WK5FX3S)8X_I}`icIL zbL=m`6d@=+u_F1h_|=_G*GbBkULE^9Ep>NE*8h{d)<#)-FMW-mI15fFK-$)`hIDyhvJJX_k^8p_?8KLsnmJhBeGt{ zjkpOP@wALhmW@tvdoDw9#RtcCE$pwK@h3Oz(w$R>b3eD>!--)cj7AQ39?oOCGWyv+ z&0B}209&hI;MP(5UW^^p0TYxkJE#8S+4Xl9LvwkJrysruedLWBv}31Sju>%h&1-yu z3b=osQ?fqxK1@)eXKuCGKS%AvC#d8LW%+*xoIG0(#XlPE#x{54i~z<{>C}B6x82G6 zrkS9wMO9_3c0P$uP|1buoyRPHU-t(3w!|YNHneer55CfxUvl|C`Qj(FIPLA>+_}5g zr>0zh;-4;!XuMSN<}=Rn+;yh3wCiP0EnfW4q0HEmD~#~%4av>(#vjF(cMk?5fX;R= z(ub~&2(0+p(vZ^$yq zXBtb`xIMbLWb=^S!|WZdUGD~8;7jWblYKmYc1_|)?*ccFNJ^U1;QuWc9)Plmdz z?ZmoruX*R-$>g2L<&I>Do*g`6p#Wi0%-# z0}<-;u?j;{Ij2T_0CS=mw?cB+gmS(ClyiV4blq@64nys=46XV;JiT+ z0-~}NBv53KjDYxTH3<|Oq#+ zIr0&}FOW4MV3MN{0g?jQ-w2rMC_;d;K=uLw(;dYKpcKkpA|S|7f&hM@tQi5b9McdW zDU`iJz&u9<0+fZa*9chXnD;JhH~-YR$2uaGEBq744}O3FIHqspgAXlTKe~D=F_f7j zH?ltxk4t%;W;=*@j&Y6_2%r?v1|wjlqdfxnMYJIZSmWq{07(&TC<4|w1|UFLM6*M{ zCdWAlpj@O4LqNG>6ax4cX~Pk4%`qMUl8ZEZ1l)8?LV)rjZN#8R{hjANk_>;2gJ>Iw zP6y=jl}NBZ8?X>CCs#rOUknTo5S6PWfv*O}2#C+6s8Cbw27M8bl*=c9ZwBTFNY9m! zK!<@P0y1-zB+zMKgMi{(O8M5gvc-xFTloZjr2a0T)H64*1?D|%qynGwIA%=TAWa0= zHiLF;#&Pu+i2sx>1s~)nEyK+f{+*Q&EReYaqI%ZK`u&=C&`#n!)N;*kiY0_TOpcG# z&wzH9TTu7erGZ=Yk9{%>loPaR$SGX-UXG>wkLLFiaXf!Wu+5s#^eA#U&;N34JN((W z2L9}h?PSUVA-T0ySF?n644tEQ4(^?iNik1nekYtqR>k0^^|T1w<0%?9zq{e)cm6dz zHHY#s*xnb$fW2L|eI)GpU&Rx`gu~i5FPE7Fekek&XD`Jr45Lt%f+4;iW0pp*GXaIK z7f(HwgeE-!@U+##J0E$$i&l1F5Nn3zpw<~tUuPWqbTQXn}(O)`~8}{++JrHy1lpKymI9 z+k^Qp4&;A?;x$XP``kD?`1?2a-U@=fSG(!vz}8-S}~LN5|cQQ@h1iT{~ErYh|dx$S@8dgJU#zGX>_YJI;k5f@aL`* zf+TqEswcS-&PNw_yn$g-oOkwq>20zb1O2M?b#6zeb?s4H)Q0ST7uD7F)frDKu<1~n z-t#evNw%t$UC7JtFo*$w*uSx}6b!x+knd20A+> zG3V}G=eI|PLhIErdk>wfl8?l-w=)LU^PkzO(yd1x7fKEKlih5H3lKcufQ82ak>}`-_=!9VZ z%dfa3H-ZI*-_d*km*fU0UaiYtfeZd-_MM+o?{0Y8Nbjx2tw_d`il`QyYkVckh<^>e zQs$2c%*kP2!!vCt+Q!`+`d3dOJ|YGBAJ5#}*0cy-9O{p`BJZ0WDd+IZ-Fj`bwJ13s z_g3i&n^K!~`)CoaHEwJ!>9P89Iu_|?`xP0`qFMPY)jhn!= zL2DF0;%2N8%RQ4TKvm?Z#1)>&6L4mG!h}j4WYkP7b-7?y7vO#b#niWKyL9?2k#V%t zGeJ>hO9Z3IpXyhI=}6clzf?cBp32VXyQhRk+C<_GUO0%-4QRw9 zotb^dMz-=-%%K=5E6VKk>=xptuw0r7we-os0s*32S~&vR9qbXH$fZ>vpv%Dn0qR`Z z6$J1d0}y}}(5@oD-*FBCLHRZN}X$D(eIP+?0%5DL~Y zWeW)7QRZxtv;K`4*MA_|oGO^TkP)yq5qT56H?t@9p{N^}%ozfVu#|oLvp;na(O}hV ztd;POp>AZ-5l}Omj{xaLCKCa5vqcD?Zer>oplP-Q0n$xO{Y4=P!fl7TnaK$jz%o*6 z8}0t{pNYbvKETUu3om+-faE`kGh-aLLe{=yi$2P0gjdn8M<$Aac}4p{S!oBZF51S1 z4Py{xRKFgdD1kCsz}}o4MXWE=!c`>TV9rHAUbsMtYB-r&AfP&2B{koofHBdM$X%S1 zra%uWnQ1VO0Itl*Lx41y$w5HXoN5G6cQB0+P&21hF_D?Jt8cZF)J^9VlQs?f!!)eo zamr~Yd39JG6WqzLDZ*1q^YECn<^+f1*j772II#l|)CY>LfP2sbm_q5&6zR4Kz9K@= zAmXm{Otea@;en`L)lalyU%=NgSM{?hO`HZxyqf(SB|8SX@VHzCeVlmsP>VQjOF66+ z(}xkVEFMl=L~$^&d^`7o0fSgrD~1o|ZJEjEO{<2zF$R4k(2c7^)AX^v?g?x{es_fm zod*;0t2)S}{YpG6^;}YZ*Fw&Pfzl8zpG>^)>93`X6U3OQf0BwUiNV}#&Orbrni__H zBj%|*w@`fPUf#^%f~vlJ@c3!v-Ek7Q3GWGaP|OZNXuYOJ zIJOd*;)}OTrl5=yjpl4M(Jg+QiUj^P=c;EgaE4il>S3gKpXw)dR`?!tTN+c4pP z1Ui+1@2gGH_$fWpt7v*KW)ts+O|ebzXYHi5LP$XVVA4AL*CX?_SbM<^8`ETneh~~Z zgE)NSG3^ShPi(<2X&hV^ zU;Xfx>k?9`*BEk#P$0GBlQ>up#T4Q0NN@4=(1EA09;%9iPz1qWBQeNndVvY9&$wqi zOHdDMBXykNZk;QW1@2GJkrhMALuzXw6#Mi9&uX~6|0FiyYkczNKi=xNU_Hifd<_$B;<8BR+$ zvFpAk1o~Haa`HJUF3{V(6oSV)Fb!)I-AKyR3A#BePDn1|-m&d6>9Z8Di1SKN)R2p~ zXZkh7B95}*ao9k75x1t=;4xanRl*`JYOB2pE#mm+WlzXOT%`lIfmq%sdM)C3-!0-O zV3+4R3`L7LVXsBp!|xVxQH~qfq^bA}jkJZ=J6Iq9mSm(Yyura90sI^qX$ybh;DG>1 z4vn;hzjg?C*J}~??z=@?W%l$zvYoJqdyE!w>g&m<|-Jzv%MmG*DBPci@&gXPX0C!2khWaSAILBASUT5&xQxR?b1JNrwF|D(bFe(L z;M&flBcNao9|6+sOeO*@&JiJin#|M-H7z$MTtcu2K{;CIjW)NBA)X*VLPY}Z<~#(5 zA_Ntv##nO?1V|!O73OCApsDK3ToW*N=O=b!^dDv^3QRD#g;h8eu zyd=u&V#70{F>7=NFzis8s8~>G2nGz%KApL%4m*S{5^otjDxk@Mt0UTPQ6EZD?wU=7 z=nwE{U)YA&$=K zbBtJw3o@C0(ZpaT&Ep|JmC5u+K-xSZ0tBa+lMs+TPke0B#3?nYjG(Y{#9C6C#oQ$f z#Fn@bH&h+18d8(G7};8aRgkqz8&*;ik}jE|k5*@u^LRr-&isnHRbXC|BvhIGfuvQu zY*HL!34=F{doa$e$BpVf4aHWT0$;3)bb%9xd07WQ2d9qm}d|tHg z-fMG`^%{s{Vf38C#C|>@?R>w#sV#S)Dp3HjAhZTn{{F}8l4by`l|Aq&V z@YnW6{I5y)TF%zWG!+(TzYa15Hid${nqP$nHNVa;u(3`^#g}ofR03enXfNDrV+{*G zIPpnE{RSrrW(?BsJuF&u1?LPImGZn_f7fqGu-a43^v5Xz`k2w~hro)kAKb#fT>s60 z^ROWEC}a9#l%qwM*0U+8l$q3r!gOJai$rWmQ7DAKE7uq%QMfhuqHt-EgpxpkMWKY- zUI&Z9Ry<-_E?f@nBo#5I$*d;{IDV9L&C>8lojF%ZNvMTl%2){QmgK^skgvuUg)L8L zgDH#!p|CQ`8Oeu;c!P3&A+x{x6xfd-W736%c(ly@7}ZW}>kO`(OYwn_ZHb`yEPlhD z_zKb*zdaG`;Ze9z_#WQq^Ad>uzJ_ntMZhib&N6RM7+XcAjHHn~nuyCTo+n}`G4Aq0 zCA5RpMN8r4`2~N{hERZHJ)1d+!l9j=WtdE{VECOex1d7A_dIRI0&T^g3Y8R@0J|0q z*V1eO;(lcKPZ&m_O9^iI3w@L`ve0%U6y!M`6fvo|sUqbHSeV?ve=NyCS zyF$X=p2VDz6JMM0eC+!k3~zZKAAb$JT||%EH!m_Go z>pZeu%W&tI4Jk`4Qbtx~gSShXEbqd8_vn2~U?bv_{j|#)wqM`;6p9zezd12Y?tB;= zRjEJR3J<-z=5_}Z$BG<`4*S0PnhC{=*E(nTKKq*Y8kF173mcld9u|y%^zGCIkH&G@ zoC3eVHi!64e9U&|rN85js>j}KYkco>*aHt0?lL^rQ!4Mg0pU8-<(D@*$}aZc&Sau} z?|eEkpO)g$Jquk_O#?QWqrjH zxf4$Y-`3LYg0UfPS-M$XO1(Pkxh84Z%P)isV@)Y%TL zcRJjCP}p3HJH=-DPRZ(ildW-zb)PO6~IOZVcV+ zq;rZ5PuWR0#R}%gky9-9DCrb?T6e?X6f0zG3Bp%SdP68ocIJ49)oiO?hFa9%Lg@ifw&-XfkrclTmHh~w32jHiQo$Hf$&5MfEMpChW_`KJLD zMgBI2V0pw2h-x~Mj)2w>J^}=ROeO-_LqrHr1v2#zAekvafM5<&9|1`-6_1I~P1PLc zt?{N!282al;)DA1jpME??-g{?$(7EX#CA|@89k59*DasdV7>+1kUjywB{=&3k$@EO z6?RJW?6BGX{2B{}Yn{ljLht#;G z*g9$Kx-59F*F^JwC=9eZO?CNmQ|wvRm5|Q!8hb2me;o`sOi~ow()su?ZJ3VL#MBI? zMas63_KAE-Kn!kk-3&e`k}{l_rTGy)Z+6Xt$j^-`D*-==V<1zV{q=j z`zK?ze&1kP^Cca7**kG30U6<&fIhw$zTvGq8vx&G)!VWo**UMqcYuqh$B6YxuAT0vM>y^(C+FD{j4nS&Z`F+~bG&3-UKX@yat(D_#b6jDHQq z;|^_azWa6jX>fj3CGYX@Dszv;J$3nyJ{wnf-}?O$%mj5_Hnkbko(%m4H*>1Ap}q9( z@DzM+zem>6^G?0#9862f{qLV`!`}DcE;s2XN>hEJzvgs8@siCCa?4)6yNUbI#1@rZ zc=o0&6{e-?2T}8mzG*MTV^q7|_?9|ecs;(?lVN|S4vf}4$fR?lD$#>HhkIM-9%NS! zOsP6r4BRh)c<$}cdO_*jol&_(yl&`fw^9AMx})y(q=yc>WEeS2IX562H+NcD&$f~k z-Ei;74FaCb`)u#PM;Xg&;khfiho0)l&Aq=Ex?1C*zjHTj041 zE*x$-C0oVL26dqIOYGcp5U*JX#gjazUb7!@%^qil2s6@zi0s4X_hhZJ_%*oqJe|z6 zgR}UNUG5Q1gTWV?lIBOHX@1Ik^r{6^+3UfBO$BhOdl1MYF*!yyqIq&@o8E8~;IEqV zj|wPJ+=-fSoW(t@NU;9%y2 zfV?1T8>-=C=7xakASnqrn|ULkHHi8N)!>==BY+z$B>`8n9}obO*k@G3%`5}~KrKN6 zV?G-!mKbYCA7h;~tMOtJOFJNx$OA;X=JPURAdP1_lQCm5Hq%f_j8eO)6=TYc3YeC) zd*Ovc(F7P~1|NLsj8AzcA>@Al;zC2IGf{chOstuh6{T~Tas(Wh$wvTn9&^k8#on8T zHFbUe!kH720AUJY5DudeL4$(ggb9P90tQ7zoB}Gv0gP6yTEz^2h!ZNcSgi&`i&F&! z#5n;(id3UgrAjq`RjSlt9g0KoT|3D+(BHSed++nS?|q-=KF{S3b3ALGX`g-eT6?X} z`mE5VGR~AR;9mfO^l+Q7z!o{%TM0_!nnV7Cg3S0RV%^*|U8uCh@0L8rPB~UM+m6$wW2V*c#<^yvp)PBwF$FeeZR~G-i@`Hkw$Va zxdtN1rU3GsEs&3U3-WPCQ&zs~laG7bJKm*};OEG>h4XXp#sNP^R3Cm0jg5()!{B`* z!js_VAiDsD1Mzd*0e+6CrMw@FQDHw}b$l<*qtoHFM5+|lRz%cT4d+Qj*QGPzu^?zH z{7Ukp?QD1|Sy|>zNMsVioOpK;vN0Z1fsJuFtqs{2Z!8rA8zLQx8S5#!X-&I*B^t$t z-#K$b=j{7JXSnlrz}Ek+)t%qH>Xw3s0(lc6bN;(lyLTj%Np`*ipF}14ndfpA{Va2j zL=FkNd+p$mfF-dp){O2G3=Rca{O`XHR>+@54h-zGUJi^KXaT>l5`xaMM!q+l2b5c` zV=KYkphfUoeKO*}n2s>HVYy%rk*@M*JNy@a0S|aG#H{8eh)8Z7@->8@EUkijK=lPd z?fNBf{&N*@enTZ6Bg^bFMaHDl5}@(o!tELVk}wjf$GDDFF!x{d8m37Qle%R%~I71lmSm3AE-|GW6L|l@$}z208n-euEM=RZsUuK@>u=w{WZ(;pu8go2Vi5xDVr>xqVs zH$cgDjX07NQYkw17LuIU)lbszmMq!>I7$5>|95`|xP4j$%?)iQ_YJZmSt|y8`ZoHJcB;nsw(0%2ZuB=MPcinbAL^ zOx2U^WdY}lV|IZply`pM_A4LOw{^hTy9U?y{9*N?KeB}`zUw*3`m^T|vV|6%|Il%0 z^&jVDX{$l5Sx+6svGwM2lr$3$`Dzc5x@*i?Gt+)l#xwgvC+f z5f(Y*411Uoq2)i?6G~b#)~KX$jl9zSBCcb}QJ2Cr@d2c*2*?xZk%XW)UPcI%6Y1j#!HIYUAs|npPap*4 z@ftz^tXh)@!NqtjAs|Q7V+cWQynzq^wZ&vY&=^m;ir+4{ZsG_*Q#?iplri+FgrFr} zObE!a^l60PalDKWC}XcuoJ!0|VL7_RbhX$UP}fxC6N>sJ48t{f z@>Pap%3oluDiNr-Kr+Yzw^bxF>Eq(wlbCT5R|2P)NWY}z;MRF3VZN07)Wx}s4=qxx zl3*??oo3+R?j*@9i_3)hQuWnf+9lk4>6)(8!&yzXo!o&-2zWEdw+O)s2YW&wnxX6! ztaNZC1kxGg+r$y89lQyFe1>v`Eq;wP4*tS8{a~B`L-GmSUnY0l$DvI;s~KT1yiFiC zuLxR-lShfh^+rQ4_mA`p^ybH zQiKS{f+{6EsQWvpiV9zqV7%1Jqt>51WRe0Tsw7w7@{-pKQqW%_4y!!e^bgG91~`KfM+EE*WR%SX6y_fDSmg_$(# zsR5M`h=^w)1H;KhF!gt@%V}2`b@3l^=283V{}SZuIPo*#>&SOUA%Go5U&jjM>rie5 z-Ht&LP8jK8Gx$0(`N$cvHF*YP7SUC&&BvzNcO+7UGG&X{gE`^gvk)~R*-uwX1p#1*Ad8;GqHW*Vgt|do5UnDf9W;8vSd*#n> zO{p0J=iRAlG^J*cx1RUeYfjB5z4DYu%^HN>W04Q9=_)KiXC^?e1fhtn`s^2^u}S%`ws0*U5CE; z=n-;rOz5kM9nd#dPW?H0&X&t*;F40l-(CFd%H@f>!LE|Fdq>r+ebl}j&R%;V^ial# zO{I_E?7wE5{*Hay;{-~gd2n{&57lX@_rR{go(P=|6$=g-pm}C+iO~syA|-)F9?=SC zl+nXB`9s=1nJEEQtg3hDXYz~mO-_Nd@sdHFVKyj3rQbErjC+M6kH3S!P}_hHS581s z4veLi<<2`Y8b2&VIoTD~-*}{*D!PgC%|`c}i}~PtmHe7+LxLTn!p%s^WtOI@dH5L-V}!Z*R|M)tyAmKMi^? z^4q6FFG6z(Ys=U#pFRHj2{b1icy8nfy#URd`=r~HwCt$7*z@WeHLW*Ix?enEJpj8} z)#DDE-Ds0S^@L3g7DlJIJ$2V#la5bfZJOP0MI`g;Ea_u9(1+x_L$AL7nnTOC?@rl* z7)`f-FFabg5n#1}xIuG#p?*PTF5W zc)MoXL`vxsp{y`!ofQsxDWaA4wJmtiMD!Z}8YS&lTY{*g!y%3~?P5fuEe7XY9)kQJ zuG^OnMkh{$6MIjTQe)suz+c6(P~l24XHPVgFg6=|qC_Q+a48f!a0!7lk!(WL* zqi0Iv0j#glwn@u3`Tuh~{lB6IP2vHgjmdNv{u1dv=#;BCzM?6b-$2ZZIlCq6qL3L>AqM8oR`J=a-v{ci82V! z{K6VlBP|Ks*pq?0*!i^vv6#ngQe5uM{3DQ?jq>BE?$F|=!7%<)&OWb>WN>O!CduMnkg46pJr zr!2Q<_Z02g{h+EhQ=`L&sApRr7|7a+*Z8|Lu})*~A_xK8W;GBQa2hmTPWp_F*#uo(quc{@l2X{;PskOyg;U8bhmJ2csZo@&%y z&eUXh6X%)cm5o1_;wP@B&{N@uZrS;^cP_l0;LU6nz6*bwvs`*=#^ROrRue^t>nl7X z#0$(9F?elkscZ2uOH~SKySL#CZ^?6}A%-cOMrk z%uE~VDo;XmwqcN=*<5yvS`oez>Z8X(620UbbBm6~^a3A^sb@4pMZA!ryRyIZUphu_PMK?O%zv-1}1PTJE-RU5>IlRmT1%Aj6o zXw5z!mKt~w(gb8E!kCg&Rk$E3yI`R}onTIv6m#QmF35I?^<||vK>-KVzTJpeA#0!A z_I~NIo#KLpBK2_A_E;Dpx=0n&dxk;}`1hf?uxHWUJp;a;;8frvMi&X8x(_w|iiWDM zqcdoBLY?F&YA{^P^yR!x$8tR9GEyv>;PyOjDo;#C?4xpW$A_=d_wVNUGsj zX~unqe=h33IP!5vu{rohg7svsv@mNAZpI*DNIaQ5N3{}v4Vo+BsIT$&^PZDU2uiOg zdIcXixr9J_McXU*$k~3^$uwyU4Y}ZrpSAhQDvPFt_{aK*BAJ7a$M=H;gh2a=5=IDm z1`7!RiKfA9@*Tpuk)@3dZ8a zk@5~0_>yNNZinXl!CS4_vQ;oew_0_Efm39TmUYSYi0$!rzkXW=zPavXhTo3$PkRB(ik2MQwiBOFaqK%(RN+S_ z+%!^f)tXiBzoQ)7i7VNmT-rg4V*;Pxy%rI2rK*#Wl!A{!YRkafDbfj0{_?Gb`nT-U zFZarZv#$Oq)r;{r26f4|!oH{*efD*T1iN70^}hRFhy(?&FZO1icCST(1F-K_-+k{z zf@0Vg-`uC2I8bm5pE59ewd?_#yB^(>4I93(&<r$(fTj4d3?ZXJuvi`~t&qFg+ zQgl|SxCgxkn9f$}sUGInA?Ah+5g)?eWgb^G#EY<}7X-3sL!HfxSOr&Mp^0~UB9}r~ zVnX4B@1yi8#%uh{{CybsYMBTIipe0F&81YY$NPRoAd5{Lq`!)Zh=XK~YD4cq zbGejS;-GbbEL-BBuB(`oILP3rmIV@5l5;6_>-|#k=|b$PRuuVZ@c$}Y)wx;Ial@MQ zJMmPXa|bytE;Jj`R5z}Qk3YQl8kSnP)cjOfKWbqXanLGmqEJ?df6=SB6lE6vjLlG3 zm^el1H4IBg#kEU6wWjx=FSrynagZU9Wknp6a}5&{2Z`O(SQ2riwOq=t`S_JO0$Kd{ zAtl=ZZYD(m66qWUAWX#LAxlBTF6K4mjh{j82U!6F63RwlC8X>ls?hV!A3OYj`%H_b z_)f*^o>ZZ4-b~dL7%Am(yr?8iTi!UbM;1OIG`SqsWQ&ijCulK5KZu0~HIyNm26ag8 zD0k4%@L$mJSyPG0W%e?Ba!DOi6#c_j-&2V6uhbo|YK%>SF&s_y%Np_>82?-ntOd9? zme{MdptG$(XA@_|vl^Th?7o&7{mQpEhvKkJy2U`EX0d`7SW$M@%@K zU7Q0%4yrZm6sPG*2i1dvV1%S*%~VZCGto5pO@%{>#ca6n&M6YKdj{;DU1v#~Q2?fj zGPiDhy?9O5Ov+4QMqB>^pFG%=t)P4Gk-?)WDm*CSE$~pw!ThNeA8gUz!)x^qLf&3Xs%>vKQ$XN9I+uu#pLRjyN^T<4sVAIdklm~irs z7(}Vad8FH~y(q6fim|`1KQ^M|cO?8qRt4GG#EeWk!erf!iU|6xQt8ayHTl6&9rmu` z9BeL|)xR2hcec=zyKBe3fv|b*7Hpo{HwZQ_c=&O5-P~6Uo0k^4VNqT-x60d5^@t;D zjkkU0d-ury_UgYt!N0frWD7!zK9v9MVm!F^)r{i5$KH;PXd+#oJp1M;UR=08a;0ZE zb9Ya6Ii?f6tGEQ4%O0G7&9hD}+}%57`BjVFox9}8-F<3S0Bm0P9X4<7)4=8>vgwwZ z?hH%Q@b0yv^nvo_p2n()N*7`6n-I8zjI#+>)W7Lk0u3OK-}^HZ1L0v@L>dJ9EK1Z93Xh z|D#M<<3ABkaDW1NJc=U7SlzG4$v;XV!!oNw!iJhl^67V*i}bA-k|3ffVi=rU35)a% z7v)WK&No&|bT?K@d?hefOYDbNOMDD_GRdr*WLGuBN;2n{8c3NkVGdN*6GKt9tmW9| z)1qzE&ZJUpZy~(tB2=j`Po<+k#KevZa0#(W2jPzw%K-B50`fRWU9+StD(~?_BJyH? z4+U192Vj{b=_UFEG7W&Bn|?!NXRlF0{}%{t@fKx#d~sI6qA7ldowC)?53CFHOde?r zI_J0O4E1n`>=BCBv^-EzSPb1mK4#1TDS4n`nyb+{r^MY(gHv@NaEAu6Lde)wloe8& zhS04*6+rT+Lcoxm`;>Z9VMH=h zS)6ZBc{q`kcif%_iF#uVIMmh6cU8WqP8_PKces)tglfqpxkeTFmse;kAX;QB{I~Zc zJq8t(L*)^-EDcD7N`j^ej{vnLEk`QNGhDN$oV<=JO0_s)UA@G3&@JuQD#(yWVj@#=R9kpsAWwA4Fp>Q}1{ zR4%`prPMhxl0}RKc_>~c%mUO*PIV@9Bl@I_akKtZ6gOL9jGJY^zG4&uLw{YL$mQ`= zfD)q}TITD(rZ_>hZYp!~Ey+v1{aWPiPY#vVut@#v*?mG!eYlT}gl85HQ}?$Q|6KGA z2xSiqoOtYa_alffe{b|6O_eaO>MoBOS{&;u> zayJmW?fz#0r;UW>(7enPd_lNW98=o{C>0!bg7Qu}-dH~R_~VQ8%McDY620ko%==%{ z5Fn?q{aozX^`_c!cdLFuXGf43gGNZ8Ud!>&4hKazFf78xp5|n=DrEBVt|r~oVksq3 z%E*$kulC6es{h1+=nf|PpOB+LH-2JX9hmT8qm&A<14Fx4=i8+@-z|mv&8tdybU5aG zf0X0%$AR5l$Iq;|iRj%==2pgDSbOwq2t|pHm-l;MuHRO^He&~(bbbNNi+9{GQaYpF zFj4)Tto#QE2hg|_$_<2B=}Ru(g;&}OgqiAOoQHAbt%v1>p;5L3a zMu0yKMoFcetTsg{HVXh1VRGPWh_W9aw@F8}vS%m>M8J`70MwvX7{l9kwnpseD*myW z$N}+nXjDzn`4fPNi56!7^I!_WH^&k1hZ6e$+8G@FP&!2^A}|a74qQT@pF$2K1cMyx z2|?EsWv@Ww;0gkn5!PcEKLHQ#V^hm{jX3N?7fr7wKqUHj5h0Mo&}#`nbG(!g=wj$~ zBk-t=A)X|`yXHL6HLtfzs{`;naH!2DS-7^a#*#cH6PR$Sfw{d=9N8YQ9pXP2M2^)c5PWOQYP)Wc( z2rcVSlu8|lV>{9uwBcjT#|TdBapewgl<*Lj5Z*`nNh?XVXJ~rBWscrDbdmy31Z_g) z`b-!sy|thH(yg_lf^`rq5M8m6aSh8l@b@U)QtJBgY~jLBIMPwdR#)u{Ld3Oi_`M@^*D@Frce!lN%H{w^5VuIfc1U8{VluXF*-Zk={8*9g;N>Gwj)P3_;iAV zE)YbsY=D&UaP?_ruKHHAk4w5>dJTAhLPV@=mz*+0`DMjb`9)K=ur%9G>)8#-xW`Cr z0irTn$UWPgtkh45nq+8Nti)d9ASAB z(v(ee#TWeiGy*4=m#owh3|)SF0w<@-HR9wPXi^FB?gy`SR1*Al+61hbz*B9ds|i7U zf|wBKw$if*fj&V-2qfF+o94vlNBk0rEp(fN57x&CVv9GlDdT|H%QV_X^GApMI<5}? z7~Ab%Uo}S{$s>Nxo|n(V$GxP0zER?^o_H)#iX@fgFgtDnF&(C)Y#k>ZM><2BFUM|| zI(k};GKYRODQhv}8}BS;qR@7Do-@qMz;`G2yXa{$dJch5vy=T!XUOehz*~?o!kijy_rviuL)0=jJ{68Z3#MgFRy@lg07zY;K@=^_v^G0;ODxtXC0#IItH5$ z>R_`c?Fob43Izh+qMwYkc7|&5zH+7E-t=Ws5n8#dlR~xkJHw1W_lDm_ps6kMyP*K# zxOBNOMLg}Uslwm^hH-6n&v%h=Z5EpLD}#9&mTh#n2Uko5n@ckH-I2rirAX0PymFytHz7Uc?Kw zAl=LhMp|L!1?xdpd=IkXk;zZS9wRHh3t92HOUQ~(Jxk9qCTg#&GzfDfEX4=}@7fn2 zWO$J7n&Dm0ET@c4!)SiyHa4>q0!TL}m(DX)0@ou8-EdDu<*5p>mL-y6?v9=IwMvrG z5Y%{k7H1d+c6!KG+RXgywfkASeEkpkJ*S+cJE8jIuH#hx6Kh7t(k9qVn|W_uKs~Cp z-X^kL5a;}4xo6ZjB@O34E|xH8VD)d>iEX(p+F>WT9Wm|hPct~Ljis}#yd?a?NpAgB z#hZ_v;Cx99sW0F5YNV*U7UR~-xDDV#Az9uQTi#Vz-q%|8EuG!9Sofc&&04ssLPkYP znVP7)$VCO$He-v5{GL}2IiFYx)v9?{lb~94P-A;UUQpautCD`69nk!rH(VP1J%6bf zm9Q3AoEBQX3bFByiC;J;fj@)uL`38ea3x&_Ja zVpFI&mc_#gS`_QlY^xI?Ri)hT`76=clCVv?L|g8Nci0=xb|fwRUvKqv<-gvl)2x5J z)uku@^;VBe|BttJ4AwD|j?+KakY{Nvelgh0vtcj=g86wlWkNx9NZ#WUd?6HOXDqkRnF$rY6})WE6MR7} zj}(TAb<4O=totAJS6l!$V@I2BT%!yIY#)N`X^oIQA>?8ch#v&u$5#XhqM!H4Flm$b zi2Idb;(z>m&4Ht94_!d$C;vY_Jmx&ZtWjG0h*ufTA3kSY2Mb$!V97azCprhs^ZTHm z(*Hi zL_WRw!}69b!i%RohK>_G8Y^tOc%iPi9tuomT;6oC-?O3s!1Vv+dVSq*dqxCohhpS{ z8_QzOu3dB-%8eSnzO2In(p!xm}H;=y<_-xf^Xzm~V?9lTAuSTHC=$~>A&1*aS z_YRN1ab4f75FB_ud4~p?*RKD)@?*pO0OQ%+cb)wH$dfT}cIEaxL&rW#-h3X;uK%vp z^3%vI_g+Ku#sL?<>pw8}voS-AgW?kcZ;HPDe!`KeMMnZ*WtH<+PW$uUo(;Y^7;3o0 zo({dSBXA@FK5qVQb@v4y|02Yma7VN6Nm1^0v=DneqU6AW9Yw+xXnt{ELtDtxW6@|K z_NaY5`_AKM?a;h*#~0h*r@Nir0L@#L?)ds5_x$<}XkHd|=+qDAhb#pMp^{tVUsC!0 zAgV+ndeR;H_0myamqPQ7ft} zeZLWheud^+51-l-z%{p|fA_ypwbAyt>#=Uzi~gQ}i9{J#9>+iwv@uIKN|tJd5vJO|Bn zH~%{I%ZAOZFU{ePny^!&m-Mm!Y}z_`3@YCoR>8 z!Q8N7ZNs2-)k9HjofjYX1rFM@(gi(R&dsXJ)vqRAhvwMxKd;{{aXS7Qn(GJ6`jB>V z_sPGZd8PX|3-{kCI|0olR%fsLb~7-v9h#HQjQI3*TXb&d6+jOzYp8zn@!^L1(7bp= zZvVBF{!0KosJw7B$IrK$ZcInN7@6gSpf7sXK+CH=gm7^lz2G2P=>l#bRXrLJl=MiS{g`u!hiFT^H zh-K-CY1YBkZwVVL)PePiu-T{f(h?}9iFv);MA-DVye2$l0CM%i#^Cm0-ZT**fSiB` zLQpVGN(gi_>El6QT0Rv|j>0!fcSM2V{+9v_cdCv3gcf!4nC5UIzR@xsx2bauChbu` zBq5JHE)$?|0t~_@B1RMtKNU#E0t`m!-YDBs7^P}usND5Xbs|dT*G)!-dxoC(XEzXxEmpi;yN(gT?^tUu@ROd2*dr2 zgZ*T}!Ug=HmH-Ww^k$(-7XPQ=e)5g2k}$HXrejsO*{p+_+k`+nT}%ihP;;9Q)K8ZY z0v*)co&<*br3maQpJ>rC{l6IQ=N(3e-HIR#_auA0Ijy5vDfl@erq^)SEu#;Ud^Lx7 zEZ$sl>^}|n#|t|mjD~yT?-4a_lY=^On(1}F_NFS8>~x!rfTz02;IL`1ot+T0blW(d zktV@Q7;S<{+I#xrBlvT|6;r=cx@xCmAE(ON2->gCsW$iu#=;AwqMt^dALTcf`|}Ea_E_a zASXde2y{90Y(kKiAg_w_NGEQ~m+zg9Uq_y-Y$NE;gnWNOke5tuCj@@{AVN@@tZe6J z(c;PA8Ug1+IAePkTH%=2x@7;tYw~c+@iTCT#Ym1`Ta1qu<<=h~F@o=6D4zv%dvfIL zSq#){y@Fsqj}QPgn+I{kaK3;L05w~$U<4l{1e*C0PvVGCd?6vw&e!z{#_&aiz%XCp zMI14fA3_L7DY{-k5?@RRuoQ_mam0Ln6d@3&=)4^rq(k~(X`jA!ht=I2nAfP&=s&r8 z5ScrRraO{}wQb2vOi#S(Z|GD)P&!jg2qgRHbV5)yQ$`4M`{+zUpq;581d{!93qnvo zQ$q-J`{^7)pr5HF1d?y*mV}^rrhyOuxwkbTFw7*~!f#h{fNnzwI%Z;oKzD#{M+ksJ zju1!=(jDf+%YN`D3-FowI%I(BmC`;(0!icUv#P8}oC@)dp zdysUoa!2n$N9kjG5BiR;=s& z37ij5UFO^n-PDHBWY zvx@KsHl+8e+l#aQgAY)&Y?ZD42YhC|M5la3M z?B4tjK0wV1Mv0LRuxn09Q0L}}1Rr44>|Q=V@=`M&VE#nWb^Ti1w-alop9elb133D^ zuF22P9>GMIrkc@d3mFOmaOAF#SUmEVpsgn!7BbN!44gUJwu!?83!p?YX%N^${Kep- z*}z3^jVmG9sD9@o#v0_HdBFf*&5<1BpxMjDDt^<;#}oiZO#5U@E>)c>@N%NdDF zMUu-|rOUZOmkTDIN}E5z2lp*?8pdjAn&?)B@v-;g737kiBgoAx&8mup!I6R-WE?AJu@ zS1NHaTw^ejFvff_V_}~A0jK?B=VGPv3Efl1G7@t+#{69C^|6)t3JN9}tUB_n-jmrM zUNX;MoReC1Nglh$hFvUWACa?XVM?rDC>;tR73Pp2)2{ikP0en9oYBI`vjvJnJ!G)}A)DkSg<+!rD98+GmRO zJB{@Rt@Ymq>(C*#fn>YEt+sw;wmwO={gt)@bhZIWp58CLhA8cZOYDZm*$vxiH@4Ld zoX{i7>}0mSG8=xfCx5<#zcA2uK{CI;%F}O`=Vx)g166$gUHp_4zNx!>7nkt^s`+WH z{LdwhD{LH>ksVhhIewvZTwCb4k?fpt%kfK{W08&XA&K*`oz91ooWD~#m%em9u5&)w z>b&^0@6y}8>3{mlZTn|bknAWjQ%E}L8*nyF`lclJ6u(7+C>dSrioEOFgShrFQzGX~ z&rVCcwqdS7A#uth_uHxz6qB8sbWR1(=MtAvolBL(RjYHY*Ub&ks$mq7Lbxe$w#Az@ z*LEJsJ%57xpnUs>iz1Rbkf$Dkse{Dop(6D#nL1cX9kqr!Jck-mKpjy+9jQM~U#9SG z@L)`;A?XBpN(0wAj*%i~q>>`R>Be7uSa)RM+-nYtOBr?77&0CsO)-Px=Rq<@;j`De z3U?j3kSMhvn_&;&DHHDu`uQi-M3%xFkxz1C<`y6Qo*y4ZGRNh?K_>N%Sy`d6#k&Ih zZc5KZ-yElZ#Uq(w^pk2vTiH+8Jv*QTQY=i7dhicwqP}TcEOB6I#RBFO?}qB28`Q|# zv>BZ5#pYY$A2B){kF7XPpMjZA#kphhHkSRQ3Ie|*iPA3KRks%Nox(SH`qJ7*{y|lX zcD|SP7590ZxR>YxPo_ad}K_Xz3!Tw`=5dCuWyOA9CG}!uxZnV zmTO9Pi_?pm7cNX9G3I5g+%H&@DrYPdGv><}OG(a!WsJoJ$BjD~>u)*AWzOHdWTcCo z58q-G1vL7KNW?BIbw4a)7c1CD<>-6) zUhjOY6uvk1$x?QuoKq&^lw+Im1|ooDr2 zYxP9y{#5Exxzp;I!U}4e{oZBOWw3fFcd3fAek-=_&a?hYYyClQ{dcST`#hKPt=6AN zwjN-zNsOOQnajmnE}zP54K{9VWKX}Hw!&7IE3LM_Cb>PmW$Pd4_SnXAu!pB8&@)it zIRri~rCTt3Wl44;ciN3EvkSdtH+HAnvp7!~%se*y-&@_1lRW!}@#n|!KTG1L7V;OB z@t5A>r;{C***FS49miA{tSL!=Sui0?*23!scW`O+gQ^=Jv>3Nm#sy$~qXCffmQ6Ct zW{qWw!P0?8Ewk>{SVtr&1*Bdm&MO#`ad)#8epR{r+rN|&@kWDKj zjPdlbSsSmK8ap&uhOB&kbg)idqlX1nIB)iyQpoay_c*l8Ng?Yw`5b`H=E z()Og04hv>TAktz=fggN_9?4K9Ov}_z{_yL-WTRp}+BYZ)#h3*o zLlQ?;2I3Y#u_s^n{r7h^TFl8z@?zfzF-3FCpYMqi{1ahG+{{v;#fqS2Vj-it{wQFbg5bvBF9{=&!r<4AHs1I>+rTz*7n_pf1${yY zQTu2J%A>I~Qrbb8cJv37cCgR**X?c;f07h#g>&GZkk{4mr2qja%>!p1MtkEwB-w<5 zPz;vPcTCqqC=;TUn9_*^)uU`ev~qCno#EC(NeNbykxW8L12W1w2lgywu|w3Cz{5Ty zl0Nl@kD{0ijbjhcM4iwo;spE${*o{f31PkPZ-)ol=Kk;^L6$a)C4utxSst#VS9L*o z`^RDg_#*!!o83gE6tIni5MbyJ?~6CLE?*zB- z5?}(A;!#Z7Q~nSX7*1= zu5ZO(GWq&e3H<=hi=>SPOqvnV)vu8ohi~e)inzmFljQJ%={DTq_Q_}~(r&*X2!IVX z_Gh?tHm**vtw!7RHcc*7KwKyKf&s`m4GpG&BHAX{o>U54v#{dLD^Q!`|5`kvQ0v^E|V$}GDGa>Jy27`xt=x85U>#u`0tOK{FC9$DD@C$cu3&3ZC zOe=iuB#oS4c~s-T8kHym${h|i%#_xv$L!!wzDhmgSD%0j^FK&(HwZW(k zoKt~B)FM^XK-cWoxzLza@Rp^ttNv8yVyI@JhZ$v@v~fs@#F_S!RVSch(&fLRofE;d z1Iwue7TY+oTwN@g3rY}81J0xl$jj8ch zyJiBxmuXTljEbW_QQ&uM@S3Ku3TY9&MpY9V$^B>x53BJm%$qDFkCOEx$7`*i_5l+v zGlaPsgI{JKGauk{-iHzFYL$#BS@jabPuOV}FSbsIAvx%w-d!{F<)J3>0KE=IOA}e3 z*9~BQsjl9muude0F138Gap06*I2680b(fJy;~xCF3B`dvV@-6nNT*Q zW=t3)zr~B}W<NRIwGdwKA$4dZS!+*>_Y6GbWI)wp+8 z9(EM&BKg;LF?-U!LD#l2+`;V(q&JXmLACmC!*YTLtl2TnGKb(8!a} z58~fR?N&ku>3elhlTim5HIGpT%{S?wmOeUYb{`#N?|KpGAf$J29rSfy9i-{2gFXWt zbi$;ADv%B;_)i^_gmh5)|4auhL^@~%u7mRaO$X%>I*9z>UvR}sT<|~`_2&Sc_w}!W1kc4$zuarnB1ET07$S`9E zHV#=cPrAt_xR_!_dQLDRJtxkXII!NI4F7hSJlD>G(wd>t?;MZ-@$>C^4K+&`k{?OK)VoK@xpGJ*gC{~$E2}S=x-H3gtP=Ec6GlNiS+K6db0*6 z|4nn~px(6+)8B*j4vOcNaJ7uNN(Ap_wfHB<@0{X-s!%c4^5Xxp#_b60SP9)z`4dg5 z1lA;p8}udHEw8#CDg@=G(DmD2R61O^)WWyWp`iAz!ZwjS@x++%EG|1tG)aS0)h^~# zcW67#Eqa4A%*GJ{=ca1yN)FwjDH9e7DM*`ssbxwW*p0SbI&qd~C;1T5fhMzy7?^OM zCY4kz?4qG(T_<7>ieUP%KpksUqU&<@VYK=zgI7~P@q-bNxd;tizx`p-I)4b=_`%Sv z9uRWGkdH6Pbgn1I(bpB?<8M_^KCjE%6%scGeCV#XXPu!Pg%{IN>&hyDE`&R>OsG}n zvCq(s!|swgC)Tk%U_d63#()kuk_LKl<1DDN@Ox<5&qyy;E_bxoDWOb~Bjiv%0K6qB zBdwD@o*H-K@ttw=+shrL!xnf|c`- zlDzW-Et!rwvVaASuG;tjlw|u`P?G1IoLQICe_=$yStCA_FG4D_o6d+B0arYtH2k$X z73;om*qNIV<8oGKxL2{wg`1J~;5=+r9fi#e`@XZ~ZlRp;fj=eU;LxbGOc zFRGD}89BT_ryw2b0M8CO)L3b^{65m5txcdqrHhdceZN%cNu#0+w~NgXG)h1!bpS+w zK&5KT8E!=t2A)DAV11id%G2^#6O@ASX02+{tE(tzZD-P}K5(txwO!u@{r~H=9idmP zvch^%m_`?h%K)Q;02X-R07pcBVX!*P3+vUcxNh~Rz?CMkz7tLPMINsGr*8Fm1I_yV zQ@56*m78f2t%`Pn0g2WnsdAm9*$Ht;nK&N`_6CR zkI_e=F>CT~6OYpO`WWEf14fFSaa8K+H2@EwaO2wtNEAT6WSu#aF@&c^wE)$!JxSaK zOxP9x^bpcbKo0>y8&*ROM_6o~yhp=jjC59Viv#m&Wtw(Tz1W#Lfs6Ob%3dhhw>T_8 z@kK`=Wx1Zs!il*A?w%!T*K*Yi16$^iV$ldZ6J<>10N`>74|Od9=ARp|=Or^+1bb>i z1=ZYa%bX1GQtXjAlk=Ml80^~@;SOoLCz3+0+O^Y$6ZZBe_ zx0eK7FJ&y{xlY9W7HL|B$50Kv5%MXLt1#RQIb?A*_}4cADnZ@A}D2U zb&KbT1NyIU5XdIYF)Mk>ZCuIcR)Uh(1oTnzB^+GIYYu5B7qz)tu+Cm3FZD7j`GReI zm3;k9qmu9VUsCcpbfn~0h}antQ1a61C0%m0m{aXvGaw&AOQNPQ(RCTSgx4hVg}krK z@$g9dLG4e|Sj6w*1Wr=0sr|Oe+4%#chei73g6~LsDe;7*S(fkS(2|W?nvJ$`VOspX zon3%?%Q~iDO=T{Lx{-2vCzxi)0;sy`ij_duDsH;+%$C5mIkn`SU=c9iq2C$?Z3B&p zj9M@q*Y776vWIjkcFXSG9yOZ`Tb7a8abTmmyERSrX3Fe-vVp#ZA|sN_TII3`8nKx z$VhvM8xYG}HsH3#qwJ;9=-}V)@Qn{f3QgNfC(FK}OJw`#O6W49U9lQ|Ic!XjxgH@% zE<@h9p@na+Byn4j{V)&h{2nSSkV`7)xjZfjQs0wPzzz|dfwjB>Rc&04j1A994U|h{ zmC+}f(&JAjqBbY#axEi`gxtw>h5k`(qociIqzmT6XDz&6x9TC2kz@jb$QYj~7>hKF zB?iXl6&k9FbG=+*1~)iA-K?2VrY@Y5 z)HR(`AJ%_0#bDtfVke`xFG>SS-8JZV(XfD;&{OSoVs15wdqc}@P;l>xEpL-7?<*`H zYAp?#{(X~s{uEok)%MR&S-;m>Yul{H0Sl1CK2~WzQD;A$>@Y>*Fjwi2sB=hH@)zs) zDP(@KgrA{wT&HtfO?F%%alD~)ZP2;akzK1Lt`C*{?(6#9CHK26>Gy=})+TX#p>%tu zb9+m6|5M`LqjZ0-b00_X4h!>+Re4XW_Lh43#C!VJCi^5%e3OfOX1Dn)4)aY>`L3w; zO{es~yI4640-6GPnb5+7SAb-uteq&OfAwO?5Ry{P%PO48`G$f~k`lPTig)b?bI$me zgpo*VVapC?Bjkrcae@EL*B7&Wz?X0Gyi1JnYAUq*PS`nsJ<7DoN|wep2EZ~7<#==F z$$1b@v4d$j44=CNC#J591C7Cw#3^fyDd3_kSdEEjS#v~_#6Pt!r6W&=156``2wnV@ zH>|znV6Ul{)4mSIS4g@a*5DI`AF>2>OohJsU9Y&kri?8Gw_Y%5TXHwLZ0 zWuBomh&Oq9uUB;G0$cJJslbfu9|MKo7|F2w7R5-_PN27j>I+%i(^!KXmZ(bbsEdJ7 z)un(O2*qk8mW#pnr-yTl&QS-|SS}|qtNp`+=rxurG{}1$L=P?je@Qg8UoPfhLz#nC zvQ6R%HB8k=-U1nGDmd>@|Kk2H{%kZS&021plYXQlAD)jR%t;SWUK(bgIL zV=l1wbs*w2pWa2LyNKCdMzfxT@Cc}4Df)9?DOHiL3Z&KKtFBV}Vkx##x}sM)cBM4E zSjzKRa$`R*2J9aZ&OAJh9T`Xc)pYPDocN*p6I>QD>z(*8CUj1;Vsa+OYC)BUUJHbm zK$1#j^$U@bc!NUnK+jRZPE2OykXVkc_44p87zEUvaCad$lJ;l>CEYa7K%rSCV?caP z=Bi`xSgQ+sgeewzN_M2f_V_A>U+9yBtRVCBsSnb=wcd>X@PsXVv?3qE(a zAQe3S5NR_;Q<&?76T_I|DuWmd)nKG{>nx(krW?>R7J-{5gd9w2RZ+_fmSUOGHB+|7 zP{_h&%K;y86Wl>Ts-ckljC@SuDWVokF9G*-#1xsjU#3!Kkc0eN&LX)NQ{cjWm%Ivd;_5C|>+;z*}qgPKVcB_wCx_m>Qp7Us|$ z84EGnKB!TUV_@w;o)eR9Uw;8rBQA;2M(}H|f7bULXv8Wq~-UozCWaC3KH3vS=v$HSGB&Aq9 z21++8cqL@3dJKN8L<;Ke&G1~&#es0AXJ{)S7*L{f?YIVRHP=)T^ZjzTeWr6Gc_X}O zc(i!o8X1Fu%&C9=4rh6ynp0r9f=doHdy>SAf}Y@dqup!rU^ntE3H?>f+J5m;>qNQk8bT&0=PnnyC#Z+)1hd5=ndFu`%)x6c7jY}Vl!}mSigKs^*yU5(` zmG3~nJ$0D1u(EdKMm>5c{FeP+fPwLZQ`i@8>vB8!XhloOyOp5tGwd_ zi&qzNt9)mEqGdI7I(o283MHpUPU+ff_a6*0NHcCsQeA?IpBP(yXdg4U@VU(869M9VFx zn~?wO&j=jz7)|rWza$J|9k69blNjv^gdgY=u}762QNTY8cB7aGMI|#Zo(~~>FW2!P zrG8L^W#)7YCpafx0yKuH4vnrGFASHX97sooKC)?$Zbg6Sy`ukvomoG?xX0wj;u!Vc z|3CNSD3qdbo$)^U-q&yFz9LZSvXM?IYteoHucAEj>*d>?U$o{{i@A0GgdP3^k*KwL zL9#}sr~=ucJtSm=jx!peu|^{lDEYt$ooh5dQ;eo4U?0K!Tx~Q(GmPeEozWD%VKhJQ z8corMWMqUsF&d#Sj7I2NqY>IeMn>p3qY)ZQ@tG3lGhOAASnV?x>`SQ9*#IxIPSSc-bs;+kRU)Zr__hp$!-Usp3c zgF3o8d~}_9bVJSP8`Lqk!^hlJkGWqn<{>q-Ej;vzI`mmh=nLxDKf}kqRgZmNGq#5s z5f%|KE;C|cZA2_Ba!N$x^vuY_+Q_-Iamf+mQZmOat{s<78^0oA{OZi{>uSem(4wm& zqU$oF8)~C(&|+>!#N5q{xnCRekQUn(5&I-F_E~N03)8G#@BR6F?(g~c zeSGfU@Avh1^q-OAdtKM{d|t1&558zk(Tz2dMZVQHe5H&0ns4}#i~Kuo_{$ar^xO!L zFA5yG5h!xrnAWrra}LdJ3dNl_6*X;=IEPg?g-M+^H#co2ox?kt!e!18JxvjE=g6U^ zNYUbLX*aiFi{rCz#^Z~(7v0=0S)5RPGeNp|NAt}cZPdVM8}&*dLIL6+s)QfE-RhiJPyJfd#LKAKY4v2P^2Z_S)5J~Y7tgno znb7AvlT!ufM4PtRl(jK}O_kQ$so5L-_e(n0O5LLWcxcEh{=lR2?@#6${0wdqE;jgi zN&_{p+-ypfRxK)r8X#}F7KNpLf3;iK)OKh-m6wX6M;3Av%=mMfgga-X2$69Uas2E) zO8ZWQw508vb3;Z2OKQYYEXZ@*wD`kU{R~n+-^o0hUl?lm@1<}m(j(`G6#kC`;=lFy zpRY$v2mfKGD)`LofO*`lLELQwcgIVvl&5_?NV}QPzWY)eI_b1~D}Q&R)3o(E9YI0I zyVH^MV`uMNyVBcsX`eA*ZVH%9PBEhamfk6rOMp526bou415dFMQR^8g_GF5U1hA`4 zv16z5Wq@N(iX$!64M}x%N_96)-9~#dpQ8}BBsGr0pV{#CqTL=FLC1njB37ZuLZ!-qq;m!kC4 zw+2wzJyi8W^&wONL=`<+)OrayUIu9vQI*iRldCJIwA^cyrK{6HiMsQtqQ?-Y>Yi~+ z*C34&b;pYy6OyWX#>*h%YN`?v&)Wyg+X+~|R-g{YS?|ZsN82la?M1-u5MXmE#rk-P z%{jpSVT#?|6#MHb)Pd;m8gP8J(D7S}_fVrZw#X;@hR;8#geI*y1+phETwXn?Z(3@U zbkacsv-0Li7Al)~o{Br@<6nEKlRD$tylK=lsVSZcl`&uO)adqQYL0YWpk&iBLJg6q z{Q+vAMC}s%LL1*7pk_+cF2QGndU7kj$H$-IKT{FWDC-+O{uEz9jE|C5^ElKf$-9BG z<|RkS;|h3cTj#5V1aUW_+-sZA*|bV(k%{`O=BqCj8!T^}w`$T-_g$gBi+sG)MX36| z`8VSOHE~K_ruKPIjpNkz4yt{Z+T@``W>oiegDyEf1lN(F20JN=i|T!*zSUyGL)Q#3 zvC*k(Mz{j?{WQL!FzdB~$#wB8s>FP+FjKDJJ<}9KRDTXq70*ZZ{j!9e>Z56ROc%a( z{@h8EvJuPu`YC1xDW){Q%o~_%3e1_8GKbpy5dc;jQ>^e5Yb?c9l42tT?3+{UWho9+ z&NdBn8!85)4E}fjO#Yd~{jXCc|MABk?UvYov?tRh9j%n}XXMJ~9rvEgR(AE=>yod0F?8>R$fYN(r3Z6) zo89sjcloWTN}4;?KZWG=luExmG=k3%gVMXQF=?hj$BhO_Sv zCVv2ENiQQmK{C>?D~l`O0reglcOnXVc+tr~=@W^t=#xvN%9 zt5$=n_N`X!R#%OL>MZC<9kbxxZ#;ciphW>fBNThnIS;BI%T z&92qmzO&8#rMtu1HitfU$Kf_dboG+-hf6Y7FU@(llvus2_~Ekh)yr!hE^k=9;?~0z zt*cjdK3w^7waeRwE`6(44L@9kdU&Tl^3L?|$$8{Mc&sUYw5Hs{x8{*=gNNU(M}Dmy z{+*BfUwQ<*eH75=5jgxP5cS-c-o7!@Gc>0?l(arz4&4-c#JMr`&sQO~>8_?|rvA_O*KN@9fzB()+;M zjstz(8N(eJs84o!XLhDfPEKbI;d89G^H{l0ZcS%ygU|6>oyS{!@;W>7UizGP+j*kT zCx5szA6-+F{-h{#O>xeXVq(qN;wNXRIN6#fB@JuN-FkAabxmpKlhT)K&cA(fzHd$0 z@RKssw>tf4b*67k&eIyg_j2*m%jLecHBW0Be6QSkdZpF3uJdW#OW&(+pI+_rtsj0` zkNP#IKWonPyOs0o7U6fh_}T4pzdJS0?lkz_z4h#FtKYrOXZK$EwY+`S(&u-7_}P8b zza#y5N2Y&g&ht*f|4H%lC*}T6Yo0%C@PBse`LkC4=bg`=zx3~V`@E~q|Hbh07id6F zdRI?oz}uXzw?x2i#a+LZ2fVB4de;!}{#Muf)_@P4T_0Ws^uF!t?F;xg-1QL+97=yN zlo>dj^J16?{963tYkA;E&5MzSz;Cx+d}|H--udGD%fQjMFGl+U$A({wp#ruToeb%# zV*e3+gZd4nJRwsk_yHOn| zPF)s+?M8K>1a)~3xf|7klGR0PsiiLis6ZWC%in_p^hO;I z#txu1&^vWWFnIvAg+8iFgV`CV9n`N*2JTzvo)xtjVyEpWT8Q><7c5OAuA2pIxGuy zfowJ8>&Ps0738QPTF*X$dPBh)*n0jE)Cbz2fv?AoplhHo4as`)2Tei5h~EYLioqfP^eG?55bP1o1hX6 zNeFoi4TH)wq#^8FbTd?;L5A>i(Qv3rLl%PNq7hK7hCGDKMI)hl4bcYn33MCuN(0-# zKY_+WZ#3`?*a>ty^iD&vfjogGKp!=v8`$~i4ya#)+`!LA6QMy3*#<11%3&YTkZ&OK z(PU^$L$r~75>12jHL;ESlW00LLlfVKokaIQ#+s6i~!LC-_An(|F#30el#Yl_0y z=h15Dl_nO(KabWxZ#3~R>^yoIdZ#G~BhRC?&__*a7`qI;0`+T>Vf-?*4jR;yg<)mr zRcJ(09!8d-_0X86Xfyi)+6?JyVVn6E&|A<9EqpU}0lf_wYe_bf7tlMBfGgg7NK(<=)&1407A9BK2bF0_BiOa*d#FN-jNsRzAD}8NSp-&#_CmE<@(8jP z{Rq`-i6Yro(IM!S78c3BiVj0>wD3smD*6?ArzMFbuc9N+M=fb2yB_@p^=pxl{Ce~| zG^izu#Ol#eXhcgMN!Fud(3qA;$i9Y7hV`{EA^#dW1)iae3$be`12)!{2+3=xB0O7L zN)by`3AWHCh5Sa830rB)gjgf04BKkUg=8bjf*rL*QS2sE9S+vUqWDdy2E0KVkHVTz zO*l+j5=AzlT5zPcG>Ux_)rOr=sft9HWtlqL9O98+ITe9g4)3Ev?bAG3u+60)RsoG@1u5bzcv}o zzmM9(gW9ra>^|xMk7&!I$@{1yJf0bE`yDABwNV` z=yG_rj&v)#4P60S=#X3aZRkqaN=LR8YeQXNTOIjUvJG7YJL-sH*zKq{9IS)I@Y_)z zc!LfegSDe;;4mFY4B3wQ!jU@C81`e-502I$WB89ze>hG@7K1%T1KP6!LM|%IDQuz55LjD;Tj&_sApM;3>@K$GAR9eEu20!@a;bVS?Ozo2QbzAmdLo~uh0zGQCAet?nblW zU|lSp-;L(L8+7q_tQ$QBhv`b<$!;_kj?|ULv)`b{;b>hlp8p2TgX45%@z@*m1e~BN zk0;-t`EatXXgm8iv)GF|C*_B-?(T%k*D z=f6Ws;VNC(cI+K`9QA!u>^iES_8k)#S^ez^fLTTSCT;X zqP6fxU1B)9r{b&npt0&(<_M`V< zM?Fy@djRc#gY~dP{s7tuZ_vXNu>tf69Hu8pBnQx^aHO6zkv)h$gQNAxME)T99FEhI zC1QhU7o4CcPb3G?7jUwkD2e?Q?STvQuq6Ih^etSdhbLiQ(cj<_JxLPz6@3Sn=}D8= zBj|g$LXS-1kDwpmDm_^eHiGuTwR-X-as>Sd*XxOr*`w$X{7Mf?=8vMo@Ebim85>2v z!teAX$>bl8}Kl}7lK0DPFNp+*{8=@`= z;-`8DQay#K6Z4-`A3Sx9B-K}%>PM#f%Tfd6se!cJLiX+`WcL=+-O&?MsTjfTSmEwC z(e7|4-3+d2-A;>(zCJj z96bG)Bt2J}ew<9tlck@Kr{~l5O4xhLk-Zm8_f|OVy~y8NDcE~SxVK8Qw;J19gYUg8 z*;^~!dxhLvC);~fzPFyXpJeZEMfN{1-QVW4{~>?>Bf>{TN7^6+iUxO@L|A(IBB!R9`I+K$sWjhMl@D|EskmtfxZ`0eXoFi*HiuOg8maRB>}Hf1KxvypHlrG4kygT_DoVJ2DeY*aF>3>^0 zoyy8FJURW*9=^$g>9e-frGffp)AcPD=+9lNZza~ZKB;eeP2XPpeRB5B^z408a)y@t zj}I=?j82*xopLiev&E?BpwZb2M&}+F9V@xQICp)|xx0+ghkHt&G0wl;)*RYfb6oNAslAuaDb`-xTYE+E%Jsch?kd(j+*`MKP&|A1K*m6UPWoUtASfge5E6Yg0DoWofdZATJuvJ`&ReXU}LZemUE30I{hS0Yu zSZH%P*rqVWrntbSq|v7Il}#C7Tdr?gvCy_M*tROgwx+WPacEMdRea?cPRxkMb z)&&NK78qSzVDfOmtj`NfRTrAiSvY6)Ld&fS=N(#Tb8(^F!-Wo?7YbCJ*UoWXx7s;m ztMkS~&YNB*hE0Nkq|V_#LBTTT$VrQ%epffl84P%2QmE>fb1=D$nh?xHDEmpD?l zWm1%snzB-qlwzg-ZBR-ndE3V+K6PoI_1vJzqH&Hh{uYE#|0oEcGHKZa{~4e6|3^mS zL^*MP1|3fpPc~F#Ppq&z5r@VW<~dTY`Gdv0zWD6H4;CYpPq3IdKUmBnY8Ui(VDb-O zL*Wi3%4WkTCB$`tw0JOm-6$FM&y;~?^*OXryL`6(t1B? zlFXgcK?Mw~oml!k=*Kw%@~ZKMkSPbzBM_6?5U5tLp7aw%I%pF~HRG(Kn+oHg^s@ed|EMn8Pe}r=|7T1#xFzSPXt+rJakpi1x0iFNc&Vg5E zDBT_^r-=$_qO#%obmdbgQm$O}2BP$aGRIH;Z>@TxeR{~W>C?@p|Kv9P?^~uD9Gq@+ zVfy&7|7p7EG<|dPiA1?A`tuIz+g#AMd!X;|NuSDATWkLJb#8zEJwxp;IcQtVXKZzw z5xZr^wu3XaUzoAu!HlF&Gw^BuL`mllkn=ZnFor%pN*8`|E|--yX~!{WKd< zF`sT>{*$};-=ob9vdoPt%uU+NXZ4$#s#us?Sj=&^u#C2tmt|p7VPV&1;m~g(P_bNV zVY$xTG9=n^W0vKn3d_xHmJ$7yLKUkm7FJu`tzx6Cwq;puudv$DW|h=$g{#<{w6Hnl zZgVEurYOtiY=zCaHklRx4ym0Pc z4}8@biv0hI$nuW}E$KgUwW!+?sKfc+JJtUm5LvFx9#xpLVBx>fS^lHQvNhOhTY=S% zS5{JeyX&tkh5AmVx-fiADQTk5EZ9lZOaqOGo@NC?Mahx#We%Np0 zPTRQLwsWEYUNmmJjjL}zDE7}Q#$~wcGGohp3q;$J(wTm@uD>pKE}a`(s#mi z`Qf-6ny8R>PE^PfTFitB({y*d^rdLs?(NkSu!AWHW?Y5&A;C)OVcPz2kk`EL z_3}Sh-qJq-*eW&A})ybXX zjBoF%nR2G2NKyK`4JM`@fW+)?Yljq1()Lr09@V=wdLNpkvHxbysmas#hnQqPKorj! zDYi*w-8H#TL0b0qQ=FX=eyF}6SM@u^ypmd*kX(?^qikBpS;}F1~3|3YH z_8ovjPs$}J;64vry(E=#9>#lHp`y)?mbM7+@CH3Mg4Jn?);}8CZF{Y;i9VNV9!XF* z8IDqiYdXMy9x!kSe1yO;-f-;3Ke}CSX52P-TN)fs^_b(b1e{R4dxwy@BU|zEHf9Ia z^QJMAdf-m^?#?0j2`!!KJ+W%$v*!Jo>?u_5c~G{W>VgmR(+@eNXKkEvByCFewkb#d z+6Bwf^T>%VxQ9aLGymGA{V2nzX6>gXzbjb=F2P*I=s%v?eNKrlxm12AG1Z4Ncj?LU zP>tCAl$G_D_U+=X$BDl`!5i-V(Y}f1tN*owOPl`>Pir52`zcuC{|Qy=|Hs^!GCh^G zCJ?jcH)c=v{BZ_dM1OmH{_hz26#CnthTm}dyNeSH`Xc(_*#RwC%eX5{j!PXn9dzAi z^L5{&x}WFwQqIf}+WTJl>F!xDmHmPd&*&KN*6p@}_E@&eSZW$n{jM%1o`gsoFO( zweL;Ud6cR1e5&rRnYtgQ>V3)7`##km^`OE2X@-Xn8s<$iI(^Wn)Xi8zF>A4L-87RM z2TkrxoB8PA%<^mh>)5(M?0Dmv@i&zvsfU*AS6zDe(9%5BWv34WX^` zGq10n*(_dhPj%&^Lo1)Fy8L>`<%8;~FNaorKQvSH2f}`NZD#iWrP%+MV*mFk_N;1w zTufJj%+#9&RJycj3Y`hfRqqf8Q|QW&wR(?0l0s)e_Uc0dc?w+}3Q|uCGTlYjfYz&L z2MKr4HK9=Tq9DmGx)v0!UL7RgMc0O+)SH7$chhyC81;@I;cmJv6tCVBB-u^ZgA&z; zg5JitTlSpNwVmxAbX9W zb@D8_Hx#6iw%+sz-3MB)k-c7cguVs})hJpoIYRe^!ZoVb%a73gpeyWuK3L9$W=E*7kbVM+*X#+E6rwhNq1!^%(0?A_ekUxWUxkJ>d%`4T^BsDCdg#06P?)@o z-VE`y(l(o3px=W2rj@-}c!7Q!GSn*CEV)3x1I^T`-YmaBzYCdZHE%Yppx=Y$YISTD zR?u4@YptHmk_!5L$X;t`v%G@d0R?HLg_~ZYcS7s6vcrX!=ue~l_a*eJC&(y9K%CFItU^8mB(zKDzgy(8^2!)MwW!PG~M<{8evtWDeA)&mHt_}xj zr$w1I(KX=p+SyUUCb}jZs$CQ%X`*Yv;o8+va%z(qj?!+9GQCOHfn&5gqJ%f;x^TR9 zPn6^)T@Oyw9*UCRq#MA5cG?!x+jK+tw08Cu;cdFjUwAjQtMD(1ybl2@U^AWOt)^}C zmGB3d@?4#ct-?0C3v8{^vsKbYUj^Ii3~iOS(Y@gyowOL!cDfI|UMD+7*iK&qhw2o? zNZRSXaJWu&jJ%!h2S@2N$Cy5*`@=Ch9Wlbk^Z+vix>I}umAJaF&giczl z=@WV=d|D?vR``Uz2`<(tij_Q}hry*f)v@v?^v!U&PIIj3Q+haDsnZcFd`gdiYjk>I zB~R&*aQ9!pcLRm;{?pj^XW#&At=qFr@`|1TQ&I2sxW9sc!Xl`ET@daJg>tcGGwCQn*sL zW4rJj{XATw+p}Hrj$Q`W=?-m|zoS>fFLl!rOnd1y@UOZ0NNV zUQeQAkp2Qr)Ei2a57K*JLN6`J^eg=>d|EF%N%)oi8(geclqC5|e+QT9RVT^6(%-}7 zdd*3uBlHh&rCvvpaD?6q*XZ>mNk-@&;X1vcB>4z^2!5%TmTWpoABKO`!*>1!g;Q5w z{A<$i1-7AFklb{ zX(vthC>SzM8)Tmp?olve6dM$slFJl-RN=}gn#mgD94W*~pS>hE83q$f0KTEulLG5u_8Oly!Sz;H4t)cuB znI&Gua5SW_ujn-Uh}fGEY>1ua9})X7HW=cku_NL&j4(sVY4V8Jml0_wJ;myVhJ_JX z#4ixY zS#dF=&Gu|0VO33r#TE<5sX$iYbe1*|(M3(T& z#C428BUuSnCcerTF_M>%W#W3qn33ol`+~Tcp>K?x<6jWpV$3kc&tVtDw;9I9l5^w* z@g2r&W9d0|h4?PR!k9eAuMpp3SQ*RCVHM&QhOM#u99bd0&u}yrm9j61I~c*nSSkOK zxRbHL7%#;xiJvgSj3uSyCGk^6q_MP=T_t|Th&Com`BmcQ42!5Lg)TTR#%bg1^TNvtzcGr9i_S|fE4*Wr8dsl}Usia}C^v3C zZ(6JHfl+DPab8%f(95VX?l~{1RrtuLGafoGuT>agyfjWLGrg)X%=pzfyG(dh;Va{< zaZ#D%s=^54gK>44{Hnq?MxSwWnQ6Vkcg7dvjxu4r!YE_dxTj1~uQ0~=Zah>bP#H58 zN!ZuKlNI$%sP!dG!oMb-qBz3@mtfb#3`JuTiG;i+R#cpABBdg2#Y&16CZvSlC}t{J znY^B)tY~f0Bat*JuoNk#W^XbikvA%+D+ZaQm76vxXeh2X$u1W*DQGH&niQ2wniRAY z!%eEo~p5T1PR?L{^U76zeJ`n8?e?n_@l1WRq`y(KAFB z*tf+7iUlUv1^#WZp<XU+oUOsbBBG@R+IuYD>FJyL5J~` zD3j`ICd~^>DD==ltw&v(DZe(WN5O_#v;QWhR?o&)?)RnGQET>h0o!vaHq^@f#qmWm z$A{F8-9iUhqdf&3erP%P^`+TAFU77e`L|jQ_1CfT?#>m@cDlTtFy`SWL z2=v*PO369nT29fT-?bd+4d3QPe&h|mhhV_7RR5ku0q++DQu>vsDUi}~u%?aK&Ofvq zmznsKA1nGlxvHIs(lt2v_wn1df2_Iw$++hMe=N%W7@Gsu*?>(EU|S8?HUF5I|2{%@ zRR`S+z=`2G_+wP=wGH%61AYG71Z)QVe~iZscK;!R!O`>JtxJB4%HiK9+>2;uOrlCQr#!pK{P(O4ht7hn7q^>^1kE__j$ z)bRU_c&H|G!H?|EzwX^$@?lS0Ai<+D092^=SNtW8_Wug-_&*`;<^7*@;+~QFpUd`l z$tUEv|2ZcvlRY(aDlc=IS*FVTOjXxRPEe*=R3;bARL{)RATl+}GqoBrwOccFUS{g{ zW$MwhOs8g<@v_X#vMlCj&2i0|8{_$zUuHSfWE_l^rOP5 zN27R0x0oG`o_}(Q8?qp?v(dSjU&zDcl@8jivGRyCupa0o4 z|4UH*Kve!9nm?48KTPC*Ezcim$p6-w|NUkDXkY#qonTKRrcNiOnG-4t2vs+NvzAcX zLU6@|`aweDB%yhM(7HxwKOl5|A#^_xdI|-m(+bR{7nqwDSS%=*<5n4EBw-m;S3%4CC zj6Yep{X${FwZa__3KM@RO!`!qtWbJF z=+ zZY$S$Rj%7#t_M_@s#KWiSD0H=SS+lV<6bd0xWY2J!YZX=URH&5L4{34g>7SnU0a3y zs|tty3P+%luTtr%U+HF1>AtXXwR@#UaHVH-rB_O&cUGlOLFJl?O5esxzqU&MSCs+% zm4QH&P^BtLziNv`RrJEDt?pGZ!Bw%*RdFd*+p?B!Ib);toq@C`mYuBBaQXn+Umc*svqsI9|IbYNx)PkXH%pBp9H8V zxtKaNARxd|@-!7R;2^+N3NRHlP*;X(D215f4LAg7DMgq{8xR=KQHnK{HBbrNdP+&A zv_?b;FjHzWMH+D>z(T3T)Tt3+0&|txO$CiO6R=Y1G8Hu<%7C>}w<+F;D+9Jly{6Je zgaz0u4VcOraTegHG-^t_fv5tmO3r4;4O|s)S8_3Px`A*24<%1C!3~@Pcqs*#iEbcj zfR9p$8GZv-1ALVt%%nFEF5s^eYbLvabAdpmBs1DgL>q`wYBED^;@Uv8Qj3|>O+*KX zQEE36+{AT&IHfK#(M?1bh*#=1!*AlcK!Q@Qne-;22P7&Dn8|M9dO)($s2S}R@)MA$ zbHy7N&X9879UFM=Y$Sj~n zsoNaCgU<&H~s8<>_r`<#50If>S7RWt(4$!9LV&QZTnF~Bp z^0W}#!{-8zl>#h8_Yg~~ZU^)$wOBaaN9=(wO6?Yc`?x(YsMKX4x{o*j!%E#2_v8o#!B}_|Jd}(`AlRE3$}69QT|fXvG%+TxP%=Q7hsMXfQ+O z;H|hbpv8=sBW*<%13JvuIkHxKF`&mxnnP1iqG#Mc8y znE{rfP9y}#VTM@Zop=b4%Z#v;b|M>qJZ7w=tP|e=OKsmF?5_yVm1}c~> zmQGKRaG;XeZYg+*hXYm2E=$oiEnm`PT&7sw8vkJ)5}yufz={md3Crx!>f@P*lKC3t}+0)xyhE71!i z2^eN}Tj4M8Bw&QuYbAYwBm>`>16HyZcrq}?9JQjoM0NsGm7V7yFY%p#in7Z*r}11ywV<~hAW_5*X3+vf>h;rjtA<*s?6SI7atTDf~3{t7<;*edtVlfFVS z0DI+ud9qh{2H>bXI*;}%atLr$cD6=-#Sa1Q$}ZMUzaobL4`okl!LRsXz)Ly6TJ$S& z1n^M~vBrPJj{v^P5!TXQk)wdWa;&xNSNtdtsGMX?dxIPYqLiDgkvI5pAX>S_+UX6F z2gE41TMORcc|e?Um$m2(asr4~?zYC?;3t3tIlxBr7C8gtD2LeKZ}Brgu5yHp^es{d zX+3V5y@Yb*PJR{<}SlWb`pkt;x- za+59c5x)ZTE4SD>eMIVjFUsw)B)Q8*vI4n;)K_7kt;IabjM14pTpuq~U z!~5_iK#LV&C+$OS0y?Z%J6Ruo6VPKN+0puuyMP(1$qwnq?*bOA7CWbY_vk}7ZA_tw#NtYE+B!` zYcCx{UI2-#0ejgX{sKs5joQ@L#$N+jEEfl-VdPifFw4_HFpU2S9AyPK zh=!4FAcqy=fDhx{KrSo7K{|}Q0rFU}4zgkV4Uo@Da-fYM?}2hwlLIn>zXvKRF=>wC~7g zpq1t9hA&SrLxX@5mtV zoE7UR`;HF+FIY*Av@zrx(8p?WM8@!MKtHR+(P<3%4t!y?I||0|@4z6d%TY9ji~_@~ zZby6!9|cBOy^hi`WDNMu8gP`2;bXuUYt)f8pEn7ds_M)~<`a`Z6;&6$(|jHXa#TI} zg82jpa#aKPqWL@s)KCrKm3*TuWj|t9IZRZOX5=_uawTmxW$WsQbRlE84LP8m|Rqf?V7xGx3z3Kp8wvb?f zj;f=4+9IAR=&I`MiYy{jL3dRbSEofh4(Ork=_*)6a6m8B09Vl>o*L+*8sdsCBGf=% z)d*MVA|4m?SB-U*Eh4yJplXsUZ81+9j8biKMHUms3eHXe)Vhz*bdfcVs0o2W(Syad%qDn+rZt^>i1kB<6yT zRRi2bD|wb+r)r2hzLKy6pQ=Wi+Uo#W!+ zEy#(4<6=tdJ24q_25xXfTzfZ7Xs#R zLOgLFA_UCkM0iSlcpJbxPOPWQhu8q-bCNu1zP!y~Ij6}J@g+8c6`U4NCtqGTSjlPk z6!;S1U=^p!Q{>Bw0Bbnip13a&0oHPQJ*B?9NU)AG;3@MZBEfpjs3*;zw-s#VIC~-f z#8$A4C+r2f1(@Hr>eOXg3+ zfiE~oUbH~o4zQ2YU47Dyz6W1LYhS`cq1I91Kr8wnzIf+}h*-cCWhRFI?Q=`9E%QbDd- zfVU`!w+qxz3-QK-h+UwTT7Knt}N zZ>M11esHc@ySE^i*biE%b$N?|c?UpiwQg@bm^c91s`YwHgLxUCz1o1cESSgu9o0s? zY3q51Kvy+qA7njo2y|C-@o`$uI}CcLdHM*}6Nf=BwE!Q{dfpMxM=it$Ur!tXebpj- zr0aP{L4UPaAK7~1C>W@g#ELUq ze?biK2=w9x_={qA?Vt}g#2=3#+Cg7#gugU~_Zal&#`?=*h{s?cH_4wC$9o1wahv>+ zIN})?&28~_isL;8W4P`9f;i$i7{~4M7sc_qz<6%AKORSPfeGARe`y@=1(?Vk@R!9A zFTiB(s6Q>9_ZrOPItL)}#A`5%>k{A;&-)cT%=HWq#1p@QN4Wt3qIh06n8OVTz~hN- zFqaz?cn?-^TLPRCcptz@ZhL?rf%pJc zak~OU3A|pghT9#0ClI}0Ew?v7n!x)A)^P^{WC_Gau%0^_KuhF(23xt#fk-0p8EoUa z1Ue=1zJQOoo`Hfy;tTkg8xSZ;k&#kr)7!`%+dA~3K|Fy=4uBBMGZU% z*3b?S;td1@YiUOar42k7*3pg?${Gk5*3(WB(i(Y6lq1z7L>dVtDi)+g=+ww#QnFOL zP|!#)seF(wp{S9kOj%RiLcEbsrh-Cxh0;bIi_)hCgtA70MJ0xe3TZcZs<5lJa};ue zP=(#KU807v;(3j?9%@yBT#8z)+(JLAAsC}pKv405ZEdy}injJ3RoYsWpomop z7_OmKrJxjTYw-d_i}1}%vNz=XIzP_y`1r6hd(X^Tv(~%bwPt5B0ehFTIaPlHcg5Cl zl&Kvz2v;nQqe`XS#5v$ytD02)O@afC7pYH`+{C%yrK{Fd#Z7_>&KbF%s=kR&0dHS* zr|NDJQ^0{EPgC_baUOULt3S2lCcy(Ij|`>KZsF6g{hXjQ{w-n}c90X1Cb@-A#}08; zrYUX_)3GC*wQ1^G_zbL=lboizMa;lH;cQOR-@<2NpK+9F9k+;?*ykKo8m$%=VCOkC zY5ZD30FEiCPm|Q*Uf4xWYnq~#@B-(R+)q>2;54kS7i;6JO;^|9bFp?#a=NaL zn2X)zY);qL;q$P29A$b(9Wf7kz)_{s8t_HfGfqu9zkyf;jytJOmo(u1SRbb~UC}`J zgL6;rr>h(A0PGc~J6+d61b_ojo~G*?a546V)1ThaK#0LfC`0Kq11`m!xIr0w10ltn zxe*x>1HKd+&s~|JFc3>ISMJ&jwE+*pCUKK9bOs^_bK`E#&>QezjLTJKbQp+WjK@`F z(3N6#6 z_!{soT5G1FjaUQDVY#2FZo}7tm(sd3b#262a3IUmOnn<32i{KW&+KR;;=su)Lz%R8 zJQ3T^4N~&kiA3xmH$o|C$CI!_+?7g2JCTGP;jUGx+wo+qn47HBwG+wMC)~|SeLMah z_8C{H>}V(6!#?M#l(f6}``}f!8YTZO@jf`lrCuqyi*Ew&w6!V~cZp5lJeT`Q^<8{3 zc)6`xsk=*T1_!%5RqF5JA7EFw{mPEJ#0TJnmmwwX9-e_2xItO`dqf7-#Er<3+`}`m z7VgR{#XTYuYvZoXQs2XsSUWd4OLvb@Vt2Wlv-J1yEbJavnbmQR$ig0QRavwL_%`er zwwGymb}FAn>i|2-%N-o2 zChj=J;Y)vfD4Q?YQ%F3si>Y1iql#XU7iX;P3@(TZD{f@zerEX%KZp@e9WT>_RqHM> zetqxu@VpjQm?SzjReeD)zH8o!8w1RX1M)7`xO6P|fpw;#T$+bwkk^ubdystDb{9f~^_yeb*eK0=z(9l^rtH^4R~Q zT{KImSpy_=jZS_*Z!M3>dk{0eYff`#a6ZscT#6UH`+N?R2K2zMlSq}l0@M7imL&$9 z*hfFb`w0psz4jtly)oz}?sxqWTs6>+ zT(j%_LJDQ;h^ks_#CC`dPPHxk2uhpnG`4}~TQQkbhHBa(RnaTWaXbyRR?%+o#l!RwAM?Tc)n--`Njoz;I%Ynj4XQm&yys~MygqUX`*DH3SKUq zwkW%3{n5W{q0Fw#s6qM3YQ~=3^ja|GY(>GG@mnWJ#*93ew91e6S0{YXh96$JQ=~n{ z_O_kO2O@}_#@n&C5wMAQk+bLQ#gVDY?H2TYcoMcaHHK;XS)8}n9&lN9idZKSyWEgM zy(Q0rix;pto319!2NaIPa%V6xQ;_Dc>gDVu=^l8C`p#+-mo%JJ6`0Th%@f&UbM5@Q2mlk>O5J4fw~B zA*(8VqW>cTb~{fJ-Sf5ndpqbZuDOh`{_;Vm(-*D2u;`?CjXmktVqnn>#`0ynsEt$M znf00_HumxEbuWW8Y4)WSQY9Vpyt5D81B5&nj)QxO*zE5%vJudv`kxKZ&k~nGq4zKG zT|ykRI%g?dk>|2OEnJ2<&Fi}C0%-r|k8WNnT^&h$&1edSQVMbK9bfG=Y}+vz)F*;u z*&LvTr&e`d25>{ApdWrXXi#)o#ED`QC*DZDHk%Q<(M9t!M*GB~8>iEyjMMWWh!)6X z(Jm7xWFJz;fTh+vVXK$BK~oqmmpZgC7lf2O4n=b`AVl5qj9py>v&nE}R+Xrxnv7;Q-Sn?fUi4lGr* z@iK)*&E-Z7&-bv9bg_VoV7Y<|?z4-;Zjj_;FInvz6(&u5Y%zYZ%3(;2SNdR@t?UD{DS}hC}lD@du2vVarI{o&?rTn8G z95lgSZi2P;QsFx2QkB-sS%%GsleuU_cA>rY?<4A0D_LFelzm~98^rQ#AAa6vwp6^} zAL4ovp%_q9B6E%!4Kc+*`0br-Iw6E+%=Ecu8YJ(hluR}d# z(`8L2U!*%;6)l~$XbWp1@&v#;gNYN2;(4%xN&%z}^X;O{9mE;;@caSWs*;YGC|sv~2;~JUn}pZ9y>8Qh^L~ zJh0~w6+1yx1B^2Hm__(90E$pjY<=s7HH=t1--zlIexPtVA^^gD@n3eS2Xc|ss6}sb z6I?F&tEVhPKji`Vz_loT_==Qg>RIb(@Rne#-!GcD#rV094tXW)t7A#Lo)W;-6spT9 z0kkMo&lQjYsNzvj_H>9ft54{6z!V#ZE7s?_P=KxU3sM}Q><30OWIri?sZ-uH1IN+x zmLvG|KPSntH8WscC1neY>m9+x;@B+ec0ifr-L4zv>jj7g8t6>TVu#DukfVb27t_Ou zT*hT#1CTC`*7y?W$90xp)SqMsXky94TG7F$tQ30`0swQ|D^_ncg#dQrn0=?7F?RMd zAlxEDxfIf&3PTpW-%)z7#Co0C9&+=){6ihE0*Vs}_z|8n&jX>DbKCXiv$C@F%L<{J zIp5(dSfJbc0{(3Z`O)z=DcNAEP%JinNXZ8CrEb?_3)Sybz$bv&Dv3JAFGmrD=8X*$ zji2nN{{x8v0x#H)g)yx)M~T=;6)jto@@f1{I3-tRZ*YDIJ1=1JBIp-g(mZb$S@h#I z8>!Z!1MHqoR~Tq2)n3^W8Tf)aki*r` zf1e6rNbKGGJh*r|mG@D81oOUK)~&~pvjZ)Mztu-e>(awOVQ&N)ndVema7SxHr8F$2 zhUYV+)6pODQNK)CUvU0>%Vrf(Lf?Z^%|&M=>ktk*Ivv_;tJk2_@#eSil9nxWtUCmi zSKa4Q;Exf;P2#i&OoT#AWO1R$GXa1%iD}Jp=lS{dxp_rYu}Lol*@fK==$P z4$<4pydpG8^a`f)-L63+#e^kZkZ{jsKm4~jpTzUiSeNfQh`@Lz&(c<~zuG$)I^l34 zjma<`C;1g)`72ZGpCO?5xd}gku(Rw$pQES+6E(zLZR|)!QveysClAgrvq@6+5fC?|Af~v2|5E#2U0YI_rAZbRjc{~+L z?x69MIEa2jIB6n!F+Gv*p&ZXu!uO06qa+OcIlh8M&a`XMj)I;?w2C|^QZ(kxtPo1b zUgLR)QcrRsjR2IUP%Y;##;ZOCqzpW~F0p{JeU$7$+xJ*<{s!z%T<1;LZE&W0qyt2r1a_q50LuV zjU26iAhpE$N331eBp*ZffKpQvwAVck94Gd4h7ZDdmY{kCFHQ z@^DZWc{0oKbHra!dM9*f46{4Z-~@rjhT$S!>{!i{_gjou)c%cx>_T%z0E#p<-;kxq zEwA=0wI^wp7*#SaZMPJsIGGi5%mI2Cs6bs?VJx@ktU1=Q^PLDgv0#^`;47x^o zlZl%{VAoVQT(+MMdIAMQ?Ns)dT!bCqV)rWhY2S?Gcg^MC$zsOqw;?vi$znj*-ip0{ z5OMe_5H5Xh&)0=shwG{Zu|zs4Sx=gksZIsWcbWZ9w$l5mD?iS~Uws zlF5sLZ0Y*O?KRL^g28mOT=^)&f+)JotD|-#KaU2rrsC(RS!f!gpv5u7ATM>o*1tt+d-@WF0*V(?>mdBiU`Uon?3p^b9>Zq zy3TUkWhLG0?*SUoPGW^(%IVEOBE+donPoBEmL&XQkgg>m zc@8o3Z$m)K28a{XG6{y6vR^FNJ~v z42}JY0v}EO0Cu8*VLXM_>GNjoX$oC=V+(k`LyEnWULhRk$Hgi~muO9t;B3g9jV7l7 z%yEOT@By-HSUdp5CxC@k4$L3Oc>a>n2$J^O11KdHJIZDWz3*Gzp1;68qa5CDT6^Fc z8WsDe_nbN4qwM8h&J0WxRXqce%)HQP?|2Eiw-}ZsU=_*f^k+eok{5f2iD8LLn+FJ-s{{PpKK0uK_VH zTgv|F)(6B#S)_g2ILcJF(4n4VwG{%Y%F6g^NvJ1c;eXX3{;T}xJIJZ=na| zCNxJ9HyJl_sZz=>OX8i$s0HvYdd~@V$5b?95I3wEV|Gm;Y2uK=73?yBT^@=dHS)_#=k7Hs~7B9*9W zg6#ZnT4^7u{J?}NJpTKU!csrjq=AK{3-neTDqL*ZW9(+U=|lq#`%PHhBdEGk76=h_ zPJLAD_R)yyl<;h%9y%9nDIBxkSpfkJ;FKw5Xu?UNDRlg?lV#^d0oYA6h!^sI+S7=97iY23X8JYPoI9%Z)iQ7#MAx1%Yz0hPBILcNO~G`q%1( z|Dbh$j~~Yxmfp{C2=!V zPCI0+ZhUh)eB}b{MeSw(IC}!5FyhvuN-7k=k}b&Q0vyA{i2T9rg6&i;Rw;nHIh!o6 z^R<@O1(}xD#pf`xeh+tle1Ou9g_emoq5eVO&~!gm!e|S(1mc)7EmKP5TAwXNH8p5_ zZi5mMI^o^o*dB}PIHQibb^+T(%^t7xXZ}X^0$&G-h|jYkao!bRU*04t4~Dx5zWp8~ pAi>@Im$4^kpJn7c6z0BW(dL<6VgT=!SiyIu{vZ4u>U{tJ literal 0 HcmV?d00001 diff --git a/python/triton/language/libdevice.py b/python/triton/language/libdevice.py new file mode 100644 index 000000000..226480fa2 --- /dev/null +++ b/python/triton/language/libdevice.py @@ -0,0 +1,1661 @@ +import os + +from . import core, extern + +LIBDEVICE_PATH = os.path.dirname( + os.path.abspath(__file__)) + "/libdevice.10.bc" + + +@extern.extern +def clz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_clz", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_clzll", core.dtype("int32")), + }, _builder) + + +@extern.extern +def popc(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_popc", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_popcll", core.dtype("int32")), + }, _builder) + + +@extern.extern +def byte_perm(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("int32"), core.dtype("int32"), core.dtype("int32"),): ("__nv_byte_perm", core.dtype("int32")), + }, _builder) + + +@extern.extern +def min(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_min", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umin", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64"),): ("__nv_llmin", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmin", core.dtype("uint64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fminf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmin", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def max(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_max", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umax", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64"),): ("__nv_llmax", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmax", core.dtype("uint64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaxf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmax", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def mulhi(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mulhi", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umulhi", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def mul64hi(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int64"), core.dtype("int64"),): ("__nv_mul64hi", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_umul64hi", core.dtype("uint64")), + }, _builder) + + +@extern.extern +def mul24(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mul24", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umul24", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def brev(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_brev", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_brevll", core.dtype("int64")), + }, _builder) + + +@extern.extern +def sad(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("int32"), core.dtype("int32"), core.dtype("uint32"),): ("__nv_sad", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32"),): ("__nv_usad", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def abs(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_abs", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_llabs", core.dtype("int64")), + (core.dtype("fp32"),): ("__nv_fabsf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_fabs", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def floor(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_floorf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_floor", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rcp64h(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_rcp64h", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rsqrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_rsqrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rsqrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ceil(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_ceil", core.dtype("fp64")), + (core.dtype("fp32"),): ("__nv_ceilf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def trunc(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_trunc", core.dtype("fp64")), + (core.dtype("fp32"),): ("__nv_truncf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def exp2(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_exp2f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp2", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def saturatef(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_saturatef", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_rn(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_rz(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_rd(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ru(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_rn(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_rz(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_rd(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmaf_ieee_ru(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ieee_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fma_rn(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma_rz(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma_rd(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma_ru(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fast_fdividef(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_fdividef", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fdiv_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frcp_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsqrt_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ddiv_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ddiv_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ddiv_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ddiv_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def drcp_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_drcp_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsqrt_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_dsqrt_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sqrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sqrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sqrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dadd_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dmul_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fadd_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fadd_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fadd_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fadd_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fmul_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def double2int_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2int_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2int_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2int_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2uint_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def int2double_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2double_rn", core.dtype("fp64")), + (core.dtype("uint32"),): ("__nv_uint2double_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def float2int_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2int_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2int_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2int_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rn", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rz", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2uint_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_ru", core.dtype("int32")), + }, _builder) + + +@extern.extern +def int2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rn", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def int2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rz", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def int2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rd", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def int2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_ru", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint2float_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def hiloint2double(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hiloint2double", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def double2loint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2loint", core.dtype("int32")), + }, _builder) + + +@extern.extern +def double2hiint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2hiint", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float2ll_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ll_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ll_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ll_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def float2ull_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ll_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rn", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rz", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rd", core.dtype("int64")), + }, _builder) + + +@extern.extern +def double2ull_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_ru", core.dtype("int64")), + }, _builder) + + +@extern.extern +def ll2float_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rn", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2float_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rz", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2float_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rd", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2float_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_ru", core.dtype("fp32")), + (core.dtype("uint64"),): ("__nv_ull2float_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ll2double_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rn", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ll2double_rz(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rz", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ll2double_rd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rd", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_rd", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ll2double_ru(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_ru", core.dtype("fp64")), + (core.dtype("uint64"),): ("__nv_ull2double_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def int_as_float(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_int_as_float", core.dtype("fp32")), + (core.dtype("uint32"),): ("__nv_uint_as_float", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def float_as_int(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float_as_int", core.dtype("int32")), + }, _builder) + + +@extern.extern +def float_as_uint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_float_as_uint", core.dtype("int32")), + }, _builder) + + +@extern.extern +def longlong_as_double(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int64"),): ("__nv_longlong_as_double", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def double_as_longlong(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_double_as_longlong", core.dtype("int64")), + }, _builder) + + +@extern.extern +def fast_sinf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_sinf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_cosf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_cosf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_log2f(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_log2f", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_logf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_logf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_expf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_expf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_tanf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_tanf", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_exp10f(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_exp10f", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fast_log10f(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_log10f", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def pow(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_powf", core.dtype("fp32")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_powf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_pow", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def hadd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_uhadd", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def rhadd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_rhadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_urhadd", core.dtype("uint32")), + }, _builder) + + +@extern.extern +def fsub_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsub_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rz", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsub_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rd", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def fsub_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_ru", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def frsqrt_rn(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_frsqrt_rn", core.dtype("fp32")), + }, _builder) + + +@extern.extern +def ffs(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("int32"),): ("__nv_ffs", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_ffsll", core.dtype("int32")), + }, _builder) + + +@extern.extern +def rint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_rintf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rint", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def llrint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_llrintf", core.dtype("int64")), + (core.dtype("fp64"),): ("__nv_llrint", core.dtype("int64")), + }, _builder) + + +@extern.extern +def nearbyint(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_nearbyintf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_nearbyint", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def isnanf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_isnanf", core.dtype("int32")), + }, _builder) + + +@extern.extern +def signbitf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_signbitf", core.dtype("int32")), + }, _builder) + + +@extern.extern +def copysign(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_copysignf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_copysign", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def finitef(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_finitef", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isinff(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_isinff", core.dtype("int32")), + }, _builder) + + +@extern.extern +def nextafter(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_nextafterf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_nextafter", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sin(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sin", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cos(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cosf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cos", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sinpi(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinpif", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sinpi", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cospi(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cospif", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cospi", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def tan(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_tanf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tan", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log2(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_log2f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log2", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def exp(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_expf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def exp10(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_exp10f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp10", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cosh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_coshf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cosh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def sinh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sinh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def tanh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_tanhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tanh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def atan2(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_atan2f", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_atan2", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def atan(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_atanf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_atan", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def asin(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_asinf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_asin", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def acos(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_acosf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_acos", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_logf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log10(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_log10f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log10", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def log1p(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_log1pf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log1p", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def acosh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_acoshf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_acosh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def asinh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_asinhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_asinh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def atanh(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_atanhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_atanh", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def expm1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_expm1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_expm1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def hypot(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_hypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_hypot", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rhypot(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rhypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rhypot", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def norm3d(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_norm3d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rnorm3d(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rnorm3d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, arg3, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_norm4d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, arg3, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rnorm4d", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cbrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cbrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cbrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def rcbrt(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_rcbrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rcbrt", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def j0(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_j0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_j0", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def j1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_j1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_j1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def y0(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_y0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_y0", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def y1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_y1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_y1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def yn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_ynf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64"),): ("__nv_yn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def jn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_jnf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64"),): ("__nv_jn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cyl_bessel_i0(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cyl_bessel_i0", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def cyl_bessel_i1(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cyl_bessel_i1", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erff", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erf", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfinv(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfinv", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfc(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfc", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfcx(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcxf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfcx", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def erfcinv(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfcinv", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def normcdfinv(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_normcdfinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_normcdfinv", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def normcdf(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_normcdff", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_normcdf", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def lgamma(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_lgammaf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_lgamma", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ldexp(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_ldexpf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_ldexp", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def scalbn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_scalbnf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_scalbn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fmod(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmodf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmod", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def remainder(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_remainderf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_remainder", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def fma(arg0, arg1, arg2, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def powi(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_powif", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_powi", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def tgamma(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_tgammaf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tgamma", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def round(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_roundf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_round", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def llround(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_llroundf", core.dtype("int64")), + (core.dtype("fp64"),): ("__nv_llround", core.dtype("int64")), + }, _builder) + + +@extern.extern +def fdim(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdimf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fdim", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def ilogb(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_ilogbf", core.dtype("int32")), + (core.dtype("fp64"),): ("__nv_ilogb", core.dtype("int32")), + }, _builder) + + +@extern.extern +def logb(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp32"),): ("__nv_logbf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_logb", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def signbitd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_signbitd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isfinited(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_isfinited", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isinfd(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_isinfd", core.dtype("int32")), + }, _builder) + + +@extern.extern +def isnand(arg0, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, ], + {(core.dtype("fp64"),): ("__nv_isnand", core.dtype("int32")), + }, _builder) + + +@extern.extern +def dsub_rn(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rn", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsub_rz(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rz", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsub_ru(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_ru", core.dtype("fp64")), + }, _builder) + + +@extern.extern +def dsub_rd(arg0, arg1, _builder=None): + return extern.elementwise("libdevice", LIBDEVICE_PATH, [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rd", core.dtype("fp64")), + }, _builder) diff --git a/python/triton/tools/build_extern.py b/python/triton/tools/build_extern.py new file mode 100644 index 000000000..6d0a04e8e --- /dev/null +++ b/python/triton/tools/build_extern.py @@ -0,0 +1,340 @@ +import argparse +import subprocess +from abc import ABC, abstractmethod + + +class Symbol: + def __init__(self, name: str, op_name: str, ret_type: str, arg_names: list, arg_types: list) -> None: + ''' + A symbol is a function declaration. + + :param name: name of the symbol + :param op_name: name of the operation + :param ret_type: return type of the operation + :param arg_names: names of the arguments + :param arg_types: types of the arguments + ''' + self._name = name + self._op_name = op_name + self._ret_type = ret_type + self._arg_names = arg_names + self._arg_types = arg_types + + @property + def name(self): + return self._name + + @property + def op_name(self): + return self._op_name + + @property + def ret_type(self): + return self._ret_type + + @property + def arg_names(self): + return self._arg_names + + @property + def arg_types(self): + return self._arg_types + + +def convert_type(type_str): + if type_str == "i32": + return "int32" + elif type_str == "u32": + return "uint32" + elif type_str == "i64": + return "int64" + elif type_str == "u64": + return "uint64" + elif type_str == "float": + return "fp32" + elif type_str == "double": + return "fp64" + else: + # ignore other types, such as pointer types + return None + + +def to_unsigned(type_str): + if type_str == "int32": + return "uint32" + elif type_str == "int64": + return "uint64" + else: + return type_str + + +class ExternLibrary(ABC): + def __init__(self, name: str, path: str, format: bool = True, grouping: bool = True) -> None: + ''' + Abstract class for extern library. + + :param name: name of the library + :param path: path of the library + :param format: whether to format the generated stub file + ''' + self._name = name + self._path = path + self._symbols = {} + self._format = True + self._grouping = grouping + + @property + def name(self): + return self._name + + @property + def path(self): + return self._path + + @property + def symbols(self): + return self._symbols + + @property + def grouping(self): + return self._grouping + + @abstractmethod + def parse_symbols(self, input_file): + pass + + @abstractmethod + def _output_stubs(self) -> str: + pass + + def generate_stub_file(self, output_dir): + file_str = self._output_stubs() + if file_str is None or len(file_str) == 0: + raise Exception("file_str is empty") + + output_file = f"{output_dir}/{self._name}.py" + with open(output_file, "w") as f: + f.write(file_str) + f.close() + if self._format: + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], + stdout=subprocess.PIPE).communicate() + subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() + + +class Libdevice(ExternLibrary): + def __init__(self, path) -> None: + ''' + Constructor for Libdevice. + + :param path: path of the libdevice library + ''' + super().__init__("libdevice", path) + self._symbol_groups = {} + + def _extract_symbol(self, line): + # Extract symbols from line in the following format: + # "define [internal] @(,)" + entries = line.split("@") + ret_str = entries[0] + func_str = entries[1] + # Get ret_type, skip internal symbols + ret_strs = ret_str.split() + if ret_strs[1] == "internal": + return None + ret_type = convert_type(ret_strs[1]) + if ret_type is None: + return None + # Get function name + func_strs = func_str.split("(") + func_name = func_strs[0].replace("@", "") + op_name = func_name.replace("__nv_", "") + # Get arg_types + arg_strs = func_strs[1].split(",") + arg_types = [] + arg_names = [] + for i, arg_str in enumerate(arg_strs): + arg_type = convert_type(arg_str.split()[0]) + if arg_type is None: + return None + arg_name = 'arg' + str(i) + arg_types.append(arg_type) + arg_names.append(arg_name) + if op_name == "sad": + # Special case for sad, where the last argument is an unsigned int + arg_types[-1] = to_unsigned(arg_types[-1]) + elif op_name.startswith("u"): + # LLVM does not differentiate between signed and unsigned integer type. + # We have to convert the types to unsigned + ret_type = to_unsigned(ret_type) + for i, arg_type in enumerate(arg_types): + arg_types[i] = to_unsigned(arg_type) + return Symbol(func_name, op_name, ret_type, arg_names, arg_types) + + def _group_symbols(self): + symbol_set = {} + for symbol in self._symbols.values(): + op_name = symbol.op_name + symbol_set[op_name] = symbol + # The following cases are grouped together: + # op_name, op_name + for symbol in self._symbols.values(): + op_name = symbol.op_name + if "max" in op_name: + op_name = "max" + elif "min" in op_name: + op_name = "min" + elif "abs" in op_name: + op_name = "abs" + elif "pow" in op_name and "fast" in op_name: + op_name = "pow" + elif "round" in op_name: + if "llround" in op_name: + op_name = "llround" + else: + op_name = "round" + elif "rint" in op_name: + if "llrint" in op_name: + op_name = "llrint" + else: + op_name = "rint" + elif op_name.startswith("ull"): + if "2" not in op_name: + # e.g., ullmax->max + op_name = op_name[3:] + else: + # e.g., ull2double->ll2double + op_name = op_name[1:] + elif op_name.startswith("u"): + if "2" not in op_name: + # e.g., uhadd->hadd + op_name = op_name[1:] + else: + # e.g., uint2double_rn->int2double_rn + op_name = op_name[1:] + elif op_name.startswith("ll"): + if "2" not in op_name: + # e.g., llmax->max + op_name = op_name[2:] + elif op_name.endswith("ll"): + op_name = op_name[:-2] + elif op_name.endswith("f"): + op_name = op_name[:-1] + if op_name in symbol_set: + # Update op_name only if there's an existing symbol + symbol._op_name = op_name + else: + op_name = symbol._op_name + if op_name in self._symbol_groups: + self._symbol_groups[op_name].append(symbol) + else: + self._symbol_groups[op_name] = [symbol] + + def parse_symbols(self, input_file): + if len(self.symbols) > 0: + return + output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() + for line in output: + symbol = self._extract_symbol(line) + if symbol is None: + continue + self._symbols[symbol.name] = symbol + + self._group_symbols() + + def _output_stubs(self): + # Generate python functions in the following format: + # @extern.extern + # def (, _builder=None): + # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}} + # return extern.dispatch("libdevice", , , , _builder) + import_str = "from . import core, extern\n" + import_str += "import os\n" + header_str = "LIBDEVICE_PATH = os.path.dirname(os.path.abspath(__file__)) + \"/libdevice.10.bc\"\n" + func_str = "" + for symbols in self._symbol_groups.values(): + func_str += "@extern.extern\n" + func_name_str = f"def {symbols[0].op_name}(" + for arg_name in symbols[0].arg_names: + func_name_str += f"{arg_name}, " + func_name_str += "_builder=None):\n" + + return_str = f"\treturn extern.elementwise(\"{self._name}\", LIBDEVICE_PATH, [" + for arg_name in symbols[0].arg_names: + return_str += f"{arg_name}, " + return_str += "], \n" + + arg_type_symbol_dict_str = "{" + for symbol in symbols: + arg_type_symbol_dict_str += "(" + for arg_type in symbol.arg_types: + arg_type_symbol_dict_str += f"core.dtype(\"{arg_type}\")," + ret_type = f"core.dtype(\"{symbol.ret_type}\")" + arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n" + arg_type_symbol_dict_str += "}" + + return_str += arg_type_symbol_dict_str + return_str += ", _builder)\n" + + func_str += func_name_str + return_str + "\n" + file_str = import_str + header_str + func_str + + return file_str + + +class LLVMDisassembler: + def __init__(self, path): + ''' + Invoke llvm-dis to disassemble the given file. + + :param path: path to llvm-dis + ''' + self._path = path + self._ll_file = "/tmp/extern_lib.ll" + + def disasm(self, lib_path): + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], + stdout=subprocess.PIPE).communicate() + + @property + def ll_file(self): + return self._ll_file + + @property + def path(self): + return self._path + + +extern_libs = ["libdevice"] + + +def build(llvm_dis_path, lib_path, lib_name, output_dir): + ''' + Interface function to build the library file. + + :param llvm_dis_path: path to the llvm-dis binary + :param lib_path: path to the external library file + :param lib_name: name of the library + :param output_dir: path to the output directory + ''' + if lib_name == "libdevice": + extern_lib = Libdevice(lib_path) + else: + raise Exception(f"Unknown extern library: {lib_name}") + + llvm_disassembler = LLVMDisassembler(llvm_dis_path) + llvm_disassembler.disasm(lib_path) + + extern_lib.parse_symbols(llvm_disassembler.ll_file) + extern_lib.generate_stub_file(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-llvm", dest="llvm_dis_path", help="path to llvm-dis", default="llvm-dis") + parser.add_argument("--lib-path", dest="lib_path", help="path to the extern library") + parser.add_argument("--lib-name", dest="lib_name", help="name of the extern library") + parser.add_argument("-o", dest="output_dir", help="output file path", default="/tmp/") + args = parser.parse_args() + + build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) diff --git a/python/tutorials/07-libdevice-function.py b/python/tutorials/07-libdevice-function.py new file mode 100644 index 000000000..bb5f7b26d --- /dev/null +++ b/python/tutorials/07-libdevice-function.py @@ -0,0 +1,74 @@ +""" +Libdevice function +=============== +Triton can invoke a custom function from an external library. +In this example, we will use the `libdevice` library to apply `asin` on a tensor. +Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html regarding the semantics of all available libdevice functions. + +In `trition/language/libdevice.py`, we try to aggregate functions with the same computation but different data types together. +For example, both `__nv_asin` and `__nvasinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`. +Using triton, you can simply call `tl.libdevice.asinf`. +triton automatically selects the correct underlying device function to invoke based on input and output types. +""" + +# %% +# asin Kernel +# -------------------------- + +import torch + +import triton +import triton.language as tl + + +@triton.jit +def asin_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x = tl.libdevice.asin(x) + tl.store(y_ptr + offsets, x, mask=mask) + +# %% +# Using the default libdevice library path +# -------------------------- +# We can use the default libdevice library path encoded in `triton/language/libdevice.py` + + +torch.manual_seed(0) +size = 98432 +x = torch.rand(size, device='cuda') +output_triton = torch.zeros(size, device='cuda') +output_torch = torch.asin(x) +assert x.is_cuda and output_triton.is_cuda +n_elements = output_torch.numel() +grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) +asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) +print(output_torch) +print(output_triton) +print( + f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}' +) + +# %% +# Customize the libdevice library path +# -------------------------- +# We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel. + +output_triton = torch.empty_like(x) +asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024, + extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'}) +print(output_torch) +print(output_triton) +print( + f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}' +)