diff --git a/include/triton/driver/module.h b/include/triton/driver/module.h index d62142c7d..2e6a390e9 100755 --- a/include/triton/driver/module.h +++ b/include/triton/driver/module.h @@ -74,6 +74,7 @@ public: cu_module(driver::context* context, std::unique_ptr module); cu_module(driver::context* context, const std::string& source); std::unique_ptr symbol(const char * name) const; + const std::string& source() const { return source_; } private: std::string source_; diff --git a/include/triton/runtime/arg.h b/include/triton/runtime/arg.h index 6e255f0e7..1a741077c 100644 --- a/include/triton/runtime/arg.h +++ b/include/triton/runtime/arg.h @@ -7,6 +7,9 @@ #include namespace triton{ +namespace ir{ + class type; +} namespace driver{ class buffer; @@ -26,6 +29,9 @@ enum arg_type { BUFFER_T }; +arg_type convert(ir::type *ty); + + inline size_t size_of(arg_type ty){ switch(ty){ case INT1_T: return 1; diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 26253b7ee..240e4944f 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -8,6 +8,7 @@ #include #include #include +#include // codegen #include "triton/ir/context.h" #include "triton/codegen/target.h" @@ -68,6 +69,11 @@ public: T D(const std::string& name) const { return convert(defines.at(name)); } + bool operator<(const options_t& other) const { + return std::make_pair(defines, num_warps) < + std::make_pair(other.defines, other.num_warps); + } + std::string to_str() const; std::map defines; size_t num_warps; @@ -79,41 +85,63 @@ public: private: class caller { public: - caller(ir::function *ir, std::shared_ptr program, const options_t& opt_); - void operator()(driver::stream *stream, const grid_t& grid, const std::vector& args) const; + // constructors + caller(driver::context* ctx, std::ifstream& ifs, const options_t& opt); + caller(ir::function *ir, std::shared_ptr program, const options_t& opt); + // serialization + void write(std::ofstream& ofs); + void read(driver::context* ctx, std::ifstream& ifs); + // accessors const options_t opt() const { return opt_; } + const driver::module* parent() const { return &*parent_; } + // entry points + void operator()(driver::stream *stream, const grid_t& grid, const std::vector& args) const; private: std::shared_ptr bin_; std::shared_ptr parent_; std::vector param_tys_; options_t opt_; + std::string name_; }; private: typedef std::pair> cache_key_t; private: + // cache + static std::string get_cache_prefix(); + // make triton::lang::translation_unit *make_ast(const std::string &src); std::unique_ptr make_ir(Parser &parser); std::unique_ptr make_bin(ir::module &function, driver::context *context, const options_t &opt); - caller autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector &args); + caller *make(driver::stream *stream, options_t opt); + void precompile(driver::stream *stream, const options_space_t& tuning_space); + // autotune + function::cache_key_t get_key(driver::stream *stream, const std::vector& args); + caller* autotune(driver::stream *stream, const grid_fn_ty& grid, const std::vector &args); public: static std::string preheader(); public: - function(const std::string& src, const options_space_t& opt = options_space_t()); + function(const std::string& src, const options_space_t& opt, const std::string &cache_ref = ""); void operator()(const std::vector& args, const grid_t& grid, driver::stream* stream); void operator()(const std::vector& args, const grid_fn_ty& grid, driver::stream *stream); void set_cst(const std::string& name, void* data, size_t n_bytes); private: + std::map> cst_; + // pre-compilation ir::context ctx_; std::string src_; - options_space_t opt_space_; - std::map cache_; - std::map> cst_; + options_space_t opt_; + std::set compiled_; + std::map> callers_; + // caching + std::string cache_ref_; + std::string cache_path_; + std::map cache_; }; } diff --git a/include/triton/tools/sys/mkdir.hpp b/include/triton/tools/sys/mkdir.hpp index e6c289535..5198a0098 100755 --- a/include/triton/tools/sys/mkdir.hpp +++ b/include/triton/tools/sys/mkdir.hpp @@ -61,6 +61,14 @@ namespace tools return (status==0 || errno==EEXIST)?0:-1; } + inline int mtime(std::string const & path) + { + struct stat st; + if(stat(path.c_str(), &st) != 0) + return 0; + return st.st_mtime; + } + } } diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 28940f563..51b722ecf 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -253,11 +253,11 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, return result; } + cu_module::cu_module(driver::context * context, std::unique_ptr ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ cu_context::context_switcher ctx(*context); -// std::cout << source << std::endl; // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; unsigned int errbufsize = 8096; @@ -266,10 +266,10 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo try{ dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval); }catch(exception::cuda::base const &){ -//#ifdef TRITON_LOG_PTX_ERROR +#ifdef TRITON_LOG_PTX_ERROR std::cerr << "Compilation Failed! Log: " << std::endl; std::cerr << errbuf << std::endl; -//#endif +#endif throw; } } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 25d355f2c..c5b109f6a 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -28,24 +28,28 @@ #include "triton/ir/function.h" #include "triton/ir/print.h" #include "triton/tools/bench.hpp" +#include "triton/tools/sha1.hpp" +#include "triton/tools/sys/getenv.hpp" +#include "triton/tools/sys/mkdir.hpp" #include "llvm/IR/Module.h" #include +#include std::mutex mut; namespace triton{ namespace runtime { -// helpers -void _parallel_loop_nest(std::vector const & ranges, - std::function const &)> const & f, - size_t nthreads){ +/* --------------------- */ +/* HELPERS */ +/* --------------------- */ + +void _loop_nest(std::vector const & ranges, + std::function const &)> const & f){ size_t D = ranges.size(); std::vector values(D, 0); - // Start with innermost loop size_t i = D - 1; while(true){ - // Execute function f(values); while(values[i]++ == ranges[i] - 1){ if(i == 0) @@ -56,24 +60,31 @@ void _parallel_loop_nest(std::vector const & ranges, } } -template -void _parallel_loop_nest(std::vector> const & iterates, std::function)> const & f, size_t nthreads){ - //Ranges to iterate over - std::vector ranges; - for(auto const & x: iterates) - ranges.push_back(x.size()); - //Proxy function - auto proxy = [&](std::vector const & idx){ - std::vector x(iterates.size()); - for(size_t i = 0; i < x.size(); ++i) - x[i] = iterates[i][idx[i]]; - f(x); - }; - //Iterate - _parallel_loop_nest(ranges, proxy, nthreads); + +/* --------------------- */ +/* OPTIONS */ +/* --------------------- */ + +std::string function::options_t::to_str() const{ + std::string ret = "nw-" + std::to_string(num_warps); + for(const auto& x : defines){ + ret += '-'; + ret += x.first; + ret += '-'; + ret += x.second; + } + // legalize + for(char& x: ret){ + if(x == ' ' || x == '^' || x == ',' || x == ':') + x = '_'; + } + return ret; } -// caller + +/* --------------------- */ +/* CALLER OBJECT */ +/* --------------------- */ arg_type convert(ir::type *ty) { if(ty->is_integer_ty(1)) @@ -97,8 +108,46 @@ arg_type convert(ir::type *ty) { throw std::runtime_error("unknown type"); } -function::caller::caller(ir::function *ir, std::shared_ptr parent, const options_t& opt) - : bin_(driver::kernel::create(&*parent, ir->get_name().c_str())), parent_(parent), opt_(opt) { +void function::caller::write(std::ofstream &ofs) { + // write name + ofs << name_ << std::endl; + // write signature + for(size_t i = 0; i < param_tys_.size(); i++) + ofs << param_tys_[i] << " "; + ofs << std::endl; + // write module + std::string source = ((driver::cu_module*)(&*parent_))->source(); + ofs << source; +} + +void function::caller::read(driver::context* ctx, std::ifstream &ifs) { + // read name + std::getline(ifs, name_); + // read signature + std::string line; + std::getline(ifs, line); + std::istringstream current(line); + int param; + param_tys_.clear(); + while(current >> param) + param_tys_.push_back((arg_type)param); + // read module + std::string src((std::istreambuf_iterator(ifs)), + std::istreambuf_iterator()); + parent_.reset(new driver::cu_module(ctx, src)); + bin_.reset(driver::kernel::create(&*parent_, name_.c_str())); + +} + +function::caller::caller(driver::context* ctx, std::ifstream &ifs, const options_t& opt) + : opt_(opt) { + read(ctx, ifs); +} + +function::caller::caller(ir::function *ir, + std::shared_ptr parent, const options_t& opt) + : parent_(parent), opt_(opt), name_(ir->get_name()) { + bin_.reset(driver::kernel::create(&*parent, name_.c_str())); // extract signature ir::function_type* ty = ir->get_fn_type(); for(size_t i = 0; i < ty->get_num_params(); i++) @@ -109,6 +158,7 @@ function::caller::caller(ir::function *ir, std::shared_ptr paren void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, const std::vector& args) const { if(args.size() != param_tys_.size()) throw std::runtime_error("invalid number of arguments"); + // set arguments for(size_t i = 0; i < args.size(); i++){ arg arg_i = args.at(i); arg_type ty = arg_i.type(); @@ -119,99 +169,33 @@ void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, else bin_->setArg(i, size_of(ty), arg_i.data()); } - // sanity check + // set grid if(_grid.size() > 3) throw std::runtime_error("grid size must be no greater than 3"); std::array grid; for(size_t i = 0; i < 3; i++) grid[i] = (i < _grid.size()) ? _grid[i] : 1; + // enqueue stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1}); } +/* --------------------- */ +/* FUNCTION */ +/* --------------------- */ + +// create Triton-IR from AST std::unique_ptr function::make_ir(Parser& parser) { - // create Triton-IR from AST ir::module* module = new ir::module("", ctx_); Generator gen(&parser); gen.Gen(module); return std::unique_ptr(module); } - -function::caller function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn, - const std::vector& args) { - - // all tuning parameters are strings - std::vector num_warps; - for(size_t i: opt_space_.num_warps) - num_warps.push_back(std::to_string(i)); - std::vector> space; - space.push_back(num_warps); - for(const auto& i: opt_space_.defines) - space.push_back(i.second); - - // exhaustive search - double best_ts = INFINITY; - std::unique_ptr ret; - - auto benchmark = [&](std::vector params) { - // extract options - options_t opt; - unsigned i = 0; - opt.num_warps = std::stoi(params[i++]); - for(auto it: opt_space_.defines){ - opt.defines[it.first] = params[i++]; - } - // pre-process - TokenSequence tokens; - Preprocessor cpp(&src_, true); - for(auto it: opt_space_.defines) - cpp.AddMacro(it.first, &opt.defines.at(it.first)); - cpp.Process(tokens); - - // parse - Parser parser(tokens); - parser.Parse(); - // triton-ir code-gen - auto ir = make_ir(parser); - // binary code-gen - std::unique_ptr bin; - try{ - bin = make_bin(*ir, stream->context(), opt); - }catch(const std::runtime_error& e){ - return; - } - // kernel uses too much resources - if(!bin) - return; - // copy constants - std::unique_ptr buffer; - for(ir::alloc_const* alloc: ir->allocs()){ - std::string name = alloc->get_name(); - auto it = cst_.find(name); - if(it == cst_.end()) - throw std::runtime_error("constant not set before execution"); - buffer = bin->symbol(name.c_str()); - stream->write(&*buffer, true, 0, it->second); - } - // benchmark - ir::function *tmp = ir->get_function_list()[0]; - caller call(tmp, std::move(bin), opt); - double ts = tools::bench([&]() { call(stream, grid_fn(opt), args); }, stream, true); - // save best - if(ts < best_ts) { - best_ts = ts; - ret.reset(new caller(call)); - } - }; - _parallel_loop_nest(space, benchmark, 1); - if(!ret) - throw std::runtime_error("could not find valid option in provided space"); - return *ret; -} - - -std::unique_ptr function::make_bin(ir::module &module, driver::context *context, const options_t& opt) { +// create Binary from Triton-IR +std::unique_ptr function::make_bin(ir::module &module, + driver::context *context, + const options_t& opt) { std::unique_ptr target = context->device()->make_target(); // generate llvm code llvm::LLVMContext ctx; @@ -236,8 +220,6 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c dce.run(module); peephole.run(module); dce.run(module); -// ir::print(module, std::cout); -// exit(EXIT_FAILURE); align.run(module); cts.run(module); axes.run(module); @@ -258,16 +240,135 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c return std::unique_ptr(); barriers.run(module); isel.visit(module, *llvm); - // return binary std::unique_ptr res(driver::module::create(context, std::move(llvm))); - // done -// exit(EXIT_FAILURE); return res; } + +// create Binary from options +function::caller* function::make(driver::stream *stream, options_t opt) { + // cache path + std::string cache_path = cache_path_ + opt.to_str() + ".ptx"; + int ref_mtime = tools::mtime(cache_ref_); + int ptx_mtime = tools::mtime(cache_path); + // if cached ptx is newer than reference library + if(!ref_mtime || !ptx_mtime || ref_mtime < ptx_mtime){ + std::ifstream ifs(cache_path); + // file is empty -- invalid + if(ifs && ifs.peek() == std::ifstream::traits_type::eof()) + return nullptr; + // load cached caller + if(ifs) + return new caller(stream->context(), ifs, opt); + } + // pre-process + TokenSequence tokens; + Preprocessor cpp(&src_, true); + for(auto it: opt.defines) + cpp.AddMacro(it.first, &it.second); + cpp.Process(tokens); + // src -> ast + Parser parser(tokens); + parser.Parse(); + // ast -> triton-ir + auto ir = make_ir(parser); + // triton-ir -> binary + std::unique_ptr bin; + try{ + bin = make_bin(*ir, stream->context(), opt); + }catch(const std::runtime_error&){ + if(!cache_path_.empty()) + std::ofstream ofs(cache_path); + return nullptr; + } + // create callable + ir::function *tmp = ir->get_function_list()[0]; + caller* ret = new caller(tmp, std::move(bin), opt); + // serialize callable + if(!cache_path_.empty()){ + std::ofstream ofs(cache_path); + ret->write(ofs); + } + return ret; +} + +// precompile all kernels spanned by given options space +void function::precompile(driver::stream* stream, + const options_space_t& space) { + // all ranges + std::vector ranges; + ranges.push_back(space.num_warps.size()); + for(const auto& x: space.defines) + ranges.push_back(x.second.size()); + // functor for source with given option + auto do_make = [&](std::vector params) { + // compilation options + unsigned i = 0; + options_t opt; + opt.num_warps = space.num_warps[params[i++]]; + for(auto D: space.defines) + opt.defines[D.first] = D.second[params[i++]]; + // compile + caller* call = make(stream, opt); + if(!call) + return; + // copy constants + std::unique_ptr buffer; + for(const auto& cst: cst_){ + buffer = call->parent()->symbol(cst.first.c_str()); + stream->write(&*buffer, true, 0, cst.second); + } + callers_[opt].reset(call); + }; + // multi-threaded compilation + _loop_nest(ranges, do_make); + if(callers_.empty()) + throw std::runtime_error("could not find valid option in provided space"); +} + +// return auto-tuning key for given function arguments +function::cache_key_t function::get_key(driver::stream *stream, const std::vector& args) { + cache_key_t ret; + ret.first = stream->context()->device(); + for(size_t i = 0; i < args.size(); i++){ + arg_type ty = args.at(i).type(); + if(!is_int_type(ty)) + continue; + long val = 0; + std::memcpy((void*)&val, args.at(i).data(), size_of(ty)); + ret.second.push_back(val); + } + return ret; +} +// returns program with best compilation options for given parameter +function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& grid_fn, + const std::vector& args) { + // fast path -- no autotuning necessary + if(callers_.size() == 1) + return &*callers_.begin()->second; + // slow path -- autotuning necessary + double best_ts = INFINITY; + caller* ret = nullptr; + for(auto &x : callers_){ + if(x.second == nullptr) + throw std::runtime_error("configuration not compiled"); + caller* current = &*x.second; + double ts = tools::bench([&]() { (*current)(stream, grid_fn(x.first), args); }, + stream, true); + ret = (ts < best_ts) ? current : ret; + best_ts = std::min(ts, best_ts); + } + return ret; +} + +// set copy host buffer "data" into constant memory buffer "name" +void function::set_cst(const std::string& name, void* data, size_t n_bytes) { + cst_[name] = std::vector((char*)data, (char*)data + n_bytes); +} + + std::string function::preheader() { -return -R"( + return R"( #define bool _Bool #define true 1 #define false 0 @@ -297,47 +398,65 @@ typedef long int64; )"; } -function::function(const std::string &src, const options_space_t& opt): src_(src), opt_space_(opt) { +std::string function::get_cache_prefix() { + //user-specified cache path + std::string result = tools::getenv("TRITON_CACHE_PATH"); + if(!result.empty()){ + if(tools::mkpath(result)==0) + return result; + } + //create in home + result = tools::getenv("HOME"); + if(!result.empty()) + { + result = result + "/.triton/cache/"; + if(tools::mkpath(result)==0) + return result; + } + return ""; +} + +function::function(const std::string &src, + const options_space_t& opt, + const std::string &cache_ref): + src_(src), opt_(opt), cache_ref_(cache_ref) { + // hash source code + unsigned char hash[20]; + sha1::calc((void*)src_.data(), src_.size(), hash); + // create cache path + char _hex[40]; + sha1::toHexString(hash, _hex); + std::string hex(_hex, _hex + 40); + cache_path_ = get_cache_prefix() + hex + "/"; + tools::mkpath(cache_path_); + // append pre-header to source src_ = preheader() + src_; } -void function::operator()(const std::vector& args, const grid_fn_ty& grid_fn, driver::stream *stream) { - cache_key_t key; - - /* figure out if the kernel should be re-tuned */ - // re-tune if device is different - key.first = stream->context()->device(); - // re-tune if any int argument is different - for(size_t i = 0; i < args.size(); i++){ - arg_type ty = args.at(i).type(); - if(is_int_type(ty)){ - long val = 0; - std::memcpy((void*)&val, args.at(i).data(), size_of(ty)); - key.second.push_back(val); - } - } - - /* find existing configuration */ +void function::operator()(const std::vector& args, + const grid_fn_ty& grid_fn, + driver::stream *stream) { + // pre-compile kernels + if(callers_.empty()) + precompile(stream, opt_); + // auto-tune if necessary + auto key = get_key(stream, args); auto it = cache_.find(key); - if(it != cache_.end()){ - it->second(stream, grid_fn(it->second.opt()), args); - return; - } - - /* re-tune and re-compile */ - { - std::lock_guard lock(mut); - cache_.insert({key, autotune(stream, grid_fn, args)}); + if(it == cache_.end()){ + auto best = autotune(stream, grid_fn, args); + it = cache_.insert({key, best}).first; } + // run + (*it->second)(stream, grid_fn(it->second->opt()), args); } -void function::operator()(const std::vector& args, const grid_t& grid, driver::stream *stream) { +void function::operator()(const std::vector& args, + const grid_t& grid, + driver::stream *stream) { return this->operator()(args, [&grid](const options_t&){ return grid; }, stream); } -void function::set_cst(const std::string& name, void* data, size_t n_bytes) { - cst_[name] = std::vector((char*)data, (char*)data + n_bytes); -} + } } diff --git a/python/examples/einsum.py b/python/examples/einsum.py index a971c6683..ce6d49210 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -17,14 +17,14 @@ MNK = [ (2048, 2048, 2048), #(8192, 8192, 8192), - # (64, 64, 64000), - # (64, 64, 128000), - # (256, 256, 64000), - # (256, 256, 128000), + (64, 64, 64000), + (64, 64, 128000), + (256, 256, 64000), + (256, 256, 128000), - # (1536, 16, 1536), - # (1536, 32, 1536), - # (1536, 64, 1536), + (1536, 16, 1536), + (1536, 32, 1536), + (1536, 64, 1536), # (1536, 128, 1536), # (4096, 16, 4096), # (4096, 32, 4096), @@ -33,9 +33,9 @@ MNK = [ # (127008, 768, 576) ] -#for M, N, K in MNK: -# matmul = lambda a, b: torch.matmul(a, b) -# configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())] +for M, N, K in MNK: + matmul = lambda a, b: torch.matmul(a, b) + configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())] #for M, N, K in MNK: # matmul = lambda a, b: torch.matmul(a.t(), b) # configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())] @@ -175,15 +175,15 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: a = torch.rand(*a_shape).type(dtype).cuda() b = torch.rand(*b_shape).type(dtype).cuda() # triton output - print(a.size(), b.size()) - tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True) + tc = torch.empty(c_shape, device=a.device) + triton.ops.einsum(expr, a, b, tc, arrays = arrays, bench = True) # reference output if torch_fn: rc = torch_fn(a, b, **arrays) else: rc = torch.einsum(expr, a, b) # performance relative to equivalent matrix multiplication - ctx = triton.ctx_registry[tc] + ctx = triton.ops._einsum.registry[tc] B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K cmp_eqbmm = False if cmp_eqbmm: diff --git a/python/src/bindings.cc b/python/src/bindings.cc index 0d9d545bc..af281ff83 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -6,6 +6,7 @@ #include #include #include "triton/runtime/function.h" +#include "triton/runtime/arg.h" #include "triton/lang/code_gen.h" #include "triton/lang/parser.h" #include "triton/lang/cpp.h" @@ -40,9 +41,10 @@ void delete_grid(size_t id) { /* Function map */ void register_fn(size_t id, - const std::string& src, - const rt::function::options_space_t& opt) { - id_fn_map[id].reset(new rt::function(src, opt)); + const std::string& src, + const rt::function::options_space_t& opt, + const std::string &cache_ref) { + id_fn_map[id].reset(new rt::function(src, opt, cache_ref)); } void delete_fn(size_t id) { @@ -64,6 +66,7 @@ size_t make_op_id() { return id_fn_map.size(); } + /* TF scalar wrapper */ size_t make_scalar_id() { size_t ret = i64scalar_map.size(); @@ -423,6 +426,37 @@ inline std::string to_torch_ty(ir::type *ty) { throw std::runtime_error("unknown type"); } +inline std::string to_torch_ty(rt::arg_type ty){ + switch(ty){ + case rt::INT1_T: return "int64_t"; + case rt::INT8_T: return "int64_t"; + case rt::INT16_T: return "int64_t"; + case rt::INT32_T: return "int64_t"; + case rt::INT64_T: return "int64_t"; + case rt::HALF_T: return "double"; + case rt::FLOAT_T: return "double"; + case rt::DOUBLE_T: return "double"; + case rt::BUFFER_T: return "torch::Tensor"; + default: return "UNKNOWN"; + } +} + +inline std::string to_c_ty(rt::arg_type ty){ + switch(ty){ + case rt::INT1_T: return "bool"; + case rt::INT8_T: return "int8_t"; + case rt::INT16_T: return "int16_t"; + case rt::INT32_T: return "int32_t"; + case rt::INT64_T: return "int64_t"; + case rt::HALF_T: return "half"; + case rt::FLOAT_T: return "float"; + case rt::DOUBLE_T: return "double"; + case rt::BUFFER_T: return "drv::cu_buffer"; + default: return "UNKNOWN"; + } +} + + inline std::string to_c_ty(ir::type *ty) { if(ty->is_integer_ty(1)) return "bool"; @@ -448,33 +482,30 @@ inline std::string to_c_ty(ir::type *ty) { void gen_torch_signature(std::ostringstream& oss, - ir::function* fn, - const std::string& name) { - const auto& args = fn->args(); + const std::string& name, + const std::vector& args) { std::string ret_ty = "void"; oss << ret_ty << " " << name << "("; oss << "int64_t id, "; oss << "int64_t bench, "; oss << "int64_t bench_id, "; for(size_t i = 0; i < args.size(); i++) { - ir::argument* arg = args[i]; if(i > 0) oss << ", "; - oss << to_torch_ty(arg->get_type()) << " " << arg->get_name(); + oss << to_torch_ty(args[i]) << " " << "th_arg_" << i; } oss << ")"; } void gen_torch_init_driver(std::ostringstream &oss, - const std::vector&args) { - ir::argument* tensor = nullptr; - for(ir::argument* arg: args) - if(arg->get_type()->is_pointer_ty()){ - tensor = arg; + const std::vector&args) { + // Find index of first buffer + size_t i; + for(i = 0; i < args.size(); i++) + if(args[i] == rt::BUFFER_T) break; - } oss << " // Wrap CUDA handles" << std::endl; - oss << " c10::DeviceIndex device = " << tensor->get_name() << ".storage().device().index();" << std::endl; + oss << " c10::DeviceIndex device = th_arg_" << i << ".storage().device().index();" << std::endl; oss << " // Get stream" << std::endl; oss << " CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();" << std::endl; oss << " triton::driver::cu_stream stream(custream, false);" << std::endl; @@ -482,28 +513,28 @@ void gen_torch_init_driver(std::ostringstream &oss, } void gen_torch_make_handles(std::ostream &os, - const std::vector& args) { + const std::vector& args) { for(unsigned i = 0; i < args.size(); i++){ - ir::argument *arg = args[i]; - const std::string& name = arg->get_name(); - ir::type* ty = arg->get_type(); - if(!ty->is_pointer_ty()) - os << " " << to_c_ty(ty) << " arg_" << name << " = " << name << ";" << std::endl; + rt::arg_type arg = args[i]; + const std::string th_name = "th_arg_" + std::to_string(i); + const std::string name = "arg_" + std::to_string(i); + if(arg != rt::BUFFER_T) + os << " " << to_c_ty(arg) << " " << name << " = " << th_name << ";" << std::endl; else{ - os << " CHECK_INPUT(" << name << ");" << std::endl; - os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), " - " (CUdeviceptr)((char*)" + name + ".storage().data() + " + name + ".storage_offset() * " + name + ".itemsize()), false);" << std::endl; + os << " CHECK_INPUT(" << th_name << ");" << std::endl; + os << " drv::cu_buffer " + name + "(ctx, " + th_name + ".storage().size(), " + " (CUdeviceptr)((char*)" + th_name + ".storage().data() + " + th_name + ".storage_offset() * " + th_name + ".itemsize()), false);" << std::endl; } } } -void gen_torch_make_launch_function(std::ostream &os, const std::vector& args) { +void gen_torch_make_launch_function(std::ostream &os, + const std::vector& args) { os << " std::function run = [&](){\n "; os << " (*id_fn_map.at(id))({"; for(unsigned i = 0; i < args.size() ; i++){ - ir::argument *arg = args[i]; - std::string name = "arg_" + arg->get_name(); - if(arg->get_type()->is_pointer_ty()) + std::string name = "arg_" + std::to_string(i); + if(args[i] == rt::BUFFER_T) name = "&" + name; if(i > 0) os << ", "; @@ -531,15 +562,7 @@ void gen_torch_ret(std::ostream &os, const std::vector& outputs) { } std::tuple make_torch_src(const std::string& src, - const runtime::function::options_space_t& opt) { - // triton-ir code-gen - ir::context ctx; - auto ir = std::shared_ptr(new ir::module("", ctx)); - make_module(src, &*ir, opt); - // function - ir::function* fn = ir->get_function_list().front(); - std::string name = fn->get_name(); + std::string> make_torch_src(const std::string& name, std::vector args) { // generate framework code std::ostringstream oss; oss << R"( @@ -563,11 +586,11 @@ extern std::map i64scalar_map; )"; - gen_torch_signature(oss, fn, name); + gen_torch_signature(oss, name, args); oss << " {" << std::endl; - gen_torch_init_driver(oss, fn->args()); - gen_torch_make_handles(oss, fn->args()); - gen_torch_make_launch_function(oss, fn->args()); + gen_torch_init_driver(oss, args); + gen_torch_make_handles(oss, args); + gen_torch_make_launch_function(oss, args); //gen_torch_ret(oss); oss << "}" << std::endl; @@ -578,6 +601,22 @@ extern std::map i64scalar_map; return {oss.str(), name}; } +/* Function signature */ +std::vector get_fn_signature(const std::string& src, + const runtime::function::options_space_t& opt) { + // triton-ir code-gen + ir::context ctx; + auto ir = std::shared_ptr(new ir::module("", ctx)); + make_module(src, &*ir, opt); + // function + ir::function* fn = ir->get_function_list().front(); + // extract signature + std::vector ret; + ir::function_type* ty = fn->get_fn_type(); + for(size_t i = 0; i < ty->get_num_params(); i++) + ret.push_back(rt::convert(ty->get_param_ty(i))); + return ret; +} typedef triton::runtime::function::options_t options_t; typedef triton::runtime::function::options_space_t options_space_t; @@ -593,6 +632,17 @@ PYBIND11_MODULE(libtriton, m) { "Creates C++ source code for a custom PyTorch op "); // bindings for triton classes + pybind11::enum_(m, "arg_type") + .value("int1", rt::INT1_T) + .value("int8", rt::INT8_T) + .value("int16", rt::INT16_T) + .value("int32", rt::INT32_T) + .value("int64", rt::INT64_T) + .value("half", rt::HALF_T) + .value("float", rt::FLOAT_T) + .value("double", rt::DOUBLE_T) + .value("buffer", rt::BUFFER_T); + pybind11::class_(m, "options") .def(pybind11::init<>()) .def("d", &options_t::D) @@ -604,6 +654,7 @@ PYBIND11_MODULE(libtriton, m) { .def_readwrite("num_warps", &options_space_t::num_warps); // hooks into triton constructs since frameworks may not use pybind11 + m.def("get_fn_signature", &get_fn_signature); m.def("register_grid", ®ister_grid); m.def("delete_grid", &delete_grid); m.def("register_fn", ®ister_fn); diff --git a/python/triton/__init__.py b/python/triton/__init__.py index cb4097e72..388484039 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,5 +1,4 @@ from .kernel import * -from .function import * from .utils import * import triton.ops diff --git a/python/triton/function.py b/python/triton/function.py deleted file mode 100644 index 2c9cb60fe..000000000 --- a/python/triton/function.py +++ /dev/null @@ -1,127 +0,0 @@ -import triton.frameworks as fw -import triton.utils as utils - -class OpContext(object): - - def __init__(self): - self.to_save = [] - - def save_for_backward(self, *tensors): - self.to_save = [x.to_tensor() if isinstance(x, utils.tf_empty_proxy) else x - for x in tensors] - - @property - def saved_tensors(self): - return self.to_save - -class function_meta(type): - - def __init__(cls, name, bases, attrs): - cls.registered = False - return super(function_meta, cls).__init__(name, bases, attrs) - -ctx_registry = utils.id_dict() - -class function(metaclass = function_meta): - - @staticmethod - def forward(ctx, *args, **kwargs): - raise NotImplementedError - - @staticmethod - def backward(ctx, grad_output): - raise NotImplementedError - - @classmethod - def apply_torch(cls, *args, **kwargs): - class TorchFunction(fw.torch.autograd.Function): - @staticmethod - def forward(ctx, *targs): - y = cls.forward(ctx, *targs, **cls.torch_kwargs) - ctx_registry[y] = ctx - return y - @staticmethod - def backward(ctx, grad_output): - return cls.backward(ctx, grad_output) - cls.torch_kwargs = kwargs - return TorchFunction.apply(*args) - torch_kwargs = 0 - - @classmethod - def extract_tf_tensors(cls, lst, err): - ret = [] - for x in lst: - if x is None: - ret += [None] - elif isinstance(x, fw.tensorflow.Tensor): - ret += [x] - elif isinstance(x, utils.tf_empty_proxy): - if x.tensor is None: - raise ValueError('Empty tensor never filled during ' + err) - else: - ret += [x.tensor] - else: - raise ValueError('Unsupported return type', type(x)) - return ret - - @classmethod - def map_in_to_args(cls, op, args): - ret = dict() - for i, ix in enumerate(op.inputs): - for j, jx in enumerate(args): - if ix is jx: - ret[j] = i - return ret - - @classmethod - def map_res_to_out(cls, op, result): - ret = [] - for i, ix in enumerate(result): - for j, jx in enumerate(op.outputs): - if ix is jx: - ret.append(j) - return ret - - @classmethod - def apply_tensorflow(cls, *args, **kwargs): - ctx = OpContext() - - # run forward pass - result = cls.forward(ctx, *args, **kwargs) - result = result if isinstance(result, tuple) else (result, ) - result = function.extract_tf_tensors(result, 'forward') - - # Register backward pass - op = result[0].op - ctx_registry[op] = ctx - if not cls.registered: - remap_in = cls.map_in_to_args(op, args) - remap_out = cls.map_res_to_out(op, result) - @fw.tensorflow.RegisterGradient(op.op_def.name) - def gradient(op, *dy): - # Remap gradient inputs in the right order - dy = [dy[i] for i in remap_out] - dy = dy if len(dy) > 1 else dy[0] - # Execute gradient function - grad = cls.backward(ctx_registry[op], dy) - grad = function.extract_tf_tensors(grad, 'backward') - # Remap gradient in the right order - ret = [None] * len(op.inputs) - for i in range(len(grad)): - if i in remap_in: - ret[remap_in[i]] = grad[i] - # Return - return ret - cls.registered = True - - # Return tensor - return result[0] if len(result)==1 else result - - @classmethod - def apply(cls, *args, **kwargs): - if fw.has_tensorflow(): - return cls.apply_tensorflow(*args, **kwargs) - elif fw.has_torch(): - return cls.apply_torch(*args, **kwargs) - else: - assert False diff --git a/python/triton/kernel.py b/python/triton/kernel.py index ab7c3f49e..fe5e0aabd 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -17,43 +17,6 @@ import triton.frameworks as fw import triton.utils import triton._C.libtriton as libtriton -def _make_framework_src(src, grid): - if fw.has_torch: - return libtriton.make_torch_src(src, grid) - else: - assert False - -def _make_cache_path(src): - md5 = hashlib.sha1(src.encode()) - hexhash = md5.hexdigest() - home = os.path.expanduser('~') - cacheroot = os.path.join(home, '.triton', 'cache') - cachepath = os.path.join(cacheroot, str(hexhash)) - if not os.path.exists(cachepath): - os.makedirs(cachepath) - return cachepath - -def _write_bindings(src, root): - if fw.has_torch(): - name = 'torch' - else: - assert False - cpp = os.path.join(root, '{name}.cpp'.format(name=name)) - suffix = sysconfig.get_config_var('EXT_SUFFIX') - so = os.path.join(root, '{name}{suffix}'.format(name=name, suffix=suffix)) - recompile = False - # recompile if .so does not exist - if not os.path.exists(cpp) or not os.path.exists(so): - recompile = True - # recompile if cpp was modified after .so - elif max(cpp, so, key=os.path.getctime) == cpp: - recompile = True - # write cpp file - if recompile: - with open(cpp, 'w+') as handle: - handle.writelines(src) - # return path of cpp file - return (cpp, so) @contextlib.contextmanager def quiet(): @@ -64,7 +27,7 @@ def quiet(): finally: sys.stdout, sys.stderr = old_stdout, old_stderr -def _build(src, path): +def _build(src, path, name): ccdir = os.path.join(libtriton.__file__, os.path.pardir) ccdir = os.path.realpath(ccdir) # include directories @@ -88,7 +51,6 @@ def _build(src, path): libraries += ['torch'] abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)] - name = 'torch' else: assert False # extra arguments @@ -142,30 +104,47 @@ def _cvt_to_def_str(obj): return str(obj) -def _make_framework_op(src, options): - src, name = _make_framework_src(src, options) - cache_path = _make_cache_path(src) - cpp, so = _write_bindings(src, cache_path) - _build(cpp, cache_path) - if fw.has_torch(): - fw.torch.ops.load_library(so) - return getattr(fw.torch.ops.triton, name) - else: - assert False +def _encode(arg_types): + codes = { + libtriton.arg_type.int1: 'i1', + libtriton.arg_type.int8: 'i8', + libtriton.arg_type.int32: 'i32', + libtriton.arg_type.int64: 'i64', + libtriton.arg_type.half: 'f16', + libtriton.arg_type.float: 'f32', + libtriton.arg_type.double: 'f64', + libtriton.arg_type.buffer: 'buf' + } + ret = '_'.join(map(codes.get, arg_types)) + return ret -def _make_grid(grid, args) : - scalars = [x for x in args if isinstance(x, triton.utils.scalar)] - def grid(opt): - for x in scalars: - x.set_assume_initialized() - result = grid(opt) - for x in scalars: - x.unset_assume_initialized() - return result - return grid - - -bench_registry = triton.utils.id_dict() +def _make_framework_op(arg_types): + name = _encode(arg_types) + # path of .cpp and .so file + home = os.path.expanduser('~') + root = os.path.join(home, '.triton', 'torch', name) + if not os.path.exists(root): + os.makedirs(root) + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(root, f'op{suffix}') + cpp = os.path.join(root, f'op.cpp') + # handle cached .so file + if os.path.exists(so): + tt_mtime = os.stat(os.path.realpath(libtriton.__file__)).st_mtime + so_mtime = os.stat(so).st_mtime + # can use cached if libtriton is older than the .so + if tt_mtime < so_mtime: + fw.torch.ops.load_library(so) + return getattr(fw.torch.ops.triton, name) + # create torch source code + src, _ = libtriton.make_torch_src(name, arg_types) + with open(cpp, 'w+') as handle: + handle.writelines(src) + # compile torch source code + _build(cpp, root, 'op') + fw.torch.ops.load_library(so) + return getattr(fw.torch.ops.triton, name) + class kernel: @@ -180,9 +159,8 @@ class kernel: self.cst[name] = value def __call__(self, *args, **kwargs): - ######################## - # keyword arguments + # JIT Options ######################## num_warps = kwargs['num_warps'] if 'num_warps' in kwargs else [2, 4, 8] defines = kwargs['defines'] if 'defines' in kwargs else dict() @@ -195,7 +173,6 @@ class kernel: ######################### # cache ######################## - # create a new framework op when defines are different key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in defines.items()]) if key not in self.fw_id.keys(): @@ -211,15 +188,16 @@ class kernel: opt = libtriton.options_space() opt.defines = macros opt.num_warps = num_warps - # create unique id for this op + # create triton function for this op op_id = libtriton.make_op_id() self.fw_id[key] = op_id - # register function - libtriton.register_fn(op_id, self.src, opt) + libtriton.register_fn(op_id, self.src, opt, os.path.realpath(libtriton.__file__)) for name, value in self.cst.items(): libtriton.register_cst(op_id, name, value) + # create pytorch hook for this op + arg_types = libtriton.get_fn_signature(self.src, opt) if self.fw_op is None: - self.fw_op = _make_framework_op(self.src, opt) + self.fw_op = _make_framework_op(arg_types) ######################## # initialize diff --git a/python/triton/ops/batchnorm.py b/python/triton/ops/batchnorm.py index 3d415c7bb..5b1cc806d 100644 --- a/python/triton/ops/batchnorm.py +++ b/python/triton/ops/batchnorm.py @@ -1,7 +1,8 @@ import triton +import torch import math -class _batchnorm(triton.function): +class _batchnorm(torch.autograd.Function): fwd_src = """ void fwdbatchnorm(float *Y, float *M, float *V, diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 1bd318b13..6c74dc92c 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -73,7 +73,7 @@ class _einsum(torch.autograd.Function): stride_a_last, stride_b_last, stride_c_last, lut_mode_a, lut_mode_b, delta_a, delta_b, - subscripted): + subscripted, varnames): use_lut_a = True use_lut_b = True @@ -123,6 +123,8 @@ __global__ void {name}( src += "\n" for ptr in subscripted: src += f", int* {ptr}" + for name in varnames: + src += f", int {name}" src += """) { // re-order outer program ids @@ -274,6 +276,9 @@ __global__ void {name}( TYPE c[TM, TN, TB] = acc; // re-materialize ranges + pid_mn = get_program_id(0) / div_m; + pid_n = pid_mn % grid_n; + pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m); """ for axes, tile, off in zip([axes_m, axes_n, axes_b], ['TM', 'TN', 'TB'], @@ -410,11 +415,8 @@ __global__ void {name}( batch = [d for d in sym_a if d in sym_b and d in sym_c] outer = [d for d in sym_a if d not in sym_b and d in sym_c] inner = [d for d in sym_a if d in sym_b and d not in sym_c] - illegal = [d for d in sym_a if d not in sym_b and d not in sym_c] - if illegal: - raise ValueError(f"einsum labels {illegal} ({expr_a}) "\ - f"not present in {expr_b} or {expr_c}") - return _einsum.uniq(batch), _einsum.uniq(outer), _einsum.uniq(inner) + variables = [d for d in sym_a if d not in sym_b and d not in sym_c] + return _einsum.uniq(batch), _einsum.uniq(outer), _einsum.uniq(inner), variables def replace_subscript(expr, arrays): @@ -467,7 +469,33 @@ __global__ void {name}( locks = None kernel_cache = dict() - def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, arrays, mask, shape_c): + @staticmethod + def _tile(M, N, B, TMs, TNs, TBs, TZs, TK): + smp = 15 + # occupancy estimation + grid = lambda TM, TN, TB, TZ: \ + triton.cdiv(M, TM)* \ + triton.cdiv(N, TN)* \ + triton.cdiv(B, TB)* \ + TZ + occupancy = lambda TM, TN, TB, TZ: \ + min(grid(TM, TN, TB, TZ), 4*smp) + # arithmetic intensity estimation + intensity = lambda TM, TN: \ + TM * TN * TK / (TM*TK + TK*TN) + # occupancy/intensity for all configurations + estimates = {(TM, TN, TB, TZ): (occupancy(TM, TN, TB, TZ), intensity(TM, TN)) \ + for TM in TMs \ + for TN in TNs \ + for TB in TBs \ + for TZ in TZs } + # returns configuration that maximizes occupancy subject to maximizing intensity + estimates = sorted(estimates.items(), + key=lambda item: item[1], + reverse=True) + return estimates[0][0] + + def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, arrays, mask, shape_c, varnames): # parse symbols expr_a, expr_bc = einsum.split(",") expr_b, expr_c = expr_bc.split("->") @@ -476,9 +504,13 @@ __global__ void {name}( sym_b = _einsum.parse_expr(expr_b, subscripted) sym_c = _einsum.parse_expr(expr_c, subscripted) # parse axes - axes_b, axes_m, axes_k = _einsum.parse_axes(sym_a, sym_b, sym_c, subscripted) - _, axes_n, _ = _einsum.parse_axes(sym_b, sym_a, sym_c, subscripted) + axes_b, axes_m, axes_k, var = _einsum.parse_axes(sym_a, sym_b, sym_c, subscripted) + _, axes_n, _, _ = _einsum.parse_axes(sym_b, sym_a, sym_c, subscripted) axes = axes_b + axes_m + axes_n + axes_k + # unresolved symbols + unresolved = [x for x in map(str, var) if x not in varnames] + if unresolved: + raise ValueError(f'unresolved symbols: {unresolved}') # check dimensions dims_a = dict(zip(sym_a, shape_a)) dims_b = dict(zip(sym_b, shape_b)) @@ -520,7 +552,7 @@ __global__ void {name}( stride_a_last, stride_b_last, stride_c_last, lut_mode_a, lut_mode_b, delta_a, delta_b, - subscripted) + subscripted, varnames) self.kernel = cache[name] # Initialize locks if _einsum.instance.locks is None: @@ -565,19 +597,21 @@ __global__ void {name}( self.pos_a = 0 self.pos_b = 1 self.pos_c = 2 - # pre-processor macros - TM = [16] + [x for x in [32, 64, 128] if x <= M] - TN = [16] + [x for x in [32, 64, 128] if x <= N] - TB = [x for x in [1, 2, 4] if x <= B] - MAX_GZ = K // 2048 - MIN_GM = M // max(TM) - MIN_GN = N // max(TN) - MIN_GB = B // max(TB) - TZ = [x for x in [1, 2, 4, 8, 16, 32] \ - if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256] - TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2] - TM, TN, TB, TZ = 64, 64, 1, 1 - self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype} + # user-provided variables + self.pos_vars = len(self.args) + self.varnames = varnames + self.args += [None] * len(varnames) + # tile size ranges + MAX_GZ = triton.cdiv(K, 2048) + TMs = [16] + [x for x in [32, 64, 128] if x <= M] + TNs = [16] + [x for x in [32, 64, 128] if x <= N] + TBs = [x for x in [1, 2, 4, 8] if x <= B] + TZs = [x for x in [1, 2, 4, 8, 16, 32] if x <= MAX_GZ] + # tile sizes + TM, TN, TB, TZ = _einsum.instance._tile(M, N, B, TMs, TNs, TBs, TZs, TK) + TM, TN, TB, TZ = 64, 128, 1, 1 + self.macros = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype} + self.num_warps = [4] if mask: self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10) # save information on the operation @@ -589,12 +623,15 @@ __global__ void {name}( self.matmul_N = N self.matmul_K = K self.is_extended = any([not x.is_symbol for x in sym_a + sym_b]) + - def run(self, a, b, c, bench): + def run(self, a, b, c, values, bench): self.args[self.pos_a] = a self.args[self.pos_b] = b self.args[self.pos_c] = c - return self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros) + for i, name in enumerate(self.varnames): + self.args[self.pos_vars + i] = values[name] + return self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros, num_warps=self.num_warps) @@ -604,8 +641,9 @@ __global__ void {name}( ############################ instance_cache = dict() + registry = triton.utils.id_dict() @staticmethod - def forward(ctx, expr, a, b, output, mask=None, arrays=dict(), bench=False): + def forward(ctx, expr, a, b, output, mask, arrays, bench, values): # compile einsum instance cache = _einsum.instance_cache key = (expr, a.dtype, @@ -615,10 +653,10 @@ __global__ void {name}( cache[key] = _einsum.instance(expr, a.dtype, a.stride(), b.stride(), output.stride(), a.shape, b.shape, arrays, - mask, output.shape) + mask, output.shape, values.keys()) instance = cache[key] # run and mark as dirty output modified in-place - perf = instance.run(a, b, output, bench) + perf = instance.run(a, b, output, values, bench) ctx.mark_dirty(output) # save information in context ctx.is_extended = instance.is_extended @@ -629,8 +667,9 @@ __global__ void {name}( ctx.matmul_M = instance.matmul_M ctx.matmul_N = instance.matmul_N ctx.matmul_K = instance.matmul_K - ctx.perf = perf + ctx.forward_ms = perf ctx.save_for_backward(a, b) + _einsum.registry[output] = ctx return output @@ -662,5 +701,5 @@ __global__ void {name}( def einsum(expr, a, b, output, mask=None, arrays=dict(), - bench=False): - return _einsum.apply(expr, a, b, output, mask, arrays, bench) \ No newline at end of file + bench=False, values=dict()): + return _einsum.apply(expr, a, b, output, mask, arrays, bench, values) \ No newline at end of file diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 131083e72..67a09cfef 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -10,10 +10,10 @@ int main() { typedef std::tuple, bool, bool, int, int, int> config_t; std::vector configs; for(auto ord: std::vector>{{1, 0}}) - for(auto x: std::vector>{{false, false}, {true, false}}){ + for(auto x: std::vector>{{false, false}}){ std::vector tmp = { // config_t{ord, x[0], x[1], 512, 512, 512}, - config_t{ord, x[0], x[1], 8192, 8192, 8192}, + config_t{ord, x[0], x[1], 2048, 2048, 2048}, // config_t{ord, x[0], x[1], 127008, 768, 576}, // config_t{ord, x[0], x[1], 8192, 8192, 8192} // config_t{ord, x[0], x[1], 16, 2048, 2048}, @@ -36,7 +36,7 @@ int main() { for(const auto& c: configs){ std::tie(ord, AT, BT, M, N, K) = c; std::cout << "// " << c ; - for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord)) + for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord)) std::cout << ", " << perf << std::flush; std::cout << std::endl; }