Source code for leicaautomator.viewer

"""
scikit-image viewer plugins and widgets.
"""
from skimage import viewer, draw, filters, exposure, measure, color, morphology
from skimage.measure._regionprops import _RegionProperties

import scipy.ndimage as nd
import numpy as np
from matplotlib.patches import Polygon

##
# Viewer
##
class ImageViewer(viewer.viewers.ImageViewer):
    "override viewer to not emit plugin._update_original_image"

    #copied from scikit-image
    def __add__(self, plugin):
        """Add plugin to ImageViewer"""
        plugin.attach(self)
    # do not emit
    #    self.original_image_changed.connect(plugin._update_original_image)

        if plugin.dock:
            location = self.dock_areas[plugin.dock]
            dock_location = viewer.qt.Qt.DockWidgetArea(location)
            dock = viewer.qt.QtWidgets.QDockWidget()
            dock.setWidget(plugin)
            dock.setWindowTitle(plugin.name)
            self.addDockWidget(dock_location, dock)

            horiz = (self.dock_areas['left'], self.dock_areas['right'])
            dimension = 'width' if location in horiz else 'height'
            self._add_widget_size(plugin, dimension=dimension)

        return self

    def show(self):
        "Call filter_image of first plugin, then super.show()"
        # find first plugin which have image_filter != None
        for plugin in self.plugins:
            if plugin.image_filter:
                plugin.filter_image()
                break
        return super(ImageViewer, self).show()



##
# Plugins
##
class SeriesPlugin(viewer.plugins.Plugin):
    "Attach widgets in series. Output of one plugin is sent to the next one."

    def attach(self, image_viewer):
        """Override attach to link plugins with plugin.image_changed instead
        of all listening to image_viewer.original_image_changed
        """
        self.dock = 'right'
        self.setParent(image_viewer)
        self.setWindowFlags(viewer.qt.QtCore.Qt.Dialog)

        self.image_viewer = image_viewer
        if len(image_viewer.plugins) == 0:
            self.arguments = [image_viewer.image]
            image_viewer.original_image_changed.connect(self._update_original_image)
        else:
            self.arguments = [image_viewer.plugins[-1].arguments[0]]
            image_viewer.plugins[-1].image_changed.connect(self._update_original_image)

        image_viewer.plugins.append(self)
        # do not filter image, wait until plugin.show is called
        #self.filter_image()


class EnablePlugin(SeriesPlugin):
    "Plugin with checkbox for enable/disable"
    def __init__(self, **kwargs):
        super(EnablePlugin, self).__init__(**kwargs)
        enable = viewer.widgets.CheckBox('enabled', value=False, ptype='plugin')
        self.add_widget(enable)
        self.enabled = False

    def update_plugin(self, name, val):
        super(EnablePlugin, self).update_plugin(name,val)
        self.filter_image()

    def filter_image(self, *args, **kwargs):
        "Filter if plugin enabled and we have image."
        if self.enabled and len(self.arguments):
            arguments = [self._get_value(a) for a in self.arguments]
            kwargs = dict([(name, self._get_value(a))
                           for name, a in self.keyword_arguments.items()])
            filtered = self.image_filter(*arguments, **kwargs)
        elif len(self.arguments):
            # not enabled
            filtered = self.arguments[0]

        if self is self.image_viewer.plugins[-1]:
            # last plugin, update view
            self.display_filtered_image(filtered)

        # send to next plugin
        self.image_changed.emit(filtered)



class SelemPlugin(EnablePlugin):
    """Add selem size widget for filters that use selem, instead of defining a
    separate filter-function for each of them.
    """
    selem_size = 3
    def __init__(self, **kwargs):
        super(SelemPlugin, self).__init__(**kwargs)
        size = viewer.widgets.Slider('selem', low=1, high=10,
            value=self.selem_size, value_type='int', ptype='plugin',
            update_on='release')
        self.add_widget(size)
        size.callback = self.update_selem
        self.keyword_arguments['selem'] = morphology.square(self.selem_size)

    def update_selem(self, name, value):
        self.keyword_arguments['selem'] = morphology.square(value)
        self.filter_image()


class CropPlugin(SeriesPlugin):
    "Crop plugin with reset button"
    def __init__(self, maxdist=10, **kwargs):
        super(CropPlugin, self).__init__(**kwargs)
        self.name = 'Crop'
        self.maxdist = maxdist

    def attach(self, image_viewer):
        super(CropPlugin, self).attach(image_viewer)
        self.rect_tool = viewer.canvastools.RectangleTool(image_viewer,
                                       maxdist=self.maxdist,
                                       on_enter=self.crop)
        self.artists.append(self.rect_tool)
        self.add_widget(ResetWidget())

    def crop(self, extents):
        xmin, xmax, ymin, ymax = extents
        cropped = self.arguments[0][ymin:ymax+1, xmin:xmax+1]

        self.display_filtered_image(cropped)
        self.image_changed.emit(cropped)


class EntropyPlugin(SelemPlugin):
    name = "Entropy"

    def image_filter(self, img, selem, **kwargs):
        ent = filters.rank.entropy(img, selem)
        return exposure.rescale_intensity(ent)


class PopBilateralPlugin(SelemPlugin):
    name = "Bilateral population"
    selem_size = 9
    width = 20 # bandwith of intensity values

    def __init__(self, **kwargs):
        super(PopBilateralPlugin, self).__init__(**kwargs)
        self.s0 = viewer.widgets.Slider('s0', low=0, high=10,
            value=self.width//2, value_type='int', update_on='release')
        self.s1 = viewer.widgets.Slider('s1', low=0, high=10,
            value=self.width//2, value_type='int', update_on='release')

        self.add_widget(self.s0)
        self.add_widget(self.s1)

    def image_filter(self, img, **kwargs):
        filtered = filters.rank.pop_bilateral(img, **kwargs)
        return exposure.rescale_intensity(-filtered)


class MeanPlugin(SelemPlugin):
    name = 'Mean'
    selem_size = 9

    def image_filter(self, img, **kwargs):
        return filters.rank.mean(img, **kwargs)


class OtsuPlugin(EnablePlugin):
    name = "Otsu Threshold"

    def image_filter(self, image, **kwargs):
        t = filters.threshold_otsu(image)
        return image >= t


class LiThresholdPlugin(EnablePlugin):
    name = "Li Threshold"
    def __init__(self, **kwargs):
        super(LiThresholdPlugin, self).__init__(**kwargs)
        self._invert = viewer.widgets.CheckBox('invert', value=False, ptype='plugin')
        self.add_widget(self._invert)
        self.invert = False

    def image_filter(self, image, **kwargs):
        t = filters.threshold_li(image)
        if self.invert:
            return image < t
        else:
            return image >= t


class ErosionPlugin(SelemPlugin):
    name = "Erosion"

    def image_filter(self, image, selem, **kwargs):
        from skimage import morphology
        return morphology.erosion(image, selem)


class DilationPlugin(SelemPlugin):
    name = "Dilation"

    def image_filter(self, image, selem, **kwargs):
        from skimage import morphology
        return morphology.dilation(image, selem)


class MinimumAreaPlugin(EnablePlugin):
    def __init__(self, minimum_area=4000, **kwargs):
        super(MinimumAreaPlugin, self).__init__(**kwargs)
        self.name = "Minimum area"
        area = viewer.widgets.Slider('minimum_area', low=1000, high=100000,
                    value=minimum_area, value_type='int')
        self.add_widget(area)


    def image_filter(self, img, minimum_area, **kwargs):
        from numpy import bincount

        labels = measure.label(img)
        counts = bincount(labels.ravel())
        # set background count to zero
        counts[counts.argmax()] = 0
        mask = counts > minimum_area
        return mask[labels]


class FillHolesPlugin(EnablePlugin):
    def __init__(self, **kwargs):
        super(FillHolesPlugin, self).__init__(**kwargs)
        self.name = 'Fill holes'
        self.add_widget(viewer.widgets.CheckBox('clear_border'))
        self.add_widget(viewer.widgets.Slider('zero_border', low=0, high=20,
                            value=3, value_type='int'))

    def image_filter(self, img, clear_border, zero_border):
        cleared = img.copy()
        if zero_border:
            a = zero_border
            cleared[ :a,:] = 0
            cleared[-a:,:] = 0
            cleared[:, :a] = 0
            cleared[:,-a:] = 0
        if clear_border:
            segmentation.clear_border(cleared)
        return nd.morphology.binary_fill_holes(cleared)


class LabelPlugin(EnablePlugin):
    name = 'Label'
    def image_filter(self, img, **kwargs):
        l = measure.label(img, background=0)
        if l.max() > 2**16-1:
            print('more than 2^16 labels, aborting labeling')
            return img
        return color.label2rgb(l, image=self.image_viewer.original_image)
                                #bg_color=(0,0,0))


class RegionPlugin(EnablePlugin):
    name = 'Region'
    def __init__(self, **kwargs):
        super(RegionPlugin, self).__init__(**kwargs)
        self.max_regions = viewer.widgets.Slider('maximum number of regions',
                low=0, high=150, value=129, value_type='int', ptype='plugin')
        self.add_widget(self.max_regions)


    def attach(self, image_viewer):
        super(RegionPlugin, self).attach(image_viewer)
        self.regions = []

        self.move_region = MoveRegion(image_viewer, self)
        image_viewer.add_tool(self.move_region)


    def filter_image(self, *args, **kwargs):
        if self.regions:
            # remove previous regions from canvas
            for r in self.regions:
                try:
                    r._polygon.remove()
                    r._text.remove()
                except ValueError:
                    continue
        super(RegionPlugin, self).filter_image(*args, **kwargs)


    def image_filter(self, img):
        self.labels = measure.label(img, background=0)
        # do not use label 0
        self.labels[self.labels==0] = self.labels.max() + 1
        # sorted by size, largest first
        self.regions = sorted((r for r in measure.regionprops(self.labels)),
                              key=lambda r: -r.area)
        # only keep max_regions
        if len(self.regions) > self.max_regions.val:
            self.regions = self.regions[:self.max_regions.val]

        self.median_area = np.median([r.area for r in self.regions])

        self.set_coordinates()
        self.set_well_positions()
        self.create_polygons()
        # overlay on original image
        return self.image_viewer.original_image


    def set_coordinates(self):
        if not self.regions:
            return

        for r in self.regions:
            r.y, r.x, r.y_end, r.x_end = r.bbox


    def create_polygons(self):
        "Creates region._polygon which can be added to the mpl axes."
        if not self.regions:
            return

        for region in self.regions:
            region = create_polygon(region)


    def display_filtered_image(self, image):
        "Display original image with polygons, instead of segmented image."
        ax = self.image_viewer.ax

        # set image and add polygons if called with args
        self.image_viewer.image = image

        if self.enabled:
            for r in self.regions:
                ax.add_patch(r._polygon)
            self.set_texts()
            self.image_viewer.canvas.draw()


    def set_well_positions(self):
        """Set property well_x/y on region.

        Returns
        -------
        list of skimage.regionprops
            Regions with extra property ``well_x`` and ``well_y`` set.
        """
        for direction in ['x', 'y']:
            regions = sorted(self.regions, key=lambda r: getattr(r, direction))

            # calc dx
            previous = regions[0]
            for region in regions:
                dx = getattr(region, direction) - getattr(previous, direction)
                setattr(region, 'd' + direction, dx)
                previous = region

            dxs = np.array([getattr(r, 'd' + direction) for r in regions])
            min_threshold = dxs.max() * 0.5
            mask = np.index_exp[dxs > min_threshold][0]
            # do not include all high dxs
            max_threshold = dxs.max() * 0.9
            mask &= np.index_exp[dxs < max_threshold][0]
            step = np.median(dxs[mask])

            # add well_x/y property to region
            well = 0
            previous = regions[0]
            for r in regions:
                dx = getattr(r, direction) - getattr(previous, direction)
                # if gradient to prev coordinate is high, we have a new row/column
                if dx > min_threshold:
                    well += 1
                setattr(r, 'well_' + direction, well) # start at 1
                previous = r

        self.regions = regions
        return regions


    def set_texts(self):
        "create _text property of well positions"
        ax = self.image_viewer.ax
        for r in self.regions:
            text = '%s,%s' % (r.well_x+1, r.well_y+1) # (1,1) top left
            x = r.x + (r.x_end - r.x) / 4
            y = r.y_end - (r.y_end - r.y) / 3
            try:
                r._text.set_text(text)
                r._text.set_position((x, y))
            except AttributeError:
                r._text = ax.text(x, y, text, color='w',
                                  fontsize=14, backgroundcolor='k')


    def output(self):
        return (self.labels, self.regions)

##
# Helper functions
##
[docs]def create_polygon(r): r.vertices = ((r.x, r.y), (r.x, r.y_end), (r.x_end, r.y_end), (r.x_end, r.y)) r._polygon = Polygon(r.vertices, fill=False, edgecolor='y', linewidth=2) return r ## # Widgets ##
class ResetWidget(viewer.widgets.BaseWidget): "Reset button which sets image to original_image" def __init__(self): super(ResetWidget, self).__init__(self) self.reset_button = viewer.qt.QtGui.QPushButton('Reset') self.reset_button.clicked.connect(self.reset) self.layout = viewer.qt.QtGui.QHBoxLayout(self) self.layout.addWidget(self.reset_button) def reset(self): img = self.plugin.image_viewer.original_image.copy() self.plugin.display_filtered_image(img) self.plugin.image_changed.emit(img) ## # Canvas tools ## class MoveRegion(viewer.canvastools.base.CanvasToolBase): """Moves regions around by clicking on them. http://matplotlib.org/users/event_handling.html#draggable-rectangle-exercise """ def __init__(self, image_viewer, region_plugin): super(MoveRegion, self).__init__(image_viewer) self.region_plugin = region_plugin self.region = None # selected region self.canvas = self.viewer.canvas def on_mouse_press(self, event): if not event.xdata or not event.ydata: return x = int(event.xdata) y = int(event.ydata) # store position, for calculation dx/dy self.x = x self.y = y # will select first region if two regions overlap self.region = next((r for r in self.region_plugin.regions if x >= r.x and x <= r.x_end and y >= r.y and y <= r.y_end), None) if event.dblclick and self.region: # remove self.region_plugin.regions.remove(self.region) self.region._polygon.remove() self.region._text.remove() self.canvas.draw() self.region = None return elif event.dblclick: # add region where double click is at label = self.region_plugin.labels.max() + 1 # square in label image width = (self.region_plugin.median_area)**0.5 / 2 slice_ = (slice(y - width, y + width), slice(x - width, x + width)) self.region_plugin.labels[slice_] = label # add region r = _RegionProperties(slice_, label, self.region_plugin.labels, intensity_image=None, cache_active=False) r.y, r.x = r.centroid r.x -= width r.y -= width # draw square around regions of interest r.x_end = r.x + 2*width r.y_end = r.y + 2*width r = create_polygon(r) self.ax.add_patch(r._polygon) self.region_plugin.regions.append(r) self.region_plugin.set_well_positions() self.region_plugin.set_texts() self.ax.draw_artist(r._polygon) self.ax.draw_artist(r._text) self.canvas.blit(self.ax.bbox) return elif self.region: self.region._polygon.set_animated(True) self.background = self.canvas.copy_from_bbox(self.ax.bbox) def on_move(self, event): if not event.xdata or not event.ydata: return if not self.region: return x = int(event.xdata) y = int(event.ydata) dx = x - self.x dy = y - self.y if dx == 0 and dy == 0: return vertices = [(v[0]+dx, v[1]+dy) for v in self.region.vertices] self.region._polygon.set_xy(vertices) # draw self.canvas.restore_region(self.background) self.ax.draw_artist(self.region._polygon) self.canvas.blit(self.ax.bbox) def on_mouse_release(self, event): if not self.region: return x = int(event.xdata) y = int(event.ydata) dx = x - self.x dy = y - self.y if dx == 0 and dy == 0: # on release first click in double click return vertices = [(v[0]+dx, v[1]+dy) for v in self.region.vertices] self.region._polygon.set_xy(vertices) self.region.vertices = vertices self.region.x += dx self.region.x_end += dx self.region.y += dy self.region.y_end += dy self.region._polygon.set_animated(False) self.region_plugin.set_well_positions() self.region_plugin.set_texts() self.canvas.draw() self.region = None