[CI] Added basic CI skeletons (#23)

Includes minor fixes to make things compile and pass static checks properly
This commit is contained in:
Philippe Tillet
2022-07-26 14:16:30 -07:00
committed by GitHub
parent 3265e0df5a
commit 25357083e6
10 changed files with 64 additions and 63 deletions

View File

@@ -1,15 +1,17 @@
from __future__ import annotations
import ast
import sys
import warnings
from typing import Dict, Union
import triton
import triton._C.libtriton.triton as _triton
def str_to_ty(name):
if name[0] == "*":
ty = str_to_ty(name[1:])
ty = str_to_ty(name[1:])
return triton.language.pointer_type(ty)
tys = {
"fp8": triton.language.float8,
@@ -26,9 +28,10 @@ def str_to_ty(name):
"u32": triton.language.uint32,
"u64": triton.language.uint64,
"B": triton.language.int1,
}
}
return tys[name]
def mangle_ty(ty):
if ty.is_ptr():
return 'P' + mangle_ty(ty.element_ty)
@@ -62,6 +65,7 @@ def mangle_fn(name, arg_tys, constants):
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
return ret
class enter_sub_region:
def __init__(self, generator: CodeGenerator):
self.generator = generator
@@ -79,6 +83,7 @@ class enter_sub_region:
self.generator.lscope = self.liveins
self.generator.local_defs = self.prev_defs
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
self.builder = _triton.ir.builder(context)
@@ -491,8 +496,8 @@ class CodeGenerator(ast.NodeVisitor):
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
[arg.handle for arg in init_args])
# merge the condition region
before_block = self.builder.create_block_with_parent(while_op.get_before(),
[ty.to_ir(self.builder) for ty in ret_types])
before_block = self.builder.create_block_with_parent(while_op.get_before(),
[ty.to_ir(self.builder) for ty in ret_types])
cond_block.merge_block_before(before_block)
self.builder.set_insertion_point_to_end(before_block)
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
@@ -538,7 +543,6 @@ class CodeGenerator(ast.NodeVisitor):
iter_args = [self.visit(arg) for arg in node.iter.args]
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
if is_static:
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
iter_args = [arg.value for arg in iter_args]
range = iterator(*iter_args)
if len(range) <= 10:
@@ -597,7 +601,7 @@ class CodeGenerator(ast.NodeVisitor):
# replace global uses with block arguments
for i, name in enumerate(names):
# arg0 is the induction variable
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i + 1))
# update lscope & local_defs (ForOp defines new values)
for i, name in enumerate(names):
@@ -633,7 +637,7 @@ class CodeGenerator(ast.NodeVisitor):
args = getcallargs(fn.fn, *args, **kws)
args = [args[name] for name in fn.arg_names]
args = [arg if isinstance(arg, triton.language.tensor)
else triton.language.constexpr(arg) for arg in args]
else triton.language.constexpr(arg) for arg in args]
# generate function def
attributes = dict()
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
@@ -712,7 +716,6 @@ class CodeGenerator(ast.NodeVisitor):
raise NotImplementedError("Unsupported node: {}".format(typename))
class CompilationError(Exception):
def __init__(self, src, node):
self.message = f'at {node.lineno}:{node.col_offset}:\n'
@@ -742,11 +745,11 @@ class OutOfResources(Exception):
return (type(self), (self.required, self.limit, self.name))
def make_triton_ir(fn, signature, constants = dict(), attributes = dict()):
def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
context = _triton.ir.context()
context.load_triton()
# create kernel prototype
arg_types = signature.replace(' ','').split(',')
arg_types = signature.replace(' ', '').split(',')
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
arg_types = [str_to_ty(x) for x in arg_types]
prototype = triton.language.function_type([], arg_types)
@@ -765,6 +768,7 @@ def make_triton_ir(fn, signature, constants = dict(), attributes = dict()):
ret.context = context
return ret
def make_tritongpu_ir(mod, num_warps):
pm = _triton.ir.pass_manager(mod.context)
pm.add_inliner_pass()
@@ -775,6 +779,7 @@ def make_tritongpu_ir(mod, num_warps):
pm.run(mod)
return mod
def optimize_tritongpu_ir(mod, num_stages):
pm = _triton.ir.pass_manager(mod.context)
pm.add_tritongpu_pipeline_pass(num_stages)
@@ -785,22 +790,24 @@ def optimize_tritongpu_ir(mod, num_stages):
pm.run(mod)
return mod
def make_ptx(mod):
# TODO
return mod
def compile(fn, signature, constants = dict(), attributes = dict(), num_warps=4, num_stages=3, output = "ttgir"):
def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
assert output in ["ttir", "ttgir", "ptx"]
# triton-ir
module = make_triton_ir(fn, signature, constants, attributes)
if output == "ttir":
if output == "ttir":
return module.str()
# tritongpu-ir
module = make_tritongpu_ir(module, num_warps)
module = optimize_tritongpu_ir(module, num_stages)
if output == "ttgir":
if output == "ttgir":
return module.str()
# ptx
if output == "ptx":
if output == "ptx":
return make_ptx(module)
assert False
assert False