#include "mlir/Analysis/DataFlowAnalysis.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "llvm/Support/raw_ostream.h" #include #include "triton/Analysis/AxisInfo.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" namespace mlir { //===----------------------------------------------------------------------===// // AxisInfo //===----------------------------------------------------------------------===// // Function for extended Euclidean Algorithm static int gcd_impl(int a, int b, int *x, int *y) { // Base Case if (a == 0) { *x = 0; *y = 1; return b; } int x1, y1; // To store results of recursive call int gcd = gcd_impl(b % a, a, &x1, &y1); // Update x and y using results of // recursive call *x = y1 - (b / a) * x1; *y = x1; return gcd; } static int gcd(int a, int b) { int x, y; return gcd_impl(a, b, &x, &y); } AxisInfo AxisInfo::getPessimisticValueState(Value value) { size_t rank = 1; if (TensorType ty = value.getType().dyn_cast()) rank = ty.getRank(); int divHint = 1; BlockArgument blockArg = value.dyn_cast(); if (blockArg && blockArg.getOwner()->isEntryBlock()) { Operation *op = blockArg.getOwner()->getParentOp(); if (FuncOp fun = dyn_cast(op)) { Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility"); if (attr) divHint = attr.cast().getValue().getZExtValue(); } else if (auto fun = dyn_cast(op)) { Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility"); if (attr) divHint = attr.cast().getValue().getZExtValue(); } } DimVectorT contiguity(rank, 1); DimVectorT divisibility(rank, divHint); DimVectorT constancy(rank, 1); return AxisInfo(contiguity, divisibility, constancy); } // The gcd of both arguments for each dimension AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) { DimVectorT retContiguity; DimVectorT retDivisibility; DimVectorT retConstancy; for (int d = 0; d < lhs.getRank(); ++d) { retContiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d))); retDivisibility.push_back( gcd(lhs.getDivisibility(d), rhs.getDivisibility(d))); retConstancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d))); } return AxisInfo(retContiguity, retDivisibility, retConstancy); } //===----------------------------------------------------------------------===// // AxisInfoAnalysis //===----------------------------------------------------------------------===// AxisInfo AxisInfoAnalysis::visitBinaryOp( Operation *op, AxisInfo lhsInfo, AxisInfo rhsInfo, const std::function &getContiguity, const std::function &getDivisibility, const std::function &getConstancy) { int rank = lhsInfo.getRank(); AxisInfo::DimVectorT newContiguity; AxisInfo::DimVectorT newDivisibility; AxisInfo::DimVectorT newConstancy; for (int d = 0; d < rank; ++d) { newContiguity.push_back(getContiguity(lhsInfo, rhsInfo, d)); newDivisibility.push_back(getDivisibility(lhsInfo, rhsInfo, d)); newConstancy.push_back(getConstancy(lhsInfo, rhsInfo, d)); } return AxisInfo(newContiguity, newDivisibility, newConstancy); } ChangeResult AxisInfoAnalysis::visitOperation( Operation *op, ArrayRef *> operands) { AxisInfo curr; // This preserves the input axes (e.g., cast): if (llvm::isa(op)) curr = operands[0]->getValue(); // Constant ranges if (triton::MakeRangeOp make_range = llvm::dyn_cast(op)) { int start = make_range.start(); int end = make_range.end(); AxisInfo::DimVectorT contiguity = {end - start}; AxisInfo::DimVectorT divisibility = {highestPowOf2Divisor(start)}; AxisInfo::DimVectorT constancy = {1}; curr = AxisInfo(contiguity, divisibility, constancy); } // Constant if (arith::ConstantOp constant = llvm::dyn_cast(op)) { auto intAttr = constant.getValue().dyn_cast(); if (intAttr) { size_t val = intAttr.getValue().getZExtValue(); curr = AxisInfo({1}, {highestPowOf2Divisor(val)}, {1}); } // TODO: generalize to dense attr auto splatAttr = constant.getValue().dyn_cast(); if (splatAttr && splatAttr.getElementType().isInteger(32)) { auto value = splatAttr.getSplatValue(); TensorType ty = splatAttr.getType().cast(); curr = AxisInfo( AxisInfo::DimVectorT(ty.getRank(), 1), AxisInfo::DimVectorT(ty.getRank(), highestPowOf2Divisor(value)), AxisInfo::DimVectorT(ty.getShape().begin(), ty.getShape().end())); } } // TODO: refactor & complete binary ops // Addition if (llvm::isa(op)) { auto newContiguity = [&](AxisInfo lhs, AxisInfo rhs, int d) { return std::max(gcd(lhs.getContiguity(d), rhs.getConstancy(d)), gcd(lhs.getConstancy(d), rhs.getContiguity(d))); }; auto newConstancy = [&](AxisInfo lhs, AxisInfo rhs, int d) { return gcd(lhs.getConstancy(d), rhs.getConstancy(d)); }; auto newDivisibility = [&](AxisInfo lhs, AxisInfo rhs, int d) { return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)); }; curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), newContiguity, newDivisibility, newConstancy); } // Multiplication if (llvm::isa(op)) { auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; }; auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) { return gcd(lhs.getConstancy(d), rhs.getConstancy(d)); }; auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { return lhs.getDivisibility(d) * rhs.getDivisibility(d); }; curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), newContiguity, newDivisibility, newConstancy); } // Remainder if (llvm::isa(op)) { auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return gcd(lhs.getContiguity(d), rhs.getDivisibility(d)); }; auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { return gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)); }; auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) { return gcd(lhs.getConstancy(d), rhs.getConstancy(d)); }; curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), newContiguity, newDivisibility, newConstancy); } // TODO: All other binary ops if (llvm::isa(op)) { auto newContiguity = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; }; auto newDivisibility = [](AxisInfo lhs, AxisInfo rhs, int d) { return 1; }; auto newConstancy = [](AxisInfo lhs, AxisInfo rhs, int d) { return gcd(lhs.getConstancy(d), rhs.getConstancy(d)); }; curr = visitBinaryOp(op, operands[0]->getValue(), operands[1]->getValue(), newContiguity, newDivisibility, newConstancy); } // Splat if (llvm::isa(op)) { Type _retTy = *op->result_type_begin(); TensorType retTy = _retTy.cast(); AxisInfo opInfo = operands[0]->getValue(); AxisInfo::DimVectorT contiguity; AxisInfo::DimVectorT divisibility; AxisInfo::DimVectorT constancy; for (int d = 0; d < retTy.getRank(); ++d) { contiguity.push_back(1); divisibility.push_back(opInfo.getDivisibility(0)); constancy.push_back(retTy.getShape()[d]); } curr = AxisInfo(contiguity, divisibility, constancy); } // expandDims if (auto expandDims = llvm::dyn_cast(op)) { AxisInfo opInfo = operands[0]->getValue(); AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); AxisInfo::DimVectorT constancy = opInfo.getConstancy(); contiguity.insert(contiguity.begin() + expandDims.axis(), 1); divisibility.insert(divisibility.begin() + expandDims.axis(), 1); constancy.insert(constancy.begin() + expandDims.axis(), 1); curr = AxisInfo(contiguity, divisibility, constancy); } // Broadcast if (llvm::isa(op)) { Type _retTy = *op->result_type_begin(); Type _opTy = *op->operand_type_begin(); TensorType retTy = _retTy.cast(); TensorType opTy = _opTy.cast(); ArrayRef retShape = retTy.getShape(); ArrayRef opShape = opTy.getShape(); AxisInfo opInfo = operands[0]->getValue(); AxisInfo::DimVectorT contiguity; AxisInfo::DimVectorT divisibility; AxisInfo::DimVectorT constancy; for (int d = 0; d < retTy.getRank(); ++d) { contiguity.push_back(opShape[d] == 1 ? 1 : opInfo.getContiguity(d)); divisibility.push_back(opInfo.getDivisibility(d)); constancy.push_back(opShape[d] == 1 ? retShape[d] : opInfo.getConstancy(d)); } curr = AxisInfo(contiguity, divisibility, constancy); } // CmpI if ((llvm::dyn_cast(op) || llvm::dyn_cast(op)) && op->getResult(0).getType().dyn_cast()) { auto resTy = op->getResult(0).getType().cast(); short rank = resTy.getRank(); auto lhsInfo = operands[0]->getValue(); auto rhsInfo = operands[1]->getValue(); auto shape = resTy.getShape(); AxisInfo::DimVectorT contiguity, divisibility, constancy; for (short d = 0; d < rank; ++d) { if (rhsInfo.getConstancy(d) % lhsInfo.getContiguity(d) == 0 || rhsInfo.getConstancy(d) % lhsInfo.getConstancy(d)) constancy.push_back( gcd(lhsInfo.getDivisibility(d), rhsInfo.getDivisibility(d))); else constancy.push_back(1); divisibility.push_back(shape[d]); contiguity.push_back(1); } curr = AxisInfo(contiguity, divisibility, constancy); } // UnrealizedConversionCast // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is // in the process of a PartialConversion, where UnrealizedConversionCast // may exist if (llvm::isa(op)) { curr = operands[0]->getValue(); } if (curr.getRank() == 0) { return markAllPessimisticFixpoint(op->getResults()); } // join all lattice elements ChangeResult result = ChangeResult::NoChange; for (Value value : op->getResults()) { result |= getLatticeElement(value).join(curr); } return result; } unsigned AxisInfoAnalysis::getPtrVectorSize(Value ptr) { auto tensorTy = ptr.getType().dyn_cast(); if (!tensorTy) return 1; auto layout = tensorTy.getEncoding(); auto shape = tensorTy.getShape(); // Here order should be ordered by contiguous first, so the first element // should have the largest contiguous. auto order = triton::gpu::getOrder(layout); unsigned align = getPtrAlignment(ptr); unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]]; unsigned vec = std::min(align, contigPerThread); vec = std::min(shape[order[0]], vec); return vec; } unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) { auto tensorTy = ptr.getType().dyn_cast(); if (!tensorTy) return 1; auto axisInfo = lookupLatticeElement(ptr)->getValue(); auto layout = tensorTy.getEncoding(); auto order = triton::gpu::getOrder(layout); unsigned maxMultiple = axisInfo.getDivisibility(order[0]); unsigned maxContig = axisInfo.getContiguity(order[0]); unsigned alignment = std::min(maxMultiple, maxContig); return alignment; } unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) { auto tensorTy = mask.getType().dyn_cast(); if (!tensorTy) return 1; auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); auto maskAxis = lookupLatticeElement(mask)->getValue(); auto alignment = std::max(maskAxis.getConstancy(maskOrder[0]), 1); return alignment; } } // namespace mlir