diff --git a/include/triton/lang/ast.h b/include/triton/lang/ast.h index f523cb835..eb442de11 100644 --- a/include/triton/lang/ast.h +++ b/include/triton/lang/ast.h @@ -431,6 +431,7 @@ public: void DerefOpTypeChecking(); void ReduceOpTypeChecking(); void UnaryArithmOpTypeChecking(); + void BitcastOpTypeChecking(); void CastOpTypeChecking(); protected: diff --git a/include/triton/lang/code_gen.h b/include/triton/lang/code_gen.h index a29cf268b..155706ddc 100644 --- a/include/triton/lang/code_gen.h +++ b/include/triton/lang/code_gen.h @@ -91,7 +91,8 @@ protected: ir::value* GenAssignOp(Expr* lvalue, ir::value* rhs); ir::value* GenBroadcastOp(ir::value* src, ir::type* dst_ty); ir::value* GenNumcastOp(ir::value*src, ir::type* dst_ty); - ir::value* GenCastOp(ir::value* op, ir::type* type); + ir::value* GenSemCastOp(ir::value* op, ir::type* type); + ir::value* GenBitCastOp(ir::value* src, ir::type* dst_ty); // Triton-IR types static ir::type* GenIRType(::Type* type, ir::context &ctx); diff --git a/include/triton/lang/token.h b/include/triton/lang/token.h index 1b2868849..2552f1769 100644 --- a/include/triton/lang/token.h +++ b/include/triton/lang/token.h @@ -164,6 +164,7 @@ public: ALIGNOF, // _Alignof GENERIC, // _Generic IMAGINARY, // _Imaginary + BITCAST, // KEYWORD END IDENTIFIER, diff --git a/lib/lang/ast.cc b/lib/lang/ast.cc index 1fd8b2dcb..30031c757 100644 --- a/lib/lang/ast.cc +++ b/lib/lang/ast.cc @@ -646,6 +646,9 @@ void UnaryOp::TypeChecking() { case '!': return UnaryArithmOpTypeChecking(); + case Token::BITCAST: + return BitcastOpTypeChecking(); + case Token::CAST: return CastOpTypeChecking(); @@ -722,6 +725,11 @@ void UnaryOp::UnaryArithmOpTypeChecking() { } } +void UnaryOp::BitcastOpTypeChecking() { + auto operandType = Type::MayCast(operand_->Type()); + if(type_->Width() != operandType->Width()) + Error(this, "cannot bitcast to type of different width"); +} void UnaryOp::CastOpTypeChecking() { auto operandType = Type::MayCast(operand_->Type()); diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index 8bbf39081..41e8afe98 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -197,7 +197,8 @@ void Generator::VisitUnaryOp(UnaryOp* unary) { case Token::MINUS: return set_ret(GenUnaryMinus(arg)); case '~': return error_not_implemented(); case '!': return error_not_implemented(); - case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_))); + case Token::BITCAST: return set_ret(GenBitCastOp(arg, GenIRType(unary->Type(), *ctx_))); + case Token::CAST: return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_))); case Token::REDUCE: { int ax, tag; UnaryOp::decodeRed(unary->info_, ax, tag); @@ -579,10 +580,15 @@ ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) { } } -ir::value* Generator::GenCastOp(ir::value* src, ir::type* dst_ty) { +ir::value* Generator::GenSemCastOp(ir::value* src, ir::type* dst_ty) { return GenNumcastOp(GenBroadcastOp(src, dst_ty), dst_ty); } +ir::value* Generator::GenBitCastOp(ir::value* src, ir::type* dst_ty) { + return bld_->create_cast(ir::BitCast, GenBroadcastOp(src, dst_ty), dst_ty); +} + + // Triton-IR Attr ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) { if(attr.kind == ASTNode::Attr::MULTIPLEOF) { diff --git a/lib/lang/parser.cc b/lib/lang/parser.cc index ae37d9567..acf9167b1 100644 --- a/lib/lang/parser.cc +++ b/lib/lang/parser.cc @@ -664,6 +664,17 @@ QualType Parser::ParseTypeName() { Expr* Parser::ParseCastExpr() { auto tok = ts_.Next(); + // bitcast + if (tok->tag_ == Token::BITCAST) { + ts_.Expect('<'); + auto type = ParseTypeName(); + ts_.Expect('>'); + ts_.Expect('('); + auto operand = ParseExpr(); + ts_.Expect(')'); + return UnaryOp::New(Token::BITCAST, operand, type); + } + // semantic cast if (tok->tag_ == '(' && IsTypeName(ts_.Peek())) { auto type = ParseTypeName(); ts_.Expect(')'); diff --git a/lib/lang/token.cc b/lib/lang/token.cc index c4a95c0c4..e5d395f8b 100644 --- a/lib/lang/token.cc +++ b/lib/lang/token.cc @@ -44,6 +44,7 @@ const std::unordered_map Token::kwTypeMap_ { { "void", Token::VOID }, { "volatile", Token::VOLATILE }, { "while", Token::WHILE }, + { "bitcast", Token::BITCAST }, { "_Alignas", Token::ALIGNAS }, { "_Alignof", Token::ALIGNOF }, { "_Atomic", Token::ATOMIC }, @@ -145,6 +146,7 @@ const std::unordered_map Token::tagLexemeMap_ { { Token::VOID, "void" }, { Token::VOLATILE, "volatile" }, { Token::WHILE, "while" }, + { Token::BITCAST, "bitcast" }, { Token::ALIGNAS, "_Alignas" }, { Token::ALIGNOF, "_Alignof" }, { Token::ATOMIC, "_Atomic" }, diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 3cc5ea87b..70e4df12f 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -283,6 +283,15 @@ extern int get_num_programs(int); extern float sqrtf(float); extern int select(bool, int, int); extern char __constant__ * calloc(int); + +typedef unsigned char uint8; +typedef unsigned short uint16; +typedef unsigned int uint32; +typedef unsigned long uint64; +typedef char int8; +typedef short int16; +typedef int int32; +typedef long int64; )"; }