[general] hmma baseline setup

This commit is contained in:
Philippe Tillet
2019-06-05 14:43:38 -07:00
parent 49fcfd6fc7
commit f58c9a4d2b
14 changed files with 50 additions and 33 deletions

View File

@@ -4,7 +4,7 @@ if(${TensorFlow_FOUND})
include_directories("${TF_INC}/tensorflow/include")
include_directories("${CUDA_HOME}/include")
link_directories(${TF_LIB})
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI})
add_library(tf_blocksparse SHARED dot.cpp)
target_link_libraries(tf_blocksparse tensorflow_framework triton)
endif()

View File

@@ -25,7 +25,8 @@ const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
const tunable int32 GZ = {1};
void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
void matmul(restrict read_only fp16 *A, restrict read_only fp16 *B,
fp32 *C,
int32 M, int32 N, int32 K,
int32 lda, int32 ldb, int32 ldc,
int32 *locks, int32 grid0, int32 grid1) {
@@ -39,10 +40,10 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
int32 rem = K % GZ;
K = select(rz < rem, div - 1, div);
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
fp32* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis];
fp32* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis];
fp32 a[TM, TK] = *pa;
fp32 b[TN, TK] = *pb;
fp16* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis];
fp16* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis];
fp16 a[TM, TK] = *pa;
fp16 b[TN, TK] = *pb;
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
last_a = last_a / TK * TK;
@@ -60,10 +61,10 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
for(int32 k = bound; k > 0; k = k - 1){
int1 checka[TM, 1] = rxc[:, newaxis] < M;
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
fp32* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis];
fp32* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis];
fp32 a[TM, 1] = checka ? *pa : 0;
fp32 b[TN, 1] = checkb ? *pb : 0;
fp16* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis];
fp16* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis];
fp16 a[TM, 1] = checka ? *pa : 0;
fp16 b[TN, 1] = checkb ? *pb : 0;
c = dot(a, trans(b), c);
}
int32 ridx = get_range_id(0);
@@ -89,13 +90,6 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
}
)";
REGISTER_OP("Dot")
.Input("a: T")
.Input("b: T")
.Input("locks: int32")
.Output("c: T")
.Attr("T: {float}")
;
class BlockSparseGemmOp : public OpKernel {
public:
@@ -126,8 +120,8 @@ class BlockSparseGemmOp : public OpKernel {
// initialize default compute device
triton::jit jit(ctx);
// matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<float>().data(), false);
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
stream->synchronize();
@@ -160,4 +154,10 @@ class BlockSparseGemmOp : public OpKernel {
private:
};
REGISTER_KERNEL_BUILDER(Name("Dot").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlockSparseGemmOp);
REGISTER_KERNEL_BUILDER(Name("Dot").Device(DEVICE_GPU), BlockSparseGemmOp);
REGISTER_OP("Dot")
.Input("a: float16")
.Input("b: float16")
.Input("locks: int32")
.Output("c: float32")
;

View File

@@ -3,18 +3,23 @@ import tensorflow as tf
import numpy as np
data_files_path = tf.resource_loader.get_data_files_path()
library_dir = '/home/philippe/Development/triton/build/examples/python/tensorflow'
library_dir = '/home/philippe/development/triton/build/examples/python/tensorflow'
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
M, N, K = 512, 512, 512
a = tf.placeholder(tf.float32, shape=[M, K])
b = tf.placeholder(tf.float32, shape=[N, K])
a = tf.placeholder(tf.float16, shape=[M, K])
b = tf.placeholder(tf.float16, shape=[N, K])
locks = tf.placeholder(tf.int32, shape=[4096])
c = module.block_sparse_mat_mul(a, b, locks)
c = module.dot(a, b, locks)
# Reference
ha = np.random.rand(M, K).astype(np.float16)
hb = np.random.rand(N, K).astype(np.float16)
hresult = np.dot(hb.T, ha)
# Run
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([c], feed_dict = {locks: np.zeros(4096),
a: np.random.rand(M, K),
b: np.random.rand(N, K)})
print(result)
a: ha,
b: hb})
print(result - hresult)