@@ -729,6 +729,7 @@ def attribute_future(
729729 feature_mask : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
730730 perturbations_per_eval : int = 1 ,
731731 show_progress : bool = False ,
732+ enable_cross_tensor_attribution : bool = False ,
732733 ** kwargs : Any ,
733734 ) -> Future [TensorOrTupleOfTensorsGeneric ]:
734735 r"""
@@ -743,17 +744,18 @@ def attribute_future(
743744 formatted_additional_forward_args = _format_additional_forward_args (
744745 additional_forward_args
745746 )
746- num_examples = formatted_inputs [0 ].shape [0 ]
747747 formatted_feature_mask = _format_feature_mask (feature_mask , formatted_inputs )
748748
749749 assert (
750750 isinstance (perturbations_per_eval , int ) and perturbations_per_eval >= 1
751751 ), "Perturbations per evaluation must be an integer and at least 1."
752752 with torch .no_grad ():
753+ attr_progress = None
753754 if show_progress :
754755 attr_progress = self ._attribute_progress_setup (
755756 formatted_inputs ,
756757 formatted_feature_mask ,
758+ enable_cross_tensor_attribution ,
757759 ** kwargs ,
758760 perturbations_per_eval = perturbations_per_eval ,
759761 )
@@ -768,7 +770,7 @@ def attribute_future(
768770 formatted_additional_forward_args ,
769771 )
770772
771- if show_progress :
773+ if attr_progress is not None :
772774 attr_progress .update ()
773775
774776 processed_initial_eval_fut : Optional [
@@ -788,101 +790,136 @@ def attribute_future(
788790 )
789791 )
790792
791- # The will be the same amount futures as modified_eval down there,
792- # since we cannot add up the evaluation result adhoc under async mode.
793- all_modified_eval_futures : List [
794- List [Future [Tuple [List [Tensor ], List [Tensor ]]]]
795- ] = [[] for _ in range (len (inputs ))]
796- # Iterate through each feature tensor for ablation
797- for i in range (len (formatted_inputs )):
798- # Skip any empty input tensors
799- if torch .numel (formatted_inputs [i ]) == 0 :
800- continue
801-
802- for (
803- current_inputs ,
804- current_add_args ,
805- current_target ,
806- current_mask ,
807- ) in self ._ith_input_ablation_generator (
808- i ,
793+ if enable_cross_tensor_attribution :
794+ raise NotImplementedError ("Not supported yet" )
795+ else :
796+ # pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
797+ # <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
798+ # `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
799+ return self ._attribute_with_independent_feature_masks_future ( # type: ignore # noqa: E501 line too long
809800 formatted_inputs ,
810801 formatted_additional_forward_args ,
811802 target ,
812803 baselines ,
813804 formatted_feature_mask ,
814805 perturbations_per_eval ,
806+ attr_progress ,
807+ processed_initial_eval_fut ,
808+ is_inputs_tuple ,
815809 ** kwargs ,
816- ):
817- # modified_eval has (n_feature_perturbed * n_outputs) elements
818- # shape:
819- # agg mode: (*initial_eval.shape)
820- # non-agg mode:
821- # (feature_perturbed * batch_size, *initial_eval.shape[1:])
822- modified_eval : Union [Tensor , Future [Tensor ]] = _run_forward (
823- self .forward_func ,
824- current_inputs ,
825- current_target ,
826- current_add_args ,
827- )
810+ )
828811
829- if show_progress :
830- attr_progress .update ()
812+ def _attribute_with_independent_feature_masks_future (
813+ self ,
814+ formatted_inputs : Tuple [Tensor , ...],
815+ formatted_additional_forward_args : Optional [Tuple [object , ...]],
816+ target : TargetType ,
817+ baselines : BaselineType ,
818+ formatted_feature_mask : Tuple [Tensor , ...],
819+ perturbations_per_eval : int ,
820+ attr_progress : Optional [Union [SimpleProgress [IterableType ], tqdm ]],
821+ processed_initial_eval_fut : Future [
822+ Tuple [List [Tensor ], List [Tensor ], Tensor , Tensor , int , dtype ]
823+ ],
824+ is_inputs_tuple : bool ,
825+ ** kwargs : Any ,
826+ ) -> Future [Union [Tensor , Tuple [Tensor , ...]]]:
827+ num_examples = formatted_inputs [0 ].shape [0 ]
828+ # The will be the same amount futures as modified_eval down there,
829+ # since we cannot add up the evaluation result adhoc under async mode.
830+ all_modified_eval_futures : List [
831+ List [Future [Tuple [List [Tensor ], List [Tensor ]]]]
832+ ] = [[] for _ in range (len (formatted_inputs ))]
833+ # Iterate through each feature tensor for ablation
834+ for i in range (len (formatted_inputs )):
835+ # Skip any empty input tensors
836+ if torch .numel (formatted_inputs [i ]) == 0 :
837+ continue
831838
832- if not isinstance (modified_eval , torch .Future ):
833- raise AssertionError (
834- "when using attribute_future, modified_eval should have "
835- f"Future type rather than { type (modified_eval )} "
836- )
837- if processed_initial_eval_fut is None :
838- raise AssertionError (
839- "processed_initial_eval_fut should not be None"
840- )
839+ for (
840+ current_inputs ,
841+ current_add_args ,
842+ current_target ,
843+ current_mask ,
844+ ) in self ._ith_input_ablation_generator (
845+ i ,
846+ formatted_inputs ,
847+ formatted_additional_forward_args ,
848+ target ,
849+ baselines ,
850+ formatted_feature_mask ,
851+ perturbations_per_eval ,
852+ ** kwargs ,
853+ ):
854+ # modified_eval has (n_feature_perturbed * n_outputs) elements
855+ # shape:
856+ # agg mode: (*initial_eval.shape)
857+ # non-agg mode:
858+ # (feature_perturbed * batch_size, *initial_eval.shape[1:])
859+ modified_eval : Union [Tensor , Future [Tensor ]] = _run_forward (
860+ self .forward_func ,
861+ current_inputs ,
862+ current_target ,
863+ current_add_args ,
864+ )
841865
842- # Need to collect both initial eval and modified_eval
843- eval_futs : Future [
844- List [
845- Future [
846- Union [
847- Tuple [
848- List [Tensor ],
849- List [Tensor ],
850- Tensor ,
851- Tensor ,
852- int ,
853- dtype ,
854- ],
866+ if attr_progress is not None :
867+ attr_progress .update ()
868+
869+ if not isinstance (modified_eval , torch .Future ):
870+ raise AssertionError (
871+ "when using attribute_future, modified_eval should have "
872+ f"Future type rather than { type (modified_eval )} "
873+ )
874+ if processed_initial_eval_fut is None :
875+ raise AssertionError (
876+ "processed_initial_eval_fut should not be None"
877+ )
878+
879+ # Need to collect both initial eval and modified_eval
880+ eval_futs : Future [
881+ List [
882+ Future [
883+ Union [
884+ Tuple [
885+ List [Tensor ],
886+ List [Tensor ],
887+ Tensor ,
855888 Tensor ,
856- ]
889+ int ,
890+ dtype ,
891+ ],
892+ Tensor ,
857893 ]
858894 ]
859- ] = collect_all (
860- [
861- processed_initial_eval_fut ,
862- modified_eval ,
863- ]
864- )
895+ ]
896+ ] = collect_all (
897+ [
898+ processed_initial_eval_fut ,
899+ modified_eval ,
900+ ]
901+ )
865902
866- ablated_out_fut : Future [Tuple [List [Tensor ], List [Tensor ]]] = (
867- eval_futs .then (
868- lambda eval_futs , current_inputs = current_inputs , current_mask = current_mask , i = i : self ._eval_fut_to_ablated_out_fut ( # type: ignore # noqa: E501 line too long
869- eval_futs = eval_futs ,
870- current_inputs = current_inputs ,
871- current_mask = current_mask ,
872- i = i ,
873- perturbations_per_eval = perturbations_per_eval ,
874- num_examples = num_examples ,
875- formatted_inputs = formatted_inputs ,
876- )
903+ ablated_out_fut : Future [Tuple [List [Tensor ], List [Tensor ]]] = (
904+ eval_futs .then (
905+ lambda eval_futs , current_inputs = current_inputs , current_mask = current_mask , i = i : self ._eval_fut_to_ablated_out_fut ( # type: ignore # noqa: E501 line too long
906+ eval_futs = eval_futs ,
907+ current_inputs = current_inputs ,
908+ current_mask = current_mask ,
909+ i = i ,
910+ perturbations_per_eval = perturbations_per_eval ,
911+ num_examples = num_examples ,
912+ formatted_inputs = formatted_inputs ,
877913 )
878914 )
915+ )
879916
880- all_modified_eval_futures [i ].append (ablated_out_fut )
917+ all_modified_eval_futures [i ].append (ablated_out_fut )
881918
882- if show_progress :
883- attr_progress .close ()
919+ if attr_progress is not None :
920+ attr_progress .close ()
884921
885- return self ._generate_async_result (all_modified_eval_futures , is_inputs_tuple ) # type: ignore # noqa: E501 line too long
922+ return self ._generate_async_result (all_modified_eval_futures , is_inputs_tuple ) # type: ignore # noqa: E501 line too long
886923
887924 # pyre-fixme[3] return type must be annotated
888925 def _attribute_progress_setup (
0 commit comments