TP2 - Segmentation sémantique avec un modèle encodeur/décodeur sur Pascal-VOC

Introduction

Contexte

Dans ce second TP, nous nous intéressons au problème de la segmentation sémantique, ou labélisation dense, sur les données Pascal VOC.

La base de données Pascal VOC contient des images de taille variable, RGB, dont les pixels sont étiquetés comme appartenant à l’une des 22 classes représentées ci-dessous avec leur code couleur.

Les 21 classes de PascalVOC avec leur code couleur.

Les 21 classes de PascalVOC avec leur code couleur. Produit par ce code

Des échantillons de la base sont représentés sur l’image ci-dessous, avec l’image et les étiquettes en sur-impression. Le contour des objets est étiquetés avec le label 255 alors que les 21 autres classes sont labelisées de 0 à 20. Les classes sont fortement déséquilibrées puisque la majorité des pixels sont des pixels de la classe « background ».

Des échantillons de la base Pascal VOC

Des échantillons de la base Pascal VOC

Nous allons suivre la même progression que dans le premier TP, en commençant par explorer les données, puis en implémentant un modèle dédié à cette tâche.

Setup

Pour réaliser ce TP, vous disposez de code à trou que vous pouvez récupérer dans le kit _static/lab2.tar.gz.

Vous allez utiliser les ressources de calcul du DCE et coder en utilisant Visual Studio Code. Sur le DCE, vous pouvez récupérer et extraire l’archive du kit par les commandes suivantes :

tar -zxvf lab2.tar.gz

Pour créer des environnements virtuels pour réaliser ce TP :

python3 -m venv lab2
source lab2/bin/activate
python -m pip install -r requirements.txt

Une fois l’environnement virtuel installé, vous pouvez le charger dans un terminal avec la commande :

source lab2/bin/activate

Sous Visual Studio Code, vous pouvez indiquer le chemin vers l’interpréteur python : labs2/bin/python

Assez peu de dépendance sont listées dans ce fichier requirements.txt. La principale est deepcs, un petit framework basé sur pytorch qui vous fourni quelques fonctions de haut niveau assez pratique. Sinon, sachez que pytorch est plutôt considéré comme un framework de bas niveau, avec un certain nombre de fonction à coder soit même.

Chargement et exploration des données

Le chargement des données est géré dans le script data.py. On commence par visualiser quelques échantillons. Dans ce TP, nous utilisons la librairie albumentations, spécifiquement dédiée aux augmentations de données.

Lisez la fonction plot_samples du script data.py. Vous pouvez exécuter ce code pour visualiser quelques échantillons :

python data.py

Dans cette fonction, vous verrez que quelques transformations sont appliquées sur les entrées. Le pipeline qui est construit est assez basique :

  • SmallestMaxSize permet de s’assurer que chaque dimension d’une image est au minimum de 256

  • RandomCrop permet de découper un patch de taille 256×256 choisi aléatoirement

  • Normalize(mean, std, max) applique la transformation (x - max.mean)/(max.std)

  • ToTensorV2 convertit l’image de type PIL en un tenseur pytorch

Complétez le pipeline de transformation des données en ajoutant des augmentations dans la liste augmentation_transforms. Vous disposez d’une liste de transformation sur la page de documentation d’albumentations

Quelques exemples ci-dessous d’échantillons augmentés en utilisant du CoarseDropout, PixelDropout, MaskDropout, HorizontalFlip, ShiftScaleRotate.

Des échantillons augmentés de la base Pascal VOC

Des échantillons augmentés de la base Pascal VOC

Par la suite, vous pourrez rajouter des transformations en évaluant leur impact sur les performances de votre modèle, à savoir l’estimation du risque réel sur le pli de validation.

Encodeur / Décodeur avec un backbone pré-entrainé

On se propose maintenant de pousser un peu plus loin l’exercice qui consiste à coder soit même un modèle. On va s’écrire un modèle de type encodeur/décodeur en partant d’un encodeur pré-entrainé.

Les modèles de classification de scène, par exemple ResNet, sont des modèles tout à fait convenables pour servir de point de départ pour l’encodeur. Pour l’utiliser comme un modèle de ségmentation, il suffit de lui ajouter une tête de décodage, une séquence de modules qui vont produire l’image de label prédit.

Le modèle que nous allons utiliser est le modèle ResNet-18 pour l’encodeur, auquel on va connecter un décodeur. L’image ci-dessous représente un modèle U-Net.

Encodeur/décoder de type U-Net

Encodeur/décoder de type U-Net

L’encodeur vous est déjà fourni par la classe TimmEncoder. On exécute les différents blocs du modèle manuellement :

x = torch.zeros((1, 3, 256, 256))
model = timm.create_model(model_name="resnet18", pretrained=True)
model.eval()

x = model.conv1(x)
x = model.bn1(x)
x = model.act1(x)
x = model.maxpool(x)

f1 = model.layer1(x)
f2 = model.layer2(f1)
f3 = model.layer3(f2)
f4 = model.layer4(f3)

Complétez la fonction test_timm et inspectez les dimensions des tenseurs f_1, f_2, f_3 et f_4.

Complétez la classe DecoderBlock qui est des modules répétés du décodeur. Les opérations d’un DecoderBlock(cin) sont les suivantes :

  • un bloc Conv(cin)-Relu-BN

  • une couche de sur-échantillonnage nn.UpSample, qui double les dimensions spatiales d’un facteur en interpolant son entrée

  • un block Conv(cin//2)-Relu-BN qui divise par 2 le nombre de canaux

  • une concaténation des attributs fournis par l’encodeur (f_3, f_2 ou f_1), en utilisant torch.cat

  • un block Conv(cin//2)-Relu-BN

Les blocs dénotes Conv(cin)-Relu-BN sont une répétition des opérations Conv2d-Relu-BatchNorm2d-Conv2d-Relu-BatchNorm2d. Vous disposez de la fonction conv_relu_bn qui implémente ces blocs. Le paramètre cin est le nombre de canaux de sortie du bloc complet.

Pour tester votre implémentation, faites bon usage de la fonction test_unet dans le script models.py.

Note

Les blocs d’encodeur et de décoder sont alignés de telle sorte que le nombre de canaux fournis par l’encodeur soit le même que le nombre de canaux produit par le décodeur avant la concaténation.

Par exemple, le premier bloc de décodeur DecoderBlock(cin=512) va produire 256 canaux avant la concaténation avec les attributs f_3 fourni par l’encodeur qui contiennent également 256 canaux.

Les dimensions spatiales sont également les mêmes. La concaténation s’opère après le sur-échantillonnage spatial.

Fonction de perte et métrique pour un problème de classification désiquilibré

La fonction de perte focal loss

Nous sommes confrontés à un problème de classification fortement déséquilibré. Il est tentant pour l’apprentissage de ne se focaliser que sur la classe majoritaire. Dans notre cas, c’est la classe background. Il existe plusieurs mécanismes pour contrebalancer l’effet de ce déséquilibre.

On se propose ici d’utiliser une fonction de perte un peu différente de la perte cross-entropique, la focal loss.

Si on note p ∈ [0,1]^K les probabilités affectées par votre modèle aux 21 classes, la perte cross-entropique s’écrit

log(pyi)

La focal loss pondère ce terme :

(1pyi)γlog(pyi)

avec gamma = 2 un facteur à déterminer. Plus la prédiction de votre modèle est bonne p_{y_i} ~ 1, plus le terme (1 - p_{y_i})^γ écrase la pénalité log(p_{y_i}). En d’autres termes, les pixels bien prédit influencent moins la focal loss que la perte cross-entropique.

Comme son implémentation est un peu délicate, la focal loss vous est fournie dans le script losses.py, prenez le temps de lire ce script, en particulier pour voir comment instancier cette fonction de perte.

La métrique F1

Dans un problème de classification déséquilibrée, l’accuracy n’est pas une bonne métrique de qualité puisqu’un modèle prédisant tout le temps la classe majoritaire peut avoir un bon score.

Dans cette situation, on lui préfère la F1 ∈ [0, 1]. On va calculer :

  • une F1 par classe

  • une macro-F1 qui est la moyenne des F1 par classe

La F1 se calcule dans un contexte de classification binaire comme la moyenne harmonique entre la précision et le rappel. Lorsqu’on calcule la F1 pour chaque classe, on est bien dans un contexte de classification binaire: la classe considérée contre le reste des classes.

F1=21precision+1recallprecision=TPTP+FPrecall=TPTP+FN

Note

Pour voir pourquoi la F1 est une meilleure métrique que l’accuracy dans un problème de classification déséquilibré, je vous invite à consulter ce notebook kaggle.

Nos premiers entrainements

Maintenant que les données, le modèle et la fonction de perte sont implémentés, vous pouvez lancer vos premiers entrainements et visualiser des résultats. Pour lancer un entrainement, il vous suffit d’invoquer le script main.py :

python main.py

Certaines options sont paramétrables en ligne de commande. Vous pouvez les découvrir par :

python main.py --help

Pendant l’entrainement, vous pouvez visualiser la progression de vos métriques en utilisant tensorboard :

tensorboard --logdir ./logs

En plus des métriques, le script d’entrainement qui vous est fourni sauvegarde des prédictions réalisées sur le pli de validation qui permettent de mieux apprécier la qualité des prédictions.

A vous de réaliser des expériences pour améliorer les performances de votre modèle et d’interpréter vos résultats. Il est impératif que dans votre rendu, vous décriviez les expériences que vous avez réaliées et que vous interprétiez les résultats (les métriques) et l’impact de vos choix sur ces résultats.

Pour aller plus loin

Utilisation du modèle DeepLab v3+

Comme c’est souvent le cas en deep learning, on a intérêt à utiliser un modèle pré-entrainé. On se propose de commencer avec le modèle DeepLabv3, qui est un modèle de ségmentation pré-entrainé sur la base COCO. Plus tard dans le TP, on implémentera soit même un modèle encodeur/décodeur.

Cette fois-ci, on utilise directement un modèle de segmentation complet, qui a déjà été complètement pré-entrainé sur la base COCO.

Complétez le script models.py pour utiliser l’un des modèles DeepLabV3 mis à disposition par torchvision. Vous devez compléter la classe DeepLabV3 du script models.py. Je vous invite à mettre en place un test dans le script models.py pour vous assurer que la passe forward à travers le modèle se passe bien.

Lancez des entrainements avec ce modèle pré-entrainé et comparez les résultats avec votre modèle n’utilisant que le backbone pré-entrainé.

Rendre plus générique l’architecture encodeur/décodeur

Le code qui vous est fourni reste un petit peu trop spécifique à l’encodeur Resnet18. En particulier, le code de l’encodeur déroule manuellement les différentes étapes de l’encodeur, ce qui n’est pas très pratique si vous voulez tester plusieurs architectures.

Vous pouvez rendre plus générique l’architecture encodeur/décodeur parce que l’API de Timm vous offre une méthode pour extraire attributs extrait d’un backbone de classification.

Adaptant l’exemple de la documentation de Timm:

import torch
import timm

 m = timm.create_model('resnet18', features_only=True, pretrained=True)
 o = m(torch.randn(1, 3, 256, 256))

 for x in o:
     print(f"Feature shape : {x.shape}")

produit la sortie

Feature shape : torch.Size([1, 64, 128, 128])
Feature shape : torch.Size([1, 64, 64, 64])
Feature shape : torch.Size([1, 128, 32, 32])
Feature shape : torch.Size([1, 256, 16, 16])
Feature shape : torch.Size([1, 512, 8, 8])

Sur cet exemple, vous noterez une petite différence par rapport au TP, avec un premier tenseur de taille [1, 64, 128, 128]. Dans tout les cas, votre code générique doit être capable de s’adapter à ces différentes tailles de tenseurs et les différents nombres d’attributs ressortis par le backbone.

Avec un code générique, vous pourriez facilement tester différents backbones comme resnet18, efficientnet_b3, res2net50_26w_4s, etc..

Lorsque l’argument features_only est passé au constructeur du modèle, la passe en avant retourne une liste de tenseurs correspondant aux attributs extrait aux différents étages de l’encodeur. Vous pouvez alors rendre votre réseau UNet plus générique de la manière suivante :

  • vous passez en argument du constructeur de votre modèle le nom du modèle de classification que vous souhaitez utiliser,

  • vous pouvez adapter la passe en avant de l’encodeur pour récupérer l’ensemble des attributs extraits par le backbone,

Pour que le décodeur s’adapte à la structure de l’encodeur, il suffit de modifier le constructeur de votre UNet pour, à la construction, passer un tenseur fictif à l’encodeur, récupérer les attributs extraits et passer le nombre d’attributs et les dimensions des tenseurs au constructeur du décodeur.

Pour mettre en oeuvre le décodeur, compte tenu du fait qu’il doit contenir un nombre variable de bloc de décodeur, ne les stockez pas dans une liste mais dans un torch.nn.ModuleList. Si vous ne les stockez pas dans un ModuleList, les modules définis ne seront pas correctement pris en compte par Pytorch.

Utilisation de la librairie Segmentation Models Pytorch

Pour votre infomation, la librairie segmentation_models_pytorch vous met à disposition plusieurs architectures, pré-entrainées, des fonctions de perte, des métriques, etc.. qui pourraient vous être utiles si vous êtes amenés à travailler sur cette problématique de segmentation sémantique.