some cleaning

This commit is contained in:
Philippe Tillet
2019-07-12 20:03:05 -07:00
parent c1c7062914
commit 7512c7ebed
4 changed files with 114 additions and 103 deletions

View File

@@ -9,17 +9,18 @@
#include "triton/external/half.hpp"
int main() {
typedef half_float::half NumericT;
std::string numeric_t_str = "fp16";
typedef float NumericT;
std::string numeric_t_str = "fp32";
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
auto op = triton::dnn::shift::FPROP;
// initialization
int32_t R = 3, S = 3;
int32_t BS = 4, F = 1024;
int32_t B = 16, F = 4096;
int32_t H = 16, W = 16;
int32_t C = 1024;
int32_t C = 4096;
// random shifts
std::vector<int32_t> shift_h(C);
@@ -29,12 +30,15 @@ int main() {
shift_w[c] = rand() % S - S/2;
}
// configuration
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, 1, 1, shift_h.data(), shift_w.data(), numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
triton::dnn::shift shift(B, C, 1, H, W, 1, R, S, F, 1, 1,
shift_h.data(), shift_w.data(),
numeric_t_str, numeric_t_str,
op, false);
// host buffers
std::vector<float> hc(shift.c_size());
std::vector<float> rc(shift.c_size());
std::vector<NumericT> ha(shift.a_size());
std::vector<NumericT> hb(shift.b_size());
std::vector<NumericT> ha(B*C*H*W);
std::vector<NumericT> hb(C*F);
std::vector<float> hc(B*F*H*W);
std::vector<float> rc(hc.size());
// device buffers
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*sizeof(NumericT));