"""MCMC sampling using numpyro."""
from os import urandom
from typing import List
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random, lax, jit
from jax.config import config
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.initialization import init_to_uniform
from superphot_plus.constants import PAD_SIZE
from superphot_plus.lightcurve import Lightcurve
from superphot_plus.posterior_samples import PosteriorSamples
from superphot_plus.samplers.sampler import Sampler
from superphot_plus.surveys.fitting_priors import MultibandPriors, PriorFields
from superphot_plus.surveys.surveys import Survey
from superphot_plus.utils import calculate_neg_chi_squareds, get_numpyro_cube
config.update("jax_enable_x64", True)
numpyro.enable_x64()
[docs]class NumpyroSampler(Sampler):
"""MCMC sampling using numpyro."""
def __init__(self):
pass
[docs] def run_single_curve(
self, lightcurve: Lightcurve, priors: MultibandPriors, rng_seed, ref_params=None, sampler="svi", **kwargs
) -> PosteriorSamples:
"""Run the sampler on a single light curve.
Parameters
----------
lightcurve : Lightcurve
The lightcurve to sample.
priors : MultibandPriors
The curve priors to use.
rng_seed : int or None
The random seed to use (for testing purposes). The user should pass None in
cases where they want a fully random run.
sampler : str
The numpyro sampler to use. Either "NUTS" or "svi"
Returns
-------
eq_wt_samples : PosteriorSamples
The resulting samples.
"""
lightcurve.pad_bands(priors.ordered_bands, PAD_SIZE)
eq_wt_samples = run_mcmc(
lightcurve,
rng_seed=rng_seed,
sampler=sampler,
priors=priors,
ref_params=ref_params,
)
if eq_wt_samples is None:
return None
return PosteriorSamples(
eq_wt_samples[0],
name=lightcurve.name,
sampling_method=sampler,
sn_class=lightcurve.sn_class,
)
[docs] def run_multi_curve(
self, lightcurves, priors: MultibandPriors, rng_seed, sampler="svi", ref_params=None, **kwargs
) -> List[PosteriorSamples]:
"""Not yet implemented."""
if len(lightcurves) == 0:
return []
padded_lcs = []
for lc in lightcurves:
padded_lcs.append(lc.pad_bands(priors.ordered_bands, PAD_SIZE, in_place=False))
eq_wt_samples = run_mcmc(
padded_lcs,
rng_seed=rng_seed,
sampler=sampler,
priors=priors,
ref_params=ref_params
)
post_list = []
for i, posts in enumerate(eq_wt_samples):
if posts is None:
continue
post_list.append(
PosteriorSamples(
posts, name=lightcurves[i].name, sampling_method=sampler, sn_class=lightcurves[i].sn_class
)
)
return post_list
[docs]def prior_helper(priors, max_flux, aux_b=None):
"""Helper function to sample prior values. If aux_b is not None,
appends aux_b to value names.
Parameters
----------
priors : CurvePriors
The priors for one band
max_flux : float
Max flux of the light curve.
aux_b : str, optional
The name of the auxiliary band, if it is auxiliary. Defaults to None, which
assumes it's the base band.
"""
if aux_b is None:
amp = max_flux * 10 ** numpyro.sample("logA", trunc_norm(priors.amp))
beta = numpyro.sample("beta", trunc_norm(priors.beta))
gamma = 10 ** numpyro.sample("log_gamma", trunc_norm(priors.gamma))
t_0 = numpyro.sample("t0", trunc_norm(priors.t_0))
tau_rise = 10 ** numpyro.sample("log_tau_rise", trunc_norm(priors.tau_rise))
tau_fall = 10 ** numpyro.sample("log_tau_fall", trunc_norm(priors.tau_fall))
extra_sigma = 10 ** numpyro.sample("log_extra_sigma", trunc_norm(priors.extra_sigma))
else:
suffix = "_" + str(aux_b)
amp = numpyro.sample(f"A{suffix}", trunc_norm(priors.amp))
beta = numpyro.sample(f"beta{suffix}", trunc_norm(priors.beta))
gamma = numpyro.sample(f"gamma{suffix}", trunc_norm(priors.gamma))
t_0 = numpyro.sample(f"t0{suffix}", trunc_norm(priors.t_0))
tau_rise = numpyro.sample(f"tau_rise{suffix}", trunc_norm(priors.tau_rise))
tau_fall = numpyro.sample(f"tau_fall{suffix}", trunc_norm(priors.tau_fall))
extra_sigma = numpyro.sample(f"extra_sigma{suffix}", trunc_norm(priors.extra_sigma))
return amp, beta, gamma, t_0, tau_rise, tau_fall, extra_sigma
[docs]def lax_helper_function(svi, svi_state, num_iters, *args, **kwargs):
"""Helper function using LAX to speed up SVI state updates."""
@jit
def update_svi(s, _):
return svi.stable_update(s, *args, **kwargs)
u = svi_state
u, _ = lax.scan(update_svi, svi_state, None, length=num_iters)
return u
[docs]def trunc_norm(fields: PriorFields):
"""Provides keyword parameters to numpyro's TruncatedNormal, using the fields in PriorFields.
Parameters
----------
fields : PriorFields
The (low, high, mean, standard deviation) fields of the truncated normal distribution.
Returns
-------
numpyro.distributions.TruncatedDistribution
A truncated normal distribution.
"""
return dist.TruncatedNormal(
loc=fields.mean, scale=fields.std, low=fields.clip_a, high=fields.clip_b, validate_args=True
)
[docs]def create_jax_model(
priors,
t=None,
obsflux=None,
uncertainties=None,
max_flux=None,
ref_params=None
): # pylint: disable=too-many-locals
"""Create a JAX model for MCMC.
Parameters
----------
t : array-like, optional
Time values. Defaults to None.
obsflux : array-like, optional
Observed flux values. Defaults to None.
uncertainties : array-like, optional
Flux uncertainties. Defaults to None.
max_flux : float, optional
Maximum flux value. Defaults to None.
priors : MultibandPriors
priors for all bands in lightcurves
"""
ref_priors = priors.bands[priors.reference_band]
if ref_params is not None:
(
amp, beta, gamma, t_0,
tau_rise, tau_fall, extra_sigma
) = ref_params
else:
(
amp, beta, gamma, t_0,
tau_rise, tau_fall, extra_sigma
) = prior_helper(ref_priors, max_flux)
es_scaled = max_flux * extra_sigma
phase = t - t_0
flux_const = amp / (1.0 + jnp.exp(-phase / tau_rise))
sigmoid = 1 / (1 + jnp.exp(10.0 * (gamma - phase)))
flux = flux_const * (
(1 - sigmoid) * (1 - beta * phase)
+ sigmoid * (1 - beta * gamma) * jnp.exp(-(phase - gamma) / tau_fall)
)
sigma_tot = jnp.sqrt(uncertainties**2 + es_scaled**2)
# auxiliary bands
for b_idx, uniq_b in enumerate(priors.ordered_bands):
if uniq_b == priors.reference_band:
continue
b_priors = priors.bands[uniq_b]
(
amp_ratio,
beta_ratio,
gamma_ratio,
t0_ratio,
tau_rise_ratio,
tau_fall_ratio,
extra_sigma_ratio,
) = prior_helper(b_priors, max_flux, uniq_b)
amp_b = amp * amp_ratio
beta_b = beta * beta_ratio
gamma_b = gamma * gamma_ratio
t0_b = t_0 * t0_ratio
tau_rise_b = tau_rise * tau_rise_ratio
tau_fall_b = tau_fall * tau_fall_ratio
# base inc_band_ix on ordered_bands
inc_band_ix = np.arange(b_idx * PAD_SIZE, (b_idx + 1) * PAD_SIZE)
phase_b = (t - t0_b)[inc_band_ix]
flux_const_b = amp_b / (1.0 + jnp.exp(-phase_b / tau_rise_b))
sigmoid_b = 1 / (1 + jnp.exp(10.0 * (gamma_b - phase_b)))
flux = flux.at[inc_band_ix].set(
flux_const_b
* (
(1 - sigmoid_b) * (1 - beta_b * phase_b)
+ sigmoid_b * (1 - beta_b * gamma_b) * jnp.exp(-(phase_b - gamma_b) / tau_fall_b)
)
)
sigma_tot = sigma_tot.at[inc_band_ix].set(
jnp.sqrt(uncertainties[inc_band_ix] ** 2 + extra_sigma_ratio**2 * es_scaled**2)
)
_ = numpyro.sample("obs", dist.Normal(flux, sigma_tot), obs=obsflux)
[docs]def create_jax_guide(priors, t=None, obsflux=None, uncertainties=None, max_flux=None, ref_params=None):
"""JAX guide function for MCMC.
Parameters
----------
priors : MultibandPriors
priors for all bands in lightcurves
"""
def numpyro_sample(prefix: str, fields: PriorFields, param_constraint: float):
param_mu = numpyro.param(
f"{prefix}_mu",
fields.mean,
constraint=constraints.interval(fields.clip_a, fields.clip_b),
)
param_sigma = numpyro.param(f"{prefix}_sigma", param_constraint, constraint=constraints.positive)
numpyro.sample(prefix, dist.Normal(param_mu, param_sigma))
ref_priors = priors.bands[priors.reference_band]
numpyro_sample("logA", ref_priors.amp, 1e-3)
numpyro_sample("beta", ref_priors.beta, 1e-5)
numpyro_sample("log_gamma", ref_priors.gamma, 1e-3)
numpyro_sample("t0", ref_priors.t_0, 1e-3)
numpyro_sample("log_tau_rise", ref_priors.tau_rise, 1e-3)
numpyro_sample("log_tau_fall", ref_priors.tau_fall, 1e-3)
numpyro_sample("log_extra_sigma", ref_priors.extra_sigma, 1e-3)
# aux bands
for uniq_b in priors.aux_bands:
b_priors = priors.bands[uniq_b]
numpyro_sample("A_" + uniq_b, b_priors.amp, 1e-3)
numpyro_sample("beta_" + uniq_b, b_priors.beta, 1e-3)
numpyro_sample("gamma_" + uniq_b, b_priors.gamma, 1e-3)
numpyro_sample("t0_" + uniq_b, b_priors.t_0, 1e-3)
numpyro_sample("tau_rise_" + uniq_b, b_priors.tau_rise, 1e-3)
numpyro_sample("tau_fall_" + uniq_b, b_priors.tau_fall, 1e-3)
numpyro_sample("extra_sigma_" + uniq_b, b_priors.extra_sigma, 1e-3)
[docs]def _svi_helper_no_recompile(
lc_single,
max_flux,
priors,
svi,
svi_state,
lax_jit,
num_iter,
seed,
ref_params=None,
):
"""Helper function to run SVI on a single light curve with an already
compiled SVI sampler object."""
if svi_state is None:
svi_state = svi.init(
random.PRNGKey(1),
obsflux=lc_single.fluxes,
t=lc_single.times,
uncertainties=lc_single.flux_errors,
max_flux=max_flux,
ref_params=ref_params,
)
svi_state = lax_jit(
svi,
svi_state,
num_iter,
obsflux=lc_single.fluxes,
t=lc_single.times,
uncertainties=lc_single.flux_errors,
max_flux=max_flux,
ref_params=ref_params,
)
# params = svi_result.params
params = svi.get_params(svi_state)
posterior_samples = {}
for param in params:
if param[-2:] == "mu":
rng = np.random.RandomState(seed[0])
posterior_samples[param[:-3]] = rng.normal(
loc=params[param], scale=params[param[:-2] + "sigma"], size=100
)
posterior_cube = get_numpyro_cube(
posterior_samples, max_flux,
priors.reference_band, priors.ordered_bands
)[0]
padded_idxs = lc_single.flux_errors == 1e10
red_neg_chisq = calculate_neg_chi_squareds(
posterior_cube,
lc_single.times[~padded_idxs],
lc_single.fluxes[~padded_idxs],
lc_single.flux_errors[~padded_idxs],
lc_single.bands[~padded_idxs],
ordered_bands=priors.ordered_bands,
ref_band=priors.reference_band,
)
return posterior_cube, red_neg_chisq, svi_state
[docs]def run_mcmc(lc, rng_seed, sampler="NUTS", priors=Survey.ZTF().priors, ref_params=None):
"""Runs MCMC using numpyro on the lightcurve to get set
of equally weighted posteriors (sets of fit parameters).
Parameters
----------
lc : Lightcurve object
The Lightcurve object on which to run MCMC
rng_seed : int or None
The random seed to use (for testing purposes). The user should pass None in
cases where they want a fully random run.
sampler : str, optional
The MCMC sampler to use. Defaults to "NUTS".
priors : MultibandPriors, optional
The prior set to use for fitting. Defaults to ZTF's priors.
Returns
-------
np.ndarray or None
A set of equally weighted posteriors (sets of fit parameters) as
a numpy array. If the lightcurve does not contain any valid
points, None is returned.
"""
batch = type(lc) is list # check if one LightCurve or multiple
if rng_seed is None:
rng_seed = int.from_bytes(urandom(4), "big")
print(f"Running numpyro with seed={rng_seed}")
rng_key = random.PRNGKey(rng_seed)
rng_key, seed2 = random.split(rng_key)
def jax_model(t=None, obsflux=None, uncertainties=None, max_flux=None, ref_params=None):
create_jax_model(priors, t, obsflux, uncertainties, max_flux, ref_params)
def jax_guide(**kwargs): # pylint: disable=unused-argument
create_jax_guide(priors)
if sampler == "NUTS":
if batch:
raise ValueError("Batch mode not implemented for NUTS.")
# Require data in all bands.
for unique_band in priors.ordered_bands:
if lc.obs_count(unique_band) == 0:
return None
max_flux, _ = lc.find_max_flux(band=priors.reference_band)
num_samples = 300
kernel = NUTS(jax_model, init_strategy=init_to_uniform)
rng_key = random.PRNGKey(rng_seed)
rng_key, _ = random.split(rng_key)
mcmc = MCMC(
kernel,
num_warmup=1000,
num_samples=num_samples,
num_chains=1,
chain_method="parallel",
jit_model_args=True,
)
# with numpyro.validation_enabled():
mcmc.run(
rng_key,
obsflux=lc.fluxes,
t=lc.times,
uncertainties=lc.flux_errors,
max_flux=max_flux,
)
posterior_samples = mcmc.get_samples()
posterior_cube, aux_bands = get_numpyro_cube(
posterior_samples,
max_flux,
priors.reference_band,
priors.ordered_bands
)
padded_idxs = lc.flux_errors > 1e5
red_neg_chisq = calculate_neg_chi_squareds(
posterior_cube,
lc.times[~padded_idxs],
lc.fluxes[~padded_idxs],
lc.flux_errors[~padded_idxs],
lc.bands[~padded_idxs],
ordered_bands=priors.ordered_bands,
ref_band=priors.reference_band,
)
posterior_cubes = [
np.hstack((posterior_cube, red_neg_chisq[np.newaxis, :].T)),
]
elif sampler == "svi":
optimizer = numpyro.optim.Adam(step_size=0.001)
svi = SVI(jax_model, jax_guide, optimizer, loss=Trace_ELBO())
num_iter = 10_000
lax_jit = jit(lax_helper_function, static_argnums=(0, 2))
if not batch:
lc = [
lc,
]
bad_prev_fit = True
posterior_cubes = []
for i, lc_single in enumerate(lc):
if i % 100 == 0:
print(i)
"""
# Require data in all bands.
for unique_band in priors.ordered_bands:
if lc_single.obs_count(unique_band) == 0:
posterior_cubes.append(None)
break
"""
if bad_prev_fit:
svi_state = None #reinitialize
max_flux, _ = lc_single.find_max_flux(band=priors.reference_band)
posterior_cube, red_neg_chisq, svi_state = _svi_helper_no_recompile(
lc_single,
max_flux,
priors,
svi,
svi_state,
lax_jit,
num_iter,
seed2,
ref_params,
)
bad_prev_fit = np.mean(red_neg_chisq) < -6
posterior_cube = np.hstack((posterior_cube, red_neg_chisq[np.newaxis, :].T))
posterior_cubes.append(posterior_cube)
else:
raise ValueError("'sampler' must be 'NUTS' or 'svi'")
return posterior_cubes