From a76efd326d61c4cc70c38fa4c02a0b071550ab06 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 19 Oct 2019 14:47:16 -0400 Subject: [PATCH] [selection] [codegen] added reduction --- include/triton/codegen/analysis/layout.h | 15 ++ include/triton/codegen/instructions.h | 8 +- lib/codegen/analysis/layout.cc | 20 +++ lib/codegen/selection/generator.cc | 169 ++++++++++++----------- tests/bench/dot.cc | 4 +- tests/common/dot.h | 2 +- 6 files changed, 129 insertions(+), 89 deletions(-) diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 70260542a..e0eee3a38 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -126,6 +126,18 @@ private: void create(size_t id, const std::vector& values); +// size_t shared_tmp_req(ir::instruction* i) { +// switch(i->get_id()) { +// case ir::INST_REDUCE: { +// ir::reduce_inst *red = (ir::reduce_inst*)i; +// ir::type *ty = red->get_type(); + + +// } +// default: return 0; +// } +// } + public: // constructor layout(analysis::axes *axes, analysis::align *align, size_t num_warps); @@ -134,8 +146,10 @@ public: unsigned layout_of(ir::value *value) const; const std::vector& values_of(unsigned id) const; size_t num_layouts() const; + const layout_t* get(size_t id) const; const layout_t* get(ir::value *v) const; std::map &get_all(); + size_t tmp(ir::instruction* i); // execution void run(ir::module &mod); @@ -148,6 +162,7 @@ private: std::map groups_; std::map> values_; std::map layouts_; + std::map tmp_; }; } diff --git a/include/triton/codegen/instructions.h b/include/triton/codegen/instructions.h index 2e5d6148f..c42abee4a 100644 --- a/include/triton/codegen/instructions.h +++ b/include/triton/codegen/instructions.h @@ -6,8 +6,12 @@ #include namespace triton{ -namespace codegen{ +namespace ir{ +class instruction; +} + +namespace codegen{ enum storage_info_t { NONE, @@ -63,7 +67,6 @@ static const std::map storage_info = { { ir::INST_RETURN, {NONE, {}}}, { ir::INST_UNCOND_BRANCH, {NONE, {}}}, { ir::INST_COND_BRANCH, {NONE, {REPLICATED}}}, - // intrinsics { ir::INST_COPY_TO_SHARED, {SHARED, {DISTRIBUTED}}}, { ir::INST_COPY_FROM_SHARED, {DISTRIBUTED, {SHARED}}}, @@ -73,6 +76,7 @@ static const std::map storage_info = { { ir::INST_MAKE_RANGE, {DISTRIBUTED, {}}} }; + } } diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 3d2296aae..1066a5cae 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -76,6 +76,10 @@ bool is_hmma_c(ir::value *v){ return result; } +const layout_t* layout::get(size_t id) const { + return layouts_.at(id); +} + const layout_t* layout::get(ir::value *v) const { return layouts_.at(groups_.at(v)); } @@ -84,6 +88,10 @@ std::map& layout::get_all() { return layouts_; } +size_t layout::tmp(ir::instruction* i) { + return tmp_.at(i); +} + void extract_io_use(ir::value *v, std::set& result) { for(ir::user* u: v->get_users()){ auto i = dynamic_cast(u); @@ -323,6 +331,7 @@ layout_shared_t::layout_shared_t(const layout_t *arg, size *= 2; } + // layout factory method void layout::create(size_t id, const std::vector& values) { auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c); @@ -364,6 +373,17 @@ void layout::run(ir::module &mod) { // create layouts for(const auto& x: values_) create(x.first, x.second); + + // create temporaries + size_t id = values_.size(); + ir::for_each_instruction(mod, [this, &id](ir::instruction* i) { + if(auto *red = dynamic_cast(i)) { + id++; + ir::value *arg = red->get_operand(0); + layouts_[id] = new layout_shared_t(get(arg), axes_->get(arg), arg->get_type()->get_tile_shapes(), {red}, red->get_type()->get_scalar_ty(), id, align_); + tmp_[red] = id; + } + }); } } diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 41398dc76..8fbdbeded 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -750,96 +750,97 @@ void generator::visit_sqrt_inst(ir::sqrt_inst* sqt) { } void generator::visit_reduce_inst(ir::reduce_inst* x) { - throw std::runtime_error("not implemented"); -// std::map partial; -// ir::value *arg = x->get_operand(0); -// distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg); -// ir::reduce_inst::op_t op = x->get_op(); -// auto accumulate = [&](Value* x, Value *y) -> Value* { -// switch(op) { -// case ir::reduce_inst::ADD: return builder_->CreateAdd(x, y); -// case ir::reduce_inst::SUB: return builder_->CreateSub(x, y); -// case ir::reduce_inst::MAX: return builder_->CreateMaximum(x, y); -// case ir::reduce_inst::MIN: return builder_->CreateMinimum(x, y); -// case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y); -// case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y); -// case ir::reduce_inst::FMAX: return builder_->CreateSelect(builder_->CreateFCmpOGT(x, y), x, y); -// case ir::reduce_inst::FMIN: return builder_->CreateSelect(builder_->CreateFCmpOLT(x, y), x, y); -// default: break; -// } -// assert(false); -// return nullptr; -// }; + std::map partial; + ir::value *arg = x->get_operand(0); + distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg); + ir::reduce_inst::op_t op = x->get_op(); + auto accumulate = [&](Value* x, Value *y) -> Value* { + switch(op) { + case ir::reduce_inst::ADD: return builder_->CreateAdd(x, y); + case ir::reduce_inst::SUB: return builder_->CreateSub(x, y); + case ir::reduce_inst::MAX: return builder_->CreateMaximum(x, y); + case ir::reduce_inst::MIN: return builder_->CreateMinimum(x, y); + case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y); + case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y); + case ir::reduce_inst::FMAX: return builder_->CreateSelect(builder_->CreateFCmpOGT(x, y), x, y); + case ir::reduce_inst::FMIN: return builder_->CreateSelect(builder_->CreateFCmpOLT(x, y), x, y); + default: break; + } + assert(false); + return nullptr; + }; -// unsigned axis = x->get_axis(); + // reduce within thread + unsigned axis = x->get_axis(); + arg_tile->for_each([&](indices_t idx) { + indices_t pidx = idx; + pidx[axis] = builder_->getInt32(0); + Value *current = arg_tile->get_value(idx); + // current partial result is not initialized -- create + if(partial.find(pidx) == partial.end()) + partial[pidx] = current; + // current partial result is initialized -- accumulate + else + partial[pidx] = accumulate(partial[pidx], current); + }); -// // reduce within thread -// arg_tile->for_each([&](indices_t idx) { -// indices_t pidx = idx; -// pidx[axis] = builder_->getInt32(0); -// Value *current = arg_tile->get_value(idx); -// // current partial result is not initialized -- create -// if(partial.find(pidx) == partial.end()) -// partial[pidx] = current; -// // current partial result is initialized -- accumulate -// else -// partial[pidx] = accumulate(partial[pidx], current); -// }); + // depth + unsigned shape_ax = arg->get_type()->get_tile_shapes()[axis]; + unsigned per_thread = arg_tile->axis(axis).values.size(); + unsigned depth = shape_ax / per_thread; -// // depth -// unsigned shape_ax = arg->get_type()->get_tile_shapes()[axis]; -// unsigned per_thread = arg_tile->axis(axis).values.size(); -// unsigned depth = shape_ax / per_thread; + // shapes + auto shared_shapes = arg_tile->get_shapes(); + shared_shapes[axis] = depth; -// // shapes -// auto shared_shapes = arg_tile->get_shapes(); -// shared_shapes[axis] = depth; + // reduce within blocks + machine_layout_t *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x))); + shared_tile *stile = (shared_tile*)slayout->create(x); -// // reduce within blocks -// unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace(); -// Type *res_ty = builder_->getFloatTy(); -// Value *base_ptr = builder_->CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space)); -// for(auto& x: partial) { -// // current element being computed -// Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id; -// Value *&result = x.second; -// indices_t write_idx = x.first; -// write_idx[axis] = lane; -// // shared memory write pointer -// Value *write_offset = shared_tile::shared_offset(*builder_, shared_shapes, write_idx); -// Value *write_ptr = builder_->CreateGEP(base_ptr, write_offset); -// // initialize shared memory -// tgt_->add_barrier(*mod_, *builder_); -// builder_->CreateStore(result, write_ptr); -// // build result -// for(unsigned i = depth/2; i > 0; i >>= 1){ -// // current indices -// indices_t current(write_idx.size(), builder_->getInt32(0)); -// current[axis] = builder_->getInt32(i); -// // shared memory offset -// Value *read_offset = shared_tile::shared_offset(*builder_, shared_shapes, current); -// Value *is_active = builder_->CreateICmpULT(lane, builder_->getInt32(i)); -// read_offset = builder_->CreateSelect(is_active, read_offset, builder_->getInt32(0)); -// // shared memory read pointer -// Value *read_ptr = builder_->CreateGEP(write_ptr, read_offset); -// tgt_->add_barrier(*mod_, *builder_); -// Value *next = builder_->CreateLoad(read_ptr); -// // accumulate -// result = accumulate(result, next); -// // write back -// builder_->CreateStore(result, write_ptr); -// } -// } -// tgt_->add_barrier(*mod_, *builder_); + unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace(); + Type *res_ty = builder_->getFloatTy(); + Value *base_ptr = builder_->CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space)); + for(auto& x: partial) { + // current element being computed + Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id; + Value *&result = x.second; + indices_t write_idx = x.first; + write_idx[axis] = lane; + // shared memory write pointer + Value *write_offset = shared_tile::shared_offset(*builder_, stile->get_shapes(), stile->get_perm(), stile->get_order(), write_idx); + Value *write_ptr = builder_->CreateGEP(base_ptr, write_offset); + // initialize shared memory + tgt_->add_barrier(mod_, *builder_); + builder_->CreateStore(result, write_ptr); + // build result + for(unsigned i = depth/2; i > 0; i >>= 1){ + // current indices + indices_t current(write_idx.size(), builder_->getInt32(0)); + current[axis] = builder_->getInt32(i); + // shared memory offset + Value *read_offset = shared_tile::shared_offset(*builder_, stile->get_shapes(), stile->get_perm(), stile->get_order(), current); + Value *is_active = builder_->CreateICmpULT(lane, builder_->getInt32(i)); + read_offset = builder_->CreateSelect(is_active, read_offset, builder_->getInt32(0)); + // shared memory read pointer + Value *read_ptr = builder_->CreateGEP(write_ptr, read_offset); + tgt_->add_barrier(mod_, *builder_); + Value *next = builder_->CreateLoad(read_ptr); + // accumulate + result = accumulate(result, next); + // write back + builder_->CreateStore(result, write_ptr); + } + } + tgt_->add_barrier(mod_, *builder_); -// distributed_tile* x_tile = (distributed_tile*)tmap_.at(x); -// x_tile->for_each([&](indices_t idx) { -// indices_t red_idx = idx; -// red_idx.insert(red_idx.begin() + axis, builder_->getInt32(0)); -// Value *read_offset = shared_tile::shared_offset(*builder_, shared_shapes, red_idx); -// Value *read_ptr = builder_->CreateGEP(base_ptr, read_offset); -// x_tile->set_value(idx, builder_->CreateLoad(read_ptr)); -// }); + distributed_tile* x_tile = (distributed_tile*)tmap_.at(x); + x_tile->for_each([&](indices_t idx) { + indices_t red_idx = idx; + red_idx.insert(red_idx.begin() + axis, builder_->getInt32(0)); + Value *read_offset = shared_tile::shared_offset(*builder_, stile->get_shapes(), stile->get_perm(), stile->get_order(), red_idx); + Value *read_ptr = builder_->CreateGEP(base_ptr, read_offset); + x_tile->set_value(idx, builder_->CreateLoad(read_ptr)); + }); } void generator::visit_select_inst(ir::select_inst* select) { diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 9857e9865..c87e1c938 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -13,7 +13,7 @@ int main() { for(auto x: std::vector>{{false, false}, {false, true}, {true, false}, {true, true}}){ std::vector tmp = { - config_t{ord, x[0], x[1], 4096, 4096, 4096}, + config_t{ord, x[0], x[1], 2048, 2048, 2048}, // config_t{ord, x[0], x[1], 16, 2048, 2048}, // config_t{ord, x[0], x[1], 32, 2048, 2048}, // config_t{ord, x[0], x[1], 64, 2048, 2048}, @@ -34,7 +34,7 @@ int main() { for(const auto& c: configs){ std::tie(ord, AT, BT, M, N, K) = c; std::cout << "// " << c << std::flush; - for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord)) + for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord)) std::cout << ", " << perf << std::flush; std::cout << std::endl; } diff --git a/tests/common/dot.h b/tests/common/dot.h index ba6447162..ddbb1c77a 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -111,7 +111,7 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, if(mode == BENCH) { opt.defines.push_back({"TM", {"64", "128"}}); opt.defines.push_back({"TN", {"64", "128"}}); - opt.defines.push_back({"TK", {"8", "16"}}); + opt.defines.push_back({"TK", {"8"}}); opt.num_warps = {2, 4, 8}; }