From 1d772cd8433443aa5eee7abc28c739a48472f3ed Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Sat, 8 Oct 2022 18:36:37 +0800 Subject: [PATCH] [Triton-MLIR][Backend] Add SCF lowering in the backend (#750) --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 46 +++++++---- lib/Target/LLVMIR/LLVMIRTranslation.cpp | 2 + python/src/triton.cc | 9 ++- python/tests/test_vecadd.py | 79 +++++++++++++++++++ python/tests/test_vecadd_no_scf.py | 42 ---------- 5 files changed, 117 insertions(+), 61 deletions(-) create mode 100644 python/tests/test_vecadd.py delete mode 100644 python/tests/test_vecadd_no_scf.py diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 3501a4c09..b58f38094 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -6,6 +6,7 @@ #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -1287,24 +1288,33 @@ struct AddPtrOpConversion matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - auto resultTy = op.getType().dyn_cast(); - auto resultLayout = resultTy.getEncoding().dyn_cast(); - assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion"); - auto resultShape = resultTy.getShape(); - unsigned elems = resultLayout.getElemsPerThread(resultShape); - Type elemTy = - this->getTypeConverter()->convertType(resultTy.getElementType()); - SmallVector types(elems, elemTy); - Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types); - auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), elems, rewriter); - auto offsets = - getElementsFromStruct(loc, adaptor.offset(), elems, rewriter); - SmallVector resultVals(elems); - for (unsigned i = 0; i < elems; ++i) { - resultVals[i] = gep(elemTy, ptrs[i], offsets[i]); + auto resultTy = op.getType(); + auto resultTensorTy = resultTy.dyn_cast(); + if (resultTensorTy) { + auto resultLayout = + resultTensorTy.getEncoding().dyn_cast(); + assert(resultLayout && "Unexpected resultLayout in AddPtrOpConversion"); + auto resultShape = resultTensorTy.getShape(); + unsigned elems = resultLayout.getElemsPerThread(resultShape); + Type elemTy = + getTypeConverter()->convertType(resultTensorTy.getElementType()); + SmallVector types(elems, elemTy); + Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types); + auto ptrs = getElementsFromStruct(loc, adaptor.ptr(), elems, rewriter); + auto offsets = + getElementsFromStruct(loc, adaptor.offset(), elems, rewriter); + SmallVector resultVals(elems); + for (unsigned i = 0; i < elems; ++i) { + resultVals[i] = gep(elemTy, ptrs[i], offsets[i]); + } + Value view = getStructFromElements(loc, resultVals, rewriter, structTy); + rewriter.replaceOp(op, view); + } else { + assert(resultTy.isa()); + Type llResultTy = getTypeConverter()->convertType(resultTy); + Value result = gep(llResultTy, adaptor.ptr(), adaptor.offset()); + rewriter.replaceOp(op, result); } - Value view = getStructFromElements(loc, resultVals, rewriter, structTy); - rewriter.replaceOp(op, view); return success(); } }; @@ -3066,6 +3076,7 @@ public: mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); + mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); @@ -3122,6 +3133,7 @@ TritonLLVMConversionTarget::TritonLLVMConversionTarget( // addIllegalDialect(); // addIllegalDialect(); addIllegalDialect(); + addIllegalDialect(); addLegalOp(); } diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 5837b0973..5967a0e04 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -1,4 +1,5 @@ #include "triton/Target/LLVMIR/LLVMIRTranslation.h" +#include "mlir/Conversion/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/OptUtils.h" @@ -135,6 +136,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, /*printAfterOnlyOnChange=*/true, /*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags); + pm.addPass(mlir::createLowerToCFGPass()); pm.addPass(createConvertTritonGPUToLLVMPass()); // Conanicalize to eliminate the remaining UnrealizedConversionCastOp pm.addPass(mlir::createCanonicalizerPass()); diff --git a/python/src/triton.cc b/python/src/triton.cc index 424c2a28e..1521d4bd1 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -3,6 +3,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Verifier.h" +#include "mlir/Conversion/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" @@ -1185,8 +1186,12 @@ void init_triton_ir(py::module &&m) { [](mlir::PassManager &self) { self.addPass(mlir::createTritonGPUVerifier()); }) - .def("add_triton_gpu_to_llvm", [](mlir::PassManager &self) { - self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass()); + .def("add_triton_gpu_to_llvm", + [](mlir::PassManager &self) { + self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass()); + }) + .def("add_scf_to_cfg", [](mlir::PassManager &self) { + self.addPass(mlir::createLowerToCFGPass()); }); } diff --git a/python/tests/test_vecadd.py b/python/tests/test_vecadd.py new file mode 100644 index 000000000..1c73979f5 --- /dev/null +++ b/python/tests/test_vecadd.py @@ -0,0 +1,79 @@ +import pytest +import torch +from torch.testing import assert_allclose + +import triton +import triton.language as tl + + +@pytest.mark.parametrize('NUM_WARPS, BLOCK_SIZE', [ + [4, 256], + [2, 256], + [1, 256], +]) +def test_vecadd_no_mask(NUM_WARPS, BLOCK_SIZE): + + @triton.jit + def kernel(x_ptr, + y_ptr, + z_ptr, + BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x_ptrs = x_ptr + offset + y_ptrs = y_ptr + offset + x = tl.load(x_ptrs) + y = tl.load(y_ptrs) + z = x + y + z_ptrs = z_ptr + offset + tl.store(z_ptrs, z) + + x = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32) + y = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32) + z = torch.empty((BLOCK_SIZE,), device=x.device, dtype=x.dtype) + + grid = lambda EA: (x.shape.numel() // BLOCK_SIZE,) + kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS) + + golden_z = x + y + assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize('NUM_WARPS, BLOCK_SIZE, ITER_SIZE', [ + [4, 256, 1], + [4, 1024, 256], +]) +def test_vecadd_scf_no_mask(NUM_WARPS, BLOCK_SIZE, ITER_SIZE): + + @triton.jit + def kernel(x_ptr, + y_ptr, + z_ptr, + BLOCK_SIZE, + ITER_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + for i in range(0, BLOCK_SIZE, ITER_SIZE): + offset = pid * BLOCK_SIZE + tl.arange(0, ITER_SIZE) + x_ptrs = x_ptr + offset + y_ptrs = y_ptr + offset + x = tl.load(x_ptrs) + y = tl.load(y_ptrs) + z = x + y + z_ptrs = z_ptr + offset + tl.store(z_ptrs, z) + x_ptr += ITER_SIZE + y_ptr += ITER_SIZE + z_ptr += ITER_SIZE + + x = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32) + y = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32) + z = torch.empty((BLOCK_SIZE,), device=x.device, dtype=x.dtype) + + grid = lambda EA: (x.shape.numel() // (BLOCK_SIZE),) + kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, + BLOCK_SIZE=x.shape[0], ITER_SIZE=ITER_SIZE, num_warps=NUM_WARPS) + + golden_z = x + y + assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7) + +# TODO: test_vecadd with mask diff --git a/python/tests/test_vecadd_no_scf.py b/python/tests/test_vecadd_no_scf.py deleted file mode 100644 index 161b11b1f..000000000 --- a/python/tests/test_vecadd_no_scf.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -from torch.testing import assert_allclose - -import triton -import triton.language as tl - - -def vecadd_no_scf_tester(num_warps, block_size): - @triton.jit - def kernel(x_ptr, - y_ptr, - z_ptr, - BLOCK_SIZE_N: tl.constexpr): - pid = tl.program_id(axis=0) - offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - x_ptrs = x_ptr + offset - y_ptrs = y_ptr + offset - x = tl.load(x_ptrs) - y = tl.load(y_ptrs) - z = x + y - z_ptrs = z_ptr + offset - tl.store(z_ptrs, z) - - x = torch.randn((block_size,), device='cuda', dtype=torch.float32) - y = torch.randn((block_size,), device='cuda', dtype=torch.float32) - z = torch.empty((block_size,), device=x.device, dtype=x.dtype) - - grid = lambda EA: (x.shape.numel() // block_size,) - kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE_N=block_size, num_warps=num_warps) - - golden_z = x + y - assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7) - - -def test_vecadd_no_scf(): - vecadd_no_scf_tester(num_warps=4, block_size=256) - vecadd_no_scf_tester(num_warps=2, block_size=256) - vecadd_no_scf_tester(num_warps=1, block_size=256) - - -if __name__ == '__main__': - test_vecadd_no_scf()