| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| |
| import pandas as pd |
| import numpy as np |
| import copy |
| |
| import pickle |
| import tempfile |
| import os |
| import json |
| |
| from bisect import bisect |
| |
| class Oracle: |
| """ |
| Oracle handles generating the true answers and evaluating/collecting the sketch's answers |
| By default, this assumes all query answers are real valued |
| |
| The exact method for a problem should be implemented here |
| """ |
| |
| def __init__(self, workload=None, answer_file=None, read_cache=False, save_dir=None, as_json=False, **kwargs): |
| """ |
| Currently, every oracle's init must have a kwargs argument. |
| |
| This uses kwargs in a less than ideal way to handle different Oracles having |
| different signatures in the initialization. This init is called when loading an |
| Oracle's results from the cache. |
| """ |
| self.workload = workload |
| self.answers = [] |
| self.answer_file = answer_file |
| self.read_cache = read_cache |
| self._prepared = False |
| self.save_dir = save_dir |
| self.as_json = as_json |
| |
| def setWorkload(self, workload): |
| self._prepared = False |
| self.workload = workload |
| |
| def getID(self): |
| return f"Oracle_{self.name}_{self.workload.getID()}" |
| |
| def getAnswer(self, qid): |
| truth = self.answers[qid] |
| return truth |
| |
| def eval_sketch_answer(self, qid, answer): |
| error = self.eval_error(qid, answer) |
| return error |
| |
| # |
| # These are the main functions that need to be implemented for each new problem |
| # |
| def eval_error(self, qid, answer): |
| """ |
| By default, assume errors are real-valued and can be added |
| """ |
| truth = self.answers[qid] |
| return answer - truth |
| |
| def add(self, x): |
| raise Exception |
| |
| def query(self, query, parameters): |
| raise Exception("Unimplemented") |
| |
| # |
| # Functions to write/read oracle answers to disk |
| # |
| def getAnswerFile(self): |
| prefix = self.getID() |
| |
| if self.save_dir is None: |
| fd, filename = tempfile.mkstemp(prefix=prefix) |
| else: |
| filename = f"Answers_{prefix}.json" |
| |
| self.answer_file = filename |
| |
| return filename |
| |
| def prepareFromCached(self): |
| if self.answer_file is None: |
| self.answer_file = self.getAnswerFile() |
| |
| print("prep from cache oracle", self.answer_file) |
| try: |
| if self.as_json: |
| with open(self.answer_file, "r") as file: |
| self.answers = json.load(file) |
| else: |
| with open(self.answer_file, "rb") as file: |
| self.answers = pickle.load(file) |
| |
| if len(self.answers) > 0: |
| return True |
| except Exception: |
| pass |
| |
| print("Cannot file {self.answer_file}") |
| return False |
| |
| def writeToCache(self): |
| answer_file = self.getAnswerFile() |
| self.answer_file = answer_file |
| if self.as_json: |
| with open(answer_file, "w") as file: |
| json.dump(self.answers, file) |
| else: |
| with open(answer_file, "wb") as file: |
| pickle.dump(self.answers, file=file) |
| |
| # I don't ever close the fd and clean up the file right now XXX |
| |
| def printAnswers(self): |
| print("answers:") |
| for a, q in zip(self.answers, self.workload.genQueries()): |
| print(q, ":", a) |
| |
| def prepare(self, **kwargs): |
| """ |
| Iterate through the data and populate the pre-prepared answers |
| """ |
| if self._prepared: |
| return |
| |
| if self.read_cache: |
| self._prepared = self.prepareFromCached() |
| if self._prepared: |
| print("read from cache") |
| return |
| |
| |
| self.workload.prepare() |
| |
| print(f"reset oracle answers") |
| self.answers = [] |
| query_iter = self.workload.genQueries() |
| q = next(query_iter) |
| for i, x in enumerate(self.workload.genData()): |
| self.add(x) |
| while q and i == q.data_idx: |
| answer = self.query(q.data_idx, q.query, q.parameters) |
| self.answers.append(copy.deepcopy(answer)) |
| assert(len(self.answers) == q.qid+1) |
| q = next(query_iter, None) |
| |
| self.printAnswers() |
| |
| self.writeToCache() # note: I should not write to cache if not using parallel processes |
| self._prepared = True |
| |
| def reset(self, workload): |
| self.setWorkload(workload) |
| |
| def prepareForPickle(self): |
| """ |
| This should remove any large objects |
| """ |
| self.workload.prepareForPickle() |
| |
| ############################################################################################################## |
| |
| # simple distinct count testing when workload always consists of unique items |
| class DistinctStreamOracle(Oracle): |
| name = 'DistinctStream' |
| |
| def __init__(self, workload, **kwargs): |
| super().__init__(workload, **kwargs) |
| self.counter = 0 |
| |
| def add(self, x): |
| self.counter += 1 |
| |
| def query(self, idx, query, params): |
| return idx |
| |
| def eval_error(self, qid, answer): |
| """ |
| By default, assume errors are real-valued and can be added |
| """ |
| truth = self.answers[qid] |
| return (answer - truth) / truth * 100. |
| |
| def getCached(self): |
| return self |
| |
| |
| class TopKOracle(Oracle): |
| name = "TopK" |
| |
| def __init__(self, workload=None, **kwargs): |
| super().__init__(workload, **kwargs) |
| self.table = {} |
| |
| def add(self, x): |
| self.table[x] = self.table.get(x, 0) + 1 |
| |
| # get all top k |
| def query(self, idx, query, k): |
| s = sorted([(w, x) for x, w in self.table.items()]) |
| topk = [(x, w) for w, x in reversed(s[-k:])] |
| return topk |
| |
| def eval_error(self, qid, answer): |
| """ |
| Returns the number of missed items in the result set |
| Note that the sketch's answer can include more than k items |
| """ |
| truth = self.answers[qid] |
| A = set([x for x, w in truth]) |
| B = set([x for x, w in answer]) |
| |
| missed = len(A) - len(A.intersection(B)) |
| return missed |
| |
| def reset(self, **kwargs): |
| super().reset(**kwargs) |
| self.table = {} |
| |
| def prepareForPickle(self): |
| super().prepareForPickle() |
| self.table = None |
| |
| class QuantileOracle(Oracle): |
| name = "Quantile" |
| |
| def __init__(self, workload=None, **kwargs): |
| super().__init__(workload, **kwargs) |
| self.dat = [] |
| self.is_sorted = False |
| |
| def add(self, x): |
| self.dat.append(x) |
| self.is_sorted = False |
| |
| # This sorts and gets the quantile q |
| # The quantile is defined to be the lower semicontinuous inverse CDF |
| # That is, it does no interpolation and F^-1(y) = sup {x: F(x) <= y} |
| # where the sup is taken over data points |
| def query(self, idx, query, q): |
| if not self.is_sorted: |
| self.dat.sort() |
| self.is_sorted = True |
| |
| n = len(self.dat) |
| if query == 'quantile': |
| rank = int(q * n) |
| return self.dat[rank] |
| else: |
| i = bisect(self.dat, q) |
| return i/n |
| |
| def eval_error(self, qid, answer): |
| truth = self.answers[qid] |
| return answer-truth |
| |
| def reset(self, **kwargs): |
| super().reset(**kwargs) |
| self.dat = [] |
| self.is_sorted = False |
| |
| def prepareForPickle(self): |
| super().prepareForPickle() |
| self.dat = None |
| |