Skip to content

Commit

Permalink
Lower value_and_grad to MLIR
Browse files Browse the repository at this point in the history
  • Loading branch information
rauletorresc committed Feb 29, 2024
1 parent f134bc6 commit 558e65a
Show file tree
Hide file tree
Showing 11 changed files with 713 additions and 102 deletions.
2 changes: 2 additions & 0 deletions frontend/catalyst/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
jvp,
measure,
mitigate_with_zne,
value_and_grad,
vjp,
while_loop,
)
Expand Down Expand Up @@ -173,6 +174,7 @@
"ctrl",
"measure",
"grad",
"value_and_grad",
"jacobian",
"vjp",
"jvp",
Expand Down
69 changes: 68 additions & 1 deletion frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from jaxlib.mlir.dialects.scf import ConditionOp, ForOp, IfOp, WhileOp, YieldOp
from jaxlib.mlir.dialects.stablehlo import ConstantOp as StableHLOConstantOp
from mlir_quantum.dialects.catalyst import PrintOp
from mlir_quantum.dialects.gradient import GradOp, JVPOp, VJPOp
from mlir_quantum.dialects.gradient import GradOp, JVPOp, ValueAndGradOp, VJPOp
from mlir_quantum.dialects.mitigation import ZneOp
from mlir_quantum.dialects.quantum import (
AdjointOp,
Expand Down Expand Up @@ -200,6 +200,8 @@ def _obs_lowering(aval):
for_p.multiple_results = True
grad_p = core.Primitive("grad")
grad_p.multiple_results = True
value_and_grad_p = core.Primitive("value_and_grad")
value_and_grad_p.multiple_results = True
func_p = core.CallPrimitive("func")
func_p.multiple_results = True
jvp_p = core.Primitive("jvp")
Expand Down Expand Up @@ -381,6 +383,70 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params):
).results


@value_and_grad_p.def_impl
def _value_and_grad_def_impl(ctx, *args, jaxpr, fn, grad_params): # pragma: no cover
raise NotImplementedError()


@value_and_grad_p.def_abstract_eval
def _value_and_grad_abstract(*args, jaxpr, fn, grad_params):
"""This function is called with abstract arguments for tracing.
Note: argument names must match these of `_value_and_grad_lowering`."""
return jaxpr.out_avals + jaxpr.out_avals


def _value_and_grad_lowering(ctx, *args, jaxpr, fn, grad_params):
"""Lowering function to value and gradient.
Args:
ctx: the MLIR context
args: the points in the function in which we are to calculate the derivative
jaxpr: the jaxpr representation of the value and grad op
fn(Grad): the function to be differentiated
method: the method used for differentiation
h: the difference for finite difference. May be None when fn is not finite difference.
argnum: argument indices which define over which arguments to
differentiate.
"""
method, h, argnum = grad_params.method, grad_params.h, grad_params.argnum
mlir_ctx = ctx.module_context.context
finiteDiffParam = None
if h:
f64 = ir.F64Type.get(mlir_ctx)
finiteDiffParam = ir.FloatAttr.get(f64, h)
offset = len(jaxpr.consts)
new_argnum = [num + offset for num in argnum]
argnum_numpy = np.array(new_argnum)
diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy)

_func_lowering(ctx, *args, call_jaxpr=jaxpr.eqns[0].params["call_jaxpr"], fn=fn, call=False)
symbol_name = mlir_fn_cache[fn]
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)

num_flat_output_types = len(flat_output_types)
assert (
num_flat_output_types % 2 == 0
), f"The total number of result tensors is expected to be even, not {num_flat_output_types}"

# ``ir.DenseElementsAttr.get()`` constructs a dense elements attribute from an array of
# element values. This doesn't support ``jaxlib.xla_extension.Array``, so we have to cast
# such constants to numpy array types.
constants = [
ConstantOp(ir.DenseElementsAttr.get(np.asarray(const))).results for const in jaxpr.consts
]
args_and_consts = constants + list(args)

return ValueAndGradOp(
flat_output_types[: num_flat_output_types // 2],
flat_output_types[num_flat_output_types // 2 :],
ir.StringAttr.get(method),
ir.FlatSymbolRefAttr.get(symbol_name),
mlir.flatten_lowering_ir_args(args_and_consts),
diffArgIndices=diffArgIndices,
finiteDiffParam=finiteDiffParam,
).results


#
# vjp/jvp
#
Expand Down Expand Up @@ -1534,6 +1600,7 @@ def _adjoint_lowering(
mlir.register_lowering(while_p, _while_loop_lowering)
mlir.register_lowering(for_p, _for_loop_lowering)
mlir.register_lowering(grad_p, _grad_lowering)
mlir.register_lowering(value_and_grad_p, _value_and_grad_lowering)
mlir.register_lowering(func_p, _func_lowering)
mlir.register_lowering(jvp_p, _jvp_lowering)
mlir.register_lowering(vjp_p, _vjp_lowering)
Expand Down
10 changes: 7 additions & 3 deletions frontend/catalyst/pennylane_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
jvp_p,
probs_p,
qmeasure_p,
value_and_grad_p,
vjp_p,
while_p,
zne_p,
Expand Down Expand Up @@ -263,7 +264,7 @@ def traverse_children(jaxpr):
primitive = eqn.primitive
if primitive is func_p:
child_jaxpr = eqn.params.get("call_jaxpr")
elif primitive is grad_p:
elif primitive in [grad_p, value_and_grad_p]:
child_jaxpr = eqn.params.get("jaxpr")
else:
continue
Expand All @@ -283,7 +284,7 @@ def traverse_children(jaxpr):
def _check_primitive_is_differentiable(primitive, method):
"""Verify restriction on primitives in the call graph of a Grad operation."""

if primitive is grad_p and method != "fd":
if primitive in [grad_p, value_and_grad_p] and method != "fd":
raise DifferentiableCompileError(
"Only finite difference can compute higher order derivatives."
)
Expand Down Expand Up @@ -404,8 +405,11 @@ def __call__(self, *args, **kwargs):

args_data, _ = tree_flatten(args)

# choose between value_and_grad_p or just grad_p
grad_func = value_and_grad_p if grad_params.with_value else grad_p

# It always returns list as required by catalyst control-flows
results = grad_p.bind(*args_data, jaxpr=jaxpr, fn=fn, grad_params=grad_params)
results = grad_func.bind(*args_data, jaxpr=jaxpr, fn=fn, grad_params=grad_params)
results = _unflatten_derivatives(
results, out_tree, self.grad_params.argnum, len(jaxpr.out_avals)
)
Expand Down
48 changes: 48 additions & 0 deletions mlir/include/Gradient/IR/GradientOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,54 @@ def GradOp : Gradient_Op<"grad", [DeclareOpInterfaceMethods<CallOpInterface>,
}];
}

def ValueAndGradOp : Gradient_Op<"value_and_grad", [
SameVariadicResultSize,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Compute the value and gradient of a function.";
let description = [{
The `gradient.value_and_grad` operation computes the value and the gradient of a function
using the finite difference method.

This operation acts much like the `func.call` operation, taking a
symbol reference and arguments to the original functionan as input.
The function result and the gradient of the function are returned.

Example:

```mlir
func.func @foo(%arg0: f64) -> f64 {
%res = arith.mulf %arg0, %arg0 : f64
func.return %res : f64
}

%0 = arith.constant 2.0 : f64
%1 = gradient.value_and_grad @foo(%0) : (f64) -> f64
```
}];

let arguments = (ins
StrAttr:$method,
FlatSymbolRefAttr:$callee,
Variadic<AnyType>:$operands,
OptionalAttr<AnyIntElementsAttr>:$diffArgIndices,
OptionalAttr<Builtin_FloatAttr>:$finiteDiffParam
);
let results = (outs
Variadic<AnyTypeOf<[AnyFloat, RankedTensorOf<[AnyFloat]>]>>:$calleeResults,
Variadic<AnyTypeOf<[AnyFloat, RankedTensorOf<[AnyFloat]>]>>:$gradients
);

let assemblyFormat = [{
$method $callee `(` $operands `)`
attr-dict `:` functional-type($operands, results)
}];


let hasVerifier = 1;
}

def AdjointOp : Gradient_Op<"adjoint", [AttrSizedOperandSegments]> {
let summary = "Perform quantum AD using the adjoint method on a device.";

Expand Down
3 changes: 3 additions & 0 deletions mlir/include/Gradient/Utils/GradientShape.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ bool isDifferentiable(mlir::Type type);
std::vector<mlir::Type> computeResultTypes(mlir::func::FuncOp callee,
const std::vector<size_t> &diffArgIndices);

std::vector<mlir::Type> computeValueAndGradTypes(mlir::func::FuncOp callee,
const std::vector<size_t> &diffArgIndices);

std::vector<mlir::Type> computeQGradTypes(mlir::func::FuncOp callee);

std::vector<mlir::Type> computeBackpropTypes(mlir::func::FuncOp callee,
Expand Down
114 changes: 114 additions & 0 deletions mlir/lib/Gradient/IR/GradientOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,43 @@ LogicalResult verifyGradOutputs(OpState *op_state, func::FuncOp fn,
return success();
}

// Gradient output checker
LogicalResult verifyValueAndGradOutputs(OpState *op_state, func::FuncOp fn,
const std::vector<size_t> &diff_arg_indices,
TypeRange result_types)
{
const std::vector<Type> &expectedTypes = computeValueAndGradTypes(fn, diff_arg_indices);

// Verify the number of results matches the expected gradient shape.
// The grad output should contain one set of results (equal in size to
// the number of function results) for each differentiable argument.
if (result_types.size() != expectedTypes.size())
return op_state->emitOpError(
"incorrect number of results in the value and gradient of the callee, ")
<< "expected " << expectedTypes.size() << " results "
<< "but got " << result_types.size();

// Verify the shape of each result. The numeric type should match the numeric type
// of the corresponding function result. The shape is given by grouping the differentiated
// argument shape with the corresponding function result shape.
TypeRange gradResultTypes = result_types;
for (unsigned i = 0; i < expectedTypes.size(); i++) {
op_state->emitRemark("Expected: ") << i << " = " << expectedTypes[i] << '\n';
}

for (unsigned i = 0; i < gradResultTypes.size(); i++) {
op_state->emitRemark("Obtained: ") << i << " = " << gradResultTypes[i] << '\n';
}

for (unsigned i = 0; i < expectedTypes.size(); i++) {
if (gradResultTypes[i] != expectedTypes[i])
return op_state->emitOpError("invalid result type: grad result at position ")
<< i << " must be " << expectedTypes[i] << " but got " << gradResultTypes[i];
}

return success();
}

//===----------------------------------------------------------------------===https://
// GradOp, CallOpInterface
//===----------------------------------------------------------------------===https://
Expand Down Expand Up @@ -153,6 +190,83 @@ LogicalResult GradOp::verify()

MutableOperandRange GradOp::getArgOperandsMutable() { return getOperandsMutable(); }

//===----------------------------------------------------------------------===https://
// ValueAndGradOp, CallOpInterface
//===----------------------------------------------------------------------===https://

CallInterfaceCallable ValueAndGradOp::getCallableForCallee() { return getCalleeAttr(); }

void ValueAndGradOp::setCalleeFromCallable(CallInterfaceCallable callee)
{
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
};

Operation::operand_range ValueAndGradOp::getArgOperands() { return getOperands(); }

//===----------------------------------------------------------------------===https://
// ValueAndGradOp, SymbolUserOpInterface
//===----------------------------------------------------------------------===https://

LogicalResult ValueAndGradOp::verifySymbolUses(SymbolTableCollection &symbolTable)
{
// Check that the callee attribute refers to a valid function.
auto callee = ({
auto callee = this->getCalleeAttr();
func::FuncOp fn =
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(this->getOperation(), callee);
if (!fn)
return this->emitOpError("invalid function name specified: ") << callee;
fn;
});

auto diffArgIndices = computeDiffArgIndices(this->getDiffArgIndices());
auto r1 = ::verifyGradInputs(this, callee, this->getArgOperands(), diffArgIndices);
if (r1.failed()) {
return r1;
}

if (this->getNumResults() != 2 * callee.getFunctionType().getNumResults()) {
return this->emitOpError(
"invalid number of results: must be twice the number of callee results")
<< " which is " << 2 * callee.getFunctionType().getNumResults() << " but got "
<< this->getNumResults();
}

std::vector<Type> gradient_types;
{
for (auto s : this->getCalleeResults()) {
gradient_types.push_back(s.getType());
}
}

for (size_t i = 0; i < callee.getFunctionType().getNumResults(); i++) {
auto calleeRtype = callee.getFunctionType().getResult(i);
auto gradientRtype = gradient_types[i];
if (calleeRtype != gradientRtype) {
return this->emitOpError("result types do not match")
<< " result " << i << " should match "
<< " was expected to match the type " << gradientRtype << " but got "
<< calleeRtype;
}
}

return success();
}

//===----------------------------------------------------------------------===https://
// ValueAndGradOp Extra methods
//===----------------------------------------------------------------------===https://

LogicalResult ValueAndGradOp::verify()
{
StringRef method = this->getMethod();
if (method != "fd" && method != "auto")
return emitOpError("got invalid differentiation method: ") << method;
return success();
}

MutableOperandRange ValueAndGradOp::getArgOperandsMutable() { return getOperandsMutable(); }

//===----------------------------------------------------------------------===https://
// JVPOp, CallOpInterface
//===----------------------------------------------------------------------===https://
Expand Down
Loading

0 comments on commit 558e65a

Please sign in to comment.