Skip to content

Commit b21f3d0

Browse files
committed
Add all CSW related files for EGO Model
1 parent dc62af6 commit b21f3d0

File tree

5 files changed

+85
-1
lines changed

5 files changed

+85
-1
lines changed
File renamed without changes.

Diff for: test_models/EGO Model - CSW with Simple Integrator.py renamed to test_models/CSW/EGO Model - CSW with Simple Integrator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import matplotlib.pyplot as plt
1212
import torch
1313
import TestParams
14-
import DeclanParams
14+
import test_models.CSW.DeclanParams as DeclanParams
1515
import timeit
1616
import psyneulink as pnl
1717
torch.manual_seed(0)

Diff for: test_models/CSW/Environment.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import numpy as np
2+
import torch
3+
from torch.utils.data import dataset
4+
from torch import utils
5+
from random import randint
6+
7+
def one_hot_encode(labels, num_classes):
8+
"""
9+
One hot encode labels and convert to tensor.
10+
"""
11+
return torch.tensor((np.arange(num_classes) == labels[..., None]).astype(float),dtype=torch.float32)
12+
13+
class DeterministicCSWDataset(dataset.Dataset):
14+
def __init__(self, n_samples_per_context, contexts_to_load) -> None:
15+
super().__init__()
16+
raw_xs = np.array([
17+
[[9,1,3,5,7],[9,2,4,6,8]],
18+
[[10,1,4,5,8],[10,2,3,6,7]]
19+
])
20+
21+
item_indices = np.random.choice(raw_xs.shape[1],sum(n_samples_per_context),replace=True)
22+
task_names = [0,1] # Flexible so these can be renamed later
23+
task_indices = [task_names.index(name) for name in contexts_to_load]
24+
25+
context_indices = np.repeat(np.array(task_indices),n_samples_per_context)
26+
self.xs = one_hot_encode(raw_xs[context_indices,item_indices],11)
27+
28+
self.xs = self.xs.reshape((-1,11))
29+
self.ys = torch.cat([self.xs[1:],one_hot_encode(np.array([0]),11)],dim=0)
30+
context_indices = np.repeat(np.array(task_indices),[x*5 for x in n_samples_per_context])
31+
self.contexts = one_hot_encode(context_indices, len(task_names))
32+
33+
# Remove the last transition since there's no next state available
34+
self.xs = self.xs[:-1]
35+
self.ys = self.ys[:-1]
36+
self.contexts = self.contexts[:-1]
37+
38+
def __len__(self):
39+
return len(self.xs)
40+
41+
def __getitem__(self, idx):
42+
return self.xs[idx], self.contexts[idx], self.ys[idx]
43+
44+
def generate_dataset(condition='Blocked'):
45+
# Generate the dataset for either the blocked or interleaved condition
46+
if condition=='Blocked':
47+
contexts_to_load = [0,1,0,1] + [randint(0,1) for _ in range(40)]
48+
n_samples_per_context = [40,40,40,40] + [1]*40
49+
elif condition == 'Interleaved':
50+
contexts_to_load = [0,1]*80 + [randint(0,1) for _ in range(40)]
51+
n_samples_per_context = [1]*160 + [1]*40
52+
else:
53+
raise ValueError(f'Unknown dataset condition: {condition}')
54+
55+
return DeterministicCSWDataset(n_samples_per_context, contexts_to_load)

Diff for: test_models/CSW/ScriptControl.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from psyneulink.core.compositions.report import ReportOutput, ReportProgress
2+
3+
# Settings for running script:
4+
5+
MODEL_PARAMS = 'TestParams'
6+
# MODEL_PARAMS = 'DeclanParams'
7+
8+
CONSTRUCT_MODEL = True # THIS MUST BE SET TO True to run the script
9+
DISPLAY_MODEL = ( # Only one of the following can be uncommented:
10+
None # suppress display of model
11+
# { # show simple visual display of model
12+
# 'show_pytorch': True, # show pytorch graph of model
13+
# 'show_learning': True
14+
# # 'show_projections_not_in_composition': True,
15+
# # 'exclude_from_gradient_calc_style': 'dashed'# show target mechanisms for learning
16+
# # {'show_node_structure': True # show detailed view of node structures and projections
17+
# }
18+
)
19+
# RUN_MODEL = False # False => don't run the model
20+
RUN_MODEL = True, # True => run the model
21+
# REPORT_OUTPUT = ReportOutput.FULL # Sets console output during run [ReportOutput.ON, .TERSE OR .FULL]
22+
REPORT_OUTPUT = ReportOutput.OFF # Sets console output during run [ReportOutput.ON, .TERSE OR .FULL]
23+
REPORT_PROGRESS = ReportProgress.OFF # Sets console progress bar during run
24+
PRINT_RESULTS = False # don't print model.results to console after execution
25+
# PRINT_RESULTS = True # print model.results to console after execution
26+
SAVE_RESULTS = False # save model.results to disk
27+
# PLOT_RESULTS = False # don't plot results (PREDICTIONS) vs. TARGETS
28+
PLOT_RESULTS = True # plot results (PREDICTIONS) vs. TARGETS
29+
ANIMATE = False # {UNIT:EXECUTION_SET} # Specifies whether to generate animation of execution
File renamed without changes.

0 commit comments

Comments
 (0)