[BACKEND] Alignment pass improvements (#503)

This commit is contained in:
Philippe Tillet
2022-04-25 21:16:00 -07:00
committed by GitHub
parent 7d544799a0
commit ae2a1ab225
4 changed files with 29 additions and 15 deletions

View File

@@ -142,14 +142,17 @@ std::vector<align::cst_info> align::populate_is_constant_cmp(ir::cmp_inst* x) {
auto rhs_multiple_of = populate_starting_multiple(rhs_op); auto rhs_multiple_of = populate_starting_multiple(rhs_op);
for(size_t d = 0; d < x_shapes.size(); d++) { for(size_t d = 0; d < x_shapes.size(); d++) {
cst_info ax = {1, 0}; cst_info ax = {1, 0};
// if lhs (resp. rhs) is a range of M value starting at a multiple of N // Examples:
// and rhs (resp. lhs) is made of M constants that are multiples of N // 16 17 18 ... 32 < 24 24 24 ... 24 => equal in groups of 8
// then comparisons have M constants // 16 17 18 ... 32 < 20 20 20 ... 20 => equal in groups of 4
int min_multiple = std::min(lhs_multiple_of[d], rhs_multiple_of[d]); // 16 17 18 ... 32 < 16 16 16 ... 16 => equal in groups of 16
if(rhs[d].num_cst % lhs_max_contiguous[d] == 0) //
ax = {std::min<int>(min_multiple, lhs_max_contiguous[d]), 0}; // if LHS is a range of N continuous (or equal) elements that starts at M,
else if(lhs[d].num_cst % rhs_max_contiguous[d] == 0) // and RHS is a set of N constants that start at K
ax = {std::min<int>(min_multiple, rhs_max_contiguous[d]), 0}; // then the result in constant in groups of gcd(M, K)
if(rhs[d].num_cst % lhs_max_contiguous[d] == 0 ||
rhs[d].num_cst % lhs[d].num_cst == 0)
ax.num_cst = gcd(lhs_multiple_of[d], rhs_multiple_of[d]);
result.push_back(ax); result.push_back(ax);
} }
return add_to_cache(x, result, is_constant_); return add_to_cache(x, result, is_constant_);
@@ -170,7 +173,6 @@ std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operat
for(size_t d = 0; d < x_shapes.size(); d++) { for(size_t d = 0; d < x_shapes.size(); d++) {
cst_info ax; cst_info ax;
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){ if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
// todo might not be entirely true
unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value); unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value);
ax = {num_constants, 0}; ax = {num_constants, 0};
} }
@@ -433,7 +435,7 @@ std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operato
if(x->is_int_add_sub()) if(x->is_int_add_sub())
result[d] = gcd(lhs[d], rhs[d]); result[d] = gcd(lhs[d], rhs[d]);
if(x->is_int_div()) if(x->is_int_div())
result[d] = 1; result[d] = (lhs[d] == (1 << 31)) ? 1 << 31 : 1;
if(x->is_int_rem() && rhs[d] > 1){ if(x->is_int_rem() && rhs[d] > 1){
result[d] = gcd(lhs[d], rhs[d]); result[d] = gcd(lhs[d], rhs[d]);
} }
@@ -503,6 +505,15 @@ std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
return add_to_cache(v, {1}, starting_multiple_); return add_to_cache(v, {1}, starting_multiple_);
} }
unsigned get_max_multiple(int val){
if(val == 0) return 1 << 31;
if(val % 16 == 0) return 16;
if(val % 8 == 0) return 8;
if(val % 4 == 0) return 4;
if(val % 2 == 0) return 2;
return 1;
}
std::vector<unsigned> align::populate_starting_multiple(ir::value *v){ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(starting_multiple_.find(v) != starting_multiple_.end()) if(starting_multiple_.find(v) != starting_multiple_.end())
return starting_multiple_.at(v); return starting_multiple_.at(v);
@@ -518,7 +529,7 @@ std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
if(auto *x = dynamic_cast<ir::constant_int*>(v)) if(auto *x = dynamic_cast<ir::constant_int*>(v))
return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_); return add_to_cache(x, {std::min<unsigned>(x->get_value(), 128)}, starting_multiple_);
if(auto *x = dynamic_cast<ir::make_range*>(v)) if(auto *x = dynamic_cast<ir::make_range*>(v))
return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_); return add_to_cache(x, {get_max_multiple(x->get_first()->get_value())}, starting_multiple_);
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)) if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
return populate_starting_multiple_gep(x); return populate_starting_multiple_gep(x);
if(auto *x = dynamic_cast<ir::splat_inst*>(v)) if(auto *x = dynamic_cast<ir::splat_inst*>(v))

View File

@@ -785,6 +785,7 @@ void generator::visit_load_inst(ir::load_inst* x){
int width = std::min(tot_width, max_word_width); int width = std::min(tot_width, max_word_width);
int n_words = std::max(1, tot_width / width); int n_words = std::max(1, tot_width / width);
bool has_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; bool has_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80;
has_evict_policy = false; // currently disable until supported in `store`
// ----- // -----
// create inline asm string // create inline asm string
// ----- // -----

View File

@@ -937,10 +937,12 @@ def test_load_cache_modifier(cache):
assert 'ld.global.ca' in ptx assert 'ld.global.ca' in ptx
assert 'ld.global.cg' not in ptx assert 'ld.global.cg' not in ptx
@pytest.mark.parametrize("N", [8, 10, 11, 1024]) @pytest.mark.parametrize("N", [8, 10, 11, 1024])
def test_vectorization(N): def test_vectorization(N):
src = torch.empty(1024, device='cuda') src = torch.empty(1024, device='cuda')
dst = torch.empty(1024, device='cuda') dst = torch.empty(1024, device='cuda')
@triton.jit @triton.jit
def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr):
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)

View File

@@ -942,9 +942,9 @@ class Kernel:
assert _type == triton.language.constexpr, "only constexpr annotations are supported for now" assert _type == triton.language.constexpr, "only constexpr annotations are supported for now"
wargs[pos] = _type(wargs[pos]) wargs[pos] = _type(wargs[pos])
# check that tensors are on GPU. # check that tensors are on GPU.
for arg in wargs: # for arg in wargs:
if hasattr(arg, 'data_ptr'): # if hasattr(arg, 'data_ptr'):
assert arg.is_cuda, "All tensors must be on GPU!" # assert arg.is_cuda, "All tensors must be on GPU!"
# set device (i.e., make sure torch has the context initialized) # set device (i.e., make sure torch has the context initialized)
device = torch.cuda.current_device() device = torch.cuda.current_device()
torch.cuda.set_device(device) torch.cuda.set_device(device)