trying to remove interior logic

This commit is contained in:
Philippe Tillet
2019-07-16 18:47:50 -07:00
parent 5f6dd23fc2
commit ec24e1e7df
5 changed files with 20 additions and 30 deletions

View File

@@ -14,13 +14,13 @@ int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
auto op = triton::dnn::shift::FPROP;
auto op = triton::dnn::shift::BPROP;
// initialization
int32_t R = 3, S = 3;
int32_t B = 128, F = 128;
int32_t B = 16, F = 4096;
int32_t H = 16, W = 16;
int32_t C = 128;
int32_t C = 4096;
// random shifts
std::vector<int32_t> shift_h(C);

View File

@@ -128,6 +128,6 @@ def run_batchnorm():
print(np.max(np.abs(dg_t - dg_n)))
print(np.max(np.abs(db_t - db_n)))
run_dot()
#run_shift()
#run_dot()
run_shift()
#run_batchnorm()

View File

@@ -237,7 +237,7 @@ void tune::run(ir::module &mod) {
continue;
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 4));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
*params_.at(i).at("nts.d0") = *tmp;
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){

View File

@@ -51,7 +51,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
jit->add_module(name_.c_str(), src.c_str(), best.params);
}
else {
jit->add_module(name_.c_str(), src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 16, 8, 1});
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());

View File

@@ -80,7 +80,9 @@ shift::shift(int B, int C,
throw std::runtime_error("unsupported input layout");
}
// Equivalent matmul
M_ = B_*CH_*CW_;
M_ = B_*(CH_ - BH_ / 2)*(CW_ - BW_/2);
if(M_ == 0)
throw std::runtime_error("unsupported input shapes - no interior !");
N_ = F_;
K_ = C_;
// transpose
@@ -288,14 +290,14 @@ void shift::triton_c_src(std::ostream &os) const {
return R"(
int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB;
int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB;
int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW;
int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW;)";
int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW + pad_w;
int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW + pad_h;)";
}
else {
return R"(
int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW;
int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW;
int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH;
int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW + pad_w;
int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH + pad_h;
int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)";
}
};
@@ -370,10 +372,7 @@ if(op_ == FPROP){
int32 offa0[TM, TK] = offxa[:, newaxis];
__constant__ int32* pd[TK] = delta_a + rka;
multiple_of(4) int32 d[TK] = *pd;
int32 offa_interior[TM, TK] = d[newaxis, :];
int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;
)" + compute_interior("ra", "TM", "TK") + R"(
int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)";
int32 offa1[TM, TK] = d[newaxis, :];)";
}
if(op_ == BPROP){
result +=
@@ -415,10 +414,8 @@ if(op_ == WGRAD){
rbw = rbw * stride_w;
rbh = rbh * stride_h;
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
)" + compute_interior("rb", "TK", "TN") + R"(
int32 incb[TK, TN] = interior ? shift : 0;
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
int32 offb1[TK, TN] = offkb[:, newaxis] + incb;)";
int32 offb1[TK, TN] = offkb[:, newaxis] + shift;)";
}
/* Main loop */
@@ -439,10 +436,7 @@ if(op_ == FPROP){
result += R"(
pd = pd + TK;
d = *pd;
offa_interior = d[newaxis, :];
offa_exterior = TK * lda_c;
int32 offa[TM, TK] = interior ? offa_interior : offa_exterior;
pa = pa + offa;)";
pa = pa + d[newaxis, :];)";
}
if(op_ == BPROP){
result += R"(
@@ -470,9 +464,7 @@ if(op_ == WGRAD){
rbw = rbw * stride_w;
rbh = rbh * stride_h;
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
)" + compute_interior("rb", "TK", "TN") + R"(
incb = interior ? shift : 0;
pb = B + offb0 + offkb[:, newaxis] + incb;)";
pb = B + offb0 + offkb[:, newaxis] + shift;)";
}
if(op_ == FPROP){
result += R"(
@@ -513,11 +505,9 @@ if(op_ == WGRAD){
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
if(op_ == BPROP){
result += R"(
)" + compute_interior("rc", "TM", "TN") + R"(
__constant__ int32* pd[TN] = delta_a + ryc;
)" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
pc = interior ? shift_pc : pc;
@checkc __atomic_add(pc, c);
pc = pc + (*pd)[newaxis, :];
@checkc *pc = c;
)";
}
else{