[CI] Added basic CI skeletons (#23)

Includes minor fixes to make things compile and pass static checks properly
This commit is contained in:
Philippe Tillet
2022-07-26 14:16:30 -07:00
committed by GitHub
parent 3265e0df5a
commit 25357083e6
10 changed files with 64 additions and 63 deletions

View File

@@ -6,12 +6,11 @@ on:
branches: branches:
- main - main
jobs: jobs:
Integration-Tests: Integration-Tests:
runs-on: self-hosted runs-on: ubuntu-20.04
steps: steps:
@@ -23,32 +22,23 @@ jobs:
rm -r ~/.triton/ rm -r ~/.triton/
continue-on-error: true continue-on-error: true
- name: Check imports
run: |
pip install isort
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
- name: Check style
run: |
pip install autopep8
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
- name: Flake8
run: |
pip install flake8
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
- name: Install Triton - name: Install Triton
run: | run: |
alias python='python3' alias python='python3'
cd python cd python
pip3 install -e '.[tests]' pip3 install -e '.[tests]'
- name: Check imports
run: "isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )"
- name: Check style
run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )"
- name: Flake8
run: "flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )"
- name: Unit tests
run: |
cd python/test/unit
pytest -vs .
- name: Regression tests
run: |
cd python/test/regression
sudo nvidia-smi -i 0 -pm 1
sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350
sudo nvidia-smi -i 0 --lock-memory-clocks=877,877
pytest -vs .
sudo nvidia-smi -i 0 -rgc
sudo nvidia-smi -i 0 -rmc

View File

@@ -1,3 +1,6 @@
add_mlir_library(TritonAnalysis add_mlir_library(TritonAnalysis
AxisInfo.cpp AxisInfo.cpp
DEPENDS
TritonGPUAttrDefsIncGen
) )

View File

@@ -2,11 +2,12 @@
import triton import triton
import triton.language as tl import triton.language as tl
# triton kernel # triton kernel
@triton.jit @triton.jit
def kernel(X, stride_xm, stride_xn, def kernel(X, stride_xm, stride_xn,
Z, stride_zm, stride_zn, Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
off_m = tl.arange(0, BLOCK_M) off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N) off_n = tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn

View File

@@ -1,8 +1,10 @@
import triton import triton
import triton.language as tl import triton.language as tl
@triton.jit @triton.jit
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr): def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
pass pass
ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir") ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir")

View File

@@ -16,15 +16,6 @@ from setuptools.command.build_ext import build_ext
def get_llvm(): def get_llvm():
# tries to find system LLVM
versions = ['-14.0', '-14', '-14-64']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None]
if paths:
return '', ''
if platform.system() == "Windows":
return '', ''
# download if nothing is installed # download if nothing is installed
name = 'clang+llvm-14.0.0-x86_64-linux-gnu-ubuntu-18.04' name = 'clang+llvm-14.0.0-x86_64-linux-gnu-ubuntu-18.04'
dir = '/tmp' dir = '/tmp'

View File

@@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
import ast import ast
import sys import sys
import warnings import warnings
from typing import Dict, Union from typing import Dict, Union
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
@@ -29,6 +31,7 @@ def str_to_ty(name):
} }
return tys[name] return tys[name]
def mangle_ty(ty): def mangle_ty(ty):
if ty.is_ptr(): if ty.is_ptr():
return 'P' + mangle_ty(ty.element_ty) return 'P' + mangle_ty(ty.element_ty)
@@ -62,6 +65,7 @@ def mangle_fn(name, arg_tys, constants):
ret = f'{name}__{mangled_arg_names}__{mangled_constants}' ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
return ret return ret
class enter_sub_region: class enter_sub_region:
def __init__(self, generator: CodeGenerator): def __init__(self, generator: CodeGenerator):
self.generator = generator self.generator = generator
@@ -79,6 +83,7 @@ class enter_sub_region:
self.generator.lscope = self.liveins self.generator.lscope = self.liveins
self.generator.local_defs = self.prev_defs self.generator.local_defs = self.prev_defs
class CodeGenerator(ast.NodeVisitor): class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()): def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
self.builder = _triton.ir.builder(context) self.builder = _triton.ir.builder(context)
@@ -492,7 +497,7 @@ class CodeGenerator(ast.NodeVisitor):
[arg.handle for arg in init_args]) [arg.handle for arg in init_args])
# merge the condition region # merge the condition region
before_block = self.builder.create_block_with_parent(while_op.get_before(), before_block = self.builder.create_block_with_parent(while_op.get_before(),
[ty.to_ir(self.builder) for ty in ret_types]) [ty.to_ir(self.builder) for ty in ret_types])
cond_block.merge_block_before(before_block) cond_block.merge_block_before(before_block)
self.builder.set_insertion_point_to_end(before_block) self.builder.set_insertion_point_to_end(before_block)
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... # create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
@@ -538,7 +543,6 @@ class CodeGenerator(ast.NodeVisitor):
iter_args = [self.visit(arg) for arg in node.iter.args] iter_args = [self.visit(arg) for arg in node.iter.args]
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args]) is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
if is_static: if is_static:
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
iter_args = [arg.value for arg in iter_args] iter_args = [arg.value for arg in iter_args]
range = iterator(*iter_args) range = iterator(*iter_args)
if len(range) <= 10: if len(range) <= 10:
@@ -597,7 +601,7 @@ class CodeGenerator(ast.NodeVisitor):
# replace global uses with block arguments # replace global uses with block arguments
for i, name in enumerate(names): for i, name in enumerate(names):
# arg0 is the induction variable # arg0 is the induction variable
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1)) for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i + 1))
# update lscope & local_defs (ForOp defines new values) # update lscope & local_defs (ForOp defines new values)
for i, name in enumerate(names): for i, name in enumerate(names):
@@ -633,7 +637,7 @@ class CodeGenerator(ast.NodeVisitor):
args = getcallargs(fn.fn, *args, **kws) args = getcallargs(fn.fn, *args, **kws)
args = [args[name] for name in fn.arg_names] args = [args[name] for name in fn.arg_names]
args = [arg if isinstance(arg, triton.language.tensor) args = [arg if isinstance(arg, triton.language.tensor)
else triton.language.constexpr(arg) for arg in args] else triton.language.constexpr(arg) for arg in args]
# generate function def # generate function def
attributes = dict() attributes = dict()
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)] constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
@@ -712,7 +716,6 @@ class CodeGenerator(ast.NodeVisitor):
raise NotImplementedError("Unsupported node: {}".format(typename)) raise NotImplementedError("Unsupported node: {}".format(typename))
class CompilationError(Exception): class CompilationError(Exception):
def __init__(self, src, node): def __init__(self, src, node):
self.message = f'at {node.lineno}:{node.col_offset}:\n' self.message = f'at {node.lineno}:{node.col_offset}:\n'
@@ -742,11 +745,11 @@ class OutOfResources(Exception):
return (type(self), (self.required, self.limit, self.name)) return (type(self), (self.required, self.limit, self.name))
def make_triton_ir(fn, signature, constants = dict(), attributes = dict()): def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
context = _triton.ir.context() context = _triton.ir.context()
context.load_triton() context.load_triton()
# create kernel prototype # create kernel prototype
arg_types = signature.replace(' ','').split(',') arg_types = signature.replace(' ', '').split(',')
constants = {fn.arg_names.index(name): value for name, value in constants.items()} constants = {fn.arg_names.index(name): value for name, value in constants.items()}
arg_types = [str_to_ty(x) for x in arg_types] arg_types = [str_to_ty(x) for x in arg_types]
prototype = triton.language.function_type([], arg_types) prototype = triton.language.function_type([], arg_types)
@@ -765,6 +768,7 @@ def make_triton_ir(fn, signature, constants = dict(), attributes = dict()):
ret.context = context ret.context = context
return ret return ret
def make_tritongpu_ir(mod, num_warps): def make_tritongpu_ir(mod, num_warps):
pm = _triton.ir.pass_manager(mod.context) pm = _triton.ir.pass_manager(mod.context)
pm.add_inliner_pass() pm.add_inliner_pass()
@@ -775,6 +779,7 @@ def make_tritongpu_ir(mod, num_warps):
pm.run(mod) pm.run(mod)
return mod return mod
def optimize_tritongpu_ir(mod, num_stages): def optimize_tritongpu_ir(mod, num_stages):
pm = _triton.ir.pass_manager(mod.context) pm = _triton.ir.pass_manager(mod.context)
pm.add_tritongpu_pipeline_pass(num_stages) pm.add_tritongpu_pipeline_pass(num_stages)
@@ -785,11 +790,13 @@ def optimize_tritongpu_ir(mod, num_stages):
pm.run(mod) pm.run(mod)
return mod return mod
def make_ptx(mod): def make_ptx(mod):
# TODO # TODO
return mod return mod
def compile(fn, signature, constants = dict(), attributes = dict(), num_warps=4, num_stages=3, output = "ttgir"):
def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
assert output in ["ttir", "ttgir", "ptx"] assert output in ["ttir", "ttgir", "ptx"]
# triton-ir # triton-ir
module = make_triton_ir(fn, signature, constants, attributes) module = make_triton_ir(fn, signature, constants, attributes)

View File

@@ -1,2 +1,2 @@
from .jit import JITFunction, jit from .autotuner import Config, autotune, heuristics # noqa: F401
from .autotuner import Config, autotune, heuristics from .jit import JITFunction, jit # noqa: F401

View File

@@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
import builtins import builtins
import time import time
from typing import Dict from typing import Dict
from ..testing import do_bench
class Autotuner: class Autotuner:
@@ -57,7 +59,7 @@ class Autotuner:
config.pre_hook(self.nargs) config.pre_hook(self.nargs)
self.hook(args) self.hook(args)
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
return triton.testing.do_bench(kernel_call) return do_bench(kernel_call)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args)) self.nargs = dict(zip(self.arg_names, args))
@@ -199,4 +201,3 @@ def heuristics(values):
return fn return fn
return decorator return decorator

View File

@@ -8,6 +8,7 @@ import os
import subprocess import subprocess
import tempfile import tempfile
import textwrap import textwrap
import triton import triton
import triton._C.libtriton.triton as _triton import triton._C.libtriton.triton as _triton
from ..tools.disasm import extract from ..tools.disasm import extract
@@ -16,6 +17,7 @@ from ..tools.disasm import extract
# Binary # Binary
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Binary: class Binary:
def __init__(self, backend, name, asm, shared_mem, num_warps): def __init__(self, backend, name, asm, shared_mem, num_warps):
self.backend = backend self.backend = backend
@@ -63,13 +65,13 @@ class LoadedBinary:
# Kernel # Kernel
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Kernel: class Kernel:
def __call__(self, *args, grid, num_warps=4, num_stages=3, **kwargs): def __call__(self, *args, grid, num_warps=4, num_stages=3, **kwargs):
raise RuntimeError("Not implemented. Public repo implementation will be rewritten to reduce latency.") raise RuntimeError("Not implemented. Public repo implementation will be rewritten to reduce latency.")
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Dependencies Finder # Dependencies Finder
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -118,6 +120,7 @@ class DependenciesFinder(ast.NodeVisitor):
# JITFunction # JITFunction
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@functools.lru_cache() @functools.lru_cache()
def version_key(): def version_key():
import pkgutil import pkgutil
@@ -242,6 +245,7 @@ class JITFunction:
# `jit` decorator # `jit` decorator
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def jit(*args, **kwargs): def jit(*args, **kwargs):
""" """
Decorator for JIT-compiling a function using the Triton compiler. Decorator for JIT-compiling a function using the Triton compiler.

View File

@@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
import torch import torch
@@ -17,6 +18,7 @@ def next_power_of_2(n):
n += 1 n += 1
return n return n
class TensorWrapper: class TensorWrapper:
def __init__(self, base, dtype): def __init__(self, base, dtype):
self.dtype = dtype self.dtype = dtype