[dnn/shift] now strictly only shifting the interior

This commit is contained in:
Philippe Tillet
2019-07-16 20:18:48 -07:00
parent ec24e1e7df
commit 07c964919c
6 changed files with 34 additions and 35 deletions

View File

@@ -75,7 +75,7 @@ torch::Tensor shift_common(
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
// Enqueue
shift.enqueue(&stream, {&a, &b, &c}, true);
shift.enqueue(&stream, {&a, &b, &c}, false);
return torchc;
}

View File

@@ -122,7 +122,7 @@ public:
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
shift.enqueue(stream, {&da, &db, &dc});
shift.enqueue(stream, {&da, &db, &dc}, false);
}
private:

View File

@@ -128,6 +128,14 @@ private:
int32_t CD_;
int32_t CH_;
int32_t CW_;
// interior image size
int32_t IAD_;
int32_t IAH_;
int32_t IAW_;
// interior activation size
int32_t ICD_;
int32_t ICH_;
int32_t ICW_;
// equivalent matmul
int32_t M_;
int32_t N_;

View File

@@ -223,7 +223,7 @@ void tune::run(ir::module &mod) {
}
else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
}
}
@@ -237,7 +237,7 @@ void tune::run(ir::module &mod) {
continue;
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
*params_.at(i).at("nts.d0") = *tmp;
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){

View File

@@ -79,10 +79,15 @@ shift::shift(int B, int C,
default:
throw std::runtime_error("unsupported input layout");
}
IAD_ = AD_ - 2*(BD_/2);
IAH_ = AH_ - 2*(BH_/2);
IAW_ = AW_ - 2*(BW_/2);
ICD_ = IAD_ / stride_d_;
ICH_ = IAH_ / stride_h_;
ICW_ = IAW_ / stride_w_;
// Equivalent matmul
M_ = B_*(CH_ - BH_ / 2)*(CW_ - BW_/2);
if(M_ == 0)
throw std::runtime_error("unsupported input shapes - no interior !");
M_ = B_*ICH_*ICW_;
N_ = F_;
K_ = C_;
// transpose
@@ -247,21 +252,21 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(18, ldc_h_);
kernel->setArg(19, ldc_f_);
kernel->setArg(20, B_);
kernel->setArg(21, AH_);
kernel->setArg(22, AW_);
kernel->setArg(21, IAH_);
kernel->setArg(22, IAW_);
kernel->setArg(23, BH_);
kernel->setArg(24, BW_);
kernel->setArg(25, CH_);
kernel->setArg(26, CW_);
kernel->setArg(25, ICH_);
kernel->setArg(26, ICW_);
kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_);
kernel->setArg(28, (int32_t)grid[0]);
kernel->setArg(29, (int32_t)grid[1]);
kernel->setArg(30, (int32_t)grid[2]);
if(locks_)
((driver::cu_buffer*)locks_)->set_zero(stream, 2*max_locks_*4);
if(op_ == BPROP){
if(op_ == FPROP || op_ == BPROP){
size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4;
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes);
((driver::cu_buffer*)c)->set_zero(stream, c_size()*c_nbytes);
}
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
}
@@ -290,33 +295,18 @@ void shift::triton_c_src(std::ostream &os) const {
return R"(
int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB;
int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB;
int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW + pad_w;
int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW + pad_h;)";
int32 )" + rx + "w[" + sz + "] = (" + rx + R"(wh % CW) + pad_w;
int32 )" + rx + "h[" + sz + "] = (" + rx + R"(wh / CW) + pad_h;)";
}
else {
return R"(
int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW;
int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW + pad_w;
int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH + pad_h;
int32 )" + rx + "w[" + sz + "] = (" + rkx + R"( % CW) + pad_w;
int32 )" + rx + "h[" + sz + "] = (" + rx + R"(bh % CH) + pad_h;
int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)";
}
};
auto compute_interior = [&](std::string rx, std::string sz0, std::string sz1) {
std::string result;
if(shift_edge_h_)
result += "int1 interiorh[" + sz0 + "] = 1;\n ";
else
result += "int1 interiorh[" + sz0 + "] = (" + rx + "h >= pad_h) && (" + rx + "h < (AH - pad_h));\n ";
if(shift_edge_w_)
result += "int1 interiorw[" + sz0 + "] = 1;";
else
result += "int1 interiorw[" + sz0 + "] = (" + rx + "w >= pad_w) && (" + rx + "w < (AW - pad_w));";
result += R"(
int1 interior[)" + sz0 + ", " + sz1 + "] = interiorh[:, newaxis] && interiorw[:, newaxis];";
return result;
};
std::string result =
R"(
const tunable int32 TM = {16, 32, 64, 128};
@@ -506,8 +496,8 @@ if(op_ == WGRAD){
if(op_ == BPROP){
result += R"(
__constant__ int32* pd[TN] = delta_a + ryc;
pc = pc + (*pd)[newaxis, :];
@checkc *pc = c;
)" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
@checkc *shift_pc = c;
)";
}
else{

View File

@@ -43,7 +43,8 @@ void loop_nest(std::vector<size_t> const & ranges,
// size_t current = 0;
while(true){
//Execute function
pool.enqueue([values, &f](){ f(values); });
// pool.enqueue([values, &f](){ f(values); });
f(values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;