[BACKEND] Compiler improvements (#557)

This PR adds several optimization capabilities in the compiler backend:
- Now using inline PTX for `tl.store`, making it possible to use things like evict_last
- For A100, mma layout can be directly converted to shared memory
- For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major.
- Fixed liveness analysis; this was broken.
- Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop.
- `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
This commit is contained in:
Philippe Tillet
2022-06-27 11:49:19 -07:00
committed by GitHub
parent 87413bc925
commit 5b4c8f221e
25 changed files with 882 additions and 284 deletions

View File

@@ -744,11 +744,13 @@ void generator::visit_load_inst(ir::load_inst* x){
BasicBlock *current = builder_->GetInsertBlock();
Module *module = current->getModule();
Value *tid = tgt_->get_local_id(module, *builder_, 0);
Value *lane = urem(tid, i32(32));
ir::value *op = x->get_pointer_operand();
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
// compute vector width
size_t vec = 1;
bool is_mma_first_row = false;
if(op->get_type()->is_block_ty()){
auto ord = ords_.at(op);
size_t aln = alignment_->get(op, ord[0]);
@@ -757,11 +759,15 @@ void generator::visit_load_inst(ir::load_inst* x){
max_eq = std::max<size_t>(max_eq, 1);
aln = std::min(aln, max_eq);
}
auto layout = layouts_->get(x)->to_scanline();
if(layout){
size_t nts = layout->nts(ord[0]);
vec = std::min(nts, aln);
}
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(x));
assert(layout);
vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln);
// TODO: generalize
is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
(a_axes_->get(x, ord[0]) == layouts_->get(x)->get_axis(1));
if(is_mma_first_row)
vec = std::min<size_t>(2, aln);
}
// code generation
auto idxs = idxs_.at(x);
@@ -795,8 +801,8 @@ void generator::visit_load_inst(ir::load_inst* x){
int tot_width = nbits*vec;
int width = std::min(tot_width, max_word_width);
int n_words = std::max(1, tot_width / width);
bool has_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
has_evict_policy = false; // currently disable until supported in `store`
bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
// has_evict_policy = false; // currently disable until supported in `store`
// -----
// create inline asm string
// -----
@@ -810,7 +816,7 @@ void generator::visit_load_inst(ir::load_inst* x){
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
if (has_evict_policy) asm_oss << ".L2::cache_hint";
if (has_l2_evict_policy) asm_oss << ".L2::cache_hint";
if(n_words > 1)
asm_oss << ".v" << n_words; // vector width
asm_oss << ".b" << width; // word size
@@ -822,7 +828,7 @@ void generator::visit_load_inst(ir::load_inst* x){
asm_oss << "}";
asm_oss << ", [ $" << n_words + 1; // load
asm_oss << " + " << in_off << "]"; // constant offset
if (has_evict_policy) asm_oss << ", $" << n_words + 2;
if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2;
asm_oss << ";";
bool has_other = other && (other != UndefValue::get(other->getType()));
std::vector<Value *> others;
@@ -844,7 +850,7 @@ void generator::visit_load_inst(ir::load_inst* x){
if(ConstantInt* cst = dyn_cast<ConstantInt>(v))
asm_oss << "0x" << std::hex << cst->getSExtValue();
else{
asm_oss << "$" << n_words + has_evict_policy + 2 + ii;
asm_oss << "$" << n_words + has_l2_evict_policy + 2 + ii;
others.push_back(v);
}
asm_oss.flags(flags);
@@ -859,7 +865,7 @@ void generator::visit_load_inst(ir::load_inst* x){
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
for(Value *v: others)
arg_tys.push_back(v->getType());
if (has_evict_policy)
if (has_l2_evict_policy)
arg_tys.push_back(i64_ty);
FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false);
// ---
@@ -875,7 +881,7 @@ void generator::visit_load_inst(ir::load_inst* x){
asm_cstrt += ",";
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
}
if (has_evict_policy)
if (has_l2_evict_policy)
asm_cstrt += ",l";
// ---
// finally call inline ASM
@@ -884,7 +890,7 @@ void generator::visit_load_inst(ir::load_inst* x){
std::vector<Value*> args = {pred, ptr};
for(Value *v: others)
args.push_back(v);
if (has_evict_policy)
if (has_l2_evict_policy)
args.push_back(policies_.at(x->get_eviction_policy()));
@@ -935,6 +941,9 @@ void generator::visit_store_inst(ir::store_inst * x){
// operands
ir::value *ptr_op = x->get_pointer_operand();
ir::value *val_op = x->get_value_operand();
ir::value *msk_op = nullptr;
if(auto* msk_st = dynamic_cast<ir::masked_store_inst*>(x))
msk_op = msk_st->get_mask_operand();
// vector size
size_t vec = 1;
if(val_op->get_type()->is_block_ty()){
@@ -946,36 +955,107 @@ void generator::visit_store_inst(ir::store_inst * x){
max_eq = std::max<size_t>(max_eq, 1);
aln = std::min(aln, max_eq);
}
vec = std::min(nts, aln);
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(ptr_op));
assert(layout);
// vec = std::min(nts, aln);
vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln);
// TODO: generalize
bool is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
(a_axes_->get(ptr_op, ord[0]) == layouts_->get(ptr_op)->get_axis(1));
if(is_mma_first_row)
vec = std::min<size_t>(2, aln);
}
bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
auto idxs = idxs_.at(val_op);
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
ty = f16_ty;
if(ty->isIntegerTy(1))
ty = builder_->getInt8Ty();
for(size_t i = 0; i < idxs.size(); i += vec){
auto idx = idxs[i];
// pointer
indices_t idx = idxs[i];
// pointers
Value *ptr = vals_[ptr_op][idx];
// vectorize
Type *v_ty = vec_ty(ty, vec);
ptr = bit_cast(ptr, v_ty->getPointerTo(1));
// value
Value* val = UndefValue::get(v_ty);
for(size_t ii = 0; ii < vec; ii++)
val = insert_elt(val, bit_cast(vals_.at(val_op)[idxs[i + ii]], ty), ii);
if(mx){
Value *msk = vals_[mx->get_mask_operand()][idx];
Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {});
builder_->SetInsertPoint(no_op->getParent());
Instruction* dummy = builder_->CreateRet(nullptr);
Instruction *term = llvm::SplitBlockAndInsertIfThen(msk, no_op, false);
dummy->removeFromParent();
builder_->SetInsertPoint(term);
store(val, ptr);
builder_->SetInsertPoint(no_op);
size_t dtsize = std::max<int>(1, val_op->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8);
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(ptr);
size_t in_off;
if(in_gep){
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
ptr = cst ? in_gep->getPointerOperand() : in_gep;
}
else
store(val, ptr);
else{
in_off = 0;
}
// mask
Value *pred = msk_op ? vals_[msk_op][idx] : builder_->getTrue();
size_t nbits = dtsize*8;
// pack sub-words (< 32/64bits) into words
// each load has width min(nbits*vec, 32/64)
// and there are (nbits * vec)/width of them
int max_word_width = std::max<int>(32, nbits);
int tot_width = nbits*vec;
int width = std::min(tot_width, max_word_width);
int n_words = std::max(1, tot_width / width);
// -----
// create inline asm string
// -----
std::ostringstream asm_oss;
asm_oss << "@$0"; // predicate
asm_oss << " st.global";
if (has_l2_evict_policy) asm_oss << ".L2::cache_hint";
if(n_words > 1)
asm_oss << ".v" << n_words; // vector width
asm_oss << ".b" << width; // word size
asm_oss << " [ $1 + " << in_off << "]";
asm_oss << " , {";
for(int i = 0; i < n_words; i++){ // return values
if(i > 0) asm_oss << ",";
asm_oss << "$" << 2 + i;
}
asm_oss << "}";
if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2;
asm_oss << ";";
// ----
// create inline ASM signature
// ---
Type* val_arg_ty = IntegerType::get(*ctx_, width);
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
for(int ii = 0; ii < n_words; ii++)
arg_tys.push_back(val_arg_ty);
if (has_l2_evict_policy)
arg_tys.push_back(i64_ty);
FunctionType *asm_ty = FunctionType::get(builder_->getVoidTy(), arg_tys, false);
// ---
// create inline ASM constraints
// ---
std::string asm_cstrt = "b,l";
for(int ii = 0; ii < n_words; ii++){
asm_cstrt += ",";
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
}
if (has_l2_evict_policy)
asm_cstrt += ",l";
// ---
// finally call inline ASM
// ---
InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
std::vector<Value*> args = {pred, ptr};
for(unsigned int ii = 0; ii < n_words; ii++){
size_t n_subw = width / nbits;
Value* curr = UndefValue::get(vec_ty(ty, n_subw));
for(unsigned int jj = 0; jj < n_subw; jj++){
Value* new_elt = vals_[val_op][idxs[i + ii*n_subw + jj]];
if(new_elt->getType()->isIntegerTy(1))
new_elt = builder_->CreateSExt(new_elt, builder_->getInt8Ty());
new_elt = bit_cast(new_elt, ty);
curr = builder_->CreateInsertElement(curr, new_elt, jj);
}
args.push_back(bit_cast(curr, val_arg_ty));
}
if (has_l2_evict_policy)
args.push_back(policies_.at(x->get_eviction_policy()));
call(_asm, args);
}
}
void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) {
@@ -1098,6 +1178,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
for(auto idx: idxs_.at(x)){
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
// Value *ex2arg = vals_[x->get_operand(0)][idx];
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
}
}
@@ -1291,6 +1372,18 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
// order
auto ord_a = layouts_->get(A)->get_order();
auto ord_b = layouts_->get(B)->get_order();
bool is_a_trans = C->is_trans_a();
// is_a_trans = false;
if(C->is_trans_a()){
std::swap(ord_a[0], ord_a[1]);
std::swap(shape_a[0], shape_a[1]);
std::swap(offset_a_m_, offset_a_k_);
}
// std::cout << "visiting" << std::endl;
// if(C->is_trans_b()){
// std::swap(ord_b[0], ord_b[1]);
// std::swap(shape_b[0], shape_b[1]);
// }
// layouts
analysis::mma_layout* layout_c = layouts_->get(C)->to_mma();
analysis::shared_layout* layout_a = layouts_->get(A)->to_shared();
@@ -1322,6 +1415,12 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
int step_b0 = is_b_row ? stride_rep_n : stride_rep_k;
int num_ptr_b = std::max(2 * per_phase_b * max_phase_b / step_b0, 1);
// max_phase_a = 4;
// vec_a = 8;
// std::cout << per_phase_a << " " << max_phase_a << " " << step_a0 << " " << num_ptr_a << " " << stride_am << " " << stride_ak << " " << stride_a0 << " " << stride_a1 << std::endl;
// std::cout << vec_a << " " << vec_b << std::endl;
/* --------------------------------- */
/* --- pre-compute pointer lanes --- */
/* --------------------------------- */
@@ -1916,12 +2015,17 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
auto shape_a = A->get_type()->get_block_shapes();
auto shape_b = B->get_type()->get_block_shapes();
auto ord_a = layouts_->get(A)->get_order();
if(C->is_trans_a()){
std::swap(ord_a[0], ord_a[1]);
std::swap(shape_a[0], shape_a[1]);
}
auto ord_b = layouts_->get(B)->get_order();
if(C->is_trans_b()){
std::swap(ord_b[0], ord_b[1]);
std::swap(shape_b[0], shape_b[1]);
}
NK = shape_a[1];
analysis::mma_layout* layout = layouts_->get(C)->to_mma();
analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0));
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
bool is_a_row = ord_a[0] == 1;
bool is_b_row = ord_b[0] == 1;
std::vector<int> mma_instr_shape = layout->get_mma_instr_shape();
const int mma_instr_m = mma_instr_shape[0];
@@ -1933,10 +2037,6 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
const int mat_shape_n = mat_shape[1];
const int mat_shape_k = mat_shape[2];
const int per_phase_a = swizzle_->get_per_phase(layout_a);
const int max_phase_a = swizzle_->get_max_phase(layout_a);
const int per_phase_b = swizzle_->get_per_phase(layout_b);
const int max_phase_b = swizzle_->get_max_phase(layout_b);
const int num_rep_m = shapes[0] / layout->shape_per_cta(0);
const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
@@ -2001,7 +2101,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
if(FirstBB != CurrBB)
// if true, this will move pointer declarations to the entry basic block
// not prefetched cases tend to be more limited in resource usage
// so we don't pre-compute ptrs to save registers
bool licm_ptrs = C->is_prefetched() && (FirstBB != CurrBB);
if(licm_ptrs)
builder_->SetInsertPoint(FirstBB->getTerminator());
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
@@ -2015,47 +2120,137 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
auto register_lds2 =
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) {
if (k < 2 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block));
} else
vals[{mn, k}] = val;
};
// | -> k (row-major), since we have ldmatrix.trans, we only need to change stride
// v (s0_0(0), s1_0(2), | *num_rep_k
// m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2)
// -----------
// *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0))
mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
{mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep);
std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane);
int num_ptr_a = a_loader.get_num_ptr();
std::function<void(int,int,int,bool)> load_a;
analysis::shared_layout* layout_a = layouts_->get(C->get_operand(0))->to_shared();
bool is_a_shared = layout_a != nullptr;
if(is_a_shared) {
const int per_phase_a = swizzle_->get_per_phase(layout_a);
const int max_phase_a = swizzle_->get_max_phase(layout_a);
mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
{mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep);
std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane);
int num_ptr_a = a_loader.get_num_ptr();
// pointers
std::vector<Value*> ptrs_a(num_ptr_a);
if(licm_ptrs)
builder_->SetInsertPoint(CurrBB);
for(int i = 0; i < num_ptr_a; i++)
ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
if(licm_ptrs)
builder_->SetInsertPoint(FirstBB->getTerminator());
// loading function
load_a = [&,a_loader,ptrs_a,off_a](int m, int k, int inc, bool is_prefetch) mutable {
auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
shared_next_ptr_[layout_a], off_a, ptrs_a,
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
register_lds2(ha, m, k, inc, ha0, is_prefetch);
register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
};
}
else {
load_a = [&](int m, int k, int inc, bool is_prefetch) {
distributed_axis ax_n = axes_.at(a_axes_->get(A, 1));
int ldm = ax_n.values.size();
if(ldm != num_rep_k*4)
throw std::runtime_error("Internal compiler error when trying to fuse matmuls!");
// std::cout << m << " " << k << std::endl;
// std::cout << idxs_[A].size() << std::endl;
// std::cout << (m+1)*ldm + k*2 + 3 << std::endl;
// int ldm = num_rep_k*4;
Value* ha0 = UndefValue::get(fp16x2_ty);
Value* ha1 = UndefValue::get(fp16x2_ty);
Value* ha2 = UndefValue::get(fp16x2_ty);
Value* ha3 = UndefValue::get(fp16x2_ty);
ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 0]], i32(0));
ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 1]], i32(1));
ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 0]], i32(0));
ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 1]], i32(1));
ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 2]], i32(0));
ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 3]], i32(1));
ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 2]], i32(0));
ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 3]], i32(1));
ha[{m, k}] = ha0;
ha[{m+1, k}] = ha1;
ha[{m, k+1}] = ha2;
ha[{m+1, k+1}] = ha3;
};
}
// | -> n (col-major)
// v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n
// k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1))
// -----------
// *num_rep_k (stride in num of matrices(mat_stride_bk): 2)
mma16816_smem_loader b_loader(layout->wpt(1), ord_b, /*k_order*/0, shape_b,
{mma_instr_k, mma_instr_n}, {mat_shape_k, mat_shape_n},
analysis::shared_layout* layout_b = layouts_->get(C->get_operand(1))->to_shared();
const int per_phase_b = swizzle_->get_per_phase(layout_b);
const int max_phase_b = swizzle_->get_max_phase(layout_b);
std::vector<int> mma_instr_b{mma_instr_k, mma_instr_n};
std::vector<int> mat_shape_b{mat_shape_k, mat_shape_n};
int k_order_b = 0;
// if(C->is_trans_b()){
// std::swap(mma_instr_b[0], mma_instr_b[1]);
// std::swap(mat_shape_b[0], mat_shape_b[1]);
// k_order_b = k_order_b ^ 1;
// std::swap(ord_b[0], ord_b[1]);
// std::swap(shape_b[0], shape_b[1]);
// }
mma16816_smem_loader b_loader(layout->wpt(1), ord_b, k_order_b, shape_b,
mma_instr_b, mat_shape_b,
per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep);
std::vector<Value*> off_b = b_loader.compute_offs(warp_n, lane);
int num_ptr_b = b_loader.get_num_ptr();
builder_->SetInsertPoint(CurrBB);
// A pointer
std::vector<Value*> ptrs_a(num_ptr_a);
for(int i = 0; i < num_ptr_a; i++)
ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
// B pointer
if(licm_ptrs)
builder_->SetInsertPoint(CurrBB);
// pointers
int num_ptr_b = b_loader.get_num_ptr();
std::vector<Value*> ptrs_b(num_ptr_b);
for(int i = 0; i < num_ptr_b; i++)
ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty);
InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() +
// loading function
std::function<void(int,int,int,bool)> load_b;
load_b = [&](int n, int k, int inc, bool is_prefetch) {
auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
shared_next_ptr_[layout_b], off_b, ptrs_b,
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
register_lds2(hb, n, k, inc, hb0, is_prefetch);
register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
};
// create mma & unpack result, m, n, k are offsets in mat
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() +
" {$0, $1, $2, $3},"
" {$4, $5, $6, $7},"
" {$8, $9},"
" {$10, $11, $12, $13};",
"=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true);
// create mma & unpack result, m, n, k are offsets in mat
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
unsigned cols_per_thread = num_rep_n * 2;
std::vector<size_t> idx = {
(m + 0)*cols_per_thread + (n*2 + 0),
@@ -2072,39 +2267,6 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
fc[idx[3]] = extract_val(nc, std::vector<unsigned>{3});
};
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
auto register_lds2 =
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) {
if (k < 2 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block));
} else
vals[{mn, k}] = val;
};
auto load_a = [&](int m, int k, int inc, bool is_prefetch) {
auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
shared_next_ptr_[layout_a], off_a, ptrs_a,
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
register_lds2(ha, m, k, inc, ha0, is_prefetch);
register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
};
auto load_b = [&](int n, int k, int inc, bool is_prefetch) {
auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
shared_next_ptr_[layout_b], off_b, ptrs_b,
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
register_lds2(hb, n, k, inc, hb0, is_prefetch);
register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
};
if (C->is_prefetched()) {
// create phis
builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
@@ -2163,6 +2325,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
i = 0;
vals_[C][idx] = fcs.at(key)[i++];
};
}
/**
@@ -2384,7 +2547,7 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
} else if (layout->to_mma()) {
shuffle_width = 4;
warps_per_inner = layout->to_mma()->wpt(1);
col_per_thread = 16;
col_per_thread = axes_.at(a_axes_->get(arg, 1)).values.size();
warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id;
}
assert(warp_j != nullptr);
@@ -2403,7 +2566,8 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
Value* is_warp0 = icmp_eq(warp, i32(0));
Value* is_thread0 = icmp_eq(thread, i32(0));
Value* lane_j = urem(lane, i32(shuffle_width));
add_barrier();
if(warps_per_inner > 1)
add_barrier();
// compute partial sum for each warp, and store to shared memory
for(size_t i = 0; i < n_elts/col_per_thread; i++){
std::pair<Value*, Value*> acc;
@@ -2425,13 +2589,21 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
// store partial result to shared memory
auto x_idxs = idxs_[x][i];
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j);
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first});
if (with_index) {
call(st_shared_index,
{icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second});
// single warp on the reduce dimension -- no need to use shmem
if(warps_per_inner==1){
vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first;
}
else{
Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j);
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first});
if (with_index) {
call(st_shared_index,
{icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second});
}
}
}
if(warps_per_inner==1)
return;
add_barrier();
// at this point, partial accumulator synchronized in shared memory
// Just need to reduce `warp_per_inner` numbers in shared memory
@@ -2559,6 +2731,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
case ir::reduce_inst::FMAX: return max_num(x, y);
case ir::reduce_inst::FMIN: return min_num(x, y);
case ir::reduce_inst::XOR: return xor_(x, y);
default: throw std::runtime_error("unreachable");
}
};
@@ -2639,7 +2812,9 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
Value *base;
base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out)))));
int off = alloc_->offset(layouts_->get(layouts_->tmp(out)));
// std::cout << off << std::endl;
base = gep(shmem_, i32(off));
base = bit_cast(base, ptr_ty(ty, 3));
std::vector<int> n_reps;
for(int i = 0; i < shape.size(); i++){
@@ -2821,15 +2996,26 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
//
int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]);
int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]);
if(in_layout->to_mma()){
mts_0 = 4 * in_layout->to_mma()->wpt(in_order[0]);
mts_1 = 8 * in_layout->to_mma()->wpt(in_order[1]);
per_phase = 1;
max_phase = 8;
}
int in_ld = in_layout->get_shape()[in_order[0]] / mts_0;
int n_shared_1 = std::max<int>(per_phase*max_phase / mts_1, 1);
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
int n_shared_1 = std::max<int>(per_phase*max_phase / mts_1, 1);
if(in_layout->to_mma()){
n_shared_0 = 8;
n_shared_1 = 1;
}
BasicBlock* CurrBB = builder_->GetInsertBlock();
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
auto shapes = cts->get_type()->get_block_shapes();
// store to shared
Value *current = nullptr;
std::map<std::pair<int, int>, Value*> ptrs;
@@ -2844,9 +3030,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
// input ptr info
int id_0 = id % (in_ld/min_vec);
int id_1 = id / (in_ld/min_vec);
int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0;
int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1;
int off = (off_1*shapes[in_order[0]] + off_0);
// std::cout << id_0 << " " << id_1 << " " << in_ld << " " << std::endl;
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
if(ptrs.find(key) == ptrs.end()){
if(FirstBB->getTerminator())
@@ -2865,6 +3049,13 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
builder_->SetInsertPoint(CurrBB);
ptrs[key] = gep(shmems_.at(cts), {off});
}
int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0;
int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1;
if(in_layout->to_mma()){
off_0 = id_0/n_shared_0*n_shared_0*8;
off_1 = id_1/n_shared_1*n_shared_1*8;
}
int off = (off_1*shapes[in_order[0]] + off_0);
Value* ptr = gep(ptrs[key], {i32(off)});
ptr = bit_cast(ptr, current->getType()->getPointerTo(3));
// asm
@@ -3069,7 +3260,7 @@ void generator::visit_function(ir::function* fn) {
if(tgt_->as_nvidia()->sm() >= 80)
for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){
std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last";
std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0;";
std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0, 1.0;";
InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false);
policies_[evict] = call(iasm);
}
@@ -3299,7 +3490,6 @@ void generator::visit_basic_block(ir::basic_block * block) {
BasicBlock *parent = bbs_[block];
builder_->SetInsertPoint(parent);
for(ir::instruction *i: block->get_inst_list()){
// i->print(std::cout);
visit_value(i);
// std::cout << "done" << std::endl;
}
@@ -3324,7 +3514,10 @@ void generator::init_idx(ir::value *v) {
std::vector<distributed_axis> axes(rank);
std::vector<int> ord(rank);
// compute axes
// std::cout << "axes" << std::endl;
for(size_t d = 0; d < shapes.size(); d++){
// std::cout << d << " " << shapes[d] << std::endl;
// std::cout << a_axes_->get(v, d) << std::endl;
if(shapes[d] > 1){
unsigned x = a_axes_->get(v, d);
axes[d] = axes_.at(x);
@@ -3334,6 +3527,7 @@ void generator::init_idx(ir::value *v) {
axes[d].values = {i32(0)};
}
}
// std::cout << "axes ok" << std::endl;
// compute order
analysis::data_layout* layout = layouts_->get(v);
std::iota(ord.begin(), ord.end(), 0);
@@ -3480,6 +3674,7 @@ void generator::finalize_phi_node(ir::phi_node *x) {
for(indices_t idx: idxs_.at(x)){
PHINode *phi = (PHINode*)vals_[x][idx];
Value *inc = vals_[x->get_incoming_value(n)][idx];
// x->print(std::cout);
phi->addIncoming(inc, block);
}
}