[CORE] Fixed several issues that arose in the development of the
torch-blocksparse package: * Now using warp shuffle in reductions when possible * Various bugfixes in layout inference * Added INFINITY, exponential and select * Better error messages for unimplemented constructs
This commit is contained in:
committed by
Philippe Tillet
parent
ac26fbdc1f
commit
3304629de9
@@ -42,8 +42,8 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
case '^': return set_ret(bld_->create_xor(lhs, rhs));
|
||||
case Token::LEFT: return set_ret(bld_->create_shl(lhs, rhs));
|
||||
case Token::RIGHT: return set_ret(bld_->create_lshr(lhs, rhs));
|
||||
case '.': return error_not_implemented();
|
||||
case ',': return error_not_implemented();
|
||||
case '.': return error_not_implemented(". binary operator not implemented");
|
||||
case ',': return error_not_implemented(", binary operator not implemented");
|
||||
case '@' : {
|
||||
ir::type* ret_ty = GenIRType(binary->Type(), *ctx_);
|
||||
ir::type* ret_scal_ty = ret_ty->get_scalar_ty();
|
||||
@@ -66,7 +66,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
auto clhs = dynamic_cast<ir::constant_int*>(lhs);
|
||||
auto crhs = dynamic_cast<ir::constant_int*>(rhs);
|
||||
if(!clhs || !crhs)
|
||||
should_not_happen();
|
||||
error_not_implemented("ellipsis between variables not implemented");
|
||||
return set_ret(bld_->insert(ir::make_range::create(clhs, crhs)));
|
||||
}
|
||||
case '+':
|
||||
@@ -97,7 +97,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
else if(!sign)
|
||||
return set_ret(bld_->create_udiv(lhs, rhs));
|
||||
else
|
||||
return should_not_happen();
|
||||
return should_not_happen("/ should not encounter type not in {float, int}");
|
||||
case '%':
|
||||
if(flt)
|
||||
return set_ret(bld_->create_frem(lhs, rhs));
|
||||
@@ -113,7 +113,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
else if(!sign)
|
||||
return set_ret(bld_->create_icmpULT(lhs, rhs));
|
||||
else
|
||||
return should_not_happen();
|
||||
return should_not_happen("< should not encounter type not in {float, int}");
|
||||
case '>':
|
||||
if(flt)
|
||||
return set_ret(bld_->create_fcmpOGT(lhs, rhs));
|
||||
@@ -122,7 +122,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
else if(!sign)
|
||||
return set_ret(bld_->create_icmpUGT(lhs, rhs));
|
||||
else
|
||||
return should_not_happen();
|
||||
return should_not_happen("> should not encounter type not in {float, int}");
|
||||
case Token::LE:
|
||||
if(flt)
|
||||
return set_ret(bld_->create_fcmpOLE(lhs, rhs));
|
||||
@@ -131,7 +131,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
else if(!sign)
|
||||
return set_ret(bld_->create_icmpULE(lhs, rhs));
|
||||
else
|
||||
return should_not_happen();
|
||||
return should_not_happen("<= should not encounter type not in {float, int}");
|
||||
case Token::GE:
|
||||
if(flt)
|
||||
return set_ret(bld_->create_fcmpOGE(lhs, rhs));
|
||||
@@ -140,7 +140,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
else if(!sign)
|
||||
return set_ret(bld_->create_icmpUGE(lhs, rhs));
|
||||
else
|
||||
return should_not_happen();
|
||||
return should_not_happen(">= should not encounter type not in {float, int}");
|
||||
case Token::EQ:
|
||||
if(flt)
|
||||
return set_ret(bld_->create_fcmpOEQ(lhs, rhs));
|
||||
@@ -152,9 +152,9 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
||||
else
|
||||
return set_ret(bld_->create_icmpNE(lhs, rhs));
|
||||
default:
|
||||
error_not_implemented();
|
||||
return error_not_implemented("binary operator " + std::to_string(binary->op_) + " not implemented");
|
||||
}
|
||||
error_not_implemented();
|
||||
should_not_happen("");
|
||||
}
|
||||
|
||||
ir::reduce_inst::op_t reduce_op(int tag, bool is_float) {
|
||||
@@ -166,7 +166,7 @@ ir::reduce_inst::op_t reduce_op(int tag, bool is_float) {
|
||||
case Token::MIN: return is_float ? reduce_inst::FMIN : reduce_inst::MIN;
|
||||
default: break;
|
||||
}
|
||||
should_not_happen();
|
||||
error_not_implemented("reduction operator " + std::to_string(tag) + " not implemented");
|
||||
return reduce_inst::op_t();
|
||||
}
|
||||
|
||||
@@ -176,7 +176,10 @@ ir::value* Generator::GenUnaryMinus(ir::value* arg) {
|
||||
ir::value *_0 = ir::constant_fp::get_zero_value_for_negation(sca_ty);
|
||||
if(ty->is_tile_ty())
|
||||
_0 = bld_->create_splat(_0, ty->get_tile_shapes());
|
||||
return bld_->create_sub(_0, arg);
|
||||
if(sca_ty->is_floating_point_ty())
|
||||
return bld_->create_fsub(_0, arg);
|
||||
else
|
||||
return bld_->create_sub(_0, arg);
|
||||
}
|
||||
|
||||
void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||
@@ -187,18 +190,19 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||
ir::type *arg_scal_ty = arg_ty->get_scalar_ty();
|
||||
// return
|
||||
switch (unary->op_) {
|
||||
case Token::PREFIX_INC: return error_not_implemented();
|
||||
case Token::PREFIX_DEC: return error_not_implemented();
|
||||
case Token::POSTFIX_INC: return error_not_implemented();
|
||||
case Token::POSTFIX_DEC: return error_not_implemented();
|
||||
case Token::ADDR: return error_not_implemented();
|
||||
case Token::PREFIX_INC: return error_not_implemented("prefix increment not implemented");
|
||||
case Token::PREFIX_DEC: return error_not_implemented("prefix decrement not implemented");
|
||||
case Token::POSTFIX_INC: return error_not_implemented("postfix increment not implemented");
|
||||
case Token::POSTFIX_DEC: return error_not_implemented("postfix decrement not implemented");
|
||||
case Token::ADDR: return error_not_implemented("unary & not implemented");
|
||||
case Token::DEREF: return set_ret(bld_->create_load(arg));
|
||||
case Token::PLUS: return error_not_implemented();
|
||||
case Token::PLUS: return error_not_implemented("unary + not implemented");
|
||||
case Token::MINUS: return set_ret(GenUnaryMinus(arg));
|
||||
case '~': return error_not_implemented();
|
||||
case '!': return error_not_implemented();
|
||||
case '~': return error_not_implemented("unary ~ not implemented");
|
||||
case '!': return error_not_implemented("unary ! not implemented");
|
||||
case Token::BITCAST: return set_ret(GenBitCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
||||
case Token::CAST: return set_ret(GenSemCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
||||
case Token::EXP: return set_ret(bld_->create_exp(arg)); //FIXME cast
|
||||
case Token::REDUCE: {
|
||||
int ax, tag;
|
||||
UnaryOp::decodeRed(unary->info_, ax, tag);
|
||||
@@ -206,9 +210,9 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
||||
ir::reduce_inst::op_t op = reduce_op(tag, is_float);
|
||||
return set_ret(bld_->create_reduce(arg, op, ax));
|
||||
}
|
||||
default: error_not_implemented();
|
||||
default: error_not_implemented("unary " + std::to_string(unary->op_) + " not implemented");
|
||||
}
|
||||
return error_not_implemented();
|
||||
return should_not_happen("");
|
||||
}
|
||||
|
||||
void Generator::VisitTransOp(TransOp *trans) {
|
||||
@@ -225,7 +229,9 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) {
|
||||
ir::value* true_val = ret_;
|
||||
VisitExpr(condOp->exprFalse_);
|
||||
ir::value* false_val = ret_;
|
||||
if(ir::load_inst* ld = dynamic_cast<ir::load_inst*>(true_val)) {
|
||||
if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
|
||||
if(!false_val->get_type()->is_tile_ty())
|
||||
false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
|
||||
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(),
|
||||
cond,
|
||||
false_val);
|
||||
@@ -233,7 +239,8 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) {
|
||||
ld->erase_from_parent();
|
||||
return set_ret(new_ld);
|
||||
}
|
||||
return error_not_implemented();
|
||||
return set_ret(bld_->create_select(cond, true_val, false_val));
|
||||
// return error_not_implemented();
|
||||
}
|
||||
|
||||
void Generator::VisitFuncCall(FuncCall* funcCall) {
|
||||
@@ -244,7 +251,7 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
||||
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
|
||||
return set_ret(bld_->create_get_program_id(axis->get_value()));
|
||||
else
|
||||
return should_not_happen();
|
||||
return should_not_happen("get_program_id argument should be constant");
|
||||
}
|
||||
if(name == "get_num_programs"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
@@ -252,7 +259,7 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
||||
if(auto axis = dynamic_cast<ir::constant_int*>(ret))
|
||||
return set_ret(bld_->create_get_num_program(axis->get_value()));
|
||||
else
|
||||
return should_not_happen();
|
||||
return should_not_happen("get_num_programs argument should be constant");
|
||||
}
|
||||
if(name == "atomic_cas"){
|
||||
VisitExpr(funcCall->Args()->at(0));
|
||||
@@ -294,7 +301,7 @@ void Generator::VisitFuncCall(FuncCall* funcCall) {
|
||||
ir::value* false_val = ret_;
|
||||
return set_ret(bld_->create_select(cond, true_val, false_val));
|
||||
}
|
||||
return error_not_implemented();
|
||||
return error_not_implemented("function calls not implemented");
|
||||
}
|
||||
|
||||
void Generator::VisitObject(Object* obj) {
|
||||
@@ -302,7 +309,7 @@ void Generator::VisitObject(Object* obj) {
|
||||
}
|
||||
|
||||
void Generator::VisitEnumerator(Enumerator* enumer) {
|
||||
return error_not_implemented();
|
||||
return error_not_implemented("enumeration not implemented");
|
||||
}
|
||||
|
||||
void Generator::VisitIdentifier(Identifier* ident) {
|
||||
@@ -316,31 +323,36 @@ void Generator::VisitConstant(Constant* cons) {
|
||||
return set_ret(ir::constant_int::get(type, cons->IVal()));
|
||||
if(ctype->IsFloat() && ctype->IsReal())
|
||||
return set_ret(ir::constant_fp::get(type, cons->FVal()));
|
||||
return error_not_implemented();
|
||||
return error_not_implemented("constant of type not in {int, float} not implemented");
|
||||
}
|
||||
|
||||
void Generator::VisitTempVar(TempVar* tempVar) {
|
||||
return error_not_implemented();
|
||||
return error_not_implemented("temporary variable not implemented");
|
||||
}
|
||||
|
||||
// Statement
|
||||
void Generator::VisitDeclaration(Declaration* decl) {
|
||||
auto obj = decl->obj_;
|
||||
// initialize to undef
|
||||
|
||||
ir::type* ty = GenIRType(obj->Type(), *ctx_);
|
||||
ir::value* val = ir::undef_value::get(ty);
|
||||
//obj->GetAttrList()
|
||||
// compute initializers
|
||||
std::vector<ir::value*> inits;
|
||||
for (const Initializer& init: decl->Inits()) {
|
||||
VisitExpr(init.expr_);
|
||||
inits.push_back(ret_);
|
||||
ir::value *val = ret_;
|
||||
for(const auto& attr: obj->GetAttrList())
|
||||
SetIRMetadata(attr, val);
|
||||
inits.push_back(val);
|
||||
}
|
||||
// initialize declaration
|
||||
ir::type::id_t id = ty->get_type_id();
|
||||
if(id == ir::type::StructTyID)
|
||||
should_not_happen();
|
||||
error_not_implemented("struct not implemented");
|
||||
if(inits.size() > 1)
|
||||
should_not_happen();
|
||||
error_not_implemented("initializer list > 1 element not implemented");
|
||||
if(inits.size() > 0)
|
||||
val = inits[0];
|
||||
assert(val->get_type() == ty);
|
||||
@@ -427,20 +439,20 @@ void Generator::VisitForStmt(ForStmt *forStmt) {
|
||||
}
|
||||
|
||||
void Generator::VisitJumpStmt(JumpStmt* jumpStmt) {
|
||||
return error_not_implemented();
|
||||
return error_not_implemented("jump not implemented");
|
||||
}
|
||||
|
||||
void Generator::VisitReturnStmt(ReturnStmt* returnStmt) {
|
||||
ir::value *ret;
|
||||
if(returnStmt->expr_)
|
||||
return error_not_implemented();
|
||||
return error_not_implemented("non-void return not implemented");
|
||||
else
|
||||
ret = bld_->create_ret_void();
|
||||
return set_ret(ret);
|
||||
}
|
||||
|
||||
void Generator::VisitLabelStmt(LabelStmt* labelStmt) {
|
||||
return error_not_implemented();
|
||||
return error_not_implemented("label not implemented");
|
||||
}
|
||||
|
||||
void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) {
|
||||
@@ -458,7 +470,7 @@ void Generator::VisitFuncDef(FuncDef* funcDef) {
|
||||
FuncType* type = funcDef->FuncType();
|
||||
auto prototype = dynamic_cast<ir::function_type*>(GenIRType(type, *ctx_));
|
||||
if(!prototype)
|
||||
should_not_happen();
|
||||
should_not_happen("could not parse function prototype");
|
||||
ir::function *fn = mod_->get_or_insert_function(name, prototype);
|
||||
std::vector<ir::argument*> args = fn->args();
|
||||
size_t i = 0;
|
||||
@@ -529,7 +541,7 @@ ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) {
|
||||
for(size_t d = 0; d < padded_shapes.size(); d++){
|
||||
if(dst_shapes[d] != padded_shapes[d] &&
|
||||
padded_shapes[d] != 1)
|
||||
should_not_happen();
|
||||
should_not_happen("broadcast should not happen between these shapes");
|
||||
}
|
||||
// pad and broadcast
|
||||
ir::value *padded = bld_->create_reshape(src, padded_shapes);
|
||||
@@ -555,6 +567,9 @@ ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) {
|
||||
bool dst_signed = false;
|
||||
if(src_scalar_ty == dst_scalar_ty)
|
||||
return src;
|
||||
else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_bool_ty())
|
||||
return bld_->create_icmpNE(bld_->create_ptr_to_int(src, ir::tile_type::get_same_shapes(bld_->get_int64_ty(), src->get_type())),
|
||||
bld_->create_splat(bld_->get_int64(0), src->get_type()->get_tile_shapes()));
|
||||
else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty())
|
||||
return bld_->create_si_to_fp(src, dst_ty);
|
||||
else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty())
|
||||
@@ -575,7 +590,7 @@ ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) {
|
||||
else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_pointer_ty())
|
||||
return bld_->create_cast(ir::BitCast, src, dst_ty);
|
||||
else{
|
||||
should_not_happen();
|
||||
error_not_implemented("cast type not implemented");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
@@ -594,7 +609,7 @@ ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) {
|
||||
if(attr.kind == ASTNode::Attr::MULTIPLEOF) {
|
||||
VisitExpr(attr.vals[0]);
|
||||
auto cst = dynamic_cast<ir::constant_int*>(ret_);
|
||||
if(!cst) should_not_happen();
|
||||
if(!cst) should_not_happen("multipleof only works on constants");
|
||||
return ir::attribute(ir::multiple_of, cst->get_value());
|
||||
}
|
||||
if(attr.kind == ASTNode::Attr::ALIGNED) {
|
||||
@@ -608,7 +623,15 @@ ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) {
|
||||
return ir::attribute(ir::readonly);
|
||||
if(attr.kind == ASTNode::Attr::WRITEONLY)
|
||||
return ir::attribute(ir::writeonly);
|
||||
should_not_happen();
|
||||
error_not_implemented("attribute " + std::to_string(attr.kind) + " not implemented");
|
||||
}
|
||||
|
||||
void Generator::SetIRMetadata(ASTNode::Attr attr, ir::value *v) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(v);
|
||||
if(!i)
|
||||
return;
|
||||
if(attr.kind == ASTNode::Attr::MULTIPLEOF)
|
||||
i->set_metadata(ir::metadata::multiple_of, GenIRAttr(attr).get_value());
|
||||
}
|
||||
|
||||
// Triton-IR Types
|
||||
@@ -684,12 +707,12 @@ ir::type* Generator::GenIRPointerType(PointerType* type, ir::context& ctx) {
|
||||
}
|
||||
|
||||
ir::type* Generator::GenIRStructType(StructType* type, ir::context& ctx) {
|
||||
error_not_implemented();
|
||||
error_not_implemented("struct not implemented");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Generator::AllocObjects(Scope* scope, const FuncDef::ParamList& params) {
|
||||
return error_not_implemented();
|
||||
return error_not_implemented("alloc not implemented");
|
||||
}
|
||||
|
||||
// SSA
|
||||
@@ -704,7 +727,7 @@ void Generator::popScope() {
|
||||
// LValue Generator
|
||||
void LValAssigner::VisitBinaryOp(BinaryOp* binary) {
|
||||
if(binary->op_ != Token::MASKED_DEREF)
|
||||
error_not_implemented();
|
||||
error_not_implemented("lvalue for binary non masked-deref not implemented");
|
||||
gen_->VisitExpr(binary->lhs_);
|
||||
ir::value* mask = gen_->ret_;
|
||||
gen_->VisitExpr(binary->rhs_);
|
||||
@@ -714,7 +737,7 @@ void LValAssigner::VisitBinaryOp(BinaryOp* binary) {
|
||||
|
||||
void LValAssigner::VisitUnaryOp(UnaryOp* unary) {
|
||||
if(unary->op_ != Token::DEREF)
|
||||
should_not_happen();
|
||||
error_not_implemented("lvalue for unary non deref not implemented");
|
||||
gen_->VisitExpr(unary->operand_);
|
||||
ir::value* addr = gen_->ret_;
|
||||
ret_ = gen_->bld_->create_store(addr, rhs_);
|
||||
|
Reference in New Issue
Block a user