[dnn/shift] added split-k for shift-conv
This commit is contained in:
@@ -18,8 +18,8 @@ int main() {
|
||||
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
int32_t B = 32, F = 128;
|
||||
int32_t H = 28, W = 28;
|
||||
int32_t B = 128, F = 128;
|
||||
int32_t H = 16, W = 16;
|
||||
int32_t C = 128;
|
||||
|
||||
// random shifts
|
||||
@@ -44,9 +44,9 @@ int main() {
|
||||
std::swap(b_size, c_size);
|
||||
std::swap(a_size, b_size);
|
||||
}
|
||||
std::vector<NumericT> ha(B*C*H*W);
|
||||
std::vector<NumericT> hb(C*F);
|
||||
std::vector<float> hc(B*F*H*W);
|
||||
std::vector<NumericT> ha(a_size);
|
||||
std::vector<NumericT> hb(b_size);
|
||||
std::vector<float> hc(c_size);
|
||||
std::vector<float> rc(hc.size());
|
||||
// device buffers
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
|
@@ -58,9 +58,9 @@ def blocksparse_matmul_grad(op, dy):
|
||||
return (dx, dw)
|
||||
|
||||
def run_shift():
|
||||
B, C, H, W = 1, 16, 4, 4
|
||||
B, C, H, W = 2, 16, 4, 4
|
||||
R, S, F = 3, 3, 16
|
||||
stride_h, stride_w = 2, 2
|
||||
stride_h, stride_w = 1, 1
|
||||
np.random.seed(2)
|
||||
a = tf.placeholder(tf.float16, shape=[B, C, H, W])
|
||||
b = tf.placeholder(tf.float16, shape=[C, F])
|
||||
@@ -82,8 +82,8 @@ def run_shift():
|
||||
dx_t, dx_n = grads[0]
|
||||
#import sys
|
||||
#np.set_printoptions(threshold=sys.maxsize)
|
||||
print(dx_t)
|
||||
print(dx_n)
|
||||
print(dw_t)
|
||||
print(dw_n)
|
||||
print(np.max(np.abs(dw_t - dw_n)))
|
||||
print(np.max(np.abs(dx_t - dx_n)))
|
||||
# Run
|
||||
|
@@ -43,6 +43,8 @@ protected:
|
||||
private:
|
||||
// initialize
|
||||
virtual void init_impl(driver::stream *, driver::cu_module *){ }
|
||||
// deinitialize
|
||||
virtual void deinit_impl(){ }
|
||||
// enqueue
|
||||
virtual void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
|
@@ -52,6 +52,7 @@ public:
|
||||
private:
|
||||
// initialize and enqueue
|
||||
void init_impl(driver::stream *stream, driver::cu_module *module);
|
||||
void deinit_impl();
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
triton::runtime::launch_information info);
|
||||
@@ -163,6 +164,9 @@ private:
|
||||
bool BT_;
|
||||
// layout
|
||||
layout_t layout_;
|
||||
// locks
|
||||
size_t max_locks_;
|
||||
driver::buffer *locks_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -32,6 +32,7 @@ double bench(OP const & op, SYNC const & sync, const triton::driver::device * de
|
||||
double total_time = 0;
|
||||
op();
|
||||
sync();
|
||||
// while(total_time*1e-9 < 1e-3){
|
||||
float norm = 1;
|
||||
// normalize clock if possible to get roughly constant result
|
||||
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
|
||||
@@ -41,6 +42,7 @@ double bench(OP const & op, SYNC const & sync, const triton::driver::device * de
|
||||
sync();
|
||||
times.push_back(norm*tmr.get().count());
|
||||
total_time+=times.back();
|
||||
// }
|
||||
return *std::min_element(times.begin(), times.end());
|
||||
}
|
||||
|
||||
|
@@ -29,19 +29,20 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
|
||||
rt::jit* jit;
|
||||
/* the current template has not already been compiled */
|
||||
if(m_jit.find(this) == m_jit.end()) {
|
||||
jit = m_jit.emplace(this->clone(), new rt::jit(ctx)).first->second.get();
|
||||
base* clone = this->clone();
|
||||
jit = m_jit.emplace(clone, new rt::jit(ctx)).first->second.get();
|
||||
std::ostringstream oss;
|
||||
triton_c_src(oss);
|
||||
clone->triton_c_src(oss);
|
||||
std::string src = oss.str();
|
||||
auto benchmark = [&](triton::driver::kernel* kernel,
|
||||
rt::launch_information info) {
|
||||
// launch info
|
||||
unsigned nthreads = info.num_threads;
|
||||
init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
enqueue_impl(stream, kernel, args, info);
|
||||
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
clone->enqueue_impl(stream, kernel, args, info);
|
||||
stream->synchronize();
|
||||
double ts = triton::tools::bench([&](){ enqueue_impl(stream, kernel, args, info); },
|
||||
double ts = triton::tools::bench([&](){ clone->enqueue_impl(stream, kernel, args, info); },
|
||||
[&](){ stream->synchronize(); }, ctx->device());
|
||||
clone->deinit_impl();
|
||||
return num_flops() / ts * 1e-3;
|
||||
};
|
||||
// auto-tune and save result
|
||||
@@ -53,7 +54,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
|
||||
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
}
|
||||
/* retrieved compiled template */
|
||||
else
|
||||
@@ -63,7 +64,8 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
|
||||
driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
rt::launch_information info = jit->get_launch_info(name_.c_str());
|
||||
/* launch */
|
||||
enqueue_impl(stream, kernel, args, info);
|
||||
auto it = m_jit.find(this);
|
||||
it->first->enqueue_impl(stream, kernel, args, info);
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -124,6 +124,9 @@ shift::shift(int B, int C,
|
||||
if(layout_ == NCHW)
|
||||
shapes_c_ = {B, C, AH_, AW_};
|
||||
}
|
||||
// locks
|
||||
max_locks_ = (op_ == WGRAD) ? 8192 : 0;
|
||||
locks_ = nullptr;
|
||||
}
|
||||
|
||||
base* shift::clone() const {
|
||||
@@ -195,11 +198,30 @@ void shift::init_impl(driver::stream *stream, driver::cu_module *module) {
|
||||
build_delta_a();
|
||||
triton::driver::buffer* delta_a = ((triton::driver::cu_module*)module)->symbol("delta_a");
|
||||
stream->write(delta_a, false, 0, h_delta_a.size()*4, h_delta_a.data());
|
||||
// locks
|
||||
if(locks_ == nullptr && max_locks_ > 0){
|
||||
std::vector<int32_t> hlocks(2*max_locks_, 0);
|
||||
locks_ = triton::driver::buffer::create(stream->context(), 2*max_locks_*4);
|
||||
stream->write(locks_, false, 0, hlocks);
|
||||
}
|
||||
}
|
||||
|
||||
void shift::deinit_impl() {
|
||||
if(locks_ != nullptr){
|
||||
delete locks_;
|
||||
locks_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer *> args,
|
||||
runtime::launch_information info) {
|
||||
unsigned TM = info.global_range_size[0], TN = info.global_range_size[1];
|
||||
unsigned grid_0 = (M_ + TM - 1)/TM;
|
||||
unsigned grid_1 = (N_ + TN - 1)/TN;
|
||||
unsigned num_locks = grid_0 * grid_1;
|
||||
unsigned grid_2 = num_locks < max_locks_ ? info.globals.at("GZ") : 1;
|
||||
std::array<size_t, 3> grid = {grid_0, grid_1, grid_2};
|
||||
driver::buffer *a = args[0], *b = args[1], *c = args[2];
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
@@ -228,8 +250,9 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(24, BW_);
|
||||
kernel->setArg(25, CH_);
|
||||
kernel->setArg(26, CW_);
|
||||
unsigned TM = info.global_range_size[0], TN = info.global_range_size[1];
|
||||
std::array<size_t, 3> grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1};
|
||||
kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_);
|
||||
kernel->setArg(28, (int32_t)grid[0]);
|
||||
kernel->setArg(29, (int32_t)grid[1]);
|
||||
if(op_ == BPROP){
|
||||
size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4;
|
||||
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes);
|
||||
@@ -256,12 +279,49 @@ void shift::triton_c_src(std::ostream &os) const {
|
||||
std::string BS = BS0 + ", " + BS1;
|
||||
bool is_chwn = layout_ == CHWN;
|
||||
|
||||
auto compute_bhw = [&](std::string rx, std::string sz, std::string rkx){
|
||||
if(is_chwn) {
|
||||
return R"(
|
||||
int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB;
|
||||
int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB;
|
||||
int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW;
|
||||
int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW;)";
|
||||
}
|
||||
else {
|
||||
return R"(
|
||||
int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW;
|
||||
int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW;
|
||||
int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH;
|
||||
int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)";
|
||||
}
|
||||
};
|
||||
|
||||
auto compute_interior = [&](std::string rx, std::string sz0, std::string sz1) {
|
||||
std::string result;
|
||||
if(shift_edge_h_)
|
||||
result += "int1 interiorh[" + sz0 + "] = 1;\n ";
|
||||
else
|
||||
result += "int1 interiorh[" + sz0 + "] = (" + rx + "h >= pad_h) && (" + rx + "h < (AH - pad_h));\n ";
|
||||
if(shift_edge_w_)
|
||||
result += "int1 interiorw[" + sz0 + "] = 1;";
|
||||
else
|
||||
result += "int1 interiorw[" + sz0 + "] = (" + rx + "w >= pad_w) && (" + rx + "w < (AW - pad_w));";
|
||||
result += R"(
|
||||
int1 interior[)" + sz0 + ", " + sz1 + "] = interiorh[:, newaxis] && interiorw[:, newaxis];";
|
||||
return result;
|
||||
};
|
||||
|
||||
std::string result =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {)" + std::to_string(TK_) + R"(};
|
||||
const tunable int32 TK = {)" + std::to_string(TK_) + "};";
|
||||
if(op_ == WGRAD)
|
||||
result += "const tunable int32 GZ = {1, 4, 16};";
|
||||
else
|
||||
result += "const tunable int32 GZ = {1};";
|
||||
|
||||
result += R"(
|
||||
__constant__ int32* delta_a = alloc_const int32[)" + std::to_string(MAX_C_) + R"(];
|
||||
|
||||
void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
@@ -275,32 +335,32 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
int32 NB,
|
||||
int32 AH, int32 AW,
|
||||
int32 BH, int32 BW,
|
||||
int32 CH, int32 CW) {
|
||||
int32 CH, int32 CW,
|
||||
int32* locks, int32 grid0, int32 grid1) {
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 ryb[TN] = get_global_range[TN](1);
|
||||
int32 rz = get_global_range[1](2);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 acc[TM, TN] = 0;
|
||||
int32 pad_h = BH / 2;
|
||||
int32 pad_w = BW / 2;)";
|
||||
int32 pad_w = BW / 2;
|
||||
int32 split = select(locks == 0, 1, GZ);
|
||||
int32 div = K / split;
|
||||
int32 rem = K % split;
|
||||
K = select(rz < rem, div - 1, div);
|
||||
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);)";
|
||||
if(op_ == WGRAD){
|
||||
result += R"(
|
||||
rka = rka + offk;
|
||||
rkb = rkb + offk;
|
||||
)";
|
||||
}
|
||||
|
||||
/* A offsets */
|
||||
if(op_ == FPROP){
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rawh[TM] = rxa / NB;
|
||||
int32 rab[TM] = rxa % NB;
|
||||
int32 raw[TM] = rawh % CW;
|
||||
int32 rah[TM] = rawh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rabh[TM] = rxa / CW;
|
||||
int32 raw[TM] = rxa % CW;
|
||||
int32 rah[TM] = rabh % CH;
|
||||
int32 rab[TM] = rabh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
result +=
|
||||
compute_bhw("ra", "TM", "rxa") + R"(
|
||||
raw = raw * stride_w;
|
||||
rah = rah * stride_h;
|
||||
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
@@ -309,35 +369,12 @@ if(op_ == FPROP){
|
||||
int32 d[TK] = *pd;
|
||||
int32 offa_interior[TM, TK] = d[newaxis, :];
|
||||
int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;
|
||||
)";
|
||||
if(shift_edge_h_)
|
||||
result += " int1 interiorh[TM] = 1;\n";
|
||||
else
|
||||
result += " int1 interiorh[TM] = (rah >= pad_h) && (rah < (AH - pad_h));\n";
|
||||
if(shift_edge_w_)
|
||||
result += " int1 interiorw[TM] = 1;";
|
||||
else
|
||||
result += " int1 interiorw[TM] = (raw >= pad_w) && (raw < (AW - pad_w));";
|
||||
result += R"(
|
||||
int1 interior[TM, TK] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
)" + compute_interior("ra", "TM", "TK") + R"(
|
||||
int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rawh[TM] = rxa / NB;
|
||||
int32 rab[TM] = rxa % NB;
|
||||
int32 raw[TM] = rawh % CW;
|
||||
int32 rah[TM] = rawh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rabh[TM] = rxa / CW;
|
||||
int32 raw[TM] = rxa % CW;
|
||||
int32 rah[TM] = rabh % CH;
|
||||
int32 rab[TM] = rabh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
result +=
|
||||
compute_bhw("ra", "TM", "rxa") + R"(
|
||||
int32 offxa[TM] = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
|
||||
@@ -348,21 +385,8 @@ if(op_ == WGRAD && layout_ == CHWN){
|
||||
int32 offa1[TK, TM] = rka[:, newaxis];)";
|
||||
}
|
||||
if(op_ == WGRAD && layout_ == NCHW){
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rawh[TK] = rka / NB;
|
||||
int32 rab[TK] = rka % NB;
|
||||
int32 raw[TK] = rawh % CW;
|
||||
int32 rah[TK] = rawh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rabh[TK] = rka / CW;
|
||||
int32 raw[TK] = rka % CW;
|
||||
int32 rah[TK] = rabh % CH;
|
||||
int32 rab[TK] = rabh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
result +=
|
||||
compute_bhw("ra", "TK", "rka") + R"(
|
||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||
int32 offxa[TK] = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
int32 offa1[TK, TM] = offxa[:, newaxis];)";
|
||||
@@ -380,38 +404,15 @@ if(op_ == BPROP){
|
||||
int32 offb1[TK, TN] = rkb[:, newaxis];)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rbwh[TK] = rkb / NB;
|
||||
int32 rbb[TK] = rkb % NB;
|
||||
int32 rbw[TK] = rbwh % CW;
|
||||
int32 rbh[TK] = rbwh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rbbh[TK] = rkb / CW;
|
||||
int32 rbw[TK] = rkb % CW;
|
||||
int32 rbh[TK] = rbbh % CH;
|
||||
int32 rbb[TK] = rbbh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
result +=
|
||||
compute_bhw("rb", "TK", "rkb") + R"(
|
||||
__constant__ int32* pd[TN] = delta_a + ryb;
|
||||
int32 d[TN] = *pd;
|
||||
int32 shift[TK, TN] = d[newaxis, :];
|
||||
rbw = rbw * stride_w;
|
||||
rbh = rbh * stride_h;
|
||||
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
|
||||
)";
|
||||
if(shift_edge_h_)
|
||||
result += " int1 interiorh[TK] = 1;\n";
|
||||
else
|
||||
result += " int1 interiorh[TK] = (rbh >= pad_h) && (rbh < (AH - pad_h));\n";
|
||||
if(shift_edge_w_)
|
||||
result += " int1 interiorw[TK] = 1;";
|
||||
else
|
||||
result += " int1 interiorw[TK] = (rbw >= pad_w) && (rbw < (AW - pad_w));";
|
||||
result += R"(
|
||||
int1 interior[TK, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
)" + compute_interior("rb", "TK", "TN") + R"(
|
||||
int32 incb[TK, TN] = interior ? shift : 0;
|
||||
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||
int32 offb1[TK, TN] = offkb[:, newaxis] + incb;)";
|
||||
@@ -421,8 +422,8 @@ if(op_ == WGRAD){
|
||||
result += R"(
|
||||
)" + a_ty_ + "* pa[" + AS + R"(] = A + offa0 + offa1;
|
||||
)" + b_ty_ + "* pb[" + BS + R"(] = B + offb0 + offb1;
|
||||
int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(;
|
||||
int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(;
|
||||
int1 checka[)" + AS + "] = (rka < K + offk)" + bca0 + R"(;
|
||||
int1 checkb[)" + BS + "] = (rkb < K + offk)" + bcb0 + R"(;
|
||||
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
|
||||
)" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
@@ -450,22 +451,8 @@ if(op_ == WGRAD && layout_ == CHWN){
|
||||
}
|
||||
if(op_ == WGRAD && layout_ == NCHW){
|
||||
result += R"(
|
||||
rka = rka + TK;)";
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rawh[TK] = rka / NB;
|
||||
int32 rab[TK] = rka % NB;
|
||||
int32 raw[TK] = rawh % CW;
|
||||
int32 rah[TK] = rawh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rabh[TK] = rka / CW;
|
||||
int32 raw[TK] = rka % CW;
|
||||
int32 rah[TK] = rabh % CH;
|
||||
int32 rab[TK] = rabh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
rka = rka + TK;)"
|
||||
+ compute_bhw("ra", "TK", "rka") + R"(
|
||||
offxa = rab*lda_b + raw*lda_w + rah*lda_h;
|
||||
pa = A + offa0 + offxa[:, newaxis];)";
|
||||
}
|
||||
@@ -475,36 +462,12 @@ if(op_ == WGRAD && layout_ == NCHW){
|
||||
/* Increment B pointers */
|
||||
if(op_ == WGRAD){
|
||||
result += R"(
|
||||
rkb = rkb + TK;)";
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rbwh[TK] = rkb / NB;
|
||||
int32 rbb[TK] = rkb % NB;
|
||||
int32 rbw[TK] = rbwh % CW;
|
||||
int32 rbh[TK] = rbwh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rbbh[TK] = rkb / CW;
|
||||
int32 rbw[TK] = rkb % CW;
|
||||
int32 rbh[TK] = rbbh % CH;
|
||||
int32 rbb[TK] = rbbh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
rkb = rkb + TK;)"
|
||||
+ compute_bhw("rb", "TK", "rkb") + R"(
|
||||
rbw = rbw * stride_w;
|
||||
rbh = rbh * stride_h;
|
||||
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
|
||||
)";
|
||||
if(shift_edge_h_)
|
||||
result += " interiorh = 1;\n";
|
||||
else
|
||||
result += " interiorh = (rbh >= pad_h) && (rbh < (AH - pad_h));\n";
|
||||
if(shift_edge_w_)
|
||||
result += " interiorw = 1;";
|
||||
else
|
||||
result += " interiorw = (rbw >= pad_w) && (rbw < (AW - pad_w));";
|
||||
result += R"(
|
||||
interior = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
)" + compute_interior("rb", "TK", "TN") + R"(
|
||||
incb = interior ? shift : 0;
|
||||
pb = B + offb0 + offkb[:, newaxis] + incb;)";
|
||||
}
|
||||
@@ -524,41 +487,15 @@ if(op_ == BPROP){
|
||||
|
||||
/* C offsets */
|
||||
if(op_ == BPROP){
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rcwh[TM] = rxc / NB;
|
||||
int32 rcb[TM] = rxc % NB;
|
||||
int32 rcw[TM] = rcwh % CW;
|
||||
int32 rch[TM] = rcwh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rcbh[TM] = rxc / CW;
|
||||
int32 rcw[TM] = rxc % CW;
|
||||
int32 rch[TM] = rcbh % CH;
|
||||
int32 rcb[TM] = rcbh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
result +=
|
||||
compute_bhw("rc", "TM", "rxc") + R"(
|
||||
rcw = rcw * stride_w;
|
||||
rch = rch * stride_h;
|
||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
|
||||
}
|
||||
if(op_ == FPROP){
|
||||
if(is_chwn){
|
||||
result += R"(
|
||||
int32 rcwh[TM] = rxc / NB;
|
||||
int32 rcb[TM] = rxc % NB;
|
||||
int32 rcw[TM] = rcwh % CW;
|
||||
int32 rch[TM] = rcwh / CW;)";
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
int32 rcbh[TM] = rxc / CW;
|
||||
int32 rcw[TM] = rxc % CW;
|
||||
int32 rch[TM] = rcbh % CH;
|
||||
int32 rcb[TM] = rcbh / CH;)";
|
||||
}
|
||||
result += R"(
|
||||
result +=
|
||||
compute_bhw("rc", "TM", "rxc") + R"(
|
||||
int32 offxc[TM] = rcb*ldc_b + rcw*ldc_w + rch*ldc_h;)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
@@ -572,17 +509,8 @@ if(op_ == WGRAD){
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
|
||||
if(op_ == BPROP){
|
||||
result += "\n";
|
||||
if(shift_edge_h_)
|
||||
result += " int1 interiorh[TM] = 1;\n";
|
||||
else
|
||||
result += " int1 interiorh[TM] = (rch >= pad_h) && (rch < (AH - pad_h));\n";
|
||||
if(shift_edge_w_)
|
||||
result += " int1 interiorw[TM] = 1;";
|
||||
else
|
||||
result += " int1 interiorw[TM] = (rcw >= pad_w) && (rcw < (AW - pad_w));";
|
||||
result += R"(
|
||||
int1 interior[TM, TN] = interiorh[:, newaxis] && interiorw[:, newaxis];
|
||||
)" + compute_interior("rc", "TM", "TN") + R"(
|
||||
__constant__ int32* pd[TN] = delta_a + ryc;
|
||||
)" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
|
||||
pc = interior ? shift_pc : pc;
|
||||
@@ -591,12 +519,32 @@ if(op_ == BPROP){
|
||||
}
|
||||
else{
|
||||
result += R"(
|
||||
@checkc *pc = c;)";
|
||||
int1 has_lock = (GZ > 1) && (locks != 0);
|
||||
if(has_lock){
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
int32 *plock = locks + ridx + ridy*grid0;
|
||||
while(__atomic_cas(plock, 0, 1));
|
||||
int32 *pcount = plock + grid0*grid1;
|
||||
int32 count = *pcount;
|
||||
int32 countp1 = select(count == split - 1, 0, count + 1);
|
||||
if(count == 0) {
|
||||
@checkc *pc = c;
|
||||
*pcount = countp1;
|
||||
}
|
||||
else {
|
||||
@checkc *pc = c + *pc;
|
||||
*pcount = countp1;
|
||||
}
|
||||
*plock = 0;
|
||||
}
|
||||
else{
|
||||
@checkc *pc = c;
|
||||
})";
|
||||
}
|
||||
result += R"(
|
||||
})";
|
||||
|
||||
// std::cout << result << std::endl;
|
||||
os << result;
|
||||
}
|
||||
|
||||
|
@@ -73,10 +73,14 @@ ir::value *binary_expression::llvm_op(ir::module *mod, ir::builder &builder, ir:
|
||||
return builder.create_icmpSGE(lhs, rhs, name);
|
||||
if(op_ == GE && is_int && !is_signed)
|
||||
return builder.create_icmpUGE(lhs, rhs, name);
|
||||
if(op_ == EQ && is_ptr)
|
||||
return builder.create_icmpEQ(lhs, rhs, name);
|
||||
if(op_ == EQ && is_float)
|
||||
return builder.create_fcmpOEQ(lhs, rhs, name);
|
||||
if(op_ == EQ && is_int)
|
||||
return builder.create_icmpEQ(lhs, rhs, name);
|
||||
if(op_ == NE && is_ptr)
|
||||
return builder.create_icmpNE(lhs, rhs, name);
|
||||
if(op_ == NE && is_float)
|
||||
return builder.create_fcmpONE(lhs, rhs, name);
|
||||
if(op_ == NE && is_int)
|
||||
|
@@ -43,7 +43,8 @@ void loop_nest(std::vector<size_t> const & ranges,
|
||||
size_t current = 0;
|
||||
while(true){
|
||||
//Execute function
|
||||
pool.add_job([values, &f](){ f(values); });
|
||||
// pool.add_job([values, &f](){ f(values); });
|
||||
f(values);
|
||||
//Increment counters
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
@@ -169,7 +170,7 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
ranges.push_back(mp->get_space());
|
||||
// iterate over parameters
|
||||
tune_res_t best;
|
||||
size_t nthreads = 1;
|
||||
size_t nthreads = 4;
|
||||
std::mutex mutex;
|
||||
loop_nest<unsigned>(ranges, [&](const std::vector<unsigned> params){
|
||||
std::map<ir::value*, std::vector<std::string>> errors;
|
||||
|
Reference in New Issue
Block a user