Skip to content

Commit

Permalink
Merge pull request #1970 from minrk/default-type
Browse files Browse the repository at this point in the history
document Generic types
  • Loading branch information
minrk authored Apr 10, 2024
2 parents 49f45dc + f5a664d commit 04b8426
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 9 deletions.
23 changes: 23 additions & 0 deletions docs/source/api/zmq.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,29 @@

## Basic Classes

````{note}
For typing purposes, `zmq.Context` and `zmq.Socket` are Generics,
which means they will accept any Context or Socket implementation.
The base `zmq.Context()` constructor returns the type
`zmq.Context[zmq.Socket[bytes]]`.
If you are using type annotations and want to _exclude_ the async subclasses,
use the resolved types instead of the base Generics:
```python
ctx: zmq.Context[zmq.Socket[bytes]] = zmq.Context()
sock: zmq.Socket[bytes]
```
in pyzmq 26, these are available as the Type Aliases (not actual classes!):
```python
ctx: zmq.SyncContext = zmq.Context()
sock: zmq.SyncSocket
```
````

### {class}`Context`

```{eval-rst}
Expand Down
13 changes: 13 additions & 0 deletions zmq/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,19 @@ def recv_multipart(
'recv_multipart', dict(flags=flags, copy=copy, track=track)
)

@overload # type: ignore
def recv(self, flags: int = 0, *, track: bool = False) -> Awaitable[bytes]: ...

@overload
def recv(
self, flags: int = 0, *, copy: Literal[True], track: bool = False
) -> Awaitable[bytes]: ...

@overload
def recv(
self, flags: int = 0, *, copy: Literal[False], track: bool = False
) -> Awaitable[_zmq.Frame]: ...

def recv( # type: ignore
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[bytes | _zmq.Frame]:
Expand Down
9 changes: 9 additions & 0 deletions zmq/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,12 @@ def __getitem__(self, key):

class TypedDict(Dict): # type: ignore
pass


if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
try:
from typing_extensions import TypeAlias
except ImportError:
TypeAlias = type # type: ignore
2 changes: 1 addition & 1 deletion zmq/eventloop/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,4 @@ def __init__(self: Context, *args: Any, **kwargs: Any) -> None:
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
super().__init__(*args, **kwargs) # type: ignore
16 changes: 10 additions & 6 deletions zmq/sugar/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
from weakref import WeakSet

import zmq
from zmq._typing import TypeAlias
from zmq.backend import Context as ContextBase
from zmq.constants import ContextOption, Errno, SocketOption
from zmq.error import ZMQError
from zmq.utils.interop import cast_int_addr

from .attrsettr import AttributeSetter, OptValT
from .socket import Socket
from .socket import Socket, SyncSocket

# notice when exiting, to avoid triggering term on exit
_exiting = False
Expand Down Expand Up @@ -78,18 +79,18 @@ class Context(ContextBase, AttributeSetter, Generic[_SocketType]):
_socket_class: type[_SocketType] = Socket # type: ignore

@overload
def __init__(self: Context[Socket], io_threads: int = 1): ...
def __init__(self: SyncContext, io_threads: int = 1): ...

@overload
def __init__(self: Context[Socket], io_threads: Context):
def __init__(self: SyncContext, io_threads: Context):
# this should be positional-only, but that requires 3.8
...

@overload
def __init__(self: Context[Socket], *, shadow: Context | int): ...
def __init__(self: SyncContext, *, shadow: Context | int): ...

def __init__(
self: Context[Socket],
self: SyncContext,
io_threads: int | Context = 1,
shadow: Context | int = 0,
) -> None:
Expand Down Expand Up @@ -415,4 +416,7 @@ def __delattr__(self, key: str) -> None:
del self.sockopts[opt]


__all__ = ['Context']
SyncContext: TypeAlias = Context[SyncSocket]


__all__ = ['Context', 'SyncContext']
6 changes: 4 additions & 2 deletions zmq/sugar/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from warnings import warn

import zmq
from zmq._typing import Literal
from zmq._typing import Literal, TypeAlias
from zmq.backend import Socket as SocketBase
from zmq.error import ZMQBindError, ZMQError
from zmq.utils import jsonapi
Expand Down Expand Up @@ -1107,4 +1107,6 @@ def disable_monitor(self) -> None:
self.monitor(None, 0)


__all__ = ['Socket']
SyncSocket: TypeAlias = Socket[bytes]

__all__ = ['Socket', 'SyncSocket']

0 comments on commit 04b8426

Please sign in to comment.