From bda209002e2bb758b78b74f61b70083a44f0e695 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 23 Apr 2022 13:18:33 -0700 Subject: [PATCH] [BACKEND][CODEGEN] vectorization bugfix (#502) --- include/triton/codegen/analysis/align.h | 3 + lib/codegen/analysis/align.cc | 41 ++++++++++- lib/codegen/selection/generator.cc | 10 +++ python/test/unit/runtime/test_comm.py | 98 ------------------------- 4 files changed, 51 insertions(+), 101 deletions(-) delete mode 100644 python/test/unit/runtime/test_comm.py diff --git a/include/triton/codegen/analysis/align.h b/include/triton/codegen/analysis/align.h index 2393603cb..513868aea 100644 --- a/include/triton/codegen/analysis/align.h +++ b/include/triton/codegen/analysis/align.h @@ -12,6 +12,7 @@ namespace ir { class phi_node; class splat_inst; class cast_inst; + class cmp_inst; class reshape_inst; class broadcast_inst; class binary_operator; @@ -35,6 +36,7 @@ private: std::vector populate_is_constant_reshape(ir::reshape_inst* x); std::vector populate_is_constant_broadcast(ir::broadcast_inst* x); std::vector populate_is_constant_binop(ir::binary_operator* x); + std::vector populate_is_constant_cmp(ir::cmp_inst* x); std::vector populate_is_constant_gep(ir::getelementptr_inst* x); std::vector populate_is_constant_default(ir::value* v); std::vector populate_is_constant(ir::value *v); @@ -65,6 +67,7 @@ public: void run(ir::module &mod); unsigned get(ir::value* v, unsigned ax) const; std::vector contiguous(ir::value* v) const; + std::vector get_cst_info(ir::value* v) const; private: std::map> is_constant_; diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index e92d3b6ee..bd68755f1 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -129,6 +129,33 @@ std::vector align::populate_is_constant_broadcast(ir::broadcast return add_to_cache(x, result, is_constant_); } +std::vector align::populate_is_constant_cmp(ir::cmp_inst* x) { + auto x_shapes = get_shapes(x); + std::vector result; + ir::value* lhs_op = x->get_operand(0); + ir::value* rhs_op = x->get_operand(1); + auto lhs = populate_is_constant(lhs_op); + auto rhs = populate_is_constant(rhs_op); + auto lhs_max_contiguous = populate_max_contiguous(lhs_op); + auto rhs_max_contiguous = populate_max_contiguous(rhs_op); + auto lhs_multiple_of = populate_starting_multiple(lhs_op); + auto rhs_multiple_of = populate_starting_multiple(rhs_op); + for(size_t d = 0; d < x_shapes.size(); d++) { + cst_info ax = {1, 0}; + // if lhs (resp. rhs) is a range of M value starting at a multiple of N + // and rhs (resp. lhs) is made of M constants that are multiples of N + // then comparisons have M constants + int min_multiple = std::min(lhs_multiple_of[d], rhs_multiple_of[d]); + if(rhs[d].num_cst % lhs_max_contiguous[d] == 0) + ax = {std::min(min_multiple, lhs_max_contiguous[d]), 0}; + else if(lhs[d].num_cst % rhs_max_contiguous[d] == 0) + ax = {std::min(min_multiple, rhs_max_contiguous[d]), 0}; + result.push_back(ax); + } + return add_to_cache(x, result, is_constant_); +} + + std::vector align::populate_is_constant_binop(ir::binary_operator* x) { auto x_shapes = get_shapes(x); std::vector result; @@ -136,12 +163,15 @@ std::vector align::populate_is_constant_binop(ir::binary_operat ir::value* rhs_op = x->get_operand(1); auto lhs = populate_is_constant(lhs_op); auto rhs = populate_is_constant(rhs_op); - auto max_contiguous = populate_max_contiguous(lhs_op); + auto lhs_max_contiguous = populate_max_contiguous(lhs_op); + auto rhs_max_contiguous = populate_max_contiguous(rhs_op); + auto lhs_multiple_of = populate_starting_multiple(lhs_op); + auto rhs_multiple_of = populate_starting_multiple(rhs_op); for(size_t d = 0; d < x_shapes.size(); d++) { cst_info ax; if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){ // todo might not be entirely true - unsigned num_constants = gcd(max_contiguous[d], rhs[d].value); + unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value); ax = {num_constants, 0}; } else @@ -184,6 +214,8 @@ std::vector align::populate_is_constant(ir::value *v) { return populate_is_constant_broadcast(x); if(auto *x = dynamic_cast(v)) return populate_is_constant_binop(x); + if(auto *x = dynamic_cast(v)) + return populate_is_constant_cmp(x); if(auto *x = dynamic_cast(v)) return populate_is_constant_gep(x); return populate_is_constant_default(v); @@ -511,12 +543,15 @@ std::vector align::contiguous(ir::value* v) const { return max_contiguous_.at(v); } +std::vector align::get_cst_info(ir::value* v) const { + return is_constant_.at(v); +} + void align::populate(ir::value *v) { populate_is_constant(v); populate_starting_multiple(v); populate_max_contiguous(v); - } void align::run(ir::module &mod) { diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index c60350060..e4723d86b 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -744,6 +744,11 @@ void generator::visit_load_inst(ir::load_inst* x){ if(op->get_type()->is_block_ty()){ auto ord = ords_.at(op); size_t aln = alignment_->get(op, ord[0]); + if(mx){ + size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst; + max_eq = std::max(max_eq, 1); + aln = std::min(aln, max_eq); + } auto layout = layouts_->get(x)->to_scanline(); if(layout){ size_t nts = layout->nts(ord[0]); @@ -912,6 +917,11 @@ void generator::visit_store_inst(ir::store_inst * x){ auto ord = ords_.at(x->get_pointer_operand()); size_t aln = alignment_->get(ptr_op, ord[0]); size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous; + if(mx){ + size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst; + max_eq = std::max(max_eq, 1); + aln = std::min(aln, max_eq); + } vec = std::min(nts, aln); } auto idxs = idxs_.at(val_op); diff --git a/python/test/unit/runtime/test_comm.py b/python/test/unit/runtime/test_comm.py deleted file mode 100644 index ae3fb69d7..000000000 --- a/python/test/unit/runtime/test_comm.py +++ /dev/null @@ -1,98 +0,0 @@ -import subprocess - -import numpy as np -import pytest -import torch - -import triton -import triton.language as tl - - -def get_p2p_matrix(): - try: - stdout = subprocess.check_output(["nvidia-smi", "topo", "-p2p", "n"]).decode("ascii") - except subprocess.CalledProcessError: - return pytest.skip("No multi-GPU topology", allow_module_level=True) - - lines = stdout.split("Legend")[0].split('\n')[1:] - matrix = np.array([line.split('\t')[1:-1] for line in lines][:-2]) - if matrix.size <= 1: - return pytest.skip("No multi-GPU topology", allow_module_level=True) - else: - return matrix - - -def get_p2p_devices(): - matrix = get_p2p_matrix() - idx = np.where(matrix == "OK") - return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else [] - - -def get_non_p2p_devices(): - matrix = get_p2p_matrix() - idx = np.where(matrix == "NS") - return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else [] - - -p2p_devices = get_p2p_devices() -non_p2p_devices = get_non_p2p_devices() - - -@triton.jit -def _copy(from_ptr, to_ptr, N, **meta): - pid = tl.program_id(0) - offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK']) - values = tl.load(from_ptr + offsets, mask=offsets < N) - tl.store(to_ptr + offsets, values, mask=offsets < N) - - -@pytest.mark.skipif(not p2p_devices, reason="No pair of device with P2P support") -@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to", - [(device_kernel, device_from, device_to, stream_from, stream_to) - for device_kernel in p2p_devices - for device_from in p2p_devices - for device_to in p2p_devices - for stream_from in ['default', 'custom'] - for stream_to in ['default', 'custom'] - ]) -def test_p2p(device_kernel, device_from, device_to, stream_from, stream_to): - if device_to == device_from: - return pytest.skip() - - torch.cuda.set_device(device_kernel) - N = 512 - grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),) - - with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)): - x_from = torch.randn(N, dtype=torch.float32, device=device_from) - with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)): - x_to = torch.empty(N, dtype=torch.float32, device=device_to) - - _copy[grid](x_from, x_to, N, BLOCK=1024) - assert torch.allclose(x_from, x_to.to(device_from)) - - -@pytest.mark.skipif(not non_p2p_devices, reason="No pair of device with no P2P support") -@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to", - [(device_kernel, device_from, device_to, stream_from, stream_to) - for device_kernel in non_p2p_devices - for device_from in non_p2p_devices - for device_to in non_p2p_devices - for stream_from in ['default', 'custom'] - for stream_to in ['default', 'custom'] - ]) -def test_non_p2p(device_kernel, device_from, device_to, stream_from, stream_to): - if device_to == device_from: - return pytest.skip() - - with pytest.raises(RuntimeError): - torch.cuda.set_device(device_kernel) - N = 512 - grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),) - - with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)): - x_from = torch.randn(N, dtype=torch.float32, device=device_from) - with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)): - x_to = torch.empty(N, dtype=torch.float32, device=device_to) - - _copy[grid](x_from, x_to, N, BLOCK=1024)