Skip to content

Commit d871c34

Browse files
author
Hongwei
committedAug 10, 2022
add mpi support
1 parent d8cd1d4 commit d871c34

File tree

1 file changed

+135
-0
lines changed

1 file changed

+135
-0
lines changed
 

‎vmc_jastrow_cp_exact_mpi.py

+135
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import numpy as np
2+
from mpi4py import MPI
3+
import math
4+
import time
5+
import numba
6+
from matplotlib import pyplot
7+
8+
N = 4
9+
10+
@numba.njit
11+
def coefficient(state, alpha):
12+
13+
ssum = 0.0
14+
for i in range(N):
15+
for j in range(i+1, N):
16+
17+
deno = min(math.fabs(1.0*j - 1.0*i), N*1.0 - math.fabs(1.0*j - 1.0*i) )
18+
# print(state[i] * state[j] / deno)
19+
ssum += state[i] * state[j] / deno
20+
21+
return math.exp(-alpha * ssum)
22+
23+
@numba.njit
24+
def local_energy(state, coeff, alpha):
25+
26+
res = 0.0
27+
ssum = 0.0
28+
29+
for i in range(N):
30+
res += state[i] * state[(i+1)%N]
31+
32+
for i in range(N):
33+
if(state[i] * state[(i+1)%N] < 0.0):
34+
state_new = state.copy()
35+
# print(state_new)
36+
state_new[i] *= -1.0
37+
state_new[(i+1)%N] *= -1.0
38+
39+
ssum += coefficient(state_new, alpha)/coeff
40+
41+
return res - 0.5 * ssum
42+
43+
44+
@numba.njit
45+
def sampler(alpha, Nsample = 5000, Nskip = 3):
46+
47+
state = np.ones(N)
48+
state[: N//2] = -1
49+
50+
state *= 0.5
51+
state = state[np.random.permutation(N)]
52+
53+
ssum = 0.0
54+
# coeff_old = coefficient(state, alpha)
55+
56+
for i in range(Nsample):
57+
58+
for i in range(Nskip):
59+
60+
x = np.random.randint(low = 0, high = N)
61+
y = x
62+
63+
while(state[y] * state[x] > 0):
64+
y = np.random.randint(low = 0, high = N)
65+
66+
new_state = state.copy()
67+
new_state[x] *= -1.0
68+
new_state[y] *= -1.0
69+
70+
coeff_old = coefficient(state, alpha)
71+
coeff_new = coefficient(new_state, alpha)
72+
73+
if(np.random.random() < min(1.0, (coeff_new**2)/(coeff_old**2))):
74+
state = new_state.copy()
75+
coeff_old = coeff_new
76+
77+
tmp = local_energy(state, coeff_old, alpha)
78+
79+
80+
ssum += tmp
81+
82+
return ssum / Nsample
83+
84+
85+
if(__name__ == '__main__'):
86+
87+
comm = MPI.COMM_WORLD
88+
nprocs = comm.Get_size()
89+
rank = comm.Get_rank()
90+
91+
ns = 10000
92+
ns = ns // nprocs
93+
94+
95+
if(rank == 0):
96+
x, y = [], []
97+
98+
99+
t0 = time.time()
100+
101+
for i in range(-30, 40):
102+
103+
alpha = i * 0.1
104+
105+
# comm.Barrier()
106+
mpi_energy = sampler(alpha, ns) / nprocs
107+
108+
energy = comm.reduce(mpi_energy, root=0)
109+
110+
if(rank == 0):
111+
print("Alpha: %.2f, Energy: %.2f" % (alpha, energy))
112+
x.append(alpha)
113+
y.append(energy)
114+
115+
if(rank == 0):
116+
117+
t1 = time.time()
118+
print("Elapsed time: %.2f sec" % (t1 - t0))
119+
120+
pyplot.xlabel("alpha")
121+
pyplot.ylabel("Energy")
122+
pyplot.plot(x, y, 'o', label="VMC")
123+
pyplot.legend()
124+
pyplot.show()
125+
126+
127+
128+
129+
130+
131+
132+
133+
134+
135+

0 commit comments

Comments
 (0)
Please sign in to comment.