[codegen/tune]: added fragmentation types

This commit is contained in:
Philippe Tillet
2019-06-06 16:48:32 -07:00
parent f58c9a4d2b
commit cdf5a0d011
7 changed files with 101 additions and 35 deletions

View File

@@ -20,8 +20,8 @@ using GPUDevice = Eigen::GpuDevice;
const char* src =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TM = {16};
const tunable int32 TN = {16};
const tunable int32 TK = {8};
const tunable int32 GZ = {1};
@@ -126,7 +126,7 @@ class BlockSparseGemmOp : public OpKernel {
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
stream->synchronize();
// just-in-time compile source-code
jit.add_module("matmul", src, {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1});
jit.add_module("matmul", src, {8, 2, 16, 8, 2, 16, 8, 8, 2, 2, 8, 8, 8, 1});
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
// launch info

View File

@@ -20,6 +20,6 @@ hresult = np.dot(hb.T, ha)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([c], feed_dict = {locks: np.zeros(4096),
a: ha,
b: hb})
a: ha,
b: hb})
print(result - hresult)