Skip to content

Remove from pool on connection failure #106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 14, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions neo4j/v1/bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@

class BufferingSocket(object):

def __init__(self, socket):
self.address = socket.getpeername()
self.socket = socket
def __init__(self, connection):
self.connection = connection
self.socket = connection.socket
self.address = self.socket.getpeername()
self.buffer = bytearray()

def fill(self):
Expand All @@ -96,6 +97,10 @@ def fill(self):
self.buffer[len(self.buffer):] = received
else:
if ready_to_read is not None:
# If this connection fails, remove this address from the
# connection pool to which this connection belongs.
if self.connection.pool:
self.connection.pool.remove(self.address)
raise ServiceUnavailable("Failed to read from connection %r" % (self.address,))

def read_message(self):
Expand Down Expand Up @@ -211,9 +216,12 @@ class Connection(object):
.. note:: logs at INFO level
"""

#: The pool of which this connection is a member
pool = None

def __init__(self, sock, **config):
self.socket = sock
self.buffering_socket = BufferingSocket(sock)
self.buffering_socket = BufferingSocket(self)
self.address = sock.getpeername()
self.channel = ChunkChannel(sock)
self.packer = Packer(self.channel)
Expand Down Expand Up @@ -411,6 +419,7 @@ def acquire(self, address):
connection.in_use = True
return connection
connection = self.connector(address)
connection.pool = self
connection.in_use = True
connections.append(connection)
return connection
Expand Down
28 changes: 23 additions & 5 deletions neo4j/v1/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,22 +263,40 @@ def refresh_routing_table(self):
def acquire_for_read(self):
""" Acquire a connection to a read server.
"""
self.refresh_routing_table()
return self.acquire(next(self.routing_table.readers))
while True:
address = None
while address is None:
self.refresh_routing_table()
address = next(self.routing_table.readers)
try:
connection = self.acquire(address)
except ServiceUnavailable:
self.remove(address)
else:
return connection

def acquire_for_write(self):
""" Acquire a connection to a write server.
"""
self.refresh_routing_table()
return self.acquire(next(self.routing_table.writers))
while True:
address = None
while address is None:
self.refresh_routing_table()
address = next(self.routing_table.writers)
try:
connection = self.acquire(address)
except ServiceUnavailable:
self.remove(address)
else:
return connection

def remove(self, address):
""" Remove an address from the connection pool, if present, closing
all connections to that address. Also remove from the routing table.
"""
super(RoutingConnectionPool, self).remove(address)
# We use `discard` instead of `remove` here since the former
# will not fail if the address has already been removed.
self.routing_table.routers.discard(address)
self.routing_table.readers.discard(address)
self.routing_table.writers.discard(address)
super(RoutingConnectionPool, self).remove(address)
4 changes: 4 additions & 0 deletions test/resources/fail_on_init.script
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
!: AUTO INIT
!: AUTO RESET

S: <EXIT>
8 changes: 8 additions & 0 deletions test/resources/router_with_multiple_writers.script
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
!: AUTO INIT
!: AUTO RESET

C: RUN "CALL dbms.cluster.routing.getServers" {}
PULL_ALL
S: SUCCESS {"fields": ["ttl", "servers"]}
RECORD [300, [{"role":"ROUTE","addresses":["127.0.0.1:9001","127.0.0.1:9002","127.0.0.1:9003"]},{"role":"READ","addresses":["127.0.0.1:9004","127.0.0.1:9005"]},{"role":"WRITE","addresses":["127.0.0.1:9006","127.0.0.1:9007"]}]]
SUCCESS {}
22 changes: 22 additions & 0 deletions test/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,17 @@ def test_connected_to_reader(self):
connection = pool.acquire_for_read()
assert connection.address in pool.routing_table.readers

def test_should_retry_if_first_reader_fails(self):
with StubCluster({9001: "router.script",
9004: "fail_on_init.script",
9005: "empty.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
assert not pool.routing_table.is_fresh()
_ = pool.acquire_for_read()
assert ("127.0.0.1", 9004) not in pool.routing_table.readers
assert ("127.0.0.1", 9005) in pool.routing_table.readers


class RoutingConnectionPoolAcquireForWriteTestCase(ServerTestCase):

Expand All @@ -596,6 +607,17 @@ def test_connected_to_writer(self):
connection = pool.acquire_for_write()
assert connection.address in pool.routing_table.writers

def test_should_retry_if_first_writer_fails(self):
with StubCluster({9001: "router_with_multiple_writers.script",
9006: "fail_on_init.script",
9007: "empty.script"}):
address = ("127.0.0.1", 9001)
with RoutingConnectionPool(connector, address) as pool:
assert not pool.routing_table.is_fresh()
_ = pool.acquire_for_write()
assert ("127.0.0.1", 9006) not in pool.routing_table.writers
assert ("127.0.0.1", 9007) in pool.routing_table.writers


class RoutingConnectionPoolRemoveTestCase(ServerTestCase):

Expand Down
1 change: 0 additions & 1 deletion test/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class ServerTestCase(TestCase):

known_hosts = KNOWN_HOSTS
known_hosts_backup = known_hosts + ".backup"
servers = []

def setUp(self):
if isfile(self.known_hosts):
Expand Down