added fragmented axis

This commit is contained in:
Philippe Tillet
2019-06-07 10:32:56 -07:00
parent 781b6d377d
commit 6fce9f28ae
3 changed files with 59 additions and 29 deletions

View File

@@ -6,14 +6,14 @@ data_files_path = tf.resource_loader.get_data_files_path()
library_dir = '/home/philippe/development/triton/build/examples/python/tensorflow'
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
M, N, K = 512, 512, 512
M, N, K = 16, 16, 16
a = tf.placeholder(tf.float16, shape=[M, K])
b = tf.placeholder(tf.float16, shape=[N, K])
locks = tf.placeholder(tf.int32, shape=[4096])
c = module.dot(a, b, locks)
# Reference
ha = np.random.rand(M, K).astype(np.float16)
hb = np.random.rand(N, K).astype(np.float16)
ha = np.ones((M, K)).astype(np.float16)
hb = np.ones((N, K)).astype(np.float16)
hresult = np.dot(hb.T, ha)
# Run

View File

@@ -21,6 +21,7 @@ class tune {
typedef std::pair<ir::value*, unsigned> node_t;
typedef std::map <node_t, std::set<node_t>> graph_t;
public:
enum fragment_t{
STRIDED_SCAN,
HMMA_FRAGMENT_C
@@ -41,6 +42,7 @@ public:
std::vector<ir::metaparameter *> get_params(ir::module& mod);
std::map<std::string, ir::metaparameter *> get_params(ir::instruction* i);
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); }
unsigned get_param_group(ir::value *value, unsigned ax);
void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; groups_[dst] = groups_[src]; }
bool check_constraints(std::map<ir::value *, std::vector<std::string>> &errors);

View File

@@ -433,33 +433,61 @@ inline void to_warps(const std::vector<unsigned> &bs, std::vector<unsigned> &nw,
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
const auto& shapes = v->get_type()->get_tile_shapes();
size_t dim = shapes.size();
std::vector<unsigned> contiguous(dim);
std::vector<unsigned> block_size(dim);
std::vector<unsigned> warp_size(dim);
std::vector<unsigned> n_warps(dim);
for(unsigned i = 0; i < shapes.size(); i++){
std::string str_i = std::to_string(i);
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
block_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
}
to_warps(block_size, n_warps, warp_size);
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
// Create axes
for(unsigned k = 0; k < dim; k++) {
std::string str_k = std::to_string(k);
Value *warp_size_k = builder.getInt32(warp_size[k]);
Value *contiguous_k = builder.getInt32(contiguous[k]);
Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k));
thread_id = builder.CreateMul(thread_id, contiguous_k);
unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k];
unsigned per_thread = contiguous[k] * shapes[k]->get_value() / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
if(params_->get_fragment(v, 0) == tune::STRIDED_SCAN){
std::vector<unsigned> contiguous(dim);
std::vector<unsigned> block_size(dim);
std::vector<unsigned> warp_size(dim);
std::vector<unsigned> n_warps(dim);
for(unsigned i = 0; i < shapes.size(); i++){
std::string str_i = std::to_string(i);
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
block_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
}
axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list};
to_warps(block_size, n_warps, warp_size);
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
// Create axes
for(unsigned k = 0; k < dim; k++) {
std::string str_k = std::to_string(k);
Value *warp_size_k = builder.getInt32(warp_size[k]);
Value *contiguous_k = builder.getInt32(contiguous[k]);
Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k));
thread_id = builder.CreateMul(thread_id, contiguous_k);
unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k];
unsigned per_thread = contiguous[k] * shapes[k]->get_value() / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list};
}
}
else {
Value *_1 = builder.getInt32(1);
Value *_2 = builder.getInt32(2);
Value *_4 = builder.getInt32(4);
Value *_8 = builder.getInt32(8);
Value *_16 = builder.getInt32(16);
// offset_i = tid & 2 + tid & 8
Value *offset_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
builder.CreateAnd(u_thread_id, _8));
// offset_j = (tid & 1) + (tid & 4)*2 + (tid & 16)/4
Value *offset_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1),
builder.CreateAdd(builder.CreateMul(builder.CreateAnd(u_thread_id, _4), _2),
builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), _4)));
// idx_i
std::vector<Value*> idx_i;
for(unsigned i = 0; i < 2; i++)
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(i*4)));
// idx_j
std::vector<Value*> idx_j;
for(unsigned j = 0; j < 2; j++)
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(j*2)));
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i};
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j};
}
}