diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 83ee2086d..3d2296aae 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -279,6 +279,8 @@ layout_shared_t::layout_shared_t(const layout_t *arg, // order if(arg->type == SCANLINE) order = arg->order; + else + order = arg->order; ir::value* dot_a = nullptr; ir::value* dot_b = nullptr; ir::value* hmma_dot_a = nullptr; @@ -291,7 +293,6 @@ layout_shared_t::layout_shared_t(const layout_t *arg, } std::vector col = {0, 1}; std::vector row = {1, 0}; - order = col; bool is_nonhmma_dot_a = dot_a && !hmma_dot_a; bool is_nonhmma_dot_b = dot_b && !hmma_dot_b; if(is_nonhmma_dot_a) @@ -303,21 +304,20 @@ layout_shared_t::layout_shared_t(const layout_t *arg, pad = 0; if(hmma_dot_a){ bool row = is_trans(hmma_dot_a) ^ order[0] == 1; - pad = 24 - shapes[row ? 0: 1] % 32; + pad = 24 - shapes[row ? order[0] : order[1]] % 32; } else if(hmma_dot_b){ bool row = is_trans(hmma_dot_b) ^ order[0] == 1; - pad = 24 - shapes[row ? 1 : 0] % 32; + pad = 24 - shapes[row ? order[1] : order[0]] % 32; } else if(order != arg->order) { pad = 4; } + shapes[order[0]] += pad; // size - auto shape = this->shapes; - shape[order[0]] += pad; size = ty->get_primitive_size_in_bits() / 8; - for(auto s: shape) + for(auto s: shapes) size *= s; if(double_buffer) size *= 2; diff --git a/lib/codegen/selection/machine_layout.cc b/lib/codegen/selection/machine_layout.cc index cf5d9bf33..1c026bfc8 100644 --- a/lib/codegen/selection/machine_layout.cc +++ b/lib/codegen/selection/machine_layout.cc @@ -108,7 +108,6 @@ machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder, tile* machine_layout_shared_t::create(ir::value *v) { auto order = layout_->order; auto shapes = layout_->shapes; - shapes[order[0]] += layout_->pad; Type* ty = llvm_type(layout_->ty, builder_->getContext()); // double-buffered if(layout_->double_buffer) { diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index c87e1c938..927f0044b 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -34,7 +34,7 @@ int main() { for(const auto& c: configs){ std::tie(ord, AT, BT, M, N, K) = c; std::cout << "// " << c << std::flush; - for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord)) + for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord)) std::cout << ", " << perf << std::flush; std::cout << std::endl; } diff --git a/tests/common/dot.h b/tests/common/dot.h index e87470edb..d9374de7a 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -111,7 +111,7 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, if(mode == BENCH) { opt.defines.push_back({"TM", {"64", "128"}}); opt.defines.push_back({"TN", {"64", "128"}}); - opt.defines.push_back({"TK", {"8"}}); + opt.defines.push_back({"TK", {"8", "16"}}); opt.num_warps = {2, 4, 8}; }