import re
import os
import glob
import inspect
import logging
from mrestimator import utility as ut
log = ut.log
from mrestimator import CoefficientResult
from mrestimator import FitResult
from mrestimator import __version__
import numpy as np
import matplotlib
if os.environ.get("DISPLAY", "") == "":
log.info(
"No display found. Using non-interactive Agg backend for plotting. " +
"Check your $DISPLAY environment variable."
)
matplotlib.use("Agg")
import matplotlib.pyplot as plt
[docs]class OutputHandler:
"""
The OutputHandler can be used to export results and to
create charts with
timeseries, correlation-coefficients or fits.
The main concept is to have one handler per plot. It contains
functions to add content into an existing matplotlib axis (subplot),
or, if not provided, creates a new figure.
Most importantly, it also exports plaintext of the respective source
material so figures are reproducible.
Note: If you want to have a live preview of the figures that are
automatically generated with matplotlib, you HAVE to assign the result
of `mre.OutputHandler()` to a variable. Otherwise, the created figures
are not retained and vanish instantly.
Attributes
----------
rks: list
List of the :obj:`CoefficientResult`. Added with `add_coefficients()`
fits: list
List of the :obj:`FitResult`. Added with `add_fit()`
Example
-------
.. code-block:: python
import numpy as np
import matplotlib.pyplot as plt
import mrestimator as mre
bp = mre.simulate_branching(numtrials=15)
rk1 = mre.coefficients(bp, method='trialseparated',
desc='T')
rk2 = mre.coefficients(bp, method='stationarymean',
desc='S')
m1 = mre.fit(rk1)
m2 = mre.fit(rk2)
# create a new handler by passing with list of elements
out = mre.OutputHandler([rk1, m1])
# manually add elements
out.add_coefficients(rk2)
out.add_fit(m2)
# save the plot and meta to disk
out.save('~/test')
..
Working with existing figures:
.. code-block:: python
# create figure with subplots
fig = plt.figure()
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)
# show each chart in its own subplot
mre.OutputHandler(rk1, ax1)
mre.OutputHandler(rk2, ax2)
mre.OutputHandler(m1, ax3)
mre.OutputHandler(m2, ax4)
# matplotlib customisations
myaxes = [ax1, ax2, ax3, ax4]
for ax in myaxes:
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.show(block=False)
# hide a legend
ax1.legend().set_visible(False)
plt.draw()
..
"""
[docs] def __init__(self, data=None, ax=None):
"""
Construct a new OutputHandler, optionally you can provide
the a list of elements to plot.
ToDo: Make the OutputHandler talk to each other so that
when one is written (possibly linked to others via one figure)
all subfigure meta data is exported, too.
Parameters
----------
data : list, CoefficientResult or FitResult, optional
List of the elements to plot/export. Can be added later.
ax : ~matplotlib.axes.Axes, optional
The an instance of a matplotlib axes (a subplot) to plot into.
"""
if isinstance(ax, matplotlib.axes.Axes):
self.ax = ax
self.axshared = True
elif ax is None:
self.axshared = False
# fig = plt.figure()
# self.ax = fig.add_subplot(111, rasterized=True)
_, self.ax = plt.subplots()
# everything below zorder 0 gets rastered to one layer
self.ax.set_rasterization_zorder(0)
else:
log.exception(
"Argument 'ax' provided to OutputHandler is not "
+ " an instance of matplotlib.axes.Axes\n"
+ "\tIn case you want to add multiple items, pass them in a list "
+ "as the first argument"
)
raise TypeError
self.rks = []
self.rklabels = []
self.rkcurves = []
self.rkkwargs = []
self.fits = []
self.fitlabels = []
self.fitcurves = [] # list of lists of drawn curves for each fit
self.fitkwargs = []
self.dt = 1
self.dtunit = None
self.type = None
self.xdata = None
self.ydata = [] # list of 1d np arrays
self.xlabel = None
self.ylabels = []
# single argument to list
if (
isinstance(data, CoefficientResult)
or isinstance(data, FitResult)
or isinstance(data, np.ndarray)
):
data = [data]
for d in data or []:
if isinstance(d, CoefficientResult):
self.add_coefficients(d)
elif isinstance(d, FitResult):
self.add_fit(d)
elif isinstance(d, np.ndarray):
self.add_ts(d)
else:
log.exception(
"Please provide a list containing "
"\tCoefficientResults and/or FitResults\n"
)
raise ValueError
def __del__(self):
"""
close opened figures when outputhandler is no longer used
"""
if not self.axshared:
try:
plt.close(self.ax.figure)
# pass
except Exception as e:
log.debug("Exception passed", exc_info=True)
[docs] def set_xdata(self, data=None, dt=1, dtunit=None):
"""
Adjust xdata of the plot, matching the input value.
Returns an array of indices matching the incoming indices to
already present ones. Automatically called when adding content.
If you want to customize the plot range, add all the content
and use matplotlibs
:obj:`~matplotlib.axes.Axes.set_xlim` function once at the end.
(`set_xdata()` also manages meta data and can only *increase* the
plot range)
Parameters
----------
data : ~numpy.array
x-values to plot the fits for. `data` does not need to be
spaced equally but is assumed to be sorted.
dt : float
check if existing data can be mapped to the new, provided `dt`
or the other way around. `set_xdata()` pads
undefined areas with `nan`.
dtunit : str
check if the new `dtunit` matches the one set previously. Any
padding to match `dt` is only done if `dtunits` are the same,
otherwise the plot falls back to using generic integer steps.
Returns
-------
: :class:`~numpy.array`
containing the indices where the `data` given to this function
coincides with (possibly) already existing data that was
added/plotted before.
Example
-------
.. code-block:: python
out = mre.OutputHandler()
# 100 intervals of 2ms
out.set_xdata(np.arange(0,100), dt=2, dtunit='ms')
# increase resolution to 1ms for the first 50ms
# this changes the existing structure in the meta data. also
# the axis of `out` is not equally spaced anymore
fiftyms = np.arange(0,50)
out.set_xdata(fiftyms, dt=1, dtunit='ms')
# data with larger intervals is less dense, the returned list
# tells you which index in `out` belongs to every index
# in `xdat`
xdat = np.arange(0,50)
ydat = np.random_sample(50)
inds = out.set_xdata(xdat, dt=4, dtunit='ms')
# to pad `ydat` to match the axis of `out`:
temp = np.full(out.xdata.size, np.nan)
temp[inds] = ydat
..
"""
log.debug("OutputHandler.set_xdata()")
# make sure data is not altered
xdata = np.copy(data.astype("float64"))
# xdata = data
# nothing set so far, no arugment provided, return some default
if self.xdata is None and xdata is None:
self.xdata = np.arange(0, 1501)
self.dtunit = dtunit
self.dt = dt
return np.arange(0, 1501)
# set x for the first time, copying input
if self.xdata is None:
self.xdata = np.array(xdata)
self.dtunit = dtunit
self.dt = dt
return np.arange(0, self.xdata.size)
# no new data provided, no need to call this
elif xdata is None:
log.debug(
"set_xdata() called without argument when "
+ "xdata is already set. Nothing to adjust"
)
return np.arange(0, self.xdata.size)
# compare dtunits
elif dtunit != self.dtunit and dtunit is not None:
log.warning(
"'dtunit' does not match across added elements, "
+ "adjusting axis label to '[different units]'"
)
regex = r"\[.*?\]"
oldlabel = self.ax.get_xlabel()
self.ax.set_xlabel(re.sub(regex, "[different units]", oldlabel))
# set dtunit to new value if not assigned yet
elif self.dtunit is None and dtunit is not None:
self.dtunit = dtunit
# new data matches old data, nothing to adjust
if np.array_equal(self.xdata, xdata) and self.dt == dt:
return np.arange(0, self.xdata.size)
# compare timescales dt
elif self.dt < dt:
log.debug("dt does not match,")
scd = dt / self.dt
if float(scd).is_integer():
log.debug(
"Changing axis values of new data (dt={})".format(dt)
+ "to match higher resolution of "
+ "old xaxis (dt={})".format(self.dt)
)
scd = dt / self.dt
xdata *= scd
else:
log.warning(
"New 'dt={}' is not an integer multiple of ".format(dt)
+ "the previous 'dt={}\n".format(self.dt)
+ "Plotting with '[different units]'\n"
+ "As a workaround, try adding the data with the "
+ "smallest 'dt' first"
)
try:
regex = r"\[.*?\]"
oldlabel = self.ax.get_xlabel()
self.ax.set_xlabel(re.sub(regex, "[different units]", oldlabel))
self.xlabel = re.sub(regex, "[different units]", self.xlabel)
except TypeError:
log.debug("Exception passed", exc_info=True)
elif self.dt > dt:
scd = self.dt / dt
if float(scd).is_integer():
log.debug(
"Changing 'dt' to new value 'dt={}'\n".format(dt)
+ "\tAdjusting existing axis values (dt={})".format(self.dt)
)
self.xdata *= scd
self.dt = dt
try:
regex = r"\[.*?\]"
oldlabel = self.ax.get_xlabel()
if self.dt == 1:
newlabel = str("[{}]".format(self.dtunit))
else:
newlabel = str(
"[{} {}]".format(ut._printeger(self.dt), self.dtunit)
)
self.ax.set_xlabel(re.sub(regex, newlabel, oldlabel))
self.xlabel = re.sub(regex, newlabel, self.xlabel)
except TypeError:
pass
else:
log.warning(
"old 'dt={}' is not an integer multiple ".format(self.dt)
+ "of the new value 'dt={}'\n".format(self.dt)
+ "\tPlotting with '[different units]'\n"
)
try:
regex = r"\[.*?\]"
oldlabel = self.ax.get_xlabel()
self.ax.set_xlabel(re.sub(regex, "[different units]", oldlabel))
self.xlabel = re.sub(regex, "[different units]", self.xlabel)
except TypeError:
pass
# check if new is subset of old
temp = np.union1d(self.xdata, xdata)
if not np.array_equal(self.xdata, temp):
log.debug("Rearranging present data")
_, indtemp = ut._intersecting_index(self.xdata, temp)
self.xdata = temp
for ydx, col in enumerate(self.ydata):
coln = np.full(self.xdata.size, np.nan)
coln[indtemp] = col
self.ydata[ydx] = coln
# return list of indices where to place new ydata in the existing
# (higher-resolution) notation
indold, indnew = ut._intersecting_index(self.xdata, xdata)
assert len(indold) == len(xdata)
return indold
[docs] def add_coefficients(self, data, **kwargs):
"""
Add an individual CoefficientResult. Note that it is not possible
to add the same data twice, instead it will be redrawn with
the new arguments/style options provided.
Parameters
----------
data : CoefficientResult
Added to the list of plotted elements.
kwargs
Keyword arguments passed to
:obj:`matplotlib.axes.Axes.plot`. Use to customise the
plots. If a `label` is set via `kwargs`, it will be used to
overwrite the description of `data` in the meta file.
If an alpha value is or linestyle is set, the shaded error
region will be omitted.
Example
-------
.. code-block:: python
rk = mre.coefficients(mre.simulate_branching())
mout = mre.OutputHandler()
mout.add_coefficients(rk, color='C1', label='test')
..
"""
if not isinstance(data, CoefficientResult):
log.exception("'data' needs to be of type CoefficientResult")
raise ValueError
if not (self.type is None or self.type == "correlation"):
log.exception(
"It is not possible to 'add_coefficients()' to "
+ "an OutputHandler containing a time series\n"
+ "\tHave you previously called 'add_ts()' on this handler?"
)
raise ValueError
self.type = "correlation"
# description for columns of meta data
desc = str(data.desc)
# plot legend label
if "label" in kwargs:
label = kwargs.get("label")
if label == "":
label = None
if label is None:
labelerr = None
else:
# user wants custom label not intended to hide the legend
label = str(label)
labelerr = str(label) + " Errors"
# apply to meta data, too
desc = str(label)
else:
# user has not set anything, copy from desc if set
label = "Data"
labelerr = "Errors"
if desc != "":
label = desc
labelerr = desc + " Errors"
if desc != "":
desc += " "
# dont put errors in the legend. this should become a user choice
labelerr = ""
# no previous coefficients present
if len(self.rks) == 0:
self.dt = data.dt
self.dtunit = data.dtunit
if self.dt == 1:
self.xlabel = "steps[{}]".format(data.dtunit)
self.ax.set_xlabel("k [{}]".format(data.dtunit))
else:
self.xlabel = "steps[{} {}]".format(
ut._printeger(data.dt, 5), data.dtunit
)
self.ax.set_xlabel(
"k [{} {}]".format(ut._printeger(data.dt, 5), data.dtunit)
)
self.ax.set_ylabel("$r_{k}$")
self.ax.set_title("Correlation", fontweight="bold")
# we dont support adding duplicates
oldcurves = []
if data in self.rks:
indrk = self.rks.index(data)
log.warning(
"Coefficients ({}/{}) ".format(self.rklabels[indrk][0], label)
+ "have already been added\n\tOverwriting with new style"
)
del self.rks[indrk]
del self.rklabels[indrk]
oldcurves = self.rkcurves[indrk]
del self.rkcurves[indrk]
del self.rkkwargs[indrk]
# add to meta data
else:
inds = self.set_xdata(data.steps, dt=data.dt, dtunit=data.dtunit)
ydata = np.full(self.xdata.size, np.nan)
ydata[inds] = data.coefficients
self.ydata.append(ydata)
self.ylabels.append(desc + "coefficients")
if data.stderrs is not None:
ydata = np.full(self.xdata.size, np.nan)
ydata[inds] = data.stderrs
self.ydata.append(ydata)
self.ylabels.append(desc + "stderrs")
self.rks.append(data)
self.rklabels.append([label, labelerr])
self.rkcurves.append(oldcurves)
self.rkkwargs.append(kwargs)
# refresh coefficients
for r in self.rks:
self._render_coefficients(r)
# refresh fits
for f in self.fits:
self._render_fit(f)
# need to implement using kwargs
def _render_coefficients(self, rk):
# (re)draw over (possibly) new xrange/dt
indrk = self.rks.index(rk)
label, labelerr = self.rklabels[indrk]
kwargs = self.rkkwargs[indrk].copy()
# reset curves and recover color
color = None
for idx, curve in enumerate(self.rkcurves[indrk]):
if idx == 0:
color = curve.get_color()
curve.remove()
self.rkcurves[indrk] = []
if "color" not in kwargs:
kwargs = dict(kwargs, color=color)
if "zorder" not in kwargs:
kwargs = dict(kwargs, zorder=1 + 0.01 * indrk)
kwargs = dict(kwargs, label=label)
# redraw plot
(p,) = self.ax.plot(rk.steps * rk.dt / self.dt, rk.coefficients, **kwargs)
self.rkcurves[indrk].append(p)
try:
if rk.stderrs is not None and "alpha" not in kwargs:
err1 = rk.coefficients - rk.stderrs
err2 = rk.coefficients + rk.stderrs
kwargs.pop("color")
kwargs.pop("zorder")
kwargs = dict(
kwargs,
label=labelerr,
alpha=0.2,
facecolor=p.get_color(),
zorder=p.get_zorder() - 1,
)
d = self.ax.fill_between(
rk.steps * rk.dt / self.dt, err1, err2, **kwargs
)
self.rkcurves[indrk].append(d)
# not all kwargs are compaible with fill_between
except AttributeError:
pass
if label is not None:
self.ax.legend()
# confirm ticks, it's confusing that we should have a tick at k=0
old_limit = self.ax.get_xlim()
old_ticks = list(self.ax.get_xticks())
new_ticks = [1] + [i for i in old_ticks if i > 1]
self.ax.set_xticks(new_ticks)
self.ax.set_xlim(old_limit) # matplotlib might change xlim to match ticks
[docs] def add_fit(self, data, **kwargs):
"""
Add an individual FitResult. By default, the part of the fit that
contributed to the fitting is drawn solid, the remaining range
is dashed. Note that it is not possible
to add the same data twice, instead it will be redrawn with
the new arguments/style options provided.
Parameters
----------
data : FitResult
Added to the list of plotted elements.
kwargs
Keyword arguments passed to
:obj:`matplotlib.axes.Axes.plot`. Use to customise the
plots. If a `label` is set via `kwargs`, it will be added
as a note in the meta data. If `linestyle` is set, the
dashed plot of the region not contributing to the fit is
omitted.
"""
if not isinstance(data, FitResult):
log.exception("'data' needs to be of type FitResult")
raise ValueError
if not (self.type is None or self.type == "correlation"):
log.exception(
"It is not possible to 'add_fit()' to "
+ "an OutputHandler containing a time series\n"
+ "\tHave you previously called 'add_ts()' on this handler?"
)
raise ValueError
self.type = "correlation"
if self.xdata is None:
self.dt = data.dt
self.dtunit = data.dtunit
self.ax.set_xlabel("k [{} {}]".format(data.dt, data.dtunit))
self.ax.set_ylabel("$r_{k}$")
self.ax.set_title("Correlation", fontweigh="bold")
inds = self.set_xdata(data.steps, dt=data.dt, dtunit=data.dtunit)
# description for fallback
desc = str(data.desc)
# plot legend label
if "label" in kwargs:
label = kwargs.get("label")
if label == "":
label = None
else:
# user wants custom label not intended to hide the legend
label = str(label)
else:
# user has not set anything, copy from desc if set
label = "Fit " + ut.math_from_doc(data.fitfunc, 0)
if desc != "":
label = desc + " " + label
# we dont support adding duplicates
oldcurves = []
if data in self.fits:
indfit = self.fits.index(data)
log.warning(
"Fit was already added ({})\n".format(self.fitlabels[indfit])
+ "\tOverwriting with new style"
)
del self.fits[indfit]
del self.fitlabels[indfit]
oldcurves = self.fitcurves[indfit]
del self.fitcurves[indfit]
del self.fitkwargs[indfit]
self.fits.append(data)
self.fitlabels.append(label)
self.fitcurves.append(oldcurves)
self.fitkwargs.append(kwargs)
# refresh coefficients
for r in self.rks:
self._render_coefficients(r)
# refresh fits
for f in self.fits:
self._render_fit(f)
def _render_fit(self, fit):
# (re)draw fit over (possibly) new xrange
indfit = self.fits.index(fit)
label = self.fitlabels[indfit]
kwargs = self.fitkwargs[indfit].copy()
color = None
for idx, curve in enumerate(self.fitcurves[indfit]):
if idx == 0:
color = curve.get_color()
curve.remove()
self.fitcurves[indfit] = []
if "color" not in kwargs:
kwargs = dict(kwargs, color=color)
if "zorder" not in kwargs:
kwargs = dict(kwargs, zorder=4 + 0.01 * indfit)
kwargs = dict(kwargs, label=label)
# update plot
(p,) = self.ax.plot(
fit.steps * fit.dt / self.dt,
fit.fitfunc(fit.steps * fit.dt, *fit.popt),
**kwargs,
)
self.fitcurves[indfit].append(p)
# only draw dashed not-fitted range if no linestyle is specified
if fit.steps[0] > self.xdata[0] or fit.steps[-1] < self.xdata[-1]:
if "linestyle" not in kwargs and "ls" not in kwargs:
kwargs.pop("label")
kwargs = dict(kwargs, ls="dashed", color=p.get_color())
(d,) = self.ax.plot(
self.xdata, fit.fitfunc(self.xdata * self.dt, *fit.popt), **kwargs
)
self.fitcurves[indfit].append(d)
# errors as shaded area
if False:
try:
if fit.taustderr is not None and "alpha" not in kwargs:
ptmp = np.copy(fit.popt)
ptmp[0] = fit.tau - fit.taustderr
err1 = fit.fitfunc(self.xdata * self.dt, *ptmp)
ptmp[0] = fit.tau + fit.taustderr
err2 = fit.fitfunc(self.xdata * self.dt, *ptmp)
kwargs.pop("color")
kwargs.pop("label")
kwargs = dict(
kwargs,
alpha=0.2,
facecolor=p.get_color(),
zorder=0 + 0.01 * indfit,
)
s = self.ax.fill_between(self.xdata, err1, err2, **kwargs)
self.fitcurves[indfit].append(s)
# not all kwargs are compaible with fill_between
except AttributeError:
log.debug("Exception passed", exc_info=True)
if label is not None:
self.ax.legend()
# confirm ticks, it's confusing that we should have a tick at k=0
old_limit = self.ax.get_xlim()
old_ticks = list(self.ax.get_xticks())
new_ticks = [1] + [i for i in old_ticks if i > 1]
self.ax.set_xticks(new_ticks)
self.ax.set_xlim(old_limit) # matplotlib might change xlim to match ticks
[docs] def add_ts(self, data, **kwargs):
"""
Add timeseries (possibly with trial structure).
Not compatible with OutputHandlers that have data added via
`add_fit()` or `add_coefficients()`.
Parameters
----------
data : ~numpy.ndarray
The timeseries to plot. If the `ndarray` is two dimensional,
a trial structure is assumed and all trials are plotted using
the same style (default or defined via `kwargs`).
*Not implemented yet*: Providing a ts with its own custom axis
kwargs
Keyword arguments passed to
:obj:`matplotlib.axes.Axes.plot`. Use to customise the
plots.
Example
-------
.. code-block:: python
bp = mre.simulate_branching(numtrials=10)
tsout = mre.OutputHandler()
tsout.add_ts(bp, alpha=0.1, label='Trials')
tsout.add_ts(np.mean(bp, axis=0), label='Mean')
plt.show()
..
"""
if not (self.type is None or self.type == "timeseries"):
log.exception(
"Adding time series 'add_ts()' is not "
+ "compatible with an OutputHandler that has coefficients\n"
+ "\tHave you previously called 'add_coefficients()' or "
+ "'add_fit()' on this handler?"
)
raise ValueError
self.type = "timeseries"
if not isinstance(data, np.ndarray):
data = np.array(data)
if len(data.shape) < 2:
data = data.reshape((1, len(data)))
elif len(data.shape) > 2:
log.exception("Only compatible with up to two dimensions")
raise NotImplementedError
desc = kwargs.get("label") if "label" in kwargs else "ts"
color = kwargs.get("color") if "color" in kwargs else None
alpha = kwargs.get("alpha") if "alpha" in kwargs else None
# per default, if more than one series provided reduce alpha
if data.shape[0] > 1 and not "alpha" in kwargs:
alpha = 0.1
kwargs = dict(kwargs, alpha=alpha)
if "zorder" not in kwargs:
kwargs = dict(kwargs, zorder=-1)
for idx, dat in enumerate(data):
if self.xdata is None:
self.set_xdata(np.arange(1, data.shape[1] + 1))
self.xlabel = "timesteps"
self.ax.set_xlabel("t")
self.ax.set_ylabel("$A_{t}$")
self.ax.set_title("Time Series", fontweight="bold")
elif len(self.xdata) != len(dat):
log.exception("Time series have different length")
raise NotImplementedError
# if self.ydata is None:
# self.ydata = np.full((1, len(self.xdata)), np.nan)
# self.ydata[0] = dat
# else:
# self.ydata = np.vstack((self.ydata, dat))
self.ydata.append(dat)
self.ylabels.append(desc + "[{}]".format(idx) if len(data) > 1 else desc)
(p,) = self.ax.plot(self.xdata, dat, **kwargs)
# dont plot an empty legend
if kwargs.get("label") is not None and kwargs.get("label") != "":
self.ax.legend()
# only add to legend once
if idx == 0:
kwargs = dict(kwargs, label=None)
kwargs = dict(kwargs, color=p.get_color())
[docs] def save(self, fname="", ftype="pdf", dpi=300):
"""
Saves plots (ax element of this handler) and source that it was
created from to the specified location.
Parameters
----------
fname : str, optional
Path where to save, without file extension. Defaults to "./mre"
"""
self.save_plot(fname, ftype=ftype, dpi=dpi)
self.save_meta(fname)
[docs] def save_plot(self, fname="", ftype="pdf", dpi=300):
"""
Only saves plots (ignoring the source) to the specified location.
Parameters
----------
fname : str, optional
Path where to save, without file extension. Defaults to "./mre"
ftype: str, optional
So far, only 'pdf' and 'png' are implemented.
"""
if not isinstance(fname, str):
fname = str(fname)
if fname == "":
fname = "./mre"
# try creating enclosing dir if not existing
tempdir = os.path.abspath(os.path.expanduser(fname + "/../"))
os.makedirs(tempdir, exist_ok=True)
fname = os.path.expanduser(fname)
if isinstance(ftype, str):
ftype = [ftype]
for t in list(ftype):
log.info("Saving plot to {}.{}".format(fname, t.lower()))
if t.lower() == "pdf":
self.ax.figure.savefig(fname + ".pdf", dpi=dpi)
elif t.lower() == "png":
self.ax.figure.savefig(fname + ".png", dpi=dpi)
else:
log.exception("Unsupported file format '{}'".format(t))
raise ValueError
def overview(src, rks, fits, **kwargs):
"""
creates an A4 overview panel and returns the matplotlib figure element.
No Argument checks are done
"""
ratios = np.ones(5)
ratios[4] = 0.0001
# ratios=None
# A5 in inches, should check rc params in the future
# matplotlib changes the figure size when modifying subplots
fig, axes = plt.subplots(
nrows=5, figsize=(5.8, 8.3), gridspec_kw={"height_ratios": ratios}
)
# avoid huge file size for many trials due to separate layers.
# everything below 0 gets rastered to the same layer.
axes[0].set_rasterization_zorder(0)
# ------------------------------------------------------------------ #
# Time Series
# ------------------------------------------------------------------ #
tsout = OutputHandler(ax=axes[0])
tsout.add_ts(src, label="Trials")
if src.shape[0] > 1:
try:
prevclr = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
except Exception:
prevclr = "navy"
log.debug("Exception getting color cycle", exc_info=True)
tsout.add_ts(np.mean(src, axis=0), color=prevclr, label="Average")
else:
tsout.ax.legend().set_visible(False)
tsout.ax.set_title("Time Series", fontweight="bold", loc="center")
tsout.ax.set_title("(Input Data)", fontsize="medium", color="#646464", loc="right")
tsout.ax.set_xlabel(
"t [{}{}]".format(
ut._printeger(rks[0].dt) + " " if rks[0].dt != 1 else "", rks[0].dtunit
)
)
# ------------------------------------------------------------------ #
# Mean Trial Activity
# ------------------------------------------------------------------ #
if src.shape[0] > 1:
# average trial activites as function of trial number
taout = OutputHandler(rks[0].trialactivities, ax=axes[1])
try:
err1 = rks[0].trialactivities - np.sqrt(rks[0].trialvariances)
err2 = rks[0].trialactivities + np.sqrt(rks[0].trialvariances)
prevclr = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
taout.ax.fill_between(
np.arange(1, rks[0].numtrials + 1), err1, err2, color=prevclr, alpha=0.2
)
except Exception as e:
log.debug("Exception adding std deviation to plot", exc_info=True)
taout.ax.set_title("Mean Trial Activity and Std. Deviation", fontweight="bold")
taout.ax.set_xlabel("Trial i")
taout.ax.set_ylabel("$\\bar{A}_i$")
else:
# running average over the one trial to see if stays stationary
numsegs = kwargs.get(numsegs) if "numsegs" in kwargs else 50
ravg = np.zeros(numsegs)
err1 = np.zeros(numsegs)
err2 = np.zeros(numsegs)
seglen = int(src.shape[1] / numsegs)
for s in range(numsegs):
temp = np.mean(src[0][s * seglen : (s + 1) * seglen])
ravg[s] = temp
stddev = np.sqrt(np.var(src[0][s * seglen : (s + 1) * seglen]))
err1[s] = temp - stddev
err2[s] = temp + stddev
taout = OutputHandler(ravg, ax=axes[1])
try:
prevclr = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
taout.ax.fill_between(
np.arange(1, numsegs + 1), err1, err2, color=prevclr, alpha=0.2
)
except Exception as e:
log.debug("Exception adding std deviation to plot", exc_info=True)
taout.ax.set_title(
"Average Activity and Stddev for {} Intervals".format(numsegs),
fontweight="bold",
)
taout.ax.set_xlabel("Interval i")
taout.ax.set_ylabel("$\\bar{A}_i$")
# ------------------------------------------------------------------ #
# Coefficients and Fit results
# ------------------------------------------------------------------ #
cout = OutputHandler(rks + fits, ax=axes[2])
fitcurves = []
fitlabels = []
for i, f in enumerate(cout.fits):
fitcurves.append(cout.fitcurves[i][0])
label = ut.math_from_doc(f.fitfunc, 5)
label += "\n\n$\\tau={:.2f}${}\n".format(f.tau, f.dtunit)
if f.tauquantiles is not None:
label += "$[{:.2f}:{:.2f}]$\n\n".format(
f.tauquantiles[0], f.tauquantiles[-1]
)
else:
label += "\n\n"
label += "$m={:.5f}$\n".format(f.mre)
if f.mrequantiles is not None:
label += "$[{:.5f}:{:.5f}]$".format(f.mrequantiles[0], f.mrequantiles[-1])
else:
label += "\n"
fitlabels.append(label)
tempkwargs = {
# 'title': 'Fitresults',
"ncol": len(fitlabels),
"loc": "upper center",
"mode": "expand",
"frameon": True,
"markerfirst": True,
"fancybox": False,
# 'framealpha': 1,
"borderaxespad": 0,
"edgecolor": "black",
# hide handles
"handlelength": 0,
"handletextpad": 0,
}
try:
axes[3].legend(fitcurves, fitlabels, **tempkwargs)
except Exception:
log.debug("Exception passed", exc_info=True)
del tempkwargs["edgecolor"]
axes[3].legend(fitcurves, fitlabels, **tempkwargs)
# hide handles
for handle in axes[3].get_legend().legendHandles:
handle.set_visible(False)
# center text
for t in axes[3].get_legend().texts:
t.set_multialignment("center")
# apply stile and fill legend
axes[3].get_legend().get_frame().set_linewidth(0.5)
axes[3].axis("off")
axes[3].set_title(
"Fitresults", fontweight="bold", loc="center",
)
axes[3].set_title(
" (with CI: [$12.5\\%:87.5\\%$])",
color="#646464",
fontsize="medium",
loc="right",
)
for a in axes:
a.xaxis.set_tick_params(width=0.5)
a.yaxis.set_tick_params(width=0.5)
for s in a.spines:
a.spines[s].set_linewidth(0.5)
# dummy axes for version and warnings
axes[4].axis("off")
fig.tight_layout()
plt.subplots_adjust(hspace=0.8, top=0.95, bottom=0.0, left=0.1, right=0.99)
title = kwargs.get("title") if "title" in kwargs else None
if title is not None and title != "":
fig.suptitle(title, fontsize=14)
plt.subplots_adjust(top=0.91)
if "warning" in kwargs and kwargs.get("warning") is not None:
s = "\u26A0 {}".format(kwargs.get("warning"))
fig.text(0.5, 0.01, s, fontsize=13, horizontalalignment="center", color="red")
fig.text(
0.995,
0.005,
"v{}".format(__version__),
fontsize="small",
horizontalalignment="right",
color="#646464",
)
return fig