Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support complex-to-real optimization #76

Open
1 of 3 tasks
Randl opened this issue Aug 17, 2024 · 11 comments
Open
1 of 3 tasks

Support complex-to-real optimization #76

Randl opened this issue Aug 17, 2024 · 11 comments

Comments

@Randl
Copy link
Contributor

Randl commented Aug 17, 2024

@Randl
Copy link
Contributor Author

Randl commented Aug 17, 2024

@patrick-kidger specifically the third point requires some design decision which I yet have no idea how to realize. Make complex-to-real operator support in Lineax?

@patrick-kidger
Copy link
Owner

I think JAX has a sharply limited notion of complex linearity, to be honest. They're pretty up-front about the fact that they're really just treating things as functions between vector spaces over the reals.

Thinking out loud: linearity is the property that f(a + λb) = f(a) + λf(b). If f maps between vector spaces over different fields then this isn't defined, and you need to start reaching for some generalised notion of linearity.

We should follow JAX's lead on this one, I think, if we're going to interface well with the rest of its ecosystem.

@Randl
Copy link
Contributor Author

Randl commented Aug 17, 2024

You are correct that linear map between vector spaces over different fields doesn't make sense, definitely when the field of the input has elements not contained in the field of the output.

If I understand correctly, jax.linearize is equivalent to JVP, which, following Jax docs is linear in c, d separately but not in c+id (which is the same for addition but not for constant multiplication). This is why stuff doesn't fail in Lineax, at least until you try to materialize it.

I'm not familiar with internals, but making R->R map under the hood makes the most sense for me.

@Randl
Copy link
Contributor Author

Randl commented Aug 18, 2024

So, thinking out loud, the only real place we are affected is materialization. As long as we do not materialize the matrix, jax internals take care of everything for us.

When we want to materialize, we have a problem since the materialization of C->R operator is impossible (it's not a linear operator). Instead, we can materialize the 2NxM matrix for the C^N->R^M jacobian. Now, it should be the responsibility of the user to do C->R^2 transformations with the inputs.

However, if we want to have a uniform interface, I think we should allow representing the C->R operator as R^2 ->R for any operator class. One possible solution is to have a flag that indicates that inputs are complex even though the matrix is real and then internally perform C-> R^2 transformations.

So, the first step would be to make lineax C->R compatible

  • Materialize functional operators (Jacobian, function) as R^2->R operators
  • Add flag indicating that the R^2->R operator is in fact C->R
  • Adjust the behavior of the solvers correspondingly (return C vector)

Then, it should be possible to use those for optimization. What do you think?

@patrick-kidger
Copy link
Owner

I like the identification of materialization as being where things go wrong. I think that's plausibly the main problem or only problem for us.

What does jax.jacfwd(some_complex_to_real_function) do here? This is the native JAX equivalent. We should be able to do whatever they do.

@Randl
Copy link
Contributor Author

Randl commented Aug 18, 2024

jax gives an error and suggests to use jvp directly:
https://github.com/google/jax/blob/b957f8baab287f1a0e1e880b885f89b1f4272b50/jax/_src/api.py#L846-L850
I'm not sure that's an option for lineax solvers and as such for optimistix.

@NeilGirdhar
Copy link
Contributor

It would be really cool to have this in Optimistix!

jax gives an error and suggests to use jvp directly:

That's funny, I think I actually added that to the error message in this pull request, which incidentally is about supporting heterogeneous pytrees (with complex and real values).

  • Add flag indicating that the R^2->R operator is in fact C->R

I just wonder why you need the flag? Wouldn't it be more ideal to support any pytree input? It might we worth it to take a look at my pull request to see how I did it and whether you can adapt my solution (or more ideally call into it somehow) to Optimistix?

@Randl
Copy link
Contributor Author

Randl commented Aug 27, 2024

I agree that support for arbitrary pytree can be nice.
I'm not sure what exactly you propose to call since we want an equivalent of jaxfwd which doesn't support C->R in Jax.
We do not probably need the flag explicitly, as input structure has the required information of which part is complex.
What we need is then during operator application in case of operator stored as matrix and during materialization of Jacobian operator convert the complex input into real one.

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Aug 27, 2024

We do not probably need the flag explicitly, as input structure has the required information of which part is complex.

Totally agree!

What we need is then during operator application in case of operator stored as matrix and during materialization of Jacobian operator convert the complex input into real one.

Maybe. Is it possible to fix jax.jacfwd? Then Patrick's wish "We should be able to do whatever they do" would come true?

@Randl
Copy link
Contributor Author

Randl commented Aug 27, 2024

Is it possible to fix jax.jacfwd?

Depends on you expectations of jacfwd. You can't possibly expect to get a matrix that you multiply the vector by to get JVP (see the discussion of complex linearity above). We can make it output R^2 for each C input, but the bookkeeping of transforming input is still on user. I'm not sure if jax would want this change, I'd say it is controversial for exposure to the end user (I'm more ok with it for internal use).

@Randl
Copy link
Contributor Author

Randl commented Sep 13, 2024

Related tutorial that discusses similar issues of JVP of C->R functions:

https://arxiv.org/abs/2409.06752

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants