[triton/python/conv]: Added cache for compiled kernels

This commit is contained in:
Philippe Tillet
2019-05-18 11:51:49 -04:00
parent 600aef72d5
commit b2b55c52c9
10 changed files with 210 additions and 516 deletions

View File

@@ -207,7 +207,7 @@ std::array<size_t, 3> conv::get_grid(size_t TM, size_t TN)
size_t conv::get_nflops()
{ return 2.*M_*N_*K_; }
void conv::init(driver::stream *stream, triton::jit &jit) {
void conv::init(driver::stream *stream, triton::driver::cu_module* module) {
auto init_lut = [&](bool is_cst, const char *name, std::vector<int32_t> host) -> triton::driver::buffer*{
if(host.empty())
return nullptr;
@@ -215,7 +215,7 @@ void conv::init(driver::stream *stream, triton::jit &jit) {
// get buffer
triton::driver::buffer* buffer;
if(is_cst)
buffer = jit.get_buffer(name);
buffer = module->symbol(name);
else
buffer = triton::driver::buffer::create(stream->context(), nbytes);
// copy
@@ -306,145 +306,6 @@ std::vector<unsigned> conv::default_params() {
}
std::string conv::src() {
bool is_wgrad = ty_ == WGRAD;
std::string BS = b_trans_ ? "[TN,TK]" : "[TK, TN]";
std::string bcb0 = b_trans_ ? "[:, newaxis]" : "[newaxis, :]";
std::string bcb1 = b_trans_ ? "[newaxis, :]" : "[:, newaxis]";
std::string ldb0 = b_trans_ ? "*ldb_s" : "";
std::string ldb1 = b_trans_ ? "*ldb_k" : "*ldb_c";
std::string useb = b_trans_ ? "trans(b)" : "b";
std::string flipr = b_trans_ ? "" : "BH - 1 -";
std::string flips = b_trans_ ? "" : "BW - 1 -";
std::string ax = b_trans_ ? "crs" : "rsc";
std::vector<std::string> redax;
if(b_trans_)
redax = {"C", "BH", "BW"};
else
redax = {"BH", "BW", "N"};
std::string inc_pb = is_wgrad ? "db[newaxis, :]" : "TK" + ldb0;
std::string a_delta_mem = is_a_deltas_cst ? "__constant__" : "";
std::string b_delta_mem = is_b_deltas_cst_? "__constant__" : "";
std::string masks_mem = is_mask_cst_? "__constant__" : "";
std::string res =
R"(
const tunable int32 TM = {16, 32, 64};
const tunable int32 TN = {16, 32, 64};
const tunable int32 TK = {8};
)";
if(is_a_deltas_cst)
res += "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n";
if(is_wgrad && is_b_deltas_cst_)
res += "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n";
if(is_mask_cst_)
res += "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
res += R"(
void conv(read_only restrict fp32 *a,
read_only restrict fp32 *b,
fp32 *c,
int32 M, int32 N, int32 K,
int32 AH, int32 AW,
int32 BH, int32 BW,
int32 CH, int32 CW,
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k,
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q,
int32 pad_h, int32 pad_w)";
if(!is_a_deltas_cst)
res += ", int32* delta";
if(is_wgrad && !is_b_deltas_cst_)
res += ", int32* b_delta";
if(!is_mask_cst_)
res += ", int32* masks";
res += R"(){
int32 rxa[TM] = get_global_range[TM](0);
int32 rb0[TN] = get_global_range[TN](1);
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
fp32 C[TM, TN] = 0;
int32 ldlut = )" + std::to_string(Fs_) + R"(;
int32 rabh[TM] = rxa / CW;
int32 raw[TM] = rxa % CW - pad_w;
int32 rab[TM] = rabh / CH;
int32 rah[TM] = rabh % CH - pad_h;
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
int32 ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(;
int32 ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(;
rar = )" + flipr + R"( rar;
ras = )" + flips + R"( ras;
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
if(ty_ == WGRAD){
res += R"(
int32 rbcr[TK] = rkb / BW;
int32 rbs[TK] = rkb % BW;
int32 rbc[TK] = rbcr / BH;
int32 rbr[TK] = rbcr % BH;
int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + ras*ldb_s;
)" + b_delta_mem + R"( int32* pdb[TK] = b_delta + rkb;
int32 db[TK] = *pdb;)";
}
else{
res += R"(
int32 rb1[TK] = rkb;)";
}
res += R"(
fp32* pb)" + BS + " = b + rb1" + bcb1 + ldb0 + " + rb0" + bcb0 + ldb1 + R"(;
)" + a_delta_mem + R"( int32* pincd[TK] = delta + rka;
)" + a_delta_mem + R"( int32* pd[TK] = delta + ldlut + rka;
int32 d[TK] = *pd;
int32 incd[TK] = *pincd;
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
)" + masks_mem + R"( int32* pm[TM] = masks + ldlut + maskw*ldlut + maskh*ldlut*(2*pad_w + 1);
)" + a_delta_mem + R"( int32* pincm[TM] = delta;
int32 incm[TM] = *pincm;
int32 checka0[TM] = *pm;
int32 checka1[TK] = 1 << rka;
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
fp32 a[TM, TK] = checka ? *pa : 0;
fp32 b)" + BS + R"( = *pb;
for(int32 k = K; k > 0; k = k - TK){
C = dot(a, )" + useb + R"(, C);
pa = pa + d[newaxis, :];
pb = pb + )" + inc_pb + R"(;
b = *pb;
pd = pd + incd;)";
if(ty_ == WGRAD){
res += R"(
pdb = pdb + TK;
db = *pdb;)";
}
res += R"(
pincd = pincd + incd;
d = *pd;
incd = *pincd;
pm = pm + incm;
pincm = pincm + incm;
incm = *pincm;
checka0 = *pm;
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
checka = checka && (k > TK);
a = checka ? *pa : 0;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 rc1[TN] = get_global_range[TN](1);
int32 rcn[TM] = rxc / (CH*CW);
int32 rcpq[TM] = rxc % (CH*CW);
int32 rc0[TM] = rcn * ldc_n + rcpq * ldc_q;
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = rc1 < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = C;
})";
return res;
}
template<class IN_DTYPE, class OUT_DTYPE>
void conv::cpu_xprop(OUT_DTYPE* C, IN_DTYPE* A, IN_DTYPE* B)
{