[lang] added support for batched matrix multiplication

This commit is contained in:
Philippe Tillet
2019-10-21 15:41:50 -04:00
parent e827d4f467
commit b81734553b
7 changed files with 22 additions and 13 deletions

View File

@@ -96,7 +96,6 @@ void allocation::run(ir::module &mod) {
offsets_[x] = starts[x] + colors[x] * Adj;
}
// Save maximum size of induced memory space
allocated_size_ = 0;
for(layout_t* x: V)

View File

@@ -77,11 +77,6 @@ void axes::update_graph_dot(ir::instruction *i) {
// add edges between result and accumulator
for(unsigned d = 0; d < shapes.size(); d++)
graph_.add_edge({dot, d}, {D, d});
// add edge for batch dimension
for(unsigned d = 2; d < shapes.size(); d++){
graph_.add_edge({dot, d}, {A, d});
graph_.add_edge({dot, d}, {B, d});
}
}
void axes::update_graph_elementwise(ir::instruction *i) {

View File

@@ -300,6 +300,10 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
}
std::vector<int> col = {0, 1};
std::vector<int> row = {1, 0};
for(size_t s = 2; s < shapes.size(); s++){
col.push_back(s);
row.push_back(s);
}
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
if(is_nonhmma_dot_a)

View File

@@ -1,4 +1,5 @@
#include <climits>
#include <iostream>
#include "triton/codegen/analysis/liveness.h"
#include "triton/codegen/analysis/layout.h"
#include "triton/ir/function.h"
@@ -37,12 +38,13 @@ void liveness::run(ir::module &mod) {
}
// compute intervals
unsigned start = INT32_MAX;
for(ir::value *v: layout->values)
if(indices.find(v) != indices.end())
start = std::min(start, indices.at(v));
unsigned end = 0;
for(ir::user *u: users)
if(indices.find(u) != indices.end()){
start = std::min(start, indices.at(u));
end = std::max(end, indices.at(u));
}
if(indices.find(u) != indices.end())
end = std::max(end, indices.at(u));
intervals_[layout] = segment{start, end};
}

View File

@@ -1,4 +1,5 @@
#include <algorithm>
#include <iostream>
#include "triton/ir/basic_block.h"
#include "triton/ir/module.h"
#include "triton/ir/type.h"

View File

@@ -471,12 +471,20 @@ void BinaryOp::MatmulOpTypeChecking() {
auto rhsShape = rhsType->Shape();
size_t lhsRank = lhsShape.size();
size_t rhsRank = rhsShape.size();
if(lhsRank != 2 || rhsRank != 2)
Error(this, "matrix multiplication operands must have rank 2");
if(lhsRank != rhsRank)
Error(this, "matrix multiplication operands have incompatible rank"
"%d and %d", lhsRank, rhsRank);
for(int d = 2; d < lhsRank; d++)
if(lhsShape[d] != rhsShape[d])
Error(this, "matrix multiplication operands have incompatible batch dimension"
"%d and %d for axis %d", lhsShape[d], rhsShape[d], d);
if(lhsShape[1] != rhsShape[0])
Error(this, "matrix multiplication operands have incompatible inner dimension"
" %d and %d", lhsShape[1], rhsShape[0]);
// ret shape
TileType::ShapeInt retShape = {lhsShape[0], rhsShape[1]};
for(int d = 2; d < lhsRank; d++)
retShape.push_back(lhsShape[d]);
QualType retType = lhsType->Derived();
if(retType != rhsType->Derived())
Error(this, "matrix multiplication operands have incompatible data types");

View File

@@ -219,7 +219,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::transform::cts cts;
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
// run passes
// ir::print(module, std::cout);
// ir::print(module, std::cout);
peephole.run(module);
dce.run(module);
align.run(module);