[triton/examples/cpp] removed common.hpp helper

This commit is contained in:
Philippe Tillet
2019-05-28 14:14:33 -04:00
parent a9d078c06f
commit 8102efc064
4 changed files with 55 additions and 62 deletions

View File

@@ -1,11 +1,41 @@
#include <cstring>
#include <cstdio>
#include "common.hpp"
#include "triton/runtime/jit.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"
// input layout: C, H, W, BS
// filter layout: C, K
// output layout: K, H, W, BS
template<class IN_DTYPE, class OUT_DTYPE>
void shift_conv(int32_t C, int32_t H, int32_t W, int32_t BS,
int32_t K,
std::vector<OUT_DTYPE>& O,
const std::vector<IN_DTYPE>& I,
const std::vector<IN_DTYPE>& F,
const std::vector<int32_t> shift_h,
const std::vector<int32_t> shift_w)
{
OUT_DTYPE acc;
for(int32_t p = 0; p < H; ++p)
for(int32_t q = 0; q < W; ++q)
for(int32_t bs = 0; bs < BS; ++bs)
for(int32_t k = 0; k < K; ++k)
{
acc = 0;
for(int32_t c = 0; c < C; ++c){
int32_t h = p + shift_h[c];
int32_t w = q + shift_w[c];
bool in_bounds = (h >= 0 && w >= 0 && h < H && w < W);
IN_DTYPE a = in_bounds?I[bs + w*BS + h*BS*W + c*BS*H*W]:0;
IN_DTYPE b = F[k + c*K];
acc = std::fma(a, b, acc);
}
O[bs + q*BS + p*BS*W + k*BS*H*W] = acc;
}
}
// K = channels
// M = batch * height * width
// N = number of feature maps