Contact
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
| Download

📚 The CoCalc Library - books, templates and other resources

Views: 96117
License: OTHER
1
"""This file contains code for use with "Think Stats",
2
by Allen B. Downey, available from greenteapress.com
3
4
Copyright 2010 Allen B. Downey
5
License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html
6
"""
7
8
import math
9
import matplotlib
10
import matplotlib.pyplot as pyplot
11
import numpy as np
12
13
# customize some matplotlib attributes
14
#matplotlib.rc('figure', figsize=(4, 3))
15
16
#matplotlib.rc('font', size=14.0)
17
#matplotlib.rc('axes', labelsize=22.0, titlesize=22.0)
18
#matplotlib.rc('legend', fontsize=20.0)
19
20
#matplotlib.rc('xtick.major', size=6.0)
21
#matplotlib.rc('xtick.minor', size=3.0)
22
23
#matplotlib.rc('ytick.major', size=6.0)
24
#matplotlib.rc('ytick.minor', size=3.0)
25
26
27
class Brewer(object):
28
"""Encapsulates a nice sequence of colors.
29
30
Shades of blue that look good in color and can be distinguished
31
in grayscale (up to a point).
32
33
Borrowed from http://colorbrewer2.org/
34
"""
35
color_iter = None
36
37
colors = ['#081D58',
38
'#253494',
39
'#225EA8',
40
'#1D91C0',
41
'#41B6C4',
42
'#7FCDBB',
43
'#C7E9B4',
44
'#EDF8B1',
45
'#FFFFD9']
46
47
# lists that indicate which colors to use depending on how many are used
48
which_colors = [[],
49
[1],
50
[1, 3],
51
[0, 2, 4],
52
[0, 2, 4, 6],
53
[0, 2, 3, 5, 6],
54
[0, 2, 3, 4, 5, 6],
55
[0, 1, 2, 3, 4, 5, 6],
56
]
57
58
@classmethod
59
def Colors(cls):
60
"""Returns the list of colors.
61
"""
62
return cls.colors
63
64
@classmethod
65
def ColorGenerator(cls, n):
66
"""Returns an iterator of color strings.
67
68
n: how many colors will be used
69
"""
70
for i in cls.which_colors[n]:
71
yield cls.colors[i]
72
raise StopIteration('Ran out of colors in Brewer.ColorGenerator')
73
74
@classmethod
75
def InitializeIter(cls, num):
76
"""Initializes the color iterator with the given number of colors."""
77
cls.color_iter = cls.ColorGenerator(num)
78
79
@classmethod
80
def ClearIter(cls):
81
"""Sets the color iterator to None."""
82
cls.color_iter = None
83
84
@classmethod
85
def GetIter(cls):
86
"""Gets the color iterator."""
87
return cls.color_iter
88
89
90
def PrePlot(num=None, rows=1, cols=1):
91
"""Takes hints about what's coming.
92
93
num: number of lines that will be plotted
94
"""
95
if num:
96
Brewer.InitializeIter(num)
97
98
# TODO: get sharey and sharex working. probably means switching
99
# to subplots instead of subplot.
100
# also, get rid of the gray background.
101
102
if rows > 1 or cols > 1:
103
pyplot.subplots(rows, cols, sharey=True)
104
global SUBPLOT_ROWS, SUBPLOT_COLS
105
SUBPLOT_ROWS = rows
106
SUBPLOT_COLS = cols
107
108
109
def SubPlot(rows, cols, plot_number):
110
"""Configures the number of subplots and changes the current plot.
111
112
rows: int
113
cols: int
114
plot_number: int
115
"""
116
pyplot.subplot(rows, cols, plot_number)
117
118
119
class InfiniteList(list):
120
"""A list that returns the same value for all indices."""
121
def __init__(self, val):
122
"""Initializes the list.
123
124
val: value to be stored
125
"""
126
list.__init__(self)
127
self.val = val
128
129
def __getitem__(self, index):
130
"""Gets the item with the given index.
131
132
index: int
133
134
returns: the stored value
135
"""
136
return self.val
137
138
139
def Underride(d, **options):
140
"""Add key-value pairs to d only if key is not in d.
141
142
If d is None, create a new dictionary.
143
144
d: dictionary
145
options: keyword args to add to d
146
"""
147
if d is None:
148
d = {}
149
150
for key, val in options.iteritems():
151
d.setdefault(key, val)
152
153
return d
154
155
156
def Clf():
157
"""Clears the figure and any hints that have been set."""
158
Brewer.ClearIter()
159
pyplot.clf()
160
161
162
def Figure(**options):
163
"""Sets options for the current figure."""
164
Underride(options, figsize=(6, 8))
165
pyplot.figure(**options)
166
167
168
def Plot(xs, ys, style='', **options):
169
"""Plots a line.
170
171
Args:
172
xs: sequence of x values
173
ys: sequence of y values
174
style: style string passed along to pyplot.plot
175
options: keyword args passed to pyplot.plot
176
"""
177
color_iter = Brewer.GetIter()
178
179
if color_iter:
180
try:
181
options = Underride(options, color=color_iter.next())
182
except StopIteration:
183
print 'Warning: Brewer ran out of colors.'
184
Brewer.ClearIter()
185
186
options = Underride(options, linewidth=3, alpha=0.8)
187
pyplot.plot(xs, ys, style, **options)
188
189
190
def Scatter(xs, ys, **options):
191
"""Makes a scatter plot.
192
193
xs: x values
194
ys: y values
195
options: options passed to pyplot.scatter
196
"""
197
options = Underride(options, color='blue', alpha=0.2,
198
s=30, edgecolors='none')
199
pyplot.scatter(xs, ys, **options)
200
201
202
def Pmf(pmf, **options):
203
"""Plots a Pmf or Hist as a line.
204
205
Args:
206
pmf: Hist or Pmf object
207
options: keyword args passed to pyplot.plot
208
"""
209
xs, ps = pmf.Render()
210
if pmf.name:
211
options = Underride(options, label=pmf.name)
212
Plot(xs, ps, **options)
213
214
215
def Pmfs(pmfs, **options):
216
"""Plots a sequence of PMFs.
217
218
Options are passed along for all PMFs. If you want different
219
options for each pmf, make multiple calls to Pmf.
220
221
Args:
222
pmfs: sequence of PMF objects
223
options: keyword args passed to pyplot.plot
224
"""
225
for pmf in pmfs:
226
Pmf(pmf, **options)
227
228
229
def Hist(hist, **options):
230
"""Plots a Pmf or Hist with a bar plot.
231
232
The default width of the bars is based on the minimum difference
233
between values in the Hist. If that's too small, you can override
234
it by providing a width keyword argument, in the same units
235
as the values.
236
237
Args:
238
hist: Hist or Pmf object
239
options: keyword args passed to pyplot.bar
240
"""
241
# find the minimum distance between adjacent values
242
xs, fs = hist.Render()
243
width = min(Diff(xs))
244
245
if hist.name:
246
options = Underride(options, label=hist.name)
247
248
options = Underride(options,
249
align='center',
250
linewidth=0,
251
width=width)
252
253
pyplot.bar(xs, fs, **options)
254
255
256
def Hists(hists, **options):
257
"""Plots two histograms as interleaved bar plots.
258
259
Options are passed along for all PMFs. If you want different
260
options for each pmf, make multiple calls to Pmf.
261
262
Args:
263
hists: list of two Hist or Pmf objects
264
options: keyword args passed to pyplot.plot
265
"""
266
for hist in hists:
267
Hist(hist, **options)
268
269
270
def Diff(t):
271
"""Compute the differences between adjacent elements in a sequence.
272
273
Args:
274
t: sequence of number
275
276
Returns:
277
sequence of differences (length one less than t)
278
"""
279
diffs = [t[i+1] - t[i] for i in range(len(t)-1)]
280
return diffs
281
282
283
def Cdf(cdf, complement=False, transform=None, **options):
284
"""Plots a CDF as a line.
285
286
Args:
287
cdf: Cdf object
288
complement: boolean, whether to plot the complementary CDF
289
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
290
options: keyword args passed to pyplot.plot
291
292
Returns:
293
dictionary with the scale options that should be passed to
294
Config, Show or Save.
295
"""
296
xs, ps = cdf.Render()
297
scale = dict(xscale='linear', yscale='linear')
298
299
for s in ['xscale', 'yscale']:
300
if s in options:
301
scale[s] = options.pop(s)
302
303
if transform == 'exponential':
304
complement = True
305
scale['yscale'] = 'log'
306
307
if transform == 'pareto':
308
complement = True
309
scale['yscale'] = 'log'
310
scale['xscale'] = 'log'
311
312
if complement:
313
ps = [1.0-p for p in ps]
314
315
if transform == 'weibull':
316
xs.pop()
317
ps.pop()
318
ps = [-math.log(1.0-p) for p in ps]
319
scale['xscale'] = 'log'
320
scale['yscale'] = 'log'
321
322
if transform == 'gumbel':
323
xs.pop(0)
324
ps.pop(0)
325
ps = [-math.log(p) for p in ps]
326
scale['yscale'] = 'log'
327
328
if cdf.name:
329
options = Underride(options, label=cdf.name)
330
331
Plot(xs, ps, **options)
332
return scale
333
334
335
def Cdfs(cdfs, complement=False, transform=None, **options):
336
"""Plots a sequence of CDFs.
337
338
cdfs: sequence of CDF objects
339
complement: boolean, whether to plot the complementary CDF
340
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
341
options: keyword args passed to pyplot.plot
342
"""
343
for cdf in cdfs:
344
Cdf(cdf, complement, transform, **options)
345
346
347
def Contour(obj, pcolor=False, contour=True, imshow=False, **options):
348
"""Makes a contour plot.
349
350
d: map from (x, y) to z, or object that provides GetDict
351
pcolor: boolean, whether to make a pseudocolor plot
352
contour: boolean, whether to make a contour plot
353
imshow: boolean, whether to use pyplot.imshow
354
options: keyword args passed to pyplot.pcolor and/or pyplot.contour
355
"""
356
try:
357
d = obj.GetDict()
358
except AttributeError:
359
d = obj
360
361
Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
362
363
xs, ys = zip(*d.iterkeys())
364
xs = sorted(set(xs))
365
ys = sorted(set(ys))
366
367
X, Y = np.meshgrid(xs, ys)
368
func = lambda x, y: d.get((x, y), 0)
369
func = np.vectorize(func)
370
Z = func(X, Y)
371
372
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
373
axes = pyplot.gca()
374
axes.xaxis.set_major_formatter(x_formatter)
375
376
if pcolor:
377
pyplot.pcolormesh(X, Y, Z, **options)
378
if contour:
379
cs = pyplot.contour(X, Y, Z, **options)
380
pyplot.clabel(cs, inline=1, fontsize=10)
381
if imshow:
382
extent = xs[0], xs[-1], ys[0], ys[-1]
383
pyplot.imshow(Z, extent=extent, **options)
384
385
386
def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options):
387
"""Makes a pseudocolor plot.
388
389
xs:
390
ys:
391
zs:
392
pcolor: boolean, whether to make a pseudocolor plot
393
contour: boolean, whether to make a contour plot
394
options: keyword args passed to pyplot.pcolor and/or pyplot.contour
395
"""
396
Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
397
398
X, Y = np.meshgrid(xs, ys)
399
Z = zs
400
401
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
402
axes = pyplot.gca()
403
axes.xaxis.set_major_formatter(x_formatter)
404
405
if pcolor:
406
pyplot.pcolormesh(X, Y, Z, **options)
407
408
if contour:
409
cs = pyplot.contour(X, Y, Z, **options)
410
pyplot.clabel(cs, inline=1, fontsize=10)
411
412
413
def Config(**options):
414
"""Configures the plot.
415
416
Pulls options out of the option dictionary and passes them to
417
the corresponding pyplot functions.
418
"""
419
names = ['title', 'xlabel', 'ylabel', 'xscale', 'yscale',
420
'xticks', 'yticks', 'axis']
421
422
for name in names:
423
if name in options:
424
getattr(pyplot, name)(options[name])
425
426
loc = options.get('loc', 0)
427
legend = options.get('legend', True)
428
if legend:
429
pyplot.legend(loc=loc)
430
431
432
def Show(**options):
433
"""Shows the plot.
434
435
For options, see Config.
436
437
options: keyword args used to invoke various pyplot functions
438
"""
439
# TODO: figure out how to show more than one plot
440
Config(**options)
441
pyplot.show()
442
443
444
def Save(root=None, formats=None, **options):
445
"""Saves the plot in the given formats.
446
447
For options, see Config.
448
449
Args:
450
root: string filename root
451
formats: list of string formats
452
options: keyword args used to invoke various pyplot functions
453
"""
454
Config(**options)
455
456
if formats is None:
457
formats = ['pdf', 'eps']
458
459
if root:
460
for fmt in formats:
461
SaveFormat(root, fmt)
462
Clf()
463
464
465
def SaveFormat(root, fmt='eps'):
466
"""Writes the current figure to a file in the given format.
467
468
Args:
469
root: string filename root
470
fmt: string format
471
"""
472
filename = '%s.%s' % (root, fmt)
473
print 'Writing', filename
474
pyplot.savefig(filename, format=fmt, dpi=300)
475
476
477
# provide aliases for calling functons with lower-case names
478
preplot = PrePlot
479
subplot = SubPlot
480
clf = Clf
481
figure = Figure
482
plot = Plot
483
scatter = Scatter
484
pmf = Pmf
485
pmfs = Pmfs
486
hist = Hist
487
hists = Hists
488
diff = Diff
489
cdf = Cdf
490
cdfs = Cdfs
491
contour = Contour
492
pcolor = Pcolor
493
config = Config
494
show = Show
495
save = Save
496
497
498
def main():
499
color_iter = Brewer.ColorGenerator(7)
500
for color in color_iter:
501
print color
502
503
if __name__ == '__main__':
504
main()
505
506