Source code for remote_sensing_processor.segmentation.regression.analysis

"""Analysis of semantic segmentation models."""

from pydantic import InstanceOf, NonNegativeInt, PositiveInt, validate_call
from typing import Any, Literal, Optional, Union

import warnings
from pathlib import Path

import joblib

import numpy as np

import shap
import shap.maskers
import torch

from remote_sensing_processor.common.torch_test import cuda_test
from remote_sensing_processor.common.types import FilePath
from remote_sensing_processor.segmentation.regression.models import pytorch_models, sklearn_models
from remote_sensing_processor.segmentation.regression.segmentation import (
    RegressionDataModule,
    RegressionModel,
    SklearnRegressionModel,
)
from remote_sensing_processor.segmentation.segmentation import sklearn_load_dataset


class TorchChannelMasker(shap.maskers.Image):
    """A custom masker that can perturb entire image channels. Inherits from shap.maskers.Image for compatibility."""

    def __init__(self, background_data: torch.Tensor) -> None:
        """
        Initializes the masker with a background image.

        The background can be a single image, a batch, or a specific value.
        """
        # We'll use the mean of the channels from the background data as our baseline
        self.background = background_data.mean(dim=(0, 2, 3), keepdim=True)
        super().__init__(self.background.detach().cpu().numpy())

    def mask_shapes(self, x: Union[np.ndarray, torch.Tensor]) -> list[tuple[int]]:
        """Explicitly set mask shapes."""
        return [(x.shape[0],)] * x.shape[0]

    def __call__(self, mask: np.ndarray, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
        """
        Applies the mask to the input tensor `x`.

        Parameters
        ----------
            mask (np.ndarray): A binary mask array. The shape depends on the explainer.
                               For PartitionExplainer with channel-level grouping,
                               the mask will be a 2D array, e.g., [[True, False, True, ...]].
            x (torch.Tensor): The input image tensor of shape (batch_size, num_channels, H, W).

        Returns
        -------
            A torch.Tensor of the masked images.
        """
        if not isinstance(x, torch.Tensor):
            x = torch.from_numpy(x)

        # The mask from PartitionExplainer can be 1D (num_channels,) or 2D (num_samples, num_channels).
        # We need to make it broadcastable to the shape of x (N, C, H, W).
        channel_mask = torch.from_numpy(mask)
        if channel_mask.dim() == 1:
            # If mask is 1D, reshape to (1, num_channels, 1, 1)
            channel_mask = channel_mask.view(1, -1, 1, 1)
        else:
            # If mask is 2D, reshape to (num_samples, num_channels, 1, 1)
            channel_mask = channel_mask.view(channel_mask.shape[0], -1, 1, 1)

        # self.mask_value is a 1D tensor of shape (num_channels,). Reshape for broadcasting.
        background_values = torch.from_numpy(self.mask_value).float().view(1, -1, 1, 1).to(x.device)

        # Use torch.where to select between original values and background values
        return torch.where(channel_mask, x, background_values)


[docs] @validate_call def band_importance( dataset: dict[str, Any], model: Union[FilePath, InstanceOf[RegressionModel], InstanceOf[SklearnRegressionModel]], num_init_images: Optional[PositiveInt] = 100, num_images: Optional[PositiveInt] = 500, batch_size: Optional[PositiveInt] = 32, num_workers: Optional[Union[NonNegativeInt, Literal["auto"]]] = 0, ) -> None: """ Explain the band importance for a pre-trained model using SHAP. Parameters ---------- dataset : dict Dataset generated by generate_tiles() function that will be used for prediction. Dataset can contain 3 elements: `path`: a path to a dataset. `sub`: subdataset name, list of subdataset names or 'all'. If not defined, prediction for the whole dataset will be performed. `y`: if there is more than one target variable in dataset, then the name of the variable that should be used for original data reconstruction should be defined. model : torch.nn model or SklearnModel or path to a model file Pre-trained model to predict target values. You can pass the model object returned by `train()` function or file (*.ckpt or *.joblib) where model is stored. num_init_images : int (default = 100) Number of images that will be used to initialize the SHAP explainer. We strongly recommend using very small `num_init_images` with sklearn models. num_images : int (default = 500) Number of images that will be used to explain band importance. We strongly recommend using very small `num_images` with sklearn models. batch_size : int (default = 32) Number of samples used in one iteration. Only works for neural networks. num_workers: int or 'auto' (default = 0) Number of parallel workers that will load the data. Set 'auto' to let RSP choose the optimal number of workers, set 0 to disable multiprocessing. It can increase training speed, but can also cause errors (e.g. pickling errors). Examples -------- >>> import remote_sensing_processor as rsp >>> x, y, out_file = ... >>> ds = rsp.regression.generate_tiles( ... x, ... y, ... out_file, ... tile_size=256, ... shuffle=True, ... split={"train": 3, "val": 1, "test": 1}, ... ) >>> model = rsp.regression.train( ... {"path": ds, "sub": "train"}, ... {"path": ds, "sub": "val"}, ... model="UperNet", ... backbone="ConvNeXTV2", ... model_file="/home/rsp_test/model/upernet.ckpt", ... batch_size=32, ... ) >>> rsp.regression.band_importance({"path": ds, "y": "nitrogen"}, model) PartitionExplainer explainer: 100%|██████████████████████████████████████████▉| 499/500 [42:07<00:05, 5.07s/it] Landsat-B1: 0.0162 Landsat-B2: 0.0493 Landsat-B3: 0.0875 Landsat-B4: 0.0243 Landsat-B5: 0.0319 Landsat-B7: 0.0194 NDVI: 0.0353 NBR: 0.0281 slope: 0.0134 curvature: 0.0239 aspect: 0.0311 dem-norm: 0.0236 >>> ds = {"path": "/home/rsp_test/model/ds.rspds"} >>> model = "/home/rsp_test/model/xgboost.joblib" >>> rsp.regression.band_importance(ds, model, num_init_images=1, num_images=1) PartitionExplainer explainer: 16385it [1:12:01, 3.78it/s] coastal: 0.0266 blue: 0.0266 green: 0.0253 red: 0.0542 rededge071: 0.0194 rededge075: 0.0194 rededge078: 0.0196 nir: 0.0034 nir08: 0.0111 nir09: 0.0758 swir16: 0.2894 swir22: 0.0309 NDVI: 0.0483 canopyheight_norm: 0.0005 dem_norm: 0.0741 """ dataset["predict"] = True if "sub" not in dataset: dataset["sub"] = "all" # Setting datamodule dm = RegressionDataModule(pred_dataset=dataset, batch_size=batch_size, num_workers=num_workers) # Loading model if isinstance(model, Path): if ".ckpt" in model.suffixes: model = RegressionModel.load_from_checkpoint(model, weights_only=False) elif ".joblib" in model.suffixes: model = joblib.load(model) else: raise ValueError("Wrong model extension. Should be .ckpt or .joblib") # Reading variables names variables = dm.variables dm.setup(stage="predict") if num_init_images > len(dm.ds_pred): raise ValueError("Not enough samples in dataset") if num_init_images + num_images > len(dm.ds_pred): num_images = len(dm.ds_pred) - num_init_images warnings.warn(f"Not enough samples in dataset, num_images set to {num_images}", stacklevel=2) # Neural networks if model.model_name in pytorch_models: if not cuda_test(): warnings.warn("CUDA or MPS is not available. Predicting on CPU could be very slow.", stacklevel=1) # Initialize dataloader data_loader = dm.predict_dataloader() # Load the background data background_data_list = [] count = 0 for batch in data_loader: background_data_list.append(batch["x"]) count += batch["x"].shape[0] if count >= num_init_images: break background_data = torch.cat(background_data_list, dim=0)[:num_init_images] # Initialize masker with the background data masker = TorchChannelMasker(background_data) # Initialize explainer explainer = shap.PartitionExplainer(lambda x: torch_model_wrapper(x, model), masker) # Load the images to be explained images_to_explain_list = [] count = 0 for batch in data_loader: images_to_explain_list.append(batch["x"]) count += batch["x"].shape[0] if count >= num_images: break images_to_explain = torch.cat(images_to_explain_list, dim=0)[:num_images] # Explaining values = explainer(images_to_explain).values global_channel_importance = np.mean(np.abs(values), axis=0).squeeze() for i, var in enumerate(variables): print(f"{var}: {global_channel_importance[i]:.4f}") # Sklearn models elif model.model_name in sklearn_models: x_pred, _, _ = sklearn_load_dataset(dm, "predict", model.generate_features) # Load the background data background_data = x_pred[: num_init_images * (dm.input_shape**2)] explainer = shap.PartitionExplainer( lambda x: sklearn_model_wrapper(x, model), background_data, ) # Run SHAP explainer images_to_explain = x_pred[ num_init_images * (dm.input_shape**2) : (num_init_images + num_images) * (dm.input_shape**2) ] values = explainer(images_to_explain).values global_channel_importance = np.mean(np.abs(values), axis=0).squeeze() for i, var in enumerate(variables): print(f"{var}: {global_channel_importance[i]:.4f}") else: raise ValueError("Wrong model name. Check spelling or read a documentation and choose a supported model")
def torch_model_wrapper(x: Union[np.ndarray, torch.Tensor], model: RegressionModel) -> np.ndarray: """Predicts and returns the average probability of a target class across the entire image.""" # Ensure the input is a tensor and is on the correct device if not isinstance(x, torch.Tensor): x = torch.from_numpy(x) x = x.to(model.device) batch = {"key": [0] * x.shape[0], "x": x} # The output is a tensor of shape (batch_size, num_classes, H, W) with torch.no_grad(): _, output, _, _ = model.forward(batch) output = torch.nn.functional.sigmoid(output) # Calculate the average probability for the target class across the image avg_prob = torch.mean(output, dim=(2, 3)) # Return a NumPy array return avg_prob.detach().cpu().numpy() def sklearn_model_wrapper(x: np.ndarray, model: SklearnRegressionModel) -> np.ndarray: """Predicts the class for a target, as predict_proba is not always available.""" if x.ndim == 1: x = x[np.newaxis, :] return model.predict(x).astype("float32")