[dnn/conv] minor cleaning
This commit is contained in:
@@ -10,7 +10,7 @@ int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
triton::jit jit(context);
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::BPROP;
|
||||
triton::dnn::conv::type ty = triton::dnn::conv::WGRAD;
|
||||
// initialization
|
||||
int32_t B = 4, NF = 32;
|
||||
int32_t D = 1, H = 24, W = 240;
|
||||
@@ -66,7 +66,7 @@ int main() {
|
||||
return configuration.get_nflops() / ts * 1e-3;
|
||||
};
|
||||
std::string src = configuration.src();
|
||||
jit.autotune("conv", src.c_str(), benchmark);
|
||||
// jit.autotune("conv", src.c_str(), benchmark);
|
||||
jit.add_module("conv", src.c_str(), configuration.default_params());
|
||||
triton::driver::kernel* kernel = jit.get_function("conv");
|
||||
triton::jit::launch_information info = jit.get_launch_info("conv");
|
||||
|
@@ -211,10 +211,18 @@ public:
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, AH_);
|
||||
kernel->setArg(7, AW_);
|
||||
kernel->setArg(8, BH_);
|
||||
kernel->setArg(9, BW_);
|
||||
kernel->setArg(10, CH_);
|
||||
kernel->setArg(11, CW_);
|
||||
if(ty_ == WGRAD){
|
||||
kernel->setArg(8, CH_);
|
||||
kernel->setArg(9, CW_);
|
||||
kernel->setArg(10, BH_);
|
||||
kernel->setArg(11, BW_);
|
||||
}
|
||||
else{
|
||||
kernel->setArg(8, BH_);
|
||||
kernel->setArg(9, BW_);
|
||||
kernel->setArg(10, CH_);
|
||||
kernel->setArg(11, CW_);
|
||||
}
|
||||
kernel->setArg(12, ld_a_[0]);
|
||||
kernel->setArg(13, ld_a_[1]);
|
||||
kernel->setArg(14, ld_a_[2]);
|
||||
@@ -360,8 +368,8 @@ public:
|
||||
fp32 *c,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 AH, int32 AW,
|
||||
int32 CH, int32 CW,
|
||||
int32 BH, int32 BW,
|
||||
int32 CH, int32 CW,
|
||||
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
|
||||
int32 ldb_n, int32 ldb_k, int32 ldb_m, int32 ldb_p, int32 ldb_q,
|
||||
int32 ldc_c, int32 ldc_t, int32 ldc_r, int32 ldc_s, int32 ldc_k,
|
||||
|
Reference in New Issue
Block a user