Coverage for comm/server.py: 95%

116 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-04-05 20:41 +0000

1# The MIT License (MIT) 

2# 

3# Copyright (c) 2021 RSK Labs Ltd 

4# 

5# Permission is hereby granted, free of charge, to any person obtaining a copy of 

6# this software and associated documentation files (the "Software"), to deal in 

7# the Software without restriction, including without limitation the rights to 

8# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 

9# of the Software, and to permit persons to whom the Software is furnished to do 

10# so, subject to the following conditions: 

11# 

12# The above copyright notice and this permission notice shall be included in all 

13# copies or substantial portions of the Software. 

14# 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23import socketserver 

24import threading 

25import socket 

26import json 

27import logging 

28from comm.protocol import HSM2ProtocolError, HSM2ProtocolInterrupt 

29 

30LOGGER_NAME = "srver" 

31 

32 

33class RequestHandlerError(RuntimeError): 

34 pass 

35 

36 

37class RequestHandlerShutdown(RuntimeError): 

38 pass 

39 

40 

41class _RequestHandler: 

42 ENCODING = "utf-8" 

43 

44 def __init__(self, protocol, logger): 

45 self.protocol = protocol 

46 self.logger = logger 

47 

48 def handle(self, client_address, rfile, wfile): 

49 try: 

50 line = rfile.readline().strip() 

51 data = line.decode(self.ENCODING) 

52 except UnicodeDecodeError: 

53 output = json.dumps(self.protocol.format_error(), sort_keys=True) 

54 self.logger.info( 

55 "<= [%s]: invalid encoding input - 0x%s", client_address, line.hex() 

56 ) 

57 self.logger.info("=> [%s]: %s", client_address, output) 

58 self._reply(wfile, output) 

59 return 

60 

61 self.logger.info("<= [%s]: %s", client_address, data) 

62 try: 

63 response = {} 

64 request = json.loads(data) 

65 self.logger.debug("Delivering request") 

66 response = self.protocol.handle_request(request) 

67 self.logger.debug("Got response: %s", response) 

68 except json.decoder.JSONDecodeError as e: 

69 self.logger.debug("JSON error: %s", e) 

70 response = self.protocol.format_error() 

71 except NotImplementedError as e: 

72 self.logger.critical("Not implemented: %s", e) 

73 except HSM2ProtocolError as e: 

74 response = self.protocol.unknown_error() 

75 raise RequestHandlerError(format(e)) 

76 except HSM2ProtocolInterrupt as e: 

77 raise RequestHandlerShutdown(format(e)) 

78 except Exception as e: 

79 message = "Unknown exception while handling request: %s" % format(e) 

80 self.logger.critical(message) 

81 raise RequestHandlerError(message) 

82 finally: 

83 output = json.dumps(response, sort_keys=True) 

84 self.logger.info("=> [%s]: %s", client_address, output) 

85 self._reply(wfile, output) 

86 

87 def _reply(self, wfile, output): 

88 try: 

89 wfile.write(output.encode(self.ENCODING)) 

90 wfile.write("\n".encode(self.ENCODING)) 

91 except Exception as e: 

92 self.logger.warning("Error replying: %s", str(e)) 

93 

94 

95class _TCPServerRequestHandler(socketserver.StreamRequestHandler): 

96 def handle(self): 

97 try: 

98 handler = _RequestHandler(self.server.protocol, self.server.logger) 

99 handler.handle(self.client_address[0], self.rfile, self.wfile) 

100 except RequestHandlerError as e: 

101 # Log the error and shutdown 

102 self.server.logger.critical("Error handling request: %s", format(e)) 

103 self.shutdown() 

104 except RequestHandlerShutdown as e: 

105 # A shutdown has been requested, log and shutdown 

106 self.server.logger.info("Shutting down: %s", format(e)) 

107 self.shutdown() 

108 except ConnectionError as e: 

109 # A connection issue should log as an error 

110 # cause it is not common or expected 

111 self.server.logger.error("Connection error while serving request: %s", 

112 format(e)) 

113 except Exception as e: 

114 # Any unknown exception should log as critical 

115 self.server.logger.critical("UNKNOWN error serving request: %s", format(e)) 

116 

117 def shutdown(self): 

118 def tgt(): 

119 return self._do_shutdown() 

120 

121 threading.Thread(target=tgt).start() 

122 

123 def _do_shutdown(self): 

124 self.server.shutdown() 

125 

126 

127class TCPServerError(RuntimeError): 

128 pass 

129 

130 

131class TCPServer: 

132 def __init__(self, host, port, protocol): 

133 self.host = host 

134 self.port = port 

135 self.protocol = protocol 

136 self.logger = logging.getLogger(LOGGER_NAME) 

137 self.server = None 

138 

139 def run(self): 

140 try: 

141 self.logger.info("Initializing device") 

142 self.protocol.initialize_device() 

143 self.logger.info("Initializing server") 

144 socketserver.TCPServer.allow_reuse_address = True 

145 self.server = socketserver.TCPServer( 

146 (self.host, self.port), _TCPServerRequestHandler 

147 ) 

148 self.server.protocol = self.protocol 

149 self.server.logger = self.logger 

150 self.logger.info("Listening on %s:%d" % (self.host, self.port)) 

151 self.server.serve_forever() 

152 except socket.error as e: 

153 message = "Error running server: %s" % format(e) 

154 self.logger.critical(message) 

155 raise TCPServerError(message) 

156 except NotImplementedError as e: 

157 message = "Not implemented: %s" % format(e) 

158 self.logger.critical(message) 

159 raise TCPServerError(message) 

160 except KeyboardInterrupt: 

161 self.logger.info("Interrupted by user!") 

162 except HSM2ProtocolInterrupt: 

163 self.logger.info("Interrupted by HSM2 protocol!") 

164 except HSM2ProtocolError as e: 

165 message = "Error in device initialization: %s" % format(e) 

166 self.logger.critical(message) 

167 raise TCPServerError(message) 

168 finally: 

169 if self.server is not None: 

170 self.logger.info("Terminating server") 

171 self.server.server_close()