[FRONTEND] A couple of bugfixes (#534)

This commit is contained in:
Philippe Tillet
2022-06-02 16:57:37 -07:00
committed by GitHub
parent 3e7500dfe6
commit efa04cac1f

View File

@@ -63,7 +63,7 @@ def mangle_ty(ty):
def mangle_fn(name, arg_tys, constants):
# doesn't mangle ret type, which must be a function of arg tys
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
key = lambda x: x.cache_key if isinstance(x, JITFunction) else repr(x)
key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x)
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
mangled_constants = mangled_constants.replace('.', '_d_')
mangled_constants = mangled_constants.replace("'", '_sq_')
@@ -971,6 +971,10 @@ class Kernel:
# assert arg.is_cuda, "All tensors must be on GPU!"
# set device (i.e., make sure torch has the context initialized)
device = torch.cuda.current_device()
# torch creates new thread for backward pass that may have uninitlialized context
# no way to know if this function should or shouldn't initialize the cuda context
# so we're being conservative here
torch.cuda.set_device(device)
if device not in self.cache_key:
cc = torch.cuda.get_device_capability(device)
cc = str(cc[0]) + '-' + str(cc[1])