[codegen] [layout] fixed padding issue for row-major HMMA

This commit is contained in:
Philippe Tillet
2019-10-18 13:42:15 -04:00
parent b43454c9b7
commit cfde3dd766
4 changed files with 8 additions and 9 deletions

View File

@@ -279,6 +279,8 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
// order
if(arg->type == SCANLINE)
order = arg->order;
else
order = arg->order;
ir::value* dot_a = nullptr;
ir::value* dot_b = nullptr;
ir::value* hmma_dot_a = nullptr;
@@ -291,7 +293,6 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
}
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
order = col;
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
if(is_nonhmma_dot_a)
@@ -303,21 +304,20 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
pad = 0;
if(hmma_dot_a){
bool row = is_trans(hmma_dot_a) ^ order[0] == 1;
pad = 24 - shapes[row ? 0: 1] % 32;
pad = 24 - shapes[row ? order[0] : order[1]] % 32;
}
else if(hmma_dot_b){
bool row = is_trans(hmma_dot_b) ^ order[0] == 1;
pad = 24 - shapes[row ? 1 : 0] % 32;
pad = 24 - shapes[row ? order[1] : order[0]] % 32;
}
else if(order != arg->order) {
pad = 4;
}
shapes[order[0]] += pad;
// size
auto shape = this->shapes;
shape[order[0]] += pad;
size = ty->get_primitive_size_in_bits() / 8;
for(auto s: shape)
for(auto s: shapes)
size *= s;
if(double_buffer)
size *= 2;

View File

@@ -108,7 +108,6 @@ machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder,
tile* machine_layout_shared_t::create(ir::value *v) {
auto order = layout_->order;
auto shapes = layout_->shapes;
shapes[order[0]] += layout_->pad;
Type* ty = llvm_type(layout_->ty, builder_->getContext());
// double-buffered
if(layout_->double_buffer) {

View File

@@ -34,7 +34,7 @@ int main() {
for(const auto& c: configs){
std::tie(ord, AT, BT, M, N, K) = c;
std::cout << "// " << c << std::flush;
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;
}

View File

@@ -111,7 +111,7 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
if(mode == BENCH) {
opt.defines.push_back({"TM", {"64", "128"}});
opt.defines.push_back({"TN", {"64", "128"}});
opt.defines.push_back({"TK", {"8"}});
opt.defines.push_back({"TK", {"8", "16"}});
opt.num_warps = {2, 4, 8};
}