[dnn/blocksparse] FPROP test passes!

This commit is contained in:
Philippe Tillet
2019-07-29 17:06:20 -07:00
parent 17cb2db356
commit dc11f70fad
20 changed files with 360 additions and 118 deletions

View File

@@ -32,6 +32,11 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un
return result;
}
Instruction* amd_cl_target::add_memfence(Module *module, IRBuilder<>& builder) {
throw std::runtime_error("not implemented");
}
Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workgroup_id_x,
@@ -43,6 +48,16 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne
return group_id;
}
Value* amd_cl_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::r600_read_ngroups_x,
Intrinsic::r600_read_ngroups_y,
Intrinsic::r600_read_ngroups_z
};
Value* get_num_group = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_num_group, {});
}
Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::amdgcn_workitem_id_x,
@@ -70,6 +85,12 @@ Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder)
return builder.CreateCall(barrier, {});
}
Instruction* nvidia_cu_target::add_memfence(Module *module, IRBuilder<>& builder) {
Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::nvvm_membar_gl);
return builder.CreateCall(barrier, {});
}
Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* group_id = get_block_id(module, builder, ax);
Value* result = builder.CreateMul(builder.getInt32(stride), group_id);
@@ -82,39 +103,39 @@ Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsi
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
};
bool z_order = true;
if(z_order && ax < 2){
static std::array<Intrinsic::ID, 3> n_cta_ids = {
Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
};
Value* cta_id_0 = builder.CreateIntrinsic(cta_ids[0], {}, {});
Value* cta_id_1 = builder.CreateIntrinsic(cta_ids[1], {}, {});
Value* n_cta_id_0 = builder.CreateIntrinsic(n_cta_ids[0], {}, {});
Value* n_cta_id_1 = builder.CreateIntrinsic(n_cta_ids[1], {}, {});
// global block ID
Value* bid = builder.CreateAdd(cta_id_0, builder.CreateMul(cta_id_1, n_cta_id_0));
// helper for minimum
auto Min = [&](Value *x, Value *y){
return builder.CreateSelect(builder.CreateICmpSGE(x, y), y, x);
};
// super-tile size
Value* sts = Min(builder.getInt32(16), n_cta_id_1);
// number of CTAs per super-block
Value *nscta = builder.CreateMul(n_cta_id_0, sts);
Value *bid0 = builder.CreateURem(builder.CreateUDiv(bid, sts), n_cta_id_0);
Value *bid1 = builder.CreateAdd(builder.CreateMul(builder.CreateUDiv(bid, nscta), sts),builder.CreateURem(bid, sts));
if(ax == 0)
return bid0;
else
return bid1;
}
else{
// bool z_order = true;
// if(z_order && ax < 2){
// static std::array<Intrinsic::ID, 3> n_cta_ids = {
// Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
// Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
// Intrinsic::nvvm_read_ptx_sreg_nctaid_z
// };
// Value* cta_id_0 = builder.CreateIntrinsic(cta_ids[0], {}, {});
// Value* cta_id_1 = builder.CreateIntrinsic(cta_ids[1], {}, {});
// Value* n_cta_id_0 = builder.CreateIntrinsic(n_cta_ids[0], {}, {});
// Value* n_cta_id_1 = builder.CreateIntrinsic(n_cta_ids[1], {}, {});
// // global block ID
// Value* bid = builder.CreateAdd(cta_id_0, builder.CreateMul(cta_id_1, n_cta_id_0));
// // helper for minimum
// auto Min = [&](Value *x, Value *y){
// return builder.CreateSelect(builder.CreateICmpSGE(x, y), y, x);
// };
// // super-tile size
// Value* sts = Min(builder.getInt32(16), n_cta_id_1);
// // number of CTAs per super-block
// Value *nscta = builder.CreateMul(n_cta_id_0, sts);
// Value *bid0 = builder.CreateURem(builder.CreateUDiv(bid, sts), n_cta_id_0);
// Value *bid1 = builder.CreateAdd(builder.CreateMul(builder.CreateUDiv(bid, nscta), sts),builder.CreateURem(bid, sts));
// if(ax == 0)
// return bid0;
// else
// return bid1;
// }
// else{
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
Value* cta_id = builder.CreateCall(get_cta_id, {});
return cta_id;
}
// }
}
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {
@@ -127,6 +148,16 @@ Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsi
return builder.CreateCall(get_local_id, {});
}
Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
};
Value* get_nctaid = Intrinsic::getDeclaration(module, ids[ax]);
return builder.CreateCall(get_nctaid, {});
}
// CPU
void cpu_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *module, Function* fn) {
@@ -138,6 +169,12 @@ Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) {
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
}
Instruction* cpu_target::add_memfence(Module *module, IRBuilder<>& builder) {
// no barrier on CPU
return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0));
}
Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) {
const Function *fn = builder.GetInsertBlock()->getParent();
size_t num_params = fn->getFunctionType()->getNumParams();
@@ -149,6 +186,11 @@ Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsi
return (Argument*)ids[ax];
}
Value* cpu_target::get_num_blocks(Module *module, IRBuilder<>& builder, unsigned ax) {
throw std::runtime_error("not implemented");
}
Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) {
Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax));
return result;