2021-04-20 22:29:40 -04:00
# include "triton/ir/dispatch.h"
2022-01-05 15:27:17 -08:00
namespace triton {
namespace ir {
2021-04-20 22:29:40 -04:00
2022-01-05 15:27:17 -08:00
[[ noreturn ]] void throw_unreachable ( std : : string key ) {
2021-04-20 22:29:40 -04:00
throw std : : runtime_error ( " Encountered unimplemented code path in ` " + key + " `. "
" This is likely a bug on our side. " ) ;
}
//===----------------------------------------------------------------------===//
// Programming Model
//===----------------------------------------------------------------------===//
ir : : value * dispatch : : program_id ( int axis , ir : : builder * builder ) {
return builder - > create_get_program_id ( axis ) ;
}
ir : : value * dispatch : : num_programs ( int axis , ir : : builder * builder ) {
return builder - > create_get_num_programs ( axis ) ;
}
//===----------------------------------------------------------------------===//
// Implicit Casting Utilities
//===----------------------------------------------------------------------===//
ir : : type * integer_promote ( ir : : type * a_ty , ir : : type * b_ty ) {
int a_rank = a_ty - > get_integer_bitwidth ( ) ;
int b_rank = b_ty - > get_integer_bitwidth ( ) ;
2022-01-05 15:27:17 -08:00
auto a_sn = a_ty - > get_integer_signedness ( ) ;
auto b_sn = b_ty - > get_integer_signedness ( ) ;
// Rules for signedness taken from "Usual arithmetic conversions" on
// https://en.cppreference.com/w/c/language/conversion.
if ( a_sn = = b_sn ) {
return a_rank > b_rank ? a_ty : b_ty ;
} else if ( a_sn = = signedness : : UNSIGNED ) {
return a_rank > = b_rank ? a_ty : b_ty ;
} else if ( b_sn = = signedness : : UNSIGNED ) {
return b_rank > = a_rank ? b_ty : a_ty ;
} else {
throw_unreachable ( " integer_promote " ) ;
}
2021-04-20 22:29:40 -04:00
}
2021-12-21 09:46:05 -08:00
enum class DivOrMod { NO , YES } ;
ir : : type * computation_type ( ir : : type * a_ty , ir : : type * b_ty , DivOrMod div_or_mod ) {
2021-04-20 22:29:40 -04:00
context & ctx = a_ty - > get_context ( ) ;
// 1) if one operand is double, the other is implicitly
// converted to double
2021-12-21 09:46:05 -08:00
if ( a_ty - > is_fp64_ty ( ) | | b_ty - > is_fp64_ty ( ) )
2021-06-25 10:19:29 -04:00
return type : : get_fp64_ty ( ctx ) ;
2021-04-20 22:29:40 -04:00
// 2) if one operand is float, the other is implicitly
// converted to float
2021-12-21 09:46:05 -08:00
if ( a_ty - > is_fp32_ty ( ) | | b_ty - > is_fp32_ty ( ) )
2021-06-25 10:19:29 -04:00
return type : : get_fp32_ty ( ctx ) ;
2021-12-21 09:46:05 -08:00
// 3 ) if one operand is half, the other is implicitly converted to half
// unless we're doing / or %, which do not exist natively in PTX for fp16.
if ( a_ty - > is_fp16_ty ( ) | | b_ty - > is_fp16_ty ( ) ) {
if ( div_or_mod = = DivOrMod : : YES ) {
return type : : get_fp32_ty ( ctx ) ;
} else {
return type : : get_fp16_ty ( ctx ) ;
}
}
if ( ! a_ty - > is_integer_ty ( ) | | ! b_ty - > is_integer_ty ( ) )
2021-12-23 17:01:17 -08:00
throw_unreachable ( " computation_type " ) ;
2021-04-20 22:29:40 -04:00
// 4 ) both operands are integer and undergo
// integer promotion
2022-01-05 15:27:17 -08:00
if ( div_or_mod = = DivOrMod : : YES & & a_ty - > get_integer_signedness ( ) ! = b_ty - > get_integer_signedness ( ) ) {
throw semantic_error ( " Cannot use /, //, or % with " + a_ty - > repr ( ) + " and " + b_ty - > repr ( ) + " because they have different signedness ; this is unlikely to result in a useful answer . Cast them to the same signedness . " );
}
2021-04-20 22:29:40 -04:00
return integer_promote ( a_ty , b_ty ) ;
}
//===----------------------------------------------------------------------===//
// Binary Operators
//===----------------------------------------------------------------------===//
void throw_incompatible_types ( ir : : type * type_a , ir : : type * type_b ) {
throw semantic_error ( " invalid operands of type " + type_a - > repr ( ) + " and " + type_b - > repr ( ) ) ;
}
void check_ptr_type ( ir : : type * type_a , ir : : type * type_b , bool allow_ptr_a ) {
if ( type_a - > is_pointer_ty ( ) ) {
if ( ! allow_ptr_a )
throw_incompatible_types ( type_a , type_b ) ;
// T* + U* with T != U
if ( type_b - > is_pointer_ty ( ) & & ( type_a ! = type_b ) )
throw_incompatible_types ( type_a , type_b ) ;
// T* + float
if ( type_b - > is_floating_point_ty ( ) )
throw_incompatible_types ( type_a , type_b ) ;
}
}
void binary_op_type_checking ( ir : : value * & lhs , ir : : value * & rhs , ir : : builder * builder ,
bool allow_lhs_ptr = false , bool allow_rhs_ptr = false ,
2021-12-21 09:46:05 -08:00
bool arithmetic_check = true , DivOrMod div_or_mod = DivOrMod : : NO ) {
2021-04-20 22:29:40 -04:00
// implicit broadcasting
std : : tie ( lhs , rhs ) = dispatch : : broadcast ( lhs , rhs , builder ) ;
// implicit typecasting
ir : : type * lhs_sca_ty = lhs - > get_type ( ) - > get_scalar_ty ( ) ;
ir : : type * rhs_sca_ty = rhs - > get_type ( ) - > get_scalar_ty ( ) ;
check_ptr_type ( lhs_sca_ty , rhs_sca_ty , allow_lhs_ptr ) ;
check_ptr_type ( rhs_sca_ty , lhs_sca_ty , allow_rhs_ptr ) ;
2021-12-21 09:46:05 -08:00
if ( arithmetic_check & & ! lhs_sca_ty - > is_pointer_ty ( ) & & ! rhs_sca_ty - > is_pointer_ty ( ) ) {
ir : : type * ret_sca_ty = computation_type ( lhs_sca_ty , rhs_sca_ty , div_or_mod ) ;
2021-04-20 22:29:40 -04:00
lhs = dispatch : : cast ( lhs , ret_sca_ty , builder ) ;
rhs = dispatch : : cast ( rhs , ret_sca_ty , builder ) ;
}
}
ir : : value * dispatch : : add ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
binary_op_type_checking ( input , other , builder , true , true ) ;
ir : : type * input_scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
ir : : type * other_scalar_ty = other - > get_type ( ) - > get_scalar_ty ( ) ;
// offset + ptr
// ptr + offset
if ( other_scalar_ty - > is_pointer_ty ( ) & & ! input_scalar_ty - > is_pointer_ty ( ) )
std : : swap ( input , other ) ;
if ( input_scalar_ty - > is_pointer_ty ( ) )
return builder - > create_gep ( input , { other } ) ;
// float + float
else if ( input_scalar_ty - > is_floating_point_ty ( ) )
return builder - > create_fadd ( input , other ) ;
// int + int
else if ( input_scalar_ty - > is_integer_ty ( ) )
return builder - > create_add ( input , other ) ;
2022-01-05 15:27:17 -08:00
throw_unreachable ( " add " ) ;
2021-04-20 22:29:40 -04:00
}
ir : : value * dispatch : : sub ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
binary_op_type_checking ( input , other , builder , true , false ) ;
ir : : type * scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
// ptr - offset
if ( scalar_ty - > is_pointer_ty ( ) )
return builder - > create_gep ( input , { dispatch : : minus ( other , builder ) } ) ;
// float + float
if ( scalar_ty - > is_floating_point_ty ( ) )
return builder - > create_fsub ( input , other ) ;
// int + int
else if ( scalar_ty - > is_integer_ty ( ) )
return builder - > create_sub ( input , other ) ;
2022-01-05 15:27:17 -08:00
throw_unreachable ( " sub " ) ;
2021-04-20 22:29:40 -04:00
}
ir : : value * dispatch : : mul ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
binary_op_type_checking ( input , other , builder ) ;
ir : : type * scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
// float * float
if ( scalar_ty - > is_floating_point_ty ( ) )
return builder - > create_fmul ( input , other ) ;
// int * int
else if ( scalar_ty - > is_integer_ty ( ) )
return builder - > create_mul ( input , other ) ;
2022-01-05 15:27:17 -08:00
throw_unreachable ( " mul " ) ;
2021-04-20 22:29:40 -04:00
}
ir : : value * dispatch : : truediv ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
2021-12-21 09:46:05 -08:00
binary_op_type_checking ( input , other , builder , false , false , true , DivOrMod : : YES ) ;
2021-04-20 22:29:40 -04:00
ir : : type * input_scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
ir : : type * other_scalar_ty = other - > get_type ( ) - > get_scalar_ty ( ) ;
// float / int
if ( input_scalar_ty - > is_floating_point_ty ( ) & & other_scalar_ty - > is_integer_ty ( ) )
other = cast ( other , input_scalar_ty , builder ) ;
// int / float
else if ( input_scalar_ty - > is_integer_ty ( ) & & other_scalar_ty - > is_floating_point_ty ( ) )
input = cast ( input , other_scalar_ty , builder ) ;
// int / int (cast to float32)
else if ( input_scalar_ty - > is_integer_ty ( ) & & other_scalar_ty - > is_integer_ty ( ) ) {
input = cast ( input , builder - > get_float_ty ( ) , builder ) ;
other = cast ( other , builder - > get_float_ty ( ) , builder ) ;
}
// float / float (cast to highest exponent type)
else if ( input_scalar_ty - > is_floating_point_ty ( ) & & other_scalar_ty - > is_floating_point_ty ( ) ) {
if ( input_scalar_ty - > get_fp_mantissa_width ( ) > other_scalar_ty - > get_fp_mantissa_width ( ) )
other = cast ( other , input_scalar_ty , builder ) ;
else
input = cast ( input , other_scalar_ty , builder ) ;
}
// unreachable
else
2022-01-05 15:27:17 -08:00
throw_unreachable ( " div " ) ;
2021-04-20 22:29:40 -04:00
return builder - > create_fdiv ( input , other ) ;
}
ir : : value * dispatch : : floordiv ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
2021-12-21 09:46:05 -08:00
binary_op_type_checking ( input , other , builder , false , false , true , DivOrMod : : YES ) ;
2021-04-20 22:29:40 -04:00
ir : : type * input_scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
ir : : type * other_scalar_ty = other - > get_type ( ) - > get_scalar_ty ( ) ;
if ( input_scalar_ty - > is_integer_ty ( ) & & other_scalar_ty - > is_integer_ty ( ) ) {
ir : : type * ret_ty = integer_promote ( input_scalar_ty , other_scalar_ty ) ;
input = dispatch : : cast ( input , ret_ty , builder ) ;
other = dispatch : : cast ( other , ret_ty , builder ) ;
2022-01-05 15:27:17 -08:00
if ( ret_ty - > is_integer_signed ( ) ) {
return builder - > create_sdiv ( input , other ) ;
} else {
return builder - > create_udiv ( input , other ) ;
}
2021-04-20 22:29:40 -04:00
}
2022-01-05 15:27:17 -08:00
throw_unreachable ( " floordiv " ) ;
2021-04-20 22:29:40 -04:00
}
2022-01-29 18:29:29 -08:00
ir : : value * dispatch : : fdiv ( ir : : value * input , ir : : value * other , constant_int * ieee_rounding , ir : : builder * builder ) {
ir : : type * input_scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
ir : : type * other_scalar_ty = other - > get_type ( ) - > get_scalar_ty ( ) ;
if ( ! input_scalar_ty - > is_floating_point_ty ( ) | | ! other_scalar_ty - > is_floating_point_ty ( ) )
throw semantic_error ( " both operands of fdiv must have floating point scalar type " ) ;
binary_op_type_checking ( input , other , builder , false , false , false , DivOrMod : : YES ) ;
ir : : value * ret = builder - > create_fdiv ( input , other ) ;
if ( ir : : binary_operator * binop = dynamic_cast < ir : : binary_operator * > ( ret ) )
binop - > set_fdiv_ieee_rounding ( ieee_rounding - > get_value ( ) ) ;
return ret ;
}
2021-04-20 22:29:40 -04:00
ir : : value * dispatch : : mod ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
2021-12-21 09:46:05 -08:00
binary_op_type_checking ( input , other , builder , false , false , true , DivOrMod : : YES ) ;
2021-04-20 22:29:40 -04:00
ir : : type * scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
2022-01-05 15:27:17 -08:00
ir : : type * other_scalar_ty = other - > get_type ( ) - > get_scalar_ty ( ) ;
2021-04-20 22:29:40 -04:00
// float % int
if ( scalar_ty - > is_floating_point_ty ( ) )
return builder - > create_frem ( input , other ) ;
// int % int
2022-01-05 15:27:17 -08:00
else if ( scalar_ty - > is_integer_ty ( ) ) {
if ( scalar_ty - > get_integer_signedness ( ) ! = other_scalar_ty - > get_integer_signedness ( ) ) {
throw semantic_error ( " Cannot mod " + scalar_ty - > repr ( ) + " by " + other_scalar_ty - > repr ( ) + " because they have different signedness ; this is unlikely to result in a useful answer . Cast them to the same signedness . " );
}
if ( scalar_ty - > is_integer_signed ( ) ) {
return builder - > create_srem ( input , other ) ;
} else {
return builder - > create_urem ( input , other ) ;
}
}
throw_unreachable ( " mod " ) ;
2021-04-20 22:29:40 -04:00
}
2021-12-21 09:46:05 -08:00
void bitwise_op_type_checking ( ir : : value * & input , ir : : value * & other , ir : : builder * builder ) {
2021-04-20 22:29:40 -04:00
binary_op_type_checking ( input , other , builder , false , false , false ) ;
ir : : type * input_sca_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
ir : : type * other_sca_ty = other - > get_type ( ) - > get_scalar_ty ( ) ;
if ( ! input_sca_ty - > is_integer_ty ( ) | | ! other_sca_ty - > is_integer_ty ( ) )
throw_incompatible_types ( input_sca_ty , other_sca_ty ) ;
2022-01-05 15:27:17 -08:00
ir : : type * ret_sca_ty = integer_promote ( input_sca_ty , other_sca_ty ) ;
if ( ret_sca_ty ! = input_sca_ty )
input = dispatch : : cast ( input , ret_sca_ty , builder ) ;
if ( ret_sca_ty ! = other_sca_ty )
other = dispatch : : cast ( other , ret_sca_ty , builder ) ;
2021-04-20 22:29:40 -04:00
}
ir : : value * dispatch : : and_ ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
2021-12-21 09:46:05 -08:00
bitwise_op_type_checking ( input , other , builder ) ;
2021-04-20 22:29:40 -04:00
return builder - > create_and ( input , other ) ;
}
ir : : value * dispatch : : or_ ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
2021-12-21 09:46:05 -08:00
bitwise_op_type_checking ( input , other , builder ) ;
2021-04-20 22:29:40 -04:00
return builder - > create_or ( input , other ) ;
}
ir : : value * dispatch : : xor_ ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
2021-12-21 09:46:05 -08:00
bitwise_op_type_checking ( input , other , builder ) ;
2021-04-20 22:29:40 -04:00
return builder - > create_xor ( input , other ) ;
}
ir : : value * dispatch : : lshr ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
2021-12-21 09:46:05 -08:00
bitwise_op_type_checking ( input , other , builder ) ;
2021-04-20 22:29:40 -04:00
return builder - > create_lshr ( input , other ) ;
}
ir : : value * dispatch : : shl ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
2021-12-21 09:46:05 -08:00
bitwise_op_type_checking ( input , other , builder ) ;
2021-04-20 22:29:40 -04:00
return builder - > create_shl ( input , other ) ;
}
//===----------------------------------------------------------------------===//
// Unary Operators
//===----------------------------------------------------------------------===//
ir : : value * dispatch : : plus ( ir : : value * input , ir : : builder * ) {
return input ;
}
ir : : value * dispatch : : minus ( ir : : value * input , ir : : builder * builder ) {
ir : : type * input_sca_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
if ( input_sca_ty - > is_pointer_ty ( ) )
throw semantic_error ( " wrong type argument to unary minus ( " + input_sca_ty - > repr ( ) + " ) " ) ;
ir : : value * _0 = ir : : constant : : get_null_value ( input_sca_ty ) ;
return dispatch : : sub ( _0 , input , builder ) ;
}
ir : : value * dispatch : : invert ( ir : : value * input , ir : : builder * builder ) {
ir : : type * input_sca_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
if ( input_sca_ty - > is_pointer_ty ( ) | | input_sca_ty - > is_floating_point_ty ( ) )
throw semantic_error ( " wrong type argument to unary invert ( " + input_sca_ty - > repr ( ) + " ) " ) ;
ir : : value * _1 = ir : : constant : : get_all_ones_value ( input_sca_ty ) ;
return dispatch : : xor_ ( input , _1 , builder ) ;
}
//===----------------------------------------------------------------------===//
// Comparison Operators
//===----------------------------------------------------------------------===//
ir : : value * dispatch : : greater_than ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
binary_op_type_checking ( input , other , builder ) ;
ir : : type * scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
// float > float
if ( scalar_ty - > is_floating_point_ty ( ) )
2021-09-13 15:06:29 -07:00
return builder - > create_fcmpOGT ( input , other ) ;
2021-04-20 22:29:40 -04:00
// int > int
2022-01-05 15:27:17 -08:00
else if ( scalar_ty - > is_integer_ty ( ) ) {
if ( scalar_ty - > is_integer_signed ( ) ) {
return builder - > create_icmpSGT ( input , other ) ;
} else {
return builder - > create_icmpUGT ( input , other ) ;
}
}
throw_unreachable ( " greater_than " ) ;
2021-04-20 22:29:40 -04:00
}
ir : : value * dispatch : : greater_equal ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
binary_op_type_checking ( input , other , builder ) ;
ir : : type * scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
// float >= float
if ( scalar_ty - > is_floating_point_ty ( ) )
2021-09-13 15:06:29 -07:00
return builder - > create_fcmpOGE ( input , other ) ;
2021-04-20 22:29:40 -04:00
// int >= int
2022-01-05 15:27:17 -08:00
else if ( scalar_ty - > is_integer_ty ( ) ) {
if ( scalar_ty - > is_integer_signed ( ) ) {
return builder - > create_icmpSGE ( input , other ) ;
} else {
return builder - > create_icmpUGE ( input , other ) ;
}
}
throw_unreachable ( " greater_equal " ) ;
2021-04-20 22:29:40 -04:00
}
ir : : value * dispatch : : less_than ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
binary_op_type_checking ( input , other , builder ) ;
ir : : type * scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
// float < float
if ( scalar_ty - > is_floating_point_ty ( ) )
2021-09-13 15:06:29 -07:00
return builder - > create_fcmpOLT ( input , other ) ;
2021-04-20 22:29:40 -04:00
// int < int
2022-01-05 15:27:17 -08:00
else if ( scalar_ty - > is_integer_ty ( ) ) {
if ( scalar_ty - > is_integer_signed ( ) ) {
return builder - > create_icmpSLT ( input , other ) ;
} else {
return builder - > create_icmpULT ( input , other ) ;
}
}
throw_unreachable ( " less_than " ) ;
2021-04-20 22:29:40 -04:00
}
ir : : value * dispatch : : less_equal ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
binary_op_type_checking ( input , other , builder ) ;
ir : : type * scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
// float < float
if ( scalar_ty - > is_floating_point_ty ( ) )
2021-09-13 15:06:29 -07:00
return builder - > create_fcmpOLE ( input , other ) ;
2021-04-20 22:29:40 -04:00
// int < int
2022-01-05 15:27:17 -08:00
else if ( scalar_ty - > is_integer_ty ( ) ) {
if ( scalar_ty - > is_integer_signed ( ) ) {
return builder - > create_icmpSLE ( input , other ) ;
} else {
return builder - > create_icmpULE ( input , other ) ;
}
}
throw_unreachable ( " less_equal " ) ;
2021-04-20 22:29:40 -04:00
}
ir : : value * dispatch : : equal ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
binary_op_type_checking ( input , other , builder ) ;
ir : : type * scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
// float == float
if ( scalar_ty - > is_floating_point_ty ( ) )
2021-09-13 15:06:29 -07:00
return builder - > create_fcmpOEQ ( input , other ) ;
2021-04-20 22:29:40 -04:00
// int == int
else if ( scalar_ty - > is_integer_ty ( ) )
return builder - > create_icmpEQ ( input , other ) ;
2022-01-05 15:27:17 -08:00
throw_unreachable ( " equal " ) ;
2021-04-20 22:29:40 -04:00
}
ir : : value * dispatch : : not_equal ( ir : : value * input , ir : : value * other , ir : : builder * builder ) {
binary_op_type_checking ( input , other , builder ) ;
ir : : type * scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
// float == float
if ( scalar_ty - > is_floating_point_ty ( ) )
2021-06-01 21:13:21 -04:00
return builder - > create_fcmpUNE ( input , other ) ;
2021-04-20 22:29:40 -04:00
// int == int
else if ( scalar_ty - > is_integer_ty ( ) )
return builder - > create_icmpNE ( input , other ) ;
2022-01-05 15:27:17 -08:00
throw_unreachable ( " equal " ) ;
2021-04-20 22:29:40 -04:00
}
//===----------------------------------------------------------------------===//
// Block Creation
//===----------------------------------------------------------------------===//
ir : : value * dispatch : : arange ( int start , int end , ir : : builder * builder ) {
return builder - > get_range ( start , end ) ;
}
ir : : value * dispatch : : zeros ( shape_t shape , ir : : type * dtype , ir : : builder * builder ) {
ir : : value * _0 = ir : : constant : : get_null_value ( dtype ) ;
return builder - > create_splat ( _0 , shape ) ;
}
//===----------------------------------------------------------------------===//
// Shape Manipulation
//===----------------------------------------------------------------------===//
ir : : value * dispatch : : reshape ( ir : : value * input , shape_t dst_shape , ir : : builder * builder ) {
unsigned numel = 1 ;
for ( unsigned s : dst_shape ) numel * = s ;
if ( input - > get_type ( ) - > get_tile_num_elements ( ) ! = numel )
throw semantic_error ( " cannot reshape block of different shape " ) ;
return builder - > create_reshape ( input , dst_shape ) ;
}
2021-10-24 02:30:46 -07:00
ir : : value * dispatch : : cat ( ir : : value * lhs , ir : : value * rhs , ir : : builder * builder ) {
return builder - > create_cat ( lhs , rhs ) ;
}
2021-04-20 22:29:40 -04:00
ir : : value * dispatch : : broadcast ( ir : : value * input , shape_t shape , ir : : builder * builder ) {
if ( ! input - > get_type ( ) - > is_block_ty ( ) )
return builder - > create_splat ( input , shape ) ;
auto src_shape = input - > get_type ( ) - > get_block_shapes ( ) ;
if ( src_shape . size ( ) ! = shape . size ( ) )
throw std : : runtime_error ( " Cannot broadcast " ) ;
2021-08-30 11:50:35 -07:00
if ( shape = = src_shape )
return input ;
2021-04-20 22:29:40 -04:00
return builder - > create_broadcast ( input , shape ) ;
}
std : : tuple < ir : : value * , ir : : value * > dispatch : : broadcast ( ir : : value * lhs , ir : : value * rhs , ir : : builder * builder ) {
ir : : type * lhs_ty = lhs - > get_type ( ) ;
ir : : type * rhs_ty = rhs - > get_type ( ) ;
2021-08-30 11:50:35 -07:00
2021-04-20 22:29:40 -04:00
// make_shape_compatible(block, scalar)
if ( lhs_ty - > is_block_ty ( ) & & ! rhs_ty - > is_block_ty ( ) )
rhs = builder - > create_splat ( rhs , lhs_ty - > get_block_shapes ( ) ) ;
// make_shape_compatible(scalar, block)
else if ( ! lhs_ty - > is_block_ty ( ) & & rhs_ty - > is_block_ty ( ) )
lhs = builder - > create_splat ( lhs , rhs_ty - > get_block_shapes ( ) ) ;
// make_shape_compatible(block, block)
else if ( lhs_ty - > is_block_ty ( ) & & rhs_ty - > is_block_ty ( ) ) {
auto lhs_shape = lhs_ty - > get_block_shapes ( ) ;
auto rhs_shape = rhs_ty - > get_block_shapes ( ) ;
if ( lhs_shape . size ( ) ! = rhs_shape . size ( ) )
throw std : : runtime_error ( " Cannot make_shape_compatible: blocks must have the same rank " ) ;
ir : : type : : block_shapes_t ret_shape ;
for ( size_t i = 0 ; i < lhs_shape . size ( ) ; + + i ) {
unsigned left = lhs_shape [ i ] ;
unsigned right = rhs_shape [ i ] ;
if ( left = = 1 )
ret_shape . push_back ( right ) ;
else if ( right = = 1 )
ret_shape . push_back ( left ) ;
else if ( left = = right )
ret_shape . push_back ( left ) ;
else
throw std : : runtime_error ( " Cannot make_shape_compatible: incompatible dimensions at index " + std : : to_string ( i ) +
" : " + std : : to_string ( left ) + " and " + std : : to_string ( right ) ) ;
}
if ( lhs_shape ! = ret_shape )
lhs = builder - > create_broadcast ( lhs , ret_shape ) ;
if ( rhs_shape ! = ret_shape )
rhs = builder - > create_broadcast ( rhs , ret_shape ) ;
}
return std : : make_tuple ( lhs , rhs ) ;
}
2021-05-21 02:47:53 -04:00
ir : : value * dispatch : : bitcast ( ir : : value * input , ir : : type * dst_ty , ir : : builder * builder ) {
ir : : type * src_ty = input - > get_type ( ) ;
if ( src_ty - > is_block_ty ( ) )
dst_ty = ir : : block_type : : get ( dst_ty , input - > get_type ( ) - > get_block_shapes ( ) ) ;
if ( src_ty = = dst_ty )
return input ;
ir : : type * src_sca_ty = src_ty - > get_scalar_ty ( ) ;
ir : : type * dst_sca_ty = dst_ty - > get_scalar_ty ( ) ;
2021-09-01 20:58:47 -07:00
if ( src_sca_ty - > is_pointer_ty ( ) | | dst_sca_ty - > is_pointer_ty ( ) )
return cast ( input , dst_ty , builder ) ;
2021-05-21 02:47:53 -04:00
// Bitcast
int src_bits = src_sca_ty - > get_primitive_size_in_bits ( ) ;
int dst_bits = dst_sca_ty - > get_primitive_size_in_bits ( ) ;
if ( src_bits ! = dst_bits )
throw std : : runtime_error ( " Cannot bitcast data-type of size " + std : : to_string ( src_bits ) +
" to data-type of size " + std : : to_string ( dst_bits ) ) ;
return builder - > create_cast ( ir : : BitCast , input , dst_ty ) ;
}
2021-04-20 22:29:40 -04:00
ir : : value * dispatch : : cast ( ir : : value * input , ir : : type * dst_ty , ir : : builder * builder ) {
ir : : type * src_ty = input - > get_type ( ) ;
if ( src_ty - > is_block_ty ( ) )
dst_ty = ir : : block_type : : get ( dst_ty , input - > get_type ( ) - > get_block_shapes ( ) ) ;
if ( src_ty = = dst_ty )
return input ;
ir : : type * src_sca_ty = src_ty - > get_scalar_ty ( ) ;
ir : : type * dst_sca_ty = dst_ty - > get_scalar_ty ( ) ;
2022-02-24 14:56:24 -08:00
//
if ( ( src_sca_ty - > is_bf16_ty ( ) & & ! dst_sca_ty - > is_fp32_ty ( ) ) | |
( dst_sca_ty - > is_bf16_ty ( ) & & ! src_sca_ty - > is_fp32_ty ( ) ) ) {
return cast ( cast ( input , builder - > get_float_ty ( ) , builder ) , dst_sca_ty , builder ) ;
}
2021-04-20 22:29:40 -04:00
// FP Truncation
bool truncate_fp = src_sca_ty - > is_floating_point_ty ( ) & &
dst_sca_ty - > is_floating_point_ty ( ) & &
src_sca_ty - > get_fp_mantissa_width ( ) > dst_sca_ty - > get_fp_mantissa_width ( ) ;
if ( truncate_fp )
return builder - > create_fp_trunc ( input , dst_ty ) ;
// FP Extension
bool ext_fp = src_sca_ty - > is_floating_point_ty ( ) & &
dst_sca_ty - > is_floating_point_ty ( ) & &
src_sca_ty - > get_fp_mantissa_width ( ) < dst_sca_ty - > get_fp_mantissa_width ( ) ;
if ( ext_fp )
return builder - > create_fp_ext ( input , dst_ty ) ;
// Int cast
if ( src_sca_ty - > is_integer_ty ( ) & & dst_sca_ty - > is_integer_ty ( ) & &
2022-01-05 15:27:17 -08:00
( src_sca_ty - > get_integer_bitwidth ( ) ! = dst_sca_ty - > get_integer_bitwidth ( ) | |
src_sca_ty - > get_integer_signedness ( ) ! = dst_sca_ty - > get_integer_signedness ( ) ) ) {
bool sign_extend = src_sca_ty - > is_integer_signed ( ) & & src_sca_ty ! = builder - > get_int1_ty ( ) ;
return builder - > create_int_cast ( input , dst_ty , sign_extend ) ;
}
2021-04-20 22:29:40 -04:00
// Float -> Int
if ( src_sca_ty - > is_floating_point_ty ( ) & & dst_sca_ty - > is_integer_ty ( ) ) {
if ( dst_sca_ty - > is_bool_ty ( ) )
return builder - > create_fp_to_ui ( input , dst_ty ) ;
else
return builder - > create_fp_to_si ( input , dst_ty ) ;
}
// int -> Float
if ( src_sca_ty - > is_integer_ty ( ) & & dst_sca_ty - > is_floating_point_ty ( ) ) {
2022-01-05 15:27:17 -08:00
if ( src_sca_ty - > is_bool_ty ( ) | | ! src_sca_ty - > is_integer_signed ( ) )
2021-04-20 22:29:40 -04:00
return builder - > create_ui_to_fp ( input , dst_ty ) ;
else
return builder - > create_si_to_fp ( input , dst_ty ) ;
}
2022-01-17 18:00:03 -08:00
if ( src_sca_ty - > is_pointer_ty ( ) & & dst_sca_ty - > is_integer_ty ( ) ) {
int bitwidth = dst_sca_ty - > get_integer_bitwidth ( ) ;
if ( bitwidth = = 64 )
return builder - > create_cast ( ir : : PtrToInt , input , dst_ty ) ;
if ( bitwidth = = 1 )
return dispatch : : not_equal ( dispatch : : cast ( input , builder - > get_int64_ty ( ) , builder ) ,
builder - > get_int64 ( 0 ) ,
builder ) ;
}
2021-09-01 20:58:47 -07:00
if ( ! src_sca_ty - > is_pointer_ty ( ) & & dst_sca_ty - > is_pointer_ty ( ) )
return builder - > create_cast ( ir : : IntToPtr , input , dst_ty ) ;
2021-04-20 22:29:40 -04:00
// Ptr -> Ptr
if ( src_sca_ty - > is_pointer_ty ( ) & & dst_sca_ty - > is_pointer_ty ( ) )
return builder - > create_cast ( ir : : BitCast , input , dst_ty ) ;
// * -> Bool
if ( dst_sca_ty - > is_bool_ty ( ) ) {
if ( src_sca_ty - > is_pointer_ty ( ) )
input = cast ( input , builder - > get_int64_ty ( ) , builder ) ;
ir : : value * other = builder - > get_int64 ( 0 ) ;
if ( src_ty - > is_bool_ty ( ) )
other = builder - > create_splat ( other , src_ty - > get_block_shapes ( ) ) ;
return builder - > create_icmpNE ( input , other ) ;
}
2022-01-05 15:27:17 -08:00
throw_unreachable ( " casting from " + src_sca_ty - > repr ( ) + " to " + dst_sca_ty - > repr ( ) ) ;
2021-04-20 22:29:40 -04:00
}
//===----------------------------------------------------------------------===//
// Memory Operators
//===----------------------------------------------------------------------===//
2022-02-24 14:56:24 -08:00
ir : : value * dispatch : : load ( ir : : value * ptr , ir : : value * mask , ir : : value * other , const std : : string & cache_modifier , const std : : string & eviction_policy , int is_volatile , ir : : builder * builder ) {
2021-04-20 22:29:40 -04:00
if ( ! ptr - > get_type ( ) - > get_scalar_ty ( ) - > is_pointer_ty ( ) )
throw semantic_error ( " Pointer argument of load instruction is " + ptr - > get_type ( ) - > repr ( ) ) ;
if ( ptr - > get_type ( ) - > is_block_ty ( ) ) {
2022-02-24 14:56:24 -08:00
if ( mask )
2021-04-20 22:29:40 -04:00
mask = dispatch : : broadcast ( mask , ptr - > get_type ( ) - > get_block_shapes ( ) , builder ) ;
2022-02-24 14:56:24 -08:00
if ( other )
2021-04-20 22:29:40 -04:00
other = dispatch : : broadcast ( other , ptr - > get_type ( ) - > get_block_shapes ( ) , builder ) ;
}
2022-02-24 14:56:24 -08:00
if ( other )
other = dispatch : : cast ( other , ptr - > get_type ( ) - > get_scalar_ty ( ) - > get_pointer_element_ty ( ) , builder ) ;
2021-09-13 17:39:06 -07:00
ir : : type * ptr_ty = ptr - > get_type ( ) - > get_scalar_ty ( ) ;
ir : : type * elt_ty = ptr_ty - > get_pointer_element_ty ( ) ;
// treat bool* as int8*
if ( elt_ty = = builder - > get_int1_ty ( ) ) {
elt_ty = builder - > get_int8_ty ( ) ;
ptr_ty = pointer_type : : get ( elt_ty , ptr_ty - > get_pointer_address_space ( ) ) ;
ptr = dispatch : : cast ( ptr , ptr_ty , builder ) ;
}
2021-12-30 22:33:24 -08:00
// cache modifier
2021-10-18 13:14:04 +08:00
load_inst : : CACHE_MODIFIER cache = load_inst : : NONE ; // default
if ( ! cache_modifier . empty ( ) ) {
if ( cache_modifier = = " .ca " )
cache = load_inst : : CA ;
else if ( cache_modifier = = " .cg " )
cache = load_inst : : CG ;
else
throw std : : runtime_error ( std : : string ( " Cache modifier " ) + cache_modifier + " not supported " ) ;
}
2022-02-24 14:56:24 -08:00
// eviction policy
load_inst : : EVICTION_POLICY eviction = load_inst : : NORMAL ; //default
if ( ! eviction_policy . empty ( ) ) {
if ( eviction_policy = = " evict_last " )
eviction = load_inst : : EVICT_LAST ;
else if ( eviction_policy = = " evict_first " )
eviction = load_inst : : EVICT_FIRST ;
else
throw std : : runtime_error ( std : : string ( " Eviction policy " ) + eviction_policy + " not supported " ) ;
}
2021-04-20 22:29:40 -04:00
if ( ! mask & & ! other )
2022-02-24 14:56:24 -08:00
return builder - > create_load ( ptr , cache , eviction , is_volatile ) ;
2021-04-20 22:29:40 -04:00
if ( ! mask )
throw std : : runtime_error ( " `other` cannot be provided without `mask` " ) ;
auto shape = ptr - > get_type ( ) - > get_block_shapes ( ) ;
if ( ! other ) {
other = ir : : undef_value : : get ( elt_ty ) ;
if ( ptr - > get_type ( ) - > is_block_ty ( ) )
other = builder - > create_splat ( other , ptr - > get_type ( ) - > get_block_shapes ( ) ) ;
}
2022-02-24 14:56:24 -08:00
return builder - > create_masked_load ( ptr , mask , other , cache , eviction , is_volatile ) ;
2021-04-20 22:29:40 -04:00
}
ir : : value * dispatch : : store ( ir : : value * ptr , ir : : value * val , ir : : value * mask , ir : : builder * builder ) {
if ( ! ptr - > get_type ( ) - > get_scalar_ty ( ) - > is_pointer_ty ( ) )
throw semantic_error ( " Pointer argument of store instruction is " + ptr - > get_type ( ) - > repr ( ) ) ;
if ( ptr - > get_type ( ) - > is_block_ty ( ) )
val = dispatch : : broadcast ( val , ptr - > get_type ( ) - > get_block_shapes ( ) , builder ) ;
if ( mask )
mask = dispatch : : broadcast ( mask , ptr - > get_type ( ) - > get_block_shapes ( ) , builder ) ;
2021-09-13 17:39:06 -07:00
ir : : type * ptr_ty = ptr - > get_type ( ) - > get_scalar_ty ( ) ;
ir : : type * elt_ty = ptr_ty - > get_pointer_element_ty ( ) ;
// treat bool* as int8*
if ( elt_ty = = builder - > get_int1_ty ( ) ) {
elt_ty = builder - > get_int8_ty ( ) ;
ptr_ty = pointer_type : : get ( elt_ty , ptr_ty - > get_pointer_address_space ( ) ) ;
ptr = dispatch : : cast ( ptr , ptr_ty , builder ) ;
}
// cast to target data-type
val = dispatch : : cast ( val , elt_ty , builder ) ;
2021-04-20 22:29:40 -04:00
if ( ! mask )
return builder - > create_store ( ptr , val ) ;
if ( ! mask - > get_type ( ) - > get_scalar_ty ( ) - > is_bool_ty ( ) )
throw semantic_error ( " Mask must have boolean scalar type " ) ;
return builder - > create_masked_store ( ptr , val , mask ) ;
}
ir : : value * dispatch : : atomic_cas ( ir : : value * ptr , ir : : value * cmp , ir : : value * val , ir : : builder * builder ) {
return builder - > create_atomic_cas ( ptr , cmp , val ) ;
}
2021-05-25 18:31:48 -04:00
void atom_red_typechecking ( ir : : value * & ptr , ir : : value * & val , ir : : value * & mask , ir : : builder * builder ) {
2021-05-01 16:14:58 -04:00
if ( ! ptr - > get_type ( ) - > get_scalar_ty ( ) - > is_pointer_ty ( ) )
throw semantic_error ( " Pointer argument of store instruction is " + ptr - > get_type ( ) - > repr ( ) ) ;
if ( ptr - > get_type ( ) - > is_block_ty ( ) ) {
if ( mask ) {
mask = dispatch : : broadcast ( mask , ptr - > get_type ( ) - > get_block_shapes ( ) , builder ) ;
}
if ( val ) {
val = dispatch : : broadcast ( val , ptr - > get_type ( ) - > get_block_shapes ( ) , builder ) ;
}
}
val = dispatch : : cast ( val , ptr - > get_type ( ) - > get_scalar_ty ( ) - > get_pointer_element_ty ( ) , builder ) ;
2021-04-29 09:13:45 -04:00
if ( ! mask ) {
mask = builder - > get_int1 ( true ) ;
if ( ptr - > get_type ( ) - > is_block_ty ( ) )
mask = builder - > create_splat ( mask , ptr - > get_type ( ) - > get_block_shapes ( ) ) ;
}
2021-05-25 18:31:48 -04:00
}
ir : : value * dispatch : : atomic_max ( ir : : value * ptr , ir : : value * val , ir : : value * mask , ir : : builder * builder ) {
atom_red_typechecking ( ptr , val , mask , builder ) ;
ir : : type * sca_ty = val - > get_type ( ) - > get_scalar_ty ( ) ;
// direct call to atomic_max for integers
2022-01-05 15:27:17 -08:00
if ( sca_ty - > is_integer_ty ( ) ) {
if ( sca_ty - > is_integer_signed ( ) ) {
return builder - > create_atomic_rmw ( ir : : atomic_rmw_op_t : : Max , ptr , val , mask ) ;
} else {
return builder - > create_atomic_rmw ( ir : : atomic_rmw_op_t : : UMax , ptr , val , mask ) ;
}
}
2021-05-25 18:31:48 -04:00
// for float
// return atomic_smax(i_ptr, i_val) if val >= 0
// return atomic_umin(i_ptr, i_val) if val < 0
ir : : value * i_val = bitcast ( val , builder - > get_int32_ty ( ) , builder ) ;
ir : : value * i_ptr = bitcast ( ptr , pointer_type : : get ( builder - > get_int32_ty ( ) , 1 ) , builder ) ;
ir : : value * pos = greater_equal ( val , constant_fp : : get ( sca_ty , 0 ) , builder ) ;
ir : : value * neg = less_than ( val , constant_fp : : get ( sca_ty , 0 ) , builder ) ;
ir : : value * pos_ret = builder - > create_atomic_rmw ( ir : : atomic_rmw_op_t : : Max , i_ptr , i_val , and_ ( mask , pos , builder ) ) ;
ir : : value * neg_ret = builder - > create_atomic_rmw ( ir : : atomic_rmw_op_t : : UMin , i_ptr , i_val , and_ ( mask , neg , builder ) ) ;
return where ( pos , pos_ret , neg_ret , builder ) ;
}
ir : : value * dispatch : : atomic_min ( ir : : value * ptr , ir : : value * val , ir : : value * mask , ir : : builder * builder ) {
atom_red_typechecking ( ptr , val , mask , builder ) ;
ir : : type * sca_ty = val - > get_type ( ) - > get_scalar_ty ( ) ;
2022-01-05 15:27:17 -08:00
// direct call to atomic_min for integers
if ( sca_ty - > is_integer_ty ( ) ) {
if ( sca_ty - > is_integer_signed ( ) ) {
return builder - > create_atomic_rmw ( ir : : atomic_rmw_op_t : : Min , ptr , val , mask ) ;
} else {
return builder - > create_atomic_rmw ( ir : : atomic_rmw_op_t : : UMin , ptr , val , mask ) ;
}
}
2021-05-25 18:31:48 -04:00
// for float
// return atomic_smin(i_ptr, i_val) if val >= 0
// return atomic_umax(i_ptr, i_val) if val < 0
ir : : value * i_val = bitcast ( val , builder - > get_int32_ty ( ) , builder ) ;
ir : : value * i_ptr = bitcast ( ptr , pointer_type : : get ( builder - > get_int32_ty ( ) , 1 ) , builder ) ;
ir : : value * pos = greater_equal ( val , constant_fp : : get ( sca_ty , 0 ) , builder ) ;
ir : : value * neg = less_than ( val , constant_fp : : get ( sca_ty , 0 ) , builder ) ;
ir : : value * pos_ret = builder - > create_atomic_rmw ( ir : : atomic_rmw_op_t : : Min , i_ptr , i_val , and_ ( mask , pos , builder ) ) ;
ir : : value * neg_ret = builder - > create_atomic_rmw ( ir : : atomic_rmw_op_t : : UMax , i_ptr , i_val , and_ ( mask , neg , builder ) ) ;
return where ( pos , pos_ret , neg_ret , builder ) ;
}
ir : : value * dispatch : : atomic_add ( ir : : value * ptr , ir : : value * val , ir : : value * mask , ir : : builder * builder ) {
atom_red_typechecking ( ptr , val , mask , builder ) ;
ir : : type * sca_ty = val - > get_type ( ) - > get_scalar_ty ( ) ;
auto op = sca_ty - > is_floating_point_ty ( ) ? ir : : atomic_rmw_op_t : : FAdd : ir : : atomic_rmw_op_t : : Add ;
return builder - > create_atomic_rmw ( op , ptr , val , mask ) ;
}
ir : : value * dispatch : : atomic_and ( ir : : value * ptr , ir : : value * val , ir : : value * mask , ir : : builder * builder ) {
atom_red_typechecking ( ptr , val , mask , builder ) ;
return builder - > create_atomic_rmw ( ir : : atomic_rmw_op_t : : And , ptr , val , mask ) ;
}
ir : : value * dispatch : : atomic_or ( ir : : value * ptr , ir : : value * val , ir : : value * mask , ir : : builder * builder ) {
atom_red_typechecking ( ptr , val , mask , builder ) ;
return builder - > create_atomic_rmw ( ir : : atomic_rmw_op_t : : Or , ptr , val , mask ) ;
}
2021-05-01 16:14:58 -04:00
2021-05-25 18:31:48 -04:00
ir : : value * dispatch : : atomic_xor ( ir : : value * ptr , ir : : value * val , ir : : value * mask , ir : : builder * builder ) {
atom_red_typechecking ( ptr , val , mask , builder ) ;
return builder - > create_atomic_rmw ( ir : : atomic_rmw_op_t : : Xor , ptr , val , mask ) ;
2021-04-29 09:13:45 -04:00
}
2021-08-17 16:33:23 -07:00
ir : : value * dispatch : : atomic_xchg ( ir : : value * ptr , ir : : value * val , ir : : value * mask , ir : : builder * builder ) {
atom_red_typechecking ( ptr , val , mask , builder ) ;
ir : : type * sca_ty = val - > get_type ( ) - > get_scalar_ty ( ) ;
return builder - > create_atomic_rmw ( ir : : atomic_rmw_op_t : : Xchg , ptr , val , mask ) ;
}
2021-04-20 22:29:40 -04:00
//===----------------------------------------------------------------------===//
// Linear Algebra
//===----------------------------------------------------------------------===//
2022-01-12 02:20:31 +08:00
ir : : value * dispatch : : dot ( ir : : value * lhs , ir : : value * rhs , ir : : constant_int * allow_tf32 , ir : : builder * builder ) {
2022-01-28 01:12:44 +08:00
ir : : value * _0 = nullptr ;
if ( lhs - > get_type ( ) - > is_int_or_tileint_ty ( ) )
_0 = builder - > get_int32 ( 0 ) ;
else
_0 = builder - > get_float32 ( 0 ) ;
2021-04-20 22:29:40 -04:00
unsigned M = lhs - > get_type ( ) - > get_block_shapes ( ) [ 0 ] ;
unsigned N = rhs - > get_type ( ) - > get_block_shapes ( ) [ 1 ] ;
_0 = builder - > create_splat ( _0 , { M , N } ) ;
2022-01-12 02:20:31 +08:00
bool _allow_tf32 = allow_tf32 - > get_value ( ) ! = 0 ;
return builder - > create_dot ( lhs , rhs , _0 , _allow_tf32 ) ;
2021-04-20 22:29:40 -04:00
}
//===----------------------------------------------------------------------===//
// Indexing
//===----------------------------------------------------------------------===//
ir : : value * dispatch : : where ( ir : : value * condition , ir : : value * x , ir : : value * y , ir : : builder * builder ) {
condition = dispatch : : cast ( condition , builder - > get_int1_ty ( ) , builder ) ;
if ( condition - > get_type ( ) - > is_block_ty ( ) ) {
x = dispatch : : broadcast ( x , condition - > get_type ( ) - > get_block_shapes ( ) , builder ) ;
y = dispatch : : broadcast ( y , condition - > get_type ( ) - > get_block_shapes ( ) , builder ) ;
}
2022-01-20 10:55:59 -08:00
ir : : type * x_ty = x - > get_type ( ) - > get_scalar_ty ( ) ;
ir : : type * y_ty = y - > get_type ( ) - > get_scalar_ty ( ) ;
ir : : type * ty = computation_type ( x_ty , y_ty , DivOrMod : : NO ) ;
x = dispatch : : cast ( x , ty , builder ) ;
y = dispatch : : cast ( y , ty , builder ) ;
2021-04-20 22:29:40 -04:00
return builder - > create_select ( condition , x , y ) ;
}
//===----------------------------------------------------------------------===//
// Reductions
//===----------------------------------------------------------------------===//
ir : : value * reduce_impl ( ir : : value * input , unsigned int axis , ir : : builder * builder , const std : : string & name ,
ir : : reduce_inst : : op_t FLOAT_OP , ir : : reduce_inst : : op_t INT_OP ) {
ir : : type * scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
2021-06-01 21:13:21 -04:00
// input is extended to 32-bits if necessary
// this increases numerical accuracy and can be done pretty much for free
// on GPUs
if ( scalar_ty - > is_integer_ty ( ) & & scalar_ty - > get_integer_bitwidth ( ) < = 32 )
input = dispatch : : cast ( input , type : : get_int32_ty ( scalar_ty - > get_context ( ) ) , builder ) ;
2021-04-20 22:29:40 -04:00
if ( scalar_ty - > is_floating_point_ty ( ) )
return builder - > create_reduce ( input , FLOAT_OP , axis ) ;
else if ( scalar_ty - > is_integer_ty ( ) )
return builder - > create_reduce ( input , INT_OP , axis ) ;
2022-01-05 15:27:17 -08:00
throw_unreachable ( name ) ;
2021-04-20 22:29:40 -04:00
}
ir : : value * dispatch : : min ( ir : : value * input , unsigned int axis , ir : : builder * builder ) {
return reduce_impl ( input , axis , builder , " min " , ir : : reduce_inst : : FMIN , ir : : reduce_inst : : MIN ) ;
}
ir : : value * dispatch : : max ( ir : : value * input , unsigned int axis , ir : : builder * builder ) {
return reduce_impl ( input , axis , builder , " max " , ir : : reduce_inst : : FMAX , ir : : reduce_inst : : MAX ) ;
}
ir : : value * dispatch : : sum ( ir : : value * input , unsigned int axis , ir : : builder * builder ) {
return reduce_impl ( input , axis , builder , " sum " , ir : : reduce_inst : : FADD , ir : : reduce_inst : : ADD ) ;
}
2021-12-16 17:55:35 -08:00
ir : : value * dispatch : : xor_sum ( ir : : value * input , unsigned int axis , ir : : builder * builder ) {
ir : : type * scalar_ty = input - > get_type ( ) - > get_scalar_ty ( ) ;
if ( ! scalar_ty - > is_integer_ty ( ) )
throw semantic_error ( " xor_sum only supported for integers " ) ;
return reduce_impl ( input , axis , builder , " sum " , ir : : reduce_inst : : XOR , ir : : reduce_inst : : XOR ) ;
}
2021-04-20 22:29:40 -04:00
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
2021-10-24 02:30:46 -07:00
ir : : value * dispatch : : umulhi ( ir : : value * x , ir : : value * y , ir : : builder * builder ) {
binary_op_type_checking ( x , y , builder ) ;
return builder - > insert ( umulhi_inst : : create ( x , y ) ) ;
}
2021-04-20 22:29:40 -04:00
ir : : value * dispatch : : exp ( ir : : value * x , ir : : builder * builder ) {
return builder - > create_exp ( x ) ;
}
ir : : value * dispatch : : log ( ir : : value * x , ir : : builder * builder ) {
return builder - > create_log ( x ) ;
}
2021-07-14 17:16:48 -07:00
ir : : value * dispatch : : cos ( ir : : value * x , ir : : builder * builder ) {
return builder - > create_cos ( x ) ;
}
ir : : value * dispatch : : sin ( ir : : value * x , ir : : builder * builder ) {
return builder - > create_sin ( x ) ;
}
2021-04-20 22:29:40 -04:00
ir : : value * dispatch : : sqrt ( ir : : value * x , ir : : builder * builder ) {
return builder - > create_sqrt ( x ) ;
}
//
ir : : value * dispatch : : multiple_of ( ir : : value * x , int value , ir : : builder * ) {
ir : : instruction * i = dynamic_cast < ir : : instruction * > ( x ) ;
if ( ! i )
throw_unreachable ( " multiple_of " ) ;
i - > set_metadata ( ir : : metadata : : multiple_of , value ) ;
return i ;
}
2021-08-07 16:41:44 -07:00
ir : : value * dispatch : : max_contiguous ( ir : : value * x , int value , ir : : builder * ) {
ir : : instruction * i = dynamic_cast < ir : : instruction * > ( x ) ;
if ( ! i )
throw_unreachable ( " max_contiguous " ) ;
i - > set_metadata ( ir : : metadata : : max_contiguous , value ) ;
return i ;
}
2021-04-20 22:29:40 -04:00
ir : : value * dispatch : : debug_barrier ( ir : : builder * builder ) {
return builder - > create_barrier ( ) ;
}
}
}