[RUNTIME] Better support for None (#387)

* regression test fails but it doesn't make sense to me.
This commit is contained in:
Philippe Tillet
2021-12-09 13:21:22 -08:00
committed by GitHub
parent f23bf55f15
commit e31b9b4e66
3 changed files with 33 additions and 15 deletions

View File

@@ -127,6 +127,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
if(PyLong_Check(arg_ptr)){ if(PyLong_Check(arg_ptr)){
int overflow; int overflow;
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow); long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
if(specialize && (value == 1)){
cache_key += '1';
continue;
}
// long and int have different kernels // long and int have different kernels
if(!overflow & (std::abs(value) <= 0xffffffff)){ if(!overflow & (std::abs(value) <= 0xffffffff)){
cache_key += 'I'; cache_key += 'I';
@@ -147,9 +151,6 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
if(!specialize) if(!specialize)
continue; continue;
// values equal to 1 are specialized // values equal to 1 are specialized
if(value == 1)
cache_key += '1';
else
cache_key += 'x'; cache_key += 'x';
// values divisible by small powers of 2 are specialized // values divisible by small powers of 2 are specialized
cache_key += pow2_divisor(value); cache_key += pow2_divisor(value);
@@ -199,6 +200,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
continue; continue;
} }
std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>(); std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>();
if(ty_str == "NoneType"){
cache_key += "None";
continue;
}
std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "." std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "."
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported."; + " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
throw std::runtime_error(err_msg); throw std::runtime_error(err_msg);

View File

@@ -112,6 +112,7 @@ class CodeGenerator(ast.NodeVisitor):
else: else:
fn = self.module.get_or_insert_function(node.name, self.prototype) fn = self.module.get_or_insert_function(node.name, self.prototype)
arg_values = [] arg_values = []
idx = 0
for i, arg_name in enumerate(arg_names): for i, arg_name in enumerate(arg_names):
if i in self.constants: if i in self.constants:
cst = self.constants[i] cst = self.constants[i]
@@ -120,13 +121,15 @@ class CodeGenerator(ast.NodeVisitor):
arg_values.append(cst) arg_values.append(cst)
else: else:
if i in self.attributes: if i in self.attributes:
is_ptr = fn.args[i].type.is_ptr() is_ptr = fn.args[idx].type.is_ptr()
attr = 'aligned' if is_ptr else 'multiple_of' attr = 'aligned' if is_ptr else 'multiple_of'
attr = getattr(_triton.ir.attribute_kind, attr) attr = getattr(_triton.ir.attribute_kind, attr)
attr = _triton.ir.attribute(attr, self.attributes[i]) attr = _triton.ir.attribute(attr, self.attributes[i])
fn.add_attr(i + 1, attr) fn.add_attr(idx + 1, attr)
fn.args[i].name = arg_name fn.args[idx].name = arg_name
arg_values.append(fn.args[i]) arg_values.append(fn.args[idx])
idx += 1
for arg_name, arg_value in zip(arg_names, arg_values): for arg_name, arg_value in zip(arg_names, arg_values):
self.set_value(arg_name, arg_value) self.set_value(arg_name, arg_value)
@@ -293,6 +296,10 @@ class CodeGenerator(ast.NodeVisitor):
lhs = lhs.value lhs = lhs.value
if isinstance(rhs, triton.language.core.constexpr): if isinstance(rhs, triton.language.core.constexpr):
rhs = rhs.value rhs = rhs.value
if type(node.ops[0]) == ast.Is:
return triton.language.constexpr(lhs is rhs)
if type(node.ops[0]) == ast.IsNot:
return triton.language.constexpr(lhs is not rhs)
fn = { fn = {
ast.Eq: '__eq__', ast.Eq: '__eq__',
ast.NotEq: '__ne__', ast.NotEq: '__ne__',
@@ -300,8 +307,6 @@ class CodeGenerator(ast.NodeVisitor):
ast.LtE: '__le__', ast.LtE: '__le__',
ast.Gt: '__gt__', ast.Gt: '__gt__',
ast.GtE: '__ge__', ast.GtE: '__ge__',
ast.Is: '__eq__',
ast.IsNot: '__ne__',
}[type(node.ops[0])] }[type(node.ops[0])]
if self.is_triton_object(lhs): if self.is_triton_object(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder) return getattr(lhs, fn)(rhs, _builder=self.builder)
@@ -313,8 +318,12 @@ class CodeGenerator(ast.NodeVisitor):
def visit_UnaryOp(self, node): def visit_UnaryOp(self, node):
op = self.visit(node.operand) op = self.visit(node.operand)
if type(node.op) == ast.Not:
assert isinstance(op, triton.language.constexpr), "`not` only supported for constexpr at the moment"
return triton.language.constexpr(not op)
if isinstance(op, triton.language.core.constexpr): if isinstance(op, triton.language.core.constexpr):
op = op.value op = op.value
# print(op)
fn = { fn = {
ast.USub: '__neg__', ast.USub: '__neg__',
ast.UAdd: '__pos__', ast.UAdd: '__pos__',
@@ -592,11 +601,11 @@ class Kernel:
self.fn = fn self.fn = fn
def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages): def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages):
wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)]
# create IR module # create IR module
context = _triton.ir.context() context = _triton.ir.context()
# get just-in-time proto-type of kernel # get just-in-time proto-type of kernel
arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs] fn_args = [arg for i, arg in enumerate(wargs) if i not in constants]
arg_types = [Kernel._to_triton_ir(context, arg) for arg in fn_args]
ret_type = _triton.ir.type.get_void(context) ret_type = _triton.ir.type.get_void(context)
prototype = _triton.ir.type.make_function(ret_type, arg_types) prototype = _triton.ir.type.make_function(ret_type, arg_types)
# generate Triton-IR # generate Triton-IR
@@ -629,8 +638,9 @@ class Kernel:
if isinstance(a, int) and i not in self.fn.do_not_specialize} if isinstance(a, int) and i not in self.fn.do_not_specialize}
# transforms ints whose value is one into constants for just-in-time compilation # transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1} constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1 and i not in self.fn.do_not_specialize}
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)}) constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
constants.update({i: None for i, arg in enumerate(wargs) if arg is None})
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest() hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
# create cache directory # create cache directory

View File

@@ -35,6 +35,9 @@ def _patch(fn):
builder = args[-1] builder = args[-1]
assert isinstance(builder, ir.builder) assert isinstance(builder, ir.builder)
args = [_to_ir(x, builder) for x in args] args = [_to_ir(x, builder) for x in args]
# for i, arg in enumerate(args):
# if arg is None:
# raise ValueError(f"Unexpected `None` at position {i} for function {fn.__name__}")
kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()} kwargs = {k: _to_ir(v, builder) for k, v in kwargs.items()}
ret = fn(*args, **kwargs) ret = fn(*args, **kwargs)
if isinstance(ret, tuple): if isinstance(ret, tuple):
@@ -77,7 +80,7 @@ class pointer_dtype:
def handle(self, builder): def handle(self, builder):
return ir.type.make_ptr(self.element_ty.handle(builder), 1) return ir.type.make_ptr(self.element_ty.handle(builder), 1)
# scalar types
int1 = dtype(ir.type.get_int1) int1 = dtype(ir.type.get_int1)
int8 = dtype(ir.type.get_int8) int8 = dtype(ir.type.get_int8)
int16 = dtype(ir.type.get_int16) int16 = dtype(ir.type.get_int16)
@@ -88,7 +91,7 @@ float16 = dtype(ir.type.get_fp16)
bfloat16 = dtype(ir.type.get_bf16) bfloat16 = dtype(ir.type.get_bf16)
float32 = dtype(ir.type.get_fp32) float32 = dtype(ir.type.get_fp32)
float64 = dtype(ir.type.get_fp64) float64 = dtype(ir.type.get_fp64)
# pointer types
pi32_t = pointer_dtype(int32) pi32_t = pointer_dtype(int32)