[triton/python/conv]: Added cache for compiled kernels
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include "common.hpp"
|
||||
#include "triton/jit.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/conv.h"
|
||||
@@ -10,11 +10,11 @@ int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::WGRAD;
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::FPROP;
|
||||
// initialization
|
||||
int32_t B = 32, NF = 128;
|
||||
int32_t B = 4, NF = 32;
|
||||
int32_t D = 1, H = 56, W = 56;
|
||||
int32_t NC = 128, T = 1, R = 3, S = 3;
|
||||
int32_t NC = 32, T = 1, R = 3, S = 3;
|
||||
int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, 1, 1, 1, pad_d, pad_h, pad_w, ty);
|
||||
// convolution configuration
|
||||
@@ -45,7 +45,7 @@ int main() {
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
|
||||
configuration.init(stream, jit);
|
||||
configuration.init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
stream->synchronize();
|
||||
configuration.set_arg(kernel, da, db, dc);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
@@ -55,7 +55,7 @@ int main() {
|
||||
return configuration.get_nflops() / ts * 1e-3;
|
||||
};
|
||||
std::string src = configuration.src();
|
||||
jit.autotune("conv", src.c_str(), benchmark);
|
||||
// jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.add_module("conv", src.c_str(), configuration.default_params());
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
|
@@ -1,7 +1,7 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include "common.hpp"
|
||||
#include "triton/jit.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/gemm.h"
|
||||
|
@@ -1,7 +1,7 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include "common.hpp"
|
||||
#include "triton/jit.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
|
||||
@@ -158,8 +158,8 @@ int main() {
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
// initialize constant memory
|
||||
triton::driver::buffer* delta = jit.get_buffer("delta");
|
||||
triton::driver::buffer* masks = jit.get_buffer("masks");
|
||||
triton::driver::buffer* delta = ((triton::driver::cu_module*)kernel->module())->symbol("delta");
|
||||
triton::driver::buffer* masks = ((triton::driver::cu_module*)kernel->module())->symbol("masks");
|
||||
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
|
||||
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
|
||||
stream->synchronize();
|
||||
|
@@ -2,7 +2,7 @@
|
||||
#include <torch/script.h>
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include <vector>
|
||||
#include "triton/jit.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/dnn/conv.h"
|
||||
|
||||
@@ -10,6 +10,16 @@
|
||||
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
typedef std::tuple<int32_t, int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t,
|
||||
int32_t, int32_t, int32_t,
|
||||
triton::dnn::conv::type> conv_key_t;
|
||||
|
||||
static std::map<CUstream, std::unique_ptr<triton::driver::stream>> m_stream;
|
||||
static std::map<conv_key_t, std::unique_ptr<triton::jit>> m_jit;
|
||||
static std::map<conv_key_t, std::unique_ptr<triton::dnn::conv>> m_config;
|
||||
|
||||
torch::Tensor conv_common(
|
||||
int32_t B, int32_t C, int32_t D, int32_t H, int32_t W,
|
||||
int32_t T, int32_t R, int32_t S, int32_t NF,
|
||||
@@ -18,41 +28,59 @@ torch::Tensor conv_common(
|
||||
triton::dnn::conv::type ty,
|
||||
torch::Tensor torcha, torch::Tensor torchb
|
||||
) {
|
||||
// Configuration
|
||||
triton::dnn::conv configuration(B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty);
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = torcha.storage().device().index();
|
||||
// Get stream
|
||||
CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();
|
||||
triton::driver::stream* stream;
|
||||
if(m_stream.find(custream) == m_stream.end())
|
||||
stream = m_stream.emplace(custream, new triton::driver::cu_stream(custream, false)).first->second.get();
|
||||
else
|
||||
stream = m_stream.at(custream).get();
|
||||
// Get context
|
||||
triton::driver::context* ctx = stream->context();
|
||||
// Get configuration
|
||||
conv_key_t key = {B, C, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, ty};
|
||||
triton::dnn::conv* configuration;
|
||||
if(m_config.find(key) == m_config.end())
|
||||
configuration = m_config.emplace(key, new triton::dnn::conv(
|
||||
B, C, D, H, W, T, R, S, NF,
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w, ty)).first->second.get();
|
||||
else
|
||||
configuration = m_config.at(key).get();
|
||||
// Get JIT
|
||||
triton::jit* jit;
|
||||
if(m_jit.find(key) == m_jit.end()){
|
||||
jit = m_jit.emplace(key, new triton::jit(ctx)).first->second.get();
|
||||
std::string src = configuration->src();
|
||||
jit->add_module("conv", src.c_str(), configuration->default_params());
|
||||
}
|
||||
else
|
||||
jit = m_jit.at(key).get();
|
||||
// Get memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
||||
// Allocate output
|
||||
std::vector<int32_t> c_shapes = configuration.c_shapes();
|
||||
std::vector<int32_t> c_shapes = configuration->c_shapes();
|
||||
torch::Tensor torchc;
|
||||
if(ty == triton::dnn::conv::WGRAD)
|
||||
torchc = torch::empty({c_shapes[0], c_shapes[2], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda();
|
||||
else
|
||||
torchc = torch::empty({c_shapes[0], c_shapes[1], c_shapes[3], c_shapes[4]}, torch::kFloat).cuda();
|
||||
// Wrap CUDA handles
|
||||
c10::DeviceIndex device = torchc.storage().device().index();
|
||||
triton::driver::cu_stream sstream((CUstream)at::cuda::getCurrentCUDAStream(device).stream(), false);
|
||||
triton::driver::stream* stream = &sstream;
|
||||
triton::driver::context* ctx = stream->context();
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
stream->synchronize();
|
||||
// Create JIT
|
||||
triton::jit jit(ctx);
|
||||
std::string src = configuration.src();
|
||||
jit.add_module("conv", src.c_str(), configuration.default_params());
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
// Add module to JIT
|
||||
triton::driver::kernel* kernel = jit->get_function("conv");
|
||||
triton::jit::launch_information info = jit->get_launch_info("conv");
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
// launch info
|
||||
configuration.init(stream, jit);
|
||||
configuration->init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = configuration.get_grid(TM, TN);
|
||||
configuration.set_arg(kernel, &a, &b, &c);
|
||||
stream->synchronize();
|
||||
std::array<size_t, 3> grid = configuration->get_grid(TM, TN);
|
||||
configuration->set_arg(kernel, &a, &b, &c);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->synchronize();
|
||||
return torchc;
|
||||
}
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import time
|
||||
torch.manual_seed(0)
|
||||
|
||||
class TritonConv(torch.autograd.Function):
|
||||
@@ -14,9 +15,9 @@ class TritonConv(torch.autograd.Function):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input = grad_weight = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = torch.ops.triton.conv_bprop(grad_output.contiguous(), weight)
|
||||
grad_input = torch.ops.triton.conv_bprop(grad_output, weight)
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output.contiguous())
|
||||
grad_weight = torch.ops.triton.conv_wgrad(input, grad_output)
|
||||
return grad_input, grad_weight
|
||||
|
||||
|
||||
@@ -38,6 +39,7 @@ x.grad.zero_()
|
||||
w.grad.zero_()
|
||||
culoss, cuy, cudx, cudw = run(x, cuw, lambda x, w: torch.nn.functional.conv2d(x, w, padding=1))
|
||||
|
||||
|
||||
print((tty - cuy).norm(2))
|
||||
print((ttdx - cudx).norm(2))
|
||||
print((ttdw.permute(3,0,1,2) - cudw).norm(2))
|
||||
|
@@ -4,7 +4,6 @@
|
||||
#include <numeric>
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/jit.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
@@ -34,7 +33,7 @@ public:
|
||||
// initialize
|
||||
void build_deltas();
|
||||
void build_masks();
|
||||
void init(driver::stream *stream, triton::jit &jit);
|
||||
void init(driver::stream *stream, driver::cu_module *module);
|
||||
std::array<size_t, 3> get_grid(size_t TM, size_t TN);
|
||||
void set_arg(driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c);
|
||||
@@ -44,7 +43,144 @@ public:
|
||||
std::vector<unsigned> default_params();
|
||||
|
||||
// source
|
||||
std::string src();
|
||||
std::string src(){
|
||||
bool is_wgrad = ty_ == WGRAD;
|
||||
std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]";
|
||||
std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]";
|
||||
std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]";
|
||||
std::string ldb0 = b_trans_ ? "*ldb_s" : "";
|
||||
std::string ldb1 = b_trans_ ? "*ldb_k" : "*ldb_c";
|
||||
std::string useb = b_trans_ ? "trans(b)" : "b";
|
||||
std::string flipr = b_trans_ ? "" : "BH - 1 -";
|
||||
std::string flips = b_trans_ ? "" : "BW - 1 -";
|
||||
std::string ax = b_trans_ ? "crs" : "rsc";
|
||||
std::vector<std::string> redax;
|
||||
if(b_trans_)
|
||||
redax = {"C", "BH", "BW"};
|
||||
else
|
||||
redax = {"BH", "BW", "N"};
|
||||
std::string inc_pb = is_wgrad ? "db[newaxis, :]" : "TK" + ldb0;
|
||||
std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : "";
|
||||
std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : "";
|
||||
std::string masks_mem = is_mask_cst_? "__constant__" : "";
|
||||
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8};
|
||||
)";
|
||||
if(is_a_deltas_cst)
|
||||
res += "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n";
|
||||
if(is_wgrad && is_b_deltas_cst_)
|
||||
res += "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n";
|
||||
if(is_mask_cst_)
|
||||
res += "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
|
||||
res += R"(
|
||||
|
||||
void conv(read_only restrict fp32 *a,
|
||||
read_only restrict fp32 *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 AH, int32 AW,
|
||||
int32 BH, int32 BW,
|
||||
int32 CH, int32 CW,
|
||||
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
|
||||
int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k,
|
||||
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q,
|
||||
int32 pad_h, int32 pad_w)";
|
||||
if(!is_a_deltas_cst)
|
||||
res += ", int32* delta";
|
||||
if(is_wgrad && !is_b_deltas_cst_)
|
||||
res += ", int32* b_delta";
|
||||
if(!is_mask_cst_)
|
||||
res += ", int32* masks";
|
||||
res += R"(){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 rb0[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 ldlut = )" + std::to_string(Fs_) + R"(;
|
||||
int32 rabh[TM] = rxa / CW;
|
||||
int32 raw[TM] = rxa % CW - pad_w;
|
||||
int32 rab[TM] = rabh / CH;
|
||||
int32 rah[TM] = rabh % CH - pad_h;
|
||||
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||
int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
|
||||
int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
|
||||
int32 ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(;
|
||||
int32 ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(;
|
||||
rar = )" + flipr + R"( rar;
|
||||
ras = )" + flips + R"( ras;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
|
||||
if(ty_ == WGRAD){
|
||||
res += R"(
|
||||
int32 rbcr[TK] = rkb / BW;
|
||||
int32 rbs[TK] = rkb % BW;
|
||||
int32 rbc[TK] = rbcr / BH;
|
||||
int32 rbr[TK] = rbcr % BH;
|
||||
int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + ras*ldb_s;
|
||||
)" + b_delta_mem + R"( int32* pdb[TK] = b_delta + rkb;
|
||||
int32 db[TK] = *pdb;)";
|
||||
}
|
||||
else{
|
||||
res += R"(
|
||||
int32 rb1[TK] = rkb;)";
|
||||
}
|
||||
res += R"(
|
||||
fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
|
||||
)" + a_delta_mem + R"( int32* pincd[TK] = delta + rka;
|
||||
)" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka;
|
||||
int32 d[TK] = *pd;
|
||||
int32 incd[TK] = *pincd;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
|
||||
)" + masks_mem + R"( int32* pm[TM] = masks + ldlut + maskw*ldlut + maskh*ldlut*(2*pad_w + 1);
|
||||
)" + a_delta_mem + R"( int32* pincm[TM] = delta;
|
||||
int32 incm[TM] = *pincm;
|
||||
int32 checka0[TM] = *pm;
|
||||
int32 checka1[TK] = 1 << rka;
|
||||
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
fp32 a[TM, TK] = checka ? *pa : 0;
|
||||
fp32 b)" + BS + R"( = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, )" + useb + R"(, C);
|
||||
pa = pa + d[newaxis, :];
|
||||
pb = pb + )" + inc_pb + R"(;
|
||||
b = *pb;
|
||||
pd = pd + incd;)";
|
||||
if(ty_ == WGRAD){
|
||||
res += R"(
|
||||
pdb = pdb + TK;
|
||||
db = *pdb;)";
|
||||
}
|
||||
res += R"(
|
||||
pincd = pincd + incd;
|
||||
d = *pd;
|
||||
incd = *pincd;
|
||||
pm = pm + incm;
|
||||
pincm = pincm + incm;
|
||||
incm = *pincm;
|
||||
checka0 = *pm;
|
||||
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
checka = checka && (k > TK);
|
||||
a = checka ? *pa : 0;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 rc1[TN] = get_global_range[TN](1);
|
||||
int32 rcn[TM] = rxc / (CH*CW);
|
||||
int32 rcpq[TM] = rxc % (CH*CW);
|
||||
int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q;
|
||||
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = rc1 < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
})";
|
||||
return res;
|
||||
}
|
||||
|
||||
// cpu check
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
|
@@ -1,117 +0,0 @@
|
||||
#ifndef TDL_INCLUDE_JIT_H
|
||||
#define TDL_INCLUDE_JIT_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/print.h"
|
||||
#include "triton/driver/module.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include "triton/codegen/selection.h"
|
||||
#include "triton/codegen/tune.h"
|
||||
#include "triton/codegen/optimize_dot.h"
|
||||
#include "triton/codegen/optimize_cse.h"
|
||||
#include "triton/codegen/optimize_trans.h"
|
||||
#include "triton/codegen/shmem_allocation.h"
|
||||
#include "triton/codegen/shmem_liveness.h"
|
||||
#include "triton/codegen/shmem_info.h"
|
||||
#include "triton/codegen/shmem_barriers.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/vectorize.h"
|
||||
#include <functional>
|
||||
|
||||
namespace llvm {
|
||||
class Module;
|
||||
}
|
||||
|
||||
namespace triton {
|
||||
|
||||
namespace codegen{
|
||||
class tune;
|
||||
}
|
||||
|
||||
namespace ir {
|
||||
class module;
|
||||
class context;
|
||||
class metaparameter;
|
||||
}
|
||||
|
||||
class jit {
|
||||
public:
|
||||
struct launch_information{
|
||||
std::vector<unsigned> global_range_size;
|
||||
unsigned num_threads;
|
||||
};
|
||||
typedef std::function<double(driver::kernel*, launch_information)> benchmark_t;
|
||||
|
||||
struct passes_wrapper {
|
||||
passes_wrapper(codegen::target* target)
|
||||
: shmem_liveness(&shmem_info),
|
||||
shmem_allocation(&shmem_liveness, &shmem_info),
|
||||
shmem_barriers(&shmem_allocation, &shmem_info),
|
||||
vectorize(&tune),
|
||||
selection(&shmem_allocation, &tune, &shmem_info, target),
|
||||
optimize_dot(&tune),
|
||||
optimize_cse(),
|
||||
optimize_trans(),
|
||||
target_(target) { }
|
||||
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_trans.run(module);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
if(target_->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
shmem_allocation.run();
|
||||
shmem_barriers.run(module);
|
||||
}
|
||||
vectorize.run(module);
|
||||
}
|
||||
|
||||
codegen::tune tune;
|
||||
codegen::shmem_info shmem_info;
|
||||
codegen::shmem_liveness shmem_liveness;
|
||||
codegen::shmem_allocation shmem_allocation;
|
||||
codegen::shmem_barriers shmem_barriers;
|
||||
codegen::vectorize vectorize;
|
||||
codegen::selection selection;
|
||||
codegen::optimize_dot optimize_dot;
|
||||
codegen::optimize_cse optimize_cse;
|
||||
codegen::optimize_trans optimize_trans;
|
||||
codegen::target* target_;
|
||||
};
|
||||
|
||||
private:
|
||||
std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true);
|
||||
std::unique_ptr<llvm::Module> make_llvm_module(triton::ir::module &module, passes_wrapper &passes);
|
||||
std::unique_ptr<ir::module> make_triton_module(const char* name, const char* src);
|
||||
|
||||
public:
|
||||
jit(driver::context* context);
|
||||
~jit();
|
||||
void autotune(const char* name, const char* src, benchmark_t benchmark);
|
||||
void add_module(ir::module &module, const std::vector<unsigned>& params = {});
|
||||
void add_module(const char* name, const char* src, const std::vector<unsigned>& params = {});
|
||||
driver::kernel* get_function(const char* name);
|
||||
launch_information get_launch_info(const char* name);
|
||||
unsigned get_int(const char* name);
|
||||
driver::buffer* get_buffer(const char* name);
|
||||
|
||||
private:
|
||||
std::vector<driver::module*> modules_;
|
||||
driver::context* driver_context_;
|
||||
llvm::LLVMContext llvm_context_;
|
||||
ir::context triton_context_;
|
||||
std::map<std::string, launch_information> launch_info_map_;
|
||||
std::map<std::string, unsigned> global_ints_;
|
||||
std::unique_ptr<triton::codegen::target> target_;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
|
||||
#endif
|
143
lib/dnn/conv.cpp
143
lib/dnn/conv.cpp
@@ -207,7 +207,7 @@ std::array<size_t, 3> conv::get_grid(size_t TM, size_t TN)
|
||||
size_t conv::get_nflops()
|
||||
{ return 2.*M_*N_*K_; }
|
||||
|
||||
void conv::init(driver::stream *stream, triton::jit &jit) {
|
||||
void conv::init(driver::stream *stream, triton::driver::cu_module* module) {
|
||||
auto init_lut = [&](bool is_cst, const char *name, std::vector<int32_t> host) -> triton::driver::buffer*{
|
||||
if(host.empty())
|
||||
return nullptr;
|
||||
@@ -215,7 +215,7 @@ void conv::init(driver::stream *stream, triton::jit &jit) {
|
||||
// get buffer
|
||||
triton::driver::buffer* buffer;
|
||||
if(is_cst)
|
||||
buffer = jit.get_buffer(name);
|
||||
buffer = module->symbol(name);
|
||||
else
|
||||
buffer = triton::driver::buffer::create(stream->context(), nbytes);
|
||||
// copy
|
||||
@@ -306,145 +306,6 @@ std::vector<unsigned> conv::default_params() {
|
||||
}
|
||||
|
||||
|
||||
std::string conv::src() {
|
||||
bool is_wgrad = ty_ == WGRAD;
|
||||
std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]";
|
||||
std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]";
|
||||
std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]";
|
||||
std::string ldb0 = b_trans_ ? "*ldb_s" : "";
|
||||
std::string ldb1 = b_trans_ ? "*ldb_k" : "*ldb_c";
|
||||
std::string useb = b_trans_ ? "trans(b)" : "b";
|
||||
std::string flipr = b_trans_ ? "" : "BH - 1 -";
|
||||
std::string flips = b_trans_ ? "" : "BW - 1 -";
|
||||
std::string ax = b_trans_ ? "crs" : "rsc";
|
||||
std::vector<std::string> redax;
|
||||
if(b_trans_)
|
||||
redax = {"C", "BH", "BW"};
|
||||
else
|
||||
redax = {"BH", "BW", "N"};
|
||||
std::string inc_pb = is_wgrad ? "db[newaxis, :]" : "TK" + ldb0;
|
||||
std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : "";
|
||||
std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : "";
|
||||
std::string masks_mem = is_mask_cst_? "__constant__" : "";
|
||||
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8};
|
||||
)";
|
||||
if(is_a_deltas_cst)
|
||||
res += "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n";
|
||||
if(is_wgrad && is_b_deltas_cst_)
|
||||
res += "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n";
|
||||
if(is_mask_cst_)
|
||||
res += "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
|
||||
res += R"(
|
||||
|
||||
void conv(read_only restrict fp32 *a,
|
||||
read_only restrict fp32 *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 AH, int32 AW,
|
||||
int32 BH, int32 BW,
|
||||
int32 CH, int32 CW,
|
||||
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
|
||||
int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k,
|
||||
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q,
|
||||
int32 pad_h, int32 pad_w)";
|
||||
if(!is_a_deltas_cst)
|
||||
res += ", int32* delta";
|
||||
if(is_wgrad && !is_b_deltas_cst_)
|
||||
res += ", int32* b_delta";
|
||||
if(!is_mask_cst_)
|
||||
res += ", int32* masks";
|
||||
res += R"(){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 rb0[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 ldlut = )" + std::to_string(Fs_) + R"(;
|
||||
int32 rabh[TM] = rxa / CW;
|
||||
int32 raw[TM] = rxa % CW - pad_w;
|
||||
int32 rab[TM] = rabh / CH;
|
||||
int32 rah[TM] = rabh % CH - pad_h;
|
||||
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||
int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
|
||||
int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
|
||||
int32 ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(;
|
||||
int32 ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(;
|
||||
rar = )" + flipr + R"( rar;
|
||||
ras = )" + flips + R"( ras;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
|
||||
if(ty_ == WGRAD){
|
||||
res += R"(
|
||||
int32 rbcr[TK] = rkb / BW;
|
||||
int32 rbs[TK] = rkb % BW;
|
||||
int32 rbc[TK] = rbcr / BH;
|
||||
int32 rbr[TK] = rbcr % BH;
|
||||
int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + ras*ldb_s;
|
||||
)" + b_delta_mem + R"( int32* pdb[TK] = b_delta + rkb;
|
||||
int32 db[TK] = *pdb;)";
|
||||
}
|
||||
else{
|
||||
res += R"(
|
||||
int32 rb1[TK] = rkb;)";
|
||||
}
|
||||
res += R"(
|
||||
fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
|
||||
)" + a_delta_mem + R"( int32* pincd[TK] = delta + rka;
|
||||
)" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka;
|
||||
int32 d[TK] = *pd;
|
||||
int32 incd[TK] = *pincd;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
|
||||
)" + masks_mem + R"( int32* pm[TM] = masks + ldlut + maskw*ldlut + maskh*ldlut*(2*pad_w + 1);
|
||||
)" + a_delta_mem + R"( int32* pincm[TM] = delta;
|
||||
int32 incm[TM] = *pincm;
|
||||
int32 checka0[TM] = *pm;
|
||||
int32 checka1[TK] = 1 << rka;
|
||||
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
fp32 a[TM, TK] = checka ? *pa : 0;
|
||||
fp32 b)" + BS + R"( = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, )" + useb + R"(, C);
|
||||
pa = pa + d[newaxis, :];
|
||||
pb = pb + )" + inc_pb + R"(;
|
||||
b = *pb;
|
||||
pd = pd + incd;)";
|
||||
if(ty_ == WGRAD){
|
||||
res += R"(
|
||||
pdb = pdb + TK;
|
||||
db = *pdb;)";
|
||||
}
|
||||
res += R"(
|
||||
pincd = pincd + incd;
|
||||
d = *pd;
|
||||
incd = *pincd;
|
||||
pm = pm + incm;
|
||||
pincm = pincm + incm;
|
||||
incm = *pincm;
|
||||
checka0 = *pm;
|
||||
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
checka = checka && (k > TK);
|
||||
a = checka ? *pa : 0;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 rc1[TN] = get_global_range[TN](1);
|
||||
int32 rcn[TM] = rxc / (CH*CW);
|
||||
int32 rcpq[TM] = rxc % (CH*CW);
|
||||
int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q;
|
||||
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = rc1 < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
})";
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void conv::cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
|
||||
{
|
||||
|
@@ -100,10 +100,10 @@ module* module::create(driver::context* ctx, llvm::Module *src) {
|
||||
}
|
||||
|
||||
void module::compile_llvm_module(llvm::Module* module, const std::string& triple,
|
||||
const std::string &proc, std::string layout,
|
||||
llvm::SmallVectorImpl<char> &buffer,
|
||||
const std::string& features,
|
||||
file_type_t ft) {
|
||||
const std::string &proc, std::string layout,
|
||||
llvm::SmallVectorImpl<char> &buffer,
|
||||
const std::string& features,
|
||||
file_type_t ft) {
|
||||
init_llvm();
|
||||
// debug
|
||||
// llvm::legacy::PassManager pm;
|
||||
|
216
lib/jit.cpp
216
lib/jit.cpp
@@ -1,216 +0,0 @@
|
||||
#include "triton/jit.h"
|
||||
#include <string>
|
||||
#include "triton/ast/ast.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/context_impl.h"
|
||||
#include "triton/driver/device.h"
|
||||
#include "triton/driver/error.h"
|
||||
#include "llvm/IR/IRPrintingPasses.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/PassManager.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/Support/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetMachine.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
#include "llvm/CodeGen/TargetPassConfig.h"
|
||||
#include "llvm/IR/LegacyPassManager.h"
|
||||
#include "llvm/Transforms/Scalar/EarlyCSE.h"
|
||||
#include "llvm/Analysis/LoopPass.h"
|
||||
|
||||
typedef struct yy_buffer_state * YY_BUFFER_STATE;
|
||||
extern int yyparse();
|
||||
extern YY_BUFFER_STATE yy_scan_string(const char * str);
|
||||
extern void yy_delete_buffer(YY_BUFFER_STATE buffer);
|
||||
using triton::ast::translation_unit;
|
||||
extern translation_unit *ast_root;
|
||||
|
||||
namespace triton {
|
||||
|
||||
void loop_nest(std::vector<size_t> const & ranges, std::function<void(std::vector<size_t> const &)> const & f){
|
||||
size_t D = ranges.size();
|
||||
std::vector<size_t> values(D, 0);
|
||||
// Start with innermost loop
|
||||
size_t i = D - 1;
|
||||
while(true){
|
||||
//Execute function
|
||||
f(values);
|
||||
//Increment counters
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
return;
|
||||
values[i--] = 0;
|
||||
}
|
||||
i = D - 1;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void loop_nest(std::vector<std::vector<T>> const & iterates, std::function<void(std::vector<T>)> const & f){
|
||||
//Ranges to iterate over
|
||||
std::vector<size_t> ranges;
|
||||
for(auto const & x: iterates)
|
||||
ranges.push_back(x.size());
|
||||
//Proxy function
|
||||
auto proxy = [&](std::vector<size_t> const & idx){
|
||||
std::vector<T> x(iterates.size());
|
||||
for(size_t i = 0; i < x.size(); ++i)
|
||||
x[i] = iterates[i][idx[i]];
|
||||
f(x);
|
||||
};
|
||||
//Iterate
|
||||
loop_nest(ranges, proxy);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
std::unique_ptr<llvm::Module> jit::make_llvm_module(ir::module &module, passes_wrapper &passes) {
|
||||
llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_);
|
||||
passes.selection.run(module, *result);
|
||||
// launch information
|
||||
launch_information& info = launch_info_map_[result->getName()];
|
||||
info.global_range_size.clear();
|
||||
for(unsigned i = 0; i < passes.tune.get_num_global_range(); i++)
|
||||
info.global_range_size.push_back(passes.tune.get_global_range_size(i));
|
||||
info.num_threads = passes.tune.get_num_threads();
|
||||
return std::unique_ptr<llvm::Module>(result);
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::module> jit::make_triton_module(const char *name, const char *src) {
|
||||
// create AST from Triton-C source
|
||||
YY_BUFFER_STATE buffer = yy_scan_string(src);
|
||||
yyparse();
|
||||
yy_delete_buffer(buffer);
|
||||
translation_unit *program = ast_root;
|
||||
// create Triton-IR from AST
|
||||
ir::module* module = new ir::module(name, triton_context_);
|
||||
program->codegen(module);
|
||||
return std::unique_ptr<ir::module>(module);
|
||||
}
|
||||
|
||||
|
||||
jit::jit(driver::context *context): driver_context_(context),
|
||||
target_(context->device()->make_target()) { }
|
||||
|
||||
jit::~jit(){ }
|
||||
|
||||
void jit::autotune(const char *name, const char *src, benchmark_t benchmark) {
|
||||
// find metaparameters
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
// set parameters
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
auto mps = passes.tune.get_params(tt_module);
|
||||
// create parameter ranges
|
||||
std::vector<std::vector<unsigned>> ranges;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
ranges.push_back(mp->get_space());
|
||||
// std::cout << ranges.size() << std::endl;
|
||||
// iterate over parameters
|
||||
unsigned i;
|
||||
double best = 0;
|
||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
i = 0;
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(params[i++]);
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.init(tt_module);
|
||||
if(!passes.tune.check_constraints(errors))
|
||||
return;
|
||||
// Deep copy of the module and tuner
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
ir::module &tt_module = *ptt_module;
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
i = 0;
|
||||
for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){
|
||||
mp->set_value(params[i++]);
|
||||
}
|
||||
passes.tune.init(tt_module);
|
||||
passes.target_dependent(tt_module);
|
||||
driver::device* device = driver_context_->device();
|
||||
if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory())
|
||||
return;
|
||||
if(passes.tune.get_num_threads() > device->max_threads_per_block())
|
||||
return;
|
||||
// Compile
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
|
||||
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
|
||||
launch_information info = launch_info_map_.at(name);
|
||||
for(unsigned p: params)
|
||||
std::cout << p << " " << std::flush;
|
||||
// add globals
|
||||
for(auto x: tt_module.globals())
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
modules_.push_back(module.get());
|
||||
double perf;
|
||||
perf = benchmark(kernel.get(), info);
|
||||
best = std::max(perf, best);
|
||||
std::cout << perf << " [ " << best << " ] " << std::endl;
|
||||
modules_.pop_back();
|
||||
});
|
||||
}
|
||||
|
||||
void jit::add_module(ir::module &tt_module, const std::vector<unsigned> ¶ms) {
|
||||
// set parameters
|
||||
passes_wrapper passes(target_.get());
|
||||
passes.target_independent(tt_module);
|
||||
passes.tune.run(tt_module);
|
||||
unsigned i = 0;
|
||||
for(ir::metaparameter* mp: passes.tune.get_params(tt_module))
|
||||
mp->set_value(params[i++]);
|
||||
passes.tune.init(tt_module);
|
||||
passes.target_dependent(tt_module);
|
||||
// check constraints
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
passes.tune.check_constraints(errors);
|
||||
for(auto x: errors){
|
||||
std::cout << x.first << std::endl;
|
||||
for(auto str: x.second)
|
||||
std::cout << str << std::endl;
|
||||
}
|
||||
if(errors.size())
|
||||
throw std::runtime_error("invalid parameters");
|
||||
// driver::device* device = driver_context_->device();
|
||||
// if(passes.allocation.get_allocated_size() > device->max_shared_memory())
|
||||
// throw std::runtime_error("invalid parameters");
|
||||
// triton module -> llvm module
|
||||
auto ll_module = make_llvm_module(tt_module, passes);
|
||||
// llvm module -> machine code
|
||||
modules_.push_back(driver::module::create(driver_context_, &*ll_module));
|
||||
// add globals
|
||||
for(auto x: tt_module.globals())
|
||||
global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value();
|
||||
}
|
||||
|
||||
void jit::add_module(const char *name, const char *src, const std::vector<unsigned> ¶ms) {
|
||||
auto ptt_module = make_triton_module(name, src);
|
||||
add_module(*ptt_module, params);
|
||||
}
|
||||
|
||||
driver::kernel *jit::get_function(const char *name) {
|
||||
return driver::kernel::create(modules_.front(), name);
|
||||
}
|
||||
|
||||
jit::launch_information jit::get_launch_info(const char *name) {
|
||||
return launch_info_map_.at(name);
|
||||
}
|
||||
|
||||
unsigned jit::get_int(const char *name){
|
||||
return global_ints_.at(name);
|
||||
}
|
||||
|
||||
driver::buffer *jit::get_buffer(const char *name){
|
||||
driver::cu_module *mod = (driver::cu_module*)modules_.front();
|
||||
return mod->symbol(name);
|
||||
}
|
||||
|
||||
}
|
Reference in New Issue
Block a user