[FRONTEND] Make triton.compile work without a cuda context (#708)
This allows compiling in a subprocess. I'm not seeing a ton of speedup from this, but figure it is a good change anyway.
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from collections import namedtuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -172,3 +174,33 @@ def test_jit_warmup_cache() -> None:
|
||||
assert len(kernel_add.cache) == 1
|
||||
kernel_add.warmup(*args, grid=(1,))
|
||||
assert len(kernel_add.cache) == 1
|
||||
|
||||
|
||||
def test_compile_in_subproc() -> None:
|
||||
@triton.jit
|
||||
def kernel_sub(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.store(o + idx,
|
||||
tl.load(a + idx) - tl.load(b + idx) * 777)
|
||||
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
cc = major * 10 + minor
|
||||
config = namedtuple("instance_descriptor", [
|
||||
"divisible_by_16", "equal_to_1"])(
|
||||
tuple(range(4)),
|
||||
())
|
||||
|
||||
proc = multiprocessing.Process(
|
||||
target=triton.compile,
|
||||
kwargs=dict(
|
||||
fn=kernel_sub,
|
||||
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
|
||||
device=0,
|
||||
constants={3: 32},
|
||||
configs=[config],
|
||||
warm_cache_only=True,
|
||||
cc=cc,
|
||||
))
|
||||
proc.start()
|
||||
proc.join()
|
||||
assert proc.exitcode == 0
|
||||
|
Reference in New Issue
Block a user