[PYTHON] Better packaging
This commit is contained in:
committed by
Philippe Tillet
parent
dfb844bf41
commit
3d769b57e2
@@ -63,19 +63,7 @@ class CMakeBuild(build_ext):
|
|||||||
'-DBUILD_PYTHON_MODULE=ON',
|
'-DBUILD_PYTHON_MODULE=ON',
|
||||||
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
|
'-DPYTHON_INCLUDE_DIRS=' + python_include_dirs,
|
||||||
'-DLLVM_CONFIG=' + find_llvm()]
|
'-DLLVM_CONFIG=' + find_llvm()]
|
||||||
# tensorflow compatibility
|
# configuration
|
||||||
try:
|
|
||||||
import tensorflow as tf
|
|
||||||
tf_abi = tf.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tf.__dict__ else 0
|
|
||||||
tf_include_dirs = tf.sysconfig.get_include()
|
|
||||||
tf_libs = tf.sysconfig.get_link_flags()[1].replace('-l', '')
|
|
||||||
cmake_args += ['-DTF_INCLUDE_DIRS=' + tf_include_dirs,
|
|
||||||
'-DTF_LIB_DIRS=' + tf.sysconfig.get_lib(),
|
|
||||||
'-DTF_LIBS=' + tf_libs,
|
|
||||||
'-DTF_ABI=' + str(tf_abi)]
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
cfg = 'Debug' if self.debug else 'Release'
|
cfg = 'Debug' if self.debug else 'Release'
|
||||||
cfg = 'Release'
|
cfg = 'Release'
|
||||||
build_args = ['--config', cfg]
|
build_args = ['--config', cfg]
|
||||||
@@ -104,8 +92,10 @@ find_llvm()
|
|||||||
directories = [x[0] for x in os.walk(os.path.join(os.path.pardir, 'include'))]
|
directories = [x[0] for x in os.walk(os.path.join(os.path.pardir, 'include'))]
|
||||||
data = []
|
data = []
|
||||||
for d in directories:
|
for d in directories:
|
||||||
files = glob.glob(os.path.join(d, '*.h'), recursive=False)
|
for htype in ['h', 'hpp']:
|
||||||
|
files = glob.glob(os.path.join(d, f'*.{htype}'), recursive=False)
|
||||||
data += [os.path.relpath(f, os.path.pardir) for f in files]
|
data += [os.path.relpath(f, os.path.pardir) for f in files]
|
||||||
|
print(data)
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='triton',
|
name='triton',
|
||||||
@@ -114,7 +104,8 @@ setup(
|
|||||||
author_email='ptillet@g.harvard.edu',
|
author_email='ptillet@g.harvard.edu',
|
||||||
description='A language and compiler for custom Deep Learning operations',
|
description='A language and compiler for custom Deep Learning operations',
|
||||||
long_description='',
|
long_description='',
|
||||||
packages=['triton', 'triton/_C', 'triton/ops'],
|
packages=['triton', 'triton/_C', 'triton/ops', 'triton/nn'],
|
||||||
|
install_requires=['numpy', 'torch', 'sympy'],
|
||||||
package_data={'': data},
|
package_data={'': data},
|
||||||
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
|
ext_modules=[CMakeExtension('triton', 'triton/_C/')],
|
||||||
cmdclass=dict(build_ext=CMakeBuild),
|
cmdclass=dict(build_ext=CMakeBuild),
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
from .kernel import *
|
from .kernel import *
|
||||||
from .utils import *
|
from .utils import *
|
||||||
import triton.ops
|
import triton.ops
|
||||||
|
import triton.nn
|
||||||
|
|
||||||
|
|
||||||
# clean-up libtriton resources
|
# clean-up libtriton resources
|
||||||
|
@@ -1,15 +1,18 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from math import ceil, log2
|
from math import ceil, log2
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
import triton
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from operator import mul
|
from operator import mul
|
||||||
from sympy.parsing.sympy_parser import parse_expr
|
|
||||||
import sympy as sp
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import re
|
import re
|
||||||
|
import triton
|
||||||
|
# torch
|
||||||
|
import torch
|
||||||
|
# numpy -- ideally removed in a future release
|
||||||
|
import numpy as np
|
||||||
|
# sympy -- ideally removed in a future release
|
||||||
|
import sympy as sp
|
||||||
|
from sympy.parsing.sympy_parser import parse_expr
|
||||||
from sympy.printing.ccode import C89CodePrinter
|
from sympy.printing.ccode import C89CodePrinter
|
||||||
|
|
||||||
|
|
||||||
@@ -612,7 +615,7 @@ __global__ void {name}(
|
|||||||
TM, TN, TB, TZ = 64, 128, 1, 1
|
TM, TN, TB, TZ = 64, 128, 1, 1
|
||||||
self.macros = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
self.macros = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype}
|
||||||
self.num_warps = [4]
|
self.num_warps = [4]
|
||||||
if mask:
|
if mask is not None:
|
||||||
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
self.macros['MASK'] = '{0:#0{1}x}'.format(mask, 10)
|
||||||
# save information on the operation
|
# save information on the operation
|
||||||
self.expr_a = expr_a
|
self.expr_a = expr_a
|
||||||
@@ -696,7 +699,7 @@ __global__ void {name}(
|
|||||||
if ctx.needs_input_grad[2]:
|
if ctx.needs_input_grad[2]:
|
||||||
db = torch.empty_like(b)
|
db = torch.empty_like(b)
|
||||||
einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db)
|
einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db)
|
||||||
return None, da, db, None, None, None, None
|
return None, da, db, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def einsum(expr, a, b, output,
|
def einsum(expr, a, b, output,
|
||||||
|
Reference in New Issue
Block a user