Skip to content

Commit

Permalink
Add 3D scatter plot (#5)
Browse files Browse the repository at this point in the history
* Add 3D scatter plot.

* Use separate type for 3D to avoid runtime errors
  • Loading branch information
crackcomm authored Sep 17, 2021
1 parent d9d47ff commit 653b2fc
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/matplotlib/fig_ax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,32 @@ module Ax = struct
end
end

module Ax3d = struct
type t = Py.Object.t

let set_title = Ax.set_title
let set_xlim = Ax.set_xlim
let set_ylim = Ax.set_ylim
let set_xlabel = Ax.set_xlabel
let set_ylabel = Ax.set_ylabel

let set_zlim t ~bottom ~top =
ignore ((t.&("set_zlim")) [| Py.Float.of_float bottom; Py.Float.of_float top |])

let set_zlabel t label = ignore ((t.&("set_zlabel")) [| Py.String.of_string label |])

let grid t b =
let keywords = [ "b", Py.Bool.of_bool b ] in
ignore (Py.Module.get_function_with_keywords t "grid" [||] keywords)

let scatter = Mpl.scatter_3d
let imshow = Mpl.imshow

module Expert = struct
let to_pyobject = Fn.id
end
end

module Fig = struct
type t = Py.Object.t

Expand All @@ -81,6 +107,11 @@ module Fig = struct
let args = [| nrows; ncols; index |] |> Array.map ~f:Py.Int.of_int in
Py.Module.get_function_with_keywords t "add_subplot" args keywords

let add_subplot_3d t ~nrows ~ncols ~index =
let keywords = [ "projection", Py.String.of_string "3d" ] in
let args = [| nrows; ncols; index |] |> Array.map ~f:Py.Int.of_int in
Py.Module.get_function_with_keywords t "add_subplot" args keywords

let create_with_ax ?figsize () =
let t = create ?figsize () in
let ax = add_subplot t ~nrows:1 ~ncols:1 ~index:1 in
Expand Down
30 changes: 30 additions & 0 deletions src/matplotlib/fig_ax.mli
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,35 @@ module Ax : sig
end
end

module Ax3d : sig
type t

val set_title : t -> string -> unit
val set_xlim : t -> left:float -> right:float -> unit
val set_ylim : t -> bottom:float -> top:float -> unit
val set_zlim : t -> bottom:float -> top:float -> unit
val set_xlabel : t -> string -> unit
val set_ylabel : t -> string -> unit
val set_zlabel : t -> string -> unit
val grid : t -> bool -> unit

val scatter
: t
-> ?s:float
-> ?c:Mpl.Color.t
-> ?marker:char
-> ?alpha:float
-> ?linewidths:float
-> (float * float * float) array
-> unit

val imshow : t -> ?cmap:string -> Mpl.Imshow_data.t -> unit

module Expert : sig
val to_pyobject : t -> Py.Object.t
end
end

module Fig : sig
type t

Expand All @@ -70,6 +99,7 @@ module Fig : sig
increases to the right.
*)
val add_subplot : t -> nrows:int -> ncols:int -> index:int -> Ax.t
val add_subplot_3d : t -> nrows:int -> ncols:int -> index:int -> Ax3d.t
val create_with_ax : ?figsize:float * float -> unit -> t * Ax.t

val create_with_two_axes
Expand Down
1 change: 1 addition & 0 deletions src/matplotlib/matplotlib.ml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
module Ax = Fig_ax.Ax
module Ax3d = Fig_ax.Ax3d
module Fig = Fig_ax.Fig
module Imshow_data = Mpl.Imshow_data
module Mpl = Mpl.Public
Expand Down
15 changes: 15 additions & 0 deletions src/matplotlib/mpl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,21 @@ let scatter p ?s ?c ?marker ?alpha ?linewidths xys =
let ys = Py.List.of_array_map (fun (_, y) -> Py.Float.of_float y) xys in
ignore (Py.Module.get_function_with_keywords p "scatter" [| xs; ys |] keywords)

let scatter_3d p ?s ?c ?marker ?alpha ?linewidths xyzs =
let keywords =
List.filter_opt
[ Option.map c ~f:(fun c -> "c", Color.to_pyobject c)
; Option.map s ~f:(fun s -> "s", Py.Float.of_float s)
; Option.map marker ~f:(fun m -> "marker", String.of_char m |> Py.String.of_string)
; Option.map alpha ~f:(fun a -> "alpha", Py.Float.of_float a)
; Option.map linewidths ~f:(fun l -> "linewidths", Py.Float.of_float l)
]
in
let xs = Py.List.of_array_map (fun (x, _, _) -> Py.Float.of_float x) xyzs in
let ys = Py.List.of_array_map (fun (_, y, _) -> Py.Float.of_float y) xyzs in
let zs = Py.List.of_array_map (fun (_, _, z) -> Py.Float.of_float z) xyzs in
ignore (Py.Module.get_function_with_keywords p "scatter" [| xs; ys; zs |] keywords)

module Imshow_data = struct
type 'a data =
| Scalar of 'a array array
Expand Down
10 changes: 10 additions & 0 deletions src/matplotlib/mpl.mli
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ val scatter
-> (float * float) array
-> unit

val scatter_3d
: Py.Object.t
-> ?s:float
-> ?c:Color.t
-> ?marker:char
-> ?alpha:float
-> ?linewidths:float
-> (float * float * float) array
-> unit

module Imshow_data : sig
type t
type 'a typ_
Expand Down

0 comments on commit 653b2fc

Please sign in to comment.