Skip to content

Commit

Permalink
Add legend function to pyplot (#2)
Browse files Browse the repository at this point in the history
* Make it possible to call Legend without figure

* Use Pyplot.legend in examples/pyplot.ml

* Makefile typo fix

* Update examples/fig.ml
  • Loading branch information
mknbv committed May 21, 2020
1 parent 54c779b commit 64bef1f
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 42 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ all:
dune build @install

plot: .FORCE
dune build examples/plot.exe
_build/default/examples/plot.exe
dune build examples/pyplot.exe
_build/default/examples/pyplot.exe

clean:
rm -Rf _build
Expand Down
2 changes: 1 addition & 1 deletion examples/fig.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ let left_graph ax =
Ax.grid ax true;
Ax.plot ax ~label:"sin1" ~color:Red ~xs ys1;
Ax.plot ax ~label:"sin2" ~color:Green ~linestyle:Dotted ~linewidth:2. ~xs ys2;
Ax.legend ax
Ax.legend ax ()

let right_graph ax =
let rnds =
Expand Down
3 changes: 2 additions & 1 deletion examples/pyplot.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ let () =
let ys1 = Array.of_list ys1 in
let ys2 = Array.of_list ys2 in
Pyplot.xlabel "x";
Pyplot.ylabel "sin(x)";
Pyplot.ylabel "y";
Pyplot.grid true;
Pyplot.plot ~color:Red ~xs ys1;
Pyplot.plot ~color:Green ~linestyle:Dotted ~linewidth:2. ~xs ys2;
Pyplot.legend ~labels:[|"$y=\\sin(x/20)$"; "$y=\\cos(x/12)$"|] ();
Mpl.savefig "test.png";
let data = Mpl.plot_data `png in
Stdio.Out_channel.write_all "test2.png" ~data;
Expand Down
25 changes: 1 addition & 24 deletions src/matplotlib/fig_ax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -51,34 +51,11 @@ module Ax = struct
in
ignore (Py.Module.get_function_with_keywords t "grid" [||] keywords)

let legend ?loc t =
let keywords =
let loc =
Option.map loc ~f:(fun loc ->
let loc =
match loc with
| `best -> "best"
| `upper_right -> "upper right"
| `upper_left -> "upper left"
| `lower_left -> "lower left"
| `lower_right -> "lower right"
| `right -> "right"
| `center_left -> "center left"
| `center_right -> "center right"
| `lower_center -> "lower center"
| `upper_center -> "upper center"
| `center -> "center"
in
"loc", Py.String.of_string loc)
in
List.filter_opt [ loc ]
in
ignore (Py.Module.get_function_with_keywords t "legend" [||] keywords)

let plot = Mpl.plot
let hist = Mpl.hist
let scatter = Mpl.scatter
let imshow = Mpl.imshow
let legend = Mpl.legend

module Expert = struct
let to_pyobject = Fn.id
Expand Down
15 changes: 1 addition & 14 deletions src/matplotlib/fig_ax.mli
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,7 @@ module Ax : sig
-> unit

val legend
: ?loc:[ `best
| `upper_right
| `upper_left
| `lower_left
| `lower_right
| `right
| `center_left
| `center_right
| `lower_center
| `upper_center
| `center
]
-> t
-> unit
: t -> ?labels:string array -> ?loc:Mpl.Loc.t -> unit -> unit

val plot
: t
Expand Down
39 changes: 39 additions & 0 deletions src/matplotlib/mpl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,36 @@ module Linestyle = struct
Py.String.of_string str
end

module Loc = struct
type t =
| Best
| UpperRight
| UpperLeft
| LowerLeft
| LowerRight
| Right
| CenterLeft
| CenterRight
| LowerCenter
| UpperCenter
| Center

let to_pyobject t =
let str = match t with
| Best -> "best"
| UpperRight -> "upper right"
| UpperLeft -> "upper left"
| LowerLeft -> "lower left"
| LowerRight -> "lower right"
| Right -> "right"
| CenterLeft -> "center left"
| CenterRight -> "center right"
| LowerCenter -> "lower center"
| UpperCenter -> "upper center"
| Center -> "center"
in Py.String.of_string str
end

let savefig filename =
let p = pyplot_module () in
ignore ((p.&("savefig")) [| Py.String.of_string filename |])
Expand Down Expand Up @@ -122,6 +152,7 @@ module Public = struct
module Backend = Backend
module Color = Color
module Linestyle = Linestyle
module Loc = Loc

let set_backend = set_backend
let show = show
Expand Down Expand Up @@ -245,3 +276,11 @@ let imshow p ?cmap data =
in
let data = Imshow_data.to_pyobject data in
ignore (Py.Module.get_function_with_keywords p "imshow" [| data |] keywords)

let legend p ?labels ?loc () =
let keywords = List.filter_opt
[ Option.map labels ~f:(fun labels -> "labels",
Py.List.of_array_map Py.String.of_string labels)
; Option.map loc ~f:(fun loc -> "loc", Loc.to_pyobject loc)
] in
ignore (Py.Module.get_function_with_keywords p "legend" [||] keywords)
40 changes: 40 additions & 0 deletions src/matplotlib/mpl.mli
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ module Linestyle : sig
val to_pyobject : t -> Py.Object.t
end

module Loc : sig
type t =
| Best
| UpperRight
| UpperLeft
| LowerLeft
| LowerRight
| Right
| CenterLeft
| CenterRight
| LowerCenter
| UpperCenter
| Center

val to_pyobject : t -> Py.Object.t
end


(* [set_backend] has to be called before any other operation. *)
val set_backend : Backend.t -> unit
val pyplot_module : unit -> Py.Object.t
Expand Down Expand Up @@ -62,6 +80,21 @@ module Public : sig
| Other of string
end

module Loc : sig
type t =
| Best
| UpperRight
| UpperLeft
| LowerLeft
| LowerRight
| Right
| CenterLeft
| CenterRight
| LowerCenter
| UpperCenter
| Center
end

(* [set_backend] has to be called before any other operation. *)
val set_backend : Backend.t -> unit
val show : unit -> unit
Expand Down Expand Up @@ -118,3 +151,10 @@ module Imshow_data : sig
end

val imshow : Py.Object.t -> ?cmap:string -> Imshow_data.t -> unit

val legend
: Py.Object.t
-> ?labels:string array
-> ?loc:Loc.t
-> unit
-> unit
4 changes: 4 additions & 0 deletions src/matplotlib/pyplot.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,7 @@ let scatter ?s ?c ?marker ?alpha ?linewidths xys =
let imshow ?cmap xys =
let p = Mpl.pyplot_module () in
Mpl.imshow p ?cmap xys

let legend ?labels ?loc () =
let p = Mpl.pyplot_module () in
Mpl.legend p ?labels ?loc ()
2 changes: 2 additions & 0 deletions src/matplotlib/pyplot.mli
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,5 @@ val scatter
-> unit

val imshow : ?cmap:string -> Mpl.Imshow_data.t -> unit

val legend : ?labels:(string array) -> ?loc:Mpl.Loc.t -> unit -> unit

0 comments on commit 64bef1f

Please sign in to comment.