This PR merges the `triton-mlir` branch, in which we have been quietly rewriting the Triton backend from scratch to increase maintainability, stability and ultimately performance. Changes to the runtime are minimal, and this new version aims to remain backward-compatible with the previous commit. The legacy backend is now officially deprecated, but can still be accessed via the `legacy-backend` tag. Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com> Co-authored-by: goostavz <109190422+goostavz@users.noreply.github.com> Co-authored-by: Shintaro Iwasaki <siwasaki@fb.com> Co-authored-by: Yan Da <dyanab@connect.ust.hk> Co-authored-by: Jun Yang <yangjunpro@gmail.com> Co-authored-by: Ian Bearman <ianb@microsoft.com> Co-authored-by: Jason Ansel <jansel@jansel.net> Co-authored-by: Qingyi Liu <qingyil@nvidia.com> Co-authored-by: ben-zhang-609 <110140741+ben-zhang-609@users.noreply.github.com> Co-authored-by: Chenggang Zhao <lyricz@yeah.net> Co-authored-by: ben-zhang-609 <benzh609@gmail.com> Co-authored-by: dongdongl <dongdongl@nvidia.com>
1166 lines
47 KiB
Python
1166 lines
47 KiB
Python
from __future__ import annotations # remove after python 3.11
|
|
|
|
from typing import List, Optional, Tuple
|
|
|
|
from . import core as tl
|
|
from triton._C.libtriton.triton import ir
|
|
|
|
|
|
# Create custom exception that prints message "hello"
|
|
class IncompatibleTypeErrorImpl(Exception):
|
|
def __init__(self, type_a, type_b):
|
|
self.type_a = type_a
|
|
self.type_b = type_b
|
|
self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__()
|
|
super(IncompatibleTypeErrorImpl, self).__init__(self.message)
|
|
|
|
|
|
# ===----------------------------------------------------------------------===##
|
|
# Programming Model
|
|
# ===----------------------------------------------------------------------===##
|
|
|
|
def program_id(axis: int, builder: ir.builder) -> tl.tensor:
|
|
return tl.tensor(builder.create_get_program_id(axis), tl.int32)
|
|
|
|
|
|
def num_programs(axis: int, builder: ir.builder) -> tl.tensor:
|
|
return tl.tensor(builder.create_get_num_programs(axis), tl.int32)
|
|
|
|
# ===----------------------------------------------------------------------===//
|
|
# Implicit Casting Utilities
|
|
# ===----------------------------------------------------------------------===//
|
|
|
|
|
|
def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype:
|
|
a_rank = a_ty.int_bitwidth
|
|
b_rank = b_ty.int_bitwidth
|
|
a_sn = a_ty.int_signedness
|
|
b_sn = b_ty.int_signedness
|
|
# Rules for signedness taken from "Usual arithmetic conversions" on
|
|
# https://en.cppreference.com/w/c/language/conversion.
|
|
if a_sn == b_sn:
|
|
return a_ty if a_rank > b_rank else b_ty
|
|
elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
|
|
return a_ty if a_rank >= b_rank else b_ty
|
|
elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
|
|
return b_ty if b_rank >= a_rank else a_ty
|
|
assert False
|
|
|
|
|
|
def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype:
|
|
# 1) if one operand is double, the other is implicitly
|
|
# converted to double
|
|
if a_ty.is_fp64() or b_ty.is_fp64():
|
|
return tl.float64
|
|
# 2) if one operand is float, the other is implicitly
|
|
# converted to float
|
|
if a_ty.is_fp32() or b_ty.is_fp32():
|
|
return tl.float32
|
|
# 3 ) if one operand is half, the other is implicitly converted to half
|
|
# unless we're doing / or %, which do not exist natively in PTX for fp16.
|
|
# Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
|
|
if a_ty.is_fp16() or b_ty.is_fp16():
|
|
if div_or_mod:
|
|
return tl.float32
|
|
else:
|
|
return tl.float16
|
|
# 4) return bf16 only if both operands are of bf16
|
|
if a_ty.is_bf16() or b_ty.is_bf16():
|
|
if div_or_mod:
|
|
return tl.float32
|
|
if a_ty.is_bf16() and b_ty.is_bf16():
|
|
return tl.bfloat16
|
|
return tl.float32
|
|
if not a_ty.is_int() or not b_ty.is_int():
|
|
assert False
|
|
# 5 ) both operands are integer and undergo
|
|
# integer promotion
|
|
if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
|
|
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;"
|
|
"this is unlikely to result in a useful answer. Cast them to the same signedness.")
|
|
return integer_promote_impl(a_ty, b_ty)
|
|
|
|
# ===----------------------------------------------------------------------===//
|
|
# Binary Operators
|
|
# ===----------------------------------------------------------------------===//
|
|
|
|
|
|
def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
|
|
if type_a.is_ptr():
|
|
if not allow_ptr_a:
|
|
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
|
# T* + U* with T != U
|
|
if type_b.is_ptr() and (type_a != type_b):
|
|
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
|
# T* + float
|
|
if type_b.is_floating():
|
|
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
|
|
|
|
|
def binary_op_type_checking_impl(lhs: tl.tensor,
|
|
rhs: tl.tensor,
|
|
builder: ir.builder,
|
|
allow_lhs_ptr=False, allow_rhs_ptr=False,
|
|
arithmetic_check=True, div_or_mod=False
|
|
) -> Tuple[tl.tensor, tl.tensor]:
|
|
# implicit broadcasting
|
|
lhs, rhs = broadcast_impl_value(lhs, rhs, builder)
|
|
# implicit typecasting
|
|
lhs_sca_ty = lhs.type.scalar
|
|
rhs_sca_ty = rhs.type.scalar
|
|
check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr)
|
|
check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr)
|
|
if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr():
|
|
ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod)
|
|
lhs = cast(lhs, ret_sca_ty, builder)
|
|
rhs = cast(rhs, ret_sca_ty, builder)
|
|
return lhs, rhs
|
|
|
|
|
|
def add(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = binary_op_type_checking_impl(input, other, builder, True, True)
|
|
input_scalar_ty = input.type.scalar
|
|
other_scalar_ty = other.type.scalar
|
|
|
|
# offset + ptr
|
|
# ptr + offset
|
|
if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
|
|
input, other = other, input
|
|
if input_scalar_ty.is_ptr():
|
|
return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type)
|
|
# float + float
|
|
elif input_scalar_ty.is_floating():
|
|
return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
|
|
# int + int
|
|
elif input_scalar_ty.is_int():
|
|
return tl.tensor(builder.create_add(input.handle, other.handle), input.type)
|
|
assert False
|
|
|
|
|
|
def sub(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = binary_op_type_checking_impl(input, other, builder, True, False)
|
|
scalar_ty = input.type.scalar
|
|
# ptr - offset
|
|
if scalar_ty.is_ptr():
|
|
return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle),
|
|
input.type)
|
|
# float - float
|
|
if scalar_ty.is_floating():
|
|
return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type)
|
|
# int - int
|
|
elif scalar_ty.is_int():
|
|
return tl.tensor(builder.create_sub(input.handle, other.handle), input.type)
|
|
assert False
|
|
|
|
|
|
def mul(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = binary_op_type_checking_impl(input, other, builder)
|
|
scalar_ty = input.type.scalar
|
|
# float * float
|
|
if scalar_ty.is_floating():
|
|
return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type)
|
|
# * int
|
|
elif scalar_ty.is_int():
|
|
return tl.tensor(builder.create_mul(input.handle, other.handle), input.type)
|
|
assert False
|
|
|
|
|
|
def truediv(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
|
|
input_scalar_ty = input.type.scalar
|
|
other_scalar_ty = other.type.scalar
|
|
# float / int
|
|
if input_scalar_ty.is_floating() and other_scalar_ty.is_int():
|
|
other = cast(other, input_scalar_ty, builder)
|
|
# int / float
|
|
elif input_scalar_ty.is_int() and other_scalar_ty.is_floating():
|
|
input = cast(input, other_scalar_ty, builder)
|
|
# int / int (cast to tl.float32)
|
|
elif input_scalar_ty.is_int() and other_scalar_ty.is_int():
|
|
input = cast(input, tl.float32, builder)
|
|
other = cast(other, tl.float32, builder)
|
|
# float / float (cast to highest exponent type)
|
|
elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating():
|
|
if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width:
|
|
other = cast(other, input_scalar_ty, builder)
|
|
else:
|
|
input = cast(input, other_scalar_ty, builder)
|
|
# unreachable
|
|
else:
|
|
assert False
|
|
return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type)
|
|
|
|
|
|
def floordiv(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
|
|
input_scalar_ty = input.type.scalar
|
|
other_scalar_ty = other.type.scalar
|
|
if input_scalar_ty.is_int() and other_scalar_ty.is_int():
|
|
ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty)
|
|
input = cast(input, ret_ty, builder)
|
|
other = cast(other, ret_ty, builder)
|
|
if ret_ty.is_int_signed():
|
|
return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type)
|
|
else:
|
|
return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type)
|
|
assert False
|
|
|
|
|
|
def fdiv(input: tl.tensor,
|
|
other: tl.tensor,
|
|
ieee_rounding: bool,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input_scalar_ty = input.type.scalar
|
|
other_scalar_ty = other.type.scalar
|
|
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
|
|
raise ValueError("both operands of fdiv must have floating scalar type")
|
|
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
|
|
ret = builder.create_fdiv(input.handle, other.handle)
|
|
return tl.tensor(ret, input.type)
|
|
|
|
|
|
def mod(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
|
|
scalar_ty = input.type.scalar
|
|
other_scalar_ty = other.type.scalar
|
|
# float % float
|
|
if scalar_ty.is_floating():
|
|
# input - input.div(other, rounding_mode="floor") * other
|
|
ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder),
|
|
other, builder),
|
|
builder)
|
|
return ret
|
|
# % int
|
|
elif scalar_ty.is_int():
|
|
if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
|
|
raise ValueError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " "
|
|
"because they have different signedness;"
|
|
"this is unlikely to result in a useful answer. Cast them to the same signedness.")
|
|
if scalar_ty.is_int_signed():
|
|
return tl.tensor(builder.create_srem(input.handle, other.handle), input.type)
|
|
else:
|
|
return tl.tensor(builder.create_urem(input.handle, other.handle), input.type)
|
|
assert False
|
|
|
|
##############
|
|
# bitwise ops
|
|
##############
|
|
|
|
|
|
def bitwise_op_type_checking_impl(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]:
|
|
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False)
|
|
input_sca_ty = input.type.scalar
|
|
other_sca_ty = other.type.scalar
|
|
if not input_sca_ty.is_int() or not other_sca_ty.is_int():
|
|
raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty)
|
|
ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty)
|
|
if ret_sca_ty != input_sca_ty:
|
|
input = cast(input, ret_sca_ty, builder)
|
|
if ret_sca_ty != other_sca_ty:
|
|
other = cast(other, ret_sca_ty, builder)
|
|
return input, other
|
|
|
|
|
|
def and_(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
|
return tl.tensor(builder.create_and(input.handle, other.handle), input.type)
|
|
|
|
|
|
def or_(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
|
return tl.tensor(builder.create_or(input.handle, other.handle), input.type)
|
|
|
|
|
|
def xor_(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
|
return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
|
|
|
|
|
|
def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
if not input.type.is_int1():
|
|
input = bitcast(input, tl.dtype("int1"), builder)
|
|
if not other.type.is_int1():
|
|
other = bitcast(other, tl.dtype("int1"), builder)
|
|
return and_(input, other, builder)
|
|
|
|
|
|
def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
if not input.type.is_int1():
|
|
input = bitcast(input, tl.dtype("int1"), builder)
|
|
if not other.type.is_int1():
|
|
other = bitcast(other, tl.dtype("int1"), builder)
|
|
return or_(input, other, builder)
|
|
|
|
|
|
def lshr(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
|
return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type)
|
|
|
|
|
|
def shl(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = bitwise_op_type_checking_impl(input, other, builder)
|
|
return tl.tensor(builder.create_shl(input.handle, other.handle), input.type)
|
|
|
|
# ===----------------------------------------------------------------------===//
|
|
# Unary Operators
|
|
# ===----------------------------------------------------------------------===//
|
|
|
|
|
|
def plus(input: tl.tensor) -> tl.tensor:
|
|
return input
|
|
|
|
|
|
def minus(input: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input_sca_ty = input.type.scalar
|
|
if input_sca_ty.is_ptr():
|
|
raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
|
|
_0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
|
|
return sub(_0, input, builder)
|
|
|
|
|
|
def invert(input: tl.tensor,
|
|
builder: tl.tensor) -> tl.tensor:
|
|
input_sca_ty = input.type.scalar
|
|
if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
|
|
raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
|
|
_1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty)
|
|
return xor_(input, _1, builder)
|
|
|
|
|
|
# ===----------------------------------------------------------------------===//
|
|
# Comparison Operators
|
|
# ===----------------------------------------------------------------------===//
|
|
def _bool_like(v: tl.tensor) -> tl.block_type:
|
|
if not v.type.is_block():
|
|
return tl.int1
|
|
shape = v.type.shape
|
|
return tl.block_type(tl.int1, shape)
|
|
|
|
|
|
def greater_than(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = binary_op_type_checking_impl(input, other, builder)
|
|
scalar_ty = input.type.scalar
|
|
# float > float
|
|
if scalar_ty.is_floating():
|
|
return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input))
|
|
# > int
|
|
elif scalar_ty.is_int():
|
|
if scalar_ty.is_int_signed():
|
|
return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input))
|
|
else:
|
|
return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input))
|
|
assert False
|
|
|
|
|
|
def greater_equal(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = binary_op_type_checking_impl(input, other, builder)
|
|
scalar_ty = input.type.scalar
|
|
# float >= float
|
|
if scalar_ty.is_floating():
|
|
return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input))
|
|
# >= int
|
|
elif scalar_ty.is_int():
|
|
if scalar_ty.is_int_signed():
|
|
return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input))
|
|
else:
|
|
return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input))
|
|
assert False
|
|
|
|
|
|
def less_than(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = binary_op_type_checking_impl(input, other, builder)
|
|
scalar_ty = input.type.scalar
|
|
# float < float
|
|
if scalar_ty.is_floating():
|
|
return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input))
|
|
# < int
|
|
elif scalar_ty.is_int():
|
|
if scalar_ty.is_int_signed():
|
|
return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input))
|
|
else:
|
|
return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input))
|
|
assert False
|
|
|
|
|
|
def less_equal(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = binary_op_type_checking_impl(input, other, builder)
|
|
scalar_ty = input.type.scalar
|
|
# float < float
|
|
if scalar_ty.is_floating():
|
|
return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input))
|
|
# < int
|
|
elif scalar_ty.is_int():
|
|
if scalar_ty.is_int_signed():
|
|
return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input))
|
|
else:
|
|
return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input))
|
|
assert False
|
|
|
|
|
|
def equal(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = binary_op_type_checking_impl(input, other, builder)
|
|
scalar_ty = input.type.scalar
|
|
# float == float
|
|
if scalar_ty.is_floating():
|
|
return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input))
|
|
# == int
|
|
elif scalar_ty.is_int():
|
|
return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input))
|
|
assert False
|
|
|
|
|
|
def not_equal(input: tl.tensor,
|
|
other: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
input, other = binary_op_type_checking_impl(input, other, builder)
|
|
scalar_ty = input.type.scalar
|
|
# float == float
|
|
if scalar_ty.is_floating():
|
|
return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input))
|
|
# == int
|
|
elif scalar_ty.is_int():
|
|
return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input))
|
|
assert False
|
|
|
|
# ===----------------------------------------------------------------------===//
|
|
# Block Creation
|
|
# ===----------------------------------------------------------------------===//
|
|
|
|
|
|
def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
|
|
if not isinstance(start, int) or not isinstance(end, int):
|
|
raise ValueError("arange's arguments must be of type tl.constexpr")
|
|
|
|
shape = [end - start]
|
|
ret_ty = tl.block_type(tl.int32, shape)
|
|
return tl.tensor(builder.create_make_range(start, end), ret_ty)
|
|
|
|
|
|
def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
|
|
_0 = builder.get_null_value(dtype.to_ir(builder))
|
|
ret_ty = tl.block_type(dtype, shape)
|
|
return tl.tensor(builder.create_splat(_0, shape), ret_ty)
|
|
|
|
# ===----------------------------------------------------------------------===//
|
|
# Shape Manipulation
|
|
# ===----------------------------------------------------------------------===//
|
|
|
|
|
|
def view(input: tl.tensor,
|
|
dst_shape: List[int],
|
|
builder: ir.builder) -> tl.tensor:
|
|
# TODO: disable when TritonToTritonGPU handles views properly
|
|
|
|
# assert len(input.shape) == len(dst_shape)
|
|
numel = 1
|
|
for s in dst_shape:
|
|
numel *= s
|
|
if input.type.numel != numel:
|
|
raise ValueError("cannot view block of different shape")
|
|
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
|
return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty)
|
|
|
|
|
|
def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
|
dst_shape = [s for s in input.type.shape]
|
|
dst_shape.insert(axis, 1)
|
|
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
|
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
|
|
|
|
|
|
def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor:
|
|
assert can_reorder, "current implementation of `cat` always may reorder elements"
|
|
assert len(lhs.shape) == 1
|
|
ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
|
|
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type)
|
|
|
|
|
|
def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
if len(input.shape) != 2:
|
|
raise ValueError("Only 2D tensors can be transposed")
|
|
ret_type = tl.block_type(input.type.scalar, [input.shape[1], input.shape[0]])
|
|
return tl.tensor(builder.create_trans(input.handle), ret_type)
|
|
|
|
|
|
def broadcast_impl_shape(input: tl.tensor,
|
|
shape: List[int],
|
|
builder: ir.builder) -> tl.tensor:
|
|
if not input.type.is_block():
|
|
ret_ty = tl.block_type(input.type, shape)
|
|
return tl.tensor(builder.create_splat(input.handle, shape), ret_ty)
|
|
src_shape = input.type.get_block_shapes()
|
|
if len(src_shape) != len(shape):
|
|
raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
|
|
if shape == src_shape:
|
|
return input
|
|
for i in range(len(src_shape)):
|
|
if shape[i] != src_shape[i] and src_shape[i] != 1:
|
|
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
|
|
f" must match the existing size ({src_shape[i]}) at non-singleton dimension"
|
|
f" {i}: {src_shape}, {shape}")
|
|
ret_ty = tl.block_type(input.type.scalar, shape)
|
|
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
|
|
|
|
|
|
def broadcast_impl_value(lhs: tl.tensor,
|
|
rhs: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
lhs_ty = lhs.type
|
|
rhs_ty = rhs.type
|
|
|
|
# make_shape_compatible(block, scalar)
|
|
if lhs_ty.is_block() and not rhs_ty.is_block():
|
|
rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape)
|
|
rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty)
|
|
# make_shape_compatible(scalar, block)
|
|
elif not lhs_ty.is_block() and rhs_ty.is_block():
|
|
lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape)
|
|
lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty)
|
|
# make_shape_compatible(block, block)
|
|
elif lhs_ty.is_block() and rhs_ty.is_block():
|
|
lhs_shape = lhs_ty.get_block_shapes()
|
|
rhs_shape = rhs_ty.get_block_shapes()
|
|
|
|
if len(lhs_shape) < len(rhs_shape):
|
|
# Add new axes to lhs
|
|
for dim in range(len(lhs_shape), len(rhs_shape)):
|
|
lhs = tl.tensor(builder.create_expand_dims(lhs.handle, dim), tl.block_type(lhs_ty.scalar, lhs_shape + [1]))
|
|
lhs_ty = lhs.type
|
|
lhs_shape = lhs_ty.get_block_shapes()
|
|
elif len(rhs_shape) < len(lhs_shape):
|
|
# Add new axes to rhs
|
|
for dim in range(len(rhs_shape), len(lhs_shape)):
|
|
rhs = tl.tensor(builder.create_expand_dims(rhs.handle, dim), tl.block_type(rhs_ty.scalar, rhs_shape + [1]))
|
|
rhs_ty = rhs.type
|
|
rhs_shape = rhs_ty.get_block_shapes()
|
|
assert len(rhs_shape) == len(lhs_shape)
|
|
|
|
ret_shape = []
|
|
for i in range(len(lhs_shape)):
|
|
left = lhs_shape[i]
|
|
right = rhs_shape[i]
|
|
if left == 1:
|
|
ret_shape.append(right)
|
|
elif right == 1:
|
|
ret_shape.append(left)
|
|
elif left == right:
|
|
ret_shape.append(left)
|
|
else:
|
|
raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
|
|
"at index " + str(i) + ": " + str(left) + " and " + str(right))
|
|
if lhs_shape != ret_shape:
|
|
ret_ty = tl.block_type(lhs_ty.scalar, ret_shape)
|
|
lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty)
|
|
if rhs_shape != ret_shape:
|
|
ret_ty = tl.block_type(rhs_ty.scalar, ret_shape)
|
|
rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty)
|
|
# (scalar, scalar) => returns original blocks
|
|
return lhs, rhs
|
|
|
|
#######
|
|
# cast
|
|
#######
|
|
|
|
|
|
def bitcast(input: tl.tensor,
|
|
dst_ty: tl.dtype,
|
|
builder: ir.builder) -> tl.tensor:
|
|
src_ty = input.type
|
|
if src_ty.is_block():
|
|
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
|
|
if src_ty == dst_ty:
|
|
return input
|
|
src_sca_ty = src_ty.scalar
|
|
dst_sca_ty = dst_ty.scalar
|
|
if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr():
|
|
return cast(input, dst_ty, builder)
|
|
# Bitcast
|
|
src_bits = src_sca_ty.primitive_bitwidth
|
|
dst_bits = dst_sca_ty.primitive_bitwidth
|
|
if src_bits != dst_bits:
|
|
raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + "to "
|
|
"data-type of size " + str(dst_bits))
|
|
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)),
|
|
dst_ty)
|
|
|
|
|
|
def cast(input: tl.tensor,
|
|
dst_ty: tl.dtype,
|
|
builder: ir.builder) -> tl.tensor:
|
|
src_ty = input.type
|
|
if src_ty.is_block():
|
|
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
|
|
if src_ty == dst_ty:
|
|
return input
|
|
|
|
src_sca_ty = src_ty.scalar
|
|
dst_sca_ty = dst_ty.scalar
|
|
|
|
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
|
|
if (src_sca_ty.is_customized_floating() and dst_sca_ty.is_floating()) or \
|
|
(src_sca_ty.is_floating() and dst_sca_ty.is_customized_floating()):
|
|
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
|
|
dst_ty)
|
|
|
|
# bf16 <=> (not fp32)
|
|
if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \
|
|
(src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()):
|
|
return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
|
|
|
|
# Standard floating types' casting: truncation
|
|
# fp64 => fp32, fp16, bf16
|
|
# fp32 => fp16, bf16
|
|
truncate_fp = src_sca_ty.is_floating() and \
|
|
dst_sca_ty.is_floating() and \
|
|
src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth
|
|
if truncate_fp:
|
|
return tl.tensor(builder.create_fp_trunc(input.handle,
|
|
dst_ty.to_ir(builder)),
|
|
dst_ty)
|
|
|
|
# Standard floating types' casting: extension
|
|
# fp32 => fp64
|
|
# fp16 => fp32, fp64
|
|
# bf16 => fp32, fp64
|
|
ext_fp = src_sca_ty.is_floating() and \
|
|
dst_sca_ty.is_floating() and \
|
|
src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth
|
|
if ext_fp:
|
|
return tl.tensor(builder.create_fp_ext(input.handle,
|
|
dst_ty.to_ir(builder)),
|
|
dst_ty)
|
|
|
|
# Casting between integer types
|
|
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
|
|
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
|
|
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
|
|
if dst_sca_ty.is_bool():
|
|
ty = input.dtype.to_ir(builder)
|
|
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
|
return not_equal(input, _0, builder)
|
|
else:
|
|
return tl.tensor(builder.create_int_cast(input.handle,
|
|
dst_ty.to_ir(builder), sign_extend),
|
|
dst_ty)
|
|
|
|
# Casting standard floating types to integer types
|
|
if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int():
|
|
if dst_sca_ty.is_bool():
|
|
ty = input.dtype.to_ir(builder)
|
|
_0 = tl.tensor(builder.get_null_value(ty), input.dtype)
|
|
return not_equal(input, _0, builder)
|
|
elif dst_sca_ty.is_int_signed():
|
|
return tl.tensor(builder.create_fp_to_si(input.handle,
|
|
dst_ty.to_ir(builder)),
|
|
dst_ty)
|
|
else:
|
|
return tl.tensor(builder.create_fp_to_ui(input.handle,
|
|
dst_ty.to_ir(builder)),
|
|
dst_ty)
|
|
|
|
# Casting integer types to standard floating types
|
|
if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating():
|
|
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
|
|
return tl.tensor(builder.create_ui_to_fp(input.handle,
|
|
dst_ty.to_ir(builder)),
|
|
dst_ty)
|
|
else:
|
|
return tl.tensor(builder.create_si_to_fp(input.handle,
|
|
dst_ty.to_ir(builder)),
|
|
dst_ty)
|
|
|
|
# Casting pointer types to integer types
|
|
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
|
|
bitwidth = dst_sca_ty.int_bitwidth
|
|
if bitwidth == 64:
|
|
return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)),
|
|
dst_ty)
|
|
if bitwidth == 1:
|
|
return not_equal(cast(input, tl.int64, builder),
|
|
tl.tensor(builder.get_int64(0), tl.int64),
|
|
builder)
|
|
|
|
# Casting integer types to pointer types
|
|
if src_sca_ty.is_int() and dst_sca_ty.is_ptr():
|
|
return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
|
|
|
# Casting pointer types to pointer types
|
|
if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
|
|
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
|
|
|
|
assert False, f'cannot cast {input} to {dst_ty}'
|
|
|
|
# ===----------------------------------------------------------------------===//
|
|
# Memory Operators
|
|
# ===----------------------------------------------------------------------===//
|
|
|
|
|
|
def load(ptr: tl.tensor,
|
|
mask: Optional[tl.tensor],
|
|
other: Optional[tl.tensor],
|
|
cache_modifier: str,
|
|
eviction_policy: str,
|
|
is_volatile: bool,
|
|
builder: ir.builder) -> tl.tensor:
|
|
if not ptr.type.scalar.is_ptr():
|
|
raise ValueError("Pointer argument of load instruction is " + ptr.type.__repr__())
|
|
if ptr.type.is_block():
|
|
if mask:
|
|
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
|
|
if other:
|
|
other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)
|
|
|
|
ptr_ty = ptr.type.scalar
|
|
elt_ty = ptr_ty.element_ty
|
|
|
|
# treat bool* as tl.int8*
|
|
if elt_ty == tl.int1:
|
|
elt_ty = tl.int8
|
|
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
|
ptr = cast(ptr, ptr_ty, builder)
|
|
|
|
if other:
|
|
other = cast(other, elt_ty, builder)
|
|
|
|
# cache modifier
|
|
cache = ir.CACHE_MODIFIER.NONE # default
|
|
if cache_modifier:
|
|
if cache_modifier == ".ca":
|
|
cache = ir.CACHE_MODIFIER.CA
|
|
elif cache_modifier == ".cg":
|
|
cache = ir.CACHE_MODIFIER.CG
|
|
else:
|
|
raise ValueError(f"Cache modifier {cache_modifier} not supported")
|
|
|
|
# eviction policy
|
|
eviction = ir.EVICTION_POLICY.NORMAL # default
|
|
if eviction_policy:
|
|
if eviction_policy == "evict_last":
|
|
eviction = ir.EVICTION_POLICY.EVICT_LAST
|
|
elif eviction_policy == "evict_first":
|
|
eviction = ir.EVICTION_POLICY.EVICT_FIRST
|
|
else:
|
|
raise ValueError(f"Eviction policy {eviction_policy} not supported")
|
|
|
|
if ptr.type.is_block():
|
|
shape = ptr.type.get_block_shapes()
|
|
dst_ty = tl.block_type(elt_ty, shape)
|
|
else:
|
|
dst_ty = elt_ty
|
|
|
|
if not mask:
|
|
if other:
|
|
raise ValueError("`other` cannot be provided without `mask`")
|
|
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile),
|
|
dst_ty)
|
|
else:
|
|
return tl.tensor(builder.create_masked_load(ptr.handle,
|
|
mask.handle,
|
|
other.handle if other else None,
|
|
cache, eviction, is_volatile),
|
|
dst_ty)
|
|
|
|
|
|
def store(ptr: tl.tensor,
|
|
val: tl.tensor,
|
|
mask: Optional[tl.tensor],
|
|
builder: ir.builder) -> tl.tensor:
|
|
if not ptr.type.scalar.is_ptr():
|
|
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
|
|
if ptr.type.is_block():
|
|
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
|
|
if mask:
|
|
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
|
|
ptr_ty = ptr.type.scalar
|
|
elt_ty = ptr_ty.element_ty
|
|
# treat bool* as tl.int8*
|
|
if elt_ty == tl.int1:
|
|
elt_ty = tl.int8
|
|
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
|
ptr = cast(ptr, ptr_ty, builder)
|
|
|
|
# cast to target data-type
|
|
val = cast(val, elt_ty, builder)
|
|
if not mask:
|
|
return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void)
|
|
if not mask.type.scalar.is_bool():
|
|
raise ValueError("Mask must have boolean scalar type")
|
|
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void)
|
|
|
|
#########
|
|
# atomic
|
|
#########
|
|
|
|
|
|
def atomic_cas(ptr: tl.tensor,
|
|
cmp: tl.tensor,
|
|
val: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
element_ty = ptr.type.scalar.element_ty
|
|
if element_ty.primitive_bitwidth not in [16, 32, 64]:
|
|
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
|
|
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type)
|
|
|
|
|
|
def atom_red_typechecking_impl(ptr: tl.tensor,
|
|
val: tl.tensor,
|
|
mask: tl.tensor,
|
|
op: str,
|
|
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
|
|
if not ptr.type.scalar.is_ptr():
|
|
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
|
|
|
|
element_ty = ptr.type.scalar.element_ty
|
|
if element_ty is tl.float16 and op != 'add':
|
|
raise ValueError("atomic_" + op + " does not support fp16")
|
|
if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
|
|
raise ValueError("atomic_" + op + " does not support " + element_ty)
|
|
if ptr.type.is_block():
|
|
if mask:
|
|
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
|
|
if val:
|
|
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
|
|
val = cast(val, ptr.type.scalar.element_ty, builder)
|
|
if not mask:
|
|
mask_ir = builder.get_int1(True)
|
|
mask_ty = tl.int1
|
|
if ptr.type.is_block():
|
|
mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes())
|
|
mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes())
|
|
mask = tl.tensor(mask_ir, mask_ty)
|
|
return ptr, val, mask
|
|
|
|
|
|
def atomic_max(ptr: tl.tensor,
|
|
val: tl.tensor,
|
|
mask: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
|
|
sca_ty = val.type.scalar
|
|
# direct call to atomic_max for integers
|
|
if sca_ty.is_int():
|
|
if sca_ty.is_int_signed():
|
|
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX,
|
|
ptr.handle,
|
|
val.handle,
|
|
mask.handle),
|
|
val.type)
|
|
else:
|
|
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
|
|
ptr.handle,
|
|
val.handle,
|
|
mask.handle),
|
|
val.type)
|
|
# for float
|
|
# return atomic_smax(i_ptr, i_val) if val >= 0
|
|
# return atomic_umin(i_ptr, i_val) if val < 0
|
|
i_val = bitcast(val, tl.int32, builder)
|
|
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
|
|
pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
|
neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
|
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type)
|
|
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type)
|
|
return where(pos, pos_ret, neg_ret, builder)
|
|
|
|
|
|
def atomic_min(ptr: tl.tensor,
|
|
val: tl.tensor,
|
|
mask: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
|
|
sca_ty = val.type.scalar
|
|
# direct call to atomic_min for integers
|
|
if sca_ty.is_int():
|
|
if sca_ty.is_int_signed():
|
|
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
|
|
ptr.handle,
|
|
val.handle,
|
|
mask.handle),
|
|
val.type)
|
|
else:
|
|
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN,
|
|
ptr.handle,
|
|
val.handle,
|
|
mask.handle),
|
|
val.type)
|
|
# for float
|
|
# return atomic_smin(i_ptr, i_val) if val >= 0
|
|
# return atomic_umax(i_ptr, i_val) if val < 0
|
|
i_val = bitcast(val, tl.int32, builder)
|
|
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
|
|
pos = greater_equal(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
|
neg = less_than(val, tl.tensor(builder.get_float32(0), sca_ty), builder)
|
|
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
|
|
i_ptr.handle,
|
|
i_val.handle,
|
|
and_(mask, pos, builder).handle),
|
|
i_val.type)
|
|
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
|
|
i_ptr.handle,
|
|
i_val.handle,
|
|
and_(mask, neg, builder).handle),
|
|
i_val.type)
|
|
return where(pos, pos_ret, neg_ret, builder)
|
|
|
|
|
|
def atomic_add(ptr: tl.tensor,
|
|
val: tl.tensor,
|
|
mask: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
|
|
sca_ty = val.type.scalar
|
|
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
|
|
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type)
|
|
|
|
|
|
def atomic_and(ptr: tl.tensor,
|
|
val: tl.tensor,
|
|
mask: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
|
|
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type)
|
|
|
|
|
|
def atomic_or(ptr: tl.tensor,
|
|
val: tl.tensor,
|
|
mask: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
|
|
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type)
|
|
|
|
|
|
def atomic_xor(ptr: tl.tensor,
|
|
val: tl.tensor,
|
|
mask: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
|
|
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type)
|
|
|
|
|
|
def atomic_xchg(ptr: tl.tensor,
|
|
val: tl.tensor,
|
|
mask: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
|
|
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type)
|
|
|
|
# ===----------------------------------------------------------------------===//
|
|
# Linear Algebra
|
|
# ===----------------------------------------------------------------------===//
|
|
|
|
|
|
def dot(lhs: tl.tensor,
|
|
rhs: tl.tensor,
|
|
allow_tf32: bool,
|
|
builder: ir.builder) -> tl.tensor:
|
|
assert lhs.type.is_block() and rhs.type.is_block()
|
|
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
|
|
assert lhs.shape[1].value == rhs.shape[0].value
|
|
assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \
|
|
and rhs.shape[1].value >= 16,\
|
|
"small blocks not supported!"
|
|
if lhs.type.scalar.is_int():
|
|
_0 = builder.get_int32(0)
|
|
ret_scalar_ty = tl.int32
|
|
else:
|
|
_0 = builder.get_float32(0)
|
|
ret_scalar_ty = tl.float32
|
|
M = lhs.type.shape[0]
|
|
N = rhs.type.shape[1]
|
|
_0 = builder.create_splat(_0, [M, N])
|
|
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
|
|
return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32),
|
|
ret_ty)
|
|
|
|
|
|
# ===----------------------------------------------------------------------===//
|
|
# Indexing
|
|
# ===----------------------------------------------------------------------===//
|
|
|
|
def where(condition: tl.tensor,
|
|
x: tl.tensor,
|
|
y: tl.tensor,
|
|
builder: ir.builder) -> tl.tensor:
|
|
condition = cast(condition, tl.int1, builder)
|
|
if condition.type.is_block():
|
|
condition, x = broadcast_impl_value(condition, x, builder)
|
|
x, y = broadcast_impl_value(x, y, builder)
|
|
condition, x = broadcast_impl_value(condition, x, builder)
|
|
|
|
x, y = binary_op_type_checking_impl(x, y, builder, True, True)
|
|
if not condition.type.is_block():
|
|
condition, _ = broadcast_impl_value(condition, x, builder)
|
|
ret_ty = x.type
|
|
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
|
|
|
|
# ===----------------------------------------------------------------------===//
|
|
# Reductions
|
|
# ===----------------------------------------------------------------------===
|
|
|
|
|
|
def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
|
FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor:
|
|
scalar_ty = input.type.scalar
|
|
# input is extended to 32-bits if necessary
|
|
# this increases numerical accuracy and can be done pretty much for free
|
|
# on GPUs
|
|
if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32:
|
|
input = cast(input, tl.int32, builder)
|
|
|
|
# hardware doesn't support FMAX, FMIN, CMP for bfloat16
|
|
if scalar_ty is tl.bfloat16:
|
|
input = cast(input, tl.float32, builder)
|
|
|
|
# choose the right unsigned operation
|
|
if scalar_ty.is_int_unsigned():
|
|
int_op_to_unit = {
|
|
ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN,
|
|
ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX,
|
|
ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN,
|
|
ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX,
|
|
}
|
|
if INT_OP in int_op_to_unit:
|
|
INT_OP = int_op_to_unit[INT_OP]
|
|
|
|
# get result type
|
|
shape = input.type.shape
|
|
ret_shape = []
|
|
for i, s in enumerate(shape):
|
|
if i != axis:
|
|
ret_shape.append(s)
|
|
if ret_shape:
|
|
res_ty = tl.block_type(scalar_ty, ret_shape)
|
|
else:
|
|
# 0d-tensor -> scalar
|
|
res_ty = scalar_ty
|
|
|
|
if scalar_ty.is_floating():
|
|
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
|
|
elif scalar_ty.is_int():
|
|
return tl.tensor(builder.create_reduce(input.handle, INT_OP, axis), res_ty)
|
|
assert False
|
|
|
|
|
|
def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
|
return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN)
|
|
|
|
|
|
def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
|
return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN)
|
|
|
|
|
|
def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
|
return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX)
|
|
|
|
|
|
def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
|
return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX)
|
|
|
|
|
|
def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
|
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)
|
|
|
|
|
|
def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
|
scalar_ty = input.type.scalar
|
|
if not scalar_ty.is_int():
|
|
raise ValueError("xor_sum only supported for integers")
|
|
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR)
|
|
|
|
|
|
# ===----------------------------------------------------------------------===
|
|
# Math
|
|
# ===----------------------------------------------------------------------===
|
|
|
|
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
x, y = binary_op_type_checking_impl(x, y, builder)
|
|
from . import libdevice
|
|
return libdevice.mulhi(x, y, _builder=builder)
|
|
|
|
|
|
def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
from . import libdevice
|
|
return libdevice.floor(x, _builder=builder)
|
|
|
|
|
|
def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
return tl.tensor(builder.create_exp(x.handle), x.type)
|
|
|
|
|
|
def log(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
return tl.tensor(builder.create_log(x.handle), x.type)
|
|
|
|
|
|
def cos(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
return tl.tensor(builder.create_cos(x.handle), x.type)
|
|
|
|
|
|
def sin(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
return tl.tensor(builder.create_sin(x.handle), x.type)
|
|
|
|
|
|
def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
|
return tl.tensor(builder.create_sqrt(x.handle), x.type)
|
|
|
|
|
|
##
|
|
|
|
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
|
|
if len(x.shape) != len(values):
|
|
raise ValueError("Shape of input to multiple_of does not match the length of values")
|
|
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
|
|
return x
|
|
|
|
|
|
def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
|
|
if len(x.shape) != len(values):
|
|
raise ValueError("Shape of input to max_contiguous does not match the length of values")
|
|
x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context()))
|
|
return x
|
|
|
|
|
|
def debug_barrier(builder: ir.builder) -> tl.tensor:
|
|
return tl.tensor(builder.create_barrier(), tl.void)
|
|
|
|
|
|
def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
|
|
new_args = []
|
|
for arg in args:
|
|
new_args.append(arg.handle)
|
|
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)
|