[codegen] better handling of row/column-major
This commit is contained in:
@@ -27,11 +27,7 @@ class align;
|
||||
|
||||
enum layout_t {
|
||||
SCANLINE,
|
||||
HMMA_C,
|
||||
HMMA_A_COL,
|
||||
HMMA_A_ROW,
|
||||
HMMA_B_COL,
|
||||
HMMA_B_ROW
|
||||
HMMA_C
|
||||
};
|
||||
|
||||
class tiles {
|
||||
|
@@ -89,28 +89,43 @@ void liveness::connected_components(node_t x, std::set<node_t> &nodes, graph_t &
|
||||
}
|
||||
}
|
||||
|
||||
bool is_trans(ir::value *v) {
|
||||
if(dynamic_cast<ir::trans_inst *>(v)) {
|
||||
return true;
|
||||
}
|
||||
if(auto *phi = dynamic_cast<ir::instruction *>(v)) {
|
||||
bool result = true;
|
||||
for(ir::value *op: phi->ops())
|
||||
result = result && is_trans(op);
|
||||
return result;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool liveness::do_pad(ir::value *x) {
|
||||
// alignment for matrix product
|
||||
if(auto* dot = dynamic_cast<ir::dot_inst*>(x)) {
|
||||
// a
|
||||
ir::value *a = dot->get_operand(0);\
|
||||
size_t previous_a = pad_[a];
|
||||
if(tiles_->hmma(a) == HMMA_A_ROW)
|
||||
pad_[a] = 16;
|
||||
else if(tiles_->hmma(a) == HMMA_A_COL)
|
||||
pad_[a] = 8;
|
||||
else
|
||||
pad_[a] = 0;
|
||||
// b
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
size_t previous_b = pad_[b];
|
||||
if(tiles_->hmma(b) == HMMA_B_COL)
|
||||
pad_[b] = 16;
|
||||
if(tiles_->hmma(b) == HMMA_B_ROW)
|
||||
pad_[b] = 8;
|
||||
else
|
||||
pad_[b] = 0;
|
||||
return previous_a != pad_[a] || previous_b != pad_[b];
|
||||
size_t a_previous = pad_[a];
|
||||
size_t b_previous = pad_[b];
|
||||
auto a_order = tiles_->order(a);
|
||||
auto b_order = tiles_->order(b);
|
||||
bool a_row = is_trans(a) ^ (a_order[0] == 1);
|
||||
bool b_row = is_trans(b) ^ (b_order[0] == 1);
|
||||
auto a_shapes = a->get_type()->get_tile_shapes();
|
||||
auto b_shapes = b->get_type()->get_tile_shapes();
|
||||
pad_[a] = std::max<int>(pad_[a], (24 - a_shapes[a_row ? 0 : 1]) % 32);
|
||||
pad_[b] = std::max<int>(pad_[b], (24 - b_shapes[b_row ? 1 : 0]) % 32);
|
||||
return a_previous != pad_[a] || b_previous != pad_[b];
|
||||
}
|
||||
if(auto* trans = dynamic_cast<ir::trans_inst*>(x)) {
|
||||
ir::value *op = trans->get_operand(0);
|
||||
size_t previous = pad_[op];
|
||||
pad_[op] = std::max(pad_[op], pad_[x]);
|
||||
return previous != pad_[op];
|
||||
}
|
||||
if(auto* cts = dynamic_cast<ir::copy_to_shared_inst*>(x)) {
|
||||
auto cts_order = tiles_->order(cts);
|
||||
@@ -118,7 +133,7 @@ bool liveness::do_pad(ir::value *x) {
|
||||
auto arg_order = tiles_->order(arg);
|
||||
size_t previous = pad_[cts];
|
||||
if(cts_order != arg_order)
|
||||
pad_[cts] = 4;
|
||||
pad_[cts] = std::max<int>(pad_[cts], 4);
|
||||
return pad_[cts] != previous;
|
||||
}
|
||||
// padding for phi-nodes
|
||||
|
@@ -215,15 +215,7 @@ void tiles::run(ir::module &) {
|
||||
for(size_t i = 0; i < num_groups; i++) {
|
||||
const auto& values = layout_->values(i);
|
||||
bool hmma_c = std::any_of(values.begin(), values.end(), &is_hmma_c);
|
||||
bool hmma_a_col = std::any_of(values.begin(), values.end(), &is_hmma_a_col);
|
||||
bool hmma_a_row = std::any_of(values.begin(), values.end(), &is_hmma_a_row);
|
||||
bool hmma_b_col = std::any_of(values.begin(), values.end(), &is_hmma_b_col);
|
||||
bool hmma_b_row = std::any_of(values.begin(), values.end(), &is_hmma_b_row);
|
||||
if(hmma_c) hmma_[i] = HMMA_C;
|
||||
else if(hmma_a_col) hmma_[i] = HMMA_A_COL;
|
||||
else if(hmma_a_row) hmma_[i] = HMMA_A_ROW;
|
||||
else if(hmma_b_col) hmma_[i] = HMMA_B_COL;
|
||||
else if(hmma_b_row) hmma_[i] = HMMA_B_ROW;
|
||||
else hmma_[i] = SCANLINE;
|
||||
|
||||
}
|
||||
@@ -254,20 +246,33 @@ void tiles::run(ir::module &) {
|
||||
}
|
||||
order_[i] = order;
|
||||
}
|
||||
// for(size_t i = 0; i < num_groups; i++){
|
||||
// std::vector<ir::dot_inst*> dots;
|
||||
// for(ir::value* v: layout_->values(i))
|
||||
// if(auto *x = dynamic_cast<ir::dot_inst*>(v))
|
||||
// dots.push_back(x);
|
||||
// for(ir::dot_inst* dot: dots){
|
||||
// ir::value* a = dot->get_operand(0);
|
||||
// ir::value* b = dot->get_operand(1);
|
||||
// std::vector<int> col = {0, 1};
|
||||
// std::vector<int> row = {1, 0};
|
||||
// order_[layout_->id(a)] = is_trans(a) ? row : col;
|
||||
// order_[layout_->id(b)] = is_trans(b) ? col : row;
|
||||
// }
|
||||
// }
|
||||
// matrix multiplication optimizations
|
||||
for(size_t i = 0; i < num_groups; i++){
|
||||
std::vector<ir::dot_inst*> dots;
|
||||
for(ir::value* v: layout_->values(i))
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v))
|
||||
dots.push_back(x);
|
||||
for(ir::dot_inst* dot: dots){
|
||||
ir::value* a = dot->get_operand(0);
|
||||
ir::value* b = dot->get_operand(1);
|
||||
if(hmma_.at(layout_->id(dot)) == HMMA_C){
|
||||
auto a_val = layout_->values(layout_->id(a));
|
||||
auto b_val = layout_->values(layout_->id(b));
|
||||
for(ir::value *v: a_val)
|
||||
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
order_[layout_->id(a)] = order_[layout_->id(cts->get_operand(0))];
|
||||
for(ir::value *v: b_val)
|
||||
if(auto *cts = dynamic_cast<ir::copy_to_shared_inst*>(v))
|
||||
order_[layout_->id(b)] = order_[layout_->id(cts->get_operand(0))];
|
||||
}
|
||||
else{
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
order_[layout_->id(a)] = is_trans(a) ? row : col;
|
||||
order_[layout_->id(b)] = is_trans(b) ? col : row;
|
||||
}
|
||||
}
|
||||
}
|
||||
// tiling parameters
|
||||
for(auto x: largest_){
|
||||
ir::value *i = x.second;
|
||||
|
@@ -239,7 +239,9 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
axes.run(module);
|
||||
layouts.run(module);
|
||||
align.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
tiles.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
selection.run(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
|
@@ -7,32 +7,34 @@ int main() {
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
|
||||
std::vector<config_t> configs;
|
||||
for(auto ord: std::vector<std::vector<int>>{{0, 1}, {1, 0}})
|
||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}, {false, true},
|
||||
{true, false}, {true, true}}){
|
||||
std::vector<config_t> tmp = {
|
||||
config_t{x[0], x[1], 2048, 2048, 2048},
|
||||
// config_t{x[0], x[1], 16, 2048, 2048},
|
||||
// config_t{x[0], x[1], 32, 2048, 2048},
|
||||
// config_t{x[0], x[1], 64, 2048, 2048},
|
||||
// config_t{x[0], x[1], 128, 2048, 2048},
|
||||
// config_t{x[0], x[1], 7000, 2048, 2048},
|
||||
// config_t{x[0], x[1], 16, 4096, 4096},
|
||||
// config_t{x[0], x[1], 32, 4096, 4096},
|
||||
// config_t{x[0], x[1], 64, 4096, 4096},
|
||||
// config_t{x[0], x[1], 128, 4096, 4096},
|
||||
// config_t{x[0], x[1], 7000, 4096, 4096}
|
||||
config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
||||
// config_t{ord, x[0], x[1], 16, 2048, 2048},
|
||||
// config_t{ord, x[0], x[1], 32, 2048, 2048},
|
||||
// config_t{ord, x[0], x[1], 64, 2048, 2048},
|
||||
// config_t{ord, x[0], x[1], 128, 2048, 2048},
|
||||
// config_t{ord, x[0], x[1], 7000, 2048, 2048},
|
||||
// config_t{ord, x[0], x[1], 16, 4096, 4096},
|
||||
// config_t{ord, x[0], x[1], 32, 4096, 4096},
|
||||
// config_t{ord, x[0], x[1], 64, 4096, 4096},
|
||||
// config_t{ord, x[0], x[1], 128, 4096, 4096},
|
||||
// config_t{ord, x[0], x[1], 7000, 4096, 4096}
|
||||
};
|
||||
configs.insert(configs.end(), tmp.begin(), tmp.end());
|
||||
}
|
||||
// does the work
|
||||
std::vector<int> ord;
|
||||
bool AT, BT;
|
||||
int32_t M, N, K;
|
||||
for(const auto& c: configs){
|
||||
std::tie(AT, BT, M, N, K) = c;
|
||||
std::cout << "// " << AT << " " << BT << " " << M << " " << N << " " << K << std::flush;
|
||||
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K))
|
||||
std::tie(ord, AT, BT, M, N, K) = c;
|
||||
std::cout << "// " << c << std::flush;
|
||||
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
|
||||
std::cout << ", " << perf << std::flush;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
@@ -19,7 +19,7 @@ static void cc_dot(std::vector<T> &c, const std::vector<T> &a, const std::vector
|
||||
for(size_t n = 0; n < N; n++){
|
||||
float acc = 0;
|
||||
for(size_t k = 0; k < K; k++)
|
||||
acc = acc + (AT ? a[k*M + m] : a[m*K + k]) * (BT ? b[n*K + k] : b[k*N + n]);
|
||||
acc = acc + (!AT ? a[k*M + m] : a[m*K + k]) * (!BT ? b[n*K + k] : b[k*N + n]);
|
||||
c[m + n*M] = static_cast<T>(acc);
|
||||
}
|
||||
}
|
||||
@@ -67,6 +67,7 @@ template<class T>
|
||||
bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
int32_t TM, int32_t TN, int32_t TK, size_t nwarp,
|
||||
const std::vector<int>& a_order, const std::vector<int>& b_order,
|
||||
run_mode_t mode, std::vector<double>& bench, bool &test){
|
||||
std::string ty = to_string<T>::value;
|
||||
size_t dt_nbytes = sizeof(T);
|
||||
@@ -74,6 +75,8 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
int32_t lda = AT ? K : M;
|
||||
int32_t ldb = BT ? N : K;
|
||||
int32_t ldc = M;
|
||||
std::vector<std::string> sa = { "1", "lda" };
|
||||
std::vector<std::string> sb = { "1", "ldb" };
|
||||
|
||||
// inputs
|
||||
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
|
||||
@@ -82,20 +85,20 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
|
||||
// macros
|
||||
rt::function::options_space_t opt;
|
||||
// B access patterns
|
||||
opt.defines.push_back({"USEB", {BT? "^b" : "b" }});
|
||||
opt.defines.push_back({"BROADCAST_BK", {BT? "newaxis, :" : ":, newaxis" }});
|
||||
opt.defines.push_back({"BROADCAST_BN", {BT? ":, newaxis" : "newaxis, :" }});
|
||||
opt.defines.push_back({"SHAPE_B", {BT? "TN, TK" : "TK, TN" }});
|
||||
opt.defines.push_back({"STRIDE_BK", {BT? "1" : "ldb" }});
|
||||
opt.defines.push_back({"STRIDE_BN", {BT? "ldb" : "1" }});
|
||||
// A access patterns
|
||||
opt.defines.push_back({"USEA", {AT? "^a" : "a" }});
|
||||
opt.defines.push_back({"BROADCAST_AK", {AT? ":, newaxis" : "newaxis, :" }});
|
||||
opt.defines.push_back({"BROADCAST_AM", {AT? "newaxis, :" : ":, newaxis" }});
|
||||
opt.defines.push_back({"SHAPE_A", {AT? "TK, TM" : "TM, TK" }});
|
||||
opt.defines.push_back({"STRIDE_AK", {AT? "lda" : "1" }});
|
||||
opt.defines.push_back({"STRIDE_AM", {AT? "1" : "lda" }});
|
||||
opt.defines.push_back({"USEA", {AT? "^a" : "a" }});
|
||||
opt.defines.push_back({"BROADCAST_AK", {AT? ":, newaxis" : "newaxis, :" }});
|
||||
opt.defines.push_back({"BROADCAST_AM", {AT? "newaxis, :" : ":, newaxis" }});
|
||||
opt.defines.push_back({"SHAPE_A", {AT? "TK, TM" : "TM, TK" }});
|
||||
opt.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
|
||||
opt.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
|
||||
// B access patterns
|
||||
opt.defines.push_back({"USEB", {BT? "^b" : "b" }});
|
||||
opt.defines.push_back({"BROADCAST_BK", {BT? "newaxis, :" : ":, newaxis" }});
|
||||
opt.defines.push_back({"BROADCAST_BN", {BT? ":, newaxis" : "newaxis, :" }});
|
||||
opt.defines.push_back({"SHAPE_B", {BT? "TN, TK" : "TK, TN" }});
|
||||
opt.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }});
|
||||
opt.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }});
|
||||
// data-type
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
// tile sizes
|
||||
@@ -164,13 +167,14 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
|
||||
std::vector<double> bench_dot(drv::stream* stream,
|
||||
dtype_t dtype, bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K) {
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
const std::vector<int>& a_order, const std::vector<int>& b_order) {
|
||||
std::vector<double> bench;
|
||||
bool test;
|
||||
switch(dtype){
|
||||
case HALF: triton_dot<half_float::half>(stream, AT, BT, M, N, K, 0, 0, 0, 0, BENCH, bench, test); break;
|
||||
case FLOAT: triton_dot<float>(stream, AT, BT, M, N, K, 0, 0, 0, 0, BENCH, bench, test); break;
|
||||
case DOUBLE: triton_dot<double>(stream, AT, BT, M, N, K, 0, 0, 0, 0, BENCH, bench, test); break;
|
||||
case HALF: triton_dot<half_float::half>(stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break;
|
||||
case FLOAT: triton_dot<float>(stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break;
|
||||
case DOUBLE: triton_dot<double>(stream, AT, BT, M, N, K, 0, 0, 0, 0, a_order, b_order, BENCH, bench, test); break;
|
||||
default: break;
|
||||
}
|
||||
return bench;
|
||||
@@ -178,13 +182,14 @@ std::vector<double> bench_dot(drv::stream* stream,
|
||||
bool test_dot(drv::stream* stream,
|
||||
dtype_t dtype, bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
const std::vector<int>& a_order, const std::vector<int>& b_order,
|
||||
int32_t TM, int32_t TN, int32_t TK, size_t nwarp) {
|
||||
std::vector<double> bench;
|
||||
bool test = false;
|
||||
switch(dtype){
|
||||
case HALF: triton_dot<half_float::half>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, TEST, bench, test); break;
|
||||
case FLOAT: triton_dot<float>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, TEST, bench, test); break;
|
||||
case DOUBLE: triton_dot<double>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, TEST, bench, test); break;
|
||||
case HALF: triton_dot<half_float::half>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break;
|
||||
case FLOAT: triton_dot<float>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break;
|
||||
case DOUBLE: triton_dot<double>(stream, AT, BT, M, N, K, TM, TN, TK, nwarp, a_order, b_order, TEST, bench, test); break;
|
||||
default: break;
|
||||
}
|
||||
return test;
|
||||
|
@@ -25,7 +25,7 @@ int main() {
|
||||
for(const auto& c: configs){
|
||||
std::tie(dtype, AT, BT, M, N, K, TM, TN, TK, nwarp) = c;
|
||||
std::cout << "Testing " << c << " ... " << std::flush;
|
||||
if(test_dot(stream, dtype, AT, BT, M, N, K, TM, TN, TK, (size_t)nwarp))
|
||||
if(test_dot(stream, dtype, AT, BT, M, N, K, {0, 1}, {0, 1}, TM, TN, TK, (size_t)nwarp))
|
||||
std::cout << " Pass! " << std::endl;
|
||||
else{
|
||||
std::cout << " Fail! " << std::endl;
|
||||
|
Reference in New Issue
Block a user