blob: 9c5211081ef4c64d2bde471a98007934673b3648 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import queue
import threading
import requests
import time
import torch
from src.logger import logger
class StreamingDataLoader:
"""
This will preoritically query data from cache-service
"""
def __init__(self, cache_svc_url, table_name, name_space):
self.last_fetch_time = 0
self.table_name = table_name
# train, valid, test
self.name_space = name_space
self.end_signal = "end_position"
self.cache_svc_url = cache_svc_url
self.data_queue = queue.Queue(maxsize=10)
self.stop_event = threading.Event()
self.thread = threading.Thread(target=self.fetch_data, daemon=True)
self.thread.start()
def fetch_data(self):
while not self.stop_event.is_set():
response = requests.get(
f'{self.cache_svc_url}/',
params={
'table_name': self.table_name,
'name_space': self.name_space})
if response.status_code == 200:
batch = response.json()
# in trianing, we use iteraiton-per-epoch to control the end
if batch == self.end_signal:
if self.name_space == "valid":
# end_signal in inference, stop !
logger.info("[StreamingDataLoader]: last iteration in valid is meet!")
self.data_queue.put({self.end_signal: True})
else:
# end_signal in trianing, then keep training
continue
else:
# convert to tensor again
id_tensor = torch.LongTensor(batch['id'])
value_tensor = torch.FloatTensor(batch['value'])
y_tensor = torch.FloatTensor(batch['y'])
data_tensor = {'id': id_tensor, 'value': value_tensor, 'y': y_tensor}
self.data_queue.put(data_tensor)
else:
print(response.json())
time.sleep(5)
def __iter__(self):
return self
def __next__(self):
print("compute time = ", time.time() - self.last_fetch_time)
self.last_fetch_time = time.time()
if self.data_queue.empty() and not self.thread.is_alive():
raise StopIteration
else:
data = self.data_queue.get(block=True)
if self.end_signal in data:
raise StopIteration
else:
return data
def __len__(self):
return self.data_queue.qsize()
def stop(self):
self.stop_event.set()
self.thread.join()