[triton/python/conv]: Added cache for compiled kernels

This commit is contained in:
Philippe Tillet
2019-05-18 11:51:49 -04:00
parent 600aef72d5
commit b2b55c52c9
10 changed files with 210 additions and 516 deletions

View File

@@ -1,7 +1,7 @@
#include <cstring>
#include <cstdio>
#include "common.hpp"
#include "triton/jit.h"
#include "triton/runtime/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/dnn/conv.h"
@@ -10,11 +10,11 @@ int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
triton::dnn::conv::type ty = triton::dnn::conv::WGRAD;
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
// initialization
int32_t B = 32, NF = 128;
int32_t B = 4, NF = 32;
int32_t D = 1, H = 56, W = 56;
int32_t NC = 128, T = 1, R = 3, S = 3;
int32_t NC = 32, T = 1, R = 3, S = 3;
int32_t pad_d = 0, pad_h = 1, pad_w = 1;
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, 1, 1, 1, pad_d, pad_h, pad_w, ty);
// convolution configuration
@@ -45,7 +45,7 @@ int main() {
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
configuration.init(stream, jit);
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
stream->synchronize();
configuration.set_arg(kernel, da, db, dc);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
@@ -55,7 +55,7 @@ int main() {
return configuration.get_nflops() / ts * 1e-3;
};
std::string src = configuration.src();
jit.autotune("conv", src.c_str(), benchmark);
// jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), configuration.default_params());
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");

View File

@@ -1,7 +1,7 @@
#include <cstring>
#include <cstdio>
#include "common.hpp"
#include "triton/jit.h"
#include "triton/runtime/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/dnn/gemm.h"

View File

@@ -1,7 +1,7 @@
#include <cstring>
#include <cstdio>
#include "common.hpp"
#include "triton/jit.h"
#include "triton/runtime/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
@@ -158,8 +158,8 @@ int main() {
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
// initialize constant memory
triton::driver::buffer* delta = jit.get_buffer("delta");
triton::driver::buffer* masks = jit.get_buffer("masks");
triton::driver::buffer* delta = ((triton::driver::cu_module*)kernel->module())->symbol("delta");
triton::driver::buffer* masks = ((triton::driver::cu_module*)kernel->module())->symbol("masks");
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
stream->synchronize();

View File

@@ -2,7 +2,7 @@
#include <torch/script.h>
#include "ATen/cuda/CUDAContext.h"
#include <vector>
#include "triton/jit.h"
#include "triton/runtime/jit.h"
#include "triton/driver/stream.h"
#include "triton/dnn/conv.h"
@@ -10,6 +10,16 @@
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
int32_t, int32_t, int32_t, int32_t,
int32_t, int32_t, int32_t,
int32_t, int32_t, int32_t,
triton::dnn::conv::type> conv_key_t;
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
static std::map<conv_key_t, std::unique_ptr<triton::jit>> m_jit;
static std::map<conv_key_t, std::unique_ptr<triton::dnn::conv>> m_config;
torch::Tensor conv_common(
int32_t B, int32_t C, int32_t D, int32_t H, int32_t W,
int32_t T, int32_t R, int32_t S, int32_t NF,
@@ -18,41 +28,59 @@ torch::Tensor conv_common(
triton::dnn::conv::type ty,
torch::Tensor torcha, torch::Tensor torchb
) {
// Configuration
triton::dnn::conv configuration(B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty);
// Wrap CUDA handles
c10::DeviceIndex device = torcha.storage().device().index();
// Get stream
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
triton::driver::stream* stream;
if(m_stream.find(custream) == m_stream.end())
stream = m_stream.emplace(custream, new triton::driver::cu_stream(custream, false)).first->second.get();
else
stream = m_stream.at(custream).get();
// Get context
triton::driver::context* ctx = stream->context();
// Get configuration
conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty};
triton::dnn::conv* configuration;
if(m_config.find(key) == m_config.end())
configuration = m_config.emplace(key, new triton::dnn::conv(
B, C, D, H, W, T, R, S, NF,
stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w, ty)).first->second.get();
else
configuration = m_config.at(key).get();
// Get JIT
triton::jit* jit;
if(m_jit.find(key) == m_jit.end()){
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
std::string src = configuration->src();
jit->add_module("conv", src.c_str(), configuration->default_params());
}
else
jit = m_jit.at(key).get();
// Get memory
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
// Allocate output
std::vector<int32_t> c_shapes = configuration.c_shapes();
std::vector<int32_t> c_shapes = configuration->c_shapes();
torch::Tensor torchc;
if(ty == triton::dnn::conv::WGRAD)
torchc = torch::empty({c_shapes[0], c_shapes[2], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda();
else
torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda();
// Wrap CUDA handles
c10::DeviceIndex device = torchc.storage().device().index();
triton::driver::cu_stream sstream((CUstream)at::cuda::getCurrentCUDAStream(device).stream(), false);
triton::driver::stream* stream = &sstream;
triton::driver::context* ctx = stream->context();
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
stream->synchronize();
// Create JIT
triton::jit jit(ctx);
std::string src = configuration.src();
jit.add_module("conv", src.c_str(), configuration.default_params());
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");
// Add module to JIT
triton::driver::kernel* kernel = jit->get_function("conv");
triton::jit::launch_information info = jit->get_launch_info("conv");
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
// launch info
configuration.init(stream, jit);
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
unsigned nthreads = info.num_threads;
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
configuration.set_arg(kernel, &a, &b, &c);
stream->synchronize();
std::array<size_t, 3> grid = configuration->get_grid(TM, TN);
configuration->set_arg(kernel, &a, &b, &c);
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
return torchc;
}

View File

@@ -1,4 +1,5 @@
import torch
import time
torch.manual_seed(0)
class TritonConv(torch.autograd.Function):
@@ -14,9 +15,9 @@ class TritonConv(torch.autograd.Function):
input, weight = ctx.saved_tensors
grad_input = grad_weight = None
if ctx.needs_input_grad[0]:
grad_input = torch.ops.triton.conv_bprop(grad_output.contiguous(), weight)
grad_input = torch.ops.triton.conv_bprop(grad_output, weight)
if ctx.needs_input_grad[1]:
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output.contiguous())
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output)
return grad_input, grad_weight
@@ -38,6 +39,7 @@ x.grad.zero_()
w.grad.zero_()
culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=1))
print((tty - cuy).norm(2))
print((ttdx - cudx).norm(2))
print((ttdw.permute(3,0,1,2) - cudw).norm(2))

View File

@@ -4,7 +4,6 @@
#include <numeric>
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
#include "triton/jit.h"
namespace triton{
namespace dnn{
@@ -34,7 +33,7 @@ public:
// initialize
void build_deltas();
void build_masks();
void init(driver::stream *stream, triton::jit &jit);
void init(driver::stream *stream, driver::cu_module *module);
std::array<size_t, 3> get_grid(size_t TM, size_t TN);
void set_arg(driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c);
@@ -44,7 +43,144 @@ public:
std::vector<unsigned> default_params();
// source
std::string src();
std::string src(){
bool is_wgrad = ty_ == WGRAD;
std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]";
std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]";
std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]";
std::string ldb0 = b_trans_ ? "*ldb_s" : "";
std::string ldb1 = b_trans_ ? "*ldb_k" : "*ldb_c";
std::string useb = b_trans_ ? "trans(b)" : "b";
std::string flipr = b_trans_ ? "" : "BH - 1 -";
std::string flips = b_trans_ ? "" : "BW - 1 -";
std::string ax = b_trans_ ? "crs" : "rsc";
std::vector<std::string> redax;
if(b_trans_)
redax = {"C", "BH", "BW"};
else
redax = {"BH", "BW", "N"};
std::string inc_pb = is_wgrad ? "db[newaxis, :]" : "TK" + ldb0;
std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : "";
std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : "";
std::string masks_mem = is_mask_cst_? "__constant__" : "";
std::string res =
R"(
const tunable int32 TM = {16, 32, 64};
const tunable int32 TN = {16, 32, 64};
const tunable int32 TK = {8};
)";
if(is_a_deltas_cst)
res += "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n";
if(is_wgrad && is_b_deltas_cst_)
res += "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n";
if(is_mask_cst_)
res += "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
res += R"(
void conv(read_only restrict fp32 *a,
read_only restrict fp32 *b,
fp32 *c,
int32 M, int32 N, int32 K,
int32 AH, int32 AW,
int32 BH, int32 BW,
int32 CH, int32 CW,
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k,
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q,
int32 pad_h, int32 pad_w)";
if(!is_a_deltas_cst)
res += ", int32* delta";
if(is_wgrad && !is_b_deltas_cst_)
res += ", int32* b_delta";
if(!is_mask_cst_)
res += ", int32* masks";
res += R"(){
int32 rxa[TM] = get_global_range[TM](0);
int32 rb0[TN] = get_global_range[TN](1);
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
fp32 C[TM, TN] = 0;
int32 ldlut = )" + std::to_string(Fs_) + R"(;
int32 rabh[TM] = rxa / CW;
int32 raw[TM] = rxa % CW - pad_w;
int32 rab[TM] = rabh / CH;
int32 rah[TM] = rabh % CH - pad_h;
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
int32 ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(;
int32 ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(;
rar = )" + flipr + R"( rar;
ras = )" + flips + R"( ras;
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
if(ty_ == WGRAD){
res += R"(
int32 rbcr[TK] = rkb / BW;
int32 rbs[TK] = rkb % BW;
int32 rbc[TK] = rbcr / BH;
int32 rbr[TK] = rbcr % BH;
int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + ras*ldb_s;
)" + b_delta_mem + R"( int32* pdb[TK] = b_delta + rkb;
int32 db[TK] = *pdb;)";
}
else{
res += R"(
int32 rb1[TK] = rkb;)";
}
res += R"(
fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
)" + a_delta_mem + R"( int32* pincd[TK] = delta + rka;
)" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka;
int32 d[TK] = *pd;
int32 incd[TK] = *pincd;
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
)" + masks_mem + R"( int32* pm[TM] = masks + ldlut + maskw*ldlut + maskh*ldlut*(2*pad_w + 1);
)" + a_delta_mem + R"( int32* pincm[TM] = delta;
int32 incm[TM] = *pincm;
int32 checka0[TM] = *pm;
int32 checka1[TK] = 1 << rka;
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
fp32 a[TM, TK] = checka ? *pa : 0;
fp32 b)" + BS + R"( = *pb;
for(int32 k = K; k > 0; k = k - TK){
C = dot(a, )" + useb + R"(, C);
pa = pa + d[newaxis, :];
pb = pb + )" + inc_pb + R"(;
b = *pb;
pd = pd + incd;)";
if(ty_ == WGRAD){
res += R"(
pdb = pdb + TK;
db = *pdb;)";
}
res += R"(
pincd = pincd + incd;
d = *pd;
incd = *pincd;
pm = pm + incm;
pincm = pincm + incm;
incm = *pincm;
checka0 = *pm;
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
checka = checka && (k > TK);
a = checka ? *pa : 0;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 rc1[TN] = get_global_range[TN](1);
int32 rcn[TM] = rxc / (CH*CW);
int32 rcpq[TM] = rxc % (CH*CW);
int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q;
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = rc1 < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = C;
})";
return res;
}
// cpu check
template<class IN_DTYPE, class OUT_DTYPE>

View File

@@ -1,117 +0,0 @@
#ifndef TDL_INCLUDE_JIT_H
#define TDL_INCLUDE_JIT_H
#include <string>
#include <memory>
#include "llvm/IR/LLVMContext.h"
#include "triton/ir/context.h"
#include "triton/ir/print.h"
#include "triton/driver/module.h"
#include "triton/driver/kernel.h"
#include "triton/codegen/selection.h"
#include "triton/codegen/tune.h"
#include "triton/codegen/optimize_dot.h"
#include "triton/codegen/optimize_cse.h"
#include "triton/codegen/optimize_trans.h"
#include "triton/codegen/shmem_allocation.h"
#include "triton/codegen/shmem_liveness.h"
#include "triton/codegen/shmem_info.h"
#include "triton/codegen/shmem_barriers.h"
#include "triton/codegen/target.h"
#include "triton/codegen/vectorize.h"
#include <functional>
namespace llvm {
class Module;
}
namespace triton {
namespace codegen{
class tune;
}
namespace ir {
class module;
class context;
class metaparameter;
}
class jit {
public:
struct launch_information{
std::vector<unsigned> global_range_size;
unsigned num_threads;
};
typedef std::function<double(driver::kernel*, launch_information)> benchmark_t;
struct passes_wrapper {
passes_wrapper(codegen::target* target)
: shmem_liveness(&shmem_info),
shmem_allocation(&shmem_liveness, &shmem_info),
shmem_barriers(&shmem_allocation, &shmem_info),
vectorize(&tune),
selection(&shmem_allocation, &tune, &shmem_info, target),
optimize_dot(&tune),
optimize_cse(),
optimize_trans(),
target_(target) { }
void target_independent(ir::module &module) {
optimize_dot.run(module);
optimize_trans.run(module);
}
void target_dependent(ir::module &module) {
if(target_->is_gpu()){
shmem_info.run(module);
shmem_liveness.run(module);
shmem_allocation.run();
shmem_barriers.run(module);
}
vectorize.run(module);
}
codegen::tune tune;
codegen::shmem_info shmem_info;
codegen::shmem_liveness shmem_liveness;
codegen::shmem_allocation shmem_allocation;
codegen::shmem_barriers shmem_barriers;
codegen::vectorize vectorize;
codegen::selection selection;
codegen::optimize_dot optimize_dot;
codegen::optimize_cse optimize_cse;
codegen::optimize_trans optimize_trans;
codegen::target* target_;
};
private:
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes);
std::unique_ptr<ir::module> make_triton_module(const char* name, const char* src);
public:
jit(driver::context* context);
~jit();
void autotune(const char* name, const char* src, benchmark_t benchmark);
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
void add_module(const char* name, const char* src, const std::vector<unsigned>& params = {});
driver::kernel* get_function(const char* name);
launch_information get_launch_info(const char* name);
unsigned get_int(const char* name);
driver::buffer* get_buffer(const char* name);
private:
std::vector<driver::module*> modules_;
driver::context* driver_context_;
llvm::LLVMContext llvm_context_;
ir::context triton_context_;
std::map<std::string, launch_information> launch_info_map_;
std::map<std::string, unsigned> global_ints_;
std::unique_ptr<triton::codegen::target> target_;
};
}
#endif

View File

@@ -207,7 +207,7 @@ std::array<size_t, 3> conv::get_grid(size_t TM, size_t TN)
size_t conv::get_nflops()
{ return 2.*M_*N_*K_; }
void conv::init(driver::stream *stream, triton::jit &jit) {
void conv::init(driver::stream *stream, triton::driver::cu_module* module) {
auto init_lut = [&](bool is_cst, const char *name, std::vector<int32_t> host) -> triton::driver::buffer*{
if(host.empty())
return nullptr;
@@ -215,7 +215,7 @@ void conv::init(driver::stream *stream, triton::jit &jit) {
// get buffer
triton::driver::buffer* buffer;
if(is_cst)
buffer = jit.get_buffer(name);
buffer = module->symbol(name);
else
buffer = triton::driver::buffer::create(stream->context(), nbytes);
// copy
@@ -306,145 +306,6 @@ std::vector<unsigned> conv::default_params() {
}
std::string conv::src() {
bool is_wgrad = ty_ == WGRAD;
std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]";
std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]";
std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]";
std::string ldb0 = b_trans_ ? "*ldb_s" : "";
std::string ldb1 = b_trans_ ? "*ldb_k" : "*ldb_c";
std::string useb = b_trans_ ? "trans(b)" : "b";
std::string flipr = b_trans_ ? "" : "BH - 1 -";
std::string flips = b_trans_ ? "" : "BW - 1 -";
std::string ax = b_trans_ ? "crs" : "rsc";
std::vector<std::string> redax;
if(b_trans_)
redax = {"C", "BH", "BW"};
else
redax = {"BH", "BW", "N"};
std::string inc_pb = is_wgrad ? "db[newaxis, :]" : "TK" + ldb0;
std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : "";
std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : "";
std::string masks_mem = is_mask_cst_? "__constant__" : "";
std::string res =
R"(
const tunable int32 TM = {16, 32, 64};
const tunable int32 TN = {16, 32, 64};
const tunable int32 TK = {8};
)";
if(is_a_deltas_cst)
res += "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n";
if(is_wgrad && is_b_deltas_cst_)
res += "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n";
if(is_mask_cst_)
res += "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
res += R"(
void conv(read_only restrict fp32 *a,
read_only restrict fp32 *b,
fp32 *c,
int32 M, int32 N, int32 K,
int32 AH, int32 AW,
int32 BH, int32 BW,
int32 CH, int32 CW,
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k,
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q,
int32 pad_h, int32 pad_w)";
if(!is_a_deltas_cst)
res += ", int32* delta";
if(is_wgrad && !is_b_deltas_cst_)
res += ", int32* b_delta";
if(!is_mask_cst_)
res += ", int32* masks";
res += R"(){
int32 rxa[TM] = get_global_range[TM](0);
int32 rb0[TN] = get_global_range[TN](1);
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
fp32 C[TM, TN] = 0;
int32 ldlut = )" + std::to_string(Fs_) + R"(;
int32 rabh[TM] = rxa / CW;
int32 raw[TM] = rxa % CW - pad_w;
int32 rab[TM] = rabh / CH;
int32 rah[TM] = rabh % CH - pad_h;
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
int32 ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(;
int32 ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(;
rar = )" + flipr + R"( rar;
ras = )" + flips + R"( ras;
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
if(ty_ == WGRAD){
res += R"(
int32 rbcr[TK] = rkb / BW;
int32 rbs[TK] = rkb % BW;
int32 rbc[TK] = rbcr / BH;
int32 rbr[TK] = rbcr % BH;
int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + ras*ldb_s;
)" + b_delta_mem + R"( int32* pdb[TK] = b_delta + rkb;
int32 db[TK] = *pdb;)";
}
else{
res += R"(
int32 rb1[TK] = rkb;)";
}
res += R"(
fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
)" + a_delta_mem + R"( int32* pincd[TK] = delta + rka;
)" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka;
int32 d[TK] = *pd;
int32 incd[TK] = *pincd;
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
)" + masks_mem + R"( int32* pm[TM] = masks + ldlut + maskw*ldlut + maskh*ldlut*(2*pad_w + 1);
)" + a_delta_mem + R"( int32* pincm[TM] = delta;
int32 incm[TM] = *pincm;
int32 checka0[TM] = *pm;
int32 checka1[TK] = 1 << rka;
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
fp32 a[TM, TK] = checka ? *pa : 0;
fp32 b)" + BS + R"( = *pb;
for(int32 k = K; k > 0; k = k - TK){
C = dot(a, )" + useb + R"(, C);
pa = pa + d[newaxis, :];
pb = pb + )" + inc_pb + R"(;
b = *pb;
pd = pd + incd;)";
if(ty_ == WGRAD){
res += R"(
pdb = pdb + TK;
db = *pdb;)";
}
res += R"(
pincd = pincd + incd;
d = *pd;
incd = *pincd;
pm = pm + incm;
pincm = pincm + incm;
incm = *pincm;
checka0 = *pm;
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
checka = checka && (k > TK);
a = checka ? *pa : 0;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 rc1[TN] = get_global_range[TN](1);
int32 rcn[TM] = rxc / (CH*CW);
int32 rcpq[TM] = rxc % (CH*CW);
int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q;
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = rc1 < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = C;
})";
return res;
}
template<class IN_DTYPE, class OUT_DTYPE>
void conv::cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
{

View File

@@ -100,10 +100,10 @@ module* module::create(driver::context* ctx, llvm::Module *src) {
}
void module::compile_llvm_module(llvm::Module* module, const std::string& triple,
const std::string &proc, std::string layout,
llvm::SmallVectorImpl<char> &buffer,
const std::string& features,
file_type_t ft) {
const std::string &proc, std::string layout,
llvm::SmallVectorImpl<char> &buffer,
const std::string& features,
file_type_t ft) {
init_llvm();
// debug
// llvm::legacy::PassManager pm;

View File

@@ -1,216 +0,0 @@
#include "triton/jit.h"
#include <string>
#include "triton/ast/ast.h"
#include "triton/codegen/target.h"
#include "triton/ir/context.h"
#include "triton/ir/context_impl.h"
#include "triton/driver/device.h"
#include "triton/driver/error.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Analysis/LoopPass.h"
typedef struct yy_buffer_state * YY_BUFFER_STATE;
extern int yyparse();
extern YY_BUFFER_STATE yy_scan_string(const char * str);
extern void yy_delete_buffer(YY_BUFFER_STATE buffer);
using triton::ast::translation_unit;
extern translation_unit *ast_root;
namespace triton {
void loop_nest(std::vector<size_t> const & ranges, std::function<void(std::vector<size_t> const &)> const & f){
size_t D = ranges.size();
std::vector<size_t> values(D, 0);
// Start with innermost loop
size_t i = D - 1;
while(true){
//Execute function
f(values);
//Increment counters
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;
values[i--] = 0;
}
i = D - 1;
}
}
template<class T>
void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f){
//Ranges to iterate over
std::vector<size_t> ranges;
for(auto const & x: iterates)
ranges.push_back(x.size());
//Proxy function
auto proxy = [&](std::vector<size_t> const & idx){
std::vector<T> x(iterates.size());
for(size_t i = 0; i < x.size(); ++i)
x[i] = iterates[i][idx[i]];
f(x);
};
//Iterate
loop_nest(ranges, proxy);
}
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes) {
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_);
passes.selection.run(module, *result);
// launch information
launch_information& info = launch_info_map_[result->getName()];
info.global_range_size.clear();
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
info.global_range_size.push_back(passes.tune.get_global_range_size(i));
info.num_threads = passes.tune.get_num_threads();
return std::unique_ptr<llvm::Module>(result);
}
std::unique_ptr<ir::module> jit::make_triton_module(const char *name, const char *src) {
// create AST from Triton-C source
YY_BUFFER_STATE buffer = yy_scan_string(src);
yyparse();
yy_delete_buffer(buffer);
translation_unit *program = ast_root;
// create Triton-IR from AST
ir::module* module = new ir::module(name, triton_context_);
program->codegen(module);
return std::unique_ptr<ir::module>(module);
}
jit::jit(driver::context *context): driver_context_(context),
target_(context->device()->make_target()) { }
jit::~jit(){ }
void jit::autotune(const char *name, const char *src, benchmark_t benchmark) {
// find metaparameters
auto ptt_module = make_triton_module(name, src);
ir::module &tt_module = *ptt_module;
// set parameters
passes_wrapper passes(target_.get());
passes.target_independent(tt_module);
passes.tune.run(tt_module);
auto mps = passes.tune.get_params(tt_module);
// create parameter ranges
std::vector<std::vector<unsigned>> ranges;
for(ir::metaparameter *mp: mps)
ranges.push_back(mp->get_space());
// std::cout << ranges.size() << std::endl;
// iterate over parameters
unsigned i;
double best = 0;
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
std::map<ir::value*, std::vector<std::string>> errors;
i = 0;
for(ir::metaparameter *mp: mps)
mp->set_value(params[i++]);
passes.target_independent(tt_module);
passes.tune.init(tt_module);
if(!passes.tune.check_constraints(errors))
return;
// Deep copy of the module and tuner
auto ptt_module = make_triton_module(name, src);
ir::module &tt_module = *ptt_module;
passes_wrapper passes(target_.get());
passes.target_independent(tt_module);
passes.tune.run(tt_module);
i = 0;
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){
mp->set_value(params[i++]);
}
passes.tune.init(tt_module);
passes.target_dependent(tt_module);
driver::device* device = driver_context_->device();
if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory())
return;
if(passes.tune.get_num_threads() > device->max_threads_per_block())
return;
// Compile
auto ll_module = make_llvm_module(tt_module, passes);
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
launch_information info = launch_info_map_.at(name);
for(unsigned p: params)
std::cout << p << " " << std::flush;
// add globals
for(auto x: tt_module.globals())
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
modules_.push_back(module.get());
double perf;
perf = benchmark(kernel.get(), info);
best = std::max(perf, best);
std::cout << perf << " [ " << best << " ] " << std::endl;
modules_.pop_back();
});
}
void jit::add_module(ir::module &tt_module, const std::vector<unsigned> &params) {
// set parameters
passes_wrapper passes(target_.get());
passes.target_independent(tt_module);
passes.tune.run(tt_module);
unsigned i = 0;
for(ir::metaparameter* mp: passes.tune.get_params(tt_module))
mp->set_value(params[i++]);
passes.tune.init(tt_module);
passes.target_dependent(tt_module);
// check constraints
std::map<ir::value*, std::vector<std::string>> errors;
passes.tune.check_constraints(errors);
for(auto x: errors){
std::cout << x.first << std::endl;
for(auto str: x.second)
std::cout << str << std::endl;
}
if(errors.size())
throw std::runtime_error("invalid parameters");
// driver::device* device = driver_context_->device();
// if(passes.allocation.get_allocated_size() > device->max_shared_memory())
// throw std::runtime_error("invalid parameters");
// triton module -> llvm module
auto ll_module = make_llvm_module(tt_module, passes);
// llvm module -> machine code
modules_.push_back(driver::module::create(driver_context_, &*ll_module));
// add globals
for(auto x: tt_module.globals())
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
}
void jit::add_module(const char *name, const char *src, const std::vector<unsigned> &params) {
auto ptt_module = make_triton_module(name, src);
add_module(*ptt_module, params);
}
driver::kernel *jit::get_function(const char *name) {
return driver::kernel::create(modules_.front(), name);
}
jit::launch_information jit::get_launch_info(const char *name) {
return launch_info_map_.at(name);
}
unsigned jit::get_int(const char *name){
return global_ints_.at(name);
}
driver::buffer *jit::get_buffer(const char *name){
driver::cu_module *mod = (driver::cu_module*)modules_.front();
return mod->symbol(name);
}
}