diff --git a/.gitignore b/.gitignore index 4c44546..5fdb523 100644 --- a/.gitignore +++ b/.gitignore @@ -103,4 +103,10 @@ venv.bak/ # mypy .mypy_cache/ -temp \ No newline at end of file +temp + +#vscode settings +settings.json + + +*_Project diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..9e6f6c3 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,40 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 # Use the ref you want to point at + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + +- repo: https://github.com/psf/black + rev: 24.8.0 + hooks: + - id: black + args: [--line-length=120] + +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.11.2 + hooks: + - id: mypy + args: [--strict, --ignore-missing-imports] + +- repo: https://github.com/pycqa/flake8 + rev: 7.1.1 + hooks: + - id: flake8 + additional_dependencies: [flake8-typing-imports==1.12.0] + args: [--max-line-length=120] + +- repo: https://github.com/asottile/reorder_python_imports + rev: v3.13.0 + hooks: + - id: reorder-python-imports + args: [--py37-plus, --add-import, 'from __future__ import annotations'] + +- repo: https://github.com/asottile/setup-cfg-fmt + rev: v2.5.0 + hooks: + - id: setup-cfg-fmt diff --git a/Example_Project/Add_node.py b/Example_Project/Add_node.py new file mode 100644 index 0000000..dcee04f --- /dev/null +++ b/Example_Project/Add_node.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import Optional +from typing import Tuple + +from node_editor.node import Node + + +class Add_Node(Node): + def __init__(self) -> None: + super().__init__() + + self.title_text: str = "Add" + self.type_text: str = "Logic Nodes" + self.set_color(title_color=(0, 128, 0)) + + self.add_pin(name="Ex In", is_output=False, execution=True) + self.add_pin(name="Ex Out", is_output=True, execution=True) + + self.add_pin(name="input A", is_output=False) + self.add_pin(name="input B", is_output=False) + self.add_pin(name="output", is_output=True) + self.build() + + def set_color( + self, + title_color: Tuple[int, int, int], + background_color: Optional[Tuple[int, int, int]] = None, + ) -> None: + super().set_color(title_color, background_color) + + def add_pin(self, name: str, is_output: bool, execution: bool = False) -> None: + # Assuming add_pin is defined in the parent class + super().add_pin(name=name, is_output=is_output, execution=execution) + + def build(self) -> None: + # Assuming build is defined in the parent class + super().build() diff --git a/Example_Project/Button_node.py b/Example_Project/Button_node.py new file mode 100644 index 0000000..23b501d --- /dev/null +++ b/Example_Project/Button_node.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from PySide6 import QtWidgets + +from node_editor.node import Node + + +class Button_Node(Node): + def __init__(self) -> None: + super().__init__() + + self.title_text = "Button" + self.type_text = "Inputs" + self.set_color(title_color=(128, 0, 0)) + + self.add_pin(name="Ex Out", is_output=True, execution=True) + + self.build() + + def init_widget(self) -> None: + self.widget = QtWidgets.QWidget() + layout = QtWidgets.QVBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + + btn = QtWidgets.QPushButton("Button test") + btn.clicked.connect(self.btn_cmd) + layout.addWidget(btn) + self.widget.setLayout(layout) + + proxy = QtWidgets.QGraphicsProxyWidget() + proxy.setWidget(self.widget) + proxy.setParentItem(self) + + super().init_widget() + + def btn_cmd(self) -> None: + print("btn command") + self.execute() diff --git a/Example_Project/Print_node.py b/Example_Project/Print_node.py new file mode 100644 index 0000000..3476c37 --- /dev/null +++ b/Example_Project/Print_node.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import Optional +from typing import Tuple + +from node_editor.node import Node + + +class Print_Node(Node): + def __init__(self) -> None: + super().__init__() + + self.title_text: str = "Print" + self.type_text: str = "Debug Nodes" + self.set_color(title_color=(160, 32, 240)) + + self.add_pin(name="Ex In", is_output=False, execution=True) + + self.add_pin(name="input", is_output=False) + self.build() + + def set_color( + self, + title_color: Tuple[int, int, int], + background_color: Optional[Tuple[int, int, int]] = None, + ) -> None: + super().set_color(title_color, background_color) + + def add_pin(self, name: str, is_output: bool, execution: bool = False) -> None: + # Assuming add_pin is defined in the parent class + super().add_pin(name=name, is_output=is_output, execution=execution) + + def build(self) -> None: + # Assuming build is defined in the parent class + super().build() diff --git a/Example_Project/Scaler_node.py b/Example_Project/Scaler_node.py new file mode 100644 index 0000000..0a93802 --- /dev/null +++ b/Example_Project/Scaler_node.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from PySide6 import QtWidgets + +from Example_Project.common_widgets import FloatLineEdit +from node_editor.node import Node + + +class Scaler_Node(Node): + def __init__(self) -> None: + super().__init__() + + self.title_text = "Scaler" + self.type_text = "Constants" + self.set_color(title_color=(255, 165, 0)) + + self.add_pin(name="value", is_output=True) + + self.build() + + def init_widget(self) -> None: + self.widget = QtWidgets.QWidget() + self.widget.setFixedWidth(100) + layout = QtWidgets.QVBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + self.scaler_line = FloatLineEdit() + layout.addWidget(self.scaler_line) + self.widget.setLayout(layout) + + proxy = QtWidgets.QGraphicsProxyWidget() + proxy.setWidget(self.widget) + proxy.setParentItem(self) + + super().init_widget() diff --git a/Example_Project/__init__.py b/Example_Project/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Example_Project/common_widgets.py b/Example_Project/common_widgets.py new file mode 100644 index 0000000..0eaede4 --- /dev/null +++ b/Example_Project/common_widgets.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import Optional +from typing import Tuple + +from PySide6 import QtGui +from PySide6 import QtWidgets +from PySide6.QtCore import Qt + + +class FloatLineEdit(QtWidgets.QLineEdit): # type: ignore + def __init__(self, parent: Optional[QtWidgets.QWidget] = None) -> None: + super().__init__(parent) + self.setValidator(FloatValidator()) + + def keyPressEvent(self, event: QtGui.QKeyEvent) -> None: + if event.key() == Qt.Key_Space: + event.ignore() + else: + super().keyPressEvent(event) + + +class FloatValidator(QtGui.QDoubleValidator): # type: ignore + def __init__(self, parent: Optional[QtGui.QObject] = None) -> None: + super().__init__(parent) + + def validate(self, input_str: str, pos: int) -> Tuple[QtGui.QValidator.State, str, int]: + state, num, pos = super().validate(input_str, pos) + if state == QtGui.QValidator.Acceptable: + return QtGui.QValidator.Acceptable, num, pos + if str(num).count(".") > 1: + return QtGui.QValidator.Invalid, num, pos + if input_str[pos - 1] == ".": + return QtGui.QValidator.Acceptable, num, pos + return QtGui.QValidator.Invalid, num, pos diff --git a/Example_Project/test.json b/Example_Project/test.json new file mode 100644 index 0000000..5f4aee2 --- /dev/null +++ b/Example_Project/test.json @@ -0,0 +1,66 @@ +{ + "nodes": [ + { + "type": "Scaler_Node", + "x": 4765, + "y": 5098, + "index": "4" + }, + { + "type": "Scaler_Node", + "x": 4752, + "y": 5000, + "index": "3" + }, + { + "type": "Print_Node", + "x": 5121, + "y": 4905, + "index": "2" + }, + { + "type": "Button_Node", + "x": 4777, + "y": 4894, + "index": "1" + }, + { + "type": "Add_Node", + "x": 4971, + "y": 4955, + "index": "0" + } + ], + "connections": [ + { + "start_id": "4", + "end_id": "0", + "start_pin": "value", + "end_pin": "input B" + }, + { + "start_id": "3", + "end_id": "0", + "start_pin": "value", + "end_pin": "input A" + }, + { + "start_id": "0", + "end_id": "2", + "start_pin": "output", + "end_pin": "input" + }, + { + "start_id": "0", + "end_id": "2", + "start_pin": "Ex Out", + "end_pin": "Ex In" + }, + { + "start_id": "1", + "end_id": "0", + "start_pin": "Ex Out", + "end_pin": "Ex In" + } + ] +} diff --git a/README.md b/README.md index 29bb159..9d849ab 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,29 @@ -# Logic Node Editor +# Python Node Editor -A very basic minimal code for implementing a node graph or editor using PySide2. In this case we are building with logic nodes. All nodes are built using QGraphics items. +This is a node based Python tool used for visual scripting that is designed to be used for composing high level Python code into reusable blocks. Nodes look and function similar to Unreal Engine blueprints. Each node consists of connection pins and a widget section enabling the developer to write a full custom PySide GUI for each node type. -Example video: https://www.youtube.com/watch?v=DOsFJ8lm9dU +The tool is designed to allow you to write Python code in individual files per class/node. This means that your code is self-contained, easily modifiable, and reusable across multiple projects. Additionally, the GUI is designed to be familiar to those who have used Unreal Engine's blueprinting system, making it easy to learn and use. -![nodes](https://github.com/bhowiebkr/simple-node-editor/blob/master/images/node_editor.jpg) +My goal with this project is to provide a new and innovative way of organizing and working with Python code. While the tool is still in the development phase, I am constantly working to improve its functionality and features. + +Visual scripting using nodes does have some benefits and drawbacks and it’s up to the end developer to decide when such a system is beneficial or not. + +![nodes](https://github.com/bhowiebkr/simple-node-editor/blob/master/images/node_editor2.jpg) + +Use it for: +- high level composing/configurable code. If a given system consists of many similar components but have a unique set of steps or requirements on similar tasks. Example a VFX or game pipeline. +- readability for non programmers as a dependency graph with built-in functionality +- enabling non-programmers a simple system to assemble blocks of logic +- networks that require a high level of feedback throughout that network and not just the end result. Example shader building, sound synthesizing, machine learning, robotics and sensors. Each node can have a custom visual feedback such as images, graphs, sound timelines, spreadsheets etc. +- prototyping logic. +- Generator scripts. Taking an input or building up a result that gets saved for other uses. Example textures, images, sound, ML training data. + +Don’t use it for +- Anything complex. 40 nodes or less. This is because the user not only needs to think of how nodes are logically connected, but also the visual composure of nodes in the graph. It’s always best to refactor code when a graph gets too complex to make sense of. +- code that needs to run fast. The overhead of node based tools will increase processing in almost all cases. +- Code that doesn’t need a GUI/human interface to use. + +For minimal GUI code for creating a node network see [GUI-nodes-only](https://github.com/bhowiebkr/simple-node-editor/tree/GUI-nodes-only) branch. + + +[![Video](http://img.youtube.com/vi/DOsFJ8lm9dU/0.jpg)](http://www.youtube.com/watch?v=DOsFJ8lm9dU) diff --git a/images/node_editor.jpg b/images/node_editor.jpg deleted file mode 100644 index 25b0e7a..0000000 Binary files a/images/node_editor.jpg and /dev/null differ diff --git a/images/node_editor2.jpg b/images/node_editor2.jpg new file mode 100644 index 0000000..9007b0a Binary files /dev/null and b/images/node_editor2.jpg differ diff --git a/launch.bat b/launch.bat new file mode 100644 index 0000000..d8ebe60 --- /dev/null +++ b/launch.bat @@ -0,0 +1,7 @@ +@echo off + +:: Activate the virtual environment +call venv\Scripts\activate + +:: Start the three Python programs in separate command windows +start cmd /k python .\main.py diff --git a/main.py b/main.py index e87fa46..b23ac46 100644 --- a/main.py +++ b/main.py @@ -1,54 +1,173 @@ -import sys - -from PySide6 import QtWidgets, QtCore, QtGui +from __future__ import annotations +import importlib +import inspect import logging -import os - -from node_editor.gui.node_widget import NodeWidget -from node_editor.gui.palette import palette +import sys +from pathlib import Path +from typing import Any +from typing import Dict +from typing import Optional + +import qdarktheme +from PySide6 import QtCore +from PySide6 import QtGui +from PySide6 import QtWidgets +from PySide6.QtCore import QByteArray # Or from PySide2.QtCore import QByteArray + +from node_editor.compute_graph import compute_dag_nodes +from node_editor.connection import Connection from node_editor.gui.node_list import NodeList +from node_editor.gui.node_widget import NodeWidget +from node_editor.node import Node logging.basicConfig(level=logging.DEBUG) +""" +A simple Node Editor application that allows the user to create, modify and connect nodes of various types. + +The application consists of a main window that contains a splitter with a Node List and a Node Widget. The Node List +shows a list of available node types, while the Node Widget is where the user can create, edit and connect nodes. -class NodeEditor(QtWidgets.QMainWindow): - def __init__(self, parent=None): - super(NodeEditor, self).__init__(parent) - self.settings = None +This application uses PySide6 as a GUI toolkit. + +Author: Bryan Howard +Repo: https://github.com/bhowiebkr/simple-node-editor +""" + + +class NodeEditor(QtWidgets.QMainWindow): # type: ignore + OnProjectPathUpdate = QtCore.Signal(Path) + + def __init__(self, parent: Optional[QtWidgets.QWidget] = None) -> None: + super().__init__(parent) + self.settings: Optional[QtCore.QSettings] = None + self.project_path: Optional[Path] = None + self.imports: Optional[Dict[str, Dict[str, Any]]] = ( + None # we will store the project import node types here for now. + ) icon = QtGui.QIcon("resources\\app.ico") self.setWindowIcon(icon) - self.setWindowTitle("Logic Node Editor") + self.setWindowTitle("Simple Node Editor") settings = QtCore.QSettings("node-editor", "NodeEditor") - # Layouts - main_widget = QtWidgets.QWidget() - main_layout = QtWidgets.QHBoxLayout() + # create a "File" menu and add an "Export CSV" action to it + file_menu = QtWidgets.QMenu("File", self) + self.menuBar().addMenu(file_menu) + + load_action = QtGui.QAction("Load Project", self) + load_action.triggered.connect(self.get_project_path) + file_menu.addAction(load_action) - self.node_list = NodeList() - self.splitter = QtWidgets.QSplitter() + save_action = QtGui.QAction("Save Project", self) + save_action.triggered.connect(self.save_project) + file_menu.addAction(save_action) + # Layouts + main_widget = QtWidgets.QWidget() self.setCentralWidget(main_widget) + main_layout = QtWidgets.QHBoxLayout() main_widget.setLayout(main_layout) - - self.node_widget = NodeWidget(self) - main_layout.addWidget(self.splitter) - self.splitter.addWidget(self.node_list) + left_layout = QtWidgets.QVBoxLayout() + left_layout.setContentsMargins(0, 0, 0, 0) + + # Widgets + self.node_list: NodeList = NodeList(self) + left_widget = QtWidgets.QWidget() + self.splitter: QtWidgets.QSplitter = QtWidgets.QSplitter() + execute_button = QtWidgets.QPushButton("Execute Graph") + execute_button.setFixedHeight(40) + execute_button.clicked.connect(self.execute_graph) + self.node_widget: NodeWidget = NodeWidget(self) + + # Add Widgets to layouts + self.splitter.addWidget(left_widget) self.splitter.addWidget(self.node_widget) + left_widget.setLayout(left_layout) + left_layout.addWidget(self.node_list) + left_layout.addWidget(execute_button) + main_layout.addWidget(self.splitter) + + # Load the example project + example_project_path = Path(__file__).parent.resolve() / "Example_project" + self.load_project(example_project_path) + + # Restore GUI from last state + if settings.contains("geometry"): + self.restoreGeometry(QByteArray(settings.value("geometry"))) - try: - self.restoreGeometry(settings.value("geometry")) s = settings.value("splitterSize") self.splitter.restoreState(s) - except AttributeError as e: - logging.warning( - "Unable to load settings. First time opening the tool?\n" + str(e) - ) + def execute_graph(self) -> None: + print("Executing Graph:") + + # Get a list of the nodes in the view + nodes = self.node_widget.scene.get_items_by_type(Node) + edges = self.node_widget.scene.get_items_by_type(Connection) + # sort them + compute_dag_nodes(nodes, edges) + + def save_project(self) -> None: + file_dialog = QtWidgets.QFileDialog() + file_dialog.setAcceptMode(QtWidgets.QFileDialog.AcceptSave) + file_dialog.setDefaultSuffix("json") + file_dialog.setNameFilter("JSON files (*.json)") + file_path, _ = file_dialog.getSaveFileName() + self.node_widget.save_project(file_path) + + def load_project(self, project_path: Optional[Path] = None) -> None: + if not project_path: + return + + project_path = Path(project_path) + if project_path.exists() and project_path.is_dir(): + self.project_path = project_path + + self.imports = {} + + for file in project_path.glob("*.py"): + if not file.stem.endswith("_node"): + print("file:", file.stem) + continue + spec = importlib.util.spec_from_file_location(file.stem, file) # type: ignore + module = importlib.util.module_from_spec(spec) # type: ignore + spec.loader.exec_module(module) + + for name, obj in inspect.getmembers(module): + if not name.endswith("_Node"): + continue + if inspect.isclass(obj): + self.imports[obj.__name__] = {"class": obj, "module": module} + # break + + self.node_list.update_project(self.imports) + + # work on just the first json file. add the ablitity to work on multiple json files later + for json_path in project_path.glob("*.json"): + self.node_widget.load_scene(str(json_path), self.imports) + break + + def get_project_path(self) -> None: + project_path = QtWidgets.QFileDialog.getExistingDirectory(None, "Select Project Folder", "") + if not project_path: + return + + self.load_project(Path(project_path)) + + def closeEvent(self, event: QtGui.QCloseEvent) -> None: + """ + Handles the close event by saving the GUI state and closing the application. + + Args: + event: Close event. + + Returns: + None. + """ - def closeEvent(self, event): self.settings = QtCore.QSettings("node-editor", "NodeEditor") self.settings.setValue("geometry", self.saveGeometry()) self.settings.setValue("splitterSize", self.splitter.saveState()) @@ -56,9 +175,11 @@ def closeEvent(self, event): if __name__ == "__main__": + app = QtWidgets.QApplication(sys.argv) app.setWindowIcon(QtGui.QIcon("resources\\app.ico")) - app.setPalette(palette) + qdarktheme.setup_theme() + launcher = NodeEditor() launcher.show() app.exec() diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..37c9065 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,4 @@ +[mypy] +python_version = 3.8 +ignore_missing_imports = True +strict = True diff --git a/node_editor/__init__.py b/node_editor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/node_editor/common.py b/node_editor/common.py new file mode 100644 index 0000000..4d49a5f --- /dev/null +++ b/node_editor/common.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from enum import Enum + + +class Node_Status(Enum): + CLEAN = 1 + DIRTY = 2 + ERROR = 3 diff --git a/node_editor/compute_graph.py b/node_editor/compute_graph.py new file mode 100644 index 0000000..a4a482c --- /dev/null +++ b/node_editor/compute_graph.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import List + +from node_editor.connection import Connection +from node_editor.node import Node + + +def topologicalSortUtil(v: int, adj: List[List[int]], visited: List[bool], stack: List[int]) -> None: + # Mark the current node as visited + visited[v] = True + + # Recur for all adjacent vertices + for i in adj[v]: + if not visited[i]: + topologicalSortUtil(i, adj, visited, stack) + + # Push current vertex to stack which stores the result + stack.append(v) + + +# Function to perform Topological Sort +def topologicalSort(adj: List[List[int]], num_nodes: int) -> List[int]: + # Stack to store the result + stack: List[int] = [] + + visited = [False] * num_nodes + + # Call the recursive helper function to store + # Topological Sort starting from all vertices one by + # one + for i in range(num_nodes): + if not visited[i]: + topologicalSortUtil(i, adj, visited, stack) + + # Print contents of stack + print("Topological sorting of the graph:", end=" ") + + topological_order = [] + while stack: + # print(stack.pop(), end=" ") + topological_order.append(stack.pop()) + + return topological_order + + +def compute_dag_nodes(nodes: List[Node], connections: List[Connection]) -> None: + print("Compute DAG Nodes") + + num_nodes = len(nodes) + # Get the edges + edges = [] + for connection in connections: + edges.append([int(node.index) for node in connection.nodes() if node is not None]) + + # Adjacency List + adjacency: List[List[int]] = [[] for _ in range(num_nodes)] + + for edge in edges: + adjacency[edge[0]].append(edge[1]) + + print("adjacency:\n\n", adjacency) + + topological_order = topologicalSort(adjacency, num_nodes) + + print(topological_order) diff --git a/node_editor/connection.py b/node_editor/connection.py new file mode 100644 index 0000000..fd3110b --- /dev/null +++ b/node_editor/connection.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import Optional +from typing import Tuple + +from PySide6.QtCore import QPointF +from PySide6.QtWidgets import QGraphicsScene + +from node_editor.gui.connection_graphics import Connection_Graphics +from node_editor.node import Node +from node_editor.pin import Pin + + +class Connection(Connection_Graphics): + def __init__(self, parent: Optional[Connection_Graphics]) -> None: + super().__init__(parent) + self.start_pin: Optional[Pin] = None + self.end_pin: Optional[Pin] = None + self.start_pos: QPointF = QPointF() + self.end_pos: QPointF = QPointF() + + def delete(self) -> None: + for pin in (self.start_pin, self.end_pin): + if pin is not None: + pin.connection = None + self.start_pin = None + self.end_pin = None + scene = self.scene() + if scene is not None: + scene.removeItem(self) + + def set_start_pin(self, pin: Pin) -> None: + self.start_pin = pin + pin.connection = self + + def set_end_pin(self, pin: Pin) -> None: + self.end_pin = pin + pin.connection = self + + def nodes(self) -> Tuple[Optional[Node], Optional[Node]]: + return ( + self.start_pin.node if self.start_pin is not None else None, + self.end_pin.node if self.end_pin is not None else None, + ) + + def update_start_and_end_pos(self) -> None: + if self.start_pin is not None and not self.start_pin.is_output: + self.start_pin, self.end_pin = self.end_pin, self.start_pin + + if self.start_pin is not None: + self.start_pos = self.start_pin.scenePos() + + if self.end_pin is not None: + self.end_pos = self.end_pin.scenePos() + + self.update_path() + + def scene(self) -> Optional[QGraphicsScene]: + return super().scene() + + def update_path(self) -> None: + super().update_path() diff --git a/node_editor/gui/connection.py b/node_editor/gui/connection.py deleted file mode 100644 index 2faa4c4..0000000 --- a/node_editor/gui/connection.py +++ /dev/null @@ -1,98 +0,0 @@ -from PySide6 import QtWidgets, QtGui, QtCore - - -class Connection(QtWidgets.QGraphicsPathItem): - def __init__(self, parent): - super(Connection, self).__init__(parent) - - self.setFlag(QtWidgets.QGraphicsPathItem.ItemIsSelectable) - - self.setPen(QtGui.QPen(QtGui.QColor(200, 200, 200), 2)) - self.setBrush(QtCore.Qt.NoBrush) - self.setZValue(-1) - - self._start_port = None - self._end_port = None - - self.start_pos = QtCore.QPointF() - self.end_pos = QtCore.QPointF() - - self._do_highlight = False - - def delete(self): - for port in (self._start_port, self._end_port): - if port: - # port.remove_connection(self) - port.connection = None - port = None - - self.scene().removeItem(self) - - @property - def start_port(self): - return self._start_port - - @property - def end_port(self): - return self._end_port - - @start_port.setter - def start_port(self, port): - self._start_port = port - self._start_port.connection = self - - @end_port.setter - def end_port(self, port): - self._end_port = port - self._end_port.connection = self - - def nodes(self): - return (self._start_port().node(), self._end_port().node()) - - def update_start_and_end_pos(self): - """Update the ends of the connection - - Get the start and end ports and use them to set the start and end positions. - """ - - if self.start_port and not self.start_port.is_output(): - print("flipping connection") - temp = self.end_port - self._end_port = self.start_port - self._start_port = temp - - if self._start_port: - self.start_pos = self._start_port.scenePos() - - # if we are pulling off an exiting connection we skip code below - if self._end_port: - self.end_pos = self._end_port.scenePos() - - self.update_path() - - def update_path(self): - """Draw a smooth cubic curve from the start to end ports - """ - path = QtGui.QPainterPath() - path.moveTo(self.start_pos) - - dx = self.end_pos.x() - self.start_pos.x() - dy = self.end_pos.y() - self.start_pos.y() - - ctr1 = QtCore.QPointF(self.start_pos.x() + dx * 0.5, self.start_pos.y()) - ctr2 = QtCore.QPointF(self.start_pos.x() + dx * 0.5, self.start_pos.y() + dy) - path.cubicTo(ctr1, ctr2, self.end_pos) - - self.setPath(path) - - def paint(self, painter, option=None, widget=None): - """ - Override the default paint method depending on if the object is selected - """ - if self.isSelected() or self._do_highlight: - painter.setPen(QtGui.QPen(QtGui.QColor(255, 102, 0), 3)) - else: - painter.setPen(QtGui.QPen(QtGui.QColor(0, 128, 255), 2)) - - painter.drawPath(self.path()) - diff --git a/node_editor/gui/connection_graphics.py b/node_editor/gui/connection_graphics.py new file mode 100644 index 0000000..5cdd030 --- /dev/null +++ b/node_editor/gui/connection_graphics.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from typing import cast +from typing import Optional +from typing import Tuple + +from PySide6 import QtCore +from PySide6 import QtGui +from PySide6 import QtWidgets + +from node_editor.node import Node +from node_editor.pin import Pin + + +class Connection_Graphics(QtWidgets.QGraphicsPathItem): # type: ignore + """ + A Connection represents a graphical connection between two NodePorts in a PySide6 application. + + Attributes: + start_pin (NodePort): The NodePort where the connection starts. + end_pin (NodePort): The NodePort where the connection ends. + start_pos (QPointF): The starting position of the connection. + end_pos (QPointF): The ending position of the connection. + + Methods: + delete(): Deletes the connection. + nodes(): Returns a tuple of the two connected nodes. + update_start_and_end_pos(): Updates the starting and ending positions of the connection. + update_path(): Draws a smooth cubic curve from the starting to ending position of the connection. + paint(painter, option=None, widget=None): Override the default paint method depending on if the object is selected. + + Example: + conn = Connection(parent) + conn.start_pin = start_pin + conn.end_pin = end_pin + conn.update_start_and_end_pos() + conn.update_path() + """ + + def __init__(self, parent: Optional[QtWidgets.QGraphicsItem] = None) -> None: + super().__init__(parent) + + self.setFlag(QtWidgets.QGraphicsItem.GraphicsItemFlag.ItemIsSelectable) + + self.setPen(QtGui.QPen(QtGui.QColor(200, 200, 200), 2)) + self.setBrush(QtCore.Qt.BrushStyle.NoBrush) + self.setZValue(-1) + + self.start_pos: QtCore.QPointF = QtCore.QPointF() + self.end_pos: QtCore.QPointF = QtCore.QPointF() + self.start_pin: Optional[Pin] + self.end_pin: Optional[Pin] + + self._do_highlight: bool = False + + def update_path(self) -> None: + """ + Draws a smooth cubic curve from the start to end pins. + """ + path = QtGui.QPainterPath() + path.moveTo(self.start_pos) + + dx = self.end_pos.x() - self.start_pos.x() + dy = self.end_pos.y() - self.start_pos.y() + + ctr1 = QtCore.QPointF(self.start_pos.x() + dx * 0.5, self.start_pos.y()) + ctr2 = QtCore.QPointF(self.start_pos.x() + dx * 0.5, self.start_pos.y() + dy) + path.cubicTo(ctr1, ctr2, self.end_pos) + + self.setPath(path) + + def paint( + self, + painter: QtGui.QPainter, + option: Optional[QtWidgets.QStyleOptionGraphicsItem] = None, + widget: Optional[QtWidgets.QWidget] = None, + ) -> None: + """ + Override the default paint method depending on if the object is selected. + + Args: + painter (QPainter): The QPainter object used to paint the Connection. + option (QStyleOptionGraphicsItem): The style options for the Connection. + widget (QWidget): The widget used to paint the Connection. + """ + + thickness = 0 + color = QtGui.QColor(0, 128, 255) + if self.start_pin: + if self.start_pin.execution: + thickness = 3 + color = QtGui.QColor(255, 255, 255) + + if self.isSelected() or self._do_highlight: + painter.setPen(QtGui.QPen(color.lighter(), thickness + 2)) + else: + painter.setPen(QtGui.QPen(color, thickness)) + + painter.drawPath(self.path()) + + def delete(self) -> None: + pass + + def nodes(self) -> Tuple[Optional[Node], Optional[Node]]: + # Implement the logic to return the connected nodes + if self.start_pin and self.end_pin: + start_node = cast(Optional[Node], self.start_pin.node) if self.start_pin else None + end_node = cast(Optional[Node], self.end_pin.node) if self.end_pin else None + return (start_node, end_node) + raise ValueError("Both start_pin and end_pin must be set") + + def update_start_and_end_pos(self) -> None: + pass diff --git a/node_editor/gui/node.py b/node_editor/gui/node.py deleted file mode 100644 index 1614075..0000000 --- a/node_editor/gui/node.py +++ /dev/null @@ -1,208 +0,0 @@ -from PySide6 import QtWidgets, QtGui, QtCore - -from node_editor.gui.port import Port - - -class Node(QtWidgets.QGraphicsPathItem): - def __init__(self): - super(Node, self).__init__() - - self.setFlag(QtWidgets.QGraphicsPathItem.ItemIsMovable) - self.setFlag(QtWidgets.QGraphicsPathItem.ItemIsSelectable) - - self._title_text = "Title" - self._type_text = "base" - - self._width = 30 # The Width of the node - self._height = 30 # the height of the node - self._ports = [] # A list of ports - - self.node_color = QtGui.QColor(20, 20, 20, 200) - - self.title_path = QtGui.QPainterPath() # The path for the title - self.type_path = QtGui.QPainterPath() # The path for the type - self.misc_path = QtGui.QPainterPath() # a bunch of other stuff - - self.horizontal_margin = 30 # horizontal margin - self.vertical_margin = 15 # vertical margin - - @property - def title(self): - return self._title_text - - @title.setter - def title(self, title): - self._title_text = title - - @property - def type_text(self): - return self._type_text - - @type_text.setter - def type_text(self, type_text): - self._type_text = type_text - - def paint(self, painter, option=None, widget=None): - if self.isSelected(): - painter.setPen(QtGui.QPen(QtGui.QColor(241, 175, 0), 2)) - painter.setBrush(self.node_color) - else: - painter.setPen(self.node_color.lighter()) - painter.setBrush(self.node_color) - - painter.drawPath(self.path()) - painter.setPen(QtCore.Qt.NoPen) - painter.setBrush(QtCore.Qt.white) - - painter.drawPath(self.title_path) - painter.drawPath(self.type_path) - painter.drawPath(self.misc_path) - - def add_port(self, name, is_output=False, flags=0, ptr=None): - port = Port(self, self.scene()) - port.set_is_output(is_output) - port.set_name(name) - port.set_node(node=self) - port.set_port_flags(flags) - port.set_ptr(ptr) - - self._ports.append(port) - - def build(self): - """ Build the node - """ - - self.title_path = QtGui.QPainterPath() # reset - self.type_path = QtGui.QPainterPath() # The path for the type - self.misc_path = QtGui.QPainterPath() # a bunch of other stuff - - total_width = 0 - total_height = 0 - path = QtGui.QPainterPath() # The main path - - # The fonts what will be used - title_font = QtGui.QFont("Lucida Sans Unicode", pointSize=16) - title_type_font = QtGui.QFont("Lucida Sans Unicode", pointSize=8) - port_font = QtGui.QFont("Lucida Sans Unicode") - - # Get the dimentions of the title and type - title_dim = { - "w": QtGui.QFontMetrics(title_font).horizontalAdvance(self._title_text), - "h": QtGui.QFontMetrics(title_font).height(), - } - - title_type_dim = { - "w": QtGui.QFontMetrics(title_type_font).horizontalAdvance("(" + self._type_text + ")"), - "h": QtGui.QFontMetrics(title_type_font).height(), - } - - # Get the max width - for dim in [title_dim["w"], title_type_dim["w"]]: - if dim > total_width: - total_width = dim - - # Add both the title and type height together for the total height - for dim in [title_dim["h"], title_type_dim["h"]]: - total_height += dim - - # Add the heigth for each of the ports - for port in self._ports: - port_dim = { - "w": QtGui.QFontMetrics(port_font).horizontalAdvance(port.name()), - "h": QtGui.QFontMetrics(port_font).height(), - } - - if port_dim["w"] > total_width: - total_width = port_dim["w"] - - total_height += port_dim["h"] - - # Add the margin to the total_width - total_width += self.horizontal_margin - total_height += self.vertical_margin - - # Draw the background rectangle - path.addRoundedRect( - -total_width / 2, -total_height / 2, total_width, total_height, 5, 5 - ) - - # Draw the title - self.title_path.addText( - -title_dim["w"] / 2, - (-total_height / 2) + title_dim["h"], - title_font, - self._title_text, - ) - - # Draw the type - self.type_path.addText( - -title_type_dim["w"] / 2, - (-total_height / 2) + title_dim["h"] + title_type_dim["h"], - title_type_font, - "(" + self._type_text + ")", - ) - - y = (-total_height / 2) + title_dim["h"] + title_type_dim["h"] + port_dim["h"] - - for port in self._ports: - if port.is_output(): - port.setPos(total_width / 2 - 10, y) - else: - port.setPos(-total_width / 2 + 10, y) - y += port_dim["h"] - - self.setPath(path) - - self._width = total_width - self._height = total_height - - def select_connections(self, value): - for port in self._ports: - if port.connection: - port.connection._do_highlight = value - port.connection.update_path() - # for connection in port.connections(): - # connection._do_highlight = value - # connection.update_path() - - def contextMenuEvent(self, event): - menu = QtWidgets.QMenu(self) - pos = event.pos() - - # actions - delete_node = QtWidgets.QAction("Delete Node") - edit_node = QtWidgets.QAction("Edit Node") - menu.addAction(delete_node) - - action = menu.exec_(self.mapToGlobal(pos)) - - if action == delete_node: - item_name = self.selectedItems()[0].text() - - if item_name not in ["And", "Not", "Input", "Output"]: - print(f"delete node: {item_name}") - else: - print("Cannot delete default nodes") - - elif action == edit_node: - print("editing node") - - # confirm to open in the editor replacing what is existing - - def delete(self): - """Delete the connection. - Remove any found connections ports by calling :any:`Port.remove_connection`. After connections - have been removed set the stored :any:`Port` to None. Lastly call :any:`QGraphicsScene.removeItem` - on the scene to remove this widget. - """ - - to_delete = [] - - for port in self._ports: - if port.connection: - to_delete.append(port.connection) - - for connection in to_delete: - connection.delete() - - self.scene().removeItem(self) diff --git a/node_editor/gui/node_editor.py b/node_editor/gui/node_editor.py index f87cd71..9963d04 100644 --- a/node_editor/gui/node_editor.py +++ b/node_editor/gui/node_editor.py @@ -1,78 +1,121 @@ -from PySide6 import QtWidgets, QtCore +from __future__ import annotations + +from contextlib import suppress +from typing import Optional + +from PySide6 import QtCore +from PySide6 import QtWidgets +from PySide6.QtWidgets import QGraphicsItem +from PySide6.QtWidgets import QGraphicsScene + +from node_editor.connection import Connection +from node_editor.node import Node +from node_editor.pin import Pin + + +class NodeEditor(QtWidgets.QWidget): # type: ignore + """ + The main class of the node editor. This class handles the logic for creating, connecting, and deleting + nodes and connections. + :ivar connection: A Connection object representing the current connection being created. + :vartype connection: Connection + :ivar port: A Pin object representing the current port being clicked for a new connection. + :vartype port: Pin + :ivar scene: The QGraphicsScene on which the nodes and connections are drawn. + :vartype scene: QGraphicsScene + :ivar _last_selected: The last Node object that was selected. + :vartype _last_selected: Node + """ + + def __init__(self, parent: Optional[QtWidgets.QWidget] = None) -> None: + """ + Constructor for NodeEditor. + + :param parent: The parent widget. + :type parent: QWidget + """ + + super().__init__(parent) + self.setWindowTitle("Node Editor") + self.setGeometry(100, 100, 800, 600) + self.connection: Optional[Connection] = None + self.port: Optional[Pin] = None + self.scene: QGraphicsScene + self._last_selected: Optional[Node] = None + + def install(self, scene: QGraphicsScene) -> None: + """ + Installs the NodeEditor into a QGraphicsScene. + + :param scene: The QGraphicsScene to install the NodeEditor into. + :type scene: QGraphicsScene + """ -from node_editor.gui.connection import Connection -from node_editor.gui.node import Node -from node_editor.gui.port import Port - - -class NodeEditor(QtCore.QObject): - def __init__(self, parent): - super(NodeEditor, self).__init__(parent) - self.connection = None - self.port = None - self.scene = None - self._last_selected = None - - def install(self, scene): self.scene = scene self.scene.installEventFilter(self) - def item_at(self, position): - items = self.scene.items( - QtCore.QRectF(position - QtCore.QPointF(1, 1), QtCore.QSizeF(3, 3)) - ) - - if items: - return items[0] - return None - - def eventFilter(self, watched, event): - if type(event) == QtWidgets.QWidgetItem: + def item_at(self, position: QtCore.QPointF) -> Optional[QGraphicsItem]: + """ + Returns the QGraphicsItem at the given position. + + :param position: The position to check for a QGraphicsItem. + :type position: QPoint + :return: The QGraphicsItem at the position, or None if no item is found. + :rtype: QGraphicsItem + """ + + if self.scene is None: + return None + + items = self.scene.items(QtCore.QRectF(position - QtCore.QPointF(1, 1), QtCore.QSizeF(3, 3))) + return items[0] if items else None + + def eventFilter(self, watched: QtCore.QObject, event: QtCore.QEvent) -> bool: + """ + Filters events from the QGraphicsScene. + + :param watched: The object that is watched. + :type watched: QObject + :param event: The event that is being filtered. + :type event: QEvent + :return: True if the event was filtered, False otherwise. + :rtype: bool + """ + if type(event) is QtWidgets.QWidgetItem: return False if event.type() == QtCore.QEvent.GraphicsSceneMousePress: - if event.button() == QtCore.Qt.LeftButton: item = self.item_at(event.scenePos()) - if isinstance(item, Port): + if isinstance(item, Pin): self.connection = Connection(None) self.scene.addItem(self.connection) - # self.connection.start_port = item self.port = item self.connection.start_pos = item.scenePos() self.connection.end_pos = event.scenePos() self.connection.update_path() return True - elif isinstance(item, Connection): + if isinstance(item, Connection): + print("selected a Connection") self.connection = Connection(None) self.connection.start_pos = item.start_pos self.scene.addItem(self.connection) - # self.connection.start_port = item.start_port - self.port = item.start_port + self.port = item.start_pin self.connection.end_pos = event.scenePos() self.connection.update_start_and_end_pos() # to fix the offset return True - elif isinstance(item, Node): - if self._last_selected: - # If we clear the scene, we loose the last selection - try: - self._last_selected.select_connections(False) - except RuntimeError: - pass + if self._last_selected: + # If we clear the scene, we loose the last selection + with suppress(RuntimeError): + self._last_selected.select_connections(False) + if isinstance(item, Node): item.select_connections(True) self._last_selected = item - else: - try: - if self._last_selected: - self._last_selected.select_connections(False) - except RuntimeError: - pass - self._last_selected = None elif event.button() == QtCore.Qt.RightButton: @@ -81,11 +124,11 @@ def eventFilter(self, watched, event): elif event.type() == QtCore.QEvent.KeyPress: if event.key() == QtCore.Qt.Key_Delete: - for item in self.scene.selectedItems(): - - if isinstance(item, (Connection, Node)): - item.delete() + if isinstance(item, Connection): + self.scene.delete_connection(item) + elif isinstance(item, Node): + self.scene.delete_node_and_reorder(item) return True @@ -100,9 +143,9 @@ def eventFilter(self, watched, event): item = self.item_at(event.scenePos()) # connecting a port - if isinstance(item, Port): + if isinstance(item, Pin) and self.port: if self.port.can_connect_to(item): - print("Making connection") + # print("Making connection") # delete existing connection on the new port if item.connection: @@ -112,16 +155,14 @@ def eventFilter(self, watched, event): self.port.clear_connection() item.clear_connection() - self.connection.start_port = self.port - - self.connection.end_port = item - + self.connection.set_start_pin(self.port) + self.connection.set_end_pin(item) self.connection.update_start_and_end_pos() - self.connection = None else: - print("Deleting connection") + # print("Deleting connection") self.connection.delete() - self.connection = None + + self.connection = None if self.connection: self.connection.delete() @@ -129,4 +170,4 @@ def eventFilter(self, watched, event): self.port = None return True - return super(NodeEditor, self).eventFilter(watched, event) + return bool(super().eventFilter(watched, event)) diff --git a/node_editor/gui/node_graphics.py b/node_editor/gui/node_graphics.py new file mode 100644 index 0000000..dc0958c --- /dev/null +++ b/node_editor/gui/node_graphics.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +from typing import List +from typing import Optional +from typing import Tuple + +from PySide6 import QtCore +from PySide6 import QtGui +from PySide6 import QtWidgets +from PySide6.QtCore import Qt + +from node_editor.common import Node_Status +from node_editor.pin import Pin + + +class Node_Graphics(QtWidgets.QGraphicsItem): # type: ignore + def __init__(self) -> None: + super().__init__() + + self.setFlag(QtWidgets.QGraphicsItem.GraphicsItemFlag.ItemIsMovable) + self.setFlag(QtWidgets.QGraphicsItem.GraphicsItemFlag.ItemIsSelectable) + + self.title_text = "Title" + self.title_color = QtGui.QColor(123, 33, 177) + self.size = QtCore.QRectF() # Size of + self.status = Node_Status.DIRTY + + self.widget = QtWidgets.QWidget() + self.widget.resize(0, 0) + + self.type_text = "base" + + self._width = 20 # The Width of the node + self._height = 20 # the height of the node + self._pins: List[Pin] = [] # A list of pins + self.index: int # An identifier to used when saving and loading the scene + + self.node_color = QtGui.QColor(20, 20, 20, 200) + + self.title_path = QtGui.QPainterPath() # The path for the title + self.type_path = QtGui.QPainterPath() # The path for the type + self.misc_path = QtGui.QPainterPath() # a bunch of other stuff + self.status_path = QtGui.QPainterPath() # A path showing the status of the node + + self.horizontal_margin = 15 # horizontal margin + self.vertical_margin = 15 # vertical margin + + def get_status_color(self) -> QtGui.QColor: + if self.status == Node_Status.CLEAN: + return QtGui.QColor(0, 255, 0) + elif self.status == Node_Status.DIRTY: + return QtGui.QColor(255, 165, 0) + elif self.status == Node_Status.ERROR: + return QtGui.QColor(255, 0, 0) + + def boundingRect(self) -> QtCore.QRectF: + return self.size + + def set_color( + self, title_color: Tuple[int, int, int], background_color: Optional[Tuple[int, int, int]] = None + ) -> None: + + # default title_color (123, 33, 177) + + if background_color is None: + background_color = (20, 20, 20) + self.title_color = QtGui.QColor(title_color[0], title_color[1], title_color[2]) + self.node_color = QtGui.QColor(background_color[0], background_color[1], background_color[2]) + + def paint( + self, + painter: QtGui.QPainter, + option: QtWidgets.QStyleOptionGraphicsItem = None, + widget: QtWidgets.QWidget = None, + ) -> None: + """ + Paints the node on the given painter. + + Args: + painter (QtGui.QPainter): The painter to use for drawing the node. + option (QStyleOptionGraphicsItem): The style options to use for drawing the node (optional). + widget (QWidget): The widget to use for drawing the node (optional). + """ + + painter.setPen(self.node_color.lighter()) + painter.setBrush(self.node_color) + painter.drawPath(self.path) + + gradient = QtGui.QLinearGradient() + gradient.setStart(0, -90) + gradient.setFinalStop(0, 0) + gradient.setColorAt(0, self.title_color) # Start color (white) + gradient.setColorAt(1, self.title_color.darker()) # End color (blue) + + painter.setBrush(QtGui.QBrush(gradient)) + painter.setPen(self.title_color) + painter.drawPath(self.title_bg_path.simplified()) + + painter.setPen(QtCore.Qt.NoPen) + painter.setBrush(QtCore.Qt.white) + + painter.drawPath(self.title_path) + painter.drawPath(self.type_path) + painter.drawPath(self.misc_path) + + # Status path + painter.setBrush(self.get_status_color()) + painter.setPen(self.get_status_color().darker()) + painter.drawPath(self.status_path.simplified()) + + # Draw the highlight + if self.isSelected(): + painter.setPen(QtGui.QPen(self.title_color.lighter(), 2)) + painter.setBrush(Qt.NoBrush) + painter.drawPath(self.path) + + def build(self) -> None: + """ + Builds the node by constructing its graphical representation. + + This method calculates the dimensions of the node, sets the fonts for various elements, and adds the necessary + graphical components to the node, such as the title, type, and pins. Once the graphical representation of the + node is constructed, the `setPath` method is called to set the path for the node. + + Returns: + None. + """ + # configure the widget side of things. We need to get the size of the widget beforebuilding the rest of the node + self.init_widget() + self.widget.setStyleSheet("background-color: " + self.node_color.name() + ";") + self.title_path = QtGui.QPainterPath() # reset + self.type_path = QtGui.QPainterPath() # The path for the type + self.misc_path = QtGui.QPainterPath() # a bunch of other stuff + + bg_height = 35 # background title height + + total_width = self.widget.size().width() + self.path = QtGui.QPainterPath() # The main path + # The fonts what will be used + title_font = QtGui.QFont("Lucida Sans Unicode", pointSize=12) + title_type_font = QtGui.QFont("Lucida Sans Unicode", pointSize=8) + pin_font = QtGui.QFont("Lucida Sans Unicode") + + # Get the dimentions of the title and type + title_dim = { + "w": QtGui.QFontMetrics(title_font).horizontalAdvance(self.title_text), + "h": QtGui.QFontMetrics(title_font).height(), + } + + title_type_dim = { + "w": QtGui.QFontMetrics(title_type_font).horizontalAdvance(f"{self.type_text}"), + "h": QtGui.QFontMetrics(title_type_font).height(), + } + + # Get the max width + for dim in [title_dim["w"], title_type_dim["w"]]: + if dim > total_width: + total_width = dim + + # Add the width for the pins + total_width += self.horizontal_margin # Increased width for spacing + + # Add both the title and type height together for the total height + # total_height = sum([title_dim["h"], title_type_dim["h"]]) + self.widget.size().height() + total_height = bg_height + self.widget.size().height() + + pin_dim = None + # Add the heigth for each of the pins + exec_height_added = False + for pin in self._pins: + pin_dim = { + "w": QtGui.QFontMetrics(pin_font).horizontalAdvance(pin.name), + "h": QtGui.QFontMetrics(pin_font).height(), + } + + if pin_dim["w"] > total_width: + total_width = pin_dim["w"] + + if pin.execution and not exec_height_added or not pin.execution: + total_height += pin_dim["h"] + exec_height_added = True + + # Add the margin to the total_width + total_width += self.horizontal_margin + # total_height += self.vertical_margin + + # Draw the background rectangle + self.size = QtCore.QRectF(-total_width / 2, -total_height / 2, total_width, total_height) + self.path.addRoundedRect(-total_width / 2, -total_height / 2, total_width, total_height + 10, 5, 5) + + # Draw the status rectangle + self.status_path.setFillRule(Qt.FillRule.WindingFill) + self.status_path.addRoundedRect(total_width / 2 - 12, -total_height / 2 + 2, 10, 10, 2, 2) + # self.status_path.addRect(total_width / 2 - 10, -total_height / 2, 5, 5) + # self.status_path.addRect(total_width / 2 - 10, -total_height / 2 + 15, 5, 5) + # self.status_path.addRect(total_width / 2 - 5, -total_height / 2 + 15, 5, 5) + + # The color on the title + self.title_bg_path = QtGui.QPainterPath() # The title background path + self.title_bg_path.setFillRule(Qt.FillRule.WindingFill) + self.title_bg_path.addRoundedRect(-total_width / 2, -total_height / 2, total_width, bg_height, 5, 5) + self.title_bg_path.addRect(-total_width / 2, -total_height / 2 + bg_height - 10, 10, 10) # bottom left corner + self.title_bg_path.addRect( + total_width / 2 - 10, -total_height / 2 + bg_height - 10, 10, 10 + ) # bottom right corner + + # Draw the title + self.title_path.addText( + -total_width / 2 + 5, + (-total_height / 2) + title_dim["h"] / 2 + 5, + title_font, + self.title_text, + ) + + # Draw the type + self.type_path.addText( + -total_width / 2 + 5, + (-total_height / 2) + title_dim["h"] + 5, + title_type_font, + f"{self.type_text}", + ) + + # Position the widget in the center + self.widget.move(int(-self.widget.size().width() / 2), int(-self.widget.size().height() / 2)) + + # Position the pins. Execution pins stay on the same row + if pin_dim: + # y = (-total_height / 2) + title_dim["h"] + title_type_dim["h"] + 5 + y = bg_height - total_height / 2 - 10 + + # Do the execution pins + exe_shifted = False + for pin in self._pins: + if not pin.execution: + continue + if not exe_shifted: + y += pin_dim["h"] + exe_shifted = True + if pin.is_output: + pin.setPos(total_width / 2 - 10, y) + else: + pin.setPos(-total_width / 2 + 10, y) + + # Do the rest of the pins + for pin in self._pins: + if pin.execution: + continue + y += pin_dim["h"] + + if pin.is_output: + pin.setPos(total_width / 2 - 10, y) + else: + pin.setPos(-total_width / 2 + 10, y) + + self._width = total_width + self._height = total_height + + # move the widget to the bottom + self.widget.move(int(-self.widget.size().width() / 2), int(total_height / 2 - self.widget.size().height() + 5)) diff --git a/node_editor/gui/node_list.py b/node_editor/gui/node_list.py index 6ab29d8..05624cc 100644 --- a/node_editor/gui/node_list.py +++ b/node_editor/gui/node_list.py @@ -1,51 +1,71 @@ -from PySide6 import QtWidgets, QtCore, QtGui +from __future__ import annotations +from types import ModuleType +from typing import Any +from typing import Dict +from typing import Optional +from typing import Union -class NodeList(QtWidgets.QListWidget): - def __init__(self, parent=None): - super(NodeList, self).__init__(parent) +from PySide6 import QtCore +from PySide6 import QtGui +from PySide6 import QtWidgets - self.addItem("Input") - self.addItem("Output") - self.addItem("And") - self.addItem("Not") - self.addItem("Nor") - self.setDragEnabled(True) # enable dragging +class CustomQListWidgetItem(QtWidgets.QListWidgetItem): # type: ignore + module: Any + class_name: Any + + +class CustomQMimeData(QtCore.QMimeData): # type: ignore + item: CustomQListWidgetItem - def contextMenuEvent(self, event): - menu = QtWidgets.QMenu(self) - pos = event.pos() - # actions - delete_node = QtWidgets.QAction("Delete Node") - edit_node = QtWidgets.QAction("Edit Node") - menu.addAction(delete_node) +class ImportData: + def __init__(self, module: str, class_: str): + self.module: Any = module + self.class_: Any = class_ - action = menu.exec_(self.mapToGlobal(pos)) - if action == delete_node: - item_name = self.selectedItems()[0].text() +# class NodeList(QtWidgets.QListWidget): # type: ignore +# module: str +# class_: str - if item_name not in ["And", "Not", "Input", "Output"]: - print(f"delete node: {item_name}") - else: - print("Cannot delete default nodes") - elif action == edit_node: - print("editing node") +class NodeList(QtWidgets.QListWidget): # type: ignore + def __init__(self, parent: Optional[QtWidgets.QWidget] = None) -> None: + super().__init__(parent) + self.setDragEnabled(True) # enable dragging + + def update_project(self, imports: Dict[str, Dict[str, Union[type, ModuleType]]]) -> None: - # confirm to open in the editor replacing what is existing + # make an item for each custom class + for name, data in imports.items(): + name = name.replace("_Node", "") - def mousePressEvent(self, event): + item = CustomQListWidgetItem(name) + + item.module = data["module"] + item.class_name = data["class"] + self.addItem(item) + + def mousePressEvent(self, event: QtGui.QMouseEvent) -> None: item = self.itemAt(event.pos()) - name = item.text() - drag = QtGui.QDrag(self) - mime_data = QtCore.QMimeData() + if isinstance(item, CustomQListWidgetItem) and item.text(): + name = item.text() + + drag = QtGui.QDrag(self) + mime_data = CustomQMimeData() + mime_data.setText(name) + mime_data.item = item + drag.setMimeData(mime_data) + + # Drag needs a pixmap or else it'll error due to a null pixmap + pixmap = QtGui.QPixmap(16, 16) + pixmap.fill(QtGui.QColor("darkgray")) + drag.setPixmap(pixmap) + drag.exec_() - mime_data.setText(name) - drag.setMimeData(mime_data) - drag.exec_() + print("Inside drag event from node list") - super(NodeList, self).mousePressEvent(event) + super().mousePressEvent(event) diff --git a/node_editor/gui/node_widget.py b/node_editor/gui/node_widget.py index 3c3c100..6390aef 100644 --- a/node_editor/gui/node_widget.py +++ b/node_editor/gui/node_widget.py @@ -1,85 +1,86 @@ -# from PySide6.QtWidgets import QWidget, QVBoxLayout, QGraphicsScene +from __future__ import annotations -from PySide6 import QtWidgets, QtGui +import json +from typing import Any +from typing import Dict +from typing import List +from typing import Optional -from node_editor.gui.view import View -from node_editor.gui.node import Node +from PySide6 import QtGui +from PySide6 import QtWidgets + +from node_editor.connection import Connection from node_editor.gui.node_editor import NodeEditor +from node_editor.gui.view import View +from node_editor.node import Node +from node_editor.pin import Pin -# import lorem -# import random - - -def create_input(): - node = Node() - node.title = "A" - node.type_text = "input" - node.add_port(name="output", is_output=True) - node.build() - return node - - -def create_output(): - node = Node() - node.title = "A" - node.type_text = "output" - node.add_port(name="input", is_output=False) - node.build() - return node - - -def create_and(): - node = Node() - node.title = "AND" - node.type_text = "built-in" - node.add_port(name="input A", is_output=False) - node.add_port(name="input B", is_output=False) - node.add_port(name="output", is_output=True) - node.build() - return node - - -def create_not(): - node = Node() - node.title = "NOT" - node.type_text = "built-in" - node.add_port(name="input", is_output=False) - node.add_port(name="output", is_output=True) - node.build() - return node - - -def create_nor(): - node = Node() - node.title = "NOR" - node.type_text = "built-in" - node.add_port(name="input", is_output=False) - node.add_port(name="output", is_output=True) - node.build() - return node - - -class NodeScene(QtWidgets.QGraphicsScene): - def dragEnterEvent(self, e): + +class NodeScene(QtWidgets.QGraphicsScene): # type: ignore + def dragEnterEvent(self, e: QtGui.QDragEnterEvent) -> None: e.acceptProposedAction() - def dropEvent(self, e): - # find item at these coordinates - item = self.itemAt(e.scenePos()) - if item.setAcceptDrops == True: - # pass on event to item at the coordinates - try: - item.dropEvent(e) - except RuntimeError: - pass # This will supress a Runtime Error generated when dropping into a widget with no ProxyWidget + def dropEvent(self, e: QtGui.QDropEvent) -> None: + item = self.itemAt(e.scenePos(), QtGui.QTransform()) + if item and hasattr(item, "setAcceptDrops"): + item.dropEvent(e) - def dragMoveEvent(self, e): + def dragMoveEvent(self, e: QtGui.QDragMoveEvent) -> None: e.acceptProposedAction() + def get_items_by_type(self, item_class: type) -> List[Any]: + items = [] + for item in self.items(): + print(f"current item: {item}, class: {item_class}") + if isinstance(item, item_class): + items.append(item) + return items + + def get_total_nodes(self) -> int: + return len(self.get_items_by_type(Node)) + + # TODO Scene should delete the node + # TODO Scene should reorder Node indexes after Node delete + + def delete_node_and_reorder(self, node_to_delete: Node) -> None: + # Delete the node + node_to_delete.delete() + + # Make a mapping of the new indexes for the nodes + nodes = self.get_items_by_type(Node) + new_index_mapping = {node.index: new_index for new_index, node in enumerate(nodes)} + + # Reindex the nodes + for node in nodes: + node.index = new_index_mapping[str(node.index)] + + def delete_connection(self, connection_to_delete: Connection) -> None: + # We an safly delete a connection without having to do anything extra + connection_to_delete.delete() + -class NodeWidget(QtWidgets.QWidget): - def __init__(self, parent): - super(NodeWidget, self).__init__(parent) +class NodeWidget(QtWidgets.QWidget): # type: ignore + """ + Widget for creating and displaying a node editor. + + Attributes: + node_editor (NodeEditor): The node editor object. + scene (NodeScene): The scene object for the node editor. + view (View): The view object for the node editor. + """ + + def __init__(self, parent: Optional[QtWidgets.QWidget] = None) -> None: + """ + Initializes the NodeWidget object. + + Args: + parent (QWidget): The parent widget. + """ + super().__init__(parent) + + self.node_lookup: Dict[int, Node] = ( + {} + ) # A dictionary of nodes, by uuids for faster looking up. Refactor this in the future main_layout = QtWidgets.QVBoxLayout() main_layout.setContentsMargins(0, 0, 0, 0) self.setLayout(main_layout) @@ -95,22 +96,104 @@ def __init__(self, parent): self.view.request_node.connect(self.create_node) - def create_node(self, name): - print("creating node:", name) - - if name == "Input": - node = create_input() - - elif name == "Output": - node = create_output() - elif name == "And": - node = create_and() - elif name == "Not": - node = create_not() - elif name == "Nor": - node = create_nor() - + def create_node(self, node: Node, index: int) -> None: + node.index = index self.scene.addItem(node) - pos = self.view.mapFromGlobal(QtGui.QCursor.pos()) node.setPos(self.view.mapToScene(pos)) + + def load_scene(self, json_path: str, imports: Dict[str, Dict[str, Node]]) -> None: + # load the scene json file + data = None + with open(json_path) as f: + data = json.load(f) + + # clear out the node lookup + self.node_lookup = {} + + # Add the nodes + if data: + for node in data["nodes"]: + try: + info = imports[node["type"]] + except KeyError: + continue + node_item = info["class"]() + node_item.index = node["index"] + self.scene.addItem(node_item) + node_item.setPos(node["x"], node["y"]) + + self.node_lookup[node["index"]] = node_item + + # Add the connections + for c in data["connections"]: + connection = Connection(None) + self.scene.addItem(connection) + + try: + start_pin = self.node_lookup[c["start_id"]].get_pin(c["start_pin"]) + end_pin = self.node_lookup[c["end_id"]].get_pin(c["end_pin"]) + except KeyError: # Node might be missing so we skip it + continue + + if start_pin: + connection.set_start_pin(start_pin) + + if end_pin: + connection.set_end_pin(end_pin) + connection.update_start_and_end_pos() + + def save_project(self, json_path: str) -> None: + # from collections import OrderedDict + + # TODO possibly an ordered dict so things stay in order (better for git changes, and manual editing) + # Maybe connections will need an index for each so they can be sorted and kept in order. + scene: Dict[str, List[Any]] = {"nodes": [], "connections": []} + + # Need the nodes, and connections of ports to nodes + for item in self.scene.items(): + # Connections + if isinstance(item, Connection): + # print(f"Name: {item}") + nodes = item.nodes() + if nodes[0]: + start_id = str(nodes[0].index) + else: + continue + end_id = str(nodes[1].index) # type: ignore + start_pin = item.start_pin.name # type: ignore + end_pin = item.end_pin.name # type: ignore + # print(f"Node ids {start_id, end_id}") + # print(f"connected ports {item.start_pin.name(), item.end_pin.name()}") + + connection = { + "start_id": start_id, + "end_id": end_id, + "start_pin": start_pin, + "end_pin": end_pin, + } + scene["connections"].append(connection) + continue + + # Pins + if isinstance(item, Pin): + continue + + # Nodes + if isinstance(item, Node): + # print("found node") + pos = item.pos().toPoint() + x, y = pos.x(), pos.y() + # print(f"pos: {x, y}") + + obj_type = type(item).__name__ + # print(f"node type: {obj_type}") + + node_id = str(item.index) + + node = {"type": obj_type, "x": x, "y": y, "index": node_id} + scene["nodes"].append(node) + + # Write the items_info dictionary to a JSON file + with open(json_path, "w") as f: + json.dump(scene, f, indent=4) diff --git a/node_editor/gui/palette.py b/node_editor/gui/palette.py deleted file mode 100644 index 89e107c..0000000 --- a/node_editor/gui/palette.py +++ /dev/null @@ -1,23 +0,0 @@ -from PySide6 import QtGui -from PySide6.QtGui import QPalette, QColor -from PySide6.QtCore import Qt - -palette = QtGui.QPalette() - -palette.setColor(QPalette.Window, QColor(27, 35, 38)) -palette.setColor(QPalette.WindowText, QColor(234, 234, 234)) -palette.setColor(QPalette.Base, QColor(27, 35, 38)) -palette.setColor(QPalette.Disabled, QPalette.Base, QColor(27 + 5, 35 + 5, 38 + 5)) -palette.setColor(QPalette.AlternateBase, QColor(12, 15, 16)) -palette.setColor(QPalette.ToolTipBase, QColor(27, 35, 38)) -palette.setColor(QPalette.ToolTipText, Qt.white) -palette.setColor(QPalette.Text, QColor(200, 200, 200)) -palette.setColor(QPalette.Disabled, QPalette.Text, QColor(100, 100, 100)) -palette.setColor(QPalette.Button, QColor(27, 35, 38)) -palette.setColor(QPalette.ButtonText, Qt.white) -palette.setColor(QPalette.BrightText, QColor(100, 215, 222)) -palette.setColor(QPalette.Link, QColor(126, 71, 130)) -palette.setColor(QPalette.Highlight, QColor(126, 71, 130)) -palette.setColor(QPalette.HighlightedText, Qt.white) -palette.setColor(QPalette.Disabled, QPalette.Light, Qt.black) -palette.setColor(QPalette.Disabled, QPalette.Shadow, QColor(12, 15, 16)) diff --git a/node_editor/gui/pin_graphics.py b/node_editor/gui/pin_graphics.py new file mode 100644 index 0000000..31981ab --- /dev/null +++ b/node_editor/gui/pin_graphics.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from typing import Optional +from typing import Union + +from PySide6 import QtCore +from PySide6 import QtGui +from PySide6 import QtWidgets +from PySide6.QtCore import Qt + + +class Pin_Graphics(QtWidgets.QGraphicsPathItem): # type: ignore + def __init__(self, parent: Optional[QtWidgets.QGraphicsItem], scene: QtWidgets.QGraphicsScene) -> None: + super().__init__(parent) + + self.radius_: float = 5 + self.margin: int = 2 + + self.execution: bool = False + + path: QtGui.QPainterPath = QtGui.QPainterPath() + path.addEllipse(-self.radius_, -self.radius_, 2 * self.radius_, 2 * self.radius_) + + self.setPath(path) + + self.setFlag(QtWidgets.QGraphicsPathItem.ItemSendsScenePositionChanges) + self.font: QtGui.QFont = QtGui.QFont() + self.font_metrics: QtGui.QFontMetrics = QtGui.QFontMetrics(self.font) + + self.pin_text_height: int = self.font_metrics.height() + + self.is_output: bool = False + + self.text_path: QtGui.QPainterPath = QtGui.QPainterPath() + + self.name: str = "" # Add this line to define self.name + self.connection: Optional[Union[QtWidgets.QGraphicsItem, QtCore.QObject]] = ( + None # Add this line to define self.connection + ) + + def set_execution(self, execution: bool) -> None: + if execution: + path: QtGui.QPainterPath = QtGui.QPainterPath() + + points: list[QtCore.QPointF] = [ + QtCore.QPointF(-6, -7), + QtCore.QPointF(-6, 7), + QtCore.QPointF(-2, 7), + QtCore.QPointF(6, 0), + QtCore.QPointF(-2, -7), + QtCore.QPointF(-6, -7), + ] + path.addPolygon(QtGui.QPolygonF(points)) + self.setPath(path) + + def set_name(self, name: str) -> None: + self.name = name # Add this line to set self.name + nice_name: str = self.name.replace("_", " ").title() + self.pin_text_width: int = self.font_metrics.horizontalAdvance(nice_name) + + if self.is_output: + x = -self.radius_ - self.margin - self.pin_text_width + else: + x = self.radius_ + self.margin + + y: float = self.pin_text_height / 4 + + self.text_path.addText(x, y, self.font, nice_name) + + def paint( + self, + painter: QtGui.QPainter, + option: Optional[QtWidgets.QStyleOptionGraphicsItem] = None, + widget: Optional[QtWidgets.QWidget] = None, + ) -> None: + if self.execution: + painter.setPen(Qt.GlobalColor.white) + else: + painter.setPen(Qt.GlobalColor.green) + + if self.is_connected(): + if self.execution: + painter.setBrush(Qt.GlobalColor.white) + else: + painter.setBrush(Qt.GlobalColor.green) + else: + painter.setBrush(Qt.BrushStyle.NoBrush) + + painter.drawPath(self.path()) + + # Draw text + if not self.execution: + painter.setPen(Qt.NoPen) + painter.setBrush(Qt.white) + painter.drawPath(self.text_path) + + def itemChange(self, change: QtWidgets.QGraphicsItem.GraphicsItemChange, value: object) -> object: + if change == QtWidgets.QGraphicsItem.GraphicsItemChange.ItemScenePositionHasChanged and self.connection: + if hasattr(self.connection, "update_start_and_end_pos"): + self.connection.update_start_and_end_pos() + return value + + def is_connected(self) -> bool: + # Add this method to resolve the 'is_connected' call in the paint method + return self.connection is not None diff --git a/node_editor/gui/port.py b/node_editor/gui/port.py deleted file mode 100644 index d71775f..0000000 --- a/node_editor/gui/port.py +++ /dev/null @@ -1,105 +0,0 @@ -from PySide6 import QtWidgets, QtGui, QtCore - - -class Port(QtWidgets.QGraphicsPathItem): - def __init__(self, parent, scene): - super(Port, self).__init__(parent) - - self.radius_ = 5 - self.margin = 2 - - path = QtGui.QPainterPath() - path.addEllipse( - -self.radius_, -self.radius_, 2 * self.radius_, 2 * self.radius_ - ) - self.setPath(path) - - self.setFlag(QtWidgets.QGraphicsPathItem.ItemSendsScenePositionChanges) - self.font = QtGui.QFont() - self.font_metrics = QtGui.QFontMetrics(self.font) - - self.port_text_height = self.font_metrics.height() - - self._is_output = False - self._name = None - self.margin = 2 - - self.m_node = None - self.connection = None - - self.text_path = QtGui.QPainterPath() - - def set_is_output(self, is_output): - self._is_output = is_output - - def set_name(self, name): - self._name = name - nice_name = self._name.replace("_", " ").title() - self.port_text_width = self.font_metrics.horizontalAdvance(nice_name) - - if self._is_output: - x = -self.radius_ - self.margin - self.port_text_width - y = self.port_text_height / 4 - - self.text_path.addText(x, y, self.font, nice_name) - - else: - x = self.radius_ + self.margin - y = self.port_text_height / 4 - - self.text_path.addText(x, y, self.font, nice_name) - - def set_node(self, node): - self.m_node = node - - def set_port_flags(self, flags): - self.m_port_flags = flags - - def set_ptr(self, ptr): - self.m_ptr = ptr - - def name(self): - return self._name - - def is_output(self): - return self._is_output - - def node(self): - return self.m_node - - def paint(self, painter, option=None, widget=None): - painter.setPen(QtGui.QPen(1)) - painter.setBrush(QtCore.Qt.green) - painter.drawPath(self.path()) - - painter.setPen(QtCore.Qt.NoPen) - painter.setBrush(QtCore.Qt.white) - painter.drawPath(self.text_path) - - def clear_connection(self): - if self.connection: - self.connection.delete() - - def can_connect_to(self, port): - print(port.node(), self.node()) - if not port: - return False - if port.node() == self.node(): - return False - - if self._is_output == port._is_output: - return False - - return True - - def is_connected(self): - if self.connection: - return True - return False - - def itemChange(self, change, value): - if change == QtWidgets.QGraphicsItem.ItemScenePositionHasChanged: - if self.connection: - self.connection.update_start_and_end_pos() - - return value diff --git a/node_editor/gui/view.py b/node_editor/gui/view.py index 31d6dfe..2a41e04 100644 --- a/node_editor/gui/view.py +++ b/node_editor/gui/view.py @@ -1,38 +1,49 @@ -from PySide6 import QtCore, QtGui, QtWidgets, QtOpenGLWidgets +from __future__ import annotations -from node_editor.gui.connection import Connection -from node_editor.gui.node import Node +from typing import Optional +from PySide6 import QtCore +from PySide6 import QtGui +from PySide6 import QtOpenGLWidgets +from PySide6 import QtWidgets -class View(QtWidgets.QGraphicsView): - _background_color = QtGui.QColor(38, 38, 38) +from node_editor.node import Node - _grid_pen_s = QtGui.QPen(QtGui.QColor(52, 52, 52, 255), 0.5) - _grid_pen_l = QtGui.QPen(QtGui.QColor(22, 22, 22, 255), 1.0) - _grid_size_fine = 15 - _grid_size_course = 150 +class View(QtWidgets.QGraphicsView): # type: ignore + """ + View class for node editor. + """ - _mouse_wheel_zoom_rate = 0.0015 + _background_color: QtGui.QColor = QtGui.QColor(38, 38, 38) - request_node = QtCore.Signal(str) + _grid_pen_s: QtGui.QPen = QtGui.QPen(QtGui.QColor(52, 52, 52, 255), 0.5) + _grid_pen_l: QtGui.QPen = QtGui.QPen(QtGui.QColor(22, 22, 22, 255), 1.0) - def __init__(self, parent): - super(View, self).__init__(parent) + _grid_size_fine: int = 15 + _grid_size_course: int = 150 + + _mouse_wheel_zoom_rate: float = 0.0015 + + request_node = QtCore.Signal(object, int) + + def __init__(self, parent: QtWidgets.QWidget) -> None: + super().__init__(parent) self.setRenderHint(QtGui.QPainter.Antialiasing) - self._manipulationMode = 0 + self._manipulationMode: int = 0 - gl_format = QtGui.QSurfaceFormat() + gl_format: QtGui.QSurfaceFormat = QtGui.QSurfaceFormat() gl_format.setSamples(10) QtGui.QSurfaceFormat.setDefaultFormat(gl_format) - gl_widget = QtOpenGLWidgets.QOpenGLWidget() + gl_widget: QtOpenGLWidgets.QOpenGLWidget = QtOpenGLWidgets.QOpenGLWidget() - self.currentScale = 1 - self._pan = False - self._pan_start_x = 0 - self._pan_start_y = 0 - self._numScheduledScalings = 0 - self.lastMousePos = QtCore.QPoint() + self.currentScale: float = 1 + self._pan: bool = False + self._pan_start_x: int = 0 + self._pan_start_y: int = 0 + self._numScheduledScalings: int = 0 + self.lastMousePos: QtCore.QPoint = QtCore.QPoint() + self.anim: Optional[QtCore.QTimeLine] = None self.setViewport(gl_widget) @@ -42,14 +53,18 @@ def __init__(self, parent): self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff) self.setFrameShape(QtWidgets.QFrame.NoFrame) - def wheelEvent(self, event): + def wheelEvent(self, event: QtGui.QWheelEvent) -> None: + """ + Handles the wheel events, e.g. zoom in/out. + :param event: Wheel event. + """ # sometimes you can triger the wheen when panning so we disable when panning if self._pan: return - num_degrees = event.delta() / 8.0 - num_steps = num_degrees / 5.0 + num_degrees = event.angleDelta() / 8.0 + num_steps = num_degrees.y() / 5.0 self._numScheduledScalings += num_steps # If the user moved the wheel another direction, we reset previously scheduled scalings @@ -63,21 +78,34 @@ def wheelEvent(self, event): self.anim.finished.connect(self.anim_finished) self.anim.start() - def scaling_time(self, x): + def scaling_time(self, x: float) -> None: + """ + Updates the current scale based on the wheel events. + + :param x: The value of the current time. + """ factor = 1.0 + self._numScheduledScalings / 300.0 self.currentScale *= factor self.scale(factor, factor) - def anim_finished(self): + def anim_finished(self) -> None: + """ + Called when the zoom animation is finished. + """ if self._numScheduledScalings > 0: self._numScheduledScalings -= 1 else: self._numScheduledScalings += 1 - def drawBackground(self, painter, rect): + def drawBackground(self, painter: QtGui.QPainter, rect: QtCore.QRectF) -> None: + """ + Draws the background for the node editor view. + :param painter: The painter to draw with. + :param rect: The rectangle to be drawn. + """ painter.fillRect(rect, self._background_color) left = int(rect.left()) - (int(rect.left()) % self._grid_size_fine) @@ -123,69 +151,73 @@ def drawBackground(self, painter, rect): y += self._grid_size_course painter.drawLines(gridLines) - return super(View, self).drawBackground(painter, rect) + super().drawBackground(painter, rect) - def contextMenuEvent(self, event): - cursor = QtGui.QCursor() - # origin = self.mapFromGlobal(cursor.pos()) - pos = self.mapFromGlobal(cursor.pos()) + def contextMenuEvent(self, event: QtGui.QContextMenuEvent) -> None: + """ + This method is called when a context menu event is triggered in the view. It finds the item at the + event position and shows a context menu if the item is a Node. + """ item = self.itemAt(event.pos()) if item: if isinstance(item, Node): - print("Found Node", item) - menu = QtWidgets.QMenu(self) - - hello_action = QtWidgets.QAction("Hello", self) - - menu.addAction(hello_action) - action = menu.exec_(self.mapToGlobal(pos)) - - if action == hello_action: - print("Hello") - - def dragEnterEvent(self, e): - + def dragEnterEvent(self, e: QtGui.QDragEnterEvent) -> None: + """ + This method is called when a drag and drop event enters the view. It checks if the mime data format + is "text/plain" and accepts or ignores the event accordingly. + """ if e.mimeData().hasFormat("text/plain"): e.accept() else: e.ignore() - def dropEvent(self, e): - drop_node_name = e.mimeData().text() - self.request_node.emit(drop_node_name) - - def mousePressEvent(self, event): + def dropEvent(self, e: QtGui.QDropEvent) -> None: + """ + This method is called when a drag and drop event is dropped onto the view. It retrieves the name of the + dropped node from the mime data and emits a signal to request the creation of the corresponding node. + """ + node = e.mimeData().item.class_name + next_index = self.scene().get_total_nodes() + self.request_node.emit(node(), next_index) + + def mousePressEvent(self, event: QtGui.QMouseEvent) -> None: + """ + This method is called when a mouse press event occurs in the view. It sets the cursor to a closed + hand cursor and enables panning if the middle mouse button is pressed. + """ if event.button() == QtCore.Qt.MiddleButton: self._pan = True self._pan_start_x = event.x() self._pan_start_y = event.y() self.setCursor(QtCore.Qt.ClosedHandCursor) - return super(View, self).mousePressEvent(event) + super().mousePressEvent(event) - def mouseReleaseEvent(self, event): + def mouseReleaseEvent(self, event: QtGui.QMouseEvent) -> None: + """ + This method is called when a mouse release event occurs in the view. It sets the cursor back to the + arrow cursor and disables panning if the middle mouse button is released. + """ if event.button() == QtCore.Qt.MiddleButton: self._pan = False self.setCursor(QtCore.Qt.ArrowCursor) - return super(View, self).mouseReleaseEvent(event) + super().mouseReleaseEvent(event) - def mouseMoveEvent(self, event): + def mouseMoveEvent(self, event: QtGui.QMouseEvent) -> None: + """ + This method is called when a mouse move event occurs in the view. It pans the view if the middle mouse button is + pressed and moves the mouse. + """ if self._pan: + self.horizontalScrollBar().setValue(self.horizontalScrollBar().value() - (event.x() - self._pan_start_x)) - self.horizontalScrollBar().setValue( - self.horizontalScrollBar().value() - (event.x() - self._pan_start_x) - ) - - self.verticalScrollBar().setValue( - self.verticalScrollBar().value() - (event.y() - self._pan_start_y) - ) + self.verticalScrollBar().setValue(self.verticalScrollBar().value() - (event.y() - self._pan_start_y)) self._pan_start_x = event.x() self._pan_start_y = event.y() - return super(View, self).mouseMoveEvent(event) - + super().mouseMoveEvent(event) diff --git a/node_editor/node.py b/node_editor/node.py new file mode 100644 index 0000000..ceae71d --- /dev/null +++ b/node_editor/node.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from typing import List +from typing import Optional + +from node_editor.gui.node_graphics import Node_Graphics +from node_editor.pin import Pin + +# from PySide6 import QtCore +# from PySide6 import QtGui +# from PySide6 import QtWidgets +# from PySide6.QtCore import Qt +# rom node_editor.common import Node_Status + + +class Node(Node_Graphics): + def __init__(self) -> None: + super().__init__() + self._pins: List[Pin] = [] + + # Override me + def init_widget(self) -> None: + pass + + def compute(self) -> None: + raise NotImplementedError("compute is not implemented") + + def execute(self) -> None: + # Get the values from the input pins + self.execute_inputs() + + # Compute the value + self.compute() + + # execute nodes connected to output + self.execute_outputs() + + def execute_inputs(self) -> None: + pass + + def execute_outputs(self) -> None: + pass + + def delete(self) -> None: + """Deletes the connection. + + This function removes any connected pins by calling :any:`Port.remove_connection` for each pin + connected to this connection. After all connections have been removed, the stored :any:`Port` + references are set to None. Finally, :any:`QGraphicsScene.removeItem` is called on the scene to + remove this widget. + + Returns: + None + """ + + to_delete = [pin.connection for pin in self._pins if pin.connection] + for connection in to_delete: + connection.delete() + + self.scene().removeItem(self) + + def get_pin(self, name: str) -> Optional[Pin]: + for pin in self._pins: + if pin.name == name: + return pin + return None + + def add_pin(self, name: str, is_output: bool, execution: bool = False) -> None: + """ + Adds a new pin to the node. + + Args: + name (str): The name of the new pin. + is_output (bool, optional): True if the new pin is an output pin, False if it's an input pin. Default is + False. + flags (int, optional): A set of flags to apply to the new pin. Default is 0. + ptr (Any, optional): A pointer to associate with the new pin. Default is None. + + Returns: + None: This method doesn't return anything. + + """ + pin = Pin(self, self.scene()) + pin.is_output = is_output + pin.set_name(name) + pin.node = self + pin.set_execution(execution) + + self._pins.append(pin) + + def select_connections(self, value: bool) -> None: + """ + Sets the highlighting of all connected pins to the specified value. + + This method takes a boolean value `value` as input and sets the `_do_highlight` attribute of all + connected pins to this value. If a pin is not connected, this method does nothing for that pin. + After setting the `_do_highlight` attribute for all connected pins, the `update_path` method is + called for each connection. + + Args: + value: A boolean value indicating whether to highlight the connected pins or not. + + Returns: + None. + """ + + for pin in self._pins: + if pin.connection: + pin.connection._do_highlight = value + pin.connection.update_path() diff --git a/node_editor/pin.py b/node_editor/pin.py new file mode 100644 index 0000000..4cb1345 --- /dev/null +++ b/node_editor/pin.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import Any +from typing import Optional + +from node_editor.gui.pin_graphics import Pin_Graphics + + +class Pin(Pin_Graphics): + def __init__(self, parent: Any, scene: Any) -> None: + super().__init__(parent, scene) + + self.name: str = "" + self.node: Optional[Any] = None + self.connection: Optional[Any] = None + self.execution: bool = False + + def set_execution(self, execution: bool) -> None: + self.execution = execution + super().set_execution(execution) + + def set_name(self, name: str) -> None: + self.name = name + super().set_name(name) + + def clear_connection(self) -> None: + if self.connection: + self.connection.delete() + + def can_connect_to(self, pin: Optional[Pin]) -> bool: + if not pin: + return False + if pin.node == self.node: + return False + + return self.is_output != pin.is_output + + def is_connected(self) -> bool: + return bool(self.connection) + + def get_data(self) -> Any: + pass + # Get a list of nodes in the order to be computed. Forward evaluation by default. + # def get_node_compute_order(node, forward=False): + # Create a set to keep track of visited nodes + # visited = set() + # Create a stack to keep track of nodes to visit + # stack = [node] + # Create a list to store the evaluation order + # order = [] + + # Get the next nodes that this node is dependent on + # def get_next_input_node(node): + # pass + + # Get the next nodes that is affected by the input node. + # def get_next_output_node(node): + # pass + + # if pin isn't connected, return it current data + + # get the evalutation order of the owning node of the pin + + # loop over each node and process it + + # return the pin's data diff --git a/pre-commit_check.ps1 b/pre-commit_check.ps1 new file mode 100644 index 0000000..58a5cd2 --- /dev/null +++ b/pre-commit_check.ps1 @@ -0,0 +1,3 @@ +clear-history +clear +pre-commit run --all-files diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..f97a60c --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,5 @@ +pre-commit +coverage +pytest +pytest-qt +types-requests diff --git a/requirements.txt b/requirements.txt index e3f31e2..8a16d08 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -PySide6 \ No newline at end of file +pyqtdarktheme +PySide6