[RUNTIME] Better support for None (#387)

* regression test fails but it doesn't make sense to me.
This commit is contained in:
Philippe Tillet
2021-12-09 13:21:22 -08:00
committed by GitHub
parent f23bf55f15
commit e31b9b4e66
3 changed files with 33 additions and 15 deletions

View File

@@ -127,6 +127,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
if(PyLong_Check(arg_ptr)){
int overflow;
long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
if(specialize && (value == 1)){
cache_key += '1';
continue;
}
// long and int have different kernels
if(!overflow & (std::abs(value) <= 0xffffffff)){
cache_key += 'I';
@@ -147,10 +151,7 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
if(!specialize)
continue;
// values equal to 1 are specialized
if(value == 1)
cache_key += '1';
else
cache_key += 'x';
cache_key += 'x';
// values divisible by small powers of 2 are specialized
cache_key += pow2_divisor(value);
continue;
@@ -199,6 +200,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
continue;
}
std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>();
if(ty_str == "NoneType"){
cache_key += "None";
continue;
}
std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "."
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
throw std::runtime_error(err_msg);