[RUNTIME] support multiple devices in the same process (#757)

This commit is contained in:
Felipe Petroski Such
2022-10-09 20:30:04 -07:00
committed by GitHub
parent 9a11a567ce
commit 5d4b26d380

View File

@@ -7,7 +7,7 @@ import inspect
import os import os
import subprocess import subprocess
import textwrap import textwrap
from collections import namedtuple from collections import defaultdict, namedtuple
import torch import torch
@@ -252,7 +252,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
if stream is None and not warmup: if stream is None and not warmup:
stream = get_cuda_stream(device) stream = get_cuda_stream(device)
try: try:
bin = cache[key] bin = cache[device][key]
if not warmup: if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args}) bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
return bin return bin
@@ -271,12 +271,11 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
for i, arg in constants.items(): for i, arg in constants.items():
if callable(arg): if callable(arg):
raise TypeError(f"Callable constexpr at index {i} is not supported") raise TypeError(f"Callable constexpr at index {i} is not supported")
device = 0
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
if not warmup: if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args) bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
self.cache[key] = bin self.cache[device][key] = bin
return bin return bin
return None return None
""" """
@@ -301,7 +300,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
self.src = textwrap.dedent(inspect.getsource(fn)) self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):] self.src = self.src[self.src.find("def"):]
# cache of just-in-time compiled kernels # cache of just-in-time compiled kernels
self.cache = dict() self.cache = defaultdict(dict)
self.hash = None self.hash = None
# JITFunction can be instantiated as kernel # JITFunction can be instantiated as kernel
# when called with a grid using __getitem__ # when called with a grid using __getitem__