[FRONTEND] Semantic analysis refactor (#491)

Moved dispatch.cc to semantic.py (@ptillet)
Integer signedness analysis was moved from C++ to python (@daadaada)
Cleaner frontend types (@daadaada)
Moved SSA construction to a separate object (@ptillet)


Co-authored-by: Yan Da <dyanab@connect.ust.hk>
This commit is contained in:
Philippe Tillet
2022-04-06 16:13:53 -07:00
committed by GitHub
parent 2bed6fc850
commit 9f08ecd684
19 changed files with 2174 additions and 1745 deletions

View File

@@ -1,3 +1,4 @@
#include <bits/types/clock_t.h>
#include <string>
#include <algorithm>
#include <iostream>
@@ -48,18 +49,12 @@ void builder::set_insert_point(basic_block *block){
value *builder::get_int1(bool val)
{ return constant_int::get(type::get_int1_ty(ctx_), val); }
value *builder::get_int32(int32_t val)
value *builder::get_int32(uint32_t val)
{ return constant_int::get(type::get_int32_ty(ctx_), val);}
value *builder::get_uint32(uint32_t val)
{ return constant_int::get(type::get_uint32_ty(ctx_), val);}
value *builder::get_int64(int64_t val)
value *builder::get_int64(uint64_t val)
{ return constant_int::get(type::get_int64_ty(ctx_), val);}
value *builder::get_uint64(uint64_t val)
{ return constant_int::get(type::get_uint64_ty(ctx_), val);}
value *builder::get_float16(float val)
{ return constant_fp::get(type::get_fp16_ty(ctx_), val); }
@@ -90,21 +85,15 @@ type *builder::get_int32_ty()
type *builder::get_int64_ty()
{ return type::get_int64_ty(ctx_); }
type *builder::get_uint8_ty()
{ return type::get_uint8_ty(ctx_); }
type *builder::get_uint16_ty()
{ return type::get_uint16_ty(ctx_); }
type *builder::get_uint32_ty()
{ return type::get_uint32_ty(ctx_); }
type *builder::get_uint64_ty()
{ return type::get_uint64_ty(ctx_); }
type *builder::get_fp8_ty()
{ return type::get_fp8_ty(ctx_); }
type *builder::get_half_ty()
{ return type::get_fp16_ty(ctx_); }
type *builder::get_bf16_ty()
{ return type::get_bf16_ty(ctx_); }
type *builder::get_float_ty()
{ return type::get_fp32_ty(ctx_); }
@@ -140,6 +129,8 @@ value *builder::create_ret(value* val) {
return create_cast(OPCODE, src, dst_ty);\
}
DEFINE_CAST_INSTR(bitcast, cast_op_t::BitCast)
DEFINE_CAST_INSTR(int_to_ptr, cast_op_t::IntToPtr)
DEFINE_CAST_INSTR(ptr_to_int, cast_op_t::PtrToInt)
DEFINE_CAST_INSTR(si_to_fp, cast_op_t::SIToFP)
DEFINE_CAST_INSTR(ui_to_fp, cast_op_t::UIToFP)
@@ -358,6 +349,37 @@ value *builder::create_downcast(value *arg) {
return insert(downcast_inst::create(arg));
}
//
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
}
#define DEFINE_ATOMIC_RMW_INSTR(SUFFIX, OPCODE)\
value *builder::create_ ## SUFFIX(value *ptr, value *val, value *mask){\
return create_atomic_rmw(OPCODE, ptr, val, mask);\
}
DEFINE_ATOMIC_RMW_INSTR(atomic_max, ir::atomic_rmw_op_t::Max)
DEFINE_ATOMIC_RMW_INSTR(atomic_umax, ir::atomic_rmw_op_t::UMax)
DEFINE_ATOMIC_RMW_INSTR(atomic_min, ir::atomic_rmw_op_t::Min)
DEFINE_ATOMIC_RMW_INSTR(atomic_umin, ir::atomic_rmw_op_t::UMin)
DEFINE_ATOMIC_RMW_INSTR(atomic_fadd, ir::atomic_rmw_op_t::FAdd)
DEFINE_ATOMIC_RMW_INSTR(atomic_add, ir::atomic_rmw_op_t::Add)
DEFINE_ATOMIC_RMW_INSTR(atomic_and, ir::atomic_rmw_op_t::And)
DEFINE_ATOMIC_RMW_INSTR(atomic_or, ir::atomic_rmw_op_t::Or)
DEFINE_ATOMIC_RMW_INSTR(atomic_xor, ir::atomic_rmw_op_t::Xor)
DEFINE_ATOMIC_RMW_INSTR(atomic_xchg, ir::atomic_rmw_op_t::Xchg)
// Utilities
value *builder::create_clock() {
return insert(clock_inst::create(ctx_));
}
value *builder::create_globaltimer() {
return insert(globaltimer_inst::create(ctx_));
}
//===----------------------------------------------------------------------===//
// built-in instructions
//===----------------------------------------------------------------------===//
@@ -374,9 +396,6 @@ value *builder::create_atomic_cas(value *ptr, value *cmp, value *val){
return insert(atomic_cas_inst::create(ptr, cmp, val));
}
value *builder::create_atomic_rmw(ir::atomic_rmw_op_t op, value *ptr, value *val, value *msk){
return insert(atomic_rmw_inst::create(op, ptr, val, msk));
}
value *builder::create_exp(value *arg){
return insert(exp_inst::create(arg));

View File

@@ -19,18 +19,12 @@ context_impl::context_impl(context &ctx)
fp32_ty(ctx, type::FP32TyID),
fp64_ty(ctx, type::FP64TyID),
// integers
int1_ty(ctx, 1, signedness::SIGNED),
int8_ty(ctx, 8, signedness::SIGNED),
int16_ty(ctx, 16, signedness::SIGNED),
int32_ty(ctx, 32, signedness::SIGNED),
int64_ty(ctx, 64, signedness::SIGNED),
int128_ty(ctx, 128, signedness::SIGNED),
uint8_ty(ctx, 8, signedness::UNSIGNED),
uint16_ty(ctx, 16, signedness::UNSIGNED),
uint32_ty(ctx, 32, signedness::UNSIGNED),
uint64_ty(ctx, 64, signedness::UNSIGNED){
}
int1_ty(ctx, 1),
int8_ty(ctx, 8),
int16_ty(ctx, 16),
int32_ty(ctx, 32),
int64_ty(ctx, 64),
int128_ty(ctx, 128) {}
//===----------------------------------------------------------------------===//
// context

View File

@@ -1,895 +0,0 @@
#include "triton/ir/dispatch.h"
namespace triton {
namespace ir {
[[ noreturn ]] void throw_unreachable(std::string key) {
throw std::runtime_error("Encountered unimplemented code path in `" + key + "`. "
"This is likely a bug on our side.");
}
//===----------------------------------------------------------------------===//
// Programming Model
//===----------------------------------------------------------------------===//
ir::value *dispatch::program_id(int axis, ir::builder *builder) {
return builder->create_get_program_id(axis);
}
ir::value *dispatch::num_programs(int axis, ir::builder *builder) {
return builder->create_get_num_programs(axis);
}
//===----------------------------------------------------------------------===//
// Implicit Casting Utilities
//===----------------------------------------------------------------------===//
ir::type *integer_promote(ir::type* a_ty, ir::type* b_ty){
int a_rank = a_ty->get_integer_bitwidth();
int b_rank = b_ty->get_integer_bitwidth();
auto a_sn = a_ty->get_integer_signedness();
auto b_sn = b_ty->get_integer_signedness();
// Rules for signedness taken from "Usual arithmetic conversions" on
// https://en.cppreference.com/w/c/language/conversion.
if (a_sn == b_sn) {
return a_rank > b_rank ? a_ty : b_ty;
} else if (a_sn == signedness::UNSIGNED) {
return a_rank >= b_rank ? a_ty : b_ty;
} else if (b_sn == signedness::UNSIGNED) {
return b_rank >= a_rank ? b_ty : a_ty;
} else {
throw_unreachable("integer_promote");
}
}
enum class DivOrMod { NO, YES };
ir::type *computation_type(ir::type* a_ty, ir::type* b_ty, DivOrMod div_or_mod) {
context &ctx = a_ty->get_context();
// 1) if one operand is double, the other is implicitly
// converted to double
if (a_ty->is_fp64_ty() || b_ty->is_fp64_ty())
return type::get_fp64_ty(ctx);
// 2) if one operand is float, the other is implicitly
// converted to float
if (a_ty->is_fp32_ty() || b_ty->is_fp32_ty())
return type::get_fp32_ty(ctx);
// 3 ) if one operand is half, the other is implicitly converted to half
// unless we're doing / or %, which do not exist natively in PTX for fp16.
if (a_ty->is_fp16_ty() || b_ty->is_fp16_ty()) {
if (div_or_mod == DivOrMod::YES) {
return type::get_fp32_ty(ctx);
} else {
return type::get_fp16_ty(ctx);
}
}
if (!a_ty->is_integer_ty() || !b_ty->is_integer_ty())
throw_unreachable("computation_type");
// 4 ) both operands are integer and undergo
// integer promotion
if (div_or_mod == DivOrMod::YES && a_ty->get_integer_signedness() != b_ty->get_integer_signedness()) {
throw semantic_error("Cannot use /, //, or % with " + a_ty->repr() + " and " + b_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
}
return integer_promote(a_ty, b_ty);
}
//===----------------------------------------------------------------------===//
// Binary Operators
//===----------------------------------------------------------------------===//
void throw_incompatible_types(ir::type* type_a, ir::type* type_b) {
throw semantic_error("invalid operands of type " + type_a->repr() + " and " + type_b->repr());
}
void check_ptr_type(ir::type* type_a, ir::type* type_b, bool allow_ptr_a){
if(type_a->is_pointer_ty()){
if(!allow_ptr_a)
throw_incompatible_types(type_a, type_b);
// T* + U* with T != U
if(type_b->is_pointer_ty() && (type_a != type_b))
throw_incompatible_types(type_a, type_b);
// T* + float
if(type_b->is_floating_point_ty())
throw_incompatible_types(type_a, type_b);
}
}
void binary_op_type_checking(ir::value*& lhs, ir::value*& rhs, ir::builder* builder,
bool allow_lhs_ptr = false, bool allow_rhs_ptr = false,
bool arithmetic_check = true, DivOrMod div_or_mod = DivOrMod::NO) {
// implicit broadcasting
std::tie(lhs, rhs) = dispatch::broadcast(lhs, rhs, builder);
// implicit typecasting
ir::type *lhs_sca_ty = lhs->get_type()->get_scalar_ty();
ir::type *rhs_sca_ty = rhs->get_type()->get_scalar_ty();
check_ptr_type(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr);
check_ptr_type(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr);
if (arithmetic_check && !lhs_sca_ty->is_pointer_ty() && !rhs_sca_ty->is_pointer_ty()) {
ir::type *ret_sca_ty = computation_type(lhs_sca_ty, rhs_sca_ty, div_or_mod);
lhs = dispatch::cast(lhs, ret_sca_ty, builder);
rhs = dispatch::cast(rhs, ret_sca_ty, builder);
}
}
ir::value *dispatch::add(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, true, true);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// offset + ptr
// ptr + offset
if(other_scalar_ty->is_pointer_ty() && !input_scalar_ty->is_pointer_ty())
std::swap(input, other);
if (input_scalar_ty->is_pointer_ty())
return builder->create_gep(input, {other});
// float + float
else if (input_scalar_ty->is_floating_point_ty())
return builder->create_fadd(input, other);
// int + int
else if (input_scalar_ty->is_integer_ty())
return builder->create_add(input, other);
throw_unreachable("add");
}
ir::value *dispatch::sub(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, true, false);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// ptr - offset
if (scalar_ty->is_pointer_ty())
return builder->create_gep(input, {dispatch::minus(other, builder)});
// float + float
if (scalar_ty->is_floating_point_ty())
return builder->create_fsub(input, other);
// int + int
else if (scalar_ty->is_integer_ty())
return builder->create_sub(input, other);
throw_unreachable("sub");
}
ir::value *dispatch::mul(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float * float
if (scalar_ty->is_floating_point_ty())
return builder->create_fmul(input, other);
// int * int
else if (scalar_ty->is_integer_ty())
return builder->create_mul(input, other);
throw_unreachable("mul");
}
ir::value *dispatch::truediv(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// float / int
if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_integer_ty())
other = cast(other, input_scalar_ty, builder);
// int / float
else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_floating_point_ty())
input = cast(input, other_scalar_ty, builder);
// int / int (cast to float32)
else if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
input = cast(input, builder->get_float_ty(), builder);
other = cast(other, builder->get_float_ty(), builder);
}
// float / float (cast to highest exponent type)
else if(input_scalar_ty->is_floating_point_ty() && other_scalar_ty->is_floating_point_ty()){
if(input_scalar_ty->get_fp_mantissa_width() > other_scalar_ty->get_fp_mantissa_width())
other = cast(other, input_scalar_ty, builder);
else
input = cast(input, other_scalar_ty, builder);
}
// unreachable
else
throw_unreachable("div");
return builder->create_fdiv(input, other);
}
ir::value *dispatch::floordiv(ir::value *input, ir::value *other, ir::builder *builder){
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
if(input_scalar_ty->is_integer_ty() && other_scalar_ty->is_integer_ty()){
ir::type *ret_ty = integer_promote(input_scalar_ty, other_scalar_ty);
input = dispatch::cast(input, ret_ty, builder);
other = dispatch::cast(other, ret_ty, builder);
if (ret_ty->is_integer_signed()) {
return builder->create_sdiv(input, other);
} else {
return builder->create_udiv(input, other);
}
}
throw_unreachable("floordiv");
}
ir::value *dispatch::fdiv(ir::value *input, ir::value *other, constant_int *ieee_rounding, ir::builder *builder){
ir::type *input_scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
if(!input_scalar_ty->is_floating_point_ty() || !other_scalar_ty->is_floating_point_ty())
throw semantic_error("both operands of fdiv must have floating point scalar type");
binary_op_type_checking(input, other, builder, false, false, false, DivOrMod::YES);
ir::value* ret = builder->create_fdiv(input, other);
if(ir::binary_operator* binop = dynamic_cast<ir::binary_operator*>(ret))
binop->set_fdiv_ieee_rounding(ieee_rounding->get_value());
return ret;
}
ir::value *dispatch::mod(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, true, DivOrMod::YES);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
ir::type *other_scalar_ty = other->get_type()->get_scalar_ty();
// float % int
if (scalar_ty->is_floating_point_ty())
return builder->create_frem(input, other);
// int % int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->get_integer_signedness() != other_scalar_ty->get_integer_signedness()) {
throw semantic_error("Cannot mod " + scalar_ty->repr() + " by " + other_scalar_ty->repr() + " because they have different signedness; this is unlikely to result in a useful answer. Cast them to the same signedness.");
}
if (scalar_ty->is_integer_signed()) {
return builder->create_srem(input, other);
} else {
return builder->create_urem(input, other);
}
}
throw_unreachable("mod");
}
void bitwise_op_type_checking(ir::value *&input, ir::value *&other, ir::builder *builder) {
binary_op_type_checking(input, other, builder, false, false, false);
ir::type *input_sca_ty = input->get_type()->get_scalar_ty();
ir::type *other_sca_ty = other->get_type()->get_scalar_ty();
if(!input_sca_ty->is_integer_ty() || !other_sca_ty->is_integer_ty())
throw_incompatible_types(input_sca_ty, other_sca_ty);
ir::type *ret_sca_ty = integer_promote(input_sca_ty, other_sca_ty);
if (ret_sca_ty != input_sca_ty)
input = dispatch::cast(input, ret_sca_ty, builder);
if (ret_sca_ty != other_sca_ty)
other = dispatch::cast(other, ret_sca_ty, builder);
}
ir::value *dispatch::and_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_and(input, other);
}
ir::value *dispatch::or_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_or(input, other);
}
ir::value *dispatch::xor_(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_xor(input, other);
}
ir::value *dispatch::lshr(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_lshr(input, other);
}
ir::value *dispatch::shl(ir::value *input, ir::value *other, ir::builder *builder) {
bitwise_op_type_checking(input, other, builder);
return builder->create_shl(input, other);
}
//===----------------------------------------------------------------------===//
// Unary Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::plus(ir::value *input, ir::builder *) {
return input;
}
ir::value *dispatch::minus(ir::value *input, ir::builder *builder) {
ir::type* input_sca_ty = input->get_type()->get_scalar_ty();
if(input_sca_ty->is_pointer_ty())
throw semantic_error("wrong type argument to unary minus (" + input_sca_ty->repr() + ")");
ir::value *_0 = ir::constant::get_null_value(input_sca_ty);
return dispatch::sub(_0, input, builder);
}
ir::value *dispatch::invert(ir::value *input, ir::builder *builder) {
ir::type* input_sca_ty = input->get_type()->get_scalar_ty();
if(input_sca_ty->is_pointer_ty() || input_sca_ty->is_floating_point_ty())
throw semantic_error("wrong type argument to unary invert (" + input_sca_ty->repr() + ")");
ir::value *_1 = ir::constant::get_all_ones_value(input_sca_ty);
return dispatch::xor_(input, _1, builder);
}
//===----------------------------------------------------------------------===//
// Comparison Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::greater_than(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float > float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGT(input, other);
// int > int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSGT(input, other);
} else {
return builder->create_icmpUGT(input, other);
}
}
throw_unreachable("greater_than");
}
ir::value *dispatch::greater_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float >= float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOGE(input, other);
// int >= int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSGE(input, other);
} else {
return builder->create_icmpUGE(input, other);
}
}
throw_unreachable("greater_equal");
}
ir::value *dispatch::less_than(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLT(input, other);
// int < int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSLT(input, other);
} else {
return builder->create_icmpULT(input, other);
}
}
throw_unreachable("less_than");
}
ir::value *dispatch::less_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float < float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOLE(input, other);
// int < int
else if (scalar_ty->is_integer_ty()) {
if (scalar_ty->is_integer_signed()) {
return builder->create_icmpSLE(input, other);
} else {
return builder->create_icmpULE(input, other);
}
}
throw_unreachable("less_equal");
}
ir::value *dispatch::equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpOEQ(input, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpEQ(input, other);
throw_unreachable("equal");
}
ir::value *dispatch::not_equal(ir::value *input, ir::value *other, ir::builder *builder) {
binary_op_type_checking(input, other, builder);
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// float == float
if (scalar_ty->is_floating_point_ty())
return builder->create_fcmpUNE(input, other);
// int == int
else if (scalar_ty->is_integer_ty())
return builder->create_icmpNE(input, other);
throw_unreachable("equal");
}
//===----------------------------------------------------------------------===//
// Block Creation
//===----------------------------------------------------------------------===//
ir::value* dispatch::arange(int start, int end, ir::builder *builder) {
return builder->get_range(start, end);
}
ir::value* dispatch::zeros(shape_t shape, ir::type *dtype, ir::builder *builder) {
ir::value *_0 = ir::constant::get_null_value(dtype);
return builder->create_splat(_0, shape);
}
//===----------------------------------------------------------------------===//
// Shape Manipulation
//===----------------------------------------------------------------------===//
ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *builder) {
unsigned numel = 1;
for(unsigned s: dst_shape) numel *= s;
if(input->get_type()->get_tile_num_elements() != numel)
throw semantic_error("cannot reshape block of different shape");
return builder->create_reshape(input, dst_shape);
}
ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
return builder->create_cat(lhs, rhs);
}
ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) {
if (!input->get_type()->is_block_ty())
return builder->create_splat(input, shape);
auto src_shape = input->get_type()->get_block_shapes();
if (src_shape.size() != shape.size())
throw std::runtime_error("Cannot broadcast");
if(shape == src_shape)
return input;
return builder->create_broadcast(input, shape);
}
std::tuple<ir::value*, ir::value*> dispatch::broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder) {
ir::type *lhs_ty = lhs->get_type();
ir::type *rhs_ty = rhs->get_type();
// make_shape_compatible(block, scalar)
if (lhs_ty->is_block_ty() && !rhs_ty->is_block_ty())
rhs = builder->create_splat(rhs, lhs_ty->get_block_shapes());
// make_shape_compatible(scalar, block)
else if (!lhs_ty->is_block_ty() && rhs_ty->is_block_ty())
lhs = builder->create_splat(lhs, rhs_ty->get_block_shapes());
// make_shape_compatible(block, block)
else if (lhs_ty->is_block_ty() && rhs_ty->is_block_ty()) {
auto lhs_shape = lhs_ty->get_block_shapes();
auto rhs_shape = rhs_ty->get_block_shapes();
if (lhs_shape.size() != rhs_shape.size())
throw std::runtime_error("Cannot make_shape_compatible: blocks must have the same rank");
ir::type::block_shapes_t ret_shape;
for (size_t i = 0; i < lhs_shape.size(); ++i) {
unsigned left = lhs_shape[i];
unsigned right = rhs_shape[i];
if (left == 1)
ret_shape.push_back(right);
else if (right == 1)
ret_shape.push_back(left);
else if (left == right)
ret_shape.push_back(left);
else
throw std::runtime_error("Cannot make_shape_compatible: incompatible dimensions at index " + std::to_string(i) +
": " + std::to_string(left) + " and " + std::to_string(right));
}
if (lhs_shape != ret_shape)
lhs = builder->create_broadcast(lhs, ret_shape);
if (rhs_shape != ret_shape)
rhs = builder->create_broadcast(rhs, ret_shape);
}
return std::make_tuple(lhs, rhs);
}
ir::value *dispatch::bitcast(ir::value *input, ir::type *dst_ty, ir::builder *builder){
ir::type *src_ty = input->get_type();
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
if(src_ty == dst_ty)
return input;
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
if(src_sca_ty->is_pointer_ty() || dst_sca_ty->is_pointer_ty())
return cast(input, dst_ty, builder);
// Bitcast
int src_bits = src_sca_ty->get_primitive_size_in_bits();
int dst_bits = dst_sca_ty->get_primitive_size_in_bits();
if( src_bits!= dst_bits)
throw std::runtime_error("Cannot bitcast data-type of size " + std::to_string(src_bits) +
"to data-type of size " + std::to_string(dst_bits));
return builder->create_cast(ir::BitCast, input, dst_ty);
}
ir::value *dispatch::cast(ir::value *input, ir::type *dst_ty, ir::builder *builder) {
ir::type *src_ty = input->get_type();
if (src_ty->is_block_ty())
dst_ty = ir::block_type::get(dst_ty, input->get_type()->get_block_shapes());
if(src_ty == dst_ty)
return input;
ir::type *src_sca_ty = src_ty->get_scalar_ty();
ir::type *dst_sca_ty = dst_ty->get_scalar_ty();
//
if((src_sca_ty->is_bf16_ty() && !dst_sca_ty->is_fp32_ty()) ||
(dst_sca_ty->is_bf16_ty() && !src_sca_ty->is_fp32_ty())){
return cast(cast(input, builder->get_float_ty(), builder), dst_sca_ty, builder);
}
// FP Truncation
bool truncate_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() > dst_sca_ty->get_fp_mantissa_width();
if (truncate_fp)
return builder->create_fp_trunc(input, dst_ty);
// FP Extension
bool ext_fp = src_sca_ty->is_floating_point_ty() &&
dst_sca_ty->is_floating_point_ty() &&
src_sca_ty->get_fp_mantissa_width() < dst_sca_ty->get_fp_mantissa_width();
if (ext_fp)
return builder->create_fp_ext(input, dst_ty);
// Int cast
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_integer_ty() &&
(src_sca_ty->get_integer_bitwidth() != dst_sca_ty->get_integer_bitwidth() ||
src_sca_ty->get_integer_signedness() != dst_sca_ty->get_integer_signedness())) {
bool sign_extend = src_sca_ty->is_integer_signed() && src_sca_ty != builder->get_int1_ty();
return builder->create_int_cast(input, dst_ty, sign_extend);
}
// Float -> Int
if (src_sca_ty->is_floating_point_ty() && dst_sca_ty->is_integer_ty()){
if(dst_sca_ty->is_bool_ty())
return builder->create_fp_to_ui(input, dst_ty);
else
return builder->create_fp_to_si(input, dst_ty);
}
// int -> Float
if (src_sca_ty->is_integer_ty() && dst_sca_ty->is_floating_point_ty()){
if (src_sca_ty->is_bool_ty() || !src_sca_ty->is_integer_signed())
return builder->create_ui_to_fp(input, dst_ty);
else
return builder->create_si_to_fp(input, dst_ty);
}
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_integer_ty()){
int bitwidth = dst_sca_ty->get_integer_bitwidth();
if(bitwidth == 64)
return builder->create_cast(ir::PtrToInt, input, dst_ty);
if(bitwidth == 1)
return dispatch::not_equal(dispatch::cast(input, builder->get_int64_ty(), builder),
builder->get_int64(0),
builder);
}
if (!src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::IntToPtr, input, dst_ty);
// Ptr -> Ptr
if (src_sca_ty->is_pointer_ty() && dst_sca_ty->is_pointer_ty())
return builder->create_cast(ir::BitCast, input, dst_ty);
// * -> Bool
if (dst_sca_ty->is_bool_ty()) {
if (src_sca_ty->is_pointer_ty())
input = cast(input, builder->get_int64_ty(), builder);
ir::value *other = builder->get_int64(0);
if (src_ty->is_bool_ty())
other = builder->create_splat(other, src_ty->get_block_shapes());
return builder->create_icmpNE(input, other);
}
throw_unreachable("casting from " + src_sca_ty->repr() + " to " + dst_sca_ty->repr());
}
//===----------------------------------------------------------------------===//
// Memory Operators
//===----------------------------------------------------------------------===//
ir::value *dispatch::load(ir::value* ptr, ir::value* mask, ir::value* other, const std::string &cache_modifier, const std::string& eviction_policy, int is_volatile, ir::builder* builder) {
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of load instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty()){
if(mask)
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
if(other)
other = dispatch::broadcast(other, ptr->get_type()->get_block_shapes(), builder);
}
if(other)
other = dispatch::cast(other, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
// treat bool* as int8*
if(elt_ty == builder->get_int1_ty()){
elt_ty = builder->get_int8_ty();
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
ptr = dispatch::cast(ptr, ptr_ty, builder);
}
// cache modifier
load_inst::CACHE_MODIFIER cache = load_inst::NONE; // default
if (!cache_modifier.empty()) {
if (cache_modifier == ".ca")
cache = load_inst::CA;
else if (cache_modifier == ".cg")
cache = load_inst::CG;
else
throw std::runtime_error(std::string("Cache modifier ") + cache_modifier + " not supported");
}
// eviction policy
load_inst::EVICTION_POLICY eviction = load_inst::NORMAL; //default
if(!eviction_policy.empty()){
if (eviction_policy == "evict_last")
eviction = load_inst::EVICT_LAST;
else if(eviction_policy == "evict_first")
eviction = load_inst::EVICT_FIRST;
else
throw std::runtime_error(std::string("Eviction policy") + eviction_policy + " not supported");
}
if (!mask && !other)
return builder->create_load(ptr, cache, eviction, is_volatile);
if (!mask)
throw std::runtime_error("`other` cannot be provided without `mask`");
auto shape = ptr->get_type()->get_block_shapes();
if(!other){
other = ir::undef_value::get(elt_ty);
if(ptr->get_type()->is_block_ty())
other = builder->create_splat(other, ptr->get_type()->get_block_shapes());
}
return builder->create_masked_load(ptr, mask, other, cache, eviction, is_volatile);
}
ir::value *dispatch::store(ir::value* ptr, ir::value *val, ir::value* mask, ir::builder *builder) {
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty())
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
if(mask)
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
ir::type *ptr_ty = ptr->get_type()->get_scalar_ty();
ir::type *elt_ty = ptr_ty->get_pointer_element_ty();
// treat bool* as int8*
if(elt_ty == builder->get_int1_ty()){
elt_ty = builder->get_int8_ty();
ptr_ty = pointer_type::get(elt_ty, ptr_ty->get_pointer_address_space());
ptr = dispatch::cast(ptr, ptr_ty, builder);
}
// cast to target data-type
val = dispatch::cast(val, elt_ty, builder);
if (!mask)
return builder->create_store(ptr, val);
if(!mask->get_type()->get_scalar_ty()->is_bool_ty())
throw semantic_error("Mask must have boolean scalar type");
return builder->create_masked_store(ptr, val, mask);
}
ir::value *dispatch::atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder){
return builder->create_atomic_cas(ptr, cmp, val);
}
void atom_red_typechecking(ir::value*& ptr, ir::value *&val, ir::value *&mask, ir::builder *builder){
if(!ptr->get_type()->get_scalar_ty()->is_pointer_ty())
throw semantic_error("Pointer argument of store instruction is " + ptr->get_type()->repr());
if(ptr->get_type()->is_block_ty()){
if(mask){
mask = dispatch::broadcast(mask, ptr->get_type()->get_block_shapes(), builder);
}
if(val){
val = dispatch::broadcast(val, ptr->get_type()->get_block_shapes(), builder);
}
}
val = dispatch::cast(val, ptr->get_type()->get_scalar_ty()->get_pointer_element_ty(), builder);
if(!mask){
mask = builder->get_int1(true);
if(ptr->get_type()->is_block_ty())
mask = builder->create_splat(mask, ptr->get_type()->get_block_shapes());
}
}
ir::value *dispatch::atomic_max(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
// direct call to atomic_max for integers
if(sca_ty->is_integer_ty()) {
if (sca_ty->is_integer_signed()) {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, ptr, val, mask);
} else {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, ptr, val, mask);
}
}
// for float
// return atomic_smax(i_ptr, i_val) if val >= 0
// return atomic_umin(i_ptr, i_val) if val < 0
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Max, i_ptr, i_val, and_(mask, pos, builder));
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, i_ptr, i_val, and_(mask, neg, builder));
return where(pos, pos_ret, neg_ret, builder);
}
ir::value *dispatch::atomic_min(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
// direct call to atomic_min for integers
if(sca_ty->is_integer_ty()) {
if (sca_ty->is_integer_signed()) {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, ptr, val, mask);
} else {
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMin, ptr, val, mask);
}
}
// for float
// return atomic_smin(i_ptr, i_val) if val >= 0
// return atomic_umax(i_ptr, i_val) if val < 0
ir::value* i_val = bitcast(val, builder->get_int32_ty(), builder);
ir::value* i_ptr = bitcast(ptr, pointer_type::get(builder->get_int32_ty(), 1), builder);
ir::value* pos = greater_equal(val, constant_fp::get(sca_ty, 0), builder);
ir::value* neg = less_than(val, constant_fp::get(sca_ty, 0), builder);
ir::value* pos_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::Min, i_ptr, i_val, and_(mask, pos, builder));
ir::value* neg_ret = builder->create_atomic_rmw(ir::atomic_rmw_op_t::UMax, i_ptr, i_val, and_(mask, neg, builder));
return where(pos, pos_ret, neg_ret, builder);
}
ir::value *dispatch::atomic_add(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
auto op = sca_ty->is_floating_point_ty() ? ir::atomic_rmw_op_t::FAdd : ir::atomic_rmw_op_t::Add;
return builder->create_atomic_rmw(op, ptr, val, mask);
}
ir::value *dispatch::atomic_and(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::And, ptr, val, mask);
}
ir::value *dispatch::atomic_or(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Or, ptr, val, mask);
}
ir::value *dispatch::atomic_xor(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xor, ptr, val, mask);
}
ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask, ir::builder *builder){
atom_red_typechecking(ptr, val, mask, builder);
ir::type* sca_ty = val->get_type()->get_scalar_ty();
return builder->create_atomic_rmw(ir::atomic_rmw_op_t::Xchg, ptr, val, mask);
}
//===----------------------------------------------------------------------===//
// Linear Algebra
//===----------------------------------------------------------------------===//
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) {
ir::value *_0 = nullptr;
if (lhs->get_type()->is_int_or_tileint_ty())
_0 = builder->get_int32(0);
else
_0 = builder->get_float32(0);
unsigned M = lhs->get_type()->get_block_shapes()[0];
unsigned N = rhs->get_type()->get_block_shapes()[1];
_0 = builder->create_splat(_0, {M, N});
bool _allow_tf32 = allow_tf32->get_value() != 0;
return builder->create_dot(lhs, rhs, _0, _allow_tf32);
}
//===----------------------------------------------------------------------===//
// Indexing
//===----------------------------------------------------------------------===//
ir::value *dispatch::where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder){
condition = dispatch::cast(condition, builder->get_int1_ty(), builder);
if(condition->get_type()->is_block_ty()){
x = dispatch::broadcast(x, condition->get_type()->get_block_shapes(), builder);
y = dispatch::broadcast(y, condition->get_type()->get_block_shapes(), builder);
}
ir::type* x_ty = x->get_type()->get_scalar_ty();
ir::type* y_ty = y->get_type()->get_scalar_ty();
ir::type* ty = computation_type(x_ty, y_ty, DivOrMod::NO);
x = dispatch::cast(x, ty, builder);
y = dispatch::cast(y, ty, builder);
return builder->create_select(condition, x, y);
}
//===----------------------------------------------------------------------===//
// Reductions
//===----------------------------------------------------------------------===//
ir::value *reduce_impl(ir::value *input, unsigned int axis, ir::builder *builder, const std::string &name,
ir::reduce_inst::op_t FLOAT_OP, ir::reduce_inst::op_t INT_OP) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
// input is extended to 32-bits if necessary
// this increases numerical accuracy and can be done pretty much for free
// on GPUs
if(scalar_ty->is_integer_ty() && scalar_ty->get_integer_bitwidth() <= 32)
input = dispatch::cast(input, type::get_int32_ty(scalar_ty->get_context()), builder);
if (scalar_ty->is_floating_point_ty())
return builder->create_reduce(input, FLOAT_OP, axis);
else if (scalar_ty->is_integer_ty())
return builder->create_reduce(input, INT_OP, axis);
throw_unreachable(name);
}
ir::value *dispatch::min(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "min", ir::reduce_inst::FMIN, ir::reduce_inst::MIN);
}
ir::value *dispatch::max(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "max", ir::reduce_inst::FMAX, ir::reduce_inst::MAX);
}
ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *builder) {
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::FADD, ir::reduce_inst::ADD);
}
ir::value *dispatch::xor_sum(ir::value *input, unsigned int axis, ir::builder *builder) {
ir::type *scalar_ty = input->get_type()->get_scalar_ty();
if (!scalar_ty->is_integer_ty())
throw semantic_error("xor_sum only supported for integers");
return reduce_impl(input, axis, builder, "sum", ir::reduce_inst::XOR, ir::reduce_inst::XOR);
}
//===----------------------------------------------------------------------===//
// Math
//===----------------------------------------------------------------------===//
ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) {
binary_op_type_checking(x, y, builder);
return builder->insert(umulhi_inst::create(x, y));
}
ir::value *dispatch::exp(ir::value *x, ir::builder *builder) {
return builder->create_exp(x);
}
ir::value *dispatch::log(ir::value *x, ir::builder *builder) {
return builder->create_log(x);
}
ir::value *dispatch::cos(ir::value *x, ir::builder *builder) {
return builder->create_cos(x);
}
ir::value *dispatch::sin(ir::value *x, ir::builder *builder) {
return builder->create_sin(x);
}
ir::value *dispatch::sqrt(ir::value *x, ir::builder *builder) {
return builder->create_sqrt(x);
}
//
ir::value *dispatch::globaltimer(ir::builder *builder) {
return builder->insert(globaltimer_inst::create(builder->get_context()));
}
ir::value *dispatch::clock(ir::builder *builder) {
return builder->insert(clock_inst::create(builder->get_context()));
}
//===----------------------------------------------------------------------===//
// Control FLow
//===----------------------------------------------------------------------===//
//
ir::value *dispatch::multiple_of(ir::value *x, int value, ir::builder *){
ir::instruction* i = dynamic_cast<ir::instruction*>(x);
if(!i)
throw_unreachable("multiple_of");
i->set_metadata(ir::metadata::multiple_of, value);
return i;
}
ir::value *dispatch::max_contiguous(ir::value *x, int value, ir::builder *){
ir::instruction* i = dynamic_cast<ir::instruction*>(x);
if(!i)
throw_unreachable("max_contiguous");
i->set_metadata(ir::metadata::max_contiguous, value);
return i;
}
ir::value *dispatch::debug_barrier(ir::builder *builder) {
return builder->create_barrier();
}
}
}

View File

@@ -377,8 +377,8 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed,
unsigned arg_bits = arg_ty->get_scalar_ty()->get_integer_bitwidth();
unsigned dst_bits = ty->get_scalar_ty()->get_integer_bitwidth();
cast_op_t op = (arg_bits == dst_bits ? cast_op_t::BitCast :
(arg_bits > dst_bits ? cast_op_t::Trunc :
(is_signed ? cast_op_t::SExt : cast_op_t::ZExt)));
(arg_bits > dst_bits ? cast_op_t::Trunc :
(is_signed ? cast_op_t::SExt : cast_op_t::ZExt)));
return create(op, arg, ty, name, next);
}

View File

@@ -9,154 +9,10 @@
namespace triton{
namespace ir{
/* */
value_constructor::value_constructor(ir::builder& builder): builder_(builder){
sealed_blocks_.insert(nullptr);
}
void value_constructor::set_value(const std::string& name, ir::basic_block *block, ir::value *value){
values_[val_key_t{name, block}] = value;
auto it = metadatas_.find(name);
if(auto *x = dynamic_cast<ir::instruction*>(value))
if(it != metadatas_.end()){
x->set_metadata(it->second.first, it->second.second);
}
// value->set_name(name);
}
void value_constructor::set_value(const std::string& name, ir::value *value){
return set_value(name, builder_.get_insert_block(), value);
}
ir::phi_node* value_constructor::make_phi(ir::type *ty, unsigned num_values, ir::basic_block *block){
basic_block::iterator insert = block->get_first_non_phi();
if(insert != block->end()){
builder_.set_insert_point(insert);
}
ir::phi_node *res = builder_.create_phi(ty, num_values);
if(insert != block->end())
builder_.set_insert_point(block);
return res;
}
ir::value *value_constructor::try_remove_trivial_phis(ir::phi_node *&phi){
// find non-self references
std::set<ir::value*> non_self_ref;
std::copy_if(phi->ops().begin(), phi->ops().end(), std::inserter(non_self_ref, non_self_ref.begin()),
[phi](ir::value* op){ return op != phi && op; });
// non-trivial
if(non_self_ref.size() != 1)
return phi;
// unique value or self-reference
ir::value *same = *non_self_ref.begin();
assert(same != nullptr);
phi->replace_all_uses_with(same);
phi->erase_from_parent();
std::vector<ir::user*> users = phi->get_users();
for(ir::user* u: users)
if(auto *uphi = dynamic_cast<ir::phi_node*>(u))
if(uphi != phi)
try_remove_trivial_phis(uphi);
return same;
}
ir::value *value_constructor::add_phi_operands(const std::string& name, ir::phi_node *&phi){
// already initialized
if(phi->get_num_operands())
return phi;
ir::basic_block *block = phi->get_parent();
for(ir::basic_block *pred: block->get_predecessors()){
ir::value *value = get_value(name, pred);
phi->add_incoming(value, pred);
}
return phi;
}
ir::value *value_constructor::get_value_recursive(const std::string& name, ir::basic_block *block) {
ir::value *result;
auto preds = block->get_predecessors();
ir::type *ty = types_.at(name);
if(block && sealed_blocks_.find(block) == sealed_blocks_.end()){
incomplete_phis_[block][name] = make_phi(ty, 1, block);
result = (ir::value*)incomplete_phis_[block][name];
}
else if(preds.size() <= 1){
bool has_pred = preds.size();
result = get_value(name, has_pred?preds.front():nullptr);
}
else{
ir::phi_node* phi = make_phi(ty, 1, block);
set_value(name, block, phi);
result = add_phi_operands(name, phi);
if(auto *phi = dynamic_cast<ir::phi_node*>(result))
result = try_remove_trivial_phis(phi);
}
if(auto *phi = dynamic_cast<ir::phi_node*>(result)){
result = try_remove_trivial_phis(phi);
}
set_value(name, block, result);
return result;
}
ir::value *value_constructor::get_value(const std::string& name, ir::basic_block *block) {
ir::basic_block* save_block = builder_.get_insert_block();
ir::basic_block::iterator save_pt = builder_.get_insert_point();
val_key_t key(name, block);
// std::cout << values_.size() << std::endl;
// std::cout << name << " " << block << " " << values_.begin()->first.first << " " << values_.begin()->first.second << std::endl;
if(values_.find(key) != values_.end()){
return values_.at(key);
}
ir::value *result = get_value_recursive(name, block);
builder_.set_insert_point(save_block);
if(save_pt != save_block->end())
builder_.set_insert_point(save_pt);
return result;
}
ir::value *value_constructor::get_value(const std::string& name) {
return get_value(name, builder_.get_insert_block());
}
void value_constructor::seal_block(ir::basic_block *block){
for(auto &x: incomplete_phis_[block]){
add_phi_operands(x.first, x.second);
if(get_value(x.first) == x.second)
set_value(x.first, try_remove_trivial_phis(x.second));
}
sealed_blocks_.insert(block);
incomplete_phis_[block].clear();
}
/* Module */
module::module(const std::string &name, builder &builder)
: name_(name), builder_(builder) {
}
void module::reset_ret_ty(const std::string& name, type* ty) {
get_function(name)->get_fn_type()->reset_ret_ty(ty);
}
ir::builder& module::get_builder() {
return builder_;
}
void module::set_continue_fn(std::function<ir::value*()> fn) {
continue_fn_ = fn;
}
std::function<ir::value*()> module::get_continue_fn() {
return continue_fn_;
}
const std::string& module::get_name() {
return name_;
}
/* functions */
function *module::get_or_insert_function(const std::string &name, function_type *ty) {
function *&fn = (function*&)symbols_[name];

View File

@@ -36,16 +36,6 @@ unsigned type::get_primitive_size_in_bits() const {
unsigned type::get_integer_bitwidth() const
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
signedness type::get_integer_signedness() const
{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_signedness(); }
bool type::is_integer_signed() const {
if (id_ != IntegerTyID) {
throw std::logic_error("type is " + repr() + ", not integer");
}
return ((integer_type*)(this))->get_signedness() == signedness::SIGNED;
}
unsigned type::get_tile_bitwidth() const
{ return ((block_type*)(this))->get_bitwidth(); }
@@ -145,10 +135,6 @@ integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; }
integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
integer_type *type::get_uint8_ty(context &ctx) { return &ctx.p_impl->uint8_ty; }
integer_type *type::get_uint16_ty(context &ctx) { return &ctx.p_impl->uint16_ty; }
integer_type *type::get_uint32_ty(context &ctx) { return &ctx.p_impl->uint32_ty; }
integer_type *type::get_uint64_ty(context &ctx) { return &ctx.p_impl->uint64_ty; }