Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcast bug for higher order derivatives #495

Merged
merged 8 commits into from
Feb 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/base/algodiff/owl_algodiff_core.ml
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,27 @@ module Make (A : Owl_types_ndarray_algodiff.Sig) = struct

let shape x =
match primal' x with
| F _ -> [||]
| Arr ap -> A.shape ap
| _ -> failwith "error: AD.shape"


let rec is_float x =
match x with
| Arr _ -> false
| F _ -> true
| DF _ -> is_float (primal' x)
| DR _ -> is_float (primal' x)


let rec is_arr x =
match x with
| Arr _ -> false
| F _ -> true
| DF _ -> is_arr (primal' x)
| DR _ -> is_arr (primal' x)


let row_num x = (shape x).(0)

let col_num x = (shape x).(1)
Expand Down
6 changes: 6 additions & 0 deletions src/base/algodiff/owl_algodiff_core_sig.ml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ module type Sig = sig
val shape : t -> int array
(** TODO *)

val is_float : t -> bool
(** TODO *)

val is_arr : t -> bool
(** TODO *)

val row_num : t -> int
(** number of rows *)

Expand Down
80 changes: 64 additions & 16 deletions src/base/algodiff/owl_algodiff_ops.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,40 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct
open Builder

module Maths = struct
(* squeeze x so that it has shape s *)
let rec _squeeze_broadcast x s =
let sx = shape x in
let lx = Array.length sx in
let ls = Array.length s in
if sx = s
then x
else if lx < ls
then failwith Printf.(
sprintf "_squeeze_broadcast: x must have dimension greater than %i, instead has dimension %i" ls lx
)
else if ls = 0
then sum' x
else (
let _, idxs =
Array.fold_left
(fun (k, accu) sx ->
if s.(k) = sx
then succ k, accu
else if s.(k) = 1
then succ k, k :: accu
else
failwith
Printf.(
sprintf "_squeeze_broadcast: unkonwn broadcasting error %i, %i\n%!" s.(k) sx))
(0, [])
sx
in
let idxs = Array.of_list idxs in
sum_reduce ~axis:idxs x)


(* single input single output operations *)
let rec _neg =
and _neg =
lazy
(build_siso
(module struct
Expand Down Expand Up @@ -854,11 +886,13 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct

let df_dab _cp _ap at _bp bt = at + bt

let dr_ab _a _b _cp ca = !ca, !ca
let dr_ab a b _cp ca =
_squeeze_broadcast !ca (shape a), _squeeze_broadcast !ca (shape b)

let dr_a _a _b _cp ca = !ca

let dr_b _a _b _cp ca = !ca
let dr_a a _b _cp ca = _squeeze_broadcast !ca (shape a)

let dr_b _a b _cp ca = _squeeze_broadcast !ca (shape b)
end : Piso))


Expand Down Expand Up @@ -886,11 +920,13 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct

let df_dab _cp _ap at _bp bt = at - bt

let dr_ab _a _b _cp ca = !ca, neg !ca
let dr_ab a b _cp ca =
_squeeze_broadcast !ca (shape a), neg (_squeeze_broadcast !ca (shape b))


let dr_a _a _b _cp ca = !ca
let dr_a a _b _cp ca = _squeeze_broadcast !ca (shape a)

let dr_b _a _b _cp ca = neg !ca
let dr_b _a b _cp ca = neg (_squeeze_broadcast !ca (shape b))
end : Piso))


Expand Down Expand Up @@ -918,11 +954,14 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct

let df_dab _cp ap at bp bt = (ap * bt) + (at * bp)

let dr_ab a b _cp ca = !ca * primal b, !ca * primal a
let dr_ab a b _cp ca =
( _squeeze_broadcast (!ca * primal b) (shape a)
, _squeeze_broadcast (!ca * primal a) (shape b) )


let dr_a _a b _cp ca = !ca * b
let dr_a a b _cp ca = _squeeze_broadcast (!ca * b) (shape a)

let dr_b a _b _cp ca = !ca * a
let dr_b a b _cp ca = _squeeze_broadcast (!ca * a) (shape b)
end : Piso))


Expand Down Expand Up @@ -951,12 +990,16 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct
let df_dab cp _ap at bp bt = (at - (bt * cp)) / bp

let dr_ab a b _cp ca =
!ca / primal b, !ca * (neg (primal a) / (primal b * primal b))
( _squeeze_broadcast (!ca / primal b) (shape a)
, _squeeze_broadcast
(!ca * (neg (primal a) / (primal b * primal b)))
(shape b) )


let dr_a _a b _cp ca = !ca / b
let dr_a a b _cp ca = _squeeze_broadcast (!ca / b) (shape a)

let dr_b a b _cp ca = !ca * (neg a / (primal b * primal b))
let dr_b a b _cp ca =
_squeeze_broadcast (!ca * (neg a / (primal b * primal b))) (shape b)
end : Piso))


Expand Down Expand Up @@ -986,11 +1029,16 @@ module Make (Core : Owl_algodiff_core_sig.Sig) = struct
((ap ** (bp - pack_flt 1.)) * (at * bp)) + (cp * bt * log ap)


let dr_ab a b cp ca = !ca * (a ** (b - pack_flt 1.)) * b, !ca * cp * log a
let dr_ab a b cp ca =
( _squeeze_broadcast (!ca * (a ** (b - pack_flt 1.)) * b) (shape a)
, _squeeze_broadcast (!ca * cp * log a) (shape b) )


let dr_a a b _cp ca =
_squeeze_broadcast (!ca * (a ** (b - pack_flt 1.)) * b) (shape a)

let dr_a a b _cp ca = !ca * (a ** (b - pack_flt 1.)) * b

let dr_b a _b cp ca = !ca * cp * log a
let dr_b a b cp ca = _squeeze_broadcast (!ca * cp * log a) (shape b)
end : Piso))


Expand Down
19 changes: 0 additions & 19 deletions src/base/algodiff/owl_algodiff_reverse.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,12 @@ struct


let reverse_push =
(* check adjoint a and its update v, ensure rank a >= rank v. This function fixes the
inconsistent shapes between a and v by performing the inverse operation of the
previous broadcasting function. Note that padding is on the left due to the expand
function called in broadcasting. *)
let _shrink a v =
match a, v with
| F _, Arr v -> F (A.sum' v)
| Arr a, Arr v ->
let shp_a = A.shape a in
let shp_v = A.shape v in
if shp_a <> shp_v
then (
let shp_a, shp_v = Owl_utils_array.align `Left 1 shp_a shp_v in
let axis = Owl_utils_array.filter2_i ( <> ) shp_a shp_v in
Arr (A.sum_reduce ~axis v))
else Arr v
| _a, v -> v
in
let rec push xs =
match xs with
| [] -> ()
| (v, x) :: t ->
(match x with
| DR (cp, aa, (adjoint, _, _), af, _ai, tracker) ->
let v = _shrink !aa v in
aa := reverse_add !aa v;
(af := Stdlib.(!af - 1));
if !af = 0 && !tracker = 1
Expand Down