[dnn/conv] minor cleaning

This commit is contained in:
Philippe Tillet
2019-05-15 11:32:47 -04:00
parent be2ba03382
commit 15a967c81e
2 changed files with 15 additions and 7 deletions

View File

@@ -10,7 +10,7 @@ int main() {
// initialize default compute device // initialize default compute device
auto context = triton::driver::backend::contexts::get_default(); auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context); triton::jit jit(context);
triton::dnn::conv::type ty = triton::dnn::conv::BPROP; triton::dnn::conv::type ty = triton::dnn::conv::WGRAD;
// initialization // initialization
int32_t B = 4, NF = 32; int32_t B = 4, NF = 32;
int32_t D = 1, H = 24, W = 240; int32_t D = 1, H = 24, W = 240;
@@ -66,7 +66,7 @@ int main() {
return configuration.get_nflops() / ts * 1e-3; return configuration.get_nflops() / ts * 1e-3;
}; };
std::string src = configuration.src(); std::string src = configuration.src();
jit.autotune("conv", src.c_str(), benchmark); // jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), configuration.default_params()); jit.add_module("conv", src.c_str(), configuration.default_params());
triton::driver::kernel* kernel = jit.get_function("conv"); triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv"); triton::jit::launch_information info = jit.get_launch_info("conv");

View File

@@ -211,10 +211,18 @@ public:
kernel->setArg(5, K_); kernel->setArg(5, K_);
kernel->setArg(6, AH_); kernel->setArg(6, AH_);
kernel->setArg(7, AW_); kernel->setArg(7, AW_);
kernel->setArg(8, BH_); if(ty_ == WGRAD){
kernel->setArg(9, BW_); kernel->setArg(8, CH_);
kernel->setArg(10, CH_); kernel->setArg(9, CW_);
kernel->setArg(11, CW_); kernel->setArg(10, BH_);
kernel->setArg(11, BW_);
}
else{
kernel->setArg(8, BH_);
kernel->setArg(9, BW_);
kernel->setArg(10, CH_);
kernel->setArg(11, CW_);
}
kernel->setArg(12, ld_a_[0]); kernel->setArg(12, ld_a_[0]);
kernel->setArg(13, ld_a_[1]); kernel->setArg(13, ld_a_[1]);
kernel->setArg(14, ld_a_[2]); kernel->setArg(14, ld_a_[2]);
@@ -360,8 +368,8 @@ public:
fp32 *c, fp32 *c,
int32 M, int32 N, int32 K, int32 M, int32 N, int32 K,
int32 AH, int32 AW, int32 AH, int32 AW,
int32 CH, int32 CW,
int32 BH, int32 BW, int32 BH, int32 BW,
int32 CH, int32 CW,
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w, int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
int32 ldb_n, int32 ldb_k, int32 ldb_m, int32 ldb_p, int32 ldb_q, int32 ldb_n, int32 ldb_k, int32 ldb_m, int32 ldb_p, int32 ldb_q,
int32 ldc_c, int32 ldc_t, int32 ldc_r, int32 ldc_s, int32 ldc_k, int32 ldc_c, int32 ldc_t, int32 ldc_r, int32 ldc_s, int32 ldc_k,