[FRONTEND][BACKEND] Fixes for cat / reshape / addptr (#959)

Most notably, this PR:
- changes the traits (and assembly format) of addptr so it can handle offsets that have arbitrary integer width.
- adds support for `cat`
This commit is contained in:
Philippe Tillet
2022-12-06 23:29:50 -08:00
committed by GitHub
parent 981aee7f1e
commit b2b793dfb5
24 changed files with 199 additions and 132 deletions

View File

@@ -64,12 +64,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
tt.store %14, %13 : tensor<128x128xf16, #dst>
return
}

View File

@@ -371,6 +371,7 @@ class CodeGenerator(ast.NodeVisitor):
# 1. we have an orelse node
# or
# 2. the then block defines new variable
else_defs = {}
if then_defs or node.orelse:
if node.orelse:
self.lscope = liveins
@@ -381,7 +382,6 @@ class CodeGenerator(ast.NodeVisitor):
else_defs = self.local_defs.copy()
else:
# collect else_defs
else_defs = {}
for name in then_defs:
if name in liveins:
assert self.is_triton_tensor(then_defs[name])

View File

@@ -55,6 +55,7 @@ from .core import (
printf,
program_id,
ravel,
reshape,
sigmoid,
sin,
softmax,
@@ -70,6 +71,7 @@ from .core import (
uint64,
uint8,
umulhi,
view,
void,
where,
xor_sum,
@@ -149,6 +151,7 @@ __all__ = [
"randn",
"randn4x",
"ravel",
"reshape",
"sigmoid",
"sin",
"softmax",
@@ -165,6 +168,7 @@ __all__ = [
"uint64",
"uint8",
"umulhi",
"view",
"void",
"where",
"xor_sum",

View File

@@ -731,7 +731,7 @@ def trans(input, _builder=None):
return semantic.trans(input, _builder)
@builtin
def cat(input, other, _builder=None):
def cat(input, other, can_reorder=False, _builder=None):
"""
Concatenate the given blocks
@@ -739,8 +739,12 @@ def cat(input, other, _builder=None):
:type input:
:param other: The second input tensor.
:type other:
:param reorder: Compiler hint. If true, the compiler is
allowed to reorder elements while concatenating inputs.
Only use if the order does not matter (e.g., result is
only used in reduction ops)
"""
return semantic.cat(input, other, _builder)
return semantic.cat(input, other, can_reorder, _builder)
@builtin
@@ -761,7 +765,8 @@ def view(input, shape, _builder=None):
@builtin
def reshape(input, shape, _builder=None):
# TODO: should be more than just a view
return view(input, shape, _builder)
shape = [x.value for x in shape]
return semantic.view(input, shape, _builder)
# -----------------------
# Linear Algebra

View File

@@ -498,9 +498,11 @@ def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
# TODO: check types
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type)
def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor:
assert can_reorder, "current implementation of `cat` always may reorder elements"
assert len(lhs.shape) == 1
ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type)
def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
if len(input.shape) != 2: