diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 289bd7874..42bf34e9a 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -6,6 +6,7 @@ #include #include #include "triton/tools/graph.h" +#include "triton/codegen/target.h" namespace triton{ @@ -113,7 +114,8 @@ struct scanline_layout: public data_layout { const std::vector& axes, const std::vector& shape, const std::vector &values, - analysis::align* align); + analysis::align* align, + target* tgt); void accept(layout_visitor* vst) { vst->visit_layout_scanline(this); } // accessor int mts(size_t k) { return mts_.at(k); } @@ -172,7 +174,7 @@ private: public: // constructor - layouts(analysis::axes *axes, analysis::align *align, size_t num_warps); + layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt); // accessors unsigned layout_of(ir::value *value) const { return groups_.at(value); } @@ -190,6 +192,7 @@ private: analysis::axes* axes_; analysis::align* align_; size_t num_warps_; + target* tgt_; tools::graph graph_; std::map groups_; std::map> values_; diff --git a/include/triton/driver/buffer.h b/include/triton/driver/buffer.h index 282f98bfb..3817ca4dd 100755 --- a/include/triton/driver/buffer.h +++ b/include/triton/driver/buffer.h @@ -19,6 +19,7 @@ public: buffer(driver::context* ctx, size_t size, CUdeviceptr cl, bool take_ownership); buffer(driver::context* ctx, size_t size, cl_mem cl, bool take_ownership); buffer(driver::context* ctx, size_t size, host_buffer_t hst, bool take_ownership); + uintptr_t addr_as_uintptr_t(); static buffer* create(driver::context* ctx, size_t size); driver::context* context(); size_t size(); diff --git a/include/triton/driver/handle.h b/include/triton/driver/handle.h index eac14dca2..2e512ddde 100755 --- a/include/triton/driver/handle.h +++ b/include/triton/driver/handle.h @@ -9,6 +9,15 @@ #include #include #include "triton/driver/dispatch.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "triton/tools/thread_pool.h" namespace llvm { @@ -42,13 +51,21 @@ struct host_context_t{ }; struct host_stream_t{ - + std::shared_ptr pool; }; struct host_module_t{ std::string error; llvm::ExecutionEngine* engine; std::map functions; + void(*fn)(char**, int32_t, int32_t, int32_t); + llvm::orc::ExecutionSession* ES; + llvm::orc::RTDyldObjectLinkingLayer* ObjectLayer; + llvm::orc::IRCompileLayer* CompileLayer; + llvm::DataLayout* DL; + llvm::orc::MangleAndInterner* Mangle; + llvm::orc::ThreadSafeContext* Ctx; + llvm::orc::JITDylib *MainJD; }; struct host_function_t{ diff --git a/include/triton/driver/stream.h b/include/triton/driver/stream.h index b7c5b7e62..7b70fd584 100755 --- a/include/triton/driver/stream.h +++ b/include/triton/driver/stream.h @@ -32,7 +32,7 @@ public: driver::context* context() const; // methods virtual void synchronize() = 0; - virtual void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const * = NULL, event *event = NULL, void **extra = NULL) = 0; + virtual void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const * = NULL, event *event = NULL, void **args = NULL, size_t args_size = 0) = 0; virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0; virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0; // template helpers @@ -53,7 +53,7 @@ public: // Overridden void synchronize(); - void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event *event, void **extra); + void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event *event, void **args, size_t args_size); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); }; @@ -66,7 +66,7 @@ public: // Overridden void synchronize(); - void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event *event, void **extra); + void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event *event, void **args, size_t args_size); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); }; @@ -80,7 +80,7 @@ public: // Overridden void synchronize(); - void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event *event, void **extra); + void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event *event, void **args, size_t args_size); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); }; diff --git a/include/triton/tools/thread_pool.h b/include/triton/tools/thread_pool.h index 143ef30ab..fbcf2b684 100644 --- a/include/triton/tools/thread_pool.h +++ b/include/triton/tools/thread_pool.h @@ -15,11 +15,65 @@ class ThreadPool { public: - ThreadPool(size_t); + ThreadPool(size_t threads) + : stop(false) { + for(size_t i = 0;i < threads;++i) + workers.emplace_back( + [this] { + for(;;){ + std::function task; + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait(lock, + [this]{ return this->stop || !this->tasks.empty(); }); + if(this->stop && this->tasks.empty()) + return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + task(); + } + } + ); + } + + template auto enqueue(F&& f, Args&&... args) - -> std::future::type>; - ~ThreadPool(); + -> std::future::type> + { + using return_type = typename std::result_of::type; + + auto task = std::make_shared< std::packaged_task >( + std::bind(std::forward(f), std::forward(args)...) + ); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if(stop) + throw std::runtime_error("enqueue on stopped ThreadPool"); + + tasks.emplace([task](){ (*task)(); }); + } + condition.notify_one(); + return res; + } + + + ~ThreadPool() { + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for(std::thread &worker: workers) + worker.join(); + } + + private: // need to keep track of threads so we can join them std::vector< std::thread > workers; @@ -32,69 +86,5 @@ private: bool stop; }; -// the constructor just launches some amount of workers -inline ThreadPool::ThreadPool(size_t threads) - : stop(false) -{ - for(size_t i = 0;i task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait(lock, - [this]{ return this->stop || !this->tasks.empty(); }); - if(this->stop && this->tasks.empty()) - return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - - task(); - } - } - ); -} - -// add new work item to the pool -template -auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> -{ - using return_type = typename std::result_of::type; - - auto task = std::make_shared< std::packaged_task >( - std::bind(std::forward(f), std::forward(args)...) - ); - - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); - - // don't allow enqueueing after stopping the pool - if(stop) - throw std::runtime_error("enqueue on stopped ThreadPool"); - - tasks.emplace([task](){ (*task)(); }); - } - condition.notify_one(); - return res; -} - -// the destructor joins all threads -inline ThreadPool::~ThreadPool() -{ - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for(std::thread &worker: workers) - worker.join(); -} #endif diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 14b207eec..5397642bc 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -168,9 +168,9 @@ scanline_layout::scanline_layout(size_t num_warps, const std::vector& axes, const std::vector& shape, const std::vector &values, - analysis::align* align): data_layout(SCANLINE, axes, shape, values, align){ + analysis::align* align, target *tgt): data_layout(SCANLINE, axes, shape, values, align){ unsigned size = std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies()); - unsigned num_threads = num_warps * 32; + unsigned num_threads = tgt->is_gpu() ? num_warps * 32 : 1; nts_.resize(shape_.size()); mts_.resize(shape_.size()); bool is_dot = std::any_of(values.begin(), values.end(), @@ -324,8 +324,8 @@ shared_layout::shared_layout(const data_layout *arg, * ---- Layouts Inference Pass ---- * * -------------------------------- */ -layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps) - : axes_(axes), align_(align), num_warps_(num_warps) { } +layouts::layouts(analysis::axes *axes, analysis::align *align, size_t num_warps, target* tgt) + : axes_(axes), align_(align), num_warps_(num_warps), tgt_(tgt){ } void layouts::connect(ir::value *x, ir::value *y) { @@ -382,7 +382,7 @@ void layouts::create(size_t id, const std::vector& values) { layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); } else - layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_); + layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_); } void layouts::run(ir::module &mod) { diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 36f4f1aa6..3dd61cfc3 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -488,41 +488,47 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* st) { ptr = gep->getPointerOperand(); } ptr = builder_->CreateBitCast(ptr, ty->getPointerTo(1)); - // asm argument type - std::vector arg_ty = {pred->getType(), ptr->getType()}; - for(int v = 0; v < vector_size; v++) - arg_ty.push_back(ty->getScalarType()); - // asm function type - FunctionType *fn_ty = FunctionType::get(builder_->getVoidTy(), arg_ty, false); - // asm string - std::string asm_str; - asm_str += "@$0 st.global"; - if(vector_size > 1) - asm_str += ".v" + std::to_string(vector_size); - asm_str += ".b" + std::to_string(nbits) + " [$1" + offset + "],"; - if(vector_size > 1) - asm_str += "{"; - for(int v = 0; v < vector_size; v++){ - if(v > 0) - asm_str += ", "; - asm_str += "$" + std::to_string(2 + v); + if(tgt_->is_gpu()){ + // asm argument type + std::vector arg_ty = {pred->getType(), ptr->getType()}; + for(int v = 0; v < vector_size; v++) + arg_ty.push_back(ty->getScalarType()); + // asm function type + FunctionType *fn_ty = FunctionType::get(builder_->getVoidTy(), arg_ty, false); + // asm string + std::string asm_str; + asm_str += "@$0 st.global"; + if(vector_size > 1) + asm_str += ".v" + std::to_string(vector_size); + asm_str += ".b" + std::to_string(nbits) + " [$1" + offset + "],"; + if(vector_size > 1) + asm_str += "{"; + for(int v = 0; v < vector_size; v++){ + if(v > 0) + asm_str += ", "; + asm_str += "$" + std::to_string(2 + v); + } + if(vector_size > 1) + asm_str += "}"; + asm_str += ";"; + // asm constraint + std::string constraint = "b,l"; + for(int v = 0; v < vector_size; v++){ + constraint += ","; + constraint += (nbits == 32 ? "r" : "h"); + } + // create inline asm + InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); + // call asm + std::vector args = {pred, ptr}; + for(int v = 0; v < vector_size; v++) + args.push_back(builder_->CreateExtractElement(elt, builder_->getInt32(v))); + builder_->CreateCall(iasm, args); } - if(vector_size > 1) - asm_str += "}"; - asm_str += ";"; - // asm constraint - std::string constraint = "b,l"; - for(int v = 0; v < vector_size; v++){ - constraint += ","; - constraint += (nbits == 32 ? "r" : "h"); + else{ + builder_->CreateMaskedStore(elt, ptr, alignment, builder_->CreateVectorSplat(vector_size, pred)); } - // create inline asm - InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); - // call asm - std::vector args = {pred, ptr}; - for(int v = 0; v < vector_size; v++) - args.push_back(builder_->CreateExtractElement(elt, builder_->getInt32(v))); - builder_->CreateCall(iasm, args); + } }); } @@ -1302,17 +1308,22 @@ void generator::visit_function(ir::function* fn) { for(auto attr_pair: fn->attrs()){ unsigned id = attr_pair.first; for(ir::attribute attr: attr_pair.second) - if(attr.is_llvm_attr()) - ret->addAttribute(id, llvm_attr(ctx, attr)); + if(attr.is_llvm_attr()){ + llvm::Attribute llattr = llvm_attr(ctx, attr); + if(llattr.getKindAsEnum() != llvm::Attribute::None) + ret->addAttribute(id, llvm_attr(ctx, attr)); + } } // set metadata - tgt_->set_kernel(*builder_, ctx, mod_, ret); - Metadata *md_args[] = { - ValueAsMetadata::get(ret), - MDString::get(ctx, "maxntidx"), - ValueAsMetadata::get(builder_->getInt32(num_warps_*32)) - }; - mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args)); + if(tgt_->is_gpu()){ + tgt_->set_kernel(*builder_, ctx, mod_, ret); + Metadata *md_args[] = { + ValueAsMetadata::get(ret), + MDString::get(ctx, "maxntidx"), + ValueAsMetadata::get(builder_->getInt32(num_warps_*32)) + }; + mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args)); + } // set arguments for(unsigned i = 0; i < fn->args().size(); i++) vmap_[fn->args()[i]] = &*(ret->arg_begin() + i); diff --git a/lib/driver/backend.cc b/lib/driver/backend.cc index 2c64936ef..036697fba 100755 --- a/lib/driver/backend.cc +++ b/lib/driver/backend.cc @@ -47,6 +47,12 @@ void backend::platforms::init() { if(dispatch::cuinit()){ cache_.push_back(new cu_platform()); } + //if host should be added + bool host_visible = true; + if(host_visible){ + cache_.push_back(new host_platform()); + } + // //if OpenCL is here // if(dispatch::clinit()){ // cl_uint num_platforms; @@ -56,11 +62,7 @@ void backend::platforms::init() { // for(cl_platform_id id: ids) // cache_.push_back(new cl_platform(id)); // } -// //if host is here -// bool host_visible = true; -// if(host_visible){ -// cache_.push_back(new host_platform()); -// } + if(cache_.empty()) throw std::runtime_error("Triton: No backend available. Make sure CUDA is available in your library path"); } diff --git a/lib/driver/buffer.cc b/lib/driver/buffer.cc index 1f499e5f3..f188d7483 100755 --- a/lib/driver/buffer.cc +++ b/lib/driver/buffer.cc @@ -53,6 +53,14 @@ size_t buffer::size() { return size_; } +uintptr_t buffer::addr_as_uintptr_t() { + switch(backend_){ + case CUDA: return *cu_; + case Host: return (uintptr_t)hst_->data; + default: return 0; + } +} + buffer* buffer::create(driver::context* ctx, size_t size) { switch(ctx->backend()){ diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 99a1e5405..57f206c8c 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -135,12 +135,6 @@ void module::compile_llvm_module(std::unique_ptr module, const std host_module::host_module(driver::context * context, std::unique_ptr src): module(context, host_module_t(), true) { init_llvm(); - // host info -// std::string triple = llvm::sys::getDefaultTargetTriple(); -// std::string cpu = llvm::sys::getHostCPUName(); -// llvm::SmallVector buffer; -// module::compile_llvm_module(src, triple, cpu, "", buffer, "", Assembly); - // create kernel wrapper llvm::LLVMContext &ctx = src->getContext(); llvm::Type *void_ty = llvm::Type::getVoidTy(ctx); @@ -148,37 +142,72 @@ host_module::host_module(driver::context * context, std::unique_ptr tys = {args_ty, int32_ty, int32_ty, int32_ty}; llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false); - llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", &*src); - llvm::Function* fn = src->getFunction("matmul"); + llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "_main", &*src); + llvm::Function* fn = &*src->getFunctionList().begin(); llvm::FunctionType *fn_ty = fn->getFunctionType(); std::vector fn_args(fn_ty->getNumParams()); std::vector ptrs(fn_args.size() - 3); llvm::BasicBlock* entry = llvm::BasicBlock::Create(ctx, "entry", main); llvm::IRBuilder<> ir_builder(ctx); ir_builder.SetInsertPoint(entry); - for(unsigned i = 0; i < ptrs.size(); i++) - ptrs[i] = ir_builder.CreateGEP(main->arg_begin(), ir_builder.getInt32(i)); + auto get_size = [](llvm::Type* ty) { return ty->isPointerTy() ? sizeof(char*) : ty->getPrimitiveSizeInBits() / 8; }; + llvm::Value* base = main->arg_begin(); + llvm::Value* args_base = ir_builder.CreateBitCast(base, base->getType()->getPointerElementType()); + + size_t offset = 0; for(unsigned i = 0; i < ptrs.size(); i++){ - llvm::Value* addr = ir_builder.CreateBitCast(ir_builder.CreateLoad(ptrs[i]), fn_ty->getParamType(i)->getPointerTo()); - fn_args[i] = ir_builder.CreateLoad(addr); + ptrs[i] = ir_builder.CreateGEP(args_base, ir_builder.getInt32(offset)); + size_t nbytes = get_size(fn_ty->getParamType(i)); + offset += nbytes; + if(i < ptrs.size() - 1){ + size_t np1bytes = get_size(fn_ty->getParamType(i+1)); + offset = (offset + np1bytes - 1) / np1bytes * np1bytes; + } } + for(unsigned i = 0; i < ptrs.size(); i++) + ptrs[i] = ir_builder.CreateBitCast(ptrs[i], fn_ty->getParamType(i)->getPointerTo()); + for(unsigned i = 0; i < ptrs.size(); i++) + fn_args[i] = ir_builder.CreateLoad(ptrs[i]); + fn_args[fn_args.size() - 3] = main->arg_begin() + 1; fn_args[fn_args.size() - 2] = main->arg_begin() + 2; fn_args[fn_args.size() - 1] = main->arg_begin() + 3; ir_builder.CreateCall(fn, fn_args); ir_builder.CreateRetVoid(); +// llvm::legacy::PassManager pm; +// pm.add(llvm::createPrintModulePass(llvm::outs())); +// pm.add(llvm::createVerifierPass()); +// pm.run(*src); - // create execution engine +// create execution engine for(llvm::Function& fn: src->functions()) hst_->functions[fn.getName()] = &fn; + +// llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost(); +// auto DL = JTMB.getDefaultDataLayoutForTarget(); +// auto CIRC = std::unique_ptr(new llvm::orc::ConcurrentIRCompiler(JTMB)); +// hst_->ES = new llvm::orc::ExecutionSession(); +// hst_->ObjectLayer = new llvm::orc::RTDyldObjectLinkingLayer(*hst_->ES, []() { return std::unique_ptr(new llvm::SectionMemoryManager()); }); +// hst_->CompileLayer = new llvm::orc::IRCompileLayer(*hst_->ES, *hst_->ObjectLayer, *CIRC); +// hst_->DL = new llvm::DataLayout(std::move(*DL)); +// hst_->Mangle = new llvm::orc::MangleAndInterner(*hst_->ES, *hst_->DL); +// hst_->Ctx = new llvm::orc::ThreadSafeContext(std::unique_ptr(new llvm::LLVMContext())); +// hst_->MainJD = &hst_->ES->createJITDylib("
"); +// hst_->MainJD->setGenerator(llvm::cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( +// hst_->DL->getGlobalPrefix()))); +// llvm::cantFail(hst_->CompileLayer->add(*hst_->MainJD, llvm::orc::ThreadSafeModule(std::move(src), *hst_->Ctx))); +// hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->ES->lookup({hst_->MainJD}, (*hst_->Mangle)("_main"))->getAddress()); + + + llvm::EngineBuilder builder(std::move(src)); builder.setErrorStr(&hst_->error); builder.setMCJITMemoryManager(llvm::make_unique()); builder.setOptLevel(llvm::CodeGenOpt::Aggressive); builder.setEngineKind(llvm::EngineKind::JIT); - builder.setUseOrcMCJITReplacement(true); hst_->engine = builder.create(); + hst_->fn = (void(*)(char**, int32_t, int32_t, int32_t))(hst_->engine->getFunctionAddress("_main")); } std::unique_ptr host_module::symbol(const char *name) const { diff --git a/lib/driver/stream.cc b/lib/driver/stream.cc index 8c397ca1e..f1501bb2c 100755 --- a/lib/driver/stream.cc +++ b/lib/driver/stream.cc @@ -72,21 +72,20 @@ driver::context* stream::context() const { /* ------------------------ */ host_stream::host_stream(driver::context *ctx): stream(ctx, host_stream_t(), true) { - + hst_->pool.reset(new ThreadPool(8)); } void host_stream::synchronize() { - + hst_->pool.reset(new ThreadPool(8)); } -void host_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event, void **extra) { - driver::host_kernel* hst_kernel = (host_kernel*)kernel; - llvm::ExecutionEngine* engine = kernel->module()->hst()->engine; - void (*fn)(char**, int32_t, int32_t, int32_t) = (void(*)(char**, int32_t, int32_t, int32_t))engine->getFunctionAddress("main"); +void host_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event, void **args, size_t args_size) { + ThreadPool pool(4); + auto hst = kernel->module()->hst(); for(size_t i = 0; i < grid[0]; i++) for(size_t j = 0; j < grid[1]; j++) for(size_t k = 0; k < grid[2]; k++) - fn((char**)hst_kernel->params().data(), int32_t(i), int32_t(j), int32_t(k)); + hst_->pool->enqueue(hst->fn, (char**)args, int32_t(i), int32_t(j), int32_t(k)); } void host_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) { @@ -112,7 +111,7 @@ void cl_stream::synchronize() { check(dispatch::clFinish(*cl_)); } -void cl_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event, void **extra) { +void cl_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event, void **args, size_t args_size) { std::array global = {grid[0]*block[0], grid[1]*block[1], grid[2]*block[2]}; check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)global.data(), (const size_t*)block.data(), 0, NULL, NULL)); } @@ -149,11 +148,16 @@ void cu_stream::synchronize() { dispatch::cuStreamSynchronize(*cu_); } -void cu_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event, void** extra) { +void cu_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event, void** args, size_t args_size) { cu_context::context_switcher ctx_switch(*ctx_); + void *config[] = { + CU_LAUNCH_PARAM_BUFFER_POINTER, args, + CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size, + CU_LAUNCH_PARAM_END + }; if(event) dispatch::cuEventRecord(event->cu()->first, *cu_); - dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, extra); + dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, config); if(event) dispatch::cuEventRecord(event->cu()->second, *cu_); } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index a176927ee..db514a4a8 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -163,11 +163,6 @@ function::caller::caller(ir::function *ir, void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, void** args, size_t args_size) const { - void *config[] = { - CU_LAUNCH_PARAM_BUFFER_POINTER, args, - CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size, - CU_LAUNCH_PARAM_END - }; // set grid if(_grid.size() > 3) throw std::runtime_error("grid size must be no greater than 3"); @@ -175,7 +170,7 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, for(size_t i = 0; i < 3; i++) grid[i] = (i < _grid.size()) ? _grid[i] : 1; // enqueue - stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1}, NULL, NULL, config); + stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1}, NULL, NULL, args, args_size); } @@ -203,7 +198,7 @@ std::unique_ptr function::make_bin(ir::module &module, codegen::analysis::align align; codegen::analysis::axes axes; codegen::transform::disassociate disassociate; - codegen::analysis::layouts layouts(&axes, &align, opt.num_warps); + codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get()); codegen::analysis::liveness liveness(&layouts); codegen::analysis::allocation allocation(&liveness); codegen::transform::membar barriers(&liveness, &layouts, &allocation); @@ -220,15 +215,18 @@ std::unique_ptr function::make_bin(ir::module &module, peephole.run(module); dce.run(module); align.run(module); - cts.run(module); + if(target->is_gpu()) + cts.run(module); axes.run(module); layouts.run(module); coalesce.run(module); dce.run(module); align.run(module); dce.run(module); - reassociate.run(module); - cts.run(module); + if(target->is_gpu()){ + reassociate.run(module); + cts.run(module); + } peephole.run(module); dce.run(module); align.run(module); @@ -260,11 +258,11 @@ function::caller* function::make(driver::stream *stream, options_t opt) { auto ir = make_ir(parser); // triton-ir -> binary std::unique_ptr bin; - try{ +// try{ bin = make_bin(*ir, stream->context(), opt); - }catch(const std::runtime_error&){ - return nullptr; - } +// }catch(const std::runtime_error&){ +// return nullptr; +// } // create callable ir::function *tmp = ir->get_function_list()[0]; caller* ret = new caller(tmp, std::move(bin), opt); diff --git a/python/setup.py b/python/setup.py index 55aa03281..d4c06b305 100644 --- a/python/setup.py +++ b/python/setup.py @@ -74,7 +74,7 @@ class CMakeBuild(build_ext): '-DLLVM_CONFIG=' + find_llvm()] # configuration cfg = 'Debug' if self.debug else 'Release' - cfg = 'Release' + cfg = 'Debug' build_args = ['--config', cfg] if platform.system() == "Windows": diff --git a/python/src/bindings.cc b/python/src/bindings.cc index b14f8cad4..2ed4aba46 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -15,7 +15,7 @@ using namespace triton; namespace rt = triton::runtime; -typedef std::pair map_key_t; +typedef std::pair map_key_t; std::map> id_grid_map; std::map> id_fn_map; std::map fp64scalar_map; diff --git a/python/src/launch.cc b/python/src/launch.cc index 24d7da6f1..75fd0e66a 100644 --- a/python/src/launch.cc +++ b/python/src/launch.cc @@ -8,22 +8,31 @@ #include "torch/script.h" #include "ATen/cuda/CUDAContext.h" -#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); - namespace rt = triton::runtime; namespace drv = triton::driver; -typedef std::pair map_key_t; +typedef std::pair map_key_t; extern std::map> id_grid_map; extern std::map> id_fn_map; +std::shared_ptr host_device; +std::shared_ptr host_context; +std::shared_ptr host_stream; void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){ - CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream(); - triton::driver::cu_stream stream(custream, false); - triton::driver::context* ctx = stream.context(); - (*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream); + if(dev_id == -1){ + if(!host_stream){ + host_device.reset(new drv::host_device()); + host_context.reset(drv::context::create(&*host_device)); + host_stream.reset(drv::stream::create(&*host_context)); + } + (*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream); + } + else{ + CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream(); + triton::driver::cu_stream stream(custream, false); + triton::driver::context* ctx = stream.context(); + (*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream); + } } diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 51cf08499..9fee937a6 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -67,6 +67,7 @@ class kernel: for x in args: if isinstance(x, torch.Tensor): device = x.device.index + device = -1 if device is None else device break # lazily register function for device if device not in self.registered: diff --git a/tests/common/dot.h b/tests/common/dot.h index 9c3f21091..35f24395c 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -14,9 +14,9 @@ struct dot_arg_t{ - CUdeviceptr a; - CUdeviceptr b; - CUdeviceptr c; + uintptr_t a; + uintptr_t b; + uintptr_t c; float alpha; int M; int N; @@ -24,7 +24,7 @@ struct dot_arg_t{ int lda; int ldb; int ldc; - CUdeviceptr locks; + uintptr_t locks; }; template @@ -98,7 +98,7 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, auto da = std::shared_ptr(drv::buffer::create(context, M*K*dt_nbytes)); auto db = std::shared_ptr(drv::buffer::create(context, K*N*dt_nbytes)); auto dlocks = std::shared_ptr(drv::buffer::create(context, 1024*1024*2*4)); - ((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size()); +// ((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size()); // macros rt::function::options_space_t opt; @@ -127,17 +127,17 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, opt.num_warps = {nwarp}; } if(mode == BENCH) { - opt.defines.push_back({"TM", {"128"}}); - opt.defines.push_back({"TN", {"128"}}); - opt.defines.push_back({"TK", {"16"}}); + opt.defines.push_back({"TM", {"64", "128"}}); + opt.defines.push_back({"TN", {"64", "128"}}); + opt.defines.push_back({"TK", {"8"}}); opt.defines.push_back({"TZ", {"1"}}); opt.num_warps = {4}; } // kernels rt::function function(src::dot, opt); - dot_arg_t args = {*da->cu(), *db->cu(), *dc->cu(), - 1, M, N, K, lda, ldb, ldc, *dlocks->cu()}; + dot_arg_t args = {da->addr_as_uintptr_t(), db->addr_as_uintptr_t(), dc->addr_as_uintptr_t(), + 1, M, N, K, lda, ldb, ldc, dlocks->addr_as_uintptr_t()}; auto grid = [M, N](const rt::function::options_t& x) { return rt::grid_t{ceil(M, x.D("TM")),