[python][tensorflow] basic op generation is working

This commit is contained in:
Philippe Tillet
2019-08-16 20:50:18 -07:00
parent c7cb5f82ad
commit 11a6a92598
9 changed files with 211 additions and 23 deletions

View File

@@ -48,9 +48,16 @@ endif()
# Python module
if(BUILD_PYTHON_MODULE)
message(STATUS "Adding Python module")
file(GLOB_RECURSE PYTHON_SRC python/src/*.cpp)
include_directories(python/src/ ${PYTHON_INCLUDE_DIRS})
set(PYTHON_LIBS )
# PyBind11 wrapper source file
file(GLOB_RECURSE PYTHON_SRC python/src/tensorflow.cpp)
# update include directory
include_directories(python/src/ ${PYTHON_INCLUDE_DIRS} ${TF_INCLUDE_DIRS})
# update link directories
link_directories(${TF_LIB_DIRS})
# extra tensorflow ops (e.g., alloc_empty)
file(GLOB_RECURSE EXTRA_TF_OPS_SRC python/src/tensorflow/*.cpp)
add_library(extra_tf_ops SHARED ${EXTRA_TF_OPS_SRC})
target_link_libraries(extra_tf_ops ${TF_LIBS})
endif()

View File

@@ -164,16 +164,17 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
res.cublas = 0;
// test
// stream->synchronize();
// stream->read(dc, true, 0, hc);
// std::vector<NumericT> rc(hc.size());
// cpu_ref(AT, BT, M, N, K, rc, ha, hb);
// for(size_t i = 0; i < M*N; i++)
// if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
// exit(EXIT_FAILURE);
// }
// std::cout << "Pass!" << std::endl;
stream->synchronize();
stream->read(dc, true, 0, hc);
std::vector<NumericT> rc(hc.size());
cpu_ref(AT, BT, M, N, K, rc, ha, hb);
for(size_t i = 0; i < M*N; i++)
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}
std::cout << hc[0] << " " << std::endl;
std::cout << "Pass!" << std::endl;
// clean-up
delete dc;

84
include/triton/ir/enums.h Normal file
View File

@@ -0,0 +1,84 @@
#ifndef TRITON_IR_ENUMS_H
#define TRITON_IR_ENUMS_H
namespace triton{
namespace ir{
enum binary_op_t {
Add,
FAdd,
Sub,
FSub,
Mul,
FMul,
UDiv,
SDiv,
FDiv,
URem,
SRem,
FRem,
Shl,
LShr,
AShr,
And,
Or,
Xor
};
enum cast_op_t {
Trunc,
ZExt,
SExt,
FPTrunc,
FPExt,
UIToFP,
SIToFP,
FPToUI,
FPToSI,
PtrToInt,
IntToPtr,
BitCast,
AddrSpaceCast
};
enum cmp_pred_t {
FIRST_FCMP_PREDICATE,
FCMP_FALSE,
FCMP_OEQ,
FCMP_OGT,
FCMP_OGE,
FCMP_OLT,
FCMP_OLE,
FCMP_ONE,
FCMP_ORD,
FCMP_UNO,
FCMP_UEQ,
FCMP_UGT,
FCMP_UGE,
FCMP_ULT,
FCMP_ULE,
FCMP_UNE,
FCMP_TRUE,
LAST_FCMP_PREDICATE,
FIRST_ICMP_PREDICATE,
ICMP_EQ,
ICMP_NE,
ICMP_UGT,
ICMP_UGE,
ICMP_ULT,
ICMP_ULE,
ICMP_SGT,
ICMP_SGE,
ICMP_SLT,
ICMP_SLE,
LAST_ICMP_PREDICATE
};
}
}
#endif

View File

@@ -4,6 +4,7 @@ import distutils
import distutils.log
import setuptools.command.build_ext
import setuptools
import numpy as np
import os
src = """
@@ -45,23 +46,25 @@ void matmul(restrict read_only align(16) half *A,
}
"""
extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so')
with open('test.cpp', 'w+') as test:
src = libtriton.make_tensorflow_src(src, [2], '(M + #TM - 1)/#TM, (N + #TN - 1)/#TN, 1')
test.writelines(src)
triton_include_dirs = ['/home/philippe/development/triton/include']
tensorflow_include_dirs = [tf.sysconfig.get_include()]
llvm_include_dirs = ['/usr/include/llvm-8/', '/usr/include/llvm-c-8/']
cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/']
triton_library_dirs = [os.path.realpath(libtriton.__file__)]
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))]
tensorflow_library_dirs = [tf.sysconfig.get_lib()]
include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs
extra_compile_args = []
extra_link_args = []
library_dirs = tensorflow_library_dirs
libraries = ['tensorflow_framework']
library_dirs = triton_library_dirs + tensorflow_library_dirs
libraries = ['tensorflow_framework', 'triton']
ext = setuptools.Extension(
name = 'test',
@@ -92,4 +95,46 @@ args = dict(
setuptools.setup(**args)
library_dir = os.path.dirname(os.path.realpath(__file__))
module = tf.load_op_library(os.path.join(library_dir, 'build/lib.linux-x86_64-3.6/test.cpython-36m-x86_64-linux-gnu.so'))
print(module.matmul)
class dot:
def __init__(self):
trans_a = True
trans_b = False
def __call__(self, a, b):
shape_a = tf.shape(a)
shape_b = tf.shape(b)
M = shape_a[0]
K = shape_a[1]
N = shape_b[0]
lda = M
ldb = K
ldc = M
c = extra_ops.alloc_empty(tf.stack([M, N]))
return module.matmul(a, b, c, M, N, K, lda, ldb, ldc)
dot_nt = dot()
def run_dot():
M, N, K = 128, 128, 128
a = tf.placeholder(tf.float16, shape=[M, K])
b = tf.placeholder(tf.float16, shape=[N, K])
# c = tf.matmul(a, b, transpose_a=True)
c = dot_nt(a, b)
# Reference
ha = np.random.rand(M, K).astype(np.float16)
hb = np.random.rand(N, K).astype(np.float16)
# Run
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([c], feed_dict = {a: ha,
b: hb})[0]
# Test
hresult = np.dot(ha.T, hb).T
dif = np.abs(result - hresult)
np.savetxt('dif.dat', dif, '%2.4f')
print(hresult)
print(result)
print("dif: %f" % np.max(dif))
run_dot()

View File

@@ -35,12 +35,22 @@ class CMakeBuild(build_ext):
def build_extension(self, ext):
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
# python directors
python_include_dirs = distutils.sysconfig.get_python_inc()
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
# tensorflow directories
import tensorflow as tf
tf_include_dirs = tf.sysconfig.get_include()
tf_lib_dirs = tf.sysconfig.get_lib()
tf_libs = 'tensorflow_framework'
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
'-DBUILD_EXAMPLES=OFF',
'-DBUILD_PYTHON_MODULE=ON',
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs]
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
'-DTF_INCLUDE_DIRS=' + tf_include_dirs,
'-DTF_LIB_DIRS=' + tf_lib_dirs,
'-DTF_LIBS=' + tf_libs]
cfg = 'Debug' if self.debug else 'Release'
build_args = ['--config', cfg]

View File

@@ -201,15 +201,26 @@ private:
rt::function fn_;
};
REGISTER_KERNEL_BUILDER(Name(")" + name + "\").Device(DEVICE_GPU), " + classname + R"();
REGISTER_KERNEL_BUILDER(Name(")" + name + "\").Device(DEVICE_GPU)";
for(size_t i = 0; i < tf_scalar_tys.size(); i++){
std::string arg_name = arg_names[i];
std::transform(arg_name.begin(), arg_name.end(), arg_name.begin(), [](char c) { return std::tolower(c);});
if(!fn_ty->get_param_ty(i)->is_pointer_ty())
result += ".HostMemory(\"" + arg_name + "\")";
}
result += ", " + classname + R"();
REGISTER_OP(")" + name + "\")\n";
for(size_t i = 0; i < tf_scalar_tys.size(); i++){
bool is_output = std::find(outputs.begin(), outputs.end(), i) != outputs.end();
std::string mode = is_output ? "Output" : "Input" ;
std::string mode = is_output ? "Input" : "Input" ;
std::string arg_name = arg_names[i];
std::transform(arg_name.begin(), arg_name.end(), arg_name.begin(), [](char c) { return std::tolower(c);});
result += " ." + mode + "(\"" + arg_name + ": " + tf_scalar_tys[i] + "\")\n";
result += " .Input(\"" + arg_name + ": " + tf_scalar_tys[i] + "\")\n";
}
for(size_t i = 0; i < outputs.size(); i++){
result += " .Output(\"out: " + tf_scalar_tys[outputs[i]] + "\")\n";
}
result += ";\n";

View File

@@ -0,0 +1,30 @@
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class AllocEmptyOp : public OpKernel {
public:
explicit AllocEmptyOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// fetch input
const Tensor& x = context->input(0);
const int32* x_data = (const int32*)x.tensor_data().data();
// allocate output
Tensor* y = NULL;
int32 x_rank = x.dims();
OP_REQUIRES(context, x_rank == 1, errors::InvalidArgument("Input tensor must be 1D"));
int32 y_rank = x.dim_size(0);
TensorShape y_shapes;
for(size_t i = 0; i < y_rank; i++)
y_shapes.AddDim(x_data[i]);
OP_REQUIRES_OK(context, context->allocate_output(0, y_shapes, &y));
}
};
REGISTER_KERNEL_BUILDER(Name("AllocEmpty").HostMemory("x").Device(DEVICE_CPU).Device(DEVICE_GPU), AllocEmptyOp);
REGISTER_OP("AllocEmpty")
.Input("x: int32")
.Output("y: float16")
;

View File

View File