[feature] basic tensor core utilization works

This commit is contained in:
Philippe Tillet
2019-06-08 12:14:37 -07:00
parent 5f3d48c1d0
commit d074a166e2
4 changed files with 66 additions and 49 deletions

View File

@@ -54,26 +54,11 @@ void matmul(restrict read_only fp16 *A, restrict read_only fp16 *B,
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
int32 *plock = locks + ridx + ridy*grid0;
while(__atomic_cas(plock, 0, 1));
int32 *pcount = plock + grid0*grid1;
int32 count = *pcount;
int32 countp1 = select(count == GZ - 1, 0, count + 1);
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
if(count == 0) {
@checkc *pc = c;
*pcount = countp1;
}
else {
@checkc *pc = c + *pc;
*pcount = countp1;
}
__atomic_cas(plock, 1, 0);
@checkc *pc = c;
}
)";

View File

@@ -6,14 +6,14 @@ data_files_path = tf.resource_loader.get_data_files_path()
library_dir = '/home/philippe/development/triton/build/examples/python/tensorflow'
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
M, N, K = 16, 16, 16
M, N, K = 256, 256, 256
a = tf.placeholder(tf.float16, shape=[M, K])
b = tf.placeholder(tf.float16, shape=[N, K])
locks = tf.placeholder(tf.int32, shape=[4096])
c = module.dot(a, b, locks)
# Reference
ha = np.ones((M, K)).astype(np.float16)
hb = np.ones((N, K)).astype(np.float16)
ha = np.random.rand(M, K).astype(np.float16)
hb = np.random.rand(N, K).astype(np.float16)
hresult = np.dot(hb.T, ha)
# Run
@@ -22,4 +22,7 @@ sess.run(tf.global_variables_initializer())
result = sess.run([c], feed_dict = {locks: np.zeros(4096),
a: ha,
b: hb})
print(result - hresult)
print(result)
print(hresult)
#print(result - hresult)
print(np.max(np.abs(result - hresult)))

View File

@@ -57,6 +57,7 @@ private:
public:
shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder, llvm::Value* offset = nullptr);
void set_vector_size(unsigned vector_size);
void set_return_mode(bool return_vector);
void set_value(indices_t, llvm::Value *);
llvm::Value* get_value(indices_t idx);
llvm::Value* get_pointer() { return ptr_; }
@@ -64,6 +65,7 @@ public:
private:
llvm::Value *ptr_;
bool return_vector_;
llvm::Value *offset_;
llvm::IRBuilder<> &builder_;
std::map<indices_t, llvm::Value*> ptr_cache_;

View File

@@ -129,6 +129,7 @@ Value* shared_tile::shared_offset(indices_t idx) {
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder, Value *offset):
tile(ty, shapes), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){
return_vector_ = false;
}
void shared_tile::set_value(indices_t idx, Value *value) {
@@ -142,12 +143,18 @@ void shared_tile::set_vector_size(unsigned vector_size) {
vector_size_ = vector_size;
}
void shared_tile::set_return_mode(bool return_vector){
return_vector_ = return_vector;
}
Value* shared_tile::get_value(indices_t idx) {
indices_t non_cst_idx, cst_idx;
extract_constant(idx, non_cst_idx, cst_idx);
Value *&base_ptr = ptr_cache_[non_cst_idx];
if(base_ptr == nullptr){
base_ptr = builder_.CreateGEP(ptr_, shared_offset(non_cst_idx));
// base_ptr = builder_.CreateBitCast(base_ptr, load_ptr_->getType());
if(vector_size_ > 1){
Type *vec_ty = VectorType::get(base_ptr->getType()->getPointerElementType(), vector_size_);
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
@@ -160,7 +167,7 @@ Value* shared_tile::get_value(indices_t idx) {
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
Value *ptr = builder_.CreateGEP(base_ptr, div);
Value *result = builder_.CreateLoad(ptr);
if(vector_size_ > 1) {
if(return_vector_ == false && vector_size_ > 1) {
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
result = builder_.CreateExtractElement(result, rem);
}
@@ -479,19 +486,19 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), _4)));
// idx_i
std::vector<Value*> idx_j;
for(unsigned i = 0; i < 2; i++){
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(i*4)));
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(i*4 + 1)));
for(unsigned j = 0; j < 2; j++){
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(j*4)));
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(j*4 + 1)));
}
// idx_j
std::vector<Value*> idx_i;
for(unsigned j = 0; j < 2; j++){
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(j*2)));
for(unsigned i = 0; i < 2; i++){
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(i*2)));
}
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_j};
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_i};
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i};
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j};
}
}
@@ -855,6 +862,11 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
}
else
{
TA->set_vector_size(2);
TB->set_vector_size(2);
TA->set_return_mode(true);
TB->set_return_mode(true);
Value *_0 = builder.getInt32(0);
Value *_1 = builder.getInt32(1);
Value *_2 = builder.getInt32(2);
Value *_3 = builder.getInt32(3);
@@ -864,47 +876,62 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
BasicBlock *current = builder.GetInsertBlock();
Module *module = current->getModule();
Value *tid = tgt_->get_local_id(module, builder, 0);
// offset_a_i = (tid & 3)
// offset_a_j = (tid & 4)*2 + (tid & 16)/4;
Value *offset_a_i = builder.CreateAnd(tid, _3);
Value *offset_a_k = builder.CreateAdd(builder.CreateMul(builder.CreateAnd(tid, _4),
_2),
// offset_a_i = (tid & 4)*2 + (tid & 16)/4;
// offset_a_k = (tid & 3)
Value *offset_a_i = builder.CreateAdd(builder.CreateMul(builder.CreateAnd(tid, _4), _2),
builder.CreateUDiv(builder.CreateAnd(tid, _16),
_4));
// offset_b_i = (tid & 3)
// offset_b_j = (tid & 8)*1 + (tid & 16)/4
Value *offset_b_i = builder.CreateAnd(tid, _3);
Value *offset_b_k = builder.CreateAdd(builder.CreateAnd(tid, _8),
Value *offset_a_k = builder.CreateAnd(tid, _3);
// offset_b_i = (tid & 4)*1 + (tid & 16)/4
// offset_b_k = (tid & 3)
Value *offset_b_i = builder.CreateAdd(builder.CreateAnd(tid, _8),
builder.CreateUDiv(builder.CreateAnd(tid, _16),
_4));
Value *ha0 = TA->get_value({offset_a_i, offset_a_k});
Value *ha1 = TA->get_value({builder.CreateAdd(offset_a_i, _1), offset_a_k});
Value *hb0 = TB->get_value({offset_b_i, offset_b_k});
Value *hb1 = TB->get_value({builder.CreateAdd(offset_b_i, _1), offset_b_k});
Value *offset_b_k = builder.CreateAnd(tid, _3);
std::vector<Value *> fc;
result->for_each([&](indices_t idx){
fc.push_back(result->get_value(idx));
});
Type *void_ty = builder.getVoidTy();
Type *int32_ty = builder.getInt32Ty();
Type *fp32_ty = builder.getFloatTy();
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {int32_ty, int32_ty, int32_ty, int32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 "
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
"{$8, $9}, "
"{$10, $11}, "
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
Value *nc = builder.CreateCall(mma_fn, {builder.getInt32(0), builder.getInt32(0), builder.getInt32(0), builder.getInt32(0), fc[0], fc[1], fc[2], fc[3], fc[4], fc[5], fc[6], fc[7]});
std::cout << mma_fn->getFunctionType()->getFunctionNumParams() << std::endl;
for(unsigned K = 0; K < NK; K += 4){
Value *_K = builder.getInt32(K);
Value *ha0 = TA->get_value({offset_a_i, builder.CreateAdd(offset_a_k, _K)});
Value *ha1 = TA->get_value({builder.CreateAdd(offset_a_i, _2), builder.CreateAdd(offset_a_k, _K)});
Value *hb0 = TB->get_value({offset_b_i, builder.CreateAdd(offset_b_k, _K)});
Value *hb1 = TB->get_value({builder.CreateAdd(offset_b_i, _2), builder.CreateAdd(offset_b_k, _K)});
Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[0], fc[2], fc[1], fc[3], fc[4], fc[6], fc[5], fc[7]});
fc[0] = builder.CreateExtractValue(nc, {0});
fc[2] = builder.CreateExtractValue(nc, {1});
fc[1] = builder.CreateExtractValue(nc, {2});
fc[3] = builder.CreateExtractValue(nc, {3});
fc[4] = builder.CreateExtractValue(nc, {4});
fc[6] = builder.CreateExtractValue(nc, {5});
fc[5] = builder.CreateExtractValue(nc, {6});
fc[7] = builder.CreateExtractValue(nc, {7});
}
// write back
unsigned i = 0;
result->for_each([&](indices_t idx){
result->set_value(idx, builder.CreateExtractValue(nc, {i++}));
result->set_value(idx, fc[i++]);
});
std::cout << "haha" << std::endl;
TA->set_return_mode(false);
TB->set_return_mode(false);
}
}
else