trying 128 bits loads

This commit is contained in:
Philippe Tillet
2019-06-12 21:07:01 -07:00
parent 1c6372711b
commit d487cf31ce
5 changed files with 46 additions and 44 deletions

View File

@@ -23,7 +23,7 @@ const char* src =
R"( R"(
const tunable int32 TM = {64, 128}; const tunable int32 TM = {64, 128};
const tunable int32 TN = {64, 128}; const tunable int32 TN = {64, 128};
const tunable int32 TK = {32}; const tunable int32 TK = {16};
const tunable int32 GZ = {1}; const tunable int32 GZ = {1};
void matmul(restrict read_only fp16 *A, restrict read_only fp16 *B, void matmul(restrict read_only fp16 *A, restrict read_only fp16 *B,
@@ -117,7 +117,7 @@ class BlockSparseGemmOp : public OpKernel {
return 2.*M*N*K / ts * 1e-3; return 2.*M*N*K / ts * 1e-3;
}; };
// just-in-time compile source-code // just-in-time compile source-code
// jit.autotune("matmul", src, benchmark); jit.autotune("matmul", src, benchmark);
// jit.add_module("matmul", src, {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1}); // jit.add_module("matmul", src, {4, 2, 8, 4, 2, 32, 1, 4, 1, 1, 8, 8, 8, 1});
// jit.add_module("matmul", src, {32, 2, 128, 32, 2, 128, 2, 2, 2, 2, 4, 8, 4, 1}); // jit.add_module("matmul", src, {32, 2, 128, 32, 2, 128, 2, 2, 2, 2, 4, 8, 4, 1});
jit.add_module("matmul", src, {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1}); jit.add_module("matmul", src, {16, 4, 128, 16, 4, 128, 2, 2, 2, 2, 8, 32, 8, 1});

View File

@@ -6,7 +6,7 @@ data_files_path = tf.resource_loader.get_data_files_path()
library_dir = os.path.dirname(os.path.realpath(__file__)) library_dir = os.path.dirname(os.path.realpath(__file__))
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so')) module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
M, N, K = 128, 128, 128 M, N, K = 8192, 8192, 8192
a = tf.placeholder(tf.float16, shape=[M, K]) a = tf.placeholder(tf.float16, shape=[M, K])
b = tf.placeholder(tf.float16, shape=[N, K]) b = tf.placeholder(tf.float16, shape=[N, K])
locks = tf.placeholder(tf.int32, shape=[4096]) locks = tf.placeholder(tf.int32, shape=[4096])
@@ -30,10 +30,10 @@ result = sess.run([c], feed_dict = {locks: np.zeros(4096),
# min_iters=100) # min_iters=100)
#print(end - start) #print(end - start)
#print(2*M*N*K / (end - start) * 1e-12) #print(2*M*N*K / (end - start) * 1e-12)
hresult = np.dot(ha.T, hb).T #hresult = np.dot(ha.T, hb).T
dif = np.abs(result - hresult) #dif = np.abs(result - hresult)
print("dif: %f" % np.max(dif)) #print("dif: %f" % np.max(dif))
np.savetxt("dif.txt", dif, fmt="%5.2f") #np.savetxt("dif.txt", dif, fmt="%5.2f")
np.savetxt("gpu.txt", result, fmt="%5.2f") #np.savetxt("gpu.txt", result, fmt="%5.2f")
np.savetxt("cpu.txt", hresult, fmt="%5.2f") #np.savetxt("cpu.txt", hresult, fmt="%5.2f")

View File

@@ -501,8 +501,8 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
unsigned num_rep_0 = shapes[0]->get_value() / hmma_bts_0; unsigned num_rep_0 = shapes[0]->get_value() / hmma_bts_0;
unsigned num_rep_1 = shapes[1]->get_value() / hmma_bts_1; unsigned num_rep_1 = shapes[1]->get_value() / hmma_bts_1;
// size of each pack (interleaving) // size of each pack (interleaving)
pack_size_0_ = 2; pack_size_0_ = std::min<unsigned>(num_rep_0, 2);
pack_size_1_ = 2; pack_size_1_ = std::min<unsigned>(num_rep_1, 2);
// number of packs (interleaving) // number of packs (interleaving)
num_packs_0_ = num_rep_0 / pack_size_0_; num_packs_0_ = num_rep_0 / pack_size_0_;
num_packs_1_ = num_rep_1 / pack_size_1_; num_packs_1_ = num_rep_1 / pack_size_1_;
@@ -922,8 +922,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
} }
else else
{ {
TA->set_vector_size(4); TA->set_vector_size(4*pack_size_0_);
TB->set_vector_size(4); TB->set_vector_size(4*pack_size_1_);
TA->set_return_mode(true); TA->set_return_mode(true);
TB->set_return_mode(true); TB->set_return_mode(true);
@@ -955,38 +955,40 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
unsigned num_rep_j = shapes[1]->get_value() / stride_rep_j; unsigned num_rep_j = shapes[1]->get_value() / stride_rep_j;
unsigned ld_fc = num_rep_i * 2; unsigned ld_fc = num_rep_i * 2;
for(unsigned pack_i = 0; pack_i < num_packs_0_; pack_i++) for(unsigned pack_i = 0; pack_i < num_packs_0_; pack_i++)
for(unsigned ii = 0; ii < pack_size_0_; ii++) for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++){
for(unsigned pack_j = 0; pack_j < num_packs_1_; pack_j++)
for(unsigned jj = 0; jj < pack_size_1_; jj++)
for(unsigned K = 0; K < NK; K += 4){ for(unsigned K = 0; K < NK; K += 4){
Value *_K = builder.getInt32(K); Value *_K = builder.getInt32(K);
Value *current_offset_a_i = builder.CreateAdd(offset_a_i_, builder.getInt32(pack_i*stride_rep_i*pack_size_0_ + ii*4)); Value *current_offset_a_i = builder.CreateAdd(offset_a_i_, builder.getInt32(pack_i*stride_rep_i*pack_size_0_));
Value *current_offset_b_i = builder.CreateAdd(offset_b_j_, builder.getInt32(pack_j*stride_rep_j*pack_size_1_ + jj*4)); Value *current_offset_b_i = builder.CreateAdd(offset_b_j_, builder.getInt32(pack_j*stride_rep_j*pack_size_1_));
Value *ha = TA->get_value({current_offset_a_i, builder.CreateAdd(offset_a_k_, _K)}); Value *ha = TA->get_value({current_offset_a_i, builder.CreateAdd(offset_a_k_, _K)});
Value *hb = TB->get_value({current_offset_b_i, builder.CreateAdd(offset_b_k_, _K)}); Value *hb = TB->get_value({current_offset_b_i, builder.CreateAdd(offset_b_k_, _K)});
Value *ha0 = builder.CreateExtractElement(ha, builder.getInt32(0)); for(unsigned ii = 0; ii < pack_size_0_; ii++)
Value *ha1 = builder.CreateExtractElement(ha, builder.getInt32(1)); for(unsigned jj = 0; jj < pack_size_1_; jj++){
Value *hb0 = builder.CreateExtractElement(hb, builder.getInt32(0)); Value *ha0 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 0));
Value *hb1 = builder.CreateExtractElement(hb, builder.getInt32(1)); Value *ha1 = builder.CreateExtractElement(ha, builder.getInt32(ii*pack_size_0_ + 1));
std::vector<size_t> idx = { Value *hb0 = builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 0));
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc, Value *hb1 = builder.CreateExtractElement(hb, builder.getInt32(jj*pack_size_0_ + 1));
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc, std::vector<size_t> idx = {
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc, (pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc, (pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc, (pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 0)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc, (pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 1)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc, (pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
(pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc (pack_i*2*pack_size_0_ + ii*2 + 0) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc,
}; (pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 2)*ld_fc,
Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]}); (pack_i*2*pack_size_0_ + ii*2 + 1) + (pack_j*4*pack_size_1_ + jj*4 + 3)*ld_fc
fc[idx[0]] = builder.CreateExtractValue(nc, {0}); };
fc[idx[1]] = builder.CreateExtractValue(nc, {1}); Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]});
fc[idx[2]] = builder.CreateExtractValue(nc, {2}); fc[idx[0]] = builder.CreateExtractValue(nc, {0});
fc[idx[3]] = builder.CreateExtractValue(nc, {3}); fc[idx[1]] = builder.CreateExtractValue(nc, {1});
fc[idx[4]] = builder.CreateExtractValue(nc, {4}); fc[idx[2]] = builder.CreateExtractValue(nc, {2});
fc[idx[5]] = builder.CreateExtractValue(nc, {5}); fc[idx[3]] = builder.CreateExtractValue(nc, {3});
fc[idx[6]] = builder.CreateExtractValue(nc, {6}); fc[idx[4]] = builder.CreateExtractValue(nc, {4});
fc[idx[7]] = builder.CreateExtractValue(nc, {7}); fc[idx[5]] = builder.CreateExtractValue(nc, {5});
fc[idx[6]] = builder.CreateExtractValue(nc, {6});
fc[idx[7]] = builder.CreateExtractValue(nc, {7});
}
}
} }
// write back // write back

View File

@@ -24,9 +24,9 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) {
if(dynamic_cast<ir::dot_inst*>(user)) if(dynamic_cast<ir::dot_inst*>(user))
if(params_->get_fragment(user, 0) == tune::HMMA_FRAGMENT_C){ if(params_->get_fragment(user, 0) == tune::HMMA_FRAGMENT_C){
if(x == user->get_operand(0)) if(x == user->get_operand(0))
return 8; return 4;
else else
return 16; return 4;
} }
return 0; return 0;
} }

View File

@@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { } cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
std::cout << source << std::endl; // std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context); cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code // JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};