-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
113 additions
and
54 deletions.
There are no files selected for viewing
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,59 @@ | ||
from torch.utils.data import Dataset | ||
""" | ||
Pytorch dataset for synthetic AKR. | ||
Author: Allen Chang | ||
Date Created: 08/02/2022 | ||
""" | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
from sklearn.preprocessing import StandardScaler | ||
|
||
from data import simulate | ||
|
||
|
||
class AKRDataset(Dataset): | ||
def __init__(self, simulate, n_dataset, brush_range=(None, None), scale: bool = True): | ||
self.img_true = [] | ||
self.img_perturbed = [] | ||
self.simulate = simulate | ||
brushes = simulate.read_brushes('data', False)[brush_range[0]:brush_range[1]] | ||
""" | ||
Dataset class for loading and scaling synthetic AKR. | ||
""" | ||
def __init__(self, n_dataset, args): | ||
""" | ||
Initializes the dataset class. | ||
:param n_dataset: Size of the dataset. | ||
:param args: Command line arguments. | ||
""" | ||
self.ground_truths = [] | ||
self.observations = [] | ||
brushes = simulate.read_brushes(args) | ||
|
||
for i in range(n_dataset): | ||
y = simulate.ground_truth(brushes) | ||
x = simulate.noise(simulate.noise(y)) | ||
for _ in range(n_dataset): | ||
y = simulate.ground_truth(brushes, args) | ||
x = simulate.noise(y, args) | ||
|
||
if scale: | ||
sc = StandardScaler() | ||
shape = y.shape | ||
x = torch.tensor(sc.fit_transform(x.reshape(-1, 1)).reshape(-1, shape[0], shape[1]) * 0.2 + 0.2).float() | ||
y = torch.tensor(sc.transform(y.reshape(-1, 1)).reshape(-1, shape[0], shape[1]) * 0.2 + 0.2).float() | ||
# Reshape and scale | ||
if args.disable_dataset_scaling: | ||
x = x[None, :, :] | ||
y = y[None, :, :] | ||
else: | ||
x = torch.tensor(x[None, :, :]).float() | ||
y = torch.tensor(y[None, :, :]).float() | ||
# Scale to N(0, 1) | ||
sc = StandardScaler() | ||
x = sc.fit_transform(x.reshape(-1, 1)).reshape(-1, args.img_size[0], args.img_size[1]) | ||
y = sc.transform(y.reshape(-1, 1)).reshape(-1, args.img_size[0], args.img_size[1]) | ||
# Rescale with mean and std | ||
x = x * args.dataset_intensity_scale[1] + args.dataset_intensity_scale[0] | ||
y = y * args.dataset_intensity_scale[1] + args.dataset_intensity_scale[0] | ||
|
||
self.img_true.append(y) | ||
self.img_perturbed.append(x) | ||
# Add to dataset | ||
self.ground_truths.append(torch.tensor(y).float()) | ||
self.observations.append(torch.tensor(x).float()) | ||
|
||
def __len__(self): | ||
return len(self.img_true) | ||
""" | ||
:return: The length of the AKR dataset. | ||
""" | ||
return len(self.ground_truths) | ||
|
||
def __getitem__(self, idx): | ||
return self.img_perturbed[idx], self.img_true[idx] | ||
""" | ||
:return: A tuple of the (x, y) of the dataset. | ||
""" | ||
return self.observations[idx], self.ground_truths[idx] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,92 @@ | ||
from datetime import datetime | ||
import numpy as np | ||
import calendar | ||
|
||
""" | ||
Helper file to plot spectrograms. | ||
Author: Allen Chang | ||
Date Created: 08/02/2022 | ||
""" | ||
|
||
def spect_simple(tensor, frequency=None, time=None, start=None, cmap='jet', dpi=400, fs=4, vmin=0, vmax=1): | ||
import matplotlib.pyplot as plt | ||
import matplotlib | ||
import numpy as np | ||
|
||
font = {'family': 'DejaVu Sans', | ||
'weight': 'normal', | ||
'size': fs} | ||
import numpy as np | ||
import matplotlib | ||
import matplotlib.pyplot as plt | ||
from mpl_toolkits.axes_grid1 import make_axes_locatable | ||
|
||
|
||
def spect_simple(data, | ||
frequency=None, | ||
time=None, | ||
start=None, | ||
cmap='jet', | ||
dpi=400, | ||
fs=4, | ||
vmin=0, | ||
vmax=1): | ||
""" | ||
Visualizes a simple spectrogram. | ||
:param data: Input spectrogram. | ||
:param frequency: An array of tick values for the frequency axis of the spectrogram. | ||
:param time: An array of the unix time values for the time axis of the spectrogram. | ||
:param start: Starting unix time for the title of the spectrogram. | ||
:param cmap: Colormap to use for the spectrogram. | ||
:param dpi: Resolution of the figure. | ||
:param fs: Font size. | ||
:param vmin: Lower value bound of the cmap. | ||
:param vmax: Upper value bound of the cmap. | ||
:return: Matplotlib axis object. | ||
""" | ||
# Font to use | ||
font = {'family': 'DejaVu Sans', 'weight': 'normal', 'size': fs} | ||
matplotlib.rc('font', **font) | ||
|
||
# Create plot | ||
fig, ax = plt.subplots(figsize=(5, 5), dpi=dpi) | ||
|
||
spect = [np.arange(0, 1e6, 1e6 / tensor.shape[0]), np.arange(0, 1e3, 1e3 / tensor.shape[1]), tensor] | ||
|
||
spect = [np.arange(0, 1e6, 1e6 / data.shape[0]), np.arange(0, 1e3, 1e3 / data.shape[1]), data] | ||
im = ax.imshow(spect[2], cmap=cmap, origin='lower', vmin=vmin, vmax=vmax, interpolation="nearest") | ||
|
||
from mpl_toolkits.axes_grid1 import make_axes_locatable | ||
# Colorbar | ||
divider = make_axes_locatable(ax) | ||
cax = divider.append_axes("right", size="5%", pad=0.05) | ||
|
||
plt.colorbar(im, cax=cax) | ||
|
||
# Axis and Labeling | ||
if frequency is not None: | ||
ax.set_yticks(np.arange(0, len(frequency), len(frequency) // 8)) | ||
ax.set_yticklabels(frequency[::len(frequency) // 8]) | ||
if time is not None: | ||
import time as t | ||
from datetime import datetime | ||
assert start is not None | ||
|
||
assert (start is not None), "Time should not be plotted without a starting time." | ||
|
||
start_struct = t.gmtime(int(start)) | ||
x_ticks = np.arange(0, len(time), len(time) // 5) | ||
ax.set_xticks(x_ticks) | ||
ax.set_xticklabels([datetime.utcfromtimestamp(np.real(x)).strftime('%H:%M:%S') for x in time[x_ticks]]) | ||
ax.set_title(t.strftime('%Y-%m-%d %H:%M:%S UTC', start_struct), fontsize=5) | ||
|
||
plt.show() | ||
|
||
return ax | ||
|
||
|
||
def spects(rows, nrows_ncols, img_shape=(256, 384)): | ||
import matplotlib.pyplot as plt | ||
from mpl_toolkits.axes_grid1 import ImageGrid | ||
def spects(rows, | ||
nrows_ncols): | ||
""" | ||
Plots multiple spectrograms in an ImageGrid. | ||
:param rows: List of tensors to plot. | ||
:param nrows_ncols: Number of rows, number of columns to plot. | ||
""" | ||
import torch | ||
from mpl_toolkits.axes_grid1 import ImageGrid | ||
|
||
# Create figure | ||
fig = plt.figure(figsize=(20., 20.)) | ||
rows = [(col.detach() if type(col) == torch.Tensor else torch.tensor(col)) for col in rows] | ||
img_shape = rows[0].shape[-2:] | ||
axis = 0 if (nrows_ncols[0] == 1 or nrows_ncols[1] == 1) else 1 | ||
rows = torch.cat(rows, axis=axis).view(nrows_ncols[0], nrows_ncols[1], img_shape[0], img_shape[1]) | ||
grid = ImageGrid(fig, 111, | ||
nrows_ncols=nrows_ncols, # creates 2x2 grid of axes | ||
axes_pad=0.1, # pad between axes | ||
) | ||
grid = ImageGrid(fig, 111, nrows_ncols=nrows_ncols, axes_pad=0.1) | ||
|
||
# Place the images on the grid | ||
df_grid = rows.reshape(-1, rows.shape[-2], rows.shape[-1]) | ||
for ax, im in zip(grid, df_grid): | ||
if len(im.shape) > 2: | ||
im = im.reshape(im.shape[-2:]) | ||
ax.imshow(im, origin='lower', cmap='jet', vmin=0, vmax=1) | ||
|
||
plt.show() | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
argparse | ||
digitalrf | ||
digital_rf | ||
imageio | ||
matplotlib | ||
numpy | ||
pandas | ||
pandas | ||
scikit-learn | ||
torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters