[LANG] Added log intrinsic
This commit is contained in:
		
				
					committed by
					
						 Philippe Tillet
						Philippe Tillet
					
				
			
			
				
	
			
			
			
						parent
						
							02a6e81b88
						
					
				
				
					commit
					f152150e7d
				
			| @@ -112,6 +112,7 @@ public: | ||||
|   void visit_downcast_inst(ir::downcast_inst*); | ||||
|  | ||||
|   void visit_exp_inst(ir::exp_inst*); | ||||
|   void visit_log_inst(ir::log_inst*); | ||||
|  | ||||
|   void visit_get_program_id_inst(ir::get_program_id_inst*); | ||||
|   void visit_get_num_program_inst(ir::get_num_program_inst*); | ||||
|   | ||||
| @@ -138,6 +138,7 @@ public: | ||||
|   value *create_atomic_exch(value *ptr, value *val, const std::string &name = ""); | ||||
|   value *create_atomic_add(value *ptr, value *val, value *msk, const std::string &name = ""); | ||||
|   value *create_exp(value* arg, const std::string &name = ""); | ||||
|   value *create_log(value* arg, const std::string &name = ""); | ||||
|   value *create_dot(value *A, value *B, value *C, const std::string &name = ""); | ||||
|   value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = ""); | ||||
|   value *create_sqrt(value *A, const std::string &name = ""); | ||||
|   | ||||
| @@ -129,6 +129,7 @@ enum value_id_t: unsigned { | ||||
|   INST_ATOMIC_ADD, | ||||
|   // math | ||||
|   INST_EXP, | ||||
|   INST_LOG, | ||||
|   // array arithmetic | ||||
|   INST_TRANS, | ||||
|   INST_REDUCE, | ||||
|   | ||||
| @@ -623,6 +623,18 @@ public: | ||||
|   static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); | ||||
| }; | ||||
|  | ||||
| class log_inst: public builtin_inst { | ||||
| private: | ||||
|   log_inst(value *val, const std::string &name = "", instruction *next = nullptr); | ||||
|   std::string repr_impl() const { return "log"; } | ||||
|   _TRITON_DEFINE_CLONE(log_inst) | ||||
|   _TRITON_DEFINE_ACCEPT(log_inst) | ||||
|  | ||||
| public: | ||||
|   static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr); | ||||
| }; | ||||
|  | ||||
|  | ||||
| class dot_inst: public builtin_inst { | ||||
| public: | ||||
|   enum TransT { NoTrans, Trans }; | ||||
|   | ||||
| @@ -49,6 +49,7 @@ class broadcast_inst; | ||||
| class downcast_inst; | ||||
|  | ||||
| class exp_inst; | ||||
| class log_inst; | ||||
|  | ||||
| class get_program_id_inst; | ||||
| class get_num_program_inst; | ||||
| @@ -117,6 +118,7 @@ public: | ||||
|   virtual void visit_masked_store_inst(masked_store_inst*) = 0; | ||||
|  | ||||
|   virtual void visit_exp_inst(exp_inst*) = 0; | ||||
|   virtual void visit_log_inst(log_inst*) = 0; | ||||
|  | ||||
|   virtual void visit_reshape_inst(reshape_inst*) = 0; | ||||
|   virtual void visit_splat_inst(splat_inst*) = 0; | ||||
|   | ||||
| @@ -167,6 +167,7 @@ public: | ||||
|     // function keywords | ||||
|     BITCAST, | ||||
|     EXP, | ||||
|     LOG, | ||||
|     // KEYWORD END | ||||
|  | ||||
|     IDENTIFIER, | ||||
|   | ||||
| @@ -598,7 +598,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){ | ||||
|   Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634); | ||||
|   std::vector<llvm::Type*> tys = {builder_->getFloatTy()}; | ||||
|   FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false); | ||||
|   InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.ftz.f32 $0, $1;", "=f,f", false); | ||||
|   InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $1;", "=f,f", false); | ||||
|  | ||||
|  | ||||
|   for_each(x, [&](indices_t idx){ | ||||
| @@ -607,6 +607,24 @@ void generator::visit_exp_inst(ir::exp_inst* x){ | ||||
|   }); | ||||
| } | ||||
|  | ||||
| void generator::visit_log_inst(ir::log_inst* x){ | ||||
|   distributed_tile *arg = (distributed_tile*)tmap_.at(x->get_operand(0)); | ||||
| //  Function *fn = builder_->GetInsertBlock()->getParent(); | ||||
| //  Module *module = fn->getParent(); | ||||
| //  Type *ty = llvm_type(x->get_type()->get_scalar_ty(), *ctx_); | ||||
| //  Function *ex2 = Intrinsic::getDeclaration(module, Intrinsic::nvvm_ex2_approx_ftz_f, {ty}); | ||||
|   Constant *rcplog2e = ConstantFP::get(builder_->getFloatTy(), 0.6931471805599453); | ||||
|   std::vector<llvm::Type*> tys = {builder_->getFloatTy()}; | ||||
|   FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false); | ||||
|   InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false); | ||||
|  | ||||
|  | ||||
|   for_each(x, [&](indices_t idx){ | ||||
|     Value *lg2arg = builder_->CreateCall(lg2, std::vector<llvm::Value*>{arg->get_value(idx)}); | ||||
|     set_value(x, idx, builder_->CreateFMul(lg2arg, rcplog2e)); | ||||
|   }); | ||||
| } | ||||
|  | ||||
| void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { | ||||
|   BasicBlock *current = builder_->GetInsertBlock(); | ||||
|   Module *module = current->getModule(); | ||||
|   | ||||
| @@ -315,6 +315,10 @@ value *builder::create_exp(value *arg, const std::string &name){ | ||||
|   return insert(exp_inst::create(arg, name)); | ||||
| } | ||||
|  | ||||
| value *builder::create_log(value *arg, const std::string &name){ | ||||
|   return insert(log_inst::create(arg, name)); | ||||
| } | ||||
|  | ||||
| value *builder::create_dot(value *A, value *B, value *C, const std::string &name) { | ||||
|   return insert(dot_inst::create_nn(A, B, C, name)); | ||||
| } | ||||
|   | ||||
| @@ -758,6 +758,17 @@ instruction* exp_inst::create(value *val, const std::string& name, instruction * | ||||
|   return new exp_inst(val, name, next); | ||||
| } | ||||
|  | ||||
| // log | ||||
|  | ||||
| log_inst::log_inst(value *val, const std::string &name, instruction *next) | ||||
|   : builtin_inst(val->get_type(), INST_LOG, 1, name, next) { | ||||
|   set_operand(0, val); | ||||
| } | ||||
|  | ||||
| instruction* log_inst::create(value *val, const std::string& name, instruction *next) { | ||||
|   return new log_inst(val, name, next); | ||||
| } | ||||
|  | ||||
|  | ||||
| //===----------------------------------------------------------------------===// | ||||
| //                               intrinsic instructions | ||||
|   | ||||
| @@ -656,6 +656,7 @@ void UnaryOp::TypeChecking() { | ||||
|     return ReduceOpTypeChecking(); | ||||
|  | ||||
|   case Token::EXP: | ||||
|   case Token::LOG: | ||||
|     return IntrinsicOpTypeChecking(); | ||||
|  | ||||
|   default: | ||||
|   | ||||
| @@ -203,6 +203,7 @@ void Generator::VisitUnaryOp(UnaryOp* unary) { | ||||
|     case Token::BITCAST:     return set_ret(GenBitCastOp(arg, GenIRType(unary->Type(), *ctx_))); | ||||
|     case Token::CAST:        return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_))); | ||||
|     case Token::EXP:         return set_ret(bld_->create_exp(arg)); //FIXME cast | ||||
|     case Token::LOG:         return set_ret(bld_->create_log(arg)); | ||||
|     case Token::REDUCE: { | ||||
|       int ax, tag; | ||||
|       UnaryOp::decodeRed(unary->info_, ax, tag); | ||||
| @@ -277,7 +278,10 @@ void Generator::VisitFuncCall(FuncCall* funcCall) { | ||||
|     ir::value* val = ret_; | ||||
|     return set_ret(bld_->create_atomic_exch(ptr, val)); | ||||
|   } | ||||
|   if(name == "f32_atomic_add" || name == "atomic_add_64x64"){ | ||||
|   if(name == "f32_atomic_add" ||  | ||||
|      name == "atomic_add_32x32" || name == "atomic_add_32x64" || name == "atomic_add_32x128" || | ||||
|      name == "atomic_add_64x32" || name == "atomic_add_64x64" || name == "atomic_add_64x128" || | ||||
|      name == "atomic_add_128x32"|| name == "atomic_add_128x64"|| name == "atomic_add_128x128"){ | ||||
|     VisitExpr(funcCall->Args()->at(0)); | ||||
|     ir::value* ptr = ret_; | ||||
|     VisitExpr(funcCall->Args()->at(1)); | ||||
|   | ||||
| @@ -571,6 +571,7 @@ Expr* Parser::ParseUnaryExpr() { | ||||
|   case Token::INC: return ParsePrefixIncDec(tok); | ||||
|   case Token::DEC: return ParsePrefixIncDec(tok); | ||||
|   case Token::EXP: return ParseUnaryIntrinsicOp(tok, Token::EXP); //FIXME: merge into generic array functions | ||||
|   case Token::LOG: return ParseUnaryIntrinsicOp(tok, Token::LOG); //FIXME: merge into generic array functions | ||||
|   case '&': return ParseUnaryOp(tok, Token::ADDR); | ||||
|   case '*': return ParseDerefOp(tok); | ||||
|   case '+': return ParseUnaryOp(tok, Token::PLUS); | ||||
|   | ||||
| @@ -46,6 +46,7 @@ const std::unordered_map<std::string, int> Token::kwTypeMap_ { | ||||
|   { "while", Token::WHILE }, | ||||
|   { "bitcast", Token::BITCAST }, | ||||
|   { "exp", Token::EXP }, | ||||
|   { "log", Token::LOG }, | ||||
|   { "_Alignas", Token::ALIGNAS }, | ||||
|   { "_Alignof", Token::ALIGNOF }, | ||||
|   { "_Atomic", Token::ATOMIC }, | ||||
| @@ -149,6 +150,7 @@ const std::unordered_map<int, const char*> Token::tagLexemeMap_ { | ||||
|   { Token::WHILE, "while" }, | ||||
|   { Token::BITCAST, "bitcast" }, | ||||
|   { Token::EXP, "exp" }, | ||||
|   { Token::LOG, "log" }, | ||||
|   { Token::ALIGNAS, "_Alignas" }, | ||||
|   { Token::ALIGNOF, "_Alignof" }, | ||||
|   { Token::ATOMIC, "_Atomic" }, | ||||
|   | ||||
| @@ -311,7 +311,7 @@ function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& g | ||||
|   // fast path -- no autotuning necessary | ||||
|   if(callers_.size() == 1) | ||||
|     return &*callers_.begin()->second; | ||||
|   // TODO" copy buffer argument so that auto-tuning doesn't corrupt data | ||||
|   // run auto-tuner | ||||
|   double best_ts = INFINITY; | ||||
|   caller* ret = nullptr; | ||||
|   for(auto &x : callers_){ | ||||
| @@ -354,8 +354,15 @@ std::string function::preheader() { | ||||
| #define EVALUATOR(a, b, _)  PASTER(a, b, _) | ||||
| #define atomic_add(TM, TN) EVALUATOR(atomic_add, EVALUATOR(TM, TN, x), _) | ||||
| extern void atomic_add_64(float*[64], float[64], bool[64]); | ||||
| extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]); | ||||
| extern void atomic_add_32x32(float*[32, 32], float[32, 32], bool[32, 32]); | ||||
| extern void atomic_add_32x64(float*[32, 64], float[32, 64], bool[32, 64]); | ||||
| extern void atomic_add_32x128(float*[32, 128], float[32, 128], bool[32, 128]); | ||||
| extern void atomic_add_64x32(float*[64, 32], float[64, 32], bool[64, 32]); | ||||
| extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]); | ||||
| extern void atomic_add_64x128(float*[64, 128], float[64, 128], bool[64, 128]); | ||||
| extern void atomic_add_128x32(float*[128, 32], float[128, 32], bool[128, 32]); | ||||
| extern void atomic_add_128x64(float*[128, 64], float[128, 64], bool[128, 64]); | ||||
| extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]); | ||||
|  | ||||
| extern int atomic_cas(int*, int, int); | ||||
| extern int atomic_xchg(int*, int); | ||||
| @@ -416,11 +423,6 @@ void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_ | ||||
|   // pre-compile kernels | ||||
|   if(callers_.empty()){ | ||||
|     precompile(stream, opt_); | ||||
|     size_t cumsum = 0; | ||||
|     for(arg_type ty: callers_.begin()->second->param_tys()){ | ||||
|       args_off_.push_back(cumsum); | ||||
|       cumsum += size_of(ty); | ||||
|     } | ||||
|   } | ||||
|   // re-tuning key | ||||
|   cache_key_t key; | ||||
|   | ||||
| @@ -18,8 +18,6 @@ namespace rt = triton::runtime; | ||||
| typedef std::pair<int, int> map_key_t; | ||||
| std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map; | ||||
| std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map; | ||||
| std::map<size_t, double> fp64scalar_map; | ||||
| std::map<size_t, int64_t> i64scalar_map; | ||||
|  | ||||
| /* Grid utilities */ | ||||
|  | ||||
| @@ -45,6 +43,7 @@ void delete_fn(const map_key_t& key) { | ||||
|   id_fn_map.erase(key); | ||||
| } | ||||
|  | ||||
|  | ||||
| void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) { | ||||
|   pybind11::buffer_info info = data.request(); | ||||
|   id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize); | ||||
| @@ -53,7 +52,6 @@ void register_cst(const map_key_t& key, const std::string& name, pybind11::buffe | ||||
| void cleanup() { | ||||
|   id_grid_map.clear(); | ||||
|   id_fn_map.clear(); | ||||
|   i64scalar_map.clear(); | ||||
| } | ||||
|  | ||||
| size_t make_op_id() { | ||||
|   | ||||
| @@ -81,8 +81,9 @@ class kernel: | ||||
|       raise RuntimeError('Must provide grid for kernel launch') | ||||
|     grid = kwargs['grid'] | ||||
|     libtriton.register_grid((self.op_id, device), grid) | ||||
|     # re-allocate buffers for auto-tuning | ||||
|     if 'autotune_buf' in kwargs: | ||||
|       pass | ||||
|     # launch | ||||
|     #print(self.tys) | ||||
|     params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args]) | ||||
|     torch.cuda.synchronize() | ||||
|     torch.ops.triton.launch_kernel(self.op_id, device, params) | ||||
		Reference in New Issue
	
	Block a user