trying to remove interior logic
This commit is contained in:
@@ -14,13 +14,13 @@ int main() {
|
||||
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
auto op = triton::dnn::shift::FPROP;
|
||||
auto op = triton::dnn::shift::BPROP;
|
||||
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
int32_t B = 128, F = 128;
|
||||
int32_t B = 16, F = 4096;
|
||||
int32_t H = 16, W = 16;
|
||||
int32_t C = 128;
|
||||
int32_t C = 4096;
|
||||
|
||||
// random shifts
|
||||
std::vector<int32_t> shift_h(C);
|
||||
|
@@ -128,6 +128,6 @@ def run_batchnorm():
|
||||
print(np.max(np.abs(dg_t - dg_n)))
|
||||
print(np.max(np.abs(db_t - db_n)))
|
||||
|
||||
run_dot()
|
||||
#run_shift()
|
||||
#run_dot()
|
||||
run_shift()
|
||||
#run_batchnorm()
|
||||
|
@@ -237,7 +237,7 @@ void tune::run(ir::module &mod) {
|
||||
continue;
|
||||
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 4));
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
|
@@ -51,7 +51,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else {
|
||||
jit->add_module(name_.c_str(), src.c_str(), {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 16, 8, 1});
|
||||
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());
|
||||
|
@@ -80,7 +80,9 @@ shift::shift(int B, int C,
|
||||
throw std::runtime_error("unsupported input layout");
|
||||
}
|
||||
// Equivalent matmul
|
||||
M_ = B_*CH_*CW_;
|
||||
M_ = B_*(CH_ - BH_ / 2)*(CW_ - BW_/2);
|
||||
if(M_ == 0)
|
||||
throw std::runtime_error("unsupported input shapes - no interior !");
|
||||
N_ = F_;
|
||||
K_ = C_;
|
||||
// transpose
|
||||
@@ -288,14 +290,14 @@ void shift::triton_c_src(std::ostream &os) const {
|
||||
return R"(
|
||||
int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB;
|
||||
int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB;
|
||||
int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW;
|
||||
int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW;)";
|
||||
int32 )" + rx + "w[" + sz + "] = " + rx + R"(wh % CW + pad_w;
|
||||
int32 )" + rx + "h[" + sz + "] = " + rx + R"(wh / CW + pad_h;)";
|
||||
}
|
||||
else {
|
||||
return R"(
|
||||
int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW;
|
||||
int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW;
|
||||
int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH;
|
||||
int32 )" + rx + "w[" + sz + "] = " + rkx + R"( % CW + pad_w;
|
||||
int32 )" + rx + "h[" + sz + "] = " + rx + R"(bh % CH + pad_h;
|
||||
int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)";
|
||||
}
|
||||
};
|
||||
@@ -370,10 +372,7 @@ if(op_ == FPROP){
|
||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||
__constant__ int32* pd[TK] = delta_a + rka;
|
||||
multiple_of(4) int32 d[TK] = *pd;
|
||||
int32 offa_interior[TM, TK] = d[newaxis, :];
|
||||
int32 offa_exterior[TM, TK] = rka[newaxis, :] * lda_c;
|
||||
)" + compute_interior("ra", "TM", "TK") + R"(
|
||||
int32 offa1[TM, TK] = interior ? offa_interior : offa_exterior;)";
|
||||
int32 offa1[TM, TK] = d[newaxis, :];)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
result +=
|
||||
@@ -415,10 +414,8 @@ if(op_ == WGRAD){
|
||||
rbw = rbw * stride_w;
|
||||
rbh = rbh * stride_h;
|
||||
int32 offkb[TK] = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
|
||||
)" + compute_interior("rb", "TK", "TN") + R"(
|
||||
int32 incb[TK, TN] = interior ? shift : 0;
|
||||
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||
int32 offb1[TK, TN] = offkb[:, newaxis] + incb;)";
|
||||
int32 offb1[TK, TN] = offkb[:, newaxis] + shift;)";
|
||||
}
|
||||
|
||||
/* Main loop */
|
||||
@@ -439,10 +436,7 @@ if(op_ == FPROP){
|
||||
result += R"(
|
||||
pd = pd + TK;
|
||||
d = *pd;
|
||||
offa_interior = d[newaxis, :];
|
||||
offa_exterior = TK * lda_c;
|
||||
int32 offa[TM, TK] = interior ? offa_interior : offa_exterior;
|
||||
pa = pa + offa;)";
|
||||
pa = pa + d[newaxis, :];)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
result += R"(
|
||||
@@ -470,9 +464,7 @@ if(op_ == WGRAD){
|
||||
rbw = rbw * stride_w;
|
||||
rbh = rbh * stride_h;
|
||||
offkb = rbb*ldb_b + rbw*ldb_w + rbh*ldb_h;
|
||||
)" + compute_interior("rb", "TK", "TN") + R"(
|
||||
incb = interior ? shift : 0;
|
||||
pb = B + offb0 + offkb[:, newaxis] + incb;)";
|
||||
pb = B + offb0 + offkb[:, newaxis] + shift;)";
|
||||
}
|
||||
if(op_ == FPROP){
|
||||
result += R"(
|
||||
@@ -513,11 +505,9 @@ if(op_ == WGRAD){
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
|
||||
if(op_ == BPROP){
|
||||
result += R"(
|
||||
)" + compute_interior("rc", "TM", "TN") + R"(
|
||||
__constant__ int32* pd[TN] = delta_a + ryc;
|
||||
)" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
|
||||
pc = interior ? shift_pc : pc;
|
||||
@checkc __atomic_add(pc, c);
|
||||
pc = pc + (*pd)[newaxis, :];
|
||||
@checkc *pc = c;
|
||||
)";
|
||||
}
|
||||
else{
|
||||
|
Reference in New Issue
Block a user