[ROCM] enable matmul(dot) and others (#391)

This commit is contained in:
Michael Melesse
2021-12-13 12:28:15 -08:00
committed by GitHub
parent 73b04d71b2
commit 94d5c2e8b5
12 changed files with 251 additions and 52 deletions

View File

@@ -15,10 +15,11 @@ from setuptools.command.test import test as TestCommand
import distutils.spawn
import urllib.request
import tarfile
import torch
def get_llvm():
# tries to find system LLVM
versions = ['-11.0', '-11', '-11-64']
versions = ['-13.0', '-13', '-13-64']
supported = ['llvm-config{v}'.format(v=v) for v in versions]
paths = [distutils.spawn.find_executable(cfg) for cfg in supported]
paths = [p for p in paths if p is not None]
@@ -27,7 +28,7 @@ def get_llvm():
if platform.system() == "Windows":
return '', ''
# download if nothing is installed
name = 'clang+llvm-11.0.1-x86_64-linux-gnu-ubuntu-16.04'
name = 'clang+llvm-13.0.0-x86_64-linux-gnu-ubuntu-16.04'
dir = '/tmp'
llvm_include_dir = '{dir}/{name}/include'.format(dir=dir, name=name)
llvm_library_dir = '{dir}/{name}/lib'.format(dir=dir, name=name)
@@ -36,7 +37,7 @@ def get_llvm():
shutil.rmtree(os.path.join(dir, name))
except:
pass
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-11.0.1/{name}.tar.xz".format(name=name)
url = "https://github.com/llvm/llvm-project/releases/download/llvmorg-13.0.0/{name}.tar.xz".format(name=name)
print('downloading and extracting ' + url + '...')
ftpstream = urllib.request.urlopen(url)
file = tarfile.open(fileobj=ftpstream, mode="r|xz")
@@ -80,7 +81,7 @@ class CMakeBuild(build_ext):
def build_extension(self, ext):
llvm_include_dir, llvm_library_dir = get_llvm()
# self.debug = True
self.debug = True
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# create build directories
build_suffix = 'debug' if self.debug else 'release'
@@ -90,7 +91,10 @@ class CMakeBuild(build_ext):
if not os.path.exists(llvm_build_dir):
os.makedirs(llvm_build_dir)
# python directories
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
if torch.version.hip is not None:
python_include_dirs= [distutils.sysconfig.get_python_inc()] +['/opt/rocm/include']
else:
python_include_dirs = [distutils.sysconfig.get_python_inc()] + ['/usr/local/cuda/include']
cmake_args = [
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
"-DBUILD_TUTORIALS=OFF",
@@ -117,6 +121,9 @@ class CMakeBuild(build_ext):
build_args += ["--", '-j' + str(2 * multiprocessing.cpu_count())]
env = os.environ.copy()
if torch.version.hip is not None:
env["TRITON_USE_ROCM"] = "ON"
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=self.build_temp, env=env)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)

View File

@@ -45,6 +45,35 @@ def test_empty_kernel(dtype_x, device='cuda'):
x = triton.testing.random(SIZE, dtype=cvt[dtype_x], device=device)
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
# ---------------
# test load and store op
# ---------------
@pytest.mark.parametrize("dtype,size", [
(dtype, size)
for dtype in dtypes
for size in [128, 256, 512, 1024, 2048, 4096]
])
def test_load_and_store_op(dtype, size, device='cuda'):
SIZE = size
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, **meta):
off = tl.arange(0, meta['SIZE'])
x = tl.load(X + off)
tl.store(Z + off, x)
# inputs
x = triton.testing.random(SIZE, dtype=cvt[dtype], device=device)
# output tensors
z_ref = x.clone() # reference result
z_tri = torch.empty_like(x) # triton result
# run load and store kernel
kernel[(1, )](z_tri, x, SIZE=SIZE, num_warps=4)
# compare
triton.testing.assert_almost_equal(z_ref, z_tri)
# generic test functions
def _test_unary(dtype_x, expr, torch_expr=None, device='cuda'):
SIZE = 128
@@ -340,18 +369,23 @@ def test_atomic_rmw(op, dtype_x, mode, device='cuda'):
('float32', 'int32', True)
])
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
x = torch.tensor([43.5], dtype=cvt[dtype_x], device=device)
if torch.version.hip is not None:
assert 'bfloat' not in dtype_x
assert 'bfloat' not in dtype_z
SIZE = 1024
x = triton.testing.random((SIZE, ), dtype=cvt[dtype_x], device=device)
# triton kernel
@triton.jit
def kernel(X, Z, **meta):
x = tl.load(X)
off = tl.arange(0, meta['SIZE'])
x = tl.load(X+ off)
z = x.to(Z.dtype.element_ty, bitcast=meta['BITCAST'])
tl.store(Z, z)
tl.store(Z+ off, z)
# triton result
z_tri = torch.empty((1, ), dtype=cvt[dtype_z], device=device)
kernel[(1, )](x, z_tri, BITCAST=bitcast)
z_tri = torch.empty((SIZE, ), dtype=cvt[dtype_z], device=device)
kernel[(1, )](x, z_tri, SIZE=SIZE, BITCAST=bitcast)
# torch result
if bitcast:
import numpy as np
@@ -359,7 +393,7 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
z_ref = torch.from_numpy(z_ref).to(device)
else:
z_ref = x.to(z_tri.dtype)
assert z_tri == z_ref
triton.testing.assert_almost_equal(z_ref, z_tri)
# ---------------
# test reduce
@@ -448,17 +482,23 @@ def test_permute(dtype, shape, perm, device='cuda'):
z_ref = x.permute(*perm).contiguous()
# compare
triton.testing.assert_almost_equal(z_tri, z_ref)
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if torch.version.hip is None:
# parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
# ---------------
# test dot
# ---------------
@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, device='cuda'):
@pytest.mark.parametrize("dtype, epilogue", [(dtype, epilogue)\
for dtype in ['float16','float32'] \
for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols']])
def test_dot(dtype, epilogue, device='cuda'):
dtype = cvt[dtype]
torch.manual_seed(0)
# triton kernel
@triton.jit
@@ -486,10 +526,10 @@ def test_dot(epilogue, device='cuda'):
tl.store(Zs, z)
# input
M, N, K = 64, 64, 32
x = triton.testing.random((M, K), dtype=torch.float16, device=device)
y = triton.testing.random((K, N), dtype=torch.float16, device=device)
x = triton.testing.random((M, K), dtype=dtype, device=device)
y = triton.testing.random((K, N), dtype=dtype, device=device)
# triton result
z = triton.testing.random((M, N), dtype=torch.float16, device=device)
z = triton.testing.random((M, N), dtype=dtype, device=device)
z_tri = z.clone()
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
y, y.stride(0), y.stride(1),
@@ -508,12 +548,14 @@ def test_dot(epilogue, device='cuda'):
z_ref += z[0,:][None, :]
z_ref = z_ref.to(torch.float16)
# compare
ptx = pgm.asm['ptx']
# print(ptx)
triton.testing.assert_almost_equal(z_tri, z_ref)
# make sure ld/st are vectorized
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
# print(ptx)
if torch.version.hip is None:
ptx = pgm.asm['ptx']
# make sure ld/st are vectorized
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
def test_dot_without_load():
@triton.jit
@@ -611,17 +653,18 @@ def test_load_cache_modifier(cache):
tl.store(dst+offsets, x)
pgm = _kernel[(1,)](dst, src, CACHE=cache)
ptx = pgm.asm['ptx']
if torch.version.hip is None:
ptx = pgm.asm['ptx']
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
if cache == '':
assert 'ld.global.ca' not in ptx
assert 'ld.global.cg' not in ptx
if cache == '.cg':
assert 'ld.global.cg' in ptx
assert 'ld.global.ca' not in ptx
if cache == '.ca':
assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx
# ---------------
# test store
@@ -647,4 +690,4 @@ def test_noop(device='cuda'):
def kernel(**meta):
pass
x = triton.testing.random((1,), dtype=torch.int32, device=device)
kernel[(1, )](x)
kernel[(1, )](x)