| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import mxnet as mx |
| import mxnet.ndarray as nd |
| import numpy |
| from base import Base |
| from operators import * |
| from atari_game import AtariGame |
| from utils import * |
| import logging |
| import argparse |
| |
| root = logging.getLogger() |
| root.setLevel(logging.DEBUG) |
| |
| ch = logging.StreamHandler(sys.stdout) |
| ch.setLevel(logging.DEBUG) |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| ch.setFormatter(formatter) |
| root.addHandler(ch) |
| mx.random.seed(100) |
| npy_rng = get_numpy_rng() |
| |
| |
| class DQNInitializer(mx.initializer.Xavier): |
| def _init_bias(self, _, arr): |
| arr[:] = .1 |
| |
| |
| def main(): |
| parser = argparse.ArgumentParser(description='Script to test the trained network on a game.') |
| parser.add_argument('-r', '--rom', required=False, type=str, |
| default=os.path.join('roms', 'breakout.bin'), |
| help='Path of the ROM File.') |
| parser.add_argument('-v', '--visualization', required=False, type=int, default=0, |
| help='Visualize the runs.') |
| parser.add_argument('--lr', required=False, type=float, default=0.01, |
| help='Learning rate of the AdaGrad optimizer') |
| parser.add_argument('--eps', required=False, type=float, default=0.01, |
| help='Eps of the AdaGrad optimizer') |
| parser.add_argument('--clip-gradient', required=False, type=float, default=None, |
| help='Clip threshold of the AdaGrad optimizer') |
| parser.add_argument('--double-q', action='store_true', |
| help='Use Double DQN only if specified') |
| parser.add_argument('--wd', required=False, type=float, default=0.0, |
| help='Weight of the L2 Regularizer') |
| parser.add_argument('-c', '--ctx', required=False, type=str, default='gpu', |
| help='Running Context. E.g `-c gpu` or `-c gpu1` or `-c cpu`') |
| parser.add_argument('-d', '--dir-path', required=False, type=str, default='', |
| help='Saving directory of model files.') |
| parser.add_argument('--start-eps', required=False, type=float, default=1.0, |
| help='Eps of the epsilon-greedy policy at the beginning') |
| parser.add_argument('--replay-start-size', required=False, type=int, default=50000, |
| help='The step that the training starts') |
| parser.add_argument('--kvstore-update-period', required=False, type=int, default=1, |
| help='The period that the worker updates the parameters from the sever') |
| parser.add_argument('--kv-type', required=False, type=str, default=None, |
| help='type of kvstore, default will not use kvstore, could also be dist_async') |
| parser.add_argument('--optimizer', required=False, type=str, default="adagrad", |
| help='type of optimizer') |
| args = parser.parse_args() |
| |
| if args.dir_path == '': |
| rom_name = os.path.splitext(os.path.basename(args.rom))[0] |
| args.dir_path = 'dqn-%s-lr%g' % (rom_name, args.lr) |
| replay_start_size = args.replay_start_size |
| max_start_nullops = 30 |
| replay_memory_size = 1000000 |
| history_length = 4 |
| rows = 84 |
| cols = 84 |
| |
| ctx = parse_ctx(args.ctx) |
| q_ctx = mx.Context(*ctx[0]) |
| |
| game = AtariGame(rom_path=args.rom, resize_mode='scale', replay_start_size=replay_start_size, |
| resized_rows=rows, resized_cols=cols, max_null_op=max_start_nullops, |
| replay_memory_size=replay_memory_size, display_screen=args.visualization, |
| history_length=history_length) |
| |
| ##RUN NATURE |
| freeze_interval = 10000 |
| epoch_num = 200 |
| steps_per_epoch = 250000 |
| update_interval = 4 |
| discount = 0.99 |
| |
| eps_start = args.start_eps |
| eps_min = 0.1 |
| eps_decay = (eps_start - eps_min) / 1000000 |
| eps_curr = eps_start |
| freeze_interval /= update_interval |
| minibatch_size = 32 |
| action_num = len(game.action_set) |
| |
| data_shapes = {'data': (minibatch_size, history_length) + (rows, cols), |
| 'dqn_action': (minibatch_size,), 'dqn_reward': (minibatch_size,)} |
| dqn_sym = dqn_sym_nature(action_num) |
| qnet = Base(data_shapes=data_shapes, sym_gen=dqn_sym, name='QNet', |
| initializer=DQNInitializer(factor_type="in"), |
| ctx=q_ctx) |
| target_qnet = qnet.copy(name="TargetQNet", ctx=q_ctx) |
| |
| use_easgd = False |
| optimizer = mx.optimizer.create(name=args.optimizer, learning_rate=args.lr, eps=args.eps, |
| clip_gradient=args.clip_gradient, |
| rescale_grad=1.0, wd=args.wd) |
| updater = mx.optimizer.get_updater(optimizer) |
| |
| qnet.print_stat() |
| target_qnet.print_stat() |
| |
| # Begin Playing Game |
| training_steps = 0 |
| total_steps = 0 |
| for epoch in range(epoch_num): |
| # Run Epoch |
| steps_left = steps_per_epoch |
| episode = 0 |
| epoch_reward = 0 |
| start = time.time() |
| game.start() |
| while steps_left > 0: |
| # Running New Episode |
| episode += 1 |
| episode_loss = 0.0 |
| episode_q_value = 0.0 |
| episode_update_step = 0 |
| episode_action_step = 0 |
| time_episode_start = time.time() |
| game.begin_episode(steps_left) |
| while not game.episode_terminate: |
| # 1. We need to choose a new action based on the current game status |
| if game.state_enabled and game.replay_memory.sample_enabled: |
| do_exploration = (npy_rng.rand() < eps_curr) |
| eps_curr = max(eps_curr - eps_decay, eps_min) |
| if do_exploration: |
| action = npy_rng.randint(action_num) |
| else: |
| # TODO Here we can in fact play multiple gaming instances simultaneously and make actions for each |
| # We can simply stack the current_state() of gaming instances and give prediction for all of them |
| # We need to wait after calling calc_score(.), which makes the program slow |
| # TODO Profiling the speed of this part! |
| current_state = game.current_state() |
| state = nd.array(current_state.reshape((1,) + current_state.shape), |
| ctx=q_ctx) / float(255.0) |
| qval_npy = qnet.forward(is_train=False, data=state)[0].asnumpy() |
| action = numpy.argmax(qval_npy) |
| episode_q_value += qval_npy[0, action] |
| episode_action_step += 1 |
| else: |
| action = npy_rng.randint(action_num) |
| |
| # 2. Play the game for a single mega-step (Inside the game, the action may be repeated for several times) |
| game.play(action) |
| total_steps += 1 |
| |
| # 3. Update our Q network if we can start sampling from the replay memory |
| # Also, we update every `update_interval` |
| if total_steps % update_interval == 0 and game.replay_memory.sample_enabled: |
| # 3.1 Draw sample from the replay_memory |
| training_steps += 1 |
| episode_update_step += 1 |
| states, actions, rewards, next_states, terminate_flags \ |
| = game.replay_memory.sample(batch_size=minibatch_size) |
| states = nd.array(states, ctx=q_ctx) / float(255.0) |
| next_states = nd.array(next_states, ctx=q_ctx) / float(255.0) |
| actions = nd.array(actions, ctx=q_ctx) |
| rewards = nd.array(rewards, ctx=q_ctx) |
| terminate_flags = nd.array(terminate_flags, ctx=q_ctx) |
| |
| # 3.2 Use the target network to compute the scores and |
| # get the corresponding target rewards |
| if not args.double_q: |
| target_qval = target_qnet.forward(is_train=False, data=next_states)[0] |
| target_rewards = rewards + nd.choose_element_0index(target_qval, |
| nd.argmax_channel(target_qval))\ |
| * (1.0 - terminate_flags) * discount |
| else: |
| target_qval = target_qnet.forward(is_train=False, data=next_states)[0] |
| qval = qnet.forward(is_train=False, data=next_states)[0] |
| |
| target_rewards = rewards + nd.choose_element_0index(target_qval, |
| nd.argmax_channel(qval))\ |
| * (1.0 - terminate_flags) * discount |
| outputs = qnet.forward(is_train=True, |
| data=states, |
| dqn_action=actions, |
| dqn_reward=target_rewards) |
| qnet.backward() |
| qnet.update(updater=updater) |
| |
| # 3.3 Calculate Loss |
| diff = nd.abs(nd.choose_element_0index(outputs[0], actions) - target_rewards) |
| quadratic_part = nd.clip(diff, -1, 1) |
| loss = 0.5 * nd.sum(nd.square(quadratic_part)).asnumpy()[0] +\ |
| nd.sum(diff - quadratic_part).asnumpy()[0] |
| episode_loss += loss |
| |
| # 3.3 Update the target network every freeze_interval |
| if training_steps % freeze_interval == 0: |
| qnet.copy_params_to(target_qnet) |
| steps_left -= game.episode_step |
| time_episode_end = time.time() |
| # Update the statistics |
| epoch_reward += game.episode_reward |
| info_str = "Epoch:%d, Episode:%d, Steps Left:%d/%d, Reward:%f, fps:%f, Exploration:%f" \ |
| % (epoch, episode, steps_left, steps_per_epoch, game.episode_reward, |
| game.episode_step / (time_episode_end - time_episode_start), eps_curr) |
| if episode_update_step > 0: |
| info_str += ", Avg Loss:%f/%d" % (episode_loss / episode_update_step, |
| episode_update_step) |
| if episode_action_step > 0: |
| info_str += ", Avg Q Value:%f/%d" % (episode_q_value / episode_action_step, |
| episode_action_step) |
| if episode % 100 == 0: |
| logging.info(info_str) |
| end = time.time() |
| fps = steps_per_epoch / (end - start) |
| qnet.save_params(dir_path=args.dir_path, epoch=epoch) |
| logging.info("Epoch:%d, FPS:%f, Avg Reward: %f/%d" |
| % (epoch, fps, epoch_reward / float(episode), episode)) |
| |
| if __name__ == '__main__': |
| main() |