[FRONTEND] Make triton.compile work without a cuda context (#708)
This allows compiling in a subprocess. I'm not seeing a ton of speedup from this, but figure it is a good change anyway.
This commit is contained in:
@@ -239,14 +239,12 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
|
||||
unlink(_flog);
|
||||
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
|
||||
}
|
||||
CUmodule ret;
|
||||
std::ifstream _cubin(_fbin, std::ios::binary );
|
||||
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
|
||||
_cubin.close();
|
||||
unlink(_fsrc);
|
||||
unlink(_flog);
|
||||
unlink(_fbin);
|
||||
dispatch::cuModuleLoadData(&ret, cubin.c_str());
|
||||
return cubin;
|
||||
}
|
||||
|
||||
|
@@ -436,7 +436,7 @@ typedef std::map<std::string, py::object> asm_map_t;
|
||||
|
||||
void init_triton_codegen(py::module &&m) {
|
||||
m.def("compile_ttir",
|
||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs) {
|
||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
|
||||
py::gil_scoped_release allow_threads;
|
||||
std::string name = ir.get_function_list()[0]->get_name();
|
||||
// record asm as we generate
|
||||
@@ -454,10 +454,12 @@ void init_triton_codegen(py::module &&m) {
|
||||
name, triton::codegen::create_extern_lib(name, path));
|
||||
}
|
||||
// device properties
|
||||
if (cc == 0) {
|
||||
CUdevice dev = (CUdevice)device;
|
||||
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
|
||||
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
||||
size_t cc = major*10 + minor;
|
||||
cc = major*10 + minor;
|
||||
}
|
||||
int version;
|
||||
std::string ptxas_path = drv::path_to_ptxas(version);
|
||||
// Triton-IR -> NVPTX LLVM-IR
|
||||
|
@@ -1,6 +1,8 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from collections import namedtuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -172,3 +174,33 @@ def test_jit_warmup_cache() -> None:
|
||||
assert len(kernel_add.cache) == 1
|
||||
kernel_add.warmup(*args, grid=(1,))
|
||||
assert len(kernel_add.cache) == 1
|
||||
|
||||
|
||||
def test_compile_in_subproc() -> None:
|
||||
@triton.jit
|
||||
def kernel_sub(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.store(o + idx,
|
||||
tl.load(a + idx) - tl.load(b + idx) * 777)
|
||||
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
cc = major * 10 + minor
|
||||
config = namedtuple("instance_descriptor", [
|
||||
"divisible_by_16", "equal_to_1"])(
|
||||
tuple(range(4)),
|
||||
())
|
||||
|
||||
proc = multiprocessing.Process(
|
||||
target=triton.compile,
|
||||
kwargs=dict(
|
||||
fn=kernel_sub,
|
||||
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
|
||||
device=0,
|
||||
constants={3: 32},
|
||||
configs=[config],
|
||||
warm_cache_only=True,
|
||||
cc=cc,
|
||||
))
|
||||
proc.start()
|
||||
proc.join()
|
||||
assert proc.exitcode == 0
|
||||
|
@@ -880,7 +880,10 @@ def ptx_get_kernel_name(ptx: str) -> str:
|
||||
return line.split()[-1]
|
||||
|
||||
|
||||
def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=_triton.code_gen.instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]:
|
||||
def _compile(fn, signature: str, device: int = -1, constants=dict(),
|
||||
specialization=_triton.code_gen.instance_descriptor(),
|
||||
num_warps: int = 4, num_stages: int = 3, extern_libs=None,
|
||||
output: str = "ttgir", cc=0) -> Tuple[str, int, str]:
|
||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
|
||||
@@ -894,7 +897,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if extern_libs is None:
|
||||
extern_libs = dict()
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs)
|
||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
||||
return asm, shared_mem, name
|
||||
|
||||
|
||||
@@ -1179,7 +1182,8 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
|
||||
return key
|
||||
|
||||
|
||||
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4,
|
||||
num_stages: int = 3, extern_libs=None, configs=None, cc=0, warm_cache_only=False):
|
||||
# we get the kernel, i.e. the first function generated in the module
|
||||
assert len(configs) == 1
|
||||
# cache manager
|
||||
@@ -1208,18 +1212,22 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
||||
if not fn_cache_manager.has_file(cubin_name) or \
|
||||
not fn_cache_manager.has_file(data_name) or \
|
||||
not fn_cache_manager.has_file(ptx_name):
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin")
|
||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||
extern_libs, "cubin", cc)
|
||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||
fn_cache_manager.put(asm["cubin"], cubin_name)
|
||||
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
|
||||
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
||||
|
||||
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir)
|
||||
if warm_cache_only:
|
||||
return # load_binary() requires a valid cuda context
|
||||
|
||||
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir, device)
|
||||
|
||||
|
||||
class CompiledKernel:
|
||||
|
||||
def __init__(self, fn_name, so_path, cache_dir):
|
||||
def __init__(self, fn_name, so_path, cache_dir, device):
|
||||
# initialize launcher
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
||||
@@ -1239,7 +1247,6 @@ class CompiledKernel:
|
||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||
self.asm["ptx"] = f.read()
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
|
Reference in New Issue
Block a user