Merge remote-tracking branch 'origin/master' into phil/fused-attention-perf-fixup
This commit is contained in:
@@ -141,10 +141,10 @@ class CMakeBuild(build_ext):
|
||||
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
|
||||
"-DTRITON_BUILD_TUTORIALS=OFF",
|
||||
"-DTRITON_BUILD_PYTHON_MODULE=ON",
|
||||
# '-DPYTHON_EXECUTABLE=' + sys.executable,
|
||||
'-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON',
|
||||
"-DPython3_EXECUTABLE:FILEPATH=" + sys.executable,
|
||||
"-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON",
|
||||
"-DPYTHON_INCLUDE_DIRS=" + python_include_dir,
|
||||
"-DLLVM_EXTERNAL_LIT=" + lit_dir
|
||||
"-DLLVM_EXTERNAL_LIT=" + lit_dir,
|
||||
] + thirdparty_cmake_args
|
||||
|
||||
# configuration
|
||||
|
@@ -491,10 +491,9 @@ def make_ptr_str(name, shape):
|
||||
# TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>``
|
||||
@pytest.mark.parametrize("expr, dtype_str", [
|
||||
(f'x[{s}]', d)
|
||||
for s in ['None, :', ':, None']
|
||||
# FIXME: 3d indexing doesn't work
|
||||
#'None, :, :',
|
||||
# ':, :, None']
|
||||
for s in ['None, :', ':, None',
|
||||
'None, :, :',
|
||||
':, :, None']
|
||||
for d in ['int32', 'uint32', 'uint16']
|
||||
])
|
||||
def test_index1d(expr, dtype_str, device='cuda'):
|
||||
@@ -1228,20 +1227,20 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
# FIXME: Unsupported layout found in ConvertSplatLikeOp
|
||||
# def test_dot_without_load():
|
||||
# @triton.jit
|
||||
# def kernel(out):
|
||||
# pid = tl.program_id(axis=0)
|
||||
# a = tl.zeros((32, 32), tl.float32)
|
||||
# b = tl.zeros((32, 32), tl.float32)
|
||||
# c = tl.zeros((32, 32), tl.float32)
|
||||
# c = tl.dot(a, b)
|
||||
# pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
# tl.store(pout, c)
|
||||
#
|
||||
# out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
|
||||
# kernel[(1,)](out)
|
||||
|
||||
def test_dot_without_load():
|
||||
@triton.jit
|
||||
def kernel(out):
|
||||
pid = tl.program_id(axis=0)
|
||||
a = tl.zeros((32, 32), tl.float32)
|
||||
b = tl.zeros((32, 32), tl.float32)
|
||||
c = tl.zeros((32, 32), tl.float32)
|
||||
c = tl.dot(a, b)
|
||||
pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
tl.store(pout, c)
|
||||
|
||||
out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
|
||||
kernel[(1,)](out)
|
||||
|
||||
# ---------------
|
||||
# test arange
|
||||
|
@@ -20,6 +20,8 @@ from .core import (
|
||||
atomic_xor,
|
||||
bfloat16,
|
||||
block_type,
|
||||
broadcast,
|
||||
broadcast_to,
|
||||
cat,
|
||||
cdiv,
|
||||
constexpr,
|
||||
@@ -105,6 +107,8 @@ __all__ = [
|
||||
"atomic_xor",
|
||||
"bfloat16",
|
||||
"block_type",
|
||||
"broadcast",
|
||||
"broadcast_to",
|
||||
"builtin",
|
||||
"cat",
|
||||
"cdiv",
|
||||
|
@@ -596,11 +596,9 @@ class tensor:
|
||||
if isinstance(slices, slice):
|
||||
slices = [slices]
|
||||
ret = self
|
||||
n_inserted = 0
|
||||
for dim, sl in enumerate(slices):
|
||||
if isinstance(sl, constexpr) and sl.value is None:
|
||||
ret = semantic.expand_dims(ret, dim + n_inserted, _builder)
|
||||
n_inserted += 1
|
||||
ret = semantic.expand_dims(ret, dim, _builder)
|
||||
elif sl == slice(None, None, None):
|
||||
pass
|
||||
else:
|
||||
|
Reference in New Issue
Block a user