[CI] Added basic CI skeletons (#23)
Includes minor fixes to make things compile and pass static checks properly
This commit is contained in:
42
.github/workflows/integration-tests.yml
vendored
42
.github/workflows/integration-tests.yml
vendored
@@ -6,12 +6,11 @@ on:
|
||||
branches:
|
||||
- main
|
||||
|
||||
|
||||
jobs:
|
||||
|
||||
Integration-Tests:
|
||||
|
||||
runs-on: self-hosted
|
||||
runs-on: ubuntu-20.04
|
||||
|
||||
steps:
|
||||
|
||||
@@ -23,32 +22,23 @@ jobs:
|
||||
rm -r ~/.triton/
|
||||
continue-on-error: true
|
||||
|
||||
- name: Check imports
|
||||
run: |
|
||||
pip install isort
|
||||
isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )
|
||||
|
||||
- name: Check style
|
||||
run: |
|
||||
pip install autopep8
|
||||
autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )
|
||||
|
||||
- name: Flake8
|
||||
run: |
|
||||
pip install flake8
|
||||
flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )
|
||||
|
||||
- name: Install Triton
|
||||
run: |
|
||||
alias python='python3'
|
||||
cd python
|
||||
pip3 install -e '.[tests]'
|
||||
|
||||
- name: Check imports
|
||||
run: "isort -c ./python || ( echo '::error title=Imports not sorted::Please run \"isort ./python\"' ; exit 1 )"
|
||||
|
||||
- name: Check style
|
||||
run: "autopep8 -a -r -d --exit-code ./python || ( echo '::error title=Style issues::Please run \"autopep8 -a -r -i ./python\"' ; exit 1 )"
|
||||
|
||||
- name: Flake8
|
||||
run: "flake8 --config ./python/setup.cfg ./python || ( echo '::error::Flake8 failed; see logs for errors.' ; exit 1 )"
|
||||
|
||||
- name: Unit tests
|
||||
run: |
|
||||
cd python/test/unit
|
||||
pytest -vs .
|
||||
|
||||
- name: Regression tests
|
||||
run: |
|
||||
cd python/test/regression
|
||||
sudo nvidia-smi -i 0 -pm 1
|
||||
sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350
|
||||
sudo nvidia-smi -i 0 --lock-memory-clocks=877,877
|
||||
pytest -vs .
|
||||
sudo nvidia-smi -i 0 -rgc
|
||||
sudo nvidia-smi -i 0 -rmc
|
||||
|
@@ -1,3 +1,6 @@
|
||||
add_mlir_library(TritonAnalysis
|
||||
AxisInfo.cpp
|
||||
|
||||
DEPENDS
|
||||
TritonGPUAttrDefsIncGen
|
||||
)
|
@@ -1,12 +1,13 @@
|
||||
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xn,
|
||||
Z, stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
def kernel(X, stride_xm, stride_xn,
|
||||
Z, stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
||||
@@ -15,4 +16,4 @@ def kernel(X, stride_xm, stride_xn,
|
||||
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,i32,*fp32,i32,i32", constants={"BLOCK_M": 128, "BLOCK_N": 128}, output="ttgir")
|
||||
print(ret)
|
||||
print(ret)
|
||||
|
@@ -1,8 +1,10 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr):
|
||||
pass
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir")
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,i32", constants={"BLOCK": 256}, output="ttgir")
|
||||
|
@@ -16,15 +16,6 @@ from setuptools.command.build_ext import build_ext
|
||||
|
||||
|
||||
def get_llvm():
|
||||
# tries to find system LLVM
|
||||
versions = ['-14.0', '-14', '-14-64']
|
||||
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
||||
paths = [p for p in paths if p is not None]
|
||||
if paths:
|
||||
return '', ''
|
||||
if platform.system() == "Windows":
|
||||
return '', ''
|
||||
# download if nothing is installed
|
||||
name = 'clang+llvm-14.0.0-x86_64-linux-gnu-ubuntu-18.04'
|
||||
dir = '/tmp'
|
||||
|
@@ -1,15 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Dict, Union
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
def str_to_ty(name):
|
||||
if name[0] == "*":
|
||||
ty = str_to_ty(name[1:])
|
||||
ty = str_to_ty(name[1:])
|
||||
return triton.language.pointer_type(ty)
|
||||
tys = {
|
||||
"fp8": triton.language.float8,
|
||||
@@ -26,9 +28,10 @@ def str_to_ty(name):
|
||||
"u32": triton.language.uint32,
|
||||
"u64": triton.language.uint64,
|
||||
"B": triton.language.int1,
|
||||
}
|
||||
}
|
||||
return tys[name]
|
||||
|
||||
|
||||
def mangle_ty(ty):
|
||||
if ty.is_ptr():
|
||||
return 'P' + mangle_ty(ty.element_ty)
|
||||
@@ -62,6 +65,7 @@ def mangle_fn(name, arg_tys, constants):
|
||||
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
|
||||
return ret
|
||||
|
||||
|
||||
class enter_sub_region:
|
||||
def __init__(self, generator: CodeGenerator):
|
||||
self.generator = generator
|
||||
@@ -79,6 +83,7 @@ class enter_sub_region:
|
||||
self.generator.lscope = self.liveins
|
||||
self.generator.local_defs = self.prev_defs
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
|
||||
self.builder = _triton.ir.builder(context)
|
||||
@@ -491,8 +496,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
|
||||
[arg.handle for arg in init_args])
|
||||
# merge the condition region
|
||||
before_block = self.builder.create_block_with_parent(while_op.get_before(),
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
before_block = self.builder.create_block_with_parent(while_op.get_before(),
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
cond_block.merge_block_before(before_block)
|
||||
self.builder.set_insertion_point_to_end(before_block)
|
||||
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
||||
@@ -538,7 +543,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
|
||||
if is_static:
|
||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||
iter_args = [arg.value for arg in iter_args]
|
||||
range = iterator(*iter_args)
|
||||
if len(range) <= 10:
|
||||
@@ -597,7 +601,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# replace global uses with block arguments
|
||||
for i, name in enumerate(names):
|
||||
# arg0 is the induction variable
|
||||
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
||||
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i + 1))
|
||||
|
||||
# update lscope & local_defs (ForOp defines new values)
|
||||
for i, name in enumerate(names):
|
||||
@@ -633,7 +637,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
args = getcallargs(fn.fn, *args, **kws)
|
||||
args = [args[name] for name in fn.arg_names]
|
||||
args = [arg if isinstance(arg, triton.language.tensor)
|
||||
else triton.language.constexpr(arg) for arg in args]
|
||||
else triton.language.constexpr(arg) for arg in args]
|
||||
# generate function def
|
||||
attributes = dict()
|
||||
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
|
||||
@@ -712,7 +716,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
raise NotImplementedError("Unsupported node: {}".format(typename))
|
||||
|
||||
|
||||
|
||||
class CompilationError(Exception):
|
||||
def __init__(self, src, node):
|
||||
self.message = f'at {node.lineno}:{node.col_offset}:\n'
|
||||
@@ -742,11 +745,11 @@ class OutOfResources(Exception):
|
||||
return (type(self), (self.required, self.limit, self.name))
|
||||
|
||||
|
||||
def make_triton_ir(fn, signature, constants = dict(), attributes = dict()):
|
||||
def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
# create kernel prototype
|
||||
arg_types = signature.replace(' ','').split(',')
|
||||
arg_types = signature.replace(' ', '').split(',')
|
||||
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
|
||||
arg_types = [str_to_ty(x) for x in arg_types]
|
||||
prototype = triton.language.function_type([], arg_types)
|
||||
@@ -765,6 +768,7 @@ def make_triton_ir(fn, signature, constants = dict(), attributes = dict()):
|
||||
ret.context = context
|
||||
return ret
|
||||
|
||||
|
||||
def make_tritongpu_ir(mod, num_warps):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_inliner_pass()
|
||||
@@ -775,6 +779,7 @@ def make_tritongpu_ir(mod, num_warps):
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
@@ -785,22 +790,24 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def make_ptx(mod):
|
||||
# TODO
|
||||
return mod
|
||||
|
||||
def compile(fn, signature, constants = dict(), attributes = dict(), num_warps=4, num_stages=3, output = "ttgir"):
|
||||
|
||||
def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
|
||||
assert output in ["ttir", "ttgir", "ptx"]
|
||||
# triton-ir
|
||||
module = make_triton_ir(fn, signature, constants, attributes)
|
||||
if output == "ttir":
|
||||
if output == "ttir":
|
||||
return module.str()
|
||||
# tritongpu-ir
|
||||
module = make_tritongpu_ir(module, num_warps)
|
||||
module = optimize_tritongpu_ir(module, num_stages)
|
||||
if output == "ttgir":
|
||||
if output == "ttgir":
|
||||
return module.str()
|
||||
# ptx
|
||||
if output == "ptx":
|
||||
if output == "ptx":
|
||||
return make_ptx(module)
|
||||
assert False
|
||||
assert False
|
||||
|
@@ -1,2 +1,2 @@
|
||||
from .jit import JITFunction, jit
|
||||
from .autotuner import Config, autotune, heuristics
|
||||
from .autotuner import Config, autotune, heuristics # noqa: F401
|
||||
from .jit import JITFunction, jit # noqa: F401
|
||||
|
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from ..testing import do_bench
|
||||
|
||||
|
||||
class Autotuner:
|
||||
@@ -57,7 +59,7 @@ class Autotuner:
|
||||
config.pre_hook(self.nargs)
|
||||
self.hook(args)
|
||||
self.kernel(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
return triton.testing.do_bench(kernel_call)
|
||||
return do_bench(kernel_call)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
@@ -199,4 +201,3 @@ def heuristics(values):
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
@@ -8,6 +8,7 @@ import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import textwrap
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from ..tools.disasm import extract
|
||||
@@ -16,6 +17,7 @@ from ..tools.disasm import extract
|
||||
# Binary
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Binary:
|
||||
def __init__(self, backend, name, asm, shared_mem, num_warps):
|
||||
self.backend = backend
|
||||
@@ -63,13 +65,13 @@ class LoadedBinary:
|
||||
# Kernel
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Kernel:
|
||||
|
||||
def __call__(self, *args, grid, num_warps=4, num_stages=3, **kwargs):
|
||||
raise RuntimeError("Not implemented. Public repo implementation will be rewritten to reduce latency.")
|
||||
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Dependencies Finder
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -118,6 +120,7 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
# JITFunction
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def version_key():
|
||||
import pkgutil
|
||||
@@ -232,7 +235,7 @@ class JITFunction:
|
||||
|
||||
def __call__(self, *wargs, **kwargs):
|
||||
return self.kernel(*wargs, **kwargs, grid=self.grid)
|
||||
|
||||
|
||||
return Launcher(self._init_kernel(), grid)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -242,6 +245,7 @@ class JITFunction:
|
||||
# `jit` decorator
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def jit(*args, **kwargs):
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
@@ -265,4 +269,4 @@ def jit(*args, **kwargs):
|
||||
else:
|
||||
def decorator(fn):
|
||||
return JITFunction(fn, **kwargs)
|
||||
return decorator
|
||||
return decorator
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -17,6 +18,7 @@ def next_power_of_2(n):
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
class TensorWrapper:
|
||||
def __init__(self, base, dtype):
|
||||
self.dtype = dtype
|
||||
|
Reference in New Issue
Block a user