blob: d1f59dc5082ed720a5248d053e57f30aa1c90493 [file]
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;
(ns dev.generator
(:require [t6.from-scala.core :as scala]
[t6.from-scala.core :refer [$ $$] :as $]
[clojure.reflect :as r]
[clojure.pprint]
[org.apache.clojure-mxnet.util :as util])
(:import (org.apache.mxnet NDArray NDArrayAPI
Symbol SymbolAPI
Base Base$RefInt Base$RefLong Base$RefFloat Base$RefString)
(scala.collection.mutable ListBuffer ArrayBuffer))
(:gen-class))
(defn clojure-case
"Transforms a scala string (function name) to clojure case"
[string]
(-> string
(clojure.string/replace #"(\s+)([A-Z][a-z]+)" "$1-$2")
(clojure.string/replace #"([A-Z]+)([A-Z][a-z]+)" "$1-$2")
(clojure.string/replace #"([a-z0-9])([A-Z])" "$1-$2")
(clojure.string/lower-case)
(clojure.string/replace #"\_" "-")
(clojure.string/replace #"\/" "div")))
(defn transform-param-names [coerce-fn parameter-types]
(->> parameter-types
(map str)
(map (fn [x] (or (coerce-fn x) x)))
(map (fn [x] (last (clojure.string/split x #"\."))))))
(defn symbol-transform-param-name [parameter-types]
(transform-param-names util/symbol-param-coerce parameter-types))
(defn ndarray-transform-param-name [parameter-types]
(transform-param-names util/ndarray-param-coerce parameter-types))
(defn has-variadic? [params]
(->> params
(map str)
(filter (fn [s] (re-find #"\&" s)))
count
pos?))
(defn increment-param-name [pname]
(if-let [num-str (re-find #"-\d" pname)]
(str
(first (clojure.string/split pname #"-"))
"-"
(inc (Integer/parseInt (last (clojure.string/split num-str #"-")))))
(str pname "-" 1)))
(defn rename-duplicate-params [pnames]
(->> (reduce
(fn [pname-counts n]
(let [rn (if (pname-counts n) (str n "-" (pname-counts n)) n)
inc-pname-counts (update-in pname-counts [n] (fnil inc 0))]
(update-in inc-pname-counts [:params] conj rn)))
{:params []}
pnames)
:params))
(defn get-public-no-default-methods [obj]
(->> (r/reflect obj)
:members
(map #(into {} %))
(filter #(-> % :flags :public))
(remove #(re-find #"org\$apache\$mxnet" (str (:name %))))
(remove #(re-find #"\$default" (str (:name %))))))
(defn get-public-to-gen-methods [public-to-hand-gen public-no-default]
(let [public-to-hand-gen-names
(into #{} (mapv (comp str :name) public-to-hand-gen))]
(remove #(-> % :name str public-to-hand-gen-names) public-no-default)))
(defn public-by-name-and-param-count [public-reflect-info]
(->> public-reflect-info
(group-by :name)
(map (fn [[k v]] [k (group-by #(count (:parameter-types %)) v)]))
(into {})))
(def license
(str
";; Licensed to the Apache Software Foundation (ASF) under one or more\n"
";; contributor license agreements. See the NOTICE file distributed with\n"
";; this work for additional information regarding copyright ownership.\n"
";; The ASF licenses this file to You under the Apache License, Version 2.0\n"
";; (the \"License\"); you may not use this file except in compliance with\n"
";; the License. You may obtain a copy of the License at\n"
";;\n"
";; http://www.apache.org/licenses/LICENSE-2.0\n"
";;\n"
";; Unless required by applicable law or agreed to in writing, software\n"
";; distributed under the License is distributed on an \"AS IS\" BASIS,\n"
";; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
";; See the License for the specific language governing permissions and\n"
";; limitations under the License.\n"
";;\n"))
(defn write-to-file [functions ns-gen fname]
(with-open [w (clojure.java.io/writer fname)]
(.write w ns-gen)
(.write w "\n\n")
(.write w ";; Do not edit - this is auto-generated")
(.write w "\n\n")
(.write w license)
(.write w "\n\n")
(.write w "\n\n")
(doseq [f functions]
(let [fstr (-> f
clojure.pprint/pprint
with-out-str
(clojure.string/replace #"\\n\\n" "\n"))]
(.write w fstr))
(.write w "\n"))))
(defn remove-prefix
[prefix s]
(let [regex (re-pattern (str prefix "(.*)"))
replacement "$1"]
(clojure.string/replace s regex replacement)))
(defn in-namespace-random? [op-name]
(or (clojure.string/includes? op-name "random_")
(clojure.string/includes? op-name "sample_")))
(defn op-name->namespace-type [op-name]
(cond
(#{"uniform" "normal"} op-name) :deprecated
(clojure.string/includes? op-name "random_") :random
(clojure.string/includes? op-name "sample_") :random
:else :core))
;;;;;;; Common operations
(def libinfo (Base/_LIB))
(def op-names
(let [l ($ ListBuffer/empty)]
(.mxListAllOpNames libinfo l)
(->> l
(util/buffer->vec)
(remove #(or (= "Custom" %) (re-matches #"^_.*" %))))))
(defn- parse-arg-type [s]
(let [[_ var-arg-type _ set-arg-type arg-spec _ type-req _ default-val] (re-find #"(([\w-\[\]\s]+)|\{([^}]+)\})\s*(\([^)]+\))?(,\s*(optional|required)(,\s*default=(.*))?)?" s)]
{:type (clojure.string/trim (or set-arg-type var-arg-type))
:spec arg-spec
:optional? (or (= "optional" type-req)
(= "boolean" var-arg-type))
:default default-val
:orig s}))
(defn- get-op-handle [op-name]
(let [ref (new Base$RefLong 0)]
(do (.nnGetOpHandle libinfo op-name ref)
(.value ref))))
(defn gen-op-info [op-name]
(let [handle (get-op-handle op-name)
name (new Base$RefString nil)
desc (new Base$RefString nil)
key-var-num-args (new Base$RefString nil)
num-args (new Base$RefInt 0)
arg-names ($ ListBuffer/empty)
arg-types ($ ListBuffer/empty)
arg-descs ($ ListBuffer/empty)]
(do (.mxSymbolGetAtomicSymbolInfo libinfo
handle
name
desc
num-args
arg-names
arg-types
arg-descs
key-var-num-args)
{:fn-name (clojure-case (.value name))
:fn-description (.value desc)
:args (mapv (fn [t n d] (assoc t :name n :description d))
(mapv parse-arg-type (util/buffer->vec arg-types))
(mapv clojure-case (util/buffer->vec arg-names))
(util/buffer->vec arg-descs))
:key-var-num-args (clojure-case (.value key-var-num-args))})))
;;;;;;; Symbol
(def symbol-public-no-default
(get-public-no-default-methods Symbol))
(into #{} (mapcat :parameter-types symbol-public-no-default))
;; #{java.lang.Object scala.collection.Seq scala.Option long double scala.collection.immutable.Map int ml.dmlc.mxnet.Executor float ml.dmlc.mxnet.Context java.lang.String scala.Enumeration$Value ml.dmlc.mxnet.Symbol int<> ml.dmlc.mxnet.Symbol<> ml.dmlc.mxnet.Shape java.lang.String<>}
(def symbol-hand-gen-set
#{"scala.Option"
"scala.Enumeration$Value"
"org.apache.mxnet.Context"
"scala.Tuple2"
"scala.collection.Traversable"})
;;; min and max have a conflicting arity of 2 with the auto gen signatures
(def symbol-filter-name-set #{"max" "min"})
(defn is-symbol-hand-gen? [info]
(or
(->> (:name info)
str
(get symbol-filter-name-set))
(->> (map str (:parameter-types info))
(into #{})
(clojure.set/intersection symbol-hand-gen-set)
count
pos?)))
(def symbol-public-to-hand-gen
(filter is-symbol-hand-gen? symbol-public-no-default))
(def symbol-public-to-gen
(get-public-to-gen-methods symbol-public-to-hand-gen
symbol-public-no-default))
(count symbol-public-to-hand-gen) ;=> 35 mostly bind!
(count symbol-public-to-gen) ;=> 307
(into #{} (map :name symbol-public-to-hand-gen))
;;=> #{arange bind ones zeros simpleBind Variable}
(defn symbol-vector-args []
`(if (map? ~'kwargs-map-or-vec-or-sym)
(~'util/empty-list)
(~'util/coerce-param ~'kwargs-map-or-vec-or-sym #{"scala.collection.Seq"})))
(defn symbol-map-args []
`(if (map? ~'kwargs-map-or-vec-or-sym)
(util/convert-symbol-map ~'kwargs-map-or-vec-or-sym)
nil))
(defn add-symbol-arities [params function-name]
(if (= ["sym-name" "kwargs-map" "symbol-list" "kwargs-map-1"]
(mapv str params))
[`([~'sym-name ~'attr-map ~'kwargs-map]
(~function-name ~'sym-name (~'util/convert-symbol-map ~'attr-map) (~'util/empty-list) (~'util/convert-symbol-map ~'kwargs-map)))
`([~'sym-name ~'kwargs-map-or-vec-or-sym]
(~function-name ~'sym-name nil ~(symbol-vector-args) ~(symbol-map-args)))
`([~'kwargs-map-or-vec-or-sym]
(~function-name nil nil ~(symbol-vector-args) ~(symbol-map-args)))]))
(defn gen-symbol-function-arity [op-name op-values function-name]
(mapcat
(fn [[param-count info]]
(let [targets (->> (mapv :parameter-types info)
(apply interleave)
(mapv str)
(partition (count info))
(mapv set))
pnames (->> (mapv :parameter-types info)
(mapv symbol-transform-param-name)
(apply interleave)
(partition (count info))
(mapv #(clojure.string/join "-or-" %))
(rename-duplicate-params)
(mapv symbol))
coerced-params (mapv (fn [p t] `(~'util/nil-or-coerce-param ~(symbol (clojure.string/replace p #"\& " "")) ~t)) pnames targets)
params (if (= #{:public :static} (:flags (first info)))
pnames
(into ['sym] pnames))
function-body (if (= #{:public :static} (:flags (first info)))
`(~'util/coerce-return (~(symbol (str "Symbol/" op-name)) ~@coerced-params))
`(~'util/coerce-return (~(symbol (str "." op-name)) ~'sym ~@coerced-params)
))]
(when (not (and (> param-count 1) (has-variadic? params)))
`[(
~params
~function-body
)
~@(add-symbol-arities params function-name)])))
op-values))
(def all-symbol-functions
(for [operation (sort (public-by-name-and-param-count symbol-public-to-gen))]
(let [[op-name op-values] operation
function-name (-> op-name
str
scala/decode-scala-symbol
clojure-case
symbol)]
`(~'defn ~function-name
~@(remove nil? (gen-symbol-function-arity op-name op-values function-name))))))
(def symbol-gen-ns "(ns org.apache.clojure-mxnet.symbol
(:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max
min repeat reverse set sort take to-array empty sin
get apply shuffle ref])
(:require [org.apache.clojure-mxnet.util :as util])
(:import (org.apache.mxnet Symbol)))")
(defn generate-symbol-file []
(println "Generating symbol file")
(write-to-file all-symbol-functions
symbol-gen-ns
"src/org/apache/clojure_mxnet/gen/symbol.clj"))
;;;;;;; NDArray
(def ndarray-public-no-default
(get-public-no-default-methods NDArray))
(def ndarray-hand-gen-set
#{"org.apache.mxnet.NDArrayFuncReturn"
"org.apache.mxnet.Context"
"scala.Enumeration$Value"
"scala.Tuple2"
"scala.collection.Traversable"})
(defn is-ndarray-hand-gen? [info]
(->> (map str (:parameter-types info))
(into #{})
(clojure.set/intersection ndarray-hand-gen-set)
count
pos?))
(def ndarray-public-to-hand-gen
(filter is-ndarray-hand-gen? ndarray-public-no-default))
(def ndarray-public-to-gen
(get-public-to-gen-methods ndarray-public-to-hand-gen
ndarray-public-no-default))
(count ndarray-public-to-hand-gen) ;=> 15
(count ndarray-public-to-gen) ;=> 486
(->> ndarray-public-to-hand-gen (map :name) (into #{}))
(defn gen-ndarray-function-arity [op-name op-values]
(for [[param-count info] op-values]
(let [targets (->> (mapv :parameter-types info)
(apply interleave)
(mapv str)
(partition (count info))
(mapv set))
pnames (->> (mapv :parameter-types info)
(mapv ndarray-transform-param-name)
(apply interleave)
(partition (count info))
(mapv #(clojure.string/join "-or-" %))
(rename-duplicate-params)
(mapv symbol))
coerced-params (mapv (fn [p t] `(~'util/coerce-param ~(symbol (clojure.string/replace p #"\& " "")) ~t)) pnames targets)
params (if (= #{:public :static} (:flags (first info)))
pnames
(into ['ndarray] pnames))
function-body (if (= #{:public :static} (:flags (first info)))
`(~'util/coerce-return (~(symbol (str "NDArray/" op-name)) ~@coerced-params))
`(~'util/coerce-return (~(symbol (str "." op-name)) ~'ndarray ~@coerced-params)
))]
(when (not (and (> param-count 1) (has-variadic? params)))
`(
~params
~function-body
)))))
(defn gen-ndarray-functions [public-to-gen-methods]
(for [operation (sort (public-by-name-and-param-count public-to-gen-methods))]
(let [[op-name op-values] operation
function-name (-> op-name
str
scala/decode-scala-symbol
clojure-case
symbol)]
`(~'defn ~function-name
~@(remove nil? (gen-ndarray-function-arity op-name op-values))))))
(def all-ndarray-functions
(gen-ndarray-functions ndarray-public-to-gen))
(def ndarray-gen-ns
"(ns org.apache.clojure-mxnet.ndarray
(:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max
min repeat reverse set sort take to-array empty shuffle
ref])
(:import (org.apache.mxnet NDArray Shape)))")
(defn generate-ndarray-file []
(println "Generating ndarray file")
(write-to-file all-ndarray-functions
ndarray-gen-ns
"src/org/apache/clojure_mxnet/gen/ndarray.clj"))
;;;;;;; SymbolAPI
(defn fn-name->random-fn-name
[fn-name]
(cond
(clojure.string/starts-with? fn-name "-random-")
(remove-prefix "-random-" fn-name)
(clojure.string/starts-with? fn-name "-sample-")
(str (remove-prefix "-sample-" fn-name) "-like")
:else fn-name))
(defn symbol-api-coerce-param
[{:keys [name sym type optional?]}]
(let [coerced-param (case type
"Shape" `(when ~sym (~'mx-shape/->shape ~sym))
"NDArray-or-Symbol[]" `(~'clojure.core/into-array ~sym)
"Map[String, String]"
`(when ~sym
(->> ~sym
(mapv (fn [[~'k ~'v]] [~'k (str ~'v)]))
(into {})
~'util/convert-map))
sym)
nil-param-allowed? (#{"name" "attr"} name)]
(if (and optional? (not nil-param-allowed?))
`(~'util/->option ~coerced-param)
coerced-param)))
(defn gen-symbol-api-doc [fn-description params]
(let [param-descriptions (mapv (fn [{:keys [name description optional?]}]
(str "`" name "`: "
description
(when optional? " (optional)")
"\n"))
params)]
(str fn-description "\n\n"
(apply str param-descriptions))))
(defn gen-symbol-api-default-arity [op-name params]
(let [opt-params (filter :optional? params)
coerced-params (mapv symbol-api-coerce-param params)
default-args (array-map :keys (mapv :sym params)
:or (into {}
(mapv (fn [{:keys [sym]}] [sym nil])
opt-params))
:as 'opts)]
`([~default-args]
(~'util/coerce-return
(~(symbol (str "SymbolAPI/" op-name))
~@coerced-params)))))
(defn symbol-api-gen-ns
[random-namespace?]
(str
"(ns\n"
" ^{:doc \"Experimental\"}\n"
(if random-namespace?
" org.apache.clojure-mxnet.symbol-random-api\n"
" org.apache.clojure-mxnet.symbol-api\n")
" (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max\n"
" min repeat reverse set sort take to-array empty sin\n"
" get apply shuffle ref])\n"
" (:require [org.apache.clojure-mxnet.util :as util]\n"
" [org.apache.clojure-mxnet.shape :as mx-shape])\n"
" (:import (org.apache.mxnet SymbolAPI)))"))
(defn make-gen-symbol-api-function
[{:keys [fn-name->fn-name] :or {fn-name->fn-name identity}}]
(fn [op-name]
(let [{:keys [fn-name fn-description args]}
(-> op-name (gen-op-info) (update :fn-name fn-name->fn-name))
params (mapv (fn [{:keys [name type optional?] :as opts}]
(assoc opts
:sym (symbol name)
:optional? (or optional?
(= "NDArray-or-Symbol" type))))
(conj args
{:name "name"
:type "String"
:optional? true
:description "Name of the symbol"}
{:name "attr"
:type "Map[String, String]"
:optional? true
:description "Attributes of the symbol"}))
doc (clojure.string/join
"\n\n "
(-> (gen-symbol-api-doc fn-description params)
(clojure.string/split #"\n")))
default-call (gen-symbol-api-default-arity op-name params)]
`(~'defn ~(symbol fn-name)
~doc
~@default-call))))
(def gen-symbol-api-function
(make-gen-symbol-api-function {}))
(def gen-symbol-random-api-function
(make-gen-symbol-api-function {:fn-name->fn-name fn-name->random-fn-name}))
(defn all-symbol-api-functions [op-names]
(->> op-names
(filter #(= :core (op-name->namespace-type %)))
(mapv gen-symbol-api-function)))
(count (all-symbol-api-functions op-names)) ;215
(defn all-symbol-random-api-functions [op-names]
(->> op-names
(filter #(= :random (op-name->namespace-type %)))
(mapv gen-symbol-random-api-function)))
(count (all-symbol-random-api-functions op-names)) ;16
(defn generate-symbol-api-file [op-names]
(println "Generating symbol-api file")
(write-to-file (all-symbol-api-functions op-names)
(symbol-api-gen-ns false)
"src/org/apache/clojure_mxnet/gen/symbol_api.clj"))
(defn generate-symbol-random-api-file [op-names]
(println "Generating symbol-random-api file")
(write-to-file (all-symbol-random-api-functions op-names)
(symbol-api-gen-ns true)
"src/org/apache/clojure_mxnet/gen/symbol_random_api.clj"))
;;;;;;; NDArrayAPI
(defn ndarray-api-coerce-param
[{:keys [sym type optional?]}]
(let [coerced-param (case type
"Shape" `(when ~sym (~'mx-shape/->shape ~sym))
"NDArray-or-Symbol[]" `(~'clojure.core/into-array ~sym)
sym)]
(if optional?
`(~'util/->option ~coerced-param)
coerced-param)))
(defn gen-ndarray-api-doc [fn-description params]
(let [param-descriptions (mapv (fn [{:keys [name description optional?]}]
(str "`" name "`: "
description
(when optional? " (optional)")
"\n"))
params)]
(str fn-description "\n\n"
(apply str param-descriptions))))
(defn gen-ndarray-api-default-arity [op-name params]
(let [opt-params (filter :optional? params)
coerced-params (mapv ndarray-api-coerce-param params)
default-args (array-map :keys (mapv :sym params)
:or (into {}
(mapv (fn [{:keys [sym]}] [sym nil])
opt-params))
:as 'opts)]
`([~default-args]
(~'util/coerce-return
(~(symbol (str "NDArrayAPI/" op-name))
~@coerced-params)))))
(defn gen-ndarray-api-required-arity [fn-name req-params]
(let [req-args (->> req-params
(mapv (fn [{:keys [sym]}] [(keyword sym) sym]))
(into {}))]
`(~(mapv :sym req-params)
(~(symbol fn-name) ~req-args))))
(defn make-gen-ndarray-api-function
[{:keys [fn-name->fn-name] :or {fn-name->fn-name identity}}]
(fn [op-name]
(let [{:keys [fn-name fn-description args]}
(-> op-name (gen-op-info) (update :fn-name fn-name->fn-name))
params (mapv (fn [{:keys [name] :as opts}]
(assoc opts :sym (symbol name)))
(conj args {:name "out"
:type "NDArray-or-Symbol"
:optional? true
:description "Output array."}))
doc (clojure.string/join
"\n\n "
(-> (gen-ndarray-api-doc fn-description params)
(clojure.string/split #"\n")))
opt-params (filter :optional? params)
req-params (remove :optional? params)
req-call (gen-ndarray-api-required-arity fn-name req-params)
default-call (gen-ndarray-api-default-arity op-name params)]
(if (= 1 (count req-params))
`(~'defn ~(symbol fn-name)
~doc
~@default-call)
`(~'defn ~(symbol fn-name)
~doc
~req-call
~default-call)))))
(def gen-ndarray-api-function
(make-gen-ndarray-api-function {}))
(def gen-ndarray-random-api-function
(make-gen-ndarray-api-function {:fn-name->fn-name fn-name->random-fn-name}))
(defn all-ndarray-api-functions [op-names]
(->> op-names
(filter #(= :core (op-name->namespace-type %)))
(mapv gen-ndarray-api-function)))
(count (all-ndarray-api-functions op-names)) ; 213
(defn all-ndarray-random-api-functions [op-names]
(->> op-names
(filter #(= :random (op-name->namespace-type %)))
(mapv gen-ndarray-random-api-function)))
(count (all-ndarray-random-api-functions op-names)) ;16
(defn ndarray-api-gen-ns [random-namespace?]
(str
"(ns\n"
" ^{:doc \"Experimental\"}\n"
(if random-namespace?
" org.apache.clojure-mxnet.ndarray-random-api\n"
" org.apache.clojure-mxnet.ndarray-api\n")
" (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max\n"
" min repeat reverse set sort take to-array empty shuffle\n"
" ref])\n"
" (:require [org.apache.clojure-mxnet.shape :as mx-shape]\n"
" [org.apache.clojure-mxnet.util :as util])\n"
" (:import (org.apache.mxnet NDArrayAPI)))"))
(defn generate-ndarray-api-file [op-names]
(println "Generating ndarray-api file")
(write-to-file (all-ndarray-api-functions op-names)
(ndarray-api-gen-ns false)
"src/org/apache/clojure_mxnet/gen/ndarray_api.clj"))
(defn generate-ndarray-random-api-file [op-names]
(println "Generating ndarray-random-api file")
(write-to-file (all-ndarray-random-api-functions op-names)
(ndarray-api-gen-ns true)
"src/org/apache/clojure_mxnet/gen/ndarray_random_api.clj"))
;;; autogen the files
(do
(generate-ndarray-file)
;; NDArrayAPI
(generate-ndarray-api-file op-names)
(generate-ndarray-random-api-file op-names)
(generate-symbol-file)
;; SymbolAPI
(generate-symbol-api-file op-names)
(generate-symbol-random-api-file op-names))
(comment
(gen-op-info "ElementWiseSum")
(gen-ndarray-api-function "Activation")
(gen-symbol-api-function "Activation")
(gen-ndarray-random-api-function "random_randint")
(gen-ndarray-random-api-function "sample_normal")
(gen-symbol-random-api-function "random_poisson")
;; This generates a file with the bulk of the nd-array functions
(generate-ndarray-file)
;; This generates a file with the bulk of the symbol functions
(generate-symbol-file))