[CI] Added basic CI skeletons (#23)
Includes minor fixes to make things compile and pass static checks properly
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Dict, Union
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
|
||||
|
||||
def str_to_ty(name):
|
||||
if name[0] == "*":
|
||||
ty = str_to_ty(name[1:])
|
||||
ty = str_to_ty(name[1:])
|
||||
return triton.language.pointer_type(ty)
|
||||
tys = {
|
||||
"fp8": triton.language.float8,
|
||||
@@ -26,9 +28,10 @@ def str_to_ty(name):
|
||||
"u32": triton.language.uint32,
|
||||
"u64": triton.language.uint64,
|
||||
"B": triton.language.int1,
|
||||
}
|
||||
}
|
||||
return tys[name]
|
||||
|
||||
|
||||
def mangle_ty(ty):
|
||||
if ty.is_ptr():
|
||||
return 'P' + mangle_ty(ty.element_ty)
|
||||
@@ -62,6 +65,7 @@ def mangle_fn(name, arg_tys, constants):
|
||||
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
|
||||
return ret
|
||||
|
||||
|
||||
class enter_sub_region:
|
||||
def __init__(self, generator: CodeGenerator):
|
||||
self.generator = generator
|
||||
@@ -79,6 +83,7 @@ class enter_sub_region:
|
||||
self.generator.lscope = self.liveins
|
||||
self.generator.local_defs = self.prev_defs
|
||||
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, module=None, is_kernel=False, function_types=dict()):
|
||||
self.builder = _triton.ir.builder(context)
|
||||
@@ -491,8 +496,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
|
||||
[arg.handle for arg in init_args])
|
||||
# merge the condition region
|
||||
before_block = self.builder.create_block_with_parent(while_op.get_before(),
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
before_block = self.builder.create_block_with_parent(while_op.get_before(),
|
||||
[ty.to_ir(self.builder) for ty in ret_types])
|
||||
cond_block.merge_block_before(before_block)
|
||||
self.builder.set_insertion_point_to_end(before_block)
|
||||
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
||||
@@ -538,7 +543,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||
is_static = all([isinstance(x, triton.language.constexpr) for x in iter_args])
|
||||
if is_static:
|
||||
st_target = ast.Name(id=node.target.id, ctx=ast.Store())
|
||||
iter_args = [arg.value for arg in iter_args]
|
||||
range = iterator(*iter_args)
|
||||
if len(range) <= 10:
|
||||
@@ -597,7 +601,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# replace global uses with block arguments
|
||||
for i, name in enumerate(names):
|
||||
# arg0 is the induction variable
|
||||
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i+1))
|
||||
for_op.get_body(0).replace_use_in_block_with(init_args[i].handle, for_op.get_body(0).arg(i + 1))
|
||||
|
||||
# update lscope & local_defs (ForOp defines new values)
|
||||
for i, name in enumerate(names):
|
||||
@@ -633,7 +637,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
args = getcallargs(fn.fn, *args, **kws)
|
||||
args = [args[name] for name in fn.arg_names]
|
||||
args = [arg if isinstance(arg, triton.language.tensor)
|
||||
else triton.language.constexpr(arg) for arg in args]
|
||||
else triton.language.constexpr(arg) for arg in args]
|
||||
# generate function def
|
||||
attributes = dict()
|
||||
constexprs = [i for i, arg in enumerate(args) if isinstance(arg, triton.language.constexpr)]
|
||||
@@ -712,7 +716,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
raise NotImplementedError("Unsupported node: {}".format(typename))
|
||||
|
||||
|
||||
|
||||
class CompilationError(Exception):
|
||||
def __init__(self, src, node):
|
||||
self.message = f'at {node.lineno}:{node.col_offset}:\n'
|
||||
@@ -742,11 +745,11 @@ class OutOfResources(Exception):
|
||||
return (type(self), (self.required, self.limit, self.name))
|
||||
|
||||
|
||||
def make_triton_ir(fn, signature, constants = dict(), attributes = dict()):
|
||||
def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
# create kernel prototype
|
||||
arg_types = signature.replace(' ','').split(',')
|
||||
arg_types = signature.replace(' ', '').split(',')
|
||||
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
|
||||
arg_types = [str_to_ty(x) for x in arg_types]
|
||||
prototype = triton.language.function_type([], arg_types)
|
||||
@@ -765,6 +768,7 @@ def make_triton_ir(fn, signature, constants = dict(), attributes = dict()):
|
||||
ret.context = context
|
||||
return ret
|
||||
|
||||
|
||||
def make_tritongpu_ir(mod, num_warps):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_inliner_pass()
|
||||
@@ -775,6 +779,7 @@ def make_tritongpu_ir(mod, num_warps):
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
@@ -785,22 +790,24 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def make_ptx(mod):
|
||||
# TODO
|
||||
return mod
|
||||
|
||||
def compile(fn, signature, constants = dict(), attributes = dict(), num_warps=4, num_stages=3, output = "ttgir"):
|
||||
|
||||
def compile(fn, signature, constants=dict(), attributes=dict(), num_warps=4, num_stages=3, output="ttgir"):
|
||||
assert output in ["ttir", "ttgir", "ptx"]
|
||||
# triton-ir
|
||||
module = make_triton_ir(fn, signature, constants, attributes)
|
||||
if output == "ttir":
|
||||
if output == "ttir":
|
||||
return module.str()
|
||||
# tritongpu-ir
|
||||
module = make_tritongpu_ir(module, num_warps)
|
||||
module = optimize_tritongpu_ir(module, num_stages)
|
||||
if output == "ttgir":
|
||||
if output == "ttgir":
|
||||
return module.str()
|
||||
# ptx
|
||||
if output == "ptx":
|
||||
if output == "ptx":
|
||||
return make_ptx(module)
|
||||
assert False
|
||||
assert False
|
||||
|
Reference in New Issue
Block a user