[examples] multiple transposition schemes now supported
This commit is contained in:
@@ -45,81 +45,67 @@ void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K,
|
||||
|
||||
|
||||
|
||||
std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::string c_ty, int align_lda, int align_ldb) {
|
||||
std::string ZS = "1";
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string XAS0 = "TM", XAS1 = "TK / " + ZS, XAS2 = ZS;
|
||||
std::string XBS0 = "TK / " + ZS, XBS1 = ZS, XBS2 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT ? "^a" : "a";
|
||||
std::string useb = BT ? "^b" : "b";
|
||||
if(AT){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(XAS0, XAS1);
|
||||
std::swap(XAS1, XAS2);
|
||||
std::swap(bca0, bca1);
|
||||
std::swap(lda0, lda1);
|
||||
}
|
||||
if(BT){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(XBS1, XBS2);
|
||||
std::swap(XBS0, XBS1);
|
||||
std::swap(bcb0, bcb1);
|
||||
std::swap(ldb0, ldb1);
|
||||
}
|
||||
std::string AS = AS0 + ", " + AS1;
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
std::string XCS = "TM, TN";
|
||||
std::string align_lda_str = "multipleof(" + std::to_string(align_lda) + ")";
|
||||
std::string align_ldb_str = "multipleof(" + std::to_string(align_ldb) + ")";
|
||||
std::string res =
|
||||
std::string src =
|
||||
R"(
|
||||
#define bool _Bool
|
||||
#define true 1
|
||||
#define false 0
|
||||
#define __bool_true_false_are_defined 1
|
||||
#ifdef AT
|
||||
#define USEA ^a
|
||||
#else
|
||||
#define USEA a
|
||||
#endif
|
||||
|
||||
#define __readonly __attribute__((readonly))
|
||||
#define __writeonly __attribute__((writeonly))
|
||||
#define __noalias __attribute__((noalias))
|
||||
#define __aligned(A) __attribute__((aligned(A)))
|
||||
#define __multipleof(A) __attribute__((multipleof(A)))
|
||||
#ifdef BT
|
||||
#define USEB ^b
|
||||
#else
|
||||
#define USEB b
|
||||
#endif
|
||||
|
||||
extern int get_program_id(int);
|
||||
|
||||
void matmul()" + a_ty + R"( * A __noalias __readonly __aligned(16),
|
||||
)" + b_ty + R"( * B __noalias __readonly __aligned(16),
|
||||
)" + c_ty + R"( * C __noalias __readonly __aligned(16),
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc) {
|
||||
void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||
TYPE * B __noalias __readonly __aligned(16),
|
||||
TYPE * C __noalias __readonly __aligned(16),
|
||||
int M, int N, int K,
|
||||
int lda __multipleof(8),
|
||||
int ldb __multipleof(8),
|
||||
int ldc) {
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rxa[TM] = ridx * TM + 0 ... TM;
|
||||
int ryb[TN] = ridy * TN + 0 ... TN;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float xc[)" + XCS + R"(] = 0;
|
||||
)" + a_ty + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
)" + a_ty + R"( a[)" + AS + R"(] = *pa;
|
||||
)" + b_ty + R"( b[)" + BS + R"(] = *pb;
|
||||
float xc[TM, TN] = 0;
|
||||
#ifdef AT
|
||||
TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda;
|
||||
TYPE a[TK, TM] = *pa;
|
||||
#else
|
||||
TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis];
|
||||
TYPE a[TM, TK] = *pa;
|
||||
#endif
|
||||
#ifdef BT
|
||||
TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis];
|
||||
TYPE b[TN, TK] = *pb;
|
||||
#else
|
||||
TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb;
|
||||
TYPE b[TK, TN] = *pb;
|
||||
#endif
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = )" + usea + " @ " + useb + R"( + xc;
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
xc = USEA @ USEB + xc;
|
||||
#ifdef AT
|
||||
pa = pa + TK;
|
||||
#else
|
||||
pa = pa + TK*lda;
|
||||
#endif
|
||||
#ifdef BT
|
||||
pb = pb + TK*ldb;
|
||||
#else
|
||||
pb = pb + TK;
|
||||
#endif
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
)" + c_ty + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty + R"( c[TM, TN] = xc;
|
||||
TYPE* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
TYPE c[TM, TN] = xc;
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@@ -127,9 +113,6 @@ void matmul()" + a_ty + R"( * A __noalias __readonly __aligned(16),
|
||||
}
|
||||
)";
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
struct perf_t {
|
||||
double triton;
|
||||
double cublas;
|
||||
@@ -165,11 +148,16 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
|
||||
stream->synchronize();
|
||||
// run
|
||||
rt::function::options_space_t opt;
|
||||
opt.defines.push_back({"TYPE", {ty}});
|
||||
if(AT)
|
||||
opt.defines.push_back({"AT", {""}});
|
||||
if(BT)
|
||||
opt.defines.push_back({"BT", {""}});
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TK", {"32"}});
|
||||
opt.num_warps = {1, 2, 4, 8};
|
||||
rt::function function(src(AT, BT, ty, ty, ty, 8, 8), opt);
|
||||
rt::function function(src, opt);
|
||||
|
||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
||||
auto grid = [&](const rt::function::options_t& x) { return rt::grid_t{ceil(M, x.D<int>("TM")), ceil(N, x.D<int>("TN")), 1}; };
|
||||
@@ -220,7 +208,7 @@ int main() {
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs = {
|
||||
// {false, false, 8192, 512, 512},
|
||||
{false, true, 128, 128, 128}
|
||||
{false, false, 128, 128, 128}
|
||||
// {false, true, 128, 128, 128},
|
||||
// {false, false, 128, 128, 128},
|
||||
// {true, false, 128, 128, 128},
|
||||
|
@@ -112,8 +112,12 @@ bool peephole::rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool tr
|
||||
trans_a = true;
|
||||
}
|
||||
}
|
||||
if(!trans_a && !trans_b)
|
||||
return false;
|
||||
|
||||
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
|
||||
dot->replace_all_uses_with(dot_atbt);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -186,8 +190,9 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
if(dot->is_a_trans() || dot->is_b_trans())
|
||||
return false;
|
||||
// hmma
|
||||
if(is_hmma(dot))
|
||||
if(is_hmma(dot)){
|
||||
return rewrite_dot_hmma(dot, builder, trans_a, trans_b, A, B, D);
|
||||
}
|
||||
else
|
||||
return rewrite_dot_fp32(dot, builder, trans_a, trans_b, A, B, D);
|
||||
}
|
||||
|
@@ -206,9 +206,25 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string preheader() {
|
||||
return R"(
|
||||
#define bool _Bool
|
||||
#define true 1
|
||||
#define false 0
|
||||
#define __bool_true_false_are_defined 1
|
||||
|
||||
#define __readonly __attribute__((readonly))
|
||||
#define __writeonly __attribute__((writeonly))
|
||||
#define __noalias __attribute__((noalias))
|
||||
#define __aligned(A) __attribute__((aligned(A)))
|
||||
#define __multipleof(A) __attribute__((multipleof(A)))
|
||||
|
||||
extern int get_program_id(int);
|
||||
)";
|
||||
}
|
||||
|
||||
function::function(const std::string &src, const options_space_t& opt): src_(src), opt_space_(opt) {
|
||||
|
||||
src_ = preheader() + src_;
|
||||
}
|
||||
|
||||
void function::operator()(const std::vector<arg>& args, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
||||
|
Reference in New Issue
Block a user