diff --git a/python/examples/dot.py b/python/examples/dot.py index da3cb9831..84ae9b6f3 100644 --- a/python/examples/dot.py +++ b/python/examples/dot.py @@ -1,121 +1,12 @@ +import numpy as np import tensorflow as tf import triton -import numpy as np - - -class dot(triton.function): - - src = """ -void dot(TYPE * A, TYPE * B, TYPE * C, - int M, int N, int K, - int lda __multipleof(8), - int ldb __multipleof(8), - int ldc) { - // prologue - int ridx = get_program_id(0); - int ridy = get_program_id(1); - int rxa[TM] = ridx * TM + 0 ... TM; - int ryb[TN] = ridy * TN + 0 ... TN; - int rka[TK] = 0 ... TK; - int rkb[TK] = 0 ... TK; - float c[TM, TN] = 0; - // pointers to operands - TYPE* pa[SHAPE_A] = A + rka[BROADCAST_AK] * STRIDE_AK + rxa[BROADCAST_AM] * STRIDE_AM; - TYPE* pb[SHAPE_B] = B + rkb[BROADCAST_BK] * STRIDE_BK + ryb[BROADCAST_BN] * STRIDE_BN; - // prefetches operands - TYPE a[SHAPE_A] = *pa; - TYPE b[SHAPE_B] = *pb; - // reduction loop - for(int k = K; k > 0; k-= TK){ - c += USE_A @ USE_B; - pa = pa + TK * STRIDE_AK; - pb = pb + TK * STRIDE_BK; - a = *pa; - b = *pb; - } - // epilogue - int rxc[TM] = ridx * TM + 0 ... TM; - int ryc[TN] = ridy * TN + 0 ... TN; - TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc; - bool checkc[TM, TN] = (rxc < M)[:, newaxis] && (ryc < N)[newaxis, :]; - *?(checkc) pc = c; -} -""" - - op = triton.op(src, ['C']) - - @staticmethod - def _call(a, b, transpose_a, transpose_b): - # extract shapes - shape_a = triton.shape(a) - shape_b = triton.shape(b) - M, Ka = shape_a[0], shape_a[1] - Kb, N = shape_b[0], shape_b[1] - # transpose shapes - if transpose_a: - M, Ka = Ka, M - if transpose_b: - Kb, N = N, Kb - # contiguous dimensions - lda = M if transpose_a else Ka - ldb = Kb if transpose_b else N - ldc = N - # data-type - dtype = a.dtype - # allocate output - c = triton.empty([M, N], dtype = dtype) - # compute - grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))] - # macros -- not necessary but makes kernel source-code simpler - macros = {# handle A transposition - 'USE_A' : '^a' if transpose_a else 'a', - 'STRIDE_AK' : 'lda' if transpose_a else '1', - 'STRIDE_AM' : '1' if transpose_a else 'lda', - 'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :', - 'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis', - 'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK', - # handle B transposition - 'USE_B' : '^b' if transpose_b else 'b', - 'STRIDE_BK' : '1' if transpose_b else 'ldb', - 'STRIDE_BN' : 'ldb' if transpose_b else '1', - 'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis', - 'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :', - 'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'} - return dot.op(a, b, c, M, N, Ka, lda, ldb, ldc, grid, - AT = transpose_a, BT = transpose_b, TYPE = dtype, - TM = [64, 128], TN = [64, 128], TK = [8], **macros) - - @staticmethod - def forward(ctx, a, b, transpose_a = False, transpose_b = False): - ctx.save_for_backward(a, b, transpose_a, transpose_b) - return dot._call(a, b, transpose_a, transpose_b) - - @staticmethod - def backward(ctx, dy): - a, b, t_a, t_b = ctx.saved_tensors - if not t_a and not t_b: - da = dot._call(dy, b, False, True) - db = dot._call(a, dy, True, False) - elif not t_a and t_b: - da = dot._call(dy, b, False, False) - db = dot._call(dy, a, True, False) - elif t_a and not t_b: - da = dot._call(b, dy, False, True) - db = dot._call(a, dy, False, False) - elif t_a and t_b: - da = dot._call(b, dy, True, True) - db = dot._call(dy, a, True, True) - else: - assert False - return [da, db, None, None, None, None, None, None, None] - - def run_dot(): M, N, K = 128, 128, 128 a = tf.placeholder(tf.float32, shape=[M, K]) b = tf.placeholder(tf.float32, shape=[N, K]) - _dot = dot.apply + _dot = triton.ops.dot.apply tr_c = _dot(a, b, transpose_a = False, transpose_b = True) tr_d = _dot(tr_c, b, transpose_a = True, transpose_b = False) tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True) diff --git a/python/setup.py b/python/setup.py index 8a7c9b372..a70aa6c51 100644 --- a/python/setup.py +++ b/python/setup.py @@ -82,7 +82,8 @@ setup( author_email='ptillet@g.harvard.edu', description='A language and compiler for custom Deep Learning operations', long_description='', - packages=['triton'], + packages=['triton', + 'triton/ops'], ext_modules=[CMakeExtension('triton')], cmdclass=dict(build_ext=CMakeBuild), zip_safe=False, diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 18dff0a49..aa05eefe1 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1 +1,12 @@ -from .ops import * \ No newline at end of file +from .kernel import * +from .function import * +from .utils import * +import triton.ops + + +# clean-up libtriton resources +import atexit +import libtriton +@atexit.register +def cleanup(): + libtriton.cleanup() \ No newline at end of file diff --git a/python/triton/frameworks.py b/python/triton/frameworks.py new file mode 100644 index 000000000..60c0728f1 --- /dev/null +++ b/python/triton/frameworks.py @@ -0,0 +1,46 @@ +import sys +import os +import libtriton + +torch_id = 'torch' +tensorflow_id = 'tensorflow' + +torch = None +tensorflow = None +tf_extra_ops = None + + +def _import_torch(): + global torch + if torch is None: + import torch + +def _import_tensorflow(): + global tensorflow + if tensorflow is None: + import tensorflow + +def _import_tf_extra_ops(): + global tf_extra_ops + if tf_extra_ops is None: + path = os.path.dirname(libtriton.__file__) + path = os.path.join(path, 'libextra_tf_ops.so') + _import_tensorflow() + tf_extra_ops = tensorflow.load_op_library(path) + + +def _find_framework(default = None): + is_tf_imported = 'tensorflow' in sys.modules + is_torch_imported = 'torch' in sys.modules + if default: + if default not in [tensorflow_id, torch_id]: + raise ValueError('unsupported framework') + else: + return default + elif is_tf_imported and not is_torch_imported: + return tensorflow_id + elif is_torch_imported and not is_tf_imported: + return torch_id + else: + raise ValueError('cannot determine imported framework, ' + 'please provide framework argument') \ No newline at end of file diff --git a/python/triton/function.py b/python/triton/function.py new file mode 100644 index 000000000..8669dbc92 --- /dev/null +++ b/python/triton/function.py @@ -0,0 +1,54 @@ +import triton.frameworks as fw + +class OpContext(object): + + def save_for_backward(self, *tensors): + self.to_save = tensors + + def mark_dirty(self, *args): + self.dirty_tensors = args + + @property + def saved_tensors(self): + return self.to_save + + +class function_meta(type): + + def __init__(cls, name, bases, attrs): + cls.contexts = dict() + cls.registered = False + return super(function_meta, cls).__init__(name, bases, attrs) + +class function(metaclass = function_meta): + + def __init__(self, framework = None): + self.framework = _find_framework(framework) + pass + + @staticmethod + def forward(ctx, *args, **kwargs): + raise NotImplementedError + + @staticmethod + def backward(ctx, grad_output): + raise NotImplementedError + + @classmethod + def apply(cls, *args, **kwargs): + # call forward + ctx = OpContext() + result = cls.forward(ctx, *args, **kwargs) + id = result.op.get_attr('id') + cls.contexts[id] = ctx + # register backward + fw._import_tensorflow() + name = result.op.op_def.name + if not cls.registered: + @fw.tensorflow.RegisterGradient(name) + def gradient(op, dy): + id = op.get_attr('id') + return cls.backward(cls.contexts[id], dy) + cls.registered = True + # return result tensor + return result \ No newline at end of file diff --git a/python/triton/kernel.py b/python/triton/kernel.py new file mode 100644 index 000000000..b3d2be50a --- /dev/null +++ b/python/triton/kernel.py @@ -0,0 +1,215 @@ +# import for cache +import os +import tempfile +import shutil +import hashlib +import sysconfig +import sys +# import for just-in-time compilation +import distutils +import setuptools.command.build_ext +import setuptools +# triton +import triton.frameworks as fw +import triton.utils +import libtriton + +def _make_framework_src(src, out, grid, framework): + if framework == fw.tensorflow_id: + return libtriton.make_tensorflow_src(src, out, grid) + elif framework == fw.torch_id: + return libtriton.make_torch_src(src, out, grid) + else: + assert False + +def _make_cache_path(src): + md5 = hashlib.sha1(src.encode()) + hexhash = md5.hexdigest() + home = os.path.expanduser('~') + cacheroot = os.path.join(home, '.triton', 'cache') + cachepath = os.path.join(cacheroot, str(hexhash)) + if not os.path.exists(cachepath): + os.makedirs(cachepath) + return cachepath + +def _write_bindings(src, root, framework): + cpp = os.path.join(root, '{framework}.cpp'.format(framework=framework)) + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(root, '{framework}{suffix}'.format(framework=framework, suffix=suffix)) + recompile = False + # recompile if .so does not exist + if not os.path.exists(cpp) or not os.path.exists(so): + recompile = True + # recompile if cpp was modified after .so + elif max(cpp, so, key=os.path.getctime) == cpp: + recompile = True + # write cpp file + if recompile: + with open(cpp, 'w+') as handle: + handle.writelines(src) + # return path of cpp file + return (cpp, so) + +def _build(src, path, framework): + # include directories + triton_include_dirs = ['/home/philippe/development/triton/include'] + include_dirs = triton_include_dirs + # library directories + triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))] + library_dirs = triton_library_dirs + # libraries + libraries = ['triton'] + # add framework + extra_compile_args = [] + if framework == fw.tensorflow_id: + library_dirs += [fw.tensorflow.sysconfig.get_lib()] + include_dirs += [fw.tensorflow.sysconfig.get_include()] + include_dirs += ['/usr/local/cuda/include/'] + libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')] + ABI = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0 + extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={ABI}'.format(ABI=ABI)] + elif framework == fw.torch_id: + prefix = os.path.dirname(torch.__file__) + library_dirs += [os.path.join(prefix, 'lib')] + include_dirs += [os.path.join(prefix, 'lib', 'include'), + os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'), + os.path.join(prefix, 'include'), + os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')] + libraries += ['torch'] + else: + assert False + # extra arguments + extra_link_args = [] + # dependences + depends = [os.path.realpath(libtriton.__file__)] + # create extension module + ext = setuptools.Extension( + name = 'tensorflow', + language = 'c++', + sources = [src], + include_dirs = include_dirs, + extra_compile_args = extra_compile_args, + extra_link_args = extra_link_args, + library_dirs = library_dirs, + libraries = libraries, + depends = depends + ) + # build extension module + args = ['build_ext'] + tmp = tempfile.mkdtemp() + args.append('--build-temp=' + tmp) + args.append('--build-lib=' + path) + args.append('-q') + args = dict( + name = 'tensorflow', + ext_modules = [ext], + script_args = args, + ) + setuptools.setup(**args) + shutil.rmtree(tmp) + +def _cvt_to_def_str(obj, framework): + # bool + if isinstance(obj, bool): + return str(int(obj)) + # tensorflow type + if framework == fw.tensorflow_id: + if isinstance(obj, fw.tensorflow.DType): + return {fw.tensorflow.int8: 'char', + fw.tensorflow.int16: 'short', + fw.tensorflow.int32: 'int', + fw.tensorflow.int64: 'long', + fw.tensorflow.float16: 'half', + fw.tensorflow.float32: 'float', + fw.tensorflow.float64: 'double'}[obj] + # torch type + elif framework == fw.torch_id: + if isinstance(obj, torch.dtype): + return {torch.int8: 'char', + torch.int16: 'short', + torch.int32: 'int', + torch.int64: 'long', + torch.float16: 'half', + torch.float32: 'float', + torch.float64: 'double'}[obj] + else: + assert False + # default + return str(obj) + + +def _make_framework_op(src, outputs, options, framework): + src, name = _make_framework_src(src, outputs, options, framework) + cache_path = _make_cache_path(src) + cpp, so = _write_bindings(src, cache_path, framework) + _build(cpp, cache_path, framework) + if framework == fw.tensorflow_id: + return fw.tensorflow.load_op_library(so).__dict__[name] + elif framework == fw.torch_id: + torch.ops.load_library(so) + return torch.ops.triton.__dict__[name] + else: + assert False + +def _make_grid(args) : + scalars = [x for x in args[:-1] if isinstance(x, triton.utils.scalar)] + def grid(opt): + for x in scalars: + x.set_assume_initialized() + result = args[-1](opt) + for x in scalars: + x.unset_assume_initialized() + return result + return grid + + +class kernel: + + def __init__(self, src, outputs, framework = None): + self.fw_id = dict() + self.fw_grids = dict() + self.fw_op = None + self.src = src + self.outputs = outputs + self.framework = fw._find_framework(framework) + if self.framework == fw.tensorflow_id: + fw._import_tensorflow() + fw._import_tf_extra_ops() + elif self.framework == fw.torch_id: + fw._import_torch() + else: + assert False + + + def __call__(self, *args, **kwargs): + # create a new framework op when defines are different + key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in kwargs.items()]) + if key not in self.fw_id.keys(): + # code generation options + defines = [] + for k, v in kwargs.items(): + cvt = lambda x: _cvt_to_def_str(x, self.framework) + if(isinstance(v, list)): + values = list(map(cvt, v)) + else: + values = [cvt(v)] + defines.append((k, values)) + opt = libtriton.options_space() + opt.defines = defines + opt.num_warps = [4] + # create unique id for this op + op_id = libtriton.make_op_id() + self.fw_id[key] = op_id + # register function + libtriton.register_fn(op_id, self.src, opt) + if self.fw_op is None: + self.fw_op = _make_framework_op(self.src, self.outputs, opt, self.framework) + + # retrieve framework op + op_id = self.fw_id[key] + # register grid + libtriton.register_grid(op_id, _make_grid(args)) + # create operands + op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]] + # call framework function + return self.fw_op(*op_args, id=op_id) \ No newline at end of file diff --git a/python/triton/ops.py b/python/triton/ops.py deleted file mode 100644 index 633ceafa7..000000000 --- a/python/triton/ops.py +++ /dev/null @@ -1,415 +0,0 @@ -# import for cache -import os -import tempfile -import shutil -import hashlib -import sysconfig -import sys -# import for just-in-time compilation -import distutils -import setuptools.command.build_ext -import setuptools -# triton -import libtriton - - -# clean-up libtriton resources -import atexit -@atexit.register -def cleanup(): - libtriton.cleanup() - - -torch_id = 'torch' -tensorflow_id = 'tensorflow' - -torch = None -tensorflow = None -_gradient_registry = None -tf_extra_ops = None - - - - -def _import_torch(): - global torch - if torch is None: - import torch - -def _import_tensorflow(): - global tensorflow - if tensorflow is None: - import tensorflow - global _gradient_registry - if _gradient_registry is None: - from tensorflow.python.framework.ops import _gradient_registry - -def _import_tf_extra_ops(): - global tf_extra_ops - if tf_extra_ops is None: - path = os.path.dirname(libtriton.__file__) - path = os.path.join(path, 'libextra_tf_ops.so') - _import_tensorflow() - tf_extra_ops = tensorflow.load_op_library(path) - - -def _find_framework(default = None): - is_tf_imported = 'tensorflow' in sys.modules - is_torch_imported = 'torch' in sys.modules - if default: - if default not in [tensorflow_id, torch_id]: - raise ValueError('unsupported framework') - else: - return default - elif is_tf_imported and not is_torch_imported: - return tensorflow_id - elif is_torch_imported and not is_tf_imported: - return torch_id - else: - raise ValueError('cannot determine imported framework, ' - 'please provide framework argument') - - -def _make_framework_src(src, out, grid, framework): - if framework == tensorflow_id: - return libtriton.make_tensorflow_src(src, out, grid) - elif framework == torch_id: - return libtriton.make_torch_src(src, out, grid) - else: - assert False - -def _make_cache_path(src): - md5 = hashlib.sha1(src.encode()) - hexhash = md5.hexdigest() - home = os.path.expanduser('~') - cacheroot = os.path.join(home, '.triton', 'cache') - cachepath = os.path.join(cacheroot, str(hexhash)) - if not os.path.exists(cachepath): - os.makedirs(cachepath) - return cachepath - -def _write_bindings(src, root, framework): - cpp = os.path.join(root, '{framework}.cpp'.format(framework=framework)) - suffix = sysconfig.get_config_var('EXT_SUFFIX') - so = os.path.join(root, '{framework}{suffix}'.format(framework=framework, suffix=suffix)) - recompile = False - # recompile if .so does not exist - if not os.path.exists(cpp) or not os.path.exists(so): - recompile = True - # recompile if cpp was modified after .so - elif max(cpp, so, key=os.path.getctime) == cpp: - recompile = True - # write cpp file - if recompile: - with open(cpp, 'w+') as handle: - handle.writelines(src) - # return path of cpp file - return (cpp, so) - -def _build(src, path, framework): - # include directories - triton_include_dirs = ['/home/philippe/development/triton/include'] - include_dirs = triton_include_dirs - # library directories - triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))] - library_dirs = triton_library_dirs - # libraries - libraries = ['triton'] - # add framework - extra_compile_args = [] - if framework == tensorflow_id: - _import_tensorflow() - library_dirs += [tensorflow.sysconfig.get_lib()] - include_dirs += [tensorflow.sysconfig.get_include()] - include_dirs += ['/usr/local/cuda/include/'] - libraries += [tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')] - ABI = tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tensorflow.__dict__ else 0 - extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={ABI}'.format(ABI=ABI)] - elif framework == torch_id: - _import_torch() - prefix = os.path.dirname(torch.__file__) - library_dirs += [os.path.join(prefix, 'lib')] - include_dirs += [os.path.join(prefix, 'lib', 'include'), - os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'), - os.path.join(prefix, 'include'), - os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')] - libraries += ['torch'] - else: - assert False - # extra arguments - extra_link_args = [] - # dependences - depends = [os.path.realpath(libtriton.__file__)] - # create extension module - ext = setuptools.Extension( - name = 'tensorflow', - language = 'c++', - sources = [src], - include_dirs = include_dirs, - extra_compile_args = extra_compile_args, - extra_link_args = extra_link_args, - library_dirs = library_dirs, - libraries = libraries, - depends = depends - ) - # build extension module - args = ['build_ext'] - tmp = tempfile.mkdtemp() - args.append('--build-temp=' + tmp) - args.append('--build-lib=' + path) - args.append('-q') - args = dict( - name = 'tensorflow', - ext_modules = [ext], - script_args = args, - ) - setuptools.setup(**args) - shutil.rmtree(tmp) - -def _cvt_to_def_str(obj, framework): - # bool - if isinstance(obj, bool): - return str(int(obj)) - # tensorflow type - if framework == tensorflow_id: - _import_tensorflow() - if isinstance(obj, tensorflow.DType): - return {tensorflow.int8: 'char', - tensorflow.int16: 'short', - tensorflow.int32: 'int', - tensorflow.int64: 'long', - tensorflow.float16: 'half', - tensorflow.float32: 'float', - tensorflow.float64: 'double'}[obj] - # torch type - elif framework == torch_id: - _import_torch() - if isinstance(obj, torch.dtype): - return {torch.int8: 'char', - torch.int16: 'short', - torch.int32: 'int', - torch.int64: 'long', - torch.float16: 'half', - torch.float32: 'float', - torch.float64: 'double'}[obj] - else: - assert False - # default - return str(obj) - - -def _make_framework_op(src, outputs, options, framework): - src, name = _make_framework_src(src, outputs, options, framework) - cache_path = _make_cache_path(src) - cpp, so = _write_bindings(src, cache_path, framework) - _build(cpp, cache_path, framework) - if framework == tensorflow_id: - _import_tensorflow() - return tensorflow.load_op_library(so).__dict__[name] - elif framework == torch_id: - _import_torch() - torch.ops.load_library(so) - return torch.ops.triton.__dict__[name] - else: - assert False - -def _make_grid(args) : - scalars = [x for x in args[:-1] if isinstance(x, scalar)] - def grid(opt): - for x in scalars: - x.set_assume_initialized() - result = args[-1](opt) - for x in scalars: - x.unset_assume_initialized() - return result - return grid - - - -class OpContext(object): - - def save_for_backward(self, *tensors): - self.to_save = tensors - - def mark_dirty(self, *args): - self.dirty_tensors = args - - @property - def saved_tensors(self): - return self.to_save - - -class function_meta(type): - - def __init__(cls, name, bases, attrs): - cls.contexts = dict() - cls.registered = False - return super(function_meta, cls).__init__(name, bases, attrs) - -class function(metaclass = function_meta): - - def __init__(self, framework = None): - self.framework = _find_framework(framework) - pass - - @staticmethod - def forward(ctx, *args, **kwargs): - raise NotImplementedError - - @staticmethod - def backward(ctx, grad_output): - raise NotImplementedError - - @classmethod - def apply(cls, *args, **kwargs): - # call forward - ctx = OpContext() - result = cls.forward(ctx, *args, **kwargs) - id = result.op.get_attr('id') - cls.contexts[id] = ctx - # register backward - _import_tensorflow() - from tensorflow.python.framework.ops import _gradient_registry - name = result.op.op_def.name - if not cls.registered: - @tensorflow.RegisterGradient(name) - def gradient(op, dy): - id = op.get_attr('id') - return cls.backward(cls.contexts[id], dy) - cls.registered = True - # return result tensor - return result - - - -class op: - - def __init__(self, src, outputs, framework = None): - self.fw_id = dict() - self.fw_grids = dict() - self.fw_op = None - self.src = src - self.outputs = outputs - self.framework = _find_framework(framework) - - - def __call__(self, *args, **kwargs): - # create a new op when defines are different - key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in kwargs.items()]) - if key not in self.fw_id.keys(): - # code generation options - defines = [] - for k, v in kwargs.items(): - cvt = lambda x: _cvt_to_def_str(x, self.framework) - if(isinstance(v, list)): - values = list(map(cvt, v)) - else: - values = [cvt(v)] - defines.append((k, values)) - opt = libtriton.options_space() - opt.defines = defines - opt.num_warps = [4] - # create unique id for this op - op_id = libtriton.make_op_id() - self.fw_id[key] = op_id - # register function - libtriton.register_fn(op_id, self.src, opt) - if self.fw_op is None: - self.fw_op = _make_framework_op(self.src, self.outputs, opt, self.framework) - - # retrieve framework op - op_id = self.fw_id[key] - # register grid - libtriton.register_grid(op_id, _make_grid(args)) - # create operands - op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]] - # call framework op - return self.fw_op(*op_args, id=op_id) - - -def empty(shapes, dtype, framework = None): - framework = _find_framework(framework) - if framework == tensorflow_id: - _import_tensorflow() - _import_tf_extra_ops - args = [x.handle if isinstance(x, scalar) else x for x in shapes] - args = tensorflow.stack(args) - return tf_extra_ops.alloc_empty(args, T = dtype) - elif framework == torch_id: - _import_torch() - return torch.empty(*shapes) - -def cdiv(a, b): - return -(-a // b) - -class scalar: - - def __init__(self, x): - _import_tf_extra_ops() - self.id = libtriton.make_scalar_id() - self.handle = tf_extra_ops.register_scalar(x, id=self.id) - self.assume_initialized = False - - def set_assume_initialized(self): - self.assume_initialized = True - - def unset_assume_initialized(self): - self.assume_initialized = False - - def get_value(self): - if self.assume_initialized: - return libtriton.retrieve_scalar(self.id) - else: - return self.handle - - def __add__(self, other): - return self.get_value() + other - - def __radd__(self, other): - return other + self.get_value() - - def __sub__(self, other): - return self.get_value() - other - - def __rsub(self, other): - return other - self.get_value() - - def __mul__(self, other): - return self.get_value() * other - - def __rmul(self, other): - return other * self.get_value() - - def __floordiv__(self, other): - return self.get_value() // other - - def __rfloordiv__(self, other): - return other // self.get_value() - - def __div__(self, other): - return self.get_value() / other - - def __rdiv__(self, other): - return other / self.get_value() - - def __truediv__(self, other): - self.get_value().__truediv__(other) - - def __rtruediv__(self, other): - other.__truediv__(self.get_value()) - - def __neg__(self): - return -self.get_value() - -class lazy_shape: - - def __init__(self, shape): - self.shape = shape - - def __getitem__(self, key): - return scalar(self.shape[key]) - -def shape(A) : - _import_tensorflow() - return lazy_shape(tensorflow.shape(A)) - diff --git a/python/triton/ops/__init__.py b/python/triton/ops/__init__.py new file mode 100644 index 000000000..f995b88f1 --- /dev/null +++ b/python/triton/ops/__init__.py @@ -0,0 +1 @@ +from .dot import dot diff --git a/python/triton/ops/dot.py b/python/triton/ops/dot.py new file mode 100644 index 000000000..f799be983 --- /dev/null +++ b/python/triton/ops/dot.py @@ -0,0 +1,107 @@ +import triton + +class dot(triton.function): + + src = """ +void dot(TYPE * A, TYPE * B, TYPE * C, + int M, int N, int K, + int lda __multipleof(8), + int ldb __multipleof(8), + int ldc) { + // prologue + int ridx = get_program_id(0); + int ridy = get_program_id(1); + int rxa[TM] = ridx * TM + 0 ... TM; + int ryb[TN] = ridy * TN + 0 ... TN; + int rka[TK] = 0 ... TK; + int rkb[TK] = 0 ... TK; + float c[TM, TN] = 0; + // pointers to operands + TYPE* pa[SHAPE_A] = A + rka[BROADCAST_AK] * STRIDE_AK + rxa[BROADCAST_AM] * STRIDE_AM; + TYPE* pb[SHAPE_B] = B + rkb[BROADCAST_BK] * STRIDE_BK + ryb[BROADCAST_BN] * STRIDE_BN; + // prefetches operands + TYPE a[SHAPE_A] = *pa; + TYPE b[SHAPE_B] = *pb; + // reduction loop + for(int k = K; k > 0; k-= TK){ + c += USE_A @ USE_B; + pa = pa + TK * STRIDE_AK; + pb = pb + TK * STRIDE_BK; + a = *pa; + b = *pb; + } + // epilogue + int rxc[TM] = ridx * TM + 0 ... TM; + int ryc[TN] = ridy * TN + 0 ... TN; + TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc; + bool checkc[TM, TN] = (rxc < M)[:, newaxis] && (ryc < N)[newaxis, :]; + *?(checkc) pc = c; +} +""" + + kernel = triton.kernel(src, ['C']) + + @staticmethod + def _call(a, b, transpose_a, transpose_b): + # extract shapes + shape_a = triton.shape(a) + shape_b = triton.shape(b) + M, Ka = shape_a[0], shape_a[1] + Kb, N = shape_b[0], shape_b[1] + # transpose shapes + if transpose_a: + M, Ka = Ka, M + if transpose_b: + Kb, N = N, Kb + # contiguous dimensions + lda = M if transpose_a else Ka + ldb = Kb if transpose_b else N + ldc = N + # data-type + dtype = a.dtype + # allocate output + c = triton.empty([M, N], dtype = dtype) + # compute + grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))] + # macros -- not necessary but makes kernel source-code simpler + macros = {# handle A transposition + 'USE_A' : '^a' if transpose_a else 'a', + 'STRIDE_AK' : 'lda' if transpose_a else '1', + 'STRIDE_AM' : '1' if transpose_a else 'lda', + 'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :', + 'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis', + 'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK', + # handle B transposition + 'USE_B' : '^b' if transpose_b else 'b', + 'STRIDE_BK' : '1' if transpose_b else 'ldb', + 'STRIDE_BN' : 'ldb' if transpose_b else '1', + 'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis', + 'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :', + 'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'} + return dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, grid, + AT = transpose_a, BT = transpose_b, TYPE = dtype, + TM = [64, 128], TN = [64, 128], TK = [8], **macros) + + @staticmethod + def forward(ctx, a, b, transpose_a = False, transpose_b = False): + ctx.save_for_backward(a, b, transpose_a, transpose_b) + return dot._call(a, b, transpose_a, transpose_b) + + @staticmethod + def backward(ctx, dy): + a, b, t_a, t_b = ctx.saved_tensors + if not t_a and not t_b: + da = dot._call(dy, b, False, True) + db = dot._call(a, dy, True, False) + elif not t_a and t_b: + da = dot._call(dy, b, False, False) + db = dot._call(dy, a, True, False) + elif t_a and not t_b: + da = dot._call(b, dy, False, True) + db = dot._call(a, dy, False, False) + elif t_a and t_b: + da = dot._call(b, dy, True, True) + db = dot._call(dy, a, True, True) + else: + assert False + return [da, db, None, None, None, None, None, None, None] \ No newline at end of file diff --git a/python/triton/tools/build.py b/python/triton/tools/build.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/python/triton/tools/checksum.py b/python/triton/tools/checksum.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/python/triton/utils.py b/python/triton/utils.py new file mode 100644 index 000000000..98380bf37 --- /dev/null +++ b/python/triton/utils.py @@ -0,0 +1,88 @@ +import triton.frameworks as fw +import libtriton + +def cdiv(a, b): + return -(-a // b) + +def empty(shapes, dtype, framework = None): + framework = fw._find_framework(framework) + if framework == fw.tensorflow_id: + args = [x.handle if isinstance(x, scalar) else x for x in shapes] + args = fw.tensorflow.stack(args) + return fw.tf_extra_ops.alloc_empty(args, T = dtype) + elif framework == fw.torch_id: + _import_torch() + return fw.torch.empty(*shapes) + +class lazy_shape: + + def __init__(self, shape): + self.shape = shape + + def __getitem__(self, key): + return scalar(self.shape[key]) + +def shape(A) : + fw._import_tensorflow() + return lazy_shape(fw.tensorflow.shape(A)) + + +class scalar: + + def __init__(self, x): + self.id = libtriton.make_scalar_id() + self.handle = fw.tf_extra_ops.register_scalar(x, id=self.id) + self.assume_initialized = False + + def set_assume_initialized(self): + self.assume_initialized = True + + def unset_assume_initialized(self): + self.assume_initialized = False + + def get_value(self): + if self.assume_initialized: + return libtriton.retrieve_scalar(self.id) + else: + return self.handle + + def __add__(self, other): + return self.get_value() + other + + def __radd__(self, other): + return other + self.get_value() + + def __sub__(self, other): + return self.get_value() - other + + def __rsub(self, other): + return other - self.get_value() + + def __mul__(self, other): + return self.get_value() * other + + def __rmul(self, other): + return other * self.get_value() + + def __floordiv__(self, other): + return self.get_value() // other + + def __rfloordiv__(self, other): + return other // self.get_value() + + def __div__(self, other): + return self.get_value() / other + + def __rdiv__(self, other): + return other / self.get_value() + + def __truediv__(self, other): + self.get_value().__truediv__(other) + + def __rtruediv__(self, other): + other.__truediv__(self.get_value()) + + def __neg__(self): + return -self.get_value() + +