[RUNTIME] Better support for None
(#387)
* regression test fails but it doesn't make sense to me.
This commit is contained in:
@@ -127,6 +127,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
if(PyLong_Check(arg_ptr)){
|
||||
int overflow;
|
||||
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
|
||||
if(specialize && (value == 1)){
|
||||
cache_key += '1';
|
||||
continue;
|
||||
}
|
||||
// long and int have different kernels
|
||||
if(!overflow & (std::abs(value) <= 0xffffffff)){
|
||||
cache_key += 'I';
|
||||
@@ -147,10 +151,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
if(!specialize)
|
||||
continue;
|
||||
// values equal to 1 are specialized
|
||||
if(value == 1)
|
||||
cache_key += '1';
|
||||
else
|
||||
cache_key += 'x';
|
||||
cache_key += 'x';
|
||||
// values divisible by small powers of 2 are specialized
|
||||
cache_key += pow2_divisor(value);
|
||||
continue;
|
||||
@@ -199,6 +200,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
continue;
|
||||
}
|
||||
std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>();
|
||||
if(ty_str == "NoneType"){
|
||||
cache_key += "None";
|
||||
continue;
|
||||
}
|
||||
std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "."
|
||||
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
|
||||
throw std::runtime_error(err_msg);
|
||||
|
@@ -112,6 +112,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
else:
|
||||
fn = self.module.get_or_insert_function(node.name, self.prototype)
|
||||
arg_values = []
|
||||
idx = 0
|
||||
for i, arg_name in enumerate(arg_names):
|
||||
if i in self.constants:
|
||||
cst = self.constants[i]
|
||||
@@ -120,13 +121,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
arg_values.append(cst)
|
||||
else:
|
||||
if i in self.attributes:
|
||||
is_ptr = fn.args[i].type.is_ptr()
|
||||
is_ptr = fn.args[idx].type.is_ptr()
|
||||
attr = 'aligned' if is_ptr else 'multiple_of'
|
||||
attr = getattr(_triton.ir.attribute_kind, attr)
|
||||
attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||
fn.add_attr(i + 1, attr)
|
||||
fn.args[i].name = arg_name
|
||||
arg_values.append(fn.args[i])
|
||||
fn.add_attr(idx + 1, attr)
|
||||
fn.args[idx].name = arg_name
|
||||
arg_values.append(fn.args[idx])
|
||||
idx += 1
|
||||
|
||||
|
||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||
self.set_value(arg_name, arg_value)
|
||||
@@ -293,6 +296,10 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
lhs = lhs.value
|
||||
if isinstance(rhs, triton.language.core.constexpr):
|
||||
rhs = rhs.value
|
||||
if type(node.ops[0]) == ast.Is:
|
||||
return triton.language.constexpr(lhs is rhs)
|
||||
if type(node.ops[0]) == ast.IsNot:
|
||||
return triton.language.constexpr(lhs is not rhs)
|
||||
fn = {
|
||||
ast.Eq: '__eq__',
|
||||
ast.NotEq: '__ne__',
|
||||
@@ -300,8 +307,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ast.LtE: '__le__',
|
||||
ast.Gt: '__gt__',
|
||||
ast.GtE: '__ge__',
|
||||
ast.Is: '__eq__',
|
||||
ast.IsNot: '__ne__',
|
||||
}[type(node.ops[0])]
|
||||
if self.is_triton_object(lhs):
|
||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||
@@ -313,8 +318,12 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
def visit_UnaryOp(self, node):
|
||||
op = self.visit(node.operand)
|
||||
if type(node.op) == ast.Not:
|
||||
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
|
||||
return triton.language.constexpr(not op)
|
||||
if isinstance(op, triton.language.core.constexpr):
|
||||
op = op.value
|
||||
# print(op)
|
||||
fn = {
|
||||
ast.USub: '__neg__',
|
||||
ast.UAdd: '__pos__',
|
||||
@@ -592,11 +601,11 @@ class Kernel:
|
||||
self.fn = fn
|
||||
|
||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages):
|
||||
wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)]
|
||||
# create IR module
|
||||
context = _triton.ir.context()
|
||||
# get just-in-time proto-type of kernel
|
||||
arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs]
|
||||
fn_args = [arg for i, arg in enumerate(wargs) if i not in constants]
|
||||
arg_types = [Kernel._to_triton_ir(context, arg) for arg in fn_args]
|
||||
ret_type = _triton.ir.type.get_void(context)
|
||||
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||
# generate Triton-IR
|
||||
@@ -629,8 +638,9 @@ class Kernel:
|
||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||
|
||||
# transforms ints whose value is one into constants for just-in-time compilation
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
|
||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
|
||||
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
|
||||
# create cache directory
|
||||
|
@@ -35,6 +35,9 @@ def _patch(fn):
|
||||
builder = args[-1]
|
||||
assert isinstance(builder, ir.builder)
|
||||
args = [_to_ir(x, builder) for x in args]
|
||||
# for i, arg in enumerate(args):
|
||||
# if arg is None:
|
||||
# raise ValueError(f"Unexpected `None` at position {i} for function {fn.__name__}")
|
||||
kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()}
|
||||
ret = fn(*args, **kwargs)
|
||||
if isinstance(ret, tuple):
|
||||
@@ -77,7 +80,7 @@ class pointer_dtype:
|
||||
def handle(self, builder):
|
||||
return ir.type.make_ptr(self.element_ty.handle(builder), 1)
|
||||
|
||||
|
||||
# scalar types
|
||||
int1 = dtype(ir.type.get_int1)
|
||||
int8 = dtype(ir.type.get_int8)
|
||||
int16 = dtype(ir.type.get_int16)
|
||||
@@ -88,7 +91,7 @@ float16 = dtype(ir.type.get_fp16)
|
||||
bfloat16 = dtype(ir.type.get_bf16)
|
||||
float32 = dtype(ir.type.get_fp32)
|
||||
float64 = dtype(ir.type.get_fp64)
|
||||
|
||||
# pointer types
|
||||
pi32_t = pointer_dtype(int32)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user