[codegen] triton-ir code generation does not crash

This commit is contained in:
Philippe Tillet
2019-08-22 17:27:10 -07:00
parent a6ec807223
commit 87072203c1
7 changed files with 103 additions and 98 deletions

View File

@@ -78,52 +78,53 @@ std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::strin
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")";
std::string res =
R"(
#define TM 128
#define TN 128
#define TK 32
#define TM 128
#define TN 128
#define TK 32
#define bool _Bool
#define true 1
#define false 0
#define __bool_true_false_are_defined 1
#define bool _Bool
#define true 1
#define false 0
#define __bool_true_false_are_defined 1
extern int get_program_id(int);
extern int get_program_id(int);
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
restrict )" + b_ty + R"( * B __attribute__((readonly, aligned(16))),
restrict )" + c_ty + R"( * C __attribute__((aligned(16))),
int M, int N, int K,
int lda __attribute__((multiple_of(8))),
int ldb __attribute__((multiple_of(8))),
int ldc) {
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rxa[{TM, TN}] = ridx * TM + 0 ... TM;
int ryb[{TN}] = ridy * TN + 0 ... TN;
int rka[{TK}] = 0 ... TK;
int rkb[{TK}] = 0 ... TK;
float xc[{)" + XCS + R"(}] = 0;
)" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
)" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
)" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
for(int k = K; k > 0; k = k - TK){
xc = )" + usea + " @ " + useb + R"( + xc;
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
a = *pa;
b = *pb;
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
restrict )" + b_ty + R"( * B __attribute__((readonly, aligned(16))),
restrict )" + c_ty + R"( * C __attribute__((aligned(16))),
int M, int N, int K,
int lda __attribute__((multiple_of(8))),
int ldb __attribute__((multiple_of(8))),
int ldc) {
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rxa[{TM}] = ridx * TM + 0 ... TM;
int ryb[{TN}] = ridy * TN + 0 ... TN;
int rka[{TK}] = 0 ... TK;
int rkb[{TK}] = 0 ... TK;
float xc[{)" + XCS + R"(}] = 0;
)" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
)" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
)" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
for(int k = K; k > 0; k = k - TK){
xc = )" + usea + " @ " + useb + R"( + xc;
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
a = *pa;
b = *pb;
}
int rxc[{TM}] = ridx * TM + (0 ... TM);
int ryc[{TN}] = ridy * TN + (0 ... TN);
)" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
)" + c_ty + R"( c[{TM, TN}] = xc;
bool checkc0[{TM}] = rxc < M;
bool checkc1[{TN}] = ryc < N;
bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
*pc = c;
}
int rxc[{TM}] = ridx * TM + (0 ... TM);
int ryc[{TN}] = ridy * TN + (0 ... TN);
)" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
)" + c_ty + R"( c[{TM, TN}] = xc;
bool checkc0[{TM}] = rxc < M;
bool checkc1[{TN}] = ryc < N;
bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
*pc = c;
}
)";
return res;
}

View File

@@ -81,6 +81,8 @@ public:
protected:
// Triton-IR values
ir::value* GenAssignOp(Expr* lvalue, ir::value* rhs);
ir::value* GenBroadcastOp(ir::value* src, ir::type* dst_ty);
ir::value* GenNumcastOp(ir::value*src, ir::type* dst_ty);
ir::value* GenCastOp(ir::value* op, ir::type* type);
// Triton-IR types

View File

@@ -221,15 +221,20 @@ ArithmType* BinaryOp::Convert() {
void BinaryOp::Broadcast() {
auto lhsType = lhs_->Type()->ToTile();
auto rhsType = rhs_->Type()->ToTile();
auto eleType = type_->ScalarType();
assert(eleType);
if(!lhsType && !rhsType)
return ;
else if(lhsType && !rhsType){
type_ = lhsType;
rhs_ = UnaryOp::New(Token::CAST, lhs_, type_);
type_ = TileType::New(lhsType->Shape(), eleType);
::Type* rtype = TileType::New(lhsType->Shape(), rhs_->Type()->ScalarType());
rhs_ = UnaryOp::New(Token::CAST, rhs_, rtype);
}
else if(!lhsType && rhsType){
type_ = rhsType;
lhs_ = UnaryOp::New(Token::CAST, rhs_, type_);
type_ = TileType::New(rhsType->Shape(), eleType);
::Type* ltype = TileType::New(rhsType->Shape(), lhs_->Type()->ScalarType());
lhs_ = UnaryOp::New(Token::CAST, lhs_, ltype);
}
else {
auto lhsShape = lhsType->Shape();
@@ -256,12 +261,13 @@ void BinaryOp::Broadcast() {
"for operands of shape %d and %d",
i, lhsShape[i], rhsShape[i]);
}
auto eleType = lhsType->Derived();
::Type* ltype = TileType::New(retShape, lhsType->ScalarType());
::Type* rtype = TileType::New(retShape, rhsType->ScalarType());
type_ = TileType::New(retShape, eleType);
if(retShape != lhsShape)
lhs_ = UnaryOp::New(Token::CAST, lhs_, type_);
lhs_ = UnaryOp::New(Token::CAST, lhs_, ltype);
if(retShape != rhsShape)
rhs_ = UnaryOp::New(Token::CAST, rhs_, type_);
rhs_ = UnaryOp::New(Token::CAST, rhs_, rtype);
}
}
@@ -347,18 +353,6 @@ void BinaryOp::CommaOpTypeChecking() {
void BinaryOp::SubScriptingOpTypeChecking() {
assert(false);
auto lhsType = lhs_->Type()->ToTile();
if (!lhsType) {
Error(this, "operator [] can only be used on tiles");
}
if (!rhs_->Type()->IsInteger()) {
Error(this, "the operand of [] should be integer");
}
// The type of [] operator is the derived type
type_ = lhsType->Derived();
}
@@ -401,7 +395,6 @@ void BinaryOp::AdditiveOpTypeChecking() {
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
auto lhsPtrType = lhsScalType->ToPointer();
auto rhsPtrType = rhsScalType->ToPointer();
std::cout << "adding" << std::endl;
if (lhsPtrType) {
if (op_ == '-') {
if (rhsPtrType) {
@@ -436,7 +429,6 @@ void BinaryOp::AdditiveOpTypeChecking() {
}
void BinaryOp::RangeOpTypeChecking() {
std::cout << "range" << std::endl;
auto lhsType = lhs_->Type()->ToArithm();
auto rhsType = rhs_->Type()->ToArithm();
if(!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
@@ -850,7 +842,6 @@ Declaration* Declaration::New(Object* obj) {
void Declaration::AddInit(Initializer init) {
init.expr_ = Expr::MayCast(init.expr_, init.type_);
auto res = inits_.insert(init);
if (!res.second) {
inits_.erase(res.first);

View File

@@ -18,6 +18,7 @@ inline bool is_terminator(ir::value* x) {
// Expression
void Generator::VisitBinaryOp(BinaryOp* binary) {
Visit(binary->rhs_);
ir::value* rhs = ret_;
@@ -43,6 +44,17 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
case Token::RIGHT: return set_ret(bld_->create_lshr(lhs, rhs));
case '.': return error_not_implemented();
case ',': return error_not_implemented();
case '@' : {
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
ir::type* ret_scal_ty = ret_ty->get_scalar_ty();
ir::value* _0;
if(ret_scal_ty->is_float_ty())
_0 = ir::constant_fp::get(ret_scal_ty, 0);
else
_0 = ir::constant_int::get(ret_scal_ty, 0);
_0 = bld_->create_splat(_0, ret_ty->get_tile_shapes());
return set_ret(bld_->create_dot(lhs, rhs, _0));
}
case Token::ELLIPSIS: {
auto clhs = dynamic_cast<ir::constant_int*>(lhs);
auto crhs = dynamic_cast<ir::constant_int*>(rhs);
@@ -51,8 +63,9 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
return set_ret(ir::constant_range::get(clhs, crhs));
}
case '+':
if(binary->lhs_->Type()->ToPointer())
if(binary->lhs_->Type()->ScalarType()->ToPointer()){
return set_ret(bld_->create_gep(lhs, {rhs}));
}
else if(flt)
return set_ret(bld_->create_fadd(lhs, rhs));
else
@@ -138,10 +151,11 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
}
void Generator::VisitUnaryOp(UnaryOp* unary) {
// recursion
Visit(unary->operand_);
ir::value* op = ret_;
ir::type* type = GenIRType(unary->operand_->Type(), *ctx_);
// return
switch (unary->op_) {
case Token::PREFIX_INC: return error_not_implemented();
@@ -149,13 +163,14 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
case Token::POSTFIX_INC: return error_not_implemented();
case Token::POSTFIX_DEC: return error_not_implemented();
case Token::ADDR: return error_not_implemented();
case Token::DEREF: return error_not_implemented();
case Token::DEREF: return set_ret(bld_->create_load(op));
case Token::PLUS: return error_not_implemented();
case Token::MINUS: return error_not_implemented();
case '~': return set_ret(bld_->create_neg(op));
case '!': return set_ret(bld_->create_not(op));
case Token::CAST: return set_ret(GenCastOp(op, type));
default: assert(false);
case Token::CAST: return set_ret(GenCastOp(op, GenIRType(unary->Type(), *ctx_)));
case '^': return set_ret(bld_->create_trans(op));
default: error_not_implemented();
}
return error_not_implemented();
}
@@ -176,7 +191,7 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
}
void Generator::VisitObject(Object* obj) {
return error_not_implemented();
return set_ret(mod_->get_value(obj->Name()));
}
void Generator::VisitEnumerator(Enumerator* enumer) {
@@ -220,14 +235,6 @@ void Generator::VisitDeclaration(Declaration* decl) {
if(inits.size() > 1)
assert(false);
val = inits[0];
std::cout << obj->Name() << " " << val->get_type()->get_type_id() << " " << ty->get_type_id() << std::endl;
if(val->get_type()->is_tile_ty() && ty->is_tile_ty()) {
for(auto s: val->get_type()->get_tile_shapes())
std::cout << s->get_value() << std::endl;
std::cout << "---" << std::endl;
for(auto s: ty->get_tile_shapes())
std::cout << s->get_value() << std::endl;
}
assert(val->get_type() == ty);
// update scope symbols table
const std::string &name = obj->Name();
@@ -351,6 +358,7 @@ void Generator::VisitFuncDef(FuncDef* funcDef) {
args[i]->set_name(name);
mod_->set_value(name, nullptr, args[i]);
mod_->get_scope().types[name] = args[i]->get_type();
i++;
}
ir::basic_block *entry = ir::basic_block::create(mod_->get_context(), "entry", fn);
mod_->seal_block(entry);
@@ -378,60 +386,58 @@ void Generator::Gen(ir::module *mod) {
}
// Triton-IR Values
ir::value* Generator::GenCastOp(ir::value* src, ir::type* dst_ty) {
ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) {
if(dst_ty->is_tile_ty()) {
ir::type *src_ty = src->get_type();
auto dst_shapes = dst_ty->get_tile_shapes();
if(!src->get_type()->is_tile_ty())
if(!src_ty->is_tile_ty())
return bld_->create_splat(src, dst_shapes);
auto src_shapes = src->get_type()->get_tile_shapes();
auto src_shapes = src_ty->get_tile_shapes();
if(src_shapes.size() != dst_shapes.size())
return bld_->create_reshape(src, dst_shapes);
else
return bld_->create_broadcast(src, dst_shapes);
}
return src;
}
ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) {
ir::type *src_scalar_ty = src->get_type()->get_scalar_ty();
ir::type *dst_scalar_ty = dst_ty->get_scalar_ty();
bool src_signed = false;
bool dst_signed = false;
if(src->get_type()->is_tile_ty())
dst_ty = ir::tile_type::get_same_shapes(dst_scalar_ty, src->get_type());
bool src_signed = false;
bool dst_signed = false;
if(src_scalar_ty == dst_scalar_ty)
return src;
else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty())
return bld_->create_si_to_fp(src, dst_ty);
else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty())
return bld_->create_ui_to_fp(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed)
return bld_->create_fp_to_si(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed)
return bld_->create_fp_to_ui(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width())
return bld_->create_fp_ext(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width())
return bld_->create_fp_trunc(src, dst_ty);
else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() &&
src_scalar_ty->get_integer_bitwidth())
return bld_->create_int_cast(src, dst_ty, dst_signed);
else{
should_not_happen();
return nullptr;
}
}
ir::value* Generator::GenCastOp(ir::value* src, ir::type* dst_ty) {
return GenNumcastOp(GenBroadcastOp(src, dst_ty), dst_ty);
}
// Triton-IR Types
ir::type* Generator::GenIRType(::Type* type, ir::context& ctx) {
if(auto T = type->ToVoid())
@@ -504,7 +510,7 @@ ir::type* Generator::GenIRPointerType(PointerType* type, ir::context& ctx) {
}
ir::type* Generator::GenIRStructType(StructType* type, ir::context& ctx) {
assert(false);
error_not_implemented();
return nullptr;
}
@@ -535,12 +541,15 @@ void LValAssigner::VisitUnaryOp(UnaryOp* unary) {
}
void LValAssigner::VisitObject(Object* obj) {
error_not_implemented();
std::string name = obj->Name();
gen_->mod_->set_value(name, rhs_);
ret_ = rhs_;
}
void LValAssigner::VisitIdentifier(Identifier* ident) {
std::string name = ident->Name();
gen_->mod_->set_value(name, rhs_);
ret_ = rhs_;
}

View File

@@ -461,8 +461,8 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
Error(tok, "only ':' and newaxis are supported in subscripts");
}while(ts_.Try(','));
ts_.Expect(']');
// if(lhsShape.size() > i)
// Error(tok, "broadcasting not using all operand axes");
if(lhsShape.size() > i)
Error(tok, "broadcasting not using all operand axes");
// create ret tile
TileType *retType = TileType::New(shape, lhsQual);
return UnaryOp::New(Token::CAST, lhs, retType);
@@ -1919,6 +1919,7 @@ void Parser::ParseInitializer(Declaration* decl,
ts_.Expect('=');
}
// std::cout << "parsing initialized " << decl->Obj()->Name() << std::endl;
Expr* expr;
auto arrType = type->ToArray();
auto structType = type->ToStruct();

View File

@@ -318,7 +318,7 @@ bool ArrayType::Compatible(const Type& other) const {
bool TileType::Compatible(const Type& other) const {
// For two tile type to be compatible,
// the element types must be compatible
// and they must have compatible shapes
// and they must have the same shape
auto otherTile = other.ToTile();
if(!otherTile)
return false;

View File

@@ -120,6 +120,7 @@ void function::caller::operator ()(driver::stream *stream, const std::array<size
// module
triton::lang::translation_unit *function::make_ast(const char *csrc) {
std::string src(csrc);
std::cout << src << std::endl;
Preprocessor cpp(&src, true);
// for (auto& def: defines)
// DefineMacro(cpp, def);