@@ -126,7 +126,7 @@ void index_put__batch_rule(
126126 auto values_ = moveBatchDimToFront (values, values_bdim);
127127 TORCH_INTERNAL_ASSERT (indices.size () == indices_bdims.size ());
128128 std::vector<optional<Tensor>> indices_ = batchIndices (indices, indices_bdims, self_.size (0 ), self_bdim, values_bdim);
129- at::index_put_ (self_, List<optional<Tensor>>(indices_), values , accumulate);
129+ at::index_put_ (self_, List<optional<Tensor>>(indices_), values_ , accumulate);
130130}
131131
132132// plumbing done since we don't support List<optional<Tensor>> in codegen
@@ -158,6 +158,54 @@ Tensor& index_put__plumbing(Tensor & self, const List<optional<Tensor>> & indice
158158 return self;
159159}
160160
161+ void _index_put_impl__batch_rule (
162+ Tensor& self,
163+ optional<int64_t > self_bdim,
164+ ArrayRef<optional<Tensor>> indices,
165+ ArrayRef<optional<int64_t >> indices_bdims,
166+ const Tensor& values,
167+ optional<int64_t > values_bdim,
168+ bool accumulate,
169+ bool unsafe) {
170+ if (!self_bdim.has_value ()) {
171+ vmapIncompatibleInplaceError (" _index_put_impl_" );
172+ }
173+ auto self_ = moveBatchDimToFront (self, self_bdim);
174+ auto values_ = moveBatchDimToFront (values, values_bdim);
175+ TORCH_INTERNAL_ASSERT (indices.size () == indices_bdims.size ());
176+ std::vector<optional<Tensor>> indices_ = batchIndices (indices, indices_bdims, self_.size (0 ), self_bdim, values_bdim);
177+ at::_index_put_impl_ (self_, List<optional<Tensor>>(indices_), values_, accumulate, unsafe);
178+ }
179+
180+ // plumbing done since we don't support List<optional<Tensor>> in codegen
181+ Tensor& _index_put_impl__plumbing (Tensor & self, const List<optional<Tensor>> & indices
182+ , const Tensor & values, bool accumulate, bool unsafe) {
183+ c10::impl::ExcludeDispatchKeyGuard guard (kBatchedKey );
184+ auto maybe_layer = maybeCurrentDynamicLayer ();
185+ TORCH_INTERNAL_ASSERT (maybe_layer.has_value ());
186+ int64_t cur_level = maybe_layer->layerId ();
187+ Tensor self_value;
188+ optional<int64_t > self_bdim;
189+ std::tie (self_value, self_bdim) = unwrapTensorAtLevel (self, cur_level);
190+ std::vector<optional<Tensor>> indices_value;
191+ std::vector<optional<int64_t >> indices_bdims;
192+ for (const auto && indRef : indices) {
193+ optional<Tensor> ind = indRef;
194+ optional<Tensor> index;
195+ optional<int64_t > index_bdim;
196+ if (ind.has_value ()) {
197+ std::tie (index, index_bdim) = unwrapTensorAtLevel (ind.value (), cur_level);
198+ }
199+ indices_value.push_back (index);
200+ indices_bdims.push_back (index_bdim);
201+ }
202+ Tensor values_value;
203+ optional<int64_t > values_bdim;
204+ std::tie (values_value, values_bdim) = unwrapTensorAtLevel (values, cur_level);
205+ _index_put_impl__batch_rule (self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate, unsafe);
206+ return self;
207+ }
208+
161209namespace {
162210
163211template <typename Func, typename ...Args>
@@ -496,6 +544,7 @@ std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
496544TORCH_LIBRARY_IMPL (aten, FT_BATCHED_KEY, m) {
497545 m.impl (" index.Tensor" , index_plumbing);
498546 m.impl (" index_put_" , index_put__plumbing);
547+ m.impl (" _index_put_impl_" , _index_put_impl__plumbing);
499548 m.impl (" slice_scatter" , slice_scatter_decomp);
500549 m.impl (" select_scatter" , select_scatter_decomp);
501550 m.impl (" index_copy" , index_copy_decomp);
0 commit comments