diff --git a/examples/matrix.cpp b/examples/matrix.cpp index e16d6b2dd..624872f9c 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -6,9 +6,9 @@ const char* src = R"( -const tunable int32 TM; -const tunable int32 TN; -const tunable int32 TK; +const tunable int32 TM = {16, 32, 64}; +const tunable int32 TN = {16, 32, 64}; +const tunable int32 TK = {8, 16}; void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){ @@ -26,20 +26,8 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c, pa = pa + TK*M; pb = pb + TK*K; k = k - TK; - int1 checka[TM, TK] = k > bound; - int1 checkb[TN, TK] = k > bound; - @checka a = *pa; - @checkb b = *pb; - if(k > bound) - continue; - int1 checka0[TM] = rxa < M; - int1 checka1[TK] = rka < k; - int1 checkb0[TN] = ryb < N; - int1 checkb1[TK] = rkb < k; - checka = checka0[:, newaxis] && checka1[newaxis, :]; - checkb = checkb0[:, newaxis] && checkb1[newaxis, :]; - a = checka ? *pa : 0; - b = checkb ? *pb : 0; + a = *pa; + b = *pb; } int32 rxc[TM] = get_global_range[TM](0); int32 ryc[TN] = get_global_range[TN](1); @@ -87,22 +75,17 @@ T min(std::vector x) template -double bench(OP const & op, SYNC const & sync) +double bench(OP const & op, SYNC const & sync, unsigned repeat = 20) { timer tmr; - std::vector times; - double total_time = 0; op(); sync(); - while(total_time*1e-9 < 1e-3){ - float norm = 1; - tmr.start(); + tmr.start(); + for(unsigned i = 0; i < repeat; i++) op(); - sync(); - times.push_back(norm*tmr.get().count()); - total_time+=times.back(); - } - return min(times); + sync(); + double time = tmr.get().count(); + return time / repeat; } int main() { @@ -111,16 +94,16 @@ int main() { triton::jit jit(context); // matrix multiplication parameters - int32_t M = 128, N = 128, K = 128; + int32_t M = 512, N = 512, K = 512; std::vector hc(M*N); std::vector rc(M*N); std::vector ha(M*K); std::vector hb(K*N); srand(0); for(size_t i = 0; i < ha.size(); i++) - ha[i] = 1; + ha[i] = (float)rand()/RAND_MAX; for(size_t i = 0; i < hb.size(); i++) - hb[i] = 1; + hb[i] = (float)rand()/RAND_MAX; for(size_t i = 0; i < hc.size(); i++) hc[i] = 0; triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4); @@ -163,11 +146,10 @@ int main() { stream->enqueue(kernel, grid, {nthreads, 1, 1}); stream->synchronize(); // benchmark -// double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});}, -// [&](){ stream->synchronize(); }); - double ts = 1; + double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});}, + [&](){ stream->synchronize(); }); ts = ts * 1e-9; - double tflops = 2*M*N*K / ts * 1e-12; + double tflops = 2.*M*N*K / ts * 1e-12; return tflops; }; @@ -177,11 +159,12 @@ int main() { 16, 2, 64, 32, 2, 64, 16, 8, 2, 2, - 8, 1, 8, - 4, 1 + 8, 8, + 4, }; +// params = {8, 2, 64, 16, 2, 64, 4, 16, 2, 2, 8, 8, 4}; -// jit.autotune(src, benchmark); + jit.autotune(src, benchmark); jit.add_module(src, params); triton::driver::kernel* kernel = jit.get_function("matmul"); triton::jit::launch_information info = jit.get_launch_info("matmul"); diff --git a/include/triton/ast/parser.y b/include/triton/ast/parser.y index 8ce55f372..ae4b7d4e3 100644 --- a/include/triton/ast/parser.y +++ b/include/triton/ast/parser.y @@ -94,10 +94,15 @@ abstract_declarator direct_abstract_declarator : '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); } -constant : +constant: CONSTANT { $$ = new constant(atoi(yytext)); } ; - + +constant_list: + constant { $$ = new list((constant*)$1); } + | constant_list ',' constant { $$ = append_ptr_list($1, $3); } + ; + type_name : declaration_specifiers { $$ = new type_name($1, nullptr); } | declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); } @@ -259,7 +264,7 @@ expression /* Initialization */ initialization_expression : assignment_expression { $$ = $1; } - | '{' constant '}' { $$ = $2; } + | '{' constant_list '}' { $$ = $2; } ; diff --git a/include/triton/driver/device.h b/include/triton/driver/device.h index d99e47fe2..a08bd3cc8 100755 --- a/include/triton/driver/device.h +++ b/include/triton/driver/device.h @@ -38,18 +38,24 @@ class context; class device: public polymorphic_resource{ public: using polymorphic_resource::polymorphic_resource; + virtual size_t max_threads_per_block() const = 0; + virtual size_t max_shared_memory() const = 0; }; // Host device class host_device: public device { public: host_device(): device(host_device_t(), true){ } + size_t max_threads_per_block() const { return 1; } + size_t max_shared_memory() const { return 0; } }; // OpenCL device class ocl_device: public device { public: ocl_device(cl_device_id cl, bool take_ownership = true): device(cl, take_ownership) { } + size_t max_threads_per_block() const; + size_t max_shared_memory() const; }; // CUDA device @@ -87,8 +93,6 @@ public: std::string infos() const; size_t address_bits() const; std::vector max_block_dim() const; - size_t max_threads_per_block() const; - size_t max_shared_memory() const; size_t warp_size() const; //Compute Capability void interpret_as(std::pair cc); @@ -99,7 +103,8 @@ public: //Clocks size_t current_sm_clock() const; size_t current_mem_clock() const; - + size_t max_threads_per_block() const; + size_t max_shared_memory() const; size_t max_sm_clock() const; size_t max_mem_clock() const; diff --git a/include/triton/driver/dispatch.h b/include/triton/driver/dispatch.h index 1e0459931..2d06bb397 100755 --- a/include/triton/driver/dispatch.h +++ b/include/triton/driver/dispatch.h @@ -87,7 +87,7 @@ public: static bool cudnninit(); static void release(); - //OpenCL + // OpenCL static cl_int clBuildProgram(cl_program, cl_uint, const cl_device_id *, const char *, void (*)(cl_program, void *), void *); static cl_int clEnqueueNDRangeKernel(cl_command_queue, cl_kernel, cl_uint, const size_t *, const size_t *, const size_t *, cl_uint, const cl_event *, cl_event *); static cl_int clSetKernelArg(cl_kernel, cl_uint, size_t, const void *); diff --git a/include/triton/driver/handle.h b/include/triton/driver/handle.h index f87a8ffa6..d3b6f151c 100755 --- a/include/triton/driver/handle.h +++ b/include/triton/driver/handle.h @@ -105,20 +105,21 @@ public: bool operator<(handle_interface const & y) { return (CUType)(*this) < (CUType)(y); } }; -template +template class handle{ public: template friend class handle_interface; public: //Constructors - handle(CUType cu = CUType(), bool take_ownership = true); + handle(T h, bool take_ownership = true); + handle(); ~handle(); - CUType& operator*() { return *h_; } - CUType const & operator*() const { return *h_; } - CUType* operator->() const { return h_.get(); } + T& operator*() { return *h_; } + T const & operator*() const { return *h_; } + T* operator->() const { return h_.get(); } protected: - std::shared_ptr h_; + std::shared_ptr h_; bool has_ownership_; }; diff --git a/include/triton/ir/constant.h b/include/triton/ir/constant.h index 317cba2ff..0c18787ea 100644 --- a/include/triton/ir/constant.h +++ b/include/triton/ir/constant.h @@ -44,19 +44,19 @@ protected: }; /* Metaparameter int */ -class metaparameter: public constant_int{ - metaparameter(type *ty, unsigned lo, unsigned hi); +class metaparameter: public constant_int { +private: + metaparameter(type *ty, const std::vector& space); public: static metaparameter *create(context &ctx, type *ty, unsigned lo, unsigned hi); + static metaparameter *create(context &ctx, type *ty, const std::vector& space); void set_value(uint64_t value) { has_value_ = true; value_ = value; } bool has_value() { return has_value_; } - unsigned get_lo() { return lo_; } - unsigned get_hi() { return hi_; } + const std::vector& get_space() { return space_; } private: - unsigned lo_; - unsigned hi_; + std::vector space_; bool has_value_; }; diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index 04d03aa99..77ba26464 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -410,12 +410,16 @@ ir::value* initializer::codegen(ir::module * mod) const{ std::string name = decl_->id()->name(); ir::value *value = ir::undef_value::get(ty); if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){ - assert(expr_ == nullptr); - //TODO: implement ranges - value = ir::metaparameter::create(mod->get_context(), ty, 8, (name=="TK")?8:64); + auto csts = dynamic_cast*>((node*)expr_); + if(csts == nullptr) + throw std::runtime_error("must specify constant list for metaparameters"); + std::vector values; + for(constant* cst: csts->values()) + values.push_back(cst->value()); + value = ir::metaparameter::create(mod->get_context(), ty, values); mod->register_global(name, value); } - if(expr_){ + else if(expr_){ value = expr_->codegen(mod); value = explicit_cast(mod->get_builder(), value, ty); implicit_broadcast(mod, value, ty); diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index f3a9cedfb..8a9a35aa0 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -144,11 +144,23 @@ void tune::run(ir::module &mod) { // Layout parameters while(!nodes_.empty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 2); + ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 4); ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32); connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_); } } + + // Simplify metaparameters + std::set fixed_io_nts; + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *i : block->get_inst_list()) + if(dynamic_cast(i) || dynamic_cast(i)) + if(i->get_type()->is_tile_ty()) + for(unsigned d = 1; d < i->get_type()->get_tile_shapes().size(); d++) + fixed_io_nts.insert(params_.at(i).at("nts.d" + std::to_string(d))); + for(ir::metaparameter* mp: fixed_io_nts) + mp->set_value(1); } void tune::init(ir::module &mod) { @@ -234,7 +246,7 @@ bool tune::check_constraints(std::map> &er int num_threads = 1; for(size_t k = 0; k < shapes.size(); k++) num_threads *= params_[i]["mts.d" + to_string(k)]->get_value(); - if(num_threads % 32 != 0) + if(num_threads % 64 != 0) errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of 32"); if(num_threads != num_threads_) errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")"); diff --git a/lib/driver/device.cpp b/lib/driver/device.cpp index 0fe875075..62c41fb98 100755 --- a/lib/driver/device.cpp +++ b/lib/driver/device.cpp @@ -25,7 +25,7 @@ #include #include #include - +#include "triton/driver/helpers/CL/infos.hpp" #include "triton/driver/device.h" #include "triton/driver/context.h" @@ -40,6 +40,14 @@ namespace driver // OpenCL // /* ------------------------ */ +// maximum amount of shared memory per block +size_t ocl_device::max_shared_memory() const { + return ocl::info(*cl_); +} + +size_t ocl_device::max_threads_per_block() const { + return ocl::info(*cl_).at(0); +} /* ------------------------ */ // CUDA // diff --git a/lib/driver/handle.cpp b/lib/driver/handle.cpp index 603cf2b0d..c698cb8b5 100755 --- a/lib/driver/handle.cpp +++ b/lib/driver/handle.cpp @@ -60,13 +60,16 @@ inline void _delete(cu_event_t x) { _delete(x.first); _delete(x.second); } inline void _delete(CUPlatform){} //Constructor -template -handle::handle(CUType cu, bool take_ownership): h_(new CUType(cu)), has_ownership_(take_ownership) +template +handle::handle(T cu, bool take_ownership): h_(new T(cu)), has_ownership_(take_ownership) { } +template +handle::handle(): has_ownership_(false){ } -template -handle::~handle(){ + +template +handle::~handle(){ if(has_ownership_ && h_ && h_.unique()) _delete(*h_); } diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index 4c0018c3d..6e3533983 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -53,6 +53,10 @@ #include "llvm/ExecutionEngine/OrcMCJITReplacement.h" #include #include "llvm/Transforms/Utils/Cloning.h" +#include "lld/Common/Driver.h" +#include "lld/Common/Args.h" +#include "lld/Common/ErrorHandler.h" +#include "lld/Common/LLVM.h" namespace triton { @@ -110,36 +114,17 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple std::string error; auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); llvm::TargetOptions opt; -// opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; -// opt.UnsafeFPMath = false; -// opt.NoInfsFPMath = false; -// opt.NoNaNsFPMath = true; + opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; + opt.UnsafeFPMath = false; + opt.NoInfsFPMath = false; + opt.NoNaNsFPMath = true; llvm::TargetMachine *machine = target->createTargetMachine(module->getTargetTriple(), proc, "code-object-v3", opt, llvm::Reloc::PIC_, llvm::None, llvm::CodeGenOpt::Aggressive); - // set data layout if(layout.empty()) module->setDataLayout(machine->createDataLayout()); else module->setDataLayout(layout); - - // link - for (std::string& path: paths) { - llvm::SMDiagnostic err; - std::unique_ptr mlib = llvm::parseIRFile(path, err, module->getContext()); - if (mlib.get() == nullptr) { - std::string msg = err.getMessage(); - std::cerr << "Fail to load bitcode file " << path << "\n" - << "line " << err.getLineNo() << ":" << msg; - } - mlib->setTargetTriple(module->getTargetTriple()); - mlib->setDataLayout(module->getDataLayout()); - for (llvm::Function &f : mlib->functions()) { - f.addFnAttr(llvm::Attribute::AlwaysInline); - } - llvm::Linker::linkModules(*module, std::move(mlib)); - } - // emit machine code for (llvm::Function &f : module->functions()) f.addFnAttr(llvm::Attribute::AlwaysInline); @@ -187,12 +172,10 @@ ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(con init_llvm(); llvm::SmallVector buffer; module::compile_llvm_module(src, "amdgcn-amd-amdhsa-amdgizcl", "gfx902", "", buffer); - - std::ofstream output("tmp.o", std::ios::binary); + std::ofstream output("/tmp/tmp.o", std::ios::binary); std::copy(buffer.begin(), buffer.end(), std::ostreambuf_iterator(output)); - system("ld.lld tmp.o -shared -o test.o"); - - std::ifstream input("test.o", std::ios::in | std::ios::binary ); + system("ld.lld-8 /tmp/tmp.o -shared -o /tmp/tmp.o"); + std::ifstream input("/tmp/tmp.o", std::ios::in | std::ios::binary ); std::vector in_buffer(std::istreambuf_iterator(input), {}); size_t sizes[] = {in_buffer.size()}; const unsigned char* data[] = {(unsigned char*)in_buffer.data()}; @@ -208,7 +191,6 @@ ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(con char log[2048]; dispatch::clGetProgramBuildInfo(*cl_, *context->device()->cl(), CL_PROGRAM_BUILD_LOG, 1024, log, NULL); std::cout << log << std::endl; - std::cout << "T_T" << std::endl; throw; } } diff --git a/lib/driver/stream.cpp b/lib/driver/stream.cpp index e9818d7bd..937750c23 100755 --- a/lib/driver/stream.cpp +++ b/lib/driver/stream.cpp @@ -111,7 +111,8 @@ void cl_stream::synchronize() { } void cl_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event) { - check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)grid.data(), (const size_t*)block.data(), 0, NULL, NULL)); + std::array global = {grid[0]*block[0], grid[1]*block[1], grid[2]*block[2]}; + check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)global.data(), (const size_t*)block.data(), 0, NULL, NULL)); } void cl_stream::write(driver::buffer* buffer, bool blocking, std::size_t offset, std::size_t size, void const* ptr) { diff --git a/lib/ir/constant.cpp b/lib/ir/constant.cpp index bfb6fdb9b..5df644842 100644 --- a/lib/ir/constant.cpp +++ b/lib/ir/constant.cpp @@ -98,12 +98,22 @@ constant *constant_fp::get(context &ctx, double v){ } // metaparameter -metaparameter::metaparameter(type *ty, unsigned lo, unsigned hi) - : constant_int(ty, 0), lo_(lo), hi_(hi), has_value_(false){ } +metaparameter::metaparameter(type *ty, const std::vector &space) + : constant_int(ty, 0), space_(space), has_value_(false){ } metaparameter* metaparameter::create(context &ctx, type *ty, unsigned lo, unsigned hi) { context_impl *impl = ctx.p_impl.get(); - metaparameter *result = new metaparameter(ty, lo, hi); + std::vector space; + for(unsigned i = lo; i <= hi; i *= 2) + space.push_back(i); + metaparameter *result = new metaparameter(ty, space); + impl->mp_constants_.push_back(result); + return result; +} + +metaparameter* metaparameter::create(context &ctx, type *ty, const std::vector &space) { + context_impl *impl = ctx.p_impl.get(); + metaparameter *result = new metaparameter(ty, space); impl->mp_constants_.push_back(result); return result; } diff --git a/lib/jit.cpp b/lib/jit.cpp index 9162b73c6..f76870a86 100644 --- a/lib/jit.cpp +++ b/lib/jit.cpp @@ -5,6 +5,7 @@ #include "triton/ir/context.h" #include "triton/ir/context_impl.h" #include "triton/driver/device.h" +#include "triton/driver/error.h" #include "llvm/IR/IRPrintingPasses.h" #include "llvm/IR/Module.h" #include "llvm/IR/LLVMContext.h" @@ -71,6 +72,7 @@ std::unique_ptr jit::make_llvm_module(ir::module &module, passes_w passes.selection.run(module, *result); // launch information auto &launch_info_map = launch_info_map_[result->getName()]; + launch_info_map.global_range_size.clear(); for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++) launch_info_map.global_range_size.push_back(passes.tune.get_global_range_size(i)); launch_info_map.num_threads = passes.tune.get_num_threads(); @@ -104,12 +106,8 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) { auto mps = passes.tune.get_params(tt_module); // create parameter ranges std::vector> ranges; - for(ir::metaparameter *mp: mps){ - std::vector current; - for(unsigned x = mp->get_lo(); x <= mp->get_hi(); x*=2) - current.push_back(x); - ranges.push_back(current); - } + for(ir::metaparameter *mp: mps) + ranges.push_back(mp->get_space()); // iterate over parameters unsigned i; double best = 0; @@ -132,22 +130,23 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) { } passes.tune.init(tt_module); passes.init(tt_module); -// driver::device* device = driver_context_->device(); -// if(passes.allocation.get_allocated_size() > device->max_shared_memory()) -// return; -// if(passes.tune.get_num_threads() > device->max_threads_per_block()) -// return; + driver::device* device = driver_context_->device(); + if(passes.allocation.get_allocated_size() > device->max_shared_memory()) + return; + if(passes.tune.get_num_threads() > device->max_threads_per_block()) + return; // Compile auto ll_module = make_llvm_module(tt_module, passes); - driver::module* module = driver::module::create(driver_context_, &*ll_module); - driver::kernel* kernel = driver::kernel::create(module, "matmul"); + std::unique_ptr module(driver::module::create(driver_context_, &*ll_module)); + std::unique_ptr kernel(driver::kernel::create(module.get(), "matmul")); launch_information info = launch_info_map_.at("matmul"); for(unsigned p: params) std::cout << p << " " << std::flush; // add globals for(auto x: tt_module.globals()) global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value(); - double perf = benchmark(kernel, info); + double perf; + perf = benchmark(kernel.get(), info); best = std::max(perf, best); std::cout << perf << " [ " << best << " ] " << std::endl; }); @@ -167,9 +166,9 @@ void jit::add_module(ir::module &tt_module, const std::vector ¶ms) passes.tune.check_constraints(errors); if(errors.size()) throw std::runtime_error("invalid parameters"); -// driver::device* device = driver_context_->device(); -// if(passes.allocation.get_allocated_size() > device->max_shared_memory()) -// throw std::runtime_error("invalid parameters"); + driver::device* device = driver_context_->device(); + if(passes.allocation.get_allocated_size() > device->max_shared_memory()) + throw std::runtime_error("invalid parameters"); // triton module -> llvm module auto ll_module = make_llvm_module(tt_module, passes); // llvm module -> machine code