Coverage for tests/comm/test_server.py: 100%

171 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-07-10 13:43 +0000

1# The MIT License (MIT) 

2# 

3# Copyright (c) 2021 RSK Labs Ltd 

4# 

5# Permission is hereby granted, free of charge, to any person obtaining a copy of 

6# this software and associated documentation files (the "Software"), to deal in 

7# the Software without restriction, including without limitation the rights to 

8# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 

9# of the Software, and to permit persons to whom the Software is furnished to do 

10# so, subject to the following conditions: 

11# 

12# The above copyright notice and this permission notice shall be included in all 

13# copies or substantial portions of the Software. 

14# 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23from unittest import TestCase 

24from unittest.mock import Mock, call, ANY, patch 

25from comm.server import ( 

26 TCPServer, 

27 TCPServerError, 

28 _RequestHandler, 

29 RequestHandlerError, 

30 RequestHandlerShutdown, 

31) 

32from comm.protocol import HSM2ProtocolError, HSM2ProtocolInterrupt 

33import socketserver 

34import socket 

35 

36import logging 

37 

38logging.disable(logging.CRITICAL) 

39 

40 

41class TestTCPServer(TestCase): 

42 def setUp(self): 

43 self.protocol = Mock() 

44 self.server = TCPServer("a-host", 1234, self.protocol) 

45 

46 def test_init_ok(self): 

47 self.assertEqual(self.server.host, "a-host") 

48 self.assertEqual(self.server.port, 1234) 

49 self.assertEqual(self.server.protocol, self.protocol) 

50 

51 @patch("socketserver.TCPServer") 

52 def test_run_ok(self, TCPServerMock): 

53 TCPServerMock.return_value = Mock() 

54 

55 self.server.run() 

56 

57 self.assert_server_setup_ok(TCPServerMock) 

58 

59 self.assertEqual(self.server.server.serve_forever.call_count, 1) 

60 self.assertEqual(self.server.server.server_close.call_count, 1) 

61 

62 @patch("socketserver.TCPServer") 

63 def test_run_interrupt(self, TCPServerMock): 

64 TCPServerMock.return_value = Mock() 

65 TCPServerMock.return_value.serve_forever = Mock(side_effect=KeyboardInterrupt) 

66 

67 self.server.run() 

68 

69 self.assert_server_setup_ok(TCPServerMock) 

70 

71 self.assertEqual(self.server.server.serve_forever.call_count, 1) 

72 self.assertEqual(self.server.server.server_close.call_count, 1) 

73 

74 @patch("socketserver.TCPServer") 

75 def test_run_socket_error(self, TCPServerMock): 

76 TCPServerMock.return_value = Mock() 

77 TCPServerMock.return_value.serve_forever = Mock(side_effect=socket.error) 

78 

79 with self.assertRaises(TCPServerError): 

80 self.server.run() 

81 

82 self.assert_server_setup_ok(TCPServerMock) 

83 

84 self.assertEqual(self.server.server.serve_forever.call_count, 1) 

85 self.assertEqual(self.server.server.server_close.call_count, 1) 

86 

87 @patch("socketserver.TCPServer") 

88 def test_run_initialize_device_not_implemented(self, TCPServerMock): 

89 TCPServerMock.return_value = Mock() 

90 self.protocol.initialize_device.side_effect = NotImplementedError() 

91 

92 with self.assertRaises(TCPServerError): 

93 self.server.run() 

94 

95 self.assertEqual(self.protocol.initialize_device.call_args, [call()]) 

96 self.assertIsNone(self.server.server) 

97 

98 def assert_server_setup_ok(self, TCPServerMock): 

99 self.assertEqual(TCPServerMock.call_args_list, [call(("a-host", 1234), ANY)]) 

100 self.assertEqual(self.server.server, TCPServerMock.return_value) 

101 self.assertEqual(self.server.server.protocol, self.protocol) 

102 self.assertEqual(self.protocol.initialize_device.call_args, [call()]) 

103 

104 

105@patch("socketserver.TCPServer") 

106@patch("comm.server._RequestHandler") 

107class TestTCPServerRequestHandler(TestCase): 

108 def prepare(self, RequestHandlerMock, TCPServerMock): 

109 self.protocol = Mock() 

110 self.server = TCPServer("a-host", 1234, self.protocol) 

111 TCPServerMock.return_value = Mock() 

112 self.server.run() 

113 self.handler_klass = TCPServerMock.call_args[0][1] 

114 self.request = Mock() # Required by socketserver.StreamRequestHandler 

115 self.client_address = [Mock()] # Required by socketserver.StreamRequestHandler 

116 

117 def handle(self, RequestHandlerMock, TCPServerMock): 

118 self.handler = self.handler_klass(self.request, 

119 self.client_address, 

120 TCPServerMock.return_value) 

121 

122 def test_handles_ok(self, RequestHandlerMock, TCPServerMock): 

123 self.prepare(RequestHandlerMock, TCPServerMock) 

124 self.handle(RequestHandlerMock, TCPServerMock) 

125 

126 self.assertEqual( 

127 RequestHandlerMock.call_args_list, 

128 [call(self.protocol, self.server.logger)], 

129 ) 

130 self.assertEqual( 

131 RequestHandlerMock.return_value.handle.call_args_list, 

132 [call(self.client_address[0], self.handler.rfile, self.handler.wfile)], 

133 ) 

134 

135 def test_handler_correct_subclass(self, RequestHandlerMock, TCPServerMock): 

136 self.prepare(RequestHandlerMock, TCPServerMock) 

137 self.handle(RequestHandlerMock, TCPServerMock) 

138 

139 self.assertIsInstance(self.handler, socketserver.StreamRequestHandler) 

140 

141 def test_handle_request_handler_error_shutsdown( 

142 self, RequestHandlerMock, TCPServerMock): 

143 self.prepare(RequestHandlerMock, TCPServerMock) 

144 RequestHandlerMock.return_value.handle.side_effect = RequestHandlerError() 

145 self.handle(RequestHandlerMock, TCPServerMock) 

146 

147 self.assertTrue(self.handler.server.shutdown.called) 

148 

149 def test_handle_request_handler_shutdown_shutsdown( 

150 self, RequestHandlerMock, TCPServerMock): 

151 self.prepare(RequestHandlerMock, TCPServerMock) 

152 RequestHandlerMock.return_value.handle.side_effect = RequestHandlerShutdown() 

153 self.handle(RequestHandlerMock, TCPServerMock) 

154 

155 self.assertTrue(self.handler.server.shutdown.called) 

156 

157 def test_handle_request_handler_connection_error_doesnotshutdown( 

158 self, RequestHandlerMock, TCPServerMock): 

159 self.prepare(RequestHandlerMock, TCPServerMock) 

160 RequestHandlerMock.return_value.handle.side_effect = ConnectionError() 

161 self.handle(RequestHandlerMock, TCPServerMock) 

162 

163 self.assertFalse(self.handler.server.shutdown.called) 

164 

165 def test_handle_request_handler_other_error_doesnotshutdown( 

166 self, RequestHandlerMock, TCPServerMock): 

167 self.prepare(RequestHandlerMock, TCPServerMock) 

168 RequestHandlerMock.return_value.handle.side_effect = Exception() 

169 self.handle(RequestHandlerMock, TCPServerMock) 

170 

171 self.assertFalse(self.handler.server.shutdown.called) 

172 

173 

174class TestRequestHandler(TestCase): 

175 def setUp(self): 

176 self.protocol = Mock() 

177 self.logger = Mock() 

178 self.handler = _RequestHandler(self.protocol, self.logger) 

179 self.rfile = Mock() 

180 self.wfile = Mock() 

181 

182 def test_logger_protocol(self): 

183 self.assertEqual(self.protocol, self.handler.protocol) 

184 self.assertEqual(self.logger, self.handler.logger) 

185 

186 def test_handler_correct_encoding(self): 

187 self.assertEqual(self.handler.ENCODING, "utf-8") 

188 

189 def test_handle_ok(self): 

190 self.mock_request('{"this-is": "legal", "json": "format"}') 

191 self.protocol.handle_request.return_value = {"this": "is", "the": "result"} 

192 

193 self.do_request() 

194 

195 self.assertEqual( 

196 self.protocol.handle_request.call_args_list, 

197 [call({ 

198 "this-is": "legal", 

199 "json": "format" 

200 })], 

201 ) 

202 self.assertEqual( 

203 self.wfile.write.call_args_list, 

204 [ 

205 call('{"the": "result", "this": "is"}'.encode("utf-8")), 

206 call("\n".encode("utf-8")), 

207 ], 

208 ) 

209 

210 def test_handle_broken_pipe_reply(self): 

211 self.mock_request('{"this-is": "legal", "json": "format"}') 

212 self.protocol.handle_request.return_value = {"this": "is", "the": "result"} 

213 self.wfile.write.side_effect = BrokenPipeError() 

214 

215 self.do_request() 

216 

217 self.assertEqual( 

218 self.protocol.handle_request.call_args_list, 

219 [call({ 

220 "this-is": "legal", 

221 "json": "format" 

222 })], 

223 ) 

224 self.assertEqual( 

225 self.wfile.write.call_args_list, 

226 [call('{"the": "result", "this": "is"}'.encode("utf-8"))], 

227 ) 

228 

229 def test_handle_json_error(self): 

230 self.mock_request("this-is-not-json") 

231 self.protocol.format_error.return_value = {"format": "error", "a": "bad"} 

232 

233 self.do_request() 

234 

235 self.assertFalse(self.protocol.handle_request.called) 

236 self.assertEqual( 

237 self.wfile.write.call_args_list, 

238 [ 

239 call('{"a": "bad", "format": "error"}'.encode("utf-8")), 

240 call("\n".encode("utf-8")), 

241 ], 

242 ) 

243 

244 def test_handle_notimplemented_error(self): 

245 self.mock_request('{"another": "valid", "json": "request"}') 

246 self.protocol.handle_request.side_effect = NotImplementedError( 

247 "method not implemented") 

248 

249 self.do_request() 

250 

251 self.assertEqual( 

252 self.protocol.handle_request.call_args_list, 

253 [call({ 

254 "another": "valid", 

255 "json": "request" 

256 })], 

257 ) 

258 self.assertEqual( 

259 self.wfile.write.call_args_list, 

260 [call("{}".encode("utf-8")), 

261 call("\n".encode("utf-8"))], 

262 ) 

263 

264 def test_handle_protocol_error(self): 

265 self.mock_request('{"another": "valid", "json": "request"}') 

266 self.protocol.unknown_error.return_value = {"an": "unknown", "e": "rror"} 

267 self.protocol.handle_request.side_effect = HSM2ProtocolError("protocol error") 

268 

269 with self.assertRaises(RequestHandlerError): 

270 self.do_request() 

271 

272 self.assertEqual( 

273 self.protocol.handle_request.call_args_list, 

274 [call({ 

275 "another": "valid", 

276 "json": "request" 

277 })], 

278 ) 

279 self.assertEqual( 

280 self.wfile.write.call_args_list, 

281 [ 

282 call('{"an": "unknown", "e": "rror"}'.encode("utf-8")), 

283 call("\n".encode("utf-8")), 

284 ], 

285 ) 

286 

287 def test_handle_protocol_shutdown(self): 

288 self.mock_request('{"another": "valid", "json": "request"}') 

289 self.protocol.handle_request.side_effect = HSM2ProtocolInterrupt( 

290 "protocol interrupt") 

291 

292 with self.assertRaises(RequestHandlerShutdown): 

293 self.do_request() 

294 

295 self.assertEqual( 

296 self.protocol.handle_request.call_args_list, 

297 [call({ 

298 "another": "valid", 

299 "json": "request" 

300 })], 

301 ) 

302 self.assertEqual( 

303 self.wfile.write.call_args_list, 

304 [call("{}".encode("utf-8")), 

305 call("\n".encode("utf-8"))], 

306 ) 

307 

308 def test_handle_unknown_exception(self): 

309 self.mock_request('{"another": "valid", "json": "request"}') 

310 self.protocol.handle_request.side_effect = ValueError("unexpected") 

311 

312 with self.assertRaises(RequestHandlerError): 

313 self.do_request() 

314 

315 self.assertEqual( 

316 self.protocol.handle_request.call_args_list, 

317 [call({ 

318 "another": "valid", 

319 "json": "request" 

320 })], 

321 ) 

322 self.assertEqual( 

323 self.wfile.write.call_args_list, 

324 [call("{}".encode("utf-8")), 

325 call("\n".encode("utf-8"))], 

326 ) 

327 

328 def test_handle_invalid_encoding(self): 

329 self.rfile.readline.return_value = b"\xff\xfa\xf0" 

330 self.protocol.format_error.return_value = {"encoding": "error", "a": "bad"} 

331 

332 self.do_request() 

333 

334 self.assertFalse(self.protocol.handle_request.called) 

335 self.assertEqual( 

336 self.wfile.write.call_args_list, 

337 [ 

338 call('{"a": "bad", "encoding": "error"}'.encode("utf-8")), 

339 call("\n".encode("utf-8")), 

340 ], 

341 ) 

342 

343 def test_handle_invalid_encoding_broken_pipe(self): 

344 self.rfile.readline.return_value = b"\xff\xfa\xf0" 

345 self.protocol.format_error.return_value = {"encoding": "error", "a": "bad"} 

346 self.wfile.write.side_effect = BrokenPipeError() 

347 

348 self.do_request() 

349 

350 self.assertFalse(self.protocol.handle_request.called) 

351 self.assertEqual( 

352 self.wfile.write.call_args_list, 

353 [call('{"a": "bad", "encoding": "error"}'.encode("utf-8"))], 

354 ) 

355 

356 def mock_request(self, line): 

357 self.rfile.readline.return_value = line.encode("utf-8") 

358 

359 def do_request(self): 

360 self.handler.handle("an-address", self.rfile, self.wfile)