[GENERAL] Removed deprecated driver files and added basic compatibility with rocm (#268)
- Removed driver module -- accelerator runtime is handled by pytorch - Added basic support for ROCM based on @micmelesse 's PR -- now can execute empty kernel on AMD devices without any compile-time changes - Now only using PREFER_SHARED for kernels when the size of shared memory is greater than 49k. Otherwise there can be poor L1 performance for broadcast tensors
This commit is contained in:
@@ -411,9 +411,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
|
||||
class Binary:
|
||||
def __init__(self, module, kernel, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm):
|
||||
def __init__(self, backend, module, kernel, asm, num_warps, num_stages, force_nc_cache, shared_mem):
|
||||
# cache ir asm
|
||||
self.ir_asm = ir_asm
|
||||
self.asm = asm
|
||||
self.module = module
|
||||
self.kernel = kernel
|
||||
self.shared_mem = shared_mem
|
||||
@@ -421,29 +421,13 @@ class Binary:
|
||||
self.num_stages = num_stages
|
||||
self.force_nc_cache = force_nc_cache
|
||||
self.sass = None
|
||||
|
||||
def asm(self, mode):
|
||||
if mode == 'ttir':
|
||||
return self.ir_asm
|
||||
if mode == 'ptx':
|
||||
return self.module.ptx()
|
||||
if mode == 'sass':
|
||||
if self.sass is None:
|
||||
cubin = self.module.cubin()
|
||||
# get a temporary file name
|
||||
fd, path = tempfile.mkstemp(suffix='.cubin')
|
||||
f = open(path, 'wb')
|
||||
f.write(cubin)
|
||||
f.close()
|
||||
# extract SASS from cubin
|
||||
self.sass = extract(path, None)
|
||||
return self.sass
|
||||
if mode == 'llir':
|
||||
return self.module.llir()
|
||||
raise ValueError('Unsupported mode ' + mode)
|
||||
self.backend = backend
|
||||
|
||||
def __call__(self, stream, args, grid_0, grid_1=1, grid_2=1):
|
||||
stream.enqueue(self.kernel, grid_0, grid_1, grid_2, self.num_warps * 32, 1, 1, args, self.shared_mem)
|
||||
_triton.runtime.enqueue(self.backend, stream, self.kernel,
|
||||
grid_0, grid_1, grid_2,
|
||||
self.num_warps * 32, 1, 1,
|
||||
args, self.shared_mem)
|
||||
|
||||
|
||||
class CompilationError(Exception):
|
||||
@@ -548,10 +532,15 @@ class Kernel:
|
||||
raise e
|
||||
raise CompilationError(self.fn.src, node, e)
|
||||
# Compile to machine code
|
||||
mod, ker, shared_mem, ir_asm = _triton.code_gen.add_passes_to_emit_bin(generator.module, device, num_warps, num_stages, force_nc_cache)
|
||||
if shared_mem > device.max_shared_memory():
|
||||
raise OutOfResources(shared_mem, device.max_shared_memory(), "shared memory")
|
||||
return Binary(mod, ker, num_warps, num_stages, force_nc_cache, shared_mem, ir_asm)
|
||||
if torch.version.hip is None:
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
else:
|
||||
backend = _triton.runtime.backend.ROCM
|
||||
mod, ker, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages, force_nc_cache)
|
||||
max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
|
||||
if shared_mem > max_shared_memory:
|
||||
raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
|
||||
return Binary(backend, mod, ker, asm, num_warps, num_stages, force_nc_cache, shared_mem)
|
||||
|
||||
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, force_nc_cache=False, **meta):
|
||||
# device inference
|
||||
@@ -571,19 +560,20 @@ class Kernel:
|
||||
" Only CUDA is supported at the moment")
|
||||
|
||||
device = torch.device('cuda', torch.cuda.current_device())
|
||||
tt_device = _triton.driver.cu_device(device.index, False)
|
||||
if len(set(device_ids)) != 1 or device_ids[0] != device.index:
|
||||
device_ty = device.type
|
||||
device_idx = device.index
|
||||
if len(set(device_ids)) != 1 or device_ids[0] != device_idx:
|
||||
# try to enable P2P communication
|
||||
for arg_idx, dst_idx in zip(tensor_idxs, device_ids):
|
||||
if dst_idx != device.index:
|
||||
if dst_idx != device_idx:
|
||||
try:
|
||||
tt_device.enable_peer_access(wargs[arg_idx].data_ptr())
|
||||
_triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr())
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}"
|
||||
.format(device.index, dst_idx, str(e)))
|
||||
.format(device_idx, dst_idx, str(e)))
|
||||
|
||||
# enqueue kernel on the current device
|
||||
torch.cuda.set_device(device.index)
|
||||
torch.cuda.set_device(device_idx)
|
||||
# attributes
|
||||
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
|
||||
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) if isinstance(a, int)}
|
||||
@@ -594,12 +584,12 @@ class Kernel:
|
||||
attr_key = frozenset(attributes.items())
|
||||
meta_key = frozenset(meta.items())
|
||||
const_key = frozenset(constants.items())
|
||||
key = (device.type, device.index, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
||||
key = (device_ty, device_idx, types_key, attr_key, num_warps, num_stages, meta_key, const_key)
|
||||
cache = self.fn.cache
|
||||
if key not in cache:
|
||||
# compile and cache configuration if necessary
|
||||
cache[key] = self._compile(
|
||||
*wargs, device=tt_device, attributes=attributes,
|
||||
*wargs, device=device_idx, attributes=attributes,
|
||||
num_warps=num_warps, num_stages=num_stages, force_nc_cache=force_nc_cache,
|
||||
constants=constants, **meta
|
||||
)
|
||||
@@ -608,8 +598,7 @@ class Kernel:
|
||||
params = struct.pack(fmt, *args)
|
||||
# enqueue cached function into stream
|
||||
binary = cache[key]
|
||||
cu_stream = torch.cuda.current_stream(device.index).cuda_stream
|
||||
stream = _triton.driver.cu_stream(cu_stream, False)
|
||||
stream = torch.cuda.current_stream(device_idx).cuda_stream
|
||||
grid = grid(meta) if hasattr(grid, '__call__') else grid
|
||||
binary(stream, params, *grid)
|
||||
return binary
|
||||
|
Reference in New Issue
Block a user