[BACKEND] Compiler improvements (#557)
This PR adds several optimization capabilities in the compiler backend: - Now using inline PTX for `tl.store`, making it possible to use things like evict_last - For A100, mma layout can be directly converted to shared memory - For A100, an additional "transpose" argument in `dot` allows tensors to be loaded once and used both row- and col- major. - Fixed liveness analysis; this was broken. - Now can load/store directly mma layout without converting. Useful for when tl.dot accumulator is initialized with DRAM data inside of an inner loop. - `tl.dot` can now take LHS inputs in registers when it comes from a previous `tl.dot` instruction. Useful for e.g. fused attention.
This commit is contained in:
@@ -744,11 +744,13 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
Module *module = current->getModule();
|
||||
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||
Value *lane = urem(tid, i32(32));
|
||||
ir::value *op = x->get_pointer_operand();
|
||||
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
|
||||
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
|
||||
// compute vector width
|
||||
size_t vec = 1;
|
||||
bool is_mma_first_row = false;
|
||||
if(op->get_type()->is_block_ty()){
|
||||
auto ord = ords_.at(op);
|
||||
size_t aln = alignment_->get(op, ord[0]);
|
||||
@@ -757,11 +759,15 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
max_eq = std::max<size_t>(max_eq, 1);
|
||||
aln = std::min(aln, max_eq);
|
||||
}
|
||||
auto layout = layouts_->get(x)->to_scanline();
|
||||
if(layout){
|
||||
size_t nts = layout->nts(ord[0]);
|
||||
vec = std::min(nts, aln);
|
||||
}
|
||||
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(x));
|
||||
assert(layout);
|
||||
|
||||
vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln);
|
||||
// TODO: generalize
|
||||
is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
|
||||
(a_axes_->get(x, ord[0]) == layouts_->get(x)->get_axis(1));
|
||||
if(is_mma_first_row)
|
||||
vec = std::min<size_t>(2, aln);
|
||||
}
|
||||
// code generation
|
||||
auto idxs = idxs_.at(x);
|
||||
@@ -795,8 +801,8 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
int tot_width = nbits*vec;
|
||||
int width = std::min(tot_width, max_word_width);
|
||||
int n_words = std::max(1, tot_width / width);
|
||||
bool has_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
|
||||
has_evict_policy = false; // currently disable until supported in `store`
|
||||
bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
|
||||
// has_evict_policy = false; // currently disable until supported in `store`
|
||||
// -----
|
||||
// create inline asm string
|
||||
// -----
|
||||
@@ -810,7 +816,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
if (x->get_cache_modifier() == ir::load_inst::CG) asm_oss << ".cg";
|
||||
if (x->get_eviction_policy() == ir::load_inst::EVICT_FIRST) asm_oss << ".L1::evict_first";
|
||||
if (x->get_eviction_policy() == ir::load_inst::EVICT_LAST) asm_oss << ".L1::evict_last";
|
||||
if (has_evict_policy) asm_oss << ".L2::cache_hint";
|
||||
if (has_l2_evict_policy) asm_oss << ".L2::cache_hint";
|
||||
if(n_words > 1)
|
||||
asm_oss << ".v" << n_words; // vector width
|
||||
asm_oss << ".b" << width; // word size
|
||||
@@ -822,7 +828,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
asm_oss << "}";
|
||||
asm_oss << ", [ $" << n_words + 1; // load
|
||||
asm_oss << " + " << in_off << "]"; // constant offset
|
||||
if (has_evict_policy) asm_oss << ", $" << n_words + 2;
|
||||
if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2;
|
||||
asm_oss << ";";
|
||||
bool has_other = other && (other != UndefValue::get(other->getType()));
|
||||
std::vector<Value *> others;
|
||||
@@ -844,7 +850,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
if(ConstantInt* cst = dyn_cast<ConstantInt>(v))
|
||||
asm_oss << "0x" << std::hex << cst->getSExtValue();
|
||||
else{
|
||||
asm_oss << "$" << n_words + has_evict_policy + 2 + ii;
|
||||
asm_oss << "$" << n_words + has_l2_evict_policy + 2 + ii;
|
||||
others.push_back(v);
|
||||
}
|
||||
asm_oss.flags(flags);
|
||||
@@ -859,7 +865,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
|
||||
for(Value *v: others)
|
||||
arg_tys.push_back(v->getType());
|
||||
if (has_evict_policy)
|
||||
if (has_l2_evict_policy)
|
||||
arg_tys.push_back(i64_ty);
|
||||
FunctionType *asm_ty = FunctionType::get(ret_ty, arg_tys, false);
|
||||
// ---
|
||||
@@ -875,7 +881,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
asm_cstrt += ",";
|
||||
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
}
|
||||
if (has_evict_policy)
|
||||
if (has_l2_evict_policy)
|
||||
asm_cstrt += ",l";
|
||||
// ---
|
||||
// finally call inline ASM
|
||||
@@ -884,7 +890,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
std::vector<Value*> args = {pred, ptr};
|
||||
for(Value *v: others)
|
||||
args.push_back(v);
|
||||
if (has_evict_policy)
|
||||
if (has_l2_evict_policy)
|
||||
args.push_back(policies_.at(x->get_eviction_policy()));
|
||||
|
||||
|
||||
@@ -935,6 +941,9 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
// operands
|
||||
ir::value *ptr_op = x->get_pointer_operand();
|
||||
ir::value *val_op = x->get_value_operand();
|
||||
ir::value *msk_op = nullptr;
|
||||
if(auto* msk_st = dynamic_cast<ir::masked_store_inst*>(x))
|
||||
msk_op = msk_st->get_mask_operand();
|
||||
// vector size
|
||||
size_t vec = 1;
|
||||
if(val_op->get_type()->is_block_ty()){
|
||||
@@ -946,36 +955,107 @@ void generator::visit_store_inst(ir::store_inst * x){
|
||||
max_eq = std::max<size_t>(max_eq, 1);
|
||||
aln = std::min(aln, max_eq);
|
||||
}
|
||||
vec = std::min(nts, aln);
|
||||
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(ptr_op));
|
||||
assert(layout);
|
||||
// vec = std::min(nts, aln);
|
||||
vec = std::min<size_t>(layout->contig_per_thread(ord[0]), aln);
|
||||
// TODO: generalize
|
||||
bool is_mma_first_row = (ord.size() >= 1) && layout->to_mma() &&
|
||||
(a_axes_->get(ptr_op, ord[0]) == layouts_->get(ptr_op)->get_axis(1));
|
||||
if(is_mma_first_row)
|
||||
vec = std::min<size_t>(2, aln);
|
||||
}
|
||||
bool has_l2_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
|
||||
auto idxs = idxs_.at(val_op);
|
||||
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
|
||||
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
|
||||
ty = f16_ty;
|
||||
if(ty->isIntegerTy(1))
|
||||
ty = builder_->getInt8Ty();
|
||||
for(size_t i = 0; i < idxs.size(); i += vec){
|
||||
auto idx = idxs[i];
|
||||
// pointer
|
||||
indices_t idx = idxs[i];
|
||||
// pointers
|
||||
Value *ptr = vals_[ptr_op][idx];
|
||||
// vectorize
|
||||
Type *v_ty = vec_ty(ty, vec);
|
||||
ptr = bit_cast(ptr, v_ty->getPointerTo(1));
|
||||
// value
|
||||
Value* val = UndefValue::get(v_ty);
|
||||
for(size_t ii = 0; ii < vec; ii++)
|
||||
val = insert_elt(val, bit_cast(vals_.at(val_op)[idxs[i + ii]], ty), ii);
|
||||
if(mx){
|
||||
Value *msk = vals_[mx->get_mask_operand()][idx];
|
||||
Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {});
|
||||
builder_->SetInsertPoint(no_op->getParent());
|
||||
Instruction* dummy = builder_->CreateRet(nullptr);
|
||||
Instruction *term = llvm::SplitBlockAndInsertIfThen(msk, no_op, false);
|
||||
dummy->removeFromParent();
|
||||
builder_->SetInsertPoint(term);
|
||||
store(val, ptr);
|
||||
builder_->SetInsertPoint(no_op);
|
||||
size_t dtsize = std::max<int>(1, val_op->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8);
|
||||
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(ptr);
|
||||
size_t in_off;
|
||||
if(in_gep){
|
||||
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
|
||||
in_off = cst ? cst->getValue().getSExtValue()*dtsize : 0;
|
||||
ptr = cst ? in_gep->getPointerOperand() : in_gep;
|
||||
}
|
||||
else
|
||||
store(val, ptr);
|
||||
else{
|
||||
in_off = 0;
|
||||
}
|
||||
// mask
|
||||
Value *pred = msk_op ? vals_[msk_op][idx] : builder_->getTrue();
|
||||
size_t nbits = dtsize*8;
|
||||
// pack sub-words (< 32/64bits) into words
|
||||
// each load has width min(nbits*vec, 32/64)
|
||||
// and there are (nbits * vec)/width of them
|
||||
int max_word_width = std::max<int>(32, nbits);
|
||||
int tot_width = nbits*vec;
|
||||
int width = std::min(tot_width, max_word_width);
|
||||
int n_words = std::max(1, tot_width / width);
|
||||
// -----
|
||||
// create inline asm string
|
||||
// -----
|
||||
std::ostringstream asm_oss;
|
||||
asm_oss << "@$0"; // predicate
|
||||
asm_oss << " st.global";
|
||||
if (has_l2_evict_policy) asm_oss << ".L2::cache_hint";
|
||||
if(n_words > 1)
|
||||
asm_oss << ".v" << n_words; // vector width
|
||||
asm_oss << ".b" << width; // word size
|
||||
asm_oss << " [ $1 + " << in_off << "]";
|
||||
asm_oss << " , {";
|
||||
for(int i = 0; i < n_words; i++){ // return values
|
||||
if(i > 0) asm_oss << ",";
|
||||
asm_oss << "$" << 2 + i;
|
||||
}
|
||||
asm_oss << "}";
|
||||
if (has_l2_evict_policy) asm_oss << ", $" << n_words + 2;
|
||||
asm_oss << ";";
|
||||
// ----
|
||||
// create inline ASM signature
|
||||
// ---
|
||||
Type* val_arg_ty = IntegerType::get(*ctx_, width);
|
||||
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
|
||||
for(int ii = 0; ii < n_words; ii++)
|
||||
arg_tys.push_back(val_arg_ty);
|
||||
if (has_l2_evict_policy)
|
||||
arg_tys.push_back(i64_ty);
|
||||
FunctionType *asm_ty = FunctionType::get(builder_->getVoidTy(), arg_tys, false);
|
||||
// ---
|
||||
// create inline ASM constraints
|
||||
// ---
|
||||
std::string asm_cstrt = "b,l";
|
||||
for(int ii = 0; ii < n_words; ii++){
|
||||
asm_cstrt += ",";
|
||||
asm_cstrt += (width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
}
|
||||
if (has_l2_evict_policy)
|
||||
asm_cstrt += ",l";
|
||||
// ---
|
||||
// finally call inline ASM
|
||||
// ---
|
||||
InlineAsm *_asm = InlineAsm::get(asm_ty, asm_oss.str(), asm_cstrt, true);
|
||||
std::vector<Value*> args = {pred, ptr};
|
||||
for(unsigned int ii = 0; ii < n_words; ii++){
|
||||
size_t n_subw = width / nbits;
|
||||
Value* curr = UndefValue::get(vec_ty(ty, n_subw));
|
||||
for(unsigned int jj = 0; jj < n_subw; jj++){
|
||||
Value* new_elt = vals_[val_op][idxs[i + ii*n_subw + jj]];
|
||||
if(new_elt->getType()->isIntegerTy(1))
|
||||
new_elt = builder_->CreateSExt(new_elt, builder_->getInt8Ty());
|
||||
new_elt = bit_cast(new_elt, ty);
|
||||
curr = builder_->CreateInsertElement(curr, new_elt, jj);
|
||||
}
|
||||
args.push_back(bit_cast(curr, val_arg_ty));
|
||||
}
|
||||
if (has_l2_evict_policy)
|
||||
args.push_back(policies_.at(x->get_eviction_policy()));
|
||||
call(_asm, args);
|
||||
}
|
||||
}
|
||||
void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) {
|
||||
@@ -1098,6 +1178,7 @@ void generator::visit_exp_inst(ir::exp_inst* x){
|
||||
InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $0;", "=f,0", false);
|
||||
for(auto idx: idxs_.at(x)){
|
||||
Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e);
|
||||
// Value *ex2arg = vals_[x->get_operand(0)][idx];
|
||||
vals_[x][idx] = call(ex2, std::vector<llvm::Value*>{ex2arg});
|
||||
}
|
||||
}
|
||||
@@ -1291,6 +1372,18 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
// order
|
||||
auto ord_a = layouts_->get(A)->get_order();
|
||||
auto ord_b = layouts_->get(B)->get_order();
|
||||
bool is_a_trans = C->is_trans_a();
|
||||
// is_a_trans = false;
|
||||
if(C->is_trans_a()){
|
||||
std::swap(ord_a[0], ord_a[1]);
|
||||
std::swap(shape_a[0], shape_a[1]);
|
||||
std::swap(offset_a_m_, offset_a_k_);
|
||||
}
|
||||
// std::cout << "visiting" << std::endl;
|
||||
// if(C->is_trans_b()){
|
||||
// std::swap(ord_b[0], ord_b[1]);
|
||||
// std::swap(shape_b[0], shape_b[1]);
|
||||
// }
|
||||
// layouts
|
||||
analysis::mma_layout* layout_c = layouts_->get(C)->to_mma();
|
||||
analysis::shared_layout* layout_a = layouts_->get(A)->to_shared();
|
||||
@@ -1322,6 +1415,12 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
|
||||
int step_b0 = is_b_row ? stride_rep_n : stride_rep_k;
|
||||
int num_ptr_b = std::max(2 * per_phase_b * max_phase_b / step_b0, 1);
|
||||
|
||||
|
||||
// max_phase_a = 4;
|
||||
// vec_a = 8;
|
||||
// std::cout << per_phase_a << " " << max_phase_a << " " << step_a0 << " " << num_ptr_a << " " << stride_am << " " << stride_ak << " " << stride_a0 << " " << stride_a1 << std::endl;
|
||||
// std::cout << vec_a << " " << vec_b << std::endl;
|
||||
|
||||
/* --------------------------------- */
|
||||
/* --- pre-compute pointer lanes --- */
|
||||
/* --------------------------------- */
|
||||
@@ -1916,12 +2015,17 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
auto shape_a = A->get_type()->get_block_shapes();
|
||||
auto shape_b = B->get_type()->get_block_shapes();
|
||||
auto ord_a = layouts_->get(A)->get_order();
|
||||
if(C->is_trans_a()){
|
||||
std::swap(ord_a[0], ord_a[1]);
|
||||
std::swap(shape_a[0], shape_a[1]);
|
||||
}
|
||||
auto ord_b = layouts_->get(B)->get_order();
|
||||
if(C->is_trans_b()){
|
||||
std::swap(ord_b[0], ord_b[1]);
|
||||
std::swap(shape_b[0], shape_b[1]);
|
||||
}
|
||||
NK = shape_a[1];
|
||||
analysis::mma_layout* layout = layouts_->get(C)->to_mma();
|
||||
analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0));
|
||||
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
|
||||
bool is_a_row = ord_a[0] == 1;
|
||||
bool is_b_row = ord_b[0] == 1;
|
||||
|
||||
std::vector<int> mma_instr_shape = layout->get_mma_instr_shape();
|
||||
const int mma_instr_m = mma_instr_shape[0];
|
||||
@@ -1933,10 +2037,6 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
const int mat_shape_n = mat_shape[1];
|
||||
const int mat_shape_k = mat_shape[2];
|
||||
|
||||
const int per_phase_a = swizzle_->get_per_phase(layout_a);
|
||||
const int max_phase_a = swizzle_->get_max_phase(layout_a);
|
||||
const int per_phase_b = swizzle_->get_per_phase(layout_b);
|
||||
const int max_phase_b = swizzle_->get_max_phase(layout_b);
|
||||
|
||||
const int num_rep_m = shapes[0] / layout->shape_per_cta(0);
|
||||
const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
|
||||
@@ -2001,7 +2101,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
|
||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||
if(FirstBB != CurrBB)
|
||||
|
||||
// if true, this will move pointer declarations to the entry basic block
|
||||
// not prefetched cases tend to be more limited in resource usage
|
||||
// so we don't pre-compute ptrs to save registers
|
||||
bool licm_ptrs = C->is_prefetched() && (FirstBB != CurrBB);
|
||||
if(licm_ptrs)
|
||||
builder_->SetInsertPoint(FirstBB->getTerminator());
|
||||
|
||||
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
@@ -2015,47 +2120,137 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
|
||||
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
|
||||
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||
auto register_lds2 =
|
||||
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) {
|
||||
if (k < 2 && is_prefetch) {
|
||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block));
|
||||
} else
|
||||
vals[{mn, k}] = val;
|
||||
};
|
||||
|
||||
// | -> k (row-major), since we have ldmatrix.trans, we only need to change stride
|
||||
// v (s0_0(0), s1_0(2), | *num_rep_k
|
||||
// m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2)
|
||||
// -----------
|
||||
// *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0))
|
||||
mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
|
||||
{mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
|
||||
per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep);
|
||||
std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane);
|
||||
int num_ptr_a = a_loader.get_num_ptr();
|
||||
std::function<void(int,int,int,bool)> load_a;
|
||||
analysis::shared_layout* layout_a = layouts_->get(C->get_operand(0))->to_shared();
|
||||
bool is_a_shared = layout_a != nullptr;
|
||||
if(is_a_shared) {
|
||||
const int per_phase_a = swizzle_->get_per_phase(layout_a);
|
||||
const int max_phase_a = swizzle_->get_max_phase(layout_a);
|
||||
mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
|
||||
{mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
|
||||
per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep);
|
||||
std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane);
|
||||
int num_ptr_a = a_loader.get_num_ptr();
|
||||
// pointers
|
||||
std::vector<Value*> ptrs_a(num_ptr_a);
|
||||
if(licm_ptrs)
|
||||
builder_->SetInsertPoint(CurrBB);
|
||||
for(int i = 0; i < num_ptr_a; i++)
|
||||
ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
|
||||
if(licm_ptrs)
|
||||
builder_->SetInsertPoint(FirstBB->getTerminator());
|
||||
// loading function
|
||||
load_a = [&,a_loader,ptrs_a,off_a](int m, int k, int inc, bool is_prefetch) mutable {
|
||||
auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
|
||||
shared_next_ptr_[layout_a], off_a, ptrs_a,
|
||||
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
|
||||
register_lds2(ha, m, k, inc, ha0, is_prefetch);
|
||||
register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
|
||||
register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
|
||||
register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
|
||||
};
|
||||
}
|
||||
else {
|
||||
load_a = [&](int m, int k, int inc, bool is_prefetch) {
|
||||
distributed_axis ax_n = axes_.at(a_axes_->get(A, 1));
|
||||
int ldm = ax_n.values.size();
|
||||
if(ldm != num_rep_k*4)
|
||||
throw std::runtime_error("Internal compiler error when trying to fuse matmuls!");
|
||||
// std::cout << m << " " << k << std::endl;
|
||||
// std::cout << idxs_[A].size() << std::endl;
|
||||
// std::cout << (m+1)*ldm + k*2 + 3 << std::endl;
|
||||
// int ldm = num_rep_k*4;
|
||||
Value* ha0 = UndefValue::get(fp16x2_ty);
|
||||
Value* ha1 = UndefValue::get(fp16x2_ty);
|
||||
Value* ha2 = UndefValue::get(fp16x2_ty);
|
||||
Value* ha3 = UndefValue::get(fp16x2_ty);
|
||||
ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 0]], i32(0));
|
||||
ha0 = builder_->CreateInsertElement(ha0, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 1]], i32(1));
|
||||
ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 0]], i32(0));
|
||||
ha1 = builder_->CreateInsertElement(ha1, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 1]], i32(1));
|
||||
ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 2]], i32(0));
|
||||
ha2 = builder_->CreateInsertElement(ha2, vals_[A][idxs_[A][(m+0)*ldm + k*2 + 3]], i32(1));
|
||||
ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 2]], i32(0));
|
||||
ha3 = builder_->CreateInsertElement(ha3, vals_[A][idxs_[A][(m+1)*ldm + k*2 + 3]], i32(1));
|
||||
ha[{m, k}] = ha0;
|
||||
ha[{m+1, k}] = ha1;
|
||||
ha[{m, k+1}] = ha2;
|
||||
ha[{m+1, k+1}] = ha3;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
// | -> n (col-major)
|
||||
// v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n
|
||||
// k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1))
|
||||
// -----------
|
||||
// *num_rep_k (stride in num of matrices(mat_stride_bk): 2)
|
||||
mma16816_smem_loader b_loader(layout->wpt(1), ord_b, /*k_order*/0, shape_b,
|
||||
{mma_instr_k, mma_instr_n}, {mat_shape_k, mat_shape_n},
|
||||
analysis::shared_layout* layout_b = layouts_->get(C->get_operand(1))->to_shared();
|
||||
const int per_phase_b = swizzle_->get_per_phase(layout_b);
|
||||
const int max_phase_b = swizzle_->get_max_phase(layout_b);
|
||||
std::vector<int> mma_instr_b{mma_instr_k, mma_instr_n};
|
||||
std::vector<int> mat_shape_b{mat_shape_k, mat_shape_n};
|
||||
int k_order_b = 0;
|
||||
// if(C->is_trans_b()){
|
||||
// std::swap(mma_instr_b[0], mma_instr_b[1]);
|
||||
// std::swap(mat_shape_b[0], mat_shape_b[1]);
|
||||
// k_order_b = k_order_b ^ 1;
|
||||
// std::swap(ord_b[0], ord_b[1]);
|
||||
// std::swap(shape_b[0], shape_b[1]);
|
||||
// }
|
||||
|
||||
mma16816_smem_loader b_loader(layout->wpt(1), ord_b, k_order_b, shape_b,
|
||||
mma_instr_b, mat_shape_b,
|
||||
per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep);
|
||||
std::vector<Value*> off_b = b_loader.compute_offs(warp_n, lane);
|
||||
int num_ptr_b = b_loader.get_num_ptr();
|
||||
|
||||
builder_->SetInsertPoint(CurrBB);
|
||||
// A pointer
|
||||
std::vector<Value*> ptrs_a(num_ptr_a);
|
||||
for(int i = 0; i < num_ptr_a; i++)
|
||||
ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
|
||||
// B pointer
|
||||
if(licm_ptrs)
|
||||
builder_->SetInsertPoint(CurrBB);
|
||||
// pointers
|
||||
int num_ptr_b = b_loader.get_num_ptr();
|
||||
std::vector<Value*> ptrs_b(num_ptr_b);
|
||||
for(int i = 0; i < num_ptr_b; i++)
|
||||
ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty);
|
||||
|
||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() +
|
||||
|
||||
// loading function
|
||||
std::function<void(int,int,int,bool)> load_b;
|
||||
load_b = [&](int n, int k, int inc, bool is_prefetch) {
|
||||
auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
|
||||
shared_next_ptr_[layout_b], off_b, ptrs_b,
|
||||
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
|
||||
register_lds2(hb, n, k, inc, hb0, is_prefetch);
|
||||
register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
|
||||
register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
|
||||
register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
|
||||
};
|
||||
|
||||
|
||||
|
||||
// create mma & unpack result, m, n, k are offsets in mat
|
||||
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
|
||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() +
|
||||
" {$0, $1, $2, $3},"
|
||||
" {$4, $5, $6, $7},"
|
||||
" {$8, $9},"
|
||||
" {$10, $11, $12, $13};",
|
||||
"=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true);
|
||||
|
||||
// create mma & unpack result, m, n, k are offsets in mat
|
||||
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
|
||||
unsigned cols_per_thread = num_rep_n * 2;
|
||||
std::vector<size_t> idx = {
|
||||
(m + 0)*cols_per_thread + (n*2 + 0),
|
||||
@@ -2072,39 +2267,6 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
fc[idx[2]] = extract_val(nc, std::vector<unsigned>{2});
|
||||
fc[idx[3]] = extract_val(nc, std::vector<unsigned>{3});
|
||||
};
|
||||
|
||||
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
|
||||
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||
|
||||
auto register_lds2 =
|
||||
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) {
|
||||
if (k < 2 && is_prefetch) {
|
||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block));
|
||||
} else
|
||||
vals[{mn, k}] = val;
|
||||
};
|
||||
|
||||
auto load_a = [&](int m, int k, int inc, bool is_prefetch) {
|
||||
auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
|
||||
shared_next_ptr_[layout_a], off_a, ptrs_a,
|
||||
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
|
||||
register_lds2(ha, m, k, inc, ha0, is_prefetch);
|
||||
register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
|
||||
register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
|
||||
register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
|
||||
};
|
||||
|
||||
auto load_b = [&](int n, int k, int inc, bool is_prefetch) {
|
||||
auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
|
||||
shared_next_ptr_[layout_b], off_b, ptrs_b,
|
||||
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
|
||||
register_lds2(hb, n, k, inc, hb0, is_prefetch);
|
||||
register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
|
||||
register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
|
||||
register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
|
||||
};
|
||||
|
||||
if (C->is_prefetched()) {
|
||||
// create phis
|
||||
builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
|
||||
@@ -2163,6 +2325,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
i = 0;
|
||||
vals_[C][idx] = fcs.at(key)[i++];
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -2384,7 +2547,7 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
|
||||
} else if (layout->to_mma()) {
|
||||
shuffle_width = 4;
|
||||
warps_per_inner = layout->to_mma()->wpt(1);
|
||||
col_per_thread = 16;
|
||||
col_per_thread = axes_.at(a_axes_->get(arg, 1)).values.size();
|
||||
warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id;
|
||||
}
|
||||
assert(warp_j != nullptr);
|
||||
@@ -2403,7 +2566,8 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
|
||||
Value* is_warp0 = icmp_eq(warp, i32(0));
|
||||
Value* is_thread0 = icmp_eq(thread, i32(0));
|
||||
Value* lane_j = urem(lane, i32(shuffle_width));
|
||||
add_barrier();
|
||||
if(warps_per_inner > 1)
|
||||
add_barrier();
|
||||
// compute partial sum for each warp, and store to shared memory
|
||||
for(size_t i = 0; i < n_elts/col_per_thread; i++){
|
||||
std::pair<Value*, Value*> acc;
|
||||
@@ -2425,13 +2589,21 @@ void generator::visit_reducend_inst_fast(ir::reduce_inst* x, acc_fn_t do_acc, Va
|
||||
// store partial result to shared memory
|
||||
auto x_idxs = idxs_[x][i];
|
||||
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
|
||||
Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j);
|
||||
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first});
|
||||
if (with_index) {
|
||||
call(st_shared_index,
|
||||
{icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second});
|
||||
// single warp on the reduce dimension -- no need to use shmem
|
||||
if(warps_per_inner==1){
|
||||
vals_[x][idxs_[x][i]] = with_index ? acc.second : acc.first;
|
||||
}
|
||||
else{
|
||||
Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j);
|
||||
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc.first});
|
||||
if (with_index) {
|
||||
call(st_shared_index,
|
||||
{icmp_eq(lane_j, i32(0)), gep(index_base, st_off), acc.second});
|
||||
}
|
||||
}
|
||||
}
|
||||
if(warps_per_inner==1)
|
||||
return;
|
||||
add_barrier();
|
||||
// at this point, partial accumulator synchronized in shared memory
|
||||
// Just need to reduce `warp_per_inner` numbers in shared memory
|
||||
@@ -2559,6 +2731,7 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||
case ir::reduce_inst::FMAX: return max_num(x, y);
|
||||
case ir::reduce_inst::FMIN: return min_num(x, y);
|
||||
case ir::reduce_inst::XOR: return xor_(x, y);
|
||||
|
||||
default: throw std::runtime_error("unreachable");
|
||||
}
|
||||
};
|
||||
@@ -2639,7 +2812,9 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
||||
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
|
||||
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
|
||||
Value *base;
|
||||
base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out)))));
|
||||
int off = alloc_->offset(layouts_->get(layouts_->tmp(out)));
|
||||
// std::cout << off << std::endl;
|
||||
base = gep(shmem_, i32(off));
|
||||
base = bit_cast(base, ptr_ty(ty, 3));
|
||||
std::vector<int> n_reps;
|
||||
for(int i = 0; i < shape.size(); i++){
|
||||
@@ -2821,15 +2996,26 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||
//
|
||||
int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]);
|
||||
int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]);
|
||||
if(in_layout->to_mma()){
|
||||
mts_0 = 4 * in_layout->to_mma()->wpt(in_order[0]);
|
||||
mts_1 = 8 * in_layout->to_mma()->wpt(in_order[1]);
|
||||
per_phase = 1;
|
||||
max_phase = 8;
|
||||
}
|
||||
|
||||
int in_ld = in_layout->get_shape()[in_order[0]] / mts_0;
|
||||
int n_shared_1 = std::max<int>(per_phase*max_phase / mts_1, 1);
|
||||
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
|
||||
int n_shared_1 = std::max<int>(per_phase*max_phase / mts_1, 1);
|
||||
if(in_layout->to_mma()){
|
||||
n_shared_0 = 8;
|
||||
n_shared_1 = 1;
|
||||
}
|
||||
|
||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||
auto shapes = cts->get_type()->get_block_shapes();
|
||||
|
||||
|
||||
// store to shared
|
||||
Value *current = nullptr;
|
||||
std::map<std::pair<int, int>, Value*> ptrs;
|
||||
@@ -2844,9 +3030,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||
// input ptr info
|
||||
int id_0 = id % (in_ld/min_vec);
|
||||
int id_1 = id / (in_ld/min_vec);
|
||||
int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0;
|
||||
int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1;
|
||||
int off = (off_1*shapes[in_order[0]] + off_0);
|
||||
// std::cout << id_0 << " " << id_1 << " " << in_ld << " " << std::endl;
|
||||
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
|
||||
if(ptrs.find(key) == ptrs.end()){
|
||||
if(FirstBB->getTerminator())
|
||||
@@ -2865,6 +3049,13 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||
builder_->SetInsertPoint(CurrBB);
|
||||
ptrs[key] = gep(shmems_.at(cts), {off});
|
||||
}
|
||||
int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0;
|
||||
int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1;
|
||||
if(in_layout->to_mma()){
|
||||
off_0 = id_0/n_shared_0*n_shared_0*8;
|
||||
off_1 = id_1/n_shared_1*n_shared_1*8;
|
||||
}
|
||||
int off = (off_1*shapes[in_order[0]] + off_0);
|
||||
Value* ptr = gep(ptrs[key], {i32(off)});
|
||||
ptr = bit_cast(ptr, current->getType()->getPointerTo(3));
|
||||
// asm
|
||||
@@ -3069,7 +3260,7 @@ void generator::visit_function(ir::function* fn) {
|
||||
if(tgt_->as_nvidia()->sm() >= 80)
|
||||
for(ir::load_inst::EVICTION_POLICY evict: {ir::load_inst::EVICT_FIRST, ir::load_inst::EVICT_LAST}){
|
||||
std::string policy = (evict == ir::load_inst::EVICT_FIRST) ? "evict_first" : "evict_last";
|
||||
std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0;";
|
||||
std::string asm_str = "createpolicy.fractional.L2::" + policy + ".b64 $0, 1.0;";
|
||||
InlineAsm* iasm = InlineAsm::get(FunctionType::get(i64_ty, {}), asm_str, "=l", false);
|
||||
policies_[evict] = call(iasm);
|
||||
}
|
||||
@@ -3299,7 +3490,6 @@ void generator::visit_basic_block(ir::basic_block * block) {
|
||||
BasicBlock *parent = bbs_[block];
|
||||
builder_->SetInsertPoint(parent);
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
// i->print(std::cout);
|
||||
visit_value(i);
|
||||
// std::cout << "done" << std::endl;
|
||||
}
|
||||
@@ -3324,7 +3514,10 @@ void generator::init_idx(ir::value *v) {
|
||||
std::vector<distributed_axis> axes(rank);
|
||||
std::vector<int> ord(rank);
|
||||
// compute axes
|
||||
// std::cout << "axes" << std::endl;
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
// std::cout << d << " " << shapes[d] << std::endl;
|
||||
// std::cout << a_axes_->get(v, d) << std::endl;
|
||||
if(shapes[d] > 1){
|
||||
unsigned x = a_axes_->get(v, d);
|
||||
axes[d] = axes_.at(x);
|
||||
@@ -3334,6 +3527,7 @@ void generator::init_idx(ir::value *v) {
|
||||
axes[d].values = {i32(0)};
|
||||
}
|
||||
}
|
||||
// std::cout << "axes ok" << std::endl;
|
||||
// compute order
|
||||
analysis::data_layout* layout = layouts_->get(v);
|
||||
std::iota(ord.begin(), ord.end(), 0);
|
||||
@@ -3480,6 +3674,7 @@ void generator::finalize_phi_node(ir::phi_node *x) {
|
||||
for(indices_t idx: idxs_.at(x)){
|
||||
PHINode *phi = (PHINode*)vals_[x][idx];
|
||||
Value *inc = vals_[x->get_incoming_value(n)][idx];
|
||||
// x->print(std::cout);
|
||||
phi->addIncoming(inc, block);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user