[triton-mlir][BACKEND] Support masked load/store (#657)

This PR does

- fix some bugs to support masked load/store,
- refine frontend, and support the `and` and `or` syntax in mask(by
extending the BoolOp in python ast.visitor), e.g. `tl.store(...,
mask=offset<n and other_conditions)`,
- add `arith.cmpI` and `arith.cmpF` op conversion in backend(required by
mask),
- add more test cases in vecadd.
This commit is contained in:
Yan Chunwei
2022-10-10 13:29:53 +08:00
committed by GitHub
parent ccc5ab6ac9
commit 555f94f9b9
9 changed files with 396 additions and 74 deletions

View File

@@ -145,7 +145,7 @@ std::string PTXInstrExecution::dump() const {
if (!pred->repr)
os << "@" << pred->dump() << " ";
else
os << pred->repr(pred->idx);
os << pred->repr(pred->idx) << " ";
}
std::string instrRepr = strJoin(instr->instrParts, ".");

View File

@@ -114,16 +114,16 @@ namespace type = mlir::triton::type;
class TritonGPUToLLVMTypeConverter;
// TODO: keep these before we have better debug log utilities
// TODO[goostavz]: Remove these methods after we have better debug log utilities
template <typename T>
void print_array(ArrayRef<T> array, const std::string &str) {
std::cout << str << ": ";
void printArray(ArrayRef<T> array, const std::string &info) {
std::cout << info << ": ";
for (const T &e : array)
std::cout << e << ",";
std::cout << std::endl;
}
template <typename T> void print_scalar(const T &e, const std::string &str) {
std::cout << str << ": " << e << std::endl;
template <typename T> void printScalar(const T &e, const std::string &info) {
std::cout << info << ": " << e << std::endl;
}
// FuncOpConversion/FuncOpConversionBase is borrowed from
@@ -808,14 +808,22 @@ struct StoreOpConversion
auto valueElems = getLLVMElems(value, llValue, layout, rewriter, loc);
assert(ptrElems.size() == valueElems.size());
// Determine the vectorization size
size_t vec = getVectorizeSize(ptr, layout);
SmallVector<Value> maskElems;
if (llMask) {
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
assert(valueElems.size() == maskElems.size());
}
auto maskOrder = mask.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<BlockedEncodingAttr>()
.getOrder();
// Determine the vectorization size
size_t vec = getVectorizeSize(ptr, layout);
auto maskAxis = getAxisInfo(mask);
size_t maskAlign = std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
vec = std::min(vec, maskAlign);
}
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
@@ -1376,13 +1384,15 @@ struct ExtractSliceOpConversion
}
};
template <typename SourceOp, typename DestOp>
class BinaryOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
// A CRTP style of base class.
template <typename SourceOp, typename DestOp, typename ConcreteT>
class BinaryOpConversionBase
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
explicit BinaryOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
explicit BinaryOpConversionBase(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
LogicalResult
@@ -1403,13 +1413,16 @@ public:
this->getTypeConverter()->convertType(resultTy.getElementType());
SmallVector<Type> types(elems, elemTy);
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
auto lhss =
this->getElementsFromStruct(loc, adaptor.getLhs(), elems, rewriter);
auto rhss =
this->getElementsFromStruct(loc, adaptor.getRhs(), elems, rewriter);
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto lhss = this->getElementsFromStruct(loc, concreteThis->getLhs(adaptor),
elems, rewriter);
auto rhss = this->getElementsFromStruct(loc, concreteThis->getRhs(adaptor),
elems, rewriter);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = rewriter.create<DestOp>(loc, elemTy, lhss[i], rhss[i]);
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
rhss[i], loc);
}
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, view);
@@ -1417,6 +1430,123 @@ public:
}
};
template <typename SourceOp, typename DestOp>
struct BinaryOpConversion
: public BinaryOpConversionBase<SourceOp, DestOp,
BinaryOpConversion<SourceOp, DestOp>> {
explicit BinaryOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: BinaryOpConversionBase<SourceOp, DestOp,
BinaryOpConversion<SourceOp, DestOp>>(
typeConverter, benefit) {}
using OpAdaptor = typename SourceOp::Adaptor;
// An interface to support variant DestOp builder.
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
Type elemTy, Value lhs, Value rhs, Location loc) const {
return rewriter.create<DestOp>(loc, elemTy, lhs, rhs);
}
// Get the left operand of the op.
Value getLhs(OpAdaptor adaptor) const { return adaptor.getLhs(); }
// Get the right operand of the op.
Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); }
};
struct CmpIOpConversion
: public BinaryOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
CmpIOpConversion> {
explicit CmpIOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: BinaryOpConversionBase(typeConverter, benefit) {}
// An interface to support variant DestOp builder.
LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op,
ConversionPatternRewriter &rewriter, Type elemTy,
Value lhs, Value rhs, Location loc) const {
return rewriter.create<LLVM::ICmpOp>(
loc, elemTy, ArithCmpIPredicteToLLVM(op.predicate()), lhs, rhs);
}
// Get the left operand of the op.
Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); }
// Get the right operand of the op.
Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); }
static LLVM::ICmpPredicate
ArithCmpIPredicteToLLVM(arith::CmpIPredicate predicate) {
switch (predicate) {
#define __PRED_ENUM(item__) \
case arith::CmpIPredicate::item__: \
return LLVM::ICmpPredicate::item__
__PRED_ENUM(eq);
__PRED_ENUM(ne);
__PRED_ENUM(sgt);
__PRED_ENUM(sge);
__PRED_ENUM(slt);
__PRED_ENUM(sle);
__PRED_ENUM(ugt);
__PRED_ENUM(uge);
__PRED_ENUM(ult);
__PRED_ENUM(ule);
#undef __PRED_ENUM
}
return LLVM::ICmpPredicate::eq;
}
};
struct CmpFOpConversion
: public BinaryOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
CmpFOpConversion> {
explicit CmpFOpConversion(LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: BinaryOpConversionBase(typeConverter, benefit) {}
// An interface to support variant DestOp builder.
LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op,
ConversionPatternRewriter &rewriter, Type elemTy,
Value lhs, Value rhs, Location loc) const {
return rewriter.create<LLVM::FCmpOp>(
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), lhs, rhs);
}
// Get the left operand of the op.
Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); }
// Get the right operand of the op.
Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); }
static LLVM::FCmpPredicate
ArithCmpFPredicteToLLVM(arith::CmpFPredicate predicate) {
switch (predicate) {
#define __PRED_ENUM(item__, item1__) \
case arith::CmpFPredicate::item__: \
return LLVM::FCmpPredicate::item1__
__PRED_ENUM(OEQ, oeq);
__PRED_ENUM(ONE, one);
__PRED_ENUM(OGT, ogt);
__PRED_ENUM(OGE, oge);
__PRED_ENUM(OLT, olt);
__PRED_ENUM(OLE, ole);
__PRED_ENUM(ORD, ord);
__PRED_ENUM(UEQ, ueq);
__PRED_ENUM(UGT, ugt);
__PRED_ENUM(ULT, ult);
__PRED_ENUM(ULE, ule);
__PRED_ENUM(UNE, une);
__PRED_ENUM(UNO, uno);
__PRED_ENUM(AlwaysTrue, _true);
__PRED_ENUM(AlwaysFalse, _false);
#undef __PRED_ENUM
}
return LLVM::FCmpPredicate::_true;
}
};
struct ConvertLayoutOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
public:
@@ -3011,6 +3141,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::AndIOp, LLVM::AndOp>>(typeConverter,
benefit);
patterns.add<BinaryOpConversion<arith::OrIOp, LLVM::OrOp>>(typeConverter,
benefit);
patterns.add<CmpIOpConversion>(typeConverter, benefit);
patterns.add<CmpFOpConversion>(typeConverter, benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
benefit);

View File

@@ -1210,6 +1210,8 @@ void init_triton_translation(py::module &m) {
llvm::LLVMContext llvmContext;
auto llvmModule =
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
if (!llvmModule)
llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR.");
std::string str;
llvm::raw_string_ostream os(str);

View File

@@ -1,6 +1,6 @@
import pytest
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
import triton
import triton.language as tl
@@ -49,4 +49,4 @@ def test_gemm_impl(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
num_warps=NUM_WARPS)
golden = torch.matmul(a, b)
torch.set_printoptions(profile="full")
assert_allclose(c, golden, rtol=1e-3, atol=1e-3)
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)

View File

@@ -1,6 +1,6 @@
import pytest
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
import triton
import triton.language as tl
@@ -44,4 +44,4 @@ def test_convert_layout_impl(NUM_WARPS, SIZE_M, SIZE_N):
z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype)
kernel[grid](x_ptr=x, stride_xm=x.stride(0), z_ptr=z, stride_zn=z.stride(0), SIZE_M=SIZE_M, SIZE_N=SIZE_N, num_warps=NUM_WARPS)
golden_z = torch.t(x)
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
assert_close(z, golden_z, rtol=1e-7, atol=1e-7, check_dtype=False)

View File

@@ -1,79 +1,215 @@
import math
import random
import pytest
import torch
from torch.testing import assert_allclose
from torch.testing import assert_close
import triton
import triton.language as tl
@pytest.mark.parametrize('NUM_WARPS, BLOCK_SIZE', [
[4, 256],
[2, 256],
[1, 256],
])
def test_vecadd_no_mask(NUM_WARPS, BLOCK_SIZE):
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
x = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
y = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
z = torch.empty((BLOCK_SIZE,), device=x.device, dtype=x.dtype)
grid = lambda EA: (x.shape.numel() // BLOCK_SIZE,)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS)
golden_z = x + y
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
@pytest.mark.parametrize('NUM_WARPS, BLOCK_SIZE, ITER_SIZE', [
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
[4, 256, 1],
[4, 1024, 256],
])
def test_vecadd_scf_no_mask(NUM_WARPS, BLOCK_SIZE, ITER_SIZE):
def test_vecadd_scf_no_mask(num_warps, block_size, iter_size):
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
BLOCK_SIZE,
ITER_SIZE: tl.constexpr):
block_size,
iter_size: tl.constexpr):
pid = tl.program_id(axis=0)
for i in range(0, BLOCK_SIZE, ITER_SIZE):
offset = pid * BLOCK_SIZE + tl.arange(0, ITER_SIZE)
for i in range(0, block_size, iter_size):
offset = pid * block_size + tl.arange(0, iter_size)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs)
y = tl.load(y_ptrs)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
x_ptr += ITER_SIZE
y_ptr += ITER_SIZE
z_ptr += ITER_SIZE
x = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
y = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
z = torch.empty((BLOCK_SIZE,), device=x.device, dtype=x.dtype)
x_ptr += iter_size
y_ptr += iter_size
z_ptr += iter_size
grid = lambda EA: (x.shape.numel() // (BLOCK_SIZE),)
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
grid = lambda EA: (x.shape.numel() // (block_size),)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
BLOCK_SIZE=x.shape[0], ITER_SIZE=ITER_SIZE, num_warps=NUM_WARPS)
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
golden_z = x + y
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
# TODO: test_vecadd with mask
@pytest.mark.parametrize('shape, num_warps, block_size, iter_size', [
[(127, 3), 2, 128, 1],
[(127, 3), 2, 128, 32],
])
def test_vecadd_scf_mask(shape, num_warps, block_size, iter_size):
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
num_elements,
block_size: tl.constexpr,
iter_size: tl.constexpr
):
'''
@block_size: size of a block
@iter_size: size of the iteration, a block has multiple iterations
@num_elements: number of elements
'''
pid = tl.program_id(axis=0)
for i in range(math.ceil(block_size / iter_size)):
# TODO: a bug here, if put the offset outside the forloop, there will be a GPU mis-aligned error.
offset = pid * block_size + tl.arange(0, iter_size)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
x = tl.load(x_ptrs, mask=offset < num_elements)
y = tl.load(y_ptrs, mask=offset < num_elements)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z, mask=offset < num_elements)
x_ptr += iter_size
y_ptr += iter_size
z_ptr += iter_size
x = torch.randn(shape, device='cuda', dtype=torch.float32)
y = torch.randn(shape, device='cuda', dtype=torch.float32)
z = torch.empty(shape, device=x.device, dtype=x.dtype)
grid = lambda EA: (math.ceil(x.numel() / block_size),)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps,
num_elements=x.numel())
golden_z = x + y
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
def vecadd_no_scf_tester(num_warps, block_size, shape):
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
n_elements,
block_size_N: tl.constexpr):
pid = tl.program_id(axis=0)
offset = pid * block_size_N + tl.arange(0, block_size_N)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
mask = offset < n_elements
x = tl.load(x_ptrs, mask=mask)
y = tl.load(y_ptrs, mask=mask)
z = x + y
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z, mask=mask)
x = torch.randn(shape, device='cuda', dtype=torch.float32)
y = torch.randn(shape, device='cuda', dtype=torch.float32)
z = torch.empty(shape, device=x.device, dtype=x.dtype)
grid = lambda EA: (math.ceil(x.shape.numel() / block_size),)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, n_elements=x.shape.numel(), block_size_N=block_size, num_warps=num_warps)
golden_z = x + y
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
def vecadd_fcmp_no_scf_tester(num_warps, block_size, shape):
'''
vecadd tester with float comparation as load/store mask.
'''
@triton.jit
def kernel(x_ptr,
y_ptr,
z_ptr,
n_elements,
block_size_N: tl.constexpr):
pid = tl.program_id(axis=0)
offset = pid * block_size_N + tl.arange(0, block_size_N)
x_ptrs = x_ptr + offset
y_ptrs = y_ptr + offset
io_mask = offset < n_elements
x = tl.load(x_ptrs, mask=io_mask)
y = tl.load(y_ptrs, mask=io_mask)
z = x + y
val_mask = offset < n_elements and (z < 0. or z > 1.)
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z, mask=val_mask)
x = torch.randn(shape, device='cuda', dtype=torch.float32)
y = torch.randn(shape, device='cuda', dtype=torch.float32)
z = torch.zeros(shape, device=x.device, dtype=x.dtype)
grid = lambda EA: (math.ceil(x.shape.numel() / block_size),)
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, n_elements=x.shape.numel(), block_size_N=block_size, num_warps=num_warps)
golden_z: torch.Tensor = x + y
gz_data = torch.flatten(golden_z)
for i in range(golden_z.numel()):
gz_data[i] = gz_data[i] if gz_data[i] < 0. or gz_data[i] > 1. else 0.
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
@pytest.mark.parametrize('num_warps, block_size, shape', [
[4, 256, (256,)],
[2, 256, (256,)],
[1, 256, (256,)],
[4, 16, (256,)],
[2, 64, (256,)],
[1, 128, (256,)],
])
def test_vecadd_no_scf(num_warps, block_size, shape):
vecadd_no_scf_tester(num_warps, block_size, shape)
@pytest.mark.parametrize('num_warps, block_size, shape', [
[1, 128, (256 + 1,)],
[1, 256, (256 + 1,)],
[2, 256, (3, 256 + 7)],
[4, 256, (3, 256 + 7)],
])
def test_vecadd__no_scf_masked(num_warps, block_size, shape):
vecadd_no_scf_tester(num_warps, block_size, shape)
def test_vecadd_no_scf_masked_randomly():
random.seed(0) # fix seed to make random test reproducible
for i in range(10):
num_elements = random.randint(128, 2048)
shape = (num_elements,)
max_warps = num_elements // 32 # floor div
for num_warps in range(1, max_warps):
is_power2 = num_warps & (num_warps - 1) == 0 and num_warps != 0
if not is_power2: continue
block_size = min(32, num_warps * 32)
vecadd_no_scf_tester(num_warps, block_size, shape)
@pytest.mark.parametrize('num_warps, block_size, shape', [
[1, 128, (256 + 1,)],
[1, 256, (256 + 1,)],
[2, 256, (3, 256 + 7)],
[4, 256, (3, 256 + 7)],
])
def test_vecadd_fcmp_no_scf_masked(num_warps, block_size, shape):
vecadd_fcmp_no_scf_tester(num_warps, block_size, shape)

View File

@@ -699,6 +699,28 @@ class CodeGenerator(ast.NodeVisitor):
def visit_Constant(self, node):
return triton.language.constexpr(node.value)
def visit_BoolOp(self, node: ast.BoolOp):
assert len(node.values) == 2
lhs = self.visit(node.values[0])
rhs = self.visit(node.values[1])
if isinstance(lhs, triton.language.constexpr):
lhs = lhs.value
if isinstance(rhs, triton.language.constexpr):
rhs = rhs.value
fn = {
ast.And: 'logical_and',
ast.Or: 'logical_or',
}[type(node.op)]
if self.is_triton_tensor(lhs):
return getattr(lhs, fn)(rhs, _builder=self.builder)
elif self.is_triton_tensor(rhs):
fn = fn[:2] + 'r' + fn[2:]
return getattr(rhs, fn)(lhs, _builder=self.builder)
else:
return getattr(lhs, fn)(rhs)
if sys.version_info < (3, 8):
def visit_NameConstant(self, node):
return triton.language.constexpr(node.value)

View File

@@ -361,8 +361,6 @@ class constexpr:
def __rfloordiv__(self, other):
return other.value // self.value
#
def __gt__(self, other):
return self.value > other.value
@@ -557,6 +555,16 @@ class tensor:
other = _to_tensor(other, _builder)
return semantic.not_equal(self, other, _builder)
@builtin
def logical_and(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.logical_and(self, other, _builder)
@builtin
def logical_or(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.logical_or(self, other, _builder)
@builtin
def __getitem__(self, slices, _builder=None):
if isinstance(slices, slice):

View File

@@ -285,6 +285,22 @@ def xor_(input: tl.tensor,
return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
if not input.type.is_int1():
input = bitcast(input, tl.dtype("int1"), builder)
if not other.type.is_int1():
other = bitcast(other, tl.dtype("int1"), builder)
return and_(input, other, builder)
def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
if not input.type.is_int1():
input = bitcast(input, tl.dtype("int1"), builder)
if not other.type.is_int1():
other = bitcast(other, tl.dtype("int1"), builder)
return or_(input, other, builder)
def lshr(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor: