Skip to content

Commit e39a161

Browse files
committed
Draft bitinfo codec
1 parent b429e4b commit e39a161

File tree

2 files changed

+273
-4
lines changed

2 files changed

+273
-4
lines changed

numcodecs/bitinfo.py

+263
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
import numpy as np
2+
3+
from .compat import ensure_ndarray_like
4+
from .bitround import BitRound
5+
6+
# The size in bits of the mantissa/significand for the various floating types
7+
# You cannot keep more bits of data than you have available
8+
# https://en.wikipedia.org/wiki/IEEE_754
9+
10+
NMBITS = {64: 12, 32: 9, 16: 6} # number of non mantissa bits for given dtype
11+
12+
class BitInfo(BitRound):
13+
"""Floating-point bit information codec
14+
15+
Drops bits from the floating point mantissa, leaving an array more amenable
16+
to compression. The number of bits to keep is determined using the approach
17+
from Klöwer et al. 2021 (https://www.nature.com/articles/s43588-021-00156-2).
18+
See https://github.com/zarr-developers/numcodecs/issues/298 for discussion
19+
and the original implementation in Julia referred to at
20+
https://github.com/milankl/BitInformation.jl
21+
22+
Parameters
23+
----------
24+
25+
inflevel: float
26+
The number of bits of the mantissa to keep. The range allowed
27+
depends on the dtype input data. If keepbits is
28+
equal to the maximum allowed for the data type, this is equivalent
29+
to no transform.
30+
31+
axes: int or list of int, optional
32+
Axes along which to calculate the bit information. If None, all axes
33+
are used.
34+
"""
35+
36+
codec_id = 'bitinfo'
37+
38+
def __init__(self, inflevel: float, axes=None):
39+
if (inflevel < 0) or (inflevel > 1.0):
40+
raise ValueError("Please provide `inflevel` from interval [0.,1.]")
41+
42+
self.inflevel = inflevel
43+
self.axes = axes
44+
45+
def encode(self, buf):
46+
"""Create int array by rounding floating-point data
47+
48+
The itemsize will be preserved, but the output should be much more
49+
compressible.
50+
"""
51+
a = ensure_ndarray_like(buf)
52+
if not a.dtype.kind == "f" or a.dtype.itemsize > 8:
53+
raise TypeError("Only float arrays (16-64bit) can be bit-rounded")
54+
55+
if self.axes is None:
56+
axes = range(a.ndim)
57+
58+
itemsize = a.dtype.itemsize
59+
astype = f"u{itemsize}"
60+
if a.dtype in (np.float16, np.float32, np.float64):
61+
a = signed_exponent(a)
62+
63+
a = a.astype(astype)
64+
keepbits = []
65+
66+
for ax in axes:
67+
info_per_bit = bitinformation(a, axis=ax)
68+
keepbits.append(get_keepbits(info_per_bit, self.inflevel))
69+
70+
keepbits = max(keepbits)
71+
72+
return BitRound._bitround(a, keepbits)
73+
74+
75+
def exponent_bias(dtype):
76+
"""
77+
Returns the exponent bias for a given floating-point dtype.
78+
79+
Example
80+
-------
81+
>>> exponent_bias("f4")
82+
127
83+
>>> exponent_bias("f8")
84+
1023
85+
"""
86+
info = np.finfo(dtype)
87+
exponent_bits = info.bits - info.nmant - 1
88+
return 2 ** (exponent_bits - 1) - 1
89+
90+
91+
def exponent_mask(dtype):
92+
"""
93+
Returns exponent mask for a given floating-point dtype.
94+
95+
Example
96+
-------
97+
>>> np.binary_repr(exponent_mask(np.float32), width=32)
98+
'01111111100000000000000000000000'
99+
>>> np.binary_repr(exponent_mask(np.float16), width=16)
100+
'0111110000000000'
101+
"""
102+
if dtype == np.float16:
103+
mask = 0x7C00
104+
elif dtype == np.float32:
105+
mask = 0x7F80_0000
106+
elif dtype == np.float64:
107+
mask = 0x7FF0_0000_0000_0000
108+
return mask
109+
110+
111+
def signed_exponent(A):
112+
"""
113+
Transform biased exponent notation to signed exponent notation.
114+
115+
Parameters
116+
----------
117+
A : :py:class:`numpy.array`
118+
Array to transform
119+
120+
Returns
121+
-------
122+
B : :py:class:`numpy.array`
123+
124+
Example
125+
-------
126+
>>> A = np.array(0.03125, dtype="float32")
127+
>>> np.binary_repr(A.view("uint32"), width=32)
128+
'00111101000000000000000000000000'
129+
>>> np.binary_repr(signed_exponent(A), width=32)
130+
'01000010100000000000000000000000'
131+
>>> A = np.array(0.03125, dtype="float64")
132+
>>> np.binary_repr(A.view("uint64"), width=64)
133+
'0011111110100000000000000000000000000000000000000000000000000000'
134+
>>> np.binary_repr(signed_exponent(A), width=64)
135+
'0100000001010000000000000000000000000000000000000000000000000000'
136+
"""
137+
itemsize = A.dtype.itemsize
138+
uinttype = f"u{itemsize}"
139+
inttype = f"i{itemsize}"
140+
141+
sign_mask = 1 << np.finfo(A.dtype).bits - 1
142+
sfmask = sign_mask | (1 << np.finfo(A.dtype).nmant) - 1
143+
emask = exponent_mask(A.dtype)
144+
esignmask = sign_mask >> 1
145+
146+
sbits = np.finfo(A.dtype).nmant
147+
if itemsize == 8:
148+
sbits = np.uint64(sbits)
149+
emask = np.uint64(emask)
150+
bias = exponent_bias(A.dtype)
151+
152+
ui = A.view(uinttype)
153+
sf = ui & sfmask
154+
e = ((ui & emask) >> sbits).astype(inttype) - bias
155+
max_eabs = np.iinfo(A.view(uinttype).dtype).max >> sbits
156+
eabs = abs(e) % (max_eabs + 1)
157+
esign = np.where(e < 0, esignmask, 0)
158+
if itemsize == 8:
159+
eabs = np.uint64(eabs)
160+
esign = np.uint64(esign)
161+
esigned = esign | (eabs << sbits)
162+
B = (sf | esigned).view(np.int64)
163+
return B
164+
165+
166+
def bitpaircount_u1(a, b):
167+
assert a.dtype == "u1"
168+
assert b.dtype == "u1"
169+
unpack_a = np.unpackbits(a.flatten()).astype("u1")
170+
unpack_b = np.unpackbits(b.flatten()).astype("u1")
171+
172+
index = ((unpack_a << 1) | unpack_b).reshape(-1, 8)
173+
174+
selection = np.array([0, 1, 2, 3], dtype="u1")
175+
sel = np.where((index[..., np.newaxis]) == selection, True, False)
176+
return sel.sum(axis=0).reshape([8, 2, 2])
177+
178+
179+
def bitpaircount(a, b):
180+
assert a.dtype.kind == "u"
181+
assert b.dtype.kind == "u"
182+
nbytes = max(a.dtype.itemsize, b.dtype.itemsize)
183+
184+
a, b = np.broadcast_arrays(a, b)
185+
186+
bytewise_counts = []
187+
for i in range(nbytes):
188+
s = (nbytes - 1 - i) * 8
189+
bitc = bitpaircount_u1((a >> s).astype("u1"), (b >> s).astype("u1"))
190+
bytewise_counts.append(bitc)
191+
return np.concatenate(bytewise_counts, axis=0)
192+
193+
194+
def mutual_information(a, b, base=2):
195+
"""Calculate the mutual information between two arrays.
196+
"""
197+
size = np.prod(np.broadcast_shapes(a.shape, b.shape))
198+
counts = bitpaircount(a, b)
199+
200+
p = counts.astype("float") / size
201+
p = np.ma.masked_equal(p, 0)
202+
pr = p.sum(axis=-1)[..., np.newaxis]
203+
ps = p.sum(axis=-2)[..., np.newaxis, :]
204+
mutual_info = (p * np.ma.log(p / (pr * ps))).sum(axis=(-1, -2)) / np.log(base)
205+
return mutual_info
206+
207+
208+
def bitinformation(a, axis=0):
209+
"""Get the information content of each bit in the array.
210+
211+
Parameters
212+
----------
213+
a : array
214+
Array to calculate the bit information.
215+
axis : int
216+
Axis along which to calculate the bit information.
217+
218+
Returns
219+
-------
220+
info_per_bit : array
221+
"""
222+
sa = tuple(slice(0, -1) if i == axis else slice(None) for i in range(len(a.shape)))
223+
sb = tuple(
224+
slice(1, None) if i == axis else slice(None) for i in range(len(a.shape))
225+
)
226+
return mutual_information(a[sa], a[sb])
227+
228+
229+
def get_keepbits(info_per_bit, inflevel=0.99):
230+
"""Get the number of mantissa bits to keep.
231+
232+
Parameters
233+
----------
234+
info_per_bit : array
235+
Information content of each bit from `get_bitinformation`.
236+
237+
inflevel : float
238+
Level of information that shall be preserved.
239+
240+
Returns
241+
-------
242+
keepbits : int
243+
Number of mantissa bits to keep
244+
245+
"""
246+
if (inflevel < 0) or (inflevel > 1.0):
247+
raise ValueError("Please provide `inflevel` from interval [0.,1.]")
248+
249+
cdf = _cdf_from_info_per_bit(info_per_bit)
250+
bitdim_non_mantissa_bits = NMBITS[len(info_per_bit)]
251+
keepmantissabits = (
252+
(cdf > inflevel).argmax() + 1 - bitdim_non_mantissa_bits
253+
)
254+
255+
return keepmantissabits
256+
257+
258+
def _cdf_from_info_per_bit(info_per_bit):
259+
"""Convert info_per_bit to cumulative distribution function"""
260+
tol = info_per_bit[-4:].max() * 1.5
261+
info_per_bit[info_per_bit < tol] = 0
262+
cdf = info_per_bit.cumsum()
263+
return cdf / cdf[-1]

numcodecs/bitround.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,20 @@ def encode(self, buf):
5454
raise TypeError("Only float arrays (16-64bit) can be bit-rounded")
5555
bits = max_bits[str(a.dtype)]
5656
# cast float to int type of same width (preserve endianness)
57-
a_int_dtype = np.dtype(a.dtype.str.replace("f", "i"))
58-
all_set = np.array(-1, dtype=a_int_dtype)
5957
if self.keepbits == bits:
6058
return a
6159
if self.keepbits > bits:
6260
raise ValueError("Keepbits too large for given dtype")
63-
b = a.view(a_int_dtype)
64-
maskbits = bits - self.keepbits
61+
62+
return self._bitround(a, self.keepbits)
63+
64+
@staticmethod
65+
def _bitround(buf, keepbits):
66+
bits = max_bits[str(buf.dtype)]
67+
a_int_dtype = np.dtype(buf.dtype.str.replace("f", "i"))
68+
all_set = np.array(-1, dtype=a_int_dtype)
69+
b = buf.view(a_int_dtype)
70+
maskbits = bits - keepbits
6571
mask = (all_set >> maskbits) << maskbits
6672
half_quantum1 = (1 << (maskbits - 1)) - 1
6773
b += ((b >> maskbits) & 1) + half_quantum1

0 commit comments

Comments
 (0)