diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 18032f247..60ba87318 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -135,7 +135,7 @@ int main() { // benchmark a given matrix multiplication kernel - auto benchmark = [&](triton::driver::cu_kernel kernel, + auto benchmark = [&](triton::driver::kernel* kernel, triton::jit::launch_information info) { // launch info unsigned TM = info.global_range_size[0]; @@ -153,20 +153,20 @@ int main() { unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk; int32_t bound = std::max(1, std::max(K - last_safe_a, K - last_safe_b)); // set argument - kernel.setArg(0, da); - kernel.setArg(1, db); - kernel.setArg(2, dc); - kernel.setArg(3, M); - kernel.setArg(4, N); - kernel.setArg(5, K); - kernel.setArg(6, bound); + kernel->setArg(0, da); + kernel->setArg(1, db); + kernel->setArg(2, dc); + kernel->setArg(3, M); + kernel->setArg(4, N); + kernel->setArg(5, K); + kernel->setArg(6, bound); // dry run stream.enqueue(kernel, grid, {nthreads, 1, 1}); stream.synchronize(); // benchmark double ts = bench([&](){stream.enqueue(kernel, grid, {nthreads, 1, 1});}, [&](){ stream.synchronize(); }, - context->device()); + (triton::driver::cu_device&)*context->device()); ts = ts * 1e-9; double tflops = 2*M*N*K / ts * 1e-12; return tflops; @@ -186,7 +186,7 @@ int main() { jit.add_module(src, params); triton::driver::cu_kernel kernel = jit.get_function("matmul"); triton::jit::launch_information info = jit.get_launch_info("matmul"); - std::cout << benchmark(kernel, info) << std::endl; + std::cout << benchmark(&kernel, info) << std::endl; stream.read(dc, true, 0, hc); simple_gemm(rc, ha, hb, M, N, K); for(size_t i = 0; i < M*N; i++) diff --git a/include/triton/driver/backend.h b/include/triton/driver/backend.h index d830df391..a91fa7c7a 100755 --- a/include/triton/driver/backend.h +++ b/include/triton/driver/backend.h @@ -28,6 +28,10 @@ #include #include "triton/driver/context.h" +namespace llvm +{ +class Module; +} namespace triton { @@ -81,7 +85,7 @@ struct backend public: static void release(); - static driver::module* get(driver::stream* stream, std::string const & name, std::string const &src); + static driver::module* get(driver::stream* stream, std::string const & name, llvm::Module *src); private: static std::map, driver::module*> cache_; diff --git a/include/triton/driver/context.h b/include/triton/driver/context.h index 842d0a82c..379fe6962 100755 --- a/include/triton/driver/context.h +++ b/include/triton/driver/context.h @@ -40,6 +40,8 @@ public: context(driver::device *dev, cl_context cl, bool take_ownership); driver::device* device() const; std::string const & cache_path() const; + // factory methods + static context* create(driver::device *dev); protected: driver::device* dev_; diff --git a/include/triton/driver/device.h b/include/triton/driver/device.h index 2945ab766..97071ec27 100755 --- a/include/triton/driver/device.h +++ b/include/triton/driver/device.h @@ -32,6 +32,8 @@ namespace triton namespace driver { +class context; + // Base device class device: public polymorphic_resource{ public: diff --git a/include/triton/driver/handle.h b/include/triton/driver/handle.h index 3bffea395..6de493722 100755 --- a/include/triton/driver/handle.h +++ b/include/triton/driver/handle.h @@ -81,6 +81,7 @@ class polymorphic_resource { public: polymorphic_resource(CUType cu, bool take_ownership): cu_(cu, take_ownership){} polymorphic_resource(CLType cl, bool take_ownership): cl_(cl, take_ownership){} + virtual ~polymorphic_resource() { } handle cu() { return cu_; } handle cl() { return cl_; } diff --git a/include/triton/driver/kernel.h b/include/triton/driver/kernel.h index 6a8f114f4..0657e775f 100755 --- a/include/triton/driver/kernel.h +++ b/include/triton/driver/kernel.h @@ -41,14 +41,27 @@ class kernel: public polymorphic_resource { public: kernel(driver::module* program, CUfunction fn, bool has_ownership); kernel(driver::module* program, cl_kernel fn, bool has_ownership); + // Getters driver::module* module(); - + // Factory methods + static kernel* create(driver::module* program, const char* name); + // Arguments setters + virtual void setArg(unsigned int index, std::size_t size, void* ptr) = 0; + virtual void setArg(unsigned int index, buffer *) = 0; + template void setArg(unsigned int index, T value) { setArg(index, sizeof(T), (void*)&value); } private: driver::module* program_; }; // OpenCL class ocl_kernel: public kernel { +public: + //Constructors + ocl_kernel(driver::module* program, const char* name); + // Arguments setters + void setArg(unsigned int index, std::size_t size, void* ptr); + void setArg(unsigned int index, driver::buffer* buffer); + }; // CUDA @@ -56,10 +69,9 @@ class cu_kernel: public kernel { public: //Constructors cu_kernel(driver::module* program, const char * name); - //Arguments setters + // Arguments setters void setArg(unsigned int index, std::size_t size, void* ptr); - void setArg(unsigned int index, cu_buffer const &); - template void setArg(unsigned int index, T value) { setArg(index, sizeof(T), (void*)&value); } + void setArg(unsigned int index, driver::buffer* buffer); //Arguments getters void* const* cu_params() const; diff --git a/include/triton/driver/module.h b/include/triton/driver/module.h index ef45243fd..92a237d3d 100755 --- a/include/triton/driver/module.h +++ b/include/triton/driver/module.h @@ -31,6 +31,8 @@ namespace llvm { class Module; + template + class SmallVectorImpl; } namespace triton @@ -42,20 +44,34 @@ namespace driver class cu_context; class cu_device; +// Base class module: public polymorphic_resource { +protected: + void init_llvm(); + public: module(driver::context* ctx, CUmodule mod, bool has_ownership); module(driver::context* ctx, cl_program mod, bool has_ownership); + static module* create(driver::context* ctx, llvm::Module *src); driver::context* context() const; + void compile_llvm_module(llvm::Module* module, const std::string& triple, + const std::string &proc, std::string layout, + llvm::SmallVectorImpl &buffer); protected: driver::context* ctx_; }; +// OpenCL +class ocl_module: public module{ + +public: + ocl_module(driver::context* context, llvm::Module *module); +}; + +// CUDA class cu_module: public module { - static std::string header(driver::cu_device const & device); std::string compile_llvm_module(llvm::Module* module); - void init_llvm(); public: cu_module(driver::context* context, llvm::Module *module); diff --git a/include/triton/driver/stream.h b/include/triton/driver/stream.h index cb2ae7d4d..723edbc13 100755 --- a/include/triton/driver/stream.h +++ b/include/triton/driver/stream.h @@ -35,7 +35,7 @@ namespace triton namespace driver { -class cu_kernel; +class kernel; class Event; class Range; class cu_buffer; @@ -45,6 +45,9 @@ class stream: public polymorphic_resource { public: stream(driver::context *ctx, CUstream, bool has_ownership); stream(driver::context *ctx, cl_command_queue, bool has_ownership); + // factory + static driver::stream* create(driver::context* ctx); + // accessors driver::context* context() const; virtual void synchronize() = 0; @@ -73,7 +76,7 @@ public: void synchronize(); //Enqueue - void enqueue(cu_kernel const & cu_kernel, std::array grid, std::array block, std::vector const * = NULL, Event *event = NULL); + void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const * = NULL, Event *event = NULL); // Write void write(driver::cu_buffer const & cu_buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr); diff --git a/lib/driver/backend.cpp b/lib/driver/backend.cpp index 726548c1c..628f0c225 100755 --- a/lib/driver/backend.cpp +++ b/lib/driver/backend.cpp @@ -99,10 +99,11 @@ void backend::modules::release(){ cache_.clear(); } -driver::module* backend::modules::get(driver::stream* stream, std::string const & name, std::string const & src){ +driver::module* backend::modules::get(driver::stream* stream, std::string const & name, llvm::Module* src){ std::tuple key(stream, name); - if(cache_.find(key)==cache_.end()) - return &*cache_.insert(std::make_pair(key, new driver::cu_module(((driver::cu_stream*)stream)->context(), src))).first->second; + if(cache_.find(key)==cache_.end()){ + return &*cache_.insert({key, driver::module::create(stream->context(), src)}).first->second; + } return &*cache_.at(key); } @@ -120,8 +121,9 @@ void backend::kernels::release(){ driver::kernel* backend::kernels::get(driver::module *mod, std::string const & name){ std::tuple key(mod, name); - if(cache_.find(key)==cache_.end()) - return &*cache_.insert(std::make_pair(key, new driver::cu_kernel((driver::cu_module*)mod, name.c_str()))).first->second; + if(cache_.find(key)==cache_.end()){ + return &*cache_.insert({key, driver::kernel::create(mod, name.c_str())}).first->second; + } return cache_.at(key); } @@ -134,7 +136,7 @@ std::map, driver::kernel*> backend::ker void backend::streams::init(std::list const & contexts){ for(driver::context* ctx : contexts) if(cache_.find(ctx)==cache_.end()) - cache_.insert(std::make_pair(ctx, std::vector{new driver::cu_stream(ctx)})); + cache_.insert(std::make_pair(ctx, std::vector{driver::stream::create(ctx)})); } void backend::streams::release(){ @@ -168,7 +170,7 @@ std::map> backend::streams::cache void backend::contexts::init(std::vector const & devices){ for(driver::device* dvc: devices) - cache_.push_back(new cu_context(dvc)); + cache_.push_back(driver::context::create(dvc)); } void backend::contexts::release(){ diff --git a/lib/driver/context.cpp b/lib/driver/context.cpp index 56654c19d..6e1618713 100755 --- a/lib/driver/context.cpp +++ b/lib/driver/context.cpp @@ -50,6 +50,15 @@ context::context(driver::device *dev, cl_context cl, bool take_ownership): } +context* context::create(driver::device *dev){ + if(dynamic_cast(dev)) + return new cu_context(dev); + if(dynamic_cast(dev)) + return new ocl_context(dev); + throw std::runtime_error("unknown context"); +} + + driver::device* context::device() const { return dev_; } diff --git a/lib/driver/device.cpp b/lib/driver/device.cpp index 0b9852e7b..0fe875075 100755 --- a/lib/driver/device.cpp +++ b/lib/driver/device.cpp @@ -27,6 +27,7 @@ #include #include "triton/driver/device.h" +#include "triton/driver/context.h" namespace triton { @@ -35,11 +36,16 @@ namespace driver { +/* ------------------------ */ +// OpenCL // +/* ------------------------ */ + + /* ------------------------ */ // CUDA // /* ------------------------ */ -// Architecture +// architecture cu_device::Architecture cu_device::nv_arch(std::pair sm) const { switch(sm.first) { case 7: diff --git a/lib/driver/handle.cpp b/lib/driver/handle.cpp index 3396b1b2b..c9534c4af 100755 --- a/lib/driver/handle.cpp +++ b/lib/driver/handle.cpp @@ -33,6 +33,12 @@ namespace driver //OpenCL inline void _delete(cl_platform_id) { } inline void _delete(cl_device_id x) { dispatch::clReleaseDevice(x); } +inline void _delete(cl_context x) { dispatch::clReleaseContext(x); } +inline void _delete(cl_program x) { dispatch::clReleaseProgram(x); } +inline void _delete(cl_kernel x) { dispatch::clReleaseKernel(x); } +inline void _delete(cl_command_queue x) { dispatch::clReleaseCommandQueue(x); } +inline void _delete(cl_mem x) { dispatch::clReleaseMemObject(x); } + //CUDA inline void _delete(CUcontext x) { dispatch::cuCtxDestroy(x); } inline void _delete(CUdeviceptr x) { dispatch::cuMemFree(x); } @@ -67,6 +73,11 @@ template class handle; template class handle; template class handle; +template class handle; +template class handle; +template class handle; +template class handle; +template class handle; } } diff --git a/lib/driver/kernel.cpp b/lib/driver/kernel.cpp index d0656a230..1490ad21d 100755 --- a/lib/driver/kernel.cpp +++ b/lib/driver/kernel.cpp @@ -45,6 +45,14 @@ kernel::kernel(driver::module *program, cl_kernel fn, bool has_ownership): polymorphic_resource(fn, has_ownership), program_(program){ } +kernel* kernel::create(driver::module* program, const char* name) { + if(dynamic_cast(program)) + return new cu_kernel(program, name); + if(dynamic_cast(program)) + return new ocl_kernel(program, name); + throw std::runtime_error("unknown program"); +} + driver::module* kernel::module() { return program_; } @@ -53,6 +61,19 @@ driver::module* kernel::module() { // OpenCL // /* ------------------------ */ +ocl_kernel::ocl_kernel(driver::module* program, const char* name): kernel(program, cl_kernel(), true) { + cl_int err; + *cl_ = dispatch::clCreateKernel(*program->cl(), name, &err); +} + +void ocl_kernel::setArg(unsigned int index, std::size_t size, void* ptr) { + dispatch::clSetKernelArg(*cl_, index, size, ptr); +} + +void ocl_kernel::setArg(unsigned int index, driver::buffer* buffer) { + dispatch::clSetKernelArg(*cl_, index, sizeof(cl_mem), (void*)&*buffer->cl()); +} + /* ------------------------ */ // CUDA // @@ -74,8 +95,8 @@ void cu_kernel::setArg(unsigned int index, std::size_t size, void* ptr){ cu_params_[index] = cu_params_store_[index].get(); } -void cu_kernel::setArg(unsigned int index, cu_buffer const & data) -{ return setArg(index, data.cu());} +void cu_kernel::setArg(unsigned int index, driver::buffer* data) +{ return kernel::setArg(index, *data->cu());} void* const* cu_kernel::cu_params() const { return cu_params_.data(); } diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index 98779e918..03793945c 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -50,6 +50,18 @@ namespace driver // Base // /* ------------------------ */ +void module::init_llvm() { + static bool init = false; + if(!init){ + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + init = true; + } +} + module::module(driver::context* ctx, CUmodule mod, bool has_ownership) : polymorphic_resource(mod, has_ownership), ctx_(ctx) { } @@ -62,26 +74,56 @@ driver::context* module::context() const { return ctx_; } +module* module::create(driver::context* ctx, llvm::Module *src) { + if(dynamic_cast(ctx)) + return new cu_module(ctx, src); + if(dynamic_cast(ctx)) + return new ocl_module(ctx, src); + throw std::runtime_error("unknown context"); +} + +void module::compile_llvm_module(llvm::Module* module, const std::string& triple, + const std::string &proc, std::string layout, + llvm::SmallVectorImpl &buffer) { + init_llvm(); + // create machine + module->setTargetTriple(triple); + std::string error; + auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); + llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, "", + llvm::TargetOptions(), llvm::Reloc::Model(), + llvm::None, llvm::CodeGenOpt::Aggressive); + + + // set data layout + if(layout.empty()) + layout = module->getDataLayoutStr(); + module->setDataLayout(layout); + + // emit machine code + llvm::legacy::PassManager pass; + llvm::raw_svector_ostream stream(buffer); + machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile); + pass.run(*module); +} /* ------------------------ */ // OpenCL // /* ------------------------ */ +ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(context, cl_program(), true) { + init_llvm(); + llvm::SmallVector buffer; + module::compile_llvm_module(src, "amdgcn-amd-amdpal", "gfx902", "", buffer); + throw std::runtime_error("need to implement opencl module creation"); +} + /* ------------------------ */ // CUDA // /* ------------------------ */ std::string cu_module::compile_llvm_module(llvm::Module* module) { - init_llvm(); - // create machine - module->setTargetTriple("nvptx64-nvidia-cuda"); - std::string error; - auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); - llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), "sm_52", "", - llvm::TargetOptions(), llvm::Reloc::Model(), - llvm::None, llvm::CodeGenOpt::Aggressive); - // set data layout std::string layout = "e"; bool is_64bit = true; @@ -91,28 +133,13 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) { else if (use_short_pointers) layout += "-p3:32:32-p4:32:32-p5:32:32"; layout += "-i64:64-i128:128-v16:16-v32:32-n16:32:64"; - module->setDataLayout(layout); - // emit machine code - llvm::legacy::PassManager pass; + // create llvm::SmallVector buffer; - llvm::raw_svector_ostream stream(buffer); - machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile); - pass.run(*module); - // done + module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_52", layout, buffer); return std::string(buffer.begin(), buffer.end()); } -void cu_module::init_llvm() { - static bool init = false; - if(!init){ - llvm::InitializeAllTargetInfos(); - llvm::InitializeAllTargets(); - llvm::InitializeAllTargetMCs(); - llvm::InitializeAllAsmParsers(); - llvm::InitializeAllAsmPrinters(); - init = true; - } -} + cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { } diff --git a/lib/driver/stream.cpp b/lib/driver/stream.cpp index a8d8f5c43..fa7d25621 100755 --- a/lib/driver/stream.cpp +++ b/lib/driver/stream.cpp @@ -52,6 +52,15 @@ stream::stream(driver::context *ctx, cl_command_queue cl, bool has_ownership) } +driver::stream* stream::create(driver::context* ctx) { + if(dynamic_cast(ctx)) + return new cu_stream(ctx); + if(dynamic_cast(ctx)) + return new cl_stream(ctx); + throw std::runtime_error("unknown context"); +} + + driver::context* stream::context() const { return ctx_; } @@ -61,6 +70,10 @@ driver::context* stream::context() const { // OpenCL // /* ------------------------ */ +cl_stream::cl_stream(driver::context *ctx): stream(ctx, cl_command_queue(), true) { + cl_int err; + *cl_ = dispatch::clCreateCommandQueue(*ctx->cl(), *ctx->device()->cl(), 0, &err); +} void cl_stream::synchronize() { dispatch::clFinish(*cl_); @@ -91,11 +104,12 @@ void cu_stream::synchronize() { dispatch::cuStreamSynchronize(*cu_); } -void cu_stream::enqueue(driver::cu_kernel const & kernel, std::array grid, std::array block, std::vector const *, Event* event) { +void cu_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, Event* event) { + driver::cu_kernel* cu_kernel = (driver::cu_kernel*)kernel; cu_context::context_switcher ctx_switch(*ctx_); if(event) dispatch::cuEventRecord(event->cu()->first, *cu_); - dispatch::cuLaunchKernel(*kernel.cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_,(void**)kernel.cu_params(), NULL); + dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_,(void**)cu_kernel->cu_params(), NULL); if(event) dispatch::cuEventRecord(event->cu()->second, *cu_); }