# JT 2017-19
from importlib import import_module
from copy import deepcopy
import logging
from numbers import Number
import numpy as np
import os
from typing import Mapping
# Conventions
_label = "label"
_prior = "prior"
_theory = "theory"
_params = "params"
_likelihood = "likelihood"
_sampler = "sampler"
_p_label = "latex"
_p_dist = "dist"
_p_value = "value"
_p_derived = "derived"
_p_renames = "renames"
_separator = "__"
_minuslogprior = "minuslogprior"
_prior_1d_name = "0"
_chi2 = "chi2"
_weight = "weight"
_minuslogpost = "minuslogpost"
_post = "post"
def cobaya_params_file(root):
file = root + ('' if root.endswith((os.sep, "/")) else '.') + 'updated.yaml'
if os.path.exists(file):
return file
else:
file = root + ('' if root.endswith((os.sep, "/")) else '__') + 'full.yaml'
if os.path.exists(file):
return file
return None
def yaml_file_or_dict(file_or_dict) -> Mapping:
if isinstance(file_or_dict, str):
from getdist.yaml_tools import yaml_load_file
return yaml_load_file(file_or_dict)
elif isinstance(file_or_dict, Mapping):
return file_or_dict
else:
raise ValueError('Cobya parameter input must be a dictionary or filename')
[docs]def MCSamplesFromCobaya(info, collections, name_tag=None,
ignore_rows=0, ini=None, settings=None):
"""
Creates a set of samples from Cobaya's output.
Parameter names, ranges and labels are taken from the "info" dictionary
(always use the "updated" one generated by `cobaya.run`).
For a description of the various analysis settings and default values see
`analysis_defaults.ini <https://getdist.readthedocs.org/en/latest/analysis_settings.html>`_.
:param collections: collection(s) of samples from Cobaya
:param info: info dictionary, common to all collections
(use the "updated" one, returned by `cobaya.run`)
:param name_tag: name for this sample to be shown in the plots' legend
:param ignore_rows: initial samples to skip, number (`int>=1`) or fraction (`float<1`)
:param ini: The name of a .ini file with analysis settings to use
:param settings: dictionary of analysis settings to override defaults
:return: The :class:`MCSamples` instance
"""
if hasattr(collections, "data"):
collections = [collections]
# Check consistency between collections
try:
columns = list(collections[0].data)
except AttributeError:
raise TypeError(
"The second argument does not appear to be a (list of) samples `Collection`.")
if not all(list(c.data) == columns for c in collections[1:]):
raise ValueError("The given collections don't have the same columns.")
# Check consistency with info
info_params = get_info_params(info)
# if skip burn in *has already been done*
skip = info.get(_post, {}).get("skip", 0)
if ignore_rows != 0 and skip != 0:
logging.warning("You are asking for rows to be ignored (%r), but some (%r) were "
"already ignored in the original chain.", ignore_rows, skip)
var_params = [k for k, v in info_params.items() if is_sampled_param(v) or is_derived_param(v)]
assert set(columns[2:]) == set(var_params), (
"Info and collection(s) are not compatible, because their parameters differ: "
"the collection(s) have %r and the info has %r. " % (columns[2:], var_params) +
"Are you sure that you are using an *updated* info dictionary "
"(i.e. the output of `cobaya.run`)?")
# We need to use *collection* sorting, not info sorting!
names = [p + ("*" if is_derived_param(info_params[p]) else "")
for p in columns[2:]]
labels = [(info_params[p] or {}).get(_p_label, p) for p in columns[2:]]
ranges = {p: get_range(info_params[p]) for p in info_params} # include fixed parameters not in columns
renames = {p: info_params.get(p, {}).get(_p_renames, []) for p in columns[2:]}
samples = [c[c.data.columns[2:]].values for c in collections]
weights = [c[_weight].values for c in collections]
loglikes = [-c[_minuslogpost].values for c in collections]
sampler = get_sampler_type(info)
label = get_sample_label(info)
from getdist.mcsamples import MCSamples
return MCSamples(samples=samples, weights=weights, loglikes=loglikes, sampler=sampler,
names=names, labels=labels, ranges=ranges, renames=renames,
ignore_rows=ignore_rows, name_tag=name_tag, label=label, ini=ini,
settings=settings)
[docs]def get_info_params(info):
"""
Extracts parameter info from the new yaml format.
"""
info = yaml_file_or_dict(info)
# Prune fixed parameters
info_params = info.get(_params)
info_params_full = dict()
for p, pinfo in info_params.items():
info_params_full[p] = info_params[p]
# Add prior and likelihoods
priors = [_prior_1d_name] + list(info.get(_prior) or [])
likes = list(info.get(_likelihood))
# Account for post
remove = info.get(_post, {}).get("remove", {})
for param in remove.get(_params, []) or []:
info_params_full.pop(param, None)
for like in remove.get(_likelihood, []) or []:
likes.remove(like)
for prior in remove.get(_prior, []) or []:
priors.remove(prior)
add = info.get(_post, {}).get("add", {})
# Adding derived params and updating 1d priors
for param, pinfo in add.get(_params, {}).items():
pinfo_old = info_params_full.get(param, {})
pinfo_old.update(pinfo)
info_params_full[param] = pinfo_old
likes += list(add.get(_likelihood, []))
priors += list(add.get(_prior, []))
# Add the prior and the likelihood as derived parameters
info_params_full[_minuslogprior] = {_p_label: r"-\log\pi"}
for prior in priors:
info_params_full[_minuslogprior + _separator + prior] = {
_p_label: r"-\log\pi_\mathrm{" + prior.replace("_", r"\ ") + r"}"}
info_params_full[_chi2] = {_p_label: r"\chi^2"}
for like in likes:
info_params_full[_chi2 + _separator + like] = {
_p_label: r"\chi^2_\mathrm{" + like.replace("_", r"\ ") + r"}"}
return info_params_full
# noinspection PyUnboundLocalVariable
def get_range(param_info):
# Sampled
if is_sampled_param(param_info):
info_lims = dict((tag, param_info[_prior].get(tag)) for tag in ["min", "max", "loc", "scale"])
if info_lims["min"] is not None or info_lims["max"] is not None:
lims = [param_info[_prior].get("min"), param_info[_prior].get("max")]
elif info_lims["loc"] is not None or info_lims["scale"] is not None:
dist = param_info[_prior].pop(_p_dist, "uniform")
pdf_dist = getattr(import_module("scipy.stats", dist), dist)
lims = pdf_dist.interval(1, **param_info[_prior])
# Derived
elif is_derived_param(param_info):
lims = (lambda i: [i.get("min", -np.inf), i.get("max", np.inf)])(param_info or {})
# Fixed
else:
value = fixed_value(param_info)
try:
value = float(value)
except ValueError:
# e.g. lambda function values
lims = (lambda i: [i.get("min", -np.inf), i.get("max", np.inf)])(param_info or {})
else:
lims = (value, value)
return lims[0] if lims[0] != -np.inf else None, lims[1] if lims[1] != np.inf else None
[docs]def fixed_value(info_param):
"""
Returns True if the parameter has been fixed to a value or through a function.
"""
return expand_info_param(info_param).get(_p_value, None)
[docs]def is_fixed_param(info_param):
"""
Returns True if the parameter has been fixed to a value or through a function.
"""
return fixed_value(info_param) is not None
def is_parameter_with_range(info_param):
value = fixed_value(info_param)
return value is None or isinstance(value, Number) or is_derived_param(info_param)
[docs]def is_sampled_param(info_param):
"""
Returns True if the parameter has a prior.
"""
return _prior in expand_info_param(info_param)
[docs]def is_derived_param(info_param):
"""
Returns True if the parameter is saved as a derived one.
"""
return expand_info_param(info_param).get(_p_derived, False)
[docs]def expand_info_param(info_param):
"""
Expands the info of a parameter, from the user friendly, shorter format
to a more unambiguous one.
"""
if not isinstance(info_param, Mapping):
if info_param is None:
info_param = {}
else:
info_param = {_p_value: info_param}
else:
info_param = deepcopy(info_param)
if all((f not in info_param) for f in [_prior, _p_value, _p_derived]):
info_param[_p_derived] = True
# Dynamical input parameters: save as derived by default
value = info_param.get(_p_value, None)
if isinstance(value, str) or callable(value):
info_param[_p_derived] = info_param.get(_p_derived, True)
return info_param
def get_sampler_type(filename_or_info, default_sampler_for_chain_type="mcmc"):
sampler = list(yaml_file_or_dict(filename_or_info).get(_sampler, [default_sampler_for_chain_type]))[0]
return {"mcmc": "mcmc", "polychord": "nested", "minimize": "minimize"}[sampler]
def get_sample_label(filename_or_info):
return yaml_file_or_dict(filename_or_info).get(_label, None)
def get_burn_removed(filename_or_info):
info = get_info_params(filename_or_info)
# if skip burn in *has already been done*
return info.get(_post, {}).get("skip", 0)