Source code for srttools.interactive_filter

"""Interactive operations."""


import copy
import warnings
import logging

try:
    import matplotlib.pyplot as plt
    from matplotlib import gridspec

    HAS_MPL = True
except ImportError:
    HAS_MPL = False

import numpy as np
from .fit import linear_fit, linear_fun, align

from .utils import compare_anything


__all__ = [
    "TestWarning",
    "PlotWarning",
    "mask",
    "intervals",
    "DataSelector",
    "ImageSelector",
    "create_empty_info",
    "select_data",
]


[docs]class TestWarning(UserWarning): pass
[docs]class PlotWarning(UserWarning): pass
warnings.filterwarnings("always", category=TestWarning) warnings.filterwarnings("always", category=PlotWarning)
[docs]def create_empty_info(keys): info = {} for key in keys: info[key] = {} info[key]["FLAG"] = None info[key]["zap"] = intervals() info[key]["base"] = intervals() info[key]["fitpars"] = np.array([0, 0]) return info
[docs]def mask(xs, border_xs, invert=False): """Create mask from a list of interval borders. Parameters ---------- xs : array the array of values to filter border_xs : array the list of borders. Should be an even number of positions Returns ------- mask : array Array of boolean values, that work as a mask to xs Other Parameters ---------------- invert : bool Mask value is False if invert = False, and vice versa. E.g. for zapped intervals, invert = False. For baseline fit selections, invert = True """ good = np.ones(len(xs), dtype=bool) if len(border_xs) >= 2: intervals = list(zip(border_xs[:-1:2], border_xs[1::2])) for i in intervals: good[np.logical_and(xs >= i[0], xs <= i[1])] = False if invert: good = np.logical_not(good) return good
[docs]class intervals: """A list of xs and ys of the points taken during interactive selection.""" def __init__(self): """Initialize.""" self.xs = [] self.ys = []
[docs] def clear(self): """Clear.""" self.xs = [] self.ys = []
[docs] def add(self, point): """Append points.""" self.xs.append(point[0]) self.ys.append(point[1])
def __eq__(self, other): if isinstance(self, other.__class__): return self.__dict__ == other.__dict__ return False def __ne__(self, other): return not self.__eq__(other)
[docs]class DataSelector: """Plot and process scans interactively.""" def __init__(self, xs, ys, ax1, ax2, masks=None, xlabel=None, title=None, test=False): """Initialize.""" self.instructions = """ ------------------------------------------------------------- Interactive plotter. ------------------------------------------------------------- Choose line to fit: Click on the line Interval selection: Point mouse + <key> z create zap intervals b suggest intervals to use for baseline fit Flagging actions: x flag as bad; v Remove flags and all masks from data; Actions: P print current zap list and fit parameters A align all scans w.r.t. the selected one u update plots with new selections B subtract the baseline; r reset baseline and zapping intervals, and fit parameters; q quit ------------------------------------------------------------- """ if not HAS_MPL: raise ImportError("matplotlib not installed") self.xs = xs self.ys = ys self.test = test if masks is None: masks = dict( list( zip( self.xs.keys(), [np.ones(len(self.xs[k]), dtype=bool) for k in self.xs.keys()], ) ) ) self.masks = masks self.ax1 = ax1 self.ax2 = ax2 self.xlabel = xlabel self.title = title self.starting_info = create_empty_info(self.xs.keys()) self.info = copy.deepcopy(self.starting_info) self.lines = [] if not test: self.print_instructions() self.current = None if not test: ax1.figure.canvas.mpl_connect("button_press_event", self.on_click) ax1.figure.canvas.mpl_connect("key_press_event", self.on_key) ax1.figure.canvas.mpl_connect("pick_event", self.on_pick) ax2.figure.canvas.mpl_connect("button_press_event", self.on_click) ax2.figure.canvas.mpl_connect("key_press_event", self.on_key) ax2.figure.canvas.mpl_connect("pick_event", self.on_pick) self.plot_all() self.zcounter = 0 self.bcounter = 0 if not test: plt.show() plt.gcf().canvas.start_event_loop(timeout=-1)
[docs] def on_click(self, event): """Dummy function, in case I want to do something with a click.""" pass
[docs] def zap(self, event): """Create a zap interval.""" key = self.current if key is None: return self.info[key]["zap"].add([event.xdata, event.ydata]) self.zcounter += 1 color = "r" if self.zcounter % 2 == 1: ls = "-" else: ls = "--" line = self.ax1.axvline(event.xdata, color=color, ls=ls) line = self.ax2.axvline(event.xdata, color=color, ls=ls) self.lines.append(line) plt.draw() if self.test: warnings.warn( "I select a zap interval at {}".format(event.xdata), TestWarning, )
[docs] def base(self, event): """Add an interval to the ones that will be used by baseline sub.""" key = self.current if key is None: return self.info[key]["base"].add([event.xdata, event.ydata]) self.bcounter += 1 color = "b" if self.bcounter % 2 == 1: ls = "-" else: ls = "--" line = self.ax1.axvline(event.xdata, color=color, ls=ls) line = self.ax2.axvline(event.xdata, color=color, ls=ls) self.lines.append(line) plt.draw() if self.test: warnings.warn("I put a baseline mark at {}".format(event.xdata), TestWarning)
[docs] def on_key(self, event): """Do something when the keyboard is used.""" if event.key == "z": self.zap(event) elif event.key == "h": self.print_instructions() elif event.key == "b": self.base(event) elif event.key == "B": self.subtract_baseline() elif event.key == "u": self.plot_all() elif event.key == "x": self.flag() elif event.key == "P": self.print_info() elif event.key == "A": self.align_all() elif event.key == "v": self.flag(value=False) elif event.key == "r": self.reset() elif event.key == "q": plt.gcf().canvas.stop_event_loop() self.quit() else: pass
[docs] def flag(self, value=True): self.info[self.current]["FLAG"] = value logging.info("Scan was {}flagged".format("un" if not value else ""))
[docs] def reset(self): for line in self.lines: line.remove() for current in self.xs.keys(): self.lines = [] self.info[current]["zap"].clear() self.info[current]["base"].clear() self.info[current]["fitpars"] = np.array([0, 0]) self.info[current]["FLAG"] = None self.plot_all(silent=True)
[docs] def quit(self): logging.info("Closing all figures and quitting.") old = copy.deepcopy(self.info) for key in old: if compare_anything(self.info[key], self.starting_info[key]): self.info.pop(key) plt.close(plt.gcf())
[docs] def subtract_baseline(self): """Subtract the baseline based on the selected intervals.""" key = self.current if len(self.info[key]["base"].xs) < 2: self.info[key]["fitpars"] = np.array([np.min(self.ys[key]), 0]) else: base_xs = self.info[key]["base"].xs good = mask(self.xs[key], base_xs, invert=True) self.info[key]["fitpars"] = linear_fit( self.xs[key][good], self.ys[key][good], self.info[key]["fitpars"], ) self.plot_all(silent=True) if self.test: warnings.warn("I subtracted the baseline", TestWarning)
[docs] def subtract_model(self, channel): """Subtract the model from the scan.""" fitpars = list(self.info[channel]["fitpars"]) return self.ys[channel] - linear_fun(self.xs[channel], *fitpars)
[docs] def align_all(self): """Given the selected scan, aligns all the others to that.""" reference = self.current x = np.array(self.xs[reference]) y = np.array(self.subtract_model(reference)) zap_xs = self.info[reference]["zap"].xs good = mask(x, zap_xs) xs = [x[good]] ys = [y[good]] keys = [reference] for key in self.xs.keys(): if key == reference: continue x = np.array(self.xs[key].copy()) y = np.array(self.ys[key].copy()) zap_xs = self.info[key]["zap"].xs good = mask(x, zap_xs) good = good * self.masks[key] if len(x[good]) == 0: continue xs.append(x[good]) ys.append(y[good]) keys.append(key) # ------- Make FIT!!! ----- qs, ms = align(xs, ys) # ------------------------- for ik, key in enumerate(keys): if ik == 0: continue self.info[key]["fitpars"] = np.array([qs[ik - 1], ms[ik - 1]]) self.plot_all(silent=True) if self.test: # warnings.filterwarnings("default") warnings.warn("I aligned all", TestWarning)
[docs] def on_pick(self, event): """Do this when I pick a line in the plot.""" thisline = event.artist self.current = thisline._label self.plot_all(silent=True)
[docs] def plot_all(self, silent=False): """Plot everything.""" update_limits = False if self.lines: xlim_save = self.ax1.get_xlim() ylim_save = self.ax1.get_ylim() update_limits = True for line in self.lines: line.remove() self.lines = [] self.ax1.cla() plt.setp(self.ax1.get_xticklabels(), visible=False) good = {} model = {} if self.current is not None: self.ax1.plot( self.xs[self.current], self.ys[self.current], color="g", lw=3, zorder=10, rasterized=True, ) for key in self.xs.keys(): self.ax1.plot( self.xs[key], self.ys[key], color="k", picker=True, label=key, lw=1, rasterized=True, ) zap_xs = self.info[key]["zap"].xs # Eliminate zapped intervals plt.draw() good[key] = mask(self.xs[key], zap_xs) if self.info[key]["FLAG"] is True: good[key][:] = 0 elif self.info[key]["FLAG"] is False: # "v" eliminates all flags! good[key][:] = 1 self.masks[key][:] = 1 good[key] = good[key] * self.masks[key] fitpars = list(self.info[key]["fitpars"]) if len(fitpars) >= 2: model[key] = linear_fun(self.xs[key], *fitpars) self.ax1.plot(self.xs[key], model[key], color="b", rasterized=True) else: model[key] = np.zeros(len(self.xs[key])) + np.min(self.ys[key]) self.ax2.cla() self.ax2.axhline(0, ls="--", color="k") for key in self.xs.keys(): self.ax2.plot( self.xs[key], self.ys[key] - model[key], color="grey", ls="-", picker=True, label=key, rasterized=True, ) self.ax2.plot( self.xs[key][good[key]], self.ys[key][good[key]] - model[key][good[key]], "k-", lw=1, rasterized=True, ) if self.current is not None: logging.info("Current scan is {}".format(self.current)) key = self.current self.ax2.plot( self.xs[key][good[key]], self.ys[key][good[key]] - model[key][good[key]], color="g", lw=3, zorder=10, rasterized=True, ) if self.xlabel is not None: self.ax2.set_xlabel(self.xlabel) if update_limits: self.ax1.set_xlim(xlim_save) self.ax1.set_ylim(ylim_save) plt.draw() if self.test and not silent: warnings.warn("I plotted all", PlotWarning)
[docs] def print_instructions(self): """Print to terminal some instructions for the interactive window.""" print(self.instructions)
[docs] def print_info(self): """Print info on the current scan. Info includes zapped intervals and fit parameters. """ for key in self.info.keys(): print(key + ":") if len(self.info[key]["zap"].xs) >= 2: print( " Zap intervals: ", list( zip( self.info[key]["zap"].xs[:-1:2], self.info[key]["zap"].xs[1::2], ) ), ) print(" Fit pars: ", self.info[key]["fitpars"])
[docs]def select_data(xs, ys, masks=None, title=None, xlabel=None, test=False): """Open a DataSelector window.""" if not HAS_MPL: raise ImportError("matplotlib not installed") try: xs.keys() except Exception: xs = {"Ch": xs} ys = {"Ch": ys} if title is None: title = 'Data selector (press "h" for help)' plt.figure(title) plt.clf() gs = gridspec.GridSpec(2, 1, height_ratios=[3, 2], hspace=0) ax1 = plt.subplot(gs[0]) ax2 = plt.subplot(gs[1], sharex=ax1) datasel = DataSelector(xs, ys, ax1, ax2, masks=masks, title=title, xlabel=xlabel, test=test) return datasel.info
[docs]class ImageSelector: """Return xs and ys of the image, and the key that was pressed. Attributes ---------- img : array the image ax : pyplot.axis instance the axis where the image will be plotted fun : function the function to call when a key is pressed. It must accept three arguments: ``x``, ``y`` and ``key`` """ def __init__(self, data, ax, fun=None, test=False): """ Initialize an ImageSelector class. Parameters ---------- data : array the image ax : pyplot.axis instance the axis where the image will be plotted fun : function, optional the function to call when a key is pressed. It must accept three arguments: ``x``, ``y`` and ``key`` """ if not HAS_MPL: raise ImportError("matplotlib not installed") self.img = data self.ax = ax self.fun = fun self.plot_img() if not test: ax.figure.canvas.mpl_connect("key_press_event", self.on_key) plt.show()
[docs] def on_key(self, event): """Do this when the keyboard is pressed.""" x, y = event.xdata, event.ydata key = event.key logging.info("Pressed key {} at coords {},{}".format(key, x, y)) if key == "q": plt.close(plt.gcf()) elif x is None or y is None or x != x or y != y: logging.warning("Invalid choice. Is the window under focus?") return elif self.fun is not None: plt.close(plt.gcf()) self.fun(x, y, key) return x, y, key
[docs] def plot_img(self): """Plot the image on the interactive display.""" from .utils import ds9_like_log_scale img_to_plot = ds9_like_log_scale(self.img) self.ax.imshow( img_to_plot, origin="lower", vmin=np.percentile(img_to_plot, 20), interpolation="nearest", cmap="gnuplot2", extent=[0, self.img.shape[1], 0, self.img.shape[0]], )