[CODEGEN] Some compiler improvements (#349)

This commit is contained in:
Philippe Tillet
2021-10-13 17:49:39 -07:00
committed by GitHub
parent c2e6b90ff1
commit 9b32075062
2 changed files with 13 additions and 10 deletions

View File

@@ -198,21 +198,24 @@ scanline_layout::scanline_layout(size_t num_warps,
bool is_dot = std::any_of(values.begin(), values.end(), bool is_dot = std::any_of(values.begin(), values.end(),
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); }); [&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
ir::value *ptr = nullptr;
std::vector<ir::value*> ptrs;
for(ir::value *v: values) for(ir::value *v: values)
for(ir::user *usr: v->get_users()) for(ir::user *usr: v->get_users())
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){ if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
if(!ptr || ptr->get_type()->get_tile_rank() < io->get_pointer_operand()->get_type()->get_tile_rank()) if(ptrs.empty() || ptrs[0]->get_type()->get_tile_rank() <= io->get_pointer_operand()->get_type()->get_tile_rank())
ptr = io->get_pointer_operand(); ptrs.push_back(io->get_pointer_operand());
} }
unsigned i = order_[0]; unsigned i = order_[0];
int contiguous = 1; int contiguous = 1;
if(ptr){ for(ir::value* ptr: ptrs){
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits(); int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
contiguous = std::min<int>(align->get(ptr, i), 128 / nbits); contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
} }
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i])); nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
size /= shape_[i]; size /= shape_[i];

View File

@@ -100,9 +100,9 @@ def uint32_to_uniform_float(x):
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
covers all the possible values it can take. covers all the possible values it can take.
""" """
max = 2147483647. max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
x = tl.where(x < 0, -x - 1, x) x = tl.where(x < 0, -x - 1, x)
return x / max return x * max
@triton.jit @triton.jit
def pair_uniform_to_normal(u1, u2): def pair_uniform_to_normal(u1, u2):