[BACKEND] Various bug fixes; making reductions faster (#533)

This commit is contained in:
Philippe Tillet
2022-05-31 17:14:44 -07:00
committed by GitHub
parent 37037bb3be
commit 3e7500dfe6
12 changed files with 174 additions and 66 deletions

View File

@@ -79,7 +79,7 @@ class CMakeBuild(build_ext):
def build_extension(self, ext):
llvm_include_dir, llvm_library_dir = get_llvm()
# self.debug = True
self.debug = True
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# create build directories
build_suffix = 'debug' if self.debug else 'release'

View File

@@ -698,6 +698,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'):
rs = RandomState(17)
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
x[:] = 1
# numpy result
z_ref = np.sum(x).astype(getattr(np, dtype_str))
# triton result
@@ -1132,3 +1133,25 @@ def test_constexpr_shape():
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
kernel[(1,)](x_tri)
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
# -------------
# test if
# -------------
def test_if():
@triton.jit
def kernel(Cond, XTrue, XFalse, Ret):
pid = tl.program_id(0)
cond = tl.load(Cond)
if pid % 2:
tl.store(Ret, tl.load(XTrue))
else:
tl.store(Ret, tl.load(XFalse))
cond = torch.ones(1, dtype=torch.int32, device='cuda')
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
ret = torch.empty(1, dtype=torch.float32, device='cuda')
kernel[(1,)](cond, x_true, x_false, ret)

View File

@@ -63,7 +63,7 @@ def mangle_ty(ty):
def mangle_fn(name, arg_tys, constants):
# doesn't mangle ret type, which must be a function of arg tys
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x)
key = lambda x: x.cache_key if isinstance(x, JITFunction) else repr(x)
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
mangled_constants = mangled_constants.replace('.', '_d_')
mangled_constants = mangled_constants.replace("'", '_sq_')

View File

@@ -32,6 +32,8 @@ def _to_tensor(x, builder):
return _to_tensor(x.value, builder)
elif isinstance(x, tensor):
return x
elif x is None:
return None
assert False, f'cannot convert {x} to tensor'

View File

@@ -559,7 +559,7 @@ def cast(input: tl.tensor,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
src_ty = input.type
if src_ty.is_block():
if src_ty.is_block() and not dst_ty.is_block():
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
if src_ty == dst_ty:
return input

View File

@@ -252,6 +252,7 @@ def matmul_kernel(
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
@triton.jit
def leaky_relu(x):
x = x + 1
return tl.where(x >= 0, x, 0.01 * x)
@@ -296,7 +297,7 @@ def matmul(a, b, activation=None):
torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
triton_output = matmul(a, b, activation=None)
triton_output = matmul(a, b, activation=leaky_relu)
torch_output = torch.matmul(a, b)
print(f"triton_output={triton_output}")
print(f"torch_output={torch_output}")
@@ -305,6 +306,8 @@ if triton.testing.allclose(triton_output, torch_output):
else:
print("❌ Triton and Torch differ")
print(matmul_kernel.cache_key)
exit()
# %%
# Benchmark
# --------------