|
|
|
@@ -129,6 +129,7 @@ Value* shared_tile::shared_offset(indices_t idx) {
|
|
|
|
|
|
|
|
|
|
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder, Value *offset):
|
|
|
|
|
tile(ty, shapes), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){
|
|
|
|
|
return_vector_ = false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void shared_tile::set_value(indices_t idx, Value *value) {
|
|
|
|
@@ -142,12 +143,18 @@ void shared_tile::set_vector_size(unsigned vector_size) {
|
|
|
|
|
vector_size_ = vector_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void shared_tile::set_return_mode(bool return_vector){
|
|
|
|
|
return_vector_ = return_vector;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Value* shared_tile::get_value(indices_t idx) {
|
|
|
|
|
indices_t non_cst_idx, cst_idx;
|
|
|
|
|
extract_constant(idx, non_cst_idx, cst_idx);
|
|
|
|
|
Value *&base_ptr = ptr_cache_[non_cst_idx];
|
|
|
|
|
if(base_ptr == nullptr){
|
|
|
|
|
base_ptr = builder_.CreateGEP(ptr_, shared_offset(non_cst_idx));
|
|
|
|
|
// base_ptr = builder_.CreateBitCast(base_ptr, load_ptr_->getType());
|
|
|
|
|
if(vector_size_ > 1){
|
|
|
|
|
Type *vec_ty = VectorType::get(base_ptr->getType()->getPointerElementType(), vector_size_);
|
|
|
|
|
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
|
|
|
|
@@ -160,7 +167,7 @@ Value* shared_tile::get_value(indices_t idx) {
|
|
|
|
|
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
|
|
|
|
|
Value *ptr = builder_.CreateGEP(base_ptr, div);
|
|
|
|
|
Value *result = builder_.CreateLoad(ptr);
|
|
|
|
|
if(vector_size_ > 1) {
|
|
|
|
|
if(return_vector_ == false && vector_size_ > 1) {
|
|
|
|
|
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
|
|
|
|
|
result = builder_.CreateExtractElement(result, rem);
|
|
|
|
|
}
|
|
|
|
@@ -479,19 +486,19 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
|
|
|
|
builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), _4)));
|
|
|
|
|
// idx_i
|
|
|
|
|
std::vector<Value*> idx_j;
|
|
|
|
|
for(unsigned i = 0; i < 2; i++){
|
|
|
|
|
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(i*4)));
|
|
|
|
|
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(i*4 + 1)));
|
|
|
|
|
for(unsigned j = 0; j < 2; j++){
|
|
|
|
|
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(j*4)));
|
|
|
|
|
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(j*4 + 1)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// idx_j
|
|
|
|
|
std::vector<Value*> idx_i;
|
|
|
|
|
for(unsigned j = 0; j < 2; j++){
|
|
|
|
|
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(j*2)));
|
|
|
|
|
for(unsigned i = 0; i < 2; i++){
|
|
|
|
|
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(i*2)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_j};
|
|
|
|
|
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_i};
|
|
|
|
|
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i};
|
|
|
|
|
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j};
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -855,6 +862,11 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
TA->set_vector_size(2);
|
|
|
|
|
TB->set_vector_size(2);
|
|
|
|
|
TA->set_return_mode(true);
|
|
|
|
|
TB->set_return_mode(true);
|
|
|
|
|
Value *_0 = builder.getInt32(0);
|
|
|
|
|
Value *_1 = builder.getInt32(1);
|
|
|
|
|
Value *_2 = builder.getInt32(2);
|
|
|
|
|
Value *_3 = builder.getInt32(3);
|
|
|
|
@@ -864,47 +876,62 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|
|
|
|
BasicBlock *current = builder.GetInsertBlock();
|
|
|
|
|
Module *module = current->getModule();
|
|
|
|
|
Value *tid = tgt_->get_local_id(module, builder, 0);
|
|
|
|
|
// offset_a_i = (tid & 3)
|
|
|
|
|
// offset_a_j = (tid & 4)*2 + (tid & 16)/4;
|
|
|
|
|
Value *offset_a_i = builder.CreateAnd(tid, _3);
|
|
|
|
|
Value *offset_a_k = builder.CreateAdd(builder.CreateMul(builder.CreateAnd(tid, _4),
|
|
|
|
|
_2),
|
|
|
|
|
// offset_a_i = (tid & 4)*2 + (tid & 16)/4;
|
|
|
|
|
// offset_a_k = (tid & 3)
|
|
|
|
|
Value *offset_a_i = builder.CreateAdd(builder.CreateMul(builder.CreateAnd(tid, _4), _2),
|
|
|
|
|
builder.CreateUDiv(builder.CreateAnd(tid, _16),
|
|
|
|
|
_4));
|
|
|
|
|
// offset_b_i = (tid & 3)
|
|
|
|
|
// offset_b_j = (tid & 8)*1 + (tid & 16)/4
|
|
|
|
|
Value *offset_b_i = builder.CreateAnd(tid, _3);
|
|
|
|
|
Value *offset_b_k = builder.CreateAdd(builder.CreateAnd(tid, _8),
|
|
|
|
|
Value *offset_a_k = builder.CreateAnd(tid, _3);
|
|
|
|
|
|
|
|
|
|
// offset_b_i = (tid & 4)*1 + (tid & 16)/4
|
|
|
|
|
// offset_b_k = (tid & 3)
|
|
|
|
|
Value *offset_b_i = builder.CreateAdd(builder.CreateAnd(tid, _8),
|
|
|
|
|
builder.CreateUDiv(builder.CreateAnd(tid, _16),
|
|
|
|
|
_4));
|
|
|
|
|
Value *ha0 = TA->get_value({offset_a_i, offset_a_k});
|
|
|
|
|
Value *ha1 = TA->get_value({builder.CreateAdd(offset_a_i, _1), offset_a_k});
|
|
|
|
|
Value *hb0 = TB->get_value({offset_b_i, offset_b_k});
|
|
|
|
|
Value *hb1 = TB->get_value({builder.CreateAdd(offset_b_i, _1), offset_b_k});
|
|
|
|
|
Value *offset_b_k = builder.CreateAnd(tid, _3);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<Value *> fc;
|
|
|
|
|
result->for_each([&](indices_t idx){
|
|
|
|
|
fc.push_back(result->get_value(idx));
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
Type *void_ty = builder.getVoidTy();
|
|
|
|
|
Type *int32_ty = builder.getInt32Ty();
|
|
|
|
|
Type *fp32_ty = builder.getFloatTy();
|
|
|
|
|
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
|
|
|
|
|
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
|
|
|
|
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {int32_ty, int32_ty, int32_ty, int32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
|
|
|
|
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
|
|
|
|
|
|
|
|
|
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 "
|
|
|
|
|
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
|
|
|
|
|
"{$8, $9}, "
|
|
|
|
|
"{$10, $11}, "
|
|
|
|
|
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
|
|
|
|
Value *nc = builder.CreateCall(mma_fn, {builder.getInt32(0), builder.getInt32(0), builder.getInt32(0), builder.getInt32(0), fc[0], fc[1], fc[2], fc[3], fc[4], fc[5], fc[6], fc[7]});
|
|
|
|
|
std::cout << mma_fn->getFunctionType()->getFunctionNumParams() << std::endl;
|
|
|
|
|
|
|
|
|
|
for(unsigned K = 0; K < NK; K += 4){
|
|
|
|
|
Value *_K = builder.getInt32(K);
|
|
|
|
|
Value *ha0 = TA->get_value({offset_a_i, builder.CreateAdd(offset_a_k, _K)});
|
|
|
|
|
Value *ha1 = TA->get_value({builder.CreateAdd(offset_a_i, _2), builder.CreateAdd(offset_a_k, _K)});
|
|
|
|
|
Value *hb0 = TB->get_value({offset_b_i, builder.CreateAdd(offset_b_k, _K)});
|
|
|
|
|
Value *hb1 = TB->get_value({builder.CreateAdd(offset_b_i, _2), builder.CreateAdd(offset_b_k, _K)});
|
|
|
|
|
Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[0], fc[2], fc[1], fc[3], fc[4], fc[6], fc[5], fc[7]});
|
|
|
|
|
fc[0] = builder.CreateExtractValue(nc, {0});
|
|
|
|
|
fc[2] = builder.CreateExtractValue(nc, {1});
|
|
|
|
|
fc[1] = builder.CreateExtractValue(nc, {2});
|
|
|
|
|
fc[3] = builder.CreateExtractValue(nc, {3});
|
|
|
|
|
fc[4] = builder.CreateExtractValue(nc, {4});
|
|
|
|
|
fc[6] = builder.CreateExtractValue(nc, {5});
|
|
|
|
|
fc[5] = builder.CreateExtractValue(nc, {6});
|
|
|
|
|
fc[7] = builder.CreateExtractValue(nc, {7});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// write back
|
|
|
|
|
unsigned i = 0;
|
|
|
|
|
result->for_each([&](indices_t idx){
|
|
|
|
|
result->set_value(idx, builder.CreateExtractValue(nc, {i++}));
|
|
|
|
|
result->set_value(idx, fc[i++]);
|
|
|
|
|
});
|
|
|
|
|
std::cout << "haha" << std::endl;
|
|
|
|
|
|
|
|
|
|
TA->set_return_mode(false);
|
|
|
|
|
TB->set_return_mode(false);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|