[lang] added support for batched matrix multiplication
This commit is contained in:
@@ -96,7 +96,6 @@ void allocation::run(ir::module &mod) {
|
||||
offsets_[x] = starts[x] + colors[x] * Adj;
|
||||
}
|
||||
|
||||
|
||||
// Save maximum size of induced memory space
|
||||
allocated_size_ = 0;
|
||||
for(layout_t* x: V)
|
||||
|
@@ -77,11 +77,6 @@ void axes::update_graph_dot(ir::instruction *i) {
|
||||
// add edges between result and accumulator
|
||||
for(unsigned d = 0; d < shapes.size(); d++)
|
||||
graph_.add_edge({dot, d}, {D, d});
|
||||
// add edge for batch dimension
|
||||
for(unsigned d = 2; d < shapes.size(); d++){
|
||||
graph_.add_edge({dot, d}, {A, d});
|
||||
graph_.add_edge({dot, d}, {B, d});
|
||||
}
|
||||
}
|
||||
|
||||
void axes::update_graph_elementwise(ir::instruction *i) {
|
||||
|
@@ -300,6 +300,10 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
||||
}
|
||||
std::vector<int> col = {0, 1};
|
||||
std::vector<int> row = {1, 0};
|
||||
for(size_t s = 2; s < shapes.size(); s++){
|
||||
col.push_back(s);
|
||||
row.push_back(s);
|
||||
}
|
||||
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
||||
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
||||
if(is_nonhmma_dot_a)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#include <climits>
|
||||
#include <iostream>
|
||||
#include "triton/codegen/analysis/liveness.h"
|
||||
#include "triton/codegen/analysis/layout.h"
|
||||
#include "triton/ir/function.h"
|
||||
@@ -37,12 +38,13 @@ void liveness::run(ir::module &mod) {
|
||||
}
|
||||
// compute intervals
|
||||
unsigned start = INT32_MAX;
|
||||
for(ir::value *v: layout->values)
|
||||
if(indices.find(v) != indices.end())
|
||||
start = std::min(start, indices.at(v));
|
||||
unsigned end = 0;
|
||||
for(ir::user *u: users)
|
||||
if(indices.find(u) != indices.end()){
|
||||
start = std::min(start, indices.at(u));
|
||||
end = std::max(end, indices.at(u));
|
||||
}
|
||||
if(indices.find(u) != indices.end())
|
||||
end = std::max(end, indices.at(u));
|
||||
intervals_[layout] = segment{start, end};
|
||||
}
|
||||
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include "triton/ir/basic_block.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/type.h"
|
||||
|
@@ -471,12 +471,20 @@ void BinaryOp::MatmulOpTypeChecking() {
|
||||
auto rhsShape = rhsType->Shape();
|
||||
size_t lhsRank = lhsShape.size();
|
||||
size_t rhsRank = rhsShape.size();
|
||||
if(lhsRank != 2 || rhsRank != 2)
|
||||
Error(this, "matrix multiplication operands must have rank 2");
|
||||
if(lhsRank != rhsRank)
|
||||
Error(this, "matrix multiplication operands have incompatible rank"
|
||||
"%d and %d", lhsRank, rhsRank);
|
||||
for(int d = 2; d < lhsRank; d++)
|
||||
if(lhsShape[d] != rhsShape[d])
|
||||
Error(this, "matrix multiplication operands have incompatible batch dimension"
|
||||
"%d and %d for axis %d", lhsShape[d], rhsShape[d], d);
|
||||
if(lhsShape[1] != rhsShape[0])
|
||||
Error(this, "matrix multiplication operands have incompatible inner dimension"
|
||||
" %d and %d", lhsShape[1], rhsShape[0]);
|
||||
// ret shape
|
||||
TileType::ShapeInt retShape = {lhsShape[0], rhsShape[1]};
|
||||
for(int d = 2; d < lhsRank; d++)
|
||||
retShape.push_back(lhsShape[d]);
|
||||
QualType retType = lhsType->Derived();
|
||||
if(retType != rhsType->Derived())
|
||||
Error(this, "matrix multiplication operands have incompatible data types");
|
||||
|
@@ -219,7 +219,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
codegen::transform::cts cts;
|
||||
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
||||
// run passes
|
||||
// ir::print(module, std::cout);
|
||||
// ir::print(module, std::cout);
|
||||
peephole.run(module);
|
||||
dce.run(module);
|
||||
align.run(module);
|
||||
|
Reference in New Issue
Block a user