{"nbformat": 4, "cells": [{"source": "# Matrix Factorization\n\nIn a recommendation system, there is a group of users and a set of items. Given\nthat each users have rated some items in the system, we would like to predict\nhow the users would rate the items that they have not yet rated, such that we\ncan make recommendations to the users.\n\nMatrix factorization is one of the mainly used algorithm in recommendation\nsystems. It can be used to discover latent features underlying the interactions\nbetween two different kinds of entities.\n\nAssume we assign a k-dimensional vector to each user and a k-dimensional vector\nto each item such that the dot product of these two vectors gives the user's\nrating of that item. We can learn the user and item vectors directly, which is\nessentially performing SVD on the user-item matrix. We can also try to learn the\nlatent features using multi-layer neural networks.\n\nIn this tutorial, we will work though the steps to implement these ideas in\nMXNet.\n\n## Prepare Data\n\nWe use the [MovieLens](http://grouplens.org/datasets/movielens/) data here, but\nit can apply to other datasets as well. Each row of this dataset contains a\ntuple of user id, movie id, rating, and time stamp, we will only use the first\nthree items. We first define the a batch which contains n tuples. It also\nprovides name and shape information to MXNet about the data and label.", "cell_type": "markdown", "metadata": {}}, {"source": "class Batch(object):\n def __init__(self, data_names, data, label_names, label):\n self.data = data\n self.label = label\n self.data_names = data_names\n self.label_names = label_names\n\n @property\n def provide_data(self):\n return [(n, x.shape) for n, x in zip(self.data_names, self.data)]\n\n @property\n def provide_label(self):\n return [(n, x.shape) for n, x in zip(self.label_names, self.label)]", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Then we define a data iterator, which returns a batch of tuples each time.", "cell_type": "markdown", "metadata": {}}, {"source": "import mxnet as mx\nimport random\n\nclass Batch(object):\n def __init__(self, data_names, data, label_names, label):\n self.data = data\n self.label = label\n self.data_names = data_names\n self.label_names = label_names\n\n @property\n def provide_data(self):\n return [(n, x.shape) for n, x in zip(self.data_names, self.data)]\n\n @property\n def provide_label(self):\n return [(n, x.shape) for n, x in zip(self.label_names, self.label)]\n\nclass DataIter(mx.io.DataIter):\n def __init__(self, fname, batch_size):\n super(DataIter, self).__init__()\n self.batch_size = batch_size\n self.data = []\n for line in file(fname):\n tks = line.strip().split('\\t')\n if len(tks) != 4:\n continue\n self.data.append((int(tks[0]), int(tks[1]), float(tks[2])))\n self.provide_data = [('user', (batch_size, )), ('item', (batch_size, ))]\n self.provide_label = [('score', (self.batch_size, ))]\n\n def __iter__(self):\n for k in range(len(self.data) / self.batch_size):\n users = []\n items = []\n scores = []\n for i in range(self.batch_size):\n j = k * self.batch_size + i\n user, item, score = self.data[j]\n users.append(user)\n items.append(item)\n scores.append(score)\n\n data_all = [mx.nd.array(users), mx.nd.array(items)]\n label_all = [mx.nd.array(scores)]\n data_names = ['user', 'item']\n label_names = ['score']\n\n data_batch = Batch(data_names, data_all, label_names, label_all)\n yield data_batch\n\n def reset(self):\n random.shuffle(self.data)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Now we download the data and provide a function to obtain the data iterator:", "cell_type": "markdown", "metadata": {}}, {"source": "import os\nimport urllib\nimport zipfile\nif not os.path.exists('ml-100k.zip'):\n urllib.urlretrieve('http://files.grouplens.org/datasets/movielens/ml-100k.zip', 'ml-100k.zip')\nwith zipfile.ZipFile(\"ml-100k.zip\",\"r\") as f:\n f.extractall(\"./\")\ndef get_data(batch_size):\n return (DataIter('./ml-100k/u1.base', batch_size), DataIter('./ml-100k/u1.test', batch_size))", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Finally we calculate the numbers of users and items for later use.", "cell_type": "markdown", "metadata": {}}, {"source": "def max_id(fname):\n mu = 0\n mi = 0\n for line in file(fname):\n tks = line.strip().split('\\t')\n if len(tks) != 4:\n continue\n mu = max(mu, int(tks[0]))\n mi = max(mi, int(tks[1]))\n return mu + 1, mi + 1\nmax_user, max_item = max_id('./ml-100k/u.data')\n(max_user, max_item)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Optimization\n\nWe first implement the RMSE (root-mean-square error) measurement, which is\ncommonly used by matrix factorization.", "cell_type": "markdown", "metadata": {}}, {"source": "import math\ndef RMSE(label, pred):\n ret = 0.0\n n = 0.0\n pred = pred.flatten()\n for i in range(len(label)):\n ret += (label[i] - pred[i]) * (label[i] - pred[i])\n n += 1.0\n return math.sqrt(ret / n)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Then we define a general training module, which is borrowed from the image\nclassification application.", "cell_type": "markdown", "metadata": {}}, {"source": "def train(network, batch_size, num_epoch, learning_rate):\n model = mx.model.FeedForward(\n ctx = mx.gpu(0),\n symbol = network,\n num_epoch = num_epoch,\n learning_rate = learning_rate,\n wd = 0.0001,\n momentum = 0.9)\n\n batch_size = 64\n train, test = get_data(batch_size)\n\n import logging\n head = '%(asctime)-15s %(message)s'\n logging.basicConfig(level=logging.DEBUG)\n\n model.fit(X = train,\n eval_data = test,\n eval_metric = RMSE,\n batch_end_callback=mx.callback.Speedometer(batch_size, 20000/batch_size),)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "## Networks\n\nNow we try various networks. We first learn the latent vectors directly.", "cell_type": "markdown", "metadata": {}}, {"source": "def plain_net(k):\n # input\n user = mx.symbol.Variable('user')\n item = mx.symbol.Variable('item')\n score = mx.symbol.Variable('score')\n # user feature lookup\n user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k)\n # item feature lookup\n item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)\n # predict by the inner product, which is elementwise product and then sum\n pred = user * item\n pred = mx.symbol.sum_axis(data = pred, axis = 1)\n pred = mx.symbol.Flatten(data = pred)\n # loss layer\n pred = mx.symbol.LinearRegressionOutput(data = pred, label = score)\n return pred\n\ntrain(plain_net(64), batch_size=64, num_epoch=10, learning_rate=.05)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Next we try to use 2 layers neural network to learn the latent variables, which stack a fully connected layer above the embedding layers:", "cell_type": "markdown", "metadata": {}}, {"source": "def get_one_layer_mlp(hidden, k):\n # input\n user = mx.symbol.Variable('user')\n item = mx.symbol.Variable('item')\n score = mx.symbol.Variable('score')\n # user latent features\n user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k)\n user = mx.symbol.Activation(data = user, act_type=\"relu\")\n user = mx.symbol.FullyConnected(data = user, num_hidden = hidden)\n # item latent features\n item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)\n item = mx.symbol.Activation(data = item, act_type=\"relu\")\n item = mx.symbol.FullyConnected(data = item, num_hidden = hidden)\n # predict by the inner product\n pred = user * item\n pred = mx.symbol.sum_axis(data = pred, axis = 1)\n pred = mx.symbol.Flatten(data = pred)\n # loss layer\n pred = mx.symbol.LinearRegressionOutput(data = pred, label = score)\n return pred\n\ntrain(get_one_layer_mlp(64, 64), batch_size=64, num_epoch=10, learning_rate=.05)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "Adding dropout layers to relief the over-fitting.", "cell_type": "markdown", "metadata": {}}, {"source": "def get_one_layer_dropout_mlp(hidden, k):\n # input\n user = mx.symbol.Variable('user')\n item = mx.symbol.Variable('item')\n score = mx.symbol.Variable('score')\n # user latent features\n user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k)\n user = mx.symbol.Activation(data = user, act_type=\"relu\")\n user = mx.symbol.FullyConnected(data = user, num_hidden = hidden)\n user = mx.symbol.Dropout(data=user, p=0.5)\n # item latent features\n item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)\n item = mx.symbol.Activation(data = item, act_type=\"relu\")\n item = mx.symbol.FullyConnected(data = item, num_hidden = hidden)\n item = mx.symbol.Dropout(data=item, p=0.5)\n # predict by the inner product\n pred = user * item\n pred = mx.symbol.sum_axis(data = pred, axis = 1)\n pred = mx.symbol.Flatten(data = pred)\n # loss layer\n pred = mx.symbol.LinearRegressionOutput(data = pred, label = score)\n return pred\ntrain(get_one_layer_mlp(256, 512), batch_size=64, num_epoch=10, learning_rate=.05)", "cell_type": "code", "execution_count": null, "outputs": [], "metadata": {}}, {"source": "\n<!-- INSERT SOURCE DOWNLOAD BUTTONS -->\n\n", "cell_type": "markdown", "metadata": {}}], "metadata": {"display_name": "", "name": "", "language": "python"}, "nbformat_minor": 2} |