[BACKEND] Various bug fixes; making reductions faster (#533)
This commit is contained in:
@@ -79,7 +79,7 @@ class CMakeBuild(build_ext):
|
||||
|
||||
def build_extension(self, ext):
|
||||
llvm_include_dir, llvm_library_dir = get_llvm()
|
||||
# self.debug = True
|
||||
self.debug = True
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||
# create build directories
|
||||
build_suffix = 'debug' if self.debug else 'release'
|
||||
|
@@ -698,6 +698,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'):
|
||||
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
||||
x[:] = 1
|
||||
# numpy result
|
||||
z_ref = np.sum(x).astype(getattr(np, dtype_str))
|
||||
# triton result
|
||||
@@ -1132,3 +1133,25 @@ def test_constexpr_shape():
|
||||
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
||||
kernel[(1,)](x_tri)
|
||||
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
|
||||
|
||||
# -------------
|
||||
# test if
|
||||
# -------------
|
||||
|
||||
|
||||
def test_if():
|
||||
|
||||
@triton.jit
|
||||
def kernel(Cond, XTrue, XFalse, Ret):
|
||||
pid = tl.program_id(0)
|
||||
cond = tl.load(Cond)
|
||||
if pid % 2:
|
||||
tl.store(Ret, tl.load(XTrue))
|
||||
else:
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
|
||||
cond = torch.ones(1, dtype=torch.int32, device='cuda')
|
||||
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
|
||||
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
|
||||
ret = torch.empty(1, dtype=torch.float32, device='cuda')
|
||||
kernel[(1,)](cond, x_true, x_false, ret)
|
||||
|
@@ -63,7 +63,7 @@ def mangle_ty(ty):
|
||||
def mangle_fn(name, arg_tys, constants):
|
||||
# doesn't mangle ret type, which must be a function of arg tys
|
||||
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
|
||||
key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x)
|
||||
key = lambda x: x.cache_key if isinstance(x, JITFunction) else repr(x)
|
||||
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
|
||||
mangled_constants = mangled_constants.replace('.', '_d_')
|
||||
mangled_constants = mangled_constants.replace("'", '_sq_')
|
||||
|
@@ -32,6 +32,8 @@ def _to_tensor(x, builder):
|
||||
return _to_tensor(x.value, builder)
|
||||
elif isinstance(x, tensor):
|
||||
return x
|
||||
elif x is None:
|
||||
return None
|
||||
assert False, f'cannot convert {x} to tensor'
|
||||
|
||||
|
||||
|
@@ -559,7 +559,7 @@ def cast(input: tl.tensor,
|
||||
dst_ty: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
src_ty = input.type
|
||||
if src_ty.is_block():
|
||||
if src_ty.is_block() and not dst_ty.is_block():
|
||||
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
|
||||
if src_ty == dst_ty:
|
||||
return input
|
||||
|
@@ -252,6 +252,7 @@ def matmul_kernel(
|
||||
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
||||
@triton.jit
|
||||
def leaky_relu(x):
|
||||
x = x + 1
|
||||
return tl.where(x >= 0, x, 0.01 * x)
|
||||
|
||||
|
||||
@@ -296,7 +297,7 @@ def matmul(a, b, activation=None):
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||
triton_output = matmul(a, b, activation=None)
|
||||
triton_output = matmul(a, b, activation=leaky_relu)
|
||||
torch_output = torch.matmul(a, b)
|
||||
print(f"triton_output={triton_output}")
|
||||
print(f"torch_output={torch_output}")
|
||||
@@ -305,6 +306,8 @@ if triton.testing.allclose(triton_output, torch_output):
|
||||
else:
|
||||
print("❌ Triton and Torch differ")
|
||||
|
||||
print(matmul_kernel.cache_key)
|
||||
exit()
|
||||
# %%
|
||||
# Benchmark
|
||||
# --------------
|
||||
|
Reference in New Issue
Block a user