more stuff
This commit is contained in:
@@ -13,7 +13,7 @@ int main() {
|
||||
// initialize default compute device
|
||||
auto context = triton::driver::backend::contexts::get_default();
|
||||
// matrix multiplication parameters
|
||||
int32_t M = 32768, N = 128, K = 128;
|
||||
int32_t M = 8192, N = 8192, K = 8192;
|
||||
std::vector<float> hc(M*N);
|
||||
std::vector<float> rc(M*N);
|
||||
std::vector<float> ha(M*K);
|
||||
|
@@ -30,7 +30,7 @@ void base::enqueue(driver::stream *stream, std::vector<driver::buffer *> args, b
|
||||
/* the current template has not already been compiled */
|
||||
if(m_jit.find(this) == m_jit.end()) {
|
||||
base* clone = this->clone();
|
||||
jit = m_jit.emplace(clone, new rt::jit(ctx)).first->second.get();
|
||||
jit = m_jit.emplace(clone, std::unique_ptr<rt::jit>(new rt::jit(ctx))).first->second.get();
|
||||
std::ostringstream oss;
|
||||
clone->triton_c_src(oss);
|
||||
std::string src = oss.str();
|
||||
|
@@ -106,7 +106,7 @@ void gemm::triton_c_src(std::ostream &os) const {
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {16, 32, 64, 128};
|
||||
const tunable int32 TK = {8};
|
||||
const tunable int32 TK = {16};
|
||||
const tunable int32 GZ = {1};
|
||||
|
||||
void matmul(restrict read_only )" + a_ty_ + R"( *A,
|
||||
|
@@ -214,9 +214,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
best.perf = perf;
|
||||
best.params = params;
|
||||
}
|
||||
// for(unsigned p: params)
|
||||
// std::cout << p << " " << std::flush;
|
||||
// std::cout << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
for(unsigned p: params)
|
||||
std::cout << p << " " << std::flush;
|
||||
std::cout << perf << " [ " << best.perf << " ] " << std::endl;
|
||||
}
|
||||
}, nthreads_);
|
||||
std::cout << "Autotuning done - Best performance: " << best.perf << std::endl;
|
||||
|
Reference in New Issue
Block a user