added fragmented axis
This commit is contained in:
@@ -6,14 +6,14 @@ data_files_path = tf.resource_loader.get_data_files_path()
|
||||
library_dir = '/home/philippe/development/triton/build/examples/python/tensorflow'
|
||||
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
|
||||
|
||||
M, N, K = 512, 512, 512
|
||||
M, N, K = 16, 16, 16
|
||||
a = tf.placeholder(tf.float16, shape=[M, K])
|
||||
b = tf.placeholder(tf.float16, shape=[N, K])
|
||||
locks = tf.placeholder(tf.int32, shape=[4096])
|
||||
c = module.dot(a, b, locks)
|
||||
# Reference
|
||||
ha = np.random.rand(M, K).astype(np.float16)
|
||||
hb = np.random.rand(N, K).astype(np.float16)
|
||||
ha = np.ones((M, K)).astype(np.float16)
|
||||
hb = np.ones((N, K)).astype(np.float16)
|
||||
hresult = np.dot(hb.T, ha)
|
||||
|
||||
# Run
|
||||
|
@@ -21,6 +21,7 @@ class tune {
|
||||
typedef std::pair<ir::value*, unsigned> node_t;
|
||||
typedef std::map <node_t, std::set<node_t>> graph_t;
|
||||
|
||||
public:
|
||||
enum fragment_t{
|
||||
STRIDED_SCAN,
|
||||
HMMA_FRAGMENT_C
|
||||
@@ -41,6 +42,7 @@ public:
|
||||
std::vector<ir::metaparameter *> get_params(ir::module& mod);
|
||||
std::map<std::string, ir::metaparameter *> get_params(ir::instruction* i);
|
||||
ir::metaparameter* get_param(ir::value *value, const std::string &key) { return params_[value][key]; }
|
||||
fragment_t get_fragment(ir::value *value, unsigned ax) { return fragments_.at({value, ax}); }
|
||||
unsigned get_param_group(ir::value *value, unsigned ax);
|
||||
void copy(ir::value *dst, ir::value *src) { params_[dst] = params_[src]; groups_[dst] = groups_[src]; }
|
||||
bool check_constraints(std::map<ir::value *, std::vector<std::string>> &errors);
|
||||
|
@@ -433,33 +433,61 @@ inline void to_warps(const std::vector<unsigned> &bs, std::vector<unsigned> &nw,
|
||||
void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
|
||||
const auto& shapes = v->get_type()->get_tile_shapes();
|
||||
size_t dim = shapes.size();
|
||||
std::vector<unsigned> contiguous(dim);
|
||||
std::vector<unsigned> block_size(dim);
|
||||
std::vector<unsigned> warp_size(dim);
|
||||
std::vector<unsigned> n_warps(dim);
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
std::string str_i = std::to_string(i);
|
||||
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
|
||||
block_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
|
||||
}
|
||||
to_warps(block_size, n_warps, warp_size);
|
||||
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
|
||||
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *warp_size_k = builder.getInt32(warp_size[k]);
|
||||
Value *contiguous_k = builder.getInt32(contiguous[k]);
|
||||
Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k));
|
||||
thread_id = builder.CreateMul(thread_id, contiguous_k);
|
||||
unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k];
|
||||
unsigned per_thread = contiguous[k] * shapes[k]->get_value() / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
||||
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
if(params_->get_fragment(v, 0) == tune::STRIDED_SCAN){
|
||||
std::vector<unsigned> contiguous(dim);
|
||||
std::vector<unsigned> block_size(dim);
|
||||
std::vector<unsigned> warp_size(dim);
|
||||
std::vector<unsigned> n_warps(dim);
|
||||
for(unsigned i = 0; i < shapes.size(); i++){
|
||||
std::string str_i = std::to_string(i);
|
||||
contiguous[i] = params_->get_param(v, "nts.d" + str_i)->get_value();
|
||||
block_size[i] = params_->get_param(v, "mts.d" + str_i)->get_value();
|
||||
}
|
||||
axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list};
|
||||
to_warps(block_size, n_warps, warp_size);
|
||||
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, warp_size, builder);
|
||||
std::vector<Value*> warp_id = delinearize(u_warp_id, n_warps, builder);
|
||||
// Create axes
|
||||
for(unsigned k = 0; k < dim; k++) {
|
||||
std::string str_k = std::to_string(k);
|
||||
Value *warp_size_k = builder.getInt32(warp_size[k]);
|
||||
Value *contiguous_k = builder.getInt32(contiguous[k]);
|
||||
Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k));
|
||||
thread_id = builder.CreateMul(thread_id, contiguous_k);
|
||||
unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k];
|
||||
unsigned per_thread = contiguous[k] * shapes[k]->get_value() / per_block;
|
||||
std::vector<Value*> idx_list(per_thread);
|
||||
for(unsigned n = 0 ; n < per_thread; n++){
|
||||
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
|
||||
idx_list[n] = builder.CreateAdd(thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
|
||||
}
|
||||
axes_[params_->get_param_group(v, k)] = distributed_axis{contiguous[k], idx_list};
|
||||
}
|
||||
}
|
||||
else {
|
||||
Value *_1 = builder.getInt32(1);
|
||||
Value *_2 = builder.getInt32(2);
|
||||
Value *_4 = builder.getInt32(4);
|
||||
Value *_8 = builder.getInt32(8);
|
||||
Value *_16 = builder.getInt32(16);
|
||||
// offset_i = tid & 2 + tid & 8
|
||||
Value *offset_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
|
||||
builder.CreateAnd(u_thread_id, _8));
|
||||
// offset_j = (tid & 1) + (tid & 4)*2 + (tid & 16)/4
|
||||
Value *offset_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1),
|
||||
builder.CreateAdd(builder.CreateMul(builder.CreateAnd(u_thread_id, _4), _2),
|
||||
builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), _4)));
|
||||
// idx_i
|
||||
std::vector<Value*> idx_i;
|
||||
for(unsigned i = 0; i < 2; i++)
|
||||
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(i*4)));
|
||||
|
||||
// idx_j
|
||||
std::vector<Value*> idx_j;
|
||||
for(unsigned j = 0; j < 2; j++)
|
||||
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(j*2)));
|
||||
|
||||
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i};
|
||||
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j};
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user