more improvements and regressions

This commit is contained in:
Philippe Tillet
2019-08-06 16:21:20 -07:00
parent 26c9849462
commit 5efdb7978e
14 changed files with 138 additions and 69 deletions

View File

@@ -74,22 +74,24 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
void dot::triton_c_src(std::ostream &os) const {
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string XAS0 = "TM", XAS1 = "TK/4", XAS2 = "4";
std::string XBS0 = "TK/4", XBS1 = "TN", XBS2 = "4";
std::string XAS0 = "TM", XAS1 = "TK/1", XAS2 = "1";
std::string XBS0 = "TK/1", XBS1 = "1", XBS2 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT_ ? "trans(xa)" : "xa";
std::string useb = BT_ ? "trans(xb)" : "xb";
std::string usea = AT_ ? "trans(xa, 0, 2, 1)" : "xa";
std::string useb = BT_ ? "trans(xb, 1, 0, 2)" : "trans(xb, 0, 2, 1)";
if(AT_){
std::swap(AS0, AS1);
std::swap(XAS0, XAS1);
std::swap(XAS1, XAS2);
std::swap(bca0, bca1);
std::swap(lda0, lda1);
}
if(BT_){
std::swap(BS0, BS1);
std::swap(XBS1, XBS2);
std::swap(XBS0, XBS1);
std::swap(bcb0, bcb1);
std::swap(ldb0, ldb1);
@@ -98,7 +100,7 @@ void dot::triton_c_src(std::ostream &os) const {
std::string BS = BS0 + ", " + BS1;
std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2;
std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2;
std::string XCS = "TM, TN, 4";
std::string XCS = "TM, TN, 1";
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
std::string res =
@@ -146,7 +148,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
}
)";
std::cout << res << std::endl;
// std::cout << res << std::endl;
os << res;
}