From 46c835ff856bc2075307083dad7de01ff085e8c8 Mon Sep 17 00:00:00 2001 From: Chouffe Date: Fri, 1 Mar 2019 15:24:14 +0100 Subject: [PATCH] [clojure-package] fix docstrings in `normal.clj` * Fixed documentation string in `normal` function * Added spec to catch `high` < `low` in `uniform` * Added spec to catch `scale` <= 0 in `normal` * Added unit tests --- .../src/org/apache/clojure_mxnet/random.clj | 70 +++++++++++-------- .../org/apache/clojure_mxnet/random_test.clj | 4 +- 2 files changed, 45 insertions(+), 29 deletions(-) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj index 0ec2039ba79b..1261e659e6dc 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj @@ -16,70 +16,84 @@ ;; (ns org.apache.clojure-mxnet.random + "Random Number interface of mxnet." (:require - [org.apache.clojure-mxnet.shape :as mx-shape] - [org.apache.clojure-mxnet.context :as context] [clojure.spec.alpha :as s] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.shape :as mx-shape] [org.apache.clojure-mxnet.util :as util]) (:import (org.apache.mxnet Context Random))) (s/def ::low number?) (s/def ::high number?) +(s/def ::low-high (fn [[low high]] (<= low high))) (s/def ::shape-vec (s/coll-of pos-int? :kind vector?)) (s/def ::ctx #(instance? Context %)) (s/def ::uniform-opts (s/keys :opt-un [::ctx])) (defn uniform - "Generate uniform distribution in [low, high) with shape. - low: The lower bound of distribution. - high: The upper bound of distribution. - shape-vec: vector shape of the ndarray generated. - opts-map { - ctx: Context of output ndarray, will use default context if not specified. - out: Output place holder} - returns: The result ndarray with generated result./" + "Generate uniform distribution in [`low`, `high`) with shape. + `low`: The lower bound of distribution. + `high`: The upper bound of distribution. + `shape-vec`: vector shape of the ndarray generated. + `opts-map` { + `ctx`: Context of output ndarray, will use default context if not specified. + `out`: Output place holder} + returns: The result ndarray with generated result. + Ex: + (uniform 0 1 [1 10]) + (uniform -10 10 [100 100])" ([low high shape-vec {:keys [ctx out] :as opts}] - (util/validate! ::uniform-opts opts "Incorrect random uniform parameters") + (util/validate! ::uniform-opts opts "Incorrect random uniform parameters") (util/validate! ::low low "Incorrect random uniform parameter") (util/validate! ::high high "Incorrect random uniform parameters") + (util/validate! ::low-high [low high] "Incorrect random uniform parameters") (util/validate! ::shape-vec shape-vec "Incorrect random uniform parameters") (Random/uniform (float low) (float high) (mx-shape/->shape shape-vec) ctx out)) ([low high shape-vec] (uniform low high shape-vec {}))) (s/def ::loc number?) -(s/def ::scale number?) +(s/def ::scale (s/and number? pos?)) (s/def ::normal-opts (s/keys :opt-un [::ctx])) (defn normal - "Generate normal(Gaussian) distribution N(mean, stdvar^^2) with shape. - loc: The standard deviation of the normal distribution - scale: The upper bound of distribution. - shape-vec: vector shape of the ndarray generated. - opts-map { - ctx: Context of output ndarray, will use default context if not specified. - out: Output place holder} - returns: The result ndarray with generated result./" + "Generate normal (Gaussian) distribution N(mean, stdvar^^2) with shape. + `loc`: Mean (centre) of the distribution. + `scale`: Standard deviation (spread or width) of the distribution. + `shape-vec`: vector shape of the ndarray generated. + `opts-map` { + `ctx`: Context of output ndarray, will use default context if not specified. + `out`: Output place holder} + returns: The result ndarray with generated result. + Ex: + (normal 0 1 [10 10]) + (normal -5 4 [2 3])" ([loc scale shape-vec {:keys [ctx out] :as opts}] (util/validate! ::normal-opts opts "Incorrect random normal parameters") (util/validate! ::loc loc "Incorrect random normal parameters") (util/validate! ::scale scale "Incorrect random normal parameters") (util/validate! ::shape-vec shape-vec "Incorrect random uniform parameters") - (Random/normal (float loc) (float scale) (mx-shape/->shape shape-vec) ctx out)) + (Random/normal (float loc) + (float scale) + (mx-shape/->shape shape-vec) ctx out)) ([loc scale shape-vec] (normal loc scale shape-vec {}))) (s/def ::seed-state number?) (defn seed - " Seed the random number generators in mxnet. - This seed will affect behavior of functions in this module, - as well as results from executors that contains Random number - such as Dropout operators. + "Seed the random number generators in mxnet. + This seed will affect behavior of functions in this module, + as well as results from executors that contains Random number + such as Dropout operators. - seed-state: The random number seed to set to all devices. + `seed-state`: The random number seed to set to all devices. note: The random number generator of mxnet is by default device specific. This means if you set the same seed, the random number sequence - generated from GPU0 can be different from CPU." + generated from GPU0 can be different from CPU. + Ex: + (seed-state 42) + (seed-state 42.0)" [seed-state] (util/validate! ::seed-state seed-state "Incorrect seed parameters") - (Random/seed (int seed-state))) \ No newline at end of file + (Random/seed (int seed-state))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj index 6952335c1390..ca1dcc9430dc 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj @@ -56,6 +56,8 @@ (is (thrown? Exception (fn_ 'a 2 []))) (is (thrown? Exception (fn_ 1 'b []))) (is (thrown? Exception (fn_ 1 2 [-1]))) + (is (thrown? Exception (fn_ 1 0 [1 2]))) + (is (thrown? Exception (fn_ 1 -1 [1 2]))) (is (thrown? Exception (fn_ 1 2 [2 3 0]))) (is (thrown? Exception (fn_ 1 2 [10 10] {:ctx "a"}))) (let [ctx (context/default-context)] @@ -64,4 +66,4 @@ (deftest test-random-parameters-specs (random-or-normal random/normal) (random-or-normal random/uniform) - (is (thrown? Exception (random/seed "a")))) \ No newline at end of file + (is (thrown? Exception (random/seed "a"))))