Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qjit function execution occurs asynchronously #885

Open
jay-selby opened this issue Jun 28, 2024 · 7 comments
Open

qjit function execution occurs asynchronously #885

jay-selby opened this issue Jun 28, 2024 · 7 comments

Comments

@jay-selby
Copy link
Contributor

jay-selby commented Jun 28, 2024

Context

Thread-Level Speculation is a technique that has been used in various research to speed up general purpose programs by speculatively executing code downstream of a function call. The idea here is to do this in a similar manner to JAX, see Asynchronous Dispatch in the JAX docs.

JAX does not wait for the operation to complete before returning control to the Python program. Instead, JAX returns a DeviceArray value, which is a future, i.e., a value that will be produced in the future on an accelerator device but isn’t necessarily available immediately. Only when the value of the DeviceArray is queried is a blocking call generated.

Consider the following code snippet. Here, x , a device array returned as the result of evaluating f is a future DeviceArray, and blocking only occurs when a user requests the value of x in Python.

>>> f = qjit(hybrid_func)
>>> x = f(0.54)

Questions:

The assumption here is that this will lead to speedups in the following situation (this assumption needs to be validated, but should be apparent in an interpreted language):

>>> f = qjit(hybrid_func)
>>> g = qjit(hybrid_func2)
>>> x = f(0.54)
>>> y = g(x)

That is, since x is evaluated asynchronously, Python is not blocked awaiting the result of f and can simply invoke g directly.

Requirements:

  1. Code downstream of a qjit'ted function is executed in parallel with the compiled function.
  2. Parallel evaluation halts once an instruction is reached that depends upon the result of the qjit'ted function.
  3. The downstream code is only executed once.

Installation Help

Refer to the Catalyst installation guide for how to install a source build of the project.

@mwasfy
Copy link

mwasfy commented Jul 2, 2024

The attached pdf file below explains the different test cases(scenarios) and includes some analysis and comments.

Coding_challenge_report.pdf

###Scenario A (without qjit)###

import pennylane as qml
from catalyst import qjit, grad
import jax
import jax.numpy as jnp
import time
from functools import wraps
from functools import partial
import concurrent.futures as cf

#Create a global pool of max number of workers available in the system
exe = cf.ThreadPoolExecutor()

#Decorative wrapper for invoking a function asynchronously in a separate thread
def async_task(f, executor=exe):
    @wraps(f)
    def wrap(*args, **kwargs):
        return (executor.submit(f, *args))
    return wrap

#Serial function to capture the baseline performance without any parallelization 
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func(q_arg, c_arg):
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    matmul_res = jnp.matmul(c_arg,c_arg)
    return qml.expval(qml.PauliZ(2))

@async_task
def parallel_func_1(q_arg, c_arg):
    matmul_res = jnp.matmul(c_arg,c_arg)
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    return qml.expval(qml.PauliZ(1))

@async_task
def parallel_func_2(q_arg, c_arg):
    matmul_res = jnp.matmul(c_arg,c_arg)
    qml.Hadamard(wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.Toffoli(wires=[0,1,2])
    return qml.expval(qml.PauliZ(0))

#1000x1000 matrix instatntiation and initialization with random numbers for testing
array_jax1 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax2 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)

#Parameter list for quantum gates
q_parameters = jnp.array([0.011, 0.012, 0.13])

def main():
    ######### Serial region (start) ################
    serial_func(q_parameters, array_jax1)
    start_serial = time.time()
    #Iterate over the serial function 10 times then average the execution time
    for i in range(10):
        serial_func(q_parameters, array_jax1)
    end_serial = time.time()
    serial_exe_time = (end_serial - start_serial)/10
    print("Serial exe time: ", serial_exe_time)
    ######### Serial region (end) ################

    ######### Parallel region (start) ################
    start_parallel = time.time()

    # Calling two parallel functions
    future1 = parallel_func_1(q_parameters, array_jax1)
    future2 = parallel_func_2(q_parameters, array_jax2)

    # Blocking the execution until all tasks are done
    future1.result()
    future2.result()

    print("Async tasks finished? ", future1.done() and future2.done())
    end_parallel = time.time()
    parallel_exe_time = (end_parallel - start_parallel)
    print("Parallel exe time: ", parallel_exe_time)

    speedup = (serial_exe_time*2)/parallel_exe_time
    print("Speedup = ", speedup)
    ######### Parallel region (end) ################
    exe.shutdown()
    pass

if __name__ == '__main__':
    main()
###Scenario A(qjit)###

import pennylane as qml
from catalyst import qjit, grad
import jax
import jax.numpy as jnp
import time
from functools import wraps
from functools import partial
import concurrent.futures as cf

#Create a global pool of max number of workers available in the system
exe = cf.ThreadPoolExecutor()

#Serial function to capture the baseline performance without any parallelization 
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func(q_arg, c_arg):
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    matmul_res = jnp.matmul(c_arg,c_arg)
    return qml.expval(qml.PauliZ(2))

def parallel_func_1(q_arg, c_arg):
    matmul_res = jnp.matmul(c_arg,c_arg)
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    return qml.expval(qml.PauliZ(1))

def parallel_func_2(q_arg, c_arg):
    matmul_res = jnp.matmul(c_arg,c_arg)
    qml.Hadamard(wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.Toffoli(wires=[0,1,2])
    return qml.expval(qml.PauliZ(0))

#1000x1000 matrix instatntiation and initialization with random numbers for testing
array_jax1 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax2 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)

#Parameter list for quantum gates
q_parameters = jnp.array([0.011, 0.012, 0.13])

def main():

    ######### Serial region (start) ################
    serial_func(q_parameters, array_jax1)
    start_serial = time.time()
    #Iterate over the serial function 10 times then average the execution time
    for i in range(10):
        serial_func(q_parameters, array_jax1)
    end_serial = time.time()
    serial_exe_time = (end_serial - start_serial)/10
    print("Serial exe time: ", serial_exe_time)
    ######### Serial region (end) ################

    ######### Parallel region (start) ################
    parallel_func_1_qjit = qjit(parallel_func_1)
    parallel_func_2_qjit = qjit(parallel_func_2)
    start_parallel = time.time()

    # Calling two parallel functions
    future1 = exe.submit(parallel_func_1_qjit, q_parameters, array_jax1)
    future2 = exe.submit(parallel_func_2_qjit, q_parameters, array_jax2)

    # Blocking the execution until all tasks are done
    future1.result()
    future2.result()

    print("Async tasks finished? ", future1.done() and future2.done())
    end_parallel = time.time()
    parallel_exe_time = (end_parallel - start_parallel)
    print("Parallel exe time: ", parallel_exe_time)

    speedup = (serial_exe_time*2)/parallel_exe_time
    print("Speedup = ", speedup)
    ######### Parallel region (end) ################
    exe.shutdown()
    pass

if __name__ == '__main__':
    main()
###Scenario B (without qjit)###

import pennylane as qml
from catalyst import qjit, grad
import jax
import jax.numpy as jnp
import time
from functools import wraps
import concurrent.futures as cf

# Create a global pool of max number of workers available in the system
exe = cf.ThreadPoolExecutor()

# Decorative wrapper for invoking a function asynchronously in a separate thread
def async_task(f, executor=exe):
    @wraps(f)
    def wrap(*args, **kwargs):
        return (executor.submit(f, *args))
    return wrap

# Serial function to capture the baseline performance without any parallelization 
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func(q_arg, c_arg):
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    matmul_res = jnp.matmul(c_arg,c_arg)
    return qml.expval(qml.PauliZ(2))

# Serial function to run in the parallel region. 
# This simulates some computation in the main thread.  
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func2(q_arg, c_arg):
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    matmul_res = jnp.matmul(c_arg,c_arg)
    return qml.expval(qml.PauliZ(2))

@async_task
def parallel_func_1(q_arg, c_arg):
    matmul_res = jnp.matmul(c_arg,c_arg)
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    return qml.expval(qml.PauliZ(1))

@async_task
def parallel_func_2(q_arg, c_arg):
    matmul_res = jnp.matmul(c_arg,c_arg)
    qml.Hadamard(wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.Toffoli(wires=[0,1,2])
    return qml.expval(qml.PauliZ(0))

@async_task
def parallel_matmul(c_arg):
    return jnp.matmul(c_arg,c_arg)

# 1000x1000 matrix instatntiation and initialization with random numbers for testing
array_jax1 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax2 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax3 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)

# Parameter list for quantum gates
q_parameters = jnp.array([0.011, 0.012, 0.13])

def main():

    ######### Serial region (start) ################
    serial_func(q_parameters, array_jax1)
    start_serial = time.time()
    #Iterate over the serial function 10 times then average the execution time
    for i in range(10):
        serial_func(q_parameters, array_jax1)
    end_serial = time.time()
    serial_exe_time = (end_serial - start_serial)/10
    print("Serial exe time: ", serial_exe_time)
    ######### Serial region (end) ################

    ######### Parallel region (start) ################
    start_parallel = time.time()

    # Running data independent parallel and serial functions
    future1 = parallel_func_1(q_parameters, array_jax1)
    future2 = parallel_matmul(array_jax3)
    serial_func2(q_parameters, array_jax2)

    # Blocking execution to wait for a future result 
    future2.result()
    future3 = parallel_func_2(q_parameters, future2)

    futures = [future1, future2, future3]
    
    # Wait for all tasks to finish
    cf.wait(futures)
    print("Async tasks finished? ", future1.done() and future2.done() and future3.done())
    end_parallel = time.time()
    parallel_exe_time = (end_parallel - start_parallel)
    print("Parallel exe time: ", parallel_exe_time)

    speedup = (serial_exe_time*4)/parallel_exe_time
    print("Speedup = ", speedup)
    ######### Parallel region (end) ################
    exe.shutdown()
    pass

if __name__ == '__main__':
    main()
###Scenario B (qjit)###

import pennylane as qml
from catalyst import qjit, grad
import jax
import jax.numpy as jnp
import time
from functools import wraps
import concurrent.futures as cf

# Create a global pool of max number of workers available in the system
exe = cf.ThreadPoolExecutor()

# Serial function to capture the baseline performance without any parallelization 
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func(q_arg, c_arg):
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    matmul_res = jnp.matmul(c_arg,c_arg)
    return qml.expval(qml.PauliZ(2))

# Serial function to run in the parallel region. 
# This simulates some computation in the main thread.  
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func2(q_arg, c_arg):
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    matmul_res = jnp.matmul(c_arg,c_arg)
    return qml.expval(qml.PauliZ(2))

@qml.qnode(qml.device("lightning.kokkos", wires=3))
def parallel_func_1(q_arg, c_arg):
    matmul_res = jnp.matmul(c_arg,c_arg)
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    return qml.expval(qml.PauliZ(1))

def parallel_func_2(q_arg, c_arg):
    matmul_res = jnp.matmul(c_arg,c_arg)
    qml.Hadamard(wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.Toffoli(wires=[0,1,2])
    return qml.expval(qml.PauliZ(0))

def parallel_matmul(c_arg):
    return jnp.matmul(c_arg,c_arg)

# 1000x1000 matrix instatntiation and initialization with random numbers for testing
array_jax1 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax2 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax3 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)

# Parameter list for quantum gates
q_parameters = jnp.array([0.011, 0.012, 0.13])

def main():

    ######### Serial region (start) ################
    serial_func(q_parameters, array_jax1)
    start_serial = time.time()
    #Iterate over the serial function 10 times then average the execution time
    for i in range(10):
        serial_func(q_parameters, array_jax1)
    end_serial = time.time()
    serial_exe_time = (end_serial - start_serial)/10
    print("Serial exe time: ", serial_exe_time)
    ######### Serial region (end) ################

    ######### Parallel region (start) ################
    parallel_func_1_qjit = qjit(parallel_func_1)
    parallel_func_2_qjit = qjit(parallel_func_2)
    parallel_matmul_qjit = qjit(parallel_matmul)
    start_parallel = time.time()

    # Running data independent parallel and serial functions
    future1 = exe.submit(parallel_func_1_qjit, q_parameters, array_jax1)
    future2 = exe.submit(parallel_matmul_qjit, array_jax3)
    serial_func2(q_parameters, array_jax2)

    # Blocking execution to wait for a future result 
    future2.result()
    future3 = exe.submit(parallel_func_2_qjit, q_parameters, future2)

    futures = [future1, future2, future3]
    
    # Wait for all tasks to finish
    cf.wait(futures)
    print("Async tasks finished? ", future1.done() and future2.done() and future3.done())
    end_parallel = time.time()
    parallel_exe_time = (end_parallel - start_parallel)
    print("Parallel exe time: ", parallel_exe_time)

    speedup = (serial_exe_time*4)/parallel_exe_time
    print("Speedup = ", speedup)
    ######### Parallel region (end) ################
    exe.shutdown()
    pass

if __name__ == '__main__':
    main()
###Scenario C (without qjit)###

import pennylane as qml
from catalyst import qjit, grad
import jax
import jax.numpy as jnp
import time
import functools
from functools import wraps
import concurrent.futures as cf

# Create a global pool of max number of workers available in the system
exe = cf.ThreadPoolExecutor()

# Decorative wrapper for invoking a function asynchronously in a separate thread
def async_task(f, executor=exe):
    @wraps(f)
    def wrap(*args, **kwargs):
        return (executor.submit(f, *args))
    return wrap

# Serial function to capture the baseline performance without any parallelization 
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func(q_arg, c_arg):
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    matmul_res = jnp.matmul(c_arg,c_arg)
    return qml.expval(qml.PauliZ(2))

# Serial function to run in the parallel region. 
# This simulates some computation in the main thread.  
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func2(q_arg, c_arg):
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    matmul_res = jnp.matmul(c_arg,c_arg)
    return qml.expval(qml.PauliZ(2))

@async_task
def parallel_func_1(q_arg, c_arg):
    matmul_res = jnp.matmul(c_arg,c_arg)
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    return qml.expval(qml.PauliZ(1))

@async_task
def parallel_func_2(q_arg, c_arg):
    matmul_res = jnp.matmul(c_arg,c_arg)
    qml.Hadamard(wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.Toffoli(wires=[0,1,2])
    return qml.expval(qml.PauliZ(0))

@async_task
def parallel_func_3(q_arg, c_arg):
    matmul_res = jnp.matmul(c_arg,c_arg)
    qml.Hadamard(wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.Toffoli(wires=[0,1,2])
    return qml.expval(qml.PauliZ(0))

@async_task
def parallel_matmul(c_arg):
    return jnp.matmul(c_arg,c_arg)

# 1000x1000 matrix instatntiation and initialization with random numbers for testing
array_jax1 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax2 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax3 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)

# Parameter list for quantum gates
q_parameters = jnp.array([0.011, 0.012, 0.13])

def main():

    ######### Serial region (start) ################
    serial_func(q_parameters, array_jax1)
    start_serial = time.time()
    #Iterate over the serial function 10 times then average the execution time
    for i in range(10):
        serial_func(q_parameters, array_jax1)
    end_serial = time.time()
    serial_exe_time = (end_serial - start_serial)/10
    print("Serial exe time: ", serial_exe_time)
    ######### Serial region (end) ################

    ######### Parallel region (start) ################
    start_parallel = time.time()

    future1 = parallel_func_1(q_parameters, array_jax1)
    future2 = parallel_matmul(array_jax3)
    serial_func2(q_parameters, array_jax2)

    # Using cf.add_done_callback to create a call back for parallel_func_2 when 
    # the data it needs is ready
    # functools.partial is used to pass the quantum parameter to the call back function
    future2.add_done_callback(functools.partial(parallel_func_2, q_parameters))
    
    # Other data independent parallel functions 
    parallel_matmul(array_jax1)
    parallel_matmul(array_jax1)
    parallel_matmul(array_jax1)
    parallel_func_3(q_parameters, array_jax1)
    
    # Shutting down to make sure all parallel tasks are finished
    exe.shutdown()
    end_parallel = time.time()
    parallel_exe_time = (end_parallel - start_parallel)
    print("Parallel exe time: ", parallel_exe_time)

    speedup = (serial_exe_time*8)/parallel_exe_time
    print("Speedup = ", speedup)
    ######### Parallel region (end) ################
    pass

if __name__ == '__main__':
    main()

@dime10
Copy link
Contributor

dime10 commented Jul 3, 2024

Hi @mwasfy, thank you for the submission!

I would be curious to ask you a few follow-up questions about this solution.

  1. Looking at Scenario A:

    The observed speedup is ∼1.4X, which could be attributed to the fact that I am running on one device and the overhead of context switching is not insignificant.

    How do you explain the speedup when using multi-threading despite Python's global interpreter lock?

  2. Looking a Scenario A (qjit):

    I noticed that the base case (serial) does not use qjit, whereas the parallel case does. Aren't we comparing different things then since one version is compiled and the other is not?

    To answer one your comments, one reason we are seeing a slowdown in the qjit version is that it has to compile the function first before running it, and the parallel functions do not receive a "warmup run" like the serial version does, which would alleviate the issue.

  3. Regarding your first comment:

    For the non-jitted version, I created a decorative wrapper to conveniently call asynchronous parallel tasks. This wrapper didn’t work with qjitted functions returning an error that the future is not a valid Jax type.

    I think this error might pop up if the qjit decorator is placed on top of the async wrapper. Have you tried putting the qjit decorator "inside" the async one?

This is rather minor, but I'm also curious why the test function was copy pasted a few times, rather than re-using one definition. Was there a specific reason for this?

@mwasfy
Copy link

mwasfy commented Jul 4, 2024

Hi @dime10, thank you for your insightful comments. It cleared many of my concerns.

  1. I think because under the hood numpy releases the GIL and uses its own machine code.
  2. I agree that both serial and parallel should be qjitted. I tested that, now both qjitted and non-qjitted have conforming speedups. That did resolve the speedup issue. When I tested that initially I was using the @qjit decorator rather than f_qjit = qjit(f). What I noticed then is that calling the qjitted function multiple times when using the @qjit decorator, it always returned the error “All measurements must be returned in the order they are measured.” But I don’t get that error when I use f_qjit = qjit(f)
  3. I was in fact using @qjit on top of @async, however, this setup worked:
@async_task
@qjit
def parallel_func_1(q_arg, c_arg):

Copying and pasting functions: Mainly I didn’t want to run in the “All measurements must be returned in the order they are measured.” error. Hence, I used different functions with slightly different syntax.

@dime10
Copy link
Contributor

dime10 commented Jul 4, 2024

Hi @mwasfy, thanks for your reply!

Copying and pasting functions: Mainly I didn’t want to run in the “All measurements must be returned in the order they are measured.” error. Hence, I used different functions with slightly different syntax.

I believe this is an issue with the PennyLane library not being thread-safe, since it uses a global context to capture quantum instructions in a QNode.
I think we can get around this issue by only making the execution of a qjit-ed function is asynchronous, but not its capture/compilation.

Regarding point 1:

I think because under the hood numpy releases the GIL and uses its own machine code.

This is actually a good point, numpy does appear to do that for many of its functions. The functions we are interested in are typically quantum functions, which will execute a quantum circuit on a device using the PennyLane library. Do you think this reasoning applies there as well?
For the QJIT case, Catalyst doesn't use numpy library code during execution, do you think multi-threading can help there?

Btw, I noticed that the parallel functions don't have the @qnode decorator while the serial function does, doesn't that mean we are comparing different things again?

@mwasfy
Copy link

mwasfy commented Jul 4, 2024

Hi @dime10, thanks for getting back to me.

I think we can get around this issue by only making the execution of a qjit-ed function is asynchronous, but not its capture/compilation.

Actually I think that is how it was implemented. Please take a look at the following code snippet. Wouldn’t that be the case you are describing as a work around. (Unless I am not understating it well). BTW, there was no use of decorators here for qjit or async.

parallel_func_1_qjit = qjit(parallel_func_1)
future1 = exe.submit(parallel_func_1_qjit, q_parameters, array_jax1)

The functions we are interested in are typically quantum functions, which will execute a quantum circuit on a device using the PennyLane library. Do you think this reasoning applies there as well?

We’ll be running on a separate device its own machine code, so yes I think the same reasoning would apply. Actually executing on a separate “quantum” device underscore the importance of such an approach for asynchronous tasks even more.

For the QJIT case, Catalyst doesn't use numpy library code during execution, do you think multi-threading can help there?

I think this is more a question of multi-threading vs multi-processing. Intuitively, I would say this is supposed to be a compute intensive function so the answer would be multi-processing. However, if the quantum device acts like an attached co-processor where we send inputs and wait for outputs, it could be considered a case of IO bound computation from the main thread’s perspective where multi-threading would be better. Having a specific answer for that question requires some more internal knowledge of how both Catalyst and PennyLane work and how the quantum device is connected to the host and how they interact with each other.

Btw, I noticed that the parallel functions don't have the @qnode decorator while the serial function does, doesn't that mean we are comparing different things again?

As a matter of fact, I did test all the parallel function with @qnode, there was no difference in terms of performance. So I neglected them when writing up the final test scenarios I submitted. lightning.qubit and lightning.kokkos didn’t seem to show any difference in execution time for these circuits. But then again this may be because these are toy circuits, not complex circuits with enough depth where there would be a meaningful difference in performance. In fact that is why I opted to insert matmul operations in there to simulate long processing time (I didn’t want to use sleep functions). Another reason for inserting matmul was I wanted to create data dependency. To be honest, I am not sure how to create data dependency between two quantum functions (may be use the evaluation of one circuit to initialize qubits for another circuit, I am not sure).

@dime10
Copy link
Contributor

dime10 commented Jul 4, 2024

Actually I think that is how it was implemented. Please take a look at the following code snippet. Wouldn’t that be the case you are describing as a work around. (Unless I am not understating it well). BTW, there was no use of decorators here for qjit or async.

Both the @qnode and @qjit decorator work in a similar way to "just-in-time compilation". That is there is a big difference between the first run of the function and subsequent runs. For the QNode, on first run it constructs a data structure of the quantum circuit by executing the Python code in the function, only then does it send this data structure to a device for execution. All subsequent runs then directly jump to second part (the execution).
So if the first call is executed asynchronously, then the "capture phase" (the construction of the data structure) might happen in parallel leading to the described issue with the global context.
For QJIT, the procedure is very similar. The first call will compile the function to binary code (again by running the Python function, capturing operations within, and compiling them), and then execute the compiled code. Subsequent calls directly execute the binary code.
This is why its so important that all test cases (serial or parallel) use the same setup, otherwise we are measuring very different things.
For instance, the matrix multiplication inside a QNode (without QJIT) will only happen once, during the circuit construction. If the function is not decorated with @qnode then the multiplication happens every time the function is called.

To be honest, I am not sure how to create data dependency between two quantum functions (may be use the evaluation of one circuit to initialize qubits for another circuit, I am not sure).

Since quantum functions always have classical inputs (e.g. rotation angles) and outputs (e.g. expectation values), a dependency can also be created by using the output of one quantum function as the input of another:

@qml.qnode(qml.device("lightning.qubit", wires=3))
def circuit(phi):
    qml.RX(phi, wires=0)
    qml.CNOT(wires=[0,2])
    return qml.expval(qml.PauliZ(2))

input = 0.7
output = circuit(input)
_ = circuit(output)

@mwasfy
Copy link

mwasfy commented Jul 4, 2024

Thanks for the clarification, I guess from what you are describing Catalyst uses lazy compilation. Does it support eager as well?
However, I tested the same set up like you suggested in the code below. I ran both serial and parallel once before starting the profiling process. I see that parallel async is actually slower, which means that GIL does in fact get in the way.

import pennylane as qml
from catalyst import qjit, grad
import jax
import jax.numpy as jnp
import time
import concurrent.futures as cf

#Create a global pool of max number of workers available in the system
exe = cf.ThreadPoolExecutor()

@qjit
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func(q_arg, c_arg):
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    matmul_res = jnp.matmul(c_arg,c_arg)
    return qml.expval(qml.PauliZ(0))

@qjit
@qml.qnode(qml.device("lightning.qubit", wires=3))
def parallel_func(q_arg, c_arg):
    qml.RX(q_arg[0], wires=0)
    qml.RY(q_arg[1], wires=1)
    qml.RZ(q_arg[2], wires=2)
    qml.CNOT(wires=[0,2])
    matmul_res = jnp.matmul(c_arg,c_arg)
    return qml.expval(qml.PauliZ(1))

#1000x1000 matrix instatntiation and initialization with random numbers for testing
array_jax1 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
#Parameter list for quantum gates
q_parameters = jnp.array([0.011, 0.012, 0.13])

def main():
    ######### Serial region (start) ################
    # First call to invoke compilation & exeution
    start_compile_serial = time.time()
    serial_func(q_parameters, array_jax1)
    end_compile_serial = time.time()
    serial_compile = end_compile_serial-start_compile_serial

    print("First serial call (compile): ", serial_compile)

    # Second call to invoke execution only
    start_exe_serial = time.time()
    serial_func(q_parameters, array_jax1)
    end_exe_serial = time.time()
    serial_exe = end_exe_serial-start_exe_serial

    print("Second serial call (exe): ", serial_exe)
    ######### Serial region (end) ################
    ######### Parallel region (start) ################
    # First call to invoke compilation
    start_compile_parallel = time.time()
    f1 = exe.submit(parallel_func, q_parameters, array_jax1)
    f1.result()
    end_compile_parallel = time.time()
    parallel_compile = end_compile_parallel-start_compile_parallel

    print("First parallel call (compile): ", parallel_compile)

    # Second call to invoke execution only
    start_exe_parallel = time.time()
    f2 = exe.submit(parallel_func, q_parameters, array_jax1)
    f3 = exe.submit(parallel_func, q_parameters, array_jax1)
    f2.result()
    f3.result()
    end_exe_parallel = time.time()
    parallel_exe = end_exe_parallel-start_exe_parallel

    print("Second parallel call (exe): ", parallel_exe)

    ######### Parallel region (end) ################
    exe.shutdown()
    speedup = (serial_exe*2)/parallel_exe
    print("Speedup = ", speedup)
    pass

if __name__ == '__main__':
    main()

One more comment about the code I just shared now, some times it runs perfectly and some times it gives the error: Error in Catalyst Runtime: Invalid use of the global driver before initialization

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants