[GENERAL] Fixed some undefined behavior with GCC-9

This commit is contained in:
Philippe Tillet
2020-05-11 11:07:21 -04:00
committed by Philippe Tillet
parent 0516ea96d0
commit ddd89e1b22
7 changed files with 34 additions and 29 deletions

View File

@@ -39,7 +39,8 @@ enum attribute_kind_t {
writeonly,
noalias,
aligned,
multiple_of
multiple_of,
not_implemented
};
class attribute {

View File

@@ -556,14 +556,14 @@ void generator::visit_exp_inst(ir::exp_inst* x){
// Type *ty = llvm_type(x->get_type()->get_scalar_ty(), *ctx_);
// Function *ex2 = Intrinsic::getDeclaration(module, Intrinsic::nvvm_ex2_approx_ftz_f, {ty});
Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634);
FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), {builder_->getFloatTy()}, false);
std::vector<llvm::Type*> tys = {builder_->getFloatTy()};
FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false);
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.ftz.f32 $0, $1;", "=f,f", false);
for_each(x, [&](indices_t idx){
Value *ex2arg = builder_->CreateFMul(arg->get_value(idx), log2e);
set_value(x, idx, builder_->CreateCall(ex2, {ex2arg}));
set_value(x, idx, builder_->CreateCall(ex2, std::vector<llvm::Value*>{ex2arg}));
});
}
@@ -584,7 +584,7 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) {
Value *old = builder_->CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val,
AtomicOrdering::Monotonic,
AtomicOrdering::Monotonic);
old = builder_->CreateExtractValue(old, {0});
old = builder_->CreateExtractValue(old, std::vector<unsigned>{0});
Value *atom_ptr;
atom_ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))));
atom_ptr = builder_->CreateBitCast(atom_ptr, PointerType::get(old->getType(), 3));
@@ -640,8 +640,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
Type *fp32_ty = builder_->getFloatTy();
Type *fp16x2_ty = VectorType::get(builder_->getHalfTy(), 2);
Type *fp32_pack8_ty = StructType::get(*ctx_, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
Type *fp32_pack8_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0);
@@ -720,15 +720,15 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
(pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 3)*ld_fc
};
Value *nc = builder_->CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
fc[idx[0]] = builder_->CreateExtractValue(nc, {0});
fc[idx[1]] = builder_->CreateExtractValue(nc, {1});
fc[idx[2]] = builder_->CreateExtractValue(nc, {2});
fc[idx[3]] = builder_->CreateExtractValue(nc, {3});
fc[idx[4]] = builder_->CreateExtractValue(nc, {4});
fc[idx[5]] = builder_->CreateExtractValue(nc, {5});
fc[idx[6]] = builder_->CreateExtractValue(nc, {6});
fc[idx[7]] = builder_->CreateExtractValue(nc, {7});
Value *nc = builder_->CreateCall(mma_fn, std::vector<llvm::Value*>{ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
fc[idx[0]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{0});
fc[idx[1]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{1});
fc[idx[2]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{2});
fc[idx[3]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{3});
fc[idx[4]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{4});
fc[idx[5]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{5});
fc[idx[6]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{6});
fc[idx[7]] = builder_->CreateExtractValue(nc, std::vector<unsigned>{7});
}
}
}
@@ -770,7 +770,7 @@ void generator::visit_scanline_dot(ir::dot_inst* dot, shared_tile *TA, shared_ti
a = builder_->CreateFPCast(a, c_ty);
if(b->getType() != c_ty)
b = builder_->CreateFPCast(b, c_ty);
res = builder_->CreateCall(f_mul_add, {a, b, res});
res = builder_->CreateCall(f_mul_add, std::vector<llvm::Value*>{a, b, res});
}
set_value(dot, idx, res);
});
@@ -790,7 +790,7 @@ void generator::visit_outer_dot(ir::dot_inst* dot, distributed_tile *TA, distrib
a = builder_->CreateFPCast(a, c_ty);
if(b->getType() != c_ty)
b = builder_->CreateFPCast(b, c_ty);
res = builder_->CreateCall(f_mul_add, {a, b, res});
res = builder_->CreateCall(f_mul_add, std::vector<llvm::Value*>{a, b, res});
set_value(dot, idx, res);
});
}
@@ -805,7 +805,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
distributed_tile *TD = (distributed_tile*)tmap_.at(D);
Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), *ctx_);
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty});
Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector<llvm::Type*>{c_ty});
auto A_shapes = A->get_type()->get_tile_shapes();
size_t red_axis = 1;
unsigned NK = A_shapes[red_axis];
@@ -835,8 +835,8 @@ void generator::visit_sqrt_inst(ir::sqrt_inst* sqt) {
for_each(sqt, [&](indices_t idx){
Value *val = get_value(sqt->get_operand(0), idx);
Module* module = builder_->GetInsertBlock()->getModule();
Value *sqrt = Intrinsic::getDeclaration(module, Intrinsic::sqrt, {val->getType()});
Value *ret = builder_->CreateCall(sqrt, {val});
Value *sqrt = Intrinsic::getDeclaration(module, Intrinsic::sqrt, std::vector<llvm::Type*>{val->getType()});
Value *ret = builder_->CreateCall(sqrt, std::vector<llvm::Value*>{val});
set_value(sqt, idx, ret);
});
}
@@ -849,7 +849,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
unsigned axis = x->get_axis();
Type *fp32_ty = builder_->getFloatTy();
FunctionType *fmaxmin_ty = FunctionType::get(fp32_ty, {fp32_ty, fp32_ty}, false);
FunctionType *fmaxmin_ty = FunctionType::get(fp32_ty, std::vector<llvm::Type*>{fp32_ty, fp32_ty}, false);
InlineAsm *fmin = InlineAsm::get(fmaxmin_ty, "min.ftz.f32 $0, $1, $2;", "=f,f,f", false);
InlineAsm *fmax = InlineAsm::get(fmaxmin_ty, "max.ftz.f32 $0, $1, $2;", "=f,f,f", false);
@@ -871,8 +871,8 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
}
case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y);
case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y);
case ir::reduce_inst::FMAX: return builder_->CreateCall(fmax, {x, y});
case ir::reduce_inst::FMIN: return builder_->CreateCall(fmin, {x, y});
case ir::reduce_inst::FMAX: return builder_->CreateCall(fmax, std::vector<llvm::Value*>{x, y});
case ir::reduce_inst::FMIN: return builder_->CreateCall(fmin, std::vector<llvm::Value*>{x, y});
default: assert(false); return nullptr;
}
};
@@ -910,11 +910,11 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
thread_acc = accumulate(thread_acc, current);
});
// reduce within wrap
FunctionType *fn_ty = FunctionType::get(thread_acc->getType(), {thread_acc->getType(), builder_->getInt32Ty()}, false);
FunctionType *fn_ty = FunctionType::get(thread_acc->getType(), std::vector<llvm::Type*>{thread_acc->getType(), builder_->getInt32Ty()}, false);
InlineAsm *shfl_xor = InlineAsm::get(fn_ty, "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false);
Value *warp_acc = thread_acc;
for(int i = 16; i > 0; i >>= 1)
warp_acc = accumulate(warp_acc, builder_->CreateCall(shfl_xor, {warp_acc, builder_->getInt32(i)}));
warp_acc = accumulate(warp_acc, builder_->CreateCall(shfl_xor, std::vector<llvm::Value*>{warp_acc, builder_->getInt32(i)}));
// shared memory pointer
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
Type *res_ty = arg_tile->get_ty();
@@ -935,7 +935,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
builder_->SetInsertPoint(bb_final_acc);
Value* final_val = builder_->CreateLoad(load_ptr);
for(int i = (num_warps_+1)/2; i > 0; i >>= 1)
final_val = accumulate(final_val, builder_->CreateCall(shfl_xor, {final_val, builder_->getInt32(i)}));
final_val = accumulate(final_val, builder_->CreateCall(shfl_xor, std::vector<llvm::Value*>{final_val, builder_->getInt32(i)}));
builder_->CreateStore(final_val, load_ptr);
builder_->CreateBr(bb_final_acc_done);
// // store first warp done

View File

@@ -81,6 +81,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
add->replace_all_uses_with(new_dot);
return true;
}
return false;
}
//bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){

View File

@@ -146,7 +146,8 @@ host_module::host_module(driver::context * context, std::unique_ptr<llvm::Module
llvm::Type *void_ty = llvm::Type::getVoidTy(ctx);
llvm::Type *args_ty = llvm::Type::getInt8PtrTy(ctx)->getPointerTo();
llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx);
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, {args_ty, int32_ty, int32_ty, int32_ty}, false);
std::vector<llvm::Type*> tys = {args_ty, int32_ty, int32_ty, int32_ty};
llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false);
llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", &*src);
llvm::Function* fn = src->getFunction("matmul");
llvm::FunctionType *fn_ty = fn->getFunctionType();

View File

@@ -70,6 +70,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){
return phi;
// unique value or self-reference
ir::value *same = *non_self_ref.begin();
assert(same != nullptr);
std::set<ir::user*> users = phi->get_users();
phi->replace_all_uses_with(same);
phi->erase_from_parent();

View File

@@ -624,6 +624,7 @@ ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) {
if(attr.kind == ASTNode::Attr::WRITEONLY)
return ir::attribute(ir::writeonly);
error_not_implemented("attribute " + std::to_string(attr.kind) + " not implemented");
return ir::attribute(ir::not_implemented);
}
void Generator::SetIRMetadata(ASTNode::Attr attr, ir::value *v) {

View File

@@ -64,7 +64,7 @@ template<> struct to_string<double>{
};
template<class T>
bool triton_dot(drv::stream* stream, bool AT, bool BT,
void triton_dot(drv::stream* stream, bool AT, bool BT,
int32_t M, int32_t N, int32_t K,
int32_t TM, int32_t TN, int32_t TK, int32_t nwarp,
const std::vector<int>& a_order, const std::vector<int>& b_order,