Issue with vmap
when using lx.linear_solve
on SparseMatrixOperator
with multi-column RHS
#53
Labels
question
User queries
I'm encountering a ValueError when using the
vmap
functionality to map thelx.linear_solve
operation across multiple columns of a right-hand side matrix (RHS) with aSparseMatrixOperator
.I expect the
_solve_beta
function when passed aSparseMatrixOperator
and a multi-column RHS, to solve the linear system for each column of the RHS smoothly similarly to how it works when aMatrixLinearOperator
is used. While the linear solve operation works fine for individual columns of y with aSparseMatrixOperator
, it throws an error when attempting to solve for multiple columns usingvmap
, and produces the following error:ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification ... (trimmed for brevity)
Here is a minimal working example provided specific to this problem:
The text was updated successfully, but these errors were encountered: