[examples/python/tensorflow] bugfix in tensorflow wrapper example

This commit is contained in:
Philippe Tillet
2019-04-30 21:04:30 -04:00
parent d934d8fb40
commit 7b6efc0463
12 changed files with 90 additions and 171 deletions

View File

@@ -1,12 +1,14 @@
execute_process(COMMAND python -c "from os.path import dirname; import tensorflow as tf; print(dirname(dirname(tf.sysconfig.get_include())))"
OUTPUT_VARIABLE TF_INC OUTPUT_STRIP_TRAILING_WHITESPACE)
#execute_process(COMMAND python -c "import tensorflow as tf; print(tf.sysconfig.get_lib())"
# OUTPUT_VARIABLE TF_LIB)
#execute_process(COMMAND python -c "import tensorflow as tf; print(tf.__cxx11_abi_flag__ if \"__cxx11_abi_flag__\" in tf.__dict__ else 0)"
# OUTPUT_VARIABLE TF_ABI)
execute_process(COMMAND python -c "import tensorflow as tf; print(tf.sysconfig.get_lib())"
OUTPUT_VARIABLE TF_LIB OUTPUT_STRIP_TRAILING_WHITESPACE)
execute_process(COMMAND python -c "import tensorflow as tf; print(tf.__cxx11_abi_flag__ if \"__cxx11_abi_flag__\" in tf.__dict__ else 0)"
OUTPUT_VARIABLE TF_ABI OUTPUT_STRIP_TRAILING_WHITESPACE)
set(CUDA_HOME "/usr/local/cuda")
include_directories("${TF_INC}/tensorflow/include")
include_directories("${CUDA_HOME}/include")
link_directories(${TF_LIB})
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
add_library(tf_blocksparse SHARED blocksparse.cpp)
#link_libraries(tf_blocksparse ${TF_LIB})
target_link_libraries(tf_blocksparse tensorflow_framework triton)

View File

@@ -66,35 +66,18 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
fp32 b[TN, 1] = checkb ? *pb : 0;
c = dot(a, trans(b), c);
}
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
int32 *plock = locks + ridx + ridy*grid0;
while(__atomic_cas(plock, 0, 1));
int32 *pcount = plock + grid0*grid1;
int32 count = *pcount;
int32 countp1 = select(count == GZ - 1, 0, count + 1);
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
if(count == 0) {
@checkc *pc = c;
*pcount = countp1;
}
else {
@checkc *pc = c + *pc;
*pcount = countp1;
}
__atomic_cas(plock, 1, 0);
*pc = c;
}
)";
REGISTER_OP("BlockSparseGemm")
REGISTER_OP("BlockSparseMatMul")
.Input("a: T")
.Input("b: T")
.Input("locks: int32")
.Output("c: T")
.Attr("T: {float}")
.Input("A: float")
.Input("B: float")
.Input("locks: int")
.Output("C: float");
;
class BlockSparseGemmOp : public OpKernel {
public:
@@ -104,59 +87,60 @@ class BlockSparseGemmOp : public OpKernel {
void Compute(OpKernelContext* context){
// get device/stream
GPUDevice device = context->eigen_device<GPUDevice>();
triton::driver::cu_stream stream(device.stream(), false);
triton::driver::cu_stream sstream(device.stream(), false);
triton::driver::context* ctx = sstream.context();
triton::driver::stream* stream = &sstream;
// get inputs
const Tensor& a = context->input(0);
const Tensor& b = context->input(1);
const Tensor& locks = context->input(2);
// get shapes
const int64 M = a.dim_size(0);
const int64 N = b.dim_size(0);
const int64 K = a.dim_size(1);
const int32_t M = a.dim_size(0);
const int32_t N = b.dim_size(0);
const int32_t K = a.dim_size(1);
// allocate output
Tensor* c = nullptr;
TensorShape out_shape({M, N});
TensorShape out_shape({(int64)M, (int64)N});
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &c));
// return early if possible
if (out_shape.num_elements() == 0)
return;
// wraps into buffers
triton::driver::cu_buffer ta(stream.context(), (CUdeviceptr)a.flat<float>().data(), false);
triton::driver::cu_buffer tb(stream.context(), (CUdeviceptr)b.flat<float>().data(), false);
triton::driver::cu_buffer tlocks(stream.context(), (CUdeviceptr)locks.flat<int32_t>().data(), false);
triton::driver::cu_buffer tc(stream.context(), (CUdeviceptr)c->flat<float>().data(), false);
// launch info
triton::jit jit(stream.context());
// initialize default compute device
triton::jit jit(ctx);
// matrix multiplication parameters
triton::driver::cu_buffer da(ctx, (CUdeviceptr)a.flat<float>().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)b.flat<float>().data(), false);
triton::driver::cu_buffer dc(ctx, (CUdeviceptr)c->flat<float>().data(), false);
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
stream->synchronize();
// just-in-time compile source-code
jit.add_module("matmul", src, {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1});
triton::driver::kernel* kernel = jit.get_function("matmul");
triton::jit::launch_information info = jit.get_launch_info("matmul");
int64 TM = info.global_range_size[0];
int64 TN = info.global_range_size[1];
// launch info
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned nthreads = info.num_threads;
int64 GZ = jit.get_int("GZ");
std::array<size_t, 3> grid;
grid[0] = (M + TM - 1)/TM;
grid[1] = (N + TN - 1)/TN;
grid[2] = GZ;
unsigned GZ = jit.get_int("GZ");
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
// set argument
kernel->setArg(0, &ta);
kernel->setArg(1, &tb);
kernel->setArg(2, &tc);
kernel->setArg(0, *da.cu());
kernel->setArg(1, *db.cu());
kernel->setArg(2, *dc.cu());
kernel->setArg(3, M);
kernel->setArg(4, N);
kernel->setArg(5, K);
kernel->setArg(6, M);
kernel->setArg(7, N);
kernel->setArg(8, M);
kernel->setArg(9, tlocks);
kernel->setArg(9, *dlocks.cu());
kernel->setArg(10, grid[0]);
kernel->setArg(11, grid[1]);
// dry run
stream.enqueue(kernel, grid, {nthreads, 1, 1}, nullptr, nullptr);
return;
stream->enqueue(kernel, grid, {nthreads, 1, 1});
stream->synchronize();
}
private:
};
REGISTER_KERNEL_BUILDER(Name("BlockSparse").Device(DEVICE_GPU), BlockSparseGemmOp);
REGISTER_KERNEL_BUILDER(Name("BlockSparseMatMul").Device(DEVICE_GPU).TypeConstraint<float>("T"), BlockSparseGemmOp);

View File

@@ -0,0 +1,20 @@
import os
import tensorflow as tf
import numpy as np
data_files_path = tf.resource_loader.get_data_files_path()
library_dir = '/home/philippe/development/triton/build/examples/python/tensorflow'
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
M, N, K = 512, 512, 512
a = tf.placeholder(tf.float32, shape=[M, K])
b = tf.placeholder(tf.float32, shape=[N, K])
locks = tf.placeholder(tf.int32, shape=[4096])
c = module.block_sparse_mat_mul(a, b, locks)
# Run
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([c], feed_dict = {locks: np.zeros(4096),
a: np.random.rand(M, K),
b: np.random.rand(N, K)})
print(result)

View File

@@ -1,74 +0,0 @@
import os, sys
from os.path import dirname
from distutils.core import setup, Extension
from glob import glob
from build import build_clib_subclass, build_ext_subclass
def recursive_glob(rootdir='.', suffix=''):
return [os.path.join(looproot, filename)
for looproot, _, filenames in os.walk(rootdir)
for filename in filenames if filename.endswith(suffix)]
def main():
path = os.path.join(os.pardir, 'include')
include = [path, os.path.join(path, 'isaac', 'external', 'CUDA')]
src = recursive_glob(os.path.join(os.pardir,'lib'), 'cpp')
flags = ['-std=c++11', '-fPIC', '-D_GLIBCXX_USE_CXX11_ABI=0']
core = ('core', {'sources': src, 'include_dirs': include, 'cflags': flags})
# Extensions
extensions = []
# Isaac
extensions += [Extension('_isaac',
sources=recursive_glob(os.path.join('src','bind'), 'cpp'),
libraries=[],
library_dirs=[],
extra_compile_args=flags,
extra_link_args=[],
include_dirs=include + [os.path.join('src', 'bind')])]
# Tensorflow
try:
import tensorflow as tf
tf_include = tf.sysconfig.get_include()
extensions += [Extension('_tensorflow',
sources=[os.path.join('src', 'extensions', 'tensorflow.cpp')],
libraries = ['tensorflow_framework'],
extra_compile_args= flags,
include_dirs = include + [tf_include, os.path.join(tf_include, 'external', 'nsync', 'public')],
library_dirs = [tf.sysconfig.get_lib()])]
except ImportError:
pass
# Setup
setup(
name='blocksparse',
version='1.0',
author='Philippe Tillet',
author_email='ptillet@g.harvard.edu',
packages=['isaac', 'isaac.pytorch', 'isaac.pytorch.models', 'isaac.pytorch.c_lib'],
libraries=[core],
ext_package='isaac',
ext_modules=extensions,
cmdclass={'build_clib': build_clib_subclass, 'build_ext': build_ext_subclass},
classifiers=['Environment :: Console',
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Other Audience',
'Intended Audience :: Science/Research',
'Natural Language :: English',
'Programming Language :: C++',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Scientific/Engineering :: Physics',
'Topic :: Scientific/Engineering :: Machine Learning']
)
if __name__ == "__main__":
main()