from __future__ import annotations
from time import time
import numpy as np
import matplotlib.pyplot as plt
import corner, sys, os, warnings, contextlib, functools, inspect
from emcee import EnsembleSampler
from beartype import beartype
from scipy.interpolate import interp1d
from numpy.polynomial import polynomial as P
from .lsd import LSD
from . import utils
from .data import Data
from .utils import IntLike, Scalar
try:
from dynesty.sampler import Sampler
from dynesty import plotting as dyplot
except ImportError:
Sampler = None
dyplot = None
#TODO: utils.set_dict_defaults for plots
warnings.filterwarnings("ignore")
def _require_profiles(method):
# Make sure all results are processed before calling method
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
if not self.data.complete: # complete is flag for if profiles have been made
name = method.__qualname__
if self.sampler is not None:
if self.config.verbose>0:
print(f"Note: The Result object was created without the profiles processed. " \
f"Running {name} requires all results to be processed, " \
"so process_results() will now be called...")
self.process_results()
else:
error = f"Cannot call {name}. The profiles attribute is not available, and no " \
"sampler object is available to process results. Please pass an Acid/Data " \
"instance after running ACID to the results init."
raise ValueError(error)
return method(self, *args, **kwargs)
return wrapper
def _require_sampler(method):
# Make sure sampler object is available before calling method
sig = inspect.signature(method)
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
bound = sig.bind_partial(self, *args, **kwargs)
inputted_sampler = bound.arguments.get("sampler", None)
self.initiate_sampler(inputted_sampler, _method_name=method.__qualname__)
return method(self, *args, **kwargs)
return wrapper
[docs]
@beartype
class Result:
"""
Class to handle the results from the Acid MCMC sampling, and results processing. Fundamentally, this
class requires two objects to run, the Sampler object and the Data object, both of which can be obtained
from the Acid object. If one or the other is not provided, some methods will not work.
"""
def __init__(
self,
data : Data|object,
sampler : EnsembleSampler|Sampler|None = None, # type:ignore
process_results : bool = True,
verbose : IntLike|bool|str|None = None,
) -> None:
"""
Initialize the Result class
Parameters
----------
data : :py:class:`Data` | :py:class:`Acid`
An Acid object or Data object (contained in Acid class). If an Acid
object is provided, all other arguments are taken from there. If a Data object is
provided, a sampler can be provided in the second argument. If a sampler object
is provided, it will be used as the sampler, but all other attributes will need
to be set manually for the Result object to be fully functional.
sampler : :py:class:`emcee.EnsembleSampler` | :py:class:`dynesty.Sampler`, optional
Sets and overwrites the sampler in the Data object with this if provided, by default None.
process_results : bool, optional
Whether to process the results from the Acid object upon initialisation, by default True.
If False, the profiles attribute will not be available until Result.process_results() is called.
The process_results functions does a LSD call, which can be skipped to save time and use
the Result object for methods that do not require the profiles attribute, such as
continue_sampling() or plot_walkers(). This requires a Data object with the necessary attributes,
and a sampler object in the initialisation, or an Acid object with the necessary attributes already set.
By default, True.
verbose : :py:type:`IntLike | bool | str`, optional
Verbosity level, works exactly the same as :py:class:`Acid`, if not provided
defaults to provided :py:class:`Acid`/:py:class:`Data` class verbosity (which itself defaults to 2).
Overwrites any value passed trough the Data object.
"""
# Handle the different possible cases for 1st argument input
from .acid import Acid
if isinstance(data, Acid):
self.data = data.data
elif isinstance(data, Data):
self.data = data
else:
raise ValueError(f"First argument must be either an Acid or Data object. Got {type(data)} instead.")
# Handle config and verbose options
self.config = self.data.config # point Result.config to Data.config to keep them in sync
self.config.verbose = verbose # property overwrites or handles if verbose input was None
# By default set sampler_initialized = False until sampler has been initialised in function so that self.initiate_sampler can be skipped
self.sampler_initialized = False
# Handle the sampler if input, initiate if one exists
self.sampler = sampler if sampler is not None else self.sampler # update sampler if provided, otherwise keep the same
if self.sampler is not None:
if Sampler is not None: # ie, only if dynesty is installed do this cheeck
self.dynesty = isinstance(self.sampler, Sampler)
else:
self.dynesty = False
self.initiate_sampler(self.sampler) # set internal variables based on sampler, sets sampler_initialiated to True
if not self.data.complete:
if process_results:
if self.sampler is None:
raise ValueError("Cannot process results without a sampler. Please provide a sampler in the initialisation or set process_results=False.")
else:
self.process_results()
elif self.config.verbose > 0:
print("Warning: Results not processed. Profiles attribute will not be available until " \
"Result.process_results() is called or passed through a method.")
elif self.sampler is None and self.config.verbose>0:
print(f"Warning: No sampler provided or found in Data object. \n" \
f"Some methods will not work unless a sampler is provided as a parameter or if Result.initiate_sampler(sampler) is called.")
[docs]
@_require_sampler
def process_results(self) -> None:
"""
Processes the MCMC sampler results to obtain the final LSD profiles and continuum fit, and errors on both.
This is effectively the final step in the ACID pipeline, and must be run before the profiles attribute is available.
This is automatically called if process_results is True during :py:class:`Acid` initialization.
This function is stored here instead of the Acid class because it is not necessary to have the final profiles to use some of the
methods contained within this class.
"""
t0 = time()
# Obtain flattened samples
if self.dynesty:
flat_samples = self.sampler.results.samples_equal()
else:
flat_samples = self.sampler.get_chain(discard=self.burnin, thin=self.thin, flat=True)
# Getting the final profile and continuum values
nvel = len(self.data.velocities) if self.config.deterministic_profile is False else 0
quartiles = np.percentile(flat_samples, [16, 50, 84], axis=0)
errors = np.diff(quartiles, axis=0)
errors = np.max(errors, axis=0)
poly_cos = quartiles[1, nvel:]
poly_cos_err = errors[nvel:] # unused for now
if self.config.verbose > 1:
print('Getting the final profiles...')
# Finding error for the continuum fit
norm_wl = self.data.wavelengths["combined_normalized"]
coeffs = flat_samples[:, nvel:]
ncoeffs = coeffs.shape[1]
powers = np.vander(norm_wl, N=ncoeffs, increasing=True)
# First check memory to see if all samples can be used
available_memory = utils.get_available_memory() # in bytes
m_available = available_memory * 0.8 / (1024**3) # in GB, with 0.8 factor safety gap
n_samples, ncoeffs = coeffs.shape
npix = powers.shape[0]
matrix_size_gb = (2 * n_samples * npix + n_samples * ncoeffs + npix * ncoeffs) * 8 / (1024**3)
# If memory exceeded, fallback to using 1000 random samples
if matrix_size_gb > m_available:
if self.config.verbose > 1:
print(f"Warning: Calculating continuum error with all samples may exceed available memory ({matrix_size_gb:.2f} GB required, {m_available:.2f} GB available). "
"Calculating with a max of 1000 random samples instead.")
indices_size = min(1000, n_samples)
random_indices = np.random.choice(n_samples, size=indices_size, replace=False)
coeffs = coeffs[random_indices, :]
conts = (coeffs @ powers.T)
continuum_error = np.std(conts, axis=0)
# First get the combined profile, and then calculate each frame's profile if there are multiple frames.
# If there is one frame, then the combined_profile is the same as the single frame profile.
nframes = len(self.data.flux["input"])
profiles = [] # switch to list format to add covariance matrix to result
for counter in range(nframes+1):
if counter == 0:
flux = np.copy(self.data.flux["combined"])
error = np.copy(self.data.errors["combined"])
wavelengths = np.copy(self.data.wavelengths["combined"])
sn = np.copy(self.data.sn["combined"])
error[self.data.residual_masks] = 1e12
else:
flux = np.copy(self.data.flux["input"][counter-1])[self.data.nanmask]
error = np.copy(self.data.errors["input"][counter-1])[self.data.nanmask]
wavelengths = np.copy(self.data.wavelengths["input"][counter-1])[self.data.nanmask]
sn = np.copy(self.data.sn["input"][counter-1])
# Masking based off residuals interpolated onto new wavelength grid
reference_wave = self.data.wavelengths["input"][np.nanargmax(self.data.sn["input"])]
reference_wave = reference_wave[self.data.nanmask]
reference_mask = np.zeros_like(reference_wave, dtype=bool)
reference_mask[self.data.residual_masks] = True
reference_interp1d = interp1d(reference_wave, reference_mask.astype(float), kind="nearest", bounds_error=False, fill_value=0.0)
interpolated_mask = reference_interp1d(wavelengths) > 0.5
error[interpolated_mask] = 1e12
# Build continuum model
a, b = utils.get_normalisation_coeffs(wavelengths)
norm_wavelengths = (a*wavelengths)+b
mdl = P.polyval(norm_wavelengths, poly_cos)
# correcting continuum
error = np.sqrt((error/mdl)**2 + (continuum_error/mdl)**2)
flux /= mdl
# Check whether we can skip alpha by reusing the same alpha, only true if the wavelength grid is identical
condition = np.array_equal(wavelengths, self.data.wavelengths["combined"])
alpha = self.data.alpha if condition else None
LSD_profiles = LSD(self.data)
LSD_profiles.run_LSD(wavelengths, flux, error, sn, alpha=alpha)
profile_f = LSD_profiles.profile_F
profile_errors_f = LSD_profiles.profile_errors_F
cov_z_f = LSD_profiles.cov_z_F
if counter == 0:
# Set combined profile params
self.data.combined_profile = [profile_f, profile_errors_f, cov_z_f]
self.data.continuum_model = mdl
# Set the forward model params, multiplied by mdl as LSD is run on normalized flux
self.data.forward_model = LSD_profiles.forward_model * mdl
self.data.forward_errors = LSD_profiles.forward_model_errors * mdl
self.data.forward_x = wavelengths
else:
profiles.append([profile_f, profile_errors_f, cov_z_f])
self.data.profiles = profiles # point Data.profiles to Result.profiles to keep them in sync
self.data.results_time = time() - t0
self.data.total_time = self.data.setup_time + self.data.mcmc_time + self.data.results_time
self.data.complete = True
# Now that results are complete, save the data instance if specified
if self.config.save_path is not None:
self.save() # the sampler is already saved if specified
return
@_require_profiles
def __getitem__(self, item) -> list|np.ndarray:
"""
Allows indexing into the profiles array directly from the Result object.
"""
if isinstance(item, tuple):
# Tuples allow for array-like indexing of the list
if len(item) == 3:
_order, frame, velocity = item
return self.data.profiles[frame][velocity]
elif len(item) == 2:
return self.data.profiles[item[0]][item[1]]
elif len(item) == 1:
return self.data.combined_profile[item[0]]
else:
raise ValueError(f"Tuple indexing must be of length 1, 2, or 3. Got {len(item)} instead.")
elif isinstance(item, int):
# Return just the profile or error (or cov_mat) for single int input
if item < 0 or item > 2:
raise ValueError(f"Integer index must be 0, 1, or 2 to specify whether to return the profile, error, or covariance matrix. Got {item} instead.")
return self.data.combined_profile[item]
elif isinstance(item, str):
# Various different options for string inputs, why not
if "error" in item.lower():
return self.data.combined_profile[1]
elif "cov" in item.lower():
return self.data.combined_profile[2]
elif "profile" in item.lower():
return self.data.combined_profile[0]
else:
raise ValueError(f"String index must contain either 'error', 'cov', or 'profile' to specify which to return. Got {item} instead.")
else:
raise ValueError(f"Invalid index type. Must be either a tuple, int, or str. Got {type(item)} instead.")
@_require_profiles
def __iter__(self):
"""Allows iterating over the profiles array directly from the Result object."""
return iter(self.data.profiles)
def __repr__(self):
# Only print out the sampler and data attributes, and whether profiles is available, to avoid printing large arrays
return f"Result object with sampler={self.sampler}, data={self.data}, profiles={'available' if self.data.profiles is not None else 'not available'}"
def __str__(self):
return self.__repr__()
[docs]
@_require_sampler
def continue_sampling(self, process_results:bool=True, sampler:EnsembleSampler|None=None, **kwargs) -> None:
"""
Continue MCMC sampling for additional steps. Passes the stored sampler into a Acid instance with the saved data. See
:py:function:`Acid.continue_sampling` for more details on the parameters that can be passed.
Parameters
----------
process_results : bool, optional
Whether to process the results after continuing sampling, by default True.
If False, the profiles attribute will not be updated until Result.process_results() is called.
sampler : emcee.EnsembleSampler | None, optional
Optionally provide a different sampler to continue sampling from, otherwise,
takes the sampler from the Result object, by default None
nsteps : :py:type:`IntLike`, optional
Number of additional MCMC steps to run. Passed to :py:function:`Acid.continue_sampling` through **kwargs.
max_steps : :py:type:`IntLike`, optional
Maximum number of MCMC steps to run, by default None. Passed to :py:function:`Acid.continue_sampling` through **kwargs.
max_steps_kwargs : dict, optional
Additional keyword arguments to be passed to the run_mcmc_until_converged function if max_steps is specified, by default None.
The kwargs description can be found in Acid.ACID(), they are the 4 kwargs appearing after max_steps. Typos for kwargs are silently
ignored. Passed to :py:function:`Acid.continue_sampling` through **kwargs.
parallel : bool, optional
Overwrites config with whether to run the MCMC in parallel. If None, uses already existing configuration. Default is None.
Passed to :py:function:`Acid.continue_sampling` through **kwargs.
cores : int, optional
Overwrites config with the number of cores to use for parallel MCMC. If None, uses already existing configuration. Default is None.
Passed to :py:function:`Acid.continue_sampling` through **kwargs.
moves : dict, optional
Overwrites config with the dictionary specifying the moves to use for MCMC sampling. If None, uses already existing configuration.
Default is None. See :py:function:`Acid.ACID` for format. Passed to :py:function:`Acid.continue_sampling` through **kwargs.
"""
# Note that sampler input updates the sampler stored using the @_require_sampler decorator
# Continue sampling using the Acid method
from .acid import Acid
acid = Acid(data=self.data) # includes config data
acid.continue_sampling(**kwargs) # updates self.data.sampler (or self.sampler)
self.initiate_sampler(self.sampler) # update internal variables to match new sampler
if process_results:
self.process_results() # update profiles
else:
if self.config.verbose>0:
print("Warning: Results not processed. profiles attribute will not be available until " \
"Result.process_results() is called.")
[docs]
@_require_sampler
def plot_walkers(
self,
sampler : EnsembleSampler|None = None,
burnin : IntLike|None = None,
thin : IntLike|None = None,
return_fig : bool = False
) -> None | tuple:
"""Plots, at maximum, the last 10 MCMC walkers for the LSD profile and continuum
polynomial coefficients.
Parameters
----------
sampler : :py:class:`emcee.EnsembleSampler`, optional
Optionally provide a different sampler to plot from, otherwise,
takes the sampler from the Result object, by default None
burnin : :py:type:`IntLike`, optional
Optionally define the number of burnin steps, by default uses the burnin calculated when the sampler was initiated.
thin : :py:type:`IntLike`, optional
Optionally define the number of thinning steps, by default uses the thinning calculated when the sampler was initiated.
return_fig : bool, optional
Whether to return the figure and axis objects instead of showing the plot, by default False
Returns
----------
If return_fig is True, returns a tuple (fig, ax) of the figure and axes objects containing, else None
"""
# Set burnin and thin to defaults if not provided
burnin = burnin if burnin is not None else self.burnin
thin = thin if thin is not None else self.thin
samples = self.sampler.get_chain(thin=int(thin))
steps = np.arange(samples.shape[0]) * thin
# Setup plot and plot the walkers for the default parameters
naxes = len(self.default_params)
fig, ax = plt.subplots(naxes, 1, figsize=(10, 20), sharex=True)
for i in range(naxes):
ax[i].plot(steps, samples[:, :, self.default_params[i]], "k", alpha=0.3)
ax[i].axvspan(0, burnin, color="red", alpha=0.1, label="burn-in")
ax[i].set_ylabel(self.default_param_labels[i])
ax[-1].legend()
ax[-1].set_xlabel("Step number")
ax[-1].set_xlim(0, self.data.nsteps)
ax[0].set_title('MCMC Walkers')
plt.subplots_adjust(hspace=0.05)
if return_fig:
return fig, ax
plt.show()
[docs]
@_require_sampler
def plot_traceplot(self, return_fig:bool=False) -> None | tuple:
if not self.dynesty:
raise ValueError("Traceplot is only available for dynesty samplers, as emcee traceplots are already plotted in plot_walkers.")
fig, ax = dyplot.traceplot(self.sampler.results, labels=self.default_param_labels)
plt.suptitle('Dynesty Traceplot')
if return_fig:
return fig, ax
plt.show()
[docs]
@_require_sampler
def plot_corner(
self,
sampler :EnsembleSampler|None = None,
return_fig :bool = False,
**kwargs,
) -> None | plt.Figure:
"""Creates a corner plot for at maximum the last 8 LSD profile and continuum polynomial coefficients.
Parameters
----------
sampler : emcee.EnsembleSampler | None, optional
Optionally provide a different sampler to plot from, otherwise,
takes the sampler from the Result object, by default None
return_fig : bool, optional
Whether to return the figure object instead of showing the plot, by default False
**kwargs:
Additional keyword arguments to pass to corner.corner().
Returns
----------
If return_fig is True, returns the figure object containing the corner plot, else None
"""
if self.dynesty:
fig, axes = dyplot.cornerplot(self.sampler.results, labels=self.default_param_labels, show_titles=True, title_fmt=".3f", title_kwargs={"fontsize": 16}, **kwargs)
plt.suptitle('Dynesty Corner Plot')
if return_fig:
return fig, axes
plt.show()
return
# Get samples and thin and burnin from the class variables
samples = self.sampler.get_chain()
samples = self.sampler.get_chain(discard=self.burnin, flat=True, thin=self.thin)[:, self.default_params]
# Use corner.corner to handle corner plot
fig = corner.corner(samples, labels=self.default_param_labels, show_title=True, title_fmt=".3f", title_kwargs={"fontsize": 16}, **kwargs)
plt.suptitle('MCMC Corner Plot')
if return_fig:
return fig
plt.show()
[docs]
@_require_profiles
def plot_profiles(
self,
grid :bool = True,
labels :dict|None = None,
return_fig :bool = False,
subplot_kwargs :dict|None = None,
errorbar_kwargs :dict|None = None,
fig_ax = None,
) -> None | tuple:
"""Plots the LSD profile result from Acid.
Parameters
----------
grid : bool, optional
Show or hide grid, by default True
labels : dict | None, optional
Keys: 'xlabel', 'ylabel', and 'title'. Allows label overrides., by default None
return_fig : bool, optional
Whether to return the figure and axis objects instead of showing the plot, by default False
subplot_kwargs : dict | None, optional
Keyword arguments to be passed to plt.subplots(), by default None
errorbar_kwargs : dict | None, optional
Keyword arguments to be passed to ax.errorbar(), by default None
fig_ax : tuple[matplotlib.figure.Figure, matplotlib.axes.Axes] | None, optional
Optionally provide an existing fig/axis tuple to plot on, by default None
Returns
----------
If return_fig is True, returns a tuple (fig, ax) of the figure and axes objects containing the plot.
Otherwise, displays the plot and returns None.
"""
# Set default errorbar kwargs
errorbar_defaults = {
"fmt" : ".-",
"ecolor" : "red",
"linewidth": 1,
}
errorbar_kwargs = utils.set_dict_defaults(errorbar_kwargs, errorbar_defaults)
# Set default subplot kwargs
subplot_kwargs = utils.set_dict_defaults(subplot_kwargs, {"figsize": (10, 6)})
# Set default labels
default_labels = {
"title" : "Acid Profile",
"xlabel": "Velocity (km/s)",
"ylabel": "Normalised Flux"
}
labels = utils.set_dict_defaults(labels, default_labels)
# Set useful variables
nframes = len(self.data.profiles)
if fig_ax is None:
fig, ax = plt.subplots(**subplot_kwargs)
else:
fig, ax = fig_ax
# Iterate through and plot frames
for f, frame in enumerate(self.data.profiles):
x, y, yerr = self.data.velocities, frame[0], frame[1]
label_default = f"Frame {f+1}" if nframes > 1 else None
# Override label in errorbar_kwargs if it is not already set, otherwise use the default label
if "label" not in errorbar_kwargs:
errorbar_kwargs["label"] = label_default
ax.errorbar(x, y-1, yerr=yerr, **errorbar_kwargs)
# Add labels and titles
ax.set_title(labels["title"])
ax.set_xlabel(labels["xlabel"])
ax.set_ylabel(labels["ylabel"])
ax.axhline(0, color='black', linestyle='--', linewidth=1)
ax.legend()
ax.grid(grid)
if return_fig:
return fig, ax
else:
plt.show()
[docs]
@_require_profiles
def plot_forward_model(
self,
fig_ax :tuple|None = None,
grid :bool = True,
labels :dict|None = None,
return_fig :bool = False,
subplot_kwargs :dict|None = None,
) -> None | tuple:
"""
Plots the forward model calculated from the final profiles to the combined input spectrum.
Parameters
----------
fig_ax: tuple | None
Optionally provide an existing fig/axis tuple to plot on, by default None and
creates a new figure and axis. The axis must be a 2 element array of axes,
where the first axis is for the spectrum and forward model,
and the second axis is for the residuals.
If provided, the grid, labels, and titles should be set by you.
grid : bool, optional
Show or hide grid, by default True
labels : dict | None, optional
Keys: 'xlabel', 'ylabel', 'title', and 'residuals_ylabel'. Allows label overrides, by default None
return_fig : bool, optional
Whether to return the figure and axis objects instead of showing the plot, by default False
subplot_kwargs : dict | None, optional
Keyword arguments to be passed to plt.subplots(). Allows label overrides, by default None
Returns
----------
If return_fig is True, returns a tuple (fig, ax) of the figure and axes objects containing the plot.
Otherwise, displays the plot and returns None.
"""
# Set default labels
default_labels = {
"title" : "Forward Model Fit to Observed Spectrum",
"xlabel" : "Wavelength (Angstroms)",
"ylabel" : "Normalised Flux",
"residuals_ylabel": "Residuals",
}
labels = utils.set_dict_defaults(labels, default_labels)
# Set default subplot kwargs
subplot_kwargs = {
"figsize": (10, 8),
"sharex": True,
"gridspec_kw": {'height_ratios': [3, 1]}
}
subplot_kwargs = utils.set_dict_defaults(subplot_kwargs, {"figsize": (10, 8)})
# Get input data
wavelengths = self.data.wavelengths["combined"]
flux = self.data.flux["combined"]
# Get flat_samples which are the same samples used to calculate the final profile, alpha is OD,
# so convert profile back to OD and reconvert to flux for forward model
model_flux = self.data.forward_model
# Due to distortion at the edges of the profile, we drop the last 2 pixels
wavelengths = utils.drop_edges(wavelengths)
flux = utils.drop_edges(flux)
model_flux = utils.drop_edges(model_flux)
continuum_model = utils.drop_edges(self.data.continuum_model)
# Plotting
if fig_ax is not None:
fig, ax = fig_ax
else:
fig, ax = plt.subplots(2, 1, **subplot_kwargs)
ax[0].set_title(labels["title"])
ax[1].set_xlabel(labels["xlabel"])
ax[0].set_ylabel(labels["ylabel"])
ax[1].set_ylabel(labels["residuals_ylabel"])
ax[0].grid(grid)
ax[1].grid(grid)
plt.subplots_adjust(hspace=0.05)
ax[1].axhline(0, color='black', linestyle='--', linewidth=1)
ax[0].plot(wavelengths, flux, color='black', linewidth=1, label='Observed Spectrum')
ax[0].plot(wavelengths, model_flux, color='C0', linewidth=1, label='Forward Model Fit')
ax[0].plot(wavelengths, continuum_model, color='C1', linewidth=1, label='Fitted Continuum', linestyle='--')
ax[1].plot(wavelengths, model_flux-flux, color='C0', linewidth=1, label='Residuals')
ax[1].axhline(0, color='black', linestyle='--', linewidth=1)
ax[0].legend()
ax[1].legend()
if return_fig:
return fig, ax
else:
plt.show()
[docs]
@_require_sampler
def plot_autocorrelation(
self,
sampler : EnsembleSampler|None = None,
n_grid : IntLike = 12,
c : float = 5.0,
return_fig : bool = False,
subplot_kwargs : dict|None = None,
min_steps : IntLike = 100
) -> None | tuple:
"""
Plot estimated integrated autocorrelation time as a function of chain length.
From the emcee docs:
- For several prefixes of the chain, estimate tau with Sokal windowing.
- Plot tau(N) and the reference line tau = N/50.
Parameters
----------
sampler : :py:class:`emcee.EnsembleSampler` | None, optional
Optionally provide a different sampler to plot from, otherwise,
takes the sampler from the Result object, by default None
n_grid : :py:type:`IntLike`, optional
Number of N values (prefix lengths) to evaluate, by default 12.
c : float, optional
Sokal window constant, by default 5.0.
return_fig : bool, optional
Whether to return the figure and axes objects, by default False
subplot_kwargs : dict | None, optional
Keyword arguments to be passed to plt.subplots(). Allows label overrides, by default None
min_steps : :py:type:`IntLike`, optional
Minimum number of post-burnin samples required to attempt autocorrelation estimation, by default 100
If you decrease this, you may get unreliable estimates or errors from the autocorrelation time estimation.
Returns
----------
If return_fig is True, returns a tuple (fig, ax) of the figure and axes objects containing
the plot. Otherwise, displays the plot and returns None.
"""
chain = self.sampler.get_chain() # (nsteps, nwalkers, ndim)
nsteps, nwalkers, ndim = chain.shape
if nsteps < min_steps:
raise ValueError("Not enough post-burnin samples to estimate autocorrelation reliably.")
Ns = np.unique(np.exp(np.linspace(np.log(min_steps), np.log(nsteps), n_grid)).astype(int))
Ns = Ns[Ns >= min_steps] # Ensure we only consider N >= min_steps
tau_estimates = {p: np.full(len(Ns), np.nan, dtype=float) for p in self.default_params}
# Estimate taus
for i, n in enumerate(Ns):
for p in self.default_params:
y = chain[:n, :, p].T
tau_estimates[p][i] = utils.autocorr_new(y, c=c)
subplot_kwargs = {} if subplot_kwargs is None else dict(subplot_kwargs)
subplot_kwargs = utils.set_dict_defaults(subplot_kwargs, {"figsize": (10, 6)})
fig, ax = plt.subplots(**subplot_kwargs)
for label, p in zip(self.default_param_labels, self.default_params):
ax.loglog(Ns, tau_estimates[p], "o-", label=f"{label}")
# Reference line tau = N/50
ax.loglog(Ns, Ns / 50.0, "--", label=r"$\tau = N/50$")
ax.set_xlabel("Number of post-burnin samples per walker (N)")
ax.set_ylabel(r"Estimated integrated autocorrelation time $\tau$")
ax.set_title("Autocorrelation time estimates vs chain length")
ax.legend()
ax.grid(True, which="both")
if return_fig:
return fig, ax
plt.show()
return
[docs]
@_require_sampler
def plot_acf(
self,
sampler : EnsembleSampler|None = None,
max_lag : IntLike|None = None,
return_fig : bool = False,
subplot_kwargs : dict|None = None,
) -> None | tuple:
"""
Plot the autocorrelation function (ACF) for each parameter, averaged across walkers.
Parameters
----------
sampler : :py:class:`emcee.EnsembleSampler`, optional
Optionally provide a different sampler to plot from, otherwise,
takes the sampler from the Result object, by default None
max_lag : :py:type:`IntLike`, optional
Maximum lag to plot, by default None (plots up to min(5000, nsteps-1))
return_fig : bool, optional
Whether to return the figure and axes objects, by default False
subplot_kwargs : dict, optional
Keyword arguments to be passed to plt.subplots(). Allows label overrides, by default None
Returns
-------
If return_fig is True, returns a tuple (fig, ax) of the figure and axes objects containing
the plot. Otherwise, displays the plot and returns None.
"""
chain = self.sampler.get_chain()
nsteps, nwalkers, ndim = chain.shape
subplot_kwargs = {} if subplot_kwargs is None else dict(subplot_kwargs)
subplot_kwargs = utils.set_dict_defaults(subplot_kwargs, {"figsize": (10, 5)})
fig, ax = plt.subplots(**subplot_kwargs)
for param, label in zip(self.default_params, self.default_param_labels):
y = chain[:, :, param].T # (nwalkers, nsteps)
# Mean ACF across walkers
f = np.zeros(nsteps)
for w in range(nwalkers):
f += utils.autocorr_func_1d(y[w], norm=True)
f /= nwalkers
if max_lag is None:
max_lag = min(5_000, nsteps - 1)
max_lag = int(max_lag)
ax.plot(np.arange(max_lag + 1), f[: max_lag + 1], label=f"{label}")
ax.set_xlabel("Lag (steps)")
ax.set_ylabel("Autocorrelation")
ax.set_title(f"Mean ACF across walkers")
ax.set_xscale("log")
ax.grid(True)
ax.axhline(0, color="black", linestyle="--", linewidth=1)
ax.legend()
if return_fig:
return fig, ax
plt.show()
[docs]
def initiate_sampler(self, sampler:EnsembleSampler|Sampler|None, _method_name=None) -> None: # type:ignore
"""
Initiates the sampler attribute from an external sampler.
Parameters
----------
sampler : :py:class:`emcee.EnsembleSampler` or object, optional
An emcee EnsembleSampler object or a compatible sampler object to set as the sampler attribute.
_method_name : str, optional
Internal parameter used to track which method is calling initiate_sampler, for error messages.
Not intended for user input, by default None.
"""
if self.sampler_initialized:
if sampler is None:
return # sampler already initiated from initialisation, so skip the rest of the method
# else: continues to update the sampler and internal variables based on new sampler input
self.sampler = sampler if sampler is not None else self.sampler
if self.sampler is None:
if _method_name is not None:
error_msg = f"Cannot run {_method_name} without a sampler, please pass in a sampler to the method or during initialisation."
else:
error_msg = "Cannot initiate sampler without a sampler stored in the instance or passed as a parameter, please pass in a sampler "
raise AttributeError(error_msg)
if self.dynesty:
a=ord('a')
alph=[chr(i) for i in range(a,a+26)]
poly_labels = [alph[i] for i in range(self.config.poly_ord + 1)]
self.default_param_labels = poly_labels
self.default_params = None
return
# Calculate autocorr time, burnin, thin
# Suppress output from get_autocorr_time call
with open(os.devnull, "w") as devnull, \
contextlib.redirect_stdout(devnull), \
contextlib.redirect_stderr(devnull):
self.tau = self.sampler.get_autocorr_time(quiet=True)
self.converged = True
if self.data.nsteps < 50 * np.max(self.tau):
self.converged = False
if self.config.verbose>1:
print("The number of MCMC steps is less than 50 times the maximum autocorrelation " \
"time.\n The sampler may not have converged. Consider running more steps or checking " \
f"the walker plots.\n The max autocorrelation time is {np.max(self.tau):.2f}, therefore " \
f"the minimum number of steps should be roughly {int(50 * np.max(self.tau))}.\n Disabling burnin " \
f"from autocorrelation time, instead using burnin=steps-1000")
try:
self.thin = int(np.min(self.tau)/5)
if self.converged:
self.burnin = int(3 * np.max(self.tau))
else:
self.burnin = self.data.nsteps - 1000 # just the last 1000 steps
except:
if self.config.verbose>0:
print(f"Warning: Could not compute autocorrelation time for burnin and thinning.\n This is likely" \
f" due to all posterior samples being rejected (possibly by prior constraints).\n The resulting profile is likely" \
f" wrong. Setting defaults: burnin=nsteps-1000, and thin=1.")
self.burnin = self.data.nsteps - 1000 # just the last 1000 steps
self.thin = 1
self.burnin = int(np.clip(self.burnin, 0, self.data.nsteps-1)) # ensure burnin is at least 0 and less than total steps
self.thin = int(np.clip(self.thin, 1, self.data.nsteps-1)) # ensure thin is at least 1, and not clipping to nsteps
# Below is used for the parameters for the walker and corner plots
n_poly_params = self.config.poly_ord + 1
poly_params = np.arange(-1, -n_poly_params-1, -1).tolist()
# Generates labels for the polynomial coefficients, starting from 'a' for the highest order term, and going backwards in the alphabet.
a=ord('a')
alph=[chr(i) for i in range(a,a+26)]
poly_labels = [alph[i] for i in range(n_poly_params)]
samples = self.sampler.get_chain(thin=self.thin, discard=self.burnin)
if not self.config.deterministic_profile:
max_profile_idx = np.argmax(samples[:,:,:-n_poly_params].mean(axis=(0,1)))
poly_params.extend([-5, max_profile_idx, 1])
poly_labels.extend(["$Z_{-1}$", "$Z_{max}$", "$Z_0$"])
self.default_params = poly_params
self.default_param_labels = poly_labels
self.sampler_initialized = True
@property
def sampler(self) -> EnsembleSampler|Sampler|None: # type:ignore
"""Returns the sampler attribute, by default is None if not saved."""
return self.data.sampler
@sampler.setter
def sampler(self, value: EnsembleSampler|Sampler|None) -> None: # type:ignore
"""Sets the sampler in the data class."""
self.data.sampler = value
[docs]
def save(self, *args, **kwargs) -> None:
"""
Saves the Data instance which the Result class inherits from Acid.
See :py:function:`Data.save` for more details on the parameters that can be passed.
"""
self.data.save(*args, **kwargs)
[docs]
@classmethod
def load(cls, data:str|Result|Data="result.pkl") -> Result:
"""
Loads a Result object from a pickle file or from a Data/Result object.
Parameters
----------
data : str | :py:class:`Result` | :py:class:`Data`, optional
A pickle file name or an object with the same attributes as a saved Result object, by default "result.pkl"
Returns
----------
:py:class:`Result`
A Result object loaded from the pickle file or from the provided object.
"""
if isinstance(data, str):
return cls(Data.load(data))
elif isinstance(data, Result):
return cls(data.data)
elif isinstance(data, Data):
return cls(data)