diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp index 183b3f492..cf88693e4 100644 --- a/examples/python/tensorflow/dot.cpp +++ b/examples/python/tensorflow/dot.cpp @@ -119,7 +119,7 @@ class BlockSparseGemmOp : public OpKernel { // just-in-time compile source-code // jit.autotune("matmul", src, benchmark); // jit.add_module("matmul", src, {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1}); - jit.add_module("matmul", src, {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1}); + jit.add_module("matmul", src, {16, 4, 128, 16, 4, 128, 1, 4, 2, 2, 8, 32, 8, 1}); // jit.add_module("matmul", src, {8, 8, 128, 16, 8, 128, 2, 2, 2, 2, 16, 32, 8, 1 }); triton::driver::kernel* kernel = jit.get_function("matmul"); triton::jit::launch_information info = jit.get_launch_info("matmul"); diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 86c0bc999..94764e515 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -6,7 +6,7 @@ data_files_path = tf.resource_loader.get_data_files_path() library_dir = os.path.dirname(os.path.realpath(__file__)) module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so')) -M, N, K = 8192, 8192, 8192 +M, N, K = 256, 256, 256 a = tf.placeholder(tf.float16, shape=[M, K]) b = tf.placeholder(tf.float16, shape=[N, K]) locks = tf.placeholder(tf.int32, shape=[4096]) @@ -30,9 +30,9 @@ result = sess.run([c], feed_dict = {locks: np.zeros(4096), # min_iters=100) #print(end - start) #print(2*M*N*K / (end - start) * 1e-12) -#hresult = np.dot(ha.T, hb).T -#dif = np.abs(result - hresult) -#print("dif: %f" % np.max(dif)) +hresult = np.dot(ha.T, hb).T +dif = np.abs(result - hresult) +print("dif: %f" % np.max(dif)) #np.savetxt("dif.txt", dif, fmt="%5.2f") #np.savetxt("gpu.txt", result, fmt="%5.2f") diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 5c1d67bf3..9ba1fa870 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -510,16 +510,16 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id /* intra warp offset */ // offset of quad in pair - Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(pack_size_0_)); - Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(pack_size_1_)); + Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(fpw_0 * pack_size_0_)); + Value *in_pair_off_b = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), builder.getInt32(fpw_1 * pack_size_1_)); // Quad pair id Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0)); pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0)); // Quad pair offset - Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(8 * pack_size_0_)); - Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(8 * pack_size_1_)); + Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_)); + Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_)); /* inter warp offset */ Value *warp_id_0 = builder.CreateURem(u_warp_id, builder.getInt32(wpt_0)); @@ -557,8 +557,8 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id for(unsigned pack = 0; pack < num_packs_1_; pack++) for(unsigned jj = 0; jj < pack_size_1_; jj++) for(unsigned j = 0; j < 2; j++){ - idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*pack_size_1_))); - idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*pack_size_1_ + 1))); + idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_))); + idx_j.push_back(builder.CreateAdd(offset_c_j, builder.getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1))); } /* axes */