[Triton-MLIR][Backend] Add SCF lowering in the backend (#750)
This commit is contained in:
@@ -6,6 +6,7 @@
|
|||||||
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
||||||
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
||||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||||
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
@@ -1287,24 +1288,33 @@ struct AddPtrOpConversion
|
|||||||
matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor,
|
matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
auto resultTy = op.getType();
|
||||||
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
auto resultTensorTy = resultTy.dyn_cast<RankedTensorType>();
|
||||||
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
|
if (resultTensorTy) {
|
||||||
auto resultShape = resultTy.getShape();
|
auto resultLayout =
|
||||||
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
resultTensorTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||||
Type elemTy =
|
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
|
||||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
auto resultShape = resultTensorTy.getShape();
|
||||||
SmallVector<Type> types(elems, elemTy);
|
unsigned elems = resultLayout.getElemsPerThread(resultShape);
|
||||||
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
|
Type elemTy =
|
||||||
auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), elems, rewriter);
|
getTypeConverter()->convertType(resultTensorTy.getElementType());
|
||||||
auto offsets =
|
SmallVector<Type> types(elems, elemTy);
|
||||||
getElementsFromStruct(loc, adaptor.offset(), elems, rewriter);
|
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
|
||||||
SmallVector<Value> resultVals(elems);
|
auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), elems, rewriter);
|
||||||
for (unsigned i = 0; i < elems; ++i) {
|
auto offsets =
|
||||||
resultVals[i] = gep(elemTy, ptrs[i], offsets[i]);
|
getElementsFromStruct(loc, adaptor.offset(), elems, rewriter);
|
||||||
|
SmallVector<Value> resultVals(elems);
|
||||||
|
for (unsigned i = 0; i < elems; ++i) {
|
||||||
|
resultVals[i] = gep(elemTy, ptrs[i], offsets[i]);
|
||||||
|
}
|
||||||
|
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||||
|
rewriter.replaceOp(op, view);
|
||||||
|
} else {
|
||||||
|
assert(resultTy.isa<triton::PointerType>());
|
||||||
|
Type llResultTy = getTypeConverter()->convertType(resultTy);
|
||||||
|
Value result = gep(llResultTy, adaptor.ptr(), adaptor.offset());
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
}
|
}
|
||||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
|
||||||
rewriter.replaceOp(op, view);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -3066,6 +3076,7 @@ public:
|
|||||||
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
||||||
patterns);
|
patterns);
|
||||||
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
|
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
@@ -3122,6 +3133,7 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget(
|
|||||||
// addIllegalDialect<triton::TritonDialect>();
|
// addIllegalDialect<triton::TritonDialect>();
|
||||||
// addIllegalDialect<triton::gpu::TritonGPUDialect>();
|
// addIllegalDialect<triton::gpu::TritonGPUDialect>();
|
||||||
addIllegalDialect<mlir::gpu::GPUDialect>();
|
addIllegalDialect<mlir::gpu::GPUDialect>();
|
||||||
|
addIllegalDialect<mlir::StandardOpsDialect>();
|
||||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
|
||||||
|
#include "mlir/Conversion/Passes.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||||
@@ -135,6 +136,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
|||||||
/*printAfterOnlyOnChange=*/true,
|
/*printAfterOnlyOnChange=*/true,
|
||||||
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
|
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
|
||||||
|
|
||||||
|
pm.addPass(mlir::createLowerToCFGPass());
|
||||||
pm.addPass(createConvertTritonGPUToLLVMPass());
|
pm.addPass(createConvertTritonGPUToLLVMPass());
|
||||||
// Conanicalize to eliminate the remaining UnrealizedConversionCastOp
|
// Conanicalize to eliminate the remaining UnrealizedConversionCastOp
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "mlir/IR/Verifier.h"
|
#include "mlir/IR/Verifier.h"
|
||||||
|
|
||||||
|
#include "mlir/Conversion/Passes.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
@@ -1185,8 +1186,12 @@ void init_triton_ir(py::module &&m) {
|
|||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self) {
|
||||||
self.addPass(mlir::createTritonGPUVerifier());
|
self.addPass(mlir::createTritonGPUVerifier());
|
||||||
})
|
})
|
||||||
.def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) {
|
.def("add_triton_gpu_to_llvm",
|
||||||
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
[](mlir::PassManager &self) {
|
||||||
|
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
|
||||||
|
})
|
||||||
|
.def("add_scf_to_cfg", [](mlir::PassManager &self) {
|
||||||
|
self.addPass(mlir::createLowerToCFGPass());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
79
python/tests/test_vecadd.py
Normal file
79
python/tests/test_vecadd.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.testing import assert_allclose
|
||||||
|
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('NUM_WARPS, BLOCK_SIZE', [
|
||||||
|
[4, 256],
|
||||||
|
[2, 256],
|
||||||
|
[1, 256],
|
||||||
|
])
|
||||||
|
def test_vecadd_no_mask(NUM_WARPS, BLOCK_SIZE):
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel(x_ptr,
|
||||||
|
y_ptr,
|
||||||
|
z_ptr,
|
||||||
|
BLOCK_SIZE: tl.constexpr):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
x_ptrs = x_ptr + offset
|
||||||
|
y_ptrs = y_ptr + offset
|
||||||
|
x = tl.load(x_ptrs)
|
||||||
|
y = tl.load(y_ptrs)
|
||||||
|
z = x + y
|
||||||
|
z_ptrs = z_ptr + offset
|
||||||
|
tl.store(z_ptrs, z)
|
||||||
|
|
||||||
|
x = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
|
||||||
|
y = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
|
||||||
|
z = torch.empty((BLOCK_SIZE,), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
grid = lambda EA: (x.shape.numel() // BLOCK_SIZE,)
|
||||||
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS)
|
||||||
|
|
||||||
|
golden_z = x + y
|
||||||
|
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('NUM_WARPS, BLOCK_SIZE, ITER_SIZE', [
|
||||||
|
[4, 256, 1],
|
||||||
|
[4, 1024, 256],
|
||||||
|
])
|
||||||
|
def test_vecadd_scf_no_mask(NUM_WARPS, BLOCK_SIZE, ITER_SIZE):
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel(x_ptr,
|
||||||
|
y_ptr,
|
||||||
|
z_ptr,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
ITER_SIZE: tl.constexpr):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
for i in range(0, BLOCK_SIZE, ITER_SIZE):
|
||||||
|
offset = pid * BLOCK_SIZE + tl.arange(0, ITER_SIZE)
|
||||||
|
x_ptrs = x_ptr + offset
|
||||||
|
y_ptrs = y_ptr + offset
|
||||||
|
x = tl.load(x_ptrs)
|
||||||
|
y = tl.load(y_ptrs)
|
||||||
|
z = x + y
|
||||||
|
z_ptrs = z_ptr + offset
|
||||||
|
tl.store(z_ptrs, z)
|
||||||
|
x_ptr += ITER_SIZE
|
||||||
|
y_ptr += ITER_SIZE
|
||||||
|
z_ptr += ITER_SIZE
|
||||||
|
|
||||||
|
x = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
|
||||||
|
y = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
|
||||||
|
z = torch.empty((BLOCK_SIZE,), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
grid = lambda EA: (x.shape.numel() // (BLOCK_SIZE),)
|
||||||
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
|
||||||
|
BLOCK_SIZE=x.shape[0], ITER_SIZE=ITER_SIZE, num_warps=NUM_WARPS)
|
||||||
|
|
||||||
|
golden_z = x + y
|
||||||
|
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||||
|
|
||||||
|
# TODO: test_vecadd with mask
|
@@ -1,42 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch.testing import assert_allclose
|
|
||||||
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
def vecadd_no_scf_tester(num_warps, block_size):
|
|
||||||
@triton.jit
|
|
||||||
def kernel(x_ptr,
|
|
||||||
y_ptr,
|
|
||||||
z_ptr,
|
|
||||||
BLOCK_SIZE_N: tl.constexpr):
|
|
||||||
pid = tl.program_id(axis=0)
|
|
||||||
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
||||||
x_ptrs = x_ptr + offset
|
|
||||||
y_ptrs = y_ptr + offset
|
|
||||||
x = tl.load(x_ptrs)
|
|
||||||
y = tl.load(y_ptrs)
|
|
||||||
z = x + y
|
|
||||||
z_ptrs = z_ptr + offset
|
|
||||||
tl.store(z_ptrs, z)
|
|
||||||
|
|
||||||
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
|
||||||
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
|
||||||
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
|
||||||
|
|
||||||
grid = lambda EA: (x.shape.numel() // block_size,)
|
|
||||||
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE_N=block_size, num_warps=num_warps)
|
|
||||||
|
|
||||||
golden_z = x + y
|
|
||||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
|
||||||
|
|
||||||
|
|
||||||
def test_vecadd_no_scf():
|
|
||||||
vecadd_no_scf_tester(num_warps=4, block_size=256)
|
|
||||||
vecadd_no_scf_tester(num_warps=2, block_size=256)
|
|
||||||
vecadd_no_scf_tester(num_warps=1, block_size=256)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_vecadd_no_scf()
|
|
Reference in New Issue
Block a user