Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement of PyTorch connector #787

Open
2 of 6 tasks
edoaltamura opened this issue Mar 11, 2024 · 2 comments
Open
2 of 6 tasks

Enhancement of PyTorch connector #787

edoaltamura opened this issue Mar 11, 2024 · 2 comments
Assignees
Labels
Connector: PyTorch 🔦 Relevant to optional packages, such as external connectors type: enhancement ✨ Features or aspects to improve
Milestone

Comments

@edoaltamura
Copy link
Collaborator

edoaltamura commented Mar 11, 2024

What should we add?

Given the demand for robust coupling with PyTorch, we propose to enhance the module


in two phases.

Phase 1 will involve no or very minor changes to the user-facing API. The changes will focus on refractoring the backend connector, bug patches, and compatibility robustness.

Phase 2 will focus on upgrading the inter-functionality with PyTorch and may involve changes in API and UX. These changes are likely to be incrementally introduced after the version 0.8 roll-out, leading up to a stable 1.0.

  • Adding version compatibility between Qiskit machine-learning and PyTorch, with relevant deprecation warnings and supported features
  • Improving the test coverage
  • Improved documentation
@edoaltamura edoaltamura added type: enhancement ✨ Features or aspects to improve Connector: PyTorch 🔦 Relevant to optional packages, such as external connectors labels Mar 11, 2024
@Jrbiltmore

This comment was marked as spam.

@edoaltamura
Copy link
Collaborator Author

This problem could be broken down into two steps.

Refactoring _TorchNNFunction as a standalone function

This is currently nested within TorchConnector. To improve modularity, we can move it outside and make it a separate function to use within TorchConnector:

def torch_nn_function_forward(
    input_data: Tensor,
    weights: Tensor,
    neural_network: NeuralNetwork,
    sparse: bool
) -> Tensor:
    """Forward pass computation."""
    if input_data.shape[-1] != neural_network.num_inputs:
        raise QiskitMachineLearningError(
            f"Invalid input dimension! Received {input_data.shape} and "
            + f"expected input compatible to {neural_network.num_inputs}"
        )

    result = neural_network.forward(
        input_data.detach().cpu().numpy(), weights.detach().cpu().numpy()
    )
    if sparse:
        if not neural_network.sparse:
            raise RuntimeError("TorchConnector configured as sparse, the network must be sparse as well")

        # Handle sparse output
        # (Implementation of sparse handling here)
    else:
        if neural_network.sparse:
            # Handle conversion to dense if necessary
            # (Implementation of dense handling here)
        result_tensor = torch.as_tensor(result, dtype=torch.float)

    if len(input_data.shape) == 1:
        result_tensor = result_tensor[0]

    result_tensor = result_tensor.to(input_data.device)
    return result_tensor

def torch_nn_function_backward(
    ctx: Any,
    grad_output: Tensor
) -> Tuple:
    """Backward pass computation."""
    input_data, weights = ctx.saved_tensors
    neural_network = ctx.neural_network

    if input_data.shape[-1] != neural_network.num_inputs:
        raise QiskitMachineLearningError(
            f"Invalid input dimension! Received {input_data.shape} and "
            + f" expected input compatible to {neural_network.num_inputs}"
        )

    # (Implementation of backward pass here)

    return input_grad, weights_grad, None, None

Restructuring TorchConnector

Now, TorchConnector will use these functions for its forward and backward passes. So we can write the rest as:

class TorchConnector(Module):
    def __init__(
        self,
        neural_network: NeuralNetwork,
        initial_weights: np.ndarray | Tensor | None = None,
        sparse: bool | None = None,
    ):
        super().__init__()
        self._neural_network = neural_network
        if sparse is None:
            sparse = self._neural_network.sparse

        self._sparse = sparse

        if self._sparse and not self._neural_network.sparse:
            raise QiskitMachineLearningError(
                "TorchConnector configured as sparse, the network must be sparse as well"
            )

        weight_param = torch.nn.Parameter(torch.zeros(neural_network.num_weights))
        self.register_parameter("weight", weight_param)
        self._weights = weight_param

        if initial_weights is None:
            self._weights.data.uniform_(-1, 1)
        else:
            self._weights.data = torch.tensor(initial_weights, dtype=torch.float)

    @property
    def neural_network(self) -> NeuralNetwork:
        return self._neural_network

    @property
    def weight(self) -> Tensor:
        return self._weights

    @property
    def sparse(self) -> bool | None:
        return self._sparse

    def forward(self, input_data: Tensor | None = None) -> Tensor:
        input_ = input_data if input_data is not None else torch.zeros(0)
        return torch_nn_function_forward(
            input_, self._weights, self._neural_network, self._sparse
        )

    def backward(self, ctx: Any, grad_output: Tensor) -> Tuple:
        return torch_nn_function_backward(ctx, grad_output)

Other points to note when updating

  • Replace Tuple with Python 3.12 compatible tuple typing syntax
  • Create dummy torch.<attribute>'s when Torch is not installed. E.g. result_tensor = torch.as_tensor(result, dtype=torch.float) must have a corresponding
if _optionals.HAS_TORCH:
    # ...
else:
    def f(): pass
    torch.as_tensor = f
    torch.float = float

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Connector: PyTorch 🔦 Relevant to optional packages, such as external connectors type: enhancement ✨ Features or aspects to improve
Projects
None yet
Development

No branches or pull requests

3 participants