[CODEGEN] Some compiler improvements (#349)
This commit is contained in:
@@ -198,21 +198,24 @@ scanline_layout::scanline_layout(size_t num_warps,
|
|||||||
bool is_dot = std::any_of(values.begin(), values.end(),
|
bool is_dot = std::any_of(values.begin(), values.end(),
|
||||||
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
|
[&](ir::value* v) { return dynamic_cast<ir::dot_inst*>(v); });
|
||||||
|
|
||||||
ir::value *ptr = nullptr;
|
|
||||||
|
|
||||||
|
std::vector<ir::value*> ptrs;
|
||||||
for(ir::value *v: values)
|
for(ir::value *v: values)
|
||||||
for(ir::user *usr: v->get_users())
|
for(ir::user *usr: v->get_users())
|
||||||
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
|
if(auto *io = dynamic_cast<ir::io_inst*>(usr)){
|
||||||
if(!ptr || ptr->get_type()->get_tile_rank() < io->get_pointer_operand()->get_type()->get_tile_rank())
|
if(ptrs.empty() || ptrs[0]->get_type()->get_tile_rank() <= io->get_pointer_operand()->get_type()->get_tile_rank())
|
||||||
ptr = io->get_pointer_operand();
|
ptrs.push_back(io->get_pointer_operand());
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned i = order_[0];
|
unsigned i = order_[0];
|
||||||
int contiguous = 1;
|
int contiguous = 1;
|
||||||
if(ptr){
|
for(ir::value* ptr: ptrs){
|
||||||
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
|
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
|
||||||
contiguous = std::min<int>(align->get(ptr, i), 128 / nbits);
|
contiguous = std::max<int>(contiguous, std::min<int>(align->get(ptr, i), 128 / nbits));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||||
size /= shape_[i];
|
size /= shape_[i];
|
||||||
|
@@ -100,9 +100,9 @@ def uint32_to_uniform_float(x):
|
|||||||
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
|
This is originally designed from uint32, but it works with int32 too as long as the int32 uniformly
|
||||||
covers all the possible values it can take.
|
covers all the possible values it can take.
|
||||||
"""
|
"""
|
||||||
max = 2147483647.
|
max = 4.656613e-10 # = 1/MAX_INT = 1/2147483647.
|
||||||
x = tl.where(x < 0, -x - 1, x)
|
x = tl.where(x < 0, -x - 1, x)
|
||||||
return x / max
|
return x * max
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def pair_uniform_to_normal(u1, u2):
|
def pair_uniform_to_normal(u1, u2):
|
||||||
|
Reference in New Issue
Block a user