| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| source("src/test/scripts/functions/federated/paramserv/TwoNN.dml") as TwoNN |
| source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN |
| |
| # create federated input matrices |
| features = federated(addresses=list($X0, $X1), |
| ranges=list(list(0, 0), list($examples_per_worker, $num_features), |
| list($examples_per_worker, 0), list($examples_per_worker * 2, $num_features))) |
| |
| labels = federated(addresses=list($y0, $y1), |
| ranges=list(list(0, 0), list($examples_per_worker, $num_labels), |
| list($examples_per_worker, 0), list($examples_per_worker * 2, $num_labels))) |
| |
| epochs = $epochs |
| batch_size = $batch_size |
| learning_rate = $eta |
| utype = $utype |
| freq = $freq |
| network_type = $network_type |
| |
| # currently ignored parameters |
| workers = 1 |
| scheme = "DISJOINT_CONTIGUOUS" |
| paramserv_mode = "LOCAL" |
| |
| # config for the cnn |
| channels = $channels |
| hin = $hin |
| win = $win |
| |
| if(network_type == "TwoNN") { |
| model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), epochs, workers, utype, freq, batch_size, scheme, paramserv_mode, learning_rate) |
| } |
| else { |
| model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), channels, hin, win, epochs, workers, utype, freq, batch_size, scheme, paramserv_mode, learning_rate) |
| } |
| print(toString(model)) |