| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import mxnet as mx |
| import mxnet.ndarray as nd |
| import numpy |
| import cv2 |
| from scipy.stats import entropy |
| from utils import * |
| |
| class DQNOutput(mx.operator.CustomOp): |
| def __init__(self): |
| super(DQNOutput, self).__init__() |
| |
| def forward(self, is_train, req, in_data, out_data, aux): |
| self.assign(out_data[0], req[0], in_data[0]) |
| |
| def backward(self, req, out_grad, in_data, out_data, in_grad, aux): |
| # TODO Backward using NDArray will cause some troubles see `https://github.com/dmlc/mxnet/issues/1720' |
| x = out_data[0].asnumpy() |
| action = in_data[1].asnumpy().astype(numpy.int) |
| reward = in_data[2].asnumpy() |
| dx = in_grad[0] |
| ret = numpy.zeros(shape=dx.shape, dtype=numpy.float32) |
| ret[numpy.arange(action.shape[0]), action] \ |
| = numpy.clip(x[numpy.arange(action.shape[0]), action] - reward, -1, 1) |
| self.assign(dx, req[0], ret) |
| |
| |
| @mx.operator.register("DQNOutput") |
| class DQNOutputProp(mx.operator.CustomOpProp): |
| def __init__(self): |
| super(DQNOutputProp, self).__init__(need_top_grad=False) |
| |
| def list_arguments(self): |
| return ['data', 'action', 'reward'] |
| |
| def list_outputs(self): |
| return ['output'] |
| |
| def infer_shape(self, in_shape): |
| data_shape = in_shape[0] |
| action_shape = (in_shape[0][0],) |
| reward_shape = (in_shape[0][0],) |
| output_shape = in_shape[0] |
| return [data_shape, action_shape, reward_shape], [output_shape], [] |
| |
| def create_operator(self, ctx, shapes, dtypes): |
| return DQNOutput() |
| |
| |
| class DQNOutputNpyOp(mx.operator.NumpyOp): |
| def __init__(self): |
| super(DQNOutputNpyOp, self).__init__(need_top_grad=False) |
| |
| def list_arguments(self): |
| return ['data', 'action', 'reward'] |
| |
| def list_outputs(self): |
| return ['output'] |
| |
| def infer_shape(self, in_shape): |
| data_shape = in_shape[0] |
| action_shape = (in_shape[0][0],) |
| reward_shape = (in_shape[0][0],) |
| output_shape = in_shape[0] |
| return [data_shape, action_shape, reward_shape], [output_shape] |
| |
| def forward(self, in_data, out_data): |
| x = in_data[0] |
| y = out_data[0] |
| y[:] = x |
| |
| def backward(self, out_grad, in_data, out_data, in_grad): |
| x = out_data[0] |
| action = in_data[1].astype(numpy.int) |
| reward = in_data[2] |
| dx = in_grad[0] |
| dx[:] = 0 |
| dx[numpy.arange(action.shape[0]), action] \ |
| = numpy.clip(x[numpy.arange(action.shape[0]), action] - reward, -1, 1) |
| |
| |
| def dqn_sym_nips(action_num, data=None, name='dqn'): |
| """Structure of the Deep Q Network in the NIPS 2013 workshop paper: |
| Playing Atari with Deep Reinforcement Learning (https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) |
| |
| Parameters |
| ---------- |
| action_num : int |
| data : mxnet.sym.Symbol, optional |
| name : str, optional |
| """ |
| if data is None: |
| net = mx.symbol.Variable('data') |
| else: |
| net = data |
| net = mx.symbol.Convolution(data=net, name='conv1', kernel=(8, 8), stride=(4, 4), num_filter=16) |
| net = mx.symbol.Activation(data=net, name='relu1', act_type="relu") |
| net = mx.symbol.Convolution(data=net, name='conv2', kernel=(4, 4), stride=(2, 2), num_filter=32) |
| net = mx.symbol.Activation(data=net, name='relu2', act_type="relu") |
| net = mx.symbol.Flatten(data=net) |
| net = mx.symbol.FullyConnected(data=net, name='fc3', num_hidden=256) |
| net = mx.symbol.Activation(data=net, name='relu3', act_type="relu") |
| net = mx.symbol.FullyConnected(data=net, name='fc4', num_hidden=action_num) |
| net = mx.symbol.Custom(data=net, name=name, op_type='DQNOutput') |
| return net |
| |
| |
| def dqn_sym_nature(action_num, data=None, name='dqn'): |
| """Structure of the Deep Q Network in the Nature 2015 paper: |
| Human-level control through deep reinforcement learning |
| (http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html) |
| |
| Parameters |
| ---------- |
| action_num : int |
| data : mxnet.sym.Symbol, optional |
| name : str, optional |
| """ |
| if data is None: |
| net = mx.symbol.Variable('data') |
| else: |
| net = data |
| net = mx.symbol.Variable('data') |
| net = mx.symbol.Convolution(data=net, name='conv1', kernel=(8, 8), stride=(4, 4), num_filter=32) |
| net = mx.symbol.Activation(data=net, name='relu1', act_type="relu") |
| net = mx.symbol.Convolution(data=net, name='conv2', kernel=(4, 4), stride=(2, 2), num_filter=64) |
| net = mx.symbol.Activation(data=net, name='relu2', act_type="relu") |
| net = mx.symbol.Convolution(data=net, name='conv3', kernel=(3, 3), stride=(1, 1), num_filter=64) |
| net = mx.symbol.Activation(data=net, name='relu3', act_type="relu") |
| net = mx.symbol.Flatten(data=net) |
| net = mx.symbol.FullyConnected(data=net, name='fc4', num_hidden=512) |
| net = mx.symbol.Activation(data=net, name='relu4', act_type="relu") |
| net = mx.symbol.FullyConnected(data=net, name='fc5', num_hidden=action_num) |
| net = mx.symbol.Custom(data=net, name=name, op_type='DQNOutput') |
| return net |