[codegen] [layout] fixed padding issue for row-major HMMA
This commit is contained in:
@@ -279,6 +279,8 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
|||||||
// order
|
// order
|
||||||
if(arg->type == SCANLINE)
|
if(arg->type == SCANLINE)
|
||||||
order = arg->order;
|
order = arg->order;
|
||||||
|
else
|
||||||
|
order = arg->order;
|
||||||
ir::value* dot_a = nullptr;
|
ir::value* dot_a = nullptr;
|
||||||
ir::value* dot_b = nullptr;
|
ir::value* dot_b = nullptr;
|
||||||
ir::value* hmma_dot_a = nullptr;
|
ir::value* hmma_dot_a = nullptr;
|
||||||
@@ -291,7 +293,6 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
|||||||
}
|
}
|
||||||
std::vector<int> col = {0, 1};
|
std::vector<int> col = {0, 1};
|
||||||
std::vector<int> row = {1, 0};
|
std::vector<int> row = {1, 0};
|
||||||
order = col;
|
|
||||||
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
||||||
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
||||||
if(is_nonhmma_dot_a)
|
if(is_nonhmma_dot_a)
|
||||||
@@ -303,21 +304,20 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
|||||||
pad = 0;
|
pad = 0;
|
||||||
if(hmma_dot_a){
|
if(hmma_dot_a){
|
||||||
bool row = is_trans(hmma_dot_a) ^ order[0] == 1;
|
bool row = is_trans(hmma_dot_a) ^ order[0] == 1;
|
||||||
pad = 24 - shapes[row ? 0: 1] % 32;
|
pad = 24 - shapes[row ? order[0] : order[1]] % 32;
|
||||||
}
|
}
|
||||||
else if(hmma_dot_b){
|
else if(hmma_dot_b){
|
||||||
bool row = is_trans(hmma_dot_b) ^ order[0] == 1;
|
bool row = is_trans(hmma_dot_b) ^ order[0] == 1;
|
||||||
pad = 24 - shapes[row ? 1 : 0] % 32;
|
pad = 24 - shapes[row ? order[1] : order[0]] % 32;
|
||||||
}
|
}
|
||||||
else if(order != arg->order) {
|
else if(order != arg->order) {
|
||||||
pad = 4;
|
pad = 4;
|
||||||
}
|
}
|
||||||
|
shapes[order[0]] += pad;
|
||||||
|
|
||||||
// size
|
// size
|
||||||
auto shape = this->shapes;
|
|
||||||
shape[order[0]] += pad;
|
|
||||||
size = ty->get_primitive_size_in_bits() / 8;
|
size = ty->get_primitive_size_in_bits() / 8;
|
||||||
for(auto s: shape)
|
for(auto s: shapes)
|
||||||
size *= s;
|
size *= s;
|
||||||
if(double_buffer)
|
if(double_buffer)
|
||||||
size *= 2;
|
size *= 2;
|
||||||
|
@@ -108,7 +108,6 @@ machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder,
|
|||||||
tile* machine_layout_shared_t::create(ir::value *v) {
|
tile* machine_layout_shared_t::create(ir::value *v) {
|
||||||
auto order = layout_->order;
|
auto order = layout_->order;
|
||||||
auto shapes = layout_->shapes;
|
auto shapes = layout_->shapes;
|
||||||
shapes[order[0]] += layout_->pad;
|
|
||||||
Type* ty = llvm_type(layout_->ty, builder_->getContext());
|
Type* ty = llvm_type(layout_->ty, builder_->getContext());
|
||||||
// double-buffered
|
// double-buffered
|
||||||
if(layout_->double_buffer) {
|
if(layout_->double_buffer) {
|
||||||
|
@@ -34,7 +34,7 @@ int main() {
|
|||||||
for(const auto& c: configs){
|
for(const auto& c: configs){
|
||||||
std::tie(ord, AT, BT, M, N, K) = c;
|
std::tie(ord, AT, BT, M, N, K) = c;
|
||||||
std::cout << "// " << c << std::flush;
|
std::cout << "// " << c << std::flush;
|
||||||
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
|
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
|
||||||
std::cout << ", " << perf << std::flush;
|
std::cout << ", " << perf << std::flush;
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
@@ -111,7 +111,7 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
|||||||
if(mode == BENCH) {
|
if(mode == BENCH) {
|
||||||
opt.defines.push_back({"TM", {"64", "128"}});
|
opt.defines.push_back({"TM", {"64", "128"}});
|
||||||
opt.defines.push_back({"TN", {"64", "128"}});
|
opt.defines.push_back({"TN", {"64", "128"}});
|
||||||
opt.defines.push_back({"TK", {"8"}});
|
opt.defines.push_back({"TK", {"8", "16"}});
|
||||||
opt.num_warps = {2, 4, 8};
|
opt.num_warps = {2, 4, 8};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user