[PYTHON][CORE] Deprecating Tensorflow support

This commit is contained in:
Philippe Tillet
2020-02-10 04:19:17 -05:00
committed by Philippe Tillet
parent d7a781dd40
commit 404dd18333
5 changed files with 26 additions and 108 deletions

View File

@@ -187,11 +187,11 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
a = torch.rand(B, M, K).type(dtype).cuda()
b = torch.rand(B, K, N).type(dtype).cuda()
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
ratio = triton.ctx_registry[tmmc].forward_ms / ctx.forward_ms
cmp_str = f'({ratio:4.2f})'
else:
cmp_str = ''
# test and benchmark
bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3
bench = 2. * B * M * N * K / ctx.forward_ms * 1e-3
diff = (tc - rc).abs().max() / rc.abs().max()
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')

View File

@@ -449,32 +449,9 @@ inline std::string to_c_ty(ir::type *ty) {
void gen_torch_signature(std::ostringstream& oss,
ir::function* fn,
const std::vector<std::string>& outputs,
const std::string& name) {
const auto& args = fn->args();
std::vector<ir::type*> out_types;
for(const std::string& out: outputs) {
auto it = std::find_if(args.begin(), args.end(),
[&](ir::argument* arg) { return arg->get_name() == out; });
if(it == args.end())
throw std::runtime_error("unknown argument");
out_types.push_back((*it)->get_type());
}
std::string ret_ty;
if(out_types.empty())
ret_ty = "void";
else{
ir::type* ty = out_types[0];
ret_ty = to_torch_ty(ty);
if(out_types.size() > 1){
for(size_t i = 1; i < out_types.size(); i++)
if(out_types[i] != ty)
throw std::runtime_error("outputs of different types not supported by pytorch");
ret_ty = "std::vector<" + ret_ty + ">";
}
}
std::string ret_ty = "void";
oss << ret_ty << " " << name << "(";
oss << "int64_t id, ";
oss << "int64_t bench, ";
@@ -555,9 +532,7 @@ void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
std::tuple<std::string,
std::string> make_torch_src(const std::string& src,
const std::vector<std::string>& outputs,
const std::vector<std::string>& tmp,
const runtime::function::options_space_t& opt) {
const runtime::function::options_space_t& opt) {
// triton-ir code-gen
ir::context ctx;
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
@@ -588,12 +563,12 @@ extern std::map<size_t, int64_t> i64scalar_map;
)";
gen_torch_signature(oss, fn, outputs, name);
gen_torch_signature(oss, fn, name);
oss << " {" << std::endl;
gen_torch_init_driver(oss, fn->args());
gen_torch_make_handles(oss, fn->args());
gen_torch_make_launch_function(oss, fn->args());
gen_torch_ret(oss, outputs);
//gen_torch_ret(oss);
oss << "}" << std::endl;
oss << std::endl;

View File

@@ -17,11 +17,9 @@ import triton.frameworks as fw
import triton.utils
import triton._C.libtriton as libtriton
def _make_framework_src(src, out, tmp, grid):
if fw.has_tensorflow():
return libtriton.make_tensorflow_src(src, out, tmp, grid)
elif fw.has_torch:
return libtriton.make_torch_src(src, out, tmp, grid)
def _make_framework_src(src, grid):
if fw.has_torch:
return libtriton.make_torch_src(src, grid)
else:
assert False
@@ -36,9 +34,7 @@ def _make_cache_path(src):
return cachepath
def _write_bindings(src, root):
if fw.has_tensorflow():
name = 'tensorflow'
elif fw.has_torch():
if fw.has_torch():
name = 'torch'
else:
assert False
@@ -81,15 +77,7 @@ def _build(src, path):
libraries = ['triton']
# add framework
extra_compile_args = []
if fw.has_tensorflow():
library_dirs += [fw.tensorflow.sysconfig.get_lib()]
include_dirs += [fw.tensorflow.sysconfig.get_include()]
include_dirs += ['/usr/local/cuda/include/']
libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
name = 'tensorflow'
elif fw.has_torch():
if fw.has_torch():
prefix = os.path.dirname(fw.torch.__file__)
library_dirs += [os.path.join(prefix, 'lib')]
include_dirs += ['/usr/local/cuda/include/',
@@ -138,18 +126,8 @@ def _cvt_to_def_str(obj):
# bool
if isinstance(obj, bool):
return str(int(obj))
# tensorflow type
if fw.has_tensorflow():
if isinstance(obj, fw.tensorflow.DType):
return {fw.tensorflow.int8: 'char',
fw.tensorflow.int16: 'short',
fw.tensorflow.int32: 'int',
fw.tensorflow.int64: 'long',
fw.tensorflow.float16: 'half',
fw.tensorflow.float32: 'float',
fw.tensorflow.float64: 'double'}[obj]
# torch type
elif fw.has_torch():
if fw.has_torch():
if isinstance(obj, fw.torch.dtype):
return {fw.torch.int8: 'char',
fw.torch.int16: 'short',
@@ -164,14 +142,12 @@ def _cvt_to_def_str(obj):
return str(obj)
def _make_framework_op(src, outputs, tmp, options):
src, name = _make_framework_src(src, outputs, tmp, options)
def _make_framework_op(src, options):
src, name = _make_framework_src(src, options)
cache_path = _make_cache_path(src)
cpp, so = _write_bindings(src, cache_path)
_build(cpp, cache_path)
if fw.has_tensorflow():
return fw.tensorflow.load_op_library(so).__dict__[name]
elif fw.has_torch():
if fw.has_torch():
fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name)
else:
@@ -193,13 +169,11 @@ bench_registry = triton.utils.id_dict()
class kernel:
def __init__(self, src, outputs, tmp=[]):
def __init__(self, src):
self.fw_id = dict()
self.fw_grids = dict()
self.fw_op = None
self.src = src
self.outputs = outputs
self.tmp = tmp
self.cst = dict()
def set_constant(self, name, value):
@@ -245,7 +219,7 @@ class kernel:
for name, value in self.cst.items():
libtriton.register_cst(op_id, name, value)
if self.fw_op is None:
self.fw_op = _make_framework_op(self.src, self.outputs, self.tmp, opt)
self.fw_op = _make_framework_op(self.src, opt)
########################
# initialize
@@ -254,45 +228,13 @@ class kernel:
libtriton.register_grid(op_id, grid)
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
#########################
# call framework function
#########################
if fw.has_tensorflow():
empty = [x for x in args if isinstance(x, triton.utils.tf_empty_proxy)]
if len(empty) != len(self.outputs):
raise ValueError('Number of empty arguments does not much number of outputs provided')
# operands
operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args]
# output data types
kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id}
for i, x in enumerate(args):
if isinstance(x, triton.utils.tf_empty_proxy):
kwargs['T' + str(i)] = x.dtype
# launch
ret = self.fw_op(*operands, **kwargs)
ret = [ret] if isinstance(ret, fw.tensorflow.Tensor) else ret
op_def = ret[0].op.op_def
# fill empty tensors with corresponding values
for j, y in enumerate(op_def.output_arg):
found = False
for i, x in enumerate(op_def.input_arg):
if y.name + '_shape' == x.name:
args[i].tensor = ret[j]
found = True
assert found
# store timing information
if bench > 0:
for y in ret:
bench_registry[y] = triton.utils.id_dict.lazy_entry(bench_id)
############################
# call torch function
############################
elif fw.has_torch():
args = [x if isinstance(x, fw.torch.Tensor) else x for x in args]
ret = self.fw_op(op_id, bench, bench_id, *args)
if fw.has_torch():
self.fw_op(op_id, bench, bench_id, *args)
if bench > 0:
bench_registry[ret] = libtriton.retrieve_scalar(bench_id)
return libtriton.retrieve_scalar(bench_id)
else:
assert False

View File

@@ -41,7 +41,7 @@ void fwdbatchnorm(float *Y, float *M, float *V,
}
}
"""
fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V'])
fwd_kernel = triton.kernel(fwd_src)
bwd_src = """
void bwdbatchnorm(float *DX, float *DG, float *DB,
@@ -88,7 +88,7 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
}
}
"""
bwd_kernel = triton.kernel(bwd_src, ['DX', 'DG', 'DB'])
bwd_kernel = triton.kernel(bwd_src)
@staticmethod
def forward(ctx, x, gamma, beta, eps):

View File

@@ -313,7 +313,7 @@ __global__ void {name}(
"""
#print(src)
ret = triton.kernel(src, ['C'])
ret = triton.kernel(src)
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
ret.set_constant('AD', delta_a)
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
@@ -563,7 +563,7 @@ __global__ void {name}(
self.args[self.pos_a] = a
self.args[self.pos_b] = b
self.args[self.pos_c] = c
self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros)
return self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros)
@@ -591,7 +591,7 @@ __global__ void {name}(
a.stride(), b.stride(), c.stride(),
a.shape, b.shape, c.shape, arrays)
instance = cache[key]
instance.run(a, b, c, bench)
speed = instance.run(a, b, c, bench)
# save information in context
ctx.flops = instance.flops
ctx.sym_a = instance.sym_a
@@ -602,6 +602,7 @@ __global__ void {name}(
ctx.matmul_N = instance.matmul_N
ctx.matmul_K = instance.matmul_K
ctx.bench = bench
ctx.forward_ms = speed
ctx.save_for_backward(a, b)
return c