@@ -21,6 +21,8 @@ Reference: PhysRevLett 97, 170201 (2006)
21
21
22
22
#include " minimizer_fire_box_change.cuh"
23
23
#include " utilities/gpu_macro.cuh"
24
+ #include < algorithm>
25
+ #include < cmath>
24
26
#include < cstring>
25
27
26
28
namespace
@@ -151,8 +153,8 @@ void get_force_temp(
151
153
template <int N>
152
154
void solveLinearEquation (const double * A, const double * B, double * X)
153
155
{
154
-
155
156
double a[N][N], b[N][N];
157
+
156
158
for (int j = 0 ; j < N; ++j) {
157
159
for (int i = 0 ; i < N; ++i) {
158
160
a[i][j] = A[j * N + i];
@@ -161,27 +163,43 @@ void solveLinearEquation(const double* A, const double* B, double* X)
161
163
}
162
164
163
165
for (int col = 0 ; col < N; ++col) {
164
- for (int i = 0 ; i < N; ++i) {
165
- if (i == col) {
166
- double diag = a[i][col];
167
- if (fabs (diag) < 1e-9 ) {
168
- printf (" Matrix is singular or nearly singular!\n " );
169
- return ;
170
- }
171
- for (int j = 0 ; j < N; ++j) {
172
- a[i][j] /= diag;
173
- b[i][j] /= diag;
174
- }
175
- } else {
176
- double factor = a[i][col];
166
+ int pivot_row = col;
167
+ for (int i = col + 1 ; i < N; ++i) {
168
+ if (fabs (a[i][col]) > fabs (a[pivot_row][col])) {
169
+ pivot_row = i;
170
+ }
171
+ }
172
+
173
+ if (fabs (a[pivot_row][col]) < 1e-9 ) {
174
+ printf (" Matrix is singular or nearly singular!\n " );
175
+ return ;
176
+ }
177
+
178
+ if (pivot_row != col) {
179
+ for (int j = 0 ; j < N; ++j) {
180
+ std::swap (a[col][j], a[pivot_row][j]);
181
+ std::swap (b[col][j], b[pivot_row][j]);
182
+ }
183
+ }
184
+
185
+ double diag = a[col][col];
186
+ for (int j = 0 ; j < N; ++j) {
187
+ a[col][j] /= diag;
188
+ b[col][j] /= diag;
189
+ }
190
+
191
+ for (int row = 0 ; row < N; ++row) {
192
+ if (row != col) {
193
+ double factor = a[row][col];
177
194
for (int j = 0 ; j < N; ++j) {
178
- a[i ][j] -= factor * a[col][j];
179
- b[i ][j] -= factor * b[col][j];
195
+ a[row ][j] -= factor * a[col][j];
196
+ b[row ][j] -= factor * b[col][j];
180
197
}
181
198
}
182
199
}
183
200
}
184
201
202
+ // 将计算得到的结果存储回 X 中(行主序)
185
203
for (int i = 0 ; i < N; ++i) {
186
204
for (int j = 0 ; j < N; ++j) {
187
205
X[i * N + j] = b[i][j];
0 commit comments