| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import mxnet as mx |
| |
| def matrix_fact_net(factor_size, num_hidden, max_user, max_item, dense): |
| # input |
| user = mx.sym.Variable('user') |
| item = mx.sym.Variable('item') |
| score = mx.sym.Variable('score') |
| stype = 'default' if dense else 'row_sparse' |
| sparse_grad = not dense |
| user_weight = mx.sym.Variable('user_weight', stype=stype) |
| item_weight = mx.sym.Variable('item_weight', stype=stype) |
| # user feature lookup |
| user = mx.sym.Embedding(data=user, weight=user_weight, sparse_grad=sparse_grad, |
| input_dim=max_user, output_dim=factor_size) |
| # item feature lookup |
| item = mx.sym.Embedding(data=item, weight=item_weight, sparse_grad=sparse_grad, |
| input_dim=max_item, output_dim=factor_size) |
| |
| # non-linear transformation of user features |
| user = mx.sym.Activation(data=user, act_type='relu') |
| user_act = mx.sym.FullyConnected(data=user, num_hidden=num_hidden) |
| # non-linear transformation of item features |
| item = mx.sym.Activation(data=item, act_type='relu') |
| item_act = mx.sym.FullyConnected(data=item, num_hidden=num_hidden) |
| # predict by the inner product, which is elementwise product and then sum |
| pred = user_act * item_act |
| pred = mx.sym.sum(data=pred, axis=1) |
| pred = mx.sym.Flatten(data=pred) |
| # loss layer |
| pred = mx.sym.LinearRegressionOutput(data=pred, label=score) |
| return pred |