[PYTHON][CORE] Deprecating Tensorflow support
This commit is contained in:
committed by
Philippe Tillet
parent
d7a781dd40
commit
404dd18333
@@ -187,11 +187,11 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
a = torch.rand(B, M, K).type(dtype).cuda()
|
||||
b = torch.rand(B, K, N).type(dtype).cuda()
|
||||
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
|
||||
ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
|
||||
ratio = triton.ctx_registry[tmmc].forward_ms / ctx.forward_ms
|
||||
cmp_str = f'({ratio:4.2f})'
|
||||
else:
|
||||
cmp_str = ''
|
||||
# test and benchmark
|
||||
bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3
|
||||
bench = 2. * B * M * N * K / ctx.forward_ms * 1e-3
|
||||
diff = (tc - rc).abs().max() / rc.abs().max()
|
||||
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')
|
||||
|
@@ -449,32 +449,9 @@ inline std::string to_c_ty(ir::type *ty) {
|
||||
|
||||
void gen_torch_signature(std::ostringstream& oss,
|
||||
ir::function* fn,
|
||||
const std::vector<std::string>& outputs,
|
||||
const std::string& name) {
|
||||
const auto& args = fn->args();
|
||||
std::vector<ir::type*> out_types;
|
||||
for(const std::string& out: outputs) {
|
||||
auto it = std::find_if(args.begin(), args.end(),
|
||||
[&](ir::argument* arg) { return arg->get_name() == out; });
|
||||
if(it == args.end())
|
||||
throw std::runtime_error("unknown argument");
|
||||
out_types.push_back((*it)->get_type());
|
||||
}
|
||||
|
||||
std::string ret_ty;
|
||||
if(out_types.empty())
|
||||
ret_ty = "void";
|
||||
else{
|
||||
ir::type* ty = out_types[0];
|
||||
ret_ty = to_torch_ty(ty);
|
||||
if(out_types.size() > 1){
|
||||
for(size_t i = 1; i < out_types.size(); i++)
|
||||
if(out_types[i] != ty)
|
||||
throw std::runtime_error("outputs of different types not supported by pytorch");
|
||||
ret_ty = "std::vector<" + ret_ty + ">";
|
||||
}
|
||||
}
|
||||
|
||||
std::string ret_ty = "void";
|
||||
oss << ret_ty << " " << name << "(";
|
||||
oss << "int64_t id, ";
|
||||
oss << "int64_t bench, ";
|
||||
@@ -555,9 +532,7 @@ void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
|
||||
|
||||
std::tuple<std::string,
|
||||
std::string> make_torch_src(const std::string& src,
|
||||
const std::vector<std::string>& outputs,
|
||||
const std::vector<std::string>& tmp,
|
||||
const runtime::function::options_space_t& opt) {
|
||||
const runtime::function::options_space_t& opt) {
|
||||
// triton-ir code-gen
|
||||
ir::context ctx;
|
||||
auto ir = std::shared_ptr<ir::module>(new ir::module("", ctx));
|
||||
@@ -588,12 +563,12 @@ extern std::map<size_t, int64_t> i64scalar_map;
|
||||
|
||||
)";
|
||||
|
||||
gen_torch_signature(oss, fn, outputs, name);
|
||||
gen_torch_signature(oss, fn, name);
|
||||
oss << " {" << std::endl;
|
||||
gen_torch_init_driver(oss, fn->args());
|
||||
gen_torch_make_handles(oss, fn->args());
|
||||
gen_torch_make_launch_function(oss, fn->args());
|
||||
gen_torch_ret(oss, outputs);
|
||||
//gen_torch_ret(oss);
|
||||
oss << "}" << std::endl;
|
||||
|
||||
oss << std::endl;
|
||||
|
@@ -17,11 +17,9 @@ import triton.frameworks as fw
|
||||
import triton.utils
|
||||
import triton._C.libtriton as libtriton
|
||||
|
||||
def _make_framework_src(src, out, tmp, grid):
|
||||
if fw.has_tensorflow():
|
||||
return libtriton.make_tensorflow_src(src, out, tmp, grid)
|
||||
elif fw.has_torch:
|
||||
return libtriton.make_torch_src(src, out, tmp, grid)
|
||||
def _make_framework_src(src, grid):
|
||||
if fw.has_torch:
|
||||
return libtriton.make_torch_src(src, grid)
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -36,9 +34,7 @@ def _make_cache_path(src):
|
||||
return cachepath
|
||||
|
||||
def _write_bindings(src, root):
|
||||
if fw.has_tensorflow():
|
||||
name = 'tensorflow'
|
||||
elif fw.has_torch():
|
||||
if fw.has_torch():
|
||||
name = 'torch'
|
||||
else:
|
||||
assert False
|
||||
@@ -81,15 +77,7 @@ def _build(src, path):
|
||||
libraries = ['triton']
|
||||
# add framework
|
||||
extra_compile_args = []
|
||||
if fw.has_tensorflow():
|
||||
library_dirs += [fw.tensorflow.sysconfig.get_lib()]
|
||||
include_dirs += [fw.tensorflow.sysconfig.get_include()]
|
||||
include_dirs += ['/usr/local/cuda/include/']
|
||||
libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
|
||||
abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
|
||||
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)]
|
||||
name = 'tensorflow'
|
||||
elif fw.has_torch():
|
||||
if fw.has_torch():
|
||||
prefix = os.path.dirname(fw.torch.__file__)
|
||||
library_dirs += [os.path.join(prefix, 'lib')]
|
||||
include_dirs += ['/usr/local/cuda/include/',
|
||||
@@ -138,18 +126,8 @@ def _cvt_to_def_str(obj):
|
||||
# bool
|
||||
if isinstance(obj, bool):
|
||||
return str(int(obj))
|
||||
# tensorflow type
|
||||
if fw.has_tensorflow():
|
||||
if isinstance(obj, fw.tensorflow.DType):
|
||||
return {fw.tensorflow.int8: 'char',
|
||||
fw.tensorflow.int16: 'short',
|
||||
fw.tensorflow.int32: 'int',
|
||||
fw.tensorflow.int64: 'long',
|
||||
fw.tensorflow.float16: 'half',
|
||||
fw.tensorflow.float32: 'float',
|
||||
fw.tensorflow.float64: 'double'}[obj]
|
||||
# torch type
|
||||
elif fw.has_torch():
|
||||
if fw.has_torch():
|
||||
if isinstance(obj, fw.torch.dtype):
|
||||
return {fw.torch.int8: 'char',
|
||||
fw.torch.int16: 'short',
|
||||
@@ -164,14 +142,12 @@ def _cvt_to_def_str(obj):
|
||||
return str(obj)
|
||||
|
||||
|
||||
def _make_framework_op(src, outputs, tmp, options):
|
||||
src, name = _make_framework_src(src, outputs, tmp, options)
|
||||
def _make_framework_op(src, options):
|
||||
src, name = _make_framework_src(src, options)
|
||||
cache_path = _make_cache_path(src)
|
||||
cpp, so = _write_bindings(src, cache_path)
|
||||
_build(cpp, cache_path)
|
||||
if fw.has_tensorflow():
|
||||
return fw.tensorflow.load_op_library(so).__dict__[name]
|
||||
elif fw.has_torch():
|
||||
if fw.has_torch():
|
||||
fw.torch.ops.load_library(so)
|
||||
return getattr(fw.torch.ops.triton, name)
|
||||
else:
|
||||
@@ -193,13 +169,11 @@ bench_registry = triton.utils.id_dict()
|
||||
|
||||
class kernel:
|
||||
|
||||
def __init__(self, src, outputs, tmp=[]):
|
||||
def __init__(self, src):
|
||||
self.fw_id = dict()
|
||||
self.fw_grids = dict()
|
||||
self.fw_op = None
|
||||
self.src = src
|
||||
self.outputs = outputs
|
||||
self.tmp = tmp
|
||||
self.cst = dict()
|
||||
|
||||
def set_constant(self, name, value):
|
||||
@@ -245,7 +219,7 @@ class kernel:
|
||||
for name, value in self.cst.items():
|
||||
libtriton.register_cst(op_id, name, value)
|
||||
if self.fw_op is None:
|
||||
self.fw_op = _make_framework_op(self.src, self.outputs, self.tmp, opt)
|
||||
self.fw_op = _make_framework_op(self.src, opt)
|
||||
|
||||
########################
|
||||
# initialize
|
||||
@@ -254,45 +228,13 @@ class kernel:
|
||||
libtriton.register_grid(op_id, grid)
|
||||
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
|
||||
|
||||
#########################
|
||||
# call framework function
|
||||
#########################
|
||||
if fw.has_tensorflow():
|
||||
empty = [x for x in args if isinstance(x, triton.utils.tf_empty_proxy)]
|
||||
if len(empty) != len(self.outputs):
|
||||
raise ValueError('Number of empty arguments does not much number of outputs provided')
|
||||
# operands
|
||||
operands = [x.shape if isinstance(x, triton.utils.tf_empty_proxy) else x for x in args]
|
||||
# output data types
|
||||
kwargs = {'id': op_id, 'bench': bench, 'bench_id': bench_id}
|
||||
for i, x in enumerate(args):
|
||||
if isinstance(x, triton.utils.tf_empty_proxy):
|
||||
kwargs['T' + str(i)] = x.dtype
|
||||
# launch
|
||||
ret = self.fw_op(*operands, **kwargs)
|
||||
ret = [ret] if isinstance(ret, fw.tensorflow.Tensor) else ret
|
||||
op_def = ret[0].op.op_def
|
||||
# fill empty tensors with corresponding values
|
||||
for j, y in enumerate(op_def.output_arg):
|
||||
found = False
|
||||
for i, x in enumerate(op_def.input_arg):
|
||||
if y.name + '_shape' == x.name:
|
||||
args[i].tensor = ret[j]
|
||||
found = True
|
||||
assert found
|
||||
# store timing information
|
||||
if bench > 0:
|
||||
for y in ret:
|
||||
bench_registry[y] = triton.utils.id_dict.lazy_entry(bench_id)
|
||||
|
||||
############################
|
||||
# call torch function
|
||||
############################
|
||||
elif fw.has_torch():
|
||||
args = [x if isinstance(x, fw.torch.Tensor) else x for x in args]
|
||||
ret = self.fw_op(op_id, bench, bench_id, *args)
|
||||
if fw.has_torch():
|
||||
self.fw_op(op_id, bench, bench_id, *args)
|
||||
if bench > 0:
|
||||
bench_registry[ret] = libtriton.retrieve_scalar(bench_id)
|
||||
return libtriton.retrieve_scalar(bench_id)
|
||||
|
||||
else:
|
||||
assert False
|
@@ -41,7 +41,7 @@ void fwdbatchnorm(float *Y, float *M, float *V,
|
||||
}
|
||||
}
|
||||
"""
|
||||
fwd_kernel = triton.kernel(fwd_src, ['Y', 'M', 'V'])
|
||||
fwd_kernel = triton.kernel(fwd_src)
|
||||
|
||||
bwd_src = """
|
||||
void bwdbatchnorm(float *DX, float *DG, float *DB,
|
||||
@@ -88,7 +88,7 @@ void bwdbatchnorm(float *DX, float *DG, float *DB,
|
||||
}
|
||||
}
|
||||
"""
|
||||
bwd_kernel = triton.kernel(bwd_src, ['DX', 'DG', 'DB'])
|
||||
bwd_kernel = triton.kernel(bwd_src)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, gamma, beta, eps):
|
||||
|
@@ -313,7 +313,7 @@ __global__ void {name}(
|
||||
"""
|
||||
|
||||
#print(src)
|
||||
ret = triton.kernel(src, ['C'])
|
||||
ret = triton.kernel(src)
|
||||
if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT:
|
||||
ret.set_constant('AD', delta_a)
|
||||
if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT:
|
||||
@@ -563,7 +563,7 @@ __global__ void {name}(
|
||||
self.args[self.pos_a] = a
|
||||
self.args[self.pos_b] = b
|
||||
self.args[self.pos_c] = c
|
||||
self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros)
|
||||
return self.kernel(*self.args, grid=self.grid, bench=bench, defines=self.macros)
|
||||
|
||||
|
||||
|
||||
@@ -591,7 +591,7 @@ __global__ void {name}(
|
||||
a.stride(), b.stride(), c.stride(),
|
||||
a.shape, b.shape, c.shape, arrays)
|
||||
instance = cache[key]
|
||||
instance.run(a, b, c, bench)
|
||||
speed = instance.run(a, b, c, bench)
|
||||
# save information in context
|
||||
ctx.flops = instance.flops
|
||||
ctx.sym_a = instance.sym_a
|
||||
@@ -602,6 +602,7 @@ __global__ void {name}(
|
||||
ctx.matmul_N = instance.matmul_N
|
||||
ctx.matmul_K = instance.matmul_K
|
||||
ctx.bench = bench
|
||||
ctx.forward_ms = speed
|
||||
ctx.save_for_backward(a, b)
|
||||
return c
|
||||
|
||||
|
Reference in New Issue
Block a user