[codegen/tune]: added fragmentation types
This commit is contained in:
@@ -20,8 +20,8 @@ using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
const char* src =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TM = {16};
|
||||
const tunable int32 TN = {16};
|
||||
const tunable int32 TK = {8};
|
||||
const tunable int32 GZ = {1};
|
||||
|
||||
@@ -126,7 +126,7 @@ class BlockSparseGemmOp : public OpKernel {
|
||||
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
|
||||
stream->synchronize();
|
||||
// just-in-time compile source-code
|
||||
jit.add_module("matmul", src, {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1});
|
||||
jit.add_module("matmul", src, {8, 2, 16, 8, 2, 16, 8, 8, 2, 2, 8, 8, 8, 1});
|
||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
// launch info
|
||||
|
Reference in New Issue
Block a user