[BACKEND] Add backend support of arith::AddIOp, arith::AddFOp, GetProgramIdOp & GEPOp and bugfix for SplatOp, StoreOp, FuncOp (#60)

Add backend support of arith::AddIOp, arith::AddFOp, GetProgramIdOp, GEPOp and bugfix for SplatOp, StoreOp, FuncOp

Co-authored-by: gzhu <gzhu@nvidia.com>
This commit is contained in:
goostavz
2022-08-18 20:46:45 +08:00
committed by GitHub
parent b1673caaf6
commit fc58250a06
9 changed files with 270 additions and 122 deletions

View File

@@ -197,6 +197,7 @@ target_link_libraries(triton
MLIRSupport
MLIRTargetLLVMIRExport
MLIRExecutionEngine
MLIRNVVMToLLVMIRTranslation
)
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})

View File

@@ -54,5 +54,6 @@ target_link_libraries(triton-translate PRIVATE
MLIRExecutionEngine
MLIRTransformUtils
MLIRLLVMToLLVMIRTranslation
MLIRNVVMToLLVMIRTranslation
)
mlir_check_all_link_libraries(triton-translate)

View File

@@ -61,7 +61,7 @@ std::string PtxInstr::Operand::dump() const {
if (repr)
return repr(idx);
if (!isList())
return llvm::formatv("%{0}", idx);
return llvm::formatv("${0}", idx);
llvm::SmallVector<std::string> oprs;
for (auto *opr : list)
oprs.push_back(opr->dump());
@@ -72,7 +72,7 @@ PtxInstr::Operand *PtxIOInstr::newAddrOperand(mlir::Value addr,
StringRef constraint, int off) {
auto *opr = newOperand(addr, constraint);
opr->repr = [off](int idx) -> std::string {
return llvm::formatv("[ %{0} + {1} ]", idx, off);
return llvm::formatv("[ ${0} + {1} ]", idx, off);
};
return opr;

View File

@@ -46,20 +46,10 @@ template <typename Int> size_t product(llvm::ArrayRef<Int> arr) {
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
}
// The following code are borrowed from mlir project including the following
// functions or classes:
// - filterFuncAttributes
// - ConvertOpToLLVMPattern
// - FuncOpConversion
//
// The code are hidden in the CPP files in MLIR repo, and we can't call them
// directly. I found such code snippets are refactored and added to LLVMCommon
// in the latest MLIR code, but the v14.0.0 version currentlly used in Triton
// doesn't contain the code.
// FuncOpConversion/FuncOpConversionBase is borrowed from
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276
// since it is not exposed on header files in mlir v14
// TODO(Superjomn) Remove the code when mlir v15.0 is included.
//
// The original code:
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp#L219
// All the rights are reserved by LLVM community.
/// Only retain those attributes that are not constructed by
@@ -79,6 +69,12 @@ static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
}
}
/// Helper function for wrapping all attributes into a single DictionaryAttr
static auto wrapAsStructAttrs(OpBuilder &b, ArrayAttr attrs) {
return DictionaryAttr::get(
b.getContext(), b.getNamedAttr(LLVM::getStructAttrsAttrName(), attrs));
}
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<FuncOp> {
protected:
using ConvertOpToLLVMPattern<FuncOp>::ConvertOpToLLVMPattern;
@@ -90,25 +86,34 @@ protected:
ConversionPatternRewriter &rewriter) const {
// Convert the original function arguments. They are converted using the
// LLVMTypeConverter provided to this legalization pattern.
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("std.varargs");
auto varargsAttr = funcOp->getAttrOfType<BoolAttr>("func.varargs");
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = getTypeConverter()->convertFunctionSignature(
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
if (!llvmType)
return nullptr;
// Propagate argument attributes to all converted arguments obtained after
// converting a given original argument.
// Propagate argument/result attributes to all converted arguments/result
// obtained after converting a given original argument/result.
SmallVector<NamedAttribute, 4> attributes;
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true,
filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true,
attributes);
if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
assert(!resAttrDicts.empty() && "expected array to be non-empty");
auto newResAttrDicts =
(funcOp.getNumResults() == 1)
? resAttrDicts
: rewriter.getArrayAttr(
{wrapAsStructAttrs(rewriter, resAttrDicts)});
attributes.push_back(rewriter.getNamedAttr(
FunctionOpInterface::getResultDictAttrName(), newResAttrDicts));
}
if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
SmallVector<Attribute, 4> newArgAttrs(
llvmType.cast<LLVM::LLVMFunctionType>().getNumParams());
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
auto mapping = result.getInputMapping(i);
assert(mapping.hasValue() &&
"unexpected deletion of function argument");
assert(mapping && "unexpected deletion of function argument");
for (size_t j = 0; j < mapping->size; ++j)
newArgAttrs[mapping->inputNo + j] = argAttrDicts[i];
}
@@ -136,37 +141,15 @@ protected:
}
linkage = attr.getLinkage();
}
auto oldArgs = funcOp.getArguments();
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal*/ false, attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
&result)))
return nullptr;
// Convert argument
llvm::DenseMap<Value, Value> argMap;
for (int i = 0, n = funcOp.getNumArguments(); i < n; i++) {
Value oldArg = oldArgs[i];
Value newArg = newFuncOp.getArgument(i);
argMap.try_emplace(oldArg, newArg);
}
newFuncOp.getBody().walk([&](Operation *op) {
// Convert the function argument types, e.g, from !tt.ptr<fp16> to
// ptr<fp16>
for (int i = 0; i < op->getNumOperands(); i++) {
auto arg = op->getOperand(i);
auto it = argMap.find(arg);
if (it != argMap.end())
op->setOperand(i, it->second);
}
});
return newFuncOp;
}
};
@@ -245,8 +228,13 @@ static int64_t getLinearIndex(std::vector<int64_t> multidim_index,
static unsigned getElemsPerThread(TritonGPUBlockedEncodingAttr layout,
ArrayRef<int64_t> shape) {
return product(shape) / (product(layout.getThreadsPerWarp()) *
product(layout.getWarpsPerCTA()));
size_t rank = shape.size();
SmallVector<unsigned> elemsPerThreadPerDim(rank);
for (size_t i = 0; i < rank; ++i) {
unsigned t = layout.getThreadsPerWarp()[i] * layout.getWarpsPerCTA()[i];
elemsPerThreadPerDim[i] = (shape[i] + t - 1) / t;
}
return product<unsigned>(elemsPerThreadPerDim);
}
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
@@ -257,7 +245,7 @@ static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Value getStructFromElements(Location loc, ValueRange resultVals,
ConversionPatternRewriter &rewriter,
Type structType, Type elemPtrPtrType) {
Type structType) {
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
for (auto v : llvm::enumerate(resultVals)) {
llvmStruct = rewriter.create<LLVM::InsertValueOp>(
@@ -513,10 +501,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
auto llElemPtrPtrTy =
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(srcType));
auto llStruct =
getStructFromElements(loc, elems, rewriter, structTy, llElemPtrPtrTy);
auto llStruct = getStructFromElements(loc, elems, rewriter, structTy);
return llStruct;
}
@@ -529,29 +514,7 @@ struct SplatOpConversion
matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto src = op->getOperand(0);
LLVM::ConstantOp arithConstantOp;
if (src.getDefiningOp() &&
(arithConstantOp =
llvm::dyn_cast<LLVM::ConstantOp>(src.getDefiningOp()))) {
Value constant;
auto values = arithConstantOp.getValue().dyn_cast<DenseElementsAttr>();
assert(values.size() == 1);
Attribute val;
if (type::isInt(src.getType())) {
val = values.getValues<IntegerAttr>()[0];
} else if (type::isFloat(src.getType())) {
val = values.getValues<FloatAttr>()[0];
} else {
llvm::errs() << "Constant op type not supported";
return failure();
}
src = rewriter.create<LLVM::ConstantOp>(loc, val.getType(), val);
}
auto src = adaptor.src();
auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src,
getTypeConverter(), rewriter, loc);
rewriter.replaceOp(op, {llStruct});
@@ -618,12 +581,15 @@ struct StoreOpConversion
Value mask = op.mask();
Value value = op.value();
Value llPtr = adaptor.ptr(); // should be LLVM ops
Value llPtr = adaptor.ptr();
Value llMask = adaptor.mask();
Value llValue = adaptor.value();
Type valueElemTy = getTypeConverter()->convertType(
value.getType().cast<RankedTensorType>().getElementType());
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
if (!valueTy)
return failure();
Type valueElemTy =
getTypeConverter()->convertType(valueTy.getElementType());
MLIRContext *ctx = rewriter.getContext();
auto loc = op->getLoc();
@@ -662,6 +628,7 @@ struct StoreOpConversion
auto [maskLayout, maskNumElems] = getLayout(mask);
auto [valueLayout, valueNumElems] = getLayout(value);
auto ptrElems = getLLVMElems(mask, llPtr, maskLayout);
auto valueElems = getLLVMElems(value, llValue, valueLayout);
auto maskElems = getLLVMElems(mask, llMask, maskLayout);
assert(valueElems.size() == maskElems.size());
@@ -718,17 +685,8 @@ struct StoreOpConversion
const int numVecs = ptrNumElems / vec;
for (size_t vecIdx = 0; vecIdx < ptrNumElems; vecIdx += vec) {
size_t in_off{};
auto ptrProducer = llPtr.getDefiningOp();
auto in_gep = llvm::dyn_cast<LLVM::GEPOp>(ptrProducer);
if (in_gep) {
auto indices = in_gep.getIndices();
auto cst = dyn_cast<LLVM::ConstantOp>(indices.front().getDefiningOp());
in_off =
cst ? cst.getValue().dyn_cast<IntegerAttr>().getInt() * dtsize : 0;
ptr = cst ? in_gep.getBase() : in_gep;
}
// TODO: optimization when ptr is GEP with constant offset
size_t in_off = 0;
// pack sub-words (< 32/64bits) into words
// each load has width min(nbits*vec, 32/64)
@@ -747,7 +705,7 @@ struct StoreOpConversion
const bool hasL2EvictPolicy = false;
PtxIOInstr asmStoreInstr("st");
asmStoreInstr.predicate(llMask, "b");
asmStoreInstr.predicate(maskElems[vecIdx], "b");
asmStoreInstr.global().v(width).b(nWords);
llvm::SmallVector<std::string> asmArgs;
@@ -755,7 +713,8 @@ struct StoreOpConversion
Type valArgTy = IntegerType::get(ctx, width);
auto wordTy = VectorType::get(wordNElems, valueElemTy);
auto *asmAddr = asmStoreInstr.newAddrOperand(llPtr, "l", in_off);
auto *asmAddr =
asmStoreInstr.newAddrOperand(ptrElems[vecIdx], "l", in_off);
auto *asmArgList = asmStoreInstr.newList();
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
// llWord is a width-len composition
@@ -800,9 +759,8 @@ struct StoreOpConversion
LLVM::AsmDialect::AD_ATT), // asm_dialect
ArrayAttr::get(ctx, {}) // operand_attrs
);
rewriter.replaceOp(op, inlineAsm.getRes());
}
rewriter.eraseOp(op);
return success();
}
@@ -1135,6 +1093,10 @@ struct LoadOpConversion
// finally call inline ASM
// ---
SmallVector<Value> args = {pred, ptr};
for (Value v : others) {
args.push_back(v);
}
// TODO: if (has_l2_evict_policy)
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
auto inlineAsmOp = rewriter.create<LLVM::InlineAsmOp>(
@@ -1177,6 +1139,95 @@ struct LoadOpConversion
}
};
struct GetProgramIdOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::GetProgramIdOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::GetProgramIdOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x);
auto llvmIndexTy = getTypeConverter()->getIndexType();
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
op, TypeRange{llvmIndexTy}, ValueRange{blockId});
return success();
}
};
struct GEPOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::GEPOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::GEPOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
auto resultLayout =
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultLayout, resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), elems, rewriter);
auto offsets =
getElementsFromStruct(loc, adaptor.offset(), elems, rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] =
rewriter.create<LLVM::GEPOp>(loc, elemTy, ptrs[i], offsets[i]);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
};
template <typename SourceOp, typename DestOp>
class BinaryOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit BinaryOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
if (!resultTy)
return failure();
Location loc = op->getLoc();
auto resultLayout = resultTy.getEncoding()
.template dyn_cast<TritonGPUBlockedEncodingAttr>();
auto resultShape = resultTy.getShape();
unsigned elems = getElemsPerThread(resultLayout, resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto lhss =
this->getElementsFromStruct(loc, adaptor.getLhs(), elems, rewriter);
auto rhss =
this->getElementsFromStruct(loc, adaptor.getRhs(), elems, rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = rewriter.create<DestOp>(loc, elemTy, lhss[i], rhss[i]);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
};
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
public:
using TypeConverter::convertType;
@@ -1221,14 +1272,20 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &analysis,
PatternBenefit benefit = 1) {
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<StoreOpConversion>(typeConverter, analysis, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<BinaryOpConversion<arith::AddIOp, LLVM::AddOp>>(typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::AddFOp, LLVM::FAddOp>>(typeConverter,
benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
patterns.add<GEPOpConversion>(typeConverter, benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<LoadOpConversion>(typeConverter, benefit);
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<StoreOpConversion>(typeConverter, analysis, benefit);
patterns.add<ViewOpConversion>(typeConverter, benefit);
}

View File

@@ -8,8 +8,10 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/driver/llvm.h"
#include "llvm/IR/Constants.h"
@@ -82,7 +84,8 @@ std::unique_ptr<llvm::Module>
translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
auto context = module->getContext();
DialectRegistry registry;
registerLLVMDialectTranslation(registry);
mlir::registerLLVMDialectTranslation(registry);
mlir::registerNVVMDialectTranslation(registry);
context->appendDialectRegistry(registry);
llvm::DenseMap<llvm::StringRef, NVVMMetadata> nvvmMetadata;
@@ -123,6 +126,8 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
applyPassManagerCLOptions(pm);
pm.addPass(createConvertTritonGPUToLLVMPass());
// Conanicalize to eliminate the remaining UnrealizedConversionCastOp
pm.addPass(mlir::createCanonicalizerPass());
if (failed(pm.run(module))) {
llvm::errs() << "Pass execution failed";

View File

@@ -1504,7 +1504,7 @@ void init_triton_ir(py::module &&m) {
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUVerifier());
})
.def("triton_gpu_to_llvm", [](mlir::PassManager &self) {
.def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) {
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
});
}

View File

@@ -0,0 +1,30 @@
import triton
import triton.language as tl
NUM_WARPS = 4
# triton kernel
@triton.jit
def kernel(x_ptr, stride_xn,
y_ptr, stride_yn,
z_ptr, stride_zn,
BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(axis=0)
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
print(ret)
# TODO: base class for python end2end tests,
# runtime execution, correctness comparison etc.

View File

@@ -19,6 +19,8 @@ func @test_splat(%ptr: !tt.ptr<f32>) {
return
}
// -----
func @test_store_splat(%ptr: !tt.ptr<f32>) {
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
%a = arith.constant 1.0 : f32
@@ -27,9 +29,8 @@ func @test_store_splat(%ptr: !tt.ptr<f32>) {
%vs = tt.splat %a : (f32) -> tensor<128xf32>
%mask = tt.splat %true : (i1) -> tensor<128xi1>
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@%0 st.global.v32.b1 [ %1 + 0 ], { %2 };",
// CHECK: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.struct<(i1, i1)>, !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>, i32) -> !llvm.struct<()>
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$0 st.global.v32.b1 [ $1 + 0 ], { $2 };",
// CHECK-SAME: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
tt.store %ptrs, %vs, %mask, {} : tensor<128xf32>
return

View File

@@ -112,28 +112,81 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
}
}
// #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
// #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
// #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
// module attributes {"triton_gpu.num-warps" = 4 : i32} {
// func @debut_kernel(%lb : index, %A : !tt.ptr<f32>, %B : !tt.ptr<f32>, %C : !tt.ptr<f32>) {
// %cst = arith.constant dense<true> : tensor<256xi1, #blocked0>
// %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
// %cst_1 = arith.constant dense<true> : tensor<1024x256xi1, #blocked1>
// %cst_2 = arith.constant dense<true> : tensor<256x2048xi1, #blocked2>
// %a_ptr_init = tt.splat %A : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
// %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
// %4 = tt.view %1 : (tensor<256xf32, #blocked0>) -> tensor<1x256xf32,#blocked1>
// %5 = tt.broadcast %4 : (tensor<1x256xf32,#blocked1>) -> tensor<1024x256xf32, #blocked1>
// %6 = tt.view %1 : (tensor<256xf32, #blocked0>) -> tensor<256x1xf32,#blocked2>
// %7 = tt.broadcast %6 : (tensor<256x1xf32,#blocked2>) -> tensor<256x2048xf32, #blocked2>
// %b_ptr_init = tt.splat %A : (!tt.ptr<f32>) -> tensor<1024x256x!tt.ptr<f32>, #blocked1>
// %c_ptr_init = tt.splat %A : (!tt.ptr<f32>) -> tensor<256x2048x!tt.ptr<f32>, #blocked2>
// tt.store %b_ptr_init, %5, %cst_1, : tensor<1024x256xf32, #blocked1>
// tt.store %c_ptr_init, %7, %cst_2, : tensor<256x2048xf32, #blocked2>
// return
// }
// }
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_addf
func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) {
// CHECK: llvm.fadd
// CHECK: llvm.fadd
%1 = arith.addf %arg0, %arg1 : tensor<256xf32,#blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_addi
func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
// CHECK: llvm.add
// CHECK: llvm.add
%1 = arith.addi %arg0, %arg1 : tensor<256xi32,#blocked0>
return
}
}
// -----
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_program_id
func @basic_program_id() {
// CHECK: nvvm.read.ptx.sreg.ctaid.x : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_gep
func @basic_gep(%arg0 : tensor<256x!tt.ptr<f32>,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) {
// CHECK: llvm.getelementptr
// CHECK: llvm.getelementptr
%0 = tt.getelementptr %arg0, %arg1 : tensor<256x!tt.ptr<f32>, #blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: basic_splat
func @basic_splat(%ptr: !tt.ptr<f32>) {
// CHECK: llvm.mlir.undef
// CHECK: llvm.insertvalue
// CHECK: llvm.insertvalue
%0 = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>,#blocked0>
return
}
}
// -----
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_store
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: st.global.v32.b1 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
// CHECK-SAME: st.global.v32.b1 [ ${{.*}} + 0 ], { ${{.*}} };", "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (i1, !llvm.ptr<f32, 1>, i32) -> !llvm.struct<()>
tt.store %ptrs, %vals, %mask, {} : tensor<256xf32, #blocked0>
return
}
}