src.superphot_plus.samplers.numpyro_sampler
MCMC sampling using numpyro.
Module Contents
Classes
MCMC sampling using numpyro. |
Functions
|
Helper function to sample prior values. If aux_b is not None, |
|
Helper function using LAX to speed up SVI state updates. |
|
Provides keyword parameters to numpyro's TruncatedNormal, using the fields in PriorFields. |
|
Create a JAX model for MCMC. |
|
JAX guide function for MCMC. |
|
Helper function to run SVI on a single light curve with an already |
|
Runs MCMC using numpyro on the lightcurve to get set |
- class NumpyroSampler(sampler='svi')[source]
Bases:
superphot_plus.samplers.sampler.SamplerMCMC sampling using numpyro.
- run_single_curve(lightcurve: superphot_plus.lightcurve.Lightcurve, priors: superphot_plus.surveys.fitting_priors.MultibandPriors, rng_seed, ref_params=None, **kwargs) superphot_plus.posterior_samples.PosteriorSamples[source]
Run the sampler on a single light curve.
- Parameters:
lightcurve (Lightcurve) – The lightcurve to sample.
priors (MultibandPriors) – The curve priors to use.
rng_seed (int or None) – The random seed to use (for testing purposes). The user should pass None in cases where they want a fully random run.
sampler (str) – The numpyro sampler to use. Either “NUTS” or “svi”
- Returns:
eq_wt_samples – The resulting samples.
- Return type:
- run_multi_curve(lightcurves, priors: superphot_plus.surveys.fitting_priors.MultibandPriors, rng_seed, sampler='svi', ref_params=None, **kwargs) List[superphot_plus.posterior_samples.PosteriorSamples][source]
Not yet implemented.
- prior_helper(priors, aux_b=None)[source]
Helper function to sample prior values. If aux_b is not None, appends aux_b to value names.
- Parameters:
priors (CurvePriors) – The priors for one band
max_flux (float) – Max flux of the light curve.
aux_b (str, optional) – The name of the auxiliary band, if it is auxiliary. Defaults to None, which assumes it’s the base band.
- lax_helper_function(svi, svi_state, num_iters, *args, **kwargs)[source]
Helper function using LAX to speed up SVI state updates.
- trunc_norm(fields: superphot_plus.surveys.fitting_priors.PriorFields, low=None, high=None)[source]
Provides keyword parameters to numpyro’s TruncatedNormal, using the fields in PriorFields.
- Parameters:
fields (PriorFields) – The (low, high, mean, standard deviation) fields of the truncated normal distribution.
- Returns:
A truncated normal distribution.
- Return type:
numpyro.distributions.TruncatedDistribution
- create_jax_model(priors, t=None, obsflux=None, uncertainties=None, max_flux=None, ref_params=None)[source]
Create a JAX model for MCMC.
- Parameters:
t (array-like, optional) – Time values. Defaults to None.
obsflux (array-like, optional) – Observed flux values. Defaults to None.
uncertainties (array-like, optional) – Flux uncertainties. Defaults to None.
max_flux (float, optional) – Maximum flux value. Defaults to None.
priors (MultibandPriors) – priors for all bands in lightcurves
- create_jax_guide(priors, t=None, obsflux=None, uncertainties=None, max_flux=None, ref_params=None)[source]
JAX guide function for MCMC.
- Parameters:
priors (MultibandPriors) – priors for all bands in lightcurves
- _svi_helper_no_recompile(lc_single, max_flux, priors, svi, svi_state, lax_jit, num_iter, seed, ref_params=None)[source]
Helper function to run SVI on a single light curve with an already compiled SVI sampler object.
- run_mcmc(lc, rng_seed, sampler='NUTS', priors=Survey.ZTF().priors, ref_params=None)[source]
Runs MCMC using numpyro on the lightcurve to get set of equally weighted posteriors (sets of fit parameters).
- Parameters:
lc (Lightcurve object) – The Lightcurve object on which to run MCMC
rng_seed (int or None) – The random seed to use (for testing purposes). The user should pass None in cases where they want a fully random run.
sampler (str, optional) – The MCMC sampler to use. Defaults to “NUTS”.
priors (MultibandPriors, optional) – The prior set to use for fitting. Defaults to ZTF’s priors.
- Returns:
A set of equally weighted posteriors (sets of fit parameters) as a numpy array. If the lightcurve does not contain any valid points, None is returned.
- Return type:
np.ndarray or None