diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 60ba87318..bbe9e25bf 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -87,7 +87,7 @@ T min(std::vector x) template -double bench(OP const & op, SYNC const & sync, triton::driver::cu_device const & device) +double bench(OP const & op, SYNC const & sync) { timer tmr; std::vector times; @@ -95,7 +95,7 @@ double bench(OP const & op, SYNC const & sync, triton::driver::cu_device const & op(); sync(); while(total_time*1e-9 < 1e-3){ - float norm = (float)device.current_sm_clock()/device.max_sm_clock(); + float norm = 1; tmr.start(); op(); sync(); @@ -108,7 +108,6 @@ double bench(OP const & op, SYNC const & sync, triton::driver::cu_device const & int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); - exit(EXIT_SUCCESS); triton::jit jit(context); // matrix multiplication parameters @@ -124,14 +123,14 @@ int main() { hb[i] = 1; for(size_t i = 0; i < hc.size(); i++) hc[i] = 0; - triton::driver::cu_buffer dc(context, hc.size()*4); - triton::driver::cu_buffer da(context, ha.size()*4); - triton::driver::cu_buffer db(context, hb.size()*4); - triton::driver::cu_stream stream(context); - stream.write(da, true, 0, ha); - stream.write(db, true, 0, hb); - stream.write(dc, true, 0, hc); - stream.synchronize(); + triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4); + triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4); + triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4); + triton::driver::stream* stream = triton::driver::stream::create(context); + stream->write(da, true, 0, ha); + stream->write(db, true, 0, hb); + stream->write(dc, true, 0, hc); + stream->synchronize(); // benchmark a given matrix multiplication kernel @@ -161,12 +160,11 @@ int main() { kernel->setArg(5, K); kernel->setArg(6, bound); // dry run - stream.enqueue(kernel, grid, {nthreads, 1, 1}); - stream.synchronize(); + stream->enqueue(kernel, grid, {nthreads, 1, 1}); + stream->synchronize(); // benchmark - double ts = bench([&](){stream.enqueue(kernel, grid, {nthreads, 1, 1});}, - [&](){ stream.synchronize(); }, - (triton::driver::cu_device&)*context->device()); + double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});}, + [&](){ stream->synchronize(); }); ts = ts * 1e-9; double tflops = 2*M*N*K / ts * 1e-12; return tflops; @@ -184,10 +182,10 @@ int main() { // jit.autotune(src, benchmark); jit.add_module(src, params); - triton::driver::cu_kernel kernel = jit.get_function("matmul"); + triton::driver::kernel* kernel = jit.get_function("matmul"); triton::jit::launch_information info = jit.get_launch_info("matmul"); - std::cout << benchmark(&kernel, info) << std::endl; - stream.read(dc, true, 0, hc); + std::cout << benchmark(kernel, info) << std::endl; + stream->read(dc, true, 0, hc); simple_gemm(rc, ha, hb, M, N, K); for(size_t i = 0; i < M*N; i++) if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ diff --git a/include/triton/driver/buffer.h b/include/triton/driver/buffer.h index c4ca53650..08bfede1d 100755 --- a/include/triton/driver/buffer.h +++ b/include/triton/driver/buffer.h @@ -38,6 +38,7 @@ class buffer : public polymorphic_resource { public: buffer(driver::context* ctx, CUdeviceptr cl, bool take_ownership); buffer(driver::context* ctx, cl_mem cl, bool take_ownership); + static buffer* create(driver::context* ctx, size_t size); driver::context* context(); protected: diff --git a/include/triton/driver/stream.h b/include/triton/driver/stream.h index 723edbc13..18bedbce0 100755 --- a/include/triton/driver/stream.h +++ b/include/triton/driver/stream.h @@ -49,7 +49,16 @@ public: static driver::stream* create(driver::context* ctx); // accessors driver::context* context() const; + // methods virtual void synchronize() = 0; + virtual void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const * = NULL, Event *event = NULL) = 0; + virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0; + virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0; + // template helpers + template void write(driver::buffer* buf, bool blocking, std::size_t offset, std::vector const & x) + { write(buf, blocking, offset, x.size()*sizeof(T), x.data()); } + template void read(driver::buffer* buf, bool blocking, std::size_t offset, std::vector& x) + { read(buf, blocking, offset, x.size()*sizeof(T), x.data()); } protected: driver::context *ctx_; @@ -61,32 +70,25 @@ public: // Constructors cl_stream(driver::context *ctx); - // Synchronize + // Overridden void synchronize(); + void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, Event *event); + void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); + void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); }; // CUDA class cu_stream: public stream { public: - //Constructors + // Constructors cu_stream(CUstream str, bool take_ownership); cu_stream(driver::context* context); - //Synchronize + // Overridden void synchronize(); - - //Enqueue - void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const * = NULL, Event *event = NULL); - - // Write - void write(driver::cu_buffer const & cu_buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr); - template void write(driver::cu_buffer const & buffer, bool blocking, std::size_t offset, std::vector const & x) - { write(buffer, blocking, offset, x.size()*sizeof(T), x.data()); } - - // Read - void read(driver::cu_buffer const & cu_buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr); - template void read(driver::cu_buffer const & buffer, bool blocking, std::size_t offset, std::vector& x) - { read(buffer, blocking, offset, x.size()*sizeof(T), x.data()); } + void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, Event *event); + void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); + void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); }; diff --git a/include/triton/jit.h b/include/triton/jit.h index a2c63bbf8..ecf22daf0 100644 --- a/include/triton/jit.h +++ b/include/triton/jit.h @@ -39,14 +39,14 @@ public: std::vector global_range_size; unsigned num_threads; }; - typedef std::function benchmark_t; + typedef std::function benchmark_t; struct passes_wrapper { passes_wrapper(): shared(&buffer_info), liveness(&buffer_info), - allocation(&liveness, &buffer_info), - barriers(&allocation, &buffer_info), - vectorize(&tune), - selection(&allocation, &tune, &buffer_info){ } + allocation(&liveness, &buffer_info), + barriers(&allocation, &buffer_info), + vectorize(&tune), + selection(&allocation, &tune, &buffer_info){ } void init(ir::module &module) { // generate ptx @@ -78,12 +78,12 @@ public: void autotune(const std::string &src, benchmark_t benchmark); void add_module(ir::module &module, const std::vector& params = {}); void add_module(const std::string &src, const std::vector& params = {}); - driver::cu_kernel get_function(const std::string &name); + driver::kernel* get_function(const std::string &name); launch_information get_launch_info(const std::string &name); unsigned get_int(const std::string &name); private: - std::vector modules_; + std::vector modules_; driver::context* driver_context_; llvm::LLVMContext llvm_context_; ir::context triton_context_; diff --git a/lib/driver/buffer.cpp b/lib/driver/buffer.cpp index 520347c7d..433d33b2e 100755 --- a/lib/driver/buffer.cpp +++ b/lib/driver/buffer.cpp @@ -46,6 +46,14 @@ driver::context* buffer::context() { return context_; } +buffer* buffer::create(driver::context* ctx, size_t size) { + if(dynamic_cast(ctx)) + return new cu_buffer(ctx, size); + if(dynamic_cast(ctx)) + return new ocl_buffer(ctx, size); + throw std::runtime_error("unknown context"); +} + // ocl_buffer::ocl_buffer(driver::context* context, size_t size) diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index 03793945c..5796cc7e5 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -77,7 +77,7 @@ driver::context* module::context() const { module* module::create(driver::context* ctx, llvm::Module *src) { if(dynamic_cast(ctx)) return new cu_module(ctx, src); - if(dynamic_cast(ctx)) + if(dynamic_cast(ctx)) return new ocl_module(ctx, src); throw std::runtime_error("unknown context"); } @@ -100,11 +100,13 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple layout = module->getDataLayoutStr(); module->setDataLayout(layout); + std::cout << "compiling" << std::endl; // emit machine code llvm::legacy::PassManager pass; llvm::raw_svector_ostream stream(buffer); machine->addPassesToEmitFile(pass, stream, nullptr, llvm::TargetMachine::CGFT_AssemblyFile); pass.run(*module); + std::cout << "compiled" << std::endl; } /* ------------------------ */ @@ -115,7 +117,6 @@ ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(con init_llvm(); llvm::SmallVector buffer; module::compile_llvm_module(src, "amdgcn-amd-amdpal", "gfx902", "", buffer); - throw std::runtime_error("need to implement opencl module creation"); } diff --git a/lib/driver/stream.cpp b/lib/driver/stream.cpp index fa7d25621..35e369716 100755 --- a/lib/driver/stream.cpp +++ b/lib/driver/stream.cpp @@ -79,6 +79,17 @@ void cl_stream::synchronize() { dispatch::clFinish(*cl_); } +void cl_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, Event* event) { + cl_int err = dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)grid.data(), (const size_t*)block.data(), 0, NULL, NULL); +} + +void cl_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) { + cl_int err = dispatch::clEnqueueWriteBuffer(*cl_, *buffer->cl(), blocking?CL_TRUE:CL_FALSE, offset, size, ptr, 0, NULL, NULL); +} + +void cl_stream::read(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr) { + cl_int err = dispatch::clEnqueueReadBuffer(*cl_, *buffer->cl(), blocking?CL_TRUE:CL_FALSE, offset, size, ptr, 0, NULL, NULL); +} /* ------------------------ */ // CUDA // @@ -114,20 +125,20 @@ void cu_stream::enqueue(driver::kernel* kernel, std::array grid, std: dispatch::cuEventRecord(event->cu()->second, *cu_); } -void cu_stream::write(driver::cu_buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) { +void cu_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) { cu_context::context_switcher ctx_switch(*ctx_); if(blocking) - dispatch::cuMemcpyHtoD(*buffer.cu() + offset, ptr, size); + dispatch::cuMemcpyHtoD(*buffer->cu() + offset, ptr, size); else - dispatch::cuMemcpyHtoDAsync(*buffer.cu() + offset, ptr, size, *cu_); + dispatch::cuMemcpyHtoDAsync(*buffer->cu() + offset, ptr, size, *cu_); } -void cu_stream::read(driver::cu_buffer const & buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr) { +void cu_stream::read(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void* ptr) { cu_context::context_switcher ctx_switch(*ctx_); if(blocking) - dispatch::cuMemcpyDtoH(ptr, *buffer.cu() + offset, size); + dispatch::cuMemcpyDtoH(ptr, *buffer->cu() + offset, size); else - dispatch::cuMemcpyDtoHAsync(ptr, *buffer.cu() + offset, size, *cu_); + dispatch::cuMemcpyDtoHAsync(ptr, *buffer->cu() + offset, size, *cu_); } diff --git a/lib/jit.cpp b/lib/jit.cpp index b4f93049a..38da020a4 100644 --- a/lib/jit.cpp +++ b/lib/jit.cpp @@ -131,15 +131,15 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) { } passes.tune.init(tt_module); passes.init(tt_module); - driver::cu_device* device = (driver::cu_device*)driver_context_->device(); - if(passes.allocation.get_allocated_size() > device->max_shared_memory()) - return; - if(passes.tune.get_num_threads() > device->max_threads_per_block()) - return; +// driver::device* device = driver_context_->device(); +// if(passes.allocation.get_allocated_size() > device->max_shared_memory()) +// return; +// if(passes.tune.get_num_threads() > device->max_threads_per_block()) +// return; // Compile auto ll_module = make_llvm_module(tt_module, passes); - driver::cu_module module(driver_context_, &*ll_module); - driver::cu_kernel kernel(&module, "matmul"); + driver::module* module = driver::module::create(driver_context_, &*ll_module); + driver::kernel* kernel = driver::kernel::create(module, "matmul"); launch_information info = launch_info_map_.at("matmul"); for(unsigned p: params) std::cout << p << " " << std::flush; @@ -166,13 +166,13 @@ void jit::add_module(ir::module &tt_module, const std::vector ¶ms) passes.tune.check_constraints(errors); if(errors.size()) throw std::runtime_error("invalid parameters"); - driver::cu_device* device = (driver::cu_device*)driver_context_->device(); - if(passes.allocation.get_allocated_size() > device->max_shared_memory()) - throw std::runtime_error("invalid parameters"); +// driver::device* device = driver_context_->device(); +// if(passes.allocation.get_allocated_size() > device->max_shared_memory()) +// throw std::runtime_error("invalid parameters"); // triton module -> llvm module auto ll_module = make_llvm_module(tt_module, passes); // llvm module -> machine code - modules_.push_back(driver::cu_module(driver_context_, &*ll_module)); + modules_.push_back(driver::module::create(driver_context_, &*ll_module)); // add globals for(auto x: tt_module.globals()) global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value(); @@ -183,8 +183,8 @@ void jit::add_module(const std::string &src, const std::vector ¶ms add_module(*ptt_module, params); } -driver::cu_kernel jit::get_function(const std::string &name) { - return driver::cu_kernel(&modules_.front(), name.c_str()); +driver::kernel *jit::get_function(const std::string &name) { + return driver::kernel::create(modules_.front(), name.c_str()); } jit::launch_information jit::get_launch_info(const std::string &name) {