[triton/examples/cpp] removed common.hpp helper
This commit is contained in:
@@ -1,59 +0,0 @@
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include "triton/driver/device.h"
|
||||
#include <algorithm>
|
||||
|
||||
template<class T, bool AT, bool BT>
|
||||
void simple_gemm(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K){
|
||||
for(size_t m = 0; m < M; m++)
|
||||
for(size_t n = 0; n < N; n++){
|
||||
T acc = 0;
|
||||
for(size_t k = 0; k < K; k++)
|
||||
acc += (AT?a[k + m*K]:a[m + k*M]) * (BT?b[n + k*N]:b[k + n*K]);
|
||||
c[m + n*M] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void simple_gemm(bool AT, bool BT, std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K) {
|
||||
if(AT && BT)
|
||||
simple_gemm<T, true, true>(c, a, b, M, N, K);
|
||||
else if(AT && !BT)
|
||||
simple_gemm<T, true, false>(c, a, b, M, N, K);
|
||||
else if(!AT && BT)
|
||||
simple_gemm<T, false, true>(c, a, b, M, N, K);
|
||||
else
|
||||
simple_gemm<T, false, false>(c, a, b, M, N, K);
|
||||
}
|
||||
|
||||
// input layout: C, H, W, BS
|
||||
// filter layout: C, K
|
||||
// output layout: K, H, W, BS
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void shift_conv(int32_t C, int32_t H, int32_t W, int32_t BS,
|
||||
int32_t K,
|
||||
std::vector<OUT_DTYPE>& O,
|
||||
const std::vector<IN_DTYPE>& I,
|
||||
const std::vector<IN_DTYPE>& F,
|
||||
const std::vector<int32_t> shift_h,
|
||||
const std::vector<int32_t> shift_w)
|
||||
{
|
||||
OUT_DTYPE acc;
|
||||
for(int32_t p = 0; p < H; ++p)
|
||||
for(int32_t q = 0; q < W; ++q)
|
||||
for(int32_t bs = 0; bs < BS; ++bs)
|
||||
for(int32_t k = 0; k < K; ++k)
|
||||
{
|
||||
acc = 0;
|
||||
for(int32_t c = 0; c < C; ++c){
|
||||
int32_t h = p + shift_h[c];
|
||||
int32_t w = q + shift_w[c];
|
||||
bool in_bounds = (h >= 0 && w >= 0 && h < H && w < W);
|
||||
IN_DTYPE a = in_bounds?I[bs + w*BS + h*BS*W + c*BS*H*W]:0;
|
||||
IN_DTYPE b = F[k + c*K];
|
||||
acc = std::fma(a, b, acc);
|
||||
}
|
||||
O[bs + q*BS + p*BS*W + k*BS*H*W] = acc;
|
||||
}
|
||||
}
|
@@ -1,6 +1,5 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include "common.hpp"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
@@ -67,7 +66,7 @@ int main() {
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
stream->read(dc, true, 0, hc);
|
||||
simple_gemm<float>(AT, BT, rc, ha, hb, M, N, K);
|
||||
triton::dnn::gemm::cpu_ref<float>(AT, BT, rc, ha, hb, M, N, K);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||
|
@@ -1,11 +1,41 @@
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
#include "common.hpp"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/driver/backend.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
// input layout: C, H, W, BS
|
||||
// filter layout: C, K
|
||||
// output layout: K, H, W, BS
|
||||
template<class IN_DTYPE, class OUT_DTYPE>
|
||||
void shift_conv(int32_t C, int32_t H, int32_t W, int32_t BS,
|
||||
int32_t K,
|
||||
std::vector<OUT_DTYPE>& O,
|
||||
const std::vector<IN_DTYPE>& I,
|
||||
const std::vector<IN_DTYPE>& F,
|
||||
const std::vector<int32_t> shift_h,
|
||||
const std::vector<int32_t> shift_w)
|
||||
{
|
||||
OUT_DTYPE acc;
|
||||
for(int32_t p = 0; p < H; ++p)
|
||||
for(int32_t q = 0; q < W; ++q)
|
||||
for(int32_t bs = 0; bs < BS; ++bs)
|
||||
for(int32_t k = 0; k < K; ++k)
|
||||
{
|
||||
acc = 0;
|
||||
for(int32_t c = 0; c < C; ++c){
|
||||
int32_t h = p + shift_h[c];
|
||||
int32_t w = q + shift_w[c];
|
||||
bool in_bounds = (h >= 0 && w >= 0 && h < H && w < W);
|
||||
IN_DTYPE a = in_bounds?I[bs + w*BS + h*BS*W + c*BS*H*W]:0;
|
||||
IN_DTYPE b = F[k + c*K];
|
||||
acc = std::fma(a, b, acc);
|
||||
}
|
||||
O[bs + q*BS + p*BS*W + k*BS*H*W] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
// K = channels
|
||||
// M = batch * height * width
|
||||
// N = number of feature maps
|
||||
|
@@ -14,6 +14,29 @@ public:
|
||||
driver::buffer *locks, int32_t grid_0, int32_t grid_1);
|
||||
static std::vector<unsigned> default_params(bool AT, bool BT);
|
||||
static std::string src(bool AT, bool BT);
|
||||
|
||||
template<class T, bool AT, bool BT>
|
||||
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K){
|
||||
for(size_t m = 0; m < M; m++)
|
||||
for(size_t n = 0; n < N; n++){
|
||||
T acc = 0;
|
||||
for(size_t k = 0; k < K; k++)
|
||||
acc += (AT?a[k + m*K]:a[m + k*M]) * (BT?b[n + k*N]:b[k + n*K]);
|
||||
c[m + n*M] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
static void cpu_ref(bool AT, bool BT, std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b, size_t M, size_t N, size_t K) {
|
||||
if(AT && BT)
|
||||
gemm::cpu_ref<T, true, true>(c, a, b, M, N, K);
|
||||
else if(AT && !BT)
|
||||
gemm::cpu_ref<T, true, false>(c, a, b, M, N, K);
|
||||
else if(!AT && BT)
|
||||
gemm::cpu_ref<T, false, true>(c, a, b, M, N, K);
|
||||
else
|
||||
gemm::cpu_ref<T, false, false>(c, a, b, M, N, K);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user