[FRONTEND][BACKEND] Clean-up transpositions (#953)

This commit is contained in:
Philippe Tillet
2022-12-06 09:32:13 -08:00
committed by GitHub
parent 16e973edf2
commit 532e10cf87
12 changed files with 31 additions and 53 deletions

View File

@@ -1184,11 +1184,10 @@ void init_triton_ir(py::module &&m) {
})
.def("create_dot",
[](mlir::OpBuilder &self, mlir::Value &a, mlir::Value &b,
mlir::Value &c, bool allowTF32, bool transA,
bool transB) -> mlir::Value {
mlir::Value &c, bool allowTF32) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::DotOp>(loc, c.getType(), a, b, c,
allowTF32, transA, transB);
allowTF32);
})
.def("create_exp",
[](mlir::OpBuilder &self, mlir::Value &val) -> mlir::Value {

View File

@@ -1126,7 +1126,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
if CHAIN_DOT:
# tl.store(Zs, z)
# tl.debug_barrier()
z = tl.dot(tl.trans(z.to(tl.float16)), tl.load(Ws))
z = tl.dot(z.to(tl.float16), tl.load(Ws))
tl.store(Zs, z)
# input
rs = RandomState(17)
@@ -1173,7 +1173,7 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
denom = np.sum(num, axis=-1, keepdims=True)
z_ref = num / denom
if epilogue == 'chain-dot':
z_ref = np.matmul(z_ref.T, w)
z_ref = np.matmul(z_ref, w)
# compare
# print(z_ref[:,0], z_tri[:,0])
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)

View File

@@ -762,6 +762,9 @@ class CodeGenerator(ast.NodeVisitor):
def visit_Attribute(self, node):
lhs = self.visit(node.value)
if isinstance(lhs, triton.language.tensor):
if node.attr == "T":
return triton.language.semantic.trans(lhs, builder=self.builder)
return getattr(lhs, node.attr)
def visit_Expr(self, node):

View File

@@ -603,10 +603,9 @@ class tensor:
assert False, "unsupported"
return ret
# x[:, None, :, None]
# x = expand_dims(x, axis=1)
# x = expand_dims(x, axis=2)
@property
def T(self):
assert False, "Transposition must be created by the AST Visitor"
@builtin
def to(self, dtype, bitcast=False, _builder=None):
@@ -766,7 +765,7 @@ def view(input, shape, _builder=None):
@builtin
def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=None):
def dot(input, other, allow_tf32=True, _builder=None):
"""
Returns the matrix product of two blocks.
@@ -778,7 +777,7 @@ def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=No
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
"""
allow_tf32 = _constexpr_to_value(allow_tf32)
return semantic.dot(input, other, allow_tf32, trans_a, trans_b, _builder)
return semantic.dot(input, other, allow_tf32, _builder)
# -----------------------

View File

@@ -976,8 +976,6 @@ def atomic_xchg(ptr: tl.tensor,
def dot(lhs: tl.tensor,
rhs: tl.tensor,
allow_tf32: bool,
trans_a: bool,
trans_b: bool,
builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
if lhs.type.scalar.is_int():
@@ -986,11 +984,11 @@ def dot(lhs: tl.tensor,
else:
_0 = builder.get_float32(0)
ret_scalar_ty = tl.float32
M = lhs.type.shape[1 if trans_a else 0]
N = rhs.type.shape[0 if trans_b else 1]
M = lhs.type.shape[0]
N = rhs.type.shape[1]
_0 = builder.create_splat(_0, [M, N])
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, trans_a, trans_b),
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
ret_ty)