"""Analysis of semantic segmentation models."""
from pydantic import InstanceOf, NonNegativeInt, PositiveInt, validate_call
from typing import Any, Literal, Optional, Union
import warnings
from pathlib import Path
import joblib
import pandas as pd
import numpy as np
import lightning as l
import shap
import shap.maskers
import torch
from torchmetrics import ConfusionMatrix
from remote_sensing_processor.common.torch_test import cuda_test
from remote_sensing_processor.common.types import FilePath
from remote_sensing_processor.segmentation.segmentation import sklearn_load_dataset
from remote_sensing_processor.segmentation.semantic.models import pytorch_models, sklearn_models
from remote_sensing_processor.segmentation.semantic.segmentation import (
SemanticSegmentationDataModule,
SemanticSegmentationModel,
SklearnSemanticSegmentationModel,
)
class TorchChannelMasker(shap.maskers.Image):
"""A custom masker that can perturb entire image channels. Inherits from shap.maskers.Image for compatibility."""
def __init__(self, background_data: torch.Tensor) -> None:
"""
Initializes the masker with a background image.
The background can be a single image, a batch, or a specific value.
"""
# We'll use the mean of the channels from the background data as our baseline
self.background = background_data.mean(dim=(0, 2, 3), keepdim=True)
super().__init__(self.background.detach().cpu().numpy())
def mask_shapes(self, x: Union[np.ndarray, torch.Tensor]) -> list[tuple[int]]:
"""Explicitly set mask shapes."""
return [(x.shape[0],)] * x.shape[0]
def __call__(self, mask: torch.Tensor, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
"""
Applies the mask to the input tensor `x`.
Parameters
----------
mask (np.ndarray): A binary mask array. The shape depends on the explainer.
For PartitionExplainer with channel-level grouping,
the mask will be a 2D array, e.g., [[True, False, True, ...]].
x (torch.Tensor): The input image tensor of shape (batch_size, num_channels, H, W).
Returns
-------
A torch.Tensor of the masked images.
"""
if not isinstance(x, torch.Tensor):
x = torch.from_numpy(x)
# The mask from PartitionExplainer can be 1D (num_channels,) or 2D (num_samples, num_channels).
# We need to make it broadcastable to the shape of x (N, C, H, W).
channel_mask = torch.from_numpy(mask)
if channel_mask.dim() == 1:
# If mask is 1D, reshape to (1, num_channels, 1, 1)
channel_mask = channel_mask.view(1, -1, 1, 1)
else:
# If mask is 2D, reshape to (num_samples, num_channels, 1, 1)
channel_mask = channel_mask.view(channel_mask.shape[0], -1, 1, 1)
# self.mask_value is a 1D tensor of shape (num_channels,). Reshape for broadcasting.
background_values = torch.from_numpy(self.mask_value).float().view(1, -1, 1, 1).to(x.device)
# Use torch.where to select between original values and background values
return torch.where(channel_mask, x, background_values)
[docs]
@validate_call
def band_importance(
dataset: dict[str, Any],
model: Union[FilePath, InstanceOf[SemanticSegmentationModel], InstanceOf[SklearnSemanticSegmentationModel]],
target_class: Optional[PositiveInt] = None,
num_init_images: Optional[PositiveInt] = 100,
num_images: Optional[PositiveInt] = 500,
batch_size: Optional[PositiveInt] = 32,
num_workers: Optional[Union[NonNegativeInt, Literal["auto"]]] = 0,
) -> None:
"""
Explain the band importance for a pre-trained model using SHAP.
Parameters
----------
dataset : dict
Dataset generated by generate_tiles() function that will be used for prediction.
Dataset can contain 3 elements:
`path`: a path to a dataset.
`sub`: subdataset name, list of subdataset names or 'all'.
If not defined, prediction for the whole dataset will be performed.
`y`: if there is more than one target variable in dataset,
then the name of the variable that should be used for original data reconstruction should be defined.
model : torch.nn model or SklearnModel or path to a model file
Pre-trained model to predict target values.
You can pass the model object returned by `train()` function or file (*.ckpt or *.joblib) where model is stored.
target_class : int (optional)
Index of the class to analyze band importance. If not set, will analyze all the classes.
num_init_images : int (default = 100)
Number of images that will be used to initialize the SHAP explainer.
We strongly recommend using very small `num_init_images` with sklearn models.
num_images : int (default = 500)
Number of images that will be used to explain band importance.
We strongly recommend using very small `num_images` with sklearn models.
batch_size : int (default = 32)
Number of samples used in one iteration. Only works for neural networks.
num_workers: int or 'auto' (default = 0)
Number of parallel workers that will load the data.
Set 'auto' to let RSP choose the optimal number of workers, set 0 to disable multiprocessing.
It can increase training speed, but can also cause errors (e.g. pickling errors).
Examples
--------
>>> import remote_sensing_processor as rsp
>>> x, y, out_file = ...
>>> ds = rsp.semantic.generate_tiles(
... x,
... y,
... out_file,
... tile_size=256,
... shuffle=True,
... split={"train": 3, "val": 1, "test": 1},
... )
>>> model = rsp.semantic.train(
... {"path": ds, "sub": "train"},
... {"path": ds, "sub": "val"},
... model="UperNet",
... backbone="ConvNeXTV2",
... model_file="/home/rsp_test/model/upernet.ckpt",
... batch_size=32,
... )
>>> rsp.semantic.band_importance({"path": ds, "y": "landcover"}, model)
PartitionExplainer explainer: 100%|██████████████████████████████████████████▉| 499/500 [42:07<00:05, 5.07s/it]
Landsat-B1: 0.0162
Landsat-B2: 0.0493
Landsat-B3: 0.0875
Landsat-B4: 0.0243
Landsat-B5: 0.0319
Landsat-B7: 0.0194
NDVI: 0.0353
NBR: 0.0281
slope: 0.0134
curvature: 0.0239
aspect: 0.0311
dem-norm: 0.0236
>>> ds = {"path": "/home/rsp_test/model/ds.rspds"}
>>> model = "/home/rsp_test/model/xgboost.joblib"
>>> rsp.semantic.band_importance(ds, model, num_init_images=1, num_images=1)
PartitionExplainer explainer: 16385it [1:12:01, 3.78it/s]
coastal: 0.0266
blue: 0.0266
green: 0.0253
red: 0.0542
rededge071: 0.0194
rededge075: 0.0194
rededge078: 0.0196
nir: 0.0034
nir08: 0.0111
nir09: 0.0758
swir16: 0.2894
swir22: 0.0309
NDVI: 0.0483
canopyheight_norm: 0.0005
dem_norm: 0.0741
"""
dataset["predict"] = True
if "sub" not in dataset:
dataset["sub"] = "all"
# Setting datamodule
dm = SemanticSegmentationDataModule(pred_dataset=dataset, batch_size=batch_size, num_workers=num_workers)
# Loading model
if isinstance(model, Path):
if ".ckpt" in model.suffixes:
model = SemanticSegmentationModel.load_from_checkpoint(model, weights_only=False)
elif ".joblib" in model.suffixes:
model = joblib.load(model)
else:
raise ValueError("Wrong model extension. Should be .ckpt or .joblib")
# Reading variables names
variables = dm.variables
dm.setup(stage="predict")
if num_init_images > len(dm.ds_pred):
raise ValueError("Not enough samples in dataset")
if num_init_images + num_images > len(dm.ds_pred):
num_images = len(dm.ds_pred) - num_init_images
warnings.warn(f"Not enough samples in dataset, num_images set to {num_images}", stacklevel=2)
# Neural networks
if model.model_name in pytorch_models:
if not cuda_test():
warnings.warn("CUDA or MPS is not available. Predicting on CPU could be very slow.", stacklevel=1)
# Initialize dataloader
data_loader = dm.predict_dataloader()
# Load the background data
background_data_list = []
count = 0
for batch in data_loader:
background_data_list.append(batch["x"])
count += batch["x"].shape[0]
if count >= num_init_images:
break
background_data = torch.cat(background_data_list, dim=0)[:num_init_images]
# Initialize masker with the background data
masker = TorchChannelMasker(background_data)
# Initialize explainer
explainer = shap.PartitionExplainer(lambda x: torch_model_wrapper(x, model, target_class=target_class), masker)
# Load the images to be explained
images_to_explain_list = []
count = 0
for batch in data_loader:
images_to_explain_list.append(batch["x"])
count += batch["x"].shape[0]
if count >= num_images:
break
images_to_explain = torch.cat(images_to_explain_list, dim=0)[:num_images]
# Explaining
values = explainer(images_to_explain).values
global_channel_importance = np.mean(np.abs(values), axis=0).squeeze()
for i, var in enumerate(variables):
print(f"{var}: {global_channel_importance[i]:.4f}")
# Sklearn models
elif model.model_name in sklearn_models:
x_pred, _, _ = sklearn_load_dataset(dm, "predict", model.generate_features)
# Load the background data
background_data = x_pred[: num_init_images * (dm.input_shape**2)]
explainer = shap.PartitionExplainer(
lambda x: sklearn_model_wrapper(x, model, target_class=target_class),
background_data,
)
# Run SHAP explainer
images_to_explain = x_pred[
num_init_images * (dm.input_shape**2) : (num_init_images + num_images) * (dm.input_shape**2)
]
values = explainer(images_to_explain).values
global_channel_importance = np.mean(np.abs(values), axis=0).squeeze()
for i, var in enumerate(variables):
print(f"{var}: {global_channel_importance[i]:.4f}")
else:
raise ValueError("Wrong model name. Check spelling or read a documentation and choose a supported model")
def torch_model_wrapper(
x: Union[np.ndarray, torch.Tensor],
model: SemanticSegmentationModel,
target_class: Optional[int] = None,
) -> np.ndarray:
"""Predicts and returns the average probability of a target class across the entire image."""
# Ensure the input is a tensor and is on the correct device
if not isinstance(x, torch.Tensor):
x = torch.from_numpy(x)
x = x.to(model.device)
batch = {"key": [0] * x.shape[0], "x": x}
# The output is a tensor of shape (batch_size, num_classes, H, W)
with torch.no_grad():
_, output, _, _ = model.forward(batch)
output = torch.nn.functional.softmax(output, dim=1)
if target_class is None:
# Get the maximum probability across the class dimension for each pixel
class_probabilities, _ = torch.max(output, dim=1, keepdim=True)
else:
# Get the probabilities for the target class
class_probabilities = output[:, target_class : target_class + 1, :, :]
# Calculate the average probability for the target class across the image
avg_prob = torch.mean(class_probabilities, dim=(2, 3))
# Return a NumPy array
return avg_prob.detach().cpu().numpy()
def sklearn_model_wrapper(
x: np.ndarray,
model: SklearnSemanticSegmentationModel,
target_class: Optional[int] = None,
) -> np.ndarray:
"""Predicts the class for a target, as predict_proba is not always available."""
if x.ndim == 1:
x = x[np.newaxis, :]
prediction = model.predict(x)
if target_class is not None:
return (prediction == target_class).astype("float32")
return prediction.astype("float32")
[docs]
@validate_call
def confusion_matrix(
dataset: dict[str, Any],
model: Union[FilePath, InstanceOf[SemanticSegmentationModel], InstanceOf[SklearnSemanticSegmentationModel]],
batch_size: Optional[PositiveInt] = 32,
num_workers: Optional[Union[NonNegativeInt, Literal["auto"]]] = 0,
) -> pd.DataFrame:
"""
Generate confusion matrix.
Row indices of the confusion matrix correspond to the true class labels
and column indices correspond to the predicted class labels.
Parameters
----------
dataset : dict
Dataset generated by generate_tiles() function that will be used for prediction.
Dataset can contain 3 elements:
`path`: a path to a dataset.
`sub`: subdataset name, list of subdataset names or 'all'.
If not defined, prediction for the whole dataset will be performed.
`y`: if there is more than one target variable in dataset,
then the name of the variable that should be used for original data reconstruction should be defined.
model : torch.nn model or SklearnModel or path to a model file
Pre-trained model to predict target values.
You can pass the model object returned by `train()` function or file (*.ckpt or *.joblib) where model is stored.
batch_size : int (default = 32)
Number of samples used in one iteration. Only works for neural networks.
num_workers: int or 'auto' (default = 0)
Number of parallel workers that will load the data.
Set 'auto' to let RSP choose the optimal number of workers, set 0 to disable multiprocessing.
It can increase training speed, but can also cause errors (e.g. pickling errors).
Examples
--------
>>> import remote_sensing_processor as rsp
>>> x, y, out_file = ...
>>> ds = rsp.semantic.generate_tiles(
... x,
... y,
... out_file,
... tile_size=256,
... shuffle=True,
... split={"train": 3, "val": 1, "test": 1},
... )
>>> model = rsp.semantic.train(
... {"path": ds, "sub": "train"},
... {"path": ds, "sub": "val"},
... model="UperNet",
... backbone="ConvNeXTV2",
... model_file="/home/rsp_test/model/upernet.ckpt",
... batch_size=32,
... )
>>> rsp.semantic.confusion_matrix({"path": ds, "y": "landcover"}, model)
Loading dataset from disk: 100%|██████████████████████████████████████████▉|37/37 [03:35<00:00, 6.71s/it]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Loading dataset from disk: 100%|██████████████████████████████████████████▉|37/37 [06:22<00:00, 16.10s/it]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████████████████████████████████████▉|77/77 [11:27<00:00, 0.11it/s]
Loading dataset from disk: 100%|██████████████████████████████████████████▉|37/37 [01:04<00:00, 2.73s/it]
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 1 | 0 | 57715149 | 616341 | 457550 | 1406527 | 508368 | 92479 | 6481 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 2 | 0 | 2844286 | 1599082 | 1083022 | 1371587 | 174298 | 40392 | 2411 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 3 | 0 | 1831478 | 558398 | 2673250 | 5310441 | 555055 | 94868 | 8320 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 4 | 0 | 1981444 | 188564 | 1732960 | 18768424 | 5257066 | 408550 | 8889 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 5 | 0 | 798174 | 24353 | 120365 | 4698820 | 7825820 | 1432090 | 20451 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 6 | 0 | 246685 | 9658 | 13940 | 280521 | 2202678 | 2090811 | 57364 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 7 | 0 | 12279 | 1288 | 2347 | 11409 | 112254 | 345227 | 57887 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 8 | 0 | 0 | 0 | 0 | 585 | 8 | 0 | 0 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 9 | 0 | 27 | 0 | 0 | 643 | 0 | 0 | 0 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
>>> ds = {"path": "/home/rsp_test/model/ds.rspds"}
>>> model = "/home/rsp_test/model/xgboost.joblib"
>>> rsp.semantic.confusion_matrix(ds, model)
Loading dataset from disk: 100%|██████████████████████████████████████████▉|37/37 [03:35<00:00, 6.71s/it]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Loading dataset from disk: 100%|██████████████████████████████████████████▉|37/37 [06:22<00:00, 16.10s/it]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|██████████████████████████████████████████▉|77/77 [11:27<00:00, 0.11it/s]
Loading dataset from disk: 100%|██████████████████████████████████████████▉|37/37 [01:04<00:00, 2.73s/it]
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 1 | 0 | 57715149 | 616341 | 457550 | 1406527 | 508368 | 92479 | 6481 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 2 | 0 | 2844286 | 1599082 | 1083022 | 1371587 | 174298 | 40392 | 2411 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 3 | 0 | 1831478 | 558398 | 2673250 | 5310441 | 555055 | 94868 | 8320 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 4 | 0 | 1981444 | 188564 | 1732960 | 18768424 | 5257066 | 408550 | 8889 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 5 | 0 | 798174 | 24353 | 120365 | 4698820 | 7825820 | 1432090 | 20451 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 6 | 0 | 246685 | 9658 | 13940 | 280521 | 2202678 | 2090811 | 57364 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 7 | 0 | 12279 | 1288 | 2347 | 11409 | 112254 | 345227 | 57887 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 8 | 0 | 0 | 0 | 0 | 585 | 8 | 0 | 0 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
| 9 | 0 | 27 | 0 | 0 | 643 | 0 | 0 | 0 | 0 | 0 |
+---+---+----------+---------+---------+----------+---------+---------+-------+---+---+
"""
dataset["predict"] = True
if "sub" not in dataset:
dataset["sub"] = "all"
# Setting datamodule
dm = SemanticSegmentationDataModule(pred_dataset=dataset, batch_size=batch_size, num_workers=num_workers)
# Loading model
if isinstance(model, Path):
if ".ckpt" in model.suffixes:
model = SemanticSegmentationModel.load_from_checkpoint(model, weights_only=False)
elif ".joblib" in model.suffixes:
model = joblib.load(model)
else:
raise ValueError("Wrong model extension. Should be .ckpt or .joblib")
dm.setup(stage="predict")
# Neural networks
if model.model_name in pytorch_models:
if not cuda_test():
warnings.warn("CUDA or MPS is not available. Predicting on CPU could be very slow.", stacklevel=1)
# Predict
trainer = l.Trainer(precision=model.precision, enable_checkpointing=False)
predictions, _ = zip(*trainer.predict(model, dm), strict=True)
predictions = torch.cat(predictions, dim=0)
# Sklearn models
elif model.model_name in sklearn_models:
dm.setup(stage="predict")
x_pred, _, _ = sklearn_load_dataset(dm, "predict", model.generate_features)
# Predict everything
predictions = model.predict(x_pred)
predictions = predictions.reshape(len(dm.ds_pred), dm.input_shape, dm.input_shape)
else:
raise ValueError("Wrong model name. Check spelling or read a documentation and choose a supported model")
cm = ConfusionMatrix(task="multiclass", num_classes=model.num_classes, ignore_index=model.y_nodata)
dataset["predict"] = False
dmt = SemanticSegmentationDataModule(test_datasets=[dataset], batch_size=batch_size, num_workers=num_workers)
dmt.setup(stage="test")
for i in range(len(dmt.ds_test)):
y = dmt.ds_test[i]["y"]
if not isinstance(predictions, torch.Tensor):
predictions = torch.from_numpy(predictions).to(y.device)
cm.update(predictions[i], y)
return pd.DataFrame(cm.compute().detach().cpu().numpy(), columns=range(model.num_classes))