Source code for remote_sensing_processor.mosaic.mosaic

"""Mosaic rasters."""

from pydantic import PositiveInt, validate_call
from typing import Optional, Union

import dask
import dask.array as da
import xarray as xr

import rasterio as rio
import rioxarray as rxr
from rioxarray.merge import merge_datasets

from pystac import Item

from remote_sensing_processor.common.common_functions import create_folder, persist
from remote_sensing_processor.common.common_raster import (
    check_dtype,
    clipf,
    load_dataset,
    prepare_nodata,
    reproject,
    reproject_match,
    write_dataset,
)
from remote_sensing_processor.common.dataset import is_multiband, read_dataset
from remote_sensing_processor.common.fill import fillnodata
from remote_sensing_processor.common.match_hist import histogram_match
from remote_sensing_processor.common.types import CRS, DirectoryPath, FilePath, NewPath, PystacItem
from remote_sensing_processor.mosaic.dataset import postprocess_mosaic_dataset


[docs] @validate_call def mosaic( inputs: list[Union[FilePath, DirectoryPath, PystacItem]], output_dir: Union[DirectoryPath, NewPath], fill_nodata: Optional[bool] = False, fill_distance: Optional[PositiveInt] = 250, clip: Optional[FilePath] = None, crs: Optional[CRS] = None, nodata: Optional[Union[int, float]] = None, reference_raster: Optional[Union[FilePath, DirectoryPath, PystacItem]] = None, resample: Optional[str] = "average", nodata_order: Optional[bool] = False, match_hist: Optional[bool] = False, keep_all_channels: Optional[bool] = True, write_stac: Optional[bool] = True, ) -> Union[DirectoryPath, NewPath]: """ Creates mosaic from several rasters. Parameters ---------- inputs : list of strings or list of STAC Items List of pathes to rasters to be merged or to folders where multiband imagery data is stored or to STAC Items in order from images that should be on top to images that should be on bottom. output_dir: path to output directory as a string Path where mosaic raster or rasters will be saved. fill_nodata : bool (default = False) Is filling the gaps in the raster needed. fill_distance : int (default = 250) Fill distance for `fill_nodata` function. clip : string (optional) Path to a vector file to be used to crop the image. crs : string (optional) CRS in which output data should be. nodata : int or float (default = None) Nodata value. If not set, then is read from inputs. reference_raster : string or STAC Item (optional) Reference raster is needed to bring output mosaic raster to the same resolution and projection as another data source. It is useful when you need to use data from different sources together. resample : resampling method from rasterio as a string (default = 'average') Resampling method that will be used to reproject and reshape to a reference raster shape. You can read more about resampling methods `here <https://rasterio.readthedocs.io/en/latest/topics/resampling.html>`_. Use 'nearest' if you want to keep only the same values that exist in the input raster. nodata_order : bool (default = False) Is it needed to merge images in order from images with less nodata on top (they are usually clear) to images with more nodata values on bottom (they are usually the most distorted and cloudy). match_hist : bool (default = False) Is it needed to match histograms of merged images. Improve mosaic uniformity, but change the original data. keep_all_channels : bool (default = True) Is needed only when you are merging images that have different number of channels (e.g., Landsat images from different generations). If True, all bands are processed, if False, only bands that are present in all input images are processed and others are omitted. write_stac : bool (default = True) If True, then output metadata is saved to a STAC file. Returns ------- pathlib.Path Path to mosaic rasters. Examples -------- >>> # Mosaic multiple Sentinel 2 multi-band products >>> import remote_sensing_processor as rsp >>> input_sentinels = [ ... "/home/rsp_test/sentinels/Sentinel1/meta.json", ... "/home/rsp_test/sentinels/Sentinel2/meta.json", ... "/home/rsp_test/sentinels/Sentinel3/meta.json", ... "/home/rsp_test/sentinels/Sentinel4/meta.json", ... "/home/rsp_test/sentinels/Sentinel5/meta.json", ... "/home/rsp_test/sentinels/Sentinel6/meta.json", ... ] >>> border = "/home/rsp_test/border.gpkg" >>> mosaic_sentinel = rsp.mosaic( ... inputs=input_sentinels, ... output_dir="/home/rsp_test/mosaics/sentinel/", ... clip=border, ... crs="EPSG:4326", ... nodata_order=True, ... ) Processing completed >>> print(mosaic_sentinel) '/home/rsp_test/mosaics/sentinel/S2A_MSIL2A_20210821T064626_N0209_R063_T42VWR_20210821T064626_mosaic.json' >>> from glob import glob >>> # Mosaic multiple DEM files and matching it with a reference raster (Sentinel band) >>> lcs = glob("/home/rsp_test/landcover/*.tif") >>> print(lcs) ['/home/rsp_test/landcover/ESA_WorldCover_10m_2020_v100_N60E075_Map.tif', '/home/rsp_test/landcover/ESA_WorldCover_10m_2020_v100_N63E072_Map.tif', '/home/rsp_test/landcover/ESA_WorldCover_10m_2020_v100_N63E075_Map.tif'] >>> mosaic_landcover = rsp.mosaic( ... inputs=lcs, ... output_dir="/home/rsp_test/mosaics/landcover/", ... clip=border, ... reference_raster="/home/rsp_test/mosaics/sentinel/B1.tif", ... nodata=-1, ... ) Processing completed >>> print(mosaic_landcover) '/home/rsp_test/mosaics/landcover/ESA_WorldCover_10m_2020_v100_N60E075_Map_mosaic.json' """ # Reading datasets datasets = [read_dataset(i) for i in inputs] mb = is_multiband(datasets[0]) # If datasets are single-band, then we should give the same name to their assets if not mb: bname = next(iter(datasets[0].assets.keys())) + "_mosaic" for i in range(len(datasets)): datasets[i].assets[bname] = datasets[i].assets.pop(next(iter(datasets[i].assets.keys()))) datasets[i].assets[bname].ext.eo.bands[0].name = bname # Getting nodata value if nodata is None: nodata = get_nodata(datasets, nodata) # Sorting in nodata order if nodata_order: datasets = order(datasets, nodata, clip) # Reading reference raster dataset if reference_raster is not None: reference_raster = read_dataset(reference_raster) crs = rio.crs.CRS.from_user_input(reference_raster.ext.proj.code) # Getting only the bands we need bands = get_bands(datasets, keep_all_channels) # Creating an output folder if not exists create_folder(output_dir, clean=False) # Step 1. Pre-process the data futures = [] for dataset in datasets: futures.append( dask.delayed(initial_process_dataset)( dataset=dataset, bands=bands, clip=clip, crs=crs, nodata=nodata, ), ) files = list(dask.compute(*futures)) # Step 2: If histogram matching is needed, apply it if match_hist: hist_match_futures = [] # Apply histogram matching to the rest of the files in parallel for file in files[1:]: hist_match_futures.append(histogram_match(file, files[0], nodata)) files[1:] = list(dask.compute(*hist_match_futures)) # Merging files final = merge_datasets(files, method="first", nodata=nodata).chunk("auto") final = persist(final) # Resampling to the same shape and resolution as another raster if reference_raster is not None: ref = load_dataset(reference_raster) final = reproject_match(final, ref, resample) # Clipping mosaic with vector mask if clip is not None: final = clipf(final, clip) # Filling nodata if fill_nodata: mask = xr.ones_like(final) # nodata -> 0, data -> 1 mask = mask.where(final != nodata, 0) # outside -> 1 for band in mask: mask[band] = mask[band].rio.write_nodata(1) mask = clipf(mask, clip) # Fill nodata final = fillnodata(final, mask, fill_distance, nodata) final = check_dtype(final) # Creating final STAC dataset stac, json_path = postprocess_mosaic_dataset(datasets, final, output_dir, bands) # Write write_dataset(final, stac, json_path) if write_stac: # Writing JSON metadata file stac.save_object(dest_href=json_path.as_posix()) return json_path return output_dir
def initial_process_dataset( dataset: Item, bands: list[str], clip: Optional[FilePath] = None, crs: Optional[CRS] = None, nodata: Optional[Union[int, float]] = None, ) -> xr.Dataset: """Prepare a single dataset for mosaicking, without histogram matching.""" img = load_dataset(dataset, bands, clip) img, nodata = prepare_nodata(img, nodata=nodata) if crs is not None: img = reproject(img, crs) if clip is not None: img = clipf(img, clip) img = check_dtype(img) # Adding empty data arrays for bands that are absent in the current dataset for band in bands: if band not in dataset.assets: img[band] = img[next(iter(img.keys()))] img[band].data = da.full_like(img[band], nodata) return persist(img) def get_bands(datasets: list[Item], keep_all_channels: bool) -> list[str]: """Read band names. If keep_all_channels == True then will read all the names, if False then will read only the bands that are present in every dataset. """ band_lists = [] for dataset in datasets: band_lists.append(list(dataset.assets.keys())) if keep_all_channels: final_bands = list({x for xs in band_lists for x in xs}) else: final_bands = [] for lst in band_lists: for i in lst: n = 0 for ll in band_lists: if i in ll: n += 1 if n == len(band_lists): final_bands.append(i) return final_bands def get_nodata(datasets: list[Item], nodata: Optional[Union[int, float]] = None) -> Union[int, float]: """Get nodata from multiple datasets.""" for ds in datasets: bands = list(ds.assets.keys()) for band in bands: with rxr.open_rasterio(ds.assets[band].href, chunks=True, lock=True) as img: if nodata is None: nodata = img.rio.nodata else: if nodata != img.rio.nodata: raise ValueError("Nodata value of " + ds.id + " is different from the other files.") return nodata def order( datasets: list[Item], nodata: Optional[Union[int, float]] = None, clip: Optional[FilePath] = None, ) -> list[Item]: """Sort datasets in order from least nodata to most nodata.""" zeros = [] for ds in datasets: band = next(iter(ds.assets.keys())) with rxr.open_rasterio(ds.assets[band].href, chunks=True, lock=True) as img: if nodata is None: nodata = img.rio.nodata if nodata is not None: try: img = (img == nodata).astype("uint8") # nodata-1 data-0 img = img.rio.write_nodata(0) if clip is not None: img = clipf(img, clip) # nodata-1 data-0 outside_nodata-0 zeros.append((da.count_nonzero(img) / img.size).compute()) except Exception: zeros.append(0.0) else: zeros.append(0.0) zerosdict = dict(zip(datasets, zeros, strict=True)) sortedzeros = dict(sorted(zerosdict.items(), key=lambda item: item[1], reverse=False)) return list(sortedzeros.keys())