[CORE] Fixed several issues that arose in the development of the

torch-blocksparse package:

* Now using warp shuffle in reductions when possible
* Various bugfixes in layout inference
* Added INFINITY, exponential and select
* Better error messages for unimplemented constructs
This commit is contained in:
Philippe Tillet
2020-03-31 18:55:31 -04:00
committed by Philippe Tillet
parent ac26fbdc1f
commit 3304629de9
33 changed files with 374 additions and 135 deletions

View File

@@ -42,8 +42,8 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
case '^': return set_ret(bld_->create_xor(lhs, rhs));
case Token::LEFT: return set_ret(bld_->create_shl(lhs, rhs));
case Token::RIGHT: return set_ret(bld_->create_lshr(lhs, rhs));
case '.': return error_not_implemented();
case ',': return error_not_implemented();
case '.': return error_not_implemented(". binary operator not implemented");
case ',': return error_not_implemented(", binary operator not implemented");
case '@' : {
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
ir::type* ret_scal_ty = ret_ty->get_scalar_ty();
@@ -66,7 +66,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
auto clhs = dynamic_cast<ir::constant_int*>(lhs);
auto crhs = dynamic_cast<ir::constant_int*>(rhs);
if(!clhs || !crhs)
should_not_happen();
error_not_implemented("ellipsis between variables not implemented");
return set_ret(bld_->insert(ir::make_range::create(clhs, crhs)));
}
case '+':
@@ -97,7 +97,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
else if(!sign)
return set_ret(bld_->create_udiv(lhs, rhs));
else
return should_not_happen();
return should_not_happen("/ should not encounter type not in {float, int}");
case '%':
if(flt)
return set_ret(bld_->create_frem(lhs, rhs));
@@ -113,7 +113,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
else if(!sign)
return set_ret(bld_->create_icmpULT(lhs, rhs));
else
return should_not_happen();
return should_not_happen("< should not encounter type not in {float, int}");
case '>':
if(flt)
return set_ret(bld_->create_fcmpOGT(lhs, rhs));
@@ -122,7 +122,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
else if(!sign)
return set_ret(bld_->create_icmpUGT(lhs, rhs));
else
return should_not_happen();
return should_not_happen("> should not encounter type not in {float, int}");
case Token::LE:
if(flt)
return set_ret(bld_->create_fcmpOLE(lhs, rhs));
@@ -131,7 +131,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
else if(!sign)
return set_ret(bld_->create_icmpULE(lhs, rhs));
else
return should_not_happen();
return should_not_happen("<= should not encounter type not in {float, int}");
case Token::GE:
if(flt)
return set_ret(bld_->create_fcmpOGE(lhs, rhs));
@@ -140,7 +140,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
else if(!sign)
return set_ret(bld_->create_icmpUGE(lhs, rhs));
else
return should_not_happen();
return should_not_happen(">= should not encounter type not in {float, int}");
case Token::EQ:
if(flt)
return set_ret(bld_->create_fcmpOEQ(lhs, rhs));
@@ -152,9 +152,9 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
else
return set_ret(bld_->create_icmpNE(lhs, rhs));
default:
error_not_implemented();
return error_not_implemented("binary operator " + std::to_string(binary->op_) + " not implemented");
}
error_not_implemented();
should_not_happen("");
}
ir::reduce_inst::op_t reduce_op(int tag, bool is_float) {
@@ -166,7 +166,7 @@ ir::reduce_inst::op_t reduce_op(int tag, bool is_float) {
case Token::MIN: return is_float ? reduce_inst::FMIN : reduce_inst::MIN;
default: break;
}
should_not_happen();
error_not_implemented("reduction operator " + std::to_string(tag) + " not implemented");
return reduce_inst::op_t();
}
@@ -176,7 +176,10 @@ ir::value* Generator::GenUnaryMinus(ir::value* arg) {
ir::value *_0 = ir::constant_fp::get_zero_value_for_negation(sca_ty);
if(ty->is_tile_ty())
_0 = bld_->create_splat(_0, ty->get_tile_shapes());
return bld_->create_sub(_0, arg);
if(sca_ty->is_floating_point_ty())
return bld_->create_fsub(_0, arg);
else
return bld_->create_sub(_0, arg);
}
void Generator::VisitUnaryOp(UnaryOp* unary) {
@@ -187,18 +190,19 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
ir::type *arg_scal_ty = arg_ty->get_scalar_ty();
// return
switch (unary->op_) {
case Token::PREFIX_INC: return error_not_implemented();
case Token::PREFIX_DEC: return error_not_implemented();
case Token::POSTFIX_INC: return error_not_implemented();
case Token::POSTFIX_DEC: return error_not_implemented();
case Token::ADDR: return error_not_implemented();
case Token::PREFIX_INC: return error_not_implemented("prefix increment not implemented");
case Token::PREFIX_DEC: return error_not_implemented("prefix decrement not implemented");
case Token::POSTFIX_INC: return error_not_implemented("postfix increment not implemented");
case Token::POSTFIX_DEC: return error_not_implemented("postfix decrement not implemented");
case Token::ADDR: return error_not_implemented("unary & not implemented");
case Token::DEREF: return set_ret(bld_->create_load(arg));
case Token::PLUS: return error_not_implemented();
case Token::PLUS: return error_not_implemented("unary + not implemented");
case Token::MINUS: return set_ret(GenUnaryMinus(arg));
case '~': return error_not_implemented();
case '!': return error_not_implemented();
case '~': return error_not_implemented("unary ~ not implemented");
case '!': return error_not_implemented("unary ! not implemented");
case Token::BITCAST: return set_ret(GenBitCastOp(arg, GenIRType(unary->Type(), *ctx_)));
case Token::CAST: return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_)));
case Token::EXP: return set_ret(bld_->create_exp(arg)); //FIXME cast
case Token::REDUCE: {
int ax, tag;
UnaryOp::decodeRed(unary->info_, ax, tag);
@@ -206,9 +210,9 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
ir::reduce_inst::op_t op = reduce_op(tag, is_float);
return set_ret(bld_->create_reduce(arg, op, ax));
}
default: error_not_implemented();
default: error_not_implemented("unary " + std::to_string(unary->op_) + " not implemented");
}
return error_not_implemented();
return should_not_happen("");
}
void Generator::VisitTransOp(TransOp *trans) {
@@ -225,7 +229,9 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) {
ir::value* true_val = ret_;
VisitExpr(condOp->exprFalse_);
ir::value* false_val = ret_;
if(ir::load_inst* ld = dynamic_cast<ir::load_inst*>(true_val)) {
if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
if(!false_val->get_type()->is_tile_ty())
false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(),
cond,
false_val);
@@ -233,7 +239,8 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) {
ld->erase_from_parent();
return set_ret(new_ld);
}
return error_not_implemented();
return set_ret(bld_->create_select(cond, true_val, false_val));
// return error_not_implemented();
}
void Generator::VisitFuncCall(FuncCall* funcCall) {
@@ -244,7 +251,7 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
return set_ret(bld_->create_get_program_id(axis->get_value()));
else
return should_not_happen();
return should_not_happen("get_program_id argument should be constant");
}
if(name == "get_num_programs"){
VisitExpr(funcCall->Args()->at(0));
@@ -252,7 +259,7 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
return set_ret(bld_->create_get_num_program(axis->get_value()));
else
return should_not_happen();
return should_not_happen("get_num_programs argument should be constant");
}
if(name == "atomic_cas"){
VisitExpr(funcCall->Args()->at(0));
@@ -294,7 +301,7 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
ir::value* false_val = ret_;
return set_ret(bld_->create_select(cond, true_val, false_val));
}
return error_not_implemented();
return error_not_implemented("function calls not implemented");
}
void Generator::VisitObject(Object* obj) {
@@ -302,7 +309,7 @@ void Generator::VisitObject(Object* obj) {
}
void Generator::VisitEnumerator(Enumerator* enumer) {
return error_not_implemented();
return error_not_implemented("enumeration not implemented");
}
void Generator::VisitIdentifier(Identifier* ident) {
@@ -316,31 +323,36 @@ void Generator::VisitConstant(Constant* cons) {
return set_ret(ir::constant_int::get(type, cons->IVal()));
if(ctype->IsFloat() && ctype->IsReal())
return set_ret(ir::constant_fp::get(type, cons->FVal()));
return error_not_implemented();
return error_not_implemented("constant of type not in {int, float} not implemented");
}
void Generator::VisitTempVar(TempVar* tempVar) {
return error_not_implemented();
return error_not_implemented("temporary variable not implemented");
}
// Statement
void Generator::VisitDeclaration(Declaration* decl) {
auto obj = decl->obj_;
// initialize to undef
ir::type* ty = GenIRType(obj->Type(), *ctx_);
ir::value* val = ir::undef_value::get(ty);
//obj->GetAttrList()
// compute initializers
std::vector<ir::value*> inits;
for (const Initializer& init: decl->Inits()) {
VisitExpr(init.expr_);
inits.push_back(ret_);
ir::value *val = ret_;
for(const auto& attr: obj->GetAttrList())
SetIRMetadata(attr, val);
inits.push_back(val);
}
// initialize declaration
ir::type::id_t id = ty->get_type_id();
if(id == ir::type::StructTyID)
should_not_happen();
error_not_implemented("struct not implemented");
if(inits.size() > 1)
should_not_happen();
error_not_implemented("initializer list > 1 element not implemented");
if(inits.size() > 0)
val = inits[0];
assert(val->get_type() == ty);
@@ -427,20 +439,20 @@ void Generator::VisitForStmt(ForStmt *forStmt) {
}
void Generator::VisitJumpStmt(JumpStmt* jumpStmt) {
return error_not_implemented();
return error_not_implemented("jump not implemented");
}
void Generator::VisitReturnStmt(ReturnStmt* returnStmt) {
ir::value *ret;
if(returnStmt->expr_)
return error_not_implemented();
return error_not_implemented("non-void return not implemented");
else
ret = bld_->create_ret_void();
return set_ret(ret);
}
void Generator::VisitLabelStmt(LabelStmt* labelStmt) {
return error_not_implemented();
return error_not_implemented("label not implemented");
}
void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) {
@@ -458,7 +470,7 @@ void Generator::VisitFuncDef(FuncDef* funcDef) {
FuncType* type = funcDef->FuncType();
auto prototype = dynamic_cast<ir::function_type*>(GenIRType(type, *ctx_));
if(!prototype)
should_not_happen();
should_not_happen("could not parse function prototype");
ir::function *fn = mod_->get_or_insert_function(name, prototype);
std::vector<ir::argument*> args = fn->args();
size_t i = 0;
@@ -529,7 +541,7 @@ ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) {
for(size_t d = 0; d < padded_shapes.size(); d++){
if(dst_shapes[d] != padded_shapes[d] &&
padded_shapes[d] != 1)
should_not_happen();
should_not_happen("broadcast should not happen between these shapes");
}
// pad and broadcast
ir::value *padded = bld_->create_reshape(src, padded_shapes);
@@ -555,6 +567,9 @@ ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) {
bool dst_signed = false;
if(src_scalar_ty == dst_scalar_ty)
return src;
else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_bool_ty())
return bld_->create_icmpNE(bld_->create_ptr_to_int(src, ir::tile_type::get_same_shapes(bld_->get_int64_ty(), src->get_type())),
bld_->create_splat(bld_->get_int64(0), src->get_type()->get_tile_shapes()));
else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty())
return bld_->create_si_to_fp(src, dst_ty);
else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty())
@@ -575,7 +590,7 @@ ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) {
else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_pointer_ty())
return bld_->create_cast(ir::BitCast, src, dst_ty);
else{
should_not_happen();
error_not_implemented("cast type not implemented");
return nullptr;
}
}
@@ -594,7 +609,7 @@ ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) {
if(attr.kind == ASTNode::Attr::MULTIPLEOF) {
VisitExpr(attr.vals[0]);
auto cst = dynamic_cast<ir::constant_int*>(ret_);
if(!cst) should_not_happen();
if(!cst) should_not_happen("multipleof only works on constants");
return ir::attribute(ir::multiple_of, cst->get_value());
}
if(attr.kind == ASTNode::Attr::ALIGNED) {
@@ -608,7 +623,15 @@ ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) {
return ir::attribute(ir::readonly);
if(attr.kind == ASTNode::Attr::WRITEONLY)
return ir::attribute(ir::writeonly);
should_not_happen();
error_not_implemented("attribute " + std::to_string(attr.kind) + " not implemented");
}
void Generator::SetIRMetadata(ASTNode::Attr attr, ir::value *v) {
auto *i = dynamic_cast<ir::instruction*>(v);
if(!i)
return;
if(attr.kind == ASTNode::Attr::MULTIPLEOF)
i->set_metadata(ir::metadata::multiple_of, GenIRAttr(attr).get_value());
}
// Triton-IR Types
@@ -684,12 +707,12 @@ ir::type* Generator::GenIRPointerType(PointerType* type, ir::context& ctx) {
}
ir::type* Generator::GenIRStructType(StructType* type, ir::context& ctx) {
error_not_implemented();
error_not_implemented("struct not implemented");
return nullptr;
}
void Generator::AllocObjects(Scope* scope, const FuncDef::ParamList& params) {
return error_not_implemented();
return error_not_implemented("alloc not implemented");
}
// SSA
@@ -704,7 +727,7 @@ void Generator::popScope() {
// LValue Generator
void LValAssigner::VisitBinaryOp(BinaryOp* binary) {
if(binary->op_ != Token::MASKED_DEREF)
error_not_implemented();
error_not_implemented("lvalue for binary non masked-deref not implemented");
gen_->VisitExpr(binary->lhs_);
ir::value* mask = gen_->ret_;
gen_->VisitExpr(binary->rhs_);
@@ -714,7 +737,7 @@ void LValAssigner::VisitBinaryOp(BinaryOp* binary) {
void LValAssigner::VisitUnaryOp(UnaryOp* unary) {
if(unary->op_ != Token::DEREF)
should_not_happen();
error_not_implemented("lvalue for unary non deref not implemented");
gen_->VisitExpr(unary->operand_);
ir::value* addr = gen_->ret_;
ret_ = gen_->bld_->create_store(addr, rhs_);