diff --git a/cadquery/occ_impl/assembly.py b/cadquery/occ_impl/assembly.py index b2efc6994..9fa75add2 100644 --- a/cadquery/occ_impl/assembly.py +++ b/cadquery/occ_impl/assembly.py @@ -117,6 +117,14 @@ def toTuple(self) -> Tuple[float, float, float, float]: return (rgb.Red(), rgb.Green(), rgb.Blue(), a) + def __getstate__(self) -> Tuple[float, float, float, float]: + + return self.toTuple() + + def __setstate__(self, data: Tuple[float, float, float, float]): + + self.wrapped = Quantity_ColorRGBA(*data) + class AssemblyProtocol(Protocol): @property diff --git a/cadquery/occ_impl/geom.py b/cadquery/occ_impl/geom.py index 89bd90c08..3a2bc88ef 100644 --- a/cadquery/occ_impl/geom.py +++ b/cadquery/occ_impl/geom.py @@ -2,6 +2,8 @@ from typing import overload, Sequence, Union, Tuple, Type, Optional, Iterator +from io import BytesIO + from OCP.gp import ( gp_Vec, gp_Ax1, @@ -22,6 +24,7 @@ from OCP.BRepMesh import BRepMesh_IncrementalMesh from OCP.TopoDS import TopoDS_Shape from OCP.TopLoc import TopLoc_Location +from OCP.BinTools import BinTools_LocationSet from ..types import Real from ..utils import multimethod @@ -256,6 +259,16 @@ def transform(self, T: "Matrix") -> "Vector": return Vector(gp_Vec(pnt_t.XYZ())) + def __getstate__(self) -> tuple[float, float, float]: + + return (self.x, self.y, self.z) + + def __setstate__(self, state: tuple[float, float, float]): + + self._wrapped = gp_Vec() + + self.x, self.y, self.z = state + class Matrix: """A 3d , 4x4 transformation matrix. @@ -400,6 +413,19 @@ def __repr__(self) -> str: matrix_str = ",\n ".join(str(matrix_transposed[i::4]) for i in range(4)) return f"Matrix([{matrix_str}])" + def __getstate__(self) -> list[list[float]]: + + trsf = self.wrapped + return [[trsf.Value(i, j) for j in range(1, 5)] for i in range(1, 4)] + + def __setstate__(self, state: list[list[float]]): + + trsf = self.wrapped = gp_GTrsf() + + for i in range(3): + for j in range(4): + trsf.SetValue(i + 1, j + 1, state[i][j]) + class Plane(object): """A 2D coordinate system in space @@ -577,6 +603,7 @@ def __init__( xDir = Vector(xDir) if xDir.Length == 0.0: raise ValueError("xDir should be non null") + self._setPlaneDir(xDir) self.origin = Vector(origin) @@ -787,6 +814,14 @@ def toPln(self) -> gp_Pln: return gp_Pln(gp_Ax3(self.origin.toPnt(), self.zDir.toDir(), self.xDir.toDir())) + def __getstate__(self) -> Tuple[Vector, Vector, Vector, Vector]: + + return (self.xDir, self.yDir, self.zDir, self._origin) + + def __setstate__(self, data: Tuple[Vector, Vector, Vector, Vector]): + + self.xDir, self.yDir, self.zDir, self.origin = data + class BoundBox(object): """A BoundingBox for an object or set of objects. Wraps the OCP one""" @@ -1065,3 +1100,22 @@ def toTuple(self) -> Tuple[Tuple[float, float, float], Tuple[float, float, float rx, ry, rz = rot.GetEulerAngles(gp_EulerSequence.gp_Extrinsic_XYZ) return rv_trans, (degrees(rx), degrees(ry), degrees(rz)) + + def __getstate__(self) -> BytesIO: + + rv = BytesIO() + + ls = BinTools_LocationSet() + ls.Add(self.wrapped) + ls.Write(rv) + + rv.seek(0) + + return rv + + def __setstate__(self, data: BytesIO): + + ls = BinTools_LocationSet() + ls.Read(data) + + self.wrapped = ls.Location(1) diff --git a/cadquery/occ_impl/shapes.py b/cadquery/occ_impl/shapes.py index d800aa3e8..bb9b731a0 100644 --- a/cadquery/occ_impl/shapes.py +++ b/cadquery/occ_impl/shapes.py @@ -1650,6 +1650,24 @@ def export( self, fname, tolerance=tolerance, angularTolerance=angularTolerance, opt=opt ) + def __getstate__(self) -> Tuple[BytesIO, bool]: + + data = BytesIO() + + BinTools.Write_s(self.wrapped, data) + data.seek(0) + + return (data, self.forConstruction) + + def __setstate__(self, data: Tuple[BytesIO, bool]): + + wrapped = TopoDS_Shape() + + BinTools.Read_s(wrapped, data[0]) + + self.wrapped = wrapped + self.forConstruction = data[1] + class ShapeProtocol(Protocol): @property diff --git a/tests/test_pickle.py b/tests/test_pickle.py new file mode 100644 index 000000000..5c038fda2 --- /dev/null +++ b/tests/test_pickle.py @@ -0,0 +1,47 @@ +from pickle import loads, dumps + +from cadquery import ( + Vector, + Matrix, + Plane, + Location, + Shape, + Sketch, + Assembly, + Color, + Workplane, +) +from cadquery.func import box + +from pytest import mark + + +@mark.parametrize( + "obj", + [ + Vector(2, 3, 4), + Matrix(), + Plane((-2, 1, 1)), + Location(1, 2, 4), + Sketch().rect(1, 1), + Color("red"), + Workplane().sphere(1), + ], +) +def test_simple(obj): + + assert isinstance(loads(dumps(obj)), type(obj)) + + +def test_shape(): + + s = Shape(box(1, 1, 1).wrapped) + + assert isinstance(loads(dumps(s)), Shape) + + +def test_assy(): + + assy = Assembly().add(box(1, 1, 1), color=Color("blue")).add(box(2, 2, 2)) + + assert isinstance(loads(dumps(assy)), Assembly)