blob: 1f0418951f3a6eb3202b42cb9541869313035361 [file]
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;
(ns dev.generator
(:require [t6.from-scala.core :as scala]
[clojure.reflect :as r]
[org.apache.clojure-mxnet.util :as util]
[clojure.pprint])
(:import (org.apache.mxnet NDArray Symbol))
(:gen-class))
(defn clojure-case
[string]
(-> string
(clojure.string/replace #"(\s+)([A-Z][a-z]+)" "$1-$2")
(clojure.string/replace #"([A-Z]+)([A-Z][a-z]+)" "$1-$2")
(clojure.string/replace #"([a-z0-9])([A-Z])" "$1-$2")
(clojure.string/lower-case)
(clojure.string/replace #"\_" "-")
(clojure.string/replace #"\/" "div")))
(defn symbol-transform-param-name [parameter-types]
(->> parameter-types
(map str)
(map (fn [x] (or (util/symbol-param-coerce x) x)))
(map (fn [x] (last (clojure.string/split x #"\."))))))
(defn ndarray-transform-param-name [parameter-types]
(->> parameter-types
(map str)
(map (fn [x] (or (util/ndarray-param-coerce x) x)))
(map (fn [x] (last (clojure.string/split x #"\."))))))
(defn has-variadic? [params]
(->> params
(map str)
(filter (fn [s] (re-find #"\&" s)))
count
pos?))
(defn increment-param-name [pname]
(if-let [num-str (re-find #"-\d" pname)]
(str (first (clojure.string/split pname #"-")) "-" (inc (Integer/parseInt (last (clojure.string/split num-str #"-")))))
(str pname "-" 1)))
(defn rename-duplicate-params [params]
(reduce (fn [known-names n] (conj known-names (if (contains? (set known-names) n)
(increment-param-name n)
n)))
[]
params))
;;;;;;; symbol
(def symbol-reflect-info (->> (:members (r/reflect Symbol))
(map #(into {} %))))
(def symbol-public (filter (fn [x] (-> x :flags :public)) symbol-reflect-info))
(def symbol-public-no-default (->> symbol-public
(filter #(not (re-find #"org\$apache\$mxnet" (str (:name %)))))
(filter #(not (re-find #"\$default" (str (:name %)))))))
(into #{} (mapcat :parameter-types symbol-public-no-default))
;#{java.lang.Object scala.collection.Seq scala.Option long double scala.collection.immutable.Map int ml.dmlc.mxnet.Executor float ml.dmlc.mxnet.Context java.lang.String scala.Enumeration$Value ml.dmlc.mxnet.Symbol int<> ml.dmlc.mxnet.Symbol<> ml.dmlc.mxnet.Shape java.lang.String<>}
(def symbol-hand-gen-set #{"scala.Option"
"int org.apache.mxnet.Executor"
"scala.Enumeration$Value"
"org.apache.mxnet.Context"
"scala.Tuple2"
"scala.collection.Traversable"} )
;;; min and max have a conflicting arity of 2 with the auto gen signatures
(def symbol-filter-name-set #{"max" "min"})
(defn is-symbol-hand-gen? [info]
(or
(->> (:name info)
str
(get symbol-filter-name-set))
(->> (map str (:parameter-types info))
(into #{})
(clojure.set/intersection symbol-hand-gen-set)
count
pos?)))
(def symbol-public-to-hand-gen (filter is-symbol-hand-gen? symbol-public-no-default))
(def symbol-public-to-gen (->> (remove #(contains?(->> symbol-public-to-hand-gen
(mapv :name)
(mapv str)
(set)) (str (:name %))) symbol-public-no-default)))
(count symbol-public-to-hand-gen) ;=> 35 mostly bind!
(count symbol-public-to-gen) ;=> 307
(into #{} (map :name symbol-public-to-hand-gen));=> #{arange bind ones zeros simpleBind Variable}
(defn public-by-name-and-param-count [public-reflect-info]
(->> public-reflect-info
(group-by :name)
(map (fn [[k v]] [k (group-by #(count (:parameter-types %)) v)]))
(into {})))
(defn symbol-vector-args []
`(if (map? ~'kwargs-map-or-vec-or-sym) (~'util/empty-list) (~'util/coerce-param ~'kwargs-map-or-vec-or-sym #{"scala.collection.Seq"})))
(defn symbol-map-args []
`(if (map? ~'kwargs-map-or-vec-or-sym) (util/convert-symbol-map ~'kwargs-map-or-vec-or-sym) nil))
(defn add-symbol-arities [params function-name]
(if (= ["sym-name" "kwargs-map" "symbol-list" "kwargs-map-1"] (mapv str params))
[`([~'sym-name ~'attr-map ~'kwargs-map]
(~function-name ~'sym-name (~'util/convert-symbol-map ~'attr-map) (~'util/empty-list) (~'util/convert-symbol-map ~'kwargs-map)))
`([~'sym-name ~'kwargs-map-or-vec-or-sym]
(~function-name ~'sym-name nil ~(symbol-vector-args) ~(symbol-map-args)))
`([~'kwargs-map-or-vec-or-sym]
(~function-name nil nil ~(symbol-vector-args) ~(symbol-map-args)))]))
(defn gen-symbol-function-arity [op-name op-values function-name]
(mapcat
(fn [[param-count info]]
(let [targets (->> (mapv :parameter-types info)
(apply interleave)
(mapv str)
(partition (count info))
(mapv set))
pnames (->> (mapv :parameter-types info)
(mapv symbol-transform-param-name)
(apply interleave)
(partition (count info))
(mapv #(clojure.string/join "-or-" %))
(rename-duplicate-params)
(mapv symbol))
coerced-params (mapv (fn [p t] `(~'util/nil-or-coerce-param ~(symbol (clojure.string/replace p #"\& " "")) ~t)) pnames targets)
params (if (= #{:public :static} (:flags (first info)))
pnames
(into ['sym] pnames))
function-body (if (= #{:public :static} (:flags (first info)))
`(~'util/coerce-return (~(symbol (str "Symbol/" op-name)) ~@coerced-params))
`(~'util/coerce-return (~(symbol (str "." op-name)) ~'sym ~@coerced-params)
))]
(when (not (and (> param-count 1) (has-variadic? params)))
`[(
~params
~function-body
)
~@(add-symbol-arities params function-name)])))
op-values))
(def all-symbol-functions
(for [operation (sort (public-by-name-and-param-count symbol-public-to-gen))]
(let [[op-name op-values] operation
function-name (-> op-name
str
scala/decode-scala-symbol
clojure-case
symbol)]
`(~'defn ~function-name
~@(remove nil? (gen-symbol-function-arity op-name op-values function-name))))))
(def license
(str
";; Licensed to the Apache Software Foundation (ASF) under one or more\n"
";; contributor license agreements. See the NOTICE file distributed with\n"
";; this work for additional information regarding copyright ownership.\n"
";; The ASF licenses this file to You under the Apache License, Version 2.0\n"
";; (the \"License\"); you may not use this file except in compliance with\n"
";; the License. You may obtain a copy of the License at\n"
";;\n"
";; http://www.apache.org/licenses/LICENSE-2.0\n"
";;\n"
";; Unless required by applicable law or agreed to in writing, software\n"
";; distributed under the License is distributed on an \"AS IS\" BASIS,\n"
";; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
";; See the License for the specific language governing permissions and\n"
";; limitations under the License.\n"
";;\n"))
(defn write-to-file [functions ns-gen fname]
(with-open [w (clojure.java.io/writer fname)]
(.write w ns-gen)
(.write w "\n\n")
(.write w ";; Do not edit - this is auto-generated")
(.write w "\n\n")
(.write w license)
(.write w "\n\n")
(.write w "\n\n")
(doseq [f functions]
(clojure.pprint/pprint f w)
(.write w "\n"))))
(def symbol-gen-ns "(ns org.apache.clojure-mxnet.symbol
(:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max
min repeat reverse set sort take to-array empty sin
get apply shuffle])
(:require [org.apache.clojure-mxnet.util :as util])
(:import (org.apache.mxnet Symbol)))")
(defn generate-symbol-file []
(println "Generating symbol file")
(write-to-file all-symbol-functions symbol-gen-ns "src/org/apache/clojure_mxnet/gen/symbol.clj"))
;;;;;;;;NDARRAY
(def ndarray-reflect-info (->> (:members (r/reflect NDArray))
(map #(into {} %))))
(def ndarray-public (filter (fn [x] (-> x :flags :public)) ndarray-reflect-info))
(def ndarray-public-no-default (->> ndarray-public
(filter #(not (re-find #"org\$apache\$mxnet" (str (:name %)))))
(filter #(not (re-find #"\$default" (str (:name %)))))))
(def ndarray-hand-gen-set #{"org.apache.mxnet.NDArrayFuncReturn"
"org.apache.mxnet.Context"
"scala.Enumeration$Value"
"scala.Tuple2"
"scala.collection.Traversable"} )
(defn is-ndarray-hand-gen? [info]
(->> (map str (:parameter-types info))
(into #{})
(clojure.set/intersection ndarray-hand-gen-set)
count
pos?))
(def ndarray-public-to-hand-gen (filter is-ndarray-hand-gen? ndarray-public-no-default))
(def ndarray-public-to-gen (->> (remove #(contains?(->> ndarray-public-to-hand-gen
(mapv :name)
(mapv str)
(set)) (str (:name %))) ndarray-public-no-default)))
(count ndarray-public-to-hand-gen) ;=> 15
(count ndarray-public-to-gen) ;=> 486
(map :name ndarray-public-to-hand-gen)
(defn gen-ndarray-function-arity [op-name op-values]
(for [[param-count info] op-values]
(let [targets (->> (mapv :parameter-types info)
(apply interleave)
(mapv str)
(partition (count info))
(mapv set))
pnames (->> (mapv :parameter-types info)
(mapv ndarray-transform-param-name)
(apply interleave)
(partition (count info))
(mapv #(clojure.string/join "-or-" %))
(rename-duplicate-params)
(mapv symbol))
coerced-params (mapv (fn [p t] `(~'util/coerce-param ~(symbol (clojure.string/replace p #"\& " "")) ~t)) pnames targets)
params (if (= #{:public :static} (:flags (first info)))
pnames
(into ['ndarray] pnames))
function-body (if (= #{:public :static} (:flags (first info)))
`(~'util/coerce-return (~(symbol (str "NDArray/" op-name)) ~@coerced-params))
`(~'util/coerce-return (~(symbol (str "." op-name)) ~'ndarray ~@coerced-params)
))]
(when (not (and (> param-count 1) (has-variadic? params)))
`(
~params
~function-body
)))))
(def all-ndarray-functions
(for [operation (sort (public-by-name-and-param-count ndarray-public-to-gen))]
(let [[op-name op-values] operation
function-name (-> op-name
str
scala/decode-scala-symbol
clojure-case
symbol)]
`(~'defn ~function-name
~@(remove nil? (gen-ndarray-function-arity op-name op-values))))))
(def ndarray-gen-ns "(ns org.apache.clojure-mxnet.ndarray
(:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max
min repeat reverse set sort take to-array empty shuffle])
(:import (org.apache.mxnet NDArray Shape)))")
(defn generate-ndarray-file []
(println "Generating ndarray file")
(write-to-file all-ndarray-functions ndarray-gen-ns "src/org/apache/clojure_mxnet/gen/ndarray.clj"))
;;; autogen the files
(do
(generate-ndarray-file)
(generate-symbol-file))
(comment
;; This generates a file with the bulk of the nd-array functions
(generate-ndarray-file)
;; This generates a file with the bulk of the symbol functions
(generate-symbol-file) )