Coverage for comm/protocol.py: 98%
210 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-07-10 13:43 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-07-10 13:43 +0000
1# The MIT License (MIT)
2#
3# Copyright (c) 2021 RSK Labs Ltd
4#
5# Permission is hereby granted, free of charge, to any person obtaining a copy of
6# this software and associated documentation files (the "Software"), to deal in
7# the Software without restriction, including without limitation the rights to
8# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
9# of the Software, and to permit persons to whom the Software is furnished to do
10# so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice shall be included in all
13# copies or substantial portions of the Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21# SOFTWARE.
23import logging
24from .bip32 import BIP32Path
25from .utils import \
26 is_nonempty_hex_string, is_hex_string_of_length, \
27 has_nonempty_hex_field, has_hex_field_of_length, \
28 has_field_of_type
30LOGGER_NAME = "protocol"
33class HSM2ProtocolError(RuntimeError):
34 pass
37class HSM2ProtocolInterrupt(Exception):
38 pass
41class HSM2Protocol:
42 # Request/Response keys
43 COMMAND_KEY = "command"
44 ERROR_CODE_KEY = "errorcode"
45 VERSION_KEY = "version"
47 # Success error codes
48 ERROR_CODE_OK = 0
49 ERROR_CODE_OK_PARTIAL = 1
51 # Auth-related error codes
52 ERROR_CODE_INVALID_AUTH = -101
53 ERROR_CODE_INVALID_MESSAGE = -102
54 ERROR_CODE_INVALID_KEYID = -103
56 # Blockchain bookkeeping error codes
57 ERROR_CODE_CHAINING_MISMATCH = -201
58 ERROR_CODE_POW_INVALID = -202
59 ERROR_CODE_TIP_MISMATCH = -203
60 ERROR_CODE_INVALID_INPUT_BLOCKS = -204
61 ERROR_CODE_INVALID_BROTHERS = -205
63 # Heartbeat error codes
64 ERROR_CODE_INVALID_HEARTBEAT_UD_VALUE = -301
66 # Generic error codes
67 ERROR_CODE_FORMAT_ERROR = -901
68 ERROR_CODE_INVALID_REQUEST = -902
69 ERROR_CODE_COMMAND_UNKNOWN = -903
70 ERROR_CODE_WRONG_VERSION = -904
71 ERROR_CODE_DEVICE = -905
72 ERROR_CODE_UNKNOWN = -906
74 # Protocol version
75 VERSION = 5
77 # Commands
78 VERSION_COMMAND = "version"
79 SIGN_COMMAND = "sign"
80 GETPUBKEY_COMMAND = "getPubKey"
81 ADVANCE_BLOCKCHAIN_COMMAND = "advanceBlockchain"
82 RESET_ADVANCE_BLOCKCHAIN_COMMAND = "resetAdvanceBlockchain"
83 BLOCKCHAIN_STATE_COMMAND = "blockchainState"
84 UPDATE_ANCESTOR_BLOCK_COMMAND = "updateAncestorBlock"
85 GET_BLOCKCHAIN_PARAMETERS = "blockchainParameters"
86 SIGNER_HEARTBEAT = "signerHeartbeat"
87 UI_HEARTBEAT = "uiHeartbeat"
89 # Minimum number of blocks to update the ancestor block
90 MINIMUM_UPDATE_ANCESTOR_BLOCKS = 1
92 # Signer and UI heartbeat user-defined value sizes
93 SIGNER_HBT_UD_VALUE_SIZE = 16 # bytes
94 UI_HBT_UD_VALUE_SIZE = 32 # bytes
96 def __init__(self):
97 self.logger = logging.getLogger(LOGGER_NAME)
98 self._init_mappings()
100 def handle_request(self, request):
101 self.logger.info("In %s", request)
102 response = self.__internal_handle_request(request)
103 self.logger.info("Out %s", response)
104 return response
106 def __internal_handle_request(self, request):
107 if type(request) != dict:
108 return self.format_error()
110 if self.COMMAND_KEY not in request:
111 return self._invalid_request()
113 if (
114 request[self.COMMAND_KEY] != self.VERSION_COMMAND
115 and self.VERSION_KEY not in request
116 ):
117 return self._invalid_request()
119 if self.VERSION_KEY in request and request[self.VERSION_KEY] != self.VERSION:
120 return self._wrong_version()
122 command = request[self.COMMAND_KEY]
123 self.logger.debug("Cmd: %s", command)
124 if command not in self._known_commands:
125 return self._command_unknown()
127 # Perform generic input validation
128 validation_result = self._validation_mappings[command](request)
129 if validation_result < 0:
130 return {self.ERROR_CODE_KEY: validation_result}
132 # Operations MUST return a tuple with TWO elements.
133 # First element MUST be an integer representing the outcome of the operation.
134 # Second element MUST be a dictionary with the result (if the operation is
135 # successful)
136 # or None if the operation failed.
137 # In the first element, a nonnegative integer indicates success, and then
138 # the result is the second
139 # element of the tuple with added protocol overhead.
140 # A negative integer indicates failure.
142 operation_result = self._mappings[command](request)
143 result = operation_result[0]
144 if result < 0:
145 return {self.ERROR_CODE_KEY: result}
147 output = operation_result[1]
148 output[self.ERROR_CODE_KEY] = result
149 return output
151 def initialize_device(self):
152 self._not_implemented("initialize_device")
154 def device_error(self):
155 self.logger.debug("Generic error")
156 return {self.ERROR_CODE_KEY: self.ERROR_CODE_DEVICE}
158 def unknown_error(self):
159 self.logger.debug("Generic error")
160 return {self.ERROR_CODE_KEY: self.ERROR_CODE_UNKNOWN}
162 def format_error(self):
163 self.logger.debug("Format error")
164 return {self.ERROR_CODE_KEY: self.ERROR_CODE_FORMAT_ERROR}
166 def _invalid_request(self):
167 self.logger.debug("Invalid request")
168 return {self.ERROR_CODE_KEY: self.ERROR_CODE_INVALID_REQUEST}
170 def _wrong_version(self):
171 self.logger.debug("Invalid version")
172 return {self.ERROR_CODE_KEY: self.ERROR_CODE_WRONG_VERSION}
174 def _command_unknown(self):
175 self.logger.debug("Command unknown")
176 return {self.ERROR_CODE_KEY: self.ERROR_CODE_COMMAND_UNKNOWN}
178 def _version(self, request):
179 return (0, {self.VERSION_KEY: self.VERSION})
181 def _validate_advance_blockchain(self, request):
182 # Validate blocks presence, type and minimum length
183 if (
184 "blocks" not in request
185 or type(request["blocks"]) != list
186 or len(request["blocks"]) == 0
187 ):
188 self.logger.info("Blocks field not present, not an array or empty")
189 return self.ERROR_CODE_INVALID_INPUT_BLOCKS
191 # Validate blocks elements are strings
192 if not all(type(item) == str for item in request["blocks"]):
193 self.logger.info("Some of the blocks elements are not strings")
194 return self.ERROR_CODE_INVALID_INPUT_BLOCKS
196 # Validate brothers presence, type and length
197 if (
198 "brothers" not in request
199 or type(request["brothers"]) != list
200 or len(request["brothers"]) != len(request["blocks"])
201 ):
202 self.logger.info("Brothers field not present, not an array or "
203 "different in length to Blocks field")
204 return self.ERROR_CODE_INVALID_BROTHERS
206 # Validate brother elements are lists of nonempty hex strings
207 if not all(type(item) == list for item in request["brothers"]) or \
208 not all(type(item) == str and is_nonempty_hex_string(item)
209 for brother_list in request["brothers"]
210 for item in brother_list):
211 self.logger.info("Some of the brother list elements are not strings")
212 return self.ERROR_CODE_INVALID_BROTHERS
214 return self.ERROR_CODE_OK
216 def _advance_blockchain(self, request):
217 self._not_implemented(self.ADVANCE_BLOCKCHAIN_COMMAND)
219 def _reset_advance_blockchain(self, request):
220 self._not_implemented(self.RESET_ADVANCE_BLOCKCHAIN_COMMAND)
222 def _blockchain_state(self, request):
223 self._not_implemented(self.BLOCKCHAIN_STATE_COMMAND)
225 def _validate_update_ancestor_block(self, request):
226 # Validate blocks presence, type and minimum length
227 if (
228 "blocks" not in request
229 or type(request["blocks"]) != list
230 or len(request["blocks"]) < self.MINIMUM_UPDATE_ANCESTOR_BLOCKS
231 ):
232 self.logger.info(
233 "Blocks field not present, not an array or shorter than the minimum "
234 "(%d blocks)" % self.MINIMUM_UPDATE_ANCESTOR_BLOCKS
235 )
236 return self.ERROR_CODE_INVALID_INPUT_BLOCKS
238 # Validate blocks elements are strings
239 if not all(type(item) == str for item in request["blocks"]):
240 self.logger.info("Some of the blocks elements are not strings")
241 return self.ERROR_CODE_INVALID_INPUT_BLOCKS
243 return self.ERROR_CODE_OK
245 def _update_ancestor_block(self, request):
246 self._not_implemented(self.UPDATE_ANCESTOR_BLOCK_COMMAND)
248 def _validate_key_id(self, request):
249 # The keyId field must be present
250 if "keyId" not in request or type(request["keyId"]) != str:
251 self.logger.info("Key ID field not present")
252 return self.ERROR_CODE_INVALID_KEYID
254 try:
255 # This overrides the "keyId" within the request itself, which
256 # might not be the best idea. Nevertheless, the only possible
257 # thing to do with this key id (which should be a BIP32 path every time)
258 # is validate it and then use it as a BIP32Path. The original string
259 # won't be needed and can always be retrieved using the BIP32Path
260 # instance.
261 request["keyId"] = BIP32Path(request["keyId"])
262 except ValueError as e:
263 self.logger.info("Invalid Key ID: %s", str(e))
264 return self.ERROR_CODE_INVALID_KEYID
266 return self.ERROR_CODE_OK
268 def _validate_auth(self, request, mandatory):
269 # The authorization field must either:
270 # - Not be present if mandatory == False
271 # - Be a dictionary with all the required fields
272 if "auth" not in request:
273 return self.ERROR_CODE_OK if not mandatory else self.ERROR_CODE_INVALID_AUTH
275 auth = request["auth"]
277 # Validate auth field is present and a dictionary (object)
278 if type(auth) != dict:
279 self.logger.info("Authorization field not an object")
280 return self.ERROR_CODE_INVALID_AUTH
282 # Validate receipt presence and type
283 if (
284 "receipt" not in auth
285 or type(auth["receipt"]) != str
286 or not is_nonempty_hex_string(auth["receipt"])
287 ):
288 self.logger.info(
289 "Transaction receipt field not present or not a nonempty hex string"
290 )
291 return self.ERROR_CODE_INVALID_AUTH
293 # Validate receipt merkle proof inclusion presence, type and minimum length
294 if (
295 "receipt_merkle_proof" not in auth
296 or type(auth["receipt_merkle_proof"]) != list
297 or len(auth["receipt_merkle_proof"]) == 0
298 ):
299 self.logger.info(
300 "Receipt merkle proof field not present or not a nonempty array"
301 )
302 return self.ERROR_CODE_INVALID_AUTH
304 # Validate merkle proof elements are nonempty hex strings
305 if not all(
306 type(item) == str and is_nonempty_hex_string(item)
307 for item in auth["receipt_merkle_proof"]
308 ):
309 self.logger.info(
310 "Some of the receipt merkle proof elements are not nonempty hex strings"
311 )
312 return self.ERROR_CODE_INVALID_AUTH
314 return self.ERROR_CODE_OK
316 def _validate_message(self, request, what):
317 # Message field must always be present and a dictionary
318 # Also, it must:
319 # - Contain exactly a "hash" element of type string (1) that must be a 32-byte hex
320 # (what is "any" or "hash")
321 # - Contain exactly a "tx" element of type string that must be a hex string;
322 # an "input" element of type int; a "sighashComputationMode" element
323 # of type string that contains exactly either "legacy" (2a) or "segwit" (2b);
324 # and, if the latter contains "segwit", then additionally:
325 # o A "witnessScript" element of type string that must be a hex string
326 # o An "outpointValue" element of type int that must be greater than 0 and
327 # at most 0xffffffffffffffff
328 # (what is "any" or "tx")
330 # Validate message presence and components
331 if "message" not in request or type(request["message"]) != dict:
332 self.logger.info("Message field not present or not an object")
333 return self.ERROR_CODE_INVALID_MESSAGE
335 message = request["message"]
337 # (1)?
338 if (
339 what in ["any", "hash"]
340 and len(message) == 1
341 and has_hex_field_of_length(message, "hash", 32)
342 ):
343 return self.ERROR_CODE_OK
345 # (2a)
346 if (
347 what in ["any", "tx"]
348 and len(message) == 3
349 and has_nonempty_hex_field(message, "tx")
350 and has_field_of_type(message, "input", int)
351 and has_field_of_type(message, "sighashComputationMode", str)
352 and message["sighashComputationMode"] == "legacy"
353 ):
354 return self.ERROR_CODE_OK
356 # (2b)
357 if (
358 what in ["any", "tx"]
359 and len(message) == 5
360 and has_nonempty_hex_field(message, "tx")
361 and has_field_of_type(message, "input", int)
362 and has_field_of_type(message, "sighashComputationMode", str)
363 and message["sighashComputationMode"] == "segwit"
364 and has_nonempty_hex_field(message, "witnessScript")
365 and has_field_of_type(message, "outpointValue", int)
366 and message["outpointValue"] > 0
367 and message["outpointValue"] <= 0xffffffffffffffff
368 ):
369 return self.ERROR_CODE_OK
371 self.logger.info("Message field for expected message of type '%s' invalid", what)
372 return self.ERROR_CODE_INVALID_MESSAGE
374 def _validate_get_pubkey(self, request):
375 # Validate key id
376 # SIDE EFFECT: request["keyId"] is turned into a comm.bip32.BIP32Path instance
377 keyid_validation = self._validate_key_id(request)
378 if keyid_validation < self.ERROR_CODE_OK:
379 return keyid_validation
381 return self.ERROR_CODE_OK
383 # In concrete classes, this should implement the "getPubKey" operation
384 # The parameters of the operation are within the "request" dictionary:
385 # keyId: a BIP32Path instance
386 def _get_pubkey(self, request):
387 self._not_implemented(self.GETPUBKEY_COMMAND)
389 def _validate_sign(self, request):
390 # Validate key id
391 # SIDE EFFECT: request["keyId"] is turned into a comm.bip32.BIP32Path instance
392 keyid_validation = self._validate_key_id(request)
393 if keyid_validation < self.ERROR_CODE_OK:
394 return keyid_validation
396 # Validate auth fields
397 auth_validation = self._validate_auth(request, mandatory=False)
398 if auth_validation < self.ERROR_CODE_OK:
399 return auth_validation
401 # Validate message fields
402 message_validation = self._validate_message(request, what="any")
403 if message_validation < self.ERROR_CODE_OK:
404 return message_validation
406 return self.ERROR_CODE_OK
408 # In concrete classes, this should implement the "sign" operation
409 # The parameters of the operation are within the "request" dictionary:
410 # keyId: a BIP32Path instance
411 # message: an object that can contain "tx" (str), "input" (int) or "hash" (str)
412 # objects within.
413 # auth: an object that can contain "receipt" and "receipt_merkle_proof"
414 # (all str) objects within.
415 def _sign(self, request):
416 self._not_implemented(self.SIGN_COMMAND)
418 def _get_blockchain_parameters(self, request):
419 self._not_implemented(self.GET_BLOCKCHAIN_PARAMETERS)
421 def _validate_signer_heartbeat(self, request):
422 # Validate UD value presence, type and length
423 if (
424 "udValue" not in request
425 or type(request["udValue"]) != str
426 or not is_hex_string_of_length(request["udValue"],
427 self.SIGNER_HBT_UD_VALUE_SIZE)
428 ):
429 self.logger.info(
430 "User defined value field not present or not a "
431 f"{self.SIGNER_HBT_UD_VALUE_SIZE}-byte hex string"
432 )
433 return self.ERROR_CODE_INVALID_HEARTBEAT_UD_VALUE
435 return self.ERROR_CODE_OK
437 def _signer_heartbeat(self, request):
438 self._not_implemented(self.SIGNER_HEARTBEAT)
440 def _validate_ui_heartbeat(self, request):
441 # Validate UD value presence, type and length
442 if (
443 "udValue" not in request
444 or type(request["udValue"]) != str
445 or not is_hex_string_of_length(request["udValue"], self.UI_HBT_UD_VALUE_SIZE)
446 ):
447 self.logger.info(
448 "User defined value field not present or not a "
449 f"{self.UI_HBT_UD_VALUE_SIZE}-byte hex string"
450 )
451 return self.ERROR_CODE_INVALID_HEARTBEAT_UD_VALUE
453 return self.ERROR_CODE_OK
455 def _ui_heartbeat(self, request):
456 self._not_implemented(self.UI_HEARTBEAT)
458 def _not_implemented(self, funcname):
459 self.logger.warning("%s not implemented", funcname)
460 raise NotImplementedError(funcname)
462 def _init_mappings(self):
463 self.logger.debug("Initializing mappings")
464 self._mappings = {
465 self.VERSION_COMMAND: self._version,
466 self.SIGN_COMMAND: self._sign,
467 self.GETPUBKEY_COMMAND: self._get_pubkey,
468 self.ADVANCE_BLOCKCHAIN_COMMAND: self._advance_blockchain,
469 self.RESET_ADVANCE_BLOCKCHAIN_COMMAND: self._reset_advance_blockchain,
470 self.BLOCKCHAIN_STATE_COMMAND: self._blockchain_state,
471 self.UPDATE_ANCESTOR_BLOCK_COMMAND: self._update_ancestor_block,
472 self.GET_BLOCKCHAIN_PARAMETERS: self._get_blockchain_parameters,
473 self.SIGNER_HEARTBEAT: self._signer_heartbeat,
474 self.UI_HEARTBEAT: self._ui_heartbeat,
475 }
477 # Command input validations
478 self._validation_mappings = {
479 self.VERSION_COMMAND: lambda r: 0,
480 self.SIGN_COMMAND: self._validate_sign,
481 self.GETPUBKEY_COMMAND: self._validate_get_pubkey,
482 self.ADVANCE_BLOCKCHAIN_COMMAND: self._validate_advance_blockchain,
483 self.RESET_ADVANCE_BLOCKCHAIN_COMMAND: lambda r: 0,
484 self.BLOCKCHAIN_STATE_COMMAND: lambda r: 0,
485 self.UPDATE_ANCESTOR_BLOCK_COMMAND: self._validate_update_ancestor_block,
486 self.GET_BLOCKCHAIN_PARAMETERS: lambda r: 0,
487 self.SIGNER_HEARTBEAT: self._validate_signer_heartbeat,
488 self.UI_HEARTBEAT: self._validate_ui_heartbeat,
489 }
490 self._known_commands = self._mappings.keys()