Skip to content

Commit 070c3a7

Browse files
authored
Merge pull request brucefan1983#906 from mushroomfire/mini
Fix a bug for Minimizer
2 parents 8c68ddf + 8210e0c commit 070c3a7

File tree

1 file changed

+34
-16
lines changed

1 file changed

+34
-16
lines changed

src/minimize/minimizer_fire_box_change.cu

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ Reference: PhysRevLett 97, 170201 (2006)
2121

2222
#include "minimizer_fire_box_change.cuh"
2323
#include "utilities/gpu_macro.cuh"
24+
#include <algorithm>
25+
#include <cmath>
2426
#include <cstring>
2527

2628
namespace
@@ -151,8 +153,8 @@ void get_force_temp(
151153
template <int N>
152154
void solveLinearEquation(const double* A, const double* B, double* X)
153155
{
154-
155156
double a[N][N], b[N][N];
157+
156158
for (int j = 0; j < N; ++j) {
157159
for (int i = 0; i < N; ++i) {
158160
a[i][j] = A[j * N + i];
@@ -161,27 +163,43 @@ void solveLinearEquation(const double* A, const double* B, double* X)
161163
}
162164

163165
for (int col = 0; col < N; ++col) {
164-
for (int i = 0; i < N; ++i) {
165-
if (i == col) {
166-
double diag = a[i][col];
167-
if (fabs(diag) < 1e-9) {
168-
printf("Matrix is singular or nearly singular!\n");
169-
return;
170-
}
171-
for (int j = 0; j < N; ++j) {
172-
a[i][j] /= diag;
173-
b[i][j] /= diag;
174-
}
175-
} else {
176-
double factor = a[i][col];
166+
int pivot_row = col;
167+
for (int i = col + 1; i < N; ++i) {
168+
if (fabs(a[i][col]) > fabs(a[pivot_row][col])) {
169+
pivot_row = i;
170+
}
171+
}
172+
173+
if (fabs(a[pivot_row][col]) < 1e-9) {
174+
printf("Matrix is singular or nearly singular!\n");
175+
return;
176+
}
177+
178+
if (pivot_row != col) {
179+
for (int j = 0; j < N; ++j) {
180+
std::swap(a[col][j], a[pivot_row][j]);
181+
std::swap(b[col][j], b[pivot_row][j]);
182+
}
183+
}
184+
185+
double diag = a[col][col];
186+
for (int j = 0; j < N; ++j) {
187+
a[col][j] /= diag;
188+
b[col][j] /= diag;
189+
}
190+
191+
for (int row = 0; row < N; ++row) {
192+
if (row != col) {
193+
double factor = a[row][col];
177194
for (int j = 0; j < N; ++j) {
178-
a[i][j] -= factor * a[col][j];
179-
b[i][j] -= factor * b[col][j];
195+
a[row][j] -= factor * a[col][j];
196+
b[row][j] -= factor * b[col][j];
180197
}
181198
}
182199
}
183200
}
184201

202+
// 将计算得到的结果存储回 X 中(行主序)
185203
for (int i = 0; i < N; ++i) {
186204
for (int j = 0; j < N; ++j) {
187205
X[i * N + j] = b[i][j];

0 commit comments

Comments
 (0)