diff --git a/python/setup.py b/python/setup.py index 060a1c450..81395135a 100644 --- a/python/setup.py +++ b/python/setup.py @@ -63,19 +63,7 @@ class CMakeBuild(build_ext): '-DBUILD_PYTHON_MODULE=ON', '-DPYTHON_INCLUDE_DIRS=' + python_include_dirs, '-DLLVM_CONFIG=' + find_llvm()] - # tensorflow compatibility - try: - import tensorflow as tf - tf_abi = tf.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tf.__dict__ else 0 - tf_include_dirs = tf.sysconfig.get_include() - tf_libs = tf.sysconfig.get_link_flags()[1].replace('-l', '') - cmake_args += ['-DTF_INCLUDE_DIRS=' + tf_include_dirs, - '-DTF_LIB_DIRS=' + tf.sysconfig.get_lib(), - '-DTF_LIBS=' + tf_libs, - '-DTF_ABI=' + str(tf_abi)] - except ModuleNotFoundError: - pass - + # configuration cfg = 'Debug' if self.debug else 'Release' cfg = 'Release' build_args = ['--config', cfg] @@ -104,8 +92,10 @@ find_llvm() directories = [x[0] for x in os.walk(os.path.join(os.path.pardir, 'include'))] data = [] for d in directories: - files = glob.glob(os.path.join(d, '*.h'), recursive=False) - data += [os.path.relpath(f, os.path.pardir) for f in files] + for htype in ['h', 'hpp']: + files = glob.glob(os.path.join(d, f'*.{htype}'), recursive=False) + data += [os.path.relpath(f, os.path.pardir) for f in files] +print(data) setup( name='triton', @@ -114,7 +104,8 @@ setup( author_email='ptillet@g.harvard.edu', description='A language and compiler for custom Deep Learning operations', long_description='', - packages=['triton', 'triton/_C', 'triton/ops'], + packages=['triton', 'triton/_C', 'triton/ops', 'triton/nn'], + install_requires=['numpy', 'torch', 'sympy'], package_data={'': data}, ext_modules=[CMakeExtension('triton', 'triton/_C/')], cmdclass=dict(build_ext=CMakeBuild), diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 388484039..535f717ec 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,6 +1,7 @@ from .kernel import * from .utils import * import triton.ops +import triton.nn # clean-up libtriton resources diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 6c74dc92c..a24e422dc 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -1,18 +1,21 @@ -import numpy as np -import torch from math import ceil, log2 from enum import IntEnum -import triton from functools import reduce from operator import mul -from sympy.parsing.sympy_parser import parse_expr -import sympy as sp from collections import OrderedDict from collections import namedtuple import re +import triton +# torch +import torch +# numpy -- ideally removed in a future release +import numpy as np +# sympy -- ideally removed in a future release +import sympy as sp +from sympy.parsing.sympy_parser import parse_expr from sympy.printing.ccode import C89CodePrinter - + class _einsum(torch.autograd.Function): @@ -612,7 +615,7 @@ __global__ void {name}( TM, TN, TB, TZ = 64, 128, 1, 1 self.macros = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype} self.num_warps = [4] - if mask: + if mask is not None: self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10) # save information on the operation self.expr_a = expr_a @@ -696,7 +699,7 @@ __global__ void {name}( if ctx.needs_input_grad[2]: db = torch.empty_like(b) einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db) - return None, da, db, None, None, None, None + return None, da, db, None, None, None, None, None def einsum(expr, a, b, output,