trying interior shift
This commit is contained in:
@@ -7,37 +7,6 @@
|
||||
#include "triton/tools/bench.hpp"
|
||||
#include "triton/dnn/shift.h"
|
||||
|
||||
// input layout: C, H, W, BS
|
||||
// filter layout: C, K
|
||||
// output layout: K, H, W, BS
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void shift_conv(int32_t C, int32_t H, int32_t W, int32_t BS,
|
||||
int32_t K,
|
||||
std::vector<OUT_DTYPE>& O,
|
||||
const std::vector<IN_DTYPE>& I,
|
||||
const std::vector<IN_DTYPE>& F,
|
||||
const std::vector<int32_t> shift_h,
|
||||
const std::vector<int32_t> shift_w)
|
||||
{
|
||||
OUT_DTYPE acc;
|
||||
for(int32_t p = 0; p < H; ++p)
|
||||
for(int32_t q = 0; q < W; ++q)
|
||||
for(int32_t bs = 0; bs < BS; ++bs)
|
||||
for(int32_t k = 0; k < K; ++k)
|
||||
{
|
||||
acc = 0;
|
||||
for(int32_t c = 0; c < C; ++c){
|
||||
int32_t h = p + shift_h[c];
|
||||
int32_t w = q + shift_w[c];
|
||||
bool in_bounds = (h >= 0 && w >= 0 && h < H && w < W);
|
||||
IN_DTYPE a = in_bounds?I[bs + w*BS + h*BS*W + c*BS*H*W]:0;
|
||||
IN_DTYPE b = F[k + c*K];
|
||||
acc = std::fma(a, b, acc);
|
||||
}
|
||||
O[bs + q*BS + p*BS*W + k*BS*H*W] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
@@ -104,7 +73,7 @@ int main() {
|
||||
std::ostringstream oss;
|
||||
shift.src(oss);
|
||||
std::string src = oss.str();
|
||||
// jit.autotune("shift", src.c_str(), benchmark);
|
||||
jit.autotune("shift", src.c_str(), benchmark);
|
||||
jit.add_module("shift", src.c_str(), params);
|
||||
triton::driver::kernel* kernel = jit.get_function("shift");
|
||||
triton::jit::launch_information info = jit.get_launch_info("shift");
|
||||
|
Reference in New Issue
Block a user