[BACKEND] Add LLVM-translation for store and splat ops (#47)
This commit is contained in:
39
include/triton/Conversion/MLIRTypes.h
Normal file
39
include/triton/Conversion/MLIRTypes.h
Normal file
@@ -0,0 +1,39 @@
|
||||
#ifndef TRITON_CONVERSION_MLIR_TYPES_H_
|
||||
#define TRITON_CONVERSION_MLIR_TYPES_H_
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
// This file redefines some common MLIR types for easy usage.
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
namespace type {
|
||||
|
||||
// Integer types
|
||||
Type i32Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 32, IntegerType::Signed);
|
||||
}
|
||||
Type i8Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 8, IntegerType::Signed);
|
||||
}
|
||||
Type u32Ty(MLIRContext *ctx) {
|
||||
return IntegerType::get(ctx, 32, IntegerType::Signless);
|
||||
}
|
||||
Type u1Ty(MLIRContext *ctx) { return IntegerType::get(ctx, 1); }
|
||||
|
||||
// Float types
|
||||
Type f16Ty(MLIRContext *ctx) { return FloatType::getF16(ctx); }
|
||||
Type f32Ty(MLIRContext *ctx) { return FloatType::getF32(ctx); }
|
||||
Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
|
||||
|
||||
static bool isFloat(Type type) {
|
||||
return type.isF32() || type.isF64() || type.isF16() || type.isF128();
|
||||
}
|
||||
|
||||
static bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }
|
||||
|
||||
} // namespace type
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_CONVERSION_MLIR_TYPES_H_
|
191
include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h
Normal file
191
include/triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h
Normal file
@@ -0,0 +1,191 @@
|
||||
#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
||||
#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
||||
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Format.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
using llvm::StringRef;
|
||||
|
||||
// TODO(Superjomn) Move to a global utility file?
|
||||
std::string strJoin(llvm::ArrayRef<std::string> strs,
|
||||
llvm::StringRef delimiter);
|
||||
|
||||
// A helper for building a single inline ASM instruction, the objective of
|
||||
// PtxInstr is to give a thin encapsulation and make the ASM code for MLIR LLVM
|
||||
// Dialect more clear. Currently, several factors are introduced to reduce the
|
||||
// need for mixing string and C++ if-else code.
|
||||
// Usage:
|
||||
// To build: asm("add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k));
|
||||
//
|
||||
// PtxInstr mulr("mul");
|
||||
// mulr.o("lo").o("u32").addOperand(valueI, "=r") // %0 bind to valueI
|
||||
// .addOperand(valueJ, "r") // %1 bind to valueJ
|
||||
// .addOperand(valueK, "k"); // %2 bind to valueK
|
||||
//
|
||||
// mulr.getConstrains() // get "=r,r,k"
|
||||
// mulr.getAllMlirArgs() // get {valueI, valueJ, valueK}
|
||||
//
|
||||
// TODO(Superjomn) Add multi-line ASM code support and register support later.
|
||||
struct PtxInstr {
|
||||
explicit PtxInstr(const std::string &name) { o(name); }
|
||||
|
||||
struct Operand {
|
||||
std::string constraint;
|
||||
Value value;
|
||||
int idx{-1};
|
||||
llvm::SmallVector<Operand *> list;
|
||||
std::function<std::string(int idx)> repr;
|
||||
|
||||
// for list
|
||||
Operand() = default;
|
||||
Operand(Value value, StringRef constraint)
|
||||
: value(value), constraint(constraint) {}
|
||||
|
||||
bool isList() const { return !value; }
|
||||
|
||||
Operand *listAppend(Operand *arg) {
|
||||
list.push_back(arg);
|
||||
return this;
|
||||
}
|
||||
|
||||
std::string dump() const;
|
||||
};
|
||||
|
||||
// Create a new operand. It will not add to operand list.
|
||||
// @value: the MLIR value bind to this operand.
|
||||
// @constraint: ASM operand constraint, .e.g. "=r"
|
||||
// @formater: extra format to represent this operand in ASM code, default is
|
||||
// "%{0}".format(operand.idx).
|
||||
Operand *newOperand(mlir::Value value, StringRef constraint,
|
||||
std::function<std::string(int idx)> formater = nullptr);
|
||||
|
||||
// Append the operand to the intruction's operand list.
|
||||
Operand *addOperand(Operand *opr) {
|
||||
assert(std::find(argsInOrder.begin(), argsInOrder.end(), opr) ==
|
||||
argsInOrder.end());
|
||||
argsInOrder.push_back(opr);
|
||||
return opr;
|
||||
}
|
||||
|
||||
// Create and add an operand to the intruction's operand list.
|
||||
Operand *addOperand(mlir::Value value, StringRef constraint) {
|
||||
auto *opr = newOperand(value, constraint);
|
||||
return addOperand(opr);
|
||||
}
|
||||
|
||||
// Prefix a predicate to the instruction.
|
||||
PtxInstr &predicate(mlir::Value value, StringRef constraint) {
|
||||
pred = newOperand(value, constraint);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Append a suffix to the instruction.
|
||||
// e.g. PtxInstr("add").o("s32") get a add.s32.
|
||||
// A predicate is used to tell whether to apply the suffix, so that no if-else
|
||||
// code needed. e.g. `PtxInstr("add").o("s32", isS32).o("u32", !isS32);` will
|
||||
// get a `add.s32` if isS32 is true.
|
||||
PtxInstr &o(const std::string &suffix, bool predicate = true) {
|
||||
if (predicate)
|
||||
instrParts.push_back(suffix);
|
||||
return *this;
|
||||
}
|
||||
|
||||
PtxInstr &addListOperation(llvm::ArrayRef<Operand *> list) {
|
||||
auto *opr = newList();
|
||||
for (auto *v : list)
|
||||
opr->listAppend(v);
|
||||
addOperand(opr);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Create a list of operands.
|
||||
Operand *newList() {
|
||||
argArchive.emplace_back(std::make_unique<Operand>());
|
||||
return argArchive.back().get();
|
||||
}
|
||||
|
||||
std::string dump() const;
|
||||
|
||||
llvm::SmallVector<Operand *, 4> getArgList() const;
|
||||
llvm::SmallVector<Operand *, 4> getAllArgs() const {
|
||||
llvm::SmallVector<Operand *, 4> res;
|
||||
for (auto &x : argArchive)
|
||||
if (!x->isList())
|
||||
res.push_back(x.get());
|
||||
return res;
|
||||
}
|
||||
|
||||
std::string getConstrains() const {
|
||||
auto args = getAllArgs();
|
||||
llvm::SmallVector<std::string, 4> argReprs;
|
||||
for (auto arg : args)
|
||||
argReprs.push_back(arg->constraint);
|
||||
return strJoin(argReprs, ",");
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> getAllMlirArgs() const {
|
||||
llvm::SmallVector<Value, 4> res;
|
||||
for (auto &arg : argArchive) {
|
||||
if (!arg->isList())
|
||||
res.push_back(arg->value);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
protected:
|
||||
Operand *pred{};
|
||||
int oprCounter{};
|
||||
llvm::SmallVector<std::string, 4> instrParts;
|
||||
llvm::SmallVector<std::unique_ptr<Operand>, 6> argArchive;
|
||||
llvm::SmallVector<Operand *> argsInOrder;
|
||||
std::string argStr;
|
||||
};
|
||||
|
||||
// A helper for PTX ld/st instruction.
|
||||
// Usage:
|
||||
// PtxIOInstr store("st");
|
||||
// store.predicate(pValue).global().v(32).b(1); // @%0 st.global.v32.b1
|
||||
// store.addAddr(addrValue, "l", off);
|
||||
struct PtxIOInstr : public PtxInstr {
|
||||
PtxIOInstr(const std::string &name) : PtxInstr(name) {}
|
||||
|
||||
// Add ".global" suffix to instruction
|
||||
PtxIOInstr &global(bool predicate = true) {
|
||||
o("global", predicate);
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Add ".v" suffix to instruction
|
||||
PtxIOInstr &v(int vecWidth, bool predicate = true) {
|
||||
if (vecWidth > 1) {
|
||||
o(llvm::formatv("v{0}", vecWidth), predicate);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Add ".b" suffix to instruction
|
||||
PtxIOInstr &b(int width) {
|
||||
o(llvm::formatv("b{0}", width));
|
||||
return *this;
|
||||
}
|
||||
|
||||
PtxIOInstr &addAddr(mlir::Value addr, StringRef constraint, int off = 0) {
|
||||
auto *operand = newAddrOperand(addr, constraint, off);
|
||||
addOperand(operand);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0);
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_CONVERSION_TRITON_GPU_TO_LLVM_ASM_FORMAT_H_
|
@@ -1,5 +1,6 @@
|
||||
add_mlir_conversion_library(TritonGPUToLLVM
|
||||
TritonGPUToLLVM.cpp
|
||||
PtxAsmFormat.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/triton/Conversion/TritonGPUToLLVM
|
||||
|
81
lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp
Normal file
81
lib/Conversion/TritonGPUToLLVM/PtxAsmFormat.cpp
Normal file
@@ -0,0 +1,81 @@
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
std::string strJoin(llvm::ArrayRef<std::string> strs,
|
||||
llvm::StringRef delimiter) {
|
||||
std::string osStr;
|
||||
llvm::raw_string_ostream os(osStr);
|
||||
for (size_t i = 0; !strs.empty() && i < strs.size() - 1; i++)
|
||||
os << strs[i] << delimiter;
|
||||
if (!strs.empty())
|
||||
os << strs.back();
|
||||
os.flush();
|
||||
return osStr;
|
||||
}
|
||||
|
||||
std::string PtxInstr::dump() const {
|
||||
std::string osStr;
|
||||
llvm::raw_string_ostream os(osStr);
|
||||
if (pred)
|
||||
os << "@" << pred->dump() << " ";
|
||||
|
||||
std::string instrRepr = strJoin(instrParts, ".");
|
||||
|
||||
llvm::SmallVector<std::string, 4> argReprs;
|
||||
for (auto *arg : argsInOrder) {
|
||||
argReprs.push_back(arg->dump());
|
||||
}
|
||||
|
||||
std::string argsRepr = strJoin(argReprs, ", ");
|
||||
|
||||
os << instrRepr << " " << argsRepr << ";";
|
||||
os.flush();
|
||||
return osStr;
|
||||
}
|
||||
|
||||
llvm::SmallVector<PtxInstr::Operand *, 4> PtxInstr::getArgList() const {
|
||||
SmallVector<Operand *> args;
|
||||
for (auto *arg : argsInOrder) {
|
||||
if (arg->isList())
|
||||
args.insert(args.end(), arg->list.begin(), arg->list.end());
|
||||
else
|
||||
args.push_back(arg);
|
||||
}
|
||||
return args;
|
||||
}
|
||||
|
||||
PtxInstr::Operand *
|
||||
PtxInstr::newOperand(mlir::Value value, StringRef constraint,
|
||||
std::function<std::string(int)> formater) {
|
||||
argArchive.emplace_back(std::make_unique<Operand>(value, constraint));
|
||||
auto *opr = argArchive.back().get();
|
||||
opr->repr = formater;
|
||||
opr->idx = oprCounter++;
|
||||
return opr;
|
||||
}
|
||||
|
||||
std::string PtxInstr::Operand::dump() const {
|
||||
if (repr)
|
||||
return repr(idx);
|
||||
if (!isList())
|
||||
return llvm::formatv("%{0}", idx);
|
||||
llvm::SmallVector<std::string> oprs;
|
||||
for (auto *opr : list)
|
||||
oprs.push_back(opr->dump());
|
||||
return "{ " + strJoin(oprs, ", ") + " }";
|
||||
}
|
||||
|
||||
PtxInstr::Operand *PtxIOInstr::newAddrOperand(mlir::Value addr,
|
||||
StringRef constraint, int off) {
|
||||
auto *opr = newOperand(addr, constraint);
|
||||
opr->repr = [off](int idx) -> std::string {
|
||||
return llvm::formatv("[ %{0} + {1} ]", idx, off);
|
||||
};
|
||||
|
||||
return opr;
|
||||
}
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
@@ -1,5 +1,4 @@
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
@@ -10,10 +9,17 @@
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/Support/Format.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
@@ -31,8 +37,15 @@ static StringRef getStructAttrsAttrName() { return "llvm.struct_attrs"; }
|
||||
|
||||
namespace {
|
||||
|
||||
namespace type = mlir::triton::type;
|
||||
|
||||
class TritonGPUToLLVMTypeConverter;
|
||||
|
||||
// TODO(Superjomn) Move to somewhere general utilities locates.
|
||||
template <typename Int> size_t product(llvm::ArrayRef<Int> arr) {
|
||||
return std::accumulate(arr.begin(), arr.end(), 1, std::multiplies{});
|
||||
}
|
||||
|
||||
// The following code are borrowed from mlir project including the following
|
||||
// functions or classes:
|
||||
// - filterFuncAttributes
|
||||
@@ -81,7 +94,6 @@ protected:
|
||||
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
|
||||
auto llvmType = getTypeConverter()->convertFunctionSignature(
|
||||
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
|
||||
assert(llvmType);
|
||||
if (!llvmType)
|
||||
return nullptr;
|
||||
|
||||
@@ -124,6 +136,8 @@ protected:
|
||||
}
|
||||
linkage = attr.getLinkage();
|
||||
}
|
||||
|
||||
auto oldArgs = funcOp.getArguments();
|
||||
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
|
||||
/*dsoLocal*/ false, attributes);
|
||||
@@ -134,6 +148,25 @@ protected:
|
||||
&result)))
|
||||
return nullptr;
|
||||
|
||||
// Convert argument
|
||||
llvm::DenseMap<Value, Value> argMap;
|
||||
for (int i = 0, n = funcOp.getNumArguments(); i < n; i++) {
|
||||
Value oldArg = oldArgs[i];
|
||||
Value newArg = newFuncOp.getArgument(i);
|
||||
argMap.try_emplace(oldArg, newArg);
|
||||
}
|
||||
|
||||
newFuncOp.getBody().walk([&](Operation *op) {
|
||||
// Convert the function argument types, e.g, from !tt.ptr<fp16> to
|
||||
// ptr<fp16>
|
||||
for (int i = 0; i < op->getNumOperands(); i++) {
|
||||
auto arg = op->getOperand(i);
|
||||
auto it = argMap.find(arg);
|
||||
if (it != argMap.end())
|
||||
op->setOperand(i, it->second);
|
||||
}
|
||||
});
|
||||
|
||||
return newFuncOp;
|
||||
}
|
||||
};
|
||||
@@ -143,8 +176,9 @@ protected:
|
||||
/// information.
|
||||
static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface";
|
||||
struct FuncOpConversion : public FuncOpConversionBase {
|
||||
FuncOpConversion(LLVMTypeConverter &converter, int numWarps)
|
||||
: FuncOpConversionBase(converter), NumWarps(numWarps) {}
|
||||
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
|
||||
PatternBenefit benefit)
|
||||
: FuncOpConversionBase(converter, benefit), NumWarps(numWarps) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
|
||||
@@ -154,11 +188,11 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
||||
return failure();
|
||||
|
||||
auto ctx = funcOp->getContext();
|
||||
auto i32 = IntegerType::get(ctx, 32);
|
||||
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
|
||||
// for `nvvm.annotation` metadata.
|
||||
newFuncOp->setAttr(NVVMMetadataField::MaxNTid,
|
||||
rewriter.getIntegerAttr(i32, 32 * NumWarps));
|
||||
newFuncOp->setAttr(
|
||||
NVVMMetadataField::MaxNTid,
|
||||
rewriter.getIntegerAttr(type::i32Ty(ctx), 32 * NumWarps));
|
||||
|
||||
rewriter.eraseOp(funcOp);
|
||||
return success();
|
||||
@@ -190,22 +224,47 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Extract numWarps information from TritonGPU module, return 0 if failed.
|
||||
// This is a naive implementation, it assumes that all the blocked layout should
|
||||
// have the same numWarps setting in a module, it just find a blocked layout
|
||||
// encoding and return the warpsPerCTA field.
|
||||
int extractNumWarps(mlir::ModuleOp module) {
|
||||
int numWarps{};
|
||||
if (module->hasAttr(AttrNumWarpsName))
|
||||
numWarps = module->getAttr(AttrNumWarpsName)
|
||||
.dyn_cast<IntegerAttr>()
|
||||
.getValue()
|
||||
.getZExtValue();
|
||||
else
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||
static int64_t getLinearIndex(std::vector<int64_t> multidim_index,
|
||||
ArrayRef<int64_t> shape) {
|
||||
assert(multidim_index.size() == shape.size());
|
||||
// sizes {a, b, c, d} -> acc_mul {b*c*d, c*d, d, 1}
|
||||
int64_t rank = shape.size();
|
||||
int64_t acc_mul = 1;
|
||||
for (int64_t i = 1; i < rank; ++i) {
|
||||
acc_mul *= shape[i];
|
||||
}
|
||||
int64_t linear_index = 0;
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
linear_index += multidim_index[i] * acc_mul;
|
||||
if (i != (rank - 1)) {
|
||||
acc_mul = acc_mul / shape[i + 1];
|
||||
}
|
||||
}
|
||||
return linear_index;
|
||||
}
|
||||
|
||||
return numWarps;
|
||||
static unsigned getElemsPerThread(TritonGPUBlockedEncodingAttr layout,
|
||||
ArrayRef<int64_t> shape) {
|
||||
return product(shape) / (product(layout.getThreadsPerWarp()) *
|
||||
product(layout.getWarpsPerCTA()));
|
||||
}
|
||||
|
||||
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
||||
Type resultType, int64_t value) {
|
||||
return builder.create<LLVM::ConstantOp>(
|
||||
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
|
||||
}
|
||||
|
||||
Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type structType, Type elemPtrPtrType) {
|
||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
||||
for (auto v : llvm::enumerate(resultVals)) {
|
||||
llvmStruct = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, structType, llvmStruct, v.value(),
|
||||
rewriter.getI64ArrayAttr(v.index()));
|
||||
}
|
||||
return llvmStruct;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -247,24 +306,6 @@ static T getLinearIndex(ArrayRef<T> multidim_index, ArrayRef<T> shape) {
|
||||
return linear_index;
|
||||
}
|
||||
|
||||
static unsigned getElemsPerThread(const TritonGPUBlockedEncodingAttr &layout,
|
||||
ArrayRef<int64_t> shape) {
|
||||
unsigned elems = 1;
|
||||
size_t rank = shape.size();
|
||||
assert(rank == layout.getThreadsPerWarp().size());
|
||||
for (size_t d = 0; d < rank; ++d) {
|
||||
elems *=
|
||||
shape[d] / (layout.getThreadsPerWarp()[d] * layout.getWarpsPerCTA()[d]);
|
||||
}
|
||||
return elems;
|
||||
}
|
||||
|
||||
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
|
||||
Type resultType, int64_t value) {
|
||||
return builder.create<LLVM::ConstantOp>(
|
||||
loc, resultType, builder.getIntegerAttr(resultType, value));
|
||||
}
|
||||
|
||||
template <typename SourceOp>
|
||||
class ConvertTritonGPUOpToLLVMPattern
|
||||
: public ConvertOpToLLVMPattern<SourceOp> {
|
||||
@@ -443,6 +484,358 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a
|
||||
// LLVM::StructType value.
|
||||
//
|
||||
// @elemType: the element type in operand.
|
||||
// @resType: the return type of the Splat-like op.
|
||||
// @constVal: a LLVM::ConstantOp or other scalar value.
|
||||
Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
TypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto layout = tensorTy.getEncoding().cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = rewriter.create<LLVM::BitcastOp>(loc, srcType, constVal);
|
||||
|
||||
auto numElems = layout.getSizePerThread();
|
||||
size_t totalElems =
|
||||
std::accumulate(tensorTy.getShape().begin(), tensorTy.getShape().end(), 1,
|
||||
std::multiplies<>{});
|
||||
size_t numThreads =
|
||||
product(layout.getWarpsPerCTA()) * product(layout.getThreadsPerWarp());
|
||||
// TODO(Superjomn) add numElemsPerThread to the layout encodings.
|
||||
size_t numElemsPerThread = totalElems / numThreads;
|
||||
|
||||
llvm::SmallVector<Value, 4> elems(numElemsPerThread, llSrc);
|
||||
llvm::SmallVector<Type, 4> elemTypes(elems.size(), srcType);
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||
|
||||
auto llElemPtrPtrTy =
|
||||
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(srcType));
|
||||
auto llStruct =
|
||||
getStructFromElements(loc, elems, rewriter, structTy, llElemPtrPtrTy);
|
||||
return llStruct;
|
||||
}
|
||||
|
||||
struct SplatOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::SplatOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::SplatOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
auto src = op->getOperand(0);
|
||||
|
||||
LLVM::ConstantOp arithConstantOp;
|
||||
if (src.getDefiningOp() &&
|
||||
(arithConstantOp =
|
||||
llvm::dyn_cast<LLVM::ConstantOp>(src.getDefiningOp()))) {
|
||||
Value constant;
|
||||
auto values = arithConstantOp.getValue().dyn_cast<DenseElementsAttr>();
|
||||
|
||||
assert(values.size() == 1);
|
||||
Attribute val;
|
||||
if (type::isInt(src.getType())) {
|
||||
val = values.getValues<IntegerAttr>()[0];
|
||||
} else if (type::isFloat(src.getType())) {
|
||||
val = values.getValues<FloatAttr>()[0];
|
||||
} else {
|
||||
llvm::errs() << "Constant op type not supported";
|
||||
return failure();
|
||||
}
|
||||
|
||||
src = rewriter.create<LLVM::ConstantOp>(loc, val.getType(), val);
|
||||
}
|
||||
|
||||
auto llStruct = convertSplatLikeOp(src.getType(), op.getType(), src,
|
||||
getTypeConverter(), rewriter, loc);
|
||||
rewriter.replaceOp(op, {llStruct});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// This pattern helps to convert arith::ConstantOp(with SplatElementsAttr), the
|
||||
// logic is the same as triton::SplatOp, so the underlying implementation is
|
||||
// reused.
|
||||
struct ArithConstantSplatOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<arith::ConstantOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
arith::ConstantOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto value = op.getValue();
|
||||
if (!value.dyn_cast<SplatElementsAttr>())
|
||||
return failure();
|
||||
|
||||
auto loc = op->getLoc();
|
||||
|
||||
LLVM::ConstantOp arithConstantOp;
|
||||
auto values = op.getValue().dyn_cast<SplatElementsAttr>();
|
||||
auto elemType = values.getElementType();
|
||||
|
||||
Attribute val;
|
||||
if (type::isInt(elemType)) {
|
||||
val = values.getValues<IntegerAttr>()[0];
|
||||
} else if (type::isFloat(elemType)) {
|
||||
val = values.getValues<FloatAttr>()[0];
|
||||
} else {
|
||||
llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: "
|
||||
<< value.getType() << "\n";
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
|
||||
auto llStruct = convertSplatLikeOp(elemType, op.getType(), constOp,
|
||||
getTypeConverter(), rewriter, loc);
|
||||
rewriter.replaceOp(op, llStruct);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct StoreOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
StoreOpConversion(LLVMTypeConverter &converter,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
|
||||
AxisAnalysisPass(axisAnalysisPass) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Value ptr = op.ptr();
|
||||
Value mask = op.mask();
|
||||
Value value = op.value();
|
||||
|
||||
Value llPtr = adaptor.ptr(); // should be LLVM ops
|
||||
Value llMask = adaptor.mask();
|
||||
Value llValue = adaptor.value();
|
||||
|
||||
Type valueElemTy = getTypeConverter()->convertType(
|
||||
value.getType().cast<RankedTensorType>().getElementType());
|
||||
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
auto getLLVMElems = [&](Value value, Value llValue,
|
||||
const TritonGPUBlockedEncodingAttr &layout)
|
||||
-> SmallVector<Value, 4> {
|
||||
auto ty = value.getType().cast<RankedTensorType>();
|
||||
auto shape = ty.getShape();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
|
||||
unsigned valueElems = getElemsPerThread(layout, shape);
|
||||
|
||||
auto llvmElemTy = getTypeConverter()->convertType(ty.getElementType());
|
||||
auto llvmElemPtrPtrTy =
|
||||
LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(llvmElemTy));
|
||||
|
||||
auto valueVals =
|
||||
getElementsFromStruct(loc, llValue, valueElems, rewriter);
|
||||
return valueVals;
|
||||
};
|
||||
|
||||
auto getLayout =
|
||||
[&](Value val) -> std::tuple<TritonGPUBlockedEncodingAttr, unsigned> {
|
||||
auto ty = val.getType().cast<RankedTensorType>();
|
||||
auto shape = ty.getShape();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
auto layout = ty.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
|
||||
unsigned valueElems = getElemsPerThread(layout, shape);
|
||||
|
||||
return std::make_tuple(layout, valueElems);
|
||||
};
|
||||
|
||||
auto [ptrLayout, ptrNumElems] = getLayout(ptr);
|
||||
auto [maskLayout, maskNumElems] = getLayout(mask);
|
||||
auto [valueLayout, valueNumElems] = getLayout(value);
|
||||
|
||||
auto valueElems = getLLVMElems(value, llValue, valueLayout);
|
||||
auto maskElems = getLLVMElems(mask, llMask, maskLayout);
|
||||
assert(valueElems.size() == maskElems.size());
|
||||
|
||||
auto getAlign =
|
||||
[this](Value val,
|
||||
const TritonGPUBlockedEncodingAttr &layout) -> unsigned {
|
||||
auto axisInfo = getAxisInfo(val);
|
||||
assert(axisInfo.hasValue());
|
||||
|
||||
auto order = layout.getOrder();
|
||||
|
||||
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
|
||||
unsigned maxContig = axisInfo->getContiguity(order[0]);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
return alignment;
|
||||
};
|
||||
|
||||
// get align
|
||||
auto getVec = [this, &getAlign](
|
||||
Value val,
|
||||
const TritonGPUBlockedEncodingAttr &layout) -> unsigned {
|
||||
auto axisInfo = getAxisInfo(val);
|
||||
auto contig = axisInfo->getContiguity();
|
||||
// Here order should be ordered by contiguous first, so the first element
|
||||
// should have the largest contiguous.
|
||||
auto order = layout.getOrder();
|
||||
unsigned align = getAlign(val, layout);
|
||||
|
||||
assert(!order.empty());
|
||||
// Is this right?
|
||||
unsigned contigPerThread = layout.getSizePerThread()[order[0]];
|
||||
unsigned vec = std::min(align, contigPerThread);
|
||||
|
||||
// TODO(Superjomn) Consider the is_mma_first_row in the legacy code
|
||||
bool isMMAFirstRow = false;
|
||||
|
||||
if (isMMAFirstRow)
|
||||
vec = std::min<size_t>(2, align);
|
||||
|
||||
return vec;
|
||||
};
|
||||
|
||||
// Determine the vectorization size
|
||||
size_t vec = getVec(ptr, ptrLayout);
|
||||
|
||||
const size_t dtsize = value.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType()
|
||||
.getIntOrFloatBitWidth() /
|
||||
8;
|
||||
const size_t valueElemNbits = dtsize * 8;
|
||||
|
||||
const int numVecs = ptrNumElems / vec;
|
||||
for (size_t vecIdx = 0; vecIdx < ptrNumElems; vecIdx += vec) {
|
||||
|
||||
size_t in_off{};
|
||||
auto ptrProducer = llPtr.getDefiningOp();
|
||||
auto in_gep = llvm::dyn_cast<LLVM::GEPOp>(ptrProducer);
|
||||
|
||||
if (in_gep) {
|
||||
auto indices = in_gep.getIndices();
|
||||
auto cst = dyn_cast<LLVM::ConstantOp>(indices.front().getDefiningOp());
|
||||
in_off =
|
||||
cst ? cst.getValue().dyn_cast<IntegerAttr>().getInt() * dtsize : 0;
|
||||
ptr = cst ? in_gep.getBase() : in_gep;
|
||||
}
|
||||
|
||||
// pack sub-words (< 32/64bits) into words
|
||||
// each load has width min(nbits*vec, 32/64)
|
||||
// and there are (nbits * vec)/width of them
|
||||
const int maxWordWidth = std::max<int>(32, valueElemNbits);
|
||||
const int totalWidth = valueElemNbits * vec;
|
||||
const int width = std::min(totalWidth, maxWordWidth);
|
||||
const int nWords = std::max(1, totalWidth / width);
|
||||
const int wordNElems = width / valueElemNbits;
|
||||
const int vecNElems = totalWidth / valueElemNbits;
|
||||
|
||||
assert(wordNElems * nWords * numVecs == valueElems.size());
|
||||
|
||||
// TODO(Superjomn) Add cache policy to store.
|
||||
// TODO(Superjomn) deal with cache policy.
|
||||
const bool hasL2EvictPolicy = false;
|
||||
|
||||
PtxIOInstr asmStoreInstr("st");
|
||||
asmStoreInstr.predicate(llMask, "b");
|
||||
asmStoreInstr.global().v(width).b(nWords);
|
||||
|
||||
llvm::SmallVector<std::string> asmArgs;
|
||||
|
||||
Type valArgTy = IntegerType::get(ctx, width);
|
||||
auto wordTy = VectorType::get(wordNElems, valueElemTy);
|
||||
|
||||
auto *asmAddr = asmStoreInstr.newAddrOperand(llPtr, "l", in_off);
|
||||
auto *asmArgList = asmStoreInstr.newList();
|
||||
for (int wordIdx = 0; wordIdx < nWords; wordIdx++) {
|
||||
// llWord is a width-len composition
|
||||
Value llWord = rewriter.create<LLVM::UndefOp>(loc, wordTy);
|
||||
// Insert each value element to the composition
|
||||
for (int elemIdx = 0; elemIdx < wordNElems; elemIdx++) {
|
||||
Value elem =
|
||||
valueElems[vecIdx * vecNElems + wordIdx * wordNElems + elemIdx];
|
||||
if (elem.getType().isInteger(1))
|
||||
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
|
||||
elem = rewriter.create<LLVM::BitcastOp>(loc, valueElemTy, elem);
|
||||
|
||||
llWord = rewriter.create<LLVM::InsertElementOp>(
|
||||
loc, wordTy, llWord, elem,
|
||||
rewriter.create<LLVM::ConstantOp>(
|
||||
loc, type::u32Ty(ctx),
|
||||
IntegerAttr::get(type::u32Ty(ctx), elemIdx)));
|
||||
}
|
||||
llWord = rewriter.create<LLVM::BitcastOp>(loc, valArgTy, llWord);
|
||||
std::string constraint =
|
||||
(width == 64) ? "l" : ((width == 32) ? "r" : "c");
|
||||
asmArgList->listAppend(asmStoreInstr.newOperand(llWord, constraint));
|
||||
}
|
||||
|
||||
asmStoreInstr.addOperand(asmAddr);
|
||||
asmStoreInstr.addOperand(asmArgList);
|
||||
|
||||
llvm::SmallVector<Type, 4> argTys({mask.getType(), ptr.getType()});
|
||||
for (int i = 0; i < nWords; i++)
|
||||
argTys.push_back(valArgTy);
|
||||
|
||||
auto ASMReturnTy = LLVM::LLVMStructType::getLiteral(ctx, /*returnTy*/ {});
|
||||
|
||||
auto inlineAsm = rewriter.create<LLVM::InlineAsmOp>(
|
||||
loc, ASMReturnTy, asmStoreInstr.getAllMlirArgs(), // operands
|
||||
asmStoreInstr.dump(), // asm_string
|
||||
asmStoreInstr.getConstrains(), // constraints
|
||||
// TODO(Superjomn) determine the side effect.
|
||||
true, // has_side_effects
|
||||
false, // is_align_stack
|
||||
LLVM::AsmDialectAttr::get(ctx,
|
||||
LLVM::AsmDialect::AD_ATT), // asm_dialect
|
||||
ArrayAttr::get(ctx, {}) // operand_attrs
|
||||
);
|
||||
|
||||
rewriter.replaceOp(op, inlineAsm.getRes());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
llvm::Optional<AxisInfo> getAxisInfo(Value val) const {
|
||||
if (auto it = AxisAnalysisPass.lookupLatticeElement(val)) {
|
||||
return it->getValue();
|
||||
}
|
||||
|
||||
return llvm::Optional<AxisInfo>{};
|
||||
}
|
||||
|
||||
private:
|
||||
AxisInfoAnalysis &AxisAnalysisPass;
|
||||
};
|
||||
|
||||
// Extract numWarps information from TritonGPU module, return 0 if failed.
|
||||
// This is a naive implementation, it assumes that all the blocked layout should
|
||||
// have the same numWarps setting in a module, it just find a blocked layout
|
||||
// encoding and return the warpsPerCTA field.
|
||||
int extractNumWarps(mlir::ModuleOp module) {
|
||||
int numWarps{};
|
||||
if (module->hasAttr(AttrNumWarpsName))
|
||||
numWarps = module->getAttr(AttrNumWarpsName)
|
||||
.dyn_cast<IntegerAttr>()
|
||||
.getValue()
|
||||
.getZExtValue();
|
||||
else
|
||||
llvm::report_fatal_error(
|
||||
"TritonGPU module should contain a triton_gpu.num-warps attribute");
|
||||
|
||||
return numWarps;
|
||||
}
|
||||
|
||||
struct BroadcastOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::BroadcastOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
@@ -647,8 +1040,7 @@ struct LoadOpConversion
|
||||
for (size_t i = 0; i < numElems; i += vecWidth) {
|
||||
Value ptr = ptrVals[i];
|
||||
// TODO: Handle the optimization if ptr is from GEP and the idx is
|
||||
// constant
|
||||
// This should be a canonicalization pattern in LLVM Dialect
|
||||
// constant. This should be a canonicalization pattern in LLVM Dialect
|
||||
unsigned in_off = 0;
|
||||
Value pred = maskVals[i];
|
||||
|
||||
@@ -826,13 +1218,18 @@ public:
|
||||
};
|
||||
|
||||
void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps) {
|
||||
patterns.add<BroadcastOpConversion>(typeConverter);
|
||||
patterns.add<FuncOpConversion>(typeConverter, numWarps);
|
||||
patterns.add<LoadOpConversion>(typeConverter);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter);
|
||||
patterns.add<ReturnOpConversion>(typeConverter);
|
||||
patterns.add<ViewOpConversion>(typeConverter);
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &analysis,
|
||||
PatternBenefit benefit = 1) {
|
||||
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
|
||||
patterns.add<SplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<StoreOpConversion>(typeConverter, analysis, benefit);
|
||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
patterns.add<LoadOpConversion>(typeConverter, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ViewOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
class ConvertTritonGPUToLLVM
|
||||
@@ -851,20 +1248,33 @@ public:
|
||||
TritonLLVMConversionTarget target(*context, typeConverter);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
// TODO: (goostavz) Temporarily disable this, since the lowering of
|
||||
// arithmetic ops in tensor format is not complete yet.
|
||||
// Add arith's patterns to help convert scalar expression to LLVM.
|
||||
// mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
// patterns);
|
||||
|
||||
int numWarps = extractNumWarps(mod);
|
||||
|
||||
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps);
|
||||
auto axisAnalysis = runAxisAnalysis(mod);
|
||||
|
||||
// We set a higher benefit here to ensure triton's patterns runs before
|
||||
// arith patterns for some encoding not supported by the community patterns.
|
||||
populateTritonToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
*axisAnalysis, 10 /*benefit*/);
|
||||
|
||||
// Add arith's patterns to help convert scalar expression to LLVM.
|
||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
|
||||
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<AxisInfoAnalysis> runAxisAnalysis(ModuleOp module) {
|
||||
auto axisAnalysisPass =
|
||||
std::make_unique<AxisInfoAnalysis>(module->getContext());
|
||||
axisAnalysisPass->run(module);
|
||||
return axisAnalysisPass;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
@@ -102,6 +102,7 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto shapedType = getType().cast<ShapedType>();
|
||||
auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
||||
return ret;
|
||||
|
36
test/Conversion/triton_to_llvm.mlir
Normal file
36
test/Conversion/triton_to_llvm.mlir
Normal file
@@ -0,0 +1,36 @@
|
||||
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 -convert-triton-gpu-to-llvm | FileCheck %s
|
||||
|
||||
func @test_splat(%ptr: !tt.ptr<f32>) {
|
||||
// Here, 128 elements, 64(2*32) threads, so each need to process 2 elements
|
||||
//
|
||||
// CHECK: %0 = llvm.bitcast %arg0 : !llvm.ptr<f32, 1> to !llvm.ptr<f32, 1>
|
||||
// CHECK: %1 = llvm.mlir.undef : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
|
||||
// CHECK: %2 = llvm.insertvalue %0, %1[0] : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
|
||||
// CHECK: %3 = llvm.insertvalue %0, %2[1] : !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>
|
||||
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
%a = arith.constant 1.0 : f32
|
||||
%true = arith.constant 1 : i1
|
||||
%b = tt.splat %a : (f32) -> tensor<128xf32>
|
||||
|
||||
// Here, each thread process only 1 element
|
||||
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(i1)>
|
||||
%mask = tt.splat %true : (i1) -> tensor<64xi1>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func @test_store_splat(%ptr: !tt.ptr<f32>) {
|
||||
%ptrs = tt.splat %ptr : (!tt.ptr<f32>) -> tensor<128x!tt.ptr<f32>>
|
||||
%a = arith.constant 1.0 : f32
|
||||
%true = arith.constant 1 : i1
|
||||
|
||||
%vs = tt.splat %a : (f32) -> tensor<128xf32>
|
||||
%mask = tt.splat %true : (i1) -> tensor<128xi1>
|
||||
|
||||
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@%0 st.global.v32.b1 [ %1 + 0 ], { %2 };",
|
||||
// CHECK: "b,l,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.struct<(i1, i1)>, !llvm.struct<(ptr<f32, 1>, ptr<f32, 1>)>, i32) -> !llvm.struct<()>
|
||||
|
||||
tt.store %ptrs, %vs, %mask, {} : tensor<128xf32>
|
||||
|
||||
return
|
||||
}
|
@@ -1,11 +1,10 @@
|
||||
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s
|
||||
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
|
||||
// Here the 128 comes from the 4 in module attribute multiples 32
|
||||
// CHECK: attributes {nvvm.maxntid = 128 : i32} {{.*}}
|
||||
// CHECK: attributes {nvvm.maxntid = 128 : si32} {{.*}}
|
||||
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
|
||||
// CHECK: llvm.return
|
||||
|
Reference in New Issue
Block a user