"""UNet model for image segmentation in keras."""
import numpy as np
from keras.layers import (
Activation,
BatchNormalization,
Concatenate,
Conv2D,
Conv2DTranspose,
Dropout,
Input,
MaxPooling2D,
)
from keras.models import Model
from tensorflow.keras.optimizers import Adam
from continunet.constants import CYAN, RESET
[docs]
class Unet:
"""UNet model for image segmentation."""
def __init__(
self,
input_shape: tuple,
filters: int = 16,
dropout: float = 0.05,
batch_normalisation: bool = True,
trained_model: str = None,
image: np.ndarray = None,
layers: int = 4,
output_activation: str = "sigmoid",
model: Model = None,
reconstructed: np.ndarray = None,
):
"""
Initialise the UNet model.
Parameters
----------
input_shape : tuple
The shape of the input image.
filters : int
The number of filters to use in the convolutional layers, default is 16.
dropout : float
The dropout rate, default is 0.05.
batch_normalisation : bool
Whether to use batch normalisation, default is True.
trained_model : str
The path to a trained model.
image : np.ndarray
The image to decode. Image must be 2D given as 4D numpy array, e.g. (1, 256, 256, 1).
Image must be grayscale, e.g. not (1, 256, 256, 3). Image array row columns must
be divisible by 2^layers, e.g. 256 % 2^4 == 0.
layers : int
The number of encoding and decoding layers, default is 4.
output_activation : str
The activation function for the output layer, either sigmoid or softmax.
Default is sigmoid.
model : keras.models.Model
A pre-built model, populated by the build_model method.
reconstructed : np.ndarray
The reconstructed image, created by the decode_image method.
"""
self.input_shape = input_shape
self.filters = filters
self.dropout = dropout
self.batch_normalisation = batch_normalisation
self.trained_model = trained_model
self.image = image
self.layers = layers
self.output_activation = output_activation
self.model = model
self.reconstructed = reconstructed
self.model = self.build_model()
[docs]
def convolutional_block(self, input_tensor, filters, kernel_size=3):
"""Convolutional block for UNet."""
convolutional_layer = Conv2D(
filters=filters,
kernel_size=(kernel_size, kernel_size),
kernel_initializer="he_normal",
padding="same",
)
batch_normalisation_layer = BatchNormalization()
relu_layer = Activation("relu")
if self.batch_normalisation:
return relu_layer(batch_normalisation_layer(convolutional_layer(input_tensor)))
return relu_layer(convolutional_layer(input_tensor))
[docs]
def encoding_block(self, input_tensor, filters, kernel_size=3):
"""Encoding block for UNet."""
convolutional_block = self.convolutional_block(input_tensor, filters, kernel_size)
max_pooling_layer = MaxPooling2D((2, 2), padding="same")
dropout_layer = Dropout(self.dropout)
return convolutional_block, dropout_layer(max_pooling_layer(convolutional_block))
[docs]
def decoding_block(self, input_tensor, concat_tensor, filters, kernel_size=3):
"""Decoding block for UNet."""
transpose_convolutional_layer = Conv2DTranspose(
filters, (3, 3), strides=(2, 2), padding="same"
)
skip_connection = Concatenate()(
[transpose_convolutional_layer(input_tensor), concat_tensor]
)
dropout_layer = Dropout(self.dropout)
return self.convolutional_block(dropout_layer(skip_connection), filters, kernel_size)
[docs]
def build_model(self):
"""Build the UNet model."""
input_image = Input(self.input_shape, name="img")
current = input_image
# Encoding Path
convolutional_tensors = []
for layer in range(self.layers):
convolutional_tensor, current = self.encoding_block(current, self.filters * (2**layer))
convolutional_tensors.append((convolutional_tensor))
# Latent Convolutional Block
latent_convolutional_tensor = self.convolutional_block(
current, filters=self.filters * 2**self.layers
)
# Decoding Path
current = latent_convolutional_tensor
for layer in reversed(range(self.layers)):
current = self.decoding_block(
current, convolutional_tensors[layer], self.filters * (2**layer)
)
outputs = Conv2D(1, (1, 1), activation=self.output_activation)(current)
model = Model(inputs=[input_image], outputs=[outputs])
return model
[docs]
def compile_model(self):
"""Compile the UNet model."""
self.model.compile(
optimizer=Adam(), loss="binary_crossentropy", metrics=["accuracy", "iou_score"]
)
return self.model
[docs]
def decode_image(self):
"""Returns images decoded by a trained model."""
print(f"{CYAN}Predicting source segmentation using pre-trained model...{RESET}")
if self.trained_model is None or self.image is None:
raise ValueError("Trained model and image arguments are required to decode image.")
if isinstance(self.image, np.ndarray) is False:
raise TypeError("Image must be a numpy array.")
if len(self.image.shape) != 4:
raise ValueError("Image must be 4D numpy array for example (1, 256, 256, 1).")
if self.image.shape[3] != 1:
raise ValueError("Input image must be grayscale.")
if (
self.image.shape[1] % (2**self.layers) != 0
or self.image.shape[2] % (2**self.layers) != 0
):
raise ValueError(
f"Image shape {self.image.shape[1:3]} must be divisible by "
f"2**layers={2**self.layers}."
)
self.model = self.compile_model()
self.model.load_weights(self.trained_model)
self.reconstructed = self.model.predict(self.image)
return self.reconstructed