[dnn] better specification of recompilation key

This commit is contained in:
Philippe Tillet
2019-08-02 17:42:48 -07:00
parent 3b92ddf7e6
commit d9945692a9
31 changed files with 418 additions and 428 deletions

View File

@@ -18,12 +18,12 @@ int main() {
int32_t pad_d = 0, pad_h = 0, pad_w = 0; int32_t pad_d = 0, pad_h = 0, pad_w = 0;
int32_t stride_d = 1, stride_h = 1, stride_w = 1; int32_t stride_d = 1, stride_h = 1, stride_w = 1;
int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1;
// triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "fp32", "fp32", triton::dnn::conv::FPROP, 0); // triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "float", "float", triton::dnn::conv::FPROP, 0);
triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF,
stride_d, stride_h, stride_w, stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w, pad_d, pad_h, pad_w,
upsample_d, upsample_h, upsample_w, upsample_d, upsample_h, upsample_w,
"fp32", "fp32", ty, 0); "float", "float", ty, 0);
// convolution configuration // convolution configuration
std::vector<float> hc(configuration.c_size()); std::vector<float> hc(configuration.c_size());
std::vector<float> rc(configuration.c_size()); std::vector<float> rc(configuration.c_size());

View File

@@ -26,7 +26,7 @@ struct perf_t {
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef float NumericT; typedef float NumericT;
std::string ty = "fp16"; std::string ty = "half";
size_t dt_nbytes = sizeof(NumericT); size_t dt_nbytes = sizeof(NumericT);
triton::driver::context* context = stream->context(); triton::driver::context* context = stream->context();
std::vector<NumericT> hc(M*N); std::vector<NumericT> hc(M*N);
@@ -46,7 +46,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
stream->write(db, true, 0, hb); stream->write(db, true, 0, hb);
stream->write(dc, true, 0, hc); stream->write(dc, true, 0, hc);
stream->synchronize(); stream->synchronize();
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8); triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8, 8);
// benchmark triton // benchmark triton
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream); double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
// benchmark cublas // benchmark cublas

View File

@@ -134,7 +134,7 @@ int main() {
}; };
for(config_t c: resnet18){ for(config_t c: resnet18){
for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}){ for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}){
configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"}); configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "half"});
} }
} }

View File

@@ -37,7 +37,7 @@ std::vector<torch::Tensor>
triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.storage().data(), false); triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.storage().data(), false);
triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.storage().data(), false); triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.storage().data(), false);
// create template // create template
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32"); triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "float");
batchnorm.enqueue(&stream, {&y, &m, &v, &x, &g, &b}); batchnorm.enqueue(&stream, {&y, &m, &v, &x, &g, &b});
stream.synchronize(); stream.synchronize();
return {fw_y, fw_m, fw_v}; return {fw_y, fw_m, fw_v};
@@ -79,7 +79,7 @@ std::vector<torch::Tensor>
triton::driver::cu_buffer dg(ctx, (CUdeviceptr)fw_dg.storage().data(), false); triton::driver::cu_buffer dg(ctx, (CUdeviceptr)fw_dg.storage().data(), false);
triton::driver::cu_buffer db(ctx, (CUdeviceptr)fw_db.storage().data(), false); triton::driver::cu_buffer db(ctx, (CUdeviceptr)fw_db.storage().data(), false);
// create config // create config
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32", eps); triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "float", eps);
batchnorm.enqueue(&stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v}); batchnorm.enqueue(&stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
stream.synchronize(); stream.synchronize();
return {fw_dx, fw_dg, fw_db}; return {fw_dx, fw_dg, fw_db};

View File

@@ -30,7 +30,7 @@ torch::Tensor conv_common(
stride_d, stride_h, stride_w, stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w, pad_d, pad_h, pad_w,
1, 1, 1, 1, 1, 1,
"fp32", "fp32", ty, has_bias); "float", "float", ty, has_bias);
// Bind memory // Bind memory
triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false); triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false);
triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false); triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false);

View File

@@ -49,9 +49,9 @@ torch::Tensor shift_common(
std::string dtype; std::string dtype;
at::ScalarType type = torcha.scalar_type(); at::ScalarType type = torcha.scalar_type();
switch(type){ switch(type){
case at::ScalarType::Double: dtype = "fp64"; break; case at::ScalarType::Double: dtype = "double"; break;
case at::ScalarType::Float: dtype = "fp32"; break; case at::ScalarType::Float: dtype = "float"; break;
case at::ScalarType::Half: dtype = "fp16"; break; case at::ScalarType::Half: dtype = "half"; break;
default: AT_ERROR("unknown data-type for shift-conv"); default: AT_ERROR("unknown data-type for shift-conv");
} }
// Get configuration // Get configuration

View File

@@ -58,7 +58,7 @@ public:
triton::driver::cu_buffer m(ctx, fw_m->tensor_data().size(), (CUdeviceptr)fw_m->tensor_data().data(), false); triton::driver::cu_buffer m(ctx, fw_m->tensor_data().size(), (CUdeviceptr)fw_m->tensor_data().data(), false);
triton::driver::cu_buffer v(ctx, fw_v->tensor_data().size(), (CUdeviceptr)fw_v->tensor_data().data(), false); triton::driver::cu_buffer v(ctx, fw_v->tensor_data().size(), (CUdeviceptr)fw_v->tensor_data().data(), false);
// create config // create config
triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32"); triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "float", triton::dnn::FULL_TUNING);
batchnorm.enqueue(stream, {&y, &m, &v, &x, &g, &b}); batchnorm.enqueue(stream, {&y, &m, &v, &x, &g, &b});
} }
@@ -126,7 +126,7 @@ public:
triton::driver::cu_buffer dg(ctx, fw_dg->tensor_data().size(), (CUdeviceptr)fw_dg->tensor_data().data(), false); triton::driver::cu_buffer dg(ctx, fw_dg->tensor_data().size(), (CUdeviceptr)fw_dg->tensor_data().data(), false);
triton::driver::cu_buffer db(ctx, fw_db->tensor_data().size(), (CUdeviceptr)fw_db->tensor_data().data(), false); triton::driver::cu_buffer db(ctx, fw_db->tensor_data().size(), (CUdeviceptr)fw_db->tensor_data().data(), false);
// create config // create config
triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32"); triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "float", triton::dnn::FULL_TUNING);
batchnorm.enqueue(stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v}); batchnorm.enqueue(stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v});
} }

View File

@@ -128,9 +128,9 @@ public:
triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false); triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false);
triton::driver::cu_buffer dlut(ctx, lut.tensor_data().size(), (CUdeviceptr)lut.tensor_data().data(), false); triton::driver::cu_buffer dlut(ctx, lut.tensor_data().size(), (CUdeviceptr)lut.tensor_data().data(), false);
// create profile // create profile
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp16", params_.bsize, params_.locks, params_.blocks, OP); triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "half", params_.bsize, params_.locks, params_.blocks, OP);
// blocksparse matmul // blocksparse matmul
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::FULL_TUNING); triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks(); triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks();
Tensor *tmp = nullptr; Tensor *tmp = nullptr;
TensorShape tmp_shapes; TensorShape tmp_shapes;

View File

@@ -61,7 +61,7 @@ public:
stride_d, stride_h, stride_w, stride_d, stride_h, stride_w,
pad_d, pad_h, pad_w, pad_d, pad_h, pad_w,
1, 1, 1, 1, 1, 1,
"fp16", "fp16", "half", "half",
triton::dnn::conv::FPROP, has_bias); triton::dnn::conv::FPROP, has_bias);
// allocate output // allocate output
auto c_shapes = conv.c_shapes(); auto c_shapes = conv.c_shapes();

View File

@@ -49,7 +49,7 @@ class DotOp : public OpKernel {
triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false); triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false);
triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false); triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false);
// template // template
triton::dnn::dot dot(M, N, K, false, false, "fp16", "fp16", 8, 8); triton::dnn::dot dot(M, N, K, false, false, "half", "half", 8, 8, 8);
dot.enqueue(stream, {&da, &db, &dc}); dot.enqueue(stream, {&da, &db, &dc});
} }

View File

@@ -105,7 +105,7 @@ def batch_norm_grad(op, dy, mean, var):
def run_batchnorm(): def run_batchnorm():
C, H, W, B = 32, 14, 14, 64 C, H, W, B = 8, 4, 4, 32
np.random.seed(0) np.random.seed(0)
# Placeholders # Placeholders
x = tf.placeholder(tf.float32, shape=[C, H, W, B]) x = tf.placeholder(tf.float32, shape=[C, H, W, B])
@@ -131,6 +131,6 @@ def run_batchnorm():
print(np.max(np.abs(dg_t - dg_n))) print(np.max(np.abs(dg_t - dg_n)))
print(np.max(np.abs(db_t - db_n))) print(np.max(np.abs(db_t - db_n)))
run_dot() #run_dot()
#run_shift() #run_shift()
#run_batchnorm() run_batchnorm()

View File

@@ -106,7 +106,7 @@ public:
triton::dnn::shift shift(B, C, D, H, W, T, R_, S_, F, triton::dnn::shift shift(B, C, D, H, W, T, R_, S_, F,
stride_h_, stride_w_, stride_h_, stride_w_,
shift_h_data, shift_w_data, shift_h_data, shift_w_data,
"fp16", "fp16", OP, has_bias, layout_); "half", "half", OP, has_bias, layout_);
// shapes for c // shapes for c
std::vector<int64> c_shapes; std::vector<int64> c_shapes;

View File

@@ -91,6 +91,7 @@ public:
void set_value(indices_t idx, llvm::Value *v); void set_value(indices_t idx, llvm::Value *v);
llvm::Value* get_value(indices_t idx); llvm::Value* get_value(indices_t idx);
unsigned get_linear_index(indices_t idx); unsigned get_linear_index(indices_t idx);
indices_t get_ordered_indices(unsigned id);
void for_each(std::function<void(indices_t)> fn); void for_each(std::function<void(indices_t)> fn);
const distributed_axis &axis(unsigned dim) { return axes_.at(dim); } const distributed_axis &axis(unsigned dim) { return axes_.at(dim); }

View File

@@ -52,12 +52,15 @@ struct launch_context_t{
typedef std::vector<unsigned> params_t; typedef std::vector<unsigned> params_t;
class base { class base {
friend class cmp_recompile; friend class recompile_hash;
friend class recompile_equal;
protected: protected:
// leading dimensions // leading dimensions
static void set_ld(const std::vector<int32_t>& shapes, static void set_ld(const std::vector<int32_t>& shapes,
std::vector<int32_t>& ld); std::vector<int32_t>& ld);
// list of retuning parameters
virtual std::vector<int64_t> retune_params() const = 0;
private: private:
// initialize // initialize
@@ -70,8 +73,6 @@ private:
triton::runtime::launch_information info) = 0; triton::runtime::launch_information info) = 0;
// number of flops // number of flops
virtual size_t num_flops() const = 0; virtual size_t num_flops() const = 0;
// comparison for maps
virtual bool operator<(const base& other) const = 0;
// default parameters // default parameters
virtual std::vector<params_t> search_space() const; virtual std::vector<params_t> search_space() const;
virtual params_t heuristics() const; virtual params_t heuristics() const;
@@ -94,12 +95,21 @@ private:
std::string name_; std::string name_;
}; };
struct cmp_recompile{
struct recompile_equal{
bool operator()(base* x, base* y) const{ bool operator()(base* x, base* y) const{
return *x < *y; return typeid(*x) == typeid(*y) &&
x->retune_params() == y->retune_params();
} }
}; };
struct recompile_hash{
unsigned operator()(base* x) const{
return x->retune_params()[0];
}
};
} }
} }

View File

@@ -47,15 +47,15 @@ private:
triton::runtime::launch_information info); triton::runtime::launch_information info);
// number of flops // number of flops
size_t num_flops() const; size_t num_flops() const;
// comparison for maps // retuning parameters
bool operator<(const base& other) const; std::vector<int64_t> retune_params() const;
// clone // clone
base* clone() const; base* clone() const;
public: public:
// constructor // constructor
batchnorm_forward(int C, int D, int H, int W, int B, batchnorm_forward(int C, int D, int H, int W, int B,
std::string ty = "fp32", float eps = 1e-5); std::string ty = "float", float eps = 1e-5);
// triton-c source // triton-c source
void triton_c_src(std::ostream &os) const; void triton_c_src(std::ostream &os) const;
@@ -82,15 +82,15 @@ private:
runtime::launch_information info); runtime::launch_information info);
// number of flops // number of flops
size_t num_flops() const; size_t num_flops() const;
// comparison for maps // retuning parameters
bool operator<(const base& other) const; std::vector<int64_t> retune_params() const;
// clone // clone
base* clone() const; base* clone() const;
public: public:
// constructor // constructor
batchnorm_backward(int C, int D, int H, int W, int B, batchnorm_backward(int C, int D, int H, int W, int B,
std::string ty = "fp32", float eps = 1e-5); std::string ty = "float", float eps = 1e-5);
// triton-c source // triton-c source
void triton_c_src(std::ostream &os) const; void triton_c_src(std::ostream &os) const;

View File

@@ -20,8 +20,8 @@ private:
triton::runtime::launch_information info); triton::runtime::launch_information info);
// number of flops // number of flops
size_t num_flops() const; size_t num_flops() const;
// comparison for maps // retuning parameters
bool operator<(const base& other) const; std::vector<int64_t> retune_params() const;
// default parameters // default parameters
std::vector<params_t> search_space() const; std::vector<params_t> search_space() const;
params_t heuristics() const; params_t heuristics() const;

View File

@@ -37,8 +37,8 @@ private:
triton::runtime::launch_information info); triton::runtime::launch_information info);
// number of flops // number of flops
size_t num_flops() const; size_t num_flops() const;
// comparison for maps // retuning parameters
bool operator<(const base& other) const; std::vector<int64_t> retune_params() const;
// clone // clone
base* clone() const; base* clone() const;
@@ -50,7 +50,7 @@ public:
int stride_d, int stride_h, int stride_w, int stride_d, int stride_h, int stride_w,
int pad_d, int pad_h, int pad_w, int pad_d, int pad_h, int pad_w,
int upsample_d, int upsample_h, int upsample_w, int upsample_d, int upsample_h, int upsample_w,
std::string a_ty = "fp32", std::string b_ty = "fp32", std::string a_ty = "float", std::string b_ty = "float",
type ty = FPROP, bool bias = false); type ty = FPROP, bool bias = false);
// accessors // accessors

View File

@@ -16,8 +16,8 @@ private:
void enqueue_impl(driver::stream *stream, driver::kernel *kernel, void enqueue_impl(driver::stream *stream, driver::kernel *kernel,
std::vector<driver::buffer*> args, std::vector<driver::buffer*> args,
triton::runtime::launch_information info); triton::runtime::launch_information info);
// comparison for maps // retuning parameters
bool operator<(const base& other) const; std::vector<int64_t> retune_params() const;
// default parameters // default parameters
virtual std::vector<params_t> search_space() const; virtual std::vector<params_t> search_space() const;
virtual params_t heuristics() const; virtual params_t heuristics() const;
@@ -25,7 +25,7 @@ private:
public: public:
dot(int M, int N, int K, bool AT, bool BT, dot(int M, int N, int K, bool AT, bool BT,
std::string a_ty, std::string b_ty, std::string a_ty, std::string b_ty,
unsigned alignment_lda, unsigned alignment_ldb); unsigned align_lda, unsigned align_ldb, unsigned align_ldc);
// number of flops // number of flops
size_t num_flops() const; size_t num_flops() const;
@@ -70,6 +70,7 @@ private:
std::string b_ty_; std::string b_ty_;
unsigned align_lda_; unsigned align_lda_;
unsigned align_ldb_; unsigned align_ldb_;
unsigned align_ldc_;
driver::buffer *locks_; driver::buffer *locks_;
}; };

View File

@@ -64,7 +64,7 @@ public:
int T, int R, int S, int NF, int T, int R, int S, int NF,
int stride_h, int stride_w, int stride_h, int stride_w,
const int32_t* shift_h, const int32_t* shift_w, const int32_t* shift_h, const int32_t* shift_w,
std::string a_ty = "fp32", std::string b_ty = "fp32", std::string a_ty = "float", std::string b_ty = "float",
op_t ty = FPROP, bool bias = false, layout_t layout = CHWN); op_t ty = FPROP, bool bias = false, layout_t layout = CHWN);
// look-up table // look-up table
@@ -86,8 +86,8 @@ public:
size_t num_flops() const; size_t num_flops() const;
// source // source
void triton_c_src(std::ostream &os) const; void triton_c_src(std::ostream &os) const;
// comparison // retuning parameters
bool operator<(const base& other) const; std::vector<int64_t> retune_params() const;
// clone // clone
base* clone() const; base* clone() const;
// cpu reference // cpu reference

View File

@@ -30,19 +30,18 @@ using triton::lang::return_void;
"for" { return return_impl(FOR, yytext); } "for" { return return_impl(FOR, yytext); }
"while" { return return_impl(WHILE, yytext); } "while" { return return_impl(WHILE, yytext); }
"void" { return return_impl(VOID, yytext); } "void" { return return_impl(VOID, yytext); }
"uint1" { return return_impl(UINT1, yytext); } "uchar" { return return_impl(UCHAR, yytext); }
"uint8" { return return_impl(UINT8, yytext); } "ushort" { return return_impl(USHORT, yytext); }
"uint16" { return return_impl(UINT16, yytext); } "uint" { return return_impl(UINT, yytext); }
"uint32" { return return_impl(UINT32, yytext); } "ulong" { return return_impl(ULONG, yytext); }
"uint64" { return return_impl(UINT64, yytext); } "bool" { return return_impl(BOOL, yytext); }
"int1" { return return_impl(INT1, yytext); } "char" { return return_impl(CHAR, yytext); }
"int8" { return return_impl(INT8, yytext); } "short" { return return_impl(SHORT, yytext); }
"int16" { return return_impl(INT16, yytext); } "int" { return return_impl(INT, yytext); }
"int32" { return return_impl(INT32, yytext); } "long" { return return_impl(LONG, yytext); }
"int64" { return return_impl(INT64, yytext); } "half" { return return_impl(HALF, yytext); }
"fp16" { return return_impl(FP16, yytext); } "float" { return return_impl(FLOAT, yytext); }
"fp32" { return return_impl(FP32, yytext); } "double" { return return_impl(DOUBLE, yytext); }
"fp64" { return return_impl(FP64, yytext); }
"..." { return return_impl(ELLIPSIS, yytext); } "..." { return return_impl(ELLIPSIS, yytext); }
"get_range_id" { return return_impl(GET_RANGE_ID, yytext); } "get_range_id" { return return_impl(GET_RANGE_ID, yytext); }
"get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); } "get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); }

View File

@@ -78,7 +78,7 @@ public:
void target_dependent(ir::module &module) { void target_dependent(ir::module &module) {
alignment_info.run(module); alignment_info.run(module);
// ir::print(module, std::cout); // ir::print(module, std::cout);
reassociate.run(module); // reassociate.run(module);
if(target_->is_gpu()){ if(target_->is_gpu()){
shmem_info.run(module); shmem_info.run(module);
shmem_liveness.run(module); shmem_liveness.run(module);
@@ -86,7 +86,7 @@ public:
shmem_barriers.run(module); shmem_barriers.run(module);
} }
vectorize.run(module); vectorize.run(module);
optimize_dce.run(module); // optimize_dce.run(module);
// ir::print(module, std::cout); // ir::print(module, std::cout);
} }

View File

@@ -37,7 +37,7 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
stream->synchronize(); stream->synchronize();
while(total_time*1e-9 < 1e-3){ while(total_time*1e-9 < 1e-3){
float norm = 1; float norm = 1;
// normalize clock if possible to get roughly constant result // normalize clock if possible to reduce noise in auto-tuning
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device)) // if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); // norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
tmr.start(); tmr.start();

View File

@@ -74,6 +74,11 @@ unsigned distributed_tile::get_linear_index(indices_t idx) {
return indices_[idx]; return indices_[idx];
} }
indices_t distributed_tile::get_ordered_indices(unsigned id) {
return ordered_indices_.at(id);
}
void distributed_tile::for_each(std::function<void (indices_t)> fn) { void distributed_tile::for_each(std::function<void (indices_t)> fn) {
for(unsigned i = 0; i < ordered_indices_.size(); i++) for(unsigned i = 0; i < ordered_indices_.size(); i++)
if(i % vector_size_ == 0) if(i % vector_size_ == 0)
@@ -779,13 +784,21 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// store // store
if(auto *x = dynamic_cast<ir::masked_store_inst*>(ins)){ if(auto *x = dynamic_cast<ir::masked_store_inst*>(ins)){
distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand()); distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *scalars = tmap_.at(x->get_value_operand()); distributed_tile* scalars = (distributed_tile*)tmap_.at(x->get_value_operand());
ir::value *mask = x->get_mask_operand(); ir::value *mask = x->get_mask_operand();
distributed_tile* preds = (distributed_tile*)tmap_.at(mask); distributed_tile* preds = (distributed_tile*)tmap_.at(mask);
ptrs->for_each([&](indices_t idx){ ptrs->for_each([&](indices_t idx){
Value *scalar = scalars->get_value(idx); Value *scalar = scalars->get_value(idx);
Value *ptr = ptrs->get_value(idx); Value *ptr = ptrs->get_value(idx);
Value *pred = preds->get_value(idx); Value *pred = preds->get_value(idx);
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
builder.CreateCondBr(pred, mask_then_bb, mask_done_bb);
builder.SetInsertPoint(mask_then_bb);
builder.CreateStore(scalar, ptr);
builder.CreateBr(mask_done_bb);
builder.SetInsertPoint(mask_done_bb);
// std::string offset = ""; // std::string offset = "";
// if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr)) // if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(ptr))
// if(gep->getNumIndices() == 1) // if(gep->getNumIndices() == 1)
@@ -796,14 +809,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// std::string asm_str = "@$0 st.global.b32 [$1" + offset + "], $2;"; // std::string asm_str = "@$0 st.global.b32 [$1" + offset + "], $2;";
// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,l,f", true); // InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,l,f", true);
// builder.CreateCall(iasm, {pred, ptr, scalar}); // builder.CreateCall(iasm, {pred, ptr, scalar});
BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn);
BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn);
builder.CreateCondBr(pred, mask_then_bb, mask_done_bb);
builder.SetInsertPoint(mask_then_bb);
builder.CreateStore(scalar, ptr);
builder.CreateBr(mask_done_bb);
builder.SetInsertPoint(mask_done_bb);
}); });
} }
else if(auto *x = dynamic_cast<ir::store_inst*>(ins)) { else if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
@@ -893,11 +898,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
ir::value* in = ins->get_operand(0); ir::value* in = ins->get_operand(0);
distributed_tile *in_tile = (distributed_tile*)tmap_.at(in); distributed_tile *in_tile = (distributed_tile*)tmap_.at(in);
result->for_each([&](indices_t out_idx){ result->for_each([&](indices_t out_idx){
indices_t in_idx; unsigned pos = result->get_linear_index(out_idx);
for(size_t k = 0; k < shapes.size(); k++){ indices_t in_idx = in_tile->get_ordered_indices(pos);
if(shapes[k]->get_value() > 1)
in_idx.push_back(out_idx[k]);
}
result->set_value(out_idx, in_tile->get_value(in_idx)); result->set_value(out_idx, in_tile->get_value(in_idx));
}); });
} }

View File

@@ -63,14 +63,19 @@ void tune::init_c_graph(ir::instruction *v) {
else else
shapes = v->get_type()->get_tile_shapes(); shapes = v->get_type()->get_tile_shapes();
// Reshape // Reshape
if(dynamic_cast<ir::reshape_inst*>(v)){ if(dynamic_cast<ir::reshape_inst*>(v)) {
ir::value *op = v->get_operand(0); ir::value *op = v->get_operand(0);
unsigned current = 0; unsigned current = 0;
bool is_skewed = false;
for(unsigned i = 0; i < shapes.size(); i ++){ for(unsigned i = 0; i < shapes.size(); i ++){
if(shapes[i] == one) bool is_one = shapes[i] == one;
bool is_same = shapes[i] == op->get_type()->get_tile_shapes()[current];
if(is_one)
static_params_.insert({{v, i}, 1}); static_params_.insert({{v, i}, 1});
else else if(!is_skewed && is_same)
add_constraint({v, i}, {op, current++}); add_constraint({v, i}, {op, current++});
else
is_skewed = true;
} }
} }
// Splat // Splat
@@ -81,9 +86,8 @@ void tune::init_c_graph(ir::instruction *v) {
else if(dynamic_cast<ir::trans_inst*>(v)){ else if(dynamic_cast<ir::trans_inst*>(v)){
ir::value *op = v->get_operand(0); ir::value *op = v->get_operand(0);
size_t n_shapes = shapes.size(); size_t n_shapes = shapes.size();
for(unsigned i = 0; i < n_shapes; i++){ for(unsigned i = 0; i < n_shapes; i++)
add_constraint({v, (i + 1) % n_shapes}, {op, i}); add_constraint({v, (i + 1) % n_shapes}, {op, i});
}
} }
// Broadcast // Broadcast
else if(dynamic_cast<ir::broadcast_inst*>(v)){ else if(dynamic_cast<ir::broadcast_inst*>(v)){
@@ -247,14 +251,14 @@ void tune::run(ir::module &mod) {
size_t addr_space = ptr_ty->get_pointer_address_space(); size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){ if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty(); ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8)); std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 8));
*params_.at(i).at("nts.d0") = *tmp; *params_.at(i).at("nts.d0") = *tmp;
} }
} }
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){ if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty(); ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 8)); std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 8));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 8)); std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 8));
*params_.at(i).at("nts.d0") = *tmp1; *params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2; *params_.at(i).at("nts.d1") = *tmp2;
} }

View File

@@ -1,4 +1,5 @@
#include <sstream> #include <sstream>
#include <unordered_map>
#include "triton/dnn/base.h" #include "triton/dnn/base.h"
#include "triton/runtime/jit.h" #include "triton/runtime/jit.h"
#include "triton/tools/bench.hpp" #include "triton/tools/bench.hpp"
@@ -31,7 +32,7 @@ params_t base::heuristics() const {
} }
std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) { std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::vector<driver::buffer *> args, autotuning_t autotune) {
static std::map<base*, std::unique_ptr<rt::jit>, cmp_recompile> m_jit; static std::unordered_map<base*, std::unique_ptr<rt::jit>, recompile_hash, recompile_equal> m_jit;
driver::context* ctx = stream->context(); driver::context* ctx = stream->context();
rt::jit* jit; rt::jit* jit;
/* the current template has not already been compiled */ /* the current template has not already been compiled */

View File

@@ -30,7 +30,7 @@ namespace dnn{
* --------------- */ * --------------- */
batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty, float eps) batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty, float eps)
: base("batchnorm"), : base("batchnorm_forward"),
C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) { C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) {
DHWB_ = D_*H_*W_*B_; DHWB_ = D_*H_*W_*B_;
rcpDHWB_ = (float)1 / DHWB_; rcpDHWB_ = (float)1 / DHWB_;
@@ -40,12 +40,9 @@ size_t batchnorm_forward::num_flops() const {
return C_*DHWB_; return C_*DHWB_;
} }
bool batchnorm_forward::operator <(const base& other) const {
auto *y = dynamic_cast<const batchnorm_forward*>(&other); std::vector<int64_t> batchnorm_forward::retune_params() const {
if(!y) return {C_, D_, H_, W_, B_};
return true;
return std::tie(C_, D_, H_, W_, B_, ty_)
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
} }
base* batchnorm_forward::clone() const { base* batchnorm_forward::clone() const {
@@ -74,50 +71,50 @@ void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *ker
void batchnorm_forward::triton_c_src(std::ostream &os) const { void batchnorm_forward::triton_c_src(std::ostream &os) const {
os << os <<
R"( R"(
const tunable int32 TM = {32, 64, 128}; const tunable int TM = {32, 64, 128};
void batchnorm(fp32 *Y, fp32 *M, fp32 *V, void batchnorm_forward(float *Y, float *M, float *V,
restrict read_only fp32 *X, restrict read_only float *X,
restrict read_only fp32 *G, restrict read_only float *G,
restrict read_only fp32 *B, restrict read_only float *B,
int32 DHWN, int DHWN,
fp32 rcpDHWN, fp32 eps) { float rcpDHWN, float eps) {
int32 rx[TM] = 0 ... TM; int rx[TM] = 0 ... TM;
fp32 *px[TM]; float *px[TM];
fp32 x[TM]; float x[TM] = 0;
int32 c = get_range_id(1); int c = get_range_id(1);
fp32 g = *(G + c); float g = *(G + c);
fp32 b = *(B + c); float b = *(B + c);
fp32 mean[TM] = 0; float mean[TM] = 0;
px = X + rx + c*DHWN; px = X + rx + c*DHWN;
for(int32 i = 0; i < DHWN; i = i + TM){ for(int i = 0; i < DHWN; i = i + TM){
x = *px; x = *px;
mean = mean + x; mean = mean + x;
px = px + TM; px = px + TM;
} }
fp32 *pm = M + c; float *pm = M + c;
fp32 m = __sum(mean) * rcpDHWN; float m = __sum(mean) * rcpDHWN;
*pm = m; *pm = m;
fp32 var[TM] = 0; float var[TM] = 0;
px = X + rx + c*DHWN; px = X + rx + c*DHWN;
for(int32 i = 0; i < DHWN; i = i + TM){ for(int i = 0; i < DHWN; i = i + TM){
x = *px; x = *px;
x = x - m; x = x - m;
var = var + x*x; var = var + x*x;
px = px + TM; px = px + TM;
} }
fp32 v = __sum(var) * rcpDHWN; float v = __sum(var) * rcpDHWN;
fp32 *pv = V + c; float *pv = V + c;
*pv = v; *pv = v;
fp32 rstdg = 1 / sqrt(v + eps) * g; float rstdg = 1 / sqrt(v + eps) * g;
px = X + rx + c*DHWN; px = X + rx + c*DHWN;
fp32* py[TM] = Y + rx + c*DHWN; float* py[TM] = Y + rx + c*DHWN;
for(int32 i = 0; i < DHWN; i = i + TM){ for(int i = 0; i < DHWN; i = i + TM){
x = *px; x = *px;
fp32 y[TM] = (x - m)*rstdg + b; float y[TM] = (x - m)*rstdg + b;
*py = y; *py = y;
px = px + TM; px = px + TM;
py = py + TM; py = py + TM;
@@ -130,7 +127,7 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V,
* --------------- */ * --------------- */
batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps) batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps)
: base("batchnorm"), : base("batchnorm_backward"),
C_(C), D_(D), H_(H), W_(W), B_(B), C_(C), D_(D), H_(H), W_(W), B_(B),
ty_(ty), eps_(eps) ty_(ty), eps_(eps)
{ } { }
@@ -139,12 +136,8 @@ size_t batchnorm_backward::num_flops() const {
return C_*D_*H_*W_*B_; return C_*D_*H_*W_*B_;
} }
bool batchnorm_backward::operator <(const base& other) const { std::vector<int64_t> batchnorm_backward::retune_params() const {
auto *y = dynamic_cast<const batchnorm_backward*>(&other); return {C_, D_, H_, W_, B_};
if(!y)
return true;
return std::tie(C_, D_, H_, W_, B_, ty_)
< std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_);
} }
base* batchnorm_backward::clone() const { base* batchnorm_backward::clone() const {
@@ -174,54 +167,54 @@ void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *ke
void batchnorm_backward::triton_c_src(std::ostream &os) const { void batchnorm_backward::triton_c_src(std::ostream &os) const {
os << os <<
R"( R"(
const tunable int32 TM = {32, 64, 128}; const tunable int TM = {32, 64, 128};
void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB, void batchnorm_backward(float *DX, float *DG, float *DB,
restrict read_only fp32 *DY, restrict read_only float *DY,
restrict read_only fp32 *X, restrict read_only float *X,
restrict read_only fp32 *G, restrict read_only float *G,
restrict read_only fp32 *M, restrict read_only float *M,
restrict read_only fp32 *V, restrict read_only float *V,
int32 DHWN, fp32 rcpDHWN, fp32 epsilon) { int DHWN, float rcpDHWN, float epsilon) {
int32 rx[TM] = 0 ... TM; int rx[TM] = 0 ... TM;
int32 c = get_range_id(1); int c = get_range_id(1);
int32 offset = c*DHWN; int offset = c*DHWN;
fp32 g = *(G + c); float g = *(G + c);
fp32 mean = *(M + c); float mean = *(M + c);
fp32 var = *(V + c); float var = *(V + c);
fp32 rstd = 1 / sqrt(var + epsilon); float rstd = 1 / sqrt(var + epsilon);
fp32* px[TM]; float* px[TM];
fp32* pdx[TM]; float* pdx[TM];
fp32* pdy[TM]; float* pdy[TM];
px = X + rx + offset; px = X + rx + offset;
pdy = DY + rx + offset; pdy = DY + rx + offset;
fp32 dg[TM] = 0; float dg[TM] = 0;
fp32 db[TM] = 0; float db[TM] = 0;
for(int32 i = 0; i < DHWN; i = i + TM){ for(int i = 0; i < DHWN; i = i + TM){
fp32 x[TM] = *px; float x[TM] = *px;
fp32 dy[TM] = *pdy; float dy[TM] = *pdy;
dg = dg + dy*(x - mean)*rstd; dg = dg + dy*(x - mean)*rstd;
db = db + dy; db = db + dy;
px = px + TM; px = px + TM;
pdy = pdy + TM; pdy = pdy + TM;
} }
fp32 sdg = __sum(dg); float sdg = __sum(dg);
fp32 sdb = __sum(db); float sdb = __sum(db);
fp32 *pdg = DG + c; float *pdg = DG + c;
fp32 *pdb = DB + c; float *pdb = DB + c;
*pdg = sdg; *pdg = sdg;
*pdb = sdb; *pdb = sdb;
px = X + rx + offset; px = X + rx + offset;
pdy = DY + rx + offset; pdy = DY + rx + offset;
pdx = DX + rx + offset; pdx = DX + rx + offset;
for(int32 i = 0; i < DHWN; i = i + TM){ for(int i = 0; i < DHWN; i = i + TM){
fp32 x[TM] = *px; float x[TM] = *px;
fp32 dy[TM] = *pdy; float dy[TM] = *pdy;
fp32 xhat[TM] = (x - mean) * rstd; float xhat[TM] = (x - mean) * rstd;
fp32 xtmp[TM] = (xhat * dg + db) * rcpDHWN; float xtmp[TM] = (xhat * dg + db) * rcpDHWN;
fp32 dx[TM] = (dy - xtmp) * rstd * g; float dx[TM] = (dy - xtmp) * rstd * g;
*pdx = dx; *pdx = dx;
px = px + TM; px = px + TM;
pdy = pdy + TM; pdy = pdy + TM;

View File

@@ -10,12 +10,8 @@ size_t dot::num_flops() const {
return 2.*nblocks_*BS_*BS_*N_; return 2.*nblocks_*BS_*BS_*N_;
} }
bool dot::operator <(const base& other) const { std::vector<int64_t> dot::retune_params() const{
auto *y = dynamic_cast<const dot*>(&other); return {N_, S_, C_, BS_, nlocks_, op_};
if(!y)
return true;
return std::tie(N_, S_, C_, BS_, nlocks_, ab_ty_, c_ty_, op_)
< std::tie(y->N_, y->S_, y->C_, y->BS_, y->nlocks_, y->ab_ty_, y->c_ty_, y->op_);
} }
std::vector<params_t> dot::search_space() const { std::vector<params_t> dot::search_space() const {
@@ -92,35 +88,35 @@ void dot::triton_c_src(std::ostream &os) const {
std::string ldb1 = (op_ == FPROP) ? "*TK" : "" ; std::string ldb1 = (op_ == FPROP) ? "*TK" : "" ;
std::string result = std::string result =
R"( R"(
const tunable int32 TM = {16, 32, 64, 128}; const tunable int TM = {16, 32, 64, 128};
const tunable int32 TN = {)" + std::to_string(BS_) + R"(}; const tunable int TN = {)" + std::to_string(BS_) + R"(};
const tunable int32 TK = {)" + std::to_string(BS_) + R"(}; const tunable int TK = {)" + std::to_string(BS_) + R"(};
void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A, void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A,
restrict read_only align(16) )" + ab_ty_ + R"( *B, restrict read_only align(16) )" + ab_ty_ + R"( *B,
)" + c_ty_ + R"(* C, )" + c_ty_ + R"(* C,
int32 lda, int32 ldc, int32 N, int lda, int ldc, int N,
int32* lut, int32* locks, int32 nlocks){ int* lut, int* locks, int nlocks){
int32 ridx = get_range_id(0); int ridx = get_range_id(0);
int32 ridy = get_range_id(1); int ridy = get_range_id(1);
fp32 acc[TM, TN] = 0; float acc[TM, TN] = 0;
int32 rxa[TM] = ridx * TM + (0 ... TM); int rxa[TM] = ridx * TM + (0 ... TM);
int32 ryb[TN] = 0 ... TN; int ryb[TN] = 0 ... TN;
int32 rka[TK] = 0 ... TK; int rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK; int rkb[TK] = 0 ... TK;
int1 checka[TM, TK] = (rxa < N)[:, newaxis]; bool checka[TM, TK] = (rxa < N)[:, newaxis];
int32 offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda; int offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda;
int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(; int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
int32 *header = lut + ridy * 4; int *header = lut + ridy * 4;
int32 offset = *(header + 0); int offset = *(header + 0);
int32 K = *(header + 1); int K = *(header + 1);
int32 column = *(header + 2); int column = *(header + 2);
int32 lockid = *(header + 3); int lockid = *(header + 3);
int32 *plut = lut + offset * 2; int *plut = lut + offset * 2;
for(int32 k = K; k > 0; k = k - 1) for(int k = K; k > 0; k = k - 1)
{ {
int32 ak = *(plut + 0); int ak = *(plut + 0);
int32 bk = *(plut + 1); int bk = *(plut + 1);
)" + ab_ty_ + "* pa[" + sizea + R"(] = A + offa + ak * TK * lda; )" + ab_ty_ + "* pa[" + sizea + R"(] = A + offa + ak * TK * lda;
)" + ab_ty_ + "* pb[" + sizeb + R"(] = B + offb + bk * TK * TN; )" + ab_ty_ + "* pb[" + sizeb + R"(] = B + offb + bk * TK * TN;
)" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0; )" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0;
@@ -128,19 +124,19 @@ void dot::triton_c_src(std::ostream &os) const {
acc = dot()" + usea + ", " + useb + R"(, acc); acc = dot()" + usea + ", " + useb + R"(, acc);
plut = plut + 2; plut = plut + 2;
} }
int32 rxc[TM] = ridx * TM + (0 ... TM); int rxc[TM] = ridx * TM + (0 ... TM);
int32 ryc[TN] = column * TN + (0 ... TN); int ryc[TN] = column * TN + (0 ... TN);
)" + c_ty_ + R"(" c[TM, TN] = acc; )" + c_ty_ + R"(" c[TM, TN] = acc;
)" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc; )" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc;
int1 checkc[TM, TN] = (rxc < N)[:, newaxis]; bool checkc[TM, TN] = (rxc < N)[:, newaxis];
if(lockid == 0) if(lockid == 0)
@checkc *pc = c; @checkc *pc = c;
else else
{ {
int32 *plock = locks + ridx*nlocks + lockid - 1; int *plock = locks + ridx*nlocks + lockid - 1;
int32 *pcount = plock + get_num_program(0)*nlocks; int *pcount = plock + get_num_program(0)*nlocks;
while(__atomic_cas(plock, 0, 1)); while(__atomic_cas(plock, 0, 1));
int32 count = *pcount; int count = *pcount;
if(count == 0){ if(count == 0){
@checkc *pc = c; @checkc *pc = c;
} }

View File

@@ -98,20 +98,12 @@ conv::conv(int B, int NC,
} }
// comparison for maps // comparison for maps
bool conv::operator<(const base& other) const { std::vector<int64_t> conv::retune_params() const {
auto *y = dynamic_cast<const conv*>(&other); return {NB_, NC_, AD_, AH_, AW_,
if(!y) NF_, BD_, BH_, BW_,
return true; pad_d_, pad_h_, pad_w_,
return std::tie(NB_, NC_, AD_, AH_, AW_, stride_d_, stride_h_, stride_w_,
NF_, BD_, BH_, BW_, ty_, bias_};
pad_d_, pad_h_, pad_w_,
stride_d_, stride_h_, stride_w_,
a_ty_, b_ty_, ty_, bias_)
< std::tie(y->NB_, y->NC_, y->AD_, y->AH_, y->AW_,
y->NF_, y->BD_, y->BH_, y->BW_,
y->pad_d_, y->pad_h_, y->pad_w_,
y->stride_d_, y->stride_h_, y->stride_w_,
y->a_ty_, y->b_ty_, y->ty_, y->bias_);
} }
// clone // clone
@@ -549,114 +541,114 @@ void conv::triton_c_src(std::ostream &os) const {
os << os <<
R"( R"(
const tunable int32 TM = {16, 32, 64}; const tunable int TM = {16, 32, 64};
const tunable int32 TN = {16, 32, 64}; const tunable int TN = {16, 32, 64};
const tunable int32 TK = {)" << TK_ << R"(}; const tunable int TK = {)" << TK_ << R"(};
const tunable int32 GZ = {1}; const tunable int GZ = {1};
)"; )";
if(is_a_deltas_cst) if(is_a_deltas_cst)
os << "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n"; os << "__constant__ int* delta = alloc_const int[" + std::to_string(h_a_deltas_.size()) + "];\n";
if(b_lut_ && is_b_deltas_cst_) if(b_lut_ && is_b_deltas_cst_)
os << "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n"; os << "__constant__ int* b_delta = alloc_const int[" + std::to_string(h_b_deltas_.size()) + "];\n";
if(is_mask_cst_) if(is_mask_cst_)
os << "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n"; os << "__constant__ int* masks = alloc_const int[" + std::to_string(h_masks_.size()) + "];\n";
os << R"( os << R"(
void conv(read_only restrict )" << a_ty_ << R"( *a, void conv(read_only restrict )" << a_ty_ << R"( *a,
read_only restrict )" << b_ty_ << R"( *b, read_only restrict )" << b_ty_ << R"( *b,
fp32 *c, float *c,
fp32 *bias, float *bias,
int32 M, int32 N, int32 K, int M, int N, int K,
int32 AH, int32 AW, int AH, int AW,
int32 BH, int32 BW, int BH, int BW,
int32 CH, int32 CW, int CH, int CW,
int32 NC, int NC,
int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w, int lda_n, int lda_c, int lda_d, int lda_h, int lda_w,
int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k, int ldb_c, int ldb_t, int ldb_r, int ldb_s, int ldb_k,
int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q, int ldc_n, int ldc_k, int ldc_m, int ldc_p, int ldc_q,
int32 pad_h, int32 pad_w, int pad_h, int pad_w,
int32 stride_h, int32 stride_w, int stride_h, int stride_w,
int32 upsample_h, int32 upsample_w, int upsample_h, int upsample_w,
int32 off_uh, int32 off_uw, int off_uh, int off_uw,
int32 off_uah, int32 off_uaw, int off_uah, int off_uaw,
int32 off_uch, int32 off_ucw, int off_uch, int off_ucw,
int32 *locks, int32 grid0, int32 grid1)"; int *locks, int grid0, int grid1)";
if(!is_a_deltas_cst) if(!is_a_deltas_cst)
os << ", int32* delta"; os << ", int* delta";
if(b_lut_ && !is_b_deltas_cst_) if(b_lut_ && !is_b_deltas_cst_)
os << ", int32* b_delta"; os << ", int* b_delta";
if(!is_mask_cst_) if(!is_mask_cst_)
os << ", int32* masks"; os << ", int* masks";
os << R"(){ os << R"(){
int32 rxa[TM] = get_global_range[TM](0); int rxa[TM] = get_global_range[TM](0);
int32 rb0[TN] = get_global_range[TN](1); int rb0[TN] = get_global_range[TN](1);
int32 rz = get_global_range[1](2); int rz = get_global_range[1](2);
int32 rka[TK] = 0 ... TK; int rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK; int rkb[TK] = 0 ... TK;
fp32 C[TM, TN] = 0; float C[TM, TN] = 0;
int32 ldlut = )" + std::to_string(Luts_) + R"(; int ldlut = )" + std::to_string(Luts_) + R"(;
int32 div = K / GZ; int div = K / GZ;
int32 rem = K % GZ; int rem = K % GZ;
K = select(rz < rem, div, div + rem); K = select(rz < rem, div, div + rem);
int32 offk = rz*div; int offk = rz*div;
rka = rka + offk; rka = rka + offk;
rkb = rkb + offk; rkb = rkb + offk;
int32 rabh[TM] = rxa / CW; int rabh[TM] = rxa / CW;
int32 raw[TM] = rxa % CW; int raw[TM] = rxa % CW;
int32 rab[TM] = rabh / CH; int rab[TM] = rabh / CH;
int32 rah[TM] = rabh % CH; int rah[TM] = rabh % CH;
rah = rah)" + upaw + R"( - off_uah; rah = rah)" + upaw + R"( - off_uah;
raw = raw)" + upah + R"( - off_uaw; raw = raw)" + upah + R"( - off_uaw;
int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w; int ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w;
int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(; int ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(;
int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(; int ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(;
int32 ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(; int ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(;
int32 ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(; int ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(;
rar = )" + flipr + R"( rar; rar = )" + flipr + R"( rar;
ras = )" + flips + R"( ras; ras = )" + flips + R"( ras;
rar = )" + upar + R"( rar; rar = )" + upar + R"( rar;
ras = )" + upas + R"( ras; ras = )" + upas + R"( ras;
int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; int ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w;
)" << a_ty_ << R"(* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)"; )" << a_ty_ << R"(* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)";
if(b_lut_){ if(b_lut_){
os << R"( os << R"(
int32 rb)" + ax[0] + ax[1] + "[TK] = rkb / " + redax[2] + R"(; int rb)" + ax[0] + ax[1] + "[TK] = rkb / " + redax[2] + R"(;
int32 rb)" + ax[2] + "[TK] = rkb % " + redax[2] + R"(; int rb)" + ax[2] + "[TK] = rkb % " + redax[2] + R"(;
int32 rb)" + ax[0] + "[TK] = rb" + ax[0] + ax[1] + " / " + redax[1] + R"(; int rb)" + ax[0] + "[TK] = rb" + ax[0] + ax[1] + " / " + redax[1] + R"(;
int32 rb)" + ax[1] + "[TK] = rb" + ax[0] + ax[1] + " % " + redax[1] + R"(; int rb)" + ax[1] + "[TK] = rb" + ax[0] + ax[1] + " % " + redax[1] + R"(;
rbr = rbr*upsample_h + off_uh; rbr = rbr*upsample_h + off_uh;
rbs = rbs*upsample_w + off_uw; rbs = rbs*upsample_w + off_uw;
int32 offdb[TK] = rkb % ldlut; int offdb[TK] = rkb % ldlut;
int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + rbs*ldb_s; int rb1[TK] = rbc*ldb_c + rbr*ldb_r + rbs*ldb_s;
)" + b_delta_mem + R"( int32* pdb[TK] = b_delta + offdb + off_uw*ldlut + off_uh*ldlut*upsample_w; )" + b_delta_mem + R"( int* pdb[TK] = b_delta + offdb + off_uw*ldlut + off_uh*ldlut*upsample_w;
int32 db[TK] = *pdb;)"; int db[TK] = *pdb;)";
} }
else{ else{
os << R"( os << R"(
int32 rb1[TK] = rkb)" + ldb0 + ";"; int rb1[TK] = rkb)" + ldb0 + ";";
} }
os << R"( os << R"(
)" << b_ty_ << R"(* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + R"(*ldb_k; )" << b_ty_ << R"(* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + R"(*ldb_k;
int32 offda[TK] = rka % ldlut; int offda[TK] = rka % ldlut;
)" + a_delta_mem + R"( int32* pincd[TK] = delta + offda; )" + a_delta_mem + R"( int* pincd[TK] = delta + offda;
)" + a_delta_mem + R"( int32* pda[TK] = delta + ldlut + offda + off_uw*ldlut + off_uh*ldlut*upsample_w; )" + a_delta_mem + R"( int* pda[TK] = delta + ldlut + offda + off_uw*ldlut + off_uh*ldlut*upsample_w;
int32 da[TK] = *pda; int da[TK] = *pda;
int32 incd[TK] = *pincd; int incd[TK] = *pincd;
int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0); int maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0);
int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0); int maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0);
int32 offma = offk % ldlut; int offma = offk % ldlut;
)" + masks_mem + R"( int32* pm[TM] = masks + ldlut + offma + maskw*ldlut + maskh*ldlut*(2*pad_w + 1) + off_uw*ldlut*(2*pad_w+1)*(2*pad_h+1) + off_uh*ldlut*(2*pad_w+1)*(2*pad_h+1)*upsample_w; )" + masks_mem + R"( int* pm[TM] = masks + ldlut + offma + maskw*ldlut + maskh*ldlut*(2*pad_w + 1) + off_uw*ldlut*(2*pad_w+1)*(2*pad_h+1) + off_uh*ldlut*(2*pad_w+1)*(2*pad_h+1)*upsample_w;
)" + a_delta_mem + R"( int32* pincm[TM] = delta + offma; )" + a_delta_mem + R"( int* pincm[TM] = delta + offma;
int32 incm[TM] = *pincm; int incm[TM] = *pincm;
int32 maska0[TM] = *pm; int maska0[TM] = *pm;
int32 maska1[TK] = 1 << (0 ... TK); int maska1[TK] = 1 << (0 ... TK);
int1 checka[TM, TK] = (maska0[:, newaxis] & maska1[newaxis, :]) > 0; bool checka[TM, TK] = (maska0[:, newaxis] & maska1[newaxis, :]) > 0;
int1 checkb0[TN] = rb0 < N; bool checkb0[TN] = rb0 < N;
int1 checkb)" + BS + " = checkb0" + bcb0 + R"(; bool checkb)" + BS + " = checkb0" + bcb0 + R"(;
)" << a_ty_ << R"( a[TM, TK] = checka ? *pa : 0; )" << a_ty_ << R"( a[TM, TK] = checka ? *pa : 0;
)" << b_ty_ << R"( b)" + BS + R"( = checkb ? *pb : 0; )" << b_ty_ << R"( b)" + BS + R"( = checkb ? *pb : 0;
int32 rkamin[TK] = rka - offk + TK; int rkamin[TK] = rka - offk + TK;
for(int32 k = K; k > 0; k = k - TK){ for(int k = K; k > 0; k = k - TK){
C = dot(a, )" + useb + R"(, C); C = dot(a, )" + useb + R"(, C);
pa = pa + da[newaxis, :]; pa = pa + da[newaxis, :];
pb = pb + )" + inc_pb + R"(; pb = pb + )" + inc_pb + R"(;
@@ -673,7 +665,7 @@ if(b_lut_){
pm = pm + incm; pm = pm + incm;
pincm = pincm + incm; pincm = pincm + incm;
incm = *pincm; incm = *pincm;
int1 checka1[TK] = (rkamin < k); bool checka1[TK] = (rkamin < k);
maska0 = *pm; maska0 = *pm;
checka = (maska0[:, newaxis] & maska1[newaxis, :]) > 0; checka = (maska0[:, newaxis] & maska1[newaxis, :]) > 0;
checka = checka && checka1[newaxis,:]; checka = checka && checka1[newaxis,:];
@@ -681,31 +673,31 @@ if(b_lut_){
checkb = checkb && (k > TK); checkb = checkb && (k > TK);
@checkb b = *pb; @checkb b = *pb;
} }
int32 rxc[TM] = get_global_range[TM](0); int rxc[TM] = get_global_range[TM](0);
int32 rc1[TN] = get_global_range[TN](1); int rc1[TN] = get_global_range[TN](1);
int32 rcn[TM] = rxc / (CH*CW); int rcn[TM] = rxc / (CH*CW);
int32 rcpq[TM] = rxc % (CH*CW); int rcpq[TM] = rxc % (CH*CW);
int32 rcp[TM] = rcpq / CW; int rcp[TM] = rcpq / CW;
int32 rcq[TM] = rcpq % CW; int rcq[TM] = rcpq % CW;
rcp = rcp * upsample_h + off_uch; rcp = rcp * upsample_h + off_uch;
rcq = rcq * upsample_w + off_ucw; rcq = rcq * upsample_w + off_ucw;
int1 checkc1[TN] = rc1 < N; bool checkc1[TN] = rc1 < N;
int32 rc0[TM] = rcn * ldc_n + rcp * ldc_p + rcq * ldc_q; int rc0[TM] = rcn * ldc_n + rcp * ldc_p + rcq * ldc_q;
fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis]; float* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
int1 checkc0[TM] = rxc < M; bool checkc0[TM] = rxc < M;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
int32 ridx = get_range_id(0); int ridx = get_range_id(0);
int32 ridy = get_range_id(1); int ridy = get_range_id(1);
int32 *plock = locks + ridx + ridy*grid0; int *plock = locks + ridx + ridy*grid0;
while(__atomic_cas(plock, 0, 1) == 1); while(__atomic_cas(plock, 0, 1) == 1);
int32 *pcount = plock + grid0*grid1; int *pcount = plock + grid0*grid1;
int32 count = *pcount; int count = *pcount;
int32 countp1 = select(count == GZ - 1, 0, count + 1); int countp1 = select(count == GZ - 1, 0, count + 1);
if(count == 0) {)"; if(count == 0) {)";
if(bias_ && ty_==FPROP){ if(bias_ && ty_==FPROP){
os << R"( os << R"(
fp32* pbias[TN] = bias + rc1; float* pbias[TN] = bias + rc1;
fp32 bias[TN] = checkc1 ? *pbias : 0; float bias[TN] = checkc1 ? *pbias : 0;
C = C + bias[newaxis, :];)"; C = C + bias[newaxis, :];)";
} }
os << R"( os << R"(

View File

@@ -10,11 +10,11 @@ namespace dnn{
dot::dot(int M, int N, int K, dot::dot(int M, int N, int K,
bool AT, bool BT, bool AT, bool BT,
std::string a_ty, std::string b_ty, std::string a_ty, std::string b_ty,
unsigned alignment_lda, unsigned alignment_ldb) unsigned align_lda, unsigned align_ldb, unsigned align_ldc)
: base("matmul"), : base("matmul"),
M_(M), N_(N), K_(K), AT_(AT), BT_(BT), M_(M), N_(N), K_(K), AT_(AT), BT_(BT),
a_ty_(a_ty), b_ty_(b_ty), a_ty_(a_ty), b_ty_(b_ty),
align_lda_(alignment_lda), align_ldb_(alignment_ldb), align_lda_(align_lda), align_ldb_(align_ldb), align_ldc_(align_ldc),
locks_(nullptr) { locks_(nullptr) {
} }
@@ -23,15 +23,10 @@ size_t dot::num_flops() const {
return 2.*M_*N_*K_; return 2.*M_*N_*K_;
} }
// comparison for maps // retune parameters
bool dot::operator<(const base& other) const { std::vector<int64_t> dot::retune_params() const {
auto *y = dynamic_cast<const dot*>(&other); return {M_, N_, K_, AT_, BT_,
if(!y) (int)align_lda_, (int)align_ldb_};
return true;
return std::tie(M_, N_, K_, AT_, BT_,
a_ty_, b_ty_, align_lda_, align_ldb_)
< std::tie(y->M_, y->N_, y->K_, y->AT_, y->BT_,
y->a_ty_, y->b_ty_, y->align_lda_, y->align_ldb_);
} }
// clone // clone
@@ -101,45 +96,45 @@ void dot::triton_c_src(std::ostream &os) const {
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")"; std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
std::string res = std::string res =
R"( R"(
const tunable int32 TM = {16, 32, 64, 128, 256}; const tunable int TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128, 256}; const tunable int TN = {16, 32, 64, 128};
const tunable int32 TK = {32}; const tunable int TK = {32};
const tunable int32 GZ = {1}; const tunable int GZ = {1};
void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
restrict read_only align(16) )" + b_ty_ + R"( *B, restrict read_only align(16) )" + b_ty_ + R"( *B,
fp32 *C, restrict read_only align(16) float *C,
int32 M, int32 N, int32 K, int M, int N, int K,
)" + align_lda_str + R"( int32 lda, )" + align_ldb_str + R"(" int32 ldb, int32 ldc, )" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc,
int32 bound, int32 *locks, int32 grid0, int32 grid1) { int bound, int *locks, int grid0, int grid1) {
int32 ridx = get_range_id(0); int ridx = get_range_id(0);
int32 ridy = get_range_id(1); int ridy = get_range_id(1);
int32 rxa[TM] = ridx * TM + (0 ... TM); int rxa[TM] = ridx * TM + (0 ... TM);
int32 ryb[TN] = ridy * TN + (0 ... TN); int ryb[TN] = ridy * TN + (0 ... TN);
int32 rka[TK] = 0 ... TK; int rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK; int rkb[TK] = 0 ... TK;
fp32 c[TM, TN] = 0; float c[TM, TN] = 0;
)" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(; )" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; )" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
int1 checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(; bool checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(;
int1 checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(; bool checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(;
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0; )" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0; )" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
for(int32 k = K; k > 0; k = k - TK){ for(int k = K; k > 0; k = k - TK){
c = dot()" + usea + ", " + useb + R"(, c); c = dot()" + usea + ", " + useb + R"(, c);
pa = pa + TK)" + lda0 + R"(; pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(; pb = pb + TK)" + ldb0 + R"(;
int1 checka[)" + AS + R"(] = k > TK; bool checka[)" + AS + R"(] = k > TK;
int1 checkb[)" + BS + R"(] = k > TK; bool checkb[)" + BS + R"(] = k > TK;
a = checka ? *pa : 0; a = checka ? *pa : 0;
b = checkb ? *pb : 0; b = checkb ? *pb : 0;
} }
int32 rxc[TM] = ridx * TM + (0 ... TM); int rxc[TM] = ridx * TM + (0 ... TM);
int32 ryc[TN] = ridy * TN + (0 ... TN); int ryc[TN] = ridy * TN + (0 ... TN);
int1 checkc0[TM] = rxc < M; bool checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N; bool checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; float* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
@checkc *pc = c; @checkc *pc = c;
} }
)"; )";

View File

@@ -28,7 +28,7 @@ shift::shift(int B, int C,
layout_(layout){ layout_(layout){
// std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl; // std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl;
// max number of channels // max number of channels
TK_ = (ty == FPROP && a_ty_ == "fp32") ? 8 : 32; TK_ = (ty == FPROP && a_ty_ == "float") ? 8 : 32;
MAX_C_ = 8192 + TK_; MAX_C_ = 8192 + TK_;
// activation sizes // activation sizes
CD_ = AD_ / stride_d_; CD_ = AD_ / stride_d_;
@@ -204,26 +204,15 @@ size_t shift::ldb() const
size_t shift::ldc() const size_t shift::ldc() const
{ return M_; } { return M_; }
bool shift::operator <(const base& other) const{ std::vector<int64_t> shift::retune_params() const {
auto *y = dynamic_cast<const shift*>(&other); return {B_, C_, F_,
if(!y) AD_, AH_, AW_,
return true; BD_, BH_, BW_,
return std::tie(B_, C_, F_, CD_, CH_, CW_,
AD_, AH_, AW_, (int64_t)shift_h_, (int64_t)shift_w_,
BD_, BH_, BW_, stride_h_, stride_w_,
CD_, CH_, CW_, layout_, op_,
shift_h_, shift_w_, bias_};
stride_h_, stride_w_,
layout_, op_,
bias_)
< std::tie(y->B_, y->C_, y->F_,
y->AD_, y->AH_, y->AW_,
y->BD_, y->BH_, y->BW_,
y->CD_, y->CH_, y->CW_,
y->shift_h_, y->shift_w_,
y->stride_h_, y->stride_w_,
y->layout_, y->op_,
y->bias_);
} }
void shift::init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info) { void shift::init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info) {
@@ -325,56 +314,56 @@ void shift::triton_c_src(std::ostream &os) const {
if(is_chwn) { if(is_chwn) {
return R"( return R"(
int32 )" + rx + "wh[" + sz + "] = " + rkx + " / " + B + R"(; int )" + rx + "wh[" + sz + "] = " + rkx + " / " + B + R"(;
int32 )" + rx + "b[" + sz + "] = " + rkx + " % " + B + R"(; int )" + rx + "b[" + sz + "] = " + rkx + " % " + B + R"(;
int32 )" + rx + "w[" + sz + "] = (" + rx + "wh % " + CW + R"() + pad_w; int )" + rx + "w[" + sz + "] = (" + rx + "wh % " + CW + R"() + pad_w;
int32 )" + rx + "h[" + sz + "] = (" + rx + "wh / " + CW + R"() + pad_h;)"; int )" + rx + "h[" + sz + "] = (" + rx + "wh / " + CW + R"() + pad_h;)";
} }
else { else {
return R"( return R"(
int32 )" + rx + "bh[" + sz + "] = " + rkx + " / " + CW + R"(; int )" + rx + "bh[" + sz + "] = " + rkx + " / " + CW + R"(;
int32 )" + rx + "w[" + sz + "] = (" + rkx + " % " + CW + R"() + pad_w; int )" + rx + "w[" + sz + "] = (" + rkx + " % " + CW + R"() + pad_w;
int32 )" + rx + "h[" + sz + "] = (" + rx + "bh % " + CH + R"() + pad_h; int )" + rx + "h[" + sz + "] = (" + rx + "bh % " + CH + R"() + pad_h;
int32 )" + rx + "b[" + sz + "] = " + rx + "bh / " + CH + ";"; int )" + rx + "b[" + sz + "] = " + rx + "bh / " + CH + ";";
} }
}; };
std::string result = std::string result =
R"( R"(
const tunable int32 TM = {16, 32, 64, 128}; const tunable int TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128}; const tunable int TN = {16, 32, 64, 128};
const tunable int32 TK = {)" + std::to_string(TK_) + "};"; const tunable int TK = {)" + std::to_string(TK_) + "};";
if(op_ == WGRAD) if(op_ == WGRAD)
result += "const tunable int32 GZ = {1};"; result += "const tunable int GZ = {1};";
else else
result += "const tunable int32 GZ = {1};"; result += "const tunable int GZ = {1};";
result += R"( result += R"(
__constant__ int32* delta_a = alloc_const int32[)" + std::to_string(MAX_C_) + R"(]; __constant__ int* delta_a = alloc_const int[)" + std::to_string(MAX_C_) + R"(];
void shift(restrict read_only align(16) )" + a_ty_ + R"( *A, void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
restrict read_only align(16) )" + b_ty_ + R"( *B, restrict read_only align(16) )" + b_ty_ + R"( *B,
)" + c_ty_ + R"( *C, )" + c_ty_ + R"( *C,
int32 M, int32 N, int32 K, int M, int N, int K,
int32 stride_h, int32 stride_w, int stride_h, int stride_w,
multiple_of(8) int32 lda_b, multiple_of(8) int32 lda_w, multiple_of(8) int32 lda_h, multiple_of(8) int32 lda_c, multiple_of(8) int lda_b, multiple_of(8) int lda_w, multiple_of(8) int lda_h, multiple_of(8) int lda_c,
multiple_of(8) int32 ldb_b, multiple_of(8) int32 ldb_w, multiple_of(8) int32 ldb_h, multiple_of(8) int32 ldb_c, multiple_of(8) int ldb_b, multiple_of(8) int ldb_w, multiple_of(8) int ldb_h, multiple_of(8) int ldb_c,
multiple_of(8) int32 ldc_b, multiple_of(8) int32 ldc_w, multiple_of(8) int32 ldc_h, multiple_of(8) int32 ldc_c, multiple_of(8) int ldc_b, multiple_of(8) int ldc_w, multiple_of(8) int ldc_h, multiple_of(8) int ldc_c,
int32 NB, int NB,
int32 AH, int32 AW, int AH, int AW,
int32 BH, int32 BW, int BH, int BW,
int32 CH, int32 CW, int CH, int CW,
int32* locks, int32 grid0, int32 grid1, int32 grid2) { int* locks, int grid0, int grid1, int grid2) {
int32 ridx = get_range_id(0); int ridx = get_range_id(0);
int32 ridy = get_range_id(1); int ridy = get_range_id(1);
int32 rz = get_range_id(2); int rz = get_range_id(2);
int32 rxa[TM] = ridx*TM + (0 ... TM); int rxa[TM] = ridx*TM + (0 ... TM);
int32 ryb[TN] = ridy*TN + (0 ... TN); int ryb[TN] = ridy*TN + (0 ... TN);
int32 rka[TK] = 0 ... TK; int rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK; int rkb[TK] = 0 ... TK;
fp32 acc[TM, TN] = 0; float acc[TM, TN] = 0;
int32 pad_h = BH / 2; int pad_h = BH / 2;
int32 pad_w = BW / 2;)"; int pad_w = BW / 2;)";
/* A offsets */ /* A offsets */
if(op_ == FPROP){ if(op_ == FPROP){
@@ -382,49 +371,49 @@ if(op_ == FPROP){
compute_bhw("ra", "TM", "rxa") + R"( compute_bhw("ra", "TM", "rxa") + R"(
raw = raw * )" + stride_w + R"(; raw = raw * )" + stride_w + R"(;
rah = rah * )" + stride_h + R"(; rah = rah * )" + stride_h + R"(;
int32 offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; int offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
int32 offa0[TM, TK] = offxa[:, newaxis]; int offa0[TM, TK] = offxa[:, newaxis];
__constant__ int32* pd[TK] = delta_a + rka; __constant__ int* pd[TK] = delta_a + rka;
multiple_of(8) int32 d[TK] = *pd; multiple_of(8) int d[TK] = *pd;
int32 offa1[TM, TK] = d[newaxis, :];)"; int offa1[TM, TK] = d[newaxis, :];)";
} }
if(op_ == BPROP){ if(op_ == BPROP){
result += result +=
compute_bhw("ra", "TM", "rxa") + R"( compute_bhw("ra", "TM", "rxa") + R"(
int32 offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; int offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
int32 offa0[TM, TK] = offxa[:, newaxis]; int offa0[TM, TK] = offxa[:, newaxis];
int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)"; int offa1[TM, TK] = rka[newaxis, :] * lda_c;)";
} }
if(op_ == WGRAD){ if(op_ == WGRAD){
result += result +=
compute_bhw("ra", "TK", "rka") + R"( compute_bhw("ra", "TK", "rka") + R"(
int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c; int offa0[TK, TM] = rxa[newaxis, :] * lda_c;
int32 offxa[TK] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; int offxa[TK] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h;
int32 offa1[TK, TM] = offxa[:, newaxis];)"; int offa1[TK, TM] = offxa[:, newaxis];)";
} }
/* B offsets */ /* B offsets */
if(op_ == FPROP){ if(op_ == FPROP){
result += R"( result += R"(
int32 offb0[TN, TK] = ryb[:, newaxis]; int offb0[TN, TK] = ryb[:, newaxis];
int32 offb1[TN, TK] = rkb[newaxis, :] * ldb_c;)"; int offb1[TN, TK] = rkb[newaxis, :] * ldb_c;)";
} }
if(op_ == BPROP){ if(op_ == BPROP){
result += R"( result += R"(
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; int offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
int32 offb1[TK, TN] = rkb[:, newaxis];)"; int offb1[TK, TN] = rkb[:, newaxis];)";
} }
if(op_ == WGRAD){ if(op_ == WGRAD){
result += result +=
compute_bhw("rb", "TK", "rkb") + R"( compute_bhw("rb", "TK", "rkb") + R"(
__constant__ int32* pd[TN] = delta_a + ryb; __constant__ int* pd[TN] = delta_a + ryb;
multiple_of(8) int32 d[TN] = *pd; multiple_of(8) int d[TN] = *pd;
multiple_of(8) int32 shift[TK, TN] = d[newaxis, :]; multiple_of(8) int shift[TK, TN] = d[newaxis, :];
rbw = rbw * )" + stride_w + R"(; rbw = rbw * )" + stride_w + R"(;
rbh = rbh * )" + stride_h + R"(; rbh = rbh * )" + stride_h + R"(;
int32 offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h; int offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h;
int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; int offb0[TK, TN] = ryb[newaxis, :] * ldb_c;
int32 offb1[TK, TN] = offkb[:, newaxis]; int offb1[TK, TN] = offkb[:, newaxis];
)" + a_ty_ + "* pa_base[" + AS + R"(] = A + offa0; )" + a_ty_ + "* pa_base[" + AS + R"(] = A + offa0;
)" + b_ty_ + "* pb_base[" + BS + R"(] = B + offb0 + shift; )" + b_ty_ + "* pb_base[" + BS + R"(] = B + offb0 + shift;
)" + a_ty_ + "* pa[" + AS + R"(] = pa_base + offa1; )" + a_ty_ + "* pa[" + AS + R"(] = pa_base + offa1;
@@ -439,14 +428,14 @@ else{
/* Main loop */ /* Main loop */
/* Increment A pointers */ /* Increment A pointers */
result += R"( result += R"(
int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(; bool checka[)" + AS + "] = (rka < K)" + bca0 + R"(;
int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(; bool checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(;
)" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0; )" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0;
)" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0; )" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0;
for(int32 k = K; k > 0; k = k - TK){ for(int k = K; k > 0; k = k - TK){
acc = dot()" + usea + "," + useb + R"(, acc); acc = dot()" + usea + "," + useb + R"(, acc);
int1 checka[)" + AS + R"(] = k > TK; bool checka[)" + AS + R"(] = k > TK;
int1 checkb[)" + BS + R"(] = k > TK;)"; bool checkb[)" + BS + R"(] = k > TK;)";
/* Increment A pointers */ /* Increment A pointers */
if(op_ == FPROP){ if(op_ == FPROP){
@@ -490,8 +479,8 @@ if(op_ == BPROP){
result += R"( result += R"(
b = checkb ? *pb : 0; b = checkb ? *pb : 0;
} }
int32 rxc[TM] = ridx*TM + (0 ... TM); int rxc[TM] = ridx*TM + (0 ... TM);
int32 ryc[TN] = ridy*TN + (0 ... TN);)"; int ryc[TN] = ridy*TN + (0 ... TN);)";
/* C offsets */ /* C offsets */
if(op_ == BPROP){ if(op_ == BPROP){
@@ -499,26 +488,26 @@ if(op_ == BPROP){
compute_bhw("rc", "TM", "rxc") + R"( compute_bhw("rc", "TM", "rxc") + R"(
rcw = rcw * )" + stride_w + R"(; rcw = rcw * )" + stride_w + R"(;
rch = rch * )" + stride_h + R"(; rch = rch * )" + stride_h + R"(;
int32 offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)"; int offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)";
} }
if(op_ == FPROP){ if(op_ == FPROP){
result += result +=
compute_bhw("rc", "TM", "rxc") + R"( compute_bhw("rc", "TM", "rxc") + R"(
int32 offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)"; int offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)";
} }
if(op_ == WGRAD){ if(op_ == WGRAD){
result += R"( result += R"(
int32 offxc[TM] = rxc;)"; int offxc[TM] = rxc;)";
} }
result += R"(" result += R"("
)" + c_ty_ + R"( c[TM, TN] = acc; )" + c_ty_ + R"( c[TM, TN] = acc;
)" + c_ty_ + R"(* pc[TM, TN] = C + offxc[:, newaxis] + ryc[newaxis, :]*ldc_c; )" + c_ty_ + R"(* pc[TM, TN] = C + offxc[:, newaxis] + ryc[newaxis, :]*ldc_c;
int1 checkc0[TM] = rxc < M; bool checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N; bool checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)"; bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)";
if(op_ == BPROP){ if(op_ == BPROP){
result += R"( result += R"(
__constant__ int32* pd[TN] = delta_a + ryc; __constant__ int* pd[TN] = delta_a + ryc;
)" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :]; )" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :];
@checkc *shift_pc = c; @checkc *shift_pc = c;
)"; )";

View File

@@ -174,8 +174,15 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
for(ir::metaparameter *mp: mps) for(ir::metaparameter *mp: mps)
mp->set_value(params[i++]); mp->set_value(params[i++]);
// for(size_t i = 0; i < params.size(); i++)
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
// std::cout << std::endl;
passes_0.tune.init(tt_module_0); passes_0.tune.init(tt_module_0);
passes_0.tune.check_constraints(errors); passes_0.tune.check_constraints(errors);
// for(auto x: errors)
// for(auto e: x.second){
// std::cout << x.first->get_name() << ": " << e << std::endl;
// }
} }
if(!errors.empty()) if(!errors.empty())
return; return;
@@ -212,9 +219,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
best.perf = perf; best.perf = perf;
best.params = params; best.params = params;
} }
for(size_t i = 0; i < params.size(); i++) // for(size_t i = 0; i < params.size(); i++)
std::cout << ((i==0)?"":", ") << params[i] << std::flush; // std::cout << ((i==0)?"":", ") << params[i] << std::flush;
std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl; // std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl;
} }
}; };