[CODEGEN] Switching to predicated inline PTX for LDGs (#103)

This commit is contained in:
Philippe Tillet
2021-05-09 21:59:25 -04:00
committed by Philippe Tillet
parent ac57812bdc
commit 1e844ba78d
2 changed files with 114 additions and 31 deletions

View File

@@ -1,4 +1,6 @@
#include <numeric>
#include <sstream>
#include <iomanip>
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/target.h"
#include "triton/codegen/analysis/axes.h"
@@ -530,8 +532,6 @@ void generator::visit_load_inst(ir::load_inst* x){
ir::value *op = x->get_pointer_operand();
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
int space = op->get_type()->get_scalar_ty()->get_pointer_address_space();
// compute vector width
size_t vec = 1;
if(op->get_type()->is_block_ty()){
@@ -540,43 +540,123 @@ void generator::visit_load_inst(ir::load_inst* x){
size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]);
vec = std::min(nts, aln);
}
// code generation
auto idxs = idxs_.at(x);
for(size_t i = 0; i < idxs.size(); i += vec){
indices_t idx = idxs[i];
// pointer value
Value *ptr = bit_cast(vals_[op][idx], ptr_ty(vec_ty(ty, vec), space));
Value *ptr = vals_[op][idx];
// masked load
Value *ret = nullptr;
if(mx){
// if mask:
// ret = load(ptr)
// else:
// ret = false_value
PHINode *_ret = phi(ptr->getType()->getPointerElementType(), 2);
Instruction *then_term;
Instruction *else_term;
builder_->SetInsertPoint(_ret->getParent());
Instruction* dummy = builder_->CreateRet(nullptr);
llvm::SplitBlockAndInsertIfThenElse(vals_[mx->get_mask_operand()][idx], _ret, &then_term, &else_term);
dummy->removeFromParent();
builder_->SetInsertPoint(then_term);
Value* then_ret = load(ptr);
builder_->SetInsertPoint(else_term);
Value* else_ret = splat(vec, vals_[mx->get_false_value_operand()][idx]);
builder_->SetInsertPoint(_ret->getParent());
_ret->addIncoming(then_ret, then_term->getParent());
_ret->addIncoming(else_ret, else_term->getParent());
ret = (Value*)_ret;
size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
// input ptr info
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(ptr);
Value *in_base = in_gep->getPointerOperand();
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
in_base = cst ? in_base : in_gep;
Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue();
Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr;
size_t nbits = dtsize*8;
// pack sub-words (< 32/64bits) into words
// each load has width min(nbits*vec, 32/64)
// and there are (nbits * vec)/width of them
int max_word_width = std::max<int>(32, nbits);
int tot_width = nbits*vec;
int width = std::min(tot_width, max_word_width);
int n_words = std::max(1, tot_width / width);
// -----
// create inline asm string
// -----
std::ostringstream asm_oss;
asm_oss << "@$" << n_words; // predicate
asm_oss << " ld.global.cg";
if(n_words > 1)
asm_oss << ".v" << n_words; // vector width
asm_oss << ".b" << width; // word size
asm_oss << " {";
for(int i = 0; i < n_words; i++){ // return values
if(i > 0) asm_oss << ",";
asm_oss << "$" << i;
}
else
ret = load(ptr);
// write back
asm_oss << "}";
asm_oss << ", [ $" << n_words + 1; // load
asm_oss << " + " << in_off << "];"; // constant offset
bool has_other = other && (other != UndefValue::get(other->getType()));
std::vector<Value *> others;
// handle `other` values for indices where the mask
// is false
if(has_other)
for(size_t ii = 0; ii < n_words; ii++){
size_t size = width / nbits;
Value *v = UndefValue::get(vec_ty(ty, size));
for(size_t s = 0; s < size; s++){
ir::value *false_val = mx->get_false_value_operand();
v = insert_elt(v, vals_[false_val][idxs[i + ii*size + s]], s);
}
v = bit_cast(v, IntegerType::get(*ctx_, width));
asm_oss << "\n ";
asm_oss << "@!$" << n_words << " mov.u" << width;
asm_oss << " $" << ii << ", ";
std::ios_base::fmtflags flags(asm_oss.flags());
if(ConstantInt* cst = dyn_cast<ConstantInt>(v))
asm_oss << "0x" << std::hex << cst->getSExtValue();
else{
asm_oss << "$" << n_words + 2 + ii;
others.push_back(v);
}
asm_oss.flags(flags);
asm_oss << ";";
}
// ----
// create inline ASM signature
// ---
std::vector<Type*> ret_tys(n_words, IntegerType::get(*ctx_, width));
Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0];
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
for(Value *v: others)
arg_tys.push_back(v->getType());
FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false);
// ---
// create inline ASM constraints
// ---
std::string asm_cstrt;
for(int ii = 0; ii < n_words; ii++){
if(ii > 0) asm_cstrt += ",";
asm_cstrt += (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
}
asm_cstrt += ",b,l";
for(size_t ii = 0; ii < others.size(); ii++){
asm_cstrt += ",";
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
}
// ---
// finally call inline ASM
// ---
InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
std::vector<Value*> args = {pred, in_base};
for(Value *v: others)
args.push_back(v);
Value *_ret = call(_asm, args);
// ---
// extract and store return values
// ---
std::vector<Value *> rets;
for(unsigned int ii = 0; ii < n_words; ii++){
Value *curr;
if(ret_ty->isStructTy())
curr = extract_val(_ret, {ii});
else
curr = _ret;
// std::cout << n_words << " " << vec << " " << width << " " << dtsize << " " << nbits << std::endl;
rets.push_back(bit_cast(curr, vec_ty(ty, width / (dtsize*8))));
}
int tmp = (width / (dtsize * 8));
for(size_t ii = 0; ii < vec; ii++)
vals_[x][idxs[i+ii]] = extract_elt(ret, ii);
vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp);
}
}
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
visit_load_inst(x);
}
@@ -1703,7 +1783,10 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
int off = (off_1*shapes[in_order[0]] + off_0);
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
if(ptrs.find(key) == ptrs.end()){
builder_->SetInsertPoint(FirstBB->getTerminator());
if(FirstBB->getTerminator())
builder_->SetInsertPoint(FirstBB->getTerminator());
else
builder_->SetInsertPoint(FirstBB);
indices_t idx = idxs_.at(arg).at(key.first*in_ld);
Value* phase = udiv(idx[in_order[1]], i32(per_phase));
phase = urem(phase, i32(max_phase));

View File

@@ -81,7 +81,7 @@ cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(progra
dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_);
dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_);
dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_);
std::cout << n_reg << std::endl;
// std::cout << n_reg << std::endl;
if (shared_optin > 49152){
// std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl;
dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);