diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 6df877486..f13a3305b 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -112,6 +112,7 @@ public: void visit_downcast_inst(ir::downcast_inst*); void visit_exp_inst(ir::exp_inst*); + void visit_log_inst(ir::log_inst*); void visit_get_program_id_inst(ir::get_program_id_inst*); void visit_get_num_program_inst(ir::get_num_program_inst*); diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index f204ee0b5..7a7ef80bc 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -138,6 +138,7 @@ public: value *create_atomic_exch(value *ptr, value *val, const std::string &name = ""); value *create_atomic_add(value *ptr, value *val, value *msk, const std::string &name = ""); value *create_exp(value* arg, const std::string &name = ""); + value *create_log(value* arg, const std::string &name = ""); value *create_dot(value *A, value *B, value *C, const std::string &name = ""); value *create_trans(value *A, const std::vector &perm = {}, const std::string &name = ""); value *create_sqrt(value *A, const std::string &name = ""); diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 10b017919..86e4b925f 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -129,6 +129,7 @@ enum value_id_t: unsigned { INST_ATOMIC_ADD, // math INST_EXP, + INST_LOG, // array arithmetic INST_TRANS, INST_REDUCE, diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 83255d215..755fb172d 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -623,6 +623,18 @@ public: static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); }; +class log_inst: public builtin_inst { +private: + log_inst(value *val, const std::string &name = "", instruction *next = nullptr); + std::string repr_impl() const { return "log"; } + _TRITON_DEFINE_CLONE(log_inst) + _TRITON_DEFINE_ACCEPT(log_inst) + +public: + static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); +}; + + class dot_inst: public builtin_inst { public: enum TransT { NoTrans, Trans }; diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index e53331af4..e612ee889 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -49,6 +49,7 @@ class broadcast_inst; class downcast_inst; class exp_inst; +class log_inst; class get_program_id_inst; class get_num_program_inst; @@ -117,6 +118,7 @@ public: virtual void visit_masked_store_inst(masked_store_inst*) = 0; virtual void visit_exp_inst(exp_inst*) = 0; + virtual void visit_log_inst(log_inst*) = 0; virtual void visit_reshape_inst(reshape_inst*) = 0; virtual void visit_splat_inst(splat_inst*) = 0; diff --git a/include/triton/lang/token.h b/include/triton/lang/token.h index 621c555fb..e4cb22c0d 100644 --- a/include/triton/lang/token.h +++ b/include/triton/lang/token.h @@ -167,6 +167,7 @@ public: // function keywords BITCAST, EXP, + LOG, // KEYWORD END IDENTIFIER, diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 3de7b2c99..18f2dafcd 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -598,7 +598,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){ Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634); std::vector tys = {builder_->getFloatTy()}; FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false); - InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.ftz.f32 $0, $1;", "=f,f", false); + InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $1;", "=f,f", false); for_each(x, [&](indices_t idx){ @@ -607,6 +607,24 @@ void generator::visit_exp_inst(ir::exp_inst* x){ }); } +void generator::visit_log_inst(ir::log_inst* x){ + distributed_tile *arg = (distributed_tile*)tmap_.at(x->get_operand(0)); +// Function *fn = builder_->GetInsertBlock()->getParent(); +// Module *module = fn->getParent(); +// Type *ty = llvm_type(x->get_type()->get_scalar_ty(), *ctx_); +// Function *ex2 = Intrinsic::getDeclaration(module, Intrinsic::nvvm_ex2_approx_ftz_f, {ty}); + Constant *rcplog2e = ConstantFP::get(builder_->getFloatTy(), 0.6931471805599453); + std::vector tys = {builder_->getFloatTy()}; + FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false); + InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false); + + + for_each(x, [&](indices_t idx){ + Value *lg2arg = builder_->CreateCall(lg2, std::vector{arg->get_value(idx)}); + set_value(x, idx, builder_->CreateFMul(lg2arg, rcplog2e)); + }); +} + void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { BasicBlock *current = builder_->GetInsertBlock(); Module *module = current->getModule(); diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index c100f461a..eceb694d2 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -315,6 +315,10 @@ value *builder::create_exp(value *arg, const std::string &name){ return insert(exp_inst::create(arg, name)); } +value *builder::create_log(value *arg, const std::string &name){ + return insert(log_inst::create(arg, name)); +} + value *builder::create_dot(value *A, value *B, value *C, const std::string &name) { return insert(dot_inst::create_nn(A, B, C, name)); } diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 6ede70001..828d7081e 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -758,6 +758,17 @@ instruction* exp_inst::create(value *val, const std::string& name, instruction * return new exp_inst(val, name, next); } +// log + +log_inst::log_inst(value *val, const std::string &name, instruction *next) + : builtin_inst(val->get_type(), INST_LOG, 1, name, next) { + set_operand(0, val); +} + +instruction* log_inst::create(value *val, const std::string& name, instruction *next) { + return new log_inst(val, name, next); +} + //===----------------------------------------------------------------------===// // intrinsic instructions diff --git a/lib/lang/ast.cc b/lib/lang/ast.cc index c0887574b..0dfce31c8 100644 --- a/lib/lang/ast.cc +++ b/lib/lang/ast.cc @@ -656,6 +656,7 @@ void UnaryOp::TypeChecking() { return ReduceOpTypeChecking(); case Token::EXP: + case Token::LOG: return IntrinsicOpTypeChecking(); default: diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index 323808ac7..823062d45 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -203,6 +203,7 @@ void Generator::VisitUnaryOp(UnaryOp* unary) { case Token::BITCAST: return set_ret(GenBitCastOp(arg, GenIRType(unary->Type(), *ctx_))); case Token::CAST: return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_))); case Token::EXP: return set_ret(bld_->create_exp(arg)); //FIXME cast + case Token::LOG: return set_ret(bld_->create_log(arg)); case Token::REDUCE: { int ax, tag; UnaryOp::decodeRed(unary->info_, ax, tag); @@ -277,7 +278,10 @@ void Generator::VisitFuncCall(FuncCall* funcCall) { ir::value* val = ret_; return set_ret(bld_->create_atomic_exch(ptr, val)); } - if(name == "f32_atomic_add" || name == "atomic_add_64x64"){ + if(name == "f32_atomic_add" || + name == "atomic_add_32x32" || name == "atomic_add_32x64" || name == "atomic_add_32x128" || + name == "atomic_add_64x32" || name == "atomic_add_64x64" || name == "atomic_add_64x128" || + name == "atomic_add_128x32"|| name == "atomic_add_128x64"|| name == "atomic_add_128x128"){ VisitExpr(funcCall->Args()->at(0)); ir::value* ptr = ret_; VisitExpr(funcCall->Args()->at(1)); diff --git a/lib/lang/parser.cc b/lib/lang/parser.cc index 8025ca563..adb22b405 100644 --- a/lib/lang/parser.cc +++ b/lib/lang/parser.cc @@ -571,6 +571,7 @@ Expr* Parser::ParseUnaryExpr() { case Token::INC: return ParsePrefixIncDec(tok); case Token::DEC: return ParsePrefixIncDec(tok); case Token::EXP: return ParseUnaryIntrinsicOp(tok, Token::EXP); //FIXME: merge into generic array functions + case Token::LOG: return ParseUnaryIntrinsicOp(tok, Token::LOG); //FIXME: merge into generic array functions case '&': return ParseUnaryOp(tok, Token::ADDR); case '*': return ParseDerefOp(tok); case '+': return ParseUnaryOp(tok, Token::PLUS); diff --git a/lib/lang/token.cc b/lib/lang/token.cc index d158062e8..aabbd134c 100644 --- a/lib/lang/token.cc +++ b/lib/lang/token.cc @@ -46,6 +46,7 @@ const std::unordered_map Token::kwTypeMap_ { { "while", Token::WHILE }, { "bitcast", Token::BITCAST }, { "exp", Token::EXP }, + { "log", Token::LOG }, { "_Alignas", Token::ALIGNAS }, { "_Alignof", Token::ALIGNOF }, { "_Atomic", Token::ATOMIC }, @@ -149,6 +150,7 @@ const std::unordered_map Token::tagLexemeMap_ { { Token::WHILE, "while" }, { Token::BITCAST, "bitcast" }, { Token::EXP, "exp" }, + { Token::LOG, "log" }, { Token::ALIGNAS, "_Alignas" }, { Token::ALIGNOF, "_Alignof" }, { Token::ATOMIC, "_Atomic" }, diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index a7f25168b..e4109b628 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -311,7 +311,7 @@ function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& g // fast path -- no autotuning necessary if(callers_.size() == 1) return &*callers_.begin()->second; - // TODO" copy buffer argument so that auto-tuning doesn't corrupt data + // run auto-tuner double best_ts = INFINITY; caller* ret = nullptr; for(auto &x : callers_){ @@ -354,8 +354,15 @@ std::string function::preheader() { #define EVALUATOR(a, b, _) PASTER(a, b, _) #define atomic_add(TM, TN) EVALUATOR(atomic_add, EVALUATOR(TM, TN, x), _) extern void atomic_add_64(float*[64], float[64], bool[64]); -extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]); +extern void atomic_add_32x32(float*[32, 32], float[32, 32], bool[32, 32]); +extern void atomic_add_32x64(float*[32, 64], float[32, 64], bool[32, 64]); +extern void atomic_add_32x128(float*[32, 128], float[32, 128], bool[32, 128]); +extern void atomic_add_64x32(float*[64, 32], float[64, 32], bool[64, 32]); extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]); +extern void atomic_add_64x128(float*[64, 128], float[64, 128], bool[64, 128]); +extern void atomic_add_128x32(float*[128, 32], float[128, 32], bool[128, 32]); +extern void atomic_add_128x64(float*[128, 64], float[128, 64], bool[128, 64]); +extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]); extern int atomic_cas(int*, int, int); extern int atomic_xchg(int*, int); @@ -416,11 +423,6 @@ void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_ // pre-compile kernels if(callers_.empty()){ precompile(stream, opt_); - size_t cumsum = 0; - for(arg_type ty: callers_.begin()->second->param_tys()){ - args_off_.push_back(cumsum); - cumsum += size_of(ty); - } } // re-tuning key cache_key_t key; diff --git a/python/src/bindings.cc b/python/src/bindings.cc index c97cf5fc0..ac1c68fc9 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -18,8 +18,6 @@ namespace rt = triton::runtime; typedef std::pair map_key_t; std::map> id_grid_map; std::map> id_fn_map; -std::map fp64scalar_map; -std::map i64scalar_map; /* Grid utilities */ @@ -45,6 +43,7 @@ void delete_fn(const map_key_t& key) { id_fn_map.erase(key); } + void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) { pybind11::buffer_info info = data.request(); id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize); @@ -53,7 +52,6 @@ void register_cst(const map_key_t& key, const std::string& name, pybind11::buffe void cleanup() { id_grid_map.clear(); id_fn_map.clear(); - i64scalar_map.clear(); } size_t make_op_id() { diff --git a/python/triton/kernel.py b/python/triton/kernel.py index efd683878..a406b61c3 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -81,8 +81,9 @@ class kernel: raise RuntimeError('Must provide grid for kernel launch') grid = kwargs['grid'] libtriton.register_grid((self.op_id, device), grid) + # re-allocate buffers for auto-tuning + if 'autotune_buf' in kwargs: + pass # launch - #print(self.tys) params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args]) - torch.cuda.synchronize() torch.ops.triton.launch_kernel(self.op_id, device, params) \ No newline at end of file