[codegen] [allocation] fixed issues in HMMA
This commit is contained in:
@@ -25,6 +25,15 @@ class axes;
|
|||||||
class layout;
|
class layout;
|
||||||
class align;
|
class align;
|
||||||
|
|
||||||
|
enum layout_t {
|
||||||
|
SCANLINE,
|
||||||
|
HMMA_C,
|
||||||
|
HMMA_A_COL,
|
||||||
|
HMMA_A_ROW,
|
||||||
|
HMMA_B_COL,
|
||||||
|
HMMA_B_ROW
|
||||||
|
};
|
||||||
|
|
||||||
class tiles {
|
class tiles {
|
||||||
typedef std::map<ir::value*, std::map<int, int>> param_map_t;
|
typedef std::map<ir::value*, std::map<int, int>> param_map_t;
|
||||||
private:
|
private:
|
||||||
@@ -34,7 +43,7 @@ private:
|
|||||||
public:
|
public:
|
||||||
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
|
tiles(size_t num_warps, analysis::align* align, analysis::axes* axes, analysis::layout* layout);
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
bool hmma(ir::value *value);
|
layout_t hmma(ir::value *value);
|
||||||
int mts(ir::value *value, unsigned ax);
|
int mts(ir::value *value, unsigned ax);
|
||||||
int nts(ir::value *value, unsigned ax);
|
int nts(ir::value *value, unsigned ax);
|
||||||
int fpw(ir::value *value, unsigned ax);
|
int fpw(ir::value *value, unsigned ax);
|
||||||
@@ -52,7 +61,7 @@ private:
|
|||||||
// tile properties
|
// tile properties
|
||||||
std::map<int, ir::value*> largest_;
|
std::map<int, ir::value*> largest_;
|
||||||
std::map<int, std::vector<int>> order_;
|
std::map<int, std::vector<int>> order_;
|
||||||
std::map<int, bool> hmma_;
|
std::map<int, layout_t> hmma_;
|
||||||
std::map<int, int> fpw_;
|
std::map<int, int> fpw_;
|
||||||
std::map<int, int> wpt_;
|
std::map<int, int> wpt_;
|
||||||
std::map<int, int> mts_;
|
std::map<int, int> mts_;
|
||||||
|
@@ -20,24 +20,14 @@ unsigned allocation::is_ld_padded(ir::value *x) {
|
|||||||
if(trans->get_perm()[0]->get_value() != 0)
|
if(trans->get_perm()[0]->get_value() != 0)
|
||||||
return 4;
|
return 4;
|
||||||
}
|
}
|
||||||
for(ir::user* user: x->get_users())
|
if(tiles_->hmma(x) == HMMA_A_ROW)
|
||||||
if(auto dot = dynamic_cast<ir::dot_inst*>(user)){
|
|
||||||
bool is_hmma = tiles_->hmma(user);
|
|
||||||
bool is_op_0 = x == dot->get_operand(0);
|
|
||||||
bool is_op_1 = x == dot->get_operand(1);
|
|
||||||
if(is_hmma && is_op_0){
|
|
||||||
if(dot->is_a_trans())
|
|
||||||
return 8;
|
return 8;
|
||||||
else
|
if(tiles_->hmma(x) == HMMA_A_COL)
|
||||||
return 16;
|
return 16;
|
||||||
}
|
if(tiles_->hmma(x) == HMMA_B_COL)
|
||||||
if(is_hmma && is_op_1){
|
|
||||||
if(!dot->is_b_trans())
|
|
||||||
return 8;
|
return 8;
|
||||||
else
|
if(tiles_->hmma(x) == HMMA_B_ROW)
|
||||||
return 16;
|
return 16;
|
||||||
}
|
|
||||||
}
|
|
||||||
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
||||||
unsigned result = 0;
|
unsigned result = 0;
|
||||||
for(unsigned i = 0; i < phi->get_num_incoming(); i++)
|
for(unsigned i = 0; i < phi->get_num_incoming(); i++)
|
||||||
|
@@ -23,7 +23,7 @@ tiles::tiles(size_t num_warps, analysis::align *align, analysis::axes *axes, ana
|
|||||||
num_warps_(num_warps), align_(align), axes_(axes), layout_(layout)
|
num_warps_(num_warps), align_(align), axes_(axes), layout_(layout)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
bool is_hmma(ir::value *v){
|
bool is_hmma_c(ir::value *v){
|
||||||
bool result = false;
|
bool result = false;
|
||||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||||
ir::value *a = x->get_operand(0);
|
ir::value *a = x->get_operand(0);
|
||||||
@@ -36,9 +36,44 @@ bool is_hmma(ir::value *v){
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_hmma_a_col(ir::value* v) {
|
||||||
|
for(ir::user *u: v->get_users())
|
||||||
|
if(is_hmma_c(u)){
|
||||||
|
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||||
|
if((v == dot->get_operand(0)) && !dot->is_a_trans())
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_hmma_a_row(ir::value* v) {
|
||||||
|
for(ir::user *u: v->get_users())
|
||||||
|
if(is_hmma_c(u)){
|
||||||
|
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||||
|
if((v == dot->get_operand(0)) && dot->is_a_trans())
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_hmma_b_col(ir::value* v) {
|
||||||
|
for(ir::user *u: v->get_users())
|
||||||
|
if(is_hmma_c(u)){
|
||||||
|
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||||
|
if((v == dot->get_operand(1)) && !dot->is_b_trans())
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_hmma_b_row(ir::value* v) {
|
||||||
|
for(ir::user *u: v->get_users())
|
||||||
|
if(is_hmma_c(u)){
|
||||||
|
ir::dot_inst* dot = (ir::dot_inst*)u;
|
||||||
|
if((v == dot->get_operand(1)) && dot->is_b_trans())
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
bool tiles::hmma(ir::value *value) {
|
layout_t tiles::hmma(ir::value *value) {
|
||||||
return hmma_.at(layout_->id(value));
|
return hmma_.at(layout_->id(value));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,7 +199,18 @@ void tiles::run(ir::module &) {
|
|||||||
// find out which groups require hmma layout
|
// find out which groups require hmma layout
|
||||||
for(size_t i = 0; i < num_groups; i++) {
|
for(size_t i = 0; i < num_groups; i++) {
|
||||||
const auto& values = layout_->values(i);
|
const auto& values = layout_->values(i);
|
||||||
hmma_[i] = std::any_of(values.begin(), values.end(), &is_hmma);
|
bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c);
|
||||||
|
bool hmma_a_col = std::any_of(values.begin(), values.end(), &is_hmma_a_col);
|
||||||
|
bool hmma_a_row = std::any_of(values.begin(), values.end(), &is_hmma_a_row);
|
||||||
|
bool hmma_b_col = std::any_of(values.begin(), values.end(), &is_hmma_b_col);
|
||||||
|
bool hmma_b_row = std::any_of(values.begin(), values.end(), &is_hmma_b_row);
|
||||||
|
if(hmma_c) hmma_[i] = HMMA_C;
|
||||||
|
else if(hmma_a_col) hmma_[i] = HMMA_A_COL;
|
||||||
|
else if(hmma_a_row) hmma_[i] = HMMA_A_ROW;
|
||||||
|
else if(hmma_b_col) hmma_[i] = HMMA_B_COL;
|
||||||
|
else if(hmma_b_row) hmma_[i] = HMMA_B_ROW;
|
||||||
|
else hmma_[i] = SCANLINE;
|
||||||
|
|
||||||
}
|
}
|
||||||
// find out which value is the largest in each group
|
// find out which value is the largest in each group
|
||||||
for(size_t i = 0; i < num_groups; i++) {
|
for(size_t i = 0; i < num_groups; i++) {
|
||||||
@@ -197,7 +243,7 @@ void tiles::run(ir::module &) {
|
|||||||
if(!i->get_type()->is_tile_ty())
|
if(!i->get_type()->is_tile_ty())
|
||||||
continue;
|
continue;
|
||||||
/* HMMA parameters*/
|
/* HMMA parameters*/
|
||||||
if(hmma_[x.first])
|
if(hmma_[x.first] == HMMA_C)
|
||||||
init_hmma_tile(i);
|
init_hmma_tile(i);
|
||||||
else
|
else
|
||||||
init_scanline_tile(i);
|
init_scanline_tile(i);
|
||||||
|
@@ -710,7 +710,7 @@ void selection::init_hmma_axes(ir::value *v, IRBuilder<> &builder, Value *u_thre
|
|||||||
|
|
||||||
|
|
||||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||||
if(tiles_->hmma(v))
|
if(tiles_->hmma(v) == analysis::HMMA_C)
|
||||||
init_hmma_axes(v, builder, u_thread_id, u_warp_id);
|
init_hmma_axes(v, builder, u_thread_id, u_warp_id);
|
||||||
else
|
else
|
||||||
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
|
init_strided_scan_axes(v, builder, u_thread_id, u_warp_id);
|
||||||
@@ -1241,11 +1241,10 @@ void selection::lower_dot(ir::dot_inst *dot, LLVMContext &ctx, Function *fn, IRB
|
|||||||
if(NK != 1) {
|
if(NK != 1) {
|
||||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||||
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
shared_tile *TB = (shared_tile*)tmap_.at(B);
|
||||||
if(tiles_->hmma(dot))
|
if(tiles_->hmma(dot) == analysis::HMMA_C)
|
||||||
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
|
lower_hmma_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK);
|
||||||
else
|
else
|
||||||
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
|
lower_scanline_dot(dot, ctx, fn, builder, TC, TA, TB, TD, NK, c_ty, f_mul_add);
|
||||||
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
distributed_tile *TA = (distributed_tile*)tmap_.at(A);
|
distributed_tile *TA = (distributed_tile*)tmap_.at(A);
|
||||||
|
@@ -104,12 +104,12 @@ bool peephole::rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool tr
|
|||||||
BB = ((ir::trans_inst*)B)->get_operand(0);
|
BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
|
if(auto *T = dynamic_cast<ir::trans_inst*>(B)){
|
||||||
std::vector<ir::constant_int*> perm(T->get_perm());
|
std::vector<ir::constant_int*> perm(T->get_perm());
|
||||||
std::swap(perm[0], perm[1]);
|
std::swap(perm[0], perm[1]);
|
||||||
AA = builder.create_trans(T->get_operand(0), perm);
|
BB = builder.create_trans(T->get_operand(0), perm);
|
||||||
T->replace_all_uses_with(AA);
|
T->replace_all_uses_with(BB);
|
||||||
trans_a = true;
|
trans_b = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(!trans_a && !trans_b)
|
if(!trans_a && !trans_b)
|
||||||
|
@@ -241,7 +241,6 @@ std::string cu_module::compile_llvm_module(std::unique_ptr<llvm::Module> module,
|
|||||||
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { }
|
||||||
|
|
||||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||||
// std::cout << source << std::endl;
|
|
||||||
cu_context::context_switcher ctx_switch(*context);
|
cu_context::context_switcher ctx_switch(*context);
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||||
|
@@ -201,17 +201,17 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
// create passes
|
// create passes
|
||||||
codegen::transform::cts cts;
|
codegen::transform::cts cts;
|
||||||
codegen::analysis::align align;
|
codegen::analysis::align align;
|
||||||
codegen::analysis::liveness shmem_liveness;
|
codegen::analysis::liveness liveness;
|
||||||
codegen::analysis::axes axes;
|
codegen::analysis::axes axes;
|
||||||
codegen::analysis::layout layouts(&axes);
|
codegen::analysis::layout layouts(&axes);
|
||||||
codegen::transform::coalesce coalesce(&align, &layouts);
|
codegen::transform::coalesce coalesce(&align, &layouts);
|
||||||
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
|
codegen::analysis::tiles tiles(opt.num_warps, &align, &axes, &layouts);
|
||||||
codegen::analysis::allocation shmem_allocation(&shmem_liveness, &tiles);
|
codegen::analysis::allocation allocation(&liveness, &tiles);
|
||||||
codegen::transform::membar shmem_barriers(&shmem_liveness, &shmem_allocation);
|
codegen::transform::membar barriers(&liveness, &allocation);
|
||||||
codegen::transform::dce dce;
|
codegen::transform::dce dce;
|
||||||
codegen::transform::peephole peephole;
|
codegen::transform::peephole peephole;
|
||||||
codegen::transform::reassociate reassociate(&align);
|
codegen::transform::reassociate reassociate(&align);
|
||||||
codegen::selection selection(&shmem_liveness, &shmem_allocation, &tiles, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
|
codegen::selection selection(&liveness, &allocation, &tiles, &align, &axes, &layouts, &coalesce, target.get(), opt.num_warps);
|
||||||
// run passes
|
// run passes
|
||||||
peephole.run(module);
|
peephole.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
@@ -226,14 +226,12 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
tiles.run(module);
|
tiles.run(module);
|
||||||
reassociate.run(module);
|
reassociate.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
peephole.run(module);
|
|
||||||
dce.run(module);
|
|
||||||
cts.run(module);
|
cts.run(module);
|
||||||
shmem_liveness.run(module);
|
liveness.run(module);
|
||||||
shmem_allocation.run(module);
|
allocation.run(module);
|
||||||
if(shmem_allocation.allocated_size() > context->device()->max_shared_memory())
|
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||||
return std::unique_ptr<driver::module>();
|
return std::unique_ptr<driver::module>();
|
||||||
shmem_barriers.run(module);
|
barriers.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
axes.run(module);
|
axes.run(module);
|
||||||
|
@@ -27,9 +27,9 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
|
|||||||
|
|
||||||
|
|
||||||
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||||
typedef float NumericT;
|
typedef half_float::half NumericT;
|
||||||
std::string ty = "float";
|
std::string ty = "half";
|
||||||
cublasDataType_t cuty = CUDA_R_32F;
|
cublasDataType_t cuty = CUDA_R_16F;
|
||||||
size_t dt_nbytes = sizeof(NumericT);
|
size_t dt_nbytes = sizeof(NumericT);
|
||||||
drv::context* context = stream->context();
|
drv::context* context = stream->context();
|
||||||
// leading dimensions
|
// leading dimensions
|
||||||
@@ -47,8 +47,8 @@ std::vector<double> do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i
|
|||||||
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
opt.defines.push_back({"BT", {BT?"1":"0"}});
|
||||||
opt.defines.push_back({"TM", {"128"}});
|
opt.defines.push_back({"TM", {"128"}});
|
||||||
opt.defines.push_back({"TN", {"128"}});
|
opt.defines.push_back({"TN", {"128"}});
|
||||||
opt.defines.push_back({"TK", {"8"}});
|
opt.defines.push_back({"TK", {"16"}});
|
||||||
opt.num_warps = {8};
|
opt.num_warps = {2, 4, 8};
|
||||||
// create function
|
// create function
|
||||||
rt::function function(src::dot, opt);
|
rt::function function(src::dot, opt);
|
||||||
// benchmark available libraries
|
// benchmark available libraries
|
||||||
@@ -79,7 +79,10 @@ int main() {
|
|||||||
// shapes to benchmark
|
// shapes to benchmark
|
||||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||||
std::vector<config_t> configs;
|
std::vector<config_t> configs;
|
||||||
for(auto x: std::vector<std::array<bool, 2>>{{false, true}}){
|
for(auto x: std::vector<std::array<bool, 2>>{{false, false},
|
||||||
|
{false, true},
|
||||||
|
{true, false},
|
||||||
|
{true, true}}){
|
||||||
std::vector<config_t> tmp = {
|
std::vector<config_t> tmp = {
|
||||||
config_t{x[0], x[1], 2048, 2048, 2048}
|
config_t{x[0], x[1], 2048, 2048, 2048}
|
||||||
// config_t{x[0], x[1], 16, 2048, 2048},
|
// config_t{x[0], x[1], 16, 2048, 2048},
|
||||||
|
@@ -64,7 +64,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
// epilogue
|
// epilogue
|
||||||
int rxc[TM] = ridx * TM + 0 ... TM;
|
int rxc[TM] = ridx * TM + 0 ... TM;
|
||||||
int ryc[TN] = ridy * TN + 0 ... TN;
|
int ryc[TN] = ridy * TN + 0 ... TN;
|
||||||
TYPE* pc[TM, TN] = C + rxc[:, newaxis] * ldc + ryc[newaxis, :];
|
TYPE* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :] * ldc;
|
||||||
*pc = c;
|
*pc = c;
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
@@ -32,7 +32,7 @@ static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vecto
|
|||||||
float acc = 0;
|
float acc = 0;
|
||||||
for(size_t k = 0; k < K; k++)
|
for(size_t k = 0; k < K; k++)
|
||||||
acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]);
|
acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]);
|
||||||
c[m*N + n] = static_cast<T>(acc);
|
c[m + n*M] = static_cast<T>(acc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user