Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compilation of Jacobian with for loops and vmap #332

Open
wants to merge 38 commits into
base: main
Choose a base branch
from

Conversation

rmoyard
Copy link
Contributor

@rmoyard rmoyard commented Oct 25, 2023

Context:
Some gradients/jacobians do not compile because the insertion point in the lowering is not correctly set.

Description of the Change:
The insertion is the call op and not the body of the function.

Benefits:
We can compile and run more derivatives of vmap and for loop.

Drawback
It compiles gradient acting on vmap but returns wrong results, see xfailed test.

[sc-59758]

@maliasadi maliasadi added the compiler Pull requests that update the compiler label Dec 6, 2023
@rmoyard rmoyard marked this pull request as ready for review March 19, 2024 14:43
Copy link

Hello. You may have forgotten to update the changelog!
Please edit doc/changelog.md on your branch with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@rmoyard rmoyard changed the title Fix compilation of Jacobian with for loops Fix compilation of Jacobian with for loops and vmap Mar 19, 2024
frontend/catalyst/jit.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

frontend/test/pytest/test_vmap.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_vmap.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_vmap.py Outdated Show resolved Hide resolved
@dime10
Copy link
Collaborator

dime10 commented May 7, 2024

Testing this fix on the following example results in very long (infinite?) compile or runtime:

import pennylane as qml
import jax
from jax import numpy as jnp
import catalyst

n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])

dev_name = "lightning.qubit"
dev = qml.device(dev_name, wires=n_wires)

@qml.qnode(dev)
def circuit(data, weights):
    """Quantum circuit ansatz"""

    for i in range(n_wires):
        qml.RY(data[i], wires=i)

    for i in range(n_wires):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

# try broadcasting
jit_circuit = catalyst.qjit(catalyst.vmap(circuit, in_axes = (1, None)))

def my_model(data, weights, bias):
    # works with default.qubit
    if dev_name == "default.qubit":
        return circuit(data, weights) + bias

    # works with lightning.qubit, not broadcasted
    # return jnp.array([circuit(jnp.array(d), weights) for d in data.T])

    # only works with loss_fn, fails at grad step
    return jit_circuit(data, weights) + bias

@jax.jit
def loss_fn(params, data, targets):
    predictions = my_model(data, params["weights"], params["bias"])
    loss = jnp.sum((targets - predictions) ** 2 / len(data))
    return loss


weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}

print(loss_fn(params, data, targets))
print(jax.grad(loss_fn)(params, data, targets))  # runs for > 20 minutes

@rmoyard
Copy link
Contributor Author

rmoyard commented Jul 4, 2024

[sc-49763]

@rmoyard
Copy link
Contributor Author

rmoyard commented Jul 23, 2024

Closes #294

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiler Pull requests that update the compiler
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants