Skip to content
3 changes: 2 additions & 1 deletion back/src/net_utils/vlan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ipmininet.ipovs_switch import IPOVSSwitch
from ipmininet.ipswitch import IPSwitch
from network_schema import Node, NodeInterface
from node_types import NodeType


def setup_vlans(net: IPNet, nodes: list[Node]) -> None:
Expand All @@ -14,7 +15,7 @@ def setup_vlans(net: IPNet, nodes: list[Node]) -> None:
"""

for node in nodes:
if node.config.type == "l2_switch":
if node.config.type == NodeType.SWITCH:
switch = net.get(node.data.id)
add_bridge(switch, node.interface)

Expand Down
3 changes: 2 additions & 1 deletion back/src/net_utils/vxlan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ipmininet.ipnet import IPNet
from network_schema import Node
from node_types import NodeType


def setup_vtep_interfaces(net: IPNet, nodes: list[Node]) -> None:
Expand All @@ -13,7 +14,7 @@ def setup_vtep_interfaces(net: IPNet, nodes: list[Node]) -> None:
nodes (list[Node]): A list of nodes to configure.
"""
for node in nodes:
if node.config.type == "router":
if node.config.type == NodeType.ROUTER:
router = net.get(node.data.id)

# Configure VXLAN network interfaces (connection_type == 1)
Expand Down
14 changes: 10 additions & 4 deletions back/src/network_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ipmininet.router.config import RouterConfig
from network_schema import Network, Node, NodeConfig, NodeInterface
from pkt_parser import is_ipv4_address
from node_types import NodeType


class MiminetTopology(IPTopo):
Expand Down Expand Up @@ -47,14 +48,17 @@ def __handle_node(self, node: Node):
node_type: str = config.type # network device type
node_id: str = node.data.id # network device name(label)

if node_type == "l2_switch":
if node_type == NodeType.SWITCH:
self.__handle_l2_switch(node_id, config)
elif node_type in ("host", "server"):
elif node_type in (NodeType.HOST, NodeType.SERVER):
self.__handle_host_or_server(node_id, config)
elif node_type == "l1_hub":
elif node_type == NodeType.HUB:
self.__handle_l1_hub(node_id)
elif node_type == "router":
elif node_type == NodeType.ROUTER:
self.__handle_router(node_id, config)
else:
print(f"Unknown node type: {node_type}")
return

def __handle_l2_switch(self, node_id: str, config: NodeConfig):
assert config.stp in (0, 1, 2), "Incorrect STP mode"
Expand Down Expand Up @@ -127,6 +131,8 @@ def build(self, *args, **kwargs):
interfaces = []

for node in self.__network.nodes:
if node.config.type == "textbox":
continue
# Caches node by ID for quick lookup later
self.__id_to_node[node.data.id] = node

Expand Down
13 changes: 13 additions & 0 deletions back/src/node_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from enum import Enum


class NodeType(str, Enum):
"""
Types of all functional network nodes
"""

HOST = "host"
SERVER = "server"
SWITCH = "l2_switch"
HUB = "l1_hub"
ROUTER = "router"
28 changes: 25 additions & 3 deletions back/src/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import os
import signal

from marshmallow import Schema
import marshmallow_dataclass

from node_types import NodeType
from celery_app import (
SEND_NETWORK_RESPONSE_EXCHANGE,
SEND_NETWORK_RESPONSE_ROUTING_KEY,
Expand All @@ -12,6 +15,25 @@
from mininet.log import error, setLogLevel
from network_schema import Network

_network_schema: Schema | None = None


def _filter_unknown_nodes(data: dict) -> dict:
allowed = set(NodeType)
data["nodes"] = [
node
for node in data.get("nodes", [])
if node.get("config", {}).get("type") in allowed
]
return data


def get_network_schema() -> Schema:
global _network_schema
if _network_schema is None:
_network_schema = marshmallow_dataclass.class_schema(Network)()
return _network_schema


def run_miminet(network_json: str):
"""Load network from JSON and start emulation safely.
Expand All @@ -30,9 +52,9 @@ def run_miminet(network_json: str):
print("Set default handler to SIGCHLD")
signal.signal(signal.SIGCHLD, signal.SIG_IGN)

jnet = json.loads(network_json)
network_schema = marshmallow_dataclass.class_schema(Network)()
network_json = network_schema.load(jnet, unknown="include")
jnet = _filter_unknown_nodes(json.loads(network_json))
schema = get_network_schema()
network_json = schema.load(jnet, unknown="include")

for _ in range(4):
try:
Expand Down
4 changes: 4 additions & 0 deletions front/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
save_edge_config,
save_host_config,
save_hub_config,
save_textbox_config,
save_router_config,
save_server_config,
save_switch_config,
Expand Down Expand Up @@ -315,6 +316,9 @@ def get_database_uri(mode):
app.add_url_rule(
"/host/hub_save_config", methods=["GET", "POST"], view_func=save_hub_config
)
app.add_url_rule(
"/host/textbox_save_config", methods=["GET", "POST"], view_func=save_textbox_config
)
app.add_url_rule(
"/host/switch_save_config", methods=["GET", "POST"], view_func=save_switch_config
)
Expand Down
Loading
Loading