| Download
Think Stats by Allen B. Downey Think Stats is an introduction to Probability and Statistics for Python programmers.
This is the accompanying code for this book.
Project: Support and Testing
Views: 7089License: GPL3
"""This file contains code for use with "Think Stats",1by Allen B. Downey, available from greenteapress.com23Copyright 2014 Allen B. Downey4License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html5"""67from __future__ import print_function, division89import pandas10import numpy as np11import statsmodels.formula.api as smf12import statsmodels.tsa.stattools as smtsa1314import matplotlib.pyplot as pyplot1516import thinkplot17import thinkstats21819FORMATS = ['png']2021def ReadData():22"""Reads data about cannabis transactions.2324http://zmjones.com/static/data/mj-clean.csv2526returns: DataFrame27"""28transactions = pandas.read_csv('mj-clean.csv', parse_dates=[5])29return transactions303132def tmean(series):33"""Computes a trimmed mean.3435series: Series3637returns: float38"""39t = series.values40n = len(t)41if n <= 3:42return t.mean()43trim = max(1, n//10)44return np.mean(sorted(t)[trim:n-trim])454647def GroupByDay(transactions, func=np.mean):48"""Groups transactions by day and compute the daily mean ppg.4950transactions: DataFrame of transactions5152returns: DataFrame of daily prices53"""54groups = transactions[['date', 'ppg']].groupby('date')55daily = groups.aggregate(func)5657daily['date'] = daily.index58start = daily.date[0]59one_year = np.timedelta64(1, 'Y')60daily['years'] = (daily.date - start) / one_year6162return daily636465def GroupByQualityAndDay(transactions):66"""Divides transactions by quality and computes mean daily price.6768transaction: DataFrame of transactions6970returns: map from quality to time series of ppg71"""72groups = transactions.groupby('quality')73dailies = {}74for name, group in groups:75dailies[name] = GroupByDay(group)7677return dailies787980def PlotDailies(dailies):81"""Makes a plot with daily prices for different qualities.8283dailies: map from name to DataFrame84"""85thinkplot.PrePlot(rows=3)86for i, (name, daily) in enumerate(dailies.items()):87thinkplot.SubPlot(i+1)88title = 'price per gram ($)' if i == 0 else ''89thinkplot.Config(ylim=[0, 20], title=title)90thinkplot.Scatter(daily.ppg, s=10, label=name)91if i == 2:92pyplot.xticks(rotation=30)93else:94thinkplot.Config(xticks=[])9596thinkplot.Save(root='timeseries1',97formats=FORMATS)9899100def RunLinearModel(daily):101"""Runs a linear model of prices versus years.102103daily: DataFrame of daily prices104105returns: model, results106"""107model = smf.ols('ppg ~ years', data=daily)108results = model.fit()109return model, results110111112def PlotFittedValues(model, results, label=''):113"""Plots original data and fitted values.114115model: StatsModel model object116results: StatsModel results object117"""118years = model.exog[:, 1]119values = model.endog120thinkplot.Scatter(years, values, s=15, label=label)121thinkplot.Plot(years, results.fittedvalues, label='model')122123124def PlotResiduals(model, results):125"""Plots the residuals of a model.126127model: StatsModel model object128results: StatsModel results object129"""130years = model.exog[:, 1]131thinkplot.Plot(years, results.resid, linewidth=0.5, alpha=0.5)132133134def PlotResidualPercentiles(model, results, index=1, num_bins=20):135"""Plots percentiles of the residuals.136137model: StatsModel model object138results: StatsModel results object139index: which exogenous variable to use140num_bins: how many bins to divide the x-axis into141"""142exog = model.exog[:, index]143resid = results.resid.values144df = pandas.DataFrame(dict(exog=exog, resid=resid))145146bins = np.linspace(np.min(exog), np.max(exog), num_bins)147indices = np.digitize(exog, bins)148groups = df.groupby(indices)149150means = [group.exog.mean() for _, group in groups][1:-1]151cdfs = [thinkstats2.Cdf(group.resid) for _, group in groups][1:-1]152153thinkplot.PrePlot(3)154for percent in [75, 50, 25]:155percentiles = [cdf.Percentile(percent) for cdf in cdfs]156label = '%dth' % percent157thinkplot.Plot(means, percentiles, label=label)158159160def SimulateResults(daily, iters=101, func=RunLinearModel):161"""Run simulations based on resampling residuals.162163daily: DataFrame of daily prices164iters: number of simulations165func: function that fits a model to the data166167returns: list of result objects168"""169_, results = func(daily)170fake = daily.copy()171172result_seq = []173for _ in range(iters):174fake.ppg = results.fittedvalues + thinkstats2.Resample(results.resid)175_, fake_results = func(fake)176result_seq.append(fake_results)177178return result_seq179180181def SimulateIntervals(daily, iters=101, func=RunLinearModel):182"""Run simulations based on different subsets of the data.183184daily: DataFrame of daily prices185iters: number of simulations186func: function that fits a model to the data187188returns: list of result objects189"""190result_seq = []191starts = np.linspace(0, len(daily), iters).astype(int)192193for start in starts[:-2]:194subset = daily[start:]195_, results = func(subset)196fake = subset.copy()197198for _ in range(iters):199fake.ppg = (results.fittedvalues +200thinkstats2.Resample(results.resid))201_, fake_results = func(fake)202result_seq.append(fake_results)203204return result_seq205206207def GeneratePredictions(result_seq, years, add_resid=False):208"""Generates an array of predicted values from a list of model results.209210When add_resid is False, predictions represent sampling error only.211212When add_resid is True, they also include residual error (which is213more relevant to prediction).214215result_seq: list of model results216years: sequence of times (in years) to make predictions for217add_resid: boolean, whether to add in resampled residuals218219returns: sequence of predictions220"""221n = len(years)222d = dict(Intercept=np.ones(n), years=years, years2=years**2)223predict_df = pandas.DataFrame(d)224225predict_seq = []226for fake_results in result_seq:227predict = fake_results.predict(predict_df)228if add_resid:229predict += thinkstats2.Resample(fake_results.resid, n)230predict_seq.append(predict)231232return predict_seq233234235def GenerateSimplePrediction(results, years):236"""Generates a simple prediction.237238results: results object239years: sequence of times (in years) to make predictions for240241returns: sequence of predicted values242"""243n = len(years)244inter = np.ones(n)245d = dict(Intercept=inter, years=years, years2=years**2)246predict_df = pandas.DataFrame(d)247predict = results.predict(predict_df)248return predict249250251def PlotPredictions(daily, years, iters=101, percent=90, func=RunLinearModel):252"""Plots predictions.253254daily: DataFrame of daily prices255years: sequence of times (in years) to make predictions for256iters: number of simulations257percent: what percentile range to show258func: function that fits a model to the data259"""260result_seq = SimulateResults(daily, iters=iters, func=func)261p = (100 - percent) / 2262percents = p, 100-p263264predict_seq = GeneratePredictions(result_seq, years, add_resid=True)265low, high = thinkstats2.PercentileRows(predict_seq, percents)266thinkplot.FillBetween(years, low, high, alpha=0.3, color='gray')267268predict_seq = GeneratePredictions(result_seq, years, add_resid=False)269low, high = thinkstats2.PercentileRows(predict_seq, percents)270thinkplot.FillBetween(years, low, high, alpha=0.5, color='gray')271272273def PlotIntervals(daily, years, iters=101, percent=90, func=RunLinearModel):274"""Plots predictions based on different intervals.275276daily: DataFrame of daily prices277years: sequence of times (in years) to make predictions for278iters: number of simulations279percent: what percentile range to show280func: function that fits a model to the data281"""282result_seq = SimulateIntervals(daily, iters=iters, func=func)283p = (100 - percent) / 2284percents = p, 100-p285286predict_seq = GeneratePredictions(result_seq, years, add_resid=True)287low, high = thinkstats2.PercentileRows(predict_seq, percents)288thinkplot.FillBetween(years, low, high, alpha=0.2, color='gray')289290291def Correlate(dailies):292"""Compute the correlation matrix between prices for difference qualities.293294dailies: map from quality to time series of ppg295296returns: correlation matrix297"""298df = pandas.DataFrame()299for name, daily in dailies.items():300df[name] = daily.ppg301302return df.corr()303304305def CorrelateResid(dailies):306"""Compute the correlation matrix between residuals.307308dailies: map from quality to time series of ppg309310returns: correlation matrix311"""312df = pandas.DataFrame()313for name, daily in dailies.items():314_, results = RunLinearModel(daily)315df[name] = results.resid316317return df.corr()318319320def TestCorrelateResid(dailies, iters=101):321"""Tests observed correlations.322323dailies: map from quality to time series of ppg324iters: number of simulations325"""326327t = []328names = ['high', 'medium', 'low']329for name in names:330daily = dailies[name]331t.append(SimulateResults(daily, iters=iters))332333corr = CorrelateResid(dailies)334335arrays = []336for result_seq in zip(*t):337df = pandas.DataFrame()338for name, results in zip(names, result_seq):339df[name] = results.resid340341opp_sign = corr * df.corr() < 0342arrays.append((opp_sign.astype(int)))343344print(np.sum(arrays))345346347def RunModels(dailies):348"""Runs linear regression for each group in dailies.349350dailies: map from group name to DataFrame351"""352rows = []353for daily in dailies.values():354_, results = RunLinearModel(daily)355intercept, slope = results.params356p1, p2 = results.pvalues357r2 = results.rsquared358s = r'%0.3f (%0.2g) & %0.3f (%0.2g) & %0.3f \\'359row = s % (intercept, p1, slope, p2, r2)360rows.append(row)361362# print results in a LaTeX table363print(r'\begin{tabular}{|c|c|c|}')364print(r'\hline')365print(r'intercept & slope & $R^2$ \\ \hline')366for row in rows:367print(row)368print(r'\hline')369print(r'\end{tabular}')370371372def FillMissing(daily, span=30):373"""Fills missing values with an exponentially weighted moving average.374375Resulting DataFrame has new columns 'ewma' and 'resid'.376377daily: DataFrame of daily prices378span: window size (sort of) passed to ewma379380returns: new DataFrame of daily prices381"""382dates = pandas.date_range(daily.index.min(), daily.index.max())383reindexed = daily.reindex(dates)384385ewma = pandas.ewma(reindexed.ppg, span=span)386387resid = (reindexed.ppg - ewma).dropna()388fake_data = ewma + thinkstats2.Resample(resid, len(reindexed))389reindexed.ppg.fillna(fake_data, inplace=True)390391reindexed['ewma'] = ewma392reindexed['resid'] = reindexed.ppg - ewma393return reindexed394395396def AddWeeklySeasonality(daily):397"""Adds a weekly pattern.398399daily: DataFrame of daily prices400401returns: new DataFrame of daily prices402"""403frisat = (daily.index.dayofweek==4) | (daily.index.dayofweek==5)404fake = daily.copy()405fake.ppg[frisat] += np.random.uniform(0, 2, frisat.sum())406return fake407408409def PrintSerialCorrelations(dailies):410"""Prints a table of correlations with different lags.411412dailies: map from category name to DataFrame of daily prices413"""414filled_dailies = {}415for name, daily in dailies.items():416filled_dailies[name] = FillMissing(daily, span=30)417418# print serial correlations for raw price data419for name, filled in filled_dailies.items():420corr = thinkstats2.SerialCorr(filled.ppg, lag=1)421print(name, corr)422423rows = []424for lag in [1, 7, 30, 365]:425row = [str(lag)]426for name, filled in filled_dailies.items():427corr = thinkstats2.SerialCorr(filled.resid, lag)428row.append('%.2g' % corr)429rows.append(row)430431print(r'\begin{tabular}{|c|c|c|c|}')432print(r'\hline')433print(r'lag & high & medium & low \\ \hline')434for row in rows:435print(' & '.join(row) + r' \\')436print(r'\hline')437print(r'\end{tabular}')438439filled = filled_dailies['high']440acf = smtsa.acf(filled.resid, nlags=365, unbiased=True)441print('%0.3f, %0.3f, %0.3f, %0.3f, %0.3f' %442(acf[0], acf[1], acf[7], acf[30], acf[365]))443444445def SimulateAutocorrelation(daily, iters=1001, nlags=40):446"""Resample residuals, compute autocorrelation, and plot percentiles.447448daily: DataFrame449iters: number of simulations to run450nlags: maximum lags to compute autocorrelation451"""452# run simulations453t = []454for _ in range(iters):455filled = FillMissing(daily, span=30)456resid = thinkstats2.Resample(filled.resid)457acf = smtsa.acf(resid, nlags=nlags, unbiased=True)[1:]458t.append(np.abs(acf))459460high = thinkstats2.PercentileRows(t, [97.5])[0]461low = -high462lags = list(range(1, nlags+1))463thinkplot.FillBetween(lags, low, high, alpha=0.2, color='gray')464465466def PlotAutoCorrelation(dailies, nlags=40, add_weekly=False):467"""Plots autocorrelation functions.468469dailies: map from category name to DataFrame of daily prices470nlags: number of lags to compute471add_weekly: boolean, whether to add a simulated weekly pattern472"""473thinkplot.PrePlot(3)474daily = dailies['high']475SimulateAutocorrelation(daily)476477for name, daily in dailies.items():478479if add_weekly:480daily = AddWeeklySeasonality(daily)481482filled = FillMissing(daily, span=30)483484acf = smtsa.acf(filled.resid, nlags=nlags, unbiased=True)485lags = np.arange(len(acf))486thinkplot.Plot(lags[1:], acf[1:], label=name)487488489def MakeAcfPlot(dailies):490"""Makes a figure showing autocorrelation functions.491492dailies: map from category name to DataFrame of daily prices493"""494axis = [0, 41, -0.2, 0.2]495496thinkplot.PrePlot(cols=2)497PlotAutoCorrelation(dailies, add_weekly=False)498thinkplot.Config(axis=axis,499loc='lower right',500ylabel='correlation',501xlabel='lag (day)')502503thinkplot.SubPlot(2)504PlotAutoCorrelation(dailies, add_weekly=True)505thinkplot.Save(root='timeseries9',506axis=axis,507loc='lower right',508xlabel='lag (days)',509formats=FORMATS)510511512def PlotRollingMean(daily, name):513"""Plots rolling mean and EWMA.514515daily: DataFrame of daily prices516"""517dates = pandas.date_range(daily.index.min(), daily.index.max())518reindexed = daily.reindex(dates)519520thinkplot.PrePlot(cols=2)521thinkplot.Scatter(reindexed.ppg, s=15, alpha=0.1, label=name)522roll_mean = pandas.rolling_mean(reindexed.ppg, 30)523thinkplot.Plot(roll_mean, label='rolling mean')524pyplot.xticks(rotation=30)525thinkplot.Config(ylabel='price per gram ($)')526527thinkplot.SubPlot(2)528thinkplot.Scatter(reindexed.ppg, s=15, alpha=0.1, label=name)529ewma = pandas.ewma(reindexed.ppg, span=30)530thinkplot.Plot(ewma, label='EWMA')531pyplot.xticks(rotation=30)532thinkplot.Save(root='timeseries10',533formats=FORMATS)534535536def PlotFilled(daily, name):537"""Plots the EWMA and filled data.538539daily: DataFrame of daily prices540"""541filled = FillMissing(daily, span=30)542thinkplot.Scatter(filled.ppg, s=15, alpha=0.3, label=name)543thinkplot.Plot(filled.ewma, label='EWMA', alpha=0.4)544pyplot.xticks(rotation=30)545thinkplot.Save(root='timeseries8',546ylabel='price per gram ($)',547formats=FORMATS)548549550def PlotLinearModel(daily, name):551"""Plots a linear fit to a sequence of prices, and the residuals.552553daily: DataFrame of daily prices554name: string555"""556model, results = RunLinearModel(daily)557PlotFittedValues(model, results, label=name)558thinkplot.Save(root='timeseries2',559title='fitted values',560xlabel='years',561xlim=[-0.1, 3.8],562ylabel='price per gram ($)',563formats=FORMATS)564565PlotResidualPercentiles(model, results)566thinkplot.Save(root='timeseries3',567title='residuals',568xlabel='years',569ylabel='price per gram ($)',570formats=FORMATS)571572#years = np.linspace(0, 5, 101)573#predict = GenerateSimplePrediction(results, years)574575576def main(name):577thinkstats2.RandomSeed(18)578transactions = ReadData()579580dailies = GroupByQualityAndDay(transactions)581PlotDailies(dailies)582RunModels(dailies)583PrintSerialCorrelations(dailies)584MakeAcfPlot(dailies)585586name = 'high'587daily = dailies[name]588589PlotLinearModel(daily, name)590PlotRollingMean(daily, name)591PlotFilled(daily, name)592593years = np.linspace(0, 5, 101)594thinkplot.Scatter(daily.years, daily.ppg, alpha=0.1, label=name)595PlotPredictions(daily, years)596xlim = years[0]-0.1, years[-1]+0.1597thinkplot.Save(root='timeseries4',598title='predictions',599xlabel='years',600xlim=xlim,601ylabel='price per gram ($)',602formats=FORMATS)603604name = 'medium'605daily = dailies[name]606607thinkplot.Scatter(daily.years, daily.ppg, alpha=0.1, label=name)608PlotIntervals(daily, years)609PlotPredictions(daily, years)610xlim = years[0]-0.1, years[-1]+0.1611thinkplot.Save(root='timeseries5',612title='predictions',613xlabel='years',614xlim=xlim,615ylabel='price per gram ($)',616formats=FORMATS)617618619if __name__ == '__main__':620import sys621main(*sys.argv)622623624