[codegen] [layout] fixed padding issue for row-major HMMA
This commit is contained in:
@@ -279,6 +279,8 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
||||
// order
|
||||
if(arg->type == SCANLINE)
|
||||
order = arg->order;
|
||||
else
|
||||
order = arg->order;
|
||||
ir::value* dot_a = nullptr;
|
||||
ir::value* dot_b = nullptr;
|
||||
ir::value* hmma_dot_a = nullptr;
|
||||
@@ -291,7 +293,6 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
||||
}
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
order = col;
|
||||
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
||||
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
||||
if(is_nonhmma_dot_a)
|
||||
@@ -303,21 +304,20 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
||||
pad = 0;
|
||||
if(hmma_dot_a){
|
||||
bool row = is_trans(hmma_dot_a) ^ order[0] == 1;
|
||||
pad = 24 - shapes[row ? 0: 1] % 32;
|
||||
pad = 24 - shapes[row ? order[0] : order[1]] % 32;
|
||||
}
|
||||
else if(hmma_dot_b){
|
||||
bool row = is_trans(hmma_dot_b) ^ order[0] == 1;
|
||||
pad = 24 - shapes[row ? 1 : 0] % 32;
|
||||
pad = 24 - shapes[row ? order[1] : order[0]] % 32;
|
||||
}
|
||||
else if(order != arg->order) {
|
||||
pad = 4;
|
||||
}
|
||||
shapes[order[0]] += pad;
|
||||
|
||||
// size
|
||||
auto shape = this->shapes;
|
||||
shape[order[0]] += pad;
|
||||
size = ty->get_primitive_size_in_bits() / 8;
|
||||
for(auto s: shape)
|
||||
for(auto s: shapes)
|
||||
size *= s;
|
||||
if(double_buffer)
|
||||
size *= 2;
|
||||
|
@@ -108,7 +108,6 @@ machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder,
|
||||
tile* machine_layout_shared_t::create(ir::value *v) {
|
||||
auto order = layout_->order;
|
||||
auto shapes = layout_->shapes;
|
||||
shapes[order[0]] += layout_->pad;
|
||||
Type* ty = llvm_type(layout_->ty, builder_->getContext());
|
||||
// double-buffered
|
||||
if(layout_->double_buffer) {
|
||||
|
@@ -34,7 +34,7 @@ int main() {
|
||||
for(const auto& c: configs){
|
||||
std::tie(ord, AT, BT, M, N, K) = c;
|
||||
std::cout << "// " << c << std::flush;
|
||||
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
|
||||
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
|
||||
std::cout << ", " << perf << std::flush;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
@@ -111,7 +111,7 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
if(mode == BENCH) {
|
||||
opt.defines.push_back({"TM", {"64", "128"}});
|
||||
opt.defines.push_back({"TN", {"64", "128"}});
|
||||
opt.defines.push_back({"TK", {"8"}});
|
||||
opt.defines.push_back({"TK", {"8", "16"}});
|
||||
opt.num_warps = {2, 4, 8};
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user