more performance optimizations

This commit is contained in:
Philippe Tillet
2019-06-28 17:04:07 -07:00
parent a567f3f8a8
commit ab1afbf082
3 changed files with 37 additions and 30 deletions

View File

@@ -14,9 +14,9 @@ int main() {
triton::jit jit(context); triton::jit jit(context);
// initialization // initialization
int32_t R = 3, S = 3; int32_t R = 3, S = 3;
int32_t BS = 4, F = 128; int32_t BS = 4, F = 512;
int32_t H = 32, W = 32; int32_t H = 32, W = 32;
int32_t C = 128; int32_t C = 512;
// random shifts // random shifts
std::vector<int32_t> shift_h(C); std::vector<int32_t> shift_h(C);
std::vector<int32_t> shift_w(C); std::vector<int32_t> shift_w(C);
@@ -68,12 +68,12 @@ int main() {
// shift // shift
std::vector<unsigned> params = { std::vector<unsigned> params = {
4, 2, 16, 8, 2, 64, 4, 8, 2, 2, 4, 8, 8 32, 2, 128, 16, 2, 128, 16, 8, 2, 2, 4, 2, 8, 8
}; };
std::ostringstream oss; std::ostringstream oss;
shift.src(oss); shift.src(oss);
std::string src = oss.str(); std::string src = oss.str();
// jit.autotune("shift", src.c_str(), benchmark); jit.autotune("shift", src.c_str(), benchmark);
jit.add_module("shift", src.c_str(), params); jit.add_module("shift", src.c_str(), params);
triton::driver::kernel* kernel = jit.get_function("shift"); triton::driver::kernel* kernel = jit.get_function("shift");
triton::jit::launch_information info = jit.get_launch_info("shift"); triton::jit::launch_information info = jit.get_launch_info("shift");

View File

@@ -106,6 +106,7 @@ public:
} }
private: private:
int32_t MAX_C_;
// image size // image size
int32_t NB_; int32_t NB_;
int32_t NC_; int32_t NC_;

View File

@@ -28,6 +28,8 @@ shift::shift(int B, int NC,
shift_h_(shift_h), shift_w_(shift_w), shift_h_(shift_h), shift_w_(shift_w),
a_ty_(a_ty), b_ty_(b_ty), a_ty_(a_ty), b_ty_(b_ty),
ty_(ty), bias_(bias) { ty_(ty), bias_(bias) {
// max number of channels
MAX_C_ = 1024;
// equivalent matmul // equivalent matmul
M_ = NB_*AH_*AW_; M_ = NB_*AH_*AW_;
N_ = NF_; N_ = NF_;
@@ -52,16 +54,12 @@ void shift::build_deltas() {
}; };
// allocate look-up table // allocate look-up table
size_t TK = 8; size_t TK = 8;
h_deltas_ = std::vector<int32_t>(512, 0); h_deltas_.resize(MAX_C_);
// populate look-up table // populate look-up table
for(unsigned c = 0; c < TK; c++){ for(unsigned c = 0; c < TK; c++)
h_deltas_[c] = offset(c); // init (shift) h_deltas_[c] = offset(c);
h_deltas_[c + 256] = c*ld_a_[0]; // init (no shift) for(unsigned c = 0; c < NC_; c++)
} h_deltas_[TK + c] = offset(c + TK) - offset(c);
for(unsigned c = 0; c < NC_; c++){
h_deltas_[TK + c] = offset(c + TK) - offset(c); // deltas (shift)
h_deltas_[TK + c + 256] = TK*ld_a_[0]; // deltas (shift)
}
} }
size_t shift::a_size(){ size_t shift::a_size(){
@@ -102,11 +100,12 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(3, M_); kernel->setArg(3, M_);
kernel->setArg(4, N_); kernel->setArg(4, N_);
kernel->setArg(5, K_); kernel->setArg(5, K_);
kernel->setArg(6, NB_); kernel->setArg(6, NB_*AH_*AW_);
kernel->setArg(7, AH_); kernel->setArg(7, NB_);
kernel->setArg(8, AW_); kernel->setArg(8, AH_);
kernel->setArg(9, BH_); kernel->setArg(9, AW_);
kernel->setArg(10, BW_); kernel->setArg(10, BH_);
kernel->setArg(11, BW_);
// dry run // dry run
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
stream->enqueue(kernel, grid, {nthreads, 1, 1}); stream->enqueue(kernel, grid, {nthreads, 1, 1});
@@ -119,19 +118,19 @@ const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128}; const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8}; const tunable int32 TK = {8};
__constant__ int32* delta = alloc_const int32[512]; __constant__ int32* delta = alloc_const int32[)" << MAX_C_ << R"(];
void shift(restrict read_only align(16) fp32 *a, void shift(restrict read_only align(16) )" << a_ty_ << R"( *a,
restrict read_only align(16) fp32 *b, restrict read_only align(16) )" << b_ty_ << R"( *b,
fp32 *c, fp32 *c,
int32 M, int32 N, int32 K, int32 M, int32 N, int32 K,
int32 lda,
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) { int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS) {
int32 rxa[TM] = get_global_range[TM](0); int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1); int32 ryb[TN] = get_global_range[TN](1);
int32 rka[TK] = 0 ... TK; int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK; int32 rkb[TK] = 0 ... TK;
fp32 C[TM, TN] = 0; fp32 C[TM, TN] = 0;
fp32* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis];
int32 pad_h = AR/2; int32 pad_h = AR/2;
int32 pad_w = AS/2; int32 pad_w = AS/2;
int32 rawhc[TM] = rxa / ABS; int32 rawhc[TM] = rxa / ABS;
@@ -140,17 +139,24 @@ void shift(restrict read_only align(16) fp32 *a,
int32 rah[TM] = rahc % AH; int32 rah[TM] = rahc % AH;
int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h)); int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));
int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w)); int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));
int1 mask[TM] = maskh && maskw; int1 mask[TM, TK] = maskh[:, newaxis] && maskw[:, newaxis];
int32 offd[TM] = mask ? 0 : 256; __constant__ int32* pd[TK] = delta + rka;
__constant__ int32* pd[TM, TK] = delta + rka[newaxis, :] + offd[:, newaxis]; int32 d[TK] = *pd;
fp32* pa[TM, TK] = a + rxa[:, newaxis] + (*pd); int32 offa1[TK] = rka*lda;
for(int32 k = K; k > 0; k = k - TK){ int32 inc[TM, TK] = mask ? d[newaxis, :] : offa1[newaxis, :];
fp32 a[TM, TK] = *pa; )" << a_ty_ << R"(* pa[TM, TK] = a + rxa[:, newaxis] + inc;
fp32 b[TN, TK] = *pb; )" << b_ty_ << R"(* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis];
)" << a_ty_ << R"( a[TM, TK] = *pa;
)" << b_ty_ << R"( b[TN, TK] = *pb;
for(int32 k = K; k > TK; k = k - TK){
C = dot(a, trans(b), C); C = dot(a, trans(b), C);
pb = pb + TK*N; pb = pb + TK*N;
pd = pd + TK; pd = pd + TK;
pa = pa + (*pd); d = *pd;
inc = mask ? d[newaxis, :] : TK*lda;
pa = pa + inc;
a = *pa;
b = *pb;
} }
int32 rxc[TM] = get_global_range[TM](0); int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1); int32 ryc[TN] = get_global_range[TN](1);