| Download
Think Stats by Allen B. Downey Think Stats is an introduction to Probability and Statistics for Python programmers.
This is the accompanying code for this book.
Project: Support and Testing
Views: 7120License: GPL3
1"""This file contains code for use with "Think Stats",2by Allen B. Downey, available from greenteapress.com34Copyright 2014 Allen B. Downey5License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html6"""78from __future__ import print_function, division910import numpy as np11import pandas as pd1213import nsfg1415import thinkstats216import thinkplot1718from collections import Counter1920FORMATS = ['pdf', 'eps', 'png']212223class SurvivalFunction(object):24"""Represents a survival function."""2526def __init__(self, ts, ss, label=''):27self.ts = ts28self.ss = ss29self.label = label3031def __len__(self):32return len(self.ts)3334def __getitem__(self, t):35return self.Prob(t)3637def Prob(self, t):38"""Returns S(t), the probability that corresponds to value t.39t: time40returns: float probability41"""42return np.interp(t, self.ts, self.ss, left=1.0)4344def Probs(self, ts):45"""Gets probabilities for a sequence of values."""46return np.interp(ts, self.ts, self.ss, left=1.0)4748def Items(self):49"""Sorted sequence of (t, s) pairs."""50return zip(self.ts, self.ss)5152def Render(self):53"""Generates a sequence of points suitable for plotting.54returns: tuple of (sorted times, survival function)55"""56return self.ts, self.ss5758def MakeHazardFunction(self, label=''):59"""Computes the hazard function.6061This simple version does not take into account the62spacing between the ts. If the ts are not equally63spaced, it is not valid to compare the magnitude of64the hazard function across different time steps.6566label: string6768returns: HazardFunction object69"""70lams = pd.Series(index=self.ts)7172prev = 1.073for t, s in zip(self.ts, self.ss):74lams[t] = (prev - s) / prev75prev = s7677return HazardFunction(lams, label=label)7879def MakePmf(self, filler=None):80"""Makes a PMF of lifetimes.8182filler: value to replace missing values8384returns: Pmf85"""86cdf = thinkstats2.Cdf(self.ts, 1-self.ss)87pmf = thinkstats2.Pmf()88for val, prob in cdf.Items():89pmf.Set(val, prob)9091cutoff = cdf.ps[-1]92if filler is not None:93pmf[filler] = 1-cutoff9495return pmf9697def RemainingLifetime(self, filler=None, func=thinkstats2.Pmf.Mean):98"""Computes remaining lifetime as a function of age.99func: function from conditional Pmf to expected liftime100returns: Series that maps from age to remaining lifetime101"""102pmf = self.MakePmf(filler=filler)103d = {}104for t in sorted(pmf.Values())[:-1]:105pmf[t] = 0106pmf.Normalize()107d[t] = func(pmf) - t108109return pd.Series(d)110111112def MakeSurvivalFromSeq(values, label=''):113"""Makes a survival function based on a complete dataset.114115values: sequence of observed lifespans116117returns: SurvivalFunction118"""119counter = Counter(values)120ts, freqs = zip(*sorted(counter.items()))121ts = np.asarray(ts)122ps = np.cumsum(freqs, dtype=np.float)123ps /= ps[-1]124ss = 1 - ps125return SurvivalFunction(ts, ss, label)126127128def MakeSurvivalFromCdf(cdf, label=''):129"""Makes a survival function based on a CDF.130131cdf: Cdf132133returns: SurvivalFunction134"""135ts = cdf.xs136ss = 1 - cdf.ps137return SurvivalFunction(ts, ss, label)138139140class HazardFunction(object):141"""Represents a hazard function."""142143def __init__(self, d, label=''):144"""Initialize the hazard function.145146d: dictionary (or anything that can initialize a series)147label: string148"""149self.series = pd.Series(d)150self.label = label151152def __len__(self):153return len(self.series)154155def __getitem__(self, t):156return self.series[t]157158def Get(self, t, default=np.nan):159return self.series.get(t, default)160161def Render(self):162"""Generates a sequence of points suitable for plotting.163164returns: tuple of (sorted times, hazard function)165"""166return self.series.index, self.series.values167168def MakeSurvival(self, label=''):169"""Makes the survival function.170171returns: SurvivalFunction172"""173ts = self.series.index174ss = (1 - self.series).cumprod()175sf = SurvivalFunction(ts, ss, label=label)176return sf177178def Extend(self, other):179"""Extends this hazard function by copying the tail from another.180other: HazardFunction181"""182last_index = self.series.index[-1] if len(self) else 0183more = other.series[other.series.index > last_index]184self.series = pd.concat([self.series, more])185186def Truncate(self, t):187"""Truncates this hazard function at the given value of t.188t: number189"""190self.series = self.series[self.series.index < t]191192193def ConditionalSurvival(pmf, t0):194"""Computes conditional survival function.195196Probability that duration exceeds t0+t, given that197duration >= t0.198199pmf: Pmf of durations200t0: minimum time201202returns: tuple of (ts, conditional survivals)203"""204cond = thinkstats2.Pmf()205for t, p in pmf.Items():206if t >= t0:207cond.Set(t-t0, p)208cond.Normalize()209return MakeSurvivalFromCdf(cond.MakeCdf())210211212def PlotConditionalSurvival(durations):213"""Plots conditional survival curves for a range of t0.214215durations: list of durations216"""217pmf = thinkstats2.Pmf(durations)218219times = [8, 16, 24, 32]220thinkplot.PrePlot(len(times))221222for t0 in times:223sf = ConditionalSurvival(pmf, t0)224label = 't0=%d' % t0225thinkplot.Plot(sf, label=label)226227thinkplot.Show()228229230def PlotSurvival(complete):231"""Plots survival and hazard curves.232233complete: list of complete lifetimes234"""235thinkplot.PrePlot(3, rows=2)236237cdf = thinkstats2.Cdf(complete, label='cdf')238sf = MakeSurvivalFromCdf(cdf, label='survival')239print(cdf[13])240print(sf[13])241242thinkplot.Plot(sf)243thinkplot.Cdf(cdf, alpha=0.2)244thinkplot.Config()245246thinkplot.SubPlot(2)247hf = sf.MakeHazardFunction(label='hazard')248print(hf[39])249thinkplot.Plot(hf)250thinkplot.Config(ylim=[0, 0.75])251252253def PlotHazard(complete, ongoing):254"""Plots the hazard function and survival function.255256complete: list of complete lifetimes257ongoing: list of ongoing lifetimes258"""259# plot S(t) based on only complete pregnancies260sf = MakeSurvivalFromSeq(complete)261thinkplot.Plot(sf, label='old S(t)', alpha=0.1)262263thinkplot.PrePlot(2)264265# plot the hazard function266hf = EstimateHazardFunction(complete, ongoing)267thinkplot.Plot(hf, label='lams(t)', alpha=0.5)268269# plot the survival function270sf = hf.MakeSurvival()271272thinkplot.Plot(sf, label='S(t)')273thinkplot.Show(xlabel='t (weeks)')274275276def EstimateHazardFunction(complete, ongoing, label='', verbose=False):277"""Estimates the hazard function by Kaplan-Meier.278279http://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator280281complete: list of complete lifetimes282ongoing: list of ongoing lifetimes283label: string284verbose: whether to display intermediate results285"""286if np.sum(np.isnan(complete)):287raise ValueError("complete contains NaNs")288if np.sum(np.isnan(ongoing)):289raise ValueError("ongoing contains NaNs")290291hist_complete = Counter(complete)292hist_ongoing = Counter(ongoing)293294ts = list(hist_complete | hist_ongoing)295ts.sort()296297at_risk = len(complete) + len(ongoing)298299lams = pd.Series(index=ts)300for t in ts:301ended = hist_complete[t]302censored = hist_ongoing[t]303304lams[t] = ended / at_risk305if verbose:306print('%0.3g\t%d\t%d\t%d\t%0.2g' %307(t, at_risk, ended, censored, lams[t]))308at_risk -= ended + censored309310return HazardFunction(lams, label=label)311312313def EstimateHazardNumpy(complete, ongoing, label=''):314"""Estimates the hazard function by Kaplan-Meier.315316Just for fun, this is a version that uses NumPy to317eliminate loops.318319complete: list of complete lifetimes320ongoing: list of ongoing lifetimes321label: string322"""323hist_complete = Counter(complete)324hist_ongoing = Counter(ongoing)325326ts = set(hist_complete) | set(hist_ongoing)327at_risk = len(complete) + len(ongoing)328329ended = [hist_complete[t] for t in ts]330ended_c = np.cumsum(ended)331censored_c = np.cumsum([hist_ongoing[t] for t in ts])332333not_at_risk = np.roll(ended_c, 1) + np.roll(censored_c, 1)334not_at_risk[0] = 0335336at_risk_array = at_risk - not_at_risk337hs = ended / at_risk_array338339lams = dict(zip(ts, hs))340341return HazardFunction(lams, label=label)342343344def AddLabelsByDecade(groups, **options):345"""Draws fake points in order to add labels to the legend.346347groups: GroupBy object348"""349thinkplot.PrePlot(len(groups))350for name, _ in groups:351label = '%d0s' % name352thinkplot.Plot([15], [1], label=label, **options)353354355def EstimateMarriageSurvivalByDecade(groups, **options):356"""Groups respondents by decade and plots survival curves.357358groups: GroupBy object359"""360thinkplot.PrePlot(len(groups))361for _, group in groups:362_, sf = EstimateMarriageSurvival(group)363thinkplot.Plot(sf, **options)364365366def PlotPredictionsByDecade(groups, **options):367"""Groups respondents by decade and plots survival curves.368369groups: GroupBy object370"""371hfs = []372for _, group in groups:373hf, sf = EstimateMarriageSurvival(group)374hfs.append(hf)375376thinkplot.PrePlot(len(hfs))377for i, hf in enumerate(hfs):378if i > 0:379hf.Extend(hfs[i-1])380sf = hf.MakeSurvival()381thinkplot.Plot(sf, **options)382383384def ResampleSurvival(resp, iters=101):385"""Resamples respondents and estimates the survival function.386387resp: DataFrame of respondents388iters: number of resamples389"""390_, sf = EstimateMarriageSurvival(resp)391thinkplot.Plot(sf)392393low, high = resp.agemarry.min(), resp.agemarry.max()394ts = np.arange(low, high, 1/12.0)395396ss_seq = []397for _ in range(iters):398sample = thinkstats2.ResampleRowsWeighted(resp)399_, sf = EstimateMarriageSurvival(sample)400ss_seq.append(sf.Probs(ts))401402low, high = thinkstats2.PercentileRows(ss_seq, [5, 95])403thinkplot.FillBetween(ts, low, high, color='gray', label='90% CI')404thinkplot.Save(root='survival3',405xlabel='age (years)',406ylabel='prob unmarried',407xlim=[12, 46],408ylim=[0, 1],409formats=FORMATS)410411412def EstimateMarriageSurvival(resp):413"""Estimates the survival curve.414415resp: DataFrame of respondents416417returns: pair of HazardFunction, SurvivalFunction418"""419# NOTE: Filling missing values would be better than dropping them.420complete = resp[resp.evrmarry == 1].agemarry.dropna()421ongoing = resp[resp.evrmarry == 0].age422423hf = EstimateHazardFunction(complete, ongoing)424sf = hf.MakeSurvival()425426return hf, sf427428429def PlotMarriageData(resp):430"""Plots hazard and survival functions.431432resp: DataFrame of respondents433"""434hf, sf = EstimateMarriageSurvival(resp)435436thinkplot.PrePlot(rows=2)437thinkplot.Plot(hf)438thinkplot.Config(ylabel='hazard', legend=False)439440thinkplot.SubPlot(2)441thinkplot.Plot(sf)442thinkplot.Save(root='survival2',443xlabel='age (years)',444ylabel='prob unmarried',445ylim=[0, 1],446legend=False,447formats=FORMATS)448return sf449450451def PlotPregnancyData(preg):452"""Plots survival and hazard curves based on pregnancy lengths.453454preg:455456457Outcome codes from http://www.icpsr.umich.edu/nsfg6/Controller?458displayPage=labelDetails&fileCode=PREG§ion=&subSec=8016&srtLabel=6119324594601 LIVE BIRTH 91484612 INDUCED ABORTION 18624623 STILLBIRTH 1204634 MISCARRIAGE 19214645 ECTOPIC PREGNANCY 1904656 CURRENT PREGNANCY 352466467"""468complete = preg.query('outcome in [1, 3, 4]').prglngth469print('Number of complete pregnancies', len(complete))470ongoing = preg[preg.outcome == 6].prglngth471print('Number of ongoing pregnancies', len(ongoing))472473PlotSurvival(complete)474thinkplot.Save(root='survival1',475xlabel='t (weeks)',476formats=FORMATS)477478hf = EstimateHazardFunction(complete, ongoing)479sf = hf.MakeSurvival()480return sf481482483def PlotRemainingLifetime(sf1, sf2):484"""Plots remaining lifetimes for pregnancy and age at first marriage.485486sf1: SurvivalFunction for pregnancy length487sf2: SurvivalFunction for age at first marriage488"""489thinkplot.PrePlot(cols=2)490rem_life1 = sf1.RemainingLifetime()491thinkplot.Plot(rem_life1)492thinkplot.Config(title='remaining pregnancy length',493xlabel='weeks',494ylabel='mean remaining weeks')495496thinkplot.SubPlot(2)497func = lambda pmf: pmf.Percentile(50)498rem_life2 = sf2.RemainingLifetime(filler=np.inf, func=func)499thinkplot.Plot(rem_life2)500thinkplot.Config(title='years until first marriage',501ylim=[0, 15],502xlim=[11, 31],503xlabel='age (years)',504ylabel='median remaining years')505506thinkplot.Save(root='survival6',507formats=FORMATS)508509510511def PlotResampledByDecade(resps, iters=11, predict_flag=False, omit=None):512"""Plots survival curves for resampled data.513514resps: list of DataFrames515iters: number of resamples to plot516predict_flag: whether to also plot predictions517"""518for i in range(iters):519samples = [thinkstats2.ResampleRowsWeighted(resp)520for resp in resps]521sample = pd.concat(samples, ignore_index=True)522groups = sample.groupby('decade')523524if omit:525groups = [(name, group) for name, group in groups526if name not in omit]527528# TODO: refactor this to collect resampled estimates and529# plot shaded areas530if i == 0:531AddLabelsByDecade(groups, alpha=0.7)532533if predict_flag:534PlotPredictionsByDecade(groups, alpha=0.1)535EstimateMarriageSurvivalByDecade(groups, alpha=0.1)536else:537EstimateMarriageSurvivalByDecade(groups, alpha=0.2)538539540541# NOTE: The functions below are copied from marriage.py in542# the MarriageNSFG repo.543544def ReadFemResp1995():545"""Reads respondent data from NSFG Cycle 5.546547returns: DataFrame548"""549dat_file = '1995FemRespData.dat.gz'550names = ['cmintvw', 'timesmar', 'cmmarrhx', 'cmbirth', 'finalwgt']551colspecs = [(12360-1, 12363),552(4637-1, 4638),553(11759-1, 11762),554(14-1, 16),555(12350-1, 12359)]556df = pd.read_fwf(dat_file,557compression='gzip',558colspecs=colspecs,559names=names)560561df.timesmar.replace([98, 99], np.nan, inplace=True)562df['evrmarry'] = (df.timesmar > 0)563564CleanFemResp(df)565return df566567568def ReadFemResp2002():569"""Reads respondent data from NSFG Cycle 6.570571returns: DataFrame572"""573usecols = ['caseid', 'cmmarrhx', 'cmdivorcx', 'cmbirth', 'cmintvw',574'evrmarry', 'parity', 'finalwgt']575df = ReadFemResp(usecols=usecols)576df['evrmarry'] = (df.evrmarry == 1)577CleanFemResp(df)578return df579580581def ReadFemResp2010():582"""Reads respondent data from NSFG Cycle 7.583584returns: DataFrame585"""586usecols = ['caseid', 'cmmarrhx', 'cmdivorcx', 'cmbirth', 'cmintvw',587'evrmarry', 'parity', 'wgtq1q16']588df = ReadFemResp('2006_2010_FemRespSetup.dct',589'2006_2010_FemResp.dat.gz',590usecols=usecols)591df['evrmarry'] = (df.evrmarry == 1)592df['finalwgt'] = df.wgtq1q16593CleanFemResp(df)594return df595596597def ReadFemResp2013():598"""Reads respondent data from NSFG Cycle 8.599600returns: DataFrame601"""602usecols = ['caseid', 'cmmarrhx', 'cmdivorcx', 'cmbirth', 'cmintvw',603'evrmarry', 'parity', 'wgt2011_2013']604df = ReadFemResp('2011_2013_FemRespSetup.dct',605'2011_2013_FemRespData.dat.gz',606usecols=usecols)607df['evrmarry'] = (df.evrmarry == 1)608df['finalwgt'] = df.wgt2011_2013609CleanFemResp(df)610return df611612613def ReadFemResp(dct_file='2002FemResp.dct',614dat_file='2002FemResp.dat.gz',615**options):616"""Reads the NSFG respondent data.617618dct_file: string file name619dat_file: string file name620621returns: DataFrame622"""623dct = thinkstats2.ReadStataDct(dct_file, encoding='iso-8859-1')624df = dct.ReadFixedWidth(dat_file, compression='gzip', **options)625return df626627628def CleanFemResp(resp):629"""Cleans a respondent DataFrame.630631resp: DataFrame of respondents632633Adds columns: agemarry, age, decade, fives634"""635resp.cmmarrhx.replace([9997, 9998, 9999], np.nan, inplace=True)636637resp['agemarry'] = (resp.cmmarrhx - resp.cmbirth) / 12.0638resp['age'] = (resp.cmintvw - resp.cmbirth) / 12.0639640month0 = pd.to_datetime('1899-12-15')641dates = [month0 + pd.DateOffset(months=cm)642for cm in resp.cmbirth]643resp['year'] = (pd.DatetimeIndex(dates).year - 1900)644resp['decade'] = resp.year // 10645resp['fives'] = resp.year // 5646647648def main():649thinkstats2.RandomSeed(17)650651preg = nsfg.ReadFemPreg()652sf1 = PlotPregnancyData(preg)653654# make the plots based on Cycle 6655resp6 = ReadFemResp2002()656657sf2 = PlotMarriageData(resp6)658659ResampleSurvival(resp6)660661PlotRemainingLifetime(sf1, sf2)662663# read Cycles 5 and 7664resp5 = ReadFemResp1995()665resp7 = ReadFemResp2010()666667# plot resampled survival functions by decade668resps = [resp5, resp6, resp7]669PlotResampledByDecade(resps)670thinkplot.Save(root='survival4',671xlabel='age (years)',672ylabel='prob unmarried',673xlim=[13, 45],674ylim=[0, 1],675formats=FORMATS)676677# plot resampled survival functions by decade, with predictions678PlotResampledByDecade(resps, predict_flag=True, omit=[5])679thinkplot.Save(root='survival5',680xlabel='age (years)',681ylabel='prob unmarried',682xlim=[13, 45],683ylim=[0, 1],684formats=FORMATS)685686687if __name__ == '__main__':688main()689690691