[FRONTEND] Make triton.compile work without a cuda context (#708)
This allows compiling in a subprocess. I'm not seeing a ton of speedup from this, but figure it is a good change anyway.
This commit is contained in:
@@ -239,14 +239,12 @@ std::string ptx_to_cubin(const std::string& ptx, const std::string& ptxas, int c
|
|||||||
unlink(_flog);
|
unlink(_flog);
|
||||||
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
|
throw std::runtime_error("Internal Triton PTX codegen error: \n" + log);
|
||||||
}
|
}
|
||||||
CUmodule ret;
|
|
||||||
std::ifstream _cubin(_fbin, std::ios::binary );
|
std::ifstream _cubin(_fbin, std::ios::binary );
|
||||||
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
|
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
|
||||||
_cubin.close();
|
_cubin.close();
|
||||||
unlink(_fsrc);
|
unlink(_fsrc);
|
||||||
unlink(_flog);
|
unlink(_flog);
|
||||||
unlink(_fbin);
|
unlink(_fbin);
|
||||||
dispatch::cuModuleLoadData(&ret, cubin.c_str());
|
|
||||||
return cubin;
|
return cubin;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -436,7 +436,7 @@ typedef std::map<std::string, py::object> asm_map_t;
|
|||||||
|
|
||||||
void init_triton_codegen(py::module &&m) {
|
void init_triton_codegen(py::module &&m) {
|
||||||
m.def("compile_ttir",
|
m.def("compile_ttir",
|
||||||
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs) {
|
[](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
|
||||||
py::gil_scoped_release allow_threads;
|
py::gil_scoped_release allow_threads;
|
||||||
std::string name = ir.get_function_list()[0]->get_name();
|
std::string name = ir.get_function_list()[0]->get_name();
|
||||||
// record asm as we generate
|
// record asm as we generate
|
||||||
@@ -454,10 +454,12 @@ void init_triton_codegen(py::module &&m) {
|
|||||||
name, triton::codegen::create_extern_lib(name, path));
|
name, triton::codegen::create_extern_lib(name, path));
|
||||||
}
|
}
|
||||||
// device properties
|
// device properties
|
||||||
CUdevice dev = (CUdevice)device;
|
if (cc == 0) {
|
||||||
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
|
CUdevice dev = (CUdevice)device;
|
||||||
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
|
||||||
size_t cc = major*10 + minor;
|
size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
|
||||||
|
cc = major*10 + minor;
|
||||||
|
}
|
||||||
int version;
|
int version;
|
||||||
std::string ptxas_path = drv::path_to_ptxas(version);
|
std::string ptxas_path = drv::path_to_ptxas(version);
|
||||||
// Triton-IR -> NVPTX LLVM-IR
|
// Triton-IR -> NVPTX LLVM-IR
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@@ -172,3 +174,33 @@ def test_jit_warmup_cache() -> None:
|
|||||||
assert len(kernel_add.cache) == 1
|
assert len(kernel_add.cache) == 1
|
||||||
kernel_add.warmup(*args, grid=(1,))
|
kernel_add.warmup(*args, grid=(1,))
|
||||||
assert len(kernel_add.cache) == 1
|
assert len(kernel_add.cache) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_compile_in_subproc() -> None:
|
||||||
|
@triton.jit
|
||||||
|
def kernel_sub(a, b, o, N: tl.constexpr):
|
||||||
|
idx = tl.arange(0, N)
|
||||||
|
tl.store(o + idx,
|
||||||
|
tl.load(a + idx) - tl.load(b + idx) * 777)
|
||||||
|
|
||||||
|
major, minor = torch.cuda.get_device_capability(0)
|
||||||
|
cc = major * 10 + minor
|
||||||
|
config = namedtuple("instance_descriptor", [
|
||||||
|
"divisible_by_16", "equal_to_1"])(
|
||||||
|
tuple(range(4)),
|
||||||
|
())
|
||||||
|
|
||||||
|
proc = multiprocessing.Process(
|
||||||
|
target=triton.compile,
|
||||||
|
kwargs=dict(
|
||||||
|
fn=kernel_sub,
|
||||||
|
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
|
||||||
|
device=0,
|
||||||
|
constants={3: 32},
|
||||||
|
configs=[config],
|
||||||
|
warm_cache_only=True,
|
||||||
|
cc=cc,
|
||||||
|
))
|
||||||
|
proc.start()
|
||||||
|
proc.join()
|
||||||
|
assert proc.exitcode == 0
|
||||||
|
@@ -880,7 +880,10 @@ def ptx_get_kernel_name(ptx: str) -> str:
|
|||||||
return line.split()[-1]
|
return line.split()[-1]
|
||||||
|
|
||||||
|
|
||||||
def _compile(fn, signature: str, device: int = -1, constants=dict(), specialization=_triton.code_gen.instance_descriptor(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, output: str = "ttgir") -> Tuple[str, int, str]:
|
def _compile(fn, signature: str, device: int = -1, constants=dict(),
|
||||||
|
specialization=_triton.code_gen.instance_descriptor(),
|
||||||
|
num_warps: int = 4, num_stages: int = 3, extern_libs=None,
|
||||||
|
output: str = "ttgir", cc=0) -> Tuple[str, int, str]:
|
||||||
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
valid_outputs = ("ttir", "ttgir", "ptx", "cubin")
|
||||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||||
|
|
||||||
@@ -894,7 +897,7 @@ def _compile(fn, signature: str, device: int = -1, constants=dict(), specializat
|
|||||||
backend = _triton.runtime.backend.CUDA
|
backend = _triton.runtime.backend.CUDA
|
||||||
if extern_libs is None:
|
if extern_libs is None:
|
||||||
extern_libs = dict()
|
extern_libs = dict()
|
||||||
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs)
|
name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, module, device, num_warps, num_stages, extern_libs, cc)
|
||||||
return asm, shared_mem, name
|
return asm, shared_mem, name
|
||||||
|
|
||||||
|
|
||||||
@@ -1179,7 +1182,8 @@ def make_fn_cache_key(fn_hash, signature, configs, constants, num_warps, num_sta
|
|||||||
return key
|
return key
|
||||||
|
|
||||||
|
|
||||||
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4,
|
||||||
|
num_stages: int = 3, extern_libs=None, configs=None, cc=0, warm_cache_only=False):
|
||||||
# we get the kernel, i.e. the first function generated in the module
|
# we get the kernel, i.e. the first function generated in the module
|
||||||
assert len(configs) == 1
|
assert len(configs) == 1
|
||||||
# cache manager
|
# cache manager
|
||||||
@@ -1208,18 +1212,22 @@ def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: i
|
|||||||
if not fn_cache_manager.has_file(cubin_name) or \
|
if not fn_cache_manager.has_file(cubin_name) or \
|
||||||
not fn_cache_manager.has_file(data_name) or \
|
not fn_cache_manager.has_file(data_name) or \
|
||||||
not fn_cache_manager.has_file(ptx_name):
|
not fn_cache_manager.has_file(ptx_name):
|
||||||
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages, extern_libs, "cubin")
|
asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
|
||||||
|
extern_libs, "cubin", cc)
|
||||||
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
metadata = {"name": kernel_name, "shared": shared, "num_warps": num_warps, "num_stages": num_stages}
|
||||||
fn_cache_manager.put(asm["cubin"], cubin_name)
|
fn_cache_manager.put(asm["cubin"], cubin_name)
|
||||||
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
|
fn_cache_manager.put(asm["ptx"], ptx_name, binary=False)
|
||||||
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
fn_cache_manager.put(json.dumps(metadata), data_name, binary=False)
|
||||||
|
|
||||||
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir)
|
if warm_cache_only:
|
||||||
|
return # load_binary() requires a valid cuda context
|
||||||
|
|
||||||
|
return CompiledKernel(name, so_cache_manager._make_path(so_name), fn_cache_manager.cache_dir, device)
|
||||||
|
|
||||||
|
|
||||||
class CompiledKernel:
|
class CompiledKernel:
|
||||||
|
|
||||||
def __init__(self, fn_name, so_path, cache_dir):
|
def __init__(self, fn_name, so_path, cache_dir, device):
|
||||||
# initialize launcher
|
# initialize launcher
|
||||||
import importlib.util
|
import importlib.util
|
||||||
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
spec = importlib.util.spec_from_file_location("launcher", so_path)
|
||||||
@@ -1239,7 +1247,6 @@ class CompiledKernel:
|
|||||||
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
with open(os.path.join(cache_dir, f"{fn_name}.ptx"), "r") as f:
|
||||||
self.asm["ptx"] = f.read()
|
self.asm["ptx"] = f.read()
|
||||||
|
|
||||||
device = torch.cuda.current_device()
|
|
||||||
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
|
||||||
self.cu_module = mod
|
self.cu_module = mod
|
||||||
self.cu_function = func
|
self.cu_function = func
|
||||||
|
Reference in New Issue
Block a user