[FRONTEND] Added default arguments to non-kernel @triton.jit'd function (#379)
This commit is contained in:
@@ -188,8 +188,8 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
continue;
|
||||
}
|
||||
// argument is `constexpr`
|
||||
py::object value = arg.attr("value");
|
||||
if(value){
|
||||
if(py::hasattr(arg, "value")){
|
||||
py::object value = arg.attr("value");
|
||||
py::object name = arg_names[i];
|
||||
constants[name] = value;
|
||||
py::object repr = py::repr(value);
|
||||
@@ -198,7 +198,10 @@ void parse_args(py::list& args, py::list do_not_specialize, const std::string& f
|
||||
cache_key += std::string(start, len);
|
||||
continue;
|
||||
}
|
||||
assert(false);
|
||||
std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>();
|
||||
std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "."
|
||||
+ " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
|
||||
throw std::runtime_error(err_msg);
|
||||
}
|
||||
cache_key += std::to_string(num_warps);
|
||||
cache_key += std::to_string(num_stages);
|
||||
@@ -269,9 +272,10 @@ void init_triton_runtime(py::module &&m) {
|
||||
CU_LAUNCH_PARAM_END
|
||||
};
|
||||
uint64_t _stream = PyLong_AsLong(stream.ptr());
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
||||
nullptr, config);
|
||||
if(grid_0*grid_1*grid_2 > 0)
|
||||
drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
|
||||
_num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
|
||||
nullptr, config);
|
||||
return bin;
|
||||
});
|
||||
|
||||
@@ -293,12 +297,16 @@ void init_triton_runtime(py::module &&m) {
|
||||
const std::string &args, int64_t shared_mem){
|
||||
void* args_ptr = (void*)args.data();
|
||||
size_t args_size = args.size();
|
||||
// release the gil in case the enqueue blocks
|
||||
// cuda will block if too many ops are enqueued
|
||||
Py_BEGIN_ALLOW_THREADS
|
||||
if(backend == HOST)
|
||||
host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
if(backend == CUDA)
|
||||
cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
if(backend == ROCM)
|
||||
hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
|
||||
Py_END_ALLOW_THREADS
|
||||
});
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user