[BACKEND] Added support for scalars in LoadOp / StoreOp / ElementwiseOp (#814)
Also fixed various errors that showed up in `test_core.py`, and added more TODOs for open (hopefully relatively minor) issues
This commit is contained in:
@@ -21,7 +21,7 @@ namespace mlir {
|
|||||||
namespace triton {
|
namespace triton {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape);
|
unsigned getElemsPerThread(Type type);
|
||||||
|
|
||||||
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
||||||
|
|
||||||
|
@@ -12,6 +12,7 @@
|
|||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "triton/Analysis/Allocation.h"
|
#include "triton/Analysis/Allocation.h"
|
||||||
#include "triton/Analysis/AxisInfo.h"
|
#include "triton/Analysis/AxisInfo.h"
|
||||||
@@ -310,6 +311,10 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> {
|
|||||||
Value getStructFromElements(Location loc, ValueRange resultVals,
|
Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
Type structType) {
|
Type structType) {
|
||||||
|
if (!structType.isa<LLVM::LLVMStructType>()) {
|
||||||
|
return *resultVals.begin();
|
||||||
|
}
|
||||||
|
|
||||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
||||||
for (auto v : llvm::enumerate(resultVals)) {
|
for (auto v : llvm::enumerate(resultVals)) {
|
||||||
llvmStruct = insert_val(structType, llvmStruct, v.value(),
|
llvmStruct = insert_val(structType, llvmStruct, v.value(),
|
||||||
@@ -369,6 +374,10 @@ struct ConvertTritonGPUOpToLLVMPatternBase {
|
|||||||
static SmallVector<Value>
|
static SmallVector<Value>
|
||||||
getElementsFromStruct(Location loc, Value llvmStruct,
|
getElementsFromStruct(Location loc, Value llvmStruct,
|
||||||
ConversionPatternRewriter &rewriter) {
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
|
||||||
|
llvmStruct.getType().isa<triton::PointerType>() ||
|
||||||
|
llvmStruct.getType().isa<LLVM::LLVMPointerType>())
|
||||||
|
return {llvmStruct};
|
||||||
ArrayRef<Type> types =
|
ArrayRef<Type> types =
|
||||||
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
|
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
|
||||||
SmallVector<Value> results(types.size());
|
SmallVector<Value> results(types.size());
|
||||||
@@ -678,7 +687,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
|||||||
auto layout = tensorTy.getEncoding();
|
auto layout = tensorTy.getEncoding();
|
||||||
auto srcType = typeConverter->convertType(elemType);
|
auto srcType = typeConverter->convertType(elemType);
|
||||||
auto llSrc = bitcast(srcType, constVal);
|
auto llSrc = bitcast(srcType, constVal);
|
||||||
size_t elemsPerThread = getElemsPerThread(layout, tensorTy.getShape());
|
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
||||||
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
|
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
|
||||||
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
|
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
|
||||||
auto structTy =
|
auto structTy =
|
||||||
@@ -760,64 +769,49 @@ struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase {
|
|||||||
|
|
||||||
// Get corresponding LLVM element values of \param value.
|
// Get corresponding LLVM element values of \param value.
|
||||||
SmallVector<Value> getLLVMElems(Value value, Value llValue,
|
SmallVector<Value> getLLVMElems(Value value, Value llValue,
|
||||||
const BlockedEncodingAttr &layout,
|
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
Location loc) const {
|
Location loc) const {
|
||||||
if (!value)
|
if (!value)
|
||||||
return {};
|
return {};
|
||||||
|
if (!llValue.getType().isa<LLVM::LLVMStructType>())
|
||||||
auto shape = value.getType().cast<RankedTensorType>().getShape();
|
return {llValue};
|
||||||
// Here, we assume that all inputs should have a blockedLayout
|
// Here, we assume that all inputs should have a blockedLayout
|
||||||
auto valueVals = getElementsFromStruct(loc, llValue, rewriter);
|
auto valueVals = getElementsFromStruct(loc, llValue, rewriter);
|
||||||
return valueVals;
|
return valueVals;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the blocked layout.
|
unsigned getVectorSize(Value ptr) const {
|
||||||
std::tuple<BlockedEncodingAttr, unsigned> getLayout(Value val) const {
|
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||||
auto ty = val.getType().cast<RankedTensorType>();
|
if (!tensorTy)
|
||||||
// Here, we assume that all inputs should have a blockedLayout
|
return 1;
|
||||||
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
auto layout = tensorTy.getEncoding();
|
||||||
assert(layout && "unexpected layout in getLayout");
|
auto shape = tensorTy.getShape();
|
||||||
auto shape = ty.getShape();
|
|
||||||
unsigned valueElems = layout.getElemsPerThread(shape);
|
|
||||||
return {layout, valueElems};
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned getAlignment(Value val, const BlockedEncodingAttr &layout) const {
|
|
||||||
auto axisInfo = getAxisInfo(val);
|
|
||||||
auto order = layout.getOrder();
|
|
||||||
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
|
|
||||||
unsigned maxContig = axisInfo->getContiguity(order[0]);
|
|
||||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
|
||||||
return alignment;
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned getVectorizeSize(Value ptr,
|
|
||||||
const BlockedEncodingAttr &layout) const {
|
|
||||||
auto axisInfo = getAxisInfo(ptr);
|
auto axisInfo = getAxisInfo(ptr);
|
||||||
// Here order should be ordered by contiguous first, so the first element
|
// Here order should be ordered by contiguous first, so the first element
|
||||||
// should have the largest contiguous.
|
// should have the largest contiguous.
|
||||||
auto order = layout.getOrder();
|
auto order = getOrder(layout);
|
||||||
unsigned align = getAlignment(ptr, layout);
|
unsigned align = getAlignment(ptr, layout);
|
||||||
|
|
||||||
auto ty = ptr.getType().dyn_cast<RankedTensorType>();
|
unsigned contigPerThread = getSizePerThread(layout)[order[0]];
|
||||||
assert(ty);
|
|
||||||
auto shape = ty.getShape();
|
|
||||||
|
|
||||||
unsigned contigPerThread = layout.getSizePerThread()[order[0]];
|
|
||||||
unsigned vec = std::min(align, contigPerThread);
|
unsigned vec = std::min(align, contigPerThread);
|
||||||
vec = std::min<unsigned>(shape[order[0]], vec);
|
vec = std::min<unsigned>(shape[order[0]], vec);
|
||||||
|
|
||||||
return vec;
|
return vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned getMaskAlignment(Value mask) const {
|
unsigned getAlignment(Value val, const Attribute &layout) const {
|
||||||
auto maskOrder = mask.getType()
|
auto axisInfo = getAxisInfo(val);
|
||||||
.cast<RankedTensorType>()
|
auto order = getOrder(layout);
|
||||||
.getEncoding()
|
unsigned maxMultiple = axisInfo->getDivisibility(order[0]);
|
||||||
.cast<BlockedEncodingAttr>()
|
unsigned maxContig = axisInfo->getContiguity(order[0]);
|
||||||
.getOrder();
|
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned getMaskAlignment(Value mask) const {
|
||||||
|
auto tensorTy = mask.getType().cast<RankedTensorType>();
|
||||||
|
auto maskOrder = getOrder(tensorTy.getEncoding());
|
||||||
auto maskAxis = getAxisInfo(mask);
|
auto maskAxis = getAxisInfo(mask);
|
||||||
return std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
|
return std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
|
||||||
}
|
}
|
||||||
@@ -848,46 +842,39 @@ struct LoadOpConversion
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
MLIRContext *ctx = rewriter.getContext();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
// original values
|
||||||
Value ptr = op.ptr();
|
Value ptr = op.ptr();
|
||||||
Value mask = op.mask();
|
Value mask = op.mask();
|
||||||
Value other = op.other();
|
Value other = op.other();
|
||||||
|
|
||||||
|
// adaptor values
|
||||||
Value llPtr = adaptor.ptr();
|
Value llPtr = adaptor.ptr();
|
||||||
Value llMask = adaptor.mask();
|
Value llMask = adaptor.mask();
|
||||||
Value llOther = adaptor.other();
|
Value llOther = adaptor.other();
|
||||||
|
|
||||||
auto loc = op->getLoc();
|
|
||||||
MLIRContext *ctx = rewriter.getContext();
|
|
||||||
|
|
||||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!valueTy)
|
|
||||||
return failure();
|
|
||||||
Type valueElemTy =
|
|
||||||
getTypeConverter()->convertType(valueTy.getElementType());
|
|
||||||
|
|
||||||
auto [layout, numElems] = getLayout(ptr);
|
|
||||||
|
|
||||||
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
|
|
||||||
assert(ptrElems.size() == numElems);
|
|
||||||
// Determine the vectorization size
|
// Determine the vectorization size
|
||||||
size_t vec = getVectorizeSize(ptr, layout);
|
Type valueTy = op.getResult().getType();
|
||||||
|
Type valueElemTy = getElementTypeOrSelf(valueTy);
|
||||||
|
unsigned vec = getVectorSize(ptr);
|
||||||
|
unsigned numElems = getElemsPerThread(ptr.getType());
|
||||||
|
if (llMask)
|
||||||
|
vec = std::min<size_t>(vec, getMaskAlignment(mask));
|
||||||
|
|
||||||
|
// Get the LLVM values for pointers
|
||||||
|
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
|
||||||
|
assert(ptrElems.size() == numElems);
|
||||||
|
|
||||||
|
// Get the LLVM values for mask
|
||||||
SmallVector<Value> maskElems;
|
SmallVector<Value> maskElems;
|
||||||
if (llMask) {
|
if (llMask) {
|
||||||
unsigned maskAlignment = getMaskAlignment(mask);
|
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
|
||||||
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
assert(maskElems.size() == numElems);
|
||||||
assert(ptrElems.size() == maskElems.size());
|
|
||||||
|
|
||||||
size_t maskAlign = getMaskAlignment(mask);
|
|
||||||
vec = std::min(vec, maskAlign);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t dtsize =
|
// Get the LLVM values for `other`
|
||||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
|
||||||
const size_t valueElemNbits = dtsize * 8;
|
|
||||||
|
|
||||||
const int numVecs = numElems / vec;
|
|
||||||
|
|
||||||
// TODO: (goostavz) handle when other is const but not splat, which
|
// TODO: (goostavz) handle when other is const but not splat, which
|
||||||
// should be rarely seen
|
// should be rarely seen
|
||||||
bool otherIsSplatConstInt = false;
|
bool otherIsSplatConstInt = false;
|
||||||
@@ -898,8 +885,12 @@ struct LoadOpConversion
|
|||||||
otherIsSplatConstInt = true;
|
otherIsSplatConstInt = true;
|
||||||
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
||||||
}
|
}
|
||||||
|
auto otherElems = getLLVMElems(other, llOther, rewriter, loc);
|
||||||
|
|
||||||
auto otherElems = getLLVMElems(other, llOther, layout, rewriter, loc);
|
// vectorized iteration through all the pointer/mask/other elements
|
||||||
|
const int valueElemNbits =
|
||||||
|
std::max(8u, valueElemTy.getIntOrFloatBitWidth());
|
||||||
|
const int numVecs = numElems / vec;
|
||||||
|
|
||||||
SmallVector<Value> loadedVals;
|
SmallVector<Value> loadedVals;
|
||||||
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
||||||
@@ -1060,30 +1051,23 @@ struct StoreOpConversion
|
|||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
MLIRContext *ctx = rewriter.getContext();
|
MLIRContext *ctx = rewriter.getContext();
|
||||||
|
|
||||||
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
auto valueTy = value.getType();
|
||||||
if (!valueTy) {
|
Type valueElemTy = getElementTypeOrSelf(valueTy);
|
||||||
store(llValue, llPtr);
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
Type valueElemTy =
|
unsigned vec = getVectorSize(ptr);
|
||||||
getTypeConverter()->convertType(valueTy.getElementType());
|
unsigned numElems = getElemsPerThread(ptr.getType());
|
||||||
|
|
||||||
auto [layout, numElems] = getLayout(ptr);
|
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
|
||||||
|
auto valueElems = getLLVMElems(value, llValue, rewriter, loc);
|
||||||
auto ptrElems = getLLVMElems(ptr, llPtr, layout, rewriter, loc);
|
|
||||||
auto valueElems = getLLVMElems(value, llValue, layout, rewriter, loc);
|
|
||||||
assert(ptrElems.size() == valueElems.size());
|
assert(ptrElems.size() == valueElems.size());
|
||||||
|
|
||||||
// Determine the vectorization size
|
// Determine the vectorization size
|
||||||
size_t vec = getVectorizeSize(ptr, layout);
|
|
||||||
SmallVector<Value> maskElems;
|
SmallVector<Value> maskElems;
|
||||||
if (llMask) {
|
if (llMask) {
|
||||||
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
|
||||||
assert(valueElems.size() == maskElems.size());
|
assert(valueElems.size() == maskElems.size());
|
||||||
|
|
||||||
size_t maskAlign = getMaskAlignment(mask);
|
unsigned maskAlign = getMaskAlignment(mask);
|
||||||
vec = std::min(vec, maskAlign);
|
vec = std::min(vec, maskAlign);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1146,7 +1130,7 @@ struct StoreOpConversion
|
|||||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||||
|
|
||||||
auto &ptxStoreInstr =
|
auto &ptxStoreInstr =
|
||||||
ptxBuilder.create<PTXIOInstr>("st")->global().b(width).v(nWords);
|
ptxBuilder.create<PTXIOInstr>("st")->global().v(nWords).b(width);
|
||||||
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
ptxStoreInstr(asmAddr, asmArgList).predicate(maskVal, "b");
|
||||||
|
|
||||||
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
Type boolTy = getTypeConverter()->convertType(rewriter.getIntegerType(1));
|
||||||
@@ -1223,8 +1207,9 @@ struct BroadcastOpConversion
|
|||||||
for (auto it : llvm::enumerate(broadcastDims)) {
|
for (auto it : llvm::enumerate(broadcastDims)) {
|
||||||
// Incase there are multiple indices in the src that is actually
|
// Incase there are multiple indices in the src that is actually
|
||||||
// calculating the same element, srcLogicalShape may not need to be 1.
|
// calculating the same element, srcLogicalShape may not need to be 1.
|
||||||
// Such as the case when src of shape [256, 1], and with a blocked layout:
|
// Such as the case when src of shape [256, 1], and with a blocked
|
||||||
// sizePerThread: [1, 4]; threadsPerWarp: [1, 32]; warpsPerCTA: [1, 2]
|
// layout: sizePerThread: [1, 4]; threadsPerWarp: [1, 32]; warpsPerCTA:
|
||||||
|
// [1, 2]
|
||||||
int64_t d = resultLogicalShape[it.value()] / srcLogicalShape[it.value()];
|
int64_t d = resultLogicalShape[it.value()] / srcLogicalShape[it.value()];
|
||||||
broadcastSizes[it.index()] = d;
|
broadcastSizes[it.index()] = d;
|
||||||
duplicates *= d;
|
duplicates *= d;
|
||||||
@@ -1234,10 +1219,10 @@ struct BroadcastOpConversion
|
|||||||
duplicates *= d;
|
duplicates *= d;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned srcElems = srcLayout.getElemsPerThread(srcShape);
|
unsigned srcElems = getElemsPerThread(srcTy);
|
||||||
auto elemTy = resultTy.getElementType();
|
auto elemTy = resultTy.getElementType();
|
||||||
auto srcVals = getElementsFromStruct(loc, src, rewriter);
|
auto srcVals = getElementsFromStruct(loc, src, rewriter);
|
||||||
unsigned resultElems = resultLayout.getElemsPerThread(resultShape);
|
unsigned resultElems = getElemsPerThread(resultTy);
|
||||||
SmallVector<Value> resultVals(resultElems);
|
SmallVector<Value> resultVals(resultElems);
|
||||||
for (unsigned i = 0; i < srcElems; ++i) {
|
for (unsigned i = 0; i < srcElems; ++i) {
|
||||||
auto srcMultiDim = getMultiDimIndex<int64_t>(i, srcLogicalShape);
|
auto srcMultiDim = getMultiDimIndex<int64_t>(i, srcLogicalShape);
|
||||||
@@ -1256,8 +1241,10 @@ struct BroadcastOpConversion
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
|
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
|
||||||
|
|
||||||
Value resultStruct =
|
Value resultStruct =
|
||||||
getStructFromElements(loc, resultVals, rewriter, llvmStructTy);
|
getStructFromElements(loc, resultVals, rewriter, llvmStructTy);
|
||||||
|
|
||||||
rewriter.replaceOp(op, {resultStruct});
|
rewriter.replaceOp(op, {resultStruct});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -1389,7 +1376,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|||||||
|
|
||||||
auto smemShape = getScratchConfigForReduce(op);
|
auto smemShape = getScratchConfigForReduce(op);
|
||||||
|
|
||||||
unsigned srcElems = getElemsPerThread(srcLayout, srcShape);
|
unsigned srcElems = getElemsPerThread(srcTy);
|
||||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||||
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
||||||
|
|
||||||
@@ -1446,7 +1433,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteBasic(
|
|||||||
auto resultLayout = resultTy.getEncoding();
|
auto resultLayout = resultTy.getEncoding();
|
||||||
auto resultShape = resultTy.getShape();
|
auto resultShape = resultTy.getShape();
|
||||||
|
|
||||||
unsigned resultElems = getElemsPerThread(resultLayout, resultShape);
|
unsigned resultElems = getElemsPerThread(resultTy);
|
||||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
||||||
assert(resultIndices.size() == resultElems);
|
assert(resultIndices.size() == resultElems);
|
||||||
|
|
||||||
@@ -1498,7 +1485,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
unsigned sizeIntraWarps = threadsPerWarp[axis];
|
unsigned sizeIntraWarps = threadsPerWarp[axis];
|
||||||
unsigned sizeInterWarps = warpsPerCTA[axis];
|
unsigned sizeInterWarps = warpsPerCTA[axis];
|
||||||
|
|
||||||
unsigned srcElems = getElemsPerThread(srcLayout, srcShape);
|
unsigned srcElems = getElemsPerThread(srcTy);
|
||||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||||
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter);
|
||||||
|
|
||||||
@@ -1586,7 +1573,7 @@ LogicalResult ReduceOpConversion::matchAndRewriteFast(
|
|||||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||||
auto resultShape = resultTy.getShape();
|
auto resultShape = resultTy.getShape();
|
||||||
|
|
||||||
unsigned resultElems = getElemsPerThread(resultLayout, resultShape);
|
unsigned resultElems = getElemsPerThread(resultTy);
|
||||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultShape);
|
||||||
assert(resultIndices.size() == resultElems);
|
assert(resultIndices.size() == resultElems);
|
||||||
|
|
||||||
@@ -1633,7 +1620,7 @@ struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
|||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||||
auto resultShape = resultTy.getShape();
|
auto resultShape = resultTy.getShape();
|
||||||
unsigned elems = getElemsPerThread(resultTy.getEncoding(), resultShape);
|
unsigned elems = getElemsPerThread(resultTy);
|
||||||
Type elemTy =
|
Type elemTy =
|
||||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||||
SmallVector<Type> types(elems, elemTy);
|
SmallVector<Type> types(elems, elemTy);
|
||||||
@@ -1712,7 +1699,7 @@ struct AddPtrOpConversion
|
|||||||
resultTensorTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
resultTensorTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||||
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
|
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
|
||||||
auto resultShape = resultTensorTy.getShape();
|
auto resultShape = resultTensorTy.getShape();
|
||||||
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
unsigned elems = getElemsPerThread(resultTy);
|
||||||
Type elemTy =
|
Type elemTy =
|
||||||
getTypeConverter()->convertType(resultTensorTy.getElementType());
|
getTypeConverter()->convertType(resultTensorTy.getElementType());
|
||||||
SmallVector<Type> types(elems, elemTy);
|
SmallVector<Type> types(elems, elemTy);
|
||||||
@@ -1769,8 +1756,8 @@ struct ExtractSliceOpConversion
|
|||||||
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
auto srcLayout = srcTy.getEncoding().dyn_cast<SharedEncodingAttr>();
|
||||||
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");
|
assert(srcLayout && "Unexpected resultLayout in ExtractSliceOpConversion");
|
||||||
|
|
||||||
// axis > 0 will result in non-contiguous memory access if the result tensor
|
// axis > 0 will result in non-contiguous memory access if the result
|
||||||
// is an alias of the source tensor.
|
// tensor is an alias of the source tensor.
|
||||||
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||||
assert(axis == 0 && "extract_slice: Only axis=0 is supported for now");
|
assert(axis == 0 && "extract_slice: Only axis=0 is supported for now");
|
||||||
|
|
||||||
@@ -1806,22 +1793,14 @@ public:
|
|||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
|
auto resultTy = op.getType();
|
||||||
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
|
|
||||||
if (!resultTy)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
auto resultLayout =
|
|
||||||
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
|
unsigned elems = getElemsPerThread(resultTy);
|
||||||
auto resultShape = resultTy.getShape();
|
auto resultElementTy = getElementTypeOrSelf(resultTy);
|
||||||
assert(resultLayout &&
|
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
|
||||||
"Unexpected resultLayout in ElementwiseOpConversionBase");
|
|
||||||
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
|
||||||
Type elemTy =
|
|
||||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
|
||||||
SmallVector<Type> types(elems, elemTy);
|
SmallVector<Type> types(elems, elemTy);
|
||||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
Type structTy = this->getTypeConverter()->convertType(resultTy);
|
||||||
|
|
||||||
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
||||||
auto operands = getOperands(rewriter, adaptor, elems, loc);
|
auto operands = getOperands(rewriter, adaptor, elems, loc);
|
||||||
@@ -1874,152 +1853,6 @@ struct ElementwiseOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
|
||||||
// Ternary
|
|
||||||
//
|
|
||||||
|
|
||||||
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
|
||||||
class TernaryOpConversionBase
|
|
||||||
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
|
||||||
public:
|
|
||||||
using OpAdaptor = typename SourceOp::Adaptor;
|
|
||||||
|
|
||||||
explicit TernaryOpConversionBase(LLVMTypeConverter &typeConverter,
|
|
||||||
PatternBenefit benefit = 1)
|
|
||||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
|
|
||||||
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
|
|
||||||
if (!resultTy)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
|
||||||
auto resultLayout =
|
|
||||||
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
|
|
||||||
auto resultShape = resultTy.getShape();
|
|
||||||
assert(resultLayout && "Unexpected resultLayout in TernaryOpConversion");
|
|
||||||
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
|
||||||
Type elemTy =
|
|
||||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
|
||||||
SmallVector<Type> types(elems, elemTy);
|
|
||||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
|
||||||
|
|
||||||
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
|
||||||
auto lhss =
|
|
||||||
this->getElementsFromStruct(loc, adaptor.getOperands()[0], rewriter);
|
|
||||||
auto rhss =
|
|
||||||
this->getElementsFromStruct(loc, adaptor.getOperands()[1], rewriter);
|
|
||||||
auto thss =
|
|
||||||
this->getElementsFromStruct(loc, adaptor.getOperands()[2], rewriter);
|
|
||||||
SmallVector<Value> resultVals(elems);
|
|
||||||
for (unsigned i = 0; i < elems; ++i) {
|
|
||||||
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
|
|
||||||
rhss[i], thss[i], loc);
|
|
||||||
}
|
|
||||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
|
||||||
rewriter.replaceOp(op, view);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename SourceOp, typename DestOp>
|
|
||||||
struct TernaryOpConversion
|
|
||||||
: public TernaryOpConversionBase<SourceOp, DestOp,
|
|
||||||
TernaryOpConversion<SourceOp, DestOp>> {
|
|
||||||
|
|
||||||
explicit TernaryOpConversion(LLVMTypeConverter &typeConverter,
|
|
||||||
PatternBenefit benefit = 1)
|
|
||||||
: TernaryOpConversionBase<SourceOp, DestOp,
|
|
||||||
TernaryOpConversion<SourceOp, DestOp>>(
|
|
||||||
typeConverter, benefit) {}
|
|
||||||
|
|
||||||
using OpAdaptor = typename SourceOp::Adaptor;
|
|
||||||
// An interface to support variant DestOp builder.
|
|
||||||
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
|
|
||||||
Type elemTy, Value lhs, Value rhs, Value th,
|
|
||||||
Location loc) const {
|
|
||||||
return rewriter.create<DestOp>(loc, elemTy, lhs, rhs, th);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
//
|
|
||||||
// Unary
|
|
||||||
//
|
|
||||||
|
|
||||||
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
|
||||||
class UnaryOpConversionBase : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
|
||||||
|
|
||||||
public:
|
|
||||||
using OpAdaptor = typename SourceOp::Adaptor;
|
|
||||||
|
|
||||||
explicit UnaryOpConversionBase(LLVMTypeConverter &typeConverter,
|
|
||||||
PatternBenefit benefit = 1)
|
|
||||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
auto resultTy = op.getType().template dyn_cast<RankedTensorType>();
|
|
||||||
|
|
||||||
// ArithmeticToLLVM will handle the lowering of scalar ArithOps
|
|
||||||
if (!resultTy)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
|
||||||
auto resultLayout =
|
|
||||||
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
|
|
||||||
auto resultShape = resultTy.getShape();
|
|
||||||
assert(resultLayout && "Unexpected resultLayout in UnaryOpConversion");
|
|
||||||
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
|
||||||
Type elemTy =
|
|
||||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
|
||||||
SmallVector<Type> types(elems, elemTy);
|
|
||||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
|
||||||
|
|
||||||
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
|
||||||
auto srcs = this->getElementsFromStruct(loc, concreteThis->getSrc(adaptor),
|
|
||||||
rewriter);
|
|
||||||
SmallVector<Value> resultVals(elems);
|
|
||||||
for (unsigned i = 0; i < elems; ++i) {
|
|
||||||
resultVals[i] =
|
|
||||||
concreteThis->createDestOp(op, rewriter, elemTy, srcs[i], loc);
|
|
||||||
}
|
|
||||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
|
||||||
rewriter.replaceOp(op, view);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename SourceOp, typename DestOp>
|
|
||||||
struct UnaryOpConversion
|
|
||||||
: public UnaryOpConversionBase<SourceOp, DestOp,
|
|
||||||
UnaryOpConversion<SourceOp, DestOp>> {
|
|
||||||
|
|
||||||
explicit UnaryOpConversion(LLVMTypeConverter &typeConverter,
|
|
||||||
PatternBenefit benefit = 1)
|
|
||||||
: UnaryOpConversionBase<SourceOp, DestOp,
|
|
||||||
UnaryOpConversion<SourceOp, DestOp>>(
|
|
||||||
typeConverter, benefit) {}
|
|
||||||
|
|
||||||
using OpAdaptor = typename SourceOp::Adaptor;
|
|
||||||
// An interface to support variant DestOp builder.
|
|
||||||
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
|
|
||||||
Type elemTy, Value src, Location loc) const {
|
|
||||||
return rewriter.create<DestOp>(loc, elemTy, src);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the source operand of the op.
|
|
||||||
Value getSrc(OpAdaptor adaptor) const {
|
|
||||||
auto operands = adaptor.getOperands();
|
|
||||||
if (operands.size() > 1)
|
|
||||||
llvm::report_fatal_error("unary operator has more than one operand");
|
|
||||||
return operands.front();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// comparisons
|
// comparisons
|
||||||
//
|
//
|
||||||
@@ -2367,13 +2200,13 @@ LogicalResult ConvertLayoutOpConversion::lowerDistributedToDistributed(
|
|||||||
}
|
}
|
||||||
// Potentially we need to store for multiple CTAs in this replication
|
// Potentially we need to store for multiple CTAs in this replication
|
||||||
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
unsigned accumNumReplicates = product<unsigned>(numReplicates);
|
||||||
unsigned elems = getElemsPerThread(srcLayout, srcTy.getShape());
|
unsigned elems = getElemsPerThread(srcTy);
|
||||||
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||||
unsigned inVec = 0;
|
unsigned inVec = 0;
|
||||||
unsigned outVec = 0;
|
unsigned outVec = 0;
|
||||||
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
|
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
|
||||||
|
|
||||||
unsigned outElems = getElemsPerThread(dstLayout, shape);
|
unsigned outElems = getElemsPerThread(dstTy);
|
||||||
auto outOrd = getOrder(dstLayout);
|
auto outOrd = getOrder(dstLayout);
|
||||||
SmallVector<Value> outVals(outElems);
|
SmallVector<Value> outVals(outElems);
|
||||||
|
|
||||||
@@ -2431,7 +2264,7 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared(
|
|||||||
unsigned minVec = std::min(outVec, inVec);
|
unsigned minVec = std::min(outVec, inVec);
|
||||||
unsigned perPhase = dstSharedLayout.getPerPhase();
|
unsigned perPhase = dstSharedLayout.getPerPhase();
|
||||||
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
||||||
unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape);
|
unsigned numElems = getElemsPerThread(srcTy);
|
||||||
auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
auto inVals = getElementsFromStruct(loc, adaptor.src(), rewriter);
|
||||||
unsigned srcAccumSizeInThreads =
|
unsigned srcAccumSizeInThreads =
|
||||||
product<unsigned>(srcBlockedLayout.getSizePerThread());
|
product<unsigned>(srcBlockedLayout.getSizePerThread());
|
||||||
@@ -2609,7 +2442,8 @@ public:
|
|||||||
Value c = urem(lane, i32_val(8));
|
Value c = urem(lane, i32_val(8));
|
||||||
Value s = udiv(lane, i32_val(8)); // sub-warp-id
|
Value s = udiv(lane, i32_val(8)); // sub-warp-id
|
||||||
|
|
||||||
// Decompose s => s_0, s_1, that is the coordinate in 2x2 matrices in a warp
|
// Decompose s => s_0, s_1, that is the coordinate in 2x2 matrices in a
|
||||||
|
// warp
|
||||||
Value s0 = urem(s, i32_val(2));
|
Value s0 = urem(s, i32_val(2));
|
||||||
Value s1 = udiv(s, i32_val(2));
|
Value s1 = udiv(s, i32_val(2));
|
||||||
|
|
||||||
@@ -2756,8 +2590,8 @@ public:
|
|||||||
llvm::report_fatal_error("unsupported mma type found");
|
llvm::report_fatal_error("unsupported mma type found");
|
||||||
|
|
||||||
// The main difference with the original triton code is we removed the
|
// The main difference with the original triton code is we removed the
|
||||||
// prefetch-related logic here for the upstream optimizer phase should take
|
// prefetch-related logic here for the upstream optimizer phase should
|
||||||
// care with it, and that is transparent in dot conversion.
|
// take care with it, and that is transparent in dot conversion.
|
||||||
auto getPtr = [&](int idx) { return ptrs[idx]; };
|
auto getPtr = [&](int idx) { return ptrs[idx]; };
|
||||||
|
|
||||||
Value ptr = getPtr(ptrIdx);
|
Value ptr = getPtr(ptrIdx);
|
||||||
@@ -2768,7 +2602,8 @@ public:
|
|||||||
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
|
matIdx[order[1]] * sMatStride * sMatShape * sTileStride * elemBytes;
|
||||||
PTXBuilder builder;
|
PTXBuilder builder;
|
||||||
|
|
||||||
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a thread.
|
// ldmatrix.m8n8.x4 returns 4x2xfp16(that is 4xb32) elements for a
|
||||||
|
// thread.
|
||||||
auto resArgs = builder.newListOperand(4, "=r");
|
auto resArgs = builder.newListOperand(4, "=r");
|
||||||
auto addrArg = builder.newAddrOperand(ptr, "r", sOffset);
|
auto addrArg = builder.newAddrOperand(ptr, "r", sOffset);
|
||||||
|
|
||||||
@@ -3067,7 +2902,8 @@ struct DotOpConversionHelper {
|
|||||||
|
|
||||||
// Get the M and N of mma instruction shape.
|
// Get the M and N of mma instruction shape.
|
||||||
static std::tuple<int, int> getInstrShapeMN() {
|
static std::tuple<int, int> getInstrShapeMN() {
|
||||||
// According to DotOpConversionHelper::mmaInstrShape, all the M,N are {16,8}
|
// According to DotOpConversionHelper::mmaInstrShape, all the M,N are
|
||||||
|
// {16,8}
|
||||||
return {16, 8};
|
return {16, 8};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3808,8 +3644,7 @@ public:
|
|||||||
if (layout &&
|
if (layout &&
|
||||||
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
||||||
layout.isa<MmaEncodingAttr>())) {
|
layout.isa<MmaEncodingAttr>())) {
|
||||||
unsigned numElementsPerThread =
|
unsigned numElementsPerThread = getElemsPerThread(type);
|
||||||
getElemsPerThread(layout, type.getShape());
|
|
||||||
SmallVector<Type, 4> types(numElementsPerThread,
|
SmallVector<Type, 4> types(numElementsPerThread,
|
||||||
convertType(type.getElementType()));
|
convertType(type.getElementType()));
|
||||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||||
@@ -3927,7 +3762,7 @@ struct InsertSliceAsyncOpConversion
|
|||||||
Value llIndex = adaptor.index();
|
Value llIndex = adaptor.index();
|
||||||
|
|
||||||
// %src
|
// %src
|
||||||
auto srcElems = getLLVMElems(src, llSrc, srcBlockedLayout, rewriter, loc);
|
auto srcElems = getLLVMElems(src, llSrc, rewriter, loc);
|
||||||
|
|
||||||
// %dst
|
// %dst
|
||||||
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||||
@@ -3943,7 +3778,7 @@ struct InsertSliceAsyncOpConversion
|
|||||||
// %mask
|
// %mask
|
||||||
SmallVector<Value> maskElems;
|
SmallVector<Value> maskElems;
|
||||||
if (llMask) {
|
if (llMask) {
|
||||||
maskElems = getLLVMElems(mask, llMask, srcBlockedLayout, rewriter, loc);
|
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
|
||||||
assert(srcElems.size() == maskElems.size());
|
assert(srcElems.size() == maskElems.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3954,15 +3789,14 @@ struct InsertSliceAsyncOpConversion
|
|||||||
// It's not necessary for now because the pipeline pass will skip
|
// It's not necessary for now because the pipeline pass will skip
|
||||||
// generating insert_slice_async if the load op has any "other" tensor.
|
// generating insert_slice_async if the load op has any "other" tensor.
|
||||||
assert(false && "insert_slice_async: Other value not supported yet");
|
assert(false && "insert_slice_async: Other value not supported yet");
|
||||||
otherElems =
|
otherElems = getLLVMElems(other, llOther, rewriter, loc);
|
||||||
getLLVMElems(other, llOther, srcBlockedLayout, rewriter, loc);
|
|
||||||
assert(srcElems.size() == otherElems.size());
|
assert(srcElems.size() == otherElems.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned inVec = getVectorizeSize(src, srcBlockedLayout);
|
unsigned inVec = getVectorSize(src);
|
||||||
unsigned outVec = resSharedLayout.getVec();
|
unsigned outVec = resSharedLayout.getVec();
|
||||||
unsigned minVec = std::min(outVec, inVec);
|
unsigned minVec = std::min(outVec, inVec);
|
||||||
unsigned numElems = getElemsPerThread(srcBlockedLayout, srcShape);
|
unsigned numElems = getElemsPerThread(srcTy);
|
||||||
unsigned perPhase = resSharedLayout.getPerPhase();
|
unsigned perPhase = resSharedLayout.getPerPhase();
|
||||||
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
unsigned maxPhase = resSharedLayout.getMaxPhase();
|
||||||
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
auto sizePerThread = srcBlockedLayout.getSizePerThread();
|
||||||
@@ -4212,6 +4046,7 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
||||||
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
#define POPULATE_UNARY_OP(SRC_OP, DST_OP) \
|
||||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||||
|
|
||||||
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
|
POPULATE_UNARY_OP(arith::TruncIOp, LLVM::TruncOp)
|
||||||
POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp)
|
POPULATE_UNARY_OP(arith::TruncFOp, LLVM::FPTruncOp)
|
||||||
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
|
POPULATE_UNARY_OP(arith::ExtSIOp, LLVM::SExtOp)
|
||||||
@@ -4221,14 +4056,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp)
|
POPULATE_UNARY_OP(arith::UIToFPOp, LLVM::UIToFPOp)
|
||||||
POPULATE_UNARY_OP(arith::SIToFPOp, LLVM::SIToFPOp)
|
POPULATE_UNARY_OP(arith::SIToFPOp, LLVM::SIToFPOp)
|
||||||
POPULATE_UNARY_OP(arith::ExtFOp, LLVM::FPExtOp)
|
POPULATE_UNARY_OP(arith::ExtFOp, LLVM::FPExtOp)
|
||||||
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
|
|
||||||
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
|
|
||||||
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
|
|
||||||
POPULATE_UNARY_OP(math::LogOp, math::LogOp)
|
POPULATE_UNARY_OP(math::LogOp, math::LogOp)
|
||||||
POPULATE_UNARY_OP(math::CosOp, math::CosOp)
|
POPULATE_UNARY_OP(math::CosOp, math::CosOp)
|
||||||
POPULATE_UNARY_OP(math::SinOp, math::SinOp)
|
POPULATE_UNARY_OP(math::SinOp, math::SinOp)
|
||||||
POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp)
|
POPULATE_UNARY_OP(math::SqrtOp, math::SqrtOp)
|
||||||
POPULATE_UNARY_OP(math::ExpOp, math::ExpOp)
|
POPULATE_UNARY_OP(math::ExpOp, math::ExpOp)
|
||||||
|
POPULATE_UNARY_OP(triton::BitcastOp, LLVM::BitcastOp)
|
||||||
|
POPULATE_UNARY_OP(triton::IntToPtrOp, LLVM::IntToPtrOp)
|
||||||
|
POPULATE_UNARY_OP(triton::PtrToIntOp, LLVM::PtrToIntOp)
|
||||||
#undef POPULATE_UNARY_OP
|
#undef POPULATE_UNARY_OP
|
||||||
|
|
||||||
patterns.add<FDivOpConversion>(typeConverter, benefit);
|
patterns.add<FDivOpConversion>(typeConverter, benefit);
|
||||||
|
@@ -131,6 +131,16 @@ void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
//-- LoadOp --
|
//-- LoadOp --
|
||||||
|
static Type getLoadOpResultType(::mlir::OpBuilder &builder, Type ptrType) {
|
||||||
|
auto ptrTensorType = ptrType.dyn_cast<RankedTensorType>();
|
||||||
|
if (!ptrTensorType)
|
||||||
|
return ptrType.cast<PointerType>().getPointeeType();
|
||||||
|
auto shape = ptrTensorType.getShape();
|
||||||
|
Type elementType =
|
||||||
|
ptrTensorType.getElementType().cast<PointerType>().getPointeeType();
|
||||||
|
return RankedTensorType::get(shape, elementType);
|
||||||
|
}
|
||||||
|
|
||||||
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||||
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
|
::mlir::Value ptr, ::mlir::triton::CacheModifier cache,
|
||||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||||
@@ -150,11 +160,8 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
|||||||
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
|
::mlir::Value ptr, ::mlir::Value mask, ::mlir::Value other,
|
||||||
::mlir::triton::CacheModifier cache,
|
::mlir::triton::CacheModifier cache,
|
||||||
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
::mlir::triton::EvictionPolicy evict, bool isVolatile) {
|
||||||
TensorType ptrType = ptr.getType().cast<TensorType>();
|
Type resultType = getLoadOpResultType(builder, ptr.getType());
|
||||||
Type elementType =
|
|
||||||
ptrType.getElementType().cast<PointerType>().getPointeeType();
|
|
||||||
auto shape = ptrType.getShape();
|
|
||||||
Type resultType = RankedTensorType::get(shape, elementType);
|
|
||||||
state.addOperands(ptr);
|
state.addOperands(ptr);
|
||||||
if (mask) {
|
if (mask) {
|
||||||
state.addOperands(mask);
|
state.addOperands(mask);
|
||||||
|
@@ -43,7 +43,12 @@ static Type getPointeeType(Type type) {
|
|||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
// TODO: Inheritation of layout attributes
|
// TODO: Inheritation of layout attributes
|
||||||
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
|
unsigned getElemsPerThread(Type type) {
|
||||||
|
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
|
||||||
|
return 1;
|
||||||
|
auto tensorType = type.cast<RankedTensorType>();
|
||||||
|
auto layout = tensorType.getEncoding();
|
||||||
|
auto shape = tensorType.getShape();
|
||||||
size_t rank = shape.size();
|
size_t rank = shape.size();
|
||||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||||
return blockedLayout.getElemsPerThread(shape);
|
return blockedLayout.getElemsPerThread(shape);
|
||||||
|
@@ -556,45 +556,45 @@ def make_ptr_str(name, shape):
|
|||||||
# # ---------------
|
# # ---------------
|
||||||
|
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def fn(a, b):
|
def fn(a, b):
|
||||||
# return a + b, \
|
return a + b, \
|
||||||
# a - b, \
|
a - b, \
|
||||||
# a * b
|
a * b
|
||||||
|
|
||||||
|
|
||||||
# def test_tuples():
|
def test_tuples():
|
||||||
# device = 'cuda'
|
device = 'cuda'
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def with_fn(X, Y, A, B, C):
|
def with_fn(X, Y, A, B, C):
|
||||||
# x = tl.load(X)
|
x = tl.load(X)
|
||||||
# y = tl.load(Y)
|
y = tl.load(Y)
|
||||||
# a, b, c = fn(x, y)
|
a, b, c = fn(x, y)
|
||||||
# tl.store(A, a)
|
tl.store(A, a)
|
||||||
# tl.store(B, b)
|
tl.store(B, b)
|
||||||
# tl.store(C, c)
|
tl.store(C, c)
|
||||||
|
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def without_fn(X, Y, A, B, C):
|
def without_fn(X, Y, A, B, C):
|
||||||
# x = tl.load(X)
|
x = tl.load(X)
|
||||||
# y = tl.load(Y)
|
y = tl.load(Y)
|
||||||
# a, b, c = x + y, x - y, x * y
|
a, b, c = x + y, x - y, x * y
|
||||||
# tl.store(A, a)
|
tl.store(A, a)
|
||||||
# tl.store(B, b)
|
tl.store(B, b)
|
||||||
# tl.store(C, c)
|
tl.store(C, c)
|
||||||
|
|
||||||
# x = torch.tensor([1.3], device=device, dtype=torch.float32)
|
x = torch.tensor([1.3], device=device, dtype=torch.float32)
|
||||||
# y = torch.tensor([1.9], device=device, dtype=torch.float32)
|
y = torch.tensor([1.9], device=device, dtype=torch.float32)
|
||||||
# a_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
a_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||||
# b_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
b_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||||
# c_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
c_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||||
# for kernel in [with_fn, without_fn]:
|
for kernel in [with_fn, without_fn]:
|
||||||
# kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1)
|
kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1)
|
||||||
# a_ref, b_ref, c_ref = x + y, x - y, x * y
|
a_ref, b_ref, c_ref = x + y, x - y, x * y
|
||||||
# assert a_tri == a_ref
|
assert a_tri == a_ref
|
||||||
# assert b_tri == b_ref
|
assert b_tri == b_ref
|
||||||
# assert c_tri == c_ref
|
assert c_tri == c_ref
|
||||||
|
|
||||||
|
|
||||||
# # ---------------
|
# # ---------------
|
||||||
@@ -709,75 +709,77 @@ def make_ptr_str(name, shape):
|
|||||||
# # ---------------
|
# # ---------------
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
||||||
# (dtype_x, dtype_z, False)
|
(dtype_x, dtype_z, False)
|
||||||
# for dtype_x in dtypes
|
for dtype_x in dtypes
|
||||||
# for dtype_z in dtypes
|
for dtype_z in dtypes
|
||||||
# ] + [
|
] + [
|
||||||
|
# TODO:
|
||||||
# ('float32', 'bfloat16', False),
|
# ('float32', 'bfloat16', False),
|
||||||
# ('bfloat16', 'float32', False),
|
# ('bfloat16', 'float32', False),
|
||||||
# ('float32', 'int32', True),
|
('float32', 'int32', True),
|
||||||
|
# TODO:
|
||||||
# ('float32', 'int1', False),
|
# ('float32', 'int1', False),
|
||||||
# ] + [
|
] + [
|
||||||
# (f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
||||||
# ] + [
|
] + [
|
||||||
# (f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
|
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
|
||||||
# ])
|
])
|
||||||
# def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||||
# # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||||
# x0 = 43 if dtype_x in int_dtypes else 43.5
|
x0 = 43 if dtype_x in int_dtypes else 43.5
|
||||||
# if dtype_x in float_dtypes and dtype_z == 'int1':
|
if dtype_x in float_dtypes and dtype_z == 'int1':
|
||||||
# x0 = 0.5
|
x0 = 0.5
|
||||||
# if dtype_x.startswith('bfloat'):
|
if dtype_x.startswith('bfloat'):
|
||||||
# x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device)
|
x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device)
|
||||||
# else:
|
else:
|
||||||
# x = np.array([x0], dtype=getattr(np, dtype_x))
|
x = np.array([x0], dtype=getattr(np, dtype_x))
|
||||||
# x_tri = to_triton(x)
|
x_tri = to_triton(x)
|
||||||
|
|
||||||
# # triton kernel
|
# triton kernel
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def kernel(X, Z, BITCAST: tl.constexpr):
|
def kernel(X, Z, BITCAST: tl.constexpr):
|
||||||
# x = tl.load(X)
|
x = tl.load(X)
|
||||||
# z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
||||||
# tl.store(Z, z)
|
tl.store(Z, z)
|
||||||
|
|
||||||
# dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
|
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
|
||||||
# # triton result
|
# triton result
|
||||||
# if dtype_z.startswith('bfloat'):
|
if dtype_z.startswith('bfloat'):
|
||||||
# z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device)
|
z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device)
|
||||||
# else:
|
else:
|
||||||
# z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z_np)), device=device)
|
z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z_np)), device=device)
|
||||||
# kernel[(1, )](x_tri, z_tri, BITCAST=bitcast)
|
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast)
|
||||||
# # torch result
|
# torch result
|
||||||
# if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
|
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
|
||||||
# assert bitcast is False
|
assert bitcast is False
|
||||||
# z_ref = x_tri.to(z_tri.dtype)
|
z_ref = x_tri.to(z_tri.dtype)
|
||||||
# assert z_tri == z_ref
|
assert z_tri == z_ref
|
||||||
# else:
|
else:
|
||||||
# if bitcast:
|
if bitcast:
|
||||||
# z_ref = x.view(getattr(np, dtype_z_np))
|
z_ref = x.view(getattr(np, dtype_z_np))
|
||||||
# else:
|
else:
|
||||||
# z_ref = x.astype(getattr(np, dtype_z_np))
|
z_ref = x.astype(getattr(np, dtype_z_np))
|
||||||
# assert to_numpy(z_tri) == z_ref
|
assert to_numpy(z_tri) == z_ref
|
||||||
|
|
||||||
|
|
||||||
# def test_store_bool():
|
def test_store_bool():
|
||||||
# """Tests that boolean True is stored as 1"""
|
"""Tests that boolean True is stored as 1"""
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||||
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
# mask = offsets < n_elements
|
mask = offsets < n_elements
|
||||||
# input = tl.load(input_ptr + offsets, mask=mask)
|
input = tl.load(input_ptr + offsets, mask=mask)
|
||||||
# output = input
|
output = input
|
||||||
# tl.store(output_ptr + offsets, output, mask=mask)
|
tl.store(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
# src = torch.tensor([True, False], dtype=torch.bool, device='cuda')
|
src = torch.tensor([True, False], dtype=torch.bool, device='cuda')
|
||||||
# n_elements = src.numel()
|
n_elements = src.numel()
|
||||||
# dst = torch.empty_like(src)
|
dst = torch.empty_like(src)
|
||||||
# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||||
# copy_kernel[grid](src, dst, n_elements, BLOCK_SIZE=1024)
|
copy_kernel[grid](src, dst, n_elements, BLOCK_SIZE=1024)
|
||||||
|
|
||||||
# assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all()
|
assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all()
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@@ -990,48 +992,49 @@ def make_ptr_str(name, shape):
|
|||||||
# # ---------------
|
# # ---------------
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("dtype_str, shape, perm",
|
@pytest.mark.parametrize("dtype_str, shape, perm",
|
||||||
# [(dtype, shape, perm)
|
[(dtype, shape, perm)
|
||||||
# for dtype in ['bfloat16', 'float16', 'float32']
|
# TODO: bfloat16
|
||||||
# for shape in [(64, 64), (128, 128)]
|
for dtype in ['float16', 'float32']
|
||||||
# for perm in [(1, 0)]])
|
for shape in [(64, 64), (128, 128)]
|
||||||
# def test_permute(dtype_str, shape, perm, device='cuda'):
|
for perm in [(1, 0)]])
|
||||||
# check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||||
|
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||||
|
|
||||||
# # triton kernel
|
# triton kernel
|
||||||
# @triton.jit
|
@triton.jit
|
||||||
# def kernel(X, stride_xm, stride_xn,
|
def kernel(X, stride_xm, stride_xn,
|
||||||
# Z, stride_zm, stride_zn,
|
Z, stride_zm, stride_zn,
|
||||||
# BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||||
# off_m = tl.arange(0, BLOCK_M)
|
off_m = tl.arange(0, BLOCK_M)
|
||||||
# off_n = tl.arange(0, BLOCK_N)
|
off_n = tl.arange(0, BLOCK_N)
|
||||||
# Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
||||||
# Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||||
# tl.store(Zs, tl.load(Xs))
|
tl.store(Zs, tl.load(Xs))
|
||||||
# # input
|
# input
|
||||||
# x = numpy_random(shape, dtype_str=dtype_str)
|
x = numpy_random(shape, dtype_str=dtype_str)
|
||||||
# # triton result
|
# triton result
|
||||||
# z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
||||||
# z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
||||||
# x_tri = to_triton(x, device=device, dst_type=dtype_str)
|
x_tri = to_triton(x, device=device, dst_type=dtype_str)
|
||||||
# pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||||
# z_tri, z_tri.stride(1), z_tri.stride(0),
|
z_tri, z_tri.stride(1), z_tri.stride(0),
|
||||||
# BLOCK_M=shape[0], BLOCK_N=shape[1])
|
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||||
# pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0),
|
pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0),
|
||||||
# z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
|
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
|
||||||
# BLOCK_M=shape[0], BLOCK_N=shape[1])
|
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||||
# # numpy result
|
# numpy result
|
||||||
# z_ref = x.transpose(*perm)
|
z_ref = x.transpose(*perm)
|
||||||
# # compare
|
# compare
|
||||||
# triton.testing.assert_almost_equal(z_tri, z_ref)
|
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||||
# triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
|
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
|
||||||
# # parse ptx to make sure ld/st are vectorized
|
# parse ptx to make sure ld/st are vectorized
|
||||||
# ptx = pgm.asm['ptx']
|
ptx = pgm.asm['ptx']
|
||||||
# assert 'ld.global.v4' in ptx
|
assert 'ld.global.v4' in ptx
|
||||||
# assert 'st.global.v4' in ptx
|
assert 'st.global.v4' in ptx
|
||||||
# ptx = pgm_contiguous.asm['ptx']
|
ptx = pgm_contiguous.asm['ptx']
|
||||||
# assert 'ld.global.v4' in ptx
|
assert 'ld.global.v4' in ptx
|
||||||
# assert 'st.global.v4' in ptx
|
assert 'st.global.v4' in ptx
|
||||||
|
|
||||||
# # ---------------
|
# # ---------------
|
||||||
# # test dot
|
# # test dot
|
||||||
|
@@ -153,7 +153,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
|
|||||||
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||||
|
|
||||||
// Store 4 elements to global with single one vectorized store instruction
|
// Store 4 elements to global with single one vectorized store instruction
|
||||||
// CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
|
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
|
||||||
tt.store %13, %11 : tensor<256xf32, #blocked0>
|
tt.store %13, %11 : tensor<256xf32, #blocked0>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -222,8 +222,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
|||||||
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>
|
||||||
|
|
||||||
// Store 8 elements to global with two vectorized store instruction
|
// Store 8 elements to global with two vectorized store instruction
|
||||||
// CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
|
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
|
||||||
// CHECK: @$5 st.global.b32.v4 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
|
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
|
||||||
tt.store %13, %11 : tensor<256xf32, #blocked0>
|
tt.store %13, %11 : tensor<256xf32, #blocked0>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user