Coverage for comm/protocol.py: 98%

210 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-07-10 13:43 +0000

1# The MIT License (MIT) 

2# 

3# Copyright (c) 2021 RSK Labs Ltd 

4# 

5# Permission is hereby granted, free of charge, to any person obtaining a copy of 

6# this software and associated documentation files (the "Software"), to deal in 

7# the Software without restriction, including without limitation the rights to 

8# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 

9# of the Software, and to permit persons to whom the Software is furnished to do 

10# so, subject to the following conditions: 

11# 

12# The above copyright notice and this permission notice shall be included in all 

13# copies or substantial portions of the Software. 

14# 

15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

21# SOFTWARE. 

22 

23import logging 

24from .bip32 import BIP32Path 

25from .utils import \ 

26 is_nonempty_hex_string, is_hex_string_of_length, \ 

27 has_nonempty_hex_field, has_hex_field_of_length, \ 

28 has_field_of_type 

29 

30LOGGER_NAME = "protocol" 

31 

32 

33class HSM2ProtocolError(RuntimeError): 

34 pass 

35 

36 

37class HSM2ProtocolInterrupt(Exception): 

38 pass 

39 

40 

41class HSM2Protocol: 

42 # Request/Response keys 

43 COMMAND_KEY = "command" 

44 ERROR_CODE_KEY = "errorcode" 

45 VERSION_KEY = "version" 

46 

47 # Success error codes 

48 ERROR_CODE_OK = 0 

49 ERROR_CODE_OK_PARTIAL = 1 

50 

51 # Auth-related error codes 

52 ERROR_CODE_INVALID_AUTH = -101 

53 ERROR_CODE_INVALID_MESSAGE = -102 

54 ERROR_CODE_INVALID_KEYID = -103 

55 

56 # Blockchain bookkeeping error codes 

57 ERROR_CODE_CHAINING_MISMATCH = -201 

58 ERROR_CODE_POW_INVALID = -202 

59 ERROR_CODE_TIP_MISMATCH = -203 

60 ERROR_CODE_INVALID_INPUT_BLOCKS = -204 

61 ERROR_CODE_INVALID_BROTHERS = -205 

62 

63 # Heartbeat error codes 

64 ERROR_CODE_INVALID_HEARTBEAT_UD_VALUE = -301 

65 

66 # Generic error codes 

67 ERROR_CODE_FORMAT_ERROR = -901 

68 ERROR_CODE_INVALID_REQUEST = -902 

69 ERROR_CODE_COMMAND_UNKNOWN = -903 

70 ERROR_CODE_WRONG_VERSION = -904 

71 ERROR_CODE_DEVICE = -905 

72 ERROR_CODE_UNKNOWN = -906 

73 

74 # Protocol version 

75 VERSION = 5 

76 

77 # Commands 

78 VERSION_COMMAND = "version" 

79 SIGN_COMMAND = "sign" 

80 GETPUBKEY_COMMAND = "getPubKey" 

81 ADVANCE_BLOCKCHAIN_COMMAND = "advanceBlockchain" 

82 RESET_ADVANCE_BLOCKCHAIN_COMMAND = "resetAdvanceBlockchain" 

83 BLOCKCHAIN_STATE_COMMAND = "blockchainState" 

84 UPDATE_ANCESTOR_BLOCK_COMMAND = "updateAncestorBlock" 

85 GET_BLOCKCHAIN_PARAMETERS = "blockchainParameters" 

86 SIGNER_HEARTBEAT = "signerHeartbeat" 

87 UI_HEARTBEAT = "uiHeartbeat" 

88 

89 # Minimum number of blocks to update the ancestor block 

90 MINIMUM_UPDATE_ANCESTOR_BLOCKS = 1 

91 

92 # Signer and UI heartbeat user-defined value sizes 

93 SIGNER_HBT_UD_VALUE_SIZE = 16 # bytes 

94 UI_HBT_UD_VALUE_SIZE = 32 # bytes 

95 

96 def __init__(self): 

97 self.logger = logging.getLogger(LOGGER_NAME) 

98 self._init_mappings() 

99 

100 def handle_request(self, request): 

101 self.logger.info("In %s", request) 

102 response = self.__internal_handle_request(request) 

103 self.logger.info("Out %s", response) 

104 return response 

105 

106 def __internal_handle_request(self, request): 

107 if type(request) != dict: 

108 return self.format_error() 

109 

110 if self.COMMAND_KEY not in request: 

111 return self._invalid_request() 

112 

113 if ( 

114 request[self.COMMAND_KEY] != self.VERSION_COMMAND 

115 and self.VERSION_KEY not in request 

116 ): 

117 return self._invalid_request() 

118 

119 if self.VERSION_KEY in request and request[self.VERSION_KEY] != self.VERSION: 

120 return self._wrong_version() 

121 

122 command = request[self.COMMAND_KEY] 

123 self.logger.debug("Cmd: %s", command) 

124 if command not in self._known_commands: 

125 return self._command_unknown() 

126 

127 # Perform generic input validation 

128 validation_result = self._validation_mappings[command](request) 

129 if validation_result < 0: 

130 return {self.ERROR_CODE_KEY: validation_result} 

131 

132 # Operations MUST return a tuple with TWO elements. 

133 # First element MUST be an integer representing the outcome of the operation. 

134 # Second element MUST be a dictionary with the result (if the operation is 

135 # successful) 

136 # or None if the operation failed. 

137 # In the first element, a nonnegative integer indicates success, and then 

138 # the result is the second 

139 # element of the tuple with added protocol overhead. 

140 # A negative integer indicates failure. 

141 

142 operation_result = self._mappings[command](request) 

143 result = operation_result[0] 

144 if result < 0: 

145 return {self.ERROR_CODE_KEY: result} 

146 

147 output = operation_result[1] 

148 output[self.ERROR_CODE_KEY] = result 

149 return output 

150 

151 def initialize_device(self): 

152 self._not_implemented("initialize_device") 

153 

154 def device_error(self): 

155 self.logger.debug("Generic error") 

156 return {self.ERROR_CODE_KEY: self.ERROR_CODE_DEVICE} 

157 

158 def unknown_error(self): 

159 self.logger.debug("Generic error") 

160 return {self.ERROR_CODE_KEY: self.ERROR_CODE_UNKNOWN} 

161 

162 def format_error(self): 

163 self.logger.debug("Format error") 

164 return {self.ERROR_CODE_KEY: self.ERROR_CODE_FORMAT_ERROR} 

165 

166 def _invalid_request(self): 

167 self.logger.debug("Invalid request") 

168 return {self.ERROR_CODE_KEY: self.ERROR_CODE_INVALID_REQUEST} 

169 

170 def _wrong_version(self): 

171 self.logger.debug("Invalid version") 

172 return {self.ERROR_CODE_KEY: self.ERROR_CODE_WRONG_VERSION} 

173 

174 def _command_unknown(self): 

175 self.logger.debug("Command unknown") 

176 return {self.ERROR_CODE_KEY: self.ERROR_CODE_COMMAND_UNKNOWN} 

177 

178 def _version(self, request): 

179 return (0, {self.VERSION_KEY: self.VERSION}) 

180 

181 def _validate_advance_blockchain(self, request): 

182 # Validate blocks presence, type and minimum length 

183 if ( 

184 "blocks" not in request 

185 or type(request["blocks"]) != list 

186 or len(request["blocks"]) == 0 

187 ): 

188 self.logger.info("Blocks field not present, not an array or empty") 

189 return self.ERROR_CODE_INVALID_INPUT_BLOCKS 

190 

191 # Validate blocks elements are strings 

192 if not all(type(item) == str for item in request["blocks"]): 

193 self.logger.info("Some of the blocks elements are not strings") 

194 return self.ERROR_CODE_INVALID_INPUT_BLOCKS 

195 

196 # Validate brothers presence, type and length 

197 if ( 

198 "brothers" not in request 

199 or type(request["brothers"]) != list 

200 or len(request["brothers"]) != len(request["blocks"]) 

201 ): 

202 self.logger.info("Brothers field not present, not an array or " 

203 "different in length to Blocks field") 

204 return self.ERROR_CODE_INVALID_BROTHERS 

205 

206 # Validate brother elements are lists of nonempty hex strings 

207 if not all(type(item) == list for item in request["brothers"]) or \ 

208 not all(type(item) == str and is_nonempty_hex_string(item) 

209 for brother_list in request["brothers"] 

210 for item in brother_list): 

211 self.logger.info("Some of the brother list elements are not strings") 

212 return self.ERROR_CODE_INVALID_BROTHERS 

213 

214 return self.ERROR_CODE_OK 

215 

216 def _advance_blockchain(self, request): 

217 self._not_implemented(self.ADVANCE_BLOCKCHAIN_COMMAND) 

218 

219 def _reset_advance_blockchain(self, request): 

220 self._not_implemented(self.RESET_ADVANCE_BLOCKCHAIN_COMMAND) 

221 

222 def _blockchain_state(self, request): 

223 self._not_implemented(self.BLOCKCHAIN_STATE_COMMAND) 

224 

225 def _validate_update_ancestor_block(self, request): 

226 # Validate blocks presence, type and minimum length 

227 if ( 

228 "blocks" not in request 

229 or type(request["blocks"]) != list 

230 or len(request["blocks"]) < self.MINIMUM_UPDATE_ANCESTOR_BLOCKS 

231 ): 

232 self.logger.info( 

233 "Blocks field not present, not an array or shorter than the minimum " 

234 "(%d blocks)" % self.MINIMUM_UPDATE_ANCESTOR_BLOCKS 

235 ) 

236 return self.ERROR_CODE_INVALID_INPUT_BLOCKS 

237 

238 # Validate blocks elements are strings 

239 if not all(type(item) == str for item in request["blocks"]): 

240 self.logger.info("Some of the blocks elements are not strings") 

241 return self.ERROR_CODE_INVALID_INPUT_BLOCKS 

242 

243 return self.ERROR_CODE_OK 

244 

245 def _update_ancestor_block(self, request): 

246 self._not_implemented(self.UPDATE_ANCESTOR_BLOCK_COMMAND) 

247 

248 def _validate_key_id(self, request): 

249 # The keyId field must be present 

250 if "keyId" not in request or type(request["keyId"]) != str: 

251 self.logger.info("Key ID field not present") 

252 return self.ERROR_CODE_INVALID_KEYID 

253 

254 try: 

255 # This overrides the "keyId" within the request itself, which 

256 # might not be the best idea. Nevertheless, the only possible 

257 # thing to do with this key id (which should be a BIP32 path every time) 

258 # is validate it and then use it as a BIP32Path. The original string 

259 # won't be needed and can always be retrieved using the BIP32Path 

260 # instance. 

261 request["keyId"] = BIP32Path(request["keyId"]) 

262 except ValueError as e: 

263 self.logger.info("Invalid Key ID: %s", str(e)) 

264 return self.ERROR_CODE_INVALID_KEYID 

265 

266 return self.ERROR_CODE_OK 

267 

268 def _validate_auth(self, request, mandatory): 

269 # The authorization field must either: 

270 # - Not be present if mandatory == False 

271 # - Be a dictionary with all the required fields 

272 if "auth" not in request: 

273 return self.ERROR_CODE_OK if not mandatory else self.ERROR_CODE_INVALID_AUTH 

274 

275 auth = request["auth"] 

276 

277 # Validate auth field is present and a dictionary (object) 

278 if type(auth) != dict: 

279 self.logger.info("Authorization field not an object") 

280 return self.ERROR_CODE_INVALID_AUTH 

281 

282 # Validate receipt presence and type 

283 if ( 

284 "receipt" not in auth 

285 or type(auth["receipt"]) != str 

286 or not is_nonempty_hex_string(auth["receipt"]) 

287 ): 

288 self.logger.info( 

289 "Transaction receipt field not present or not a nonempty hex string" 

290 ) 

291 return self.ERROR_CODE_INVALID_AUTH 

292 

293 # Validate receipt merkle proof inclusion presence, type and minimum length 

294 if ( 

295 "receipt_merkle_proof" not in auth 

296 or type(auth["receipt_merkle_proof"]) != list 

297 or len(auth["receipt_merkle_proof"]) == 0 

298 ): 

299 self.logger.info( 

300 "Receipt merkle proof field not present or not a nonempty array" 

301 ) 

302 return self.ERROR_CODE_INVALID_AUTH 

303 

304 # Validate merkle proof elements are nonempty hex strings 

305 if not all( 

306 type(item) == str and is_nonempty_hex_string(item) 

307 for item in auth["receipt_merkle_proof"] 

308 ): 

309 self.logger.info( 

310 "Some of the receipt merkle proof elements are not nonempty hex strings" 

311 ) 

312 return self.ERROR_CODE_INVALID_AUTH 

313 

314 return self.ERROR_CODE_OK 

315 

316 def _validate_message(self, request, what): 

317 # Message field must always be present and a dictionary 

318 # Also, it must: 

319 # - Contain exactly a "hash" element of type string (1) that must be a 32-byte hex 

320 # (what is "any" or "hash") 

321 # - Contain exactly a "tx" element of type string that must be a hex string; 

322 # an "input" element of type int; a "sighashComputationMode" element 

323 # of type string that contains exactly either "legacy" (2a) or "segwit" (2b); 

324 # and, if the latter contains "segwit", then additionally: 

325 # o A "witnessScript" element of type string that must be a hex string 

326 # o An "outpointValue" element of type int that must be greater than 0 and 

327 # at most 0xffffffffffffffff 

328 # (what is "any" or "tx") 

329 

330 # Validate message presence and components 

331 if "message" not in request or type(request["message"]) != dict: 

332 self.logger.info("Message field not present or not an object") 

333 return self.ERROR_CODE_INVALID_MESSAGE 

334 

335 message = request["message"] 

336 

337 # (1)? 

338 if ( 

339 what in ["any", "hash"] 

340 and len(message) == 1 

341 and has_hex_field_of_length(message, "hash", 32) 

342 ): 

343 return self.ERROR_CODE_OK 

344 

345 # (2a) 

346 if ( 

347 what in ["any", "tx"] 

348 and len(message) == 3 

349 and has_nonempty_hex_field(message, "tx") 

350 and has_field_of_type(message, "input", int) 

351 and has_field_of_type(message, "sighashComputationMode", str) 

352 and message["sighashComputationMode"] == "legacy" 

353 ): 

354 return self.ERROR_CODE_OK 

355 

356 # (2b) 

357 if ( 

358 what in ["any", "tx"] 

359 and len(message) == 5 

360 and has_nonempty_hex_field(message, "tx") 

361 and has_field_of_type(message, "input", int) 

362 and has_field_of_type(message, "sighashComputationMode", str) 

363 and message["sighashComputationMode"] == "segwit" 

364 and has_nonempty_hex_field(message, "witnessScript") 

365 and has_field_of_type(message, "outpointValue", int) 

366 and message["outpointValue"] > 0 

367 and message["outpointValue"] <= 0xffffffffffffffff 

368 ): 

369 return self.ERROR_CODE_OK 

370 

371 self.logger.info("Message field for expected message of type '%s' invalid", what) 

372 return self.ERROR_CODE_INVALID_MESSAGE 

373 

374 def _validate_get_pubkey(self, request): 

375 # Validate key id 

376 # SIDE EFFECT: request["keyId"] is turned into a comm.bip32.BIP32Path instance 

377 keyid_validation = self._validate_key_id(request) 

378 if keyid_validation < self.ERROR_CODE_OK: 

379 return keyid_validation 

380 

381 return self.ERROR_CODE_OK 

382 

383 # In concrete classes, this should implement the "getPubKey" operation 

384 # The parameters of the operation are within the "request" dictionary: 

385 # keyId: a BIP32Path instance 

386 def _get_pubkey(self, request): 

387 self._not_implemented(self.GETPUBKEY_COMMAND) 

388 

389 def _validate_sign(self, request): 

390 # Validate key id 

391 # SIDE EFFECT: request["keyId"] is turned into a comm.bip32.BIP32Path instance 

392 keyid_validation = self._validate_key_id(request) 

393 if keyid_validation < self.ERROR_CODE_OK: 

394 return keyid_validation 

395 

396 # Validate auth fields 

397 auth_validation = self._validate_auth(request, mandatory=False) 

398 if auth_validation < self.ERROR_CODE_OK: 

399 return auth_validation 

400 

401 # Validate message fields 

402 message_validation = self._validate_message(request, what="any") 

403 if message_validation < self.ERROR_CODE_OK: 

404 return message_validation 

405 

406 return self.ERROR_CODE_OK 

407 

408 # In concrete classes, this should implement the "sign" operation 

409 # The parameters of the operation are within the "request" dictionary: 

410 # keyId: a BIP32Path instance 

411 # message: an object that can contain "tx" (str), "input" (int) or "hash" (str) 

412 # objects within. 

413 # auth: an object that can contain "receipt" and "receipt_merkle_proof" 

414 # (all str) objects within. 

415 def _sign(self, request): 

416 self._not_implemented(self.SIGN_COMMAND) 

417 

418 def _get_blockchain_parameters(self, request): 

419 self._not_implemented(self.GET_BLOCKCHAIN_PARAMETERS) 

420 

421 def _validate_signer_heartbeat(self, request): 

422 # Validate UD value presence, type and length 

423 if ( 

424 "udValue" not in request 

425 or type(request["udValue"]) != str 

426 or not is_hex_string_of_length(request["udValue"], 

427 self.SIGNER_HBT_UD_VALUE_SIZE) 

428 ): 

429 self.logger.info( 

430 "User defined value field not present or not a " 

431 f"{self.SIGNER_HBT_UD_VALUE_SIZE}-byte hex string" 

432 ) 

433 return self.ERROR_CODE_INVALID_HEARTBEAT_UD_VALUE 

434 

435 return self.ERROR_CODE_OK 

436 

437 def _signer_heartbeat(self, request): 

438 self._not_implemented(self.SIGNER_HEARTBEAT) 

439 

440 def _validate_ui_heartbeat(self, request): 

441 # Validate UD value presence, type and length 

442 if ( 

443 "udValue" not in request 

444 or type(request["udValue"]) != str 

445 or not is_hex_string_of_length(request["udValue"], self.UI_HBT_UD_VALUE_SIZE) 

446 ): 

447 self.logger.info( 

448 "User defined value field not present or not a " 

449 f"{self.UI_HBT_UD_VALUE_SIZE}-byte hex string" 

450 ) 

451 return self.ERROR_CODE_INVALID_HEARTBEAT_UD_VALUE 

452 

453 return self.ERROR_CODE_OK 

454 

455 def _ui_heartbeat(self, request): 

456 self._not_implemented(self.UI_HEARTBEAT) 

457 

458 def _not_implemented(self, funcname): 

459 self.logger.warning("%s not implemented", funcname) 

460 raise NotImplementedError(funcname) 

461 

462 def _init_mappings(self): 

463 self.logger.debug("Initializing mappings") 

464 self._mappings = { 

465 self.VERSION_COMMAND: self._version, 

466 self.SIGN_COMMAND: self._sign, 

467 self.GETPUBKEY_COMMAND: self._get_pubkey, 

468 self.ADVANCE_BLOCKCHAIN_COMMAND: self._advance_blockchain, 

469 self.RESET_ADVANCE_BLOCKCHAIN_COMMAND: self._reset_advance_blockchain, 

470 self.BLOCKCHAIN_STATE_COMMAND: self._blockchain_state, 

471 self.UPDATE_ANCESTOR_BLOCK_COMMAND: self._update_ancestor_block, 

472 self.GET_BLOCKCHAIN_PARAMETERS: self._get_blockchain_parameters, 

473 self.SIGNER_HEARTBEAT: self._signer_heartbeat, 

474 self.UI_HEARTBEAT: self._ui_heartbeat, 

475 } 

476 

477 # Command input validations 

478 self._validation_mappings = { 

479 self.VERSION_COMMAND: lambda r: 0, 

480 self.SIGN_COMMAND: self._validate_sign, 

481 self.GETPUBKEY_COMMAND: self._validate_get_pubkey, 

482 self.ADVANCE_BLOCKCHAIN_COMMAND: self._validate_advance_blockchain, 

483 self.RESET_ADVANCE_BLOCKCHAIN_COMMAND: lambda r: 0, 

484 self.BLOCKCHAIN_STATE_COMMAND: lambda r: 0, 

485 self.UPDATE_ANCESTOR_BLOCK_COMMAND: self._validate_update_ancestor_block, 

486 self.GET_BLOCKCHAIN_PARAMETERS: lambda r: 0, 

487 self.SIGNER_HEARTBEAT: self._validate_signer_heartbeat, 

488 self.UI_HEARTBEAT: self._validate_ui_heartbeat, 

489 } 

490 self._known_commands = self._mappings.keys()