-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualize.py
61 lines (47 loc) · 1.54 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import matplotlib.pyplot as plt
import pickle
def load_training_logs(log_file):
"""
Load training logs from a pickle file.
"""
with open(log_file, 'rb') as f:
logs = pickle.load(f)
return logs
logs = load_training_logs('training_logs.pkl')
def extract_metrics_from_logs(logs):
"""
Extract metrics such as total rewards and steps from training logs.
"""
episodes = []
rewards_history = []
steps_history = []
for log in logs:
episodes.append(log['episode'])
total_rewards = sum(log['total_rewards'].values())
total_steps = len(log['steps'])
rewards_history.append(total_rewards)
steps_history.append(total_steps)
return episodes, rewards_history, steps_history
episodes, rewards_history, steps_history = extract_metrics_from_logs(logs)
def plot_metrics(episodes, rewards_history, steps_history):
"""
Plot training metrics: rewards and steps over episodes.
"""
plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(episodes, rewards_history, label='Total Rewards', color='blue')
plt.xlabel('Episodes')
plt.ylabel('Rewards')
plt.title('Total Rewards Over Episodes')
plt.legend()
plt.grid()
plt.subplot(2, 1, 2)
plt.plot(episodes, steps_history, label='Total Steps', color='orange')
plt.xlabel('Episodes')
plt.ylabel('Steps')
plt.title('Total Steps Over Episodes')
plt.legend()
plt.grid()
plt.tight_layout()
plt.show()
plot_metrics(episodes, rewards_history, steps_history)