samreay / chainconsumer Goto Github PK
View Code? Open in Web Editor NEWCorner plots, LaTeX tables and plotting walks.
Home Page: https://samreay.github.io/ChainConsumer
License: MIT License
Corner plots, LaTeX tables and plotting walks.
Home Page: https://samreay.github.io/ChainConsumer
License: MIT License
Hi Sam,
Really nice package!
I was hoping to plot walks with multiple chains, but it is not currently possible. I would think this would be easy enough -- the way I usually do it is to plot each chain with the same colour and a pretty low alpha. The resulting plot is "messy", but shows pretty clearly if there are any chains doing strange things compared to the rest.
What do you think?
Hi,
is there a way to control the ticks position when using plot_summary?
I've tried to increase "max_ticks" in "configure, but it does not seem to change anything.
Thanks a lot
Matteo
I find that changing this option does not actually change the font size on the figure. The label size option is responding correctly.
`import numpy as np
from numpy.random import normal, multivariate_normal
from chainconsumer import ChainConsumer
np.random.seed(0)
data = multivariate_normal([0, 1, 2], np.eye(3) + 0.2, size=100000)
c = ChainConsumer()
c.add_chain(data, parameters=["$x$", "$y^2$", r"$\Omega_\beta$"])
c.configure(diagonal_tick_labels=False, tick_font_size=2, label_font_size=16, max_ticks=4)
fig = c.plot(figsize="column")`
and
`import numpy as np
from numpy.random import normal, multivariate_normal
from chainconsumer import ChainConsumer
np.random.seed(0)
data = multivariate_normal([0, 1, 2], np.eye(3) + 0.2, size=100000)
c = ChainConsumer()
c.add_chain(data, parameters=["$x$", "$y^2$", r"$\Omega_\beta$"])
c.configure(diagonal_tick_labels=False, tick_font_size=10, label_font_size=16, max_ticks=4)
fig = c.plot(figsize="column")`
produce the same figure.
Hi Sam,
It would be nice if there was an option to scale up the fontsize when the figsize is scaled. For papers, it's nice to have the same fontsize regardless of the plot size, but for Jupyter notebooks etc, its nice to scale the font with the figure.
Unless I am missing something, to draw the contours you use (from plotter.py)
levels = 1.0 - np.exp(-0.5 * self.parent.config["sigmas"] ** 2)
This assumes that our model parameters are always Gaussian distributed.
To draw contours more generally, you need to find the level that contains the volume you are looking for. This can be done by gridding the height of the map from 0 to zmax and then summing the volume of each bin with height greater than the current grid value (zthresh). When zthresh is 0, you expect to recover 1 (if normalized to 1). When zthresh is zmax, you expect to recover 0. If we calculate the area over a grid of zthresh, we can then interpolate to find zthresh that contains the volume that we are looking for.
I have attached an example file that does just this.
draw_contours.txt
Somehow ChainConsumer.plot interfaces with matplotlib in such a way that it requires dvipng. I did not have this and thus the examples failed. This was rectified when I installed dvipng
sudo apt-get install dvipng
Right now, the top google results for "chain consumer" with a space is the github.io page, and this repo isn't on the first page of google results. This is fine, but the installation instructions aren't on the github.io page. Either these should be added or a link back to the repo's installation instructions should be added to facilitate installation for someone making that search.
Hi, I am new to chainconsumer. I have installed it successfully on my system. I get the following error while testing any example given on the chainconsumer website:
TypeError: init() got an unexpected keyword argument 'gridspec_kw'
How to rectify this error?
Thanks.
First of all - thanks for the excellent package! I just started using ChainConsumer
- so I am probably doing unusual/naive.
c = ChainConsumer()
for f, nburn, lbl in zip(chainfiles, burn_in_steps, chainlabel):
chain = load_chain_data(f)
chain_param_dist = chain[nburn:, 0:nparams]
lnlike = chain[nburn:, nparams]
# since we have flat priors, posterior \propto likelihood
lnpost = lnlike
c = c.add_chain(chain_param_dist, parameters=param_names,
posterior=lnpost, walkers=nwalkers,
name=lbl)
for i, param_idx in enumerate(plot_param_combos):
outfile = "{0}{1}.{2}".format(filebase, i, outputfiletype)
params = [param_names[idx] for idx in param_idx]
print("Generating plot for params names = {0}".format(params))
c.configure(kde=[False] * nchains,
shade=[True] * nchains,
sigmas=[0, 1, 2, 3],
bar_shade=[True] * nchains,
plot_hists=[False] * nchains,
diagonal_tick_labels=False)
c.plotter.plot(filename=outfile, figsize="column",
parameters=params)
print("Generating plot for params names = {0}...done".format(params))
This results in the following:
Generating plot for params names = ['$\\log M_{min}$', '$\\sigma_{\\log M}$']
WARNING:chainconsumer.analysis:Parameter $\log M_{min}$ is not constrained
WARNING:chainconsumer.analysis:Parameter $\sigma_{\log M}$ is not constrained
Generating plot for params names = ['$\\log M_{min}$', '$\\sigma_{\\log M}$']...done
Generating plot for params names = ['$\\log M_1$', '$\\alpha$']
error occurred for key = max_ticks val = 5
All keys = ['max_ticks', 'linestyles', 'kde', 'tick_font_size', 'colors', 'plot_hists', 'legend_color_text', 'plot_color_params', 'cloud', 'statistics', 'legend_kwargs', 'bar_shade', 'label_font_size', 'watermark_text_kwargs', 'usetex', 'contour_labels', 'legend_location', 'legend_artists', 'shade_alpha', 'num_cloud', 'spacing', 'linewidths', 'serif', 'cmaps', 'sigmas', 'sigma2d', 'shade', 'smooth', 'flip', 'summary', 'contour_label_font_size', 'color_params', 'diagonal_tick_labels', 'shade_gradient', 'bins']
Traceback (most recent call last):
File "xx.py", line 140, in <module>
main()
File "xx.py", line 134, in main
diagonal_tick_labels=False)
File "/Users/msinha/anaconda/lib/python2.7/site-packages/chainconsumer/chain.py", line 597, in configure
assert len(val) >= num_chains, \
TypeError: object of type 'int' has no len()
I added the following lines into chain.py
to help debug (which results in the additional info output in the preceeding block):
for key in self.config.keys():
val = self.config[key]
try:
assert len(val) >= num_chains, \
"Only have %d options for %s, but have %d chains!" % (len(val), key, num_chains)
except:
print("error occurred for key = {0} val = {1}".format(
key, val))
print("All keys = {0}".format(self.config.keys()))
raise
My read is that max_ticks
is already defined and therefore does not pass the check. Might be good to add an additional boolean flag per key option that records whether or not the comparison with nchains
test should run.
Thanks again!
We would like to connect to the DES multiprobe pipeline system.
I've been playing around with ChainConsumer to examine mcmc data. Its a great utility for visualizing, and processing chains, but I have noticed two issues in the Gelman Rubin diagnostic function.
The first is that when calculating the within sequence variance (W) you should take np.std()**2
to get the variance rather than the standard deviation. The variance is what Gelman & Rubin (1992: https://projecteuclid.org/euclid.ss/1177011136), use to calculate W. This distinction is important because unless your variance is very close to 1, you will either overestimate or underestimate W, which will cause you to either oversample or undersample your PDF (because the ratio B/W that is most important in determining convergence will be systematically overestimated or underestimated). This gets really bad for chains which have a small sample variance (of order 0.001); such chains can appear to converge in as few as 110 MCMC steps, when they actually need ~10000.
The Second is that when calculating R, you should not be taking the square root of Var/W. The Gelman & Rubin (1992) paper actually defines sqrt(R) = sqrt(V/W df/(df-2))
, where df is some measure of the degrees of freedom (the way it is calculated is in their paper). There are two implications of this: 1) ChainConsumer is actually calculating sqrt(R)
, which is always less than R, so it will always make chains appear more converged than they actually are. 2) The value of R can decrease below 1 as long as V/W
is sufficiently small (because it actually approaches df-2/df
). This is bad, because by definition, R cannot be less than 1, although the effect on determining convergence is negligible.
The fix should be pretty easy. Just change line 1322 to be:
chain_std = np.array([np.std(c, axis=0)**2 for c in chains])
and get the formula for df from the Gelman & Rubin paper, and adjust line 1326 to be
R = var / w * df / (df-2.)
I cannot get the output from getdist and chainconsumer to match, even using the same cosmomc chains. I attach figures showing what I mean.
testkids_2D.pdf
Cosmomc chains have a weight (multiplicity) and a likelihood. I include both of these in my 'add_chain' statement
fig = ChainConsumer().add_chain(final,weights=final_weights,posterior=final_post,parameters=paramnames)
Any suggestion what I might be doing wrong?
Hi again,
since you're so fast, I thought I'd ask another favour. For what I'm doing, I'm using CC to do some convergence tests on a series of chains, and if they fail, I just want to skip that chain. The best way to do this, I think, is to have a remove_chain() method (or even just a remove_last_chain method) which undoes the add_chain() method. I've hacked one myself, but probably missing something important...
def remove_last_chain(cc):
del cc._chains[-1]
del cc._names[-1]
del cc._walkers[-1]
del cc._posteriors[-1]
del cc._weights[-1]
del cc._parameters[-1]
del cc._num_data[-1]
del cc._num_free[-1]
Hi,
I've update chainConsumer to the latest version (0.23.1) and I'm having problem with the Plotter.plot() class (with the version 0.21.7 I I didn't experience this problem).
For instance, if I try to run the example "plot_introduction.ipynb" I get the following error:
`---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/usr/local/lib/python2.7/dist-packages/IPython/core/formatters.pyc in __call__(self, obj)
335 pass
336 else:
--> 337 return printer(obj)
338 # Finally look for special method names
339 method = _safe_get_formatter_method(obj, self.print_method)
/usr/local/lib/python2.7/dist-packages/IPython/core/pylabtools.pyc in <lambda>(fig)
205
206 if 'png' in formats:
--> 207 png_formatter.for_type(Figure, lambda fig: print_figure(fig, 'png', **kwargs))
208 if 'retina' in formats or 'png2x' in formats:
209 png_formatter.for_type(Figure, lambda fig: retina_figure(fig, **kwargs))
/usr/local/lib/python2.7/dist-packages/IPython/core/pylabtools.pyc in print_figure(fig, fmt, bbox_inches, **kwargs)
115
116 bytes_io = BytesIO()
--> 117 fig.canvas.print_figure(bytes_io, **kw)
118 data = bytes_io.getvalue()
119 if fmt == 'svg':
/usr/local/lib/python2.7/dist-packages/matplotlib/backend_bases.pyc in print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, **kwargs)
2257 orientation=orientation,
2258 bbox_inches_restore=_bbox_inches_restore,
-> 2259 **kwargs)
2260 finally:
2261 if bbox_inches and restore_bbox:
/usr/local/lib/python2.7/dist-packages/matplotlib/backends/backend_agg.pyc in print_png(self, filename_or_obj, *args, **kwargs)
505
506 def print_png(self, filename_or_obj, *args, **kwargs):
--> 507 FigureCanvasAgg.draw(self)
508 renderer = self.get_renderer()
509 original_dpi = renderer.dpi
/usr/local/lib/python2.7/dist-packages/matplotlib/backends/backend_agg.pyc in draw(self)
428 if toolbar:
429 toolbar.set_cursor(cursors.WAIT)
--> 430 self.figure.draw(self.renderer)
431 finally:
432 if toolbar:
/usr/local/lib/python2.7/dist-packages/matplotlib/artist.pyc in draw_wrapper(artist, renderer, *args, **kwargs)
53 renderer.start_filter()
54
---> 55 return draw(artist, renderer, *args, **kwargs)
56 finally:
57 if artist.get_agg_filter() is not None:
/usr/local/lib/python2.7/dist-packages/matplotlib/figure.pyc in draw(self, renderer)
1293
1294 mimage._draw_list_compositing_images(
-> 1295 renderer, self, artists, self.suppressComposite)
1296
1297 renderer.close_group('figure')
/usr/local/lib/python2.7/dist-packages/matplotlib/image.pyc in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
136 if not_composite or not has_images:
137 for a in artists:
--> 138 a.draw(renderer)
139 else:
140 # Composite any adjacent images together
/usr/local/lib/python2.7/dist-packages/matplotlib/artist.pyc in draw_wrapper(artist, renderer, *args, **kwargs)
53 renderer.start_filter()
54
---> 55 return draw(artist, renderer, *args, **kwargs)
56 finally:
57 if artist.get_agg_filter() is not None:
/usr/local/lib/python2.7/dist-packages/matplotlib/axes/_base.pyc in draw(self, renderer, inframe)
2397 renderer.stop_rasterizing()
2398
-> 2399 mimage._draw_list_compositing_images(renderer, self, artists)
2400
2401 renderer.close_group('axes')
/usr/local/lib/python2.7/dist-packages/matplotlib/image.pyc in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
136 if not_composite or not has_images:
137 for a in artists:
--> 138 a.draw(renderer)
139 else:
140 # Composite any adjacent images together
/usr/local/lib/python2.7/dist-packages/matplotlib/artist.pyc in draw_wrapper(artist, renderer, *args, **kwargs)
53 renderer.start_filter()
54
---> 55 return draw(artist, renderer, *args, **kwargs)
56 finally:
57 if artist.get_agg_filter() is not None:
/usr/local/lib/python2.7/dist-packages/matplotlib/axis.pyc in draw(self, renderer, *args, **kwargs)
1145 self._update_label_position(ticklabelBoxes, ticklabelBoxes2)
1146
-> 1147 self.label.draw(renderer)
1148
1149 self._update_offset_text_position(ticklabelBoxes, ticklabelBoxes2)
/usr/local/lib/python2.7/dist-packages/matplotlib/artist.pyc in draw_wrapper(artist, renderer, *args, **kwargs)
53 renderer.start_filter()
54
---> 55 return draw(artist, renderer, *args, **kwargs)
56 finally:
57 if artist.get_agg_filter() is not None:
/usr/local/lib/python2.7/dist-packages/matplotlib/text.pyc in draw(self, renderer)
761 posy = float(textobj.convert_yunits(textobj._y))
762 if not np.isfinite(posx) or not np.isfinite(posy):
--> 763 raise ValueError("posx and posy should be finite values")
764 posx, posy = trans.transform_point((posx, posy))
765 canvasw, canvash = renderer.get_canvas_width_height()
ValueError: posx and posy should be finite values
`
The new class "plotter.plot_summary" works very nicely instead :)
Do you have any suggestion on how to fix this error?
Thanks a lot.
Matteo
Paper by Andrew Little 2009 comparing AIC, BIC, DIC
0903.4210
Also
Hi Admin
Can I limit x-axis values of the triangle subplots for different parameters? E.g. I have a truth value which is not appearing in the subplot and if I can limit the x-axis of the subplot for that particular parameter it will be helpful. I haven't found such a variable in the document and if I missed it could you please point that to me?
Thanks
vinu
Hey Sam, just started using this spiffing package today after hearing about how awesome it is from Caitlin for ages.
I'm trying to plot multiple contours with different colours. However for some reason legend_kwargs
is not recognised. I've tried copy pasting the code directly from the Documents but it's still generating the error.
# G1 and G2 are arrays of galaxy information.
c = ChainConsumer()
x_label = "$\mathrm{M}_\mathrm{H} \: [\mathrm{M}_\odot h^{-1}]$"
y_label = "$\mathrm{M}_* \: [\mathrm{M}_\odot h^{-1}]$"
w1 = np.where((G1.Mvir > 0.0) & (G1.StellarMass > 0.1))[0]
w2 = np.where((G2.Mvir > 0.0) & (G2.StellarMass > 0.1))[0]
halomass1 = np.log10(G1.Mvir[w1] * 1.0e10 / self.Hubble_h)
stellarmass1 = np.log10(G1.StellarMass[w1] * 1.0e10 / self.Hubble_h)
halomass2 = np.log10(G2.Mvir[w2] * 1.0e10 / self.Hubble_h)
stellarmass2 = np.log10(G2.StellarMass[w2] * 1.0e10 / self.Hubble_h)
c.add_chain([halomass1, stellarmass1], parameters=[x_label, y_label], name = "$\eta_\mathrm{SN} = 5 \: \mathrm{x} \: 10^{-3}$")
c.add_chain([halomass2, stellarmass2], parameters=[x_label, y_label], name = "$\eta_\mathrm{SN} = 5 \: \mathrm{x} \: 10^{-4}$")
c.configure(linestyles=["-", "--"], sigmas=[0, 1, 2, 3],
legend_kwargs={"loc": "upper left", "fontsize": 10},
legend_color_text=False, legend_location=(0, 0))
fig = c.plotter.plot(filename = OutputDir + 'XX.StellarMass_HaloMass' + OutputFormat, figsize = 1.0)
Yields
Traceback (most recent call last):
File "allresults.py", line 1443, in <module>
res.StellarMass_HaloMass(G1, G2, G4)
File "allresults.py", line 1320, in StellarMass_HaloMass
legend_color_text=False, legend_location=(0, 0))
TypeError: configure() got an unexpected keyword argument 'legend_kwargs'
I'm also getting the error for legend_color_text
and legend_artists
(haven't tried any other esoteric options).
Thanks!
I was hoping to make a figure that had two subfigures, and in each subfigure it would have an output from ChainConsumer. However, matplotlib subfigures cannot be created from matplotlib figure objects. Thus, I am unable to do this with ChainConsumer, since it creates and returns a figure object to the user.
A straightforward (I think) fix would be to optionally take in a figure or axes from the user, clear it, then do the plotting, then return the figure to the user.
The call signature in this case would look like
import matplotlib.pyplot as plt
import chainconsumer as cc
fig, ax = plt.subplots(1)
mean = [0.0, 4.0]
data = np.random.multivariate_normal(mean, [[1.0, 0.7], [0.7, 1.5]], size=100000)
c = ChainConsumer()
c.add_chain(data, parameters=["$x_1$", "$x_2$"])
c.plotter.plot(axis=ax, filename="example.png", figsize="column", truth=mean)
this way, ax
can be a subplot, and we could use ChainConsumer figuers inside a pyplot.subplots()
figure.
For some reason, the joint plots stopped shading the contours:
import numpy as np
from numpy.random import normal, multivariate_normal
from chainconsumer import ChainConsumer
np.random.seed(0)
cov = normal(size=(3, 3))
cov2 = normal(size=(4, 4))
data = multivariate_normal(normal(size=3), 0.5 * (cov + cov.T), size=100000)
data2 = multivariate_normal(normal(size=4), 0.5 * (cov2 + cov2.T), size=100000)
c = ChainConsumer()
c.add_chain(data, parameters=["$x$", "$y$", r"$\alpha$"])
c.add_chain(data2, parameters=["$x$", "$y$", r"$\alpha$", r"$\gamma$"])
fig = c.plot()
Is it something in the code?
Can chain consumer output the correlation/covariance matrix between the parameters? Or flag the most correlated combination?
The figure found here has a label chopped off at the top. This is purely a cosmetic issue with the documentation.
Hi, I was trying to put labels on each histogram, the label is something like x_\mathrm{y}. Each time I got an error says latex can not process x_\\mathrm_{y}. Apparently, it adds extra "//" there. Could you kindly fix this?
Any plans on adding simple covariance ellipses to the plots?
I used chainconsumer to generate the figures for a recent publication (https://arxiv.org/abs/1607.01884), which was submitted to JCAP. The referee response included the following comments:
- Figure-4; Axes labels are very small and the resolution of the figure needs to be improved.
And these comments were repeated for figures 5, 6 and 7. Though the figures look great on the screen, would it be possible to have a different set of defaults for printing quality figures?
When trying to make a corner plot from a chain that is not saved in a numpy array, it crashes.
trans_chain = [list(transform_chain(link)) for link in chain[:100]]
c = ChainConsumer()
c.add_chain(trans_chain, parameters=['p1', 'p2'])
fig = c.plotter.plot(figsize="column")
The error message I get is:
File "/pythonpath/lib/python2.7/site-packages/chainconsumer/chain.py", line 51, in validate_chain
(self.name, len(self.parameters), self.chain.shape[1])
AssertionError: Chain Chain 0 has 2 parameters but data has 100 columns
To fix it, I simply recast trans_chain as a numpy array. Opening issue because I felt that this would be easy to check for and then correct for when adding a chain.
Hi,
I'm trying to smooth my results but I get the following error:
File "plot_kde_extents.py", line 29, in
fig = c.plotter.plot(extents=[(-2, 4), (0, 9)])
File "/home/simone/anaconda2/lib/python2.7/site-packages/chainconsumer/plotter.py", line 128, in plot
fit_values = self.parent.analysis.get_summary(squeeze=False, parameters=parameters)
File "/home/simone/anaconda2/lib/python2.7/site-packages/chainconsumer/analysis.py", line 137, in get_summary
summary = self._get_parameter_summary(chain[:, i], weights, p, ind, grid=g)
File "/home/simone/anaconda2/lib/python2.7/site-packages/chainconsumer/analysis.py", line 149, in _get_parameter_summary
return method(data, weights, parameter, chain_index, desired_area=desired_area, **kwargs)
File "/home/simone/anaconda2/lib/python2.7/site-packages/chainconsumer/analysis.py", line 392, in get_parameter_summary_max
xs, ys, cs = self._get_smoothed_histogram(data, weights, chain_index, grid)
File "/home/simone/anaconda2/lib/python2.7/site-packages/chainconsumer/analysis.py", line 275, in _get_smoothed_histogram
area = simps(ys, x=kde_xs)
File "/home/simone/anaconda2/lib/python2.7/site-packages/scipy/integrate/quadrature.py", line 392, in simps
x = x.reshape(tuple(shapex))
TypeError: 'numpy.float64' object cannot be interpreted as an index
This happens both when I run my script and when I run the example https://samreay.github.io/ChainConsumer/examples/customisations/plot_kde_extents.html#sphx-glr-examples-customisations-plot-kde-extents-py
Thank you for your attention,
Simone
PSIS-LOO
Pareto smoothed importance sampling. Aki Vehtari, Andrew Gelman and Jonah Gabry (2016).
https://arxiv.org/abs/1507.04544
see also https://github.com/avehtari/PSIS
Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC. Aki Vehtari, Andrew Gelman, Jonah Gabry 29 June 2016
http://www.stat.columbia.edu/~gelman/research/unpublished/loo_stan.pdf
WAIC
http://watanabe-www.math.dis.titech.ac.jp/users/swatanab/waicwbic_e.html
see also http://www.stat.columbia.edu/~gelman/research/published/waic_understand3.pdf
Both methods use the MCMC chain for evaluation.
I need this package in conda for another package I'm working on. Is it possible to have this done here, or would you like for me to get it pushed to conda?
Hi Sam, any ideas why I'd get the following matplotlib error all of a sudden?:
/home/steven/miniconda3/envs/HIHOD/lib/python2.7/site-packages/matplotlib/text.pyc in draw(self, renderer)
761 posy = float(textobj.convert_yunits(textobj._y))
762 if not np.isfinite(posx) or not np.isfinite(posy):
--> 763 raise ValueError("posx and posy should be finite values")
764 posx, posy = trans.transform_point((posx, posy))
765 canvasw, canvash = renderer.get_canvas_width_height()
ValueError: posx and posy should be finite values
I'm running the same code as I always have. May have upgraded versions of either matplotlib or chainconsumer since last time, I'm not sure. Current version of MPL: 2.1.0 and ChainConsumer: 0.23.1
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.