More on type inference & assembly format

This commit is contained in:
Yan Da
2022-04-08 19:37:57 +08:00
parent 13aead4808
commit 62f7609612
2 changed files with 77 additions and 22 deletions

View File

@@ -6,6 +6,37 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OperationSupport.h"
namespace mlir {
namespace triton {
// Type inference
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type);
return Type();
}
static Type getI32SameShape(Type type) {
auto i32Type = IntegerType::get(type.getContext(), 32);
if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i32Type);
return Type();
}
static Type getPointerTypeFromTensor(Type type) {
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
Type elementType = tensorType.getElementType();
auto shape = tensorType.getShape();
PointerType ptrType = PointerType::get(elementType, 1);
return RankedTensorType::get(shape, ptrType);
}
return Type();
}
}
}
#define GET_OP_CLASSES
#include "triton/ir/Ops.cpp.inc"