"""Generating tiles for regression."""
from pydantic import BaseModel, Field, PositiveFloat, PositiveInt, TypeAdapter, validate_call
from typing import Annotated, Any, Dict, Generator, Optional, Union
import gc
import xbatcher
import dask
import numpy as np
import xarray as xr
import datasets
from remote_sensing_processor.common.common_functions import persist, write_json
from remote_sensing_processor.common.common_raster import assert_equal_shapes
from remote_sensing_processor.common.types import (
DirectoryPath,
DType,
FilePath,
ListOfDict,
ListOfPath,
ListOfPystacItem,
NewRSPDS,
PystacItem,
SingleOrList,
)
from remote_sensing_processor.segmentation.tiles import (
border_pad,
check_dtype,
clean_cache,
create_folders,
filter_nodata_raster,
filter_samples,
get_cache,
pad,
prepare_images,
prepare_raster_sm,
prepare_vector_sm,
split_samples,
write_reference,
)
class Y(BaseModel):
"""Y variable class for user input."""
name: str
path: Union[FilePath, DirectoryPath, PystacItem]
burn_value: Optional[str] = None
ListOfY = SingleOrList[Y]
[docs]
@validate_call
def generate_tiles(
x: Union[ListOfPath, ListOfPystacItem],
y: Union[ListOfDict, None],
output: NewRSPDS,
tile_size: Optional[Annotated[int, Field(strict=True, ge=8)]] = 128,
shuffle: Optional[bool] = False,
split: Optional[Dict[str, Union[PositiveInt, PositiveFloat]]] = None,
filter_nodata: Optional[str] = "x",
x_dtype: Optional[DType] = None,
y_dtype: Optional[DType] = None,
x_nodata: Optional[Union[int, float]] = None,
y_nodata: Optional[Union[int, float]] = None,
) -> NewRSPDS:
"""
Cut rasters into tiles.
Parameters
----------
x : list of paths as strings
Rasters to use as training data.
y : dict or list of dicts
Target variable or multiple target variables. Can be set to None if target value is not needed.
Dict or multiple dicts.
It should contain:
`name`: a name of a target variable that will be used further to call it.
`path`: raster or vector file to use as target variable.
`burn_value` (optional): a field to use for a burn-in value. Field should be numeric.
output : path as a string (optional)
Path to save generated output x data.
Data is saved in a .rspds format (custom dataset format based on WebDataset.
tile_size : int (default = 128)
Size of tiles to generate (tile_size x tile_size).
shuffle : bool (default = False)
Is a random shuffling of samples needed.
split : dict (optional)
Splitting data in subsets.
Is a dict, where keys are the names of split subsets and
values are numbers defining proportions of every subset.
For example, `{"train": 3, "validation": 1, "test": 1}` will generate
3 subsets (train, validation, and test) in proportion 3 to 1 to 1.
filter_nodata : str (default = "x")
How the nodata values should be treated.
`None`: do not filter nodata.
`"x"`: filter out pixels that are nodata in x.
`"y"`: filter out pixels that are nodata in y.
`"x_or_y"`: filter out pixels that are nodata in x or y.
`"x_and_y"`: filter out pixels that are nodata in x and y.
x_dtype : dtype definition as a string (optional)
If you run out of memory, you can try to convert your data to less memory consuming format.
y_dtype : dtype definition as a string (optional)
If you run out of memory, you can try to convert your data to less memory consuming format.
x_nodata : int or float (optional)
You can define which value in x raster corresponds to nodata
and areas that contain nodata in x raster will be ignored while training and testing.
Tiles that contain only nodata in both x and y will be omitted.
If not defined, then the most common nodata value amongst x files will be used.
If there are no nodata values, will be set to 0.
y_nodata : int or float (optional)
You can define which value will be used to fill nodata.
If there are polygons with the same value as `y_nodata`, they will be ignored while training and testing.
Tiles that contain only nodata in both x and y will be omitted.
If not defined, then it will be set to 0.
Returns
-------
pathlib.Path
Path to the output dataset.
Examples
--------
>>> import remote_sensing_processor as rsp
>>> x = ["/home/rsp_test/mosaics/sentinel/sentinel.json", "/home/rsp_test/mosaics/dem/dem.tif"]
>>> y = [
... {"name": "nitrogen", "path": "/home/rsp_test/mosaics/nitrogen.tif"},
... {"name": "phosphorus", "path": "/home/rsp_test/vectors/phosphorus.gpkg", "burn_value": "P"},
... ]
>>> out_file = "/home/rsp_test/model/chem_dataset.rspds"
>>> out_dataset = rsp.regression.generate_tiles(
... x,
... y,
... out_file,
... tile_size=256,
... shuffle=True,
... split={"train": 3, "val": 1, "test": 1},
... )
>>> print(out_dataset)
PosixPath('/home/rsp_test/model/chem_dataset.rspds')
"""
if y is not None:
y = TypeAdapter(ListOfY).validate_python(y)
if split is None:
split = {"train": 1}
unique_id = datasets.fingerprint.Hasher.hash(locals())
create_folders(output, split)
data: dict[str, Any] = {
"task": "regression",
}
# Initially load and preprocess data
x_img, x_nodata = prepare_images(img=x, nodata=x_nodata, dtype=x_dtype)
# Write a reference file
write_reference(x_img, output, x_nodata)
if y is not None:
y_img, y_nodata = prepare_seg_maps(y=y, y_nodata=y_nodata, ref=x_img[0], dtype=y_dtype, dtype_class=np.floating)
assert_equal_shapes([x_img, y_img])
else:
y_img, y_nodata = None, None
border, padding = border_pad(x_img, tile_size)
data["tile_size"] = tile_size
data["border"] = border
data["pad"] = padding
# Padding
x_img = pad(x_img, padding, x_nodata)
if y_img is not None:
y_img = pad(y_img, padding, y_nodata)
# Filtering
if y_img is not None and filter_nodata is not None:
x_img, y_img = filter_nodata_raster(x_img, y_img, filter_nodata, x_nodata, y_nodata)
data["x"] = {
"dtype": x_img.dtype,
"nodata": x_nodata,
"bands": x_img.shape[0],
"variables": ["_".join(x.split("_")[1:]) for x in x_img["band"].values.tolist()],
}
if y_img is not None:
data["y"] = {}
for i in range(len(y_img)):
# noinspection PyUnresolvedReferences
data["y"][y[i].name] = {"dtype": y_img[i].dtype, "nodata": y_nodata}
# Setting up tiles generators
x_batches = xbatcher.BatchGenerator(
ds=x_img,
input_dims={"x": tile_size - (border * 2), "y": tile_size - (border * 2)},
)
if y_img is not None:
y_batches = xbatcher.BatchGenerator(
ds=y_img,
input_dims={"x": tile_size - (border * 2), "y": tile_size - (border * 2)},
)
else:
y_batches = None
# Getting samples
samples = list(range(len(x_batches)))
samples_x = filter_samples(x_batches, samples, x_nodata)
if y_img is not None:
samples_y = filter_samples(y_batches, samples, y_nodata)
samples = list(set(samples_x + samples_y))
else:
samples = samples_x
# Shuffling samples
if shuffle:
np.random.shuffle(samples)
# Splitting samples
samples = split_samples(samples, split)
data["samples"] = samples
write_json(data, output / "meta.json")
# Calculating optimal batch size
# Target size is 256MB
target_size = 256 * 1024 * 1024
x_channels = x_img.shape[0]
y_channels = len(y_img) if y_img is not None else 0
bytes_per_sample = (x_channels + y_channels) * (tile_size**2) * 4
writer_batch_size = max(1, int(target_size / bytes_per_sample))
for name in split:
# Generate features
feat = {
"key": datasets.Value(dtype="int64"),
"x": datasets.Array3D(dtype="float32", shape=(data["x"]["bands"], tile_size, tile_size)),
}
if y_img is not None:
for i in range(len(y)):
# noinspection PyTypeChecker
feat["y_" + y[i].name] = datasets.Array2D(dtype="float32", shape=(tile_size, tile_size))
feat = datasets.Features(feat)
def dataset_generator(samples: dict[str, list[int]], name: str) -> Generator[dict]:
for index in samples[name]:
data_dict = {
"key": index,
"x": np.pad(
x_batches[index].data,
((0, 0), (border, border), (border, border)),
"symmetric",
),
}
if y is not None:
y_index = np.pad(
y_batches[index].data,
((0, 0), (border, border), (border, border)),
"symmetric",
)
for j in range(len(y)):
data_dict["y_" + y[j].name] = y_index[j]
yield data_dict
# Create dataset
ds = datasets.Dataset.from_generator(
dataset_generator,
features=feat,
cache_dir=(output / ".cache").as_posix(),
fingerprint=unique_id,
gen_kwargs={"samples": samples, "name": name},
writer_batch_size=writer_batch_size,
)
# Save dataset
ds.save_to_disk(output / name)
# Cleaning the cache
cache = get_cache(ds)
del ds
gc.collect()
clean_cache(cache)
return output
def prepare_seg_maps(
y: ListOfY,
ref: xr.DataArray,
dtype: Optional[type] = None,
dtype_class: Optional[type] = np.floating,
y_nodata: Optional[Union[int, float]] = 0,
) -> tuple[xr.DataArray, Union[int, float]]:
"""Prepare segmentation maps: match rasters."""
if y_nodata is None:
y_nodata = 0
arrays = []
for ds in y:
if ds.burn_value is not None:
arrays.append(dask.delayed(prepare_vector_sm)(ds.path, ref, ds.burn_value, ds.name, y_nodata))
else:
arrays.append(dask.delayed(prepare_raster_sm)(ds.path, ref, ds.name, y_nodata))
arrays = list(dask.compute(*arrays))
arrays = xr.merge(arrays)
arrays = check_dtype(img=arrays, dtype=dtype, dtype_class=dtype_class)
arrays = persist(arrays.squeeze().to_array("band").chunk("auto"))
return arrays, y_nodata