[RUNTIME] Better support for None (#387)

* regression test fails but it doesn't make sense to me.
This commit is contained in:
Philippe Tillet
2021-12-09 13:21:22 -08:00
committed by GitHub
parent f23bf55f15
commit e31b9b4e66
3 changed files with 33 additions and 15 deletions

View File

@@ -112,6 +112,7 @@ class CodeGenerator(ast.NodeVisitor):
else:
fn = self.module.get_or_insert_function(node.name, self.prototype)
arg_values = []
idx = 0
for i, arg_name in enumerate(arg_names):
if i in self.constants:
cst = self.constants[i]
@@ -120,13 +121,15 @@ class CodeGenerator(ast.NodeVisitor):
arg_values.append(cst)
else:
if i in self.attributes:
is_ptr = fn.args[i].type.is_ptr()
is_ptr = fn.args[idx].type.is_ptr()
attr = 'aligned' if is_ptr else 'multiple_of'
attr = getattr(_triton.ir.attribute_kind, attr)
attr = _triton.ir.attribute(attr, self.attributes[i])
fn.add_attr(i + 1, attr)
fn.args[i].name = arg_name
arg_values.append(fn.args[i])
fn.add_attr(idx + 1, attr)
fn.args[idx].name = arg_name
arg_values.append(fn.args[idx])
idx += 1
for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value)
@@ -293,6 +296,10 @@ class CodeGenerator(ast.NodeVisitor):
lhs = lhs.value
if isinstance(rhs, triton.language.core.constexpr):
rhs = rhs.value
if type(node.ops[0]) == ast.Is:
return triton.language.constexpr(lhs is rhs)
if type(node.ops[0]) == ast.IsNot:
return triton.language.constexpr(lhs is not rhs)
fn = {
ast.Eq: '__eq__',
ast.NotEq: '__ne__',
@@ -300,8 +307,6 @@ class CodeGenerator(ast.NodeVisitor):
ast.LtE: '__le__',
ast.Gt: '__gt__',
ast.GtE: '__ge__',
ast.Is: '__eq__',
ast.IsNot: '__ne__',
}[type(node.ops[0])]
if self.is_triton_object(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder)
@@ -313,8 +318,12 @@ class CodeGenerator(ast.NodeVisitor):
def visit_UnaryOp(self, node):
op = self.visit(node.operand)
if type(node.op) == ast.Not:
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
return triton.language.constexpr(not op)
if isinstance(op, triton.language.core.constexpr):
op = op.value
# print(op)
fn = {
ast.USub: '__neg__',
ast.UAdd: '__pos__',
@@ -592,11 +601,11 @@ class Kernel:
self.fn = fn
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages):
wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)]
# create IR module
context = _triton.ir.context()
# get just-in-time proto-type of kernel
arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs]
fn_args = [arg for i, arg in enumerate(wargs) if i not in constants]
arg_types = [Kernel._to_triton_ir(context, arg) for arg in fn_args]
ret_type = _triton.ir.type.get_void(context)
prototype = _triton.ir.type.make_function(ret_type, arg_types)
# generate Triton-IR
@@ -629,8 +638,9 @@ class Kernel:
if isinstance(a, int) and i not in self.fn.do_not_specialize}
# transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
# create cache directory