[BACKEND] Added support for scalars in LoadOp / StoreOp / ElementwiseOp (#814)
Also fixed various errors that showed up in `test_core.py`, and added more TODOs for open (hopefully relatively minor) issues
This commit is contained in:
@@ -131,6 +131,16 @@ void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
}
|
||||
|
||||
//-- LoadOp --
|
||||
static Type getLoadOpResultType(::mlir::OpBuilder &builder, Type ptrType) {
|
||||
auto ptrTensorType = ptrType.dyn_cast<RankedTensorType>();
|
||||
if (!ptrTensorType)
|
||||
return ptrType.cast<PointerType>().getPointeeType();
|
||||
auto shape = ptrTensorType.getShape();
|
||||
Type elementType =
|
||||
ptrTensorType.getElementType().cast<PointerType>().getPointeeType();
|
||||
return RankedTensorType::get(shape, elementType);
|
||||
}
|
||||
|
||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
@@ -150,11 +160,8 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
|
||||
::mlir::triton::CacheModifier cache,
|
||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||
TensorType ptrType = ptr.getType().cast<TensorType>();
|
||||
Type elementType =
|
||||
ptrType.getElementType().cast<PointerType>().getPointeeType();
|
||||
auto shape = ptrType.getShape();
|
||||
Type resultType = RankedTensorType::get(shape, elementType);
|
||||
Type resultType = getLoadOpResultType(builder, ptr.getType());
|
||||
|
||||
state.addOperands(ptr);
|
||||
if (mask) {
|
||||
state.addOperands(mask);
|
||||
|
Reference in New Issue
Block a user