From 5358f607fbe4fbb80db8b058da32701a3643f767 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Sat, 20 Jul 2024 13:26:44 +0200 Subject: [PATCH 1/2] Add type hints to EpochsArray constructor --- mne/io/array/array.py | 13 ++++++++++++- mne/io/base.py | 6 ++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/mne/io/array/array.py b/mne/io/array/array.py index dda73b80a23..5c1a5b187a6 100644 --- a/mne/io/array/array.py +++ b/mne/io/array/array.py @@ -5,8 +5,12 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +from typing import Literal + import numpy as np +from numpy.typing import NDArray +from ..._fiff.meas_info import Info from ...utils import _check_option, _validate_type, fill_doc, logger, verbose from ..base import BaseRaw @@ -52,7 +56,14 @@ class RawArray(BaseRaw): """ @verbose - def __init__(self, data, info, first_samp=0, copy="auto", verbose=None): + def __init__( + self, + data: NDArray[np.float64], + info: Info, + first_samp: int = 0, + copy: Literal["data", "info", "both", "auto"] | None = "auto", + verbose=None, + ): _validate_type(info, "info", "info") _check_option("copy", copy, ("data", "info", "both", "auto", None)) dtype = np.complex128 if np.any(np.iscomplex(data)) else np.float64 diff --git a/mne/io/base.py b/mne/io/base.py index a1485e6caf6..88d3c707ecd 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -22,11 +22,13 @@ from pathlib import Path import numpy as np +from numpy.typing import NDArray from .._fiff.compensator import make_compensator, set_current_comp from .._fiff.constants import FIFF from .._fiff.meas_info import ( ContainsMixin, + Info, SetChannelsMixin, _ensure_infos_match, _unit2human, @@ -194,7 +196,7 @@ class BaseRaw( @verbose def __init__( self, - info, + info: Info, preload=False, first_samps=(0,), last_samps=None, @@ -890,7 +892,7 @@ def get_data( tmin=None, tmax=None, verbose=None, - ): + ) -> NDArray[np.float64] | tuple[NDArray[np.float64], NDArray[np.float64]]: """Get data in the given range. Parameters From 9a7c1407c02d7aa97d1be5b7f87a37a3f4a1603b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Richard=20H=C3=B6chenberger?= Date: Sat, 20 Jul 2024 13:26:44 +0200 Subject: [PATCH 2/2] Forgot future imports --- mne/io/array/array.py | 15 ++++++++++++++- mne/io/base.py | 8 ++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/mne/io/array/array.py b/mne/io/array/array.py index dda73b80a23..41502aef927 100644 --- a/mne/io/array/array.py +++ b/mne/io/array/array.py @@ -5,8 +5,14 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +from __future__ import annotations + +from typing import Literal + import numpy as np +from numpy.typing import NDArray +from ..._fiff.meas_info import Info from ...utils import _check_option, _validate_type, fill_doc, logger, verbose from ..base import BaseRaw @@ -52,7 +58,14 @@ class RawArray(BaseRaw): """ @verbose - def __init__(self, data, info, first_samp=0, copy="auto", verbose=None): + def __init__( + self, + data: NDArray[np.float64], + info: Info, + first_samp: int = 0, + copy: Literal["data", "info", "both", "auto"] | None = "auto", + verbose=None, + ): _validate_type(info, "info", "info") _check_option("copy", copy, ("data", "info", "both", "auto", None)) dtype = np.complex128 if np.any(np.iscomplex(data)) else np.float64 diff --git a/mne/io/base.py b/mne/io/base.py index a1485e6caf6..674122e5d4e 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -10,6 +10,8 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. +from __future__ import annotations + import os import os.path as op import shutil @@ -22,11 +24,13 @@ from pathlib import Path import numpy as np +from numpy.typing import NDArray from .._fiff.compensator import make_compensator, set_current_comp from .._fiff.constants import FIFF from .._fiff.meas_info import ( ContainsMixin, + Info, SetChannelsMixin, _ensure_infos_match, _unit2human, @@ -194,7 +198,7 @@ class BaseRaw( @verbose def __init__( self, - info, + info: Info, preload=False, first_samps=(0,), last_samps=None, @@ -890,7 +894,7 @@ def get_data( tmin=None, tmax=None, verbose=None, - ): + ) -> NDArray[np.float64] | tuple[NDArray[np.float64], NDArray[np.float64]]: """Get data in the given range. Parameters