diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index e17c381cb..555a2b14e 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -239,14 +239,12 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c unlink(_flog); throw std::runtime_error("Internal Triton PTX codegen error: \n" + log); } - CUmodule ret; std::ifstream _cubin(_fbin, std::ios::binary ); std::string cubin(std::istreambuf_iterator(_cubin), {}); _cubin.close(); unlink(_fsrc); unlink(_flog); unlink(_fbin); - dispatch::cuModuleLoadData(&ret, cubin.c_str()); return cubin; } diff --git a/python/src/triton.cc b/python/src/triton.cc index 31bc0445f..2a5052199 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -436,7 +436,7 @@ typedef std::map asm_map_t; void init_triton_codegen(py::module &&m) { m.def("compile_ttir", - [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs) { + [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) { py::gil_scoped_release allow_threads; std::string name = ir.get_function_list()[0]->get_name(); // record asm as we generate @@ -454,10 +454,12 @@ void init_triton_codegen(py::module &&m) { name, triton::codegen::create_extern_lib(name, path)); } // device properties - CUdevice dev = (CUdevice)device; - size_t major = cuGetInfo(dev); - size_t minor = cuGetInfo(dev); - size_t cc = major*10 + minor; + if (cc == 0) { + CUdevice dev = (CUdevice)device; + size_t major = cuGetInfo(dev); + size_t minor = cuGetInfo(dev); + cc = major*10 + minor; + } int version; std::string ptxas_path = drv::path_to_ptxas(version); // Triton-IR -> NVPTX LLVM-IR diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 6d6c0e131..a96c47916 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -1,6 +1,8 @@ +import multiprocessing import os import re import shutil +from collections import namedtuple import pytest import torch @@ -172,3 +174,33 @@ def test_jit_warmup_cache() -> None: assert len(kernel_add.cache) == 1 kernel_add.warmup(*args, grid=(1,)) assert len(kernel_add.cache) == 1 + + +def test_compile_in_subproc() -> None: + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, + tl.load(a + idx) - tl.load(b + idx) * 777) + + major, minor = torch.cuda.get_device_capability(0) + cc = major * 10 + minor + config = namedtuple("instance_descriptor", [ + "divisible_by_16", "equal_to_1"])( + tuple(range(4)), + ()) + + proc = multiprocessing.Process( + target=triton.compile, + kwargs=dict( + fn=kernel_sub, + signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + device=0, + constants={3: 32}, + configs=[config], + warm_cache_only=True, + cc=cc, + )) + proc.start() + proc.join() + assert proc.exitcode == 0 diff --git a/python/triton/compiler.py b/python/triton/compiler.py index c6523f016..87875d1d2 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -880,7 +880,10 @@ def ptx_get_kernel_name(ptx: str) -> str: return line.split()[-1] -def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=_triton.code_gen.instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]: +def _compile(fn, signature: str, device: int = -1, constants=dict(), + specialization=_triton.code_gen.instance_descriptor(), + num_warps: int = 4, num_stages: int = 3, extern_libs=None, + output: str = "ttgir", cc=0) -> Tuple[str, int, str]: valid_outputs = ("ttir", "ttgir", "ptx", "cubin") assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output) @@ -894,7 +897,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat backend = _triton.runtime.backend.CUDA if extern_libs is None: extern_libs = dict() - name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs) + name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs, cc) return asm, shared_mem, name @@ -1179,7 +1182,8 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta return key -def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None): +def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, + num_stages: int = 3, extern_libs=None, configs=None, cc=0, warm_cache_only=False): # we get the kernel, i.e. the first function generated in the module assert len(configs) == 1 # cache manager @@ -1208,18 +1212,22 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i if not fn_cache_manager.has_file(cubin_name) or \ not fn_cache_manager.has_file(data_name) or \ not fn_cache_manager.has_file(ptx_name): - asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin") + asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, + extern_libs, "cubin", cc) metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages} fn_cache_manager.put(asm["cubin"], cubin_name) fn_cache_manager.put(asm["ptx"], ptx_name, binary=False) fn_cache_manager.put(json.dumps(metadata), data_name, binary=False) - return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir) + if warm_cache_only: + return # load_binary() requires a valid cuda context + + return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir, device) class CompiledKernel: - def __init__(self, fn_name, so_path, cache_dir): + def __init__(self, fn_name, so_path, cache_dir, device): # initialize launcher import importlib.util spec = importlib.util.spec_from_file_location("launcher", so_path) @@ -1239,7 +1247,6 @@ class CompiledKernel: with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f: self.asm["ptx"] = f.read() - device = torch.cuda.current_device() mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device) self.cu_module = mod self.cu_function = func