Contact
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
| Download

📚 The CoCalc Library - books, templates and other resources

Views: 96160
License: OTHER
1
"""This file contains code used in "Think Bayes",
2
by Allen B. Downey, available from greenteapress.com
3
4
Copyright 2012 Allen B. Downey
5
License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html
6
"""
7
8
import matplotlib.pyplot as pyplot
9
import thinkplot
10
import numpy
11
12
import csv
13
import random
14
import shelve
15
import sys
16
import time
17
18
import thinkbayes
19
20
import warnings
21
22
warnings.simplefilter('error', RuntimeWarning)
23
24
25
FORMATS = ['pdf', 'eps', 'png']
26
27
28
class Locker(object):
29
"""Encapsulates a shelf for storing key-value pairs."""
30
31
def __init__(self, shelf_file):
32
self.shelf = shelve.open(shelf_file)
33
34
def Close(self):
35
"""Closes the shelf.
36
"""
37
self.shelf.close()
38
39
def Add(self, key, value):
40
"""Adds a key-value pair."""
41
self.shelf[str(key)] = value
42
43
def Lookup(self, key):
44
"""Looks up a key."""
45
return self.shelf.get(str(key))
46
47
def Keys(self):
48
"""Returns an iterator of keys."""
49
return self.shelf.iterkeys()
50
51
def Read(self):
52
"""Returns the contents of the shelf as a map."""
53
return dict(self.shelf)
54
55
56
class Subject(object):
57
"""Represents a subject from the belly button study."""
58
59
def __init__(self, code):
60
"""
61
code: string ID
62
species: sequence of (int count, string species) pairs
63
"""
64
self.code = code
65
self.species = []
66
self.suite = None
67
self.num_reads = None
68
self.num_species = None
69
self.total_reads = None
70
self.total_species = None
71
self.prev_unseen = None
72
self.pmf_n = None
73
self.pmf_q = None
74
self.pmf_l = None
75
76
def Add(self, species, count):
77
"""Add a species-count pair.
78
79
It is up to the caller to ensure that species names are unique.
80
81
species: string species/genus name
82
count: int number of individuals
83
"""
84
self.species.append((count, species))
85
86
def Done(self, reverse=False, clean_param=0):
87
"""Called when we are done adding species counts.
88
89
reverse: which order to sort in
90
"""
91
if clean_param:
92
self.Clean(clean_param)
93
94
self.species.sort(reverse=reverse)
95
counts = self.GetCounts()
96
self.num_species = len(counts)
97
self.num_reads = sum(counts)
98
99
def Clean(self, clean_param=50):
100
"""Identifies and removes bogus data.
101
102
clean_param: parameter that controls the number of legit species
103
"""
104
def prob_bogus(k, r):
105
"""Compute the probability that a species is bogus."""
106
q = clean_param / r
107
p = (1-q) ** k
108
return p
109
110
print self.code, clean_param
111
112
counts = self.GetCounts()
113
r = 1.0 * sum(counts)
114
115
species_seq = []
116
for k, species in sorted(self.species):
117
118
if random.random() < prob_bogus(k, r):
119
continue
120
species_seq.append((k, species))
121
self.species = species_seq
122
123
def GetM(self):
124
"""Gets number of observed species."""
125
return len(self.species)
126
127
def GetCounts(self):
128
"""Gets the list of species counts
129
130
Should be in increasing order, if Sort() has been invoked.
131
"""
132
return [count for count, _ in self.species]
133
134
def MakeCdf(self):
135
"""Makes a CDF of total prevalence vs rank."""
136
counts = self.GetCounts()
137
counts.sort(reverse=True)
138
cdf = thinkbayes.MakeCdfFromItems(enumerate(counts))
139
return cdf
140
141
def GetNames(self):
142
"""Gets the names of the seen species."""
143
return [name for _, name in self.species]
144
145
def PrintCounts(self):
146
"""Prints the counts and species names."""
147
for count, name in reversed(self.species):
148
print count, name
149
150
def GetSpecies(self, index):
151
"""Gets the count and name of the indicated species.
152
153
Returns: count-species pair
154
"""
155
return self.species[index]
156
157
def GetCdf(self):
158
"""Returns cumulative prevalence vs number of species.
159
"""
160
counts = self.GetCounts()
161
items = enumerate(counts)
162
cdf = thinkbayes.MakeCdfFromItems(items)
163
return cdf
164
165
def GetPrevalences(self):
166
"""Returns a sequence of prevalences (normalized counts).
167
"""
168
counts = self.GetCounts()
169
total = sum(counts)
170
prevalences = numpy.array(counts, dtype=numpy.float) / total
171
return prevalences
172
173
def Process(self, low=None, high=500, conc=1, iters=100):
174
"""Computes the posterior distribution of n and the prevalences.
175
176
Sets attribute: self.suite
177
178
low: minimum number of species
179
high: maximum number of species
180
conc: concentration parameter
181
iters: number of iterations to use in the estimator
182
"""
183
counts = self.GetCounts()
184
m = len(counts)
185
if low is None:
186
low = max(m, 2)
187
ns = range(low, high+1)
188
189
#start = time.time()
190
self.suite = Species5(ns, conc=conc, iters=iters)
191
self.suite.Update(counts)
192
#end = time.time()
193
194
#print 'Processing time' end-start
195
196
def MakePrediction(self, num_sims=100):
197
"""Make predictions for the given subject.
198
199
Precondition: Process has run
200
201
num_sims: how many simulations to run for predictions
202
203
Adds attributes
204
pmf_l: predictive distribution of additional species
205
"""
206
add_reads = self.total_reads - self.num_reads
207
curves = self.RunSimulations(num_sims, add_reads)
208
self.pmf_l = self.MakePredictive(curves)
209
210
def MakeQuickPrediction(self, num_sims=100):
211
"""Make predictions for the given subject.
212
213
Precondition: Process has run
214
215
num_sims: how many simulations to run for predictions
216
217
Adds attribute:
218
pmf_l: predictive distribution of additional species
219
"""
220
add_reads = self.total_reads - self.num_reads
221
pmf = thinkbayes.Pmf()
222
_, seen = self.GetSeenSpecies()
223
224
for _ in range(num_sims):
225
_, observations = self.GenerateObservations(add_reads)
226
all_seen = seen.union(observations)
227
l = len(all_seen) - len(seen)
228
pmf.Incr(l)
229
230
pmf.Normalize()
231
self.pmf_l = pmf
232
233
def DistL(self):
234
"""Returns the distribution of additional species, l.
235
"""
236
return self.pmf_l
237
238
def MakeFigures(self):
239
"""Makes figures showing distribution of n and the prevalences."""
240
self.PlotDistN()
241
self.PlotPrevalences()
242
243
def PlotDistN(self):
244
"""Plots distribution of n."""
245
pmf = self.suite.DistN()
246
print '90% CI for N:', pmf.CredibleInterval(90)
247
pmf.name = self.code
248
249
thinkplot.Clf()
250
thinkplot.PrePlot(num=1)
251
252
thinkplot.Pmf(pmf)
253
254
root = 'species-ndist-%s' % self.code
255
thinkplot.Save(root=root,
256
xlabel='Number of species',
257
ylabel='Prob',
258
formats=FORMATS,
259
)
260
261
def PlotPrevalences(self, num=5):
262
"""Plots dist of prevalence for several species.
263
264
num: how many species (starting with the highest prevalence)
265
"""
266
thinkplot.Clf()
267
thinkplot.PrePlot(num=5)
268
269
for rank in range(1, num+1):
270
self.PlotPrevalence(rank)
271
272
root = 'species-prev-%s' % self.code
273
thinkplot.Save(root=root,
274
xlabel='Prevalence',
275
ylabel='Prob',
276
formats=FORMATS,
277
axis=[0, 0.3, 0, 1],
278
)
279
280
def PlotPrevalence(self, rank=1, cdf_flag=True):
281
"""Plots dist of prevalence for one species.
282
283
rank: rank order of the species to plot.
284
cdf_flag: whether to plot the CDF
285
"""
286
# convert rank to index
287
index = self.GetM() - rank
288
289
_, mix = self.suite.DistOfPrevalence(index)
290
count, _ = self.GetSpecies(index)
291
mix.name = '%d (%d)' % (rank, count)
292
293
print '90%% CI for prevalence of species %d:' % rank,
294
print mix.CredibleInterval(90)
295
296
if cdf_flag:
297
cdf = mix.MakeCdf()
298
thinkplot.Cdf(cdf)
299
else:
300
thinkplot.Pmf(mix)
301
302
def PlotMixture(self, rank=1):
303
"""Plots dist of prevalence for all n, and the mix.
304
305
rank: rank order of the species to plot
306
"""
307
# convert rank to index
308
index = self.GetM() - rank
309
310
print self.GetSpecies(index)
311
print self.GetCounts()[index]
312
313
metapmf, mix = self.suite.DistOfPrevalence(index)
314
315
thinkplot.Clf()
316
for pmf in metapmf.Values():
317
thinkplot.Pmf(pmf, color='blue', alpha=0.2, linewidth=0.5)
318
319
thinkplot.Pmf(mix, color='blue', alpha=0.9, linewidth=2)
320
321
root = 'species-mix-%s' % self.code
322
thinkplot.Save(root=root,
323
xlabel='Prevalence',
324
ylabel='Prob',
325
formats=FORMATS,
326
axis=[0, 0.3, 0, 0.3],
327
legend=False)
328
329
def GetSeenSpecies(self):
330
"""Makes a set of the names of seen species.
331
332
Returns: number of species, set of string species names
333
"""
334
names = self.GetNames()
335
m = len(names)
336
seen = set(SpeciesGenerator(names, m))
337
return m, seen
338
339
def GenerateObservations(self, num_reads):
340
"""Generates a series of random observations.
341
342
num_reads: number of reads to generate
343
344
Returns: number of species, sequence of string species names
345
"""
346
n, prevalences = self.suite.SamplePosterior()
347
348
names = self.GetNames()
349
name_iter = SpeciesGenerator(names, n)
350
351
items = zip(name_iter, prevalences)
352
353
cdf = thinkbayes.MakeCdfFromItems(items)
354
observations = cdf.Sample(num_reads)
355
356
#for ob in observations:
357
# print ob
358
359
return n, observations
360
361
def Resample(self, num_reads):
362
"""Choose a random subset of the data (without replacement).
363
364
num_reads: number of reads in the subset
365
"""
366
t = []
367
for count, species in self.species:
368
t.extend([species]*count)
369
370
random.shuffle(t)
371
reads = t[:num_reads]
372
373
subject = Subject(self.code)
374
hist = thinkbayes.MakeHistFromList(reads)
375
for species, count in hist.Items():
376
subject.Add(species, count)
377
378
subject.Done()
379
return subject
380
381
def Match(self, match):
382
"""Match up a rarefied subject with a complete subject.
383
384
match: complete Subject
385
386
Assigns attributes:
387
total_reads:
388
total_species:
389
prev_unseen:
390
"""
391
self.total_reads = match.num_reads
392
self.total_species = match.num_species
393
394
# compute the prevalence of unseen species (at least approximately,
395
# based on all species counts in match
396
_, seen = self.GetSeenSpecies()
397
398
seen_total = 0.0
399
unseen_total = 0.0
400
for count, species in match.species:
401
if species in seen:
402
seen_total += count
403
else:
404
unseen_total += count
405
406
self.prev_unseen = unseen_total / (seen_total + unseen_total)
407
408
def RunSimulation(self, num_reads, frac_flag=False, jitter=0.01):
409
"""Simulates additional observations and returns a rarefaction curve.
410
411
k is the number of additional observations
412
num_new is the number of new species seen
413
414
num_reads: how many new reads to simulate
415
frac_flag: whether to convert to fraction of species seen
416
jitter: size of jitter added if frac_flag is true
417
418
Returns: list of (k, num_new) pairs
419
"""
420
m, seen = self.GetSeenSpecies()
421
n, observations = self.GenerateObservations(num_reads)
422
423
curve = []
424
for i, obs in enumerate(observations):
425
seen.add(obs)
426
427
if frac_flag:
428
frac_seen = len(seen) / float(n)
429
frac_seen += random.uniform(-jitter, jitter)
430
curve.append((i+1, frac_seen))
431
else:
432
num_new = len(seen) - m
433
curve.append((i+1, num_new))
434
435
return curve
436
437
def RunSimulations(self, num_sims, num_reads, frac_flag=False):
438
"""Runs simulations and returns a list of curves.
439
440
Each curve is a sequence of (k, num_new) pairs.
441
442
num_sims: how many simulations to run
443
num_reads: how many samples to generate in each simulation
444
frac_flag: whether to convert num_new to fraction of total
445
"""
446
curves = [self.RunSimulation(num_reads, frac_flag)
447
for _ in range(num_sims)]
448
return curves
449
450
def MakePredictive(self, curves):
451
"""Makes a predictive distribution of additional species.
452
453
curves: list of (k, num_new) curves
454
455
Returns: Pmf of num_new
456
"""
457
pred = thinkbayes.Pmf(name=self.code)
458
for curve in curves:
459
_, last_num_new = curve[-1]
460
pred.Incr(last_num_new)
461
pred.Normalize()
462
return pred
463
464
465
def MakeConditionals(curves, ks):
466
"""Makes Cdfs of the distribution of num_new conditioned on k.
467
468
curves: list of (k, num_new) curves
469
ks: list of values of k
470
471
Returns: list of Cdfs
472
"""
473
joint = MakeJointPredictive(curves)
474
475
cdfs = []
476
for k in ks:
477
pmf = joint.Conditional(1, 0, k)
478
pmf.name = 'k=%d' % k
479
cdf = pmf.MakeCdf()
480
cdfs.append(cdf)
481
print '90%% credible interval for %d' % k,
482
print cdf.CredibleInterval(90)
483
return cdfs
484
485
486
def MakeJointPredictive(curves):
487
"""Makes a joint distribution of k and num_new.
488
489
curves: list of (k, num_new) curves
490
491
Returns: joint Pmf of (k, num_new)
492
"""
493
joint = thinkbayes.Joint()
494
for curve in curves:
495
for k, num_new in curve:
496
joint.Incr((k, num_new))
497
joint.Normalize()
498
return joint
499
500
501
def MakeFracCdfs(curves, ks):
502
"""Makes Cdfs of the fraction of species seen.
503
504
curves: list of (k, num_new) curves
505
506
Returns: list of Cdfs
507
"""
508
d = {}
509
for curve in curves:
510
for k, frac in curve:
511
if k in ks:
512
d.setdefault(k, []).append(frac)
513
514
cdfs = {}
515
for k, fracs in d.iteritems():
516
cdf = thinkbayes.MakeCdfFromList(fracs)
517
cdfs[k] = cdf
518
519
return cdfs
520
521
def SpeciesGenerator(names, num):
522
"""Generates a series of names, starting with the given names.
523
524
Additional names are 'unseen' plus a serial number.
525
526
names: list of strings
527
num: total number of species names to generate
528
529
Returns: string iterator
530
"""
531
i = 0
532
for name in names:
533
yield name
534
i += 1
535
536
while i < num:
537
yield 'unseen-%d' % i
538
i += 1
539
540
541
def ReadRarefactedData(filename='journal.pone.0047712.s001.csv',
542
clean_param=0):
543
"""Reads a data file and returns a list of Subjects.
544
545
Data from http://www.plosone.org/article/
546
info%3Adoi%2F10.1371%2Fjournal.pone.0047712#s4
547
548
filename: string filename to read
549
clean_param: parameter passed to Clean
550
551
Returns: map from code to Subject
552
"""
553
fp = open(filename)
554
reader = csv.reader(fp)
555
_ = reader.next()
556
557
subject = Subject('')
558
subject_map = {}
559
560
i = 0
561
for t in reader:
562
code = t[0]
563
if code != subject.code:
564
# start a new subject
565
subject = Subject(code)
566
subject_map[code] = subject
567
568
# append a number to the species names so they're unique
569
species = t[1]
570
species = '%s-%d' % (species, i)
571
i += 1
572
573
count = int(t[2])
574
subject.Add(species, count)
575
576
for code, subject in subject_map.iteritems():
577
subject.Done(clean_param=clean_param)
578
579
return subject_map
580
581
582
def ReadCompleteDataset(filename='BBB_data_from_Rob.csv', clean_param=0):
583
"""Reads a data file and returns a list of Subjects.
584
585
Data from personal correspondence with Rob Dunn, received 2-7-13.
586
Converted from xlsx to csv.
587
588
filename: string filename to read
589
clean_param: parameter passed to Clean
590
591
Returns: map from code to Subject
592
"""
593
fp = open(filename)
594
reader = csv.reader(fp)
595
header = reader.next()
596
header = reader.next()
597
598
subject_codes = header[1:-1]
599
subject_codes = ['B'+code for code in subject_codes]
600
601
# create the subject map
602
uber_subject = Subject('uber')
603
subject_map = {}
604
for code in subject_codes:
605
subject_map[code] = Subject(code)
606
607
# read lines
608
i = 0
609
for t in reader:
610
otu_code = t[0]
611
if otu_code == '':
612
continue
613
614
# pull out a species name and give it a number
615
otu_names = t[-1]
616
taxons = otu_names.split(';')
617
species = taxons[-1]
618
species = '%s-%d' % (species, i)
619
i += 1
620
621
counts = [int(x) for x in t[1:-1]]
622
623
# print otu_code, species
624
625
for code, count in zip(subject_codes, counts):
626
if count > 0:
627
subject_map[code].Add(species, count)
628
uber_subject.Add(species, count)
629
630
uber_subject.Done(clean_param=clean_param)
631
for code, subject in subject_map.iteritems():
632
subject.Done(clean_param=clean_param)
633
634
return subject_map, uber_subject
635
636
637
def JoinSubjects():
638
"""Reads both datasets and computers their inner join.
639
640
Finds all subjects that appear in both datasets.
641
642
For subjects in the rarefacted dataset, looks up the total
643
number of reads and stores it as total_reads. num_reads
644
is normally 400.
645
646
Returns: map from code to Subject
647
"""
648
649
# read the rarefacted dataset
650
sampled_subjects = ReadRarefactedData()
651
652
# read the complete dataset
653
all_subjects, _ = ReadCompleteDataset()
654
655
for code, subject in sampled_subjects.iteritems():
656
if code in all_subjects:
657
match = all_subjects[code]
658
subject.Match(match)
659
660
return sampled_subjects
661
662
663
def JitterCurve(curve, dx=0.2, dy=0.3):
664
"""Adds random noise to the pairs in a curve.
665
666
dx and dy control the amplitude of the noise in each dimension.
667
"""
668
curve = [(x+random.uniform(-dx, dx),
669
y+random.uniform(-dy, dy)) for x, y in curve]
670
return curve
671
672
673
def OffsetCurve(curve, i, n, dx=0.3, dy=0.3):
674
"""Adds random noise to the pairs in a curve.
675
676
i is the index of the curve
677
n is the number of curves
678
679
dx and dy control the amplitude of the noise in each dimension.
680
"""
681
xoff = -dx + 2 * dx * i / (n-1)
682
yoff = -dy + 2 * dy * i / (n-1)
683
curve = [(x+xoff, y+yoff) for x, y in curve]
684
return curve
685
686
687
def PlotCurves(curves, root='species-rare'):
688
"""Plots a set of curves.
689
690
curves is a list of curves; each curve is a list of (x, y) pairs.
691
"""
692
thinkplot.Clf()
693
color = '#225EA8'
694
695
n = len(curves)
696
for i, curve in enumerate(curves):
697
curve = OffsetCurve(curve, i, n)
698
xs, ys = zip(*curve)
699
thinkplot.Plot(xs, ys, color=color, alpha=0.3, linewidth=0.5)
700
701
thinkplot.Save(root=root,
702
xlabel='# samples',
703
ylabel='# species',
704
formats=FORMATS,
705
legend=False)
706
707
708
def PlotConditionals(cdfs, root='species-cond'):
709
"""Plots cdfs of num_new conditioned on k.
710
711
cdfs: list of Cdf
712
root: string filename root
713
"""
714
thinkplot.Clf()
715
thinkplot.PrePlot(num=len(cdfs))
716
717
thinkplot.Cdfs(cdfs)
718
719
thinkplot.Save(root=root,
720
xlabel='# new species',
721
ylabel='Prob',
722
formats=FORMATS)
723
724
725
def PlotFracCdfs(cdfs, root='species-frac'):
726
"""Plots CDFs of the fraction of species seen.
727
728
cdfs: map from k to CDF of fraction of species seen after k samples
729
"""
730
thinkplot.Clf()
731
color = '#225EA8'
732
733
for k, cdf in cdfs.iteritems():
734
xs, ys = cdf.Render()
735
ys = [1-y for y in ys]
736
thinkplot.Plot(xs, ys, color=color, linewidth=1)
737
738
x = 0.9
739
y = 1 - cdf.Prob(x)
740
pyplot.text(x, y, str(k), fontsize=9, color=color,
741
horizontalalignment='center',
742
verticalalignment='center',
743
bbox=dict(facecolor='white', edgecolor='none'))
744
745
thinkplot.Save(root=root,
746
xlabel='Fraction of species seen',
747
ylabel='Probability',
748
formats=FORMATS,
749
legend=False)
750
751
752
class Species(thinkbayes.Suite):
753
"""Represents hypotheses about the number of species."""
754
755
def __init__(self, ns, conc=1, iters=1000):
756
hypos = [thinkbayes.Dirichlet(n, conc) for n in ns]
757
thinkbayes.Suite.__init__(self, hypos)
758
self.iters = iters
759
760
def Update(self, data):
761
"""Updates the suite based on the data.
762
763
data: list of observed frequencies
764
"""
765
# call Update in the parent class, which calls Likelihood
766
thinkbayes.Suite.Update(self, data)
767
768
# update the next level of the hierarchy
769
for hypo in self.Values():
770
hypo.Update(data)
771
772
def Likelihood(self, data, hypo):
773
"""Computes the likelihood of the data under this hypothesis.
774
775
hypo: Dirichlet object
776
data: list of observed frequencies
777
"""
778
dirichlet = hypo
779
780
# draw sample Likelihoods from the hypothetical Dirichlet dist
781
# and add them up
782
like = 0
783
for _ in range(self.iters):
784
like += dirichlet.Likelihood(data)
785
786
# correct for the number of ways the observed species
787
# might have been chosen from all species
788
m = len(data)
789
like *= thinkbayes.BinomialCoef(dirichlet.n, m)
790
791
return like
792
793
def DistN(self):
794
"""Computes the distribution of n."""
795
pmf = thinkbayes.Pmf()
796
for hypo, prob in self.Items():
797
pmf.Set(hypo.n, prob)
798
return pmf
799
800
801
class Species2(object):
802
"""Represents hypotheses about the number of species.
803
804
Combines two layers of the hierarchy into one object.
805
806
ns and probs represent the distribution of N
807
808
params represents the parameters of the Dirichlet distributions
809
"""
810
811
def __init__(self, ns, conc=1, iters=1000):
812
self.ns = ns
813
self.conc = conc
814
self.probs = numpy.ones(len(ns), dtype=numpy.float)
815
self.params = numpy.ones(self.ns[-1], dtype=numpy.float) * conc
816
self.iters = iters
817
self.num_reads = 0
818
self.m = 0
819
820
def Preload(self, data):
821
"""Change the initial parameters to fit the data better.
822
823
Just an experiment. Doesn't work.
824
"""
825
m = len(data)
826
singletons = data.count(1)
827
num = m - singletons
828
print m, singletons, num
829
addend = numpy.ones(num, dtype=numpy.float) * 1
830
print len(addend)
831
print len(self.params[singletons:m])
832
self.params[singletons:m] += addend
833
print 'Preload', num
834
835
def Update(self, data):
836
"""Updates the distribution based on data.
837
838
data: numpy array of counts
839
"""
840
self.num_reads += sum(data)
841
842
like = numpy.zeros(len(self.ns), dtype=numpy.float)
843
for _ in range(self.iters):
844
like += self.SampleLikelihood(data)
845
846
self.probs *= like
847
self.probs /= self.probs.sum()
848
849
self.m = len(data)
850
#self.params[:self.m] += data * self.conc
851
self.params[:self.m] += data
852
853
def SampleLikelihood(self, data):
854
"""Computes the likelihood of the data for all values of n.
855
856
Draws one sample from the distribution of prevalences.
857
858
data: sequence of observed counts
859
860
Returns: numpy array of m likelihoods
861
"""
862
gammas = numpy.random.gamma(self.params)
863
864
m = len(data)
865
row = gammas[:m]
866
col = numpy.cumsum(gammas)
867
868
log_likes = []
869
for n in self.ns:
870
ps = row / col[n-1]
871
terms = numpy.log(ps) * data
872
log_like = terms.sum()
873
log_likes.append(log_like)
874
875
log_likes -= numpy.max(log_likes)
876
likes = numpy.exp(log_likes)
877
878
coefs = [thinkbayes.BinomialCoef(n, m) for n in self.ns]
879
likes *= coefs
880
881
return likes
882
883
def DistN(self):
884
"""Computes the distribution of n.
885
886
Returns: new Pmf object
887
"""
888
pmf = thinkbayes.MakePmfFromItems(zip(self.ns, self.probs))
889
return pmf
890
891
def RandomN(self):
892
"""Returns a random value of n."""
893
return self.DistN().Random()
894
895
def DistQ(self, iters=100):
896
"""Computes the distribution of q based on distribution of n.
897
898
Returns: pmf of q
899
"""
900
cdf_n = self.DistN().MakeCdf()
901
sample_n = cdf_n.Sample(iters)
902
903
pmf = thinkbayes.Pmf()
904
for n in sample_n:
905
q = self.RandomQ(n)
906
pmf.Incr(q)
907
908
pmf.Normalize()
909
return pmf
910
911
def RandomQ(self, n):
912
"""Returns a random value of q.
913
914
Based on n, self.num_reads and self.conc.
915
916
n: number of species
917
918
Returns: q
919
"""
920
# generate random prevalences
921
dirichlet = thinkbayes.Dirichlet(n, conc=self.conc)
922
prevalences = dirichlet.Random()
923
924
# generate a simulated sample
925
pmf = thinkbayes.MakePmfFromItems(enumerate(prevalences))
926
cdf = pmf.MakeCdf()
927
sample = cdf.Sample(self.num_reads)
928
seen = set(sample)
929
930
# add up the prevalence of unseen species
931
q = 0
932
for species, prev in enumerate(prevalences):
933
if species not in seen:
934
q += prev
935
936
return q
937
938
def MarginalBeta(self, n, index):
939
"""Computes the conditional distribution of the indicated species.
940
941
n: conditional number of species
942
index: which species
943
944
Returns: Beta object representing a distribution of prevalence.
945
"""
946
alpha0 = self.params[:n].sum()
947
alpha = self.params[index]
948
return thinkbayes.Beta(alpha, alpha0-alpha)
949
950
def DistOfPrevalence(self, index):
951
"""Computes the distribution of prevalence for the indicated species.
952
953
index: which species
954
955
Returns: (metapmf, mix) where metapmf is a MetaPmf and mix is a Pmf
956
"""
957
metapmf = thinkbayes.Pmf()
958
959
for n, prob in zip(self.ns, self.probs):
960
beta = self.MarginalBeta(n, index)
961
pmf = beta.MakePmf()
962
metapmf.Set(pmf, prob)
963
964
mix = thinkbayes.MakeMixture(metapmf)
965
return metapmf, mix
966
967
def SamplePosterior(self):
968
"""Draws random n and prevalences.
969
970
Returns: (n, prevalences)
971
"""
972
n = self.RandomN()
973
prevalences = self.SamplePrevalences(n)
974
975
#print 'Peeking at n_cheat'
976
#n = n_cheat
977
978
return n, prevalences
979
980
def SamplePrevalences(self, n):
981
"""Draws a sample of prevalences given n.
982
983
n: the number of species assumed in the conditional
984
985
Returns: numpy array of n prevalences
986
"""
987
if n == 1:
988
return [1.0]
989
990
q_desired = self.RandomQ(n)
991
q_desired = max(q_desired, 1e-6)
992
993
params = self.Unbias(n, self.m, q_desired)
994
995
gammas = numpy.random.gamma(params)
996
gammas /= gammas.sum()
997
return gammas
998
999
def Unbias(self, n, m, q_desired):
1000
"""Adjusts the parameters to achieve desired prev_unseen (q).
1001
1002
n: number of species
1003
m: seen species
1004
q_desired: prevalence of unseen species
1005
"""
1006
params = self.params[:n].copy()
1007
1008
if n == m:
1009
return params
1010
1011
x = sum(params[:m])
1012
y = sum(params[m:])
1013
a = x + y
1014
#print x, y, a, x/a, y/a
1015
1016
g = q_desired * a / y
1017
f = (a - g * y) / x
1018
params[:m] *= f
1019
params[m:] *= g
1020
1021
return params
1022
1023
1024
class Species3(Species2):
1025
"""Represents hypotheses about the number of species."""
1026
1027
def Update(self, data):
1028
"""Updates the suite based on the data.
1029
1030
data: list of observations
1031
"""
1032
# sample the likelihoods and add them up
1033
like = numpy.zeros(len(self.ns), dtype=numpy.float)
1034
for _ in range(self.iters):
1035
like += self.SampleLikelihood(data)
1036
1037
self.probs *= like
1038
self.probs /= self.probs.sum()
1039
1040
m = len(data)
1041
self.params[:m] += data
1042
1043
def SampleLikelihood(self, data):
1044
"""Computes the likelihood of the data under all hypotheses.
1045
1046
data: list of observations
1047
"""
1048
# get a random sample
1049
gammas = numpy.random.gamma(self.params)
1050
1051
# row is just the first m elements of gammas
1052
m = len(data)
1053
row = gammas[:m]
1054
1055
# col is the cumulative sum of gammas
1056
col = numpy.cumsum(gammas)[self.ns[0]-1:]
1057
1058
# each row of the array is a set of ps, normalized
1059
# for each hypothetical value of n
1060
array = row / col[:, numpy.newaxis]
1061
1062
# computing the multinomial PDF under a log transform
1063
# take the log of the ps and multiply by the data
1064
terms = numpy.log(array) * data
1065
1066
# add up the rows
1067
log_likes = terms.sum(axis=1)
1068
1069
# before exponentiating, scale into a reasonable range
1070
log_likes -= numpy.max(log_likes)
1071
likes = numpy.exp(log_likes)
1072
1073
# correct for the number of ways we could see m species
1074
# out of a possible n
1075
coefs = [thinkbayes.BinomialCoef(n, m) for n in self.ns]
1076
likes *= coefs
1077
1078
return likes
1079
1080
1081
class Species4(Species):
1082
"""Represents hypotheses about the number of species."""
1083
1084
def Update(self, data):
1085
"""Updates the suite based on the data.
1086
1087
data: list of observed frequencies
1088
"""
1089
m = len(data)
1090
1091
# loop through the species and update one at a time
1092
for i in range(m):
1093
one = numpy.zeros(i+1)
1094
one[i] = data[i]
1095
1096
# call the parent class
1097
Species.Update(self, one)
1098
1099
def Likelihood(self, data, hypo):
1100
"""Computes the likelihood of the data under this hypothesis.
1101
1102
Note: this only works correctly if we update one species at a time.
1103
1104
hypo: Dirichlet object
1105
data: list of observed frequencies
1106
"""
1107
dirichlet = hypo
1108
like = 0
1109
for _ in range(self.iters):
1110
like += dirichlet.Likelihood(data)
1111
1112
# correct for the number of unseen species the new one
1113
# could have been
1114
m = len(data)
1115
num_unseen = dirichlet.n - m + 1
1116
like *= num_unseen
1117
1118
return like
1119
1120
1121
class Species5(Species2):
1122
"""Represents hypotheses about the number of species.
1123
1124
Combines two laters of the hierarchy into one object.
1125
1126
ns and probs represent the distribution of N
1127
1128
params represents the parameters of the Dirichlet distributions
1129
"""
1130
1131
def Update(self, data):
1132
"""Updates the suite based on the data.
1133
1134
data: list of observed frequencies in increasing order
1135
"""
1136
# loop through the species and update one at a time
1137
m = len(data)
1138
for i in range(m):
1139
self.UpdateOne(i+1, data[i])
1140
self.params[i] += data[i]
1141
1142
def UpdateOne(self, i, count):
1143
"""Updates the suite based on the data.
1144
1145
Evaluates the likelihood for all values of n.
1146
1147
i: which species was observed (1..n)
1148
count: how many were observed
1149
"""
1150
# how many species have we seen so far
1151
self.m = i
1152
1153
# how many reads have we seen
1154
self.num_reads += count
1155
1156
if self.iters == 0:
1157
return
1158
1159
# sample the likelihoods and add them up
1160
likes = numpy.zeros(len(self.ns), dtype=numpy.float)
1161
for _ in range(self.iters):
1162
likes += self.SampleLikelihood(i, count)
1163
1164
# correct for the number of unseen species the new one
1165
# could have been
1166
unseen_species = [n-i+1 for n in self.ns]
1167
likes *= unseen_species
1168
1169
# multiply the priors by the likelihoods and renormalize
1170
self.probs *= likes
1171
self.probs /= self.probs.sum()
1172
1173
def SampleLikelihood(self, i, count):
1174
"""Computes the likelihood of the data under all hypotheses.
1175
1176
i: which species was observed
1177
count: how many were observed
1178
"""
1179
# get a random sample of p
1180
gammas = numpy.random.gamma(self.params)
1181
1182
# sums is the cumulative sum of p, for each value of n
1183
sums = numpy.cumsum(gammas)[self.ns[0]-1:]
1184
1185
# get p for the mth species, for each value of n
1186
ps = gammas[i-1] / sums
1187
log_likes = numpy.log(ps) * count
1188
1189
# before exponentiating, scale into a reasonable range
1190
log_likes -= numpy.max(log_likes)
1191
likes = numpy.exp(log_likes)
1192
1193
return likes
1194
1195
1196
def MakePosterior(constructor, data, ns, conc=1, iters=1000):
1197
"""Makes a suite, updates it and returns the posterior suite.
1198
1199
Prints the elapsed time.
1200
1201
data: observed species and their counts
1202
ns: sequence of hypothetical ns
1203
conc: concentration parameter
1204
iters: how many samples to draw
1205
1206
Returns: posterior suite of the given type
1207
"""
1208
suite = constructor(ns, conc=conc, iters=iters)
1209
1210
# print constructor.__name__
1211
start = time.time()
1212
suite.Update(data)
1213
end = time.time()
1214
print 'Processing time', end-start
1215
1216
return suite
1217
1218
1219
def PlotAllVersions():
1220
"""Makes a graph of posterior distributions of N."""
1221
data = [1, 2, 3]
1222
m = len(data)
1223
n = 20
1224
ns = range(m, n)
1225
1226
for constructor in [Species, Species2, Species3, Species4, Species5]:
1227
suite = MakePosterior(constructor, data, ns)
1228
pmf = suite.DistN()
1229
pmf.name = '%s' % (constructor.__name__)
1230
thinkplot.Pmf(pmf)
1231
1232
thinkplot.Save(root='species3',
1233
xlabel='Number of species',
1234
ylabel='Prob')
1235
1236
1237
def PlotMedium():
1238
"""Makes a graph of posterior distributions of N."""
1239
data = [1, 1, 1, 1, 2, 3, 5, 9]
1240
m = len(data)
1241
n = 20
1242
ns = range(m, n)
1243
1244
for constructor in [Species, Species2, Species3, Species4, Species5]:
1245
suite = MakePosterior(constructor, data, ns)
1246
pmf = suite.DistN()
1247
pmf.name = '%s' % (constructor.__name__)
1248
thinkplot.Pmf(pmf)
1249
1250
thinkplot.Show()
1251
1252
1253
def SimpleDirichletExample():
1254
"""Makes a plot showing posterior distributions for three species.
1255
1256
This is the case where we know there are exactly three species.
1257
"""
1258
thinkplot.Clf()
1259
thinkplot.PrePlot(3)
1260
1261
names = ['lions', 'tigers', 'bears']
1262
data = [3, 2, 1]
1263
1264
dirichlet = thinkbayes.Dirichlet(3)
1265
for i in range(3):
1266
beta = dirichlet.MarginalBeta(i)
1267
print 'mean', names[i], beta.Mean()
1268
1269
dirichlet.Update(data)
1270
for i in range(3):
1271
beta = dirichlet.MarginalBeta(i)
1272
print 'mean', names[i], beta.Mean()
1273
1274
pmf = beta.MakePmf(name=names[i])
1275
thinkplot.Pmf(pmf)
1276
1277
thinkplot.Save(root='species1',
1278
xlabel='Prevalence',
1279
ylabel='Prob',
1280
formats=FORMATS,
1281
)
1282
1283
1284
def HierarchicalExample():
1285
"""Shows the posterior distribution of n for lions, tigers and bears.
1286
"""
1287
ns = range(3, 30)
1288
suite = Species(ns, iters=8000)
1289
1290
data = [3, 2, 1]
1291
suite.Update(data)
1292
1293
thinkplot.Clf()
1294
thinkplot.PrePlot(num=1)
1295
1296
pmf = suite.DistN()
1297
thinkplot.Pmf(pmf)
1298
thinkplot.Save(root='species2',
1299
xlabel='Number of species',
1300
ylabel='Prob',
1301
formats=FORMATS,
1302
)
1303
1304
1305
def CompareHierarchicalExample():
1306
"""Makes a graph of posterior distributions of N."""
1307
data = [3, 2, 1]
1308
m = len(data)
1309
n = 30
1310
ns = range(m, n)
1311
1312
constructors = [Species, Species5]
1313
iters = [1000, 100]
1314
1315
for constructor, iters in zip(constructors, iters):
1316
suite = MakePosterior(constructor, data, ns, iters)
1317
pmf = suite.DistN()
1318
pmf.name = '%s' % (constructor.__name__)
1319
thinkplot.Pmf(pmf)
1320
1321
thinkplot.Show()
1322
1323
1324
def ProcessSubjects(codes):
1325
"""Process subjects with the given codes and plot their posteriors.
1326
1327
code: sequence of string codes
1328
"""
1329
thinkplot.Clf()
1330
thinkplot.PrePlot(len(codes))
1331
1332
subjects = ReadRarefactedData()
1333
pmfs = []
1334
for code in codes:
1335
subject = subjects[code]
1336
1337
subject.Process()
1338
pmf = subject.suite.DistN()
1339
pmf.name = subject.code
1340
thinkplot.Pmf(pmf)
1341
1342
pmfs.append(pmf)
1343
1344
print 'ProbGreater', thinkbayes.PmfProbGreater(pmfs[0], pmfs[1])
1345
print 'ProbLess', thinkbayes.PmfProbLess(pmfs[0], pmfs[1])
1346
1347
thinkplot.Save(root='species4',
1348
xlabel='Number of species',
1349
ylabel='Prob',
1350
formats=FORMATS,
1351
)
1352
1353
1354
def RunSubject(code, conc=1, high=500):
1355
"""Run the analysis for the subject with the given code.
1356
1357
code: string code
1358
"""
1359
subjects = JoinSubjects()
1360
subject = subjects[code]
1361
1362
subject.Process(conc=conc, high=high, iters=300)
1363
subject.MakeQuickPrediction()
1364
1365
PrintSummary(subject)
1366
actual_l = subject.total_species - subject.num_species
1367
cdf_l = subject.DistL().MakeCdf()
1368
PrintPrediction(cdf_l, actual_l)
1369
1370
subject.MakeFigures()
1371
1372
num_reads = 400
1373
curves = subject.RunSimulations(100, num_reads)
1374
root = 'species-rare-%s' % subject.code
1375
PlotCurves(curves, root=root)
1376
1377
num_reads = 800
1378
curves = subject.RunSimulations(500, num_reads)
1379
ks = [100, 200, 400, 800]
1380
cdfs = MakeConditionals(curves, ks)
1381
root = 'species-cond-%s' % subject.code
1382
PlotConditionals(cdfs, root=root)
1383
1384
num_reads = 1000
1385
curves = subject.RunSimulations(500, num_reads, frac_flag=True)
1386
ks = [10, 100, 200, 400, 600, 800, 1000]
1387
cdfs = MakeFracCdfs(curves, ks)
1388
root = 'species-frac-%s' % subject.code
1389
PlotFracCdfs(cdfs, root=root)
1390
1391
1392
def PrintSummary(subject):
1393
"""Print a summary of a subject.
1394
1395
subject: Subject
1396
"""
1397
print subject.code
1398
print 'found %d species in %d reads' % (subject.num_species,
1399
subject.num_reads)
1400
1401
print 'total %d species in %d reads' % (subject.total_species,
1402
subject.total_reads)
1403
1404
cdf = subject.suite.DistN().MakeCdf()
1405
print 'n'
1406
PrintPrediction(cdf, 'unknown')
1407
1408
1409
def PrintPrediction(cdf, actual):
1410
"""Print a summary of a prediction.
1411
1412
cdf: predictive distribution
1413
actual: actual value
1414
"""
1415
median = cdf.Percentile(50)
1416
low, high = cdf.CredibleInterval(75)
1417
1418
print 'predicted %0.2f (%0.2f %0.2f)' % (median, low, high)
1419
print 'actual', actual
1420
1421
1422
def RandomSeed(x):
1423
"""Initialize random.random and numpy.random.
1424
1425
x: int seed
1426
"""
1427
random.seed(x)
1428
numpy.random.seed(x)
1429
1430
1431
def GenerateFakeSample(n, r, tr, conc=1):
1432
"""Generates fake data with the given parameters.
1433
1434
n: number of species
1435
r: number of reads in subsample
1436
tr: total number of reads
1437
conc: concentration parameter
1438
1439
Returns: hist of all reads, hist of subsample, prev_unseen
1440
"""
1441
# generate random prevalences
1442
dirichlet = thinkbayes.Dirichlet(n, conc=conc)
1443
prevalences = dirichlet.Random()
1444
prevalences.sort()
1445
1446
# generate a simulated sample
1447
pmf = thinkbayes.MakePmfFromItems(enumerate(prevalences))
1448
cdf = pmf.MakeCdf()
1449
sample = cdf.Sample(tr)
1450
1451
# collect the species counts
1452
hist = thinkbayes.MakeHistFromList(sample)
1453
1454
# extract a subset of the data
1455
if tr > r:
1456
random.shuffle(sample)
1457
subsample = sample[:r]
1458
subhist = thinkbayes.MakeHistFromList(subsample)
1459
else:
1460
subhist = hist
1461
1462
# add up the prevalence of unseen species
1463
prev_unseen = 0
1464
for species, prev in enumerate(prevalences):
1465
if species not in subhist:
1466
prev_unseen += prev
1467
1468
return hist, subhist, prev_unseen
1469
1470
1471
def PlotActualPrevalences():
1472
"""Makes a plot comparing actual prevalences with a model.
1473
"""
1474
# read data
1475
subject_map, _ = ReadCompleteDataset()
1476
1477
# for subjects with more than 50 species,
1478
# PMF of max prevalence, and PMF of max prevalence
1479
# generated by a simulation
1480
pmf_actual = thinkbayes.Pmf()
1481
pmf_sim = thinkbayes.Pmf()
1482
1483
# concentration parameter used in the simulation
1484
conc = 0.06
1485
1486
for code, subject in subject_map.iteritems():
1487
prevalences = subject.GetPrevalences()
1488
m = len(prevalences)
1489
if m < 2:
1490
continue
1491
1492
actual_max = max(prevalences)
1493
print code, m, actual_max
1494
1495
# incr the PMFs
1496
if m > 50:
1497
pmf_actual.Incr(actual_max)
1498
pmf_sim.Incr(SimulateMaxPrev(m, conc))
1499
1500
# plot CDFs for the actual and simulated max prevalence
1501
cdf_actual = pmf_actual.MakeCdf(name='actual')
1502
cdf_sim = pmf_sim.MakeCdf(name='sim')
1503
1504
thinkplot.Cdfs([cdf_actual, cdf_sim])
1505
thinkplot.Show()
1506
1507
1508
def ScatterPrevalences(ms, actual):
1509
"""Make a scatter plot of actual prevalences and expected values.
1510
1511
ms: sorted sequence of in m (number of species)
1512
actual: sequence of actual max prevalence
1513
"""
1514
for conc in [1, 0.5, 0.2, 0.1]:
1515
expected = [ExpectedMaxPrev(m, conc) for m in ms]
1516
thinkplot.Plot(ms, expected)
1517
1518
thinkplot.Scatter(ms, actual)
1519
thinkplot.Show(xscale='log')
1520
1521
1522
def SimulateMaxPrev(m, conc=1):
1523
"""Returns random max prevalence from a Dirichlet distribution.
1524
1525
m: int number of species
1526
conc: concentration parameter of the Dirichlet distribution
1527
1528
Returns: float max of m prevalences
1529
"""
1530
dirichlet = thinkbayes.Dirichlet(m, conc)
1531
prevalences = dirichlet.Random()
1532
return max(prevalences)
1533
1534
1535
def ExpectedMaxPrev(m, conc=1, iters=100):
1536
"""Estimate expected max prevalence.
1537
1538
m: number of species
1539
conc: concentration parameter
1540
iters: how many iterations to run
1541
1542
Returns: expected max prevalence
1543
"""
1544
dirichlet = thinkbayes.Dirichlet(m, conc)
1545
1546
t = []
1547
for _ in range(iters):
1548
prevalences = dirichlet.Random()
1549
t.append(max(prevalences))
1550
1551
return numpy.mean(t)
1552
1553
1554
class Calibrator(object):
1555
"""Encapsulates the calibration process."""
1556
1557
def __init__(self, conc=0.1):
1558
"""
1559
"""
1560
self.conc = conc
1561
1562
self.ps = range(10, 100, 10)
1563
self.total_n = numpy.zeros(len(self.ps))
1564
self.total_q = numpy.zeros(len(self.ps))
1565
self.total_l = numpy.zeros(len(self.ps))
1566
1567
self.n_seq = []
1568
self.q_seq = []
1569
self.l_seq = []
1570
1571
def Calibrate(self, num_runs=100, n_low=30, n_high=400, r=400, tr=1200):
1572
"""Runs calibrations.
1573
1574
num_runs: how many runs
1575
"""
1576
for seed in range(num_runs):
1577
self.RunCalibration(seed, n_low, n_high, r, tr)
1578
1579
self.total_n *= 100.0 / num_runs
1580
self.total_q *= 100.0 / num_runs
1581
self.total_l *= 100.0 / num_runs
1582
1583
def Validate(self, num_runs=100, clean_param=0):
1584
"""Runs validations.
1585
1586
num_runs: how many runs
1587
"""
1588
subject_map, _ = ReadCompleteDataset(clean_param=clean_param)
1589
1590
i = 0
1591
for match in subject_map.itervalues():
1592
if match.num_reads < 400:
1593
continue
1594
num_reads = 100
1595
1596
print 'Validate', match.code
1597
subject = match.Resample(num_reads)
1598
subject.Match(match)
1599
1600
n_actual = None
1601
q_actual = subject.prev_unseen
1602
l_actual = subject.total_species - subject.num_species
1603
self.RunSubject(subject, n_actual, q_actual, l_actual)
1604
1605
i += 1
1606
if i == num_runs:
1607
break
1608
1609
self.total_n *= 100.0 / num_runs
1610
self.total_q *= 100.0 / num_runs
1611
self.total_l *= 100.0 / num_runs
1612
1613
def PlotN(self, root='species-n'):
1614
"""Makes a scatter plot of simulated vs actual prev_unseen (q).
1615
"""
1616
xs, ys = zip(*self.n_seq)
1617
if None in xs:
1618
return
1619
1620
high = max(xs+ys)
1621
1622
thinkplot.Plot([0, high], [0, high], color='gray')
1623
thinkplot.Scatter(xs, ys)
1624
thinkplot.Save(root=root,
1625
xlabel='Actual n',
1626
ylabel='Predicted')
1627
1628
def PlotQ(self, root='species-q'):
1629
"""Makes a scatter plot of simulated vs actual prev_unseen (q).
1630
"""
1631
thinkplot.Plot([0, 0.2], [0, 0.2], color='gray')
1632
xs, ys = zip(*self.q_seq)
1633
thinkplot.Scatter(xs, ys)
1634
thinkplot.Save(root=root,
1635
xlabel='Actual q',
1636
ylabel='Predicted')
1637
1638
def PlotL(self, root='species-n'):
1639
"""Makes a scatter plot of simulated vs actual l.
1640
"""
1641
thinkplot.Plot([0, 20], [0, 20], color='gray')
1642
xs, ys = zip(*self.l_seq)
1643
thinkplot.Scatter(xs, ys)
1644
thinkplot.Save(root=root,
1645
xlabel='Actual l',
1646
ylabel='Predicted')
1647
1648
def PlotCalibrationCurves(self, root='species5'):
1649
"""Plots calibration curves"""
1650
print self.total_n
1651
print self.total_q
1652
print self.total_l
1653
1654
thinkplot.Plot([0, 100], [0, 100], color='gray', alpha=0.2)
1655
1656
if self.total_n[0] >= 0:
1657
thinkplot.Plot(self.ps, self.total_n, label='n')
1658
1659
thinkplot.Plot(self.ps, self.total_q, label='q')
1660
thinkplot.Plot(self.ps, self.total_l, label='l')
1661
1662
thinkplot.Save(root=root,
1663
axis=[0, 100, 0, 100],
1664
xlabel='Ideal percentages',
1665
ylabel='Predictive distributions',
1666
formats=FORMATS,
1667
)
1668
1669
def RunCalibration(self, seed, n_low, n_high, r, tr):
1670
"""Runs a single calibration run.
1671
1672
Generates N and prevalences from a Dirichlet distribution,
1673
then generates simulated data.
1674
1675
Runs analysis to get the posterior distributions.
1676
Generates calibration curves for each posterior distribution.
1677
1678
seed: int random seed
1679
"""
1680
# generate a random number of species and their prevalences
1681
# (from a Dirichlet distribution with alpha_i = conc for all i)
1682
RandomSeed(seed)
1683
n_actual = random.randrange(n_low, n_high+1)
1684
1685
hist, subhist, q_actual = GenerateFakeSample(
1686
n_actual,
1687
r,
1688
tr,
1689
self.conc)
1690
1691
l_actual = len(hist) - len(subhist)
1692
print 'Run low, high, conc', n_low, n_high, self.conc
1693
print 'Run r, tr', r, tr
1694
print 'Run n, q, l', n_actual, q_actual, l_actual
1695
1696
# extract the data
1697
data = [count for species, count in subhist.Items()]
1698
data.sort()
1699
print 'data', data
1700
1701
# make a Subject and process
1702
subject = Subject('simulated')
1703
subject.num_reads = r
1704
subject.total_reads = tr
1705
1706
for species, count in subhist.Items():
1707
subject.Add(species, count)
1708
subject.Done()
1709
1710
self.RunSubject(subject, n_actual, q_actual, l_actual)
1711
1712
def RunSubject(self, subject, n_actual, q_actual, l_actual):
1713
"""Runs the analysis for a subject.
1714
1715
subject: Subject
1716
n_actual: number of species
1717
q_actual: prevalence of unseen species
1718
l_actual: number of new species
1719
"""
1720
# process and make prediction
1721
subject.Process(conc=self.conc, iters=100)
1722
subject.MakeQuickPrediction()
1723
1724
# extract the posterior suite
1725
suite = subject.suite
1726
1727
# check the distribution of n
1728
pmf_n = suite.DistN()
1729
print 'n'
1730
self.total_n += self.CheckDistribution(pmf_n, n_actual, self.n_seq)
1731
1732
# check the distribution of q
1733
pmf_q = suite.DistQ()
1734
print 'q'
1735
self.total_q += self.CheckDistribution(pmf_q, q_actual, self.q_seq)
1736
1737
# check the distribution of additional species
1738
pmf_l = subject.DistL()
1739
print 'l'
1740
self.total_l += self.CheckDistribution(pmf_l, l_actual, self.l_seq)
1741
1742
def CheckDistribution(self, pmf, actual, seq):
1743
"""Checks a predictive distribution and returns a score vector.
1744
1745
pmf: predictive distribution
1746
actual: actual value
1747
seq: which sequence to append (actual, mean) onto
1748
"""
1749
mean = pmf.Mean()
1750
seq.append((actual, mean))
1751
1752
cdf = pmf.MakeCdf()
1753
PrintPrediction(cdf, actual)
1754
1755
sv = ScoreVector(cdf, self.ps, actual)
1756
return sv
1757
1758
1759
def ScoreVector(cdf, ps, actual):
1760
"""Checks whether the actual value falls in each credible interval.
1761
1762
cdf: predictive distribution
1763
ps: percentages to check (0-100)
1764
actual: actual value
1765
1766
Returns: numpy array of 0, 0.5, or 1
1767
"""
1768
scores = []
1769
for p in ps:
1770
low, high = cdf.CredibleInterval(p)
1771
score = Score(low, high, actual)
1772
scores.append(score)
1773
1774
return numpy.array(scores)
1775
1776
1777
def Score(low, high, n):
1778
"""Score whether the actual value falls in the range.
1779
1780
Hitting the posts counts as 0.5, -1 is invalid.
1781
1782
low: low end of range
1783
high: high end of range
1784
n: actual value
1785
1786
Returns: -1, 0, 0.5 or 1
1787
"""
1788
if n is None:
1789
return -1
1790
if low < n < high:
1791
return 1
1792
if n == low or n == high:
1793
return 0.5
1794
else:
1795
return 0
1796
1797
1798
def FakeSubject(n=300, conc=0.1, num_reads=400, prevalences=None):
1799
"""Makes a fake Subject.
1800
1801
If prevalences is provided, n and conc are ignored.
1802
1803
n: number of species
1804
conc: concentration parameter
1805
num_reads: number of reads
1806
prevalences: numpy array of prevalences (overrides n and conc)
1807
"""
1808
# generate random prevalences
1809
if prevalences is None:
1810
dirichlet = thinkbayes.Dirichlet(n, conc=conc)
1811
prevalences = dirichlet.Random()
1812
prevalences.sort()
1813
1814
# generate a simulated sample
1815
pmf = thinkbayes.MakePmfFromItems(enumerate(prevalences))
1816
cdf = pmf.MakeCdf()
1817
sample = cdf.Sample(num_reads)
1818
1819
# collect the species counts
1820
hist = thinkbayes.MakeHistFromList(sample)
1821
1822
# extract the data
1823
data = [count for species, count in hist.Items()]
1824
data.sort()
1825
1826
# make a Subject and process
1827
subject = Subject('simulated')
1828
1829
for species, count in hist.Items():
1830
subject.Add(species, count)
1831
subject.Done()
1832
1833
return subject
1834
1835
1836
def PlotSubjectCdf(code=None, clean_param=0):
1837
"""Checks whether the Dirichlet model can replicate the data.
1838
"""
1839
subject_map, uber_subject = ReadCompleteDataset(clean_param=clean_param)
1840
1841
if code is None:
1842
subjects = subject_map.values()
1843
subject = random.choice(subjects)
1844
code = subject.code
1845
elif code == 'uber':
1846
subject = uber_subject
1847
else:
1848
subject = subject_map[code]
1849
1850
print subject.code
1851
1852
m = subject.GetM()
1853
1854
subject.Process(high=m, conc=0.1, iters=0)
1855
print subject.suite.params[:m]
1856
1857
# plot the cdf
1858
options = dict(linewidth=3, color='blue', alpha=0.5)
1859
cdf = subject.MakeCdf()
1860
thinkplot.Cdf(cdf, **options)
1861
1862
options = dict(linewidth=1, color='green', alpha=0.5)
1863
1864
# generate fake subjects and plot their CDFs
1865
for _ in range(10):
1866
prevalences = subject.suite.SamplePrevalences(m)
1867
fake = FakeSubject(prevalences=prevalences)
1868
cdf = fake.MakeCdf()
1869
thinkplot.Cdf(cdf, **options)
1870
1871
root = 'species-cdf-%s' % code
1872
thinkplot.Save(root=root,
1873
xlabel='rank',
1874
ylabel='CDF',
1875
xscale='log',
1876
formats=FORMATS,
1877
)
1878
1879
1880
def RunCalibration(flag='cal', num_runs=100, clean_param=50):
1881
"""Runs either the calibration or validation process.
1882
1883
flag: string 'cal' or 'val'
1884
num_runs: how many runs
1885
clean_param: parameter used for data cleaning
1886
"""
1887
cal = Calibrator(conc=0.1)
1888
1889
if flag == 'val':
1890
cal.Validate(num_runs=num_runs, clean_param=clean_param)
1891
else:
1892
cal.Calibrate(num_runs=num_runs)
1893
1894
cal.PlotN(root='species-n-%s' % flag)
1895
cal.PlotQ(root='species-q-%s' % flag)
1896
cal.PlotL(root='species-l-%s' % flag)
1897
cal.PlotCalibrationCurves(root='species5-%s' % flag)
1898
1899
1900
def RunTests():
1901
"""Runs calibration code and generates some figures."""
1902
RunCalibration(flag='val')
1903
RunCalibration(flag='cal')
1904
1905
PlotSubjectCdf('B1558.G', clean_param=50)
1906
PlotSubjectCdf(None)
1907
1908
1909
def main(script):
1910
RandomSeed(17)
1911
RunSubject('B1242', conc=1, high=100)
1912
1913
RandomSeed(17)
1914
SimpleDirichletExample()
1915
1916
RandomSeed(17)
1917
HierarchicalExample()
1918
1919
1920
if __name__ == '__main__':
1921
main(*sys.argv)
1922
1923