Source code for remote_sensing_processor.common.match_hist

"""Histogram matching."""

from pydantic import validate_call
from typing import Optional, Union

import dask
import dask.array as da
import numpy as np
from xarray import DataArray, Dataset

from remote_sensing_processor.common.common_functions import create_path, persist
from remote_sensing_processor.common.common_raster import (
    check_dtype,
    load_dataset,
    prepare_nodata,
    restore_nodata_from_nan,
    set_nodata_to_nan,
    write_dataset,
)
from remote_sensing_processor.common.dataset import check_output, postprocess_dataset, read_dataset
from remote_sensing_processor.common.types import DirectoryPath, FilePath, NewPath, PystacItem


[docs] @validate_call def match_hist( input_path: Union[FilePath, DirectoryPath, PystacItem], reference_raster: Union[FilePath, DirectoryPath, PystacItem], output_path: Optional[Union[FilePath, DirectoryPath, NewPath]] = None, nodata: Optional[Union[int, float]] = None, write_stac: Optional[bool] = True, ) -> NewPath: """ Matches histograms of two files. Parameters ---------- input_path : string or STAC Item Path to an input file, directory or a STAC dataset or a STAC Item (e.g., from Planetary Computer). reference_raster : string or STAC Item A raster that will be used as a reference in histogram matching. output_path : string (optional) Path to an output file, directory, or STAC dataset. If not set, then will overwrite the input files. Must be set if input is a remote STAC Item. nodata : int or float (default = None) Nodata value. If not set, then is read from inputs. write_stac : bool (default = True) If True, then output metadata is saved to a STAC file. Returns ------- pathlib.Path Path where output raster is saved. Examples -------- >>> import remote_sensing_processor as rsp >>> rsp.match_hist( ... input_path="/home/rsp_test/image_1/sentinel_B1.tif", ... output_path="/home/rsp_test/image_1/sentinel_B1_matched.tif", ... reference_raster="/home/rsp_test/image_2/sentinel_B1.tif", ... nodata=0, ... ) """ output_path = check_output(input_path, output_path) dataset = read_dataset(input_path) img = load_dataset(dataset) img, nodata = prepare_nodata(img, nodata) reference_raster = read_dataset(reference_raster) ref = load_dataset(reference_raster) img = histogram_match(img, ref, nodata) img = persist(img) img = check_dtype(img) # Creating an output folder create_path(output_path) # Creating final STAC dataset dataset, json_path = postprocess_dataset(dataset, img, output_path) # Write write_dataset(img, dataset, json_path) if write_stac: # Writing JSON metadata file dataset.save_object(dest_href=json_path.as_posix()) return json_path return output_path
def histogram_match(image: Dataset, reference: Dataset, nodata: Optional[Union[int, float]] = None) -> Dataset: """Service function that matches histograms. This function is a dask-aware implementation of histogram matching. It computes the global histograms of the image and reference arrays without loading the entire data into memory, making it suitable for large, chunked dask arrays. """ image = set_nodata_to_nan(image) reference = set_nodata_to_nan(reference) # Checking if the number of bands in both datasets is the same if not len(image.data_vars) == len(reference.data_vars): raise ValueError("Datasets have different number of channels") # If there is only one data variable, rename it if len(image.data_vars) == 1: reference = reference.rename({next(iter(list(reference.data_vars))): next(iter(list(image.data_vars)))}) if not set(image.data_vars) == set(reference.data_vars): raise ValueError("Datasets have different bands") # Histogram matching image = image.map(match_single_band, reference=reference) image, _ = prepare_nodata(image, nodata) image = restore_nodata_from_nan(image) return persist(image) def match_single_band(image: DataArray, reference: Dataset) -> DataArray: """Does a histogram matching for a single band.""" source_band = image.data reference_band = reference[image.name].data # Determine the min and max values to define histogram bins, excluding nodata vmin = da.nanmin(da.stack([da.nanmin(source_band), da.nanmin(reference_band)])) vmax = da.nanmax(da.stack([da.nanmax(source_band), da.nanmax(reference_band)])) vmin, vmax = dask.compute(vmin, vmax) # Use integer bins for integer data types bins = int(vmax - vmin + 1) if np.issubdtype(image.dtype, np.integer) else 256 # Compute global histograms for source and reference bands source_hist, _ = da.histogram(source_band, bins=bins, range=[vmin, vmax]) reference_hist, _ = da.histogram(reference_band, bins=bins, range=[vmin, vmax]) # The histograms are small, so we can compute them into memory. source_hist, reference_hist = dask.compute(source_hist, reference_hist) # Compute Cumulative Distribution Functions (CDFs) source_cdf = source_hist.cumsum() / source_hist.sum() if source_hist.sum() > 0 else source_hist reference_cdf = reference_hist.cumsum() / reference_hist.sum() if reference_hist.sum() > 0 else reference_hist # Compute the lookup table (LUT) bin_centers = np.linspace(vmin, vmax, bins) interp_values = np.interp(source_cdf, reference_cdf, bin_centers) # Apply the LUT to the source band using map_blocks for memory efficiency new_data = da.map_blocks( lambda block: np.interp(block, bin_centers, interp_values), source_band, dtype=image.dtype, ) image.data = new_data return image ''' def histogram_match(image, reference, nodata): """Service function that matches histograms. As match_histograms is a function that is not natively dask-compatible; there are several hacks that try to fix that """ if isinstance(image, xr.Dataset): image = image.to_dataarray("band") i_ds = True else: i_ds = False if isinstance(reference, xr.Dataset): reference = reference.to_dataarray("band") # TODO: find a way how to run it without rechunk try: # If an array consists of several chunks, we need to turn it to one chunk to calc the values for the whole array image = image.chunk(-1) reference = reference.chunk(-1) image.data = dask.array.map_blocks(match_with_nodata, image.data, reference.data, nodata) image = image.chunk("auto") except Exception: # If cannot rechunk, then process blockwise, but the output will contain errors warnings.warn( "Failed to rechunk the input array to single chunk. The result might be inconsistent.", stacklevel=1, ) image.data = dask.array.map_blocks(match_with_nodata, image.data, reference.data, nodata) if i_ds: image = image.to_dataset("band").expand_dims({"band": 1}) image, _ = prepare_nodata(image, nodata) return image def match_with_nodata(image, reference, nodata): """Performs band-wise histogram matching on masked arrays. Adapted from https://gist.github.com/tayden/dcc83424ce55bfb970f60db3d4ddad18 """ image_mask = np.where(image == nodata, True, False) reference_mask = np.where(reference == nodata, True, False) masked_source_image = np.ma.array(image, mask=image_mask) masked_reference_image = np.ma.array(reference, mask=reference_mask) matched = np.ma.array(np.empty(image.shape, dtype=image.dtype), mask=image_mask, fill_value=nodata) for channel in range(masked_source_image.shape[0]): matched_channel = match_histograms( masked_source_image[channel].compressed(), masked_reference_image[channel].compressed(), ) # Re-insert masked background mask_ch = image_mask[channel] matched[channel][~mask_ch] = matched_channel.ravel() return matched.filled() '''