-
Notifications
You must be signed in to change notification settings - Fork 0
Implicit fixed size #107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Implicit fixed size #107
Changes from all commits
9b30f59
61c1441
15481a4
22fe21d
fe47df1
cab9876
e976de9
0da6fa3
80a7b79
273f561
dffd19f
ea75d33
d89be44
71a0015
ab3f3c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| #!/bin/python | ||
|
|
||
| import os | ||
| import glob | ||
| import numpy as np | ||
|
|
||
| import torch | ||
| from torch.utils import data | ||
| import matplotlib.pyplot as plt | ||
|
|
||
|
|
||
| class Dataset(data.Dataset): | ||
| def __init__(self, path): | ||
| self.examples = [a for a in glob.glob(os.path.join(path, "**"), recursive=True) if os.path.isfile(a)] | ||
| print(self.examples[0]) | ||
|
|
||
| def __len__(self): | ||
| return len(self.examples) | ||
|
|
||
| def __getitem__(self, index): | ||
| return np.load(self.examples[index]).transpose(0, 2, 1) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| ds = Dataset("../data/with_100_samples/") | ||
| dl = data.DataLoader(ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True) | ||
|
|
||
| for e in dl: | ||
| print(e.shape) | ||
| exit() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| #!/bin/python3 | ||
|
|
||
| #!/bin/python | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
||
|
|
||
| class ResidualBlock(nn.Module): | ||
|
ccurro marked this conversation as resolved.
|
||
| def __init__(self, num_channels, kernel_size, dilation, bn=False): | ||
| super(ResidualBlock, self).__init__() | ||
| self.num_channels = num_channels | ||
| self.bn = bn | ||
| self.conv1 = nn.Conv1d( | ||
| num_channels, | ||
| num_channels, | ||
| kernel_size=kernel_size, | ||
| dilation=dilation, | ||
| padding=dilation * (kernel_size - 1) // 2) | ||
| self.conv2 = nn.Conv1d( | ||
| num_channels, | ||
| num_channels, | ||
| kernel_size=kernel_size, | ||
| dilation=dilation, | ||
| padding=dilation * (kernel_size - 1) // 2) | ||
| if self.bn: | ||
| self.bn1 = nn.BatchNorm1d(num_channels) | ||
| self.bn2 = nn.BatchNorm1d(num_channels) | ||
|
|
||
| def forward(self, x): | ||
| if self.bn: | ||
| in_ = self.bn1(x) | ||
| else: | ||
| in_ = x | ||
|
|
||
| a = self.conv1(F.relu(in_)) | ||
|
|
||
| if self.bn: | ||
| a = self.bn2(a) | ||
|
|
||
| b = self.conv2(F.relu(a)) | ||
|
|
||
| return x + b | ||
|
|
||
|
|
||
| class Path(nn.Module): | ||
| def __init__(self, | ||
| in_channels=6, | ||
| num_channels=[32, 32, 64, 128], | ||
| kernel_size=3, | ||
| dilations=[1, 1, 1, 1], | ||
| num_components=1, | ||
| bn=False): | ||
| super(Path, self).__init__() | ||
| num_blocks = len(num_channels) - 1 | ||
| if len(dilations) != num_blocks: | ||
| msg = ("Number of dilations must be equal to number of residual " | ||
| "blocks.") | ||
| raise ValueError(msg) | ||
|
|
||
| self.bn = bn | ||
| self.num_blocks = num_blocks | ||
|
|
||
| self.conv1 = nn.Conv1d( | ||
| in_channels, | ||
| num_channels[0], | ||
| kernel_size=kernel_size, | ||
| dilation=1, | ||
| padding=(kernel_size - 1) // 2) | ||
|
|
||
| self.blocks = nn.ModuleList([ | ||
| ResidualBlock( | ||
| num_channels[i], | ||
| kernel_size=kernel_size, | ||
| dilation=dilations[i], | ||
| bn=self.bn) for i in range(self.num_blocks) | ||
| ]) | ||
|
|
||
| self.convs = nn.ModuleList([ | ||
| nn.Conv1d( | ||
| num_channels[i], | ||
| num_channels[i+1], | ||
| kernel_size=kernel_size, | ||
| padding=(kernel_size - 1) // 2) | ||
| for i in range(self.num_blocks) | ||
| ]) | ||
|
|
||
| if self.bn: | ||
| self.bn_out = nn.BatchNorm1d(num_channels) | ||
|
|
||
| def forward(self, x): | ||
| tap = self.conv1(x) | ||
|
|
||
| for i in range(self.num_blocks): | ||
| tap = F.avg_pool1d(self.blocks[i](tap), 2, 2) | ||
| #tap = self.blocks[i](tap) | ||
| tap = F.relu(self.convs[i](tap)) | ||
|
|
||
| return tap | ||
|
|
||
|
|
||
| class Discriminator(nn.Module): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the discriminator now in a separate module?
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I felt like breaking it out of the main as i made it more complicated |
||
| def __init__(self, in_channels=2, num_channels=[32, 32, 32, 32, 32]): | ||
| super(Discriminator, self).__init__() | ||
| self.a = Path(in_channels=in_channels, num_channels=num_channels) | ||
| self.b = Path(in_channels=in_channels, num_channels=num_channels) | ||
| self.c = Path(in_channels=in_channels, num_channels=num_channels) | ||
|
|
||
| self.fc = nn.Linear(num_channels[-1] * 3 * 40, 1) | ||
| for m in self.modules(): | ||
| if 'weight' in m._parameters: | ||
| nn.utils.spectral_norm(m) | ||
|
|
||
| def forward(self, x): | ||
| a = x[:, 0, ...] | ||
| b = x[:, 1, ...] | ||
| c = x[:, 2, ...] | ||
|
|
||
| #feats = torch.cat([self.a(a), self.b(b), self.c(c)], dim=1).reshape(a.shape[0], -1)#.mean(dim=2) | ||
| feats = torch.cat([self.a(a), self.b(b), self.c(c)], dim=1).reshape(a.shape[0], -1)#.mean(dim=2) | ||
|
|
||
| logits = self.fc(feats) | ||
|
|
||
| return logits | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| disc = Discriminator() | ||
| disc(torch.randn(32, 3, 2, 160)) | ||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Uh oh!
There was an error while loading. Please reload this page.