Source code for disdrodb.utils.netcdf

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# -----------------------------------------------------------------------------.
# Copyright (c) 2021-2022 DISDRODB developers
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------.

import logging
import numpy as np
import pandas as pd
import xarray as xr
from typing import Tuple
from disdrodb.utils.logger import log_info, log_warning, log_error

logger = logging.getLogger(__name__)


####---------------------------------------------------------------------------.
def _sort_datasets_by_dim(list_ds: list, fpaths: str, dim: str = "time") -> Tuple[list, list]:
    """Sort a list of xarray.Dataset and corresponding file paths by the starting value of a specified dimension.

    Parameters
    ----------
    fpaths : list
        List of netCDFs file paths.
    list_ds : list
        List of xarray Dataset.
    dim : str, optional
        Dimension name. The default is "time".

    Returns
    -------
    tuple
        Tuple of sorted list of xarray datasets and sorted list of file paths.
    """
    start_values = [ds[dim].values[0] for ds in list_ds]
    sorted_idx = np.argsort(start_values)
    sorted_list_ds = [list_ds[i] for i in sorted_idx]
    sorted_fpaths = [fpaths[i] for i in sorted_idx]
    return sorted_list_ds, sorted_fpaths


def _get_dim_values_index(list_ds: list, dim: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Get list and dataset indices associated to the dimension values."""
    dim_values = np.concatenate([ds[dim].values for ds in list_ds])
    list_index = np.concatenate([np.ones(len(ds[dim])) * i for i, ds in enumerate(list_ds)])
    list_index = list_index.astype(int)
    ds_index = np.concatenate([np.arange(0, len(ds[dim])) for i, ds in enumerate(list_ds)])
    return dim_values, list_index, ds_index


def _get_non_monotonic_indices_to_remove(dim_values: np.ndarray) -> np.ndarray:
    """Returns the indices that cause a non-monotonic increasing series of values.

    Assume that duplicated values, if present, occurs consecutively !
    """
    diff_dim_values = np.diff(dim_values)
    indices_decreasing = np.where(diff_dim_values.astype(float) <= 0)[0] + 1
    if len(indices_decreasing) == 0:
        return []
    idx_start_decreasing = indices_decreasing[0]
    idx_restart_increase = np.max(np.where(dim_values <= dim_values[idx_start_decreasing - 1])[0])
    idx_to_remove = np.arange(idx_start_decreasing, idx_restart_increase + 1)
    return idx_to_remove


def _get_duplicated_indices(x, keep="first"):
    """Return the indices to remove for duplicated values in x such that there is only one value occurence.

    Parameters
    ----------
    x :  np.array
        Array of values.
    keep : str, optional
        The value to keep, either 'first', 'last' or False.
        The default is 'first'.
        ‘first’ : Mark duplicates as True except for the first occurrence.
        ‘last’ : Mark duplicates as True except for the last occurrence.
        False : Mark all duplicates as True.

    Returns
    -------
    np.array
        Array of indices to remove.
    """
    # Check 'keep' argument
    # if not isinstance(keep, str):
    #     raise TypeError("`keep` must be a string. Either first or last.")
    # if not np.isin(keep, ["first", "last"]):
    #     raise ValueError("Invalid value for argument keep. Only 'first' and 'last' are accepted.")
    # # Get

    # x_indices = np.arange(len(x))
    # unique_values, unique_counts = np.unique(x, return_counts=True)
    # duplicated_values = unique_values[unique_counts > 1]

    # duplicated_indices = np.array([], dtype=np.int32)
    # if keep == 'first':
    #     for value in duplicated_values:
    #         indices = np.where(x == value)[0]
    #         duplicated_indices = np.concatenate([duplicated_indices, indices[1:]])
    # elif keep == 'last':
    #     indices = np.where(x == value)[0]
    #     duplicated_indices = np.concatenate([duplicated_indices, indices[:-1]])
    # return duplicated_indices

    # Get duplicate indices
    idx_duplicated = pd.Index(x).duplicated(keep=keep)
    return np.where(idx_duplicated)[0]


def _get_bad_info_dict(
    idx_to_remove: np.ndarray,
    list_index: np.ndarray,
    dim_values: np.ndarray,
    ds_index: np.ndarray,
) -> Tuple[dict, dict]:
    """Return two dictionaries mapping, for each dataset, the bad values and indices to remove.

    Parameters
    ----------
    idx_to_remove : np.ndarray
        Indices to be removed to ensure monotonic dimension.
    list_index : np.ndarray
        Indices corresponding to the file in the `list_ds` parameter.
    ds_index : np.ndarray
        Indices corresponding to the dataset dimension index in the `list_ds` parameter.

    Returns
    -------
    dict
        A dictionary mapping the dimension values to remove for each file.
    dict
        A dictionary mapping the dataset dimension indices to remove for each file.
    """
    list_index_bad = list_index[idx_to_remove]
    ds_index_bad = ds_index[idx_to_remove]
    dim_values_bad = dim_values[idx_to_remove]
    # Retrieve dictionary with the bad values in each dataset
    dict_ds_bad_values = {k: dim_values_bad[np.where(list_index_bad == k)[0]] for k in np.unique(list_index_bad)}
    # Retrieve dictionary with the index with the bad values in each dataset
    dict_ds_bad_idx = {k: ds_index_bad[np.where(list_index_bad == k)[0]] for k in np.unique(list_index_bad)}
    return dict_ds_bad_values, dict_ds_bad_idx


def _remove_dataset_bad_values(list_ds, fpaths, dict_ds_bad_idx, dim):
    """Remove portions of xarray Datasets corresponding to duplicated values.

    Parameters
    ----------
    list_ds : list
        List of xarray Dataset.
    dict_ds_bad_idx : dict
        Dictionary with the dimension indices corresponding to bad values in each xarray Dataset.

    Returns
    -------

    list_ds : list
        List of xarray Dataset without bad values.
    """
    list_index_valid = list(range(len(list_ds)))
    for list_index_bad, bad_idx in dict_ds_bad_idx.items():
        # Get dataset
        ds = list_ds[list_index_bad]
        # If resulting in a empty dataset, drop index from list_index_valid
        if len(bad_idx) == len(list_ds[list_index_bad][dim]):
            list_index_valid.remove(list_index_bad)
        # Remove unvalid indices
        list_ds[list_index_bad] = ds.drop_isel({dim: bad_idx})

    # Keep non-empty datasets
    new_list_ds = [list_ds[idx] for idx in list_index_valid]
    new_fpaths = [fpaths[idx] for idx in list_index_valid]
    return new_list_ds, new_fpaths


[docs]def ensure_unique_dimension_values(list_ds: list, fpaths: str, dim: str = "time", verbose: bool = False) -> list: """Ensure that a list of xr.Dataset has non duplicated dimension values. Parameters ---------- list_ds : list List of xarray Dataset. fpaths : list List of netCDFs file paths. dim : str, optional Dimension name. The default is "time". Returns ------- list List of xarray Dataset. list List of netCDFs file paths. """ # Reorder the files and filepaths by the starting dimension value (time) sorted_list_ds, sorted_fpaths = _sort_datasets_by_dim(list_ds=list_ds, fpaths=fpaths, dim=dim) # Get the datasets dimension values array (and associated list_ds/xr.Dataset indices) dim_values, list_index, ds_index = _get_dim_values_index(list_ds, dim=dim) # Get duplicated indices idx_duplicated = _get_duplicated_indices(dim_values, keep="first") # Remove duplicated indices if len(idx_duplicated) > 0: # Retrieve dictionary providing bad values and indexes for each dataset dict_ds_bad_values, dict_ds_bad_idx = _get_bad_info_dict( idx_to_remove=idx_duplicated, list_index=list_index, dim_values=dim_values, ds_index=ds_index, ) # Report for each dataset, the duplicates values occuring for list_index_bad, bad_values in dict_ds_bad_values.items(): # Retrieve dataset filepath fpath = fpaths[list_index_bad] # If all values inside the file are duplicated, report it if len(bad_values) == len(list_ds[list_index_bad][dim]): msg = ( f"{fpath} is excluded from concatenation. All {dim} values are already present in some other file." ) log_warning(logger=logger, msg=msg, verbose=verbose) else: if np.issubdtype(bad_values.dtype, np.datetime64): bad_values = bad_values.astype("M8[s]") msg = f"In {fpath}, dropping {dim} values {bad_values} to avoid duplicated {dim} values." log_warning(logger=logger, msg=msg, verbose=verbose) # Remove duplicated values list_ds, fpaths = _remove_dataset_bad_values( list_ds=list_ds, fpaths=fpaths, dict_ds_bad_idx=dict_ds_bad_idx, dim=dim ) return list_ds, fpaths
[docs]def ensure_monotonic_dimension(list_ds: list, fpaths: str, dim: str = "time", verbose: bool = False) -> list: """Ensure that a list of xr.Dataset has a monotonic increasing (non duplicated) dimension values. Parameters ---------- list_ds : list List of xarray Dataset. fpaths : list List of netCDFs file paths. dim : str, optional Dimension name. The default is "time". Returns ------- list List of xarray Dataset. list List of netCDFs file paths. """ # Reorder the files and filepaths by the starting dimension value (time) # TODO: should maybe also split by non-continuous time ... sorted_list_ds, sorted_fpaths = _sort_datasets_by_dim(list_ds=list_ds, fpaths=fpaths, dim=dim) # Get the datasets dimension values array (and associated list_ds/xr.Dataset indices) dim_values, list_index, ds_index = _get_dim_values_index(list_ds, dim=dim) # Identify the indices to remove to ensure monotonic values idx_to_remove = _get_non_monotonic_indices_to_remove(dim_values) # Remove indices causing the values to be non-monotonic increasing if len(idx_to_remove) > 0: # Retrieve dictionary providing bad values and indexes for each dataset dict_ds_bad_values, dict_ds_bad_idx = _get_bad_info_dict( idx_to_remove=idx_to_remove, list_index=list_index, dim_values=dim_values, ds_index=ds_index, ) # Report for each dataset, the values to be dropped for list_index_bad, bad_values in dict_ds_bad_values.items(): # Retrieve dataset filepath fpath = fpaths[list_index_bad] # If all values inside the file shoudl be dropped, report it if len(bad_values) == len(list_ds[list_index_bad][dim]): msg = ( f"{fpath} is excluded from concatenation. All {dim} values cause the dimension to be non-monotonic." ) log_warning(logger=logger, msg=msg, verbose=verbose) else: if np.issubdtype(bad_values.dtype, np.datetime64): bad_values = bad_values.astype("M8[s]") msg = f"In {fpath}, dropping {dim} values {bad_values} to ensure monotonic {dim} dimension." log_warning(logger=logger, msg=msg, verbose=verbose) # Remove duplicated values list_ds, fpaths = _remove_dataset_bad_values( list_ds=list_ds, fpaths=fpaths, dict_ds_bad_idx=dict_ds_bad_idx, dim=dim ) # Iterative check list_ds, fpathsa = ensure_monotonic_dimension(list_ds=list_ds, fpaths=fpaths, dim=dim) return list_ds, fpaths
# ds_index = [0,1,2,3,0,1,2,3,4] # list_index = [0,0,0,0,1, 1, 1,1, 1] # dim_values = [0,1,5,5,5, 5, 6,7,8] # list_index = np.array(list_index) # dim_values = np.array(dim_values) # ds_index = np.array(ds_index) ####---------------------------------------------------------------------------
[docs]def get_list_ds(fpaths: str) -> list: """Get list of xarray datasets from file paths. Parameters ---------- fpaths : list List of netCDFs file paths. Returns ------- list List of xarray datasets. """ import xarray as xr list_ds = [] for fpath in fpaths: # This context manager is required to avoid random HDF locking # - cache=True: store data in memory to avoid reading back from disk # --> but LRU cache might cause the netCDF to not be closed ! with xr.open_dataset(fpath, cache=False) as data: ds = data.load() list_ds.append(ds) return list_ds
# def get_list_ds(fpaths: str) -> list: # """Get list of xarray datasets from file paths. # Parameters # ---------- # fpaths : list # List of netCDFs file paths. # Returns # ------- # list # List of xarray datasets. # """ # # WARNING: READING IN PARALLEL USING MULTIPROCESS CAUSE HDF LOCK ERRORS # @dask.delayed # def open_dataset_delayed(fpath): # import os # os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" # import xarray as xr # # This context manager is required to avoid random HDF locking # # - cache=True: store data in memory to avoid reading back from disk # # --> but LRU cache might cause the netCDF to not be closed ! # with xr.open_dataset(fpath, cache=False) as data: # ds = data.load() # return ds # list_ds_delayed = [] # for fpath in fpaths: # list_ds_delayed.append(open_dataset_delayed(fpath)) # list_ds = dask.compute(list_ds_delayed)[0] # return list_ds ####--------------------------------------------------------------------------- def _concatenate_datasets(list_ds, dim="time", verbose=False): try: msg = "Start concatenating with xr.concat." log_info(logger=logger, msg=msg, verbose=verbose) ds = xr.concat(list_ds, dim="time", coords="minimal", compat="override") msg = "Concatenation with xr.concat has been successful." log_info(logger=logger, msg=msg, verbose=verbose) except Exception as e: msg = f"Concatenation with xr.concat failed. Error is {e}." log_error(logger=logger, msg=msg, verbose=False) raise ValueError(msg) return ds
[docs]def xr_concat_datasets(fpaths: str, verbose=False) -> xr.Dataset: """Concat xr.Dataset in a robust and parallel way. 1. It checks for time dimension monotonicity Parameters ---------- fpaths : list List of netCDFs file paths. Returns ------- xr.Dataset A single xarray dataset. Raises ------ ValueError Error if the merging/concatenation operations can not be achieved. """ # --------------------------------------. # Open xr.Dataset lazily in parallel using dask delayed list_ds = get_list_ds(fpaths) # --------------------------------------. # Ensure time dimension contains no duplicated values list_ds, fpaths = ensure_unique_dimension_values(list_ds=list_ds, fpaths=fpaths, dim="time", verbose=verbose) # Ensure time dimension is monotonic increasingly list_ds, fpaths = ensure_monotonic_dimension(list_ds=list_ds, fpaths=fpaths, dim="time", verbose=verbose) # --------------------------------------. # Concatenate all netCDFs ds = _concatenate_datasets(list_ds=list_ds, dim="time", verbose=verbose) # --------------------------------------. # Return xr.Dataset return ds