[dnn/shift] now strictly only shifting the interior
This commit is contained in:
@@ -75,7 +75,7 @@ torch::Tensor shift_common(
|
|||||||
|
|
||||||
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
triton::driver::cu_buffer c(ctx, (CUdeviceptr)torchc.storage().data(), false);
|
||||||
// Enqueue
|
// Enqueue
|
||||||
shift.enqueue(&stream, {&a, &b, &c}, true);
|
shift.enqueue(&stream, {&a, &b, &c}, false);
|
||||||
return torchc;
|
return torchc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -122,7 +122,7 @@ public:
|
|||||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
|
triton::driver::cu_buffer da(ctx, (CUdeviceptr)tf_a.flat<Eigen::half>().data(), false);
|
||||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
|
triton::driver::cu_buffer db(ctx, (CUdeviceptr)tf_b.flat<Eigen::half>().data(), false);
|
||||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
|
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)tf_c->flat<Eigen::half>().data(), false);
|
||||||
shift.enqueue(stream, {&da, &db, &dc});
|
shift.enqueue(stream, {&da, &db, &dc}, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@@ -128,6 +128,14 @@ private:
|
|||||||
int32_t CD_;
|
int32_t CD_;
|
||||||
int32_t CH_;
|
int32_t CH_;
|
||||||
int32_t CW_;
|
int32_t CW_;
|
||||||
|
// interior image size
|
||||||
|
int32_t IAD_;
|
||||||
|
int32_t IAH_;
|
||||||
|
int32_t IAW_;
|
||||||
|
// interior activation size
|
||||||
|
int32_t ICD_;
|
||||||
|
int32_t ICH_;
|
||||||
|
int32_t ICW_;
|
||||||
// equivalent matmul
|
// equivalent matmul
|
||||||
int32_t M_;
|
int32_t M_;
|
||||||
int32_t N_;
|
int32_t N_;
|
||||||
|
@@ -223,7 +223,7 @@ void tune::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4);
|
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -237,7 +237,7 @@ void tune::run(ir::module &mod) {
|
|||||||
continue;
|
continue;
|
||||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
|
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
|
||||||
*params_.at(i).at("nts.d0") = *tmp;
|
*params_.at(i).at("nts.d0") = *tmp;
|
||||||
}
|
}
|
||||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||||
|
@@ -79,10 +79,15 @@ shift::shift(int B, int C,
|
|||||||
default:
|
default:
|
||||||
throw std::runtime_error("unsupported input layout");
|
throw std::runtime_error("unsupported input layout");
|
||||||
}
|
}
|
||||||
|
IAD_ = AD_ - 2*(BD_/2);
|
||||||
|
IAH_ = AH_ - 2*(BH_/2);
|
||||||
|
IAW_ = AW_ - 2*(BW_/2);
|
||||||
|
ICD_ = IAD_ / stride_d_;
|
||||||
|
ICH_ = IAH_ / stride_h_;
|
||||||
|
ICW_ = IAW_ / stride_w_;
|
||||||
|
|
||||||
// Equivalent matmul
|
// Equivalent matmul
|
||||||
M_ = B_*(CH_ - BH_ / 2)*(CW_ - BW_/2);
|
M_ = B_*ICH_*ICW_;
|
||||||
if(M_ == 0)
|
|
||||||
throw std::runtime_error("unsupported input shapes - no interior !");
|
|
||||||
N_ = F_;
|
N_ = F_;
|
||||||
K_ = C_;
|
K_ = C_;
|
||||||
// transpose
|
// transpose
|
||||||
@@ -247,21 +252,21 @@ void shift::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
|||||||
kernel->setArg(18, ldc_h_);
|
kernel->setArg(18, ldc_h_);
|
||||||
kernel->setArg(19, ldc_f_);
|
kernel->setArg(19, ldc_f_);
|
||||||
kernel->setArg(20, B_);
|
kernel->setArg(20, B_);
|
||||||
kernel->setArg(21, AH_);
|
kernel->setArg(21, IAH_);
|
||||||
kernel->setArg(22, AW_);
|
kernel->setArg(22, IAW_);
|
||||||
kernel->setArg(23, BH_);
|
kernel->setArg(23, BH_);
|
||||||
kernel->setArg(24, BW_);
|
kernel->setArg(24, BW_);
|
||||||
kernel->setArg(25, CH_);
|
kernel->setArg(25, ICH_);
|
||||||
kernel->setArg(26, CW_);
|
kernel->setArg(26, ICW_);
|
||||||
kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_);
|
kernel->setArg(27, (num_locks > max_locks_) ? nullptr : locks_);
|
||||||
kernel->setArg(28, (int32_t)grid[0]);
|
kernel->setArg(28, (int32_t)grid[0]);
|
||||||
kernel->setArg(29, (int32_t)grid[1]);
|
kernel->setArg(29, (int32_t)grid[1]);
|
||||||
kernel->setArg(30, (int32_t)grid[2]);
|
kernel->setArg(30, (int32_t)grid[2]);
|
||||||
if(locks_)
|
if(locks_)
|
||||||
((driver::cu_buffer*)locks_)->set_zero(stream, 2*max_locks_*4);
|
((driver::cu_buffer*)locks_)->set_zero(stream, 2*max_locks_*4);
|
||||||
if(op_ == BPROP){
|
if(op_ == FPROP || op_ == BPROP){
|
||||||
size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4;
|
size_t c_nbytes = (c_ty_ == "fp16") ? 2 : 4;
|
||||||
((driver::cu_buffer*)c)->set_zero(stream, AH_*AW_*B_*C_*c_nbytes);
|
((driver::cu_buffer*)c)->set_zero(stream, c_size()*c_nbytes);
|
||||||
}
|
}
|
||||||
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
|
||||||
}
|
}
|
||||||
@@ -290,33 +295,18 @@ void shift::triton_c_src(std::ostream &os) const {
|
|||||||
return R"(
|
return R"(
|
||||||
int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB;
|
int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB;
|
||||||
int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB;
|
int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB;
|
||||||
int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW + pad_w;
|
int32 )" + rx + "w[" + sz + "] = (" + rx + R"(wh % CW) + pad_w;
|
||||||
int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW + pad_h;)";
|
int32 )" + rx + "h[" + sz + "] = (" + rx + R"(wh / CW) + pad_h;)";
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
return R"(
|
return R"(
|
||||||
int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW;
|
int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW;
|
||||||
int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW + pad_w;
|
int32 )" + rx + "w[" + sz + "] = (" + rkx + R"( % CW) + pad_w;
|
||||||
int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH + pad_h;
|
int32 )" + rx + "h[" + sz + "] = (" + rx + R"(bh % CH) + pad_h;
|
||||||
int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)";
|
int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)";
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
auto compute_interior = [&](std::string rx, std::string sz0, std::string sz1) {
|
|
||||||
std::string result;
|
|
||||||
if(shift_edge_h_)
|
|
||||||
result += "int1 interiorh[" + sz0 + "] = 1;\n ";
|
|
||||||
else
|
|
||||||
result += "int1 interiorh[" + sz0 + "] = (" + rx + "h >= pad_h) && (" + rx + "h < (AH - pad_h));\n ";
|
|
||||||
if(shift_edge_w_)
|
|
||||||
result += "int1 interiorw[" + sz0 + "] = 1;";
|
|
||||||
else
|
|
||||||
result += "int1 interiorw[" + sz0 + "] = (" + rx + "w >= pad_w) && (" + rx + "w < (AW - pad_w));";
|
|
||||||
result += R"(
|
|
||||||
int1 interior[)" + sz0 + ", " + sz1 + "] = interiorh[:, newaxis] && interiorw[:, newaxis];";
|
|
||||||
return result;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string result =
|
std::string result =
|
||||||
R"(
|
R"(
|
||||||
const tunable int32 TM = {16, 32, 64, 128};
|
const tunable int32 TM = {16, 32, 64, 128};
|
||||||
@@ -506,8 +496,8 @@ if(op_ == WGRAD){
|
|||||||
if(op_ == BPROP){
|
if(op_ == BPROP){
|
||||||
result += R"(
|
result += R"(
|
||||||
__constant__ int32* pd[TN] = delta_a + ryc;
|
__constant__ int32* pd[TN] = delta_a + ryc;
|
||||||
pc = pc + (*pd)[newaxis, :];
|
)" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
|
||||||
@checkc *pc = c;
|
@checkc *shift_pc = c;
|
||||||
)";
|
)";
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
|
@@ -43,7 +43,8 @@ void loop_nest(std::vector<size_t> const & ranges,
|
|||||||
// size_t current = 0;
|
// size_t current = 0;
|
||||||
while(true){
|
while(true){
|
||||||
//Execute function
|
//Execute function
|
||||||
pool.enqueue([values, &f](){ f(values); });
|
// pool.enqueue([values, &f](){ f(values); });
|
||||||
|
f(values);
|
||||||
while(values[i]++ == ranges[i] - 1){
|
while(values[i]++ == ranges[i] - 1){
|
||||||
if(i == 0)
|
if(i == 0)
|
||||||
return;
|
return;
|
||||||
|
Reference in New Issue
Block a user