[RUNTIME] support multiple devices in the same process (#757)
This commit is contained in:
committed by
GitHub
parent
9a11a567ce
commit
5d4b26d380
@@ -7,7 +7,7 @@ import inspect
|
||||
import os
|
||||
import subprocess
|
||||
import textwrap
|
||||
from collections import namedtuple
|
||||
from collections import defaultdict, namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -252,7 +252,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
if stream is None and not warmup:
|
||||
stream = get_cuda_stream(device)
|
||||
try:
|
||||
bin = cache[key]
|
||||
bin = cache[device][key]
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
|
||||
return bin
|
||||
@@ -271,12 +271,11 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
for i, arg in constants.items():
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||
device = 0
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
|
||||
self.cache[key] = bin
|
||||
self.cache[device][key] = bin
|
||||
return bin
|
||||
return None
|
||||
"""
|
||||
@@ -301,7 +300,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
# cache of just-in-time compiled kernels
|
||||
self.cache = dict()
|
||||
self.cache = defaultdict(dict)
|
||||
self.hash = None
|
||||
# JITFunction can be instantiated as kernel
|
||||
# when called with a grid using __getitem__
|
||||
|
Reference in New Issue
Block a user