[dnn/shift] now strictly only shifting the interior
This commit is contained in:
@@ -75,7 +75,7 @@ torch::Tensor shift_common(
|
||||
|
||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||
// Enqueue
|
||||
shift.enqueue(&stream, {&a, &b, &c}, true);
|
||||
shift.enqueue(&stream, {&a, &b, &c}, false);
|
||||
return torchc;
|
||||
}
|
||||
|
||||
|
@@ -122,7 +122,7 @@ public:
|
||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
|
||||
shift.enqueue(stream, {&da, &db, &dc});
|
||||
shift.enqueue(stream, {&da, &db, &dc}, false);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@@ -128,6 +128,14 @@ private:
|
||||
int32_t CD_;
|
||||
int32_t CH_;
|
||||
int32_t CW_;
|
||||
// interior image size
|
||||
int32_t IAD_;
|
||||
int32_t IAH_;
|
||||
int32_t IAW_;
|
||||
// interior activation size
|
||||
int32_t ICD_;
|
||||
int32_t ICH_;
|
||||
int32_t ICW_;
|
||||
// equivalent matmul
|
||||
int32_t M_;
|
||||
int32_t N_;
|
||||
|
@@ -223,7 +223,7 @@ void tune::run(ir::module &mod) {
|
||||
}
|
||||
else {
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
}
|
||||
@@ -237,7 +237,7 @@ void tune::run(ir::module &mod) {
|
||||
continue;
|
||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
|
@@ -79,10 +79,15 @@ shift::shift(int B, int C,
|
||||
default:
|
||||
throw std::runtime_error("unsupported input layout");
|
||||
}
|
||||
IAD_ = AD_ - 2*(BD_/2);
|
||||
IAH_ = AH_ - 2*(BH_/2);
|
||||
IAW_ = AW_ - 2*(BW_/2);
|
||||
ICD_ = IAD_ / stride_d_;
|
||||
ICH_ = IAH_ / stride_h_;
|
||||
ICW_ = IAW_ / stride_w_;
|
||||
|
||||
// Equivalent matmul
|
||||
M_ = B_*(CH_ - BH_ / 2)*(CW_ - BW_/2);
|
||||
if(M_ == 0)
|
||||
throw std::runtime_error("unsupported input shapes - no interior !");
|
||||
M_ = B_*ICH_*ICW_;
|
||||
N_ = F_;
|
||||
K_ = C_;
|
||||
// transpose
|
||||
@@ -247,21 +252,21 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(18, ldc_h_);
|
||||
kernel->setArg(19, ldc_f_);
|
||||
kernel->setArg(20, B_);
|
||||
kernel->setArg(21, AH_);
|
||||
kernel->setArg(22, AW_);
|
||||
kernel->setArg(21, IAH_);
|
||||
kernel->setArg(22, IAW_);
|
||||
kernel->setArg(23, BH_);
|
||||
kernel->setArg(24, BW_);
|
||||
kernel->setArg(25, CH_);
|
||||
kernel->setArg(26, CW_);
|
||||
kernel->setArg(25, ICH_);
|
||||
kernel->setArg(26, ICW_);
|
||||
kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_);
|
||||
kernel->setArg(28, (int32_t)grid[0]);
|
||||
kernel->setArg(29, (int32_t)grid[1]);
|
||||
kernel->setArg(30, (int32_t)grid[2]);
|
||||
if(locks_)
|
||||
((driver::cu_buffer*)locks_)->set_zero(stream, 2*max_locks_*4);
|
||||
if(op_ == BPROP){
|
||||
if(op_ == FPROP || op_ == BPROP){
|
||||
size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4;
|
||||
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes);
|
||||
((driver::cu_buffer*)c)->set_zero(stream, c_size()*c_nbytes);
|
||||
}
|
||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||
}
|
||||
@@ -290,33 +295,18 @@ void shift::triton_c_src(std::ostream &os) const {
|
||||
return R"(
|
||||
int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB;
|
||||
int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB;
|
||||
int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW + pad_w;
|
||||
int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW + pad_h;)";
|
||||
int32 )" + rx + "w[" + sz + "] = (" + rx + R"(wh % CW) + pad_w;
|
||||
int32 )" + rx + "h[" + sz + "] = (" + rx + R"(wh / CW) + pad_h;)";
|
||||
}
|
||||
else {
|
||||
return R"(
|
||||
int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW;
|
||||
int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW + pad_w;
|
||||
int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH + pad_h;
|
||||
int32 )" + rx + "w[" + sz + "] = (" + rkx + R"( % CW) + pad_w;
|
||||
int32 )" + rx + "h[" + sz + "] = (" + rx + R"(bh % CH) + pad_h;
|
||||
int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)";
|
||||
}
|
||||
};
|
||||
|
||||
auto compute_interior = [&](std::string rx, std::string sz0, std::string sz1) {
|
||||
std::string result;
|
||||
if(shift_edge_h_)
|
||||
result += "int1 interiorh[" + sz0 + "] = 1;\n ";
|
||||
else
|
||||
result += "int1 interiorh[" + sz0 + "] = (" + rx + "h >= pad_h) && (" + rx + "h < (AH - pad_h));\n ";
|
||||
if(shift_edge_w_)
|
||||
result += "int1 interiorw[" + sz0 + "] = 1;";
|
||||
else
|
||||
result += "int1 interiorw[" + sz0 + "] = (" + rx + "w >= pad_w) && (" + rx + "w < (AW - pad_w));";
|
||||
result += R"(
|
||||
int1 interior[)" + sz0 + ", " + sz1 + "] = interiorh[:, newaxis] && interiorw[:, newaxis];";
|
||||
return result;
|
||||
};
|
||||
|
||||
std::string result =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
@@ -506,8 +496,8 @@ if(op_ == WGRAD){
|
||||
if(op_ == BPROP){
|
||||
result += R"(
|
||||
__constant__ int32* pd[TN] = delta_a + ryc;
|
||||
pc = pc + (*pd)[newaxis, :];
|
||||
@checkc *pc = c;
|
||||
)" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
|
||||
@checkc *shift_pc = c;
|
||||
)";
|
||||
}
|
||||
else{
|
||||
|
@@ -43,7 +43,8 @@ void loop_nest(std::vector<size_t> const & ranges,
|
||||
// size_t current = 0;
|
||||
while(true){
|
||||
//Execute function
|
||||
pool.enqueue([values, &f](){ f(values); });
|
||||
// pool.enqueue([values, &f](){ f(values); });
|
||||
f(values);
|
||||
while(values[i]++ == ranges[i] - 1){
|
||||
if(i == 0)
|
||||
return;
|
||||
|
Reference in New Issue
Block a user