[EXAMPLES] Added C++ example for Conv2d

This commit is contained in:
Philippe Tillet
2020-08-08 14:51:29 -04:00
committed by Philippe Tillet
parent ba9955ae39
commit f4f216b88a
3 changed files with 295 additions and 0 deletions

39
tests/bench/conv.cc Normal file
View File

@@ -0,0 +1,39 @@
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "conv.h"
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context);
// shapes to benchmark
typedef std::tuple<int, int, int, int, int, int, int, int, int, int, int> config_t;
std::vector<config_t> configs = {
// {1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 128, 128, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 256, 256, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 384, 384, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 512, 512, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 768, 768, 3, 3, 1, 1, 1, 1},
// {1, 56, 56, 1024, 1024, 3, 3, 1, 1, 1, 1},
// {1, 8, 8, 256, 256, 3, 3, 1, 1, 1, 1},
// {1, 16, 16, 256, 256, 3, 3, 1, 1, 1, 1},
// {1, 32, 32, 256, 256, 3, 3, 1, 1, 1, 1},
// {1, 64, 64, 256, 256, 3, 3, 1, 1, 1, 1},
{1, 64, 64, 4096, 4096, 1, 1, 0, 0, 1, 1},
// {1, 256, 256, 256, 256, 3, 3, 1, 1, 1, 1}
};
int Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w;
for(const auto& c: configs){
std::tie(Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w) = c;
std::cout << "// " << c ;
for(auto perf: bench_conv(stream, HALF, Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;
}
}

124
tests/common/conv.h Normal file
View File

@@ -0,0 +1,124 @@
#include <iomanip>
#include <cstring>
#include <sstream>
#include <cstdio>
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "triton/runtime/function.h"
#include "src/conv.h"
#include "cuda/cublas.h"
#include "util.h"
enum run_mode_t {
BENCH,
TEST
};
enum dtype_t {
FLOAT,
HALF,
DOUBLE
};
template<class T>
struct to_string;
template<> struct to_string<half_float::half>{
static constexpr const char* value = "half";
};
template<> struct to_string<float>{
static constexpr const char* value = "float";
};
template<> struct to_string<double>{
static constexpr const char* value = "double";
};
template<class T>
void triton_conv(drv::stream* stream,
int Z, int CI, int H, int W, int CO, int R, int S,
int pad_h, int pad_w, int stride_h, int stride_w,
run_mode_t mode, std::vector<double>& bench, bool &test){
std::string ty = to_string<T>::value;
size_t dt_nbytes = sizeof(T);
drv::context* context = stream->context();
int P = (H + 2*pad_h - R)/stride_h + 1;
int Q = (W + 2*pad_w - S)/stride_w + 1;
// inputs
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, Z*CO*P*Q*dt_nbytes));
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, Z*CI*H*W*dt_nbytes));
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, CI*R*S*CO*dt_nbytes));
auto ddelta = std::shared_ptr<drv::buffer>(drv::buffer::create(context, CI*R*S*4));
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
std::vector<int32_t> hdelta(CI*R*S);
int TK = 16;
for(int i = 0; i < hdelta.size(); i++){
int s = i % S;
int cr = i / S;
int r = cr % R;
int c = cr / R;
int nexti = i + TK;
int nexts = nexti % S;
int nextcr = nexti / S;
int nextr = nextcr % R;
int nextc = nextcr / R;
hdelta[i] = (nextc - c)*W*H + (nextr - r)*W + (nexts - s);
}
stream->write(&*ddelta, true, 0, hdelta);
// macros
rt::function::options_space_t opt;
opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"TM", {"128"}});
opt.defines.push_back({"TN", {"128"}});
opt.defines.push_back({"TK", {std::to_string(TK)}});
opt.defines.push_back({"TZ", {"1"}});
opt.defines.push_back({"RR", {std::to_string(R)}});
opt.defines.push_back({"SS", {std::to_string(S)}});
opt.defines.push_back({"PP", {std::to_string(P)}});
opt.defines.push_back({"QQ", {std::to_string(Q)}});
opt.defines.push_back({"HH", {std::to_string(H)}});
opt.defines.push_back({"WW", {std::to_string(W)}});
opt.num_warps = {2, 4};
// kernels
rt::function function(src::conv, opt);
std::vector<rt::arg> args = {&*da, &*db, &*dc, (float)1, Z*P*Q, CO, CI*R*S,
pad_h, pad_w, stride_h, stride_w,
&*ddelta,
W*H*CI, W*H, W, 1,
CO*S*R , CO*S, CO, 1,
Q*P*CO, Q*P, Q, 1};
auto grid = [Z,P,Q,CO](const rt::function::options_t& x) {
return rt::grid_t{ceil(Z*P*Q, x.D<int>("TM")),
ceil(CO , x.D<int>("TN")),
(size_t)x.D<int>("TZ")};
};
auto tflops = [&](double nanosec) { return 2.*Z*P*Q*CI*CO*R*S / nanosec * 1e-3; };
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
bench.push_back(tflops(triton_ns));
}
std::vector<double> bench_conv(drv::stream* stream, dtype_t dtype,
int32_t Z, int32_t H, int32_t W, int32_t CO, int32_t CI, int32_t R, int32_t S,
int32_t pad_h, int32_t pad_w, int32_t stride_h, int32_t stride_w) {
std::vector<double> bench;
bool test;
switch(dtype){
case HALF: triton_conv<half_float::half>(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
case FLOAT: triton_conv<float>(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
case DOUBLE: triton_conv<double>(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
default: break;
}
return bench;
}

132
tests/common/src/conv.h Normal file
View File

@@ -0,0 +1,132 @@
namespace src {
const char *conv =
R"(
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
TYPE *B __noalias __readonly __aligned(16),
TYPE *C __noalias __aligned(16),
float alpha,
// equivalent matmul
int M, int N, int K,
// convolution properties
int pad_h, int pad_w, int stride_h, int stride_w,
// pointer increment
int *ADELTA,
// memory strides
int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8),
int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8),
int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int ridz = get_program_id(2);
int rm[TM] = ridx * TM + 0 ... TM;
int rn[TN] = ridy * TN + 0 ... TN;
// reduction splitting
K = K / TZ;
int rk[TK] = ridz * K + 0 ... TK;
// unpack aggregate rows
// m = (z, p, q)
int rq[TM] = rm % QQ;
int rzp[TM] = rm / QQ;
int rp[TM] = rzp % PP;
int rz[TM] = rzp / PP;
// unpack aggregate reduction
// k = (ci, r, s)
int rs [TK] = rk % SS;
int rcir[TK] = rk / SS;
int rr [TK] = rcir % RR;
int rci [TK] = rcir / RR;
// padding / striding
int rh_0[TM] = rp * stride_h - pad_h;
int rw_0[TM] = rq * stride_w - pad_w;
int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :];
int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :];
// pointers to lhs
int offa[TM, TK] = rz [:, newaxis] * lda_z +
rci[newaxis, :] * lda_ci +
rh * lda_h +
rw * 1;
TYPE* pa[TM, TK] = A + offa;
// pointers to rhs
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
rr [:, newaxis] * ldb_r +
rs [:, newaxis] * ldb_s +
rn [newaxis, :] * 1;
TYPE* pb[TK, TN] = B + offb;
// pointers to delta
int* padelta[TK] = ADELTA + rk;
int adelta[TK] = *padelta;
// prefetches operands
bool checkam[TM, TK] = rm[:, newaxis] < M;
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
bool checkbn[TK, TN] = rn[newaxis, :] < N;
bool checkbk[TK, TN] = rk[:, newaxis] < K;
bool checkb[TK, TN] = checkbn && checkbk;
TYPE a[TM, TK] = checka ? *pa : 0;
TYPE b[TK, TN] = checkb ? *pb : 0;
// reduction loop
float acc[TM, TN] = 0;
for(int k = K; k > 0; k -= TK){
acc += a @ b;
// increment A
pa += adelta[newaxis, :];
// bounds-checking A
rk += TK;
rs = rk % SS;
rcir = rk / SS;
rr = rcir % RR;
rh = rh_0[:, newaxis] + rr[newaxis, :];
rw = rw_0[:, newaxis] + rs[newaxis, :];
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
// increment B
pb += TK * ldb_s;
// bounds-checking B
bool checkbk = k > TK;
bool checkb[TK, TN] = checkbn && checkbk;
a = checka ? *pa : 0;
b = *?(checkb)pb;
padelta += TK;
adelta = *padelta;
}
acc = acc * alpha;
TYPE c[TM, TN] = acc;
// epilogue
rm = ridx * TM + 0 ... TM;
rn = ridy * TN + 0 ... TN;
rq = rm % QQ;
rzp = rm / QQ;
rp = rzp % PP;
rz = rzp / PP;
int offc[TM, TN] = rz [:, newaxis] * ldc_z +
rn [newaxis, :] * ldc_co +
rp [:, newaxis] * ldc_p +
rm [:, newaxis] * 1;
TYPE* pc[TM, TN] = C + offc;
bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N;
#if (TZ==1)
*?(checkc) pc = c;
#else
// accumulate partial result using spin-locks
int *plock = locks + rid;
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *?(checkc)pc;
atomic_xchg(pcount, (count + 1) % TZ);
atomic_xchg(plock, 0);
#endif
}
)";
}