| Download
Think Stats by Allen B. Downey Think Stats is an introduction to Probability and Statistics for Python programmers.
This is the accompanying code for this book.
Project: Support and Testing
Views: 7089License: GPL3
"""This file contains code for use with "Think Stats",1by Allen B. Downey, available from greenteapress.com23Copyright 2014 Allen B. Downey4License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html5"""67from __future__ import print_function89import math10import matplotlib11import matplotlib.pyplot as plt12import numpy as np13import pandas as pd1415import warnings1617# customize some matplotlib attributes18#matplotlib.rc('figure', figsize=(4, 3))1920#matplotlib.rc('font', size=14.0)21#matplotlib.rc('axes', labelsize=22.0, titlesize=22.0)22#matplotlib.rc('legend', fontsize=20.0)2324#matplotlib.rc('xtick.major', size=6.0)25#matplotlib.rc('xtick.minor', size=3.0)2627#matplotlib.rc('ytick.major', size=6.0)28#matplotlib.rc('ytick.minor', size=3.0)293031class _Brewer(object):32"""Encapsulates a nice sequence of colors.3334Shades of blue that look good in color and can be distinguished35in grayscale (up to a point).3637Borrowed from http://colorbrewer2.org/38"""39color_iter = None4041colors = ['#f7fbff', '#deebf7', '#c6dbef',42'#9ecae1', '#6baed6', '#4292c6',43'#2171b5','#08519c','#08306b'][::-1]4445# lists that indicate which colors to use depending on how many are used46which_colors = [[],47[1],48[1, 3],49[0, 2, 4],50[0, 2, 4, 6],51[0, 2, 3, 5, 6],52[0, 2, 3, 4, 5, 6],53[0, 1, 2, 3, 4, 5, 6],54[0, 1, 2, 3, 4, 5, 6, 7],55[0, 1, 2, 3, 4, 5, 6, 7, 8],56]5758current_figure = None5960@classmethod61def Colors(cls):62"""Returns the list of colors.63"""64return cls.colors6566@classmethod67def ColorGenerator(cls, num):68"""Returns an iterator of color strings.6970n: how many colors will be used71"""72for i in cls.which_colors[num]:73yield cls.colors[i]74raise StopIteration('Ran out of colors in _Brewer.')7576@classmethod77def InitIter(cls, num):78"""Initializes the color iterator with the given number of colors."""79cls.color_iter = cls.ColorGenerator(num)80fig = plt.gcf()81cls.current_figure = fig8283@classmethod84def ClearIter(cls):85"""Sets the color iterator to None."""86cls.color_iter = None87cls.current_figure = None8889@classmethod90def GetIter(cls, num):91"""Gets the color iterator."""92fig = plt.gcf()93if fig != cls.current_figure:94cls.InitIter(num)95cls.current_figure = fig9697if cls.color_iter is None:98cls.InitIter(num)99100return cls.color_iter101102103def _UnderrideColor(options):104"""If color is not in the options, chooses a color.105"""106if 'color' in options:107return options108109# get the current color iterator; if there is none, init one110color_iter = _Brewer.GetIter(5)111112try:113options['color'] = next(color_iter)114except StopIteration:115# if you run out of colors, initialize the color iterator116# and try again117warnings.warn('Ran out of colors. Starting over.')118_Brewer.ClearIter()119_UnderrideColor(options)120121return options122123124def PrePlot(num=None, rows=None, cols=None):125"""Takes hints about what's coming.126127num: number of lines that will be plotted128rows: number of rows of subplots129cols: number of columns of subplots130"""131if num:132_Brewer.InitIter(num)133134if rows is None and cols is None:135return136137if rows is not None and cols is None:138cols = 1139140if cols is not None and rows is None:141rows = 1142143# resize the image, depending on the number of rows and cols144size_map = {(1, 1): (8, 6),145(1, 2): (12, 6),146(1, 3): (12, 6),147(1, 4): (12, 5),148(1, 5): (12, 4),149(2, 2): (10, 10),150(2, 3): (16, 10),151(3, 1): (8, 10),152(4, 1): (8, 12),153}154155if (rows, cols) in size_map:156fig = plt.gcf()157fig.set_size_inches(*size_map[rows, cols])158159# create the first subplot160if rows > 1 or cols > 1:161ax = plt.subplot(rows, cols, 1)162global SUBPLOT_ROWS, SUBPLOT_COLS163SUBPLOT_ROWS = rows164SUBPLOT_COLS = cols165else:166ax = plt.gca()167168return ax169170171def SubPlot(plot_number, rows=None, cols=None, **options):172"""Configures the number of subplots and changes the current plot.173174rows: int175cols: int176plot_number: int177options: passed to subplot178"""179rows = rows or SUBPLOT_ROWS180cols = cols or SUBPLOT_COLS181return plt.subplot(rows, cols, plot_number, **options)182183184def _Underride(d, **options):185"""Add key-value pairs to d only if key is not in d.186187If d is None, create a new dictionary.188189d: dictionary190options: keyword args to add to d191"""192if d is None:193d = {}194195for key, val in options.items():196d.setdefault(key, val)197198return d199200201def Clf():202"""Clears the figure and any hints that have been set."""203global LOC204LOC = None205_Brewer.ClearIter()206plt.clf()207fig = plt.gcf()208fig.set_size_inches(8, 6)209210211def Figure(**options):212"""Sets options for the current figure."""213_Underride(options, figsize=(6, 8))214plt.figure(**options)215216217def Plot(obj, ys=None, style='', **options):218"""Plots a line.219220Args:221obj: sequence of x values, or Series, or anything with Render()222ys: sequence of y values223style: style string passed along to plt.plot224options: keyword args passed to plt.plot225"""226options = _UnderrideColor(options)227label = getattr(obj, 'label', '_nolegend_')228options = _Underride(options, linewidth=3, alpha=0.7, label=label)229230xs = obj231if ys is None:232if hasattr(obj, 'Render'):233xs, ys = obj.Render()234if isinstance(obj, pd.Series):235ys = obj.values236xs = obj.index237238if ys is None:239plt.plot(xs, style, **options)240else:241plt.plot(xs, ys, style, **options)242243244def Vlines(xs, y1, y2, **options):245"""Plots a set of vertical lines.246247Args:248xs: sequence of x values249y1: sequence of y values250y2: sequence of y values251options: keyword args passed to plt.vlines252"""253options = _UnderrideColor(options)254options = _Underride(options, linewidth=1, alpha=0.5)255plt.vlines(xs, y1, y2, **options)256257258def Hlines(ys, x1, x2, **options):259"""Plots a set of horizontal lines.260261Args:262ys: sequence of y values263x1: sequence of x values264x2: sequence of x values265options: keyword args passed to plt.vlines266"""267options = _UnderrideColor(options)268options = _Underride(options, linewidth=1, alpha=0.5)269plt.hlines(ys, x1, x2, **options)270271272def axvline(x, **options):273"""Plots a vertical line.274275Args:276x: x location277options: keyword args passed to plt.axvline278"""279options = _UnderrideColor(options)280options = _Underride(options, linewidth=1, alpha=0.5)281plt.axvline(x, **options)282283284def axhline(y, **options):285"""Plots a horizontal line.286287Args:288y: y location289options: keyword args passed to plt.axhline290"""291options = _UnderrideColor(options)292options = _Underride(options, linewidth=1, alpha=0.5)293plt.axhline(y, **options)294295296def tight_layout(**options):297"""Adjust subplots to minimize padding and margins.298"""299options = _Underride(options,300wspace=0.1, hspace=0.1,301left=0, right=1,302bottom=0, top=1)303plt.tight_layout()304plt.subplots_adjust(**options)305306307def FillBetween(xs, y1, y2=None, where=None, **options):308"""Fills the space between two lines.309310Args:311xs: sequence of x values312y1: sequence of y values313y2: sequence of y values314where: sequence of boolean315options: keyword args passed to plt.fill_between316"""317options = _UnderrideColor(options)318options = _Underride(options, linewidth=0, alpha=0.5)319plt.fill_between(xs, y1, y2, where, **options)320321322def Bar(xs, ys, **options):323"""Plots a line.324325Args:326xs: sequence of x values327ys: sequence of y values328options: keyword args passed to plt.bar329"""330options = _UnderrideColor(options)331options = _Underride(options, linewidth=0, alpha=0.6)332plt.bar(xs, ys, **options)333334335def Scatter(xs, ys=None, **options):336"""Makes a scatter plot.337338xs: x values339ys: y values340options: options passed to plt.scatter341"""342options = _Underride(options, color='blue', alpha=0.2,343s=30, edgecolors='none')344345if ys is None and isinstance(xs, pd.Series):346ys = xs.values347xs = xs.index348349plt.scatter(xs, ys, **options)350351352def HexBin(xs, ys, **options):353"""Makes a scatter plot.354355xs: x values356ys: y values357options: options passed to plt.scatter358"""359options = _Underride(options, cmap=matplotlib.cm.Blues)360plt.hexbin(xs, ys, **options)361362363def Pdf(pdf, **options):364"""Plots a Pdf, Pmf, or Hist as a line.365366Args:367pdf: Pdf, Pmf, or Hist object368options: keyword args passed to plt.plot369"""370low, high = options.pop('low', None), options.pop('high', None)371n = options.pop('n', 101)372xs, ps = pdf.Render(low=low, high=high, n=n)373options = _Underride(options, label=pdf.label)374Plot(xs, ps, **options)375376377def Pdfs(pdfs, **options):378"""Plots a sequence of PDFs.379380Options are passed along for all PDFs. If you want different381options for each pdf, make multiple calls to Pdf.382383Args:384pdfs: sequence of PDF objects385options: keyword args passed to plt.plot386"""387for pdf in pdfs:388Pdf(pdf, **options)389390391def Hist(hist, **options):392"""Plots a Pmf or Hist with a bar plot.393394The default width of the bars is based on the minimum difference395between values in the Hist. If that's too small, you can override396it by providing a width keyword argument, in the same units397as the values.398399Args:400hist: Hist or Pmf object401options: keyword args passed to plt.bar402"""403# find the minimum distance between adjacent values404xs, ys = hist.Render()405406# see if the values support arithmetic407try:408xs[0] - xs[0]409except TypeError:410# if not, replace values with numbers411labels = [str(x) for x in xs]412xs = np.arange(len(xs))413plt.xticks(xs+0.5, labels)414415if 'width' not in options:416try:417options['width'] = 0.9 * np.diff(xs).min()418except TypeError:419warnings.warn("Hist: Can't compute bar width automatically."420"Check for non-numeric types in Hist."421"Or try providing width option."422)423424options = _Underride(options, label=hist.label)425options = _Underride(options, align='center')426if options['align'] == 'left':427options['align'] = 'edge'428elif options['align'] == 'right':429options['align'] = 'edge'430options['width'] *= -1431432Bar(xs, ys, **options)433434435def Hists(hists, **options):436"""Plots two histograms as interleaved bar plots.437438Options are passed along for all PMFs. If you want different439options for each pmf, make multiple calls to Pmf.440441Args:442hists: list of two Hist or Pmf objects443options: keyword args passed to plt.plot444"""445for hist in hists:446Hist(hist, **options)447448449def Pmf(pmf, **options):450"""Plots a Pmf or Hist as a line.451452Args:453pmf: Hist or Pmf object454options: keyword args passed to plt.plot455"""456xs, ys = pmf.Render()457low, high = min(xs), max(xs)458459width = options.pop('width', None)460if width is None:461try:462width = np.diff(xs).min()463except TypeError:464warnings.warn("Pmf: Can't compute bar width automatically."465"Check for non-numeric types in Pmf."466"Or try providing width option.")467points = []468469lastx = np.nan470lasty = 0471for x, y in zip(xs, ys):472if (x - lastx) > 1e-5:473points.append((lastx, 0))474points.append((x, 0))475476points.append((x, lasty))477points.append((x, y))478points.append((x+width, y))479480lastx = x + width481lasty = y482points.append((lastx, 0))483pxs, pys = zip(*points)484485align = options.pop('align', 'center')486if align == 'center':487pxs = np.array(pxs) - width/2.0488if align == 'right':489pxs = np.array(pxs) - width490491options = _Underride(options, label=pmf.label)492Plot(pxs, pys, **options)493494495def Pmfs(pmfs, **options):496"""Plots a sequence of PMFs.497498Options are passed along for all PMFs. If you want different499options for each pmf, make multiple calls to Pmf.500501Args:502pmfs: sequence of PMF objects503options: keyword args passed to plt.plot504"""505for pmf in pmfs:506Pmf(pmf, **options)507508509def Diff(t):510"""Compute the differences between adjacent elements in a sequence.511512Args:513t: sequence of number514515Returns:516sequence of differences (length one less than t)517"""518diffs = [t[i+1] - t[i] for i in range(len(t)-1)]519return diffs520521522def Cdf(cdf, complement=False, transform=None, **options):523"""Plots a CDF as a line.524525Args:526cdf: Cdf object527complement: boolean, whether to plot the complementary CDF528transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'529options: keyword args passed to plt.plot530531Returns:532dictionary with the scale options that should be passed to533Config, Show or Save.534"""535xs, ps = cdf.Render()536xs = np.asarray(xs)537ps = np.asarray(ps)538539scale = dict(xscale='linear', yscale='linear')540541for s in ['xscale', 'yscale']:542if s in options:543scale[s] = options.pop(s)544545if transform == 'exponential':546complement = True547scale['yscale'] = 'log'548549if transform == 'pareto':550complement = True551scale['yscale'] = 'log'552scale['xscale'] = 'log'553554if complement:555ps = [1.0-p for p in ps]556557if transform == 'weibull':558xs = np.delete(xs, -1)559ps = np.delete(ps, -1)560ps = [-math.log(1.0-p) for p in ps]561scale['xscale'] = 'log'562scale['yscale'] = 'log'563564if transform == 'gumbel':565xs = np.delete(xs, 0)566ps = np.delete(ps, 0)567ps = [-math.log(p) for p in ps]568scale['yscale'] = 'log'569570options = _Underride(options, label=cdf.label)571Plot(xs, ps, **options)572return scale573574575def Cdfs(cdfs, complement=False, transform=None, **options):576"""Plots a sequence of CDFs.577578cdfs: sequence of CDF objects579complement: boolean, whether to plot the complementary CDF580transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'581options: keyword args passed to plt.plot582"""583for cdf in cdfs:584Cdf(cdf, complement, transform, **options)585586587def Contour(obj, pcolor=False, contour=True, imshow=False, **options):588"""Makes a contour plot.589590d: map from (x, y) to z, or object that provides GetDict591pcolor: boolean, whether to make a pseudocolor plot592contour: boolean, whether to make a contour plot593imshow: boolean, whether to use plt.imshow594options: keyword args passed to plt.pcolor and/or plt.contour595"""596try:597d = obj.GetDict()598except AttributeError:599d = obj600601_Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)602603xs, ys = zip(*d.keys())604xs = sorted(set(xs))605ys = sorted(set(ys))606607X, Y = np.meshgrid(xs, ys)608func = lambda x, y: d.get((x, y), 0)609func = np.vectorize(func)610Z = func(X, Y)611612x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)613axes = plt.gca()614axes.xaxis.set_major_formatter(x_formatter)615616if pcolor:617plt.pcolormesh(X, Y, Z, **options)618if contour:619cs = plt.contour(X, Y, Z, **options)620plt.clabel(cs, inline=1, fontsize=10)621if imshow:622extent = xs[0], xs[-1], ys[0], ys[-1]623plt.imshow(Z, extent=extent, **options)624625626def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options):627"""Makes a pseudocolor plot.628629xs:630ys:631zs:632pcolor: boolean, whether to make a pseudocolor plot633contour: boolean, whether to make a contour plot634options: keyword args passed to plt.pcolor and/or plt.contour635"""636_Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)637638X, Y = np.meshgrid(xs, ys)639Z = zs640641x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)642axes = plt.gca()643axes.xaxis.set_major_formatter(x_formatter)644645if pcolor:646plt.pcolormesh(X, Y, Z, **options)647648if contour:649cs = plt.contour(X, Y, Z, **options)650plt.clabel(cs, inline=1, fontsize=10)651652653def Text(x, y, s, **options):654"""Puts text in a figure.655656x: number657y: number658s: string659options: keyword args passed to plt.text660"""661options = _Underride(options,662fontsize=16,663verticalalignment='top',664horizontalalignment='left')665plt.text(x, y, s, **options)666667668LEGEND = True669LOC = None670671def Config(**options):672"""Configures the plot.673674Pulls options out of the option dictionary and passes them to675the corresponding plt functions.676"""677names = ['title', 'xlabel', 'ylabel', 'xscale', 'yscale',678'xticks', 'yticks', 'axis', 'xlim', 'ylim']679680for name in names:681if name in options:682getattr(plt, name)(options[name])683684global LEGEND685LEGEND = options.get('legend', LEGEND)686687# see if there are any elements with labels;688# if not, don't draw a legend689ax = plt.gca()690handles, labels = ax.get_legend_handles_labels()691692if LEGEND and len(labels) > 0:693global LOC694LOC = options.get('loc', LOC)695frameon = options.get('frameon', True)696697try:698plt.legend(loc=LOC, frameon=frameon)699except UserWarning:700pass701702# x and y ticklabels can be made invisible703val = options.get('xticklabels', None)704if val is not None:705if val == 'invisible':706ax = plt.gca()707labels = ax.get_xticklabels()708plt.setp(labels, visible=False)709710val = options.get('yticklabels', None)711if val is not None:712if val == 'invisible':713ax = plt.gca()714labels = ax.get_yticklabels()715plt.setp(labels, visible=False)716717def set_font_size(title_size=16, label_size=16, ticklabel_size=14, legend_size=14):718"""Set font sizes for the title, labels, ticklabels, and legend.719"""720def set_text_size(texts, size):721for text in texts:722text.set_size(size)723724ax = plt.gca()725726# TODO: Make this function more robust if any of these elements727# is missing.728729# title730ax.title.set_size(title_size)731732# x axis733ax.xaxis.label.set_size(label_size)734set_text_size(ax.xaxis.get_ticklabels(), ticklabel_size)735736# y axis737ax.yaxis.label.set_size(label_size)738set_text_size(ax.yaxis.get_ticklabels(), ticklabel_size)739740# legend741legend = ax.get_legend()742if legend is not None:743set_text_size(legend.texts, legend_size)744745746def bigger_text():747sizes = dict(title_size=16, label_size=16, ticklabel_size=14, legend_size=14)748set_font_size(**sizes)749750751def Show(**options):752"""Shows the plot.753754For options, see Config.755756options: keyword args used to invoke various plt functions757"""758clf = options.pop('clf', True)759Config(**options)760plt.show()761if clf:762Clf()763764765def Plotly(**options):766"""Shows the plot.767768For options, see Config.769770options: keyword args used to invoke various plt functions771"""772clf = options.pop('clf', True)773Config(**options)774import plotly.plotly as plotly775url = plotly.plot_mpl(plt.gcf())776if clf:777Clf()778return url779780781def Save(root=None, formats=None, **options):782"""Saves the plot in the given formats and clears the figure.783784For options, see Config.785786Note: With a capital S, this is the original save, maintained for787compatibility. New code should use save(), which works better788with my newer code, especially in Jupyter notebooks.789790Args:791root: string filename root792formats: list of string formats793options: keyword args used to invoke various plt functions794"""795clf = options.pop('clf', True)796797save_options = {}798for option in ['bbox_inches', 'pad_inches']:799if option in options:800save_options[option] = options.pop(option)801802# TODO: falling Config inside Save was probably a mistake, but removing803# it will require some work804Config(**options)805806if formats is None:807formats = ['pdf', 'png']808809try:810formats.remove('plotly')811Plotly(clf=False)812except ValueError:813pass814815if root:816for fmt in formats:817SaveFormat(root, fmt, **save_options)818if clf:819Clf()820821822def save(root, formats=None, **options):823"""Saves the plot in the given formats and clears the figure.824825For options, see plt.savefig.826827Args:828root: string filename root829formats: list of string formats830options: keyword args passed to plt.savefig831"""832if formats is None:833formats = ['pdf', 'png']834835try:836formats.remove('plotly')837Plotly(clf=False)838except ValueError:839pass840841for fmt in formats:842SaveFormat(root, fmt, **options)843844845def SaveFormat(root, fmt='eps', **options):846"""Writes the current figure to a file in the given format.847848Args:849root: string filename root850fmt: string format851"""852_Underride(options, dpi=300)853filename = '%s.%s' % (root, fmt)854print('Writing', filename)855plt.savefig(filename, format=fmt, **options)856857858# provide aliases for calling functions with lower-case names859preplot = PrePlot860subplot = SubPlot861clf = Clf862figure = Figure863plot = Plot864vlines = Vlines865hlines = Hlines866fill_between = FillBetween867text = Text868scatter = Scatter869pmf = Pmf870pmfs = Pmfs871hist = Hist872hists = Hists873diff = Diff874cdf = Cdf875cdfs = Cdfs876contour = Contour877pcolor = Pcolor878config = Config879show = Show880881882def main():883color_iter = _Brewer.ColorGenerator(7)884for color in color_iter:885print(color)886887888if __name__ == '__main__':889main()890891892