@@ -177,7 +177,7 @@ def from_tensordict_pair(
177177 collate_fn : Callable [[Any ], Any ] | None = None ,
178178 write_fn : Callable [[Any , Any ], Any ] | None = None ,
179179 consolidated : bool | None = None ,
180- ):
180+ ) -> TensorDictMap :
181181 """Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb.
182182
183183 Args:
@@ -308,7 +308,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
308308 if not self ._has_lazy_out_keys ():
309309 # TODO: make this work with pytrees and avoid calling select if keys match
310310 value = value .select (* self .out_keys , strict = False )
311+ item , value = self ._maybe_add_batch (item , value )
312+ index = self ._to_index (item , extend = True )
313+ if index .unique ().numel () < index .numel ():
314+ # If multiple values point to the same place in the storage, we cannot process them by batch
315+ # There could be a better way to deal with this, using unique ids.
316+ vals = []
317+ for it , val in zip (item .split (1 ), value .split (1 )):
318+ self [it ] = val
319+ vals .append (val )
320+ # __setitem__ may affect the content of the input data
321+ value .update (TensorDictBase .lazy_stack (vals ))
322+ return
311323 if self .write_fn is not None :
324+ # We use this block in the following context: the value written in the storage is already present,
325+ # but it needs to be updated.
326+ # We first check if the value is already there using `contains`. If so, we pass the new value and the
327+ # previous one to write_fn. The values that are not present are passed alone.
312328 if len (self ):
313329 modifiable = self .contains (item )
314330 if modifiable .any ():
@@ -322,8 +338,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
322338 value = self .write_fn (value )
323339 else :
324340 value = self .write_fn (value )
325- item , value = self ._maybe_add_batch (item , value )
326- index = self ._to_index (item , extend = True )
327341 self .storage .set (index , value )
328342
329343 def __len__ (self ):
0 commit comments