|
1 | 1 | import dotenv |
| 2 | +import numpy as np |
2 | 3 | import os |
3 | | - |
4 | 4 | from typing import Optional |
5 | 5 |
|
6 | | -import numpy as np |
7 | | - |
8 | 6 | import gempy_engine |
9 | | -from gempy_engine.core.backend_tensor import BackendTensor |
10 | 7 | from gempy.API.gp2_gp3_compatibility.gp3_to_gp2_input import gempy3_to_gempy2 |
11 | 8 | from gempy_engine.config import AvailableBackends |
| 9 | +from gempy_engine.core.backend_tensor import BackendTensor |
12 | 10 | from gempy_engine.core.data import Solutions |
13 | | -from gempy_engine.core.data.interpolation_input import InterpolationInput |
14 | 11 | from .grid_API import set_custom_grid |
| 12 | +from ..core.data import StructuralGroup |
15 | 13 | from ..core.data.gempy_engine_config import GemPyEngineConfig |
16 | 14 | from ..core.data.geo_model import GeoModel |
17 | | -from ..modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame |
| 15 | +from ..modules.data_manipulation import interpolation_input_from_structural_frame |
| 16 | +from ..modules.optimize_nuggets import nugget_optimizer |
18 | 17 | from ..optional_dependencies import require_gempy_legacy |
19 | 18 |
|
20 | 19 | dotenv.load_dotenv() |
@@ -92,91 +91,29 @@ def compute_model_at(gempy_model: GeoModel, at: np.ndarray, |
92 | 91 | return sol.raw_arrays.custom |
93 | 92 |
|
94 | 93 |
|
95 | | -def optimize_and_compute(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10, |
96 | | - convergence_criteria: float = 1e5): |
| 94 | +def optimize_nuggets(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10, |
| 95 | + convergence_criteria: float = 1e5, only_groups:list[StructuralGroup] | None = None) -> GeoModel: |
| 96 | + """ |
| 97 | + Optimize the nuggets of the interpolation input of the provided model. |
| 98 | + """ |
| 99 | + |
97 | 100 | if engine_config.backend != AvailableBackends.PYTORCH: |
98 | 101 | raise ValueError(f'Only PyTorch backend is supported for optimization. Received {engine_config.backend}') |
99 | | - |
100 | | - BackendTensor.change_backend_gempy( |
101 | | - engine_backend=engine_config.backend, |
102 | | - use_gpu=engine_config.use_gpu, |
103 | | - dtype=engine_config.dtype |
104 | | - ) |
105 | | - |
106 | | - import torch |
107 | | - from gempy_engine.core.data.continue_epoch import ContinueEpoch |
108 | | - interpolation_input: InterpolationInput = interpolation_input_from_structural_frame(geo_model) |
109 | | - |
110 | | - geo_model.taped_interpolation_input = interpolation_input |
111 | | - |
112 | | - nugget_effect_scalar: torch.Tensor = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar |
113 | | - |
114 | | - optimizer = torch.optim.Adam( |
115 | | - params=[nugget_effect_scalar], |
116 | | - lr=0.01, |
| 102 | + |
| 103 | + geo_model = nugget_optimizer( |
| 104 | + target_cond_num=convergence_criteria, |
| 105 | + engine_cfg=engine_config, |
| 106 | + model=geo_model, |
| 107 | + max_epochs=max_epochs, |
| 108 | + only_groups=only_groups |
117 | 109 | ) |
118 | 110 |
|
119 | | - # Optimization loop |
120 | | - geo_model.interpolation_options.kernel_options.optimizing_condition_number = True |
121 | | - |
122 | | - def _check_convergence_criterion(conditional_number: float, condition_number_old: float, conditional_number_target: float = 1e5): |
123 | | - reached_conditional_target = conditional_number < conditional_number_target |
124 | | - if reached_conditional_target == False and epoch > 10: |
125 | | - condition_number_change = torch.abs(conditional_number - condition_number_old) / condition_number_old |
126 | | - if condition_number_change < 0.01: |
127 | | - reached_conditional_target = True |
128 | | - return reached_conditional_target |
129 | | - |
130 | | - previous_condition_number = 0 |
131 | | - for epoch in range(max_epochs): |
132 | | - optimizer.zero_grad() |
133 | | - try: |
134 | | - # geo_model.taped_interpolation_input.grid = geo_model.interpolation_input_copy.grid |
135 | | - |
136 | | - gempy_engine.compute_model( |
137 | | - interpolation_input=geo_model.taped_interpolation_input, |
138 | | - options=geo_model.interpolation_options, |
139 | | - data_descriptor=geo_model.input_data_descriptor, |
140 | | - geophysics_input=geo_model.geophysics_input, |
141 | | - ) |
142 | | - except ContinueEpoch: |
143 | | - # Get absolute values of gradients |
144 | | - grad_magnitudes = torch.abs(nugget_effect_scalar.grad) |
145 | | - |
146 | | - # Get indices of the 10 largest gradients |
147 | | - grad_magnitudes.size |
148 | | - |
149 | | - # * This ignores 90 percent of the gradients |
150 | | - # To int |
151 | | - n_values = int(grad_magnitudes.size()[0] * 0.9) |
152 | | - _, indices = torch.topk(grad_magnitudes, n_values, largest=False) |
153 | | - |
154 | | - # Zero out gradients that are not in the top 10 |
155 | | - mask = torch.ones_like(nugget_effect_scalar.grad) |
156 | | - mask[indices] = 0 |
157 | | - nugget_effect_scalar.grad *= mask |
158 | | - |
159 | | - # Update the vector |
160 | | - optimizer.step() |
161 | | - nugget_effect_scalar.data = nugget_effect_scalar.data.clamp_(min=1e-7) # Replace negative values with 0 |
162 | | - |
163 | | - # optimizer.zero_grad() |
164 | | - # Monitor progress |
165 | | - if epoch % 1 == 0: |
166 | | - # print(f"Epoch {epoch}: Condition Number = {condition_number.item()}") |
167 | | - print(f"Epoch {epoch}") |
168 | | - |
169 | | - if _check_convergence_criterion( |
170 | | - conditional_number=geo_model.interpolation_options.kernel_options.condition_number, |
171 | | - condition_number_old=previous_condition_number, |
172 | | - conditional_number_target=convergence_criteria, |
173 | | - ): |
174 | | - break |
175 | | - previous_condition_number = geo_model.interpolation_options.kernel_options.condition_number |
176 | | - continue |
177 | | - |
178 | | - geo_model.interpolation_options.kernel_options.optimizing_condition_number = False |
| 111 | + return geo_model |
179 | 112 |
|
| 113 | +def optimize_and_compute(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10, |
| 114 | + convergence_criteria: float = 1e5): |
| 115 | + |
| 116 | + optimize_nuggets(geo_model, engine_config, max_epochs, convergence_criteria) |
180 | 117 | geo_model.solutions = gempy_engine.compute_model( |
181 | 118 | interpolation_input=geo_model.taped_interpolation_input, |
182 | 119 | options=geo_model.interpolation_options, |
|
0 commit comments