more stuff

This commit is contained in:
Philippe Tillet
2019-06-30 16:55:02 -07:00
parent 9a86bc51e1
commit c172bd518b
9 changed files with 124 additions and 90 deletions

View File

@@ -16,6 +16,7 @@ int main() {
auto context = triton::driver::backend::contexts::get_default();
// initialize just-in-time compiler
triton::jit jit(context);
// initialization
int32_t R = 3, S = 3;
int32_t BS = 32, F = 1024;
@@ -30,7 +31,7 @@ int main() {
shift_w[c] = rand() % S - S/2;
}
// configuration
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str);
triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::FPROP);
// host buffers
std::vector<float> hc(shift.c_size());
std::vector<float> rc(shift.c_size());
@@ -58,7 +59,7 @@ int main() {
auto benchmark = [&](triton::driver::kernel* kernel,
triton::jit::launch_information info) {
shift.init(stream, (triton::driver::cu_module*)kernel->module());
// launch info
// launch infoRR
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
@@ -78,7 +79,7 @@ int main() {
std::ostringstream oss;
shift.src(oss);
std::string src = oss.str();
// jit.autotune("shift", src.c_str(), benchmark);
jit.autotune("shift", src.c_str(), benchmark);
jit.add_module("shift", src.c_str(), params);
triton::driver::kernel* kernel = jit.get_function("shift");
triton::jit::launch_information info = jit.get_launch_info("shift");