[CODEGEN] Major performance improvements on A100 (#70)
Improved handling of asynchronous copy, scheduling and synchronization for A100. Now achieving CUTLASS-like performance on large square dense matrix multiplication tasks
This commit is contained in:
committed by
Philippe Tillet
parent
045ab5d62a
commit
5b83259592
@@ -1416,59 +1416,80 @@ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) {
|
||||
}
|
||||
|
||||
void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){
|
||||
unsigned vector = 1;
|
||||
ir::value *ptrs = x->get_pointer_operand();
|
||||
ir::value *msks = x->get_mask_operand();
|
||||
unsigned in_vec = 1;
|
||||
ir::value *arg = x->get_pointer_operand();
|
||||
analysis::shared_layout* out_layout = layouts_->get(x)->to_shared();
|
||||
analysis::scanline_layout* in_layout = layouts_->get(ptrs)->to_scanline();
|
||||
analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
|
||||
auto out_order = out_layout->get_order();
|
||||
auto in_order = in_layout->get_order();
|
||||
// tiles
|
||||
if(out_order == in_order)
|
||||
vector = in_layout->nts(in_order[0]);
|
||||
in_vec = in_layout->nts(in_order[0]);
|
||||
int out_vec = swizzle_->get_vec(out_layout);
|
||||
int min_vec = std::min<int>(out_vec, in_vec);
|
||||
int s = std::max<int>(out_vec / in_vec, 1);
|
||||
//
|
||||
int dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
int num_per_phase = std::max<int>(128 / (in_layout->mts(in_order[0])*vector*dtsize), 1);
|
||||
Value *max_phase = i32(8 / num_per_phase);
|
||||
int per_phase = swizzle_->get_per_phase(out_layout);
|
||||
int max_phase = swizzle_->get_max_phase(out_layout);
|
||||
//
|
||||
int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
|
||||
int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
|
||||
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
|
||||
auto shapes = x->get_type()->get_tile_shapes();
|
||||
//
|
||||
int per_thread_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
|
||||
int n_shared = std::max<int>(8 / in_layout->mts(in_order[1]), 1);
|
||||
std::vector<Value*> shared;
|
||||
for(size_t i = 0; i < n_shared; i++){
|
||||
indices_t idx = idxs_.at(ptrs).at(i*per_thread_ld);
|
||||
// phase
|
||||
Value* phase = udiv(idx[in_order[1]], i32(num_per_phase));
|
||||
phase = urem(phase, max_phase);
|
||||
// off
|
||||
Value* off_0 = idx[in_order[0]];
|
||||
off_0 = udiv(off_0, i32(vector));
|
||||
off_0 = xor_(off_0, phase);
|
||||
off_0 = mul(off_0 , i32(vector));
|
||||
Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
|
||||
Value* off = add(off_0, off_1);
|
||||
//
|
||||
shared.push_back(gep(shmems_[x], {off}));
|
||||
}
|
||||
//
|
||||
for(size_t i = 0; i < idxs_.at(ptrs).size(); i += vector){
|
||||
auto idx = idxs_[ptrs][i];
|
||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||
std::map<std::pair<int, int>, Value*> tmp;
|
||||
std::vector<std::pair<Value*, int>> shared;
|
||||
for(int i = 0; i < idxs_.at(arg).size(); i++){
|
||||
unsigned id = i / min_vec;
|
||||
// input ptr info
|
||||
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[ptrs][idx]);
|
||||
Value *in_base = in_gep->getPointerOperand();
|
||||
size_t in_off = dyn_cast<ConstantInt>(in_gep->idx_begin())->getValue().getSExtValue()*2*vector;
|
||||
Value* out_base = shared[(i / per_thread_ld) % n_shared];
|
||||
int out_off_0 = (i / per_thread_ld) / n_shared * n_shared * in_layout->mts(in_order[1]);
|
||||
int out_off_1 = i % per_thread_ld;
|
||||
int out_off = (out_off_0*shapes[in_order[0]] + out_off_1)*2;
|
||||
// asm
|
||||
FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType()}, false);
|
||||
std::string mod = (vector*2 == 16) ? ".cg" : ".ca";
|
||||
std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(vector*2) + ";";
|
||||
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true);
|
||||
call(iasm, {vals_[msks][idx], out_base, in_base});
|
||||
int id_0 = id % (in_ld/min_vec);
|
||||
int id_1 = id / (in_ld/min_vec);
|
||||
int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]);
|
||||
int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]);
|
||||
int off = (off_1*shapes[in_order[0]] + off_0);
|
||||
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
|
||||
if(tmp.find(key) == tmp.end()){
|
||||
if(CurrBB != FirstBB)
|
||||
builder_->SetInsertPoint(FirstBB->getTerminator());
|
||||
indices_t idx = idxs_.at(arg).at(key.first*in_ld);
|
||||
Value* phase = udiv(idx[in_order[1]], i32(per_phase));
|
||||
phase = urem(phase, i32(max_phase));
|
||||
Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]]));
|
||||
Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec));
|
||||
off_0 = udiv(off_0, i32(min_vec));
|
||||
off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s)));
|
||||
off_0 = mul(off_0 , i32(min_vec));
|
||||
Value* off = add(off_0, off_1);
|
||||
if(CurrBB != FirstBB)
|
||||
builder_->SetInsertPoint(CurrBB);
|
||||
tmp[key] = gep(shmems_[x], {off});
|
||||
}
|
||||
shared.push_back({tmp[key], off});
|
||||
}
|
||||
|
||||
for(size_t i = 0; i < idxs_.at(arg).size(); i += in_vec){
|
||||
auto idx = idxs_[arg][i];
|
||||
// input ptr info
|
||||
GetElementPtrInst *in_gep = dyn_cast<GetElementPtrInst>(vals_[arg][idx]);
|
||||
Value *in_base = in_gep->getPointerOperand();
|
||||
ConstantInt* cst = dyn_cast<ConstantInt>(in_gep->idx_begin());
|
||||
size_t in_off = cst ? cst->getValue().getSExtValue()*2*in_vec : 0;
|
||||
in_base = cst ? in_base : in_gep;
|
||||
// output ptr info
|
||||
Value* out_base = shared[i].first;
|
||||
int out_off = shared[i].second*2;
|
||||
// asm
|
||||
FunctionType *ty = FunctionType::get(void_ty, {builder_->getInt1Ty(), out_base->getType(), in_base->getType()}, false);
|
||||
std::string mod = (in_vec*2 == 16) ? ".cg" : ".ca";
|
||||
std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(in_vec*2) + ";";
|
||||
InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true);
|
||||
call(iasm, {vals_[x->get_mask_operand()][idx], out_base, in_base});
|
||||
}
|
||||
|
||||
std::string asm_str = "cp.async.commit_group;";
|
||||
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
|
||||
call(iasm);
|
||||
}
|
||||
|
||||
void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||
@@ -1496,7 +1517,7 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
||||
BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock();
|
||||
auto shapes = cts->get_type()->get_tile_shapes();
|
||||
|
||||
// default implementation
|
||||
// store to shared
|
||||
Value *current = nullptr;
|
||||
std::map<std::pair<int, int>, Value*> ptrs;
|
||||
for(int i = 0; i < idxs_.at(arg).size(); i++){
|
||||
@@ -1549,11 +1570,10 @@ void generator::visit_barrier_inst(ir::barrier_inst*) {
|
||||
add_barrier();
|
||||
}
|
||||
|
||||
void generator::visit_async_wait_inst(ir::async_wait_inst*) {
|
||||
std::string asm_str = "cp.async.wait_all;";
|
||||
void generator::visit_async_wait_inst(ir::async_wait_inst* i) {
|
||||
std::string asm_str = "cp.async.wait_group " + std::to_string(i->get_N()) + ";";
|
||||
InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true);
|
||||
call(iasm);
|
||||
add_barrier();
|
||||
}
|
||||
|
||||
void generator::visit_make_range_dyn(ir::make_range_dyn* x) {
|
||||
@@ -1993,10 +2013,10 @@ void generator::visit(ir::module &src, llvm::Module &dst) {
|
||||
if(unsigned alloc_size = alloc_->allocated_size()){
|
||||
Type *int_8_ty = Type::getInt8Ty(*ctx_);
|
||||
Type *int_32_ty = Type::getInt32Ty(*ctx_);
|
||||
ArrayType *array_ty = ArrayType::get(int_32_ty, alloc_size/4);
|
||||
ArrayType *array_ty = ArrayType::get(int_32_ty, 0);
|
||||
Type *ptr_ty = ptr_ty(int_8_ty, 3);
|
||||
GlobalVariable *sh_mem_array =
|
||||
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalWeakLinkage,
|
||||
new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage,
|
||||
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
|
||||
shmem_ = bit_cast(sh_mem_array, ptr_ty);
|
||||
}
|
||||
|
Reference in New Issue
Block a user