blob: 2f3c4ec9ffd0527ee89c4efe6de4729ccf943740 [file] [log] [blame]
import os
import sys
import random
import shlex
import time
import re
from utils.utils import to_bool
from feat_readers.common import *
from feat_readers import stats
from feat_io import DataReadStream
class RegrDataReadStream(object):
def __init__(self, dataset_args, n_ins):
dataset_args["has_labels"] = False
assert("seed" in dataset_args)
args1 = dict(dataset_args)
args2 = dict(dataset_args)
args1["lst_file"] = dataset_args["input_lst_file"]
args2["lst_file"] = dataset_args["output_lst_file"]
self.input = DataReadStream(args1, n_ins)
self.output = DataReadStream(args2, n_ins)
def read_by_part(self):
self.input.read_by_part()
self.output.read_by_part()
def read_by_matrix(self):
self.input.read_by_matrix()
self.output.read_by_matrix()
def make_shared(self):
self.input.make_shared()
self.output.make_shared()
def get_shared(self):
iret = self.input.get_shared()
oret = self.output.get_shared()
assert(iret[1] is None)
assert(oret[1] is None)
return iret[0], oret[0]
def initialize_read(self):
self.input.initialize_read()
self.output.initialize_read()
def current_utt_id(self):
a = self.input.current_utt_id()
b = self.output.current_utt_id()
assert(a == b)
return a
def load_next_block(self):
a = self.input.load_next_block()
b = self.output.load_next_block()
assert(a == b)
return a
def get_state(self):
a = self.input.get_state()
b = self.output.get_state()
assert(a[0] == b[0])
assert(a[2] == b[2])
assert(a[3] == b[3])
assert(a[4] == b[4])
assert(numpy.array_equal(a[1], b[1]))
return a
def set_state(self, state):
self.input.set_state(state)
self.output.set_state(state)