[Triton-MLIR][BACKEND] Make mmav1 works on basic cases (#944)
TODO: - Add more cases - Currently, we just set vec to 4 to make the basic cases pass Issue: - the vec in shared layout is different compared to master branch - when vec=1, it encounters CUDA misalignment error, it doesn't work in master branch as well - when setting vec to the value identical to master branch, the MMA works
This commit is contained in:
@@ -756,6 +756,7 @@ public:
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
int version = computeCapabilityToMMAVersion(computeCapability);
|
||||
|
||||
auto newRetType = RankedTensorType::get(
|
||||
retShape, oldRetType.getElementType(),
|
||||
triton::gpu::MmaEncodingAttr::get(
|
||||
|
Reference in New Issue
Block a user