[codegen] [allocation] fixed issues in HMMA
This commit is contained in:
@@ -25,6 +25,15 @@ class axes;
|
||||
class layout;
|
||||
class align;
|
||||
|
||||
enum layout_t {
|
||||
SCANLINE,
|
||||
HMMA_C,
|
||||
HMMA_A_COL,
|
||||
HMMA_A_ROW,
|
||||
HMMA_B_COL,
|
||||
HMMA_B_ROW
|
||||
};
|
||||
|
||||
class tiles {
|
||||
typedef std::map<ir::value*, std::map<int, int>> param_map_t;
|
||||
private:
|
||||
@@ -34,7 +43,7 @@ private:
|
||||
public:
|
||||
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
|
||||
void run(ir::module &mod);
|
||||
bool hmma(ir::value *value);
|
||||
layout_t hmma(ir::value *value);
|
||||
int mts(ir::value *value, unsigned ax);
|
||||
int nts(ir::value *value, unsigned ax);
|
||||
int fpw(ir::value *value, unsigned ax);
|
||||
@@ -52,7 +61,7 @@ private:
|
||||
// tile properties
|
||||
std::map<int, ir::value*> largest_;
|
||||
std::map<int, std::vector<int>> order_;
|
||||
std::map<int, bool> hmma_;
|
||||
std::map<int, layout_t> hmma_;
|
||||
std::map<int, int> fpw_;
|
||||
std::map<int, int> wpt_;
|
||||
std::map<int, int> mts_;
|
||||
|
@@ -20,24 +20,14 @@ unsigned allocation::is_ld_padded(ir::value *x) {
|
||||
if(trans->get_perm()[0]->get_value() != 0)
|
||||
return 4;
|
||||
}
|
||||
for(ir::user* user: x->get_users())
|
||||
if(auto dot = dynamic_cast<ir::dot_inst*>(user)){
|
||||
bool is_hmma = tiles_->hmma(user);
|
||||
bool is_op_0 = x == dot->get_operand(0);
|
||||
bool is_op_1 = x == dot->get_operand(1);
|
||||
if(is_hmma && is_op_0){
|
||||
if(dot->is_a_trans())
|
||||
return 8;
|
||||
else
|
||||
return 16;
|
||||
}
|
||||
if(is_hmma && is_op_1){
|
||||
if(!dot->is_b_trans())
|
||||
return 8;
|
||||
else
|
||||
return 16;
|
||||
}
|
||||
}
|
||||
if(tiles_->hmma(x) == HMMA_A_ROW)
|
||||
return 8;
|
||||
if(tiles_->hmma(x) == HMMA_A_COL)
|
||||
return 16;
|
||||
if(tiles_->hmma(x) == HMMA_B_COL)
|
||||
return 8;
|
||||
if(tiles_->hmma(x) == HMMA_B_ROW)
|
||||
return 16;
|
||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||
unsigned result = 0;
|
||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++)
|
||||
|
@@ -23,7 +23,7 @@ tiles::tiles(size_t num_warps, analysis::align *align, analysis::axes *axes, ana
|
||||
num_warps_(num_warps), align_(align), axes_(axes), layout_(layout)
|
||||
{ }
|
||||
|
||||
bool is_hmma(ir::value *v){
|
||||
bool is_hmma_c(ir::value *v){
|
||||
bool result = false;
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *a = x->get_operand(0);
|
||||
@@ -36,9 +36,44 @@ bool is_hmma(ir::value *v){
|
||||
return result;
|
||||
}
|
||||
|
||||
bool is_hmma_a_col(ir::value* v) {
|
||||
for(ir::user *u: v->get_users())
|
||||
if(is_hmma_c(u)){
|
||||
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||
if((v == dot->get_operand(0)) && !dot->is_a_trans())
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_hmma_a_row(ir::value* v) {
|
||||
for(ir::user *u: v->get_users())
|
||||
if(is_hmma_c(u)){
|
||||
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||
if((v == dot->get_operand(0)) && dot->is_a_trans())
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_hmma_b_col(ir::value* v) {
|
||||
for(ir::user *u: v->get_users())
|
||||
if(is_hmma_c(u)){
|
||||
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||
if((v == dot->get_operand(1)) && !dot->is_b_trans())
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_hmma_b_row(ir::value* v) {
|
||||
for(ir::user *u: v->get_users())
|
||||
if(is_hmma_c(u)){
|
||||
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||
if((v == dot->get_operand(1)) && dot->is_b_trans())
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
bool tiles::hmma(ir::value *value) {
|
||||
layout_t tiles::hmma(ir::value *value) {
|
||||
return hmma_.at(layout_->id(value));
|
||||
}
|
||||
|
||||
@@ -164,7 +199,18 @@ void tiles::run(ir::module &) {
|
||||
// find out which groups require hmma layout
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
const auto& values = layout_->values(i);
|
||||
hmma_[i] = std::any_of(values.begin(), values.end(), &is_hmma);
|
||||
bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c);
|
||||
bool hmma_a_col = std::any_of(values.begin(), values.end(), &is_hmma_a_col);
|
||||
bool hmma_a_row = std::any_of(values.begin(), values.end(), &is_hmma_a_row);
|
||||
bool hmma_b_col = std::any_of(values.begin(), values.end(), &is_hmma_b_col);
|
||||
bool hmma_b_row = std::any_of(values.begin(), values.end(), &is_hmma_b_row);
|
||||
if(hmma_c) hmma_[i] = HMMA_C;
|
||||
else if(hmma_a_col) hmma_[i] = HMMA_A_COL;
|
||||
else if(hmma_a_row) hmma_[i] = HMMA_A_ROW;
|
||||
else if(hmma_b_col) hmma_[i] = HMMA_B_COL;
|
||||
else if(hmma_b_row) hmma_[i] = HMMA_B_ROW;
|
||||
else hmma_[i] = SCANLINE;
|
||||
|
||||
}
|
||||
// find out which value is the largest in each group
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
@@ -197,7 +243,7 @@ void tiles::run(ir::module &) {
|
||||
if(!i->get_type()->is_tile_ty())
|
||||
continue;
|
||||
/* HMMA parameters*/
|
||||
if(hmma_[x.first])
|
||||
if(hmma_[x.first] == HMMA_C)
|
||||
init_hmma_tile(i);
|
||||
else
|
||||
init_scanline_tile(i);
|
||||
|
@@ -710,7 +710,7 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
|
||||
|
||||
|
||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
if(tiles_->hmma(v))
|
||||
if(tiles_->hmma(v) == analysis::HMMA_C)
|
||||
init_hmma_axes(v, builder, u_thread_id, u_warp_id);
|
||||
else
|
||||
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
|
||||
@@ -1241,11 +1241,10 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
|
||||
if(NK != 1) {
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||
if(tiles_->hmma(dot))
|
||||
if(tiles_->hmma(dot) == analysis::HMMA_C)
|
||||
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
|
||||
else
|
||||
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
|
||||
|
||||
}
|
||||
else {
|
||||
distributed_tile *TA = (distributed_tile*)tmap_.at(A);
|
||||
|
@@ -104,12 +104,12 @@ bool peephole::rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool tr
|
||||
BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||
}
|
||||
else{
|
||||
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
||||
if(auto *T = dynamic_cast<ir::trans_inst*>(B)){
|
||||
std::vector<ir::constant_int*> perm(T->get_perm());
|
||||
std::swap(perm[0], perm[1]);
|
||||
AA = builder.create_trans(T->get_operand(0), perm);
|
||||
T->replace_all_uses_with(AA);
|
||||
trans_a = true;
|
||||
BB = builder.create_trans(T->get_operand(0), perm);
|
||||
T->replace_all_uses_with(BB);
|
||||
trans_b = true;
|
||||
}
|
||||
}
|
||||
if(!trans_a && !trans_b)
|
||||
|
@@ -241,7 +241,6 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -201,17 +201,17 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
// create passes
|
||||
codegen::transform::cts cts;
|
||||
codegen::analysis::align align;
|
||||
codegen::analysis::liveness shmem_liveness;
|
||||
codegen::analysis::liveness liveness;
|
||||
codegen::analysis::axes axes;
|
||||
codegen::analysis::layout layouts(&axes);
|
||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
|
||||
codegen::analysis::allocation shmem_allocation(&shmem_liveness, &tiles);
|
||||
codegen::transform::membar shmem_barriers(&shmem_liveness, &shmem_allocation);
|
||||
codegen::analysis::allocation allocation(&liveness, &tiles);
|
||||
codegen::transform::membar barriers(&liveness, &allocation);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate(&align);
|
||||
codegen::selection selection(&shmem_liveness, &shmem_allocation, &tiles, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
|
||||
codegen::selection selection(&liveness, &allocation, &tiles, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
@@ -226,14 +226,12 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
tiles.run(module);
|
||||
reassociate.run(module);
|
||||
dce.run(module);
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
cts.run(module);
|
||||
shmem_liveness.run(module);
|
||||
shmem_allocation.run(module);
|
||||
if(shmem_allocation.allocated_size() > context->device()->max_shared_memory())
|
||||
liveness.run(module);
|
||||
allocation.run(module);
|
||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||
return std::unique_ptr<driver::module>();
|
||||
shmem_barriers.run(module);
|
||||
barriers.run(module);
|
||||
dce.run(module);
|
||||
dce.run(module);
|
||||
axes.run(module);
|
||||
|
@@ -27,9 +27,9 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
|
||||
|
||||
|
||||
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef float NumericT;
|
||||
std::string ty = "float";
|
||||
cublasDataType_t cuty = CUDA_R_32F;
|
||||
typedef half_float::half NumericT;
|
||||
std::string ty = "half";
|
||||
cublasDataType_t cuty = CUDA_R_16F;
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
drv::context* context = stream->context();
|
||||
// leading dimensions
|
||||
@@ -47,8 +47,8 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {"8"}});
|
||||
opt.num_warps = {8};
|
||||
opt.defines.push_back({"TK", {"16"}});
|
||||
opt.num_warps = {2, 4, 8};
|
||||
// create function
|
||||
rt::function function(src::dot, opt);
|
||||
// benchmark available libraries
|
||||
@@ -79,7 +79,10 @@ int main() {
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||
std::vector<config_t> configs;
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, true}}){
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false},
|
||||
{false, true},
|
||||
{true, false},
|
||||
{true, true}}){
|
||||
std::vector<config_t> tmp = {
|
||||
config_t{x[0], x[1], 2048, 2048, 2048}
|
||||
// config_t{x[0], x[1], 16, 2048, 2048},
|
||||
|
@@ -64,7 +64,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
// epilogue
|
||||
int rxc[TM] = ridx * TM + 0 ... TM;
|
||||
int ryc[TN] = ridy * TN + 0 ... TN;
|
||||
TYPE* pc[TM, TN] = C + rxc[:, newaxis] * ldc + ryc[newaxis, :];
|
||||
TYPE* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :] * ldc;
|
||||
*pc = c;
|
||||
}
|
||||
)";
|
||||
|
@@ -32,7 +32,7 @@ static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vecto
|
||||
float acc = 0;
|
||||
for(size_t k = 0; k < K; k++)
|
||||
acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]);
|
||||
c[m*N + n] = static_cast<T>(acc);
|
||||
c[m + n*M] = static_cast<T>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user