[FRONTEND][BACKEND] Fixes for cat / reshape / addptr (#959)
Most notably, this PR: - changes the traits (and assembly format) of addptr so it can handle offsets that have arbitrary integer width. - adds support for `cat`
This commit is contained in:
@@ -64,12 +64,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
%7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src>
|
||||
%8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src>
|
||||
%9 = arith.addi %8, %7 : tensor<128x128xi32, #src>
|
||||
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>
|
||||
%10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr<f16>, #src>, tensor<128x128xi32, #src>
|
||||
%11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src>
|
||||
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<128x128x!tt.ptr<f16>, #dst>
|
||||
%12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst>
|
||||
%13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst>
|
||||
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>
|
||||
%14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr<f16>, #dst>, tensor<128x128xi32, #dst>
|
||||
tt.store %14, %13 : tensor<128x128xf16, #dst>
|
||||
return
|
||||
}
|
||||
|
@@ -371,6 +371,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# 1. we have an orelse node
|
||||
# or
|
||||
# 2. the then block defines new variable
|
||||
else_defs = {}
|
||||
if then_defs or node.orelse:
|
||||
if node.orelse:
|
||||
self.lscope = liveins
|
||||
@@ -381,7 +382,6 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
else_defs = self.local_defs.copy()
|
||||
else:
|
||||
# collect else_defs
|
||||
else_defs = {}
|
||||
for name in then_defs:
|
||||
if name in liveins:
|
||||
assert self.is_triton_tensor(then_defs[name])
|
||||
|
@@ -55,6 +55,7 @@ from .core import (
|
||||
printf,
|
||||
program_id,
|
||||
ravel,
|
||||
reshape,
|
||||
sigmoid,
|
||||
sin,
|
||||
softmax,
|
||||
@@ -70,6 +71,7 @@ from .core import (
|
||||
uint64,
|
||||
uint8,
|
||||
umulhi,
|
||||
view,
|
||||
void,
|
||||
where,
|
||||
xor_sum,
|
||||
@@ -149,6 +151,7 @@ __all__ = [
|
||||
"randn",
|
||||
"randn4x",
|
||||
"ravel",
|
||||
"reshape",
|
||||
"sigmoid",
|
||||
"sin",
|
||||
"softmax",
|
||||
@@ -165,6 +168,7 @@ __all__ = [
|
||||
"uint64",
|
||||
"uint8",
|
||||
"umulhi",
|
||||
"view",
|
||||
"void",
|
||||
"where",
|
||||
"xor_sum",
|
||||
|
@@ -731,7 +731,7 @@ def trans(input, _builder=None):
|
||||
return semantic.trans(input, _builder)
|
||||
|
||||
@builtin
|
||||
def cat(input, other, _builder=None):
|
||||
def cat(input, other, can_reorder=False, _builder=None):
|
||||
"""
|
||||
Concatenate the given blocks
|
||||
|
||||
@@ -739,8 +739,12 @@ def cat(input, other, _builder=None):
|
||||
:type input:
|
||||
:param other: The second input tensor.
|
||||
:type other:
|
||||
:param reorder: Compiler hint. If true, the compiler is
|
||||
allowed to reorder elements while concatenating inputs.
|
||||
Only use if the order does not matter (e.g., result is
|
||||
only used in reduction ops)
|
||||
"""
|
||||
return semantic.cat(input, other, _builder)
|
||||
return semantic.cat(input, other, can_reorder, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@@ -761,7 +765,8 @@ def view(input, shape, _builder=None):
|
||||
@builtin
|
||||
def reshape(input, shape, _builder=None):
|
||||
# TODO: should be more than just a view
|
||||
return view(input, shape, _builder)
|
||||
shape = [x.value for x in shape]
|
||||
return semantic.view(input, shape, _builder)
|
||||
|
||||
# -----------------------
|
||||
# Linear Algebra
|
||||
|
@@ -498,9 +498,11 @@ def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
|
||||
|
||||
|
||||
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
# TODO: check types
|
||||
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), lhs.type)
|
||||
def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor:
|
||||
assert can_reorder, "current implementation of `cat` always may reorder elements"
|
||||
assert len(lhs.shape) == 1
|
||||
ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
|
||||
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type)
|
||||
|
||||
def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
if len(input.shape) != 2:
|
||||
|
Reference in New Issue
Block a user