From 404dd18333cc5e5df343e95023e423d2ea5ff2e9 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 10 Feb 2020 04:19:17 -0500 Subject: [PATCH] [PYTHON][CORE] Deprecating Tensorflow support --- python/examples/einsum.py | 4 +- python/src/bindings.cc | 33 ++----------- python/triton/kernel.py | 86 ++++++---------------------------- python/triton/ops/batchnorm.py | 4 +- python/triton/ops/einsum.py | 7 +-- 5 files changed, 26 insertions(+), 108 deletions(-) diff --git a/python/examples/einsum.py b/python/examples/einsum.py index d86edf847..340298be5 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -187,11 +187,11 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: a = torch.rand(B, M, K).type(dtype).cuda() b = torch.rand(B, K, N).type(dtype).cuda() tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True) - ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc] + ratio = triton.ctx_registry[tmmc].forward_ms / ctx.forward_ms cmp_str = f'({ratio:4.2f})' else: cmp_str = '' # test and benchmark - bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3 + bench = 2. * B * M * N * K / ctx.forward_ms * 1e-3 diff = (tc - rc).abs().max() / rc.abs().max() print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}') diff --git a/python/src/bindings.cc b/python/src/bindings.cc index 8b3ee2971..0d9d545bc 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -449,32 +449,9 @@ inline std::string to_c_ty(ir::type *ty) { void gen_torch_signature(std::ostringstream& oss, ir::function* fn, - const std::vector& outputs, const std::string& name) { const auto& args = fn->args(); - std::vector out_types; - for(const std::string& out: outputs) { - auto it = std::find_if(args.begin(), args.end(), - [&](ir::argument* arg) { return arg->get_name() == out; }); - if(it == args.end()) - throw std::runtime_error("unknown argument"); - out_types.push_back((*it)->get_type()); - } - - std::string ret_ty; - if(out_types.empty()) - ret_ty = "void"; - else{ - ir::type* ty = out_types[0]; - ret_ty = to_torch_ty(ty); - if(out_types.size() > 1){ - for(size_t i = 1; i < out_types.size(); i++) - if(out_types[i] != ty) - throw std::runtime_error("outputs of different types not supported by pytorch"); - ret_ty = "std::vector<" + ret_ty + ">"; - } - } - + std::string ret_ty = "void"; oss << ret_ty << " " << name << "("; oss << "int64_t id, "; oss << "int64_t bench, "; @@ -555,9 +532,7 @@ void gen_torch_ret(std::ostream &os, const std::vector& outputs) { std::tuple make_torch_src(const std::string& src, - const std::vector& outputs, - const std::vector& tmp, - const runtime::function::options_space_t& opt) { + const runtime::function::options_space_t& opt) { // triton-ir code-gen ir::context ctx; auto ir = std::shared_ptr(new ir::module("", ctx)); @@ -588,12 +563,12 @@ extern std::map i64scalar_map; )"; - gen_torch_signature(oss, fn, outputs, name); + gen_torch_signature(oss, fn, name); oss << " {" << std::endl; gen_torch_init_driver(oss, fn->args()); gen_torch_make_handles(oss, fn->args()); gen_torch_make_launch_function(oss, fn->args()); - gen_torch_ret(oss, outputs); + //gen_torch_ret(oss); oss << "}" << std::endl; oss << std::endl; diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 71b79bb99..edec0af12 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -17,11 +17,9 @@ import triton.frameworks as fw import triton.utils import triton._C.libtriton as libtriton -def _make_framework_src(src, out, tmp, grid): - if fw.has_tensorflow(): - return libtriton.make_tensorflow_src(src, out, tmp, grid) - elif fw.has_torch: - return libtriton.make_torch_src(src, out, tmp, grid) +def _make_framework_src(src, grid): + if fw.has_torch: + return libtriton.make_torch_src(src, grid) else: assert False @@ -36,9 +34,7 @@ def _make_cache_path(src): return cachepath def _write_bindings(src, root): - if fw.has_tensorflow(): - name = 'tensorflow' - elif fw.has_torch(): + if fw.has_torch(): name = 'torch' else: assert False @@ -81,15 +77,7 @@ def _build(src, path): libraries = ['triton'] # add framework extra_compile_args = [] - if fw.has_tensorflow(): - library_dirs += [fw.tensorflow.sysconfig.get_lib()] - include_dirs += [fw.tensorflow.sysconfig.get_include()] - include_dirs += ['/usr/local/cuda/include/'] - libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')] - abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0 - extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)] - name = 'tensorflow' - elif fw.has_torch(): + if fw.has_torch(): prefix = os.path.dirname(fw.torch.__file__) library_dirs += [os.path.join(prefix, 'lib')] include_dirs += ['/usr/local/cuda/include/', @@ -138,18 +126,8 @@ def _cvt_to_def_str(obj): # bool if isinstance(obj, bool): return str(int(obj)) - # tensorflow type - if fw.has_tensorflow(): - if isinstance(obj, fw.tensorflow.DType): - return {fw.tensorflow.int8: 'char', - fw.tensorflow.int16: 'short', - fw.tensorflow.int32: 'int', - fw.tensorflow.int64: 'long', - fw.tensorflow.float16: 'half', - fw.tensorflow.float32: 'float', - fw.tensorflow.float64: 'double'}[obj] # torch type - elif fw.has_torch(): + if fw.has_torch(): if isinstance(obj, fw.torch.dtype): return {fw.torch.int8: 'char', fw.torch.int16: 'short', @@ -164,14 +142,12 @@ def _cvt_to_def_str(obj): return str(obj) -def _make_framework_op(src, outputs, tmp, options): - src, name = _make_framework_src(src, outputs, tmp, options) +def _make_framework_op(src, options): + src, name = _make_framework_src(src, options) cache_path = _make_cache_path(src) cpp, so = _write_bindings(src, cache_path) _build(cpp, cache_path) - if fw.has_tensorflow(): - return fw.tensorflow.load_op_library(so).__dict__[name] - elif fw.has_torch(): + if fw.has_torch(): fw.torch.ops.load_library(so) return getattr(fw.torch.ops.triton, name) else: @@ -193,13 +169,11 @@ bench_registry = triton.utils.id_dict() class kernel: - def __init__(self, src, outputs, tmp=[]): + def __init__(self, src): self.fw_id = dict() self.fw_grids = dict() self.fw_op = None self.src = src - self.outputs = outputs - self.tmp = tmp self.cst = dict() def set_constant(self, name, value): @@ -245,7 +219,7 @@ class kernel: for name, value in self.cst.items(): libtriton.register_cst(op_id, name, value) if self.fw_op is None: - self.fw_op = _make_framework_op(self.src, self.outputs, self.tmp, opt) + self.fw_op = _make_framework_op(self.src, opt) ######################## # initialize @@ -254,45 +228,13 @@ class kernel: libtriton.register_grid(op_id, grid) bench_id = libtriton.make_scalar_id() if bench > 0 else -1 - ######################### - # call framework function - ######################### - if fw.has_tensorflow(): - empty = [x for x in args if isinstance(x, triton.utils.tf_empty_proxy)] - if len(empty) != len(self.outputs): - raise ValueError('Number of empty arguments does not much number of outputs provided') - # operands - operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args] - # output data types - kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id} - for i, x in enumerate(args): - if isinstance(x, triton.utils.tf_empty_proxy): - kwargs['T' + str(i)] = x.dtype - # launch - ret = self.fw_op(*operands, **kwargs) - ret = [ret] if isinstance(ret, fw.tensorflow.Tensor) else ret - op_def = ret[0].op.op_def - # fill empty tensors with corresponding values - for j, y in enumerate(op_def.output_arg): - found = False - for i, x in enumerate(op_def.input_arg): - if y.name + '_shape' == x.name: - args[i].tensor = ret[j] - found = True - assert found - # store timing information - if bench > 0: - for y in ret: - bench_registry[y] = triton.utils.id_dict.lazy_entry(bench_id) - ############################ # call torch function ############################ - elif fw.has_torch(): - args = [x if isinstance(x, fw.torch.Tensor) else x for x in args] - ret = self.fw_op(op_id, bench, bench_id, *args) + if fw.has_torch(): + self.fw_op(op_id, bench, bench_id, *args) if bench > 0: - bench_registry[ret] = libtriton.retrieve_scalar(bench_id) + return libtriton.retrieve_scalar(bench_id) else: assert False \ No newline at end of file diff --git a/python/triton/ops/batchnorm.py b/python/triton/ops/batchnorm.py index 117cca3b1..3d415c7bb 100644 --- a/python/triton/ops/batchnorm.py +++ b/python/triton/ops/batchnorm.py @@ -41,7 +41,7 @@ void fwdbatchnorm(float *Y, float *M, float *V, } } """ - fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V']) + fwd_kernel = triton.kernel(fwd_src) bwd_src = """ void bwdbatchnorm(float *DX, float *DG, float *DB, @@ -88,7 +88,7 @@ void bwdbatchnorm(float *DX, float *DG, float *DB, } } """ - bwd_kernel = triton.kernel(bwd_src, ['DX', 'DG', 'DB']) + bwd_kernel = triton.kernel(bwd_src) @staticmethod def forward(ctx, x, gamma, beta, eps): diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 005ea8812..8ebf9f439 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -313,7 +313,7 @@ __global__ void {name}( """ #print(src) - ret = triton.kernel(src, ['C']) + ret = triton.kernel(src) if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT: ret.set_constant('AD', delta_a) if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT: @@ -563,7 +563,7 @@ __global__ void {name}( self.args[self.pos_a] = a self.args[self.pos_b] = b self.args[self.pos_c] = c - self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros) + return self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros) @@ -591,7 +591,7 @@ __global__ void {name}( a.stride(), b.stride(), c.stride(), a.shape, b.shape, c.shape, arrays) instance = cache[key] - instance.run(a, b, c, bench) + speed = instance.run(a, b, c, bench) # save information in context ctx.flops = instance.flops ctx.sym_a = instance.sym_a @@ -602,6 +602,7 @@ __global__ void {name}( ctx.matmul_N = instance.matmul_N ctx.matmul_K = instance.matmul_K ctx.bench = bench + ctx.forward_ms = speed ctx.save_for_backward(a, b) return c