[BACKEND] Added support for scalars in LoadOp / StoreOp / ElementwiseOp (#814)

Also fixed various errors that showed up in `test_core.py`, and added more TODOs for open (hopefully relatively minor) issues
This commit is contained in:
Philippe Tillet
2022-10-28 01:17:55 -07:00
committed by GitHub
parent 3685194456
commit ac0f6793cc
6 changed files with 269 additions and 419 deletions

View File

@@ -131,6 +131,16 @@ void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
}
//-- LoadOp --
static Type getLoadOpResultType(::mlir::OpBuilder &builder, Type ptrType) {
auto ptrTensorType = ptrType.dyn_cast<RankedTensorType>();
if (!ptrTensorType)
return ptrType.cast<PointerType>().getPointeeType();
auto shape = ptrTensorType.getShape();
Type elementType =
ptrTensorType.getElementType().cast<PointerType>().getPointeeType();
return RankedTensorType::get(shape, elementType);
}
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
@@ -150,11 +160,8 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
TensorType ptrType = ptr.getType().cast<TensorType>();
Type elementType =
ptrType.getElementType().cast<PointerType>().getPointeeType();
auto shape = ptrType.getShape();
Type resultType = RankedTensorType::get(shape, elementType);
Type resultType = getLoadOpResultType(builder, ptr.getType());
state.addOperands(ptr);
if (mask) {
state.addOperands(mask);

View File

@@ -43,7 +43,12 @@ static Type getPointeeType(Type type) {
namespace gpu {
// TODO: Inheritation of layout attributes
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
unsigned getElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
return 1;
auto tensorType = type.cast<RankedTensorType>();
auto layout = tensorType.getEncoding();
auto shape = tensorType.getShape();
size_t rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return blockedLayout.getElemsPerThread(shape);