From f58c9a4d2baf635b53f278ef0115687b22bfacf4 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 5 Jun 2019 14:43:38 -0700 Subject: [PATCH] [general] hmma baseline setup --- examples/cpp/dot.cpp | 4 +-- examples/python/tensorflow/CMakeLists.txt | 2 +- examples/python/tensorflow/dot.cpp | 38 +++++++++++------------ examples/python/tensorflow/run.py | 19 +++++++----- include/triton/ir/builder.h | 1 + include/triton/lang/ops.h | 2 +- include/triton/lang/parser.y | 3 +- include/triton/lang/scanner.l | 1 + include/triton/runtime/jit.h | 1 + lib/driver/module.cpp | 2 -- lib/ir/builder.cpp | 3 ++ lib/ir/ir.cpp | 0 lib/lang/declaration.cpp | 1 + lib/lang/node.cpp | 6 ++++ 14 files changed, 50 insertions(+), 33 deletions(-) delete mode 100644 lib/ir/ir.cpp diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index abaed5ff3..8b7559f55 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -16,7 +16,7 @@ int main() { triton::jit jit(context); // matrix multiplication parameters - int32_t M = 512, N = 512, K = 512; + int32_t M = 2048, N = 2048, K = 2048; std::vector hc(M*N); std::vector rc(M*N); std::vector ha(M*K); @@ -60,7 +60,7 @@ int main() { // just-in-time compile source-code std::string src = triton::dnn::gemm::src(AT, BT); -// jit.autotune("matmul",src.c_str(), benchmark); + jit.autotune("matmul",src.c_str(), benchmark); jit.add_module("matmul", src.c_str(), triton::dnn::gemm::default_params(AT, BT)); triton::driver::kernel* kernel = jit.get_function("matmul"); triton::jit::launch_information info = jit.get_launch_info("matmul"); diff --git a/examples/python/tensorflow/CMakeLists.txt b/examples/python/tensorflow/CMakeLists.txt index 1ce055203..6c8a6f008 100644 --- a/examples/python/tensorflow/CMakeLists.txt +++ b/examples/python/tensorflow/CMakeLists.txt @@ -4,7 +4,7 @@ if(${TensorFlow_FOUND}) include_directories("${TF_INC}/tensorflow/include") include_directories("${CUDA_HOME}/include") link_directories(${TF_LIB}) - add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) + add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI}) add_library(tf_blocksparse SHARED dot.cpp) target_link_libraries(tf_blocksparse tensorflow_framework triton) endif() diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp index 70ab8c386..c87b054fa 100644 --- a/examples/python/tensorflow/dot.cpp +++ b/examples/python/tensorflow/dot.cpp @@ -25,7 +25,8 @@ const tunable int32 TN = {16, 32, 64, 128}; const tunable int32 TK = {8}; const tunable int32 GZ = {1}; -void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C, +void matmul(restrict read_only fp16 *A, restrict read_only fp16 *B, + fp32 *C, int32 M, int32 N, int32 K, int32 lda, int32 ldb, int32 ldc, int32 *locks, int32 grid0, int32 grid1) { @@ -39,10 +40,10 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C, int32 rem = K % GZ; K = select(rz < rem, div - 1, div); int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem); - fp32* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis]; - fp32* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis]; - fp32 a[TM, TK] = *pa; - fp32 b[TN, TK] = *pb; + fp16* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis]; + fp16* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis]; + fp16 a[TM, TK] = *pa; + fp16 b[TN, TK] = *pb; int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda; int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb; last_a = last_a / TK * TK; @@ -60,10 +61,10 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C, for(int32 k = bound; k > 0; k = k - 1){ int1 checka[TM, 1] = rxc[:, newaxis] < M; int1 checkb[TN, 1] = ryc[:, newaxis] < N; - fp32* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis]; - fp32* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis]; - fp32 a[TM, 1] = checka ? *pa : 0; - fp32 b[TN, 1] = checkb ? *pb : 0; + fp16* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis]; + fp16* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis]; + fp16 a[TM, 1] = checka ? *pa : 0; + fp16 b[TN, 1] = checkb ? *pb : 0; c = dot(a, trans(b), c); } int32 ridx = get_range_id(0); @@ -89,13 +90,6 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C, } )"; -REGISTER_OP("Dot") - .Input("a: T") - .Input("b: T") - .Input("locks: int32") - .Output("c: T") - .Attr("T: {float}") -; class BlockSparseGemmOp : public OpKernel { public: @@ -126,8 +120,8 @@ class BlockSparseGemmOp : public OpKernel { // initialize default compute device triton::jit jit(ctx); // matrix multiplication parameters - triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat().data(), false); - triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat().data(), false); + triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat().data(), false); + triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat().data(), false); triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat().data(), false); triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat().data(), false); stream->synchronize(); @@ -160,4 +154,10 @@ class BlockSparseGemmOp : public OpKernel { private: }; -REGISTER_KERNEL_BUILDER(Name("Dot").Device(DEVICE_GPU).TypeConstraint("T"), BlockSparseGemmOp); +REGISTER_KERNEL_BUILDER(Name("Dot").Device(DEVICE_GPU), BlockSparseGemmOp); +REGISTER_OP("Dot") + .Input("a: float16") + .Input("b: float16") + .Input("locks: int32") + .Output("c: float32") +; diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 5a721def9..194e6e9ed 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -3,18 +3,23 @@ import tensorflow as tf import numpy as np data_files_path = tf.resource_loader.get_data_files_path() -library_dir = '/home/philippe/Development/triton/build/examples/python/tensorflow' +library_dir = '/home/philippe/development/triton/build/examples/python/tensorflow' module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so')) M, N, K = 512, 512, 512 -a = tf.placeholder(tf.float32, shape=[M, K]) -b = tf.placeholder(tf.float32, shape=[N, K]) +a = tf.placeholder(tf.float16, shape=[M, K]) +b = tf.placeholder(tf.float16, shape=[N, K]) locks = tf.placeholder(tf.int32, shape=[4096]) -c = module.block_sparse_mat_mul(a, b, locks) +c = module.dot(a, b, locks) +# Reference +ha = np.random.rand(M, K).astype(np.float16) +hb = np.random.rand(N, K).astype(np.float16) +hresult = np.dot(hb.T, ha) + # Run sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) result = sess.run([c], feed_dict = {locks: np.zeros(4096), - a: np.random.rand(M, K), - b: np.random.rand(N, K)}) -print(result) + a: ha, + b: hb}) +print(result - hresult) diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 51dd656d3..48b1d172d 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -40,6 +40,7 @@ public: type *get_int16_ty(); type *get_int32_ty(); type *get_int64_ty(); + type *get_half_ty(); type *get_float_ty(); type *get_double_ty(); // Insert diff --git a/include/triton/lang/ops.h b/include/triton/lang/ops.h index 9328be921..38fc200bf 100644 --- a/include/triton/lang/ops.h +++ b/include/triton/lang/ops.h @@ -35,7 +35,7 @@ enum TYPE_T{ VOID_T, UINT1_T, UINT8_T, UINT16_T, UINT32_T, UINT64_T, INT1_T, INT8_T, INT16_T, INT32_T, INT64_T, - FLOAT32_T, FLOAT64_T + FLOAT16_T, FLOAT32_T, FLOAT64_T }; enum STORAGE_SPEC_T{ diff --git a/include/triton/lang/parser.y b/include/triton/lang/parser.y index 66d7c1770..18fc3bbed 100644 --- a/include/triton/lang/parser.y +++ b/include/triton/lang/parser.y @@ -52,7 +52,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;} %token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN %token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN %token XOR_ASSIGN OR_ASSIGN TYPE_NAME -%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP32 FP64 +%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64 %token IF ELSE FOR CONTINUE WHILE %token NEWAXIS ELLIPSIS AT %token GET_GLOBAL_RANGE GET_RANGE_ID DOT TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ALLOC_CONST @@ -77,6 +77,7 @@ type_specifier | INT16 { $$ = new token(INT16_T); } | INT32 { $$ = new token(INT32_T); } | INT64 { $$ = new token(INT64_T); } + | FP16 { $$ = new token(FLOAT16_T); } | FP32 { $$ = new token(FLOAT32_T); } | FP64 { $$ = new token(FLOAT64_T); } ; diff --git a/include/triton/lang/scanner.l b/include/triton/lang/scanner.l index b1160fb1c..a2cd50922 100644 --- a/include/triton/lang/scanner.l +++ b/include/triton/lang/scanner.l @@ -38,6 +38,7 @@ using triton::lang::return_void; "int16" { return return_impl(INT16, yytext); } "int32" { return return_impl(INT32, yytext); } "int64" { return return_impl(INT64, yytext); } +"fp16" { return return_impl(FP16, yytext); } "fp32" { return return_impl(FP32, yytext); } "fp64" { return return_impl(FP64, yytext); } "..." { return return_impl(ELLIPSIS, yytext); } diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index 476d25f5a..424a00e6d 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -65,6 +65,7 @@ public: void target_independent(ir::module &module) { optimize_dot.run(module); optimize_trans.run(module); +// ir::print(module, std::cout); } void target_dependent(ir::module &module) { diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index 3f595b318..19c9baccb 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -247,8 +247,6 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) { return std::string(buffer.begin(), buffer.end()); } - - cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ diff --git a/lib/ir/builder.cpp b/lib/ir/builder.cpp index d82ee2c3b..5de366045 100644 --- a/lib/ir/builder.cpp +++ b/lib/ir/builder.cpp @@ -56,6 +56,9 @@ type *builder::get_int32_ty() type *builder::get_int64_ty() { return type::get_int64_ty(ctx_); } +type *builder::get_half_ty() +{ return type::get_half_ty(ctx_); } + type *builder::get_float_ty() { return type::get_float_ty(ctx_); } diff --git a/lib/ir/ir.cpp b/lib/ir/ir.cpp deleted file mode 100644 index e69de29bb..000000000 diff --git a/lib/lang/declaration.cpp b/lib/lang/declaration.cpp index d4a73ef00..46fa6b597 100644 --- a/lib/lang/declaration.cpp +++ b/lib/lang/declaration.cpp @@ -21,6 +21,7 @@ ir::type* typed_declaration_specifier::type(ir::module *mod) const { case INT16_T: return ir::type::get_int16_ty(ctx); case INT32_T: return ir::type::get_int32_ty(ctx); case INT64_T: return ir::type::get_int64_ty(ctx); + case FLOAT16_T: return ir::type::get_half_ty(ctx); case FLOAT32_T: return ir::type::get_float_ty(ctx); case FLOAT64_T: return ir::type::get_double_ty(ctx); default: throw std::runtime_error("unreachable"); diff --git a/lib/lang/node.cpp b/lib/lang/node.cpp index f25a5fdf5..418a86fca 100644 --- a/lib/lang/node.cpp +++ b/lib/lang/node.cpp @@ -69,6 +69,12 @@ void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs, to_convert = explicit_cast(builder, to_convert, builder.get_float_ty()); is_float = true; } + // One operand is half + else if(left_ty->is_half_ty() || right_ty->is_half_ty()){ + ir::value *&to_convert = left_ty->is_half_ty()?rhs:lhs; + to_convert = explicit_cast(builder, to_convert, builder.get_half_ty()); + is_float = true; + } // Both operands are integers else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){ is_int = true;