[BUILD] fix minor issues with MLIR assert enabled (#46)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#ifndef TRITON_CONVERSION_PASSES_H
|
||||
#define TRITON_CONVERSION_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
|
||||
|
@@ -35,6 +35,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||
"mlir::StandardOpsDialect",
|
||||
"mlir::scf::SCFDialect",
|
||||
"mlir::LLVM::LLVMDialect",
|
||||
"mlir::triton::TritonDialect",
|
||||
"mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
#ifndef TRITON_CONVERSION_PASSDETAIL_H
|
||||
#define TRITON_CONVERSION_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
@@ -112,8 +112,19 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||
if (!constOperand)
|
||||
return {};
|
||||
|
||||
auto shapedType = getType().cast<ShapedType>();
|
||||
return SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
||||
auto value = constOperand.getValue();
|
||||
if (auto denseElemsAttr = value.dyn_cast<DenseElementsAttr>()) {
|
||||
if (!denseElemsAttr.isSplat())
|
||||
return {};
|
||||
return SplatElementsAttr::get(shapedType,
|
||||
denseElemsAttr.getSplatValue<Attribute>());
|
||||
} else if (value.getType().isIntOrIndexOrFloat()) {
|
||||
return SplatElementsAttr::get(shapedType, value);
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
|
@@ -8,6 +8,38 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton::gpu;
|
||||
|
||||
static LogicalResult parseIntAttrValue(AsmParser &parser, const Attribute &attr,
|
||||
unsigned &value, StringRef desc) {
|
||||
auto intAttr = attr.dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an integer type in ")
|
||||
<< desc;
|
||||
return failure();
|
||||
}
|
||||
if (intAttr.getType().isSignedInteger()) {
|
||||
int64_t attrVal = intAttr.getSInt();
|
||||
if (attrVal < 0) {
|
||||
parser.emitError(parser.getNameLoc(),
|
||||
"expected an unsigned integer value in ")
|
||||
<< desc;
|
||||
return failure();
|
||||
}
|
||||
value = attrVal;
|
||||
} else if (intAttr.getType().isSignlessInteger()) {
|
||||
int64_t attrVal = intAttr.getInt();
|
||||
if (attrVal < 0) {
|
||||
parser.emitError(parser.getNameLoc(),
|
||||
"expected an unsigned integer value in ")
|
||||
<< desc;
|
||||
return failure();
|
||||
}
|
||||
value = attrVal;
|
||||
} else {
|
||||
value = intAttr.getUInt();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// parse an array of integers
|
||||
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||
const NamedAttribute &attr,
|
||||
@@ -19,26 +51,17 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||
return failure();
|
||||
}
|
||||
for (Attribute i : arrayAttr) {
|
||||
auto intAttr = i.dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an integer value in ")
|
||||
<< desc;
|
||||
unsigned value;
|
||||
if (parseIntAttrValue(parser, i, value, desc).failed())
|
||||
return failure();
|
||||
}
|
||||
res.push_back(intAttr.getUInt());
|
||||
res.push_back(value);
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
|
||||
unsigned &value, StringRef desc) {
|
||||
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
|
||||
return failure();
|
||||
}
|
||||
value = intAttr.getUInt();
|
||||
return success();
|
||||
return parseIntAttrValue(parser, attr.getValue(), value, desc);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -214,8 +237,7 @@ public:
|
||||
os << "blocked";
|
||||
return AliasResult::FinalAlias;
|
||||
}
|
||||
OpAsmDialectInterface::getAlias(attr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
return OpAsmDialectInterface::getAlias(attr, os);
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user