"""Generating tiles for semantic segmentation."""
from pydantic import BaseModel, Field, PositiveFloat, PositiveInt, TypeAdapter, validate_call
from typing import Annotated, Any, Dict, Generator, Optional, Union
import gc
import warnings
import xbatcher
import dask
import numpy as np
import xarray as xr
import datasets
from remote_sensing_processor.common.common_functions import persist, write_json
from remote_sensing_processor.common.common_raster import assert_equal_shapes
from remote_sensing_processor.common.types import (
DirectoryPath,
DType,
FilePath,
ListOfDict,
ListOfPath,
ListOfPystacItem,
NewRSPDS,
PystacItem,
SingleOrList,
)
from remote_sensing_processor.segmentation.tiles import (
border_pad,
check_classes,
check_dtype,
clean_cache,
create_folders,
filter_nodata_raster,
filter_samples,
get_cache,
pad,
prepare_images,
prepare_raster_sm,
prepare_vector_sm,
split_samples,
write_reference,
)
class Y(BaseModel):
"""Y variable class for user input."""
name: str
path: Union[FilePath, DirectoryPath, PystacItem]
burn_value: Optional[str] = None
ListOfY = SingleOrList[Y]
[docs]
@validate_call
def generate_tiles(
x: Union[ListOfPath, ListOfPystacItem],
y: Union[ListOfDict, None],
output: NewRSPDS,
tile_size: Optional[Annotated[int, Field(strict=True, ge=8)]] = 128,
shuffle: Optional[bool] = False,
split: Optional[Dict[str, Union[PositiveInt, PositiveFloat]]] = None,
filter_nodata: Optional[str] = "x",
x_dtype: Optional[DType] = None,
y_dtype: Optional[DType] = None,
x_nodata: Optional[Union[int, float]] = None,
y_nodata: Optional[int] = None,
) -> NewRSPDS:
"""
Cut rasters into tiles.
Parameters
----------
x : list of paths as strings
Rasters to use as training data.
y : dict or list of dicts (Optional)
Target variable or multiple target variables. It can be set to None if target value is not needed.
Dict or multiple dicts. If a target variable is not needed in the dataset, can be set to None.
It should contain:
`name`: a name of a target variable that will be used further to call it.
`path`: raster or vector file to use as target variable.
`burn_value` (optional): a field to use for a burn-in value. Field should be numeric.
If there is a `burn_value` key in dict, target variable will be considered a vector file,
if there is only a `path` key, variable will be considered a raster file.
We strongly recommend you to change class values to 0, 1, 2, ..., n (where 0 is nodata) before generating tiles.
output : path as a string
Path to save generated output x data.
Data is saved in a .rspds format (custom dataset format based on WebDataset).
tile_size : int (default = 128)
Size of tiles to generate (tile_size x tile_size).
shuffle : bool (default = False)
Is a random shuffling of samples needed.
split : dict (optional)
Splitting data in subsets.
Is a dict, where keys are the names of split subsets and
values are numbers defining proportions of every subset.
For example, `{"train": 3, "validation": 1, "test": 1}` will generate
3 subsets (train, validation, and test) in proportion 3 to 1 to 1.
filter_nodata : str (default = "x")
How the nodata values should be treated.
`None`: do not filter nodata.
`"x"`: filter out pixels that are nodata in x.
`"y"`: filter out pixels that are nodata in y.
`"x_or_y"`: filter out pixels that are nodata in x or y.
`"x_and_y"`: filter out pixels that are nodata in x and y.
x_dtype : dtype definition as a string (optional)
If you run out of memory, you can try to convert your data to less memory consuming format.
y_dtype : dtype definition as a string (optional)
If you run out of memory, you can try to convert your data to less memory consuming format.
x_nodata : int or float (optional)
You can define which value in x raster corresponds to nodata
and areas that contain nodata in x raster will be ignored while training and testing.
Tiles that contain only nodata in both x and y will be omitted.
If not defined, then the most common nodata value amongst x files will be used.
If there are no nodata values, will be set to 0.
y_nodata : int (optional)
You can define which value will be used to fill nodata.
If there are polygons with the same value as `y_nodata`, they will be ignored while training and testing.
Tiles that contain only nodata in both x and y will be omitted.
If not defined, then it will be set to 0.
Returns
-------
pathlib.Path
Path to the output dataset.
Examples
--------
>>> import remote_sensing_processor as rsp
>>> x = ["/home/rsp_test/mosaics/sentinel/sentinel.json", "/home/rsp_test/mosaics/dem/dem.json"]
>>> y = [
... {"name": "landcover", "path": "/home/rsp_test/mosaics/landcover.tif"},
... {"name": "forest_types", "path": "/home/rsp_test/mosaics/forest_types.gpkg", "burn_value": "class"},
... ]
>>> out_file = "/home/rsp_test/model/landcover_dataset.rspds"
>>> out_dataset = rsp.semantic.generate_tiles(
... x,
... y,
... out_file,
... tile_size=256,
... shuffle=True,
... split={"train": 3, "val": 1, "test": 1},
... )
>>> print(out_dataset)
PosixPath('/home/rsp_test/model/landcover_dataset.rspds')
"""
if y is not None:
y = TypeAdapter(ListOfY).validate_python(y)
if split is None:
split = {"train": 1}
unique_id = datasets.fingerprint.Hasher.hash(locals())
if y_nodata is not None and y_nodata != 0:
warnings.warn(
"Recommended class values format is 0, 1, 2, ..., n (where 0 is nodata), but y_nodata is " + str(y_nodata),
stacklevel=1,
)
create_folders(output, split)
data: dict[str, Any] = {
"task": "semantic",
}
# Initially load and preprocess data
x_img, x_nodata = prepare_images(img=x, nodata=x_nodata, dtype=x_dtype)
# Write a reference file
write_reference(x_img, output, x_nodata)
if y is not None:
y_img, y_nodata = prepare_seg_maps(y=y, ref=x_img[0], dtype=y_dtype, y_nodata=y_nodata)
assert_equal_shapes([x_img, y_img])
else:
y_img, y_nodata = None, None
border, padding = border_pad(x_img, tile_size)
data["tile_size"] = tile_size
data["border"] = border
data["pad"] = padding
# Padding
x_img = pad(x_img, padding, x_nodata)
if y_img is not None:
y_img = pad(y_img, padding, y_nodata)
# Filtering
if y_img is not None and filter_nodata is not None:
x_img, y_img = filter_nodata_raster(x_img, y_img, filter_nodata, x_nodata, y_nodata)
data["x"] = {
"dtype": x_img.dtype,
"nodata": x_nodata,
"bands": x_img.shape[0],
"variables": ["_".join(x.split("_")[1:]) for x in x_img["band"].values.tolist()],
}
if y_img is not None:
data["y"] = {}
for i in range(len(y_img)):
# noinspection PyUnresolvedReferences
data["y"][y[i].name] = {"dtype": y_img[i].dtype, "nodata": y_nodata}
# Checking classes values
classes = sorted(np.unique(y_img[i]))
classes = list(range(classes[-1] + 1))
num_classes = len(classes)
data["y"][y[i].name]["classes"] = classes
data["y"][y[i].name]["num_classes"] = num_classes
if not min(classes) >= 0:
raise ValueError("Class values must be >= 0")
# y_img = y_img.to_dataset('band').expand_dims({'band': 1})
# y_img[list(y_img.keys())[i]] = normalize_classes(y_img[list(y_img.keys())[i]], y_nodata, classes)
# y_img = persist(y_img.squeeze().to_array('band'))
# y_nodata = classes.index(y_nodata)
# Setting up tiles generators
x_batches = xbatcher.BatchGenerator(
ds=x_img,
input_dims={"x": tile_size - (border * 2), "y": tile_size - (border * 2)},
)
if y_img is not None:
y_batches = xbatcher.BatchGenerator(
ds=y_img,
input_dims={"x": tile_size - (border * 2), "y": tile_size - (border * 2)},
)
else:
y_batches = None
# Getting samples
samples = list(range(len(x_batches)))
samples_x = filter_samples(x_batches, samples, x_nodata)
if y_img is not None:
samples_y = filter_samples(y_batches, samples, y_nodata)
samples = list(set(samples_x + samples_y))
else:
samples = samples_x
# Shuffling samples
if shuffle:
np.random.shuffle(samples)
# Splitting samples
samples = split_samples(samples, split)
data["samples"] = samples
write_json(data, output / "meta.json")
# Calculating optimal batch size
# Target size is 256MB
target_size = 256 * 1024 * 1024
x_channels = x_img.shape[0]
y_channels = len(y_img) if y_img is not None else 0
bytes_per_sample = (x_channels + y_channels) * (tile_size**2) * 4
writer_batch_size = max(1, int(target_size / bytes_per_sample))
for name in split:
# Generate features
feat = {
"key": datasets.Value(dtype="int64"),
"x": datasets.Array3D(dtype="float32", shape=(data["x"]["bands"], tile_size, tile_size)),
}
if y_img is not None:
for i in range(len(y)):
# noinspection PyTypeChecker
feat["y_" + y[i].name] = datasets.Array2D(dtype="int32", shape=(tile_size, tile_size))
feat = datasets.Features(feat)
def dataset_generator(samples: dict[str, list[int]], name: str) -> Generator[dict]:
for index in samples[name]:
data_dict = {
"key": index,
"x": np.pad(
x_batches[index].data,
((0, 0), (border, border), (border, border)),
"symmetric",
),
}
if y is not None:
y_index = np.pad(y_batches[index].data, ((0, 0), (border, border), (border, border)), "symmetric")
for j in range(len(y)):
data_dict["y_" + y[j].name] = y_index[j]
yield data_dict
# Create dataset
ds = datasets.Dataset.from_generator(
dataset_generator,
features=feat,
cache_dir=(output / ".cache").as_posix(),
fingerprint=unique_id,
gen_kwargs={"samples": samples, "name": name},
writer_batch_size=writer_batch_size,
)
# Save dataset
ds.save_to_disk(output / name)
# Cleaning the cache
cache = get_cache(ds)
del ds
gc.collect()
clean_cache(cache)
return output
def prepare_seg_maps(
y: ListOfY,
ref: xr.DataArray,
dtype: Optional[type] = None,
dtype_class: Optional[type] = np.integer,
y_nodata: Optional[int] = 0,
) -> tuple[xr.DataArray, int]:
"""Prepare segmentation maps: rasterize vectors and match rasters."""
if y_nodata is None:
y_nodata = 0
arrays = []
for ds in y:
if ds.burn_value is not None:
arrays.append(dask.delayed(prepare_vector_sm)(ds.path, ref, ds.burn_value, ds.name, y_nodata))
else:
arrays.append(dask.delayed(prepare_raster_sm)(ds.path, ref, ds.name, y_nodata))
arrays = list(dask.compute(*arrays))
arrays = xr.merge(arrays)
arrays = check_dtype(img=arrays, dtype=dtype, dtype_class=dtype_class)
check_classes(arrays, y_nodata)
arrays = persist(arrays.squeeze().to_array("band").chunk("auto"))
return arrays, y_nodata