diff --git a/.gitignore b/.gitignore index 21dbc3d8..741ab569 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.coverage .idea* .vscode *pycache* @@ -6,12 +7,9 @@ *.png *.log *.nc -.coverage coverage.xml miniconda -# output from unittest -test_stats.csv -test_stats_csv.csv -test_stats_csv.dat +# output from pytest random_tolerance.csv +random_tolerance_None.csv diff --git a/AUTHORS.md b/AUTHORS.md index 37d6c2f5..f28d533c 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -6,6 +6,7 @@ | Marek Jacob | DWD | | Jonas Jucker | C2SM | | Annika Lauber | C2SM | +| Andrea Leuthard | MeteoSwiss | | Giacomo Serafini | MeteoSwiss | | Mikael Stellio | C2SM | | Ben Weber | MeteoSwiss | diff --git a/README.md b/README.md index 7b3a07f9..f0bf08bd 100644 --- a/README.md +++ b/README.md @@ -100,10 +100,13 @@ The pinned requirements can be installed by ```console ./setup_env.sh ``` +which is recommended for users. + The unpinned requirements and updating the environment can be done by ```console ./setup_env.sh -u -e ``` +which is recommended for developers and required for adding new requirements. ### The init command @@ -138,7 +141,17 @@ here should refer to your experiment script: ```console cd icon-base-dir/reference-build -python ../externals/probtest/probtest.py init --codebase-install $PWD --experiment-name exp_name --reference $PWD --file-id NetCDF "*atm_3d_ml*.nc" --file-id NetCDF "*atm_3d_il*.nc" --file-id NetCDF "*atm_3d_hl*.nc" --file-id NetCDF "*atm_3d_pl*.nc" --file-id latlon "*atm_2d_ll*.nc" --file-id meteogram "Meteogram*.nc" +python ../externals/probtest/probtest.py init \ + --codebase-install $PWD \ + --experiment-name exp_name \ + --reference $PWD \ + --file-id NetCDF "*atm_3d_ml*.nc" \ + --file-id NetCDF "*atm_3d_il*.nc" \ + --file-id NetCDF "*atm_3d_hl*.nc" \ + --file-id NetCDF "*atm_3d_pl*.nc" \ + --file-id latlon "*atm_2d_ll*.nc" \ + --file-id meteogram "Meteogram*.nc" \ + --file-id GRIB "lfff0*z" ``` You might need to update the used account in the json file. The perturbation amplitude may also need to be changed in the json file diff --git a/engine/cdo_table.py b/engine/cdo_table.py index 95d0d113..70ee82ba 100644 --- a/engine/cdo_table.py +++ b/engine/cdo_table.py @@ -17,7 +17,7 @@ from util.click_util import CommaSeperatedInts, cli_help from util.constants import cdo_bins from util.dataframe_ops import df_from_file_ids -from util.file_system import file_names_from_pattern +from util.file_system import get_file_names_from_pattern from util.log_handler import logger @@ -139,21 +139,21 @@ def cdo_table( assert isinstance(file_specification, dict), "must be dict" # save original method and restore at the end of this module - dataframe_from_ncfile_orig = model_output_parser.dataframe_from_ncfile + dataframe_from_ncfile_orig = model_output_parser.create_statistics_dataframe # modify netcdf parse method: - model_output_parser.dataframe_from_ncfile = rel_diff_stats + model_output_parser.create_statistics_dataframe = rel_diff_stats # step 1: compute rel-diff netcdf files with tempfile.TemporaryDirectory() as tmpdir: for _, file_pattern in file_id: - ref_files, err = file_names_from_pattern(model_output_dir, file_pattern) + ref_files, err = get_file_names_from_pattern(model_output_dir, file_pattern) if err > 0: logger.info( "did not find any files for pattern %s. Continue.", file_pattern ) continue ref_files.sort() - perturb_files, err = file_names_from_pattern( + perturb_files, err = get_file_names_from_pattern( perturbed_model_output_dir.format(member_id=member_id), file_pattern ) if err > 0: @@ -202,4 +202,4 @@ def cdo_table( Path(cdo_table_file).parent.mkdir(parents=True, exist_ok=True) df.to_csv(cdo_table_file) - model_output_parser.dataframe_from_ncfile = dataframe_from_ncfile_orig + model_output_parser.create_statistics_dataframe = dataframe_from_ncfile_orig diff --git a/engine/performance.py b/engine/performance.py index 6edf2bce..82814e9d 100644 --- a/engine/performance.py +++ b/engine/performance.py @@ -15,7 +15,7 @@ import click from util.click_util import cli_help -from util.file_system import file_names_from_pattern +from util.file_system import get_file_names_from_pattern from util.icon.extract_timings import read_logfile from util.log_handler import logger from util.tree import TREEFILE_TEMPLATE, TimingTree @@ -36,7 +36,9 @@ def performance(timing_regex, timing_database, append_time): if timing_file_name_base == "": timing_file_name_base = "." - timing_file_name, err = file_names_from_pattern(timing_file_name_base, timing_regex) + timing_file_name, err = get_file_names_from_pattern( + timing_file_name_base, timing_regex + ) if err > 0: logger.info("Did not find any files for regex %s", timing_regex) sys.exit(1) diff --git a/requirements/environment.yml b/requirements/environment.yml index c4070ab1..4d975866 100644 --- a/requirements/environment.yml +++ b/requirements/environment.yml @@ -5,76 +5,129 @@ channels: dependencies: - _libgcc_mutex=0.1 - _openmp_mutex=4.5 - - alsa-lib=1.2.12 + - alsa-lib=1.2.13 - annotated-types=0.7.0 + - array-api-compat=1.9.1 - astroid=3.3.5 - asttokens=2.4.1 - - black=24.8.0 + - attrs=24.2.0 + - aws-c-auth=0.8.0 + - aws-c-cal=0.8.0 + - aws-c-common=0.10.3 + - aws-c-compression=0.3.0 + - aws-c-event-stream=0.5.0 + - aws-c-http=0.9.1 + - aws-c-io=0.15.2 + - aws-c-mqtt=0.11.0 + - aws-c-s3=0.7.1 + - aws-c-sdkutils=0.2.1 + - aws-checksums=0.2.2 + - aws-crt-cpp=0.29.5 + - aws-sdk-cpp=1.11.449 + - azure-core-cpp=1.14.0 + - azure-identity-cpp=1.10.0 + - azure-storage-blobs-cpp=12.13.0 + - azure-storage-common-cpp=12.8.0 + - azure-storage-files-datalake-cpp=12.12.0 + - black=24.10.0 - blosc=1.21.6 + - bokeh=3.6.1 - brotli=1.1.0 - brotli-bin=1.1.0 + - brotli-python=1.1.0 - bzip2=1.0.8 - - c-ares=1.33.1 + - c-ares=1.34.3 - ca-certificates=2024.8.30 - cairo=1.18.0 - certifi=2024.8.30 - cffi=1.17.1 + - cfgrib=0.9.14.1 - cfgv=3.3.1 - cftime=1.6.4 + - charset-normalizer=3.4.0 - click=8.1.7 + - cloudpickle=3.1.0 - codespell=2.3.0 - colorama=0.4.6 - - contourpy=1.3.0 + - contourpy=1.3.1 - cycler=0.12.1 - cyrus-sasl=2.1.27 + - cytoolz=1.0.0 + - dask=2024.11.2 + - dask-core=2024.11.2 + - dask-expr=1.1.19 - dbus=1.13.6 - decorator=5.1.1 - dill=0.3.9 - - distlib=0.3.8 + - distlib=0.3.9 + - distributed=2024.11.2 - docutils=0.21.2 - double-conversion=3.3.0 + - earthkit-data=0.10.11 + - earthkit-geo=0.3.0 + - eccodes=2.38.3 + - entrypoints=0.4 - exceptiongroup=1.2.2 - executing=2.1.0 - - expat=2.6.3 + - expat=2.6.4 - filelock=3.16.1 + - findlibs=0.0.5 - flake8=7.1.1 - flake8-black=0.3.6 - font-ttf-dejavu-sans-mono=2.37 - font-ttf-inconsolata=3.000 - font-ttf-source-code-pro=2.038 - font-ttf-ubuntu=0.83 - - fontconfig=2.14.2 + - fontconfig=2.15.0 - fonts-conda-ecosystem=1 - fonts-conda-forge=1 - - fonttools=4.54.1 + - fonttools=4.55.0 + - freeglut=3.2.2 - freetype=2.12.1 + - fsspec=2024.10.0 + - gflags=2.2.2 + - glog=0.7.1 - graphite2=1.3.13 + - h2=4.1.0 - harfbuzz=9.0.0 - hdf4=4.2.15 - hdf5=1.14.4 + - hpack=4.0.0 + - hyperframe=6.0.1 - icu=75.1 - - identify=2.6.1 + - identify=2.6.2 + - idna=3.10 - importlib-metadata=8.5.0 + - importlib_resources=6.4.5 - iniconfig=2.0.0 - ipdb=0.13.13 - - ipython=8.28.0 + - ipython=8.29.0 - isort=5.13.2 - - jedi=0.19.1 + - jasper=4.2.4 + - jedi=0.19.2 - jinja2=3.1.4 + - jsonschema=4.23.0 + - jsonschema-specifications=2024.10.1 - keyutils=1.6.1 - kiwisolver=1.4.7 - krb5=1.21.3 - lcms2=2.16 - ld_impl_linux-64=2.43 - lerc=4.0.0 + - libabseil=20240722.0 - libaec=1.1.3 + - libarrow=18.0.0 + - libarrow-acero=18.0.0 + - libarrow-dataset=18.0.0 + - libarrow-substrait=18.0.0 - libblas=3.9.0 - libbrotlicommon=1.1.0 - libbrotlidec=1.1.0 - libbrotlienc=1.1.0 - libcblas=3.9.0 - - libclang-cpp19.1=19.1.0 - - libclang13=19.1.0 + - libclang-cpp19.1=19.1.4 + - libclang13=19.1.4 + - libcrc32c=1.1.2 - libcups=2.3.3 - libcurl=8.10.1 - libdeflate=1.22 @@ -82,138 +135,170 @@ dependencies: - libedit=3.1.20191231 - libegl=1.7.0 - libev=4.33 - - libexpat=2.6.3 + - libevent=2.1.12 + - libexpat=2.6.4 - libffi=3.4.2 - - libgcc=14.1.0 - - libgcc-ng=14.1.0 - - libgfortran=14.1.0 - - libgfortran-ng=14.1.0 - - libgfortran5=14.1.0 + - libgcc=14.2.0 + - libgcc-ng=14.2.0 + - libgfortran=14.2.0 + - libgfortran5=14.2.0 - libgl=1.7.0 - - libglib=2.82.1 + - libglib=2.82.2 + - libglu=9.0.3 - libglvnd=1.7.0 - libglx=1.7.0 - - libgomp=14.1.0 + - libgomp=14.2.0 + - libgoogle-cloud=2.31.0 + - libgoogle-cloud-storage=2.31.0 + - libgrpc=1.67.1 - libiconv=1.17 - libjpeg-turbo=3.0.0 - liblapack=3.9.0 - - libllvm19=19.1.1 + - libllvm19=19.1.4 - libnetcdf=4.9.2 - - libnghttp2=1.58.0 + - libnghttp2=1.64.0 - libnsl=2.0.1 - libntlm=1.4 - - libopenblas=0.3.27 + - libopenblas=0.3.28 - libopengl=1.7.0 + - libparquet=18.0.0 - libpciaccess=0.18 - libpng=1.6.44 - - libpq=17.0 - - libsqlite=3.46.1 + - libpq=17.2 + - libprotobuf=5.28.2 + - libre2-11=2024.07.02 + - libsqlite=3.47.0 - libssh2=1.11.0 - - libstdcxx=14.1.0 - - libstdcxx-ng=14.1.0 + - libstdcxx=14.2.0 + - libstdcxx-ng=14.2.0 + - libthrift=0.21.0 - libtiff=4.7.0 + - libutf8proc=2.8.0 - libuuid=2.38.1 - libwebp-base=1.4.0 - libxcb=1.17.0 - libxkbcommon=1.7.0 - - libxml2=2.12.7 + - libxml2=2.13.5 - libxslt=1.1.39 - - libzip=1.11.1 + - libzip=1.11.2 - libzlib=1.3.1 + - locket=1.0.0 + - lru-dict=1.3.0 + - lz4=4.3.3 - lz4-c=1.9.4 + - markdown=3.6 - markdown-it-py=3.0.0 - - markupsafe=3.0.0 + - markupsafe=3.0.2 - matplotlib=3.9.2 - matplotlib-base=3.9.2 - matplotlib-inline=0.1.7 - mccabe=0.7.0 - mdurl=0.1.2 + - msgpack-python=1.1.0 - munkres=1.1.4 - - mypy=1.11.2 + - mypy=1.13.0 - mypy_extensions=1.0.0 - mysql-common=9.0.1 - mysql-libs=9.0.1 - ncurses=6.5 - - netcdf4=1.7.1 + - netcdf4=1.7.2 - nodeenv=1.9.1 - - numpy=2.1.2 + - numpy=2.1.3 - openjpeg=2.5.2 - openldap=2.6.8 - - openssl=3.3.2 - - packaging=24.1 + - openssl=3.4.0 + - orc=2.0.3 + - packaging=24.2 - pandas=2.2.3 - parso=0.8.4 + - partd=1.4.2 - pathlib=1.0.1 - pathspec=0.12.1 - pcre2=10.44 + - pdbufr=0.11.0 - pexpect=4.9.0 - pickleshare=0.7.5 - - pillow=10.4.0 - - pip=24.2 + - pillow=11.0.0 + - pip=24.3.1 - pixman=0.43.2 + - pkgutil-resolve-name=1.3.10 - platformdirs=4.3.6 - pluggy=1.5.0 - - pre-commit=4.0.0 + - pre-commit=4.0.1 - pre-commit-hooks=5.0.0 - prompt-toolkit=3.0.48 - - psutil=6.0.0 + - psutil=6.1.0 - pthread-stubs=0.4 - ptyprocess=0.7.0 - pure_eval=0.2.3 + - pyarrow=18.0.0 + - pyarrow-core=18.0.0 - pycodestyle=2.12.1 - pycparser=2.22 - - pydantic=2.9.2 - - pydantic-core=2.23.4 + - pydantic=2.10.0 + - pydantic-core=2.27.0 - pydocstyle=6.3.0 - pyflakes=3.2.0 - pygments=2.18.0 - pylint=3.3.1 - - pyparsing=3.1.4 - - pyside6=6.7.3 + - pyparsing=3.2.0 + - pyside6=6.8.0.2 + - pysocks=1.7.1 - pytest=8.3.3 - python=3.10.8 - - python-dateutil=2.9.0 + - python-dateutil=2.9.0.post0 + - python-eccodes=2.37.0 - python-tzdata=2024.2 - python_abi=3.10 - pytoolconfig=1.2.5 - pytz=2024.1 - pyyaml=6.0.2 - qhull=2020.2 - - qt6-main=6.7.3 + - qt6-main=6.8.0 + - re2=2024.07.02 - readline=8.2 - - regex=2024.9.11 - - rich=13.9.2 + - referencing=0.35.1 + - regex=2024.11.6 + - requests=2.32.3 + - rich=13.9.4 - rope=1.13.0 + - rpds-py=0.21.0 - rstcheck=6.2.4 - rstcheck-core=1.2.1 - ruamel.yaml=0.18.6 - ruamel.yaml.clib=0.2.8 + - s2n=1.5.9 - scipy=1.14.1 - - setuptools=75.1.0 + - setuptools=75.6.0 - shellingham=1.5.4 - six=1.16.0 - snappy=1.2.1 - snowballstemmer=2.2.0 + - sortedcontainers=2.4.0 - stack_data=0.6.2 + - tblib=3.0.0 - tk=8.6.13 - toml=0.10.2 - - tomli=2.0.2 + - tomli=2.1.0 - tomlkit=0.13.2 + - toolz=1.0.0 - tornado=6.4.1 + - tqdm=4.67.0 - traitlets=5.14.3 - - typer=0.12.5 - - typer-slim=0.12.5 - - typer-slim-standard=0.12.5 + - typer=0.13.1 + - typer-slim=0.13.1 + - typer-slim-standard=0.13.1 - typing-extensions=4.12.2 - typing_extensions=4.12.2 - tzdata=2024b - ukkonen=1.0.1 - unicodedata2=15.1.0 - - virtualenv=20.26.6 + - urllib3=2.2.3 + - virtualenv=20.27.1 - wayland=1.23.1 - wcwidth=0.2.13 - - wheel=0.44.0 - - xarray=2024.9.0 + - wheel=0.45.0 + - xarray=2024.10.0 - xcb-util=0.4.1 - xcb-util-cursor=0.1.5 - xcb-util-image=0.4.0 @@ -226,7 +311,7 @@ dependencies: - xorg-libx11=1.8.10 - xorg-libxau=1.0.11 - xorg-libxcomposite=0.4.6 - - xorg-libxcursor=1.2.2 + - xorg-libxcursor=1.2.3 - xorg-libxdamage=1.1.6 - xorg-libxdmcp=1.1.5 - xorg-libxext=1.3.6 @@ -237,12 +322,15 @@ dependencies: - xorg-libxtst=1.2.5 - xorg-libxxf86vm=1.1.5 - xorg-xorgproto=2024.1 + - xyzservices=2024.9.0 - xz=5.2.6 - yaml=0.2.5 - - zipp=3.20.2 + - zict=3.0.0 + - zipp=3.21.0 - zlib=1.3.1 + - zstandard=0.23.0 - zstd=1.5.6 - pip: - - coverage==7.6.1 + - coverage==7.6.7 - flake8-pyproject==1.2.3 - - pytest-cov==5.0.0 + - pytest-cov==6.0.0 diff --git a/requirements/requirements.yml b/requirements/requirements.yml index a982d2e7..943fcbe8 100644 --- a/requirements/requirements.yml +++ b/requirements/requirements.yml @@ -5,7 +5,10 @@ dependencies: - python==3.10.8 - pip>=22.3 # runtime + - array-api-compat - click>=7.1.2 + - earthkit-data<0.11 + - eccodes>=2.38.0 - Jinja2>=3.0.1 - matplotlib>=3.2.1 - netCDF4>=1.5.3 diff --git a/setup_env.sh b/setup_env.sh index a2020bfb..41744f78 100755 --- a/setup_env.sh +++ b/setup_env.sh @@ -10,23 +10,22 @@ fi # Default env names DEFAULT_ENV_NAME="probtest" +CONDA=conda # Default options ENV_NAME="${DEFAULT_ENV_NAME}" PYVERSION=3.10.8 PINNED=true EXPORT=false -CONDA=conda HELP=false -help_msg="Usage: $(basename "${0}") [-n NAME] [-p VER] [-u] [-e] [-m] [-h] +help_msg="Usage: $(basename "${0}") [-n NAME] [-p VER] [-u] [-e] [-h] Options: -n NAME Env name [default: ${DEFAULT_ENV_NAME} -p VER Python version [default: ${PYVERSION}] -u Use unpinned requirements (minimal version restrictions) -e Export environment files (requires -u) - -m Use mamba instead of conda -h Print this help message and exit " @@ -37,7 +36,6 @@ while getopts n:p:defhimu flag; do p) PYVERSION=${OPTARG};; e) EXPORT=true;; h) HELP=true;; - m) CONDA=mamba;; u) PINNED=false;; ?) echo -e "\n${help_msg}" >&2; exit 1;; esac @@ -68,3 +66,24 @@ else ${CONDA} env export --name ${ENV_NAME} --no-builds | \grep -v '^prefix:' > requirements/environment.yml || exit fi fi + + +# Setting ECCODES_DEFINITION_PATH: +${CONDA} activate ${ENV_NAME} + +CONDA_LOC=${CONDA_PREFIX} +DEFINITION_VERSION="v2.36.0.2" +DEFINITION_PATH_DEFAULT=${CONDA_LOC}/share/eccodes +DEFINITION_PATH_RESOURCES=${CONDA_LOC}/share/eccodes-cosmo-resources_${DEFINITION_VERSION} + +git clone -b ${DEFINITION_VERSION} https://github.com/COSMO-ORG/eccodes-cosmo-resources.git ${DEFINITION_PATH_RESOURCES} || exit + +${CONDA} env config vars set ECCODES_DEFINITION_PATH=${DEFINITION_PATH_DEFAULT}/definitions:${DEFINITION_PATH_RESOURCES}/definitions +${CONDA} env config vars set ECCODES_SAMPLES_PATH=${DEFINITION_PATH_DEFAULT}/samples +${CONDA} env config vars set GRIB_DEFINITION_PATH=${DEFINITION_PATH_DEFAULT}/definitions:${DEFINITION_PATH_RESOURCES}/definitions +${CONDA} env config vars set GRIB_SAMPLES_PATH=${DEFINITION_PATH_DEFAULT}/samples + +echo "Variables saved to environment: " +${CONDA} env config vars list + +${CONDA} deactivate diff --git a/templates/ICON.jinja b/templates/ICON.jinja index b97ade1e..2ff2cec7 100644 --- a/templates/ICON.jinja +++ b/templates/ICON.jinja @@ -11,6 +11,7 @@ "member_type": "{{member_type}}", "factor": 5, "file_specification": [{ + "GRIB": {"format": "GRIB", "time_dim": "step", "horizontal_dims": ["values"], "var_excl": ["tlon", "tlat", "vlon", "vlat", "ulon", "ulat", "h", "slor", "anor", "isor", "sdor"], "fill_value_key": "_FillValue"}, "latlon": { "format": "netcdf", "time_dim": "time", "horizontal_dims": ["lat:lon"] }, "meteogram": { "format": "netcdf", "time_dim": "time", "horizontal_dims": ["max_nlevs:nstations", "nstations"] }, "dace":{ "format": "netcdf", "time_dim": null, "horizontal_dims": ["d_body"]}, diff --git a/tests/engine/test_perturb.py b/tests/engine/test_perturb.py index f78b5c69..a420bbe6 100644 --- a/tests/engine/test_perturb.py +++ b/tests/engine/test_perturb.py @@ -12,8 +12,8 @@ from engine.perturb import perturb_array -atype = np.float32 -AMPLITUDE = atype(1e-14) +ATYPE = np.float32 +AMPLITUDE = ATYPE(1e-14) ARRAY_DIM = 100 @@ -33,8 +33,8 @@ def fixture_create_nc_files(tmp_dir): def test_perturb_array(): # create two arrays, perturb one. - x1 = np.ones((ARRAY_DIM, ARRAY_DIM), dtype=atype) - x2 = np.ones((ARRAY_DIM, ARRAY_DIM), dtype=atype) + x1 = np.ones((ARRAY_DIM, ARRAY_DIM), dtype=ATYPE) + x2 = np.ones((ARRAY_DIM, ARRAY_DIM), dtype=ATYPE) x_perturbed = perturb_array(x2, 10, AMPLITUDE) # compute some stats and do assertions diff --git a/tests/engine/test_stats.py b/tests/engine/test_stats.py index a4701695..0893d73c 100644 --- a/tests/engine/test_stats.py +++ b/tests/engine/test_stats.py @@ -4,10 +4,9 @@ dataframes from both NetCDF and CSV files. """ -import os -import unittest - +import eccodes import numpy as np +import pytest from netCDF4 import Dataset # pylint: disable=no-name-in-module from engine.stats import create_stats_dataframe @@ -16,6 +15,16 @@ HOR_DIM_SIZE = 100 HEIGHT_DIM_SIZE = 5 +TIME_DIM_GRIB_SIZE = 1 +STEP_DIM_SIZE = 1 +HEIGHT_DIM_GRIB_SIZE = 1 +HORIZONTAL_DIM_GRIB_SIZE = 6114 + +GRIB_FILE_NAME = "test_stats_grib.grib" +STATS_FILE_NAME = "test_stats.csv" +NC_FILE_NAME = "test_stats.nc" +NC_FILE_GLOB = "test_s*.nc" + def initialize_dummy_netcdf_file(name): data = Dataset(name, "w") @@ -35,147 +44,212 @@ def initialize_dummy_netcdf_file(name): return data -class TestStatsNetcdf(unittest.TestCase): - """ - Unit test class for validating statistical calculations from NetCDF files. - - This class tests the accuracy of statistical calculations (mean, max, min) - performed on data extracted from NetCDF files. - It ensures that the statistics DataFrame produced from the NetCDF data - matches expected values. - """ - - nc_file_name = "test_stats.nc" - nc_file_glob = "test_s*.nc" - stats_file_names = "test_stats.csv" - - def setUp(self): - data = initialize_dummy_netcdf_file(self.nc_file_name) - - data.createVariable("v1", np.float64, dimensions=("t", "z", "x")) - data.variables["v1"][:] = np.ones( - (TIME_DIM_SIZE, HEIGHT_DIM_SIZE, HOR_DIM_SIZE) - ) - data.variables["v1"][:, :, 0] = 0 - data.variables["v1"][:, :, -1] = 2 - - data.createVariable("v2", np.float64, dimensions=("t", "x"), fill_value=42) - data.variables["v2"][:] = np.ones((TIME_DIM_SIZE, HOR_DIM_SIZE)) * 2 - data.variables["v2"][:, 0] = 1 - data.variables["v2"][:, 1] = 42 # shall be ignored in max-statistic - data.variables["v2"][:, -1] = 3 - - data.createVariable("v3", np.float64, dimensions=("t", "x")) - data.variables["v3"][:] = np.ones((TIME_DIM_SIZE, HOR_DIM_SIZE)) * 3 - data.variables["v3"][:, 0] = 2 - data.variables["v3"][:, -1] = 4 - - data.close() - - def tear_down(self): - os.remove(self.nc_file_name) - os.remove(self.stats_file_names) - - def test_stats(self): - file_specification = { - "Test data": { - "format": "netcdf", - "time_dim": "t", - "horizontal_dims": ["x"], - "fill_value_key": "_FillValue", # should be the name for fill_value - }, - } - - df = create_stats_dataframe( - input_dir=".", - file_id=[["Test data", self.nc_file_glob]], - stats_file_name=self.stats_file_names, - file_specification=file_specification, +def add_variable_to_grib(filename, dict_data): + with open(filename, "wb") as f_out: + for short_name in list(dict_data.keys()): + gid = eccodes.codes_grib_new_from_samples("reduced_rotated_gg_sfc_grib2") + eccodes.codes_set(gid, "edition", 2) + eccodes.codes_set(gid, "centre", "lssw") + eccodes.codes_set(gid, "dataDate", 20230913) + eccodes.codes_set(gid, "dataTime", 0) + eccodes.codes_set(gid, "stepRange", 0) + eccodes.codes_set(gid, "typeOfLevel", "surface") + eccodes.codes_set(gid, "level", 0) + eccodes.codes_set(gid, "shortName", short_name) + eccodes.codes_set_values(gid, dict_data[short_name]) + eccodes.codes_write(gid, f_out) + eccodes.codes_release(gid) + + +@pytest.fixture +def setup_grib_file(tmp_path): + array_v = np.ones( + ( + TIME_DIM_GRIB_SIZE, + STEP_DIM_SIZE, + HEIGHT_DIM_GRIB_SIZE, + HORIZONTAL_DIM_GRIB_SIZE, ) - - # check that the mean/max/min are correct - expected = np.array( - [ - [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], - [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], - [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], - [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], - [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], - [2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0], - [3.0, 4.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 2.0], - ] + ) + array_v[:, :, :, 0] = 0 + array_v[:, :, :, -1] = 2 + + array_t = ( + np.ones( + ( + TIME_DIM_GRIB_SIZE, + STEP_DIM_SIZE, + HEIGHT_DIM_GRIB_SIZE, + HORIZONTAL_DIM_GRIB_SIZE, + ) ) + * 3 + ) + array_t[:, :, :, 0] = 2 + array_t[:, :, :, -1] = 4 + + dict_data = {"t": array_t, "v": array_v} + + # This would be where your grib file is created + add_variable_to_grib(tmp_path / GRIB_FILE_NAME, dict_data) + + +@pytest.mark.usefixtures("setup_grib_file") +def test_stats_grib(tmp_path): + file_specification = { + "Test data": { + "format": "grib", + "time_dim": "step", + "horizontal_dims": ["values"], + "var_excl": [], + "fill_value_key": "_FillValue", # This should be the name for fill_value. + }, + } + + df = create_stats_dataframe( + input_dir=str(tmp_path), + file_id=[["Test data", GRIB_FILE_NAME]], + stats_file_name=tmp_path / STATS_FILE_NAME, + file_specification=file_specification, + ) + + # check that the mean/max/min are correct + expected = np.array( + [ + [1.0, 2.0, 0.0], + [3.0, 4.0, 2.0], + ] + ) + + assert np.array_equal( + df.values, expected + ), f"Stats dataframe incorrect. Difference:\n{df.values == expected}" + + +@pytest.fixture(name="setup_netcdf_file") +def fixture_setup_netcdf_file(tmp_path): + """Fixture to create and initialize a dummy NetCDF file for testing.""" + + data = initialize_dummy_netcdf_file(tmp_path / NC_FILE_NAME) + + # Creating variable "v1" with specified dimensions and setting its values + data.createVariable("v1", np.float64, dimensions=("t", "z", "x")) + data.variables["v1"][:] = np.ones((TIME_DIM_SIZE, HEIGHT_DIM_SIZE, HOR_DIM_SIZE)) + data.variables["v1"][:, :, 0] = 0 + data.variables["v1"][:, :, -1] = 2 + + # Creating variable "v2" with fill_value, and setting its values + data.createVariable("v2", np.float64, dimensions=("t", "x"), fill_value=42) + data.variables["v2"][:] = np.ones((TIME_DIM_SIZE, HOR_DIM_SIZE)) * 2 + data.variables["v2"][:, 0] = 1 + data.variables["v2"][:, 1] = 42 # should be ignored in max-statistic + data.variables["v2"][:, -1] = 3 + + # Creating variable "v3" and setting its values + data.createVariable("v3", np.float64, dimensions=("t", "x")) + data.variables["v3"][:] = np.ones((TIME_DIM_SIZE, HOR_DIM_SIZE)) * 3 + data.variables["v3"][:, 0] = 2 + data.variables["v3"][:, -1] = 4 + + data.close() + + yield + + +def test_stats_netcdf(setup_netcdf_file, tmp_path): # pylint: disable=unused-argument + """Test that the statistics generated from the NetCDF file match the + expected values.""" + + file_specification = { + "Test data": { + "format": "netcdf", + "time_dim": "t", + "horizontal_dims": ["x"], + "fill_value_key": "_FillValue", # should be the name for fill_value + }, + } + + # Call the function to generate the statistics dataframe + df = create_stats_dataframe( + input_dir=str(tmp_path), + file_id=[["Test data", NC_FILE_GLOB]], + stats_file_name=tmp_path / STATS_FILE_NAME, + file_specification=file_specification, + ) + + # Define the expected values for comparison + expected = np.array( + [ + [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], + [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], + [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], + [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], + [1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0], + [2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0], + [3.0, 4.0, 2.0, 3.0, 4.0, 2.0, 3.0, 4.0, 2.0], + ] + ) + + # Check that the dataframe matches the expected values + assert np.array_equal( + df.values, expected + ), f"Stats dataframe incorrect. Difference:\n{df.values == expected}" + + +@pytest.fixture(name="setup_csv_file") +def fixture_setup_csv_file(tmp_path): + """ + Fixture to set up a temporary CSV file. + """ + dat_file_name = tmp_path / "test_stats_csv.dat" - self.assertTrue( - np.array_equal(df.values, expected), - f"stats dataframe incorrect. Difference:\n{df.values == expected}", - ) + # Create the CSV file with the necessary content + lines = ( + "time v1 v2 v3 v4 v5", + "10 1.4 15 16 17 18", + "20 2.4 25 26 27 28", + "30 3.4 35 36 37 38", + ) + with open(dat_file_name, "w", encoding="utf-8") as f: + f.write("\n".join(lines)) -class TestStatsCsv(unittest.TestCase): +def test_stats_csv(setup_csv_file, tmp_path): # pylint: disable=unused-argument """ - Test suite for validating statistical calculations and CSV file handling. - - This class contains unit tests for creating and validating statistics from a - CSV file. - The primary focus is on ensuring that the statistics calculated from the - input data match the expected values. - The CSV file used for testing is created and cleaned up during the test - lifecycle. + Test that the statistics generated from the CSV file match the expected values. """ - dat_file_name = "test_stats_csv.dat" - stats_file_name = "test_stats_csv.csv" - - def setUp(self): - lines = ( - "time v1 v2 v3 v4 v5", - "10 1.4 15 16 17 18", - "20 2.4 25 26 27 28", - "30 3.4 35 36 37 38", - ) - with open(self.dat_file_name, "w", encoding="utf-8") as f: - f.write("\n".join(lines)) - - def tear_down(self): - os.remove(self.dat_file_name) - os.remove(self.stats_file_name) - - def test_stats(self): - file_specification = { - "Test data": { - "format": "csv", - "parser_args": { - "delimiter": "\\s+", - "header": 0, - "index_col": 0, - }, + file_specification = { + "Test data": { + "format": "csv", + "parser_args": { + "delimiter": "\\s+", + "header": 0, + "index_col": 0, }, - } - - df = create_stats_dataframe( - input_dir=".", - file_id=[["Test data", self.dat_file_name]], - stats_file_name=self.stats_file_name, - file_specification=file_specification, - ) - - # check that the mean/max/min are correct (i.e. the same as in CSV) - expected = np.array( - [ - [1.4, 1.4, 1.4, 2.4, 2.4, 2.4, 3.4, 3.4, 3.4], - [15, 15, 15, 25, 25, 25, 35, 35, 35], - [16, 16, 16, 26, 26, 26, 36, 36, 36], - [17, 17, 17, 27, 27, 27, 37, 37, 37], - [18, 18, 18, 28, 28, 28, 38, 38, 38], - ], - ) - - self.assertTrue( - np.array_equal(df.values, expected), - f"stats dataframe incorrect. Difference:\n{df.values == expected}", - ) - - -if __name__ == "__main__": - unittest.main() + }, + } + + # Call the function that creates the stats DataFrame + df = create_stats_dataframe( + input_dir=str(tmp_path), + file_id=[["Test data", "test_stats_csv.dat"]], + stats_file_name=tmp_path / "test_stats_csv.csv", + file_specification=file_specification, + ) + + # Expected result + expected = np.array( + [ + [1.4, 1.4, 1.4, 2.4, 2.4, 2.4, 3.4, 3.4, 3.4], + [15, 15, 15, 25, 25, 25, 35, 35, 35], + [16, 16, 16, 26, 26, 26, 36, 36, 36], + [17, 17, 17, 27, 27, 27, 37, 37, 37], + [18, 18, 18, 28, 28, 28, 38, 38, 38], + ] + ) + + # Assert the DataFrame matches the expected values + assert np.array_equal( + df.values, expected + ), f"Stats DataFrame incorrect. Difference:\n{df.values != expected}" diff --git a/tests/util/test_dataframe_ops.py b/tests/util/test_dataframe_ops.py index ec58928c..306e2b24 100644 --- a/tests/util/test_dataframe_ops.py +++ b/tests/util/test_dataframe_ops.py @@ -4,9 +4,40 @@ from unittest.mock import patch +import numpy as np import pandas as pd -from util.dataframe_ops import parse_check +from util.dataframe_ops import adjust_time_index, parse_check + + +def test_adjust_time_index(): + # Create a sample DataFrame with MultiIndex for 'time' and 'statistic' + index = pd.MultiIndex.from_product( + [["0 days 00:00:00", "0 days 00:01:00"], ["mean", "max", "min"]], + names=["time", "statistic"], + ) + data = np.random.random((5, 6)) # Sample data + df = pd.DataFrame(data, columns=index) + + # Apply the function to a list of DataFrames (in this case, just one DataFrame) + result = adjust_time_index([df])[0] + + # Create the expected new time index based on unique statistics + expected_time_values = np.repeat( + range(2), 3 + ) # Two unique time points, three statistics per time + expected_index = pd.MultiIndex.from_arrays( + [expected_time_values, ["mean", "max", "min"] * 2], + names=["time", "statistic"], + ) + + # Verify that the new MultiIndex matches the expected index + assert result.columns.equals( + expected_index + ), "The time index was not adjusted correctly." + + # Verify that the data remains unchanged + pd.testing.assert_frame_equal(df, result, check_like=True) @patch("util.dataframe_ops.parse_probtest_csv") diff --git a/tests/util/test_model_output_parser.py b/tests/util/test_model_output_parser.py new file mode 100644 index 00000000..fbda4b65 --- /dev/null +++ b/tests/util/test_model_output_parser.py @@ -0,0 +1,114 @@ +""" +This module contains unit tests for the `model_output_parser` module. +""" + +from unittest.mock import MagicMock + +import pytest + +from util.model_output_parser import get_dataset + + +@pytest.fixture(name="mock_ds_grib") +def fixture_mock_ds_grib(): + """ + Fixture that creates a mock GRIB object to simulate different cases of data + selection and exceptions that `get_ds` needs to handle. + """ + ds_grib = MagicMock() + + # Simulating the successful selection and conversion to xarray + ds_grib.sel.return_value.to_xarray.return_value = "valid_xarray" + + # Simulating the metadata retrieval (unique values) + ds_grib.sel.return_value.metadata.side_effect = [ + ["forecast"], # stepType + [100], # numberOfPoints + ["hours"], # stepUnits + ["analysis"], # dataType + ["regular_ll"], # gridType + ] + + return ds_grib + + +def test_get_ds_success(mock_ds_grib): + """ + Test case where get_ds successfully retrieves the dataset on the first attempt. + """ + pid = 1 + lev = "surface" + + result = get_dataset(mock_ds_grib, pid, lev) + + # Ensure the dataset is selected once + mock_ds_grib.sel.assert_called_once_with(paramId=pid, typeOfLevel=lev) + + # The result should contain the mocked xarray dataset + assert result == ["valid_xarray"] + + +@pytest.mark.parametrize( + "to_xarray_return_value, expected_result", + [ + ((KeyError(), "valid_stepType_xarray"), ["valid_stepType_xarray"]), + ( + (KeyError(), KeyError(), "valid_numberOfPoints_xarray"), + ["valid_numberOfPoints_xarray"], + ), + ( + (KeyError(), KeyError(), KeyError(), "valid_stepUnits_xarray"), + ["valid_stepUnits_xarray"], + ), + ( + (KeyError(), KeyError(), KeyError(), KeyError(), "valid_dataType_xarray"), + ["valid_dataType_xarray"], + ), + ( + ( + KeyError(), + KeyError(), + KeyError(), + KeyError(), + KeyError(), + "valid_gridType_xarray", + ), + ["valid_gridType_xarray"], + ), + ], +) +def test_get_ds_recursive_selection( + mock_ds_grib, to_xarray_return_value, expected_result +): + """ + Test case where get_ds recursively selects the dataset by metadata fields. + """ + pid = 1 + lev = "surface" + + mock_ds_grib.sel.return_value.to_xarray.side_effect = to_xarray_return_value + + result = get_dataset(mock_ds_grib, pid, lev) + + # Ensure the recursive logic is triggered by calling sel multiple times + assert mock_ds_grib.sel.call_count >= len(to_xarray_return_value) + + # The result should contain the mocked xarray dataset + assert result == expected_result + + +def test_get_ds_keyerror_handling(caplog, mock_ds_grib): + """ + Test case where get_ds fails to retrieve data and handles multiple KeyErrors. + """ + pid = 1 + lev = "surface" + + # Simulate KeyErrors for all attempts to select datasets + mock_ds_grib.sel.return_value.to_xarray.side_effect = KeyError() + + result = get_dataset(mock_ds_grib, pid, lev) + + # Assert that the warning was logged + assert "GRIB file of level surface and paramId 1 cannot be read." in caplog.text + assert not result diff --git a/util/dataframe_ops.py b/util/dataframe_ops.py index 711ac0e0..3a573f0f 100644 --- a/util/dataframe_ops.py +++ b/util/dataframe_ops.py @@ -12,7 +12,7 @@ import pandas as pd from util.constants import CHECK_THRESHOLD, compute_statistics -from util.file_system import file_names_from_pattern +from util.file_system import get_file_names_from_pattern from util.log_handler import logger from util.model_output_parser import model_output_parser @@ -88,6 +88,13 @@ def read_input_file(label, file_name, specification): def df_from_file_ids(file_id, input_dir, file_specification): """ + Collect data frames for each combination of file id (fid) and specification + (spec). + Frames for the same fid and spec represent different timestamps and have to + be concatenated along time-axis (axis=1). + Time-concatenated frames from different ids and/or specifications will be + concatenated along variable-axis (axis=0). + file_id: [[file_type, file_pattern], [file_type, file_pattern], ...] List of 2-tuples. The 2-tuple combines two strings. The first sets the file_type and must be a key in file_specification. The second string @@ -111,14 +118,9 @@ def df_from_file_ids(file_id, input_dir, file_specification): separated by ":". """ - # Collect data frames for each combination of file id (fid) and - # specification (spec). Frames for the same fid and spec represent - # different timestamps and have to be concatenated along time-axis (axis=1). - # Time-concatenated frames from different ids and/or specifications will be - # concatenated along variable-axis (axis=0). fid_dfs = [] for file_type, file_pattern in file_id: - input_files, err = file_names_from_pattern(input_dir, file_pattern) + input_files, err = get_file_names_from_pattern(input_dir, file_pattern) if err > 0: logger.info( "Can not find any files for file_pattern %s. Continue.", file_pattern @@ -152,7 +154,12 @@ def df_from_file_ids(file_id, input_dir, file_specification): logger.error("Could not find any file.") sys.exit(2) - fid_dfs = unify_time_index(fid_dfs) + # workaround for not properly set time column + try: + fid_dfs = unify_time_index(fid_dfs) + except ValueError: + fid_dfs = adjust_time_index(fid_dfs) + # different file IDs will have different variables but with same timestamps: # concatenate along variable axis df = pd.concat(fid_dfs, axis=0) @@ -185,6 +192,58 @@ def unify_time_index(fid_dfs): return fid_dfs_out +def adjust_time_index(fid_dfs): + """ + Adjust the 'time' level of the MultiIndex in DataFrame columns by replacing + it with a sequential range based on the number of unique 'statistic' values. + + Parameters: + ----------- + fid_dfs : list of pandas.DataFrame + A list of DataFrames with MultiIndex columns containing 'time' and + 'statistic' levels. + + Returns: + -------- + fid_dfs_out : list of pandas.DataFrame + DataFrames with corrected 'time' values in their MultiIndex columns. + """ + fid_dfs_out = [] + for df in fid_dfs: + # Get the existing MultiIndex + current_multiindex = df.columns + + # Find the number of unique values in the 'statistic' level + unique_statistic_count = current_multiindex.get_level_values( + "statistic" + ).nunique() + + # Create a sequential integer range for 'time' values (based on the + # number of unique statistics) + new_time_values = list( + range( + len(current_multiindex.get_level_values("time")) + // unique_statistic_count + ) + ) + + # Repeat these integer values to match the length of your columns + new_time_repeated = np.repeat(new_time_values, unique_statistic_count) + + # Construct a new MultiIndex with updated 'time' values + new_multiindex = pd.MultiIndex.from_arrays( + [new_time_repeated, current_multiindex.get_level_values("statistic")], + names=["time", "statistic"], + ) + + # Assign the new MultiIndex to the DataFrame + df.columns = new_multiindex + + fid_dfs_out.append(df) + + return fid_dfs_out + + def check_intersection(df_ref, df_cur): # Check if variable names in reference and test case have any intersection # Check if numbers of time steps agree diff --git a/util/file_system.py b/util/file_system.py index 99ad3baf..8dda9e73 100644 --- a/util/file_system.py +++ b/util/file_system.py @@ -7,7 +7,7 @@ from util.log_handler import logger -def file_names_from_pattern(dir_name, file_pattern): +def get_file_names_from_pattern(dir_name, file_pattern): """ Search for all file patching file_pattern in directory dir_name @@ -38,5 +38,6 @@ def file_names_from_pattern(dir_name, file_pattern): for f in Path(dir_name).glob("*"): logger.debug(f.name) err = 1 + file_names = sorted(file_names) return file_names, err diff --git a/util/model_output_parser.py b/util/model_output_parser.py index 882fb1af..4a2d8ea4 100644 --- a/util/model_output_parser.py +++ b/util/model_output_parser.py @@ -26,6 +26,7 @@ import sys from collections.abc import Iterable +import earthkit.data import numpy as np import pandas as pd import xarray @@ -48,7 +49,7 @@ def parse_netcdf(file_id, filename, specification): var_dfs = [] for v in var_tmp: - sub_df = dataframe_from_ncfile( + sub_df = create_statistics_dataframe( file_id=file_id, filename=filename, varname=v, @@ -63,6 +64,108 @@ def parse_netcdf(file_id, filename, specification): return var_dfs +def parse_grib(file_id, filename, specification): + logger.debug("parse GRIB file %s", filename) + time_dim = specification["time_dim"] + horizontal_dims = specification["horizontal_dims"] + fill_value_key = specification.get("fill_value_key", None) + + ds_grib = earthkit.data.from_source("file", filename) + short_name_excl = specification["var_excl"] + + short_names = np.unique(ds_grib.metadata("shortName")) + short_names = short_names[ + np.isin(short_names, short_name_excl, invert=True, assume_unique=True) + ].tolist() + + level_types = np.unique(ds_grib.metadata("typeOfLevel")).tolist() + + var_dfs = [] + for lev in level_types: + param_ids = np.unique( + ds_grib.sel(typeOfLevel=lev, shortName=short_names).metadata("paramId") + ).tolist() + for pid in param_ids: + ds_temp_list = get_dataset(ds_grib, pid, lev) + for ds_temp in ds_temp_list: + v = list(ds_temp.keys())[0] + + dim_to_squeeze = [ + dim + for dim, size in zip(ds_temp[v].dims, ds_temp[v].shape) + if size == 1 and dim != time_dim + ] + ds = ds_temp.squeeze(dim=dim_to_squeeze) + + sub_df = create_statistics_dataframe( + file_id=file_id, + filename=filename, + varname=v, + time_dim=time_dim, + horizontal_dims=horizontal_dims, + xarray_ds=ds, + fill_value_key=fill_value_key, + ) + var_dfs.append(sub_df) + + return var_dfs + + +def get_dataset(ds_grib, pid, lev): + """ + Retrieve datasets from a GRIB file based on specified parameters and + hierarchical metadata. + + This function attempts to extract data from the GRIB file by selecting + fields that match the given `paramId` and `typeOfLevel`. If the initial + selection fails due to missing or mismatched metadata, the function + will explore other metadata fields such as `stepType`, `numberOfPoints`, + `stepUnits`, `dataType`, and `gridType` to find matching datasets. + + Parameters: + ----------- + ds_grib : GRIB object + The GRIB file object to extract data from. + pid : int + The parameter ID (`paramId`) to select in the GRIB file. + lev : str + The level type (`typeOfLevel`) to select in the GRIB file. + + Returns: + -------- + ds_list : list + A list of xarray datasets that match the specified parameter and level, + with additional filtering based on hierarchical metadata fields. + """ + ds_list = [] + selectors = {"paramId": pid, "typeOfLevel": lev} + metadata_keys = ["stepType", "numberOfPoints", "stepUnits", "dataType", "gridType"] + + def recursive_select(selects, depth=0): + try: + ds = ds_grib.sel(**selects).to_xarray() + ds_list.append(ds) + except KeyError: + if depth == len(metadata_keys): # No more metadata keys to try + return + key = metadata_keys[depth] + try: + values = np.unique(ds_grib.sel(**selects).metadata(key)).tolist() + for value in values: + selects[key] = value + recursive_select(selects, depth + 1) # Recurse to next level + except KeyError: + pass + + # Try initial selection + recursive_select(selectors) + + if not ds_list: + logger.warning("GRIB file of level %s and paramId %s cannot be read.", lev, pid) + + return ds_list + + def __get_variables(data, time_dim, horizontal_dims): # return a list of variable names from the dataset data that have a time dimension # and horizontal dimension or in case there is no time dimension just the variables @@ -105,9 +208,41 @@ def __get_variables(data, time_dim, horizontal_dims): return variables -def dataframe_from_ncfile( +def create_statistics_dataframe( file_id, filename, varname, time_dim, horizontal_dims, xarray_ds, fill_value_key ): # pylint: disable=too-many-positional-arguments + """ + Create a DataFrame of statistical values for a given variable from an xarray + dataset. + + This function computes statistics (mean, max, min, etc.) over horizontal + dimensions and organizes them into a pandas DataFrame, indexed by file ID, + variable name, and height (if applicable). + The columns represent time and the computed statistics. + + Parameters: + ----------- + file_id : str + Identifier for the file. + filename : str + Name of the input file. + varname : str + Name of the variable to process. + time_dim : str + Name of the time dimension. + horizontal_dims : list + List of dimensions to compute statistics over. + xarray_ds : xarray.Dataset + The xarray dataset containing the data. + fill_value_key : str + Key for the fill value in the dataset. + + Returns: + -------- + pd.DataFrame + DataFrame with the computed statistics indexed by file ID, variable, and + height. + """ statistics = statistics_over_horizontal_dim( xarray_ds[varname], horizontal_dims, @@ -232,4 +367,5 @@ def parse_csv(file_id, filename, specification): model_output_parser = { # global lookup dict "netcdf": parse_netcdf, "csv": parse_csv, + "grib": parse_grib, }