Source code for remote_sensing_processor.segmentation.semantic.mapping

"""Mapping semantic segmentation model predictions."""

from pydantic import InstanceOf, NonNegativeInt, PositiveInt, validate_call
from typing import Any, Literal, Optional, Union

import math
import warnings
from pathlib import Path

import joblib

import numpy as np

import lightning as l

from remote_sensing_processor.common.common_functions import create_path, persist
from remote_sensing_processor.common.common_raster import prepare_nodata, write
from remote_sensing_processor.common.torch_test import cuda_test
from remote_sensing_processor.common.types import FilePath, NewPath
from remote_sensing_processor.segmentation.mapping import load_reference, post_process_raster_dataset
from remote_sensing_processor.segmentation.segmentation import sklearn_load_dataset
from remote_sensing_processor.segmentation.semantic.models import pytorch_models, sklearn_models
from remote_sensing_processor.segmentation.semantic.segmentation import (
    SemanticSegmentationDataModule,
    SemanticSegmentationModel,
    SklearnSemanticSegmentationModel,
)


[docs] @validate_call def generate_map( dataset: dict[str, Any], model: Union[FilePath, InstanceOf[SemanticSegmentationModel], InstanceOf[SklearnSemanticSegmentationModel]], output: NewPath, reference_dataset: Optional[dict[str, Any]] = None, batch_size: Optional[PositiveInt] = 32, num_workers: Optional[Union[NonNegativeInt, Literal["auto"]]] = 0, write_stac: Optional[bool] = True, ) -> NewPath: """ Create a map using pre-trained model. Parameters ---------- dataset : dict Dataset generated by generate_tiles() function that will be used for prediction. Dataset can contain 3 elements: `path` (path as str): a path to a dataset. Required parameter. `sub` (str): subdataset name, list of subdataset names or 'all'. Optional parameter. If not defined, prediction for the whole dataset will be performed. `y` (str): if there is more than one target variable in dataset, then the name of the variable that should be used for original data reconstruction should be defined. Optional parameter. model : torch.nn model or SklearnModel or path to a model file Pre-trained model to predict target values. You can pass the model object returned by `train()` function or file (*.ckpt or *.joblib) where model is stored. output : path as a string Path where to write an output map. reference_dataset : path as a string (optional) Dataset generated by generate_tiles() function that will be used to reconstruct original class values and nodata if prediction dataset has no target variable ('y'). Dataset can contain 2 elements: `path`: a path to a dataset. `y`: if there is more than one target variable in dataset, then the name of the variable that should be used for reconstruction should be defined. batch_size : int (default = 32) Number of samples used in one iteration. Only works for neural networks. num_workers: int or 'auto' (default = 0) Number of parallel workers that will load the data. Set 'auto' to let RSP choose the optimal number of workers, set 0 to disable multiprocessing. It can increase training speed, but can also cause errors (e.g. pickling errors). write_stac : bool (default = True) If True, then output metadata is saved to a STAC file. Returns ------- pathlib.Path Path where output raster is saved. Examples -------- >>> import remote_sensing_processor as rsp >>> x, y, out_file = ... >>> ds = rsp.semantic.generate_tiles( ... x, ... y, ... out_file, ... tile_size=256, ... shuffle=True, ... split={"train": 3, "val": 1, "test": 1}, ... ) >>> model = rsp.semantic.train( ... {"path": ds, "sub": "train"}, ... {"path": ds, "sub": "val"}, ... model="UperNet", ... backbone="ConvNeXTV2", ... model_file="/home/rsp_test/model/upernet.ckpt", ... batch_size=32, ... ) >>> output_map = "/home/rsp_test/prediction.tif" >>> rsp.semantic.generate_map({"path": ds, "y": "landcover"}, model, output_map) Predicting: 100% #################### 372/372 [32:16, 1.6s/it] >>> ds = {"path": "/home/rsp_test/model/ds.rspds"} >>> model = "/home/rsp_test/model/upernet.ckpt" >>> output_map = "/home/rsp_test/prediction.tif" >>> rsp.semantic.generate_map(ds, model, output_map) Predicting: 100% #################### 372/372 [32:16, 1.6s/it] >>> # Train model on data from Montana >>> x_montana_files = "/home/rsp_test/mosaics/landsat_montana/landsat.json" >>> y_montana_files = {"name": "landcover", "path": "/home/rsp_test/mosaics/landcover_montana/landcover.tif"} >>> ds_montana = rsp.semantic.generate_tiles( ... x_montana_files, ... y_montana_files, ... tile_size=256, ... shuffle=True, ... split={"train": 3, "val": 1, "test": 1}, ... ) >>> train_ds = {"path": ds_montana, "sub": "train"} >>> val_ds = {"path": ds_montana, "sub": "val"} >>> model_montana = rsp.semantic.train( ... train_ds, ... val_ds, ... model="UperNet", ... backbone="ConvNeXTV2", ... model_file="/home/rsp_test/model/upernet.ckpt", ... epochs={"max_epochs": 10, "early_stopping": False}, ... batch_size=32, ... ) >>> # Use model to map landcover of Idaho >>> x_idaho_files = "/home/rsp_test/mosaics/landsat_idaho/landsat.json" >>> ds_idaho = rsp.semantic.generate_tiles(x_idaho_files, None, tile_size=256) >>> output_map = "/home/rsp_test/prediction_idaho.tif" >>> pred_ds = {"path": ds_idaho} >>> ref_ds = {"path": ds_montana, "y": "landcover"} >>> rsp.semantic.generate_map(pred_ds, model_montana, output_map, reference_dataset=ref_ds) Predicting: 100% #################### 372/372 [32:16, 1.6s/it] """ dataset["predict"] = True if "sub" not in dataset: dataset["sub"] = "all" if reference_dataset is not None and "sub" not in reference_dataset: reference_dataset["sub"] = "all" # Setting datamodule dm = SemanticSegmentationDataModule(pred_dataset=dataset, batch_size=batch_size, num_workers=num_workers) # Loading reference raster reference, stac = load_reference(dm) # Loading model if isinstance(model, Path): if ".ckpt" in model.suffixes: model = SemanticSegmentationModel.load_from_checkpoint(model, weights_only=False) elif ".joblib" in model.suffixes: model = joblib.load(model) else: raise ValueError("Wrong model extension. Should be .ckpt or .joblib") # Reading classes and nodata if reference_dataset is not None: rdm = SemanticSegmentationDataModule( pred_dataset=reference_dataset, batch_size=batch_size, num_workers=num_workers, ) else: dataset["predict"] = False rdm = SemanticSegmentationDataModule(pred_dataset=dataset, batch_size=batch_size, num_workers=num_workers) classes = rdm.classes nodata = rdm.y_nodata dtype = rdm.y_dtype if classes is None or nodata is None: raise ValueError("Classes or nodata information is absent in the input dataset.") # Setting reference nodata reference = reference.astype(dtype) reference = reference + nodata # noinspection PyTypeChecker reference, nodata = prepare_nodata(reference, nodata, 0) reference = persist(reference) # Neural networks if model.model_name in pytorch_models: if not cuda_test(): warnings.warn("CUDA or MPS is not available. Predicting on CPU could be very slow.", stacklevel=1) # Predict trainer = l.Trainer(precision=model.precision, enable_checkpointing=False) predictions, keys = zip(*trainer.predict(model, dm), strict=True) predictions = np.concatenate(predictions, axis=0)[:, np.newaxis, :, :] keys = np.concatenate(keys, axis=0).tolist() # Sklearn models elif model.model_name in sklearn_models: dm.setup(stage="predict") x_pred, _, keys = sklearn_load_dataset(dm, "predict", model.generate_features) # Predict everything predictions = model.predict(x_pred) predictions = predictions.reshape(len(dm.ds_pred), 1, dm.input_shape, dm.input_shape) else: raise ValueError("Wrong model name. Check spelling or read a documentation and choose a supported model") predictions = predictions[:, :, dm.border : (dm.input_shape - dm.border), dm.border : (dm.input_shape - dm.border)] predictions = predictions.astype(dtype) # Mapping for i in range(len(predictions)): prediction = predictions[i] # Getting coordinates of a tile in a resulting array shape = dm.input_shape - (dm.border * 2) pos1 = shape * (keys[i] % math.ceil(reference.shape[1] / shape)) pos2 = shape * (keys[i] // math.ceil(reference.shape[1] / shape)) pos3 = min(reference.shape[1], pos1 + shape) pos4 = min(reference.shape[2], pos2 + shape) # Changing prediction shape if needed if (pos3 - pos1) != shape: prediction = prediction[:, : (pos3 - pos1), :] if (pos4 - pos2) != shape: prediction = prediction[:, :, : (pos4 - pos2)] # Writing predicted tile to its position in array reference.data[:, pos1:pos3, pos2:pos4] = np.where( reference.data[:, pos1:pos3, pos2:pos4] == nodata, nodata, prediction, ) # Recreating original classes values # reference, nodata = restore_classes(reference, classes, nodata) reference = persist(reference) # Creating an output folder create_path(output) # Post-processing STAC dataset stac, json_path = post_process_raster_dataset(stac, reference, output) # Writing to file write(reference, output) if write_stac: # Writing JSON metadata file stac.save_object(dest_href=json_path.as_posix()) return json_path return output