diff --git a/pynetdicom/transport.py b/pynetdicom/transport.py index d160c8b337..dd1f881062 100644 --- a/pynetdicom/transport.py +++ b/pynetdicom/transport.py @@ -1,21 +1,18 @@ """Implementation of the Transport Service.""" +import asyncio from copy import deepcopy from datetime import datetime import logging -import select import socket -try: - from SocketServer import TCPServer, ThreadingMixIn, BaseRequestHandler -except ImportError: - from socketserver import TCPServer, ThreadingMixIn, BaseRequestHandler + +import pynetdicom + try: import ssl _HAS_SSL = True except ImportError: _HAS_SSL = False -from struct import pack -import threading from pynetdicom import evt, _config from pynetdicom._globals import MODE_ACCEPTOR @@ -28,13 +25,12 @@ LOGGER = logging.getLogger('pynetdicom.transport') -class AssociationSocket: - """A wrapper for a `socket - `_ object. +class AssociationStream: + """A wrapper for a StreamReader and StreamWriter objects. .. versionadded:: 1.2 - Provides an interface for ``socket`` that is integrated + Provides an interface for transport that is integrated nicely with an :class:`~pynetdicom.association.Association` instance and the state machine. @@ -45,44 +41,29 @@ class AssociationSocket: :meth:`ready` will block for (default ``0.5``). A value of ``0`` specifies a poll and never blocks. A value of ``None`` blocks until a connection is ready. - socket : socket.socket or None - The wrapped socket, will be ``None`` if :meth:`close` is called. + transport : asyncio.transport or None + The wrapped transport, will be ``None`` if :meth:`close` is called. """ - def __init__(self, assoc, client_socket=None, address=('', 0)): - """Create a new :class:`AssociationSocket`. + def __init__(self, assoc, reader, writer): + """Create a new :class:`AssociationStream`. Parameters ---------- assoc : association.Association The :class:`~pynetdicom.association.Association` instance that will be using the socket to communicate. - client_socket : socket.socket, optional - The ``socket.socket`` to wrap, - if not supplied then a new socket will be created instead. - address : 2-tuple, optional - If *client_socket* is ``None`` then this is the ``(host, port)`` to - bind the newly created socket to, which by default will be - ``('', 0)``. + transport : asyncio.transport, optional + The ``asyncio.transport`` to wrap """ self._assoc = assoc - if client_socket is not None and address != ('', 0): - LOGGER.warning( - "AssociationSocket instantiated with both a 'client_socket' " - "and bind 'address'. The original socket will not be rebound" - ) + self._is_connected = True - self._ready = threading.Event() + self._reader = reader + self._writer = writer - if client_socket is None: - self.socket = self._create_socket(address) - self._is_connected = False - else: - self.socket = client_socket - self._is_connected = True - self._ready.set() - # Evt5: Transport connection indication - self.event_queue.put('Evt5') + # Evt5: Transport connection indication + self.event_queue.put('Evt5') self._tls_args = None self.select_timeout = 0.5 @@ -97,131 +78,22 @@ def assoc(self): def close(self): """Close the connection to the peer and shutdown the socket. - Sets :attr:`AssociationSocket.socket` to ``None`` once complete. + Sets :attr:`AssociationStream.socket` to ``None`` once complete. **Events Emitted** - Evt17: Transport connection closed """ - if self.socket is None or self._is_connected is False: + if self._writer is None or self._is_connected is False: return - try: - self.socket.shutdown(socket.SHUT_RDWR) - except socket.error: - pass - - self.socket.close() - self.socket = None + self._writer.close() + self._writer = None + self._reader = None self._is_connected = False # Evt17: Transport connection closed self.event_queue.put('Evt17') - def connect(self, address): - """Try and connect to a remote at `address`. - - **Events Emitted** - - - Evt2: Transport connection confirmed - - Evt17: Transport connection closed - - Parameters - ---------- - address : 2-tuple - The ``(host, port)`` IPv4 address to connect to. - """ - if self.socket is None: - self.socket = self._create_socket() - - try: - if self.tls_args: - context, server_hostname = self.tls_args - self.socket = context.wrap_socket( - self.socket, - server_side=False, - server_hostname=server_hostname, - ) - # Set ae connection timeout - self.socket.settimeout(self.assoc.connection_timeout) - # Try and connect to remote at (address, port) - # raises socket.error if connection refused - self.socket.connect(address) - # Clear ae connection timeout - self.socket.settimeout(None) - # Trigger event - connection open - evt.trigger(self.assoc, evt.EVT_CONN_OPEN, {'address' : address}) - self._is_connected = True - # Evt2: Transport connection confirmation - self.event_queue.put('Evt2') - except OSError as exc: - # Log connection failure - LOGGER.error( - "Association request failed: unable to connect to remote" - ) - LOGGER.error(f"TCP Initialisation Error: {exc}") - # Log exception if TLS issue to help with troubleshooting - if isinstance(exc, ssl.SSLError): - LOGGER.exception(exc) - - # Don't be tempted to replace this with a self.close() call - - # it doesn't work because `_is_connected` is False - if self.socket: - try: - self.socket.shutdown(socket.SHUT_RDWR) - except: - pass - self.socket.close() - self.socket = None - self.event_queue.put('Evt17') - finally: - self._ready.set() - - def _create_socket(self, address=('', 0)): - """Create a new IPv4 TCP socket and set it up for use. - - *Socket Options* - - - ``SO_REUSEADDR`` is 1 - - ``SO_RCVTIMEO`` is set to the Association's ``network_timeout`` - value. - - Parameters - ---------- - address : 2-tuple, optional - The ``(host, port)`` to bind the socket to. By default the socket - is bound to ``('', 0)``, i.e. the first available port. - - Returns - ------- - socket.socket - A bound and unconnected socket instance. - """ - # AF_INET: IPv4, SOCK_STREAM: TCP socket - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # SO_REUSEADDR: reuse the socket in TIME_WAIT state without - # waiting for its natural timeout to expire - # Allows local address reuse - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # If no timeout is set then recv() will block forever if - # the connection is kept alive with no data sent - # SO_RCVTIMEO: the timeout on receive calls in seconds - # set using a packed binary string containing two uint32s as - # (seconds, microseconds) - if self.assoc.network_timeout is not None: - timeout_seconds = int(self.assoc.network_timeout) - timeout_microsec = int(self.assoc.network_timeout % 1 * 1000) - sock.setsockopt( - socket.SOL_SOCKET, - socket.SO_RCVTIMEO, - pack('ll', timeout_seconds, timeout_microsec) - ) - - sock.bind(address) - - self._is_connected = False - - return sock - @property def event_queue(self): """Return the :class:`~pynetdicom.association.Association`'s event @@ -266,25 +138,13 @@ def ready(self): ``True`` if the socket has data ready to be read, ``False`` otherwise. """ - if self.socket is None or self._is_connected is False: - return False - - try: - # Use a timeout of 0 so we get an "instant" result - ready, _, _ = select.select([self.socket], [], [], 0) - except (socket.error, socket.timeout, ValueError): - # Evt17: Transport connection closed - self.event_queue.put('Evt17') + if self._writer is None or self._is_connected is False: return False - # An SSLSocket may have buffered data available that `select` - # is unaware of - see #528 - if _HAS_SSL and isinstance(self.socket, ssl.SSLSocket): - return bool(ready) or bool(self.socket.pending()) - - return bool(ready) + # TODO: not sure if this is a good idea + return self._writer._transport.is_reading() - def recv(self, nr_bytes): + async def recv(self, nr_bytes): """Read `nr_bytes` from the socket. *Events Emitted* @@ -301,31 +161,7 @@ def recv(self, nr_bytes): bytearray The data read from the socket. """ - bytestream = bytearray() - nr_read = 0 - # socket.recv() returns when the network buffer has been emptied - # not necessarily when the number of bytes requested have been - # read. Its up to us to keep calling recv() until we have all the - # data we want - # **BLOCKING** until either all the data is read or an error occurs - while nr_read < nr_bytes: - # Python docs recommend reading a relatively small power of 2 - # such as 4096 - bufsize = 4096 - if (nr_bytes - nr_read) < bufsize: - bufsize = nr_bytes - nr_read - - bytes_read = self.socket.recv(bufsize) - - # If socket.recv() reads 0 bytes then the connection has been - # broken, so return what we have so far - if not bytes_read: - return bytestream - - bytestream.extend(bytes_read) - nr_read += len(bytes_read) - - return bytestream + await self._reader.read(nr_bytes) def send(self, bytestream): """Try and send the data in `bytestream` to the remote. @@ -340,23 +176,10 @@ def send(self, bytestream): bytestream : bytes The data to send to the remote. """ - total_sent = 0 - length_data = len(bytestream) - try: - while total_sent < length_data: - # Returns the number of bytes sent - nr_sent = self.socket.send(bytestream[total_sent:]) - total_sent += nr_sent - - evt.trigger(self.assoc, evt.EVT_DATA_SENT, {'data' : bytestream}) - except (socket.error, socket.timeout): - # Evt17: Transport connection closed - self.event_queue.put('Evt17') - - def __str__(self): - """Return the string output for ``socket``.""" - return self.socket.__str__() + self._writer.write(bytestream) + evt.trigger(self.assoc, evt.EVT_DATA_SENT, {'data' : bytestream}) + # TODO: make tls work @property def tls_args(self): """Return the TLS context and hostname (if set) or ``None``. @@ -384,71 +207,68 @@ def tls_args(self, tls_args): self._tls_args = tls_args -class RequestHandler(BaseRequestHandler): - """Connection request handler for the ``AssociationServer``. +class AssociationProtocol(asyncio.streams.StreamReaderProtocol): + def __init__(self, server): + self._assoc = None + self._stream = None + self._server = server - .. versionadded:: 1.2 - - Attributes - ---------- - client_address : 2-tuple - The ``(host, port)`` of the remote. - request : socket.socket - The (unaccepted) client socket. - server : transport.AssociationServer or transport.ThreadedAssociationServer - The server that received the connection request. - """ - @property - def ae(self): - """Return the server's parent AE.""" - return self.server.ae + super().__init__( + stream_reader=asyncio.StreamReader(), + client_connected_cb=self._handle_connection + ) - def handle(self): + async def _handle_connection(self, reader, writer): """Handle an association request. * Creates a new Association acceptor instance and configures it. * Sets the Association's socket to the request's socket. * Starts the Association reactor. """ - assoc = self._create_association() - + self._create_association(reader=reader, writer=writer) # Trigger must be after binding the events evt.trigger( - assoc, evt.EVT_CONN_OPEN, {'address' : self.client_address} + self._assoc, evt.EVT_CONN_OPEN, {'address' : self.remote} ) + self._server._active_connections.append(self) + await self._assoc.start() - assoc.start() + @property + def ae(self): + """Return the server's parent AE.""" + return self._server.ae @property def local(self): """Return a 2-tuple of the local server's ``(host, port)`` address.""" - return self.server.server_address + return self._server.address @property def remote(self): """Return a 2-tuple of the remote client's ``(host, port)`` address.""" - return self.client_address + return self._stream._writer.get_extra_info('peername') - def _create_association(self): + def _create_association(self, reader, writer): """Create an :class:`Association` object for the current request. .. versionadded:: 1.5 """ from pynetdicom.association import Association - assoc = Association(self.ae, MODE_ACCEPTOR) - assoc._server = self.server + self._assoc = assoc = Association(self.ae, MODE_ACCEPTOR) + assoc._server = self._server # Set the thread name timestamp = datetime.strftime(datetime.now(), "%Y%m%d%H%M%S") assoc.name = f"AcceptorThread@{timestamp}" - sock = AssociationSocket(assoc, client_socket=self.request) - assoc.set_socket(sock) + self._stream = stream = AssociationStream(assoc, reader, writer) + # TODO: create stream + # assoc.set_stream(stream) # Association Acceptor object -> local AE assoc.acceptor.maximum_length = self.ae.maximum_pdu_size - assoc.acceptor.ae_title = self.server.ae_title + assoc.acceptor.ae_title = self._server.ae_title assoc.acceptor.address = self.local[0] assoc.acceptor.port = self.local[1] assoc.acceptor.implementation_class_uid = ( @@ -457,24 +277,30 @@ def _create_association(self): assoc.acceptor.implementation_version_name = ( self.ae.implementation_version_name ) - assoc.acceptor.supported_contexts = deepcopy(self.server.contexts) + assoc.acceptor.supported_contexts = deepcopy(self._server.contexts) # Association Requestor object -> remote AE assoc.requestor.address = self.remote[0] assoc.requestor.port = self.remote[1] # Bind events to handlers - for event in self.server._handlers: + for event in self._server._handlers: # Intervention events - if event.is_intervention and self.server._handlers[event]: - assoc.bind(event, *self.server._handlers[event]) + if event.is_intervention and self._server._handlers[event]: + assoc.bind(event, *self._server._handlers[event]) elif event.is_notification: - for handler in self.server._handlers[event]: + for handler in self._server._handlers[event]: assoc.bind(event, *handler) - return assoc + return assoc, stream + + def connection_lost(self, exc): + super().connection_lost(exc) + self._stream.close() + self._server._active_connections.remove(self) -class AssociationServer(TCPServer): + +class AssociationServer: """An Association server implementation. .. versionadded:: 1.2 @@ -533,11 +359,10 @@ def __init__(self, ae, address, ae_title, contexts, ssl_context=None, self.ae_title = ae_title self.contexts = contexts self.ssl_context = ssl_context - self.allow_reuse_address = True - self.socket = None + self.address = address - request_handler = request_handler or RequestHandler - super().__init__(address, request_handler, bind_and_activate=True) + # active connections (aka protocols) are stored here + self._active_connections = [] self.timeout = 60 @@ -550,6 +375,10 @@ def __init__(self, ae, address, ae_title, contexts, ssl_context=None, for evt_hh_args in (evt_handlers or {}): self.bind(*evt_hh_args) + + def __call__(self): + return AssociationProtocol(server=self) + def bind(self, event, handler, args=None): """Bind a callable `handler` to an `event`. @@ -604,11 +433,7 @@ def active_associations(self): """Return the server's running :class:`~pynetdicom.association.Association` acceptor instances """ - # Find all AcceptorThreads with `_server` as self - threads = [ - tt for tt in threading.enumerate() if 'AcceptorThread' in tt.name - ] - return [tt for tt in threads if tt._server is self] + return [x._assoc for x in self._active_connections if x._is_connected] def get_events(self): """Return a list of currently bound events. @@ -646,81 +471,6 @@ def get_handlers(self, event): return self._handlers[event] - def get_request(self): - """Handle a connection request. - - If :attr:`~AssociationServer.ssl_context` is set then the client socket - will be wrapped using - :meth:`SSLContext.wrap_socket()`. - - Returns - ------- - client_socket : socket.socket - The connection request. - address : 2-tuple - The client's address as ``(host, port)``. - """ - client_socket, address = self.socket.accept() - if self.ssl_context: - client_socket = self.ssl_context.wrap_socket( - client_socket, server_side=True - ) - - return client_socket, address - - def process_request(self, request, client_address): - """Process a connection request.""" - self.finish_request(request, client_address) - - def server_bind(self): - """Bind the socket and set the socket options. - - - ``socket.SO_REUSEADDR`` is set to ``1`` - - ``socket.SO_RCVTIMEO`` is set to - :attr:`AE.network_timeout - ` unless the - value is ``None`` in which case it will be left unset. - """ - # SO_REUSEADDR: reuse the socket in TIME_WAIT state without - # waiting for its natural timeout to expire - # Allows local address reuse - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # If no timeout is set then recv() will block forever if - # the connection is kept alive with no data sent - # SO_RCVTIMEO: the timeout on receive calls in seconds - # set using a packed binary string containing two uint32s as - # (seconds, microseconds) - if self.ae.network_timeout is not None: - timeout_seconds = int(self.ae.network_timeout) - timeout_microsec = int(self.ae.network_timeout % 1 * 1000) - self.socket.setsockopt( - socket.SOL_SOCKET, - socket.SO_RCVTIMEO, - pack('ll', timeout_seconds, timeout_microsec) - ) - - # Bind the socket to an (address, port) - # If address is '' then the socket is reachable by any - # address the machine may have, otherwise is visible only on that - # address - self.socket.bind(self.server_address) - self.server_address = self.socket.getsockname() - - def server_close(self): - """Close the server.""" - try: - self.socket.shutdown(socket.SHUT_RDWR) - except socket.error: - pass - - self.socket.close() - - def shutdown(self): - """Completely shutdown the server and close it's socket.""" - super().shutdown() - self.server_close() - self.ae._servers.remove(self) - @property def ssl_context(self): """Return the :class:`ssl.SSLContext` (if available). @@ -779,17 +529,10 @@ def unbind(self, event, handler): for assoc in self.active_associations: assoc.unbind(event, handler) - -class ThreadedAssociationServer(ThreadingMixIn, AssociationServer): - """An :class:`AssociationServer` suitable for threading. - - .. versionadded:: 1.2 - """ - def process_request_thread(self, request, client_address): - """Process a connection request.""" - # pylint: disable=broad-except - try: - self.finish_request(request, client_address) - except Exception: - self.handle_error(request, client_address) - self.shutdown_request(request) + async def serve_forever(self): + loop = asyncio.get_event_loop() + server = await loop.create_server( + self, '127.0.0.1', 4242 + ) + async with server: + await server.serve_forever()