[GENERAL] Merged v1.0alpha into master. Added features are:
- A100 support via mma.16816 - Thread swizzling for conflict-free shared memory accesses without padding - Complete overhaul of the LLVM code generation in codegen/selection/generator.cc to remove overengineering - Added debugging capabilities in the Python binding - Compilation error for kernels that spill
This commit is contained in:
@@ -79,7 +79,7 @@ void axes::update_graph_dot(ir::instruction *i) {
|
||||
graph_.add_edge({dot, d}, {D, d});
|
||||
}
|
||||
|
||||
void axes::update_graph_elementwise(ir::instruction *i) {
|
||||
void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) {
|
||||
if(i->get_num_operands() == 0)
|
||||
return;
|
||||
ir::value *op = i->get_operand(0);
|
||||
@@ -89,7 +89,7 @@ void axes::update_graph_elementwise(ir::instruction *i) {
|
||||
for(unsigned d = 0; d < rank; d++)
|
||||
for(ir::value* opx: i->ops())
|
||||
for(ir::value* opy: i->ops()){
|
||||
if(!i->get_type()->is_void_ty())
|
||||
if(connect_ret && !i->get_type()->is_void_ty())
|
||||
graph_.add_edge({i, d}, {opx, d});
|
||||
graph_.add_edge({opx, d}, {opy, d});
|
||||
}
|
||||
@@ -111,7 +111,8 @@ void axes::update_graph(ir::instruction *i) {
|
||||
case ir::INST_TRANS: return update_graph_trans(i);
|
||||
case ir::INST_BROADCAST: return update_graph_broadcast(i);
|
||||
case ir::INST_DOT: return update_graph_dot(i);
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);;
|
||||
case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_MASKED_LOAD_ASYNC:return update_graph_elementwise(i, false);
|
||||
case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i);
|
||||
case ir::INST_RECOALESCE: return update_graph_no_edge(i);
|
||||
default: return update_graph_elementwise(i);
|
||||
|
@@ -55,7 +55,7 @@ inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
|
||||
for(ir::user* u: v->get_users()){
|
||||
auto i = dynamic_cast<ir::dot_inst*>(u);
|
||||
if(i && is_hmma_c(i) && i->get_operand(n) == v)
|
||||
result = v;
|
||||
result = i;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,8 +115,10 @@ data_layout::data_layout(id_t id,
|
||||
}
|
||||
}
|
||||
|
||||
size_t data_layout::find_axis(int to_find) const {
|
||||
int data_layout::find_axis(int to_find) const {
|
||||
auto it = std::find(axes_.begin(), axes_.end(), to_find);
|
||||
if(it == axes_.end())
|
||||
return -1;
|
||||
return std::distance(axes_.begin(), it);
|
||||
}
|
||||
|
||||
@@ -125,23 +127,41 @@ size_t data_layout::find_axis(int to_find) const {
|
||||
* MMA Layout *
|
||||
* -------------------------------- */
|
||||
|
||||
mma884_layout::mma884_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) {
|
||||
mma_layout::mma_layout(size_t num_warps,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
analysis::align* align, target* tgt,
|
||||
shared_layout *layout_a, shared_layout *layout_b): data_layout(MMA, axes, shape, values, align) {
|
||||
/* fragments per warp */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
fpw_ = {1, 1, 1};
|
||||
std::vector<int> fpw_nm1;
|
||||
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
|
||||
do {
|
||||
fpw_nm1 = fpw_;
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
|
||||
}while(fpw_nm1 != fpw_);
|
||||
if(tgt->as_nvidia()->sm() < 80){
|
||||
fpw_ = {1, 1, 1};
|
||||
std::vector<int> fpw_nm1;
|
||||
unsigned num_fragments = std::min<unsigned>((shape_[0]/8)*(shape_[1]/8), 4);
|
||||
do {
|
||||
fpw_nm1 = fpw_;
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8);
|
||||
if(fpw_[0]*fpw_[1] < num_fragments)
|
||||
fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8);
|
||||
}while(fpw_nm1 != fpw_);
|
||||
auto ord_a = layout_a->get_order();
|
||||
auto ord_b = layout_b->get_order();
|
||||
bool is_a_row = ord_a[0] != 0;
|
||||
bool is_b_row = ord_b[0] != 0;
|
||||
bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16);
|
||||
bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16);
|
||||
int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2;
|
||||
int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1;
|
||||
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
|
||||
spw_ = {fpw_[0]*8*pack_size_0, fpw_[1]*8*pack_size_1, 1};
|
||||
}
|
||||
else{
|
||||
fpw_ = {1, 1, 1};
|
||||
spw_ = {16, 8, 1};
|
||||
rep_ = {2, 2, 1};
|
||||
}
|
||||
|
||||
/* warps per tile */
|
||||
// try to make things as square as possible to maximize data re-use
|
||||
@@ -150,17 +170,13 @@ mma884_layout::mma884_layout(size_t num_warps,
|
||||
do{
|
||||
wpt_nm1 = wpt_;
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8));
|
||||
wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]);
|
||||
if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps)
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8));
|
||||
wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]);
|
||||
}while(wpt_nm1 != wpt_);
|
||||
|
||||
/* sanity check */
|
||||
unsigned effective_num_warps = 1;
|
||||
for(size_t d = 0; d < shape.size(); d++)
|
||||
effective_num_warps *= wpt_[d];
|
||||
// if(num_warps != effective_num_warps)
|
||||
// throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
/* shape per block */
|
||||
spt_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1};
|
||||
}
|
||||
|
||||
|
||||
@@ -183,13 +199,15 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
ir::value *ptr = nullptr;
|
||||
for(ir::value *v: values)
|
||||
for(ir::user *usr: v->get_users())
|
||||
if(auto *st = dynamic_cast<ir::store_inst*>(usr))
|
||||
if(auto *st = dynamic_cast<ir::io_inst*>(usr))
|
||||
ptr = st->get_pointer_operand();
|
||||
|
||||
unsigned i = order_[0];
|
||||
int contiguous = 4;
|
||||
if(ptr)
|
||||
contiguous = std::min<int>(align->contiguous(ptr)[i], 4);
|
||||
int contiguous = 1;
|
||||
if(ptr){
|
||||
int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits();
|
||||
contiguous = std::min<int>(align->contiguous(ptr)[i], 128 / nbits);
|
||||
}
|
||||
|
||||
nts_[i] = clamp(size / num_threads, 1, std::min<int>(contiguous, shape_[i]));
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
@@ -204,14 +222,6 @@ scanline_layout::scanline_layout(size_t num_warps,
|
||||
mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]);
|
||||
num_threads = num_threads / mts_[i];
|
||||
}
|
||||
/* sanity check */
|
||||
unsigned effective_num_threads = 1;
|
||||
for(size_t d = 0; d < shape_.size(); d++)
|
||||
effective_num_threads *= mts_[d];
|
||||
|
||||
// std::cout <<values.size() << " " << num_warps << " " << effective_num_threads << std::endl;
|
||||
// if(num_warps * 32 != effective_num_threads)
|
||||
// throw std::runtime_error("cannot create a kernel with this amount of warps");
|
||||
}
|
||||
|
||||
|
||||
@@ -246,9 +256,9 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
|
||||
ir::value *value_1 = phi->get_incoming_value(1);
|
||||
ir::instruction *i_0 = dynamic_cast<ir::instruction*>(value_0);
|
||||
ir::instruction *i_1 = dynamic_cast<ir::instruction*>(value_1);
|
||||
if(!i_0 || !i_1 ||
|
||||
!dynamic_cast<ir::copy_to_shared_inst*>(i_0) ||
|
||||
!dynamic_cast<ir::copy_to_shared_inst*>(i_1) )
|
||||
if(!(i_0 && !i_1) &&
|
||||
!(dynamic_cast<ir::copy_to_shared_inst*>(i_0) && dynamic_cast<ir::copy_to_shared_inst*>(i_1)) &&
|
||||
!(dynamic_cast<ir::masked_load_async_inst*>(i_0) && dynamic_cast<ir::masked_load_async_inst*>(i_1)))
|
||||
return;
|
||||
if(is_latch_1)
|
||||
res.reset(new double_buffer_info_t{value_0, value_1, phi});
|
||||
@@ -257,7 +267,7 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr<doub
|
||||
}
|
||||
|
||||
|
||||
shared_layout::shared_layout(const data_layout *arg,
|
||||
shared_layout::shared_layout(data_layout *arg,
|
||||
const std::vector<int>& axes,
|
||||
const std::vector<unsigned>& shape,
|
||||
const std::vector<ir::value *> &values,
|
||||
@@ -265,6 +275,7 @@ shared_layout::shared_layout(const data_layout *arg,
|
||||
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
|
||||
|
||||
size_ = 0;
|
||||
arg_layout_ = arg;
|
||||
|
||||
// double-buffering
|
||||
for(ir::value *v: values)
|
||||
@@ -284,36 +295,8 @@ shared_layout::shared_layout(const data_layout *arg,
|
||||
extract_hmma_dot_use(v, hmma_dot_a, 0);
|
||||
extract_hmma_dot_use(v, hmma_dot_b, 1);
|
||||
}
|
||||
|
||||
|
||||
// non-mma ordering
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
for(size_t s = 2; s < get_rank(); s++){
|
||||
col.push_back(s);
|
||||
row.push_back(s);
|
||||
}
|
||||
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
||||
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
||||
if(is_nonhmma_dot_a)
|
||||
order_ = is_trans(dot_a) ? row : col;
|
||||
else if(is_nonhmma_dot_b)
|
||||
order_ = is_trans(dot_b) ? col : row;
|
||||
|
||||
// padding
|
||||
size_t pad = 0;
|
||||
if(hmma_dot_a){
|
||||
bool row = is_trans(hmma_dot_a) ^ order_[0] != 0;
|
||||
pad = 24 - shape_[row ? 0 : 1] % 32;
|
||||
}
|
||||
else if(hmma_dot_b){
|
||||
bool row = is_trans(hmma_dot_b) ^ order_[0] != 0;
|
||||
pad = 24 - shape_[row ? 1 : 0] % 32;
|
||||
}
|
||||
else if(order_ != arg_order) {
|
||||
pad = 4;
|
||||
}
|
||||
shape_[order_[0]] += pad;
|
||||
hmma_dot_a_ = hmma_dot_a;
|
||||
hmma_dot_b_ = hmma_dot_b;
|
||||
|
||||
// size
|
||||
size_ = ty_->get_primitive_size_in_bits() / 8;
|
||||
@@ -362,6 +345,8 @@ void layouts::make_graph(ir::instruction *i) {
|
||||
}
|
||||
|
||||
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
// if(layouts_.find(id) != layouts_.end())
|
||||
// return;
|
||||
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
||||
auto cmp = [](ir::value* x, ir::value *y) {
|
||||
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
|
||||
@@ -374,19 +359,27 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
|
||||
const auto& axes = axes_->get(largest);
|
||||
const auto& shapes = largest->get_type()->get_tile_shapes();
|
||||
auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) {
|
||||
return dynamic_cast<ir::copy_to_shared_inst*>(v);
|
||||
return dynamic_cast<ir::copy_to_shared_inst*>(v) ||
|
||||
dynamic_cast<ir::masked_load_async_inst*>(v);
|
||||
});
|
||||
// type
|
||||
if(it_hmma_c != values.end())
|
||||
layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_);
|
||||
if(it_hmma_c != values.end()){
|
||||
ir::instruction *dot = (ir::instruction*)*it_hmma_c;
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
create(groups_.at(a), values_.at(groups_.at(a)));
|
||||
create(groups_.at(b), values_.at(groups_.at(b)));
|
||||
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b)));
|
||||
}
|
||||
else if(it_cts != values.end()){
|
||||
ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts;
|
||||
ir::instruction *cts = (ir::instruction*)*it_cts;
|
||||
ir::value *arg = cts->get_operand(0);
|
||||
create(groups_.at(arg), values_.at(groups_.at(arg)));
|
||||
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
|
||||
}
|
||||
else
|
||||
else{
|
||||
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
|
||||
}
|
||||
}
|
||||
|
||||
void layouts::run(ir::module &mod) {
|
||||
@@ -420,7 +413,7 @@ void layouts::run(ir::module &mod) {
|
||||
}
|
||||
if(auto *recoalasce = dynamic_cast<ir::recoalesce_inst*>(i)){
|
||||
ir::value *val = recoalasce->get_operand(0);
|
||||
mma884_layout* in_layout = get(val)->to_mma884();
|
||||
mma_layout* in_layout = get(val)->to_mma();
|
||||
scanline_layout* out_layout = get(i)->to_scanline();
|
||||
if(!in_layout || !out_layout)
|
||||
return;
|
||||
@@ -431,7 +424,7 @@ void layouts::run(ir::module &mod) {
|
||||
shape[ld] = in_shape[ld];
|
||||
for(size_t k = 0; k < in_shape.size(); k++)
|
||||
if(k != ld)
|
||||
shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k);
|
||||
shape[k] = in_layout->to_mma()->spt(k);
|
||||
// create layout
|
||||
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_);
|
||||
tmp_[recoalasce] = id;
|
||||
|
54
lib/codegen/analysis/swizzle.cc
Normal file
54
lib/codegen/analysis/swizzle.cc
Normal file
@@ -0,0 +1,54 @@
|
||||
#include "triton/codegen/analysis/swizzle.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
namespace analysis{
|
||||
|
||||
|
||||
void swizzle::run(ir::module &) {
|
||||
per_phase_.clear();
|
||||
max_phase_.clear();
|
||||
|
||||
for(auto &x: layouts_->get_all()){
|
||||
shared_layout* layout = dynamic_cast<shared_layout*>(x.second);
|
||||
if(!layout)
|
||||
continue;
|
||||
ir::value* mma_dot_a = layout->hmma_dot_a();
|
||||
ir::value* mma_dot_b = layout->hmma_dot_b();
|
||||
if(!mma_dot_a && !mma_dot_b){
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
continue;
|
||||
}
|
||||
auto ord = layout->get_order();
|
||||
scanline_layout* in_layout = dynamic_cast<scanline_layout*>(layout->get_arg_layout());
|
||||
if(!in_layout)
|
||||
continue;
|
||||
int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
|
||||
if(tgt_->as_nvidia()->sm() < 80){
|
||||
int inner = mma_dot_a ? 0 : 1;
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout];
|
||||
if(mma_dot_a)
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0);
|
||||
else
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||
}
|
||||
else{
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = 8 / per_phase_[layout];
|
||||
vec_[layout] = 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@@ -1,325 +0,0 @@
|
||||
#include <numeric>
|
||||
#include "triton/codegen/selection/machine_layout.h"
|
||||
#include "triton/codegen/selection/machine_value.h"
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) {
|
||||
// function
|
||||
if(auto* tt = dynamic_cast<ir::function_type*>(ty)){
|
||||
Type *return_ty = llvm_type(tt->get_return_ty(), ctx);
|
||||
std::vector<Type*> param_tys;
|
||||
std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys),
|
||||
[&ctx](ir::type* t){ return llvm_type(t, ctx);});
|
||||
return FunctionType::get(return_ty, param_tys, false);
|
||||
}
|
||||
// pointer
|
||||
if(ty->is_pointer_ty()){
|
||||
Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx);
|
||||
unsigned addr_space = ty->get_pointer_address_space();
|
||||
return PointerType::get(elt_ty, addr_space);
|
||||
}
|
||||
// integer
|
||||
if(ty->is_integer_ty()){
|
||||
unsigned bitwidth = ty->get_integer_bitwidth();
|
||||
return IntegerType::get(ctx, bitwidth);
|
||||
}
|
||||
// primitive types
|
||||
switch(ty->get_type_id()){
|
||||
case ir::type::VoidTyID: return Type::getVoidTy(ctx);
|
||||
case ir::type::HalfTyID: return Type::getHalfTy(ctx);
|
||||
case ir::type::FloatTyID: return Type::getFloatTy(ctx);
|
||||
case ir::type::DoubleTyID: return Type::getDoubleTy(ctx);
|
||||
case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx);
|
||||
case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx);
|
||||
case ir::type::LabelTyID: return Type::getLabelTy(ctx);
|
||||
case ir::type::MetadataTyID: return Type::getMetadataTy(ctx);
|
||||
case ir::type::TokenTyID: return Type::getTokenTy(ctx);
|
||||
default: break;
|
||||
}
|
||||
// unknown type
|
||||
throw std::runtime_error("unknown conversion from ir::type to Type");
|
||||
}
|
||||
|
||||
// Grid construction
|
||||
inline std::vector<Value*> delinearize(Value *trailing, const std::vector<int>& order, std::vector<int> &shapes, IRBuilder<> &builder){
|
||||
size_t dim = shapes.size();
|
||||
std::vector<Value*> result(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
Constant *dim_k = builder.getInt32(shapes[order[k]]);
|
||||
Value *rem = builder.CreateURem(trailing, dim_k);
|
||||
trailing = builder.CreateUDiv(trailing, dim_k);
|
||||
result[order[k]] = rem;
|
||||
}
|
||||
result[order[dim - 1]] = trailing;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline int32_t ceil(int32_t num, int32_t div){
|
||||
return (num + div - 1)/div;
|
||||
}
|
||||
|
||||
|
||||
|
||||
machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc,
|
||||
Value *&sh_mem_ptr, analysis::shared_layout *layout,
|
||||
std::map<ir::value *, Value *>& vmap,
|
||||
std::map<ir::value *, tile *>& tmap)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) {
|
||||
|
||||
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
|
||||
PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace());
|
||||
// double-buffered
|
||||
if(layout_->get_double_buffer()) {
|
||||
BasicBlock *current = builder_->GetInsertBlock();
|
||||
auto info = *layout_->get_double_buffer();
|
||||
ir::phi_node *phi = info.phi;
|
||||
BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent()));
|
||||
if(parent->empty())
|
||||
builder_->SetInsertPoint(parent);
|
||||
else
|
||||
builder_->SetInsertPoint(&*parent->getFirstNonPHI());
|
||||
// create pointers
|
||||
ptr_ = builder_->CreatePHI(ptr_ty, 2);
|
||||
pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_)));
|
||||
pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType());
|
||||
offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2);
|
||||
next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr");
|
||||
builder_->SetInsertPoint(current);
|
||||
}
|
||||
else{
|
||||
size_t offset = alloc_->offset(layout_);
|
||||
ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset));
|
||||
ptr_ = builder_->CreateBitCast(ptr_, ptr_ty);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
tile* machine_shared_layout::create(ir::value *v) {
|
||||
Type* ty = llvm_type(layout_->get_type(), builder_->getContext());
|
||||
auto double_buffer = layout_->get_double_buffer();
|
||||
// offset
|
||||
Value *offset = nullptr;
|
||||
if(double_buffer && v == double_buffer->phi)
|
||||
offset = offset_;
|
||||
// base pointer
|
||||
Value *ptr = ptr_;
|
||||
if(double_buffer && v == double_buffer->latch)
|
||||
ptr = next_ptr_;
|
||||
else if(double_buffer && v == double_buffer->first)
|
||||
ptr = pre_ptr_;
|
||||
// create tile
|
||||
return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset);
|
||||
}
|
||||
|
||||
machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::data_layout *layout)
|
||||
: mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) {
|
||||
|
||||
}
|
||||
|
||||
tile *machine_distributed_layout::create(ir::value *v) {
|
||||
Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext());
|
||||
const auto &shapes = v->get_type()->get_tile_shapes();
|
||||
size_t rank = shapes.size();
|
||||
std::vector<distributed_axis> axes(rank);
|
||||
std::vector<int> order(rank);
|
||||
// compute axes
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d] > 1){
|
||||
unsigned x = a_axes_->get(v, d);
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
else{
|
||||
axes[d].contiguous = 1;
|
||||
axes[d].values = {builder_->getInt32(0)};
|
||||
}
|
||||
}
|
||||
// compute order
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
auto cmp = [&](int x, int y) {
|
||||
unsigned axx = a_axes_->get(v, x);
|
||||
unsigned axy = a_axes_->get(v, y);
|
||||
size_t posx = layout_->find_axis(axx);
|
||||
size_t posy = layout_->find_axis(axy);
|
||||
if(posx < rank && posy < rank)
|
||||
return layout_->get_order(posx) < layout_->get_order(posy);
|
||||
return false;
|
||||
};
|
||||
std::sort(order.begin(), order.end(), cmp);
|
||||
return new distributed_tile(ty, shapes, order, axes, *builder_);
|
||||
}
|
||||
|
||||
machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder,
|
||||
target *tgt, analysis::axes *a_axes,
|
||||
std::map<unsigned, distributed_axis>& axes,
|
||||
analysis::mma884_layout* layout)
|
||||
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||
|
||||
const auto& shape = layout->get_shape();
|
||||
if(shape.size() > 3)
|
||||
throw std::runtime_error("unsupported");
|
||||
bool is_batched = shape.size() >= 3;
|
||||
|
||||
Value *_1 = builder_->getInt32(1);
|
||||
Value *_2 = builder_->getInt32(2);
|
||||
Value *_3 = builder_->getInt32(3);
|
||||
Value *_4 = builder_->getInt32(4);
|
||||
Value *_16 = builder_->getInt32(16);
|
||||
|
||||
// fragments per warp
|
||||
unsigned fpw_0 = layout->fpw(0);
|
||||
unsigned fpw_1 = layout->fpw(1);
|
||||
unsigned fpw_2 = is_batched ? layout->fpw(2) : 1;
|
||||
// warps per tile
|
||||
unsigned wpt_0 = layout->wpt(0);
|
||||
unsigned wpt_1 = layout->wpt(1);
|
||||
unsigned wpt_2 = is_batched ? layout->wpt(2) : 1;
|
||||
// mma warp tile size
|
||||
unsigned hmma_wts_0 = fpw_0 * 8;
|
||||
unsigned hmma_wts_1 = fpw_1 * 8;
|
||||
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
|
||||
// mma block tile size
|
||||
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
|
||||
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
|
||||
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
|
||||
// number of repetition
|
||||
unsigned num_rep_0 = shape[0] / hmma_bts_0;
|
||||
unsigned num_rep_1 = shape[1] / hmma_bts_1;
|
||||
unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1;
|
||||
// size of each pack (interleaving)
|
||||
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
|
||||
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
|
||||
// number of packs (interleaving)
|
||||
num_packs_0_ = num_rep_0 / pack_size_0_;
|
||||
num_packs_1_ = num_rep_1 / pack_size_1_;
|
||||
|
||||
/* intra warp offset */
|
||||
// offset of quad in pair
|
||||
Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
||||
builder_->getInt32(fpw_0 * pack_size_0_));
|
||||
Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)),
|
||||
builder_->getInt32(fpw_1 * pack_size_1_));
|
||||
|
||||
// Quad pair id
|
||||
Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
||||
Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4);
|
||||
pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0));
|
||||
pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0));
|
||||
pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1));
|
||||
// Quad pair offset
|
||||
Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_));
|
||||
Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_));
|
||||
|
||||
/* inter warp offset */
|
||||
Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0));
|
||||
Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0));
|
||||
Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1));
|
||||
Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1));
|
||||
Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_));
|
||||
Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_));
|
||||
|
||||
/* offsets */
|
||||
// a offset
|
||||
offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a));
|
||||
offset_a_k_ = builder_->CreateAnd(u_thread_id, _3);
|
||||
// b offsets
|
||||
offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b));
|
||||
offset_b_k_ = builder_->CreateAnd(u_thread_id, _3);
|
||||
|
||||
// c offsets
|
||||
Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_);
|
||||
Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2),
|
||||
builder_->CreateAdd(warp_offset_j, pair_b_off));
|
||||
|
||||
/* indices */
|
||||
// i indices
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned pack = 0; pack < num_packs_0_; pack++)
|
||||
for(unsigned ii = 0; ii < pack_size_0_; ii++)
|
||||
for(unsigned i = 0; i < 2; i++){
|
||||
idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2)));
|
||||
}
|
||||
// j indices
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned pack = 0; pack < num_packs_1_; pack++)
|
||||
for(unsigned jj = 0; jj < pack_size_1_; jj++)
|
||||
for(unsigned j = 0; j < 2; j++){
|
||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_)));
|
||||
idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1)));
|
||||
}
|
||||
// z indices
|
||||
std::vector<Value*> idx_z;
|
||||
for(unsigned pack = 0; pack < num_rep_2; pack++)
|
||||
idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2)));
|
||||
|
||||
|
||||
/* axes */
|
||||
axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0};
|
||||
axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1};
|
||||
if(is_batched)
|
||||
axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2};
|
||||
}
|
||||
|
||||
|
||||
machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder,
|
||||
target *tgt,
|
||||
analysis::axes *a_axes, std::map<unsigned, distributed_axis> &axes,
|
||||
analysis::scanline_layout* layout)
|
||||
: machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) {
|
||||
|
||||
Value *warp_size = builder_->getInt32(32);
|
||||
Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0);
|
||||
Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size);
|
||||
Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size);
|
||||
|
||||
auto order = layout->get_order();
|
||||
const auto& shape = layout->get_shape();
|
||||
Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id);
|
||||
// Delinearize
|
||||
size_t dim = shape.size();
|
||||
std::vector<Value*> thread_id(dim);
|
||||
for(unsigned k = 0; k < dim - 1; k++){
|
||||
Constant *dim_k = builder_->getInt32(layout->mts(order[k]));
|
||||
Value *rem = builder_->CreateURem(full_thread_id, dim_k);
|
||||
full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k);
|
||||
thread_id[order[k]] = rem;
|
||||
}
|
||||
thread_id[order[dim - 1]] = full_thread_id;
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
int nts = layout->nts(k);
|
||||
int mts = layout->mts(k);
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *contiguous_k = builder_->getInt32(nts);
|
||||
Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k);
|
||||
unsigned per_block = nts * mts;
|
||||
unsigned per_thread = nts * shape[k] / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / nts * per_block + n % nts;
|
||||
idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -1,214 +0,0 @@
|
||||
#include <numeric>
|
||||
#include <iostream>
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "triton/codegen/selection/machine_value.h"
|
||||
|
||||
namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
/* Distributed Tile */
|
||||
void distributed_tile::init_indices() {
|
||||
std::vector<size_t> id(axes_.size(), 0);
|
||||
// build
|
||||
size_t k = 0;
|
||||
while(true) {
|
||||
indices_t current;
|
||||
for(size_t d = 0; d < id.size(); d++)
|
||||
current.push_back(axes_[d].values[id[d]]);
|
||||
size_t sz = indices_.size();
|
||||
indices_[current] = sz;
|
||||
values_[current] = nullptr;
|
||||
ordered_indices_.push_back(current);
|
||||
id[order_[0]]++;
|
||||
while(id[order_[k]] == axes_[order_[k]].values.size()){
|
||||
if(k == id.size() - 1)
|
||||
return;
|
||||
id[order_[k++]] = 0;
|
||||
id[order_[k]]++;
|
||||
}
|
||||
k = 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, const axes_t &axes, llvm::IRBuilder<> &builder)
|
||||
: tile(ty, shapes), axes_(axes), order_(order), builder_(builder) {
|
||||
init_indices();
|
||||
}
|
||||
|
||||
void distributed_tile::set_value(indices_t idx, Value *x) {
|
||||
assert(x->getType() == ty_ && "cannot set a value of different type");
|
||||
Value *&result = values_[idx];
|
||||
assert(!result && "value cannot be set twice");
|
||||
result = x;
|
||||
}
|
||||
|
||||
Value* distributed_tile::get_value(indices_t idx) {
|
||||
Value *result = values_.at(idx);
|
||||
assert(result && "value has not been set");
|
||||
return result;
|
||||
}
|
||||
|
||||
unsigned distributed_tile::get_linear_index(indices_t idx) {
|
||||
return indices_[idx];
|
||||
}
|
||||
|
||||
indices_t distributed_tile::get_ordered_indices(unsigned id) {
|
||||
return ordered_indices_.at(id);
|
||||
}
|
||||
|
||||
|
||||
void distributed_tile::for_each(std::function<void (indices_t)> fn, int start, int end) {
|
||||
if(end < 0)
|
||||
end = ordered_indices_.size() + end + 1;
|
||||
for(unsigned i = start; i < end; i++)
|
||||
fn(ordered_indices_[i]);
|
||||
}
|
||||
|
||||
void distributed_tile::for_each(std::function<void(indices_t)> fn, std::vector<int> starts, std::vector<int> sizes){
|
||||
int rank = sizes.size();
|
||||
int len = 1;
|
||||
for(int s: sizes)
|
||||
len *= s;
|
||||
|
||||
for(int i = 0; i < len; i++){
|
||||
indices_t idx(rank);
|
||||
int current = i;
|
||||
for(int k = 0; k < rank; k++){
|
||||
idx[k] = axes_[k].values.at(starts[k] + current % sizes[k]);
|
||||
current = current / sizes[k];
|
||||
}
|
||||
fn(idx);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Shared Tile */
|
||||
void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) {
|
||||
BinaryOperator *bin_op = dyn_cast<BinaryOperator>(arg);
|
||||
Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0);
|
||||
if(dyn_cast<Constant>(arg)){
|
||||
cst = arg;
|
||||
non_cst = _0;
|
||||
return;
|
||||
}
|
||||
if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){
|
||||
non_cst = arg;
|
||||
cst = _0;
|
||||
return;
|
||||
}
|
||||
Constant *cst_lhs = dyn_cast<Constant>(bin_op->getOperand(0));
|
||||
Constant *cst_rhs = dyn_cast<Constant>(bin_op->getOperand(1));
|
||||
if(cst_lhs && cst_rhs){
|
||||
cst = arg;
|
||||
non_cst = _0;
|
||||
}
|
||||
else if(cst_lhs){
|
||||
cst = cst_lhs;
|
||||
non_cst = bin_op->getOperand(1);
|
||||
}
|
||||
else if(cst_rhs){
|
||||
cst = cst_rhs;
|
||||
non_cst = bin_op->getOperand(0);
|
||||
}
|
||||
else{
|
||||
non_cst = arg;
|
||||
cst = _0;
|
||||
}
|
||||
}
|
||||
|
||||
void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) {
|
||||
non_cst_idx.clear();
|
||||
cst_idx.clear();
|
||||
for(Value *idx: arg_idx){
|
||||
Value *non_cst, *cst;
|
||||
extract_constant(idx, non_cst, cst);
|
||||
non_cst_idx.push_back(non_cst);
|
||||
cst_idx.push_back(cst);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes,
|
||||
const std::vector<int>& perm, const std::vector<int>& order,
|
||||
indices_t idx) {
|
||||
// strides
|
||||
std::vector<Value*> strides(shapes.size(), builder.getInt32(0));
|
||||
strides[order[0]] = builder.getInt32(1);
|
||||
for(size_t i = 1; i < idx.size(); i++)
|
||||
strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]]));
|
||||
// result
|
||||
Value *result = builder.getInt32(0);
|
||||
for(size_t i = 0; i < idx.size(); i++)
|
||||
result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i]));
|
||||
return result;
|
||||
}
|
||||
|
||||
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector<int>& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector<int>& perm):
|
||||
tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){
|
||||
return_vector_ = false;
|
||||
if(perm_.empty()){
|
||||
perm_.resize(shapes.size());
|
||||
std::iota(perm_.begin(), perm_.end(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
void shared_tile::set_value(indices_t idx, Value *value) {
|
||||
Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx));
|
||||
unsigned addr_space = ptr->getType()->getPointerAddressSpace();
|
||||
ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space));
|
||||
builder_.CreateStore(value, ptr);
|
||||
}
|
||||
|
||||
void shared_tile::set_vector_size(unsigned vector_size) {
|
||||
vector_size_ = vector_size;
|
||||
}
|
||||
|
||||
void shared_tile::set_return_mode(bool return_vector){
|
||||
return_vector_ = return_vector;
|
||||
}
|
||||
|
||||
|
||||
Value* shared_tile::get_value(indices_t idx) {
|
||||
indices_t non_cst_idx, cst_idx;
|
||||
extract_constant(idx, non_cst_idx, cst_idx);
|
||||
Value *&base_ptr = ptr_cache_[non_cst_idx];
|
||||
unsigned vector_size = vector_size_;
|
||||
Type *ty = ty_;
|
||||
if(ty->isHalfTy() && (vector_size % 2 == 0)){
|
||||
ty = IntegerType::get(ty->getContext(), 32);
|
||||
vector_size = vector_size / 2;
|
||||
}
|
||||
if(base_ptr == nullptr){
|
||||
// BasicBlock* store = builder_.GetInsertBlock();
|
||||
// if(!non_cst_idx.empty())
|
||||
// if(isa<Instruction>(non_cst_idx.front())){
|
||||
// builder_.SetInsertPoint((Instruction*)non_cst_idx.front());
|
||||
// }
|
||||
base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx));
|
||||
if(vector_size_ > 1){
|
||||
Type *vec_ty = VectorType::get(ty, vector_size);
|
||||
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
|
||||
base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty);
|
||||
}
|
||||
// builder_.SetInsertPoint(store);
|
||||
}
|
||||
Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx);
|
||||
Value *div = offset;
|
||||
if(vector_size_ > 1)
|
||||
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
|
||||
Value *ptr = builder_.CreateGEP(base_ptr, div);
|
||||
Value *result = builder_.CreateLoad(ptr);
|
||||
if(return_vector_ == false && vector_size_ > 1) {
|
||||
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
|
||||
result = builder_.CreateExtractElement(result, rem);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
@@ -14,6 +14,12 @@ namespace triton{
|
||||
namespace codegen{
|
||||
|
||||
// base
|
||||
|
||||
|
||||
nvidia_cu_target* target::as_nvidia() {
|
||||
return dynamic_cast<nvidia_cu_target*>(this);
|
||||
}
|
||||
|
||||
bool target::is_gpu() const {
|
||||
return is_gpu_;
|
||||
}
|
||||
@@ -25,7 +31,7 @@ void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *m
|
||||
|
||||
Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) {
|
||||
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier);
|
||||
return builder.CreateCall(barrier, {});
|
||||
return builder.CreateIntrinsic(Intrinsic::amdgcn_s_barrier, {}, {});
|
||||
}
|
||||
|
||||
Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
|
||||
@@ -45,8 +51,7 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne
|
||||
Intrinsic::amdgcn_workgroup_id_y,
|
||||
Intrinsic::amdgcn_workgroup_id_z
|
||||
};
|
||||
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
Value* group_id = builder.CreateCall(get_group_id, {});
|
||||
Value* group_id = builder.CreateIntrinsic(ids[ax], {}, {});
|
||||
return group_id;
|
||||
}
|
||||
|
||||
@@ -99,8 +104,7 @@ Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsi
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
|
||||
};
|
||||
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
|
||||
Value* cta_id = builder.CreateCall(get_cta_id, {});
|
||||
Value* cta_id = builder.CreateIntrinsic(cta_ids[ax], {}, {});
|
||||
return cta_id;
|
||||
}
|
||||
|
||||
@@ -120,8 +124,7 @@ Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, un
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
|
||||
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
|
||||
};
|
||||
Value* get_nctaid = Intrinsic::getDeclaration(module, ids[ax]);
|
||||
return builder.CreateCall(get_nctaid, {});
|
||||
return builder.CreateIntrinsic(ids[ax], {}, {});
|
||||
}
|
||||
|
||||
// CPU
|
||||
|
@@ -66,7 +66,7 @@ void coalesce::run(ir::module &mod) {
|
||||
|
||||
|
||||
for(size_t id = 0; id < num_groups; id++) {
|
||||
if(!layout_->get(id)->to_mma884())
|
||||
if(!layout_->get(id)->to_mma())
|
||||
continue;
|
||||
// extract memory stores
|
||||
const auto& values = layout_->values_of(id);
|
||||
|
@@ -28,12 +28,14 @@ inline bool is_shmem_res(ir::value* v){
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
||||
return true;
|
||||
if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// run pass on module
|
||||
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) {
|
||||
auto *i = dynamic_cast<ir::instruction*>(x);
|
||||
// not an instruction
|
||||
if(!i) {
|
||||
@@ -58,8 +60,9 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool
|
||||
// copy
|
||||
builder.set_insert_point_after(i);
|
||||
ir::value *copy;
|
||||
if(to_shared)
|
||||
if(to_shared){
|
||||
copy = builder.create_copy_to_shared(x);
|
||||
}
|
||||
else
|
||||
copy = builder.create_copy_from_shared(x);
|
||||
parent->replace_uses_of_with(x, copy);
|
||||
|
@@ -54,7 +54,7 @@ void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){
|
||||
add_reference(i, res);
|
||||
}
|
||||
|
||||
void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
void membar::insert_barrier(ir::instruction *instr, std::pair<bool, bool> type, ir::builder &builder) {
|
||||
if(auto *phi = dynamic_cast<ir::phi_node*>(instr)) {
|
||||
std::set<ir::value*> incoming;
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++){
|
||||
@@ -63,7 +63,10 @@ void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) {
|
||||
if(incoming.insert(inc_val).second){
|
||||
ir::basic_block *block = inc_val->get_parent();
|
||||
builder.set_insert_point(block->get_inst_list().back());
|
||||
builder.create_barrier();
|
||||
if(type.first)
|
||||
builder.create_async_wait();
|
||||
if(type.second)
|
||||
builder.create_barrier();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -85,8 +88,9 @@ std::pair<membar::interval_vec_t,
|
||||
membar::interval_vec_t> membar::transfer(ir::basic_block *block,
|
||||
const interval_vec_t &written_to,
|
||||
const interval_vec_t &read_from,
|
||||
std::set<ir::instruction*>& insert_loc,
|
||||
std::set<ir::value*>& safe_war) {
|
||||
std::map<ir::instruction*, std::pair<bool,bool>>& insert_loc,
|
||||
std::set<ir::value*>& safe_war,
|
||||
std::vector<ir::instruction*>& to_sync) {
|
||||
ir::basic_block::inst_list_t instructions = block->get_inst_list();
|
||||
interval_vec_t new_written_to = written_to;
|
||||
interval_vec_t new_read_from = read_from;
|
||||
@@ -95,6 +99,8 @@ std::pair<membar::interval_vec_t,
|
||||
interval_vec_t read, written;
|
||||
get_read_intervals(i, read);
|
||||
get_written_intervals(i, written);
|
||||
if(written.size())
|
||||
to_sync.push_back(i);
|
||||
bool read_after_write = intersect(new_written_to, read);
|
||||
bool write_after_read = intersect(new_read_from, written);
|
||||
// double buffering
|
||||
@@ -104,9 +110,14 @@ std::pair<membar::interval_vec_t,
|
||||
}
|
||||
// record hazards
|
||||
if(read_after_write || write_after_read) {
|
||||
insert_loc.insert(i);
|
||||
auto is_load_async = [&](ir::instruction *i){ return dynamic_cast<ir::masked_load_async_inst*>(i);};
|
||||
auto is_copy_to_shared = [&](ir::instruction *i){ return dynamic_cast<ir::copy_to_shared_inst*>(i);};
|
||||
bool copy_async_wait = std::any_of(to_sync.begin(), to_sync.end(), is_load_async);
|
||||
bool barrier = std::any_of(to_sync.begin(), to_sync.end(), is_copy_to_shared);
|
||||
insert_loc.insert({i, {copy_async_wait, barrier}});
|
||||
new_written_to.clear();
|
||||
new_read_from.clear();
|
||||
to_sync.clear();
|
||||
}
|
||||
std::copy(written.begin(), written.end(), std::back_inserter(new_written_to));
|
||||
std::copy(read.begin(), read.end(), std::back_inserter(new_read_from));
|
||||
@@ -125,17 +136,17 @@ void membar::run(ir::module &mod) {
|
||||
if(!layout || !layout->get_double_buffer())
|
||||
continue;
|
||||
for(ir::value *v: layout->get_values())
|
||||
if(v != layout->get_double_buffer()->phi)
|
||||
if(v != layout->get_double_buffer()->phi){
|
||||
safe_war.insert(v);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
std::map<ir::basic_block*, interval_vec_t> written_to;
|
||||
std::map<ir::basic_block*, interval_vec_t> read_from;
|
||||
std::set<ir::instruction*> insert_locs;
|
||||
std::vector<ir::instruction*> to_sync;
|
||||
std::map<ir::instruction*, std::pair<bool,bool>> insert_locs;
|
||||
size_t n_inserted_im1 = 0;
|
||||
bool done = false;
|
||||
do{
|
||||
@@ -150,7 +161,7 @@ void membar::run(ir::module &mod) {
|
||||
for(ir::basic_block* pred: block->get_predecessors())
|
||||
pred_read_from.push_back(read_from[pred]);
|
||||
// apply transfer function
|
||||
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war);
|
||||
auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war, to_sync);
|
||||
written_to[block] = result.first;
|
||||
read_from[block] = result.second;
|
||||
}
|
||||
@@ -158,8 +169,9 @@ void membar::run(ir::module &mod) {
|
||||
done = (n_inserted_im1 == n_inserted_i);
|
||||
n_inserted_im1 = n_inserted_i;
|
||||
}while(!done);
|
||||
for(ir::instruction* i: insert_locs)
|
||||
insert_barrier(i, builder);
|
||||
for(auto x: insert_locs){
|
||||
insert_barrier(x.first, x.second, builder);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -97,6 +97,24 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
|
||||
//}
|
||||
|
||||
bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& builder){
|
||||
auto copy_to_shared = dynamic_cast<ir::copy_to_shared_inst*>(value);
|
||||
if(!copy_to_shared)
|
||||
return false;
|
||||
ir::value *arg = copy_to_shared->get_operand(0);
|
||||
ir::masked_load_inst* ld = dynamic_cast<ir::masked_load_inst*>(arg);
|
||||
if(!ld)
|
||||
return false;
|
||||
builder.set_insert_point(copy_to_shared);
|
||||
ir::value *ptr = ld->get_pointer_operand();
|
||||
ir::value *msk = ld->get_mask_operand();
|
||||
ir::value *val = ld->get_false_value_operand();
|
||||
ir::value* new_load = builder.create_masked_load_async(ptr, msk, val);
|
||||
copy_to_shared->replace_all_uses_with(new_load);
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||
if(!x)
|
||||
@@ -197,10 +215,12 @@ void peephole::run(ir::module &mod) {
|
||||
continue;
|
||||
bool was_modified = false;
|
||||
was_modified = was_modified || rewrite_mult(i, builder);
|
||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||
// if(tgt_->as_nvidia()->sm() >= 80)
|
||||
// was_modified = was_modified || rewrite_load_to_shared(i, builder);
|
||||
if(was_modified)
|
||||
seen.insert(i);
|
||||
}
|
||||
|
51
lib/codegen/transform/reorder.cc
Normal file
51
lib/codegen/transform/reorder.cc
Normal file
@@ -0,0 +1,51 @@
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/instructions.h"
|
||||
#include "triton/codegen/transform/reorder.h"
|
||||
|
||||
namespace triton {
|
||||
namespace codegen{
|
||||
namespace transform{
|
||||
|
||||
void reorder::run(ir::module& mod){
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::vector<std::pair<ir::instruction*, ir::value*>> to_replace;
|
||||
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
for(ir::instruction* i: block->get_inst_list()){
|
||||
if(auto* ld = dynamic_cast<ir::masked_load_inst*>(i)){
|
||||
ir::value* _ptr = ld->get_pointer_operand();
|
||||
ir::value* _msk = ld->get_mask_operand();
|
||||
ir::value* _val = ld->get_false_value_operand();
|
||||
auto ptr = std::find(block->begin(), block->end(), _ptr);
|
||||
auto msk = std::find(block->begin(), block->end(), _msk);
|
||||
auto val = std::find(block->begin(), block->end(), _val);
|
||||
if(ptr == block->end() || msk == block->end() || val == block->end())
|
||||
continue;
|
||||
auto it = std::find(block->begin(), block->end(), i);
|
||||
int dist_ptr = std::distance(ptr, it);
|
||||
int dist_msk = std::distance(msk, it);
|
||||
int dist_val = std::distance(val, it);
|
||||
if(dist_ptr < dist_msk && dist_ptr < dist_val)
|
||||
builder.set_insert_point(++ptr);
|
||||
if(dist_msk < dist_ptr && dist_msk < dist_val)
|
||||
builder.set_insert_point(++msk);
|
||||
if(dist_val < dist_ptr && dist_val < dist_msk)
|
||||
builder.set_insert_point(++val);
|
||||
ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val);
|
||||
to_replace.push_back(std::make_pair(ld, new_ld));
|
||||
}
|
||||
}
|
||||
|
||||
for(auto& x: to_replace)
|
||||
x.first->replace_all_uses_with(x.second);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user