[RUNTIME] Better support for None
(#387)
* regression test fails but it doesn't make sense to me.
This commit is contained in:
@@ -127,6 +127,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
if(PyLong_Check(arg_ptr)){
|
if(PyLong_Check(arg_ptr)){
|
||||||
int overflow;
|
int overflow;
|
||||||
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
|
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
|
||||||
|
if(specialize && (value == 1)){
|
||||||
|
cache_key += '1';
|
||||||
|
continue;
|
||||||
|
}
|
||||||
// long and int have different kernels
|
// long and int have different kernels
|
||||||
if(!overflow & (std::abs(value) <= 0xffffffff)){
|
if(!overflow & (std::abs(value) <= 0xffffffff)){
|
||||||
cache_key += 'I';
|
cache_key += 'I';
|
||||||
@@ -147,9 +151,6 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
if(!specialize)
|
if(!specialize)
|
||||||
continue;
|
continue;
|
||||||
// values equal to 1 are specialized
|
// values equal to 1 are specialized
|
||||||
if(value == 1)
|
|
||||||
cache_key += '1';
|
|
||||||
else
|
|
||||||
cache_key += 'x';
|
cache_key += 'x';
|
||||||
// values divisible by small powers of 2 are specialized
|
// values divisible by small powers of 2 are specialized
|
||||||
cache_key += pow2_divisor(value);
|
cache_key += pow2_divisor(value);
|
||||||
@@ -199,6 +200,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>();
|
std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>();
|
||||||
|
if(ty_str == "NoneType"){
|
||||||
|
cache_key += "None";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "."
|
std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "."
|
||||||
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
|
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
|
||||||
throw std::runtime_error(err_msg);
|
throw std::runtime_error(err_msg);
|
||||||
|
@@ -112,6 +112,7 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
else:
|
else:
|
||||||
fn = self.module.get_or_insert_function(node.name, self.prototype)
|
fn = self.module.get_or_insert_function(node.name, self.prototype)
|
||||||
arg_values = []
|
arg_values = []
|
||||||
|
idx = 0
|
||||||
for i, arg_name in enumerate(arg_names):
|
for i, arg_name in enumerate(arg_names):
|
||||||
if i in self.constants:
|
if i in self.constants:
|
||||||
cst = self.constants[i]
|
cst = self.constants[i]
|
||||||
@@ -120,13 +121,15 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
arg_values.append(cst)
|
arg_values.append(cst)
|
||||||
else:
|
else:
|
||||||
if i in self.attributes:
|
if i in self.attributes:
|
||||||
is_ptr = fn.args[i].type.is_ptr()
|
is_ptr = fn.args[idx].type.is_ptr()
|
||||||
attr = 'aligned' if is_ptr else 'multiple_of'
|
attr = 'aligned' if is_ptr else 'multiple_of'
|
||||||
attr = getattr(_triton.ir.attribute_kind, attr)
|
attr = getattr(_triton.ir.attribute_kind, attr)
|
||||||
attr = _triton.ir.attribute(attr, self.attributes[i])
|
attr = _triton.ir.attribute(attr, self.attributes[i])
|
||||||
fn.add_attr(i + 1, attr)
|
fn.add_attr(idx + 1, attr)
|
||||||
fn.args[i].name = arg_name
|
fn.args[idx].name = arg_name
|
||||||
arg_values.append(fn.args[i])
|
arg_values.append(fn.args[idx])
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
|
||||||
for arg_name, arg_value in zip(arg_names, arg_values):
|
for arg_name, arg_value in zip(arg_names, arg_values):
|
||||||
self.set_value(arg_name, arg_value)
|
self.set_value(arg_name, arg_value)
|
||||||
@@ -293,6 +296,10 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
lhs = lhs.value
|
lhs = lhs.value
|
||||||
if isinstance(rhs, triton.language.core.constexpr):
|
if isinstance(rhs, triton.language.core.constexpr):
|
||||||
rhs = rhs.value
|
rhs = rhs.value
|
||||||
|
if type(node.ops[0]) == ast.Is:
|
||||||
|
return triton.language.constexpr(lhs is rhs)
|
||||||
|
if type(node.ops[0]) == ast.IsNot:
|
||||||
|
return triton.language.constexpr(lhs is not rhs)
|
||||||
fn = {
|
fn = {
|
||||||
ast.Eq: '__eq__',
|
ast.Eq: '__eq__',
|
||||||
ast.NotEq: '__ne__',
|
ast.NotEq: '__ne__',
|
||||||
@@ -300,8 +307,6 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
ast.LtE: '__le__',
|
ast.LtE: '__le__',
|
||||||
ast.Gt: '__gt__',
|
ast.Gt: '__gt__',
|
||||||
ast.GtE: '__ge__',
|
ast.GtE: '__ge__',
|
||||||
ast.Is: '__eq__',
|
|
||||||
ast.IsNot: '__ne__',
|
|
||||||
}[type(node.ops[0])]
|
}[type(node.ops[0])]
|
||||||
if self.is_triton_object(lhs):
|
if self.is_triton_object(lhs):
|
||||||
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||||
@@ -313,8 +318,12 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
|
|
||||||
def visit_UnaryOp(self, node):
|
def visit_UnaryOp(self, node):
|
||||||
op = self.visit(node.operand)
|
op = self.visit(node.operand)
|
||||||
|
if type(node.op) == ast.Not:
|
||||||
|
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
|
||||||
|
return triton.language.constexpr(not op)
|
||||||
if isinstance(op, triton.language.core.constexpr):
|
if isinstance(op, triton.language.core.constexpr):
|
||||||
op = op.value
|
op = op.value
|
||||||
|
# print(op)
|
||||||
fn = {
|
fn = {
|
||||||
ast.USub: '__neg__',
|
ast.USub: '__neg__',
|
||||||
ast.UAdd: '__pos__',
|
ast.UAdd: '__pos__',
|
||||||
@@ -592,11 +601,11 @@ class Kernel:
|
|||||||
self.fn = fn
|
self.fn = fn
|
||||||
|
|
||||||
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages):
|
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages):
|
||||||
wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)]
|
|
||||||
# create IR module
|
# create IR module
|
||||||
context = _triton.ir.context()
|
context = _triton.ir.context()
|
||||||
# get just-in-time proto-type of kernel
|
# get just-in-time proto-type of kernel
|
||||||
arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs]
|
fn_args = [arg for i, arg in enumerate(wargs) if i not in constants]
|
||||||
|
arg_types = [Kernel._to_triton_ir(context, arg) for arg in fn_args]
|
||||||
ret_type = _triton.ir.type.get_void(context)
|
ret_type = _triton.ir.type.get_void(context)
|
||||||
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
prototype = _triton.ir.type.make_function(ret_type, arg_types)
|
||||||
# generate Triton-IR
|
# generate Triton-IR
|
||||||
@@ -629,8 +638,9 @@ class Kernel:
|
|||||||
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
if isinstance(a, int) and i not in self.fn.do_not_specialize}
|
||||||
|
|
||||||
# transforms ints whose value is one into constants for just-in-time compilation
|
# transforms ints whose value is one into constants for just-in-time compilation
|
||||||
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
|
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
|
||||||
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
|
||||||
|
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
|
||||||
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
# create cache directory
|
# create cache directory
|
||||||
|
@@ -35,6 +35,9 @@ def _patch(fn):
|
|||||||
builder = args[-1]
|
builder = args[-1]
|
||||||
assert isinstance(builder, ir.builder)
|
assert isinstance(builder, ir.builder)
|
||||||
args = [_to_ir(x, builder) for x in args]
|
args = [_to_ir(x, builder) for x in args]
|
||||||
|
# for i, arg in enumerate(args):
|
||||||
|
# if arg is None:
|
||||||
|
# raise ValueError(f"Unexpected `None` at position {i} for function {fn.__name__}")
|
||||||
kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()}
|
kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()}
|
||||||
ret = fn(*args, **kwargs)
|
ret = fn(*args, **kwargs)
|
||||||
if isinstance(ret, tuple):
|
if isinstance(ret, tuple):
|
||||||
@@ -77,7 +80,7 @@ class pointer_dtype:
|
|||||||
def handle(self, builder):
|
def handle(self, builder):
|
||||||
return ir.type.make_ptr(self.element_ty.handle(builder), 1)
|
return ir.type.make_ptr(self.element_ty.handle(builder), 1)
|
||||||
|
|
||||||
|
# scalar types
|
||||||
int1 = dtype(ir.type.get_int1)
|
int1 = dtype(ir.type.get_int1)
|
||||||
int8 = dtype(ir.type.get_int8)
|
int8 = dtype(ir.type.get_int8)
|
||||||
int16 = dtype(ir.type.get_int16)
|
int16 = dtype(ir.type.get_int16)
|
||||||
@@ -88,7 +91,7 @@ float16 = dtype(ir.type.get_fp16)
|
|||||||
bfloat16 = dtype(ir.type.get_bf16)
|
bfloat16 = dtype(ir.type.get_bf16)
|
||||||
float32 = dtype(ir.type.get_fp32)
|
float32 = dtype(ir.type.get_fp32)
|
||||||
float64 = dtype(ir.type.get_fp64)
|
float64 = dtype(ir.type.get_fp64)
|
||||||
|
# pointer types
|
||||||
pi32_t = pointer_dtype(int32)
|
pi32_t = pointer_dtype(int32)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user