From 5d4b26d380889f6a1993da45d8884f3c9fc7771e Mon Sep 17 00:00:00 2001 From: Felipe Petroski Such Date: Sun, 9 Oct 2022 20:30:04 -0700 Subject: [PATCH] [RUNTIME] support multiple devices in the same process (#757) --- python/triton/runtime/jit.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index bfa0edcc9..89ad3e2ca 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -7,7 +7,7 @@ import inspect import os import subprocess import textwrap -from collections import namedtuple +from collections import defaultdict, namedtuple import torch @@ -252,7 +252,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage if stream is None and not warmup: stream = get_cuda_stream(device) try: - bin = cache[key] + bin = cache[device][key] if not warmup: bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args}) return bin @@ -271,12 +271,11 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage for i, arg in constants.items(): if callable(arg): raise TypeError(f"Callable constexpr at index {i} is not supported") - device = 0 if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs) if not warmup: bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args) - self.cache[key] = bin + self.cache[device][key] = bin return bin return None """ @@ -301,7 +300,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage self.src = textwrap.dedent(inspect.getsource(fn)) self.src = self.src[self.src.find("def"):] # cache of just-in-time compiled kernels - self.cache = dict() + self.cache = defaultdict(dict) self.hash = None # JITFunction can be instantiated as kernel # when called with a grid using __getitem__