[CODEGEN] Switching to predicated inline PTX for LDGs (#103)
This commit is contained in:
committed by
Philippe Tillet
parent
ac57812bdc
commit
1e844ba78d
@@ -1,4 +1,6 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
@@ -530,8 +532,6 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
ir::value *op = x->get_pointer_operand();
|
||||
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
|
||||
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
|
||||
int space = op->get_type()->get_scalar_ty()->get_pointer_address_space();
|
||||
|
||||
// compute vector width
|
||||
size_t vec = 1;
|
||||
if(op->get_type()->is_block_ty()){
|
||||
@@ -540,43 +540,123 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]);
|
||||
vec = std::min(nts, aln);
|
||||
}
|
||||
|
||||
// code generation
|
||||
auto idxs = idxs_.at(x);
|
||||
for(size_t i = 0; i < idxs.size(); i += vec){
|
||||
indices_t idx = idxs[i];
|
||||
// pointer value
|
||||
Value *ptr = bit_cast(vals_[op][idx], ptr_ty(vec_ty(ty, vec), space));
|
||||
Value *ptr = vals_[op][idx];
|
||||
// masked load
|
||||
Value *ret = nullptr;
|
||||
if(mx){
|
||||
// if mask:
|
||||
// ret = load(ptr)
|
||||
// else:
|
||||
// ret = false_value
|
||||
PHINode *_ret = phi(ptr->getType()->getPointerElementType(), 2);
|
||||
Instruction *then_term;
|
||||
Instruction *else_term;
|
||||
builder_->SetInsertPoint(_ret->getParent());
|
||||
Instruction* dummy = builder_->CreateRet(nullptr);
|
||||
llvm::SplitBlockAndInsertIfThenElse(vals_[mx->get_mask_operand()][idx], _ret, &then_term, &else_term);
|
||||
dummy->removeFromParent();
|
||||
builder_->SetInsertPoint(then_term);
|
||||
Value* then_ret = load(ptr);
|
||||
builder_->SetInsertPoint(else_term);
|
||||
Value* else_ret = splat(vec, vals_[mx->get_false_value_operand()][idx]);
|
||||
builder_->SetInsertPoint(_ret->getParent());
|
||||
_ret->addIncoming(then_ret, then_term->getParent());
|
||||
_ret->addIncoming(else_ret, else_term->getParent());
|
||||
ret = (Value*)_ret;
|
||||
size_t dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
// input ptr info
|
||||
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(ptr);
|
||||
Value *in_base = in_gep->getPointerOperand();
|
||||
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
|
||||
size_t in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
|
||||
in_base = cst ? in_base : in_gep;
|
||||
|
||||
Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue();
|
||||
Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr;
|
||||
size_t nbits = dtsize*8;
|
||||
// pack sub-words (< 32/64bits) into words
|
||||
// each load has width min(nbits*vec, 32/64)
|
||||
// and there are (nbits * vec)/width of them
|
||||
int max_word_width = std::max<int>(32, nbits);
|
||||
int tot_width = nbits*vec;
|
||||
int width = std::min(tot_width, max_word_width);
|
||||
int n_words = std::max(1, tot_width / width);
|
||||
// -----
|
||||
// create inline asm string
|
||||
// -----
|
||||
std::ostringstream asm_oss;
|
||||
asm_oss << "@$" << n_words; // predicate
|
||||
asm_oss << " ld.global.cg";
|
||||
if(n_words > 1)
|
||||
asm_oss << ".v" << n_words; // vector width
|
||||
asm_oss << ".b" << width; // word size
|
||||
asm_oss << " {";
|
||||
for(int i = 0; i < n_words; i++){ // return values
|
||||
if(i > 0) asm_oss << ",";
|
||||
asm_oss << "$" << i;
|
||||
}
|
||||
else
|
||||
ret = load(ptr);
|
||||
// write back
|
||||
asm_oss << "}";
|
||||
asm_oss << ", [ $" << n_words + 1; // load
|
||||
asm_oss << " + " << in_off << "];"; // constant offset
|
||||
bool has_other = other && (other != UndefValue::get(other->getType()));
|
||||
std::vector<Value *> others;
|
||||
// handle `other` values for indices where the mask
|
||||
// is false
|
||||
if(has_other)
|
||||
for(size_t ii = 0; ii < n_words; ii++){
|
||||
size_t size = width / nbits;
|
||||
Value *v = UndefValue::get(vec_ty(ty, size));
|
||||
for(size_t s = 0; s < size; s++){
|
||||
ir::value *false_val = mx->get_false_value_operand();
|
||||
v = insert_elt(v, vals_[false_val][idxs[i + ii*size + s]], s);
|
||||
}
|
||||
v = bit_cast(v, IntegerType::get(*ctx_, width));
|
||||
asm_oss << "\n ";
|
||||
asm_oss << "@!$" << n_words << " mov.u" << width;
|
||||
asm_oss << " $" << ii << ", ";
|
||||
std::ios_base::fmtflags flags(asm_oss.flags());
|
||||
if(ConstantInt* cst = dyn_cast<ConstantInt>(v))
|
||||
asm_oss << "0x" << std::hex << cst->getSExtValue();
|
||||
else{
|
||||
asm_oss << "$" << n_words + 2 + ii;
|
||||
others.push_back(v);
|
||||
}
|
||||
asm_oss.flags(flags);
|
||||
asm_oss << ";";
|
||||
}
|
||||
// ----
|
||||
// create inline ASM signature
|
||||
// ---
|
||||
std::vector<Type*> ret_tys(n_words, IntegerType::get(*ctx_, width));
|
||||
Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0];
|
||||
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
|
||||
for(Value *v: others)
|
||||
arg_tys.push_back(v->getType());
|
||||
FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false);
|
||||
// ---
|
||||
// create inline ASM constraints
|
||||
// ---
|
||||
std::string asm_cstrt;
|
||||
for(int ii = 0; ii < n_words; ii++){
|
||||
if(ii > 0) asm_cstrt += ",";
|
||||
asm_cstrt += (width == 64) ? "=l" : ((width == 32) ? "=r" : "=c");
|
||||
}
|
||||
asm_cstrt += ",b,l";
|
||||
for(size_t ii = 0; ii < others.size(); ii++){
|
||||
asm_cstrt += ",";
|
||||
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
}
|
||||
// ---
|
||||
// finally call inline ASM
|
||||
// ---
|
||||
InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
|
||||
std::vector<Value*> args = {pred, in_base};
|
||||
for(Value *v: others)
|
||||
args.push_back(v);
|
||||
Value *_ret = call(_asm, args);
|
||||
// ---
|
||||
// extract and store return values
|
||||
// ---
|
||||
std::vector<Value *> rets;
|
||||
for(unsigned int ii = 0; ii < n_words; ii++){
|
||||
Value *curr;
|
||||
if(ret_ty->isStructTy())
|
||||
curr = extract_val(_ret, {ii});
|
||||
else
|
||||
curr = _ret;
|
||||
// std::cout << n_words << " " << vec << " " << width << " " << dtsize << " " << nbits << std::endl;
|
||||
rets.push_back(bit_cast(curr, vec_ty(ty, width / (dtsize*8))));
|
||||
}
|
||||
int tmp = (width / (dtsize * 8));
|
||||
for(size_t ii = 0; ii < vec; ii++)
|
||||
vals_[x][idxs[i+ii]] = extract_elt(ret, ii);
|
||||
vals_[x][idxs[i+ii]] = extract_elt(rets[ii/tmp], ii % tmp);
|
||||
}
|
||||
}
|
||||
|
||||
void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
||||
visit_load_inst(x);
|
||||
}
|
||||
@@ -1703,7 +1783,10 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||
int off = (off_1*shapes[in_order[0]] + off_0);
|
||||
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
|
||||
if(ptrs.find(key) == ptrs.end()){
|
||||
builder_->SetInsertPoint(FirstBB->getTerminator());
|
||||
if(FirstBB->getTerminator())
|
||||
builder_->SetInsertPoint(FirstBB->getTerminator());
|
||||
else
|
||||
builder_->SetInsertPoint(FirstBB);
|
||||
indices_t idx = idxs_.at(arg).at(key.first*in_ld);
|
||||
Value* phase = udiv(idx[in_order[1]], i32(per_phase));
|
||||
phase = urem(phase, i32(max_phase));
|
||||
|
@@ -81,7 +81,7 @@ cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(progra
|
||||
dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, *cu_);
|
||||
dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, *cu_);
|
||||
dispatch::cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, *cu_);
|
||||
std::cout << n_reg << std::endl;
|
||||
// std::cout << n_reg << std::endl;
|
||||
if (shared_optin > 49152){
|
||||
// std::cout << "dynamic shared memory " << shared_optin << " " << shared_static << std::endl;
|
||||
dispatch::cuFuncSetAttribute(*cu_, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
||||
|
Reference in New Issue
Block a user