more performance optimizations
This commit is contained in:
@@ -14,9 +14,9 @@ int main() {
|
||||
triton::jit jit(context);
|
||||
// initialization
|
||||
int32_t R = 3, S = 3;
|
||||
int32_t BS = 4, F = 128;
|
||||
int32_t BS = 4, F = 512;
|
||||
int32_t H = 32, W = 32;
|
||||
int32_t C = 128;
|
||||
int32_t C = 512;
|
||||
// random shifts
|
||||
std::vector<int32_t> shift_h(C);
|
||||
std::vector<int32_t> shift_w(C);
|
||||
@@ -68,12 +68,12 @@ int main() {
|
||||
|
||||
// shift
|
||||
std::vector<unsigned> params = {
|
||||
4, 2, 16, 8, 2, 64, 4, 8, 2, 2, 4, 8, 8
|
||||
32, 2, 128, 16, 2, 128, 16, 8, 2, 2, 4, 2, 8, 8
|
||||
};
|
||||
std::ostringstream oss;
|
||||
shift.src(oss);
|
||||
std::string src = oss.str();
|
||||
// jit.autotune("shift", src.c_str(), benchmark);
|
||||
jit.autotune("shift", src.c_str(), benchmark);
|
||||
jit.add_module("shift", src.c_str(), params);
|
||||
triton::driver::kernel* kernel = jit.get_function("shift");
|
||||
triton::jit::launch_information info = jit.get_launch_info("shift");
|
||||
|
Reference in New Issue
Block a user