[Triton-MLIR] Add compute capability (#902)
add compute capability from python frontend to backend. Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -13,7 +13,7 @@ std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
|
|||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUCoalescePass();
|
std::unique_ptr<Pass> createTritonGPUCoalescePass();
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass();
|
std::unique_ptr<Pass> createTritonGPUCombineOpsPass(int computeCapability = 80);
|
||||||
|
|
||||||
std::unique_ptr<Pass> createTritonGPUVerifier();
|
std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||||
|
|
||||||
|
@@ -63,6 +63,12 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
|||||||
|
|
||||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||||
"mlir::triton::TritonDialect"];
|
"mlir::triton::TritonDialect"];
|
||||||
|
|
||||||
|
let options = [
|
||||||
|
Option<"computeCapability", "compute-capability",
|
||||||
|
"int32_t", /*default*/"80",
|
||||||
|
"device compute capability">
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
||||||
|
@@ -15,6 +15,7 @@
|
|||||||
#include "triton/Analysis/Utility.h"
|
#include "triton/Analysis/Utility.h"
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||||
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
@@ -560,11 +561,37 @@ public:
|
|||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
//
|
//
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
static int computeCapabilityToMMAVersion(int computeCapability) {
|
||||||
|
if (computeCapability < 80) {
|
||||||
|
return 1;
|
||||||
|
} else if (computeCapability < 90) {
|
||||||
|
return 2;
|
||||||
|
} else {
|
||||||
|
assert(false && "computeCapability > 90 not supported");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<int64_t, 2>
|
||||||
|
mmaVersionToShapePerWarp(int version, const ArrayRef<int64_t> &shape,
|
||||||
|
int numWarps) {
|
||||||
|
if (version == 1) {
|
||||||
|
return {16, 16};
|
||||||
|
} else if (version == 2) {
|
||||||
|
return {16, 8};
|
||||||
|
} else {
|
||||||
|
assert(false && "version not supported");
|
||||||
|
return {0, 0};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class BlockedToMMA : public mlir::RewritePattern {
|
class BlockedToMMA : public mlir::RewritePattern {
|
||||||
|
int computeCapability;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
BlockedToMMA(mlir::MLIRContext *context)
|
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
|
||||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {}
|
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
||||||
|
computeCapability(computeCapability) {}
|
||||||
|
|
||||||
static SmallVector<unsigned, 2>
|
static SmallVector<unsigned, 2>
|
||||||
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) {
|
getWarpsPerTile(const ArrayRef<int64_t> &shape, int version, int numWarps) {
|
||||||
@@ -572,7 +599,8 @@ public:
|
|||||||
// TODO: Handle one warp per row for fused matmuls
|
// TODO: Handle one warp per row for fused matmuls
|
||||||
// TODO: unsigned -> int64_t to keep things uniform
|
// TODO: unsigned -> int64_t to keep things uniform
|
||||||
SmallVector<unsigned, 2> ret = {1, 1};
|
SmallVector<unsigned, 2> ret = {1, 1};
|
||||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
SmallVector<int64_t, 2> shapePerWarp =
|
||||||
|
mmaVersionToShapePerWarp(version, shape, numWarps);
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
// TODO (@daadaada): double-check.
|
// TODO (@daadaada): double-check.
|
||||||
// original logic in
|
// original logic in
|
||||||
@@ -615,11 +643,12 @@ public:
|
|||||||
auto retShape = oldRetType.getShape();
|
auto retShape = oldRetType.getShape();
|
||||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||||
auto newRetType =
|
int version = computeCapabilityToMMAVersion(computeCapability);
|
||||||
RankedTensorType::get(retShape, oldRetType.getElementType(),
|
auto newRetType = RankedTensorType::get(
|
||||||
|
retShape, oldRetType.getElementType(),
|
||||||
triton::gpu::MmaEncodingAttr::get(
|
triton::gpu::MmaEncodingAttr::get(
|
||||||
oldRetType.getContext(), 2,
|
oldRetType.getContext(), version,
|
||||||
getWarpsPerTile(retShape, 2, numWarps)));
|
getWarpsPerTile(retShape, version, numWarps)));
|
||||||
// convert accumulator
|
// convert accumulator
|
||||||
auto oldAcc = dotOp.getOperand(2);
|
auto oldAcc = dotOp.getOperand(2);
|
||||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||||
@@ -656,6 +685,10 @@ public:
|
|||||||
class TritonGPUCombineOpsPass
|
class TritonGPUCombineOpsPass
|
||||||
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
|
: public TritonGPUCombineOpsBase<TritonGPUCombineOpsPass> {
|
||||||
public:
|
public:
|
||||||
|
TritonGPUCombineOpsPass() = default;
|
||||||
|
TritonGPUCombineOpsPass(int computeCapability) {
|
||||||
|
this->computeCapability = computeCapability;
|
||||||
|
}
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ModuleOp m = getOperation();
|
ModuleOp m = getOperation();
|
||||||
@@ -667,7 +700,7 @@ public:
|
|||||||
patterns.add<RematerializeBackward>(context);
|
patterns.add<RematerializeBackward>(context);
|
||||||
patterns.add<RematerializeForward>(context);
|
patterns.add<RematerializeForward>(context);
|
||||||
patterns.add<MoveConvertOutOfLoop>(context);
|
patterns.add<MoveConvertOutOfLoop>(context);
|
||||||
patterns.add<BlockedToMMA>(context);
|
patterns.add<BlockedToMMA>(context, computeCapability);
|
||||||
|
|
||||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
@@ -675,6 +708,7 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
|
std::unique_ptr<Pass>
|
||||||
return std::make_unique<TritonGPUCombineOpsPass>();
|
mlir::createTritonGPUCombineOpsPass(int computeCapability) {
|
||||||
|
return std::make_unique<TritonGPUCombineOpsPass>(computeCapability);
|
||||||
}
|
}
|
||||||
|
@@ -1279,8 +1279,9 @@ void init_triton_ir(py::module &&m) {
|
|||||||
self.addPass(mlir::createTritonGPUPrefetchPass());
|
self.addPass(mlir::createTritonGPUPrefetchPass());
|
||||||
})
|
})
|
||||||
.def("add_triton_gpu_combine_pass",
|
.def("add_triton_gpu_combine_pass",
|
||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self, int computeCapability) {
|
||||||
self.addPass(mlir::createTritonGPUCombineOpsPass());
|
self.addPass(
|
||||||
|
mlir::createTritonGPUCombineOpsPass(computeCapability));
|
||||||
})
|
})
|
||||||
.def("add_triton_gpu_to_llvm",
|
.def("add_triton_gpu_to_llvm",
|
||||||
[](mlir::PassManager &self) {
|
[](mlir::PassManager &self) {
|
||||||
@@ -1301,7 +1302,7 @@ void init_triton_translation(py::module &m) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"translate_triton_gpu_to_llvmir",
|
"translate_triton_gpu_to_llvmir",
|
||||||
[](mlir::ModuleOp op) {
|
[](mlir::ModuleOp op, int computeCapability) {
|
||||||
llvm::LLVMContext llvmContext;
|
llvm::LLVMContext llvmContext;
|
||||||
auto llvmModule =
|
auto llvmModule =
|
||||||
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
|
::mlir::triton::translateTritonGPUToLLVMIR(&llvmContext, op);
|
||||||
|
@@ -15,10 +15,10 @@ import sysconfig
|
|||||||
import tempfile
|
import tempfile
|
||||||
import warnings
|
import warnings
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from pathlib import Path
|
||||||
from sysconfig import get_paths
|
from sysconfig import get_paths
|
||||||
from typing import Any, Callable, Dict, Tuple, Union
|
from typing import Any, Callable, Dict, Tuple, Union
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
import setuptools
|
import setuptools
|
||||||
import torch
|
import torch
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
@@ -876,23 +876,23 @@ def ast_to_ttir(fn, signature, specialization, constants):
|
|||||||
return optimize_triton_ir(mod)
|
return optimize_triton_ir(mod)
|
||||||
|
|
||||||
|
|
||||||
def ttir_to_ttgir(mod, num_warps, num_stages):
|
def ttir_to_ttgir(mod, num_warps, num_stages, compute_capability):
|
||||||
pm = _triton.ir.pass_manager(mod.context)
|
pm = _triton.ir.pass_manager(mod.context)
|
||||||
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
pm.add_convert_triton_to_tritongpu_pass(num_warps)
|
||||||
pm.enable_debug()
|
pm.enable_debug()
|
||||||
# Convert blocked layout to mma layout for dot ops so that pipeline
|
# Convert blocked layout to mma layout for dot ops so that pipeline
|
||||||
# can get shared memory swizzled correctly.
|
# can get shared memory swizzled correctly.
|
||||||
pm.add_coalesce_pass()
|
pm.add_coalesce_pass()
|
||||||
pm.add_triton_gpu_combine_pass()
|
pm.add_triton_gpu_combine_pass(compute_capability)
|
||||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||||
# Prefetch must be done after pipeline pass because pipeline pass
|
# Prefetch must be done after pipeline pass because pipeline pass
|
||||||
# extracts slices from the original tensor.
|
# extracts slices from the original tensor.
|
||||||
pm.add_tritongpu_prefetch_pass()
|
pm.add_tritongpu_prefetch_pass()
|
||||||
pm.add_canonicalizer_pass()
|
pm.add_canonicalizer_pass()
|
||||||
pm.add_cse_pass()
|
pm.add_cse_pass()
|
||||||
pm.add_triton_gpu_combine_pass()
|
pm.add_triton_gpu_combine_pass(compute_capability)
|
||||||
pm.add_licm_pass()
|
pm.add_licm_pass()
|
||||||
pm.add_triton_gpu_combine_pass()
|
pm.add_triton_gpu_combine_pass(compute_capability)
|
||||||
pm.add_cse_pass()
|
pm.add_cse_pass()
|
||||||
pm.run(mod)
|
pm.run(mod)
|
||||||
return mod
|
return mod
|
||||||
@@ -905,13 +905,13 @@ def add_external_libs(mod, libs):
|
|||||||
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
_triton.add_external_libs(mod, list(libs.keys()), list(libs.values()))
|
||||||
|
|
||||||
|
|
||||||
def ttgir_to_llir(mod, extern_libs):
|
def ttgir_to_llir(mod, extern_libs, compute_capability):
|
||||||
if extern_libs:
|
if extern_libs:
|
||||||
add_external_libs(mod, extern_libs)
|
add_external_libs(mod, extern_libs)
|
||||||
return _triton.translate_triton_gpu_to_llvmir(mod)
|
return _triton.translate_triton_gpu_to_llvmir(mod, compute_capability)
|
||||||
|
|
||||||
|
|
||||||
def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = None) -> Tuple[str, int]:
|
def llir_to_ptx(mod: Any, compute_capability: int, ptx_version: int = None) -> Tuple[str, int]:
|
||||||
'''
|
'''
|
||||||
Translate TritonGPU module to PTX code.
|
Translate TritonGPU module to PTX code.
|
||||||
:param mod: a TritonGPU dialect module
|
:param mod: a TritonGPU dialect module
|
||||||
@@ -919,26 +919,20 @@ def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = Non
|
|||||||
- PTX code
|
- PTX code
|
||||||
- shared memory allocation size
|
- shared memory allocation size
|
||||||
'''
|
'''
|
||||||
if compute_capability is None:
|
|
||||||
device = torch.cuda.current_device()
|
|
||||||
compute_capability = torch.cuda.get_device_capability(device)
|
|
||||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
|
||||||
if ptx_version is None:
|
if ptx_version is None:
|
||||||
_, cuda_version = path_to_ptxas()
|
_, cuda_version = path_to_ptxas()
|
||||||
ptx_version = ptx_get_version(cuda_version)
|
ptx_version = ptx_get_version(cuda_version)
|
||||||
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
|
return _triton.translate_llvmir_to_ptx(mod, compute_capability, ptx_version)
|
||||||
|
|
||||||
|
|
||||||
def ptx_to_cubin(ptx: str, device: int):
|
def ptx_to_cubin(ptx: str, compute_capability: int):
|
||||||
'''
|
'''
|
||||||
Compile TritonGPU module to cubin.
|
Compile TritonGPU module to cubin.
|
||||||
:param ptx: ptx code
|
:param ptx: ptx code
|
||||||
:param device: CUDA device
|
:param compute_capability: compute capability
|
||||||
:return: str
|
:return: str
|
||||||
'''
|
'''
|
||||||
ptxas, _ = path_to_ptxas()
|
ptxas, _ = path_to_ptxas()
|
||||||
compute_capability = torch.cuda.get_device_capability(device)
|
|
||||||
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
|
||||||
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
|
return _triton.compile_ptx_to_cubin(ptx, ptxas, compute_capability)
|
||||||
|
|
||||||
|
|
||||||
@@ -1304,6 +1298,8 @@ def read_or_execute(cache_manager, force_compile, file_name, metadata,
|
|||||||
return module, md5, True, False
|
return module, md5, True, False
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def make_stub(name, signature, constants):
|
def make_stub(name, signature, constants):
|
||||||
# name of files that are cached
|
# name of files that are cached
|
||||||
so_cache_key = make_so_cache_key(signature, constants)
|
so_cache_key = make_so_cache_key(signature, constants)
|
||||||
@@ -1328,6 +1324,7 @@ def convert_type_repr(x):
|
|||||||
return '*' + convert_type_repr(match.group(1))
|
return '*' + convert_type_repr(match.group(1))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def make_hash(fn, **kwargs):
|
def make_hash(fn, **kwargs):
|
||||||
if isinstance(fn, triton.runtime.JITFunction):
|
if isinstance(fn, triton.runtime.JITFunction):
|
||||||
configs = kwargs["configs"]
|
configs = kwargs["configs"]
|
||||||
@@ -1344,14 +1341,13 @@ def make_hash(fn, **kwargs):
|
|||||||
return hashlib.md5(Path(fn).read_text().encode("utf-8")).hexdigest()
|
return hashlib.md5(Path(fn).read_text().encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
# def compile(fn, signature: str, device: int = -1, constants=dict(), num_warps: int = 4, num_stages: int = 3, extern_libs=None, configs=None):
|
||||||
def compile(fn, **kwargs):
|
def compile(fn, **kwargs):
|
||||||
# we get the kernel, i.e. the first function generated in the module
|
# we get the kernel, i.e. the first function generated in the module
|
||||||
# if fn is not a JITFunction, then it
|
# if fn is not a JITFunction, then it
|
||||||
# has to be a path to a file
|
# has to be a path to a file
|
||||||
context = _triton.ir.context()
|
context = _triton.ir.context()
|
||||||
asm, md5 = dict(), dict()
|
asm = dict()
|
||||||
constants = kwargs.get("constants", dict())
|
constants = kwargs.get("constants", dict())
|
||||||
if isinstance(fn, triton.runtime.JITFunction):
|
if isinstance(fn, triton.runtime.JITFunction):
|
||||||
configs = kwargs.get("configs", None)
|
configs = kwargs.get("configs", None)
|
||||||
@@ -1389,6 +1385,8 @@ def compile(fn, **kwargs):
|
|||||||
num_stages = kwargs.get("num_stages", 3)
|
num_stages = kwargs.get("num_stages", 3)
|
||||||
extern_libs = kwargs.get("extern_libs", dict())
|
extern_libs = kwargs.get("extern_libs", dict())
|
||||||
device = kwargs.get("device", torch.cuda.current_device())
|
device = kwargs.get("device", torch.cuda.current_device())
|
||||||
|
compute_capability = torch.cuda.get_device_capability(device)
|
||||||
|
compute_capability = compute_capability[0] * 10 + compute_capability[1]
|
||||||
# load metadata if any
|
# load metadata if any
|
||||||
metadata = None
|
metadata = None
|
||||||
if fn_cache_manager.has_file(f'{name}.json'):
|
if fn_cache_manager.has_file(f'{name}.json'):
|
||||||
@@ -1398,17 +1396,17 @@ def compile(fn, **kwargs):
|
|||||||
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
|
metadata = {"num_warps": num_warps, "num_stages": num_stages, "ctime": dict()}
|
||||||
# build compilation stages
|
# build compilation stages
|
||||||
stages = {
|
stages = {
|
||||||
"ast" : (lambda path: fn, None),
|
"ast": (lambda path: fn, None),
|
||||||
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
"ttir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||||
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
"ttgir": (lambda path: _triton.ir.parse_mlir_module(path, context),
|
||||||
lambda src: ttir_to_ttgir(src, num_warps, num_stages)),
|
lambda src: ttir_to_ttgir(src, num_warps, num_stages, compute_capability)),
|
||||||
"llir": (lambda path: Path(path).read_bytes(),
|
"llir": (lambda path: Path(path).read_bytes(),
|
||||||
lambda src: ttgir_to_llir(src, extern_libs)),
|
lambda src: ttgir_to_llir(src, extern_libs, compute_capability)),
|
||||||
"ptx": (lambda path: Path(path).read_text(),
|
"ptx": (lambda path: Path(path).read_text(),
|
||||||
llir_to_ptx),
|
lambda src: llir_to_ptx(src, compute_capability)),
|
||||||
"cubin": (lambda path: Path(path).read_bytes(),
|
"cubin": (lambda path: Path(path).read_bytes(),
|
||||||
lambda src: ptx_to_cubin(src, device))
|
lambda src: ptx_to_cubin(src, compute_capability))
|
||||||
}
|
}
|
||||||
first_stage = list(stages.keys()).index(ext)
|
first_stage = list(stages.keys()).index(ext)
|
||||||
asm = dict()
|
asm = dict()
|
||||||
|
@@ -37,20 +37,21 @@ if __name__ == '__main__':
|
|||||||
print(module.str())
|
print(module.str())
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
if not args.sm:
|
||||||
|
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
|
||||||
|
|
||||||
# triton-ir -> triton-gpu-ir
|
# triton-ir -> triton-gpu-ir
|
||||||
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3)
|
module = triton.compiler.ttir_to_ttgir(module, num_warps=4, num_stages=3, compute_capability=args.sm)
|
||||||
if args.target == 'triton-gpu-ir':
|
if args.target == 'triton-gpu-ir':
|
||||||
print(module.str())
|
print(module.str())
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
# triton-gpu-ir -> llvm-ir
|
# triton-gpu-ir -> llvm-ir
|
||||||
module = triton.compiler.ttgir_to_llir(module, extern_libs=None)
|
module = triton.compiler.ttgir_to_llir(module, extern_libs=None, compute_capability=args.sm)
|
||||||
if args.target == 'llvm-ir':
|
if args.target == 'llvm-ir':
|
||||||
print(module)
|
print(module)
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
if not args.sm:
|
|
||||||
raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation")
|
|
||||||
if not args.ptx_version:
|
if not args.ptx_version:
|
||||||
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
|
raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation")
|
||||||
|
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir | FileCheck %s
|
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir --sm=80 | FileCheck %s
|
||||||
|
|
||||||
// == LLVM IR check begin ==
|
// == LLVM IR check begin ==
|
||||||
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
|
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
|
||||||
|
Reference in New Issue
Block a user