[FRONTEND] Make triton.compile work without a cuda context (#708)

This allows compiling in a subprocess. I'm not seeing a ton of speedup from this, but figure it is a good change anyway.
This commit is contained in:
Jason Ansel
2022-09-24 13:41:47 -07:00
committed by GitHub
parent 3ac929b48b
commit 998fd5f9af
4 changed files with 53 additions and 14 deletions

View File

@@ -1,6 +1,8 @@
import multiprocessing
import os
import re
import shutil
from collections import namedtuple
import pytest
import torch
@@ -172,3 +174,33 @@ def test_jit_warmup_cache() -> None:
assert len(kernel_add.cache) == 1
kernel_add.warmup(*args, grid=(1,))
assert len(kernel_add.cache) == 1
def test_compile_in_subproc() -> None:
@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx,
tl.load(a + idx) - tl.load(b + idx) * 777)
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = namedtuple("instance_descriptor", [
"divisible_by_16", "equal_to_1"])(
tuple(range(4)),
())
proc = multiprocessing.Process(
target=triton.compile,
kwargs=dict(
fn=kernel_sub,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
device=0,
constants={3: 32},
configs=[config],
warm_cache_only=True,
cc=cc,
))
proc.start()
proc.join()
assert proc.exitcode == 0