Source code for Stoner.plot.core

"""Provides the a class to facilitte easier plotting of Stoner Data.

Classes:
    PlotMixin:
        A class that uses matplotlib to plot data
"""

# pylint: disable=C0413
from __future__ import division

__all__ = ["PlotMixin"]
import copy
from collections.abc import Mapping
from functools import wraps

import numpy as np
from matplotlib import figure as mplfig
from matplotlib import pyplot as plt

from ..compat import int_types
from ..tools import (
    TypedList,
    all_type,
    fix_signature,
    get_option,
    isiterable,
)
from ..tools.decorators import class_modifier
from . import functions
from .formats import DefaultPlotStyle
from .utils import errorfill

try:  # Check we've got 3D plotting
    import mpl_toolkits.axisartist as AA  # noqa: F401 pylint: disable=unused-import
    from mpl_toolkits.axes_grid1 import (  # noqa: F401 pylint: disable=unused-import
        host_subplot,
        inset_locator,
    )
    from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 pylint: disable=unused-import

    _3D = True
except ImportError:
    _3D = False


[docs] @class_modifier([functions], adaptor=None, no_long_names=True, overload=True) class PlotMixin: r"""A mixin class that works with :py:class:`Stoner.Core.DataFile` to add additional plotting functionality. Args: args(tuple): Arguments to pass to :py:meth:`Stoner.Core.DataFile.__init__` kwargs (dict): keyword arguments to pass to \b DataFile.__init__ Attributes: ax (matplotlib.Axes): The current axes on the current figure. axes (list of matplotlib.Axes): A list of all the axes on the current figure fig (matplotlib.figure): The current figure object being worked with fignum (int): The current figure's number. labels (list of string): List of axis labels as aternates to the column_headers showfig (bool or None): Controls whether plot functions return a copy of the figure (True), the DataFile (False) or Nothing (None) subplots (list of matplotlib.Axes): essentially the same as :py:attr:`PlotMixin.axes` but ensures that the list of subplots is synchronised to the number fo Axes. template (:py:class:`Sonter.plot.formats.DefaultPlotStyle` or instance): A plot style template subclass or object that determines the format and appearance of plots. """ positional_fmt = [plt.plot, plt.semilogx, plt.semilogy, plt.loglog] no_fmt = [errorfill] _template = DefaultPlotStyle() _legend = True multiple = "common" def __init__(self, *args, **kwargs): # Do the import of plt here to speed module load """Import plt and then calls the parent constructor.""" self._figure = None self._showfig = kwargs.pop("showfig", True) # Retains previous behaviour self._subplots = [] self._public_attrs = { "fig": (int, mplfig.Figure), "labels": list, "template": (DefaultPlotStyle, type(DefaultPlotStyle)), "xlim": tuple, "ylim": tuple, "title": str, "xlabel": str, "ylabel": str, "_showfig": bool, } super().__init__(*args, **kwargs) self._labels = TypedList(str, []) if self.debug: print("Done PlotMixin init") # =========================================================================================================== # Properties of PlotMixin # =========================================================================================================== @property def ax(self): """Return the current axis number.""" return self.axes.index(self.fig.gca()) @ax.setter def ax(self, value): """Change the current axis number.""" if isinstance(value, int) and 0 <= value < len(self.axes): self.fig.sca(self.axes[value]) else: raise IndexError("Figure doesn't have enough axes") @property def axes(self): """Return the current axes object.""" if isinstance(self._figure, mplfig.Figure): ret = self._figure.axes else: ret = None return ret @property def cmap(self): """Get the current cmap.""" return plt.get_cmap() @cmap.setter def cmap(self, cmap): """Set the plot cmap.""" return plt.set_cmap(cmap) @property def fig(self): """Get the current figure.""" return self._figure @fig.setter def fig(self, value): """Set the current figure.""" if isinstance(value, plt.Figure): self._figure = self.template.new_figure(value.number)[0] elif isinstance(value, int): value = plt.figure(value) self.fig = value elif value is None: self._figure = None else: raise NotImplementedError("fig should be a number of matplotlib figure") @property def fignum(self): """Return the current figure number.""" return self._figure.number @property def labels(self): """Return the labels for the plot columns.""" if len(self._labels) == 0: return self.column_headers if len(self._labels) < len(self.column_headers): self._labels.extend(copy.deepcopy(self.column_headers[len(self._labels) :])) return self._labels @labels.setter def labels(self, value): """Set the labels for the plot columns.""" if value is None: self._labels = TypedList(str, self.column_headers) elif isiterable(value) and all_type(value, str): self._labels = TypedList(str, value) else: raise TypeError(f"labels should be iterable and all strings, or None, not {type(value)}") @property def showfig(self): """Return either the current figure or self or None. The return value depends on whether the attribute is True or False or None. """ if self._showfig is None or get_option("no_figs"): return None if self._showfig: return self._figure return self @showfig.setter def showfig(self, value): """Force a figure to be displayed.""" if not (value is None or isinstance(value, bool)): raise AttributeError(f"showfig should be a boolean value not a {type(value)}") self._showfig = value @property def subplots(self): """Return the subplot instances.""" if self._figure is not None and len(self._figure.axes) > len(self._subplots): self._subplots = self._figure.axes return self._subplots @property def template(self): """Return the current plot template.""" if not hasattr(self, "_template"): self.template = DefaultPlotStyle() return self._template @template.setter def template(self, value): """Set the current template.""" if isinstance(value, type) and issubclass(value, DefaultPlotStyle): value = value() if isinstance(value, DefaultPlotStyle): self._template = value else: raise ValueError(f"Template is not of the right class:{type(value)}") self._template.apply() def _span_slice(self, col, num): """Create a slice that covers the range of a given column.""" v1, v2 = self.span(col) v = np.linspace(v1, v2, num) delta = v[1] - v[0] if isinstance(delta, int_types): v2 = v2 + delta else: v2 = v2 + delta / 2 return slice(v1, v2, delta) def _col_label(self, index, join=False): """Look up a column and see if it exists in self._lables, otherwise get from self.column_headers. Args: index (column index type): Column to return label for Returns: String type representing the column label. """ ix = self.find_col(index) if isinstance(ix, list): if join: return ",".join([self._col_label(i) for i in ix]) return [self._col_label(i) for i in ix] return self.labels[ix] def __dir__(self): """Handle the local attributes as well as the inherited ones.""" attr = dir(type(self)) attr.extend(super().__dir__()) attr.extend(list(self.__dict__.keys())) attr.extend(["fig", "axes", "labels", "subplots", "template"]) attr.extend(("xlabel", "ylabel", "title", "xlim", "ylim")) attr = list(set(attr)) return sorted(attr) def __getattr__(self, name): """Wrap attribute access with extra renaming logic. Args: name (string): Name of attribute the following attributes are supported: - fig - the current plt figure reference - axes - the plt axes object for the current plot - xlim - the X axis limits - ylim - the Y axis limits All other attributes are passed over to the parent class """ func = None o_name = name mapping = dict( [ ("plt_", (plt, "pyplot")), # Need to be explicit in 2.7! ("ax_", (plt.Axes, "axes")), ("fig_", (plt.Figure, "figure")), ] ) try: return object.__getattribute__(self, o_name) except AttributeError: pass if plt.get_fignums(): # Store the current figure and axes tfig = plt.gcf() tax = tfig.gca() # protect the current axes and figure else: tfig = None tax = None # First look for a function in the pyplot package for prefix, (obj, key) in mapping.items(): name = o_name[len(prefix) :] if o_name.startswith(prefix) else o_name if name in dir(obj): try: return self._pyplot_proxy(name, key) except AttributeError: pass # Nowcheck for prefixed access on axes and figures with get_ if name.startswith("ax_") and f"get_{name[3:]}" in dir(plt.Axes): name = name[3:] if f"get_{name}" in dir(plt.Axes) and self._figure: ax = self.fig.gca() func = ax.__getattribute__(f"get_{name}") if name.startswith("fig_") and f"get_{name[4:]}" in dir(plt.Figure): name = name[4:] if f"get_{name}" in dir(plt.Figure) and self._figure: fig = self.fig func = fig.__getattribute__(f"get_{name}") if func is None: # Ok Fallback to lookinf again at parent class return super().__getattr__(o_name) # If we're still here then we're calling a proxy from that figure or axes ret = func() if tfig is not None: # Restore the current figures plt.figure(tfig.number) plt.sca(tax) return ret def _pyplot_proxy(self, name, what): """Proxy for accessing :py:module:`matplotlib.pyplot` functions.""" if what not in ["axes", "figure", "pyplot"]: raise SyntaxError("pyplot proxy can't figure out what to get proxy from.") if what == "pyplot": obj = plt elif what == "figure" and self._figure: obj = self._figure elif what == "axes" and self._figure: obj = self._figure.gca() else: raise AttributeError( "Attempting to manipulate the methods on a figure or axes before a" + " figure has been created for this Data." ) func = getattr(obj, name) if not callable(func): # Bug out if this isn't a callable proxy! return func @wraps(func) def _proxy(*args, **kwargs): ret = func(*args, **kwargs) return ret return fix_signature(_proxy, func) def __setattr__(self, name, value): """Set the specified attribute. Args: name (string): The name of the attribute to set. The cuirrent attributes are supported: - fig - set the plt figure instance to use - xlabel - set the X axis label text - ylabel - set the Y axis label text - title - set the plot title - subtitle - set the plot subtitle - xlim - set the x-axis limits - ylim - set the y-axis limits Only "fig" is supported in this class - everything else drops through to the parent class value (any): The value of the attribute to set. """ if plt.get_fignums(): tfig = plt.gcf() if len(tfig.axes): tax = tfig.gca() # protect the current axes and figure else: tax = None else: tfig = None tax = None func = None o_name = name if name.startswith("ax_") and f"set_{name[3:]}" in dir(plt.Axes): name = name[3:] if f"set_{name}" in dir(plt.Axes) and self.fig: if self.fig is None: # oops we need a figure first! self.figure() ax = self.fig.gca() func = ax.__getattribute__(f"set_{name}") elif name.startswith("fig_") and f"set_{name[4:]}" in dir(plt.Figure): name = f"set_{name[4:]}" func = getattr(self.fig, name) elif f"set_{name}" in dir(plt.Figure) and self.fig: if self.fig is None: # oops we need a figure first! self.figure() fig = self.fig func = fig.__getattribute__(f"set_{name}") if o_name in dir(type(self)) or func is None: try: super().__setattr__(o_name, value) return except AttributeError: pass if func is None: raise AttributeError(f"Unable to set attribute {o_name}") if isinstance(value, str) and "$" not in value: value = value.format(**self) if not isiterable(value) or isinstance(value, str): value = (value,) if isinstance(value, Mapping): func(**value) else: func(*value) if tfig is not None: plt.figure(tfig.number) if tax is not None: plt.sca(tax)