Issues with buildProtocol in the ServerFactory within the tests
So yet again I have been confronted with broken tests in Ubuntu One. As I have already mentioned before I have spent a significant amount of time ensuring that the tests of Ubuntu One (which use twisted a lot) are deterministic and we do not leave a dirty reactor in the way. In order to do that a few week a go I wrote the following code that will help the rest of the team write such tests:
import os import shutil import tempfile from twisted.internet import defer, endpoints, protocol from twisted.spread import pb from ubuntuone.devtools.testcases import BaseTestCase # no init method + twisted common warnings # pylint: disable=W0232, C0103, E1101 def server_protocol_factory(cls): """Factory to create tidy protocols.""" if cls is None: cls = protocol.Protocol class ServerTidyProtocol(cls): """A tidy protocol.""" def connectionLost(self, *args): """Lost the connection.""" cls.connectionLost(self, *args) # lets tell everyone # pylint: disable=W0212 if (self.factory._disconnecting and self.factory.testserver_on_connection_lost is not None and not self.factory.testserver_on_connection_lost.called): self.factory.testserver_on_connection_lost.callback(self) # pylint: enable=W0212 return ServerTidyProtocol def client_protocol_factory(cls): """Factory to create tidy protocols.""" if cls is None: cls = protocol.Protocol class ClientTidyProtocol(cls): """A tidy protocol.""" def connectionLost(self, *a): """Connection list.""" # pylint: disable=W0212 if (self.factory._disconnecting and self.factory.testserver_on_connection_lost is not None and not self.factory.testserver_on_connection_lost.called): self.factory.testserver_on_connection_lost.callback(self) # pylint: enable=W0212 cls.connectionLost(self, *a) return ClientTidyProtocol class TidySocketServer(object): """Ensure that twisted servers are correctly managed in tests. Closing a twisted server is a complicated matter. In order to do so you have to ensure that three different deferreds are fired: 1. The server must stop listening. 2. The client connection must disconnect. 3. The server connection must disconnect. This class allows to create a server and a client that will ensure that the reactor is left clean by following the pattern described at http://mumak.net/stuff/twisted-disconnect.html """ def __init__(self): """Create a new instance.""" self.listener = None self.server_factory = None self.connector = None self.client_factory = None def get_server_endpoint(self): """Return the server endpoint description.""" raise NotImplementedError('To be implemented by child classes.') def get_client_endpoint(self): """Return the client endpoint description.""" raise NotImplementedError('To be implemented by child classes.') @defer.inlineCallbacks def listen_server(self, server_class, *args, **kwargs): """Start a server in a random port.""" from twisted.internet import reactor self.server_factory = server_class(*args, **kwargs) self.server_factory._disconnecting = False self.server_factory.testserver_on_connection_lost = defer.Deferred() self.server_factory.protocol = server_protocol_factory( self.server_factory.protocol) endpoint = endpoints.serverFromString(reactor, self.get_server_endpoint()) self.listener = yield endpoint.listen(self.server_factory) defer.returnValue(self.server_factory) @defer.inlineCallbacks def connect_client(self, client_class, *args, **kwargs): """Conect a client to a given server.""" from twisted.internet import reactor if self.server_factory is None: raise ValueError('Server Factory was not provided.') if self.listener is None: raise ValueError('%s has not started listening.', self.server_factory) self.client_factory = client_class(*args, **kwargs) self.client_factory._disconnecting = False self.client_factory.protocol = client_protocol_factory( self.client_factory.protocol) self.client_factory.testserver_on_connection_lost = defer.Deferred() endpoint = endpoints.clientFromString(reactor, self.get_client_endpoint()) self.connector = yield endpoint.connect(self.client_factory) defer.returnValue(self.client_factory) def clean_up(self): """Action to be performed for clean up.""" if self.server_factory is None or self.listener is None: # nothing to clean return defer.succeed(None) if self.listener and self.connector: # clean client and server self.server_factory._disconnecting = True self.client_factory._disconnecting = True self.connector.transport.loseConnection() d = defer.maybeDeferred(self.listener.stopListening) return defer.gatherResults([d, self.client_factory.testserver_on_connection_lost, self.server_factory.testserver_on_connection_lost]) if self.listener: # just clean the server since there is no client self.server_factory._disconnecting = True return defer.maybeDeferred(self.listener.stopListening) class TidyTCPServer(TidySocketServer): """A tidy tcp domain sockets server.""" client_endpoint_pattern = 'tcp:host=127.0.0.1:port=%s' server_endpoint_pattern = 'tcp:0:interface=127.0.0.1' def get_server_endpoint(self): """Return the server endpoint description.""" return self.server_endpoint_pattern def get_client_endpoint(self): """Return the client endpoint description.""" if self.server_factory is None: raise ValueError('Server Factory was not provided.') if self.listener is None: raise ValueError('%s has not started listening.', self.server_factory) return self.client_endpoint_pattern % self.listener.getHost().port class TidyUnixServer(TidySocketServer): """A tidy unix domain sockets server.""" client_endpoint_pattern = 'unix:path=%s' server_endpoint_pattern = 'unix:%s' def __init__(self): """Create a new instance.""" super(TidyUnixServer, self).__init__() self.temp_dir = tempfile.mkdtemp() self.path = os.path.join(self.temp_dir, 'tidy_unix_server') def get_server_endpoint(self): """Return the server endpoint description.""" return self.server_endpoint_pattern % self.path def get_client_endpoint(self): """Return the client endpoint description.""" return self.client_endpoint_pattern % self.path def clean_up(self): """Action to be performed for clean up.""" result = super(TidyUnixServer, self).clean_up() # remove the dir once we are disconnected result.addCallback(lambda _: shutil.rmtree(self.temp_dir)) return result class ServerTestCase(BaseTestCase): """Base test case for tidy servers.""" @defer.inlineCallbacks def setUp(self): """Set the diff tests.""" yield super(ServerTestCase, self).setUp() try: self.server_runner = self.get_server() except NotImplementedError: self.server_runner = None self.server_factory = None self.client_factory = None self.server_disconnected = None self.client_connected = None self.client_disconnected = None self.listener = None self.connector = None self.addCleanup(self.tear_down_server_client) def get_server(self): """Return the server to be used to run the tests.""" raise NotImplementedError('To be implemented by child classes.') @defer.inlineCallbacks def listen_server(self, server_class, *args, **kwargs): """Listen a server. The method takes the server class and the arguments that should be passed to the server constructor. """ self.server_factory = yield self.server_runner.listen_server( server_class, *args, **kwargs) self.server_disconnected = self.server_factory.testserver_on_connection_lost self.listener = self.server_runner.listener @defer.inlineCallbacks def connect_client(self, client_class, *args, **kwargs): """Connect the client. The method takes the client factory class and the arguments that should be passed to the client constructor. """ self.client_factory = yield self.server_runner.connect_client( client_class, *args, **kwargs) self.client_disconnected = self.client_factory.testserver_on_connection_lost self.connector = self.server_runner.connector def tear_down_server_client(self): """Clean the server and client.""" if self.server_runner: return self.server_runner.clean_up() class TCPServerTestCase(ServerTestCase): """Test that uses a single twisted server.""" def get_server(self): """Return the server to be used to run the tests.""" return TidyTCPServer() class UnixServerTestCase(ServerTestCase): """Test that uses a single twisted server.""" def get_server(self): """Return the server to be used to run the tests.""" return TidyUnixServer() class PbServerTestCase(ServerTestCase): """Test a pb server.""" def get_server(self): """Return the server to be used to run the tests.""" raise NotImplementedError('To be implemented by child classes.') @defer.inlineCallbacks def listen_server(self, *args, **kwargs): """Listen a pb server.""" yield super(PbServerTestCase, self).listen_server(pb.PBServerFactory, *args, **kwargs) @defer.inlineCallbacks def connect_client(self, *args, **kwargs): """Connect a pb client.""" yield super(PbServerTestCase, self).connect_client(pb.PBClientFactory, *args, **kwargs) class TCPPbServerTestCase(PbServerTestCase): """Test a pb server over TCP.""" def get_server(self): """Return the server to be used to run the tests.""" return TidyTCPServer() class UnixPbServerTestCase(PbServerTestCase): """Test a pb server over Unix domain sockets.""" def get_server(self): """Return the server to be used to run the tests.""" return TidyUnixServer() |
The idea of the code is that developers do not need to worry about how to stop listening ports in their tests and just write tests like the following:
class TCPMultipleServersTestCase(TestCase): """Ensure that several servers can be ran.""" timeout = 2 @defer.inlineCallbacks def setUp(self): """Set the diff tests.""" yield super(TCPMultipleServersTestCase, self).setUp() self.first_tcp_server = self.get_server() self.second_tcp_server = self.get_server() self.adder = Adder() self.calculator = Calculator(self.adder) self.echoer = Echoer() def get_server(self): """Return the server to be used to run the tests.""" return TidyTCPServer() @defer.inlineCallbacks def test_single_server(self): """Test setting a single server.""" first_number = 1 second_number = 2 yield self.first_tcp_server.listen_server(pb.PBServerFactory, self.calculator) self.addCleanup(self.first_tcp_server.clean_up) calculator_c = yield self.first_tcp_server.connect_client( pb.PBClientFactory) calculator = yield calculator_c.getRootObject() adder = yield calculator.callRemote('get_adder') result = yield adder.callRemote('add', first_number, second_number) self.assertEqual(first_number + second_number, result) @defer.inlineCallbacks def test_multiple_server(self): """Test setting multiple server.""" first_number = 1 second_number = 2 # first server yield self.first_tcp_server.listen_server(pb.PBServerFactory, self.calculator) self.addCleanup(self.first_tcp_server.clean_up) # second server yield self.second_tcp_server.listen_server(pb.PBServerFactory, self.echoer) self.addCleanup(self.second_tcp_server.clean_up) # connect the diff clients calculator_c = yield self.first_tcp_server.connect_client( pb.PBClientFactory) echoer_c = yield self.second_tcp_server.connect_client( pb.PBClientFactory) calculator = yield calculator_c.getRootObject() adder = yield calculator.callRemote('get_adder') result = yield adder.callRemote('add', first_number, second_number) self.assertEqual(first_number + second_number, result) echoer = yield echoer_c.getRootObject() echo = yield echoer.callRemote('say', 'hello') self.assertEqual(self.echoer.remote_say('hello'), echo) |
As you can see those tests do not give a rats ass about ensuring that the clients lose connection or we stop listening ports… Or so I though because the following code made such approach break in Mac OS X (although I suspect it was broken on Linux and Windows but we never experienced it):
class NullProtocol(protocol.Protocol): """A protocol that drops the connection.""" def connectionMade(self): """Just drop the connection.""" self.transport.loseConnection() class PortDetectFactory(protocol.ClientFactory): """Will detect if something is listening in a given port.""" protocol = NullProtocol def __init__(self): """Initialize this instance.""" self.d = defer.Deferred() def is_listening(self): """A deferred that will become True if something is listening.""" return self.d def buildProtocol(self, addr): """Connected.""" p = protocol.ClientFactory.buildProtocol(self, addr) if not self.d.called: self.d.callback(True) return p def clientConnectionLost(self, connector, reason): """The connection was lost.""" if not self.d.called: self.d.callback(False) def clientConnectionFailed(self, connector, reason): """The connection failed.""" if not self.d.called: self.d.callback(False) |
The code used to test the above was written as:
@defer.inlineCallbacks def test_is_already_running(self): """The is_already_running method returns True if already started.""" server = self.get_server() self.addCleanup(server.clean_up) class TestConnect(object): @defer.inlineCallbacks def connect(my_self, factory): connected_factory = yield server.connect_client(PortDetectFactory) self.patch(factory, 'is_listening', lambda: connected_factory.is_listening()) defer.returnValue(connected_factory) self.patch(tcpactivation, 'clientFromString', lambda *args: TestConnect()) yield server.listen_server(protocol.ServerFactory) # pylint: disable=E1101 ad = ActivationDetector(self.config) result = yield ad.is_already_running() self.assertTrue(result, "It should be already running.") |
While in all the other platforms the tests passed with no problems on Mac OS X the tests would block in the clean_up method from the server because the deferred that was called in the connectionLost from the ServerTidyProtocol was never fired… Interesting.. After digging in the code I realized that the main issue with the approach of the clean_up code was wrong. The problem relies on the way in which the NullProtocol works. As you can see in the code the protocol loses its connections as soon as it made. This results in to possible things:
- The server does know that we have a client connected and calls buildProtocol.
- The connection is lost so fast that the buildProtocol on the ServerFactory does not get call.
When running the tests on Windows and Linux we were always facing the first scenario, buildProtocol was called which meant that connectionLost in the server protocol would be called. On the other hand, on Mac OS X, 1 out of 10 runs of the tests would block in the clean up because we would be in the second scenario, that is, no protocol would be build in the ServerFactory which results in the connectionLost never being called because it was no needed. The work around this issue is quite simple once you understand what is going on. The ServerFactory has to be modified to set the deferred when buildProtocol is called and not before ensuring that when we cleanup we check if the deferred is None and if it is not we wait for it to be fired. The fixed version of the helper code is the following:
import os import shutil import tempfile from twisted.internet import defer, endpoints, protocol from twisted.spread import pb from ubuntuone.devtools.testcases import BaseTestCase # no init method + twisted common warnings # pylint: disable=W0232, C0103, E1101 def server_protocol_factory(cls): """Factory to create tidy protocols.""" if cls is None: cls = protocol.Protocol class ServerTidyProtocol(cls): """A tidy protocol.""" def connectionLost(self, *args): """Lost the connection.""" cls.connectionLost(self, *args) # lets tell everyone # pylint: disable=W0212 if (self.factory._disconnecting and self.factory.testserver_on_connection_lost is not None and not self.factory.testserver_on_connection_lost.called): self.factory.testserver_on_connection_lost.callback(self) # pylint: enable=W0212 return ServerTidyProtocol def server_factory_factory(cls): """Factory that creates special types of factories for tests.""" if cls is None: cls = protocol.ServerFactory class TidyServerFactory(cls): """A tidy factory.""" testserver_on_connection_lost = None def buildProtocol(self, addr): prot = cls.buildProtocol(self, addr) self.testserver_on_connection_lost = defer.Deferred() return prot return TidyServerFactory def client_protocol_factory(cls): """Factory to create tidy protocols.""" if cls is None: cls = protocol.Protocol class ClientTidyProtocol(cls): """A tidy protocol.""" def connectionLost(self, *a): """Connection list.""" cls.connectionLost(self, *a) # pylint: disable=W0212 if (self.factory._disconnecting and self.factory.testserver_on_connection_lost is not None and not self.factory.testserver_on_connection_lost.called): self.factory.testserver_on_connection_lost.callback(self) # pylint: enable=W0212 return ClientTidyProtocol class TidySocketServer(object): """Ensure that twisted servers are correctly managed in tests. Closing a twisted server is a complicated matter. In order to do so you have to ensure that three different deferreds are fired: 1. The server must stop listening. 2. The client connection must disconnect. 3. The server connection must disconnect. This class allows to create a server and a client that will ensure that the reactor is left clean by following the pattern described at http://mumak.net/stuff/twisted-disconnect.html """ def __init__(self): """Create a new instance.""" self.listener = None self.server_factory = None self.connector = None self.client_factory = None def get_server_endpoint(self): """Return the server endpoint description.""" raise NotImplementedError('To be implemented by child classes.') def get_client_endpoint(self): """Return the client endpoint description.""" raise NotImplementedError('To be implemented by child classes.') @defer.inlineCallbacks def listen_server(self, server_class, *args, **kwargs): """Start a server in a random port.""" from twisted.internet import reactor tidy_class = server_factory_factory(server_class) self.server_factory = tidy_class(*args, **kwargs) self.server_factory._disconnecting = False self.server_factory.protocol = server_protocol_factory( self.server_factory.protocol) endpoint = endpoints.serverFromString(reactor, self.get_server_endpoint()) self.listener = yield endpoint.listen(self.server_factory) defer.returnValue(self.server_factory) @defer.inlineCallbacks def connect_client(self, client_class, *args, **kwargs): """Conect a client to a given server.""" from twisted.internet import reactor if self.server_factory is None: raise ValueError('Server Factory was not provided.') if self.listener is None: raise ValueError('%s has not started listening.', self.server_factory) self.client_factory = client_class(*args, **kwargs) self.client_factory._disconnecting = False self.client_factory.protocol = client_protocol_factory( self.client_factory.protocol) self.client_factory.testserver_on_connection_lost = defer.Deferred() endpoint = endpoints.clientFromString(reactor, self.get_client_endpoint()) self.connector = yield endpoint.connect(self.client_factory) defer.returnValue(self.client_factory) def clean_up(self): """Action to be performed for clean up.""" if self.server_factory is None or self.listener is None: # nothing to clean return defer.succeed(None) if self.listener and self.connector: # clean client and server self.server_factory._disconnecting = True self.client_factory._disconnecting = True d = defer.maybeDeferred(self.listener.stopListening) self.connector.transport.loseConnection() if self.server_factory.testserver_on_connection_lost: return defer.gatherResults([d, self.client_factory.testserver_on_connection_lost, self.server_factory.testserver_on_connection_lost]) else: return defer.gatherResults([d, self.client_factory.testserver_on_connection_lost]) if self.listener: # just clean the server since there is no client # pylint: disable=W0201 self.server_factory._disconnecting = True return defer.maybeDeferred(self.listener.stopListening) # pylint: enable=W0201 class TidyTCPServer(TidySocketServer): """A tidy tcp domain sockets server.""" client_endpoint_pattern = 'tcp:host=127.0.0.1:port=%s' server_endpoint_pattern = 'tcp:0:interface=127.0.0.1' def get_server_endpoint(self): """Return the server endpoint description.""" return self.server_endpoint_pattern def get_client_endpoint(self): """Return the client endpoint description.""" if self.server_factory is None: raise ValueError('Server Factory was not provided.') if self.listener is None: raise ValueError('%s has not started listening.', self.server_factory) return self.client_endpoint_pattern % self.listener.getHost().port class TidyUnixServer(TidySocketServer): """A tidy unix domain sockets server.""" client_endpoint_pattern = 'unix:path=%s' server_endpoint_pattern = 'unix:%s' def __init__(self): """Create a new instance.""" super(TidyUnixServer, self).__init__() self.temp_dir = tempfile.mkdtemp() self.path = os.path.join(self.temp_dir, 'tidy_unix_server') def get_server_endpoint(self): """Return the server endpoint description.""" return self.server_endpoint_pattern % self.path def get_client_endpoint(self): """Return the client endpoint description.""" return self.client_endpoint_pattern % self.path def clean_up(self): """Action to be performed for clean up.""" result = super(TidyUnixServer, self).clean_up() # remove the dir once we are disconnected result.addCallback(lambda _: shutil.rmtree(self.temp_dir)) return result class ServerTestCase(BaseTestCase): """Base test case for tidy servers.""" @defer.inlineCallbacks def setUp(self): """Set the diff tests.""" yield super(ServerTestCase, self).setUp() try: self.server_runner = self.get_server() except NotImplementedError: self.server_runner = None self.server_factory = None self.client_factory = None self.server_disconnected = None self.client_connected = None self.client_disconnected = None self.listener = None self.connector = None self.addCleanup(self.tear_down_server_client) def get_server(self): """Return the server to be used to run the tests.""" raise NotImplementedError('To be implemented by child classes.') @defer.inlineCallbacks def listen_server(self, server_class, *args, **kwargs): """Listen a server. The method takes the server class and the arguments that should be passed to the server constructor. """ self.server_factory = yield self.server_runner.listen_server( server_class, *args, **kwargs) self.server_disconnected = self.server_factory.testserver_on_connection_lost self.listener = self.server_runner.listener @defer.inlineCallbacks def connect_client(self, client_class, *args, **kwargs): """Connect the client. The method takes the client factory class and the arguments that should be passed to the client constructor. """ self.client_factory = yield self.server_runner.connect_client( client_class, *args, **kwargs) self.client_disconnected = self.client_factory.testserver_on_connection_lost self.connector = self.server_runner.connector def tear_down_server_client(self): """Clean the server and client.""" if self.server_runner: return self.server_runner.clean_up() class TCPServerTestCase(ServerTestCase): """Test that uses a single twisted server.""" def get_server(self): """Return the server to be used to run the tests.""" return TidyTCPServer() class UnixServerTestCase(ServerTestCase): """Test that uses a single twisted server.""" def get_server(self): """Return the server to be used to run the tests.""" return TidyUnixServer() class PbServerTestCase(ServerTestCase): """Test a pb server.""" def get_server(self): """Return the server to be used to run the tests.""" raise NotImplementedError('To be implemented by child classes.') @defer.inlineCallbacks def listen_server(self, *args, **kwargs): """Listen a pb server.""" yield super(PbServerTestCase, self).listen_server(pb.PBServerFactory, *args, **kwargs) @defer.inlineCallbacks def connect_client(self, *args, **kwargs): """Connect a pb client.""" yield super(PbServerTestCase, self).connect_client(pb.PBClientFactory, *args, **kwargs) class TCPPbServerTestCase(PbServerTestCase): """Test a pb server over TCP.""" def get_server(self): """Return the server to be used to run the tests.""" return TidyTCPServer() class UnixPbServerTestCase(PbServerTestCase): """Test a pb server over Unix domain sockets.""" def get_server(self): """Return the server to be used to run the tests.""" return TidyUnixServer() |
I wonder if at some point I should share this code for the people out there… any opinions?




