Source code for continunet.network.unet

"""UNet model for image segmentation in keras."""

import numpy as np

from keras.layers import (
    Activation,
    BatchNormalization,
    Concatenate,
    Conv2D,
    Conv2DTranspose,
    Dropout,
    Input,
    MaxPooling2D,
)
from keras.models import Model
from tensorflow.keras.optimizers import Adam

from continunet.constants import CYAN, RESET


[docs] class Unet: """UNet model for image segmentation.""" def __init__( self, input_shape: tuple, filters: int = 16, dropout: float = 0.05, batch_normalisation: bool = True, trained_model: str = None, image: np.ndarray = None, layers: int = 4, output_activation: str = "sigmoid", model: Model = None, reconstructed: np.ndarray = None, ): """ Initialise the UNet model. Parameters ---------- input_shape : tuple The shape of the input image. filters : int The number of filters to use in the convolutional layers, default is 16. dropout : float The dropout rate, default is 0.05. batch_normalisation : bool Whether to use batch normalisation, default is True. trained_model : str The path to a trained model. image : np.ndarray The image to decode. Image must be 2D given as 4D numpy array, e.g. (1, 256, 256, 1). Image must be grayscale, e.g. not (1, 256, 256, 3). Image array row columns must be divisible by 2^layers, e.g. 256 % 2^4 == 0. layers : int The number of encoding and decoding layers, default is 4. output_activation : str The activation function for the output layer, either sigmoid or softmax. Default is sigmoid. model : keras.models.Model A pre-built model, populated by the build_model method. reconstructed : np.ndarray The reconstructed image, created by the decode_image method. """ self.input_shape = input_shape self.filters = filters self.dropout = dropout self.batch_normalisation = batch_normalisation self.trained_model = trained_model self.image = image self.layers = layers self.output_activation = output_activation self.model = model self.reconstructed = reconstructed self.model = self.build_model()
[docs] def convolutional_block(self, input_tensor, filters, kernel_size=3): """Convolutional block for UNet.""" convolutional_layer = Conv2D( filters=filters, kernel_size=(kernel_size, kernel_size), kernel_initializer="he_normal", padding="same", ) batch_normalisation_layer = BatchNormalization() relu_layer = Activation("relu") if self.batch_normalisation: return relu_layer(batch_normalisation_layer(convolutional_layer(input_tensor))) return relu_layer(convolutional_layer(input_tensor))
[docs] def encoding_block(self, input_tensor, filters, kernel_size=3): """Encoding block for UNet.""" convolutional_block = self.convolutional_block(input_tensor, filters, kernel_size) max_pooling_layer = MaxPooling2D((2, 2), padding="same") dropout_layer = Dropout(self.dropout) return convolutional_block, dropout_layer(max_pooling_layer(convolutional_block))
[docs] def decoding_block(self, input_tensor, concat_tensor, filters, kernel_size=3): """Decoding block for UNet.""" transpose_convolutional_layer = Conv2DTranspose( filters, (3, 3), strides=(2, 2), padding="same" ) skip_connection = Concatenate()( [transpose_convolutional_layer(input_tensor), concat_tensor] ) dropout_layer = Dropout(self.dropout) return self.convolutional_block(dropout_layer(skip_connection), filters, kernel_size)
[docs] def build_model(self): """Build the UNet model.""" input_image = Input(self.input_shape, name="img") current = input_image # Encoding Path convolutional_tensors = [] for layer in range(self.layers): convolutional_tensor, current = self.encoding_block(current, self.filters * (2**layer)) convolutional_tensors.append((convolutional_tensor)) # Latent Convolutional Block latent_convolutional_tensor = self.convolutional_block( current, filters=self.filters * 2**self.layers ) # Decoding Path current = latent_convolutional_tensor for layer in reversed(range(self.layers)): current = self.decoding_block( current, convolutional_tensors[layer], self.filters * (2**layer) ) outputs = Conv2D(1, (1, 1), activation=self.output_activation)(current) model = Model(inputs=[input_image], outputs=[outputs]) return model
[docs] def compile_model(self): """Compile the UNet model.""" self.model.compile( optimizer=Adam(), loss="binary_crossentropy", metrics=["accuracy", "iou_score"] ) return self.model
[docs] def decode_image(self): """Returns images decoded by a trained model.""" print(f"{CYAN}Predicting source segmentation using pre-trained model...{RESET}") if self.trained_model is None or self.image is None: raise ValueError("Trained model and image arguments are required to decode image.") if isinstance(self.image, np.ndarray) is False: raise TypeError("Image must be a numpy array.") if len(self.image.shape) != 4: raise ValueError("Image must be 4D numpy array for example (1, 256, 256, 1).") if self.image.shape[3] != 1: raise ValueError("Input image must be grayscale.") if ( self.image.shape[1] % (2**self.layers) != 0 or self.image.shape[2] % (2**self.layers) != 0 ): raise ValueError( f"Image shape {self.image.shape[1:3]} must be divisible by " f"2**layers={2**self.layers}." ) self.model = self.compile_model() self.model.load_weights(self.trained_model) self.reconstructed = self.model.predict(self.image) return self.reconstructed