[python][tensorflow] basic op generation is working
This commit is contained in:
@@ -48,9 +48,16 @@ endif()
|
|||||||
# Python module
|
# Python module
|
||||||
if(BUILD_PYTHON_MODULE)
|
if(BUILD_PYTHON_MODULE)
|
||||||
message(STATUS "Adding Python module")
|
message(STATUS "Adding Python module")
|
||||||
file(GLOB_RECURSE PYTHON_SRC python/src/*.cpp)
|
# PyBind11 wrapper source file
|
||||||
include_directories(python/src/ ${PYTHON_INCLUDE_DIRS})
|
file(GLOB_RECURSE PYTHON_SRC python/src/tensorflow.cpp)
|
||||||
set(PYTHON_LIBS )
|
# update include directory
|
||||||
|
include_directories(python/src/ ${PYTHON_INCLUDE_DIRS} ${TF_INCLUDE_DIRS})
|
||||||
|
# update link directories
|
||||||
|
link_directories(${TF_LIB_DIRS})
|
||||||
|
# extra tensorflow ops (e.g., alloc_empty)
|
||||||
|
file(GLOB_RECURSE EXTRA_TF_OPS_SRC python/src/tensorflow/*.cpp)
|
||||||
|
add_library(extra_tf_ops SHARED ${EXTRA_TF_OPS_SRC})
|
||||||
|
target_link_libraries(extra_tf_ops ${TF_LIBS})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
@@ -164,16 +164,17 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int
|
|||||||
res.cublas = 0;
|
res.cublas = 0;
|
||||||
|
|
||||||
// test
|
// test
|
||||||
// stream->synchronize();
|
stream->synchronize();
|
||||||
// stream->read(dc, true, 0, hc);
|
stream->read(dc, true, 0, hc);
|
||||||
// std::vector<NumericT> rc(hc.size());
|
std::vector<NumericT> rc(hc.size());
|
||||||
// cpu_ref(AT, BT, M, N, K, rc, ha, hb);
|
cpu_ref(AT, BT, M, N, K, rc, ha, hb);
|
||||||
// for(size_t i = 0; i < M*N; i++)
|
for(size_t i = 0; i < M*N; i++)
|
||||||
// if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
|
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
|
||||||
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
|
||||||
// exit(EXIT_FAILURE);
|
exit(EXIT_FAILURE);
|
||||||
// }
|
}
|
||||||
// std::cout << "Pass!" << std::endl;
|
std::cout << hc[0] << " " << std::endl;
|
||||||
|
std::cout << "Pass!" << std::endl;
|
||||||
|
|
||||||
// clean-up
|
// clean-up
|
||||||
delete dc;
|
delete dc;
|
||||||
|
84
include/triton/ir/enums.h
Normal file
84
include/triton/ir/enums.h
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
#ifndef TRITON_IR_ENUMS_H
|
||||||
|
#define TRITON_IR_ENUMS_H
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
namespace ir{
|
||||||
|
|
||||||
|
|
||||||
|
enum binary_op_t {
|
||||||
|
Add,
|
||||||
|
FAdd,
|
||||||
|
Sub,
|
||||||
|
FSub,
|
||||||
|
Mul,
|
||||||
|
FMul,
|
||||||
|
UDiv,
|
||||||
|
SDiv,
|
||||||
|
FDiv,
|
||||||
|
URem,
|
||||||
|
SRem,
|
||||||
|
FRem,
|
||||||
|
Shl,
|
||||||
|
LShr,
|
||||||
|
AShr,
|
||||||
|
And,
|
||||||
|
Or,
|
||||||
|
Xor
|
||||||
|
};
|
||||||
|
|
||||||
|
enum cast_op_t {
|
||||||
|
Trunc,
|
||||||
|
ZExt,
|
||||||
|
SExt,
|
||||||
|
FPTrunc,
|
||||||
|
FPExt,
|
||||||
|
UIToFP,
|
||||||
|
SIToFP,
|
||||||
|
FPToUI,
|
||||||
|
FPToSI,
|
||||||
|
PtrToInt,
|
||||||
|
IntToPtr,
|
||||||
|
BitCast,
|
||||||
|
AddrSpaceCast
|
||||||
|
};
|
||||||
|
|
||||||
|
enum cmp_pred_t {
|
||||||
|
FIRST_FCMP_PREDICATE,
|
||||||
|
FCMP_FALSE,
|
||||||
|
FCMP_OEQ,
|
||||||
|
FCMP_OGT,
|
||||||
|
FCMP_OGE,
|
||||||
|
FCMP_OLT,
|
||||||
|
FCMP_OLE,
|
||||||
|
FCMP_ONE,
|
||||||
|
FCMP_ORD,
|
||||||
|
FCMP_UNO,
|
||||||
|
FCMP_UEQ,
|
||||||
|
FCMP_UGT,
|
||||||
|
FCMP_UGE,
|
||||||
|
FCMP_ULT,
|
||||||
|
FCMP_ULE,
|
||||||
|
FCMP_UNE,
|
||||||
|
FCMP_TRUE,
|
||||||
|
LAST_FCMP_PREDICATE,
|
||||||
|
FIRST_ICMP_PREDICATE,
|
||||||
|
ICMP_EQ,
|
||||||
|
ICMP_NE,
|
||||||
|
ICMP_UGT,
|
||||||
|
ICMP_UGE,
|
||||||
|
ICMP_ULT,
|
||||||
|
ICMP_ULE,
|
||||||
|
ICMP_SGT,
|
||||||
|
ICMP_SGE,
|
||||||
|
ICMP_SLT,
|
||||||
|
ICMP_SLE,
|
||||||
|
LAST_ICMP_PREDICATE
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@@ -4,6 +4,7 @@ import distutils
|
|||||||
import distutils.log
|
import distutils.log
|
||||||
import setuptools.command.build_ext
|
import setuptools.command.build_ext
|
||||||
import setuptools
|
import setuptools
|
||||||
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
src = """
|
src = """
|
||||||
@@ -45,23 +46,25 @@ void matmul(restrict read_only align(16) half *A,
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so')
|
||||||
|
|
||||||
|
|
||||||
with open('test.cpp', 'w+') as test:
|
with open('test.cpp', 'w+') as test:
|
||||||
src = libtriton.make_tensorflow_src(src, [2], '(M + #TM - 1)/#TM, (N + #TN - 1)/#TN, 1')
|
src = libtriton.make_tensorflow_src(src, [2], '(M + #TM - 1)/#TM, (N + #TN - 1)/#TN, 1')
|
||||||
test.writelines(src)
|
test.writelines(src)
|
||||||
|
|
||||||
triton_include_dirs = ['/home/philippe/development/triton/include']
|
triton_include_dirs = ['/home/philippe/development/triton/include']
|
||||||
tensorflow_include_dirs = [tf.sysconfig.get_include()]
|
tensorflow_include_dirs = [tf.sysconfig.get_include()]
|
||||||
llvm_include_dirs = ['/usr/include/llvm-8/', '/usr/include/llvm-c-8/']
|
|
||||||
cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/']
|
cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/']
|
||||||
|
|
||||||
triton_library_dirs = [os.path.realpath(libtriton.__file__)]
|
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))]
|
||||||
tensorflow_library_dirs = [tf.sysconfig.get_lib()]
|
tensorflow_library_dirs = [tf.sysconfig.get_lib()]
|
||||||
|
|
||||||
include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs
|
include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs
|
||||||
extra_compile_args = []
|
extra_compile_args = []
|
||||||
extra_link_args = []
|
extra_link_args = []
|
||||||
library_dirs = tensorflow_library_dirs
|
library_dirs = triton_library_dirs + tensorflow_library_dirs
|
||||||
libraries = ['tensorflow_framework']
|
libraries = ['tensorflow_framework', 'triton']
|
||||||
|
|
||||||
ext = setuptools.Extension(
|
ext = setuptools.Extension(
|
||||||
name = 'test',
|
name = 'test',
|
||||||
@@ -92,4 +95,46 @@ args = dict(
|
|||||||
setuptools.setup(**args)
|
setuptools.setup(**args)
|
||||||
library_dir = os.path.dirname(os.path.realpath(__file__))
|
library_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
module = tf.load_op_library(os.path.join(library_dir, 'build/lib.linux-x86_64-3.6/test.cpython-36m-x86_64-linux-gnu.so'))
|
module = tf.load_op_library(os.path.join(library_dir, 'build/lib.linux-x86_64-3.6/test.cpython-36m-x86_64-linux-gnu.so'))
|
||||||
print(module.matmul)
|
|
||||||
|
class dot:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
trans_a = True
|
||||||
|
trans_b = False
|
||||||
|
|
||||||
|
def __call__(self, a, b):
|
||||||
|
shape_a = tf.shape(a)
|
||||||
|
shape_b = tf.shape(b)
|
||||||
|
M = shape_a[0]
|
||||||
|
K = shape_a[1]
|
||||||
|
N = shape_b[0]
|
||||||
|
lda = M
|
||||||
|
ldb = K
|
||||||
|
ldc = M
|
||||||
|
c = extra_ops.alloc_empty(tf.stack([M, N]))
|
||||||
|
return module.matmul(a, b, c, M, N, K, lda, ldb, ldc)
|
||||||
|
|
||||||
|
dot_nt = dot()
|
||||||
|
def run_dot():
|
||||||
|
M, N, K = 128, 128, 128
|
||||||
|
a = tf.placeholder(tf.float16, shape=[M, K])
|
||||||
|
b = tf.placeholder(tf.float16, shape=[N, K])
|
||||||
|
# c = tf.matmul(a, b, transpose_a=True)
|
||||||
|
c = dot_nt(a, b)
|
||||||
|
# Reference
|
||||||
|
ha = np.random.rand(M, K).astype(np.float16)
|
||||||
|
hb = np.random.rand(N, K).astype(np.float16)
|
||||||
|
# Run
|
||||||
|
sess = tf.InteractiveSession()
|
||||||
|
sess.run(tf.global_variables_initializer())
|
||||||
|
result = sess.run([c], feed_dict = {a: ha,
|
||||||
|
b: hb})[0]
|
||||||
|
# Test
|
||||||
|
hresult = np.dot(ha.T, hb).T
|
||||||
|
dif = np.abs(result - hresult)
|
||||||
|
np.savetxt('dif.dat', dif, '%2.4f')
|
||||||
|
print(hresult)
|
||||||
|
print(result)
|
||||||
|
print("dif: %f" % np.max(dif))
|
||||||
|
|
||||||
|
run_dot()
|
@@ -35,12 +35,22 @@ class CMakeBuild(build_ext):
|
|||||||
|
|
||||||
def build_extension(self, ext):
|
def build_extension(self, ext):
|
||||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
|
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
|
||||||
|
# python directors
|
||||||
python_include_dirs = distutils.sysconfig.get_python_inc()
|
python_include_dirs = distutils.sysconfig.get_python_inc()
|
||||||
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
|
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
|
||||||
|
# tensorflow directories
|
||||||
|
import tensorflow as tf
|
||||||
|
tf_include_dirs = tf.sysconfig.get_include()
|
||||||
|
tf_lib_dirs = tf.sysconfig.get_lib()
|
||||||
|
tf_libs = 'tensorflow_framework'
|
||||||
|
|
||||||
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
|
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
|
||||||
'-DBUILD_EXAMPLES=OFF',
|
'-DBUILD_EXAMPLES=OFF',
|
||||||
'-DBUILD_PYTHON_MODULE=ON',
|
'-DBUILD_PYTHON_MODULE=ON',
|
||||||
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs]
|
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
|
||||||
|
'-DTF_INCLUDE_DIRS=' + tf_include_dirs,
|
||||||
|
'-DTF_LIB_DIRS=' + tf_lib_dirs,
|
||||||
|
'-DTF_LIBS=' + tf_libs]
|
||||||
|
|
||||||
cfg = 'Debug' if self.debug else 'Release'
|
cfg = 'Debug' if self.debug else 'Release'
|
||||||
build_args = ['--config', cfg]
|
build_args = ['--config', cfg]
|
||||||
|
@@ -161,7 +161,7 @@ result += R"(
|
|||||||
// extract outputs)";
|
// extract outputs)";
|
||||||
for(unsigned i = 0; i < n_outputs; i++)
|
for(unsigned i = 0; i < n_outputs; i++)
|
||||||
result += R"(
|
result += R"(
|
||||||
context->set_output()" + str_i[i] + ", " + arg_names[outputs[i]] + ");";
|
context->set_output()" + str_i[i] + ", " + arg_names[outputs[i]] + ");";
|
||||||
|
|
||||||
result += R"(
|
result += R"(
|
||||||
|
|
||||||
@@ -201,15 +201,26 @@ private:
|
|||||||
rt::function fn_;
|
rt::function fn_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name(")" + name + "\").Device(DEVICE_GPU), " + classname + R"();
|
REGISTER_KERNEL_BUILDER(Name(")" + name + "\").Device(DEVICE_GPU)";
|
||||||
|
for(size_t i = 0; i < tf_scalar_tys.size(); i++){
|
||||||
|
std::string arg_name = arg_names[i];
|
||||||
|
std::transform(arg_name.begin(), arg_name.end(), arg_name.begin(), [](char c) { return std::tolower(c);});
|
||||||
|
if(!fn_ty->get_param_ty(i)->is_pointer_ty())
|
||||||
|
result += ".HostMemory(\"" + arg_name + "\")";
|
||||||
|
}
|
||||||
|
result += ", " + classname + R"();
|
||||||
|
|
||||||
|
|
||||||
REGISTER_OP(")" + name + "\")\n";
|
REGISTER_OP(")" + name + "\")\n";
|
||||||
for(size_t i = 0; i < tf_scalar_tys.size(); i++){
|
for(size_t i = 0; i < tf_scalar_tys.size(); i++){
|
||||||
bool is_output = std::find(outputs.begin(), outputs.end(), i) != outputs.end();
|
bool is_output = std::find(outputs.begin(), outputs.end(), i) != outputs.end();
|
||||||
std::string mode = is_output ? "Output" : "Input" ;
|
std::string mode = is_output ? "Input" : "Input" ;
|
||||||
std::string arg_name = arg_names[i];
|
std::string arg_name = arg_names[i];
|
||||||
std::transform(arg_name.begin(), arg_name.end(), arg_name.begin(), [](char c) { return std::tolower(c);});
|
std::transform(arg_name.begin(), arg_name.end(), arg_name.begin(), [](char c) { return std::tolower(c);});
|
||||||
result += " ." + mode + "(\"" + arg_name + ": " + tf_scalar_tys[i] + "\")\n";
|
result += " .Input(\"" + arg_name + ": " + tf_scalar_tys[i] + "\")\n";
|
||||||
|
}
|
||||||
|
for(size_t i = 0; i < outputs.size(); i++){
|
||||||
|
result += " .Output(\"out: " + tf_scalar_tys[outputs[i]] + "\")\n";
|
||||||
}
|
}
|
||||||
result += ";\n";
|
result += ";\n";
|
||||||
|
|
||||||
|
30
python/src/tensorflow/alloc_empty.cpp
Normal file
30
python/src/tensorflow/alloc_empty.cpp
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
|
||||||
|
using namespace tensorflow;
|
||||||
|
|
||||||
|
class AllocEmptyOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit AllocEmptyOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
// fetch input
|
||||||
|
const Tensor& x = context->input(0);
|
||||||
|
const int32* x_data = (const int32*)x.tensor_data().data();
|
||||||
|
// allocate output
|
||||||
|
Tensor* y = NULL;
|
||||||
|
int32 x_rank = x.dims();
|
||||||
|
OP_REQUIRES(context, x_rank == 1, errors::InvalidArgument("Input tensor must be 1D"));
|
||||||
|
int32 y_rank = x.dim_size(0);
|
||||||
|
TensorShape y_shapes;
|
||||||
|
for(size_t i = 0; i < y_rank; i++)
|
||||||
|
y_shapes.AddDim(x_data[i]);
|
||||||
|
OP_REQUIRES_OK(context, context->allocate_output(0, y_shapes, &y));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("AllocEmpty").HostMemory("x").Device(DEVICE_CPU).Device(DEVICE_GPU), AllocEmptyOp);
|
||||||
|
REGISTER_OP("AllocEmpty")
|
||||||
|
.Input("x: int32")
|
||||||
|
.Output("y: float16")
|
||||||
|
;
|
0
python/triton/tools/build.py
Normal file
0
python/triton/tools/build.py
Normal file
0
python/triton/tools/checksum.py
Normal file
0
python/triton/tools/checksum.py
Normal file
Reference in New Issue
Block a user