Files
triton/python/triton/language/semantic.py
Philippe Tillet 20100a7254 Merge triton-mlir branch - Complete rewrite of the backend from scratch (#1004)
This PR merges the `triton-mlir` branch, in which we have been quietly
rewriting the Triton backend from scratch to increase maintainability,
stability and ultimately performance. Changes to the runtime are
minimal, and this new version aims to remain backward-compatible with
the previous commit. The legacy backend is now officially deprecated,
but can still be accessed via the `legacy-backend` tag.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com>
Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com>
Co-authored-by: Yan Da <dyanab@connect.ust.hk>
Co-authored-by: Jun Yang <yangjunpro@gmail.com>
Co-authored-by: Ian Bearman <ianb@microsoft.com>
Co-authored-by: Jason Ansel <jansel@jansel.net>
Co-authored-by: Qingyi Liu <qingyil@nvidia.com>
Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <lyricz@yeah.net>
Co-authored-by: ben-zhang-609 <benzh609@gmail.com>
Co-authored-by: dongdongl <dongdongl@nvidia.com>
2022-12-21 01:30:50 -08:00

1166 lines
47 KiB
Python

from __future__ import annotations # remove after python 3.11
from typing import List, Optional, Tuple
from . import core as tl
from triton._C.libtriton.triton import ir
# Create custom exception that prints message "hello"
class IncompatibleTypeErrorImpl(Exception):
def __init__(self, type_a, type_b):
self.type_a = type_a
self.type_b = type_b
self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__()
super(IncompatibleTypeErrorImpl, self).__init__(self.message)
# ===----------------------------------------------------------------------===##
# Programming Model
# ===----------------------------------------------------------------------===##
def program_id(axis: int, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_get_program_id(axis), tl.int32)
def num_programs(axis: int, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_get_num_programs(axis), tl.int32)
# ===----------------------------------------------------------------------===//
# Implicit Casting Utilities
# ===----------------------------------------------------------------------===//
def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype:
a_rank = a_ty.int_bitwidth
b_rank = b_ty.int_bitwidth
a_sn = a_ty.int_signedness
b_sn = b_ty.int_signedness
# Rules for signedness taken from "Usual arithmetic conversions" on
# https://en.cppreference.com/w/c/language/conversion.
if a_sn == b_sn:
return a_ty if a_rank > b_rank else b_ty
elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
return a_ty if a_rank >= b_rank else b_ty
elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
return b_ty if b_rank >= a_rank else a_ty
assert False
def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype:
# 1) if one operand is double, the other is implicitly
# converted to double
if a_ty.is_fp64() or b_ty.is_fp64():
return tl.float64
# 2) if one operand is float, the other is implicitly
# converted to float
if a_ty.is_fp32() or b_ty.is_fp32():
return tl.float32
# 3 ) if one operand is half, the other is implicitly converted to half
# unless we're doing / or %, which do not exist natively in PTX for fp16.
# Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
if a_ty.is_fp16() or b_ty.is_fp16():
if div_or_mod:
return tl.float32
else:
return tl.float16
# 4) return bf16 only if both operands are of bf16
if a_ty.is_bf16() or b_ty.is_bf16():
if div_or_mod:
return tl.float32
if a_ty.is_bf16() and b_ty.is_bf16():
return tl.bfloat16
return tl.float32
if not a_ty.is_int() or not b_ty.is_int():
assert False
# 5 ) both operands are integer and undergo
# integer promotion
if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;"
"this is unlikely to result in a useful answer. Cast them to the same signedness.")
return integer_promote_impl(a_ty, b_ty)
# ===----------------------------------------------------------------------===//
# Binary Operators
# ===----------------------------------------------------------------------===//
def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
if type_a.is_ptr():
if not allow_ptr_a:
raise IncompatibleTypeErrorImpl(type_a, type_b)
# T* + U* with T != U
if type_b.is_ptr() and (type_a != type_b):
raise IncompatibleTypeErrorImpl(type_a, type_b)
# T* + float
if type_b.is_floating():
raise IncompatibleTypeErrorImpl(type_a, type_b)
def binary_op_type_checking_impl(lhs: tl.tensor,
rhs: tl.tensor,
builder: ir.builder,
allow_lhs_ptr=False, allow_rhs_ptr=False,
arithmetic_check=True, div_or_mod=False
) -> Tuple[tl.tensor, tl.tensor]:
# implicit broadcasting
lhs, rhs = broadcast_impl_value(lhs, rhs, builder)
# implicit typecasting
lhs_sca_ty = lhs.type.scalar
rhs_sca_ty = rhs.type.scalar
check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr)
check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr)
if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr():
ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod)
lhs = cast(lhs, ret_sca_ty, builder)
rhs = cast(rhs, ret_sca_ty, builder)
return lhs, rhs
def add(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, True, True)
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
# offset + ptr
# ptr + offset
if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
input, other = other, input
if input_scalar_ty.is_ptr():
return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type)
# float + float
elif input_scalar_ty.is_floating():
return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
# int + int
elif input_scalar_ty.is_int():
return tl.tensor(builder.create_add(input.handle, other.handle), input.type)
assert False
def sub(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, True, False)
scalar_ty = input.type.scalar
# ptr - offset
if scalar_ty.is_ptr():
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle),
input.type)
# float - float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type)
# int - int
elif scalar_ty.is_int():
return tl.tensor(builder.create_sub(input.handle, other.handle), input.type)
assert False
def mul(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float * float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type)
# * int
elif scalar_ty.is_int():
return tl.tensor(builder.create_mul(input.handle, other.handle), input.type)
assert False
def truediv(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
# float / int
if input_scalar_ty.is_floating() and other_scalar_ty.is_int():
other = cast(other, input_scalar_ty, builder)
# int / float
elif input_scalar_ty.is_int() and other_scalar_ty.is_floating():
input = cast(input, other_scalar_ty, builder)
# int / int (cast to tl.float32)
elif input_scalar_ty.is_int() and other_scalar_ty.is_int():
input = cast(input, tl.float32, builder)
other = cast(other, tl.float32, builder)
# float / float (cast to highest exponent type)
elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating():
if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width:
other = cast(other, input_scalar_ty, builder)
else:
input = cast(input, other_scalar_ty, builder)
# unreachable
else:
assert False
return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type)
def floordiv(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
if input_scalar_ty.is_int() and other_scalar_ty.is_int():
ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty)
input = cast(input, ret_ty, builder)
other = cast(other, ret_ty, builder)
if ret_ty.is_int_signed():
return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type)
else:
return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type)
assert False
def fdiv(input: tl.tensor,
other: tl.tensor,
ieee_rounding: bool,
builder: ir.builder) -> tl.tensor:
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
raise ValueError("both operands of fdiv must have floating scalar type")
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
ret = builder.create_fdiv(input.handle, other.handle)
return tl.tensor(ret, input.type)
def mod(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
# float % float
if scalar_ty.is_floating():
# input - input.div(other, rounding_mode="floor") * other
ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder),
other, builder),
builder)
return ret
# % int
elif scalar_ty.is_int():
if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
raise ValueError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " "
"because they have different signedness;"
"this is unlikely to result in a useful answer. Cast them to the same signedness.")
if scalar_ty.is_int_signed():
return tl.tensor(builder.create_srem(input.handle, other.handle), input.type)
else:
return tl.tensor(builder.create_urem(input.handle, other.handle), input.type)
assert False
##############
# bitwise ops
##############
def bitwise_op_type_checking_impl(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False)
input_sca_ty = input.type.scalar
other_sca_ty = other.type.scalar
if not input_sca_ty.is_int() or not other_sca_ty.is_int():
raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty)
ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty)
if ret_sca_ty != input_sca_ty:
input = cast(input, ret_sca_ty, builder)
if ret_sca_ty != other_sca_ty:
other = cast(other, ret_sca_ty, builder)
return input, other
def and_(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_and(input.handle, other.handle), input.type)
def or_(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_or(input.handle, other.handle), input.type)
def xor_(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
if not input.type.is_int1():
input = bitcast(input, tl.dtype("int1"), builder)
if not other.type.is_int1():
other = bitcast(other, tl.dtype("int1"), builder)
return and_(input, other, builder)
def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
if not input.type.is_int1():
input = bitcast(input, tl.dtype("int1"), builder)
if not other.type.is_int1():
other = bitcast(other, tl.dtype("int1"), builder)
return or_(input, other, builder)
def lshr(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type)
def shl(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_shl(input.handle, other.handle), input.type)
# ===----------------------------------------------------------------------===//
# Unary Operators
# ===----------------------------------------------------------------------===//
def plus(input: tl.tensor) -> tl.tensor:
return input
def minus(input: tl.tensor,
builder: ir.builder) -> tl.tensor:
input_sca_ty = input.type.scalar
if input_sca_ty.is_ptr():
raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
_0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
return sub(_0, input, builder)
def invert(input: tl.tensor,
builder: tl.tensor) -> tl.tensor:
input_sca_ty = input.type.scalar
if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
_1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty)
return xor_(input, _1, builder)
# ===----------------------------------------------------------------------===//
# Comparison Operators
# ===----------------------------------------------------------------------===//
def _bool_like(v: tl.tensor) -> tl.block_type:
if not v.type.is_block():
return tl.int1
shape = v.type.shape
return tl.block_type(tl.int1, shape)
def greater_than(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float > float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input))
# > int
elif scalar_ty.is_int():
if scalar_ty.is_int_signed():
return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input))
else:
return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input))
assert False
def greater_equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float >= float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input))
# >= int
elif scalar_ty.is_int():
if scalar_ty.is_int_signed():
return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input))
else:
return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input))
assert False
def less_than(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float < float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input))
# < int
elif scalar_ty.is_int():
if scalar_ty.is_int_signed():
return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input))
else:
return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input))
assert False
def less_equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float < float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input))
# < int
elif scalar_ty.is_int():
if scalar_ty.is_int_signed():
return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input))
else:
return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input))
assert False
def equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float == float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input))
# == int
elif scalar_ty.is_int():
return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input))
assert False
def not_equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float == float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input))
# == int
elif scalar_ty.is_int():
return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input))
assert False
# ===----------------------------------------------------------------------===//
# Block Creation
# ===----------------------------------------------------------------------===//
def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
if not isinstance(start, int) or not isinstance(end, int):
raise ValueError("arange's arguments must be of type tl.constexpr")
shape = [end - start]
ret_ty = tl.block_type(tl.int32, shape)
return tl.tensor(builder.create_make_range(start, end), ret_ty)
def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
_0 = builder.get_null_value(dtype.to_ir(builder))
ret_ty = tl.block_type(dtype, shape)
return tl.tensor(builder.create_splat(_0, shape), ret_ty)
# ===----------------------------------------------------------------------===//
# Shape Manipulation
# ===----------------------------------------------------------------------===//
def view(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
# TODO: disable when TritonToTritonGPU handles views properly
# assert len(input.shape) == len(dst_shape)
numel = 1
for s in dst_shape:
numel *= s
if input.type.numel != numel:
raise ValueError("cannot view block of different shape")
ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
dst_shape = [s for s in input.type.shape]
dst_shape.insert(axis, 1)
ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor:
assert can_reorder, "current implementation of `cat` always may reorder elements"
assert len(lhs.shape) == 1
ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type)
def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
if len(input.shape) != 2:
raise ValueError("Only 2D tensors can be transposed")
ret_type = tl.block_type(input.type.scalar, [input.shape[1], input.shape[0]])
return tl.tensor(builder.create_trans(input.handle), ret_type)
def broadcast_impl_shape(input: tl.tensor,
shape: List[int],
builder: ir.builder) -> tl.tensor:
if not input.type.is_block():
ret_ty = tl.block_type(input.type, shape)
return tl.tensor(builder.create_splat(input.handle, shape), ret_ty)
src_shape = input.type.get_block_shapes()
if len(src_shape) != len(shape):
raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
if shape == src_shape:
return input
for i in range(len(src_shape)):
if shape[i] != src_shape[i] and src_shape[i] != 1:
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
f" must match the existing size ({src_shape[i]}) at non-singleton dimension"
f" {i}: {src_shape}, {shape}")
ret_ty = tl.block_type(input.type.scalar, shape)
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
def broadcast_impl_value(lhs: tl.tensor,
rhs: tl.tensor,
builder: ir.builder) -> tl.tensor:
lhs_ty = lhs.type
rhs_ty = rhs.type
# make_shape_compatible(block, scalar)
if lhs_ty.is_block() and not rhs_ty.is_block():
rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape)
rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty)
# make_shape_compatible(scalar, block)
elif not lhs_ty.is_block() and rhs_ty.is_block():
lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape)
lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty)
# make_shape_compatible(block, block)
elif lhs_ty.is_block() and rhs_ty.is_block():
lhs_shape = lhs_ty.get_block_shapes()
rhs_shape = rhs_ty.get_block_shapes()
if len(lhs_shape) < len(rhs_shape):
# Add new axes to lhs
for dim in range(len(lhs_shape), len(rhs_shape)):
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, dim), tl.block_type(lhs_ty.scalar, lhs_shape + [1]))
lhs_ty = lhs.type
lhs_shape = lhs_ty.get_block_shapes()
elif len(rhs_shape) < len(lhs_shape):
# Add new axes to rhs
for dim in range(len(rhs_shape), len(lhs_shape)):
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, dim), tl.block_type(rhs_ty.scalar, rhs_shape + [1]))
rhs_ty = rhs.type
rhs_shape = rhs_ty.get_block_shapes()
assert len(rhs_shape) == len(lhs_shape)
ret_shape = []
for i in range(len(lhs_shape)):
left = lhs_shape[i]
right = rhs_shape[i]
if left == 1:
ret_shape.append(right)
elif right == 1:
ret_shape.append(left)
elif left == right:
ret_shape.append(left)
else:
raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
"at index " + str(i) + ": " + str(left) + " and " + str(right))
if lhs_shape != ret_shape:
ret_ty = tl.block_type(lhs_ty.scalar, ret_shape)
lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty)
if rhs_shape != ret_shape:
ret_ty = tl.block_type(rhs_ty.scalar, ret_shape)
rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty)
# (scalar, scalar) => returns original blocks
return lhs, rhs
#######
# cast
#######
def bitcast(input: tl.tensor,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
src_ty = input.type
if src_ty.is_block():
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
if src_ty == dst_ty:
return input
src_sca_ty = src_ty.scalar
dst_sca_ty = dst_ty.scalar
if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr():
return cast(input, dst_ty, builder)
# Bitcast
src_bits = src_sca_ty.primitive_bitwidth
dst_bits = dst_sca_ty.primitive_bitwidth
if src_bits != dst_bits:
raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + "to "
"data-type of size " + str(dst_bits))
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)),
dst_ty)
def cast(input: tl.tensor,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
src_ty = input.type
if src_ty.is_block():
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
if src_ty == dst_ty:
return input
src_sca_ty = src_ty.scalar
dst_sca_ty = dst_ty.scalar
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
if (src_sca_ty.is_customized_floating() and dst_sca_ty.is_floating()) or \
(src_sca_ty.is_floating() and dst_sca_ty.is_customized_floating()):
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
dst_ty)
# bf16 <=> (not fp32)
if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
(src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()):
return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
# Standard floating types' casting: truncation
# fp64 => fp32, fp16, bf16
# fp32 => fp16, bf16
truncate_fp = src_sca_ty.is_floating() and \
dst_sca_ty.is_floating() and \
src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
if truncate_fp:
return tl.tensor(builder.create_fp_trunc(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
# Standard floating types' casting: extension
# fp32 => fp64
# fp16 => fp32, fp64
# bf16 => fp32, fp64
ext_fp = src_sca_ty.is_floating() and \
dst_sca_ty.is_floating() and \
src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
if ext_fp:
return tl.tensor(builder.create_fp_ext(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
# Casting between integer types
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
if dst_sca_ty.is_bool():
ty = input.dtype.to_ir(builder)
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
return not_equal(input, _0, builder)
else:
return tl.tensor(builder.create_int_cast(input.handle,
dst_ty.to_ir(builder), sign_extend),
dst_ty)
# Casting standard floating types to integer types
if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
if dst_sca_ty.is_bool():
ty = input.dtype.to_ir(builder)
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
return not_equal(input, _0, builder)
elif dst_sca_ty.is_int_signed():
return tl.tensor(builder.create_fp_to_si(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
else:
return tl.tensor(builder.create_fp_to_ui(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
# Casting integer types to standard floating types
if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
return tl.tensor(builder.create_ui_to_fp(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
else:
return tl.tensor(builder.create_si_to_fp(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
# Casting pointer types to integer types
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
bitwidth = dst_sca_ty.int_bitwidth
if bitwidth == 64:
return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)),
dst_ty)
if bitwidth == 1:
return not_equal(cast(input, tl.int64, builder),
tl.tensor(builder.get_int64(0), tl.int64),
builder)
# Casting integer types to pointer types
if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty)
# Casting pointer types to pointer types
if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
assert False, f'cannot cast {input} to {dst_ty}'
# ===----------------------------------------------------------------------===//
# Memory Operators
# ===----------------------------------------------------------------------===//
def load(ptr: tl.tensor,
mask: Optional[tl.tensor],
other: Optional[tl.tensor],
cache_modifier: str,
eviction_policy: str,
is_volatile: bool,
builder: ir.builder) -> tl.tensor:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of load instruction is " + ptr.type.__repr__())
if ptr.type.is_block():
if mask:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
if other:
other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)
ptr_ty = ptr.type.scalar
elt_ty = ptr_ty.element_ty
# treat bool* as tl.int8*
if elt_ty == tl.int1:
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)
if other:
other = cast(other, elt_ty, builder)
# cache modifier
cache = ir.CACHE_MODIFIER.NONE # default
if cache_modifier:
if cache_modifier == ".ca":
cache = ir.CACHE_MODIFIER.CA
elif cache_modifier == ".cg":
cache = ir.CACHE_MODIFIER.CG
else:
raise ValueError(f"Cache modifier {cache_modifier} not supported")
# eviction policy
eviction = ir.EVICTION_POLICY.NORMAL # default
if eviction_policy:
if eviction_policy == "evict_last":
eviction = ir.EVICTION_POLICY.EVICT_LAST
elif eviction_policy == "evict_first":
eviction = ir.EVICTION_POLICY.EVICT_FIRST
else:
raise ValueError(f"Eviction policy {eviction_policy} not supported")
if ptr.type.is_block():
shape = ptr.type.get_block_shapes()
dst_ty = tl.block_type(elt_ty, shape)
else:
dst_ty = elt_ty
if not mask:
if other:
raise ValueError("`other` cannot be provided without `mask`")
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile),
dst_ty)
else:
return tl.tensor(builder.create_masked_load(ptr.handle,
mask.handle,
other.handle if other else None,
cache, eviction, is_volatile),
dst_ty)
def store(ptr: tl.tensor,
val: tl.tensor,
mask: Optional[tl.tensor],
builder: ir.builder) -> tl.tensor:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
if ptr.type.is_block():
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
if mask:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
ptr_ty = ptr.type.scalar
elt_ty = ptr_ty.element_ty
# treat bool* as tl.int8*
if elt_ty == tl.int1:
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)
# cast to target data-type
val = cast(val, elt_ty, builder)
if not mask:
return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void)
if not mask.type.scalar.is_bool():
raise ValueError("Mask must have boolean scalar type")
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void)
#########
# atomic
#########
def atomic_cas(ptr: tl.tensor,
cmp: tl.tensor,
val: tl.tensor,
builder: ir.builder) -> tl.tensor:
element_ty = ptr.type.scalar.element_ty
if element_ty.primitive_bitwidth not in [16, 32, 64]:
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type)
def atom_red_typechecking_impl(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
op: str,
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
element_ty = ptr.type.scalar.element_ty
if element_ty is tl.float16 and op != 'add':
raise ValueError("atomic_" + op + " does not support fp16")
if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
raise ValueError("atomic_" + op + " does not support " + element_ty)
if ptr.type.is_block():
if mask:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
if val:
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
val = cast(val, ptr.type.scalar.element_ty, builder)
if not mask:
mask_ir = builder.get_int1(True)
mask_ty = tl.int1
if ptr.type.is_block():
mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes())
mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes())
mask = tl.tensor(mask_ir, mask_ty)
return ptr, val, mask
def atomic_max(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
sca_ty = val.type.scalar
# direct call to atomic_max for integers
if sca_ty.is_int():
if sca_ty.is_int_signed():
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX,
ptr.handle,
val.handle,
mask.handle),
val.type)
else:
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
ptr.handle,
val.handle,
mask.handle),
val.type)
# for float
# return atomic_smax(i_ptr, i_val) if val >= 0
# return atomic_umin(i_ptr, i_val) if val < 0
i_val = bitcast(val, tl.int32, builder)
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type)
return where(pos, pos_ret, neg_ret, builder)
def atomic_min(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
sca_ty = val.type.scalar
# direct call to atomic_min for integers
if sca_ty.is_int():
if sca_ty.is_int_signed():
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
ptr.handle,
val.handle,
mask.handle),
val.type)
else:
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN,
ptr.handle,
val.handle,
mask.handle),
val.type)
# for float
# return atomic_smin(i_ptr, i_val) if val >= 0
# return atomic_umax(i_ptr, i_val) if val < 0
i_val = bitcast(val, tl.int32, builder)
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
i_ptr.handle,
i_val.handle,
and_(mask, pos, builder).handle),
i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
i_ptr.handle,
i_val.handle,
and_(mask, neg, builder).handle),
i_val.type)
return where(pos, pos_ret, neg_ret, builder)
def atomic_add(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
sca_ty = val.type.scalar
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type)
def atomic_and(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type)
def atomic_or(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type)
def atomic_xor(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type)
def atomic_xchg(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type)
# ===----------------------------------------------------------------------===//
# Linear Algebra
# ===----------------------------------------------------------------------===//
def dot(lhs: tl.tensor,
rhs: tl.tensor,
allow_tf32: bool,
builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
assert lhs.shape[1].value == rhs.shape[0].value
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
and rhs.shape[1].value >= 16,\
"small blocks not supported!"
if lhs.type.scalar.is_int():
_0 = builder.get_int32(0)
ret_scalar_ty = tl.int32
else:
_0 = builder.get_float32(0)
ret_scalar_ty = tl.float32
M = lhs.type.shape[0]
N = rhs.type.shape[1]
_0 = builder.create_splat(_0, [M, N])
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
ret_ty)
# ===----------------------------------------------------------------------===//
# Indexing
# ===----------------------------------------------------------------------===//
def where(condition: tl.tensor,
x: tl.tensor,
y: tl.tensor,
builder: ir.builder) -> tl.tensor:
condition = cast(condition, tl.int1, builder)
if condition.type.is_block():
condition, x = broadcast_impl_value(condition, x, builder)
x, y = broadcast_impl_value(x, y, builder)
condition, x = broadcast_impl_value(condition, x, builder)
x, y = binary_op_type_checking_impl(x, y, builder, True, True)
if not condition.type.is_block():
condition, _ = broadcast_impl_value(condition, x, builder)
ret_ty = x.type
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
# ===----------------------------------------------------------------------===//
# Reductions
# ===----------------------------------------------------------------------===
def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor:
scalar_ty = input.type.scalar
# input is extended to 32-bits if necessary
# this increases numerical accuracy and can be done pretty much for free
# on GPUs
if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32:
input = cast(input, tl.int32, builder)
# hardware doesn't support FMAX, FMIN, CMP for bfloat16
if scalar_ty is tl.bfloat16:
input = cast(input, tl.float32, builder)
# choose the right unsigned operation
if scalar_ty.is_int_unsigned():
int_op_to_unit = {
ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN,
ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX,
ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN,
ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX,
}
if INT_OP in int_op_to_unit:
INT_OP = int_op_to_unit[INT_OP]
# get result type
shape = input.type.shape
ret_shape = []
for i, s in enumerate(shape):
if i != axis:
ret_shape.append(s)
if ret_shape:
res_ty = tl.block_type(scalar_ty, ret_shape)
else:
# 0d-tensor -> scalar
res_ty = scalar_ty
if scalar_ty.is_floating():
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
elif scalar_ty.is_int():
return tl.tensor(builder.create_reduce(input.handle, INT_OP, axis), res_ty)
assert False
def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN)
def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN)
def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX)
def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX)
def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)
def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
scalar_ty = input.type.scalar
if not scalar_ty.is_int():
raise ValueError("xor_sum only supported for integers")
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR)
# ===----------------------------------------------------------------------===
# Math
# ===----------------------------------------------------------------------===
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
x, y = binary_op_type_checking_impl(x, y, builder)
from . import libdevice
return libdevice.mulhi(x, y, _builder=builder)
def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor:
from . import libdevice
return libdevice.floor(x, _builder=builder)
def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_exp(x.handle), x.type)
def log(x: tl.tensor, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_log(x.handle), x.type)
def cos(x: tl.tensor, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_cos(x.handle), x.type)
def sin(x: tl.tensor, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_sin(x.handle), x.type)
def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_sqrt(x.handle), x.type)
##
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
raise ValueError("Shape of input to multiple_of does not match the length of values")
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
return x
def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
raise ValueError("Shape of input to max_contiguous does not match the length of values")
x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
return x
def debug_barrier(builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_barrier(), tl.void)
def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
new_args = []
for arg in args:
new_args.append(arg.handle)
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)