[RUNTIME] Added option to print LLVM-IR
Also includes appropriate driver code change for that
This commit is contained in:
@@ -131,15 +131,14 @@ template<> struct to_string<double>{
|
||||
};
|
||||
|
||||
template<class T>
|
||||
void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
const std::vector<int>& a_order, const std::vector<int>& b_order,
|
||||
std::vector<double>& bench, bool &test){
|
||||
float triton_dot(drv::context* context, drv::stream* stream,
|
||||
bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K){
|
||||
std::string ty = to_string<T>::value;
|
||||
size_t dt_nbytes = sizeof(T);
|
||||
drv::device* device = context->device();
|
||||
int32_t lda = (AT ^ a_order[0]==1) ? K : M;
|
||||
int32_t ldb = (BT ^ b_order[0]==1) ? N : K;
|
||||
int32_t lda = AT ? K : M;
|
||||
int32_t ldb = BT ? N : K;
|
||||
int32_t ldc = N;
|
||||
std::vector<std::string> sa = { "1", "lda" };
|
||||
std::vector<std::string> sb = { "1", "ldb" };
|
||||
@@ -156,18 +155,16 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT,
|
||||
ha[i] = (float)rand()/RAND_MAX;
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = (float)rand()/RAND_MAX;
|
||||
// copy buffer
|
||||
stream->write(&*da, true, 0, ha);
|
||||
stream->write(&*db, true, 0, hb);
|
||||
|
||||
// macros
|
||||
rt::options_space_t opts;
|
||||
// A access patterns
|
||||
opts.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
|
||||
opts.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
|
||||
opts.defines.push_back({"STRIDE_AK", {AT? "1" : "lda" }});
|
||||
opts.defines.push_back({"STRIDE_AM", {AT? "lda" : "1" }});
|
||||
// B access patterns
|
||||
opts.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }});
|
||||
opts.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }});
|
||||
opts.defines.push_back({"STRIDE_BK", {BT? "ldb" : "1" }});
|
||||
opts.defines.push_back({"STRIDE_BN", {BT? "1" : "ldb" }});
|
||||
// data-type
|
||||
opts.defines.push_back({"TYPE", {ty}});
|
||||
// tile sizes
|
||||
@@ -190,8 +187,9 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT,
|
||||
rt::add_arg(oss, ldb);
|
||||
rt::add_arg(oss, ldc);
|
||||
rt::add_arg(oss, *dlocks->cu());
|
||||
// kernel
|
||||
// function
|
||||
rt::function function(src::dot, opts, device);
|
||||
// std::cout << function.get_kernels()[0].second->get_asm(rt::ASM_LLIR) << std::endl;
|
||||
// grid
|
||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
||||
auto grid = [ceil, M, N](const rt::options_t& x) {
|
||||
@@ -203,43 +201,37 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT,
|
||||
// metrics
|
||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||
double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream);}, stream);
|
||||
bench.push_back(tflops(triton_ns));
|
||||
return tflops(triton_ns);
|
||||
}
|
||||
|
||||
std::vector<double> bench_dot(drv::context* context, drv::stream* stream,
|
||||
dtype_t dtype, bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
const std::vector<int>& a_order, const std::vector<int>& b_order) {
|
||||
std::vector<double> bench;
|
||||
bool test;
|
||||
float bench_dot(drv::context* context, drv::stream* stream,
|
||||
bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
dtype_t dtype) {
|
||||
switch(dtype){
|
||||
case HALF: triton_dot<half_float::half>(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break;
|
||||
case FLOAT: triton_dot<float>(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break;
|
||||
case DOUBLE: triton_dot<double>(context, stream, AT, BT, M, N, K, a_order, b_order, bench, test); break;
|
||||
default: break;
|
||||
case HALF: return triton_dot<half_float::half>(context, stream, AT, BT, M, N, K);
|
||||
case FLOAT: return triton_dot<float>(context, stream, AT, BT, M, N, K);
|
||||
case DOUBLE: return triton_dot<double>(context, stream, AT, BT, M, N, K);
|
||||
default: return 0;
|
||||
}
|
||||
return bench;
|
||||
}
|
||||
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context->backend());
|
||||
// shapes to benchmark
|
||||
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
|
||||
typedef std::tuple<bool, bool, int, int, int> config_t;
|
||||
std::vector<config_t> configs = {
|
||||
{{1, 0}, false, false, 8192, 8192, 8192}
|
||||
{false, false, 8192, 8192, 8192}
|
||||
};
|
||||
// does the work
|
||||
std::vector<int> ord;
|
||||
bool AT, BT;
|
||||
int32_t M, N, K;
|
||||
dtype_t dtype = HALF;
|
||||
for(const auto& c: configs){
|
||||
std::tie(ord, AT, BT, M, N, K) = c;
|
||||
std::cout << "// " << AT << ", " << BT << ", " << M << ", " << N << ", " << K ;
|
||||
for(auto perf: bench_dot(context, stream, HALF, AT, BT, M, N, K, ord, ord))
|
||||
std::cout << ", " << perf << std::flush;
|
||||
std::cout << std::endl;
|
||||
std::tie(AT, BT, M, N, K) = c;
|
||||
float tflops = bench_dot(context, stream, AT, BT, M, N, K, dtype);
|
||||
std::cout << "// " << AT << ", " << BT << ", " << M << ", " << N << ", " << K << ", " << tflops << std::endl;
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user