Coverage for comm/server.py: 95%

120 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-07-10 13:43 +0000

1# The MIT License (MIT) 

2# 

3# Copyright (c) 2021 RSK Labs Ltd 

4# 

5# Permission is hereby granted, free of charge, to any person obtaining a copy of 

6# this software and associated documentation files (the "Software"), to deal in 

7# the Software without restriction, including without limitation the rights to 

8# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 

9# of the Software, and to permit persons to whom the Software is furnished to do 

10# so, subject to the following conditions: 

11# 

12# The above copyright notice and this permission notice shall be included in all 

13# copies or substantial portions of the Software. 

14# 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23import socketserver 

24import threading 

25import socket 

26import json 

27import logging 

28from comm.protocol import HSM2ProtocolError, HSM2ProtocolInterrupt 

29 

30LOGGER_NAME = "srver" 

31 

32 

33class RequestHandlerError(RuntimeError): 

34 pass 

35 

36 

37class RequestHandlerShutdown(RuntimeError): 

38 pass 

39 

40 

41class _RequestHandler: 

42 ENCODING = "utf-8" 

43 

44 def __init__(self, protocol, logger): 

45 self.protocol = protocol 

46 self.logger = logger 

47 

48 def handle(self, client_address, rfile, wfile): 

49 try: 

50 line = rfile.readline().strip() 

51 data = line.decode(self.ENCODING) 

52 except UnicodeDecodeError: 

53 output = json.dumps(self.protocol.format_error(), sort_keys=True) 

54 self.logger.info( 

55 "<= [%s]: invalid encoding input - 0x%s", client_address, line.hex() 

56 ) 

57 success = self._reply(wfile, output) 

58 if success: 

59 self.logger.info("=> [%s]: %s", client_address, output) 

60 return 

61 

62 self.logger.info("<= [%s]: %s", client_address, data) 

63 try: 

64 response = {} 

65 request = json.loads(data) 

66 self.logger.debug("Delivering request") 

67 response = self.protocol.handle_request(request) 

68 self.logger.debug("Got response: %s", response) 

69 except json.decoder.JSONDecodeError as e: 

70 self.logger.debug("JSON error: %s", e) 

71 response = self.protocol.format_error() 

72 except NotImplementedError as e: 

73 self.logger.critical("Not implemented: %s", e) 

74 except HSM2ProtocolError as e: 

75 response = self.protocol.unknown_error() 

76 raise RequestHandlerError(format(e)) 

77 except HSM2ProtocolInterrupt as e: 

78 raise RequestHandlerShutdown(format(e)) 

79 except Exception as e: 

80 message = "Unknown exception while handling request: %s" % format(e) 

81 self.logger.critical(message) 

82 raise RequestHandlerError(message) 

83 finally: 

84 output = json.dumps(response, sort_keys=True) 

85 success = self._reply(wfile, output) 

86 if success: 

87 self.logger.info("=> [%s]: %s", client_address, output) 

88 

89 def _reply(self, wfile, output): 

90 try: 

91 wfile.write(output.encode(self.ENCODING)) 

92 wfile.write("\n".encode(self.ENCODING)) 

93 return True 

94 except Exception as e: 

95 self.logger.warning("Error replying: %s", str(e)) 

96 return False 

97 

98 

99class _TCPServerRequestHandler(socketserver.StreamRequestHandler): 

100 def handle(self): 

101 try: 

102 handler = _RequestHandler(self.server.protocol, self.server.logger) 

103 handler.handle(self.client_address[0], self.rfile, self.wfile) 

104 except RequestHandlerError as e: 

105 # Log the error and shutdown 

106 self.server.logger.critical("Error handling request: %s", format(e)) 

107 self.shutdown() 

108 except RequestHandlerShutdown as e: 

109 # A shutdown has been requested, log and shutdown 

110 self.server.logger.info("Shutting down: %s", format(e)) 

111 self.shutdown() 

112 except ConnectionError as e: 

113 # A connection issue should log as an error 

114 # cause it is not common or expected 

115 self.server.logger.error("Connection error while serving request: %s", 

116 format(e)) 

117 except Exception as e: 

118 # Any unknown exception should log as critical 

119 self.server.logger.critical("UNKNOWN error serving request: %s", format(e)) 

120 

121 def shutdown(self): 

122 def tgt(): 

123 return self._do_shutdown() 

124 

125 threading.Thread(target=tgt).start() 

126 

127 def _do_shutdown(self): 

128 self.server.shutdown() 

129 

130 

131class TCPServerError(RuntimeError): 

132 pass 

133 

134 

135class TCPServer: 

136 def __init__(self, host, port, protocol): 

137 self.host = host 

138 self.port = port 

139 self.protocol = protocol 

140 self.logger = logging.getLogger(LOGGER_NAME) 

141 self.server = None 

142 

143 def run(self): 

144 try: 

145 self.logger.info("Initializing device") 

146 self.protocol.initialize_device() 

147 self.logger.info("Initializing server") 

148 socketserver.TCPServer.allow_reuse_address = True 

149 self.server = socketserver.TCPServer( 

150 (self.host, self.port), _TCPServerRequestHandler 

151 ) 

152 self.server.protocol = self.protocol 

153 self.server.logger = self.logger 

154 self.logger.info("Listening on %s:%d" % (self.host, self.port)) 

155 self.server.serve_forever() 

156 except socket.error as e: 

157 message = "Error running server: %s" % format(e) 

158 self.logger.critical(message) 

159 raise TCPServerError(message) 

160 except NotImplementedError as e: 

161 message = "Not implemented: %s" % format(e) 

162 self.logger.critical(message) 

163 raise TCPServerError(message) 

164 except KeyboardInterrupt: 

165 self.logger.info("Interrupted by user!") 

166 except HSM2ProtocolInterrupt: 

167 self.logger.info("Interrupted by HSM2 protocol!") 

168 except HSM2ProtocolError as e: 

169 message = "Error in device initialization: %s" % format(e) 

170 self.logger.critical(message) 

171 raise TCPServerError(message) 

172 finally: 

173 if self.server is not None: 

174 self.logger.info("Terminating server") 

175 self.server.server_close()