[FRONTEND] A couple of bugfixes (#534)
This commit is contained in:
@@ -63,7 +63,7 @@ def mangle_ty(ty):
|
|||||||
def mangle_fn(name, arg_tys, constants):
|
def mangle_fn(name, arg_tys, constants):
|
||||||
# doesn't mangle ret type, which must be a function of arg tys
|
# doesn't mangle ret type, which must be a function of arg tys
|
||||||
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
|
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
|
||||||
key = lambda x: x.cache_key if isinstance(x, JITFunction) else repr(x)
|
key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x)
|
||||||
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
|
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
|
||||||
mangled_constants = mangled_constants.replace('.', '_d_')
|
mangled_constants = mangled_constants.replace('.', '_d_')
|
||||||
mangled_constants = mangled_constants.replace("'", '_sq_')
|
mangled_constants = mangled_constants.replace("'", '_sq_')
|
||||||
@@ -971,6 +971,10 @@ class Kernel:
|
|||||||
# assert arg.is_cuda, "All tensors must be on GPU!"
|
# assert arg.is_cuda, "All tensors must be on GPU!"
|
||||||
# set device (i.e., make sure torch has the context initialized)
|
# set device (i.e., make sure torch has the context initialized)
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
|
# torch creates new thread for backward pass that may have uninitlialized context
|
||||||
|
# no way to know if this function should or shouldn't initialize the cuda context
|
||||||
|
# so we're being conservative here
|
||||||
|
torch.cuda.set_device(device)
|
||||||
if device not in self.cache_key:
|
if device not in self.cache_key:
|
||||||
cc = torch.cuda.get_device_capability(device)
|
cc = torch.cuda.get_device_capability(device)
|
||||||
cc = str(cc[0]) + '-' + str(cc[1])
|
cc = str(cc[0]) + '-' + str(cc[1])
|
||||||
|
Reference in New Issue
Block a user