[FRONTEND] Add scalar type support for some ops (#661)
This PR adds basic support for scalar-type inputs to some ops (cast and pointer arithmetics) for Triton-MLIR. Also renames getelementptr -> addptr
This commit is contained in:
@@ -29,18 +29,22 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [SameOperandsAndResultShape, NoSideEffec
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast int64 to pointer";
|
||||
|
||||
let arguments = (ins I64Tensor:$from);
|
||||
let arguments = (ins TT_I64Like:$from);
|
||||
|
||||
let results = (outs TT_PtrTensor:$result);
|
||||
let results = (outs TT_PtrLike:$result);
|
||||
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
}
|
||||
|
||||
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [SameOperandsAndResultShape, NoSideEffect,
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Cast pointer to int64";
|
||||
|
||||
let arguments = (ins TT_PtrTensor:$from);
|
||||
let arguments = (ins TT_PtrLike:$from);
|
||||
|
||||
let results = (outs I64Tensor:$result);
|
||||
let results = (outs TT_I64Like:$result);
|
||||
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
}
|
||||
|
||||
def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
||||
@@ -54,9 +58,11 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
||||
BF8 <-> F8, FP16, FP32
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_FloatTensor:$from);
|
||||
let arguments = (ins TT_FloatLike:$from);
|
||||
|
||||
let results = (outs TT_FloatTensor:$result);
|
||||
let results = (outs TT_FloatLike:$result);
|
||||
|
||||
let assemblyFormat = "$from attr-dict `:` type($from) `->` type($result)";
|
||||
|
||||
// TODO: We need a verifier here.
|
||||
}
|
||||
@@ -127,16 +133,16 @@ def TT_StoreOp : TT_Op<"store",
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TT_GEPOp : TT_Op<"getelementptr",
|
||||
def TT_AddPtrOp : TT_Op<"addptr",
|
||||
[NoSideEffect, SameOperandsAndResultShape,
|
||||
TypesMatchWith<"result type matches ptr type",
|
||||
"result", "ptr", "$_self">,
|
||||
TypesMatchWith<"result shape matches offset shape",
|
||||
"result", "offset",
|
||||
"getI32SameShape($_self)">]> {
|
||||
let arguments = (ins TT_PtrTensor:$ptr, I32Tensor:$offset);
|
||||
let arguments = (ins TT_PtrLike:$ptr, TT_I32Like:$offset);
|
||||
|
||||
let results = (outs TT_PtrTensor:$result);
|
||||
let results = (outs TT_PtrLike:$result);
|
||||
|
||||
let assemblyFormat = "$ptr `,` $offset attr-dict `:` type($result)";
|
||||
}
|
||||
@@ -278,7 +284,7 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas"> {
|
||||
return $old
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_Pointer:$ptr, TT_Type:$cmp, TT_Type:$val);
|
||||
let arguments = (ins TT_Ptr:$ptr, TT_Type:$cmp, TT_Type:$val);
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
}
|
||||
@@ -318,7 +324,7 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
||||
|
||||
let arguments = (ins I32Attr:$start, I32Attr:$end);
|
||||
|
||||
let results = (outs TT_IntegerTensor:$result);
|
||||
let results = (outs TT_IntTensor:$result);
|
||||
|
||||
let assemblyFormat = "attr-dict `:` type($result)";
|
||||
}
|
||||
|
@@ -12,18 +12,36 @@ class TritonTypeDef<string name, string _mnemonic>
|
||||
let mnemonic = _mnemonic;
|
||||
}
|
||||
|
||||
// Floating-point Type
|
||||
def F8 : TritonTypeDef<"Float8", "f8">;
|
||||
def BF8 : TritonTypeDef<"BFloat8", "bf8">;
|
||||
|
||||
def TT_Float : AnyTypeOf<[F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
|
||||
|
||||
// IntegerType
|
||||
// Boolean Type
|
||||
// TT_Bool -> I1
|
||||
def TT_BoolTensor : TensorOf<[I1]>;
|
||||
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;
|
||||
|
||||
// Integer Type
|
||||
def TT_Int : AnyTypeOf<[I1, I8, I16, I32, I64], "integer">;
|
||||
def TT_IntegerTensor : TensorOf<[TT_Int]>;
|
||||
def TT_IntTensor : TensorOf<[TT_Int]>;
|
||||
def TT_IntLike : AnyTypeOf<[TT_Int, TT_IntTensor]>;
|
||||
|
||||
// PointerType
|
||||
def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
|
||||
// I32 Type
|
||||
// TT_I32 -> I32
|
||||
// TT_I32Tensor -> I32Tensor
|
||||
def TT_I32Like: AnyTypeOf<[I32, I32Tensor]>;
|
||||
|
||||
// I64 Type
|
||||
// TT_I64 -> I64
|
||||
// TT_I64Tensor -> I64Tensor
|
||||
def TT_I64Like: AnyTypeOf<[I64, I64Tensor]>;
|
||||
|
||||
// Pointer Type
|
||||
def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> {
|
||||
let summary = "pointer type";
|
||||
|
||||
let description = [{
|
||||
@@ -43,12 +61,12 @@ def TT_Pointer : TritonTypeDef<"Pointer", "ptr"> {
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
def TT_PtrTensor : TensorOf<[TT_Pointer]>;
|
||||
def TT_PtrTensor : TensorOf<[TT_Ptr]>;
|
||||
def TT_PtrLike : AnyTypeOf<[TT_Ptr, TT_PtrTensor]>;
|
||||
|
||||
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntegerTensor]>;
|
||||
def TT_FpIntTensor : AnyTypeOf<[TT_FloatTensor, TT_IntTensor]>;
|
||||
def TT_Tensor : AnyTypeOf<[TT_FpIntTensor, TT_PtrTensor]>;
|
||||
|
||||
def TT_Type : AnyTypeOf<[TT_Float, TT_FloatTensor, TT_Int, TT_IntegerTensor,
|
||||
TT_Pointer, TT_PtrTensor]>;
|
||||
def TT_Type : AnyTypeOf<[TT_FloatLike, TT_IntLike, TT_PtrLike]>;
|
||||
|
||||
#endif
|
||||
|
@@ -8,7 +8,7 @@ def TritonCombineOps : Pass</*cli-arg*/"triton-combine", /*Op*/"mlir::ModuleOp">
|
||||
let description = [{
|
||||
dot(a, b, 0) + c => dot(a, b, c)
|
||||
|
||||
gep(gep(ptr, idx0), idx1) => gep(ptr, AddI(idx0, idx1))
|
||||
addptr(addptr(ptr, idx0), idx1) => addptr(ptr, AddI(idx0, idx1))
|
||||
|
||||
select(cond, load(ptrs, broadcast(cond), ???), other) =>
|
||||
load(ptrs, broadcast(cond), other)
|
||||
|
@@ -10,12 +10,6 @@ include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
|
||||
def TT_BoolTensor : TensorOf<[I1]>;
|
||||
|
||||
def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>;
|
||||
def TT_IntegerLike : AnyTypeOf<[TT_Int, TT_IntegerTensor]>;
|
||||
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
|
||||
|
||||
class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<TritonGPU_Dialect, mnemonic, traits>;
|
||||
|
||||
@@ -48,8 +42,8 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
|
||||
let description = [{}];
|
||||
|
||||
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
|
||||
TT_IntegerLike:$lhs,
|
||||
TT_IntegerLike:$rhs);
|
||||
TT_IntLike:$lhs,
|
||||
TT_IntLike:$rhs);
|
||||
|
||||
let results = (outs TT_BoolLike:$result);
|
||||
}
|
||||
|
@@ -126,7 +126,7 @@ ChangeResult AxisInfoAnalysis::visitOperation(
|
||||
}
|
||||
}
|
||||
// Addition
|
||||
if (llvm::isa<arith::AddIOp, triton::GEPOp>(op)) {
|
||||
if (llvm::isa<arith::AddIOp, triton::AddPtrOp>(op)) {
|
||||
auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) {
|
||||
return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)),
|
||||
gcd(lhs.getConstancy(d), rhs.getContiguity(d)));
|
||||
|
@@ -701,7 +701,7 @@ struct StoreOpConversion
|
||||
|
||||
const int numVecs = numElems / vec;
|
||||
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
||||
// TODO: optimization when ptr is GEP with constant offset
|
||||
// TODO: optimization when ptr is AddPtr with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
||||
@@ -1173,12 +1173,13 @@ struct GetProgramIdOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct GEPOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::GEPOp> {
|
||||
struct AddPtrOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AddPtrOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::GEPOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
triton::AddPtrOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
@@ -1298,7 +1299,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
|
||||
patterns.add<GEPOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<LoadOpConversion>(typeConverter, analysis, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
||||
|
@@ -323,7 +323,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
patterns.add< // TODO: view should have custom pattern that views the layout
|
||||
TritonGenericPattern<triton::ViewOp>,
|
||||
TritonGenericPattern<triton::SplatOp>, TritonBroadcastPattern,
|
||||
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonReducePattern,
|
||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern>(
|
||||
typeConverter, context);
|
||||
|
@@ -15,7 +15,7 @@ static Type getI1SameShape(Type type) {
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
return i1Type;
|
||||
}
|
||||
|
||||
static Type getI32SameShape(Type type) {
|
||||
@@ -23,7 +23,7 @@ static Type getI32SameShape(Type type) {
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i32Type,
|
||||
tensorType.getEncoding());
|
||||
return Type();
|
||||
return i32Type;
|
||||
}
|
||||
|
||||
static Type getPointerTypeFromTensor(Type type) {
|
||||
|
@@ -194,7 +194,7 @@ public:
|
||||
patterns.add<CombineDotAddFRevPattern>(context);
|
||||
// %}
|
||||
patterns.add<CombineSelectMaskedLoadPattern>(context);
|
||||
patterns.add<CombineGEPPattern>(context);
|
||||
patterns.add<CombineAddPtrPattern>(context);
|
||||
patterns.add<CombineBroadcastConstantPattern>(context);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||
|
@@ -30,12 +30,12 @@ def CombineDotAddFRevPattern : Pat<
|
||||
[(Constraint<CPred<"isZero($0)">> $c)]>;
|
||||
|
||||
|
||||
// gep(gep(%ptr, %idx0), %idx1) => gep(%ptr, AddI(%idx0, %idx1))
|
||||
// addptr(addptr(%ptr, %idx0), %idx1) => addptr(%ptr, AddI(%idx0, %idx1))
|
||||
// Note: leave (sub %c0, %c0) canceling to ArithmeticDialect
|
||||
// (ref: ArithmeticCanonicalization.td)
|
||||
def CombineGEPPattern : Pat<
|
||||
(TT_GEPOp (TT_GEPOp $ptr, $idx0), $idx1),
|
||||
(TT_GEPOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
|
||||
def CombineAddPtrPattern : Pat<
|
||||
(TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1),
|
||||
(TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1))>;
|
||||
|
||||
// broadcast(cst) => cst
|
||||
def getConstantValue : NativeCodeCall<"getConstantValue($_builder, $0, $1)">;
|
||||
|
@@ -203,7 +203,7 @@ bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
|
||||
targetType.getEncoding());
|
||||
};
|
||||
bool hasSameTypes = op->getDialect()->getNamespace() == "arith" ||
|
||||
isa<triton::SplatOp, triton::GEPOp>(op);
|
||||
isa<triton::SplatOp, triton::AddPtrOp>(op);
|
||||
if (hasSameTypes) {
|
||||
// replace argument types
|
||||
for (auto arg : llvm::enumerate(op->getOperands())) {
|
||||
|
@@ -70,7 +70,7 @@ private:
|
||||
if (auto storeOp = llvm::dyn_cast<triton::StoreOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
if (auto gepOp = llvm::dyn_cast<triton::GEPOp>(op)) {
|
||||
if (auto addptrOp = llvm::dyn_cast<triton::AddPtrOp>(op)) {
|
||||
// TODO: fill this
|
||||
}
|
||||
// Triton builtin Ops
|
||||
|
@@ -1245,12 +1245,12 @@ void init_triton_ir(py::module &&m) {
|
||||
return mlir::Value(
|
||||
self.create<mlir::arith::ShRSIOp>(loc, lhs, rhs));
|
||||
})
|
||||
// GEP
|
||||
.def("create_gep",
|
||||
// AddPtr (similar to GEP)
|
||||
.def("create_addptr",
|
||||
[](mlir::OpBuilder &self, mlir::Value &ptr,
|
||||
mlir::Value &offset) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::GEPOp>(loc, ptr.getType(), ptr,
|
||||
return self.create<mlir::triton::AddPtrOp>(loc, ptr.getType(), ptr,
|
||||
offset);
|
||||
})
|
||||
// Comparison (int)
|
||||
|
@@ -121,7 +121,7 @@ def add(input: tl.tensor,
|
||||
if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
|
||||
input, other = other, input
|
||||
if input_scalar_ty.is_ptr():
|
||||
return tl.tensor(builder.create_gep(input.handle, other.handle), input.type)
|
||||
return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type)
|
||||
# float + float
|
||||
elif input_scalar_ty.is_floating():
|
||||
return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
|
||||
@@ -138,7 +138,7 @@ def sub(input: tl.tensor,
|
||||
scalar_ty = input.type.scalar
|
||||
# ptr - offset
|
||||
if scalar_ty.is_ptr():
|
||||
return tl.tensor(builder.create_gep(input.handle, minus(other, builder).handle),
|
||||
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle),
|
||||
input.type)
|
||||
# float - float
|
||||
if scalar_ty.is_floating():
|
||||
|
@@ -26,8 +26,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.getelementptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
|
@@ -18,7 +18,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1]
|
||||
%6 = tt.getelementptr %5, %4 : tensor<128x1x!tt.ptr<f32>>
|
||||
%6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
|
||||
%7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
|
||||
@@ -26,13 +26,13 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [128, 1]
|
||||
%9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1]
|
||||
%10 = tt.getelementptr %8, %9 : tensor<128x128x!tt.ptr<f32>>
|
||||
%10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [65536, 1] ; Constancy: [1, 1]
|
||||
%11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1]
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||
%13 = tt.getelementptr %12, %11 : tensor<128x1x!tt.ptr<f32>>
|
||||
%13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 65536] ; Constancy: [1, 1]
|
||||
%14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128]
|
||||
@@ -44,7 +44,7 @@ func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {t
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 1048576] ; Constancy: [128, 1]
|
||||
%18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32>
|
||||
// CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1]
|
||||
%19 = tt.getelementptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
|
||||
%19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1]
|
||||
%20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32>
|
||||
tt.store %19, %20, %cst : tensor<128x128xf32>
|
||||
|
@@ -30,8 +30,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.getelementptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
|
@@ -28,8 +28,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
// CHECK: Membar 13
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.getelementptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
|
55
test/Conversion/triton_ops.mlir
Normal file
55
test/Conversion/triton_ops.mlir
Normal file
@@ -0,0 +1,55 @@
|
||||
// RUN: triton-opt %s | FileCheck %s
|
||||
|
||||
func @cast_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_f32: f32, %scalar_i64: i64) {
|
||||
// scalar -> scalar
|
||||
// CHECK: i64 -> !tt.ptr<f32>
|
||||
%0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr<f32>
|
||||
// CHECK: !tt.ptr<f32> -> i64
|
||||
%1 = tt.ptr_to_int %scalar_ptr : !tt.ptr<f32> -> i64
|
||||
// CHECK: f32 -> f16
|
||||
%2 = tt.fp_to_fp %scalar_f32 : f32 -> f16
|
||||
|
||||
// 0D tensor -> 0D tensor
|
||||
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
|
||||
%tensor_f32_0d = tt.splat %scalar_f32 : (f32) -> tensor<f32>
|
||||
%tensor_i64_0d = tt.splat %scalar_i64 : (i64) -> tensor<i64>
|
||||
|
||||
// CHECK: tensor<i64> -> tensor<!tt.ptr<f32>>
|
||||
%3 = tt.int_to_ptr %tensor_i64_0d : tensor<i64> -> tensor<!tt.ptr<f32>>
|
||||
// CHECK: tensor<!tt.ptr<f32>> -> tensor<i64>
|
||||
%4 = tt.ptr_to_int %tensor_ptr_0d : tensor<!tt.ptr<f32>> -> tensor<i64>
|
||||
// CHECK: tensor<f32> -> tensor<f16>
|
||||
%5 = tt.fp_to_fp %tensor_f32_0d : tensor<f32> -> tensor<f16>
|
||||
|
||||
// 1D tensor -> 1D tensor
|
||||
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
|
||||
%tensor_f32_1d = tt.splat %scalar_f32 : (f32) -> tensor<16xf32>
|
||||
%tensor_i64_1d = tt.splat %scalar_i64 : (i64) -> tensor<16xi64>
|
||||
|
||||
// CHECK: tensor<16xi64> -> tensor<16x!tt.ptr<f32>>
|
||||
%6 = tt.int_to_ptr %tensor_i64_1d : tensor<16xi64> -> tensor<16x!tt.ptr<f32>>
|
||||
// CHECK: tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
|
||||
%7 = tt.ptr_to_int %tensor_ptr_1d : tensor<16x!tt.ptr<f32>> -> tensor<16xi64>
|
||||
// CHECK: tensor<16xf32> -> tensor<16xf16>
|
||||
%8 = tt.fp_to_fp %tensor_f32_1d : tensor<16xf32> -> tensor<16xf16>
|
||||
return
|
||||
}
|
||||
|
||||
func @addptr_ops(%scalar_ptr: !tt.ptr<f32>, %scalar_i32: i32) {
|
||||
// scalar -> scalar
|
||||
// CHECK: !tt.ptr<f32>
|
||||
%0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr<f32>
|
||||
|
||||
// 0D tensor -> 0D tensor
|
||||
%tensor_ptr_0d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<!tt.ptr<f32>>
|
||||
%tensor_i32_0d = tt.splat %scalar_i32 : (i32) -> tensor<i32>
|
||||
// CHECK: tensor<!tt.ptr<f32>>
|
||||
%1 = tt.addptr %tensor_ptr_0d, %tensor_i32_0d : tensor<!tt.ptr<f32>>
|
||||
|
||||
// 1D tensor -> 1D tensor
|
||||
%tensor_ptr_1d = tt.splat %scalar_ptr : (!tt.ptr<f32>) -> tensor<16x!tt.ptr<f32>>
|
||||
%tensor_i32_1d = tt.splat %scalar_i32 : (i32) -> tensor<16xi32>
|
||||
// CHECK: tensor<16x!tt.ptr<f32>>
|
||||
%2 = tt.addptr %tensor_ptr_1d, %tensor_i32_1d : tensor<16x!tt.ptr<f32>>
|
||||
return
|
||||
}
|
@@ -82,9 +82,9 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%3 = tt.splat %1 : (i32) -> tensor<256xi32, #blocked0>
|
||||
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
|
||||
%5 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%6 = tt.getelementptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
|
||||
// CHECK: ld.global.v4.b32
|
||||
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
@@ -92,7 +92,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
|
||||
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%13 = tt.getelementptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
|
||||
// Store 4 elements to global
|
||||
// CHECK: st.global.b32.v4
|
||||
@@ -104,7 +104,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
||||
|
||||
|
||||
// TODO: Add a testcase to verify the optimization when ptr of the LoadOp
|
||||
// is from a GEP with const idx
|
||||
// is from an addptr with const idx
|
||||
|
||||
// -----
|
||||
|
||||
@@ -187,11 +187,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: basic_gep
|
||||
func @basic_gep(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
// CHECK-LABEL: basic_addptr
|
||||
func @basic_addptr(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
|
||||
// CHECK: llvm.getelementptr
|
||||
// CHECK: llvm.getelementptr
|
||||
%0 = tt.getelementptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
%0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@@ -22,8 +22,8 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32
|
||||
return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_combine_gep_pattern
|
||||
func @test_combine_gep_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
||||
// CHECK-LABEL: @test_combine_addptr_pattern
|
||||
func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
||||
%off0 = arith.constant 10 : i32
|
||||
%off1 = arith.constant 15 : i32
|
||||
|
||||
@@ -37,9 +37,9 @@ func @test_combine_gep_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
|
||||
%idx0 = tt.broadcast %off0 : (i32) -> tensor<8xi32>
|
||||
%idx1 = tt.broadcast %off1 : (i32) -> tensor<8xi32>
|
||||
|
||||
// CHECK-NEXT: %1 = tt.getelementptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>
|
||||
%ptr0 = tt.getelementptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>
|
||||
%ptr1 = tt.getelementptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>
|
||||
// CHECK-NEXT: %1 = tt.addptr %[[tmp0]], %[[cst]] : tensor<8x!tt.ptr<f32>>
|
||||
%ptr0 = tt.addptr %base_, %idx0 : tensor<8x!tt.ptr<f32>>
|
||||
%ptr1 = tt.addptr %ptr0, %idx1 : tensor<8x!tt.ptr<f32>>
|
||||
|
||||
return %ptr1 : tensor<8x!tt.ptr<f32>>
|
||||
}
|
||||
|
@@ -11,9 +11,9 @@ module {
|
||||
%5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32>
|
||||
%6 = arith.cmpi slt, %4, %5 : tensor<256xi32>
|
||||
%7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
|
||||
%8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr<f32>>
|
||||
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>>
|
||||
%9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
|
||||
%10 = tt.getelementptr %9, %4 : tensor<256x!tt.ptr<f32>>
|
||||
%10 = tt.addptr %9, %4 : tensor<256x!tt.ptr<f32>>
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%11 = tt.broadcast %cst : (f32) -> tensor<256xf32>
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
@@ -31,13 +31,13 @@ module {
|
||||
%22 = arith.addf %19, %21 : tensor<256xf32>
|
||||
%23 = arith.addf %arg7, %22 : tensor<256xf32>
|
||||
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32>
|
||||
%25 = tt.getelementptr %arg8, %24 : tensor<256x!tt.ptr<f32>>
|
||||
%25 = tt.addptr %arg8, %24 : tensor<256x!tt.ptr<f32>>
|
||||
%26 = tt.broadcast %arg5 : (i32) -> tensor<256xi32>
|
||||
%27 = tt.getelementptr %arg9, %26 : tensor<256x!tt.ptr<f32>>
|
||||
%27 = tt.addptr %arg9, %26 : tensor<256x!tt.ptr<f32>>
|
||||
scf.yield %23, %25, %27 : tensor<256xf32>, tensor<256x!tt.ptr<f32>>, tensor<256x!tt.ptr<f32>>
|
||||
}
|
||||
%16 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>>
|
||||
%17 = tt.getelementptr %16, %4 : tensor<256x!tt.ptr<f32>>
|
||||
%17 = tt.addptr %16, %4 : tensor<256x!tt.ptr<f32>>
|
||||
tt.store %17, %15#0, %6 : tensor<256xf32>
|
||||
return
|
||||
}
|
||||
@@ -57,9 +57,9 @@ module {
|
||||
// %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %8 = tt.getelementptr %7, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %8 = tt.addptr %7, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %10 = tt.getelementptr %9, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %10 = tt.addptr %9, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %11 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %12 = arith.index_cast %arg4 : i32 to index
|
||||
// %13 = arith.cmpi slt, %c0, %12 : index
|
||||
@@ -72,9 +72,9 @@ module {
|
||||
// %20 = arith.andi %6, %19 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %21 = triton_gpu.copy_async %10, %20, %18 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %22 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %23 = tt.getelementptr %8, %22, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %23 = tt.addptr %8, %22, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %25 = tt.getelementptr %10, %24, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %25 = tt.addptr %10, %24, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %26 = arith.cmpi slt, %c32, %12 : index
|
||||
// %27 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %28 = tt.broadcast %26 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
@@ -85,9 +85,9 @@ module {
|
||||
// %33 = arith.andi %6, %32 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %34 = triton_gpu.copy_async %25, %33, %31 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %35 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %36 = tt.getelementptr %23, %35, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %36 = tt.addptr %23, %35, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %37 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %38 = tt.getelementptr %25, %37, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %38 = tt.addptr %25, %37, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %39 = arith.cmpi slt, %c64, %12 : index
|
||||
// %40 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %41 = tt.broadcast %39 : (i1) -> tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
@@ -98,16 +98,16 @@ module {
|
||||
// %46 = arith.andi %6, %45 : tensor<256xi1, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %47 = triton_gpu.copy_async %38, %46, %44 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %48 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %49 = tt.getelementptr %36, %48, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %49 = tt.addptr %36, %48, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %50 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %51 = tt.getelementptr %38, %50, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %51 = tt.addptr %38, %50, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %52:12 = scf.for %arg6 = %c0 to %12 step %c32 iter_args(%arg7 = %11, %arg8 = %8, %arg9 = %10, %arg10 = %17, %arg11 = %30, %arg12 = %43, %arg13 = %21, %arg14 = %34, %arg15 = %47, %arg16 = %51, %arg17 = %49, %arg18 = %c64) -> (tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, index) {
|
||||
// %55 = arith.addf %arg10, %arg13 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %56 = arith.addf %arg7, %55 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %57 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %58 = tt.getelementptr %arg8, %57, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %58 = tt.addptr %arg8, %57, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %59 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %60 = tt.getelementptr %arg9, %59, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %60 = tt.addptr %arg9, %59, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %61 = arith.addi %arg18, %c32 : index
|
||||
// %62 = arith.cmpi slt, %61, %12 : index
|
||||
// %63 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
@@ -117,13 +117,13 @@ module {
|
||||
// %67 = tt.broadcast %cst : (f32) -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %68 = triton_gpu.copy_async %arg16, %65, %67 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">> -> tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %69 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %70 = tt.getelementptr %arg17, %69, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %70 = tt.addptr %arg17, %69, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %71 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %72 = tt.getelementptr %arg16, %71, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %72 = tt.addptr %arg16, %71, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// scf.yield %56, %58, %60, %arg11, %arg12, %66, %arg14, %arg15, %68, %72, %70, %61 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>, index
|
||||
// }
|
||||
// %53 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %54 = tt.getelementptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// %54 = tt.addptr %53, %4, : tensor<256x!tt.ptr<f32>, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// tt.store %54, %52#0, %6 : tensor<256xf32, #triton_gpu<"coalesced encoding<threadTileSize = 1, blockTileSize = 32, order = 0>">>
|
||||
// return
|
||||
// }
|
||||
|
@@ -28,20 +28,20 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%10 = tt.getelementptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%11 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%12 = tt.getelementptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2>
|
||||
%15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%19 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, #blocked1>
|
||||
tt.store %18, %19, %cst : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
|
@@ -70,7 +70,7 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
|
||||
// CHECK: %5 = arith.muli %2, %3 : tensor<64x1xi32, [[row_layout]]>
|
||||
// CHECK: %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>>
|
||||
// CHECK: %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>>
|
||||
// CHECK: %8 = tt.getelementptr %4, %5 : tensor<64x1x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: %8 = tt.addptr %4, %5 : tensor<64x1x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: %9 = tt.expand_dims %7 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>>) -> tensor<1x64xi32, [[row_layout]]>
|
||||
// CHECK: %10 = tt.broadcast %8 : (tensor<64x1x!tt.ptr<f32>, [[row_layout]]>) -> tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: %11 = tt.broadcast %9 : (tensor<1x64xi32, [[row_layout]]>) -> tensor<64x64xi32, [[row_layout]]>
|
||||
@@ -78,13 +78,13 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
|
||||
// CHECK: %13 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[col_layout]]}>>) -> tensor<64x1xi32, [[col_layout]]>
|
||||
// CHECK: %14 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>>) -> tensor<1x64xi32, [[col_layout]]>
|
||||
// CHECK: %15 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, [[col_layout]]>
|
||||
// CHECK: %16 = tt.getelementptr %12, %13 : tensor<64x1x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: %16 = tt.addptr %12, %13 : tensor<64x1x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: %17 = arith.muli %14, %15 : tensor<1x64xi32, [[col_layout]]>
|
||||
// CHECK: %18 = tt.broadcast %16 : (tensor<64x1x!tt.ptr<f32>, [[col_layout]]>) -> tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: %19 = tt.broadcast %17 : (tensor<1x64xi32, [[col_layout]]>) -> tensor<64x64xi32, [[col_layout]]>
|
||||
// CHECK: %20 = tt.getelementptr %10, %11 : tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: %20 = tt.addptr %10, %11 : tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK: %22 = tt.getelementptr %18, %19 : tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: %22 = tt.addptr %18, %19 : tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
||||
// CHECK: tt.store %22, %23, %cst_1 : tensor<64x64xf32, [[col_layout]]>
|
||||
// CHECK: return
|
||||
@@ -95,20 +95,20 @@ func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%10 = tt.getelementptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%11 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%12 = tt.getelementptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%12 = tt.addptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2>
|
||||
%15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%18 = tt.addptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%19 = triton_gpu.convert_layout %10 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
|
||||
%20 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
|
||||
%21 = triton_gpu.convert_layout %cst : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
|
||||
@@ -127,7 +127,7 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
||||
// CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[row_layout]]>)
|
||||
// CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK-NEXT: {{.*}} = tt.getelementptr {{.*}} : tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK-NEXT: {{.*}} = tt.addptr {{.*}} : tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout_novec]]>
|
||||
@@ -143,12 +143,12 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.addptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%10 = tt.getelementptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%10 = tt.addptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>) {
|
||||
%23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
|
||||
%24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
|
||||
@@ -156,17 +156,17 @@ func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %ar
|
||||
%26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, #blocked3>
|
||||
%27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
|
||||
%28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1>
|
||||
%29 = tt.getelementptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%29 = tt.addptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
}
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%13 = tt.getelementptr %12, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%13 = tt.addptr %12, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%14 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2>
|
||||
%16 = tt.broadcast %13 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%17 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%18 = triton_gpu.convert_layout %17 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%19 = tt.getelementptr %16, %18 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%19 = tt.addptr %16, %18 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
|
||||
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
|
||||
@@ -190,16 +190,16 @@ func @vecadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f3
|
||||
%9 = arith.addi %6, %7 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%10 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%11 = arith.addi %4, %5 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%12 = tt.getelementptr %8, %9 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%12 = tt.addptr %8, %9 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%14 = triton_gpu.convert_layout %13 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
|
||||
%15 = tt.getelementptr %10, %11 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%15 = tt.addptr %10, %11 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%16 = tt.load %15 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%17 = triton_gpu.convert_layout %16 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
|
||||
%18 = arith.addf %14, %17 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>
|
||||
%19 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%20 = arith.addi %2, %3 : tensor<256xi32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%21 = tt.getelementptr %19, %20 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%21 = tt.addptr %19, %20 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
%22 = triton_gpu.convert_layout %18 : (tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>>) -> tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
tt.store %21, %22 : tensor<256xf32, #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>>
|
||||
return
|
||||
|
@@ -55,8 +55,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.getelementptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
@@ -112,8 +112,8 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.getelementptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
}
|
||||
@@ -161,7 +161,7 @@ func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %A :
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%next_b_ptr = tt.getelementptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
|
@@ -46,7 +46,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
%31 = tt.broadcast %29 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||
%32 = arith.addi %30, %31 : tensor<64x64xi32>
|
||||
%33 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
|
||||
%34 = tt.getelementptr %33, %32 : tensor<64x64x!tt.ptr<f32>>
|
||||
%34 = tt.addptr %33, %32 : tensor<64x64x!tt.ptr<f32>>
|
||||
%35 = tt.expand_dims %23 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%36 = tt.splat %arg8 : (i32) -> tensor<64x1xi32>
|
||||
%37 = arith.muli %35, %36 : tensor<64x1xi32>
|
||||
@@ -57,7 +57,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
%42 = tt.broadcast %40 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||
%43 = arith.addi %41, %42 : tensor<64x64xi32>
|
||||
%44 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
|
||||
%45 = tt.getelementptr %44, %43 : tensor<64x64x!tt.ptr<f32>>
|
||||
%45 = tt.addptr %44, %43 : tensor<64x64x!tt.ptr<f32>>
|
||||
%46 = arith.index_cast %arg5 : i32 to index
|
||||
%47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>) {
|
||||
%76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32>
|
||||
@@ -66,10 +66,10 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
%79 = arith.addf %arg13, %78 : tensor<64x64xf32>
|
||||
%80 = arith.muli %arg7, %c64_i32 : i32
|
||||
%81 = tt.splat %80 : (i32) -> tensor<64x64xi32>
|
||||
%82 = tt.getelementptr %arg14, %81 : tensor<64x64x!tt.ptr<f32>>
|
||||
%82 = tt.addptr %arg14, %81 : tensor<64x64x!tt.ptr<f32>>
|
||||
%83 = arith.muli %arg8, %c64_i32 : i32
|
||||
%84 = tt.splat %83 : (i32) -> tensor<64x64xi32>
|
||||
%85 = tt.getelementptr %arg15, %84 : tensor<64x64x!tt.ptr<f32>>
|
||||
%85 = tt.addptr %arg15, %84 : tensor<64x64x!tt.ptr<f32>>
|
||||
scf.yield %79, %82, %85 : tensor<64x64xf32>, tensor<64x64x!tt.ptr<f32>>, tensor<64x64x!tt.ptr<f32>>
|
||||
}
|
||||
%48 = arith.muli %12, %c64_i32 : i32
|
||||
@@ -90,7 +90,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6
|
||||
%63 = tt.broadcast %61 : (tensor<1x64xi32>) -> tensor<64x64xi32>
|
||||
%64 = arith.addi %62, %63 : tensor<64x64xi32>
|
||||
%65 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x64x!tt.ptr<f32>>
|
||||
%66 = tt.getelementptr %65, %64 : tensor<64x64x!tt.ptr<f32>>
|
||||
%66 = tt.addptr %65, %64 : tensor<64x64x!tt.ptr<f32>>
|
||||
%67 = tt.expand_dims %51 {axis = 1 : i32} : (tensor<64xi32>) -> tensor<64x1xi32>
|
||||
%68 = tt.splat %arg3 : (i32) -> tensor<64x1xi32>
|
||||
%69 = arith.cmpi slt, %67, %68 : tensor<64x1xi32>
|
||||
|
Reference in New Issue
Block a user