@@ -963,23 +963,12 @@ def ptx_get_version(cuda_version) -> int:
|
||||
'''
|
||||
assert isinstance(cuda_version, str)
|
||||
major, minor = map(int, cuda_version.split('.'))
|
||||
version = major * 1000 + minor * 10
|
||||
if version >= 11040:
|
||||
return 74
|
||||
if version >= 11030:
|
||||
return 73
|
||||
if version >= 11020:
|
||||
return 72
|
||||
if version >= 11010:
|
||||
return 71
|
||||
if version >= 11000:
|
||||
return 70
|
||||
if version >= 10020:
|
||||
return 65
|
||||
if version >= 10010:
|
||||
return 64
|
||||
if version >= 10000:
|
||||
return 63
|
||||
if major == 12:
|
||||
return 80 + minor
|
||||
if major == 11:
|
||||
return 70 + minor
|
||||
if major == 10:
|
||||
return 63 + minor
|
||||
raise RuntimeError("Triton only support CUDA 10.0 or higher")
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user