[RUNTIME] Major code cleanup (#711)
This PR does the following: - CUDA utilities (e.g., cuGetInfo) won't be compiled as part of libtriton.so anymore. - Refactoring driver/llvm.cc to split it between PTX codegen and python. - By extension this will also deprecate include/external so Triton won't have to live with a copy of some CUDA/Hip headers anymore. - `triton-translate` becomes a `triton.tools.aot` Python utility that re-uses functions from the triton.compile sub-module.
This commit is contained in:
@@ -7,6 +7,7 @@ import hashlib
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -843,7 +844,11 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
return mod
|
||||
|
||||
|
||||
def make_ptx(mod: Any, device: int) -> Tuple[str, int]:
|
||||
def make_llvm_ir(mod):
|
||||
return _triton.translate_triton_gpu_to_llvmir(mod)
|
||||
|
||||
|
||||
def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str, int]:
|
||||
'''
|
||||
Translate TritonGPU module to PTX code.
|
||||
:param mod: a TritonGPU dialect module
|
||||
@@ -851,17 +856,17 @@ def make_ptx(mod: Any, device: int) -> Tuple[str, int]:
|
||||
- PTX code
|
||||
- shared memory alloaction size
|
||||
'''
|
||||
return _triton.translate_triton_gpu_to_ptx(mod, device)
|
||||
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
|
||||
|
||||
|
||||
def make_cubin(ptx, device):
|
||||
def make_cubin(ptx: str, ptxas: str, compute_capability: int):
|
||||
'''
|
||||
Compile TritonGPU module to cubin.
|
||||
:param ptx: ptx code
|
||||
:param device: CUDA device
|
||||
:return: str
|
||||
'''
|
||||
return _triton.compile_ptx_to_cubin(ptx, device)
|
||||
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
|
||||
|
||||
|
||||
def ptx_get_kernel_name(ptx: str) -> str:
|
||||
@@ -877,6 +882,46 @@ def ptx_get_kernel_name(ptx: str) -> str:
|
||||
return line.split()[-1]
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def ptx_get_version(cuda_version) -> int:
|
||||
'''
|
||||
Get the highest PTX version supported by the current CUDA driver.
|
||||
'''
|
||||
assert isinstance(cuda_version, str)
|
||||
major, minor = map(int, cuda_version.split('.'))
|
||||
version = major * 1000 + minor * 10
|
||||
if version >= 11040:
|
||||
return 74
|
||||
if version >= 11030:
|
||||
return 73
|
||||
if version >= 11020:
|
||||
return 72
|
||||
if version >= 11010:
|
||||
return 71
|
||||
if version >= 11000:
|
||||
return 70
|
||||
if version >= 10020:
|
||||
return 65
|
||||
if version >= 10010:
|
||||
return 64
|
||||
if version >= 10000:
|
||||
return 63
|
||||
raise RuntimeError("Triton only support CUDA 10.0 or higher")
|
||||
|
||||
|
||||
def path_to_ptxas():
|
||||
prefixes = [os.environ.get("TRITON_PTXAS_PATH", ""), "", "/usr/local/cuda/"]
|
||||
for prefix in prefixes:
|
||||
ptxas = os.path.join(prefix, "bin", "ptxas")
|
||||
if os.path.exists(ptxas):
|
||||
result = subprocess.check_output([ptxas, "--version"], stderr=subprocess.STDOUT)
|
||||
if result is not None:
|
||||
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
|
||||
if version is not None:
|
||||
return ptxas, version.group(1)
|
||||
raise RuntimeError("Cannot find ptxas")
|
||||
|
||||
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
|
||||
|
||||
|
||||
@@ -895,17 +940,24 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat
|
||||
# tritongpu-ir
|
||||
module = make_tritongpu_ir(module, num_warps)
|
||||
module = optimize_tritongpu_ir(module, num_stages)
|
||||
|
||||
if output == "ttgir":
|
||||
return module.str()
|
||||
|
||||
# llvm-ir
|
||||
llvm_ir = make_llvm_ir(module)
|
||||
|
||||
assert device >= 0, "device should be provided."
|
||||
ptx, shem_size = make_ptx(module, device)
|
||||
ptxas, cuda_version = path_to_ptxas()
|
||||
compute_capability = torch.cuda.get_device_capability(device)
|
||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||
ptx_version = ptx_get_version(cuda_version)
|
||||
ptx = make_ptx(llvm_ir, compute_capability, ptx_version)
|
||||
shem_size = _triton.get_shared_memory_size(module)
|
||||
kernel_name = ptx_get_kernel_name(ptx)
|
||||
if output == "ptx":
|
||||
return ptx, shem_size, kernel_name
|
||||
|
||||
cubin = make_cubin(ptx, device)
|
||||
cubin = make_cubin(ptx, ptxas, compute_capability)
|
||||
if output == "cubin":
|
||||
return cubin, ptx, shem_size, kernel_name
|
||||
|
||||
@@ -980,6 +1032,7 @@ def generate_launcher(identifier, constants, signature):
|
||||
src = f"""
|
||||
#include \"cuda.h\"
|
||||
#include <Python.h>
|
||||
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
{{
|
||||
if (code != CUDA_SUCCESS)
|
||||
@@ -993,13 +1046,16 @@ static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}}
|
||||
}}
|
||||
|
||||
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
|
||||
|
||||
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
|
||||
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
|
||||
if(gridX*gridY*gridZ > 0){{
|
||||
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
|
||||
}}
|
||||
}}
|
||||
|
||||
static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
|
||||
if (PyLong_Check(obj)) {{
|
||||
return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj);
|
||||
@@ -1021,6 +1077,7 @@ static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
|
||||
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
|
||||
return (CUdeviceptr)0;
|
||||
}}
|
||||
|
||||
static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
int gridX, gridY, gridZ;
|
||||
uint64_t _stream;
|
||||
@@ -1039,10 +1096,12 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {{
|
||||
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
|
||||
{{NULL, NULL, 0, NULL}} // sentinel
|
||||
}};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {{
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"launcher\",
|
||||
@@ -1050,6 +1109,7 @@ static struct PyModuleDef ModuleDef = {{
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
}};
|
||||
|
||||
PyMODINIT_FUNC PyInit_launcher(void) {{
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {{
|
||||
@@ -1251,7 +1311,10 @@ class CompiledKernel:
|
||||
self.asm["ptx"] = f.read()
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
mod, func, n_regs, n_spills = _triton.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
global cuda_utils
|
||||
if cuda_utils is None:
|
||||
cuda_utils = CudaUtils()
|
||||
mod, func, n_regs, n_spills = cuda_utils.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
|
||||
@@ -1261,3 +1324,118 @@ class CompiledKernel:
|
||||
stream = torch.cuda.current_stream().cuda_stream
|
||||
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
|
||||
return
|
||||
|
||||
|
||||
class CudaUtils(object):
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, 'instance'):
|
||||
cls.instance = super(CudaUtils, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def _generate_src(self):
|
||||
return """
|
||||
#include <cuda.h>
|
||||
|
||||
#include \"cuda.h\"
|
||||
#include <Python.h>
|
||||
|
||||
static inline void gpuAssert(CUresult code, const char *file, int line)
|
||||
{
|
||||
if (code != CUDA_SUCCESS)
|
||||
{
|
||||
const char* prefix = "Triton Error [CUDA]: ";
|
||||
const char* str;
|
||||
cuGetErrorString(code, &str);
|
||||
char err[1024] = {0};
|
||||
strcat(err, prefix);
|
||||
strcat(err, str);
|
||||
PyErr_SetString(PyExc_RuntimeError, err);
|
||||
}
|
||||
}
|
||||
|
||||
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); }
|
||||
|
||||
static PyObject* loadBinary(PyObject* self, PyObject* args) {
|
||||
const char* name;
|
||||
const char* data;
|
||||
Py_ssize_t data_size;
|
||||
int shared;
|
||||
int device;
|
||||
if(!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, &device)) {
|
||||
return NULL;
|
||||
}
|
||||
CUfunction fun;
|
||||
CUmodule mod;
|
||||
int32_t n_regs = 0;
|
||||
int32_t n_spills = 0;
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
// create driver handles
|
||||
CUDA_CHECK(cuModuleLoadData(&mod, data));
|
||||
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
|
||||
// get allocated registers and spilled registers from the function
|
||||
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
|
||||
CUDA_CHECK(cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
|
||||
n_spills /= 4;
|
||||
// set dynamic shared memory if necessary
|
||||
int shared_optin;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device));
|
||||
if (shared > 49152 && shared_optin > 49152) {
|
||||
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
|
||||
int shared_total, shared_static;
|
||||
CUDA_CHECK(cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device));
|
||||
CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
|
||||
CUDA_CHECK(cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static));
|
||||
}
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
if(PyErr_Occurred()) {
|
||||
return NULL;
|
||||
}
|
||||
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, n_spills);
|
||||
}
|
||||
|
||||
static PyMethodDef ModuleMethods[] = {
|
||||
{"load_binary", loadBinary, METH_VARARGS, "Load provided cubin into CUDA driver"},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
static struct PyModuleDef ModuleDef = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
\"cuda_utils\",
|
||||
NULL, //documentation
|
||||
-1, //size
|
||||
ModuleMethods
|
||||
};
|
||||
|
||||
PyMODINIT_FUNC PyInit_cuda_utils(void) {
|
||||
PyObject *m = PyModule_Create(&ModuleDef);
|
||||
if(m == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
PyModule_AddFunctions(m, ModuleMethods);
|
||||
return m;
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
src = self._generate_src()
|
||||
key = hashlib.md5(src.encode("utf-8")).hexdigest()
|
||||
cache = CacheManager(key)
|
||||
fname = "cuda_utils.so"
|
||||
if not cache.has_file(fname):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src_path = os.path.join(tmpdir, "main.c")
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
so = _build("cuda_utils", src_path, tmpdir)
|
||||
with open(so, "rb") as f:
|
||||
cache.put(f.read(), fname, binary=True)
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("cuda_utils", cache._make_path(fname))
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
self.load_binary = mod.load_binary
|
||||
|
||||
|
||||
cuda_utils = None
|
||||
|
Reference in New Issue
Block a user