Contact
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
| Download

📚 The CoCalc Library - books, templates and other resources

Views: 96161
License: OTHER
1
"""This file contains code for use with "Think Stats",
2
by Allen B. Downey, available from greenteapress.com
3
4
Copyright 2010 Allen B. Downey
5
License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html
6
"""
7
8
import math
9
import matplotlib
10
import matplotlib.pyplot as pyplot
11
import numpy as np
12
13
# customize some matplotlib attributes
14
#matplotlib.rc('figure', figsize=(4, 3))
15
16
#matplotlib.rc('font', size=14.0)
17
#matplotlib.rc('axes', labelsize=22.0, titlesize=22.0)
18
#matplotlib.rc('legend', fontsize=20.0)
19
20
#matplotlib.rc('xtick.major', size=6.0)
21
#matplotlib.rc('xtick.minor', size=3.0)
22
23
#matplotlib.rc('ytick.major', size=6.0)
24
#matplotlib.rc('ytick.minor', size=3.0)
25
26
27
class Brewer(object):
28
"""Encapsulates a nice sequence of colors.
29
30
Shades of blue that look good in color and can be distinguished
31
in grayscale (up to a point).
32
33
Borrowed from http://colorbrewer2.org/
34
"""
35
color_iter = None
36
37
colors = ['#081D58',
38
'#253494',
39
'#225EA8',
40
'#1D91C0',
41
'#41B6C4',
42
'#7FCDBB',
43
'#C7E9B4',
44
'#EDF8B1',
45
'#FFFFD9']
46
47
# lists that indicate which colors to use depending on how many are used
48
which_colors = [[],
49
[1],
50
[1, 3],
51
[0, 2, 4],
52
[0, 2, 4, 6],
53
[0, 2, 3, 5, 6],
54
[0, 2, 3, 4, 5, 6],
55
[0, 1, 2, 3, 4, 5, 6],
56
]
57
58
@classmethod
59
def Colors(cls):
60
"""Returns the list of colors.
61
"""
62
return cls.colors
63
64
@classmethod
65
def ColorGenerator(cls, n):
66
"""Returns an iterator of color strings.
67
68
n: how many colors will be used
69
"""
70
for i in cls.which_colors[n]:
71
yield cls.colors[i]
72
raise StopIteration('Ran out of colors in Brewer.ColorGenerator')
73
74
@classmethod
75
def InitializeIter(cls, num):
76
"""Initializes the color iterator with the given number of colors."""
77
cls.color_iter = cls.ColorGenerator(num)
78
79
@classmethod
80
def ClearIter(cls):
81
"""Sets the color iterator to None."""
82
cls.color_iter = None
83
84
@classmethod
85
def GetIter(cls):
86
"""Gets the color iterator."""
87
return cls.color_iter
88
89
90
def PrePlot(num=None, rows=1, cols=1):
91
"""Takes hints about what's coming.
92
93
num: number of lines that will be plotted
94
"""
95
if num:
96
Brewer.InitializeIter(num)
97
98
# TODO: get sharey and sharex working. probably means switching
99
# to subplots instead of subplot.
100
# also, get rid of the gray background.
101
102
if rows > 1 or cols > 1:
103
pyplot.subplots(rows, cols, sharey=True)
104
global SUBPLOT_ROWS, SUBPLOT_COLS
105
SUBPLOT_ROWS = rows
106
SUBPLOT_COLS = cols
107
108
109
def SubPlot(plot_number):
110
pyplot.subplot(SUBPLOT_ROWS, SUBPLOT_COLS, plot_number)
111
112
113
class InfiniteList(list):
114
"""A list that returns the same value for all indices."""
115
def __init__(self, val):
116
"""Initializes the list.
117
118
val: value to be stored
119
"""
120
list.__init__(self)
121
self.val = val
122
123
def __getitem__(self, index):
124
"""Gets the item with the given index.
125
126
index: int
127
128
returns: the stored value
129
"""
130
return self.val
131
132
133
def Underride(d, **options):
134
"""Add key-value pairs to d only if key is not in d.
135
136
If d is None, create a new dictionary.
137
138
d: dictionary
139
options: keyword args to add to d
140
"""
141
if d is None:
142
d = {}
143
144
for key, val in options.iteritems():
145
d.setdefault(key, val)
146
147
return d
148
149
150
def Clf():
151
"""Clears the figure and any hints that have been set."""
152
Brewer.ClearIter()
153
pyplot.clf()
154
155
156
def Figure(**options):
157
"""Sets options for the current figure."""
158
Underride(options, figsize=(6, 8))
159
pyplot.figure(**options)
160
161
162
def Plot(xs, ys, style='', **options):
163
"""Plots a line.
164
165
Args:
166
xs: sequence of x values
167
ys: sequence of y values
168
style: style string passed along to pyplot.plot
169
options: keyword args passed to pyplot.plot
170
"""
171
color_iter = Brewer.GetIter()
172
173
if color_iter:
174
try:
175
options = Underride(options, color=color_iter.next())
176
except StopIteration:
177
print 'Warning: Brewer ran out of colors.'
178
Brewer.ClearIter()
179
180
options = Underride(options, linewidth=3, alpha=0.8)
181
pyplot.plot(xs, ys, style, **options)
182
183
184
def Scatter(xs, ys, **options):
185
"""Makes a scatter plot.
186
187
xs: x values
188
ys: y values
189
options: options passed to pyplot.scatter
190
"""
191
options = Underride(options, color='blue', alpha=0.2,
192
s=30, edgecolors='none')
193
pyplot.scatter(xs, ys, **options)
194
195
196
def Pmf(pmf, **options):
197
"""Plots a Pmf or Hist as a line.
198
199
Args:
200
pmf: Hist or Pmf object
201
options: keyword args passed to pyplot.plot
202
"""
203
xs, ps = pmf.Render()
204
if pmf.name:
205
options = Underride(options, label=pmf.name)
206
Plot(xs, ps, **options)
207
208
209
def Pmfs(pmfs, **options):
210
"""Plots a sequence of PMFs.
211
212
Options are passed along for all PMFs. If you want different
213
options for each pmf, make multiple calls to Pmf.
214
215
Args:
216
pmfs: sequence of PMF objects
217
options: keyword args passed to pyplot.plot
218
"""
219
for pmf in pmfs:
220
Pmf(pmf, **options)
221
222
223
def Hist(hist, **options):
224
"""Plots a Pmf or Hist with a bar plot.
225
226
Args:
227
hist: Hist or Pmf object
228
options: keyword args passed to pyplot.bar
229
"""
230
# find the minimum distance between adjacent values
231
xs, fs = hist.Render()
232
width = min(Diff(xs))
233
234
if hist.name:
235
options = Underride(options, label=hist.name)
236
237
options = Underride(options,
238
align='center',
239
linewidth=0,
240
width=width)
241
242
pyplot.bar(xs, fs, **options)
243
244
245
def Hists(hists, **options):
246
"""Plots two histograms as interleaved bar plots.
247
248
Options are passed along for all PMFs. If you want different
249
options for each pmf, make multiple calls to Pmf.
250
251
Args:
252
hists: list of two Hist or Pmf objects
253
options: keyword args passed to pyplot.plot
254
"""
255
for hist in hists:
256
Hist(hist, **options)
257
258
259
def Diff(t):
260
"""Compute the differences between adjacent elements in a sequence.
261
262
Args:
263
t: sequence of number
264
265
Returns:
266
sequence of differences (length one less than t)
267
"""
268
diffs = [t[i+1] - t[i] for i in range(len(t)-1)]
269
return diffs
270
271
272
def Cdf(cdf, complement=False, transform=None, **options):
273
"""Plots a CDF as a line.
274
275
Args:
276
cdf: Cdf object
277
complement: boolean, whether to plot the complementary CDF
278
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
279
options: keyword args passed to pyplot.plot
280
281
Returns:
282
dictionary with the scale options that should be passed to
283
myplot.Save or myplot.Show
284
"""
285
xs, ps = cdf.Render()
286
scale = dict(xscale='linear', yscale='linear')
287
288
if transform == 'exponential':
289
complement = True
290
scale['yscale'] = 'log'
291
292
if transform == 'pareto':
293
complement = True
294
scale['yscale'] = 'log'
295
scale['xscale'] = 'log'
296
297
if complement:
298
ps = [1.0-p for p in ps]
299
300
if transform == 'weibull':
301
xs.pop()
302
ps.pop()
303
ps = [-math.log(1.0-p) for p in ps]
304
scale['xscale'] = 'log'
305
scale['yscale'] = 'log'
306
307
if transform == 'gumbel':
308
xs.pop(0)
309
ps.pop(0)
310
ps = [-math.log(p) for p in ps]
311
scale['yscale'] = 'log'
312
313
if cdf.name:
314
options = Underride(options, label=cdf.name)
315
316
Plot(xs, ps, **options)
317
return scale
318
319
320
def Cdfs(cdfs, complement=False, transform=None, **options):
321
"""Plots a sequence of CDFs.
322
323
cdfs: sequence of CDF objects
324
complement: boolean, whether to plot the complementary CDF
325
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
326
options: keyword args passed to pyplot.plot
327
"""
328
for cdf in cdfs:
329
Cdf(cdf, complement, transform, **options)
330
331
332
def Contour(obj, pcolor=False, contour=True, imshow=False, **options):
333
"""Makes a contour plot.
334
335
d: map from (x, y) to z, or object that provides GetDict
336
pcolor: boolean, whether to make a pseudocolor plot
337
contour: boolean, whether to make a contour plot
338
imshow: boolean, whether to use pyplot.imshow
339
options: keyword args passed to pyplot.pcolor and/or pyplot.contour
340
"""
341
try:
342
d = obj.GetDict()
343
except AttributeError:
344
d = obj
345
346
Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
347
348
xs, ys = zip(*d.iterkeys())
349
xs = sorted(set(xs))
350
ys = sorted(set(ys))
351
352
X, Y = np.meshgrid(xs, ys)
353
func = lambda x, y: d.get((x, y), 0)
354
func = np.vectorize(func)
355
Z = func(X, Y)
356
357
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
358
axes = pyplot.gca()
359
axes.xaxis.set_major_formatter(x_formatter)
360
361
if pcolor:
362
pyplot.pcolormesh(X, Y, Z, **options)
363
if contour:
364
cs = pyplot.contour(X, Y, Z, **options)
365
pyplot.clabel(cs, inline=1, fontsize=10)
366
if imshow:
367
extent = xs[0], xs[-1], ys[0], ys[-1]
368
pyplot.imshow(Z, extent=extent, **options)
369
370
371
def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options):
372
"""Makes a pseudocolor plot.
373
374
xs:
375
ys:
376
zs:
377
pcolor: boolean, whether to make a pseudocolor plot
378
contour: boolean, whether to make a contour plot
379
options: keyword args passed to pyplot.pcolor and/or pyplot.contour
380
"""
381
Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
382
383
X, Y = np.meshgrid(xs, ys)
384
Z = zs
385
386
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
387
axes = pyplot.gca()
388
axes.xaxis.set_major_formatter(x_formatter)
389
390
if pcolor:
391
pyplot.pcolormesh(X, Y, Z, **options)
392
393
if contour:
394
cs = pyplot.contour(X, Y, Z, **options)
395
pyplot.clabel(cs, inline=1, fontsize=10)
396
397
398
def Config(**options):
399
"""Configures the plot.
400
401
Pulls options out of the option dictionary and passes them to
402
title, xlabel, ylabel, xscale, yscale, xticks, yticks, axis, legend,
403
and loc.
404
"""
405
title = options.get('title', '')
406
pyplot.title(title)
407
408
xlabel = options.get('xlabel', '')
409
pyplot.xlabel(xlabel)
410
411
ylabel = options.get('ylabel', '')
412
pyplot.ylabel(ylabel)
413
414
if 'xscale' in options:
415
pyplot.xscale(options['xscale'])
416
417
if 'xticks' in options:
418
pyplot.xticks(options['xticks'])
419
420
if 'yscale' in options:
421
pyplot.yscale(options['yscale'])
422
423
if 'yticks' in options:
424
pyplot.yticks(options['yticks'])
425
426
if 'axis' in options:
427
pyplot.axis(options['axis'])
428
429
loc = options.get('loc', 0)
430
legend = options.get('legend', True)
431
if legend:
432
pyplot.legend(loc=loc)
433
434
435
def Show(**options):
436
"""Shows the plot.
437
438
For options, see Config.
439
440
options: keyword args used to invoke various pyplot functions
441
"""
442
# TODO: figure out how to show more than one plot
443
Config(**options)
444
pyplot.show()
445
446
447
def Save(root=None, formats=None, **options):
448
"""Saves the plot in the given formats.
449
450
For options, see Config.
451
452
Args:
453
root: string filename root
454
formats: list of string formats
455
options: keyword args used to invoke various pyplot functions
456
"""
457
Config(**options)
458
459
if formats is None:
460
formats = ['pdf', 'eps']
461
462
if root:
463
for fmt in formats:
464
SaveFormat(root, fmt)
465
Clf()
466
467
468
def SaveFormat(root, fmt='eps'):
469
"""Writes the current figure to a file in the given format.
470
471
Args:
472
root: string filename root
473
fmt: string format
474
"""
475
filename = '%s.%s' % (root, fmt)
476
print 'Writing', filename
477
pyplot.savefig(filename, format=fmt, dpi=300)
478
479
480
# provide aliases for calling functons with lower-case names
481
preplot = PrePlot
482
subplot = SubPlot
483
clf = Clf
484
figure = Figure
485
plot = Plot
486
scatter = Scatter
487
pmf = Pmf
488
pmfs = Pmfs
489
hist = Hist
490
hists = Hists
491
diff = Diff
492
cdf = Cdf
493
cdfs = Cdfs
494
contour = Contour
495
pcolor = Pcolor
496
config = Config
497
show = Show
498
save = Save
499
500
501
def main():
502
color_iter = Brewer.ColorGenerator(7)
503
for color in color_iter:
504
print color
505
506
if __name__ == '__main__':
507
main()
508
509