From ae2a1ab225335b054e04f6baaf537184bb337a15 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 25 Apr 2022 21:16:00 -0700 Subject: [PATCH] [BACKEND] Alignment pass improvements (#503) --- lib/codegen/analysis/align.cc | 33 +++++++++++++++++--------- lib/codegen/selection/generator.cc | 1 + python/test/unit/language/test_core.py | 4 +++- python/triton/code_gen.py | 6 ++--- 4 files changed, 29 insertions(+), 15 deletions(-) diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index bd68755f1..8dabbaf21 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -142,14 +142,17 @@ std::vector align::populate_is_constant_cmp(ir::cmp_inst* x) { auto rhs_multiple_of = populate_starting_multiple(rhs_op); for(size_t d = 0; d < x_shapes.size(); d++) { cst_info ax = {1, 0}; - // if lhs (resp. rhs) is a range of M value starting at a multiple of N - // and rhs (resp. lhs) is made of M constants that are multiples of N - // then comparisons have M constants - int min_multiple = std::min(lhs_multiple_of[d], rhs_multiple_of[d]); - if(rhs[d].num_cst % lhs_max_contiguous[d] == 0) - ax = {std::min(min_multiple, lhs_max_contiguous[d]), 0}; - else if(lhs[d].num_cst % rhs_max_contiguous[d] == 0) - ax = {std::min(min_multiple, rhs_max_contiguous[d]), 0}; + // Examples: + // 16 17 18 ... 32 < 24 24 24 ... 24 => equal in groups of 8 + // 16 17 18 ... 32 < 20 20 20 ... 20 => equal in groups of 4 + // 16 17 18 ... 32 < 16 16 16 ... 16 => equal in groups of 16 + // + // if LHS is a range of N continuous (or equal) elements that starts at M, + // and RHS is a set of N constants that start at K + // then the result in constant in groups of gcd(M, K) + if(rhs[d].num_cst % lhs_max_contiguous[d] == 0 || + rhs[d].num_cst % lhs[d].num_cst == 0) + ax.num_cst = gcd(lhs_multiple_of[d], rhs_multiple_of[d]); result.push_back(ax); } return add_to_cache(x, result, is_constant_); @@ -170,7 +173,6 @@ std::vector align::populate_is_constant_binop(ir::binary_operat for(size_t d = 0; d < x_shapes.size(); d++) { cst_info ax; if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){ - // todo might not be entirely true unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value); ax = {num_constants, 0}; } @@ -433,7 +435,7 @@ std::vector align::populate_starting_multiple_binop(ir::binary_operato if(x->is_int_add_sub()) result[d] = gcd(lhs[d], rhs[d]); if(x->is_int_div()) - result[d] = 1; + result[d] = (lhs[d] == (1 << 31)) ? 1 << 31 : 1; if(x->is_int_rem() && rhs[d] > 1){ result[d] = gcd(lhs[d], rhs[d]); } @@ -503,6 +505,15 @@ std::vector align::populate_starting_multiple_default(ir::value* v) { return add_to_cache(v, {1}, starting_multiple_); } +unsigned get_max_multiple(int val){ + if(val == 0) return 1 << 31; + if(val % 16 == 0) return 16; + if(val % 8 == 0) return 8; + if(val % 4 == 0) return 4; + if(val % 2 == 0) return 2; + return 1; +} + std::vector align::populate_starting_multiple(ir::value *v){ if(starting_multiple_.find(v) != starting_multiple_.end()) return starting_multiple_.at(v); @@ -518,7 +529,7 @@ std::vector align::populate_starting_multiple(ir::value *v){ if(auto *x = dynamic_cast(v)) return add_to_cache(x, {std::min(x->get_value(), 128)}, starting_multiple_); if(auto *x = dynamic_cast(v)) - return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_); + return add_to_cache(x, {get_max_multiple(x->get_first()->get_value())}, starting_multiple_); if(auto *x = dynamic_cast(v)) return populate_starting_multiple_gep(x); if(auto *x = dynamic_cast(v)) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 03533e559..c6f064ea8 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -785,6 +785,7 @@ void generator::visit_load_inst(ir::load_inst* x){ int width = std::min(tot_width, max_word_width); int n_words = std::max(1, tot_width / width); bool has_evict_policy = (x->get_eviction_policy() != ir::load_inst::NORMAL) && tgt_->as_nvidia()->sm() >= 80; + has_evict_policy = false; // currently disable until supported in `store` // ----- // create inline asm string // ----- diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a7f27eaba..9a997d661 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -937,13 +937,15 @@ def test_load_cache_modifier(cache): assert 'ld.global.ca' in ptx assert 'ld.global.cg' not in ptx + @pytest.mark.parametrize("N", [8, 10, 11, 1024]) def test_vectorization(N): src = torch.empty(1024, device='cuda') dst = torch.empty(1024, device='cuda') + @triton.jit def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): - offsets = tl.program_id(0)*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(src + offsets, mask=offsets < N) tl.store(dst + offsets, x, mask=offsets < N) pgm = _kernel[(1,)](dst, src, N=N, BLOCK_SIZE=src.shape[0]) diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index 5fd1c1be6..711cc87ac 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -942,9 +942,9 @@ class Kernel: assert _type == triton.language.constexpr, "only constexpr annotations are supported for now" wargs[pos] = _type(wargs[pos]) # check that tensors are on GPU. - for arg in wargs: - if hasattr(arg, 'data_ptr'): - assert arg.is_cuda, "All tensors must be on GPU!" + # for arg in wargs: + # if hasattr(arg, 'data_ptr'): + # assert arg.is_cuda, "All tensors must be on GPU!" # set device (i.e., make sure torch has the context initialized) device = torch.cuda.current_device() torch.cuda.set_device(device)