[dnn/shift] added split-k for shift-conv

This commit is contained in:
Philippe Tillet
2019-07-15 21:03:58 -07:00
parent 434f65737f
commit aa8bcf6bde
9 changed files with 166 additions and 203 deletions

View File

@@ -18,8 +18,8 @@ int main() {
// initialization
int32_t R = 3, S = 3;
int32_t B = 32, F = 128;
int32_t H = 28, W = 28;
int32_t B = 128, F = 128;
int32_t H = 16, W = 16;
int32_t C = 128;
// random shifts
@@ -44,9 +44,9 @@ int main() {
std::swap(b_size, c_size);
std::swap(a_size, b_size);
}
std::vector<NumericT> ha(B*C*H*W);
std::vector<NumericT> hb(C*F);
std::vector<float> hc(B*F*H*W);
std::vector<NumericT> ha(a_size);
std::vector<NumericT> hb(b_size);
std::vector<float> hc(c_size);
std::vector<float> rc(hc.size());
// device buffers
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);

View File

@@ -58,9 +58,9 @@ def blocksparse_matmul_grad(op, dy):
return (dx, dw)
def run_shift():
B, C, H, W = 1, 16, 4, 4
B, C, H, W = 2, 16, 4, 4
R, S, F = 3, 3, 16
stride_h, stride_w = 2, 2
stride_h, stride_w = 1, 1
np.random.seed(2)
a = tf.placeholder(tf.float16, shape=[B, C, H, W])
b = tf.placeholder(tf.float16, shape=[C, F])
@@ -82,8 +82,8 @@ def run_shift():
dx_t, dx_n = grads[0]
#import sys
#np.set_printoptions(threshold=sys.maxsize)
print(dx_t)
print(dx_n)
print(dw_t)
print(dw_n)
print(np.max(np.abs(dw_t - dw_n)))
print(np.max(np.abs(dx_t - dx_n)))
# Run

View File

@@ -43,6 +43,8 @@ protected:
private:
// initialize
virtual void init_impl(driver::stream *, driver::cu_module *){ }
// deinitialize
virtual void deinit_impl(){ }
// enqueue
virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,

View File

@@ -52,6 +52,7 @@ public:
private:
// initialize and enqueue
void init_impl(driver::stream *stream, driver::cu_module *module);
void deinit_impl();
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
triton::runtime::launch_information info);
@@ -163,6 +164,9 @@ private:
bool BT_;
// layout
layout_t layout_;
// locks
size_t max_locks_;
driver::buffer *locks_;
};
}

View File

@@ -32,6 +32,7 @@ double bench(OP const & op, SYNC const & sync, const triton::driver::device * de
double total_time = 0;
op();
sync();
// while(total_time*1e-9 < 1e-3){
float norm = 1;
// normalize clock if possible to get roughly constant result
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
@@ -41,6 +42,7 @@ double bench(OP const & op, SYNC const & sync, const triton::driver::device * de
sync();
times.push_back(norm*tmr.get().count());
total_time+=times.back();
// }
return *std::min_element(times.begin(), times.end());
}

View File

@@ -29,19 +29,20 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
rt::jit* jit;
/* the current template has not already been compiled */
if(m_jit.find(this) == m_jit.end()) {
jit = m_jit.emplace(this->clone(), new rt::jit(ctx)).first->second.get();
base* clone = this->clone();
jit = m_jit.emplace(clone, new rt::jit(ctx)).first->second.get();
std::ostringstream oss;
triton_c_src(oss);
clone->triton_c_src(oss);
std::string src = oss.str();
auto benchmark = [&](triton::driver::kernel* kernel,
rt::launch_information info) {
// launch info
unsigned nthreads = info.num_threads;
init_impl(stream, (triton::driver::cu_module*)kernel->module());
enqueue_impl(stream, kernel, args, info);
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
clone->enqueue_impl(stream, kernel, args, info);
stream->synchronize();
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, info); },
double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); },
[&](){ stream->synchronize(); }, ctx->device());
clone->deinit_impl();
return num_flops() / ts * 1e-3;
};
// auto-tune and save result
@@ -53,7 +54,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
init_impl(stream, (triton::driver::cu_module*)kernel->module());
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
}
/* retrieved compiled template */
else
@@ -63,7 +64,8 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
driver::kernel* kernel = jit->get_function(name_.c_str());
rt::launch_information info = jit->get_launch_info(name_.c_str());
/* launch */
enqueue_impl(stream, kernel, args, info);
auto it = m_jit.find(this);
it->first->enqueue_impl(stream, kernel, args, info);
}
}

View File

@@ -124,6 +124,9 @@ shift::shift(int B, int C,
if(layout_ == NCHW)
shapes_c_ = {B, C, AH_, AW_};
}
// locks
max_locks_ = (op_ == WGRAD) ? 8192 : 0;
locks_ = nullptr;
}
base* shift::clone() const {
@@ -195,11 +198,30 @@ void shift::init_impl(driver::stream *stream, driver::cu_module *module) {
build_delta_a();
triton::driver::buffer* delta_a = ((triton::driver::cu_module*)module)->symbol("delta_a");
stream->write(delta_a, false, 0, h_delta_a.size()*4, h_delta_a.data());
// locks
if(locks_ == nullptr && max_locks_ > 0){
std::vector<int32_t> hlocks(2*max_locks_, 0);
locks_ = triton::driver::buffer::create(stream->context(), 2*max_locks_*4);
stream->write(locks_, false, 0, hlocks);
}
}
void shift::deinit_impl() {
if(locks_ != nullptr){
delete locks_;
locks_ = nullptr;
}
}
void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer *> args,
runtime::launch_information info) {
unsigned TM = info.global_range_size[0], TN = info.global_range_size[1];
unsigned grid_0 = (M_ + TM - 1)/TM;
unsigned grid_1 = (N_ + TN - 1)/TN;
unsigned num_locks = grid_0 * grid_1;
unsigned grid_2 = num_locks < max_locks_ ? info.globals.at("GZ") : 1;
std::array<size_t, 3> grid = {grid_0, grid_1, grid_2};
driver::buffer *a = args[0], *b = args[1], *c = args[2];
kernel->setArg(0, a);
kernel->setArg(1, b);
@@ -228,8 +250,9 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(24, BW_);
kernel->setArg(25, CH_);
kernel->setArg(26, CW_);
unsigned TM = info.global_range_size[0], TN = info.global_range_size[1];
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_);
kernel->setArg(28, (int32_t)grid[0]);
kernel->setArg(29, (int32_t)grid[1]);
if(op_ == BPROP){
size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4;
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes);
@@ -256,12 +279,49 @@ void shift::triton_c_src(std::ostream &os) const {
std::string BS = BS0 + ", " + BS1;
bool is_chwn = layout_ == CHWN;
auto compute_bhw = [&](std::string rx, std::string sz, std::string rkx){
if(is_chwn) {
return R"(
int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB;
int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB;
int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW;
int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW;)";
}
else {
return R"(
int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW;
int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW;
int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH;
int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)";
}
};
auto compute_interior = [&](std::string rx, std::string sz0, std::string sz1) {
std::string result;
if(shift_edge_h_)
result += "int1 interiorh[" + sz0 + "] = 1;\n ";
else
result += "int1 interiorh[" + sz0 + "] = (" + rx + "h >= pad_h) && (" + rx + "h < (AH - pad_h));\n ";
if(shift_edge_w_)
result += "int1 interiorw[" + sz0 + "] = 1;";
else
result += "int1 interiorw[" + sz0 + "] = (" + rx + "w >= pad_w) && (" + rx + "w < (AW - pad_w));";
result += R"(
int1 interior[)" + sz0 + ", " + sz1 + "] = interiorh[:, newaxis] && interiorw[:, newaxis];";
return result;
};
std::string result =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {)" + std::to_string(TK_) + R"(};
const tunable int32 TK = {)" + std::to_string(TK_) + "};";
if(op_ == WGRAD)
result += "const tunable int32 GZ = {1, 4, 16};";
else
result += "const tunable int32 GZ = {1};";
result += R"(
__constant__ int32* delta_a = alloc_const int32[)" + std::to_string(MAX_C_) + R"(];
void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
@@ -275,32 +335,32 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
int32 NB,
int32 AH, int32 AW,
int32 BH, int32 BW,
int32 CH, int32 CW) {
int32 CH, int32 CW,
int32* locks, int32 grid0, int32 grid1) {
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);
int32 rz = get_global_range[1](2);
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
fp32 acc[TM, TN] = 0;
int32 pad_h = BH / 2;
int32 pad_w = BW / 2;)";
int32 pad_w = BW / 2;
int32 split = select(locks == 0, 1, GZ);
int32 div = K / split;
int32 rem = K % split;
K = select(rz < rem, div - 1, div);
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);)";
if(op_ == WGRAD){
result += R"(
rka = rka + offk;
rkb = rkb + offk;
)";
}
/* A offsets */
if(op_ == FPROP){
if(is_chwn){
result += R"(
int32 rawh[TM] = rxa / NB;
int32 rab[TM] = rxa % NB;
int32 raw[TM] = rawh % CW;
int32 rah[TM] = rawh / CW;)";
}
else{
result += R"(
int32 rabh[TM] = rxa / CW;
int32 raw[TM] = rxa % CW;
int32 rah[TM] = rabh % CH;
int32 rab[TM] = rabh / CH;)";
}
result += R"(
result +=
compute_bhw("ra", "TM", "rxa") + R"(
raw = raw * stride_w;
rah = rah * stride_h;
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
@@ -309,35 +369,12 @@ if(op_ == FPROP){
int32 d[TK] = *pd;
int32 offa_interior[TM, TK] = d[newaxis, :];
int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;
)";
if(shift_edge_h_)
result += " int1 interiorh[TM] = 1;\n";
else
result += " int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));\n";
if(shift_edge_w_)
result += " int1 interiorw[TM] = 1;";
else
result += " int1 interiorw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));";
result += R"(
int1 interior[TM, TK] = interiorh[:, newaxis] && interiorw[:, newaxis];
)" + compute_interior("ra", "TM", "TK") + R"(
int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)";
}
if(op_ == BPROP){
if(is_chwn){
result += R"(
int32 rawh[TM] = rxa / NB;
int32 rab[TM] = rxa % NB;
int32 raw[TM] = rawh % CW;
int32 rah[TM] = rawh / CW;)";
}
else{
result += R"(
int32 rabh[TM] = rxa / CW;
int32 raw[TM] = rxa % CW;
int32 rah[TM] = rabh % CH;
int32 rab[TM] = rabh / CH;)";
}
result += R"(
result +=
compute_bhw("ra", "TM", "rxa") + R"(
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
int32 offa0[TM, TK] = offxa[:, newaxis];
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
@@ -348,21 +385,8 @@ if(op_ == WGRAD && layout_ == CHWN){
int32 offa1[TK, TM] = rka[:, newaxis];)";
}
if(op_ == WGRAD && layout_ == NCHW){
if(is_chwn){
result += R"(
int32 rawh[TK] = rka / NB;
int32 rab[TK] = rka % NB;
int32 raw[TK] = rawh % CW;
int32 rah[TK] = rawh / CW;)";
}
else{
result += R"(
int32 rabh[TK] = rka / CW;
int32 raw[TK] = rka % CW;
int32 rah[TK] = rabh % CH;
int32 rab[TK] = rabh / CH;)";
}
result += R"(
result +=
compute_bhw("ra", "TK", "rka") + R"(
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h;
int32 offa1[TK, TM] = offxa[:, newaxis];)";
@@ -380,38 +404,15 @@ if(op_ == BPROP){
int32 offb1[TK, TN] = rkb[:, newaxis];)";
}
if(op_ == WGRAD){
if(is_chwn){
result += R"(
int32 rbwh[TK] = rkb / NB;
int32 rbb[TK] = rkb % NB;
int32 rbw[TK] = rbwh % CW;
int32 rbh[TK] = rbwh / CW;)";
}
else{
result += R"(
int32 rbbh[TK] = rkb / CW;
int32 rbw[TK] = rkb % CW;
int32 rbh[TK] = rbbh % CH;
int32 rbb[TK] = rbbh / CH;)";
}
result += R"(
result +=
compute_bhw("rb", "TK", "rkb") + R"(
__constant__ int32* pd[TN] = delta_a + ryb;
int32 d[TN] = *pd;
int32 shift[TK, TN] = d[newaxis, :];
rbw = rbw * stride_w;
rbh = rbh * stride_h;
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
)";
if(shift_edge_h_)
result += " int1 interiorh[TK] = 1;\n";
else
result += " int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));\n";
if(shift_edge_w_)
result += " int1 interiorw[TK] = 1;";
else
result += " int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));";
result += R"(
int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
)" + compute_interior("rb", "TK", "TN") + R"(
int32 incb[TK, TN] = interior ? shift : 0;
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
int32 offb1[TK, TN] = offkb[:, newaxis] + incb;)";
@@ -421,8 +422,8 @@ if(op_ == WGRAD){
result += R"(
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;
int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(;
int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(;
int1 checka[)" + AS + "] = (rka < K + offk)" + bca0 + R"(;
int1 checkb[)" + BS + "] = (rkb < K + offk)" + bcb0 + R"(;
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
)" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0;
for(int32 k = K; k > 0; k = k - TK){
@@ -450,22 +451,8 @@ if(op_ == WGRAD && layout_ == CHWN){
}
if(op_ == WGRAD && layout_ == NCHW){
result += R"(
rka = rka + TK;)";
if(is_chwn){
result += R"(
int32 rawh[TK] = rka / NB;
int32 rab[TK] = rka % NB;
int32 raw[TK] = rawh % CW;
int32 rah[TK] = rawh / CW;)";
}
else{
result += R"(
int32 rabh[TK] = rka / CW;
int32 raw[TK] = rka % CW;
int32 rah[TK] = rabh % CH;
int32 rab[TK] = rabh / CH;)";
}
result += R"(
rka = rka + TK;)"
+ compute_bhw("ra", "TK", "rka") + R"(
offxa = rab*lda_b + raw*lda_w + rah*lda_h;
pa = A + offa0 + offxa[:, newaxis];)";
}
@@ -475,36 +462,12 @@ if(op_ == WGRAD && layout_ == NCHW){
/* Increment B pointers */
if(op_ == WGRAD){
result += R"(
rkb = rkb + TK;)";
if(is_chwn){
result += R"(
int32 rbwh[TK] = rkb / NB;
int32 rbb[TK] = rkb % NB;
int32 rbw[TK] = rbwh % CW;
int32 rbh[TK] = rbwh / CW;)";
}
else{
result += R"(
int32 rbbh[TK] = rkb / CW;
int32 rbw[TK] = rkb % CW;
int32 rbh[TK] = rbbh % CH;
int32 rbb[TK] = rbbh / CH;)";
}
result += R"(
rkb = rkb + TK;)"
+ compute_bhw("rb", "TK", "rkb") + R"(
rbw = rbw * stride_w;
rbh = rbh * stride_h;
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
)";
if(shift_edge_h_)
result += " interiorh = 1;\n";
else
result += " interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h));\n";
if(shift_edge_w_)
result += " interiorw = 1;";
else
result += " interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w));";
result += R"(
interior = interiorh[:, newaxis] && interiorw[:, newaxis];
)" + compute_interior("rb", "TK", "TN") + R"(
incb = interior ? shift : 0;
pb = B + offb0 + offkb[:, newaxis] + incb;)";
}
@@ -524,41 +487,15 @@ if(op_ == BPROP){
/* C offsets */
if(op_ == BPROP){
if(is_chwn){
result += R"(
int32 rcwh[TM] = rxc / NB;
int32 rcb[TM] = rxc % NB;
int32 rcw[TM] = rcwh % CW;
int32 rch[TM] = rcwh / CW;)";
}
else{
result += R"(
int32 rcbh[TM] = rxc / CW;
int32 rcw[TM] = rxc % CW;
int32 rch[TM] = rcbh % CH;
int32 rcb[TM] = rcbh / CH;)";
}
result += R"(
result +=
compute_bhw("rc", "TM", "rxc") + R"(
rcw = rcw * stride_w;
rch = rch * stride_h;
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
}
if(op_ == FPROP){
if(is_chwn){
result += R"(
int32 rcwh[TM] = rxc / NB;
int32 rcb[TM] = rxc % NB;
int32 rcw[TM] = rcwh % CW;
int32 rch[TM] = rcwh / CW;)";
}
else{
result += R"(
int32 rcbh[TM] = rxc / CW;
int32 rcw[TM] = rxc % CW;
int32 rch[TM] = rcbh % CH;
int32 rcb[TM] = rcbh / CH;)";
}
result += R"(
result +=
compute_bhw("rc", "TM", "rxc") + R"(
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
}
if(op_ == WGRAD){
@@ -572,17 +509,8 @@ if(op_ == WGRAD){
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
if(op_ == BPROP){
result += "\n";
if(shift_edge_h_)
result += " int1 interiorh[TM] = 1;\n";
else
result += " int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));\n";
if(shift_edge_w_)
result += " int1 interiorw[TM] = 1;";
else
result += " int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));";
result += R"(
int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
)" + compute_interior("rc", "TM", "TN") + R"(
__constant__ int32* pd[TN] = delta_a + ryc;
)" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
pc = interior ? shift_pc : pc;
@@ -591,12 +519,32 @@ if(op_ == BPROP){
}
else{
result += R"(
@checkc *pc = c;)";
int1 has_lock = (GZ > 1) && (locks != 0);
if(has_lock){
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
int32 *plock = locks + ridx + ridy*grid0;
while(__atomic_cas(plock, 0, 1));
int32 *pcount = plock + grid0*grid1;
int32 count = *pcount;
int32 countp1 = select(count == split - 1, 0, count + 1);
if(count == 0) {
@checkc *pc = c;
*pcount = countp1;
}
else {
@checkc *pc = c + *pc;
*pcount = countp1;
}
*plock = 0;
}
else{
@checkc *pc = c;
})";
}
result += R"(
})";
// std::cout << result << std::endl;
os << result;
}

View File

@@ -73,10 +73,14 @@ ir::value *binary_expression::llvm_op(ir::module *mod, ir::builder &builder, ir:
return builder.create_icmpSGE(lhs, rhs, name);
if(op_ == GE && is_int && !is_signed)
return builder.create_icmpUGE(lhs, rhs, name);
if(op_ == EQ && is_ptr)
return builder.create_icmpEQ(lhs, rhs, name);
if(op_ == EQ && is_float)
return builder.create_fcmpOEQ(lhs, rhs, name);
if(op_ == EQ && is_int)
return builder.create_icmpEQ(lhs, rhs, name);
if(op_ == NE && is_ptr)
return builder.create_icmpNE(lhs, rhs, name);
if(op_ == NE && is_float)
return builder.create_fcmpONE(lhs, rhs, name);
if(op_ == NE && is_int)

View File

@@ -43,7 +43,8 @@ void loop_nest(std::vector<size_t> const & ranges,
size_t current = 0;
while(true){
//Execute function
pool.add_job([values, &f](){ f(values); });
// pool.add_job([values, &f](){ f(values); });
f(values);
//Increment counters
while(values[i]++ == ranges[i] - 1){
if(i == 0)
@@ -169,7 +170,7 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
ranges.push_back(mp->get_space());
// iterate over parameters
tune_res_t best;
size_t nthreads = 1;
size_t nthreads = 4;
std::mutex mutex;
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
std::map<ir::value*, std::vector<std::string>> errors;