From efa04cac1ff746701dd1087aca4dae418473413b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 2 Jun 2022 16:57:37 -0700 Subject: [PATCH] [FRONTEND] A couple of bugfixes (#534) --- python/triton/code_gen.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 4f0c75f8d..b64c7eb86 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -63,7 +63,7 @@ def mangle_ty(ty): def mangle_fn(name, arg_tys, constants): # doesn't mangle ret type, which must be a function of arg tys mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) - key = lambda x: x.cache_key if isinstance(x, JITFunction) else repr(x) + key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x) mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)]) mangled_constants = mangled_constants.replace('.', '_d_') mangled_constants = mangled_constants.replace("'", '_sq_') @@ -971,6 +971,10 @@ class Kernel: # assert arg.is_cuda, "All tensors must be on GPU!" # set device (i.e., make sure torch has the context initialized) device = torch.cuda.current_device() + # torch creates new thread for backward pass that may have uninitlialized context + # no way to know if this function should or shouldn't initialize the cuda context + # so we're being conservative here + torch.cuda.set_device(device) if device not in self.cache_key: cc = torch.cuda.get_device_capability(device) cc = str(cc[0]) + '-' + str(cc[1])