[RUNTIME] now decoupling entry point from cubin (#696)
This commit is contained in:
@@ -430,150 +430,90 @@ void init_triton_runtime(py::module &&m) {
|
|||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
typedef std::map<std::string, py::object> asm_map_t;
|
typedef std::map<std::string, py::object> asm_map_t;
|
||||||
|
|
||||||
// ---------------------------------------
|
|
||||||
// Load provided assembly code into driver
|
|
||||||
// ---------------------------------------
|
|
||||||
|
|
||||||
// CUDA
|
|
||||||
std::tuple<uint64_t, uint64_t, uint64_t, uint64_t> cu_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
|
||||||
// load assembly
|
|
||||||
std::string assembly;
|
|
||||||
if(asm_map.find("cubin") != asm_map.end())
|
|
||||||
assembly = py::cast<std::string>(asm_map["cubin"]);
|
|
||||||
else
|
|
||||||
assembly = py::cast<std::string>(asm_map["ptx"]);
|
|
||||||
// create driver handles
|
|
||||||
CUfunction fun;
|
|
||||||
CUmodule mod;
|
|
||||||
drv::dispatch::cuModuleLoadData(&mod, assembly.c_str());
|
|
||||||
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
|
||||||
// get allocated registers and spilled registers from the function
|
|
||||||
int n_regs = 0;
|
|
||||||
int n_spills = 0;
|
|
||||||
drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
|
||||||
drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
|
||||||
n_spills /= 4;
|
|
||||||
// set dynamic shared memory if necessary
|
|
||||||
int shared_optin;
|
|
||||||
drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, dev);
|
|
||||||
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
|
||||||
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
|
|
||||||
int shared_total, shared_static;
|
|
||||||
drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, dev);
|
|
||||||
drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
|
|
||||||
drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
|
||||||
}
|
|
||||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills);
|
|
||||||
}
|
|
||||||
|
|
||||||
// ROCM
|
|
||||||
std::tuple<uint64_t, uint64_t, uint64_t, uint64_t> hip_load_binary(const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
|
||||||
py::bytes _assembly = asm_map["hsaco"];
|
|
||||||
std::string assembly = py::cast<std::string>(_assembly);
|
|
||||||
// HSA-CO -> hipModule
|
|
||||||
hipModule_t mod = drv::amdgpu_to_hipmodule(assembly);
|
|
||||||
// Handle to the kernel
|
|
||||||
hipFunction_t fun;
|
|
||||||
drv::dispatch::hipModuleGetFunction(&fun, mod, name.c_str());
|
|
||||||
// record asm
|
|
||||||
return std::make_tuple((uint64_t)mod, (uint64_t)fun, 0, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------------------------------
|
// ---------------------------------------
|
||||||
// Compile Triton-IR to assembly
|
// Compile Triton-IR to assembly
|
||||||
// ---------------------------------------
|
// ---------------------------------------
|
||||||
|
|
||||||
// CUDA
|
|
||||||
std::tuple<std::string, asm_map_t, int> cu_compile_ttir(
|
|
||||||
const std::string &name, ir::module &ir, uint64_t device, int num_warps,
|
|
||||||
int num_stages, asm_map_t &asm_map,
|
|
||||||
const triton::codegen::ExternLibMap &extern_lib_map) {
|
|
||||||
py::gil_scoped_release allow_threads;
|
|
||||||
llvm::LLVMContext ctx;
|
|
||||||
// device properties
|
|
||||||
CUdevice dev = (CUdevice)device;
|
|
||||||
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
|
|
||||||
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
|
||||||
size_t cc = major*10 + minor;
|
|
||||||
int version;
|
|
||||||
std::string ptxas_path = drv::path_to_ptxas(version);
|
|
||||||
// Triton-IR -> NVPTX LLVM-IR
|
|
||||||
triton::codegen::nvidia_cu_target target(cc);
|
|
||||||
int n_shared_bytes;
|
|
||||||
auto llvm = triton::codegen::add_passes_to_emit_bin(
|
|
||||||
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
|
|
||||||
std::string tmp;
|
|
||||||
llvm::raw_string_ostream llir(tmp);
|
|
||||||
llir << *llvm;
|
|
||||||
llir.flush();
|
|
||||||
asm_map["llir"] = py::cast(tmp);
|
|
||||||
// LLVM-IR -> PTX
|
|
||||||
std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
|
|
||||||
asm_map["ptx"] = py::cast(ptx);
|
|
||||||
// PTX -> Binary
|
|
||||||
std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc);
|
|
||||||
if(!cubin.empty()){
|
|
||||||
py::bytes bytes(cubin);
|
|
||||||
asm_map["cubin"] = bytes;
|
|
||||||
}
|
|
||||||
return std::make_tuple(name, asm_map, n_shared_bytes);
|
|
||||||
}
|
|
||||||
|
|
||||||
// HIP
|
|
||||||
std::tuple<std::string, asm_map_t, int> hip_compile_ttir(
|
|
||||||
const std::string &name, ir::module &ir, uint64_t device, int num_warps,
|
|
||||||
int num_stages, asm_map_t &asm_map,
|
|
||||||
const triton::codegen::ExternLibMap &extern_lib_map) {
|
|
||||||
llvm::LLVMContext ctx;
|
|
||||||
// Triton-IR -> NVPTX LLVM-IR
|
|
||||||
triton::codegen::amd_cl_target target;
|
|
||||||
int n_shared_bytes;
|
|
||||||
auto llvm = triton::codegen::add_passes_to_emit_bin(
|
|
||||||
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
|
|
||||||
std::string tmp;
|
|
||||||
llvm::raw_string_ostream llir(tmp);
|
|
||||||
llir << *llvm;
|
|
||||||
llir.flush();
|
|
||||||
asm_map["llir"] = py::cast(tmp);
|
|
||||||
// LLVM-IR -> HSA-CO
|
|
||||||
std::string path = drv::llir_to_amdgpu(llvm.get(), "gfx908");
|
|
||||||
asm_map["hsaco"] = py::cast(path);
|
|
||||||
return std::make_tuple(name, asm_map, n_shared_bytes);
|
|
||||||
}
|
|
||||||
|
|
||||||
void init_triton_codegen(py::module &&m) {
|
void init_triton_codegen(py::module &&m) {
|
||||||
m.def(
|
m.def("compile_ttir",
|
||||||
"compile_ttir",
|
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs) {
|
||||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps,
|
py::gil_scoped_release allow_threads;
|
||||||
int num_stages, py::dict& extern_libs) {
|
std::string name = ir.get_function_list()[0]->get_name();
|
||||||
std::string name = ir.get_function_list()[0]->get_name();
|
// record asm as we generate
|
||||||
// record asm as we generate
|
asm_map_t asm_map;
|
||||||
asm_map_t asm_map;
|
std::ostringstream ttir;
|
||||||
std::ostringstream ttir;
|
ir.print(ttir);
|
||||||
ir.print(ttir);
|
asm_map["ttir"] = py::cast(ttir.str());
|
||||||
asm_map["ttir"] = py::cast(ttir.str());
|
llvm::LLVMContext ctx;
|
||||||
llvm::LLVMContext ctx;
|
// construct extern lib map
|
||||||
// construct extern lib map
|
triton::codegen::ExternLibMap extern_lib_map;
|
||||||
triton::codegen::ExternLibMap extern_lib_map;
|
for (auto item : extern_libs) {
|
||||||
for (auto item : extern_libs) {
|
auto name = item.first.cast<std::string>();
|
||||||
auto name = item.first.cast<std::string>();
|
auto path = item.second.cast<std::string>();
|
||||||
auto path = item.second.cast<std::string>();
|
extern_lib_map.emplace(
|
||||||
extern_lib_map.emplace(
|
name, triton::codegen::create_extern_lib(name, path));
|
||||||
name, triton::codegen::create_extern_lib(name, path));
|
}
|
||||||
}
|
// device properties
|
||||||
if(backend == CUDA)
|
CUdevice dev = (CUdevice)device;
|
||||||
return cu_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map);
|
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
|
||||||
assert(backend == ROCM);
|
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
||||||
return hip_compile_ttir(name, ir, device, num_warps, num_stages, asm_map, extern_lib_map);
|
size_t cc = major*10 + minor;
|
||||||
|
int version;
|
||||||
|
std::string ptxas_path = drv::path_to_ptxas(version);
|
||||||
|
// Triton-IR -> NVPTX LLVM-IR
|
||||||
|
triton::codegen::nvidia_cu_target target(cc);
|
||||||
|
int n_shared_bytes;
|
||||||
|
auto llvm = triton::codegen::add_passes_to_emit_bin(
|
||||||
|
ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
|
||||||
|
std::string tmp;
|
||||||
|
llvm::raw_string_ostream llir(tmp);
|
||||||
|
llir << *llvm;
|
||||||
|
llir.flush();
|
||||||
|
asm_map["llir"] = py::cast(tmp);
|
||||||
|
// LLVM-IR -> PTX
|
||||||
|
std::string ptx = drv::llir_to_ptx(llvm.get(), cc, version);
|
||||||
|
asm_map["ptx"] = py::cast(ptx);
|
||||||
|
// PTX -> Binary
|
||||||
|
std::string cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc);
|
||||||
|
if(!cubin.empty()){
|
||||||
|
py::bytes bytes(cubin);
|
||||||
|
asm_map["cubin"] = bytes;
|
||||||
|
}
|
||||||
|
return std::make_tuple(name, asm_map, n_shared_bytes);
|
||||||
},
|
},
|
||||||
py::return_value_policy::take_ownership);
|
py::return_value_policy::take_ownership);
|
||||||
m.def("load_binary", [](backend_t backend, const std::string& name, asm_map_t &asm_map, size_t n_shared_bytes, uint64_t dev){
|
|
||||||
py::gil_scoped_release allow_threads;
|
|
||||||
if(backend == CUDA)
|
// ---------------------------------------
|
||||||
return cu_load_binary(name, asm_map, n_shared_bytes, dev);
|
// Load provided assembly code into driver
|
||||||
assert(backend == ROCM);
|
// ---------------------------------------
|
||||||
return hip_load_binary(name, asm_map, n_shared_bytes, dev);
|
m.def("load_binary", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
|
||||||
}, py::return_value_policy::take_ownership);
|
py::gil_scoped_release allow_threads;
|
||||||
|
// create driver handles
|
||||||
|
CUfunction fun;
|
||||||
|
CUmodule mod;
|
||||||
|
drv::dispatch::cuModuleLoadData(&mod, data.c_str());
|
||||||
|
drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
|
||||||
|
// get allocated registers and spilled registers from the function
|
||||||
|
int n_regs = 0;
|
||||||
|
int n_spills = 0;
|
||||||
|
drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
|
||||||
|
drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
|
||||||
|
n_spills /= 4;
|
||||||
|
// set dynamic shared memory if necessary
|
||||||
|
int shared_optin;
|
||||||
|
drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device);
|
||||||
|
if(n_shared_bytes > 49152 && shared_optin > 49152){
|
||||||
|
drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
|
||||||
|
int shared_total, shared_static;
|
||||||
|
drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device);
|
||||||
|
drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
|
||||||
|
drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
|
||||||
|
}
|
||||||
|
return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills);
|
||||||
|
},
|
||||||
|
py::return_value_policy::take_ownership
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
struct InstanceDescriptor
|
struct InstanceDescriptor
|
||||||
|
@@ -5,6 +5,7 @@ import contextlib
|
|||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -926,23 +927,7 @@ def binary_name_to_header_name(name):
|
|||||||
return f"{name}.h"
|
return f"{name}.h"
|
||||||
|
|
||||||
|
|
||||||
def generate_torch_glue(kernel_name, constants, signature, num_warps, binaries, tmpdir):
|
def generate_launcher(identifier, constants, signature):
|
||||||
headers = dict()
|
|
||||||
|
|
||||||
# write all cubins to header files
|
|
||||||
assert len(binaries) == 1, "AoT compilation not yet supported"
|
|
||||||
|
|
||||||
for bin, shmem_size, name in binaries:
|
|
||||||
assert len(name) < 1024
|
|
||||||
initializer = f"""
|
|
||||||
const char* {name}_ptx = R"({bin["ptx"]})";
|
|
||||||
unsigned char {name}_bin[] = {{ {','.join(map(hex, bin["cubin"]))} }};
|
|
||||||
unsigned int {name}_shmem = {shmem_size};"""
|
|
||||||
headers[name] = os.path.join(tmpdir, binary_name_to_header_name(name))
|
|
||||||
with open(headers[name], "w") as f:
|
|
||||||
f.write(initializer)
|
|
||||||
|
|
||||||
func_init = '\n '.join(f"init_function(\"{name}\", {name}_bin, {name}_shmem, device);" for _, _, name in binaries)
|
|
||||||
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
|
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())
|
||||||
|
|
||||||
def _extracted_type(ty):
|
def _extracted_type(ty):
|
||||||
@@ -970,13 +955,10 @@ unsigned int {name}_shmem = {shmem_size};"""
|
|||||||
"int64_t": "L",
|
"int64_t": "L",
|
||||||
}[ty]
|
}[ty]
|
||||||
|
|
||||||
format = "iiiK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
format = "iiiiiKK" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()])
|
||||||
|
|
||||||
# generate glue code
|
# generate glue code
|
||||||
src = ""
|
src = f"""
|
||||||
for bin, shmem_size, name in binaries:
|
|
||||||
src += f"#include \"{headers[name]}\"\n"
|
|
||||||
src += f"""
|
|
||||||
#include \"cuda.h\"
|
#include \"cuda.h\"
|
||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
|
|
||||||
@@ -995,50 +977,16 @@ static inline void gpuAssert(CUresult code, const char *file, int line)
|
|||||||
}}
|
}}
|
||||||
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||||
|
|
||||||
static CUmodule module = 0;
|
|
||||||
static CUfunction function = 0;
|
|
||||||
|
|
||||||
static inline void init_function(const char* name, const unsigned char* src, size_t n_shared_bytes, int64_t device){{
|
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
|
||||||
CUmodule mod;
|
|
||||||
CUfunction fun;
|
|
||||||
CUDA_CHECK(cuModuleLoadData(&mod, src));
|
|
||||||
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
|
|
||||||
// set dynamic shared memory if necessary
|
|
||||||
int shared_optin;
|
|
||||||
CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device));
|
|
||||||
if (n_shared_bytes > 49152 && shared_optin > 49152) {{
|
|
||||||
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
|
|
||||||
int shared_total, shared_static;
|
|
||||||
int n_spills, n_reg;
|
|
||||||
CUDA_CHECK(cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device));
|
|
||||||
CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
|
|
||||||
CUDA_CHECK(cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
|
|
||||||
CUDA_CHECK(cuFuncGetAttribute(&n_reg, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
|
|
||||||
CUDA_CHECK(cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static));
|
|
||||||
}}
|
|
||||||
module = mod;
|
|
||||||
function = fun;
|
|
||||||
}}
|
|
||||||
|
|
||||||
static inline void init_module(CUdevice device) {{
|
|
||||||
{func_init}
|
|
||||||
}}
|
|
||||||
|
|
||||||
|
|
||||||
void _{kernel_name}(int gridX, int gridY, int gridZ, CUstream stream, {arg_decls}) {{
|
|
||||||
// TODO: machine may have heterogeneous devices
|
|
||||||
if(function == 0){{
|
|
||||||
CUdevice device;
|
|
||||||
CUDA_CHECK(cuCtxGetDevice(&device));
|
|
||||||
init_module(device);
|
|
||||||
}}
|
|
||||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||||
if(gridX*gridY*gridZ > 0){{
|
if(gridX*gridY*gridZ > 0){{
|
||||||
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*{num_warps}, 1, 1, {name}_shmem, stream, params, 0));
|
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
|
||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
|
|
||||||
CUdeviceptr getPointer(PyObject *obj, int idx) {{
|
|
||||||
|
static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
|
||||||
if (PyLong_Check(obj)) {{
|
if (PyLong_Check(obj)) {{
|
||||||
return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj);
|
return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj);
|
||||||
}}
|
}}
|
||||||
@@ -1061,15 +1009,18 @@ CUdeviceptr getPointer(PyObject *obj, int idx) {{
|
|||||||
}}
|
}}
|
||||||
|
|
||||||
|
|
||||||
static PyObject* {kernel_name}(PyObject* self, PyObject* args) {{
|
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||||
int gridX, gridY, gridZ;
|
int gridX, gridY, gridZ;
|
||||||
uint64_t stream;
|
uint64_t _stream;
|
||||||
|
uint64_t _function;
|
||||||
|
int num_warps;
|
||||||
|
int shared_memory;
|
||||||
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
|
||||||
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &stream, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
|
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
|
||||||
return NULL;
|
return NULL;
|
||||||
}}
|
}}
|
||||||
|
|
||||||
_{kernel_name}(gridX, gridY, gridZ, (CUstream)stream, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"getPointer(_arg{i},{i})" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
|
||||||
|
|
||||||
|
|
||||||
if(PyErr_Occurred()) {{
|
if(PyErr_Occurred()) {{
|
||||||
@@ -1081,38 +1032,26 @@ static PyObject* {kernel_name}(PyObject* self, PyObject* args) {{
|
|||||||
}}
|
}}
|
||||||
|
|
||||||
static PyMethodDef ModuleMethods[] = {{
|
static PyMethodDef ModuleMethods[] = {{
|
||||||
{{"{kernel_name}", {kernel_name}, METH_VARARGS, "Call {kernel_name} kernel"}},
|
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||||
{{NULL, NULL, 0, NULL}} // sentinel
|
{{NULL, NULL, 0, NULL}} // sentinel
|
||||||
}};
|
}};
|
||||||
|
|
||||||
static struct PyModuleDef ModuleDef = {{
|
static struct PyModuleDef ModuleDef = {{
|
||||||
PyModuleDef_HEAD_INIT,
|
PyModuleDef_HEAD_INIT,
|
||||||
\"{kernel_name}\",
|
\"launcher\",
|
||||||
NULL, //documentation
|
NULL, //documentation
|
||||||
-1, //size
|
-1, //size
|
||||||
ModuleMethods
|
ModuleMethods
|
||||||
}};
|
}};
|
||||||
|
|
||||||
PyMODINIT_FUNC PyInit_{kernel_name}(void) {{
|
PyMODINIT_FUNC PyInit_launcher(void) {{
|
||||||
PyObject *m = PyModule_Create(&ModuleDef);
|
PyObject *m = PyModule_Create(&ModuleDef);
|
||||||
if(m == NULL) {{
|
if(m == NULL) {{
|
||||||
return NULL;
|
return NULL;
|
||||||
}}
|
}}
|
||||||
PyModule_AddFunctions(m, ModuleMethods);
|
PyModule_AddFunctions(m, ModuleMethods);
|
||||||
PyObject *ptx = PyDict_New();
|
|
||||||
"""
|
|
||||||
|
|
||||||
for _, _, name in binaries:
|
|
||||||
src += f"""
|
|
||||||
PyObject *py_{name}_ptx = PyUnicode_FromString({name}_ptx);
|
|
||||||
PyDict_SetItemString(ptx, "{name}", py_{name}_ptx);
|
|
||||||
Py_DECREF(py_{name}_ptx);
|
|
||||||
"""
|
|
||||||
|
|
||||||
src += """
|
|
||||||
PyModule_AddObject(m, "ptx", ptx);
|
|
||||||
return m;
|
return m;
|
||||||
}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return src
|
return src
|
||||||
@@ -1126,35 +1065,34 @@ class CacheManager:
|
|||||||
|
|
||||||
def __init__(self, key):
|
def __init__(self, key):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.bin_path = None
|
|
||||||
self.lock_path = None
|
self.lock_path = None
|
||||||
# if caching is enabled, get the lock and bin path
|
# create cache directory if it doesn't exist
|
||||||
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
|
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
|
||||||
if self.cache_dir:
|
if self.cache_dir:
|
||||||
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||||
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||||
os.makedirs(self.cache_dir, exist_ok=True)
|
os.makedirs(self.cache_dir, exist_ok=True)
|
||||||
if self.cache_dir:
|
|
||||||
self.bin_path = os.path.join(self.cache_dir, self.key + ".so")
|
|
||||||
self.lock_path = self.bin_path + ".lock"
|
|
||||||
|
|
||||||
def has_file(self):
|
def _make_path(self, filename):
|
||||||
return self.bin_path and os.path.exists(self.bin_path)
|
return os.path.join(self.cache_dir, filename)
|
||||||
|
|
||||||
def put(self, binary):
|
def has_file(self, filename):
|
||||||
if self.bin_path:
|
if not self.cache_dir:
|
||||||
assert self.lock_path is not None
|
return False
|
||||||
with FileLock(self.lock_path):
|
return os.path.exists(self._make_path(filename))
|
||||||
with open(self.bin_path + ".tmp", "wb") as f:
|
|
||||||
f.write(binary)
|
|
||||||
os.rename(self.bin_path + ".tmp", self.bin_path)
|
|
||||||
|
|
||||||
|
def put(self, data, filename, binary=True):
|
||||||
|
if not self.cache_dir:
|
||||||
|
return
|
||||||
|
assert self.lock_path is not None
|
||||||
|
filepath = self._make_path(filename)
|
||||||
|
with FileLock(self.lock_path):
|
||||||
|
# use tempfile to be robust against program interruptions
|
||||||
|
mode = "wb" if binary else "w"
|
||||||
|
with open(filepath + ".tmp", mode) as f:
|
||||||
|
f.write(data)
|
||||||
|
os.rename(filepath + ".tmp", filepath)
|
||||||
|
|
||||||
def make_cache_key(fn, signature, configs, constants, num_warps, num_stages):
|
|
||||||
# Get unique key for the compiled code
|
|
||||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
|
|
||||||
configs_key = [get_conf_key(conf) for conf in configs]
|
|
||||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
|
|
||||||
key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
|
||||||
return key
|
|
||||||
|
|
||||||
# utilties for generating and compiling C wrappers
|
# utilties for generating and compiling C wrappers
|
||||||
|
|
||||||
@@ -1224,54 +1162,91 @@ def _build(name, src, srcdir):
|
|||||||
return so
|
return so
|
||||||
|
|
||||||
|
|
||||||
|
def make_so_cache_key(signature, constants):
|
||||||
|
# Get unique key for the compiled code
|
||||||
|
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
|
||||||
|
key = f"{''.join(signature.values())}{constants}"
|
||||||
|
key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_stages):
|
||||||
|
# Get unique key for the compiled code
|
||||||
|
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
|
||||||
|
configs_key = [get_conf_key(conf) for conf in configs]
|
||||||
|
key = f"{fn_hash}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}"
|
||||||
|
key = hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||||
# we get the kernel, i.e. the first function generated in the module
|
# we get the kernel, i.e. the first function generated in the module
|
||||||
if configs is None:
|
|
||||||
assert False, "automatic specialization is not supported yet"
|
|
||||||
ref, _ = make_triton_ir(fn, signature, _triton.code_gen.instance_descriptor(), constants)
|
|
||||||
fns = ref.get_functions()
|
|
||||||
configs = _triton.infer_specialization_configs(fns[0])
|
|
||||||
assert len(configs) == 1
|
assert len(configs) == 1
|
||||||
# cache manager
|
# cache manager
|
||||||
cache_key = make_cache_key(fn, signature, configs, constants, num_warps, num_stages)
|
name = fn.__name__
|
||||||
cache_manager = CacheManager(cache_key)
|
# name of files that are cached
|
||||||
# retrieve cached shared object if it exists
|
so_cache_key = make_so_cache_key(signature, constants)
|
||||||
if cache_manager.has_file():
|
so_cache_manager = CacheManager(so_cache_key)
|
||||||
return CompiledKernel(fn.__name__, cache_manager.bin_path)
|
so_name = f"{name}.so"
|
||||||
# compile all the configs
|
# retrieve stub from cache if it exists
|
||||||
binaries = []
|
if not so_cache_manager.has_file(so_name):
|
||||||
for config in configs:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
binaries.append(_compile(fn, signature, device, constants, config, num_warps, num_stages, extern_libs, "cubin"))
|
src = generate_launcher(name, constants, signature)
|
||||||
# generate and compile glue code into shared object
|
src_path = os.path.join(tmpdir, "main.c")
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with open(src_path, "w") as f:
|
||||||
all_constants = set(constants.keys())
|
f.write(src)
|
||||||
all_constants.update(configs[0].equal_to_1)
|
so = _build(fn.__name__, src_path, tmpdir)
|
||||||
src = generate_torch_glue(fn.__name__, constants, signature, num_warps, binaries, tmpdir)
|
with open(so, "rb") as f:
|
||||||
src_path = os.path.join(tmpdir, "main.c")
|
so_cache_manager.put(f.read(), so_name, binary=True)
|
||||||
with open(src_path, "w") as f:
|
|
||||||
f.write(src)
|
|
||||||
so = _build(fn.__name__, src_path, tmpdir)
|
|
||||||
with open(so, "rb") as f:
|
|
||||||
cache_manager.put(f.read())
|
|
||||||
|
|
||||||
return CompiledKernel(fn.__name__, cache_manager.bin_path)
|
# retrieve cached shared object if it exists
|
||||||
|
fn_cache_key = make_fn_cache_key(fn.cache_key, signature, configs, constants, num_warps, num_stages)
|
||||||
|
fn_cache_manager = CacheManager(fn_cache_key)
|
||||||
|
ptx_name = f"{name}.ptx"
|
||||||
|
cubin_name = f"{name}.cubin"
|
||||||
|
data_name = f"{name}.json"
|
||||||
|
if not fn_cache_manager.has_file(cubin_name) or \
|
||||||
|
not fn_cache_manager.has_file(data_name) or \
|
||||||
|
not fn_cache_manager.has_file(ptx_name):
|
||||||
|
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin")
|
||||||
|
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||||
|
fn_cache_manager.put(asm["cubin"], cubin_name)
|
||||||
|
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
|
||||||
|
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
||||||
|
|
||||||
|
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir)
|
||||||
|
|
||||||
|
|
||||||
class CompiledKernel:
|
class CompiledKernel:
|
||||||
|
|
||||||
def __init__(self, fn_name, data_path):
|
def __init__(self, fn_name, so_path, cache_dir):
|
||||||
|
# initialize launcher
|
||||||
import importlib.util
|
import importlib.util
|
||||||
spec = importlib.util.spec_from_file_location(fn_name, data_path)
|
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
||||||
mod = importlib.util.module_from_spec(spec)
|
mod = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(mod)
|
spec.loader.exec_module(mod)
|
||||||
self.c_wrapper = getattr(mod, fn_name)
|
self.c_wrapper = getattr(mod, "launch")
|
||||||
ptx = getattr(mod, "ptx")
|
# initialize metadata
|
||||||
if len(ptx) == 1:
|
with open(os.path.join(cache_dir, f"{fn_name}.json")) as f:
|
||||||
self.asm = {"ptx": list(ptx.values())[0]}
|
metadata = json.load(f)
|
||||||
|
self.shared = metadata["shared"]
|
||||||
|
self.num_warps = metadata["num_warps"]
|
||||||
|
self.num_stages = metadata["num_stages"]
|
||||||
|
# initialize asm dict
|
||||||
|
self.asm = dict()
|
||||||
|
with open(os.path.join(cache_dir, f"{fn_name}.cubin"), "rb") as f:
|
||||||
|
self.asm["cubin"] = f.read()
|
||||||
|
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||||
|
self.asm["ptx"] = f.read()
|
||||||
|
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||||
|
self.cu_module = mod
|
||||||
|
self.cu_function = func
|
||||||
|
|
||||||
def __getitem__(self, grid):
|
def __getitem__(self, grid):
|
||||||
def runner(*args, stream=None):
|
def runner(*args, stream=None):
|
||||||
if stream is None:
|
if stream is None:
|
||||||
stream = torch.cuda.current_stream().cuda_stream
|
stream = torch.cuda.current_stream().cuda_stream
|
||||||
self.c_wrapper(grid[0], grid[1], grid[2], stream, *args)
|
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
|
||||||
return runner
|
return runner
|
||||||
|
@@ -253,7 +253,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
try:
|
try:
|
||||||
bin = cache[key]
|
bin = cache[key]
|
||||||
if not warmup:
|
if not warmup:
|
||||||
bin.c_wrapper(grid_0, grid_1, grid_2, stream, {args})
|
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, {args})
|
||||||
return bin
|
return bin
|
||||||
# kernel not cached -- compile
|
# kernel not cached -- compile
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@@ -274,7 +274,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||||
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
||||||
if not warmup:
|
if not warmup:
|
||||||
bin.c_wrapper(grid_0, grid_1, grid_2, stream, *args)
|
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, *args)
|
||||||
self.cache[key] = bin
|
self.cache[key] = bin
|
||||||
return bin
|
return bin
|
||||||
return None
|
return None
|
||||||
|
Reference in New Issue
Block a user