From e31b9b4e660b037d5680ef7b9924fd012566a22b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 9 Dec 2021 13:21:22 -0800 Subject: [PATCH] [RUNTIME] Better support for `None` (#387) * regression test fails but it doesn't make sense to me. --- python/src/triton.cc | 13 +++++++++---- python/triton/code_gen.py | 28 +++++++++++++++++++--------- python/triton/language/core.py | 7 +++++-- 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/python/src/triton.cc b/python/src/triton.cc index 7b6c1ce81..aca171dae 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -127,6 +127,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f if(PyLong_Check(arg_ptr)){ int overflow; long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow); + if(specialize && (value == 1)){ + cache_key += '1'; + continue; + } // long and int have different kernels if(!overflow & (std::abs(value) <= 0xffffffff)){ cache_key += 'I'; @@ -147,10 +151,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f if(!specialize) continue; // values equal to 1 are specialized - if(value == 1) - cache_key += '1'; - else - cache_key += 'x'; + cache_key += 'x'; // values divisible by small powers of 2 are specialized cache_key += pow2_divisor(value); continue; @@ -199,6 +200,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f continue; } std::string ty_str = arg.attr("__class__").attr("__name__").cast(); + if(ty_str == "NoneType"){ + cache_key += "None"; + continue; + } std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "." + " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported."; throw std::runtime_error(err_msg); diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index b2fded136..e4b7d9f0f 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -112,6 +112,7 @@ class CodeGenerator(ast.NodeVisitor): else: fn = self.module.get_or_insert_function(node.name, self.prototype) arg_values = [] + idx = 0 for i, arg_name in enumerate(arg_names): if i in self.constants: cst = self.constants[i] @@ -120,13 +121,15 @@ class CodeGenerator(ast.NodeVisitor): arg_values.append(cst) else: if i in self.attributes: - is_ptr = fn.args[i].type.is_ptr() + is_ptr = fn.args[idx].type.is_ptr() attr = 'aligned' if is_ptr else 'multiple_of' attr = getattr(_triton.ir.attribute_kind, attr) attr = _triton.ir.attribute(attr, self.attributes[i]) - fn.add_attr(i + 1, attr) - fn.args[i].name = arg_name - arg_values.append(fn.args[i]) + fn.add_attr(idx + 1, attr) + fn.args[idx].name = arg_name + arg_values.append(fn.args[idx]) + idx += 1 + for arg_name, arg_value in zip(arg_names, arg_values): self.set_value(arg_name, arg_value) @@ -293,6 +296,10 @@ class CodeGenerator(ast.NodeVisitor): lhs = lhs.value if isinstance(rhs, triton.language.core.constexpr): rhs = rhs.value + if type(node.ops[0]) == ast.Is: + return triton.language.constexpr(lhs is rhs) + if type(node.ops[0]) == ast.IsNot: + return triton.language.constexpr(lhs is not rhs) fn = { ast.Eq: '__eq__', ast.NotEq: '__ne__', @@ -300,8 +307,6 @@ class CodeGenerator(ast.NodeVisitor): ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__', - ast.Is: '__eq__', - ast.IsNot: '__ne__', }[type(node.ops[0])] if self.is_triton_object(lhs): return getattr(lhs, fn)(rhs, _builder=self.builder) @@ -313,8 +318,12 @@ class CodeGenerator(ast.NodeVisitor): def visit_UnaryOp(self, node): op = self.visit(node.operand) + if type(node.op) == ast.Not: + assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment" + return triton.language.constexpr(not op) if isinstance(op, triton.language.core.constexpr): op = op.value + # print(op) fn = { ast.USub: '__neg__', ast.UAdd: '__pos__', @@ -592,11 +601,11 @@ class Kernel: self.fn = fn def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages): - wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)] # create IR module context = _triton.ir.context() # get just-in-time proto-type of kernel - arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs] + fn_args = [arg for i, arg in enumerate(wargs) if i not in constants] + arg_types = [Kernel._to_triton_ir(context, arg) for arg in fn_args] ret_type = _triton.ir.type.get_void(context) prototype = _triton.ir.type.make_function(ret_type, arg_types) # generate Triton-IR @@ -629,8 +638,9 @@ class Kernel: if isinstance(a, int) and i not in self.fn.do_not_specialize} # transforms ints whose value is one into constants for just-in-time compilation - constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} + constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize} constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) + constants.update({i: None for i, arg in enumerate(wargs) if arg is None}) hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() # create cache directory diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 7875a30f6..d7240fcf8 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -35,6 +35,9 @@ def _patch(fn): builder = args[-1] assert isinstance(builder, ir.builder) args = [_to_ir(x, builder) for x in args] + # for i, arg in enumerate(args): + # if arg is None: + # raise ValueError(f"Unexpected `None` at position {i} for function {fn.__name__}") kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()} ret = fn(*args, **kwargs) if isinstance(ret, tuple): @@ -77,7 +80,7 @@ class pointer_dtype: def handle(self, builder): return ir.type.make_ptr(self.element_ty.handle(builder), 1) - +# scalar types int1 = dtype(ir.type.get_int1) int8 = dtype(ir.type.get_int8) int16 = dtype(ir.type.get_int16) @@ -88,7 +91,7 @@ float16 = dtype(ir.type.get_fp16) bfloat16 = dtype(ir.type.get_bf16) float32 = dtype(ir.type.get_fp32) float64 = dtype(ir.type.get_fp64) - +# pointer types pi32_t = pointer_dtype(int32)