This PR merges the `triton-mlir` branch, in which we have been quietly rewriting the Triton backend from scratch to increase maintainability, stability and ultimately performance. Changes to the runtime are minimal, and this new version aims to remain backward-compatible with the previous commit. The legacy backend is now officially deprecated, but can still be accessed via the `legacy-backend` tag. Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com> Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com> Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com> Co-authored-by: Yan Da <dyanab@connect.ust.hk> Co-authored-by: Jun Yang <yangjunpro@gmail.com> Co-authored-by: Ian Bearman <ianb@microsoft.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: Qingyi Liu <qingyil@nvidia.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Co-authored-by: Chenggang Zhao <lyricz@yeah.net> Co-authored-by: ben-zhang-609 <benzh609@gmail.com> Co-authored-by: dongdongl <dongdongl@nvidia.com>
370 lines
16 KiB
C++
370 lines
16 KiB
C++
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
|
|
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H
|
|
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "triton/Analysis/Utility.h"
|
|
#include "triton/Conversion/MLIRTypes.h"
|
|
#include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h"
|
|
|
|
// Shortcuts for some commonly used LLVM ops to keep code simple and intuitive
|
|
// Operators
|
|
#define inttoptr(...) rewriter.create<LLVM::IntToPtrOp>(loc, __VA_ARGS__)
|
|
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
|
|
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
|
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
|
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
|
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
|
#define sub(...) rewriter.create<LLVM::SubOp>(loc, __VA_ARGS__)
|
|
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
|
|
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
|
|
#define fmul(...) rewriter.create<LLVM::FMulOp>(loc, __VA_ARGS__)
|
|
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
|
|
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
|
|
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)
|
|
#define smin(...) rewriter.create<LLVM::SMinOp>(loc, __VA_ARGS__)
|
|
#define umin(...) rewriter.create<LLVM::UMinOp>(loc, __VA_ARGS__)
|
|
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
|
|
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
|
|
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
|
|
#define bitcast(val__, type__) \
|
|
rewriter.create<LLVM::BitcastOp>(loc, type__, val__)
|
|
#define gep(...) rewriter.create<LLVM::GEPOp>(loc, __VA_ARGS__)
|
|
#define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__)
|
|
#define insert_val(...) rewriter.create<LLVM::InsertValueOp>(loc, __VA_ARGS__)
|
|
#define extract_val(...) rewriter.create<LLVM::ExtractValueOp>(loc, __VA_ARGS__)
|
|
#define insert_element(...) \
|
|
rewriter.create<LLVM::InsertElementOp>(loc, __VA_ARGS__)
|
|
#define extract_element(...) \
|
|
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
|
|
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
|
|
#define store(val, ptr) rewriter.create<LLVM::StoreOp>(loc, val, ptr)
|
|
#define fcmp_ogt(lhs, rhs) \
|
|
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
|
LLVM::FCmpPredicate::ogt, lhs, rhs)
|
|
#define fcmp_olt(lhs, rhs) \
|
|
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
|
LLVM::FCmpPredicate::olt, lhs, rhs)
|
|
#define icmp_eq(...) \
|
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
|
#define icmp_ne(...) \
|
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__)
|
|
#define icmp_slt(...) \
|
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__)
|
|
#define icmp_sle(...) \
|
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sle, __VA_ARGS__)
|
|
#define icmp_sgt(...) \
|
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, __VA_ARGS__)
|
|
#define icmp_sge(...) \
|
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sge, __VA_ARGS__)
|
|
#define icmp_ult(...) \
|
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ult, __VA_ARGS__)
|
|
#define icmp_ule(...) \
|
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ule, __VA_ARGS__)
|
|
#define icmp_ugt(...) \
|
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ugt, __VA_ARGS__)
|
|
#define icmp_uge(...) \
|
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::uge, __VA_ARGS__)
|
|
#define select(...) rewriter.create<LLVM::SelectOp>(loc, __VA_ARGS__)
|
|
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
|
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
|
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
|
|
|
// Types
|
|
#define i32_ty rewriter.getIntegerType(32)
|
|
#define i16_ty rewriter.getIntegerType(16)
|
|
#define ui32_ty rewriter.getIntegerType(32, false)
|
|
#define f16_ty rewriter.getF16Type()
|
|
#define bf16_ty rewriter.getBF16Type()
|
|
#define i8_ty rewriter.getIntegerType(8)
|
|
#define f32_ty rewriter.getF32Type()
|
|
#define f64_ty rewriter.getF64Type()
|
|
#define vec_ty(type, num) VectorType::get(num, type)
|
|
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
|
|
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
|
|
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
|
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
|
|
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
|
|
|
|
// Constants
|
|
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
|
|
#define int_val(width, val) \
|
|
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
|
|
#define idx_val(...) \
|
|
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
|
|
__VA_ARGS__)
|
|
#define tid_val() getThreadId(rewriter, loc)
|
|
|
|
namespace mlir {
|
|
namespace triton {
|
|
|
|
// Delinearize supposing order is [0, 1, .. , n]
|
|
template <typename T>
|
|
llvm::SmallVector<T> getMultiDimIndexImpl(T linearIndex,
|
|
llvm::ArrayRef<T> shape) {
|
|
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
|
|
size_t rank = shape.size();
|
|
T accMul = product(shape.drop_back());
|
|
T linearRemain = linearIndex;
|
|
llvm::SmallVector<T> multiDimIndex(rank);
|
|
for (int i = rank - 1; i >= 0; --i) {
|
|
multiDimIndex[i] = linearRemain / accMul;
|
|
linearRemain = linearRemain % accMul;
|
|
if (i != 0) {
|
|
accMul = accMul / shape[i - 1];
|
|
}
|
|
}
|
|
return multiDimIndex;
|
|
}
|
|
|
|
template <typename T>
|
|
llvm::SmallVector<T> getMultiDimIndex(T linearIndex, llvm::ArrayRef<T> shape,
|
|
llvm::ArrayRef<unsigned> order) {
|
|
size_t rank = shape.size();
|
|
assert(rank == order.size());
|
|
auto reordered = reorder(shape, order);
|
|
auto reorderedMultiDim = getMultiDimIndexImpl<T>(linearIndex, reordered);
|
|
llvm::SmallVector<T> multiDim(rank);
|
|
for (unsigned i = 0; i < rank; ++i) {
|
|
multiDim[order[i]] = reorderedMultiDim[i];
|
|
}
|
|
return multiDim;
|
|
}
|
|
|
|
// Linearize supposing order is [0, 1, .. , n]
|
|
template <typename T>
|
|
static T getLinearIndexImpl(llvm::ArrayRef<T> multiDimIndex,
|
|
llvm::ArrayRef<T> shape) {
|
|
assert(multiDimIndex.size() == shape.size());
|
|
// shape: {a, b, c, d} -> accMul: {1, a, a*b, a*b*c}
|
|
size_t rank = shape.size();
|
|
T accMul = product(shape.drop_back());
|
|
T linearIndex = 0;
|
|
for (int i = rank - 1; i >= 0; --i) {
|
|
linearIndex += multiDimIndex[i] * accMul;
|
|
if (i != 0) {
|
|
accMul = accMul / shape[i - 1];
|
|
}
|
|
}
|
|
return linearIndex;
|
|
}
|
|
|
|
template <typename T>
|
|
static T getLinearIndex(llvm::ArrayRef<T> multiDimIndex,
|
|
llvm::ArrayRef<T> shape,
|
|
llvm::ArrayRef<unsigned> order) {
|
|
assert(shape.size() == order.size());
|
|
return getLinearIndexImpl<T>(reorder(multiDimIndex, order),
|
|
reorder(shape, order));
|
|
}
|
|
|
|
} // namespace triton
|
|
|
|
namespace LLVM {
|
|
using namespace mlir::triton;
|
|
|
|
static Value getStructFromElements(Location loc, ValueRange resultVals,
|
|
ConversionPatternRewriter &rewriter,
|
|
Type structType) {
|
|
if (!structType.isa<LLVM::LLVMStructType>()) {
|
|
return *resultVals.begin();
|
|
}
|
|
|
|
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
|
for (const auto &v : llvm::enumerate(resultVals)) {
|
|
assert(v.value() && "can not insert null values");
|
|
llvmStruct = insert_val(structType, llvmStruct, v.value(),
|
|
rewriter.getI64ArrayAttr(v.index()));
|
|
}
|
|
return llvmStruct;
|
|
}
|
|
|
|
static SmallVector<Value>
|
|
getElementsFromStruct(Location loc, Value llvmStruct,
|
|
ConversionPatternRewriter &rewriter) {
|
|
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
|
|
llvmStruct.getType().isa<triton::PointerType>() ||
|
|
llvmStruct.getType().isa<LLVM::LLVMPointerType>())
|
|
return {llvmStruct};
|
|
ArrayRef<Type> types =
|
|
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
|
|
SmallVector<Value> results(types.size());
|
|
for (unsigned i = 0; i < types.size(); ++i) {
|
|
Type type = types[i];
|
|
results[i] = extract_val(type, llvmStruct, rewriter.getI64ArrayAttr(i));
|
|
}
|
|
return results;
|
|
}
|
|
|
|
// Create a 32-bit integer constant.
|
|
static Value createConstantI32(Location loc, PatternRewriter &rewriter,
|
|
int32_t v) {
|
|
auto i32ty = rewriter.getIntegerType(32);
|
|
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
|
|
IntegerAttr::get(i32ty, v));
|
|
}
|
|
|
|
static Value createConstantF32(Location loc, PatternRewriter &rewriter,
|
|
float v) {
|
|
auto type = type::f32Ty(rewriter.getContext());
|
|
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
|
rewriter.getF32FloatAttr(v));
|
|
}
|
|
|
|
static Value createConstantF64(Location loc, PatternRewriter &rewriter,
|
|
float v) {
|
|
auto type = type::f64Ty(rewriter.getContext());
|
|
return rewriter.create<LLVM::ConstantOp>(loc, type,
|
|
rewriter.getF64FloatAttr(v));
|
|
}
|
|
|
|
// Create an index type constant.
|
|
static Value createIndexConstant(OpBuilder &builder, Location loc,
|
|
TypeConverter *converter, int64_t value) {
|
|
Type ty = converter->convertType(builder.getIndexType());
|
|
return builder.create<LLVM::ConstantOp>(loc, ty,
|
|
builder.getIntegerAttr(ty, value));
|
|
}
|
|
|
|
// Create an integer constant of \param width bits.
|
|
static Value createLLVMIntegerConstant(OpBuilder &builder, Location loc,
|
|
short width, int64_t value) {
|
|
Type ty = builder.getIntegerType(width);
|
|
return builder.create<LLVM::ConstantOp>(loc, ty,
|
|
builder.getIntegerAttr(ty, value));
|
|
}
|
|
|
|
/// Helper function to get strides from a given shape and its order
|
|
static SmallVector<Value>
|
|
getStridesFromShapeAndOrder(ArrayRef<int64_t> shape, ArrayRef<unsigned> order,
|
|
Location loc, ConversionPatternRewriter &rewriter) {
|
|
auto rank = shape.size();
|
|
SmallVector<Value> strides(rank);
|
|
int64_t stride = 1;
|
|
for (auto idx : order) {
|
|
strides[idx] = i32_val(stride);
|
|
stride *= shape[idx];
|
|
}
|
|
return strides;
|
|
}
|
|
|
|
struct SharedMemoryObject {
|
|
Value base; // i32 ptr. The start address of the shared memory object.
|
|
// We need to store strides as Values but not integers because the
|
|
// extract_slice instruction can take a slice at arbitrary offsets.
|
|
// Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is
|
|
// 32, we need to let the instruction that uses $a to be aware of that.
|
|
// Otherwise, when we use $a, we only know that the shape of $a is 16x16. If
|
|
// we store strides into an attribute array of integers, the information
|
|
// cannot pass through block argument assignment because attributes are
|
|
// associated with operations but not Values.
|
|
// TODO(Keren): We may need to figure out a way to store strides as integers
|
|
// if we want to support more optimizations.
|
|
SmallVector<Value>
|
|
strides; // i32 int. The strides of the shared memory object.
|
|
SmallVector<Value> offsets; // i32 int. The offsets of the shared memory
|
|
// objects from the originally allocated object.
|
|
|
|
SharedMemoryObject(Value base, ArrayRef<Value> strides,
|
|
ArrayRef<Value> offsets)
|
|
: base(base), strides(strides.begin(), strides.end()),
|
|
offsets(offsets.begin(), offsets.end()) {}
|
|
|
|
SharedMemoryObject(Value base, ArrayRef<int64_t> shape,
|
|
ArrayRef<unsigned> order, Location loc,
|
|
ConversionPatternRewriter &rewriter)
|
|
: base(base) {
|
|
strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter);
|
|
|
|
for (auto idx : order) {
|
|
offsets.emplace_back(i32_val(0));
|
|
}
|
|
}
|
|
|
|
SmallVector<Value> getElems() const {
|
|
SmallVector<Value> elems;
|
|
elems.push_back(base);
|
|
elems.append(strides.begin(), strides.end());
|
|
elems.append(offsets.begin(), offsets.end());
|
|
return elems;
|
|
}
|
|
|
|
SmallVector<Type> getTypes() const {
|
|
SmallVector<Type> types;
|
|
types.push_back(base.getType());
|
|
types.append(strides.size(), IntegerType::get(base.getContext(), 32));
|
|
types.append(offsets.size(), IntegerType::get(base.getContext(), 32));
|
|
return types;
|
|
}
|
|
|
|
Value getCSwizzleOffset(int order) const {
|
|
assert(order >= 0 && order < strides.size());
|
|
return offsets[order];
|
|
}
|
|
|
|
Value getBaseBeforeSwizzle(int order, Location loc,
|
|
ConversionPatternRewriter &rewriter) const {
|
|
Value cSwizzleOffset = getCSwizzleOffset(order);
|
|
Value offset = sub(i32_val(0), cSwizzleOffset);
|
|
Type type = base.getType();
|
|
return gep(type, base, offset);
|
|
}
|
|
};
|
|
|
|
static SharedMemoryObject
|
|
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
|
ConversionPatternRewriter &rewriter) {
|
|
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
|
|
auto rank = (elems.size() - 1) / 2;
|
|
return {/*base=*/elems[0],
|
|
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
|
|
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
|
|
}
|
|
|
|
static Value storeShared(ConversionPatternRewriter &rewriter, Location loc,
|
|
Value ptr, Value val, Value pred) {
|
|
MLIRContext *ctx = rewriter.getContext();
|
|
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
|
const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r");
|
|
|
|
PTXBuilder builder;
|
|
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
|
|
auto *valOpr = builder.newOperand(val, c);
|
|
auto &st = builder.create<>("st")->shared().b(bits);
|
|
st(ptrOpr, valOpr).predicate(pred, "b");
|
|
return builder.launch(rewriter, loc, void_ty(ctx));
|
|
}
|
|
|
|
static Value shflSync(Location loc, ConversionPatternRewriter &rewriter,
|
|
Value val, int i) {
|
|
unsigned bits = val.getType().getIntOrFloatBitWidth();
|
|
|
|
if (bits == 64) {
|
|
Type vecTy = vec_ty(f32_ty, 2);
|
|
Value vec = bitcast(val, vecTy);
|
|
Value val0 = extract_element(f32_ty, vec, i32_val(0));
|
|
Value val1 = extract_element(f32_ty, vec, i32_val(1));
|
|
val0 = shflSync(loc, rewriter, val0, i);
|
|
val1 = shflSync(loc, rewriter, val1, i);
|
|
vec = undef(vecTy);
|
|
vec = insert_element(vecTy, vec, val0, i32_val(0));
|
|
vec = insert_element(vecTy, vec, val1, i32_val(1));
|
|
return bitcast(vec, val.getType());
|
|
}
|
|
|
|
PTXBuilder builder;
|
|
auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32");
|
|
auto *dOpr = builder.newOperand("=r");
|
|
auto *aOpr = builder.newOperand(val, "r");
|
|
auto *bOpr = builder.newConstantOperand(i);
|
|
auto *cOpr = builder.newConstantOperand("0x1f");
|
|
auto *maskOpr = builder.newConstantOperand("0xffffffff");
|
|
shfl(dOpr, aOpr, bOpr, cOpr, maskOpr);
|
|
return builder.launch(rewriter, loc, val.getType(), false);
|
|
}
|
|
|
|
} // namespace LLVM
|
|
} // namespace mlir
|
|
|
|
#endif
|