"""Mapping regression model predictions."""
from pydantic import InstanceOf, NonNegativeInt, PositiveInt, validate_call
from typing import Any, Literal, Optional, Union
import math
import warnings
from pathlib import Path
import joblib
import numpy as np
import lightning as l
from remote_sensing_processor.common.common_functions import create_path, persist
from remote_sensing_processor.common.common_raster import prepare_nodata, write
from remote_sensing_processor.common.torch_test import cuda_test
from remote_sensing_processor.common.types import FilePath, NewPath
from remote_sensing_processor.segmentation.mapping import load_reference, post_process_raster_dataset
from remote_sensing_processor.segmentation.regression.models import pytorch_models, sklearn_models
from remote_sensing_processor.segmentation.regression.segmentation import (
RegressionDataModule,
RegressionModel,
SklearnRegressionModel,
)
from remote_sensing_processor.segmentation.segmentation import sklearn_load_dataset
[docs]
@validate_call
def generate_map(
dataset: dict[str, Any],
model: Union[FilePath, InstanceOf[RegressionModel], InstanceOf[SklearnRegressionModel]],
output: NewPath,
reference_dataset: Optional[dict[str, Any]] = None,
batch_size: Optional[PositiveInt] = 32,
num_workers: Optional[Union[NonNegativeInt, Literal["auto"]]] = 0,
write_stac: Optional[bool] = True,
) -> NewPath:
"""
Create a map using pre-trained model.
Parameters
----------
dataset : dict
Dataset generated by generate_tiles() function that will be used for prediction.
Dataset can contain 3 elements:
`path` (path as str): a path to a dataset. Required parameter.
`sub` (str): subdataset name, list of subdataset names or 'all'. Optional parameter.
If not defined, prediction for the whole dataset will be performed.
`y` (str): if there is more than one target variable in dataset,
then the name of the variable that should be used for original data reconstruction should be defined.
Optional parameter.
model : torch.nn model or SklearnModel or path to a model file
Pre-trained model to predict target values.
You can pass the model object returned by `train()` function or file (*.ckpt or *.joblib) where model is stored.
output : path as a string
Path where to write an output map.
reference_dataset : path as a string (optional)
Dataset generated by generate_tiles() function that will be used to reconstruct original class values and nodata
if prediction dataset has no target variable ('y').
Dataset can contain 2 elements:
`path`: a path to a dataset.
`y`: if there is more than one target variable in dataset,
then the name of the variable that should be used for reconstruction should be defined.
batch_size : int (default = 32)
Number of samples used in one iteration. Only works for neural networks.
num_workers: int or 'auto' (default = 0)
Number of parallel workers that will load the data.
Set 'auto' to let RSP choose the optimal number of workers, set 0 to disable multiprocessing.
It can increase training speed, but can also cause errors (e.g., pickling errors).
write_stac : bool (default = True)
If True, then output metadata is saved to a STAC file.
Returns
-------
pathlib.Path
Path where output raster is saved.
Examples
--------
>>> import remote_sensing_processor as rsp
>>> x, y, out_file = ...
>>> ds = rsp.regression.generate_tiles(
... x,
... y,
... out_file,
... tile_size=256,
... shuffle=True,
... split={"train": 3, "val": 1, "test": 1},
... )
>>> model = rsp.regression.train(
... {"path": ds, "sub": "train"},
... {"path": ds, "sub": "val"},
... model="UperNet",
... backbone="ConvNeXTV2",
... model_file="/home/rsp_test/model/upernet.ckpt",
... batch_size=32,
... )
>>> output_map = "/home/rsp_test/prediction.tif"
>>> rsp.regression.generate_map({"path": ds, "y": "nitrogen"}, model, output_map)
Predicting: 100% #################### 372/372 [32:16, 1.6s/it]
>>> ds = {"path": "/home/rsp_test/model/ds.rspds"}
>>> model = "/home/rsp_test/model/upernet.ckpt"
>>> output_map = "/home/rsp_test/prediction.tif"
>>> rsp.regression.generate_map(ds, model, output_map)
Predicting: 100% #################### 372/372 [32:16, 1.6s/it]
>>> # Train model on data from Montana
>>> x_montana_files = "/home/rsp_test/mosaics/landsat_montana/landsat.json"
>>> y_montana_files = {"name": "nitrogen", "path": "/home/rsp_test/mosaics/chem_montana/nitrogen.tif"}
>>> ds_montana = rsp.regression.generate_tiles(
... x_montana_files, y_montana_files, tile_size=256, shuffle=True, split={"train": 3, "val": 1, "test": 1}
... )
>>> train_ds = {"path": ds_montana, "sub": "train"}
>>> val_ds = {"path": ds_montana, "sub": "val"}
>>> model_montana = rsp.regression.train(
... train_ds,
... val_ds,
... model="UperNet",
... backbone="ConvNeXTV2",
... model_file="/home/rsp_test/model/upernet.ckpt",
... epochs={"max_epochs": 10, "early_stopping": False},
... batch_size=32,
... )
>>> # Use model to map crop nitrogen content in Idaho
>>> x_idaho_files = "/home/rsp_test/mosaics/landsat_idaho/landsat.json"
>>> ds_idaho = rsp.regression.generate_tiles(x_idaho_files, None, tile_size=256)
>>> output_map = "/home/rsp_test/prediction_idaho.tif"
>>> pred_ds = {"path": ds_idaho}
>>> ref_ds = {"path": ds_montana, "y": "nitrogen"}
>>> rsp.regression.generate_map(pred_ds, model_montana, output_map, reference_dataset=ref_ds)
Predicting: 100% #################### 372/372 [32:16, 1.6s/it]
"""
dataset["predict"] = True
if "sub" not in dataset:
dataset["sub"] = "all"
if reference_dataset is not None and "sub" not in reference_dataset:
reference_dataset["sub"] = "all"
# Setting datamodule
dm = RegressionDataModule(pred_dataset=dataset, batch_size=batch_size, num_workers=num_workers)
# Loading reference raster
reference, stac = load_reference(dm)
# Loading model
if isinstance(model, Path):
if ".ckpt" in model.suffixes:
model = RegressionModel.load_from_checkpoint(model, weights_only=False)
elif ".joblib" in model.suffixes:
model = joblib.load(model)
else:
raise ValueError("Wrong model extension. Should be .ckpt or .joblib")
# Reading nodata
if reference_dataset is not None:
rdm = RegressionDataModule(pred_dataset=reference_dataset, batch_size=batch_size, num_workers=num_workers)
else:
dataset["predict"] = False
rdm = RegressionDataModule(pred_dataset=dataset, batch_size=batch_size, num_workers=num_workers)
nodata = rdm.y_nodata
dtype = rdm.y_dtype
if nodata is None:
raise ValueError("Nodata information is absent in the input dataset.")
# Setting reference nodata
reference = reference.astype(dtype)
reference = reference + nodata
reference, nodata = prepare_nodata(reference, nodata, 0)
reference = persist(reference)
# Neural networks
if model.model_name in pytorch_models:
if not cuda_test():
warnings.warn("CUDA or MPS is not available. Predicting on CPU could be very slow.", stacklevel=1)
# Predict
trainer = l.Trainer(precision=model.precision, enable_checkpointing=False)
predictions, keys = zip(*trainer.predict(model, dm), strict=True)
predictions = np.concatenate(predictions, axis=0)
keys = np.concatenate(keys, axis=0).tolist()
# Sklearn models
elif model.model_name in sklearn_models:
dm.setup(stage="predict")
x_pred, _, keys = sklearn_load_dataset(dm, "predict", model.generate_features)
# Predict everything
predictions = model.predict(x_pred)
predictions = predictions.reshape(len(dm.ds_pred), 1, dm.input_shape, dm.input_shape)
else:
raise ValueError("Wrong model name. Check spelling or read a documentation and choose a supported model")
predictions = predictions[:, :, dm.border : (dm.input_shape - dm.border), dm.border : (dm.input_shape - dm.border)]
predictions = predictions.astype(dtype)
# Mapping
for i in range(len(predictions)):
prediction = predictions[i]
# Getting coordinates of a tile in a resulting array
shape = dm.input_shape - (dm.border * 2)
pos1 = shape * (keys[i] % math.ceil(reference.shape[1] / shape))
pos2 = shape * (keys[i] // math.ceil(reference.shape[1] / shape))
pos3 = min(reference.shape[1], pos1 + shape)
pos4 = min(reference.shape[2], pos2 + shape)
# Changing prediciton shape if needed
if (pos3 - pos1) != shape:
prediction = prediction[:, : (pos3 - pos1), :]
if (pos4 - pos2) != shape:
prediction = prediction[:, :, : (pos4 - pos2)]
# Writing predicted tile to its position in array
reference.data[:, pos1:pos3, pos2:pos4] = np.where(
reference.data[:, pos1:pos3, pos2:pos4] == nodata,
nodata,
prediction,
)
reference = persist(reference)
# Creating an output folder
create_path(output)
# Post-processing STAC dataset
stac, json_path = post_process_raster_dataset(stac, reference, output)
# Writing to file
write(reference, output)
if write_stac:
# Writing JSON metadata file
stac.save_object(dest_href=json_path.as_posix())
return json_path
return output