From e7f4358326915c474951874277ca575186c48fe8 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Thu, 12 Oct 2023 17:00:01 +0530 Subject: [PATCH 01/16] added typing - 1 --- .prettierrc.toml | 2 +- xesmf/backend.py | 14 ++++-- xesmf/data.py | 31 +++++++------ xesmf/frontend.py | 116 ++++++++++++++++++++++++++-------------------- xesmf/smm.py | 109 ++++++++++++++++++++++++++----------------- xesmf/util.py | 86 ++++++++++++++++++++++++---------- 6 files changed, 222 insertions(+), 136 deletions(-) diff --git a/.prettierrc.toml b/.prettierrc.toml index addd6d36..24a4663a 100644 --- a/.prettierrc.toml +++ b/.prettierrc.toml @@ -1,3 +1,3 @@ -tabWidth = 2 +tabWidth = 4 semi = false singleQuote = true diff --git a/xesmf/backend.py b/xesmf/backend.py index 55294181..7b87cc60 100644 --- a/xesmf/backend.py +++ b/xesmf/backend.py @@ -26,7 +26,7 @@ import numpy.lib.recfunctions as nprec -def warn_f_contiguous(a): +def warn_f_contiguous(a: np.ndarray) -> None: """ Give a warning if input array if not Fortran-ordered. @@ -41,7 +41,7 @@ def warn_f_contiguous(a): warnings.warn('Input array is not F_CONTIGUOUS. ' 'Will affect performance.') -def warn_lat_range(lat): +def warn_lat_range(lat: np.ndarray) -> None: """ Give a warning if latitude is outside of [-90, 90] @@ -58,7 +58,13 @@ def warn_lat_range(lat): class Grid(ESMF.Grid): @classmethod - def from_xarray(cls, lon, lat, periodic=False, mask=None): + def from_xarray( + cls, + lon: np.ndarray[float, int], + lat: np.ndarray[float, int], + periodic: bool = False, + mask=None, + ): """ Create an ESMF.Grid object, for constructing ESMF.Field and ESMF.Regrid. @@ -158,7 +164,7 @@ def from_xarray(cls, lon, lat): Parameters ---------- lon, lat : 1D numpy array - Longitute/Latitude of cell centers. + Longitute/Latitude of cell centers. Returns ------- diff --git a/xesmf/data.py b/xesmf/data.py index e534c3a1..5d011ebc 100644 --- a/xesmf/data.py +++ b/xesmf/data.py @@ -3,42 +3,45 @@ """ import numpy as np +import xarray -def wave_smooth(lon, lat): - r""" +def wave_smooth( # type: ignore + lon: np.ndarray[float] | xarray.DataArray, # type: ignore + lat: np.ndarray[float] | xarray.DataArray, # type: ignore +) -> np.ndarray[float] | xarray.DataArray: # type: ignore + """ Spherical harmonic with low frequency. Parameters ---------- lon, lat : 2D numpy array or xarray DataArray - Longitute/Latitude of cell centers + Longitute/Latitude of cell centers Returns ------- - f : 2D numpy array or xarray DataArray depending on input - 2D wave field + f : 2D numpy array or xarray DataArray depending on input2D wave field Notes ------- Equation from [1]_ [2]_: - .. math:: Y_2^2 = 2 + \cos^2(\\theta) \cos(2 \phi) + .. math:: Y_2^2 = 2 + cos^2(lat) * cos(2 * lon) References ---------- .. [1] Jones, P. W. (1999). First-and second-order conservative remapping - schemes for grids in spherical coordinates. Monthly Weather Review, - 127(9), 2204-2210. + schemes for grids in spherical coordinates. Monthly Weather Review, + 127(9), 2204-2210. .. [2] Ullrich, P. A., Lauritzen, P. H., & Jablonowski, C. (2009). - Geometrically exact conservative remapping (GECoRe): regular - latitude–longitude and cubed-sphere grids. Monthly Weather Review, - 137(6), 1721-1741. + Geometrically exact conservative remapping (GECoRe): regular + latitude-longitude and cubed-sphere grids. Monthly Weather Review, + 137(6), 1721-1741. """ # degree to radius, make a copy - lat = lat / 180.0 * np.pi - lon = lon / 180.0 * np.pi + lat *= np.pi / 180.0 # type: ignore + lon *= np.pi / 180.0 # type: ignore - f = 2 + np.cos(lat) ** 2 * np.cos(2 * lon) + f = 2 + pow(np.cos(lat), 2) * np.cos(2 * lon) # type: ignore return f diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 267ed3d9..2582ed24 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -8,9 +8,9 @@ import numpy as np import sparse as sps import xarray as xr -from shapely.geometry import LineString +from shapely.geometry import LineString, Polygon from xarray import DataArray, Dataset - +from typing import Any, Literal, Optional, Tuple from .backend import Grid, LocStream, Mesh, add_corner, esmf_regrid_build, esmf_regrid_finalize from .smm import ( _combine_weight_multipoly, @@ -31,7 +31,15 @@ def subset_regridder( - ds_out, ds_in, method, in_dims, out_dims, locstream_in, locstream_out, periodic, **kwargs + ds_out, + ds_in, + method, + in_dims, + out_dims, + locstream_in, + locstream_out, + periodic, + **kwargs, ): """Compute subset of weights""" kwargs.pop('filename', None) # Don't save subset of weights @@ -234,20 +242,20 @@ def polys_to_ESMFmesh(polys): class BaseRegridder(object): def __init__( self, - grid_in, - grid_out, - method, - filename=None, - reuse_weights=False, - extrap_method=None, - extrap_dist_exponent=None, - extrap_num_src_pnts=None, - weights=None, - ignore_degenerate=None, - input_dims=None, - output_dims=None, - unmapped_to_nan=False, - parallel=False, + grid_in: Grid, + grid_out: Grid, + method: str, + filename: Optional[str] = None, + reuse_weights: bool = False, + extrap_method: Optional[Literal['inverse_dist', 'nearest_s2d']] = None, + extrap_dist_exponent: Optional[float] = None, + extrap_num_src_pnts: Optional[int] = None, + weights: Optional[Any] = None, + ignore_degenerate: bool = False, + input_dims: Optional[Tuple[str, ...]] = None, + output_dims: Optional[Tuple[str, ...]] = None, + unmapped_to_nan: bool = False, + parallel: bool = False, ): """ Base xESMF regridding class supporting ESMF objects: `Grid`, `Mesh` and `LocStream`. @@ -298,10 +306,10 @@ def __init__( weights : None, coo_matrix, dict, str, Dataset, Path, Regridding weights, stored as - - a scipy.sparse COO matrix, - - a dictionary with keys `row_dst`, `col_src` and `weights`, - - an xarray Dataset with data variables `col`, `row` and `S`, - - or a path to a netCDF file created by ESMF. + - a scipy.sparse COO matrix, + - a dictionary with keys `row_dst`, `col_src` and `weights`, + - an xarray Dataset with data variables `col`, `row` and `S`, + - or a path to a netCDF file created by ESMF. If None, compute the weights. ignore_degenerate : bool, optional @@ -626,7 +634,12 @@ def regrid_dask(self, indata, **kwargs): return self.regrid_array(indata, self.weights.data, **kwargs) def regrid_dataarray( - self, dr_in, keep_attrs=False, skipna=False, na_thres=1.0, output_chunks=None + self, + dr_in, + keep_attrs: bool = False, + skipna: bool = False, + na_thres: float = 1.0, + output_chunks=None, ): """See __call__().""" @@ -646,7 +659,12 @@ def regrid_dataarray( return self._format_xroutput(dr_out, temp_horiz_dims) def regrid_dataset( - self, ds_in, keep_attrs=False, skipna=False, na_thres=1.0, output_chunks=None + self, + ds_in, + keep_attrs: bool = False, + skipna: bool = False, + na_thres: float = 1.0, + output_chunks=None, ): """See __call__().""" @@ -734,7 +752,7 @@ def __repr__(self): return info - def to_netcdf(self, filename=None): + def to_netcdf(self, filename: Optional[str] = None): """Save weights to disk as a netCDF file.""" if filename is None: filename = self.filename @@ -750,13 +768,15 @@ def to_netcdf(self, filename=None): class Regridder(BaseRegridder): def __init__( self, - ds_in, - ds_out, - method, - locstream_in=False, - locstream_out=False, - periodic=False, - parallel=False, + ds_in: xr.DataArray | xr.Dataset | dict, + ds_out: xr.DataArray | xr.Dataset | dict, + method: Literal[ + 'bilinear', 'conservative', 'conservative_normed', 'patch', 'nearest_s2d', 'nearest_d2s' + ], + locstream_in: bool = False, + locstream_out: bool = False, + periodic: bool = False, + parallel: bool = False, **kwargs, ): """ @@ -833,10 +853,10 @@ def __init__( weights : None, coo_matrix, dict, str, Dataset, Path, Regridding weights, stored as - - a scipy.sparse COO matrix, - - a dictionary with keys `row_dst`, `col_src` and `weights`, - - an xarray Dataset with data variables `col`, `row` and `S`, - - or a path to a netCDF file created by ESMF. + - a scipy.sparse COO matrix, + - a dictionary with keys `row_dst`, `col_src` and `weights`, + - an xarray Dataset with data variables `col`, `row` and `S`, + - or a path to a netCDF file created by ESMF. If None, compute the weights. @@ -1215,7 +1235,7 @@ def __init__( ) @staticmethod - def _check_polys_length(polys, threshold=1): + def _check_polys_length(polys: List[Polygons], threshold: int = 1) -> None: # Check length of polys segments, issue warning if too long check_polys, check_holes, _, _ = split_polygons_and_holes(polys) check_polys.extend(check_holes) @@ -1231,7 +1251,7 @@ def _check_polys_length(polys, threshold=1): stacklevel=2, ) - def _compute_weights_and_area(self, mesh_out): + def _compute_weights_and_area(self, mesh_out) -> tuple[DataArray, Any]: """Return the weights and the area of the destination mesh cells.""" # Build the regrid object @@ -1253,12 +1273,12 @@ def _compute_weights_and_area(self, mesh_out): esmf_regrid_finalize(regrid) return w, dstarea - def _compute_weights(self): + def _compute_weights(self) -> DataArray: """Return weight sparse matrix. This function first explodes the geometries into a flat list of Polygon exterior objects: - - Polygon -> polygon.exterior - - MultiPolygon -> list of polygon.exterior + - Polygon -> polygon.exterior + - MultiPolygon -> list of polygon.exterior and a list of Polygon.interiors (holes). @@ -1310,7 +1330,7 @@ def w(self) -> xr.DataArray: dims = self.geom_dim_name, 'y_in', 'x_in' return xr.DataArray(data, dims=dims) - def _get_default_filename(self): + def _get_default_filename(self) -> str: # e.g. bilinear_400x600_300x400.nc filename = 'spatialavg_{0}x{1}_{2}.nc'.format( self.shape_in[0], self.shape_in[1], self.n_out @@ -1318,15 +1338,13 @@ def _get_default_filename(self): return filename - def __repr__(self): + def __repr__(self) -> str: info = ( - 'xESMF SpatialAverager \n' - 'Weight filename: {} \n' - 'Reuse pre-computed weights? {} \n' - 'Input grid shape: {} \n' - 'Output list length: {} \n'.format( - self.filename, self.reuse_weights, self.shape_in, self.n_out - ) + f'xESMF SpatialAverager \n' + f'Weight filename: {self.filename} \n' + f'Reuse pre-computed weights? {self.reuse_weights} \n' + f'Input grid shape: {self.shape_in} \n' + f'Output list length: {self.n_out} \n' ) return info diff --git a/xesmf/smm.py b/xesmf/smm.py index a94bacb6..3e355cee 100644 --- a/xesmf/smm.py +++ b/xesmf/smm.py @@ -1,16 +1,21 @@ """ Sparse matrix multiplication (SMM) using scipy.sparse library. """ +from typing import Any, Tuple import warnings from pathlib import Path -import numba as nb +import numba as nb # type: ignore[import] import numpy as np -import sparse as sps +import sparse as sps # type: ignore[import] import xarray as xr -def read_weights(weights, n_in, n_out): +def read_weights( + weights: str | Path | xr.Dataset | xr.DataArray | sps.COO | dict, # type: ignore[no-untyped-def] + n_in: int, + n_out: int, +) -> xr.DataArray: """ Read regridding weights into a DataArray (sparse COO matrix). @@ -25,8 +30,8 @@ def read_weights(weights, n_in, n_out): ``(N_out, N_in)`` will be the shape of the returning sparse matrix. They are the total number of grid boxes in input and output grids:: - N_in = Nx_in * Ny_in - N_out = Nx_out * Ny_out + N_in = Nx_in * Ny_in + N_out = Nx_out * Ny_out We need them because the shape cannot always be inferred from the largest column and row indices, due to unmapped grid boxes. @@ -34,45 +39,49 @@ def read_weights(weights, n_in, n_out): Returns ------- xr.DataArray - A DataArray backed by a sparse.COO array, with dims ('out_dim', 'in_dim') - and size (n_out, n_in). + A DataArray backed by a sparse.COO array, with dims ('out_dim', 'in_dim') + and size (n_out, n_in). """ if isinstance(weights, (str, Path, xr.Dataset, dict)): - weights = _parse_coords_and_values(weights, n_in, n_out) + return _parse_coords_and_values(weights, n_in, n_out) - elif isinstance(weights, sps.COO): - weights = xr.DataArray(weights, dims=('out_dim', 'in_dim'), name='weights') + if isinstance(weights, sps.COO): + return xr.DataArray(weights, dims=('out_dim', 'in_dim'), name='weights') - elif not isinstance(weights, xr.DataArray): - raise ValueError(f'Weights of type {type(weights)} not understood.') + if isinstance(weights, xr.DataArray): # type: ignore[no-untyped-def] + return weights - return weights + raise ValueError(f'Weights of type {type(weights)} not understood.') -def _parse_coords_and_values(indata, n_in, n_out): +def _parse_coords_and_values( + indata: str | Path | xr.Dataset | dict, # type: ignore[no-untyped-def] + n_in: int, + n_out: int, +) -> xr.DataArray: """Creates a sparse.COO array from weights stored in a dict-like fashion. Parameters ---------- indata: str, Path, xr.Dataset or dict - A dictionary as returned by ESMF.Regrid.get_weights_dict - or an xarray Dataset (or its path) as saved by xESMF. + A dictionary as returned by ESMF.Regrid.get_weights_dict + or an xarray Dataset (or its path) as saved by xESMF. n_in : int - The number of points in the input grid. + The number of points in the input grid. n_out : int - The number of points in the output grid. + The number of points in the output grid. Returns ------- sparse.COO - Sparse array in the COO format. + Sparse array in the COO format. """ if isinstance(indata, (str, Path, xr.Dataset)): if not isinstance(indata, xr.Dataset): if not Path(indata).exists(): raise IOError(f'Weights file not found on disk.\n{indata}') - ds_w = xr.open_dataset(indata) + ds_w = xr.open_dataset(indata) # type: ignore[no-untyped-def] else: ds_w = indata @@ -82,9 +91,9 @@ def _parse_coords_and_values(indata, n_in, n_out): 'values of weights.' ) - col = ds_w['col'].values - 1 # Python starts with 0 - row = ds_w['row'].values - 1 - s = ds_w['S'].values + col = ds_w['col'].values - 1 # type: ignore[no-untyped-def] + row = ds_w['row'].values - 1 # type: ignore[no-untyped-def] + s = ds_w['S'].values # type: ignore[no-untyped-def] elif isinstance(indata, dict): if not {'col_src', 'row_dst', 'weights'}.issubset(indata.keys()): @@ -100,28 +109,33 @@ def _parse_coords_and_values(indata, n_in, n_out): return xr.DataArray(sps.COO(crds, s, (n_out, n_in)), dims=('out_dim', 'in_dim'), name='weights') -def check_shapes(indata, weights, shape_in, shape_out): +def check_shapes( + indata: np.ndarray, # type: ignore[no-untyped-def] + weights: np.ndarray, # type: ignore[no-untyped-def] + shape_in: Tuple[int, int], + shape_out: Tuple[int, int], +) -> None: """Compare the shapes of the input array, the weights and the regridder and raises potential errors. Parameters ---------- indata : array - Input array with the two spatial dimensions at the end, - which should fit shape_in. + Input array with the two spatial dimensions at the end, + which should fit shape_in. weights : array - Weights 2D array of shape (out_dim, in_dim). - First element should be the product of shape_out. - Second element should be the product of shape_in. + Weights 2D array of shape (out_dim, in_dim). + First element should be the product of shape_out. + Second element should be the product of shape_in. shape_in : 2-tuple of int - Shape of the input of the Regridder. + Shape of the input of the Regridder. shape_out : 2-tuple of int - Shape of the output of the Regridder. + Shape of the output of the Regridder. Raises ------ ValueError - If any of the conditions is not respected. + If any of the conditions is not respected. """ # COO matrix is fast with F-ordered array but slow with C-array, so we # take in a C-ordered and then transpose) @@ -155,14 +169,19 @@ def check_shapes(indata, weights, shape_in, shape_out): raise ValueError('ny_out * nx_out should equal to weights.shape[0]') -def apply_weights(weights, indata, shape_in, shape_out): +def apply_weights( + weights: np.ndarray, # type: ignore[no-untyped-def] + indata: np.ndarray, # type: ignore[no-untyped-def] + shape_in: Tuple[int, int], + shape_out: Tuple[int, int], +) -> np.ndarray[Any, np.dtype[Any]]: """ Apply regridding weights to data. Parameters ---------- weights : sparse COO matrix - Regridding weights. + Regridding weights. indata : numpy array of shape ``(..., n_lat, n_lon)`` or ``(..., n_y, n_x)``. Should be C-ordered. Will be then tranposed to F-ordered. shape_in, shape_out : tuple of two integers @@ -200,7 +219,7 @@ def apply_weights(weights, indata, shape_in, shape_out): return outdata -def add_nans_to_weights(weights): +def add_nans_to_weights(weights: xr.DataArray) -> xr.DataArray: """Add NaN in empty rows of the regridding weights sparse matrix. By default, empty rows in the weights sparse matrix are interpreted as zeroes. This can become problematic @@ -210,12 +229,12 @@ def add_nans_to_weights(weights): Parameters ---------- weights : DataArray backed by a sparse.COO array - Sparse weights matrix. + Sparse weights matrix. Returns ------- DataArray backed by a sparse.COO array - Sparse weights matrix. + Sparse weights matrix. """ # Taken from @trondkr and adapted by @raphaeldussin to use `lil`. @@ -231,7 +250,11 @@ def add_nans_to_weights(weights): return weights -def _combine_weight_multipoly(weights, areas, indexes): +def _combine_weight_multipoly( + weights: xr.DataArray, + areas: np.ndarray[Any, np.dtype[Any]], + indexes: np.ndarray[Any, np.dtype[Any]], +) -> xr.DataArray: """Reduce a weight sparse matrix (csc format) by combining (adding) columns. This is used to sum individual weight matrices from multi-part geometries. @@ -239,17 +262,17 @@ def _combine_weight_multipoly(weights, areas, indexes): Parameters ---------- weights : DataArray - Usually backed by a sparse.COO array, with dims ('out_dim', 'in_dim') + Usually backed by a sparse.COO array, with dims ('out_dim', 'in_dim') areas : np.array - Array of destination areas, following same order as weights. + Array of destination areas, following same order as weights. indexes : array of integers - Columns with the same "index" will be summed into a single column at this - index in the output matrix. + Columns with the same "index" will be summed into a single column at this + index in the output matrix. Returns ------- sparse matrix (CSC) - Sum of weights from individual geometries. + Sum of weights from individual geometries. """ sub_weights = weights.rename(out_dim='subgeometries') diff --git a/xesmf/util.py b/xesmf/util.py index 4790b112..b78db6bd 100644 --- a/xesmf/util.py +++ b/xesmf/util.py @@ -1,3 +1,4 @@ +from typing import Any, Generator, List, Literal, Tuple import warnings import numpy as np @@ -8,7 +9,7 @@ LAT_CF_ATTRS = {'standard_name': 'latitude', 'units': 'degrees_north'} -def _grid_1d(start_b, end_b, step): +def _grid_1d(start_b: float, end_b: float, step: float): """ 1D grid centers and bounds @@ -33,7 +34,14 @@ def _grid_1d(start_b, end_b, step): return centers, bounds -def grid_2d(lon0_b, lon1_b, d_lon, lat0_b, lat1_b, d_lat): +def grid_2d( + lon0_b: float, + lon1_b: float, + d_lon: float, + lat0_b: float, + lat1_b: float, + d_lat: float, +) -> xr.Dataset: """ 2D rectilinear grid centers and bounds @@ -75,7 +83,14 @@ def grid_2d(lon0_b, lon1_b, d_lon, lat0_b, lat1_b, d_lat): return ds -def cf_grid_2d(lon0_b, lon1_b, d_lon, lat0_b, lat1_b, d_lat): +def cf_grid_2d( + lon0_b: float, + lon1_b: float, + d_lon: float, + lat0_b: float, + lat1_b: float, + d_lat: float, +) -> xr.Dataset: """ CF compliant 2D rectilinear grid centers and bounds. @@ -126,21 +141,26 @@ def cf_grid_2d(lon0_b, lon1_b, d_lon, lat0_b, lat1_b, d_lat): return ds -def grid_global(d_lon, d_lat, cf=False, lon1=180): +def grid_global( + d_lon: float, + d_lat: float, + cf: bool = False, + lon1: Literal[180, 360] = 180, +) -> xr.Dataset: """ Global 2D rectilinear grid centers and bounds Parameters ---------- d_lon : float - Longitude step size, i.e. grid resolution + Longitude step size, i.e. grid resolution d_lat : float - Latitude step size, i.e. grid resolution + Latitude step size, i.e. grid resolution cf : bool - Return a CF compliant grid. + Return a CF compliant grid. lon1 : {180, 360} - Right longitude bound. According to which convention is used longitudes will - vary from -180 to 180 or from 0 to 360. + Right longitude bound. According to which convention is used longitudes will + vary from -180 to 180 or from 0 to 360. Returns ------- @@ -168,7 +188,9 @@ def grid_global(d_lon, d_lat, cf=False, lon1=180): return grid_2d(lon0, lon1, d_lon, -90, 90, d_lat) -def _flatten_poly_list(polys): +def _flatten_poly_list( + polys: List[Polygon], +) -> Generator[tuple[int, Any] | tuple[int, Polygon], Any, None]: """Iterator flattening MultiPolygons.""" for i, poly in enumerate(polys): if isinstance(poly, MultiPolygon): @@ -178,7 +200,9 @@ def _flatten_poly_list(polys): yield (i, poly) -def split_polygons_and_holes(polys): +def split_polygons_and_holes( + polys: List[Polygon], +) -> Tuple[List[Polygon], List[Polygon], List[int], List[int]]: """Split the exterior boundaries and the holes for a list of polygons. If MultiPolygons are encountered in the list, they are flattened out @@ -195,14 +219,14 @@ def split_polygons_and_holes(polys): holes : list of Polygons Holes of the polygons as polygons i_ext : list of integers - The index in `polys` of each polygon in `exteriors`. + The index in `polys` of each polygon in `exteriors`. i_hol : list of integers - The index in `polys` of the owner of each hole in `holes`. + The index in `polys` of the owner of each hole in `holes`. """ - exteriors = [] - holes = [] - i_ext = [] - i_hol = [] + exteriors: List[Polygon] = [] + holes: List[Polygon] = [] + i_ext: List[int] = [] + i_hol: List[int] = [] for i, poly in _flatten_poly_list(polys): exteriors.append(Polygon(poly.exterior)) i_ext.append(i) @@ -218,19 +242,19 @@ def split_polygons_and_holes(polys): HUGE = 1.0e30 -def simple_tripolar_grid(nlons, nlats, lat_cap=60, lon_cut=-300): +def simple_tripolar_grid(nlons: int, nlats: int, lat_cap: float = 60, lon_cut: float = -300): """Generate a simple tripolar grid, regular under `lat_cap`. Parameters ---------- nlons: int - Number of longitude points. + Number of longitude points. nlats: int - Number of latitude points. + Number of latitude points. lat_cap: float - Latitude of the northern cap. + Latitude of the northern cap. lon_cut: float - Longitude of the periodic boundary. + Longitude of the periodic boundary. """ @@ -258,7 +282,13 @@ def simple_tripolar_grid(nlons, nlats, lat_cap=60, lon_cut=-300): # rather than using the package as a dependency -def _bipolar_projection(lamg, phig, lon_bp, rp, metrics_only=False): +def _bipolar_projection( + lamg: float, + phig: float, + lon_bp: float, + rp: float, + metrics_only: bool = False, +): """Makes a stereographic bipolar projection of the input coordinate mesh (lamg,phig) Returns the projected coordinate mesh and their metric coefficients (h^-1). The input mesh must be a regular spherical grid capping the pole with: @@ -328,7 +358,13 @@ def _bipolar_projection(lamg, phig, lon_bp, rp, metrics_only=False): return h_i_inv, h_j_inv -def _generate_bipolar_cap_mesh(Ni, Nj_ncap, lat0_bp, lon_bp, ensure_nj_even=True): +def _generate_bipolar_cap_mesh( + Ni: float, + Nj_ncap: float, + lat0_bp: float, + lon_bp: float, + ensure_nj_even: bool = True, +): # Define a (lon,lat) coordinate mesh on the Northern hemisphere of the globe sphere # such that the resolution of latg matches the desired resolution of the final grid along the symmetry meridian print('Generating bipolar grid bounded at latitude ', lat0_bp) @@ -350,7 +386,7 @@ def _generate_bipolar_cap_mesh(Ni, Nj_ncap, lat0_bp, lon_bp, ensure_nj_even=True return lams, phis, h_i_inv, h_j_inv -def _mdist(x1, x2): +def _mdist(x1: float, x2: float) -> float: """Returns positive distance modulo 360.""" return np.minimum(np.mod(x1 - x2, 360.0), np.mod(x2 - x1, 360.0)) From f89305712195334318d400cd7c8252fcaebf5bc0 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Thu, 12 Oct 2023 21:27:24 +0530 Subject: [PATCH 02/16] fixed frontend --- xesmf/frontend.py | 58 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 2582ed24..ebdbfc32 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -10,7 +10,7 @@ import xarray as xr from shapely.geometry import LineString, Polygon from xarray import DataArray, Dataset -from typing import Any, Literal, Optional, Tuple +from typing import Any, Hashable, List, Literal, Optional, Tuple from .backend import Grid, LocStream, Mesh, add_corner, esmf_regrid_build, esmf_regrid_finalize from .smm import ( _combine_weight_multipoly, @@ -426,7 +426,7 @@ def w(self) -> xr.DataArray: dims = 'y_out', 'x_out', 'y_in', 'x_in' return xr.DataArray(data, dims=dims) - def _get_default_filename(self): + def _get_default_filename(self) -> str: # e.g. bilinear_400x600_300x400.nc filename = '{0}_{1}x{2}_{3}x{4}'.format( self.method, @@ -458,7 +458,14 @@ def _compute_weights(self): esmf_regrid_finalize(regrid) # only need weights, not regrid object return w - def __call__(self, indata, keep_attrs=False, skipna=False, na_thres=1.0, output_chunks=None): + def __call__( + self, + indata: np.ndarray | dask_array_type | xr.DataArray | xr.Dataset, + keep_attrs: bool = False, + skipna: bool = False, + na_thres: float = 1.0, + output_chunks=None, + ): """ Apply regridding to input data. @@ -561,7 +568,15 @@ def __call__(self, indata, keep_attrs=False, skipna=False, na_thres=1.0, output_ raise TypeError('input must be numpy array, dask array, xarray DataArray or Dataset!') @staticmethod - def _regrid(indata, weights, *, shape_in, shape_out, skipna, na_thres): + def _regrid( + indata: np.ndarray, + weights: sps.coo_matrix, + *, + shape_in: Tuple[int, int], + shape_out: Tuple[int, int], + skipna: bool, + na_thresh: float, + ) -> np.ndarray: # skipna: set missing values to zero if skipna: missing = np.isnan(indata) @@ -580,7 +595,14 @@ def _regrid(indata, weights, *, shape_in, shape_out, skipna, na_thres): return outdata - def regrid_array(self, indata, weights, skipna=False, na_thres=1.0, output_chunks=None): + def regrid_array( + self, + indata: np.ndarray | dask_array_type, + weights: sps.coo_matrix, + skipna: bool = False, + na_thres: float = 1.0, + output_chunks: Optional[Tuple[int, ...]] = None, + ): """See __call__().""" if self.sequence_in: indata = np.reshape(indata, (*indata.shape[:-1], 1, indata.shape[-1])) @@ -619,14 +641,14 @@ def regrid_array(self, indata, weights, skipna=False, na_thres=1.0, output_chunk outdata = self._regrid(indata, weights, **kwargs) return outdata - def regrid_numpy(self, indata, **kwargs): + def regrid_numpy(self, indata: dask_array_type, **kwargs): warnings.warn( '`regrid_numpy()` will be removed in xESMF 0.7, please use `regrid_array` instead.', category=FutureWarning, ) return self.regrid_array(indata, self.weights.data, **kwargs) - def regrid_dask(self, indata, **kwargs): + def regrid_dask(self, indata: dask_array_type, **kwargs): warnings.warn( '`regrid_dask()` will be removed in xESMF 0.7, please use `regrid_array` instead.', category=FutureWarning, @@ -635,11 +657,11 @@ def regrid_dask(self, indata, **kwargs): def regrid_dataarray( self, - dr_in, + dr_in: xr.DataArray, keep_attrs: bool = False, skipna: bool = False, na_thres: float = 1.0, - output_chunks=None, + output_chunks: Optional[Tuple[int, ...]] = None, ): """See __call__().""" @@ -660,11 +682,11 @@ def regrid_dataarray( def regrid_dataset( self, - ds_in, + ds_in: xr.Dataset, keep_attrs: bool = False, skipna: bool = False, na_thres: float = 1.0, - output_chunks=None, + output_chunks: Optional[Tuple[int, ...]] = None, ): """See __call__().""" @@ -693,7 +715,9 @@ def regrid_dataset( return self._format_xroutput(ds_out, temp_horiz_dims) - def _parse_xrinput(self, dr_in): + def _parse_xrinput( + self, dr_in: xr.DataArray | xr.Dataset + ) -> Tuple[Tuple[Hashable, ...], List[str]]: # dr could be a DataArray or a Dataset # Get input horiz dim names and set output horiz dim names if self.in_horiz_dims is not None and all(dim in dr_in.dims for dim in self.in_horiz_dims): @@ -720,15 +744,17 @@ def _parse_xrinput(self, dr_in): ) if self.sequence_out: - temp_horiz_dims = ['dummy', 'locations'] + temp_horiz_dims: List[str] = ['dummy', 'locations'] else: - temp_horiz_dims = [s + '_new' for s in input_horiz_dims] + temp_horiz_dims: List[str] = [s + '_new' for s in input_horiz_dims] if self.sequence_in and not self.sequence_out: temp_horiz_dims = ['dummy_new'] + temp_horiz_dims return input_horiz_dims, temp_horiz_dims - def _format_xroutput(self, out, new_dims=None): + def _format_xroutput( + self, out: xr.DataArray | xr.Dataset, new_dims: List[str] + ) -> xr.DataArray | xr.Dataset: out.attrs['regrid_method'] = self.method return out @@ -752,7 +778,7 @@ def __repr__(self): return info - def to_netcdf(self, filename: Optional[str] = None): + def to_netcdf(self, filename: Optional[str] = None) -> str: """Save weights to disk as a netCDF file.""" if filename is None: filename = self.filename From be821eb120ffc27991d4b70d43baa89acf23a2d0 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Sat, 28 Oct 2023 16:27:20 +0530 Subject: [PATCH 03/16] added typing --- xesmf/backend.py | 105 ++++++---- xesmf/data.py | 16 +- xesmf/frontend.py | 496 ++++++++++++++++++++++++++++------------------ xesmf/smm.py | 108 +++++----- xesmf/util.py | 79 +++++--- 5 files changed, 471 insertions(+), 333 deletions(-) diff --git a/xesmf/backend.py b/xesmf/backend.py index 7b87cc60..ff1733dc 100644 --- a/xesmf/backend.py +++ b/xesmf/backend.py @@ -16,6 +16,7 @@ """ import os +from typing import Optional import warnings try: @@ -37,8 +38,8 @@ def warn_f_contiguous(a: np.ndarray) -> None: ---------- a : numpy array """ - if not a.flags['F_CONTIGUOUS']: - warnings.warn('Input array is not F_CONTIGUOUS. ' 'Will affect performance.') + if not a.flags["F_CONTIGUOUS"]: + warnings.warn("Input array is not F_CONTIGUOUS. " "Will affect performance.") def warn_lat_range(lat: np.ndarray) -> None: @@ -53,17 +54,17 @@ def warn_lat_range(lat: np.ndarray) -> None: lat : numpy array """ if (lat.max() > 90.0) or (lat.min() < -90.0): - warnings.warn('Latitude is outside of [-90, 90]') + warnings.warn("Latitude is outside of [-90, 90]") class Grid(ESMF.Grid): @classmethod def from_xarray( cls, - lon: np.ndarray[float, int], - lat: np.ndarray[float, int], + lon: np.ndarray, + lat: np.ndarray, periodic: bool = False, - mask=None, + mask: Optional[np.ndarray] = None, ): """ Create an ESMF.Grid object, for constructing ESMF.Field and ESMF.Regrid. @@ -103,8 +104,8 @@ def from_xarray( # ESMF.Grid can actually take 3D array (lon, lat, radius), # but regridding only works for 2D array - assert lon.ndim == 2, 'Input grid must be 2D array' - assert lon.shape == lat.shape, 'lon and lat must have same shape' + assert lon.ndim == 2, "Input grid must be 2D array" + assert lon.shape == lat.shape, "lon and lat must have same shape" staggerloc = ESMF.StaggerLoc.CENTER # actually just integer 0 @@ -142,10 +143,13 @@ def from_xarray( grid_mask = mask.astype(np.int32) if not (grid_mask.shape == lon.shape): raise ValueError( - 'mask must have the same shape as the latitude/longitude' - 'coordinates, got: mask.shape = %s, lon.shape = %s' % (mask.shape, lon.shape) + "mask must have the same shape as the latitude/longitude" + "coordinates, got: mask.shape = %s, lon.shape = %s" + % (mask.shape, lon.shape) ) - grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER, from_file=False) + grid.add_item( + ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER, from_file=False + ) grid.mask[0][:] = grid_mask return grid @@ -157,7 +161,7 @@ def get_shape(self, loc=ESMF.StaggerLoc.CENTER): class LocStream(ESMF.LocStream): @classmethod - def from_xarray(cls, lon, lat): + def from_xarray(cls, lon: np.ndarray, lat: np.ndarray) -> ESMF.LocStream: """ Create an ESMF.LocStream object, for contrusting ESMF.Field and ESMF.Regrid @@ -172,9 +176,9 @@ def from_xarray(cls, lon, lat): """ if len(lon.shape) > 1: - raise ValueError('lon can only be 1d') + raise ValueError("lon can only be 1d") if len(lat.shape) > 1: - raise ValueError('lat can only be 1d') + raise ValueError("lat can only be 1d") assert lon.shape == lat.shape @@ -182,8 +186,8 @@ def from_xarray(cls, lon, lat): locstream = cls(location_count, coord_sys=ESMF.CoordSys.SPH_DEG) - locstream['ESMF:Lon'] = lon.astype(np.dtype('f8')) - locstream['ESMF:Lat'] = lat.astype(np.dtype('f8')) + locstream["ESMF:Lon"] = lon.astype(np.dtype("f8")) + locstream["ESMF:Lat"] = lat.astype(np.dtype("f8")) return locstream @@ -218,12 +222,14 @@ def add_corner(grid, lon_b, lat_b): warn_lat_range(lat_b) - assert lon_b.ndim == 2, 'Input grid must be 2D array' - assert lon_b.shape == lat_b.shape, 'lon_b and lat_b must have same shape' - assert np.array_equal(lon_b.shape, grid.max_index + 1), 'lon_b should be size (Nx+1, Ny+1)' + assert lon_b.ndim == 2, "Input grid must be 2D array" + assert lon_b.shape == lat_b.shape, "lon_b and lat_b must have same shape" + assert np.array_equal( + lon_b.shape, grid.max_index + 1 + ), "lon_b should be size (Nx+1, Ny+1)" assert (grid.num_peri_dims == 0) and ( grid.periodic_dim is None - ), 'Cannot add corner for periodic grid' + ), "Cannot add corner for periodic grid" grid.add_coords(staggerloc=staggerloc) @@ -236,7 +242,7 @@ def add_corner(grid, lon_b, lat_b): class Mesh(ESMF.Mesh): @classmethod - def from_polygons(cls, polys, element_coords='centroid'): + def from_polygons(cls, polys, element_coords="centroid"): """ Create an ESMF.Mesh object from a list of polygons. @@ -260,13 +266,13 @@ def from_polygons(cls, polys, element_coords='centroid'): node_num = sum(len(e.exterior.coords) - 1 for e in polys) elem_num = len(polys) # Pre alloc arrays. Special structure for coords makes the code faster. - crd_dt = np.dtype([('x', np.float32), ('y', np.float32)]) + crd_dt = np.dtype([("x", np.float32), ("y", np.float32)]) node_coords = np.empty(node_num, dtype=crd_dt) node_coords[:] = (np.nan, np.nan) # Fill with impossible values element_types = np.empty(elem_num, dtype=np.uint32) element_conn = np.empty(node_num, dtype=np.uint32) # Flag for centroid calculation - calc_centroid = isinstance(element_coords, str) and element_coords == 'centroid' + calc_centroid = isinstance(element_coords, str) and element_coords == "centroid" if calc_centroid: element_coords = np.empty(elem_num, dtype=crd_dt) inode = 0 @@ -309,7 +315,7 @@ def from_polygons(cls, polys, element_coords='centroid'): ) except ValueError as err: raise ValueError( - 'ESMF failed to create the Mesh, this usually happen when some polygons are invalid (test with `poly.is_valid`)' + "ESMF failed to create the Mesh, this usually happen when some polygons are invalid (test with `poly.is_valid`)" ) from err return mesh @@ -394,56 +400,66 @@ def esmf_regrid_build( # use shorter, clearer names for options in ESMF.RegridMethod method_dict = { - 'bilinear': ESMF.RegridMethod.BILINEAR, - 'conservative': ESMF.RegridMethod.CONSERVE, - 'conservative_normed': ESMF.RegridMethod.CONSERVE, - 'patch': ESMF.RegridMethod.PATCH, - 'nearest_s2d': ESMF.RegridMethod.NEAREST_STOD, - 'nearest_d2s': ESMF.RegridMethod.NEAREST_DTOS, + "bilinear": ESMF.RegridMethod.BILINEAR, + "conservative": ESMF.RegridMethod.CONSERVE, + "conservative_normed": ESMF.RegridMethod.CONSERVE, + "patch": ESMF.RegridMethod.PATCH, + "nearest_s2d": ESMF.RegridMethod.NEAREST_STOD, + "nearest_d2s": ESMF.RegridMethod.NEAREST_DTOS, } try: esmf_regrid_method = method_dict[method] except Exception: - raise ValueError('method should be chosen from ' '{}'.format(list(method_dict.keys()))) + raise ValueError( + "method should be chosen from " "{}".format(list(method_dict.keys())) + ) # use shorter, clearer names for options in ESMF.ExtrapMethod extrap_dict = { - 'inverse_dist': ESMF.ExtrapMethod.NEAREST_IDAVG, - 'nearest_s2d': ESMF.ExtrapMethod.NEAREST_STOD, + "inverse_dist": ESMF.ExtrapMethod.NEAREST_IDAVG, + "nearest_s2d": ESMF.ExtrapMethod.NEAREST_STOD, None: None, } try: esmf_extrap_method = extrap_dict[extrap_method] except KeyError: raise KeyError( - '`extrap_method` should be chosen from ' '{}'.format(list(extrap_dict.keys())) + "`extrap_method` should be chosen from " "{}".format( + list(extrap_dict.keys()) + ) ) # until ESMPy updates ESMP_FieldRegridStoreFile, extrapolation is not possible # if files are written on disk if (extrap_method is not None) & (filename is not None): - raise ValueError('`extrap_method` cannot be used along with `filename`.') + raise ValueError("`extrap_method` cannot be used along with `filename`.") # conservative regridding needs cell corner information - if method in ['conservative', 'conservative_normed']: + if method in ["conservative", "conservative_normed"]: if not isinstance(sourcegrid, ESMF.Mesh) and not sourcegrid.has_corners: raise ValueError( - 'source grid has no corner information. ' 'cannot use conservative regridding.' + "source grid has no corner information. " + "cannot use conservative regridding." ) if not isinstance(destgrid, ESMF.Mesh) and not destgrid.has_corners: raise ValueError( - 'destination grid has no corner information. ' 'cannot use conservative regridding.' + "destination grid has no corner information. " + "cannot use conservative regridding." ) # ESMF.Regrid requires Field (Grid+data) as input, not just Grid. # Extra dimensions are specified when constructing the Field objects, # not when constructing the Regrid object later on. if isinstance(sourcegrid, ESMF.Mesh): - sourcefield = ESMF.Field(sourcegrid, meshloc=ESMF.MeshLoc.ELEMENT, ndbounds=extra_dims) + sourcefield = ESMF.Field( + sourcegrid, meshloc=ESMF.MeshLoc.ELEMENT, ndbounds=extra_dims + ) else: sourcefield = ESMF.Field(sourcegrid, ndbounds=extra_dims) if isinstance(destgrid, ESMF.Mesh): - destfield = ESMF.Field(destgrid, meshloc=ESMF.MeshLoc.ELEMENT, ndbounds=extra_dims) + destfield = ESMF.Field( + destgrid, meshloc=ESMF.MeshLoc.ELEMENT, ndbounds=extra_dims + ) else: destfield = ESMF.Field(destgrid, ndbounds=extra_dims) @@ -460,11 +476,11 @@ def esmf_regrid_build( if filename is not None: assert not os.path.exists( filename - ), 'Weight file already exists! Please remove it or use a new name.' + ), "Weight file already exists! Please remove it or use a new name." # re-normalize conservative regridding results # https://github.com/JiaweiZhuang/xESMF/issues/17 - if method == 'conservative_normed': + if method == "conservative_normed": norm_type = ESMF.NormType.FRACAREA else: norm_type = ESMF.NormType.DSTAREA @@ -571,7 +587,7 @@ def esmf_regrid_finalize(regrid): def esmf_locstream(lon, lat): warnings.warn( - '`esmf_locstream` is being deprecated in favor of `LocStream.from_xarray`', + "`esmf_locstream` is being deprecated in favor of `LocStream.from_xarray`", DeprecationWarning, ) return LocStream.from_xarray(lon, lat) @@ -579,6 +595,7 @@ def esmf_locstream(lon, lat): def esmf_grid(lon, lat, periodic=False, mask=None): warnings.warn( - '`esmf_grid` is being deprecated in favor of `Grid.from_xarray`', DeprecationWarning + "`esmf_grid` is being deprecated in favor of `Grid.from_xarray`", + DeprecationWarning, ) return Grid.from_xarray(lon, lat) diff --git a/xesmf/data.py b/xesmf/data.py index 5d011ebc..6b9a8ac3 100644 --- a/xesmf/data.py +++ b/xesmf/data.py @@ -2,21 +2,23 @@ Standard test data for regridding benchmark. """ +from typing import Any import numpy as np +import numpy.typing as npt import xarray def wave_smooth( # type: ignore - lon: np.ndarray[float] | xarray.DataArray, # type: ignore - lat: np.ndarray[float] | xarray.DataArray, # type: ignore -) -> np.ndarray[float] | xarray.DataArray: # type: ignore + lon: npt.NDArray[np.floating[Any]] | xarray.DataArray, + lat: npt.NDArray[np.floating[Any]] | xarray.DataArray, +) -> npt.NDArray[np.floating[Any]] | xarray.DataArray: """ Spherical harmonic with low frequency. Parameters ---------- lon, lat : 2D numpy array or xarray DataArray - Longitute/Latitude of cell centers + Longitude/Latitude of cell centers Returns ------- @@ -40,8 +42,8 @@ def wave_smooth( # type: ignore 137(6), 1721-1741. """ # degree to radius, make a copy - lat *= np.pi / 180.0 # type: ignore - lon *= np.pi / 180.0 # type: ignore + lat *= np.pi / 180.0 + lon *= np.pi / 180.0 - f = 2 + pow(np.cos(lat), 2) * np.cos(2 * lon) # type: ignore + f = 2 + pow(np.cos(lat), 2) * np.cos(2 * lon) return f diff --git a/xesmf/frontend.py b/xesmf/frontend.py index ebdbfc32..d4e5621d 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -6,12 +6,20 @@ import cf_xarray as cfxr import numpy as np +import numpy.typing as npt import sparse as sps import xarray as xr -from shapely.geometry import LineString, Polygon +from shapely.geometry import LineString, Polygon, MultiPolygon from xarray import DataArray, Dataset -from typing import Any, Hashable, List, Literal, Optional, Tuple -from .backend import Grid, LocStream, Mesh, add_corner, esmf_regrid_build, esmf_regrid_finalize +from typing import Any, Dict, Hashable, List, Literal, Optional, Sequence, Tuple +from .backend import ( + Grid, + LocStream, + Mesh, + add_corner, + esmf_regrid_build, + esmf_regrid_finalize, +) from .smm import ( _combine_weight_multipoly, _parse_coords_and_values, @@ -31,95 +39,125 @@ def subset_regridder( - ds_out, - ds_in, - method, + ds_in: DataArray | Dataset | dict[str, DataArray], + ds_out: DataArray | Dataset | dict[str, DataArray], + method: Literal[ + "bilinear", + "conservative", + "conservative_normed", + "patch", + "nearest_s2d", + "nearest_d2s", + ], in_dims, out_dims, - locstream_in, - locstream_out, - periodic, + locstream_in: bool, + locstream_out: bool, + periodic: bool, **kwargs, ): """Compute subset of weights""" - kwargs.pop('filename', None) # Don't save subset of weights - kwargs.pop('reuse_weights', None) + kwargs.pop("filename", None) # Don't save subset of weights + kwargs.pop("reuse_weights", None) # Renaming dims to original names for the subset regridding if locstream_in: - ds_in = ds_in.rename({'x_in': in_dims[0]}) + ds_in = ds_in.rename({"x_in": in_dims[0]}) else: - ds_in = ds_in.rename({'y_in': in_dims[0], 'x_in': in_dims[1]}) + ds_in = ds_in.rename({"y_in": in_dims[0], "x_in": in_dims[1]}) if locstream_out: - ds_out = ds_out.rename({'x_out': out_dims[1]}) + ds_out = ds_out.rename({"x_out": out_dims[1]}) else: - ds_out = ds_out.rename({'y_out': out_dims[0], 'x_out': out_dims[1]}) + ds_out = ds_out.rename({"y_out": out_dims[0], "x_out": out_dims[1]}) regridder = Regridder( - ds_in, ds_out, method, locstream_in, locstream_out, periodic, parallel=False, **kwargs + ds_in, + ds_out, + method, + locstream_in, + locstream_out, + periodic, + parallel=False, + **kwargs, ) return regridder.w -def as_2d_mesh(lon, lat): +def as_2d_mesh( + lon: DataArray | npt.NDArray[np.float256], + lat: DataArray | npt.NDArray[np.float256], +) -> tuple[DataArray | npt.NDArray[np.float256], DataArray | npt.NDArray[np.float256]]: if (lon.ndim, lat.ndim) == (2, 2): - assert lon.shape == lat.shape, 'lon and lat should have same shape' + assert lon.shape == lat.shape, "lon and lat should have same shape" elif (lon.ndim, lat.ndim) == (1, 1): lon, lat = np.meshgrid(lon, lat) else: - raise ValueError('lon and lat should be both 1D or 2D') + raise ValueError("lon and lat should be both 1D or 2D") return lon, lat -def _get_lon_lat(ds): +def _get_lon_lat( + ds: Dataset | Dict[str, npt.NDArray[np.float256]] +) -> tuple[DataArray | npt.NDArray[np.float256], DataArray | npt.NDArray[np.float256]]: """Return lon and lat extracted from ds.""" - if ('lat' in ds and 'lon' in ds) or ('lat' in ds.coords and 'lon' in ds.coords): + if ("lat" in ds and "lon" in ds) or ("lat" in ds.coords and "lon" in ds.coords): # Old way. - return ds['lon'], ds['lat'] + return ds["lon"], ds["lat"] # else : cf-xarray way try: - lon = ds.cf['longitude'] - lat = ds.cf['latitude'] + lon = ds.cf["longitude"] + lat = ds.cf["latitude"] except (KeyError, AttributeError, ValueError): # KeyError if cfxr doesn't detect the coords # AttributeError if ds is a dict - raise ValueError('dataset must include lon/lat or be CF-compliant') + raise ValueError("dataset must include lon/lat or be CF-compliant") return lon, lat -def _get_lon_lat_bounds(ds): +def _get_lon_lat_bounds( + ds: Dataset | Dict[str, npt.NDArray[np.float256]] +) -> tuple[DataArray | npt.NDArray[np.float256], DataArray | npt.NDArray[np.float256]]: """Return bounds of lon and lat extracted from ds.""" - if 'lat_b' in ds and 'lon_b' in ds: + if "lat_b" in ds and "lon_b" in ds: # Old way. - return ds['lon_b'], ds['lat_b'] + return ds["lon_b"], ds["lat_b"] # else : cf-xarray way - if 'longitude' not in ds.cf.coordinates: + if "longitude" not in ds.cf.coordinates: # If we are here, _get_lon_lat() didn't fail, thus we should be able to guess the coords. ds = ds.cf.guess_coord_axis() try: - lon_bnds = ds.cf.get_bounds('longitude') - lat_bnds = ds.cf.get_bounds('latitude') + lon_bnds = ds.cf.get_bounds("longitude") + lat_bnds = ds.cf.get_bounds("latitude") except KeyError: # bounds are not already present - if ds.cf['longitude'].ndim > 1: + if ds.cf["longitude"].ndim > 1: # We cannot infer 2D bounds, raise KeyError as custom "lon_b" is missing. - raise KeyError('lon_b') - lon_name = ds.cf['longitude'].name - lat_name = ds.cf['latitude'].name + raise KeyError("lon_b") + lon_name = ds.cf["longitude"].name + lat_name = ds.cf["latitude"].name ds = ds.cf.add_bounds([lon_name, lat_name]) - lon_bnds = ds.cf.get_bounds('longitude') - lat_bnds = ds.cf.get_bounds('latitude') + lon_bnds = ds.cf.get_bounds("longitude") + lat_bnds = ds.cf.get_bounds("latitude") # Convert from CF bounds to xESMF bounds. # order=None is because we don't want to assume the dimension order for 2D bounds. - lon_b = cfxr.bounds_to_vertices(lon_bnds, ds.cf.get_bounds_dim_name('longitude'), order=None) - lat_b = cfxr.bounds_to_vertices(lat_bnds, ds.cf.get_bounds_dim_name('latitude'), order=None) + lon_b = cfxr.bounds_to_vertices( + lon_bnds, ds.cf.get_bounds_dim_name("longitude"), order=None + ) + lat_b = cfxr.bounds_to_vertices( + lat_bnds, ds.cf.get_bounds_dim_name("latitude"), order=None + ) return lon_b, lat_b -def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None): +def ds_to_ESMFgrid( + ds: Dataset | Dict[str, npt.NDArray[np.float256]], + need_bounds: bool = False, + periodic: bool = False, + append=None, +): """ Convert xarray DataSet or dictionary to ESMF.Grid object. @@ -151,7 +189,7 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None): # use np.asarray(dr) instead of dr.values, so it also works for dictionary lon, lat = _get_lon_lat(ds) - if hasattr(lon, 'dims'): + if hasattr(lon, "dims"): if lon.ndim == 1: dim_names = lat.dims + lon.dims else: @@ -160,8 +198,8 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None): dim_names = None lon, lat = as_2d_mesh(np.asarray(lon), np.asarray(lat)) - if 'mask' in ds: - mask = np.asarray(ds['mask']) + if "mask" in ds: + mask = np.asarray(ds["mask"]) else: mask = None @@ -195,16 +233,16 @@ def ds_to_ESMFlocstream(ds): """ lon, lat = _get_lon_lat(ds) - if hasattr(lon, 'dims'): + if hasattr(lon, "dims"): dim_names = lon.dims else: dim_names = None lon, lat = np.asarray(lon), np.asarray(lat) if len(lon.shape) > 1: - raise ValueError('lon can only be 1d') + raise ValueError("lon can only be 1d") if len(lat.shape) > 1: - raise ValueError('lat can only be 1d') + raise ValueError("lat can only be 1d") assert lon.shape == lat.shape @@ -213,7 +251,7 @@ def ds_to_ESMFlocstream(ds): return locstream, (1,) + lon.shape, dim_names -def polys_to_ESMFmesh(polys): +def polys_to_ESMFmesh(polys) -> tuple[Mesh, tuple[Literal[1], int]]: """ Convert a sequence of shapely Polygons to a ESMF.Mesh object. @@ -234,7 +272,7 @@ def polys_to_ESMFmesh(polys): ext, holes, _, _ = split_polygons_and_holes(polys) if len(holes) > 0: warnings.warn( - 'Some passed polygons have holes, those are not represented in the returned Mesh.' + "Some passed polygons have holes, those are not represented in the returned Mesh." ) return Mesh.from_polygons(ext), (1, len(ext)) @@ -242,12 +280,12 @@ def polys_to_ESMFmesh(polys): class BaseRegridder(object): def __init__( self, - grid_in: Grid, - grid_out: Grid, + grid_in: Grid | LocStream | Mesh, + grid_out: Grid | LocStream | Mesh, method: str, filename: Optional[str] = None, reuse_weights: bool = False, - extrap_method: Optional[Literal['inverse_dist', 'nearest_s2d']] = None, + extrap_method: Optional[Literal["inverse_dist", "nearest_s2d"]] = None, extrap_dist_exponent: Optional[float] = None, extrap_num_src_pnts: Optional[int] = None, weights: Optional[Any] = None, @@ -256,7 +294,7 @@ def __init__( output_dims: Optional[Tuple[str, ...]] = None, unmapped_to_nan: bool = False, parallel: bool = False, - ): + ) -> None: """ Base xESMF regridding class supporting ESMF objects: `Grid`, `Mesh` and `LocStream`. @@ -350,17 +388,22 @@ def __init__( self.extrap_dist_exponent = extrap_dist_exponent self.extrap_num_src_pnts = extrap_num_src_pnts self.ignore_degenerate = ignore_degenerate - self.periodic = getattr(self.grid_in, 'periodic_dim', None) is not None + self.periodic = getattr(self.grid_in, "periodic_dim", None) is not None self.sequence_in = isinstance(self.grid_in, (LocStream, Mesh)) self.sequence_out = isinstance(self.grid_out, (LocStream, Mesh)) if input_dims is not None and len(input_dims) != int(not self.sequence_in) + 1: - raise ValueError(f'Wrong number of dimension names in `input_dims` ({len(input_dims)}.') + raise ValueError( + f"Wrong number of dimension names in `input_dims` ({len(input_dims)}." + ) self.in_horiz_dims = input_dims - if output_dims is not None and len(output_dims) != int(not self.sequence_out) + 1: + if ( + output_dims is not None + and len(output_dims) != int(not self.sequence_out) + 1 + ): raise ValueError( - f'Wrong number of dimension names in `output dims` ({len(output_dims)}.' + f"Wrong number of dimension names in `output dims` ({len(output_dims)}." ) self.out_horiz_dims = output_dims @@ -373,7 +416,9 @@ def __init__( # some logic about reusing weights with either filename or weights args if reuse_weights and (filename is None) and (weights is None): - raise ValueError('To reuse weights, you need to provide either filename or weights.') + raise ValueError( + "To reuse weights, you need to provide either filename or weights." + ) if not parallel: if not reuse_weights and weights is None: @@ -397,13 +442,15 @@ def __init__( self.to_netcdf(filename=filename) # set default weights filename if none given - self.filename = self._get_default_filename() if filename is None else filename + self.filename = ( + self._get_default_filename() if filename is None else filename + ) @property - def A(self): + def A(self) -> DataArray: message = ( - 'regridder.A is deprecated and will be removed in future versions. ' - 'Use regridder.weights instead.' + "regridder.A is deprecated and will be removed in future versions. " + "Use regridder.weights instead." ) warnings.warn(message, DeprecationWarning) @@ -423,12 +470,12 @@ def w(self) -> xr.DataArray: # TODO: Add coords ? s = self.shape_out + self.shape_in data = self.weights.data.reshape(s) - dims = 'y_out', 'x_out', 'y_in', 'x_in' + dims = "y_out", "x_out", "y_in", "x_in" return xr.DataArray(data, dims=dims) def _get_default_filename(self) -> str: # e.g. bilinear_400x600_300x400.nc - filename = '{0}_{1}x{2}_{3}x{4}'.format( + filename = "{0}_{1}x{2}_{3}x{4}".format( self.method, self.shape_in[0], self.shape_in[1], @@ -437,9 +484,9 @@ def _get_default_filename(self) -> str: ) if self.periodic: - filename += '_peri.nc' + filename += "_peri.nc" else: - filename += '.nc' + filename += ".nc" return filename @@ -460,11 +507,11 @@ def _compute_weights(self): def __call__( self, - indata: np.ndarray | dask_array_type | xr.DataArray | xr.Dataset, + indata: npt.NDArray[np.float256] | dask_array_type | xr.DataArray | xr.Dataset, keep_attrs: bool = False, skipna: bool = False, na_thres: float = 1.0, - output_chunks=None, + output_chunks: Optional[Dict[str, int] | Tuple[int, ...]] = None, ): """ Apply regridding to input data. @@ -565,18 +612,20 @@ def __call__( output_chunks=output_chunks, ) else: - raise TypeError('input must be numpy array, dask array, xarray DataArray or Dataset!') + raise TypeError( + "input must be numpy array, dask array, xarray DataArray or Dataset!" + ) @staticmethod def _regrid( - indata: np.ndarray, + indata: npt.NDArray[np.float256], weights: sps.coo_matrix, *, shape_in: Tuple[int, int], shape_out: Tuple[int, int], skipna: bool, na_thresh: float, - ) -> np.ndarray: + ) -> npt.NDArray[np.float256]: # skipna: set missing values to zero if skipna: missing = np.isnan(indata) @@ -587,7 +636,9 @@ def _regrid( # skipna: Compute the influence of missing data at each interpolation point and filter those not meeting acceptable threshold. if skipna: - fraction_valid = apply_weights(weights, (~missing).astype('d'), shape_in, shape_out) + fraction_valid = apply_weights( + weights, (~missing).astype("d"), shape_in, shape_out + ) tol = 1e-6 bad = fraction_valid < np.clip(1 - na_thres, tol, 1 - tol) fraction_valid[bad] = 1 @@ -597,11 +648,11 @@ def _regrid( def regrid_array( self, - indata: np.ndarray | dask_array_type, + indata: npt.NDArray[np.float256] | dask_array_type, weights: sps.coo_matrix, skipna: bool = False, na_thres: float = 1.0, - output_chunks: Optional[Tuple[int, ...]] = None, + output_chunks: Optional[Tuple[int, ...] | Dict[str, int]] = None, ): """See __call__().""" if self.sequence_in: @@ -609,11 +660,11 @@ def regrid_array( # If output_chunk is dict, order output chunks to match order of out_horiz_dims and convert to tuple if isinstance(output_chunks, dict): - output_chunks = tuple([output_chunks.get(key) for key in self.out_horiz_dims]) + output_chunks = tuple(map(output_chunks.get, self.out_horiz_dims)) kwargs = { - 'shape_in': self.shape_in, - 'shape_out': self.shape_out, + "shape_in": self.shape_in, + "shape_out": self.shape_out, } check_shapes(indata, weights, **kwargs) @@ -624,33 +675,40 @@ def regrid_array( if isinstance(indata, dask_array_type): # dask if output_chunks is None: output_chunks = tuple( - [min(shp, inchnk) for shp, inchnk in zip(self.shape_out, indata.chunksize[-2:])] + min(shp, inchnk) + for shp, inchnk in zip(self.shape_out, indata.chunksize[-2:]) ) if len(output_chunks) != len(self.shape_out): if len(output_chunks) == 1 and self.sequence_out: output_chunks = (1, output_chunks[0]) else: raise ValueError( - f'output_chunks must have same dimension as ds_out,' - f' output_chunks dimension ({len(output_chunks)}) does not ' - f'match ds_out dimension ({len(self.shape_out)})' + f"output_chunks must have same dimension as ds_out," + f" output_chunks dimension ({len(output_chunks)}) does not " + f"match ds_out dimension ({len(self.shape_out)})" ) - weights = da.from_array(weights, chunks=(output_chunks + indata.chunksize[-2:])) + weights = da.from_array( + weights, chunks=(output_chunks + indata.chunksize[-2:]) + ) outdata = self._regrid(indata, weights, **kwargs) else: # numpy outdata = self._regrid(indata, weights, **kwargs) return outdata - def regrid_numpy(self, indata: dask_array_type, **kwargs): + def regrid_numpy( + self, indata: npt.NDArray[np.float256], **kwargs + ) -> npt.NDArray[np.float256]: warnings.warn( - '`regrid_numpy()` will be removed in xESMF 0.7, please use `regrid_array` instead.', + "`regrid_numpy()` will be removed in xESMF 0.7, please use `regrid_array` instead.", category=FutureWarning, ) return self.regrid_array(indata, self.weights.data, **kwargs) - def regrid_dask(self, indata: dask_array_type, **kwargs): + def regrid_dask( + self, indata: npt.NDArray[np.float256], **kwargs + ) -> npt.NDArray[np.float256]: warnings.warn( - '`regrid_dask()` will be removed in xESMF 0.7, please use `regrid_array` instead.', + "`regrid_dask()` will be removed in xESMF 0.7, please use `regrid_array` instead.", category=FutureWarning, ) return self.regrid_array(indata, self.weights.data, **kwargs) @@ -662,7 +720,7 @@ def regrid_dataarray( skipna: bool = False, na_thres: float = 1.0, output_chunks: Optional[Tuple[int, ...]] = None, - ): + ) -> DataArray | Dataset: """See __call__().""" input_horiz_dims, temp_horiz_dims = self._parse_xrinput(dr_in) @@ -672,9 +730,9 @@ def regrid_dataarray( dr_in, self.weights, kwargs=kwargs, - input_core_dims=[input_horiz_dims, ('out_dim', 'in_dim')], + input_core_dims=[input_horiz_dims, ("out_dim", "in_dim")], output_core_dims=[temp_horiz_dims], - dask='allowed', + dask="allowed", keep_attrs=keep_attrs, ) @@ -687,7 +745,7 @@ def regrid_dataset( skipna: bool = False, na_thres: float = 1.0, output_chunks: Optional[Tuple[int, ...]] = None, - ): + ) -> DataArray | Dataset: """See __call__().""" # get the first data variable to infer input_core_dims @@ -707,9 +765,9 @@ def regrid_dataset( ds_in, self.weights, kwargs=kwargs, - input_core_dims=[input_horiz_dims, ('out_dim', 'in_dim')], + input_core_dims=[input_horiz_dims, ("out_dim", "in_dim")], output_core_dims=[temp_horiz_dims], - dask='allowed', + dask="allowed", keep_attrs=keep_attrs, ) @@ -720,7 +778,9 @@ def _parse_xrinput( ) -> Tuple[Tuple[Hashable, ...], List[str]]: # dr could be a DataArray or a Dataset # Get input horiz dim names and set output horiz dim names - if self.in_horiz_dims is not None and all(dim in dr_in.dims for dim in self.in_horiz_dims): + if self.in_horiz_dims is not None and all( + dim in dr_in.dims for dim in self.in_horiz_dims + ): input_horiz_dims = self.in_horiz_dims else: if isinstance(dr_in, Dataset): @@ -737,36 +797,36 @@ def _parse_xrinput( # help user debugging invalid horizontal dimensions warnings.warn( ( - f'Using dimensions {input_horiz_dims} from data variable {name} ' - 'as the horizontal dimensions for the regridding.' + f"Using dimensions {input_horiz_dims} from data variable {name} " + "as the horizontal dimensions for the regridding." ), UserWarning, ) if self.sequence_out: - temp_horiz_dims: List[str] = ['dummy', 'locations'] + temp_horiz_dims: List[str] = ["dummy", "locations"] else: - temp_horiz_dims: List[str] = [s + '_new' for s in input_horiz_dims] + temp_horiz_dims: List[str] = [s + "_new" for s in input_horiz_dims] if self.sequence_in and not self.sequence_out: - temp_horiz_dims = ['dummy_new'] + temp_horiz_dims + temp_horiz_dims = ["dummy_new"] + temp_horiz_dims return input_horiz_dims, temp_horiz_dims def _format_xroutput( self, out: xr.DataArray | xr.Dataset, new_dims: List[str] ) -> xr.DataArray | xr.Dataset: - out.attrs['regrid_method'] = self.method + out.attrs["regrid_method"] = self.method return out - def __repr__(self): + def __repr__(self) -> str: info = ( - 'xESMF Regridder \n' - 'Regridding algorithm: {} \n' - 'Weight filename: {} \n' - 'Reuse pre-computed weights? {} \n' - 'Input grid shape: {} \n' - 'Output grid shape: {} \n' - 'Periodic in longitude? {}'.format( + "xESMF Regridder \n" + "Regridding algorithm: {} \n" + "Weight filename: {} \n" + "Reuse pre-computed weights? {} \n" + "Input grid shape: {} \n" + "Output grid shape: {} \n" + "Periodic in longitude? {}".format( self.method, self.filename, self.reuse_weights, @@ -783,9 +843,13 @@ def to_netcdf(self, filename: Optional[str] = None) -> str: if filename is None: filename = self.filename w = self.weights.data - dim = 'n_s' + dim = "n_s" ds = xr.Dataset( - {'S': (dim, w.data), 'col': (dim, w.coords[1, :] + 1), 'row': (dim, w.coords[0, :] + 1)} + { + "S": (dim, w.data), + "col": (dim, w.coords[1, :] + 1), + "row": (dim, w.coords[0, :] + 1), + } ) ds.to_netcdf(filename) return filename @@ -794,17 +858,22 @@ def to_netcdf(self, filename: Optional[str] = None) -> str: class Regridder(BaseRegridder): def __init__( self, - ds_in: xr.DataArray | xr.Dataset | dict, - ds_out: xr.DataArray | xr.Dataset | dict, + ds_in: xr.DataArray | xr.Dataset | dict[str, xr.DataArray], + ds_out: xr.DataArray | xr.Dataset | dict[str, xr.DataArray], method: Literal[ - 'bilinear', 'conservative', 'conservative_normed', 'patch', 'nearest_s2d', 'nearest_d2s' + "bilinear", + "conservative", + "conservative_normed", + "patch", + "nearest_s2d", + "nearest_d2s", ], locstream_in: bool = False, locstream_out: bool = False, periodic: bool = False, parallel: bool = False, **kwargs, - ): + ) -> None: """ Make xESMF regridder @@ -900,30 +969,30 @@ def __init__( ------- regridder : xESMF regridder object """ - methods_avail_ls_in = ['nearest_s2d', 'nearest_d2s'] - methods_avail_ls_out = ['bilinear', 'patch'] + methods_avail_ls_in + methods_avail_ls_in = ["nearest_s2d", "nearest_d2s"] + methods_avail_ls_out = ["bilinear", "patch"] + methods_avail_ls_in if locstream_in and method not in methods_avail_ls_in: raise ValueError( - f'locstream input is only available for method in {methods_avail_ls_in}' + f"locstream input is only available for method in {methods_avail_ls_in}" ) if locstream_out and method not in methods_avail_ls_out: raise ValueError( - f'locstream output is only available for method in {methods_avail_ls_out}' + f"locstream output is only available for method in {methods_avail_ls_out}" ) - reuse_weights = kwargs.get('reuse_weights', False) + reuse_weights = kwargs.get("reuse_weights", False) - weights = kwargs.get('weights', None) + weights = kwargs.get("weights", None) if parallel and (reuse_weights or weights is not None): parallel = False warnings.warn( - 'Cannot use parallel=True when reuse_weights=True or when weights is not None. Building Regridder normally.' + "Cannot use parallel=True when reuse_weights=True or when weights is not None. Building Regridder normally." ) # Record basic switches - if method in ['conservative', 'conservative_normed']: + if method in ["conservative", "conservative_normed"]: need_bounds = True periodic = False # bound shape will not be N+1 for periodic grid else: @@ -946,7 +1015,9 @@ def __init__( if locstream_out: grid_out, shape_out, output_dims = ds_to_ESMFlocstream(ds_out) else: - grid_out, shape_out, output_dims = ds_to_ESMFgrid(ds_out, need_bounds=need_bounds) + grid_out, shape_out, output_dims = ds_to_ESMFgrid( + ds_out, need_bounds=need_bounds + ) # Create the BaseRegridder super().__init__( @@ -963,24 +1034,28 @@ def __init__( lon_out, lat_out = _get_lon_lat(ds_out) if not isinstance(lon_out, DataArray): if lon_out.ndim == 2: - dims = [('y', 'x'), ('y', 'x')] + dims = [("y", "x"), ("y", "x")] elif self.sequence_out: - dims = [('locations',), ('locations',)] + dims = [("locations",), ("locations",)] else: - dims = [('lon',), ('lat',)] - lon_out = xr.DataArray(lon_out, dims=dims[0], name='lon', attrs=LON_CF_ATTRS) - lat_out = xr.DataArray(lat_out, dims=dims[1], name='lat', attrs=LAT_CF_ATTRS) + dims = [("lon",), ("lat",)] + lon_out = xr.DataArray( + lon_out, dims=dims[0], name="lon", attrs=LON_CF_ATTRS + ) + lat_out = xr.DataArray( + lat_out, dims=dims[1], name="lat", attrs=LAT_CF_ATTRS + ) if lat_out.ndim == 2: self.out_horiz_dims = lat_out.dims elif self.sequence_out: if lat_out.dims != lon_out.dims: raise ValueError( - 'Regridder expects a locstream output, but the passed longitude ' - 'and latitude are not specified along the same dimension. ' - f'(lon: {lon_out.dims}, lat: {lat_out.dims})' + "Regridder expects a locstream output, but the passed longitude " + "and latitude are not specified along the same dimension. " + f"(lon: {lon_out.dims}, lat: {lat_out.dims})" ) - self.out_horiz_dims = ('dummy',) + lat_out.dims + self.out_horiz_dims = ("dummy",) + lat_out.dims else: self.out_horiz_dims = (lat_out.dims[0], lon_out.dims[0]) @@ -991,53 +1066,66 @@ def __init__( if set(self.out_horiz_dims).issuperset(crd.dims) } grid_mapping = { - var.attrs['grid_mapping'] + var.attrs["grid_mapping"] for var in ds_out.data_vars.values() - if 'grid_mapping' in var.attrs + if "grid_mapping" in var.attrs } if grid_mapping: - self.out_coords.update({gm: ds_out[gm] for gm in grid_mapping if gm in ds_out}) + self.out_coords.update( + {gm: ds_out[gm] for gm in grid_mapping if gm in ds_out} + ) else: self.out_coords = {lat_out.name: lat_out, lon_out.name: lon_out} if parallel: self._init_para_regrid(ds_in, ds_out, kwargs) - def _init_para_regrid(self, ds_in, ds_out, kwargs): + def _init_para_regrid( + self, + ds_in: xr.Dataset, + ds_out: xr.Dataset, + kwargs: dict, + ): # Check if we have bounds as variable and not coords, and add them to coords in both datasets - if 'lon_b' in ds_out.data_vars: - ds_out = ds_out.set_coords(['lon_b', 'lat_b']) - if 'lon_b' in ds_in.data_vars: - ds_in = ds_in.set_coords(['lon_b', 'lat_b']) - if not (set(self.out_horiz_dims) - {'dummy'}).issubset(ds_out.chunksizes.keys()): + if "lon_b" in ds_out.data_vars: + ds_out = ds_out.set_coords(["lon_b", "lat_b"]) + if "lon_b" in ds_in.data_vars: + ds_in = ds_in.set_coords(["lon_b", "lat_b"]) + if not (set(self.out_horiz_dims) - {"dummy"}).issubset( + ds_out.chunksizes.keys() + ): raise ValueError( - 'Using `parallel=True` requires the output grid to have chunks along all spatial dimensions. ' - 'If the dataset has no variables, consider adding an all-True spatial mask with appropriate chunks.' + "Using `parallel=True` requires the output grid to have chunks along all spatial dimensions. " + "If the dataset has no variables, consider adding an all-True spatial mask with appropriate chunks." ) # Drop everything in ds_out except mask or create mask if None. This is to prevent map_blocks loading unnecessary large data if self.sequence_out: ds_out_dims_drop = set(ds_out.variables).difference(ds_out.data_vars) ds_out = ds_out.drop_dims(ds_out_dims_drop) else: - if 'mask' in ds_out: + if "mask" in ds_out: mask = ds_out.mask ds_out = ds_out.coords.to_dataset() - ds_out['mask'] = mask + ds_out["mask"] = mask else: - ds_out_chunks = tuple([ds_out.chunksizes[i] for i in self.out_horiz_dims]) + ds_out_chunks = tuple( + [ds_out.chunksizes[i] for i in self.out_horiz_dims] + ) ds_out = ds_out.coords.to_dataset() mask = da.ones(self.shape_out, dtype=bool, chunks=ds_out_chunks) - ds_out['mask'] = (self.out_horiz_dims, mask) + ds_out["mask"] = (self.out_horiz_dims, mask) ds_out_dims_drop = set(ds_out.cf.coordinates.keys()).difference( - ['longitude', 'latitude'] + ["longitude", "latitude"] ) ds_out = ds_out.cf.drop_dims(ds_out_dims_drop) # Drop unnecessary variables in ds_in to save memory if not self.sequence_in: # Drop unnecessary dims - ds_in_dims_drop = set(ds_in.cf.coordinates.keys()).difference(['longitude', 'latitude']) + ds_in_dims_drop = set(ds_in.cf.coordinates.keys()).difference( + ["longitude", "latitude"] + ) ds_in = ds_in.cf.drop_dims(ds_in_dims_drop) # Drop unnecessary vars @@ -1048,35 +1136,37 @@ def _init_para_regrid(self, ds_in, ds_out, kwargs): ds_in = ds_in.compute() # if bounds in ds_out, we switch to cf bounds for map_blocks - if 'lon_b' in ds_out and (ds_out.lon_b.ndim == ds_out.cf['longitude'].ndim): + if "lon_b" in ds_out and (ds_out.lon_b.ndim == ds_out.cf["longitude"].ndim): ds_out = ds_out.assign_coords( lon_bounds=cfxr.vertices_to_bounds( - ds_out.lon_b, ('bounds', *ds_out.cf['longitude'].dims) + ds_out.lon_b, ("bounds", *ds_out.cf["longitude"].dims) ), lat_bounds=cfxr.vertices_to_bounds( - ds_out.lat_b, ('bounds', *ds_out.cf['latitude'].dims) + ds_out.lat_b, ("bounds", *ds_out.cf["latitude"].dims) ), ) # Make cf-xarray aware of the new bounds - ds_out[ds_out.cf['longitude'].name].attrs['bounds'] = 'lon_bounds' - ds_out[ds_out.cf['latitude'].name].attrs['bounds'] = 'lat_bounds' + ds_out[ds_out.cf["longitude"].name].attrs["bounds"] = "lon_bounds" + ds_out[ds_out.cf["latitude"].name].attrs["bounds"] = "lat_bounds" ds_out = ds_out.drop_dims(ds_out.lon_b.dims + ds_out.lat_b.dims) # rename dims to avoid map_blocks confusing ds_in and ds_out dims. if self.sequence_in: - ds_in = ds_in.rename({self.in_horiz_dims[0]: 'x_in'}) + ds_in = ds_in.rename({self.in_horiz_dims[0]: "x_in"}) else: - ds_in = ds_in.rename({self.in_horiz_dims[0]: 'y_in', self.in_horiz_dims[1]: 'x_in'}) + ds_in = ds_in.rename( + {self.in_horiz_dims[0]: "y_in", self.in_horiz_dims[1]: "x_in"} + ) if self.sequence_out: - ds_out = ds_out.rename({self.out_horiz_dims[1]: 'x_out'}) - out_chunks = ds_out.chunks.get('x_out') + ds_out = ds_out.rename({self.out_horiz_dims[1]: "x_out"}) + out_chunks = ds_out.chunks.get("x_out") else: ds_out = ds_out.rename( - {self.out_horiz_dims[0]: 'y_out', self.out_horiz_dims[1]: 'x_out'} + {self.out_horiz_dims[0]: "y_out", self.out_horiz_dims[1]: "x_out"} ) - out_chunks = [ds_out.chunks.get(k) for k in ['y_out', 'x_out']] + out_chunks = [ds_out.chunks.get(k) for k in ["y_out", "x_out"]] - weights_dims = ('y_out', 'x_out', 'y_in', 'x_in') + weights_dims = ("y_out", "x_out", "y_in", "x_in") templ = sps.zeros((self.shape_out + self.shape_in)) w_templ = xr.DataArray(templ, dims=weights_dims).chunk( out_chunks @@ -1097,14 +1187,14 @@ def _init_para_regrid(self, ds_in, ds_out, kwargs): kwargs=kwargs, template=w_templ, ) - w = w.compute(scheduler='processes') + w = w.compute(scheduler="processes") weights = w.stack(out_dim=weights_dims[:2], in_dim=weights_dims[2:]) - weights.name = 'weights' + weights.name = "weights" self.weights = weights # follows legacy logic of writing weights if filename is provided - if 'filename' in kwargs: - filename = kwargs['filename'] + if "filename" in kwargs: + filename = kwargs["filename"] else: filename = None if filename is not None and not self.reuse_weights: @@ -1119,10 +1209,10 @@ def _format_xroutput(self, out, new_dims=None): out = out.rename({nd: od for nd, od in zip(new_dims, self.out_horiz_dims)}) out = out.assign_coords(**self.out_coords) - out.attrs['regrid_method'] = self.method + out.attrs["regrid_method"] = self.method if self.sequence_out: - out = out.squeeze(dim='dummy') + out = out.squeeze(dim="dummy") return out @@ -1130,15 +1220,15 @@ def _format_xroutput(self, out, new_dims=None): class SpatialAverager(BaseRegridder): def __init__( self, - ds_in, - polys, - ignore_holes=False, - periodic=False, - filename=None, - reuse_weights=False, - weights=None, - ignore_degenerate=False, - geom_dim_name='geom', + ds_in: xr.DataArray | xr.Dataset | dict, + polys: Sequence[Polygon | MultiPolygon], + ignore_holes: bool = False, + periodic: bool = False, + filename: Optional[str] = None, + reuse_weights: bool = False, + weights: Optional[sps.coo_matrix | dict | str | Dataset | "Path"] = None, + ignore_degenerate: bool = False, + geom_dim_name: str = "geom", ): """Compute the exact average of a gridded array over a geometry. @@ -1223,17 +1313,19 @@ def __init__( if isinstance(ds_in, xr.DataArray): ds_in = ds_in._to_temp_dataset() - grid_in, shape_in, input_dims = ds_to_ESMFgrid(ds_in, need_bounds=True, periodic=periodic) + grid_in, shape_in, input_dims = ds_to_ESMFgrid( + ds_in, need_bounds=True, periodic=periodic + ) # Create an output locstream so that the regridder knows the output shape and coords. # Latitude and longitude coordinates are the polygon centroid. lon_out, lat_out = _get_lon_lat(ds_in) - if hasattr(lon_out, 'name'): + if hasattr(lon_out, "name"): self._lon_out_name = lon_out.name self._lat_out_name = lat_out.name else: - self._lon_out_name = 'lon' - self._lat_out_name = 'lat' + self._lon_out_name = "lon" + self._lat_out_name = "lat" # Check length of polys segments self._check_polys_length(polys) @@ -1244,14 +1336,14 @@ def __init__( # We put names 'lon' and 'lat' so ds_to_ESMFlocstream finds them easily. # _lon_out_name and _lat_out_name are used on the output anyway. - ds_out = {'lon': self._lon_out, 'lat': self._lat_out} + ds_out = {"lon": self._lon_out, "lat": self._lat_out} locstream_out, shape_out, _ = ds_to_ESMFlocstream(ds_out) # BaseRegridder with custom-computed weights and dummy out grid super().__init__( grid_in, locstream_out, - 'conservative', + "conservative", input_dims=input_dims, weights=weights, filename=filename, @@ -1261,7 +1353,7 @@ def __init__( ) @staticmethod - def _check_polys_length(polys: List[Polygons], threshold: int = 1) -> None: + def _check_polys_length(polys: List[Polygon], threshold: int = 1) -> None: # Check length of polys segments, issue warning if too long check_polys, check_holes, _, _ = split_polygons_and_holes(polys) check_polys.extend(check_holes) @@ -1269,22 +1361,24 @@ def _check_polys_length(polys: List[Polygons], threshold: int = 1) -> None: for check_poly in check_polys: b = check_poly.boundary.coords # Length of each segment - poly_segments.extend([LineString(b[k : k + 2]).length for k in range(len(b) - 1)]) + poly_segments.extend( + [LineString(b[k : k + 2]).length for k in range(len(b) - 1)] + ) if np.any(np.array(poly_segments) > threshold): warnings.warn( - f'`polys` contains large (> {threshold}°) segments. This could lead to errors over large regions. For a more accurate average, segmentize (densify) your shapes with `shapely.segmentize(polys, {threshold})`', + f"`polys` contains large (> {threshold}°) segments. This could lead to errors over large regions. For a more accurate average, segmentize (densify) your shapes with `shapely.segmentize(polys, {threshold})`", UserWarning, stacklevel=2, ) - def _compute_weights_and_area(self, mesh_out) -> tuple[DataArray, Any]: + def _compute_weights_and_area(self, mesh_out: Mesh) -> tuple[DataArray, Any]: """Return the weights and the area of the destination mesh cells.""" # Build the regrid object regrid = esmf_regrid_build( self.grid_in, mesh_out, - method='conservative', + method="conservative", ignore_degenerate=self.ignore_degenerate, ) @@ -1335,7 +1429,7 @@ def _compute_weights(self) -> DataArray: w_int, area_int = self._compute_weights_and_area(mesh_int) # Append weights from holes as negative weights - w = xr.concat((w, -w_int), 'out_dim') + w = xr.concat((w, -w_int), "out_dim") # Append areas area = np.concatenate([area, area_int]) @@ -1353,12 +1447,12 @@ def w(self) -> xr.DataArray: """ s = self.shape_out[1:2] + self.shape_in data = self.weights.data.reshape(s) - dims = self.geom_dim_name, 'y_in', 'x_in' + dims = self.geom_dim_name, "y_in", "x_in" return xr.DataArray(data, dims=dims) def _get_default_filename(self) -> str: # e.g. bilinear_400x600_300x400.nc - filename = 'spatialavg_{0}x{1}_{2}.nc'.format( + filename = "spatialavg_{0}x{1}_{2}.nc".format( self.shape_in[0], self.shape_in[1], self.n_out ) @@ -1366,24 +1460,30 @@ def _get_default_filename(self) -> str: def __repr__(self) -> str: info = ( - f'xESMF SpatialAverager \n' - f'Weight filename: {self.filename} \n' - f'Reuse pre-computed weights? {self.reuse_weights} \n' - f'Input grid shape: {self.shape_in} \n' - f'Output list length: {self.n_out} \n' + f"xESMF SpatialAverager \n" + f"Weight filename: {self.filename} \n" + f"Reuse pre-computed weights: {self.reuse_weights} \n" + f"Input grid shape: {self.shape_in} \n" + f"Output list length: {self.n_out} \n" ) return info - def _format_xroutput(self, out, new_dims=None): - out = out.squeeze(dim='dummy') + def _format_xroutput( + self, out: DataArray | Dataset, new_dims=None + ) -> DataArray | Dataset: + out = out.squeeze(dim="dummy") # rename dimension name to match output grid out = out.rename(locations=self.geom_dim_name) # append output horizontal coordinate values # extra coordinates are automatically tracked by apply_ufunc - out.coords[self._lon_out_name] = xr.DataArray(self._lon_out, dims=(self.geom_dim_name,)) - out.coords[self._lat_out_name] = xr.DataArray(self._lat_out, dims=(self.geom_dim_name,)) - out.attrs['regrid_method'] = self.method + out.coords[self._lon_out_name] = xr.DataArray( + self._lon_out, dims=(self.geom_dim_name,) + ) + out.coords[self._lat_out_name] = xr.DataArray( + self._lat_out, dims=(self.geom_dim_name,) + ) + out.attrs["regrid_method"] = self.method return out diff --git a/xesmf/smm.py b/xesmf/smm.py index 3e355cee..ca2d8d47 100644 --- a/xesmf/smm.py +++ b/xesmf/smm.py @@ -7,12 +7,13 @@ import numba as nb # type: ignore[import] import numpy as np +import numpy.typing as npt import sparse as sps # type: ignore[import] import xarray as xr def read_weights( - weights: str | Path | xr.Dataset | xr.DataArray | sps.COO | dict, # type: ignore[no-untyped-def] + weights: str | Path | xr.Dataset | xr.DataArray | sps.COO | dict[str, Any], n_in: int, n_out: int, ) -> xr.DataArray: @@ -22,7 +23,7 @@ def read_weights( Parameters ---------- weights : str, Path, xr.Dataset, xr.DataArray, sparse.COO - Weights generated by ESMF. Can be a path to a netCDF file generated by ESMF, an xarray.Dataset, + Weights generated by ESMF. Can be a path to a netCDF file generated by ESMF, an xr.Dataset, a dictionary created by `ESMPy.api.Regrid.get_weights_dict` or directly the sparse array as returned by this function. @@ -46,16 +47,16 @@ def read_weights( return _parse_coords_and_values(weights, n_in, n_out) if isinstance(weights, sps.COO): - return xr.DataArray(weights, dims=('out_dim', 'in_dim'), name='weights') + return xr.DataArray(weights, dims=("out_dim", "in_dim"), name="weights") if isinstance(weights, xr.DataArray): # type: ignore[no-untyped-def] return weights - raise ValueError(f'Weights of type {type(weights)} not understood.') + raise ValueError(f"Weights of type {type(weights)} not understood.") def _parse_coords_and_values( - indata: str | Path | xr.Dataset | dict, # type: ignore[no-untyped-def] + indata: str | Path | xr.Dataset | dict[str, Any], n_in: int, n_out: int, ) -> xr.DataArray: @@ -80,38 +81,40 @@ def _parse_coords_and_values( if isinstance(indata, (str, Path, xr.Dataset)): if not isinstance(indata, xr.Dataset): if not Path(indata).exists(): - raise IOError(f'Weights file not found on disk.\n{indata}') + raise IOError(f"Weights file not found on disk.\n{indata}") ds_w = xr.open_dataset(indata) # type: ignore[no-untyped-def] else: ds_w = indata - if not {'col', 'row', 'S'}.issubset(ds_w.variables): + if not {"col", "row", "S"}.issubset(ds_w.variables): raise ValueError( - 'Weights dataset should have variables `col`, `row` and `S` storing the indices and ' - 'values of weights.' + "Weights dataset should have variables `col`, `row` and `S` storing the indices " + "and values of weights." ) - col = ds_w['col'].values - 1 # type: ignore[no-untyped-def] - row = ds_w['row'].values - 1 # type: ignore[no-untyped-def] - s = ds_w['S'].values # type: ignore[no-untyped-def] + col = ds_w["col"].values - 1 # type: ignore[no-untyped-def] + row = ds_w["row"].values - 1 # type: ignore[no-untyped-def] + s = ds_w["S"].values # type: ignore[no-untyped-def] - elif isinstance(indata, dict): - if not {'col_src', 'row_dst', 'weights'}.issubset(indata.keys()): + elif isinstance(indata, dict): # type: ignore + if not {"col_src", "row_dst", "weights"}.issubset(indata.keys()): raise ValueError( - 'Weights dictionary should have keys `col_src`, `row_dst` and `weights` storing the ' - 'indices and values of weights.' + "Weights dictionary should have keys `col_src`, `row_dst` and `weights` storing " + "the indices and values of weights." ) - col = indata['col_src'] - 1 - row = indata['row_dst'] - 1 - s = indata['weights'] + col = indata["col_src"] - 1 + row = indata["row_dst"] - 1 + s = indata["weights"] crds = np.stack([row, col]) - return xr.DataArray(sps.COO(crds, s, (n_out, n_in)), dims=('out_dim', 'in_dim'), name='weights') + return xr.DataArray( + sps.COO(crds, s, (n_out, n_in)), dims=("out_dim", "in_dim"), name="weights" + ) def check_shapes( - indata: np.ndarray, # type: ignore[no-untyped-def] - weights: np.ndarray, # type: ignore[no-untyped-def] + indata: npt.NDArray[Any], + weights: npt.NDArray[Any], shape_in: Tuple[int, int], shape_out: Tuple[int, int], ) -> None: @@ -140,17 +143,17 @@ def check_shapes( # COO matrix is fast with F-ordered array but slow with C-array, so we # take in a C-ordered and then transpose) # (CSR or CRS matrix is fast with C-ordered array but slow with F-array) - if hasattr(indata, 'flags') and not indata.flags['C_CONTIGUOUS']: - warnings.warn('Input array is not C_CONTIGUOUS. ' 'Will affect performance.') + if hasattr(indata, "flags") and not indata.flags["C_CONTIGUOUS"]: + warnings.warn("Input array is not C_CONTIGUOUS. " "Will affect performance.") # Limitation from numba : some big-endian dtypes are not supported. try: - nb.from_dtype(indata.dtype) - nb.from_dtype(weights.dtype) - except (NotImplementedError, nb.core.errors.NumbaError): + nb.from_dtype(indata.dtype) # type: ignore + nb.from_dtype(weights.dtype) # type: ignore + except (NotImplementedError, nb.core.errors.NumbaError): # type: ignore warnings.warn( - 'Input array has a dtype not supported by sparse and numba.' - 'Computation will fall back to scipy.' + "Input array has a dtype not supported by sparse and numba." + "Computation will fall back to scipy." ) # get input shape information @@ -158,23 +161,23 @@ def check_shapes( if shape_horiz != shape_in: raise ValueError( - f'The horizontal shape of input data is {shape_horiz}, different from that ' - f'of the regridder {shape_in}!' + f"The horizontal shape of input data is {shape_horiz}, different from that " + f"of the regridder {shape_in}!" ) if shape_in[0] * shape_in[1] != weights.shape[1]: - raise ValueError('ny_in * nx_in should equal to weights.shape[1]') + raise ValueError("ny_in * nx_in should equal to weights.shape[1]") if shape_out[0] * shape_out[1] != weights.shape[0]: - raise ValueError('ny_out * nx_out should equal to weights.shape[0]') + raise ValueError("ny_out * nx_out should equal to weights.shape[0]") def apply_weights( - weights: np.ndarray, # type: ignore[no-untyped-def] - indata: np.ndarray, # type: ignore[no-untyped-def] + weights: npt.NDArray[Any], + indata: npt.NDArray[Any], shape_in: Tuple[int, int], shape_out: Tuple[int, int], -) -> np.ndarray[Any, np.dtype[Any]]: +) -> npt.NDArray[Any]: """ Apply regridding weights to data. @@ -183,7 +186,7 @@ def apply_weights( weights : sparse COO matrix Regridding weights. indata : numpy array of shape ``(..., n_lat, n_lon)`` or ``(..., n_y, n_x)``. - Should be C-ordered. Will be then tranposed to F-ordered. + Should be C-ordered. Will be then transposed to F-ordered. shape_in, shape_out : tuple of two integers Input/output data shape. For rectilinear grid, it is just ``(n_lat, n_lon)``. @@ -199,10 +202,10 @@ def apply_weights( # Limitation from numba : some big-endian dtypes are not supported. indata_dtype = indata.dtype try: - nb.from_dtype(indata.dtype) - nb.from_dtype(weights.dtype) - except (NotImplementedError, nb.core.errors.NumbaError): - indata = indata.astype(' xr.DataArray: for krow in range(len(m.rows)): m.rows[krow] = [0] if m.rows[krow] == [] else m.rows[krow] m.data[krow] = [np.NaN] if m.data[krow] == [] else m.data[krow] + # update regridder weights (in COO) - weights = weights.copy(data=sps.COO.from_scipy_sparse(m)) + weights = weights.copy(data=sps.COO.from_scipy_sparse(m)) # type: ignore return weights -def _combine_weight_multipoly( +def _combine_weight_multipoly( # type: ignore weights: xr.DataArray, - areas: np.ndarray[Any, np.dtype[Any]], - indexes: np.ndarray[Any, np.dtype[Any]], + areas: npt.NDArray[np.integer[Any]], + indexes: npt.NDArray[np.integer[Any]], ) -> xr.DataArray: """Reduce a weight sparse matrix (csc format) by combining (adding) columns. @@ -275,7 +279,7 @@ def _combine_weight_multipoly( Sum of weights from individual geometries. """ - sub_weights = weights.rename(out_dim='subgeometries') + sub_weights = weights.rename(out_dim="subgeometries") # Create a sparse DataArray with the mesh areas # This ties the `out_dim` (the dimension for the original geometries) to the @@ -283,19 +287,21 @@ def _combine_weight_multipoly( crds = np.stack([indexes, np.arange(len(indexes))]) a = xr.DataArray( sps.COO(crds, areas, (indexes.max() + 1, len(indexes)), fill_value=0), - dims=('out_dim', 'subgeometries'), - name='area', + dims=("out_dim", "subgeometries"), + name="area", ) # Weight the regridding weights by the area of the destination polygon and sum over sub-geometries - out = (sub_weights * a).sum(dim='subgeometries') + out = (sub_weights * a).sum(dim="subgeometries") # Renormalize weights along in_dim - wsum = out.sum('in_dim') + wsum = out.sum("in_dim") # Change the fill_value to 1 wsum = wsum.copy( - data=sps.COO(wsum.data.coords, wsum.data.data, shape=wsum.data.shape, fill_value=1) + data=sps.COO( + wsum.data.coords, wsum.data.data, shape=wsum.data.shape, fill_value=1 + ) ) return out / wsum diff --git a/xesmf/util.py b/xesmf/util.py index b78db6bd..5e24bbb7 100644 --- a/xesmf/util.py +++ b/xesmf/util.py @@ -2,14 +2,17 @@ import warnings import numpy as np +import numpy.typing as npt import xarray as xr from shapely.geometry import MultiPolygon, Polygon -LON_CF_ATTRS = {'standard_name': 'longitude', 'units': 'degrees_east'} -LAT_CF_ATTRS = {'standard_name': 'latitude', 'units': 'degrees_north'} +LON_CF_ATTRS = {"standard_name": "longitude", "units": "degrees_east"} +LAT_CF_ATTRS = {"standard_name": "latitude", "units": "degrees_north"} -def _grid_1d(start_b: float, end_b: float, step: float): +def _grid_1d( + start_b: float, end_b: float, step: float +) -> tuple[npt.NDArray[np.floating[Any]], npt.NDArray[np.floating[Any]]]: """ 1D grid centers and bounds @@ -73,10 +76,10 @@ def grid_2d( ds = xr.Dataset( coords={ - 'lon': (['y', 'x'], lon, {'standard_name': 'longitude'}), - 'lat': (['y', 'x'], lat, {'standard_name': 'latitude'}), - 'lon_b': (['y_b', 'x_b'], lon_b), - 'lat_b': (['y_b', 'x_b'], lat_b), + "lon": (["y", "x"], lon, {"standard_name": "longitude"}), + "lat": (["y", "x"], lat, {"standard_name": "latitude"}), + "lon_b": (["y_b", "x_b"], lon_b), + "lat_b": (["y_b", "x_b"], lat_b), } ) @@ -120,21 +123,21 @@ def cf_grid_2d( ds = xr.Dataset( coords={ - 'lon': ( - 'lon', + "lon": ( + "lon", lon_1d, - {'bounds': 'lon_bounds', **LON_CF_ATTRS}, + {"bounds": "lon_bounds", **LON_CF_ATTRS}, ), - 'lat': ( - 'lat', + "lat": ( + "lat", lat_1d, - {'bounds': 'lat_bounds', **LAT_CF_ATTRS}, + {"bounds": "lat_bounds", **LAT_CF_ATTRS}, ), - 'latitude_longitude': xr.DataArray(), + "latitude_longitude": xr.DataArray(), }, data_vars={ - 'lon_bounds': vertices_to_bounds(lon_b_1d, ('bound', 'lon')), - 'lat_bounds': vertices_to_bounds(lat_b_1d, ('bound', 'lat')), + "lon_bounds": vertices_to_bounds(lon_b_1d, ("bound", "lon")), + "lat_bounds": vertices_to_bounds(lat_b_1d, ("bound", "lat")), }, ) @@ -170,14 +173,12 @@ def grid_global( if not np.isclose(360 / d_lon, 360 // d_lon): warnings.warn( - '360 cannot be divided by d_lon = {}, ' - 'might not cover the globe uniformly'.format(d_lon) + f"360 cannot be divided by d_lon = {d_lon}, might not cover the globe uniformly" ) if not np.isclose(180 / d_lat, 180 // d_lat): warnings.warn( - '180 cannot be divided by d_lat = {}, ' - 'might not cover the globe uniformly'.format(d_lat) + f"180 cannot be divided by d_lat = {d_lat}, might not cover the globe uniformly" ) lon0 = lon1 - 360 @@ -242,7 +243,12 @@ def split_polygons_and_holes( HUGE = 1.0e30 -def simple_tripolar_grid(nlons: int, nlats: int, lat_cap: float = 60, lon_cut: float = -300): +def simple_tripolar_grid( + nlons: int, + nlats: int, + lat_cap: float = 60, + lon_cut: float = -300, +) -> tuple[npt.NDArray[np.floating[Any]], npt.NDArray[Any]]: """Generate a simple tripolar grid, regular under `lat_cap`. Parameters @@ -259,7 +265,7 @@ def simple_tripolar_grid(nlons: int, nlats: int, lat_cap: float = 60, lon_cut: f """ # first generate the bipolar cap for north poles - nj_cap = np.rint(nlats * lat_cap / 180.0).astype('int') + nj_cap = np.rint(nlats * lat_cap / 180.0).astype("int") lams, phis, _, _ = _generate_bipolar_cap_mesh( nlons, nj_cap, lat_cap, lon_cut, ensure_nj_even=True @@ -283,9 +289,9 @@ def simple_tripolar_grid(nlons: int, nlats: int, lat_cap: float = 60, lon_cut: f def _bipolar_projection( - lamg: float, - phig: float, - lon_bp: float, + lamg: npt.NDArray[np.floating[Any]], + phig: npt.NDArray[np.floating[Any]], + lon_bp: npt.NDArray[np.floating[Any]], rp: float, metrics_only: bool = False, ): @@ -314,7 +320,9 @@ def _bipolar_projection( # One way is simply to demand lamc to be continuous with lam on the equator phi=0 # I am sure there is a more mathematically concrete way to do this. lamc = np.where((lamg - lon_bp > 90) & (lamg - lon_bp <= 180), 180 - lamc, lamc) - lamc = np.where((lamg - lon_bp > 180) & (lamg - lon_bp <= 270), 180 + lamc, lamc) + lamc = np.where( + (lamg - lon_bp > 180) & (lamg - lon_bp <= 270), 180 + lamc, lamc + ) lamc = np.where((lamg - lon_bp > 270), 360 - lamc, lamc) # Along symmetry meridian choose lamc lamc = np.where( @@ -339,7 +347,9 @@ def _bipolar_projection( N_inv = 1 / N cos2phis = (np.cos(phis * PI_180)) ** 2 - h_j_inv_t1 = cos2phis * alpha2 * (1 - alpha2) * beta2_inv * (1 + beta2_inv) * (rden**2) + h_j_inv_t1 = ( + cos2phis * alpha2 * (1 - alpha2) * beta2_inv * (1 + beta2_inv) * (rden**2) + ) h_j_inv_t2 = M_inv * M_inv * (1 - alpha2) * rden h_j_inv = h_j_inv_t1 + h_j_inv_t2 @@ -347,7 +357,10 @@ def _bipolar_projection( h_j_inv = np.where(np.abs(beta2_inv) > HUGE, M_inv * M_inv, h_j_inv) h_j_inv = np.sqrt(h_j_inv) * N_inv - h_i_inv = cos2phis * (1 + beta2_inv) * (rden**2) + M_inv * M_inv * alpha2 * beta2_inv * rden + h_i_inv = ( + cos2phis * (1 + beta2_inv) * (rden**2) + + M_inv * M_inv * alpha2 * beta2_inv * rden + ) # Deal with beta=0 h_i_inv = np.where(np.abs(beta2_inv) > HUGE, M_inv * M_inv, h_i_inv) h_i_inv = np.sqrt(h_i_inv) @@ -367,22 +380,22 @@ def _generate_bipolar_cap_mesh( ): # Define a (lon,lat) coordinate mesh on the Northern hemisphere of the globe sphere # such that the resolution of latg matches the desired resolution of the final grid along the symmetry meridian - print('Generating bipolar grid bounded at latitude ', lat0_bp) + print("Generating bipolar grid bounded at latitude ", lat0_bp) if Nj_ncap % 2 != 0 and ensure_nj_even: - print(' Supergrid has an odd number of area cells!') + print(" Supergrid has an odd number of area cells!") if ensure_nj_even: print(" The number of j's is not even. Fixing this by cutting one row.") Nj_ncap = Nj_ncap - 1 lon_g = lon_bp + np.arange(Ni + 1) * 360.0 / float(Ni) - lamg = np.tile(lon_g, (Nj_ncap + 1, 1)) + lamg: float = np.tile(lon_g, (Nj_ncap + 1, 1)) latg0_cap = lat0_bp + np.arange(Nj_ncap + 1) * (90 - lat0_bp) / float(Nj_ncap) - phig = np.tile(latg0_cap.reshape((Nj_ncap + 1, 1)), (1, Ni + 1)) + phig: float = np.tile(latg0_cap.reshape((Nj_ncap + 1, 1)), (1, Ni + 1)) rp = np.tan(0.5 * (90 - lat0_bp) * PI_180) lams, phis, h_i_inv, h_j_inv = _bipolar_projection(lamg, phig, lon_bp, rp) h_i_inv = h_i_inv[:, :-1] * 2 * np.pi / float(Ni) h_j_inv = h_j_inv[:-1, :] * PI_180 * (90 - lat0_bp) / float(Nj_ncap) - print(' number of js=', phis.shape[0]) + print(" number of js=", phis.shape[0]) return lams, phis, h_i_inv, h_j_inv From 28521304ac1c0ef6f2bde48fb11d8b4fe625e357 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Sat, 28 Oct 2023 16:30:48 +0530 Subject: [PATCH 04/16] [Formatting] isort + black --- xesmf/backend.py | 36 +++++--------- xesmf/data.py | 1 + xesmf/frontend.py | 118 ++++++++++++---------------------------------- xesmf/smm.py | 10 ++-- xesmf/util.py | 15 ++---- 5 files changed, 49 insertions(+), 131 deletions(-) diff --git a/xesmf/backend.py b/xesmf/backend.py index ff1733dc..61d7db42 100644 --- a/xesmf/backend.py +++ b/xesmf/backend.py @@ -16,13 +16,14 @@ """ import os -from typing import Optional import warnings +from typing import Optional try: import esmpy as ESMF except ImportError: import ESMF + import numpy as np import numpy.lib.recfunctions as nprec @@ -144,12 +145,9 @@ def from_xarray( if not (grid_mask.shape == lon.shape): raise ValueError( "mask must have the same shape as the latitude/longitude" - "coordinates, got: mask.shape = %s, lon.shape = %s" - % (mask.shape, lon.shape) + "coordinates, got: mask.shape = %s, lon.shape = %s" % (mask.shape, lon.shape) ) - grid.add_item( - ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER, from_file=False - ) + grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER, from_file=False) grid.mask[0][:] = grid_mask return grid @@ -224,9 +222,7 @@ def add_corner(grid, lon_b, lat_b): assert lon_b.ndim == 2, "Input grid must be 2D array" assert lon_b.shape == lat_b.shape, "lon_b and lat_b must have same shape" - assert np.array_equal( - lon_b.shape, grid.max_index + 1 - ), "lon_b should be size (Nx+1, Ny+1)" + assert np.array_equal(lon_b.shape, grid.max_index + 1), "lon_b should be size (Nx+1, Ny+1)" assert (grid.num_peri_dims == 0) and ( grid.periodic_dim is None ), "Cannot add corner for periodic grid" @@ -410,9 +406,7 @@ def esmf_regrid_build( try: esmf_regrid_method = method_dict[method] except Exception: - raise ValueError( - "method should be chosen from " "{}".format(list(method_dict.keys())) - ) + raise ValueError("method should be chosen from " "{}".format(list(method_dict.keys()))) # use shorter, clearer names for options in ESMF.ExtrapMethod extrap_dict = { @@ -424,9 +418,7 @@ def esmf_regrid_build( esmf_extrap_method = extrap_dict[extrap_method] except KeyError: raise KeyError( - "`extrap_method` should be chosen from " "{}".format( - list(extrap_dict.keys()) - ) + "`extrap_method` should be chosen from " "{}".format(list(extrap_dict.keys())) ) # until ESMPy updates ESMP_FieldRegridStoreFile, extrapolation is not possible @@ -438,28 +430,22 @@ def esmf_regrid_build( if method in ["conservative", "conservative_normed"]: if not isinstance(sourcegrid, ESMF.Mesh) and not sourcegrid.has_corners: raise ValueError( - "source grid has no corner information. " - "cannot use conservative regridding." + "source grid has no corner information. " "cannot use conservative regridding." ) if not isinstance(destgrid, ESMF.Mesh) and not destgrid.has_corners: raise ValueError( - "destination grid has no corner information. " - "cannot use conservative regridding." + "destination grid has no corner information. " "cannot use conservative regridding." ) # ESMF.Regrid requires Field (Grid+data) as input, not just Grid. # Extra dimensions are specified when constructing the Field objects, # not when constructing the Regrid object later on. if isinstance(sourcegrid, ESMF.Mesh): - sourcefield = ESMF.Field( - sourcegrid, meshloc=ESMF.MeshLoc.ELEMENT, ndbounds=extra_dims - ) + sourcefield = ESMF.Field(sourcegrid, meshloc=ESMF.MeshLoc.ELEMENT, ndbounds=extra_dims) else: sourcefield = ESMF.Field(sourcegrid, ndbounds=extra_dims) if isinstance(destgrid, ESMF.Mesh): - destfield = ESMF.Field( - destgrid, meshloc=ESMF.MeshLoc.ELEMENT, ndbounds=extra_dims - ) + destfield = ESMF.Field(destgrid, meshloc=ESMF.MeshLoc.ELEMENT, ndbounds=extra_dims) else: destfield = ESMF.Field(destgrid, ndbounds=extra_dims) diff --git a/xesmf/data.py b/xesmf/data.py index 6b9a8ac3..0505ae4e 100644 --- a/xesmf/data.py +++ b/xesmf/data.py @@ -3,6 +3,7 @@ """ from typing import Any + import numpy as np import numpy.typing as npt import xarray diff --git a/xesmf/frontend.py b/xesmf/frontend.py index d4e5621d..ff163b05 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -3,23 +3,17 @@ """ import warnings +from typing import Any, Dict, Hashable, List, Literal, Optional, Sequence, Tuple import cf_xarray as cfxr import numpy as np import numpy.typing as npt import sparse as sps import xarray as xr -from shapely.geometry import LineString, Polygon, MultiPolygon +from shapely.geometry import LineString, MultiPolygon, Polygon from xarray import DataArray, Dataset -from typing import Any, Dict, Hashable, List, Literal, Optional, Sequence, Tuple -from .backend import ( - Grid, - LocStream, - Mesh, - add_corner, - esmf_regrid_build, - esmf_regrid_finalize, -) + +from .backend import Grid, LocStream, Mesh, add_corner, esmf_regrid_build, esmf_regrid_finalize from .smm import ( _combine_weight_multipoly, _parse_coords_and_values, @@ -143,12 +137,8 @@ def _get_lon_lat_bounds( # Convert from CF bounds to xESMF bounds. # order=None is because we don't want to assume the dimension order for 2D bounds. - lon_b = cfxr.bounds_to_vertices( - lon_bnds, ds.cf.get_bounds_dim_name("longitude"), order=None - ) - lat_b = cfxr.bounds_to_vertices( - lat_bnds, ds.cf.get_bounds_dim_name("latitude"), order=None - ) + lon_b = cfxr.bounds_to_vertices(lon_bnds, ds.cf.get_bounds_dim_name("longitude"), order=None) + lat_b = cfxr.bounds_to_vertices(lat_bnds, ds.cf.get_bounds_dim_name("latitude"), order=None) return lon_b, lat_b @@ -393,15 +383,10 @@ def __init__( self.sequence_out = isinstance(self.grid_out, (LocStream, Mesh)) if input_dims is not None and len(input_dims) != int(not self.sequence_in) + 1: - raise ValueError( - f"Wrong number of dimension names in `input_dims` ({len(input_dims)}." - ) + raise ValueError(f"Wrong number of dimension names in `input_dims` ({len(input_dims)}.") self.in_horiz_dims = input_dims - if ( - output_dims is not None - and len(output_dims) != int(not self.sequence_out) + 1 - ): + if output_dims is not None and len(output_dims) != int(not self.sequence_out) + 1: raise ValueError( f"Wrong number of dimension names in `output dims` ({len(output_dims)}." ) @@ -416,9 +401,7 @@ def __init__( # some logic about reusing weights with either filename or weights args if reuse_weights and (filename is None) and (weights is None): - raise ValueError( - "To reuse weights, you need to provide either filename or weights." - ) + raise ValueError("To reuse weights, you need to provide either filename or weights.") if not parallel: if not reuse_weights and weights is None: @@ -442,9 +425,7 @@ def __init__( self.to_netcdf(filename=filename) # set default weights filename if none given - self.filename = ( - self._get_default_filename() if filename is None else filename - ) + self.filename = self._get_default_filename() if filename is None else filename @property def A(self) -> DataArray: @@ -612,9 +593,7 @@ def __call__( output_chunks=output_chunks, ) else: - raise TypeError( - "input must be numpy array, dask array, xarray DataArray or Dataset!" - ) + raise TypeError("input must be numpy array, dask array, xarray DataArray or Dataset!") @staticmethod def _regrid( @@ -636,9 +615,7 @@ def _regrid( # skipna: Compute the influence of missing data at each interpolation point and filter those not meeting acceptable threshold. if skipna: - fraction_valid = apply_weights( - weights, (~missing).astype("d"), shape_in, shape_out - ) + fraction_valid = apply_weights(weights, (~missing).astype("d"), shape_in, shape_out) tol = 1e-6 bad = fraction_valid < np.clip(1 - na_thres, tol, 1 - tol) fraction_valid[bad] = 1 @@ -675,8 +652,7 @@ def regrid_array( if isinstance(indata, dask_array_type): # dask if output_chunks is None: output_chunks = tuple( - min(shp, inchnk) - for shp, inchnk in zip(self.shape_out, indata.chunksize[-2:]) + min(shp, inchnk) for shp, inchnk in zip(self.shape_out, indata.chunksize[-2:]) ) if len(output_chunks) != len(self.shape_out): if len(output_chunks) == 1 and self.sequence_out: @@ -687,26 +663,20 @@ def regrid_array( f" output_chunks dimension ({len(output_chunks)}) does not " f"match ds_out dimension ({len(self.shape_out)})" ) - weights = da.from_array( - weights, chunks=(output_chunks + indata.chunksize[-2:]) - ) + weights = da.from_array(weights, chunks=(output_chunks + indata.chunksize[-2:])) outdata = self._regrid(indata, weights, **kwargs) else: # numpy outdata = self._regrid(indata, weights, **kwargs) return outdata - def regrid_numpy( - self, indata: npt.NDArray[np.float256], **kwargs - ) -> npt.NDArray[np.float256]: + def regrid_numpy(self, indata: npt.NDArray[np.float256], **kwargs) -> npt.NDArray[np.float256]: warnings.warn( "`regrid_numpy()` will be removed in xESMF 0.7, please use `regrid_array` instead.", category=FutureWarning, ) return self.regrid_array(indata, self.weights.data, **kwargs) - def regrid_dask( - self, indata: npt.NDArray[np.float256], **kwargs - ) -> npt.NDArray[np.float256]: + def regrid_dask(self, indata: npt.NDArray[np.float256], **kwargs) -> npt.NDArray[np.float256]: warnings.warn( "`regrid_dask()` will be removed in xESMF 0.7, please use `regrid_array` instead.", category=FutureWarning, @@ -778,9 +748,7 @@ def _parse_xrinput( ) -> Tuple[Tuple[Hashable, ...], List[str]]: # dr could be a DataArray or a Dataset # Get input horiz dim names and set output horiz dim names - if self.in_horiz_dims is not None and all( - dim in dr_in.dims for dim in self.in_horiz_dims - ): + if self.in_horiz_dims is not None and all(dim in dr_in.dims for dim in self.in_horiz_dims): input_horiz_dims = self.in_horiz_dims else: if isinstance(dr_in, Dataset): @@ -1015,9 +983,7 @@ def __init__( if locstream_out: grid_out, shape_out, output_dims = ds_to_ESMFlocstream(ds_out) else: - grid_out, shape_out, output_dims = ds_to_ESMFgrid( - ds_out, need_bounds=need_bounds - ) + grid_out, shape_out, output_dims = ds_to_ESMFgrid(ds_out, need_bounds=need_bounds) # Create the BaseRegridder super().__init__( @@ -1039,12 +1005,8 @@ def __init__( dims = [("locations",), ("locations",)] else: dims = [("lon",), ("lat",)] - lon_out = xr.DataArray( - lon_out, dims=dims[0], name="lon", attrs=LON_CF_ATTRS - ) - lat_out = xr.DataArray( - lat_out, dims=dims[1], name="lat", attrs=LAT_CF_ATTRS - ) + lon_out = xr.DataArray(lon_out, dims=dims[0], name="lon", attrs=LON_CF_ATTRS) + lat_out = xr.DataArray(lat_out, dims=dims[1], name="lat", attrs=LAT_CF_ATTRS) if lat_out.ndim == 2: self.out_horiz_dims = lat_out.dims @@ -1071,9 +1033,7 @@ def __init__( if "grid_mapping" in var.attrs } if grid_mapping: - self.out_coords.update( - {gm: ds_out[gm] for gm in grid_mapping if gm in ds_out} - ) + self.out_coords.update({gm: ds_out[gm] for gm in grid_mapping if gm in ds_out}) else: self.out_coords = {lat_out.name: lat_out, lon_out.name: lon_out} @@ -1091,9 +1051,7 @@ def _init_para_regrid( ds_out = ds_out.set_coords(["lon_b", "lat_b"]) if "lon_b" in ds_in.data_vars: ds_in = ds_in.set_coords(["lon_b", "lat_b"]) - if not (set(self.out_horiz_dims) - {"dummy"}).issubset( - ds_out.chunksizes.keys() - ): + if not (set(self.out_horiz_dims) - {"dummy"}).issubset(ds_out.chunksizes.keys()): raise ValueError( "Using `parallel=True` requires the output grid to have chunks along all spatial dimensions. " "If the dataset has no variables, consider adding an all-True spatial mask with appropriate chunks." @@ -1108,9 +1066,7 @@ def _init_para_regrid( ds_out = ds_out.coords.to_dataset() ds_out["mask"] = mask else: - ds_out_chunks = tuple( - [ds_out.chunksizes[i] for i in self.out_horiz_dims] - ) + ds_out_chunks = tuple([ds_out.chunksizes[i] for i in self.out_horiz_dims]) ds_out = ds_out.coords.to_dataset() mask = da.ones(self.shape_out, dtype=bool, chunks=ds_out_chunks) ds_out["mask"] = (self.out_horiz_dims, mask) @@ -1123,9 +1079,7 @@ def _init_para_regrid( # Drop unnecessary variables in ds_in to save memory if not self.sequence_in: # Drop unnecessary dims - ds_in_dims_drop = set(ds_in.cf.coordinates.keys()).difference( - ["longitude", "latitude"] - ) + ds_in_dims_drop = set(ds_in.cf.coordinates.keys()).difference(["longitude", "latitude"]) ds_in = ds_in.cf.drop_dims(ds_in_dims_drop) # Drop unnecessary vars @@ -1153,9 +1107,7 @@ def _init_para_regrid( if self.sequence_in: ds_in = ds_in.rename({self.in_horiz_dims[0]: "x_in"}) else: - ds_in = ds_in.rename( - {self.in_horiz_dims[0]: "y_in", self.in_horiz_dims[1]: "x_in"} - ) + ds_in = ds_in.rename({self.in_horiz_dims[0]: "y_in", self.in_horiz_dims[1]: "x_in"}) if self.sequence_out: ds_out = ds_out.rename({self.out_horiz_dims[1]: "x_out"}) @@ -1313,9 +1265,7 @@ def __init__( if isinstance(ds_in, xr.DataArray): ds_in = ds_in._to_temp_dataset() - grid_in, shape_in, input_dims = ds_to_ESMFgrid( - ds_in, need_bounds=True, periodic=periodic - ) + grid_in, shape_in, input_dims = ds_to_ESMFgrid(ds_in, need_bounds=True, periodic=periodic) # Create an output locstream so that the regridder knows the output shape and coords. # Latitude and longitude coordinates are the polygon centroid. @@ -1361,9 +1311,7 @@ def _check_polys_length(polys: List[Polygon], threshold: int = 1) -> None: for check_poly in check_polys: b = check_poly.boundary.coords # Length of each segment - poly_segments.extend( - [LineString(b[k : k + 2]).length for k in range(len(b) - 1)] - ) + poly_segments.extend([LineString(b[k : k + 2]).length for k in range(len(b) - 1)]) if np.any(np.array(poly_segments) > threshold): warnings.warn( f"`polys` contains large (> {threshold}°) segments. This could lead to errors over large regions. For a more accurate average, segmentize (densify) your shapes with `shapely.segmentize(polys, {threshold})`", @@ -1469,9 +1417,7 @@ def __repr__(self) -> str: return info - def _format_xroutput( - self, out: DataArray | Dataset, new_dims=None - ) -> DataArray | Dataset: + def _format_xroutput(self, out: DataArray | Dataset, new_dims=None) -> DataArray | Dataset: out = out.squeeze(dim="dummy") # rename dimension name to match output grid @@ -1479,11 +1425,7 @@ def _format_xroutput( # append output horizontal coordinate values # extra coordinates are automatically tracked by apply_ufunc - out.coords[self._lon_out_name] = xr.DataArray( - self._lon_out, dims=(self.geom_dim_name,) - ) - out.coords[self._lat_out_name] = xr.DataArray( - self._lat_out, dims=(self.geom_dim_name,) - ) + out.coords[self._lon_out_name] = xr.DataArray(self._lon_out, dims=(self.geom_dim_name,)) + out.coords[self._lat_out_name] = xr.DataArray(self._lat_out, dims=(self.geom_dim_name,)) out.attrs["regrid_method"] = self.method return out diff --git a/xesmf/smm.py b/xesmf/smm.py index ca2d8d47..f82ea8d4 100644 --- a/xesmf/smm.py +++ b/xesmf/smm.py @@ -1,9 +1,9 @@ """ Sparse matrix multiplication (SMM) using scipy.sparse library. """ -from typing import Any, Tuple import warnings from pathlib import Path +from typing import Any, Tuple import numba as nb # type: ignore[import] import numpy as np @@ -107,9 +107,7 @@ def _parse_coords_and_values( s = indata["weights"] crds = np.stack([row, col]) - return xr.DataArray( - sps.COO(crds, s, (n_out, n_in)), dims=("out_dim", "in_dim"), name="weights" - ) + return xr.DataArray(sps.COO(crds, s, (n_out, n_in)), dims=("out_dim", "in_dim"), name="weights") def check_shapes( @@ -299,9 +297,7 @@ def _combine_weight_multipoly( # type: ignore # Change the fill_value to 1 wsum = wsum.copy( - data=sps.COO( - wsum.data.coords, wsum.data.data, shape=wsum.data.shape, fill_value=1 - ) + data=sps.COO(wsum.data.coords, wsum.data.data, shape=wsum.data.shape, fill_value=1) ) return out / wsum diff --git a/xesmf/util.py b/xesmf/util.py index 5e24bbb7..0b7c57fe 100644 --- a/xesmf/util.py +++ b/xesmf/util.py @@ -1,5 +1,5 @@ -from typing import Any, Generator, List, Literal, Tuple import warnings +from typing import Any, Generator, List, Literal, Tuple import numpy as np import numpy.typing as npt @@ -320,9 +320,7 @@ def _bipolar_projection( # One way is simply to demand lamc to be continuous with lam on the equator phi=0 # I am sure there is a more mathematically concrete way to do this. lamc = np.where((lamg - lon_bp > 90) & (lamg - lon_bp <= 180), 180 - lamc, lamc) - lamc = np.where( - (lamg - lon_bp > 180) & (lamg - lon_bp <= 270), 180 + lamc, lamc - ) + lamc = np.where((lamg - lon_bp > 180) & (lamg - lon_bp <= 270), 180 + lamc, lamc) lamc = np.where((lamg - lon_bp > 270), 360 - lamc, lamc) # Along symmetry meridian choose lamc lamc = np.where( @@ -347,9 +345,7 @@ def _bipolar_projection( N_inv = 1 / N cos2phis = (np.cos(phis * PI_180)) ** 2 - h_j_inv_t1 = ( - cos2phis * alpha2 * (1 - alpha2) * beta2_inv * (1 + beta2_inv) * (rden**2) - ) + h_j_inv_t1 = cos2phis * alpha2 * (1 - alpha2) * beta2_inv * (1 + beta2_inv) * (rden**2) h_j_inv_t2 = M_inv * M_inv * (1 - alpha2) * rden h_j_inv = h_j_inv_t1 + h_j_inv_t2 @@ -357,10 +353,7 @@ def _bipolar_projection( h_j_inv = np.where(np.abs(beta2_inv) > HUGE, M_inv * M_inv, h_j_inv) h_j_inv = np.sqrt(h_j_inv) * N_inv - h_i_inv = ( - cos2phis * (1 + beta2_inv) * (rden**2) - + M_inv * M_inv * alpha2 * beta2_inv * rden - ) + h_i_inv = cos2phis * (1 + beta2_inv) * (rden**2) + M_inv * M_inv * alpha2 * beta2_inv * rden # Deal with beta=0 h_i_inv = np.where(np.abs(beta2_inv) > HUGE, M_inv * M_inv, h_i_inv) h_i_inv = np.sqrt(h_i_inv) From d736c3c088bb146d4e9d57cd01755332887915af Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Wed, 1 Nov 2023 09:40:40 +0530 Subject: [PATCH 05/16] fixed more typing --- xesmf/frontend.py | 47 +++++++++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/xesmf/frontend.py b/xesmf/frontend.py index ff163b05..8677c7be 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -13,7 +13,14 @@ from shapely.geometry import LineString, MultiPolygon, Polygon from xarray import DataArray, Dataset -from .backend import Grid, LocStream, Mesh, add_corner, esmf_regrid_build, esmf_regrid_finalize +from .backend import ( + Grid, + LocStream, + Mesh, + add_corner, + esmf_regrid_build, + esmf_regrid_finalize, +) from .smm import ( _combine_weight_multipoly, _parse_coords_and_values, @@ -79,9 +86,9 @@ def subset_regridder( def as_2d_mesh( - lon: DataArray | npt.NDArray[np.float256], - lat: DataArray | npt.NDArray[np.float256], -) -> tuple[DataArray | npt.NDArray[np.float256], DataArray | npt.NDArray[np.float256]]: + lon: DataArray | npt.NDArray[np.floating[Any]], + lat: DataArray | npt.NDArray[np.floating[Any]], +) -> tuple[DataArray | npt.NDArray[np.floating[Any]], DataArray | npt.NDArray[np.floating[Any]]]: if (lon.ndim, lat.ndim) == (2, 2): assert lon.shape == lat.shape, "lon and lat should have same shape" elif (lon.ndim, lat.ndim) == (1, 1): @@ -93,8 +100,8 @@ def as_2d_mesh( def _get_lon_lat( - ds: Dataset | Dict[str, npt.NDArray[np.float256]] -) -> tuple[DataArray | npt.NDArray[np.float256], DataArray | npt.NDArray[np.float256]]: + ds: Dataset | Dict[str, npt.NDArray[np.floating[Any]]] +) -> tuple[DataArray | npt.NDArray[np.floating[Any]], DataArray | npt.NDArray[np.floating[Any]]]: """Return lon and lat extracted from ds.""" if ("lat" in ds and "lon" in ds) or ("lat" in ds.coords and "lon" in ds.coords): # Old way. @@ -112,8 +119,8 @@ def _get_lon_lat( def _get_lon_lat_bounds( - ds: Dataset | Dict[str, npt.NDArray[np.float256]] -) -> tuple[DataArray | npt.NDArray[np.float256], DataArray | npt.NDArray[np.float256]]: + ds: Dataset | Dict[str, npt.NDArray[np.floating[Any]]] +) -> tuple[DataArray | npt.NDArray[np.floating[Any]], DataArray | npt.NDArray[np.floating[Any]]]: """Return bounds of lon and lat extracted from ds.""" if "lat_b" in ds and "lon_b" in ds: # Old way. @@ -143,7 +150,7 @@ def _get_lon_lat_bounds( def ds_to_ESMFgrid( - ds: Dataset | Dict[str, npt.NDArray[np.float256]], + ds: Dataset | Dict[str, npt.NDArray[np.floating[Any]]], need_bounds: bool = False, periodic: bool = False, append=None, @@ -488,7 +495,7 @@ def _compute_weights(self): def __call__( self, - indata: npt.NDArray[np.float256] | dask_array_type | xr.DataArray | xr.Dataset, + indata: npt.NDArray[np.floating[Any]] | dask_array_type | xr.DataArray | xr.Dataset, keep_attrs: bool = False, skipna: bool = False, na_thres: float = 1.0, @@ -597,17 +604,17 @@ def __call__( @staticmethod def _regrid( - indata: npt.NDArray[np.float256], + indata: npt.NDArray[np.floating[Any]], weights: sps.coo_matrix, *, shape_in: Tuple[int, int], shape_out: Tuple[int, int], skipna: bool, na_thresh: float, - ) -> npt.NDArray[np.float256]: + ) -> npt.NDArray[np.floating[Any]]: # skipna: set missing values to zero if skipna: - missing = np.isnan(indata) + missing: npt.NDArray[np.bool_] = np.isnan(indata) indata = np.where(missing, 0.0, indata) # apply weights @@ -625,7 +632,7 @@ def _regrid( def regrid_array( self, - indata: npt.NDArray[np.float256] | dask_array_type, + indata: npt.NDArray[np.floating[Any]] | dask_array_type, weights: sps.coo_matrix, skipna: bool = False, na_thres: float = 1.0, @@ -669,14 +676,18 @@ def regrid_array( outdata = self._regrid(indata, weights, **kwargs) return outdata - def regrid_numpy(self, indata: npt.NDArray[np.float256], **kwargs) -> npt.NDArray[np.float256]: + def regrid_numpy( + self, indata: npt.NDArray[np.floating[Any]], **kwargs + ) -> npt.NDArray[np.floating[Any]]: warnings.warn( "`regrid_numpy()` will be removed in xESMF 0.7, please use `regrid_array` instead.", category=FutureWarning, ) return self.regrid_array(indata, self.weights.data, **kwargs) - def regrid_dask(self, indata: npt.NDArray[np.float256], **kwargs) -> npt.NDArray[np.float256]: + def regrid_dask( + self, indata: npt.NDArray[np.floating[Any]], **kwargs + ) -> npt.NDArray[np.floating[Any]]: warnings.warn( "`regrid_dask()` will be removed in xESMF 0.7, please use `regrid_array` instead.", category=FutureWarning, @@ -689,7 +700,7 @@ def regrid_dataarray( keep_attrs: bool = False, skipna: bool = False, na_thres: float = 1.0, - output_chunks: Optional[Tuple[int, ...]] = None, + output_chunks: Optional[Dict[str, int] | Tuple[int, ...]] = None, ) -> DataArray | Dataset: """See __call__().""" @@ -714,7 +725,7 @@ def regrid_dataset( keep_attrs: bool = False, skipna: bool = False, na_thres: float = 1.0, - output_chunks: Optional[Tuple[int, ...]] = None, + output_chunks: Optional[Dict[str, int] | Tuple[int, ...]] = None, ) -> DataArray | Dataset: """See __call__().""" From 9e9e5a140ecc900096068678458f2bdf9ad9b6ec Mon Sep 17 00:00:00 2001 From: David Huard Date: Wed, 1 Nov 2023 08:58:05 -0400 Subject: [PATCH 06/16] Apply suggestions from code review Co-authored-by: Charles Gauthier <85585494+charlesgauthier-udm@users.noreply.github.com> --- xesmf/frontend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 8677c7be..35ba517e 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -792,7 +792,7 @@ def _parse_xrinput( return input_horiz_dims, temp_horiz_dims def _format_xroutput( - self, out: xr.DataArray | xr.Dataset, new_dims: List[str] + self, out: xr.DataArray | xr.Dataset, new_dims: Optional[List[str]] = None ) -> xr.DataArray | xr.Dataset: out.attrs["regrid_method"] = self.method return out From 06a6fb4335057183f8ccaac3aed1ea9bdf72e6b7 Mon Sep 17 00:00:00 2001 From: David Huard Date: Wed, 1 Nov 2023 09:04:24 -0400 Subject: [PATCH 07/16] Update xesmf/smm.py --- xesmf/smm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xesmf/smm.py b/xesmf/smm.py index f82ea8d4..f226705c 100644 --- a/xesmf/smm.py +++ b/xesmf/smm.py @@ -171,7 +171,7 @@ def check_shapes( def apply_weights( - weights: npt.NDArray[Any], + weights: sps.COO, indata: npt.NDArray[Any], shape_in: Tuple[int, int], shape_out: Tuple[int, int], From e317564f5c4c95d73485b6f5ed8a0c8fba83fc14 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Wed, 1 Nov 2023 19:00:22 +0530 Subject: [PATCH 08/16] fixed code --- xesmf/frontend.py | 34 ++++++++++------------------------ xesmf/util.py | 6 +----- 2 files changed, 11 insertions(+), 29 deletions(-) diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 8677c7be..407c16c4 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -13,14 +13,7 @@ from shapely.geometry import LineString, MultiPolygon, Polygon from xarray import DataArray, Dataset -from .backend import ( - Grid, - LocStream, - Mesh, - add_corner, - esmf_regrid_build, - esmf_regrid_finalize, -) +from .backend import Grid, LocStream, Mesh, add_corner, esmf_regrid_build, esmf_regrid_finalize from .smm import ( _combine_weight_multipoly, _parse_coords_and_values, @@ -624,7 +617,7 @@ def _regrid( if skipna: fraction_valid = apply_weights(weights, (~missing).astype("d"), shape_in, shape_out) tol = 1e-6 - bad = fraction_valid < np.clip(1 - na_thres, tol, 1 - tol) + bad = fraction_valid < np.clip(1 - na_thresh, tol, 1 - tol) fraction_valid[bad] = 1 outdata = np.where(bad, np.nan, outdata / fraction_valid) @@ -792,7 +785,7 @@ def _parse_xrinput( return input_horiz_dims, temp_horiz_dims def _format_xroutput( - self, out: xr.DataArray | xr.Dataset, new_dims: List[str] + self, out: xr.DataArray | xr.Dataset, new_dims: Optional[List[str]] = None ) -> xr.DataArray | xr.Dataset: out.attrs["regrid_method"] = self.method return out @@ -800,19 +793,12 @@ def _format_xroutput( def __repr__(self) -> str: info = ( "xESMF Regridder \n" - "Regridding algorithm: {} \n" - "Weight filename: {} \n" - "Reuse pre-computed weights? {} \n" - "Input grid shape: {} \n" - "Output grid shape: {} \n" - "Periodic in longitude? {}".format( - self.method, - self.filename, - self.reuse_weights, - self.shape_in, - self.shape_out, - self.periodic, - ) + f"Regridding algorithm: {self.method} \n" + f"Weight filename: {self.filename} \n" + f"Reuse pre-computed weights? {self.reuse_weights} \n" + f"Input grid shape: {self.shape_in} \n" + f"Output grid shape: {self.shape_out} \n" + f"Periodic in longitude? {self.periodic}" ) return info @@ -1189,7 +1175,7 @@ def __init__( periodic: bool = False, filename: Optional[str] = None, reuse_weights: bool = False, - weights: Optional[sps.coo_matrix | dict | str | Dataset | "Path"] = None, + weights: Optional[sps.coo_matrix | dict | str | Dataset] = None, ignore_degenerate: bool = False, geom_dim_name: str = "geom", ): diff --git a/xesmf/util.py b/xesmf/util.py index 0b7c57fe..861cdb3f 100644 --- a/xesmf/util.py +++ b/xesmf/util.py @@ -289,11 +289,7 @@ def simple_tripolar_grid( def _bipolar_projection( - lamg: npt.NDArray[np.floating[Any]], - phig: npt.NDArray[np.floating[Any]], - lon_bp: npt.NDArray[np.floating[Any]], - rp: float, - metrics_only: bool = False, + lamg: float, phig: float, lon_bp: float, rp: float, metrics_only: bool = False ): """Makes a stereographic bipolar projection of the input coordinate mesh (lamg,phig) Returns the projected coordinate mesh and their metric coefficients (h^-1). From c001ab12bfbeac1d76aa9277bc8da556a6afeacc Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Fri, 10 Nov 2023 12:41:59 +0530 Subject: [PATCH 09/16] fixed typing + formatting --- xesmf/backend.py | 168 ++++++++++++++----------- xesmf/frontend.py | 308 +++++++++++++++++++++++----------------------- xesmf/smm.py | 62 +++++----- xesmf/util.py | 46 +++---- 4 files changed, 303 insertions(+), 281 deletions(-) diff --git a/xesmf/backend.py b/xesmf/backend.py index 61d7db42..ff2aaa92 100644 --- a/xesmf/backend.py +++ b/xesmf/backend.py @@ -17,7 +17,7 @@ import os import warnings -from typing import Optional +from typing import Any, List, Literal, Union try: import esmpy as ESMF @@ -25,10 +25,11 @@ import ESMF import numpy as np +import numpy.typing as npt import numpy.lib.recfunctions as nprec -def warn_f_contiguous(a: np.ndarray) -> None: +def warn_f_contiguous(a: npt.NDArray[np.floating[Any]]) -> None: """ Give a warning if input array if not Fortran-ordered. @@ -39,11 +40,11 @@ def warn_f_contiguous(a: np.ndarray) -> None: ---------- a : numpy array """ - if not a.flags["F_CONTIGUOUS"]: - warnings.warn("Input array is not F_CONTIGUOUS. " "Will affect performance.") + if not a.flags['F_CONTIGUOUS']: + warnings.warn('Input array is not F_CONTIGUOUS. ' 'Will affect performance.') -def warn_lat_range(lat: np.ndarray) -> None: +def warn_lat_range(lat: npt.NDArray[np.floating[Any]]) -> None: """ Give a warning if latitude is outside of [-90, 90] @@ -55,17 +56,17 @@ def warn_lat_range(lat: np.ndarray) -> None: lat : numpy array """ if (lat.max() > 90.0) or (lat.min() < -90.0): - warnings.warn("Latitude is outside of [-90, 90]") + warnings.warn('Latitude is outside of [-90, 90]') class Grid(ESMF.Grid): @classmethod def from_xarray( cls, - lon: np.ndarray, - lat: np.ndarray, + lon: npt.NDArray[np.floating[Any]], + lat: npt.NDArray[np.floating[Any]], periodic: bool = False, - mask: Optional[np.ndarray] = None, + mask: Union[npt.NDArray[np.integer[Any]], None] = None, ): """ Create an ESMF.Grid object, for constructing ESMF.Field and ESMF.Regrid. @@ -73,12 +74,12 @@ def from_xarray( Parameters ---------- lon, lat : 2D numpy array - Longitute/Latitude of cell centers. + Longitute/Latitude of cell centers. - Recommend Fortran-ordering to match ESMPy internal. + Recommend Fortran-ordering to match ESMPy internal. - Shape should be ``(Nlon, Nlat)`` for rectilinear grid, - or ``(Nx, Ny)`` for general quadrilateral grid. + Shape should be ``(Nlon, Nlat)`` for rectilinear grid, + or ``(Nx, Ny)`` for general quadrilateral grid. periodic : bool, optional Periodic in longitude? Default to False. @@ -105,8 +106,8 @@ def from_xarray( # ESMF.Grid can actually take 3D array (lon, lat, radius), # but regridding only works for 2D array - assert lon.ndim == 2, "Input grid must be 2D array" - assert lon.shape == lat.shape, "lon and lat must have same shape" + assert lon.ndim == 2, 'Input grid must be 2D array' + assert lon.shape == lat.shape, 'lon and lat must have same shape' staggerloc = ESMF.StaggerLoc.CENTER # actually just integer 0 @@ -144,8 +145,8 @@ def from_xarray( grid_mask = mask.astype(np.int32) if not (grid_mask.shape == lon.shape): raise ValueError( - "mask must have the same shape as the latitude/longitude" - "coordinates, got: mask.shape = %s, lon.shape = %s" % (mask.shape, lon.shape) + 'mask must have the same shape as the latitude/longitude coordinates,' + f'got: mask.shape = {mask.shape}, lon.shape = {lon.shape}' ) grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER, from_file=False) grid.mask[0][:] = grid_mask @@ -159,7 +160,11 @@ def get_shape(self, loc=ESMF.StaggerLoc.CENTER): class LocStream(ESMF.LocStream): @classmethod - def from_xarray(cls, lon: np.ndarray, lat: np.ndarray) -> ESMF.LocStream: + def from_xarray( + cls, + lon: npt.NDArray[np.floating[Any]], + lat: npt.NDArray[np.floating[Any]], + ) -> ESMF.LocStream: """ Create an ESMF.LocStream object, for contrusting ESMF.Field and ESMF.Regrid @@ -174,9 +179,9 @@ def from_xarray(cls, lon: np.ndarray, lat: np.ndarray) -> ESMF.LocStream: """ if len(lon.shape) > 1: - raise ValueError("lon can only be 1d") + raise ValueError('lon can only be 1d') if len(lat.shape) > 1: - raise ValueError("lat can only be 1d") + raise ValueError('lat can only be 1d') assert lon.shape == lat.shape @@ -184,8 +189,8 @@ def from_xarray(cls, lon: np.ndarray, lat: np.ndarray) -> ESMF.LocStream: locstream = cls(location_count, coord_sys=ESMF.CoordSys.SPH_DEG) - locstream["ESMF:Lon"] = lon.astype(np.dtype("f8")) - locstream["ESMF:Lat"] = lat.astype(np.dtype("f8")) + locstream['ESMF:Lon'] = lon.astype(np.dtype('f8')) + locstream['ESMF:Lat'] = lat.astype(np.dtype('f8')) return locstream @@ -194,7 +199,7 @@ def get_shape(self): return (self.size, 1) -def add_corner(grid, lon_b, lat_b): +def add_corner(grid: Grid, lon_b, lat_b): """ Add corner information to ESMF.Grid for conservative regridding. @@ -212,7 +217,7 @@ def add_corner(grid, lon_b, lat_b): """ # codes here are almost the same as Grid.from_xarray(), - # except for the "staggerloc" keyword + # except for the 'staggerloc' keyword staggerloc = ESMF.StaggerLoc.CORNER # actually just integer 3 for a in [lon_b, lat_b]: @@ -220,12 +225,12 @@ def add_corner(grid, lon_b, lat_b): warn_lat_range(lat_b) - assert lon_b.ndim == 2, "Input grid must be 2D array" - assert lon_b.shape == lat_b.shape, "lon_b and lat_b must have same shape" - assert np.array_equal(lon_b.shape, grid.max_index + 1), "lon_b should be size (Nx+1, Ny+1)" + assert lon_b.ndim == 2, 'Input grid must be 2D array' + assert lon_b.shape == lat_b.shape, 'lon_b and lat_b must have same shape' + assert np.array_equal(lon_b.shape, grid.max_index + 1), 'lon_b should be size (Nx+1, Ny+1)' assert (grid.num_peri_dims == 0) and ( grid.periodic_dim is None - ), "Cannot add corner for periodic grid" + ), 'Cannot add corner for periodic grid' grid.add_coords(staggerloc=staggerloc) @@ -238,7 +243,11 @@ def add_corner(grid, lon_b, lat_b): class Mesh(ESMF.Mesh): @classmethod - def from_polygons(cls, polys, element_coords="centroid"): + def from_polygons( + cls, + polys: npt.NDArray[np.object_], + element_coords: Union[Literal['centroid'], npt.NDArray[np.floating[Any]]] = 'centroid', + ): """ Create an ESMF.Mesh object from a list of polygons. @@ -248,9 +257,9 @@ def from_polygons(cls, polys, element_coords="centroid"): Parameters ---------- polys : sequence of shapely Polygon - Holes are not represented by the Mesh. - element_coords : array or "centroid", optional - If "centroid", the polygon centroids will be used (default) + Holes are not represented by the Mesh. + element_coords : array or 'centroid', optional + If 'centroid', the polygon centroids will be used (default) If an array of shape (len(polys), 2) : the element coordinates of the mesh. If None, the Mesh's elements will not have coordinates. @@ -262,13 +271,13 @@ def from_polygons(cls, polys, element_coords="centroid"): node_num = sum(len(e.exterior.coords) - 1 for e in polys) elem_num = len(polys) # Pre alloc arrays. Special structure for coords makes the code faster. - crd_dt = np.dtype([("x", np.float32), ("y", np.float32)]) + crd_dt = np.dtype([('x', np.float32), ('y', np.float32)]) node_coords = np.empty(node_num, dtype=crd_dt) node_coords[:] = (np.nan, np.nan) # Fill with impossible values element_types = np.empty(elem_num, dtype=np.uint32) element_conn = np.empty(node_num, dtype=np.uint32) # Flag for centroid calculation - calc_centroid = isinstance(element_coords, str) and element_coords == "centroid" + calc_centroid = isinstance(element_coords, str) and element_coords == 'centroid' if calc_centroid: element_coords = np.empty(elem_num, dtype=crd_dt) inode = 0 @@ -311,7 +320,7 @@ def from_polygons(cls, polys, element_coords="centroid"): ) except ValueError as err: raise ValueError( - "ESMF failed to create the Mesh, this usually happen when some polygons are invalid (test with `poly.is_valid`)" + 'ESMF failed to create the Mesh, this usually happen when some polygons are invalid (test with `poly.is_valid`)' ) from err return mesh @@ -321,15 +330,22 @@ def get_shape(self, loc=ESMF.MeshLoc.ELEMENT): def esmf_regrid_build( - sourcegrid, - destgrid, - method, - filename=None, - extra_dims=None, - extrap_method=None, - extrap_dist_exponent=None, - extrap_num_src_pnts=None, - ignore_degenerate=None, + sourcegrid: Union[Grid, Mesh], + destgrid: Union[Grid, Mesh], + method: Literal[ + 'bilinear', + 'conservative', + 'conservative_normed', + 'patch', + 'nearest_s2d', + 'nearest_d2s', + ], + filename: Union[str, None] = None, + extra_dims: Union[List[int], None] = None, + extrap_method: Union[Literal['inverse_dist', 'nearest_s2d'], None] = None, + extrap_dist_exponent: float = 2.0, + extrap_num_src_pnts: int = 8, + ignore_degenerate: bool = False, ): """ Create an ESMF.Regrid object, containing regridding weights. @@ -395,56 +411,54 @@ def esmf_regrid_build( """ # use shorter, clearer names for options in ESMF.RegridMethod - method_dict = { - "bilinear": ESMF.RegridMethod.BILINEAR, - "conservative": ESMF.RegridMethod.CONSERVE, - "conservative_normed": ESMF.RegridMethod.CONSERVE, - "patch": ESMF.RegridMethod.PATCH, - "nearest_s2d": ESMF.RegridMethod.NEAREST_STOD, - "nearest_d2s": ESMF.RegridMethod.NEAREST_DTOS, + method_dict: dict[str, int] = { + 'bilinear': ESMF.RegridMethod.BILINEAR, + 'conservative': ESMF.RegridMethod.CONSERVE, + 'conservative_normed': ESMF.RegridMethod.CONSERVE, + 'patch': ESMF.RegridMethod.PATCH, + 'nearest_s2d': ESMF.RegridMethod.NEAREST_STOD, + 'nearest_d2s': ESMF.RegridMethod.NEAREST_DTOS, } try: esmf_regrid_method = method_dict[method] except Exception: - raise ValueError("method should be chosen from " "{}".format(list(method_dict.keys()))) + raise ValueError(f'method should be chosen from {list(method_dict.keys())}') # use shorter, clearer names for options in ESMF.ExtrapMethod extrap_dict = { - "inverse_dist": ESMF.ExtrapMethod.NEAREST_IDAVG, - "nearest_s2d": ESMF.ExtrapMethod.NEAREST_STOD, + 'inverse_dist': ESMF.ExtrapMethod.NEAREST_IDAVG, + 'nearest_s2d': ESMF.ExtrapMethod.NEAREST_STOD, None: None, } try: esmf_extrap_method = extrap_dict[extrap_method] except KeyError: - raise KeyError( - "`extrap_method` should be chosen from " "{}".format(list(extrap_dict.keys())) - ) + raise KeyError(f'`extrap_method` should be chosen from {list(extrap_dict.keys())}') # until ESMPy updates ESMP_FieldRegridStoreFile, extrapolation is not possible # if files are written on disk if (extrap_method is not None) & (filename is not None): - raise ValueError("`extrap_method` cannot be used along with `filename`.") + raise ValueError('`extrap_method` cannot be used along with `filename`.') # conservative regridding needs cell corner information - if method in ["conservative", "conservative_normed"]: - if not isinstance(sourcegrid, ESMF.Mesh) and not sourcegrid.has_corners: + if method in ['conservative', 'conservative_normed']: + if not isinstance(sourcegrid, Mesh) and not sourcegrid.has_corners: raise ValueError( - "source grid has no corner information. " "cannot use conservative regridding." + 'source grid has no corner information. ' 'cannot use conservative regridding.' ) - if not isinstance(destgrid, ESMF.Mesh) and not destgrid.has_corners: + if not isinstance(destgrid, Mesh) and not destgrid.has_corners: raise ValueError( - "destination grid has no corner information. " "cannot use conservative regridding." + 'destination grid has no corner information. ' 'cannot use conservative regridding.' ) # ESMF.Regrid requires Field (Grid+data) as input, not just Grid. # Extra dimensions are specified when constructing the Field objects, # not when constructing the Regrid object later on. - if isinstance(sourcegrid, ESMF.Mesh): + if isinstance(sourcegrid, Mesh): sourcefield = ESMF.Field(sourcegrid, meshloc=ESMF.MeshLoc.ELEMENT, ndbounds=extra_dims) else: sourcefield = ESMF.Field(sourcegrid, ndbounds=extra_dims) - if isinstance(destgrid, ESMF.Mesh): + if isinstance(destgrid, Mesh): destfield = ESMF.Field(destgrid, meshloc=ESMF.MeshLoc.ELEMENT, ndbounds=extra_dims) else: destfield = ESMF.Field(destgrid, ndbounds=extra_dims) @@ -462,11 +476,11 @@ def esmf_regrid_build( if filename is not None: assert not os.path.exists( filename - ), "Weight file already exists! Please remove it or use a new name." + ), 'Weight file already exists! Please remove it or use a new name.' # re-normalize conservative regridding results # https://github.com/JiaweiZhuang/xESMF/issues/17 - if method == "conservative_normed": + if method == 'conservative_normed': norm_type = ESMF.NormType.FRACAREA else: norm_type = ESMF.NormType.DSTAREA @@ -493,7 +507,7 @@ def esmf_regrid_build( return regrid -def esmf_regrid_apply(regrid, indata): +def esmf_regrid_apply(regrid: ESMF.Regrid, indata): """ Apply existing regridding weights to the data field, using ESMPy's built-in functionality. @@ -541,7 +555,7 @@ def esmf_regrid_apply(regrid, indata): return destfield.data -def esmf_regrid_finalize(regrid): +def esmf_regrid_finalize(regrid: ESMF.Regrid): """ Free the underlying Fortran array to avoid memory leak. @@ -571,17 +585,25 @@ def esmf_regrid_finalize(regrid): # Deprecated as of version 0.5.0 -def esmf_locstream(lon, lat): +def esmf_locstream( + lon: npt.NDArray[np.floating[Any]], + lat: npt.NDArray[np.floating[Any]], +) -> LocStream: warnings.warn( - "`esmf_locstream` is being deprecated in favor of `LocStream.from_xarray`", + '`esmf_locstream` is being deprecated in favor of `LocStream.from_xarray`', DeprecationWarning, ) return LocStream.from_xarray(lon, lat) -def esmf_grid(lon, lat, periodic=False, mask=None): +def esmf_grid( + lon: npt.NDArray[np.floating[Any]], + lat: npt.NDArray[np.floating[Any]], + periodic: bool = False, + mask: Union[npt.NDArray[np.integer[Any]], None] = None, +) -> Grid: warnings.warn( - "`esmf_grid` is being deprecated in favor of `Grid.from_xarray`", + '`esmf_grid` is being deprecated in favor of `Grid.from_xarray`', DeprecationWarning, ) return Grid.from_xarray(lon, lat) diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 407c16c4..87b4adcd 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -36,12 +36,12 @@ def subset_regridder( ds_in: DataArray | Dataset | dict[str, DataArray], ds_out: DataArray | Dataset | dict[str, DataArray], method: Literal[ - "bilinear", - "conservative", - "conservative_normed", - "patch", - "nearest_s2d", - "nearest_d2s", + 'bilinear', + 'conservative', + 'conservative_normed', + 'patch', + 'nearest_s2d', + 'nearest_d2s', ], in_dims, out_dims, @@ -51,19 +51,19 @@ def subset_regridder( **kwargs, ): """Compute subset of weights""" - kwargs.pop("filename", None) # Don't save subset of weights - kwargs.pop("reuse_weights", None) + kwargs.pop('filename', None) # Don't save subset of weights + kwargs.pop('reuse_weights', None) # Renaming dims to original names for the subset regridding if locstream_in: - ds_in = ds_in.rename({"x_in": in_dims[0]}) + ds_in = ds_in.rename({'x_in': in_dims[0]}) else: - ds_in = ds_in.rename({"y_in": in_dims[0], "x_in": in_dims[1]}) + ds_in = ds_in.rename({'y_in': in_dims[0], 'x_in': in_dims[1]}) if locstream_out: - ds_out = ds_out.rename({"x_out": out_dims[1]}) + ds_out = ds_out.rename({'x_out': out_dims[1]}) else: - ds_out = ds_out.rename({"y_out": out_dims[0], "x_out": out_dims[1]}) + ds_out = ds_out.rename({'y_out': out_dims[0], 'x_out': out_dims[1]}) regridder = Regridder( ds_in, @@ -83,11 +83,11 @@ def as_2d_mesh( lat: DataArray | npt.NDArray[np.floating[Any]], ) -> tuple[DataArray | npt.NDArray[np.floating[Any]], DataArray | npt.NDArray[np.floating[Any]]]: if (lon.ndim, lat.ndim) == (2, 2): - assert lon.shape == lat.shape, "lon and lat should have same shape" + assert lon.shape == lat.shape, 'lon and lat should have same shape' elif (lon.ndim, lat.ndim) == (1, 1): lon, lat = np.meshgrid(lon, lat) else: - raise ValueError("lon and lat should be both 1D or 2D") + raise ValueError('lon and lat should be both 1D or 2D') return lon, lat @@ -96,17 +96,17 @@ def _get_lon_lat( ds: Dataset | Dict[str, npt.NDArray[np.floating[Any]]] ) -> tuple[DataArray | npt.NDArray[np.floating[Any]], DataArray | npt.NDArray[np.floating[Any]]]: """Return lon and lat extracted from ds.""" - if ("lat" in ds and "lon" in ds) or ("lat" in ds.coords and "lon" in ds.coords): + if ('lat' in ds and 'lon' in ds) or ('lat' in ds.coords and 'lon' in ds.coords): # Old way. - return ds["lon"], ds["lat"] + return ds['lon'], ds['lat'] # else : cf-xarray way try: - lon = ds.cf["longitude"] - lat = ds.cf["latitude"] + lon = ds.cf['longitude'] + lat = ds.cf['latitude'] except (KeyError, AttributeError, ValueError): # KeyError if cfxr doesn't detect the coords # AttributeError if ds is a dict - raise ValueError("dataset must include lon/lat or be CF-compliant") + raise ValueError('dataset must include lon/lat or be CF-compliant') return lon, lat @@ -115,30 +115,30 @@ def _get_lon_lat_bounds( ds: Dataset | Dict[str, npt.NDArray[np.floating[Any]]] ) -> tuple[DataArray | npt.NDArray[np.floating[Any]], DataArray | npt.NDArray[np.floating[Any]]]: """Return bounds of lon and lat extracted from ds.""" - if "lat_b" in ds and "lon_b" in ds: + if 'lat_b' in ds and 'lon_b' in ds: # Old way. - return ds["lon_b"], ds["lat_b"] + return ds['lon_b'], ds['lat_b'] # else : cf-xarray way - if "longitude" not in ds.cf.coordinates: + if 'longitude' not in ds.cf.coordinates: # If we are here, _get_lon_lat() didn't fail, thus we should be able to guess the coords. ds = ds.cf.guess_coord_axis() try: - lon_bnds = ds.cf.get_bounds("longitude") - lat_bnds = ds.cf.get_bounds("latitude") + lon_bnds = ds.cf.get_bounds('longitude') + lat_bnds = ds.cf.get_bounds('latitude') except KeyError: # bounds are not already present - if ds.cf["longitude"].ndim > 1: - # We cannot infer 2D bounds, raise KeyError as custom "lon_b" is missing. - raise KeyError("lon_b") - lon_name = ds.cf["longitude"].name - lat_name = ds.cf["latitude"].name + if ds.cf['longitude'].ndim > 1: + # We cannot infer 2D bounds, raise KeyError as custom 'lon_b' is missing. + raise KeyError('lon_b') + lon_name = ds.cf['longitude'].name + lat_name = ds.cf['latitude'].name ds = ds.cf.add_bounds([lon_name, lat_name]) - lon_bnds = ds.cf.get_bounds("longitude") - lat_bnds = ds.cf.get_bounds("latitude") + lon_bnds = ds.cf.get_bounds('longitude') + lat_bnds = ds.cf.get_bounds('latitude') # Convert from CF bounds to xESMF bounds. # order=None is because we don't want to assume the dimension order for 2D bounds. - lon_b = cfxr.bounds_to_vertices(lon_bnds, ds.cf.get_bounds_dim_name("longitude"), order=None) - lat_b = cfxr.bounds_to_vertices(lat_bnds, ds.cf.get_bounds_dim_name("latitude"), order=None) + lon_b = cfxr.bounds_to_vertices(lon_bnds, ds.cf.get_bounds_dim_name('longitude'), order=None) + lat_b = cfxr.bounds_to_vertices(lat_bnds, ds.cf.get_bounds_dim_name('latitude'), order=None) return lon_b, lat_b @@ -179,7 +179,7 @@ def ds_to_ESMFgrid( # use np.asarray(dr) instead of dr.values, so it also works for dictionary lon, lat = _get_lon_lat(ds) - if hasattr(lon, "dims"): + if hasattr(lon, 'dims'): if lon.ndim == 1: dim_names = lat.dims + lon.dims else: @@ -188,8 +188,8 @@ def ds_to_ESMFgrid( dim_names = None lon, lat = as_2d_mesh(np.asarray(lon), np.asarray(lat)) - if "mask" in ds: - mask = np.asarray(ds["mask"]) + if 'mask' in ds: + mask = np.asarray(ds['mask']) else: mask = None @@ -223,16 +223,16 @@ def ds_to_ESMFlocstream(ds): """ lon, lat = _get_lon_lat(ds) - if hasattr(lon, "dims"): + if hasattr(lon, 'dims'): dim_names = lon.dims else: dim_names = None lon, lat = np.asarray(lon), np.asarray(lat) if len(lon.shape) > 1: - raise ValueError("lon can only be 1d") + raise ValueError('lon can only be 1d') if len(lat.shape) > 1: - raise ValueError("lat can only be 1d") + raise ValueError('lat can only be 1d') assert lon.shape == lat.shape @@ -262,7 +262,7 @@ def polys_to_ESMFmesh(polys) -> tuple[Mesh, tuple[Literal[1], int]]: ext, holes, _, _ = split_polygons_and_holes(polys) if len(holes) > 0: warnings.warn( - "Some passed polygons have holes, those are not represented in the returned Mesh." + 'Some passed polygons have holes, those are not represented in the returned Mesh.' ) return Mesh.from_polygons(ext), (1, len(ext)) @@ -275,7 +275,7 @@ def __init__( method: str, filename: Optional[str] = None, reuse_weights: bool = False, - extrap_method: Optional[Literal["inverse_dist", "nearest_s2d"]] = None, + extrap_method: Optional[Literal['inverse_dist', 'nearest_s2d']] = None, extrap_dist_exponent: Optional[float] = None, extrap_num_src_pnts: Optional[int] = None, weights: Optional[Any] = None, @@ -378,17 +378,17 @@ def __init__( self.extrap_dist_exponent = extrap_dist_exponent self.extrap_num_src_pnts = extrap_num_src_pnts self.ignore_degenerate = ignore_degenerate - self.periodic = getattr(self.grid_in, "periodic_dim", None) is not None + self.periodic = getattr(self.grid_in, 'periodic_dim', None) is not None self.sequence_in = isinstance(self.grid_in, (LocStream, Mesh)) self.sequence_out = isinstance(self.grid_out, (LocStream, Mesh)) if input_dims is not None and len(input_dims) != int(not self.sequence_in) + 1: - raise ValueError(f"Wrong number of dimension names in `input_dims` ({len(input_dims)}.") + raise ValueError(f'Wrong number of dimension names in `input_dims` ({len(input_dims)}.') self.in_horiz_dims = input_dims if output_dims is not None and len(output_dims) != int(not self.sequence_out) + 1: raise ValueError( - f"Wrong number of dimension names in `output dims` ({len(output_dims)}." + f'Wrong number of dimension names in `output dims` ({len(output_dims)}.' ) self.out_horiz_dims = output_dims @@ -401,7 +401,7 @@ def __init__( # some logic about reusing weights with either filename or weights args if reuse_weights and (filename is None) and (weights is None): - raise ValueError("To reuse weights, you need to provide either filename or weights.") + raise ValueError('To reuse weights, you need to provide either filename or weights.') if not parallel: if not reuse_weights and weights is None: @@ -430,8 +430,8 @@ def __init__( @property def A(self) -> DataArray: message = ( - "regridder.A is deprecated and will be removed in future versions. " - "Use regridder.weights instead." + 'regridder.A is deprecated and will be removed in future versions. ' + 'Use regridder.weights instead.' ) warnings.warn(message, DeprecationWarning) @@ -451,12 +451,12 @@ def w(self) -> xr.DataArray: # TODO: Add coords ? s = self.shape_out + self.shape_in data = self.weights.data.reshape(s) - dims = "y_out", "x_out", "y_in", "x_in" + dims = 'y_out', 'x_out', 'y_in', 'x_in' return xr.DataArray(data, dims=dims) def _get_default_filename(self) -> str: # e.g. bilinear_400x600_300x400.nc - filename = "{0}_{1}x{2}_{3}x{4}".format( + filename = '{0}_{1}x{2}_{3}x{4}'.format( self.method, self.shape_in[0], self.shape_in[1], @@ -465,9 +465,9 @@ def _get_default_filename(self) -> str: ) if self.periodic: - filename += "_peri.nc" + filename += '_peri.nc' else: - filename += ".nc" + filename += '.nc' return filename @@ -593,7 +593,7 @@ def __call__( output_chunks=output_chunks, ) else: - raise TypeError("input must be numpy array, dask array, xarray DataArray or Dataset!") + raise TypeError('input must be numpy array, dask array, xarray DataArray or Dataset!') @staticmethod def _regrid( @@ -615,7 +615,7 @@ def _regrid( # skipna: Compute the influence of missing data at each interpolation point and filter those not meeting acceptable threshold. if skipna: - fraction_valid = apply_weights(weights, (~missing).astype("d"), shape_in, shape_out) + fraction_valid = apply_weights(weights, (~missing).astype('d'), shape_in, shape_out) tol = 1e-6 bad = fraction_valid < np.clip(1 - na_thresh, tol, 1 - tol) fraction_valid[bad] = 1 @@ -640,8 +640,8 @@ def regrid_array( output_chunks = tuple(map(output_chunks.get, self.out_horiz_dims)) kwargs = { - "shape_in": self.shape_in, - "shape_out": self.shape_out, + 'shape_in': self.shape_in, + 'shape_out': self.shape_out, } check_shapes(indata, weights, **kwargs) @@ -659,9 +659,9 @@ def regrid_array( output_chunks = (1, output_chunks[0]) else: raise ValueError( - f"output_chunks must have same dimension as ds_out," - f" output_chunks dimension ({len(output_chunks)}) does not " - f"match ds_out dimension ({len(self.shape_out)})" + f'output_chunks must have same dimension as ds_out,' + f' output_chunks dimension ({len(output_chunks)}) does not ' + f'match ds_out dimension ({len(self.shape_out)})' ) weights = da.from_array(weights, chunks=(output_chunks + indata.chunksize[-2:])) outdata = self._regrid(indata, weights, **kwargs) @@ -673,7 +673,7 @@ def regrid_numpy( self, indata: npt.NDArray[np.floating[Any]], **kwargs ) -> npt.NDArray[np.floating[Any]]: warnings.warn( - "`regrid_numpy()` will be removed in xESMF 0.7, please use `regrid_array` instead.", + '`regrid_numpy()` will be removed in xESMF 0.7, please use `regrid_array` instead.', category=FutureWarning, ) return self.regrid_array(indata, self.weights.data, **kwargs) @@ -682,7 +682,7 @@ def regrid_dask( self, indata: npt.NDArray[np.floating[Any]], **kwargs ) -> npt.NDArray[np.floating[Any]]: warnings.warn( - "`regrid_dask()` will be removed in xESMF 0.7, please use `regrid_array` instead.", + '`regrid_dask()` will be removed in xESMF 0.7, please use `regrid_array` instead.', category=FutureWarning, ) return self.regrid_array(indata, self.weights.data, **kwargs) @@ -704,9 +704,9 @@ def regrid_dataarray( dr_in, self.weights, kwargs=kwargs, - input_core_dims=[input_horiz_dims, ("out_dim", "in_dim")], + input_core_dims=[input_horiz_dims, ('out_dim', 'in_dim')], output_core_dims=[temp_horiz_dims], - dask="allowed", + dask='allowed', keep_attrs=keep_attrs, ) @@ -739,9 +739,9 @@ def regrid_dataset( ds_in, self.weights, kwargs=kwargs, - input_core_dims=[input_horiz_dims, ("out_dim", "in_dim")], + input_core_dims=[input_horiz_dims, ('out_dim', 'in_dim')], output_core_dims=[temp_horiz_dims], - dask="allowed", + dask='allowed', keep_attrs=keep_attrs, ) @@ -769,36 +769,36 @@ def _parse_xrinput( # help user debugging invalid horizontal dimensions warnings.warn( ( - f"Using dimensions {input_horiz_dims} from data variable {name} " - "as the horizontal dimensions for the regridding." + f'Using dimensions {input_horiz_dims} from data variable {name} ' + 'as the horizontal dimensions for the regridding.' ), UserWarning, ) if self.sequence_out: - temp_horiz_dims: List[str] = ["dummy", "locations"] + temp_horiz_dims: List[str] = ['dummy', 'locations'] else: - temp_horiz_dims: List[str] = [s + "_new" for s in input_horiz_dims] + temp_horiz_dims: List[str] = [s + '_new' for s in input_horiz_dims] if self.sequence_in and not self.sequence_out: - temp_horiz_dims = ["dummy_new"] + temp_horiz_dims + temp_horiz_dims = ['dummy_new'] + temp_horiz_dims return input_horiz_dims, temp_horiz_dims def _format_xroutput( self, out: xr.DataArray | xr.Dataset, new_dims: Optional[List[str]] = None ) -> xr.DataArray | xr.Dataset: - out.attrs["regrid_method"] = self.method + out.attrs['regrid_method'] = self.method return out def __repr__(self) -> str: info = ( - "xESMF Regridder \n" - f"Regridding algorithm: {self.method} \n" - f"Weight filename: {self.filename} \n" - f"Reuse pre-computed weights? {self.reuse_weights} \n" - f"Input grid shape: {self.shape_in} \n" - f"Output grid shape: {self.shape_out} \n" - f"Periodic in longitude? {self.periodic}" + 'xESMF Regridder \n' + f'Regridding algorithm: {self.method} \n' + f'Weight filename: {self.filename} \n' + f'Reuse pre-computed weights? {self.reuse_weights} \n' + f'Input grid shape: {self.shape_in} \n' + f'Output grid shape: {self.shape_out} \n' + f'Periodic in longitude? {self.periodic}' ) return info @@ -808,12 +808,12 @@ def to_netcdf(self, filename: Optional[str] = None) -> str: if filename is None: filename = self.filename w = self.weights.data - dim = "n_s" + dim = 'n_s' ds = xr.Dataset( { - "S": (dim, w.data), - "col": (dim, w.coords[1, :] + 1), - "row": (dim, w.coords[0, :] + 1), + 'S': (dim, w.data), + 'col': (dim, w.coords[1, :] + 1), + 'row': (dim, w.coords[0, :] + 1), } ) ds.to_netcdf(filename) @@ -826,12 +826,12 @@ def __init__( ds_in: xr.DataArray | xr.Dataset | dict[str, xr.DataArray], ds_out: xr.DataArray | xr.Dataset | dict[str, xr.DataArray], method: Literal[ - "bilinear", - "conservative", - "conservative_normed", - "patch", - "nearest_s2d", - "nearest_d2s", + 'bilinear', + 'conservative', + 'conservative_normed', + 'patch', + 'nearest_s2d', + 'nearest_d2s', ], locstream_in: bool = False, locstream_out: bool = False, @@ -934,30 +934,30 @@ def __init__( ------- regridder : xESMF regridder object """ - methods_avail_ls_in = ["nearest_s2d", "nearest_d2s"] - methods_avail_ls_out = ["bilinear", "patch"] + methods_avail_ls_in + methods_avail_ls_in = ['nearest_s2d', 'nearest_d2s'] + methods_avail_ls_out = ['bilinear', 'patch'] + methods_avail_ls_in if locstream_in and method not in methods_avail_ls_in: raise ValueError( - f"locstream input is only available for method in {methods_avail_ls_in}" + f'locstream input is only available for method in {methods_avail_ls_in}' ) if locstream_out and method not in methods_avail_ls_out: raise ValueError( - f"locstream output is only available for method in {methods_avail_ls_out}" + f'locstream output is only available for method in {methods_avail_ls_out}' ) - reuse_weights = kwargs.get("reuse_weights", False) + reuse_weights = kwargs.get('reuse_weights', False) - weights = kwargs.get("weights", None) + weights = kwargs.get('weights', None) if parallel and (reuse_weights or weights is not None): parallel = False warnings.warn( - "Cannot use parallel=True when reuse_weights=True or when weights is not None. Building Regridder normally." + 'Cannot use parallel=True when reuse_weights=True or when weights is not None. Building Regridder normally.' ) # Record basic switches - if method in ["conservative", "conservative_normed"]: + if method in ['conservative', 'conservative_normed']: need_bounds = True periodic = False # bound shape will not be N+1 for periodic grid else: @@ -997,24 +997,24 @@ def __init__( lon_out, lat_out = _get_lon_lat(ds_out) if not isinstance(lon_out, DataArray): if lon_out.ndim == 2: - dims = [("y", "x"), ("y", "x")] + dims = [('y', 'x'), ('y', 'x')] elif self.sequence_out: - dims = [("locations",), ("locations",)] + dims = [('locations',), ('locations',)] else: - dims = [("lon",), ("lat",)] - lon_out = xr.DataArray(lon_out, dims=dims[0], name="lon", attrs=LON_CF_ATTRS) - lat_out = xr.DataArray(lat_out, dims=dims[1], name="lat", attrs=LAT_CF_ATTRS) + dims = [('lon',), ('lat',)] + lon_out = xr.DataArray(lon_out, dims=dims[0], name='lon', attrs=LON_CF_ATTRS) + lat_out = xr.DataArray(lat_out, dims=dims[1], name='lat', attrs=LAT_CF_ATTRS) if lat_out.ndim == 2: self.out_horiz_dims = lat_out.dims elif self.sequence_out: if lat_out.dims != lon_out.dims: raise ValueError( - "Regridder expects a locstream output, but the passed longitude " - "and latitude are not specified along the same dimension. " - f"(lon: {lon_out.dims}, lat: {lat_out.dims})" + 'Regridder expects a locstream output, but the passed longitude ' + 'and latitude are not specified along the same dimension. ' + f'(lon: {lon_out.dims}, lat: {lat_out.dims})' ) - self.out_horiz_dims = ("dummy",) + lat_out.dims + self.out_horiz_dims = ('dummy',) + lat_out.dims else: self.out_horiz_dims = (lat_out.dims[0], lon_out.dims[0]) @@ -1025,9 +1025,9 @@ def __init__( if set(self.out_horiz_dims).issuperset(crd.dims) } grid_mapping = { - var.attrs["grid_mapping"] + var.attrs['grid_mapping'] for var in ds_out.data_vars.values() - if "grid_mapping" in var.attrs + if 'grid_mapping' in var.attrs } if grid_mapping: self.out_coords.update({gm: ds_out[gm] for gm in grid_mapping if gm in ds_out}) @@ -1044,39 +1044,39 @@ def _init_para_regrid( kwargs: dict, ): # Check if we have bounds as variable and not coords, and add them to coords in both datasets - if "lon_b" in ds_out.data_vars: - ds_out = ds_out.set_coords(["lon_b", "lat_b"]) - if "lon_b" in ds_in.data_vars: - ds_in = ds_in.set_coords(["lon_b", "lat_b"]) - if not (set(self.out_horiz_dims) - {"dummy"}).issubset(ds_out.chunksizes.keys()): + if 'lon_b' in ds_out.data_vars: + ds_out = ds_out.set_coords(['lon_b', 'lat_b']) + if 'lon_b' in ds_in.data_vars: + ds_in = ds_in.set_coords(['lon_b', 'lat_b']) + if not (set(self.out_horiz_dims) - {'dummy'}).issubset(ds_out.chunksizes.keys()): raise ValueError( - "Using `parallel=True` requires the output grid to have chunks along all spatial dimensions. " - "If the dataset has no variables, consider adding an all-True spatial mask with appropriate chunks." + 'Using `parallel=True` requires the output grid to have chunks along all spatial dimensions. ' + 'If the dataset has no variables, consider adding an all-True spatial mask with appropriate chunks.' ) # Drop everything in ds_out except mask or create mask if None. This is to prevent map_blocks loading unnecessary large data if self.sequence_out: ds_out_dims_drop = set(ds_out.variables).difference(ds_out.data_vars) ds_out = ds_out.drop_dims(ds_out_dims_drop) else: - if "mask" in ds_out: + if 'mask' in ds_out: mask = ds_out.mask ds_out = ds_out.coords.to_dataset() - ds_out["mask"] = mask + ds_out['mask'] = mask else: ds_out_chunks = tuple([ds_out.chunksizes[i] for i in self.out_horiz_dims]) ds_out = ds_out.coords.to_dataset() mask = da.ones(self.shape_out, dtype=bool, chunks=ds_out_chunks) - ds_out["mask"] = (self.out_horiz_dims, mask) + ds_out['mask'] = (self.out_horiz_dims, mask) ds_out_dims_drop = set(ds_out.cf.coordinates.keys()).difference( - ["longitude", "latitude"] + ['longitude', 'latitude'] ) ds_out = ds_out.cf.drop_dims(ds_out_dims_drop) # Drop unnecessary variables in ds_in to save memory if not self.sequence_in: # Drop unnecessary dims - ds_in_dims_drop = set(ds_in.cf.coordinates.keys()).difference(["longitude", "latitude"]) + ds_in_dims_drop = set(ds_in.cf.coordinates.keys()).difference(['longitude', 'latitude']) ds_in = ds_in.cf.drop_dims(ds_in_dims_drop) # Drop unnecessary vars @@ -1087,35 +1087,35 @@ def _init_para_regrid( ds_in = ds_in.compute() # if bounds in ds_out, we switch to cf bounds for map_blocks - if "lon_b" in ds_out and (ds_out.lon_b.ndim == ds_out.cf["longitude"].ndim): + if 'lon_b' in ds_out and (ds_out.lon_b.ndim == ds_out.cf['longitude'].ndim): ds_out = ds_out.assign_coords( lon_bounds=cfxr.vertices_to_bounds( - ds_out.lon_b, ("bounds", *ds_out.cf["longitude"].dims) + ds_out.lon_b, ('bounds', *ds_out.cf['longitude'].dims) ), lat_bounds=cfxr.vertices_to_bounds( - ds_out.lat_b, ("bounds", *ds_out.cf["latitude"].dims) + ds_out.lat_b, ('bounds', *ds_out.cf['latitude'].dims) ), ) # Make cf-xarray aware of the new bounds - ds_out[ds_out.cf["longitude"].name].attrs["bounds"] = "lon_bounds" - ds_out[ds_out.cf["latitude"].name].attrs["bounds"] = "lat_bounds" + ds_out[ds_out.cf['longitude'].name].attrs['bounds'] = 'lon_bounds' + ds_out[ds_out.cf['latitude'].name].attrs['bounds'] = 'lat_bounds' ds_out = ds_out.drop_dims(ds_out.lon_b.dims + ds_out.lat_b.dims) # rename dims to avoid map_blocks confusing ds_in and ds_out dims. if self.sequence_in: - ds_in = ds_in.rename({self.in_horiz_dims[0]: "x_in"}) + ds_in = ds_in.rename({self.in_horiz_dims[0]: 'x_in'}) else: - ds_in = ds_in.rename({self.in_horiz_dims[0]: "y_in", self.in_horiz_dims[1]: "x_in"}) + ds_in = ds_in.rename({self.in_horiz_dims[0]: 'y_in', self.in_horiz_dims[1]: 'x_in'}) if self.sequence_out: - ds_out = ds_out.rename({self.out_horiz_dims[1]: "x_out"}) - out_chunks = ds_out.chunks.get("x_out") + ds_out = ds_out.rename({self.out_horiz_dims[1]: 'x_out'}) + out_chunks = ds_out.chunks.get('x_out') else: ds_out = ds_out.rename( - {self.out_horiz_dims[0]: "y_out", self.out_horiz_dims[1]: "x_out"} + {self.out_horiz_dims[0]: 'y_out', self.out_horiz_dims[1]: 'x_out'} ) - out_chunks = [ds_out.chunks.get(k) for k in ["y_out", "x_out"]] + out_chunks = [ds_out.chunks.get(k) for k in ['y_out', 'x_out']] - weights_dims = ("y_out", "x_out", "y_in", "x_in") + weights_dims = ('y_out', 'x_out', 'y_in', 'x_in') templ = sps.zeros((self.shape_out + self.shape_in)) w_templ = xr.DataArray(templ, dims=weights_dims).chunk( out_chunks @@ -1136,14 +1136,14 @@ def _init_para_regrid( kwargs=kwargs, template=w_templ, ) - w = w.compute(scheduler="processes") + w = w.compute(scheduler='processes') weights = w.stack(out_dim=weights_dims[:2], in_dim=weights_dims[2:]) - weights.name = "weights" + weights.name = 'weights' self.weights = weights # follows legacy logic of writing weights if filename is provided - if "filename" in kwargs: - filename = kwargs["filename"] + if 'filename' in kwargs: + filename = kwargs['filename'] else: filename = None if filename is not None and not self.reuse_weights: @@ -1158,10 +1158,10 @@ def _format_xroutput(self, out, new_dims=None): out = out.rename({nd: od for nd, od in zip(new_dims, self.out_horiz_dims)}) out = out.assign_coords(**self.out_coords) - out.attrs["regrid_method"] = self.method + out.attrs['regrid_method'] = self.method if self.sequence_out: - out = out.squeeze(dim="dummy") + out = out.squeeze(dim='dummy') return out @@ -1177,7 +1177,7 @@ def __init__( reuse_weights: bool = False, weights: Optional[sps.coo_matrix | dict | str | Dataset] = None, ignore_degenerate: bool = False, - geom_dim_name: str = "geom", + geom_dim_name: str = 'geom', ): """Compute the exact average of a gridded array over a geometry. @@ -1267,12 +1267,12 @@ def __init__( # Create an output locstream so that the regridder knows the output shape and coords. # Latitude and longitude coordinates are the polygon centroid. lon_out, lat_out = _get_lon_lat(ds_in) - if hasattr(lon_out, "name"): + if hasattr(lon_out, 'name'): self._lon_out_name = lon_out.name self._lat_out_name = lat_out.name else: - self._lon_out_name = "lon" - self._lat_out_name = "lat" + self._lon_out_name = 'lon' + self._lat_out_name = 'lat' # Check length of polys segments self._check_polys_length(polys) @@ -1283,14 +1283,14 @@ def __init__( # We put names 'lon' and 'lat' so ds_to_ESMFlocstream finds them easily. # _lon_out_name and _lat_out_name are used on the output anyway. - ds_out = {"lon": self._lon_out, "lat": self._lat_out} + ds_out = {'lon': self._lon_out, 'lat': self._lat_out} locstream_out, shape_out, _ = ds_to_ESMFlocstream(ds_out) # BaseRegridder with custom-computed weights and dummy out grid super().__init__( grid_in, locstream_out, - "conservative", + 'conservative', input_dims=input_dims, weights=weights, filename=filename, @@ -1311,7 +1311,7 @@ def _check_polys_length(polys: List[Polygon], threshold: int = 1) -> None: poly_segments.extend([LineString(b[k : k + 2]).length for k in range(len(b) - 1)]) if np.any(np.array(poly_segments) > threshold): warnings.warn( - f"`polys` contains large (> {threshold}°) segments. This could lead to errors over large regions. For a more accurate average, segmentize (densify) your shapes with `shapely.segmentize(polys, {threshold})`", + f'`polys` contains large (> {threshold}°) segments. This could lead to errors over large regions. For a more accurate average, segmentize (densify) your shapes with `shapely.segmentize(polys, {threshold})`', UserWarning, stacklevel=2, ) @@ -1323,7 +1323,7 @@ def _compute_weights_and_area(self, mesh_out: Mesh) -> tuple[DataArray, Any]: regrid = esmf_regrid_build( self.grid_in, mesh_out, - method="conservative", + method='conservative', ignore_degenerate=self.ignore_degenerate, ) @@ -1374,7 +1374,7 @@ def _compute_weights(self) -> DataArray: w_int, area_int = self._compute_weights_and_area(mesh_int) # Append weights from holes as negative weights - w = xr.concat((w, -w_int), "out_dim") + w = xr.concat((w, -w_int), 'out_dim') # Append areas area = np.concatenate([area, area_int]) @@ -1392,12 +1392,12 @@ def w(self) -> xr.DataArray: """ s = self.shape_out[1:2] + self.shape_in data = self.weights.data.reshape(s) - dims = self.geom_dim_name, "y_in", "x_in" + dims = self.geom_dim_name, 'y_in', 'x_in' return xr.DataArray(data, dims=dims) def _get_default_filename(self) -> str: # e.g. bilinear_400x600_300x400.nc - filename = "spatialavg_{0}x{1}_{2}.nc".format( + filename = 'spatialavg_{0}x{1}_{2}.nc'.format( self.shape_in[0], self.shape_in[1], self.n_out ) @@ -1405,17 +1405,17 @@ def _get_default_filename(self) -> str: def __repr__(self) -> str: info = ( - f"xESMF SpatialAverager \n" - f"Weight filename: {self.filename} \n" - f"Reuse pre-computed weights: {self.reuse_weights} \n" - f"Input grid shape: {self.shape_in} \n" - f"Output list length: {self.n_out} \n" + f'xESMF SpatialAverager \n' + f'Weight filename: {self.filename} \n' + f'Reuse pre-computed weights: {self.reuse_weights} \n' + f'Input grid shape: {self.shape_in} \n' + f'Output list length: {self.n_out} \n' ) return info def _format_xroutput(self, out: DataArray | Dataset, new_dims=None) -> DataArray | Dataset: - out = out.squeeze(dim="dummy") + out = out.squeeze(dim='dummy') # rename dimension name to match output grid out = out.rename(locations=self.geom_dim_name) @@ -1424,5 +1424,5 @@ def _format_xroutput(self, out: DataArray | Dataset, new_dims=None) -> DataArray # extra coordinates are automatically tracked by apply_ufunc out.coords[self._lon_out_name] = xr.DataArray(self._lon_out, dims=(self.geom_dim_name,)) out.coords[self._lat_out_name] = xr.DataArray(self._lat_out, dims=(self.geom_dim_name,)) - out.attrs["regrid_method"] = self.method + out.attrs['regrid_method'] = self.method return out diff --git a/xesmf/smm.py b/xesmf/smm.py index f226705c..f46b0f86 100644 --- a/xesmf/smm.py +++ b/xesmf/smm.py @@ -47,12 +47,12 @@ def read_weights( return _parse_coords_and_values(weights, n_in, n_out) if isinstance(weights, sps.COO): - return xr.DataArray(weights, dims=("out_dim", "in_dim"), name="weights") + return xr.DataArray(weights, dims=('out_dim', 'in_dim'), name='weights') if isinstance(weights, xr.DataArray): # type: ignore[no-untyped-def] return weights - raise ValueError(f"Weights of type {type(weights)} not understood.") + raise ValueError(f'Weights of type {type(weights)} not understood.') def _parse_coords_and_values( @@ -81,33 +81,33 @@ def _parse_coords_and_values( if isinstance(indata, (str, Path, xr.Dataset)): if not isinstance(indata, xr.Dataset): if not Path(indata).exists(): - raise IOError(f"Weights file not found on disk.\n{indata}") + raise IOError(f'Weights file not found on disk.\n{indata}') ds_w = xr.open_dataset(indata) # type: ignore[no-untyped-def] else: ds_w = indata - if not {"col", "row", "S"}.issubset(ds_w.variables): + if not {'col', 'row', 'S'}.issubset(ds_w.variables): raise ValueError( - "Weights dataset should have variables `col`, `row` and `S` storing the indices " - "and values of weights." + 'Weights dataset should have variables `col`, `row` and `S` storing the indices ' + 'and values of weights.' ) - col = ds_w["col"].values - 1 # type: ignore[no-untyped-def] - row = ds_w["row"].values - 1 # type: ignore[no-untyped-def] - s = ds_w["S"].values # type: ignore[no-untyped-def] + col = ds_w['col'].values - 1 # type: ignore[no-untyped-def] + row = ds_w['row'].values - 1 # type: ignore[no-untyped-def] + s = ds_w['S'].values # type: ignore[no-untyped-def] elif isinstance(indata, dict): # type: ignore - if not {"col_src", "row_dst", "weights"}.issubset(indata.keys()): + if not {'col_src', 'row_dst', 'weights'}.issubset(indata.keys()): raise ValueError( - "Weights dictionary should have keys `col_src`, `row_dst` and `weights` storing " - "the indices and values of weights." + 'Weights dictionary should have keys `col_src`, `row_dst` and `weights` storing ' + 'the indices and values of weights.' ) - col = indata["col_src"] - 1 - row = indata["row_dst"] - 1 - s = indata["weights"] + col = indata['col_src'] - 1 + row = indata['row_dst'] - 1 + s = indata['weights'] crds = np.stack([row, col]) - return xr.DataArray(sps.COO(crds, s, (n_out, n_in)), dims=("out_dim", "in_dim"), name="weights") + return xr.DataArray(sps.COO(crds, s, (n_out, n_in)), dims=('out_dim', 'in_dim'), name='weights') def check_shapes( @@ -141,8 +141,8 @@ def check_shapes( # COO matrix is fast with F-ordered array but slow with C-array, so we # take in a C-ordered and then transpose) # (CSR or CRS matrix is fast with C-ordered array but slow with F-array) - if hasattr(indata, "flags") and not indata.flags["C_CONTIGUOUS"]: - warnings.warn("Input array is not C_CONTIGUOUS. " "Will affect performance.") + if hasattr(indata, 'flags') and not indata.flags['C_CONTIGUOUS']: + warnings.warn('Input array is not C_CONTIGUOUS. ' 'Will affect performance.') # Limitation from numba : some big-endian dtypes are not supported. try: @@ -150,8 +150,8 @@ def check_shapes( nb.from_dtype(weights.dtype) # type: ignore except (NotImplementedError, nb.core.errors.NumbaError): # type: ignore warnings.warn( - "Input array has a dtype not supported by sparse and numba." - "Computation will fall back to scipy." + 'Input array has a dtype not supported by sparse and numba.' + 'Computation will fall back to scipy.' ) # get input shape information @@ -159,15 +159,15 @@ def check_shapes( if shape_horiz != shape_in: raise ValueError( - f"The horizontal shape of input data is {shape_horiz}, different from that " - f"of the regridder {shape_in}!" + f'The horizontal shape of input data is {shape_horiz}, different from that ' + f'of the regridder {shape_in}!' ) if shape_in[0] * shape_in[1] != weights.shape[1]: - raise ValueError("ny_in * nx_in should equal to weights.shape[1]") + raise ValueError('ny_in * nx_in should equal to weights.shape[1]') if shape_out[0] * shape_out[1] != weights.shape[0]: - raise ValueError("ny_out * nx_out should equal to weights.shape[0]") + raise ValueError('ny_out * nx_out should equal to weights.shape[0]') def apply_weights( @@ -203,7 +203,7 @@ def apply_weights( nb.from_dtype(indata.dtype) # type: ignore nb.from_dtype(weights.dtype) # type: ignore except (NotImplementedError, nb.core.errors.NumbaError): # type: ignore - indata = indata.astype(" HUGE, 0.0, B) lamc = np.arcsin(B) / PI_180 # But this equation accepts 4 solutions for a given B, {l, 180-l, l+180, 360-l } - # We have to pickup the "correct" root. + # We have to pickup the 'correct' root. # One way is simply to demand lamc to be continuous with lam on the equator phi=0 # I am sure there is a more mathematically concrete way to do this. lamc = np.where((lamg - lon_bp > 90) & (lamg - lon_bp <= 180), 180 - lamc, lamc) @@ -369,11 +369,11 @@ def _generate_bipolar_cap_mesh( ): # Define a (lon,lat) coordinate mesh on the Northern hemisphere of the globe sphere # such that the resolution of latg matches the desired resolution of the final grid along the symmetry meridian - print("Generating bipolar grid bounded at latitude ", lat0_bp) + print('Generating bipolar grid bounded at latitude ', lat0_bp) if Nj_ncap % 2 != 0 and ensure_nj_even: - print(" Supergrid has an odd number of area cells!") + print(' Supergrid has an odd number of area cells!') if ensure_nj_even: - print(" The number of j's is not even. Fixing this by cutting one row.") + print(' The number of j's is not even. Fixing this by cutting one row.') Nj_ncap = Nj_ncap - 1 lon_g = lon_bp + np.arange(Ni + 1) * 360.0 / float(Ni) @@ -384,7 +384,7 @@ def _generate_bipolar_cap_mesh( lams, phis, h_i_inv, h_j_inv = _bipolar_projection(lamg, phig, lon_bp, rp) h_i_inv = h_i_inv[:, :-1] * 2 * np.pi / float(Ni) h_j_inv = h_j_inv[:-1, :] * PI_180 * (90 - lat0_bp) / float(Nj_ncap) - print(" number of js=", phis.shape[0]) + print(' number of js=', phis.shape[0]) return lams, phis, h_i_inv, h_j_inv From 3c494bdc4fc2a378b23db14b66b523c29e5c55e7 Mon Sep 17 00:00:00 2001 From: Siddhesh Agarwal Date: Fri, 10 Nov 2023 14:32:41 +0530 Subject: [PATCH 10/16] formatted --- xesmf/backend.py | 2 +- xesmf/util.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/xesmf/backend.py b/xesmf/backend.py index ff2aaa92..ad0871dc 100644 --- a/xesmf/backend.py +++ b/xesmf/backend.py @@ -25,8 +25,8 @@ import ESMF import numpy as np -import numpy.typing as npt import numpy.lib.recfunctions as nprec +import numpy.typing as npt def warn_f_contiguous(a: npt.NDArray[np.floating[Any]]) -> None: diff --git a/xesmf/util.py b/xesmf/util.py index 3fdc5018..5f68cf59 100644 --- a/xesmf/util.py +++ b/xesmf/util.py @@ -373,13 +373,13 @@ def _generate_bipolar_cap_mesh( if Nj_ncap % 2 != 0 and ensure_nj_even: print(' Supergrid has an odd number of area cells!') if ensure_nj_even: - print(' The number of j's is not even. Fixing this by cutting one row.') + print(' The number of j\'s is not even. Fixing this by cutting one row.') Nj_ncap = Nj_ncap - 1 lon_g = lon_bp + np.arange(Ni + 1) * 360.0 / float(Ni) - lamg: float = np.tile(lon_g, (Nj_ncap + 1, 1)) + lamg = np.tile(lon_g, (Nj_ncap + 1, 1)) latg0_cap = lat0_bp + np.arange(Nj_ncap + 1) * (90 - lat0_bp) / float(Nj_ncap) - phig: float = np.tile(latg0_cap.reshape((Nj_ncap + 1, 1)), (1, Ni + 1)) + phig = np.tile(latg0_cap.reshape((Nj_ncap + 1, 1)), (1, Ni + 1)) rp = np.tan(0.5 * (90 - lat0_bp) * PI_180) lams, phis, h_i_inv, h_j_inv = _bipolar_projection(lamg, phig, lon_bp, rp) h_i_inv = h_i_inv[:, :-1] * 2 * np.pi / float(Ni) From 5be7f20489a24eab6af5fb06db44e71c9512adc0 Mon Sep 17 00:00:00 2001 From: David Huard Date: Fri, 10 Nov 2023 08:41:14 -0500 Subject: [PATCH 11/16] run pre-commit --- .github/dependabot.yml | 18 +-- .github/workflows/ci.yaml | 176 ++++++++++++------------- .github/workflows/linting.yaml | 26 ++-- .github/workflows/pypi.yaml | 96 +++++++------- .pre-commit-config.yaml | 114 ++++++++-------- binder/environment.yml | 26 ++-- ci/environment-upstream-dev.yml | 34 ++--- ci/environment.yml | 32 ++--- codecov.yml | 26 ++-- doc/notebooks/Compare_algorithms.ipynb | 47 +++---- doc/notebooks/Masking.ipynb | 14 +- doc/notebooks/Pure_numpy.ipynb | 14 +- readthedocs.yml | 14 +- 13 files changed, 319 insertions(+), 318 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index b4b3fa44..2eacebc2 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,11 +1,11 @@ version: 2 updates: - # - package-ecosystem: pip - # directory: "/" - # schedule: - # interval: daily - - package-ecosystem: 'github-actions' - directory: '/' - schedule: - # Check for updates once a week - interval: 'weekly' + # - package-ecosystem: pip + # directory: "/" + # schedule: + # interval: daily + - package-ecosystem: 'github-actions' + directory: '/' + schedule: + # Check for updates once a week + interval: 'weekly' diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index cc84cb70..9b970a3c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,94 +1,94 @@ name: CI on: - push: - branches: - - master - pull_request: - branches: - - '*' - schedule: - - cron: '0 0 * * *' # Daily “At 00:00” - workflow_dispatch: # allows you to trigger manually + push: + branches: + - master + pull_request: + branches: + - '*' + schedule: + - cron: '0 0 * * *' # Daily “At 00:00” + workflow_dispatch: # allows you to trigger manually jobs: - build: - runs-on: ubuntu-latest - defaults: - run: - shell: bash -l {0} - strategy: - fail-fast: false - matrix: - include: - # Warning: Unless in quotations, numbers below are read as floats. 3.10 < 3.2 - - python-version: '3.8' - esmf-version: 8.2 - - python-version: '3.9' - esmf-version: 8.3 - - python-version: '3.10' - esmf-version: 8.4 - - python-version: '3.11' - esmf-version: 8.4 - steps: - - name: Cancel previous runs - uses: styfle/cancel-workflow-action@0.12.0 - with: - access_token: ${{ github.token }} - - name: Checkout source - uses: actions/checkout@v4 - - name: Create conda environment - uses: mamba-org/provision-with-micromamba@main - with: - cache-downloads: true - micromamba-version: 'latest' - environment-file: ci/environment.yml - extra-specs: | - python=${{ matrix.python-version }} - esmpy=${{ matrix.esmf-version }} - - name: Install Xesmf (editable) - run: | - python -m pip install --no-deps -e . - - name: Conda list information - run: | - conda env list - conda list - - name: Run tests - run: | - python -m pytest --cov=./ --cov-report=xml --verbose - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3.1.3 - with: - file: ./coverage.xml - fail_ci_if_error: false + build: + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + strategy: + fail-fast: false + matrix: + include: + # Warning: Unless in quotations, numbers below are read as floats. 3.10 < 3.2 + - python-version: '3.8' + esmf-version: 8.2 + - python-version: '3.9' + esmf-version: 8.3 + - python-version: '3.10' + esmf-version: 8.4 + - python-version: '3.11' + esmf-version: 8.4 + steps: + - name: Cancel previous runs + uses: styfle/cancel-workflow-action@0.12.0 + with: + access_token: ${{ github.token }} + - name: Checkout source + uses: actions/checkout@v4 + - name: Create conda environment + uses: mamba-org/provision-with-micromamba@main + with: + cache-downloads: true + micromamba-version: 'latest' + environment-file: ci/environment.yml + extra-specs: | + python=${{ matrix.python-version }} + esmpy=${{ matrix.esmf-version }} + - name: Install Xesmf (editable) + run: | + python -m pip install --no-deps -e . + - name: Conda list information + run: | + conda env list + conda list + - name: Run tests + run: | + python -m pytest --cov=./ --cov-report=xml --verbose + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3.1.3 + with: + file: ./coverage.xml + fail_ci_if_error: false - upstream-dev: - name: upstream-dev - runs-on: ubuntu-latest - defaults: - run: - shell: bash -l {0} - steps: - - name: Cancel previous runs - uses: styfle/cancel-workflow-action@0.12.0 - with: - access_token: ${{ github.token }} - - uses: actions/checkout@v4 - - name: Create conda environment - uses: mamba-org/provision-with-micromamba@v16 - with: - cache-downloads: true - micromamba-version: 'latest' - environment-file: ci/environment-upstream-dev.yml - extra-specs: | - python=3.10 - - name: Install Xesmf (editable) - run: | - python -m pip install -e . - - name: Conda list information - run: | - conda env list - conda list - - name: Run tests - run: | - python -m pytest --cov=./ --cov-report=xml --verbose + upstream-dev: + name: upstream-dev + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + steps: + - name: Cancel previous runs + uses: styfle/cancel-workflow-action@0.12.0 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v4 + - name: Create conda environment + uses: mamba-org/provision-with-micromamba@v16 + with: + cache-downloads: true + micromamba-version: 'latest' + environment-file: ci/environment-upstream-dev.yml + extra-specs: | + python=3.10 + - name: Install Xesmf (editable) + run: | + python -m pip install -e . + - name: Conda list information + run: | + conda env list + conda list + - name: Run tests + run: | + python -m pytest --cov=./ --cov-report=xml --verbose diff --git a/.github/workflows/linting.yaml b/.github/workflows/linting.yaml index 2391406f..c7c018ea 100644 --- a/.github/workflows/linting.yaml +++ b/.github/workflows/linting.yaml @@ -1,18 +1,18 @@ name: linting on: - push: - branches: - - master - pull_request: - branches: '*' + push: + branches: + - master + pull_request: + branches: '*' jobs: - linting: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 - with: - python-version: '3.x' - - uses: pre-commit/action@v3.0.0 + linting: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: '3.x' + - uses: pre-commit/action@v3.0.0 diff --git a/.github/workflows/pypi.yaml b/.github/workflows/pypi.yaml index fde7fd5c..962a1b7b 100644 --- a/.github/workflows/pypi.yaml +++ b/.github/workflows/pypi.yaml @@ -1,55 +1,55 @@ name: Publish to PyPI on: - pull_request: - push: - branches: - - master - release: - types: - - published + pull_request: + push: + branches: + - master + release: + types: + - published defaults: - run: - shell: bash + run: + shell: bash jobs: - packages: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.x' - - - name: Get tags - run: git fetch --depth=1 origin +refs/tags/*:refs/tags/* - - - name: Install build tools - run: | - python -m pip install --upgrade build - - - name: Build binary wheel - run: python -m build --sdist --wheel . --outdir dist - - - name: CheckFiles - run: | - ls dist - python -m pip install --upgrade check-manifest - check-manifest --verbose - - - name: Test wheels - run: | - # We cannot run this step b/c esmpy is not available on PyPI - # cd dist && python -m pip install *.whl && cd .. - python -m pip install --upgrade build twine - python -m twine check dist/* - - - name: Publish a Python distribution to PyPI - if: success() && github.event_name == 'release' - uses: pypa/gh-action-pypi-publish@release/v1 - with: - user: __token__ - password: ${{ secrets.PYPI_TOKEN }} + packages: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Get tags + run: git fetch --depth=1 origin +refs/tags/*:refs/tags/* + + - name: Install build tools + run: | + python -m pip install --upgrade build + + - name: Build binary wheel + run: python -m build --sdist --wheel . --outdir dist + + - name: CheckFiles + run: | + ls dist + python -m pip install --upgrade check-manifest + check-manifest --verbose + + - name: Test wheels + run: | + # We cannot run this step b/c esmpy is not available on PyPI + # cd dist && python -m pip install *.whl && cd .. + python -m pip install --upgrade build twine + python -m twine check dist/* + + - name: Publish a Python distribution to PyPI + if: success() && github.event_name == 'release' + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.PYPI_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a190da46..f148794f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,70 +1,70 @@ default_language_version: - python: python3 + python: python3 repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 - hooks: - - id: trailing-whitespace - - id: end-of-file-fixer - - id: check-docstring-first - - id: check-json - - id: check-yaml - - id: double-quote-string-fixer + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-docstring-first + - id: check-json + - id: check-yaml + - id: double-quote-string-fixer - - repo: https://github.com/psf/black - rev: 23.9.1 - hooks: - - id: black + - repo: https://github.com/psf/black + rev: 23.9.1 + hooks: + - id: black - - repo: https://github.com/keewis/blackdoc - rev: v0.3.8 - hooks: - - id: blackdoc + - repo: https://github.com/keewis/blackdoc + rev: v0.3.8 + hooks: + - id: blackdoc - - repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 - hooks: - - id: flake8 + - repo: https://github.com/PyCQA/flake8 + rev: 6.1.0 + hooks: + - id: flake8 - - repo: https://github.com/asottile/seed-isort-config - rev: v2.2.0 - hooks: - - id: seed-isort-config - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort + - repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.3 - hooks: - - id: prettier + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v3.0.3 + hooks: + - id: prettier - - repo: https://github.com/deathbeds/prenotebook - rev: f5bdb72a400f1a56fe88109936c83aa12cc349fa - hooks: - - id: prenotebook - args: - [ - '--keep-output', - '--keep-metadata', - '--keep-execution-count', - '--keep-empty', - ] + - repo: https://github.com/deathbeds/prenotebook + rev: f5bdb72a400f1a56fe88109936c83aa12cc349fa + hooks: + - id: prenotebook + args: + [ + '--keep-output', + '--keep-metadata', + '--keep-execution-count', + '--keep-empty', + ] - - repo: https://github.com/tox-dev/pyproject-fmt - rev: 1.2.0 - hooks: - - id: pyproject-fmt + - repo: https://github.com/tox-dev/pyproject-fmt + rev: 1.2.0 + hooks: + - id: pyproject-fmt ci: - autofix_commit_msg: | - [pre-commit.ci] auto fixes from pre-commit.com hooks + autofix_commit_msg: | + [pre-commit.ci] auto fixes from pre-commit.com hooks - for more information, see https://pre-commit.ci - autofix_prs: true - autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' - autoupdate_schedule: monthly - skip: [] - submodules: false + for more information, see https://pre-commit.ci + autofix_prs: true + autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' + autoupdate_schedule: monthly + skip: [] + submodules: false diff --git a/binder/environment.yml b/binder/environment.yml index bcb6ee6e..42e1e14b 100644 --- a/binder/environment.yml +++ b/binder/environment.yml @@ -1,15 +1,15 @@ channels: - - conda-forge + - conda-forge dependencies: - - python=3.7 - - esmpy==7.1.0r - - xarray - - dask - - numpy - - scipy - - shapely - - matplotlib - - cartopy - - cf_xarray>=0.3.1 - - pip: - - xesmf==0.2.2 + - python=3.7 + - esmpy==7.1.0r + - xarray + - dask + - numpy + - scipy + - shapely + - matplotlib + - cartopy + - cf_xarray>=0.3.1 + - pip: + - xesmf==0.2.2 diff --git a/ci/environment-upstream-dev.yml b/ci/environment-upstream-dev.yml index 659aee2a..3d0f2590 100644 --- a/ci/environment-upstream-dev.yml +++ b/ci/environment-upstream-dev.yml @@ -1,20 +1,20 @@ name: xesmf channels: - - conda-forge + - conda-forge dependencies: - - cftime - - codecov - - dask - - esmpy - - numba - - numpy - - pip - - pre-commit - - pydap - - pytest - - pytest-cov - - shapely - - sparse>=0.8.0 - - pip: - - git+https://github.com/pydata/xarray.git - - git+https://github.com/xarray-contrib/cf-xarray.git + - cftime + - codecov + - dask + - esmpy + - numba + - numpy + - pip + - pre-commit + - pydap + - pytest + - pytest-cov + - shapely + - sparse>=0.8.0 + - pip: + - git+https://github.com/pydata/xarray.git + - git+https://github.com/xarray-contrib/cf-xarray.git diff --git a/ci/environment.yml b/ci/environment.yml index b596fcaa..8f57d1ba 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -1,19 +1,19 @@ name: xesmf channels: - - conda-forge + - conda-forge dependencies: - - cf_xarray>=0.3.1 - - cftime - - codecov - - dask - - esmpy - - numba - - numpy - - pip - - pre-commit - - pydap - - pytest - - pytest-cov - - shapely - - sparse>=0.8.0 - - xarray>=0.17.0 + - cf_xarray>=0.3.1 + - cftime + - codecov + - dask + - esmpy + - numba + - numpy + - pip + - pre-commit + - pydap + - pytest + - pytest-cov + - shapely + - sparse>=0.8.0 + - xarray>=0.17.0 diff --git a/codecov.yml b/codecov.yml index 1e11ad52..d241151b 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,20 +1,20 @@ codecov: - require_ci_to_pass: no - max_report_age: off + require_ci_to_pass: no + max_report_age: off comment: false ignore: - - 'xesmf/tests/*' - - 'setup.py' + - 'xesmf/tests/*' + - 'setup.py' coverage: - precision: 2 - round: down - status: - project: - default: - target: 95 - informational: true - patch: off - changes: off + precision: 2 + round: down + status: + project: + default: + target: 95 + informational: true + patch: off + changes: off diff --git a/doc/notebooks/Compare_algorithms.ipynb b/doc/notebooks/Compare_algorithms.ipynb index 0e194893..f5dcb746 100644 --- a/doc/notebooks/Compare_algorithms.ipynb +++ b/doc/notebooks/Compare_algorithms.ipynb @@ -8,12 +8,12 @@ "\n", "xESMF exposes five different regridding algorithms from the ESMF library:\n", "\n", - "- `bilinear`: `ESMF.RegridMethod.BILINEAR`\n", - "- `conservative`: `ESMF.RegridMethod.CONSERVE`\n", - "- `conservative_normed`: `ESMF.RegridMethod.CONSERVE`\n", - "- `patch`: `ESMF.RegridMethod.PATCH`\n", - "- `nearest_s2d`: `ESMF.RegridMethod.NEAREST_STOD`\n", - "- `nearest_d2s`: `ESMF.RegridMethod.NEAREST_DTOS`\n", + "- `bilinear`: `ESMF.RegridMethod.BILINEAR`\n", + "- `conservative`: `ESMF.RegridMethod.CONSERVE`\n", + "- `conservative_normed`: `ESMF.RegridMethod.CONSERVE`\n", + "- `patch`: `ESMF.RegridMethod.PATCH`\n", + "- `nearest_s2d`: `ESMF.RegridMethod.NEAREST_STOD`\n", + "- `nearest_d2s`: `ESMF.RegridMethod.NEAREST_DTOS`\n", "\n", "where `conservative_normed` is just the `conservative` method with the\n", "normalization set to `ESMF.NormType.FRACAREA` instead of the default\n", @@ -23,26 +23,27 @@ "\n", "## Notes\n", "\n", - "- `bilinear` and `conservative` should be the most commonly used methods. They\n", - " are both monotonic (i.e. will not create new maximum/minimum).\n", - "- Nearest neighbour methods, either source to destination (s2d) or destination\n", - " to source (d2s), could be useful in special cases. Keep in mind that d2s is\n", - " highly non-monotonic.\n", - "- Patch is ESMF's unique method, producing highly smooth results but quite slow.\n", - "- From the ESMF documentation:\n", + "- `bilinear` and `conservative` should be the most commonly used methods. They\n", + " are both monotonic (i.e. will not create new maximum/minimum).\n", + "- Nearest neighbour methods, either source to destination (s2d) or destination\n", + " to source (d2s), could be useful in special cases. Keep in mind that d2s is\n", + " highly non-monotonic.\n", + "- Patch is ESMF's unique method, producing highly smooth results but quite\n", + " slow.\n", + "- From the ESMF documentation:\n", "\n", - " > The weight $w_{ij}$ for a particular source cell $i$ and destination cell\n", - " > $j$ are calculated as $w_{ij}=f_{ij} * A_{si}/A_{dj}$. In this equation\n", - " > $f_{ij}$ is the fraction of the source cell $i$ contributing to destination\n", - " > cell $j$, and $A_{si}$ and $A_{dj}$ are the areas of the source and\n", - " > destination cells.\n", + " > The weight $w_{ij}$ for a particular source cell $i$ and destination cell\n", + " > $j$ are calculated as $w_{ij}=f_{ij} * A_{si}/A_{dj}$. In this equation\n", + " > $f_{ij}$ is the fraction of the source cell $i$ contributing to\n", + " > destination cell $j$, and $A_{si}$ and $A_{dj}$ are the areas of the\n", + " > source and destination cells.\n", "\n", - " For `conservative_normed`,\n", + " For `conservative_normed`,\n", "\n", - " > ... then the weights are further divided by the destination fraction. In\n", - " > other words, in that case $w_{ij}=f_{ij} * A_{si}/(A_{dj}*D_j)$ where $D_j$\n", - " > is fraction of the destination cell that intersects the unmasked source\n", - " > grid.\n", + " > ... then the weights are further divided by the destination fraction. In\n", + " > other words, in that case $w_{ij}=f_{ij} * A_{si}/(A_{dj}*D_j)$ where\n", + " > $D_j$ is fraction of the destination cell that intersects the unmasked\n", + " > source grid.\n", "\n", "Detailed explanations are available on\n", "[ESMPy documentation](http://www.earthsystemmodeling.org/esmf_releases/last_built/esmpy_doc/html/api.html#regridding).\n", diff --git a/doc/notebooks/Masking.ipynb b/doc/notebooks/Masking.ipynb index 0dcc213d..8c134190 100644 --- a/doc/notebooks/Masking.ipynb +++ b/doc/notebooks/Masking.ipynb @@ -499,15 +499,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "- mask can only be 2D (ESMF design) so regridding a 3D field requires to\n", - " generate regridding weights for each vertical level.\n", + "- mask can only be 2D (ESMF design) so regridding a 3D field requires to\n", + " generate regridding weights for each vertical level.\n", "\n", - "- conservative method will give you a normalization by the total area of the\n", - " target cell. Except for some specific cases, you probably want to use\n", - " conservative_normed.\n", + "- conservative method will give you a normalization by the total area of the\n", + " target cell. Except for some specific cases, you probably want to use\n", + " conservative_normed.\n", "\n", - "- results with other methods (e.g. bilinear) may not give masks consistent with\n", - " the coarse grid.\n" + "- results with other methods (e.g. bilinear) may not give masks consistent\n", + " with the coarse grid.\n" ] }, { diff --git a/doc/notebooks/Pure_numpy.ipynb b/doc/notebooks/Pure_numpy.ipynb index a04ab748..308eef62 100644 --- a/doc/notebooks/Pure_numpy.ipynb +++ b/doc/notebooks/Pure_numpy.ipynb @@ -236,10 +236,10 @@ "We use the previous input data, but now assume it is on a curvilinear grid\n", "described by 2D arrays. We also computed the cell corners, for two purposes:\n", "\n", - "- Visualization with `plt.pcolormesh` (using cell centers will miss one\n", - " row&column)\n", - "- Conservative regridding with xESMF (corner information is required for\n", - " conservative method)\n" + "- Visualization with `plt.pcolormesh` (using cell centers will miss one\n", + " row&column)\n", + "- Conservative regridding with xESMF (corner information is required for\n", + " conservative method)\n" ] }, { @@ -446,9 +446,9 @@ "source": [ "All $2 \\times 2\\times 2 = 8$ combinations would work:\n", "\n", - "- Input grid: `xarray.DataSet` or `dict`\n", - "- Output grid: `xarray.DataSet` or `dict`\n", - "- Input data: `xarray.DataArray` or `numpy.ndarray`\n", + "- Input grid: `xarray.DataSet` or `dict`\n", + "- Output grid: `xarray.DataSet` or `dict`\n", + "- Input data: `xarray.DataArray` or `numpy.ndarray`\n", "\n", "The output data type will be the same as input data.\n" ] diff --git a/readthedocs.yml b/readthedocs.yml index f0c4a4ae..c605ae18 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,12 +1,12 @@ version: 2 build: - os: ubuntu-22.04 - tools: - python: '3.9' + os: ubuntu-22.04 + tools: + python: '3.9' python: - install: - - requirements: doc/requirements.txt - - method: pip - path: . + install: + - requirements: doc/requirements.txt + - method: pip + path: . From 76ff3f9f01fa48911cb8ea9989b81fbf4304d989 Mon Sep 17 00:00:00 2001 From: David Huard Date: Fri, 10 Nov 2023 09:28:31 -0500 Subject: [PATCH 12/16] Remove typing statements not supported by 3.8 --- xesmf/data.py | 8 ++++---- xesmf/frontend.py | 14 +++++++------- xesmf/smm.py | 6 +++--- xesmf/util.py | 10 ++++------ 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/xesmf/data.py b/xesmf/data.py index 0505ae4e..a20918ac 100644 --- a/xesmf/data.py +++ b/xesmf/data.py @@ -2,7 +2,7 @@ Standard test data for regridding benchmark. """ -from typing import Any +from typing import Union import numpy as np import numpy.typing as npt @@ -10,9 +10,9 @@ def wave_smooth( # type: ignore - lon: npt.NDArray[np.floating[Any]] | xarray.DataArray, - lat: npt.NDArray[np.floating[Any]] | xarray.DataArray, -) -> npt.NDArray[np.floating[Any]] | xarray.DataArray: + lon: Union[npt.NDArray, xarray.DataArray], + lat: Union[npt.NDArray, xarray.DataArray], +) -> Union[npt.NDArray, xarray.DataArray]: """ Spherical harmonic with low frequency. diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 87b4adcd..467bd7df 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -3,7 +3,7 @@ """ import warnings -from typing import Any, Dict, Hashable, List, Literal, Optional, Sequence, Tuple +from typing import Any, Dict, Hashable, List, Literal, Optional, Sequence, Tuple, Union import cf_xarray as cfxr import numpy as np @@ -270,8 +270,8 @@ def polys_to_ESMFmesh(polys) -> tuple[Mesh, tuple[Literal[1], int]]: class BaseRegridder(object): def __init__( self, - grid_in: Grid | LocStream | Mesh, - grid_out: Grid | LocStream | Mesh, + grid_in: Union[Grid, LocStream, Mesh], + grid_out: Union[Grid, LocStream, Mesh], method: str, filename: Optional[str] = None, reuse_weights: bool = False, @@ -492,7 +492,7 @@ def __call__( keep_attrs: bool = False, skipna: bool = False, na_thres: float = 1.0, - output_chunks: Optional[Dict[str, int] | Tuple[int, ...]] = None, + output_chunks: Optional[Union[Dict[str, int], Tuple[int, ...]]] = None, ): """ Apply regridding to input data. @@ -629,7 +629,7 @@ def regrid_array( weights: sps.coo_matrix, skipna: bool = False, na_thres: float = 1.0, - output_chunks: Optional[Tuple[int, ...] | Dict[str, int]] = None, + output_chunks: Optional[Union[Tuple[int, ...], Dict[str, int]]] = None, ): """See __call__().""" if self.sequence_in: @@ -693,7 +693,7 @@ def regrid_dataarray( keep_attrs: bool = False, skipna: bool = False, na_thres: float = 1.0, - output_chunks: Optional[Dict[str, int] | Tuple[int, ...]] = None, + output_chunks: Optional[Union[Dict[str, int], Tuple[int, ...]]] = None, ) -> DataArray | Dataset: """See __call__().""" @@ -718,7 +718,7 @@ def regrid_dataset( keep_attrs: bool = False, skipna: bool = False, na_thres: float = 1.0, - output_chunks: Optional[Dict[str, int] | Tuple[int, ...]] = None, + output_chunks: Optional[Union[Dict[str, int], Tuple[int, ...]]] = None, ) -> DataArray | Dataset: """See __call__().""" diff --git a/xesmf/smm.py b/xesmf/smm.py index f46b0f86..97cc89b0 100644 --- a/xesmf/smm.py +++ b/xesmf/smm.py @@ -3,7 +3,7 @@ """ import warnings from pathlib import Path -from typing import Any, Tuple +from typing import Any, Dict, Tuple, Union import numba as nb # type: ignore[import] import numpy as np @@ -13,7 +13,7 @@ def read_weights( - weights: str | Path | xr.Dataset | xr.DataArray | sps.COO | dict[str, Any], + weights: Union[str, Path, xr.Dataset, xr.DataArray, sps.COO, Dict[str, Any]], n_in: int, n_out: int, ) -> xr.DataArray: @@ -56,7 +56,7 @@ def read_weights( def _parse_coords_and_values( - indata: str | Path | xr.Dataset | dict[str, Any], + indata: Union[str, Path, xr.Dataset, Dict[str, Any]], n_in: int, n_out: int, ) -> xr.DataArray: diff --git a/xesmf/util.py b/xesmf/util.py index 5f68cf59..fa921b45 100644 --- a/xesmf/util.py +++ b/xesmf/util.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Generator, List, Literal, Tuple +from typing import Any, Generator, List, Literal, Tuple, Union import numpy as np import numpy.typing as npt @@ -10,9 +10,7 @@ LAT_CF_ATTRS = {'standard_name': 'latitude', 'units': 'degrees_north'} -def _grid_1d( - start_b: float, end_b: float, step: float -) -> tuple[npt.NDArray[np.floating[Any]], npt.NDArray[np.floating[Any]]]: +def _grid_1d(start_b: float, end_b: float, step: float) -> Tuple[npt.NDArray, npt.NDArray]: """ 1D grid centers and bounds @@ -191,7 +189,7 @@ def grid_global( def _flatten_poly_list( polys: List[Polygon], -) -> Generator[tuple[int, Any] | tuple[int, Polygon], Any, None]: +) -> Generator[Union[Tuple[int, Any], Tuple[int, Polygon]], Any, None]: """Iterator flattening MultiPolygons.""" for i, poly in enumerate(polys): if isinstance(poly, MultiPolygon): @@ -248,7 +246,7 @@ def simple_tripolar_grid( nlats: int, lat_cap: float = 60, lon_cut: float = -300, -) -> tuple[npt.NDArray[np.floating[Any]], npt.NDArray[Any]]: +) -> Tuple[npt.NDArray, npt.NDArray]: """Generate a simple tripolar grid, regular under `lat_cap`. Parameters From 01dab15d94532c7d8f47d061431b0aa281e4205d Mon Sep 17 00:00:00 2001 From: David Huard Date: Fri, 10 Nov 2023 09:41:10 -0500 Subject: [PATCH 13/16] Yet other typing adjustements --- xesmf/backend.py | 22 ++++++++--------- xesmf/frontend.py | 62 +++++++++++++++++++++++------------------------ 2 files changed, 41 insertions(+), 43 deletions(-) diff --git a/xesmf/backend.py b/xesmf/backend.py index ad0871dc..34701f75 100644 --- a/xesmf/backend.py +++ b/xesmf/backend.py @@ -29,7 +29,7 @@ import numpy.typing as npt -def warn_f_contiguous(a: npt.NDArray[np.floating[Any]]) -> None: +def warn_f_contiguous(a: npt.NDArray) -> None: """ Give a warning if input array if not Fortran-ordered. @@ -44,7 +44,7 @@ def warn_f_contiguous(a: npt.NDArray[np.floating[Any]]) -> None: warnings.warn('Input array is not F_CONTIGUOUS. ' 'Will affect performance.') -def warn_lat_range(lat: npt.NDArray[np.floating[Any]]) -> None: +def warn_lat_range(lat: npt.NDArray) -> None: """ Give a warning if latitude is outside of [-90, 90] @@ -63,8 +63,8 @@ class Grid(ESMF.Grid): @classmethod def from_xarray( cls, - lon: npt.NDArray[np.floating[Any]], - lat: npt.NDArray[np.floating[Any]], + lon: npt.NDArray, + lat: npt.NDArray, periodic: bool = False, mask: Union[npt.NDArray[np.integer[Any]], None] = None, ): @@ -162,8 +162,8 @@ class LocStream(ESMF.LocStream): @classmethod def from_xarray( cls, - lon: npt.NDArray[np.floating[Any]], - lat: npt.NDArray[np.floating[Any]], + lon: npt.NDArray, + lat: npt.NDArray, ) -> ESMF.LocStream: """ Create an ESMF.LocStream object, for contrusting ESMF.Field and ESMF.Regrid @@ -246,7 +246,7 @@ class Mesh(ESMF.Mesh): def from_polygons( cls, polys: npt.NDArray[np.object_], - element_coords: Union[Literal['centroid'], npt.NDArray[np.floating[Any]]] = 'centroid', + element_coords: Union[Literal['centroid'], npt.NDArray] = 'centroid', ): """ Create an ESMF.Mesh object from a list of polygons. @@ -586,8 +586,8 @@ def esmf_regrid_finalize(regrid: ESMF.Regrid): def esmf_locstream( - lon: npt.NDArray[np.floating[Any]], - lat: npt.NDArray[np.floating[Any]], + lon: npt.NDArray, + lat: npt.NDArray, ) -> LocStream: warnings.warn( '`esmf_locstream` is being deprecated in favor of `LocStream.from_xarray`', @@ -597,8 +597,8 @@ def esmf_locstream( def esmf_grid( - lon: npt.NDArray[np.floating[Any]], - lat: npt.NDArray[np.floating[Any]], + lon: npt.NDArray, + lat: npt.NDArray, periodic: bool = False, mask: Union[npt.NDArray[np.integer[Any]], None] = None, ) -> Grid: diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 467bd7df..bf2b66b5 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -33,8 +33,8 @@ def subset_regridder( - ds_in: DataArray | Dataset | dict[str, DataArray], - ds_out: DataArray | Dataset | dict[str, DataArray], + ds_in: Union[DataArray, Dataset, Dict[str, DataArray]], + ds_out: Union[DataArray, Dataset, Dict[str, DataArray]], method: Literal[ 'bilinear', 'conservative', @@ -79,9 +79,9 @@ def subset_regridder( def as_2d_mesh( - lon: DataArray | npt.NDArray[np.floating[Any]], - lat: DataArray | npt.NDArray[np.floating[Any]], -) -> tuple[DataArray | npt.NDArray[np.floating[Any]], DataArray | npt.NDArray[np.floating[Any]]]: + lon: Union[DataArray, npt.NDArray], + lat: Union[DataArray, npt.NDArray], +) -> Tuple[Union[DataArray, npt.NDArray], Union[DataArray, npt.NDArray]]: if (lon.ndim, lat.ndim) == (2, 2): assert lon.shape == lat.shape, 'lon and lat should have same shape' elif (lon.ndim, lat.ndim) == (1, 1): @@ -93,8 +93,8 @@ def as_2d_mesh( def _get_lon_lat( - ds: Dataset | Dict[str, npt.NDArray[np.floating[Any]]] -) -> tuple[DataArray | npt.NDArray[np.floating[Any]], DataArray | npt.NDArray[np.floating[Any]]]: + ds: Union[Dataset, Dict[str, npt.NDArray]] +) -> Tuple[Union[DataArray, npt.NDArray], Union[DataArray, npt.NDArray]]: """Return lon and lat extracted from ds.""" if ('lat' in ds and 'lon' in ds) or ('lat' in ds.coords and 'lon' in ds.coords): # Old way. @@ -112,8 +112,8 @@ def _get_lon_lat( def _get_lon_lat_bounds( - ds: Dataset | Dict[str, npt.NDArray[np.floating[Any]]] -) -> tuple[DataArray | npt.NDArray[np.floating[Any]], DataArray | npt.NDArray[np.floating[Any]]]: + ds: Union[Dataset, Dict[str, npt.NDArray]] +) -> Tuple[Union[DataArray, npt.NDArray], Union[DataArray, npt.NDArray]]: """Return bounds of lon and lat extracted from ds.""" if 'lat_b' in ds and 'lon_b' in ds: # Old way. @@ -143,7 +143,7 @@ def _get_lon_lat_bounds( def ds_to_ESMFgrid( - ds: Dataset | Dict[str, npt.NDArray[np.floating[Any]]], + ds: Union[Dataset, Dict[str, npt.NDArray]], need_bounds: bool = False, periodic: bool = False, append=None, @@ -488,7 +488,7 @@ def _compute_weights(self): def __call__( self, - indata: npt.NDArray[np.floating[Any]] | dask_array_type | xr.DataArray | xr.Dataset, + indata: Union[npt.NDArray, dask_array_type, xr.DataArray, xr.Dataset], keep_attrs: bool = False, skipna: bool = False, na_thres: float = 1.0, @@ -597,17 +597,17 @@ def __call__( @staticmethod def _regrid( - indata: npt.NDArray[np.floating[Any]], + indata: npt.NDArray, weights: sps.coo_matrix, *, shape_in: Tuple[int, int], shape_out: Tuple[int, int], skipna: bool, na_thresh: float, - ) -> npt.NDArray[np.floating[Any]]: + ) -> npt.NDArray: # skipna: set missing values to zero if skipna: - missing: npt.NDArray[np.bool_] = np.isnan(indata) + missing: npt.NDArray = np.isnan(indata) indata = np.where(missing, 0.0, indata) # apply weights @@ -625,7 +625,7 @@ def _regrid( def regrid_array( self, - indata: npt.NDArray[np.floating[Any]] | dask_array_type, + indata: Union[npt.NDArray, dask_array_type], weights: sps.coo_matrix, skipna: bool = False, na_thres: float = 1.0, @@ -669,18 +669,14 @@ def regrid_array( outdata = self._regrid(indata, weights, **kwargs) return outdata - def regrid_numpy( - self, indata: npt.NDArray[np.floating[Any]], **kwargs - ) -> npt.NDArray[np.floating[Any]]: + def regrid_numpy(self, indata: npt.NDArray, **kwargs) -> npt.NDArray: warnings.warn( '`regrid_numpy()` will be removed in xESMF 0.7, please use `regrid_array` instead.', category=FutureWarning, ) return self.regrid_array(indata, self.weights.data, **kwargs) - def regrid_dask( - self, indata: npt.NDArray[np.floating[Any]], **kwargs - ) -> npt.NDArray[np.floating[Any]]: + def regrid_dask(self, indata: npt.NDArray, **kwargs) -> npt.NDArray: warnings.warn( '`regrid_dask()` will be removed in xESMF 0.7, please use `regrid_array` instead.', category=FutureWarning, @@ -694,7 +690,7 @@ def regrid_dataarray( skipna: bool = False, na_thres: float = 1.0, output_chunks: Optional[Union[Dict[str, int], Tuple[int, ...]]] = None, - ) -> DataArray | Dataset: + ) -> Union[DataArray, Dataset]: """See __call__().""" input_horiz_dims, temp_horiz_dims = self._parse_xrinput(dr_in) @@ -719,7 +715,7 @@ def regrid_dataset( skipna: bool = False, na_thres: float = 1.0, output_chunks: Optional[Union[Dict[str, int], Tuple[int, ...]]] = None, - ) -> DataArray | Dataset: + ) -> Union[DataArray, Dataset]: """See __call__().""" # get the first data variable to infer input_core_dims @@ -748,7 +744,7 @@ def regrid_dataset( return self._format_xroutput(ds_out, temp_horiz_dims) def _parse_xrinput( - self, dr_in: xr.DataArray | xr.Dataset + self, dr_in: Union[xr.DataArray, xr.Dataset] ) -> Tuple[Tuple[Hashable, ...], List[str]]: # dr could be a DataArray or a Dataset # Get input horiz dim names and set output horiz dim names @@ -785,8 +781,8 @@ def _parse_xrinput( return input_horiz_dims, temp_horiz_dims def _format_xroutput( - self, out: xr.DataArray | xr.Dataset, new_dims: Optional[List[str]] = None - ) -> xr.DataArray | xr.Dataset: + self, out: Union[xr.DataArray, xr.Dataset], new_dims: Optional[List[str]] = None + ) -> Union[xr.DataArray, xr.Dataset]: out.attrs['regrid_method'] = self.method return out @@ -823,8 +819,8 @@ def to_netcdf(self, filename: Optional[str] = None) -> str: class Regridder(BaseRegridder): def __init__( self, - ds_in: xr.DataArray | xr.Dataset | dict[str, xr.DataArray], - ds_out: xr.DataArray | xr.Dataset | dict[str, xr.DataArray], + ds_in: Union[xr.DataArray, xr.Dataset, Dict[str, xr.DataArray]], + ds_out: Union[xr.DataArray, xr.Dataset, Dict[str, xr.DataArray]], method: Literal[ 'bilinear', 'conservative', @@ -1169,13 +1165,13 @@ def _format_xroutput(self, out, new_dims=None): class SpatialAverager(BaseRegridder): def __init__( self, - ds_in: xr.DataArray | xr.Dataset | dict, - polys: Sequence[Polygon | MultiPolygon], + ds_in: Union[xr.DataArray, xr.Dataset, dict], + polys: Sequence[Union[Polygon, MultiPolygon]], ignore_holes: bool = False, periodic: bool = False, filename: Optional[str] = None, reuse_weights: bool = False, - weights: Optional[sps.coo_matrix | dict | str | Dataset] = None, + weights: Optional[Union[sps.coo_matrix, dict, str, Dataset]] = None, ignore_degenerate: bool = False, geom_dim_name: str = 'geom', ): @@ -1414,7 +1410,9 @@ def __repr__(self) -> str: return info - def _format_xroutput(self, out: DataArray | Dataset, new_dims=None) -> DataArray | Dataset: + def _format_xroutput( + self, out: Union[DataArray, Dataset], new_dims=None + ) -> Union[DataArray, Dataset]: out = out.squeeze(dim='dummy') # rename dimension name to match output grid From 2dcaf6fdf46e4d753a050c426f4e28e562f5997c Mon Sep 17 00:00:00 2001 From: David Huard Date: Fri, 10 Nov 2023 10:32:48 -0500 Subject: [PATCH 14/16] Typing bug fixes --- xesmf/backend.py | 9 +++++---- xesmf/frontend.py | 14 +++++++------- xesmf/smm.py | 12 ++++++------ 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/xesmf/backend.py b/xesmf/backend.py index 34701f75..1ce05424 100644 --- a/xesmf/backend.py +++ b/xesmf/backend.py @@ -17,7 +17,7 @@ import os import warnings -from typing import Any, List, Literal, Union +from typing import List, Literal, Optional, Sequence, Union try: import esmpy as ESMF @@ -27,6 +27,7 @@ import numpy as np import numpy.lib.recfunctions as nprec import numpy.typing as npt +from shapely.geometry import Polygon def warn_f_contiguous(a: npt.NDArray) -> None: @@ -66,7 +67,7 @@ def from_xarray( lon: npt.NDArray, lat: npt.NDArray, periodic: bool = False, - mask: Union[npt.NDArray[np.integer[Any]], None] = None, + mask: Optional[npt.NDArray] = None, ): """ Create an ESMF.Grid object, for constructing ESMF.Field and ESMF.Regrid. @@ -245,7 +246,7 @@ class Mesh(ESMF.Mesh): @classmethod def from_polygons( cls, - polys: npt.NDArray[np.object_], + polys: Sequence[Polygon], element_coords: Union[Literal['centroid'], npt.NDArray] = 'centroid', ): """ @@ -600,7 +601,7 @@ def esmf_grid( lon: npt.NDArray, lat: npt.NDArray, periodic: bool = False, - mask: Union[npt.NDArray[np.integer[Any]], None] = None, + mask: Optional[npt.NDArray] = None, ) -> Grid: warnings.warn( '`esmf_grid` is being deprecated in favor of `Grid.from_xarray`', diff --git a/xesmf/frontend.py b/xesmf/frontend.py index bf2b66b5..fd0e50f5 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -241,7 +241,7 @@ def ds_to_ESMFlocstream(ds): return locstream, (1,) + lon.shape, dim_names -def polys_to_ESMFmesh(polys) -> tuple[Mesh, tuple[Literal[1], int]]: +def polys_to_ESMFmesh(polys) -> Tuple[Mesh, Tuple[Literal[1], int]]: """ Convert a sequence of shapely Polygons to a ESMF.Mesh object. @@ -488,7 +488,7 @@ def _compute_weights(self): def __call__( self, - indata: Union[npt.NDArray, dask_array_type, xr.DataArray, xr.Dataset], + indata: Union[npt.NDArray, 'da.Array', xr.DataArray, xr.Dataset], keep_attrs: bool = False, skipna: bool = False, na_thres: float = 1.0, @@ -598,7 +598,7 @@ def __call__( @staticmethod def _regrid( indata: npt.NDArray, - weights: sps.coo_matrix, + weights: sps.COO, *, shape_in: Tuple[int, int], shape_out: Tuple[int, int], @@ -625,8 +625,8 @@ def _regrid( def regrid_array( self, - indata: Union[npt.NDArray, dask_array_type], - weights: sps.coo_matrix, + indata: Union[npt.NDArray, 'da.Array'], + weights: sps.COO, skipna: bool = False, na_thres: float = 1.0, output_chunks: Optional[Union[Tuple[int, ...], Dict[str, int]]] = None, @@ -1171,7 +1171,7 @@ def __init__( periodic: bool = False, filename: Optional[str] = None, reuse_weights: bool = False, - weights: Optional[Union[sps.coo_matrix, dict, str, Dataset]] = None, + weights: Optional[Union[sps.COO, dict, str, Dataset]] = None, ignore_degenerate: bool = False, geom_dim_name: str = 'geom', ): @@ -1312,7 +1312,7 @@ def _check_polys_length(polys: List[Polygon], threshold: int = 1) -> None: stacklevel=2, ) - def _compute_weights_and_area(self, mesh_out: Mesh) -> tuple[DataArray, Any]: + def _compute_weights_and_area(self, mesh_out: Mesh) -> Tuple[DataArray, Any]: """Return the weights and the area of the destination mesh cells.""" # Build the regrid object diff --git a/xesmf/smm.py b/xesmf/smm.py index 97cc89b0..9531fc45 100644 --- a/xesmf/smm.py +++ b/xesmf/smm.py @@ -111,8 +111,8 @@ def _parse_coords_and_values( def check_shapes( - indata: npt.NDArray[Any], - weights: npt.NDArray[Any], + indata: npt.NDArray, + weights: npt.NDArray, shape_in: Tuple[int, int], shape_out: Tuple[int, int], ) -> None: @@ -172,10 +172,10 @@ def check_shapes( def apply_weights( weights: sps.COO, - indata: npt.NDArray[Any], + indata: npt.NDArray, shape_in: Tuple[int, int], shape_out: Tuple[int, int], -) -> npt.NDArray[Any]: +) -> npt.NDArray: """ Apply regridding weights to data. @@ -254,8 +254,8 @@ def add_nans_to_weights(weights: xr.DataArray) -> xr.DataArray: def _combine_weight_multipoly( # type: ignore weights: xr.DataArray, - areas: npt.NDArray[np.integer[Any]], - indexes: npt.NDArray[np.integer[Any]], + areas: npt.NDArray, + indexes: npt.NDArray, ) -> xr.DataArray: """Reduce a weight sparse matrix (csc format) by combining (adding) columns. From 067ef8aa078766eef6323d547e3b39453002475e Mon Sep 17 00:00:00 2001 From: David Huard Date: Fri, 10 Nov 2023 10:47:55 -0500 Subject: [PATCH 15/16] fix typo --- xesmf/frontend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xesmf/frontend.py b/xesmf/frontend.py index fd0e50f5..93d98250 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -603,7 +603,7 @@ def _regrid( shape_in: Tuple[int, int], shape_out: Tuple[int, int], skipna: bool, - na_thresh: float, + na_thres: float, ) -> npt.NDArray: # skipna: set missing values to zero if skipna: @@ -617,7 +617,7 @@ def _regrid( if skipna: fraction_valid = apply_weights(weights, (~missing).astype('d'), shape_in, shape_out) tol = 1e-6 - bad = fraction_valid < np.clip(1 - na_thresh, tol, 1 - tol) + bad = fraction_valid < np.clip(1 - na_thres, tol, 1 - tol) fraction_valid[bad] = 1 outdata = np.where(bad, np.nan, outdata / fraction_valid) From 74cc81eeeb3d12907bbe1f48d5b3d66d92cb6ac7 Mon Sep 17 00:00:00 2001 From: David Huard Date: Fri, 10 Nov 2023 11:22:31 -0500 Subject: [PATCH 16/16] set periodic to None instead of False as before. --- xesmf/frontend.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 93d98250..91570a16 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -66,12 +66,12 @@ def subset_regridder( ds_out = ds_out.rename({'y_out': out_dims[0], 'x_out': out_dims[1]}) regridder = Regridder( - ds_in, - ds_out, - method, - locstream_in, - locstream_out, - periodic, + ds_in=ds_in, + ds_out=ds_out, + method=method, + locstream_in=locstream_in, + locstream_out=locstream_out, + periodic=periodic, parallel=False, **kwargs, ) @@ -145,7 +145,7 @@ def _get_lon_lat_bounds( def ds_to_ESMFgrid( ds: Union[Dataset, Dict[str, npt.NDArray]], need_bounds: bool = False, - periodic: bool = False, + periodic: bool = None, append=None, ): """