@@ -729,6 +729,7 @@ def attribute_future(
729729 feature_mask : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
730730 perturbations_per_eval : int = 1 ,
731731 show_progress : bool = False ,
732+ enable_cross_tensor_attribution : bool = False ,
732733 ** kwargs : Any ,
733734 ) -> Future [TensorOrTupleOfTensorsGeneric ]:
734735 r"""
@@ -743,17 +744,18 @@ def attribute_future(
743744 formatted_additional_forward_args = _format_additional_forward_args (
744745 additional_forward_args
745746 )
746- num_examples = formatted_inputs [0 ].shape [0 ]
747747 formatted_feature_mask = _format_feature_mask (feature_mask , formatted_inputs )
748748
749749 assert (
750750 isinstance (perturbations_per_eval , int ) and perturbations_per_eval >= 1
751751 ), "Perturbations per evaluation must be an integer and at least 1."
752752 with torch .no_grad ():
753+ attr_progress = None
753754 if show_progress :
754755 attr_progress = self ._attribute_progress_setup (
755756 formatted_inputs ,
756757 formatted_feature_mask ,
758+ enable_cross_tensor_attribution ,
757759 ** kwargs ,
758760 perturbations_per_eval = perturbations_per_eval ,
759761 )
@@ -788,101 +790,135 @@ def attribute_future(
788790 )
789791 )
790792
791- # The will be the same amount futures as modified_eval down there,
792- # since we cannot add up the evaluation result adhoc under async mode.
793- all_modified_eval_futures : List [
794- List [Future [Tuple [List [Tensor ], List [Tensor ]]]]
795- ] = [[] for _ in range (len (inputs ))]
796- # Iterate through each feature tensor for ablation
797- for i in range (len (formatted_inputs )):
798- # Skip any empty input tensors
799- if torch .numel (formatted_inputs [i ]) == 0 :
800- continue
801-
802- for (
803- current_inputs ,
804- current_add_args ,
805- current_target ,
806- current_mask ,
807- ) in self ._ith_input_ablation_generator (
808- i ,
793+ if enable_cross_tensor_attribution :
794+ raise NotImplementedError ("Not supported yet" )
795+ else :
796+ # pyre-fixme[7]: Expected`` Future[Variable[TensorOrTupleOfTensorsGeneric <:
797+ # [Tensor, typing.Tuple[Tensor, ...]]]]` but got `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
798+ return self ._attribute_with_independent_feature_masks_future (
809799 formatted_inputs ,
810800 formatted_additional_forward_args ,
811801 target ,
812802 baselines ,
813803 formatted_feature_mask ,
814804 perturbations_per_eval ,
805+ attr_progress ,
806+ processed_initial_eval_fut ,
807+ is_inputs_tuple ,
815808 ** kwargs ,
816- ):
817- # modified_eval has (n_feature_perturbed * n_outputs) elements
818- # shape:
819- # agg mode: (*initial_eval.shape)
820- # non-agg mode:
821- # (feature_perturbed * batch_size, *initial_eval.shape[1:])
822- modified_eval : Union [Tensor , Future [Tensor ]] = _run_forward (
823- self .forward_func ,
824- current_inputs ,
825- current_target ,
826- current_add_args ,
827- )
809+ )
828810
829- if show_progress :
830- attr_progress .update ()
811+ def _attribute_with_independent_feature_masks_future (
812+ self ,
813+ formatted_inputs : Tuple [Tensor , ...],
814+ formatted_additional_forward_args : Optional [Tuple [object , ...]],
815+ target : TargetType ,
816+ baselines : BaselineType ,
817+ formatted_feature_mask : Tuple [Tensor , ...],
818+ perturbations_per_eval : int ,
819+ attr_progress : Optional [Union [SimpleProgress [IterableType ], tqdm ]],
820+ processed_initial_eval_fut : Future [
821+ Tuple [List [Tensor ], List [Tensor ], Tensor , Tensor , int , dtype ]
822+ ],
823+ is_inputs_tuple : bool ,
824+ ** kwargs : Any ,
825+ ) -> Future [Tensor | Tuple [Tensor , ...]]:
826+ num_examples = formatted_inputs [0 ].shape [0 ]
827+ # The will be the same amount futures as modified_eval down there,
828+ # since we cannot add up the evaluation result adhoc under async mode.
829+ all_modified_eval_futures : List [
830+ List [Future [Tuple [List [Tensor ], List [Tensor ]]]]
831+ ] = [[] for _ in range (len (formatted_inputs ))]
832+ # Iterate through each feature tensor for ablation
833+ for i in range (len (formatted_inputs )):
834+ # Skip any empty input tensors
835+ if torch .numel (formatted_inputs [i ]) == 0 :
836+ continue
831837
832- if not isinstance (modified_eval , torch .Future ):
833- raise AssertionError (
834- "when using attribute_future, modified_eval should have "
835- f"Future type rather than { type (modified_eval )} "
836- )
837- if processed_initial_eval_fut is None :
838- raise AssertionError (
839- "processed_initial_eval_fut should not be None"
840- )
838+ for (
839+ current_inputs ,
840+ current_add_args ,
841+ current_target ,
842+ current_mask ,
843+ ) in self ._ith_input_ablation_generator (
844+ i ,
845+ formatted_inputs ,
846+ formatted_additional_forward_args ,
847+ target ,
848+ baselines ,
849+ formatted_feature_mask ,
850+ perturbations_per_eval ,
851+ ** kwargs ,
852+ ):
853+ # modified_eval has (n_feature_perturbed * n_outputs) elements
854+ # shape:
855+ # agg mode: (*initial_eval.shape)
856+ # non-agg mode:
857+ # (feature_perturbed * batch_size, *initial_eval.shape[1:])
858+ modified_eval : Union [Tensor , Future [Tensor ]] = _run_forward (
859+ self .forward_func ,
860+ current_inputs ,
861+ current_target ,
862+ current_add_args ,
863+ )
864+
865+ if attr_progress is not None :
866+ attr_progress .update ()
867+
868+ if not isinstance (modified_eval , torch .Future ):
869+ raise AssertionError (
870+ "when using attribute_future, modified_eval should have "
871+ f"Future type rather than { type (modified_eval )} "
872+ )
873+ if processed_initial_eval_fut is None :
874+ raise AssertionError (
875+ "processed_initial_eval_fut should not be None"
876+ )
841877
842- # Need to collect both initial eval and modified_eval
843- eval_futs : Future [
844- List [
845- Future [
846- Union [
847- Tuple [
848- List [Tensor ],
849- List [Tensor ],
850- Tensor ,
851- Tensor ,
852- int ,
853- dtype ,
854- ],
878+ # Need to collect both initial eval and modified_eval
879+ eval_futs : Future [
880+ List [
881+ Future [
882+ Union [
883+ Tuple [
884+ List [Tensor ],
885+ List [Tensor ],
886+ Tensor ,
855887 Tensor ,
856- ]
888+ int ,
889+ dtype ,
890+ ],
891+ Tensor ,
857892 ]
858893 ]
859- ] = collect_all (
860- [
861- processed_initial_eval_fut ,
862- modified_eval ,
863- ]
864- )
894+ ]
895+ ] = collect_all (
896+ [
897+ processed_initial_eval_fut ,
898+ modified_eval ,
899+ ]
900+ )
865901
866- ablated_out_fut : Future [Tuple [List [Tensor ], List [Tensor ]]] = (
867- eval_futs .then (
868- lambda eval_futs , current_inputs = current_inputs , current_mask = current_mask , i = i : self ._eval_fut_to_ablated_out_fut ( # type: ignore # noqa: E501 line too long
869- eval_futs = eval_futs ,
870- current_inputs = current_inputs ,
871- current_mask = current_mask ,
872- i = i ,
873- perturbations_per_eval = perturbations_per_eval ,
874- num_examples = num_examples ,
875- formatted_inputs = formatted_inputs ,
876- )
902+ ablated_out_fut : Future [Tuple [List [Tensor ], List [Tensor ]]] = (
903+ eval_futs .then (
904+ lambda eval_futs , current_inputs = current_inputs , current_mask = current_mask , i = i : self ._eval_fut_to_ablated_out_fut ( # type: ignore # noqa: E501 line too long
905+ eval_futs = eval_futs ,
906+ current_inputs = current_inputs ,
907+ current_mask = current_mask ,
908+ i = i ,
909+ perturbations_per_eval = perturbations_per_eval ,
910+ num_examples = num_examples ,
911+ formatted_inputs = formatted_inputs ,
877912 )
878913 )
914+ )
879915
880- all_modified_eval_futures [i ].append (ablated_out_fut )
916+ all_modified_eval_futures [i ].append (ablated_out_fut )
881917
882- if show_progress :
883- attr_progress .close ()
918+ if attr_progress is not None :
919+ attr_progress .close ()
884920
885- return self ._generate_async_result (all_modified_eval_futures , is_inputs_tuple ) # type: ignore # noqa: E501 line too long
921+ return self ._generate_async_result (all_modified_eval_futures , is_inputs_tuple ) # type: ignore # noqa: E501 line too long
886922
887923 # pyre-fixme[3] return type must be annotated
888924 def _attribute_progress_setup (
0 commit comments