| from ddpg import DDPG |
| from rllab.envs.box2d.cartpole_env import CartpoleEnv |
| from rllab.envs.normalized_env import normalize |
| from policies import DeterministicMLPPolicy |
| from qfuncs import ContinuousMLPQ |
| from strategies import OUStrategy |
| from utils import SEED |
| import mxnet as mx |
| |
| # set environment, policy, qfunc, strategy |
| |
| env = normalize(CartpoleEnv()) |
| |
| policy = DeterministicMLPPolicy(env.spec) |
| qfunc = ContinuousMLPQ(env.spec) |
| strategy = OUStrategy(env.spec) |
| |
| # set the training algorithm and train |
| |
| algo = DDPG( |
| env=env, |
| policy=policy, |
| qfunc=qfunc, |
| strategy=strategy, |
| ctx=mx.gpu(0), |
| max_path_length=100, |
| epoch_length=1000, |
| memory_start_size=10000, |
| n_epochs=1000, |
| discount=0.99, |
| qfunc_lr=1e-3, |
| policy_lr=1e-4, |
| seed=SEED) |
| |
| algo.train() |