Contact
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
| Download
Views: 6249
License: OTHER
1
from __future__ import division
2
3
import numpy as np
4
import matplotlib.pyplot as plt
5
6
from scipy.ndimage import grey_dilation
7
8
from skimage import img_as_float
9
from skimage import color
10
from skimage import exposure
11
from skimage.util.dtype import dtype_limits
12
13
14
__all__ = ['imshow_all', 'imshow_with_histogram', 'mean_filter_demo',
15
'mean_filter_interactive_demo', 'plot_cdf', 'plot_histogram']
16
17
18
# Gray-scale images should actually be gray!
19
plt.rcParams['image.cmap'] = 'gray'
20
21
22
#--------------------------------------------------------------------------
23
# Custom `imshow` functions
24
#--------------------------------------------------------------------------
25
26
def imshow_rgb_shifted(rgb_image, shift=100, ax=None):
27
"""Plot each RGB layer with an x, y shift."""
28
if ax is None:
29
ax = plt.gca()
30
31
height, width, n_channels = rgb_image.shape
32
x = y = 0
33
for i_channel, channel in enumerate(iter_channels(rgb_image)):
34
image = np.zeros((height, width, n_channels), dtype=channel.dtype)
35
36
image[:, :, i_channel] = channel
37
ax.imshow(image, extent=[x, x+width, y, y+height], alpha=0.7)
38
x += shift
39
y += shift
40
# `imshow` fits the extents of the last image shown, so we need to rescale.
41
ax.autoscale()
42
ax.set_axis_off()
43
44
45
def imshow_all(*images, **kwargs):
46
""" Plot a series of images side-by-side.
47
48
Convert all images to float so that images have a common intensity range.
49
50
Parameters
51
----------
52
limits : str
53
Control the intensity limits. By default, 'image' is used set the
54
min/max intensities to the min/max of all images. Setting `limits` to
55
'dtype' can also be used if you want to preserve the image exposure.
56
titles : list of str
57
Titles for subplots. If the length of titles is less than the number
58
of images, empty strings are appended.
59
kwargs : dict
60
Additional keyword-arguments passed to `imshow`.
61
"""
62
images = [img_as_float(img) for img in images]
63
64
titles = kwargs.pop('titles', [])
65
if len(titles) != len(images):
66
titles = list(titles) + [''] * (len(images) - len(titles))
67
68
limits = kwargs.pop('limits', 'image')
69
if limits == 'image':
70
kwargs.setdefault('vmin', min(img.min() for img in images))
71
kwargs.setdefault('vmax', max(img.max() for img in images))
72
elif limits == 'dtype':
73
vmin, vmax = dtype_limits(images[0])
74
kwargs.setdefault('vmin', vmin)
75
kwargs.setdefault('vmax', vmax)
76
77
nrows, ncols = kwargs.get('shape', (1, len(images)))
78
79
size = nrows * kwargs.pop('size', 5)
80
width = size * len(images)
81
if nrows > 1:
82
width /= nrows * 1.33
83
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(width, size))
84
for ax, img, label in zip(axes.ravel(), images, titles):
85
ax.imshow(img, **kwargs)
86
ax.set_title(label)
87
88
89
def imshow_with_histogram(image, **kwargs):
90
""" Plot an image side-by-side with its histogram.
91
92
- Plot the image next to the histogram
93
- Plot each RGB channel separately (if input is color)
94
- Automatically flatten channels
95
- Select reasonable bins based on the image's dtype
96
97
See `plot_histogram` for information on how the histogram is plotted.
98
"""
99
width, height = plt.rcParams['figure.figsize']
100
fig, (ax_image, ax_hist) = plt.subplots(ncols=2, figsize=(2*width, height))
101
102
kwargs.setdefault('cmap', plt.cm.gray)
103
ax_image.imshow(image, **kwargs)
104
plot_histogram(image, ax=ax_hist)
105
106
# pretty it up
107
ax_image.set_axis_off()
108
match_axes_height(ax_image, ax_hist)
109
return ax_image, ax_hist
110
111
112
#--------------------------------------------------------------------------
113
# Helper functions
114
#--------------------------------------------------------------------------
115
116
117
def match_axes_height(ax_src, ax_dst):
118
""" Match the axes height of two axes objects.
119
120
The height of `ax_dst` is synced to that of `ax_src`.
121
"""
122
# HACK: plot geometry isn't set until the plot is drawn
123
plt.draw()
124
dst = ax_dst.get_position()
125
src = ax_src.get_position()
126
ax_dst.set_position([dst.xmin, src.ymin, dst.width, src.height])
127
128
129
def plot_cdf(image, ax=None):
130
img_cdf, bins = exposure.cumulative_distribution(image)
131
ax.plot(bins, img_cdf, 'r')
132
ax.set_ylabel("Fraction of pixels below intensity")
133
134
135
def plot_histogram(image, ax=None, **kwargs):
136
""" Plot the histogram of an image (gray-scale or RGB) on `ax`.
137
138
Calculate histogram using `skimage.exposure.histogram` and plot as filled
139
line. If an image has a 3rd dimension, assume it's RGB and plot each
140
channel separately.
141
"""
142
ax = ax if ax is not None else plt.gca()
143
144
if image.ndim == 2:
145
_plot_histogram(ax, image, color='black', **kwargs)
146
elif image.ndim == 3:
147
# `channel` is the red, green, or blue channel of the image.
148
for channel, channel_color in zip(iter_channels(image), 'rgb'):
149
_plot_histogram(ax, channel, color=channel_color, **kwargs)
150
151
152
def _plot_histogram(ax, image, alpha=0.3, **kwargs):
153
# Use skimage's histogram function which has nice defaults for
154
# integer and float images.
155
hist, bin_centers = exposure.histogram(image)
156
ax.fill_between(bin_centers, hist, alpha=alpha, **kwargs)
157
ax.set_xlabel('intensity')
158
ax.set_ylabel('# pixels')
159
160
161
def iter_channels(color_image):
162
"""Yield color channels of an image."""
163
# Roll array-axis so that we iterate over the color channels of an image.
164
for channel in np.rollaxis(color_image, -1):
165
yield channel
166
167
168
#--------------------------------------------------------------------------
169
# Convolution Demo
170
#--------------------------------------------------------------------------
171
172
def mean_filter_demo(image, vmax=1):
173
mean_factor = 1.0 / 9.0 # This assumes a 3x3 kernel.
174
iter_kernel_and_subimage = iter_kernel(image)
175
176
image_cache = []
177
178
def mean_filter_step(i_step):
179
while i_step >= len(image_cache):
180
filtered = image if i_step == 0 else image_cache[-1][1]
181
filtered = filtered.copy()
182
183
(i, j), mask, subimage = iter_kernel_and_subimage.next()
184
filter_overlay = color.label2rgb(mask, image, bg_label=0,
185
colors=('yellow', 'red'))
186
filtered[i, j] = np.sum(mean_factor * subimage)
187
image_cache.append((filter_overlay, filtered))
188
189
imshow_all(*image_cache[i_step], vmax=vmax)
190
plt.show()
191
return mean_filter_step
192
193
194
def mean_filter_interactive_demo(image):
195
from IPython.html import widgets
196
mean_filter_step = mean_filter_demo(image)
197
step_slider = widgets.IntSliderWidget(min=0, max=image.size-1, value=0)
198
widgets.interact(mean_filter_step, i_step=step_slider)
199
200
201
def iter_kernel(image, size=1):
202
""" Yield position, kernel mask, and image for each pixel in the image.
203
204
The kernel mask has a 2 at the center pixel and 1 around it. The actual
205
width of the kernel is 2*size + 1.
206
"""
207
width = 2*size + 1
208
for (i, j), pixel in iter_pixels(image):
209
mask = np.zeros(image.shape, dtype='int16')
210
mask[i, j] = 1
211
mask = grey_dilation(mask, size=width)
212
mask[i, j] = 2
213
subimage = image[bounded_slice((i, j), image.shape[:2], size=size)]
214
yield (i, j), mask, subimage
215
216
217
def iter_pixels(image):
218
""" Yield pixel position (row, column) and pixel intensity. """
219
height, width = image.shape[:2]
220
for i in range(height):
221
for j in range(width):
222
yield (i, j), image[i, j]
223
224
225
def bounded_slice(center, xy_max, size=1, i_min=0):
226
slices = []
227
for i, i_max in zip(center, xy_max):
228
slices.append(slice(max(i - size, i_min), min(i + size + 1, i_max)))
229
return slices
230
231