Files
triton/python/triton/language/semantic.py
Michael Melesse 4fb9d4904e fix 6/7 dot tests
2022-11-01 14:18:06 +00:00

1153 lines
46 KiB
Python

from __future__ import annotations # remove after python 3.11
from typing import List, Optional, Tuple
from . import core as tl
from triton._C.libtriton.triton import ir
import torch
# Create custom exception that prints message "hello"
class IncompatibleTypeErrorimpl(Exception):
def __init__(self, type_a, type_b):
self.type_a = type_a
self.type_b = type_b
self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__()
super(IncompatibleTypeErrorimpl, self).__init__(self.message)
# ===----------------------------------------------------------------------===##
# Programming Model
# ===----------------------------------------------------------------------===##
def program_id(axis: int, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_get_program_id(axis), tl.int32)
def num_programs(axis: int, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_get_num_programs(axis), tl.int32)
# ===----------------------------------------------------------------------===//
# Implicit Casting Utilities
# ===----------------------------------------------------------------------===//
def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype:
a_rank = a_ty.int_bitwidth
b_rank = b_ty.int_bitwidth
a_sn = a_ty.int_signedness
b_sn = b_ty.int_signedness
# Rules for signedness taken from "Usual arithmetic conversions" on
# https://en.cppreference.com/w/c/language/conversion.
if a_sn == b_sn:
return a_ty if a_rank > b_rank else b_ty
elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
return a_ty if a_rank >= b_rank else b_ty
elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED:
return b_ty if b_rank >= a_rank else a_ty
assert False
def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype:
# 1) if one operand is double, the other is implicitly
# converted to double
if a_ty.is_fp64() or b_ty.is_fp64():
return tl.float64
# 2) if one operand is float, the other is implicitly
# converted to float
if a_ty.is_fp32() or b_ty.is_fp32():
return tl.float32
# 3 ) if one operand is half, the other is implicitly converted to half
# unless we're doing / or %, which do not exist natively in PTX for fp16.
# Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp
if a_ty.is_fp16() or b_ty.is_fp16():
if div_or_mod:
return tl.float32
else:
return tl.float16
# 4) return bf16 only if both operands are of bf16
if a_ty.is_bf16() or b_ty.is_bf16():
if div_or_mod:
return tl.float32
if a_ty.is_bf16() and b_ty.is_bf16():
return tl.bfloat16
return tl.float32
if not a_ty.is_int() or not b_ty.is_int():
assert False
# 5 ) both operands are integer and undergo
# integer promotion
if div_or_mod and a_ty.int_signedness != b_ty.int_signedness:
raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;"
"this is unlikely to result in a useful answer. Cast them to the same signedness.")
return integer_promote_impl(a_ty, b_ty)
# ===----------------------------------------------------------------------===//
# Binary Operators
# ===----------------------------------------------------------------------===//
def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
if type_a.is_ptr():
if not allow_ptr_a:
raise IncompatibleTypeErrorimpl(type_a, type_b)
# T* + U* with T != U
if type_b.is_ptr() and (type_a != type_b):
raise IncompatibleTypeErrorimpl(type_a, type_b)
# T* + float
if type_b.is_floating():
raise IncompatibleTypeErrorimpl(type_a, type_b)
def binary_op_type_checking_impl(lhs: tl.tensor,
rhs: tl.tensor,
builder: ir.builder,
allow_lhs_ptr=False, allow_rhs_ptr=False,
arithmetic_check=True, div_or_mod=False
) -> Tuple[tl.tensor, tl.tensor]:
# implicit broadcasting
lhs, rhs = broadcast_impl_value(lhs, rhs, builder)
# implicit typecasting
lhs_sca_ty = lhs.type.scalar
rhs_sca_ty = rhs.type.scalar
check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr)
check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr)
if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr():
ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod)
lhs = cast(lhs, ret_sca_ty, builder)
rhs = cast(rhs, ret_sca_ty, builder)
return lhs, rhs
def add(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, True, True)
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
# offset + ptr
# ptr + offset
if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr():
input, other = other, input
if input_scalar_ty.is_ptr():
return tl.tensor(builder.create_gep(input.handle, [other.handle]), input.type)
# float + float
elif input_scalar_ty.is_floating():
return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
# int + int
elif input_scalar_ty.is_int():
return tl.tensor(builder.create_add(input.handle, other.handle), input.type)
assert False
def sub(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, True, False)
scalar_ty = input.type.scalar
# ptr - offset
if scalar_ty.is_ptr():
return tl.tensor(builder.create_gep(input.handle, [minus(other, builder).handle]),
input.type)
# float - float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type)
# int - int
elif scalar_ty.is_int():
return tl.tensor(builder.create_sub(input.handle, other.handle), input.type)
assert False
def mul(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float * float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type)
# * int
elif scalar_ty.is_int():
return tl.tensor(builder.create_mul(input.handle, other.handle), input.type)
assert False
def truediv(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
# float / int
if input_scalar_ty.is_floating() and other_scalar_ty.is_int():
other = cast(other, input_scalar_ty, builder)
# int / float
elif input_scalar_ty.is_int() and other_scalar_ty.is_floating():
input = cast(input, other_scalar_ty, builder)
# int / int (cast to tl.float32)
elif input_scalar_ty.is_int() and other_scalar_ty.is_int():
input = cast(input, tl.float32, builder)
other = cast(other, tl.float32, builder)
# float / float (cast to highest exponent type)
elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating():
if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width:
other = cast(other, input_scalar_ty, builder)
else:
input = cast(input, other_scalar_ty, builder)
# unreachable
else:
assert False
return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type)
def floordiv(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
if input_scalar_ty.is_int() and other_scalar_ty.is_int():
ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty)
input = cast(input, ret_ty, builder)
other = cast(other, ret_ty, builder)
if ret_ty.is_int_signed():
return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type)
else:
return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type)
assert False
def fdiv(input: tl.tensor,
other: tl.tensor,
ieee_rounding: bool,
builder: ir.builder) -> tl.tensor:
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
raise ValueError("both operands of fdiv must have floating poscalar type")
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
ret = builder.create_fdiv(input.handle, other.handle)
ret.set_fdiv_ieee_rounding(ieee_rounding)
return tl.tensor(ret, input.type)
def mod(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True)
scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
# float % float
if scalar_ty.is_floating():
return tl.tensor(builder.create_frem(input.handle, other.handle), input.type)
# % int
elif scalar_ty.is_int():
if scalar_ty.int_signedness != other_scalar_ty.int_signedness:
raise ValueError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " "
"because they have different signedness;"
"this is unlikely to result in a useful answer. Cast them to the same signedness.")
if scalar_ty.is_int_signed():
return tl.tensor(builder.create_srem(input.handle, other.handle), input.type)
else:
return tl.tensor(builder.create_urem(input.handle, other.handle), input.type)
assert False
##############
# bitwise ops
##############
def bitwise_op_type_checking_impl(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]:
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False)
input_sca_ty = input.type.scalar
other_sca_ty = other.type.scalar
if not input_sca_ty.is_int() or not other_sca_ty.is_int():
raise IncompatibleTypeErrorimpl(input_sca_ty, other_sca_ty)
ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty)
if ret_sca_ty != input_sca_ty:
input = cast(input, ret_sca_ty, builder)
if ret_sca_ty != other_sca_ty:
other = cast(other, ret_sca_ty, builder)
return input, other
def and_(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_and(input.handle, other.handle), input.type)
def or_(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_or(input.handle, other.handle), input.type)
def xor_(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_xor(input.handle, other.handle), input.type)
def lshr(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type)
def shl(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = bitwise_op_type_checking_impl(input, other, builder)
return tl.tensor(builder.create_shl(input.handle, other.handle), input.type)
# ===----------------------------------------------------------------------===//
# Unary Operators
# ===----------------------------------------------------------------------===//
def plus(input: tl.tensor) -> tl.tensor:
return input
def minus(input: tl.tensor,
builder: ir.builder) -> tl.tensor:
input_sca_ty = input.type.scalar
if input_sca_ty.is_ptr():
raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")")
_0 = tl.tensor(ir.constant.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty)
return sub(_0, input, builder)
def invert(input: tl.tensor,
builder: tl.tensor) -> tl.tensor:
input_sca_ty = input.type.scalar
if input_sca_ty.is_ptr() or input_sca_ty.is_floating():
raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")")
_1 = tl.tensor(ir.constant.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty)
return xor_(input, _1, builder)
# ===----------------------------------------------------------------------===//
# Comparison Operators
# ===----------------------------------------------------------------------===//
def _bool_like(v: tl.tensor) -> tl.block_type:
if not v.type.is_block():
return tl.int1
shape = v.type.shape
return tl.block_type(tl.int1, shape)
def greater_than(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float > float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input))
# > int
elif scalar_ty.is_int():
if scalar_ty.is_int_signed():
return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input))
else:
return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input))
assert False
def greater_equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float >= float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input))
# >= int
elif scalar_ty.is_int():
if scalar_ty.is_int_signed():
return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input))
else:
return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input))
assert False
def less_than(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float < float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input))
# < int
elif scalar_ty.is_int():
if scalar_ty.is_int_signed():
return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input))
else:
return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input))
assert False
def less_equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float < float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input))
# < int
elif scalar_ty.is_int():
if scalar_ty.is_int_signed():
return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input))
else:
return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input))
assert False
def equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float == float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input))
# == int
elif scalar_ty.is_int():
return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input))
assert False
def not_equal(input: tl.tensor,
other: tl.tensor,
builder: ir.builder) -> tl.tensor:
input, other = binary_op_type_checking_impl(input, other, builder)
scalar_ty = input.type.scalar
# float == float
if scalar_ty.is_floating():
return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input))
# == int
elif scalar_ty.is_int():
return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input))
assert False
# ===----------------------------------------------------------------------===//
# Block Creation
# ===----------------------------------------------------------------------===//
def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
if not isinstance(start, int) or not isinstance(end, int):
raise ValueError("arange's arguments must be of type tl.constexpr")
shape = [end - start]
ret_ty = tl.block_type(tl.int32, shape)
return tl.tensor(builder.get_range(start, end), ret_ty)
def zeros(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
_0 = ir.constant.get_null_value(dtype.to_ir(builder))
ret_ty = tl.block_type(dtype, shape)
return tl.tensor(builder.create_splat(_0, shape), ret_ty)
# ===----------------------------------------------------------------------===//
# Shape Manipulation
# ===----------------------------------------------------------------------===//
def reshape(input: tl.tensor,
dst_shape: List[int],
builder: ir.builder) -> tl.tensor:
numel = 1
for s in dst_shape:
numel *= s
if input.type.numel != numel:
raise ValueError("cannot reshape block of different shape")
ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_reshape(input.handle, dst_shape), ret_ty)
def cat(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
assert lhs.type.shape[1:] == rhs.type.shape[1:]
ret_shape = [lhs.type.shape[0] + rhs.type.shape[0]]
ret_ty = tl.block_type(lhs.type.scalar, ret_shape)
return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_ty)
def broadcast_impl_shape(input: tl.tensor,
shape: List[int],
builder: ir.builder) -> tl.tensor:
if not input.type.is_block():
ret_ty = tl.block_type(input.type, shape)
return tl.tensor(builder.create_splat(input.handle, shape), ret_ty)
src_shape = input.type.get_block_shapes()
if len(src_shape) != len(shape):
raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
if shape == src_shape:
return input
for i in range(len(src_shape)):
if shape[i] != src_shape[i] and src_shape[i] != 1:
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
f" must match the existing size ({src_shape[1]}) at non-singleton dimension"
f" {i}: {src_shape}, {shape}")
ret_ty = tl.block_type(input.type.scalar, shape)
return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty)
def broadcast_impl_value(lhs: tl.tensor,
rhs: tl.tensor,
builder: ir.builder) -> tl.tensor:
lhs_ty = lhs.type
rhs_ty = rhs.type
# make_shape_compatible(block, scalar)
if lhs_ty.is_block() and not rhs_ty.is_block():
rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape)
rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty)
# make_shape_compatible(scalar, block)
elif not lhs_ty.is_block() and rhs_ty.is_block():
lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape)
lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty)
# make_shape_compatible(block, block)
elif lhs_ty.is_block() and rhs_ty.is_block():
lhs_shape = lhs_ty.get_block_shapes()
rhs_shape = rhs_ty.get_block_shapes()
if len(lhs_shape) != len(rhs_shape):
raise ValueError("Cannot make_shape_compatible: blocks must have the same rank")
ret_shape = []
for i in range(len(lhs_shape)):
left = lhs_shape[i]
right = rhs_shape[i]
if left == 1:
ret_shape.append(right)
elif right == 1:
ret_shape.append(left)
elif left == right:
ret_shape.append(left)
else:
raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
"at index " + str(i) + ": " + str(left) + " and " + str(right))
if lhs_shape != ret_shape:
ret_ty = tl.block_type(lhs_ty.scalar, ret_shape)
lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty)
if rhs_shape != ret_shape:
ret_ty = tl.block_type(rhs_ty.scalar, ret_shape)
rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty)
# (scalar, scalar) => returns original blocks
return lhs, rhs
#######
# dequantize
#######
def dequantize(input: tl.tensor,
scale: tl.tensor,
shift: tl.tensor,
nbit: int,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
input_ty = input.type
assert input_ty.is_block()
assert input_ty.element_ty.is_int32() or input_ty.element_ty.is_int16()
assert nbit in [2, 4, 8]
assert dst_ty == tl.float16
shape = input_ty.get_block_shapes()
factor = input_ty.element_ty.primitive_bitwidth // nbit
dst_shape = shape[:-1] + [factor * shape[-1]]
dst_ty = tl.block_type(dst_ty, dst_shape)
return tl.tensor(builder.create_dequantize(input.handle, scale.handle, shift.handle, dst_ty.to_ir(builder)), dst_ty)
#######
# cast
#######
def bitcast(input: tl.tensor,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
src_ty = input.type
if src_ty.is_block():
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
if src_ty == dst_ty:
return input
src_sca_ty = src_ty.scalar
dst_sca_ty = dst_ty.scalar
if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr():
return cast(input, dst_ty, builder)
# Bitcast
src_bits = src_sca_ty.primitive_bitwidth
dst_bits = dst_sca_ty.primitive_bitwidth
if src_bits != dst_bits:
raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + "to "
"data-type of size " + str(dst_bits))
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)),
dst_ty)
def cast(input: tl.tensor,
dst_ty: tl.dtype,
builder: ir.builder) -> tl.tensor:
src_ty = input.type
if src_ty.is_block() and not dst_ty.is_block():
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
if src_ty == dst_ty:
return input
src_sca_ty = src_ty.scalar
dst_sca_ty = dst_ty.scalar
# fp8 <=> bf16/fp16
if (src_sca_ty.is_bf16() or src_sca_ty.is_fp16()) and dst_sca_ty.is_fp8():
return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)),
dst_ty)
if src_sca_ty.is_fp8() and (dst_sca_ty.is_bf16() or dst_sca_ty.is_fp16()):
return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)),
dst_ty)
# bf16 <=> (not fp32)
if (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()) or \
(dst_sca_ty.is_bf16() and not src_sca_ty.is_fp32()):
return cast(cast(input, tl.float32, builder), dst_sca_ty, builder)
# FP Truncation
truncate_fp = src_sca_ty.is_floating() and \
dst_sca_ty.is_floating() and \
src_sca_ty.fp_mantissa_width > dst_sca_ty.fp_mantissa_width
if truncate_fp:
return tl.tensor(builder.create_fp_trunc(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
# FP Extension
ext_fp = src_sca_ty.is_floating() and \
dst_sca_ty.is_floating() and \
src_sca_ty.fp_mantissa_width < dst_sca_ty.fp_mantissa_width
if ext_fp:
return tl.tensor(builder.create_fp_ext(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
# Int cast
if src_sca_ty.is_int() and dst_sca_ty.is_int() and \
(src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness):
sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool()
return tl.tensor(builder.create_int_cast(input.handle,
dst_ty.to_ir(builder), sign_extend),
dst_ty)
# Float to Int
if src_sca_ty.is_floating() and dst_sca_ty.is_int():
# TODO: is this correct?
if dst_sca_ty.is_bool():
return not_equal(input, tl._to_tensor(0, builder), builder)
else:
return tl.tensor(builder.create_fp_to_si(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
# int => float
if src_sca_ty.is_int() and dst_sca_ty.is_floating():
if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed():
return tl.tensor(builder.create_ui_to_fp(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
else:
return tl.tensor(builder.create_si_to_fp(input.handle,
dst_ty.to_ir(builder)),
dst_ty)
# ptr => int
if src_sca_ty.is_ptr() and dst_sca_ty.is_int():
bitwidth = dst_sca_ty.int_bitwidth
if bitwidth == 64:
return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)),
dst_ty)
if bitwidth == 1:
return not_equal(cast(input, tl.int64, builder),
tl.tensor(builder.get_int64(0), tl.int64),
builder)
if not src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty)
# Ptr . Ptr
if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr():
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty)
# * . Bool
if dst_sca_ty.is_bool():
if src_sca_ty.is_ptr():
input = cast(input, tl.int64, builder)
other = builder.get_int64(0)
if src_ty.is_bool():
other = builder.create_splat(other, src_ty.get_block_shapes())
return tl.tensor(builder.create_icmpNE(input.handle, other), dst_ty)
assert False, f'cannot cast {input} to {dst_ty}'
# ===----------------------------------------------------------------------===//
# Memory Operators
# ===----------------------------------------------------------------------===//
def _parse_eviction_policy(eviction_policy):
eviction = ir.EVICTION_POLICY.NORMAL # default
if eviction_policy:
if eviction_policy == "evict_last":
eviction = ir.EVICTION_POLICY.EVICT_LAST
elif eviction_policy == "evict_first":
eviction = ir.EVICTION_POLICY.EVICT_FIRST
else:
raise ValueError(f"Eviction policy {eviction_policy} not supported")
return eviction
def load(ptr: tl.tensor,
mask: Optional[tl.tensor],
other: Optional[tl.tensor],
cache_modifier: str,
eviction_policy: str,
is_volatile: bool,
builder: ir.builder) -> tl.tensor:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of load instruction is " + ptr.type.__repr__())
if ptr.type.is_block():
if mask:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
if other:
other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)
if other:
other = cast(other, ptr.type.scalar.element_ty, builder)
ptr_ty = ptr.type.scalar
elt_ty = ptr_ty.element_ty
# treat bool* as tl.int8*
if elt_ty == tl.int1:
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)
# cache modifier
cache = ir.CACHE_MODIFIER.NONE # default
if cache_modifier:
if cache_modifier == ".ca":
cache = ir.CACHE_MODIFIER.CA
elif cache_modifier == ".cg":
cache = ir.CACHE_MODIFIER.CG
else:
raise ValueError(f"Cache modifier {cache_modifier} not supported")
# eviction policy
eviction = _parse_eviction_policy(eviction_policy)
if ptr.type.is_block():
shape = ptr.type.get_block_shapes()
dst_ty = tl.block_type(elt_ty, shape)
else:
dst_ty = elt_ty
if not mask and not other:
return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile),
dst_ty)
if not mask:
raise ValueError("`other` cannot be provided without `mask`")
if not other:
other_ir = ir.undef.get(elt_ty.to_ir(builder))
if ptr.type.is_block():
other_ir = builder.create_splat(other_ir, ptr.type.get_block_shapes())
other = tl.tensor(other_ir, dst_ty)
return tl.tensor(builder.create_masked_load(ptr.handle,
mask.handle,
other.handle,
cache, eviction, is_volatile),
dst_ty)
def store(ptr: tl.tensor,
val: tl.tensor,
mask: Optional[tl.tensor],
eviction_policy: str,
builder: ir.builder) -> tl.tensor:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
if ptr.type.is_block():
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
if mask:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
ptr_ty = ptr.type.scalar
elt_ty = ptr_ty.element_ty
# treat bool* as tl.int8*
if elt_ty == tl.int1:
# convert to bool first and then store as int8
val = cast(val, tl.int1, builder)
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)
# eviction policy
eviction = _parse_eviction_policy(eviction_policy)
# cast to target data-type
val = cast(val, elt_ty, builder)
if not mask:
return tl.tensor(builder.create_store(ptr.handle, val.handle, eviction), tl.void)
if not mask.type.scalar.is_bool():
raise ValueError("Mask must have boolean scalar type")
return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, eviction), tl.void)
#########
# atomic
#########
def atomic_cas(ptr: tl.tensor,
cmp: tl.tensor,
val: tl.tensor,
builder: ir.builder) -> tl.tensor:
element_ty = ptr.type.scalar.element_ty
if element_ty.primitive_bitwidth not in [16, 32, 64]:
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type)
def atom_red_typechecking_impl(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
op: str,
builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]:
if not ptr.type.scalar.is_ptr():
raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__())
element_ty = ptr.type.scalar.element_ty
if element_ty is tl.float16 and op != 'add':
raise ValueError("atomic_" + op + " does not support fp16")
if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
raise ValueError("atomic_" + op + " does not support " + element_ty)
if ptr.type.is_block():
if mask:
mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder)
if val:
val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder)
val = cast(val, ptr.type.scalar.element_ty, builder)
if not mask:
mask_ir = builder.get_int1(True)
mask_ty = tl.int1
if ptr.type.is_block():
mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes())
mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes())
mask = tl.tensor(mask_ir, mask_ty)
return ptr, val, mask
def atomic_max(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
sca_ty = val.type.scalar
# direct call to atomic_max for integers
if sca_ty.is_int():
if sca_ty.is_int_signed():
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX,
ptr.handle,
val.handle,
mask.handle),
val.type)
else:
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
ptr.handle,
val.handle,
mask.handle),
val.type)
# for float
# return atomic_smax(i_ptr, i_val) if val >= 0
# return atomic_umin(i_ptr, i_val) if val < 0
i_val = bitcast(val, tl.int32, builder)
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type)
return where(pos, pos_ret, neg_ret, builder)
def atomic_min(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
sca_ty = val.type.scalar
# direct call to atomic_min for integers
if sca_ty.is_int():
if sca_ty.is_int_signed():
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
ptr.handle,
val.handle,
mask.handle),
val.type)
else:
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN,
ptr.handle,
val.handle,
mask.handle),
val.type)
# for float
# return atomic_smin(i_ptr, i_val) if val >= 0
# return atomic_umax(i_ptr, i_val) if val < 0
i_val = bitcast(val, tl.int32, builder)
i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder)
pos = greater_equal(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
neg = less_than(val, tl.tensor(ir.constant_float.get(sca_ty.to_ir(builder), 0), sca_ty), builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN,
i_ptr.handle,
i_val.handle,
and_(mask, pos, builder).handle),
i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
i_ptr.handle,
i_val.handle,
and_(mask, neg, builder).handle),
i_val.type)
return where(pos, pos_ret, neg_ret, builder)
def atomic_add(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
sca_ty = val.type.scalar
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type)
def atomic_and(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type)
def atomic_or(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type)
def atomic_xor(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type)
def atomic_xchg(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type)
# ===----------------------------------------------------------------------===//
# Linear Algebra
# ===----------------------------------------------------------------------===//
def dot(a: tl.tensor,
b: tl.tensor,
trans_a: bool,
trans_b: bool,
allow_tf32: bool,
builder: ir.builder) -> tl.tensor:
if torch.version.hip is not None:
a = cast(a, tl.float32, builder)
b = cast(b, tl.float32, builder)
in_a = 1 if not trans_a else 0
in_b = 1 if trans_b else 0
assert a.type.is_block() and b.type.is_block()
assert len(a.shape) == 2 and len(b.shape) == 2
assert a.shape[in_a] == b.shape[in_b]
assert a.shape[0] >= 16 and a.shape[1] >= 16 and b.shape[1] >= 16,\
"small blocks not supported!"
if a.type.scalar.is_int():
_0 = builder.get_int32(0)
ret_scalar_ty = tl.int32
else:
_0 = builder.get_float32(0)
ret_scalar_ty = tl.float32
M = a.type.shape[in_a ^ 1]
N = b.type.shape[in_b ^ 1]
_0 = builder.create_splat(_0, [M, N])
ret_ty = tl.block_type(ret_scalar_ty, [M, N])
ret = builder.create_dot(a.handle, b.handle, _0, trans_a, trans_b, allow_tf32)
return tl.tensor(ret, ret_ty)
# ===----------------------------------------------------------------------===//
# Indexing
# ===----------------------------------------------------------------------===//
def where(condition: tl.tensor,
x: tl.tensor,
y: tl.tensor,
builder: ir.builder) -> tl.tensor:
condition = cast(condition, tl.int1, builder)
if condition.type.is_block():
condition, x = broadcast_impl_value(condition, x, builder)
x, y = broadcast_impl_value(x, y, builder)
condition, x = broadcast_impl_value(condition, x, builder)
x, y = binary_op_type_checking_impl(x, y, builder, True, True)
if not condition.type.is_block():
condition, _ = broadcast_impl_value(condition, x, builder)
ret_ty = x.type
return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty)
# ===----------------------------------------------------------------------===//
# Reductions
# ===----------------------------------------------------------------------===
def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
FLOAT_OP: ir.REDUCE_OP, INT_OP: ir.REDUCE_OP) -> tl.tensor:
scalar_ty = input.type.scalar
# input is extended to 32-bits if necessary
# this increases numerical accuracy and can be done pretty much for free
# on GPUs
if scalar_ty.is_int() and scalar_ty.int_bitwidth <= 32:
input = cast(input, tl.int32, builder)
# hardware doesn't support FMAX, FMIN, CMP for bfloat16
if scalar_ty is tl.bfloat16:
input = cast(input, tl.float32, builder)
# choose the right unsigned operation
if scalar_ty.is_int_unsigned():
int_op_to_unit = {
ir.REDUCE_OP.MIN: ir.REDUCE_OP.UMIN,
ir.REDUCE_OP.MAX: ir.REDUCE_OP.UMAX,
ir.REDUCE_OP.ARGMIN: ir.REDUCE_OP.ARGUMIN,
ir.REDUCE_OP.ARGMAX: ir.REDUCE_OP.ARGUMAX,
}
if INT_OP in int_op_to_unit:
INT_OP = int_op_to_unit[INT_OP]
# get result type
shape = input.type.shape
ret_shape = []
for i, s in enumerate(shape):
if i != axis:
ret_shape.append(s)
if len(ret_shape) == 0:
res_ty = scalar_ty
else:
res_ty = tl.block_type(scalar_ty, ret_shape)
if scalar_ty.is_floating():
return tl.tensor(builder.create_reduce(input.handle, FLOAT_OP, axis), res_ty)
elif scalar_ty.is_int():
return tl.tensor(builder.create_reduce(input.handle, INT_OP, axis), res_ty)
assert False
def min(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "min", ir.REDUCE_OP.FMIN, ir.REDUCE_OP.MIN)
def argmin(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "argmin", ir.REDUCE_OP.ARGFMIN, ir.REDUCE_OP.ARGMIN)
def max(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "max", ir.REDUCE_OP.FMAX, ir.REDUCE_OP.MAX)
def argmax(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "argmax", ir.REDUCE_OP.ARGFMAX, ir.REDUCE_OP.ARGMAX)
def sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.FADD, ir.REDUCE_OP.ADD)
def xor_sum(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
scalar_ty = input.type.scalar
if not scalar_ty.is_int():
raise ValueError("xor_sum only supported for integers")
return reduce_impl(input, axis, builder, "sum", ir.REDUCE_OP.XOR, ir.REDUCE_OP.XOR)
# -----------------------
# Utilities
# -----------------------
def clock(builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_clock(), tl.int64)
def globaltimer(builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_globaltimer, tl.int64)
# ===----------------------------------------------------------------------===
# Math
# ===----------------------------------------------------------------------===
def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor:
x, y = binary_op_type_checking_impl(x, y, builder)
return tl.tensor(builder.create_umulhi(x.handle, y.handle), x.type)
def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_exp(x.handle), x.type)
def log(x: tl.tensor, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_log(x.handle), x.type)
def cos(x: tl.tensor, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_cos(x.handle), x.type)
def sin(x: tl.tensor, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_sin(x.handle), x.type)
def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_sqrt(x.handle), x.type)
##
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
raise ValueError("Shape of input to multiple_of does not match the length of values")
x.handle.multiple_of(values)
return x
def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
raise ValueError("Shape of input to max_contiguous does not match the length of values")
x.handle.max_contiguous(values)
return x
def debug_barrier(builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_barrier(''), tl.void)