[EXAMPLES] Added C++ example for Conv2d
This commit is contained in:
committed by
Philippe Tillet
parent
ba9955ae39
commit
f4f216b88a
39
tests/bench/conv.cc
Normal file
39
tests/bench/conv.cc
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
#include "triton/driver/backend.h"
|
||||||
|
#include "triton/driver/stream.h"
|
||||||
|
#include "conv.h"
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
// initialize default compute device
|
||||||
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
|
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||||
|
// shapes to benchmark
|
||||||
|
typedef std::tuple<int, int, int, int, int, int, int, int, int, int, int> config_t;
|
||||||
|
std::vector<config_t> configs = {
|
||||||
|
// {1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1},
|
||||||
|
// {1, 56, 56, 128, 128, 3, 3, 1, 1, 1, 1},
|
||||||
|
// {1, 56, 56, 256, 256, 3, 3, 1, 1, 1, 1},
|
||||||
|
// {1, 56, 56, 384, 384, 3, 3, 1, 1, 1, 1},
|
||||||
|
// {1, 56, 56, 512, 512, 3, 3, 1, 1, 1, 1},
|
||||||
|
// {1, 56, 56, 768, 768, 3, 3, 1, 1, 1, 1},
|
||||||
|
// {1, 56, 56, 1024, 1024, 3, 3, 1, 1, 1, 1},
|
||||||
|
|
||||||
|
// {1, 8, 8, 256, 256, 3, 3, 1, 1, 1, 1},
|
||||||
|
// {1, 16, 16, 256, 256, 3, 3, 1, 1, 1, 1},
|
||||||
|
// {1, 32, 32, 256, 256, 3, 3, 1, 1, 1, 1},
|
||||||
|
// {1, 64, 64, 256, 256, 3, 3, 1, 1, 1, 1},
|
||||||
|
{1, 64, 64, 4096, 4096, 1, 1, 0, 0, 1, 1},
|
||||||
|
// {1, 256, 256, 256, 256, 3, 3, 1, 1, 1, 1}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
};
|
||||||
|
int Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w;
|
||||||
|
for(const auto& c: configs){
|
||||||
|
std::tie(Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w) = c;
|
||||||
|
std::cout << "// " << c ;
|
||||||
|
for(auto perf: bench_conv(stream, HALF, Z, H, W, CO, CI, R, S, pad_h, pad_w, stride_h, stride_w))
|
||||||
|
std::cout << ", " << perf << std::flush;
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
}
|
124
tests/common/conv.h
Normal file
124
tests/common/conv.h
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
#include <iomanip>
|
||||||
|
#include <cstring>
|
||||||
|
#include <sstream>
|
||||||
|
#include <cstdio>
|
||||||
|
#include "triton/driver/backend.h"
|
||||||
|
#include "triton/driver/stream.h"
|
||||||
|
#include "triton/tools/bench.hpp"
|
||||||
|
#include "triton/external/half.hpp"
|
||||||
|
#include "triton/runtime/function.h"
|
||||||
|
#include "src/conv.h"
|
||||||
|
#include "cuda/cublas.h"
|
||||||
|
#include "util.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
enum run_mode_t {
|
||||||
|
BENCH,
|
||||||
|
TEST
|
||||||
|
};
|
||||||
|
|
||||||
|
enum dtype_t {
|
||||||
|
FLOAT,
|
||||||
|
HALF,
|
||||||
|
DOUBLE
|
||||||
|
};
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
struct to_string;
|
||||||
|
|
||||||
|
template<> struct to_string<half_float::half>{
|
||||||
|
static constexpr const char* value = "half";
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct to_string<float>{
|
||||||
|
static constexpr const char* value = "float";
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct to_string<double>{
|
||||||
|
static constexpr const char* value = "double";
|
||||||
|
};
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
void triton_conv(drv::stream* stream,
|
||||||
|
int Z, int CI, int H, int W, int CO, int R, int S,
|
||||||
|
int pad_h, int pad_w, int stride_h, int stride_w,
|
||||||
|
run_mode_t mode, std::vector<double>& bench, bool &test){
|
||||||
|
std::string ty = to_string<T>::value;
|
||||||
|
size_t dt_nbytes = sizeof(T);
|
||||||
|
drv::context* context = stream->context();
|
||||||
|
|
||||||
|
int P = (H + 2*pad_h - R)/stride_h + 1;
|
||||||
|
int Q = (W + 2*pad_w - S)/stride_w + 1;
|
||||||
|
|
||||||
|
// inputs
|
||||||
|
auto dc = std::shared_ptr<drv::buffer>(drv::buffer::create(context, Z*CO*P*Q*dt_nbytes));
|
||||||
|
auto da = std::shared_ptr<drv::buffer>(drv::buffer::create(context, Z*CI*H*W*dt_nbytes));
|
||||||
|
auto db = std::shared_ptr<drv::buffer>(drv::buffer::create(context, CI*R*S*CO*dt_nbytes));
|
||||||
|
auto ddelta = std::shared_ptr<drv::buffer>(drv::buffer::create(context, CI*R*S*4));
|
||||||
|
auto dlocks = std::shared_ptr<drv::buffer>(drv::buffer::create(context, 1024*1024*2*4));
|
||||||
|
((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size());
|
||||||
|
|
||||||
|
std::vector<int32_t> hdelta(CI*R*S);
|
||||||
|
int TK = 16;
|
||||||
|
for(int i = 0; i < hdelta.size(); i++){
|
||||||
|
int s = i % S;
|
||||||
|
int cr = i / S;
|
||||||
|
int r = cr % R;
|
||||||
|
int c = cr / R;
|
||||||
|
int nexti = i + TK;
|
||||||
|
int nexts = nexti % S;
|
||||||
|
int nextcr = nexti / S;
|
||||||
|
int nextr = nextcr % R;
|
||||||
|
int nextc = nextcr / R;
|
||||||
|
hdelta[i] = (nextc - c)*W*H + (nextr - r)*W + (nexts - s);
|
||||||
|
}
|
||||||
|
stream->write(&*ddelta, true, 0, hdelta);
|
||||||
|
|
||||||
|
// macros
|
||||||
|
rt::function::options_space_t opt;
|
||||||
|
opt.defines.push_back({"TYPE", {ty}});
|
||||||
|
opt.defines.push_back({"TM", {"128"}});
|
||||||
|
opt.defines.push_back({"TN", {"128"}});
|
||||||
|
opt.defines.push_back({"TK", {std::to_string(TK)}});
|
||||||
|
opt.defines.push_back({"TZ", {"1"}});
|
||||||
|
opt.defines.push_back({"RR", {std::to_string(R)}});
|
||||||
|
opt.defines.push_back({"SS", {std::to_string(S)}});
|
||||||
|
opt.defines.push_back({"PP", {std::to_string(P)}});
|
||||||
|
opt.defines.push_back({"QQ", {std::to_string(Q)}});
|
||||||
|
opt.defines.push_back({"HH", {std::to_string(H)}});
|
||||||
|
opt.defines.push_back({"WW", {std::to_string(W)}});
|
||||||
|
|
||||||
|
opt.num_warps = {2, 4};
|
||||||
|
|
||||||
|
// kernels
|
||||||
|
rt::function function(src::conv, opt);
|
||||||
|
std::vector<rt::arg> args = {&*da, &*db, &*dc, (float)1, Z*P*Q, CO, CI*R*S,
|
||||||
|
pad_h, pad_w, stride_h, stride_w,
|
||||||
|
&*ddelta,
|
||||||
|
W*H*CI, W*H, W, 1,
|
||||||
|
CO*S*R , CO*S, CO, 1,
|
||||||
|
Q*P*CO, Q*P, Q, 1};
|
||||||
|
auto grid = [Z,P,Q,CO](const rt::function::options_t& x) {
|
||||||
|
return rt::grid_t{ceil(Z*P*Q, x.D<int>("TM")),
|
||||||
|
ceil(CO , x.D<int>("TN")),
|
||||||
|
(size_t)x.D<int>("TZ")};
|
||||||
|
};
|
||||||
|
auto tflops = [&](double nanosec) { return 2.*Z*P*Q*CI*CO*R*S / nanosec * 1e-3; };
|
||||||
|
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
|
||||||
|
bench.push_back(tflops(triton_ns));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<double> bench_conv(drv::stream* stream, dtype_t dtype,
|
||||||
|
int32_t Z, int32_t H, int32_t W, int32_t CO, int32_t CI, int32_t R, int32_t S,
|
||||||
|
int32_t pad_h, int32_t pad_w, int32_t stride_h, int32_t stride_w) {
|
||||||
|
std::vector<double> bench;
|
||||||
|
bool test;
|
||||||
|
switch(dtype){
|
||||||
|
case HALF: triton_conv<half_float::half>(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
|
||||||
|
case FLOAT: triton_conv<float>(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
|
||||||
|
case DOUBLE: triton_conv<double>(stream, Z, CI, H, W, CO, R, S, pad_h, pad_w, stride_h, stride_w, BENCH, bench, test); break;
|
||||||
|
default: break;
|
||||||
|
}
|
||||||
|
return bench;
|
||||||
|
}
|
132
tests/common/src/conv.h
Normal file
132
tests/common/src/conv.h
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
namespace src {
|
||||||
|
|
||||||
|
const char *conv =
|
||||||
|
R"(
|
||||||
|
__global__ void conv(TYPE *A __noalias __readonly __aligned(16),
|
||||||
|
TYPE *B __noalias __readonly __aligned(16),
|
||||||
|
TYPE *C __noalias __aligned(16),
|
||||||
|
float alpha,
|
||||||
|
// equivalent matmul
|
||||||
|
int M, int N, int K,
|
||||||
|
// convolution properties
|
||||||
|
int pad_h, int pad_w, int stride_h, int stride_w,
|
||||||
|
// pointer increment
|
||||||
|
int *ADELTA,
|
||||||
|
// memory strides
|
||||||
|
int lda_z __multipleof(8), int lda_ci __multipleof(8), int lda_h __multipleof(8), int lda_w __multipleof(8),
|
||||||
|
int ldb_ci __multipleof(8), int ldb_r __multipleof(8), int ldb_s __multipleof(8), int ldb_co __multipleof(8),
|
||||||
|
int ldc_z __multipleof(8), int ldc_co __multipleof(8), int ldc_p __multipleof(8), int ldc_q __multipleof(8)) {
|
||||||
|
// prologue
|
||||||
|
int ridx = get_program_id(0);
|
||||||
|
int ridy = get_program_id(1);
|
||||||
|
int ridz = get_program_id(2);
|
||||||
|
int rm[TM] = ridx * TM + 0 ... TM;
|
||||||
|
int rn[TN] = ridy * TN + 0 ... TN;
|
||||||
|
// reduction splitting
|
||||||
|
K = K / TZ;
|
||||||
|
int rk[TK] = ridz * K + 0 ... TK;
|
||||||
|
|
||||||
|
// unpack aggregate rows
|
||||||
|
// m = (z, p, q)
|
||||||
|
int rq[TM] = rm % QQ;
|
||||||
|
int rzp[TM] = rm / QQ;
|
||||||
|
int rp[TM] = rzp % PP;
|
||||||
|
int rz[TM] = rzp / PP;
|
||||||
|
// unpack aggregate reduction
|
||||||
|
// k = (ci, r, s)
|
||||||
|
int rs [TK] = rk % SS;
|
||||||
|
int rcir[TK] = rk / SS;
|
||||||
|
int rr [TK] = rcir % RR;
|
||||||
|
int rci [TK] = rcir / RR;
|
||||||
|
|
||||||
|
// padding / striding
|
||||||
|
int rh_0[TM] = rp * stride_h - pad_h;
|
||||||
|
int rw_0[TM] = rq * stride_w - pad_w;
|
||||||
|
int rh[TM, TK] = rh_0[:, newaxis] + rr[newaxis, :];
|
||||||
|
int rw[TM, TK] = rw_0[:, newaxis] + rs[newaxis, :];
|
||||||
|
|
||||||
|
// pointers to lhs
|
||||||
|
int offa[TM, TK] = rz [:, newaxis] * lda_z +
|
||||||
|
rci[newaxis, :] * lda_ci +
|
||||||
|
rh * lda_h +
|
||||||
|
rw * 1;
|
||||||
|
TYPE* pa[TM, TK] = A + offa;
|
||||||
|
// pointers to rhs
|
||||||
|
int offb[TK, TN] = rci[:, newaxis] * ldb_ci +
|
||||||
|
rr [:, newaxis] * ldb_r +
|
||||||
|
rs [:, newaxis] * ldb_s +
|
||||||
|
rn [newaxis, :] * 1;
|
||||||
|
TYPE* pb[TK, TN] = B + offb;
|
||||||
|
// pointers to delta
|
||||||
|
int* padelta[TK] = ADELTA + rk;
|
||||||
|
int adelta[TK] = *padelta;
|
||||||
|
|
||||||
|
// prefetches operands
|
||||||
|
bool checkam[TM, TK] = rm[:, newaxis] < M;
|
||||||
|
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||||
|
bool checkbn[TK, TN] = rn[newaxis, :] < N;
|
||||||
|
bool checkbk[TK, TN] = rk[:, newaxis] < K;
|
||||||
|
bool checkb[TK, TN] = checkbn && checkbk;
|
||||||
|
TYPE a[TM, TK] = checka ? *pa : 0;
|
||||||
|
TYPE b[TK, TN] = checkb ? *pb : 0;
|
||||||
|
|
||||||
|
// reduction loop
|
||||||
|
float acc[TM, TN] = 0;
|
||||||
|
for(int k = K; k > 0; k -= TK){
|
||||||
|
acc += a @ b;
|
||||||
|
// increment A
|
||||||
|
pa += adelta[newaxis, :];
|
||||||
|
// bounds-checking A
|
||||||
|
rk += TK;
|
||||||
|
rs = rk % SS;
|
||||||
|
rcir = rk / SS;
|
||||||
|
rr = rcir % RR;
|
||||||
|
rh = rh_0[:, newaxis] + rr[newaxis, :];
|
||||||
|
rw = rw_0[:, newaxis] + rs[newaxis, :];
|
||||||
|
bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW;
|
||||||
|
// increment B
|
||||||
|
pb += TK * ldb_s;
|
||||||
|
// bounds-checking B
|
||||||
|
bool checkbk = k > TK;
|
||||||
|
bool checkb[TK, TN] = checkbn && checkbk;
|
||||||
|
a = checka ? *pa : 0;
|
||||||
|
b = *?(checkb)pb;
|
||||||
|
padelta += TK;
|
||||||
|
adelta = *padelta;
|
||||||
|
}
|
||||||
|
acc = acc * alpha;
|
||||||
|
TYPE c[TM, TN] = acc;
|
||||||
|
|
||||||
|
// epilogue
|
||||||
|
rm = ridx * TM + 0 ... TM;
|
||||||
|
rn = ridy * TN + 0 ... TN;
|
||||||
|
rq = rm % QQ;
|
||||||
|
rzp = rm / QQ;
|
||||||
|
rp = rzp % PP;
|
||||||
|
rz = rzp / PP;
|
||||||
|
int offc[TM, TN] = rz [:, newaxis] * ldc_z +
|
||||||
|
rn [newaxis, :] * ldc_co +
|
||||||
|
rp [:, newaxis] * ldc_p +
|
||||||
|
rm [:, newaxis] * 1;
|
||||||
|
TYPE* pc[TM, TN] = C + offc;
|
||||||
|
bool checkc[TM, TN] = rm[:, newaxis] < M && rn[newaxis, :] < N;
|
||||||
|
|
||||||
|
#if (TZ==1)
|
||||||
|
*?(checkc) pc = c;
|
||||||
|
#else
|
||||||
|
// accumulate partial result using spin-locks
|
||||||
|
int *plock = locks + rid;
|
||||||
|
int *pcount = plock + get_num_programs(0) * get_num_programs(1);
|
||||||
|
for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1));
|
||||||
|
int count = *pcount;
|
||||||
|
if(count == 0)
|
||||||
|
*?(checkc) pc = c;
|
||||||
|
else
|
||||||
|
*?(checkc) pc = c + *?(checkc)pc;
|
||||||
|
atomic_xchg(pcount, (count + 1) % TZ);
|
||||||
|
atomic_xchg(plock, 0);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
}
|
Reference in New Issue
Block a user