Source code for nmrespy.estimators.jres

# jres.py
# Simon Hulse
# simon.hulse@chem.ox.ac.uk
# Last Edited: Fri 21 Jul 2023 12:36:09 BST

from __future__ import annotations
import copy
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Tuple, Union

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from nmrespy import ExpInfo, sig
from nmrespy.contour_app import ContourApp
from nmrespy.load import load_bruker
from nmrespy.estimators import logger
from nmrespy.estimators._proc_onedim import _Estimator1DProc
from nmrespy.plot import make_color_cycle
from nmrespy._colors import RED, GRE, END, USE_COLORAMA
from nmrespy._files import check_existent_dir, check_saveable_dir
from nmrespy._sanity import (
    sanity_check,
    funcs as sfuncs,
)


if USE_COLORAMA:
    import colorama
    colorama.init()


[docs] class Estimator2DJ(_Estimator1DProc): """Estimator class for J-Resolved (2DJ) datasets, enabling use of our CUPID method for Pure Shift spectra. For a tutorial on the basic functionailty this provides, see :ref:`ESTIMATOR2DJ`. .. note:: To create an instance of ``Estimator2DJ``, you are advised to use one of the following methods if any are appropriate: * :py:meth:`new_bruker` * :py:meth:`increment=i, new_spinach` * :py:meth:`from_pickle` (re-loads a previously saved estimator). """ dim = 2 twodim_dtype = "hyper" proc_dims = [1] ft_dims = [0, 1] default_mpm_trim = [256] default_nlp_trim = [1024] default_max_iterations_exact_hessian = 40 default_max_iterations_gn_hessian = 80
[docs] @classmethod def new_bruker( cls, directory: Union[str, Path], convdta: bool = True, ) -> Estimator2DJ: """Create a new instance from Bruker-formatted data. Parameters ---------- directory Absolute path to data directory. convdta If ``True``, removal of the FID's digital filter will be carried out, using the ``GRPDLY`` parameter. Notes ----- There are certain file paths expected to be found relative to ``directory`` which contain the data and parameter files: * ``directory/ser`` * ``directory/acqus`` * ``directory/acqu2s`` See also -------- :py:meth:`nmrespy.sig.convdta` """ sanity_check( ("directory", directory, check_existent_dir), ("convdta", convdta, sfuncs.check_bool), ) directory = Path(directory).expanduser() data, expinfo = load_bruker(directory) if data.ndim != 2: raise ValueError(f"{RED}Data dimension should be 2.{END}") if directory.parent.name == "pdata": raise ValueError(f"{RED}Importing pdata is not permitted.{END}") if convdta: grpdly = expinfo.parameters["acqus"]["GRPDLY"] data = sig.convdta(data, grpdly) expinfo._offset = (0., expinfo.offset()[1]) expinfo._sfo = (None, expinfo.sfo[1]) expinfo._default_pts = data.shape return cls(data, expinfo, directory)
[docs] @classmethod def new_spinach( cls, shifts: Iterable[float], couplings: Iterable[Tuple(int, int, float)], pts: Tuple[int, int], sw: Tuple[float, float], offset: float, field: float = 11.74, nucleus: str = "1H", snr: Optional[float] = 20., lb: Optional[Tuple[float, float]] = (6.91, 6.91), ) -> None: r"""Create a new instance from a 2DJ Spinach simulation. Parameters ---------- shifts A list of tuple of chemical shift values for each spin. couplings The scalar couplings present in the spin system. Given ``shifts`` is of length ``n``, couplings should be an iterable with entries of the form ``(i1, i2, coupling)``, where ``1 <= i1, i2 <= n`` are the indices of the two spins involved in the coupling, and ``coupling`` is the value of the scalar coupling in Hz. ``None`` will set all spins to be uncoupled. pts The number of points the signal comprises. sw The sweep width of the signal (Hz). offset The transmitter offset (Hz). field The magnetic field strength (T). nucleus The identity of the nucleus targeted in the pulse sequence. snr The signal-to-noise ratio of the resulting signal, in decibels. ``None`` produces a noiseless signal. lb Line broadening (exponential damping) to apply to the signal. The first point will be unaffected by damping, and the final point will be multiplied by ``np.exp(-lb)``. The default results in the final point being decreased in value by a factor of roughly 1000. """ sanity_check( ("shifts", shifts, sfuncs.check_float_list), ("pts", pts, sfuncs.check_int_list, (), {"length": 2, "min_value": 1}), ( "sw", sw, sfuncs.check_float_list, (), {"length": 2, "must_be_positive": 0.}, ), ("offset", offset, sfuncs.check_float), ("field", field, sfuncs.check_float, (), {"greater_than_zero": True}), ("nucleus", nucleus, sfuncs.check_nucleus), ("snr", snr, sfuncs.check_float), ( "lb", lb, sfuncs.check_float_list, (), {"length": 2, "must_be_positive": True}, ), ) nspins = len(shifts) sanity_check( ("couplings", couplings, sfuncs.check_spinach_couplings, (nspins,)), ) if couplings is None: couplings = [] fid, sfo = cls._run_spinach( "jres_sim", shifts, couplings, pts, sw, offset, field, nucleus, to_int=[2], to_double=[3, 4], ) fid = np.array(fid) fid = sig.phase(fid, (0., np.pi / 2), (0., 0.)) # Apply exponential damping for i, k in enumerate(lb): fid = sig.exp_apodisation(fid, k, axes=[i]) if snr is not None: fid = sig.add_noise(fid, snr) expinfo = ExpInfo( dim=2, sw=sw, offset=(0., offset), sfo=(None, sfo), nuclei=(None, nucleus), default_pts=fid.shape, ) return cls(fid, expinfo)
[docs] def view_data( self, domain: str = "freq", abs_: bool = True, ) -> None: r"""View the data FID or the spectral data with an interactive matplotlib figure. Parameters ---------- domain Must be ``"freq"`` or ``"time"``. abs\_ Whether or not to display frequency-domain data in absolute-value mode, as is conventional with 2DJ data. """ sanity_check( ("domain", domain, sfuncs.check_one_of, ("freq", "time")), ("abs_", abs_, sfuncs.check_bool), ) if domain == "freq": spectrum = np.abs(self.spectrum_sinebell) if abs_ else self.spectrum app = ContourApp(spectrum, self.expinfo) app.mainloop() elif domain == "time": fig = plt.figure() ax = fig.add_subplot(projection="3d") x, y = self.get_timepoints() xlabel, ylabel = [f"$t_{i}$ (s)" for i in range(1, 3)] ax.plot_wireframe(x, y, self.data) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_xlim(reversed(ax.get_xlim())) ax.set_ylim(reversed(ax.get_ylim())) ax.set_zticks([]) plt.show()
@property def spectrum_tilt(self) -> np.ndarray: """Generate the spectrum of the data with a 45° tilt.""" spectrum = np.abs(self.spectrum_sinebell).real sw1, sw2 = self.sw() n1, n2 = self.default_pts tilt_factor = (sw1 * n2) / (sw2 * n1) for i, row in enumerate(spectrum): spectrum[i] = np.roll( row, shift=int(tilt_factor * (n1 // 2 - i)), ) return spectrum @property def spectrum_sinebell(self) -> np.ndarray: """Spectrum with sine-bell apodisation. Generated applying sine-bell apodisation to the FID, and applying FT. """ data = copy.deepcopy(self.data) data[0, 0] *= 0.5 data = sig.sinebell_apodisation(data) return sig.ft(data) @property def default_multiplet_thold(self) -> float: r"""The default margin for error when determining oscillators which belong to the same multiplet. Given by :math:`f_{\text{sw}}^{(1)} / 2 N^{(1)}` (i.e. half the spetral resolution in the indirect dimension). """ return 0.5 * (self.sw()[0] / self.default_pts[0])
[docs] @logger def cupid_signal( self, indices: Optional[Iterable[int]] = None, pts: Optional[int] = None, _log: bool = True, ) -> np.ndarray: r"""Generate the signal :math:`y_{-45^{\circ}}(t)`, where :math:`t \geq 0`: .. math:: y_{-45^{\circ}}(t) = \sum_{m=1}^M a_m \exp\left( \mathrm{i} \phi_m \right) \exp\left( \left(2 \mathrm{i} \pi \left(f^{(2)}_m - f^{(1)}_m \right) - \eta^{(2)}_m \right) t \right) Producing this signal from parameters derived from estimation of a 2DJ dataset should generate an absorption-mode 1D homodecoupled spectrum. Parameters ---------- indices See :ref:`INDICES`. pts The number of points to construct the signal from. If ``None``, ``self.default_pts`` will be used. """ self._check_results_exist() sanity_check( self._indices_check(indices), ("pts", pts, sfuncs.check_int, (), {"min_value": 1}, True), ) params = self.get_params(indices) offset = self.offset()[1] if pts is None: pts = self.default_pts[1] tp = self.get_timepoints(pts=(1, pts), meshgrid=False)[1] f1 = params[:, 2] f2 = params[:, 3] signal = np.einsum( "ij,j->i", np.exp( np.outer( tp, 2j * np.pi * (f2 - f1 - offset) - params[:, 5], ) ), params[:, 0] * np.exp(1j * params[:, 1]) ) return signal
[docs] @logger def cupid_spectrum( self, indices: Optional[Iterable[int]] = None, pts: Optional[int] = None, _log: bool = True, ) -> np.ndarray: """Generate a homodecoupled spectrum according to the CUPID method. This generates an FID using :py:meth:`cupid_signal`, halves the first point, and applies FT. Parameters ---------- indices See :ref:`INDICES`. pts The number of points to construct the signal from. If ``None``, ``self.default_pts`` will be used. _log Ignore this! """ self._check_results_exist() sanity_check( self._indices_check(indices), ("pts", pts, sfuncs.check_int, (), {"min_value": 1}, True), ) fid = self.cupid_signal(indices=indices, pts=pts) fid[0] *= 0.5 return sig.ft(fid)
[docs] @logger def predict_multiplets( self, indices: Optional[Iterable[int]] = None, thold: Optional[float] = None, freq_unit: str = "hz", rm_spurious: bool = False, _log: bool = True, **estimate_kwargs, ) -> Dict[float, Iterable[int]]: r"""Predict the estimated oscillators which correspond to each multiplet in the signal. Parameters ---------- indices See :ref:`INDICES`. thold Frequency threshold for multiplet prediction. All oscillators that make up a multiplet are assumed to obey the following expression: .. math:: f_c - f_t < f^{(2)} - f^{(1)} < f_c + f_t where :math:`f_c` is the central frequency of the multiplet, and `f_t` is the threshold. freq_unit Must be ``"hz"`` or ``"ppm"``. rm_spurious If set to ``True``, all oscillators which fall into the following criteria will be purged: * The oscillator is the only member in a multiplet set. * The oscillator's frequency in F1 has a magnitude greater than ``thold`` (i.e. the indirect-dimension frequency is sufficiently far from 0Hz) _log Ignore me! estimate_kwargs If ``rm_suprious`` is ``True``, and oscillators are purged, optimisation isrun. Kwargs are given to :py:meth:estimate:\. Returns ------- Dict[float, Iterable[int]] A dictionary with keys as the multiplet's central frequency, and values as a list of oscillator indices which make up the multiplet. """ self._check_results_exist() sanity_check( self._indices_check(indices), ("thold", thold, sfuncs.check_float, (), {"greater_than_zero": True}, True), ("freq_unit", freq_unit, sfuncs.check_frequency_unit, (self.hz_ppm_valid,)), ("rm_spurious", rm_spurious, sfuncs.check_bool), ) if thold is None: thold = self.default_multiplet_thold params = self.get_params(indices) multiplets = {} in_range = lambda f, g: (g - thold < f < g + thold) for i, osc in enumerate(params): centre_freq = osc[3] - osc[2] assigned = False for freq in multiplets: if in_range(centre_freq, freq): multiplets[freq].append(i) assigned = True break if not assigned: multiplets[centre_freq] = [i] # Set center freqs to average f2 - f1 in the multiplet for old_freq, mp_indices in list(multiplets.items()): new_freq = np.mean(params[mp_indices, 3] - params[mp_indices, 2]) multiplets[new_freq] = multiplets.pop(old_freq) # Remove spurious opscillators, if requested if rm_spurious: spurious = {} for oscs in multiplets.values(): if len(oscs) == 1: osc = oscs[0] f1 = params[osc, 2] if abs(f1) > thold: # osc_loc is a tuple of the form (result_index, osc_index) osc_loc = self.find_osc(params[osc]) if osc_loc[0] in spurious: spurious[osc_loc[0]].append(osc_loc[1]) else: spurious[osc_loc[0]] = [osc_loc[1]] for res_idx, osc_idx in spurious.items(): self.edit_result(index=res_idx, rm_oscs=osc_idx, **estimate_kwargs) factor = 1. if freq_unit == "hz" else self.sfo[-1] multiplets = { freq / factor: indices for freq, indices in sorted(multiplets.items(), key=lambda item: item[0]) } return multiplets
[docs] def get_multiplet_integrals( self, scale: bool = True, **kwargs, ) -> Dict[float, float]: """Get integrals of multiplets assigned using :py:meth:`predict_multiplets`. Parameters ---------- scale If ``True``, the integrals are scaled so that the smallest integral is 1. kwargs Keyword arguments for :py:meth:`predict_multiplet_integrals`. """ self._check_results_exist() sanity_check(("scale", scale, sfuncs.check_bool)) multiplets = self.predict_multiplets(**kwargs) indices = self._process_indices(kwargs.get("indices", None)) params = self.get_params(indices) integrals = { freq: sum(self.oscillator_integrals(params[mp])) for freq, mp in list(multiplets.items()) } if scale: min_integral = min(list(integrals.values())) integrals = { freq: integral / min_integral for freq, integral in list(integrals.items()) } return integrals
[docs] @logger def sheared_signal( self, indices: Optional[Iterable[int]] = None, pts: Optional[Tuple[int, int]] = None, indirect_modulation: Optional[str] = None, ) -> np.ndarray: r"""Return an FID where direct dimension frequencies are perturbed such that: .. math:: f^{(2)}_m = f^{(2)}_m - f^{(1)}_m This should yeild a signal where all components in a multiplet are centered at the spin's chemical shift in the direct dimenion, akin to performing a 45° tilt. Parameters ---------- indices See :ref:`INDICES`. pts The number of points to construct the signal from. If ``None``, ``self.default_pts`` will be used. indirect_modulation Acquisition mode in the indirect dimension. * ``None`` - hypercomplex dataset: .. math:: y \left(t_1, t_2\right) = \sum_{m} a_m e^{\mathrm{i} \phi_m} e^{\left(2 \pi \mathrm{i} f_{1, m} - \eta_{1, m}\right) t_1} e^{\left(2 \pi \mathrm{i} f_{2, m} - \eta_{2, m}\right) t_2} * ``"amp"`` - amplitude modulated pair: .. math:: y_{\mathrm{cos}} \left(t_1, t_2\right) = \sum_{m} a_m e^{\mathrm{i} \phi_m} \cos\left(\left(2 \pi \mathrm{i} f_{1, m} - \eta_{1, m}\right) t_1\right) e^{\left(2 \pi \mathrm{i} f_{2, m} - \eta_{2, m}\right) t_2} .. math:: y_{\mathrm{sin}} \left(t_1, t_2\right) = \sum_{m} a_m e^{\mathrm{i} \phi_m} \sin\left(\left(2 \pi \mathrm{i} f_{1, m} - \eta_{1, m}\right) t_1\right) e^{\left(2 \pi \mathrm{i} f_{2, m} - \eta_{2, m}\right) t_2} * ``"phase"`` - phase-modulated pair: .. math:: y_{\mathrm{P}} \left(t_1, t_2\right) = \sum_{m} a_m e^{\mathrm{i} \phi_m} e^{\left(2 \pi \mathrm{i} f_{1, m} - \eta_{1, m}\right) t_1} e^{\left(2 \pi \mathrm{i} f_{2, m} - \eta_{2, m}\right) t_2} .. math:: y_{\mathrm{N}} \left(t_1, t_2\right) = \sum_{m} a_m e^{\mathrm{i} \phi_m} e^{\left(-2 \pi \mathrm{i} f_{1, m} - \eta_{1, m}\right) t_1} e^{\left(2 \pi \mathrm{i} f_{2, m} - \eta_{2, m}\right) t_2} ``None`` will lead to an array of shape ``(n1, n2)``. ``amp`` and ``phase`` will lead to an array of shape ``(2, n1, n2)``, with ``fid[0]`` and ``fid[1]`` being the two components of the pair. See also -------- * For converting amplitude-modulated data to spectral data, see :py:func:`nmrespy.sig.proc_amp_modulated` * For converting phase-modulated data to spectral data, see :py:func:`nmrespy.sig.proc_phase_modulated` """ self._check_results_exist() sanity_check( ( "indices", indices, sfuncs.check_index, (len(self._results),), {}, True, ), ("pts", pts, sfuncs.check_int, (), {"min_value": 1}, True), ) edited_params = copy.deepcopy(self.get_params(indices)) edited_params[:, 3] -= edited_params[:, 2] return self.make_fid( edited_params, pts=pts, indirect_modulation=indirect_modulation, )
[docs] def construct_multiplet_fids( self, indices: Optional[Iterable[int]] = None, pts: Optional[int] = None, thold: Optional[float] = None, freq_unit: str = "hz", ) -> Iterable[np.ndarray]: """Generate a list of FIDs corresponding to each multiplet structure. Parameters ---------- indices See :ref:`INDICES`. pts The number of points to construct the mutliplets from. thold Frequency threshold for multiplet prediction. All oscillators that make up a multiplet are assumed to obey the following expression: .. math:: f_c - f_t < f^{(2)} - f^{(1)} < f_c + f_t where :math:`f_c` is the central frequency of the multiplet, and `f_t` is ``thold`` freq_unit Must be ``"hz"`` or ``"ppm"``. Returns ------- List of numpy arrays with each FID. """ # TODO: CHECKING multiplets = self.predict_multiplets( indices=indices, thold=thold, freq_unit=freq_unit, ) # Sort by frequency multiplets = sorted(list(multiplets.items()), key=lambda item: item[0]) full_params = self.get_params(indices=indices)[:, [0, 1, 3, 5]] expinfo_direct = self.expinfo_direct fids = [] for (_, idx) in multiplets: params = full_params[idx] fids.append( expinfo_direct.make_fid( params=params, pts=pts, ) ) return fids
[docs] def write_multiplets_to_bruker( self, path: Union[Path, str], expno_prefix: Optional[int] = None, indices: Optional[Iterable[int]] = None, pts: Optional[int] = None, thold: Optional[float] = None, force_overwrite: bool = False, ) -> None: """Write each individual multiplet structure to a Bruker data directory. Each multiplet is saved to a directory of the form ``<path>/<expno_prefix><x>/pdata/1`` where ``<x>`` is iterated from 1 onwards. Parameters ---------- path The path to the root directory to store the data in. expinfo_prefix Prefix to the experiment numbers for storing the multiplets to. If ``None``, experiments will be numbered ``1``, ``2``, etc. indices See :ref:`INDICES`. pts The number of points to construct the mutliplets from. thold Frequency threshold for multiplet prediction. All oscillators that make up a multiplet are assumed to obey the following expression: .. math:: f_c - f_t < f^{(2)} - f^{(1)} < f_c + f_t where :math:`f_c` is the central frequency of the multiplet, and `f_t` is ``thold`` force_overwite If ``False``, if any directories that will be written to already exist, you will be promted if you are happy to overwrite. If ``True``, overwriting will take place without asking. """ self._check_results_exist() sanity_check( ("path", path, check_saveable_dir, (True,)), ( "expno_prefix", expno_prefix, sfuncs.check_int, (), {"min_value": 1}, True, ), self._indices_check(indices), # Not a "normal" pts check as checking for valid 1D value rather than 2D ( "pts", pts, sfuncs.check_int, (), {"min_value": 1}, self.default_pts is not None, ), ("thold", thold, sfuncs.check_float, (), {"greater_than_zero": True}, True), ) path = Path(path).expanduser() fids = self.construct_multiplet_fids( indices=indices, pts=pts, thold=thold, ) # Establish list of expno names n = len(fids) if expno_prefix is None: first = 1 else: ndigits = int(np.log10(n)) + 1 first = expno_prefix * (10 ** ndigits) + 1 expnos = list(range(first, first + n)) if not force_overwrite: for expno in expnos: sanity_check( ( "expno_prefix", path / str(expno), check_saveable_dir, (False,), ), ) expinfo_1d = self.expinfo_direct for (fid, expno) in zip(fids, expnos): expinfo_1d.write_to_bruker( fid, path, expno=expno, procno=1, force_overwrite=True, ) print( f"{GRE}Saved multiplets to folders {path}/[{expnos[0]}-{expnos[-1]}]/" f"{END}" )
[docs] def write_cupid_to_bruker( self, path: Union[Path, str], expno: Optional[int] = None, indices: Optional[Iterable[int]] = None, pts: Optional[int] = None, force_overwrite: bool = False, ) -> None: """Write the signal generated by :py:meth:`cupid_signal` to a Bruker dataset. The dataset is saved to a directory of the form ``<path>/<expno>`` Parameters ---------- path The path to the root directory to store the data in. This must already exist. expno The experiment number. If ``None``, the first directory number ``<x>`` for which ``<path>/<x>/`` isn;t currently a directory will be used. indices See :ref:`INDICES`. pts The number of points to construct the dataset from. force_overwite If ``False``, and ``<path>/<expno>/`` already exists, the user will be asked if they are happy to overwrite. If ``True``, overwriting will take place without asking. """ self._check_results_exist() sanity_check( ("path", path, check_saveable_dir, (True,)), ( "expno", expno, sfuncs.check_int, (), {"min_value": 1}, True, ), self._indices_check(indices), # Not a "normal" pts check as checking for valid 1D value rather than 2D ( "pts", pts, sfuncs.check_int, (), {"min_value": 1}, self.default_pts is not None, ), ("force_overwrite", force_overwrite, sfuncs.check_bool), ) fid = self.cupid_signal(indices=indices, pts=pts, _log=False) expinfo_1d = self.expinfo_direct expinfo_1d.write_to_bruker( fid, path, expno=expno, procno=1, force_overwrite=force_overwrite, ) print(f"{GRE}Saved CUPID signal to {path}/{expno}/{END}")
[docs] @logger def plot_result( self, indices: Optional[Iterable[int]] = None, multiplet_thold: Optional[float] = None, high_resolution_pts: Optional[int] = None, ratio_1d_2d: Tuple[float, float] = (2., 1.), region_unit: str = "hz", axes_left: float = 0.07, axes_right: float = 0.96, axes_bottom: float = 0.08, axes_top: float = 0.96, axes_region_separation: float = 0.05, xaxis_label_height: float = 0.02, xaxis_ticks: Optional[Iterable[Tuple[int, Iterable[float]]]] = None, contour_base: Optional[float] = None, contour_nlevels: Optional[int] = None, contour_factor: Optional[float] = None, contour_lw: float = 0.5, contour_color: Any = "k", jres_sinebell: bool = True, multiplet_colors: Any = None, multiplet_lw: float = 1., multiplet_vertical_shift: float = 0., multiplet_show_center_freq: bool = True, multiplet_show_45: bool = True, marker_size: float = 3., marker_shape: str = "o", label_peaks: bool = False, denote_regions: bool = False, **kwargs, ) -> Tuple[mpl.figure.Figure, np.ndarray[mpl.axes.Axes]]: r"""Generate a figure of the estimation result. The figure includes a contour plot of the 2DJ spectrum, a 1D plot of the first slice through the indirect dimension, plots of estimated multiplets, and a plot of :py:meth:`cupid_spectrum`. Parameters ---------- indices See :ref:`INDICES`. multiplet_thold Frequency threshold for multiplet prediction. All oscillators that make up a multiplet are assumed to obey the following expression: .. math:: f_c - f_t < f^{(2)} - f^{(1)} < f_c + f_t where :math:`f_c` is the central frequency of the multiplet, and `f_t` is the threshold. high_resolution_pts Indicates the number of points used to generate the multiplet structures and :py:meth:`cupid_spectrum`. Should be greater than or equal to ``self.default_pts[1]``. ratio_1d_2d The relative heights of the regions containing the 1D spectra and the 2DJ spectrum. axes_left The position of the left edge of the axes, in `figure coordinates <https://matplotlib.org/stable/tutorials/advanced/\ transforms_tutorial.html>`_\. Should be between ``0.`` and ``1.``. axes_right The position of the right edge of the axes, in figure coordinates. Should be between ``0.`` and ``1.``. axes_top The position of the top edge of the axes, in figure coordinates. Should be between ``0.`` and ``1.``. axes_bottom The position of the bottom edge of the axes, in figure coordinates. Should be between ``0.`` and ``1.``. axes_region_separation The extent by which adjacent regions are separated in the figure. xaxis_label_height The vertical location of the x-axis label, in figure coordinates. Should be between ``0.`` and ``1.``, though you are likely to want this to be only slightly larger than ``0.``. xaxis_ticks See :ref:`XAXIS_TICKS`. contour_base The lowest level for the contour levels in the 2DJ spectrum plot. contour_nlevels The number of contour levels in the 2DJ spectrum plot. contour_factor The geometric scaling factor for adjacent contours in the 2DJ spectrum plot. contour_lw The linewidth of contours in the 2DJ spectrum plot. contour_color The color of the 2DJ spectrum plot. jres_sinebell If ``True``, applies sine-bell apodisation to the 2DJ spectrum. multiplet_colors Describes how to color multiplets. See :ref:`COLOR_CYCLE` for options. multiplet_lw Line width of multiplet spectra multiplet_vertical_shift The vertical displacement of adjacent mutliplets, as a multiple of ``mutliplet_lw``. Set to ``0.`` if you want all mutliplets to lie on the same line. multiplet_show_center_freq If ``True``, lines are plotted on the 2DJ spectrum indicating the central frequency of each mutliplet. multiplet_show_45 If ``True``, lines are plotted on the 2DJ spectrum indicating the 45° line along which peaks lie in each multiplet. marker_size The size of markers indicating positions of peaks on the 2DJ contour plot. marker_shape The `shape of markers <https://matplotlib.org/stable/api/markers_api.html>`_ indicating positions of peaks on the 2DJ contour plot. kwargs Keyword arguments provided to `matplotlib.pyplot.figure <https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.figure.html\ #matplotlib.pyplot.figure>`_\. Allowed arguments include ``figsize``, ``facecolor``, ``edgecolor``, etc. Returns ------- fig The result figure. This can be saved to various formats using the `savefig <https://matplotlib.org/stable/api/figure_api.html\ #matplotlib.figure.Figure.savefig>`_ method. axs A ``(2, N)`` NumPy array of the axes used for plotting. The first row of axes contain the 1D plots. The second row contain the 2DJ contour plots. """ sanity_check( ( "indices", indices, sfuncs.check_int_list, (), { "len_one_can_be_listless": True, "min_value": -len(self._results), "max_value": len(self._results) - 1, }, True, ), ( "multiplet_thold", multiplet_thold, sfuncs.check_float, (), {"greater_than_zero": True}, True, ), ( "high_resolution_pts", high_resolution_pts, sfuncs.check_int, (), {"min_value": self.default_pts[1]}, True, ), ( "ratio_1d_2d", ratio_1d_2d, sfuncs.check_float_list, (), {"length": 2, "must_be_positive": True}, ), self._funit_check(region_unit, "region_unit"), ( "axes_left", axes_left, sfuncs.check_float, (), {"min_value": 0., "max_value": 1.}, ), ( "axes_right", axes_right, sfuncs.check_float, (), {"min_value": 0., "max_value": 1.}, ), ( "axes_bottom", axes_bottom, sfuncs.check_float, (), {"min_value": 0., "max_value": 1.}, ), ( "axes_top", axes_top, sfuncs.check_float, (), {"min_value": 0., "max_value": 1.}, ), ( "axes_region_separation", axes_region_separation, sfuncs.check_float, (), {"min_value": 0., "max_value": 1.}, ), ( "xaxis_label_height", xaxis_label_height, sfuncs.check_float, (), {"min_value": 0., "max_value": 1.}, ), ( "contour_base", contour_base, sfuncs.check_float, (), {"min_value": 0.}, True, ), ( "contour_nlevels", contour_nlevels, sfuncs.check_int, (), {"min_value": 1}, True, ), ( "contour_factor", contour_factor, sfuncs.check_float, (), {"min_value": 1.}, True, ), ("contour_lw", contour_lw, sfuncs.check_float, (), {"min_value": 0.}), ("jres_sinebell", jres_sinebell, sfuncs.check_bool), ("marker_size", marker_size, sfuncs.check_float, (), {"min_value": 0.}), ( "multiplet_colors", multiplet_colors, sfuncs.check_oscillator_colors, (), {}, True, ), ("multiplet_lw", multiplet_lw, sfuncs.check_float, (), {"min_value": 0.}), ( "multiplet_vertical_shift", multiplet_vertical_shift, sfuncs.check_float, (), {"min_value": 0.}, ), ( "multiplet_show_center_freq", multiplet_show_center_freq, sfuncs.check_bool, ), ("multiplet_show_45", multiplet_show_45, sfuncs.check_bool), ("denote_regions", denote_regions, sfuncs.check_bool), ) # TODO # contour_color # linewidth # marker_shape: str = "o", indices = self._process_indices(indices) regions = sorted( [ (i, result.get_region(unit=region_unit)[1]) for i, result in enumerate(self.get_results()) if i in indices ], key=lambda x: x[1][0], reverse=True, ) # Megre overlapping/bordering regions merge_indices = [] merge_regions = [] for idx, region in regions: assigned = False for i, reg in enumerate(merge_regions): if max(region) >= min(reg): merge_regions[i] = (max(reg), min(region)) assigned = True elif min(region) >= max(reg): merge_regions[i] = (max(region), min(reg)) assigned = True if assigned: merge_indices[i].append(idx) break if not assigned: merge_indices.append([idx]) merge_regions.append(region) n_regions = len(merge_regions) fig, axs = plt.subplots( nrows=2, ncols=n_regions, gridspec_kw={ "left": axes_left, "right": axes_right, "bottom": axes_bottom, "top": axes_top, "wspace": axes_region_separation, "hspace": 0., "width_ratios": [r[0] - r[1] for r in merge_regions], "height_ratios": ratio_1d_2d, }, **kwargs, ) if n_regions == 1: axs = axs.reshape(2, 1) if all( [isinstance(x, (float, int)) for x in (contour_base, contour_nlevels, contour_factor)] ): contour_levels = [ contour_base * contour_factor ** i for i in range(contour_nlevels) ] else: contour_levels = None if high_resolution_pts is None: high_resolution_pts = self.default_pts[1] expinfo_1d = self.expinfo_direct expinfo_1d_highres = copy.deepcopy(expinfo_1d) expinfo_1d_highres.default_pts = (high_resolution_pts,) full_shifts_1d, = expinfo_1d.get_shifts(unit=region_unit) full_shifts_1d_highres, = expinfo_1d_highres.get_shifts(unit=region_unit) full_shifts_2d_y, full_shifts_2d_x = self.get_shifts(unit=region_unit) sfo = self.sfo[1] shifts_2d = [] shifts_1d = [] shifts_1d_highres = [] spectra_2d = [] spectra_1d = [] neg_45_spectra = [] f1_f2 = [] center_freqs = [] multiplet_spectra = [] multiplet_indices = [] conv = f"{region_unit}->idx" full_spectrum = np.abs( self.spectrum_sinebell if jres_sinebell else self.spectrum ).real for idx, region in zip(merge_indices, merge_regions): slice_ = slice(*expinfo_1d.convert([region], conv)[0]) highres_slice = slice(*expinfo_1d_highres.convert([region], conv)[0]) shifts_2d.append( (full_shifts_2d_x[:, slice_], full_shifts_2d_y[:, slice_]) ) shifts_1d.append(full_shifts_1d[slice_]) shifts_1d_highres.append(full_shifts_1d_highres[highres_slice]) spectra_2d.append(np.abs(full_spectrum).real[:, slice_]) spectra_1d.append(self.spectrum_first_direct.real[slice_]) neg_45_spectra.append( self.cupid_spectrum( indices=idx, pts=high_resolution_pts, _log=False, ).real[highres_slice] ) params = self.get_params(indices=idx) multiplet_indices.append( list( reversed( self.predict_multiplets( indices=idx, thold=multiplet_thold, _log=False, ).values() ) ) ) multiplet_params = [params[i] for i in multiplet_indices[-1]] f1_f2_region = [] center_freq = [] for multiplet_param in multiplet_params: f1, f2 = multiplet_param[:, [2, 3]].T cf = np.mean(f2 - f1) f2 = f2 / sfo if region_unit == "ppm" else f2 cf = cf / sfo if region_unit == "ppm" else cf center_freq.append(cf) f1_f2_region.append((f1, f2)) multiplet = expinfo_1d.make_fid( multiplet_param[:, [0, 1, 3, 5]], pts=high_resolution_pts, ) multiplet[0] *= 0.5 multiplet_spectra.append(sig.ft(multiplet).real) f1_f2.append(f1_f2_region) center_freqs.append(center_freq) print(center_freqs) n_multiplets = len(multiplet_spectra) # Plot individual mutliplets for ax in axs[0]: colors = make_color_cycle(multiplet_colors, n_multiplets) ymax = -np.inf for i, mp_spectrum in enumerate(multiplet_spectra): color = next(colors) x = n_multiplets - 1 - i line = ax.plot( full_shifts_1d_highres, mp_spectrum + multiplet_vertical_shift * x, color=color, lw=multiplet_lw, zorder=i, )[0] line_max = np.amax(line.get_ydata()) if line_max > ymax: ymax = line_max i += 1 # Plot 1D spectrum spec_1d_low_pt = min([np.amin(spec) for spec in spectra_1d]) shift = 1.03 * (ymax - spec_1d_low_pt) ymax = -np.inf for ax, shifts, spectrum in zip(axs[0], shifts_1d, spectra_1d): line = ax.plot(shifts, spectrum + shift, color="k")[0] line_max = np.amax(line.get_ydata()) if line_max > ymax: ymax = line_max # Plot homodecoupled spectrum homo_spec_low_pt = min([np.amin(spec) for spec in neg_45_spectra]) shift = 1.03 * (ymax - homo_spec_low_pt) for ax, shifts, spectrum in zip(axs[0], shifts_1d_highres, neg_45_spectra): ax.plot(shifts, spectrum + shift, color="k") # Plot 2DJ contour for ax, shifts, spectrum in zip(axs[1], shifts_2d, spectra_2d): ax.contour( *shifts, spectrum, colors=contour_color, linewidths=contour_lw, levels=contour_levels, zorder=0, ) # Plot peak positions onto 2DJ colors = make_color_cycle(multiplet_colors, n_multiplets) for ax, f1f2, mp_idxs in zip(axs[1], f1_f2, multiplet_indices): for mp_f1f2, mp_idx in zip(f1f2, mp_idxs): color = next(colors) f1, f2 = mp_f1f2 ax.scatter( x=f2, y=f1, s=marker_size, marker=marker_shape, color=color, edgecolor="none", zorder=100, ) if label_peaks: for f1_, f2_, idx in zip(f1, f2, mp_idx): ax.text( x=f2_, y=f1_, s=str(idx), color=color, fontsize=8, clip_on=True, ) ylim1 = (shifts_2d[0][1][0, 0], shifts_2d[0][1][-1, 0]) # Plot multiplet central frequencies if multiplet_show_center_freq: colors = make_color_cycle(multiplet_colors, n_multiplets) for ax, center_freq in zip(axs[1], center_freqs): for cf in center_freq: color = next(colors) ax.plot( [cf, cf], ylim1, color=color, lw=0.8, zorder=2, ) # Plot 45 lines that multiplets lie along if multiplet_show_45: colors = make_color_cycle(multiplet_colors, n_multiplets) for ax, center_freq in zip(axs[1], center_freqs): for cf in center_freq: color = next(colors) ax.plot( [cf + lim / (sfo if region_unit == "ppm" else 1.) for lim in ylim1], ylim1, color=color, lw=0.8, zorder=2, ls=":", ) # Configure axis appearance ylim0 = ( min([ax.get_ylim()[0] for ax in axs[0]]), max([ax.get_ylim()[1] for ax in axs[0]]), ) if denote_regions: for i, mi in enumerate(merge_indices): if len(mi) > 1: locs_to_plot = [reg[1][0] for reg in regions if reg[0] in mi[1:]] for loc in locs_to_plot: for j, y in enumerate((ylim0, ylim1)): axs[j, i].plot( [loc, loc], y, color="#808080", ls=":", ) axs[0, 0].spines["left"].set_zorder(1000) axs[0, -1].spines["right"].set_zorder(1000) self._configure_axes( fig, axs, merge_regions, xaxis_ticks, axes_left, axes_right, xaxis_label_height, region_unit, ) for ax in axs[0]: ax.set_ylim(ylim0) for ax in axs[1]: ax.set_ylim(ylim1) axs[1, 0].set_ylabel("Hz") return fig, axs
[docs] def edit_result( self, index: int = -1, add_oscs: Optional[np.ndarray] = None, rm_oscs: Optional[Iterable[int]] = None, merge_oscs: Optional[Iterable[Iterable[int]]] = None, split_oscs: Optional[Dict[int, Optional[Dict]]] = None, mirror_oscs: Optional[Iterable[int]] = None, **estimate_kwargs, ) -> None: r"""Manipulate an estimation result. After the result has been changed, it is subjected to optimisation. There are five types of edit that you can make: * *Add* new oscillators with defined parameters. * *Remove* oscillators. * *Merge* multiple oscillators into a single oscillator. * *Split* an oscillator into many oscillators. * **Unique to 2DJ**: *Mirror* an oscillator. This allows you add a new oscillator with the same parameters as an osciallator in the result, except with the following frequencies: .. math:: f^{(1)}_{\text{new}} = -f^{(1)}_{\text{old}} .. math:: f^{(2)}_{\text{new}} = f^{(2)}_{\text{old}} - f^{(1)}_{\text{old}} Parameters ---------- index See :ref:`INDEX`. add_oscs The parameters of new oscillators to be added. Should be of shape ``(n, 2 * (1 + self.dim))``, where ``n`` is the number of new oscillators to add. Even when one oscillator is being added this should be a 2D array, i.e. * 1D data: .. code:: params = np.array([[a, φ, f, η]]) * 2D data: .. code:: params = np.array([[a, φ, f₁, f₂, η₁, η₂]]) rm_oscs An iterable of ints for the indices of oscillators to remove from the result. merge_oscs An iterable of iterables. Each sub-iterable denotes the indices of oscillators to merge together. For example, ``[[0, 2], [6, 7]]`` would mean that oscillators 0 and 2 are merged, and oscillators 6 and 7 are merged. A merge involves removing all the oscillators, and creating a new oscillator with the sum of amplitudes, and the average of phases, freqeuncies and damping factors. split_oscs A dictionary with ints as keys, denoting the oscillators to split. The values should themselves be dicts, with the following permitted key/value pairs: * ``"separation"`` - An list of length equal to ``self.dim``. Indicates the frequency separation of the split oscillators in Hz. If not specified, this will be the spectral resolution in each dimension. * ``"number"`` - An int indicating how many oscillators to split into. If not specified, this will be ``2``. * ``"amp_ratio"`` A list of floats with length equal to the number of oscillators to be split into (see ``"number"``). Specifies the relative amplitudes of the oscillators. If not specified, the amplitudes will be equal. As an example for a 1D estimator: .. code:: split_oscs = { 2: { "separation": 1., # if 1D, don't need a list }, 5: { "number": 3, "amp_ratio": [1., 2., 1.], }, } Here, 2 oscillators will be split. * Oscillator 2 will be split into 2 (default) oscillators with equal amplitude (default). These will be separated by 1Hz. * Oscillator 5 will be split into 3 oscillators with relative amplitudes 1:2:1. These will be separated by ``self.sw()[0] / self.default_pts()[0]`` Hz (default). mirror_oscs An interable of oscillators to mirror (see the description above). estimate_kwargs Keyword arguments to provide to the call to :py:meth:`estimate`. Note that ``"initial_guess"`` and ``"region_unit"`` are set internally and will be ignored if given. """ sanity_check(self._index_check(index)) index, = self._process_indices([index]) result, = self.get_results(indices=[index]) params = result.get_params() max_osc_idx = len(params) - 1 sanity_check( ( "mirror_oscs", mirror_oscs, sfuncs.check_int_list, (), {"min_value": 0, "max_value": max_osc_idx}, True, ), ) if mirror_oscs is not None: to_mirror = params[mirror_oscs] mirrored = copy.deepcopy(to_mirror) mirrored[:, 2] = -mirrored[:, 2] mirrored[:, 3] += mirrored[:, 2] if isinstance(add_oscs, np.ndarray): add_oscs = np.vstack((add_oscs, mirrored)) else: add_oscs = mirrored super().edit_result( index, add_oscs, rm_oscs, merge_oscs, split_oscs, **estimate_kwargs, )