[examples] added skeleton for tensorflow op

This commit is contained in:
Philippe Tillet
2019-04-30 10:50:54 -04:00
parent 93f53501c6
commit 8e809a9536
6 changed files with 127 additions and 2 deletions

View File

@@ -1 +1,2 @@
add_subdirectory(cpp)
add_subdirectory(python)

View File

@@ -103,6 +103,7 @@ int main() {
stream->write(da, true, 0, ha);
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->write(dlocks, true, 0, hlocks);
stream->synchronize();
@@ -115,8 +116,6 @@ int main() {
unsigned nthreads = info.num_threads;
unsigned GZ = jit.get_int("GZ");
std::array<size_t, 3> grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ};
// init locks
stream->write(dlocks, true, 0, hlocks);
// set argument
kernel->setArg(0, da);
kernel->setArg(1, db);

View File

@@ -0,0 +1 @@
add_subdirectory(tensorflow)

View File

@@ -0,0 +1,12 @@
execute_process(COMMAND python -c "from os.path import dirname; import tensorflow as tf; print(dirname(dirname(tf.sysconfig.get_include())))"
OUTPUT_VARIABLE TF_INC OUTPUT_STRIP_TRAILING_WHITESPACE)
#execute_process(COMMAND python -c "import tensorflow as tf; print(tf.sysconfig.get_lib())"
# OUTPUT_VARIABLE TF_LIB)
#execute_process(COMMAND python -c "import tensorflow as tf; print(tf.__cxx11_abi_flag__ if \"__cxx11_abi_flag__\" in tf.__dict__ else 0)"
# OUTPUT_VARIABLE TF_ABI)
set(CUDA_HOME "/usr/local/cuda")
include_directories("${TF_INC}/tensorflow/include")
include_directories("${CUDA_HOME}/include")
add_library(tf_blocksparse SHARED blocksparse.cpp)
#link_libraries(tf_blocksparse ${TF_LIB})

View File

@@ -0,0 +1,38 @@
#include <iostream>
#include "triton/driver/buffer.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/framework/common_shape_fns.h"
using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice;
REGISTER_OP("BlockSparseGemm")
.Attr("T: {float}")
.Input("A: float")
.Input("B: float")
.Output("C: float");
class BlockSparseGemmOp : public OpKernel {
public:
explicit BlockSparseGemmOp(OpKernelConstruction* context) : OpKernel(context) {
}
void Compute(OpKernelContext* context){
GPUDevice device = context->eigen_device<GPUDevice>();
triton::driver::cu_stream stream(device.stream(), false);
}
private:
};
REGISTER_KERNEL_BUILDER(Name("BlockSparse").Device(DEVICE_GPU), BlockSparseGemmOp);

View File

@@ -0,0 +1,74 @@
import os, sys
from os.path import dirname
from distutils.core import setup, Extension
from glob import glob
from build import build_clib_subclass, build_ext_subclass
def recursive_glob(rootdir='.', suffix=''):
return [os.path.join(looproot, filename)
for looproot, _, filenames in os.walk(rootdir)
for filename in filenames if filename.endswith(suffix)]
def main():
path = os.path.join(os.pardir, 'include')
include = [path, os.path.join(path, 'isaac', 'external', 'CUDA')]
src = recursive_glob(os.path.join(os.pardir,'lib'), 'cpp')
flags = ['-std=c++11', '-fPIC', '-D_GLIBCXX_USE_CXX11_ABI=0']
core = ('core', {'sources': src, 'include_dirs': include, 'cflags': flags})
# Extensions
extensions = []
# Isaac
extensions += [Extension('_isaac',
sources=recursive_glob(os.path.join('src','bind'), 'cpp'),
libraries=[],
library_dirs=[],
extra_compile_args=flags,
extra_link_args=[],
include_dirs=include + [os.path.join('src', 'bind')])]
# Tensorflow
try:
import tensorflow as tf
tf_include = tf.sysconfig.get_include()
extensions += [Extension('_tensorflow',
sources=[os.path.join('src', 'extensions', 'tensorflow.cpp')],
libraries = ['tensorflow_framework'],
extra_compile_args= flags,
include_dirs = include + [tf_include, os.path.join(tf_include, 'external', 'nsync', 'public')],
library_dirs = [tf.sysconfig.get_lib()])]
except ImportError:
pass
# Setup
setup(
name='blocksparse',
version='1.0',
author='Philippe Tillet',
author_email='ptillet@g.harvard.edu',
packages=['isaac', 'isaac.pytorch', 'isaac.pytorch.models', 'isaac.pytorch.c_lib'],
libraries=[core],
ext_package='isaac',
ext_modules=extensions,
cmdclass={'build_clib': build_clib_subclass, 'build_ext': build_ext_subclass},
classifiers=['Environment :: Console',
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Other Audience',
'Intended Audience :: Science/Research',
'Natural Language :: English',
'Programming Language :: C++',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Scientific/Engineering :: Physics',
'Topic :: Scientific/Engineering :: Machine Learning']
)
if __name__ == "__main__":
main()