diff --git a/docs/Project.toml b/docs/Project.toml index e1f963380..8c9732713 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,6 +6,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" +FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Integrals = "de52edbc-65ea-441a-8357-d3a637375a31" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" diff --git a/docs/pages.jl b/docs/pages.jl index 54b706f28..15bb065a6 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -3,6 +3,7 @@ pages = ["index.md", "Bayesian PINNs for Coupled ODEs" => "tutorials/Lotka_Volterra_BPINNs.md", "PINNs DAEs" => "tutorials/dae.md", "Parameter Estimation with PINNs for ODEs" => "tutorials/ode_parameter_estimation.md", + "Improved PINNs for Inverse problems in ODEs" => "tutorials/data_collocation_Inverse.md", "Physics informed Neural Operator ODEs" => "tutorials/pino_ode.md", "Deep Galerkin Method" => "tutorials/dgm.md" #"examples/nnrode_example.md", # currently incorrect ], diff --git a/docs/src/index.md b/docs/src/index.md index 0511445f8..ee7e49b44 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -34,7 +34,7 @@ Pkg.add("NeuralPDE") ## Contributing - Please refer to the - [SciML ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://github.com/SciML/ColPrac/blob/master/README.md) + [SciML ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://github.com/SciML/ColPrac) for guidance on PRs, issues, and other matters relating to contributing to SciML. - See the [SciML Style Guide](https://github.com/SciML/SciMLStyle) for common coding practices and other style decisions. diff --git a/docs/src/tutorials/data_collocation_Inverse.md b/docs/src/tutorials/data_collocation_Inverse.md new file mode 100644 index 000000000..62f8d45b7 --- /dev/null +++ b/docs/src/tutorials/data_collocation_Inverse.md @@ -0,0 +1,142 @@ +# Model Improvement in Physics-Informed Neural Networks for solving Inverse problems in ODEs. + +Consider an Inverse problem setting for the [lotka volterra system](https://en.wikipedia.org/wiki/Lotka%E2%80%93Volterra_equations). Here we want to optimize parameters $\alpha$, $\beta$, $\gamma$ and $\delta$ and also solve a parametric Lotka Volterra system. +PINNs are especially useful in these types of problems and are preferred over conventional solvers, due to their ability to learn from observations - the underlying physics governing the distribution of observations. + +We start by defining the problem, with a random and non informative initialization for parameters: + +```@example improv_param_estim +using NeuralPDE, OrdinaryDiffEq, Lux, Random, OptimizationOptimJL, LineSearches, + Distributions, Plots +using FastGaussQuadrature +using Test # hide + +function lv(u, p, t) + u₁, u₂ = u + α, β, γ, δ = p + du₁ = α * u₁ - β * u₁ * u₂ + du₂ = δ * u₁ * u₂ - γ * u₂ + [du₁, du₂] +end + +tspan = (0.0, 5.0) +u0 = [5.0, 5.0] +initialization = [-5.0, 8.0, 5.0, -7.0] +prob = ODEProblem(lv, u0, tspan, initialization) +``` + +We require a set of observations before we train the PINN. +Considering we want robust results even for cases where measurement values are sparse and limited in number. +We simulate a system that uses the true parameter `true_p` values and record phenomena/solution (`u`) values algorithmically at only `N=20` pre-decided timepoints in the system's time domain. + +The value for `N` can be incremented based on the non linearity (~ `N` degree polynomial) in the measured phenomenon, this tutorial's setting shows that even with minimal but systematically chosen data-points we can extract excellent results. + +```@example improv_param_estim +true_p = [1.5, 1.0, 3.0, 1.0] +prob_data = remake(prob, p = true_p) + +N = 20 +x, w = gausslobatto(N) +a = tspan[1] +b = tspan[2] +``` + +Now scale the weights and the gauss-lobatto/clenshaw-curtis/gauss-legendre quadrature points to fit in `tspan`. + +```@example improv_param_estim +t = map((x) -> (x * (b - a) + (b + a)) / 2, x) +W = map((x) -> x * (b - a) / 2, w) +``` + +We now have our dataset of `20` measurements in our `tspan` and corresponding weights. Using this we can now use the Data Quadrature loss function by passing `estim_collocate` = `true` in [`NNODE`](@ref). + +```@example improv_param_estim +sol_data = solve(prob_data, Tsit5(); saveat = t) +t_ = sol_data.t +u_ = sol_data.u +u1_ = [u_[i][1] for i in eachindex(t_)] +u2_ = [u_[i][2] for i in eachindex(t_)] +dataset = [u1_, u2_, t_, W] +``` + +Now, let's define a neural network for the PINN using [Lux.jl](https://lux.csail.mit.edu/). + +```@example improv_param_estim +rng = Random.default_rng() +Random.seed!(rng, 0) +n = 7 +chain = Chain(Dense(1, n, tanh), Dense(n, n, tanh), Dense(n, 2)) +ps, st = Lux.setup(rng, chain) |> f64 +``` + +!!! note + + While solving Inverse problems, when we specify `param_estim = true` in [`NNODE`](@ref) or [`BNNODE`](@ref), an L2 loss function measuring how the neural network's predictions fit the provided `dataset` is used internally during Maximum Likelihood Estimation. + Therefore, the `additional_loss` mentioned in the [ODE parameter estimation tutorial](https://docs.sciml.ai/NeuralPDE/stable/tutorials/ode_parameter_estimation/) is not limited to an L2 loss function against data. + +We now define the optimizer and [`NNODE`](@ref) - the ODE solving PINN algorithm, for the old PINN model and the proposed new PINN formulation which uses a Data Quadrature loss. +This optimizer and respective algorithms are plugged into the `solve` calls for comparing results between the new and old PINN models. + +```@example improv_param_estim +opt = LBFGS(linesearch = BackTracking()) + +alg_old = NNODE( + chain, opt; strategy = GridTraining(0.01), dataset = dataset, param_estim = true) + +alg_new = NNODE(chain, opt; strategy = GridTraining(0.01), param_estim = true, + dataset = dataset, estim_collocate = true) +``` + +Now we have all the pieces to solve the optimization problem. + +```@example improv_param_estim +sol_old = solve( + prob, alg_old; verbose = true, abstol = 1e-12, maxiters = 5000, saveat = 0.01) + +sol_new = solve( + prob, alg_new; verbose = true, abstol = 1e-12, maxiters = 5000, saveat = 0.01) + +sol = solve(prob_data, Tsit5(); saveat = 0.01) +sol_points = hcat(sol.u...) +sol_old_points = hcat(sol_old.u...) +sol_new_points = hcat(sol_new.u...) +``` + +Let's plot the predictions from the PINN models, data used and compare it to the ideal system solution. +First the old model. + +```@example improv_param_estim +plot(sol, labels = ["u1" "u2"]) +plot!(sol_old, labels = ["u1_pinn_old" "u2_pinn_old"]) +scatter!(sol_data, labels = ["u1_data" "u2_data"]) +``` + +Clearly the old model cannot optimize given a realistic, tougher initialization of parameters especially with such limited data. It only seems to work when initial values are close to `true_p` and we have around `500` points for our `tspan`, as seen in the [ODE parameter estimation tutorial](https://docs.sciml.ai/NeuralPDE/stable/tutorials/ode_parameter_estimation/). + +Lets move on to the proposed new model... + +```@example improv_param_estim +plot(sol, labels = ["u1" "u2"]) +plot!(sol_new, labels = ["u1_pinn_new" "u2_pinn_new"]) +scatter!(sol_data, labels = ["u1_data" "u2_data"]) +``` + +We can see that it is a good fit! Now let's see what the estimated parameters of the equation tell us in both cases. + +```@example improv_param_estim +sol_old.k.u.p +@test any(true_p .- sol_old.k.u.p .> 0.5 .* true_p) # hide +``` + +Nowhere near the true [1.5, 1.0, 3.0, 1.0]. But the new model gives : + +```@example improv_param_estim +sol_new.k.u.p +@test sol_new.k.u.p≈true_p rtol=2e-2 norm=Base.Fix1(maximum, abs) # hide +``` + +This is indeed close to the true values [1.5, 1.0, 3.0, 1.0]. + +!!! note + + This feature for using a Data collocation loss is also available for BPINNs solving Inverse problems in ODEs. Use a `dataset` of the form as described in this tutorial and set `estim_collocate`=`true` and you are good to go. diff --git a/docs/src/tutorials/ode_parameter_estimation.md b/docs/src/tutorials/ode_parameter_estimation.md index 784828ee7..2db70df06 100644 --- a/docs/src/tutorials/ode_parameter_estimation.md +++ b/docs/src/tutorials/ode_parameter_estimation.md @@ -54,7 +54,7 @@ Next we define the optimizer and [`NNODE`](@ref) which is then plugged into the ```@example param_estim_lv opt = LBFGS(linesearch = BackTracking()) alg = NNODE(chain, opt, ps; strategy = WeightedIntervalTraining([0.7, 0.2, 0.1], 500), - param_estim = true, additional_loss) + param_estim = true, additional_loss = additional_loss) ``` Now we have all the pieces to solve the optimization problem. diff --git a/src/PDE_BPINN.jl b/src/PDE_BPINN.jl index a9b272c39..da01b63ef 100644 --- a/src/PDE_BPINN.jl +++ b/src/PDE_BPINN.jl @@ -284,7 +284,7 @@ end each dependant variable of interest. * `phystd`: Vector of standard deviations of BPINN prediction against Chosen Underlying PDE equations. -* `phynewstd`: Vector of standard deviations of new loss term. +* `phynewstd`: Vector of standard deviations of the Data Quadrature loss term. * `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of BPINN are Normal Distributions by default. * `param`: Vector of chosen PDE's parameter's Distributions in case of Inverse problems. @@ -437,7 +437,7 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization; @printf("Current SSE against dataset Log-likelihood : %g\n", L2LossData(ℓπ, initial_θ)) if !(newloss isa Nothing) - @printf("Current new loss : %g\n", + @printf("Current Data Quadrature loss : %g\n", ℓπ.L2_loss2(setparameters(ℓπ, initial_θ), ℓπ.phynewstd)) end @@ -495,7 +495,7 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization; @printf("Final SSE against dataset Log-likelihood : %g\n", L2LossData(ℓπ, samples[end])) if !(newloss isa Nothing) - @printf("Final new loss : %g\n", + @printf("Final Data Quadrature loss : %g\n", ℓπ.L2_loss2(setparameters(ℓπ, samples[end]), ℓπ.phynewstd)) end diff --git a/src/advancedHMC_MCMC.jl b/src/advancedHMC_MCMC.jl index b1a27df88..c2feba533 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/advancedHMC_MCMC.jl @@ -336,7 +336,7 @@ Incase you are only solving Non parametric ODE Equations for a solution, do not * `strategy`: The training strategy used to choose the points for the evaluations. By default GridTraining is used with given physdt discretization. * `dataset`: Is either an empty Vector or a nested Vector of the form `[x̂, t, W]` where `x̂` are dependant variable observations, `t` are time points and `W` are quadrature weights for domain. - The dataset is used to compute the L2 loss against the data and also for the new loss function. + The dataset is used to compute the L2 loss against the data and also for the Data Duadrature loss function. For multiple dependant variables, there will be multiple vectors with the last two vectors in dataset still being for `t`, `W`. Is empty by default assuming a forward problem is being solved. * `init_params`: initial parameter values for BPINN (ideally for multiple chains different @@ -346,7 +346,7 @@ Incase you are only solving Non parametric ODE Equations for a solution, do not ~2/3 of draw samples) * `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset * `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System -* `phynewstd`: A function that gives the standard deviation of the new loss function at each iteration. +* `phynewstd`: A function that gives the standard deviation of the Data Quadrature loss function at each iteration. It takes the ODE parameters as input and returns a vector of standard deviations. Is (ode_params) -> [0.05] by default. * `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of @@ -368,7 +368,7 @@ Incase you are only solving Non parametric ODE Equations for a solution, do not * `max_depth`: Maximum doubling tree depth (NUTS) * `Δ_max`: Maximum divergence during doubling tree (NUTS) Refer: https://turinglang.org/AdvancedHMC.jl/stable/ -* `estim_collocate`: A boolean value to indicate whether to use the new loss function or not. This is only relevant for ODE parameter estimation. +* `estim_collocate`: A boolean value to indicate whether to use the Data Quadrature loss function or not. This is only relevant for ODE parameter estimation. * `progress`: controls whether to show the progress meter or not. * `verbose`: controls the verbosity. (Sample call args in AHMC) diff --git a/src/ode_solve.jl b/src/ode_solve.jl index 89cac8b33..9f789e9eb 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -30,7 +30,7 @@ standard `ODEProblem`. * `additional_loss`: A function additional_loss(phi, θ) where phi are the neural network trial solutions, θ are the weights of the neural network(s). * `dataset`: Is either an empty Vector or a nested Vector of the form `[x̂, t, W]` where `x̂` are dependant variable observations, `t` are time points and `W` are quadrature weights for domain. - The dataset is used to compute a L2 loss against the data and also for the new loss function. + The dataset is used to compute a L2 loss against the data and also for the Data Quadrature loss function. For multiple dependant variables, there will be multiple vectors with the last two vectors in dataset still being for `t`, `W`. Is empty by default assuming a forward problem is being solved. * `autodiff`: The switch between automatic and numerical differentiation for @@ -49,7 +49,7 @@ standard `ODEProblem`. * `strategy`: The training strategy used to choose the points for the evaluations. Default of `nothing` means that `QuadratureTraining` with QuadGK is used if no `dt` is given, and `GridTraining` is used with `dt` if given. -* `estim_collocate`: A boolean value to indicate whether to use the new loss function or not. This is only relevant for ODE parameter estimation. +* `estim_collocate`: A boolean value to indicate whether to use the Data Quadrature loss function or not. This is only relevant for ODE parameter estimation. * `kwargs`: Extra keyword arguments are splatted to the Optimization.jl `solve` call. ## Examples @@ -282,7 +282,7 @@ function generate_L2lossData(dataset, phi, n_output) end """ -new loss +Data Quadrature loss function (provides very accurate solution, parameter estimates and a method for algorithmic sampling of a minimal set of data points for Inverse problems). """ function generate_L2loss2(f, autodiff, dataset, phi, n_output) isempty(dataset) && return 0 @@ -393,7 +393,7 @@ function SciMLBase.__solve( if isempty(dataset) && param_estim && isnothing(additional_loss) error("Dataset or an additional loss is required for Inverse problems performing Parameter Estimation.") elseif isempty(dataset) && estim_collocate - error("Dataset is required for Inverse problems performing Parameter Estimation using the new loss.") + error("Dataset is required for Inverse problems performing Parameter Estimation using the Data Quadrature loss function.") end n_output = length(u0) @@ -406,7 +406,7 @@ function SciMLBase.__solve( if param_estim && estim_collocate L2_loss = L2_loss + L2lossData(θ, phi) + L2loss2(θ, phi) - elseif param_estim + elseif param_estim && !isempty(dataset) L2_loss = L2_loss + L2lossData(θ, phi) end if additional_loss !== nothing diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index 4b63a8992..d7849a469 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -219,7 +219,7 @@ end n = 8 luxchain = Chain(Dense(1, n, σ), Dense(n, n, σ), Dense(n, 3)) - # this example is especially easy for new loss. + # this example is especially easy for the Data Quadrature loss. # even with ~2 observed data points, we can exactly calculate the physics parameters (even before solve call). N = 7 # x, w = gausslegendre(N) # does not include endpoints