[FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379)

This commit is contained in:
Philippe Tillet
2021-11-29 19:11:26 -08:00
committed by GitHub
parent 1296eb877b
commit c86ad9c9ab
4 changed files with 149 additions and 122 deletions

View File

@@ -188,8 +188,8 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
continue;
}
// argument is `constexpr`
py::object value = arg.attr("value");
if(value){
if(py::hasattr(arg, "value")){
py::object value = arg.attr("value");
py::object name = arg_names[i];
constants[name] = value;
py::object repr = py::repr(value);
@@ -198,7 +198,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
cache_key += std::string(start, len);
continue;
}
assert(false);
std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>();
std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "."
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
throw std::runtime_error(err_msg);
}
cache_key += std::to_string(num_warps);
cache_key += std::to_string(num_stages);
@@ -269,9 +272,10 @@ void init_triton_runtime(py::module &&m) {
CU_LAUNCH_PARAM_END
};
uint64_t _stream = PyLong_AsLong(stream.ptr());
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
nullptr, config);
if(grid_0*grid_1*grid_2 > 0)
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
nullptr, config);
return bin;
});
@@ -293,12 +297,16 @@ void init_triton_runtime(py::module &&m) {
const std::string &args, int64_t shared_mem){
void* args_ptr = (void*)args.data();
size_t args_size = args.size();
// release the gil in case the enqueue blocks
// cuda will block if too many ops are enqueued
Py_BEGIN_ALLOW_THREADS
if(backend == HOST)
host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
if(backend == CUDA)
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
if(backend == ROCM)
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
Py_END_ALLOW_THREADS
});