From 555f94f9b9073b26539db6a71d9c0fbcd773aaf3 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Mon, 10 Oct 2022 13:29:53 +0800 Subject: [PATCH] [triton-mlir][BACKEND] Support masked load/store (#657) This PR does - fix some bugs to support masked load/store, - refine frontend, and support the `and` and `or` syntax in mask(by extending the BoolOp in python ast.visitor), e.g. `tl.store(..., mask=offsetrepr) os << "@" << pred->dump() << " "; else - os << pred->repr(pred->idx); + os << pred->repr(pred->idx) << " "; } std::string instrRepr = strJoin(instr->instrParts, "."); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index b58f38094..ac7111bf2 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -114,16 +114,16 @@ namespace type = mlir::triton::type; class TritonGPUToLLVMTypeConverter; -// TODO: keep these before we have better debug log utilities +// TODO[goostavz]: Remove these methods after we have better debug log utilities template -void print_array(ArrayRef array, const std::string &str) { - std::cout << str << ": "; +void printArray(ArrayRef array, const std::string &info) { + std::cout << info << ": "; for (const T &e : array) std::cout << e << ","; std::cout << std::endl; } -template void print_scalar(const T &e, const std::string &str) { - std::cout << str << ": " << e << std::endl; +template void printScalar(const T &e, const std::string &info) { + std::cout << info << ": " << e << std::endl; } // FuncOpConversion/FuncOpConversionBase is borrowed from @@ -808,14 +808,22 @@ struct StoreOpConversion auto valueElems = getLLVMElems(value, llValue, layout, rewriter, loc); assert(ptrElems.size() == valueElems.size()); + // Determine the vectorization size + size_t vec = getVectorizeSize(ptr, layout); SmallVector maskElems; if (llMask) { maskElems = getLLVMElems(mask, llMask, layout, rewriter, loc); assert(valueElems.size() == maskElems.size()); - } + auto maskOrder = mask.getType() + .cast() + .getEncoding() + .cast() + .getOrder(); - // Determine the vectorization size - size_t vec = getVectorizeSize(ptr, layout); + auto maskAxis = getAxisInfo(mask); + size_t maskAlign = std::max(maskAxis->getConstancy(maskOrder[0]), 1); + vec = std::min(vec, maskAlign); + } const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); @@ -1376,13 +1384,15 @@ struct ExtractSliceOpConversion } }; -template -class BinaryOpConversion : public ConvertTritonGPUOpToLLVMPattern { +// A CRTP style of base class. +template +class BinaryOpConversionBase + : public ConvertTritonGPUOpToLLVMPattern { public: using OpAdaptor = typename SourceOp::Adaptor; - explicit BinaryOpConversion(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) + explicit BinaryOpConversionBase(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} LogicalResult @@ -1403,13 +1413,16 @@ public: this->getTypeConverter()->convertType(resultTy.getElementType()); SmallVector types(elems, elemTy); Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); - auto lhss = - this->getElementsFromStruct(loc, adaptor.getLhs(), elems, rewriter); - auto rhss = - this->getElementsFromStruct(loc, adaptor.getRhs(), elems, rewriter); + + auto *concreteThis = static_cast(this); + auto lhss = this->getElementsFromStruct(loc, concreteThis->getLhs(adaptor), + elems, rewriter); + auto rhss = this->getElementsFromStruct(loc, concreteThis->getRhs(adaptor), + elems, rewriter); SmallVector resultVals(elems); for (unsigned i = 0; i < elems; ++i) { - resultVals[i] = rewriter.create(loc, elemTy, lhss[i], rhss[i]); + resultVals[i] = concreteThis->createDestOp(op, rewriter, elemTy, lhss[i], + rhss[i], loc); } Value view = getStructFromElements(loc, resultVals, rewriter, structTy); rewriter.replaceOp(op, view); @@ -1417,6 +1430,123 @@ public: } }; +template +struct BinaryOpConversion + : public BinaryOpConversionBase> { + + explicit BinaryOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : BinaryOpConversionBase>( + typeConverter, benefit) {} + + using OpAdaptor = typename SourceOp::Adaptor; + // An interface to support variant DestOp builder. + DestOp createDestOp(SourceOp op, ConversionPatternRewriter &rewriter, + Type elemTy, Value lhs, Value rhs, Location loc) const { + return rewriter.create(loc, elemTy, lhs, rhs); + } + + // Get the left operand of the op. + Value getLhs(OpAdaptor adaptor) const { return adaptor.getLhs(); } + // Get the right operand of the op. + Value getRhs(OpAdaptor adaptor) const { return adaptor.getRhs(); } +}; + +struct CmpIOpConversion + : public BinaryOpConversionBase { + explicit CmpIOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : BinaryOpConversionBase(typeConverter, benefit) {} + + // An interface to support variant DestOp builder. + LLVM::ICmpOp createDestOp(triton::gpu::CmpIOp op, + ConversionPatternRewriter &rewriter, Type elemTy, + Value lhs, Value rhs, Location loc) const { + return rewriter.create( + loc, elemTy, ArithCmpIPredicteToLLVM(op.predicate()), lhs, rhs); + } + + // Get the left operand of the op. + Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); } + // Get the right operand of the op. + Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); } + + static LLVM::ICmpPredicate + ArithCmpIPredicteToLLVM(arith::CmpIPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__) \ + case arith::CmpIPredicate::item__: \ + return LLVM::ICmpPredicate::item__ + + __PRED_ENUM(eq); + __PRED_ENUM(ne); + __PRED_ENUM(sgt); + __PRED_ENUM(sge); + __PRED_ENUM(slt); + __PRED_ENUM(sle); + __PRED_ENUM(ugt); + __PRED_ENUM(uge); + __PRED_ENUM(ult); + __PRED_ENUM(ule); + +#undef __PRED_ENUM + } + return LLVM::ICmpPredicate::eq; + } +}; + +struct CmpFOpConversion + : public BinaryOpConversionBase { + explicit CmpFOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : BinaryOpConversionBase(typeConverter, benefit) {} + + // An interface to support variant DestOp builder. + LLVM::FCmpOp createDestOp(triton::gpu::CmpFOp op, + ConversionPatternRewriter &rewriter, Type elemTy, + Value lhs, Value rhs, Location loc) const { + return rewriter.create( + loc, elemTy, ArithCmpFPredicteToLLVM(op.predicate()), lhs, rhs); + } + + // Get the left operand of the op. + Value getLhs(OpAdaptor adaptor) const { return adaptor.lhs(); } + // Get the right operand of the op. + Value getRhs(OpAdaptor adaptor) const { return adaptor.rhs(); } + + static LLVM::FCmpPredicate + ArithCmpFPredicteToLLVM(arith::CmpFPredicate predicate) { + switch (predicate) { +#define __PRED_ENUM(item__, item1__) \ + case arith::CmpFPredicate::item__: \ + return LLVM::FCmpPredicate::item1__ + + __PRED_ENUM(OEQ, oeq); + __PRED_ENUM(ONE, one); + __PRED_ENUM(OGT, ogt); + __PRED_ENUM(OGE, oge); + __PRED_ENUM(OLT, olt); + __PRED_ENUM(OLE, ole); + __PRED_ENUM(ORD, ord); + __PRED_ENUM(UEQ, ueq); + __PRED_ENUM(UGT, ugt); + __PRED_ENUM(ULT, ult); + __PRED_ENUM(ULE, ule); + __PRED_ENUM(UNE, une); + __PRED_ENUM(UNO, uno); + __PRED_ENUM(AlwaysTrue, _true); + __PRED_ENUM(AlwaysFalse, _false); + +#undef __PRED_ENUM + } + return LLVM::FCmpPredicate::_true; + } +}; + struct ConvertLayoutOpConversion : public ConvertTritonGPUOpToLLVMPattern { public: @@ -3011,6 +3141,14 @@ void populateTritonToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, benefit); patterns.add>(typeConverter, benefit); + + patterns.add>(typeConverter, + benefit); + patterns.add>(typeConverter, + benefit); + + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, allocation, smem, benefit); diff --git a/python/src/triton.cc b/python/src/triton.cc index 1521d4bd1..6512546c8 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1210,6 +1210,8 @@ void init_triton_translation(py::module &m) { llvm::LLVMContext llvmContext; auto llvmModule = ::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op); + if (!llvmModule) + llvm::report_fatal_error("Failed to translate TritonGPU to LLVM IR."); std::string str; llvm::raw_string_ostream os(str); diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 2d3e88170..3e8c1173d 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -1,6 +1,6 @@ import pytest import torch -from torch.testing import assert_allclose +from torch.testing import assert_close import triton import triton.language as tl @@ -49,4 +49,4 @@ def test_gemm_impl(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS): num_warps=NUM_WARPS) golden = torch.matmul(a, b) torch.set_printoptions(profile="full") - assert_allclose(c, golden, rtol=1e-3, atol=1e-3) + assert_close(c, golden, rtol=1e-3, atol=1e-3, check_dtype=False) diff --git a/python/tests/test_transpose.py b/python/tests/test_transpose.py index 8875b7feb..b7a1a09d5 100644 --- a/python/tests/test_transpose.py +++ b/python/tests/test_transpose.py @@ -1,6 +1,6 @@ import pytest import torch -from torch.testing import assert_allclose +from torch.testing import assert_close import triton import triton.language as tl @@ -44,4 +44,4 @@ def test_convert_layout_impl(NUM_WARPS, SIZE_M, SIZE_N): z = torch.empty((SIZE_N, SIZE_M), device=x.device, dtype=x.dtype) kernel[grid](x_ptr=x, stride_xm=x.stride(0), z_ptr=z, stride_zn=z.stride(0), SIZE_M=SIZE_M, SIZE_N=SIZE_N, num_warps=NUM_WARPS) golden_z = torch.t(x) - assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7) + assert_close(z, golden_z, rtol=1e-7, atol=1e-7, check_dtype=False) diff --git a/python/tests/test_vecadd.py b/python/tests/test_vecadd.py index 1c73979f5..187dc115f 100644 --- a/python/tests/test_vecadd.py +++ b/python/tests/test_vecadd.py @@ -1,79 +1,215 @@ +import math +import random + import pytest import torch -from torch.testing import assert_allclose +from torch.testing import assert_close import triton import triton.language as tl -@pytest.mark.parametrize('NUM_WARPS, BLOCK_SIZE', [ - [4, 256], - [2, 256], - [1, 256], -]) -def test_vecadd_no_mask(NUM_WARPS, BLOCK_SIZE): - - @triton.jit - def kernel(x_ptr, - y_ptr, - z_ptr, - BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - x_ptrs = x_ptr + offset - y_ptrs = y_ptr + offset - x = tl.load(x_ptrs) - y = tl.load(y_ptrs) - z = x + y - z_ptrs = z_ptr + offset - tl.store(z_ptrs, z) - - x = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32) - y = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32) - z = torch.empty((BLOCK_SIZE,), device=x.device, dtype=x.dtype) - - grid = lambda EA: (x.shape.numel() // BLOCK_SIZE,) - kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, BLOCK_SIZE=BLOCK_SIZE, num_warps=NUM_WARPS) - - golden_z = x + y - assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7) - - -@pytest.mark.parametrize('NUM_WARPS, BLOCK_SIZE, ITER_SIZE', [ +@pytest.mark.parametrize('num_warps, block_size, iter_size', [ [4, 256, 1], [4, 1024, 256], ]) -def test_vecadd_scf_no_mask(NUM_WARPS, BLOCK_SIZE, ITER_SIZE): +def test_vecadd_scf_no_mask(num_warps, block_size, iter_size): @triton.jit def kernel(x_ptr, y_ptr, z_ptr, - BLOCK_SIZE, - ITER_SIZE: tl.constexpr): + block_size, + iter_size: tl.constexpr): pid = tl.program_id(axis=0) - for i in range(0, BLOCK_SIZE, ITER_SIZE): - offset = pid * BLOCK_SIZE + tl.arange(0, ITER_SIZE) + for i in range(0, block_size, iter_size): + offset = pid * block_size + tl.arange(0, iter_size) x_ptrs = x_ptr + offset y_ptrs = y_ptr + offset + x = tl.load(x_ptrs) y = tl.load(y_ptrs) z = x + y z_ptrs = z_ptr + offset tl.store(z_ptrs, z) - x_ptr += ITER_SIZE - y_ptr += ITER_SIZE - z_ptr += ITER_SIZE - x = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32) - y = torch.randn((BLOCK_SIZE,), device='cuda', dtype=torch.float32) - z = torch.empty((BLOCK_SIZE,), device=x.device, dtype=x.dtype) + x_ptr += iter_size + y_ptr += iter_size + z_ptr += iter_size - grid = lambda EA: (x.shape.numel() // (BLOCK_SIZE),) + x = torch.randn((block_size,), device='cuda', dtype=torch.float32) + y = torch.randn((block_size,), device='cuda', dtype=torch.float32) + z = torch.empty((block_size,), device=x.device, dtype=x.dtype) + + grid = lambda EA: (x.shape.numel() // (block_size),) kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, - BLOCK_SIZE=x.shape[0], ITER_SIZE=ITER_SIZE, num_warps=NUM_WARPS) + block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps) golden_z = x + y - assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7) + assert_close(z, golden_z, rtol=1e-7, atol=1e-7) -# TODO: test_vecadd with mask + +@pytest.mark.parametrize('shape, num_warps, block_size, iter_size', [ + [(127, 3), 2, 128, 1], + [(127, 3), 2, 128, 32], +]) +def test_vecadd_scf_mask(shape, num_warps, block_size, iter_size): + @triton.jit + def kernel(x_ptr, + y_ptr, + z_ptr, + num_elements, + block_size: tl.constexpr, + iter_size: tl.constexpr + ): + ''' + @block_size: size of a block + @iter_size: size of the iteration, a block has multiple iterations + @num_elements: number of elements + ''' + pid = tl.program_id(axis=0) + for i in range(math.ceil(block_size / iter_size)): + # TODO: a bug here, if put the offset outside the forloop, there will be a GPU mis-aligned error. + offset = pid * block_size + tl.arange(0, iter_size) + x_ptrs = x_ptr + offset + y_ptrs = y_ptr + offset + + x = tl.load(x_ptrs, mask=offset < num_elements) + y = tl.load(y_ptrs, mask=offset < num_elements) + z = x + y + z_ptrs = z_ptr + offset + tl.store(z_ptrs, z, mask=offset < num_elements) + + x_ptr += iter_size + y_ptr += iter_size + z_ptr += iter_size + + x = torch.randn(shape, device='cuda', dtype=torch.float32) + y = torch.randn(shape, device='cuda', dtype=torch.float32) + z = torch.empty(shape, device=x.device, dtype=x.dtype) + + grid = lambda EA: (math.ceil(x.numel() / block_size),) + kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, + block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps, + num_elements=x.numel()) + + golden_z = x + y + assert_close(z, golden_z, rtol=1e-7, atol=1e-7) + + +def vecadd_no_scf_tester(num_warps, block_size, shape): + @triton.jit + def kernel(x_ptr, + y_ptr, + z_ptr, + n_elements, + block_size_N: tl.constexpr): + pid = tl.program_id(axis=0) + + offset = pid * block_size_N + tl.arange(0, block_size_N) + x_ptrs = x_ptr + offset + y_ptrs = y_ptr + offset + + mask = offset < n_elements + + x = tl.load(x_ptrs, mask=mask) + y = tl.load(y_ptrs, mask=mask) + z = x + y + z_ptrs = z_ptr + offset + tl.store(z_ptrs, z, mask=mask) + + x = torch.randn(shape, device='cuda', dtype=torch.float32) + y = torch.randn(shape, device='cuda', dtype=torch.float32) + z = torch.empty(shape, device=x.device, dtype=x.dtype) + + grid = lambda EA: (math.ceil(x.shape.numel() / block_size),) + kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, n_elements=x.shape.numel(), block_size_N=block_size, num_warps=num_warps) + + golden_z = x + y + assert_close(z, golden_z, rtol=1e-7, atol=1e-7) + + +def vecadd_fcmp_no_scf_tester(num_warps, block_size, shape): + ''' + vecadd tester with float comparation as load/store mask. + ''' + @triton.jit + def kernel(x_ptr, + y_ptr, + z_ptr, + n_elements, + block_size_N: tl.constexpr): + pid = tl.program_id(axis=0) + + offset = pid * block_size_N + tl.arange(0, block_size_N) + x_ptrs = x_ptr + offset + y_ptrs = y_ptr + offset + + io_mask = offset < n_elements + x = tl.load(x_ptrs, mask=io_mask) + y = tl.load(y_ptrs, mask=io_mask) + + z = x + y + val_mask = offset < n_elements and (z < 0. or z > 1.) + + z_ptrs = z_ptr + offset + tl.store(z_ptrs, z, mask=val_mask) + + x = torch.randn(shape, device='cuda', dtype=torch.float32) + y = torch.randn(shape, device='cuda', dtype=torch.float32) + z = torch.zeros(shape, device=x.device, dtype=x.dtype) + + grid = lambda EA: (math.ceil(x.shape.numel() / block_size),) + kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, n_elements=x.shape.numel(), block_size_N=block_size, num_warps=num_warps) + + golden_z: torch.Tensor = x + y + gz_data = torch.flatten(golden_z) + for i in range(golden_z.numel()): + gz_data[i] = gz_data[i] if gz_data[i] < 0. or gz_data[i] > 1. else 0. + + assert_close(z, golden_z, rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize('num_warps, block_size, shape', [ + [4, 256, (256,)], + [2, 256, (256,)], + [1, 256, (256,)], + [4, 16, (256,)], + [2, 64, (256,)], + [1, 128, (256,)], +]) +def test_vecadd_no_scf(num_warps, block_size, shape): + vecadd_no_scf_tester(num_warps, block_size, shape) + + +@pytest.mark.parametrize('num_warps, block_size, shape', [ + [1, 128, (256 + 1,)], + [1, 256, (256 + 1,)], + [2, 256, (3, 256 + 7)], + [4, 256, (3, 256 + 7)], +]) +def test_vecadd__no_scf_masked(num_warps, block_size, shape): + vecadd_no_scf_tester(num_warps, block_size, shape) + + +def test_vecadd_no_scf_masked_randomly(): + random.seed(0) # fix seed to make random test reproducible + for i in range(10): + num_elements = random.randint(128, 2048) + shape = (num_elements,) + max_warps = num_elements // 32 # floor div + for num_warps in range(1, max_warps): + is_power2 = num_warps & (num_warps - 1) == 0 and num_warps != 0 + if not is_power2: continue + block_size = min(32, num_warps * 32) + vecadd_no_scf_tester(num_warps, block_size, shape) + + +@pytest.mark.parametrize('num_warps, block_size, shape', [ + [1, 128, (256 + 1,)], + [1, 256, (256 + 1,)], + [2, 256, (3, 256 + 7)], + [4, 256, (3, 256 + 7)], +]) +def test_vecadd_fcmp_no_scf_masked(num_warps, block_size, shape): + vecadd_fcmp_no_scf_tester(num_warps, block_size, shape) diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 562d59526..46ca7fd16 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -699,6 +699,28 @@ class CodeGenerator(ast.NodeVisitor): def visit_Constant(self, node): return triton.language.constexpr(node.value) + def visit_BoolOp(self, node: ast.BoolOp): + assert len(node.values) == 2 + lhs = self.visit(node.values[0]) + rhs = self.visit(node.values[1]) + if isinstance(lhs, triton.language.constexpr): + lhs = lhs.value + if isinstance(rhs, triton.language.constexpr): + rhs = rhs.value + + fn = { + ast.And: 'logical_and', + ast.Or: 'logical_or', + }[type(node.op)] + + if self.is_triton_tensor(lhs): + return getattr(lhs, fn)(rhs, _builder=self.builder) + elif self.is_triton_tensor(rhs): + fn = fn[:2] + 'r' + fn[2:] + return getattr(rhs, fn)(lhs, _builder=self.builder) + else: + return getattr(lhs, fn)(rhs) + if sys.version_info < (3, 8): def visit_NameConstant(self, node): return triton.language.constexpr(node.value) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 33f3e7d41..bf1d57ba4 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -361,8 +361,6 @@ class constexpr: def __rfloordiv__(self, other): return other.value // self.value - # - def __gt__(self, other): return self.value > other.value @@ -557,6 +555,16 @@ class tensor: other = _to_tensor(other, _builder) return semantic.not_equal(self, other, _builder) + @builtin + def logical_and(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.logical_and(self, other, _builder) + + @builtin + def logical_or(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.logical_or(self, other, _builder) + @builtin def __getitem__(self, slices, _builder=None): if isinstance(slices, slice): diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 4084a2cc5..1c8caaefe 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -285,6 +285,22 @@ def xor_(input: tl.tensor, return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) +def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return and_(input, other, builder) + + +def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return or_(input, other, builder) + + def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: