[dnn/conv] minor cleaning
This commit is contained in:
@@ -10,7 +10,7 @@ int main() {
|
|||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
auto context = triton::driver::backend::contexts::get_default();
|
auto context = triton::driver::backend::contexts::get_default();
|
||||||
triton::jit jit(context);
|
triton::jit jit(context);
|
||||||
triton::dnn::conv::type ty = triton::dnn::conv::BPROP;
|
triton::dnn::conv::type ty = triton::dnn::conv::WGRAD;
|
||||||
// initialization
|
// initialization
|
||||||
int32_t B = 4, NF = 32;
|
int32_t B = 4, NF = 32;
|
||||||
int32_t D = 1, H = 24, W = 240;
|
int32_t D = 1, H = 24, W = 240;
|
||||||
@@ -66,7 +66,7 @@ int main() {
|
|||||||
return configuration.get_nflops() / ts * 1e-3;
|
return configuration.get_nflops() / ts * 1e-3;
|
||||||
};
|
};
|
||||||
std::string src = configuration.src();
|
std::string src = configuration.src();
|
||||||
jit.autotune("conv", src.c_str(), benchmark);
|
// jit.autotune("conv", src.c_str(), benchmark);
|
||||||
jit.add_module("conv", src.c_str(), configuration.default_params());
|
jit.add_module("conv", src.c_str(), configuration.default_params());
|
||||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||||
|
@@ -211,10 +211,18 @@ public:
|
|||||||
kernel->setArg(5, K_);
|
kernel->setArg(5, K_);
|
||||||
kernel->setArg(6, AH_);
|
kernel->setArg(6, AH_);
|
||||||
kernel->setArg(7, AW_);
|
kernel->setArg(7, AW_);
|
||||||
kernel->setArg(8, BH_);
|
if(ty_ == WGRAD){
|
||||||
kernel->setArg(9, BW_);
|
kernel->setArg(8, CH_);
|
||||||
kernel->setArg(10, CH_);
|
kernel->setArg(9, CW_);
|
||||||
kernel->setArg(11, CW_);
|
kernel->setArg(10, BH_);
|
||||||
|
kernel->setArg(11, BW_);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
kernel->setArg(8, BH_);
|
||||||
|
kernel->setArg(9, BW_);
|
||||||
|
kernel->setArg(10, CH_);
|
||||||
|
kernel->setArg(11, CW_);
|
||||||
|
}
|
||||||
kernel->setArg(12, ld_a_[0]);
|
kernel->setArg(12, ld_a_[0]);
|
||||||
kernel->setArg(13, ld_a_[1]);
|
kernel->setArg(13, ld_a_[1]);
|
||||||
kernel->setArg(14, ld_a_[2]);
|
kernel->setArg(14, ld_a_[2]);
|
||||||
@@ -360,8 +368,8 @@ public:
|
|||||||
fp32 *c,
|
fp32 *c,
|
||||||
int32 M, int32 N, int32 K,
|
int32 M, int32 N, int32 K,
|
||||||
int32 AH, int32 AW,
|
int32 AH, int32 AW,
|
||||||
int32 CH, int32 CW,
|
|
||||||
int32 BH, int32 BW,
|
int32 BH, int32 BW,
|
||||||
|
int32 CH, int32 CW,
|
||||||
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
|
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
|
||||||
int32 ldb_n, int32 ldb_k, int32 ldb_m, int32 ldb_p, int32 ldb_q,
|
int32 ldb_n, int32 ldb_k, int32 ldb_m, int32 ldb_p, int32 ldb_q,
|
||||||
int32 ldc_c, int32 ldc_t, int32 ldc_r, int32 ldc_s, int32 ldc_k,
|
int32 ldc_c, int32 ldc_t, int32 ldc_r, int32 ldc_s, int32 ldc_k,
|
||||||
|
Reference in New Issue
Block a user