[FRONTEND] Make triton.compile work without a cuda context (#708)

This allows compiling in a subprocess. I'm not seeing a ton of speedup from this, but figure it is a good change anyway.
This commit is contained in:
Jason Ansel
2022-09-24 13:41:47 -07:00
committed by GitHub
parent 3ac929b48b
commit 998fd5f9af
4 changed files with 53 additions and 14 deletions

View File

@@ -239,14 +239,12 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
unlink(_flog);
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
}
CUmodule ret;
std::ifstream _cubin(_fbin, std::ios::binary );
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
unlink(_fsrc);
unlink(_flog);
unlink(_fbin);
dispatch::cuModuleLoadData(&ret, cubin.c_str());
return cubin;
}

View File

@@ -436,7 +436,7 @@ typedef std::map<std::string, py::object> asm_map_t;
void init_triton_codegen(py::module &&m) {
m.def("compile_ttir",
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs) {
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
py::gil_scoped_release allow_threads;
std::string name = ir.get_function_list()[0]->get_name();
// record asm as we generate
@@ -454,10 +454,12 @@ void init_triton_codegen(py::module &&m) {
name, triton::codegen::create_extern_lib(name, path));
}
// device properties
if (cc == 0) {
CUdevice dev = (CUdevice)device;
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
size_t cc = major*10 + minor;
cc = major*10 + minor;
}
int version;
std::string ptxas_path = drv::path_to_ptxas(version);
// Triton-IR -> NVPTX LLVM-IR

View File

@@ -1,6 +1,8 @@
import multiprocessing
import os
import re
import shutil
from collections import namedtuple
import pytest
import torch
@@ -172,3 +174,33 @@ def test_jit_warmup_cache() -> None:
assert len(kernel_add.cache) == 1
kernel_add.warmup(*args, grid=(1,))
assert len(kernel_add.cache) == 1
def test_compile_in_subproc() -> None:
@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx,
tl.load(a + idx) - tl.load(b + idx) * 777)
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = namedtuple("instance_descriptor", [
"divisible_by_16", "equal_to_1"])(
tuple(range(4)),
())
proc = multiprocessing.Process(
target=triton.compile,
kwargs=dict(
fn=kernel_sub,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
device=0,
constants={3: 32},
configs=[config],
warm_cache_only=True,
cc=cc,
))
proc.start()
proc.join()
assert proc.exitcode == 0

View File

@@ -880,7 +880,10 @@ def ptx_get_kernel_name(ptx: str) -> str:
return line.split()[-1]
def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=_triton.code_gen.instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]:
def _compile(fn, signature: str, device: int = -1, constants=dict(),
specialization=_triton.code_gen.instance_descriptor(),
num_warps: int = 4, num_stages: int = 3, extern_libs=None,
output: str = "ttgir", cc=0) -> Tuple[str, int, str]:
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
@@ -894,7 +897,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat
backend = _triton.runtime.backend.CUDA
if extern_libs is None:
extern_libs = dict()
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs)
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs, cc)
return asm, shared_mem, name
@@ -1179,7 +1182,8 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
return key
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4,
num_stages: int = 3, extern_libs=None, configs=None, cc=0, warm_cache_only=False):
# we get the kernel, i.e. the first function generated in the module
assert len(configs) == 1
# cache manager
@@ -1208,18 +1212,22 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
if not fn_cache_manager.has_file(cubin_name) or \
not fn_cache_manager.has_file(data_name) or \
not fn_cache_manager.has_file(ptx_name):
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin")
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
extern_libs, "cubin", cc)
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
fn_cache_manager.put(asm["cubin"], cubin_name)
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir)
if warm_cache_only:
return # load_binary() requires a valid cuda context
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir, device)
class CompiledKernel:
def __init__(self, fn_name, so_path, cache_dir):
def __init__(self, fn_name, so_path, cache_dir, device):
# initialize launcher
import importlib.util
spec = importlib.util.spec_from_file_location("launcher", so_path)
@@ -1239,7 +1247,6 @@ class CompiledKernel:
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
self.asm["ptx"] = f.read()
device = torch.cuda.current_device()
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
self.cu_module = mod
self.cu_function = func