Skip to content

Commit

Permalink
add docstring and typing stubs (#6)
Browse files Browse the repository at this point in the history
* add docstring

* fix

* add pyi

* update

---------

Co-authored-by: tang zhixiong <[email protected]>
  • Loading branch information
district10 and zhixiong-tang authored Oct 1, 2024
1 parent 225ae16 commit 3b69952
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,4 @@ cython_debug/

_skbuild/
.pyodide-xbuildenv/
stubs
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ python_sdist:
python_test: pytest
.PHONY: build

restub:
pybind11-stubgen fast_viterbi._core -o stubs
cp stubs/fast_viterbi/_core.pyi src/fast_viterbi

# conda create -y -n py37 python=3.7
# conda create -y -n py38 python=3.8
# conda create -y -n py39 python=3.9
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "scikit_build_core.build"

[project]
name = "fast_viterbi"
version = "0.1.1"
version = "0.1.2"
description="a viterbi algo collection"
readme = "README.md"
authors = [
Expand Down
104 changes: 104 additions & 0 deletions src/fast_viterbi/_core.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
Pybind11 example plugin
-----------------------
.. currentmodule:: scikit_build_example
.. autosummary::
:toctree: _generate
add
subtract
"""

from __future__ import annotations

import typing

__all__ = ["FastViterbi", "add", "subtract"]

class FastViterbi:
def __init__(
self,
K: int,
N: int,
scores: dict[tuple[tuple[int, int], tuple[int, int]], float],
) -> None:
"""
Initialize FastViterbi object.
Args:
K (int): Number of nodes per layer.
N (int): Number of layers.
scores (dict): Scores for node transitions.
"""
def all_road_paths(self) -> list[list[int]]:
"""
Get all road paths.
Returns:
list: All road paths in the graph.
"""
@typing.overload
def inference(self) -> tuple[float, list[int]]:
"""
Perform inference without a road path.
Returns:
tuple: Best path and its score.
"""
@typing.overload
def inference(self, road_path: list[int]) -> tuple[float, list[int], list[int]]:
"""
Perform inference with a given road path.
Args:
road_path (list): List of road indices representing a path.
Returns:
tuple: Best path and its score.
"""
def scores(self, node_path: list[int]) -> list[float]:
"""
Get scores for a given node path.
Args:
node_path (list): List of node indices representing a path.
Returns:
float: Total score for the given path.
"""
def setup_roads(self, roads: list[list[int]]) -> bool:
"""
Set up roads for the Viterbi algorithm.
Args:
roads (list): List of road sequences.
"""
def setup_shortest_road_paths(
self, sp_paths: dict[tuple[tuple[int, int], tuple[int, int]], list[int]]
) -> bool:
"""
Set up shortest road paths.
Args:
sp_paths (dict): Dictionary of shortest paths between nodes.
"""

def add(arg0: int, arg1: int) -> int:
"""
Add two numbers
Some other explanation about the add function.
"""

def subtract(arg0: int, arg1: int) -> int:
"""
Subtract two numbers
Some other explanation about the subtract function.
"""

__version__: str = "0.1.2"
Empty file added src/fast_viterbi/py.typed
Empty file.
63 changes: 56 additions & 7 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,19 +398,68 @@ PYBIND11_MODULE(_core, m) {
using NodeIndex = FastViterbi::NodeIndex;
py::class_<FastViterbi>(m, "FastViterbi", py::module_local(), py::dynamic_attr()) //
.def(py::init<int, int, const std::map<std::tuple<NodeIndex, NodeIndex>, double> &>(), //
"K"_a, "N"_a, "scores"_a)
"K"_a, "N"_a, "scores"_a,
R"pbdoc(
Initialize FastViterbi object.
Args:
K (int): Number of nodes per layer.
N (int): Number of layers.
scores (dict): Scores for node transitions.
)pbdoc")
//
.def("scores", &FastViterbi::scores, "node_path")
.def("scores", &FastViterbi::scores, "node_path"_a,
R"pbdoc(
Get scores for a given node path.
Args:
node_path (list): List of node indices representing a path.
Returns:
float: Total score for the given path.
)pbdoc")
//
.def("inference", py::overload_cast<>(&FastViterbi::inference, py::const_))
.def("inference", py::overload_cast<>(&FastViterbi::inference, py::const_),
R"pbdoc(
Perform inference without a road path.
Returns:
tuple: Best path and its score.
)pbdoc")
//
.def("setup_roads", &FastViterbi::setup_roads, "roads"_a)
.def("setup_shortest_road_paths", &FastViterbi::setup_shortest_road_paths, "sp_paths"_a)
.def("setup_roads", &FastViterbi::setup_roads, "roads"_a,
R"pbdoc(
Set up roads for the Viterbi algorithm.
Args:
roads (list): List of road sequences.
)pbdoc")
.def("setup_shortest_road_paths", &FastViterbi::setup_shortest_road_paths, "sp_paths"_a,
R"pbdoc(
Set up shortest road paths.
Args:
sp_paths (dict): Dictionary of shortest paths between nodes.
)pbdoc")
//
.def("all_road_paths", &FastViterbi::all_road_paths)
.def("all_road_paths", &FastViterbi::all_road_paths,
R"pbdoc(
Get all road paths.
Returns:
list: All road paths in the graph.
)pbdoc")
.def("inference", py::overload_cast<const std::vector<int64_t> &>(&FastViterbi::inference, py::const_),
"road_path"_a, py::call_guard<py::gil_scoped_release>())
"road_path"_a, py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Perform inference with a given road path.
Args:
road_path (list): List of road indices representing a path.
Returns:
tuple: Best path and its score.
)pbdoc")
//
;

Expand Down
2 changes: 1 addition & 1 deletion tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def test_version():
assert m.__version__ == "0.1.1"
assert m.__version__ == "0.1.2"


def test_add():
Expand Down

0 comments on commit 3b69952

Please sign in to comment.