1- --!native
21--!optimize 2
32--!strict
43
54local BASE64_ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
6- local BASE64_ENCODING_LUT = table.create (4096 )
7- local BASE64_DECODING_LUT = table .create (255 , 0 )
5+ local BASE64_ENCODING_LUT = table.create (4096 ) :: { number }
6+ local BASE64_DECODING_LUT = buffer .create (256 )
87
98do
109 for i = 0 , 4095 do
1413 bit32.bor (string.byte (BASE64_ALPHABET , hi ), bit32.lshift (string.byte (BASE64_ALPHABET , lo ), 8 ))
1514 end
1615
16+ -- prefill lookup table with invalid base64 value (0xFF)
17+ buffer.fill (BASE64_DECODING_LUT , 0 , 0xFF )
1718 for i = 1 , # BASE64_ALPHABET do
18- BASE64_DECODING_LUT [ string.byte (BASE64_ALPHABET , i )] = i - 1
19+ buffer.writeu8 ( BASE64_DECODING_LUT , string.byte (BASE64_ALPHABET , i ), i - 1 )
1920 end
2021end
2122
23+ @ native
2224local function encode (input_buffer : buffer ): buffer
2325 assert (typeof (input_buffer ) == "buffer" , "Expected input to be a buffer" )
2426
@@ -71,6 +73,7 @@ local function encode(input_buffer: buffer): buffer
7173 return output
7274end
7375
76+ @ native
7477local function decode (input_buffer : buffer ): buffer
7578 assert (typeof (input_buffer ) == "buffer" , "Expected input to be a buffer" )
7679
@@ -80,49 +83,56 @@ local function decode(input_buffer: buffer): buffer
8083 return buffer.create (0 )
8184 end
8285
83- local padding_size = 0
84- if input_length >= 2 and buffer.readu16 (input_buffer , input_length - 2 ) == 0x3D3D then
85- padding_size = 2
86- elseif input_length >= 1 and buffer.readu8 (input_buffer , input_length - 1 ) == 0x3D then
87- padding_size = 1
86+ -- strip padding (rfc-4648 section 3.3: the excess pad characters MAY also be ignored)
87+ while input_length > 0 and buffer.readu8 (input_buffer , input_length - 1 ) == 0x3D do
88+ input_length -= 1
89+ end
90+
91+ -- rfc-4648 section 3.3 forbids padding that isn't preceded by at least one Base64 digit
92+ if input_length == 0 then
93+ error ("Invalid base64 input" , 2 )
8894 end
8995
9096 -- get correct output size
91- local output_length = (( input_length / 4 ) * 3 ) - padding_size
97+ local output_length = (3 * input_length ) // 4
9298 local output = buffer.create (output_length )
93- local chunks = input_length // 4
94-
95- for chunk_idx = 1 , chunks do
96- local index = (chunk_idx - 1 ) * 4
97- local out_index = (chunk_idx - 1 ) * 3
98-
99- local value1 = BASE64_DECODING_LUT [buffer.readu8 (input_buffer , index )]
100- local value2 = BASE64_DECODING_LUT [buffer.readu8 (input_buffer , index + 1 )]
101- local value3 = BASE64_DECODING_LUT [buffer.readu8 (input_buffer , index + 2 )]
102- local value4 = BASE64_DECODING_LUT [buffer.readu8 (input_buffer , index + 3 )]
103-
104- local chunk = bit32.bor (bit32.lshift (value1 , 18 ), bit32.lshift (value2 , 12 ), bit32.lshift (value3 , 6 ), value4 )
105-
106- local character1 = bit32.rshift (chunk , 16 )
107- local character2 = bit32.band (bit32.rshift (chunk , 8 ), 0b11111111 )
108- local character3 = bit32.band (chunk , 0b11111111 )
10999
110- -- always write the first byte
111- if out_index < output_length then
112- buffer.writeu8 (output , out_index , character1 )
113- end
114-
115- -- write second byte if have space (+padding)
116- if out_index + 1 < output_length then
117- buffer.writeu8 (output , out_index + 1 , character2 )
100+ local read_offset = 0
101+ local write_offset = 0
102+ -- loop invariant: at least 4 bytes to write
103+ while write_offset + 4 <= output_length do
104+ local b4 = buffer.readu8 (BASE64_DECODING_LUT , buffer.readu8 (input_buffer , read_offset + 3 ))
105+ local b3 = buffer.readu8 (BASE64_DECODING_LUT , buffer.readu8 (input_buffer , read_offset + 2 ))
106+ local b2 = buffer.readu8 (BASE64_DECODING_LUT , buffer.readu8 (input_buffer , read_offset + 1 ))
107+ local b1 = buffer.readu8 (BASE64_DECODING_LUT , buffer.readu8 (input_buffer , read_offset ))
108+ if bit32.bor (b1 , b2 , b3 , b4 ) >= 64 then
109+ error ("Invalid base64 input" , 2 )
118110 end
111+ read_offset += 4
112+ -- u32 BE: [B1, B2, B3, 0] = b1<<26|b2<<20|b3<<14|b4<<8, trailing 0 will be overwritten next iteration
113+ buffer.writeu32 (output , write_offset , bit32.byteswap (b1 * 0x4000000 + b2 * 0x100000 + b3 * 0x4000 + b4 * 0x100 ))
114+ write_offset += 3
115+ end
119116
120- -- Write third byte if we have space (+padding)
121- if out_index + 2 < output_length then
122- buffer.writeu8 (output , out_index + 2 , character3 )
117+ local u24be , nbits = 0 , 0
118+ while read_offset < input_length do
119+ local b = buffer.readu8 (BASE64_DECODING_LUT , buffer.readu8 (input_buffer , read_offset ))
120+ read_offset += 1
121+ if b >= 64 then
122+ error ("Invalid base64 input" , 2 )
123123 end
124+ u24be = u24be * 0x40 + b
125+ nbits += 6
126+ end
127+ while nbits >= 8 do
128+ buffer.writeu8 (output , write_offset , bit32.rshift (u24be , nbits - 8 ))
129+ nbits -= 8
130+ write_offset += 1
131+ end
132+ -- 2 or 4 leftover bits must be zero
133+ if nbits == 6 or (nbits == 2 and bit32.btest (u24be , 0x03 )) or (nbits == 4 and bit32.btest (u24be , 0x0F )) then
134+ error ("Invalid base64 input" , 2 )
124135 end
125-
126136 return output
127137end
128138
0 commit comments