[dnn/conv] minor cleaning

This commit is contained in:
Philippe Tillet
2019-05-15 11:32:47 -04:00
parent be2ba03382
commit 15a967c81e
2 changed files with 15 additions and 7 deletions

View File

@@ -10,7 +10,7 @@ int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::jit jit(context);
triton::dnn::conv::type ty = triton::dnn::conv::BPROP;
triton::dnn::conv::type ty = triton::dnn::conv::WGRAD;
// initialization
int32_t B = 4, NF = 32;
int32_t D = 1, H = 24, W = 240;
@@ -66,7 +66,7 @@ int main() {
return configuration.get_nflops() / ts * 1e-3;
};
std::string src = configuration.src();
jit.autotune("conv", src.c_str(), benchmark);
// jit.autotune("conv", src.c_str(), benchmark);
jit.add_module("conv", src.c_str(), configuration.default_params());
triton::driver::kernel* kernel = jit.get_function("conv");
triton::jit::launch_information info = jit.get_launch_info("conv");

View File

@@ -211,10 +211,18 @@ public:
kernel->setArg(5, K_);
kernel->setArg(6, AH_);
kernel->setArg(7, AW_);
kernel->setArg(8, BH_);
kernel->setArg(9, BW_);
kernel->setArg(10, CH_);
kernel->setArg(11, CW_);
if(ty_ == WGRAD){
kernel->setArg(8, CH_);
kernel->setArg(9, CW_);
kernel->setArg(10, BH_);
kernel->setArg(11, BW_);
}
else{
kernel->setArg(8, BH_);
kernel->setArg(9, BW_);
kernel->setArg(10, CH_);
kernel->setArg(11, CW_);
}
kernel->setArg(12, ld_a_[0]);
kernel->setArg(13, ld_a_[1]);
kernel->setArg(14, ld_a_[2]);
@@ -360,8 +368,8 @@ public:
fp32 *c,
int32 M, int32 N, int32 K,
int32 AH, int32 AW,
int32 CH, int32 CW,
int32 BH, int32 BW,
int32 CH, int32 CW,
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
int32 ldb_n, int32 ldb_k, int32 ldb_m, int32 ldb_p, int32 ldb_q,
int32 ldc_c, int32 ldc_t, int32 ldc_r, int32 ldc_s, int32 ldc_k,