[codegen] triton-ir code generation does not crash
This commit is contained in:
@@ -78,52 +78,53 @@ std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::strin
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")";
|
||||
std::string res =
|
||||
R"(
|
||||
#define TM 128
|
||||
#define TN 128
|
||||
#define TK 32
|
||||
#define TM 128
|
||||
#define TN 128
|
||||
#define TK 32
|
||||
|
||||
#define bool _Bool
|
||||
#define true 1
|
||||
#define false 0
|
||||
#define __bool_true_false_are_defined 1
|
||||
#define bool _Bool
|
||||
#define true 1
|
||||
#define false 0
|
||||
#define __bool_true_false_are_defined 1
|
||||
|
||||
extern int get_program_id(int);
|
||||
extern int get_program_id(int);
|
||||
|
||||
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
|
||||
restrict )" + b_ty + R"( * B __attribute__((readonly, aligned(16))),
|
||||
restrict )" + c_ty + R"( * C __attribute__((aligned(16))),
|
||||
int M, int N, int K,
|
||||
int lda __attribute__((multiple_of(8))),
|
||||
int ldb __attribute__((multiple_of(8))),
|
||||
int ldc) {
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rxa[{TM, TN}] = ridx * TM + 0 ... TM;
|
||||
int ryb[{TN}] = ridy * TN + 0 ... TN;
|
||||
int rka[{TK}] = 0 ... TK;
|
||||
int rkb[{TK}] = 0 ... TK;
|
||||
float xc[{)" + XCS + R"(}] = 0;
|
||||
)" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
)" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
|
||||
)" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = )" + usea + " @ " + useb + R"( + xc;
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
|
||||
restrict )" + b_ty + R"( * B __attribute__((readonly, aligned(16))),
|
||||
restrict )" + c_ty + R"( * C __attribute__((aligned(16))),
|
||||
int M, int N, int K,
|
||||
int lda __attribute__((multiple_of(8))),
|
||||
int ldb __attribute__((multiple_of(8))),
|
||||
int ldc) {
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rxa[{TM}] = ridx * TM + 0 ... TM;
|
||||
int ryb[{TN}] = ridy * TN + 0 ... TN;
|
||||
int rka[{TK}] = 0 ... TK;
|
||||
int rkb[{TK}] = 0 ... TK;
|
||||
float xc[{)" + XCS + R"(}] = 0;
|
||||
)" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
)" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
|
||||
)" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = )" + usea + " @ " + useb + R"( + xc;
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int rxc[{TM}] = ridx * TM + (0 ... TM);
|
||||
int ryc[{TN}] = ridy * TN + (0 ... TN);
|
||||
)" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty + R"( c[{TM, TN}] = xc;
|
||||
bool checkc0[{TM}] = rxc < M;
|
||||
bool checkc1[{TN}] = ryc < N;
|
||||
bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
*pc = c;
|
||||
}
|
||||
int rxc[{TM}] = ridx * TM + (0 ... TM);
|
||||
int ryc[{TN}] = ridy * TN + (0 ... TN);
|
||||
)" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty + R"( c[{TM, TN}] = xc;
|
||||
bool checkc0[{TM}] = rxc < M;
|
||||
bool checkc1[{TN}] = ryc < N;
|
||||
bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
*pc = c;
|
||||
}
|
||||
)";
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@@ -81,6 +81,8 @@ public:
|
||||
protected:
|
||||
// Triton-IR values
|
||||
ir::value* GenAssignOp(Expr* lvalue, ir::value* rhs);
|
||||
ir::value* GenBroadcastOp(ir::value* src, ir::type* dst_ty);
|
||||
ir::value* GenNumcastOp(ir::value*src, ir::type* dst_ty);
|
||||
ir::value* GenCastOp(ir::value* op, ir::type* type);
|
||||
|
||||
// Triton-IR types
|
||||
|
@@ -221,15 +221,20 @@ ArithmType* BinaryOp::Convert() {
|
||||
void BinaryOp::Broadcast() {
|
||||
auto lhsType = lhs_->Type()->ToTile();
|
||||
auto rhsType = rhs_->Type()->ToTile();
|
||||
auto eleType = type_->ScalarType();
|
||||
assert(eleType);
|
||||
if(!lhsType && !rhsType)
|
||||
return ;
|
||||
else if(lhsType && !rhsType){
|
||||
type_ = lhsType;
|
||||
rhs_ = UnaryOp::New(Token::CAST, lhs_, type_);
|
||||
type_ = TileType::New(lhsType->Shape(), eleType);
|
||||
::Type* rtype = TileType::New(lhsType->Shape(), rhs_->Type()->ScalarType());
|
||||
rhs_ = UnaryOp::New(Token::CAST, rhs_, rtype);
|
||||
}
|
||||
else if(!lhsType && rhsType){
|
||||
type_ = rhsType;
|
||||
lhs_ = UnaryOp::New(Token::CAST, rhs_, type_);
|
||||
type_ = TileType::New(rhsType->Shape(), eleType);
|
||||
::Type* ltype = TileType::New(rhsType->Shape(), lhs_->Type()->ScalarType());
|
||||
lhs_ = UnaryOp::New(Token::CAST, lhs_, ltype);
|
||||
|
||||
}
|
||||
else {
|
||||
auto lhsShape = lhsType->Shape();
|
||||
@@ -256,12 +261,13 @@ void BinaryOp::Broadcast() {
|
||||
"for operands of shape %d and %d",
|
||||
i, lhsShape[i], rhsShape[i]);
|
||||
}
|
||||
auto eleType = lhsType->Derived();
|
||||
::Type* ltype = TileType::New(retShape, lhsType->ScalarType());
|
||||
::Type* rtype = TileType::New(retShape, rhsType->ScalarType());
|
||||
type_ = TileType::New(retShape, eleType);
|
||||
if(retShape != lhsShape)
|
||||
lhs_ = UnaryOp::New(Token::CAST, lhs_, type_);
|
||||
lhs_ = UnaryOp::New(Token::CAST, lhs_, ltype);
|
||||
if(retShape != rhsShape)
|
||||
rhs_ = UnaryOp::New(Token::CAST, rhs_, type_);
|
||||
rhs_ = UnaryOp::New(Token::CAST, rhs_, rtype);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,18 +353,6 @@ void BinaryOp::CommaOpTypeChecking() {
|
||||
|
||||
void BinaryOp::SubScriptingOpTypeChecking() {
|
||||
assert(false);
|
||||
auto lhsType = lhs_->Type()->ToTile();
|
||||
|
||||
if (!lhsType) {
|
||||
Error(this, "operator [] can only be used on tiles");
|
||||
}
|
||||
|
||||
if (!rhs_->Type()->IsInteger()) {
|
||||
Error(this, "the operand of [] should be integer");
|
||||
}
|
||||
|
||||
// The type of [] operator is the derived type
|
||||
type_ = lhsType->Derived();
|
||||
}
|
||||
|
||||
|
||||
@@ -401,7 +395,6 @@ void BinaryOp::AdditiveOpTypeChecking() {
|
||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||
auto lhsPtrType = lhsScalType->ToPointer();
|
||||
auto rhsPtrType = rhsScalType->ToPointer();
|
||||
std::cout << "adding" << std::endl;
|
||||
if (lhsPtrType) {
|
||||
if (op_ == '-') {
|
||||
if (rhsPtrType) {
|
||||
@@ -436,7 +429,6 @@ void BinaryOp::AdditiveOpTypeChecking() {
|
||||
}
|
||||
|
||||
void BinaryOp::RangeOpTypeChecking() {
|
||||
std::cout << "range" << std::endl;
|
||||
auto lhsType = lhs_->Type()->ToArithm();
|
||||
auto rhsType = rhs_->Type()->ToArithm();
|
||||
if(!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
|
||||
@@ -850,7 +842,6 @@ Declaration* Declaration::New(Object* obj) {
|
||||
|
||||
void Declaration::AddInit(Initializer init) {
|
||||
init.expr_ = Expr::MayCast(init.expr_, init.type_);
|
||||
|
||||
auto res = inits_.insert(init);
|
||||
if (!res.second) {
|
||||
inits_.erase(res.first);
|
||||
|
@@ -18,6 +18,7 @@ inline bool is_terminator(ir::value* x) {
|
||||
// Expression
|
||||
|
||||
void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
|
||||
Visit(binary->rhs_);
|
||||
ir::value* rhs = ret_;
|
||||
|
||||
@@ -43,6 +44,17 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
case Token::RIGHT: return set_ret(bld_->create_lshr(lhs, rhs));
|
||||
case '.': return error_not_implemented();
|
||||
case ',': return error_not_implemented();
|
||||
case '@' : {
|
||||
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
|
||||
ir::type* ret_scal_ty = ret_ty->get_scalar_ty();
|
||||
ir::value* _0;
|
||||
if(ret_scal_ty->is_float_ty())
|
||||
_0 = ir::constant_fp::get(ret_scal_ty, 0);
|
||||
else
|
||||
_0 = ir::constant_int::get(ret_scal_ty, 0);
|
||||
_0 = bld_->create_splat(_0, ret_ty->get_tile_shapes());
|
||||
return set_ret(bld_->create_dot(lhs, rhs, _0));
|
||||
}
|
||||
case Token::ELLIPSIS: {
|
||||
auto clhs = dynamic_cast<ir::constant_int*>(lhs);
|
||||
auto crhs = dynamic_cast<ir::constant_int*>(rhs);
|
||||
@@ -51,8 +63,9 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
return set_ret(ir::constant_range::get(clhs, crhs));
|
||||
}
|
||||
case '+':
|
||||
if(binary->lhs_->Type()->ToPointer())
|
||||
if(binary->lhs_->Type()->ScalarType()->ToPointer()){
|
||||
return set_ret(bld_->create_gep(lhs, {rhs}));
|
||||
}
|
||||
else if(flt)
|
||||
return set_ret(bld_->create_fadd(lhs, rhs));
|
||||
else
|
||||
@@ -138,10 +151,11 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
}
|
||||
|
||||
void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||
|
||||
// recursion
|
||||
Visit(unary->operand_);
|
||||
ir::value* op = ret_;
|
||||
ir::type* type = GenIRType(unary->operand_->Type(), *ctx_);
|
||||
|
||||
// return
|
||||
switch (unary->op_) {
|
||||
case Token::PREFIX_INC: return error_not_implemented();
|
||||
@@ -149,13 +163,14 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||
case Token::POSTFIX_INC: return error_not_implemented();
|
||||
case Token::POSTFIX_DEC: return error_not_implemented();
|
||||
case Token::ADDR: return error_not_implemented();
|
||||
case Token::DEREF: return error_not_implemented();
|
||||
case Token::DEREF: return set_ret(bld_->create_load(op));
|
||||
case Token::PLUS: return error_not_implemented();
|
||||
case Token::MINUS: return error_not_implemented();
|
||||
case '~': return set_ret(bld_->create_neg(op));
|
||||
case '!': return set_ret(bld_->create_not(op));
|
||||
case Token::CAST: return set_ret(GenCastOp(op, type));
|
||||
default: assert(false);
|
||||
case Token::CAST: return set_ret(GenCastOp(op, GenIRType(unary->Type(), *ctx_)));
|
||||
case '^': return set_ret(bld_->create_trans(op));
|
||||
default: error_not_implemented();
|
||||
}
|
||||
return error_not_implemented();
|
||||
}
|
||||
@@ -176,7 +191,7 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
||||
}
|
||||
|
||||
void Generator::VisitObject(Object* obj) {
|
||||
return error_not_implemented();
|
||||
return set_ret(mod_->get_value(obj->Name()));
|
||||
}
|
||||
|
||||
void Generator::VisitEnumerator(Enumerator* enumer) {
|
||||
@@ -220,14 +235,6 @@ void Generator::VisitDeclaration(Declaration* decl) {
|
||||
if(inits.size() > 1)
|
||||
assert(false);
|
||||
val = inits[0];
|
||||
std::cout << obj->Name() << " " << val->get_type()->get_type_id() << " " << ty->get_type_id() << std::endl;
|
||||
if(val->get_type()->is_tile_ty() && ty->is_tile_ty()) {
|
||||
for(auto s: val->get_type()->get_tile_shapes())
|
||||
std::cout << s->get_value() << std::endl;
|
||||
std::cout << "---" << std::endl;
|
||||
for(auto s: ty->get_tile_shapes())
|
||||
std::cout << s->get_value() << std::endl;
|
||||
}
|
||||
assert(val->get_type() == ty);
|
||||
// update scope symbols table
|
||||
const std::string &name = obj->Name();
|
||||
@@ -351,6 +358,7 @@ void Generator::VisitFuncDef(FuncDef* funcDef) {
|
||||
args[i]->set_name(name);
|
||||
mod_->set_value(name, nullptr, args[i]);
|
||||
mod_->get_scope().types[name] = args[i]->get_type();
|
||||
i++;
|
||||
}
|
||||
ir::basic_block *entry = ir::basic_block::create(mod_->get_context(), "entry", fn);
|
||||
mod_->seal_block(entry);
|
||||
@@ -378,60 +386,58 @@ void Generator::Gen(ir::module *mod) {
|
||||
}
|
||||
|
||||
|
||||
// Triton-IR Values
|
||||
|
||||
ir::value* Generator::GenCastOp(ir::value* src, ir::type* dst_ty) {
|
||||
ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) {
|
||||
if(dst_ty->is_tile_ty()) {
|
||||
ir::type *src_ty = src->get_type();
|
||||
auto dst_shapes = dst_ty->get_tile_shapes();
|
||||
if(!src->get_type()->is_tile_ty())
|
||||
if(!src_ty->is_tile_ty())
|
||||
return bld_->create_splat(src, dst_shapes);
|
||||
auto src_shapes = src->get_type()->get_tile_shapes();
|
||||
auto src_shapes = src_ty->get_tile_shapes();
|
||||
if(src_shapes.size() != dst_shapes.size())
|
||||
return bld_->create_reshape(src, dst_shapes);
|
||||
else
|
||||
return bld_->create_broadcast(src, dst_shapes);
|
||||
}
|
||||
return src;
|
||||
}
|
||||
|
||||
ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) {
|
||||
ir::type *src_scalar_ty = src->get_type()->get_scalar_ty();
|
||||
ir::type *dst_scalar_ty = dst_ty->get_scalar_ty();
|
||||
bool src_signed = false;
|
||||
bool dst_signed = false;
|
||||
|
||||
if(src->get_type()->is_tile_ty())
|
||||
dst_ty = ir::tile_type::get_same_shapes(dst_scalar_ty, src->get_type());
|
||||
|
||||
bool src_signed = false;
|
||||
bool dst_signed = false;
|
||||
if(src_scalar_ty == dst_scalar_ty)
|
||||
return src;
|
||||
|
||||
else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty())
|
||||
return bld_->create_si_to_fp(src, dst_ty);
|
||||
|
||||
else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty())
|
||||
return bld_->create_ui_to_fp(src, dst_ty);
|
||||
|
||||
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed)
|
||||
return bld_->create_fp_to_si(src, dst_ty);
|
||||
|
||||
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed)
|
||||
return bld_->create_fp_to_ui(src, dst_ty);
|
||||
|
||||
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
|
||||
src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width())
|
||||
return bld_->create_fp_ext(src, dst_ty);
|
||||
|
||||
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
|
||||
src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width())
|
||||
return bld_->create_fp_trunc(src, dst_ty);
|
||||
|
||||
else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() &&
|
||||
src_scalar_ty->get_integer_bitwidth())
|
||||
return bld_->create_int_cast(src, dst_ty, dst_signed);
|
||||
|
||||
else{
|
||||
should_not_happen();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
ir::value* Generator::GenCastOp(ir::value* src, ir::type* dst_ty) {
|
||||
return GenNumcastOp(GenBroadcastOp(src, dst_ty), dst_ty);
|
||||
}
|
||||
|
||||
// Triton-IR Types
|
||||
ir::type* Generator::GenIRType(::Type* type, ir::context& ctx) {
|
||||
if(auto T = type->ToVoid())
|
||||
@@ -504,7 +510,7 @@ ir::type* Generator::GenIRPointerType(PointerType* type, ir::context& ctx) {
|
||||
}
|
||||
|
||||
ir::type* Generator::GenIRStructType(StructType* type, ir::context& ctx) {
|
||||
assert(false);
|
||||
error_not_implemented();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -535,12 +541,15 @@ void LValAssigner::VisitUnaryOp(UnaryOp* unary) {
|
||||
}
|
||||
|
||||
void LValAssigner::VisitObject(Object* obj) {
|
||||
error_not_implemented();
|
||||
std::string name = obj->Name();
|
||||
gen_->mod_->set_value(name, rhs_);
|
||||
ret_ = rhs_;
|
||||
}
|
||||
|
||||
void LValAssigner::VisitIdentifier(Identifier* ident) {
|
||||
std::string name = ident->Name();
|
||||
gen_->mod_->set_value(name, rhs_);
|
||||
ret_ = rhs_;
|
||||
}
|
||||
|
||||
|
||||
|
@@ -461,8 +461,8 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
||||
Error(tok, "only ':' and newaxis are supported in subscripts");
|
||||
}while(ts_.Try(','));
|
||||
ts_.Expect(']');
|
||||
// if(lhsShape.size() > i)
|
||||
// Error(tok, "broadcasting not using all operand axes");
|
||||
if(lhsShape.size() > i)
|
||||
Error(tok, "broadcasting not using all operand axes");
|
||||
// create ret tile
|
||||
TileType *retType = TileType::New(shape, lhsQual);
|
||||
return UnaryOp::New(Token::CAST, lhs, retType);
|
||||
@@ -1919,6 +1919,7 @@ void Parser::ParseInitializer(Declaration* decl,
|
||||
ts_.Expect('=');
|
||||
}
|
||||
|
||||
// std::cout << "parsing initialized " << decl->Obj()->Name() << std::endl;
|
||||
Expr* expr;
|
||||
auto arrType = type->ToArray();
|
||||
auto structType = type->ToStruct();
|
||||
|
@@ -318,7 +318,7 @@ bool ArrayType::Compatible(const Type& other) const {
|
||||
bool TileType::Compatible(const Type& other) const {
|
||||
// For two tile type to be compatible,
|
||||
// the element types must be compatible
|
||||
// and they must have compatible shapes
|
||||
// and they must have the same shape
|
||||
auto otherTile = other.ToTile();
|
||||
if(!otherTile)
|
||||
return false;
|
||||
|
@@ -120,6 +120,7 @@ void function::caller::operator ()(driver::stream *stream, const std::array<size
|
||||
// module
|
||||
triton::lang::translation_unit *function::make_ast(const char *csrc) {
|
||||
std::string src(csrc);
|
||||
std::cout << src << std::endl;
|
||||
Preprocessor cpp(&src, true);
|
||||
// for (auto& def: defines)
|
||||
// DefineMacro(cpp, def);
|
||||
|
Reference in New Issue
Block a user