[dnn] added shift in the DNN libs
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/dnn/shift.h"
|
||||
|
||||
// input layout: C, H, W, BS
|
||||
// filter layout: C, K
|
||||
@@ -36,96 +38,6 @@ void shift_conv(int32_t C, int32_t H, int32_t W, int32_t BS,
|
||||
}
|
||||
}
|
||||
|
||||
// K = channels
|
||||
// M = batch * height * width
|
||||
// N = number of feature maps
|
||||
|
||||
const char* src =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
__constant__ int32* delta = alloc_const int32[256];
|
||||
__constant__ int32* masks = alloc_const int32[8192];
|
||||
|
||||
void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
fp32* pxa[TM, TK] = a + rxa[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis];
|
||||
__constant__ int32* pd[TK] = delta + rka;
|
||||
int32 pad_h = AR/2;
|
||||
int32 pad_w = AS/2;
|
||||
int32 rawhc[TM] = rxa / ABS;
|
||||
int32 raw[TM] = rawhc % AW - pad_w;
|
||||
int32 rahc[TM] = rawhc / AW;
|
||||
int32 rah[TM] = rahc % AH - pad_h;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0);
|
||||
__constant__ int32* pxm[TM] = masks + maskh*K + maskw*K*(2*pad_h + 1);
|
||||
__constant__ int32* pm[TM, TK] = pxm[:, newaxis] + rka[newaxis, :];
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
int32 delta[TK] = *pd;
|
||||
fp32 *pa[TM, TK] = pxa + delta[newaxis, :];
|
||||
int1 m[TM, TK] = *pm > 0;
|
||||
fp32 a[TM, TK] = m ? *pa : 0;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
C = dot(a, trans(b), C);
|
||||
pb = pb + TK*N;
|
||||
pd = pd + TK;
|
||||
pm = pm + TK;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
}
|
||||
)";
|
||||
|
||||
std::vector<int32_t> shift_deltas(// strides
|
||||
int32_t stride_w, int32_t stride_h, int32_t stride_c,
|
||||
// shift
|
||||
int32_t C,
|
||||
const std::vector<int32_t>& shift_h,
|
||||
const std::vector<int32_t>& shift_w) {
|
||||
std::vector<int32_t> res(C);
|
||||
for(unsigned c = 0; c < C; c++){
|
||||
res[c] = c*stride_c;
|
||||
res[c] += shift_h[c]*stride_h;
|
||||
res[c] += shift_w[c]*stride_w;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<int32_t> shift_masks(int32_t C,
|
||||
const std::vector<int32_t>& shift_h,
|
||||
const std::vector<int32_t>& shift_w,
|
||||
int32_t R, int32_t S) {
|
||||
size_t S0 = C;
|
||||
size_t S1 = R;
|
||||
size_t S2 = S;
|
||||
std::vector<int32_t> res(S0*S1*S2);
|
||||
for(size_t ph = 0; ph < S1; ++ph)
|
||||
for(size_t pw = 0; pw < S2; ++pw){
|
||||
int32_t* ptr = &res[ph*S0 + pw*S0*S1];
|
||||
for(size_t i = 0; i < S0; ++i){
|
||||
bool in_bounds_h = shift_h[i] + ph >= 0 && shift_h[i] + ph < R;
|
||||
bool in_bounds_w = shift_w[i] + pw >= 0 && shift_w[i] + pw < S;
|
||||
ptr[i] = in_bounds_h && in_bounds_w;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
@@ -136,20 +48,6 @@ int main() {
|
||||
int32_t BS = 4, F = 128;
|
||||
int32_t H = 32, W = 32;
|
||||
int32_t C = 128;
|
||||
// equivalent matmul dimensions
|
||||
int32_t M = BS*H*W;
|
||||
int32_t N = F;
|
||||
int32_t K = C;
|
||||
std::cout << M << " " << N << " " << K << std::endl;
|
||||
std::vector<float> hc(BS*H*W*F);
|
||||
std::vector<float> rc(BS*H*W*F);
|
||||
std::vector<float> ha(BS*C*H*W);
|
||||
std::vector<float> hb(F*C);
|
||||
// strides
|
||||
int32_t stride_i_bs = 1;
|
||||
int32_t stride_i_w = BS*stride_i_bs;
|
||||
int32_t stride_i_h = W*stride_i_w;
|
||||
int32_t stride_i_c = H*stride_i_h;
|
||||
// random shifts
|
||||
std::vector<int32_t> shift_h(C);
|
||||
std::vector<int32_t> shift_w(C);
|
||||
@@ -157,83 +55,63 @@ int main() {
|
||||
shift_h[c] = rand() % R - R/2;
|
||||
shift_w[c] = rand() % S - S/2;
|
||||
}
|
||||
// initialize buffers
|
||||
srand(0);
|
||||
for(int c = 0 ; c < C; c++)
|
||||
for(int h = 0 ; h < H; h++)
|
||||
for(int w = 0 ; w < W; w++)
|
||||
for(int bs = 0 ; bs < BS; bs++){
|
||||
float value = (float)rand() / RAND_MAX;
|
||||
size_t idx = bs + w*stride_i_w + h*stride_i_h + c*stride_i_c;
|
||||
ha[idx] = value;
|
||||
}
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = (float)rand() / RAND_MAX;
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = 0;
|
||||
// configuration
|
||||
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w);
|
||||
// host buffers
|
||||
std::vector<float> hc(shift.c_size());
|
||||
std::vector<float> rc(shift.c_size());
|
||||
std::vector<float> ha(shift.a_size());
|
||||
std::vector<float> hb(shift.b_size());
|
||||
// device buffers
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4);
|
||||
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4);
|
||||
triton::driver::stream* stream = triton::driver::stream::create(context);
|
||||
// initialize host
|
||||
srand(0);
|
||||
for(size_t i = 0; i < ha.size(); i++)
|
||||
ha[i] = (float)rand() / RAND_MAX;
|
||||
for(size_t i = 0; i < hb.size(); i++)
|
||||
hb[i] = (float)rand() / RAND_MAX;
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
hc[i] = 0;
|
||||
// initialize device
|
||||
stream->write(da, true, 0, ha);
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
std::vector<int32_t> h_delta = shift_deltas(stride_i_w, stride_i_h, stride_i_c, C, shift_h, shift_w);
|
||||
std::vector<int32_t> h_masks = shift_masks(C, shift_h, shift_w, R, S);
|
||||
// benchmark a given matrix multiplication kernel
|
||||
// benchmark
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
triton::jit::launch_information info) {
|
||||
shift.init(stream, (triton::driver::cu_module*)kernel->module());
|
||||
// launch info
|
||||
unsigned TM = info.global_range_size[0];
|
||||
unsigned TN = info.global_range_size[1];
|
||||
unsigned nthreads = info.num_threads;
|
||||
// initialize constant memory
|
||||
triton::driver::buffer* delta = ((triton::driver::cu_module*)kernel->module())->symbol("delta");
|
||||
triton::driver::buffer* masks = ((triton::driver::cu_module*)kernel->module())->symbol("masks");
|
||||
stream->write(delta, false, 0, h_delta.size()*4, h_delta.data());
|
||||
stream->write(masks, false, 0, h_masks.size()*4, h_masks.data());
|
||||
stream->synchronize();
|
||||
// set argument
|
||||
kernel->setArg(0, da);
|
||||
kernel->setArg(1, db);
|
||||
kernel->setArg(2, dc);
|
||||
kernel->setArg(3, M);
|
||||
kernel->setArg(4, N);
|
||||
kernel->setArg(5, K);
|
||||
kernel->setArg(6, BS);
|
||||
kernel->setArg(7, H);
|
||||
kernel->setArg(8, W);
|
||||
kernel->setArg(9, R);
|
||||
kernel->setArg(10, S);
|
||||
// dry run
|
||||
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1};
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
shift.enqueue(stream, kernel, da, db, dc, TM, TN, nthreads);
|
||||
stream->synchronize();
|
||||
// benchmark
|
||||
double ts = triton::tools::bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});},
|
||||
double ts = triton::tools::bench([&](){shift.enqueue(stream, kernel, da, db, dc, TM, TN, nthreads);},
|
||||
[&](){ stream->synchronize(); }, context->device());
|
||||
ts = ts * 1e-9;
|
||||
double tflops = 2.*M*N*K / ts * 1e-12;
|
||||
return tflops;
|
||||
return shift.get_nflops() / ts * 1e-3;
|
||||
};
|
||||
|
||||
// shift
|
||||
std::vector<unsigned> params = {
|
||||
16, 2, 64,
|
||||
32, 2, 64,
|
||||
16, 8, 2, 2,
|
||||
8, 8,
|
||||
4
|
||||
8, 2, 16, 8, 2, 32, 8, 4, 2, 2, 4, 2, 8, 4
|
||||
};
|
||||
jit.autotune("shift", src, benchmark);
|
||||
jit.add_module("shift", src, params);
|
||||
std::ostringstream oss;
|
||||
shift.src(oss);
|
||||
std::string src = oss.str();
|
||||
// jit.autotune("shift", src.c_str(), benchmark);
|
||||
jit.add_module("shift", src.c_str(), params);
|
||||
triton::driver::kernel* kernel = jit.get_function("shift");
|
||||
triton::jit::launch_information info = jit.get_launch_info("shift");
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
stream->read(dc, true, 0, hc);
|
||||
shift_conv(C, H, W, BS, F, rc, ha, hb, shift_h, shift_w);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
shift.cpu_ref(rc.data(), ha.data(), hb.data());
|
||||
for(size_t i = 0; i < hc.size(); i++)
|
||||
if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
|
151
include/triton/dnn/shift.h
Normal file
151
include/triton/dnn/shift.h
Normal file
@@ -0,0 +1,151 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#ifndef TDL_INCLUDE_DNN_SHIFT_H
|
||||
#define TDL_INCLUDE_DNN_SHIFT_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <cmath>
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/kernel.h"
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
class shift {
|
||||
|
||||
public:
|
||||
enum type {
|
||||
FPROP
|
||||
};
|
||||
|
||||
private:
|
||||
void set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld);
|
||||
|
||||
public:
|
||||
|
||||
shift(int B, int NC,
|
||||
int D, int H, int W,
|
||||
int T, int R, int S, int NF,
|
||||
const std::vector<int32_t> &shift_h, const std::vector<int32_t> &shift_w,
|
||||
std::string a_ty = "fp32", std::string b_ty = "fp32",
|
||||
type ty = FPROP, bool bias = false);
|
||||
|
||||
// look-up table
|
||||
void build_deltas();
|
||||
void build_masks();
|
||||
|
||||
// accessors
|
||||
size_t a_size();
|
||||
size_t b_size();
|
||||
size_t c_size();
|
||||
std::vector<int32_t> c_shapes();
|
||||
|
||||
// device function
|
||||
void init(driver::stream *stream, driver::cu_module *module);
|
||||
void enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
size_t TM, size_t TN, size_t nthreads);
|
||||
|
||||
// utils
|
||||
size_t get_nflops();
|
||||
|
||||
// source
|
||||
void src(std::ostream &os);
|
||||
|
||||
// cpu_ref
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void cpu_ref(OUT_DTYPE* O,
|
||||
const IN_DTYPE* I,
|
||||
const IN_DTYPE* F)
|
||||
{
|
||||
OUT_DTYPE acc;
|
||||
for(int32_t p = 0; p < AH_; ++p)
|
||||
for(int32_t q = 0; q < AW_; ++q)
|
||||
for(int32_t bs = 0; bs < NB_; ++bs)
|
||||
for(int32_t k = 0; k < NF_; ++k)
|
||||
{
|
||||
acc = 0;
|
||||
for(int32_t c = 0; c < NC_; ++c){
|
||||
int32_t h = p + shift_h_[c];
|
||||
int32_t w = q + shift_w_[c];
|
||||
bool in_bounds = (h >= 0 && w >= 0 && h < AH_ && w < AW_);
|
||||
IN_DTYPE a = in_bounds?I[bs + w*NB_ + h*NB_*AW_ + c*NB_*AH_*AW_]:0;
|
||||
IN_DTYPE b = F[k + c*NF_];
|
||||
acc = std::fma(a, b, acc);
|
||||
}
|
||||
O[bs + q*NB_ + p*NB_*AW_ + k*NB_*AH_*AW_] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// image size
|
||||
int32_t NB_;
|
||||
int32_t NC_;
|
||||
int32_t AD_;
|
||||
int32_t AH_;
|
||||
int32_t AW_;
|
||||
// filter size
|
||||
int32_t BD_;
|
||||
int32_t BH_;
|
||||
int32_t BW_;
|
||||
int32_t NF_;
|
||||
// activation size
|
||||
int32_t CD_;
|
||||
int32_t CH_;
|
||||
int32_t CW_;
|
||||
// equivalent matmul
|
||||
int32_t M_;
|
||||
int32_t N_;
|
||||
int32_t K_;
|
||||
// shapes
|
||||
std::vector<int32_t> shapes_a_;
|
||||
std::vector<int32_t> shapes_b_;
|
||||
std::vector<int32_t> shapes_c_;
|
||||
// memory strides
|
||||
std::vector<int32_t> ld_a_;
|
||||
std::vector<int32_t> ld_b_;
|
||||
std::vector<int32_t> ld_c_;
|
||||
// shift values
|
||||
std::vector<int32_t> shift_h_;
|
||||
std::vector<int32_t> shift_w_;
|
||||
// look-up tables
|
||||
std::vector<int32_t> h_deltas_;
|
||||
std::vector<int32_t> h_masks_;
|
||||
driver::buffer* d_deltas_;
|
||||
driver::buffer* d_masks_;
|
||||
// data types
|
||||
std::string a_ty_;
|
||||
std::string b_ty_;
|
||||
// convolution type
|
||||
type ty_;
|
||||
bool bias_;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
176
lib/dnn/shift.cpp
Normal file
176
lib/dnn/shift.cpp
Normal file
@@ -0,0 +1,176 @@
|
||||
#include "triton/dnn/shift.h"
|
||||
|
||||
|
||||
namespace triton{
|
||||
namespace dnn{
|
||||
|
||||
void shift::set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld) {
|
||||
size_t size = shapes.size();
|
||||
ld.resize(size);
|
||||
ld[4] = 1;
|
||||
ld[3] = shapes[4]*ld[4];
|
||||
ld[2] = shapes[3]*ld[3];
|
||||
ld[1] = shapes[2]*ld[2];
|
||||
ld[0] = shapes[1]*ld[1];
|
||||
}
|
||||
|
||||
shift::shift(int B, int NC,
|
||||
int D, int H, int W,
|
||||
int T, int R, int S,
|
||||
int NF,
|
||||
const std::vector<int32_t>& shift_h, const std::vector<int32_t>& shift_w,
|
||||
std::string a_ty, std::string b_ty,
|
||||
type ty, bool bias)
|
||||
: NB_(B), NC_(NC),
|
||||
AD_(D), AH_(H), AW_(W),
|
||||
BD_(T), BH_(R), BW_(S),
|
||||
NF_(NF),
|
||||
shift_h_(shift_h), shift_w_(shift_w),
|
||||
a_ty_(a_ty), b_ty_(b_ty),
|
||||
ty_(ty), bias_(bias) {
|
||||
// equivalent matmul
|
||||
M_ = NB_*AH_*AW_;
|
||||
N_ = NF_;
|
||||
K_ = NC_;
|
||||
// shapes
|
||||
// input layout: C, H, W, BS
|
||||
// filter layout: C, K
|
||||
// output layout: K, H, W, BS
|
||||
shapes_a_ = {NC, H, W, B};
|
||||
shapes_b_ = {NC, NF};
|
||||
shapes_c_ = {NF, H, W, B};
|
||||
// memory strides
|
||||
set_ld(shapes_a_, ld_a_);
|
||||
}
|
||||
|
||||
void shift::build_deltas() {
|
||||
h_deltas_.resize(NC_);
|
||||
for(unsigned c = 0; c < NC_; c++){
|
||||
h_deltas_[c] = c*ld_a_[0];
|
||||
h_deltas_[c] += shift_h_[c]*ld_a_[1];
|
||||
h_deltas_[c] += shift_w_[c]*ld_a_[2];
|
||||
}
|
||||
}
|
||||
|
||||
void shift::build_masks() {
|
||||
size_t S0 = NC_;
|
||||
size_t S1 = BH_;
|
||||
size_t S2 = BW_;
|
||||
h_masks_.resize(S0*S1*S2);
|
||||
for(size_t ph = 0; ph < S1; ++ph)
|
||||
for(size_t pw = 0; pw < S2; ++pw){
|
||||
int32_t* ptr = &h_masks_[ph*S0 + pw*S0*S1];
|
||||
for(size_t i = 0; i < S0; ++i){
|
||||
bool in_bounds_h = shift_h_[i] + ph >= 0 && shift_h_[i] + ph < BH_;
|
||||
bool in_bounds_w = shift_w_[i] + pw >= 0 && shift_w_[i] + pw < BW_;
|
||||
ptr[i] = in_bounds_h && in_bounds_w;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t shift::a_size(){
|
||||
return std::accumulate(shapes_a_.begin(), shapes_a_.end(),
|
||||
1, std::multiplies<int>());
|
||||
}
|
||||
|
||||
size_t shift::b_size(){
|
||||
return std::accumulate(shapes_b_.begin(), shapes_b_.end(),
|
||||
1, std::multiplies<int>());
|
||||
}
|
||||
|
||||
size_t shift::c_size(){
|
||||
return std::accumulate(shapes_c_.begin(), shapes_c_.end(),
|
||||
1, std::multiplies<int>());
|
||||
}
|
||||
|
||||
std::vector<int32_t> shift::c_shapes(){
|
||||
return shapes_c_;
|
||||
}
|
||||
|
||||
size_t shift::get_nflops() {
|
||||
return 2 * M_ * N_ * K_;
|
||||
}
|
||||
|
||||
|
||||
void shift::init(driver::stream *stream, driver::cu_module *module) {
|
||||
triton::driver::buffer* delta = ((triton::driver::cu_module*)module)->symbol("delta");
|
||||
triton::driver::buffer* masks = ((triton::driver::cu_module*)module)->symbol("masks");
|
||||
stream->write(delta, false, 0, h_deltas_.size()*4, h_deltas_.data());
|
||||
stream->write(masks, false, 0, h_masks_.size()*4, h_masks_.data());
|
||||
}
|
||||
|
||||
void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
size_t TM, size_t TN, size_t nthreads) {
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
kernel->setArg(3, M_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, NB_);
|
||||
kernel->setArg(7, AH_);
|
||||
kernel->setArg(8, AW_);
|
||||
kernel->setArg(9, BH_);
|
||||
kernel->setArg(10, BW_);
|
||||
// dry run
|
||||
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||
stream->enqueue(kernel, grid, {nthreads, 1, 1});
|
||||
}
|
||||
|
||||
void shift::src(std::ostream &os) {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {8};
|
||||
|
||||
__constant__ int32* delta = alloc_const int32[256];
|
||||
__constant__ int32* masks = alloc_const int32[8192];
|
||||
|
||||
void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
fp32* pxa[TM, TK] = a + rxa[:, newaxis];
|
||||
fp32* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis];
|
||||
__constant__ int32* pd[TK] = delta + rka;
|
||||
int32 pad_h = AR/2;
|
||||
int32 pad_w = AS/2;
|
||||
int32 rawhc[TM] = rxa / ABS;
|
||||
int32 raw[TM] = rawhc % AW - pad_w;
|
||||
int32 rahc[TM] = rawhc / AW;
|
||||
int32 rah[TM] = rahc % AH - pad_h;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0);
|
||||
__constant__ int32* pxm[TM] = masks + maskh*K + maskw*K*(2*pad_h + 1);
|
||||
__constant__ int32* pm[TM, TK] = pxm[:, newaxis] + rka[newaxis, :];
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
int32 delta[TK] = *pd;
|
||||
fp32 *pa[TM, TK] = pxa + delta[newaxis, :];
|
||||
int1 m[TM, TK] = *pm > 0;
|
||||
fp32 a[TM, TK] = m ? *pa : 0;
|
||||
fp32 b[TN, TK] = *pb;
|
||||
C = dot(a, trans(b), C);
|
||||
pb = pb + TK*N;
|
||||
pd = pd + TK;
|
||||
pm = pm + TK;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 ryc[TN] = get_global_range[TN](1);
|
||||
fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
@checkc *pc = C;
|
||||
}
|
||||
)";
|
||||
}
|
||||
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user