trying to remove interior logic

This commit is contained in:
Philippe Tillet
2019-07-16 18:47:50 -07:00
parent 5f6dd23fc2
commit ec24e1e7df
5 changed files with 20 additions and 30 deletions

View File

@@ -14,13 +14,13 @@ int main() {
// initialize default compute device // initialize default compute device
auto context = triton::driver::backend::contexts::get_default(); auto context = triton::driver::backend::contexts::get_default();
auto op = triton::dnn::shift::FPROP; auto op = triton::dnn::shift::BPROP;
// initialization // initialization
int32_t R = 3, S = 3; int32_t R = 3, S = 3;
int32_t B = 128, F = 128; int32_t B = 16, F = 4096;
int32_t H = 16, W = 16; int32_t H = 16, W = 16;
int32_t C = 128; int32_t C = 4096;
// random shifts // random shifts
std::vector<int32_t> shift_h(C); std::vector<int32_t> shift_h(C);

View File

@@ -128,6 +128,6 @@ def run_batchnorm():
print(np.max(np.abs(dg_t - dg_n))) print(np.max(np.abs(dg_t - dg_n)))
print(np.max(np.abs(db_t - db_n))) print(np.max(np.abs(db_t - db_n)))
run_dot() #run_dot()
#run_shift() run_shift()
#run_batchnorm() #run_batchnorm()

View File

@@ -237,7 +237,7 @@ void tune::run(ir::module &mod) {
continue; continue;
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){ if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty(); ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 4)); std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
*params_.at(i).at("nts.d0") = *tmp; *params_.at(i).at("nts.d0") = *tmp;
} }
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){ if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){

View File

@@ -51,7 +51,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
jit->add_module(name_.c_str(), src.c_str(), best.params); jit->add_module(name_.c_str(), src.c_str(), best.params);
} }
else { else {
jit->add_module(name_.c_str(), src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 16, 8, 1}); jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
} }
triton::driver::kernel* kernel = jit->get_function(name_.c_str()); triton::driver::kernel* kernel = jit->get_function(name_.c_str());
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module()); clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());

View File

@@ -80,7 +80,9 @@ shift::shift(int B, int C,
throw std::runtime_error("unsupported input layout"); throw std::runtime_error("unsupported input layout");
} }
// Equivalent matmul // Equivalent matmul
M_ = B_*CH_*CW_; M_ = B_*(CH_ - BH_ / 2)*(CW_ - BW_/2);
if(M_ == 0)
throw std::runtime_error("unsupported input shapes - no interior !");
N_ = F_; N_ = F_;
K_ = C_; K_ = C_;
// transpose // transpose
@@ -288,14 +290,14 @@ void shift::triton_c_src(std::ostream &os) const {
return R"( return R"(
int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB; int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB;
int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB; int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB;
int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW; int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW + pad_w;
int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW;)"; int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW + pad_h;)";
} }
else { else {
return R"( return R"(
int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW; int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW;
int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW; int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW + pad_w;
int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH; int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH + pad_h;
int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)"; int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)";
} }
}; };
@@ -370,10 +372,7 @@ if(op_ == FPROP){
int32 offa0[TM, TK] = offxa[:, newaxis]; int32 offa0[TM, TK] = offxa[:, newaxis];
__constant__ int32* pd[TK] = delta_a + rka; __constant__ int32* pd[TK] = delta_a + rka;
multiple_of(4) int32 d[TK] = *pd; multiple_of(4) int32 d[TK] = *pd;
int32 offa_interior[TM, TK] = d[newaxis, :]; int32 offa1[TM, TK] = d[newaxis, :];)";
int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;
)" + compute_interior("ra", "TM", "TK") + R"(
int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)";
} }
if(op_ == BPROP){ if(op_ == BPROP){
result += result +=
@@ -415,10 +414,8 @@ if(op_ == WGRAD){
rbw = rbw * stride_w; rbw = rbw * stride_w;
rbh = rbh * stride_h; rbh = rbh * stride_h;
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
)" + compute_interior("rb", "TK", "TN") + R"(
int32 incb[TK, TN] = interior ? shift : 0;
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
int32 offb1[TK, TN] = offkb[:, newaxis] + incb;)"; int32 offb1[TK, TN] = offkb[:, newaxis] + shift;)";
} }
/* Main loop */ /* Main loop */
@@ -439,10 +436,7 @@ if(op_ == FPROP){
result += R"( result += R"(
pd = pd + TK; pd = pd + TK;
d = *pd; d = *pd;
offa_interior = d[newaxis, :]; pa = pa + d[newaxis, :];)";
offa_exterior = TK * lda_c;
int32 offa[TM, TK] = interior ? offa_interior : offa_exterior;
pa = pa + offa;)";
} }
if(op_ == BPROP){ if(op_ == BPROP){
result += R"( result += R"(
@@ -470,9 +464,7 @@ if(op_ == WGRAD){
rbw = rbw * stride_w; rbw = rbw * stride_w;
rbh = rbh * stride_h; rbh = rbh * stride_h;
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h; offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
)" + compute_interior("rb", "TK", "TN") + R"( pb = B + offb0 + offkb[:, newaxis] + shift;)";
incb = interior ? shift : 0;
pb = B + offb0 + offkb[:, newaxis] + incb;)";
} }
if(op_ == FPROP){ if(op_ == FPROP){
result += R"( result += R"(
@@ -513,11 +505,9 @@ if(op_ == WGRAD){
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)"; int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
if(op_ == BPROP){ if(op_ == BPROP){
result += R"( result += R"(
)" + compute_interior("rc", "TM", "TN") + R"(
__constant__ int32* pd[TN] = delta_a + ryc; __constant__ int32* pd[TN] = delta_a + ryc;
)" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :]; pc = pc + (*pd)[newaxis, :];
pc = interior ? shift_pc : pc; @checkc *pc = c;
@checkc __atomic_add(pc, c);
)"; )";
} }
else{ else{