-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmain.py
More file actions
59 lines (51 loc) · 2.02 KB
/
main.py
File metadata and controls
59 lines (51 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from datetime import datetime
from rlmm.environment.openmmEnv import OpenMMEnv
from rlmm.utils.config import Config
from rlmm.rl.Expert import ExpertPolicy, RandomPolicy
import pickle
import os
def setup_temp_files(config):
try:
os.mkdir(config.configs['tempdir'])
except FileExistsError:
pass
if config.configs['tempdir'][-1] != '/':
config.configs['tempdir'] = config.configs['tempdir'] + "/"
config.configs['tempdir'] = config.configs['tempdir'] + "{}/".format( datetime.now().strftime("rlmm_%d_%m_%YT%H%M%S"))
try:
os.mkdir(config.configs['tempdir'])
except FileExistsError:
print("Somehow the directory already exists... exiting")
exit()
for k ,v in config.configs.items():
if k in ['actions', 'systemloader', 'openmmWrapper', 'obsmethods']:
for k_, v_ in config.configs.items():
if k_ != k:
v.update(k_, v_)
print("?",config.configs)
def test_load_test_system():
import logging
import warnings
import shutil
from openeye import oechem
oechem.OEThrow.SetLevel(oechem.OEErrorLevel_Warning)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logging.getLogger('openforcefield').setLevel(logging.CRITICAL)
warnings.filterwarnings("ignore")
config = Config.load_yaml('rlmm/tests/test_config.yaml')
setup_temp_files(config)
shutil.copy('rlmm/tests/test_config.yaml', config.configs['tempdir'] + "config.yaml")
env = OpenMMEnv(OpenMMEnv.Config(config.configs))
policy = ExpertPolicy(env,num_returns=-1, step_size=1.0, orig_pdb=config.configs['systemloader'].pdb_file_name)
first_obs = env.reset()
energies = []
for i in range(100):
choice = policy.choose_action()
print("Action taken: ", choice[1])
_, _, _, data = env.step(choice)
energies.append(data['energies'])
with open("rundata.pkl", 'wb') as f:
pickle.dump(env.data, f)
if __name__ == '__main__':
test_load_test_system()