[CODEGEN] Reverted to old launch method (memory leak?)

This commit is contained in:
Philippe Tillet
2021-11-16 01:21:03 -08:00
parent 791b953b21
commit 5b7ba3eb96

View File

@@ -493,6 +493,184 @@ class OutOfResources(Exception):
super().__init__(self.message) super().__init__(self.message)
# class Kernel:
# @staticmethod
# def _type_name(obj):
# type_names = {
# triton.language.float8: 'f8',
# torch.bfloat16: 'bf16',
# torch.float16: 'f16',
# torch.float32: 'f32',
# torch.float64: 'f64',
# torch.bool: 'i1',
# torch.int8: 'i8',
# torch.int16: 'i16',
# torch.int32: 'i32',
# torch.int64: 'i64',
# }
# if hasattr(obj, 'data_ptr'):
# return type_names[obj.dtype]
# if isinstance(obj, triton.language.core.constexpr):
# obj = obj.value
# if isinstance(obj, int):
# if abs(obj) <= 0xffffffff:
# return 'I'
# return 'L'
# if isinstance(obj, float):
# return 'f'
# if isinstance(obj, bool):
# return 'B'
# if isinstance(obj, str):
# return 'str'
# assert False
# @staticmethod
# def _to_triton_ir(context, obj):
# type_map = {
# 'I': _triton.ir.type.get_int32,
# 'L': _triton.ir.type.get_int64,
# 'f': _triton.ir.type.get_fp32,
# 'B': _triton.ir.type.get_int1,
# 'f8': _triton.ir.type.get_fp8,
# 'f16': _triton.ir.type.get_fp16,
# 'bf16': _triton.ir.type.get_bf16,
# 'f32': _triton.ir.type.get_fp32,
# 'f64': _triton.ir.type.get_fp64,
# 'i1': _triton.ir.type.get_int1,
# 'i8': _triton.ir.type.get_int8,
# 'i16': _triton.ir.type.get_int16,
# 'i32': _triton.ir.type.get_int32,
# 'i64': _triton.ir.type.get_int64,
# }
# # convert torch.Tensor to Triton IR pointers
# if hasattr(obj, 'data_ptr'):
# name = Kernel._type_name(obj)
# elt_ty = type_map[name](context)
# return _triton.ir.type.make_ptr(elt_ty, 1)
# # default path returns triton.ir.type directly
# name = Kernel._type_name(obj)
# return type_map[name](context)
# @staticmethod
# def pow2_divisor(N):
# if N % 16 == 0: return 16
# if N % 8 == 0: return 8
# if N % 4 == 0: return 4
# if N % 2 == 0: return 2
# return 1
# def __init__(self, fn):
# self.fn = fn
# def _compile(self, *wargs, device, attributes, constants, num_warps, num_stages):
# wargs = [arg for arg in wargs if not isinstance(arg, triton.language.constexpr)]
# # create IR module
# context = _triton.ir.context()
# # get just-in-time proto-type of kernel
# arg_types = [Kernel._to_triton_ir(context, arg) for arg in wargs]
# ret_type = _triton.ir.type.get_void(context)
# prototype = _triton.ir.type.make_function(ret_type, arg_types)
# # generate Triton-IR
# # export symbols visible from self.fn into code-generator object
# gscope = sys.modules[self.fn.module].__dict__
# generator = CodeGenerator(context, prototype, gscope=gscope, attributes=attributes, constants=constants, kwargs=dict())
# try:
# generator.visit(self.fn.parse())
# except Exception as e:
# node = generator.last_node
# if node is None or isinstance(e, (NotImplementedError, CompilationError)):
# raise e
# raise CompilationError(self.fn.src, node, e)
# # Compile to machine code
# if torch.version.hip is None:
# backend = _triton.runtime.backend.CUDA
# else:
# backend = _triton.runtime.backend.ROCM
# name, asm, shared_mem = _triton.code_gen.compile_ttir(backend, generator.module, device, num_warps, num_stages)
# max_shared_memory = _triton.runtime.max_shared_memory(backend, device)
# if shared_mem > max_shared_memory:
# raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
# return Binary(backend, name, asm, shared_mem, num_warps)
# def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
# tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# # attributes
# args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
# attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
# if isinstance(a, int) and i not in self.fn.do_not_specialize}
# # transforms ints whose value is one into constants for just-in-time compilation
# constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
# constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
# hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
# # create cache directory
# cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
# if cache_dir and not os.path.exists(cache_dir):
# os.makedirs(cache_dir, exist_ok=True)
# if cache_dir:
# bin_cache_path = os.path.join(cache_dir, hashed_key)
# bin_lock_path = bin_cache_path + ".lock"
# else:
# bin_cache_path = None
# bin_lock_path = None
# binary = None
# if bin_cache_path and os.path.exists(bin_cache_path):
# assert bin_lock_path is not None
# with FileLock(bin_lock_path):
# with open(bin_cache_path, 'rb') as f:
# binary = pickle.load(f)["binary"]
# if binary is None:
# binary = self._compile(
# *wargs, device=device_idx, attributes=attributes,
# num_warps=num_warps, num_stages=num_stages,
# constants=constants,
# )
# if bin_cache_path:
# assert bin_lock_path is not None
# with FileLock(bin_lock_path):
# with open(bin_cache_path + ".tmp", "wb") as f:
# pickle.dump({"binary": binary, "key": key}, f)
# os.rename(bin_cache_path + ".tmp", bin_cache_path)
# if JITFunction.cache_hook is not None:
# JITFunction.cache_hook(key=key, binary=binary)
# self.fn.bin_cache[key] = LoadedBinary(device_idx, binary)
# def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
# # handle arguments passed by name
# kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
# wargs = list(wargs)
# for i, pos in enumerate(sorted(kwargs)):
# wargs.insert(pos + i, kwargs[pos])
# if len(wargs) != len(self.fn.arg_names):
# raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given")
# # handle annotations
# for pos, _type in self.fn.annotations.items():
# wargs[pos] = _type(wargs[pos])
# # query device index and cuda stream
# device = torch.cuda.current_device()
# torch.cuda.set_device(device)
# cc = torch.cuda.get_device_capability(device)
# cc = str(cc[0]) + '-' + str(cc[1])
# # # query stream
# # # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream`
# # # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154
# # # building a C wrapper to re-use the unpack function would add a build-time torch dependency
# # # and require different wheels for different torch versions -- undesirable!
# # bits = torch._C._cuda_getCurrentStream(device)
# # mask = 1 << 47
# # stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask
# stream = torch.cuda.current_stream(device).cuda_stream
# # make key for cache
# return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream,
# self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid)
class Kernel: class Kernel:
@staticmethod @staticmethod
def _type_name(obj): def _type_name(obj):
@@ -553,6 +731,16 @@ class Kernel:
name = Kernel._type_name(obj) name = Kernel._type_name(obj)
return type_map[name](context) return type_map[name](context)
@staticmethod
def _types_key(*wargs, tensor_idxs):
# type inference
types_key = [None] * len(wargs)
for i, arg in enumerate(wargs):
prefix = 'P' if i in tensor_idxs else ''
suffix = Kernel._type_name(arg) if i in tensor_idxs else Kernel._type_name(arg)
types_key[i] = prefix + suffix
return tuple(types_key)
@staticmethod @staticmethod
def pow2_divisor(N): def pow2_divisor(N):
if N % 16 == 0: return 16 if N % 16 == 0: return 16
@@ -594,53 +782,6 @@ class Kernel:
raise OutOfResources(shared_mem, max_shared_memory, "shared memory") raise OutOfResources(shared_mem, max_shared_memory, "shared memory")
return Binary(backend, name, asm, shared_mem, num_warps) return Binary(backend, name, asm, shared_mem, num_warps)
def add_to_cache(self, key, wargs, device_idx, num_warps, num_stages):
tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
# attributes
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
if isinstance(a, int) and i not in self.fn.do_not_specialize}
# transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
# create cache directory
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
if cache_dir and not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
if cache_dir:
bin_cache_path = os.path.join(cache_dir, hashed_key)
bin_lock_path = bin_cache_path + ".lock"
else:
bin_cache_path = None
bin_lock_path = None
binary = None
if bin_cache_path and os.path.exists(bin_cache_path):
assert bin_lock_path is not None
with FileLock(bin_lock_path):
with open(bin_cache_path, 'rb') as f:
binary = pickle.load(f)["binary"]
if binary is None:
binary = self._compile(
*wargs, device=device_idx, attributes=attributes,
num_warps=num_warps, num_stages=num_stages,
constants=constants,
)
if bin_cache_path:
assert bin_lock_path is not None
with FileLock(bin_lock_path):
with open(bin_cache_path + ".tmp", "wb") as f:
pickle.dump({"binary": binary, "key": key}, f)
os.rename(bin_cache_path + ".tmp", bin_cache_path)
if JITFunction.cache_hook is not None:
JITFunction.cache_hook(key=key, binary=binary)
self.fn.bin_cache[key] = LoadedBinary(device_idx, binary)
def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs): def __call__(self, *wargs, grid, num_warps=4, num_stages=2, **kwargs):
# handle arguments passed by name # handle arguments passed by name
kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()} kwargs = {self.fn.arg_names.index(name): value for name, value in kwargs.items()}
@@ -650,25 +791,112 @@ class Kernel:
if len(wargs) != len(self.fn.arg_names): if len(wargs) != len(self.fn.arg_names):
raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given") raise TypeError(f"Function takes {len(self.fn.arg_names)} positional arguments but {len(wargs)} were given")
# handle annotations # handle annotations
for pos, _type in self.fn.annotations.items(): for name, type in self.fn.__annotations__.items():
wargs[pos] = _type(wargs[pos]) pos = self.fn.arg_names.index(name)
# query device index and cuda stream assert type == triton.language.core.constexpr
device = torch.cuda.current_device() wargs[pos] = type(wargs[pos])
torch.cuda.set_device(device) # device inference
cc = torch.cuda.get_device_capability(device) tensor_idxs = [i for i, arg in enumerate(wargs) if hasattr(arg, 'data_ptr')]
cc = str(cc[0]) + '-' + str(cc[1]) if len(tensor_idxs) == 0:
# # query stream raise ValueError("No Tensor argument found.")
# # this is hacky but much faster than `torch.cuda.current_stream(device).cuda_stream` invalid_args = []
# # https://github.com/pytorch/pytorch/blob/master/c10/core/Stream.h#L154 device_ids = []
# # building a C wrapper to re-use the unpack function would add a build-time torch dependency for idx in tensor_idxs:
# # and require different wheels for different torch versions -- undesirable! curr = wargs[idx]
# bits = torch._C._cuda_getCurrentStream(device) if not curr.is_cuda:
# mask = 1 << 47 invalid_args.append(idx)
# stream = ((bits & 0xFFFFFFFFFFFF) ^ mask) - mask else:
stream = torch.cuda.current_stream(device).cuda_stream device_ids.append(curr.device.index)
# make key for cache if invalid_args:
return _triton.runtime.launch(wargs, self.fn.do_not_specialize, self.fn.cache_key + cc, self.fn.arg_names, device, stream, raise ValueError("Arguments at index {invalid_args} are on the wrong device.".format(invalid_args=invalid_args) +
self.fn.bin_cache, num_warps, num_stages, self.add_to_cache, grid) " Only CUDA is supported at the moment")
device = torch.device('cuda', torch.cuda.current_device())
device_idx = device.index
# if len(set(device_ids)) != 1 or device_ids[0] != device_idx:
# # try to enable P2P communication
# for arg_idx, dst_idx in zip(tensor_idxs, device_ids):
# if dst_idx != device_idx:
# try:
# _triton.runtime.enable_peer_access(self.backend, wargs[arg_idx].data_ptr())
# except RuntimeError as e:
# raise RuntimeError("Cannot enable P2P access from device {} to device {}: {}"
# .format(device_idx, dst_idx, str(e)))
# enqueue kernel on the current device
torch.cuda.set_device(device_idx)
# attributes
args = [arg.data_ptr() if i in tensor_idxs else arg for i, arg in enumerate(wargs)]
attributes = {i: Kernel.pow2_divisor(a) for i, a in enumerate(args) \
if isinstance(a, int) and i not in self.fn.do_not_specialize}
# transforms ints whose value is one into constants for just-in-time compilation
constants = {i: arg for i, arg in enumerate(wargs) if isinstance(arg, int) and arg == 1}
constants.update({i: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.constexpr)})
# compute hash for caching this kernel
types_key = Kernel._types_key(*wargs, tensor_idxs=tensor_idxs)
attr_key = tuple(attributes.items())
const_key = tuple(constants.items())
compute_capability = torch.cuda.get_device_capability(device)
key = (
self.fn.cache_key, version_key(), compute_capability,
types_key, attr_key, num_warps, num_stages, const_key
)
key = repr(key)
# get cached binary
bin_cache = self.fn.bin_cache
if key not in bin_cache:
hashed_key = hashlib.md5(key.encode("utf-8")).hexdigest()
# create cache directory
cache_dir = os.environ.get('TRITON_CACHE_DIR', '/tmp/triton/')
if cache_dir and not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
if cache_dir:
bin_cache_path = os.path.join(cache_dir, hashed_key)
bin_lock_path = bin_cache_path + ".lock"
else:
bin_cache_path = None
bin_lock_path = None
binary = None
if bin_cache_path and os.path.exists(bin_cache_path):
assert bin_lock_path is not None
with FileLock(bin_lock_path):
with open(bin_cache_path, 'rb') as f:
binary = pickle.load(f)["binary"]
if binary is None:
binary = self._compile(
*wargs, device=device_idx, attributes=attributes,
num_warps=num_warps, num_stages=num_stages,
constants=constants,
)
if bin_cache_path:
assert bin_lock_path is not None
with FileLock(bin_lock_path):
with open(bin_cache_path + ".tmp", "wb") as f:
pickle.dump({"binary": binary, "key": key}, f)
os.rename(bin_cache_path + ".tmp", bin_cache_path)
if JITFunction.cache_hook is not None:
JITFunction.cache_hook(key=key, binary=binary)
bin_cache[key] = LoadedBinary(device_idx, binary)
# pack arguments
fmt = ''.join(['P' if i in tensor_idxs else Kernel._type_name(arg) for i, arg in enumerate(wargs) if not isinstance(arg, triton.language.core.constexpr)])
params = struct.pack(fmt, *[arg for arg in args if not isinstance(arg, triton.language.core.constexpr)])
# enqueue cached function into stream
callable = bin_cache[key]
stream = torch.cuda.current_stream(device_idx).cuda_stream
csts = {self.fn.arg_names[i]: arg.value for i, arg in enumerate(wargs) if isinstance(arg, triton.language.core.constexpr)}
grid = grid(csts) if hasattr(grid, '__call__') else grid
if isinstance(grid, int):
grid = tuple(grid)
callable(stream, params, *grid)
return callable
class Launcher: class Launcher: