[dnn/shift] now strictly only shifting the interior

This commit is contained in:
Philippe Tillet
2019-07-16 20:18:48 -07:00
parent ec24e1e7df
commit 07c964919c
6 changed files with 34 additions and 35 deletions

View File

@@ -75,7 +75,7 @@ torch::Tensor shift_common(
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false); triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
// Enqueue // Enqueue
shift.enqueue(&stream, {&a, &b, &c}, true); shift.enqueue(&stream, {&a, &b, &c}, false);
return torchc; return torchc;
} }

View File

@@ -122,7 +122,7 @@ public:
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false); triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false); triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false); triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
shift.enqueue(stream, {&da, &db, &dc}); shift.enqueue(stream, {&da, &db, &dc}, false);
} }
private: private:

View File

@@ -128,6 +128,14 @@ private:
int32_t CD_; int32_t CD_;
int32_t CH_; int32_t CH_;
int32_t CW_; int32_t CW_;
// interior image size
int32_t IAD_;
int32_t IAH_;
int32_t IAW_;
// interior activation size
int32_t ICD_;
int32_t ICH_;
int32_t ICW_;
// equivalent matmul // equivalent matmul
int32_t M_; int32_t M_;
int32_t N_; int32_t N_;

View File

@@ -223,7 +223,7 @@ void tune::run(ir::module &mod) {
} }
else { else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2); ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4); ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++); connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
} }
} }
@@ -237,7 +237,7 @@ void tune::run(ir::module &mod) {
continue; continue;
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){ if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty(); ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4)); std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
*params_.at(i).at("nts.d0") = *tmp; *params_.at(i).at("nts.d0") = *tmp;
} }
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){ if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){

View File

@@ -79,10 +79,15 @@ shift::shift(int B, int C,
default: default:
throw std::runtime_error("unsupported input layout"); throw std::runtime_error("unsupported input layout");
} }
IAD_ = AD_ - 2*(BD_/2);
IAH_ = AH_ - 2*(BH_/2);
IAW_ = AW_ - 2*(BW_/2);
ICD_ = IAD_ / stride_d_;
ICH_ = IAH_ / stride_h_;
ICW_ = IAW_ / stride_w_;
// Equivalent matmul // Equivalent matmul
M_ = B_*(CH_ - BH_ / 2)*(CW_ - BW_/2); M_ = B_*ICH_*ICW_;
if(M_ == 0)
throw std::runtime_error("unsupported input shapes - no interior !");
N_ = F_; N_ = F_;
K_ = C_; K_ = C_;
// transpose // transpose
@@ -247,21 +252,21 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(18, ldc_h_); kernel->setArg(18, ldc_h_);
kernel->setArg(19, ldc_f_); kernel->setArg(19, ldc_f_);
kernel->setArg(20, B_); kernel->setArg(20, B_);
kernel->setArg(21, AH_); kernel->setArg(21, IAH_);
kernel->setArg(22, AW_); kernel->setArg(22, IAW_);
kernel->setArg(23, BH_); kernel->setArg(23, BH_);
kernel->setArg(24, BW_); kernel->setArg(24, BW_);
kernel->setArg(25, CH_); kernel->setArg(25, ICH_);
kernel->setArg(26, CW_); kernel->setArg(26, ICW_);
kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_); kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_);
kernel->setArg(28, (int32_t)grid[0]); kernel->setArg(28, (int32_t)grid[0]);
kernel->setArg(29, (int32_t)grid[1]); kernel->setArg(29, (int32_t)grid[1]);
kernel->setArg(30, (int32_t)grid[2]); kernel->setArg(30, (int32_t)grid[2]);
if(locks_) if(locks_)
((driver::cu_buffer*)locks_)->set_zero(stream, 2*max_locks_*4); ((driver::cu_buffer*)locks_)->set_zero(stream, 2*max_locks_*4);
if(op_ == BPROP){ if(op_ == FPROP || op_ == BPROP){
size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4; size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4;
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes); ((driver::cu_buffer*)c)->set_zero(stream, c_size()*c_nbytes);
} }
stream->enqueue(kernel, grid, {info.num_threads, 1, 1}); stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
} }
@@ -290,33 +295,18 @@ void shift::triton_c_src(std::ostream &os) const {
return R"( return R"(
int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB; int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB;
int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB; int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB;
int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW + pad_w; int32 )" + rx + "w[" + sz + "] = (" + rx + R"(wh % CW) + pad_w;
int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW + pad_h;)"; int32 )" + rx + "h[" + sz + "] = (" + rx + R"(wh / CW) + pad_h;)";
} }
else { else {
return R"( return R"(
int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW; int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW;
int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW + pad_w; int32 )" + rx + "w[" + sz + "] = (" + rkx + R"( % CW) + pad_w;
int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH + pad_h; int32 )" + rx + "h[" + sz + "] = (" + rx + R"(bh % CH) + pad_h;
int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)"; int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)";
} }
}; };
auto compute_interior = [&](std::string rx, std::string sz0, std::string sz1) {
std::string result;
if(shift_edge_h_)
result += "int1 interiorh[" + sz0 + "] = 1;\n ";
else
result += "int1 interiorh[" + sz0 + "] = (" + rx + "h >= pad_h) && (" + rx + "h < (AH - pad_h));\n ";
if(shift_edge_w_)
result += "int1 interiorw[" + sz0 + "] = 1;";
else
result += "int1 interiorw[" + sz0 + "] = (" + rx + "w >= pad_w) && (" + rx + "w < (AW - pad_w));";
result += R"(
int1 interior[)" + sz0 + ", " + sz1 + "] = interiorh[:, newaxis] && interiorw[:, newaxis];";
return result;
};
std::string result = std::string result =
R"( R"(
const tunable int32 TM = {16, 32, 64, 128}; const tunable int32 TM = {16, 32, 64, 128};
@@ -506,8 +496,8 @@ if(op_ == WGRAD){
if(op_ == BPROP){ if(op_ == BPROP){
result += R"( result += R"(
__constant__ int32* pd[TN] = delta_a + ryc; __constant__ int32* pd[TN] = delta_a + ryc;
pc = pc + (*pd)[newaxis, :]; )" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
@checkc *pc = c; @checkc *shift_pc = c;
)"; )";
} }
else{ else{

View File

@@ -43,7 +43,8 @@ void loop_nest(std::vector<size_t> const & ranges,
// size_t current = 0; // size_t current = 0;
while(true){ while(true){
//Execute function //Execute function
pool.enqueue([values, &f](){ f(values); }); // pool.enqueue([values, &f](){ f(values); });
f(values);
while(values[i]++ == ranges[i] - 1){ while(values[i]++ == ranges[i] - 1){
if(i == 0) if(i == 0)
return; return;