[general] creation of dnn module for gemm/conv triton routines
This commit is contained in:
@@ -4,87 +4,7 @@
|
||||
#include "triton/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
|
||||
const char* src =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
__constant__ int32* delta = alloc_const int32[18];
|
||||
__constant__ int32* masks = alloc_const int32[1024];
|
||||
|
||||
void conv(read_only restrict fp32 *a,
|
||||
read_only restrict fp32 *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 AN, int32 AH, int32 AW,
|
||||
int32 CN, int32 CK, int32 CP, int32 CQ,
|
||||
int32 AC, int32 AR, int32 AS,
|
||||
int32 lda_n, int32 lda_c, int32 lda_h, int32 lda_w,
|
||||
int32 ldc_n, int32 ldc_k, int32 ldc_p, int32 ldc_q,
|
||||
int32 pad_h, int32 pad_w,
|
||||
int32 bound){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 rb0[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rb1[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 ranh[TM] = rxa / CQ;
|
||||
int32 raw[TM] = rxa % CQ - pad_w;
|
||||
int32 ran[TM] = ranh / CP;
|
||||
int32 rah[TM] = ranh % CP - pad_h;
|
||||
int32 ra0[TM] = ran*lda_n + rah*lda_h + raw*lda_w;
|
||||
int32 racr[TK] = rka / AS;
|
||||
int32 ras[TK] = rka % AS;
|
||||
int32 rac[TK] = racr / AR;
|
||||
int32 rar[TK] = racr % AR;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rb1[newaxis, :]*CK + rb0[:, newaxis];
|
||||
__constant__ int32* pincd[TK] = delta + rka;
|
||||
__constant__ int32* pd[TK] = delta + AR*AS + rka;
|
||||
int32 d[TK] = *pd;
|
||||
int32 incd[TK] = *pincd;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0);
|
||||
__constant__ int32* pm[TM] = masks + AR*AS + maskw*AR*AS + maskh*AR*AS*(2*pad_w + 1);
|
||||
__constant__ int32* pincm[TM] = delta;
|
||||
int32 incm[TM] = *pincm;
|
||||
int32 checka0[TM] = *pm;
|
||||
int32 checka1[TK] = 1 << rka;
|
||||
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
fp32 a[TM, TK] = checka ? *pa : 0;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, trans(b), C);
|
||||
pb = pb + TK*CK;
|
||||
pa = pa + d[newaxis, :];
|
||||
b = *pb;
|
||||
pd = pd + incd;
|
||||
pincd = pincd + incd;
|
||||
d = *pd;
|
||||
incd = *pincd;
|
||||
pm = pm + incm;
|
||||
pincm = pincm + incm;
|
||||
incm = *pincm;
|
||||
checka0 = *pm;
|
||||
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
a = checka ? *pa : 0;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 rc1[TN] = get_global_range[TN](1);
|
||||
int32 rcn[TM] = rxc / (CP*CQ);
|
||||
int32 rcpq[TM] = rxc % (CP*CQ);
|
||||
int32 rc0[TM] = rcn * ldc_n + rcpq;
|
||||
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = rc1 < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
})";
|
||||
|
||||
|
||||
#include "triton/dnn/conv.h"
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
@@ -92,28 +12,28 @@ int main() {
|
||||
// initialize just-in-time compiler
|
||||
triton::jit jit(context);
|
||||
// initialization
|
||||
int32_t AN = 4, CK = 32;
|
||||
int32_t AD = 1, AH = 24, AW = 240;
|
||||
int32_t BC = 64, BT = 1, BR = 3, BS = 3;
|
||||
int32_t B = 4, NF = 32;
|
||||
int32_t D = 1, H = 24, W = 240;
|
||||
int32_t NC = 64, T = 1, R = 3, S = 3;
|
||||
int32_t pad_d = 0, pad_h = 1, pad_w = 1;
|
||||
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
int32_t CM = (AD*upsample_d - BT + 1 + 2*pad_d + stride_d - 1)/stride_d;
|
||||
int32_t CP = (AH*upsample_h - BR + 1 + 2*pad_h + stride_h - 1)/stride_h;
|
||||
int32_t CQ = (AW*upsample_w - BS + 1 + 2*pad_w + stride_w - 1)/stride_w;
|
||||
int32_t RD = (D*upsample_d - T + 1 + 2*pad_d + stride_d - 1)/stride_d;
|
||||
int32_t RH = (H*upsample_h - R + 1 + 2*pad_h + stride_h - 1)/stride_h;
|
||||
int32_t RW = (W*upsample_w - S + 1 + 2*pad_w + stride_w - 1)/stride_w;
|
||||
// equivalent matmul dimensions
|
||||
int32_t M = AN*CM*CP*CQ;
|
||||
int32_t N = CK;
|
||||
int32_t K = BC*BT*BR*BS;
|
||||
std::vector<float> hc(AN*CP*CQ*CK);
|
||||
std::vector<float> rc(AN*CP*CQ*CK);
|
||||
std::vector<float> ha(AN*BC*AH*AW);
|
||||
std::vector<float> hb(BC*BR*BS*CK);
|
||||
int32_t M = B*RD*RH*RW;
|
||||
int32_t N = NF;
|
||||
int32_t K = NC*T*R*S;
|
||||
std::vector<float> hc(B*RH*RW*NF);
|
||||
std::vector<float> rc(B*RH*RW*NF);
|
||||
std::vector<float> ha(B*NC*H*W);
|
||||
std::vector<float> hb(NC*R*S*NF);
|
||||
srand(0);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = 1;
|
||||
ha[i] = (float)rand()/RAND_MAX;
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = 1;
|
||||
hb[i] = (float)rand()/RAND_MAX;
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = 0;
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
@@ -126,36 +46,25 @@ int main() {
|
||||
stream->synchronize();
|
||||
// memory strides for data
|
||||
int32_t stride_i_w = 1;
|
||||
int32_t stride_i_h = AW*stride_i_w;
|
||||
int32_t stride_i_d = AH*stride_i_h;
|
||||
int32_t stride_i_c = AD*stride_i_d;
|
||||
int32_t stride_i_n = BC*stride_i_c;
|
||||
// memory strides for filters
|
||||
int32_t stride_f_k = 1;
|
||||
int32_t stride_f_s = CK*stride_f_k;
|
||||
int32_t stride_f_r = BS*stride_f_s;
|
||||
int32_t stride_f_t = BR*stride_f_r;
|
||||
int32_t stride_f_c = BT*stride_f_t;
|
||||
int32_t stride_i_h = W*stride_i_w;
|
||||
int32_t stride_i_d = H*stride_i_h;
|
||||
int32_t stride_i_c = D*stride_i_d;
|
||||
int32_t stride_i_n = NC*stride_i_c;
|
||||
// memory stride for activations
|
||||
int32_t stride_o_q = 1;
|
||||
int32_t stride_o_p = CQ*stride_o_q;
|
||||
int32_t stride_o_m = CP*stride_o_p;
|
||||
int32_t stride_o_k = CM*stride_o_m;
|
||||
int32_t stride_o_n = CK*stride_o_k;
|
||||
int32_t stride_o_p = RW*stride_o_q;
|
||||
int32_t stride_o_m = RH*stride_o_p;
|
||||
int32_t stride_o_k = RD*stride_o_m;
|
||||
int32_t stride_o_n = NF*stride_o_k;
|
||||
// look-up table
|
||||
int TK = 8;
|
||||
int F = BT * BR * BS;
|
||||
int nlut = (TK + F - 1) / F * F;
|
||||
std::vector<int> h_delta(nlut + upsample_d*upsample_h*upsample_w*nlut);
|
||||
std::vector<int> h_masks(nlut + (2*pad_h+1)*(2*pad_w+1)*(2*pad_d+1)*nlut);
|
||||
build_conv_lut(TK, stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, BT, BR, BS, h_delta, h_masks);
|
||||
std::vector<int> h_delta, h_masks;
|
||||
triton::dnn::conv::init_cst(stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks);
|
||||
// benchmark a given convolution kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned TK = jit.get_int("TK");
|
||||
// initialize constant memory
|
||||
triton::driver::buffer* delta = jit.get_buffer("delta");
|
||||
triton::driver::buffer* masks = jit.get_buffer("masks");
|
||||
@@ -165,15 +74,6 @@ int main() {
|
||||
// launch info
|
||||
unsigned nthreads = info.num_threads;
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
|
||||
// fast bounds-checking
|
||||
unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1;
|
||||
unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1;
|
||||
unsigned lastk = TK - 1;
|
||||
bool AT = false;
|
||||
bool BT = true;
|
||||
unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk;
|
||||
unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk;
|
||||
int32_t bound = std::max<unsigned>(1, std::max(K - last_safe_a, K - last_safe_b));
|
||||
// set arguments
|
||||
kernel->setArg(0, da);
|
||||
kernel->setArg(1, db);
|
||||
@@ -181,52 +81,41 @@ int main() {
|
||||
kernel->setArg(3, M);
|
||||
kernel->setArg(4, N);
|
||||
kernel->setArg(5, K);
|
||||
kernel->setArg(6, AN);
|
||||
kernel->setArg(7, AH);
|
||||
kernel->setArg(8, AW);
|
||||
kernel->setArg(9, AN);
|
||||
kernel->setArg(10, CK);
|
||||
kernel->setArg(11, CP);
|
||||
kernel->setArg(12, CQ);
|
||||
kernel->setArg(13, BC);
|
||||
kernel->setArg(14, BR);
|
||||
kernel->setArg(15, BS);
|
||||
kernel->setArg(16, stride_i_n);
|
||||
kernel->setArg(17, stride_i_c);
|
||||
kernel->setArg(18, stride_i_h);
|
||||
kernel->setArg(19, stride_i_w);
|
||||
kernel->setArg(20, stride_o_n);
|
||||
kernel->setArg(21, stride_o_k);
|
||||
kernel->setArg(22, stride_o_p);
|
||||
kernel->setArg(23, stride_o_q);
|
||||
kernel->setArg(24, pad_h);
|
||||
kernel->setArg(25, pad_w);
|
||||
kernel->setArg(26, bound);
|
||||
kernel->setArg(6, B);
|
||||
kernel->setArg(7, H);
|
||||
kernel->setArg(8, W);
|
||||
kernel->setArg(9, NF);
|
||||
kernel->setArg(10, RH);
|
||||
kernel->setArg(11, RW);
|
||||
kernel->setArg(12, NC);
|
||||
kernel->setArg(13, R);
|
||||
kernel->setArg(14, S);
|
||||
kernel->setArg(15, stride_i_n);
|
||||
kernel->setArg(16, stride_i_c);
|
||||
kernel->setArg(17, stride_i_h);
|
||||
kernel->setArg(18, stride_i_w);
|
||||
kernel->setArg(19, stride_o_n);
|
||||
kernel->setArg(20, stride_o_k);
|
||||
kernel->setArg(21, stride_o_p);
|
||||
kernel->setArg(22, stride_o_q);
|
||||
kernel->setArg(23, pad_h);
|
||||
kernel->setArg(24, pad_w);
|
||||
// dry run
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->synchronize();
|
||||
// benchmark
|
||||
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
[&](){ stream->synchronize(); }, *context->device());
|
||||
ts = ts * 1e-9;
|
||||
double tflops = 2.*M*N*K / ts * 1e-12;
|
||||
return tflops;
|
||||
return 2.*M*N*K / ts * 1e-3;
|
||||
};
|
||||
// run
|
||||
std::vector<unsigned> params = {
|
||||
16, 2, 64,
|
||||
32, 2, 64,
|
||||
16, 8, 2, 2,
|
||||
8, 1, 8,
|
||||
4
|
||||
};
|
||||
// jit.autotune("conv", src, benchmark);
|
||||
jit.add_module("conv", src, params);
|
||||
std::string src = triton::dnn::conv::src();
|
||||
// jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.add_module("conv", src.c_str(), triton::dnn::conv::default_params());
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
stream->read(dc, true, 0, hc);
|
||||
cpp_conv_nchw(BC, AN, CK, AD, AH, AW, BT, BR, BS, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, CM, CP, CQ, rc, ha, hb);
|
||||
cpp_conv_nchw(NC, B, NF, D, H, W, T, R, S, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, RD, RH, RW, rc, ha, hb);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
|
@@ -4,99 +4,7 @@
|
||||
#include "triton/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
|
||||
|
||||
std::string triton_source(bool AT, bool BT) {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT ? "trans(a)" : "a";
|
||||
std::string useb = BT ? "trans(b)" : "b";
|
||||
if(AT){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(bca0, bca1);
|
||||
std::swap(lda0, lda1);
|
||||
}
|
||||
if(BT){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(bcb0, bcb1);
|
||||
std::swap(ldb0, ldb1);
|
||||
}
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {8};
|
||||
const tunable int32 GZ = {1};
|
||||
|
||||
void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 lda, int32 ldb, int32 ldc,
|
||||
int32 *locks, int32 grid0, int32 grid1) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rz = get_global_range[1](2);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 c[TM, TN] = 0;
|
||||
int32 div = K / GZ;
|
||||
int32 rem = K % GZ;
|
||||
K = select(rz < rem, div - 1, div);
|
||||
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
|
||||
fp32* pa[)" + AS0 + ", " + AS1 + "] = A + (offk + rka" + bca0 + ")" + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
fp32* pb[)" + BS0 + ", " + BS1 + "] = B + (offk + rkb" + bcb0 + ")" + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
fp32 a[)" + AS0 + ", " + AS1 + R"(] = *pa;
|
||||
fp32 b[)" + BS0 + ", " + BS1 + R"(] = *pb;
|
||||
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
|
||||
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
|
||||
last_a = last_a / TK * TK;
|
||||
last_b = last_b / TK * TK;
|
||||
int32 bound = K - max(last_a, last_b);
|
||||
for(int32 k = K; k > bound; k = k - TK){
|
||||
c = dot()" + usea + ", " + useb + R"(, c);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
for(int32 k = bound; k > 0; k = k - 1){
|
||||
int1 checka[TM, 1] = rxc[:, newaxis] < M;
|
||||
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
|
||||
fp32* pa[TM, 1] = A + (offk + K - k))" + lda0 + " + rxc[:, newaxis]" + lda1 + R"(;
|
||||
fp32* pb[TN, 1] = B + (offk + K - k))" + ldb0 + " + ryc[:, newaxis]" + ldb1 + R"(;
|
||||
fp32 a[TM, 1] = checka ? *pa : 0;
|
||||
fp32 b[TN, 1] = checkb ? *pb : 0;
|
||||
c = dot(a, trans(b), c);
|
||||
}
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
int32 *plock = locks + ridx + ridy*grid0;
|
||||
while(__atomic_cas(plock, 0, 1));
|
||||
int32 *pcount = plock + grid0*grid1;
|
||||
int32 count = *pcount;
|
||||
int32 countp1 = select(count == GZ - 1, 0, count + 1);
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
if(count == 0) {
|
||||
@checkc *pc = c;
|
||||
*pcount = countp1;
|
||||
}
|
||||
else {
|
||||
@checkc *pc = c + *pc;
|
||||
*pcount = countp1;
|
||||
}
|
||||
__atomic_cas(plock, 1, 0);
|
||||
}
|
||||
)";
|
||||
return res;
|
||||
}
|
||||
#include "triton/dnn/gemm.h"
|
||||
|
||||
|
||||
int main() {
|
||||
@@ -129,51 +37,31 @@ int main() {
|
||||
stream->write(da, true, 0, ha);
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
triton::dnn::gemm::init(stream, dlocks);
|
||||
stream->synchronize();
|
||||
|
||||
|
||||
// benchmark a given matrix multiplication kernel
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
unsigned GZ = jit.get_int("GZ");
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
|
||||
// init locks
|
||||
stream->write(dlocks, true, 0, hlocks);
|
||||
// set argument
|
||||
kernel->setArg(0, da);
|
||||
kernel->setArg(1, db);
|
||||
kernel->setArg(2, dc);
|
||||
kernel->setArg(3, M);
|
||||
kernel->setArg(4, N);
|
||||
kernel->setArg(5, K);
|
||||
kernel->setArg(6, M);
|
||||
kernel->setArg(7, N);
|
||||
kernel->setArg(8, M);
|
||||
kernel->setArg(9, dlocks);
|
||||
kernel->setArg(10, grid[0]);
|
||||
kernel->setArg(11, grid[1]);
|
||||
// dry run
|
||||
triton::dnn::gemm::set_arg(kernel, da, db, dc, M, N, K, dlocks, grid[0], grid[1]);
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
stream->synchronize();
|
||||
// benchmark
|
||||
double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
[&](){ stream->synchronize(); }, *context->device());
|
||||
ts = ts * 1e-9;
|
||||
double tflops = 2.*M*N*K / ts * 1e-12;
|
||||
return tflops;
|
||||
return 2.*M*N*K / ts * 1e-3;
|
||||
};
|
||||
|
||||
|
||||
// just-in-time compile source-code
|
||||
std::string src = triton_source(AT, BT);
|
||||
std::string src = triton::dnn::gemm::src(AT, BT);
|
||||
// jit.autotune("matmul",src.c_str(), benchmark);
|
||||
jit.add_module("matmul", src.c_str(), {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1});
|
||||
// jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 32, 4, 2, 2, 8, 8, 4, 2, 1});
|
||||
// jit.add_module("matmul", src.c_str(), {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1});
|
||||
jit.add_module("matmul", src.c_str(), triton::dnn::gemm::default_params(AT, BT));
|
||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
|
@@ -266,24 +266,23 @@ torch::Tensor conv_forward(
|
||||
kernel->setArg(6, B);
|
||||
kernel->setArg(7, H);
|
||||
kernel->setArg(8, W);
|
||||
kernel->setArg(9, B);
|
||||
kernel->setArg(10, NF);
|
||||
kernel->setArg(11, P);
|
||||
kernel->setArg(12, Q);
|
||||
kernel->setArg(13, Ci);
|
||||
kernel->setArg(14, R);
|
||||
kernel->setArg(15, S);
|
||||
kernel->setArg(16, stride_i_n);
|
||||
kernel->setArg(17, stride_i_c);
|
||||
kernel->setArg(18, stride_i_h);
|
||||
kernel->setArg(19, stride_i_w);
|
||||
kernel->setArg(20, stride_o_n);
|
||||
kernel->setArg(21, stride_o_k);
|
||||
kernel->setArg(22, stride_o_p);
|
||||
kernel->setArg(23, stride_o_q);
|
||||
kernel->setArg(24, pad_h);
|
||||
kernel->setArg(25, pad_w);
|
||||
kernel->setArg(26, bound);
|
||||
kernel->setArg(9, NF);
|
||||
kernel->setArg(10, P);
|
||||
kernel->setArg(11, Q);
|
||||
kernel->setArg(12, Ci);
|
||||
kernel->setArg(13, R);
|
||||
kernel->setArg(14, S);
|
||||
kernel->setArg(15, stride_i_n);
|
||||
kernel->setArg(16, stride_i_c);
|
||||
kernel->setArg(17, stride_i_h);
|
||||
kernel->setArg(18, stride_i_w);
|
||||
kernel->setArg(19, stride_o_n);
|
||||
kernel->setArg(20, stride_o_k);
|
||||
kernel->setArg(21, stride_o_p);
|
||||
kernel->setArg(22, stride_o_q);
|
||||
kernel->setArg(23, pad_h);
|
||||
kernel->setArg(24, pad_w);
|
||||
kernel->setArg(25, bound);
|
||||
// // dry run
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
return output;
|
||||
|
20
examples/python/tensorflow/run.py
Normal file
20
examples/python/tensorflow/run.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
data_files_path = tf.resource_loader.get_data_files_path()
|
||||
library_dir = '/home/philippe/Development/triton/build/examples/python/tensorflow'
|
||||
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
|
||||
|
||||
M, N, K = 512, 512, 512
|
||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||
locks = tf.placeholder(tf.int32, shape=[4096])
|
||||
c = module.block_sparse_mat_mul(a, b, locks)
|
||||
# Run
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([c], feed_dict = {locks: np.zeros(4096),
|
||||
a: np.random.rand(M, K),
|
||||
b: np.random.rand(N, K)})
|
||||
print(result)
|
197
include/triton/dnn/conv.h
Normal file
197
include/triton/dnn/conv.h
Normal file
@@ -0,0 +1,197 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
class conv {
|
||||
public:
|
||||
enum type {
|
||||
FPROP,
|
||||
BPROP,
|
||||
WGRAD
|
||||
};
|
||||
|
||||
static void build_lut(int TK,
|
||||
int stride_d, int stride_h, int stride_w, int stride_c,
|
||||
int pad_d, int pad_h, int pad_w,
|
||||
int T, int R, int S,
|
||||
std::vector<int>& res, std::vector<int>& masks) {
|
||||
/* convolution parameters */
|
||||
int F = T * R * S;
|
||||
int Nlut = (TK + F - 1) / F * F;
|
||||
int upsample_w = 1;
|
||||
int upsample_h = 1;
|
||||
int upsample_d = 1;
|
||||
/* unpack index wrt filters */
|
||||
auto unpack = [&](int32_t trs){
|
||||
int32_t tr = trs / S;
|
||||
int32_t s = trs - tr*S;
|
||||
int32_t t = tr / R;
|
||||
int32_t r = tr - t*R;
|
||||
return std::make_tuple(t, r, s);
|
||||
};
|
||||
/* increments */
|
||||
for(size_t i = 0; i < Nlut; ++i)
|
||||
res[i] = (((i + TK) % Nlut) - i);
|
||||
/* deltas */
|
||||
size_t Ds0 = Nlut;
|
||||
size_t Ds1 = upsample_w;
|
||||
size_t Ds2 = upsample_h;
|
||||
size_t Ds3 = upsample_d;
|
||||
for(size_t pd = 0; pd < Ds3; ++pd)
|
||||
for(size_t ph = 0; ph < Ds2; ++ph)
|
||||
for(size_t pw = 0; pw < Ds1; ++pw){
|
||||
int32_t* deltas_ptr = &res[Nlut + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2];
|
||||
// cumulative increments
|
||||
for(size_t i = 0; i < Ds0; ++i){
|
||||
int32_t ctrs = i;
|
||||
int32_t c = ctrs / F;
|
||||
int32_t t, r, s;
|
||||
std::tie(t, r, s) = unpack(ctrs % F);
|
||||
// next indices
|
||||
int32_t nextctrs = ctrs + TK;
|
||||
int32_t nextc = nextctrs / F;
|
||||
int32_t nextt, nextr, nexts;
|
||||
std::tie(nextt, nextr, nexts) = unpack(nextctrs % F);
|
||||
// diffs
|
||||
int32_t cdiff = nextc - c;
|
||||
int32_t tdiff = (nextt + pd)/upsample_d - (t + pd)/upsample_d;
|
||||
int32_t rdiff = (nextr + ph)/upsample_h - (r + ph)/upsample_h;
|
||||
int32_t sdiff = (nexts + pw)/upsample_w - (s + pw)/upsample_w;
|
||||
// delta pointers
|
||||
deltas_ptr[i] = cdiff*stride_c + sdiff*stride_w + rdiff*stride_h + tdiff*stride_d;
|
||||
}
|
||||
}
|
||||
|
||||
/* Masks */
|
||||
size_t Ms0 = Nlut;
|
||||
size_t Ms1 = 2*pad_w + 1;
|
||||
size_t Ms2 = 2*pad_h + 1;
|
||||
size_t Ms3 = 2*pad_d + 1;
|
||||
for(size_t pd = 0; pd < Ms3; ++pd)
|
||||
for(size_t ph = 0; ph < Ms2; ++ph)
|
||||
for(size_t pw = 0; pw < Ms1; ++pw){
|
||||
int32_t* masks_ptr = &masks[Nlut + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2];
|
||||
for(size_t i = 0; i < Ms0; ++i){
|
||||
int32_t t, r, s;
|
||||
int32_t mask = 0x0;
|
||||
for(size_t j = 0; j < TK; ++j){
|
||||
std::tie(t, r, s) = unpack((i + j) % F);
|
||||
bool in_bounds_d = (t + pd) >= pad_d && (t + pd) < (T + pad_d);
|
||||
bool in_bounds_h = (r + ph) >= pad_h && (r + ph) < (R + pad_h);
|
||||
bool in_bounds_w = (s + pw) >= pad_w && (s + pw) < (S + pad_w);
|
||||
mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j;
|
||||
}
|
||||
masks_ptr[i] = mask;
|
||||
}
|
||||
}
|
||||
for(size_t i = 0; i < Nlut; ++i)
|
||||
masks[i] = 0x0;
|
||||
|
||||
}
|
||||
|
||||
static std::vector<unsigned> default_params() {
|
||||
return {16, 2, 64, 32, 2, 64, 16, 8, 2, 2, 8, 1, 8, 4 };
|
||||
}
|
||||
|
||||
static void init_cst(int stride_d, int stride_h, int stride_w, int stride_c,
|
||||
int pad_d, int pad_h, int pad_w,
|
||||
int T, int R, int S,
|
||||
std::vector<int> &h_delta, std::vector<int> &h_masks) {
|
||||
int upsample_d = 1;
|
||||
int upsample_h = 1;
|
||||
int upsample_w = 1;
|
||||
int TK = 8;
|
||||
int F = T * R * S;
|
||||
int nlut = (TK + F - 1) / F * F;
|
||||
h_delta.resize(nlut + upsample_d*upsample_h*upsample_w*nlut);
|
||||
h_masks.resize(nlut + (2*pad_h+1)*(2*pad_w+1)*(2*pad_d+1)*nlut);
|
||||
build_lut(TK, stride_d, stride_h, stride_w, stride_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks);
|
||||
}
|
||||
|
||||
static std::string src(type ty = FPROP) {
|
||||
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
__constant__ int32* delta = alloc_const int32[18];
|
||||
__constant__ int32* masks = alloc_const int32[1024];
|
||||
|
||||
void conv(read_only restrict fp32 *a,
|
||||
read_only restrict fp32 *b,
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 B, int32 H, int32 W,
|
||||
int32 NF, int32 RH, int32 RW,
|
||||
int32 NC, int32 R, int32 S,
|
||||
int32 lda_n, int32 lda_c, int32 lda_h, int32 lda_w,
|
||||
int32 ldc_n, int32 ldc_k, int32 ldc_p, int32 ldc_q,
|
||||
int32 pad_h, int32 pad_w){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 rb0[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rb1[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 rabh[TM] = rxa / RW;
|
||||
int32 raw[TM] = rxa % RW - pad_w;
|
||||
int32 rab[TM] = rabh / RH;
|
||||
int32 rah[TM] = rabh % RH - pad_h;
|
||||
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||
int32 racr[TK] = rka / S;
|
||||
int32 ras[TK] = rka % S;
|
||||
int32 rac[TK] = racr / R;
|
||||
int32 rar[TK] = racr % R;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rb1[newaxis, :]*NF + rb0[:, newaxis];
|
||||
__constant__ int32* pincd[TK] = delta + rka;
|
||||
__constant__ int32* pd[TK] = delta + R*S + rka;
|
||||
int32 d[TK] = *pd;
|
||||
int32 incd[TK] = *pincd;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + R - H, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + S - W, 0);
|
||||
__constant__ int32* pm[TM] = masks + R*S + maskw*R*S + maskh*R*S*(2*pad_w + 1);
|
||||
__constant__ int32* pincm[TM] = delta;
|
||||
int32 incm[TM] = *pincm;
|
||||
int32 checka0[TM] = *pm;
|
||||
int32 checka1[TK] = 1 << rka;
|
||||
int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
fp32 a[TM, TK] = checka ? *pa : 0;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
C = dot(a, trans(b), C);
|
||||
pb = pb + TK*NF;
|
||||
pa = pa + d[newaxis, :];
|
||||
b = *pb;
|
||||
pd = pd + incd;
|
||||
pincd = pincd + incd;
|
||||
d = *pd;
|
||||
incd = *pincd;
|
||||
pm = pm + incm;
|
||||
pincm = pincm + incm;
|
||||
incm = *pincm;
|
||||
checka0 = *pm;
|
||||
checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0;
|
||||
a = checka ? *pa : 0;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 rc1[TN] = get_global_range[TN](1);
|
||||
int32 rcn[TM] = rxc / (RH*RW);
|
||||
int32 rcpq[TM] = rxc % (RH*RW);
|
||||
int32 rc0[TM] = rcn * ldc_n + rcpq;
|
||||
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = rc1 < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
})";
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
139
include/triton/dnn/gemm.h
Normal file
139
include/triton/dnn/gemm.h
Normal file
@@ -0,0 +1,139 @@
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
#include <string>
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
class gemm {
|
||||
public:
|
||||
|
||||
static void init(driver::stream* stream, driver::buffer* locks) {
|
||||
std::vector<int32_t> hlocks(2048, 0);
|
||||
stream->write(locks, false, 0, hlocks);
|
||||
}
|
||||
|
||||
static void set_arg(driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
driver::buffer *locks, int32_t grid_0, int32_t grid_1) {
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
kernel->setArg(3, M);
|
||||
kernel->setArg(4, N);
|
||||
kernel->setArg(5, K);
|
||||
kernel->setArg(6, M);
|
||||
kernel->setArg(7, N);
|
||||
kernel->setArg(8, M);
|
||||
kernel->setArg(9, locks);
|
||||
kernel->setArg(10, grid_0);
|
||||
kernel->setArg(11, grid_1);
|
||||
}
|
||||
|
||||
static std::vector<unsigned> default_params(bool AT, bool BT) {
|
||||
if(AT && BT)
|
||||
return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1};
|
||||
else if(AT && !BT)
|
||||
return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1};
|
||||
else if(!AT && BT)
|
||||
return {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1};
|
||||
else
|
||||
return {16, 2, 128, 32, 32, 32, 4, 2, 2, 8, 8, 4, 2, 1};
|
||||
}
|
||||
|
||||
static std::string src(bool AT, bool BT) {
|
||||
std::string AS0 = "TM", AS1 = "TK";
|
||||
std::string BS0 = "TK", BS1 = "TN";
|
||||
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT ? "trans(a)" : "a";
|
||||
std::string useb = BT ? "trans(b)" : "b";
|
||||
if(AT){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(bca0, bca1);
|
||||
std::swap(lda0, lda1);
|
||||
}
|
||||
if(BT){
|
||||
std::swap(BS0, BS1);
|
||||
std::swap(bcb0, bcb1);
|
||||
std::swap(ldb0, ldb1);
|
||||
}
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {8};
|
||||
const tunable int32 GZ = {1};
|
||||
|
||||
void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 lda, int32 ldb, int32 ldc,
|
||||
int32 *locks, int32 grid0, int32 grid1) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rz = get_global_range[1](2);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 c[TM, TN] = 0;
|
||||
int32 div = K / GZ;
|
||||
int32 rem = K % GZ;
|
||||
K = select(rz < rem, div - 1, div);
|
||||
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
|
||||
fp32* pa[)" + AS0 + ", " + AS1 + "] = A + (offk + rka" + bca0 + ")" + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
fp32* pb[)" + BS0 + ", " + BS1 + "] = B + (offk + rkb" + bcb0 + ")" + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
fp32 a[)" + AS0 + ", " + AS1 + R"(] = *pa;
|
||||
fp32 b[)" + BS0 + ", " + BS1 + R"(] = *pb;
|
||||
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
|
||||
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
|
||||
last_a = last_a / TK * TK;
|
||||
last_b = last_b / TK * TK;
|
||||
int32 bound = K - max(last_a, last_b);
|
||||
for(int32 k = K; k > bound; k = k - TK){
|
||||
c = dot()" + usea + ", " + useb + R"(, c);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
for(int32 k = bound; k > 0; k = k - 1){
|
||||
int1 checka[TM, 1] = rxc[:, newaxis] < M;
|
||||
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
|
||||
fp32* pa[TM, 1] = A + (offk + K - k))" + lda0 + " + rxc[:, newaxis]" + lda1 + R"(;
|
||||
fp32* pb[TN, 1] = B + (offk + K - k))" + ldb0 + " + ryc[:, newaxis]" + ldb1 + R"(;
|
||||
fp32 a[TM, 1] = checka ? *pa : 0;
|
||||
fp32 b[TN, 1] = checkb ? *pb : 0;
|
||||
c = dot(a, trans(b), c);
|
||||
}
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
int32 *plock = locks + ridx + ridy*grid0;
|
||||
while(__atomic_cas(plock, 0, 1));
|
||||
int32 *pcount = plock + grid0*grid1;
|
||||
int32 count = *pcount;
|
||||
int32 countp1 = select(count == GZ - 1, 0, count + 1);
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
if(count == 0) {
|
||||
@checkc *pc = c;
|
||||
*pcount = countp1;
|
||||
}
|
||||
else {
|
||||
@checkc *pc = c + *pc;
|
||||
*pcount = countp1;
|
||||
}
|
||||
__atomic_cas(plock, 1, 0);
|
||||
}
|
||||
)";
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
@@ -204,6 +204,7 @@ public:
|
||||
static int initializeLLVMToSPIRVPass(llvm::PassRegistry &);
|
||||
static bool writeSpirv(llvm::Module *M, std::ostream &OS, std::string &ErrMsg);
|
||||
|
||||
|
||||
private:
|
||||
|
||||
// Libraries
|
||||
|
0
lib/frontend/jit.cpp
Normal file
0
lib/frontend/jit.cpp
Normal file
Reference in New Issue
Block a user