[dnn] better specification of recompilation key
This commit is contained in:
@@ -18,12 +18,12 @@ int main() {
|
||||
int32_t pad_d = 0, pad_h = 0, pad_w = 0;
|
||||
int32_t stride_d = 1, stride_h = 1, stride_w = 1;
|
||||
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
|
||||
// triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "fp32", "fp32", triton::dnn::conv::FPROP, 0);
|
||||
// triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "float", "float", triton::dnn::conv::FPROP, 0);
|
||||
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF,
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
upsample_d, upsample_h, upsample_w,
|
||||
"fp32", "fp32", ty, 0);
|
||||
"float", "float", ty, 0);
|
||||
// convolution configuration
|
||||
std::vector<float> hc(configuration.c_size());
|
||||
std::vector<float> rc(configuration.c_size());
|
||||
|
@@ -26,7 +26,7 @@ struct perf_t {
|
||||
|
||||
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef float NumericT;
|
||||
std::string ty = "fp16";
|
||||
std::string ty = "half";
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
triton::driver::context* context = stream->context();
|
||||
std::vector<NumericT> hc(M*N);
|
||||
@@ -46,7 +46,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
||||
stream->write(db, true, 0, hb);
|
||||
stream->write(dc, true, 0, hc);
|
||||
stream->synchronize();
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8);
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8, 8);
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
|
@@ -134,7 +134,7 @@ int main() {
|
||||
};
|
||||
for(config_t c: resnet18){
|
||||
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}){
|
||||
configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"});
|
||||
configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "half"});
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -37,7 +37,7 @@ std::vector<torch::Tensor>
|
||||
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.storage().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.storage().data(), false);
|
||||
// create template
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "float");
|
||||
batchnorm.enqueue(&stream, {&y, &m, &v, &x, &g, &b});
|
||||
stream.synchronize();
|
||||
return {fw_y, fw_m, fw_v};
|
||||
@@ -79,7 +79,7 @@ std::vector<torch::Tensor>
|
||||
triton::driver::cu_buffer dg(ctx, (CUdeviceptr)fw_dg.storage().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, (CUdeviceptr)fw_db.storage().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32", eps);
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "float", eps);
|
||||
batchnorm.enqueue(&stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
|
||||
stream.synchronize();
|
||||
return {fw_dx, fw_dg, fw_db};
|
||||
|
@@ -30,7 +30,7 @@ torch::Tensor conv_common(
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
1, 1, 1,
|
||||
"fp32", "fp32", ty, has_bias);
|
||||
"float", "float", ty, has_bias);
|
||||
// Bind memory
|
||||
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
|
||||
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);
|
||||
|
@@ -49,9 +49,9 @@ torch::Tensor shift_common(
|
||||
std::string dtype;
|
||||
at::ScalarType type = torcha.scalar_type();
|
||||
switch(type){
|
||||
case at::ScalarType::Double: dtype = "fp64"; break;
|
||||
case at::ScalarType::Float: dtype = "fp32"; break;
|
||||
case at::ScalarType::Half: dtype = "fp16"; break;
|
||||
case at::ScalarType::Double: dtype = "double"; break;
|
||||
case at::ScalarType::Float: dtype = "float"; break;
|
||||
case at::ScalarType::Half: dtype = "half"; break;
|
||||
default: AT_ERROR("unknown data-type for shift-conv");
|
||||
}
|
||||
// Get configuration
|
||||
|
@@ -58,7 +58,7 @@ public:
|
||||
triton::driver::cu_buffer m(ctx, fw_m->tensor_data().size(), (CUdeviceptr)fw_m->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer v(ctx, fw_v->tensor_data().size(), (CUdeviceptr)fw_v->tensor_data().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32");
|
||||
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "float", triton::dnn::FULL_TUNING);
|
||||
batchnorm.enqueue(stream, {&y, &m, &v, &x, &g, &b});
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ public:
|
||||
triton::driver::cu_buffer dg(ctx, fw_dg->tensor_data().size(), (CUdeviceptr)fw_dg->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer db(ctx, fw_db->tensor_data().size(), (CUdeviceptr)fw_db->tensor_data().data(), false);
|
||||
// create config
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32");
|
||||
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "float", triton::dnn::FULL_TUNING);
|
||||
batchnorm.enqueue(stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
|
||||
}
|
||||
|
||||
|
@@ -128,9 +128,9 @@ public:
|
||||
triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dlut(ctx, lut.tensor_data().size(), (CUdeviceptr)lut.tensor_data().data(), false);
|
||||
// create profile
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp16", params_.bsize, params_.locks, params_.blocks, OP);
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "half", params_.bsize, params_.locks, params_.blocks, OP);
|
||||
// blocksparse matmul
|
||||
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::FULL_TUNING);
|
||||
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
|
||||
triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks();
|
||||
Tensor *tmp = nullptr;
|
||||
TensorShape tmp_shapes;
|
||||
|
@@ -61,7 +61,7 @@ public:
|
||||
stride_d, stride_h, stride_w,
|
||||
pad_d, pad_h, pad_w,
|
||||
1, 1, 1,
|
||||
"fp16", "fp16",
|
||||
"half", "half",
|
||||
triton::dnn::conv::FPROP, has_bias);
|
||||
// allocate output
|
||||
auto c_shapes = conv.c_shapes();
|
||||
|
@@ -49,7 +49,7 @@ class DotOp : public OpKernel {
|
||||
triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false);
|
||||
triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false);
|
||||
// template
|
||||
triton::dnn::dot dot(M, N, K, false, false, "fp16", "fp16", 8, 8);
|
||||
triton::dnn::dot dot(M, N, K, false, false, "half", "half", 8, 8, 8);
|
||||
dot.enqueue(stream, {&da, &db, &dc});
|
||||
}
|
||||
|
||||
|
@@ -105,7 +105,7 @@ def batch_norm_grad(op, dy, mean, var):
|
||||
|
||||
|
||||
def run_batchnorm():
|
||||
C, H, W, B = 32, 14, 14, 64
|
||||
C, H, W, B = 8, 4, 4, 32
|
||||
np.random.seed(0)
|
||||
# Placeholders
|
||||
x = tf.placeholder(tf.float32, shape=[C, H, W, B])
|
||||
@@ -131,6 +131,6 @@ def run_batchnorm():
|
||||
print(np.max(np.abs(dg_t - dg_n)))
|
||||
print(np.max(np.abs(db_t - db_n)))
|
||||
|
||||
run_dot()
|
||||
#run_dot()
|
||||
#run_shift()
|
||||
#run_batchnorm()
|
||||
run_batchnorm()
|
||||
|
@@ -106,7 +106,7 @@ public:
|
||||
triton::dnn::shift shift(B, C, D, H, W, T, R_, S_, F,
|
||||
stride_h_, stride_w_,
|
||||
shift_h_data, shift_w_data,
|
||||
"fp16", "fp16", OP, has_bias, layout_);
|
||||
"half", "half", OP, has_bias, layout_);
|
||||
|
||||
// shapes for c
|
||||
std::vector<int64> c_shapes;
|
||||
|
@@ -91,6 +91,7 @@ public:
|
||||
void set_value(indices_t idx, llvm::Value *v);
|
||||
llvm::Value* get_value(indices_t idx);
|
||||
unsigned get_linear_index(indices_t idx);
|
||||
indices_t get_ordered_indices(unsigned id);
|
||||
void for_each(std::function<void(indices_t)> fn);
|
||||
const distributed_axis &axis(unsigned dim) { return axes_.at(dim); }
|
||||
|
||||
|
@@ -52,12 +52,15 @@ struct launch_context_t{
|
||||
typedef std::vector<unsigned> params_t;
|
||||
|
||||
class base {
|
||||
friend class cmp_recompile;
|
||||
friend class recompile_hash;
|
||||
friend class recompile_equal;
|
||||
|
||||
protected:
|
||||
// leading dimensions
|
||||
static void set_ld(const std::vector<int32_t>& shapes,
|
||||
std::vector<int32_t>& ld);
|
||||
// list of retuning parameters
|
||||
virtual std::vector<int64_t> retune_params() const = 0;
|
||||
|
||||
private:
|
||||
// initialize
|
||||
@@ -70,8 +73,6 @@ private:
|
||||
triton::runtime::launch_information info) = 0;
|
||||
// number of flops
|
||||
virtual size_t num_flops() const = 0;
|
||||
// comparison for maps
|
||||
virtual bool operator<(const base& other) const = 0;
|
||||
// default parameters
|
||||
virtual std::vector<params_t> search_space() const;
|
||||
virtual params_t heuristics() const;
|
||||
@@ -94,12 +95,21 @@ private:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
struct cmp_recompile{
|
||||
|
||||
struct recompile_equal{
|
||||
bool operator()(base* x, base* y) const{
|
||||
return *x < *y;
|
||||
return typeid(*x) == typeid(*y) &&
|
||||
x->retune_params() == y->retune_params();
|
||||
}
|
||||
};
|
||||
|
||||
struct recompile_hash{
|
||||
unsigned operator()(base* x) const{
|
||||
return x->retune_params()[0];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -47,15 +47,15 @@ private:
|
||||
triton::runtime::launch_information info);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
bool operator<(const base& other) const;
|
||||
// retuning parameters
|
||||
std::vector<int64_t> retune_params() const;
|
||||
// clone
|
||||
base* clone() const;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
batchnorm_forward(int C, int D, int H, int W, int B,
|
||||
std::string ty = "fp32", float eps = 1e-5);
|
||||
std::string ty = "float", float eps = 1e-5);
|
||||
// triton-c source
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
|
||||
@@ -82,15 +82,15 @@ private:
|
||||
runtime::launch_information info);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
bool operator<(const base& other) const;
|
||||
// retuning parameters
|
||||
std::vector<int64_t> retune_params() const;
|
||||
// clone
|
||||
base* clone() const;
|
||||
|
||||
public:
|
||||
// constructor
|
||||
batchnorm_backward(int C, int D, int H, int W, int B,
|
||||
std::string ty = "fp32", float eps = 1e-5);
|
||||
std::string ty = "float", float eps = 1e-5);
|
||||
// triton-c source
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
|
||||
|
@@ -20,8 +20,8 @@ private:
|
||||
triton::runtime::launch_information info);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
bool operator<(const base& other) const;
|
||||
// retuning parameters
|
||||
std::vector<int64_t> retune_params() const;
|
||||
// default parameters
|
||||
std::vector<params_t> search_space() const;
|
||||
params_t heuristics() const;
|
||||
|
@@ -37,8 +37,8 @@ private:
|
||||
triton::runtime::launch_information info);
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
// comparison for maps
|
||||
bool operator<(const base& other) const;
|
||||
// retuning parameters
|
||||
std::vector<int64_t> retune_params() const;
|
||||
// clone
|
||||
base* clone() const;
|
||||
|
||||
@@ -50,7 +50,7 @@ public:
|
||||
int stride_d, int stride_h, int stride_w,
|
||||
int pad_d, int pad_h, int pad_w,
|
||||
int upsample_d, int upsample_h, int upsample_w,
|
||||
std::string a_ty = "fp32", std::string b_ty = "fp32",
|
||||
std::string a_ty = "float", std::string b_ty = "float",
|
||||
type ty = FPROP, bool bias = false);
|
||||
|
||||
// accessors
|
||||
|
@@ -16,8 +16,8 @@ private:
|
||||
void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
std::vector<driver::buffer*> args,
|
||||
triton::runtime::launch_information info);
|
||||
// comparison for maps
|
||||
bool operator<(const base& other) const;
|
||||
// retuning parameters
|
||||
std::vector<int64_t> retune_params() const;
|
||||
// default parameters
|
||||
virtual std::vector<params_t> search_space() const;
|
||||
virtual params_t heuristics() const;
|
||||
@@ -25,7 +25,7 @@ private:
|
||||
public:
|
||||
dot(int M, int N, int K, bool AT, bool BT,
|
||||
std::string a_ty, std::string b_ty,
|
||||
unsigned alignment_lda, unsigned alignment_ldb);
|
||||
unsigned align_lda, unsigned align_ldb, unsigned align_ldc);
|
||||
|
||||
// number of flops
|
||||
size_t num_flops() const;
|
||||
@@ -70,6 +70,7 @@ private:
|
||||
std::string b_ty_;
|
||||
unsigned align_lda_;
|
||||
unsigned align_ldb_;
|
||||
unsigned align_ldc_;
|
||||
driver::buffer *locks_;
|
||||
};
|
||||
|
||||
|
@@ -64,7 +64,7 @@ public:
|
||||
int T, int R, int S, int NF,
|
||||
int stride_h, int stride_w,
|
||||
const int32_t* shift_h, const int32_t* shift_w,
|
||||
std::string a_ty = "fp32", std::string b_ty = "fp32",
|
||||
std::string a_ty = "float", std::string b_ty = "float",
|
||||
op_t ty = FPROP, bool bias = false, layout_t layout = CHWN);
|
||||
|
||||
// look-up table
|
||||
@@ -86,8 +86,8 @@ public:
|
||||
size_t num_flops() const;
|
||||
// source
|
||||
void triton_c_src(std::ostream &os) const;
|
||||
// comparison
|
||||
bool operator<(const base& other) const;
|
||||
// retuning parameters
|
||||
std::vector<int64_t> retune_params() const;
|
||||
// clone
|
||||
base* clone() const;
|
||||
// cpu reference
|
||||
|
@@ -30,19 +30,18 @@ using triton::lang::return_void;
|
||||
"for" { return return_impl(FOR, yytext); }
|
||||
"while" { return return_impl(WHILE, yytext); }
|
||||
"void" { return return_impl(VOID, yytext); }
|
||||
"uint1" { return return_impl(UINT1, yytext); }
|
||||
"uint8" { return return_impl(UINT8, yytext); }
|
||||
"uint16" { return return_impl(UINT16, yytext); }
|
||||
"uint32" { return return_impl(UINT32, yytext); }
|
||||
"uint64" { return return_impl(UINT64, yytext); }
|
||||
"int1" { return return_impl(INT1, yytext); }
|
||||
"int8" { return return_impl(INT8, yytext); }
|
||||
"int16" { return return_impl(INT16, yytext); }
|
||||
"int32" { return return_impl(INT32, yytext); }
|
||||
"int64" { return return_impl(INT64, yytext); }
|
||||
"fp16" { return return_impl(FP16, yytext); }
|
||||
"fp32" { return return_impl(FP32, yytext); }
|
||||
"fp64" { return return_impl(FP64, yytext); }
|
||||
"uchar" { return return_impl(UCHAR, yytext); }
|
||||
"ushort" { return return_impl(USHORT, yytext); }
|
||||
"uint" { return return_impl(UINT, yytext); }
|
||||
"ulong" { return return_impl(ULONG, yytext); }
|
||||
"bool" { return return_impl(BOOL, yytext); }
|
||||
"char" { return return_impl(CHAR, yytext); }
|
||||
"short" { return return_impl(SHORT, yytext); }
|
||||
"int" { return return_impl(INT, yytext); }
|
||||
"long" { return return_impl(LONG, yytext); }
|
||||
"half" { return return_impl(HALF, yytext); }
|
||||
"float" { return return_impl(FLOAT, yytext); }
|
||||
"double" { return return_impl(DOUBLE, yytext); }
|
||||
"..." { return return_impl(ELLIPSIS, yytext); }
|
||||
"get_range_id" { return return_impl(GET_RANGE_ID, yytext); }
|
||||
"get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); }
|
||||
|
@@ -78,7 +78,7 @@ public:
|
||||
void target_dependent(ir::module &module) {
|
||||
alignment_info.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
reassociate.run(module);
|
||||
// reassociate.run(module);
|
||||
if(target_->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
shmem_liveness.run(module);
|
||||
@@ -86,7 +86,7 @@ public:
|
||||
shmem_barriers.run(module);
|
||||
}
|
||||
vectorize.run(module);
|
||||
optimize_dce.run(module);
|
||||
// optimize_dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
|
@@ -37,7 +37,7 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
|
||||
stream->synchronize();
|
||||
while(total_time*1e-9 < 1e-3){
|
||||
float norm = 1;
|
||||
// normalize clock if possible to get roughly constant result
|
||||
// normalize clock if possible to reduce noise in auto-tuning
|
||||
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
|
||||
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
|
||||
tmr.start();
|
||||
|
@@ -74,6 +74,11 @@ unsigned distributed_tile::get_linear_index(indices_t idx) {
|
||||
return indices_[idx];
|
||||
}
|
||||
|
||||
indices_t distributed_tile::get_ordered_indices(unsigned id) {
|
||||
return ordered_indices_.at(id);
|
||||
}
|
||||
|
||||
|
||||
void distributed_tile::for_each(std::function<void (indices_t)> fn) {
|
||||
for(unsigned i = 0; i < ordered_indices_.size(); i++)
|
||||
if(i % vector_size_ == 0)
|
||||
@@ -779,13 +784,21 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
// store
|
||||
if(auto *x = dynamic_cast<ir::masked_store_inst*>(ins)){
|
||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
|
||||
tile *scalars = tmap_.at(x->get_value_operand());
|
||||
distributed_tile* scalars = (distributed_tile*)tmap_.at(x->get_value_operand());
|
||||
ir::value *mask = x->get_mask_operand();
|
||||
distributed_tile* preds = (distributed_tile*)tmap_.at(mask);
|
||||
ptrs->for_each([&](indices_t idx){
|
||||
Value *scalar = scalars->get_value(idx);
|
||||
Value *ptr = ptrs->get_value(idx);
|
||||
Value *pred = preds->get_value(idx);
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
builder.CreateCondBr(pred, mask_then_bb, mask_done_bb);
|
||||
builder.SetInsertPoint(mask_then_bb);
|
||||
builder.CreateStore(scalar, ptr);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
|
||||
// std::string offset = "";
|
||||
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
|
||||
// if(gep->getNumIndices() == 1)
|
||||
@@ -796,14 +809,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
// std::string asm_str = "@$0 st.global.b32 [$1" + offset + "], $2;";
|
||||
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,l,f", true);
|
||||
// builder.CreateCall(iasm, {pred, ptr, scalar});
|
||||
|
||||
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
|
||||
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
|
||||
builder.CreateCondBr(pred, mask_then_bb, mask_done_bb);
|
||||
builder.SetInsertPoint(mask_then_bb);
|
||||
builder.CreateStore(scalar, ptr);
|
||||
builder.CreateBr(mask_done_bb);
|
||||
builder.SetInsertPoint(mask_done_bb);
|
||||
});
|
||||
}
|
||||
else if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
|
||||
@@ -893,11 +898,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
ir::value* in = ins->get_operand(0);
|
||||
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
|
||||
result->for_each([&](indices_t out_idx){
|
||||
indices_t in_idx;
|
||||
for(size_t k = 0; k < shapes.size(); k++){
|
||||
if(shapes[k]->get_value() > 1)
|
||||
in_idx.push_back(out_idx[k]);
|
||||
}
|
||||
unsigned pos = result->get_linear_index(out_idx);
|
||||
indices_t in_idx = in_tile->get_ordered_indices(pos);
|
||||
result->set_value(out_idx, in_tile->get_value(in_idx));
|
||||
});
|
||||
}
|
||||
|
@@ -63,14 +63,19 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
else
|
||||
shapes = v->get_type()->get_tile_shapes();
|
||||
// Reshape
|
||||
if(dynamic_cast<ir::reshape_inst*>(v)){
|
||||
if(dynamic_cast<ir::reshape_inst*>(v)) {
|
||||
ir::value *op = v->get_operand(0);
|
||||
unsigned current = 0;
|
||||
bool is_skewed = false;
|
||||
for(unsigned i = 0; i < shapes.size(); i ++){
|
||||
if(shapes[i] == one)
|
||||
bool is_one = shapes[i] == one;
|
||||
bool is_same = shapes[i] == op->get_type()->get_tile_shapes()[current];
|
||||
if(is_one)
|
||||
static_params_.insert({{v, i}, 1});
|
||||
else
|
||||
else if(!is_skewed && is_same)
|
||||
add_constraint({v, i}, {op, current++});
|
||||
else
|
||||
is_skewed = true;
|
||||
}
|
||||
}
|
||||
// Splat
|
||||
@@ -81,9 +86,8 @@ void tune::init_c_graph(ir::instruction *v) {
|
||||
else if(dynamic_cast<ir::trans_inst*>(v)){
|
||||
ir::value *op = v->get_operand(0);
|
||||
size_t n_shapes = shapes.size();
|
||||
for(unsigned i = 0; i < n_shapes; i++){
|
||||
for(unsigned i = 0; i < n_shapes; i++)
|
||||
add_constraint({v, (i + 1) % n_shapes}, {op, i});
|
||||
}
|
||||
}
|
||||
// Broadcast
|
||||
else if(dynamic_cast<ir::broadcast_inst*>(v)){
|
||||
@@ -247,14 +251,14 @@ void tune::run(ir::module &mod) {
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 8));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 8));
|
||||
*params_.at(i).at("nts.d0") = *tmp1;
|
||||
*params_.at(i).at("nts.d1") = *tmp2;
|
||||
}
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
#include "triton/dnn/base.h"
|
||||
#include "triton/runtime/jit.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
@@ -31,7 +32,7 @@ params_t base::heuristics() const {
|
||||
}
|
||||
|
||||
std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
|
||||
static std::map<base*, std::unique_ptr<rt::jit>, cmp_recompile> m_jit;
|
||||
static std::unordered_map<base*, std::unique_ptr<rt::jit>, recompile_hash, recompile_equal> m_jit;
|
||||
driver::context* ctx = stream->context();
|
||||
rt::jit* jit;
|
||||
/* the current template has not already been compiled */
|
||||
|
@@ -30,7 +30,7 @@ namespace dnn{
|
||||
* --------------- */
|
||||
|
||||
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty, float eps)
|
||||
: base("batchnorm"),
|
||||
: base("batchnorm_forward"),
|
||||
C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) {
|
||||
DHWB_ = D_*H_*W_*B_;
|
||||
rcpDHWB_ = (float)1 / DHWB_;
|
||||
@@ -40,12 +40,9 @@ size_t batchnorm_forward::num_flops() const {
|
||||
return C_*DHWB_;
|
||||
}
|
||||
|
||||
bool batchnorm_forward::operator <(const base& other) const {
|
||||
auto *y = dynamic_cast<const batchnorm_forward*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(C_, D_, H_, W_, B_, ty_)
|
||||
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
|
||||
|
||||
std::vector<int64_t> batchnorm_forward::retune_params() const {
|
||||
return {C_, D_, H_, W_, B_};
|
||||
}
|
||||
|
||||
base* batchnorm_forward::clone() const {
|
||||
@@ -74,50 +71,50 @@ void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *ker
|
||||
void batchnorm_forward::triton_c_src(std::ostream &os) const {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int32 TM = {32, 64, 128};
|
||||
const tunable int TM = {32, 64, 128};
|
||||
|
||||
void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
||||
restrict read_only fp32 *X,
|
||||
restrict read_only fp32 *G,
|
||||
restrict read_only fp32 *B,
|
||||
int32 DHWN,
|
||||
fp32 rcpDHWN, fp32 eps) {
|
||||
int32 rx[TM] = 0 ... TM;
|
||||
fp32 *px[TM];
|
||||
fp32 x[TM];
|
||||
int32 c = get_range_id(1);
|
||||
fp32 g = *(G + c);
|
||||
fp32 b = *(B + c);
|
||||
void batchnorm_forward(float *Y, float *M, float *V,
|
||||
restrict read_only float *X,
|
||||
restrict read_only float *G,
|
||||
restrict read_only float *B,
|
||||
int DHWN,
|
||||
float rcpDHWN, float eps) {
|
||||
int rx[TM] = 0 ... TM;
|
||||
float *px[TM];
|
||||
float x[TM] = 0;
|
||||
int c = get_range_id(1);
|
||||
float g = *(G + c);
|
||||
float b = *(B + c);
|
||||
|
||||
fp32 mean[TM] = 0;
|
||||
float mean[TM] = 0;
|
||||
px = X + rx + c*DHWN;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
x = *px;
|
||||
mean = mean + x;
|
||||
px = px + TM;
|
||||
}
|
||||
fp32 *pm = M + c;
|
||||
fp32 m = __sum(mean) * rcpDHWN;
|
||||
float *pm = M + c;
|
||||
float m = __sum(mean) * rcpDHWN;
|
||||
*pm = m;
|
||||
|
||||
fp32 var[TM] = 0;
|
||||
float var[TM] = 0;
|
||||
px = X + rx + c*DHWN;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
x = *px;
|
||||
x = x - m;
|
||||
var = var + x*x;
|
||||
px = px + TM;
|
||||
}
|
||||
fp32 v = __sum(var) * rcpDHWN;
|
||||
fp32 *pv = V + c;
|
||||
float v = __sum(var) * rcpDHWN;
|
||||
float *pv = V + c;
|
||||
*pv = v;
|
||||
fp32 rstdg = 1 / sqrt(v + eps) * g;
|
||||
float rstdg = 1 / sqrt(v + eps) * g;
|
||||
|
||||
px = X + rx + c*DHWN;
|
||||
fp32* py[TM] = Y + rx + c*DHWN;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
float* py[TM] = Y + rx + c*DHWN;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
x = *px;
|
||||
fp32 y[TM] = (x - m)*rstdg + b;
|
||||
float y[TM] = (x - m)*rstdg + b;
|
||||
*py = y;
|
||||
px = px + TM;
|
||||
py = py + TM;
|
||||
@@ -130,7 +127,7 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
|
||||
* --------------- */
|
||||
|
||||
batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps)
|
||||
: base("batchnorm"),
|
||||
: base("batchnorm_backward"),
|
||||
C_(C), D_(D), H_(H), W_(W), B_(B),
|
||||
ty_(ty), eps_(eps)
|
||||
{ }
|
||||
@@ -139,12 +136,8 @@ size_t batchnorm_backward::num_flops() const {
|
||||
return C_*D_*H_*W_*B_;
|
||||
}
|
||||
|
||||
bool batchnorm_backward::operator <(const base& other) const {
|
||||
auto *y = dynamic_cast<const batchnorm_backward*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(C_, D_, H_, W_, B_, ty_)
|
||||
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
|
||||
std::vector<int64_t> batchnorm_backward::retune_params() const {
|
||||
return {C_, D_, H_, W_, B_};
|
||||
}
|
||||
|
||||
base* batchnorm_backward::clone() const {
|
||||
@@ -174,54 +167,54 @@ void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *ke
|
||||
void batchnorm_backward::triton_c_src(std::ostream &os) const {
|
||||
os <<
|
||||
R"(
|
||||
const tunable int32 TM = {32, 64, 128};
|
||||
const tunable int TM = {32, 64, 128};
|
||||
|
||||
void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB,
|
||||
restrict read_only fp32 *DY,
|
||||
restrict read_only fp32 *X,
|
||||
restrict read_only fp32 *G,
|
||||
restrict read_only fp32 *M,
|
||||
restrict read_only fp32 *V,
|
||||
int32 DHWN, fp32 rcpDHWN, fp32 epsilon) {
|
||||
int32 rx[TM] = 0 ... TM;
|
||||
int32 c = get_range_id(1);
|
||||
int32 offset = c*DHWN;
|
||||
fp32 g = *(G + c);
|
||||
fp32 mean = *(M + c);
|
||||
fp32 var = *(V + c);
|
||||
fp32 rstd = 1 / sqrt(var + epsilon);
|
||||
fp32* px[TM];
|
||||
fp32* pdx[TM];
|
||||
fp32* pdy[TM];
|
||||
void batchnorm_backward(float *DX, float *DG, float *DB,
|
||||
restrict read_only float *DY,
|
||||
restrict read_only float *X,
|
||||
restrict read_only float *G,
|
||||
restrict read_only float *M,
|
||||
restrict read_only float *V,
|
||||
int DHWN, float rcpDHWN, float epsilon) {
|
||||
int rx[TM] = 0 ... TM;
|
||||
int c = get_range_id(1);
|
||||
int offset = c*DHWN;
|
||||
float g = *(G + c);
|
||||
float mean = *(M + c);
|
||||
float var = *(V + c);
|
||||
float rstd = 1 / sqrt(var + epsilon);
|
||||
float* px[TM];
|
||||
float* pdx[TM];
|
||||
float* pdy[TM];
|
||||
|
||||
px = X + rx + offset;
|
||||
pdy = DY + rx + offset;
|
||||
fp32 dg[TM] = 0;
|
||||
fp32 db[TM] = 0;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
fp32 x[TM] = *px;
|
||||
fp32 dy[TM] = *pdy;
|
||||
float dg[TM] = 0;
|
||||
float db[TM] = 0;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
float x[TM] = *px;
|
||||
float dy[TM] = *pdy;
|
||||
dg = dg + dy*(x - mean)*rstd;
|
||||
db = db + dy;
|
||||
px = px + TM;
|
||||
pdy = pdy + TM;
|
||||
}
|
||||
fp32 sdg = __sum(dg);
|
||||
fp32 sdb = __sum(db);
|
||||
fp32 *pdg = DG + c;
|
||||
fp32 *pdb = DB + c;
|
||||
float sdg = __sum(dg);
|
||||
float sdb = __sum(db);
|
||||
float *pdg = DG + c;
|
||||
float *pdb = DB + c;
|
||||
*pdg = sdg;
|
||||
*pdb = sdb;
|
||||
|
||||
px = X + rx + offset;
|
||||
pdy = DY + rx + offset;
|
||||
pdx = DX + rx + offset;
|
||||
for(int32 i = 0; i < DHWN; i = i + TM){
|
||||
fp32 x[TM] = *px;
|
||||
fp32 dy[TM] = *pdy;
|
||||
fp32 xhat[TM] = (x - mean) * rstd;
|
||||
fp32 xtmp[TM] = (xhat * dg + db) * rcpDHWN;
|
||||
fp32 dx[TM] = (dy - xtmp) * rstd * g;
|
||||
for(int i = 0; i < DHWN; i = i + TM){
|
||||
float x[TM] = *px;
|
||||
float dy[TM] = *pdy;
|
||||
float xhat[TM] = (x - mean) * rstd;
|
||||
float xtmp[TM] = (xhat * dg + db) * rcpDHWN;
|
||||
float dx[TM] = (dy - xtmp) * rstd * g;
|
||||
*pdx = dx;
|
||||
px = px + TM;
|
||||
pdy = pdy + TM;
|
||||
|
@@ -10,12 +10,8 @@ size_t dot::num_flops() const {
|
||||
return 2.*nblocks_*BS_*BS_*N_;
|
||||
}
|
||||
|
||||
bool dot::operator <(const base& other) const {
|
||||
auto *y = dynamic_cast<const dot*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(N_, S_, C_, BS_, nlocks_, ab_ty_, c_ty_, op_)
|
||||
< std::tie(y->N_, y->S_, y->C_, y->BS_, y->nlocks_, y->ab_ty_, y->c_ty_, y->op_);
|
||||
std::vector<int64_t> dot::retune_params() const{
|
||||
return {N_, S_, C_, BS_, nlocks_, op_};
|
||||
}
|
||||
|
||||
std::vector<params_t> dot::search_space() const {
|
||||
@@ -92,35 +88,35 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string ldb1 = (op_ == FPROP) ? "*TK" : "" ;
|
||||
std::string result =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {)" + std::to_string(BS_) + R"(};
|
||||
const tunable int32 TK = {)" + std::to_string(BS_) + R"(};
|
||||
const tunable int TM = {16, 32, 64, 128};
|
||||
const tunable int TN = {)" + std::to_string(BS_) + R"(};
|
||||
const tunable int TK = {)" + std::to_string(BS_) + R"(};
|
||||
|
||||
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + ab_ty_ + R"( *B,
|
||||
)" + c_ty_ + R"(* C,
|
||||
int32 lda, int32 ldc, int32 N,
|
||||
int32* lut, int32* locks, int32 nlocks){
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
fp32 acc[TM, TN] = 0;
|
||||
int32 rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int32 ryb[TN] = 0 ... TN;
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
int1 checka[TM, TK] = (rxa < N)[:, newaxis];
|
||||
int32 offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda;
|
||||
int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
||||
int32 *header = lut + ridy * 4;
|
||||
int32 offset = *(header + 0);
|
||||
int32 K = *(header + 1);
|
||||
int32 column = *(header + 2);
|
||||
int32 lockid = *(header + 3);
|
||||
int32 *plut = lut + offset * 2;
|
||||
for(int32 k = K; k > 0; k = k - 1)
|
||||
int lda, int ldc, int N,
|
||||
int* lut, int* locks, int nlocks){
|
||||
int ridx = get_range_id(0);
|
||||
int ridy = get_range_id(1);
|
||||
float acc[TM, TN] = 0;
|
||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int ryb[TN] = 0 ... TN;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
bool checka[TM, TK] = (rxa < N)[:, newaxis];
|
||||
int offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda;
|
||||
int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
||||
int *header = lut + ridy * 4;
|
||||
int offset = *(header + 0);
|
||||
int K = *(header + 1);
|
||||
int column = *(header + 2);
|
||||
int lockid = *(header + 3);
|
||||
int *plut = lut + offset * 2;
|
||||
for(int k = K; k > 0; k = k - 1)
|
||||
{
|
||||
int32 ak = *(plut + 0);
|
||||
int32 bk = *(plut + 1);
|
||||
int ak = *(plut + 0);
|
||||
int bk = *(plut + 1);
|
||||
)" + ab_ty_ + "* pa[" + sizea + R"(] = A + offa + ak * TK * lda;
|
||||
)" + ab_ty_ + "* pb[" + sizeb + R"(] = B + offb + bk * TK * TN;
|
||||
)" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0;
|
||||
@@ -128,19 +124,19 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
acc = dot()" + usea + ", " + useb + R"(, acc);
|
||||
plut = plut + 2;
|
||||
}
|
||||
int32 rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int32 ryc[TN] = column * TN + (0 ... TN);
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = column * TN + (0 ... TN);
|
||||
)" + c_ty_ + R"(" c[TM, TN] = acc;
|
||||
)" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
|
||||
int1 checkc[TM, TN] = (rxc < N)[:, newaxis];
|
||||
bool checkc[TM, TN] = (rxc < N)[:, newaxis];
|
||||
if(lockid == 0)
|
||||
@checkc *pc = c;
|
||||
else
|
||||
{
|
||||
int32 *plock = locks + ridx*nlocks + lockid - 1;
|
||||
int32 *pcount = plock + get_num_program(0)*nlocks;
|
||||
int *plock = locks + ridx*nlocks + lockid - 1;
|
||||
int *pcount = plock + get_num_program(0)*nlocks;
|
||||
while(__atomic_cas(plock, 0, 1));
|
||||
int32 count = *pcount;
|
||||
int count = *pcount;
|
||||
if(count == 0){
|
||||
@checkc *pc = c;
|
||||
}
|
||||
|
208
lib/dnn/conv.cpp
208
lib/dnn/conv.cpp
@@ -98,20 +98,12 @@ conv::conv(int B, int NC,
|
||||
}
|
||||
|
||||
// comparison for maps
|
||||
bool conv::operator<(const base& other) const {
|
||||
auto *y = dynamic_cast<const conv*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(NB_, NC_, AD_, AH_, AW_,
|
||||
NF_, BD_, BH_, BW_,
|
||||
pad_d_, pad_h_, pad_w_,
|
||||
stride_d_, stride_h_, stride_w_,
|
||||
a_ty_, b_ty_, ty_, bias_)
|
||||
< std::tie(y->NB_, y->NC_, y->AD_, y->AH_, y->AW_,
|
||||
y->NF_, y->BD_, y->BH_, y->BW_,
|
||||
y->pad_d_, y->pad_h_, y->pad_w_,
|
||||
y->stride_d_, y->stride_h_, y->stride_w_,
|
||||
y->a_ty_, y->b_ty_, y->ty_, y->bias_);
|
||||
std::vector<int64_t> conv::retune_params() const {
|
||||
return {NB_, NC_, AD_, AH_, AW_,
|
||||
NF_, BD_, BH_, BW_,
|
||||
pad_d_, pad_h_, pad_w_,
|
||||
stride_d_, stride_h_, stride_w_,
|
||||
ty_, bias_};
|
||||
}
|
||||
|
||||
// clone
|
||||
@@ -549,114 +541,114 @@ void conv::triton_c_src(std::ostream &os) const {
|
||||
|
||||
os <<
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64};
|
||||
const tunable int32 TN = {16, 32, 64};
|
||||
const tunable int32 TK = {)" << TK_ << R"(};
|
||||
const tunable int32 GZ = {1};
|
||||
const tunable int TM = {16, 32, 64};
|
||||
const tunable int TN = {16, 32, 64};
|
||||
const tunable int TK = {)" << TK_ << R"(};
|
||||
const tunable int GZ = {1};
|
||||
)";
|
||||
if(is_a_deltas_cst)
|
||||
os << "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n";
|
||||
os << "__constant__ int* delta = alloc_const int[" + std::to_string(h_a_deltas_.size()) + "];\n";
|
||||
if(b_lut_ && is_b_deltas_cst_)
|
||||
os << "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n";
|
||||
os << "__constant__ int* b_delta = alloc_const int[" + std::to_string(h_b_deltas_.size()) + "];\n";
|
||||
if(is_mask_cst_)
|
||||
os << "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n";
|
||||
os << "__constant__ int* masks = alloc_const int[" + std::to_string(h_masks_.size()) + "];\n";
|
||||
os << R"(
|
||||
|
||||
void conv(read_only restrict )" << a_ty_ << R"( *a,
|
||||
read_only restrict )" << b_ty_ << R"( *b,
|
||||
fp32 *c,
|
||||
fp32 *bias,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 AH, int32 AW,
|
||||
int32 BH, int32 BW,
|
||||
int32 CH, int32 CW,
|
||||
int32 NC,
|
||||
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w,
|
||||
int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k,
|
||||
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q,
|
||||
int32 pad_h, int32 pad_w,
|
||||
int32 stride_h, int32 stride_w,
|
||||
int32 upsample_h, int32 upsample_w,
|
||||
int32 off_uh, int32 off_uw,
|
||||
int32 off_uah, int32 off_uaw,
|
||||
int32 off_uch, int32 off_ucw,
|
||||
int32 *locks, int32 grid0, int32 grid1)";
|
||||
float *c,
|
||||
float *bias,
|
||||
int M, int N, int K,
|
||||
int AH, int AW,
|
||||
int BH, int BW,
|
||||
int CH, int CW,
|
||||
int NC,
|
||||
int lda_n, int lda_c, int lda_d, int lda_h, int lda_w,
|
||||
int ldb_c, int ldb_t, int ldb_r, int ldb_s, int ldb_k,
|
||||
int ldc_n, int ldc_k, int ldc_m, int ldc_p, int ldc_q,
|
||||
int pad_h, int pad_w,
|
||||
int stride_h, int stride_w,
|
||||
int upsample_h, int upsample_w,
|
||||
int off_uh, int off_uw,
|
||||
int off_uah, int off_uaw,
|
||||
int off_uch, int off_ucw,
|
||||
int *locks, int grid0, int grid1)";
|
||||
if(!is_a_deltas_cst)
|
||||
os << ", int32* delta";
|
||||
os << ", int* delta";
|
||||
if(b_lut_ && !is_b_deltas_cst_)
|
||||
os << ", int32* b_delta";
|
||||
os << ", int* b_delta";
|
||||
if(!is_mask_cst_)
|
||||
os << ", int32* masks";
|
||||
os << ", int* masks";
|
||||
os << R"(){
|
||||
int32 rxa[TM] = get_global_range[TM](0);
|
||||
int32 rb0[TN] = get_global_range[TN](1);
|
||||
int32 rz = get_global_range[1](2);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 C[TM, TN] = 0;
|
||||
int32 ldlut = )" + std::to_string(Luts_) + R"(;
|
||||
int32 div = K / GZ;
|
||||
int32 rem = K % GZ;
|
||||
int rxa[TM] = get_global_range[TM](0);
|
||||
int rb0[TN] = get_global_range[TN](1);
|
||||
int rz = get_global_range[1](2);
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float C[TM, TN] = 0;
|
||||
int ldlut = )" + std::to_string(Luts_) + R"(;
|
||||
int div = K / GZ;
|
||||
int rem = K % GZ;
|
||||
K = select(rz < rem, div, div + rem);
|
||||
int32 offk = rz*div;
|
||||
int offk = rz*div;
|
||||
rka = rka + offk;
|
||||
rkb = rkb + offk;
|
||||
int32 rabh[TM] = rxa / CW;
|
||||
int32 raw[TM] = rxa % CW;
|
||||
int32 rab[TM] = rabh / CH;
|
||||
int32 rah[TM] = rabh % CH;
|
||||
int rabh[TM] = rxa / CW;
|
||||
int raw[TM] = rxa % CW;
|
||||
int rab[TM] = rabh / CH;
|
||||
int rah[TM] = rabh % CH;
|
||||
rah = rah)" + upaw + R"( - off_uah;
|
||||
raw = raw)" + upah + R"( - off_uaw;
|
||||
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||
int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
|
||||
int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
|
||||
int32 ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(;
|
||||
int32 ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(;
|
||||
int ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
|
||||
int ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
|
||||
int ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
|
||||
int ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(;
|
||||
int ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(;
|
||||
rar = )" + flipr + R"( rar;
|
||||
ras = )" + flips + R"( ras;
|
||||
rar = )" + upar + R"( rar;
|
||||
ras = )" + upas + R"( ras;
|
||||
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
int ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
|
||||
)" << a_ty_ << R"(* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
|
||||
if(b_lut_){
|
||||
os << R"(
|
||||
int32 rb)" + ax[0] + ax[1] + "[TK] = rkb / " + redax[2] + R"(;
|
||||
int32 rb)" + ax[2] + "[TK] = rkb % " + redax[2] + R"(;
|
||||
int32 rb)" + ax[0] + "[TK] = rb" + ax[0] + ax[1] + " / " + redax[1] + R"(;
|
||||
int32 rb)" + ax[1] + "[TK] = rb" + ax[0] + ax[1] + " % " + redax[1] + R"(;
|
||||
int rb)" + ax[0] + ax[1] + "[TK] = rkb / " + redax[2] + R"(;
|
||||
int rb)" + ax[2] + "[TK] = rkb % " + redax[2] + R"(;
|
||||
int rb)" + ax[0] + "[TK] = rb" + ax[0] + ax[1] + " / " + redax[1] + R"(;
|
||||
int rb)" + ax[1] + "[TK] = rb" + ax[0] + ax[1] + " % " + redax[1] + R"(;
|
||||
rbr = rbr*upsample_h + off_uh;
|
||||
rbs = rbs*upsample_w + off_uw;
|
||||
int32 offdb[TK] = rkb % ldlut;
|
||||
int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + rbs*ldb_s;
|
||||
)" + b_delta_mem + R"( int32* pdb[TK] = b_delta + offdb + off_uw*ldlut + off_uh*ldlut*upsample_w;
|
||||
int32 db[TK] = *pdb;)";
|
||||
int offdb[TK] = rkb % ldlut;
|
||||
int rb1[TK] = rbc*ldb_c + rbr*ldb_r + rbs*ldb_s;
|
||||
)" + b_delta_mem + R"( int* pdb[TK] = b_delta + offdb + off_uw*ldlut + off_uh*ldlut*upsample_w;
|
||||
int db[TK] = *pdb;)";
|
||||
}
|
||||
else{
|
||||
os << R"(
|
||||
int32 rb1[TK] = rkb)" + ldb0 + ";";
|
||||
int rb1[TK] = rkb)" + ldb0 + ";";
|
||||
}
|
||||
os << R"(
|
||||
)" << b_ty_ << R"(* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + R"(*ldb_k;
|
||||
int32 offda[TK] = rka % ldlut;
|
||||
)" + a_delta_mem + R"( int32* pincd[TK] = delta + offda;
|
||||
)" + a_delta_mem + R"( int32* pda[TK] = delta + ldlut + offda + off_uw*ldlut + off_uh*ldlut*upsample_w;
|
||||
int32 da[TK] = *pda;
|
||||
int32 incd[TK] = *pincd;
|
||||
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
|
||||
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
|
||||
int32 offma = offk % ldlut;
|
||||
)" + masks_mem + R"( int32* pm[TM] = masks + ldlut + offma + maskw*ldlut + maskh*ldlut*(2*pad_w + 1) + off_uw*ldlut*(2*pad_w+1)*(2*pad_h+1) + off_uh*ldlut*(2*pad_w+1)*(2*pad_h+1)*upsample_w;
|
||||
)" + a_delta_mem + R"( int32* pincm[TM] = delta + offma;
|
||||
int32 incm[TM] = *pincm;
|
||||
int32 maska0[TM] = *pm;
|
||||
int32 maska1[TK] = 1 << (0 ... TK);
|
||||
int1 checka[TM, TK] = (maska0[:, newaxis] & maska1[newaxis, :]) > 0;
|
||||
int1 checkb0[TN] = rb0 < N;
|
||||
int1 checkb)" + BS + " = checkb0" + bcb0 + R"(;
|
||||
int offda[TK] = rka % ldlut;
|
||||
)" + a_delta_mem + R"( int* pincd[TK] = delta + offda;
|
||||
)" + a_delta_mem + R"( int* pda[TK] = delta + ldlut + offda + off_uw*ldlut + off_uh*ldlut*upsample_w;
|
||||
int da[TK] = *pda;
|
||||
int incd[TK] = *pincd;
|
||||
int maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
|
||||
int maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
|
||||
int offma = offk % ldlut;
|
||||
)" + masks_mem + R"( int* pm[TM] = masks + ldlut + offma + maskw*ldlut + maskh*ldlut*(2*pad_w + 1) + off_uw*ldlut*(2*pad_w+1)*(2*pad_h+1) + off_uh*ldlut*(2*pad_w+1)*(2*pad_h+1)*upsample_w;
|
||||
)" + a_delta_mem + R"( int* pincm[TM] = delta + offma;
|
||||
int incm[TM] = *pincm;
|
||||
int maska0[TM] = *pm;
|
||||
int maska1[TK] = 1 << (0 ... TK);
|
||||
bool checka[TM, TK] = (maska0[:, newaxis] & maska1[newaxis, :]) > 0;
|
||||
bool checkb0[TN] = rb0 < N;
|
||||
bool checkb)" + BS + " = checkb0" + bcb0 + R"(;
|
||||
)" << a_ty_ << R"( a[TM, TK] = checka ? *pa : 0;
|
||||
)" << b_ty_ << R"( b)" + BS + R"( = checkb ? *pb : 0;
|
||||
int32 rkamin[TK] = rka - offk + TK;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
int rkamin[TK] = rka - offk + TK;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
C = dot(a, )" + useb + R"(, C);
|
||||
pa = pa + da[newaxis, :];
|
||||
pb = pb + )" + inc_pb + R"(;
|
||||
@@ -673,7 +665,7 @@ if(b_lut_){
|
||||
pm = pm + incm;
|
||||
pincm = pincm + incm;
|
||||
incm = *pincm;
|
||||
int1 checka1[TK] = (rkamin < k);
|
||||
bool checka1[TK] = (rkamin < k);
|
||||
maska0 = *pm;
|
||||
checka = (maska0[:, newaxis] & maska1[newaxis, :]) > 0;
|
||||
checka = checka && checka1[newaxis,:];
|
||||
@@ -681,31 +673,31 @@ if(b_lut_){
|
||||
checkb = checkb && (k > TK);
|
||||
@checkb b = *pb;
|
||||
}
|
||||
int32 rxc[TM] = get_global_range[TM](0);
|
||||
int32 rc1[TN] = get_global_range[TN](1);
|
||||
int32 rcn[TM] = rxc / (CH*CW);
|
||||
int32 rcpq[TM] = rxc % (CH*CW);
|
||||
int32 rcp[TM] = rcpq / CW;
|
||||
int32 rcq[TM] = rcpq % CW;
|
||||
int rxc[TM] = get_global_range[TM](0);
|
||||
int rc1[TN] = get_global_range[TN](1);
|
||||
int rcn[TM] = rxc / (CH*CW);
|
||||
int rcpq[TM] = rxc % (CH*CW);
|
||||
int rcp[TM] = rcpq / CW;
|
||||
int rcq[TM] = rcpq % CW;
|
||||
rcp = rcp * upsample_h + off_uch;
|
||||
rcq = rcq * upsample_w + off_ucw;
|
||||
int1 checkc1[TN] = rc1 < N;
|
||||
int32 rc0[TM] = rcn * ldc_n + rcp * ldc_p + rcq * ldc_q;
|
||||
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
int32 *plock = locks + ridx + ridy*grid0;
|
||||
bool checkc1[TN] = rc1 < N;
|
||||
int rc0[TM] = rcn * ldc_n + rcp * ldc_p + rcq * ldc_q;
|
||||
float* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
int ridx = get_range_id(0);
|
||||
int ridy = get_range_id(1);
|
||||
int *plock = locks + ridx + ridy*grid0;
|
||||
while(__atomic_cas(plock, 0, 1) == 1);
|
||||
int32 *pcount = plock + grid0*grid1;
|
||||
int32 count = *pcount;
|
||||
int32 countp1 = select(count == GZ - 1, 0, count + 1);
|
||||
int *pcount = plock + grid0*grid1;
|
||||
int count = *pcount;
|
||||
int countp1 = select(count == GZ - 1, 0, count + 1);
|
||||
if(count == 0) {)";
|
||||
if(bias_ && ty_==FPROP){
|
||||
os << R"(
|
||||
fp32* pbias[TN] = bias + rc1;
|
||||
fp32 bias[TN] = checkc1 ? *pbias : 0;
|
||||
float* pbias[TN] = bias + rc1;
|
||||
float bias[TN] = checkc1 ? *pbias : 0;
|
||||
C = C + bias[newaxis, :];)";
|
||||
}
|
||||
os << R"(
|
||||
|
@@ -10,11 +10,11 @@ namespace dnn{
|
||||
dot::dot(int M, int N, int K,
|
||||
bool AT, bool BT,
|
||||
std::string a_ty, std::string b_ty,
|
||||
unsigned alignment_lda, unsigned alignment_ldb)
|
||||
unsigned align_lda, unsigned align_ldb, unsigned align_ldc)
|
||||
: base("matmul"),
|
||||
M_(M), N_(N), K_(K), AT_(AT), BT_(BT),
|
||||
a_ty_(a_ty), b_ty_(b_ty),
|
||||
align_lda_(alignment_lda), align_ldb_(alignment_ldb),
|
||||
align_lda_(align_lda), align_ldb_(align_ldb), align_ldc_(align_ldc),
|
||||
locks_(nullptr) {
|
||||
|
||||
}
|
||||
@@ -23,15 +23,10 @@ size_t dot::num_flops() const {
|
||||
return 2.*M_*N_*K_;
|
||||
}
|
||||
|
||||
// comparison for maps
|
||||
bool dot::operator<(const base& other) const {
|
||||
auto *y = dynamic_cast<const dot*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(M_, N_, K_, AT_, BT_,
|
||||
a_ty_, b_ty_, align_lda_, align_ldb_)
|
||||
< std::tie(y->M_, y->N_, y->K_, y->AT_, y->BT_,
|
||||
y->a_ty_, y->b_ty_, y->align_lda_, y->align_ldb_);
|
||||
// retune parameters
|
||||
std::vector<int64_t> dot::retune_params() const {
|
||||
return {M_, N_, K_, AT_, BT_,
|
||||
(int)align_lda_, (int)align_ldb_};
|
||||
}
|
||||
|
||||
// clone
|
||||
@@ -101,45 +96,45 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
|
||||
std::string res =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128, 256};
|
||||
const tunable int32 TN = {16, 32, 64, 128, 256};
|
||||
const tunable int32 TK = {32};
|
||||
const tunable int32 GZ = {1};
|
||||
const tunable int TM = {16, 32, 64, 128};
|
||||
const tunable int TN = {16, 32, 64, 128};
|
||||
const tunable int TK = {32};
|
||||
const tunable int GZ = {1};
|
||||
|
||||
void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + b_ty_ + R"( *B,
|
||||
fp32 *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
)" + align_lda_str + R"( int32 lda, )" + align_ldb_str + R"(" int32 ldb, int32 ldc,
|
||||
int32 bound, int32 *locks, int32 grid0, int32 grid1) {
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
int32 rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int32 ryb[TN] = ridy * TN + (0 ... TN);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 c[TM, TN] = 0;
|
||||
restrict read_only align(16) float *C,
|
||||
int M, int N, int K,
|
||||
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc,
|
||||
int bound, int *locks, int grid0, int grid1) {
|
||||
int ridx = get_range_id(0);
|
||||
int ridy = get_range_id(1);
|
||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int ryb[TN] = ridy * TN + (0 ... TN);
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float c[TM, TN] = 0;
|
||||
)" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
int1 checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(;
|
||||
int1 checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(;
|
||||
bool checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(;
|
||||
bool checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(;
|
||||
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
|
||||
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
c = dot()" + usea + ", " + useb + R"(, c);
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
int1 checka[)" + AS + R"(] = k > TK;
|
||||
int1 checkb[)" + BS + R"(] = k > TK;
|
||||
bool checka[)" + AS + R"(] = k > TK;
|
||||
bool checkb[)" + BS + R"(] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
int32 rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int32 ryc[TN] = ridy * TN + (0 ... TN);
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
float* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
@checkc *pc = c;
|
||||
}
|
||||
)";
|
||||
|
@@ -28,7 +28,7 @@ shift::shift(int B, int C,
|
||||
layout_(layout){
|
||||
// std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl;
|
||||
// max number of channels
|
||||
TK_ = (ty == FPROP && a_ty_ == "fp32") ? 8 : 32;
|
||||
TK_ = (ty == FPROP && a_ty_ == "float") ? 8 : 32;
|
||||
MAX_C_ = 8192 + TK_;
|
||||
// activation sizes
|
||||
CD_ = AD_ / stride_d_;
|
||||
@@ -204,26 +204,15 @@ size_t shift::ldb() const
|
||||
size_t shift::ldc() const
|
||||
{ return M_; }
|
||||
|
||||
bool shift::operator <(const base& other) const{
|
||||
auto *y = dynamic_cast<const shift*>(&other);
|
||||
if(!y)
|
||||
return true;
|
||||
return std::tie(B_, C_, F_,
|
||||
AD_, AH_, AW_,
|
||||
BD_, BH_, BW_,
|
||||
CD_, CH_, CW_,
|
||||
shift_h_, shift_w_,
|
||||
stride_h_, stride_w_,
|
||||
layout_, op_,
|
||||
bias_)
|
||||
< std::tie(y->B_, y->C_, y->F_,
|
||||
y->AD_, y->AH_, y->AW_,
|
||||
y->BD_, y->BH_, y->BW_,
|
||||
y->CD_, y->CH_, y->CW_,
|
||||
y->shift_h_, y->shift_w_,
|
||||
y->stride_h_, y->stride_w_,
|
||||
y->layout_, y->op_,
|
||||
y->bias_);
|
||||
std::vector<int64_t> shift::retune_params() const {
|
||||
return {B_, C_, F_,
|
||||
AD_, AH_, AW_,
|
||||
BD_, BH_, BW_,
|
||||
CD_, CH_, CW_,
|
||||
(int64_t)shift_h_, (int64_t)shift_w_,
|
||||
stride_h_, stride_w_,
|
||||
layout_, op_,
|
||||
bias_};
|
||||
}
|
||||
|
||||
void shift::init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info) {
|
||||
@@ -325,56 +314,56 @@ void shift::triton_c_src(std::ostream &os) const {
|
||||
|
||||
if(is_chwn) {
|
||||
return R"(
|
||||
int32 )" + rx + "wh[" + sz + "] = " + rkx + " / " + B + R"(;
|
||||
int32 )" + rx + "b[" + sz + "] = " + rkx + " % " + B + R"(;
|
||||
int32 )" + rx + "w[" + sz + "] = (" + rx + "wh % " + CW + R"() + pad_w;
|
||||
int32 )" + rx + "h[" + sz + "] = (" + rx + "wh / " + CW + R"() + pad_h;)";
|
||||
int )" + rx + "wh[" + sz + "] = " + rkx + " / " + B + R"(;
|
||||
int )" + rx + "b[" + sz + "] = " + rkx + " % " + B + R"(;
|
||||
int )" + rx + "w[" + sz + "] = (" + rx + "wh % " + CW + R"() + pad_w;
|
||||
int )" + rx + "h[" + sz + "] = (" + rx + "wh / " + CW + R"() + pad_h;)";
|
||||
}
|
||||
else {
|
||||
return R"(
|
||||
int32 )" + rx + "bh[" + sz + "] = " + rkx + " / " + CW + R"(;
|
||||
int32 )" + rx + "w[" + sz + "] = (" + rkx + " % " + CW + R"() + pad_w;
|
||||
int32 )" + rx + "h[" + sz + "] = (" + rx + "bh % " + CH + R"() + pad_h;
|
||||
int32 )" + rx + "b[" + sz + "] = " + rx + "bh / " + CH + ";";
|
||||
int )" + rx + "bh[" + sz + "] = " + rkx + " / " + CW + R"(;
|
||||
int )" + rx + "w[" + sz + "] = (" + rkx + " % " + CW + R"() + pad_w;
|
||||
int )" + rx + "h[" + sz + "] = (" + rx + "bh % " + CH + R"() + pad_h;
|
||||
int )" + rx + "b[" + sz + "] = " + rx + "bh / " + CH + ";";
|
||||
}
|
||||
};
|
||||
|
||||
std::string result =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {)" + std::to_string(TK_) + "};";
|
||||
const tunable int TM = {16, 32, 64, 128};
|
||||
const tunable int TN = {16, 32, 64, 128};
|
||||
const tunable int TK = {)" + std::to_string(TK_) + "};";
|
||||
if(op_ == WGRAD)
|
||||
result += "const tunable int32 GZ = {1};";
|
||||
result += "const tunable int GZ = {1};";
|
||||
else
|
||||
result += "const tunable int32 GZ = {1};";
|
||||
result += "const tunable int GZ = {1};";
|
||||
|
||||
result += R"(
|
||||
__constant__ int32* delta_a = alloc_const int32[)" + std::to_string(MAX_C_) + R"(];
|
||||
__constant__ int* delta_a = alloc_const int[)" + std::to_string(MAX_C_) + R"(];
|
||||
|
||||
void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
restrict read_only align(16) )" + b_ty_ + R"( *B,
|
||||
)" + c_ty_ + R"( *C,
|
||||
int32 M, int32 N, int32 K,
|
||||
int32 stride_h, int32 stride_w,
|
||||
multiple_of(8) int32 lda_b, multiple_of(8) int32 lda_w, multiple_of(8) int32 lda_h, multiple_of(8) int32 lda_c,
|
||||
multiple_of(8) int32 ldb_b, multiple_of(8) int32 ldb_w, multiple_of(8) int32 ldb_h, multiple_of(8) int32 ldb_c,
|
||||
multiple_of(8) int32 ldc_b, multiple_of(8) int32 ldc_w, multiple_of(8) int32 ldc_h, multiple_of(8) int32 ldc_c,
|
||||
int32 NB,
|
||||
int32 AH, int32 AW,
|
||||
int32 BH, int32 BW,
|
||||
int32 CH, int32 CW,
|
||||
int32* locks, int32 grid0, int32 grid1, int32 grid2) {
|
||||
int32 ridx = get_range_id(0);
|
||||
int32 ridy = get_range_id(1);
|
||||
int32 rz = get_range_id(2);
|
||||
int32 rxa[TM] = ridx*TM + (0 ... TM);
|
||||
int32 ryb[TN] = ridy*TN + (0 ... TN);
|
||||
int32 rka[TK] = 0 ... TK;
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
fp32 acc[TM, TN] = 0;
|
||||
int32 pad_h = BH / 2;
|
||||
int32 pad_w = BW / 2;)";
|
||||
int M, int N, int K,
|
||||
int stride_h, int stride_w,
|
||||
multiple_of(8) int lda_b, multiple_of(8) int lda_w, multiple_of(8) int lda_h, multiple_of(8) int lda_c,
|
||||
multiple_of(8) int ldb_b, multiple_of(8) int ldb_w, multiple_of(8) int ldb_h, multiple_of(8) int ldb_c,
|
||||
multiple_of(8) int ldc_b, multiple_of(8) int ldc_w, multiple_of(8) int ldc_h, multiple_of(8) int ldc_c,
|
||||
int NB,
|
||||
int AH, int AW,
|
||||
int BH, int BW,
|
||||
int CH, int CW,
|
||||
int* locks, int grid0, int grid1, int grid2) {
|
||||
int ridx = get_range_id(0);
|
||||
int ridy = get_range_id(1);
|
||||
int rz = get_range_id(2);
|
||||
int rxa[TM] = ridx*TM + (0 ... TM);
|
||||
int ryb[TN] = ridy*TN + (0 ... TN);
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float acc[TM, TN] = 0;
|
||||
int pad_h = BH / 2;
|
||||
int pad_w = BW / 2;)";
|
||||
|
||||
/* A offsets */
|
||||
if(op_ == FPROP){
|
||||
@@ -382,49 +371,49 @@ if(op_ == FPROP){
|
||||
compute_bhw("ra", "TM", "rxa") + R"(
|
||||
raw = raw * )" + stride_w + R"(;
|
||||
rah = rah * )" + stride_h + R"(;
|
||||
int32 offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||
__constant__ int32* pd[TK] = delta_a + rka;
|
||||
multiple_of(8) int32 d[TK] = *pd;
|
||||
int32 offa1[TM, TK] = d[newaxis, :];)";
|
||||
int offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
int offa0[TM, TK] = offxa[:, newaxis];
|
||||
__constant__ int* pd[TK] = delta_a + rka;
|
||||
multiple_of(8) int d[TK] = *pd;
|
||||
int offa1[TM, TK] = d[newaxis, :];)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
result +=
|
||||
compute_bhw("ra", "TM", "rxa") + R"(
|
||||
int32 offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
int32 offa0[TM, TK] = offxa[:, newaxis];
|
||||
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
|
||||
int offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
int offa0[TM, TK] = offxa[:, newaxis];
|
||||
int offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
result +=
|
||||
compute_bhw("ra", "TK", "rka") + R"(
|
||||
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||
int32 offxa[TK] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
int32 offa1[TK, TM] = offxa[:, newaxis];)";
|
||||
int offa0[TK, TM] = rxa[newaxis, :] * lda_c;
|
||||
int offxa[TK] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
|
||||
int offa1[TK, TM] = offxa[:, newaxis];)";
|
||||
}
|
||||
|
||||
/* B offsets */
|
||||
if(op_ == FPROP){
|
||||
result += R"(
|
||||
int32 offb0[TN, TK] = ryb[:, newaxis];
|
||||
int32 offb1[TN, TK] = rkb[newaxis, :] * ldb_c;)";
|
||||
int offb0[TN, TK] = ryb[:, newaxis];
|
||||
int offb1[TN, TK] = rkb[newaxis, :] * ldb_c;)";
|
||||
}
|
||||
if(op_ == BPROP){
|
||||
result += R"(
|
||||
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||
int32 offb1[TK, TN] = rkb[:, newaxis];)";
|
||||
int offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||
int offb1[TK, TN] = rkb[:, newaxis];)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
result +=
|
||||
compute_bhw("rb", "TK", "rkb") + R"(
|
||||
__constant__ int32* pd[TN] = delta_a + ryb;
|
||||
multiple_of(8) int32 d[TN] = *pd;
|
||||
multiple_of(8) int32 shift[TK, TN] = d[newaxis, :];
|
||||
__constant__ int* pd[TN] = delta_a + ryb;
|
||||
multiple_of(8) int d[TN] = *pd;
|
||||
multiple_of(8) int shift[TK, TN] = d[newaxis, :];
|
||||
rbw = rbw * )" + stride_w + R"(;
|
||||
rbh = rbh * )" + stride_h + R"(;
|
||||
int32 offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
|
||||
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||
int32 offb1[TK, TN] = offkb[:, newaxis];
|
||||
int offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
|
||||
int offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
|
||||
int offb1[TK, TN] = offkb[:, newaxis];
|
||||
)" + a_ty_ + "* pa_base[" + AS + R"(] = A + offa0;
|
||||
)" + b_ty_ + "* pb_base[" + BS + R"(] = B + offb0 + shift;
|
||||
)" + a_ty_ + "* pa[" + AS + R"(] = pa_base + offa1;
|
||||
@@ -439,14 +428,14 @@ else{
|
||||
/* Main loop */
|
||||
/* Increment A pointers */
|
||||
result += R"(
|
||||
int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(;
|
||||
int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(;
|
||||
bool checka[)" + AS + "] = (rka < K)" + bca0 + R"(;
|
||||
bool checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(;
|
||||
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
|
||||
)" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0;
|
||||
for(int32 k = K; k > 0; k = k - TK){
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
acc = dot()" + usea + "," + useb + R"(, acc);
|
||||
int1 checka[)" + AS + R"(] = k > TK;
|
||||
int1 checkb[)" + BS + R"(] = k > TK;)";
|
||||
bool checka[)" + AS + R"(] = k > TK;
|
||||
bool checkb[)" + BS + R"(] = k > TK;)";
|
||||
|
||||
/* Increment A pointers */
|
||||
if(op_ == FPROP){
|
||||
@@ -490,8 +479,8 @@ if(op_ == BPROP){
|
||||
result += R"(
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
int32 rxc[TM] = ridx*TM + (0 ... TM);
|
||||
int32 ryc[TN] = ridy*TN + (0 ... TN);)";
|
||||
int rxc[TM] = ridx*TM + (0 ... TM);
|
||||
int ryc[TN] = ridy*TN + (0 ... TN);)";
|
||||
|
||||
/* C offsets */
|
||||
if(op_ == BPROP){
|
||||
@@ -499,26 +488,26 @@ if(op_ == BPROP){
|
||||
compute_bhw("rc", "TM", "rxc") + R"(
|
||||
rcw = rcw * )" + stride_w + R"(;
|
||||
rch = rch * )" + stride_h + R"(;
|
||||
int32 offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)";
|
||||
int offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)";
|
||||
}
|
||||
if(op_ == FPROP){
|
||||
result +=
|
||||
compute_bhw("rc", "TM", "rxc") + R"(
|
||||
int32 offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)";
|
||||
int offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)";
|
||||
}
|
||||
if(op_ == WGRAD){
|
||||
result += R"(
|
||||
int32 offxc[TM] = rxc;)";
|
||||
int offxc[TM] = rxc;)";
|
||||
}
|
||||
result += R"("
|
||||
)" + c_ty_ + R"( c[TM, TN] = acc;
|
||||
)" + c_ty_ + R"(* pc[TM, TN] = C + offxc[:, newaxis] + ryc[newaxis, :]*ldc_c;
|
||||
int1 checkc0[TM] = rxc < M;
|
||||
int1 checkc1[TN] = ryc < N;
|
||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
|
||||
if(op_ == BPROP){
|
||||
result += R"(
|
||||
__constant__ int32* pd[TN] = delta_a + ryc;
|
||||
__constant__ int* pd[TN] = delta_a + ryc;
|
||||
)" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
|
||||
@checkc *shift_pc = c;
|
||||
)";
|
||||
|
@@ -174,8 +174,15 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
for(ir::metaparameter *mp: mps)
|
||||
mp->set_value(params[i++]);
|
||||
// for(size_t i = 0; i < params.size(); i++)
|
||||
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
// std::cout << std::endl;
|
||||
passes_0.tune.init(tt_module_0);
|
||||
passes_0.tune.check_constraints(errors);
|
||||
// for(auto x: errors)
|
||||
// for(auto e: x.second){
|
||||
// std::cout << x.first->get_name() << ": " << e << std::endl;
|
||||
// }
|
||||
}
|
||||
if(!errors.empty())
|
||||
return;
|
||||
@@ -212,9 +219,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
best.perf = perf;
|
||||
best.params = params;
|
||||
}
|
||||
for(size_t i = 0; i < params.size(); i++)
|
||||
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
// for(size_t i = 0; i < params.size(); i++)
|
||||
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
|
||||
// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user