blob: f4c667621f707b0c37804f7e4d2a7aab45e4bad7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Contains a class for handling multi-process data generation"""
from __future__ import print_function
from ctypes import c_bool
import multiprocessing as mp
try:
from queue import Full as QFullExcept
from queue import Empty as QEmptyExcept
except ImportError:
from Queue import Full as QFullExcept
from Queue import Empty as QEmptyExcept
class MPData(object):
"""Handles multi-process data generation.
Operation:
- call start() to start the data generation
- call get() (blocking) to read one sample
- call reset() to stop data generation
"""
def __init__(self, num_processes, max_queue_size, fn):
"""Parameters
----------
num_processes: int
Number of processes to spawn
max_queue_size: int
Maximum samples in the queue before processes wait
fn: function
function that generates samples, executed on separate processes.
"""
self.queue = mp.Queue(maxsize=int(max_queue_size))
self.alive = mp.Value(c_bool, False, lock=False)
self.num_proc = num_processes
self.proc = list()
self.fn = fn
def start(self):
"""Starts the processes"""
self._init_proc()
@staticmethod
def _proc_loop(proc_id, alive, queue, fn):
"""Thread loop for generating data
Parameters
----------
proc_id: int
Process id
alive: multiprocessing.Value
variable for signaling whether process should continue or not
queue: multiprocessing.Queue
queue for passing data back
fn: function
function object that returns a sample to be pushed into the queue
"""
print("proc {} started".format(proc_id))
try:
while alive.value:
data = fn()
put_success = False
while alive.value and not put_success:
try:
queue.put(data, timeout=0.5)
put_success = True
except QFullExcept:
# print("Queue Full")
pass
except KeyboardInterrupt:
print("W: interrupt received, stopping process {} ...".format(proc_id))
print("Closing process {}".format(proc_id))
queue.close()
def _init_proc(self):
"""Start processes if not already started"""
if not self.proc:
self.proc = [
mp.Process(target=self._proc_loop, args=(i, self.alive, self.queue, self.fn))
for i in range(self.num_proc)
]
self.alive.value = True
for p in self.proc:
p.start()
def get(self):
"""Get a datum from the queue
Returns
-------
np.ndarray
A captcha image, normalized to [0, 1]
"""
self._init_proc()
return self.queue.get()
def reset(self):
"""Resets the generator by stopping all processes"""
self.alive.value = False
qsize = 0
try:
while True:
self.queue.get(timeout=0.1)
qsize += 1
except QEmptyExcept:
pass
print("Queue size on reset: {}".format(qsize))
for i, p in enumerate(self.proc):
p.join()
self.proc.clear()