diff --git a/python/src/triton.cc b/python/src/triton.cc index a606c051a..d56ff8430 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -226,11 +226,16 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f // copy param std::memcpy(params_ptr, &value, 8); params_ptr += 8; - // udpate cache key + // update cache key cache_key += dtype_cache_key_part(arg.attr("dtype")); cache_key += "*"; cache_key += "[multipleof("; - size_t range_size = get_pointer_range_size(value); + size_t range_size; + try { + range_size = get_pointer_range_size(value); + } catch (...) { + throw std::runtime_error("argument tensor #" + std::to_string(i) + " is not on cuda! " + std::string(py::str(arg))); + } cache_key += std::to_string(std::min(pow2_divisor(value), pow2_divisor(range_size))); cache_key += ")]"; continue;