Skip to content

Commit

Permalink
Add dataset and plot in lib
Browse files Browse the repository at this point in the history
  • Loading branch information
Cylumn committed Aug 2, 2022
1 parent 186d03c commit 4aa6088
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 54 deletions.
Binary file removed data/brushes/32.jpg
Binary file not shown.
1 change: 0 additions & 1 deletion data/brushes/brushes.csv

Large diffs are not rendered by default.

66 changes: 46 additions & 20 deletions lib/dataset.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,59 @@
from torch.utils.data import Dataset
"""
Pytorch dataset for synthetic AKR.
Author: Allen Chang
Date Created: 08/02/2022
"""

import torch
from torch.utils.data import Dataset
from sklearn.preprocessing import StandardScaler

from data import simulate


class AKRDataset(Dataset):
def __init__(self, simulate, n_dataset, brush_range=(None, None), scale: bool = True):
self.img_true = []
self.img_perturbed = []
self.simulate = simulate
brushes = simulate.read_brushes('data', False)[brush_range[0]:brush_range[1]]
"""
Dataset class for loading and scaling synthetic AKR.
"""
def __init__(self, n_dataset, args):
"""
Initializes the dataset class.
:param n_dataset: Size of the dataset.
:param args: Command line arguments.
"""
self.ground_truths = []
self.observations = []
brushes = simulate.read_brushes(args)

for i in range(n_dataset):
y = simulate.ground_truth(brushes)
x = simulate.noise(simulate.noise(y))
for _ in range(n_dataset):
y = simulate.ground_truth(brushes, args)
x = simulate.noise(y, args)

if scale:
sc = StandardScaler()
shape = y.shape
x = torch.tensor(sc.fit_transform(x.reshape(-1, 1)).reshape(-1, shape[0], shape[1]) * 0.2 + 0.2).float()
y = torch.tensor(sc.transform(y.reshape(-1, 1)).reshape(-1, shape[0], shape[1]) * 0.2 + 0.2).float()
# Reshape and scale
if args.disable_dataset_scaling:
x = x[None, :, :]
y = y[None, :, :]
else:
x = torch.tensor(x[None, :, :]).float()
y = torch.tensor(y[None, :, :]).float()
# Scale to N(0, 1)
sc = StandardScaler()
x = sc.fit_transform(x.reshape(-1, 1)).reshape(-1, args.img_size[0], args.img_size[1])
y = sc.transform(y.reshape(-1, 1)).reshape(-1, args.img_size[0], args.img_size[1])
# Rescale with mean and std
x = x * args.dataset_intensity_scale[1] + args.dataset_intensity_scale[0]
y = y * args.dataset_intensity_scale[1] + args.dataset_intensity_scale[0]

self.img_true.append(y)
self.img_perturbed.append(x)
# Add to dataset
self.ground_truths.append(torch.tensor(y).float())
self.observations.append(torch.tensor(x).float())

def __len__(self):
return len(self.img_true)
"""
:return: The length of the AKR dataset.
"""
return len(self.ground_truths)

def __getitem__(self, idx):
return self.img_perturbed[idx], self.img_true[idx]
"""
:return: A tuple of the (x, y) of the dataset.
"""
return self.observations[idx], self.ground_truths[idx]
84 changes: 54 additions & 30 deletions lib/plot.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,92 @@
from datetime import datetime
import numpy as np
import calendar

"""
Helper file to plot spectrograms.
Author: Allen Chang
Date Created: 08/02/2022
"""

def spect_simple(tensor, frequency=None, time=None, start=None, cmap='jet', dpi=400, fs=4, vmin=0, vmax=1):
import matplotlib.pyplot as plt
import matplotlib
import numpy as np

font = {'family': 'DejaVu Sans',
'weight': 'normal',
'size': fs}
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


def spect_simple(data,
frequency=None,
time=None,
start=None,
cmap='jet',
dpi=400,
fs=4,
vmin=0,
vmax=1):
"""
Visualizes a simple spectrogram.
:param data: Input spectrogram.
:param frequency: An array of tick values for the frequency axis of the spectrogram.
:param time: An array of the unix time values for the time axis of the spectrogram.
:param start: Starting unix time for the title of the spectrogram.
:param cmap: Colormap to use for the spectrogram.
:param dpi: Resolution of the figure.
:param fs: Font size.
:param vmin: Lower value bound of the cmap.
:param vmax: Upper value bound of the cmap.
:return: Matplotlib axis object.
"""
# Font to use
font = {'family': 'DejaVu Sans', 'weight': 'normal', 'size': fs}
matplotlib.rc('font', **font)

# Create plot
fig, ax = plt.subplots(figsize=(5, 5), dpi=dpi)

spect = [np.arange(0, 1e6, 1e6 / tensor.shape[0]), np.arange(0, 1e3, 1e3 / tensor.shape[1]), tensor]

spect = [np.arange(0, 1e6, 1e6 / data.shape[0]), np.arange(0, 1e3, 1e3 / data.shape[1]), data]
im = ax.imshow(spect[2], cmap=cmap, origin='lower', vmin=vmin, vmax=vmax, interpolation="nearest")

from mpl_toolkits.axes_grid1 import make_axes_locatable
# Colorbar
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)

plt.colorbar(im, cax=cax)

# Axis and Labeling
if frequency is not None:
ax.set_yticks(np.arange(0, len(frequency), len(frequency) // 8))
ax.set_yticklabels(frequency[::len(frequency) // 8])
if time is not None:
import time as t
from datetime import datetime
assert start is not None

assert (start is not None), "Time should not be plotted without a starting time."

start_struct = t.gmtime(int(start))
x_ticks = np.arange(0, len(time), len(time) // 5)
ax.set_xticks(x_ticks)
ax.set_xticklabels([datetime.utcfromtimestamp(np.real(x)).strftime('%H:%M:%S') for x in time[x_ticks]])
ax.set_title(t.strftime('%Y-%m-%d %H:%M:%S UTC', start_struct), fontsize=5)

plt.show()

return ax


def spects(rows, nrows_ncols, img_shape=(256, 384)):
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
def spects(rows,
nrows_ncols):
"""
Plots multiple spectrograms in an ImageGrid.
:param rows: List of tensors to plot.
:param nrows_ncols: Number of rows, number of columns to plot.
"""
import torch
from mpl_toolkits.axes_grid1 import ImageGrid

# Create figure
fig = plt.figure(figsize=(20., 20.))
rows = [(col.detach() if type(col) == torch.Tensor else torch.tensor(col)) for col in rows]
img_shape = rows[0].shape[-2:]
axis = 0 if (nrows_ncols[0] == 1 or nrows_ncols[1] == 1) else 1
rows = torch.cat(rows, axis=axis).view(nrows_ncols[0], nrows_ncols[1], img_shape[0], img_shape[1])
grid = ImageGrid(fig, 111,
nrows_ncols=nrows_ncols, # creates 2x2 grid of axes
axes_pad=0.1, # pad between axes
)
grid = ImageGrid(fig, 111, nrows_ncols=nrows_ncols, axes_pad=0.1)

# Place the images on the grid
df_grid = rows.reshape(-1, rows.shape[-2], rows.shape[-1])
for ax, im in zip(grid, df_grid):
if len(im.shape) > 2:
im = im.reshape(im.shape[-2:])
ax.imshow(im, origin='lower', cmap='jet', vmin=0, vmax=1)

plt.show()
plt.show()
7 changes: 5 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
argparse
digitalrf
digital_rf
imageio
matplotlib
numpy
pandas
pandas
scikit-learn
torch
9 changes: 8 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,19 @@ def get_args():
# Options
parser.add_argument('--model_name', default='daare_v1', type=str, help='Name of the model when logging and saving.')
parser.add_argument('--verbose', action='store_true', help='Trains with debugging outputs and print statements.')
parser.add_argument('--disable_logs', action='store_false', help='Disables logging to the output log directory.')
parser.add_argument('--disable_logs', action='store_true', help='Disables logging to the output log directory.')
parser.add_argument('--refresh_brushes_file', action='store_true',
help='Rereads brush images and saves them to data/brushes.csv')

# Simulation parameters
# > Ground truth
parser.add_argument('--theta_bg_intensity', default=[0, 0.6], type=float, nargs=2,
help='Bounds of the uniform distribution to draw background intensity.')
parser.add_argument('--theta_n_akr', default=8, type=int,
help='Expected number of akr from the poisson distribution.')
parser.add_argument('--theta_akr_intensity', default=[0, 0.15], type=float, nargs=2,
help='(Before absolute value) mean and std of AKR intensity.')
# > Noise
parser.add_argument('--theta_gaussian_intensity', default=[0.01, 0.04], type=float, nargs=2,
help='Bounds of the uniform distribution to determine the intensity of gaussian noise.')
parser.add_argument('--theta_overall_channel_intensity', default=[0.3, 0.6], type=float, nargs=2,
Expand All @@ -40,6 +42,11 @@ def get_args():
help='Expected *half* height of the channel from the exponential distribution.')
parser.add_argument('--theta_channel_intensity', default=[0.1, 0.8], type=float, nargs=2,
help='Bounds of the uniform distribution to determine the individual intensity of channels.')
# > Simulation scaling
parser.add_argument('--disable_dataset_scaling', action='store_true',
help='Disables scaling of synthetic AKR in the dataset.')
parser.add_argument('--dataset_intensity_scale', default=[0.2, 0.2], type=float, nargs=2,
help='Mean and standard deviation to scale the images to.')

# Model parameters
parser.add_argument('--img_size', default=[256, 384], type=int, nargs=2, help='Input size to DAARE.')
Expand Down

0 comments on commit 4aa6088

Please sign in to comment.