diff --git a/include/triton/Conversion/Passes.h b/include/triton/Conversion/Passes.h index 07aff36a4..e1669c01d 100644 --- a/include/triton/Conversion/Passes.h +++ b/include/triton/Conversion/Passes.h @@ -1,6 +1,7 @@ #ifndef TRITON_CONVERSION_PASSES_H #define TRITON_CONVERSION_PASSES_H +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" @@ -13,4 +14,4 @@ namespace triton { } // namespace triton } // namespace mlir -#endif \ No newline at end of file +#endif diff --git a/include/triton/Conversion/Passes.td b/include/triton/Conversion/Passes.td index 0f37c76a7..504dc9cb9 100644 --- a/include/triton/Conversion/Passes.td +++ b/include/triton/Conversion/Passes.td @@ -35,6 +35,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" let dependentDialects = ["mlir::arith::ArithmeticDialect", "mlir::StandardOpsDialect", "mlir::scf::SCFDialect", + "mlir::LLVM::LLVMDialect", "mlir::triton::TritonDialect", "mlir::triton::gpu::TritonGPUDialect"]; } diff --git a/lib/Conversion/PassDetail.h b/lib/Conversion/PassDetail.h index e60f15a84..ca6b9d9c4 100644 --- a/lib/Conversion/PassDetail.h +++ b/lib/Conversion/PassDetail.h @@ -1,6 +1,7 @@ #ifndef TRITON_CONVERSION_PASSDETAIL_H #define TRITON_CONVERSION_PASSDETAIL_H +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Pass/Pass.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 3d9204183..77d74aa8a 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -112,8 +112,19 @@ OpFoldResult BroadcastOp::fold(ArrayRef operands) { auto constOperand = src().getDefiningOp(); if (!constOperand) return {}; + auto shapedType = getType().cast(); - return SplatElementsAttr::get(shapedType, {constOperand.getValue()}); + auto value = constOperand.getValue(); + if (auto denseElemsAttr = value.dyn_cast()) { + if (!denseElemsAttr.isSplat()) + return {}; + return SplatElementsAttr::get(shapedType, + denseElemsAttr.getSplatValue()); + } else if (value.getType().isIntOrIndexOrFloat()) { + return SplatElementsAttr::get(shapedType, value); + } else { + return {}; + } } } // namespace triton diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index eefd13633..16b74587c 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -8,6 +8,38 @@ using namespace mlir; using namespace mlir::triton::gpu; +static LogicalResult parseIntAttrValue(AsmParser &parser, const Attribute &attr, + unsigned &value, StringRef desc) { + auto intAttr = attr.dyn_cast(); + if (!intAttr) { + parser.emitError(parser.getNameLoc(), "expected an integer type in ") + << desc; + return failure(); + } + if (intAttr.getType().isSignedInteger()) { + int64_t attrVal = intAttr.getSInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else if (intAttr.getType().isSignlessInteger()) { + int64_t attrVal = intAttr.getInt(); + if (attrVal < 0) { + parser.emitError(parser.getNameLoc(), + "expected an unsigned integer value in ") + << desc; + return failure(); + } + value = attrVal; + } else { + value = intAttr.getUInt(); + } + return success(); +} + // parse an array of integers static LogicalResult parseIntArrayAttr(AsmParser &parser, const NamedAttribute &attr, @@ -19,26 +51,17 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser, return failure(); } for (Attribute i : arrayAttr) { - auto intAttr = i.dyn_cast(); - if (!intAttr) { - parser.emitError(parser.getNameLoc(), "expected an integer value in ") - << desc; + unsigned value; + if (parseIntAttrValue(parser, i, value, desc).failed()) return failure(); - } - res.push_back(intAttr.getUInt()); + res.push_back(value); } return success(); }; static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, unsigned &value, StringRef desc) { - auto intAttr = attr.getValue().dyn_cast(); - if (!intAttr) { - parser.emitError(parser.getNameLoc(), "expected an integer ") << desc; - return failure(); - } - value = intAttr.getUInt(); - return success(); + return parseIntAttrValue(parser, attr.getValue(), value, desc); }; //===----------------------------------------------------------------------===// @@ -214,8 +237,7 @@ public: os << "blocked"; return AliasResult::FinalAlias; } - OpAsmDialectInterface::getAlias(attr, os); - return AliasResult::FinalAlias; + return OpAsmDialectInterface::getAlias(attr, os); } };