[general] hmma baseline setup
This commit is contained in:
@@ -16,7 +16,7 @@ int main() {
|
|||||||
triton::jit jit(context);
|
triton::jit jit(context);
|
||||||
|
|
||||||
// matrix multiplication parameters
|
// matrix multiplication parameters
|
||||||
int32_t M = 512, N = 512, K = 512;
|
int32_t M = 2048, N = 2048, K = 2048;
|
||||||
std::vector<float> hc(M*N);
|
std::vector<float> hc(M*N);
|
||||||
std::vector<float> rc(M*N);
|
std::vector<float> rc(M*N);
|
||||||
std::vector<float> ha(M*K);
|
std::vector<float> ha(M*K);
|
||||||
@@ -60,7 +60,7 @@ int main() {
|
|||||||
|
|
||||||
// just-in-time compile source-code
|
// just-in-time compile source-code
|
||||||
std::string src = triton::dnn::gemm::src(AT, BT);
|
std::string src = triton::dnn::gemm::src(AT, BT);
|
||||||
// jit.autotune("matmul",src.c_str(), benchmark);
|
jit.autotune("matmul",src.c_str(), benchmark);
|
||||||
jit.add_module("matmul", src.c_str(), triton::dnn::gemm::default_params(AT, BT));
|
jit.add_module("matmul", src.c_str(), triton::dnn::gemm::default_params(AT, BT));
|
||||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||||
|
@@ -4,7 +4,7 @@ if(${TensorFlow_FOUND})
|
|||||||
include_directories("${TF_INC}/tensorflow/include")
|
include_directories("${TF_INC}/tensorflow/include")
|
||||||
include_directories("${CUDA_HOME}/include")
|
include_directories("${CUDA_HOME}/include")
|
||||||
link_directories(${TF_LIB})
|
link_directories(${TF_LIB})
|
||||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI})
|
||||||
add_library(tf_blocksparse SHARED dot.cpp)
|
add_library(tf_blocksparse SHARED dot.cpp)
|
||||||
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
target_link_libraries(tf_blocksparse tensorflow_framework triton)
|
||||||
endif()
|
endif()
|
||||||
|
@@ -25,7 +25,8 @@ const tunable int32 TN = {16, 32, 64, 128};
|
|||||||
const tunable int32 TK = {8};
|
const tunable int32 TK = {8};
|
||||||
const tunable int32 GZ = {1};
|
const tunable int32 GZ = {1};
|
||||||
|
|
||||||
void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
void matmul(restrict read_only fp16 *A, restrict read_only fp16 *B,
|
||||||
|
fp32 *C,
|
||||||
int32 M, int32 N, int32 K,
|
int32 M, int32 N, int32 K,
|
||||||
int32 lda, int32 ldb, int32 ldc,
|
int32 lda, int32 ldb, int32 ldc,
|
||||||
int32 *locks, int32 grid0, int32 grid1) {
|
int32 *locks, int32 grid0, int32 grid1) {
|
||||||
@@ -39,10 +40,10 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
|||||||
int32 rem = K % GZ;
|
int32 rem = K % GZ;
|
||||||
K = select(rz < rem, div - 1, div);
|
K = select(rz < rem, div - 1, div);
|
||||||
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
|
int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem);
|
||||||
fp32* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis];
|
fp16* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis];
|
||||||
fp32* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis];
|
fp16* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis];
|
||||||
fp32 a[TM, TK] = *pa;
|
fp16 a[TM, TK] = *pa;
|
||||||
fp32 b[TN, TK] = *pb;
|
fp16 b[TN, TK] = *pb;
|
||||||
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
|
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
|
||||||
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
|
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
|
||||||
last_a = last_a / TK * TK;
|
last_a = last_a / TK * TK;
|
||||||
@@ -60,10 +61,10 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
|||||||
for(int32 k = bound; k > 0; k = k - 1){
|
for(int32 k = bound; k > 0; k = k - 1){
|
||||||
int1 checka[TM, 1] = rxc[:, newaxis] < M;
|
int1 checka[TM, 1] = rxc[:, newaxis] < M;
|
||||||
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
|
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
|
||||||
fp32* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis];
|
fp16* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis];
|
||||||
fp32* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis];
|
fp16* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis];
|
||||||
fp32 a[TM, 1] = checka ? *pa : 0;
|
fp16 a[TM, 1] = checka ? *pa : 0;
|
||||||
fp32 b[TN, 1] = checkb ? *pb : 0;
|
fp16 b[TN, 1] = checkb ? *pb : 0;
|
||||||
c = dot(a, trans(b), c);
|
c = dot(a, trans(b), c);
|
||||||
}
|
}
|
||||||
int32 ridx = get_range_id(0);
|
int32 ridx = get_range_id(0);
|
||||||
@@ -89,13 +90,6 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
REGISTER_OP("Dot")
|
|
||||||
.Input("a: T")
|
|
||||||
.Input("b: T")
|
|
||||||
.Input("locks: int32")
|
|
||||||
.Output("c: T")
|
|
||||||
.Attr("T: {float}")
|
|
||||||
;
|
|
||||||
|
|
||||||
class BlockSparseGemmOp : public OpKernel {
|
class BlockSparseGemmOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
@@ -126,8 +120,8 @@ class BlockSparseGemmOp : public OpKernel {
|
|||||||
// initialize default compute device
|
// initialize default compute device
|
||||||
triton::jit jit(ctx);
|
triton::jit jit(ctx);
|
||||||
// matrix multiplication parameters
|
// matrix multiplication parameters
|
||||||
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<float>().data(), false);
|
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<Eigen::half>().data(), false);
|
||||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<float>().data(), false);
|
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<Eigen::half>().data(), false);
|
||||||
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
|
||||||
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
|
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
@@ -160,4 +154,10 @@ class BlockSparseGemmOp : public OpKernel {
|
|||||||
private:
|
private:
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("Dot").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlockSparseGemmOp);
|
REGISTER_KERNEL_BUILDER(Name("Dot").Device(DEVICE_GPU), BlockSparseGemmOp);
|
||||||
|
REGISTER_OP("Dot")
|
||||||
|
.Input("a: float16")
|
||||||
|
.Input("b: float16")
|
||||||
|
.Input("locks: int32")
|
||||||
|
.Output("c: float32")
|
||||||
|
;
|
||||||
|
@@ -3,18 +3,23 @@ import tensorflow as tf
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
data_files_path = tf.resource_loader.get_data_files_path()
|
data_files_path = tf.resource_loader.get_data_files_path()
|
||||||
library_dir = '/home/philippe/Development/triton/build/examples/python/tensorflow'
|
library_dir = '/home/philippe/development/triton/build/examples/python/tensorflow'
|
||||||
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
|
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
|
||||||
|
|
||||||
M, N, K = 512, 512, 512
|
M, N, K = 512, 512, 512
|
||||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
a = tf.placeholder(tf.float16, shape=[M, K])
|
||||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
b = tf.placeholder(tf.float16, shape=[N, K])
|
||||||
locks = tf.placeholder(tf.int32, shape=[4096])
|
locks = tf.placeholder(tf.int32, shape=[4096])
|
||||||
c = module.block_sparse_mat_mul(a, b, locks)
|
c = module.dot(a, b, locks)
|
||||||
|
# Reference
|
||||||
|
ha = np.random.rand(M, K).astype(np.float16)
|
||||||
|
hb = np.random.rand(N, K).astype(np.float16)
|
||||||
|
hresult = np.dot(hb.T, ha)
|
||||||
|
|
||||||
# Run
|
# Run
|
||||||
sess = tf.InteractiveSession()
|
sess = tf.InteractiveSession()
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
result = sess.run([c], feed_dict = {locks: np.zeros(4096),
|
result = sess.run([c], feed_dict = {locks: np.zeros(4096),
|
||||||
a: np.random.rand(M, K),
|
a: ha,
|
||||||
b: np.random.rand(N, K)})
|
b: hb})
|
||||||
print(result)
|
print(result - hresult)
|
||||||
|
@@ -40,6 +40,7 @@ public:
|
|||||||
type *get_int16_ty();
|
type *get_int16_ty();
|
||||||
type *get_int32_ty();
|
type *get_int32_ty();
|
||||||
type *get_int64_ty();
|
type *get_int64_ty();
|
||||||
|
type *get_half_ty();
|
||||||
type *get_float_ty();
|
type *get_float_ty();
|
||||||
type *get_double_ty();
|
type *get_double_ty();
|
||||||
// Insert
|
// Insert
|
||||||
|
@@ -35,7 +35,7 @@ enum TYPE_T{
|
|||||||
VOID_T,
|
VOID_T,
|
||||||
UINT1_T, UINT8_T, UINT16_T, UINT32_T, UINT64_T,
|
UINT1_T, UINT8_T, UINT16_T, UINT32_T, UINT64_T,
|
||||||
INT1_T, INT8_T, INT16_T, INT32_T, INT64_T,
|
INT1_T, INT8_T, INT16_T, INT32_T, INT64_T,
|
||||||
FLOAT32_T, FLOAT64_T
|
FLOAT16_T, FLOAT32_T, FLOAT64_T
|
||||||
};
|
};
|
||||||
|
|
||||||
enum STORAGE_SPEC_T{
|
enum STORAGE_SPEC_T{
|
||||||
|
@@ -52,7 +52,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
|
|||||||
%token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN
|
%token AND_OP OR_OP MUL_ASSIGN DIV_ASSIGN MOD_ASSIGN ADD_ASSIGN
|
||||||
%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN
|
%token SUB_ASSIGN LEFT_ASSIGN RIGHT_ASSIGN AND_ASSIGN
|
||||||
%token XOR_ASSIGN OR_ASSIGN TYPE_NAME
|
%token XOR_ASSIGN OR_ASSIGN TYPE_NAME
|
||||||
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP32 FP64
|
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
|
||||||
%token IF ELSE FOR CONTINUE WHILE
|
%token IF ELSE FOR CONTINUE WHILE
|
||||||
%token NEWAXIS ELLIPSIS AT
|
%token NEWAXIS ELLIPSIS AT
|
||||||
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ALLOC_CONST
|
%token GET_GLOBAL_RANGE GET_RANGE_ID DOT TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCHG ALLOC_CONST
|
||||||
@@ -77,6 +77,7 @@ type_specifier
|
|||||||
| INT16 { $$ = new token(INT16_T); }
|
| INT16 { $$ = new token(INT16_T); }
|
||||||
| INT32 { $$ = new token(INT32_T); }
|
| INT32 { $$ = new token(INT32_T); }
|
||||||
| INT64 { $$ = new token(INT64_T); }
|
| INT64 { $$ = new token(INT64_T); }
|
||||||
|
| FP16 { $$ = new token(FLOAT16_T); }
|
||||||
| FP32 { $$ = new token(FLOAT32_T); }
|
| FP32 { $$ = new token(FLOAT32_T); }
|
||||||
| FP64 { $$ = new token(FLOAT64_T); }
|
| FP64 { $$ = new token(FLOAT64_T); }
|
||||||
;
|
;
|
||||||
|
@@ -38,6 +38,7 @@ using triton::lang::return_void;
|
|||||||
"int16" { return return_impl(INT16, yytext); }
|
"int16" { return return_impl(INT16, yytext); }
|
||||||
"int32" { return return_impl(INT32, yytext); }
|
"int32" { return return_impl(INT32, yytext); }
|
||||||
"int64" { return return_impl(INT64, yytext); }
|
"int64" { return return_impl(INT64, yytext); }
|
||||||
|
"fp16" { return return_impl(FP16, yytext); }
|
||||||
"fp32" { return return_impl(FP32, yytext); }
|
"fp32" { return return_impl(FP32, yytext); }
|
||||||
"fp64" { return return_impl(FP64, yytext); }
|
"fp64" { return return_impl(FP64, yytext); }
|
||||||
"..." { return return_impl(ELLIPSIS, yytext); }
|
"..." { return return_impl(ELLIPSIS, yytext); }
|
||||||
|
@@ -65,6 +65,7 @@ public:
|
|||||||
void target_independent(ir::module &module) {
|
void target_independent(ir::module &module) {
|
||||||
optimize_dot.run(module);
|
optimize_dot.run(module);
|
||||||
optimize_trans.run(module);
|
optimize_trans.run(module);
|
||||||
|
// ir::print(module, std::cout);
|
||||||
}
|
}
|
||||||
|
|
||||||
void target_dependent(ir::module &module) {
|
void target_dependent(ir::module &module) {
|
||||||
|
@@ -247,8 +247,6 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
|||||||
return std::string(buffer.begin(), buffer.end());
|
return std::string(buffer.begin(), buffer.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||||
|
|
||||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||||
|
@@ -56,6 +56,9 @@ type *builder::get_int32_ty()
|
|||||||
type *builder::get_int64_ty()
|
type *builder::get_int64_ty()
|
||||||
{ return type::get_int64_ty(ctx_); }
|
{ return type::get_int64_ty(ctx_); }
|
||||||
|
|
||||||
|
type *builder::get_half_ty()
|
||||||
|
{ return type::get_half_ty(ctx_); }
|
||||||
|
|
||||||
type *builder::get_float_ty()
|
type *builder::get_float_ty()
|
||||||
{ return type::get_float_ty(ctx_); }
|
{ return type::get_float_ty(ctx_); }
|
||||||
|
|
||||||
|
@@ -21,6 +21,7 @@ ir::type* typed_declaration_specifier::type(ir::module *mod) const {
|
|||||||
case INT16_T: return ir::type::get_int16_ty(ctx);
|
case INT16_T: return ir::type::get_int16_ty(ctx);
|
||||||
case INT32_T: return ir::type::get_int32_ty(ctx);
|
case INT32_T: return ir::type::get_int32_ty(ctx);
|
||||||
case INT64_T: return ir::type::get_int64_ty(ctx);
|
case INT64_T: return ir::type::get_int64_ty(ctx);
|
||||||
|
case FLOAT16_T: return ir::type::get_half_ty(ctx);
|
||||||
case FLOAT32_T: return ir::type::get_float_ty(ctx);
|
case FLOAT32_T: return ir::type::get_float_ty(ctx);
|
||||||
case FLOAT64_T: return ir::type::get_double_ty(ctx);
|
case FLOAT64_T: return ir::type::get_double_ty(ctx);
|
||||||
default: throw std::runtime_error("unreachable");
|
default: throw std::runtime_error("unreachable");
|
||||||
|
@@ -69,6 +69,12 @@ void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
|||||||
to_convert = explicit_cast(builder, to_convert, builder.get_float_ty());
|
to_convert = explicit_cast(builder, to_convert, builder.get_float_ty());
|
||||||
is_float = true;
|
is_float = true;
|
||||||
}
|
}
|
||||||
|
// One operand is half
|
||||||
|
else if(left_ty->is_half_ty() || right_ty->is_half_ty()){
|
||||||
|
ir::value *&to_convert = left_ty->is_half_ty()?rhs:lhs;
|
||||||
|
to_convert = explicit_cast(builder, to_convert, builder.get_half_ty());
|
||||||
|
is_float = true;
|
||||||
|
}
|
||||||
// Both operands are integers
|
// Both operands are integers
|
||||||
else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){
|
else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){
|
||||||
is_int = true;
|
is_int = true;
|
||||||
|
Reference in New Issue
Block a user