From ddd89e1b22b3b477777e6211c887e0d63378a4d5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 11 May 2020 11:07:21 -0400 Subject: [PATCH] [GENERAL] Fixed some undefined behavior with GCC-9 --- include/triton/ir/function.h | 3 +- lib/codegen/selection/generator.cc | 52 +++++++++++++++--------------- lib/codegen/transform/peephole.cc | 1 + lib/driver/module.cc | 3 +- lib/ir/module.cc | 1 + lib/lang/code_gen.cc | 1 + tests/common/dot.h | 2 +- 7 files changed, 34 insertions(+), 29 deletions(-) diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index d3ebe199b..1a1b3a0ec 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -39,7 +39,8 @@ enum attribute_kind_t { writeonly, noalias, aligned, - multiple_of + multiple_of, + not_implemented }; class attribute { diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 0a79481cb..92674b7db 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -556,14 +556,14 @@ void generator::visit_exp_inst(ir::exp_inst* x){ // Type *ty = llvm_type(x->get_type()->get_scalar_ty(), *ctx_); // Function *ex2 = Intrinsic::getDeclaration(module, Intrinsic::nvvm_ex2_approx_ftz_f, {ty}); Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634); - - FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), {builder_->getFloatTy()}, false); + std::vector tys = {builder_->getFloatTy()}; + FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false); InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.ftz.f32 $0, $1;", "=f,f", false); for_each(x, [&](indices_t idx){ Value *ex2arg = builder_->CreateFMul(arg->get_value(idx), log2e); - set_value(x, idx, builder_->CreateCall(ex2, {ex2arg})); + set_value(x, idx, builder_->CreateCall(ex2, std::vector{ex2arg})); }); } @@ -584,7 +584,7 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { Value *old = builder_->CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic); - old = builder_->CreateExtractValue(old, {0}); + old = builder_->CreateExtractValue(old, std::vector{0}); Value *atom_ptr; atom_ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(layouts_->tmp(cas))))); atom_ptr = builder_->CreateBitCast(atom_ptr, PointerType::get(old->getType(), 3)); @@ -640,8 +640,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile * Type *fp32_ty = builder_->getFloatTy(); Type *fp16x2_ty = VectorType::get(builder_->getHalfTy(), 2); - Type *fp32_pack8_ty = StructType::get(*ctx_, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}); - FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); + Type *fp32_pack8_ty = StructType::get(*ctx_, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}); + FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0); @@ -720,15 +720,15 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile * (pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 2)*ld_fc, (pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 3)*ld_fc }; - Value *nc = builder_->CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]}); - fc[idx[0]] = builder_->CreateExtractValue(nc, {0}); - fc[idx[1]] = builder_->CreateExtractValue(nc, {1}); - fc[idx[2]] = builder_->CreateExtractValue(nc, {2}); - fc[idx[3]] = builder_->CreateExtractValue(nc, {3}); - fc[idx[4]] = builder_->CreateExtractValue(nc, {4}); - fc[idx[5]] = builder_->CreateExtractValue(nc, {5}); - fc[idx[6]] = builder_->CreateExtractValue(nc, {6}); - fc[idx[7]] = builder_->CreateExtractValue(nc, {7}); + Value *nc = builder_->CreateCall(mma_fn, std::vector{ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]}); + fc[idx[0]] = builder_->CreateExtractValue(nc, std::vector{0}); + fc[idx[1]] = builder_->CreateExtractValue(nc, std::vector{1}); + fc[idx[2]] = builder_->CreateExtractValue(nc, std::vector{2}); + fc[idx[3]] = builder_->CreateExtractValue(nc, std::vector{3}); + fc[idx[4]] = builder_->CreateExtractValue(nc, std::vector{4}); + fc[idx[5]] = builder_->CreateExtractValue(nc, std::vector{5}); + fc[idx[6]] = builder_->CreateExtractValue(nc, std::vector{6}); + fc[idx[7]] = builder_->CreateExtractValue(nc, std::vector{7}); } } } @@ -770,7 +770,7 @@ void generator::visit_scanline_dot(ir::dot_inst* dot, shared_tile *TA, shared_ti a = builder_->CreateFPCast(a, c_ty); if(b->getType() != c_ty) b = builder_->CreateFPCast(b, c_ty); - res = builder_->CreateCall(f_mul_add, {a, b, res}); + res = builder_->CreateCall(f_mul_add, std::vector{a, b, res}); } set_value(dot, idx, res); }); @@ -790,7 +790,7 @@ void generator::visit_outer_dot(ir::dot_inst* dot, distributed_tile *TA, distrib a = builder_->CreateFPCast(a, c_ty); if(b->getType() != c_ty) b = builder_->CreateFPCast(b, c_ty); - res = builder_->CreateCall(f_mul_add, {a, b, res}); + res = builder_->CreateCall(f_mul_add, std::vector{a, b, res}); set_value(dot, idx, res); }); } @@ -805,7 +805,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) { distributed_tile *TD = (distributed_tile*)tmap_.at(D); Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), *ctx_); - Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty}); + Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector{c_ty}); auto A_shapes = A->get_type()->get_tile_shapes(); size_t red_axis = 1; unsigned NK = A_shapes[red_axis]; @@ -835,8 +835,8 @@ void generator::visit_sqrt_inst(ir::sqrt_inst* sqt) { for_each(sqt, [&](indices_t idx){ Value *val = get_value(sqt->get_operand(0), idx); Module* module = builder_->GetInsertBlock()->getModule(); - Value *sqrt = Intrinsic::getDeclaration(module, Intrinsic::sqrt, {val->getType()}); - Value *ret = builder_->CreateCall(sqrt, {val}); + Value *sqrt = Intrinsic::getDeclaration(module, Intrinsic::sqrt, std::vector{val->getType()}); + Value *ret = builder_->CreateCall(sqrt, std::vector{val}); set_value(sqt, idx, ret); }); } @@ -849,7 +849,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { unsigned axis = x->get_axis(); Type *fp32_ty = builder_->getFloatTy(); - FunctionType *fmaxmin_ty = FunctionType::get(fp32_ty, {fp32_ty, fp32_ty}, false); + FunctionType *fmaxmin_ty = FunctionType::get(fp32_ty, std::vector{fp32_ty, fp32_ty}, false); InlineAsm *fmin = InlineAsm::get(fmaxmin_ty, "min.ftz.f32 $0, $1, $2;", "=f,f,f", false); InlineAsm *fmax = InlineAsm::get(fmaxmin_ty, "max.ftz.f32 $0, $1, $2;", "=f,f,f", false); @@ -871,8 +871,8 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { } case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y); case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y); - case ir::reduce_inst::FMAX: return builder_->CreateCall(fmax, {x, y}); - case ir::reduce_inst::FMIN: return builder_->CreateCall(fmin, {x, y}); + case ir::reduce_inst::FMAX: return builder_->CreateCall(fmax, std::vector{x, y}); + case ir::reduce_inst::FMIN: return builder_->CreateCall(fmin, std::vector{x, y}); default: assert(false); return nullptr; } }; @@ -910,11 +910,11 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { thread_acc = accumulate(thread_acc, current); }); // reduce within wrap - FunctionType *fn_ty = FunctionType::get(thread_acc->getType(), {thread_acc->getType(), builder_->getInt32Ty()}, false); + FunctionType *fn_ty = FunctionType::get(thread_acc->getType(), std::vector{thread_acc->getType(), builder_->getInt32Ty()}, false); InlineAsm *shfl_xor = InlineAsm::get(fn_ty, "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false); Value *warp_acc = thread_acc; for(int i = 16; i > 0; i >>= 1) - warp_acc = accumulate(warp_acc, builder_->CreateCall(shfl_xor, {warp_acc, builder_->getInt32(i)})); + warp_acc = accumulate(warp_acc, builder_->CreateCall(shfl_xor, std::vector{warp_acc, builder_->getInt32(i)})); // shared memory pointer unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace(); Type *res_ty = arg_tile->get_ty(); @@ -935,7 +935,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) { builder_->SetInsertPoint(bb_final_acc); Value* final_val = builder_->CreateLoad(load_ptr); for(int i = (num_warps_+1)/2; i > 0; i >>= 1) - final_val = accumulate(final_val, builder_->CreateCall(shfl_xor, {final_val, builder_->getInt32(i)})); + final_val = accumulate(final_val, builder_->CreateCall(shfl_xor, std::vector{final_val, builder_->getInt32(i)})); builder_->CreateStore(final_val, load_ptr); builder_->CreateBr(bb_final_acc_done); // // store first warp done diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index c6ff03fbe..d8062be6b 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -81,6 +81,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ add->replace_all_uses_with(new_dot); return true; } + return false; } //bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){ diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 3a48dfdd1..c09fa5ef1 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -146,7 +146,8 @@ host_module::host_module(driver::context * context, std::unique_ptrgetPointerTo(); llvm::Type *int32_ty = llvm::Type::getInt32Ty(ctx); - llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, {args_ty, int32_ty, int32_ty, int32_ty}, false); + std::vector tys = {args_ty, int32_ty, int32_ty, int32_ty}; + llvm::FunctionType *main_ty = llvm::FunctionType::get(void_ty, tys, false); llvm::Function* main = llvm::Function::Create(main_ty, llvm::Function::ExternalLinkage, "main", &*src); llvm::Function* fn = src->getFunction("matmul"); llvm::FunctionType *fn_ty = fn->getFunctionType(); diff --git a/lib/ir/module.cc b/lib/ir/module.cc index 4a4655fb6..67617478b 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -70,6 +70,7 @@ ir::value *module::try_remove_trivial_phis(ir::phi_node *&phi){ return phi; // unique value or self-reference ir::value *same = *non_self_ref.begin(); + assert(same != nullptr); std::set users = phi->get_users(); phi->replace_all_uses_with(same); phi->erase_from_parent(); diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index 6b82b8b26..517553e97 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -624,6 +624,7 @@ ir::attribute Generator::GenIRAttr(ASTNode::Attr attr) { if(attr.kind == ASTNode::Attr::WRITEONLY) return ir::attribute(ir::writeonly); error_not_implemented("attribute " + std::to_string(attr.kind) + " not implemented"); + return ir::attribute(ir::not_implemented); } void Generator::SetIRMetadata(ASTNode::Attr attr, ir::value *v) { diff --git a/tests/common/dot.h b/tests/common/dot.h index 5556f750f..5dc7ef74e 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -64,7 +64,7 @@ template<> struct to_string{ }; template -bool triton_dot(drv::stream* stream, bool AT, bool BT, +void triton_dot(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, int32_t nwarp, const std::vector& a_order, const std::vector& b_order,