Skip to content
Merged
58 changes: 25 additions & 33 deletions neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,60 +38,52 @@
from logging import getLogger
from random import choice
from select import select
from time import perf_counter

from socket import (
AF_INET,
AF_INET6,
SHUT_RDWR,
SO_KEEPALIVE,
socket,
SOL_SOCKET,
SO_KEEPALIVE,
SHUT_RDWR,
timeout as SocketTimeout,
AF_INET,
AF_INET6,
)

from ssl import (
HAS_SNI,
SSLError,
)

from struct import (
pack as struct_pack,
)

from threading import (
Condition,
Lock,
RLock,
Condition,
)
from time import perf_counter

from neo4j.addressing import Address
from neo4j.conf import PoolConfig
from neo4j._exceptions import (
BoltHandshakeError,
BoltProtocolError,
BoltRoutingError,
BoltSecurityError,
BoltProtocolError,
BoltHandshakeError,
)
from neo4j.exceptions import (
ServiceUnavailable,
ClientError,
SessionExpired,
ReadServiceUnavailable,
WriteServiceUnavailable,
ConfigurationError,
UnsupportedServerProduct,
from neo4j.addressing import Address
from neo4j.api import (
READ_ACCESS,
Version,
WRITE_ACCESS,
)
from neo4j.routing import RoutingTable
from neo4j.conf import (
PoolConfig,
WorkspaceConfig,
)
from neo4j.api import (
READ_ACCESS,
WRITE_ACCESS,
Version,
from neo4j.exceptions import (
ClientError,
ConfigurationError,
ReadServiceUnavailable,
ServiceUnavailable,
SessionExpired,
UnsupportedServerProduct,
WriteServiceUnavailable,
)
from neo4j.routing import RoutingTable

# Set up logger
log = getLogger("neo4j")
Expand Down Expand Up @@ -258,7 +250,7 @@ def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_
except Exception as error:
log.debug("[#%04X] C: <CLOSE> %s", s.getsockname()[1], str(error))
_close_socket(s)
raise error
raise

return connection

Expand Down Expand Up @@ -522,7 +514,7 @@ def deactivate(self, address):
connections.remove(conn)
try:
conn.close()
except IOError:
except OSError:
pass
if not connections:
self.remove(address)
Expand All @@ -538,7 +530,7 @@ def remove(self, address):
for connection in self.connections.pop(address, ()):
try:
connection.close()
except IOError:
except OSError:
pass

def close(self):
Expand Down
97 changes: 47 additions & 50 deletions neo4j/io/_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,49 @@
# limitations under the License.

from collections import deque
from logging import getLogger
from ssl import SSLSocket
from time import perf_counter

from neo4j._exceptions import (
BoltError,
BoltProtocolError,
)
from neo4j.addressing import Address
from neo4j.api import (
Version,
READ_ACCESS,
ServerInfo,
Version,
)
from neo4j.io._common import (
Inbox,
Outbox,
Response,
InitResponse,
CommitResponse,
)
from neo4j.meta import get_user_agent
from neo4j.exceptions import (
AuthError,
DatabaseUnavailable,
ConfigurationError,
DatabaseUnavailable,
DriverError,
ForbiddenOnReadOnlyDatabase,
IncompleteCommit,
NotALeader,
ServiceUnavailable,
SessionExpired,
)
from neo4j._exceptions import BoltProtocolError
from neo4j.packstream import (
Unpacker,
Packer,
)
from neo4j.io import (
check_supported_server_product,
Bolt,
BoltPool,
check_supported_server_product,
)
from neo4j.api import ServerInfo
from neo4j.addressing import Address
from neo4j.io._common import (
CommitResponse,
Inbox,
InitResponse,
Outbox,
Response,
)
from neo4j.meta import get_user_agent
from neo4j.packstream import (
Packer,
Unpacker,
)

from logging import getLogger
log = getLogger("neo4j")


Expand Down Expand Up @@ -85,7 +90,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No
self.socket = sock
self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION)
self.outbox = Outbox()
self.inbox = Inbox(self.socket, on_error=self._set_defunct)
self.inbox = Inbox(self.socket, on_error=self._set_defunct_read)
self.packer = Packer(self.outbox)
self.unpacker = Unpacker(self.inbox)
self.responses = deque()
Expand Down Expand Up @@ -135,7 +140,7 @@ def der_encoded_server_certificate(self):
def local_port(self):
try:
return self.socket.getsockname()[1]
except IOError:
except OSError:
return 0

def get_base_headers(self):
Expand Down Expand Up @@ -292,7 +297,10 @@ def fail(metadata):
def _send_all(self):
data = self.outbox.view()
if data:
self.socket.sendall(data)
try:
self.socket.sendall(data)
except OSError as error:
self._set_defunct_write(error)
self.outbox.clear()

def send_all(self):
Expand All @@ -306,17 +314,7 @@ def send_all(self):
raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format(
self.unresolved_address, self.server_info.address))

try:
self._send_all()
except (IOError, OSError) as error:
log.error("Failed to write data to connection "
"{!r} ({!r}); ({!r})".
format(self.unresolved_address,
self.server_info.address,
"; ".join(map(repr, error.args))))
if self.pool:
self.pool.deactivate(address=self.unresolved_address)
raise
self._send_all()

def fetch_message(self):
""" Receive at least one message from the server, if available.
Expand All @@ -336,17 +334,7 @@ def fetch_message(self):
return 0, 0

# Receive exactly one message
try:
details, summary_signature, summary_metadata = next(self.inbox)
except (IOError, OSError) as error:
log.error("Failed to read data from connection "
"{!r} ({!r}); ({!r})".
format(self.unresolved_address,
self.server_info.address,
"; ".join(map(repr, error.args))))
if self.pool:
self.pool.deactivate(address=self.unresolved_address)
raise
details, summary_signature, summary_metadata = next(self.inbox)

if details:
log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data
Expand Down Expand Up @@ -380,11 +368,20 @@ def fetch_message(self):

return len(details), 1

def _set_defunct(self, error=None):
direct_driver = isinstance(self.pool, BoltPool)
def _set_defunct_read(self, error=None):
message = "Failed to read from defunct connection {!r} ({!r})".format(
self.unresolved_address, self.server_info.address
)
self._set_defunct(message, error=error)

message = ("Failed to read from defunct connection {!r} ({!r})".format(
self.unresolved_address, self.server_info.address))
def _set_defunct_write(self, error=None):
message = "Failed to write data to connection {!r} ({!r})".format(
self.unresolved_address, self.server_info.address
)
self._set_defunct(message, error=error)

def _set_defunct(self, message, error=None):
direct_driver = isinstance(self.pool, BoltPool)

if error:
log.error(str(error))
Expand Down Expand Up @@ -445,12 +442,12 @@ def close(self):
self._append(b"\x02", ())
try:
self._send_all()
except:
except (OSError, BoltError, DriverError):
pass
log.debug("[#%04X] C: <CLOSE>", self.local_port)
try:
self.socket.close()
except IOError:
except OSError:
pass
finally:
self._closed = True
Expand Down
Loading