[BUILD] fix minor issues with MLIR assert enabled (#46)

This commit is contained in:
Shintaro Iwasaki
2022-08-11 21:20:47 -07:00
committed by GitHub
parent 3a48ca0d4d
commit 2ba9a83465
5 changed files with 53 additions and 17 deletions

View File

@@ -1,6 +1,7 @@
#ifndef TRITON_CONVERSION_PASSES_H #ifndef TRITON_CONVERSION_PASSES_H
#define TRITON_CONVERSION_PASSES_H #define TRITON_CONVERSION_PASSES_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
@@ -13,4 +14,4 @@ namespace triton {
} // namespace triton } // namespace triton
} // namespace mlir } // namespace mlir
#endif #endif

View File

@@ -35,6 +35,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
let dependentDialects = ["mlir::arith::ArithmeticDialect", let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::StandardOpsDialect", "mlir::StandardOpsDialect",
"mlir::scf::SCFDialect", "mlir::scf::SCFDialect",
"mlir::LLVM::LLVMDialect",
"mlir::triton::TritonDialect", "mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect"]; "mlir::triton::gpu::TritonGPUDialect"];
} }

View File

@@ -1,6 +1,7 @@
#ifndef TRITON_CONVERSION_PASSDETAIL_H #ifndef TRITON_CONVERSION_PASSDETAIL_H
#define TRITON_CONVERSION_PASSDETAIL_H #define TRITON_CONVERSION_PASSDETAIL_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h"

View File

@@ -112,8 +112,19 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>(); auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand) if (!constOperand)
return {}; return {};
auto shapedType = getType().cast<ShapedType>(); auto shapedType = getType().cast<ShapedType>();
return SplatElementsAttr::get(shapedType, {constOperand.getValue()}); auto value = constOperand.getValue();
if (auto denseElemsAttr = value.dyn_cast<DenseElementsAttr>()) {
if (!denseElemsAttr.isSplat())
return {};
return SplatElementsAttr::get(shapedType,
denseElemsAttr.getSplatValue<Attribute>());
} else if (value.getType().isIntOrIndexOrFloat()) {
return SplatElementsAttr::get(shapedType, value);
} else {
return {};
}
} }
} // namespace triton } // namespace triton

View File

@@ -8,6 +8,38 @@
using namespace mlir; using namespace mlir;
using namespace mlir::triton::gpu; using namespace mlir::triton::gpu;
static LogicalResult parseIntAttrValue(AsmParser &parser, const Attribute &attr,
unsigned &value, StringRef desc) {
auto intAttr = attr.dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer type in ")
<< desc;
return failure();
}
if (intAttr.getType().isSignedInteger()) {
int64_t attrVal = intAttr.getSInt();
if (attrVal < 0) {
parser.emitError(parser.getNameLoc(),
"expected an unsigned integer value in ")
<< desc;
return failure();
}
value = attrVal;
} else if (intAttr.getType().isSignlessInteger()) {
int64_t attrVal = intAttr.getInt();
if (attrVal < 0) {
parser.emitError(parser.getNameLoc(),
"expected an unsigned integer value in ")
<< desc;
return failure();
}
value = attrVal;
} else {
value = intAttr.getUInt();
}
return success();
}
// parse an array of integers // parse an array of integers
static LogicalResult parseIntArrayAttr(AsmParser &parser, static LogicalResult parseIntArrayAttr(AsmParser &parser,
const NamedAttribute &attr, const NamedAttribute &attr,
@@ -19,26 +51,17 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
return failure(); return failure();
} }
for (Attribute i : arrayAttr) { for (Attribute i : arrayAttr) {
auto intAttr = i.dyn_cast<IntegerAttr>(); unsigned value;
if (!intAttr) { if (parseIntAttrValue(parser, i, value, desc).failed())
parser.emitError(parser.getNameLoc(), "expected an integer value in ")
<< desc;
return failure(); return failure();
} res.push_back(value);
res.push_back(intAttr.getUInt());
} }
return success(); return success();
}; };
static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
unsigned &value, StringRef desc) { unsigned &value, StringRef desc) {
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); return parseIntAttrValue(parser, attr.getValue(), value, desc);
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
return failure();
}
value = intAttr.getUInt();
return success();
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@@ -214,8 +237,7 @@ public:
os << "blocked"; os << "blocked";
return AliasResult::FinalAlias; return AliasResult::FinalAlias;
} }
OpAsmDialectInterface::getAlias(attr, os); return OpAsmDialectInterface::getAlias(attr, os);
return AliasResult::FinalAlias;
} }
}; };