assembly format
This commit is contained in:
@@ -58,6 +58,34 @@ void TritonGPUDialect::initialize() {
|
||||
>();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
// Type inference
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding());
|
||||
return Type();
|
||||
}
|
||||
|
||||
static Type getPointeeType(Type type) {
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
// Tensor of pointers
|
||||
auto shape = tensorType.getShape();
|
||||
auto ptrType = tensorType.getElementType().dyn_cast<PointerType>();
|
||||
Type pointeeType = ptrType.getPointeeType();
|
||||
return RankedTensorType::get(shape, pointeeType, tensorType.getEncoding());
|
||||
} else if (auto ptrType = type.dyn_cast<PointerType>()) {
|
||||
// scalar pointer
|
||||
Type pointeeType = ptrType.getPointeeType();
|
||||
return pointeeType;
|
||||
}
|
||||
return Type();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||
|
Reference in New Issue
Block a user