diff --git a/examples/nbody/nbody.cpp b/examples/nbody/nbody.cpp index 45beeb3cb1..8334886b73 100644 --- a/examples/nbody/nbody.cpp +++ b/examples/nbody/nbody.cpp @@ -210,7 +210,7 @@ namespace usellama return FP{1} / sqrt(distSixth); } }(); - const auto sts = (pj(tag::Mass{}) * timestep) * invDistCube; + const auto sts = pj(tag::Mass{}) * timestep * invDistCube; pi(tag::Vel{}) += dist * sts; } @@ -416,7 +416,7 @@ namespace manualAoS const FP distSqr = eps2 + distance.x + distance.y + distance.z; const FP distSixth = distSqr * distSqr * distSqr; const FP invDistCube = 1.0f / std::sqrt(distSixth); - const FP sts = pj.mass * invDistCube * timestep; + const FP sts = pj.mass * timestep * invDistCube; pi.vel += distance * sts; } @@ -504,7 +504,7 @@ namespace manualSoA const FP distSqr = eps2 + xdistance + ydistance + zdistance; const FP distSixth = distSqr * distSqr * distSqr; const FP invDistCube = 1.0f / std::sqrt(distSixth); - const FP sts = pjmass * invDistCube * timestep; + const FP sts = pjmass * timestep * invDistCube; pivelx += xdistance * sts; pively += ydistance * sts; pivelz += zdistance * sts; @@ -631,7 +631,7 @@ namespace manualAoSoA const FP distSqr = eps2 + xdistance + ydistance + zdistance; const FP distSixth = distSqr * distSqr * distSqr; const FP invDistCube = 1.0f / std::sqrt(distSixth); - const FP sts = pjmass * invDistCube * timestep; + const FP sts = pjmass * timestep * invDistCube; pivelx += xdistance * sts; pively += ydistance * sts; pivelz += zdistance * sts; @@ -849,7 +849,8 @@ namespace manualAoSoAManualAVX else return _mm256_div_ps(_mm256_set1_ps(1.0f), _mm256_sqrt_ps(distSixth)); }(); - const __m256 sts = _mm256_mul_ps(_mm256_mul_ps(pjmass, invDistCube), vTIMESTEP); + // FIXME(bgruber): we could keep pjmass scalar for the update8 version: + const __m256 sts = _mm256_mul_ps(_mm256_mul_ps(pjmass, vTIMESTEP), invDistCube); pivelx = _mm256_fmadd_ps(xdistanceSqr, sts, pivelx); pively = _mm256_fmadd_ps(ydistanceSqr, sts, pively); pivelz = _mm256_fmadd_ps(zdistanceSqr, sts, pivelz); @@ -1018,7 +1019,7 @@ namespace manualAoSoASIMD }; - template + template inline void pPInteraction( Simd piposx, Simd piposy, @@ -1029,7 +1030,7 @@ namespace manualAoSoASIMD Simd pjposx, Simd pjposy, Simd pjposz, - Simd pjmass) + SimdOrScalar pjmass) { const Simd xdistance = piposx - pjposx; const Simd ydistance = piposy - pjposy; @@ -1060,7 +1061,7 @@ namespace manualAoSoASIMD return FP{1} / xsimd::sqrt(distSixth); } }(); - const Simd sts = pjmass * invDistCube * timestep; + const Simd sts = pjmass * timestep * invDistCube; pivelx = xdistanceSqr * sts + pivelx; pively = ydistanceSqr * sts + pively; pivelz = zdistanceSqr * sts + pivelz; @@ -1091,7 +1092,7 @@ namespace manualAoSoASIMD const Simd pjposx = blockJ.pos.x.get(j); const Simd pjposy = blockJ.pos.y.get(j); const Simd pjposz = blockJ.pos.z.get(j); - const Simd pjmass = blockJ.mass.get(j); + const FP pjmass = blockJ.mass.get(j); pPInteraction(piposx, piposy, piposz, pivelx, pively, pivelz, pjposx, pjposy, pjposz, pjmass); } @@ -1346,7 +1347,7 @@ namespace manualAoSSIMD const Simd pjposx(pj.pos.x); const Simd pjposy(pj.pos.y); const Simd pjposz(pj.pos.z); - const Simd pjmass(pj.mass); + const FP pjmass(pj.mass); pPInteraction(piposx, piposy, piposz, pivelx, pively, pivelz, pjposx, pjposy, pjposz, pjmass); } @@ -1440,7 +1441,7 @@ namespace manualSoASIMD Simd(posx[j]), Simd(posy[j]), Simd(posz[j]), - Simd(mass[j])); + mass[j]); pivelx.store_aligned(velx + i); pively.store_aligned(vely + i); pivelz.store_aligned(velz + i);