@@ -85,7 +85,7 @@ def __init__(
8585 * ,
8686 input_values : List [TInputValue ],
8787 target_names : List [str ],
88- target_values : Optional [npt .ArrayLike ] = None ,
88+ target_values : Optional [Union [ npt .ArrayLike , List [ TTargetValue ]] ] = None ,
8989 aggregate_attr : npt .ArrayLike ,
9090 element_attr : Optional [npt .ArrayLike ] = None ,
9191 aggregate_descriptor : str = "Aggregate" ,
@@ -104,11 +104,11 @@ def aggregate_attr(self) -> Tensor:
104104 return self ._aggregate_attr
105105
106106 @aggregate_attr .setter
107- def aggregate_attr (self , seq_attr : npt .ArrayLike ) -> None :
108- if isinstance (seq_attr , Tensor ):
109- self ._aggregate_attr = seq_attr
107+ def aggregate_attr (self , aggregate_attr : npt .ArrayLike ) -> None :
108+ if isinstance (aggregate_attr , Tensor ):
109+ self ._aggregate_attr = aggregate_attr
110110 else :
111- self ._aggregate_attr = torch .tensor (seq_attr )
111+ self ._aggregate_attr = torch .tensor (aggregate_attr )
112112 # IDEA: in the future we might want to support higher dim seq_attr
113113 # (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes)
114114 assert len (self ._aggregate_attr .shape ) == 1 , "seq_attr must be a 1D tensor"
@@ -121,13 +121,13 @@ def element_attr(self) -> Optional[Tensor]:
121121 return self ._element_attr
122122
123123 @element_attr .setter
124- def element_attr (self , token_attr : Optional [npt .ArrayLike ]) -> None :
125- if token_attr is None :
124+ def element_attr (self , element_attr : Optional [npt .ArrayLike ]) -> None :
125+ if element_attr is None :
126126 self ._element_attr = None
127- elif isinstance (token_attr , Tensor ):
128- self ._element_attr = token_attr
127+ elif isinstance (element_attr , Tensor ):
128+ self ._element_attr = element_attr
129129 else :
130- self ._element_attr = torch .tensor (token_attr )
130+ self ._element_attr = torch .tensor (element_attr )
131131
132132 if self ._element_attr is not None :
133133 # IDEA: in the future we might want to support higher dim seq_attr
@@ -146,7 +146,9 @@ def target_values(self) -> Optional[List[TTargetValue]]:
146146 return self ._target_values
147147
148148 @target_values .setter
149- def target_values (self , target_values : Optional [npt .ArrayLike ]) -> None :
149+ def target_values (
150+ self , target_values : Optional [Union [npt .ArrayLike , List [TTargetValue ]]]
151+ ) -> None :
150152 if target_values is None :
151153 self ._target_values = None
152154 elif isinstance (target_values , (Tensor , np .ndarray )):
0 commit comments