#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H #define TRITON_CONVERSION_TRITONGPU_TO_LLVM_UTILITY_H #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/PtxAsmFormat.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/Support/Format.h" #include "llvm/Support/FormatVariadic.h" #include #include // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive // Operators #define inttoptr(...) rewriter.create(loc, __VA_ARGS__) #define ptrtoint(...) rewriter.create(loc, __VA_ARGS__) #define zext(...) rewriter.create(loc, __VA_ARGS__) #define udiv(...) rewriter.create(loc, __VA_ARGS__) #define urem(...) rewriter.create(loc, __VA_ARGS__) #define add(...) rewriter.create(loc, __VA_ARGS__) #define sub(...) rewriter.create(loc, __VA_ARGS__) #define fadd(...) rewriter.create(loc, __VA_ARGS__) #define mul(...) rewriter.create(loc, __VA_ARGS__) #define fmul(...) rewriter.create(loc, __VA_ARGS__) #define smax(...) rewriter.create(loc, __VA_ARGS__) #define umax(...) rewriter.create(loc, __VA_ARGS__) #define fmax(...) rewriter.create(loc, __VA_ARGS__) #define smin(...) rewriter.create(loc, __VA_ARGS__) #define umin(...) rewriter.create(loc, __VA_ARGS__) #define fmin(...) rewriter.create(loc, __VA_ARGS__) #define and_(...) rewriter.create(loc, __VA_ARGS__) #define xor_(...) rewriter.create(loc, __VA_ARGS__) #define bitcast(val__, type__) \ rewriter.create(loc, type__, val__) #define gep(...) rewriter.create(loc, __VA_ARGS__) #define ptr_ty(...) LLVM::LLVMPointerType::get(__VA_ARGS__) #define insert_val(...) rewriter.create(loc, __VA_ARGS__) #define extract_val(...) rewriter.create(loc, __VA_ARGS__) #define insert_element(...) \ rewriter.create(loc, __VA_ARGS__) #define extract_element(...) \ rewriter.create(loc, __VA_ARGS__) #define load(...) rewriter.create(loc, __VA_ARGS__) #define store(val, ptr) rewriter.create(loc, val, ptr) #define fcmp_ogt(lhs, rhs) \ rewriter.create(loc, rewriter.getI1Type(), \ LLVM::FCmpPredicate::ogt, lhs, rhs) #define fcmp_olt(lhs, rhs) \ rewriter.create(loc, rewriter.getI1Type(), \ LLVM::FCmpPredicate::olt, lhs, rhs) #define icmp_eq(...) \ rewriter.create(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__) #define icmp_ne(...) \ rewriter.create(loc, LLVM::ICmpPredicate::ne, __VA_ARGS__) #define icmp_slt(...) \ rewriter.create(loc, LLVM::ICmpPredicate::slt, __VA_ARGS__) #define icmp_sle(...) \ rewriter.create(loc, LLVM::ICmpPredicate::sle, __VA_ARGS__) #define icmp_sgt(...) \ rewriter.create(loc, LLVM::ICmpPredicate::sgt, __VA_ARGS__) #define icmp_sge(...) \ rewriter.create(loc, LLVM::ICmpPredicate::sge, __VA_ARGS__) #define icmp_ult(...) \ rewriter.create(loc, LLVM::ICmpPredicate::ult, __VA_ARGS__) #define icmp_ule(...) \ rewriter.create(loc, LLVM::ICmpPredicate::ule, __VA_ARGS__) #define icmp_ugt(...) \ rewriter.create(loc, LLVM::ICmpPredicate::ugt, __VA_ARGS__) #define icmp_uge(...) \ rewriter.create(loc, LLVM::ICmpPredicate::uge, __VA_ARGS__) #define select(...) rewriter.create(loc, __VA_ARGS__) #define address_of(...) rewriter.create(loc, __VA_ARGS__) #define barrier() rewriter.create(loc) #define undef(...) rewriter.create(loc, __VA_ARGS__) // Types #define i32_ty rewriter.getIntegerType(32) #define ui32_ty rewriter.getIntegerType(32, false) #define f16_ty rewriter.getF16Type() #define bf16_ty rewriter.getBF16Type() #define i8_ty rewriter.getIntegerType(8) #define f32_ty rewriter.getF32Type() #define f64_ty rewriter.getF64Type() #define vec_ty(type, num) VectorType::get(num, type) #define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__) #define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__) #define void_ty(ctx) LLVM::LLVMVoidType::get(ctx) #define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__) #define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count) // Constants #define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__) #define int_val(width, val) \ LLVM::createLLVMIntegerConstant(rewriter, loc, width, val) #define idx_val(...) \ LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \ __VA_ARGS__) #define tid_val() getThreadId(rewriter, loc) namespace mlir { namespace LLVM { using namespace mlir::triton; Value getStructFromElements(Location loc, ValueRange resultVals, ConversionPatternRewriter &rewriter, Type structType) { if (!structType.isa()) { return *resultVals.begin(); } Value llvmStruct = rewriter.create(loc, structType); for (const auto &v : llvm::enumerate(resultVals)) { assert(v.value() && "can not insert null values"); llvmStruct = insert_val(structType, llvmStruct, v.value(), rewriter.getI64ArrayAttr(v.index())); } return llvmStruct; } SmallVector getElementsFromStruct(Location loc, Value llvmStruct, ConversionPatternRewriter &rewriter) { if (llvmStruct.getType().isIntOrIndexOrFloat() || llvmStruct.getType().isa() || llvmStruct.getType().isa()) return {llvmStruct}; ArrayRef types = llvmStruct.getType().cast().getBody(); SmallVector results(types.size()); for (unsigned i = 0; i < types.size(); ++i) { Type type = types[i]; results[i] = extract_val(type, llvmStruct, rewriter.getI64ArrayAttr(i)); } return results; } // Create a 32-bit integer constant. Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) { auto i32ty = rewriter.getIntegerType(32); return rewriter.create(loc, i32ty, IntegerAttr::get(i32ty, v)); } Value createConstantF32(Location loc, PatternRewriter &rewriter, float v) { auto type = type::f32Ty(rewriter.getContext()); return rewriter.create(loc, type, rewriter.getF32FloatAttr(v)); } Value createConstantF64(Location loc, PatternRewriter &rewriter, float v) { auto type = type::f64Ty(rewriter.getContext()); return rewriter.create(loc, type, rewriter.getF64FloatAttr(v)); } // Create an index type constant. Value createIndexConstant(OpBuilder &builder, Location loc, TypeConverter *converter, int64_t value) { Type ty = converter->convertType(builder.getIndexType()); return builder.create(loc, ty, builder.getIntegerAttr(ty, value)); } // Create an integer constant of \param width bits. Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, int64_t value) { Type ty = builder.getIntegerType(width); return builder.create(loc, ty, builder.getIntegerAttr(ty, value)); } /// Helper function to get strides from a given shape and its order SmallVector getStridesFromShapeAndOrder(ArrayRef shape, ArrayRef order, Location loc, ConversionPatternRewriter &rewriter) { auto rank = shape.size(); SmallVector strides(rank); auto stride = 1; for (auto idx : order) { strides[idx] = i32_val(stride); stride *= shape[idx]; } return strides; } struct SharedMemoryObject { Value base; // i32 ptr. The start address of the shared memory object. // We need to store strides as Values but not integers because the // extract_slice instruction can take a slice at artibary offsets. // Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is // 32, we need to let the instruction that uses $a to be aware of that. // Otherwise, when we use $a, we only know that the shape of $a is 16x16. If // we store strides into an attribute array of integers, the information // cannot pass through block argument assignment because attributes are // associated with operations but not Values. // TODO(Keren): We may need to figure out a way to store strides as integers // if we want to support more optimizations. SmallVector strides; // i32 int. The strides of the shared memory object. SmallVector offsets; // i32 int. The offsets of the shared memory // objects from the originally allocated object. SharedMemoryObject(Value base, ArrayRef strides, ArrayRef offsets) : base(base), strides(strides.begin(), strides.end()), offsets(offsets.begin(), offsets.end()) {} SharedMemoryObject(Value base, ArrayRef shape, ArrayRef order, Location loc, ConversionPatternRewriter &rewriter) : base(base) { strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); for (auto idx : order) { offsets.emplace_back(i32_val(0)); } } SmallVector getElems() const { SmallVector elems; elems.push_back(base); elems.append(strides.begin(), strides.end()); elems.append(offsets.begin(), offsets.end()); return elems; } SmallVector getTypes() const { SmallVector types; types.push_back(base.getType()); types.append(strides.size(), IntegerType::get(base.getContext(), 32)); types.append(offsets.size(), IntegerType::get(base.getContext(), 32)); return types; } Value getCSwizzleOffset(int order) const { assert(order >= 0 && order < strides.size()); return offsets[order]; } Value getBaseBeforeSwizzle(int order, Location loc, ConversionPatternRewriter &rewriter) const { Value cSwizzleOffset = getCSwizzleOffset(order); Value offset = sub(i32_val(0), cSwizzleOffset); Type type = base.getType(); return gep(type, base, offset); } }; SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, ConversionPatternRewriter &rewriter) { auto elems = getElementsFromStruct(loc, llvmStruct, rewriter); auto rank = (elems.size() - 1) / 2; return {/*base=*/elems[0], /*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank}, /*offsets=*/{elems.begin() + 1 + rank, elems.end()}}; } Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, Value val, Value pred) { MLIRContext *ctx = rewriter.getContext(); unsigned bits = val.getType().getIntOrFloatBitWidth(); const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); PTXBuilder builder; auto *ptrOpr = builder.newAddrOperand(ptr, "r"); auto *valOpr = builder.newOperand(val, c); auto &st = builder.create<>("st")->shared().b(bits); st(ptrOpr, valOpr).predicate(pred, "b"); return builder.launch(rewriter, loc, void_ty(ctx)); } Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i) { unsigned bits = val.getType().getIntOrFloatBitWidth(); if (bits == 64) { Type vecTy = vec_ty(f32_ty, 2); Value vec = bitcast(val, vecTy); Value val0 = extract_element(f32_ty, vec, i32_val(0)); Value val1 = extract_element(f32_ty, vec, i32_val(1)); val0 = shflSync(loc, rewriter, val0, i); val1 = shflSync(loc, rewriter, val1, i); vec = undef(vecTy); vec = insert_element(vecTy, vec, val0, i32_val(0)); vec = insert_element(vecTy, vec, val1, i32_val(1)); return bitcast(vec, val.getType()); } PTXBuilder builder; auto &shfl = builder.create("shfl.sync")->o("bfly").o("b32"); auto *dOpr = builder.newOperand("=r"); auto *aOpr = builder.newOperand(val, "r"); auto *bOpr = builder.newConstantOperand(i); auto *cOpr = builder.newConstantOperand("0x1f"); auto *maskOpr = builder.newConstantOperand("0xffffffff"); shfl(dOpr, aOpr, bOpr, cOpr, maskOpr); return builder.launch(rewriter, loc, val.getType(), false); } } // namespace LLVM } // namespace mlir #endif