@@ -177,7 +177,7 @@ def from_tensordict_pair(
177177        collate_fn : Callable [[Any ], Any ] |  None  =  None ,
178178        write_fn : Callable [[Any , Any ], Any ] |  None  =  None ,
179179        consolidated : bool  |  None  =  None ,
180-     ):
180+     )  ->   TensorDictMap :
181181        """Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb. 
182182
183183        Args: 
@@ -308,7 +308,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
308308        if  not  self ._has_lazy_out_keys ():
309309            # TODO: make this work with pytrees and avoid calling select if keys match 
310310            value  =  value .select (* self .out_keys , strict = False )
311+         item , value  =  self ._maybe_add_batch (item , value )
312+         index  =  self ._to_index (item , extend = True )
313+         if  index .unique ().numel () <  index .numel ():
314+             # If multiple values point to the same place in the storage, we cannot process them by batch 
315+             # There could be a better way to deal with this, using unique ids. 
316+             vals  =  []
317+             for  it , val  in  zip (item .split (1 ), value .split (1 )):
318+                 self [it ] =  val 
319+                 vals .append (val )
320+             # __setitem__ may affect the content of the input data 
321+             value .update (TensorDictBase .lazy_stack (vals ))
322+             return 
311323        if  self .write_fn  is  not None :
324+             # We use this block in the following context: the value written in the storage is already present, 
325+             # but it needs to be updated. 
326+             # We first check if the value is already there using `contains`. If so, we pass the new value and the 
327+             # previous one to write_fn. The values that are not present are passed alone. 
312328            if  len (self ):
313329                modifiable  =  self .contains (item )
314330                if  modifiable .any ():
@@ -322,8 +338,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
322338                    value  =  self .write_fn (value )
323339            else :
324340                value  =  self .write_fn (value )
325-         item , value  =  self ._maybe_add_batch (item , value )
326-         index  =  self ._to_index (item , extend = True )
327341        self .storage .set (index , value )
328342
329343    def  __len__ (self ):
0 commit comments