@@ -29,7 +29,7 @@ namespace {
|
||||
// convert(blocked, dot_operand) ->
|
||||
// convert(blocked, mma) + convert(mma, dot_operand)
|
||||
// if this value is itself the result of a dot operation
|
||||
// this is a heuristic to accomodate some pattern seen in fused attention
|
||||
// this is a heuristic to accommodate some pattern seen in fused attention
|
||||
// kernels.
|
||||
// TODO: replace this by something more generic, i.e. layout-aware CSE
|
||||
class DecomposeDotOperand : public mlir::RewritePattern {
|
||||
@@ -81,7 +81,7 @@ public:
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accomodate fused attention
|
||||
// this is a heuristics to accommodate fused attention
|
||||
// if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
// return mlir::failure();
|
||||
// convert to the same layout -- we can delete
|
||||
@@ -265,7 +265,7 @@ public:
|
||||
isSharedEncoding(cvt->getOperand(0)))
|
||||
return mlir::failure();
|
||||
// we don't handle conversions to DotOperandEncodingAttr
|
||||
// this is a heuristics to accomodate fused attention
|
||||
// this is a heuristics to accommodate fused attention
|
||||
auto targetType = cvt->getResultTypes()[0].cast<RankedTensorType>();
|
||||
if (targetType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
return mlir::failure();
|
||||
@@ -285,7 +285,7 @@ public:
|
||||
// we stop everything
|
||||
if (expensive_to_remat(currOp))
|
||||
break;
|
||||
// a conversion will be removed here (i.e. transfered to operands)
|
||||
// a conversion will be removed here (i.e. transferred to operands)
|
||||
numCvts -= 1;
|
||||
// done processing
|
||||
processed.insert(currOp);
|
||||
|
Reference in New Issue
Block a user