[dnn/shift] added split-k for shift-conv

This commit is contained in:
Philippe Tillet
2019-07-15 21:03:58 -07:00
parent 434f65737f
commit aa8bcf6bde
9 changed files with 166 additions and 203 deletions

View File

@@ -18,8 +18,8 @@ int main() {
// initialization // initialization
int32_t R = 3, S = 3; int32_t R = 3, S = 3;
int32_t B = 32, F = 128; int32_t B = 128, F = 128;
int32_t H = 28, W = 28; int32_t H = 16, W = 16;
int32_t C = 128; int32_t C = 128;
// random shifts // random shifts
@@ -44,9 +44,9 @@ int main() {
std::swap(b_size, c_size); std::swap(b_size, c_size);
std::swap(a_size, b_size); std::swap(a_size, b_size);
} }
std::vector<NumericT> ha(B*C*H*W); std::vector<NumericT> ha(a_size);
std::vector<NumericT> hb(C*F); std::vector<NumericT> hb(b_size);
std::vector<float> hc(B*F*H*W); std::vector<float> hc(c_size);
std::vector<float> rc(hc.size()); std::vector<float> rc(hc.size());
// device buffers // device buffers
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4); triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);

View File

@@ -58,9 +58,9 @@ def blocksparse_matmul_grad(op, dy):
return (dx, dw) return (dx, dw)
def run_shift(): def run_shift():
B, C, H, W = 1, 16, 4, 4 B, C, H, W = 2, 16, 4, 4
R, S, F = 3, 3, 16 R, S, F = 3, 3, 16
stride_h, stride_w = 2, 2 stride_h, stride_w = 1, 1
np.random.seed(2) np.random.seed(2)
a = tf.placeholder(tf.float16, shape=[B, C, H, W]) a = tf.placeholder(tf.float16, shape=[B, C, H, W])
b = tf.placeholder(tf.float16, shape=[C, F]) b = tf.placeholder(tf.float16, shape=[C, F])
@@ -82,8 +82,8 @@ def run_shift():
dx_t, dx_n = grads[0] dx_t, dx_n = grads[0]
#import sys #import sys
#np.set_printoptions(threshold=sys.maxsize) #np.set_printoptions(threshold=sys.maxsize)
print(dx_t) print(dw_t)
print(dx_n) print(dw_n)
print(np.max(np.abs(dw_t - dw_n))) print(np.max(np.abs(dw_t - dw_n)))
print(np.max(np.abs(dx_t - dx_n))) print(np.max(np.abs(dx_t - dx_n)))
# Run # Run

View File

@@ -43,6 +43,8 @@ protected:
private: private:
// initialize // initialize
virtual void init_impl(driver::stream *, driver::cu_module *){ } virtual void init_impl(driver::stream *, driver::cu_module *){ }
// deinitialize
virtual void deinit_impl(){ }
// enqueue // enqueue
virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel, virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args, std::vector<driver::buffer*> args,

View File

@@ -52,6 +52,7 @@ public:
private: private:
// initialize and enqueue // initialize and enqueue
void init_impl(driver::stream *stream, driver::cu_module *module); void init_impl(driver::stream *stream, driver::cu_module *module);
void deinit_impl();
void enqueue_impl(driver::stream *stream, driver::kernel *kernel, void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args, std::vector<driver::buffer*> args,
triton::runtime::launch_information info); triton::runtime::launch_information info);
@@ -163,6 +164,9 @@ private:
bool BT_; bool BT_;
// layout // layout
layout_t layout_; layout_t layout_;
// locks
size_t max_locks_;
driver::buffer *locks_;
}; };
} }

View File

@@ -32,15 +32,17 @@ double bench(OP const & op, SYNC const & sync, const triton::driver::device * de
double total_time = 0; double total_time = 0;
op(); op();
sync(); sync();
float norm = 1; // while(total_time*1e-9 < 1e-3){
// normalize clock if possible to get roughly constant result float norm = 1;
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device)) // normalize clock if possible to get roughly constant result
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
tmr.start(); norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
op(); tmr.start();
sync(); op();
times.push_back(norm*tmr.get().count()); sync();
total_time+=times.back(); times.push_back(norm*tmr.get().count());
total_time+=times.back();
// }
return *std::min_element(times.begin(), times.end()); return *std::min_element(times.begin(), times.end());
} }

View File

@@ -29,19 +29,20 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
rt::jit* jit; rt::jit* jit;
/* the current template has not already been compiled */ /* the current template has not already been compiled */
if(m_jit.find(this) == m_jit.end()) { if(m_jit.find(this) == m_jit.end()) {
jit = m_jit.emplace(this->clone(), new rt::jit(ctx)).first->second.get(); base* clone = this->clone();
jit = m_jit.emplace(clone, new rt::jit(ctx)).first->second.get();
std::ostringstream oss; std::ostringstream oss;
triton_c_src(oss); clone->triton_c_src(oss);
std::string src = oss.str(); std::string src = oss.str();
auto benchmark = [&](triton::driver::kernel* kernel, auto benchmark = [&](triton::driver::kernel* kernel,
rt::launch_information info) { rt::launch_information info) {
// launch info // launch info
unsigned nthreads = info.num_threads; clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
init_impl(stream, (triton::driver::cu_module*)kernel->module()); clone->enqueue_impl(stream, kernel, args, info);
enqueue_impl(stream, kernel, args, info);
stream->synchronize(); stream->synchronize();
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, info); }, double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); },
[&](){ stream->synchronize(); }, ctx->device()); [&](){ stream->synchronize(); }, ctx->device());
clone->deinit_impl();
return num_flops() / ts * 1e-3; return num_flops() / ts * 1e-3;
}; };
// auto-tune and save result // auto-tune and save result
@@ -53,7 +54,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str())); jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
} }
triton::driver::kernel* kernel = jit->get_function(name_.c_str()); triton::driver::kernel* kernel = jit->get_function(name_.c_str());
init_impl(stream, (triton::driver::cu_module*)kernel->module()); clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
} }
/* retrieved compiled template */ /* retrieved compiled template */
else else
@@ -63,7 +64,8 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
driver::kernel* kernel = jit->get_function(name_.c_str()); driver::kernel* kernel = jit->get_function(name_.c_str());
rt::launch_information info = jit->get_launch_info(name_.c_str()); rt::launch_information info = jit->get_launch_info(name_.c_str());
/* launch */ /* launch */
enqueue_impl(stream, kernel, args, info); auto it = m_jit.find(this);
it->first->enqueue_impl(stream, kernel, args, info);
} }
} }

View File

@@ -124,6 +124,9 @@ shift::shift(int B, int C,
if(layout_ == NCHW) if(layout_ == NCHW)
shapes_c_ = {B, C, AH_, AW_}; shapes_c_ = {B, C, AH_, AW_};
} }
// locks
max_locks_ = (op_ == WGRAD) ? 8192 : 0;
locks_ = nullptr;
} }
base* shift::clone() const { base* shift::clone() const {
@@ -195,11 +198,30 @@ void shift::init_impl(driver::stream *stream, driver::cu_module *module) {
build_delta_a(); build_delta_a();
triton::driver::buffer* delta_a = ((triton::driver::cu_module*)module)->symbol("delta_a"); triton::driver::buffer* delta_a = ((triton::driver::cu_module*)module)->symbol("delta_a");
stream->write(delta_a, false, 0, h_delta_a.size()*4, h_delta_a.data()); stream->write(delta_a, false, 0, h_delta_a.size()*4, h_delta_a.data());
// locks
if(locks_ == nullptr && max_locks_ > 0){
std::vector<int32_t> hlocks(2*max_locks_, 0);
locks_ = triton::driver::buffer::create(stream->context(), 2*max_locks_*4);
stream->write(locks_, false, 0, hlocks);
}
}
void shift::deinit_impl() {
if(locks_ != nullptr){
delete locks_;
locks_ = nullptr;
}
} }
void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel, void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer *> args, std::vector<driver::buffer *> args,
runtime::launch_information info) { runtime::launch_information info) {
unsigned TM = info.global_range_size[0], TN = info.global_range_size[1];
unsigned grid_0 = (M_ + TM - 1)/TM;
unsigned grid_1 = (N_ + TN - 1)/TN;
unsigned num_locks = grid_0 * grid_1;
unsigned grid_2 = num_locks < max_locks_ ? info.globals.at("GZ") : 1;
std::array<size_t, 3> grid = {grid_0, grid_1, grid_2};
driver::buffer *a = args[0], *b = args[1], *c = args[2]; driver::buffer *a = args[0], *b = args[1], *c = args[2];
kernel->setArg(0, a); kernel->setArg(0, a);
kernel->setArg(1, b); kernel->setArg(1, b);
@@ -228,8 +250,9 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(24, BW_); kernel->setArg(24, BW_);
kernel->setArg(25, CH_); kernel->setArg(25, CH_);
kernel->setArg(26, CW_); kernel->setArg(26, CW_);
unsigned TM = info.global_range_size[0], TN = info.global_range_size[1]; kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_);
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; kernel->setArg(28, (int32_t)grid[0]);
kernel->setArg(29, (int32_t)grid[1]);
if(op_ == BPROP){ if(op_ == BPROP){
size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4; size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4;
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes); ((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes);
@@ -256,12 +279,49 @@ void shift::triton_c_src(std::ostream &os) const {
std::string BS = BS0 + ", " + BS1; std::string BS = BS0 + ", " + BS1;
bool is_chwn = layout_ == CHWN; bool is_chwn = layout_ == CHWN;
auto compute_bhw = [&](std::string rx, std::string sz, std::string rkx){
if(is_chwn) {
return R"(
int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB;
int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB;
int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW;
int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW;)";
}
else {
return R"(
int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW;
int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW;
int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH;
int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)";
}
};
auto compute_interior = [&](std::string rx, std::string sz0, std::string sz1) {
std::string result;
if(shift_edge_h_)
result += "int1 interiorh[" + sz0 + "] = 1;\n ";
else
result += "int1 interiorh[" + sz0 + "] = (" + rx + "h >= pad_h) && (" + rx + "h < (AH - pad_h));\n ";
if(shift_edge_w_)
result += "int1 interiorw[" + sz0 + "] = 1;";
else
result += "int1 interiorw[" + sz0 + "] = (" + rx + "w >= pad_w) && (" + rx + "w < (AW - pad_w));";
result += R"(
int1 interior[)" + sz0 + ", " + sz1 + "] = interiorh[:, newaxis] && interiorw[:, newaxis];";
return result;
};
std::string result = std::string result =
R"( R"(
const tunable int32 TM = {16, 32, 64, 128}; const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128}; const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {)" + std::to_string(TK_) + R"(}; const tunable int32 TK = {)" + std::to_string(TK_) + "};";
if(op_ == WGRAD)
result += "const tunable int32 GZ = {1, 4, 16};";
else
result += "const tunable int32 GZ = {1};";
result += R"(
__constant__ int32* delta_a = alloc_const int32[)" + std::to_string(MAX_C_) + R"(]; __constant__ int32* delta_a = alloc_const int32[)" + std::to_string(MAX_C_) + R"(];
void shift(restrict read_only align(16) )" + a_ty_ + R"( *A, void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
@@ -275,32 +335,32 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
int32 NB, int32 NB,
int32 AH, int32 AW, int32 AH, int32 AW,
int32 BH, int32 BW, int32 BH, int32 BW,
int32 CH, int32 CW) { int32 CH, int32 CW,
int32* locks, int32 grid0, int32 grid1) {
int32 rxa[TM] = get_global_range[TM](0); int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1); int32 ryb[TN] = get_global_range[TN](1);
int32 rz = get_global_range[1](2);
int32 rka[TK] = 0 ... TK; int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK; int32 rkb[TK] = 0 ... TK;
fp32 acc[TM, TN] = 0; fp32 acc[TM, TN] = 0;
int32 pad_h = BH / 2; int32 pad_h = BH / 2;
int32 pad_w = BW / 2;)"; int32 pad_w = BW / 2;
int32 split = select(locks == 0, 1, GZ);
int32 div = K / split;
int32 rem = K % split;
K = select(rz < rem, div - 1, div);
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);)";
if(op_ == WGRAD){
result += R"(
rka = rka + offk;
rkb = rkb + offk;
)";
}
/* A offsets */ /* A offsets */
if(op_ == FPROP){ if(op_ == FPROP){
if(is_chwn){ result +=
result += R"( compute_bhw("ra", "TM", "rxa") + R"(
int32 rawh[TM] = rxa / NB;
int32 rab[TM] = rxa % NB;
int32 raw[TM] = rawh % CW;
int32 rah[TM] = rawh / CW;)";
}
else{
result += R"(
int32 rabh[TM] = rxa / CW;
int32 raw[TM] = rxa % CW;
int32 rah[TM] = rabh % CH;
int32 rab[TM] = rabh / CH;)";
}
result += R"(
raw = raw * stride_w; raw = raw * stride_w;
rah = rah * stride_h; rah = rah * stride_h;
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h; int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
@@ -309,35 +369,12 @@ if(op_ == FPROP){
int32 d[TK] = *pd; int32 d[TK] = *pd;
int32 offa_interior[TM, TK] = d[newaxis, :]; int32 offa_interior[TM, TK] = d[newaxis, :];
int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c; int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;
)"; )" + compute_interior("ra", "TM", "TK") + R"(
if(shift_edge_h_)
result += " int1 interiorh[TM] = 1;\n";
else
result += " int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));\n";
if(shift_edge_w_)
result += " int1 interiorw[TM] = 1;";
else
result += " int1 interiorw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));";
result += R"(
int1 interior[TM, TK] = interiorh[:, newaxis] && interiorw[:, newaxis];
int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)"; int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)";
} }
if(op_ == BPROP){ if(op_ == BPROP){
if(is_chwn){ result +=
result += R"( compute_bhw("ra", "TM", "rxa") + R"(
int32 rawh[TM] = rxa / NB;
int32 rab[TM] = rxa % NB;
int32 raw[TM] = rawh % CW;
int32 rah[TM] = rawh / CW;)";
}
else{
result += R"(
int32 rabh[TM] = rxa / CW;
int32 raw[TM] = rxa % CW;
int32 rah[TM] = rabh % CH;
int32 rab[TM] = rabh / CH;)";
}
result += R"(
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h; int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
int32 offa0[TM, TK] = offxa[:, newaxis]; int32 offa0[TM, TK] = offxa[:, newaxis];
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)"; int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
@@ -348,21 +385,8 @@ if(op_ == WGRAD && layout_ == CHWN){
int32 offa1[TK, TM] = rka[:, newaxis];)"; int32 offa1[TK, TM] = rka[:, newaxis];)";
} }
if(op_ == WGRAD && layout_ == NCHW){ if(op_ == WGRAD && layout_ == NCHW){
if(is_chwn){ result +=
result += R"( compute_bhw("ra", "TK", "rka") + R"(
int32 rawh[TK] = rka / NB;
int32 rab[TK] = rka % NB;
int32 raw[TK] = rawh % CW;
int32 rah[TK] = rawh / CW;)";
}
else{
result += R"(
int32 rabh[TK] = rka / CW;
int32 raw[TK] = rka % CW;
int32 rah[TK] = rabh % CH;
int32 rab[TK] = rabh / CH;)";
}
result += R"(
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c; int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h; int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h;
int32 offa1[TK, TM] = offxa[:, newaxis];)"; int32 offa1[TK, TM] = offxa[:, newaxis];)";
@@ -380,38 +404,15 @@ if(op_ == BPROP){
int32 offb1[TK, TN] = rkb[:, newaxis];)"; int32 offb1[TK, TN] = rkb[:, newaxis];)";
} }
if(op_ == WGRAD){ if(op_ == WGRAD){
if(is_chwn){ result +=
result += R"( compute_bhw("rb", "TK", "rkb") + R"(
int32 rbwh[TK] = rkb / NB;
int32 rbb[TK] = rkb % NB;
int32 rbw[TK] = rbwh % CW;
int32 rbh[TK] = rbwh / CW;)";
}
else{
result += R"(
int32 rbbh[TK] = rkb / CW;
int32 rbw[TK] = rkb % CW;
int32 rbh[TK] = rbbh % CH;
int32 rbb[TK] = rbbh / CH;)";
}
result += R"(
__constant__ int32* pd[TN] = delta_a + ryb; __constant__ int32* pd[TN] = delta_a + ryb;
int32 d[TN] = *pd; int32 d[TN] = *pd;
int32 shift[TK, TN] = d[newaxis, :]; int32 shift[TK, TN] = d[newaxis, :];
rbw = rbw * stride_w; rbw = rbw * stride_w;
rbh = rbh * stride_h; rbh = rbh * stride_h;
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
)"; )" + compute_interior("rb", "TK", "TN") + R"(
if(shift_edge_h_)
result += " int1 interiorh[TK] = 1;\n";
else
result += " int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));\n";
if(shift_edge_w_)
result += " int1 interiorw[TK] = 1;";
else
result += " int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));";
result += R"(
int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
int32 incb[TK, TN] = interior ? shift : 0; int32 incb[TK, TN] = interior ? shift : 0;
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
int32 offb1[TK, TN] = offkb[:, newaxis] + incb;)"; int32 offb1[TK, TN] = offkb[:, newaxis] + incb;)";
@@ -421,8 +422,8 @@ if(op_ == WGRAD){
result += R"( result += R"(
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1; )" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1; )" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;
int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(; int1 checka[)" + AS + "] = (rka < K + offk)" + bca0 + R"(;
int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(; int1 checkb[)" + BS + "] = (rkb < K + offk)" + bcb0 + R"(;
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0; )" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
)" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0; )" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0;
for(int32 k = K; k > 0; k = k - TK){ for(int32 k = K; k > 0; k = k - TK){
@@ -450,22 +451,8 @@ if(op_ == WGRAD && layout_ == CHWN){
} }
if(op_ == WGRAD && layout_ == NCHW){ if(op_ == WGRAD && layout_ == NCHW){
result += R"( result += R"(
rka = rka + TK;)"; rka = rka + TK;)"
if(is_chwn){ + compute_bhw("ra", "TK", "rka") + R"(
result += R"(
int32 rawh[TK] = rka / NB;
int32 rab[TK] = rka % NB;
int32 raw[TK] = rawh % CW;
int32 rah[TK] = rawh / CW;)";
}
else{
result += R"(
int32 rabh[TK] = rka / CW;
int32 raw[TK] = rka % CW;
int32 rah[TK] = rabh % CH;
int32 rab[TK] = rabh / CH;)";
}
result += R"(
offxa = rab*lda_b + raw*lda_w + rah*lda_h; offxa = rab*lda_b + raw*lda_w + rah*lda_h;
pa = A + offa0 + offxa[:, newaxis];)"; pa = A + offa0 + offxa[:, newaxis];)";
} }
@@ -475,36 +462,12 @@ if(op_ == WGRAD && layout_ == NCHW){
/* Increment B pointers */ /* Increment B pointers */
if(op_ == WGRAD){ if(op_ == WGRAD){
result += R"( result += R"(
rkb = rkb + TK;)"; rkb = rkb + TK;)"
if(is_chwn){ + compute_bhw("rb", "TK", "rkb") + R"(
result += R"(
int32 rbwh[TK] = rkb / NB;
int32 rbb[TK] = rkb % NB;
int32 rbw[TK] = rbwh % CW;
int32 rbh[TK] = rbwh / CW;)";
}
else{
result += R"(
int32 rbbh[TK] = rkb / CW;
int32 rbw[TK] = rkb % CW;
int32 rbh[TK] = rbbh % CH;
int32 rbb[TK] = rbbh / CH;)";
}
result += R"(
rbw = rbw * stride_w; rbw = rbw * stride_w;
rbh = rbh * stride_h; rbh = rbh * stride_h;
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
)"; )" + compute_interior("rb", "TK", "TN") + R"(
if(shift_edge_h_)
result += " interiorh = 1;\n";
else
result += " interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h));\n";
if(shift_edge_w_)
result += " interiorw = 1;";
else
result += " interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w));";
result += R"(
interior = interiorh[:, newaxis] && interiorw[:, newaxis];
incb = interior ? shift : 0; incb = interior ? shift : 0;
pb = B + offb0 + offkb[:, newaxis] + incb;)"; pb = B + offb0 + offkb[:, newaxis] + incb;)";
} }
@@ -524,41 +487,15 @@ if(op_ == BPROP){
/* C offsets */ /* C offsets */
if(op_ == BPROP){ if(op_ == BPROP){
if(is_chwn){ result +=
result += R"( compute_bhw("rc", "TM", "rxc") + R"(
int32 rcwh[TM] = rxc / NB;
int32 rcb[TM] = rxc % NB;
int32 rcw[TM] = rcwh % CW;
int32 rch[TM] = rcwh / CW;)";
}
else{
result += R"(
int32 rcbh[TM] = rxc / CW;
int32 rcw[TM] = rxc % CW;
int32 rch[TM] = rcbh % CH;
int32 rcb[TM] = rcbh / CH;)";
}
result += R"(
rcw = rcw * stride_w; rcw = rcw * stride_w;
rch = rch * stride_h; rch = rch * stride_h;
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)"; int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
} }
if(op_ == FPROP){ if(op_ == FPROP){
if(is_chwn){ result +=
result += R"( compute_bhw("rc", "TM", "rxc") + R"(
int32 rcwh[TM] = rxc / NB;
int32 rcb[TM] = rxc % NB;
int32 rcw[TM] = rcwh % CW;
int32 rch[TM] = rcwh / CW;)";
}
else{
result += R"(
int32 rcbh[TM] = rxc / CW;
int32 rcw[TM] = rxc % CW;
int32 rch[TM] = rcbh % CH;
int32 rcb[TM] = rcbh / CH;)";
}
result += R"(
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)"; int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
} }
if(op_ == WGRAD){ if(op_ == WGRAD){
@@ -572,17 +509,8 @@ if(op_ == WGRAD){
int1 checkc1[TN] = ryc < N; int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)"; int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
if(op_ == BPROP){ if(op_ == BPROP){
result += "\n"; result += R"(
if(shift_edge_h_) )" + compute_interior("rc", "TM", "TN") + R"(
result += " int1 interiorh[TM] = 1;\n";
else
result += " int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));\n";
if(shift_edge_w_)
result += " int1 interiorw[TM] = 1;";
else
result += " int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));";
result += R"(
int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
__constant__ int32* pd[TN] = delta_a + ryc; __constant__ int32* pd[TN] = delta_a + ryc;
)" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :]; )" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
pc = interior ? shift_pc : pc; pc = interior ? shift_pc : pc;
@@ -591,12 +519,32 @@ if(op_ == BPROP){
} }
else{ else{
result += R"( result += R"(
@checkc *pc = c;)"; int1 has_lock = (GZ > 1) && (locks != 0);
if(has_lock){
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
int32 *plock = locks + ridx + ridy*grid0;
while(__atomic_cas(plock, 0, 1));
int32 *pcount = plock + grid0*grid1;
int32 count = *pcount;
int32 countp1 = select(count == split - 1, 0, count + 1);
if(count == 0) {
@checkc *pc = c;
*pcount = countp1;
}
else {
@checkc *pc = c + *pc;
*pcount = countp1;
}
*plock = 0;
}
else{
@checkc *pc = c;
})";
} }
result += R"( result += R"(
})"; })";
// std::cout << result << std::endl;
os << result; os << result;
} }

View File

@@ -73,10 +73,14 @@ ir::value *binary_expression::llvm_op(ir::module *mod, ir::builder &builder, ir:
return builder.create_icmpSGE(lhs, rhs, name); return builder.create_icmpSGE(lhs, rhs, name);
if(op_ == GE && is_int && !is_signed) if(op_ == GE && is_int && !is_signed)
return builder.create_icmpUGE(lhs, rhs, name); return builder.create_icmpUGE(lhs, rhs, name);
if(op_ == EQ && is_ptr)
return builder.create_icmpEQ(lhs, rhs, name);
if(op_ == EQ && is_float) if(op_ == EQ && is_float)
return builder.create_fcmpOEQ(lhs, rhs, name); return builder.create_fcmpOEQ(lhs, rhs, name);
if(op_ == EQ && is_int) if(op_ == EQ && is_int)
return builder.create_icmpEQ(lhs, rhs, name); return builder.create_icmpEQ(lhs, rhs, name);
if(op_ == NE && is_ptr)
return builder.create_icmpNE(lhs, rhs, name);
if(op_ == NE && is_float) if(op_ == NE && is_float)
return builder.create_fcmpONE(lhs, rhs, name); return builder.create_fcmpONE(lhs, rhs, name);
if(op_ == NE && is_int) if(op_ == NE && is_int)

View File

@@ -43,7 +43,8 @@ void loop_nest(std::vector<size_t> const & ranges,
size_t current = 0; size_t current = 0;
while(true){ while(true){
//Execute function //Execute function
pool.add_job([values, &f](){ f(values); }); // pool.add_job([values, &f](){ f(values); });
f(values);
//Increment counters //Increment counters
while(values[i]++ == ranges[i] - 1){ while(values[i]++ == ranges[i] - 1){
if(i == 0) if(i == 0)
@@ -169,7 +170,7 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
ranges.push_back(mp->get_space()); ranges.push_back(mp->get_space());
// iterate over parameters // iterate over parameters
tune_res_t best; tune_res_t best;
size_t nthreads = 1; size_t nthreads = 4;
std::mutex mutex; std::mutex mutex;
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){ loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
std::map<ir::value*, std::vector<std::string>> errors; std::map<ir::value*, std::vector<std::string>> errors;