diff --git a/pyscf/tools/trexio.py b/pyscf/tools/trexio.py index 33cfc979..9a2158c3 100644 --- a/pyscf/tools/trexio.py +++ b/pyscf/tools/trexio.py @@ -22,6 +22,7 @@ from pyscf import scf from pyscf import pbc from pyscf import fci +from pyscf import ao2mo import trexio @@ -307,7 +308,8 @@ def scf_from_trexio(filename): mf.mo_occ = mo_occ return mf -def write_eri(eri, filename, backend='h5'): +def write_eri(eri, filename, backend='h5', basis='mo'): + assert basis.upper() in ['MO','AO'] num_integrals = eri.size if eri.ndim == 4: n = eri.shape[0] @@ -330,15 +332,31 @@ def write_eri(eri, filename, backend='h5'): idx = idx[np.tril_indices(npair)] with trexio.File(filename, 'w', back_end=_mode(backend)) as tf: - trexio.write_mo_2e_int_eri(tf, 0, num_integrals, idx, eri.ravel()) + if basis.upper() == 'MO': + trexio.write_mo_2e_int_eri(tf, 0, num_integrals, idx, eri.ravel()) + else: + trexio.write_ao_2e_int_eri(tf, 0, num_integrals, idx, eri.ravel()) + +def write_scf_eri(mf, filename, backend='h5', basis='mo'): + assert basis.upper() in ['MO','AO'] + if basis.upper() == 'MO': + write_eri(ao2mo.kernel(mf._eri, mf.mo_coeff), filename, backend, basis) + else: + write_eri(mf._eri, filename, backend, basis) -def read_eri(filename): + +def read_eri(filename, basis='mo'): '''Read ERIs in AO basis, 8-fold symmetry is assumed''' + assert basis.upper() in ['MO','AO'] + basis_is_mo = (basis.upper() == 'MO') with trexio.File(filename, 'r', back_end=trexio.TREXIO_AUTO) as tf: - nmo = trexio.read_mo_num(tf) - nao_pair = nmo * (nmo+1) // 2 - eri_size = nao_pair * (nao_pair+1) // 2 - idx, data, n_read, eof_flag = trexio.read_mo_2e_int_eri(tf, 0, eri_size) + norb = trexio.read_mo_num(tf) if basis_is_mo else trexio.read_ao_num(tf) + norb_pair = norb * (norb+1) // 2 + eri_size = norb_pair * (norb_pair+1) // 2 + if basis_is_mo: + idx, data, n_read, eof_flag = trexio.read_mo_2e_int_eri(tf, 0, eri_size) + else: + idx, data, n_read, eof_flag = trexio.read_ao_2e_int_eri(tf, 0, eri_size) eri = np.zeros(eri_size) x = idx[:,0]*(idx[:,0]+1)//2 + idx[:,1] y = idx[:,2]*(idx[:,2]+1)//2 + idx[:,3]