11# pylint: disable=too-many-branches
22"""Utility function for variable selection and bart interpretability."""
33
4+ import base64
45import warnings
56from collections .abc import Callable
67from typing import Any , TypeVar
@@ -708,7 +709,7 @@ def get_variable_inclusion(idata, X, model=None, bart_var_name=None, labels=None
708709 """
709710 n_vars = X .shape [1 ]
710711 vi_xarray = idata ["sample_stats" ]["variable_inclusion" ]
711- if " variable_inclusion_dim_0" in vi_xarray . coords :
712+ if vi_xarray . variable_inclusion_dim_0 . size > 1 :
712713 if model is None or bart_var_name is None :
713714 raise ValueError (
714715 "The InfereceData was generated from a model with multiple BART variables, \n "
@@ -727,13 +728,13 @@ def get_variable_inclusion(idata, X, model=None, bart_var_name=None, labels=None
727728 n_vars = len (indices )
728729
729730 if hasattr (X , "columns" ) and hasattr (X , "to_numpy" ):
730- labels = list (X .columns )
731+ labels = list (X .columns [ indices ] )
731732
732733 if labels is None :
733- labels = [str (i ) for i in range ( n_vars ) ]
734+ labels = [str (i ) for i in indices ]
734735
735736 if to_kulprit :
736- return [labels [:idx ] for idx in range (n_vars )]
737+ return [labels [:idx ] for idx in range (n_vars + 1 )]
737738 else :
738739 return VI_norm [indices ], labels
739740
@@ -884,7 +885,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
884885
885886 if method in ["VI" , "backward_VI" ]:
886887 vi_xarray = idata ["sample_stats" ]["variable_inclusion" ]
887- if " variable_inclusion_dim_0" in vi_xarray . coords :
888+ if vi_xarray . variable_inclusion_dim_0 . size > 1 :
888889 if model is None :
889890 raise ValueError (
890891 "The InfereceData was generated from a model with multiple BART variables, \n "
@@ -968,7 +969,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
968969
969970 # Save values for plotting later
970971 r2_mean [i_var - init ] = max_r_2
971- r2_hdi [i_var - init ] = array_stats .hdi (r_2_without_least_important_vars )
972+ r2_hdi [i_var - init ] = array_stats .hdi (
973+ r_2_without_least_important_vars , prob = rcParams ["stats.ci_prob" ]
974+ )
972975 preds [i_var - init ] = least_important_samples .squeeze ()
973976
974977 # extend current list of least important variable
@@ -1282,37 +1285,34 @@ def _plot_hdi(x, y, smooth, color, alpha, smooth_kwargs, ax):
12821285 return ax
12831286
12841287
1285- def _decode_vi (n : int , length : int ) -> list [int ]:
1286- """
1287- Decode the variable inclusion from the BART model.
1288- """
1289- bits = bin (n )[2 :]
1290- vi_list : list [int ] = []
1288+ def _decode_vi (s : str , length : int ) -> list [int ]:
1289+ """Decode base64 string back to vector."""
1290+ data = base64 .b64decode (s )
1291+ result : list [int ] = []
12911292 i = 0
1292- while len (vi_list ) < length :
1293- # Count prefix ones
1294- prefix_len = 0
1295- while bits [ i ] == "1" :
1296- prefix_len += 1
1293+ while len (result ) < length and i < len ( data ) :
1294+ num = 0
1295+ shift = 0
1296+ while i < len ( data ) :
1297+ byte = data [ i ]
12971298 i += 1
1298- i += 1 # skip the '0'
1299- b = bits [i : i + prefix_len ]
1300- vi_list .append (int (b , 2 ))
1301- i += prefix_len
1302- return vi_list
1299+ num |= (byte & 0x7F ) << shift
1300+ if not (byte & 0x80 ):
1301+ break
1302+ shift += 7
1303+ result .append (num )
1304+ return result
13031305
13041306
1305- def _encode_vi (vec : npt . NDArray ) -> int :
1307+ def _encode_vi (vec : list [ int ] ) -> str :
13061308 """
1307- Encode variable inclusion vector into a single integer.
1308-
1309- The encoding is done by converting each element of the vector into a binary string,
1310- where each element contributes a prefix of '1's followed by a '0' and its binary representation.
1311- The final result is the integer representation of the concatenated binary string.
1309+ Encode vector to base64 string.
13121310 """
1313- bits = ""
1314- for x in vec :
1315- b = bin (x )[2 :]
1316- prefix = "1" * len (b ) + "0"
1317- bits += prefix + b
1318- return int (bits , 2 )
1311+ result = bytearray ()
1312+ for num in vec :
1313+ n = num
1314+ while n > 127 :
1315+ result .append ((n & 0x7F ) | 0x80 )
1316+ n >>= 7
1317+ result .append (n & 0x7F )
1318+ return base64 .b64encode (bytes (result )).decode ("ascii" )
0 commit comments