diff --git a/CMakeLists.txt b/CMakeLists.txt index 3875cf348..a5a5d0f1e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,8 +33,10 @@ endif() if(BUILD_PYTHON_MODULE) message(STATUS "Adding Python module") # PyBind11 wrapper source file - set(PYTHON_SRC bindings.cc) + set(PYTHON_SRC bindings.cc launch.cc) + set_source_files_properties(launch.cc PROPERTIES COMPILE_FLAGS "-std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0") include_directories("." ${PYTHON_INCLUDE_DIRS}) + link_directories(${PYTHON_LINK_DIRS}) endif() @@ -43,3 +45,6 @@ file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc) add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) target_link_libraries(triton ${LLVM_LIBRARIES}) +if(BUILD_PYTHON_MODULE) + target_link_libraries(triton ${TORCH_LIBRARIES}) +endif() diff --git a/include/triton/driver/stream.h b/include/triton/driver/stream.h index 4b80b62af..b7c5b7e62 100755 --- a/include/triton/driver/stream.h +++ b/include/triton/driver/stream.h @@ -32,7 +32,7 @@ public: driver::context* context() const; // methods virtual void synchronize() = 0; - virtual void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const * = NULL, event *event = NULL) = 0; + virtual void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const * = NULL, event *event = NULL, void **extra = NULL) = 0; virtual void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr) = 0; virtual void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr) = 0; // template helpers @@ -53,7 +53,7 @@ public: // Overridden void synchronize(); - void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event *event); + void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event *event, void **extra); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); }; @@ -66,7 +66,7 @@ public: // Overridden void synchronize(); - void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event *event); + void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event *event, void **extra); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); }; @@ -80,7 +80,7 @@ public: // Overridden void synchronize(); - void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event *event); + void enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event *event, void **extra); void write(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void const* ptr); void read(driver::buffer* buf, bool blocking, std::size_t offset, std::size_t size, void* ptr); }; diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index 1a1b3a0ec..189fdd88b 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -40,6 +40,7 @@ enum attribute_kind_t { noalias, aligned, multiple_of, + retune, not_implemented }; @@ -113,6 +114,7 @@ public: // attributes void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); } const attr_map_t &attrs() { return attrs_; } + bool has_attr(unsigned arg_id) const { return attrs_.find(arg_id) != attrs_.end(); } std::set get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; } // visitor diff --git a/include/triton/lang/ast.h b/include/triton/lang/ast.h index 2d888efc2..0f57d86cc 100644 --- a/include/triton/lang/ast.h +++ b/include/triton/lang/ast.h @@ -64,7 +64,8 @@ public: ALIGNED, NOALIAS, READONLY, - WRITEONLY + WRITEONLY, + RETUNE, }; KindT kind; diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 240e4944f..01ebf0eb6 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -3,7 +3,7 @@ #ifndef _TRITON_RUNTIME_FUNCTION_H_ #define _TRITON_RUNTIME_FUNCTION_H_ - +#include #include #include #include @@ -62,6 +62,7 @@ public: typedef std::pair> define_t; std::vector defines; std::vector num_warps; + std::vector recompile_key; }; struct options_t { @@ -94,19 +95,25 @@ private: // accessors const options_t opt() const { return opt_; } const driver::module* parent() const { return &*parent_; } + const driver::kernel* bin() const { return &*bin_; } + arg_type param_ty(size_t i) const { return param_tys_.at(i);} + const std::vector& param_tys() const { return param_tys_; } + + std::vector retune() const { return retune_; } // entry points - void operator()(driver::stream *stream, const grid_t& grid, const std::vector& args) const; + void operator()(driver::stream *stream, const grid_t& grid, void **args, size_t args_size) const; private: std::shared_ptr bin_; std::shared_ptr parent_; std::vector param_tys_; + std::vector retune_; options_t opt_; std::string name_; }; private: - typedef std::pair> cache_key_t; + typedef std::pair> cache_key_t; private: // cache @@ -118,16 +125,15 @@ private: caller *make(driver::stream *stream, options_t opt); void precompile(driver::stream *stream, const options_space_t& tuning_space); // autotune - function::cache_key_t get_key(driver::stream *stream, const std::vector& args); - caller* autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector &args); + caller* autotune(driver::stream *stream, const grid_fn_ty& grid, void **args, size_t args_size); public: static std::string preheader(); public: function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = ""); - void operator()(const std::vector& args, const grid_t& grid, driver::stream* stream); - void operator()(const std::vector& args, const grid_fn_ty& grid, driver::stream *stream); + void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream); + void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream); void set_cst(const std::string& name, void* data, size_t n_bytes); private: @@ -138,6 +144,8 @@ private: options_space_t opt_; std::set compiled_; std::map> callers_; + std::vector args_off_; + size_t args_size_; // caching std::string cache_ref_; std::string cache_path_; diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 93ea21c44..36f4f1aa6 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -177,6 +177,7 @@ inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) { case ir::readonly: return llvm::Attribute::get(ctx, llvm::Attribute::ReadOnly); case ir::writeonly: return llvm::Attribute::get(ctx, llvm::Attribute::WriteOnly); case ir::aligned: return llvm::Attribute::get(ctx, llvm::Attribute::Alignment, attr.get_value()); + case ir::retune: return llvm::Attribute::get(ctx, llvm::Attribute::None); default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute"); } } diff --git a/lib/driver/stream.cc b/lib/driver/stream.cc index 2ff5746fc..8c397ca1e 100755 --- a/lib/driver/stream.cc +++ b/lib/driver/stream.cc @@ -79,7 +79,7 @@ void host_stream::synchronize() { } -void host_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event) { +void host_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event, void **extra) { driver::host_kernel* hst_kernel = (host_kernel*)kernel; llvm::ExecutionEngine* engine = kernel->module()->hst()->engine; void (*fn)(char**, int32_t, int32_t, int32_t) = (void(*)(char**, int32_t, int32_t, int32_t))engine->getFunctionAddress("main"); @@ -112,7 +112,7 @@ void cl_stream::synchronize() { check(dispatch::clFinish(*cl_)); } -void cl_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event) { +void cl_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event, void **extra) { std::array global = {grid[0]*block[0], grid[1]*block[1], grid[2]*block[2]}; check(dispatch::clEnqueueNDRangeKernel(*cl_, *kernel->cl(), grid.size(), NULL, (const size_t*)global.data(), (const size_t*)block.data(), 0, NULL, NULL)); } @@ -149,12 +149,11 @@ void cu_stream::synchronize() { dispatch::cuStreamSynchronize(*cu_); } -void cu_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event) { - driver::cu_kernel* cu_kernel = (driver::cu_kernel*)kernel; +void cu_stream::enqueue(driver::kernel* kernel, std::array grid, std::array block, std::vector const *, event* event, void** extra) { cu_context::context_switcher ctx_switch(*ctx_); if(event) dispatch::cuEventRecord(event->cu()->first, *cu_); - dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_,(void**)cu_kernel->cu_params(), NULL); + dispatch::cuLaunchKernel(*kernel->cu(), grid[0], grid[1], grid[2], block[0], block[1], block[2], 0, *cu_, nullptr, extra); if(event) dispatch::cuEventRecord(event->cu()->second, *cu_); } diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index de11ab646..2cf20d85e 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -630,6 +630,8 @@ ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) { return ir::attribute(ir::readonly); if(attr.kind == ASTNode::Attr::WRITEONLY) return ir::attribute(ir::writeonly); + if(attr.kind == ASTNode::Attr::RETUNE) + return ir::attribute(ir::retune); error_not_implemented("attribute " + std::to_string(attr.kind) + " not implemented"); return ir::attribute(ir::not_implemented); } diff --git a/lib/lang/parser.cc b/lib/lang/parser.cc index ba85a04cf..c3f9c5ab7 100644 --- a/lib/lang/parser.cc +++ b/lib/lang/parser.cc @@ -2778,6 +2778,8 @@ ASTNode::Attr Parser::ParseAttribute() { ret.kind = ASTNode::Attr::MULTIPLEOF; else if(name == "noalias") ret.kind = ASTNode::Attr::NOALIAS; + else if(name == "retune") + ret.kind = ASTNode::Attr::RETUNE; else Error(tok, "unknown attribute kind"); // set exprs diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index fb2daea43..a176927ee 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -151,27 +151,23 @@ function::caller::caller(ir::function *ir, bin_.reset(driver::kernel::create(&*parent, name_.c_str())); // extract signature ir::function_type* ty = ir->get_fn_type(); - for(size_t i = 0; i < ty->get_num_params(); i++) + for(size_t i = 0; i < ty->get_num_params(); i++){ param_tys_.push_back(convert(ty->get_param_ty(i))); + if(!ir->has_attr(i+1)) + continue; + for(ir::attribute attr: ir->attrs().at(i + 1)) + if(attr.get_kind() == ir::retune) + retune_.push_back(i); + } } -void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, const std::vector& args) const { - if(args.size() != param_tys_.size()) - throw std::runtime_error("invalid number of arguments"); - // set arguments - for(size_t i = 0; i < args.size(); i++){ - arg arg_i = args.at(i); - arg_type ty = arg_i.type(); - if(ty != param_tys_.at(i)) - throw std::runtime_error("invalid type for argument " + std::to_string(i)); - if(ty == BUFFER_T){ - driver::buffer* buf = *((driver::buffer**)arg_i.data()); - bin_->setArg(i, buf->size() == 0 ? nullptr : buf); - } - else - bin_->setArg(i, size_of(ty), arg_i.data()); - } +void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, void** args, size_t args_size) const { + void *config[] = { + CU_LAUNCH_PARAM_BUFFER_POINTER, args, + CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size, + CU_LAUNCH_PARAM_END + }; // set grid if(_grid.size() > 3) throw std::runtime_error("grid size must be no greater than 3"); @@ -179,7 +175,7 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, for(size_t i = 0; i < 3; i++) grid[i] = (i < _grid.size()) ? _grid[i] : 1; // enqueue - stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1}); + stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1}, NULL, NULL, config); } @@ -251,20 +247,6 @@ std::unique_ptr function::make_bin(ir::module &module, // create Binary from options function::caller* function::make(driver::stream *stream, options_t opt) { - // cache path - std::string cache_path = cache_path_ + opt.to_str() + ".ptx"; - int ref_mtime = tools::mtime(cache_ref_); - int ptx_mtime = tools::mtime(cache_path); - // if cached ptx is newer than reference library - if(!ref_mtime || !ptx_mtime || ref_mtime < ptx_mtime){ - std::ifstream ifs(cache_path); - // file is empty -- invalid - if(ifs && ifs.peek() == std::ifstream::traits_type::eof()) - return nullptr; - // load cached caller - if(ifs) - return new caller(stream->context(), ifs, opt); - } // pre-process TokenSequence tokens; Preprocessor cpp(&src_, true); @@ -281,18 +263,11 @@ function::caller* function::make(driver::stream *stream, options_t opt) { try{ bin = make_bin(*ir, stream->context(), opt); }catch(const std::runtime_error&){ - if(!cache_path_.empty()) - std::ofstream ofs(cache_path); return nullptr; } // create callable ir::function *tmp = ir->get_function_list()[0]; caller* ret = new caller(tmp, std::move(bin), opt); - // serialize callable - if(!cache_path_.empty()){ - std::ofstream ofs(cache_path); - ret->write(ofs); - } return ret; } @@ -330,48 +305,20 @@ void function::precompile(driver::stream* stream, throw std::runtime_error("could not compile kernel"); } -// return auto-tuning key for given function arguments -function::cache_key_t function::get_key(driver::stream *stream, const std::vector& args) { - cache_key_t ret; - ret.first = stream->context()->device(); - for(size_t i = 0; i < args.size(); i++){ - arg_type ty = args.at(i).type(); - if(!is_int_type(ty)) - continue; - long val = 0; - std::memcpy((void*)&val, args.at(i).data(), size_of(ty)); - ret.second.push_back(val); - } - return ret; -} // returns program with best compilation options for given parameter function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn, - const std::vector& args) { + void** args, size_t args_size) { // fast path -- no autotuning necessary if(callers_.size() == 1) return &*callers_.begin()->second; - // slow path -- autotuning necessary - // copy buffer argument so that auto-tuning doesn't corrupt data - std::list> copies; - std::vector _args = args; - for(size_t i = 0; i < args.size(); i++) - if(_args[i].type() == BUFFER_T){ - driver::buffer* old = _args[i].buffer(); - size_t size = old->size(); - // only copy scalars - // TODO: change that - if(size != 4 && size != 2) - continue; - copies.push_back(std::make_shared(old->context(), size)); - _args[i] = arg(copies.back().get()); - } + // TODO" copy buffer argument so that auto-tuning doesn't corrupt data double best_ts = INFINITY; caller* ret = nullptr; for(auto &x : callers_){ if(x.second == nullptr) throw std::runtime_error("configuration not compiled"); caller* current = &*x.second; - double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args); }, + double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args, args_size); }, stream, true); ret = (ts < best_ts) ? current : ret; best_ts = std::min(ts, best_ts); @@ -397,6 +344,7 @@ std::string function::preheader() { #define __noalias __attribute__((noalias)) #define __aligned(A) __attribute__((aligned(A))) #define __multipleof(A) __attribute__((multipleof(A))) +#define __retune __attribute__((retune)) #define F32_INFINITY bitcast(0x7F800000) #define F16_INFINITY bitcast((int16)0x7C00) @@ -456,27 +404,35 @@ function::function(const std::string &src, src_ = preheader() + src_; } -void function::operator()(const std::vector& args, - const grid_fn_ty& grid_fn, - driver::stream *stream) { +void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_fn, driver::stream *stream) { // pre-compile kernels - if(callers_.empty()) + if(callers_.empty()){ precompile(stream, opt_); + size_t cumsum = 0; + for(arg_type ty: callers_.begin()->second->param_tys()){ + args_off_.push_back(cumsum); + cumsum += size_of(ty); + } + } + // re-tuning key + cache_key_t key; + key.first = stream->context()->device(); + key.second = callers_.begin()->second->retune(); // auto-tune if necessary - auto key = get_key(stream, args); auto it = cache_.find(key); if(it == cache_.end()){ - auto best = autotune(stream, grid_fn, args); + auto best = autotune(stream, grid_fn, args, args_size); it = cache_.insert({key, best}).first; } // run - (*it->second)(stream, grid_fn(it->second->opt()), args); + (*it->second)(stream, grid_fn(it->second->opt()), args, args_size); } -void function::operator()(const std::vector& args, +void function::operator()(void** args, + size_t args_size, const grid_t& grid, driver::stream *stream) { - return this->operator()(args, [&grid](const options_t&){ return grid; }, stream); + return this->operator()(args, args_size, [&grid](const options_t&){ return grid; }, stream); } diff --git a/python/examples/tutorials/conv2d.py b/python/examples/tutorials/conv2d.py index 34c5d41ed..8e8f38491 100644 --- a/python/examples/tutorials/conv2d.py +++ b/python/examples/tutorials/conv2d.py @@ -8,7 +8,9 @@ class _conv(torch.autograd.Function): TYPE *C __noalias __aligned(16), float alpha, // equivalent matmul - int M, int N, int K, + int M __retune, + int N __retune, + int K __retune, // convolution properties int pad_h, int pad_w, int stride_h, int stride_w, // pointer increment @@ -197,4 +199,4 @@ c = conv(a, b, pad, stride, time) print((cc - c).abs().max() / max(cc.max(), c.max())) print(time[0], 2*Z*H*W*CI*CO*R*S/(time[0]*1e-9)*1e-12) #zc = torch.matmul(a,b) -#zc_ = dot(a,b) \ No newline at end of file +#zc_ = dot(a,b) diff --git a/python/examples/tutorials/mat_copy.py b/python/examples/tutorials/mat_copy.py index 5eeca842f..7c5276242 100644 --- a/python/examples/tutorials/mat_copy.py +++ b/python/examples/tutorials/mat_copy.py @@ -4,7 +4,9 @@ import triton class _copy(torch.autograd.Function): src = """ __global__ void copy(TYPE * X, TYPE * Y, - int M, int N, int ldx __multipleof(8)) { + int M __retune, + int N __retune, + int ldx __multipleof(8)) { // extract program ID int pidm = get_program_id(0); //(1) int pidn = get_program_id(1); //(2) diff --git a/python/examples/tutorials/mat_mul.py b/python/examples/tutorials/mat_mul.py index 419a61f51..78ecb9712 100644 --- a/python/examples/tutorials/mat_mul.py +++ b/python/examples/tutorials/mat_mul.py @@ -7,7 +7,9 @@ class _dot(torch.autograd.Function): TYPE *B __noalias __readonly __aligned(16), TYPE *C __noalias __aligned(16), float alpha, - int M, int N, int K, + int M __retune, + int N __retune, + int K __retune, int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8)) { @@ -128,4 +130,4 @@ b = torch.rand((K, N)).cuda() #zc = torch.matmul(a,b) zc_ = dot(a,b) -#print(torch.allclose(zc, zc_)) \ No newline at end of file +#print(torch.allclose(zc, zc_)) diff --git a/python/examples/tutorials/mat_transpose.py b/python/examples/tutorials/mat_transpose.py index 39f05c902..be31cd2cc 100644 --- a/python/examples/tutorials/mat_transpose.py +++ b/python/examples/tutorials/mat_transpose.py @@ -4,7 +4,9 @@ import triton class _transpose(torch.autograd.Function): src = """ __global__ void transpose(TYPE * X, TYPE * Y, - int M, int N, int ldx __multipleof(8), int ldy __multipleof(8)) { + int M __retune, + int N __retune, + int ldx __multipleof(8), int ldy __multipleof(8)) { // extract program ID int pidm = get_program_id(0); //(1) int pidn = get_program_id(1); //(2) diff --git a/python/setup.py b/python/setup.py index 1d32ad4c4..01f0c3049 100644 --- a/python/setup.py +++ b/python/setup.py @@ -8,9 +8,11 @@ import distutils import glob from distutils.version import LooseVersion from setuptools import setup, Extension, find_packages +from torch.utils.cpp_extension import include_paths, library_paths from setuptools.command.build_ext import build_ext from setuptools.command.test import test as TestCommand import distutils.spawn +import torch def find_llvm(): @@ -58,12 +60,17 @@ class CMakeBuild(build_ext): # python directories python_include_dirs = distutils.sysconfig.get_python_inc() python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR') + torch_include_dirs = include_paths(True) + torch_library_dirs = library_paths(True) + abi = torch._C._GLIBCXX_USE_CXX11_ABI cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, '-DBUILD_TESTS=OFF', '-DBUILD_PYTHON_MODULE=ON', #'-DPYTHON_EXECUTABLE=' + sys.executable, #'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON, - '-DPYTHON_INCLUDE_DIRS=' + python_include_dirs, + '-DPYTHON_INCLUDE_DIRS=' + ';'.join([python_include_dirs] + include_paths(True)), + '-DPYTHON_LINK_DIRS=' + ';'.join(library_paths(True)), + '-DTORCH_LIBRARIES=c10;c10_cuda;torch;torch_cuda;torch_cpu;torch_python;triton', '-DLLVM_CONFIG=' + find_llvm()] # configuration cfg = 'Debug' if self.debug else 'Release' @@ -80,8 +87,6 @@ class CMakeBuild(build_ext): build_args += ['--', '-j4'] env = os.environ.copy() - env['CXXFLAGS'] = '{} -DVERSION_INFO=\\"{}\\"'.format(env.get('CXXFLAGS', ''), - self.distribution.get_version()) if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) sourcedir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')) diff --git a/python/src/bindings.cc b/python/src/bindings.cc index ec4c1c6a1..b14f8cad4 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -3,20 +3,13 @@ #include #include #include -#include -#include #include "triton/runtime/function.h" #include "triton/runtime/arg.h" #include "triton/lang/code_gen.h" #include "triton/lang/parser.h" #include "triton/lang/cpp.h" -#include "triton/driver/device.h" -#include "triton/driver/stream.h" -#include "triton/driver/kernel.h" -#include "triton/driver/module.h" #include "triton/ir/module.h" #include "triton/ir/function.h" -#include "triton/tools/bench.hpp" using namespace triton; @@ -83,196 +76,6 @@ int64_t retrieve_scalar(size_t id) { return i64scalar_map.at(id); } -/* TF source-code generation */ - -inline std::string to_tf_ty(ir::type *ty) { - if(ty->is_integer_ty(1)) - return "bool"; - if(ty->is_integer_ty(8)) - return "int8"; - if(ty->is_integer_ty(16)) - return "int16"; - if(ty->is_integer_ty(32)) - return "int32"; - if(ty->is_integer_ty(64)) - return "int64"; - if(ty->is_half_ty()) - return "float16"; - if(ty->is_float_ty()) - return "float"; - if(ty->is_double_ty()) - return "double"; - if(ty->is_pointer_ty()) - return "Tensor"; - throw std::runtime_error("unknown type"); -} - -inline std::string to_tf_scalar_ty(ir::type *ty) { - if(ty->is_pointer_ty()) - return to_tf_ty(ty->get_pointer_element_ty()); - else { - return to_tf_ty(ty); - } -} - -inline std::string ref_to_tf_ty(ir::type *ty) { - std::string res = to_tf_ty(ty); - if(ty->is_pointer_ty()) - res = "const " + res + "&"; - return res; -} - -std::string tf_normalize(const std::string& name) { - std::string ret = name; - auto tolower = [](char c) { return std::tolower(c);}; - std::transform(ret.begin(), ret.end(), ret.begin(), tolower); - return ret; -} - -struct tf_alloc_t{ - enum type_t{ - OUTPUT, - TEMP - }; - - tf_alloc_t(const std::string& _name, type_t _type) - : name(_name), type(_type), tf_name(tf_normalize(_name)){ } - - std::string tf_name; - std::string name; - type_t type; - size_t shape_id; -}; - -typedef std::vector alloc_map_t; - - -void gen_extract_inputs(std::ostream &os, const std::vector& args, const alloc_map_t& allocs) { - for(unsigned i = 0; i < args.size(); i++){ - ir::value *arg = args[i]; - const std::string& name = arg->get_name(); - std::string ty = to_tf_ty(arg->get_type()); - if(!arg->get_type()->is_pointer_ty()) - os << " " << ty << " " << name << " = context->input(" << i << ").scalar<" << ty << ">()();\n "; - else if(std::find_if(allocs.begin(), allocs.end(), - [&](tf_alloc_t x) { - return x.name == name; - }) == allocs.end()) - os << " const Tensor* " << name << " = &context->input(" << i << ");\n "; - else - os << " Tensor* " << name << " = nullptr;\n "; - } -} - -void gen_set_outputs(std::ostream &os, const std::vector& args, const alloc_map_t& allocs) { - // initialize shapes - for(const auto& x: allocs) - os << " TensorShape " << x.name << "_shape;\n "; - for(const auto& x: allocs) - os << " const Tensor& " << x.name << "_shape_tensor = context->input(" << x.shape_id << ");\n "; - for(const auto& x: allocs) - os << " const int32* " << x.name << "_shape_data = (const int32*)" << x.name << "_shape_tensor.tensor_data().data();\n "; - for(const auto& x: allocs) - os << " size_t " << x.name << "_rank = " << x.name << "_shape_tensor.dim_size(0);\n "; - for(const auto& x: allocs) - os << " for(size_t d = 0; d < " << x.name << "_rank ; d++) " - << x.name << "_shape.AddDim(" << x.name << "_shape_data[d]);\n "; - - // allocate - int output = 0; - for(const auto& x: allocs){ - if(x.type == tf_alloc_t::OUTPUT) - os << " OP_REQUIRES_OK(context, context->allocate_output(" << output++ << ", " << x.name << "_shape, &" << x.name << "));\n "; - else - os << " OP_REQUIRES_OK(context, context->allocate_temp(" << x.name << "_type, " << x.name << "_shape, " << x.name << "));\n "; - } -} - -void gen_make_handles(std::ostream &os, const std::vector& args) { - for(unsigned i = 0; i < args.size(); i++){ - ir::argument *arg = args[i]; - if(!arg->get_type()->is_pointer_ty()) - continue; - const std::string& name = arg->get_name(); - os << " drv::cu_buffer cu_" + name + "(ctx, " + name + "->nbytes(), (CUdeviceptr)" + name + "->tensor_data().data(), false);\n "; - } -} - -void gen_make_launch_function(std::ostream &os, const std::vector& args) { - os << " std::function run = [&](){\n "; - os << " (*id_fn_map.at(id_))({"; - for(unsigned i = 0; i < args.size() ; i++){ - ir::argument *arg = args[i]; - std::string name = arg->get_name(); - if(arg->get_type()->is_pointer_ty()) - name = "&cu_" + name; - if(i > 0) - os << ", "; - os << name; - } - os << "}, *id_grid_map.at(id_), stream);\n "; - os << " };\n "; - os << " run();\n "; - os << " if(bench_ > 0)\n "; - os << " i64scalar_map[bench_id_] = triton::tools::bench(run, stream);\n "; -} - -void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name, - const std::string &opname, - const std::vector& args, - const alloc_map_t& allocs){ - - os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)"; - for(size_t i = 0; i < args.size(); i++){ - ir::argument *arg = args[i]; - std::string name = tf_normalize(arg->get_name()); - if(!arg->get_type()->is_pointer_ty()) - os << ".HostMemory(\"" + name + "\")"; - } - for(const auto& x: allocs) - os << ".HostMemory(\"" << x.tf_name << "_shape\")"; - os << ", " + opname << ");\n"; -} - -void gen_tf_register_op(std::ostream &os, const std::string &name, - const std::vector& args, - const alloc_map_t& allocs){ - - - os << "REGISTER_OP(\"" << name << "\")\n"; - for(size_t i = 0; i < args.size(); i++) - os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl; - for(size_t i = 0; i < args.size(); i++){ - ir::argument *arg = args[i]; - std::string name = tf_normalize(arg->get_name()); - if(std::find_if(allocs.begin(), allocs.end(), - [&](tf_alloc_t x) { - return name == x.tf_name; - }) == allocs.end()) - os << " .Input(\"" << name << ": T" << i << "\")\n"; - else - os << " .Input(\"" << name << "_shape: int32\")\n"; - } - for(const auto& x: allocs) - if(x.type == tf_alloc_t::OUTPUT) - os << " .Output(\"" << x.tf_name << ": T" << x.shape_id << "\")\n"; - os << " .Attr(\"id: int\")\n"; - os << " .Attr(\"bench: int\")\n"; - os << " .Attr(\"bench_id: int\")\n"; - os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* ctx) {\n"; - size_t current = 0; - for(const auto& x: allocs) - if(x.type == tf_alloc_t::OUTPUT){ - os << " shape_inference::ShapeHandle " << x.tf_name << "_handle;\n"; - os << " ctx->MakeShapeFromShapeTensor(" << x.shape_id << ", &" << x.tf_name << "_handle);\n"; - os << " ctx->set_output(" << current++ << ", " << x.tf_name << "_handle);\n"; - } - os << " return Status::OK();\n"; - os << " })\n"; - - os << ";\n"; -} - void make_module(const std::string& src, ir::module* ir, const runtime::function::options_space_t& opt) { std::string copy = triton::runtime::function::preheader() + src; @@ -290,339 +93,6 @@ void make_module(const std::string& src, ir::module* ir, gen.Gen(ir); } -std::tuple make_tensorflow_src(const std::string& src, - const std::vector& outputs, - const std::vector& tmp, - const runtime::function::options_space_t& opt) -{ - // triton-ir code-gen - ir::context ctx; - auto ir = std::shared_ptr(new ir::module("", ctx)); - make_module(src, &*ir, opt); - - // function - ir::function* fn = ir->get_function_list().front(); - const std::vector& args = fn->args(); - std::string name = fn->get_name(); - std::string cc_name = name; - cc_name[0] = static_cast(std::toupper(cc_name[0])); - std::string opname = cc_name + "Op"; - - // allocation info - alloc_map_t allocs; - for(size_t i = 0; i < outputs.size(); i++) - allocs.push_back(tf_alloc_t(outputs[i], tf_alloc_t::OUTPUT)); - for(size_t i = 0; i < tmp.size(); i++) - allocs.push_back(tf_alloc_t(tmp[i], tf_alloc_t::TEMP)); - - for(auto &x: allocs){ - size_t idx; - for(idx = 0; idx < args.size(); idx++) - if(args[idx]->get_name() == x.name) - break; - if(idx == args.size()) - throw std::runtime_error("unknown output"); - x.shape_id = idx; - } - - std::ostringstream oss; - oss << R"( -#include "triton/driver/buffer.h" -#include "triton/driver/backend.h" -#include "triton/driver/stream.h" -#include "triton/runtime/function.h" -#include "triton/tools/bench.hpp" - -#define EIGEN_USE_GPU -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/op_kernel.h" - -using namespace tensorflow; -using GPUDevice = Eigen::GpuDevice; -namespace rt = triton::runtime; -namespace drv = triton::driver; - -extern std::map> id_grid_map; -extern std::map> id_fn_map; -extern std::map i64scalar_map; - -class )" << opname << R"(: public OpKernel { - public: - explicit )" << opname << R"((OpKernelConstruction* context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("id", &id_)); - OP_REQUIRES_OK(context, context->GetAttr("bench", &bench_)); - OP_REQUIRES_OK(context, context->GetAttr("bench_id", &bench_id_)); - )"; -for(const auto& alloc: allocs) - oss << " OP_REQUIRES_OK(context, context->GetAttr(\"T" << alloc.shape_id << "\", &" << alloc.name << "_type));\n "; - -oss << R"( - } - - void Compute(OpKernelContext* context){ - - // get device/stream - GPUDevice device = context->eigen_device(); - drv::cu_stream sstream(device.stream(), false); - drv::context* ctx = sstream.context(); - drv::stream* stream = &sstream; - - // extract inputs - )"; -gen_extract_inputs(oss, args, allocs); -oss << R"( - // set outputs - )"; -gen_set_outputs(oss, args, allocs); -oss << R"( - // wrap tensors - )"; -gen_make_handles(oss, args); -oss << R"( - )"; -oss << R"( - // launch function - )"; -gen_make_launch_function(oss, args); -oss << R"( - } - -private: - int id_; - int bench_; - int64 bench_id_; - )"; -for(const auto& alloc: allocs) - oss << "DataType " << alloc.name << "_type;\n "; - -oss << R"( -}; - -// register kernel builder -)"; -gen_tf_register_kernel_builder(oss, cc_name, opname, args, allocs); -oss << R"( -// register op -)"; -gen_tf_register_op(oss, cc_name, args, allocs); - - return std::tuple{oss.str(), name}; -} - - -inline std::string to_torch_ty(ir::type *ty) { - if(ty->is_integer_ty()) - return "int64_t"; - if(ty->is_half_ty()) - return "double"; - if(ty->is_float_ty()) - return "double"; - if(ty->is_double_ty()) - return "double"; - if(ty->is_pointer_ty()) - return "torch::Tensor"; - throw std::runtime_error("unknown type"); -} - -inline std::string to_torch_ty(rt::arg_type ty){ - switch(ty){ - case rt::INT1_T: return "int64_t"; - case rt::INT8_T: return "int64_t"; - case rt::INT16_T: return "int64_t"; - case rt::INT32_T: return "int64_t"; - case rt::INT64_T: return "int64_t"; - case rt::HALF_T: return "double"; - case rt::FLOAT_T: return "double"; - case rt::DOUBLE_T: return "double"; - case rt::BUFFER_T: return "torch::Tensor"; - default: return "UNKNOWN"; - } -} - -inline std::string to_c_ty(rt::arg_type ty){ - switch(ty){ - case rt::INT1_T: return "bool"; - case rt::INT8_T: return "int8_t"; - case rt::INT16_T: return "int16_t"; - case rt::INT32_T: return "int32_t"; - case rt::INT64_T: return "int64_t"; - case rt::HALF_T: return "half"; - case rt::FLOAT_T: return "float"; - case rt::DOUBLE_T: return "double"; - case rt::BUFFER_T: return "drv::cu_buffer"; - default: return "UNKNOWN"; - } -} - - -inline std::string to_c_ty(ir::type *ty) { - if(ty->is_integer_ty(1)) - return "bool"; - if(ty->is_integer_ty(8)) - return "int8_t"; - if(ty->is_integer_ty(16)) - return "int16_t"; - if(ty->is_integer_ty(32)) - return "int32_t"; - if(ty->is_integer_ty(64)) - return "int64_t"; - if(ty->is_half_ty()) - return "half"; - if(ty->is_float_ty()) - return "float"; - if(ty->is_double_ty()) - return "double"; - if(ty->is_pointer_ty()) - return "drv::cu_buffer"; - throw std::runtime_error("unknown type"); -} - - - -void gen_torch_signature(std::ostringstream& oss, - const std::string& name, - const std::vector& args) { - std::string ret_ty = "void"; - oss << ret_ty << " " << name << "("; - oss << "int64_t id, "; - oss << "int64_t dev_id, "; - oss << "int64_t bench, "; - oss << "int64_t bench_id, "; - for(size_t i = 0; i < args.size(); i++) { - if(i > 0) - oss << ", "; - oss << to_torch_ty(args[i]) << " " << "th_arg_" << i; - } - oss << ")"; -} - -void gen_torch_init_driver(std::ostringstream &oss, - const std::vector&args) { - // Find index of first buffer - size_t i; - for(i = 0; i < args.size(); i++) - if(args[i] == rt::BUFFER_T) - break; - oss << " // Wrap CUDA handles" << std::endl; - oss << " c10::DeviceIndex device = th_arg_" << i << ".storage().device().index();" << std::endl; - oss << " // Get stream" << std::endl; - oss << " CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();" << std::endl; - oss << " triton::driver::cu_stream stream(custream, false);" << std::endl; - oss << " triton::driver::context* ctx = stream.context();" << std::endl; -} - -void gen_torch_make_handles(std::ostream &os, - const std::vector& args) { - for(unsigned i = 0; i < args.size(); i++){ - rt::arg_type arg = args[i]; - const std::string th_name = "th_arg_" + std::to_string(i); - const std::string name = "arg_" + std::to_string(i); - if(arg != rt::BUFFER_T) - os << " " << to_c_ty(arg) << " " << name << " = " << th_name << ";" << std::endl; - else{ - os << " CHECK_INPUT(" << th_name << ");" << std::endl; - os << " drv::cu_buffer " + name + "(ctx, " + th_name + ".nbytes(), " - " (CUdeviceptr)((char*)" + th_name + ".storage().data() + " + th_name + ".storage_offset() * " + th_name + ".itemsize()), false);" << std::endl; - } - } -} - -std::string get_val_struct_name(rt::arg_type ty){ - switch(ty){ - case rt::INT1_T: return "int1"; - case rt::INT8_T: return "int8"; - case rt::INT16_T: return "int16"; - case rt::INT32_T: return "int32"; - case rt::INT64_T: return "int64"; - case rt::HALF_T: return "fp16"; - case rt::FLOAT_T: return "fp32"; - case rt::DOUBLE_T: return "fp64"; - case rt::BUFFER_T: return "buf"; - default: return ""; - } -} - -void gen_torch_make_launch_function(std::ostream &os, - const std::vector& args) { - os << " namespace rt = triton::runtime;\n "; - os << " std::vector args;\n "; - for(unsigned i = 0; i < args.size(); i++){ - std::string name = "arg_" + std::to_string(i); - if(args[i] == rt::BUFFER_T) - name = "&" + name; - if(args[i] == rt::HALF_T) - name = "*((uint16_t*)&" + name + ")"; - os << "rt::arg_type ty" << i << " = (rt::arg_type)(" << args[i] << ");\n "; - os << "rt::arg::value_t val" << i << ";\n "; - os << "val" << i << "." << get_val_struct_name(args[i]) << " = " << name << ";\n "; - os << "args.push_back(rt::arg(ty" << i << ", val" << i << "));\n "; - } - os << " std::function run = [&](){\n "; - os << " (*id_fn_map.at({id, dev_id}))(args , *id_grid_map.at({id, dev_id}), &stream);\n"; - os << " };\n"; - os << " run();\n"; - os << " if(bench > 0)\n "; - os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n "; - } - -void gen_torch_ret(std::ostream &os, const std::vector& outputs) { - if(outputs.size() == 1){ - os << " return " << outputs[0] << ";" << std::endl; - return; - } - os << " return {"; - for(size_t i = 0; i < outputs.size(); i++){ - if(i > 0) - os << ", "; - os << outputs[i]; - } - os << "};" << std::endl; -} - -std::tuple make_torch_src(const std::string& name, std::vector args) { - // generate framework code - std::ostringstream oss; - oss << R"( -#include "triton/driver/buffer.h" -#include "triton/driver/stream.h" -#include "triton/runtime/function.h" -#include "triton/tools/bench.hpp" -#include "torch/script.h" -#include "ATen/cuda/CUDAContext.h" - -#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); - -namespace rt = triton::runtime; -namespace drv = triton::driver; - -typedef std::pair map_key_t; -extern std::map> id_grid_map; -extern std::map> id_fn_map; -extern std::map i64scalar_map; - -)"; - - gen_torch_signature(oss, name, args); - oss << " {" << std::endl; - gen_torch_init_driver(oss, args); - gen_torch_make_handles(oss, args); - gen_torch_make_launch_function(oss, args); - //gen_torch_ret(oss); - oss << "}" << std::endl; - - oss << std::endl; - oss << std::endl; - oss << "static auto registry = torch::RegisterOperators(\"triton::" << name << "\", &" << name << ");" << std::endl; - - return std::tuple{oss.str(), name}; -} - /* Function signature */ std::vector get_fn_signature(const std::string& src, const runtime::function::options_space_t& opt) { @@ -646,13 +116,6 @@ typedef triton::runtime::function::options_space_t options_space_t; PYBIND11_MODULE(libtriton, m) { m.doc() = "Python bindings to the C++ Triton API"; - // framework binding source code generation - m.def("make_tensorflow_src", &make_tensorflow_src, - "Creates C++ source code for a custom Tensorflow op " - "corresponding to the specified Triton kernel"); - m.def("make_torch_src", &make_torch_src, - "Creates C++ source code for a custom PyTorch op "); - // bindings for triton classes pybind11::enum_(m, "arg_type") .value("int1", rt::INT1_T) diff --git a/python/src/launch.cc b/python/src/launch.cc new file mode 100644 index 000000000..1911bc7b3 --- /dev/null +++ b/python/src/launch.cc @@ -0,0 +1,27 @@ +#include "triton/driver/buffer.h" +#include "triton/driver/stream.h" +#include "triton/runtime/function.h" +#include "triton/tools/bench.hpp" +#include "torch/script.h" +#include "ATen/cuda/CUDAContext.h" + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); + +namespace rt = triton::runtime; +namespace drv = triton::driver; + +typedef std::pair map_key_t; +extern std::map> id_grid_map; +extern std::map> id_fn_map; + +void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){ + CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream(); + triton::driver::cu_stream stream(custream, false); + triton::driver::context* ctx = stream.context(); + (*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream); +} + + +static auto registry = torch::RegisterOperators("triton::launch_kernel", &launch_kernel); diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 535f717ec..afaa70ecd 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,7 +1,6 @@ from .kernel import * -from .utils import * -import triton.ops -import triton.nn +#import triton.ops +#import triton.nn # clean-up libtriton resources diff --git a/python/triton/frameworks.py b/python/triton/frameworks.py deleted file mode 100644 index 9385bc212..000000000 --- a/python/triton/frameworks.py +++ /dev/null @@ -1,28 +0,0 @@ -import sys -import os -import triton._C.libtriton as libtriton - -torch = None -tensorflow = None - -def _import_torch(): - global torch - if torch is None: - import torch - -def _import_tensorflow(): - global tensorflow - if tensorflow is None: - import tensorflow - -def has_tensorflow(): - result = 'tensorflow' in sys.modules - if result: - _import_tensorflow() - return result - -def has_torch(): - result = 'torch' in sys.modules - if result: - _import_torch() - return result \ No newline at end of file diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 019819b7e..51cf08499 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -1,181 +1,71 @@ -# import for cache -import os -import tempfile -import shutil -import hashlib -import sysconfig -import sys -import weakref -import contextlib -import io -import torch.utils.cpp_extension -# import for just-in-time compilation -import distutils -import setuptools.command.build_ext -import setuptools -# triton -import triton.frameworks as fw -import triton.utils import triton._C.libtriton as libtriton import os import time -import platform +from struct import pack +import torch -@contextlib.contextmanager -def quiet(): - old_stdout, old_stderr = sys.stdout, sys.stderr - sys.stdout, sys.stderr = io.StringIO(), io.StringIO() - try: - yield - finally: - sys.stdout, sys.stderr = old_stdout, old_stderr +codes = { + libtriton.arg_type.int1: 'B', + libtriton.arg_type.int8: 'B', + libtriton.arg_type.int32: 'I', + libtriton.arg_type.int64: 'Q', + libtriton.arg_type.half: 'H', + libtriton.arg_type.float: 'f', + libtriton.arg_type.double: 'd', + libtriton.arg_type.buffer: 'P' +} -def _build(src, path, name): - ccdir = os.path.join(libtriton.__file__, os.path.pardir) - ccdir = os.path.realpath(ccdir) - # include / libraries - include_dirs = [os.path.join(ccdir, 'include')] - library_dirs = [ccdir] - libraries = ['triton'] - # create extension module - abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI - extra_compile_args = ['-fPIC', '-Wno-deprecated-declarations', f'-D_GLIBCXX_USE_CXX11_ABI={str(int(abi))}'] - extra_compile_args += ['-DTORCH_EXTENSION_NAME={}'.format(name)] - extra_compile_args += ['-DTORCH_API_INCLUDE_EXTENSION_H'] - - ext = torch.utils.cpp_extension.CUDAExtension( - name = name, - language = 'c++', - sources = [src], - include_dirs = include_dirs, - library_dirs = library_dirs, - libraries = libraries, - extra_compile_args = extra_compile_args, - depends = [os.path.realpath(libtriton.__file__)] - ) - # build extension module - args = ['build_ext'] - tmp = tempfile.mkdtemp() - args.append('--build-temp=' + tmp) - args.append('--build-lib=' + path) - args.append('-q') - args = dict( - name = name, - ext_modules = [ext], - script_args = args, - ) - with quiet(): - setuptools.setup(**args) - shutil.rmtree(tmp) - -def _cvt_to_def_str(obj): - # bool - if isinstance(obj, bool): - return str(int(obj)) - # torch type - if fw.has_torch(): - if isinstance(obj, fw.torch.dtype): - return {fw.torch.int8: 'char', - fw.torch.int16: 'short', - fw.torch.int32: 'int', - fw.torch.int64: 'long', - fw.torch.float16: 'half', - fw.torch.float32: 'float', - fw.torch.float64: 'double'}[obj] - else: - assert False - # default - return str(obj) +sizes = { + libtriton.arg_type.int1: 1, + libtriton.arg_type.int8: 1, + libtriton.arg_type.int32: 4, + libtriton.arg_type.int64: 8, + libtriton.arg_type.half: 2, + libtriton.arg_type.float: 4, + libtriton.arg_type.double: 8, + libtriton.arg_type.buffer: 8 +} -def _encode(arg_types): - codes = { - libtriton.arg_type.int1: 'i1', - libtriton.arg_type.int8: 'i8', - libtriton.arg_type.int32: 'i32', - libtriton.arg_type.int64: 'i64', - libtriton.arg_type.half: 'f16', - libtriton.arg_type.float: 'f32', - libtriton.arg_type.double: 'f64', - libtriton.arg_type.buffer: 'buf' +def th_to_triton(obj): + tys = { + torch.int8: 'char', + torch.int16: 'short', + torch.int32: 'int', + torch.int64: 'long', + torch.float16: 'half', + torch.float32: 'float', + torch.float64: 'double' } - ret = '_'.join(map(codes.get, arg_types)) - return ret + if isinstance(obj, torch.dtype): + return [tys[obj]] + if isinstance(obj, list): + return [th_to_triton(x)[0] for x in obj] + return [str(obj)] -def _make_framework_op(arg_types): - name = _encode(arg_types) - # path of .cpp and .so file - home = os.path.expanduser('~') - root = os.path.join(home, '.triton', 'torch', name) - try: - os.makedirs(root) - except FileExistsError: - pass - suffix = sysconfig.get_config_var('EXT_SUFFIX') - so = os.path.join(root, f'op{suffix}') - cpp = os.path.join(root, f'op.cpp') - # handle cached .so file - if os.path.exists(so) and os.stat(so).st_size > 0: - tt_mtime = os.stat(os.path.realpath(libtriton.__file__)).st_mtime - so_mtime = os.stat(so).st_mtime - # can use cached if libtriton is older than the .so - if tt_mtime < so_mtime: - fw.torch.ops.load_library(so) - return getattr(fw.torch.ops.triton, name) - # create torch source code - print('[TRITON] Compiling op...') - baton = torch.utils.file_baton.FileBaton(os.path.join(root, 'lock')) - if baton.try_acquire(): - try: - src, _ = libtriton.make_torch_src(name, arg_types) - with open(cpp, 'w+') as handle: - handle.writelines(src) - ccdir = os.path.join(libtriton.__file__, os.path.pardir) - ccdir = os.path.realpath(ccdir) - _build(cpp, root, 'op') - finally: - baton.release() - else: - baton.wait() - print('[TRITON] Done compiling...') - fw.torch.ops.load_library(so) - return getattr(fw.torch.ops.triton, name) - - - +def cdiv(a, b): + return (a + b - 1) // b class kernel: def __init__(self, src, defines = dict(), num_warps = [2, 4, 8]): self.src = src - # create constants - self.cst = dict() - # create triton op - macros = [] - for k, v in defines.items(): - cvt = lambda x: _cvt_to_def_str(x) - if(isinstance(v, list)): - values = list(map(cvt, v)) - else: - values = [cvt(v)] - macros.append((k, values)) - opt = libtriton.options_space() - opt.defines = macros - opt.num_warps = num_warps + self.opt = libtriton.options_space() + self.opt.defines = [(k, th_to_triton(v)) for k, v in defines.items()] + self.opt.num_warps = num_warps self.op_id = libtriton.make_op_id() - self.opt = opt self.registered = set() - # create pytorch hook - arg_types = libtriton.get_fn_signature(self.src, opt) - self.fw_op = _make_framework_op(arg_types) + arg_types = libtriton.get_fn_signature(self.src, self.opt) + size = sum([sizes[x] for x in arg_types]) + self.tys = ''.join([codes[x] for x in arg_types]) def set_constant(self, device, name, value): libtriton.register_cst((self.op_id, device), name, value) def __call__(self, *args, **kwargs): for x in args: - if isinstance(x, fw.torch.Tensor): + if isinstance(x, torch.Tensor): device = x.device.index break # lazily register function for device @@ -191,6 +81,6 @@ class kernel: grid = kwargs['grid'] libtriton.register_grid((self.op_id, device), grid) # launch - self.fw_op(self.op_id, device, bench, bench_id, *args) - if bench > 0: - return libtriton.retrieve_scalar(bench_id) \ No newline at end of file + params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args]) + torch.cuda.synchronize() + torch.ops.triton.launch_kernel(self.op_id, device, params) \ No newline at end of file diff --git a/python/triton/utils.py b/python/triton/utils.py deleted file mode 100644 index 7112870d2..000000000 --- a/python/triton/utils.py +++ /dev/null @@ -1,75 +0,0 @@ -import triton.frameworks as fw -import triton._C.libtriton as libtriton -import numpy as np -import weakref - -def cdiv(a, b): - return (a + b - 1) // b - -class tf_empty_proxy: - - def __init__(self, shape, dtype): - self.shape = shape - self.dtype = dtype - self.tensor = None - - def to_tensor(self): - assert self.tensor is not None - return self.tensor - -def empty(shape, dtype): - if fw.has_tensorflow(): - shape = [fw.tensorflow.constant(x) for x in shape] - shape = fw.tensorflow.stack(shape) - return tf_empty_proxy(shape, dtype) - #return fw.tf_extra_ops.alloc_empty(args, T = dtype) - elif fw.has_torch(): - return fw.torch.empty(shape, dtype=dtype, device='cuda:0') - -def shape(A) : - if fw.has_tensorflow(): - return A.shape.as_list() - elif fw.has_torch(): - return A.shape - else: - assert False - - -class id_dict: - - # Lazy entry for e.g., tensorflow, when value of benchmark is - # not known at graph compile time - class lazy_entry: - def __init__(self, id): - self.id = id - - def get(self): - return libtriton.retrieve_scalar(self.id) - - def __init__(self): - self.data = dict() - - def __delitem__(self, key): - del self.data[key] - - @staticmethod - def _get_key(key): - if fw.has_tensorflow(): - if isinstance(key, fw.tensorflow.Tensor): - key = id(key.op) - if fw.has_torch(): - if isinstance(key, fw.torch.Tensor): - key = id(key) - return key - - def __getitem__(self, key): - ret = self.data[id_dict._get_key(key)] - if isinstance(ret, id_dict.lazy_entry): - return ret.get() - return ret - - def __len__(self): - return len(self.data) - - def __setitem__(self, key, value): - self.data[id_dict._get_key(key)] = value \ No newline at end of file diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 9f1260469..9f5a01d9c 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -17,11 +17,11 @@ int main() { // config_t{ord, x[0], x[1], 384, 384, 384}, // config_t{ord, x[0], x[1], 512, 512, 512}, // config_t{ord, x[0], x[1], 768, 768, 768}, -// config_t{ord, x[0], x[1], 1024, 1024, 1024}, + config_t{ord, x[0], x[1], 1024, 1024, 1024}, // config_t{ord, x[0], x[1], 1280, 1280, 1280}, // config_t{ord, x[0], x[1], 1536, 1536, 1536}, // config_t{ord, x[0], x[1], 2048, 2048, 2048}, - config_t{ord, x[0], x[1], 8192, 8192, 8192}, +// config_t{ord, x[0], x[1], 8192, 8192, 8192}, // config_t{ord, x[0], x[1], 256, 16, 256}, // config_t{ord, x[0], x[1], 512, 16, 512}, @@ -65,7 +65,7 @@ int main() { for(const auto& c: configs){ std::tie(ord, AT, BT, M, N, K) = c; std::cout << "// " << c ; - for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord)) + for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord)) std::cout << ", " << perf << std::flush; std::cout << std::endl; } diff --git a/tests/common/dot.h b/tests/common/dot.h index 6433ec562..74c7d78ee 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -2,6 +2,7 @@ #include #include #include +#include #include "triton/driver/backend.h" #include "triton/driver/stream.h" #include "triton/tools/bench.hpp" @@ -12,6 +13,24 @@ #include "util.h" +//struct dot_arg_t{ +// CUdeviceptr a; +// CUdeviceptr b; +// CUdeviceptr c; +// float alpha; +// int M; +// int N; +// int K; +// int lda; +// int ldb; +// int ldc; +// CUdeviceptr locks; +//}; + +//typedef std::tuple dot_arg_t; + template static void cc_dot(std::vector &c, const std::vector &a, const std::vector &b, size_t M, size_t N, size_t K){ @@ -108,6 +127,7 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, opt.defines.push_back({"TM", {std::to_string(TM)}}); opt.defines.push_back({"TN", {std::to_string(TN)}}); opt.defines.push_back({"TK", {std::to_string(TK)}}); + opt.defines.push_back({"TZ", {"1"}}); opt.num_warps = {nwarp}; } if(mode == BENCH) { @@ -119,9 +139,25 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, } // kernels - rt::function function(src::dot, opt); - std::vector args = {&*da, &*db, &*dc, (float)1, M, N, K, lda, ldb, ldc, &*dlocks}; + float alpha = 1; + char args[60]; + memcpy(args + 0, &*da->cu(), 8); + memcpy(args + 8, &*db->cu(), 8); + memcpy(args + 16, &*dc->cu(), 8); + memcpy(args + 24, &alpha, 4); + memcpy(args + 28, &M, 4); + memcpy(args + 32, &N, 4); + memcpy(args + 36, &K, 4); + memcpy(args + 40, &lda, 4); + memcpy(args + 44, &ldb, 4); + memcpy(args + 48, &ldc, 4); + memcpy(args + 52, &*dlocks->cu(), 8); + + +// dot_arg_t args = {*da->cu(), *db->cu(), *dc->cu(), +// 1, M, N, K, lda, ldb, ldc, *dlocks->cu()}; +// std::cout << sizeof(dot_arg_t) << std::endl; auto grid = [M, N](const rt::function::options_t& x) { return rt::grid_t{ceil(M, x.D("TM")), ceil(N, x.D("TN")), @@ -131,7 +167,7 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, // metrics if(mode == BENCH){ auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; }; - double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream); + double triton_ns = triton::tools::bench([&]() { function((void**)&args, grid, stream);}, stream); bench.push_back(tflops(triton_ns)); // cublas @@ -162,7 +198,7 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, stream->write(&*da, true, 0, ha); stream->write(&*db, true, 0, hb); // run kernel - function(args, grid, stream); + function((void**)&args, grid, stream); // write back stream->synchronize(); // compare with CPU diff --git a/tests/common/src/dot.h b/tests/common/src/dot.h index e94e03b5c..54e58e13e 100644 --- a/tests/common/src/dot.h +++ b/tests/common/src/dot.h @@ -6,7 +6,9 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), TYPE * B __noalias __readonly __aligned(16), TYPE * C __noalias __aligned(16), float alpha, - int M, int N, int K __multipleof(16), + int M __retune, + int N __retune, + int K __retune __multipleof(16), int lda __multipleof(8), int ldb __multipleof(8), int ldc __multipleof(8), diff --git a/tests/unit/dot.cc b/tests/unit/dot.cc index 6c24386ea..300ca9427 100644 --- a/tests/unit/dot.cc +++ b/tests/unit/dot.cc @@ -16,7 +16,7 @@ int main() { for(int nwarps: std::vector{4}) for(bool AT: std::array{false, true}) for(bool BT: std::array{false, true}){ - configs.push_back(config_t{HALF, AT, BT, TM, TN, TK, TM, TN, TK, nwarps}); + configs.push_back(config_t{FLOAT, AT, BT, TM, TN, TK, TM, TN, TK, nwarps}); } // test dtype_t dtype;