# onedim.py
# Simon Hulse
# simon.hulse@chem.ox.ac.uk
# Last Edited: Wed 24 May 2023 10:59:32 BST
from __future__ import annotations
import copy
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import nmrespy as ne
from nmrespy.load import load_bruker
from nmrespy.plot import make_color_cycle
from nmrespy._colors import RED, END, USE_COLORAMA
from nmrespy._files import check_existent_dir, check_saveable_dir
from nmrespy._misc import proc_kwargs_dict
from nmrespy._sanity import (
sanity_check,
funcs as sfuncs,
)
from nmrespy.estimators import logger
from nmrespy.estimators._proc_onedim import _Estimator1DProc
if USE_COLORAMA:
import colorama
colorama.init()
[docs]
class Estimator1D(_Estimator1DProc):
"""Estimator class for 1D data. For a tutorial on the basic functionailty
this provides, see :ref:`ESTIMATOR1D`.
.. note::
To create an instance of ``Estimator1D``, you are advised to use one of
the following methods if any are appropriate:
* :py:meth:`new_bruker`
* :py:meth:`new_from_parameters`
* :py:meth:`new_spinach`
* :py:meth:`from_pickle` (re-loads a previously saved estimator).
"""
dim = 1
twodim_dtype = None
proc_dims = [0]
ft_dims = [0]
default_mpm_trim = [4096]
default_nlp_trim = [None]
default_max_iterations_exact_hessian = 100
default_max_iterations_gn_hessian = 200
[docs]
@classmethod
def new_bruker(
cls,
directory: Union[str, Path],
convdta: bool = True,
) -> Estimator1D:
"""Create a new instance from Bruker-formatted data.
Parameters
----------
directory
Absolute path to data directory.
convdta
If ``True`` and the data is derived from an ``fid`` file, removal of
the FID's digital filter will be carried out.
Notes
-----
There are certain file paths expected to be found relative to ``directory``
which contain the data and parameter files. Here is an extensive list of
the paths expected to exist, for different data types:
* Raw FID
+ ``directory/fid``
+ ``directory/acqus``
* Processed data
+ ``directory/1r``
+ ``directory/../../acqus``
+ ``directory/procs``
"""
sanity_check(
("directory", directory, check_existent_dir),
("convdta", convdta, sfuncs.check_bool),
)
directory = Path(directory).expanduser()
data, expinfo = load_bruker(directory)
if data.ndim != 1:
raise ValueError(f"{RED}Data dimension should be 1.{END}")
if directory.parent.name == "pdata":
slice_ = slice(0, data.shape[0] // 2)
data = (2 * ne.sig.ift(data))[slice_]
elif convdta:
grpdly = expinfo.parameters["acqus"]["GRPDLY"]
data = ne.sig.convdta(data, grpdly)
return cls(data, expinfo, directory)
[docs]
@classmethod
def new_spinach(
cls,
shifts: Iterable[float],
couplings: Optional[Iterable[Tuple(int, int, float)]],
pts: int,
sw: float,
offset: float = 0.,
field: float = 11.74,
nucleus: str = "1H",
snr: Optional[float] = 20.,
lb: float = 6.91,
) -> Estimator1D:
r"""Create a new instance from a pulse-acquire Spinach simulation.
See :ref:`SPINACH_INSTALL` for requirments to use this method.
Parameters
----------
shifts
A list of tuple of chemical shift values for each spin.
couplings
The scalar couplings present in the spin system. Given ``shifts`` is of
length ``n``, couplings should be an iterable with entries of the form
``(i1, i2, coupling)``, where ``1 <= i1, i2 <= n`` are the indices of
the two spins involved in the coupling, and ``coupling`` is the value
of the scalar coupling in Hz. ``None`` will set all spins to be
uncoupled.
pts
The number of points the signal comprises.
sw
The sweep width of the signal (Hz).
offset
The transmitter offset (Hz).
sfo
The magnetic field strength (T).
nucleus
The identity of the nucleus. Should be of the form ``"<mass><sym>"``
where ``<mass>`` is the atomic mass and ``<sym>`` is the element symbol.
Examples:
* ``"1H"``
* ``"13C"``
* ``"195Pt"``
snr
The signal-to-noise ratio of the resulting signal, in decibels. ``None``
produces a noiseless signal.
lb
Line broadening (exponential damping) to apply to the signal.
The first point will be unaffected by damping, and the final point will
be multiplied by ``np.exp(-lb)``. The default results in the final
point being decreased in value by a factor of roughly 1000.
"""
sanity_check(
("shifts", shifts, sfuncs.check_float_list),
("pts", pts, sfuncs.check_int, (), {"min_value": 1}),
("sw", sw, sfuncs.check_float, (), {"greater_than_zero": True}),
("offset", offset, sfuncs.check_float),
("field", field, sfuncs.check_float, (), {"greater_than_zero": True}),
("nucleus", nucleus, sfuncs.check_nucleus),
("snr", snr, sfuncs.check_float, (), {}, True),
("lb", lb, sfuncs.check_float, (), {"greater_than_zero": True})
)
nspins = len(shifts)
sanity_check(
("couplings", couplings, sfuncs.check_spinach_couplings, (nspins,), {}, True), # noqa: E501
)
if couplings is None:
couplings = []
fid, sfo = cls._run_spinach(
"onedim_sim", shifts, couplings, pts, sw, offset, field, nucleus,
)
fid = np.array(fid).flatten()
if snr is not None:
fid = ne.sig.add_noise(fid, snr)
fid = ne.sig.exp_apodisation(fid, lb)
expinfo = ne.ExpInfo(
dim=1,
sw=sw,
offset=offset,
sfo=sfo,
nuclei=nucleus,
default_pts=fid.shape,
)
return cls(fid, expinfo)
[docs]
@classmethod
def new_from_parameters(
cls,
params: np.ndarray,
pts: int,
sw: float,
offset: float,
sfo: float = 500.,
nucleus: str = "1H",
snr: Optional[float] = 20.,
) -> Estimator1D:
"""Generate an estimator instance with sythetic data created from an
array of oscillator parameters.
Parameters
----------
params
Parameter array with the following structure:
.. code:: python
params = numpy.array([
[a_1, φ_1, f_1, η_1],
[a_2, φ_2, f_2, η_2],
...,
[a_m, φ_m, f_m, η_m],
])
pts
The number of points the signal comprises.
sw
The sweep width of the signal (Hz).
offset
The transmitter offset (Hz).
sfo
The transmitter frequency (MHz).
nucleus
The identity of the nucleus. Should be of the form ``"<mass><sym>"``
where ``<mass>`` is the atomic mass and ``<sym>`` is the element symbol.
Examples: ``"1H"``, ``"13C"``, ``"195Pt"``
snr
The signal-to-noise ratio (dB). If ``None`` then no noise will be added
to the FID.
"""
sanity_check(
("params", params, sfuncs.check_parameter_array, (1,)),
("pts", pts, sfuncs.check_int, (), {"min_value": 1}),
("sw", sw, sfuncs.check_float, (), {"greater_than_zero": True}),
("offset", offset, sfuncs.check_float, (), {}, True),
("nucleus", nucleus, sfuncs.check_nucleus),
("sfo", sfo, sfuncs.check_float, (), {"greater_than_zero": True}, True),
("snr", snr, sfuncs.check_float, (), {"greater_than_zero": True}, True),
)
expinfo = ne.ExpInfo(
dim=1,
sw=sw,
offset=offset,
sfo=sfo,
nuclei=nucleus,
default_pts=pts,
)
data = expinfo.make_fid(params, snr=snr)
return cls(data, expinfo)
[docs]
def view_data(
self,
domain: str = "freq",
components: str = "real",
freq_unit: str = "hz",
) -> None:
"""View the data (FID or spectrum) with an interactive matplotlib plot.
Parameters
----------
domain
Must be ``"freq"`` or ``"time"``.
components
Must be ``"real"``, ``"imag"``, or ``"both"``.
freq_unit
Must be ``"hz"`` or ``"ppm"``. If ``domain`` is ``freq``, this
determines which unit to set chemical shifts to.
"""
sanity_check(
("domain", domain, sfuncs.check_one_of, ("freq", "time")),
("components", components, sfuncs.check_one_of, ("real", "imag", "both")),
("freq_unit", freq_unit, sfuncs.check_frequency_unit, (self.hz_ppm_valid,)),
)
fig = plt.figure()
ax = fig.add_subplot()
y = copy.deepcopy(self._data)
if domain == "freq":
x = self.get_shifts(unit=freq_unit)[0]
y = self.spectrum
label, = self._axis_freq_labels(freq_unit)
elif domain == "time":
x, = self.get_timepoints()
label = "$t$ (s)"
if components in ["real", "both"]:
ax.plot(x, y.real, color="k")
if components in ["imag", "both"]:
ax.plot(x, y.imag, color="#808080")
ax.set_xlabel(label)
ax.set_xlim((x[0], x[-1]))
plt.show()
[docs]
def write_to_bruker(
self,
path: Union[str, Path],
indices: Optional[Iterable[int]] = None,
pts: Optional[Iterable[int]] = None,
expno: Optional[int] = None,
procno: Optional[int] = None,
force_overwrite: bool = False,
) -> None:
"""Write a signal generated with estimated parameters to Bruker format.
* ``<path>/<expno>/`` will contain the time-domain data and information
(``fid``, ``acqus``, ...)
* ``<path>/<expno>/pdata/1/`` will contain the processed data and
information (``pdata``, ``procs``, ...)
.. note::
There is a known problem that the spectral data has timepoints along
the x-axis rather than chemical shifts. I will try to figure out why
and fix this in due course!
Parameters
----------
path
The path to the root directory to store the data in.
indices
See :ref:`INDICES`.
pts
The number of points to construct the signal from.
expno
The experiment number. If ``None``, the smallest int ``x`` for which the
directory ``<path>/<x>/`` doesn't exist will be used.
force_overwrite
If ``False`` and the directory ``<path>/<expno>/`` already exists,
the user will be prompted to confirm whether they are happy to
overwrite it. If ``True``, said directory will be overwritten.
"""
# TODO: figure out x-axis issue (see warning above).
self._check_results_exist()
sanity_check(
("path", path, check_saveable_dir, (True,)),
self._indices_check(indices),
self._pts_check(pts),
("expno", expno, sfuncs.check_int, (), {"min_value": 1}, True),
("force_overwrite", force_overwrite, sfuncs.check_bool),
)
fid = self.make_fid_from_result(indices=indices, pts=pts)
# calls ne.ExpInfo.write_to_bruker()
super().write_to_bruker(fid, path, expno, 1, force_overwrite)
[docs]
@logger
def plot_result(
self,
indices: Optional[Iterable[int]] = None,
high_resolution_pts: Optional[int] = None,
axes_left: float = 0.07,
axes_right: float = 0.96,
axes_bottom: float = 0.08,
axes_top: float = 0.96,
axes_region_separation: float = 0.05,
xaxis_unit: str = "hz",
xaxis_label_height: float = 0.02,
xaxis_ticks: Optional[Iterable[Tuple[int, Iterable[float]]]] = None,
oscillator_colors: Any = None,
plot_model: bool = True,
plot_residual: bool = True,
model_shift: Optional[float] = None,
residual_shift: Optional[float] = None,
label_peaks: bool = True,
denote_regions: bool = False,
spectrum_line_kwargs: Optional[Dict] = None,
oscillator_line_kwargs: Optional[Dict] = None,
residual_line_kwargs: Optional[Dict] = None,
model_line_kwargs: Optional[Dict] = None,
label_kwargs: Optional[Dict] = None,
**kwargs,
) -> Tuple[mpl.figure.Figure, np.ndarray[mpl.axes.Axes]]:
r"""Generate a figure of the estimation result.
Parameters
----------
indices
See :ref:`INDICES`.
high_resolution_pts
Indicates the number of points used to generate the oscillators and model.
Should be greater than or equal to ``self.default_pts[0]``. If ``None``,
``self.default_pts[0]`` will be used.
axes_left
The position of the left edge of the axes, in `figure coordinates
<https://matplotlib.org/stable/tutorials/advanced/\
transforms_tutorial.html>`_\. Should be between ``0.`` and ``1.``.
axes_right
The position of the right edge of the axes, in figure coordinates. Should
be between ``0.`` and ``1.``.
axes_top
The position of the top edge of the axes, in figure coordinates. Should
be between ``0.`` and ``1.``.
axes_bottom
The position of the bottom edge of the axes, in figure coordinates. Should
be between ``0.`` and ``1.``.
axes_region_separation
The extent by which adjacent regions are separated in the figure,
in figure coordinates.
xaxis_unit
The unit to express chemical shifts in. Should be ``"hz"`` or ``"ppm"``.
xaxis_label_height
The vertical location of the x-axis label, in figure coordinates. Should
be between ``0.`` and ``1.``, though you are likely to want this to be
only slightly larger than ``0.``.
xaxis_ticks
Specifies custom x-axis ticks for each region, overwriting the default
ticks. Should be of the form: ``[(i, (a, b, ...)), (j, (c, d, ...)), ...]``
where ``i`` and ``j`` are ints indicating the region under consideration,
and ``a``-``d`` are floats indicating the tick values.
oscillator_colors
Describes how to color individual oscillators. See :ref:`COLOR_CYCLE`
for details.
plot_model
.. todo::
Add description
plot_residual
.. todo::
Add description
model_shift
The vertical displacement of the model relative to the data.
residual_shift
The vertical displacement of the residaul relative to the data.
label_peaks
If True, label peaks according to their index. The parameters of a peak
denoted with the label ``i`` in the figure can be accessed with
``self.get_results(indices)[i]``.
denote_regions
If ``True``, and there are regions which share a boundary, a
vertical line will be plotted to show the boundary.
spectrum_line_kwargs
Keyword arguments for the spectrum line. All keys should be valid
arguments for `matplotlib.axes.Axes.plot
<https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.plot.html>`_.
oscillator_line_kwargs
Keyword arguments for the oscillator lines. All keys should be valid
arguments for `matplotlib.axes.Axes.plot
<https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.plot.html>`_.
If ``"color"`` is included, it is ignored (colors are processed
based on the ``oscillator_colors`` argument.
residual_line_kwargs
Keyword arguments for the residual line (if included). All keys
should be valid arguments for `matplotlib.axes.Axes.plot
<https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.plot.html>`_.
model_line_kwargs
Keyword arguments for the model line (if included). All keys should
be valid arguments for `matplotlib.axes.Axes.plot
<https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.plot.html>`_.
label_kwargs
Keyword arguments for oscillator labels. All keys should be valid
arguments for
`matplotlib.text.Text
<https://matplotlib.org/stable/api/text_api.html#matplotlib.text.Text>`_
If ``"color"`` is included, it is ignored (colors are procecessed
based on the ``oscillator_colors`` argument.
``"horizontalalignment"``, ``"ha"``, ``"verticalalignment"``, and
``"va"`` are also ignored, as these are determined internally.
kwargs
Keyword arguments provided to `matplotlib.pyplot.figure
<https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.figure.html\
#matplotlib.pyplot.figure>`_\.
Returns
-------
fig
The result figure. This can be saved to various formats using the
`savefig <https://matplotlib.org/stable/api/figure_api.html\
#matplotlib.figure.Figure.savefig>`_ method.
axs
A ``(1, N)`` NumPy array of the axes generated.
"""
sanity_check(
self._indices_check(indices),
(
"high_resolution_pts", high_resolution_pts, sfuncs.check_int, (),
{"min_value": self.default_pts[-1]}, True,
),
(
"axes_left", axes_left, sfuncs.check_float, (),
{"min_value": 0., "max_value": 1.},
),
(
"axes_right", axes_right, sfuncs.check_float, (),
{"min_value": 0., "max_value": 1.},
),
(
"axes_bottom", axes_bottom, sfuncs.check_float, (),
{"min_value": 0., "max_value": 1.},
),
(
"axes_top", axes_top, sfuncs.check_float, (),
{"min_value": 0., "max_value": 1.},
),
(
"axes_region_separation", axes_region_separation, sfuncs.check_float,
(), {"min_value": 0., "max_value": 1.},
),
self._funit_check(xaxis_unit, "xaxis_unit"),
(
"xaxis_label_height", xaxis_label_height, sfuncs.check_float, (),
{"min_value": 0., "max_value": 1.},
),
("plot_model", plot_model, sfuncs.check_bool),
("plot_residual", plot_residual, sfuncs.check_bool),
(
"model_shift", model_shift, sfuncs.check_float, (),
{"min_value": 0.}, True,
),
(
"oscillator_colors", oscillator_colors, sfuncs.check_oscillator_colors,
(), {}, True,
),
("label_peaks", label_peaks, sfuncs.check_bool),
("denote_regions", denote_regions, sfuncs.check_bool),
)
spectrum_line_kwargs = proc_kwargs_dict(
spectrum_line_kwargs,
default={"color": "#000000"},
)
oscillator_line_kwargs = proc_kwargs_dict(
oscillator_line_kwargs,
to_pop=("color",)
)
if plot_residual:
residual_line_kwargs = proc_kwargs_dict(
residual_line_kwargs,
default={"color": "#808080"},
)
if plot_model:
model_line_kwargs = proc_kwargs_dict(
model_line_kwargs,
default={"color": "#808080"},
)
if label_peaks:
label_kwargs = proc_kwargs_dict(
label_kwargs,
to_pop=("ha", "horizontalalignment", "va", "verticalalignment", "color"), # noqa: E501
)
indices = self._process_indices(indices)
merge_indices, merge_regions = self._plot_regions(indices, xaxis_unit)
n_regions = len(merge_regions)
fig, axs = plt.subplots(
nrows=1,
ncols=n_regions,
gridspec_kw={
"left": axes_left,
"right": axes_right,
"bottom": axes_bottom,
"top": axes_top,
"wspace": axes_region_separation,
"width_ratios": [r[0] - r[1] for r in merge_regions],
},
**kwargs,
)
if n_regions == 1:
axs = np.array([axs])
axs = np.expand_dims(axs, axis=0)
self._configure_axes(
fig,
axs,
merge_regions,
xaxis_ticks,
axes_left,
axes_right,
xaxis_label_height,
xaxis_unit,
)
# Configure high-resolutions points for oscillator and model plots
if high_resolution_pts is None:
high_resolution_pts = self.data.size
highres_expinfo = self.expinfo
highres_expinfo.default_pts = (high_resolution_pts,)
# Get data which spans full spectral width.
# These will be sliced for each region.
full_spectrum = self.spectrum.real
full_shifts_highres, = highres_expinfo.get_shifts(unit=xaxis_unit)
full_shifts, = self.get_shifts(unit=xaxis_unit)
full_model = self.make_fid_from_result(indices)
full_model[0] *= 0.5
full_model = ne.sig.ft(full_model).real
full_residual = full_spectrum - full_model
full_model_highres = self.make_fid_from_result(indices, pts=high_resolution_pts)
full_model_highres[0] *= 0.5
full_model_highres = ne.sig.ft(full_model_highres).real
slices = [
slice(
*self.convert([region], f"{xaxis_unit}->idx")[0]
) for region in merge_regions
]
highres_slices = [
slice(
*highres_expinfo.convert([region], f"{xaxis_unit}->idx")[0]
) for region in merge_regions
]
params = self.get_params(indices=indices)
label_ax_idxs = []
for idx in merge_indices:
vals = []
ps = self.get_params(indices=idx)
for i, p in enumerate(params):
if len(np.where((ps == p).all(axis=-1))[0]) == 1:
vals.append(i)
label_ax_idxs.append(vals)
n_oscs = params.shape[0]
# Store line and text objects.
# Will be shifting these vertically later on
spectra = []
oscs = []
if label_peaks:
labels = []
if plot_model:
models = []
if plot_residual:
residuals = []
for ax, slce, highres_slice, ax_labels in zip(axs[0], slices, highres_slices, label_ax_idxs): # noqa: E501
shifts = full_shifts[slce]
shifts_highres = full_shifts_highres[highres_slice]
spectrum = full_spectrum[slce]
spectra.append(ax.plot(shifts, spectrum, **spectrum_line_kwargs)[0])
if plot_residual:
residual = full_residual[slce]
residuals.append(ax.plot(shifts, residual, **residual_line_kwargs)[0])
if plot_model:
model = full_model[slce]
models.append(ax.plot(shifts, model, **model_line_kwargs)[0])
colorcycle = make_color_cycle(oscillator_colors, n_oscs)
for i, p in enumerate(params):
color = next(colorcycle)
p = np.expand_dims(p, axis=0)
osc = self.make_fid(p, pts=high_resolution_pts)
osc[0] *= 0.5
spec = ne.sig.ft(osc).real[highres_slice]
oscs.append(ax.plot(shifts_highres, spec, color=color, **oscillator_line_kwargs)[0]) # noqa: E501
if label_peaks and (i in ax_labels):
label_idx = np.argmax(np.abs(spec))
label_x = shifts_highres[label_idx]
label_y = spec[label_idx]
label_va, label_ha = (
("bottom", "left") if spec[label_idx] >= 0
else ("top", "right")
)
labels.append(
ax.text(
label_x, label_y, str(i), color=color, va=label_va,
ha=label_ha, **label_kwargs,
)
)
# Vertical shifting of plot lines and labels
if plot_model and (model_shift is None):
model_shift = 0.1 * max(
[np.amax(spectrum.get_ydata()) for spectrum in spectra]
)
if plot_residual:
residual_span = self._get_line_span(residuals)
lines_to_shift = oscs + spectra
if plot_model:
lines_to_shift.extend(models)
lines_to_shift_span = self._get_line_span(lines_to_shift)
if residual_shift is None:
top = (
(residual_span[1] - residual_span[0]) +
(lines_to_shift_span[1] - lines_to_shift_span[0])
) / 0.91
line_shift = residual_span[1] - lines_to_shift_span[0] + (0.03 * top)
else:
line_shift = residual_shift
for line in lines_to_shift:
line.set_ydata(line.get_ydata() + line_shift)
if label_peaks:
for label in labels:
old_pos = label.get_position()
new_pos = (old_pos[0], old_pos[1] + line_shift)
label.set_position(new_pos)
if plot_model:
for model in models:
model.set_ydata(model.get_ydata() + model_shift)
# Set y-limit
all_lines = oscs + spectra
if plot_model:
all_lines.extend(models)
if plot_residual:
all_lines.extend(residuals)
all_lines_span = self._get_line_span(all_lines)
height = all_lines_span[1] - all_lines_span[0]
bottom = all_lines_span[0] - (0.03 * height)
top = all_lines_span[1] + (0.03 * height)
for ax in axs[0]:
ax.set_ylim(bottom, top)
ax.set_yticks([])
return fig, axs
@staticmethod
def _get_line_span(lines: Iterable[mpl.lines.Line2D]) -> Tuple[float, float]:
return (
min([np.amin(line.get_ydata()) for line in lines]),
max([np.amax(line.get_ydata()) for line in lines]),
)