trying to work around tensorflow limitations
This commit is contained in:
@@ -50,7 +50,7 @@ endif()
|
||||
# Triton
|
||||
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||
add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
target_link_libraries(triton LLVM ${TF_LIBS})
|
||||
target_link_libraries(triton LLVM)
|
||||
|
||||
# Warning level
|
||||
#if(MSVC)
|
||||
|
@@ -250,10 +250,10 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
|
||||
try{
|
||||
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
|
||||
}catch(exception::cuda::base const &){
|
||||
#ifdef TRITON_LOG_PTX_ERROR
|
||||
//#ifdef TRITON_LOG_PTX_ERROR
|
||||
std::cerr << "Compilation Failed! Log: " << std::endl;
|
||||
std::cerr << errbuf << std::endl;
|
||||
#endif
|
||||
//#endif
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
@@ -87,19 +87,18 @@ src = '''
|
||||
else {
|
||||
int *plock = locks + ridx*nlocks + lockid - 1;
|
||||
int *pcount = plock + get_num_program(0)*nlocks;
|
||||
while(__atomic_cas(plock, 0, 1));
|
||||
while(atomic_cas(plock, 0, 1));
|
||||
int count = *pcount;
|
||||
if(count == 0)
|
||||
*?(checkc) pc = c;
|
||||
else
|
||||
*?(checkc) pc = c + *pc;
|
||||
__atomic_exch(pcount, 1);
|
||||
__atomic_exch(plock, 0);
|
||||
atomic_exch(pcount, 1);
|
||||
atomic_exch(plock, 0);
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
|
||||
# std::string dot::triton_c_src_dw() const {
|
||||
# bool AT = (op_ == WGRAD);
|
||||
# bool BT = (op_ == FPROP);
|
||||
|
@@ -81,6 +81,7 @@ class dot_op:
|
||||
self.transpose_b = transpose_b
|
||||
|
||||
def __call__(self, a, b):
|
||||
dtype = a.dtype
|
||||
# extract shapes
|
||||
shape_a = triton.shape(a)
|
||||
shape_b = triton.shape(b)
|
||||
@@ -96,13 +97,12 @@ class dot_op:
|
||||
ldb = Kb if self.transpose_b else N
|
||||
ldc = N
|
||||
# allocate output
|
||||
c = triton.empty([M, N])
|
||||
c = triton.empty([M, N], dtype = dtype)
|
||||
# compute
|
||||
return self.dot(a, b, c, M, N, Ka, lda, ldb, ldc,
|
||||
lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))],
|
||||
AT = self.transpose_a, BT = self.transpose_b, TYPE = tf.float16,
|
||||
TM = [128], TN = [128], TK = [32])
|
||||
|
||||
AT = self.transpose_a, BT = self.transpose_b, TYPE = dtype,
|
||||
TM = [128], TN = [128], TK = [8])
|
||||
|
||||
def dot(a, b, transpose_a = False, transpose_b = False):
|
||||
if (transpose_a, transpose_b) not in dot.ops:
|
||||
@@ -114,20 +114,25 @@ dot.ops = dict()
|
||||
def _dot_grad(op, dy):
|
||||
a = op.inputs[0]
|
||||
b = op.inputs[1]
|
||||
print(op.triton)
|
||||
return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None]
|
||||
|
||||
def run_dot():
|
||||
M, N, K = 128, 128, 128
|
||||
a = tf.placeholder(tf.float16, shape=[M, K])
|
||||
b = tf.placeholder(tf.float16, shape=[N, K])
|
||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||
c = dot(a, b, transpose_a = False, transpose_b = False)
|
||||
print("LULZ")
|
||||
da, db = tf.gradients(c, [a, b])
|
||||
print(da, db)
|
||||
exit
|
||||
# Reference
|
||||
ha = np.random.rand(M, K).astype(np.float16)
|
||||
hb = np.random.rand(K, N).astype(np.float16)
|
||||
ha = np.random.rand(M, K).astype(np.float32)
|
||||
hb = np.random.rand(K, N).astype(np.float32)
|
||||
# Run
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([c], feed_dict = {a: ha,
|
||||
result = sess.run([da], feed_dict = {a: ha,
|
||||
b: hb})[0]
|
||||
# Test
|
||||
print(result)
|
||||
|
@@ -42,17 +42,15 @@ class CMakeBuild(build_ext):
|
||||
python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR')
|
||||
# tensorflow directories
|
||||
import tensorflow as tf
|
||||
tf_include_dirs = tf.sysconfig.get_include()
|
||||
tf_lib_dirs = tf.sysconfig.get_lib()
|
||||
tf_abi = tf.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tf.__dict__ else 0
|
||||
tf_libs = 'tensorflow_framework'
|
||||
|
||||
tf_include_dirs = tf.sysconfig.get_include()
|
||||
tf_libs = tf.sysconfig.get_link_flags()[1].replace('-l', '')
|
||||
cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
|
||||
'-DBUILD_TESTS=OFF',
|
||||
'-DBUILD_PYTHON_MODULE=ON',
|
||||
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
|
||||
'-DTF_INCLUDE_DIRS=' + tf_include_dirs,
|
||||
'-DTF_LIB_DIRS=' + tf_lib_dirs,
|
||||
'-DTF_LIB_DIRS=' + tf.sysconfig.get_lib(),
|
||||
'-DTF_LIBS=' + tf_libs,
|
||||
'-DTF_ABI=' + str(tf_abi)]
|
||||
|
||||
|
@@ -171,7 +171,7 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
|
||||
break;
|
||||
if(idx == args.size())
|
||||
throw std::runtime_error("unknown output");
|
||||
os << " .Output(\"out" << i << ": " << to_tf_scalar_ty(args[idx]->get_type()) << "\")\n";
|
||||
os << " .Output(\"out" << i << ": T" << idx << "\")\n";
|
||||
}
|
||||
os << " .Attr(\"id: int\")" << std::endl;
|
||||
os << ";\n";
|
||||
@@ -239,10 +239,6 @@ std::tuple<std::string,
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
|
||||
using namespace tensorflow;
|
||||
using GPUDevice = Eigen::GpuDevice;
|
||||
|
@@ -26,5 +26,6 @@ class AllocEmptyOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("AllocEmpty").HostMemory("x").Device(DEVICE_CPU).Device(DEVICE_GPU), AllocEmptyOp);
|
||||
REGISTER_OP("AllocEmpty")
|
||||
.Input("x: int32")
|
||||
.Output("y: float16")
|
||||
.Attr("T : {bool, int8, int16, int32, int64, float16, float32, float64}")
|
||||
.Output("y: T")
|
||||
;
|
||||
|
@@ -108,7 +108,7 @@ def _build(src, path, framework):
|
||||
library_dirs += [tensorflow.sysconfig.get_lib()]
|
||||
include_dirs += [tensorflow.sysconfig.get_include()]
|
||||
include_dirs += ['/usr/local/cuda/include/']
|
||||
libraries += ['tensorflow_framework']
|
||||
libraries += [tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
|
||||
ABI = tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tensorflow.__dict__ else 0
|
||||
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={ABI}'.format(ABI=ABI)]
|
||||
elif framework == torch_id:
|
||||
@@ -210,8 +210,20 @@ def _make_grid(args) :
|
||||
return result
|
||||
return grid
|
||||
|
||||
|
||||
class op:
|
||||
|
||||
class _definitions_descriptor:
|
||||
def __init__(self):
|
||||
self.values = dict()
|
||||
|
||||
def __set__(self, instance, value):
|
||||
self.values[value[0]] = value[1]
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
return self.values
|
||||
|
||||
|
||||
def __init__(self, src, outputs, framework = None):
|
||||
self.fw_id = dict()
|
||||
self.fw_ops = dict()
|
||||
@@ -219,6 +231,10 @@ class op:
|
||||
self.src = src
|
||||
self.outputs = outputs
|
||||
self.framework = _find_framework(framework)
|
||||
if self.framework == tensorflow_id:
|
||||
_import_tensorflow()
|
||||
tensorflow.Operation.triton = property(op._definitions_descriptor)
|
||||
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# create a new op when defines are different
|
||||
@@ -253,7 +269,9 @@ class op:
|
||||
# create operands
|
||||
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
|
||||
# call framework op
|
||||
return op(*op_args, id=op_id)
|
||||
tensor = op(*op_args, id=op_id)
|
||||
tensor.op.triton = ('lol', 1)
|
||||
return tensor
|
||||
|
||||
|
||||
class register_gradient:
|
||||
@@ -266,14 +284,14 @@ class register_gradient:
|
||||
ops.RegisterGradient(name)(f)
|
||||
|
||||
|
||||
def empty(shapes, framework = None):
|
||||
def empty(shapes, dtype, framework = None):
|
||||
framework = _find_framework(framework)
|
||||
if framework == tensorflow_id:
|
||||
_import_tensorflow()
|
||||
_import_tf_extra_ops
|
||||
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
|
||||
args = tensorflow.stack(args)
|
||||
return tf_extra_ops.alloc_empty(args)
|
||||
return tf_extra_ops.alloc_empty(args, T = dtype)
|
||||
elif framework == torch_id:
|
||||
_import_torch()
|
||||
return torch.empty(*shapes)
|
||||
|
Reference in New Issue
Block a user