From 1e844ba78d1a14ab21eca6ce492ae1da9266248c Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 9 May 2021 21:59:25 -0400 Subject: [PATCH] [CODEGEN] Switching to predicated inline PTX for LDGs (#103) --- lib/codegen/selection/generator.cc | 143 +++++++++++++++++++++++------ lib/driver/kernel.cc | 2 +- 2 files changed, 114 insertions(+), 31 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 91f940841..abd7fe9fd 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1,4 +1,6 @@ #include +#include +#include #include "triton/codegen/selection/generator.h" #include "triton/codegen/target.h" #include "triton/codegen/analysis/axes.h" @@ -530,8 +532,6 @@ void generator::visit_load_inst(ir::load_inst* x){ ir::value *op = x->get_pointer_operand(); ir::masked_load_inst *mx = dynamic_cast(x); Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); - int space = op->get_type()->get_scalar_ty()->get_pointer_address_space(); - // compute vector width size_t vec = 1; if(op->get_type()->is_block_ty()){ @@ -540,43 +540,123 @@ void generator::visit_load_inst(ir::load_inst* x){ size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]); vec = std::min(nts, aln); } - // code generation auto idxs = idxs_.at(x); for(size_t i = 0; i < idxs.size(); i += vec){ indices_t idx = idxs[i]; // pointer value - Value *ptr = bit_cast(vals_[op][idx], ptr_ty(vec_ty(ty, vec), space)); + Value *ptr = vals_[op][idx]; // masked load - Value *ret = nullptr; - if(mx){ - // if mask: - // ret = load(ptr) - // else: - // ret = false_value - PHINode *_ret = phi(ptr->getType()->getPointerElementType(), 2); - Instruction *then_term; - Instruction *else_term; - builder_->SetInsertPoint(_ret->getParent()); - Instruction* dummy = builder_->CreateRet(nullptr); - llvm::SplitBlockAndInsertIfThenElse(vals_[mx->get_mask_operand()][idx], _ret, &then_term, &else_term); - dummy->removeFromParent(); - builder_->SetInsertPoint(then_term); - Value* then_ret = load(ptr); - builder_->SetInsertPoint(else_term); - Value* else_ret = splat(vec, vals_[mx->get_false_value_operand()][idx]); - builder_->SetInsertPoint(_ret->getParent()); - _ret->addIncoming(then_ret, then_term->getParent()); - _ret->addIncoming(else_ret, else_term->getParent()); - ret = (Value*)_ret; + size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + // input ptr info + GetElementPtrInst *in_gep = dyn_cast(ptr); + Value *in_base = in_gep->getPointerOperand(); + ConstantInt* cst = dyn_cast(in_gep->idx_begin()); + size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0; + in_base = cst ? in_base : in_gep; + + Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue(); + Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr; + size_t nbits = dtsize*8; + // pack sub-words (< 32/64bits) into words + // each load has width min(nbits*vec, 32/64) + // and there are (nbits * vec)/width of them + int max_word_width = std::max(32, nbits); + int tot_width = nbits*vec; + int width = std::min(tot_width, max_word_width); + int n_words = std::max(1, tot_width / width); + // ----- + // create inline asm string + // ----- + std::ostringstream asm_oss; + asm_oss << "@$" << n_words; // predicate + asm_oss << " ld.global.cg"; + if(n_words > 1) + asm_oss << ".v" << n_words; // vector width + asm_oss << ".b" << width; // word size + asm_oss << " {"; + for(int i = 0; i < n_words; i++){ // return values + if(i > 0) asm_oss << ","; + asm_oss << "$" << i; } - else - ret = load(ptr); - // write back + asm_oss << "}"; + asm_oss << ", [ $" << n_words + 1; // load + asm_oss << " + " << in_off << "];"; // constant offset + bool has_other = other && (other != UndefValue::get(other->getType())); + std::vector others; + // handle `other` values for indices where the mask + // is false + if(has_other) + for(size_t ii = 0; ii < n_words; ii++){ + size_t size = width / nbits; + Value *v = UndefValue::get(vec_ty(ty, size)); + for(size_t s = 0; s < size; s++){ + ir::value *false_val = mx->get_false_value_operand(); + v = insert_elt(v, vals_[false_val][idxs[i + ii*size + s]], s); + } + v = bit_cast(v, IntegerType::get(*ctx_, width)); + asm_oss << "\n "; + asm_oss << "@!$" << n_words << " mov.u" << width; + asm_oss << " $" << ii << ", "; + std::ios_base::fmtflags flags(asm_oss.flags()); + if(ConstantInt* cst = dyn_cast(v)) + asm_oss << "0x" << std::hex << cst->getSExtValue(); + else{ + asm_oss << "$" << n_words + 2 + ii; + others.push_back(v); + } + asm_oss.flags(flags); + asm_oss << ";"; + } + // ---- + // create inline ASM signature + // --- + std::vector ret_tys(n_words, IntegerType::get(*ctx_, width)); + Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0]; + std::vector arg_tys = {pred->getType(), ptr->getType()}; + for(Value *v: others) + arg_tys.push_back(v->getType()); + FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false); + // --- + // create inline ASM constraints + // --- + std::string asm_cstrt; + for(int ii = 0; ii < n_words; ii++){ + if(ii > 0) asm_cstrt += ","; + asm_cstrt += (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c"); + } + asm_cstrt += ",b,l"; + for(size_t ii = 0; ii < others.size(); ii++){ + asm_cstrt += ","; + asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c"); + } + // --- + // finally call inline ASM + // --- + InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true); + std::vector args = {pred, in_base}; + for(Value *v: others) + args.push_back(v); + Value *_ret = call(_asm, args); + // --- + // extract and store return values + // --- + std::vector rets; + for(unsigned int ii = 0; ii < n_words; ii++){ + Value *curr; + if(ret_ty->isStructTy()) + curr = extract_val(_ret, {ii}); + else + curr = _ret; +// std::cout << n_words << " " << vec << " " << width << " " << dtsize << " " << nbits << std::endl; + rets.push_back(bit_cast(curr, vec_ty(ty, width / (dtsize*8)))); + } + int tmp = (width / (dtsize * 8)); for(size_t ii = 0; ii < vec; ii++) - vals_[x][idxs[i+ii]] = extract_elt(ret, ii); + vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp); } } + void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) { visit_load_inst(x); } @@ -1703,7 +1783,10 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { int off = (off_1*shapes[in_order[0]] + off_0); std::pair key = {id_1 % n_shared_1, id_0 % n_shared_0}; if(ptrs.find(key) == ptrs.end()){ - builder_->SetInsertPoint(FirstBB->getTerminator()); + if(FirstBB->getTerminator()) + builder_->SetInsertPoint(FirstBB->getTerminator()); + else + builder_->SetInsertPoint(FirstBB); indices_t idx = idxs_.at(arg).at(key.first*in_ld); Value* phase = udiv(idx[in_order[1]], i32(per_phase)); phase = urem(phase, i32(max_phase)); diff --git a/lib/driver/kernel.cc b/lib/driver/kernel.cc index 4d02ac469..4771191ba 100755 --- a/lib/driver/kernel.cc +++ b/lib/driver/kernel.cc @@ -81,7 +81,7 @@ cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(progra dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_); dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_); dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_); - std::cout << n_reg << std::endl; +// std::cout << n_reg << std::endl; if (shared_optin > 49152){ // std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl; dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);