[RUNTIME] support multiple devices in the same process (#757)
This commit is contained in:
committed by
GitHub
parent
9a11a567ce
commit
5d4b26d380
@@ -7,7 +7,7 @@ import inspect
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import textwrap
|
import textwrap
|
||||||
from collections import namedtuple
|
from collections import defaultdict, namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -252,7 +252,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
if stream is None and not warmup:
|
if stream is None and not warmup:
|
||||||
stream = get_cuda_stream(device)
|
stream = get_cuda_stream(device)
|
||||||
try:
|
try:
|
||||||
bin = cache[key]
|
bin = cache[device][key]
|
||||||
if not warmup:
|
if not warmup:
|
||||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
|
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
|
||||||
return bin
|
return bin
|
||||||
@@ -271,12 +271,11 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
for i, arg in constants.items():
|
for i, arg in constants.items():
|
||||||
if callable(arg):
|
if callable(arg):
|
||||||
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
raise TypeError(f"Callable constexpr at index {i} is not supported")
|
||||||
device = 0
|
|
||||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||||
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
bin = triton.compile(self, signature, device, constants, num_warps, num_stages, extern_libs=extern_libs, configs=configs)
|
||||||
if not warmup:
|
if not warmup:
|
||||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
|
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
|
||||||
self.cache[key] = bin
|
self.cache[device][key] = bin
|
||||||
return bin
|
return bin
|
||||||
return None
|
return None
|
||||||
"""
|
"""
|
||||||
@@ -301,7 +300,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
|||||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||||
self.src = self.src[self.src.find("def"):]
|
self.src = self.src[self.src.find("def"):]
|
||||||
# cache of just-in-time compiled kernels
|
# cache of just-in-time compiled kernels
|
||||||
self.cache = dict()
|
self.cache = defaultdict(dict)
|
||||||
self.hash = None
|
self.hash = None
|
||||||
# JITFunction can be instantiated as kernel
|
# JITFunction can be instantiated as kernel
|
||||||
# when called with a grid using __getitem__
|
# when called with a grid using __getitem__
|
||||||
|
Reference in New Issue
Block a user