Source code for remote_sensing_processor.segmentation.semantic.analysis

"""Analysis of semantic segmentation models."""

from pydantic import InstanceOf, NonNegativeInt, PositiveInt, validate_call
from typing import Any, Literal, Optional, Union

import warnings
from pathlib import Path

import joblib
import pandas as pd

import numpy as np

import lightning as l
import shap
import shap.maskers
import torch
from torchmetrics import ConfusionMatrix

from remote_sensing_processor.common.torch_test import cuda_test
from remote_sensing_processor.common.types import FilePath
from remote_sensing_processor.segmentation.segmentation import sklearn_load_dataset
from remote_sensing_processor.segmentation.semantic.models import pytorch_models, sklearn_models
from remote_sensing_processor.segmentation.semantic.segmentation import (
    SemanticSegmentationDataModule,
    SemanticSegmentationModel,
    SklearnSemanticSegmentationModel,
)


class TorchChannelMasker(shap.maskers.Image):
    """A custom masker that can perturb entire image channels. Inherits from shap.maskers.Image for compatibility."""

    def __init__(self, background_data: torch.Tensor) -> None:
        """
        Initializes the masker with a background image.

        The background can be a single image, a batch, or a specific value.
        """
        # We'll use the mean of the channels from the background data as our baseline
        self.background = background_data.mean(dim=(0, 2, 3), keepdim=True)
        super().__init__(self.background.detach().cpu().numpy())

    def mask_shapes(self, x: Union[np.ndarray, torch.Tensor]) -> list[tuple[int]]:
        """Explicitly set mask shapes."""
        return [(x.shape[0],)] * x.shape[0]

    def __call__(self, mask: torch.Tensor, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
        """
        Applies the mask to the input tensor `x`.

        Parameters
        ----------
            mask (np.ndarray): A binary mask array. The shape depends on the explainer.
                               For PartitionExplainer with channel-level grouping,
                               the mask will be a 2D array, e.g., [[True, False, True, ...]].
            x (torch.Tensor): The input image tensor of shape (batch_size, num_channels, H, W).

        Returns
        -------
            A torch.Tensor of the masked images.
        """
        if not isinstance(x, torch.Tensor):
            x = torch.from_numpy(x)

        # The mask from PartitionExplainer can be 1D (num_channels,) or 2D (num_samples, num_channels).
        # We need to make it broadcastable to the shape of x (N, C, H, W).
        channel_mask = torch.from_numpy(mask)
        if channel_mask.dim() == 1:
            # If mask is 1D, reshape to (1, num_channels, 1, 1)
            channel_mask = channel_mask.view(1, -1, 1, 1)
        else:
            # If mask is 2D, reshape to (num_samples, num_channels, 1, 1)
            channel_mask = channel_mask.view(channel_mask.shape[0], -1, 1, 1)

        # self.mask_value is a 1D tensor of shape (num_channels,). Reshape for broadcasting.
        background_values = torch.from_numpy(self.mask_value).float().view(1, -1, 1, 1).to(x.device)

        # Use torch.where to select between original values and background values
        return torch.where(channel_mask, x, background_values)


[docs] @validate_call def band_importance( dataset: dict[str, Any], model: Union[FilePath, InstanceOf[SemanticSegmentationModel], InstanceOf[SklearnSemanticSegmentationModel]], target_class: Optional[PositiveInt] = None, num_init_images: Optional[PositiveInt] = 100, num_images: Optional[PositiveInt] = 500, batch_size: Optional[PositiveInt] = 32, num_workers: Optional[Union[NonNegativeInt, Literal["auto"]]] = 0, ) -> None: """ Explain the band importance for a pre-trained model using SHAP. Parameters ---------- dataset : dict Dataset generated by generate_tiles() function that will be used for prediction. Dataset can contain 3 elements: `path`: a path to a dataset. `sub`: subdataset name, list of subdataset names or 'all'. If not defined, prediction for the whole dataset will be performed. `y`: if there is more than one target variable in dataset, then the name of the variable that should be used for original data reconstruction should be defined. model : torch.nn model or SklearnModel or path to a model file Pre-trained model to predict target values. You can pass the model object returned by `train()` function or file (*.ckpt or *.joblib) where model is stored. target_class : int (optional) Index of the class to analyze band importance. If not set, will analyze all the classes. num_init_images : int (default = 100) Number of images that will be used to initialize the SHAP explainer. We strongly recommend using very small `num_init_images` with sklearn models. num_images : int (default = 500) Number of images that will be used to explain band importance. We strongly recommend using very small `num_images` with sklearn models. batch_size : int (default = 32) Number of samples used in one iteration. Only works for neural networks. num_workers: int or 'auto' (default = 0) Number of parallel workers that will load the data. Set 'auto' to let RSP choose the optimal number of workers, set 0 to disable multiprocessing. It can increase training speed, but can also cause errors (e.g. pickling errors). Examples -------- >>> import remote_sensing_processor as rsp >>> x, y, out_file = ... >>> ds = rsp.semantic.generate_tiles( ... x, ... y, ... out_file, ... tile_size=256, ... shuffle=True, ... split={"train": 3, "val": 1, "test": 1}, ... ) >>> model = rsp.semantic.train( ... {"path": ds, "sub": "train"}, ... {"path": ds, "sub": "val"}, ... model="UperNet", ... backbone="ConvNeXTV2", ... model_file="/home/rsp_test/model/upernet.ckpt", ... batch_size=32, ... ) >>> rsp.semantic.band_importance({"path": ds, "y": "landcover"}, model) PartitionExplainer explainer: 100%|██████████████████████████████████████████▉| 499/500 [42:07<00:05, 5.07s/it] Landsat-B1: 0.0162 Landsat-B2: 0.0493 Landsat-B3: 0.0875 Landsat-B4: 0.0243 Landsat-B5: 0.0319 Landsat-B7: 0.0194 NDVI: 0.0353 NBR: 0.0281 slope: 0.0134 curvature: 0.0239 aspect: 0.0311 dem-norm: 0.0236 >>> ds = {"path": "/home/rsp_test/model/ds.rspds"} >>> model = "/home/rsp_test/model/xgboost.joblib" >>> rsp.semantic.band_importance(ds, model, num_init_images=1, num_images=1) PartitionExplainer explainer: 16385it [1:12:01, 3.78it/s] coastal: 0.0266 blue: 0.0266 green: 0.0253 red: 0.0542 rededge071: 0.0194 rededge075: 0.0194 rededge078: 0.0196 nir: 0.0034 nir08: 0.0111 nir09: 0.0758 swir16: 0.2894 swir22: 0.0309 NDVI: 0.0483 canopyheight_norm: 0.0005 dem_norm: 0.0741 """ dataset["predict"] = True if "sub" not in dataset: dataset["sub"] = "all" # Setting datamodule dm = SemanticSegmentationDataModule(pred_dataset=dataset, batch_size=batch_size, num_workers=num_workers) # Loading model if isinstance(model, Path): if ".ckpt" in model.suffixes: model = SemanticSegmentationModel.load_from_checkpoint(model, weights_only=False) elif ".joblib" in model.suffixes: model = joblib.load(model) else: raise ValueError("Wrong model extension. Should be .ckpt or .joblib") # Reading variables names variables = dm.variables dm.setup(stage="predict") if num_init_images > len(dm.ds_pred): raise ValueError("Not enough samples in dataset") if num_init_images + num_images > len(dm.ds_pred): num_images = len(dm.ds_pred) - num_init_images warnings.warn(f"Not enough samples in dataset, num_images set to {num_images}", stacklevel=2) # Neural networks if model.model_name in pytorch_models: if not cuda_test(): warnings.warn("CUDA or MPS is not available. Predicting on CPU could be very slow.", stacklevel=1) # Initialize dataloader data_loader = dm.predict_dataloader() # Load the background data background_data_list = [] count = 0 for batch in data_loader: background_data_list.append(batch["x"]) count += batch["x"].shape[0] if count >= num_init_images: break background_data = torch.cat(background_data_list, dim=0)[:num_init_images] # Initialize masker with the background data masker = TorchChannelMasker(background_data) # Initialize explainer explainer = shap.PartitionExplainer(lambda x: torch_model_wrapper(x, model, target_class=target_class), masker) # Load the images to be explained images_to_explain_list = [] count = 0 for batch in data_loader: images_to_explain_list.append(batch["x"]) count += batch["x"].shape[0] if count >= num_images: break images_to_explain = torch.cat(images_to_explain_list, dim=0)[:num_images] # Explaining values = explainer(images_to_explain).values global_channel_importance = np.mean(np.abs(values), axis=0).squeeze() for i, var in enumerate(variables): print(f"{var}: {global_channel_importance[i]:.4f}") # Sklearn models elif model.model_name in sklearn_models: x_pred, _, _ = sklearn_load_dataset(dm, "predict", model.generate_features) # Load the background data background_data = x_pred[: num_init_images * (dm.input_shape**2)] explainer = shap.PartitionExplainer( lambda x: sklearn_model_wrapper(x, model, target_class=target_class), background_data, ) # Run SHAP explainer images_to_explain = x_pred[ num_init_images * (dm.input_shape**2) : (num_init_images + num_images) * (dm.input_shape**2) ] values = explainer(images_to_explain).values global_channel_importance = np.mean(np.abs(values), axis=0).squeeze() for i, var in enumerate(variables): print(f"{var}: {global_channel_importance[i]:.4f}") else: raise ValueError("Wrong model name. Check spelling or read a documentation and choose a supported model")
def torch_model_wrapper( x: Union[np.ndarray, torch.Tensor], model: SemanticSegmentationModel, target_class: Optional[int] = None, ) -> np.ndarray: """Predicts and returns the average probability of a target class across the entire image.""" # Ensure the input is a tensor and is on the correct device if not isinstance(x, torch.Tensor): x = torch.from_numpy(x) x = x.to(model.device) batch = {"key": [0] * x.shape[0], "x": x} # The output is a tensor of shape (batch_size, num_classes, H, W) with torch.no_grad(): _, output, _, _ = model.forward(batch) output = torch.nn.functional.softmax(output, dim=1) if target_class is None: # Get the maximum probability across the class dimension for each pixel class_probabilities, _ = torch.max(output, dim=1, keepdim=True) else: # Get the probabilities for the target class class_probabilities = output[:, target_class : target_class + 1, :, :] # Calculate the average probability for the target class across the image avg_prob = torch.mean(class_probabilities, dim=(2, 3)) # Return a NumPy array return avg_prob.detach().cpu().numpy() def sklearn_model_wrapper( x: np.ndarray, model: SklearnSemanticSegmentationModel, target_class: Optional[int] = None, ) -> np.ndarray: """Predicts the class for a target, as predict_proba is not always available.""" if x.ndim == 1: x = x[np.newaxis, :] prediction = model.predict(x) if target_class is not None: return (prediction == target_class).astype("float32") return prediction.astype("float32")
[docs] @validate_call def confusion_matrix( dataset: dict[str, Any], model: Union[FilePath, InstanceOf[SemanticSegmentationModel], InstanceOf[SklearnSemanticSegmentationModel]], batch_size: Optional[PositiveInt] = 32, num_workers: Optional[Union[NonNegativeInt, Literal["auto"]]] = 0, ) -> pd.DataFrame: """ Generate confusion matrix. Row indices of the confusion matrix correspond to the true class labels and column indices correspond to the predicted class labels. Parameters ---------- dataset : dict Dataset generated by generate_tiles() function that will be used for prediction. Dataset can contain 3 elements: `path`: a path to a dataset. `sub`: subdataset name, list of subdataset names or 'all'. If not defined, prediction for the whole dataset will be performed. `y`: if there is more than one target variable in dataset, then the name of the variable that should be used for original data reconstruction should be defined. model : torch.nn model or SklearnModel or path to a model file Pre-trained model to predict target values. You can pass the model object returned by `train()` function or file (*.ckpt or *.joblib) where model is stored. batch_size : int (default = 32) Number of samples used in one iteration. Only works for neural networks. num_workers: int or 'auto' (default = 0) Number of parallel workers that will load the data. Set 'auto' to let RSP choose the optimal number of workers, set 0 to disable multiprocessing. It can increase training speed, but can also cause errors (e.g. pickling errors). Examples -------- >>> import remote_sensing_processor as rsp >>> x, y, out_file = ... >>> ds = rsp.semantic.generate_tiles( ... x, ... y, ... out_file, ... tile_size=256, ... shuffle=True, ... split={"train": 3, "val": 1, "test": 1}, ... ) >>> model = rsp.semantic.train( ... {"path": ds, "sub": "train"}, ... {"path": ds, "sub": "val"}, ... model="UperNet", ... backbone="ConvNeXTV2", ... model_file="/home/rsp_test/model/upernet.ckpt", ... batch_size=32, ... ) >>> rsp.semantic.confusion_matrix({"path": ds, "y": "landcover"}, model) Loading dataset from disk: 100%|██████████████████████████████████████████▉|37/37 [03:35<00:00, 6.71s/it] GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs Loading dataset from disk: 100%|██████████████████████████████████████████▉|37/37 [06:22<00:00, 16.10s/it] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] Predicting DataLoader 0: 100%|██████████████████████████████████████████▉|77/77 [11:27<00:00, 0.11it/s] Loading dataset from disk: 100%|██████████████████████████████████████████▉|37/37 [01:04<00:00, 2.73s/it] +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 1 | 0 | 57715149 | 616341 | 457550 | 1406527 | 508368 | 92479 | 6481 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 2 | 0 | 2844286 | 1599082 | 1083022 | 1371587 | 174298 | 40392 | 2411 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 3 | 0 | 1831478 | 558398 | 2673250 | 5310441 | 555055 | 94868 | 8320 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 4 | 0 | 1981444 | 188564 | 1732960 | 18768424 | 5257066 | 408550 | 8889 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 5 | 0 | 798174 | 24353 | 120365 | 4698820 | 7825820 | 1432090 | 20451 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 6 | 0 | 246685 | 9658 | 13940 | 280521 | 2202678 | 2090811 | 57364 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 7 | 0 | 12279 | 1288 | 2347 | 11409 | 112254 | 345227 | 57887 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 8 | 0 | 0 | 0 | 0 | 585 | 8 | 0 | 0 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 9 | 0 | 27 | 0 | 0 | 643 | 0 | 0 | 0 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ >>> ds = {"path": "/home/rsp_test/model/ds.rspds"} >>> model = "/home/rsp_test/model/xgboost.joblib" >>> rsp.semantic.confusion_matrix(ds, model) Loading dataset from disk: 100%|██████████████████████████████████████████▉|37/37 [03:35<00:00, 6.71s/it] GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs Loading dataset from disk: 100%|██████████████████████████████████████████▉|37/37 [06:22<00:00, 16.10s/it] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] Predicting DataLoader 0: 100%|██████████████████████████████████████████▉|77/77 [11:27<00:00, 0.11it/s] Loading dataset from disk: 100%|██████████████████████████████████████████▉|37/37 [01:04<00:00, 2.73s/it] +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 1 | 0 | 57715149 | 616341 | 457550 | 1406527 | 508368 | 92479 | 6481 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 2 | 0 | 2844286 | 1599082 | 1083022 | 1371587 | 174298 | 40392 | 2411 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 3 | 0 | 1831478 | 558398 | 2673250 | 5310441 | 555055 | 94868 | 8320 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 4 | 0 | 1981444 | 188564 | 1732960 | 18768424 | 5257066 | 408550 | 8889 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 5 | 0 | 798174 | 24353 | 120365 | 4698820 | 7825820 | 1432090 | 20451 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 6 | 0 | 246685 | 9658 | 13940 | 280521 | 2202678 | 2090811 | 57364 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 7 | 0 | 12279 | 1288 | 2347 | 11409 | 112254 | 345227 | 57887 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 8 | 0 | 0 | 0 | 0 | 585 | 8 | 0 | 0 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ | 9 | 0 | 27 | 0 | 0 | 643 | 0 | 0 | 0 | 0 | 0 | +---+---+----------+---------+---------+----------+---------+---------+-------+---+---+ """ dataset["predict"] = True if "sub" not in dataset: dataset["sub"] = "all" # Setting datamodule dm = SemanticSegmentationDataModule(pred_dataset=dataset, batch_size=batch_size, num_workers=num_workers) # Loading model if isinstance(model, Path): if ".ckpt" in model.suffixes: model = SemanticSegmentationModel.load_from_checkpoint(model, weights_only=False) elif ".joblib" in model.suffixes: model = joblib.load(model) else: raise ValueError("Wrong model extension. Should be .ckpt or .joblib") dm.setup(stage="predict") # Neural networks if model.model_name in pytorch_models: if not cuda_test(): warnings.warn("CUDA or MPS is not available. Predicting on CPU could be very slow.", stacklevel=1) # Predict trainer = l.Trainer(precision=model.precision, enable_checkpointing=False) predictions, _ = zip(*trainer.predict(model, dm), strict=True) predictions = torch.cat(predictions, dim=0) # Sklearn models elif model.model_name in sklearn_models: dm.setup(stage="predict") x_pred, _, _ = sklearn_load_dataset(dm, "predict", model.generate_features) # Predict everything predictions = model.predict(x_pred) predictions = predictions.reshape(len(dm.ds_pred), dm.input_shape, dm.input_shape) else: raise ValueError("Wrong model name. Check spelling or read a documentation and choose a supported model") cm = ConfusionMatrix(task="multiclass", num_classes=model.num_classes, ignore_index=model.y_nodata) dataset["predict"] = False dmt = SemanticSegmentationDataModule(test_datasets=[dataset], batch_size=batch_size, num_workers=num_workers) dmt.setup(stage="test") for i in range(len(dmt.ds_test)): y = dmt.ds_test[i]["y"] if not isinstance(predictions, torch.Tensor): predictions = torch.from_numpy(predictions).to(y.device) cm.update(predictions[i], y) return pd.DataFrame(cm.compute().detach().cpu().numpy(), columns=range(model.num_classes))