[BUILD] fix minor issues with MLIR assert enabled (#46)

This commit is contained in:
Shintaro Iwasaki
2022-08-11 21:20:47 -07:00
committed by GitHub
parent 3a48ca0d4d
commit 2ba9a83465
5 changed files with 53 additions and 17 deletions

View File

@@ -1,6 +1,7 @@
#ifndef TRITON_CONVERSION_PASSES_H
#define TRITON_CONVERSION_PASSES_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"

View File

@@ -35,6 +35,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::StandardOpsDialect",
"mlir::scf::SCFDialect",
"mlir::LLVM::LLVMDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect"];
}

View File

@@ -1,6 +1,7 @@
#ifndef TRITON_CONVERSION_PASSDETAIL_H
#define TRITON_CONVERSION_PASSDETAIL_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"

View File

@@ -112,8 +112,19 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
if (!constOperand)
return {};
auto shapedType = getType().cast<ShapedType>();
return SplatElementsAttr::get(shapedType, {constOperand.getValue()});
auto value = constOperand.getValue();
if (auto denseElemsAttr = value.dyn_cast<DenseElementsAttr>()) {
if (!denseElemsAttr.isSplat())
return {};
return SplatElementsAttr::get(shapedType,
denseElemsAttr.getSplatValue<Attribute>());
} else if (value.getType().isIntOrIndexOrFloat()) {
return SplatElementsAttr::get(shapedType, value);
} else {
return {};
}
}
} // namespace triton

View File

@@ -8,6 +8,38 @@
using namespace mlir;
using namespace mlir::triton::gpu;
static LogicalResult parseIntAttrValue(AsmParser &parser, const Attribute &attr,
unsigned &value, StringRef desc) {
auto intAttr = attr.dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer type in ")
<< desc;
return failure();
}
if (intAttr.getType().isSignedInteger()) {
int64_t attrVal = intAttr.getSInt();
if (attrVal < 0) {
parser.emitError(parser.getNameLoc(),
"expected an unsigned integer value in ")
<< desc;
return failure();
}
value = attrVal;
} else if (intAttr.getType().isSignlessInteger()) {
int64_t attrVal = intAttr.getInt();
if (attrVal < 0) {
parser.emitError(parser.getNameLoc(),
"expected an unsigned integer value in ")
<< desc;
return failure();
}
value = attrVal;
} else {
value = intAttr.getUInt();
}
return success();
}
// parse an array of integers
static LogicalResult parseIntArrayAttr(AsmParser &parser,
const NamedAttribute &attr,
@@ -19,26 +51,17 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
return failure();
}
for (Attribute i : arrayAttr) {
auto intAttr = i.dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer value in ")
<< desc;
unsigned value;
if (parseIntAttrValue(parser, i, value, desc).failed())
return failure();
}
res.push_back(intAttr.getUInt());
res.push_back(value);
}
return success();
};
static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
unsigned &value, StringRef desc) {
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
return failure();
}
value = intAttr.getUInt();
return success();
return parseIntAttrValue(parser, attr.getValue(), value, desc);
};
//===----------------------------------------------------------------------===//
@@ -214,8 +237,7 @@ public:
os << "blocked";
return AliasResult::FinalAlias;
}
OpAsmDialectInterface::getAlias(attr, os);
return AliasResult::FinalAlias;
return OpAsmDialectInterface::getAlias(attr, os);
}
};