[ROCM] enable matmul(dot) and others (#391)
This commit is contained in:
@@ -15,10 +15,11 @@ from setuptools.command.test import test as TestCommand
|
||||
import distutils.spawn
|
||||
import urllib.request
|
||||
import tarfile
|
||||
import torch
|
||||
|
||||
def get_llvm():
|
||||
# tries to find system LLVM
|
||||
versions = ['-11.0', '-11', '-11-64']
|
||||
versions = ['-13.0', '-13', '-13-64']
|
||||
supported = ['llvm-config{v}'.format(v=v) for v in versions]
|
||||
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
|
||||
paths = [p for p in paths if p is not None]
|
||||
@@ -27,7 +28,7 @@ def get_llvm():
|
||||
if platform.system() == "Windows":
|
||||
return '', ''
|
||||
# download if nothing is installed
|
||||
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
|
||||
name = 'clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04'
|
||||
dir = '/tmp'
|
||||
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
|
||||
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
|
||||
@@ -36,7 +37,7 @@ def get_llvm():
|
||||
shutil.rmtree(os.path.join(dir, name))
|
||||
except:
|
||||
pass
|
||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
|
||||
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-13.0.0/{name}.tar.xz".format(name=name)
|
||||
print('downloading and extracting ' + url + '...')
|
||||
ftpstream = urllib.request.urlopen(url)
|
||||
file = tarfile.open(fileobj=ftpstream, mode="r|xz")
|
||||
@@ -80,7 +81,7 @@ class CMakeBuild(build_ext):
|
||||
|
||||
def build_extension(self, ext):
|
||||
llvm_include_dir, llvm_library_dir = get_llvm()
|
||||
# self.debug = True
|
||||
self.debug = True
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||
# create build directories
|
||||
build_suffix = 'debug' if self.debug else 'release'
|
||||
@@ -90,7 +91,10 @@ class CMakeBuild(build_ext):
|
||||
if not os.path.exists(llvm_build_dir):
|
||||
os.makedirs(llvm_build_dir)
|
||||
# python directories
|
||||
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
|
||||
if torch.version.hip is not None:
|
||||
python_include_dirs= [distutils.sysconfig.get_python_inc()] +['/opt/rocm/include']
|
||||
else:
|
||||
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
|
||||
cmake_args = [
|
||||
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
||||
"-DBUILD_TUTORIALS=OFF",
|
||||
@@ -117,6 +121,9 @@ class CMakeBuild(build_ext):
|
||||
build_args += ["--", '-j' + str(2 * multiprocessing.cpu_count())]
|
||||
|
||||
env = os.environ.copy()
|
||||
|
||||
if torch.version.hip is not None:
|
||||
env["TRITON_USE_ROCM"] = "ON"
|
||||
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=self.build_temp, env=env)
|
||||
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
|
||||
|
||||
|
Reference in New Issue
Block a user