removed shared conflicts for 8x32x4 and 32x8x4 configurations
This commit is contained in:
@@ -119,7 +119,7 @@ class BlockSparseGemmOp : public OpKernel {
|
||||
// just-in-time compile source-code
|
||||
// jit.autotune("matmul", src, benchmark);
|
||||
// jit.add_module("matmul", src, {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1});
|
||||
jit.add_module("matmul", src, {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1});
|
||||
jit.add_module("matmul", src, {16, 4, 128, 16, 4, 128, 1, 4, 2, 2, 8, 32, 8, 1});
|
||||
// jit.add_module("matmul", src, {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 });
|
||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
|
@@ -6,7 +6,7 @@ data_files_path = tf.resource_loader.get_data_files_path()
|
||||
library_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
|
||||
|
||||
M, N, K = 8192, 8192, 8192
|
||||
M, N, K = 256, 256, 256
|
||||
a = tf.placeholder(tf.float16, shape=[M, K])
|
||||
b = tf.placeholder(tf.float16, shape=[N, K])
|
||||
locks = tf.placeholder(tf.int32, shape=[4096])
|
||||
@@ -30,9 +30,9 @@ result = sess.run([c], feed_dict = {locks: np.zeros(4096),
|
||||
# min_iters=100)
|
||||
#print(end - start)
|
||||
#print(2*M*N*K / (end - start) * 1e-12)
|
||||
#hresult = np.dot(ha.T, hb).T
|
||||
#dif = np.abs(result - hresult)
|
||||
#print("dif: %f" % np.max(dif))
|
||||
hresult = np.dot(ha.T, hb).T
|
||||
dif = np.abs(result - hresult)
|
||||
print("dif: %f" % np.max(dif))
|
||||
|
||||
#np.savetxt("dif.txt", dif, fmt="%5.2f")
|
||||
#np.savetxt("gpu.txt", result, fmt="%5.2f")
|
||||
|
@@ -510,16 +510,16 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
|
||||
/* intra warp offset */
|
||||
// offset of quad in pair
|
||||
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(pack_size_0_));
|
||||
Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(pack_size_1_));
|
||||
Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(fpw_0 * pack_size_0_));
|
||||
Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(fpw_1 * pack_size_1_));
|
||||
// Quad pair id
|
||||
Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
||||
Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4);
|
||||
pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0));
|
||||
pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0));
|
||||
// Quad pair offset
|
||||
Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(8 * pack_size_0_));
|
||||
Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(8 * pack_size_1_));
|
||||
Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_));
|
||||
Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_));
|
||||
|
||||
/* inter warp offset */
|
||||
Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0));
|
||||
@@ -557,8 +557,8 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
for(unsigned pack = 0; pack < num_packs_1_; pack++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++)
|
||||
for(unsigned j = 0; j < 2; j++){
|
||||
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*pack_size_1_)));
|
||||
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*pack_size_1_ + 1)));
|
||||
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
|
||||
idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
|
||||
}
|
||||
|
||||
/* axes */
|
||||
|
Reference in New Issue
Block a user