trying to work around tensorflow limitations

This commit is contained in:
Philippe Tillet
2019-09-04 01:54:43 -04:00
parent 2ccc915011
commit b747959a57
8 changed files with 48 additions and 31 deletions

View File

@@ -50,7 +50,7 @@ endif()
# Triton
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
target_link_libraries(triton LLVM ${TF_LIBS})
target_link_libraries(triton LLVM)
# Warning level
#if(MSVC)

View File

@@ -250,10 +250,10 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
try{
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
}catch(exception::cuda::base const &){
#ifdef TRITON_LOG_PTX_ERROR
//#ifdef TRITON_LOG_PTX_ERROR
std::cerr << "Compilation Failed! Log: " << std::endl;
std::cerr << errbuf << std::endl;
#endif
//#endif
throw;
}
}

View File

@@ -87,19 +87,18 @@ src = '''
else {
int *plock = locks + ridx*nlocks + lockid - 1;
int *pcount = plock + get_num_program(0)*nlocks;
while(__atomic_cas(plock, 0, 1));
while(atomic_cas(plock, 0, 1));
int count = *pcount;
if(count == 0)
*?(checkc) pc = c;
else
*?(checkc) pc = c + *pc;
__atomic_exch(pcount, 1);
__atomic_exch(plock, 0);
atomic_exch(pcount, 1);
atomic_exch(plock, 0);
}
}
'''
# std::string dot::triton_c_src_dw() const {
# bool AT = (op_ == WGRAD);
# bool BT = (op_ == FPROP);

View File

@@ -81,6 +81,7 @@ class dot_op:
self.transpose_b = transpose_b
def __call__(self, a, b):
dtype = a.dtype
# extract shapes
shape_a = triton.shape(a)
shape_b = triton.shape(b)
@@ -96,13 +97,12 @@ class dot_op:
ldb = Kb if self.transpose_b else N
ldc = N
# allocate output
c = triton.empty([M, N])
c = triton.empty([M, N], dtype = dtype)
# compute
return self.dot(a, b, c, M, N, Ka, lda, ldb, ldc,
lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))],
AT = self.transpose_a, BT = self.transpose_b, TYPE = tf.float16,
TM = [128], TN = [128], TK = [32])
AT = self.transpose_a, BT = self.transpose_b, TYPE = dtype,
TM = [128], TN = [128], TK = [8])
def dot(a, b, transpose_a = False, transpose_b = False):
if (transpose_a, transpose_b) not in dot.ops:
@@ -114,20 +114,25 @@ dot.ops = dict()
def _dot_grad(op, dy):
a = op.inputs[0]
b = op.inputs[1]
print(op.triton)
return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
def run_dot():
M, N, K = 128, 128, 128
a = tf.placeholder(tf.float16, shape=[M, K])
b = tf.placeholder(tf.float16, shape=[N, K])
a = tf.placeholder(tf.float32, shape=[M, K])
b = tf.placeholder(tf.float32, shape=[N, K])
c = dot(a, b, transpose_a = False, transpose_b = False)
print("LULZ")
da, db = tf.gradients(c, [a, b])
print(da, db)
exit
# Reference
ha = np.random.rand(M, K).astype(np.float16)
hb = np.random.rand(K, N).astype(np.float16)
ha = np.random.rand(M, K).astype(np.float32)
hb = np.random.rand(K, N).astype(np.float32)
# Run
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([c], feed_dict = {a: ha,
result = sess.run([da], feed_dict = {a: ha,
b: hb})[0]
# Test
print(result)

View File

@@ -42,17 +42,15 @@ class CMakeBuild(build_ext):
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
# tensorflow directories
import tensorflow as tf
tf_include_dirs = tf.sysconfig.get_include()
tf_lib_dirs = tf.sysconfig.get_lib()
tf_abi = tf.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tf.__dict__ else 0
tf_libs = 'tensorflow_framework'
tf_include_dirs = tf.sysconfig.get_include()
tf_libs = tf.sysconfig.get_link_flags()[1].replace('-l', '')
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
'-DBUILD_TESTS=OFF',
'-DBUILD_PYTHON_MODULE=ON',
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
'-DTF_INCLUDE_DIRS=' + tf_include_dirs,
'-DTF_LIB_DIRS=' + tf_lib_dirs,
'-DTF_LIB_DIRS=' + tf.sysconfig.get_lib(),
'-DTF_LIBS=' + tf_libs,
'-DTF_ABI=' + str(tf_abi)]

View File

@@ -171,7 +171,7 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
break;
if(idx == args.size())
throw std::runtime_error("unknown output");
os << " .Output(\"out" << i << ": " << to_tf_scalar_ty(args[idx]->get_type()) << "\")\n";
os << " .Output(\"out" << i << ": T" << idx << "\")\n";
}
os << " .Attr(\"id: int\")" << std::endl;
os << ";\n";
@@ -239,10 +239,6 @@ std::tuple<std::string,
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/framework/common_shape_fns.h"
using namespace tensorflow;
using GPUDevice = Eigen::GpuDevice;

View File

@@ -26,5 +26,6 @@ class AllocEmptyOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("AllocEmpty").HostMemory("x").Device(DEVICE_CPU).Device(DEVICE_GPU), AllocEmptyOp);
REGISTER_OP("AllocEmpty")
.Input("x: int32")
.Output("y: float16")
.Attr("T : {bool, int8, int16, int32, int64, float16, float32, float64}")
.Output("y: T")
;

View File

@@ -108,7 +108,7 @@ def _build(src, path, framework):
library_dirs += [tensorflow.sysconfig.get_lib()]
include_dirs += [tensorflow.sysconfig.get_include()]
include_dirs += ['/usr/local/cuda/include/']
libraries += ['tensorflow_framework']
libraries += [tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
ABI = tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tensorflow.__dict__ else 0
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={ABI}'.format(ABI=ABI)]
elif framework == torch_id:
@@ -210,8 +210,20 @@ def _make_grid(args) :
return result
return grid
class op:
class _definitions_descriptor:
def __init__(self):
self.values = dict()
def __set__(self, instance, value):
self.values[value[0]] = value[1]
def __get__(self, instance, owner):
return self.values
def __init__(self, src, outputs, framework = None):
self.fw_id = dict()
self.fw_ops = dict()
@@ -219,6 +231,10 @@ class op:
self.src = src
self.outputs = outputs
self.framework = _find_framework(framework)
if self.framework == tensorflow_id:
_import_tensorflow()
tensorflow.Operation.triton = property(op._definitions_descriptor)
def __call__(self, *args, **kwargs):
# create a new op when defines are different
@@ -253,7 +269,9 @@ class op:
# create operands
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
# call framework op
return op(*op_args, id=op_id)
tensor = op(*op_args, id=op_id)
tensor.op.triton = ('lol', 1)
return tensor
class register_gradient:
@@ -266,14 +284,14 @@ class register_gradient:
ops.RegisterGradient(name)(f)
def empty(shapes, framework = None):
def empty(shapes, dtype, framework = None):
framework = _find_framework(framework)
if framework == tensorflow_id:
_import_tensorflow()
_import_tf_extra_ops
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
args = tensorflow.stack(args)
return tf_extra_ops.alloc_empty(args)
return tf_extra_ops.alloc_empty(args, T = dtype)
elif framework == torch_id:
_import_torch()
return torch.empty(*shapes)