[BACKEND][CODEGEN] vectorization bugfix (#502)
This commit is contained in:
@@ -12,6 +12,7 @@ namespace ir {
|
|||||||
class phi_node;
|
class phi_node;
|
||||||
class splat_inst;
|
class splat_inst;
|
||||||
class cast_inst;
|
class cast_inst;
|
||||||
|
class cmp_inst;
|
||||||
class reshape_inst;
|
class reshape_inst;
|
||||||
class broadcast_inst;
|
class broadcast_inst;
|
||||||
class binary_operator;
|
class binary_operator;
|
||||||
@@ -35,6 +36,7 @@ private:
|
|||||||
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
|
std::vector<cst_info> populate_is_constant_reshape(ir::reshape_inst* x);
|
||||||
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
|
std::vector<cst_info> populate_is_constant_broadcast(ir::broadcast_inst* x);
|
||||||
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
|
std::vector<cst_info> populate_is_constant_binop(ir::binary_operator* x);
|
||||||
|
std::vector<cst_info> populate_is_constant_cmp(ir::cmp_inst* x);
|
||||||
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
|
std::vector<cst_info> populate_is_constant_gep(ir::getelementptr_inst* x);
|
||||||
std::vector<cst_info> populate_is_constant_default(ir::value* v);
|
std::vector<cst_info> populate_is_constant_default(ir::value* v);
|
||||||
std::vector<cst_info> populate_is_constant(ir::value *v);
|
std::vector<cst_info> populate_is_constant(ir::value *v);
|
||||||
@@ -65,6 +67,7 @@ public:
|
|||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
unsigned get(ir::value* v, unsigned ax) const;
|
unsigned get(ir::value* v, unsigned ax) const;
|
||||||
std::vector<unsigned> contiguous(ir::value* v) const;
|
std::vector<unsigned> contiguous(ir::value* v) const;
|
||||||
|
std::vector<cst_info> get_cst_info(ir::value* v) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<ir::value*, std::vector<cst_info>> is_constant_;
|
std::map<ir::value*, std::vector<cst_info>> is_constant_;
|
||||||
|
@@ -129,6 +129,33 @@ std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast
|
|||||||
return add_to_cache(x, result, is_constant_);
|
return add_to_cache(x, result, is_constant_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<align::cst_info> align::populate_is_constant_cmp(ir::cmp_inst* x) {
|
||||||
|
auto x_shapes = get_shapes(x);
|
||||||
|
std::vector<cst_info> result;
|
||||||
|
ir::value* lhs_op = x->get_operand(0);
|
||||||
|
ir::value* rhs_op = x->get_operand(1);
|
||||||
|
auto lhs = populate_is_constant(lhs_op);
|
||||||
|
auto rhs = populate_is_constant(rhs_op);
|
||||||
|
auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
|
||||||
|
auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
|
||||||
|
auto lhs_multiple_of = populate_starting_multiple(lhs_op);
|
||||||
|
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
|
||||||
|
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||||
|
cst_info ax = {1, 0};
|
||||||
|
// if lhs (resp. rhs) is a range of M value starting at a multiple of N
|
||||||
|
// and rhs (resp. lhs) is made of M constants that are multiples of N
|
||||||
|
// then comparisons have M constants
|
||||||
|
int min_multiple = std::min(lhs_multiple_of[d], rhs_multiple_of[d]);
|
||||||
|
if(rhs[d].num_cst % lhs_max_contiguous[d] == 0)
|
||||||
|
ax = {std::min<int>(min_multiple, lhs_max_contiguous[d]), 0};
|
||||||
|
else if(lhs[d].num_cst % rhs_max_contiguous[d] == 0)
|
||||||
|
ax = {std::min<int>(min_multiple, rhs_max_contiguous[d]), 0};
|
||||||
|
result.push_back(ax);
|
||||||
|
}
|
||||||
|
return add_to_cache(x, result, is_constant_);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
|
std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
|
||||||
auto x_shapes = get_shapes(x);
|
auto x_shapes = get_shapes(x);
|
||||||
std::vector<cst_info> result;
|
std::vector<cst_info> result;
|
||||||
@@ -136,12 +163,15 @@ std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operat
|
|||||||
ir::value* rhs_op = x->get_operand(1);
|
ir::value* rhs_op = x->get_operand(1);
|
||||||
auto lhs = populate_is_constant(lhs_op);
|
auto lhs = populate_is_constant(lhs_op);
|
||||||
auto rhs = populate_is_constant(rhs_op);
|
auto rhs = populate_is_constant(rhs_op);
|
||||||
auto max_contiguous = populate_max_contiguous(lhs_op);
|
auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
|
||||||
|
auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
|
||||||
|
auto lhs_multiple_of = populate_starting_multiple(lhs_op);
|
||||||
|
auto rhs_multiple_of = populate_starting_multiple(rhs_op);
|
||||||
for(size_t d = 0; d < x_shapes.size(); d++) {
|
for(size_t d = 0; d < x_shapes.size(); d++) {
|
||||||
cst_info ax;
|
cst_info ax;
|
||||||
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
|
if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
|
||||||
// todo might not be entirely true
|
// todo might not be entirely true
|
||||||
unsigned num_constants = gcd(max_contiguous[d], rhs[d].value);
|
unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value);
|
||||||
ax = {num_constants, 0};
|
ax = {num_constants, 0};
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@@ -184,6 +214,8 @@ std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
|
|||||||
return populate_is_constant_broadcast(x);
|
return populate_is_constant_broadcast(x);
|
||||||
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
if(auto *x = dynamic_cast<ir::binary_operator*>(v))
|
||||||
return populate_is_constant_binop(x);
|
return populate_is_constant_binop(x);
|
||||||
|
if(auto *x = dynamic_cast<ir::cmp_inst*>(v))
|
||||||
|
return populate_is_constant_cmp(x);
|
||||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
|
||||||
return populate_is_constant_gep(x);
|
return populate_is_constant_gep(x);
|
||||||
return populate_is_constant_default(v);
|
return populate_is_constant_default(v);
|
||||||
@@ -511,12 +543,15 @@ std::vector<unsigned> align::contiguous(ir::value* v) const {
|
|||||||
return max_contiguous_.at(v);
|
return max_contiguous_.at(v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<align::cst_info> align::get_cst_info(ir::value* v) const {
|
||||||
|
return is_constant_.at(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void align::populate(ir::value *v) {
|
void align::populate(ir::value *v) {
|
||||||
populate_is_constant(v);
|
populate_is_constant(v);
|
||||||
populate_starting_multiple(v);
|
populate_starting_multiple(v);
|
||||||
populate_max_contiguous(v);
|
populate_max_contiguous(v);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void align::run(ir::module &mod) {
|
void align::run(ir::module &mod) {
|
||||||
|
@@ -744,6 +744,11 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
if(op->get_type()->is_block_ty()){
|
if(op->get_type()->is_block_ty()){
|
||||||
auto ord = ords_.at(op);
|
auto ord = ords_.at(op);
|
||||||
size_t aln = alignment_->get(op, ord[0]);
|
size_t aln = alignment_->get(op, ord[0]);
|
||||||
|
if(mx){
|
||||||
|
size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst;
|
||||||
|
max_eq = std::max<size_t>(max_eq, 1);
|
||||||
|
aln = std::min(aln, max_eq);
|
||||||
|
}
|
||||||
auto layout = layouts_->get(x)->to_scanline();
|
auto layout = layouts_->get(x)->to_scanline();
|
||||||
if(layout){
|
if(layout){
|
||||||
size_t nts = layout->nts(ord[0]);
|
size_t nts = layout->nts(ord[0]);
|
||||||
@@ -912,6 +917,11 @@ void generator::visit_store_inst(ir::store_inst * x){
|
|||||||
auto ord = ords_.at(x->get_pointer_operand());
|
auto ord = ords_.at(x->get_pointer_operand());
|
||||||
size_t aln = alignment_->get(ptr_op, ord[0]);
|
size_t aln = alignment_->get(ptr_op, ord[0]);
|
||||||
size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous;
|
size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous;
|
||||||
|
if(mx){
|
||||||
|
size_t max_eq = alignment_->get_cst_info(mx->get_mask_operand())[ord[0]].num_cst;
|
||||||
|
max_eq = std::max<size_t>(max_eq, 1);
|
||||||
|
aln = std::min(aln, max_eq);
|
||||||
|
}
|
||||||
vec = std::min(nts, aln);
|
vec = std::min(nts, aln);
|
||||||
}
|
}
|
||||||
auto idxs = idxs_.at(val_op);
|
auto idxs = idxs_.at(val_op);
|
||||||
|
@@ -1,98 +0,0 @@
|
|||||||
import subprocess
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
def get_p2p_matrix():
|
|
||||||
try:
|
|
||||||
stdout = subprocess.check_output(["nvidia-smi", "topo", "-p2p", "n"]).decode("ascii")
|
|
||||||
except subprocess.CalledProcessError:
|
|
||||||
return pytest.skip("No multi-GPU topology", allow_module_level=True)
|
|
||||||
|
|
||||||
lines = stdout.split("Legend")[0].split('\n')[1:]
|
|
||||||
matrix = np.array([line.split('\t')[1:-1] for line in lines][:-2])
|
|
||||||
if matrix.size <= 1:
|
|
||||||
return pytest.skip("No multi-GPU topology", allow_module_level=True)
|
|
||||||
else:
|
|
||||||
return matrix
|
|
||||||
|
|
||||||
|
|
||||||
def get_p2p_devices():
|
|
||||||
matrix = get_p2p_matrix()
|
|
||||||
idx = np.where(matrix == "OK")
|
|
||||||
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
|
|
||||||
|
|
||||||
|
|
||||||
def get_non_p2p_devices():
|
|
||||||
matrix = get_p2p_matrix()
|
|
||||||
idx = np.where(matrix == "NS")
|
|
||||||
return [f"cuda:{idx[0][0]}", f"cuda:{idx[1][0]}"] if len(idx[0]) > 0 else []
|
|
||||||
|
|
||||||
|
|
||||||
p2p_devices = get_p2p_devices()
|
|
||||||
non_p2p_devices = get_non_p2p_devices()
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _copy(from_ptr, to_ptr, N, **meta):
|
|
||||||
pid = tl.program_id(0)
|
|
||||||
offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
|
|
||||||
values = tl.load(from_ptr + offsets, mask=offsets < N)
|
|
||||||
tl.store(to_ptr + offsets, values, mask=offsets < N)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not p2p_devices, reason="No pair of device with P2P support")
|
|
||||||
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
|
|
||||||
[(device_kernel, device_from, device_to, stream_from, stream_to)
|
|
||||||
for device_kernel in p2p_devices
|
|
||||||
for device_from in p2p_devices
|
|
||||||
for device_to in p2p_devices
|
|
||||||
for stream_from in ['default', 'custom']
|
|
||||||
for stream_to in ['default', 'custom']
|
|
||||||
])
|
|
||||||
def test_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
|
|
||||||
if device_to == device_from:
|
|
||||||
return pytest.skip()
|
|
||||||
|
|
||||||
torch.cuda.set_device(device_kernel)
|
|
||||||
N = 512
|
|
||||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
|
|
||||||
|
|
||||||
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
|
|
||||||
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
|
|
||||||
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
|
|
||||||
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
|
|
||||||
|
|
||||||
_copy[grid](x_from, x_to, N, BLOCK=1024)
|
|
||||||
assert torch.allclose(x_from, x_to.to(device_from))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not non_p2p_devices, reason="No pair of device with no P2P support")
|
|
||||||
@pytest.mark.parametrize("device_kernel, device_from, device_to, stream_from, stream_to",
|
|
||||||
[(device_kernel, device_from, device_to, stream_from, stream_to)
|
|
||||||
for device_kernel in non_p2p_devices
|
|
||||||
for device_from in non_p2p_devices
|
|
||||||
for device_to in non_p2p_devices
|
|
||||||
for stream_from in ['default', 'custom']
|
|
||||||
for stream_to in ['default', 'custom']
|
|
||||||
])
|
|
||||||
def test_non_p2p(device_kernel, device_from, device_to, stream_from, stream_to):
|
|
||||||
if device_to == device_from:
|
|
||||||
return pytest.skip()
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
torch.cuda.set_device(device_kernel)
|
|
||||||
N = 512
|
|
||||||
grid = lambda meta: (triton.cdiv(N, meta['BLOCK']),)
|
|
||||||
|
|
||||||
with torch.cuda.stream(None if stream_from == 'default' else torch.cuda.Stream(device_from)):
|
|
||||||
x_from = torch.randn(N, dtype=torch.float32, device=device_from)
|
|
||||||
with torch.cuda.stream(None if stream_to == 'default' else torch.cuda.Stream(device_to)):
|
|
||||||
x_to = torch.empty(N, dtype=torch.float32, device=device_to)
|
|
||||||
|
|
||||||
_copy[grid](x_from, x_to, N, BLOCK=1024)
|
|
Reference in New Issue
Block a user