[BACKEND] Alignment pass improvements (#503)
This commit is contained in:
@@ -142,14 +142,17 @@ std::vector<align::cst_info> align::populate_is_constant_cmp(ir::cmp_inst* x) {
|
|||||||
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
|
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
|
||||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||||
cst_info ax = {1, 0};
|
cst_info ax = {1, 0};
|
||||||
// if lhs (resp. rhs) is a range of M value starting at a multiple of N
|
// Examples:
|
||||||
// and rhs (resp. lhs) is made of M constants that are multiples of N
|
// 16 17 18 ... 32 < 24 24 24 ... 24 => equal in groups of 8
|
||||||
// then comparisons have M constants
|
// 16 17 18 ... 32 < 20 20 20 ... 20 => equal in groups of 4
|
||||||
int min_multiple = std::min(lhs_multiple_of[d], rhs_multiple_of[d]);
|
// 16 17 18 ... 32 < 16 16 16 ... 16 => equal in groups of 16
|
||||||
if(rhs[d].num_cst % lhs_max_contiguous[d] == 0)
|
//
|
||||||
ax = {std::min<int>(min_multiple, lhs_max_contiguous[d]), 0};
|
// if LHS is a range of N continuous (or equal) elements that starts at M,
|
||||||
else if(lhs[d].num_cst % rhs_max_contiguous[d] == 0)
|
// and RHS is a set of N constants that start at K
|
||||||
ax = {std::min<int>(min_multiple, rhs_max_contiguous[d]), 0};
|
// then the result in constant in groups of gcd(M, K)
|
||||||
|
if(rhs[d].num_cst % lhs_max_contiguous[d] == 0 ||
|
||||||
|
rhs[d].num_cst % lhs[d].num_cst == 0)
|
||||||
|
ax.num_cst = gcd(lhs_multiple_of[d], rhs_multiple_of[d]);
|
||||||
result.push_back(ax);
|
result.push_back(ax);
|
||||||
}
|
}
|
||||||
return add_to_cache(x, result, is_constant_);
|
return add_to_cache(x, result, is_constant_);
|
||||||
@@ -170,7 +173,6 @@ std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operat
|
|||||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||||
cst_info ax;
|
cst_info ax;
|
||||||
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
|
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
|
||||||
// todo might not be entirely true
|
|
||||||
unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value);
|
unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value);
|
||||||
ax = {num_constants, 0};
|
ax = {num_constants, 0};
|
||||||
}
|
}
|
||||||
@@ -433,7 +435,7 @@ std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operato
|
|||||||
if(x->is_int_add_sub())
|
if(x->is_int_add_sub())
|
||||||
result[d] = gcd(lhs[d], rhs[d]);
|
result[d] = gcd(lhs[d], rhs[d]);
|
||||||
if(x->is_int_div())
|
if(x->is_int_div())
|
||||||
result[d] = 1;
|
result[d] = (lhs[d] == (1 << 31)) ? 1 << 31 : 1;
|
||||||
if(x->is_int_rem() && rhs[d] > 1){
|
if(x->is_int_rem() && rhs[d] > 1){
|
||||||
result[d] = gcd(lhs[d], rhs[d]);
|
result[d] = gcd(lhs[d], rhs[d]);
|
||||||
}
|
}
|
||||||
@@ -503,6 +505,15 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
|
|||||||
return add_to_cache(v, {1}, starting_multiple_);
|
return add_to_cache(v, {1}, starting_multiple_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsigned get_max_multiple(int val){
|
||||||
|
if(val == 0) return 1 << 31;
|
||||||
|
if(val % 16 == 0) return 16;
|
||||||
|
if(val % 8 == 0) return 8;
|
||||||
|
if(val % 4 == 0) return 4;
|
||||||
|
if(val % 2 == 0) return 2;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
||||||
if(starting_multiple_.find(v) != starting_multiple_.end())
|
if(starting_multiple_.find(v) != starting_multiple_.end())
|
||||||
return starting_multiple_.at(v);
|
return starting_multiple_.at(v);
|
||||||
@@ -518,7 +529,7 @@ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
|
|||||||
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
if(auto *x = dynamic_cast<ir::constant_int*>(v))
|
||||||
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
|
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
|
||||||
if(auto *x = dynamic_cast<ir::make_range*>(v))
|
if(auto *x = dynamic_cast<ir::make_range*>(v))
|
||||||
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_);
|
return add_to_cache(x, {get_max_multiple(x->get_first()->get_value())}, starting_multiple_);
|
||||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||||
return populate_starting_multiple_gep(x);
|
return populate_starting_multiple_gep(x);
|
||||||
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
if(auto *x = dynamic_cast<ir::splat_inst*>(v))
|
||||||
|
@@ -785,6 +785,7 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
int width = std::min(tot_width, max_word_width);
|
int width = std::min(tot_width, max_word_width);
|
||||||
int n_words = std::max(1, tot_width / width);
|
int n_words = std::max(1, tot_width / width);
|
||||||
bool has_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
|
bool has_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
|
||||||
|
has_evict_policy = false; // currently disable until supported in `store`
|
||||||
// -----
|
// -----
|
||||||
// create inline asm string
|
// create inline asm string
|
||||||
// -----
|
// -----
|
||||||
|
@@ -937,13 +937,15 @@ def test_load_cache_modifier(cache):
|
|||||||
assert 'ld.global.ca' in ptx
|
assert 'ld.global.ca' in ptx
|
||||||
assert 'ld.global.cg' not in ptx
|
assert 'ld.global.cg' not in ptx
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("N", [8, 10, 11, 1024])
|
@pytest.mark.parametrize("N", [8, 10, 11, 1024])
|
||||||
def test_vectorization(N):
|
def test_vectorization(N):
|
||||||
src = torch.empty(1024, device='cuda')
|
src = torch.empty(1024, device='cuda')
|
||||||
dst = torch.empty(1024, device='cuda')
|
dst = torch.empty(1024, device='cuda')
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
|
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
|
||||||
offsets = tl.program_id(0)*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
x = tl.load(src + offsets, mask=offsets < N)
|
x = tl.load(src + offsets, mask=offsets < N)
|
||||||
tl.store(dst + offsets, x, mask=offsets < N)
|
tl.store(dst + offsets, x, mask=offsets < N)
|
||||||
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0])
|
||||||
|
@@ -942,9 +942,9 @@ class Kernel:
|
|||||||
assert _type == triton.language.constexpr, "only constexpr annotations are supported for now"
|
assert _type == triton.language.constexpr, "only constexpr annotations are supported for now"
|
||||||
wargs[pos] = _type(wargs[pos])
|
wargs[pos] = _type(wargs[pos])
|
||||||
# check that tensors are on GPU.
|
# check that tensors are on GPU.
|
||||||
for arg in wargs:
|
# for arg in wargs:
|
||||||
if hasattr(arg, 'data_ptr'):
|
# if hasattr(arg, 'data_ptr'):
|
||||||
assert arg.is_cuda, "All tensors must be on GPU!"
|
# assert arg.is_cuda, "All tensors must be on GPU!"
|
||||||
# set device (i.e., make sure torch has the context initialized)
|
# set device (i.e., make sure torch has the context initialized)
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
Reference in New Issue
Block a user