[RUNTIME] Better support for None
(#387)
* regression test fails but it doesn't make sense to me.
This commit is contained in:
@@ -127,6 +127,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
if(PyLong_Check(arg_ptr)){
|
||||
int overflow;
|
||||
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
|
||||
if(specialize && (value == 1)){
|
||||
cache_key += '1';
|
||||
continue;
|
||||
}
|
||||
// long and int have different kernels
|
||||
if(!overflow & (std::abs(value) <= 0xffffffff)){
|
||||
cache_key += 'I';
|
||||
@@ -147,10 +151,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
if(!specialize)
|
||||
continue;
|
||||
// values equal to 1 are specialized
|
||||
if(value == 1)
|
||||
cache_key += '1';
|
||||
else
|
||||
cache_key += 'x';
|
||||
cache_key += 'x';
|
||||
// values divisible by small powers of 2 are specialized
|
||||
cache_key += pow2_divisor(value);
|
||||
continue;
|
||||
@@ -199,6 +200,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
continue;
|
||||
}
|
||||
std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>();
|
||||
if(ty_str == "NoneType"){
|
||||
cache_key += "None";
|
||||
continue;
|
||||
}
|
||||
std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "."
|
||||
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
|
||||
throw std::runtime_error(err_msg);
|
||||
|
Reference in New Issue
Block a user