This tutorial is available as a Jupyter notebook.

Open in Colab

๐Ÿฉบ Image Segmentation#

In this notebook you will use Composer and PyTorch to segment pneumothorax (air around or outside of the lungs) from chest radiographic images. This dataset was originally released for a kaggle competition by the Society for Informatics in Medicine (SIIM).

Disclaimer: This example represents a minimal working baseline. In order to get competitive results this notebook must run for a long time.

We will cover: - installing relevant packages - downloading the SIIM dataset from kaggle - cleaning and resampling the dataset - splitting data for validation - visualizing model inputs - training a baseline model with Composer - using Composer methods - next steps


Letโ€™s get started and configure our environment.

Install Dependencies#

If you havenโ€™t already, letโ€™s install the following dependencies, which are needed for this example:

[ ]:
%pip install mosaicml kaggle pydicom git+https://github.com/qubvel/segmentation_models.pytorch opencv-python-headless jupyterlab-widgets

Kaggle Authentication#

To access the data you need a Kaggle Account - accept competition terms https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/data - download kaggle.json from https://www.kaggle.com/yourusername/account by clicking โ€œCreate new API tokenโ€ - upload kaggle.json file using the following code cells.

[ ]:
from ipywidgets import FileUpload
from IPython.display import display
uploader = FileUpload(accept='.json', multiple=True)
[ ]:
import os

kaggle_folder = os.path.join(os.path.expanduser("~"), ".kaggle")
os.makedirs(kaggle_folder, exist_ok=True)
kaggle_config_file = os.path.join(kaggle_folder, "kaggle.json")
with open(kaggle_config_file, 'wb+') as output_file:
    for uploaded_filename in uploader.value:
        content = uploader.value[uploaded_filename]['content']

Download and unzip the data#

[ ]:
!kaggle datasets download -d seesee/siim-train-test
!unzip -q siim-train-test.zip -d .

Flatten Image Directories#

The original dataset is oddly nested, we flatten it out so the images are easier to access in our pytorch dataset.

/siim/dicom-images-train/id/id/id.dcm to /siim/dicom-images-train/id.dcm.

[ ]:
from pathlib import Path
from tqdm.auto import tqdm

train_images = list(Path('siim/dicom-images-train').glob('*/*/*.dcm'))
for image in tqdm(train_images):

Project setup#


[ ]:
import itertools
from ipywidgets import interact, fixed, IntSlider

import numpy as np
import pandas as pd
import torch
from torch import nn
import matplotlib.pyplot as plt
import cv2

from pydicom.filereader import dcmread
# model
import segmentation_models_pytorch as smp

# data
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import draw_segmentation_masks, make_grid

from sklearn.model_selection import StratifiedKFold

# transforms
from albumentations import ShiftScaleRotate, Resize, Compose

from torchmetrics import Metric
from torchmetrics.collections import MetricCollection

from composer import Trainer
from composer.models import ComposerModel
from composer.optim import DecoupledAdamW
from composer.metrics.metrics import Dice


Here we define some utility functions to help with logging, decoding/encoding targets and visualization.

[ ]:
class LossMetric(Metric):
    """Turns any torch.nn Loss Module into distributed torchmetrics Metric."""

    def __init__(self, loss, dist_sync_on_step=False):
        self.loss = loss
        self.add_state("sum_loss", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("total_batches", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds, target):
        """Update the state with new predictions and targets.
        # Loss calculated over samples/batch, accumulate loss over all batches
        self.sum_loss += self.loss(preds, target)
        self.total_batches += 1

    def compute(self):
        """Aggregate state over all processes and compute the metric.
        # Return average loss over entire validation dataset
        return self.sum_loss / self.total_batches

def rle2mask(rle, height=1024, width=1024, fill_value=1):
    mask = np.zeros((height, width), np.float32)
    mask = mask.reshape(-1)
    rle = np.array([int(s) for s in rle.strip().split(' ')])
    rle = rle.reshape(-1, 2)
    start = 0
    for index, length in rle:
        start = start+index
        end = start+length
        mask[start: end] = fill_value
        start = end
    mask = mask.reshape(width, height).T
    return mask

def mask2rle(mask):
    mask = mask.T.flatten()
    start = np.where(mask[1:] > mask[:-1])[0]+1
    end = np.where(mask[:-1] > mask[1:])[0]+1
    length = end-start
    rle = []
    for i in range(len(length)):
        if i == 0:
            rle.extend([start[0], length[0]])
            rle.extend([start[i]-end[i-1], length[i]])
    rle = ' '.join([str(r) for r in rle])
    return rle

Preprocessing and Data Science#

SIIM Dataset#

The SIIM dataset consists of: - dicom-images-train - 12954 labeled images in DICOM format. - dicom-images-test3205 unlabeled DICOM images for testing

  • train-rle.csv comes with a label file train-rle.csv mapping ImageId to EncodedPixels.

    • ImageIds map to image paths for DICOM format images.

  • EncodedPixels are run length encoded segmentation masks representing areas where pneumothorax has been labeled by an expert. A label of "-1" indicates the image was examined and no pneumothorax was found.

[ ]:
!ls siim
[ ]:
labels_df = pd.read_csv('siim/train-rle.csv')

Clean Data#

Of the ~13,000, only 3600 have masks. We will throw out some of the negative samples to better balance our dataset and speed up training.

[ ]:
labels_df[labels_df[" EncodedPixels"] != "-1"].shape, labels_df[labels_df[" EncodedPixels"] == "-1"].shape
[ ]:
def balance_labels(labels_df, extra_samples_without_mask=1500, random_state=1337):
    Drop duplicates and mark samples with masks.
    Sample 3576+extra_samples_without_mask unmasked samples to balance dataset.
    df = labels_df.drop_duplicates('ImageId')
    df_with_mask = df[df[" EncodedPixels"] != "-1"].copy(deep=True)
    df_with_mask['has_mask'] = 1
    df_without_mask = df[df[" EncodedPixels"] == "-1"].copy(deep=True)
    df_without_mask['has_mask'] = 0
    df_without_mask_sampled = df_without_mask.sample(len(df_with_mask)+extra_samples_without_mask, random_state=random_state)
    df = pd.concat([df_with_mask, df_without_mask_sampled])
    return df
[ ]:
df = balance_labels(labels_df)

Create Cross Validation Splits#

Once cleaned and balanced, weโ€™re left with only 6838 images. This will leave us with rather small training and validation sets once we split the data. To mitigate the chances of us validating on a poorly sampled (not representative of our unlabeled test data) validation set, we use StratifiedKFold to create 5 different 80%-20%, train eval splits.


For datasets of this size, itโ€™s good practice to train and evaluate on each split, but due to runtime constraints in this notebook we will only train on the first split which contains 5470 training and 1368 eval samples.

[ ]:
kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=1337)
train_idx, eval_idx = list(kfold.split(df["ImageId"], df["has_mask"]))[0]
train_df, eval_df = df.iloc[train_idx], df.iloc[eval_idx]
train_df.shape, eval_df.shape


PyTorch Dataset#

SIIMDataset is a standard pytorch dataset that reads images and decodes labels from the siim label csv. DICOM images are loaded as grayscale numpy arrays, converted to rgb, and scaled. Labels are converted from rle strings to binary segmentation masks.

[ ]:
class SIIMDataset(Dataset):
    def __init__(self,
        self.labels_df = labels_df
        self.image_dir = image_dir
        self.transforms = transforms

    def __getitem__(self, idx):
        row = self.labels_df.iloc[idx]
        image_id = row.ImageId
        image_path = self.image_dir / f'{image_id}.dcm'
        image = dcmread(image_path).pixel_array # load dicom image
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # convert rgb so we can keep imagenet first layer weights
        image = (image / 255.).astype('float32') # scale (0.- 1.)

        rle = row[' EncodedPixels']
        if rle != '-1':
            mask = rle2mask(rle, 1024, 1024).astype('float32')
            mask = np.zeros([1024, 1024]).astype('float32')

        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return (
            torch.from_numpy(image).permute(2, 0, 1),

    def __len__(self):
        return len(self.labels_df)


We use the albumentations library to Resize, and randomly scale/rotate our training images.

[ ]:
image_size = 512

train_transforms = Compose(
        Resize(image_size, image_size),
            rotate_limit=10, # rotate

eval_transforms = Compose([Resize(image_size, image_size)])


[ ]:

train_batch_size = 32 val_batch_size = 32 train_dataloader = DataLoader(SIIMDataset(train_df, transforms=train_transforms), batch_size=train_batch_size, shuffle=True, num_workers=2) eval_dataloader = DataLoader(SIIMDataset(eval_df, transforms=eval_transforms), batch_size=val_batch_size, shuffle=False, num_workers=2)

Visualize batch#

Areas of pneumothorax as highlighted in red, drag the slider to iterate through batches.

[ ]:
@interact(data_loader=fixed(train_dataloader), batch=IntSlider(min=0, max=len(train_dataloader)-1, step=1, value=0))
def show_batch(data_loader, batch):
    plt.rcParams['figure.figsize'] = [20, 15]

    images, masks = list(itertools.islice(data_loader, batch, batch+1))[0]
    masks_list = []
    for image, mask in zip(images, masks):
        masked = draw_segmentation_masks((image * 255).byte(),
                                    mask.bool(), alpha=0.5, colors='red')

    grid  = make_grid(masks_list, nrow=6)
    plt.imshow(grid.permute(1, 2, 0));



Here we define a composer model that wraps the smp segmentation models pytorch package. This lets us quickly create many different segmentation models made from common pre-trained pytorch encoders.

  • We set defaults to create a Unet from an ImageNet pre-trained ResNet34 with 3 input channels for our RGB (converted) inputs and 1 output channel.

  • We set the default loss to nn.BCEWithLogitsLoss() to classify each pixel of the output.

[ ]:
class SMPUNet(ComposerModel):
    def __init__(self,
                 in_channels=3, classes=1,
        self.model = smp.Unet(
            encoder_weights=encoder_weights,     # use `imagenet` pre-trained weights for encoder initialization
            in_channels=in_channels,        # model input channels (1 for gray-scale images, 3 for RGB, etc.)
            classes=classes         # model output channels (number of classes in your dataset)

        self.criterion = loss
        self.train_loss = LossMetric(loss)
        self.val_loss = LossMetric(loss)
        self.val_dice = Dice(num_classes=classes)

    def forward(self, batch):
        images, targets = batch
        return self.model(images)

    def loss(self, outputs, batch):
        _, targets = batch
        return self.criterion(outputs, targets)

    def metrics(self, train: bool = False):
        return self.train_loss if self.train else MetricCollection([self.val_loss, self.dice])

    def validate(self, batch):
        images, targets = batch
        return self.model(images), targets
[ ]:
model = SMPUNet() # define unet model
optimizer = DecoupledAdamW(model.parameters(), lr=1e-3)


[ ]:
trainer = Trainer(


  • composer allows us to quickly experiment with algorithms that can speed up or improve the quality of our model. This is how we can add CutOut and LabelSmoothing

  • additionally, the composer trainer has builtin support for automatic mixed precision training and gradient accumulation to help train quickly and simulate larger batch sizes.

[ ]:
from composer.algorithms import CutOut, LabelSmoothing

model = SMPUNet() # define unet model
optimizer = DecoupledAdamW(model.parameters(), lr=1e-3)

algorithms = [CutOut(length=0.5), LabelSmoothing(smoothing=0.1)]

trainer = Trainer(

Next steps#

  • train longer

  • try different loss functions, architectures, transformations

  • try different combinations of composer methods!