Contact
CoCalc Logo Icon
StoreFeaturesDocsShareSupport News AboutSign UpSign In
| Download

Think Stats by Allen B. Downey Think Stats is an introduction to Probability and Statistics for Python programmers.

This is the accompanying code for this book.

Website: http://greenteapress.com/wp/think-stats-2e/

Views: 7120
License: GPL3
1
2
"""This file contains code for use with "Think Stats",
3
by Allen B. Downey, available from greenteapress.com
4
5
Copyright 2014 Allen B. Downey
6
License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html
7
"""
8
9
from __future__ import print_function, division
10
11
import numpy as np
12
import pandas as pd
13
14
import nsfg
15
16
import thinkstats2
17
import thinkplot
18
19
from collections import Counter
20
21
FORMATS = ['pdf', 'eps', 'png']
22
23
24
class SurvivalFunction(object):
25
"""Represents a survival function."""
26
27
def __init__(self, ts, ss, label=''):
28
self.ts = ts
29
self.ss = ss
30
self.label = label
31
32
def __len__(self):
33
return len(self.ts)
34
35
def __getitem__(self, t):
36
return self.Prob(t)
37
38
def Prob(self, t):
39
"""Returns S(t), the probability that corresponds to value t.
40
t: time
41
returns: float probability
42
"""
43
return np.interp(t, self.ts, self.ss, left=1.0)
44
45
def Probs(self, ts):
46
"""Gets probabilities for a sequence of values."""
47
return np.interp(ts, self.ts, self.ss, left=1.0)
48
49
def Items(self):
50
"""Sorted sequence of (t, s) pairs."""
51
return zip(self.ts, self.ss)
52
53
def Render(self):
54
"""Generates a sequence of points suitable for plotting.
55
returns: tuple of (sorted times, survival function)
56
"""
57
return self.ts, self.ss
58
59
def MakeHazardFunction(self, label=''):
60
"""Computes the hazard function.
61
62
This simple version does not take into account the
63
spacing between the ts. If the ts are not equally
64
spaced, it is not valid to compare the magnitude of
65
the hazard function across different time steps.
66
67
label: string
68
69
returns: HazardFunction object
70
"""
71
lams = pd.Series(index=self.ts)
72
73
prev = 1.0
74
for t, s in zip(self.ts, self.ss):
75
lams[t] = (prev - s) / prev
76
prev = s
77
78
return HazardFunction(lams, label=label)
79
80
def MakePmf(self, filler=None):
81
"""Makes a PMF of lifetimes.
82
83
filler: value to replace missing values
84
85
returns: Pmf
86
"""
87
cdf = thinkstats2.Cdf(self.ts, 1-self.ss)
88
pmf = thinkstats2.Pmf()
89
for val, prob in cdf.Items():
90
pmf.Set(val, prob)
91
92
cutoff = cdf.ps[-1]
93
if filler is not None:
94
pmf[filler] = 1-cutoff
95
96
return pmf
97
98
def RemainingLifetime(self, filler=None, func=thinkstats2.Pmf.Mean):
99
"""Computes remaining lifetime as a function of age.
100
func: function from conditional Pmf to expected liftime
101
returns: Series that maps from age to remaining lifetime
102
"""
103
pmf = self.MakePmf(filler=filler)
104
d = {}
105
for t in sorted(pmf.Values())[:-1]:
106
pmf[t] = 0
107
pmf.Normalize()
108
d[t] = func(pmf) - t
109
110
return pd.Series(d)
111
112
113
def MakeSurvivalFromSeq(values, label=''):
114
"""Makes a survival function based on a complete dataset.
115
116
values: sequence of observed lifespans
117
118
returns: SurvivalFunction
119
"""
120
counter = Counter(values)
121
ts, freqs = zip(*sorted(counter.items()))
122
ts = np.asarray(ts)
123
ps = np.cumsum(freqs, dtype=np.float)
124
ps /= ps[-1]
125
ss = 1 - ps
126
return SurvivalFunction(ts, ss, label)
127
128
129
def MakeSurvivalFromCdf(cdf, label=''):
130
"""Makes a survival function based on a CDF.
131
132
cdf: Cdf
133
134
returns: SurvivalFunction
135
"""
136
ts = cdf.xs
137
ss = 1 - cdf.ps
138
return SurvivalFunction(ts, ss, label)
139
140
141
class HazardFunction(object):
142
"""Represents a hazard function."""
143
144
def __init__(self, d, label=''):
145
"""Initialize the hazard function.
146
147
d: dictionary (or anything that can initialize a series)
148
label: string
149
"""
150
self.series = pd.Series(d)
151
self.label = label
152
153
def __len__(self):
154
return len(self.series)
155
156
def __getitem__(self, t):
157
return self.series[t]
158
159
def Get(self, t, default=np.nan):
160
return self.series.get(t, default)
161
162
def Render(self):
163
"""Generates a sequence of points suitable for plotting.
164
165
returns: tuple of (sorted times, hazard function)
166
"""
167
return self.series.index, self.series.values
168
169
def MakeSurvival(self, label=''):
170
"""Makes the survival function.
171
172
returns: SurvivalFunction
173
"""
174
ts = self.series.index
175
ss = (1 - self.series).cumprod()
176
sf = SurvivalFunction(ts, ss, label=label)
177
return sf
178
179
def Extend(self, other):
180
"""Extends this hazard function by copying the tail from another.
181
other: HazardFunction
182
"""
183
last_index = self.series.index[-1] if len(self) else 0
184
more = other.series[other.series.index > last_index]
185
self.series = pd.concat([self.series, more])
186
187
def Truncate(self, t):
188
"""Truncates this hazard function at the given value of t.
189
t: number
190
"""
191
self.series = self.series[self.series.index < t]
192
193
194
def ConditionalSurvival(pmf, t0):
195
"""Computes conditional survival function.
196
197
Probability that duration exceeds t0+t, given that
198
duration >= t0.
199
200
pmf: Pmf of durations
201
t0: minimum time
202
203
returns: tuple of (ts, conditional survivals)
204
"""
205
cond = thinkstats2.Pmf()
206
for t, p in pmf.Items():
207
if t >= t0:
208
cond.Set(t-t0, p)
209
cond.Normalize()
210
return MakeSurvivalFromCdf(cond.MakeCdf())
211
212
213
def PlotConditionalSurvival(durations):
214
"""Plots conditional survival curves for a range of t0.
215
216
durations: list of durations
217
"""
218
pmf = thinkstats2.Pmf(durations)
219
220
times = [8, 16, 24, 32]
221
thinkplot.PrePlot(len(times))
222
223
for t0 in times:
224
sf = ConditionalSurvival(pmf, t0)
225
label = 't0=%d' % t0
226
thinkplot.Plot(sf, label=label)
227
228
thinkplot.Show()
229
230
231
def PlotSurvival(complete):
232
"""Plots survival and hazard curves.
233
234
complete: list of complete lifetimes
235
"""
236
thinkplot.PrePlot(3, rows=2)
237
238
cdf = thinkstats2.Cdf(complete, label='cdf')
239
sf = MakeSurvivalFromCdf(cdf, label='survival')
240
print(cdf[13])
241
print(sf[13])
242
243
thinkplot.Plot(sf)
244
thinkplot.Cdf(cdf, alpha=0.2)
245
thinkplot.Config()
246
247
thinkplot.SubPlot(2)
248
hf = sf.MakeHazardFunction(label='hazard')
249
print(hf[39])
250
thinkplot.Plot(hf)
251
thinkplot.Config(ylim=[0, 0.75])
252
253
254
def PlotHazard(complete, ongoing):
255
"""Plots the hazard function and survival function.
256
257
complete: list of complete lifetimes
258
ongoing: list of ongoing lifetimes
259
"""
260
# plot S(t) based on only complete pregnancies
261
sf = MakeSurvivalFromSeq(complete)
262
thinkplot.Plot(sf, label='old S(t)', alpha=0.1)
263
264
thinkplot.PrePlot(2)
265
266
# plot the hazard function
267
hf = EstimateHazardFunction(complete, ongoing)
268
thinkplot.Plot(hf, label='lams(t)', alpha=0.5)
269
270
# plot the survival function
271
sf = hf.MakeSurvival()
272
273
thinkplot.Plot(sf, label='S(t)')
274
thinkplot.Show(xlabel='t (weeks)')
275
276
277
def EstimateHazardFunction(complete, ongoing, label='', verbose=False):
278
"""Estimates the hazard function by Kaplan-Meier.
279
280
http://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator
281
282
complete: list of complete lifetimes
283
ongoing: list of ongoing lifetimes
284
label: string
285
verbose: whether to display intermediate results
286
"""
287
if np.sum(np.isnan(complete)):
288
raise ValueError("complete contains NaNs")
289
if np.sum(np.isnan(ongoing)):
290
raise ValueError("ongoing contains NaNs")
291
292
hist_complete = Counter(complete)
293
hist_ongoing = Counter(ongoing)
294
295
ts = list(hist_complete | hist_ongoing)
296
ts.sort()
297
298
at_risk = len(complete) + len(ongoing)
299
300
lams = pd.Series(index=ts)
301
for t in ts:
302
ended = hist_complete[t]
303
censored = hist_ongoing[t]
304
305
lams[t] = ended / at_risk
306
if verbose:
307
print('%0.3g\t%d\t%d\t%d\t%0.2g' %
308
(t, at_risk, ended, censored, lams[t]))
309
at_risk -= ended + censored
310
311
return HazardFunction(lams, label=label)
312
313
314
def EstimateHazardNumpy(complete, ongoing, label=''):
315
"""Estimates the hazard function by Kaplan-Meier.
316
317
Just for fun, this is a version that uses NumPy to
318
eliminate loops.
319
320
complete: list of complete lifetimes
321
ongoing: list of ongoing lifetimes
322
label: string
323
"""
324
hist_complete = Counter(complete)
325
hist_ongoing = Counter(ongoing)
326
327
ts = set(hist_complete) | set(hist_ongoing)
328
at_risk = len(complete) + len(ongoing)
329
330
ended = [hist_complete[t] for t in ts]
331
ended_c = np.cumsum(ended)
332
censored_c = np.cumsum([hist_ongoing[t] for t in ts])
333
334
not_at_risk = np.roll(ended_c, 1) + np.roll(censored_c, 1)
335
not_at_risk[0] = 0
336
337
at_risk_array = at_risk - not_at_risk
338
hs = ended / at_risk_array
339
340
lams = dict(zip(ts, hs))
341
342
return HazardFunction(lams, label=label)
343
344
345
def AddLabelsByDecade(groups, **options):
346
"""Draws fake points in order to add labels to the legend.
347
348
groups: GroupBy object
349
"""
350
thinkplot.PrePlot(len(groups))
351
for name, _ in groups:
352
label = '%d0s' % name
353
thinkplot.Plot([15], [1], label=label, **options)
354
355
356
def EstimateMarriageSurvivalByDecade(groups, **options):
357
"""Groups respondents by decade and plots survival curves.
358
359
groups: GroupBy object
360
"""
361
thinkplot.PrePlot(len(groups))
362
for _, group in groups:
363
_, sf = EstimateMarriageSurvival(group)
364
thinkplot.Plot(sf, **options)
365
366
367
def PlotPredictionsByDecade(groups, **options):
368
"""Groups respondents by decade and plots survival curves.
369
370
groups: GroupBy object
371
"""
372
hfs = []
373
for _, group in groups:
374
hf, sf = EstimateMarriageSurvival(group)
375
hfs.append(hf)
376
377
thinkplot.PrePlot(len(hfs))
378
for i, hf in enumerate(hfs):
379
if i > 0:
380
hf.Extend(hfs[i-1])
381
sf = hf.MakeSurvival()
382
thinkplot.Plot(sf, **options)
383
384
385
def ResampleSurvival(resp, iters=101):
386
"""Resamples respondents and estimates the survival function.
387
388
resp: DataFrame of respondents
389
iters: number of resamples
390
"""
391
_, sf = EstimateMarriageSurvival(resp)
392
thinkplot.Plot(sf)
393
394
low, high = resp.agemarry.min(), resp.agemarry.max()
395
ts = np.arange(low, high, 1/12.0)
396
397
ss_seq = []
398
for _ in range(iters):
399
sample = thinkstats2.ResampleRowsWeighted(resp)
400
_, sf = EstimateMarriageSurvival(sample)
401
ss_seq.append(sf.Probs(ts))
402
403
low, high = thinkstats2.PercentileRows(ss_seq, [5, 95])
404
thinkplot.FillBetween(ts, low, high, color='gray', label='90% CI')
405
thinkplot.Save(root='survival3',
406
xlabel='age (years)',
407
ylabel='prob unmarried',
408
xlim=[12, 46],
409
ylim=[0, 1],
410
formats=FORMATS)
411
412
413
def EstimateMarriageSurvival(resp):
414
"""Estimates the survival curve.
415
416
resp: DataFrame of respondents
417
418
returns: pair of HazardFunction, SurvivalFunction
419
"""
420
# NOTE: Filling missing values would be better than dropping them.
421
complete = resp[resp.evrmarry == 1].agemarry.dropna()
422
ongoing = resp[resp.evrmarry == 0].age
423
424
hf = EstimateHazardFunction(complete, ongoing)
425
sf = hf.MakeSurvival()
426
427
return hf, sf
428
429
430
def PlotMarriageData(resp):
431
"""Plots hazard and survival functions.
432
433
resp: DataFrame of respondents
434
"""
435
hf, sf = EstimateMarriageSurvival(resp)
436
437
thinkplot.PrePlot(rows=2)
438
thinkplot.Plot(hf)
439
thinkplot.Config(ylabel='hazard', legend=False)
440
441
thinkplot.SubPlot(2)
442
thinkplot.Plot(sf)
443
thinkplot.Save(root='survival2',
444
xlabel='age (years)',
445
ylabel='prob unmarried',
446
ylim=[0, 1],
447
legend=False,
448
formats=FORMATS)
449
return sf
450
451
452
def PlotPregnancyData(preg):
453
"""Plots survival and hazard curves based on pregnancy lengths.
454
455
preg:
456
457
458
Outcome codes from http://www.icpsr.umich.edu/nsfg6/Controller?
459
displayPage=labelDetails&fileCode=PREG&section=&subSec=8016&srtLabel=611932
460
461
1 LIVE BIRTH 9148
462
2 INDUCED ABORTION 1862
463
3 STILLBIRTH 120
464
4 MISCARRIAGE 1921
465
5 ECTOPIC PREGNANCY 190
466
6 CURRENT PREGNANCY 352
467
468
"""
469
complete = preg.query('outcome in [1, 3, 4]').prglngth
470
print('Number of complete pregnancies', len(complete))
471
ongoing = preg[preg.outcome == 6].prglngth
472
print('Number of ongoing pregnancies', len(ongoing))
473
474
PlotSurvival(complete)
475
thinkplot.Save(root='survival1',
476
xlabel='t (weeks)',
477
formats=FORMATS)
478
479
hf = EstimateHazardFunction(complete, ongoing)
480
sf = hf.MakeSurvival()
481
return sf
482
483
484
def PlotRemainingLifetime(sf1, sf2):
485
"""Plots remaining lifetimes for pregnancy and age at first marriage.
486
487
sf1: SurvivalFunction for pregnancy length
488
sf2: SurvivalFunction for age at first marriage
489
"""
490
thinkplot.PrePlot(cols=2)
491
rem_life1 = sf1.RemainingLifetime()
492
thinkplot.Plot(rem_life1)
493
thinkplot.Config(title='remaining pregnancy length',
494
xlabel='weeks',
495
ylabel='mean remaining weeks')
496
497
thinkplot.SubPlot(2)
498
func = lambda pmf: pmf.Percentile(50)
499
rem_life2 = sf2.RemainingLifetime(filler=np.inf, func=func)
500
thinkplot.Plot(rem_life2)
501
thinkplot.Config(title='years until first marriage',
502
ylim=[0, 15],
503
xlim=[11, 31],
504
xlabel='age (years)',
505
ylabel='median remaining years')
506
507
thinkplot.Save(root='survival6',
508
formats=FORMATS)
509
510
511
512
def PlotResampledByDecade(resps, iters=11, predict_flag=False, omit=None):
513
"""Plots survival curves for resampled data.
514
515
resps: list of DataFrames
516
iters: number of resamples to plot
517
predict_flag: whether to also plot predictions
518
"""
519
for i in range(iters):
520
samples = [thinkstats2.ResampleRowsWeighted(resp)
521
for resp in resps]
522
sample = pd.concat(samples, ignore_index=True)
523
groups = sample.groupby('decade')
524
525
if omit:
526
groups = [(name, group) for name, group in groups
527
if name not in omit]
528
529
# TODO: refactor this to collect resampled estimates and
530
# plot shaded areas
531
if i == 0:
532
AddLabelsByDecade(groups, alpha=0.7)
533
534
if predict_flag:
535
PlotPredictionsByDecade(groups, alpha=0.1)
536
EstimateMarriageSurvivalByDecade(groups, alpha=0.1)
537
else:
538
EstimateMarriageSurvivalByDecade(groups, alpha=0.2)
539
540
541
542
# NOTE: The functions below are copied from marriage.py in
543
# the MarriageNSFG repo.
544
545
def ReadFemResp1995():
546
"""Reads respondent data from NSFG Cycle 5.
547
548
returns: DataFrame
549
"""
550
dat_file = '1995FemRespData.dat.gz'
551
names = ['cmintvw', 'timesmar', 'cmmarrhx', 'cmbirth', 'finalwgt']
552
colspecs = [(12360-1, 12363),
553
(4637-1, 4638),
554
(11759-1, 11762),
555
(14-1, 16),
556
(12350-1, 12359)]
557
df = pd.read_fwf(dat_file,
558
compression='gzip',
559
colspecs=colspecs,
560
names=names)
561
562
df.timesmar.replace([98, 99], np.nan, inplace=True)
563
df['evrmarry'] = (df.timesmar > 0)
564
565
CleanFemResp(df)
566
return df
567
568
569
def ReadFemResp2002():
570
"""Reads respondent data from NSFG Cycle 6.
571
572
returns: DataFrame
573
"""
574
usecols = ['caseid', 'cmmarrhx', 'cmdivorcx', 'cmbirth', 'cmintvw',
575
'evrmarry', 'parity', 'finalwgt']
576
df = ReadFemResp(usecols=usecols)
577
df['evrmarry'] = (df.evrmarry == 1)
578
CleanFemResp(df)
579
return df
580
581
582
def ReadFemResp2010():
583
"""Reads respondent data from NSFG Cycle 7.
584
585
returns: DataFrame
586
"""
587
usecols = ['caseid', 'cmmarrhx', 'cmdivorcx', 'cmbirth', 'cmintvw',
588
'evrmarry', 'parity', 'wgtq1q16']
589
df = ReadFemResp('2006_2010_FemRespSetup.dct',
590
'2006_2010_FemResp.dat.gz',
591
usecols=usecols)
592
df['evrmarry'] = (df.evrmarry == 1)
593
df['finalwgt'] = df.wgtq1q16
594
CleanFemResp(df)
595
return df
596
597
598
def ReadFemResp2013():
599
"""Reads respondent data from NSFG Cycle 8.
600
601
returns: DataFrame
602
"""
603
usecols = ['caseid', 'cmmarrhx', 'cmdivorcx', 'cmbirth', 'cmintvw',
604
'evrmarry', 'parity', 'wgt2011_2013']
605
df = ReadFemResp('2011_2013_FemRespSetup.dct',
606
'2011_2013_FemRespData.dat.gz',
607
usecols=usecols)
608
df['evrmarry'] = (df.evrmarry == 1)
609
df['finalwgt'] = df.wgt2011_2013
610
CleanFemResp(df)
611
return df
612
613
614
def ReadFemResp(dct_file='2002FemResp.dct',
615
dat_file='2002FemResp.dat.gz',
616
**options):
617
"""Reads the NSFG respondent data.
618
619
dct_file: string file name
620
dat_file: string file name
621
622
returns: DataFrame
623
"""
624
dct = thinkstats2.ReadStataDct(dct_file, encoding='iso-8859-1')
625
df = dct.ReadFixedWidth(dat_file, compression='gzip', **options)
626
return df
627
628
629
def CleanFemResp(resp):
630
"""Cleans a respondent DataFrame.
631
632
resp: DataFrame of respondents
633
634
Adds columns: agemarry, age, decade, fives
635
"""
636
resp.cmmarrhx.replace([9997, 9998, 9999], np.nan, inplace=True)
637
638
resp['agemarry'] = (resp.cmmarrhx - resp.cmbirth) / 12.0
639
resp['age'] = (resp.cmintvw - resp.cmbirth) / 12.0
640
641
month0 = pd.to_datetime('1899-12-15')
642
dates = [month0 + pd.DateOffset(months=cm)
643
for cm in resp.cmbirth]
644
resp['year'] = (pd.DatetimeIndex(dates).year - 1900)
645
resp['decade'] = resp.year // 10
646
resp['fives'] = resp.year // 5
647
648
649
def main():
650
thinkstats2.RandomSeed(17)
651
652
preg = nsfg.ReadFemPreg()
653
sf1 = PlotPregnancyData(preg)
654
655
# make the plots based on Cycle 6
656
resp6 = ReadFemResp2002()
657
658
sf2 = PlotMarriageData(resp6)
659
660
ResampleSurvival(resp6)
661
662
PlotRemainingLifetime(sf1, sf2)
663
664
# read Cycles 5 and 7
665
resp5 = ReadFemResp1995()
666
resp7 = ReadFemResp2010()
667
668
# plot resampled survival functions by decade
669
resps = [resp5, resp6, resp7]
670
PlotResampledByDecade(resps)
671
thinkplot.Save(root='survival4',
672
xlabel='age (years)',
673
ylabel='prob unmarried',
674
xlim=[13, 45],
675
ylim=[0, 1],
676
formats=FORMATS)
677
678
# plot resampled survival functions by decade, with predictions
679
PlotResampledByDecade(resps, predict_flag=True, omit=[5])
680
thinkplot.Save(root='survival5',
681
xlabel='age (years)',
682
ylabel='prob unmarried',
683
xlim=[13, 45],
684
ylim=[0, 1],
685
formats=FORMATS)
686
687
688
if __name__ == '__main__':
689
main()
690
691