[runtime] overall of the run-time API
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
#include <cstring>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <cstdio>
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/dot.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/external/half.hpp"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "cuda.h"
|
||||
|
||||
template<class T>
|
||||
@@ -19,20 +19,125 @@ void diff(const std::vector<T>& x, const std::vector<T>& y){
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
||||
|
||||
template<class T, bool AT, bool BT>
|
||||
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
|
||||
size_t M, size_t N, size_t K){
|
||||
for(size_t m = 0; m < M; m++)
|
||||
for(size_t n = 0; n < N; n++){
|
||||
float acc = 0;
|
||||
for(size_t k = 0; k < K; k++)
|
||||
acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]);
|
||||
c[m + n*M] = static_cast<T>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K,
|
||||
std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b) {
|
||||
if(AT_ && BT_)
|
||||
cpu_ref<T, true, true>(c, a, b, M, N, K);
|
||||
else if(AT_ && !BT_)
|
||||
cpu_ref<T, true, false>(c, a, b, M, N, K);
|
||||
else if(!AT_ && BT_)
|
||||
cpu_ref<T, false, true>(c, a, b, M, N, K);
|
||||
else
|
||||
cpu_ref<T, false, false>(c, a, b, M, N, K);
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::string c_ty, int align_lda, int align_ldb) {
|
||||
std::string ZS = "1";
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string XAS0 = "TM", XAS1 = "TK / " + ZS, XAS2 = ZS;
|
||||
std::string XBS0 = "TK / " + ZS, XBS1 = ZS, XBS2 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT ? "trans(a)" : "a";
|
||||
std::string useb = BT ? "trans(b)" : "b";
|
||||
if(AT){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(XAS0, XAS1);
|
||||
std::swap(XAS1, XAS2);
|
||||
std::swap(bca0, bca1);
|
||||
std::swap(lda0, lda1);
|
||||
}
|
||||
if(BT){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(XBS1, XBS2);
|
||||
std::swap(XBS0, XBS1);
|
||||
std::swap(bcb0, bcb1);
|
||||
std::swap(ldb0, ldb1);
|
||||
}
|
||||
std::string AS = AS0 + ", " + AS1;
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
std::string XCS = "TM, TN";
|
||||
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda) + ")";
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")";
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int TM = {128};
|
||||
const tunable int TN = {128};
|
||||
const tunable int TK = {32};
|
||||
|
||||
void matmul(restrict read_only align(16) )" + a_ty + R"( *A,
|
||||
restrict read_only align(16) )" + b_ty + R"( *B,
|
||||
restrict read_only align(16) )" + c_ty + R"( *C,
|
||||
int M, int N, int K,
|
||||
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc) {
|
||||
int ridx = get_range_id(0);
|
||||
int ridy = get_range_id(1);
|
||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int ryb[TN] = ridy * TN + (0 ... TN);
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float xc[)" + XCS + R"(] = 0;
|
||||
)" + a_ty + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
)" + a_ty + R"( a[)" + AS + R"(] = *pa;
|
||||
)" + b_ty + R"( b[)" + BS + R"(] = *pb;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = dot()" + usea + ", " + useb + R"(, xc);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
)" + c_ty + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty + R"( c[TM, TN] = xc;
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = c;
|
||||
}
|
||||
)";
|
||||
return res;
|
||||
}
|
||||
|
||||
struct perf_t {
|
||||
double triton;
|
||||
double cublas;
|
||||
};
|
||||
|
||||
namespace drv = triton::driver;
|
||||
namespace rt = triton::runtime;
|
||||
|
||||
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef half NumericT;
|
||||
std::string ty = "half";
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
triton::driver::context* context = stream->context();
|
||||
drv::context* context = stream->context();
|
||||
std::vector<NumericT> hc(M*N);
|
||||
std::vector<NumericT> ha(M*K);
|
||||
std::vector<NumericT> hb(K*N);
|
||||
int32_t lda = AT ? K : M;
|
||||
int32_t ldb = BT ? N : K;
|
||||
int32_t ldc = M;
|
||||
srand(0);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = static_cast<NumericT>((double)rand()/RAND_MAX);
|
||||
@@ -40,54 +145,40 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
||||
hb[i] = static_cast<NumericT>((double)rand()/RAND_MAX);
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = static_cast<NumericT>((double)0);
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*dt_nbytes);
|
||||
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*dt_nbytes);
|
||||
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*dt_nbytes);
|
||||
drv::buffer* dc = drv::buffer::create(context, hc.size()*dt_nbytes);
|
||||
drv::buffer* da = drv::buffer::create(context, ha.size()*dt_nbytes);
|
||||
drv::buffer* db = drv::buffer::create(context, hb.size()*dt_nbytes);
|
||||
stream->write(da, true, 0, ha);
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8);
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
// NumericT alpha = 1;
|
||||
// NumericT beta = 0;
|
||||
// int32_t lda = AT ? K : M;
|
||||
// int32_t ldb = BT ? N : K;
|
||||
// int32_t ldc = M;
|
||||
// cublasGemmAlgo_t fastest;
|
||||
// cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
|
||||
// &alpha, da, lda,
|
||||
// db, ldb, &beta,
|
||||
// dc, ldc, &fastest);
|
||||
// double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
|
||||
// &alpha, da, lda,
|
||||
// db, ldb, &beta,
|
||||
// dc, ldc, nullptr, CUBLAS_GEMM_DEFAULT_TENSOR_OP); }, stream);
|
||||
// result
|
||||
auto tflops = [&](double nanosec) { return dot.num_flops() / nanosec * 1e-3; };
|
||||
// run
|
||||
rt::function function(src(AT, BT, ty, ty, ty, 8, 8));
|
||||
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
|
||||
auto grid = [&](const rt::params_t& x) { return rt::grid_t{ceil(M, x.at("TM")), ceil(N, x.at("TN")), 1}; };
|
||||
|
||||
perf_t result;
|
||||
// result.cublas = tflops(cublas_ns);
|
||||
result.triton = tflops(triton_ns);
|
||||
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
|
||||
perf_t res;
|
||||
res.triton = tflops(triton::tools::bench([&]() { function({da, db, dc, M, N, K, lda, ldb, ldc}, grid, stream);}, stream));
|
||||
res.cublas = 0;
|
||||
|
||||
// test
|
||||
stream->read(dc, true, 0, hc);
|
||||
std::vector<NumericT> rc(hc.size());
|
||||
dot.cpu_ref(rc, ha, hb);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
// stream->synchronize();
|
||||
// stream->read(dc, true, 0, hc);
|
||||
// std::vector<NumericT> rc(hc.size());
|
||||
// cpu_ref(AT, BT, M, N, K, rc, ha, hb);
|
||||
// for(size_t i = 0; i < M*N; i++)
|
||||
// if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
|
||||
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
// exit(EXIT_FAILURE);
|
||||
// }
|
||||
// std::cout << "Pass!" << std::endl;
|
||||
|
||||
// clean-up
|
||||
delete dc;
|
||||
delete da;
|
||||
delete db;
|
||||
return result;
|
||||
return res;
|
||||
}
|
||||
|
||||
int main() {
|
||||
@@ -111,12 +202,11 @@ int main() {
|
||||
// shapes to benchmark
|
||||
std::vector<config_t> configs = {
|
||||
// {false, false, 8192, 512, 512},
|
||||
{false, true, 128, 128, 128}
|
||||
{false, true, 8192, 8192, 8192}
|
||||
// {false, true, 128, 128, 128},
|
||||
// {false, false, 128, 128, 128},
|
||||
// {true, false, 128, 128, 128},
|
||||
// {true, true, 128, 128, 128}
|
||||
|
||||
// {false, true, 32768, 256, 512}
|
||||
// {true, false, 8192, 512, 512},
|
||||
// {true, true, 8192, 512, 512}
|
||||
|
@@ -38,16 +38,14 @@ private:
|
||||
void create_grids(std::vector<ir::value*> &grids,
|
||||
std::map<unsigned, ir::value*> &references,
|
||||
ir::function *fn);
|
||||
unsigned get_req_num_threads(ir::instruction *i);
|
||||
|
||||
|
||||
public:
|
||||
tune();
|
||||
tune(size_t num_warps);
|
||||
std::vector<ir::metaparameter *> get_params(ir::module& mod);
|
||||
std::map<std::string, ir::metaparameter *> get_params(ir::instruction* i);
|
||||
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
|
||||
fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); }
|
||||
unsigned get_param_group(ir::value *value, unsigned ax);
|
||||
fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); }
|
||||
void copy(ir::value *dst, ir::value *src);
|
||||
bool check_constraints(std::map<ir::value *, std::vector<std::string>> &errors);
|
||||
void run(ir::module &mod);
|
||||
@@ -64,7 +62,7 @@ private:
|
||||
std::map<unsigned, ir::metaparameter*> global_range_sizes_;
|
||||
std::vector<ir::value*> grids_;
|
||||
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
|
||||
ir::metaparameter* num_warps_;
|
||||
size_t num_warps_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -185,8 +185,8 @@ private:
|
||||
|
||||
|
||||
public:
|
||||
selection(analysis::shmem::allocation *alloc, analysis::tune *params, analysis::shmem::info *buffer_info, analysis::alignment_info *ax_info, target *tgt)
|
||||
: alloc_(alloc), params_(params), buffer_info_(buffer_info), axis_info_(ax_info), tgt_(tgt){ }
|
||||
selection(analysis::shmem::allocation *alloc, analysis::tune *params, analysis::shmem::info *buffer_info, analysis::alignment_info *alignment, target *tgt)
|
||||
: alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), tgt_(tgt){ }
|
||||
|
||||
void run(ir::module &src, Module &dst);
|
||||
|
||||
@@ -197,7 +197,7 @@ private:
|
||||
analysis::tune *params_;
|
||||
target *tgt_;
|
||||
analysis::shmem::info *buffer_info_;
|
||||
analysis::alignment_info *axis_info_;
|
||||
analysis::alignment_info *alignment_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
Value *sh_mem_ptr_;
|
||||
Value *offset_a_i_, *offset_a_k_;
|
||||
|
@@ -66,16 +66,18 @@ public:
|
||||
type *get_pointer_element_ty() const;
|
||||
|
||||
// primitive predicates
|
||||
bool is_void_ty() const { return id_ == VoidTyID; }
|
||||
bool is_half_ty() const { return id_ == HalfTyID; }
|
||||
bool is_float_ty() const { return id_ == FloatTyID; }
|
||||
bool is_double_ty() const { return id_ == DoubleTyID; }
|
||||
bool is_label_ty() const { return id_ == LabelTyID;}
|
||||
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
||||
bool is_token_ty() const { return id_ == TokenTyID; }
|
||||
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||
bool is_tile_ty() const { return id_ == TileTyID; }
|
||||
bool is_void_ty() const { return id_ == VoidTyID; }
|
||||
bool is_half_ty() const { return id_ == HalfTyID; }
|
||||
bool is_float_ty() const { return id_ == FloatTyID; }
|
||||
bool is_double_ty() const { return id_ == DoubleTyID; }
|
||||
bool is_label_ty() const { return id_ == LabelTyID;}
|
||||
bool is_metadata_ty() const { return id_ == MetadataTyID; }
|
||||
bool is_token_ty() const { return id_ == TokenTyID; }
|
||||
bool is_integer_ty() const { return id_ == IntegerTyID; }
|
||||
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
|
||||
get_integer_bitwidth() == bitwidth;}
|
||||
bool is_pointer_ty() const { return id_ == PointerTyID; }
|
||||
bool is_tile_ty() const { return id_ == TileTyID; }
|
||||
|
||||
// Composite predicates
|
||||
bool is_int_or_tileint_ty();
|
||||
|
80
include/triton/runtime/arg.h
Normal file
80
include/triton/runtime/arg.h
Normal file
@@ -0,0 +1,80 @@
|
||||
#ifndef TDL_INCLUDE_ARG_H
|
||||
#define TDL_INCLUDE_ARG_H
|
||||
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace triton{
|
||||
|
||||
namespace driver{
|
||||
class buffer;
|
||||
}
|
||||
|
||||
namespace runtime {
|
||||
|
||||
enum arg_type {
|
||||
INT1_T,
|
||||
INT8_T,
|
||||
INT16_T,
|
||||
INT32_T,
|
||||
INT64_T,
|
||||
HALF_T,
|
||||
FLOAT_T,
|
||||
DOUBLE_T,
|
||||
BUFFER_T
|
||||
};
|
||||
|
||||
size_t size_of(arg_type ty){
|
||||
switch(ty){
|
||||
case INT1_T: return 1;
|
||||
case INT8_T: return 1;
|
||||
case INT16_T: return 2;
|
||||
case INT32_T: return 4;
|
||||
case INT64_T: return 8;
|
||||
case HALF_T: return 2;
|
||||
case FLOAT_T: return 4;
|
||||
case DOUBLE_T: return 8;
|
||||
case BUFFER_T: return 8;
|
||||
default: throw std::runtime_error("unknown type");
|
||||
}
|
||||
}
|
||||
|
||||
bool is_int_type(arg_type ty){
|
||||
return ty == INT1_T || ty == INT8_T || ty == INT16_T ||
|
||||
ty == INT32_T || ty == INT64_T;
|
||||
}
|
||||
|
||||
class arg {
|
||||
private:
|
||||
union value_t {
|
||||
bool int1;
|
||||
int8_t int8;
|
||||
int16_t int16;
|
||||
int32_t int32;
|
||||
int64_t int64;
|
||||
float fp32;
|
||||
double fp64;
|
||||
driver::buffer* buf;
|
||||
};
|
||||
|
||||
public:
|
||||
// construct from primitive types
|
||||
arg(int32_t x): ty_(INT32_T) { val_.int32 = x; }
|
||||
arg(int64_t x): ty_(INT64_T) { val_.int64 = x; }
|
||||
arg(float x): ty_(FLOAT_T) { val_.fp32 = x; }
|
||||
arg(double x): ty_(DOUBLE_T) { val_.fp64 = x; }
|
||||
arg(driver::buffer* x): ty_(BUFFER_T) { val_.buf = x; }
|
||||
// accessors
|
||||
arg_type type() const { return ty_; }
|
||||
void* data() const { return (void*)&val_; }
|
||||
|
||||
|
||||
private:
|
||||
arg_type ty_;
|
||||
value_t val_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
113
include/triton/runtime/function.h
Normal file
113
include/triton/runtime/function.h
Normal file
@@ -0,0 +1,113 @@
|
||||
#ifndef TDL_INCLUDE_FUNCTION_H
|
||||
#define TDL_INCLUDE_FUNCTION_H
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include "arg.h"
|
||||
// codegen
|
||||
#include "triton/codegen/selection/selection.h"
|
||||
#include "triton/codegen/selection/target.h"
|
||||
#include "triton/codegen/analysis/tune.h"
|
||||
#include "triton/codegen/analysis/shmem/allocation.h"
|
||||
#include "triton/codegen/analysis/shmem/liveness.h"
|
||||
#include "triton/codegen/analysis/shmem/info.h"
|
||||
#include "triton/codegen/analysis/alignment.h"
|
||||
#include "triton/codegen/transform/dce.h"
|
||||
#include "triton/codegen/transform/peephole.h"
|
||||
#include "triton/codegen/transform/shmem/barriers.h"
|
||||
#include "triton/codegen/transform/reassociate.h"
|
||||
#include "triton/codegen/transform/vectorize.h"
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
class LLVMContext;
|
||||
}
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace driver{
|
||||
class module;
|
||||
class stream;
|
||||
class kernel;
|
||||
class context;
|
||||
class device;
|
||||
}
|
||||
|
||||
namespace lang{
|
||||
class translation_unit;
|
||||
}
|
||||
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
class tune;
|
||||
}
|
||||
}
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class function;
|
||||
class context;
|
||||
class metaparameter;
|
||||
}
|
||||
|
||||
namespace runtime{
|
||||
|
||||
|
||||
typedef std::array<size_t, 3> grid_t;
|
||||
typedef std::map<std::string, size_t> params_t;
|
||||
|
||||
struct options {
|
||||
size_t num_warps;
|
||||
params_t params;
|
||||
};
|
||||
|
||||
|
||||
class function {
|
||||
public:
|
||||
typedef std::function<grid_t(const params_t&)> grid_fn_ty;
|
||||
|
||||
private:
|
||||
class caller {
|
||||
public:
|
||||
caller(ir::function *ir, std::shared_ptr<driver::module> program, size_t n_threads);
|
||||
void operator()(driver::stream *stream, const std::array<size_t, 3>& grid, const std::vector<arg>& args) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<driver::module> parent_;
|
||||
std::shared_ptr<driver::kernel> bin_;
|
||||
std::vector<arg_type> param_tys_;
|
||||
size_t n_threads_;
|
||||
};
|
||||
|
||||
private:
|
||||
typedef std::pair<driver::device*, std::vector<int64_t>> cache_key_t;
|
||||
typedef std::pair<options, caller> cache_val_t;
|
||||
|
||||
private:
|
||||
triton::lang::translation_unit *make_ast(const char *src);
|
||||
std::unique_ptr<ir::module> make_ir(triton::lang::translation_unit *program);
|
||||
options autotune(lang::translation_unit *ast, driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
|
||||
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options &opt);
|
||||
|
||||
|
||||
public:
|
||||
function(const std::string& src);
|
||||
void operator()(const std::vector<arg>& args, const std::array<size_t, 3>& grid, driver::stream* stream);
|
||||
void operator()(const std::vector<arg>& args, const grid_fn_ty& grid, driver::stream *stream);
|
||||
|
||||
private:
|
||||
// execution context
|
||||
ir::context ctx_;
|
||||
// program representations
|
||||
std::string src_;
|
||||
lang::translation_unit *ast_;
|
||||
std::map<cache_key_t, cache_val_t> cache_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
@@ -58,7 +58,8 @@ public:
|
||||
|
||||
struct passes_wrapper {
|
||||
passes_wrapper(codegen::target* target)
|
||||
: shmem_liveness(&shmem_info),
|
||||
: tune(0),
|
||||
shmem_liveness(&shmem_info),
|
||||
shmem_allocation(&shmem_liveness, &shmem_info, &tune),
|
||||
shmem_barriers(&shmem_allocation, &shmem_info),
|
||||
vectorize(&tune),
|
||||
|
@@ -3,6 +3,7 @@
|
||||
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/stream.h"
|
||||
|
||||
@@ -29,7 +30,7 @@ private:
|
||||
|
||||
inline double bench(std::function<void()> const & op, driver::stream * stream)
|
||||
{
|
||||
const driver::device * device = stream->context()->device();
|
||||
// const driver::device * device = stream->context()->device();
|
||||
timer tmr;
|
||||
std::vector<size_t> times;
|
||||
double total_time = 0;
|
||||
|
@@ -14,7 +14,7 @@ namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
tune::tune() {
|
||||
tune::tune(size_t num_warps): num_warps_(num_warps){
|
||||
}
|
||||
|
||||
bool is_hmma(ir::value *v){
|
||||
@@ -183,20 +183,17 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
|
||||
}
|
||||
|
||||
std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
|
||||
std::vector<ir::metaparameter*> result;
|
||||
std::set<ir::metaparameter*> seen;
|
||||
for(auto x: mod.globals()) {
|
||||
if(auto mp = dynamic_cast<ir::metaparameter*>(x.second))
|
||||
if(seen.insert(mp).second && !mp->has_value())
|
||||
result.push_back(mp);
|
||||
}
|
||||
num_warps_ = ir::metaparameter::create(mod.get_context(), mod.get_builder().get_int32_ty(), 4, 4);
|
||||
result.push_back(num_warps_);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i) {
|
||||
return params_.at(i);
|
||||
throw std::runtime_error("remove me");
|
||||
// std::vector<ir::metaparameter*> result;
|
||||
// std::set<ir::metaparameter*> seen;
|
||||
// for(auto x: mod.globals()) {
|
||||
// if(auto mp = dynamic_cast<ir::metaparameter*>(x.second))
|
||||
// if(seen.insert(mp).second && !mp->has_value())
|
||||
// result.push_back(mp);
|
||||
// }
|
||||
// num_warps_ = ir::metaparameter::create(mod.get_context(), mod.get_builder().get_int32_ty(), 4, 4);
|
||||
// result.push_back(num_warps_);
|
||||
// return result;
|
||||
}
|
||||
|
||||
unsigned tune::get_param_group(ir::value *value, unsigned ax) {
|
||||
@@ -257,7 +254,6 @@ void tune::init(ir::module &mod) {
|
||||
}
|
||||
|
||||
int num_threads = get_num_threads();
|
||||
int num_warps = num_warps_->get_value();
|
||||
auto clamp = [&](int x, int lo, int hi) { return std::min(std::max(x, lo), hi); };
|
||||
|
||||
for(ir::value *i: grids_){
|
||||
@@ -292,9 +288,9 @@ void tune::init(ir::module &mod) {
|
||||
std::vector<int> wpt_nm1;
|
||||
do{
|
||||
wpt_nm1 = wpt;
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
|
||||
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
|
||||
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
|
||||
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
|
||||
}while(wpt_nm1 != wpt);
|
||||
// store parameters
|
||||
@@ -307,7 +303,7 @@ void tune::init(ir::module &mod) {
|
||||
std::string str_d = std::to_string(d);
|
||||
effective_num_warps *= params_.at(i).at("wpt.d" + str_d)->get_value();
|
||||
}
|
||||
assert(num_warps == effective_num_warps);
|
||||
assert(num_warps_ == effective_num_warps);
|
||||
}
|
||||
|
||||
/* Scan-line */
|
||||
@@ -386,7 +382,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
||||
}
|
||||
|
||||
unsigned tune::get_num_threads() {
|
||||
return num_warps_->get_value()*32;
|
||||
return num_warps_*32;
|
||||
}
|
||||
|
||||
|
||||
|
@@ -1208,8 +1208,8 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
|
||||
// find vector size
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
unsigned starting_multiple = axis_info_->get_starting_multiple(ptr);
|
||||
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
|
||||
unsigned starting_multiple = alignment_->get_starting_multiple(ptr);
|
||||
unsigned max_contiguous = alignment_->get_max_contiguous(ptr);
|
||||
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
@@ -1280,8 +1280,8 @@ void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRB
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
// find vector size
|
||||
ir::value *ptr = x->get_pointer_operand();
|
||||
unsigned starting_multiple = axis_info_->get_starting_multiple(ptr);
|
||||
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
|
||||
unsigned starting_multiple = alignment_->get_starting_multiple(ptr);
|
||||
unsigned max_contiguous = alignment_->get_max_contiguous(ptr);
|
||||
unsigned alignment = std::min(starting_multiple, max_contiguous);
|
||||
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
|
||||
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
|
||||
|
@@ -54,6 +54,7 @@ size_t buffer::size() {
|
||||
return size_;
|
||||
}
|
||||
|
||||
|
||||
buffer* buffer::create(driver::context* ctx, size_t size) {
|
||||
switch(ctx->backend()){
|
||||
case CUDA: return new cu_buffer(ctx, size);
|
||||
|
0
lib/runtime/arg.cpp
Normal file
0
lib/runtime/arg.cpp
Normal file
265
lib/runtime/function.cpp
Normal file
265
lib/runtime/function.cpp
Normal file
@@ -0,0 +1,265 @@
|
||||
#include <string>
|
||||
#include <mutex>
|
||||
#include <functional>
|
||||
#include "triton/codegen/selection/selection.h"
|
||||
#include "triton/runtime/function.h"
|
||||
#include "triton/lang/lang.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "llvm/IR/Module.h"
|
||||
|
||||
|
||||
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
||||
extern int yyparse();
|
||||
extern YY_BUFFER_STATE yy_scan_string(const char * str);
|
||||
extern void yy_delete_buffer(YY_BUFFER_STATE buffer);
|
||||
extern triton::lang::translation_unit *ast_root;
|
||||
|
||||
namespace triton{
|
||||
namespace runtime {
|
||||
|
||||
|
||||
// helpers
|
||||
void _parallel_loop_nest(std::vector<size_t> const & ranges,
|
||||
std::function<void(std::vector<size_t> const &)> const & f,
|
||||
size_t nthreads){
|
||||
size_t D = ranges.size();
|
||||
std::vector<size_t> values(D, 0);
|
||||
// Start with innermost loop
|
||||
size_t i = D - 1;
|
||||
while(true){
|
||||
// Execute function
|
||||
f(values);
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
return;
|
||||
values[i--] = 0;
|
||||
}
|
||||
i = D - 1;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void _parallel_loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f, size_t nthreads){
|
||||
//Ranges to iterate over
|
||||
std::vector<size_t> ranges;
|
||||
for(auto const & x: iterates)
|
||||
ranges.push_back(x.size());
|
||||
//Proxy function
|
||||
auto proxy = [&](std::vector<size_t> const & idx){
|
||||
std::vector<T> x(iterates.size());
|
||||
for(size_t i = 0; i < x.size(); ++i)
|
||||
x[i] = iterates[i][idx[i]];
|
||||
f(x);
|
||||
};
|
||||
//Iterate
|
||||
_parallel_loop_nest(ranges, proxy, nthreads);
|
||||
}
|
||||
|
||||
// caller
|
||||
|
||||
arg_type convert(ir::type *ty) {
|
||||
if(ty->is_integer_ty(1))
|
||||
return INT1_T;
|
||||
if(ty->is_integer_ty(8))
|
||||
return INT8_T;
|
||||
if(ty->is_integer_ty(16))
|
||||
return INT16_T;
|
||||
if(ty->is_integer_ty(32))
|
||||
return INT32_T;
|
||||
if(ty->is_integer_ty(64))
|
||||
return INT64_T;
|
||||
if(ty->is_half_ty())
|
||||
return HALF_T;
|
||||
if(ty->is_float_ty())
|
||||
return FLOAT_T;
|
||||
if(ty->is_double_ty())
|
||||
return DOUBLE_T;
|
||||
if(ty->is_pointer_ty())
|
||||
return BUFFER_T;
|
||||
throw std::runtime_error("unknown type");
|
||||
}
|
||||
|
||||
function::caller::caller(ir::function *ir, std::shared_ptr<driver::module> parent, size_t n_threads)
|
||||
: bin_(driver::kernel::create(&*parent, ir->get_name().c_str())), n_threads_(n_threads), parent_(parent) {
|
||||
// extract signature
|
||||
ir::function_type* ty = ir->get_fn_type();
|
||||
for(int i = 0; i < ty->get_num_params(); i++)
|
||||
param_tys_.push_back(convert(ty->get_param_ty(i)));
|
||||
}
|
||||
|
||||
|
||||
void function::caller::operator ()(driver::stream *stream, const std::array<size_t, 3>& grid, const std::vector<arg>& args) const {
|
||||
if(args.size() != param_tys_.size())
|
||||
throw std::runtime_error("invalid number of arguments");
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
arg arg_i = args.at(i);
|
||||
arg_type ty = arg_i.type();
|
||||
if(ty != param_tys_.at(i))
|
||||
throw std::runtime_error("invalid type");
|
||||
if(ty == BUFFER_T)
|
||||
bin_->setArg(i, *((driver::buffer**)arg_i.data()));
|
||||
else
|
||||
bin_->setArg(i, size_of(ty), arg_i.data());
|
||||
}
|
||||
stream->enqueue(&*bin_, grid, {n_threads_, 1, 1});
|
||||
}
|
||||
|
||||
|
||||
|
||||
// module
|
||||
triton::lang::translation_unit *function::make_ast(const char *src) {
|
||||
YY_BUFFER_STATE buffer = yy_scan_string(src);
|
||||
yyparse();
|
||||
yy_delete_buffer(buffer);
|
||||
triton::lang::translation_unit *program = ast_root;
|
||||
return program;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::module> function::make_ir(triton::lang::translation_unit *program) {
|
||||
// create Triton-IR from AST
|
||||
ir::module* module = new ir::module("", ctx_);
|
||||
program->codegen(module);
|
||||
return std::unique_ptr<ir::module>(module);
|
||||
}
|
||||
|
||||
options function::autotune(lang::translation_unit *ast, driver::stream* stream, const grid_fn_ty& grid_fn, const std::vector<arg>& args) {
|
||||
std::unique_ptr<ir::module> ir = make_ir(ast);
|
||||
// extract tunable values
|
||||
std::vector<std::pair<std::string, ir::metaparameter*>> values;
|
||||
for(auto it: ir->globals())
|
||||
if(auto *mp = dynamic_cast<ir::metaparameter*>(it.second))
|
||||
values.push_back({it.first, mp});
|
||||
// extract search space
|
||||
std::vector<std::vector<unsigned>> space;
|
||||
space.push_back({1, 2, 4, 8}); // num warps
|
||||
for(auto it: values)
|
||||
space.push_back(it.second->get_space());
|
||||
// exhaustive search
|
||||
struct profile_t{
|
||||
double ts;
|
||||
std::vector<unsigned> params;
|
||||
};
|
||||
profile_t best = { INFINITY };
|
||||
std::function<void(std::vector<unsigned>)> benchmark =
|
||||
[&](std::vector<unsigned> params) {
|
||||
// options
|
||||
options opt;
|
||||
unsigned i = 0;
|
||||
opt.num_warps = params[i++];
|
||||
for(auto it: values)
|
||||
opt.params[it.first] = params[i++];
|
||||
// make binary
|
||||
auto ir = make_ir(ast);
|
||||
auto bin = make_bin(*ir, stream->context(), opt);
|
||||
// benchmark
|
||||
ir::function *tmp = ir->get_function_list()[0];
|
||||
caller fn(tmp, std::move(bin), opt.num_warps * 32);
|
||||
double ts = tools::bench([&]() { fn(stream, grid_fn(opt.params), args); }, stream);
|
||||
if(ts < best.ts)
|
||||
best = {ts, params};
|
||||
};
|
||||
_parallel_loop_nest<unsigned>(space, benchmark, 1);
|
||||
// populate options
|
||||
unsigned current = 0;
|
||||
options opt;
|
||||
opt.num_warps = best.params[current++];
|
||||
for(auto it: values)
|
||||
opt.params[it.first] = best.params[current++];
|
||||
return opt;
|
||||
}
|
||||
|
||||
|
||||
std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::context *context, const options& opt) {
|
||||
std::unique_ptr<codegen::target> target = context->device()->make_target();
|
||||
// update metaparameter values
|
||||
for(auto x: opt.params)
|
||||
if(auto* mp = dynamic_cast<ir::metaparameter*>(module.globals().at(x.first)))
|
||||
mp->set_value(x.second);
|
||||
// create passes
|
||||
codegen::analysis::tune tune(opt.num_warps);
|
||||
codegen::analysis::shmem::info shmem_info;
|
||||
codegen::analysis::shmem::liveness shmem_liveness(&shmem_info);
|
||||
codegen::analysis::shmem::allocation shmem_allocation(&shmem_liveness, &shmem_info, &tune);
|
||||
codegen::analysis::alignment_info alignment_info;
|
||||
codegen::transform::shmem_barriers shmem_barriers(&shmem_allocation, &shmem_info);
|
||||
codegen::transform::vectorize vectorize(&tune);
|
||||
codegen::transform::dce dce;
|
||||
codegen::transform::peephole peephole;
|
||||
codegen::transform::reassociate reassociate(&tune);
|
||||
codegen::selection selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target.get());
|
||||
// run passes
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
tune.run(module);
|
||||
tune.init(module);
|
||||
reassociate.run(module);
|
||||
peephole.run(module);
|
||||
if(target->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
shmem_allocation.run();
|
||||
shmem_barriers.run(module);
|
||||
}
|
||||
alignment_info.run(module);
|
||||
vectorize.run(module);
|
||||
dce.run(module);
|
||||
// generate llvm code
|
||||
llvm::LLVMContext ctx;
|
||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||
selection.run(module, *llvm);
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, llvm.get()));
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
function::function(const std::string &src): src_(src) {
|
||||
// src -> ast
|
||||
ast_ = make_ast(src_.c_str());
|
||||
}
|
||||
|
||||
void function::operator()(const std::vector<arg>& args, const grid_fn_ty& grid_fn, driver::stream *stream) {
|
||||
/* determine if should re-tune or not */
|
||||
cache_key_t key;
|
||||
// re-tune if device is difference
|
||||
key.first = stream->context()->device();
|
||||
// re-tune if any int argument is different
|
||||
for(size_t i = 0; i < args.size(); i++){
|
||||
arg_type ty = args.at(i).type();
|
||||
if(is_int_type(ty)){
|
||||
long val = 0;
|
||||
std::memcpy((void*)&val, args.at(i).data(), size_of(ty));
|
||||
key.second.push_back(val);
|
||||
}
|
||||
}
|
||||
|
||||
/* find existing configuration */
|
||||
auto it = cache_.find(key);
|
||||
if(it != cache_.end()){
|
||||
it->second.second(stream, grid_fn(it->second.first.params), args);
|
||||
return;
|
||||
}
|
||||
|
||||
/* re-tune and re-compile */
|
||||
options opt = autotune(ast_, stream, grid_fn, args);
|
||||
std::unique_ptr<ir::module> ir = make_ir(ast_);
|
||||
std::unique_ptr<driver::module> bin = make_bin(*ir, stream->context(), opt);
|
||||
ir::function* fn = ir->get_function_list().front();
|
||||
const caller& run = cache_.insert({key, cache_val_t{opt, caller(fn, std::move(bin), opt.num_warps*32)}}).first->second.second;
|
||||
run(stream, grid_fn(opt.params), args);
|
||||
}
|
||||
|
||||
void function::operator()(const std::vector<arg>& args, const grid_t& grid, driver::stream *stream) {
|
||||
return this->operator()(args, [&grid](const params_t&){ return grid; }, stream);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user