[PYTHON][CORE] Deprecating Tensorflow support

This commit is contained in:
Philippe Tillet
2020-02-10 04:19:17 -05:00
committed by Philippe Tillet
parent d7a781dd40
commit 404dd18333
5 changed files with 26 additions and 108 deletions

View File

@@ -187,11 +187,11 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
a = torch.rand(B, M, K).type(dtype).cuda() a = torch.rand(B, M, K).type(dtype).cuda()
b = torch.rand(B, K, N).type(dtype).cuda() b = torch.rand(B, K, N).type(dtype).cuda()
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True) tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc] ratio = triton.ctx_registry[tmmc].forward_ms / ctx.forward_ms
cmp_str = f'({ratio:4.2f})' cmp_str = f'({ratio:4.2f})'
else: else:
cmp_str = '' cmp_str = ''
# test and benchmark # test and benchmark
bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3 bench = 2. * B * M * N * K / ctx.forward_ms * 1e-3
diff = (tc - rc).abs().max() / rc.abs().max() diff = (tc - rc).abs().max() / rc.abs().max()
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}') print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')

View File

@@ -449,32 +449,9 @@ inline std::string to_c_ty(ir::type *ty) {
void gen_torch_signature(std::ostringstream& oss, void gen_torch_signature(std::ostringstream& oss,
ir::function* fn, ir::function* fn,
const std::vector<std::string>& outputs,
const std::string& name) { const std::string& name) {
const auto& args = fn->args(); const auto& args = fn->args();
std::vector<ir::type*> out_types; std::string ret_ty = "void";
for(const std::string& out: outputs) {
auto it = std::find_if(args.begin(), args.end(),
[&](ir::argument* arg) { return arg->get_name() == out; });
if(it == args.end())
throw std::runtime_error("unknown argument");
out_types.push_back((*it)->get_type());
}
std::string ret_ty;
if(out_types.empty())
ret_ty = "void";
else{
ir::type* ty = out_types[0];
ret_ty = to_torch_ty(ty);
if(out_types.size() > 1){
for(size_t i = 1; i < out_types.size(); i++)
if(out_types[i] != ty)
throw std::runtime_error("outputs of different types not supported by pytorch");
ret_ty = "std::vector<" + ret_ty + ">";
}
}
oss << ret_ty << " " << name << "("; oss << ret_ty << " " << name << "(";
oss << "int64_t id, "; oss << "int64_t id, ";
oss << "int64_t bench, "; oss << "int64_t bench, ";
@@ -555,8 +532,6 @@ void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
std::tuple<std::string, std::tuple<std::string,
std::string> make_torch_src(const std::string& src, std::string> make_torch_src(const std::string& src,
const std::vector<std::string>& outputs,
const std::vector<std::string>& tmp,
const runtime::function::options_space_t& opt) { const runtime::function::options_space_t& opt) {
// triton-ir code-gen // triton-ir code-gen
ir::context ctx; ir::context ctx;
@@ -588,12 +563,12 @@ extern std::map<size_t, int64_t> i64scalar_map;
)"; )";
gen_torch_signature(oss, fn, outputs, name); gen_torch_signature(oss, fn, name);
oss << " {" << std::endl; oss << " {" << std::endl;
gen_torch_init_driver(oss, fn->args()); gen_torch_init_driver(oss, fn->args());
gen_torch_make_handles(oss, fn->args()); gen_torch_make_handles(oss, fn->args());
gen_torch_make_launch_function(oss, fn->args()); gen_torch_make_launch_function(oss, fn->args());
gen_torch_ret(oss, outputs); //gen_torch_ret(oss);
oss << "}" << std::endl; oss << "}" << std::endl;
oss << std::endl; oss << std::endl;

View File

@@ -17,11 +17,9 @@ import triton.frameworks as fw
import triton.utils import triton.utils
import triton._C.libtriton as libtriton import triton._C.libtriton as libtriton
def _make_framework_src(src, out, tmp, grid): def _make_framework_src(src, grid):
if fw.has_tensorflow(): if fw.has_torch:
return libtriton.make_tensorflow_src(src, out, tmp, grid) return libtriton.make_torch_src(src, grid)
elif fw.has_torch:
return libtriton.make_torch_src(src, out, tmp, grid)
else: else:
assert False assert False
@@ -36,9 +34,7 @@ def _make_cache_path(src):
return cachepath return cachepath
def _write_bindings(src, root): def _write_bindings(src, root):
if fw.has_tensorflow(): if fw.has_torch():
name = 'tensorflow'
elif fw.has_torch():
name = 'torch' name = 'torch'
else: else:
assert False assert False
@@ -81,15 +77,7 @@ def _build(src, path):
libraries = ['triton'] libraries = ['triton']
# add framework # add framework
extra_compile_args = [] extra_compile_args = []
if fw.has_tensorflow(): if fw.has_torch():
library_dirs += [fw.tensorflow.sysconfig.get_lib()]
include_dirs += [fw.tensorflow.sysconfig.get_include()]
include_dirs += ['/usr/local/cuda/include/']
libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
name = 'tensorflow'
elif fw.has_torch():
prefix = os.path.dirname(fw.torch.__file__) prefix = os.path.dirname(fw.torch.__file__)
library_dirs += [os.path.join(prefix, 'lib')] library_dirs += [os.path.join(prefix, 'lib')]
include_dirs += ['/usr/local/cuda/include/', include_dirs += ['/usr/local/cuda/include/',
@@ -138,18 +126,8 @@ def _cvt_to_def_str(obj):
# bool # bool
if isinstance(obj, bool): if isinstance(obj, bool):
return str(int(obj)) return str(int(obj))
# tensorflow type
if fw.has_tensorflow():
if isinstance(obj, fw.tensorflow.DType):
return {fw.tensorflow.int8: 'char',
fw.tensorflow.int16: 'short',
fw.tensorflow.int32: 'int',
fw.tensorflow.int64: 'long',
fw.tensorflow.float16: 'half',
fw.tensorflow.float32: 'float',
fw.tensorflow.float64: 'double'}[obj]
# torch type # torch type
elif fw.has_torch(): if fw.has_torch():
if isinstance(obj, fw.torch.dtype): if isinstance(obj, fw.torch.dtype):
return {fw.torch.int8: 'char', return {fw.torch.int8: 'char',
fw.torch.int16: 'short', fw.torch.int16: 'short',
@@ -164,14 +142,12 @@ def _cvt_to_def_str(obj):
return str(obj) return str(obj)
def _make_framework_op(src, outputs, tmp, options): def _make_framework_op(src, options):
src, name = _make_framework_src(src, outputs, tmp, options) src, name = _make_framework_src(src, options)
cache_path = _make_cache_path(src) cache_path = _make_cache_path(src)
cpp, so = _write_bindings(src, cache_path) cpp, so = _write_bindings(src, cache_path)
_build(cpp, cache_path) _build(cpp, cache_path)
if fw.has_tensorflow(): if fw.has_torch():
return fw.tensorflow.load_op_library(so).__dict__[name]
elif fw.has_torch():
fw.torch.ops.load_library(so) fw.torch.ops.load_library(so)
return getattr(fw.torch.ops.triton, name) return getattr(fw.torch.ops.triton, name)
else: else:
@@ -193,13 +169,11 @@ bench_registry = triton.utils.id_dict()
class kernel: class kernel:
def __init__(self, src, outputs, tmp=[]): def __init__(self, src):
self.fw_id = dict() self.fw_id = dict()
self.fw_grids = dict() self.fw_grids = dict()
self.fw_op = None self.fw_op = None
self.src = src self.src = src
self.outputs = outputs
self.tmp = tmp
self.cst = dict() self.cst = dict()
def set_constant(self, name, value): def set_constant(self, name, value):
@@ -245,7 +219,7 @@ class kernel:
for name, value in self.cst.items(): for name, value in self.cst.items():
libtriton.register_cst(op_id, name, value) libtriton.register_cst(op_id, name, value)
if self.fw_op is None: if self.fw_op is None:
self.fw_op = _make_framework_op(self.src, self.outputs, self.tmp, opt) self.fw_op = _make_framework_op(self.src, opt)
######################## ########################
# initialize # initialize
@@ -254,45 +228,13 @@ class kernel:
libtriton.register_grid(op_id, grid) libtriton.register_grid(op_id, grid)
bench_id = libtriton.make_scalar_id() if bench > 0 else -1 bench_id = libtriton.make_scalar_id() if bench > 0 else -1
#########################
# call framework function
#########################
if fw.has_tensorflow():
empty = [x for x in args if isinstance(x, triton.utils.tf_empty_proxy)]
if len(empty) != len(self.outputs):
raise ValueError('Number of empty arguments does not much number of outputs provided')
# operands
operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args]
# output data types
kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id}
for i, x in enumerate(args):
if isinstance(x, triton.utils.tf_empty_proxy):
kwargs['T' + str(i)] = x.dtype
# launch
ret = self.fw_op(*operands, **kwargs)
ret = [ret] if isinstance(ret, fw.tensorflow.Tensor) else ret
op_def = ret[0].op.op_def
# fill empty tensors with corresponding values
for j, y in enumerate(op_def.output_arg):
found = False
for i, x in enumerate(op_def.input_arg):
if y.name + '_shape' == x.name:
args[i].tensor = ret[j]
found = True
assert found
# store timing information
if bench > 0:
for y in ret:
bench_registry[y] = triton.utils.id_dict.lazy_entry(bench_id)
############################ ############################
# call torch function # call torch function
############################ ############################
elif fw.has_torch(): if fw.has_torch():
args = [x if isinstance(x, fw.torch.Tensor) else x for x in args] self.fw_op(op_id, bench, bench_id, *args)
ret = self.fw_op(op_id, bench, bench_id, *args)
if bench > 0: if bench > 0:
bench_registry[ret] = libtriton.retrieve_scalar(bench_id) return libtriton.retrieve_scalar(bench_id)
else: else:
assert False assert False

View File

@@ -41,7 +41,7 @@ void fwdbatchnorm(float *Y, float *M, float *V,
} }
} }
""" """
fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V']) fwd_kernel = triton.kernel(fwd_src)
bwd_src = """ bwd_src = """
void bwdbatchnorm(float *DX, float *DG, float *DB, void bwdbatchnorm(float *DX, float *DG, float *DB,
@@ -88,7 +88,7 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
} }
} }
""" """
bwd_kernel = triton.kernel(bwd_src, ['DX', 'DG', 'DB']) bwd_kernel = triton.kernel(bwd_src)
@staticmethod @staticmethod
def forward(ctx, x, gamma, beta, eps): def forward(ctx, x, gamma, beta, eps):

View File

@@ -313,7 +313,7 @@ __global__ void {name}(
""" """
#print(src) #print(src)
ret = triton.kernel(src, ['C']) ret = triton.kernel(src)
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT: if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
ret.set_constant('AD', delta_a) ret.set_constant('AD', delta_a)
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT: if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
@@ -563,7 +563,7 @@ __global__ void {name}(
self.args[self.pos_a] = a self.args[self.pos_a] = a
self.args[self.pos_b] = b self.args[self.pos_b] = b
self.args[self.pos_c] = c self.args[self.pos_c] = c
self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros) return self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros)
@@ -591,7 +591,7 @@ __global__ void {name}(
a.stride(), b.stride(), c.stride(), a.stride(), b.stride(), c.stride(),
a.shape, b.shape, c.shape, arrays) a.shape, b.shape, c.shape, arrays)
instance = cache[key] instance = cache[key]
instance.run(a, b, c, bench) speed = instance.run(a, b, c, bench)
# save information in context # save information in context
ctx.flops = instance.flops ctx.flops = instance.flops
ctx.sym_a = instance.sym_a ctx.sym_a = instance.sym_a
@@ -602,6 +602,7 @@ __global__ void {name}(
ctx.matmul_N = instance.matmul_N ctx.matmul_N = instance.matmul_N
ctx.matmul_K = instance.matmul_K ctx.matmul_K = instance.matmul_K
ctx.bench = bench ctx.bench = bench
ctx.forward_ms = speed
ctx.save_for_backward(a, b) ctx.save_for_backward(a, b)
return c return c