Skip to content

Commit

Permalink
editing example notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
Isha Puri committed Sep 27, 2022
1 parent c0ba34f commit 4744d36
Showing 1 changed file with 14 additions and 35 deletions.
49 changes: 14 additions & 35 deletions examples/cofrnet/cofrnet_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,12 @@
" order 13 interactions. \n",
"'''\n",
"network_depth = 13\n",
"input_size = 30\n",
"output_size = 2 \n",
"input_size = 40\n",
"output_size = 3\n",
"cofrnet_version = \"diag_ladder_of_ladder_combined\"\n",
"#Create CoFrNet\n",
"model = CoFrNet_Model(generate_connections(network_depth, \n",
" input_size, \n",
" output_size, \n",
"model = CoFrNet_Model(generate_connections(network_depth,\n",
" input_size,\n",
" output_size,\n",
" cofrnet_version))\n"
]
},
Expand All @@ -99,37 +98,17 @@
"metadata": {},
"outputs": [],
"source": [
"data = load_breast_cancer()\n",
"X = torch.from_numpy(data['data'])\n",
"y = torch.from_numpy(data['target'])\n",
"from sklearn.preprocessing import LabelEncoder\n",
"le = LabelEncoder()\n",
"y = le.fit_transform(y)\n",
"from sklearn.preprocessing import StandardScaler\n",
"sc = MinMaxScaler(feature_range=(0,1))\n",
"X = sc.fit_transform(X)\n",
"X.argmax()\n",
"from sklearn.model_selection import train_test_split\n",
"X_train, X_test, y_train, y_test = train_test_split(X,\n",
" y,\n",
" test_size = 0.3,\n",
" random_state = 100,\n",
" shuffle = True)\n",
"X_train, X_val, y_train, y_val = train_test_split(X_train,\n",
" y_train,\n",
" test_size=0.05,\n",
" random_state=100,\n",
" shuffle = True)\n",
"#CONVERTING TO TENSOR\n",
"tensor_x_train = torch.Tensor(X_train)\n",
"tensor_x_val = torch.Tensor(X_val)\n",
"tensor_x_test = torch.Tensor(X_test)\n",
"tensor_y_val = torch.Tensor(y_val).long()\n",
"tensor_y_train = torch.Tensor(y_train).long()\n",
"tensor_y_test = torch.Tensor(y_test).long()\n",
"first_column_csv = 0\n",
"last_column_csv = -1\n",
"\n",
"\n",
"web_link = 'http://www.dropbox.com/s/qtdv1teptf097zl/waveformnoise.csv?dl=1'\n",
"tensor_x_train, tensor_y_train, tensor_x_val, tensor_y_val, tensor_x_test, y_test = process_data(first_column_csv = first_column_csv, \n",
" last_column_csv = last_column_csv, \n",
" web_link=web_link)\n",
"\n",
"train_dataset = OnlyTabularDataset(tensor_x_train, \n",
" tensor_y_train)\n",
" tensor_y_train)\n",
"\n",
"batch_size = 100\n",
"dataloader = DataLoader(train_dataset, batch_size) "
Expand Down

0 comments on commit 4744d36

Please sign in to comment.