#include #include #include #include #include "triton/codegen/selection/generator.h" #include "triton/codegen/target.h" #include "triton/codegen/analysis/axes.h" #include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/swizzle.h" #include "triton/codegen/transform/coalesce.h" #include "triton/ir/context.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/type.h" #include "triton/ir/utils.h" #include "llvm/IR/Module.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Verifier.h" #include "llvm/IR/Type.h" #ifdef USE_ROCM #include "llvm/IR/IntrinsicsAMDGPU.h" #else #include "llvm/IR/IntrinsicsNVPTX.h" #endif #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/InlineAsm.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" namespace triton{ namespace codegen{ using namespace llvm; Value* adder::operator()(Value *x, Value *y, const std::string& name) { // (x + cst) + y -> (x + y) + cst if(auto* bin = dyn_cast(x)) if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) if(dyn_cast(bin->getOperand(1))){ return (*builder_)->CreateAdd((*builder_)->CreateAdd(bin->getOperand(0), y), bin->getOperand(1)); } // (x + (y + cst)) -> (x + y) + cst if(auto* bin = dyn_cast(y)) if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) if(dyn_cast(bin->getOperand(1))){ return (*builder_)->CreateAdd((*builder_)->CreateAdd(x, bin->getOperand(0)), bin->getOperand(1)); } // default return (*builder_)->CreateAdd(x, y, name); } Value* multiplier::operator()(Value *x, Value *y, const std::string &name) { // (x + cst1) * cst2 -> (x * cst2) + (cst1 * cst2) if(auto* bin = dyn_cast(x)) if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) if(dyn_cast(bin->getOperand(1))) if(dyn_cast(y)){ return (*builder_)->CreateAdd((*builder_)->CreateMul(bin->getOperand(0), y), (*builder_)->CreateMul(bin->getOperand(1), y)); } // default return (*builder_)->CreateMul(x, y, name); } Value* geper::operator()(Value *ptr, Value* off, const std::string& name){ // (ptr + cst1) + (cst2) -> ptr + (cst1 + cst2) if(auto* gep = dyn_cast(ptr)) if(ConstantInt* cst1 = dyn_cast(gep->idx_begin())) if(ConstantInt* cst2 = dyn_cast(off)){ return (*builder_)->CreateGEP(gep->getPointerOperand()->getType()->getScalarType()->getPointerElementType(), gep->getPointerOperand(), (*builder_)->CreateAdd(cst1, cst2)); } // ptr + (off + cst) -> (ptr + off) + cst if(auto* bin = dyn_cast(off)) if(bin->getOpcode() == llvm::BinaryOperator::BinaryOps::Add) if(ConstantInt* cst = dyn_cast(bin->getOperand(1))){ Value *gep = (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(), ptr, bin->getOperand(0)); return (*builder_)->CreateGEP(gep->getType()->getScalarType()->getPointerElementType(), gep, bin->getOperand(1)); } // default return (*builder_)->CreateGEP(ptr->getType()->getScalarType()->getPointerElementType(), ptr, off, name); } //Value* geper::operator()(Type *ty, Value *ptr, std::vector vals, const std::string &name) { // return (*builder_)->CreateGEP(ty, ptr, vals, name); //} // types #define void_ty builder_->getVoidTy() #define f16_ty builder_->getHalfTy() #define bf16_ty builder_->getInt16Ty() #define f32_ty builder_->getFloatTy() #define f64_ty builder_->getDoubleTy() #define i1_ty builder_->getInt1Ty() #define i8_ty builder_->getInt8Ty() #define i16_ty builder_->getInt16Ty() #define i32_ty builder_->getInt32Ty() #define i64_ty builder_->getInt64Ty() #define vec_ty(type, num_el) VectorType::get(type, num_el, false) #define ptr_ty(...) PointerType::get(__VA_ARGS__) // constants #define i16(...) builder_->getInt16(__VA_ARGS__) #define i32(...) builder_->getInt32(__VA_ARGS__) // ops #define and_(...) builder_->CreateAnd(__VA_ARGS__) #define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__) #define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__) #define bin_op(...) builder_->CreateBinOp(__VA_ARGS__) #define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__) #define br(...) builder_->CreateBr(__VA_ARGS__) #define call(...) builder_->CreateCall(__VA_ARGS__) #define cast(...) builder_->CreateCast(__VA_ARGS__) #define cond_br(...) builder_->CreateCondBr(__VA_ARGS__) #define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__) #define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__) #define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__) #define fadd(...) builder_->CreateFAdd(__VA_ARGS__) #define fcmp(...) builder_->CreateFCmp(__VA_ARGS__) #define fcmp_oge(...) builder_->CreateFCmpOGE(__VA_ARGS__) #define fcmp_ole(...) builder_->CreateFCmpOLE(__VA_ARGS__) #define fmul(...) builder_->CreateFMul(__VA_ARGS__) #define fpcast(...) builder_->CreateFPCast(__VA_ARGS__) #define fsub(...) builder_->CreateFSub(__VA_ARGS__) #define icmp(...) builder_->CreateICmp(__VA_ARGS__) #define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__) #define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__) #define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__) #define icmp_uge(...) builder_->CreateICmpUGE(__VA_ARGS__) #define icmp_ule(...) builder_->CreateICmpULE(__VA_ARGS__) #define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__) #define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__) #define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__) #define load(ptr) builder_->CreateLoad(ptr->getType()->getPointerElementType(), ptr) #define lshr(...) builder_->CreateLShr(__VA_ARGS__) #define max_num(...) builder_->CreateMaxNum(__VA_ARGS__) #define min_num(...) builder_->CreateMinNum(__VA_ARGS__) #define neg(...) builder_->CreateNeg(__VA_ARGS__) #define phi(...) builder_->CreatePHI(__VA_ARGS__) #define ret(...) builder_->CreateRet(__VA_ARGS__) #define select(...) builder_->CreateSelect(__VA_ARGS__) #define store(...) builder_->CreateStore(__VA_ARGS__) #define sub(...) builder_->CreateSub(__VA_ARGS__) #define shl(...) builder_->CreateShl(__VA_ARGS__) #define udiv(...) builder_->CreateUDiv(__VA_ARGS__) #define urem(...) builder_->CreateURem(__VA_ARGS__) #define splat(...) builder_->CreateVectorSplat(__VA_ARGS__) #define xor_(...) builder_->CreateXor(__VA_ARGS__) /** * \brief Convert Triton-IR Type to LLVM-IR Type */ Type *generator::cvt(ir::type *ty) { // struct if(ty->is_struct_ty()){ std::vector tys; for(size_t i = 0; i < ty->get_struct_numel(); i++) tys.push_back(cvt(ty->get_struct_type(i))); return StructType::get(builder_->getContext(), tys, true); } // function if(auto* tt = dynamic_cast(ty)){ Type *ret_ty = cvt(tt->get_return_ty()); std::vector arg_tys(tt->get_num_params()); for(size_t i = 0; i < arg_tys.size(); i++) arg_tys[i] = cvt(tt->get_param_ty(i)); return FunctionType::get(ret_ty, arg_tys, false); } // pointer if(ty->is_pointer_ty()){ Type *elt_ty = cvt(ty->get_pointer_element_ty()); unsigned addr_space = ty->get_pointer_address_space(); return ptr_ty(elt_ty, addr_space); } // integer if(ty->is_integer_ty()){ unsigned bitwidth = ty->get_integer_bitwidth(); return IntegerType::get(*ctx_, bitwidth); } // primitive types switch(ty->get_type_id()){ case ir::type::VoidTyID: return Type::getVoidTy(*ctx_); case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_); case ir::type::FP16TyID: return Type::getHalfTy(*ctx_); case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_); // use int16 as storage type case ir::type::FP32TyID: return Type::getFloatTy(*ctx_); case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_); case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_); case ir::type::TokenTyID: return Type::getTokenTy(*ctx_); default: break; } // unknown type throw std::runtime_error("unknown conversion from ir::type to Type"); } /** * \brief Convert Triton-IR Attribute to LLVM-IR Attribute */ llvm::Attribute generator::cvt(ir::attribute attr) { switch(attr.get_kind()){ case ir::noalias: return llvm::Attribute::get(*ctx_, llvm::Attribute::NoAlias); case ir::readonly: return llvm::Attribute::get(*ctx_, llvm::Attribute::ReadOnly); case ir::writeonly: return llvm::Attribute::get(*ctx_, llvm::Attribute::WriteOnly); case ir::aligned: return llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, attr.get_value()); case ir::retune: return llvm::Attribute::get(*ctx_, llvm::Attribute::None); default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute"); } } /** * \brief Constructor of LLVM code generator */ generator::generator(analysis::axes *a_axes, analysis::layouts *layouts, analysis::align *alignment, analysis::allocation *alloc, analysis::swizzle *swizzle, target *tgt, unsigned num_warps) : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle), tgt_(tgt), num_warps_(num_warps), add(&builder_), mul(&builder_), gep(&builder_) { } /** * \brief Code Generation for `value` */ void generator::visit_value(ir::value* v) { if(!seen_.insert(v).second) return; if(v->get_type()->is_block_ty()){ if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){ analysis::N_buffer_info_t *n_buffer = layout->get_N_buffer(); analysis::double_buffer_info_t *double_buffer = layout->get_double_buffer(); // offset Value *offset = nullptr; // base pointer Value *ptr = shared_ptr_[layout]; if (n_buffer) { // ptr = base (shared_ptr_[layout]) + smem_idx * size // read_smem_idx if (v == n_buffer->phi) { ptr = shared_ptr_[layout]; } // write_smem_idx if (std::find(n_buffer->firsts.begin(), n_buffer->firsts.end(), v) != n_buffer->firsts.end()) { int write_smem_idx = /*stage_idx*/n_buffer->firsts_idx.at(v); int elements = write_smem_idx * layout->get_per_stage_elements(); ptr = gep(shared_pre_ptr_[layout], i32(elements)); } else if (v == n_buffer->latch) { Value* write_smem_idx = write_smem_idx_[layout]; Value* elements = mul(write_smem_idx, i32(layout->get_per_stage_elements())); ptr = gep(shared_pre_ptr_[layout], elements); } } else if (double_buffer) { if(v == double_buffer->phi) offset = shared_off_[layout]; if(v == double_buffer->latch) ptr = shared_next_ptr_[layout]; else if(v == double_buffer->first) ptr = shared_pre_ptr_[layout]; } // else do nothing // what visit_dot & vist_cts & ... see shmems_[v] = ptr; // now only latches have offset (PHINode), only used by finalize_share_layout() shoffs_[v] = offset; } } // visit operands BasicBlock *current = builder_->GetInsertBlock(); auto *inst = dynamic_cast(v); if(inst) for(ir::value *op: inst->ops()){ if(dynamic_cast(op) || !dynamic_cast(v)) visit_value(op); } init_idx(v); // change insert point for phi node builder_->SetInsertPoint(current); auto *phi = dynamic_cast(v); if(phi && !current->empty() && current->getFirstNonPHI()) builder_->SetInsertPoint(&*current->getFirstNonPHI()); // visit user if(auto *usr = dynamic_cast(v)){ if(!dynamic_cast(usr)) usr->accept(this); } // revert insert point if(phi && !current->empty() && current->getFirstNonPHI()) builder_->SetInsertPoint(current); } /** * \brief Code Generation for `phi` */ void generator::visit_phi_node(ir::phi_node* x) { Type *ty = cvt(x->get_type()->get_scalar_ty()); for(indices_t idx: idxs_.at(x)) vals_[x][idx] = phi(ty, x->get_num_operands()); } /** * \brief Code Generation for `call` */ void generator::visit_call_inst(ir::call_inst* call) { throw std::runtime_error("call not supported! Triton should be inlining everything."); } void generator::visit_launch_inst(ir::launch_inst *launch) { ir::function* fn = (ir::function*)launch->get_operand(0); // forward-declare cudaGetParameterBufferV2 std::vector get_param_arg_tys = {PointerType::get(builder_->getInt8Ty(), 0), ArrayType::get(builder_->getInt32Ty(), 3), ArrayType::get(builder_->getInt32Ty(), 3), builder_->getInt32Ty()}; FunctionType* get_param_ty = FunctionType::get(PointerType::get(builder_->getInt8Ty(), 0), get_param_arg_tys, false); Function* get_param_buffer = Function::Create(get_param_ty, Function::ExternalLinkage, "cudaGetParameterBufferV2", mod_); AllocaInst* grid = builder_->CreateAlloca(get_param_arg_tys[1]); AllocaInst* block = builder_->CreateAlloca(get_param_arg_tys[2]); ConstantInt* _0 = builder_->getInt32(0); ConstantInt* _1 = builder_->getInt32(1); ConstantInt* _2 = builder_->getInt32(2); // create basic block BasicBlock* launch_done_bb = BasicBlock::Create(builder_->getContext(), "launch_done", builder_->GetInsertBlock()->getParent()); BasicBlock* launch_bb = BasicBlock::Create(builder_->getContext(), "launch", launch_done_bb->getParent(), launch_done_bb); Value *tid = tgt_->get_local_id(mod_, *builder_, 0); Value *is_first_thread = builder_->CreateICmpEQ(tid, i32(0)); builder_->CreateCondBr(is_first_thread, launch_bb, launch_done_bb); builder_->SetInsertPoint(launch_bb); // builder_->CreateStore(vals_[launch->get_grid()[0]][{}], builder_->CreateGEP(grid, {_0, _0})); builder_->CreateStore(vals_[launch->get_grid()[1]][{}], builder_->CreateGEP(grid, {_0, _1})); builder_->CreateStore(vals_[launch->get_grid()[2]][{}], builder_->CreateGEP(grid, {_0, _2})); Value* num_warps = mul(builder_->getInt32(32), vals_[launch->get_num_warps()][{}]); builder_->CreateStore(num_warps, builder_->CreateGEP(block, {_0, _0})); builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _1})); builder_->CreateStore(builder_->getInt32(1), builder_->CreateGEP(block, {_0, _2})); Function* called_fn = fns_[fn]; Value* callee = ConstantExpr::getCast(Instruction::BitCast, called_fn, get_param_arg_tys[0]); Value* arg_ptr = builder_->CreateCall(get_param_buffer, {callee, builder_->CreateLoad(grid), builder_->CreateLoad(block), builder_->getInt32(0)}); // forwrd-declare cudaLaunchDeviceV2 std::vector launch_device_arg_tys = {get_param_ty->getReturnType(), builder_->getInt64Ty()}; FunctionType* launch_device_ty = FunctionType::get(builder_->getInt32Ty(), launch_device_arg_tys, false); Function* launch_device = Function::Create(launch_device_ty, Function::ExternalLinkage, "cudaLaunchDeviceV2", mod_); // TODO: add branch Value* do_not_launch = builder_->CreateICmpEQ(builder_->CreatePtrToInt(arg_ptr, builder_->getInt64Ty()), builder_->getInt64(0)); BasicBlock* launch2_bb = BasicBlock::Create(builder_->getContext(), "launch2", launch_done_bb->getParent(), launch_done_bb); builder_->CreateCondBr(do_not_launch, launch_done_bb, launch2_bb); builder_->SetInsertPoint(launch2_bb); unsigned addr_space = arg_ptr->getType()->getPointerAddressSpace(); unsigned off = 0; unsigned last_size = 0; for(ir::value* arg: launch->get_values()){ Value* curr_arg = vals_[arg][{}]; Type* curr_arg_ty = curr_arg->getType(); // handle struct alignment off += last_size; unsigned size = curr_arg_ty->isPointerTy() ? 8 : curr_arg_ty->getPrimitiveSizeInBits() / 8; off = (off + size - 1) / size * size; // get pointer to current arg Value* curr_arg_ptr = builder_->CreateGEP(arg_ptr, builder_->getInt32(off)); curr_arg_ptr = builder_->CreateBitCast(curr_arg_ptr, curr_arg_ty->getPointerTo(addr_space)); // store arg builder_->CreateStore(curr_arg, curr_arg_ptr); last_size = size; } builder_->CreateCall(launch_device, {arg_ptr, builder_->getInt64(0)}); builder_->CreateBr(launch_done_bb); // done builder_->SetInsertPoint(launch_done_bb); } /** * \brief Code Generation for `binary_operator` */ void generator::visit_binary_operator(ir::binary_operator*x) { using ll = llvm::Instruction::BinaryOps; using tt = ir::binary_op_t; auto cvt = [](ir::binary_op_t op){ switch(op) { case tt::Add: return ll::Add; case tt::FAdd: return ll::FAdd; case tt::Sub: return ll::Sub; case tt::FSub: return ll::FSub; case tt::Mul: return ll::Mul; case tt::FMul: return ll::FMul; case tt::UDiv: return ll::UDiv; case tt::SDiv: return ll::SDiv; case tt::FDiv: return ll::FDiv; case tt::URem: return ll::URem; case tt::SRem: return ll::SRem; case tt::FRem: return ll::FRem; case tt::Shl: return ll::Shl; case tt::LShr: return ll::LShr; case tt::AShr: return ll::AShr; case tt::And: return ll::And; case tt::Or: return ll::Or; case tt::Xor: return ll::Xor; default: throw std::runtime_error("unreachable switch"); } }; // x->print(std::cout); for(indices_t idx: idxs_.at(x)){ Value *lhs = vals_[x->get_operand(0)][idx]; Value *rhs = vals_[x->get_operand(1)][idx]; if (x->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty()) { assert(x->get_operand(1)->get_type()->get_scalar_ty()->is_bf16_ty()); if (x->get_op() == tt::FAdd) { InlineAsm *bf16_add_asm = InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), "v_lshlrev_b32 $1, 16, $1 " "v_lshlrev_b32 $2, 16, $2 " "v_add_f32 $0, $1, $2 " "v_lshrrev_b32 $0, 16, $0 ", "=v,v,v", false); vals_[x][idx] = builder_->CreateCall(bf16_add_asm, {lhs, rhs}); } else if (x->get_op() == tt::FSub) { // a - b = b * (-1.0) + a InlineAsm *bf16_sub_asm = InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), "v_lshlrev_b32 $1, 16, $1 " "v_lshlrev_b32 $2, 16, $2 " "v_sub_f32 $0, $1, $2 " "v_lshrrev_b32 $0, 16, $0 ", "=v,v,v", false); vals_[x][idx] = builder_->CreateCall(bf16_sub_asm, {lhs, rhs}); } else if (x->get_op() == tt::FMul) { // a * b = a*b + 0 InlineAsm *bf16_mul_asm = InlineAsm::get(FunctionType::get(bf16_ty, {bf16_ty, bf16_ty}, false), "v_lshlrev_b32 $1, 16, $1 " "v_lshlrev_b32 $2, 16, $2 " "v_mul_f32 $0, $1, $2 " "v_lshrrev_b32 $0, 16, $0 ", "=v,v,v", false); vals_[x][idx] = builder_->CreateCall(bf16_mul_asm, {lhs, rhs}); } else throw std::runtime_error("invalid bin op for bf16"); } else { auto op = cvt(x->get_op()); if(op == ll::Add) vals_[x][idx] = add(lhs, rhs); else if(op == ll::Mul) vals_[x][idx] = mul(lhs, rhs); else vals_[x][idx] = bin_op(op, lhs, rhs); } } } /** * \brief Code Generation for `getelementptr` */ void generator::visit_getelementptr_inst(ir::getelementptr_inst* x) { for(indices_t idx: idxs_.at(x)){ Value *ptr = vals_[x->get_pointer_operand()][idx]; std::vector vals; for(auto it= x->idx_begin(); it != x->idx_end(); it++) vals.push_back(vals_[*it][idx]); assert(vals.size() == 1); vals_[x][idx] = gep(ptr, vals[0]); } } /** * \brief Code Generation for `icmp` */ void generator::visit_icmp_inst(ir::icmp_inst* x) { auto cvt = [](ir::cmp_pred_t pred) { using ll = llvm::CmpInst::Predicate; using tt = ir::cmp_pred_t; switch(pred){ case tt::FIRST_ICMP_PREDICATE: return ll::FIRST_ICMP_PREDICATE; case tt::ICMP_EQ: return ll::ICMP_EQ; case tt::ICMP_NE: return ll::ICMP_NE; case tt::ICMP_UGT: return ll::ICMP_UGT; case tt::ICMP_UGE: return ll::ICMP_UGE; case tt::ICMP_ULT: return ll::ICMP_ULT; case tt::ICMP_ULE: return ll::ICMP_ULE; case tt::ICMP_SGT: return ll::ICMP_SGT; case tt::ICMP_SGE: return ll::ICMP_SGE; case tt::ICMP_SLT: return ll::ICMP_SLT; case tt::ICMP_SLE: return ll::ICMP_SLE; case tt::LAST_ICMP_PREDICATE: return ll::LAST_ICMP_PREDICATE; default: throw std::runtime_error("unreachable switch"); } }; for(indices_t idx: idxs_.at(x)){ Value *lhs = vals_[x->get_operand(0)][idx]; Value *rhs = vals_[x->get_operand(1)][idx]; vals_[x][idx] = icmp(cvt(x->get_pred()), lhs, rhs); } } /** * \brief Code Generation for `fcmp` */ void generator::visit_fcmp_inst(ir::fcmp_inst* x) { auto cvt = [](ir::cmp_pred_t pred) { using ll = llvm::CmpInst::Predicate; using tt = ir::cmp_pred_t; switch(pred){ case tt::FIRST_FCMP_PREDICATE: return ll::FIRST_FCMP_PREDICATE; case tt::FCMP_FALSE: return ll::FCMP_FALSE; case tt::FCMP_OEQ: return ll::FCMP_OEQ; case tt::FCMP_OGT: return ll::FCMP_OGT; case tt::FCMP_OGE: return ll::FCMP_OGE; case tt::FCMP_OLT: return ll::FCMP_OLT; case tt::FCMP_OLE: return ll::FCMP_OLE; case tt::FCMP_ONE: return ll::FCMP_ONE; case tt::FCMP_ORD: return ll::FCMP_ORD; case tt::FCMP_UNO: return ll::FCMP_UNO; case tt::FCMP_UEQ: return ll::FCMP_UEQ; case tt::FCMP_UGT: return ll::FCMP_UGT; case tt::FCMP_UGE: return ll::FCMP_UGE; case tt::FCMP_ULT: return ll::FCMP_ULT; case tt::FCMP_ULE: return ll::FCMP_ULE; case tt::FCMP_UNE: return ll::FCMP_UNE; case tt::FCMP_TRUE: return ll::FCMP_TRUE; case tt::LAST_FCMP_PREDICATE: return ll::LAST_FCMP_PREDICATE; default: throw std::runtime_error("unreachable switch"); } }; for(indices_t idx: idxs_.at(x)){ Value *lhs = vals_[x->get_operand(0)][idx]; Value *rhs = vals_[x->get_operand(1)][idx]; vals_[x][idx] = fcmp(cvt(x->get_pred()), lhs, rhs); } } std::tuple generator::fp32x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3){ in0 = cast(llvm::Instruction::FPTrunc, in0, f16_ty); in1 = cast(llvm::Instruction::FPTrunc, in1, f16_ty); in2 = cast(llvm::Instruction::FPTrunc, in2, f16_ty); in3 = cast(llvm::Instruction::FPTrunc, in3, f16_ty); Value *ret0, *ret1, *ret2, *ret3; std::tie(ret0, ret1, ret2, ret3) = fp16x4_to_fp8x4(in0, in1, in2, in3); return std::make_tuple(ret0, ret1, ret2, ret3); } std::tuple generator::fp8x4_to_fp32x4(Value *in0, Value *in1, Value *in2, Value *in3){ Value *ret0, *ret1, *ret2, *ret3; std::tie(ret0, ret1, ret2, ret3) = fp8x4_to_fp16x4(in0, in1, in2, in3); ret0 = cast(llvm::Instruction::FPExt, ret0, f32_ty); ret1 = cast(llvm::Instruction::FPExt, ret1, f32_ty); ret2 = cast(llvm::Instruction::FPExt, ret2, f32_ty); ret3 = cast(llvm::Instruction::FPExt, ret3, f32_ty); return std::make_tuple(ret0, ret1, ret2, ret3); } std::tuple generator::fp8x4_to_fp16x4(Value *in0, Value *in1, Value *in2, Value *in3){ Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false), "{" ".reg .b32 a<2>, b<2>; \n\t" "prmt.b32 a0, 0, $2, 0x5040; \n\t" // If input is 0xdcba set a0 to 0xb0a0 "prmt.b32 a1, 0, $2, 0x7060; \n\t" // If input is 0xdcba set a1 to 0xd0c0 "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // b0 = a0 & 0x7fff7fff (strip sign) "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // b1 = a1 & 0x7fff7fff (strip sign) "shr.b32 b0, b0, 1; \n\t" // b0 >>= 1 (shift into fp16 position) "shr.b32 b1, b1, 1; \n\t" // b1 >>= 1 (shift into fp16 position) "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n\t" // out0 = b0 | (0x80008000 & a0) (restore sign) "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n\t" // out1 = b1 | (0x80008000 & a1) (restore sign) "}", "=r,=r,r", false); Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4)); packed_in = insert_elt(packed_in, in0, (uint64_t)0); packed_in = insert_elt(packed_in, in1, (uint64_t)1); packed_in = insert_elt(packed_in, in2, (uint64_t)2); packed_in = insert_elt(packed_in, in3, (uint64_t)3); Value *in = bit_cast(packed_in, i32_ty); Value *ret = call(ptx, {in}); Value *packed_ret0 = extract_val(ret, {0}); Value *packed_ret1 = extract_val(ret, {1}); Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); Value *ret1 = extract_elt(packed_ret0, (uint64_t)1); Value *ret2 = extract_elt(packed_ret1, (uint64_t)0); Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); return std::make_tuple(ret0, ret1, ret2, ret3); } std::tuple generator::fp16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) { /* fp16 bit representation is seeeeemmmmmmmmmm (s=sign, e=exponent, m=mantissa) * fp8 bit representation is seeeemmm * The 4 fp8 exponent bits are the low order 4 exponent bits in fp16. * The 3 fp8 mantissa bits are the high order 3 mantissa bits in fp16. * Note that the low order exponent bits and high order mantissa bits in fp16 are contiguous. * We want to round to nearest fp8 value. To do that add 1 to 4th mantissa bit in fp16 (that's * one more than the number of mantissa bits in fp8). * fp8 = (fp16 & 0x8000) | (((f16 << 1) + 0x0080) & 0x7fff) * * We compute two fp16s in one uint32. The addition could cause bit flips from one fp16 to the * other. To avoid this we zero out the most significant exponent bit. If that bit is set then * the value isn't representable in float8 anyway so we assume it's never set (and give garbage * output if it is). If we were willing to assume the most significant exponent was never set * we could save the first two lop3.b32 instructions below. */ InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false), "{" ".reg .b32 a<2>, b<2>; \n\t" "shl.b32 a0, $1, 1; \n\t" // a0 = input0 << 1 "shl.b32 a1, $2, 1; \n\t" // a1 = input1 << 1 "lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n\t" // a0 = (a0 & 0x7fff7fff) "lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n\t" // a1 = (a1 & 0x7fff7fff) "add.u32 a0, a0, 0x00800080; \n\t" // a0 += 0x00800080 "add.u32 a1, a1, 0x00800080; \n\t" // a1 += 0x00800080 "lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n\t" // b0 = (input0 & 0x80008000) | a0 "lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n\t" // b1 = (input1 & 0x80008000) | a1 "prmt.b32 $0, b0, b1, 0x7531; \n\t" // If b0 = 0xabcd and b1=0x0123 sets output to 0xac02 "}", "=r,r,r", false); Value *packed_in0 = UndefValue::get(vec_ty(f16_ty, 2)); Value *packed_in1 = UndefValue::get(vec_ty(f16_ty, 2)); packed_in0 = insert_elt(packed_in0, in0, (int)0); packed_in0 = insert_elt(packed_in0, in1, (int)1); packed_in1 = insert_elt(packed_in1, in2, (int)0); packed_in1 = insert_elt(packed_in1, in3, (int)1); Value *in_arg0 = bit_cast(packed_in0, i32_ty); Value *in_arg1 = bit_cast(packed_in1, i32_ty); Value *ret = call(ptx, {in_arg0, in_arg1}); Value *ret0 = extract_elt(ret, (int)0); Value *ret1 = extract_elt(ret, (int)1); Value *ret2 = extract_elt(ret, (int)2); Value *ret3 = extract_elt(ret, (int)3); return std::make_tuple(ret0, ret1, ret2, ret3); } std::tuple generator::fp8x4_to_bf16x4(Value *in0, Value *in1, Value *in2, Value *in3) { // current exp offset: 15 // Add 112 (127-15) to compensate the difference in exponent bias // bf16 = (nosign >> (8-4) + 112 << 7) | sign; // bf16 = (nosign >> 4 + 0x3800) | sign; Type *ret_ty = StructType::get(*ctx_, {vec_ty(bf16_ty, 2), vec_ty(bf16_ty, 2)}); InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty}, false), "{" ".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n\t" "prmt.b32 a0, 0, $2, 0x5040; \n\t" // 0xdcba => 0xb0a0 "prmt.b32 a1, 0, $2, 0x7060; \n\t" // 0xdcba => 0xd0c0 "and.b32 sign0, a0, 0x80008000; \n\t" "and.b32 sign1, a1, 0x80008000; \n\t" "and.b32 nosign0, a0, 0x7fff7fff; \n\t" "and.b32 nosign1, a1, 0x7fff7fff; \n\t" "shr.b32 nosign0, nosign0, 4; \n\t" "shr.b32 nosign1, nosign1, 4; \n\t" "add.u32 nosign0, nosign0, 0x38003800; \n\t" "add.u32 nosign1, nosign1, 0x38003800; \n\t" "or.b32 $0, sign0, nosign0; \n\t" "or.b32 $1, sign1, nosign1; \n\t" "}", "=r,=r,r", false); Value *packed_in = UndefValue::get(vec_ty(i8_ty, 4)); packed_in = insert_elt(packed_in, in0, (uint64_t)0); packed_in = insert_elt(packed_in, in1, (uint64_t)1); packed_in = insert_elt(packed_in, in2, (uint64_t)2); packed_in = insert_elt(packed_in, in3, (uint64_t)3); Value *in = bit_cast(packed_in, i32_ty); Value *ret = call(ptx, {in}); Value *packed_ret0 = extract_val(ret, {0}); Value *packed_ret1 = extract_val(ret, {1}); Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); Value *ret1 = extract_elt(packed_ret0, (uint64_t)1); Value *ret2 = extract_elt(packed_ret1, (uint64_t)0); Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); return std::make_tuple(ret0, ret1, ret2, ret3); } std::tuple generator::bf16x4_to_fp8x4(Value *in0, Value *in1, Value *in2, Value *in3) { /* Assuming fp8 exponent offset is 16. bf16 exponent offset is 127. Max value in fp8: 0b01111111 (0x7f), bf16: 3ff0 Min value in fp8: 0b00000000 (0x00) bf16: 0x3c00 // @note: +0x8 is for "rounding to nearest zero" fp8 = (nosign(bf16) - (112 << 7) + 0x8) << 4; return fp8 | sign; // also permute bytes */ InlineAsm *ptx = InlineAsm::get(FunctionType::get({vec_ty(i8_ty, 4)}, {i32_ty, i32_ty}, false), "{\n\t" ".reg .u32 sign, sign<2>, nosign, nosign<2>; \n\t" ".reg .u32 fp8_min, fp8_max, rn_, zero; \n\t" "mov.u32 fp8_min, 0x38003800; \n\t" "mov.u32 fp8_max, 0x3ff03ff0; \n\t" "mov.u32 rn_, 0x80008; \n\t" "mov.u32 zero, 0; \n\t" "and.b32 sign0, $1, 0x80008000; \n\t" "and.b32 sign1, $2, 0x80008000; \n\t" "prmt.b32 sign, sign0, sign1, 0x7531; \n\t" "and.b32 nosign0, $1, 0x7fff7fff; \n\t" "and.b32 nosign1, $2, 0x7fff7fff; \n\t" ".reg .u32 nosign_0_<2>, nosign_1_<2>; \n\t" // nosign = clamp(nosign, min, max) "and.b32 nosign_0_0, nosign0, 0xffff0000; \n\t" "max.u32 nosign_0_0, nosign_0_0, 0x38000000; \n\t" "min.u32 nosign_0_0, nosign_0_0, 0x3ff00000; \n\t" "and.b32 nosign_0_1, nosign0, 0x0000ffff; \n\t" "max.u32 nosign_0_1, nosign_0_1, 0x3800; \n\t" "min.u32 nosign_0_1, nosign_0_1, 0x3ff0; \n\t" "or.b32 nosign0, nosign_0_0, nosign_0_1; \n\t" "and.b32 nosign_1_0, nosign1, 0xffff0000; \n\t" "max.u32 nosign_1_0, nosign_1_0, 0x38000000; \n\t" "min.u32 nosign_1_0, nosign_1_0, 0x3ff00000; \n\t" "and.b32 nosign_1_1, nosign1, 0x0000ffff; \n\t" "max.u32 nosign_1_1, nosign_1_1, 0x3800; \n\t" "min.u32 nosign_1_1, nosign_1_1, 0x3ff0; \n\t" "or.b32 nosign1, nosign_1_0, nosign_1_1; \n\t" "add.u32 nosign0, nosign0, rn_; \n\t" // round to nearest zero "add.u32 nosign1, nosign1, rn_; \n\t" "sub.u32 nosign0, nosign0, 0x38003800; \n\t" // compensate offset "sub.u32 nosign1, nosign1, 0x38003800; \n\t" "shr.u32 nosign0, nosign0, 4; \n\t" "shr.u32 nosign1, nosign1, 4; \n\t" "prmt.b32 nosign, nosign0, nosign1, 0x6420; \n\t" "or.b32 $0, nosign, sign; \n\t" "" "}", "=r,r,r", false); Value *packed_in0 = UndefValue::get(vec_ty(bf16_ty, 2)); Value *packed_in1 = UndefValue::get(vec_ty(bf16_ty, 2)); packed_in0 = insert_elt(packed_in0, in0, (int)0); packed_in0 = insert_elt(packed_in0, in1, (int)1); packed_in1 = insert_elt(packed_in1, in2, (int)0); packed_in1 = insert_elt(packed_in1, in3, (int)1); Value *in_arg0 = bit_cast(packed_in0, i32_ty); Value *in_arg1 = bit_cast(packed_in1, i32_ty); Value *ret = call(ptx, {in_arg0, in_arg1}); Value *ret0 = extract_elt(ret, (int)0); Value *ret1 = extract_elt(ret, (int)1); Value *ret2 = extract_elt(ret, (int)2); Value *ret3 = extract_elt(ret, (int)3); return std::make_tuple(ret0, ret1, ret2, ret3); } Value* generator::bf16_to_fp32(Value *in0){ if (tgt_->as_nvidia()->sm() >= 80) { InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false), "cvt.rn.f32.bf16 $0, $1;", "=r,h", false); return call(ptx, {in0}); } else { Value *ret = UndefValue::get(vec_ty(i16_ty, 2)); ret = insert_elt(ret, bit_cast(in0, i16_ty), (uint64_t)1); ret = insert_elt(ret, bit_cast(builder_->getInt16(0), i16_ty), (uint64_t)0); return bit_cast(ret, f32_ty); } } Value* generator::fp32_to_bf16(Value *in0){ if(tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80){ InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false), "cvt.rn.bf16.f32 $0, $1;", "=h,r", false); return call(ptx, {in0}); } return extract_elt(bit_cast(in0, vec_ty(i16_ty, 2)), (uint64_t)1); } /** * \brief Code Generation for `cast` */ void generator::visit_cast_inst(ir::cast_inst* x) { ir::value *op = x->get_operand(0); ir::type* ret_sca_ty = x->get_type()->get_scalar_ty(); ir::type* op_sca_ty = op->get_type()->get_scalar_ty(); auto x_idxs = idxs_.at(x); auto op_idxs = idxs_.at(op); // <> FP8 if(ret_sca_ty->is_fp8_ty() || op_sca_ty->is_fp8_ty()){ // ensure that conversions can be vectorized int ld = layouts_->get(x)->get_order(0); int contiguous = layouts_->get(x)->to_scanline()->nts(ld); if(contiguous % 4 != 0) throw std::runtime_error("unsupported fp32 -> fp8 conversion"); // run the conversion auto cvt = [&](Value* a, Value* b, Value* c, Value* d){ if(op_sca_ty->is_fp32_ty() && ret_sca_ty->is_fp8_ty()) return fp32x4_to_fp8x4(a, b, c, d); if(op_sca_ty->is_fp16_ty() && ret_sca_ty->is_fp8_ty()) return fp16x4_to_fp8x4(a, b, c, d); if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp16_ty()) return fp8x4_to_fp16x4(a, b, c, d); if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_fp32_ty()) return fp8x4_to_fp32x4(a, b, c, d); // fp8 <> bf16 if(op_sca_ty->is_fp8_ty() && ret_sca_ty->is_bf16_ty()) return fp8x4_to_bf16x4(a, b, c, d); if (op_sca_ty->is_bf16_ty() && ret_sca_ty->is_fp8_ty()) return bf16x4_to_fp8x4(a, b, c, d); throw std::runtime_error("unsupported conversion"); }; for(size_t i = 0; i < x_idxs.size(); i+=4){ std::tie(vals_[x][x_idxs[i+0]], vals_[x][x_idxs[i+1]], vals_[x][x_idxs[i+2]], vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[i+0]], vals_[op][op_idxs[i+1]], vals_[op][op_idxs[i+2]], vals_[op][op_idxs[i+3]]); } return; } // <> BF16 if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){ // FP32 -> BF16 if(op_sca_ty->is_fp32_ty()){ for (indices_t idx: idxs_.at(x)) { Value *arg = vals_[x->get_operand(0)][idx]; vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty); } return; } // BF16 -> FP32 if(ret_sca_ty->is_fp32_ty()){ for(size_t i = 0; i < x_idxs.size(); i++) vals_[x][x_idxs[i + 0]] = bf16_to_fp32(vals_[op][op_idxs[i + 0]]); return; } } Type *ty = cvt(x->get_type()->get_scalar_ty()); auto cvt = [](ir::cast_op_t op){ using ll = llvm::Instruction::CastOps; using tt = ir::cast_op_t; switch(op){ case tt::Trunc: return ll::Trunc; case tt::ZExt: return ll::ZExt; case tt::SExt: return ll::SExt; case tt::FPTrunc: return ll::FPTrunc; case tt::FPExt: return ll::FPExt; case tt::UIToFP: return ll::UIToFP; case tt::SIToFP: return ll::SIToFP; case tt::FPToUI: return ll::FPToUI; case tt::FPToSI: return ll::FPToSI; case tt::PtrToInt: return ll::PtrToInt; case tt::IntToPtr: return ll::IntToPtr; case tt::BitCast: return ll::BitCast; case tt::AddrSpaceCast: return ll::AddrSpaceCast; default: throw std::runtime_error("unreachable switch"); } }; for(indices_t idx: idxs_.at(x)){ Value *arg = vals_[x->get_operand(0)][idx]; vals_[x][idx] = cast(cvt(x->get_op()), arg, ty); } } std::tuple generator::int16_to_float16x8( Value *in0, Value *scale_x512, Value *shift ){ /* unpacking 8 int2s packed into an int16 to 8 float16s * the algorithm is similar to * https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/ fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1492-L1563 */ Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false), "{" ".reg .b32 a<2>, b<4>; \n\t" // input is 0xab,cd,ef,gh,ab,cd,ef,gh, each a, b etc occupies two bits. "and.b32 a0, 0x30300303, $4; \n\t" // set a0 to 0x0b,00,0f,00,00,0d,00,0h "and.b32 a1, 0xc0c00c0c, $4; \n\t" // set a1 to 0xa0,00,e0,00,00,c0,00,g0 "prmt.b32 b0, 0, a0, 0x0504; \n\t" // set b0 to 0x00,00,00,0d,00,00,00,0h "prmt.b32 b1, 0, a1, 0x0504; \n\t" // set b1 to 0x00,00,00,c0,00,00,00,g0 "prmt.b32 b2, 0, a0, 0x0706; \n\t" // set b2 to 0x00,00,0b,00,00,00,0f,00 "prmt.b32 b3, 0, a1, 0x0706; \n\t" // set b3 to 0x00,00,a0,00,00,00,e0,00 "mov.b32 a0, 0x78007800; \n\t" // a0 = 32768 "mov.b32 a1, 0x70007000; \n\t" // a1 = 8192 "mul.f16x2 b0, b0, a0; \n\t" // b0 = b0 * 32768. "mul.f16x2 b1, b1, a1; \n\t" // b1 = b1 * 8192. "mov.b32 a0, 0x68006800; \n\t" // a0 = 2048 "mov.b32 a1, 0x60006000; \n\t" // a1 = 512 "mul.f16x2 b2, b2, a0; \n\t" // b2 = b2 * 2048. "mul.f16x2 b3, b3, a1; \n\t" // b3 = b3 * 512. "fma.rn.f16x2 $0, b0, $5, $6; \n\t" // out0 = b0 * scale + shift. "fma.rn.f16x2 $1, b1, $5, $6; \n\t" // out1 = b1 * scale + shift. "fma.rn.f16x2 $2, b2, $5, $6; \n\t" // out2 = b2 * scale + shift. "fma.rn.f16x2 $3, b3, $5, $6; \n\t" // out3 = b3 * scale + shift. "}", "=r,=r,=r,=r,r,r,r", false); Value *packed_in = UndefValue::get(vec_ty(i16_ty, 2)); packed_in = insert_elt(packed_in, in0, (int)0); packed_in = insert_elt(packed_in, in0, (int)1); Value *in = bit_cast(packed_in, i32_ty); Value *ret = call(ptx, {in, scale_x512, shift}); Value *packed_ret0 = extract_val(ret, {0}); Value *packed_ret1 = extract_val(ret, {1}); Value *packed_ret2 = extract_val(ret, {2}); Value *packed_ret3 = extract_val(ret, {3}); Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // h Value *ret1 = extract_elt(packed_ret1, (uint64_t)0); // g Value *ret2 = extract_elt(packed_ret2, (uint64_t)0); // f Value *ret3 = extract_elt(packed_ret3, (uint64_t)0); // e Value *ret4 = extract_elt(packed_ret0, (uint64_t)1); // d Value *ret5 = extract_elt(packed_ret1, (uint64_t)1); // c Value *ret6 = extract_elt(packed_ret2, (uint64_t)1); // b Value *ret7 = extract_elt(packed_ret3, (uint64_t)1); // a return std::make_tuple(ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7); } std::tuple generator::int32_to_float16x8( Value *in0, Value *scale_x512, Value *shift ){ /* unpacking 8 int4s packed into an int32 to 8 float16s * the algorithm is similar to * https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/ fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1566-L1619 */ Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false), "{" ".reg .b32 a<2>, b<4>; \n\t" "and.b32 a0, 0x0f0f0f0f, $4; \n\t" // If input is 0xabcdefgh set a to 0x0b0d0f0h "and.b32 a1, 0xf0f0f0f0, $4; \n\t" // If input is 0xabcdefgh set a to 0xa0c0e0g0 "prmt.b32 b0, 0, a0, 0x0504; \n\t" // set b0 to 0x000f000h "prmt.b32 b1, 0, a1, 0x0504; \n\t" // set b1 to 0x00e000g0 "prmt.b32 b2, 0, a0, 0x0706; \n\t" // set b2 to 0x000b000d "prmt.b32 b3, 0, a1, 0x0706; \n\t" // set b3 to 0x00a000c0 "mov.b32 a0, 0x78007800; \n\t" "mov.b32 a1, 0x68006800; \n\t" "mul.f16x2 b0, b0, a0; \n\t" // b0 = b0 * 32768. "mul.f16x2 b1, b1, a1; \n\t" // b1 = b1 * 2048. "mul.f16x2 b2, b2, a0; \n\t" // b2 = b2 * 32768. "mul.f16x2 b3, b3, a1; \n\t" // b3 = b3 * 2048. "fma.rn.f16x2 $0, b0, $5, $6; \n\t" // out0 = b0 * scale + shift. "fma.rn.f16x2 $1, b1, $5, $6; \n\t" // out1 = b1 * scale + shift. "fma.rn.f16x2 $2, b2, $5, $6; \n\t" // out0 = b0 * scale + shift. "fma.rn.f16x2 $3, b3, $5, $6; \n\t" // out1 = b1 * scale + shift. "}", "=r,=r,=r,=r,r,r,r", false); Value *ret = call(ptx, {in0, scale_x512, shift}); Value *packed_ret0 = extract_val(ret, {0}); Value *packed_ret1 = extract_val(ret, {1}); Value *packed_ret2 = extract_val(ret, {2}); Value *packed_ret3 = extract_val(ret, {3}); Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // h Value *ret1 = extract_elt(packed_ret1, (uint64_t)0); // g Value *ret2 = extract_elt(packed_ret0, (uint64_t)1); // f Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); // e Value *ret4 = extract_elt(packed_ret2, (uint64_t)0); // d Value *ret5 = extract_elt(packed_ret3, (uint64_t)0); // c Value *ret6 = extract_elt(packed_ret2, (uint64_t)1); // b Value *ret7 = extract_elt(packed_ret3, (uint64_t)1); // a return std::make_tuple(ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7); } std::tuple generator::int32_to_float16x4(Value *in0, Value *scale_x512, Value *shift){ /* unpacking 4 int8s packed into an int32 to 4 fp16s * the algorithm is similar to * https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/ fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1622-L1646 */ Type *ret_ty = StructType::get(*ctx_, {vec_ty(f16_ty, 2), vec_ty(f16_ty, 2)}); InlineAsm *ptx = InlineAsm::get(FunctionType::get(ret_ty, {i32_ty, i32_ty, i32_ty}, false), "{" ".reg .b32 a, b<2>; \n\t" "prmt.b32 b0, 0, $2, 0x0504; \n\t" // If input is 0xabcdefgh set b0 to 0x00ef00gh "prmt.b32 b1, 0, $2, 0x0706; \n\t" // If input is 0xabcdefgh set b1 to 0x00ab00cd "mov.b32 a, 0x78007800; \n\t" "mul.f16x2 b0, b0, a; \n\t" // b0 = b0 * 32768. "mul.f16x2 b1, b1, a; \n\t" // b1 = b1 * 32768. "fma.rn.f16x2 $0, b0, $3, $4; \n\t" // out0 = b0 * scale + shift. "fma.rn.f16x2 $1, b1, $3, $4; \n\t" // out1 = b1 * scale + shift. "}", "=r,=r,r,r,r", false); Value *ret = call(ptx, {in0, scale_x512, shift}); Value *packed_ret0 = extract_val(ret, {0}); Value *packed_ret1 = extract_val(ret, {1}); Value *ret0 = extract_elt(packed_ret0, (uint64_t)0); // gh Value *ret1 = extract_elt(packed_ret0, (uint64_t)1); // ef Value *ret2 = extract_elt(packed_ret1, (uint64_t)0); // cd Value *ret3 = extract_elt(packed_ret1, (uint64_t)1); // ab return std::make_tuple(ret0, ret1, ret2, ret3); } std::tuple generator::prepare_scale_shift(Value *scale, Value *shift){ Value *scale_x512 = fmul(scale, bit_cast(i16(0x6000), f16_ty)); Value *p_scale_x512 = UndefValue::get(vec_ty(f16_ty, 2)); p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)0); p_scale_x512 = insert_elt(p_scale_x512, scale_x512, (int)1); p_scale_x512 = bit_cast(p_scale_x512, i32_ty); Value *p_shift = UndefValue::get(vec_ty(f16_ty, 2)); p_shift = insert_elt(p_shift, shift, (int)0); p_shift = insert_elt(p_shift, shift, (int)1); p_shift = bit_cast(p_shift, i32_ty); return std::make_tuple(p_scale_x512, p_shift); } /** * \brief Code Generation for `dequantize` */ void generator::visit_dequantize_inst(ir::dequantize_inst* x) { ir::value *op = x->get_operand(0); auto src_ty_size_in_bits = op->get_type()->get_scalar_ty()->get_primitive_size_in_bits(); auto ret_last_dim = (x->get_type()->get_block_shapes()).back(); auto op_last_dim = (op->get_type()->get_block_shapes()).back(); auto x_idxs = idxs_.at(x); auto op_idxs = idxs_.at(op); ir::value *scale = x->get_operand(1); ir::value *shift = x->get_operand(2); Value *p_scale_x512, *p_shift; std::tie(p_scale_x512, p_shift) = prepare_scale_shift(vals_[scale][{}], vals_[shift][{}]); int ld = layouts_->get(x)->get_order(0); int contiguous = layouts_->get(x)->to_scanline()->nts(ld); int op_ld = layouts_->get(op)->get_order(0); int op_contiguous = layouts_->get(op)->to_scanline()->nts(op_ld); std::string err_msg; err_msg = "unsupported dequantization, cannot vectorize properly. x_idxs.size(): " + std::to_string(x_idxs.size()) + "; op_idxs.size(): " + std::to_string(op_idxs.size()) + "; contiguous: " + std::to_string(contiguous) + "; op_contiguous: " + std::to_string(op_contiguous) + ". if the condition " "is not met, please try adjusting block_size, num_warps or " "using tl.multiple_of to hint the input/output ptr address."; if (ret_last_dim == 8 * op_last_dim) { if((x_idxs.size() != 8 * op_idxs.size()) || (contiguous != 8 * op_contiguous)) { throw std::runtime_error(err_msg); } auto cvt = [&]( Value* a, Value* scale, Value* shift ){ if (src_ty_size_in_bits == 16){ // int2 quantization, int16 to 8 fp16s return int16_to_float16x8(a, scale, shift); } else if (src_ty_size_in_bits == 32) { // int4 quantization, int32 to 8 fp16s return int32_to_float16x8(a, scale, shift); } else { throw std::runtime_error("unsupported conversion"); } }; for(size_t j = 0; j < op_idxs.size(); j++){ size_t i = j * 8; std::tie(vals_[x][x_idxs[i+0]], vals_[x][x_idxs[i+1]], vals_[x][x_idxs[i+2]], vals_[x][x_idxs[i+3]], vals_[x][x_idxs[i+4]], vals_[x][x_idxs[i+5]], vals_[x][x_idxs[i+6]], vals_[x][x_idxs[i+7]]) = cvt(vals_[op][op_idxs[j]], p_scale_x512, p_shift); } } else if (ret_last_dim == 4 * op_last_dim && src_ty_size_in_bits == 32) { // int8 quantization, int32 to 4 fp16s if((x_idxs.size() != 4 * op_idxs.size()) || (contiguous != 4 * op_contiguous)) { throw std::runtime_error(err_msg); } auto cvt = [&](Value* a, Value* scale, Value* shift){ return int32_to_float16x4(a, scale, shift); }; for(size_t j = 0; j < op_idxs.size(); j++){ size_t i = j * 4; std::tie(vals_[x][x_idxs[i+0]], vals_[x][x_idxs[i+1]], vals_[x][x_idxs[i+2]], vals_[x][x_idxs[i+3]]) = cvt(vals_[op][op_idxs[j]], p_scale_x512, p_shift); } } else { throw std::runtime_error("unsupported dequantization"); } return; } /** * \brief Code Generation for `return` */ void generator::visit_return_inst(ir::return_inst* rr) { ir::value *ret_val = rr->get_return_value(); ret(ret_val ? vals_[ret_val][{}] : nullptr); } /** * \brief Code Generation for `cond_branch` */ void generator::visit_cond_branch_inst(ir::cond_branch_inst* br) { BasicBlock *true_dest = bbs_.at(br->get_true_dest()); BasicBlock *false_dest = bbs_.at(br->get_false_dest()); Value *cond = vals_[br->get_cond()][{}]; cond_br(cond, true_dest, false_dest); } /** * \brief Code Generation for `uncond_branch` */ void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) { BasicBlock *dest = bbs_.at(br->get_dest()); br(dest); } /** * \brief Code Generation for a (synchronous) `load` */ void generator::visit_load_inst(ir::load_inst* x){ BasicBlock *current = builder_->GetInsertBlock(); Module *module = current->getModule(); Value *tid = tgt_->get_local_id(module, *builder_, 0); Value *lane = urem(tid, i32(32)); ir::value *op = x->get_pointer_operand(); ir::masked_load_inst *mx = dynamic_cast(x); Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); #ifdef USE_ROCM // code generation auto idxs = idxs_.at(x); for(size_t i = 0; i CreateLoad(ty, ptr); // upload to global vals map vals_[x][idx] = _ret; } #else // compute vector width size_t vec = 1; bool is_mma_first_row = false; if(op->get_type()->is_block_ty()){ auto ord = ords_.at(op); size_t aln = alignment_->get(op, ord[0]); if(mx){ size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst; max_eq = std::max(max_eq, 1); aln = std::min(aln, max_eq); } analysis::distributed_layout* layout = dynamic_cast(layouts_->get(x)); assert(layout); vec = std::min(layout->contig_per_thread(ord[0]), aln); // TODO: generalize is_mma_first_row = (ord.size() >= 1) && layout->to_mma() && (a_axes_->get(x, ord[0]) == layouts_->get(x)->get_axis(1)); if(is_mma_first_row) vec = std::min(2, aln); } // code generation auto idxs = idxs_.at(x); for(size_t i = 0; i < idxs.size(); i += vec){ indices_t idx = idxs[i]; // pointer value Value *ptr = vals_[op][idx]; // masked load size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; // input ptr info GetElementPtrInst *in_gep = dyn_cast(ptr); size_t in_off; if(in_gep){ ConstantInt* cst = dyn_cast(in_gep->idx_begin()); in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; ptr = cst ? in_gep->getPointerOperand() : in_gep; } else{ in_off = 0; } Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue(); // if(!op->get_type()->is_block_ty()){ // pred = builder_->CreateAnd(pred, icmp_eq(tid, i32(0))); // } Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr; size_t nbits = dtsize*8; // pack sub-words (< 32/64bits) into words // each load has width min(nbits*vec, 32/64) // and there are (nbits * vec)/width of them int max_word_width = std::max(32, nbits); int tot_width = nbits*vec; int width = std::min(tot_width, max_word_width); int n_words = std::max(1, tot_width / width); bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; has_l2_evict_policy = false; // has_evict_policy = false; // currently disable until supported in `store` // ----- // create inline asm string // ----- std::ostringstream asm_oss; asm_oss << "@$" << n_words; // predicate asm_oss << " ld"; if(x->get_is_volatile()) asm_oss << ".volatile"; asm_oss << ".global"; if (x->get_cache_modifier() == ir::load_inst::CA) asm_oss << ".ca"; if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg"; if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first"; if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last"; if (has_l2_evict_policy) asm_oss << ".L2::cache_hint"; if(n_words > 1) asm_oss << ".v" << n_words; // vector width asm_oss << ".b" << width; // word size asm_oss << " {"; for(int i = 0; i < n_words; i++){ // return values if(i > 0) asm_oss << ","; asm_oss << "$" << i; } asm_oss << "}"; asm_oss << ", [ $" << n_words + 1; // load asm_oss << " + " << in_off << "]"; // constant offset if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2; asm_oss << ";"; bool has_other = other && (other != UndefValue::get(other->getType())); std::vector others; // handle `other` values for indices where the mask // is false if(has_other) for(size_t ii = 0; ii < n_words; ii++){ size_t size = width / nbits; Value *v = UndefValue::get(vec_ty(ty, size)); for(size_t s = 0; s < size; s++){ ir::value *false_val = mx->get_false_value_operand(); v = insert_elt(v, vals_[false_val][idxs[i + ii*size + s]], s); } v = bit_cast(v, IntegerType::get(*ctx_, width)); // PTX doesn't support mov.u8, so we need to use mov.u16 auto mov_width = width < 16 ? 16 : width; asm_oss << "\n "; asm_oss << "@!$" << n_words << " mov.u" << mov_width; asm_oss << " $" << ii << ", "; std::ios_base::fmtflags flags(asm_oss.flags()); if(ConstantInt* cst = dyn_cast(v)) asm_oss << "0x" << std::hex << cst->getSExtValue(); else{ asm_oss << "$" << n_words + has_l2_evict_policy + 2 + ii; others.push_back(v); } asm_oss.flags(flags); asm_oss << ";"; } // ---- // create inline ASM signature // --- std::vector ret_tys(n_words, IntegerType::get(*ctx_, width)); Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0]; // ret_ty->print(llvm::outs()); std::vector arg_tys = {pred->getType(), ptr->getType()}; for(Value *v: others) arg_tys.push_back(v->getType()); if (has_l2_evict_policy) arg_tys.push_back(i64_ty); FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false); // --- // create inline ASM constraints // --- std::string asm_cstrt; for(int ii = 0; ii < n_words; ii++){ if(ii > 0) asm_cstrt += ","; asm_cstrt += (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); } asm_cstrt += ",b,l"; for(size_t ii = 0; ii < others.size(); ii++){ asm_cstrt += ","; asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c"); } if (has_l2_evict_policy) asm_cstrt += ",l"; // --- // finally call inline ASM // --- InlineAsm *inlineAsm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true); std::vector args = {pred, ptr}; for(Value *v: others) args.push_back(v); if (has_l2_evict_policy) args.push_back(policies_.at(x->get_eviction_policy())); Value *_ret = call(inlineAsm, args); // if(!op->get_type()->is_block_ty()){ // Value* cond = icmp_eq(tid, i32(0)); // Value* shptr = bit_cast(shmem_, ptr_ty(_ret->getType(), 3)); // Instruction* bar = add_barrier(); // Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, bar, false); // builder_->SetInsertPoint(term); // store(_ret, shptr); // builder_->SetInsertPoint(bar->getParent()); // _ret = load(shptr); // add_barrier(); // } // --- // extract and store return values // --- std::vector rets; for(unsigned int ii = 0; ii < n_words; ii++){ Value *curr; if(ret_ty->isStructTy()) curr = extract_val(_ret, {ii}); else curr = _ret; rets.push_back(bit_cast(curr, vec_ty(ty, width / (dtsize*8)))); } int tmp = (width / (dtsize * 8)); for(size_t ii = 0; ii < vec; ii++) vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp); } #endif } void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) { visit_load_inst(x); } void generator::visit_masked_load_inst(ir::masked_load_inst* x) { visit_load_inst(x); } /** * \brief Code Generation for a (synchronous) `store` */ void generator::visit_store_inst(ir::store_inst * x){ ir::masked_store_inst *mx = dynamic_cast(x); // operands ir::value *ptr_op = x->get_pointer_operand(); ir::value *val_op = x->get_value_operand(); #ifdef USE_ROCM auto idxs = idxs_.at(val_op); Type *ty = cvt(val_op->get_type()->get_scalar_ty()); for (size_t i = 0; i < idxs.size(); i += 1) { auto idx = idxs[i]; // pointer Value *ptr = vals_[ptr_op][idx]; // value Value *val = vals_.at(val_op)[idxs[i]]; // store value at pointer store(val, ptr); } #else ir::value *msk_op = nullptr; if(auto* msk_st = dynamic_cast(x)) msk_op = msk_st->get_mask_operand(); // vector size size_t vec = 1; if(val_op->get_type()->is_block_ty()){ auto ord = ords_.at(x->get_pointer_operand()); size_t aln = alignment_->get(ptr_op, ord[0]); size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous; if(mx){ size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst; max_eq = std::max(max_eq, 1); aln = std::min(aln, max_eq); } analysis::distributed_layout* layout = dynamic_cast(layouts_->get(ptr_op)); assert(layout); // vec = std::min(nts, aln); vec = std::min(layout->contig_per_thread(ord[0]), aln); // TODO: generalize bool is_mma_first_row = (ord.size() >= 1) && layout->to_mma() && (a_axes_->get(ptr_op, ord[0]) == layouts_->get(ptr_op)->get_axis(1)); if(is_mma_first_row) vec = std::min(2, aln); } bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; has_l2_evict_policy = false; auto idxs = idxs_.at(val_op); Type *ty = cvt(val_op->get_type()->get_scalar_ty()); if(ty->isIntegerTy(1)) ty = builder_->getInt8Ty(); for(size_t i = 0; i < idxs.size(); i += vec){ indices_t idx = idxs[i]; // pointers Value *ptr = vals_[ptr_op][idx]; size_t dtsize = std::max(1, val_op->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8); GetElementPtrInst *in_gep = dyn_cast(ptr); size_t in_off; if(in_gep){ ConstantInt* cst = dyn_cast(in_gep->idx_begin()); in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; ptr = cst ? in_gep->getPointerOperand() : in_gep; } else{ in_off = 0; } // mask Value *pred = msk_op ? vals_[msk_op][idx] : builder_->getTrue(); size_t nbits = dtsize*8; // pack sub-words (< 32/64bits) into words // each load has width min(nbits*vec, 32/64) // and there are (nbits * vec)/width of them int max_word_width = std::max(32, nbits); int tot_width = nbits*vec; int width = std::min(tot_width, max_word_width); int n_words = std::max(1, tot_width / width); // ----- // create inline asm string // ----- std::ostringstream asm_oss; asm_oss << "@$0"; // predicate asm_oss << " st.global"; if (has_l2_evict_policy) asm_oss << ".L2::cache_hint"; if(n_words > 1) asm_oss << ".v" << n_words; // vector width asm_oss << ".b" << width; // word size asm_oss << " [ $1 + " << in_off << "]"; asm_oss << " , {"; for(int i = 0; i < n_words; i++){ // return values if(i > 0) asm_oss << ","; asm_oss << "$" << 2 + i; } asm_oss << "}"; if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2; asm_oss << ";"; // ---- // create inline ASM signature // --- Type* val_arg_ty = IntegerType::get(*ctx_, width); std::vector arg_tys = {pred->getType(), ptr->getType()}; for(int ii = 0; ii < n_words; ii++) arg_tys.push_back(val_arg_ty); if (has_l2_evict_policy) arg_tys.push_back(i64_ty); FunctionType *asm_ty = FunctionType::get(builder_->getVoidTy(), arg_tys, false); // --- // create inline ASM constraints // --- std::string asm_cstrt = "b,l"; for(int ii = 0; ii < n_words; ii++){ asm_cstrt += ","; asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c"); } if (has_l2_evict_policy) asm_cstrt += ",l"; // --- // finally call inline ASM // --- InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true); std::vector args = {pred, ptr}; for(unsigned int ii = 0; ii < n_words; ii++){ size_t n_subw = width / nbits; Value* curr = UndefValue::get(vec_ty(ty, n_subw)); for(unsigned int jj = 0; jj < n_subw; jj++){ Value* new_elt = vals_[val_op][idxs[i + ii*n_subw + jj]]; if(new_elt->getType()->isIntegerTy(1)) new_elt = builder_->CreateSExt(new_elt, builder_->getInt8Ty()); new_elt = bit_cast(new_elt, ty); curr = builder_->CreateInsertElement(curr, new_elt, jj); } args.push_back(bit_cast(curr, val_arg_ty)); } if (has_l2_evict_policy) args.push_back(policies_.at(x->get_eviction_policy())); call(_asm, args); } #endif } void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) { visit_store_inst(x); } void generator::visit_masked_store_inst(ir::masked_store_inst* x) { visit_store_inst(x); } // -- void generator::visit_extract_value_inst(ir::extract_value_inst *x) { auto idxs = idxs_.at(x); ir::value* agg = x->get_operand(0); unsigned insert_idx = x->get_idx(); for(size_t i = 0; i < idxs.size(); i++){ auto idx = idxs[i]; vals_[x][idx] = builder_->CreateExtractValue(vals_[agg][idx], {insert_idx}); } } void generator::visit_insert_value_inst(ir::insert_value_inst *x){ auto idxs = idxs_.at(x); ir::value* agg = x->get_operand(0); ir::value* val = x->get_operand(1); unsigned insert_idx = x->get_idx(); for(size_t i = 0; i < idxs.size(); i++){ auto idx = idxs[i]; vals_[x][idx] = builder_->CreateInsertValue(vals_[agg][idx], vals_[val][idx],{insert_idx}); } } // -- /** * \brief Code Generation for `cat` */ void generator::visit_cat_inst(ir::cat_inst* x) { auto idxs = idxs_.at(x); ir::value* lhs = x->get_operand(0); ir::value* rhs = x->get_operand(1); int i = 0; for(size_t j = 0; j < idxs_.at(lhs).size(); j ++){ vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]]; } for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){ vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]]; } } /** * \brief Code Generation for `reshape` */ void generator::visit_reshape_inst(ir::reshape_inst* x) { auto idxs = idxs_.at(x); for(size_t i = 0; i < idxs_.at(x).size(); i ++){ ir::value* op = x->get_operand(0); vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]]; }; } /** * \brief Code Generation for `splat` */ void generator::visit_splat_inst(ir::splat_inst* x) { for(auto idx: idxs_.at(x)) vals_[x][idx] = vals_[x->get_operand(0)][{}]; } /** * \brief Code Generation for `broadcast` */ void generator::visit_broadcast_inst(ir::broadcast_inst* x) { ir::value* op = x->get_operand(0); const auto& shape = op->get_type()->get_block_shapes(); for(auto out_idx: idxs_.at(x)){ indices_t in_idx = out_idx; for(size_t k = 0; k < in_idx.size(); k++) in_idx[k] = shape[k] == 1 ? i32(0) : in_idx[k]; vals_[x][out_idx] = vals_[op][in_idx]; } // for(size_t i = 0; i < idxs_.at(x).size(); i++) // vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]]; } /** * \brief Code Generation for `downcast` */ void generator::visit_downcast_inst(ir::downcast_inst* x) { vals_[x][{}] = vals_[x->get_operand(0)][{i32(0)}]; } /** * \brief Code Generation for `get_program_id` */ void generator::visit_get_program_id_inst(ir::get_program_id_inst* pid) { Module *module = builder_->GetInsertBlock()->getModule(); Value *ret = tgt_->get_block_id(module, *builder_, pid->get_axis()); vals_[pid][{}] = ret; } /** * \brief Code Generation for `get_num_programs` */ void generator::visit_get_num_programs_inst(ir::get_num_programs_inst* np) { Module *module = builder_->GetInsertBlock()->getModule(); Value *ret = tgt_->get_num_blocks(module, *builder_, np->get_axis()); vals_[np][{}] = ret; } /** * \brief Code Generation for `exp` */ void generator::visit_exp_inst(ir::exp_inst* x){ Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634); std::vector tys = {f32_ty}; FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); #ifdef USE_ROCM llvm::Function *ex2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::exp2, tys); #else InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false); #endif for(auto idx: idxs_.at(x)){ Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e); // Value *ex2arg = vals_[x->get_operand(0)][idx]; vals_[x][idx] = call(ex2, std::vector{ex2arg}); } } /** * \brief Code Generation for `cos` */ void generator::visit_cos_inst(ir::cos_inst* x){ std::vector tys = {f32_ty}; FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); #ifdef USE_ROCM llvm::Function *cos = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::cos, tys); #else InlineAsm *cos = InlineAsm::get(fn_ty, "cos.approx.f32 $0, $0;", "=f,0", false); #endif for(auto idx: idxs_.at(x)){ vals_[x][idx] = call(cos, std::vector{vals_[x->get_operand(0)][idx]}); } } /** * \brief Code Generation for `umulhi` */ void generator::visit_umulhi_inst(ir::umulhi_inst* x){ std::vector tys = {i32_ty, i32_ty}; FunctionType *fn_ty = FunctionType::get(i32_ty, tys, false); InlineAsm *umulhi = InlineAsm::get(fn_ty, "mul.hi.u32 $0, $1, $2;", "=r,r,r", false); for(auto idx: idxs_.at(x)){ Value* lhs = vals_[x->get_operand(0)][idx]; Value* rhs = vals_[x->get_operand(1)][idx]; vals_[x][idx] = call(umulhi, std::vector{lhs, rhs}); } } /** * \brief Code Generation for `sin` */ void generator::visit_sin_inst(ir::sin_inst* x){ std::vector tys = {f32_ty}; FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); #ifdef USE_ROCM llvm::Function *sin = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::sin, tys); #else InlineAsm *sin = InlineAsm::get(fn_ty, "sin.approx.f32 $0, $0;", "=f,0", false); #endif for(auto idx: idxs_.at(x)){ vals_[x][idx] = call(sin, std::vector{vals_[x->get_operand(0)][idx]}); } } /** * \brief Code Generation for `log` */ void generator::visit_log_inst(ir::log_inst* x){ Constant *rcplog2e = ConstantFP::get(f32_ty, 0.6931471805599453); std::vector tys = {f32_ty}; FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); #ifdef USE_ROCM llvm::Function *lg2 = llvm::Intrinsic::getDeclaration(mod_, Intrinsic::log2, tys); #else InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false); #endif for(auto idx: idxs_.at(x)){ Value *lg2arg = call(lg2, std::vector{vals_[x->get_operand(0)][idx]}); vals_[x][idx] = fmul(lg2arg, rcplog2e); } } /** * \brief Code Generation for `atomic_cas` */ #if defined(USE_ROCM) void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { BasicBlock *current = builder_->GetInsertBlock(); Module *module = current->getModule(); Value *tid = tgt_->get_local_id(module, *builder_, 0); Value *pred = icmp_eq(tid, i32(0)); BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); add_barrier(); tgt_->add_memfence(module, *builder_); cond_br(pred, tid_0_bb, tid_0_done_bb); builder_->SetInsertPoint(tid_0_bb); Value *cas_ptr = vals_[cas->get_operand(0)][{}]; Value *cas_cmp = vals_[cas->get_operand(1)][{}]; Value *cas_val = vals_[cas->get_operand(2)][{}]; Value *old = atomic_cmp_xchg(cas_ptr, cas_cmp, cas_val, MaybeAlign(), AtomicOrdering::Monotonic, AtomicOrdering::Monotonic); old = extract_val(old, std::vector{0}); Value *atom_ptr; atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), ""); atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3)); store(old, atom_ptr); br(tid_0_done_bb); builder_->SetInsertPoint(tid_0_done_bb); tgt_->add_memfence(module, *builder_); add_barrier(); vals_[cas][{}] = load(atom_ptr); add_barrier(); } #else void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { BasicBlock *current = builder_->GetInsertBlock(); Module *module = current->getModule(); Value *tid = tgt_->get_local_id(module, *builder_, 0); Value *pred = icmp_eq(tid, i32(0)); // BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); // BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); add_barrier(); tgt_->add_memfence(module, *builder_); Value *atom_ptr; atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), ""); atom_ptr = bit_cast(atom_ptr, ptr_ty(cvt(cas->get_type()->get_scalar_ty()), 3)); // cond_br(pred, tid_0_bb, tid_0_done_bb); // builder_->SetInsertPoint(tid_0_bb); Value *cas_ptr = vals_[cas->get_operand(0)][{}]; Value *cas_cmp = vals_[cas->get_operand(1)][{}]; Value *cas_val = vals_[cas->get_operand(2)][{}]; std::string asm_str = "@$1 atom.global.cas.b32 $0, [$2], $3, $4;"; FunctionType *fn_ty = FunctionType::get(i32_ty, {pred->getType(), cas_ptr->getType(), cas_cmp->getType(), cas_val->getType()}, false); InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, "=r,b,l,r,r", true); add_barrier(); Value *old = call(iasm, {pred, cas_ptr, cas_cmp, cas_val}); add_barrier(); std::string asm2_str = "@$0 st.shared.b32 [$1], $2;"; FunctionType *fn2_ty = FunctionType::get(void_ty, {pred->getType(), atom_ptr->getType(), old->getType()}, false); InlineAsm *iasm2 = InlineAsm::get(fn2_ty, asm2_str, "b,r,r", true); add_barrier(); call(iasm2, {pred, atom_ptr, old}); tgt_->add_memfence(module, *builder_); add_barrier(); vals_[cas][{}] = load(atom_ptr); add_barrier(); } #endif // defined(USE_ROCM) /** * \brief Code Generation for `atomic_rmw` */ #if defined(USE_ROCM) void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { if (atom->get_op() == ir::atomic_rmw_op_t::Add || atom->get_op() == ir::atomic_rmw_op_t::Max || atom->get_op() == ir::atomic_rmw_op_t::Min || atom->get_op() == ir::atomic_rmw_op_t::UMax || atom->get_op() == ir::atomic_rmw_op_t::UMin || atom->get_op() == ir::atomic_rmw_op_t::Xchg) { BasicBlock *current = builder_->GetInsertBlock(); Module *module = current->getModule(); Value *rmw_ptr = vals_[atom->get_operand(0)][{}]; Value *rmw_val = vals_[atom->get_operand(1)][{}]; Value *tid = tgt_->get_local_id(module, *builder_, 0); Value *pred = icmp_eq(tid, i32(0)); BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); tgt_->add_memfence(module, *builder_); add_barrier(); cond_br(pred, tid_0_bb, tid_0_done_bb); builder_->SetInsertPoint(tid_0_bb); AtomicRMWInst::BinOp binop; switch (atom->get_op()) { case ir::atomic_rmw_op_t::Add: binop = AtomicRMWInst::Add; break; case ir::atomic_rmw_op_t::Max: binop = AtomicRMWInst::Max; break; case ir::atomic_rmw_op_t::Min: binop = AtomicRMWInst::Min; break; case ir::atomic_rmw_op_t::UMax: binop = AtomicRMWInst::UMax; break; case ir::atomic_rmw_op_t::UMin: binop = AtomicRMWInst::UMin; break; case ir::atomic_rmw_op_t::Xchg: binop = AtomicRMWInst::Xchg; break; default: throw std::runtime_error("Not supported!"); } atomic_rmw(binop, rmw_ptr, rmw_val, MaybeAlign(), AtomicOrdering::Monotonic, SyncScope::System); br(tid_0_done_bb); builder_->SetInsertPoint(tid_0_done_bb); tgt_->add_memfence(module, *builder_); return; } throw std::runtime_error("Not supported!"); } #else // defined(USE_ROCM) void generator::visit_atomic_rmw_inst(ir::atomic_rmw_inst *atom) { ir::value *ptr = atom->get_operand(0); ir::value* val = atom->get_operand(1); ir::value* msk = atom->get_operand(2); // vector size int vec = 1; Value *mask = builder_->getInt1(true); if(atom->get_type()->is_block_ty()){ auto shape = atom->get_type()->get_block_shapes(); int ld = ords_.at(ptr)[0]; unsigned alignment = alignment_->get(ptr, ld); vec = std::min(layouts_->get(ptr)->to_scanline()->nts(ld), alignment); vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_fp16_ty() ? 2 : 1); // mask out inactive threads analysis::data_layout* layout = layouts_->get(val); auto curr_axes = a_axes_->get(val); auto layt_axes = layout->get_axes(); for(unsigned k = 0; k < layt_axes.size(); k++){ unsigned ax = layt_axes.at(k); distributed_axis dax = axes_.at(ax); // axis is part of the original layout: thread id should be 0 // but not the current layout if(std::find(curr_axes.begin(), curr_axes.end(), ax) == curr_axes.end()) mask = and_(mask, icmp_eq(dax.thread_id, i32(0))); } // last axis may spillover Value *thread_id = tgt_->get_local_id(mod_, *builder_, 0); int per_thread = 1; for(int ax: layt_axes) { per_thread *= axes_.at(ax).contiguous; } int numel = 1; for(int s: layout->get_shape()) { numel *= s; } mask = and_(mask, icmp_ult(mul(thread_id, i32(per_thread)), i32(numel))); } for(int i = 0; i < idxs_.at(val).size(); i += vec){ auto idx = idxs_[val][i]; Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec)); for(int ii = 0; ii < vec; ii++) rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii); Value *rmw_ptr = vals_[ptr][idx]; Value *rmw_msk = vals_[msk][idx]; rmw_msk = and_(rmw_msk, mask); if(vec == 1) rmw_val = extract_elt(rmw_val, i32(0)); Type* ty = rmw_val->getType(); size_t nbits = ty->getScalarSizeInBits(); // extract pointer offset std::string offset = ""; if(GetElementPtrInst *gep = dyn_cast(rmw_ptr)) if(gep->getNumIndices() == 1) if(ConstantInt *cst = dyn_cast(gep->idx_begin())){ offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8); rmw_ptr = gep->getPointerOperand(); } rmw_ptr = bit_cast(rmw_ptr, ty->getPointerTo(1)); // asm argument type std::vector arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()}; // asm function type FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false); // asm string std::string s_nbits = std::to_string(nbits); std::string name; std::string s_ty; using tt = ir::atomic_rmw_op_t; switch(atom->get_op()){ case tt::Or: name = "or"; s_ty = "b"; break; case tt::And: name = "and"; s_ty = "b"; break; case tt::Xor: name = "xor", s_ty = "b"; break; case tt::Add: name = "add" , s_ty = "s"; break; case tt::Min: name = "min", s_ty = "s"; break; case tt::Max: name = "max", s_ty = "s"; break; case tt::UMin: name = "min", s_ty = "u"; break; case tt::UMax: name = "max", s_ty = "u"; break; case tt::FAdd: name = "add", s_ty = "f"; break; case tt::Xchg: name = "exch", s_ty = "b"; break; } std::string s_vec = vec == 2 ? "x2" : ""; std::string mod = nbits == 16 ? ".noftz" : ""; std::string asm_str = "@$1 atom.global.gpu." + name + mod + "." + s_ty + s_nbits + s_vec + " $0, [$2" + offset + "], $3;"; std::string ty_id = nbits*vec == 64 ? "l" : (nbits*vec == 32 ? "r" : "h"); std::string constraint = "=" + ty_id + ",b,l," + ty_id; // create inline asm InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); // call asm if(atom->get_type()->is_block_ty()) vals_[atom][idx] = call(iasm, (ArrayRef{rmw_msk, rmw_ptr, rmw_val})); else{ Module *mod = builder_->GetInsertBlock()->getModule(); tgt_->add_memfence(mod, *builder_); add_barrier(); Value *tid = tgt_->get_local_id(mod, *builder_, 0); rmw_msk = builder_->CreateAnd(rmw_msk, icmp_eq(tid, i32(0))); Value *old = call(iasm, (ArrayRef{rmw_msk, rmw_ptr, rmw_val})); Value *atom_ptr; atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(atom)))), ""); atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3)); store(old, atom_ptr); add_barrier(); vals_[atom][idx] = load(atom_ptr); add_barrier(); } } } #endif // defined(USE_ROCM) /** * \brief Code Generation for `mma.884` (V100) */ //TODO: clean-up void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) { // shapes auto shape_c = C->get_type()->get_block_shapes(); auto shape_a = A->get_type()->get_block_shapes(); auto shape_b = B->get_type()->get_block_shapes(); // order auto ord_a = layouts_->get(A)->get_order(); auto ord_b = layouts_->get(B)->get_order(); bool is_a_trans = C->is_trans_a(); // is_a_trans = false; if(C->is_trans_a()){ std::swap(ord_a[0], ord_a[1]); std::swap(shape_a[0], shape_a[1]); std::swap(offset_a_m_, offset_a_k_); } // std::cout << "visiting" << std::endl; // if(C->is_trans_b()){ // std::swap(ord_b[0], ord_b[1]); // std::swap(shape_b[0], shape_b[1]); // } // layouts analysis::mma_layout* layout_c = layouts_->get(C)->to_mma(); analysis::shared_layout* layout_a = layouts_->get(A)->to_shared(); analysis::shared_layout* layout_b = layouts_->get(B)->to_shared(); // vectorization int vec_a = swizzle_->get_vec(layout_a); int vec_b = swizzle_->get_vec(layout_b); // strides bool is_a_row = ord_a[0] != 0; bool is_b_row = ord_b[0] != 0; int stride_am = is_a_row ? shape_a[1] : 1; int stride_ak = is_a_row ? 1 : shape_a[0]; int stride_a0 = is_a_row ? stride_ak : stride_am; int stride_a1 = is_a_row ? stride_am : stride_ak; int stride_bn = is_b_row ? 1 : shape_b[0]; int stride_bk = is_b_row ? shape_b[1] : 1; int stride_b0 = is_b_row ? stride_bn : stride_bk; int stride_b1 = is_b_row ? stride_bk : stride_bn; int stride_rep_m = layout_c->wpt(0) * layout_c->fpw(0) * 8; int stride_rep_n = layout_c->wpt(1) * layout_c->fpw(1) * 8; int stride_rep_k = 1; // swizzling int per_phase_a = swizzle_->get_per_phase(layout_a); int max_phase_a = swizzle_->get_max_phase(layout_a); int step_a0 = is_a_row ? stride_rep_k : stride_rep_m; int num_ptr_a = std::max(2 * per_phase_a * max_phase_a / step_a0, 1); int per_phase_b = swizzle_->get_per_phase(layout_b); int max_phase_b = swizzle_->get_max_phase(layout_b); int step_b0 = is_b_row ? stride_rep_n : stride_rep_k; int num_ptr_b = std::max(2 * per_phase_b * max_phase_b / step_b0, 1); // max_phase_a = 4; // vec_a = 8; // std::cout << per_phase_a << " " << max_phase_a << " " << step_a0 << " " << num_ptr_a << " " << stride_am << " " << stride_ak << " " << stride_a0 << " " << stride_a1 << std::endl; // std::cout << vec_a << " " << vec_b << std::endl; /* --------------------------------- */ /* --- pre-compute pointer lanes --- */ /* --------------------------------- */ BasicBlock* curr_bb = builder_->GetInsertBlock(); BasicBlock* entry = &curr_bb->getParent()->getEntryBlock(); if(entry != curr_bb) builder_->SetInsertPoint(entry->getTerminator()); Value* off_a0 = is_a_row ? offset_a_k_[layout_c] : offset_a_m_[layout_c]; Value* off_a1 = is_a_row ? offset_a_m_[layout_c] : offset_a_k_[layout_c]; Value* phase_a = urem(udiv(off_a1, i32(per_phase_a)), i32(max_phase_a)); std::vector off_a(num_ptr_a); for(int i = 0; i < num_ptr_a; i++){ Value* off_a0i = add(off_a0, i32(i*(is_a_row?4:stride_rep_m))); off_a0i = exact_udiv(off_a0i, i32(vec_a)); off_a0i = xor_(off_a0i, phase_a); off_a0i = mul(off_a0i, i32(vec_a)); off_a[i] = add(mul(off_a0i, i32(stride_a0)), mul(off_a1, i32(stride_a1))); } Value* off_b0 = is_b_row ? offset_b_n_[layout_c] : offset_b_k_[layout_c]; Value* off_b1 = is_b_row ? offset_b_k_[layout_c] : offset_b_n_[layout_c]; Value* phase_b = urem(udiv(off_b1, i32(per_phase_b)), i32(max_phase_b)); std::vector off_b(num_ptr_b); for(int i = 0; i < num_ptr_b; i++){ Value* off_b0i = add(off_b0, i32(i*(is_b_row?stride_rep_n:4))); off_b0i = udiv(off_b0i, i32(vec_b)); off_b0i = xor_(off_b0i, phase_b); off_b0i = mul(off_b0i, i32(vec_b)); off_b[i] = add(mul(off_b0i, i32(stride_b0)), mul(off_b1, i32(stride_b1))); } builder_->SetInsertPoint(curr_bb); /* --------------------------------- */ /* --- MMA intrinsic --- */ /* --------------------------------- */ Type *f16x2_ty = vec_ty(f16_ty, 2); Type *ret_ty = StructType::get(*ctx_, {f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty}); std::vector arg_ty = {f16x2_ty, f16x2_ty, f16x2_ty, f16x2_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty}; InlineAsm *mma = InlineAsm::get(FunctionType::get(ret_ty, arg_ty, false), " mma.sync.aligned.m8n8k4." + std::string(is_a_row ? "row" : "col") + "." + std::string(is_b_row ? "row" : "col") + ".f32.f16.f16.f32 " "{$0, $1, $2, $3, $4, $5, $6, $7}, " "{$8, $9}, " "{$10, $11}, " "{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false); std::vector ptr_a(num_ptr_a); std::vector ptr_b(num_ptr_b); std::map, std::pair> has, hbs; for(int i = 0; i < num_ptr_a; i++) ptr_a[i] = gep(shmems_[A], off_a[i]); for(int i = 0; i < num_ptr_b; i++) ptr_b[i] = gep(shmems_[B], off_b[i]); // initialize accumulators std::vector acc; for(indices_t idx: idxs_.at(C)) acc.push_back(vals_[D][idx]); unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->shape_per_cta(0); unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->shape_per_cta(1); // create mma & unpack result auto call_mma = [&](unsigned m, unsigned n, unsigned K) { auto ha = has[{m, K}]; auto hb = hbs[{n, K}]; // arguments std::vector idx = { (m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m, (m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m, (m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m, (m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m }; std::vector args = {ha.first, ha.second, hb.first, hb.second}; for(unsigned i = 0; i < 8; i++) args.push_back(acc[idx[i]]); // execute mma Value *nc = call(mma, args); // unpack for(unsigned i = 0; i < 8; i++) acc[idx[i]] = extract_val(nc, {i}); }; ir::phi_node* phiA = dynamic_cast(A); ir::phi_node* phiB = dynamic_cast(B); // Cache lds value. If values are prefetched, create phi node // @param inc: incoming block (0 = header, 1 = loop) auto register_lds = [&](decltype(has)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) { if (K == 0 && is_prefetch) { ir::basic_block* inc_block = phiA->get_incoming_block(inc); lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block)); lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block)); } else vals[{m, K}] = {val0, val1}; }; auto load_a = [&](int m, int K, int inc, bool is_prefetch) { int offidx = (is_a_row ? K/4 : m) % num_ptr_a; Value* ptra; if(K==0 && is_prefetch){ if(inc == 0) ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]); else ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]); } else ptra = ptr_a[offidx]; int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K; Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak)); Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3))); // record lds that needs to be moved if (K == 0 && inc == 1 && is_prefetch) prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(ha); Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty); Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty); register_lds(has, m, K, inc, ha00, ha01, is_prefetch); if(vec_a > 4){ Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty); Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty); if(is_a_row) register_lds(has, m, K+4, inc, ha10, ha11, is_prefetch); else register_lds(has, m+1, K, inc, ha10, ha11, is_prefetch); } }; auto load_b = [&](int n, int K, int inc, bool is_prefetch) { int offidx = (is_b_row? n : K/4) % num_ptr_b; Value* ptrb; if(K==0 && is_prefetch){ if(inc == 0) ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]); else ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]); } else ptrb = ptr_b[offidx]; int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n; int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b); Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk)); Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3))); // record lds that needs to be moved if (K == 0 && inc == 1 && is_prefetch) prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hb); Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty); Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty); register_lds(hbs, n, K, inc, hb00, hb01, is_prefetch); if(vec_b > 4){ Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty); Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty); if(is_b_row) register_lds(hbs, n+1, K, inc, hb10, hb11, is_prefetch); else register_lds(hbs, n, K+4, inc, hb10, hb11, is_prefetch); } }; // update accumulators if (C->is_prefetched()) { // create phis builder_->SetInsertPoint(curr_bb->getFirstNonPHI()); for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) { has[{m, 0}].first = phi(f16x2_ty, 2); has[{m, 0}].second = phi(f16x2_ty, 2); if (!is_a_row && vec_a>4) { has[{m+1, 0}].first = phi(f16x2_ty, 2); has[{m+1, 0}].second = phi(f16x2_ty, 2); } } for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) { hbs[{n, 0}].first = phi(f16x2_ty, 2); hbs[{n, 0}].second = phi(f16x2_ty, 2); if (is_b_row && vec_b>4) { hbs[{n+1, 0}].first = phi(f16x2_ty, 2); hbs[{n+1, 0}].second = phi(f16x2_ty, 2); } } // insert prefetched lds at the end of loop header builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator()); for (unsigned m = 0; m < num_m/2; m += is_a_row?1:2) load_a(m, 0, 0, true); for (unsigned n = 0; n < num_n/2; n += is_b_row?2:1) load_b(n, 0, 0, true); // update accumulators builder_->SetInsertPoint(curr_bb); for (unsigned K = 0; K < NK; K += 4) { int NEXTK = (K + 4) % NK; // prefetch A for (unsigned m = 0; m < num_m/2; m+=is_a_row?1:2) load_a(m, NEXTK, 1, true); // prefetch B for (unsigned n = 0; n < num_n/2; n+=is_b_row?2:1) load_b(n, NEXTK, 1, true); // tensor core ops for(unsigned m = 0; m < num_m/2; m++) for(unsigned n = 0; n < num_n/2; n++){ call_mma(m, n, K); } } } else { // not prefetched for(unsigned K = 0; K < NK; K += 4) for(unsigned m = 0; m < num_m/2; m++) for(unsigned n = 0; n < num_n/2; n++) { if(has.find({m, K}) == has.end()) load_a(m, K, /*inc*/0, /*is_prefetch*/false); if(hbs.find({n, K}) == hbs.end()) load_b(n, K, /*inc*/0, /*is_prefetch*/false); call_mma(m, n, K); } } // write back accumulators for(size_t i = 0; i < idxs_.at(C).size(); i++) vals_[C][idxs_[C][i]] = acc[i]; } namespace { class mma16816_smem_loader { public: mma16816_smem_loader(int wpt, std::vector order, int k_order, std::vector tile_shape, std::vector instr_shape, std::vector mat_shape, int per_phase, int max_phase, int dtsize, Builder *builder, adder add, multiplier mul, geper gep) : wpt_(wpt), order_(order), k_order_(k_order), tile_shape_(tile_shape), instr_shape_(instr_shape), mat_shape_(mat_shape), per_phase_(per_phase), max_phase_(max_phase), dtsize_(dtsize), builder_(builder), add(add), mul(mul), gep(gep) { // compute compile-time constant variables & types c_mat_shape_ = mat_shape[order[0]]; s_mat_shape_ = mat_shape[order[1]]; c_stride_ = tile_shape[order[1]]; s_stride_ = tile_shape[order[0]]; // rule: k must be the fast-changing axis need_trans_ = k_order_ != order_[0]; can_use_ldmatrix_ = dtsize == 2 || (!need_trans_); // we need more pointers at the fast-changing axis, if (can_use_ldmatrix_) num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]]; else // warning: this only works for tf32 & need transpose num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]]; num_ptr_ = std::max(num_ptr_, 2); // special rule for i8/u8, 4 ptrs for each matrix if (!can_use_ldmatrix_ && dtsize_ == 1) num_ptr_ *= 4; // load_v4 stride (in num of mats) int load_stride_in_mat[2]; load_stride_in_mat[k_order] = 2; // instr_shape[k_order] / mat_shape[k_order], always 2 load_stride_in_mat[k_order^1] = wpt * (instr_shape[k_order^1] / mat_shape[k_order^1]); p_load_stride_in_mat_ = load_stride_in_mat[order[0]]; // stride in mat, used by load_v4 s_mat_stride_ = load_stride_in_mat[order[1]] / (instr_shape[order[1]]/mat_shape[order[1]]); } std::vector compute_offs(Value *warp_off, Value *lane) { // TODO: this needs to be moved to constructor (and extracted to arr_order) mat_arr_stride_ = (k_order_ == 1) ? 1 : wpt_; warp_off_stride_ = instr_shape_[k_order_^1] / mat_shape_[k_order_^1]; // start matrix logic offset (rename it as base_mat_off?) Value *mat_off[2] = {nullptr, nullptr}; if (can_use_ldmatrix_) { // c: lane idx inside a group (a group is a collection of 8 contiguous threads) // s: group idx (0,1,2,3) inside a warp Value *c = urem(lane, i32(8)); Value *s = udiv(lane, i32(8)); // We can decompose s => s_0, s_1... Value *s0 = urem(s, i32(2)); Value *s1 = udiv(s, i32(2)); // We use different orders for a & b for better performance. Value *k_mat_arr = (k_order_ == 1) ? s1 : s0; Value *nk_mat_arr = (k_order_ == 1) ? s0 : s1; mat_off[k_order_^1] = add(mul(warp_off, i32(warp_off_stride_)), mul(nk_mat_arr, i32(mat_arr_stride_))); mat_off[k_order_] = k_mat_arr; // physical offset (before swizzling) Value *c_mat_off = mat_off[order_[0]]; Value *s_mat_off = mat_off[order_[1]]; // offset inside a matrix Value *s_off_in_mat = c; std::vector offs(num_ptr_); Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); // pre-compute strided offset Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_))); for (int i=0; i < num_ptr_; ++i) { Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_)); c_mat_off_i = xor_(c_mat_off_i, phase); // smem swizzle offs[i] = add(mul(c_mat_off_i, i32(c_mat_shape_)), mul(s_off, i32(s_stride_))); } return offs; } else if (dtsize_ == 4 && need_trans_) { // load tf32 matrices with lds32 Value *c_off_in_mat = udiv(lane, i32(4)); // 4 = mat_shape[order[1]] Value *s_off_in_mat = urem(lane, i32(4)); // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); std::vector offs(num_ptr_); for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2; int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2; if (k_mat_arr_int > 0) // we don't need pointers for k continue; Value *k_mat_arr = i32(k_mat_arr_int); Value *nk_mat_arr = i32(nk_mat_arr_int); // physical offset (before swizzling) Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)), mul(nk_mat_arr, i32(mat_arr_stride_))); Value *s_mat_off = k_mat_arr; // always 0? Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_))); // FIXME: (k_order_ == 1?) is really dirty hack for (int i = 0; i < num_ptr_/2; ++i) { Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_*(k_order_ == 1?1:2))); c_mat_off_i = xor_(c_mat_off_i, phase); Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_))); // TODO: move this out of the loop c_off = urem(c_off, i32(tile_shape_[order_[0]])); s_off = urem(s_off, i32(tile_shape_[order_[1]])); offs[2*i + nk_mat_arr_int] = add(c_off, mul(s_off, i32(s_stride_))); } } return offs; // throw std::runtime_error("not implemented"); } else if (dtsize_ == 1 && need_trans_) { // load i8/u8 matrices with lds8 Value *c_off_in_mat = udiv(lane, i32(4)); // Value *s_off_in_mat = mul(urem(lane, i32(4)), i32(4)); // each thread load 4 cols // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); std::vector offs(num_ptr_); for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2; int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2; if (k_mat_arr_int > 0) // we don't need pointers for k continue; Value *k_mat_arr = i32(k_mat_arr_int); Value *nk_mat_arr = i32(nk_mat_arr_int); // physical offset (before swizzling) Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)), mul(nk_mat_arr, i32(mat_arr_stride_))); Value *s_mat_off = k_mat_arr; // always 0? for (int loadx4_off = 0; loadx4_off < num_ptr_/8; ++loadx4_off) { for (int elem_off = 0; elem_off < 4; ++elem_off) { int ptr_off = loadx4_off*8 + nk_mat_arr_int*4 + elem_off; Value *c_mat_off_i = add(c_mat_off, i32(loadx4_off*p_load_stride_in_mat_*(k_order_ == 1?1:2))); Value *s_off_in_mat_elem = add(s_off_in_mat, i32(elem_off)); // disable swizzling ... // Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_)); // c_mat_off_i = xor_(c_mat_off_i, phase); Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_))); Value *s_off = add(s_off_in_mat_elem, mul(s_mat_off, i32(s_mat_shape_))); // To prevent out-of-bound access when the tile is too small c_off = urem(c_off, i32(tile_shape_[order_[0]])); s_off = urem(s_off, i32(tile_shape_[order_[1]])); offs[ptr_off] = add(c_off, mul(s_off, i32(s_stride_))); } } } return offs; } else throw std::runtime_error("invalid smem load config"); } std::tuple load_x4(int mat0, int mat1, int inc, bool is_prefetch, ir::phi_node *pn, Value *pre_ptr, Value *next_ptr, std::vector &off, std::vector &ptrs, FunctionType *ldmatrix_ty, Type *smem_ptr_ty, std::map> &prefetch_latch_to_bb_) { assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned"); int mat_idx[2] = {mat0, mat1}; int k = mat_idx[k_order_]; int ptr_idx = -1; if (can_use_ldmatrix_) ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]); else if (dtsize_ == 4 && need_trans_) // tf32 & trans ptr_idx = mat_idx[order_[0]]; else // i8 & trans ptr_idx = mat_idx[order_[0]] * 4; auto get_ptr = [&](int idx) -> Value* { Value *ptr = nullptr; if (k == 0 && is_prefetch) { if (inc == 0) ptr = bit_cast(gep(pre_ptr, off.at(idx)), smem_ptr_ty); else ptr = bit_cast(gep(next_ptr, off.at(idx)), smem_ptr_ty); } else ptr = ptrs.at(idx); return ptr; }; Value *ptr = get_ptr(ptr_idx); Value *res_v4 = nullptr; if (can_use_ldmatrix_) { std::string trans = need_trans_ ? ".trans" : ""; // the offset (in byte) on the strided axis is a constant int s_offset = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_ * dtsize_; InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty, "ldmatrix.sync.aligned.m8n8.x4" + trans + ".shared.b16 " "{$0, $1, $2, $3}, " "[$4 + " + std::to_string(s_offset) + "];", "=r,=r,=r,=r,r", true); assert(ptr); res_v4 = call(ldmatrix_ty, ld_fn, {ptr}); if (k == 0 && inc == 1 && is_prefetch) prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(res_v4); return {extract_val(res_v4, std::vector{0}), extract_val(res_v4, std::vector{1}), extract_val(res_v4, std::vector{2}), extract_val(res_v4, std::vector{3})}; } else if (dtsize_ == 4 && need_trans_) { // use lds.32 to load tf32 matrices Value *ptr2 = get_ptr(ptr_idx+1); assert(s_mat_stride_ == 1); int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_; int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_; Value *elem0, *elem1, *elem2, *elem3; if (k_order_ == 1) { elem0 = load(gep(ptr, i32(s_offset_elem))); elem1 = load(gep(ptr2, i32(s_offset_elem))); elem2 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem))); elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem))); } else { // for b (k first) elem0 = load(gep(ptr, i32(s_offset_elem))); elem2 = load(gep(ptr2, i32(s_offset_elem))); elem1 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem))); elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem))); } if (k == 0 && inc == 1 && is_prefetch) { prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem0); prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem1); prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem2); prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3); } return {elem0, elem1, elem2, elem3}; } else if (dtsize_ == 1 && need_trans_) { // use lds.8 to load i8/u8 matrices Value *ptr00 = get_ptr(ptr_idx); Value *ptr01 = get_ptr(ptr_idx+1); Value *ptr02 = get_ptr(ptr_idx+2); Value *ptr03 = get_ptr(ptr_idx+3); Value *ptr10 = get_ptr(ptr_idx+4); Value *ptr11 = get_ptr(ptr_idx+5); Value *ptr12 = get_ptr(ptr_idx+6); Value *ptr13 = get_ptr(ptr_idx+7); assert(s_mat_stride_ == 1); int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_; int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_; Value *i8v4_elems[4]; Value *i32_elems[4]; for (int i=0; i<4; ++i) i8v4_elems[i] = UndefValue::get(vec_ty(i8_ty, 4)); Value *elem00, *elem01, *elem02, *elem03; Value *elem10, *elem11, *elem12, *elem13; Value *elem20, *elem21, *elem22, *elem23; Value *elem30, *elem31, *elem32, *elem33; Value *i8_elems[4*4]; if (k_order_ == 1) { // i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem))); i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem))); i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem))); i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem))); assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8)); i8_elems[1*4 + 0] = load(gep(ptr10, i32(s_offset_elem))); i8_elems[1*4 + 1] = load(gep(ptr11, i32(s_offset_elem))); i8_elems[1*4 + 2] = load(gep(ptr12, i32(s_offset_elem))); i8_elems[1*4 + 3] = load(gep(ptr13, i32(s_offset_elem))); i8_elems[2*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem))); i8_elems[2*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem))); i8_elems[2*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem))); i8_elems[2*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem))); i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem))); i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem))); i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem))); i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem))); for (int m=0; m<4; ++m) { for (int e=0; e<4; ++e) i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e); i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty); } } else { // for b (k first) i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem))); i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem))); i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem))); i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem))); assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8)); i8_elems[2*4 + 0] = load(gep(ptr10, i32(s_offset_elem))); i8_elems[2*4 + 1] = load(gep(ptr11, i32(s_offset_elem))); i8_elems[2*4 + 2] = load(gep(ptr12, i32(s_offset_elem))); i8_elems[2*4 + 3] = load(gep(ptr13, i32(s_offset_elem))); i8_elems[1*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem))); i8_elems[1*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem))); i8_elems[1*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem))); i8_elems[1*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem))); i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem))); i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem))); i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem))); i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem))); for (int m=0; m<4; ++m) { for (int e=0; e<4; ++e) i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e); i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty); } } if (k == 0 && inc == 1 && is_prefetch) { for (int m = 0; m < 4; ++m) for (int e = 0; e < 4; ++e) prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(i8_elems[m*4 + e]); } return {i32_elems[0], i32_elems[1], i32_elems[2], i32_elems[3]}; } else throw std::runtime_error("invalid smem load"); } int get_num_ptr() const { return num_ptr_; } private: int wpt_; std::vector order_; int k_order_; std::vector tile_shape_; std::vector instr_shape_; std::vector mat_shape_; int per_phase_, max_phase_; int dtsize_; // generated int c_mat_shape_, s_mat_shape_; int c_stride_, s_stride_; // p_: on the pointer axis int p_load_stride_in_mat_; int s_mat_stride_; // stride when moving to next not-k mat int warp_off_stride_; int mat_arr_stride_; // matrix arrangement (inside a load) stride bool need_trans_, can_use_ldmatrix_; int num_ptr_; Builder *builder_; adder add; multiplier mul; geper gep; }; } /** * \brief Code Generation for `mma.16816` (A100) */ //TODO: clean-up void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) { const std::vector& shapes = C->get_type()->get_block_shapes(); std::map, std::vector> fcs; for(indices_t idx: idxs_.at(C)){ std::vector key(idx.size() - 2); std::copy(idx.begin() + 2, idx.end(), key.begin()); fcs[key].push_back(vals_[D][idx]); }; auto shape_a = A->get_type()->get_block_shapes(); auto shape_b = B->get_type()->get_block_shapes(); auto ord_a = layouts_->get(A)->get_order(); if(C->is_trans_a()){ std::swap(ord_a[0], ord_a[1]); std::swap(shape_a[0], shape_a[1]); } auto ord_b = layouts_->get(B)->get_order(); if(C->is_trans_b()){ std::swap(ord_b[0], ord_b[1]); std::swap(shape_b[0], shape_b[1]); } NK = shape_a[1]; analysis::mma_layout* layout = layouts_->get(C)->to_mma(); std::vector mma_instr_shape = layout->get_mma_instr_shape(); const int mma_instr_m = mma_instr_shape[0]; const int mma_instr_n = mma_instr_shape[1]; const int mma_instr_k = mma_instr_shape[2]; std::vector mat_shape = layout->get_mma_mat_shape(); const int mat_shape_m = mat_shape[0]; const int mat_shape_n = mat_shape[1]; const int mat_shape_k = mat_shape[2]; const int num_rep_m = shapes[0] / layout->shape_per_cta(0); const int num_rep_n = shapes[1] / layout->shape_per_cta(1); const int num_rep_k = std::max(NK/mma_instr_k, 1); // floating point types Type *fp32_ty = f32_ty; Type *fp16x2_ty = vec_ty(f16_ty, 2); Type *bf16x2_ty = vec_ty(bf16_ty, 2); Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty}); Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty}); // integer types Type *i8x4_ty = vec_ty(i8_ty, 4); Type *i8x4_pack4_ty = StructType::get(*ctx_, std::vector{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty}); Type *i32_pack4_ty = StructType::get(*ctx_, std::vector{i32_ty, i32_ty, i32_ty, i32_ty}); FunctionType *ldmatrix_ty = nullptr; FunctionType *mma_ty = nullptr; Type *phi_ty = nullptr; Type *smem_ptr_ty = nullptr; ir::type *A_ir_ty = A->get_type()->get_scalar_ty(); ir::type *B_ir_ty = B->get_type()->get_scalar_ty(); if (A_ir_ty->is_fp16_ty() && B_ir_ty->is_fp16_ty()) { mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); smem_ptr_ty = ptr_ty(f16_ty, 3); ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{smem_ptr_ty}, false); phi_ty = fp16x2_ty; } else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) { mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); smem_ptr_ty = ptr_ty(bf16_ty, 3); ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector{smem_ptr_ty}, false); phi_ty = bf16x2_ty; } else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) { mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); smem_ptr_ty = ptr_ty(fp32_ty, 3); ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector{smem_ptr_ty}, false); phi_ty = fp32_ty; } else if (A_ir_ty->is_integer_ty(8) && B_ir_ty->is_integer_ty(8)) { // FIXME: We should use i8 here (but nvptx will generate extra casts when using i8) mma_ty = FunctionType::get(i32_pack4_ty, std::vector{i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false); smem_ptr_ty = ptr_ty(i8_ty, 3); ldmatrix_ty = FunctionType::get(i32_pack4_ty, std::vector{smem_ptr_ty}, false); phi_ty = i32_ty; // mma_ty = FunctionType::get(i32_pack4_ty, std::vector{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false); // smem_ptr_ty = ptr_ty(i8_ty, 3); // ldmatrix_ty = FunctionType::get(i8x4_pack4_ty, std::vector{smem_ptr_ty}, false); // phi_ty = i8x4_ty; } else throw std::runtime_error("mma16816 data type not supported"); // left-hand-side values std::map, Value*> ha; std::map, Value*> hb; BasicBlock* CurrBB = builder_->GetInsertBlock(); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); // if true, this will move pointer declarations to the entry basic block // not prefetched cases tend to be more limited in resource usage // so we don't pre-compute ptrs to save registers bool licm_ptrs = C->is_prefetched() && (FirstBB != CurrBB); if(licm_ptrs) builder_->SetInsertPoint(FirstBB->getTerminator()); Value* thread = tgt_->get_local_id(mod_, *builder_, 0); Value *lane = urem(thread, i32(32)); Value *warp = udiv(thread, i32(32)); Value *warp_mn = udiv(warp, i32(layout->wpt(0))); Value *warp_m = urem(warp, i32(layout->wpt(0))); Value *warp_n = urem(warp_mn, i32(layout->wpt(1))); std::vector& fc = fcs.begin()->second; size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; ir::phi_node* phiA = dynamic_cast(A); ir::phi_node* phiB = dynamic_cast(B); auto register_lds2 = [&](std::map, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) { if (k < 2 && is_prefetch) { ir::basic_block* inc_block = phiA->get_incoming_block(inc); lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block)); } else vals[{mn, k}] = val; }; // | -> k (row-major), since we have ldmatrix.trans, we only need to change stride // v (s0_0(0), s1_0(2), | *num_rep_k // m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2) // ----------- // *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0)) std::function load_a; analysis::shared_layout* layout_a = layouts_->get(C->get_operand(0))->to_shared(); bool is_a_shared = layout_a != nullptr; if(is_a_shared) { const int per_phase_a = swizzle_->get_per_phase(layout_a); const int max_phase_a = swizzle_->get_max_phase(layout_a); mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a, {mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k}, per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep); std::vector off_a = a_loader.compute_offs(warp_m, lane); int num_ptr_a = a_loader.get_num_ptr(); // pointers std::vector ptrs_a(num_ptr_a); if(licm_ptrs) builder_->SetInsertPoint(CurrBB); for(int i = 0; i < num_ptr_a; i++) ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty); if(licm_ptrs) builder_->SetInsertPoint(FirstBB->getTerminator()); // loading function load_a = [&,a_loader,ptrs_a,off_a](int m, int k, int inc, bool is_prefetch) mutable { auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a], shared_next_ptr_[layout_a], off_a, ptrs_a, ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); register_lds2(ha, m, k, inc, ha0, is_prefetch); register_lds2(ha, m+1, k, inc, ha1, is_prefetch); register_lds2(ha, m, k+1, inc, ha2, is_prefetch); register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch); }; } else { load_a = [&](int m, int k, int inc, bool is_prefetch) { distributed_axis ax_n = axes_.at(a_axes_->get(A, 1)); int ldm = ax_n.values.size(); if(ldm != num_rep_k*4) throw std::runtime_error("Internal compiler error when trying to fuse matmuls!"); // std::cout << m << " " << k << std::endl; // std::cout << idxs_[A].size() << std::endl; // std::cout << (m+1)*ldm + k*2 + 3 << std::endl; // int ldm = num_rep_k*4; Value* ha0 = UndefValue::get(phi_ty); // e.g., fp16x2 Value* ha1 = UndefValue::get(phi_ty); Value* ha2 = UndefValue::get(phi_ty); Value* ha3 = UndefValue::get(phi_ty); ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 0]], i32(0)); ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 1]], i32(1)); ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 0]], i32(0)); ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 1]], i32(1)); ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 2]], i32(0)); ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 3]], i32(1)); ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 2]], i32(0)); ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 3]], i32(1)); ha[{m, k}] = ha0; ha[{m+1, k}] = ha1; ha[{m, k+1}] = ha2; ha[{m+1, k+1}] = ha3; }; } // | -> n (col-major) // v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n // k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1)) // ----------- // *num_rep_k (stride in num of matrices(mat_stride_bk): 2) analysis::shared_layout* layout_b = layouts_->get(C->get_operand(1))->to_shared(); const int per_phase_b = swizzle_->get_per_phase(layout_b); const int max_phase_b = swizzle_->get_max_phase(layout_b); std::vector mma_instr_b{mma_instr_k, mma_instr_n}; std::vector mat_shape_b{mat_shape_k, mat_shape_n}; int k_order_b = 0; // if(C->is_trans_b()){ // std::swap(mma_instr_b[0], mma_instr_b[1]); // std::swap(mat_shape_b[0], mat_shape_b[1]); // k_order_b = k_order_b ^ 1; // std::swap(ord_b[0], ord_b[1]); // std::swap(shape_b[0], shape_b[1]); // } mma16816_smem_loader b_loader(layout->wpt(1), ord_b, k_order_b, shape_b, mma_instr_b, mat_shape_b, per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep); std::vector off_b = b_loader.compute_offs(warp_n, lane); if(licm_ptrs) builder_->SetInsertPoint(CurrBB); // pointers int num_ptr_b = b_loader.get_num_ptr(); std::vector ptrs_b(num_ptr_b); for(int i = 0; i < num_ptr_b; i++) ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty); // loading function std::function load_b; load_b = [&](int n, int k, int inc, bool is_prefetch) { auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b], shared_next_ptr_[layout_b], off_b, ptrs_b, ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_); register_lds2(hb, n, k, inc, hb0, is_prefetch); register_lds2(hb, n+1, k, inc, hb2, is_prefetch); register_lds2(hb, n, k+1, inc, hb1, is_prefetch); register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch); }; // create mma & unpack result, m, n, k are offsets in mat auto call_mma = [&](unsigned m, unsigned n, unsigned k) { InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() + " {$0, $1, $2, $3}," " {$4, $5, $6, $7}," " {$8, $9}," " {$10, $11, $12, $13};", "=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true); unsigned cols_per_thread = num_rep_n * 2; std::vector idx = { (m + 0)*cols_per_thread + (n*2 + 0), (m + 0)*cols_per_thread + (n*2 + 1), (m + 1)*cols_per_thread + (n*2 + 0), (m + 1)*cols_per_thread + (n*2 + 1) }; Value *nc = call(mma_ty, mma_fn, {ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}], hb[{n, k}], hb[{n, k+1}], fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]}); fc[idx[0]] = extract_val(nc, std::vector{0}); fc[idx[1]] = extract_val(nc, std::vector{1}); fc[idx[2]] = extract_val(nc, std::vector{2}); fc[idx[3]] = extract_val(nc, std::vector{3}); }; if (C->is_prefetched()) { // create phis builder_->SetInsertPoint(CurrBB->getFirstNonPHI()); for(unsigned m = 0; m < num_rep_m; m++){ ha[{2*m, 0}] = phi(phi_ty, 2); ha[{2*m+1, 0}] = phi(phi_ty, 2); ha[{2*m, 1}] = phi(phi_ty, 2); ha[{2*m+1, 1}] = phi(phi_ty, 2); } for(unsigned n = 0; n < num_rep_n; n+=2){ hb[{n, 0}] = phi(phi_ty, 2); hb[{n+1, 0}] = phi(phi_ty, 2); hb[{n, 1}] = phi(phi_ty, 2); hb[{n+1, 1}] = phi(phi_ty, 2); } // insert prefetched lds at the end of loop header builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator()); for(unsigned m = 0; m < num_rep_m; m++) load_a(2*m, 0, 0, true); for(unsigned n = 0; n < num_rep_n; n+=2) load_b(n, 0, 0, true); // update accumulators builder_->SetInsertPoint(CurrBB); for(unsigned k = 0; k < num_rep_k; ++k){ // stride of instr in mat is 2 int next_k = (k + 1) % num_rep_k; // prefetch A for(unsigned m = 0; m < num_rep_m; m++) load_a(2*m, 2*next_k, 1, true); // prefetch B for(unsigned n = 0; n < num_rep_n; n+=2) load_b(n, 2*next_k, 1, true); // tensor core ops for(unsigned m = 0; m < num_rep_m; m++) for(unsigned n = 0; n < num_rep_n; n++){ call_mma(2*m, n, 2*k); } } } else{ for (unsigned k = 0; k < num_rep_k; k++) { for (unsigned m = 0; m < num_rep_m; m++) load_a(2*m, 2*k, 0, /*is_prefetch*/false); for (unsigned n = 0; n < num_rep_n; n+=2) load_b(n, 2*k, 0, /*is_prefetch*/false); for (unsigned m = 0; m < num_rep_m; m++) for (unsigned n = 0; n < num_rep_n; n++) call_mma(2*m, n, 2*k); } } // write back unsigned i = 0; for(indices_t idx: idxs_.at(C)){ std::vector key(idx.size() - 2); std::copy(idx.begin() + 2, idx.end(), key.begin()); if(i >= fcs.at(key).size()) i = 0; vals_[C][idx] = fcs.at(key)[i++]; }; } /** * \brief Code Generation for FMA-based `dot` (FP32, FP64, Default) */ void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::value* D, unsigned NK, Type *c_ty, Function *f_mul_add) { auto shape_c = C->get_type()->get_block_shapes(); auto shape_a = A->get_type()->get_block_shapes(); auto shape_b = B->get_type()->get_block_shapes(); auto ord_a = layouts_->get(A)->get_order(); auto ord_b = layouts_->get(B)->get_order(); analysis::scanline_layout* layout_c = layouts_->get(C)->to_scanline(); analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0)); analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1)); bool is_a_row = ord_a[0] == 1; bool is_b_row = ord_b[0] == 1; std::string a_trans = is_a_row ? "" : ".trans"; std::string b_trans = is_b_row ? ".trans" : ""; int stride_a_m = is_a_row ? shape_a[1] : 1; int stride_a_k = is_a_row ? 1 : shape_a[0]; int stride_b_n = is_b_row ? 1 : shape_b[0]; int stride_b_k = is_b_row ? shape_b[1] : 1; int stride_a0 = is_a_row ? stride_a_k : stride_a_m; int stride_a1 = is_a_row ? stride_a_m : stride_a_k; int stride_b0 = is_b_row ? stride_b_n : stride_b_k; int stride_b1 = is_b_row ? stride_b_k : stride_b_n; int lda = is_a_row ? stride_a_m : stride_a_k; int ldb = is_b_row ? stride_b_k : stride_b_n; int per_phase_a = swizzle_->get_per_phase(layout_a); int max_phase_a = swizzle_->get_max_phase(layout_a); int per_phase_b = swizzle_->get_per_phase(layout_b); int max_phase_b = swizzle_->get_max_phase(layout_b); int num_ptr_a = 8; int num_ptr_b = 8; int vec_a = 2; int vec_b = 4; distributed_axis ax_m = axes_.at(a_axes_->get(C, 0)); distributed_axis ax_n = axes_.at(a_axes_->get(C, 1)); // Value* thread = tgt_->get_local_id(mod_, *builder_, 0); Value* off_a0 = is_a_row ? i32(0) : mul(ax_m.thread_id, i32(ax_m.contiguous)); Value* off_a1 = is_a_row ? mul(ax_m.thread_id, i32(ax_m.contiguous)): i32(0); std::vector off_a(num_ptr_a); for(int i = 0; i < num_ptr_a; i++){ // Value* off_a0i = add(off_a0, i32(is_a_row ? vec_a : layout_c->mts(0)*vec_a)); // off_a0i = exact_udiv(off_a0i, i32(vec_a)); // off_a0i = xor_(off_a0i, phase_a); // off_a0i = mul(off_a0i, i32(vec_a)); off_a[i] = add(mul(off_a0, i32(stride_a0)), mul(off_a1, i32(stride_a1))); } Value* off_b0 = is_b_row ? mul(ax_n.thread_id, i32(ax_n.contiguous)): i32(0); Value* off_b1 = is_b_row ? i32(0) : mul(ax_n.thread_id, i32(ax_n.contiguous)); std::vector off_b(num_ptr_b); for(int i = 0; i < num_ptr_b; i++){ // Value* off_b0i = add(off_b0, i32(is_b_row ? layout_c->mts(1)*vec_b : vec_b)); // off_b0i = exact_udiv(off_b0i, i32(vec_b)); // off_b0i = xor_(off_b0i, phase_b); // off_b0i = mul(off_b0i, i32(vec_b)); off_b[i] = add(mul(off_b0, i32(stride_b0)), mul(off_b1, i32(stride_b1))); } std::vector ptrs_a(num_ptr_a); for(int i = 0; i < num_ptr_a; i++) ptrs_a[i] = gep(shmems_[A], off_a[i]); std::vector ptrs_b(num_ptr_b); for(int i = 0; i < num_ptr_b; i++) ptrs_b[i] = gep(shmems_[B], off_b[i]); std::map ret = vals_[D]; std::map, Value*> has, hbs; auto ord = layout_c->get_order(); for(unsigned k = 0; k < NK; k++){ int z = 0; for(unsigned i = 0; i < shape_c[ord[1]]; i += layout_c->shape_per_cta(ord[1])) for(unsigned j = 0; j < shape_c[ord[0]]; j += layout_c->shape_per_cta(ord[0])) for(unsigned ii = 0; ii < layout_c->nts(ord[1]); ii++) for(unsigned jj = 0; jj < layout_c->nts(ord[0]); jj++){ unsigned m = (ord[0] == 1) ? i : j; unsigned n = (ord[0] == 1) ? j : i; unsigned mm = (ord[0] == 1) ? ii : jj; unsigned nn = (ord[0] == 1) ? jj : ii; if(has.find({m + mm, k}) == has.end()){ Value* pa = gep(ptrs_a[0], i32((m + mm)*stride_a_m + k*stride_a_k)); Value* va = load(pa); has[{m + mm, k}] = va; } if(hbs.find({n + nn, k}) == hbs.end()){ Value* pb = gep(ptrs_b[0], i32((n + nn)*stride_b_n + k*stride_b_k)); Value* vb = load(pb); hbs[{n + nn, k}] = vb; } ret[idxs_[C].at(z)] = call(f_mul_add, {has[{m+mm,k}], hbs[{n+nn, k}], ret[idxs_[C].at(z)]}); z++; } } for(indices_t idx: idxs_.at(C)){ vals_[C][idx] = ret[idx]; } } /** * \brief Code Generation for `dot` * Dispatches to appropriate specialized function */ void generator::visit_dot_inst(ir::dot_inst* dot) { Function *fn = builder_->GetInsertBlock()->getParent(); Module *module = fn->getParent(); ir::value *A = dot->get_operand(0); ir::value *B = dot->get_operand(1); ir::value *D = dot->get_operand(2); Type *c_ty = cvt(D->get_type()->get_scalar_ty()); Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector{c_ty}); auto A_shapes = A->get_type()->get_block_shapes(); size_t red_axis = 1; unsigned NK = A_shapes[red_axis]; bool is_outer = NK == 1; #ifdef USE_ROCM bool is_mma = layouts_->get(dot)->to_mma(); #else bool is_mma = false; #endif if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() < 80) return visit_mma884(dot, A, B, D, NK); if(!is_outer && is_mma && tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()? if (dot->get_type()->get_scalar_ty()->is_fp32_ty() && A->get_type()->get_scalar_ty()->is_fp32_ty()) return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); throw std::runtime_error("dot has invalid operand type"); } void generator::visit_trans_inst(ir::trans_inst* trans) { throw std::runtime_error("not supported"); } /** * \brief Code Generation for `sqrt` */ void generator::visit_sqrt_inst(ir::sqrt_inst* x) { for(indices_t idx: idxs_.at(x)){ Value *val = vals_[x->get_operand(0)][idx]; Value *ret = intrinsic(Intrinsic::sqrt, {val->getType()}, {val}); vals_[x][idx] = ret; } } Value* generator::shared_off(const std::vector& shapes, const std::vector& order, indices_t idx){ // strides std::vector strides(shapes.size(), builder_->getInt32(0)); strides[order[0]] = builder_->getInt32(1); for(size_t i = 1; i < idx.size(); i++) strides[order[i]] = builder_->CreateMul(strides[order[i-1]], builder_->getInt32(shapes[order[i-1]])); // result Value *result = builder_->getInt32(0); for(size_t i = 0; i < idx.size(); i++) result = builder_->CreateAdd(result, builder_->CreateMul(idx[i], strides[i])); return result; } inline Value* generator::shfl_sync(Value* acc, int32_t i){ Type* ty = acc->getType(); #ifdef USE_ROCM std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;"; InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false); #else std::string asm_str = "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;"; InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false); #endif if(ty->getPrimitiveSizeInBits() <= 32) return call(shfl, {acc, i32(i)}); acc = bit_cast(acc, vec_ty(f32_ty, 2)); Value* acc0 = builder_->CreateExtractElement(acc, i32(0)); Value* acc1 = builder_->CreateExtractElement(acc, i32(1)); Value* ret = UndefValue::get(vec_ty(f32_ty, 2)); ret = insert_elt(ret, shfl_sync(acc0, i), i32(0)); ret = insert_elt(ret, shfl_sync(acc1, i), i32(1)); return bit_cast(ret, ty); } /** * \brief Code Generation for `reduce` (ND case) */ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral){ ir::value *arg = x->get_operand(0); const auto with_index = x->with_index(); unsigned axis = x->get_axis(); analysis::distributed_layout* layout = dynamic_cast(layouts_->get(arg)); const auto &shapes = layout->get_shape(); Type* sca_ty = cvt(arg->get_type()->get_scalar_ty()); size_t n_bits = sca_ty->getPrimitiveSizeInBits(); std::string n_bits_str = std::to_string(n_bits); std::string cst = (n_bits == 64) ? "l" : "r"; FunctionType *st_shared_ty = FunctionType::get(void_ty, {i1_ty, ptr_ty(sca_ty, 3), sca_ty}, false); InlineAsm *st_shared = InlineAsm::get(st_shared_ty, "@$0 st.shared.b" + n_bits_str + " [$1], $2;", "b," + cst + "," + cst, true); FunctionType *ld_shared_ty = FunctionType::get(sca_ty, {i1_ty, ptr_ty(sca_ty, 3)}, false); InlineAsm *ld_shared = InlineAsm::get(ld_shared_ty, "@$1 ld.shared.b" + n_bits_str + " $0, [$2];", "=" + cst + ",b," + cst, true); Type *index_ty = IntegerType::get(*ctx_, 32); FunctionType *st_shared_index_ty = FunctionType::get(void_ty, {i1_ty, ptr_ty(index_ty, 3), index_ty}, false); InlineAsm *st_shared_index = InlineAsm::get( st_shared_index_ty, "@$0 st.shared.b32 [$1], $2;", "b,r,r", true); FunctionType *ld_shared_index_ty = FunctionType::get(index_ty, {i1_ty, ptr_ty(index_ty, 3)}, false); InlineAsm *ld_shared_index = InlineAsm::get( ld_shared_index_ty, "@$1 ld.shared.b32 $0, [$2];", "=r,b,r", true); Value* thread = tgt_->get_local_id(mod_, *builder_, 0); Value* warp = udiv(thread, i32(32)); Value* lane = urem(thread, i32(32)); unsigned shuffle_width = 0; unsigned warps_per_inner = 0; auto arg_vals = vals_.at(arg); std::vector arg_idxs = idxs_.at(arg); size_t n_elts = arg_idxs.size(); unsigned col_per_thread = 0; Value* warp_j = nullptr; if (analysis::scanline_layout *scanline = layout->to_scanline()) { std::vector order = layout->get_order(); unsigned mts = scanline->mts(order[0]); shuffle_width = std::min(mts, 32); warps_per_inner = std::max(mts / 32, 1); col_per_thread = shapes[order[0]] / mts; warp_j = urem(warp, i32(warps_per_inner)); } else if (layout->to_mma()) { shuffle_width = 4; warps_per_inner = layout->to_mma()->wpt(1); col_per_thread = axes_.at(a_axes_->get(arg, 1)).values.size(); warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id; } assert(warp_j != nullptr); // unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]); // Value *base = cast_shared_layout_ptr(layouts_->get(layouts_->tmp(x)), cvt(x->get_type()->get_scalar_ty())); Value *index_base = with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)), IntegerType::get(*ctx_, 32)) : nullptr; // preds Value* is_lane0 = icmp_eq(lane, i32(0)); Value* is_warp0 = icmp_eq(warp, i32(0)); Value* is_thread0 = icmp_eq(thread, i32(0)); Value* lane_j = urem(lane, i32(shuffle_width)); if(warps_per_inner > 1) add_barrier(); // compute partial sum for each warp, and store to shared memory for(size_t i = 0; i < n_elts/col_per_thread; i++){ std::pair acc; // reduce within thread for(size_t j = 0; j < col_per_thread; j++){ auto arg_idx = arg_idxs[i*col_per_thread + j]; bool is_first = j == 0; do_acc( acc, [&]() -> Value * { return arg_vals[arg_idx]; }, [&]() -> Value * { return arg_idx[axis]; }, is_first); } // reduce within warp for(int k = shuffle_width/2 ; k > 0; k >>= 1) { do_acc( acc, [&]() -> Value * { return shfl_sync(acc.first, k); }, [&]() -> Value * { return shfl_sync(acc.second, k); }, false); } // store partial result to shared memory auto x_idxs = idxs_[x][i]; Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; // single warp on the reduce dimension -- no need to use shmem if(warps_per_inner==1){ vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first; } else{ Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j); call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first}); if (with_index) { call(st_shared_index, {icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second}); } } } if(warps_per_inner==1) return; add_barrier(); // at this point, partial accumulator synchronized in shared memory // Just need to reduce `warp_per_inner` numbers in shared memory for(size_t i = 0; i < n_elts/col_per_thread; i++){ auto x_idxs = idxs_[x][i]; Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0]; Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner))); std::pair acc; acc.first = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)}); acc.second = with_index ? call(ld_shared_index, {builder_->getInt1(true), gep(index_base, ld_off)}) : nullptr; for (int k = warps_per_inner / 2; k > 0; k >>= 1) { do_acc( acc, [&]() -> Value * { return shfl_sync(acc.first, k); }, [&]() -> Value * { return shfl_sync(acc.second, k); }, false); } vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first; } // add_barrier(); } void generator::visit_reducend_inst(ir::reduce_inst* x, acc_fn_t do_acc, Value *neutral) { ir::value *arg = x->get_operand(0); unsigned axis = x->get_axis(); auto with_index = x->with_index(); // reduce within thread // index-> std::map> accs; for(indices_t idx: idxs_.at(arg)){ indices_t pidx = idx; pidx[axis] = i32(0); bool is_first = accs.find(pidx) == accs.end(); do_acc( accs[pidx], [&]() -> Value * { return vals_[arg][idx]; }, [&]() -> Value * { return idx[axis]; }, is_first); }; // reduce within blocks auto *data_layout = layouts_->get(layouts_->tmp(x)); auto *data_ptr = cast_shared_layout_ptr(data_layout, cvt(x->get_type()->get_scalar_ty())); auto *index_ptr = with_index ? cast_shared_layout_ptr(layouts_->get(layouts_->tmp_index(x)), IntegerType::get(*ctx_, 32)) : data_ptr; auto shape = data_layout->get_shape(); auto order = data_layout->get_order(); Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id; for(auto& x: accs) { // current element being computed std::pair acc = x.second; indices_t write_idx = x.first; write_idx[axis] = lane; // shared memory write pointer Value *write_off = shared_off(shape, order, write_idx); Value *write_ptr = gep(data_ptr, write_off); Value *index_write_ptr = gep(index_ptr, write_off); // initialize shared memory add_barrier(); store(acc.first, write_ptr); if (with_index) { store(acc.second, index_write_ptr); } // build result indices_t idx(write_idx.size(), i32(0)); for(size_t i = shape[axis]/2; i > 0; i >>= 1){ idx[axis] = i32(i); // read pointer Value *read_msk = icmp_ult(lane, i32(i)); Value *read_off = select(read_msk, shared_off(shape, order, idx), i32(0)); Value *read_ptr = gep(write_ptr, read_off); Value *index_read_ptr = gep(index_write_ptr, read_off); add_barrier(); // update accumulator do_acc( acc, [&]() -> Value * { return load(read_ptr); }, [&]() -> Value * { return load(index_read_ptr); }, false); add_barrier(); store(acc.first, write_ptr); if (with_index) { store(acc.second, index_write_ptr); } } } add_barrier(); // write back for(indices_t idx: idxs_.at(x)){ indices_t read_idx = idx; read_idx.insert(read_idx.begin() + axis, i32(0)); Value *read_off = shared_off(shape, order, read_idx); Value *read_ptr = with_index ? gep(index_ptr, read_off) : gep(data_ptr, read_off); vals_[x][idx] = load(read_ptr); }; } /** * \brief Code Generation for `reduce` (generic case) */ void generator::visit_reduce_inst(ir::reduce_inst* x) { Type *ty = cvt(x->get_type()->get_scalar_ty()); // accumulation function ir::reduce_inst::op_t op = x->get_op(); auto do_acc_op = [&](Value *x, Value *y) -> Value* { switch(op){ case ir::reduce_inst::ADD: return add(x, y); case ir::reduce_inst::SUB: return sub(x, y); case ir::reduce_inst::ARGUMAX: return icmp_uge(x, y); case ir::reduce_inst::ARGUMIN: return icmp_ule(x, y); case ir::reduce_inst::ARGMAX: return icmp_sge(x, y); case ir::reduce_inst::ARGMIN: return icmp_sle(x, y); case ir::reduce_inst::UMAX: return select(icmp_uge(x, y), x, y); case ir::reduce_inst::UMIN: return select(icmp_ule(x, y), x, y); case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y); case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y); case ir::reduce_inst::FADD: return fadd(x, y); case ir::reduce_inst::FSUB: return fsub(x, y); case ir::reduce_inst::ARGFMAX: return fcmp_oge(x, y); case ir::reduce_inst::ARGFMIN: return fcmp_ole(x, y); case ir::reduce_inst::FMAX: return max_num(x, y); case ir::reduce_inst::FMIN: return min_num(x, y); case ir::reduce_inst::XOR: return xor_(x, y); default: throw std::runtime_error("unreachable"); } }; auto do_acc = [&](std::pair &acc, std::function load_value_fn, std::function load_index_fn, bool is_first) -> void { auto *val = load_value_fn(); if (x->with_index()) { auto *index = load_index_fn(); if (is_first) { acc.first = val; acc.second = index; } else { Value *ret = do_acc_op(acc.first, val); acc.first = select(ret, acc.first, val); acc.second = select(ret, acc.second, index); } } else { acc.first = is_first ? val : do_acc_op(acc.first, val); } }; // neutral element Value *neutral; switch(op) { case ir::reduce_inst::ADD: neutral = ConstantInt::get(ty, 0); break; case ir::reduce_inst::SUB: neutral = ConstantInt::get(ty, 0); break; case ir::reduce_inst::ARGUMAX: neutral = ConstantInt::get(ty, INT32_MIN); break; case ir::reduce_inst::ARGUMIN: neutral = ConstantInt::get(ty, INT32_MAX); break; case ir::reduce_inst::ARGMAX: neutral = ConstantInt::get(ty, INT32_MIN); break; case ir::reduce_inst::ARGMIN: neutral = ConstantInt::get(ty, INT32_MAX); break; case ir::reduce_inst::UMAX: neutral = ConstantInt::get(ty, 0); break; case ir::reduce_inst::UMIN: neutral = ConstantInt::get(ty, UINT32_MAX); break; case ir::reduce_inst::MAX: neutral = ConstantInt::get(ty, INT32_MIN); break; case ir::reduce_inst::MIN: neutral = ConstantInt::get(ty, INT32_MAX); break; case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break; case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break; case ir::reduce_inst::ARGFMAX: neutral = ConstantFP::get(ty, -INFINITY); break; case ir::reduce_inst::ARGFMIN: neutral = ConstantFP::get(ty, INFINITY); break; case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break; case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break; case ir::reduce_inst::XOR: neutral = ConstantInt::get(ty, 0); break; default: throw std::runtime_error("unreachable"); } ir::value *arg = x->get_operand(0); #ifdef USE_ROCM visit_reducend_inst(x, do_acc, neutral); #else bool is_coalesced_scanline = layouts_->is_coalesced_scanline(x); bool is_a100_mma = layouts_->is_a100_mma(x); if (is_coalesced_scanline || is_a100_mma) visit_reducend_inst_fast(x, do_acc, neutral); else visit_reducend_inst(x, do_acc, neutral); #endif } /** * \brief Code Generation for `select` */ void generator::visit_select_inst(ir::select_inst* x) { for(indices_t idx: idxs_.at(x)){ vals_[x][idx] = select(vals_[x->get_operand(0)][idx], vals_[x->get_operand(1)][idx], vals_[x->get_operand(2)][idx]); } } void generator::visit_layout_convert(ir::value *out, ir::value *in){ ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes(); // pointer to temporary shared memory Type *ty = cvt(out->get_type()->get_scalar_ty()); // Orders analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(in)); analysis::distributed_layout* out_layout = dynamic_cast(layouts_->get(out)); Value *base; int off = alloc_->offset(layouts_->get(layouts_->tmp(out))); // std::cout << off << std::endl; base = gep(shmem_, i32(off)); base = bit_cast(base, ptr_ty(ty, 3)); std::vector n_reps; for(int i = 0; i < shape.size(); i++){ int in_per_cta = in_layout->shape_per_cta(i); int out_per_cta = out_layout->shape_per_cta(i); int max_per_cta = std::max(in_per_cta, out_per_cta); n_reps.push_back(shape[i]/max_per_cta); } std::vector> in_ax; std::vector> out_ax; for(int d = 0; d < shape.size(); d++){ in_ax.push_back(axes_.at(a_axes_->get(in, d)).values); out_ax.push_back(axes_.at(a_axes_->get(out, d)).values); } auto in_ord = in_layout->to_mma() ? out_layout->get_order() : in_layout->get_order(); auto out_ord = out_layout->to_mma() ? in_layout->get_order() : out_layout->get_order(); // out_ord[0] == 0 or in_order[0] == 0 means the first dimension is // non-contiguous. in_vec can be greater than 0 only if both out_ord[0] and // and in_ord[0] are contiguous. int in_vec = out_ord[0] == 0 ? 1 : in_ord[0] == 0 ? 1 : in_layout->contig_per_thread(in_ord[0]); int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]); int pad = std::max(in_vec, out_vec); Value *in_ld = i32(shape[in_ord[0]] + pad); Value *out_ld = i32(shape[out_ord[0]] + pad); for(int i = 0; i < n_reps[0]; i++) for(int j = 0; j < n_reps[1]; j++){ int max_ii, max_jj; add_barrier(); max_ii = in_ax[0].size()/n_reps[0]; max_jj = in_ax[1].size()/n_reps[1]; for(int ii = 0; ii < max_ii; ii++) for(int jj = 0; jj < max_jj; jj+=in_vec){ // shared mem pointer indices_t offs = {in_ax[0][ii], in_ax[1][jj]}; Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]])); Value *ptr = gep(base, off); // stash value to shared mem Value* vals = UndefValue::get(vec_ty(ty, in_vec)); for(int jjj = 0; jjj < in_vec; jjj++){ indices_t idxs = {in_ax[0][i*max_ii + ii], in_ax[1][j*max_jj + jj + jjj]}; Value* val = bit_cast(vals_[in][idxs], ty); vals = insert_elt(vals, val, jjj); } ptr = bit_cast(ptr, ptr_ty(vals->getType(), ptr->getType()->getPointerAddressSpace())); store(vals, ptr); } add_barrier(); max_ii = out_ax[0].size()/n_reps[0]; max_jj = out_ax[1].size()/n_reps[1]; for(int ii = 0; ii < max_ii; ii++) for(int jj = 0; jj < max_jj; jj+=out_vec){ // shared mem pointer indices_t offs = {out_ax[0][ii], out_ax[1][jj]}; Value *off = add(offs[out_ord[0]], mul(out_ld, offs[out_ord[1]])); Value *ptr = gep(base, off); ptr = bit_cast(ptr, ptr_ty(vec_ty(ty, out_vec), ptr->getType()->getPointerAddressSpace())); // load value from shared rem Value* vals = load(ptr); for(int jjj = 0; jjj < out_vec; jjj++){ indices_t idxs = {out_ax[0][i*max_ii + ii], out_ax[1][j*max_jj + jj + jjj]}; vals_[out][idxs] = extract_elt(vals, jjj); } } } } void generator::visit_cvt_layout_inst(ir::cvt_layout_inst *rc) { visit_layout_convert(rc, rc->get_operand(0)); } void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ unsigned in_vec = 1; ir::value *arg = x->get_pointer_operand(); analysis::shared_layout* out_layout = layouts_->get(x)->to_shared(); analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline(); auto out_order = out_layout->get_order(); auto in_order = in_layout->get_order(); // tiles if(out_order == in_order) in_vec = in_layout->nts(in_order[0]); int out_vec = swizzle_->get_vec(out_layout); int min_vec = std::min(out_vec, in_vec); int s = std::max(out_vec / in_vec, 1); // int per_phase = swizzle_->get_per_phase(out_layout); int max_phase = swizzle_->get_max_phase(out_layout); // int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); int n_shared_1 = std::max(per_phase*max_phase / in_layout->mts(in_order[1]), 1); int n_shared_0 = std::max(in_vec / out_vec, 1); auto shapes = x->get_type()->get_block_shapes(); BasicBlock* CurrBB = builder_->GetInsertBlock(); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); std::map, Value*> tmp; std::vector> shared; for(int i = 0; i < idxs_.at(arg).size(); i++){ unsigned id = i / min_vec; // input ptr info int id_0 = id % (in_ld/min_vec); int id_1 = id / (in_ld/min_vec); int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]); int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]); int off = (off_1*shapes[in_order[0]] + off_0); std::pair key = {id_1 % n_shared_1, id_0 % n_shared_0}; if(tmp.find(key) == tmp.end()){ if(CurrBB != FirstBB) builder_->SetInsertPoint(FirstBB->getTerminator()); indices_t idx = idxs_.at(arg).at(key.first*in_ld); Value* phase = udiv(idx[in_order[1]], i32(per_phase)); phase = urem(phase, i32(max_phase)); Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]])); Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec)); off_0 = udiv(off_0, i32(min_vec)); off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s))); off_0 = mul(off_0 , i32(min_vec)); Value* off = add(off_0, off_1); if(CurrBB != FirstBB) builder_->SetInsertPoint(CurrBB); tmp[key] = gep(shmems_[x], {off}); } shared.push_back({tmp[key], off}); } size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){ auto idx = idxs_[arg][i]; // input ptr info Value *ptr = vals_[arg][idx]; size_t in_off = 0; GetElementPtrInst *in_gep = dyn_cast(vals_[arg][idx]); if(in_gep){ ConstantInt* cst = dyn_cast(in_gep->idx_begin()); in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; ptr= cst ? in_gep->getPointerOperand() : in_gep; } // output ptr info Value* out_base = shared[i].first; int out_off = shared[i].second*dtsize; // asm std::string mod = (in_vec*dtsize == 16) ? ".cg" : ".ca"; // Value* false_value = vals_[x->get_false_value_operand()][idx]; // bool is_zero_false_value = false; // if(Constant* cst = dyn_cast(false_value)) // is_zero_false_value = cst->isZeroValue(); Value* src_size = builder_->CreateSelect(vals_[x->get_mask_operand()][idx], i32(in_vec*dtsize), i32(0)); std::string asm_str = "cp.async" + mod + ".shared.global [$0 + " + std::to_string(out_off) + "], [$1 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*dtsize) + ", $2;"; FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), ptr->getType(), builder_->getInt32Ty()}, false); InlineAsm *iasm = InlineAsm::get(ty, asm_str, "r,l,r", true); call(iasm, {out_base, ptr, src_size}); } std::string asm_str = "cp.async.commit_group;"; InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true); call(iasm); } void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { unsigned in_vec = 1; ir::value *arg = cts->get_operand(0); analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared(); analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(arg)); auto out_order = out_layout->get_order(); auto in_order = in_layout->get_order(); // tiles if(out_order == in_order) in_vec = in_layout->contig_per_thread(in_order[0]); int out_vec = swizzle_->get_vec(out_layout); int min_vec = std::min(out_vec, in_vec); int s = std::max(out_vec / in_vec, 1); // int per_phase = swizzle_->get_per_phase(out_layout); int max_phase = swizzle_->get_max_phase(out_layout); // int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]); int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]); if(in_layout->to_mma()){ mts_0 = 4 * in_layout->to_mma()->wpt(in_order[0]); mts_1 = 8 * in_layout->to_mma()->wpt(in_order[1]); per_phase = 1; max_phase = 8; } int in_ld = in_layout->get_shape()[in_order[0]] / mts_0; int n_shared_0 = std::max(in_vec / out_vec, 1); int n_shared_1 = std::max(per_phase*max_phase / mts_1, 1); if(in_layout->to_mma()){ n_shared_0 = 8; n_shared_1 = 1; } BasicBlock* CurrBB = builder_->GetInsertBlock(); BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); auto shapes = cts->get_type()->get_block_shapes(); // store to shared Value *current = nullptr; std::map, Value*> ptrs; for(int i = 0; i < idxs_.at(arg).size(); i++){ auto idx = idxs_[arg][i]; Value *in_value = vals_[arg][idx]; if(i % min_vec == 0) current = UndefValue::get(vec_ty(in_value->getType(), min_vec)); current = insert_elt(current, in_value, i % min_vec); if(i % min_vec == min_vec - 1){ unsigned id = i / min_vec; // input ptr info int id_0 = id % (in_ld/min_vec); int id_1 = id / (in_ld/min_vec); // std::cout << id_0 << " " << id_1 << " " << in_ld << " " << std::endl; std::pair key = {id_1 % n_shared_1, id_0 % n_shared_0}; if(ptrs.find(key) == ptrs.end()){ if(FirstBB->getTerminator()) builder_->SetInsertPoint(FirstBB->getTerminator()); else builder_->SetInsertPoint(FirstBB); indices_t idx = idxs_.at(arg).at(key.first*in_ld); Value* phase = udiv(idx[in_order[1]], i32(per_phase)); phase = urem(phase, i32(max_phase)); Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]])); Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec)); off_0 = udiv(off_0, i32(min_vec)); off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s))); off_0 = mul(off_0 , i32(min_vec)); Value* off = add(off_0, off_1); builder_->SetInsertPoint(CurrBB); ptrs[key] = gep(shmems_.at(cts), {off}); } int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0; int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1; if(in_layout->to_mma()){ off_0 = id_0/n_shared_0*n_shared_0*8; off_1 = id_1/n_shared_1*n_shared_1*8; } int off = (off_1*shapes[in_order[0]] + off_0); Value* ptr = gep(ptrs[key], {i32(off)}); ptr = bit_cast(ptr, current->getType()->getPointerTo(3)); // asm store(current, ptr); } }; } void generator::visit_copy_from_shared_inst(ir::copy_from_shared_inst*) { throw std::runtime_error("TODO"); } Instruction* generator::add_barrier() { Module *module = builder_->GetInsertBlock()->getModule(); return tgt_->add_barrier(module, *builder_); } void generator::visit_barrier_inst(ir::barrier_inst*) { add_barrier(); } void generator::visit_clock_inst(ir::clock_inst* clock){ InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %clock64;", "=l", true); vals_[clock][{}] = call(iasm); } void generator::visit_globaltimer_inst(ir::globaltimer_inst* timer){ InlineAsm *iasm = InlineAsm::get(FunctionType::get(builder_->getInt64Ty(), {}), "mov.u64 $0, %globaltimer;", "=l", true); vals_[timer][{}] = call(iasm); } void generator::visit_prefetch_s_inst(ir::prefetch_s_inst *i) { ir::value *v = i->get_operand(0); int inc = i->get_inc(); if (inc == 0) { // If dot has not been visitied, do nothing. } else { // If dot has been visitied, insert prefetched lds assert(inc == 1); assert(prefetch_latch_to_bb_.find(v) != prefetch_latch_to_bb_.end() && "dot hasn't be visited"); // sink lds & extract element // move lds & all uses to current location std::stack work_stack; for (Value *value : prefetch_latch_to_bb_[v]) work_stack.push(value); std::vector dead_instrs; while (!work_stack.empty()) { Value *m = work_stack.top(); work_stack.pop(); for (auto u : m->users()) work_stack.push(u); assert(isa(m)); auto m_instr = static_cast(m); m_instr->removeFromParent(); m_instr->insertAfter(&*std::prev(builder_->GetInsertBlock()->end())); assert(m_instr->getParent() == &*builder_->GetInsertBlock()); builder_->SetInsertPoint(m_instr->getParent()); } } } void generator::visit_async_wait_inst(ir::async_wait_inst* i) { std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";"; InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true); call(iasm); } /** * \brief Code Generation for `extern_elementwise` */ void generator::visit_extern_elementwise_inst(ir::extern_elementwise_inst *i) { std::vector operand_types; for (size_t j = 0; j < i->get_num_operands(); j++) { operand_types.push_back( cvt(i->get_operand(j)->get_type()->get_scalar_ty())); } Type *ret_type = cvt(i->get_type()->get_scalar_ty()); FunctionType *FT = FunctionType::get(ret_type, std::move(operand_types), false); Function *F = llvm::cast( mod_->getOrInsertFunction(i->get_symbol_name(), FT).getCallee()); for (auto idx : idxs_.at(i)) { std::vector args; for (size_t j = 0; j < i->get_num_operands(); j++) { args.emplace_back(vals_[i->get_operand(j)][idx]); } vals_[i][idx] = call(F, std::move(args)); } add_extern_lib(i->get_lib_name(), i->get_lib_path()); } //void generator::visit_make_range_dyn(ir::make_range_dyn* x) { // for(indices_t idx: idxs_.at(x)){ // assert(idx.size() == 1); // if(idx[0] == i32(0)) // vals_[x][idx] = idx[0]; // else{ // BinaryOperator *bin_add = dyn_cast(idx[0]); // assert(bin_add); // vals_[x][idx] = bin_add->getOperand(0); // } // } //} //void generator::visit_make_range_sta(ir::make_range_sta* x) { // for(indices_t idx: idxs_.at(x)){ // assert(idx.size() == 1); // if(idx[0] == i32(0)){ // vals_[x][idx] = idx[0]; // } // else{ // BinaryOperator *bin_add = dyn_cast(idx[0]); // assert(bin_add); // Value *cst = bin_add->getOperand(1); // assert(isa(cst)); // vals_[x][idx] = cst; // } // }; //} void generator::visit_make_range(ir::make_range* x) { for(indices_t idx: idxs_.at(x)){ Value* start = ConstantInt::get(idx[0]->getType(), x->get_first()->get_value()); vals_[x][idx] = add(start, idx[0]); } } void generator::visit_undef_value(ir::undef_value *x) { ir::type* sca_ty = x->get_type()->get_scalar_ty(); Type* ty = cvt(sca_ty); for(indices_t idx: idxs_.at(x)) vals_[x][idx] = llvm::UndefValue::get(ty); } void generator::visit_constant_int(ir::constant_int *x){ Type *ty = cvt(x->get_type()->get_scalar_ty()); for(indices_t idx: idxs_.at(x)) vals_[x][idx] = ConstantInt::get(ty, x->get_value()); } void generator::visit_constant_fp(ir::constant_fp *x){ Type *ty = cvt(x->get_type()->get_scalar_ty()); for(indices_t idx: idxs_.at(x)) { // manually select bf16 constant if (x->get_type()->get_scalar_ty()->is_bf16_ty()) { // highest 16 bits of fp32 float fp32_value = x->get_value(); uint16_t bf16_raw = (*reinterpret_cast(&fp32_value) & 0xffff0000) >> 16; std::stringstream const_str; const_str << "0x" << std::hex << bf16_raw << "U"; // unsigned InlineAsm *bf16_const = InlineAsm::get(FunctionType::get(bf16_ty, {}, false), " mov.b16 $0, " + const_str.str() + ";", "=h", false); vals_[x][idx] = builder_->CreateCall(bf16_const, {}); } else vals_[x][idx] = ConstantFP::get(ty, x->get_value()); } } void generator::visit_alloc_const(ir::alloc_const *alloc) { unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value(); Type *element_ty = cvt(alloc->get_type()->get_pointer_element_ty()); Type *array_ty = llvm::ArrayType::get(element_ty, size); Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage, nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4); vals_[alloc][{}] = bit_cast(array, element_ty->getPointerTo(4)); } void generator::forward_declare(ir::function* fn){ FunctionType *fn_ty = (FunctionType*)cvt(fn->get_fn_type()); if(!tgt_->is_gpu()){ Type *fn_ret_ty = fn_ty->getReturnType(); std::vector fn_args_ty; for(unsigned i = 0; i < fn_ty->getNumParams(); i++) fn_args_ty.push_back(fn_ty->getParamType(i)); fn_args_ty.push_back(i32_ty); fn_args_ty.push_back(i32_ty); fn_args_ty.push_back(i32_ty); fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false); } Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_); fns_[fn] = ret; } Value *generator::cast_shared_layout_ptr(analysis::data_layout *layout, Type *ty) { unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); Value *base = bit_cast(shared_ptr_.at(layout), ptr_ty(ty, addr_space)); return base; } void generator::visit_function(ir::function* fn) { std::cout << "generator.cc: generator::visit_function:" << std::endl; idxs_.clear(); vals_.clear(); seen_.clear(); LLVMContext &ctx = builder_->getContext(); Function* ret = fns_[fn]; // set attributes std::cout << "\t// set attributes" << std::endl; for(auto attr_pair: fn->attrs()){ unsigned id = attr_pair.first; for(ir::attribute attr: attr_pair.second) if(attr.is_llvm_attr()){ llvm::Attribute llattr = cvt(attr); if(llattr.getKindAsEnum() != llvm::Attribute::None) ret->addAttribute(id, cvt(attr)); } } // set metadata std::cout << "\t// set metadata" << std::endl; if(tgt_->is_gpu()){ tgt_->set_kernel(*builder_, ctx, mod_, ret); #ifndef USE_ROCM Metadata *md_args[] = { ValueAsMetadata::get(ret), MDString::get(ctx, "maxntidx"), ValueAsMetadata::get(i32(num_warps_*32)) }; mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args)); #endif } // set arguments std::cout << "\t// set arguments" << std::endl; for(unsigned i = 0; i < fn->args().size(); i++) vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i); // create blocks std::cout << "\t// create blocks" << std::endl; auto blocks = ir::cfg::reverse_post_order(fn); for(ir::basic_block *block: blocks) { BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret); bbs_[block] = dst_block; } builder_->SetInsertPoint(bbs_[fn->blocks()[0]]); // create policies #ifndef USE_ROCM std::cout << "\t// create policies" << std::endl; if(tgt_->as_nvidia()->sm() >= 80) for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){ std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last"; std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0, 1.0;"; InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false); policies_[evict] = call(iasm); } #endif // initialize layouts std::cout << "\t// initialize layouts" << std::endl; for(auto x: layouts_->get_all()){ visit_layout(x.second); } // generate LLVM-IR code std::cout << "\t// generate LLVM-IR code" << std::endl; for(ir::basic_block *block: blocks) visit_basic_block(block); // finalize std::cout << "\t// verifyFunction" << std::endl; finalize_function(fn); // verifyFunction std::cout << "\t// verifyFunction" << std::endl; llvm::verifyFunction(*ret); } void generator::visit_layout_mma(analysis::mma_layout* layout) { ir::value *a = nullptr; ir::value *b = nullptr; for(ir::value* v: layout->get_values()) if(ir::dot_inst* dot = dynamic_cast(v)){ a = dot->get_operand(0); b = dot->get_operand(1); } analysis::data_layout* layout_a = layouts_->get(a); analysis::data_layout* layout_b = layouts_->get(b); const auto& shape = layout->get_shape(); Value *_1 = i32(1); Value *_2 = i32(2); Value *_3 = i32(3); Value *_4 = i32(4); Value *_8 = i32(8); Value *_16 = i32(16); Value *_32 = i32(32); #ifdef USE_ROCM int cc = 1; // generate ir for older CUDA cards #else int cc = tgt_->as_nvidia()->sm(); #endif std::vector idx_m; std::vector idx_n; std::vector idx_z; // Value* thread = tgt_->get_local_id(mod_, *builder_, 0); Value *lane = urem(thread, _32); Value *warp = udiv(thread, _32); /* lane offset */ if(cc < 80){ auto ord_a = layout_a->get_order(); auto ord_b = layout_b->get_order(); bool is_a_row = ord_a[0] != 0; bool is_b_row = ord_b[0] != 0; /* warp offset */ Value *warp_0 = urem(warp, i32(layout->wpt(0))); Value *warp_12 = udiv(warp, i32(layout->wpt(0))); Value *warp_1 = urem(warp_12, i32(layout->wpt(1))); Value *off_warp_m = mul(warp_0, i32(layout->spw(0))); Value *off_warp_n = mul(warp_1, i32(layout->spw(1))); // Quad offset Value *off_quad_m = mul(udiv(and_(lane, _16), _4), i32(layout->fpw(0))); Value *off_quad_n = mul(udiv(and_(lane, _16), _4), i32(layout->fpw(1))); // Pair offset Value *off_pair_m = udiv(urem(lane, _16), _4); off_pair_m = urem(off_pair_m, i32(layout->fpw(0))); off_pair_m = mul(off_pair_m, i32(4)); Value *off_pair_n = udiv(urem(lane, _16), _4); off_pair_n = udiv(off_pair_n, i32(layout->fpw(0))); off_pair_n = urem(off_pair_n, i32(layout->fpw(1))); off_pair_n = mul(off_pair_n, i32(4)); // scale off_pair_m = mul(off_pair_m, i32(layout->rep(0)/2)); off_quad_m = mul(off_quad_m, i32(layout->rep(0)/2)); off_pair_n = mul(off_pair_n, i32(layout->rep(1)/2)); off_quad_n = mul(off_quad_n, i32(layout->rep(1)/2)); // Quad pair offset Value *off_lane_m = add(off_pair_m, off_quad_m); Value *off_lane_n = add(off_pair_n, off_quad_n); // a offset offset_a_m_[layout] = add(off_warp_m, off_lane_m); offset_a_k_[layout] = and_(lane, _3); // b offsets offset_b_n_[layout] = add(off_warp_n, off_lane_n); offset_b_k_[layout] = and_(lane, _3); // i indices Value *offset_c_m = add(and_(lane, _1), offset_a_m_[layout]); for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0)) for(unsigned mm = 0; mm < layout->rep(0); mm++) idx_m.push_back(add(offset_c_m, i32(m + mm*2))); // j indices Value *offset_c_n = add(and_(lane, _2), add(off_warp_n, off_pair_n)); for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1)) for(unsigned nn = 0; nn < layout->rep(1); nn++){ idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1)))); idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1) + 1))); } if(is_a_row){ offset_a_m_[layout] = add(offset_a_m_[layout], urem(thread, i32(4))); offset_a_k_[layout] = i32(0); } if(!is_b_row){ offset_b_n_[layout] = add(offset_b_n_[layout], urem(thread, i32(4))); offset_b_k_[layout] = i32(0); } /* axes */ axes_[layout->get_axis(0)] = distributed_axis{1, idx_m, warp_0}; axes_[layout->get_axis(1)] = distributed_axis{1, idx_n, warp_1}; } else{ /* warp offset */ Value *warp_0 = urem(warp, i32(layout->wpt(0))); Value *warp_1 = urem(udiv(warp, i32(layout->wpt(0))), i32(layout->wpt(1))); Value *off_warp_m = mul(warp_0, i32(layout->spw(0))); Value *off_warp_n = mul(warp_1, i32(layout->spw(1))); Value *off_lane_m = urem(lane, _16); Value *off_lane_n = urem(lane, _8); /* offsets */ // a offset offset_a_m_[layout] = add(off_warp_m, off_lane_m); offset_a_k_[layout] = i32(0); // b offsets offset_b_n_[layout] = add(off_warp_n, off_lane_n); offset_b_k_[layout] = i32(0); // c offset Value *off_c_m = add(udiv(lane, _4), off_warp_m); Value *off_c_n = add(mul(_2, urem(lane, _4)), off_warp_n); for(unsigned m = 0; m < shape[0]; m+=layout->shape_per_cta(0)){ idx_m.push_back(add(off_c_m, i32(m))); idx_m.push_back(add(off_c_m, i32(m + 8))); } for(unsigned n = 0; n < shape[1]; n+=layout->shape_per_cta(1)){ idx_n.push_back(add(off_c_n, i32(n))); idx_n.push_back(add(off_c_n, i32(n + 1))); } /* axes */ axes_[layout->get_axis(0)] = distributed_axis{1, idx_m, warp_0}; axes_[layout->get_axis(1)] = distributed_axis{1, idx_n, warp_1}; } } void generator::visit_layout_scanline(analysis::scanline_layout* layout) { Value* thread_id = tgt_->get_local_id(mod_, *builder_, 0); auto order = layout->get_order(); const auto& shape = layout->get_shape(); // Delinearize size_t dim = shape.size(); std::vector thread_ids(dim); for(unsigned k = 0; k < dim - 1; k++){ Constant *dim_k = i32(layout->mts(order[k])); Value *rem = urem(thread_id, dim_k); thread_id = udiv(thread_id, dim_k); thread_ids[order[k]] = rem; } Constant *dim_k = i32(layout->mts(order[dim - 1])); thread_ids[order[dim - 1]] = urem(thread_id, dim_k); // Create axes for(unsigned k = 0; k < dim; k++) { int nts = layout->nts(k); int mts = layout->mts(k); std::string str_k = std::to_string(k); Value *contiguous_k = i32(nts); Value *scaled_thread_ids = mul(thread_ids[k], contiguous_k); unsigned per_cta = layout->shape_per_cta(k); unsigned per_thread = nts * shape[k] / per_cta; std::vector idx_list(per_thread); for(unsigned n = 0 ; n < per_thread; n++){ unsigned offset = n / nts * per_cta + n % nts; idx_list[n] = add(scaled_thread_ids, i32(offset), "idx_" + str_k + "_" + std::to_string(n)); } axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_ids[k]}; } } void generator::visit_layout_shared(analysis::shared_layout* layout) { Type* ty = cvt(layout->get_type()); PointerType *ptr_ty = ty->getPointerTo(shmem_->getType()->getPointerAddressSpace()); if (layout->get_N_buffer()) { // create pointers shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout))); shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], ptr_ty); BasicBlock *current = builder_->GetInsertBlock(); auto info = *layout->get_N_buffer(); ir::phi_node *phi = info.phi; BasicBlock *parent = bbs_.at(phi->get_parent()); if(parent->empty()) builder_->SetInsertPoint(parent); else if (const Instruction *first_non_phi = &*parent->getFirstNonPHI()) { builder_->SetInsertPoint(&*parent->getFirstNonPHI()); } else builder_->SetInsertPoint(parent); // create smem_idx read_smem_idx_[layout] = phi(i32_ty, 2); write_smem_idx_[layout] = phi(i32_ty, 2); // create pointers // ptr of the current iteration shared_ptr_[layout] = phi(ptr_ty, 2); // ptr of the next iteration shared_next_ptr_[layout] = phi(ptr_ty, 2); builder_->SetInsertPoint(current); } else if(layout->get_double_buffer()) { BasicBlock *current = builder_->GetInsertBlock(); auto info = *layout->get_double_buffer(); ir::phi_node *phi = info.phi; BasicBlock *parent = bbs_.at(phi->get_parent()); if(parent->empty()) builder_->SetInsertPoint(parent); else builder_->SetInsertPoint(&*parent->getFirstNonPHI()); // create pointers shared_ptr_[layout] = phi(ptr_ty, 2); shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout))); shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], shared_ptr_[layout]->getType()); shared_off_[layout] = phi(i32_ty, 2); shared_next_ptr_[layout] = gep(shared_ptr_[layout], shared_off_[layout], "next_ptr"); builder_->SetInsertPoint(current); } else{ size_t offset = alloc_->offset(layout); shared_ptr_[layout] = gep(shmem_, i32(offset)); shared_ptr_[layout] = bit_cast(shared_ptr_[layout], ptr_ty); } } void generator::visit_basic_block(ir::basic_block * block) { BasicBlock *parent = bbs_[block]; builder_->SetInsertPoint(parent); for(ir::instruction *i: block->get_inst_list()){ visit_value(i); // std::cout << "done" << std::endl; } // Update ir bb -> llvm bb mapping bbs_[block] = builder_->GetInsertBlock(); } void generator::visit_argument(ir::argument* arg) { } void generator::init_idx(ir::value *v) { idxs_[v].clear(); if(!v->get_type()->is_block_ty()){ idxs_[v].push_back({}); return; } if(layouts_->get(v)->to_shared()) return; const auto &shapes = v->get_type()->get_block_shapes(); size_t rank = shapes.size(); std::vector axes(rank); std::vector ord(rank); // compute axes // std::cout << "axes" << std::endl; for(size_t d = 0; d < shapes.size(); d++){ // std::cout << d << " " << shapes[d] << std::endl; // std::cout << a_axes_->get(v, d) << std::endl; if(shapes[d] > 1){ unsigned x = a_axes_->get(v, d); axes[d] = axes_.at(x); } else{ axes[d].contiguous = 1; axes[d].values = {i32(0)}; } } // std::cout << "axes ok" << std::endl; // compute order analysis::data_layout* layout = layouts_->get(v); std::iota(ord.begin(), ord.end(), 0); auto cmp = [&](int x, int y) { unsigned axx = a_axes_->get(v, x); unsigned axy = a_axes_->get(v, y); size_t posx = layout->find_axis(axx); size_t posy = layout->find_axis(axy); if(posx < rank && posy < rank) return layout->get_order(posx) < layout->get_order(posy); return false; }; std::sort(ord.begin(), ord.end(), cmp); ords_[v] = ord; // indices if(axes.size() == 1) for(Value* x0: axes[ord[0]].values){ idxs_[v].push_back({x0}); } if(axes.size() == 2) for(Value* x1: axes[ord[1]].values) for(Value* x0: axes[ord[0]].values){ indices_t idx(2); idx[ord[0]] = x0; idx[ord[1]] = x1; idxs_[v].push_back(idx); } if(axes.size() == 3) for(Value* x2: axes[ord[2]].values) for(Value* x1: axes[ord[1]].values) for(Value* x0: axes[ord[0]].values){ indices_t idx(3); idx[ord[0]] = x0; idx[ord[1]] = x1; idx[ord[2]] = x2; idxs_[v].push_back(idx); } } void generator::finalize_shared_layout(analysis::shared_layout *shared) { if (auto n_buffer = shared->get_N_buffer()) { // if (*_smem_idx == #stages-1) { // *_smem_idx = 0; // } else *_smem_idx++; auto finalize_smem_idx = [&](auto &smem_idx, int init_stage) { // insert point Value *idx = smem_idx[shared]; builder_->SetInsertPoint(bbs_.at(n_buffer->phi->get_parent())->getTerminator()); Value *cond = icmp_eq(idx, i32(shared->get_num_stages()-1)); PHINode *_ret = phi(i32_ty, 2); Instruction *then_term = nullptr; Instruction *else_term = nullptr; Instruction *dummy = builder_->CreateRet(nullptr); llvm::SplitBlockAndInsertIfThenElse(cond, _ret, &then_term, &else_term, nullptr); dummy->removeFromParent(); builder_->SetInsertPoint(then_term); Value *zero_smem_idx = i32(0); builder_->SetInsertPoint(else_term); Value *inc_smem_idx = add(idx, i32(1)); builder_->SetInsertPoint(_ret->getParent()); _ret->addIncoming(zero_smem_idx, then_term->getParent()); _ret->addIncoming(inc_smem_idx, else_term->getParent()); // update ir::bb -> llvm::bb mapping bbs_.at(n_buffer->phi->get_parent()) = builder_->GetInsertBlock(); // idx = init_stage; // loop: ... if (auto idx_phi = llvm::dyn_cast(smem_idx[shared])) { idx_phi->addIncoming(i32(init_stage), bbs_.at(n_buffer->phi->get_incoming_block(0))); idx_phi->addIncoming(_ret, bbs_.at(n_buffer->phi->get_incoming_block(1))); } else throw std::runtime_error("Should be PHINode"); }; // read_smem_idx is used by next_ptr to compute the next iteration value, so init value is 2 finalize_smem_idx(read_smem_idx_, 2); finalize_smem_idx(write_smem_idx_, shared->get_num_stages()-1); // finalize pointers ir::phi_node *pn = n_buffer->phi; BasicBlock *header = bbs_.at(pn->get_incoming_block(0)); BasicBlock *loop = bbs_.at(pn->get_incoming_block(1)); // %curr_ptr = phi %shared_pre_ptr, %next_ptr // %next_ptr = phi %shared_pre_ptr[+1], (gep(%pre_ptr, read_smem_idx*per_stage_size)) if (auto curr_ptr = dyn_cast(shared_ptr_[shared])) { curr_ptr->addIncoming(shared_pre_ptr_[shared], header); curr_ptr->addIncoming(shared_next_ptr_[shared], loop); } else throw std::runtime_error("Should be PHINode"); BasicBlock *current = builder_->GetInsertBlock(); builder_->SetInsertPoint(header->getTerminator()); Value *next_ptr_header = gep(shared_pre_ptr_[shared], i32(shared->get_per_stage_elements())); builder_->SetInsertPoint(current->getTerminator()); assert(isa(shared_next_ptr_[shared])); static_cast(shared_next_ptr_[shared])->addIncoming(next_ptr_header, header); Value *lds_offset = mul(read_smem_idx_[shared], i32(shared->get_per_stage_elements())); Value *next_ptr = gep(shared_pre_ptr_[shared], lds_offset); static_cast(shared_next_ptr_[shared])->addIncoming(next_ptr, loop); } else if(shared->get_double_buffer()) { auto info = *shared->get_double_buffer(); ir::phi_node *phi = info.phi; PHINode *ptr = (PHINode*)shmems_[phi]; PHINode *offset = (PHINode*)shoffs_[phi]; for(unsigned n = 0; n < phi->get_num_incoming(); n++){ ir::basic_block* inc_block = phi->get_incoming_block(n); ir::value* inc_val = phi->get_incoming_value(n); BasicBlock *llvm_inc_block = bbs_.at(inc_block); if(inc_val == info.latch){ builder_->SetInsertPoint(llvm_inc_block->getTerminator()); Value *next_offset = neg(offset); offset->addIncoming(next_offset, llvm_inc_block); } else { unsigned num_bytes = shared->get_type()->get_primitive_size_in_bits() / 8; offset->addIncoming(i32(shared->get_size() / (2*num_bytes)), llvm_inc_block); } ptr->addIncoming(shmems_[inc_val], llvm_inc_block); } } } void generator::finalize_function(ir::function *fn) { // finalize double-buffering for(const auto& x: layouts_->get_all()) if(auto *shared = dynamic_cast(x.second)) finalize_shared_layout(shared); // finalize phi for(ir::basic_block *block: fn->blocks()) for(ir::instruction *inst: block->get_inst_list()) if(auto *phi = dynamic_cast(inst)) finalize_phi_node(phi); for(auto& x: lazy_phi_incs_) std::get<0>(x)->addIncoming(std::get<1>(x), bbs_[std::get<2>(x)]); } void generator::finalize_phi_node(ir::phi_node *x) { if(shmems_.find(x) != shmems_.end()) return; for(unsigned n = 0; n < x->get_num_incoming(); n++){ ir::basic_block *_block = x->get_incoming_block(n); BasicBlock *block = bbs_.at(_block); for(indices_t idx: idxs_.at(x)){ PHINode *phi = (PHINode*)vals_[x][idx]; Value *inc = vals_[x->get_incoming_value(n)][idx]; // x->print(std::cout); phi->addIncoming(inc, block); } } } void generator::packed_type(ir::value* i){ Type* dtype = cvt(i->get_type()->get_tile_element_ty()); auto* layout = dynamic_cast(layouts_->get(i)); assert(layout); } void generator::visit(ir::module &src, llvm::Module &dst) { std::cout << "generator.cc: generator::visit" << std::endl; mod_ = &dst; ctx_ = &dst.getContext(); builder_ = new Builder(*ctx_); // allocate shared memory if(tgt_->is_gpu()) if(unsigned alloc_size = alloc_->allocated_size()){ Type *int_8_ty = Type::getInt8Ty(*ctx_); Type *int_32_ty = Type::getInt32Ty(*ctx_); ArrayType *array_ty = ArrayType::get(int_32_ty, 0); Type *ptr_ty = ptr_ty(int_8_ty, 3); GlobalVariable *sh_mem_array = new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage, nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); shmem_ = bit_cast(sh_mem_array, ptr_ty); } // instantiate device functions // for(ir::function *fn: src.get_function_list()) // for(ir::basic_block *bb: fn->blocks()) // for(ir::instruction *i: bb->get_inst_list()) // if(auto *call = dynamic_cast(i)){ // std::cout << "call??" << std::endl; // } // visit functions for(ir::function *fn: src.get_function_list()) forward_declare(fn); for(ir::function *fn: src.get_function_list()) visit_function(fn); } void generator::add_extern_lib(const std::string &lib_name, const std::string &lib_path) { if (extern_lib_map_.count(lib_name) == 0) { extern_lib_map_[lib_name] = create_extern_lib(lib_name, lib_path); } else if (extern_lib_map_.at(lib_name)->path() != lib_path) { throw std::runtime_error("A library has multiple paths (1) " + lib_path + " (2) " + extern_lib_map_.at(lib_name)->path()); } } } // namespace codegen } // namespace triton