more performance optimizations

This commit is contained in:
Philippe Tillet
2019-06-28 17:04:07 -07:00
parent a567f3f8a8
commit ab1afbf082
3 changed files with 37 additions and 30 deletions

View File

@@ -14,9 +14,9 @@ int main() {
triton::jit jit(context);
// initialization
int32_t R = 3, S = 3;
int32_t BS = 4, F = 128;
int32_t BS = 4, F = 512;
int32_t H = 32, W = 32;
int32_t C = 128;
int32_t C = 512;
// random shifts
std::vector<int32_t> shift_h(C);
std::vector<int32_t> shift_w(C);
@@ -68,12 +68,12 @@ int main() {
// shift
std::vector<unsigned> params = {
4, 2, 16, 8, 2, 64, 4, 8, 2, 2, 4, 8, 8
32, 2, 128, 16, 2, 128, 16, 8, 2, 2, 4, 2, 8, 8
};
std::ostringstream oss;
shift.src(oss);
std::string src = oss.str();
// jit.autotune("shift", src.c_str(), benchmark);
jit.autotune("shift", src.c_str(), benchmark);
jit.add_module("shift", src.c_str(), params);
triton::driver::kernel* kernel = jit.get_function("shift");
triton::jit::launch_information info = jit.get_launch_info("shift");