[codegen] some more optimizations

This commit is contained in:
Philippe Tillet
2019-07-18 19:39:40 -07:00
parent 71594da66f
commit 5215fb0424
12 changed files with 108 additions and 73 deletions

View File

@@ -10,37 +10,40 @@
int main() {
bool AT = true;
bool BT = false;
typedef float T;
std::string ty = "fp16";
size_t dt_nbytes = sizeof(T);
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
// matrix multiplication parameters
int32_t M = 64, N = 128, K = 128;
std::vector<float> hc(M*N);
std::vector<float> rc(M*N);
std::vector<float> ha(M*K);
std::vector<float> hb(K*N);
int32_t M = 65536, N = 2048, K = 2048;
std::vector<T> hc(M*N);
std::vector<T> rc(M*N);
std::vector<T> ha(M*K);
std::vector<T> hb(K*N);
srand(0);
for(size_t i = 0; i < ha.size(); i++)
ha[i] = (float)rand()/RAND_MAX;
ha[i] = (T)rand()/RAND_MAX;
for(size_t i = 0; i < hb.size(); i++)
hb[i] = (float)rand()/RAND_MAX;
hb[i] = (T)rand()/RAND_MAX;
for(size_t i = 0; i < hc.size(); i++)
hc[i] = 0;
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4);
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4);
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*dt_nbytes);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*dt_nbytes);
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*dt_nbytes);
triton::driver::stream* stream = triton::driver::stream::create(context);
stream->write(da, true, 0, ha);
stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc);
stream->synchronize();
triton::dnn::gemm gemm(M, N, K, AT, BT, "fp32", "fp32", 4, 4);
gemm.enqueue(stream, {da, db, dc}, false);
stream->read(dc, true, 0, hc);
gemm.cpu_ref<float>(rc, ha, hb);
for(size_t i = 0; i < M*N; i++)
if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}
std::cout << "Pass!" << std::endl;
triton::dnn::gemm gemm(M, N, K, AT, BT, ty, ty, 4, 4);
gemm.enqueue(stream, {da, db, dc}, true);
// stream->read(dc, true, 0, hc);
// gemm.cpu_ref<T>(rc, ha, hb);
// for(size_t i = 0; i < M*N; i++)
// if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
// exit(EXIT_FAILURE);
// }
// std::cout << "Pass!" << std::endl;
}

View File

@@ -18,22 +18,22 @@ int main() {
// initialization
int32_t R = 3, S = 3;
int32_t B = 128, F = 128;
int32_t B = 16, F = 4096;
int32_t H = 16, W = 16;
int32_t C = 128;
int32_t C = 4096;
// random shifts
std::vector<int32_t> shift_h(C);
std::vector<int32_t> shift_w(C);
for(int32_t c = 0; c < C; c++){
shift_h[c] = rand() % R - R/2;
shift_w[c] = rand() % S - S/2;
shift_h[c] = 0;
shift_w[c] = 0;
}
// configuration
triton::dnn::shift shift(B, C, 1, H, W, 1, R, S, F, 1, 1,
shift_h.data(), shift_w.data(),
numeric_t_str, numeric_t_str,
op, false, triton::dnn::shift::NCHW);
op, false, triton::dnn::shift::CHWN);
// host buffers
size_t a_size = B*C*H*W;
size_t b_size = C*F;

View File

@@ -159,6 +159,9 @@ public:
static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
static CUresult cuCtxGetDevice(CUdevice* result);
static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
static CUresult cuFuncSetCacheConfig (CUfunction hfunc, CUfunc_cache config);
// NVML
static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
@@ -252,6 +255,9 @@ private:
static void* cuMemsetD8Async_;
static void* cuCtxPushCurrent_v2_;
static void* cuCtxPopCurrent_v2_;
static void* cuFuncGetAttribute_;
static void* cuFuncSetAttribute_;
static void* cuFuncSetCacheConfig_;
// NVML
static void* nvmlInit_v2_;
static void* nvmlDeviceGetHandleByPciBusId_v2_;

View File

@@ -529,8 +529,8 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
unsigned num_rep_0 = shapes[0]->get_value() / hmma_bts_0;
unsigned num_rep_1 = shapes[1]->get_value() / hmma_bts_1;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
pack_size_0_ = std::min<unsigned>(num_rep_0, 2);
pack_size_1_ = std::min<unsigned>(num_rep_1, 2);
// number of packs (interleaving)
num_packs_0_ = num_rep_0 / pack_size_0_;
num_packs_1_ = num_rep_1 / pack_size_1_;
@@ -1148,7 +1148,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
unsigned max_contiguous = axis_info_->get_max_contiguous(ptr);
unsigned alignment = std::min(starting_multiple, max_contiguous);
unsigned vector_size = std::min<unsigned>(result->axis(0).contiguous, alignment);
// vector_size = result->axis(0).contiguous;
vector_size = result->axis(0).contiguous;
// vector_size = 1;
std::map<unsigned, Value*> packets;
distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand());
@@ -1251,6 +1251,14 @@ void selection::run(ir::module &src, Module &dst) {
dst_fn->addAttribute(id, llvm_attr(dst_ctx, attr));
}
tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn);
// set metadata
Metadata *md_args[] = {
ValueAsMetadata::get(dst_fn),
MDString::get(dst_ctx, "maxntidx"),
ValueAsMetadata::get(dst_builder.getInt32(params_->get_num_threads()))
};
dst.getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(dst_ctx, md_args));
// map parameters
for(unsigned i = 0; i < fn->args().size(); i++)

View File

@@ -21,13 +21,13 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) {
bool is_op_1 = x == dot->get_operand(1);
if(is_hmma && is_op_0){
if(dot->is_a_trans())
return 4;
return 8;
else
return 16;
}
if(is_hmma && is_op_1){
if(!dot->is_b_trans())
return 4;
return 8;
else
return 16;
}

View File

@@ -77,14 +77,44 @@ Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder,
}
Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) {
static std::array<Intrinsic::ID, 3> ids = {
static std::array<Intrinsic::ID, 3> cta_ids = {
Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
Intrinsic::nvvm_read_ptx_sreg_ctaid_y,
Intrinsic::nvvm_read_ptx_sreg_ctaid_z
};
Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]);
Value* group_id = builder.CreateCall(get_group_id, {});
return group_id;
bool z_order = true;
if(z_order && ax < 2){
static std::array<Intrinsic::ID, 3> n_cta_ids = {
Intrinsic::nvvm_read_ptx_sreg_nctaid_x,
Intrinsic::nvvm_read_ptx_sreg_nctaid_y,
Intrinsic::nvvm_read_ptx_sreg_nctaid_z
};
Value* cta_id_0 = builder.CreateIntrinsic(cta_ids[0], {}, {});
Value* cta_id_1 = builder.CreateIntrinsic(cta_ids[1], {}, {});
Value* n_cta_id_0 = builder.CreateIntrinsic(n_cta_ids[0], {}, {});
Value* n_cta_id_1 = builder.CreateIntrinsic(n_cta_ids[1], {}, {});
// global block ID
Value* bid = builder.CreateAdd(cta_id_0, builder.CreateMul(cta_id_1, n_cta_id_0));
// helper for minimum
auto Min = [&](Value *x, Value *y){
return builder.CreateSelect(builder.CreateICmpSGE(x, y), y, x);
};
// super-tile size
Value* sts = Min(builder.getInt32(16), n_cta_id_1);
// number of CTAs per super-block
Value *nscta = builder.CreateMul(n_cta_id_0, sts);
Value *bid0 = builder.CreateURem(builder.CreateUDiv(bid, sts), n_cta_id_0);
Value *bid1 = builder.CreateAdd(builder.CreateMul(builder.CreateUDiv(bid, nscta), sts),builder.CreateURem(bid, sts));
if(ax == 0)
return bid0;
else
return bid1;
}
else{
Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]);
Value* cta_id = builder.CreateCall(get_cta_id, {});
return cta_id;
}
}
Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) {

View File

@@ -215,7 +215,7 @@ void tune::run(ir::module &mod) {
node_t node = *nodes_.begin();
if(fragments_[node] == STRIDED_SCAN) {
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 2, 64);
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
nts->set_value(1);
}
@@ -235,13 +235,13 @@ void tune::run(ir::module &mod) {
continue;
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 4));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 8, 8));
*params_.at(i).at("nts.d0") = *tmp;
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 4));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 4));
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 8, 8));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 8, 8));
*params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2;
}

View File

@@ -30,7 +30,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
/* the current template has not already been compiled */
if(m_jit.find(this) == m_jit.end()) {
base* clone = this->clone();
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx))).first->second.get();
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx, 8))).first->second.get();
std::ostringstream oss;
clone->triton_c_src(oss);
std::string src = oss.str();
@@ -51,7 +51,8 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
jit->add_module(name_.c_str(), src.c_str(), best.params);
}
else {
jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
// jit->add_module(name_.c_str(), src.c_str(), jit->get_valid(name_.c_str(), src.c_str()));
jit->add_module(name_.c_str(), src.c_str(), {32, 128, 16, 128, 2, 2, 2, 2, 4, 4, 32, 8, 4, 1});
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
clone->init_impl(stream, (triton::driver::cu_module*)kernel->module());

View File

@@ -49,8 +49,8 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args,
runtime::launch_information info) {
driver::buffer *a = args[0], *b = args[1], *c = args[2];
unsigned TM = info.global_range_size[0];
unsigned TN = info.global_range_size[1];
unsigned TM = info.globals.at("TM");
unsigned TN = info.globals.at("TN");
unsigned grid_0 = (M_ + TM - 1)/TM;
unsigned grid_1 = (N_ + TN - 1)/TN;
unsigned grid_2 = 1;
@@ -109,7 +109,7 @@ void gemm::triton_c_src(std::ostream &os) const {
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {16};
const tunable int32 TK = {32};
const tunable int32 GZ = {1};
void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
@@ -127,12 +127,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
)" + b_ty_ + R"(* pb[)" + BS0 + ", " + BS1 + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
)" + a_ty_ + R"( a[)" + AS0 + ", " + AS1 + R"(] = *pa;
)" + b_ty_ + R"( b[)" + BS0 + ", " + BS1 + R"(] = *pb;
int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda;
int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb;
last_a = last_a / TK * TK;
last_b = last_b / TK * TK;
int32 bound = K - max(last_a, last_b);
for(int32 k = K; k > bound; k = k - TK){
for(int32 k = K; k > TK; k = k - TK){
c = dot()" + usea + ", " + useb + R"(, c);
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
@@ -141,22 +136,8 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
for(int32 k = bound; k > 0; k = k - 1){
int1 checka[TM, 1] = rxc[:, newaxis] < M;
int1 checkb[TN, 1] = ryc[:, newaxis] < N;
)" + a_ty_ + R"(* pa[TM, 1] = A + (K - k))" + lda0 + " + rxc[:, newaxis]" + lda1 + R"(;
)" + b_ty_ + R"(* pb[TN, 1] = B + (K - k))" + ldb0 + " + ryc[:, newaxis]" + ldb1 + R"(;
)" + a_ty_ + R"( a[TM, 1] = checka ? *pa : 0;
)" + b_ty_ + R"( b[TN, 1] = checkb ? *pb : 0;
c = dot(a, trans(b), c);
}
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
@checkc *pc = c;
*pc = c;
}
)";
os << res;

View File

@@ -175,6 +175,9 @@ CUDA_DEFINE1(CUresult, cuCtxSetCurrent, CUcontext)
CUDA_DEFINE4(CUresult, cuMemsetD8Async, CUdeviceptr, unsigned char, size_t, CUstream)
CUDA_DEFINE1(CUresult, cuCtxPushCurrent_v2, CUcontext)
CUDA_DEFINE1(CUresult, cuCtxPopCurrent_v2, CUcontext*)
CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction)
CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute, int)
CUDA_DEFINE2(CUresult, cuFuncSetCacheConfig, CUfunction, CUfunc_cache)
NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *, nvmlDevice_t*)
NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
@@ -316,6 +319,9 @@ void* dispatch::cuCtxGetDevice_;
void* dispatch::cuMemsetD8Async_;
void* dispatch::cuCtxPushCurrent_v2_;
void* dispatch::cuCtxPopCurrent_v2_;
void* dispatch::cuFuncGetAttribute_;
void* dispatch::cuFuncSetAttribute_;
void* dispatch::cuFuncSetCacheConfig_;
void* dispatch::nvmlInit_v2_;
void* dispatch::nvmlDeviceGetHandleByPciBusId_v2_;

View File

@@ -124,6 +124,7 @@ cu_kernel::cu_kernel(driver::module *program, const char * name) : kernel(progra
cu_params_store_.reserve(64);
cu_params_.reserve(64);
dispatch::cuModuleGetFunction(&*cu_, *program->cu(), name);
// dispatch::cuFuncSetCacheConfig(*cu_, CU_FUNC_CACHE_PREFER_SHARED);
}
void cu_kernel::setArg(unsigned int index, std::size_t size, void* ptr){

View File

@@ -40,17 +40,16 @@ void loop_nest(std::vector<size_t> const & ranges,
ThreadPool pool(nthreads);
// Start with innermost loop
size_t i = D - 1;
// size_t current = 0;
while(true){
//Execute function
pool.enqueue([values, &f](){ f(values); });
// f(values);
// Execute function
pool.enqueue(f,values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;
values[i--] = 0;
}
i = D - 1;
// Small sleep so that the thread pool doesn't grow too big
std::this_thread::sleep_for(std::chrono::microseconds(1));
}
}
@@ -201,19 +200,19 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
launch_information info;
llvm::LLVMContext llvm_context;
auto ll_module = make_llvm_module(tt_module_1, passes_1, llvm_context, info);
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
double perf;
{
std::lock_guard<std::mutex> lock(mutex);
std::unique_ptr<driver::module> module(driver::module::create(driver_context_, &*ll_module));
std::unique_ptr<driver::kernel> kernel(driver::kernel::create(module.get(), name));
perf = benchmark(kernel.get(), info);
if(perf > best.perf){
best.perf = perf;
best.params = params;
}
for(unsigned p: params)
std::cout << p << " " << std::flush;
std::cout << perf << " [ " << best.perf << " ] " << std::endl;
for(size_t i = 0; i < params.size(); i++)
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
}
}, nthreads_);
std::cout << "Autotuning done - Best performance: " << best.perf << std::endl;