From 7d095ec6862346ecb208852e3b8b581ab19b9eb6 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 17 Nov 2020 01:26:53 -0500 Subject: [PATCH] [LANG] Added sqrtf support --- include/triton/lang/token.h | 1 + lib/lang/ast.cc | 1 + lib/lang/code_gen.cc | 6 +----- lib/lang/parser.cc | 1 + lib/lang/token.cc | 2 ++ lib/runtime/function.cc | 1 - 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/include/triton/lang/token.h b/include/triton/lang/token.h index e4cb22c0d..178f8c42e 100644 --- a/include/triton/lang/token.h +++ b/include/triton/lang/token.h @@ -168,6 +168,7 @@ public: BITCAST, EXP, LOG, + SQRTF, // KEYWORD END IDENTIFIER, diff --git a/lib/lang/ast.cc b/lib/lang/ast.cc index 0dfce31c8..6aaf13408 100644 --- a/lib/lang/ast.cc +++ b/lib/lang/ast.cc @@ -657,6 +657,7 @@ void UnaryOp::TypeChecking() { case Token::EXP: case Token::LOG: + case Token::SQRTF: return IntrinsicOpTypeChecking(); default: diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index a05a7a123..2d72e1794 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -204,6 +204,7 @@ void Generator::VisitUnaryOp(UnaryOp* unary) { case Token::CAST: return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_))); case Token::EXP: return set_ret(bld_->create_exp(arg)); //FIXME cast case Token::LOG: return set_ret(bld_->create_log(arg)); + case Token::SQRTF: return set_ret(bld_->create_sqrt(arg)); case Token::REDUCE: { int ax, tag; UnaryOp::decodeRed(unary->info_, ax, tag); @@ -287,11 +288,6 @@ void Generator::VisitFuncCall(FuncCall* funcCall) { ir::value* msk = ret_; return set_ret(bld_->create_atomic_add(ptr, val, msk)); } - if(name == "sqrtf"){ - VisitExpr(funcCall->Args()->at(0)); - ir::value* ret = ret_; - return set_ret(bld_->create_sqrt(ret)); - } if(name == "calloc"){ VisitExpr(funcCall->Args()->at(0)); ir::value* ret = ret_; diff --git a/lib/lang/parser.cc b/lib/lang/parser.cc index adb22b405..ca9e4bb03 100644 --- a/lib/lang/parser.cc +++ b/lib/lang/parser.cc @@ -571,6 +571,7 @@ Expr* Parser::ParseUnaryExpr() { case Token::INC: return ParsePrefixIncDec(tok); case Token::DEC: return ParsePrefixIncDec(tok); case Token::EXP: return ParseUnaryIntrinsicOp(tok, Token::EXP); //FIXME: merge into generic array functions + case Token::SQRTF: return ParseUnaryIntrinsicOp(tok, Token::SQRTF); case Token::LOG: return ParseUnaryIntrinsicOp(tok, Token::LOG); //FIXME: merge into generic array functions case '&': return ParseUnaryOp(tok, Token::ADDR); case '*': return ParseDerefOp(tok); diff --git a/lib/lang/token.cc b/lib/lang/token.cc index aabbd134c..5e9b535b6 100644 --- a/lib/lang/token.cc +++ b/lib/lang/token.cc @@ -47,6 +47,7 @@ const std::unordered_map Token::kwTypeMap_ { { "bitcast", Token::BITCAST }, { "exp", Token::EXP }, { "log", Token::LOG }, + { "sqrtf", Token::SQRTF }, { "_Alignas", Token::ALIGNAS }, { "_Alignof", Token::ALIGNOF }, { "_Atomic", Token::ATOMIC }, @@ -151,6 +152,7 @@ const std::unordered_map Token::tagLexemeMap_ { { Token::BITCAST, "bitcast" }, { Token::EXP, "exp" }, { Token::LOG, "log" }, + { Token::SQRTF, "sqrtf" }, { Token::ALIGNAS, "_Alignas" }, { Token::ALIGNOF, "_Alignof" }, { Token::ATOMIC, "_Atomic" }, diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index bd9d8cb48..579aa46f0 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -371,7 +371,6 @@ extern int atomic_xchg(int*, int); extern float f32_atomic_add(float*, float); extern int get_program_id(int); extern int get_num_programs(int); -extern float sqrtf(float); extern int select(bool, int, int); extern char __constant__ * calloc(int);