[BUILD] fix minor issues with MLIR assert enabled (#46)
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
#ifndef TRITON_CONVERSION_PASSES_H
|
#ifndef TRITON_CONVERSION_PASSES_H
|
||||||
#define TRITON_CONVERSION_PASSES_H
|
#define TRITON_CONVERSION_PASSES_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h"
|
||||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||||
|
|
||||||
@@ -13,4 +14,4 @@ namespace triton {
|
|||||||
} // namespace triton
|
} // namespace triton
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -35,6 +35,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
|||||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||||
"mlir::StandardOpsDialect",
|
"mlir::StandardOpsDialect",
|
||||||
"mlir::scf::SCFDialect",
|
"mlir::scf::SCFDialect",
|
||||||
|
"mlir::LLVM::LLVMDialect",
|
||||||
"mlir::triton::TritonDialect",
|
"mlir::triton::TritonDialect",
|
||||||
"mlir::triton::gpu::TritonGPUDialect"];
|
"mlir::triton::gpu::TritonGPUDialect"];
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
#ifndef TRITON_CONVERSION_PASSDETAIL_H
|
#ifndef TRITON_CONVERSION_PASSDETAIL_H
|
||||||
#define TRITON_CONVERSION_PASSDETAIL_H
|
#define TRITON_CONVERSION_PASSDETAIL_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
|
@@ -112,8 +112,19 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
|||||||
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
auto constOperand = src().getDefiningOp<arith::ConstantOp>();
|
||||||
if (!constOperand)
|
if (!constOperand)
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
auto shapedType = getType().cast<ShapedType>();
|
auto shapedType = getType().cast<ShapedType>();
|
||||||
return SplatElementsAttr::get(shapedType, {constOperand.getValue()});
|
auto value = constOperand.getValue();
|
||||||
|
if (auto denseElemsAttr = value.dyn_cast<DenseElementsAttr>()) {
|
||||||
|
if (!denseElemsAttr.isSplat())
|
||||||
|
return {};
|
||||||
|
return SplatElementsAttr::get(shapedType,
|
||||||
|
denseElemsAttr.getSplatValue<Attribute>());
|
||||||
|
} else if (value.getType().isIntOrIndexOrFloat()) {
|
||||||
|
return SplatElementsAttr::get(shapedType, value);
|
||||||
|
} else {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace triton
|
} // namespace triton
|
||||||
|
@@ -8,6 +8,38 @@
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton::gpu;
|
using namespace mlir::triton::gpu;
|
||||||
|
|
||||||
|
static LogicalResult parseIntAttrValue(AsmParser &parser, const Attribute &attr,
|
||||||
|
unsigned &value, StringRef desc) {
|
||||||
|
auto intAttr = attr.dyn_cast<IntegerAttr>();
|
||||||
|
if (!intAttr) {
|
||||||
|
parser.emitError(parser.getNameLoc(), "expected an integer type in ")
|
||||||
|
<< desc;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (intAttr.getType().isSignedInteger()) {
|
||||||
|
int64_t attrVal = intAttr.getSInt();
|
||||||
|
if (attrVal < 0) {
|
||||||
|
parser.emitError(parser.getNameLoc(),
|
||||||
|
"expected an unsigned integer value in ")
|
||||||
|
<< desc;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
value = attrVal;
|
||||||
|
} else if (intAttr.getType().isSignlessInteger()) {
|
||||||
|
int64_t attrVal = intAttr.getInt();
|
||||||
|
if (attrVal < 0) {
|
||||||
|
parser.emitError(parser.getNameLoc(),
|
||||||
|
"expected an unsigned integer value in ")
|
||||||
|
<< desc;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
value = attrVal;
|
||||||
|
} else {
|
||||||
|
value = intAttr.getUInt();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// parse an array of integers
|
// parse an array of integers
|
||||||
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||||
const NamedAttribute &attr,
|
const NamedAttribute &attr,
|
||||||
@@ -19,26 +51,17 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
for (Attribute i : arrayAttr) {
|
for (Attribute i : arrayAttr) {
|
||||||
auto intAttr = i.dyn_cast<IntegerAttr>();
|
unsigned value;
|
||||||
if (!intAttr) {
|
if (parseIntAttrValue(parser, i, value, desc).failed())
|
||||||
parser.emitError(parser.getNameLoc(), "expected an integer value in ")
|
|
||||||
<< desc;
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
res.push_back(value);
|
||||||
res.push_back(intAttr.getUInt());
|
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
};
|
};
|
||||||
|
|
||||||
static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
|
static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
|
||||||
unsigned &value, StringRef desc) {
|
unsigned &value, StringRef desc) {
|
||||||
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
|
return parseIntAttrValue(parser, attr.getValue(), value, desc);
|
||||||
if (!intAttr) {
|
|
||||||
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
value = intAttr.getUInt();
|
|
||||||
return success();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -214,8 +237,7 @@ public:
|
|||||||
os << "blocked";
|
os << "blocked";
|
||||||
return AliasResult::FinalAlias;
|
return AliasResult::FinalAlias;
|
||||||
}
|
}
|
||||||
OpAsmDialectInterface::getAlias(attr, os);
|
return OpAsmDialectInterface::getAlias(attr, os);
|
||||||
return AliasResult::FinalAlias;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user