Coverage for tests/comm/test_server.py: 100%
171 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-07-10 13:43 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-07-10 13:43 +0000
1# The MIT License (MIT)
2#
3# Copyright (c) 2021 RSK Labs Ltd
4#
5# Permission is hereby granted, free of charge, to any person obtaining a copy of
6# this software and associated documentation files (the "Software"), to deal in
7# the Software without restriction, including without limitation the rights to
8# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
9# of the Software, and to permit persons to whom the Software is furnished to do
10# so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice shall be included in all
13# copies or substantial portions of the Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23from unittest import TestCase
24from unittest.mock import Mock, call, ANY, patch
25from comm.server import (
26 TCPServer,
27 TCPServerError,
28 _RequestHandler,
29 RequestHandlerError,
30 RequestHandlerShutdown,
31)
32from comm.protocol import HSM2ProtocolError, HSM2ProtocolInterrupt
33import socketserver
34import socket
36import logging
38logging.disable(logging.CRITICAL)
41class TestTCPServer(TestCase):
42 def setUp(self):
43 self.protocol = Mock()
44 self.server = TCPServer("a-host", 1234, self.protocol)
46 def test_init_ok(self):
47 self.assertEqual(self.server.host, "a-host")
48 self.assertEqual(self.server.port, 1234)
49 self.assertEqual(self.server.protocol, self.protocol)
51 @patch("socketserver.TCPServer")
52 def test_run_ok(self, TCPServerMock):
53 TCPServerMock.return_value = Mock()
55 self.server.run()
57 self.assert_server_setup_ok(TCPServerMock)
59 self.assertEqual(self.server.server.serve_forever.call_count, 1)
60 self.assertEqual(self.server.server.server_close.call_count, 1)
62 @patch("socketserver.TCPServer")
63 def test_run_interrupt(self, TCPServerMock):
64 TCPServerMock.return_value = Mock()
65 TCPServerMock.return_value.serve_forever = Mock(side_effect=KeyboardInterrupt)
67 self.server.run()
69 self.assert_server_setup_ok(TCPServerMock)
71 self.assertEqual(self.server.server.serve_forever.call_count, 1)
72 self.assertEqual(self.server.server.server_close.call_count, 1)
74 @patch("socketserver.TCPServer")
75 def test_run_socket_error(self, TCPServerMock):
76 TCPServerMock.return_value = Mock()
77 TCPServerMock.return_value.serve_forever = Mock(side_effect=socket.error)
79 with self.assertRaises(TCPServerError):
80 self.server.run()
82 self.assert_server_setup_ok(TCPServerMock)
84 self.assertEqual(self.server.server.serve_forever.call_count, 1)
85 self.assertEqual(self.server.server.server_close.call_count, 1)
87 @patch("socketserver.TCPServer")
88 def test_run_initialize_device_not_implemented(self, TCPServerMock):
89 TCPServerMock.return_value = Mock()
90 self.protocol.initialize_device.side_effect = NotImplementedError()
92 with self.assertRaises(TCPServerError):
93 self.server.run()
95 self.assertEqual(self.protocol.initialize_device.call_args, [call()])
96 self.assertIsNone(self.server.server)
98 def assert_server_setup_ok(self, TCPServerMock):
99 self.assertEqual(TCPServerMock.call_args_list, [call(("a-host", 1234), ANY)])
100 self.assertEqual(self.server.server, TCPServerMock.return_value)
101 self.assertEqual(self.server.server.protocol, self.protocol)
102 self.assertEqual(self.protocol.initialize_device.call_args, [call()])
105@patch("socketserver.TCPServer")
106@patch("comm.server._RequestHandler")
107class TestTCPServerRequestHandler(TestCase):
108 def prepare(self, RequestHandlerMock, TCPServerMock):
109 self.protocol = Mock()
110 self.server = TCPServer("a-host", 1234, self.protocol)
111 TCPServerMock.return_value = Mock()
112 self.server.run()
113 self.handler_klass = TCPServerMock.call_args[0][1]
114 self.request = Mock() # Required by socketserver.StreamRequestHandler
115 self.client_address = [Mock()] # Required by socketserver.StreamRequestHandler
117 def handle(self, RequestHandlerMock, TCPServerMock):
118 self.handler = self.handler_klass(self.request,
119 self.client_address,
120 TCPServerMock.return_value)
122 def test_handles_ok(self, RequestHandlerMock, TCPServerMock):
123 self.prepare(RequestHandlerMock, TCPServerMock)
124 self.handle(RequestHandlerMock, TCPServerMock)
126 self.assertEqual(
127 RequestHandlerMock.call_args_list,
128 [call(self.protocol, self.server.logger)],
129 )
130 self.assertEqual(
131 RequestHandlerMock.return_value.handle.call_args_list,
132 [call(self.client_address[0], self.handler.rfile, self.handler.wfile)],
133 )
135 def test_handler_correct_subclass(self, RequestHandlerMock, TCPServerMock):
136 self.prepare(RequestHandlerMock, TCPServerMock)
137 self.handle(RequestHandlerMock, TCPServerMock)
139 self.assertIsInstance(self.handler, socketserver.StreamRequestHandler)
141 def test_handle_request_handler_error_shutsdown(
142 self, RequestHandlerMock, TCPServerMock):
143 self.prepare(RequestHandlerMock, TCPServerMock)
144 RequestHandlerMock.return_value.handle.side_effect = RequestHandlerError()
145 self.handle(RequestHandlerMock, TCPServerMock)
147 self.assertTrue(self.handler.server.shutdown.called)
149 def test_handle_request_handler_shutdown_shutsdown(
150 self, RequestHandlerMock, TCPServerMock):
151 self.prepare(RequestHandlerMock, TCPServerMock)
152 RequestHandlerMock.return_value.handle.side_effect = RequestHandlerShutdown()
153 self.handle(RequestHandlerMock, TCPServerMock)
155 self.assertTrue(self.handler.server.shutdown.called)
157 def test_handle_request_handler_connection_error_doesnotshutdown(
158 self, RequestHandlerMock, TCPServerMock):
159 self.prepare(RequestHandlerMock, TCPServerMock)
160 RequestHandlerMock.return_value.handle.side_effect = ConnectionError()
161 self.handle(RequestHandlerMock, TCPServerMock)
163 self.assertFalse(self.handler.server.shutdown.called)
165 def test_handle_request_handler_other_error_doesnotshutdown(
166 self, RequestHandlerMock, TCPServerMock):
167 self.prepare(RequestHandlerMock, TCPServerMock)
168 RequestHandlerMock.return_value.handle.side_effect = Exception()
169 self.handle(RequestHandlerMock, TCPServerMock)
171 self.assertFalse(self.handler.server.shutdown.called)
174class TestRequestHandler(TestCase):
175 def setUp(self):
176 self.protocol = Mock()
177 self.logger = Mock()
178 self.handler = _RequestHandler(self.protocol, self.logger)
179 self.rfile = Mock()
180 self.wfile = Mock()
182 def test_logger_protocol(self):
183 self.assertEqual(self.protocol, self.handler.protocol)
184 self.assertEqual(self.logger, self.handler.logger)
186 def test_handler_correct_encoding(self):
187 self.assertEqual(self.handler.ENCODING, "utf-8")
189 def test_handle_ok(self):
190 self.mock_request('{"this-is": "legal", "json": "format"}')
191 self.protocol.handle_request.return_value = {"this": "is", "the": "result"}
193 self.do_request()
195 self.assertEqual(
196 self.protocol.handle_request.call_args_list,
197 [call({
198 "this-is": "legal",
199 "json": "format"
200 })],
201 )
202 self.assertEqual(
203 self.wfile.write.call_args_list,
204 [
205 call('{"the": "result", "this": "is"}'.encode("utf-8")),
206 call("\n".encode("utf-8")),
207 ],
208 )
210 def test_handle_broken_pipe_reply(self):
211 self.mock_request('{"this-is": "legal", "json": "format"}')
212 self.protocol.handle_request.return_value = {"this": "is", "the": "result"}
213 self.wfile.write.side_effect = BrokenPipeError()
215 self.do_request()
217 self.assertEqual(
218 self.protocol.handle_request.call_args_list,
219 [call({
220 "this-is": "legal",
221 "json": "format"
222 })],
223 )
224 self.assertEqual(
225 self.wfile.write.call_args_list,
226 [call('{"the": "result", "this": "is"}'.encode("utf-8"))],
227 )
229 def test_handle_json_error(self):
230 self.mock_request("this-is-not-json")
231 self.protocol.format_error.return_value = {"format": "error", "a": "bad"}
233 self.do_request()
235 self.assertFalse(self.protocol.handle_request.called)
236 self.assertEqual(
237 self.wfile.write.call_args_list,
238 [
239 call('{"a": "bad", "format": "error"}'.encode("utf-8")),
240 call("\n".encode("utf-8")),
241 ],
242 )
244 def test_handle_notimplemented_error(self):
245 self.mock_request('{"another": "valid", "json": "request"}')
246 self.protocol.handle_request.side_effect = NotImplementedError(
247 "method not implemented")
249 self.do_request()
251 self.assertEqual(
252 self.protocol.handle_request.call_args_list,
253 [call({
254 "another": "valid",
255 "json": "request"
256 })],
257 )
258 self.assertEqual(
259 self.wfile.write.call_args_list,
260 [call("{}".encode("utf-8")),
261 call("\n".encode("utf-8"))],
262 )
264 def test_handle_protocol_error(self):
265 self.mock_request('{"another": "valid", "json": "request"}')
266 self.protocol.unknown_error.return_value = {"an": "unknown", "e": "rror"}
267 self.protocol.handle_request.side_effect = HSM2ProtocolError("protocol error")
269 with self.assertRaises(RequestHandlerError):
270 self.do_request()
272 self.assertEqual(
273 self.protocol.handle_request.call_args_list,
274 [call({
275 "another": "valid",
276 "json": "request"
277 })],
278 )
279 self.assertEqual(
280 self.wfile.write.call_args_list,
281 [
282 call('{"an": "unknown", "e": "rror"}'.encode("utf-8")),
283 call("\n".encode("utf-8")),
284 ],
285 )
287 def test_handle_protocol_shutdown(self):
288 self.mock_request('{"another": "valid", "json": "request"}')
289 self.protocol.handle_request.side_effect = HSM2ProtocolInterrupt(
290 "protocol interrupt")
292 with self.assertRaises(RequestHandlerShutdown):
293 self.do_request()
295 self.assertEqual(
296 self.protocol.handle_request.call_args_list,
297 [call({
298 "another": "valid",
299 "json": "request"
300 })],
301 )
302 self.assertEqual(
303 self.wfile.write.call_args_list,
304 [call("{}".encode("utf-8")),
305 call("\n".encode("utf-8"))],
306 )
308 def test_handle_unknown_exception(self):
309 self.mock_request('{"another": "valid", "json": "request"}')
310 self.protocol.handle_request.side_effect = ValueError("unexpected")
312 with self.assertRaises(RequestHandlerError):
313 self.do_request()
315 self.assertEqual(
316 self.protocol.handle_request.call_args_list,
317 [call({
318 "another": "valid",
319 "json": "request"
320 })],
321 )
322 self.assertEqual(
323 self.wfile.write.call_args_list,
324 [call("{}".encode("utf-8")),
325 call("\n".encode("utf-8"))],
326 )
328 def test_handle_invalid_encoding(self):
329 self.rfile.readline.return_value = b"\xff\xfa\xf0"
330 self.protocol.format_error.return_value = {"encoding": "error", "a": "bad"}
332 self.do_request()
334 self.assertFalse(self.protocol.handle_request.called)
335 self.assertEqual(
336 self.wfile.write.call_args_list,
337 [
338 call('{"a": "bad", "encoding": "error"}'.encode("utf-8")),
339 call("\n".encode("utf-8")),
340 ],
341 )
343 def test_handle_invalid_encoding_broken_pipe(self):
344 self.rfile.readline.return_value = b"\xff\xfa\xf0"
345 self.protocol.format_error.return_value = {"encoding": "error", "a": "bad"}
346 self.wfile.write.side_effect = BrokenPipeError()
348 self.do_request()
350 self.assertFalse(self.protocol.handle_request.called)
351 self.assertEqual(
352 self.wfile.write.call_args_list,
353 [call('{"a": "bad", "encoding": "error"}'.encode("utf-8"))],
354 )
356 def mock_request(self, line):
357 self.rfile.readline.return_value = line.encode("utf-8")
359 def do_request(self):
360 self.handler.handle("an-address", self.rfile, self.wfile)