[runtime] overall of the run-time API

This commit is contained in:
Philippe Tillet
2019-08-14 15:43:50 -07:00
parent b8cd63e0da
commit 38a8b0ab19
13 changed files with 633 additions and 86 deletions

View File

@@ -1,12 +1,12 @@
#include <cstring>
#include <cstring>
#include <sstream>
#include <cstdio>
#include "triton/runtime/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/dnn/dot.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "triton/runtime/function.h"
#include "cuda.h"
template<class T>
@@ -19,20 +19,125 @@ void diff(const std::vector<T>& x, const std::vector<T>& y){
std::cout << "Pass!" << std::endl;
}
template<class T, bool AT, bool BT>
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,
size_t M, size_t N, size_t K){
for(size_t m = 0; m < M; m++)
for(size_t n = 0; n < N; n++){
float acc = 0;
for(size_t k = 0; k < K; k++)
acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]);
c[m + n*M] = static_cast<T>(acc);
}
}
template<class T>
void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K,
std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b) {
if(AT_ && BT_)
cpu_ref<T, true, true>(c, a, b, M, N, K);
else if(AT_ && !BT_)
cpu_ref<T, true, false>(c, a, b, M, N, K);
else if(!AT_ && BT_)
cpu_ref<T, false, true>(c, a, b, M, N, K);
else
cpu_ref<T, false, false>(c, a, b, M, N, K);
}
std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::string c_ty, int align_lda, int align_ldb) {
std::string ZS = "1";
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string XAS0 = "TM", XAS1 = "TK / " + ZS, XAS2 = ZS;
std::string XBS0 = "TK / " + ZS, XBS1 = ZS, XBS2 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT ? "trans(a)" : "a";
std::string useb = BT ? "trans(b)" : "b";
if(AT){
std::swap(AS0, AS1);
std::swap(XAS0, XAS1);
std::swap(XAS1, XAS2);
std::swap(bca0, bca1);
std::swap(lda0, lda1);
}
if(BT){
std::swap(BS0, BS1);
std::swap(XBS1, XBS2);
std::swap(XBS0, XBS1);
std::swap(bcb0, bcb1);
std::swap(ldb0, ldb1);
}
std::string AS = AS0 + ", " + AS1;
std::string BS = BS0 + ", " + BS1;
std::string XCS = "TM, TN";
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda) + ")";
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")";
std::string res =
R"(
const tunable int TM = {128};
const tunable int TN = {128};
const tunable int TK = {32};
void matmul(restrict read_only align(16) )" + a_ty + R"( *A,
restrict read_only align(16) )" + b_ty + R"( *B,
restrict read_only align(16) )" + c_ty + R"( *C,
int M, int N, int K,
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc) {
int ridx = get_range_id(0);
int ridy = get_range_id(1);
int rxa[TM] = ridx * TM + (0 ... TM);
int ryb[TN] = ridy * TN + (0 ... TN);
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;
float xc[)" + XCS + R"(] = 0;
)" + a_ty + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
)" + a_ty + R"( a[)" + AS + R"(] = *pa;
)" + b_ty + R"( b[)" + BS + R"(] = *pb;
for(int k = K; k > 0; k = k - TK){
xc = dot()" + usea + ", " + useb + R"(, xc);
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
a = *pa;
b = *pb;
}
int rxc[TM] = ridx * TM + (0 ... TM);
int ryc[TN] = ridy * TN + (0 ... TN);
)" + c_ty + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
)" + c_ty + R"( c[TM, TN] = xc;
bool checkc0[TM] = rxc < M;
bool checkc1[TN] = ryc < N;
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = c;
}
)";
return res;
}
struct perf_t {
double triton;
double cublas;
};
namespace drv = triton::driver;
namespace rt = triton::runtime;
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef half NumericT;
std::string ty = "half";
size_t dt_nbytes = sizeof(NumericT);
triton::driver::context* context = stream->context();
drv::context* context = stream->context();
std::vector<NumericT> hc(M*N);
std::vector<NumericT> ha(M*K);
std::vector<NumericT> hb(K*N);
int32_t lda = AT ? K : M;
int32_t ldb = BT ? N : K;
int32_t ldc = M;
srand(0);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = static_cast<NumericT>((double)rand()/RAND_MAX);
@@ -40,54 +145,40 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
hb[i] = static_cast<NumericT>((double)rand()/RAND_MAX);
for(size_t i = 0; i < hc.size(); i++)
hc[i] = static_cast<NumericT>((double)0);
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*dt_nbytes);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*dt_nbytes);
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*dt_nbytes);
drv::buffer* dc = drv::buffer::create(context, hc.size()*dt_nbytes);
drv::buffer* da = drv::buffer::create(context, ha.size()*dt_nbytes);
drv::buffer* db = drv::buffer::create(context, hb.size()*dt_nbytes);
stream->write(da, true, 0, ha);
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8);
// benchmark triton
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
// benchmark cublas
// NumericT alpha = 1;
// NumericT beta = 0;
// int32_t lda = AT ? K : M;
// int32_t ldb = BT ? N : K;
// int32_t ldc = M;
// cublasGemmAlgo_t fastest;
// cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
// &alpha, da, lda,
// db, ldb, &beta,
// dc, ldc, &fastest);
// double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
// &alpha, da, lda,
// db, ldb, &beta,
// dc, ldc, nullptr, CUBLAS_GEMM_DEFAULT_TENSOR_OP); }, stream);
// result
auto tflops = [&](double nanosec) { return dot.num_flops() / nanosec * 1e-3; };
// run
rt::function function(src(AT, BT, ty, ty, ty, 8, 8));
auto ceil = [](size_t x, size_t y) { return (x + y - 1) / y; };
auto grid = [&](const rt::params_t& x) { return rt::grid_t{ceil(M, x.at("TM")), ceil(N, x.at("TN")), 1}; };
perf_t result;
// result.cublas = tflops(cublas_ns);
result.triton = tflops(triton_ns);
auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; };
perf_t res;
res.triton = tflops(triton::tools::bench([&]() { function({da, db, dc, M, N, K, lda, ldb, ldc}, grid, stream);}, stream));
res.cublas = 0;
// test
stream->read(dc, true, 0, hc);
std::vector<NumericT> rc(hc.size());
dot.cpu_ref(rc, ha, hb);
for(size_t i = 0; i < M*N; i++)
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}
std::cout << "Pass!" << std::endl;
// stream->synchronize();
// stream->read(dc, true, 0, hc);
// std::vector<NumericT> rc(hc.size());
// cpu_ref(AT, BT, M, N, K, rc, ha, hb);
// for(size_t i = 0; i < M*N; i++)
// if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
// exit(EXIT_FAILURE);
// }
// std::cout << "Pass!" << std::endl;
// clean-up
delete dc;
delete da;
delete db;
return result;
return res;
}
int main() {
@@ -111,12 +202,11 @@ int main() {
// shapes to benchmark
std::vector<config_t> configs = {
// {false, false, 8192, 512, 512},
{false, true, 128, 128, 128}
{false, true, 8192, 8192, 8192}
// {false, true, 128, 128, 128},
// {false, false, 128, 128, 128},
// {true, false, 128, 128, 128},
// {true, true, 128, 128, 128}
// {false, true, 32768, 256, 512}
// {true, false, 8192, 512, 512},
// {true, true, 8192, 512, 512}

View File

@@ -38,16 +38,14 @@ private:
void create_grids(std::vector<ir::value*> &grids,
std::map<unsigned, ir::value*> &references,
ir::function *fn);
unsigned get_req_num_threads(ir::instruction *i);
public:
tune();
tune(size_t num_warps);
std::vector<ir::metaparameter *> get_params(ir::module& mod);
std::map<std::string, ir::metaparameter *> get_params(ir::instruction* i);
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); }
unsigned get_param_group(ir::value *value, unsigned ax);
fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); }
void copy(ir::value *dst, ir::value *src);
bool check_constraints(std::map<ir::value *, std::vector<std::string>> &errors);
void run(ir::module &mod);
@@ -64,7 +62,7 @@ private:
std::map<unsigned, ir::metaparameter*> global_range_sizes_;
std::vector<ir::value*> grids_;
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
ir::metaparameter* num_warps_;
size_t num_warps_;
};

View File

@@ -185,8 +185,8 @@ private:
public:
selection(analysis::shmem::allocation *alloc, analysis::tune *params, analysis::shmem::info *buffer_info, analysis::alignment_info *ax_info, target *tgt)
: alloc_(alloc), params_(params), buffer_info_(buffer_info), axis_info_(ax_info), tgt_(tgt){ }
selection(analysis::shmem::allocation *alloc, analysis::tune *params, analysis::shmem::info *buffer_info, analysis::alignment_info *alignment, target *tgt)
: alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), tgt_(tgt){ }
void run(ir::module &src, Module &dst);
@@ -197,7 +197,7 @@ private:
analysis::tune *params_;
target *tgt_;
analysis::shmem::info *buffer_info_;
analysis::alignment_info *axis_info_;
analysis::alignment_info *alignment_;
std::map<unsigned, distributed_axis> axes_;
Value *sh_mem_ptr_;
Value *offset_a_i_, *offset_a_k_;

View File

@@ -66,16 +66,18 @@ public:
type *get_pointer_element_ty() const;
// primitive predicates
bool is_void_ty() const { return id_ == VoidTyID; }
bool is_half_ty() const { return id_ == HalfTyID; }
bool is_float_ty() const { return id_ == FloatTyID; }
bool is_double_ty() const { return id_ == DoubleTyID; }
bool is_label_ty() const { return id_ == LabelTyID;}
bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; }
bool is_integer_ty() const { return id_ == IntegerTyID; }
bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_tile_ty() const { return id_ == TileTyID; }
bool is_void_ty() const { return id_ == VoidTyID; }
bool is_half_ty() const { return id_ == HalfTyID; }
bool is_float_ty() const { return id_ == FloatTyID; }
bool is_double_ty() const { return id_ == DoubleTyID; }
bool is_label_ty() const { return id_ == LabelTyID;}
bool is_metadata_ty() const { return id_ == MetadataTyID; }
bool is_token_ty() const { return id_ == TokenTyID; }
bool is_integer_ty() const { return id_ == IntegerTyID; }
bool is_integer_ty(unsigned bitwidth) { return is_integer_ty() &&
get_integer_bitwidth() == bitwidth;}
bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_tile_ty() const { return id_ == TileTyID; }
// Composite predicates
bool is_int_or_tileint_ty();

View File

@@ -0,0 +1,80 @@
#ifndef TDL_INCLUDE_ARG_H
#define TDL_INCLUDE_ARG_H
#include <string>
#include <stdexcept>
namespace triton{
namespace driver{
class buffer;
}
namespace runtime {
enum arg_type {
INT1_T,
INT8_T,
INT16_T,
INT32_T,
INT64_T,
HALF_T,
FLOAT_T,
DOUBLE_T,
BUFFER_T
};
size_t size_of(arg_type ty){
switch(ty){
case INT1_T: return 1;
case INT8_T: return 1;
case INT16_T: return 2;
case INT32_T: return 4;
case INT64_T: return 8;
case HALF_T: return 2;
case FLOAT_T: return 4;
case DOUBLE_T: return 8;
case BUFFER_T: return 8;
default: throw std::runtime_error("unknown type");
}
}
bool is_int_type(arg_type ty){
return ty == INT1_T || ty == INT8_T || ty == INT16_T ||
ty == INT32_T || ty == INT64_T;
}
class arg {
private:
union value_t {
bool int1;
int8_t int8;
int16_t int16;
int32_t int32;
int64_t int64;
float fp32;
double fp64;
driver::buffer* buf;
};
public:
// construct from primitive types
arg(int32_t x): ty_(INT32_T) { val_.int32 = x; }
arg(int64_t x): ty_(INT64_T) { val_.int64 = x; }
arg(float x): ty_(FLOAT_T) { val_.fp32 = x; }
arg(double x): ty_(DOUBLE_T) { val_.fp64 = x; }
arg(driver::buffer* x): ty_(BUFFER_T) { val_.buf = x; }
// accessors
arg_type type() const { return ty_; }
void* data() const { return (void*)&val_; }
private:
arg_type ty_;
value_t val_;
};
}
}
#endif

View File

@@ -0,0 +1,113 @@
#ifndef TDL_INCLUDE_FUNCTION_H
#define TDL_INCLUDE_FUNCTION_H
#include <unordered_map>
#include <vector>
#include <string>
#include <memory>
#include <functional>
#include "arg.h"
// codegen
#include "triton/codegen/selection/selection.h"
#include "triton/codegen/selection/target.h"
#include "triton/codegen/analysis/tune.h"
#include "triton/codegen/analysis/shmem/allocation.h"
#include "triton/codegen/analysis/shmem/liveness.h"
#include "triton/codegen/analysis/shmem/info.h"
#include "triton/codegen/analysis/alignment.h"
#include "triton/codegen/transform/dce.h"
#include "triton/codegen/transform/peephole.h"
#include "triton/codegen/transform/shmem/barriers.h"
#include "triton/codegen/transform/reassociate.h"
#include "triton/codegen/transform/vectorize.h"
namespace llvm {
class Module;
class LLVMContext;
}
namespace triton {
namespace driver{
class module;
class stream;
class kernel;
class context;
class device;
}
namespace lang{
class translation_unit;
}
namespace codegen{
namespace analysis{
class tune;
}
}
namespace ir {
class module;
class function;
class context;
class metaparameter;
}
namespace runtime{
typedef std::array<size_t, 3> grid_t;
typedef std::map<std::string, size_t> params_t;
struct options {
size_t num_warps;
params_t params;
};
class function {
public:
typedef std::function<grid_t(const params_t&)> grid_fn_ty;
private:
class caller {
public:
caller(ir::function *ir, std::shared_ptr<driver::module> program, size_t n_threads);
void operator()(driver::stream *stream, const std::array<size_t, 3>& grid, const std::vector<arg>& args) const;
private:
std::shared_ptr<driver::module> parent_;
std::shared_ptr<driver::kernel> bin_;
std::vector<arg_type> param_tys_;
size_t n_threads_;
};
private:
typedef std::pair<driver::device*, std::vector<int64_t>> cache_key_t;
typedef std::pair<options, caller> cache_val_t;
private:
triton::lang::translation_unit *make_ast(const char *src);
std::unique_ptr<ir::module> make_ir(triton::lang::translation_unit *program);
options autotune(lang::translation_unit *ast, driver::stream *stream, const grid_fn_ty& grid, const std::vector<arg> &args);
std::unique_ptr<driver::module> make_bin(ir::module &function, driver::context *context, const options &opt);
public:
function(const std::string& src);
void operator()(const std::vector<arg>& args, const std::array<size_t, 3>& grid, driver::stream* stream);
void operator()(const std::vector<arg>& args, const grid_fn_ty& grid, driver::stream *stream);
private:
// execution context
ir::context ctx_;
// program representations
std::string src_;
lang::translation_unit *ast_;
std::map<cache_key_t, cache_val_t> cache_;
};
}
}
#endif

View File

@@ -58,7 +58,8 @@ public:
struct passes_wrapper {
passes_wrapper(codegen::target* target)
: shmem_liveness(&shmem_info),
: tune(0),
shmem_liveness(&shmem_info),
shmem_allocation(&shmem_liveness, &shmem_info, &tune),
shmem_barriers(&shmem_allocation, &shmem_info),
vectorize(&tune),

View File

@@ -3,6 +3,7 @@
#include <chrono>
#include <functional>
#include <algorithm>
#include "triton/driver/device.h"
#include "triton/driver/stream.h"
@@ -29,7 +30,7 @@ private:
inline double bench(std::function<void()> const & op, driver::stream * stream)
{
const driver::device * device = stream->context()->device();
// const driver::device * device = stream->context()->device();
timer tmr;
std::vector<size_t> times;
double total_time = 0;

View File

@@ -14,7 +14,7 @@ namespace triton{
namespace codegen{
namespace analysis{
tune::tune() {
tune::tune(size_t num_warps): num_warps_(num_warps){
}
bool is_hmma(ir::value *v){
@@ -183,20 +183,17 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
}
std::vector<ir::metaparameter *> tune::get_params(ir::module &mod) {
std::vector<ir::metaparameter*> result;
std::set<ir::metaparameter*> seen;
for(auto x: mod.globals()) {
if(auto mp = dynamic_cast<ir::metaparameter*>(x.second))
if(seen.insert(mp).second && !mp->has_value())
result.push_back(mp);
}
num_warps_ = ir::metaparameter::create(mod.get_context(), mod.get_builder().get_int32_ty(), 4, 4);
result.push_back(num_warps_);
return result;
}
std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i) {
return params_.at(i);
throw std::runtime_error("remove me");
// std::vector<ir::metaparameter*> result;
// std::set<ir::metaparameter*> seen;
// for(auto x: mod.globals()) {
// if(auto mp = dynamic_cast<ir::metaparameter*>(x.second))
// if(seen.insert(mp).second && !mp->has_value())
// result.push_back(mp);
// }
// num_warps_ = ir::metaparameter::create(mod.get_context(), mod.get_builder().get_int32_ty(), 4, 4);
// result.push_back(num_warps_);
// return result;
}
unsigned tune::get_param_group(ir::value *value, unsigned ax) {
@@ -257,7 +254,6 @@ void tune::init(ir::module &mod) {
}
int num_threads = get_num_threads();
int num_warps = num_warps_->get_value();
auto clamp = [&](int x, int lo, int hi) { return std::min(std::max(x, lo), hi); };
for(ir::value *i: grids_){
@@ -292,9 +288,9 @@ void tune::init(ir::module &mod) {
std::vector<int> wpt_nm1;
do{
wpt_nm1 = wpt;
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
wpt[0] = clamp(wpt[0]*2, 1, shape_0 / (fpw[0]*8));
if(wpt[0] * wpt[1] * wpt[2] < num_warps)
if(wpt[0] * wpt[1] * wpt[2] < num_warps_)
wpt[1] = clamp(wpt[1]*2, 1, shape_1 / (fpw[1]*8));
}while(wpt_nm1 != wpt);
// store parameters
@@ -307,7 +303,7 @@ void tune::init(ir::module &mod) {
std::string str_d = std::to_string(d);
effective_num_warps *= params_.at(i).at("wpt.d" + str_d)->get_value();
}
assert(num_warps == effective_num_warps);
assert(num_warps_ == effective_num_warps);
}
/* Scan-line */
@@ -386,7 +382,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
}
unsigned tune::get_num_threads() {
return num_warps_->get_value()*32;
return num_warps_*32;
}

View File

@@ -1208,8 +1208,8 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun
// find vector size
distributed_tile* result = (distributed_tile*)tmap_.at(x);
ir::value *ptr = x->get_pointer_operand();
unsigned starting_multiple = axis_info_->get_starting_multiple(ptr);
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
unsigned starting_multiple = alignment_->get_starting_multiple(ptr);
unsigned max_contiguous = alignment_->get_max_contiguous(ptr);
unsigned alignment = std::min(starting_multiple, max_contiguous);
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);
@@ -1280,8 +1280,8 @@ void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRB
distributed_tile* result = (distributed_tile*)tmap_.at(x);
// find vector size
ir::value *ptr = x->get_pointer_operand();
unsigned starting_multiple = axis_info_->get_starting_multiple(ptr);
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
unsigned starting_multiple = alignment_->get_starting_multiple(ptr);
unsigned max_contiguous = alignment_->get_max_contiguous(ptr);
unsigned alignment = std::min(starting_multiple, max_contiguous);
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr);

View File

@@ -54,6 +54,7 @@ size_t buffer::size() {
return size_;
}
buffer* buffer::create(driver::context* ctx, size_t size) {
switch(ctx->backend()){
case CUDA: return new cu_buffer(ctx, size);

0
lib/runtime/arg.cpp Normal file
View File

265
lib/runtime/function.cpp Normal file
View File

@@ -0,0 +1,265 @@
#include <string>
#include <mutex>
#include <functional>
#include "triton/codegen/selection/selection.h"
#include "triton/runtime/function.h"
#include "triton/lang/lang.h"
#include "triton/driver/device.h"
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
#include "triton/driver/module.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/tools/bench.hpp"
#include "llvm/IR/Module.h"
typedef struct yy_buffer_state * YY_BUFFER_STATE;
extern int yyparse();
extern YY_BUFFER_STATE yy_scan_string(const char * str);
extern void yy_delete_buffer(YY_BUFFER_STATE buffer);
extern triton::lang::translation_unit *ast_root;
namespace triton{
namespace runtime {
// helpers
void _parallel_loop_nest(std::vector<size_t> const & ranges,
std::function<void(std::vector<size_t> const &)> const & f,
size_t nthreads){
size_t D = ranges.size();
std::vector<size_t> values(D, 0);
// Start with innermost loop
size_t i = D - 1;
while(true){
// Execute function
f(values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;
values[i--] = 0;
}
i = D - 1;
}
}
template<class T>
void _parallel_loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f, size_t nthreads){
//Ranges to iterate over
std::vector<size_t> ranges;
for(auto const & x: iterates)
ranges.push_back(x.size());
//Proxy function
auto proxy = [&](std::vector<size_t> const & idx){
std::vector<T> x(iterates.size());
for(size_t i = 0; i < x.size(); ++i)
x[i] = iterates[i][idx[i]];
f(x);
};
//Iterate
_parallel_loop_nest(ranges, proxy, nthreads);
}
// caller
arg_type convert(ir::type *ty) {
if(ty->is_integer_ty(1))
return INT1_T;
if(ty->is_integer_ty(8))
return INT8_T;
if(ty->is_integer_ty(16))
return INT16_T;
if(ty->is_integer_ty(32))
return INT32_T;
if(ty->is_integer_ty(64))
return INT64_T;
if(ty->is_half_ty())
return HALF_T;
if(ty->is_float_ty())
return FLOAT_T;
if(ty->is_double_ty())
return DOUBLE_T;
if(ty->is_pointer_ty())
return BUFFER_T;
throw std::runtime_error("unknown type");
}
function::caller::caller(ir::function *ir, std::shared_ptr<driver::module> parent, size_t n_threads)
: bin_(driver::kernel::create(&*parent, ir->get_name().c_str())), n_threads_(n_threads), parent_(parent) {
// extract signature
ir::function_type* ty = ir->get_fn_type();
for(int i = 0; i < ty->get_num_params(); i++)
param_tys_.push_back(convert(ty->get_param_ty(i)));
}
void function::caller::operator ()(driver::stream *stream, const std::array<size_t, 3>& grid, const std::vector<arg>& args) const {
if(args.size() != param_tys_.size())
throw std::runtime_error("invalid number of arguments");
for(size_t i = 0; i < args.size(); i++){
arg arg_i = args.at(i);
arg_type ty = arg_i.type();
if(ty != param_tys_.at(i))
throw std::runtime_error("invalid type");
if(ty == BUFFER_T)
bin_->setArg(i, *((driver::buffer**)arg_i.data()));
else
bin_->setArg(i, size_of(ty), arg_i.data());
}
stream->enqueue(&*bin_, grid, {n_threads_, 1, 1});
}
// module
triton::lang::translation_unit *function::make_ast(const char *src) {
YY_BUFFER_STATE buffer = yy_scan_string(src);
yyparse();
yy_delete_buffer(buffer);
triton::lang::translation_unit *program = ast_root;
return program;
}
std::unique_ptr<ir::module> function::make_ir(triton::lang::translation_unit *program) {
// create Triton-IR from AST
ir::module* module = new ir::module("", ctx_);
program->codegen(module);
return std::unique_ptr<ir::module>(module);
}
options function::autotune(lang::translation_unit *ast, driver::stream* stream, const grid_fn_ty& grid_fn, const std::vector<arg>& args) {
std::unique_ptr<ir::module> ir = make_ir(ast);
// extract tunable values
std::vector<std::pair<std::string, ir::metaparameter*>> values;
for(auto it: ir->globals())
if(auto *mp = dynamic_cast<ir::metaparameter*>(it.second))
values.push_back({it.first, mp});
// extract search space
std::vector<std::vector<unsigned>> space;
space.push_back({1, 2, 4, 8}); // num warps
for(auto it: values)
space.push_back(it.second->get_space());
// exhaustive search
struct profile_t{
double ts;
std::vector<unsigned> params;
};
profile_t best = { INFINITY };
std::function<void(std::vector<unsigned>)> benchmark =
[&](std::vector<unsigned> params) {
// options
options opt;
unsigned i = 0;
opt.num_warps = params[i++];
for(auto it: values)
opt.params[it.first] = params[i++];
// make binary
auto ir = make_ir(ast);
auto bin = make_bin(*ir, stream->context(), opt);
// benchmark
ir::function *tmp = ir->get_function_list()[0];
caller fn(tmp, std::move(bin), opt.num_warps * 32);
double ts = tools::bench([&]() { fn(stream, grid_fn(opt.params), args); }, stream);
if(ts < best.ts)
best = {ts, params};
};
_parallel_loop_nest<unsigned>(space, benchmark, 1);
// populate options
unsigned current = 0;
options opt;
opt.num_warps = best.params[current++];
for(auto it: values)
opt.params[it.first] = best.params[current++];
return opt;
}
std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::context *context, const options& opt) {
std::unique_ptr<codegen::target> target = context->device()->make_target();
// update metaparameter values
for(auto x: opt.params)
if(auto* mp = dynamic_cast<ir::metaparameter*>(module.globals().at(x.first)))
mp->set_value(x.second);
// create passes
codegen::analysis::tune tune(opt.num_warps);
codegen::analysis::shmem::info shmem_info;
codegen::analysis::shmem::liveness shmem_liveness(&shmem_info);
codegen::analysis::shmem::allocation shmem_allocation(&shmem_liveness, &shmem_info, &tune);
codegen::analysis::alignment_info alignment_info;
codegen::transform::shmem_barriers shmem_barriers(&shmem_allocation, &shmem_info);
codegen::transform::vectorize vectorize(&tune);
codegen::transform::dce dce;
codegen::transform::peephole peephole;
codegen::transform::reassociate reassociate(&tune);
codegen::selection selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target.get());
// run passes
peephole.run(module);
dce.run(module);
tune.run(module);
tune.init(module);
reassociate.run(module);
peephole.run(module);
if(target->is_gpu()){
shmem_info.run(module);
shmem_liveness.run(module);
shmem_allocation.run();
shmem_barriers.run(module);
}
alignment_info.run(module);
vectorize.run(module);
dce.run(module);
// generate llvm code
llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
selection.run(module, *llvm);
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, llvm.get()));
return res;
}
function::function(const std::string &src): src_(src) {
// src -> ast
ast_ = make_ast(src_.c_str());
}
void function::operator()(const std::vector<arg>& args, const grid_fn_ty& grid_fn, driver::stream *stream) {
/* determine if should re-tune or not */
cache_key_t key;
// re-tune if device is difference
key.first = stream->context()->device();
// re-tune if any int argument is different
for(size_t i = 0; i < args.size(); i++){
arg_type ty = args.at(i).type();
if(is_int_type(ty)){
long val = 0;
std::memcpy((void*)&val, args.at(i).data(), size_of(ty));
key.second.push_back(val);
}
}
/* find existing configuration */
auto it = cache_.find(key);
if(it != cache_.end()){
it->second.second(stream, grid_fn(it->second.first.params), args);
return;
}
/* re-tune and re-compile */
options opt = autotune(ast_, stream, grid_fn, args);
std::unique_ptr<ir::module> ir = make_ir(ast_);
std::unique_ptr<driver::module> bin = make_bin(*ir, stream->context(), opt);
ir::function* fn = ir->get_function_list().front();
const caller& run = cache_.insert({key, cache_val_t{opt, caller(fn, std::move(bin), opt.num_warps*32)}}).first->second.second;
run(stream, grid_fn(opt.params), args);
}
void function::operator()(const std::vector<arg>& args, const grid_t& grid, driver::stream *stream) {
return this->operator()(args, [&grid](const params_t&){ return grid; }, stream);
}
}
}