[triton-mlir][BACKEND] Support masked load/store (#657)
This PR does - fix some bugs to support masked load/store, - refine frontend, and support the `and` and `or` syntax in mask(by extending the BoolOp in python ast.visitor), e.g. `tl.store(..., mask=offset<n and other_conditions)`, - add `arith.cmpI` and `arith.cmpF` op conversion in backend(required by mask), - add more test cases in vecadd.
This commit is contained in:
@@ -145,7 +145,7 @@ std::string PTXInstrExecution::dump() const {
|
|||||||
if (!pred->repr)
|
if (!pred->repr)
|
||||||
os << "@" << pred->dump() << " ";
|
os << "@" << pred->dump() << " ";
|
||||||
else
|
else
|
||||||
os << pred->repr(pred->idx);
|
os << pred->repr(pred->idx) << " ";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string instrRepr = strJoin(instr->instrParts, ".");
|
std::string instrRepr = strJoin(instr->instrParts, ".");
|
||||||
|
@@ -114,16 +114,16 @@ namespace type = mlir::triton::type;
|
|||||||
|
|
||||||
class TritonGPUToLLVMTypeConverter;
|
class TritonGPUToLLVMTypeConverter;
|
||||||
|
|
||||||
// TODO: keep these before we have better debug log utilities
|
// TODO[goostavz]: Remove these methods after we have better debug log utilities
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void print_array(ArrayRef<T> array, const std::string &str) {
|
void printArray(ArrayRef<T> array, const std::string &info) {
|
||||||
std::cout << str << ": ";
|
std::cout << info << ": ";
|
||||||
for (const T &e : array)
|
for (const T &e : array)
|
||||||
std::cout << e << ",";
|
std::cout << e << ",";
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
template <typename T> void print_scalar(const T &e, const std::string &str) {
|
template <typename T> void printScalar(const T &e, const std::string &info) {
|
||||||
std::cout << str << ": " << e << std::endl;
|
std::cout << info << ": " << e << std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
// FuncOpConversion/FuncOpConversionBase is borrowed from
|
||||||
@@ -808,14 +808,22 @@ struct StoreOpConversion
|
|||||||
auto valueElems = getLLVMElems(value, llValue, layout, rewriter, loc);
|
auto valueElems = getLLVMElems(value, llValue, layout, rewriter, loc);
|
||||||
assert(ptrElems.size() == valueElems.size());
|
assert(ptrElems.size() == valueElems.size());
|
||||||
|
|
||||||
|
// Determine the vectorization size
|
||||||
|
size_t vec = getVectorizeSize(ptr, layout);
|
||||||
SmallVector<Value> maskElems;
|
SmallVector<Value> maskElems;
|
||||||
if (llMask) {
|
if (llMask) {
|
||||||
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc);
|
||||||
assert(valueElems.size() == maskElems.size());
|
assert(valueElems.size() == maskElems.size());
|
||||||
}
|
auto maskOrder = mask.getType()
|
||||||
|
.cast<RankedTensorType>()
|
||||||
|
.getEncoding()
|
||||||
|
.cast<BlockedEncodingAttr>()
|
||||||
|
.getOrder();
|
||||||
|
|
||||||
// Determine the vectorization size
|
auto maskAxis = getAxisInfo(mask);
|
||||||
size_t vec = getVectorizeSize(ptr, layout);
|
size_t maskAlign = std::max<int>(maskAxis->getConstancy(maskOrder[0]), 1);
|
||||||
|
vec = std::min(vec, maskAlign);
|
||||||
|
}
|
||||||
|
|
||||||
const size_t dtsize =
|
const size_t dtsize =
|
||||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
||||||
@@ -1376,13 +1384,15 @@ struct ExtractSliceOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename SourceOp, typename DestOp>
|
// A CRTP style of base class.
|
||||||
class BinaryOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
template <typename SourceOp, typename DestOp, typename ConcreteT>
|
||||||
|
class BinaryOpConversionBase
|
||||||
|
: public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||||
public:
|
public:
|
||||||
using OpAdaptor = typename SourceOp::Adaptor;
|
using OpAdaptor = typename SourceOp::Adaptor;
|
||||||
|
|
||||||
explicit BinaryOpConversion(LLVMTypeConverter &typeConverter,
|
explicit BinaryOpConversionBase(LLVMTypeConverter &typeConverter,
|
||||||
PatternBenefit benefit = 1)
|
PatternBenefit benefit = 1)
|
||||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
@@ -1403,13 +1413,16 @@ public:
|
|||||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||||
SmallVector<Type> types(elems, elemTy);
|
SmallVector<Type> types(elems, elemTy);
|
||||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||||
auto lhss =
|
|
||||||
this->getElementsFromStruct(loc, adaptor.getLhs(), elems, rewriter);
|
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
||||||
auto rhss =
|
auto lhss = this->getElementsFromStruct(loc, concreteThis->getLhs(adaptor),
|
||||||
this->getElementsFromStruct(loc, adaptor.getRhs(), elems, rewriter);
|
elems, rewriter);
|
||||||
|
auto rhss = this->getElementsFromStruct(loc, concreteThis->getRhs(adaptor),
|
||||||
|
elems, rewriter);
|
||||||
SmallVector<Value> resultVals(elems);
|
SmallVector<Value> resultVals(elems);
|
||||||
for (unsigned i = 0; i < elems; ++i) {
|
for (unsigned i = 0; i < elems; ++i) {
|
||||||
resultVals[i] = rewriter.create<DestOp>(loc, elemTy, lhss[i], rhss[i]);
|
resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i],
|
||||||
|
rhss[i], loc);
|
||||||
}
|
}
|
||||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||||
rewriter.replaceOp(op, view);
|
rewriter.replaceOp(op, view);
|
||||||
@@ -1417,6 +1430,123 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename SourceOp, typename DestOp>
|
||||||
|
struct BinaryOpConversion
|
||||||
|
: public BinaryOpConversionBase<SourceOp, DestOp,
|
||||||
|
BinaryOpConversion<SourceOp, DestOp>> {
|
||||||
|
|
||||||
|
explicit BinaryOpConversion(LLVMTypeConverter &typeConverter,
|
||||||
|
PatternBenefit benefit = 1)
|
||||||
|
: BinaryOpConversionBase<SourceOp, DestOp,
|
||||||
|
BinaryOpConversion<SourceOp, DestOp>>(
|
||||||
|
typeConverter, benefit) {}
|
||||||
|
|
||||||
|
using OpAdaptor = typename SourceOp::Adaptor;
|
||||||
|
// An interface to support variant DestOp builder.
|
||||||
|
DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter,
|
||||||
|
Type elemTy, Value lhs, Value rhs, Location loc) const {
|
||||||
|
return rewriter.create<DestOp>(loc, elemTy, lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the left operand of the op.
|
||||||
|
Value getLhs(OpAdaptor adaptor) const { return adaptor.getLhs(); }
|
||||||
|
// Get the right operand of the op.
|
||||||
|
Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CmpIOpConversion
|
||||||
|
: public BinaryOpConversionBase<triton::gpu::CmpIOp, LLVM::ICmpOp,
|
||||||
|
CmpIOpConversion> {
|
||||||
|
explicit CmpIOpConversion(LLVMTypeConverter &typeConverter,
|
||||||
|
PatternBenefit benefit = 1)
|
||||||
|
: BinaryOpConversionBase(typeConverter, benefit) {}
|
||||||
|
|
||||||
|
// An interface to support variant DestOp builder.
|
||||||
|
LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op,
|
||||||
|
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||||
|
Value lhs, Value rhs, Location loc) const {
|
||||||
|
return rewriter.create<LLVM::ICmpOp>(
|
||||||
|
loc, elemTy, ArithCmpIPredicteToLLVM(op.predicate()), lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the left operand of the op.
|
||||||
|
Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); }
|
||||||
|
// Get the right operand of the op.
|
||||||
|
Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); }
|
||||||
|
|
||||||
|
static LLVM::ICmpPredicate
|
||||||
|
ArithCmpIPredicteToLLVM(arith::CmpIPredicate predicate) {
|
||||||
|
switch (predicate) {
|
||||||
|
#define __PRED_ENUM(item__) \
|
||||||
|
case arith::CmpIPredicate::item__: \
|
||||||
|
return LLVM::ICmpPredicate::item__
|
||||||
|
|
||||||
|
__PRED_ENUM(eq);
|
||||||
|
__PRED_ENUM(ne);
|
||||||
|
__PRED_ENUM(sgt);
|
||||||
|
__PRED_ENUM(sge);
|
||||||
|
__PRED_ENUM(slt);
|
||||||
|
__PRED_ENUM(sle);
|
||||||
|
__PRED_ENUM(ugt);
|
||||||
|
__PRED_ENUM(uge);
|
||||||
|
__PRED_ENUM(ult);
|
||||||
|
__PRED_ENUM(ule);
|
||||||
|
|
||||||
|
#undef __PRED_ENUM
|
||||||
|
}
|
||||||
|
return LLVM::ICmpPredicate::eq;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CmpFOpConversion
|
||||||
|
: public BinaryOpConversionBase<triton::gpu::CmpFOp, LLVM::FCmpOp,
|
||||||
|
CmpFOpConversion> {
|
||||||
|
explicit CmpFOpConversion(LLVMTypeConverter &typeConverter,
|
||||||
|
PatternBenefit benefit = 1)
|
||||||
|
: BinaryOpConversionBase(typeConverter, benefit) {}
|
||||||
|
|
||||||
|
// An interface to support variant DestOp builder.
|
||||||
|
LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op,
|
||||||
|
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||||
|
Value lhs, Value rhs, Location loc) const {
|
||||||
|
return rewriter.create<LLVM::FCmpOp>(
|
||||||
|
loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the left operand of the op.
|
||||||
|
Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); }
|
||||||
|
// Get the right operand of the op.
|
||||||
|
Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); }
|
||||||
|
|
||||||
|
static LLVM::FCmpPredicate
|
||||||
|
ArithCmpFPredicteToLLVM(arith::CmpFPredicate predicate) {
|
||||||
|
switch (predicate) {
|
||||||
|
#define __PRED_ENUM(item__, item1__) \
|
||||||
|
case arith::CmpFPredicate::item__: \
|
||||||
|
return LLVM::FCmpPredicate::item1__
|
||||||
|
|
||||||
|
__PRED_ENUM(OEQ, oeq);
|
||||||
|
__PRED_ENUM(ONE, one);
|
||||||
|
__PRED_ENUM(OGT, ogt);
|
||||||
|
__PRED_ENUM(OGE, oge);
|
||||||
|
__PRED_ENUM(OLT, olt);
|
||||||
|
__PRED_ENUM(OLE, ole);
|
||||||
|
__PRED_ENUM(ORD, ord);
|
||||||
|
__PRED_ENUM(UEQ, ueq);
|
||||||
|
__PRED_ENUM(UGT, ugt);
|
||||||
|
__PRED_ENUM(ULT, ult);
|
||||||
|
__PRED_ENUM(ULE, ule);
|
||||||
|
__PRED_ENUM(UNE, une);
|
||||||
|
__PRED_ENUM(UNO, uno);
|
||||||
|
__PRED_ENUM(AlwaysTrue, _true);
|
||||||
|
__PRED_ENUM(AlwaysFalse, _false);
|
||||||
|
|
||||||
|
#undef __PRED_ENUM
|
||||||
|
}
|
||||||
|
return LLVM::FCmpPredicate::_true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct ConvertLayoutOpConversion
|
struct ConvertLayoutOpConversion
|
||||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
|
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
|
||||||
public:
|
public:
|
||||||
@@ -3011,6 +3141,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
|||||||
benefit);
|
benefit);
|
||||||
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
|
patterns.add<BinaryOpConversion<arith::MulFOp, LLVM::FMulOp>>(typeConverter,
|
||||||
benefit);
|
benefit);
|
||||||
|
|
||||||
|
patterns.add<BinaryOpConversion<arith::AndIOp, LLVM::AndOp>>(typeConverter,
|
||||||
|
benefit);
|
||||||
|
patterns.add<BinaryOpConversion<arith::OrIOp, LLVM::OrOp>>(typeConverter,
|
||||||
|
benefit);
|
||||||
|
|
||||||
|
patterns.add<CmpIOpConversion>(typeConverter, benefit);
|
||||||
|
patterns.add<CmpFOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||||
benefit);
|
benefit);
|
||||||
|
@@ -1210,6 +1210,8 @@ void init_triton_translation(py::module &m) {
|
|||||||
llvm::LLVMContext llvmContext;
|
llvm::LLVMContext llvmContext;
|
||||||
auto llvmModule =
|
auto llvmModule =
|
||||||
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
|
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
|
||||||
|
if (!llvmModule)
|
||||||
|
llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR.");
|
||||||
|
|
||||||
std::string str;
|
std::string str;
|
||||||
llvm::raw_string_ostream os(str);
|
llvm::raw_string_ostream os(str);
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch.testing import assert_allclose
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@@ -49,4 +49,4 @@ def test_gemm_impl(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS):
|
|||||||
num_warps=NUM_WARPS)
|
num_warps=NUM_WARPS)
|
||||||
golden = torch.matmul(a, b)
|
golden = torch.matmul(a, b)
|
||||||
torch.set_printoptions(profile="full")
|
torch.set_printoptions(profile="full")
|
||||||
assert_allclose(c, golden, rtol=1e-3, atol=1e-3)
|
assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False)
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch.testing import assert_allclose
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@@ -44,4 +44,4 @@ def test_convert_layout_impl(NUM_WARPS, SIZE_M, SIZE_N):
|
|||||||
z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype)
|
z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype)
|
||||||
kernel[grid](x_ptr=x, stride_xm=x.stride(0), z_ptr=z, stride_zn=z.stride(0), SIZE_M=SIZE_M, SIZE_N=SIZE_N, num_warps=NUM_WARPS)
|
kernel[grid](x_ptr=x, stride_xm=x.stride(0), z_ptr=z, stride_zn=z.stride(0), SIZE_M=SIZE_M, SIZE_N=SIZE_N, num_warps=NUM_WARPS)
|
||||||
golden_z = torch.t(x)
|
golden_z = torch.t(x)
|
||||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
assert_close(z, golden_z, rtol=1e-7, atol=1e-7, check_dtype=False)
|
||||||
|
@@ -1,79 +1,215 @@
|
|||||||
|
import math
|
||||||
|
import random
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch.testing import assert_allclose
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('NUM_WARPS, BLOCK_SIZE', [
|
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
||||||
[4, 256],
|
|
||||||
[2, 256],
|
|
||||||
[1, 256],
|
|
||||||
])
|
|
||||||
def test_vecadd_no_mask(NUM_WARPS, BLOCK_SIZE):
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def kernel(x_ptr,
|
|
||||||
y_ptr,
|
|
||||||
z_ptr,
|
|
||||||
BLOCK_SIZE: tl.constexpr):
|
|
||||||
pid = tl.program_id(axis=0)
|
|
||||||
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
||||||
x_ptrs = x_ptr + offset
|
|
||||||
y_ptrs = y_ptr + offset
|
|
||||||
x = tl.load(x_ptrs)
|
|
||||||
y = tl.load(y_ptrs)
|
|
||||||
z = x + y
|
|
||||||
z_ptrs = z_ptr + offset
|
|
||||||
tl.store(z_ptrs, z)
|
|
||||||
|
|
||||||
x = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
|
|
||||||
y = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
|
|
||||||
z = torch.empty((BLOCK_SIZE,), device=x.device, dtype=x.dtype)
|
|
||||||
|
|
||||||
grid = lambda EA: (x.shape.numel() // BLOCK_SIZE,)
|
|
||||||
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS)
|
|
||||||
|
|
||||||
golden_z = x + y
|
|
||||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('NUM_WARPS, BLOCK_SIZE, ITER_SIZE', [
|
|
||||||
[4, 256, 1],
|
[4, 256, 1],
|
||||||
[4, 1024, 256],
|
[4, 1024, 256],
|
||||||
])
|
])
|
||||||
def test_vecadd_scf_no_mask(NUM_WARPS, BLOCK_SIZE, ITER_SIZE):
|
def test_vecadd_scf_no_mask(num_warps, block_size, iter_size):
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(x_ptr,
|
def kernel(x_ptr,
|
||||||
y_ptr,
|
y_ptr,
|
||||||
z_ptr,
|
z_ptr,
|
||||||
BLOCK_SIZE,
|
block_size,
|
||||||
ITER_SIZE: tl.constexpr):
|
iter_size: tl.constexpr):
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
for i in range(0, BLOCK_SIZE, ITER_SIZE):
|
for i in range(0, block_size, iter_size):
|
||||||
offset = pid * BLOCK_SIZE + tl.arange(0, ITER_SIZE)
|
offset = pid * block_size + tl.arange(0, iter_size)
|
||||||
x_ptrs = x_ptr + offset
|
x_ptrs = x_ptr + offset
|
||||||
y_ptrs = y_ptr + offset
|
y_ptrs = y_ptr + offset
|
||||||
|
|
||||||
x = tl.load(x_ptrs)
|
x = tl.load(x_ptrs)
|
||||||
y = tl.load(y_ptrs)
|
y = tl.load(y_ptrs)
|
||||||
z = x + y
|
z = x + y
|
||||||
z_ptrs = z_ptr + offset
|
z_ptrs = z_ptr + offset
|
||||||
tl.store(z_ptrs, z)
|
tl.store(z_ptrs, z)
|
||||||
x_ptr += ITER_SIZE
|
|
||||||
y_ptr += ITER_SIZE
|
|
||||||
z_ptr += ITER_SIZE
|
|
||||||
|
|
||||||
x = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
|
x_ptr += iter_size
|
||||||
y = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32)
|
y_ptr += iter_size
|
||||||
z = torch.empty((BLOCK_SIZE,), device=x.device, dtype=x.dtype)
|
z_ptr += iter_size
|
||||||
|
|
||||||
grid = lambda EA: (x.shape.numel() // (BLOCK_SIZE),)
|
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||||
|
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
||||||
|
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
grid = lambda EA: (x.shape.numel() // (block_size),)
|
||||||
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
|
||||||
BLOCK_SIZE=x.shape[0], ITER_SIZE=ITER_SIZE, num_warps=NUM_WARPS)
|
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
|
||||||
|
|
||||||
golden_z = x + y
|
golden_z = x + y
|
||||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||||
|
|
||||||
# TODO: test_vecadd with mask
|
|
||||||
|
@pytest.mark.parametrize('shape, num_warps, block_size, iter_size', [
|
||||||
|
[(127, 3), 2, 128, 1],
|
||||||
|
[(127, 3), 2, 128, 32],
|
||||||
|
])
|
||||||
|
def test_vecadd_scf_mask(shape, num_warps, block_size, iter_size):
|
||||||
|
@triton.jit
|
||||||
|
def kernel(x_ptr,
|
||||||
|
y_ptr,
|
||||||
|
z_ptr,
|
||||||
|
num_elements,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
iter_size: tl.constexpr
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
@block_size: size of a block
|
||||||
|
@iter_size: size of the iteration, a block has multiple iterations
|
||||||
|
@num_elements: number of elements
|
||||||
|
'''
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
for i in range(math.ceil(block_size / iter_size)):
|
||||||
|
# TODO: a bug here, if put the offset outside the forloop, there will be a GPU mis-aligned error.
|
||||||
|
offset = pid * block_size + tl.arange(0, iter_size)
|
||||||
|
x_ptrs = x_ptr + offset
|
||||||
|
y_ptrs = y_ptr + offset
|
||||||
|
|
||||||
|
x = tl.load(x_ptrs, mask=offset < num_elements)
|
||||||
|
y = tl.load(y_ptrs, mask=offset < num_elements)
|
||||||
|
z = x + y
|
||||||
|
z_ptrs = z_ptr + offset
|
||||||
|
tl.store(z_ptrs, z, mask=offset < num_elements)
|
||||||
|
|
||||||
|
x_ptr += iter_size
|
||||||
|
y_ptr += iter_size
|
||||||
|
z_ptr += iter_size
|
||||||
|
|
||||||
|
x = torch.randn(shape, device='cuda', dtype=torch.float32)
|
||||||
|
y = torch.randn(shape, device='cuda', dtype=torch.float32)
|
||||||
|
z = torch.empty(shape, device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
grid = lambda EA: (math.ceil(x.numel() / block_size),)
|
||||||
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
|
||||||
|
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps,
|
||||||
|
num_elements=x.numel())
|
||||||
|
|
||||||
|
golden_z = x + y
|
||||||
|
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||||
|
|
||||||
|
|
||||||
|
def vecadd_no_scf_tester(num_warps, block_size, shape):
|
||||||
|
@triton.jit
|
||||||
|
def kernel(x_ptr,
|
||||||
|
y_ptr,
|
||||||
|
z_ptr,
|
||||||
|
n_elements,
|
||||||
|
block_size_N: tl.constexpr):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
|
||||||
|
offset = pid * block_size_N + tl.arange(0, block_size_N)
|
||||||
|
x_ptrs = x_ptr + offset
|
||||||
|
y_ptrs = y_ptr + offset
|
||||||
|
|
||||||
|
mask = offset < n_elements
|
||||||
|
|
||||||
|
x = tl.load(x_ptrs, mask=mask)
|
||||||
|
y = tl.load(y_ptrs, mask=mask)
|
||||||
|
z = x + y
|
||||||
|
z_ptrs = z_ptr + offset
|
||||||
|
tl.store(z_ptrs, z, mask=mask)
|
||||||
|
|
||||||
|
x = torch.randn(shape, device='cuda', dtype=torch.float32)
|
||||||
|
y = torch.randn(shape, device='cuda', dtype=torch.float32)
|
||||||
|
z = torch.empty(shape, device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
grid = lambda EA: (math.ceil(x.shape.numel() / block_size),)
|
||||||
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, n_elements=x.shape.numel(), block_size_N=block_size, num_warps=num_warps)
|
||||||
|
|
||||||
|
golden_z = x + y
|
||||||
|
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||||
|
|
||||||
|
|
||||||
|
def vecadd_fcmp_no_scf_tester(num_warps, block_size, shape):
|
||||||
|
'''
|
||||||
|
vecadd tester with float comparation as load/store mask.
|
||||||
|
'''
|
||||||
|
@triton.jit
|
||||||
|
def kernel(x_ptr,
|
||||||
|
y_ptr,
|
||||||
|
z_ptr,
|
||||||
|
n_elements,
|
||||||
|
block_size_N: tl.constexpr):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
|
||||||
|
offset = pid * block_size_N + tl.arange(0, block_size_N)
|
||||||
|
x_ptrs = x_ptr + offset
|
||||||
|
y_ptrs = y_ptr + offset
|
||||||
|
|
||||||
|
io_mask = offset < n_elements
|
||||||
|
x = tl.load(x_ptrs, mask=io_mask)
|
||||||
|
y = tl.load(y_ptrs, mask=io_mask)
|
||||||
|
|
||||||
|
z = x + y
|
||||||
|
val_mask = offset < n_elements and (z < 0. or z > 1.)
|
||||||
|
|
||||||
|
z_ptrs = z_ptr + offset
|
||||||
|
tl.store(z_ptrs, z, mask=val_mask)
|
||||||
|
|
||||||
|
x = torch.randn(shape, device='cuda', dtype=torch.float32)
|
||||||
|
y = torch.randn(shape, device='cuda', dtype=torch.float32)
|
||||||
|
z = torch.zeros(shape, device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
grid = lambda EA: (math.ceil(x.shape.numel() / block_size),)
|
||||||
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, n_elements=x.shape.numel(), block_size_N=block_size, num_warps=num_warps)
|
||||||
|
|
||||||
|
golden_z: torch.Tensor = x + y
|
||||||
|
gz_data = torch.flatten(golden_z)
|
||||||
|
for i in range(golden_z.numel()):
|
||||||
|
gz_data[i] = gz_data[i] if gz_data[i] < 0. or gz_data[i] > 1. else 0.
|
||||||
|
|
||||||
|
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('num_warps, block_size, shape', [
|
||||||
|
[4, 256, (256,)],
|
||||||
|
[2, 256, (256,)],
|
||||||
|
[1, 256, (256,)],
|
||||||
|
[4, 16, (256,)],
|
||||||
|
[2, 64, (256,)],
|
||||||
|
[1, 128, (256,)],
|
||||||
|
])
|
||||||
|
def test_vecadd_no_scf(num_warps, block_size, shape):
|
||||||
|
vecadd_no_scf_tester(num_warps, block_size, shape)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('num_warps, block_size, shape', [
|
||||||
|
[1, 128, (256 + 1,)],
|
||||||
|
[1, 256, (256 + 1,)],
|
||||||
|
[2, 256, (3, 256 + 7)],
|
||||||
|
[4, 256, (3, 256 + 7)],
|
||||||
|
])
|
||||||
|
def test_vecadd__no_scf_masked(num_warps, block_size, shape):
|
||||||
|
vecadd_no_scf_tester(num_warps, block_size, shape)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vecadd_no_scf_masked_randomly():
|
||||||
|
random.seed(0) # fix seed to make random test reproducible
|
||||||
|
for i in range(10):
|
||||||
|
num_elements = random.randint(128, 2048)
|
||||||
|
shape = (num_elements,)
|
||||||
|
max_warps = num_elements // 32 # floor div
|
||||||
|
for num_warps in range(1, max_warps):
|
||||||
|
is_power2 = num_warps & (num_warps - 1) == 0 and num_warps != 0
|
||||||
|
if not is_power2: continue
|
||||||
|
block_size = min(32, num_warps * 32)
|
||||||
|
vecadd_no_scf_tester(num_warps, block_size, shape)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('num_warps, block_size, shape', [
|
||||||
|
[1, 128, (256 + 1,)],
|
||||||
|
[1, 256, (256 + 1,)],
|
||||||
|
[2, 256, (3, 256 + 7)],
|
||||||
|
[4, 256, (3, 256 + 7)],
|
||||||
|
])
|
||||||
|
def test_vecadd_fcmp_no_scf_masked(num_warps, block_size, shape):
|
||||||
|
vecadd_fcmp_no_scf_tester(num_warps, block_size, shape)
|
||||||
|
@@ -699,6 +699,28 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
def visit_Constant(self, node):
|
def visit_Constant(self, node):
|
||||||
return triton.language.constexpr(node.value)
|
return triton.language.constexpr(node.value)
|
||||||
|
|
||||||
|
def visit_BoolOp(self, node: ast.BoolOp):
|
||||||
|
assert len(node.values) == 2
|
||||||
|
lhs = self.visit(node.values[0])
|
||||||
|
rhs = self.visit(node.values[1])
|
||||||
|
if isinstance(lhs, triton.language.constexpr):
|
||||||
|
lhs = lhs.value
|
||||||
|
if isinstance(rhs, triton.language.constexpr):
|
||||||
|
rhs = rhs.value
|
||||||
|
|
||||||
|
fn = {
|
||||||
|
ast.And: 'logical_and',
|
||||||
|
ast.Or: 'logical_or',
|
||||||
|
}[type(node.op)]
|
||||||
|
|
||||||
|
if self.is_triton_tensor(lhs):
|
||||||
|
return getattr(lhs, fn)(rhs, _builder=self.builder)
|
||||||
|
elif self.is_triton_tensor(rhs):
|
||||||
|
fn = fn[:2] + 'r' + fn[2:]
|
||||||
|
return getattr(rhs, fn)(lhs, _builder=self.builder)
|
||||||
|
else:
|
||||||
|
return getattr(lhs, fn)(rhs)
|
||||||
|
|
||||||
if sys.version_info < (3, 8):
|
if sys.version_info < (3, 8):
|
||||||
def visit_NameConstant(self, node):
|
def visit_NameConstant(self, node):
|
||||||
return triton.language.constexpr(node.value)
|
return triton.language.constexpr(node.value)
|
||||||
|
@@ -361,8 +361,6 @@ class constexpr:
|
|||||||
def __rfloordiv__(self, other):
|
def __rfloordiv__(self, other):
|
||||||
return other.value // self.value
|
return other.value // self.value
|
||||||
|
|
||||||
#
|
|
||||||
|
|
||||||
def __gt__(self, other):
|
def __gt__(self, other):
|
||||||
return self.value > other.value
|
return self.value > other.value
|
||||||
|
|
||||||
@@ -557,6 +555,16 @@ class tensor:
|
|||||||
other = _to_tensor(other, _builder)
|
other = _to_tensor(other, _builder)
|
||||||
return semantic.not_equal(self, other, _builder)
|
return semantic.not_equal(self, other, _builder)
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def logical_and(self, other, _builder=None):
|
||||||
|
other = _to_tensor(other, _builder)
|
||||||
|
return semantic.logical_and(self, other, _builder)
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def logical_or(self, other, _builder=None):
|
||||||
|
other = _to_tensor(other, _builder)
|
||||||
|
return semantic.logical_or(self, other, _builder)
|
||||||
|
|
||||||
@builtin
|
@builtin
|
||||||
def __getitem__(self, slices, _builder=None):
|
def __getitem__(self, slices, _builder=None):
|
||||||
if isinstance(slices, slice):
|
if isinstance(slices, slice):
|
||||||
|
@@ -285,6 +285,22 @@ def xor_(input: tl.tensor,
|
|||||||
return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
|
return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
|
||||||
|
|
||||||
|
|
||||||
|
def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||||
|
if not input.type.is_int1():
|
||||||
|
input = bitcast(input, tl.dtype("int1"), builder)
|
||||||
|
if not other.type.is_int1():
|
||||||
|
other = bitcast(other, tl.dtype("int1"), builder)
|
||||||
|
return and_(input, other, builder)
|
||||||
|
|
||||||
|
|
||||||
|
def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||||
|
if not input.type.is_int1():
|
||||||
|
input = bitcast(input, tl.dtype("int1"), builder)
|
||||||
|
if not other.type.is_int1():
|
||||||
|
other = bitcast(other, tl.dtype("int1"), builder)
|
||||||
|
return or_(input, other, builder)
|
||||||
|
|
||||||
|
|
||||||
def lshr(input: tl.tensor,
|
def lshr(input: tl.tensor,
|
||||||
other: tl.tensor,
|
other: tl.tensor,
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
|
Reference in New Issue
Block a user