Sharedsupport / 2014-11-23-pystan.sagewsOpen in CoCalc
Examples for support purposes...


# module import
import pystan
import numpy as np
import pylab as py
import pandas as pd

## data simulation
x = np.arange(1, 100, 5)
y = 2.5 + .5 * x + np.random.randn(20) * 10

# get number of observations
N = len(x)

# plot the data
py.plot(x,y, 'o')

# STAN model (this is the most important part)
regress_code = """
data {
 int<lower = 0> N; // number of observations
 real y[N]; // response variable
 real x[N]; // predictor variable
parameters {
 real a; // intercept
 real b; // slope
 real<lower=0> sigma; // standard deviation
transformed parameters {
 real mu[N]; // fitted values

for(i in 1:N)
 mu[i] <- a + b*x[i];
model {
 y ~ normal(mu, sigma);

# make a dictionary containing all data to be passed to STAN
regress_dat = {'x': x,
 'y': y,
 'N': N}

# Fit the model
fit = pystan.stan(model_code=regress_code, data=regress_dat,
 iter=1000, chains=4)

# model summary
print fit

# show a traceplot of ALL parameters. This is a bear if you have many

# Instead, show a traceplot for single parameter

##### PREDICTION ####

# make a dataframe of parameter estimates for all chains
params = pd.DataFrame({'a': fit.extract('a', permuted=True), 'b': fit.extract('b', permuted=True)})

# next, make a prediction function. Making a function makes every step following this 10 times easier
def stanPred(p):
 fitted = p[0] + p[1] * predX
 return pd.Series({'fitted': fitted})

# make a prediction vector (the values of X for which you want to predict)
predX = np.arange(1, 100)

# get the median parameter estimates
medParam = params.median()
# predict
yhat = stanPred(medParam)

# get the predicted values for each chain. This is super convenient in pandas because
# it is possible to have a single column where each element is a list
chainPreds = params.apply(stanPred, axis = 1)


# create a random index for chain sampling
idx = np.random.choice(1999, 50)
# plot each chain. chainPreds.iloc[i, 0] gets predicted values from the ith set of parameter estimates
for i in range(len(idx)):
    py.plot(predX, chainPreds.iloc[idx[i], 0], color='lightgrey')

# original data
py.plot(x, y, 'ko')
# fitted values
py.plot(predX, yhat['fitted'], 'k')

# supplementals
[<matplotlib.lines.Line2D object at 0x7f2fb9f727d0>]
Error in lines 21-23 Traceback (most recent call last): File "/projects/4a5f0542-5873-4eed-a85c-a18c706e8bcd/.sagemathcloud/", line 865, in execute exec compile(block+'\n', '', 'single') in namespace, locals File "", line 2, in <module> File "/usr/local/sage/sage-6.4/local/lib/python2.7/site-packages/pystan/", line 370, in stan save_dso=save_dso, verbose=verbose) File "/usr/local/sage/sage-6.4/local/lib/python2.7/site-packages/pystan/", line 305, in __init__ File "/usr/local/sage/sage-6.4/local/lib/python/distutils/command/", line 337, in run self.build_extensions() File "/usr/local/sage/sage-6.4/local/lib/python/distutils/command/", line 446, in build_extensions self.build_extension(ext) File "/usr/local/sage/sage-6.4/local/lib/python/distutils/command/", line 496, in build_extension depends=ext.depends) File "/usr/local/sage/sage-6.4/local/lib/python/distutils/", line 574, in compile self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts) File "/usr/local/sage/sage-6.4/local/lib/python/distutils/", line 124, in _compile raise CompileError, msg CompileError: command 'gcc' failed with exit status 4
# make a function that iterates over every predicted values in every chain and returns the quantiles. For example:

def quantileGet(q):
    # make a list to store the quantiles
    quants = []

    # for every predicted value
    for i in range(len(predX)):
        # make a vector to store the predictions from each chain
        val = []

        # next go down the rows and store the values
        for j in range(chainPreds.shape[0]):

        # return the quantile for the predictions.
        quants.append(np.percentile(val, q))

    return quants

# 2.5% quantile
lower = quantileGet(2.5)
upper = quantileGet(97.5)

# plot this
fig = py.figure()
ax = fig.add_subplot(111)

# shade the credible interval
ax.fill_between(predX, lower, upper, facecolor = 'lightgrey', edgecolor = 'none')
# plot the data
ax.plot(x, y, 'ko')
# plot the fitted line
ax.plot(predX, yhat['fitted'], 'k')

# supplementals
Error in lines 17-18 Traceback (most recent call last): File "/projects/4a5f0542-5873-4eed-a85c-a18c706e8bcd/.sagemathcloud/", line 865, in execute exec compile(block+'\n', '', 'single') in namespace, locals File "", line 1, in <module> File "", line 5, in quantileGet NameError: global name 'predX' is not defined