[RUNTIME] Major code cleanup (#711)

This PR does the following:
- CUDA utilities (e.g., cuGetInfo) won't be compiled as part of libtriton.so anymore.
- Refactoring driver/llvm.cc to split it between PTX codegen and python.
- By extension this will also deprecate include/external so Triton won't have to live with a copy of some CUDA/Hip headers anymore.
- `triton-translate` becomes a `triton.tools.aot` Python utility that re-uses functions from the triton.compile sub-module.
This commit is contained in:
Philippe Tillet
2022-09-26 16:38:06 -07:00
committed by GitHub
parent 8bb09f83ee
commit 1e91ed30d0
28 changed files with 509 additions and 31483 deletions

View File

@@ -7,6 +7,7 @@ import hashlib
import io
import json
import os
import re
import shutil
import subprocess
import sys
@@ -843,7 +844,11 @@ def optimize_tritongpu_ir(mod, num_stages):
return mod
def make_ptx(mod: Any, device: int) -> Tuple[str, int]:
def make_llvm_ir(mod):
return _triton.translate_triton_gpu_to_llvmir(mod)
def make_ptx(mod: Any, compute_capability: int, ptx_version: int) -> Tuple[str, int]:
'''
Translate TritonGPU module to PTX code.
:param mod: a TritonGPU dialect module
@@ -851,17 +856,17 @@ def make_ptx(mod: Any, device: int) -> Tuple[str, int]:
- PTX code
- shared memory alloaction size
'''
return _triton.translate_triton_gpu_to_ptx(mod, device)
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
def make_cubin(ptx, device):
def make_cubin(ptx: str, ptxas: str, compute_capability: int):
'''
Compile TritonGPU module to cubin.
:param ptx: ptx code
:param device: CUDA device
:return: str
'''
return _triton.compile_ptx_to_cubin(ptx, device)
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
def ptx_get_kernel_name(ptx: str) -> str:
@@ -877,6 +882,46 @@ def ptx_get_kernel_name(ptx: str) -> str:
return line.split()[-1]
@functools.lru_cache
def ptx_get_version(cuda_version) -> int:
'''
Get the highest PTX version supported by the current CUDA driver.
'''
assert isinstance(cuda_version, str)
major, minor = map(int, cuda_version.split('.'))
version = major * 1000 + minor * 10
if version >= 11040:
return 74
if version >= 11030:
return 73
if version >= 11020:
return 72
if version >= 11010:
return 71
if version >= 11000:
return 70
if version >= 10020:
return 65
if version >= 10010:
return 64
if version >= 10000:
return 63
raise RuntimeError("Triton only support CUDA 10.0 or higher")
def path_to_ptxas():
prefixes = [os.environ.get("TRITON_PTXAS_PATH", ""), "", "/usr/local/cuda/"]
for prefix in prefixes:
ptxas = os.path.join(prefix, "bin", "ptxas")
if os.path.exists(ptxas):
result = subprocess.check_output([ptxas, "--version"], stderr=subprocess.STDOUT)
if result is not None:
version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
if version is not None:
return ptxas, version.group(1)
raise RuntimeError("Cannot find ptxas")
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()])
@@ -895,17 +940,24 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat
# tritongpu-ir
module = make_tritongpu_ir(module, num_warps)
module = optimize_tritongpu_ir(module, num_stages)
if output == "ttgir":
return module.str()
# llvm-ir
llvm_ir = make_llvm_ir(module)
assert device >= 0, "device should be provided."
ptx, shem_size = make_ptx(module, device)
ptxas, cuda_version = path_to_ptxas()
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = compute_capability[0] * 10 + compute_capability[1]
ptx_version = ptx_get_version(cuda_version)
ptx = make_ptx(llvm_ir, compute_capability, ptx_version)
shem_size = _triton.get_shared_memory_size(module)
kernel_name = ptx_get_kernel_name(ptx)
if output == "ptx":
return ptx, shem_size, kernel_name
cubin = make_cubin(ptx, device)
cubin = make_cubin(ptx, ptxas, compute_capability)
if output == "cubin":
return cubin, ptx, shem_size, kernel_name
@@ -980,6 +1032,7 @@ def generate_launcher(identifier, constants, signature):
src = f"""
#include \"cuda.h\"
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line)
{{
if (code != CUDA_SUCCESS)
@@ -993,13 +1046,16 @@ static inline void gpuAssert(CUresult code, const char *file, int line)
PyErr_SetString(PyExc_RuntimeError, err);
}}
}}
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
if(gridX*gridY*gridZ > 0){{
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
}}
}}
static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
if (PyLong_Check(obj)) {{
return (CUdeviceptr)PyLong_AsUnsignedLongLong(obj);
@@ -1021,6 +1077,7 @@ static inline CUdeviceptr getPointer(PyObject *obj, int idx) {{
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return (CUdeviceptr)0;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
int gridX, gridY, gridZ;
uint64_t _stream;
@@ -1039,10 +1096,12 @@ static PyObject* launch(PyObject* self, PyObject* args) {{
Py_INCREF(Py_None);
return Py_None;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"launcher\",
@@ -1050,6 +1109,7 @@ static struct PyModuleDef ModuleDef = {{
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
@@ -1251,7 +1311,10 @@ class CompiledKernel:
self.asm["ptx"] = f.read()
device = torch.cuda.current_device()
mod, func, n_regs, n_spills = _triton.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
global cuda_utils
if cuda_utils is None:
cuda_utils = CudaUtils()
mod, func, n_regs, n_spills = cuda_utils.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
self.cu_module = mod
self.cu_function = func
@@ -1261,3 +1324,118 @@ class CompiledKernel:
stream = torch.cuda.current_stream().cuda_stream
self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, *args)
return
class CudaUtils(object):
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(CudaUtils, cls).__new__(cls)
return cls.instance
def _generate_src(self):
return """
#include <cuda.h>
#include \"cuda.h\"
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line)
{
if (code != CUDA_SUCCESS)
{
const char* prefix = "Triton Error [CUDA]: ";
const char* str;
cuGetErrorString(code, &str);
char err[1024] = {0};
strcat(err, prefix);
strcat(err, str);
PyErr_SetString(PyExc_RuntimeError, err);
}
}
#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); }
static PyObject* loadBinary(PyObject* self, PyObject* args) {
const char* name;
const char* data;
Py_ssize_t data_size;
int shared;
int device;
if(!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, &device)) {
return NULL;
}
CUfunction fun;
CUmodule mod;
int32_t n_regs = 0;
int32_t n_spills = 0;
Py_BEGIN_ALLOW_THREADS;
// create driver handles
CUDA_CHECK(cuModuleLoadData(&mod, data));
CUDA_CHECK(cuModuleGetFunction(&fun, mod, name));
// get allocated registers and spilled registers from the function
CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
CUDA_CHECK(cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
n_spills /= 4;
// set dynamic shared memory if necessary
int shared_optin;
CUDA_CHECK(cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device));
if (shared > 49152 && shared_optin > 49152) {
CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
int shared_total, shared_static;
CUDA_CHECK(cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device));
CUDA_CHECK(cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
CUDA_CHECK(cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static));
}
Py_END_ALLOW_THREADS;
if(PyErr_Occurred()) {
return NULL;
}
return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, n_spills);
}
static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS, "Load provided cubin into CUDA driver"},
{NULL, NULL, 0, NULL} // sentinel
};
static struct PyModuleDef ModuleDef = {
PyModuleDef_HEAD_INIT,
\"cuda_utils\",
NULL, //documentation
-1, //size
ModuleMethods
};
PyMODINIT_FUNC PyInit_cuda_utils(void) {
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {
return NULL;
}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}
"""
def __init__(self):
src = self._generate_src()
key = hashlib.md5(src.encode("utf-8")).hexdigest()
cache = CacheManager(key)
fname = "cuda_utils.so"
if not cache.has_file(fname):
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build("cuda_utils", src_path, tmpdir)
with open(so, "rb") as f:
cache.put(f.read(), fname, binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location("cuda_utils", cache._make_path(fname))
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
self.load_binary = mod.load_binary
cuda_utils = None