[codegen/tune]: added fragmentation types
This commit is contained in:
@@ -20,8 +20,8 @@ using GPUDevice = Eigen::GpuDevice;
|
||||
|
||||
const char* src =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TM = {16};
|
||||
const tunable int32 TN = {16};
|
||||
const tunable int32 TK = {8};
|
||||
const tunable int32 GZ = {1};
|
||||
|
||||
@@ -126,7 +126,7 @@ class BlockSparseGemmOp : public OpKernel {
|
||||
triton::driver::cu_buffer dlocks(ctx, (CUdeviceptr)locks.flat<int32_t>().data(), false);
|
||||
stream->synchronize();
|
||||
// just-in-time compile source-code
|
||||
jit.add_module("matmul", src, {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1});
|
||||
jit.add_module("matmul", src, {8, 2, 16, 8, 2, 16, 8, 8, 2, 2, 8, 8, 8, 1});
|
||||
triton::driver::kernel* kernel = jit.get_function("matmul");
|
||||
triton::jit::launch_information info = jit.get_launch_info("matmul");
|
||||
// launch info
|
||||
|
@@ -20,6 +20,6 @@ hresult = np.dot(hb.T, ha)
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([c], feed_dict = {locks: np.zeros(4096),
|
||||
a: ha,
|
||||
b: hb})
|
||||
a: ha,
|
||||
b: hb})
|
||||
print(result - hresult)
|
||||
|
@@ -70,6 +70,7 @@ private:
|
||||
unsigned vector_size_;
|
||||
};
|
||||
|
||||
// Distribtued tile
|
||||
class distributed_tile: public tile{
|
||||
typedef std::vector<distributed_axis> axes_t;
|
||||
typedef std::vector<indices_t> ordered_indices_vec_t;
|
||||
@@ -98,6 +99,15 @@ private:
|
||||
};
|
||||
|
||||
|
||||
// Fragmented tile
|
||||
class fragmented_tile: public tile{
|
||||
public:
|
||||
|
||||
private:
|
||||
|
||||
};
|
||||
|
||||
// Selection pass
|
||||
class selection{
|
||||
typedef std::map<ir::value *, llvm::Value *> vmap_t;
|
||||
typedef std::map<ir::value *, tile *> tmap_t;
|
||||
@@ -118,9 +128,9 @@ private:
|
||||
|
||||
// grid construction
|
||||
void create_grids(std::vector<ir::value *> &grids,
|
||||
std::map<ir::metaparameter *, ir::value *> &references,
|
||||
std::map<unsigned, ir::value *> &references,
|
||||
ir::function *fn);
|
||||
void create_tile(ir::value *v, llvm::IRBuilder<> &builder, const std::map<ir::metaparameter *, ir::value *> &references, std::set<ir::value *> &seen, llvm::Value *sh_mem_ptr);
|
||||
void create_tile(ir::value *v, llvm::IRBuilder<> &builder, const std::map<unsigned, ir::value *> &references, std::set<ir::value *> &seen, llvm::Value *sh_mem_ptr);
|
||||
void init_axes(ir::value *i, llvm::IRBuilder<> &builder, llvm::Value *u_thread_id, llvm::Value *u_warp_id);
|
||||
void init_grids(ir::function *fn, llvm::IRBuilder<> &builder, llvm::Value *sh_mem_ptr);
|
||||
|
||||
@@ -143,7 +153,7 @@ private:
|
||||
tune *params_;
|
||||
target *tgt_;
|
||||
shmem_info *buffer_info_;
|
||||
std::map<ir::metaparameter*, distributed_axis> axes_;
|
||||
std::map<unsigned, distributed_axis> axes_;
|
||||
llvm::Value *sh_mem_ptr_;
|
||||
};
|
||||
|
||||
|
@@ -21,11 +21,17 @@ class tune {
|
||||
typedef std::pair<ir::value*, unsigned> node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
|
||||
enum fragment_t{
|
||||
STRIDED_SCAN,
|
||||
HMMA_FRAGMENT_C
|
||||
};
|
||||
|
||||
private:
|
||||
void add_constraint(node_t x, node_t y);
|
||||
void init_c_phi(ir::instruction *i);
|
||||
void init_c_graph(ir::instruction *v);
|
||||
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, std::set<node_t> &nodes, graph_t &graph);
|
||||
fragment_t get_fragmentation_type(node_t x, graph_t &graph);
|
||||
void connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id);
|
||||
void create_grids(std::vector<ir::instruction*> &grids, std::map<ir::metaparameter *, ir::instruction *> &references, ir::function *fn);
|
||||
|
||||
|
||||
@@ -34,7 +40,8 @@ public:
|
||||
std::vector<ir::metaparameter *> get_params(ir::module& mod);
|
||||
std::map<std::string, ir::metaparameter *> get_params(ir::instruction* i);
|
||||
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
|
||||
void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; }
|
||||
unsigned get_param_group(ir::value *value, unsigned ax);
|
||||
void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; groups_[dst] = groups_[src]; }
|
||||
bool check_constraints(std::map<ir::value *, std::vector<std::string>> &errors);
|
||||
void run(ir::module &mod);
|
||||
void init(ir::module &mod);
|
||||
@@ -46,12 +53,14 @@ private:
|
||||
std::vector<unsigned*> pool_;
|
||||
graph_t dependencies_;
|
||||
std::set<node_t> nodes_;
|
||||
std::map<node_t, fragment_t> fragments_;
|
||||
std::map<node_t, unsigned> static_params_;
|
||||
std::map<ir::value*, std::map<std::string, ir::metaparameter*>> params_;
|
||||
std::map<unsigned, ir::metaparameter*> global_range_sizes_;
|
||||
unsigned num_global_ranges_;
|
||||
unsigned num_threads_;
|
||||
std::vector<ir::instruction*> grids_;
|
||||
std::map<ir::value*, std::map<unsigned, unsigned>> groups_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -459,12 +459,12 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
||||
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
||||
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[params_->get_param(v, "nts.d" + str_k)] = distributed_axis{contiguous[k], idx_list};
|
||||
axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list};
|
||||
}
|
||||
}
|
||||
|
||||
void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
std::map<ir::metaparameter*, ir::value*> &references,
|
||||
std::map<unsigned, ir::value*> &references,
|
||||
ir::function *fn) {
|
||||
// get number of dimensions greater than 1
|
||||
auto get_tile_gt1_dim = [&](ir::value *v){
|
||||
@@ -479,7 +479,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
std::function<void(ir::value*)> bind_references = [&](ir::value *v)
|
||||
{
|
||||
// skip
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second || dynamic_cast<ir::mask_inst*>(v))
|
||||
return;
|
||||
// recurse
|
||||
if(auto *user = dynamic_cast<ir::user*>(v))
|
||||
@@ -492,7 +492,7 @@ void selection::create_grids(std::vector<ir::value*> &grids,
|
||||
for(size_t d = 0; d < shapes.size(); d++){
|
||||
if(shapes[d]->get_value() == 1)
|
||||
continue;
|
||||
ir::metaparameter *x = params_->get_param(v, "nts.d" + std::to_string(d));
|
||||
unsigned x = params_->get_param_group(v, d);
|
||||
ir::value *&r = references[x];
|
||||
if(!r || get_tile_gt1_dim(v) > get_tile_gt1_dim(r))
|
||||
r = v;
|
||||
@@ -517,7 +517,7 @@ bool static inline has_phi_user(ir::value *v) {
|
||||
return false;
|
||||
}
|
||||
void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
const std::map<ir::metaparameter*, ir::value*>& references,
|
||||
const std::map<unsigned, ir::value*>& references,
|
||||
std::set<ir::value*> &seen, Value *sh_mem_ptr) {
|
||||
if(!v->get_type()->is_tile_ty() || !seen.insert(v).second)
|
||||
return;
|
||||
@@ -576,7 +576,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
std::vector<distributed_axis> axes(cshapes.size());
|
||||
for(size_t d = 0; d < cshapes.size(); d++){
|
||||
if(cshapes[d]->get_value() > 1){
|
||||
ir::metaparameter *x = params_->get_param(v, "nts.d" + std::to_string(d));
|
||||
unsigned x = params_->get_param_group(v, d);
|
||||
axes[d] = axes_.at(x);
|
||||
}
|
||||
else{
|
||||
@@ -607,7 +607,7 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
|
||||
Value *u_warp_id = builder.CreateUDiv(u_thread_id, warp_size);
|
||||
// create grid
|
||||
std::vector<ir::value*> grids;
|
||||
std::map<ir::metaparameter*, ir::value*> references;
|
||||
std::map<unsigned, ir::value*> references;
|
||||
create_grids(grids, references, fn);
|
||||
for(ir::value* i: grids){
|
||||
if(auto *instr = dynamic_cast<ir::instruction*>(i))
|
||||
@@ -812,7 +812,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
std::swap(b_idx[0], b_idx[1]);
|
||||
Value *a = TA->get_value(a_idx);
|
||||
Value *b = TB->get_value(b_idx);
|
||||
// res = builder.CreateCall(f_mul_add, {ConstantFP::get(a->getType(), 1), ConstantFP::get(b->getType(), 1), res});
|
||||
res = builder.CreateCall(f_mul_add, {a, b, res});
|
||||
|
||||
}
|
||||
|
@@ -15,6 +15,19 @@ namespace codegen{
|
||||
|
||||
tune::tune(): num_global_ranges_(0){ }
|
||||
|
||||
bool is_hmma(ir::value *v){
|
||||
bool result = false;
|
||||
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
|
||||
ir::value *a = x->get_operand(0);
|
||||
ir::type *a_ty = a->get_type();
|
||||
ir::value *b = x->get_operand(1);
|
||||
ir::type *b_ty = b->get_type();
|
||||
result = !x->is_a_trans() && x->is_b_trans();
|
||||
result = result && a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void tune::add_constraint(node_t x, node_t y) {
|
||||
dependencies_[x].insert(y);
|
||||
dependencies_[y].insert(x);
|
||||
@@ -34,6 +47,7 @@ void tune::init_c_phi(ir::instruction *v) {
|
||||
}
|
||||
|
||||
void tune::init_c_graph(ir::instruction *v) {
|
||||
|
||||
// Reference shape
|
||||
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(v->get_parent()->get_context());
|
||||
ir::type::tile_shapes_t shapes;
|
||||
@@ -83,20 +97,41 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
add_constraint({v, 1}, {D, 1});
|
||||
}
|
||||
// Element-wise
|
||||
else if(dynamic_cast<ir::user*>(v)){
|
||||
else if(dynamic_cast<ir::user*>(v)) {
|
||||
for(unsigned k = 0; k < v->get_num_results(); k++)
|
||||
for(unsigned i = 0; i < shapes.size(); i ++)
|
||||
for(ir::value* op: v->ops())
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
for(ir::value* op: v->ops()){
|
||||
add_constraint({v->get_result(k), i}, {op, i});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void tune::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, std::set<node_t> &nodes, graph_t &graph) {
|
||||
tune::fragment_t tune::get_fragmentation_type(node_t x, graph_t &graph){
|
||||
std::list<node_t> work;
|
||||
std::set<node_t> seen;
|
||||
work.push_back(x);
|
||||
while(!work.empty()){
|
||||
node_t current = work.back();
|
||||
if(is_hmma(current.first))
|
||||
return HMMA_FRAGMENT_C;
|
||||
work.pop_back();
|
||||
seen.insert(current);
|
||||
for(node_t y: graph[current]){
|
||||
if(seen.find(y) == seen.end())
|
||||
work.push_back(y);
|
||||
}
|
||||
}
|
||||
return STRIDED_SCAN;
|
||||
}
|
||||
|
||||
void tune::connected_components(node_t x, const std::vector<ir::metaparameter *> mps, const std::vector<std::string> prefixes, std::set<node_t> &nodes, graph_t &graph, unsigned group_id) {
|
||||
groups_[x.first][x.second] = group_id;
|
||||
if(nodes.find(x) != nodes.end()){
|
||||
nodes.erase(x);
|
||||
std::string suffix = ".d" + std::to_string(x.second);
|
||||
params_[x.first].insert({"nts" + suffix, mps[0]});
|
||||
params_[x.first].insert({"mts" + suffix, mps[1]});
|
||||
for(int i = 0; i < mps.size(); i++)
|
||||
params_[x.first].insert({prefixes[i] + suffix, mps[i]});
|
||||
ir::type *ty = x.first->get_type();
|
||||
if(ty->is_tile_ty()){
|
||||
ir::type::tile_shapes_t::value_type shape = ty->get_tile_shapes().at(x.second);
|
||||
@@ -109,11 +144,11 @@ void tune::connected_components(node_t x, const std::vector<ir::metaparameter *>
|
||||
num_global_ranges_ = std::max(num_global_ranges_, ax + 1);
|
||||
}
|
||||
if(static_params_.find(x) != static_params_.end()){
|
||||
mps[0]->set_value(static_params_.at(x));
|
||||
mps[1]->set_value(static_params_.at(x));
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(static_params_.at(x));
|
||||
}
|
||||
for(const node_t &y: graph[x])
|
||||
connected_components(y, mps, nodes, graph);
|
||||
connected_components(y, mps, prefixes, nodes, graph, group_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,6 +177,10 @@ std::map<std::string, ir::metaparameter *> tune::get_params(ir::instruction* i)
|
||||
return params_.at(i);
|
||||
}
|
||||
|
||||
unsigned tune::get_param_group(ir::value *value, unsigned ax) {
|
||||
unsigned result = groups_.at(value).at(ax);
|
||||
return result;
|
||||
}
|
||||
|
||||
void tune::run(ir::module &mod) {
|
||||
ir::context &ctx = mod.get_context();
|
||||
@@ -159,12 +198,21 @@ void tune::run(ir::module &mod) {
|
||||
if(i->has_tile_result_or_op())
|
||||
init_c_phi(i);
|
||||
// Layout parameters
|
||||
while(!nodes_.empty()){
|
||||
unsigned group_id = 0;
|
||||
while(!nodes_.empty()) {
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
nts->set_value(1);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_);
|
||||
node_t node = *nodes_.begin();
|
||||
fragment_t fragment = get_fragmentation_type(node, dependencies_);
|
||||
if(fragment == STRIDED_SCAN) {
|
||||
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
|
||||
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
|
||||
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
|
||||
nts->set_value(1);
|
||||
}
|
||||
else {
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
connected_components(node, {fpw}, {"fpw"}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -269,7 +317,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
|
||||
int num_threads = 1;
|
||||
for(size_t k = 0; k < shapes.size(); k++)
|
||||
num_threads *= params_[i]["mts.d" + to_string(k)]->get_value();
|
||||
if(num_threads % 64 != 0)
|
||||
if(num_threads % 32 != 0)
|
||||
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size");
|
||||
if(num_threads != num_threads_)
|
||||
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
|
||||
|
@@ -62,8 +62,8 @@ std::string gemm::src(bool AT, bool BT) {
|
||||
}
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TM = {16};
|
||||
const tunable int32 TN = {16};
|
||||
const tunable int32 TK = {8};
|
||||
const tunable int32 GZ = {1};
|
||||
|
||||
|
Reference in New Issue
Block a user