blob: 47e9bc843ac07c884eb68c4ece62e65367414473 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import numpy as np
class ReplayMem(object):
def __init__(
self,
obs_dim,
act_dim,
memory_size=1000000):
# allocate space for memory cells
self.obs_dim = obs_dim
self.act_dim = act_dim
self.obss = np.zeros((memory_size, obs_dim))
self.acts = np.zeros((memory_size, act_dim))
self.rwds = np.zeros((memory_size, ))
self.ends = np.zeros(memory_size, dtype='uint8')
self.memory_size = memory_size
self.top = -1
self.size = 0
def add_sample(self, obs, act, rwd, end):
self.top = (self.top + 1) % self.memory_size
self.obss[self.top] = obs
self.acts[self.top] = act
self.rwds[self.top] = rwd
self.ends[self.top] = end
if self.size < self.memory_size:
self.size += 1
def get_batch(self, batch_size):
assert self.size >= batch_size
indices = np.zeros(batch_size, dtype="uint64")
transit_indices = np.zeros(batch_size, dtype="uint64")
counter = 0
while counter < batch_size:
idx = np.random.randint(0, self.size)
# case where the last piece of memory is sampled
# which does not have a successor state
if idx == self.top:
continue
transit_idx = (idx + 1) % self.memory_size
indices[counter] = idx
transit_indices[counter] = transit_idx
counter += 1
return (self.obss[indices],
self.acts[indices],
self.rwds[indices],
self.ends[indices],
self.obss[transit_indices])
if __name__ == "__main__":
memory = ReplayMem(2, 1, memory_size=10)
memory.add_sample(np.array([2, 2]), np.array([2]), 10, 0)
memory.add_sample(np.array([2, 2]), np.array([2]), 10, 0)
memory.add_sample(np.array([2, 2]), np.array([2]), 10, 0)
memory.add_sample(np.array([2, 2]), np.array([2]), 10, 0)
memory.add_sample(np.array([2, 2]), np.array([2]), 10, 0)
memory.add_sample(np.array([2, 2]), np.array([2]), 10, 0)
memory.add_sample(np.array([2, 2]), np.array([2]), 10, 0)
memory.add_sample(np.array([2, 2]), np.array([2]), 10, 0)
memory.add_sample(np.array([1, 1]), np.array([1]), 100, 1)
print(memory.obss)
print(memory.acts)
print(memory.rwds)
print(memory.ends)
print(memory.get_batch(5))