11import base64
22import binascii
3- import struct
43from struct import unpack
5- from typing import Any , Dict , List , Optional
4+ from typing import List , Literal , Optional , Union , cast
65
76from Crypto .Hash import keccak
87from loguru import logger
@@ -164,7 +163,7 @@ def __str__(self):
164163
165164
166165# Referenced from https://github.com/pyth-network/pyth-crosschain/blob/110caed6be3be7885773d2f6070b143cc13fb0ee/price_service/server/src/encoding.ts#L24
167- def encode_vaa_for_chain (vaa , vaa_format , buffer = False ):
166+ def encode_vaa_for_chain (vaa : str , vaa_format : str , buffer = False ) -> Union [ bytes , str ] :
168167 # check if vaa is already in vaa_format
169168 if isinstance (vaa , str ):
170169 if vaa_format == DEFAULT_VAA_ENCODING :
@@ -197,7 +196,7 @@ def encode_vaa_for_chain(vaa, vaa_format, buffer=False):
197196
198197# Referenced from https://github.com/wormhole-foundation/wormhole/blob/main/sdk/js/src/vaa/wormhole.ts#L26-L56
199198def parse_vaa (vaa , encoding ):
200- vaa = encode_vaa_for_chain (vaa , encoding , buffer = True )
199+ vaa = cast ( bytes , encode_vaa_for_chain (vaa , encoding , buffer = True ) )
201200
202201 num_signers = vaa [5 ]
203202 sig_length = 66
@@ -284,7 +283,7 @@ def parse_batch_price_attestation(bytes_):
284283 offset += 2
285284
286285 price_attestations = []
287- for i in range (batch_len ):
286+ for _ in range (batch_len ):
288287 price_attestations .append (
289288 parse_price_attestation (bytes_ [offset : offset + attestation_size ])
290289 )
@@ -401,13 +400,13 @@ def is_accumulator_update(vaa, encoding=DEFAULT_VAA_ENCODING) -> bool:
401400 Returns:
402401 bool: True if the VAA is an accumulator update, False otherwise.
403402 """
404- if encode_vaa_for_chain (vaa , encoding , buffer = True )[:4 ].hex () == ACCUMULATOR_MAGIC :
403+ if cast ( bytes , encode_vaa_for_chain (vaa , encoding , buffer = True ) )[:4 ].hex () == ACCUMULATOR_MAGIC :
405404 return True
406405 return False
407406
408407
409408# Referenced from https://github.com/pyth-network/pyth-crosschain/blob/110caed6be3be7885773d2f6070b143cc13fb0ee/price_service/server/src/rest.ts#L139
410- def vaa_to_price_infos (vaa , encoding = DEFAULT_VAA_ENCODING ) -> List [PriceInfo ]:
409+ def vaa_to_price_infos (vaa , encoding : Literal [ "hex" , "base64" ] = DEFAULT_VAA_ENCODING ) -> Optional [ List [PriceInfo ] ]:
411410 if is_accumulator_update (vaa , encoding ):
412411 return extract_price_info_from_accumulator_update (vaa , encoding )
413412 parsed_vaa = parse_vaa (vaa , encoding )
@@ -425,7 +424,7 @@ def vaa_to_price_infos(vaa, encoding=DEFAULT_VAA_ENCODING) -> List[PriceInfo]:
425424 return price_infos
426425
427426
428- def vaa_to_price_info (id , vaa , encoding = DEFAULT_VAA_ENCODING ) -> Optional [PriceInfo ]:
427+ def vaa_to_price_info (id : str , vaa : str , encoding : Literal [ "hex" , "base64" ] = DEFAULT_VAA_ENCODING ) -> Optional [PriceInfo ]:
429428 """
430429 This function retrieves a specific PriceInfo object from a given VAA.
431430
@@ -502,14 +501,21 @@ def price_attestation_to_price_feed(price_attestation):
502501
503502# Referenced from https://github.com/pyth-network/pyth-crosschain/blob/1a00598334e52fc5faf967eb1170d7fc23ad828b/price_service/server/src/rest.ts#L137
504503def extract_price_info_from_accumulator_update (
505- update_data , encoding
506- ) -> Optional [Dict [str , Any ]]:
504+ update_data : str ,
505+ encoding : Literal ["hex" , "base64" ]
506+ ) -> Optional [List [PriceInfo ]]:
507507 parsed_update_data = parse_accumulator_update (update_data , encoding )
508+ if parsed_update_data is None :
509+ return None
510+
508511 vaa_buffer = parsed_update_data .vaa
509512 if encoding == "hex" :
510513 vaa_str = vaa_buffer .hex ()
511514 elif encoding == "base64" :
512515 vaa_str = base64 .b64encode (vaa_buffer ).decode ("ascii" )
516+ else :
517+ raise ValueError (f"Invalid encoding: { encoding } " )
518+
513519 parsed_vaa = parse_vaa (vaa_str , encoding )
514520 price_infos = []
515521 for update in parsed_update_data .updates :
@@ -581,7 +587,6 @@ def extract_price_info_from_accumulator_update(
581587
582588 return price_infos
583589
584-
585590def compress_accumulator_update (update_data_list , encoding ) -> List [str ]:
586591 """
587592 This function compresses a list of accumulator update data by combining those with the same VAA.
@@ -593,17 +598,21 @@ def compress_accumulator_update(update_data_list, encoding) -> List[str]:
593598
594599 Returns:
595600 List[str]: A list of serialized accumulator update data. Each item in the list is a hexadecimal string representing
596- an accumulator update data. The updates with the same VAA are combined and split into chunks of 255 updates each.
601+ an accumulator update data. The updates with the same VAA payload are combined and split into chunks of 255 updates each.
597602 """
598603 parsed_data_dict = {} # Use a dictionary for O(1) lookup
599604 # Combine the ones with the same VAA to a list
600605 for update_data in update_data_list :
601606 parsed_update_data = parse_accumulator_update (update_data , encoding )
602- vaa = parsed_update_data .vaa
603607
604- if vaa not in parsed_data_dict :
605- parsed_data_dict [vaa ] = []
606- parsed_data_dict [vaa ].append (parsed_update_data )
608+ if parsed_update_data is None :
609+ raise ValueError (f"Invalid accumulator update data: { update_data } " )
610+
611+ payload = parse_vaa (parsed_update_data .vaa .hex (), "hex" )["payload" ]
612+
613+ if payload not in parsed_data_dict :
614+ parsed_data_dict [payload ] = []
615+ parsed_data_dict [payload ].append (parsed_update_data )
607616 parsed_data_list = list (parsed_data_dict .values ())
608617
609618 # Combines accumulator update data with the same VAA into a single dictionary
@@ -698,7 +707,7 @@ def serialize_accumulator_update(data, encoding):
698707 return base64 .b64encode (serialized_data ).decode ("ascii" )
699708
700709
701- def parse_accumulator_update (update_data , encoding ) :
710+ def parse_accumulator_update (update_data : str , encoding : str ) -> Optional [ AccumulatorUpdate ] :
702711 """
703712 This function parses an accumulator update data.
704713
@@ -724,7 +733,8 @@ def parse_accumulator_update(update_data, encoding):
724733
725734 If the update type is not 0, the function logs an info message and returns None.
726735 """
727- encoded_update_data = encode_vaa_for_chain (update_data , encoding , buffer = True )
736+ encoded_update_data = cast (bytes , encode_vaa_for_chain (update_data , encoding , buffer = True ))
737+
728738 offset = 0
729739 magic = encoded_update_data [offset : offset + 4 ]
730740 offset += 4
0 commit comments