Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why are step_ts and jump_ts treated differently here? #483

Open
andyElking opened this issue Aug 10, 2024 · 10 comments
Open

Why are step_ts and jump_ts treated differently here? #483

andyElking opened this issue Aug 10, 2024 · 10 comments
Labels
question User queries

Comments

@andyElking
Copy link
Contributor

andyElking commented Aug 10, 2024

Hi Patrick,

Am I correct in saying that the only differences between step_ts and jump_ts are the following:

  • jump_ts cannot be integers, but must be floats (so that you can do prevbefore and nextafter)
  • _clip_jump_ts also returns made_jump, which is used to determine whether we need to do _t1 = nextafter(nextafter(t1)).

But in addition to those discrepancies, it seems Diffrax treats them differently in one other way as well, which I am not sure I understand. Namely, the line below uses prev_dt=prev_dt if the step was clipped due to a jump, but prev_dt=t1-t0 if the step was clipped due to step_ts. I don't see why we should make a distinction between these two cases.

prev_dt = jnp.where(made_jump, prev_dt, t1 - t0)

I would go even further and say that the line should just say prev_dt=t1-t0 in all cases. This is because the error of the current step depends on t1-t0, rather than on prev_dt, so I feel like keeping prev_dt in controller state is not needed. Here is what could go wrong with the current setup:

Say prev_dt=0.1, but due to jump_ts it was clipped to t1-t0 = 0.01. Also assume that the error was large and the step gets rejected and assume that the controller computes factor=0.5. Then the next step-size proposal will be 0.05, which is bigger than the step that was just taken, so it will again be clipped by jump_ts to 0.01, resulting in an infinite loop. Instead the new step proposal should just be (t1-t0)*factor = 0.005, which would presumably result in a smaller error and move forward.

On the other hand if the step was clipped to a much smaller size than was intended (i.e. t1-t0 << prev_dt), then this will usually reflect in the error being small accordingly, resulting in a large factor. This means that (t1-t0)*factor would be again a reasonably large step-size proposal, whereas prev_dt*factor would be disproportionately massive.

Let me know if I missed something.

@patrick-kidger
Copy link
Owner

So I don't think this should ever be an infinite loop -- as the next time around then prev_dt (pulled out of the controller state) should be 0.05, and we'll keep on shrinking the step.

I do take your point that if the previous step was rejected, we shouldn't use the prev_dt out of the controller state. In this case we should use the t1 - t0 that actually caused a step rejection. I'd be happy to have a tweak to this effect. (I think we need a test for this too! This is subtle.)

As for why we continue to use prev_dt, the reason is to handle the case in which prev_dt=<large>, but jump_ts clips it to <small>, and the step is accepted. In this case, we don't want the steps after the jump to have to slowly work their way back up to a good step size. This is particularly troublesome for problems which have many jumps. If it would only take 1 step to move between jumps given prev_dt, but we'd need to spend 3 steps working back up to the 'proper step size' given the clipped t1 - t0, then we'd end up tripling the number of steps required.

Whilst we're here I will note that there is one other difference between jump_ts and step_ts, and that is that the former causes the FSAL property of e.g. a Runge--Kutta method to discard the saved function evaluation, whilst in the latter case we can continue to use it.

@patrick-kidger patrick-kidger added the question User queries label Aug 11, 2024
@andyElking
Copy link
Contributor Author

So I don't think this should ever be an infinite loop -- as the next time around then prev_dt (pulled out of the controller state) should be 0.05, and we'll keep on shrinking the step.

Oh right, I missed that.

I do take your point that if the previous step was rejected, we shouldn't use the prev_dt out of the controller state. In this case we should use the t1 - t0 that actually caused a step rejection. I'd be happy to have a tweak to this effect. (I think we need a test for this too! This is subtle.)

Sounds good, I'll include that change in the JumpStepControllerWrapper and we can discuss in more detail once I make the PR.

As for why we continue to use prev_dt, the reason is to handle the case in which prev_dt=<large>, but jump_ts clips it to <small>, and the step is accepted. In this case, we don't want the steps after the jump to have to slowly work their way back up to a good step size. This is particularly troublesome for problems which have many jumps. If it would only take 1 step to move between jumps given prev_dt, but we'd need to spend 3 steps working back up to the 'proper step size' given the clipped t1 - t0, then we'd end up tripling the number of steps required.

I understand and I agree. However:

  1. If we choose to use prev_dt when the step was clipped due to jump_ts, I feel like we should also use it when it was clipped due to step_ts.
  2. Have you considered using something like max((t1-t0)*factor, prev_dt) in this case? After all I feel like three steps that are too small but accepted are still be better than three steps that are too big and rejected. And if the error was small, than factor should already be large anyway, no?

Whilst we're here I will note that there is one other difference between jump_ts and step_ts, and that is that the former causes the FSAL property of e.g. a Runge--Kutta method to discard the saved function evaluation, whilst in the latter case we can continue to use it.

Thanks, that is very useful to know.

@patrick-kidger
Copy link
Owner

  1. That also sounds reasonable to me.
  2. Indeed, I think other heuristics could be deployed here too. I'm a bit wary about adding heuristics that only trigger after jumps though (which IIUC is what you're suggesting) -- that just sounds like it's getting a bit tricky to reason about / to debug.

@andyElking
Copy link
Contributor Author

andyElking commented Aug 12, 2024

I would summarise the complete behaviour in 3 rules:

  1. We always have t1-t0 <= prev_dt (we can explicitly check that with an eqx.error_if), with inequality only when the step was clipped or if we hit the end of the integration interval (we do not explicitly check that, but I see no other way how inequality could arise here).
  2. If the step was accepted, then next_dt must be >=prev_dt.
  3. If the step was rejected, then next_dt must be < t1-t0.

These can be implemented in a very simple way:

dt_proposal = factor*(t1 - t0)  # note that if step is rejected, then factor<1
# Here comes the clipping between dt_min and dt_max

eqx.error_if(prev_dt, prev_dt < t1-t0, "prev_dt must be >= t1-t0")

dt_proposal = jnp.where(keep_step, jnp.maximum(dt_proposal, prev_dt), dt_proposal)
new_prev_dt = dt_proposal  # this goes into controller state as prev_dt

# Here comes the clipping due to step_ts and jump_ts and the whole nextafter(nextafter()) business

This has the nice property that it factors well into a controller (which does the first two lines) and a JumpStepWrapper which does all the rest. That means that made_jump (used only for the purposes of the nextafter business) and prev_dt are both kept in the state of the JumpStepWrapper and the inner controller never sees them.

WDYT?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Aug 17, 2024

On eqx.error_if: I try to be use these really rarely, as they carry a performance overhead. In this case I don't think it's important enough to use it. (Besides that, note that ou have to use its return value if you want the check to run.)

I think what you've got sounds reasonable. Mulling it over, I think it should be possible to do something even simpler: change this line:

prev_dt = jnp.where(made_jump, prev_dt, t1 - t0)

with

- prev_dt = jnp.where(made_jump, prev_dt, t1 - t0)
+ prev_dt = jnp.where(made_jump & keep_step, prev_dt, t1 - t0)

(and if need be I think this still factors apart, as you describe).

@andyElking
Copy link
Contributor Author

Thanks, that's good to know. I will keep the number of eqx.error_ifs to a minimum.

In practice it seems that your proposal doesn't lead to desirable behaviour. I compared our two approaches on a very simple example ODE and I was surprised how precisely the experiment echoed the issue I described in my first comment:

On the other hand if the step was clipped to a much smaller size than was intended (i.e. t1-t0 << prev_dt), then this will usually reflect in the error being small accordingly, resulting in a large factor (and the step being accepted). This means that (t1-t0)*factor would be again a reasonable step-size proposal, whereas prev_dt*factor would be disproportionately massive.

In addition the experiment shows that my solution completely fixes this issue. You can find the experiment here. And here you can see why my proposal makes it easier to separate the jump_ts and step_ts into a wrapper.

@patrick-kidger
Copy link
Owner

Ah, I see what you mean.

Okay, in this case I think maybe what we should do is simply remove dt from the controller state altogether. Just always use t1 - t0, nothing else.

I believe your suggestion amounts to preventing the step size form shrinking after an accepted step. For context some PID implementations exhibit this behaviour (e.g. torchdiffeq does this) but I recall deciding against this for Diffrax. It's a heuristic that I think helps some problems but hurts others.

@andyElking
Copy link
Contributor Author

Fair enough, I'll get rid of it then. On the flip side, as I mentioned in #484, it seems like it was possible for the step size to increase after rejecting (I discovered this in some unrelated experiment, where it just seemed to go on forever until max_steps was reached). Was this intentional or is it good that I now capped dt_proposal at self.safety*(t1-t0) when keep_step=False?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Aug 18, 2024

That is definitely pretty weird! I'm willing to believe that it happens, though.

In fact, here's something interesting I came across whilst looking at this just now:

factormin = jnp.where(keep_step, 1, self.factormin)

It seems we do prevent step shrinking after an accepted step after all! 😅

In light of this, maybe we should fix the case you just mentioned by also adding factormax = jnp.where(keep_step, self.factormax, 1)?

(EDIT: I've now seen that you mentioned this in #484. Ignore me, you're way ahead of me!)

@andyElking
Copy link
Contributor Author

Well technically this prevents it only from shrinking below t1-t0, which can still be smaller than prev_dt. But as you said before, probably all of these choices can bring pluses and minuses in different cases.

Haha yes, I was just about to point you to #484 😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants