[codegen] [allocation] fixed issues in HMMA

This commit is contained in:
Philippe Tillet
2019-09-23 17:54:42 -04:00
parent b95ac15d48
commit f0013f8bf1
10 changed files with 94 additions and 50 deletions

View File

@@ -25,6 +25,15 @@ class axes;
class layout;
class align;
enum layout_t {
SCANLINE,
HMMA_C,
HMMA_A_COL,
HMMA_A_ROW,
HMMA_B_COL,
HMMA_B_ROW
};
class tiles {
typedef std::map<ir::value*, std::map<int, int>> param_map_t;
private:
@@ -34,7 +43,7 @@ private:
public:
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
void run(ir::module &mod);
bool hmma(ir::value *value);
layout_t hmma(ir::value *value);
int mts(ir::value *value, unsigned ax);
int nts(ir::value *value, unsigned ax);
int fpw(ir::value *value, unsigned ax);
@@ -52,7 +61,7 @@ private:
// tile properties
std::map<int, ir::value*> largest_;
std::map<int, std::vector<int>> order_;
std::map<int, bool> hmma_;
std::map<int, layout_t> hmma_;
std::map<int, int> fpw_;
std::map<int, int> wpt_;
std::map<int, int> mts_;

View File

@@ -20,24 +20,14 @@ unsigned allocation::is_ld_padded(ir::value *x) {
if(trans->get_perm()[0]->get_value() != 0)
return 4;
}
for(ir::user* user: x->get_users())
if(auto dot = dynamic_cast<ir::dot_inst*>(user)){
bool is_hmma = tiles_->hmma(user);
bool is_op_0 = x == dot->get_operand(0);
bool is_op_1 = x == dot->get_operand(1);
if(is_hmma && is_op_0){
if(dot->is_a_trans())
return 8;
else
return 16;
}
if(is_hmma && is_op_1){
if(!dot->is_b_trans())
return 8;
else
return 16;
}
}
if(tiles_->hmma(x) == HMMA_A_ROW)
return 8;
if(tiles_->hmma(x) == HMMA_A_COL)
return 16;
if(tiles_->hmma(x) == HMMA_B_COL)
return 8;
if(tiles_->hmma(x) == HMMA_B_ROW)
return 16;
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
unsigned result = 0;
for(unsigned i = 0; i < phi->get_num_incoming(); i++)

View File

@@ -23,7 +23,7 @@ tiles::tiles(size_t num_warps, analysis::align *align, analysis::axes *axes, ana
num_warps_(num_warps), align_(align), axes_(axes), layout_(layout)
{ }
bool is_hmma(ir::value *v){
bool is_hmma_c(ir::value *v){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
@@ -36,9 +36,44 @@ bool is_hmma(ir::value *v){
return result;
}
bool is_hmma_a_col(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(0)) && !dot->is_a_trans())
return true;
}
}
bool is_hmma_a_row(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(0)) && dot->is_a_trans())
return true;
}
}
bool is_hmma_b_col(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(1)) && !dot->is_b_trans())
return true;
}
}
bool is_hmma_b_row(ir::value* v) {
for(ir::user *u: v->get_users())
if(is_hmma_c(u)){
ir::dot_inst* dot = (ir::dot_inst*)u;
if((v == dot->get_operand(1)) && dot->is_b_trans())
return true;
}
}
bool tiles::hmma(ir::value *value) {
layout_t tiles::hmma(ir::value *value) {
return hmma_.at(layout_->id(value));
}
@@ -164,7 +199,18 @@ void tiles::run(ir::module &) {
// find out which groups require hmma layout
for(size_t i = 0; i < num_groups; i++) {
const auto& values = layout_->values(i);
hmma_[i] = std::any_of(values.begin(), values.end(), &is_hmma);
bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c);
bool hmma_a_col = std::any_of(values.begin(), values.end(), &is_hmma_a_col);
bool hmma_a_row = std::any_of(values.begin(), values.end(), &is_hmma_a_row);
bool hmma_b_col = std::any_of(values.begin(), values.end(), &is_hmma_b_col);
bool hmma_b_row = std::any_of(values.begin(), values.end(), &is_hmma_b_row);
if(hmma_c) hmma_[i] = HMMA_C;
else if(hmma_a_col) hmma_[i] = HMMA_A_COL;
else if(hmma_a_row) hmma_[i] = HMMA_A_ROW;
else if(hmma_b_col) hmma_[i] = HMMA_B_COL;
else if(hmma_b_row) hmma_[i] = HMMA_B_ROW;
else hmma_[i] = SCANLINE;
}
// find out which value is the largest in each group
for(size_t i = 0; i < num_groups; i++) {
@@ -197,7 +243,7 @@ void tiles::run(ir::module &) {
if(!i->get_type()->is_tile_ty())
continue;
/* HMMA parameters*/
if(hmma_[x.first])
if(hmma_[x.first] == HMMA_C)
init_hmma_tile(i);
else
init_scanline_tile(i);

View File

@@ -710,7 +710,7 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
if(tiles_->hmma(v))
if(tiles_->hmma(v) == analysis::HMMA_C)
init_hmma_axes(v, builder, u_thread_id, u_warp_id);
else
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
@@ -1241,11 +1241,10 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
if(NK != 1) {
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
if(tiles_->hmma(dot))
if(tiles_->hmma(dot) == analysis::HMMA_C)
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
else
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
}
else {
distributed_tile *TA = (distributed_tile*)tmap_.at(A);

View File

@@ -104,12 +104,12 @@ bool peephole::rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool tr
BB = ((ir::trans_inst*)B)->get_operand(0);
}
else{
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
if(auto *T = dynamic_cast<ir::trans_inst*>(B)){
std::vector<ir::constant_int*> perm(T->get_perm());
std::swap(perm[0], perm[1]);
AA = builder.create_trans(T->get_operand(0), perm);
T->replace_all_uses_with(AA);
trans_a = true;
BB = builder.create_trans(T->get_operand(0), perm);
T->replace_all_uses_with(BB);
trans_b = true;
}
}
if(!trans_a && !trans_b)

View File

@@ -241,7 +241,6 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -201,17 +201,17 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
// create passes
codegen::transform::cts cts;
codegen::analysis::align align;
codegen::analysis::liveness shmem_liveness;
codegen::analysis::liveness liveness;
codegen::analysis::axes axes;
codegen::analysis::layout layouts(&axes);
codegen::transform::coalesce coalesce(&align, &layouts);
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
codegen::analysis::allocation shmem_allocation(&shmem_liveness, &tiles);
codegen::transform::membar shmem_barriers(&shmem_liveness, &shmem_allocation);
codegen::analysis::allocation allocation(&liveness, &tiles);
codegen::transform::membar barriers(&liveness, &allocation);
codegen::transform::dce dce;
codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate(&align);
codegen::selection selection(&shmem_liveness, &shmem_allocation, &tiles, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
codegen::selection selection(&liveness, &allocation, &tiles, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
// run passes
peephole.run(module);
dce.run(module);
@@ -226,14 +226,12 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
tiles.run(module);
reassociate.run(module);
dce.run(module);
peephole.run(module);
dce.run(module);
cts.run(module);
shmem_liveness.run(module);
shmem_allocation.run(module);
if(shmem_allocation.allocated_size() > context->device()->max_shared_memory())
liveness.run(module);
allocation.run(module);
if(allocation.allocated_size() > context->device()->max_shared_memory())
return std::unique_ptr<driver::module>();
shmem_barriers.run(module);
barriers.run(module);
dce.run(module);
dce.run(module);
axes.run(module);

View File

@@ -27,9 +27,9 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef float NumericT;
std::string ty = "float";
cublasDataType_t cuty = CUDA_R_32F;
typedef half_float::half NumericT;
std::string ty = "half";
cublasDataType_t cuty = CUDA_R_16F;
size_t dt_nbytes = sizeof(NumericT);
drv::context* context = stream->context();
// leading dimensions
@@ -47,8 +47,8 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
opt.defines.push_back({"BT", {BT?"1":"0"}});
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TK", {"8"}});
opt.num_warps = {8};
opt.defines.push_back({"TK", {"16"}});
opt.num_warps = {2, 4, 8};
// create function
rt::function function(src::dot, opt);
// benchmark available libraries
@@ -79,7 +79,10 @@ int main() {
// shapes to benchmark
typedef std::tuple<bool, bool, int, int, int> config_t;
std::vector<config_t> configs;
for(auto x: std::vector<std::array<bool, 2>>{{false, true}}){
for(auto x: std::vector<std::array<bool, 2>>{{false, false},
{false, true},
{true, false},
{true, true}}){
std::vector<config_t> tmp = {
config_t{x[0], x[1], 2048, 2048, 2048}
// config_t{x[0], x[1], 16, 2048, 2048},

View File

@@ -64,7 +64,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
// epilogue
int rxc[TM] = ridx * TM + 0 ... TM;
int ryc[TN] = ridy * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + rxc[:, newaxis] * ldc + ryc[newaxis, :];
TYPE* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :] * ldc;
*pc = c;
}
)";

View File

@@ -32,7 +32,7 @@ static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vecto
float acc = 0;
for(size_t k = 0; k < K; k++)
acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]);
c[m*N + n] = static_cast<T>(acc);
c[m + n*M] = static_cast<T>(acc);
}
}