[codegen] [allocation] fixed issues in HMMA

This commit is contained in:
Philippe Tillet
2019-09-23 17:54:42 -04:00
parent b95ac15d48
commit f0013f8bf1
10 changed files with 94 additions and 50 deletions

View File

@@ -25,6 +25,15 @@ class axes;
class layout; class layout;
class align; class align;
enum layout_t {
SCANLINE,
HMMA_C,
HMMA_A_COL,
HMMA_A_ROW,
HMMA_B_COL,
HMMA_B_ROW
};
class tiles { class tiles {
typedef std::map<ir::value*, std::map<int, int>> param_map_t; typedef std::map<ir::value*, std::map<int, int>> param_map_t;
private: private:
@@ -34,7 +43,7 @@ private:
public: public:
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout); tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
void run(ir::module &mod); void run(ir::module &mod);
bool hmma(ir::value *value); layout_t hmma(ir::value *value);
int mts(ir::value *value, unsigned ax); int mts(ir::value *value, unsigned ax);
int nts(ir::value *value, unsigned ax); int nts(ir::value *value, unsigned ax);
int fpw(ir::value *value, unsigned ax); int fpw(ir::value *value, unsigned ax);
@@ -52,7 +61,7 @@ private:
// tile properties // tile properties
std::map<int, ir::value*> largest_; std::map<int, ir::value*> largest_;
std::map<int, std::vector<int>> order_; std::map<int, std::vector<int>> order_;
std::map<int, bool> hmma_; std::map<int, layout_t> hmma_;
std::map<int, int> fpw_; std::map<int, int> fpw_;
std::map<int, int> wpt_; std::map<int, int> wpt_;
std::map<int, int> mts_; std::map<int, int> mts_;

View File

@@ -20,24 +20,14 @@ unsigned allocation::is_ld_padded(ir::value *x) {
if(trans->get_perm()[0]->get_value() != 0) if(trans->get_perm()[0]->get_value() != 0)
return 4; return 4;
} }
for(ir::user* user: x->get_users()) if(tiles_->hmma(x) == HMMA_A_ROW)
if(auto dot = dynamic_cast<ir::dot_inst*>(user)){
bool is_hmma = tiles_->hmma(user);
bool is_op_0 = x == dot->get_operand(0);
bool is_op_1 = x == dot->get_operand(1);
if(is_hmma && is_op_0){
if(dot->is_a_trans())
return 8; return 8;
else if(tiles_->hmma(x) == HMMA_A_COL)
return 16; return 16;
} if(tiles_->hmma(x) == HMMA_B_COL)
if(is_hmma && is_op_1){
if(!dot->is_b_trans())
return 8; return 8;
else if(tiles_->hmma(x) == HMMA_B_ROW)
return 16; return 16;
}
}
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) { if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
unsigned result = 0; unsigned result = 0;
for(unsigned i = 0; i < phi->get_num_incoming(); i++) for(unsigned i = 0; i < phi->get_num_incoming(); i++)

View File

@@ -23,7 +23,7 @@ tiles::tiles(size_t num_warps, analysis::align *align, analysis::axes *axes, ana
num_warps_(num_warps), align_(align), axes_(axes), layout_(layout) num_warps_(num_warps), align_(align), axes_(axes), layout_(layout)
{ } { }
bool is_hmma(ir::value *v){ bool is_hmma_c(ir::value *v){
bool result = false; bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){ if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0); ir::value *a = x->get_operand(0);
@@ -36,9 +36,44 @@ bool is_hmma(ir::value *v){
return result; return result;
} }
bool is_hmma_a_col(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(0)) && !dot->is_a_trans())
return true;
}
}
bool is_hmma_a_row(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(0)) && dot->is_a_trans())
return true;
}
}
bool is_hmma_b_col(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(1)) && !dot->is_b_trans())
return true;
}
}
bool is_hmma_b_row(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(1)) && dot->is_b_trans())
return true;
}
}
bool tiles::hmma(ir::value *value) { layout_t tiles::hmma(ir::value *value) {
return hmma_.at(layout_->id(value)); return hmma_.at(layout_->id(value));
} }
@@ -164,7 +199,18 @@ void tiles::run(ir::module &) {
// find out which groups require hmma layout // find out which groups require hmma layout
for(size_t i = 0; i < num_groups; i++) { for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values(i); const auto& values = layout_->values(i);
hmma_[i] = std::any_of(values.begin(), values.end(), &is_hmma); bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c);
bool hmma_a_col = std::any_of(values.begin(), values.end(), &is_hmma_a_col);
bool hmma_a_row = std::any_of(values.begin(), values.end(), &is_hmma_a_row);
bool hmma_b_col = std::any_of(values.begin(), values.end(), &is_hmma_b_col);
bool hmma_b_row = std::any_of(values.begin(), values.end(), &is_hmma_b_row);
if(hmma_c) hmma_[i] = HMMA_C;
else if(hmma_a_col) hmma_[i] = HMMA_A_COL;
else if(hmma_a_row) hmma_[i] = HMMA_A_ROW;
else if(hmma_b_col) hmma_[i] = HMMA_B_COL;
else if(hmma_b_row) hmma_[i] = HMMA_B_ROW;
else hmma_[i] = SCANLINE;
} }
// find out which value is the largest in each group // find out which value is the largest in each group
for(size_t i = 0; i < num_groups; i++) { for(size_t i = 0; i < num_groups; i++) {
@@ -197,7 +243,7 @@ void tiles::run(ir::module &) {
if(!i->get_type()->is_tile_ty()) if(!i->get_type()->is_tile_ty())
continue; continue;
/* HMMA parameters*/ /* HMMA parameters*/
if(hmma_[x.first]) if(hmma_[x.first] == HMMA_C)
init_hmma_tile(i); init_hmma_tile(i);
else else
init_scanline_tile(i); init_scanline_tile(i);

View File

@@ -710,7 +710,7 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) { void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
if(tiles_->hmma(v)) if(tiles_->hmma(v) == analysis::HMMA_C)
init_hmma_axes(v, builder, u_thread_id, u_warp_id); init_hmma_axes(v, builder, u_thread_id, u_warp_id);
else else
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id); init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
@@ -1241,11 +1241,10 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
if(NK != 1) { if(NK != 1) {
shared_tile *TA = (shared_tile*)tmap_.at(A); shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B); shared_tile *TB = (shared_tile*)tmap_.at(B);
if(tiles_->hmma(dot)) if(tiles_->hmma(dot) == analysis::HMMA_C)
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK); lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
else else
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add); lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
} }
else { else {
distributed_tile *TA = (distributed_tile*)tmap_.at(A); distributed_tile *TA = (distributed_tile*)tmap_.at(A);

View File

@@ -104,12 +104,12 @@ bool peephole::rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool tr
BB = ((ir::trans_inst*)B)->get_operand(0); BB = ((ir::trans_inst*)B)->get_operand(0);
} }
else{ else{
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){ if(auto *T = dynamic_cast<ir::trans_inst*>(B)){
std::vector<ir::constant_int*> perm(T->get_perm()); std::vector<ir::constant_int*> perm(T->get_perm());
std::swap(perm[0], perm[1]); std::swap(perm[0], perm[1]);
AA = builder.create_trans(T->get_operand(0), perm); BB = builder.create_trans(T->get_operand(0), perm);
T->replace_all_uses_with(AA); T->replace_all_uses_with(BB);
trans_a = true; trans_b = true;
} }
} }
if(!trans_a && !trans_b) if(!trans_a && !trans_b)

View File

@@ -241,7 +241,6 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { } cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context); cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code // JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -201,17 +201,17 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
// create passes // create passes
codegen::transform::cts cts; codegen::transform::cts cts;
codegen::analysis::align align; codegen::analysis::align align;
codegen::analysis::liveness shmem_liveness; codegen::analysis::liveness liveness;
codegen::analysis::axes axes; codegen::analysis::axes axes;
codegen::analysis::layout layouts(&axes); codegen::analysis::layout layouts(&axes);
codegen::transform::coalesce coalesce(&align, &layouts); codegen::transform::coalesce coalesce(&align, &layouts);
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts); codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
codegen::analysis::allocation shmem_allocation(&shmem_liveness, &tiles); codegen::analysis::allocation allocation(&liveness, &tiles);
codegen::transform::membar shmem_barriers(&shmem_liveness, &shmem_allocation); codegen::transform::membar barriers(&liveness, &allocation);
codegen::transform::dce dce; codegen::transform::dce dce;
codegen::transform::peephole peephole; codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate(&align); codegen::transform::reassociate reassociate(&align);
codegen::selection selection(&shmem_liveness, &shmem_allocation, &tiles, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps); codegen::selection selection(&liveness, &allocation, &tiles, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
// run passes // run passes
peephole.run(module); peephole.run(module);
dce.run(module); dce.run(module);
@@ -226,14 +226,12 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
tiles.run(module); tiles.run(module);
reassociate.run(module); reassociate.run(module);
dce.run(module); dce.run(module);
peephole.run(module);
dce.run(module);
cts.run(module); cts.run(module);
shmem_liveness.run(module); liveness.run(module);
shmem_allocation.run(module); allocation.run(module);
if(shmem_allocation.allocated_size() > context->device()->max_shared_memory()) if(allocation.allocated_size() > context->device()->max_shared_memory())
return std::unique_ptr<driver::module>(); return std::unique_ptr<driver::module>();
shmem_barriers.run(module); barriers.run(module);
dce.run(module); dce.run(module);
dce.run(module); dce.run(module);
axes.run(module); axes.run(module);

View File

@@ -27,9 +27,9 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef float NumericT; typedef half_float::half NumericT;
std::string ty = "float"; std::string ty = "half";
cublasDataType_t cuty = CUDA_R_32F; cublasDataType_t cuty = CUDA_R_16F;
size_t dt_nbytes = sizeof(NumericT); size_t dt_nbytes = sizeof(NumericT);
drv::context* context = stream->context(); drv::context* context = stream->context();
// leading dimensions // leading dimensions
@@ -47,8 +47,8 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
opt.defines.push_back({"BT", {BT?"1":"0"}}); opt.defines.push_back({"BT", {BT?"1":"0"}});
opt.defines.push_back({"TM", {"128"}}); opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"128"}}); opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TK", {"8"}}); opt.defines.push_back({"TK", {"16"}});
opt.num_warps = {8}; opt.num_warps = {2, 4, 8};
// create function // create function
rt::function function(src::dot, opt); rt::function function(src::dot, opt);
// benchmark available libraries // benchmark available libraries
@@ -79,7 +79,10 @@ int main() {
// shapes to benchmark // shapes to benchmark
typedef std::tuple<bool, bool, int, int, int> config_t; typedef std::tuple<bool, bool, int, int, int> config_t;
std::vector<config_t> configs; std::vector<config_t> configs;
for(auto x: std::vector<std::array<bool, 2>>{{false, true}}){ for(auto x: std::vector<std::array<bool, 2>>{{false, false},
{false, true},
{true, false},
{true, true}}){
std::vector<config_t> tmp = { std::vector<config_t> tmp = {
config_t{x[0], x[1], 2048, 2048, 2048} config_t{x[0], x[1], 2048, 2048, 2048}
// config_t{x[0], x[1], 16, 2048, 2048}, // config_t{x[0], x[1], 16, 2048, 2048},

View File

@@ -64,7 +64,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
// epilogue // epilogue
int rxc[TM] = ridx * TM + 0 ... TM; int rxc[TM] = ridx * TM + 0 ... TM;
int ryc[TN] = ridy * TN + 0 ... TN; int ryc[TN] = ridy * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rxc[:, newaxis] * ldc + ryc[newaxis, :]; TYPE* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :] * ldc;
*pc = c; *pc = c;
} }
)"; )";

View File

@@ -32,7 +32,7 @@ static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vecto
float acc = 0; float acc = 0;
for(size_t k = 0; k < K; k++) for(size_t k = 0; k < K; k++)
acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]); acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]);
c[m*N + n] = static_cast<T>(acc); c[m + n*M] = static_cast<T>(acc);
} }
} }