[Triton-MLIR][Backend] Add SCF lowering in the backend (#750)

This commit is contained in:
goostavz
2022-10-08 18:36:37 +08:00
committed by GitHub
parent 498c685b46
commit 1d772cd843
5 changed files with 117 additions and 61 deletions

View File

@@ -6,6 +6,7 @@
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -1287,24 +1288,33 @@ struct AddPtrOpConversion
matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
auto resultShape = resultTy.getShape();
unsigned elems = resultLayout.getElemsPerThread(resultShape);
Type elemTy =
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), elems, rewriter);
auto offsets =
getElementsFromStruct(loc, adaptor.offset(), elems, rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = gep(elemTy, ptrs[i], offsets[i]);
auto resultTy = op.getType();
auto resultTensorTy = resultTy.dyn_cast<RankedTensorType>();
if (resultTensorTy) {
auto resultLayout =
resultTensorTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion");
auto resultShape = resultTensorTy.getShape();
unsigned elems = resultLayout.getElemsPerThread(resultShape);
Type elemTy =
getTypeConverter()->convertType(resultTensorTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), elems, rewriter);
auto offsets =
getElementsFromStruct(loc, adaptor.offset(), elems, rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = gep(elemTy, ptrs[i], offsets[i]);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
} else {
assert(resultTy.isa<triton::PointerType>());
Type llResultTy = getTypeConverter()->convertType(resultTy);
Value result = gep(llResultTy, adaptor.ptr(), adaptor.offset());
rewriter.replaceOp(op, result);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
return success();
}
};
@@ -3066,6 +3076,7 @@ public:
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
patterns);
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
@@ -3122,6 +3133,7 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget(
// addIllegalDialect<triton::TritonDialect>();
// addIllegalDialect<triton::gpu::TritonGPUDialect>();
addIllegalDialect<mlir::gpu::GPUDialect>();
addIllegalDialect<mlir::StandardOpsDialect>();
addLegalOp<mlir::UnrealizedConversionCastOp>();
}

View File

@@ -1,4 +1,5 @@
#include "triton/Target/LLVMIR/LLVMIRTranslation.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
@@ -135,6 +136,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
/*printAfterOnlyOnChange=*/true,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
pm.addPass(mlir::createLowerToCFGPass());
pm.addPass(createConvertTritonGPUToLLVMPass());
// Conanicalize to eliminate the remaining UnrealizedConversionCastOp
pm.addPass(mlir::createCanonicalizerPass());

View File

@@ -3,6 +3,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
@@ -1185,8 +1186,12 @@ void init_triton_ir(py::module &&m) {
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUVerifier());
})
.def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) {
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
.def("add_triton_gpu_to_llvm",
[](mlir::PassManager &self) {
self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass());
})
.def("add_scf_to_cfg", [](mlir::PassManager &self) {
self.addPass(mlir::createLowerToCFGPass());
});
}

View File

@@ -0,0 +1,79 @@
import pytest
import torch
from torch.testing import assert_allclose
import triton
import triton.language as tl
@pytest.mark.parametrize('NUM_WARPS, BLOCK_SIZE', [
[4, 256],
[2, 256],
[1, 256],
])
def test_vecadd_no_mask(NUM_WARPS, BLOCK_SIZE):
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
x = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
y = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
z = torch.empty((BLOCK_SIZE,), device=x.device, dtype=x.dtype)
grid = lambda EA: (x.shape.numel() // BLOCK_SIZE,)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS)
golden_z = x + y
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
@pytest.mark.parametrize('NUM_WARPS, BLOCK_SIZE, ITER_SIZE', [
[4, 256, 1],
[4, 1024, 256],
])
def test_vecadd_scf_no_mask(NUM_WARPS, BLOCK_SIZE, ITER_SIZE):
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
BLOCK_SIZE,
ITER_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
for i in range(0, BLOCK_SIZE, ITER_SIZE):
offset = pid * BLOCK_SIZE + tl.arange(0, ITER_SIZE)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
x_ptr += ITER_SIZE
y_ptr += ITER_SIZE
z_ptr += ITER_SIZE
x = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
y = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
z = torch.empty((BLOCK_SIZE,), device=x.device, dtype=x.dtype)
grid = lambda EA: (x.shape.numel() // (BLOCK_SIZE),)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
BLOCK_SIZE=x.shape[0], ITER_SIZE=ITER_SIZE, num_warps=NUM_WARPS)
golden_z = x + y
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
# TODO: test_vecadd with mask

View File

@@ -1,42 +0,0 @@
import torch
from torch.testing import assert_allclose
import triton
import triton.language as tl
def vecadd_no_scf_tester(num_warps, block_size):
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(axis=0)
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
grid = lambda EA: (x.shape.numel() // block_size,)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE_N=block_size, num_warps=num_warps)
golden_z = x + y
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
def test_vecadd_no_scf():
vecadd_no_scf_tester(num_warps=4, block_size=256)
vecadd_no_scf_tester(num_warps=2, block_size=256)
vecadd_no_scf_tester(num_warps=1, block_size=256)
if __name__ == '__main__':
test_vecadd_no_scf()