[LANG] Added log intrinsic
This commit is contained in:
committed by
Philippe Tillet
parent
02a6e81b88
commit
f152150e7d
@@ -112,6 +112,7 @@ public:
|
|||||||
void visit_downcast_inst(ir::downcast_inst*);
|
void visit_downcast_inst(ir::downcast_inst*);
|
||||||
|
|
||||||
void visit_exp_inst(ir::exp_inst*);
|
void visit_exp_inst(ir::exp_inst*);
|
||||||
|
void visit_log_inst(ir::log_inst*);
|
||||||
|
|
||||||
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
void visit_get_program_id_inst(ir::get_program_id_inst*);
|
||||||
void visit_get_num_program_inst(ir::get_num_program_inst*);
|
void visit_get_num_program_inst(ir::get_num_program_inst*);
|
||||||
|
@@ -138,6 +138,7 @@ public:
|
|||||||
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
||||||
value *create_atomic_add(value *ptr, value *val, value *msk, const std::string &name = "");
|
value *create_atomic_add(value *ptr, value *val, value *msk, const std::string &name = "");
|
||||||
value *create_exp(value* arg, const std::string &name = "");
|
value *create_exp(value* arg, const std::string &name = "");
|
||||||
|
value *create_log(value* arg, const std::string &name = "");
|
||||||
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
|
||||||
value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");
|
value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");
|
||||||
value *create_sqrt(value *A, const std::string &name = "");
|
value *create_sqrt(value *A, const std::string &name = "");
|
||||||
|
@@ -129,6 +129,7 @@ enum value_id_t: unsigned {
|
|||||||
INST_ATOMIC_ADD,
|
INST_ATOMIC_ADD,
|
||||||
// math
|
// math
|
||||||
INST_EXP,
|
INST_EXP,
|
||||||
|
INST_LOG,
|
||||||
// array arithmetic
|
// array arithmetic
|
||||||
INST_TRANS,
|
INST_TRANS,
|
||||||
INST_REDUCE,
|
INST_REDUCE,
|
||||||
|
@@ -623,6 +623,18 @@ public:
|
|||||||
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class log_inst: public builtin_inst {
|
||||||
|
private:
|
||||||
|
log_inst(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||||
|
std::string repr_impl() const { return "log"; }
|
||||||
|
_TRITON_DEFINE_CLONE(log_inst)
|
||||||
|
_TRITON_DEFINE_ACCEPT(log_inst)
|
||||||
|
|
||||||
|
public:
|
||||||
|
static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
class dot_inst: public builtin_inst {
|
class dot_inst: public builtin_inst {
|
||||||
public:
|
public:
|
||||||
enum TransT { NoTrans, Trans };
|
enum TransT { NoTrans, Trans };
|
||||||
|
@@ -49,6 +49,7 @@ class broadcast_inst;
|
|||||||
class downcast_inst;
|
class downcast_inst;
|
||||||
|
|
||||||
class exp_inst;
|
class exp_inst;
|
||||||
|
class log_inst;
|
||||||
|
|
||||||
class get_program_id_inst;
|
class get_program_id_inst;
|
||||||
class get_num_program_inst;
|
class get_num_program_inst;
|
||||||
@@ -117,6 +118,7 @@ public:
|
|||||||
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
virtual void visit_masked_store_inst(masked_store_inst*) = 0;
|
||||||
|
|
||||||
virtual void visit_exp_inst(exp_inst*) = 0;
|
virtual void visit_exp_inst(exp_inst*) = 0;
|
||||||
|
virtual void visit_log_inst(log_inst*) = 0;
|
||||||
|
|
||||||
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
virtual void visit_reshape_inst(reshape_inst*) = 0;
|
||||||
virtual void visit_splat_inst(splat_inst*) = 0;
|
virtual void visit_splat_inst(splat_inst*) = 0;
|
||||||
|
@@ -167,6 +167,7 @@ public:
|
|||||||
// function keywords
|
// function keywords
|
||||||
BITCAST,
|
BITCAST,
|
||||||
EXP,
|
EXP,
|
||||||
|
LOG,
|
||||||
// KEYWORD END
|
// KEYWORD END
|
||||||
|
|
||||||
IDENTIFIER,
|
IDENTIFIER,
|
||||||
|
@@ -598,7 +598,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
|||||||
Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634);
|
Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634);
|
||||||
std::vector<llvm::Type*> tys = {builder_->getFloatTy()};
|
std::vector<llvm::Type*> tys = {builder_->getFloatTy()};
|
||||||
FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false);
|
FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false);
|
||||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.ftz.f32 $0, $1;", "=f,f", false);
|
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $1;", "=f,f", false);
|
||||||
|
|
||||||
|
|
||||||
for_each(x, [&](indices_t idx){
|
for_each(x, [&](indices_t idx){
|
||||||
@@ -607,6 +607,24 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void generator::visit_log_inst(ir::log_inst* x){
|
||||||
|
distributed_tile *arg = (distributed_tile*)tmap_.at(x->get_operand(0));
|
||||||
|
// Function *fn = builder_->GetInsertBlock()->getParent();
|
||||||
|
// Module *module = fn->getParent();
|
||||||
|
// Type *ty = llvm_type(x->get_type()->get_scalar_ty(), *ctx_);
|
||||||
|
// Function *ex2 = Intrinsic::getDeclaration(module, Intrinsic::nvvm_ex2_approx_ftz_f, {ty});
|
||||||
|
Constant *rcplog2e = ConstantFP::get(builder_->getFloatTy(), 0.6931471805599453);
|
||||||
|
std::vector<llvm::Type*> tys = {builder_->getFloatTy()};
|
||||||
|
FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false);
|
||||||
|
InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false);
|
||||||
|
|
||||||
|
|
||||||
|
for_each(x, [&](indices_t idx){
|
||||||
|
Value *lg2arg = builder_->CreateCall(lg2, std::vector<llvm::Value*>{arg->get_value(idx)});
|
||||||
|
set_value(x, idx, builder_->CreateFMul(lg2arg, rcplog2e));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
|
||||||
BasicBlock *current = builder_->GetInsertBlock();
|
BasicBlock *current = builder_->GetInsertBlock();
|
||||||
Module *module = current->getModule();
|
Module *module = current->getModule();
|
||||||
|
@@ -315,6 +315,10 @@ value *builder::create_exp(value *arg, const std::string &name){
|
|||||||
return insert(exp_inst::create(arg, name));
|
return insert(exp_inst::create(arg, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
value *builder::create_log(value *arg, const std::string &name){
|
||||||
|
return insert(log_inst::create(arg, name));
|
||||||
|
}
|
||||||
|
|
||||||
value *builder::create_dot(value *A, value *B, value *C, const std::string &name) {
|
value *builder::create_dot(value *A, value *B, value *C, const std::string &name) {
|
||||||
return insert(dot_inst::create_nn(A, B, C, name));
|
return insert(dot_inst::create_nn(A, B, C, name));
|
||||||
}
|
}
|
||||||
|
@@ -758,6 +758,17 @@ instruction* exp_inst::create(value *val, const std::string& name, instruction *
|
|||||||
return new exp_inst(val, name, next);
|
return new exp_inst(val, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// log
|
||||||
|
|
||||||
|
log_inst::log_inst(value *val, const std::string &name, instruction *next)
|
||||||
|
: builtin_inst(val->get_type(), INST_LOG, 1, name, next) {
|
||||||
|
set_operand(0, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
instruction* log_inst::create(value *val, const std::string& name, instruction *next) {
|
||||||
|
return new log_inst(val, name, next);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// intrinsic instructions
|
// intrinsic instructions
|
||||||
|
@@ -656,6 +656,7 @@ void UnaryOp::TypeChecking() {
|
|||||||
return ReduceOpTypeChecking();
|
return ReduceOpTypeChecking();
|
||||||
|
|
||||||
case Token::EXP:
|
case Token::EXP:
|
||||||
|
case Token::LOG:
|
||||||
return IntrinsicOpTypeChecking();
|
return IntrinsicOpTypeChecking();
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
@@ -203,6 +203,7 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
|||||||
case Token::BITCAST: return set_ret(GenBitCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
case Token::BITCAST: return set_ret(GenBitCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
||||||
case Token::CAST: return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
case Token::CAST: return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
||||||
case Token::EXP: return set_ret(bld_->create_exp(arg)); //FIXME cast
|
case Token::EXP: return set_ret(bld_->create_exp(arg)); //FIXME cast
|
||||||
|
case Token::LOG: return set_ret(bld_->create_log(arg));
|
||||||
case Token::REDUCE: {
|
case Token::REDUCE: {
|
||||||
int ax, tag;
|
int ax, tag;
|
||||||
UnaryOp::decodeRed(unary->info_, ax, tag);
|
UnaryOp::decodeRed(unary->info_, ax, tag);
|
||||||
@@ -277,7 +278,10 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
|||||||
ir::value* val = ret_;
|
ir::value* val = ret_;
|
||||||
return set_ret(bld_->create_atomic_exch(ptr, val));
|
return set_ret(bld_->create_atomic_exch(ptr, val));
|
||||||
}
|
}
|
||||||
if(name == "f32_atomic_add" || name == "atomic_add_64x64"){
|
if(name == "f32_atomic_add" ||
|
||||||
|
name == "atomic_add_32x32" || name == "atomic_add_32x64" || name == "atomic_add_32x128" ||
|
||||||
|
name == "atomic_add_64x32" || name == "atomic_add_64x64" || name == "atomic_add_64x128" ||
|
||||||
|
name == "atomic_add_128x32"|| name == "atomic_add_128x64"|| name == "atomic_add_128x128"){
|
||||||
VisitExpr(funcCall->Args()->at(0));
|
VisitExpr(funcCall->Args()->at(0));
|
||||||
ir::value* ptr = ret_;
|
ir::value* ptr = ret_;
|
||||||
VisitExpr(funcCall->Args()->at(1));
|
VisitExpr(funcCall->Args()->at(1));
|
||||||
|
@@ -571,6 +571,7 @@ Expr* Parser::ParseUnaryExpr() {
|
|||||||
case Token::INC: return ParsePrefixIncDec(tok);
|
case Token::INC: return ParsePrefixIncDec(tok);
|
||||||
case Token::DEC: return ParsePrefixIncDec(tok);
|
case Token::DEC: return ParsePrefixIncDec(tok);
|
||||||
case Token::EXP: return ParseUnaryIntrinsicOp(tok, Token::EXP); //FIXME: merge into generic array functions
|
case Token::EXP: return ParseUnaryIntrinsicOp(tok, Token::EXP); //FIXME: merge into generic array functions
|
||||||
|
case Token::LOG: return ParseUnaryIntrinsicOp(tok, Token::LOG); //FIXME: merge into generic array functions
|
||||||
case '&': return ParseUnaryOp(tok, Token::ADDR);
|
case '&': return ParseUnaryOp(tok, Token::ADDR);
|
||||||
case '*': return ParseDerefOp(tok);
|
case '*': return ParseDerefOp(tok);
|
||||||
case '+': return ParseUnaryOp(tok, Token::PLUS);
|
case '+': return ParseUnaryOp(tok, Token::PLUS);
|
||||||
|
@@ -46,6 +46,7 @@ const std::unordered_map<std::string, int> Token::kwTypeMap_ {
|
|||||||
{ "while", Token::WHILE },
|
{ "while", Token::WHILE },
|
||||||
{ "bitcast", Token::BITCAST },
|
{ "bitcast", Token::BITCAST },
|
||||||
{ "exp", Token::EXP },
|
{ "exp", Token::EXP },
|
||||||
|
{ "log", Token::LOG },
|
||||||
{ "_Alignas", Token::ALIGNAS },
|
{ "_Alignas", Token::ALIGNAS },
|
||||||
{ "_Alignof", Token::ALIGNOF },
|
{ "_Alignof", Token::ALIGNOF },
|
||||||
{ "_Atomic", Token::ATOMIC },
|
{ "_Atomic", Token::ATOMIC },
|
||||||
@@ -149,6 +150,7 @@ const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
|
|||||||
{ Token::WHILE, "while" },
|
{ Token::WHILE, "while" },
|
||||||
{ Token::BITCAST, "bitcast" },
|
{ Token::BITCAST, "bitcast" },
|
||||||
{ Token::EXP, "exp" },
|
{ Token::EXP, "exp" },
|
||||||
|
{ Token::LOG, "log" },
|
||||||
{ Token::ALIGNAS, "_Alignas" },
|
{ Token::ALIGNAS, "_Alignas" },
|
||||||
{ Token::ALIGNOF, "_Alignof" },
|
{ Token::ALIGNOF, "_Alignof" },
|
||||||
{ Token::ATOMIC, "_Atomic" },
|
{ Token::ATOMIC, "_Atomic" },
|
||||||
|
@@ -311,7 +311,7 @@ function::caller* function::autotune(driver::stream* stream, const grid_fn_ty& g
|
|||||||
// fast path -- no autotuning necessary
|
// fast path -- no autotuning necessary
|
||||||
if(callers_.size() == 1)
|
if(callers_.size() == 1)
|
||||||
return &*callers_.begin()->second;
|
return &*callers_.begin()->second;
|
||||||
// TODO" copy buffer argument so that auto-tuning doesn't corrupt data
|
// run auto-tuner
|
||||||
double best_ts = INFINITY;
|
double best_ts = INFINITY;
|
||||||
caller* ret = nullptr;
|
caller* ret = nullptr;
|
||||||
for(auto &x : callers_){
|
for(auto &x : callers_){
|
||||||
@@ -354,8 +354,15 @@ std::string function::preheader() {
|
|||||||
#define EVALUATOR(a, b, _) PASTER(a, b, _)
|
#define EVALUATOR(a, b, _) PASTER(a, b, _)
|
||||||
#define atomic_add(TM, TN) EVALUATOR(atomic_add, EVALUATOR(TM, TN, x), _)
|
#define atomic_add(TM, TN) EVALUATOR(atomic_add, EVALUATOR(TM, TN, x), _)
|
||||||
extern void atomic_add_64(float*[64], float[64], bool[64]);
|
extern void atomic_add_64(float*[64], float[64], bool[64]);
|
||||||
extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]);
|
extern void atomic_add_32x32(float*[32, 32], float[32, 32], bool[32, 32]);
|
||||||
|
extern void atomic_add_32x64(float*[32, 64], float[32, 64], bool[32, 64]);
|
||||||
|
extern void atomic_add_32x128(float*[32, 128], float[32, 128], bool[32, 128]);
|
||||||
|
extern void atomic_add_64x32(float*[64, 32], float[64, 32], bool[64, 32]);
|
||||||
extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]);
|
extern void atomic_add_64x64(float*[64, 64], float[64, 64], bool[64, 64]);
|
||||||
|
extern void atomic_add_64x128(float*[64, 128], float[64, 128], bool[64, 128]);
|
||||||
|
extern void atomic_add_128x32(float*[128, 32], float[128, 32], bool[128, 32]);
|
||||||
|
extern void atomic_add_128x64(float*[128, 64], float[128, 64], bool[128, 64]);
|
||||||
|
extern void atomic_add_128x128(float*[128, 128], float[128, 128], bool[128, 128]);
|
||||||
|
|
||||||
extern int atomic_cas(int*, int, int);
|
extern int atomic_cas(int*, int, int);
|
||||||
extern int atomic_xchg(int*, int);
|
extern int atomic_xchg(int*, int);
|
||||||
@@ -416,11 +423,6 @@ void function::operator()(void** args, size_t args_size, const grid_fn_ty& grid_
|
|||||||
// pre-compile kernels
|
// pre-compile kernels
|
||||||
if(callers_.empty()){
|
if(callers_.empty()){
|
||||||
precompile(stream, opt_);
|
precompile(stream, opt_);
|
||||||
size_t cumsum = 0;
|
|
||||||
for(arg_type ty: callers_.begin()->second->param_tys()){
|
|
||||||
args_off_.push_back(cumsum);
|
|
||||||
cumsum += size_of(ty);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// re-tuning key
|
// re-tuning key
|
||||||
cache_key_t key;
|
cache_key_t key;
|
||||||
|
@@ -18,8 +18,6 @@ namespace rt = triton::runtime;
|
|||||||
typedef std::pair<int, int> map_key_t;
|
typedef std::pair<int, int> map_key_t;
|
||||||
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
std::map<map_key_t, std::shared_ptr<rt::function::grid_fn_ty>> id_grid_map;
|
||||||
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
std::map<map_key_t, std::shared_ptr<rt::function>> id_fn_map;
|
||||||
std::map<size_t, double> fp64scalar_map;
|
|
||||||
std::map<size_t, int64_t> i64scalar_map;
|
|
||||||
|
|
||||||
/* Grid utilities */
|
/* Grid utilities */
|
||||||
|
|
||||||
@@ -45,6 +43,7 @@ void delete_fn(const map_key_t& key) {
|
|||||||
id_fn_map.erase(key);
|
id_fn_map.erase(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
|
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
|
||||||
pybind11::buffer_info info = data.request();
|
pybind11::buffer_info info = data.request();
|
||||||
id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize);
|
id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize);
|
||||||
@@ -53,7 +52,6 @@ void register_cst(const map_key_t& key, const std::string& name, pybind11::buffe
|
|||||||
void cleanup() {
|
void cleanup() {
|
||||||
id_grid_map.clear();
|
id_grid_map.clear();
|
||||||
id_fn_map.clear();
|
id_fn_map.clear();
|
||||||
i64scalar_map.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t make_op_id() {
|
size_t make_op_id() {
|
||||||
|
@@ -81,8 +81,9 @@ class kernel:
|
|||||||
raise RuntimeError('Must provide grid for kernel launch')
|
raise RuntimeError('Must provide grid for kernel launch')
|
||||||
grid = kwargs['grid']
|
grid = kwargs['grid']
|
||||||
libtriton.register_grid((self.op_id, device), grid)
|
libtriton.register_grid((self.op_id, device), grid)
|
||||||
|
# re-allocate buffers for auto-tuning
|
||||||
|
if 'autotune_buf' in kwargs:
|
||||||
|
pass
|
||||||
# launch
|
# launch
|
||||||
#print(self.tys)
|
|
||||||
params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
|
params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
|
||||||
torch.cuda.synchronize()
|
|
||||||
torch.ops.triton.launch_kernel(self.op_id, device, params)
|
torch.ops.triton.launch_kernel(self.op_id, device, params)
|
Reference in New Issue
Block a user