blob: 982cc50f9781313fc6e193eaf8c5a498a405af6a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# * "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
class DataInput:
def __init__(self, data, batch_size):
self.batch_size = batch_size
self.data = data
self.epoch_size = len(self.data) // self.batch_size
if self.epoch_size * self.batch_size < len(self.data):
self.epoch_size += 1
self.i = 0
def __iter__(self):
return self
def next(self):
return self.__next__()
def __next__(self):
if self.i == self.epoch_size:
raise StopIteration
ts = self.data[self.i * self.batch_size : min((self.i+1) * self.batch_size,
len(self.data))]
self.i += 1
u, i, y, sl = [], [], [], []
for t in ts:
u.append(t[0])
i.append(t[2])
y.append(t[3])
sl.append(len(t[1]))
max_sl = max(sl)
hist_i = np.zeros([len(ts), max_sl], np.int64)
k = 0
for t in ts:
for l in range(len(t[1])):
hist_i[k][l] = t[1][l]
k += 1
return self.i, (u, i, y, hist_i, sl)
class DataInputTest:
def __init__(self, data, batch_size):
self.batch_size = batch_size
self.data = data
self.epoch_size = len(self.data) // self.batch_size
if self.epoch_size * self.batch_size < len(self.data):
self.epoch_size += 1
self.i = 0
def __iter__(self):
return self
def next(self):
return self.__next__()
def __next__(self):
if self.i == self.epoch_size:
raise StopIteration
ts = self.data[self.i * self.batch_size : min((self.i+1) * self.batch_size,
len(self.data))]
self.i += 1
u, i, j, sl = [], [], [], []
for t in ts:
u.append(t[0])
i.append(t[2][0])
j.append(t[2][1])
sl.append(len(t[1]))
max_sl = max(sl)
hist_i = np.zeros([len(ts), max_sl], np.int64)
k = 0
for t in ts:
for l in range(len(t[1])):
hist_i[k][l] = t[1][l]
k += 1
return self.i, (u, i, j, hist_i, sl)