[CODEGEN] Fixed bug that caused conditional operator to not always
properly mask load operations Also includes minor improvement to benchmarking infrastructure
This commit is contained in:
@@ -32,6 +32,7 @@ private:
|
|||||||
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
|
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
|
||||||
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
||||||
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
||||||
|
bool rewrite_select_masked_load(ir::instruction *value, ir::builder& builder);
|
||||||
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
|
bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@@ -748,6 +748,9 @@ private:
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
|
||||||
|
value* get_pred_op() { return get_operand(0); }
|
||||||
|
value* get_if_value_op() { return get_operand(1); }
|
||||||
|
value* get_else_value_op() { return get_operand(2); }
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -1616,20 +1616,22 @@ void generator::visit_make_range(ir::make_range* x) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void generator::visit_undef_value(ir::undef_value *x) {
|
||||||
|
Type* ty = cvt(x->get_type()->get_scalar_ty());
|
||||||
void generator::visit_undef_value(ir::undef_value *ud) {
|
for(indices_t idx: idxs_.at(x))
|
||||||
vals_[ud][{}] = llvm::UndefValue::get(cvt(ud->get_type()));
|
vals_[x][idx] = llvm::UndefValue::get(ty);
|
||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_constant_int(ir::constant_int *cst){
|
void generator::visit_constant_int(ir::constant_int *x){
|
||||||
Type *ty = cvt(cst->get_type()->get_scalar_ty());
|
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||||
vals_[cst][{}] = ConstantInt::get(ty, cst->get_value());
|
for(indices_t idx: idxs_.at(x))
|
||||||
|
vals_[x][idx] = ConstantInt::get(ty, x->get_value());
|
||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_constant_fp(ir::constant_fp *cst){
|
void generator::visit_constant_fp(ir::constant_fp *x){
|
||||||
Type *ty = cvt(cst->get_type()->get_scalar_ty());
|
Type *ty = cvt(x->get_type()->get_scalar_ty());
|
||||||
vals_[cst][{}] = ConstantFP::get(ty, cst->get_value());
|
for(indices_t idx: idxs_.at(x))
|
||||||
|
vals_[x][idx] = ConstantFP::get(ty, x->get_value());
|
||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_alloc_const(ir::alloc_const *alloc) {
|
void generator::visit_alloc_const(ir::alloc_const *alloc) {
|
||||||
|
@@ -193,6 +193,22 @@ bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::buil
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool peephole::rewrite_select_masked_load(ir::instruction *value, ir::builder& builder){
|
||||||
|
auto select = dynamic_cast<ir::select_inst*>(value);
|
||||||
|
if(!select)
|
||||||
|
return false;
|
||||||
|
auto if_value = dynamic_cast<ir::masked_load_inst*>(select->get_if_value_op());
|
||||||
|
if(!if_value)
|
||||||
|
return false;
|
||||||
|
if(select->get_pred_op() != if_value->get_mask_operand())
|
||||||
|
return false;
|
||||||
|
builder.set_insert_point(select);
|
||||||
|
ir::value* new_load = builder.create_masked_load(if_value->get_pointer_operand(),
|
||||||
|
if_value->get_mask_operand(),
|
||||||
|
select->get_else_value_op());
|
||||||
|
select->replace_all_uses_with(new_load);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void peephole::run(ir::module &mod) {
|
void peephole::run(ir::module &mod) {
|
||||||
ir::builder &builder = mod.get_builder();
|
ir::builder &builder = mod.get_builder();
|
||||||
@@ -230,6 +246,7 @@ void peephole::run(ir::module &mod) {
|
|||||||
// was_modified = was_modified || rewrite_trans_phi(i, builder);
|
// was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||||
|
was_modified = was_modified || rewrite_select_masked_load(i, builder);
|
||||||
if(tgt_->as_nvidia()->sm() >= 80)
|
if(tgt_->as_nvidia()->sm() >= 80)
|
||||||
was_modified = was_modified || rewrite_load_to_shared(i, builder);
|
was_modified = was_modified || rewrite_load_to_shared(i, builder);
|
||||||
if(was_modified)
|
if(was_modified)
|
||||||
|
@@ -267,23 +267,56 @@ void Generator::VisitTransOp(TransOp *trans) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitConditionalOp(ConditionalOp* condOp) {
|
void Generator::VisitConditionalOp(ConditionalOp* condOp) {
|
||||||
// auto &instructions = bld_->get_insert_block()->get_inst_list();
|
auto &instructions = bld_->get_insert_block()->get_inst_list();
|
||||||
VisitExpr(condOp->cond_);
|
VisitExpr(condOp->cond_);
|
||||||
ir::value* cond = ret_;
|
ir::value* true_cond = ret_;
|
||||||
|
ir::instruction* start = instructions.back();
|
||||||
VisitExpr(condOp->exprTrue_);
|
VisitExpr(condOp->exprTrue_);
|
||||||
ir::value* true_val = ret_;
|
ir::value* true_val = ret_;
|
||||||
VisitExpr(condOp->exprFalse_);
|
VisitExpr(condOp->exprFalse_);
|
||||||
ir::value* false_val = ret_;
|
ir::value* false_val = ret_;
|
||||||
if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
|
auto begin = std::find(instructions.begin(), instructions.end(), start);
|
||||||
if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty())
|
bool is_in_true_cond = true;
|
||||||
false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
|
for(auto it = begin; it != instructions.end(); it++){
|
||||||
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), cond, false_val);
|
ir::instruction* instr = *it;
|
||||||
ld->replace_all_uses_with(new_ld);
|
// we mask load with `cond` when used to compute true_value
|
||||||
ld->erase_from_parent();
|
// we mask load with `!cond` when used to compute false_value
|
||||||
return set_ret(new_ld);
|
if(auto ld = dynamic_cast<ir::unmasked_load_inst*>(instr)){
|
||||||
|
bld_->set_insert_point(ld);
|
||||||
|
ir::type* ty = ld->get_type();
|
||||||
|
ir::value* cond = is_in_true_cond ? true_cond : true_cond;
|
||||||
|
ir::value* ptr = ld->get_pointer_operand();
|
||||||
|
ir::value* else_val = ir::undef_value::get(ty);
|
||||||
|
ir::value* masked_ld = bld_->create_masked_load(ptr, cond, else_val);
|
||||||
|
ld->replace_all_uses_with(masked_ld);
|
||||||
|
ld->erase_from_parent();
|
||||||
|
if(true_val == ld)
|
||||||
|
true_val = masked_ld;
|
||||||
|
if(false_val == ld)
|
||||||
|
false_val = masked_ld;
|
||||||
|
it = std::find(instructions.begin(), instructions.end(), masked_ld);
|
||||||
|
}
|
||||||
|
if(instr == true_val)
|
||||||
|
is_in_true_cond = false;
|
||||||
}
|
}
|
||||||
return set_ret(bld_->create_select(cond, true_val, false_val));
|
bld_->set_insert_point(bld_->get_insert_block());
|
||||||
// return error_not_implemented();
|
return set_ret(bld_->create_select(true_cond, true_val, false_val));
|
||||||
|
|
||||||
|
// VisitExpr(condOp->cond_);
|
||||||
|
// ir::value* cond = ret_;
|
||||||
|
// VisitExpr(condOp->exprTrue_);
|
||||||
|
// ir::value* true_val = ret_;
|
||||||
|
// VisitExpr(condOp->exprFalse_);
|
||||||
|
// ir::value* false_val = ret_;
|
||||||
|
// if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
|
||||||
|
// if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty())
|
||||||
|
// false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
|
||||||
|
// ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), cond, false_val);
|
||||||
|
// ld->replace_all_uses_with(new_ld);
|
||||||
|
// ld->erase_from_parent();
|
||||||
|
// return set_ret(new_ld);
|
||||||
|
// }
|
||||||
|
// return set_ret(bld_->create_select(cond, true_val, false_val));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitFuncCall(FuncCall* funcCall) {
|
void Generator::VisitFuncCall(FuncCall* funcCall) {
|
||||||
|
@@ -156,6 +156,7 @@ std::tuple<std::shared_ptr<driver::module>,
|
|||||||
layouts.run(ir);
|
layouts.run(ir);
|
||||||
peephole.run(ir);
|
peephole.run(ir);
|
||||||
dce.run(ir);
|
dce.run(ir);
|
||||||
|
// ir::print(ir, std::cout);
|
||||||
if(target->is_gpu())
|
if(target->is_gpu())
|
||||||
cts.run(ir);
|
cts.run(ir);
|
||||||
align.run(ir);
|
align.run(ir);
|
||||||
|
@@ -2,33 +2,39 @@ import triton
|
|||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def rounded_linspace(low, high, steps, div):
|
def rounded_linspace(low, high, steps, div):
|
||||||
ret = torch.linspace(low, high, steps)
|
ret = torch.linspace(low, high, steps)
|
||||||
ret = (ret.int() + div - 1) // div * div
|
ret = (ret.int() + div - 1) // div * div
|
||||||
ret = torch.unique(ret)
|
ret = torch.unique(ret)
|
||||||
return list(map(int, ret))
|
return list(map(int, ret))
|
||||||
|
|
||||||
|
|
||||||
# Square benchmarks
|
# Square benchmarks
|
||||||
nt = {False: "n", True: "t"}
|
nt = {False: "n", True: "t"}
|
||||||
square_confs = [
|
square_confs = [
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=["M", "N", "K"],
|
x_names=["M", "N", "K"],
|
||||||
x_vals=rounded_linspace(512, 8192, 17, 128),
|
x_vals=rounded_linspace(512, 8192, 32, 128),
|
||||||
y_name="provider",
|
y_name="provider",
|
||||||
y_vals=["torch", "triton", "cutlass"],
|
y_vals=["torch", "triton", "cutlass"],
|
||||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||||
ylabel="TFLOPS",
|
ylabel="TFLOPS",
|
||||||
loglog=False,
|
loglog=False,
|
||||||
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
|
plot_name=f"matmul-square-{nt[AT]}{nt[BT]}",
|
||||||
args={"AT": AT, "BT": BT, "dtype": torch.float16},
|
args={
|
||||||
) for AT in [False, True] for BT in [False, True]
|
"AT": AT,
|
||||||
|
"BT": BT,
|
||||||
|
"dtype": torch.float16
|
||||||
|
},
|
||||||
|
) for AT in [False] for BT in [False]
|
||||||
]
|
]
|
||||||
|
|
||||||
# Transformer training benchmarks
|
# Transformer training benchmarks
|
||||||
transformer_confs = [
|
transformer_confs = [
|
||||||
triton.testing.Benchmark(
|
triton.testing.Benchmark(
|
||||||
x_names=[x],
|
x_names=[x],
|
||||||
x_vals = rounded_linspace(NK//16, NK, 33, 128),
|
x_vals = rounded_linspace(NK//16, NK, 32, 128),
|
||||||
y_name="provider",
|
y_name="provider",
|
||||||
y_vals=["torch", "triton", "cutlass"],
|
y_vals=["torch", "triton", "cutlass"],
|
||||||
y_lines=["Torch", "Triton", "CUTLASS"],
|
y_lines=["Torch", "Triton", "CUTLASS"],
|
||||||
@@ -41,21 +47,21 @@ transformer_confs = [
|
|||||||
for M in [2048]
|
for M in [2048]
|
||||||
]
|
]
|
||||||
|
|
||||||
@triton.testing.perf_report(square_confs)
|
|
||||||
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=40):
|
@triton.testing.perf_report(transformer_confs)
|
||||||
|
def bench_op(M, N, K, AT, BT, dtype, provider, warmup=10, rep=50):
|
||||||
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
a = torch.rand((K, M) if AT else (M, K), device="cuda", dtype=dtype)
|
||||||
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
b = torch.rand((N, K) if BT else (K, N), device="cuda", dtype=dtype)
|
||||||
if AT: a = a.t()
|
if AT: a = a.t()
|
||||||
if BT: b = b.t()
|
if BT: b = b.t()
|
||||||
num_flops = 2 * M * N * K
|
num_flops = 2 * M * N * K
|
||||||
|
tflops = lambda ms: 2. * M * N * K / ms * 1e-9
|
||||||
if provider == "torch":
|
if provider == "torch":
|
||||||
torch_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=rep)
|
||||||
torch_tflops = num_flops / torch_ms * 1e-9
|
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||||
return torch_tflops
|
|
||||||
if provider == "triton":
|
if provider == "triton":
|
||||||
triton_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton.ops.matmul(a, b), warmup=warmup, rep=rep)
|
||||||
triton_tflops = num_flops / triton_ms * 1e-9
|
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
||||||
return triton_tflops
|
|
||||||
if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
|
if provider == "cutlass" and "CUTLASS_PROFILER" in os.environ:
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -87,6 +93,6 @@ def bench_op(M, N, K, AT, BT, dtype, provider, warmup=5, rep=40):
|
|||||||
subprocess.run(cmd, stdout=subprocess.PIPE)
|
subprocess.run(cmd, stdout=subprocess.PIPE)
|
||||||
# read CSV output
|
# read CSV output
|
||||||
df_c = pd.read_csv(f"{fname}.gemm.csv")
|
df_c = pd.read_csv(f"{fname}.gemm.csv")
|
||||||
cutlass_tflops = max(df_c["GFLOPs"]) / 1e3
|
tflops = max(df_c["GFLOPs"]) / 1e3
|
||||||
return cutlass_tflops
|
return tflops
|
||||||
return None
|
return None
|
||||||
|
@@ -26,7 +26,7 @@ __global__ void backward(TYPE *neg_logprobs, long *indices, TYPE *dneg_logprobs,
|
|||||||
TYPE local_dn = *(dneg_logprobs + row);
|
TYPE local_dn = *(dneg_logprobs + row);
|
||||||
// We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
// We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
||||||
// and we have -log(p[k]) stored, so this is easy
|
// and we have -log(p[k]) stored, so this is easy
|
||||||
TYPE intermediate[TILE] = check ? exp(-(float[TILE]) * ? (check)px) : 0;
|
TYPE intermediate[TILE] = check ? exp(-(float[TILE]) * px) : 0;
|
||||||
// selected_logit_idx is selected logit index for our token
|
// selected_logit_idx is selected logit index for our token
|
||||||
bool find_one[TILE] = ((0 ... TILE) == local_ind);
|
bool find_one[TILE] = ((0 ... TILE) == local_ind);
|
||||||
intermediate = intermediate - ((TYPE[TILE])find_one);
|
intermediate = intermediate - ((TYPE[TILE])find_one);
|
||||||
|
@@ -26,35 +26,34 @@ def allclose(x, y):
|
|||||||
return err < tol
|
return err < tol
|
||||||
|
|
||||||
|
|
||||||
def do_bench(fn, warmup=10, rep=50, grad_to_none=None, clear_l2=False):
|
def do_bench(fn, warmup=10, rep=50, grad_to_none=None):
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
# warmup to put the clock in a stable regime
|
|
||||||
ret = fn()
|
|
||||||
for i in range(warmup):
|
|
||||||
fn()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
total_ms = 0
|
|
||||||
# We maintain a buffer of 256 MB that we clear
|
# We maintain a buffer of 256 MB that we clear
|
||||||
# before each kernel call to make sure that the L2
|
# before each kernel call to make sure that the L2
|
||||||
# doesn't contain any input data before the run
|
# doesn't contain any input data before the run
|
||||||
|
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
|
||||||
|
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
|
||||||
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda')
|
||||||
for i in range(rep):
|
for i in range(warmup + rep):
|
||||||
# we don't want `fn` to accumulate gradient values
|
# we don't want `fn` to accumulate gradient values
|
||||||
# if it contains a backward pass. So we clear the
|
# if it contains a backward pass. So we clear the
|
||||||
# provided gradients
|
# provided gradients
|
||||||
if grad_to_none is not None:
|
if grad_to_none is not None:
|
||||||
grad_to_none.grad = None
|
grad_to_none.grad = None
|
||||||
# reset L2
|
# we clear the L2 cache before each run
|
||||||
cache.zero_()
|
cache.zero_()
|
||||||
# record time of `fn`
|
# record time of `fn`
|
||||||
start_event.record()
|
if i >= warmup:
|
||||||
|
start_event[i - warmup].record()
|
||||||
fn()
|
fn()
|
||||||
end_event.record()
|
if i >= warmup:
|
||||||
torch.cuda.synchronize()
|
end_event[i - warmup].record()
|
||||||
total_ms += start_event.elapsed_time(end_event)
|
torch.cuda.synchronize()
|
||||||
# return the average runtime of `fn`
|
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)])
|
||||||
return total_ms / rep
|
q = torch.quantile(times, torch.tensor([0.1, 0.5, 0.9]))
|
||||||
|
min_ms = q[0].item()
|
||||||
|
mean_ms = q[1].item()
|
||||||
|
max_ms = q[2].item()
|
||||||
|
return mean_ms, min_ms, max_ms
|
||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
class Benchmark:
|
||||||
@@ -79,22 +78,42 @@ class Mark:
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import os
|
import os
|
||||||
|
y_mean = bench.y_lines
|
||||||
df = pd.DataFrame(columns=[bench.x_names[0]] + bench.y_lines)
|
y_min = [f'{x}-min' for x in bench.y_lines]
|
||||||
|
y_max = [f'{x}-max' for x in bench.y_lines]
|
||||||
|
df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max)
|
||||||
for x in bench.x_vals:
|
for x in bench.x_vals:
|
||||||
x_args = {x_name: x for x_name in bench.x_names}
|
x_args = {x_name: x for x_name in bench.x_names}
|
||||||
row = [self.fn(**x_args, **{bench.y_name: y}, **bench.args) for y in bench.y_vals]
|
row_mean, row_min, row_max = [], [], []
|
||||||
df.loc[len(df)] = [x] + row
|
for y in bench.y_vals:
|
||||||
|
ret = self.fn(**x_args, **{bench.y_name: y}, **bench.args)
|
||||||
|
try:
|
||||||
|
y_mean, y_min, y_max = ret
|
||||||
|
except TypeError:
|
||||||
|
y_mean, y_min, y_max = ret, None, None
|
||||||
|
row_mean += [y_mean]
|
||||||
|
row_min += [y_min]
|
||||||
|
row_max += [y_max]
|
||||||
|
df.loc[len(df)] = [x] + row_mean + row_min + row_max
|
||||||
if with_plot and bench.plot_name:
|
if with_plot and bench.plot_name:
|
||||||
|
plt.figure()
|
||||||
|
ax = plt.subplot()
|
||||||
xlabel = " = ".join(bench.x_names)
|
xlabel = " = ".join(bench.x_names)
|
||||||
plot = df.plot(x=bench.x_names[0], y=bench.y_lines)
|
x = bench.x_names[0]
|
||||||
plot.set_xlabel(xlabel)
|
for y in bench.y_lines:
|
||||||
plot.set_ylabel(bench.ylabel)
|
y_min, y_max = df[y + '-min'], df[y + '-max']
|
||||||
plot.set_title(bench.plot_name)
|
ax.plot(df[x], df[y], label=y)
|
||||||
plot.set_xscale("log" if bench.loglog else "linear")
|
if y_min is not None and y_max is not None:
|
||||||
plot.set_yscale("log" if bench.loglog else "linear")
|
ax.fill_between(df[x], y_min, y_max, alpha=0.5)
|
||||||
|
ax.legend()
|
||||||
|
ax.set_xlabel(xlabel)
|
||||||
|
ax.set_ylabel(bench.ylabel)
|
||||||
|
ax.set_title(bench.plot_name)
|
||||||
|
ax.set_xscale("log" if bench.loglog else "linear")
|
||||||
|
ax.set_yscale("log" if bench.loglog else "linear")
|
||||||
plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png"))
|
plt.savefig(os.path.join(result_path, f"{bench.plot_name}.png"))
|
||||||
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"))
|
df = df[[bench.x_names[0]] + bench.y_lines]
|
||||||
|
df.to_csv(os.path.join(result_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False)
|
||||||
|
|
||||||
def run(self, result_path, with_plot):
|
def run(self, result_path, with_plot):
|
||||||
with open(os.path.join(result_path, "results.html"), "w") as html:
|
with open(os.path.join(result_path, "results.html"), "w") as html:
|
||||||
|
Reference in New Issue
Block a user