[python] modularized triton package

This commit is contained in:
Philippe Tillet
2019-09-04 21:55:47 -04:00
parent f6e9c24fe8
commit 945b5d0de9
12 changed files with 527 additions and 528 deletions

View File

@@ -1,121 +1,12 @@
import numpy as np
import tensorflow as tf import tensorflow as tf
import triton import triton
import numpy as np
class dot(triton.function):
src = """
void dot(TYPE * A, TYPE * B, TYPE * C,
int M, int N, int K,
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rxa[TM] = ridx * TM + 0 ... TM;
int ryb[TN] = ridy * TN + 0 ... TN;
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;
float c[TM, TN] = 0;
// pointers to operands
TYPE* pa[SHAPE_A] = A + rka[BROADCAST_AK] * STRIDE_AK + rxa[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rkb[BROADCAST_BK] * STRIDE_BK + ryb[BROADCAST_BN] * STRIDE_BN;
// prefetches operands
TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb;
// reduction loop
for(int k = K; k > 0; k-= TK){
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
b = *pb;
}
// epilogue
int rxc[TM] = ridx * TM + 0 ... TM;
int ryc[TN] = ridy * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc;
bool checkc[TM, TN] = (rxc < M)[:, newaxis] && (ryc < N)[newaxis, :];
*?(checkc) pc = c;
}
"""
op = triton.op(src, ['C'])
@staticmethod
def _call(a, b, transpose_a, transpose_b):
# extract shapes
shape_a = triton.shape(a)
shape_b = triton.shape(b)
M, Ka = shape_a[0], shape_a[1]
Kb, N = shape_b[0], shape_b[1]
# transpose shapes
if transpose_a:
M, Ka = Ka, M
if transpose_b:
Kb, N = N, Kb
# contiguous dimensions
lda = M if transpose_a else Ka
ldb = Kb if transpose_b else N
ldc = N
# data-type
dtype = a.dtype
# allocate output
c = triton.empty([M, N], dtype = dtype)
# compute
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
# macros -- not necessary but makes kernel source-code simpler
macros = {# handle A transposition
'USE_A' : '^a' if transpose_a else 'a',
'STRIDE_AK' : 'lda' if transpose_a else '1',
'STRIDE_AM' : '1' if transpose_a else 'lda',
'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK',
# handle B transposition
'USE_B' : '^b' if transpose_b else 'b',
'STRIDE_BK' : '1' if transpose_b else 'ldb',
'STRIDE_BN' : 'ldb' if transpose_b else '1',
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
return dot.op(a, b, c, M, N, Ka, lda, ldb, ldc, grid,
AT = transpose_a, BT = transpose_b, TYPE = dtype,
TM = [64, 128], TN = [64, 128], TK = [8], **macros)
@staticmethod
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
ctx.save_for_backward(a, b, transpose_a, transpose_b)
return dot._call(a, b, transpose_a, transpose_b)
@staticmethod
def backward(ctx, dy):
a, b, t_a, t_b = ctx.saved_tensors
if not t_a and not t_b:
da = dot._call(dy, b, False, True)
db = dot._call(a, dy, True, False)
elif not t_a and t_b:
da = dot._call(dy, b, False, False)
db = dot._call(dy, a, True, False)
elif t_a and not t_b:
da = dot._call(b, dy, False, True)
db = dot._call(a, dy, False, False)
elif t_a and t_b:
da = dot._call(b, dy, True, True)
db = dot._call(dy, a, True, True)
else:
assert False
return [da, db, None, None, None, None, None, None, None]
def run_dot(): def run_dot():
M, N, K = 128, 128, 128 M, N, K = 128, 128, 128
a = tf.placeholder(tf.float32, shape=[M, K]) a = tf.placeholder(tf.float32, shape=[M, K])
b = tf.placeholder(tf.float32, shape=[N, K]) b = tf.placeholder(tf.float32, shape=[N, K])
_dot = dot.apply _dot = triton.ops.dot.apply
tr_c = _dot(a, b, transpose_a = False, transpose_b = True) tr_c = _dot(a, b, transpose_a = False, transpose_b = True)
tr_d = _dot(tr_c, b, transpose_a = True, transpose_b = False) tr_d = _dot(tr_c, b, transpose_a = True, transpose_b = False)
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True) tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)

View File

@@ -82,7 +82,8 @@ setup(
author_email='ptillet@g.harvard.edu', author_email='ptillet@g.harvard.edu',
description='A language and compiler for custom Deep Learning operations', description='A language and compiler for custom Deep Learning operations',
long_description='', long_description='',
packages=['triton'], packages=['triton',
'triton/ops'],
ext_modules=[CMakeExtension('triton')], ext_modules=[CMakeExtension('triton')],
cmdclass=dict(build_ext=CMakeBuild), cmdclass=dict(build_ext=CMakeBuild),
zip_safe=False, zip_safe=False,

View File

@@ -1 +1,12 @@
from .ops import * from .kernel import *
from .function import *
from .utils import *
import triton.ops
# clean-up libtriton resources
import atexit
import libtriton
@atexit.register
def cleanup():
libtriton.cleanup()

View File

@@ -0,0 +1,46 @@
import sys
import os
import libtriton
torch_id = 'torch'
tensorflow_id = 'tensorflow'
torch = None
tensorflow = None
tf_extra_ops = None
def _import_torch():
global torch
if torch is None:
import torch
def _import_tensorflow():
global tensorflow
if tensorflow is None:
import tensorflow
def _import_tf_extra_ops():
global tf_extra_ops
if tf_extra_ops is None:
path = os.path.dirname(libtriton.__file__)
path = os.path.join(path, 'libextra_tf_ops.so')
_import_tensorflow()
tf_extra_ops = tensorflow.load_op_library(path)
def _find_framework(default = None):
is_tf_imported = 'tensorflow' in sys.modules
is_torch_imported = 'torch' in sys.modules
if default:
if default not in [tensorflow_id, torch_id]:
raise ValueError('unsupported framework')
else:
return default
elif is_tf_imported and not is_torch_imported:
return tensorflow_id
elif is_torch_imported and not is_tf_imported:
return torch_id
else:
raise ValueError('cannot determine imported framework, '
'please provide framework argument')

54
python/triton/function.py Normal file
View File

@@ -0,0 +1,54 @@
import triton.frameworks as fw
class OpContext(object):
def save_for_backward(self, *tensors):
self.to_save = tensors
def mark_dirty(self, *args):
self.dirty_tensors = args
@property
def saved_tensors(self):
return self.to_save
class function_meta(type):
def __init__(cls, name, bases, attrs):
cls.contexts = dict()
cls.registered = False
return super(function_meta, cls).__init__(name, bases, attrs)
class function(metaclass = function_meta):
def __init__(self, framework = None):
self.framework = _find_framework(framework)
pass
@staticmethod
def forward(ctx, *args, **kwargs):
raise NotImplementedError
@staticmethod
def backward(ctx, grad_output):
raise NotImplementedError
@classmethod
def apply(cls, *args, **kwargs):
# call forward
ctx = OpContext()
result = cls.forward(ctx, *args, **kwargs)
id = result.op.get_attr('id')
cls.contexts[id] = ctx
# register backward
fw._import_tensorflow()
name = result.op.op_def.name
if not cls.registered:
@fw.tensorflow.RegisterGradient(name)
def gradient(op, dy):
id = op.get_attr('id')
return cls.backward(cls.contexts[id], dy)
cls.registered = True
# return result tensor
return result

215
python/triton/kernel.py Normal file
View File

@@ -0,0 +1,215 @@
# import for cache
import os
import tempfile
import shutil
import hashlib
import sysconfig
import sys
# import for just-in-time compilation
import distutils
import setuptools.command.build_ext
import setuptools
# triton
import triton.frameworks as fw
import triton.utils
import libtriton
def _make_framework_src(src, out, grid, framework):
if framework == fw.tensorflow_id:
return libtriton.make_tensorflow_src(src, out, grid)
elif framework == fw.torch_id:
return libtriton.make_torch_src(src, out, grid)
else:
assert False
def _make_cache_path(src):
md5 = hashlib.sha1(src.encode())
hexhash = md5.hexdigest()
home = os.path.expanduser('~')
cacheroot = os.path.join(home, '.triton', 'cache')
cachepath = os.path.join(cacheroot, str(hexhash))
if not os.path.exists(cachepath):
os.makedirs(cachepath)
return cachepath
def _write_bindings(src, root, framework):
cpp = os.path.join(root, '{framework}.cpp'.format(framework=framework))
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(root, '{framework}{suffix}'.format(framework=framework, suffix=suffix))
recompile = False
# recompile if .so does not exist
if not os.path.exists(cpp) or not os.path.exists(so):
recompile = True
# recompile if cpp was modified after .so
elif max(cpp, so, key=os.path.getctime) == cpp:
recompile = True
# write cpp file
if recompile:
with open(cpp, 'w+') as handle:
handle.writelines(src)
# return path of cpp file
return (cpp, so)
def _build(src, path, framework):
# include directories
triton_include_dirs = ['/home/philippe/development/triton/include']
include_dirs = triton_include_dirs
# library directories
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))]
library_dirs = triton_library_dirs
# libraries
libraries = ['triton']
# add framework
extra_compile_args = []
if framework == fw.tensorflow_id:
library_dirs += [fw.tensorflow.sysconfig.get_lib()]
include_dirs += [fw.tensorflow.sysconfig.get_include()]
include_dirs += ['/usr/local/cuda/include/']
libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
ABI = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={ABI}'.format(ABI=ABI)]
elif framework == fw.torch_id:
prefix = os.path.dirname(torch.__file__)
library_dirs += [os.path.join(prefix, 'lib')]
include_dirs += [os.path.join(prefix, 'lib', 'include'),
os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'),
os.path.join(prefix, 'include'),
os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')]
libraries += ['torch']
else:
assert False
# extra arguments
extra_link_args = []
# dependences
depends = [os.path.realpath(libtriton.__file__)]
# create extension module
ext = setuptools.Extension(
name = 'tensorflow',
language = 'c++',
sources = [src],
include_dirs = include_dirs,
extra_compile_args = extra_compile_args,
extra_link_args = extra_link_args,
library_dirs = library_dirs,
libraries = libraries,
depends = depends
)
# build extension module
args = ['build_ext']
tmp = tempfile.mkdtemp()
args.append('--build-temp=' + tmp)
args.append('--build-lib=' + path)
args.append('-q')
args = dict(
name = 'tensorflow',
ext_modules = [ext],
script_args = args,
)
setuptools.setup(**args)
shutil.rmtree(tmp)
def _cvt_to_def_str(obj, framework):
# bool
if isinstance(obj, bool):
return str(int(obj))
# tensorflow type
if framework == fw.tensorflow_id:
if isinstance(obj, fw.tensorflow.DType):
return {fw.tensorflow.int8: 'char',
fw.tensorflow.int16: 'short',
fw.tensorflow.int32: 'int',
fw.tensorflow.int64: 'long',
fw.tensorflow.float16: 'half',
fw.tensorflow.float32: 'float',
fw.tensorflow.float64: 'double'}[obj]
# torch type
elif framework == fw.torch_id:
if isinstance(obj, torch.dtype):
return {torch.int8: 'char',
torch.int16: 'short',
torch.int32: 'int',
torch.int64: 'long',
torch.float16: 'half',
torch.float32: 'float',
torch.float64: 'double'}[obj]
else:
assert False
# default
return str(obj)
def _make_framework_op(src, outputs, options, framework):
src, name = _make_framework_src(src, outputs, options, framework)
cache_path = _make_cache_path(src)
cpp, so = _write_bindings(src, cache_path, framework)
_build(cpp, cache_path, framework)
if framework == fw.tensorflow_id:
return fw.tensorflow.load_op_library(so).__dict__[name]
elif framework == fw.torch_id:
torch.ops.load_library(so)
return torch.ops.triton.__dict__[name]
else:
assert False
def _make_grid(args) :
scalars = [x for x in args[:-1] if isinstance(x, triton.utils.scalar)]
def grid(opt):
for x in scalars:
x.set_assume_initialized()
result = args[-1](opt)
for x in scalars:
x.unset_assume_initialized()
return result
return grid
class kernel:
def __init__(self, src, outputs, framework = None):
self.fw_id = dict()
self.fw_grids = dict()
self.fw_op = None
self.src = src
self.outputs = outputs
self.framework = fw._find_framework(framework)
if self.framework == fw.tensorflow_id:
fw._import_tensorflow()
fw._import_tf_extra_ops()
elif self.framework == fw.torch_id:
fw._import_torch()
else:
assert False
def __call__(self, *args, **kwargs):
# create a new framework op when defines are different
key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in kwargs.items()])
if key not in self.fw_id.keys():
# code generation options
defines = []
for k, v in kwargs.items():
cvt = lambda x: _cvt_to_def_str(x, self.framework)
if(isinstance(v, list)):
values = list(map(cvt, v))
else:
values = [cvt(v)]
defines.append((k, values))
opt = libtriton.options_space()
opt.defines = defines
opt.num_warps = [4]
# create unique id for this op
op_id = libtriton.make_op_id()
self.fw_id[key] = op_id
# register function
libtriton.register_fn(op_id, self.src, opt)
if self.fw_op is None:
self.fw_op = _make_framework_op(self.src, self.outputs, opt, self.framework)
# retrieve framework op
op_id = self.fw_id[key]
# register grid
libtriton.register_grid(op_id, _make_grid(args))
# create operands
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
# call framework function
return self.fw_op(*op_args, id=op_id)

View File

@@ -1,415 +0,0 @@
# import for cache
import os
import tempfile
import shutil
import hashlib
import sysconfig
import sys
# import for just-in-time compilation
import distutils
import setuptools.command.build_ext
import setuptools
# triton
import libtriton
# clean-up libtriton resources
import atexit
@atexit.register
def cleanup():
libtriton.cleanup()
torch_id = 'torch'
tensorflow_id = 'tensorflow'
torch = None
tensorflow = None
_gradient_registry = None
tf_extra_ops = None
def _import_torch():
global torch
if torch is None:
import torch
def _import_tensorflow():
global tensorflow
if tensorflow is None:
import tensorflow
global _gradient_registry
if _gradient_registry is None:
from tensorflow.python.framework.ops import _gradient_registry
def _import_tf_extra_ops():
global tf_extra_ops
if tf_extra_ops is None:
path = os.path.dirname(libtriton.__file__)
path = os.path.join(path, 'libextra_tf_ops.so')
_import_tensorflow()
tf_extra_ops = tensorflow.load_op_library(path)
def _find_framework(default = None):
is_tf_imported = 'tensorflow' in sys.modules
is_torch_imported = 'torch' in sys.modules
if default:
if default not in [tensorflow_id, torch_id]:
raise ValueError('unsupported framework')
else:
return default
elif is_tf_imported and not is_torch_imported:
return tensorflow_id
elif is_torch_imported and not is_tf_imported:
return torch_id
else:
raise ValueError('cannot determine imported framework, '
'please provide framework argument')
def _make_framework_src(src, out, grid, framework):
if framework == tensorflow_id:
return libtriton.make_tensorflow_src(src, out, grid)
elif framework == torch_id:
return libtriton.make_torch_src(src, out, grid)
else:
assert False
def _make_cache_path(src):
md5 = hashlib.sha1(src.encode())
hexhash = md5.hexdigest()
home = os.path.expanduser('~')
cacheroot = os.path.join(home, '.triton', 'cache')
cachepath = os.path.join(cacheroot, str(hexhash))
if not os.path.exists(cachepath):
os.makedirs(cachepath)
return cachepath
def _write_bindings(src, root, framework):
cpp = os.path.join(root, '{framework}.cpp'.format(framework=framework))
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(root, '{framework}{suffix}'.format(framework=framework, suffix=suffix))
recompile = False
# recompile if .so does not exist
if not os.path.exists(cpp) or not os.path.exists(so):
recompile = True
# recompile if cpp was modified after .so
elif max(cpp, so, key=os.path.getctime) == cpp:
recompile = True
# write cpp file
if recompile:
with open(cpp, 'w+') as handle:
handle.writelines(src)
# return path of cpp file
return (cpp, so)
def _build(src, path, framework):
# include directories
triton_include_dirs = ['/home/philippe/development/triton/include']
include_dirs = triton_include_dirs
# library directories
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))]
library_dirs = triton_library_dirs
# libraries
libraries = ['triton']
# add framework
extra_compile_args = []
if framework == tensorflow_id:
_import_tensorflow()
library_dirs += [tensorflow.sysconfig.get_lib()]
include_dirs += [tensorflow.sysconfig.get_include()]
include_dirs += ['/usr/local/cuda/include/']
libraries += [tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')]
ABI = tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tensorflow.__dict__ else 0
extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={ABI}'.format(ABI=ABI)]
elif framework == torch_id:
_import_torch()
prefix = os.path.dirname(torch.__file__)
library_dirs += [os.path.join(prefix, 'lib')]
include_dirs += [os.path.join(prefix, 'lib', 'include'),
os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'),
os.path.join(prefix, 'include'),
os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')]
libraries += ['torch']
else:
assert False
# extra arguments
extra_link_args = []
# dependences
depends = [os.path.realpath(libtriton.__file__)]
# create extension module
ext = setuptools.Extension(
name = 'tensorflow',
language = 'c++',
sources = [src],
include_dirs = include_dirs,
extra_compile_args = extra_compile_args,
extra_link_args = extra_link_args,
library_dirs = library_dirs,
libraries = libraries,
depends = depends
)
# build extension module
args = ['build_ext']
tmp = tempfile.mkdtemp()
args.append('--build-temp=' + tmp)
args.append('--build-lib=' + path)
args.append('-q')
args = dict(
name = 'tensorflow',
ext_modules = [ext],
script_args = args,
)
setuptools.setup(**args)
shutil.rmtree(tmp)
def _cvt_to_def_str(obj, framework):
# bool
if isinstance(obj, bool):
return str(int(obj))
# tensorflow type
if framework == tensorflow_id:
_import_tensorflow()
if isinstance(obj, tensorflow.DType):
return {tensorflow.int8: 'char',
tensorflow.int16: 'short',
tensorflow.int32: 'int',
tensorflow.int64: 'long',
tensorflow.float16: 'half',
tensorflow.float32: 'float',
tensorflow.float64: 'double'}[obj]
# torch type
elif framework == torch_id:
_import_torch()
if isinstance(obj, torch.dtype):
return {torch.int8: 'char',
torch.int16: 'short',
torch.int32: 'int',
torch.int64: 'long',
torch.float16: 'half',
torch.float32: 'float',
torch.float64: 'double'}[obj]
else:
assert False
# default
return str(obj)
def _make_framework_op(src, outputs, options, framework):
src, name = _make_framework_src(src, outputs, options, framework)
cache_path = _make_cache_path(src)
cpp, so = _write_bindings(src, cache_path, framework)
_build(cpp, cache_path, framework)
if framework == tensorflow_id:
_import_tensorflow()
return tensorflow.load_op_library(so).__dict__[name]
elif framework == torch_id:
_import_torch()
torch.ops.load_library(so)
return torch.ops.triton.__dict__[name]
else:
assert False
def _make_grid(args) :
scalars = [x for x in args[:-1] if isinstance(x, scalar)]
def grid(opt):
for x in scalars:
x.set_assume_initialized()
result = args[-1](opt)
for x in scalars:
x.unset_assume_initialized()
return result
return grid
class OpContext(object):
def save_for_backward(self, *tensors):
self.to_save = tensors
def mark_dirty(self, *args):
self.dirty_tensors = args
@property
def saved_tensors(self):
return self.to_save
class function_meta(type):
def __init__(cls, name, bases, attrs):
cls.contexts = dict()
cls.registered = False
return super(function_meta, cls).__init__(name, bases, attrs)
class function(metaclass = function_meta):
def __init__(self, framework = None):
self.framework = _find_framework(framework)
pass
@staticmethod
def forward(ctx, *args, **kwargs):
raise NotImplementedError
@staticmethod
def backward(ctx, grad_output):
raise NotImplementedError
@classmethod
def apply(cls, *args, **kwargs):
# call forward
ctx = OpContext()
result = cls.forward(ctx, *args, **kwargs)
id = result.op.get_attr('id')
cls.contexts[id] = ctx
# register backward
_import_tensorflow()
from tensorflow.python.framework.ops import _gradient_registry
name = result.op.op_def.name
if not cls.registered:
@tensorflow.RegisterGradient(name)
def gradient(op, dy):
id = op.get_attr('id')
return cls.backward(cls.contexts[id], dy)
cls.registered = True
# return result tensor
return result
class op:
def __init__(self, src, outputs, framework = None):
self.fw_id = dict()
self.fw_grids = dict()
self.fw_op = None
self.src = src
self.outputs = outputs
self.framework = _find_framework(framework)
def __call__(self, *args, **kwargs):
# create a new op when defines are different
key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in kwargs.items()])
if key not in self.fw_id.keys():
# code generation options
defines = []
for k, v in kwargs.items():
cvt = lambda x: _cvt_to_def_str(x, self.framework)
if(isinstance(v, list)):
values = list(map(cvt, v))
else:
values = [cvt(v)]
defines.append((k, values))
opt = libtriton.options_space()
opt.defines = defines
opt.num_warps = [4]
# create unique id for this op
op_id = libtriton.make_op_id()
self.fw_id[key] = op_id
# register function
libtriton.register_fn(op_id, self.src, opt)
if self.fw_op is None:
self.fw_op = _make_framework_op(self.src, self.outputs, opt, self.framework)
# retrieve framework op
op_id = self.fw_id[key]
# register grid
libtriton.register_grid(op_id, _make_grid(args))
# create operands
op_args = [x.handle if isinstance(x, scalar) else x for x in args[:-1]]
# call framework op
return self.fw_op(*op_args, id=op_id)
def empty(shapes, dtype, framework = None):
framework = _find_framework(framework)
if framework == tensorflow_id:
_import_tensorflow()
_import_tf_extra_ops
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
args = tensorflow.stack(args)
return tf_extra_ops.alloc_empty(args, T = dtype)
elif framework == torch_id:
_import_torch()
return torch.empty(*shapes)
def cdiv(a, b):
return -(-a // b)
class scalar:
def __init__(self, x):
_import_tf_extra_ops()
self.id = libtriton.make_scalar_id()
self.handle = tf_extra_ops.register_scalar(x, id=self.id)
self.assume_initialized = False
def set_assume_initialized(self):
self.assume_initialized = True
def unset_assume_initialized(self):
self.assume_initialized = False
def get_value(self):
if self.assume_initialized:
return libtriton.retrieve_scalar(self.id)
else:
return self.handle
def __add__(self, other):
return self.get_value() + other
def __radd__(self, other):
return other + self.get_value()
def __sub__(self, other):
return self.get_value() - other
def __rsub(self, other):
return other - self.get_value()
def __mul__(self, other):
return self.get_value() * other
def __rmul(self, other):
return other * self.get_value()
def __floordiv__(self, other):
return self.get_value() // other
def __rfloordiv__(self, other):
return other // self.get_value()
def __div__(self, other):
return self.get_value() / other
def __rdiv__(self, other):
return other / self.get_value()
def __truediv__(self, other):
self.get_value().__truediv__(other)
def __rtruediv__(self, other):
other.__truediv__(self.get_value())
def __neg__(self):
return -self.get_value()
class lazy_shape:
def __init__(self, shape):
self.shape = shape
def __getitem__(self, key):
return scalar(self.shape[key])
def shape(A) :
_import_tensorflow()
return lazy_shape(tensorflow.shape(A))

View File

@@ -0,0 +1 @@
from .dot import dot

107
python/triton/ops/dot.py Normal file
View File

@@ -0,0 +1,107 @@
import triton
class dot(triton.function):
src = """
void dot(TYPE * A, TYPE * B, TYPE * C,
int M, int N, int K,
int lda __multipleof(8),
int ldb __multipleof(8),
int ldc) {
// prologue
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rxa[TM] = ridx * TM + 0 ... TM;
int ryb[TN] = ridy * TN + 0 ... TN;
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;
float c[TM, TN] = 0;
// pointers to operands
TYPE* pa[SHAPE_A] = A + rka[BROADCAST_AK] * STRIDE_AK + rxa[BROADCAST_AM] * STRIDE_AM;
TYPE* pb[SHAPE_B] = B + rkb[BROADCAST_BK] * STRIDE_BK + ryb[BROADCAST_BN] * STRIDE_BN;
// prefetches operands
TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb;
// reduction loop
for(int k = K; k > 0; k-= TK){
c += USE_A @ USE_B;
pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK;
a = *pa;
b = *pb;
}
// epilogue
int rxc[TM] = ridx * TM + 0 ... TM;
int ryc[TN] = ridy * TN + 0 ... TN;
TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc;
bool checkc[TM, TN] = (rxc < M)[:, newaxis] && (ryc < N)[newaxis, :];
*?(checkc) pc = c;
}
"""
kernel = triton.kernel(src, ['C'])
@staticmethod
def _call(a, b, transpose_a, transpose_b):
# extract shapes
shape_a = triton.shape(a)
shape_b = triton.shape(b)
M, Ka = shape_a[0], shape_a[1]
Kb, N = shape_b[0], shape_b[1]
# transpose shapes
if transpose_a:
M, Ka = Ka, M
if transpose_b:
Kb, N = N, Kb
# contiguous dimensions
lda = M if transpose_a else Ka
ldb = Kb if transpose_b else N
ldc = N
# data-type
dtype = a.dtype
# allocate output
c = triton.empty([M, N], dtype = dtype)
# compute
grid = lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))]
# macros -- not necessary but makes kernel source-code simpler
macros = {# handle A transposition
'USE_A' : '^a' if transpose_a else 'a',
'STRIDE_AK' : 'lda' if transpose_a else '1',
'STRIDE_AM' : '1' if transpose_a else 'lda',
'BROADCAST_AK': ':, newaxis' if transpose_a else 'newaxis, :',
'BROADCAST_AM': 'newaxis, :' if transpose_a else ':, newaxis',
'SHAPE_A' : 'TK, TM' if transpose_a else 'TM, TK',
# handle B transposition
'USE_B' : '^b' if transpose_b else 'b',
'STRIDE_BK' : '1' if transpose_b else 'ldb',
'STRIDE_BN' : 'ldb' if transpose_b else '1',
'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis',
'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :',
'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'}
return dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, grid,
AT = transpose_a, BT = transpose_b, TYPE = dtype,
TM = [64, 128], TN = [64, 128], TK = [8], **macros)
@staticmethod
def forward(ctx, a, b, transpose_a = False, transpose_b = False):
ctx.save_for_backward(a, b, transpose_a, transpose_b)
return dot._call(a, b, transpose_a, transpose_b)
@staticmethod
def backward(ctx, dy):
a, b, t_a, t_b = ctx.saved_tensors
if not t_a and not t_b:
da = dot._call(dy, b, False, True)
db = dot._call(a, dy, True, False)
elif not t_a and t_b:
da = dot._call(dy, b, False, False)
db = dot._call(dy, a, True, False)
elif t_a and not t_b:
da = dot._call(b, dy, False, True)
db = dot._call(a, dy, False, False)
elif t_a and t_b:
da = dot._call(b, dy, True, True)
db = dot._call(dy, a, True, True)
else:
assert False
return [da, db, None, None, None, None, None, None, None]

88
python/triton/utils.py Normal file
View File

@@ -0,0 +1,88 @@
import triton.frameworks as fw
import libtriton
def cdiv(a, b):
return -(-a // b)
def empty(shapes, dtype, framework = None):
framework = fw._find_framework(framework)
if framework == fw.tensorflow_id:
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
args = fw.tensorflow.stack(args)
return fw.tf_extra_ops.alloc_empty(args, T = dtype)
elif framework == fw.torch_id:
_import_torch()
return fw.torch.empty(*shapes)
class lazy_shape:
def __init__(self, shape):
self.shape = shape
def __getitem__(self, key):
return scalar(self.shape[key])
def shape(A) :
fw._import_tensorflow()
return lazy_shape(fw.tensorflow.shape(A))
class scalar:
def __init__(self, x):
self.id = libtriton.make_scalar_id()
self.handle = fw.tf_extra_ops.register_scalar(x, id=self.id)
self.assume_initialized = False
def set_assume_initialized(self):
self.assume_initialized = True
def unset_assume_initialized(self):
self.assume_initialized = False
def get_value(self):
if self.assume_initialized:
return libtriton.retrieve_scalar(self.id)
else:
return self.handle
def __add__(self, other):
return self.get_value() + other
def __radd__(self, other):
return other + self.get_value()
def __sub__(self, other):
return self.get_value() - other
def __rsub(self, other):
return other - self.get_value()
def __mul__(self, other):
return self.get_value() * other
def __rmul(self, other):
return other * self.get_value()
def __floordiv__(self, other):
return self.get_value() // other
def __rfloordiv__(self, other):
return other // self.get_value()
def __div__(self, other):
return self.get_value() / other
def __rdiv__(self, other):
return other / self.get_value()
def __truediv__(self, other):
self.get_value().__truediv__(other)
def __rtruediv__(self, other):
other.__truediv__(self.get_value())
def __neg__(self):
return -self.get_value()