Skip to content

Commit

Permalink
Add implot bindings.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed May 16, 2019
1 parent 87692af commit 2eba25b
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 0 deletions.
25 changes: 25 additions & 0 deletions examples/fig.ml
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,28 @@ let () =
Ax.scatter ax ~c:Blue ~marker:'*' xys3;
Fig.suptitle fig "...scatter...";
Mpl.show ()

let () =
let fig, ax1, ax2 = Fig.create_with_two_axes `horizontal in
let data1 =
Array.init 128 ~f:(fun i ->
Array.init 128 ~f:(fun j ->
i + j))
|> Imshow_data.scalar Imshow_data.int
in
Ax.grid ax1 false;
Ax.imshow ax1 data1;
Ax.set_title ax1 "scalar - default cmap";
let data2 =
Array.init 128 ~f:(fun i ->
Array.init 128 ~f:(fun j ->
let i = Float.of_int i *. 0.1 in
let j = Float.of_int j *. 0.1 in
Float.(abs (cos i), abs (sin j), abs (cos (0.1 *. (i+.j))))))
|> Imshow_data.rgb Imshow_data.float
in
Ax.grid ax2 false;
Ax.imshow ax2 data2;
Ax.set_title ax2 "rgb";
Fig.suptitle fig "imshow";
Mpl.show ()
1 change: 1 addition & 0 deletions src/matplotlib/fig_ax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ module Ax = struct
let plot = Mpl.plot
let hist = Mpl.hist
let scatter = Mpl.scatter
let imshow = Mpl.imshow

module Expert = struct
let to_pyobject = Fn.id
Expand Down
6 changes: 6 additions & 0 deletions src/matplotlib/fig_ax.mli
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ module Ax : sig
-> (float * float) array
-> unit

val imshow
: t
-> ?cmap:string
-> Mpl.Imshow_data.t
-> unit

module Expert : sig
val to_pyobject : t -> Py.Object.t
end
Expand Down
1 change: 1 addition & 0 deletions src/matplotlib/matplotlib.ml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
module Ax = Fig_ax.Ax
module Fig = Fig_ax.Fig
module Imshow_data = Mpl.Imshow_data
module Mpl = Mpl.Public
module Pyplot = Pyplot
51 changes: 51 additions & 0 deletions src/matplotlib/mpl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,54 @@ let scatter p ?s ?c ?marker ?alpha ?linewidths xys =
let xs = Py.List.of_array_map (fun (x, _) -> Py.Float.of_float x) xys in
let ys = Py.List.of_array_map (fun (_, y) -> Py.Float.of_float y) xys in
ignore (Py.Module.get_function_with_keywords p "scatter" [| xs; ys |] keywords)

module Imshow_data = struct
type 'a data =
| Scalar of 'a array array
| Rgb of ('a * 'a * 'a) array array
| Rgba of ('a * 'a * 'a * 'a) array array

type 'a typ_ =
| Int : int typ_
| Float : float typ_

let int = Int
let float = Float

type t = P : ('a data * 'a typ_) -> t

let scalar typ_ data = P (Scalar data, typ_)
let rgb typ_ data = P (Rgb data, typ_)
let rgba typ_ data = P (Rgba data, typ_)

let to_pyobject (type a) (P (data, typ_)) =
let to_pyobject ~scalar_to_pyobject =
match data with
| Scalar data ->
Py.List.of_array_map (Py.List.of_array_map scalar_to_pyobject) data
| Rgb data ->
let rgb_to_pyobject (r, g, b) =
(scalar_to_pyobject r, scalar_to_pyobject g, scalar_to_pyobject b)
|> Py.Tuple.of_tuple3
in
Py.List.of_array_map (Py.List.of_array_map rgb_to_pyobject) data
| Rgba data ->
let rgba_to_pyobject (r, g, b, a) =
(scalar_to_pyobject r, scalar_to_pyobject g, scalar_to_pyobject b, scalar_to_pyobject a)
|> Py.Tuple.of_tuple4
in
Py.List.of_array_map (Py.List.of_array_map rgba_to_pyobject) data
in
match typ_ with
| Int -> to_pyobject ~scalar_to_pyobject:Py.Int.of_int
| Float -> to_pyobject ~scalar_to_pyobject:Py.Float.of_float
end

let imshow p ?cmap data =
let keywords =
List.filter_opt
[ Option.map cmap ~f:(fun c -> "cmap", Py.String.of_string c)
]
in
let data = Imshow_data.to_pyobject data in
ignore (Py.Module.get_function_with_keywords p "imshow" [| data |] keywords)
18 changes: 18 additions & 0 deletions src/matplotlib/mpl.mli
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,21 @@ val scatter
-> ?linewidths:float
-> (float * float) array
-> unit

module Imshow_data : sig
type t

type 'a typ_
val int : int typ_
val float : float typ_

val scalar : 'a typ_ -> 'a array array -> t
val rgb : 'a typ_ -> ('a * 'a * 'a) array array -> t
val rgba : 'a typ_ -> ('a * 'a * 'a * 'a) array array -> t
end

val imshow
: Py.Object.t
-> ?cmap:string
-> Imshow_data.t
-> unit
4 changes: 4 additions & 0 deletions src/matplotlib/pyplot.ml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ let hist ?label ?color ?bins ?orientation ?histtype ?xs ys =
let scatter ?s ?c ?marker ?alpha ?linewidths xys =
let p = Mpl.pyplot_module () in
Mpl.scatter p ?s ?c ?marker ?alpha ?linewidths xys

let imshow ?cmap xys =
let p = Mpl.pyplot_module () in
Mpl.imshow p ?cmap xys
5 changes: 5 additions & 0 deletions src/matplotlib/pyplot.mli
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,8 @@ val scatter
-> ?linewidths:float
-> (float * float) array
-> unit

val imshow
: ?cmap:string
-> Mpl.Imshow_data.t
-> unit

0 comments on commit 2eba25b

Please sign in to comment.