89 lines
3.2 KiB
Python
89 lines
3.2 KiB
Python
#
|
|
# BSD 3-Clause License
|
|
#
|
|
# This file is part of the Basalt project.
|
|
# https://gitlab.com/VladyslavUsenko/basalt.git
|
|
#
|
|
# Copyright (c) 2019-2021, Vladyslav Usenko and Nikolaus Demmel.
|
|
# All rights reserved.
|
|
#
|
|
import numbers
|
|
import os
|
|
import scipy.stats
|
|
import numpy as np
|
|
|
|
from pylatex import Subsection, FootnoteText, Tabular, NoEscape, escape_latex
|
|
from pylatex.utils import italic, bold
|
|
|
|
from .containers import ExperimentsTable
|
|
from .util import best_two_non_repeating
|
|
|
|
|
|
class SummarizeSequencesTable(ExperimentsTable):
|
|
|
|
def __init__(self, exps, spec, show_values_failed_runs, seq_displayname_mapping, export_basepath):
|
|
super().__init__(exps, spec, show_values_failed_runs, seq_displayname_mapping, export_basepath)
|
|
|
|
self.doit()
|
|
|
|
def doit(self):
|
|
|
|
def render_metric(value, best, second, decimals, format_string):
|
|
if isinstance(value, numbers.Number):
|
|
rendered = format_string.format(value, prec=decimals)
|
|
if value == best:
|
|
rendered = bold(rendered)
|
|
elif value == second:
|
|
rendered = italic(rendered)
|
|
return rendered
|
|
else:
|
|
return value
|
|
|
|
values = np.empty((self.num_metrics, self.num_seqs, self.num_exps))
|
|
|
|
for i, seq in enumerate(self.seq_names):
|
|
for j, s in enumerate(self.experiment_specs):
|
|
values[:, i, j] = np.array(self.get_metrics(self.exps[s.name], seq, s.it))
|
|
|
|
means = np.empty((self.num_metrics, self.num_exps))
|
|
for i, m in enumerate(self.metrics):
|
|
if m.geometric_mean:
|
|
means[i, :] = scipy.stats.gmean(values[i, :, :], axis=0)
|
|
else:
|
|
means[i, :] = np.mean(values[i, :, :], axis=0)
|
|
|
|
t = Tabular('l' + 'c' * self.num_exps)
|
|
|
|
t.add_hline()
|
|
escape_header_fun = lambda text: text if self.spec.escape_latex_header else NoEscape(text)
|
|
if self.spec.rotate_header:
|
|
t.add_row([self.spec.header] + [
|
|
NoEscape(r"\rotatebox{90}{%s}" % escape_latex(escape_header_fun(s.display_name(self.exps[s.name]))))
|
|
for s in self.experiment_specs
|
|
])
|
|
else:
|
|
t.add_row([self.spec.header] +
|
|
[escape_header_fun(s.display_name(self.exps[s.name])) for s in self.experiment_specs])
|
|
t.add_hline()
|
|
|
|
for i, m in enumerate(self.metrics):
|
|
row_values = np.around(means[i, :], m.decimals)
|
|
top_values = best_two_non_repeating(row_values, reverse=m.larger_is_better)
|
|
row = [m.display_name]
|
|
for v in row_values:
|
|
# TODO: use NoEscape only if certain flag is enabled?
|
|
row.append(
|
|
NoEscape(
|
|
render_metric(v, top_values[0], top_values[1], m.effective_display_decimals(),
|
|
m.format_string)))
|
|
t.add_row(row)
|
|
|
|
t.add_hline()
|
|
|
|
if self.spec.export_latex:
|
|
os.makedirs(self.export_basepath, exist_ok=True)
|
|
t.generate_tex(os.path.join(self.export_basepath, self.spec.export_latex))
|
|
|
|
with self.create(Subsection(self.spec.name, numbering=False)) as p:
|
|
p.append(FootnoteText(t))
|