[FRONTEND] A couple of bugfixes (#534)
This commit is contained in:
@@ -63,7 +63,7 @@ def mangle_ty(ty):
|
||||
def mangle_fn(name, arg_tys, constants):
|
||||
# doesn't mangle ret type, which must be a function of arg tys
|
||||
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
|
||||
key = lambda x: x.cache_key if isinstance(x, JITFunction) else repr(x)
|
||||
key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x)
|
||||
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
|
||||
mangled_constants = mangled_constants.replace('.', '_d_')
|
||||
mangled_constants = mangled_constants.replace("'", '_sq_')
|
||||
@@ -971,6 +971,10 @@ class Kernel:
|
||||
# assert arg.is_cuda, "All tensors must be on GPU!"
|
||||
# set device (i.e., make sure torch has the context initialized)
|
||||
device = torch.cuda.current_device()
|
||||
# torch creates new thread for backward pass that may have uninitlialized context
|
||||
# no way to know if this function should or shouldn't initialize the cuda context
|
||||
# so we're being conservative here
|
||||
torch.cuda.set_device(device)
|
||||
if device not in self.cache_key:
|
||||
cc = torch.cuda.get_device_capability(device)
|
||||
cc = str(cc[0]) + '-' + str(cc[1])
|
||||
|
Reference in New Issue
Block a user