diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 7891fbe43..89ac8f403 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -6,12 +6,11 @@ on: branches: - main - jobs: Integration-Tests: - runs-on: self-hosted + runs-on: ubuntu-20.04 steps: @@ -23,32 +22,23 @@ jobs: rm -r ~/.triton/ continue-on-error: true + - name: Check imports + run: | + pip install isort + isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 ) + + - name: Check style + run: | + pip install autopep8 + autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 ) + + - name: Flake8 + run: | + pip install flake8 + flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 ) + - name: Install Triton run: | alias python='python3' cd python pip3 install -e '.[tests]' - - - name: Check imports - run: "isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )" - - - name: Check style - run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )" - - - name: Flake8 - run: "flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )" - - - name: Unit tests - run: | - cd python/test/unit - pytest -vs . - - - name: Regression tests - run: | - cd python/test/regression - sudo nvidia-smi -i 0 -pm 1 - sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350 - sudo nvidia-smi -i 0 --lock-memory-clocks=877,877 - pytest -vs . - sudo nvidia-smi -i 0 -rgc - sudo nvidia-smi -i 0 -rmc diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index 896939e53..68278474b 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -1,3 +1,6 @@ add_mlir_library(TritonAnalysis AxisInfo.cpp + + DEPENDS + TritonGPUAttrDefsIncGen ) \ No newline at end of file diff --git a/python/examples/copy_strided.py b/python/examples/copy_strided.py index 8e0db3e15..48e9dbb4e 100644 --- a/python/examples/copy_strided.py +++ b/python/examples/copy_strided.py @@ -1,12 +1,13 @@ - + import triton import triton.language as tl + # triton kernel @triton.jit -def kernel(X, stride_xm, stride_xn, - Z, stride_zm, stride_zn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): +def kernel(X, stride_xm, stride_xn, + Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn @@ -15,4 +16,4 @@ def kernel(X, stride_xm, stride_xn, ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 128, "BLOCK_N": 128}, output="ttgir") -print(ret) \ No newline at end of file +print(ret) diff --git a/python/examples/empty.py b/python/examples/empty.py index 233aff36e..b17d58ca3 100644 --- a/python/examples/empty.py +++ b/python/examples/empty.py @@ -1,8 +1,10 @@ import triton import triton.language as tl + @triton.jit def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr): pass -ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir") \ No newline at end of file + +ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir") diff --git a/python/setup.py b/python/setup.py index d8ddb327c..1e17d26f2 100644 --- a/python/setup.py +++ b/python/setup.py @@ -16,15 +16,6 @@ from setuptools.command.build_ext import build_ext def get_llvm(): - # tries to find system LLVM - versions = ['-14.0', '-14', '-14-64'] - supported = ['llvm-config{v}'.format(v=v) for v in versions] - paths = [distutils.spawn.find_executable(cfg) for cfg in supported] - paths = [p for p in paths if p is not None] - if paths: - return '', '' - if platform.system() == "Windows": - return '', '' # download if nothing is installed name = 'clang+llvm-14.0.0-x86_64-linux-gnu-ubuntu-18.04' dir = '/tmp' diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 6449de421..da85b2ade 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1,15 +1,17 @@ from __future__ import annotations + import ast import sys import warnings from typing import Dict, Union + import triton import triton._C.libtriton.triton as _triton def str_to_ty(name): if name[0] == "*": - ty = str_to_ty(name[1:]) + ty = str_to_ty(name[1:]) return triton.language.pointer_type(ty) tys = { "fp8": triton.language.float8, @@ -26,9 +28,10 @@ def str_to_ty(name): "u32": triton.language.uint32, "u64": triton.language.uint64, "B": triton.language.int1, - } + } return tys[name] + def mangle_ty(ty): if ty.is_ptr(): return 'P' + mangle_ty(ty.element_ty) @@ -62,6 +65,7 @@ def mangle_fn(name, arg_tys, constants): ret = f'{name}__{mangled_arg_names}__{mangled_constants}' return ret + class enter_sub_region: def __init__(self, generator: CodeGenerator): self.generator = generator @@ -79,6 +83,7 @@ class enter_sub_region: self.generator.lscope = self.liveins self.generator.local_defs = self.prev_defs + class CodeGenerator(ast.NodeVisitor): def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()): self.builder = _triton.ir.builder(context) @@ -491,8 +496,8 @@ class CodeGenerator(ast.NodeVisitor): while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], [arg.handle for arg in init_args]) # merge the condition region - before_block = self.builder.create_block_with_parent(while_op.get_before(), - [ty.to_ir(self.builder) for ty in ret_types]) + before_block = self.builder.create_block_with_parent(while_op.get_before(), + [ty.to_ir(self.builder) for ty in ret_types]) cond_block.merge_block_before(before_block) self.builder.set_insertion_point_to_end(before_block) # create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... @@ -538,7 +543,6 @@ class CodeGenerator(ast.NodeVisitor): iter_args = [self.visit(arg) for arg in node.iter.args] is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args]) if is_static: - st_target = ast.Name(id=node.target.id, ctx=ast.Store()) iter_args = [arg.value for arg in iter_args] range = iterator(*iter_args) if len(range) <= 10: @@ -597,7 +601,7 @@ class CodeGenerator(ast.NodeVisitor): # replace global uses with block arguments for i, name in enumerate(names): # arg0 is the induction variable - for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1)) + for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i + 1)) # update lscope & local_defs (ForOp defines new values) for i, name in enumerate(names): @@ -633,7 +637,7 @@ class CodeGenerator(ast.NodeVisitor): args = getcallargs(fn.fn, *args, **kws) args = [args[name] for name in fn.arg_names] args = [arg if isinstance(arg, triton.language.tensor) - else triton.language.constexpr(arg) for arg in args] + else triton.language.constexpr(arg) for arg in args] # generate function def attributes = dict() constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)] @@ -712,7 +716,6 @@ class CodeGenerator(ast.NodeVisitor): raise NotImplementedError("Unsupported node: {}".format(typename)) - class CompilationError(Exception): def __init__(self, src, node): self.message = f'at {node.lineno}:{node.col_offset}:\n' @@ -742,11 +745,11 @@ class OutOfResources(Exception): return (type(self), (self.required, self.limit, self.name)) -def make_triton_ir(fn, signature, constants = dict(), attributes = dict()): +def make_triton_ir(fn, signature, constants=dict(), attributes=dict()): context = _triton.ir.context() context.load_triton() # create kernel prototype - arg_types = signature.replace(' ','').split(',') + arg_types = signature.replace(' ', '').split(',') constants = {fn.arg_names.index(name): value for name, value in constants.items()} arg_types = [str_to_ty(x) for x in arg_types] prototype = triton.language.function_type([], arg_types) @@ -765,6 +768,7 @@ def make_triton_ir(fn, signature, constants = dict(), attributes = dict()): ret.context = context return ret + def make_tritongpu_ir(mod, num_warps): pm = _triton.ir.pass_manager(mod.context) pm.add_inliner_pass() @@ -775,6 +779,7 @@ def make_tritongpu_ir(mod, num_warps): pm.run(mod) return mod + def optimize_tritongpu_ir(mod, num_stages): pm = _triton.ir.pass_manager(mod.context) pm.add_tritongpu_pipeline_pass(num_stages) @@ -785,22 +790,24 @@ def optimize_tritongpu_ir(mod, num_stages): pm.run(mod) return mod + def make_ptx(mod): # TODO return mod -def compile(fn, signature, constants = dict(), attributes = dict(), num_warps=4, num_stages=3, output = "ttgir"): + +def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"): assert output in ["ttir", "ttgir", "ptx"] # triton-ir module = make_triton_ir(fn, signature, constants, attributes) - if output == "ttir": + if output == "ttir": return module.str() # tritongpu-ir module = make_tritongpu_ir(module, num_warps) module = optimize_tritongpu_ir(module, num_stages) - if output == "ttgir": + if output == "ttgir": return module.str() # ptx - if output == "ptx": + if output == "ptx": return make_ptx(module) - assert False \ No newline at end of file + assert False diff --git a/python/triton/runtime/__init__.py b/python/triton/runtime/__init__.py index 47b5d5a1e..296d750ee 100644 --- a/python/triton/runtime/__init__.py +++ b/python/triton/runtime/__init__.py @@ -1,2 +1,2 @@ -from .jit import JITFunction, jit -from .autotuner import Config, autotune, heuristics \ No newline at end of file +from .autotuner import Config, autotune, heuristics # noqa: F401 +from .jit import JITFunction, jit # noqa: F401 diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 63bf2bbf7..010af80b4 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -1,8 +1,10 @@ from __future__ import annotations + import builtins import time from typing import Dict +from ..testing import do_bench class Autotuner: @@ -57,7 +59,7 @@ class Autotuner: config.pre_hook(self.nargs) self.hook(args) self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) - return triton.testing.do_bench(kernel_call) + return do_bench(kernel_call) def __call__(self, *args, **kwargs): self.nargs = dict(zip(self.arg_names, args)) @@ -199,4 +201,3 @@ def heuristics(values): return fn return decorator - diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 758b05341..6c9547b8f 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -8,6 +8,7 @@ import os import subprocess import tempfile import textwrap + import triton import triton._C.libtriton.triton as _triton from ..tools.disasm import extract @@ -16,6 +17,7 @@ from ..tools.disasm import extract # Binary # ----------------------------------------------------------------------------- + class Binary: def __init__(self, backend, name, asm, shared_mem, num_warps): self.backend = backend @@ -63,13 +65,13 @@ class LoadedBinary: # Kernel # ----------------------------------------------------------------------------- + class Kernel: def __call__(self, *args, grid, num_warps=4, num_stages=3, **kwargs): raise RuntimeError("Not implemented. Public repo implementation will be rewritten to reduce latency.") - # ----------------------------------------------------------------------------- # Dependencies Finder # ----------------------------------------------------------------------------- @@ -118,6 +120,7 @@ class DependenciesFinder(ast.NodeVisitor): # JITFunction # ----------------------------------------------------------------------------- + @functools.lru_cache() def version_key(): import pkgutil @@ -232,7 +235,7 @@ class JITFunction: def __call__(self, *wargs, **kwargs): return self.kernel(*wargs, **kwargs, grid=self.grid) - + return Launcher(self._init_kernel(), grid) def __repr__(self): @@ -242,6 +245,7 @@ class JITFunction: # `jit` decorator # ----------------------------------------------------------------------------- + def jit(*args, **kwargs): """ Decorator for JIT-compiling a function using the Triton compiler. @@ -265,4 +269,4 @@ def jit(*args, **kwargs): else: def decorator(fn): return JITFunction(fn, **kwargs) - return decorator \ No newline at end of file + return decorator diff --git a/python/triton/utils.py b/python/triton/utils.py index b9db92dfb..f446dd06a 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -1,4 +1,5 @@ from __future__ import annotations + import torch @@ -17,6 +18,7 @@ def next_power_of_2(n): n += 1 return n + class TensorWrapper: def __init__(self, base, dtype): self.dtype = dtype