removed shared conflicts for 8x32x4 and 32x8x4 configurations

This commit is contained in:
Philippe Tillet
2019-06-13 17:51:54 -07:00
parent 21a9b92c87
commit 36e3667a9a
3 changed files with 11 additions and 11 deletions

View File

@@ -119,7 +119,7 @@ class BlockSparseGemmOp : public OpKernel {
// just-in-time compile source-code
// jit.autotune("matmul", src, benchmark);
// jit.add_module("matmul", src, {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1});
jit.add_module("matmul", src, {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1});
jit.add_module("matmul", src, {16, 4, 128, 16, 4, 128, 1, 4, 2, 2, 8, 32, 8, 1});
// jit.add_module("matmul", src, {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 });
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");

View File

@@ -6,7 +6,7 @@ data_files_path = tf.resource_loader.get_data_files_path()
library_dir = os.path.dirname(os.path.realpath(__file__))
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
M, N, K = 8192, 8192, 8192
M, N, K = 256, 256, 256
a = tf.placeholder(tf.float16, shape=[M, K])
b = tf.placeholder(tf.float16, shape=[N, K])
locks = tf.placeholder(tf.int32, shape=[4096])
@@ -30,9 +30,9 @@ result = sess.run([c], feed_dict = {locks: np.zeros(4096),
# min_iters=100)
#print(end - start)
#print(2*M*N*K / (end - start) * 1e-12)
#hresult = np.dot(ha.T, hb).T
#dif = np.abs(result - hresult)
#print("dif: %f" % np.max(dif))
hresult = np.dot(ha.T, hb).T
dif = np.abs(result - hresult)
print("dif: %f" % np.max(dif))
#np.savetxt("dif.txt", dif, fmt="%5.2f")
#np.savetxt("gpu.txt", result, fmt="%5.2f")

View File

@@ -510,16 +510,16 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
/* intra warp offset */
// offset of quad in pair
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(pack_size_0_));
Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(pack_size_1_));
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(fpw_0 * pack_size_0_));
Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(fpw_1 * pack_size_1_));
// Quad pair id
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0));
pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0));
// Quad pair offset
Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(8 * pack_size_0_));
Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(8 * pack_size_1_));
Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_));
Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_));
/* inter warp offset */
Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0));
@@ -557,8 +557,8 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
for(unsigned pack = 0; pack < num_packs_1_; pack++)
for(unsigned jj = 0; jj < pack_size_1_; jj++)
for(unsigned j = 0; j < 2; j++){
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*pack_size_1_)));
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*pack_size_1_ + 1)));
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
}
/* axes */