[LANG] Added log intrinsic
This commit is contained in:
		
				
					committed by
					
						 Philippe Tillet
						Philippe Tillet
					
				
			
			
				
	
			
			
			
						parent
						
							02a6e81b88
						
					
				
				
					commit
					f152150e7d
				
			| @@ -112,6 +112,7 @@ public: | |||||||
|   void visit_downcast_inst(ir::downcast_inst*); |   void visit_downcast_inst(ir::downcast_inst*); | ||||||
|  |  | ||||||
|   void visit_exp_inst(ir::exp_inst*); |   void visit_exp_inst(ir::exp_inst*); | ||||||
|  |   void visit_log_inst(ir::log_inst*); | ||||||
|  |  | ||||||
|   void visit_get_program_id_inst(ir::get_program_id_inst*); |   void visit_get_program_id_inst(ir::get_program_id_inst*); | ||||||
|   void visit_get_num_program_inst(ir::get_num_program_inst*); |   void visit_get_num_program_inst(ir::get_num_program_inst*); | ||||||
|   | |||||||
| @@ -138,6 +138,7 @@ public: | |||||||
|   value *create_atomic_exch(value *ptr, value *val, const std::string &name = ""); |   value *create_atomic_exch(value *ptr, value *val, const std::string &name = ""); | ||||||
|   value *create_atomic_add(value *ptr, value *val, value *msk, const std::string &name = ""); |   value *create_atomic_add(value *ptr, value *val, value *msk, const std::string &name = ""); | ||||||
|   value *create_exp(value* arg, const std::string &name = ""); |   value *create_exp(value* arg, const std::string &name = ""); | ||||||
|  |   value *create_log(value* arg, const std::string &name = ""); | ||||||
|   value *create_dot(value *A, value *B, value *C, const std::string &name = ""); |   value *create_dot(value *A, value *B, value *C, const std::string &name = ""); | ||||||
|   value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = ""); |   value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = ""); | ||||||
|   value *create_sqrt(value *A, const std::string &name = ""); |   value *create_sqrt(value *A, const std::string &name = ""); | ||||||
|   | |||||||
| @@ -129,6 +129,7 @@ enum value_id_t: unsigned { | |||||||
|   INST_ATOMIC_ADD, |   INST_ATOMIC_ADD, | ||||||
|   // math |   // math | ||||||
|   INST_EXP, |   INST_EXP, | ||||||
|  |   INST_LOG, | ||||||
|   // array arithmetic |   // array arithmetic | ||||||
|   INST_TRANS, |   INST_TRANS, | ||||||
|   INST_REDUCE, |   INST_REDUCE, | ||||||
|   | |||||||
| @@ -623,6 +623,18 @@ public: | |||||||
|   static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); |   static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | class log_inst: public builtin_inst { | ||||||
|  | private: | ||||||
|  |   log_inst(value *val, const std::string &name = "", instruction *next = nullptr); | ||||||
|  |   std::string repr_impl() const { return "log"; } | ||||||
|  |   _TRITON_DEFINE_CLONE(log_inst) | ||||||
|  |   _TRITON_DEFINE_ACCEPT(log_inst) | ||||||
|  |  | ||||||
|  | public: | ||||||
|  |   static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); | ||||||
|  | }; | ||||||
|  |  | ||||||
|  |  | ||||||
| class dot_inst: public builtin_inst { | class dot_inst: public builtin_inst { | ||||||
| public: | public: | ||||||
|   enum TransT { NoTrans, Trans }; |   enum TransT { NoTrans, Trans }; | ||||||
|   | |||||||
| @@ -49,6 +49,7 @@ class broadcast_inst; | |||||||
| class downcast_inst; | class downcast_inst; | ||||||
|  |  | ||||||
| class exp_inst; | class exp_inst; | ||||||
|  | class log_inst; | ||||||
|  |  | ||||||
| class get_program_id_inst; | class get_program_id_inst; | ||||||
| class get_num_program_inst; | class get_num_program_inst; | ||||||
| @@ -117,6 +118,7 @@ public: | |||||||
|   virtual void visit_masked_store_inst(masked_store_inst*) = 0; |   virtual void visit_masked_store_inst(masked_store_inst*) = 0; | ||||||
|  |  | ||||||
|   virtual void visit_exp_inst(exp_inst*) = 0; |   virtual void visit_exp_inst(exp_inst*) = 0; | ||||||
|  |   virtual void visit_log_inst(log_inst*) = 0; | ||||||
|  |  | ||||||
|   virtual void visit_reshape_inst(reshape_inst*) = 0; |   virtual void visit_reshape_inst(reshape_inst*) = 0; | ||||||
|   virtual void visit_splat_inst(splat_inst*) = 0; |   virtual void visit_splat_inst(splat_inst*) = 0; | ||||||
|   | |||||||
| @@ -167,6 +167,7 @@ public: | |||||||
|     // function keywords |     // function keywords | ||||||
|     BITCAST, |     BITCAST, | ||||||
|     EXP, |     EXP, | ||||||
|  |     LOG, | ||||||
|     // KEYWORD END |     // KEYWORD END | ||||||
|  |  | ||||||
|     IDENTIFIER, |     IDENTIFIER, | ||||||
|   | |||||||
| @@ -598,7 +598,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){ | |||||||
|   Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634); |   Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634); | ||||||
|   std::vector<llvm::Type*> tys = {builder_->getFloatTy()}; |   std::vector<llvm::Type*> tys = {builder_->getFloatTy()}; | ||||||
|   FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false); |   FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false); | ||||||
|   InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.ftz.f32 $0, $1;", "=f,f", false); |   InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $1;", "=f,f", false); | ||||||
|  |  | ||||||
|  |  | ||||||
|   for_each(x, [&](indices_t idx){ |   for_each(x, [&](indices_t idx){ | ||||||
| @@ -607,6 +607,24 @@ void generator::visit_exp_inst(ir::exp_inst* x){ | |||||||
|   }); |   }); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void generator::visit_log_inst(ir::log_inst* x){ | ||||||
|  |   distributed_tile *arg = (distributed_tile*)tmap_.at(x->get_operand(0)); | ||||||
|  | //  Function *fn = builder_->GetInsertBlock()->getParent(); | ||||||
|  | //  Module *module = fn->getParent(); | ||||||
|  | //  Type *ty = llvm_type(x->get_type()->get_scalar_ty(), *ctx_); | ||||||
|  | //  Function *ex2 = Intrinsic::getDeclaration(module, Intrinsic::nvvm_ex2_approx_ftz_f, {ty}); | ||||||
|  |   Constant *rcplog2e = ConstantFP::get(builder_->getFloatTy(), 0.6931471805599453); | ||||||
|  |   std::vector<llvm::Type*> tys = {builder_->getFloatTy()}; | ||||||
|  |   FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false); | ||||||
|  |   InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false); | ||||||
|  |  | ||||||
|  |  | ||||||
|  |   for_each(x, [&](indices_t idx){ | ||||||
|  |     Value *lg2arg = builder_->CreateCall(lg2, std::vector<llvm::Value*>{arg->get_value(idx)}); | ||||||
|  |     set_value(x, idx, builder_->CreateFMul(lg2arg, rcplog2e)); | ||||||
|  |   }); | ||||||
|  | } | ||||||
|  |  | ||||||
| void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { | void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { | ||||||
|   BasicBlock *current = builder_->GetInsertBlock(); |   BasicBlock *current = builder_->GetInsertBlock(); | ||||||
|   Module *module = current->getModule(); |   Module *module = current->getModule(); | ||||||
|   | |||||||
| @@ -315,6 +315,10 @@ value *builder::create_exp(value *arg, const std::string &name){ | |||||||
|   return insert(exp_inst::create(arg, name)); |   return insert(exp_inst::create(arg, name)); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | value *builder::create_log(value *arg, const std::string &name){ | ||||||
|  |   return insert(log_inst::create(arg, name)); | ||||||
|  | } | ||||||
|  |  | ||||||
| value *builder::create_dot(value *A, value *B, value *C, const std::string &name) { | value *builder::create_dot(value *A, value *B, value *C, const std::string &name) { | ||||||
|   return insert(dot_inst::create_nn(A, B, C, name)); |   return insert(dot_inst::create_nn(A, B, C, name)); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -758,6 +758,17 @@ instruction* exp_inst::create(value *val, const std::string& name, instruction * | |||||||
|   return new exp_inst(val, name, next); |   return new exp_inst(val, name, next); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // log | ||||||
|  |  | ||||||
|  | log_inst::log_inst(value *val, const std::string &name, instruction *next) | ||||||
|  |   : builtin_inst(val->get_type(), INST_LOG, 1, name, next) { | ||||||
|  |   set_operand(0, val); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | instruction* log_inst::create(value *val, const std::string& name, instruction *next) { | ||||||
|  |   return new log_inst(val, name, next); | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
| //===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||||
| //                               intrinsic instructions | //                               intrinsic instructions | ||||||
|   | |||||||
| @@ -656,6 +656,7 @@ void UnaryOp::TypeChecking() { | |||||||
|     return ReduceOpTypeChecking(); |     return ReduceOpTypeChecking(); | ||||||
|  |  | ||||||
|   case Token::EXP: |   case Token::EXP: | ||||||
|  |   case Token::LOG: | ||||||
|     return IntrinsicOpTypeChecking(); |     return IntrinsicOpTypeChecking(); | ||||||
|  |  | ||||||
|   default: |   default: | ||||||
|   | |||||||
| @@ -203,6 +203,7 @@ void Generator::VisitUnaryOp(UnaryOp* unary) { | |||||||
|     case Token::BITCAST:     return set_ret(GenBitCastOp(arg, GenIRType(unary->Type(), *ctx_))); |     case Token::BITCAST:     return set_ret(GenBitCastOp(arg, GenIRType(unary->Type(), *ctx_))); | ||||||
|     case Token::CAST:        return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_))); |     case Token::CAST:        return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_))); | ||||||
|     case Token::EXP:         return set_ret(bld_->create_exp(arg)); //FIXME cast |     case Token::EXP:         return set_ret(bld_->create_exp(arg)); //FIXME cast | ||||||
|  |     case Token::LOG:         return set_ret(bld_->create_log(arg)); | ||||||
|     case Token::REDUCE: { |     case Token::REDUCE: { | ||||||
|       int ax, tag; |       int ax, tag; | ||||||
|       UnaryOp::decodeRed(unary->info_, ax, tag); |       UnaryOp::decodeRed(unary->info_, ax, tag); | ||||||
| @@ -277,7 +278,10 @@ void Generator::VisitFuncCall(FuncCall* funcCall) { | |||||||
|     ir::value* val = ret_; |     ir::value* val = ret_; | ||||||
|     return set_ret(bld_->create_atomic_exch(ptr, val)); |     return set_ret(bld_->create_atomic_exch(ptr, val)); | ||||||
|   } |   } | ||||||
|   if(name == "f32_atomic_add" || name == "atomic_add_64x64"){ |   if(name == "f32_atomic_add" ||  | ||||||
|  |      name == "atomic_add_32x32" || name == "atomic_add_32x64" || name == "atomic_add_32x128" || | ||||||
|  |      name == "atomic_add_64x32" || name == "atomic_add_64x64" || name == "atomic_add_64x128" || | ||||||
|  |      name == "atomic_add_128x32"|| name == "atomic_add_128x64"|| name == "atomic_add_128x128"){ | ||||||
|     VisitExpr(funcCall->Args()->at(0)); |     VisitExpr(funcCall->Args()->at(0)); | ||||||
|     ir::value* ptr = ret_; |     ir::value* ptr = ret_; | ||||||
|     VisitExpr(funcCall->Args()->at(1)); |     VisitExpr(funcCall->Args()->at(1)); | ||||||
|   | |||||||
| @@ -571,6 +571,7 @@ Expr* Parser::ParseUnaryExpr() { | |||||||
|   case Token::INC: return ParsePrefixIncDec(tok); |   case Token::INC: return ParsePrefixIncDec(tok); | ||||||
|   case Token::DEC: return ParsePrefixIncDec(tok); |   case Token::DEC: return ParsePrefixIncDec(tok); | ||||||
|   case Token::EXP: return ParseUnaryIntrinsicOp(tok, Token::EXP); //FIXME: merge into generic array functions |   case Token::EXP: return ParseUnaryIntrinsicOp(tok, Token::EXP); //FIXME: merge into generic array functions | ||||||
|  |   case Token::LOG: return ParseUnaryIntrinsicOp(tok, Token::LOG); //FIXME: merge into generic array functions | ||||||
|   case '&': return ParseUnaryOp(tok, Token::ADDR); |   case '&': return ParseUnaryOp(tok, Token::ADDR); | ||||||
|   case '*': return ParseDerefOp(tok); |   case '*': return ParseDerefOp(tok); | ||||||
|   case '+': return ParseUnaryOp(tok, Token::PLUS); |   case '+': return ParseUnaryOp(tok, Token::PLUS); | ||||||
|   | |||||||
| @@ -46,6 +46,7 @@ const std::unordered_map<std::string, int> Token::kwTypeMap_ { | |||||||
|   { "while", Token::WHILE }, |   { "while", Token::WHILE }, | ||||||
|   { "bitcast", Token::BITCAST }, |   { "bitcast", Token::BITCAST }, | ||||||
|   { "exp", Token::EXP }, |   { "exp", Token::EXP }, | ||||||
|  |   { "log", Token::LOG }, | ||||||
|   { "_Alignas", Token::ALIGNAS }, |   { "_Alignas", Token::ALIGNAS }, | ||||||
|   { "_Alignof", Token::ALIGNOF }, |   { "_Alignof", Token::ALIGNOF }, | ||||||
|   { "_Atomic", Token::ATOMIC }, |   { "_Atomic", Token::ATOMIC }, | ||||||
| @@ -149,6 +150,7 @@ const std::unordered_map<int, const char*> Token::tagLexemeMap_ { | |||||||
|   { Token::WHILE, "while" }, |   { Token::WHILE, "while" }, | ||||||
|   { Token::BITCAST, "bitcast" }, |   { Token::BITCAST, "bitcast" }, | ||||||
|   { Token::EXP, "exp" }, |   { Token::EXP, "exp" }, | ||||||
|  |   { Token::LOG, "log" }, | ||||||
|   { Token::ALIGNAS, "_Alignas" }, |   { Token::ALIGNAS, "_Alignas" }, | ||||||
|   { Token::ALIGNOF, "_Alignof" }, |   { Token::ALIGNOF, "_Alignof" }, | ||||||
|   { Token::ATOMIC, "_Atomic" }, |   { Token::ATOMIC, "_Atomic" }, | ||||||
|   | |||||||
| @@ -311,7 +311,7 @@ function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& g | |||||||
|   // fast path -- no autotuning necessary |   // fast path -- no autotuning necessary | ||||||
|   if(callers_.size() == 1) |   if(callers_.size() == 1) | ||||||
|     return &*callers_.begin()->second; |     return &*callers_.begin()->second; | ||||||
|   // TODO" copy buffer argument so that auto-tuning doesn't corrupt data |   // run auto-tuner | ||||||
|   double best_ts = INFINITY; |   double best_ts = INFINITY; | ||||||
|   caller* ret = nullptr; |   caller* ret = nullptr; | ||||||
|   for(auto &x : callers_){ |   for(auto &x : callers_){ | ||||||
| @@ -354,8 +354,15 @@ std::string function::preheader() { | |||||||
| #define EVALUATOR(a, b, _)  PASTER(a, b, _) | #define EVALUATOR(a, b, _)  PASTER(a, b, _) | ||||||
| #define atomic_add(TM, TN) EVALUATOR(atomic_add, EVALUATOR(TM, TN, x), _) | #define atomic_add(TM, TN) EVALUATOR(atomic_add, EVALUATOR(TM, TN, x), _) | ||||||
| extern void atomic_add_64(float*[64], float[64], bool[64]); | extern void atomic_add_64(float*[64], float[64], bool[64]); | ||||||
| extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]); | extern void atomic_add_32x32(float*[32, 32], float[32, 32], bool[32, 32]); | ||||||
|  | extern void atomic_add_32x64(float*[32, 64], float[32, 64], bool[32, 64]); | ||||||
|  | extern void atomic_add_32x128(float*[32, 128], float[32, 128], bool[32, 128]); | ||||||
|  | extern void atomic_add_64x32(float*[64, 32], float[64, 32], bool[64, 32]); | ||||||
| extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]); | extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]); | ||||||
|  | extern void atomic_add_64x128(float*[64, 128], float[64, 128], bool[64, 128]); | ||||||
|  | extern void atomic_add_128x32(float*[128, 32], float[128, 32], bool[128, 32]); | ||||||
|  | extern void atomic_add_128x64(float*[128, 64], float[128, 64], bool[128, 64]); | ||||||
|  | extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]); | ||||||
|  |  | ||||||
| extern int atomic_cas(int*, int, int); | extern int atomic_cas(int*, int, int); | ||||||
| extern int atomic_xchg(int*, int); | extern int atomic_xchg(int*, int); | ||||||
| @@ -416,11 +423,6 @@ void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_ | |||||||
|   // pre-compile kernels |   // pre-compile kernels | ||||||
|   if(callers_.empty()){ |   if(callers_.empty()){ | ||||||
|     precompile(stream, opt_); |     precompile(stream, opt_); | ||||||
|     size_t cumsum = 0; |  | ||||||
|     for(arg_type ty: callers_.begin()->second->param_tys()){ |  | ||||||
|       args_off_.push_back(cumsum); |  | ||||||
|       cumsum += size_of(ty); |  | ||||||
|     } |  | ||||||
|   } |   } | ||||||
|   // re-tuning key |   // re-tuning key | ||||||
|   cache_key_t key; |   cache_key_t key; | ||||||
|   | |||||||
| @@ -18,8 +18,6 @@ namespace rt = triton::runtime; | |||||||
| typedef std::pair<int, int> map_key_t; | typedef std::pair<int, int> map_key_t; | ||||||
| std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map; | std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map; | ||||||
| std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map; | std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map; | ||||||
| std::map<size_t, double> fp64scalar_map; |  | ||||||
| std::map<size_t, int64_t> i64scalar_map; |  | ||||||
|  |  | ||||||
| /* Grid utilities */ | /* Grid utilities */ | ||||||
|  |  | ||||||
| @@ -45,6 +43,7 @@ void delete_fn(const map_key_t& key) { | |||||||
|   id_fn_map.erase(key); |   id_fn_map.erase(key); | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) { | void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) { | ||||||
|   pybind11::buffer_info info = data.request(); |   pybind11::buffer_info info = data.request(); | ||||||
|   id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize); |   id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize); | ||||||
| @@ -53,7 +52,6 @@ void register_cst(const map_key_t& key, const std::string& name, pybind11::buffe | |||||||
| void cleanup() { | void cleanup() { | ||||||
|   id_grid_map.clear(); |   id_grid_map.clear(); | ||||||
|   id_fn_map.clear(); |   id_fn_map.clear(); | ||||||
|   i64scalar_map.clear(); |  | ||||||
| } | } | ||||||
|  |  | ||||||
| size_t make_op_id() { | size_t make_op_id() { | ||||||
|   | |||||||
| @@ -81,8 +81,9 @@ class kernel: | |||||||
|       raise RuntimeError('Must provide grid for kernel launch') |       raise RuntimeError('Must provide grid for kernel launch') | ||||||
|     grid = kwargs['grid'] |     grid = kwargs['grid'] | ||||||
|     libtriton.register_grid((self.op_id, device), grid) |     libtriton.register_grid((self.op_id, device), grid) | ||||||
|  |     # re-allocate buffers for auto-tuning | ||||||
|  |     if 'autotune_buf' in kwargs: | ||||||
|  |       pass | ||||||
|     # launch |     # launch | ||||||
|     #print(self.tys) |  | ||||||
|     params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args]) |     params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args]) | ||||||
|     torch.cuda.synchronize() |  | ||||||
|     torch.ops.triton.launch_kernel(self.op_id, device, params) |     torch.ops.triton.launch_kernel(self.op_id, device, params) | ||||||
		Reference in New Issue
	
	Block a user