some cleaning
This commit is contained in:
@@ -9,17 +9,18 @@
|
||||
#include "triton/external/half.hpp"
|
||||
|
||||
int main() {
|
||||
typedef half_float::half NumericT;
|
||||
std::string numeric_t_str = "fp16";
|
||||
typedef float NumericT;
|
||||
std::string numeric_t_str = "fp32";
|
||||
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
auto op = triton::dnn::shift::FPROP;
|
||||
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
int32_t BS = 4, F = 1024;
|
||||
int32_t B = 16, F = 4096;
|
||||
int32_t H = 16, W = 16;
|
||||
int32_t C = 1024;
|
||||
int32_t C = 4096;
|
||||
|
||||
// random shifts
|
||||
std::vector<int32_t> shift_h(C);
|
||||
@@ -29,12 +30,15 @@ int main() {
|
||||
shift_w[c] = rand() % S - S/2;
|
||||
}
|
||||
// configuration
|
||||
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, 1, 1, shift_h.data(), shift_w.data(), numeric_t_str, numeric_t_str, triton::dnn::shift::BPROP);
|
||||
triton::dnn::shift shift(B, C, 1, H, W, 1, R, S, F, 1, 1,
|
||||
shift_h.data(), shift_w.data(),
|
||||
numeric_t_str, numeric_t_str,
|
||||
op, false);
|
||||
// host buffers
|
||||
std::vector<float> hc(shift.c_size());
|
||||
std::vector<float> rc(shift.c_size());
|
||||
std::vector<NumericT> ha(shift.a_size());
|
||||
std::vector<NumericT> hb(shift.b_size());
|
||||
std::vector<NumericT> ha(B*C*H*W);
|
||||
std::vector<NumericT> hb(C*F);
|
||||
std::vector<float> hc(B*F*H*W);
|
||||
std::vector<float> rc(hc.size());
|
||||
// device buffers
|
||||
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
|
||||
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*sizeof(NumericT));
|
||||
|
Reference in New Issue
Block a user