[selection] [codegen] added reduction
This commit is contained in:
@@ -126,6 +126,18 @@ private:
|
|||||||
|
|
||||||
void create(size_t id, const std::vector<ir::value*>& values);
|
void create(size_t id, const std::vector<ir::value*>& values);
|
||||||
|
|
||||||
|
// size_t shared_tmp_req(ir::instruction* i) {
|
||||||
|
// switch(i->get_id()) {
|
||||||
|
// case ir::INST_REDUCE: {
|
||||||
|
// ir::reduce_inst *red = (ir::reduce_inst*)i;
|
||||||
|
// ir::type *ty = red->get_type();
|
||||||
|
|
||||||
|
|
||||||
|
// }
|
||||||
|
// default: return 0;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// constructor
|
// constructor
|
||||||
layout(analysis::axes *axes, analysis::align *align, size_t num_warps);
|
layout(analysis::axes *axes, analysis::align *align, size_t num_warps);
|
||||||
@@ -134,8 +146,10 @@ public:
|
|||||||
unsigned layout_of(ir::value *value) const;
|
unsigned layout_of(ir::value *value) const;
|
||||||
const std::vector<ir::value*>& values_of(unsigned id) const;
|
const std::vector<ir::value*>& values_of(unsigned id) const;
|
||||||
size_t num_layouts() const;
|
size_t num_layouts() const;
|
||||||
|
const layout_t* get(size_t id) const;
|
||||||
const layout_t* get(ir::value *v) const;
|
const layout_t* get(ir::value *v) const;
|
||||||
std::map<size_t, layout_t*> &get_all();
|
std::map<size_t, layout_t*> &get_all();
|
||||||
|
size_t tmp(ir::instruction* i);
|
||||||
|
|
||||||
// execution
|
// execution
|
||||||
void run(ir::module &mod);
|
void run(ir::module &mod);
|
||||||
@@ -148,6 +162,7 @@ private:
|
|||||||
std::map<ir::value*, size_t> groups_;
|
std::map<ir::value*, size_t> groups_;
|
||||||
std::map<size_t, std::vector<ir::value*>> values_;
|
std::map<size_t, std::vector<ir::value*>> values_;
|
||||||
std::map<size_t, layout_t*> layouts_;
|
std::map<size_t, layout_t*> layouts_;
|
||||||
|
std::map<ir::value*, size_t> tmp_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -6,8 +6,12 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace triton{
|
namespace triton{
|
||||||
namespace codegen{
|
|
||||||
|
|
||||||
|
namespace ir{
|
||||||
|
class instruction;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace codegen{
|
||||||
|
|
||||||
enum storage_info_t {
|
enum storage_info_t {
|
||||||
NONE,
|
NONE,
|
||||||
@@ -63,7 +67,6 @@ static const std::map<ir::value_id_t, inst_storage_info_t> storage_info = {
|
|||||||
{ ir::INST_RETURN, {NONE, {}}},
|
{ ir::INST_RETURN, {NONE, {}}},
|
||||||
{ ir::INST_UNCOND_BRANCH, {NONE, {}}},
|
{ ir::INST_UNCOND_BRANCH, {NONE, {}}},
|
||||||
{ ir::INST_COND_BRANCH, {NONE, {REPLICATED}}},
|
{ ir::INST_COND_BRANCH, {NONE, {REPLICATED}}},
|
||||||
|
|
||||||
// intrinsics
|
// intrinsics
|
||||||
{ ir::INST_COPY_TO_SHARED, {SHARED, {DISTRIBUTED}}},
|
{ ir::INST_COPY_TO_SHARED, {SHARED, {DISTRIBUTED}}},
|
||||||
{ ir::INST_COPY_FROM_SHARED, {DISTRIBUTED, {SHARED}}},
|
{ ir::INST_COPY_FROM_SHARED, {DISTRIBUTED, {SHARED}}},
|
||||||
@@ -73,6 +76,7 @@ static const std::map<ir::value_id_t, inst_storage_info_t> storage_info = {
|
|||||||
{ ir::INST_MAKE_RANGE, {DISTRIBUTED, {}}}
|
{ ir::INST_MAKE_RANGE, {DISTRIBUTED, {}}}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -76,6 +76,10 @@ bool is_hmma_c(ir::value *v){
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const layout_t* layout::get(size_t id) const {
|
||||||
|
return layouts_.at(id);
|
||||||
|
}
|
||||||
|
|
||||||
const layout_t* layout::get(ir::value *v) const {
|
const layout_t* layout::get(ir::value *v) const {
|
||||||
return layouts_.at(groups_.at(v));
|
return layouts_.at(groups_.at(v));
|
||||||
}
|
}
|
||||||
@@ -84,6 +88,10 @@ std::map<size_t, layout_t*>& layout::get_all() {
|
|||||||
return layouts_;
|
return layouts_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t layout::tmp(ir::instruction* i) {
|
||||||
|
return tmp_.at(i);
|
||||||
|
}
|
||||||
|
|
||||||
void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
|
void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
|
||||||
for(ir::user* u: v->get_users()){
|
for(ir::user* u: v->get_users()){
|
||||||
auto i = dynamic_cast<ir::io_inst*>(u);
|
auto i = dynamic_cast<ir::io_inst*>(u);
|
||||||
@@ -323,6 +331,7 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
|||||||
size *= 2;
|
size *= 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// layout factory method
|
// layout factory method
|
||||||
void layout::create(size_t id, const std::vector<ir::value*>& values) {
|
void layout::create(size_t id, const std::vector<ir::value*>& values) {
|
||||||
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
|
||||||
@@ -364,6 +373,17 @@ void layout::run(ir::module &mod) {
|
|||||||
// create layouts
|
// create layouts
|
||||||
for(const auto& x: values_)
|
for(const auto& x: values_)
|
||||||
create(x.first, x.second);
|
create(x.first, x.second);
|
||||||
|
|
||||||
|
// create temporaries
|
||||||
|
size_t id = values_.size();
|
||||||
|
ir::for_each_instruction(mod, [this, &id](ir::instruction* i) {
|
||||||
|
if(auto *red = dynamic_cast<ir::reduce_inst*>(i)) {
|
||||||
|
id++;
|
||||||
|
ir::value *arg = red->get_operand(0);
|
||||||
|
layouts_[id] = new layout_shared_t(get(arg), axes_->get(arg), arg->get_type()->get_tile_shapes(), {red}, red->get_type()->get_scalar_ty(), id, align_);
|
||||||
|
tmp_[red] = id;
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -750,96 +750,97 @@ void generator::visit_sqrt_inst(ir::sqrt_inst* sqt) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
||||||
throw std::runtime_error("not implemented");
|
std::map<indices_t, Value*> partial;
|
||||||
// std::map<indices_t, Value*> partial;
|
ir::value *arg = x->get_operand(0);
|
||||||
// ir::value *arg = x->get_operand(0);
|
distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg);
|
||||||
// distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg);
|
ir::reduce_inst::op_t op = x->get_op();
|
||||||
// ir::reduce_inst::op_t op = x->get_op();
|
auto accumulate = [&](Value* x, Value *y) -> Value* {
|
||||||
// auto accumulate = [&](Value* x, Value *y) -> Value* {
|
switch(op) {
|
||||||
// switch(op) {
|
case ir::reduce_inst::ADD: return builder_->CreateAdd(x, y);
|
||||||
// case ir::reduce_inst::ADD: return builder_->CreateAdd(x, y);
|
case ir::reduce_inst::SUB: return builder_->CreateSub(x, y);
|
||||||
// case ir::reduce_inst::SUB: return builder_->CreateSub(x, y);
|
case ir::reduce_inst::MAX: return builder_->CreateMaximum(x, y);
|
||||||
// case ir::reduce_inst::MAX: return builder_->CreateMaximum(x, y);
|
case ir::reduce_inst::MIN: return builder_->CreateMinimum(x, y);
|
||||||
// case ir::reduce_inst::MIN: return builder_->CreateMinimum(x, y);
|
case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y);
|
||||||
// case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y);
|
case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y);
|
||||||
// case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y);
|
case ir::reduce_inst::FMAX: return builder_->CreateSelect(builder_->CreateFCmpOGT(x, y), x, y);
|
||||||
// case ir::reduce_inst::FMAX: return builder_->CreateSelect(builder_->CreateFCmpOGT(x, y), x, y);
|
case ir::reduce_inst::FMIN: return builder_->CreateSelect(builder_->CreateFCmpOLT(x, y), x, y);
|
||||||
// case ir::reduce_inst::FMIN: return builder_->CreateSelect(builder_->CreateFCmpOLT(x, y), x, y);
|
default: break;
|
||||||
// default: break;
|
}
|
||||||
// }
|
assert(false);
|
||||||
// assert(false);
|
return nullptr;
|
||||||
// return nullptr;
|
};
|
||||||
// };
|
|
||||||
|
|
||||||
// unsigned axis = x->get_axis();
|
// reduce within thread
|
||||||
|
unsigned axis = x->get_axis();
|
||||||
|
arg_tile->for_each([&](indices_t idx) {
|
||||||
|
indices_t pidx = idx;
|
||||||
|
pidx[axis] = builder_->getInt32(0);
|
||||||
|
Value *current = arg_tile->get_value(idx);
|
||||||
|
// current partial result is not initialized -- create
|
||||||
|
if(partial.find(pidx) == partial.end())
|
||||||
|
partial[pidx] = current;
|
||||||
|
// current partial result is initialized -- accumulate
|
||||||
|
else
|
||||||
|
partial[pidx] = accumulate(partial[pidx], current);
|
||||||
|
});
|
||||||
|
|
||||||
// // reduce within thread
|
// depth
|
||||||
// arg_tile->for_each([&](indices_t idx) {
|
unsigned shape_ax = arg->get_type()->get_tile_shapes()[axis];
|
||||||
// indices_t pidx = idx;
|
unsigned per_thread = arg_tile->axis(axis).values.size();
|
||||||
// pidx[axis] = builder_->getInt32(0);
|
unsigned depth = shape_ax / per_thread;
|
||||||
// Value *current = arg_tile->get_value(idx);
|
|
||||||
// // current partial result is not initialized -- create
|
|
||||||
// if(partial.find(pidx) == partial.end())
|
|
||||||
// partial[pidx] = current;
|
|
||||||
// // current partial result is initialized -- accumulate
|
|
||||||
// else
|
|
||||||
// partial[pidx] = accumulate(partial[pidx], current);
|
|
||||||
// });
|
|
||||||
|
|
||||||
// // depth
|
// shapes
|
||||||
// unsigned shape_ax = arg->get_type()->get_tile_shapes()[axis];
|
auto shared_shapes = arg_tile->get_shapes();
|
||||||
// unsigned per_thread = arg_tile->axis(axis).values.size();
|
shared_shapes[axis] = depth;
|
||||||
// unsigned depth = shape_ax / per_thread;
|
|
||||||
|
|
||||||
// // shapes
|
// reduce within blocks
|
||||||
// auto shared_shapes = arg_tile->get_shapes();
|
machine_layout_t *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x)));
|
||||||
// shared_shapes[axis] = depth;
|
shared_tile *stile = (shared_tile*)slayout->create(x);
|
||||||
|
|
||||||
// // reduce within blocks
|
unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
||||||
// unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace();
|
Type *res_ty = builder_->getFloatTy();
|
||||||
// Type *res_ty = builder_->getFloatTy();
|
Value *base_ptr = builder_->CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
||||||
// Value *base_ptr = builder_->CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space));
|
for(auto& x: partial) {
|
||||||
// for(auto& x: partial) {
|
// current element being computed
|
||||||
// // current element being computed
|
Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id;
|
||||||
// Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id;
|
Value *&result = x.second;
|
||||||
// Value *&result = x.second;
|
indices_t write_idx = x.first;
|
||||||
// indices_t write_idx = x.first;
|
write_idx[axis] = lane;
|
||||||
// write_idx[axis] = lane;
|
// shared memory write pointer
|
||||||
// // shared memory write pointer
|
Value *write_offset = shared_tile::shared_offset(*builder_, stile->get_shapes(), stile->get_perm(), stile->get_order(), write_idx);
|
||||||
// Value *write_offset = shared_tile::shared_offset(*builder_, shared_shapes, write_idx);
|
Value *write_ptr = builder_->CreateGEP(base_ptr, write_offset);
|
||||||
// Value *write_ptr = builder_->CreateGEP(base_ptr, write_offset);
|
// initialize shared memory
|
||||||
// // initialize shared memory
|
tgt_->add_barrier(mod_, *builder_);
|
||||||
// tgt_->add_barrier(*mod_, *builder_);
|
builder_->CreateStore(result, write_ptr);
|
||||||
// builder_->CreateStore(result, write_ptr);
|
// build result
|
||||||
// // build result
|
for(unsigned i = depth/2; i > 0; i >>= 1){
|
||||||
// for(unsigned i = depth/2; i > 0; i >>= 1){
|
// current indices
|
||||||
// // current indices
|
indices_t current(write_idx.size(), builder_->getInt32(0));
|
||||||
// indices_t current(write_idx.size(), builder_->getInt32(0));
|
current[axis] = builder_->getInt32(i);
|
||||||
// current[axis] = builder_->getInt32(i);
|
// shared memory offset
|
||||||
// // shared memory offset
|
Value *read_offset = shared_tile::shared_offset(*builder_, stile->get_shapes(), stile->get_perm(), stile->get_order(), current);
|
||||||
// Value *read_offset = shared_tile::shared_offset(*builder_, shared_shapes, current);
|
Value *is_active = builder_->CreateICmpULT(lane, builder_->getInt32(i));
|
||||||
// Value *is_active = builder_->CreateICmpULT(lane, builder_->getInt32(i));
|
read_offset = builder_->CreateSelect(is_active, read_offset, builder_->getInt32(0));
|
||||||
// read_offset = builder_->CreateSelect(is_active, read_offset, builder_->getInt32(0));
|
// shared memory read pointer
|
||||||
// // shared memory read pointer
|
Value *read_ptr = builder_->CreateGEP(write_ptr, read_offset);
|
||||||
// Value *read_ptr = builder_->CreateGEP(write_ptr, read_offset);
|
tgt_->add_barrier(mod_, *builder_);
|
||||||
// tgt_->add_barrier(*mod_, *builder_);
|
Value *next = builder_->CreateLoad(read_ptr);
|
||||||
// Value *next = builder_->CreateLoad(read_ptr);
|
// accumulate
|
||||||
// // accumulate
|
result = accumulate(result, next);
|
||||||
// result = accumulate(result, next);
|
// write back
|
||||||
// // write back
|
builder_->CreateStore(result, write_ptr);
|
||||||
// builder_->CreateStore(result, write_ptr);
|
}
|
||||||
// }
|
}
|
||||||
// }
|
tgt_->add_barrier(mod_, *builder_);
|
||||||
// tgt_->add_barrier(*mod_, *builder_);
|
|
||||||
|
|
||||||
// distributed_tile* x_tile = (distributed_tile*)tmap_.at(x);
|
distributed_tile* x_tile = (distributed_tile*)tmap_.at(x);
|
||||||
// x_tile->for_each([&](indices_t idx) {
|
x_tile->for_each([&](indices_t idx) {
|
||||||
// indices_t red_idx = idx;
|
indices_t red_idx = idx;
|
||||||
// red_idx.insert(red_idx.begin() + axis, builder_->getInt32(0));
|
red_idx.insert(red_idx.begin() + axis, builder_->getInt32(0));
|
||||||
// Value *read_offset = shared_tile::shared_offset(*builder_, shared_shapes, red_idx);
|
Value *read_offset = shared_tile::shared_offset(*builder_, stile->get_shapes(), stile->get_perm(), stile->get_order(), red_idx);
|
||||||
// Value *read_ptr = builder_->CreateGEP(base_ptr, read_offset);
|
Value *read_ptr = builder_->CreateGEP(base_ptr, read_offset);
|
||||||
// x_tile->set_value(idx, builder_->CreateLoad(read_ptr));
|
x_tile->set_value(idx, builder_->CreateLoad(read_ptr));
|
||||||
// });
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_select_inst(ir::select_inst* select) {
|
void generator::visit_select_inst(ir::select_inst* select) {
|
||||||
|
@@ -13,7 +13,7 @@ int main() {
|
|||||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}, {false, true},
|
for(auto x: std::vector<std::array<bool, 2>>{{false, false}, {false, true},
|
||||||
{true, false}, {true, true}}){
|
{true, false}, {true, true}}){
|
||||||
std::vector<config_t> tmp = {
|
std::vector<config_t> tmp = {
|
||||||
config_t{ord, x[0], x[1], 4096, 4096, 4096},
|
config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
||||||
// config_t{ord, x[0], x[1], 16, 2048, 2048},
|
// config_t{ord, x[0], x[1], 16, 2048, 2048},
|
||||||
// config_t{ord, x[0], x[1], 32, 2048, 2048},
|
// config_t{ord, x[0], x[1], 32, 2048, 2048},
|
||||||
// config_t{ord, x[0], x[1], 64, 2048, 2048},
|
// config_t{ord, x[0], x[1], 64, 2048, 2048},
|
||||||
@@ -34,7 +34,7 @@ int main() {
|
|||||||
for(const auto& c: configs){
|
for(const auto& c: configs){
|
||||||
std::tie(ord, AT, BT, M, N, K) = c;
|
std::tie(ord, AT, BT, M, N, K) = c;
|
||||||
std::cout << "// " << c << std::flush;
|
std::cout << "// " << c << std::flush;
|
||||||
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
|
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
|
||||||
std::cout << ", " << perf << std::flush;
|
std::cout << ", " << perf << std::flush;
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
@@ -111,7 +111,7 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
|||||||
if(mode == BENCH) {
|
if(mode == BENCH) {
|
||||||
opt.defines.push_back({"TM", {"64", "128"}});
|
opt.defines.push_back({"TM", {"64", "128"}});
|
||||||
opt.defines.push_back({"TN", {"64", "128"}});
|
opt.defines.push_back({"TN", {"64", "128"}});
|
||||||
opt.defines.push_back({"TK", {"8", "16"}});
|
opt.defines.push_back({"TK", {"8"}});
|
||||||
opt.num_warps = {2, 4, 8};
|
opt.num_warps = {2, 4, 8};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user