Introduction

plot_trace is one of the most common plots to assess the convergence of MCMC runs, therefore, it is also one of the most used ArviZ functions. plot_trace has a lot of parameters that allow creating highly customizable plots, but they may not be straightforward to use. There are many reasons that can explain this convolutedness of the arguments and their format, there is no clear culprit: ArviZ has to integrate with several libraries such as xarray and matplotlib which provide amazing features and customization power, and we'd like to allow ArviZ users to access all these features. However, we also aim to keep ArviZ usage simple and with sensible defaults; plot_xyz(idata) should generate acceptable results in most situations.

This post aims to be an extension to the API section on plot_trace, focusing mostly on arguments where examples may be lacking and arguments that appear often in questions posted to ArviZ issues.

Therefore, the most common arguments such as var_names will not be covered, and for arguments that I do not remeber appearing in issues or generating confusion only some examples will be shown without an in depth description.

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

# html render is not correctly rendered in blog, 
# comment the line below if in jupyter
xr.set_options(display_style="text")  

rng = np.random.default_rng()
az.style.use("arviz-darkgrid")
idata_centered = az.load_arviz_data("centered_eight")
idata = az.load_arviz_data("rugby")

The kind argument

az.plot_trace generates two columns. The left one calls plot_dist to plot KDE/Histogram of the data, and the right column can contain either the trace itself (which gives the name to the plot) or a rank plot for which two visualizations are available. Rank plots are an alternative to trace plots, see https://arxiv.org/abs/1903.08008 for more details.

fig, axes = plt.subplots(3,2, figsize=(12,6))
for i, kind in enumerate(("trace", "rank_bars", "rank_vlines")):
    az.plot_trace(idata, var_names="home", kind=kind, ax=axes[i,:]);
fig.tight_layout()
/home/oriol/venvs/arviz-dev/lib/python3.6/site-packages/ipykernel_launcher.py:4: UserWarning: This figure was using constrained_layout==True, but that is incompatible with subplots_adjust and or tight_layout: setting constrained_layout==False. 
  after removing the cwd from sys.path.

The divergences argument

If present, divergences are indicated as a black rugplot in both columns of the trace plot. By default they are placed at the bottom of the plot, but they can be placed at the top or hidden.

az.plot_trace(idata_centered, var_names="tau");
az.plot_trace(idata_centered, var_names="tau", divergences=None);

The rug argument

rug adds a rug plot with the posterior samples at the bottom of the distribution plot, there are no changes in the trace plot column.

ax = az.plot_trace(idata, var_names="home", rug=True, rug_kwargs={"alpha": .4})

But what about having both rug and divergences at the same time? Fear not, ArviZ automatically modifies the default for divergences from bottom to top to prevent rug and divergences from overlapping:

az.plot_trace(idata_centered, var_names="mu", rug=True);

The lines argument

The description about lines in plot_trace's docstring is the following:

lines : list of tuple of (str, dict, array_like), optional

List of (var_name, {‘coord’: selection}, [line, positions]) to be overplotted as vertical lines on the density and horizontal lines on the trace.

It is possible that the first thought after reading this line is similar to "What is with this weird format?" Well, this format is actually the stardard way ArviZ uses to iterate over xarray.Dataset objects because it contains all the info about the variable and the selected coordinates as well as the values themselves. The main helper function that handles this is arviz.plots.plot_utils.xarray_var_iter.

This section will be a little different from the other ones, and will focus on boosting plot_trace capabilities with internal ArviZ functions. You may want to skip to the section altogether of go straigh to the end.

Let's see what xarray_var_iter does with a simple dataset. We will create a dataset with two variables: a will be a 2x3 matrix and b will be a scalar. In addition, the dimensions of a will be labeled.

ds = xr.Dataset({
    "a": (("pos", "direction"), rng.normal(size=(2,3))),
    "b": 12, 
    "pos": ["top", "bottom"],
    "direction": ["x", "y", "z"]
})
ds
<xarray.Dataset>
Dimensions:    (direction: 3, pos: 2)
Coordinates:
  * pos        (pos) <U6 'top' 'bottom'
  * direction  (direction) <U1 'x' 'y' 'z'
Data variables:
    a          (pos, direction) float64 -0.5306 0.8029 0.7965 ... 0.4623 -0.128
    b          int64 12
from arviz.plots.plot_utils import xarray_var_iter
for var_name, sel, values in xarray_var_iter(ds):
    print(var_name, sel, values)
a {'pos': 'top', 'direction': 'x'} -0.5306128314326483
a {'pos': 'top', 'direction': 'y'} 0.8029249611338745
a {'pos': 'top', 'direction': 'z'} 0.7965222104405889
a {'pos': 'bottom', 'direction': 'x'} -1.4255055469706215
a {'pos': 'bottom', 'direction': 'y'} 0.4622636712711883
a {'pos': 'bottom', 'direction': 'z'} -0.12804707435886095
b {} 12

xarray_var_iter has iterated over every single scalar value without loosing track of where did every value come from. We can also modify the behaviour to skip some dimensions (i.e. in ArviZ we generally iterate over data dimensions and skip chain and draw dims).

for var_name, sel, values in xarray_var_iter(ds, skip_dims={"direction"}):
    print(var_name, sel, values)
a {'pos': 'top'} [-0.53061283  0.80292496  0.79652221]
a {'pos': 'bottom'} [-1.42550555  0.46226367 -0.12804707]
b {} 12

Now that we know about xarray_var_iter and what it does, we can use it to generate a list in the required format directly from xarray objects. Let's say for example we were interested in plotting the mean as a line in the trace plot:

var_names = ["home", "atts"]
lines = list(xarray_var_iter(idata.posterior[var_names].mean(dim=("chain", "draw"))))
az.plot_trace(idata, var_names=var_names, lines=lines);