[codegen] shift: added sketch for shift-convolution backpropagation
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
|
||||
|
||||
int main() {
|
||||
bool AT = false;
|
||||
bool AT = true;
|
||||
bool BT = true;
|
||||
|
||||
// initialize default compute device
|
||||
@@ -16,7 +16,7 @@ int main() {
|
||||
triton::jit jit(context);
|
||||
|
||||
// matrix multiplication parameters
|
||||
int32_t M = 32768, N = 1024, K = 1024;
|
||||
int32_t M = 1024, N = 1024, K = 1024;
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
std::vector<float> ha(M*K);
|
||||
@@ -59,9 +59,9 @@ int main() {
|
||||
|
||||
|
||||
// just-in-time compile source-code
|
||||
std::string src = triton::dnn::gemm::src(AT, BT, "fp32", "fp32", 1, 1);
|
||||
jit.autotune("matmul",src.c_str(), benchmark);
|
||||
jit.add_module("matmul", src.c_str(), triton::dnn::gemm::default_params(AT, BT));
|
||||
std::string src = triton::dnn::gemm::src(AT, BT, "fp32", "fp32", 4, 4);
|
||||
// jit.autotune("matmul",src.c_str(), benchmark);
|
||||
jit.add_module("matmul", src.c_str(), {8, 16, 4, 2, 16, 8, 4, 2, 2, 4, 2, 8, 8, 1});
|
||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl;
|
||||
|
Reference in New Issue
Block a user