Contact
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
| Download

📚 The CoCalc Library - books, templates and other resources

Views: 96164
License: OTHER
1
"""This file contains code for use with "Think Stats",
2
by Allen B. Downey, available from greenteapress.com
3
4
Copyright 2014 Allen B. Downey
5
License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html
6
"""
7
8
from __future__ import print_function
9
10
import logging
11
import math
12
import matplotlib
13
import matplotlib.pyplot as pyplot
14
import numpy as np
15
import pandas
16
17
# customize some matplotlib attributes
18
#matplotlib.rc('figure', figsize=(4, 3))
19
20
#matplotlib.rc('font', size=14.0)
21
#matplotlib.rc('axes', labelsize=22.0, titlesize=22.0)
22
#matplotlib.rc('legend', fontsize=20.0)
23
24
#matplotlib.rc('xtick.major', size=6.0)
25
#matplotlib.rc('xtick.minor', size=3.0)
26
27
#matplotlib.rc('ytick.major', size=6.0)
28
#matplotlib.rc('ytick.minor', size=3.0)
29
30
31
class _Brewer(object):
32
"""Encapsulates a nice sequence of colors.
33
34
Shades of blue that look good in color and can be distinguished
35
in grayscale (up to a point).
36
37
Borrowed from http://colorbrewer2.org/
38
"""
39
color_iter = None
40
41
colors = ['#081D58',
42
'#253494',
43
'#225EA8',
44
'#1D91C0',
45
'#41B6C4',
46
'#7FCDBB',
47
'#C7E9B4',
48
'#EDF8B1',
49
'#FFFFD9']
50
51
# lists that indicate which colors to use depending on how many are used
52
which_colors = [[],
53
[1],
54
[1, 3],
55
[0, 2, 4],
56
[0, 2, 4, 6],
57
[0, 2, 3, 5, 6],
58
[0, 2, 3, 4, 5, 6],
59
[0, 1, 2, 3, 4, 5, 6],
60
]
61
62
@classmethod
63
def Colors(cls):
64
"""Returns the list of colors.
65
"""
66
return cls.colors
67
68
@classmethod
69
def ColorGenerator(cls, n):
70
"""Returns an iterator of color strings.
71
72
n: how many colors will be used
73
"""
74
for i in cls.which_colors[n]:
75
yield cls.colors[i]
76
raise StopIteration('Ran out of colors in _Brewer.ColorGenerator')
77
78
@classmethod
79
def InitializeIter(cls, num):
80
"""Initializes the color iterator with the given number of colors."""
81
cls.color_iter = cls.ColorGenerator(num)
82
83
@classmethod
84
def ClearIter(cls):
85
"""Sets the color iterator to None."""
86
cls.color_iter = None
87
88
@classmethod
89
def GetIter(cls):
90
"""Gets the color iterator."""
91
if cls.color_iter is None:
92
cls.InitializeIter(7)
93
94
return cls.color_iter
95
96
97
def PrePlot(num=None, rows=None, cols=None):
98
"""Takes hints about what's coming.
99
100
num: number of lines that will be plotted
101
rows: number of rows of subplots
102
cols: number of columns of subplots
103
"""
104
if num:
105
_Brewer.InitializeIter(num)
106
107
if rows is None and cols is None:
108
return
109
110
if rows is not None and cols is None:
111
cols = 1
112
113
if cols is not None and rows is None:
114
rows = 1
115
116
# resize the image, depending on the number of rows and cols
117
size_map = {(1, 1): (8, 6),
118
(1, 2): (14, 6),
119
(1, 3): (14, 6),
120
(2, 2): (10, 10),
121
(2, 3): (16, 10),
122
(3, 1): (8, 10),
123
}
124
125
if (rows, cols) in size_map:
126
fig = pyplot.gcf()
127
fig.set_size_inches(*size_map[rows, cols])
128
129
# create the first subplot
130
if rows > 1 or cols > 1:
131
pyplot.subplot(rows, cols, 1)
132
global SUBPLOT_ROWS, SUBPLOT_COLS
133
SUBPLOT_ROWS = rows
134
SUBPLOT_COLS = cols
135
136
137
def SubPlot(plot_number, rows=None, cols=None):
138
"""Configures the number of subplots and changes the current plot.
139
140
rows: int
141
cols: int
142
plot_number: int
143
"""
144
rows = rows or SUBPLOT_ROWS
145
cols = cols or SUBPLOT_COLS
146
pyplot.subplot(rows, cols, plot_number)
147
148
149
def _Underride(d, **options):
150
"""Add key-value pairs to d only if key is not in d.
151
152
If d is None, create a new dictionary.
153
154
d: dictionary
155
options: keyword args to add to d
156
"""
157
if d is None:
158
d = {}
159
160
for key, val in options.items():
161
d.setdefault(key, val)
162
163
return d
164
165
166
def Clf():
167
"""Clears the figure and any hints that have been set."""
168
_Brewer.ClearIter()
169
pyplot.clf()
170
fig = pyplot.gcf()
171
fig.set_size_inches(8, 6)
172
173
174
def Figure(**options):
175
"""Sets options for the current figure."""
176
_Underride(options, figsize=(6, 8))
177
pyplot.figure(**options)
178
179
180
def _UnderrideColor(options):
181
if 'color' in options:
182
return options
183
184
color_iter = _Brewer.GetIter()
185
186
if color_iter:
187
try:
188
options['color'] = next(color_iter)
189
except StopIteration:
190
print('Warning: Brewer ran out of colors.')
191
_Brewer.ClearIter()
192
return options
193
194
195
def Plot(obj, ys=None, style='', **options):
196
"""Plots a line.
197
198
Args:
199
obj: sequence of x values, or Series, or anything with Render()
200
ys: sequence of y values
201
style: style string passed along to pyplot.plot
202
options: keyword args passed to pyplot.plot
203
"""
204
options = _UnderrideColor(options)
205
label = getattr(obj, 'name', '_nolegend_')
206
options = _Underride(options, linewidth=3, alpha=0.8, label=label)
207
208
xs = obj
209
if ys is None:
210
if hasattr(obj, 'Render'):
211
xs, ys = obj.Render()
212
if isinstance(obj, pandas.Series):
213
ys = obj.values
214
xs = obj.index
215
216
if ys is None:
217
pyplot.plot(xs, style, **options)
218
else:
219
pyplot.plot(xs, ys, style, **options)
220
221
222
def FillBetween(xs, y1, y2=None, where=None, **options):
223
"""Plots a line.
224
225
Args:
226
xs: sequence of x values
227
y1: sequence of y values
228
y2: sequence of y values
229
where: sequence of boolean
230
options: keyword args passed to pyplot.fill_between
231
"""
232
options = _UnderrideColor(options)
233
options = _Underride(options, linewidth=0, alpha=0.5)
234
pyplot.fill_between(xs, y1, y2, where, **options)
235
236
237
def Bar(xs, ys, **options):
238
"""Plots a line.
239
240
Args:
241
xs: sequence of x values
242
ys: sequence of y values
243
options: keyword args passed to pyplot.bar
244
"""
245
options = _UnderrideColor(options)
246
options = _Underride(options, linewidth=0, alpha=0.6)
247
pyplot.bar(xs, ys, **options)
248
249
250
def Scatter(xs, ys=None, **options):
251
"""Makes a scatter plot.
252
253
xs: x values
254
ys: y values
255
options: options passed to pyplot.scatter
256
"""
257
options = _Underride(options, color='blue', alpha=0.2,
258
s=30, edgecolors='none')
259
260
if ys is None and isinstance(xs, pandas.Series):
261
ys = xs.values
262
xs = xs.index
263
264
pyplot.scatter(xs, ys, **options)
265
266
267
def HexBin(xs, ys, **options):
268
"""Makes a scatter plot.
269
270
xs: x values
271
ys: y values
272
options: options passed to pyplot.scatter
273
"""
274
options = _Underride(options, cmap=matplotlib.cm.Blues)
275
pyplot.hexbin(xs, ys, **options)
276
277
278
def Pdf(pdf, **options):
279
"""Plots a Pdf, Pmf, or Hist as a line.
280
281
Args:
282
pdf: Pdf, Pmf, or Hist object
283
options: keyword args passed to pyplot.plot
284
"""
285
low, high = options.pop('low', None), options.pop('high', None)
286
n = options.pop('n', 101)
287
xs, ps = pdf.Render(low=low, high=high, n=n)
288
options = _Underride(options, label=pdf.name)
289
Plot(xs, ps, **options)
290
291
292
def Pdfs(pdfs, **options):
293
"""Plots a sequence of PDFs.
294
295
Options are passed along for all PDFs. If you want different
296
options for each pdf, make multiple calls to Pdf.
297
298
Args:
299
pdfs: sequence of PDF objects
300
options: keyword args passed to pyplot.plot
301
"""
302
for pdf in pdfs:
303
Pdf(pdf, **options)
304
305
306
def Hist(hist, **options):
307
"""Plots a Pmf or Hist with a bar plot.
308
309
The default width of the bars is based on the minimum difference
310
between values in the Hist. If that's too small, you can override
311
it by providing a width keyword argument, in the same units
312
as the values.
313
314
Args:
315
hist: Hist or Pmf object
316
options: keyword args passed to pyplot.bar
317
"""
318
# find the minimum distance between adjacent values
319
xs, ys = hist.Render()
320
321
if 'width' not in options:
322
try:
323
options['width'] = 0.9 * np.diff(xs).min()
324
except TypeError:
325
logging.warning("Hist: Can't compute bar width automatically."
326
"Check for non-numeric types in Hist."
327
"Or try providing width option."
328
)
329
330
options = _Underride(options, label=hist.name)
331
options = _Underride(options, align='center')
332
if options['align'] == 'left':
333
options['align'] = 'edge'
334
elif options['align'] == 'right':
335
options['align'] = 'edge'
336
options['width'] *= -1
337
338
Bar(xs, ys, **options)
339
340
341
def Hists(hists, **options):
342
"""Plots two histograms as interleaved bar plots.
343
344
Options are passed along for all PMFs. If you want different
345
options for each pmf, make multiple calls to Pmf.
346
347
Args:
348
hists: list of two Hist or Pmf objects
349
options: keyword args passed to pyplot.plot
350
"""
351
for hist in hists:
352
Hist(hist, **options)
353
354
355
def Pmf(pmf, **options):
356
"""Plots a Pmf or Hist as a line.
357
358
Args:
359
pmf: Hist or Pmf object
360
options: keyword args passed to pyplot.plot
361
"""
362
xs, ys = pmf.Render()
363
low, high = min(xs), max(xs)
364
365
width = options.pop('width', None)
366
if width is None:
367
try:
368
width = np.diff(xs).min()
369
except TypeError:
370
logging.warning("Pmf: Can't compute bar width automatically."
371
"Check for non-numeric types in Pmf."
372
"Or try providing width option.")
373
points = []
374
375
lastx = np.nan
376
lasty = 0
377
for x, y in zip(xs, ys):
378
if (x - lastx) > 1e-5:
379
points.append((lastx, 0))
380
points.append((x, 0))
381
382
points.append((x, lasty))
383
points.append((x, y))
384
points.append((x+width, y))
385
386
lastx = x + width
387
lasty = y
388
points.append((lastx, 0))
389
pxs, pys = zip(*points)
390
391
align = options.pop('align', 'center')
392
if align == 'center':
393
pxs = np.array(pxs) - width/2.0
394
if align == 'right':
395
pxs = np.array(pxs) - width
396
397
options = _Underride(options, label=pmf.name)
398
Plot(pxs, pys, **options)
399
400
401
def Pmfs(pmfs, **options):
402
"""Plots a sequence of PMFs.
403
404
Options are passed along for all PMFs. If you want different
405
options for each pmf, make multiple calls to Pmf.
406
407
Args:
408
pmfs: sequence of PMF objects
409
options: keyword args passed to pyplot.plot
410
"""
411
for pmf in pmfs:
412
Pmf(pmf, **options)
413
414
415
def Diff(t):
416
"""Compute the differences between adjacent elements in a sequence.
417
418
Args:
419
t: sequence of number
420
421
Returns:
422
sequence of differences (length one less than t)
423
"""
424
diffs = [t[i+1] - t[i] for i in range(len(t)-1)]
425
return diffs
426
427
428
def Cdf(cdf, complement=False, transform=None, **options):
429
"""Plots a CDF as a line.
430
431
Args:
432
cdf: Cdf object
433
complement: boolean, whether to plot the complementary CDF
434
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
435
options: keyword args passed to pyplot.plot
436
437
Returns:
438
dictionary with the scale options that should be passed to
439
Config, Show or Save.
440
"""
441
xs, ps = cdf.Render()
442
xs = np.asarray(xs)
443
ps = np.asarray(ps)
444
445
scale = dict(xscale='linear', yscale='linear')
446
447
for s in ['xscale', 'yscale']:
448
if s in options:
449
scale[s] = options.pop(s)
450
451
if transform == 'exponential':
452
complement = True
453
scale['yscale'] = 'log'
454
455
if transform == 'pareto':
456
complement = True
457
scale['yscale'] = 'log'
458
scale['xscale'] = 'log'
459
460
if complement:
461
ps = [1.0-p for p in ps]
462
463
if transform == 'weibull':
464
xs = np.delete(xs, -1)
465
ps = np.delete(ps, -1)
466
ps = [-math.log(1.0-p) for p in ps]
467
scale['xscale'] = 'log'
468
scale['yscale'] = 'log'
469
470
if transform == 'gumbel':
471
xs = xp.delete(xs, 0)
472
ps = np.delete(ps, 0)
473
ps = [-math.log(p) for p in ps]
474
scale['yscale'] = 'log'
475
476
options = _Underride(options, label=cdf.name)
477
Plot(xs, ps, **options)
478
return scale
479
480
481
def Cdfs(cdfs, complement=False, transform=None, **options):
482
"""Plots a sequence of CDFs.
483
484
cdfs: sequence of CDF objects
485
complement: boolean, whether to plot the complementary CDF
486
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
487
options: keyword args passed to pyplot.plot
488
"""
489
for cdf in cdfs:
490
Cdf(cdf, complement, transform, **options)
491
492
493
def Contour(obj, pcolor=False, contour=True, imshow=False, **options):
494
"""Makes a contour plot.
495
496
d: map from (x, y) to z, or object that provides GetDict
497
pcolor: boolean, whether to make a pseudocolor plot
498
contour: boolean, whether to make a contour plot
499
imshow: boolean, whether to use pyplot.imshow
500
options: keyword args passed to pyplot.pcolor and/or pyplot.contour
501
"""
502
try:
503
d = obj.GetDict()
504
except AttributeError:
505
d = obj
506
507
_Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
508
509
xs, ys = zip(*d.keys())
510
xs = sorted(set(xs))
511
ys = sorted(set(ys))
512
513
X, Y = np.meshgrid(xs, ys)
514
func = lambda x, y: d.get((x, y), 0)
515
func = np.vectorize(func)
516
Z = func(X, Y)
517
518
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
519
axes = pyplot.gca()
520
axes.xaxis.set_major_formatter(x_formatter)
521
522
if pcolor:
523
pyplot.pcolormesh(X, Y, Z, **options)
524
if contour:
525
cs = pyplot.contour(X, Y, Z, **options)
526
pyplot.clabel(cs, inline=1, fontsize=10)
527
if imshow:
528
extent = xs[0], xs[-1], ys[0], ys[-1]
529
pyplot.imshow(Z, extent=extent, **options)
530
531
532
def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options):
533
"""Makes a pseudocolor plot.
534
535
xs:
536
ys:
537
zs:
538
pcolor: boolean, whether to make a pseudocolor plot
539
contour: boolean, whether to make a contour plot
540
options: keyword args passed to pyplot.pcolor and/or pyplot.contour
541
"""
542
_Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
543
544
X, Y = np.meshgrid(xs, ys)
545
Z = zs
546
547
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
548
axes = pyplot.gca()
549
axes.xaxis.set_major_formatter(x_formatter)
550
551
if pcolor:
552
pyplot.pcolormesh(X, Y, Z, **options)
553
554
if contour:
555
cs = pyplot.contour(X, Y, Z, **options)
556
pyplot.clabel(cs, inline=1, fontsize=10)
557
558
559
def Text(x, y, s, **options):
560
"""Puts text in a figure.
561
562
x: number
563
y: number
564
s: string
565
options: keyword args passed to pyplot.text
566
"""
567
options = _Underride(options, verticalalignment='top',
568
horizontalalignment='left')
569
pyplot.text(x, y, s, **options)
570
571
572
def Config(**options):
573
"""Configures the plot.
574
575
Pulls options out of the option dictionary and passes them to
576
the corresponding pyplot functions.
577
"""
578
names = ['title', 'xlabel', 'ylabel', 'xscale', 'yscale',
579
'xticks', 'yticks', 'axis', 'xlim', 'ylim']
580
581
for name in names:
582
if name in options:
583
getattr(pyplot, name)(options[name])
584
585
# looks like this is not necessary: matplotlib understands text loc specs
586
loc_dict = {'upper right': 1,
587
'upper left': 2,
588
'lower left': 3,
589
'lower right': 4,
590
'right': 5,
591
'center left': 6,
592
'center right': 7,
593
'lower center': 8,
594
'upper center': 9,
595
'center': 10,
596
}
597
598
loc = options.get('loc', 0)
599
#loc = loc_dict.get(loc, loc)
600
601
legend = options.get('legend', True)
602
if legend:
603
pyplot.legend(loc=loc)
604
605
606
def Show(**options):
607
"""Shows the plot.
608
609
For options, see Config.
610
611
options: keyword args used to invoke various pyplot functions
612
"""
613
clf = options.pop('clf', True)
614
Config(**options)
615
pyplot.show()
616
if clf:
617
Clf()
618
619
620
def Plotly(**options):
621
"""Shows the plot.
622
623
For options, see Config.
624
625
options: keyword args used to invoke various pyplot functions
626
"""
627
clf = options.pop('clf', True)
628
Config(**options)
629
import plotly.plotly as plotly
630
url = plotly.plot_mpl(pyplot.gcf())
631
if clf:
632
Clf()
633
return url
634
635
636
def Save(root=None, formats=None, **options):
637
"""Saves the plot in the given formats and clears the figure.
638
639
For options, see Config.
640
641
Args:
642
root: string filename root
643
formats: list of string formats
644
options: keyword args used to invoke various pyplot functions
645
"""
646
clf = options.pop('clf', True)
647
Config(**options)
648
649
if formats is None:
650
formats = ['pdf', 'eps']
651
652
try:
653
formats.remove('plotly')
654
Plotly(clf=False)
655
except ValueError:
656
pass
657
658
if root:
659
for fmt in formats:
660
SaveFormat(root, fmt)
661
if clf:
662
Clf()
663
664
665
def SaveFormat(root, fmt='eps'):
666
"""Writes the current figure to a file in the given format.
667
668
Args:
669
root: string filename root
670
fmt: string format
671
"""
672
filename = '%s.%s' % (root, fmt)
673
print('Writing', filename)
674
pyplot.savefig(filename, format=fmt, dpi=300)
675
676
677
# provide aliases for calling functons with lower-case names
678
preplot = PrePlot
679
subplot = SubPlot
680
clf = Clf
681
figure = Figure
682
plot = Plot
683
scatter = Scatter
684
pmf = Pmf
685
pmfs = Pmfs
686
hist = Hist
687
hists = Hists
688
diff = Diff
689
cdf = Cdf
690
cdfs = Cdfs
691
contour = Contour
692
pcolor = Pcolor
693
config = Config
694
show = Show
695
save = Save
696
697
698
def main():
699
color_iter = _Brewer.ColorGenerator(7)
700
for color in color_iter:
701
print(color)
702
703
704
if __name__ == '__main__':
705
main()
706
707