didn't break correctness of existing HMMA
This commit is contained in:
@@ -26,8 +26,8 @@ struct perf_t {
|
||||
|
||||
|
||||
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
|
||||
typedef float NumericT;
|
||||
std::string ty = "float";
|
||||
typedef half NumericT;
|
||||
std::string ty = "half";
|
||||
size_t dt_nbytes = sizeof(NumericT);
|
||||
triton::driver::context* context = stream->context();
|
||||
std::vector<NumericT> hc(M*N);
|
||||
@@ -112,9 +112,9 @@ int main() {
|
||||
std::vector<config_t> configs = {
|
||||
// {false, false, 8192, 512, 512},
|
||||
// {false, true, 8192, 8192, 8192}
|
||||
// {false, true, 128, 128, 128},
|
||||
// {false, false, 128, 128, 128},
|
||||
// {true, false, 128, 128, 128},
|
||||
{false, true, 128, 128, 128},
|
||||
{false, false, 128, 128, 128},
|
||||
{true, false, 128, 128, 128},
|
||||
{true, true, 128, 128, 128}
|
||||
|
||||
// {false, true, 32768, 256, 512}
|
||||
|
@@ -999,6 +999,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
auto A_shapes = A->get_type()->get_tile_shapes();
|
||||
size_t red_axis = dot->is_a_trans() ? 0 : 1;
|
||||
unsigned NK = A_shapes[red_axis]->get_value();
|
||||
// std::cout << red_axis << " " << NK << std::endl;
|
||||
if(NK != 1)
|
||||
{
|
||||
shared_tile *TA = (shared_tile*)tmap_.at(A);
|
||||
|
@@ -62,11 +62,11 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else{
|
||||
// params_t params = heuristics();
|
||||
params_t params = heuristics();
|
||||
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
|
||||
// params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 4, 4, 1}; //NT
|
||||
// params_t params = {4, 1, 32, 4, 32, 4, 4, 4, 1, 1, 16, 32, 16, 1, 4, 4, 4, 4, 4, 1}; //NN
|
||||
params_t params = {4, 16, 4, 2, 16, 4, 8, 2, 2, 8, 2, 32, 8, 1}; // TT
|
||||
// params_t params = {4, 16, 4, 2, 16, 4, 8, 2, 2, 8, 2, 32, 8, 1}; // TT
|
||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
|
@@ -80,7 +80,7 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT_ ? "trans(xa, 0, 2, 1)" : "xa";
|
||||
std::string usea = AT_ ? "trans(xa, 2, 0, 1)" : "xa";
|
||||
std::string useb = BT_ ? "trans(xb, 1, 0, 2)" : "trans(xb, 0, 2, 1)";
|
||||
if(AT_){
|
||||
std::swap(AS0, AS1);
|
||||
@@ -149,7 +149,6 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
}
|
||||
)";
|
||||
|
||||
// std::cout << res << std::endl;
|
||||
os << res;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user