Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added Monocraft.ttf
Binary file not shown.
16 changes: 8 additions & 8 deletions train.odin
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ load_mnist_data :: proc(path: string, size: int) -> (ret: [dynamic]MnistRecord,
if ferr != 0 do return
defer os.close(f)

r: bufio.Reader
r: bufio.Reader
buffer: [1024]byte
bufio.reader_init_with_buf(&r, os.stream_from_handle(f), buffer[:])
defer bufio.reader_destroy(&r)
Expand All @@ -26,18 +26,18 @@ load_mnist_data :: proc(path: string, size: int) -> (ret: [dynamic]MnistRecord,

i := 0
ret = make([dynamic]MnistRecord, size)
for {
defer i += 1
for {
defer i += 1
line, err := bufio.reader_read_string(&r, '\n', context.temp_allocator)
if err != nil || i >= size - 1 {
break
}
if err != nil || i >= size - 1 {
break
}

// Process line
values := split_u8_string(line)
ret[i].label = values[0]
for j in 1..=MNIST_IMG_DATA_LEN {
ret[i].pixels[j-1] = f32 (values[j]) / 255.0
for j in 0..<MNIST_IMG_DATA_LEN {
ret[i].pixels[j] = f32 (values[j]) / 255.0
}
}

Expand Down
237 changes: 207 additions & 30 deletions viz.odin
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package main
import "core:c"
import "core:fmt"
import "core:math"
import "core:math/linalg"
import sort "core:sort"
import rand "core:math/rand"
import rl "vendor:raylib"
Expand All @@ -16,8 +17,10 @@ COLOR_ACTIVATION :: rl.ORANGE
COLOR_GRAD :: rl.GREEN

// MARK: Globals

g_font_20: rl.Font
g_font_30: rl.Font
g_camera3d: rl.Camera3D
g_saved_mouse_position: rl.Vector2
g_cam_angle: f32 = 0
g_img_input: MnistRecord
g_flags: Flags
Expand Down Expand Up @@ -75,14 +78,31 @@ SceneObject :: struct {
@private
viz_init :: proc() -> (err: bool) {
// Raylib init
rl.SetWindowState({.MSAA_4X_HINT})
rl.InitWindow(WINDOW_W, WINDOW_H, "NN")
rl.SetWindowState(rl.ConfigFlags{.WINDOW_RESIZABLE})
rl.SetTargetFPS(FPS)

// Font Setup
codepoints: [126-32+1]rune
for i in 32..=126 {
codepoints[i-32] = rune(i)
}
g_font_20 = rl.LoadFontEx("Monocraft.ttf", 20, &codepoints[0], len(codepoints))
g_font_30 = rl.LoadFontEx("Monocraft.ttf", 30, &codepoints[0], len(codepoints))

// Do not mem init the loaded net
net_err := net_load(&g_net)
if net_err do return true
if net_err do return true

// Init cam and settings
viz_init_transient()

return false
}

@private
viz_init_transient :: proc() {
// Cam setup
g_camera3d.position = rl.Vector3{0, 30, 0}
g_camera3d.target = rl.Vector3{0, 0, 0}
Expand All @@ -96,8 +116,6 @@ viz_init :: proc() -> (err: bool) {
g_flags.draw_cubes = true
g_flags.draw_cube_lines = true
g_flags.load_test_imgs = true

return false
}

@private
Expand All @@ -120,6 +138,119 @@ viz_update :: proc(test_img: ^MnistRecord) {
g_cam_angle += CAM_REVOLUTION_SPEED * rl.GetFrameTime()
g_camera3d.position.x = math.cos(g_cam_angle) * CAM_REVOLUTION_RADIUS
g_camera3d.position.z = math.sin(g_cam_angle) * CAM_REVOLUTION_RADIUS
} else {
// Camera Variables

dist := linalg.length(g_camera3d.position - g_camera3d.target)
norm := linalg.normalize(g_camera3d.position - g_camera3d.target)
cross := linalg.cross(norm, rl.Vector3{0, 1, 0})

// Orbital Camera Movement - Mouse

if rl.IsMouseButtonPressed(.MIDDLE) {
rl.HideCursor()
g_saved_mouse_position = rl.GetMousePosition()
}
if rl.IsMouseButtonDown(.MIDDLE) {
mouse_delta := rl.GetMouseDelta()
delta_x := mouse_delta.x * (1.0/f32(rl.GetScreenHeight()))
delta_y := mouse_delta.y * (1.0/f32(rl.GetScreenHeight()))

g_camera3d.position += delta_x * cross * 60
g_camera3d.position.y += delta_y * 60
new_norm := linalg.normalize(g_camera3d.position - g_camera3d.target)
g_camera3d.position = g_camera3d.target + dist * new_norm

rl.SetMousePosition(i32(g_saved_mouse_position.x), i32(g_saved_mouse_position.y))
}
if rl.IsMouseButtonReleased(.MIDDLE) {
rl.ShowCursor()
}

// Rotational Camera Movement - Mouse

if rl.IsMouseButtonPressed(.RIGHT) {
rl.HideCursor()
g_saved_mouse_position = rl.GetMousePosition()
}
if rl.IsMouseButtonDown(.RIGHT) {
mouse_delta := rl.GetMouseDelta()
delta_x := mouse_delta.x * (1.0/f32(rl.GetScreenHeight()))
delta_y := mouse_delta.y * (1.0/f32(rl.GetScreenHeight()))

rl.UpdateCameraPro(&g_camera3d, {}, {delta_x * 30, delta_y * 30, 0}, 0)

rl.SetMousePosition(i32(g_saved_mouse_position.x), i32(g_saved_mouse_position.y))
}
if rl.IsMouseButtonReleased(.RIGHT) {
rl.ShowCursor()
}

// Camera Zoom - Mouse

mouse_wheel := rl.GetMouseWheelMoveV()
rl.UpdateCameraPro(&g_camera3d, 0, 0, -mouse_wheel.y * 5)

// Camera Zoom - Keyboard
if rl.IsKeyDown(.PAGE_UP) {
rl.UpdateCameraPro(&g_camera3d, 0, 0, -1)
}

if rl.IsKeyDown(.PAGE_DOWN) {
rl.UpdateCameraPro(&g_camera3d, 0, 0, 1)
}

// Free Camera Movement - Keyboard

if rl.IsKeyDown(.W) {
rl.UpdateCameraPro(&g_camera3d, {1, 0, 0}, 0, 0)
}

if rl.IsKeyDown(.A) {
rl.UpdateCameraPro(&g_camera3d, {0, -1, 0}, 0, 0)
}

if rl.IsKeyDown(.S) {
rl.UpdateCameraPro(&g_camera3d, {-1, 0, 0}, 0, 0)
}

if rl.IsKeyDown(.D) {
rl.UpdateCameraPro(&g_camera3d, {0, 1, 0}, 0, 0)
}

if rl.IsKeyDown(.LEFT_SHIFT) {
rl.UpdateCameraPro(&g_camera3d, {0, 0, 1}, 0, 0)
}

if rl.IsKeyDown(.LEFT_CONTROL) {
rl.UpdateCameraPro(&g_camera3d, {0, 0, -1}, 0, 0)
}

if rl.IsKeyDown(.Z) {
g_camera3d.position.y -= 1
new_norm := linalg.normalize(g_camera3d.position - g_camera3d.target)
g_camera3d.position = g_camera3d.target + dist * new_norm
}

if rl.IsKeyDown(.X) {
g_camera3d.position.y += 1
new_norm := linalg.normalize(g_camera3d.position - g_camera3d.target)
g_camera3d.position = g_camera3d.target + dist * new_norm
}

if rl.IsKeyDown(.Q) {
g_camera3d.position += cross
}

if rl.IsKeyDown(.E) {
g_camera3d.position -= cross
}

// Reset

if rl.IsKeyPressed(.K) {
viz_init_transient()
}
}
if g_flags.load_test_imgs {
g_img_input.pixels = test_img.pixels
Expand Down Expand Up @@ -156,13 +287,20 @@ viz_update :: proc(test_img: ^MnistRecord) {
// MARK: Draw Root

draw_2d :: proc(pred_idx: int, pred_accuracy: f32) {
rl.DrawRectangle(0, 0, 285, rl.GetScreenHeight(), rl.Fade(rl.DARKGRAY, 0.8))
draw_settings()
draw_2d_image_input_grid(30, 250)
rl.DrawFPS(30, 550)
if ui_button("Clear Input", {30, 490}) {
g_img_input.pixels = {}
}
rl.DrawFPS(30, 590)
if ui_button("Reset", {30, 650}) {
viz_init_transient()
}

result := fmt.tprintf("RES: %d: %.2f%%", pred_idx, pred_accuracy * 100)
fps := fmt.tprintf("RES: %d: %.2f%%", pred_idx, pred_accuracy * 100)
rl.DrawText(cstring(raw_data(result)), 30, 500, 30, rl.WHITE)
rl.DrawTextEx(g_font_30, cstring(raw_data(result)), {30, 550}, 30, 0, rl.WHITE)

// // Draw inference results
// START_X :: 700
Expand Down Expand Up @@ -537,24 +675,16 @@ collect_output_layer_shapes :: proc(shapes: ^[dynamic]Shape, prediction_idx: int
// MARK: !! 2D !!

draw_settings :: proc() {
ui_checkbox("Rotate Cam", {30, 30}, g_flags.cam_rotate, proc() {
g_flags.cam_rotate = !g_flags.cam_rotate
})
ui_checkbox("Show Connections", {30, 60}, g_flags.draw_connections, proc() {
g_flags.draw_connections = !g_flags.draw_connections
})
ui_checkbox("Show Cube Lines", {30, 90}, g_flags.draw_cube_lines, proc() {
g_flags.draw_cube_lines = !g_flags.draw_cube_lines
})
ui_checkbox("Show Cubes", {30, 120}, g_flags.draw_cubes, proc() {
g_flags.draw_cubes = !g_flags.draw_cubes
})
ui_checkbox("Show Weight Cloud", {30, 150}, g_flags.draw_weight_cloud, proc() {
g_flags.draw_weight_cloud = !g_flags.draw_weight_cloud
})
ui_checkbox("Load Test Images", {30, 180}, g_flags.load_test_imgs, proc() {
g_flags.load_test_imgs = !g_flags.load_test_imgs
})
temp := g_flags.cam_rotate
if ui_checkbox("Rotate Cam", {30, 30}, &g_flags.cam_rotate) && !temp {
g_cam_angle = 0
g_camera3d.target = {}
}
ui_checkbox("Show Connections", {30, 60}, &g_flags.draw_connections)
ui_checkbox("Show Cube Lines", {30, 90}, &g_flags.draw_cube_lines)
ui_checkbox("Show Cubes", {30, 120}, &g_flags.draw_cubes)
ui_checkbox("Show Weight Cloud", {30, 150}, &g_flags.draw_weight_cloud)
ui_checkbox("Load Test Images", {30, 180}, &g_flags.load_test_imgs)
}

draw_2d_image_input_grid :: proc(x_offset: int, y_offset: int) {
Expand Down Expand Up @@ -651,7 +781,7 @@ handle_keyboard_input :: proc() {

// MARK: UI

ui_checkbox :: proc(label: string, pos: rl.Vector2, is_enabled: bool, onClick: proc()) {
ui_checkbox :: proc(label: string, pos: rl.Vector2, is_enabled: ^bool) -> bool {
checkbox_size: i32 = 20
checkbox_enabled_size: i32 = 14
enabled_size_diff: i32 = (checkbox_size - checkbox_enabled_size) / 2
Expand All @@ -666,7 +796,7 @@ ui_checkbox :: proc(label: string, pos: rl.Vector2, is_enabled: bool, onClick: p
checkbox_size, checkbox_size,
rl.WHITE
)
if is_enabled {
if is_enabled^ {
rl.DrawRectangle(
i32(pos.x) + enabled_size_diff, i32(pos.y) + enabled_size_diff,
checkbox_enabled_size, checkbox_enabled_size,
Expand All @@ -677,9 +807,12 @@ ui_checkbox :: proc(label: string, pos: rl.Vector2, is_enabled: bool, onClick: p
// Draw text
text_pos_x := i32(pos.x) + checkbox_size + 10
text_pos_y := i32(pos.y) + checkbox_size / 2 - 10
rl.DrawText(
rl.DrawTextEx(
g_font_20,
cstring(raw_data(label)),
text_pos_x, text_pos_y, text_size,
{f32(text_pos_x), f32(text_pos_y)},
f32(text_size),
0,
rl.WHITE
)

Expand All @@ -689,9 +822,53 @@ ui_checkbox :: proc(label: string, pos: rl.Vector2, is_enabled: bool, onClick: p
rl.GetMousePosition(),
{pos.x, pos.y, total_width, total_height}
)
if is_mouse_on_area && rl.IsMouseButtonPressed(.LEFT) {
onClick()

is_clicked := is_mouse_on_area && rl.IsMouseButtonPressed(.LEFT)
if is_clicked {
is_enabled^ = !(is_enabled^)
}
return is_clicked
}

ui_button :: proc(label: string, pos: rl.Vector2) -> bool {
button_size: f32 = 20
button_padding: f32 = 4

// Calculate text dimensions
text_size: f32 = 20
text_width := rl.MeasureTextEx(g_font_20, cstring(raw_data(label)), text_size, 0).x

total_width := text_width + f32(button_padding*2)
total_height := max(button_size, text_size + button_padding*2)
is_mouse_on_area := rl.CheckCollisionPointRec(
rl.GetMousePosition(),
{pos.x, pos.y, total_width, total_height}
)

// Draw button
rl.DrawRectangle(
i32(pos.x), i32(pos.y),
i32(total_width), i32(total_height),
is_mouse_on_area ? rl.GRAY : rl.DARKGRAY
)
rl.DrawRectangleLines(
i32(pos.x), i32(pos.y),
i32(total_width), i32(total_height),
rl.WHITE
)

// Draw text
rl.DrawTextEx(
g_font_20,
cstring(raw_data(label)),
{pos.x+1, pos.y} + button_padding,
f32(text_size),
0,
rl.WHITE
)

is_clicked := is_mouse_on_area && rl.IsMouseButtonPressed(.LEFT)
return is_clicked
}

// MARK: Utils
Expand Down