blob: d95953f40451a70dc0e559544504be1024188414 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package datafu.test.pig.hash.lsh;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.adrianwalker.multilinestring.Multiline;
import org.apache.commons.math.linear.RealVector;
import org.apache.commons.math.random.JDKRandomGenerator;
import org.apache.commons.math.random.RandomData;
import org.apache.commons.math.random.RandomDataImpl;
import org.apache.commons.math.random.RandomGenerator;
import org.apache.log4j.ConsoleAppender;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.log4j.PatternLayout;
import org.apache.pig.PigException;
import org.apache.pig.backend.executionengine.ExecException;
import org.apache.pig.data.DataBag;
import org.apache.pig.data.Tuple;
import org.apache.pig.pigunit.PigTest;
import org.apache.pig.tools.parameters.ParseException;
import org.testng.Assert;
import org.testng.annotations.Test;
import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import datafu.pig.hash.lsh.metric.Cosine;
import datafu.pig.hash.lsh.metric.L1;
import datafu.pig.hash.lsh.metric.L2;
import datafu.pig.hash.lsh.util.DataTypeUtil;
import datafu.test.pig.PigTests;
public class LSHPigTest extends PigTests
{
private static void setMemorySettings()
{
System.getProperties().setProperty("mapred.map.child.java.opts", "-Xmx1G");
System.getProperties().setProperty("mapred.reduce.child.java.opts","-Xmx1G");
System.getProperties().setProperty("io.sort.mb","10");
}
/**
* PTS = LOAD 'input' AS (b:bag{t:tuple(idx:int, val:double)});
* STORE PTS INTO 'output';
*/
@Multiline private String sparseVectorTest;
@Test
public void testSparseVectors() throws IOException, ParseException
{
RandomGenerator rg = new JDKRandomGenerator();
rg.setSeed(0);
RandomData rd = new RandomDataImpl(rg);
int n = 20;
List<RealVector> vectors = LSHTest.getVectors(rd, 1000, n);
PigTest test = createPigTestFromString(sparseVectorTest);
writeLinesToFile("input", getSparseLines(vectors));
test.runScript();
List<Tuple> neighbors = this.getLinesForAlias(test, "PTS");
Assert.assertEquals(neighbors.size(), n);
int idx = 0;
for(Tuple t : neighbors)
{
Assert.assertTrue(t.get(0) instanceof DataBag);
Assert.assertEquals(t.size(), 1);
RealVector interpreted = DataTypeUtil.INSTANCE.convert(t, 3);
RealVector original = vectors.get(idx);
Assert.assertEquals(original.getDimension(), interpreted.getDimension());
for(int i = 0;i < interpreted.getDimension();++i)
{
double originalField = original.getEntry(i);
double interpretedField = interpreted.getEntry(i);
Assert.assertTrue(Math.abs(originalField - interpretedField) < 1e-5);
}
idx++;
}
}
/**
define LSH datafu.pig.hash.lsh.L1PStableHash('3', '150', '1', '5');
define METRIC datafu.pig.hash.lsh.metric.L1('3');
PTS = LOAD 'input' AS (dim1:double, dim2:double, dim3:double);
PTS_HASHED = foreach PTS generate TOTUPLE(dim1, dim2, dim3) as pt
, FLATTEN(LSH(TOTUPLE(dim1, dim2, dim3)));
store PTS_HASHED INTO 'lsh_pts';
*/
@Multiline private String randomSeedTest;
@Test
public void testRandomSeed() throws Exception
{
setMemorySettings();
RandomGenerator rg = new JDKRandomGenerator();
rg.setSeed(0);
RandomData rd = new RandomDataImpl(rg);
int n = 10;
List<RealVector> vectors = LSHTest.getVectors(rd, 1000, n);
writeLinesToFile("input", getLines(vectors));
Map<String, Long> hashes = new HashMap<String, Long>();
int numDiff = 0;
int numHashes = 0;
{
PigTest test = createPigTestFromString(randomSeedTest);
test.runScript();
List<Tuple> pts = this.getLinesForAlias(test, "PTS_HASHED");
numHashes = pts.size();
for(Tuple t : pts)
{
String key = ((Tuple)t.get(0)).toString() + t.get(1).toString();
Long value = (Long)t.get(2);
hashes.put(key, value);
}
}
{
PigTest test = createPigTestFromString(randomSeedTest);
test.runScript();
List<Tuple> pts = this.getLinesForAlias(test, "PTS_HASHED");
Assert.assertEquals(numHashes, pts.size());
for(Tuple t : pts)
{
String key = ((Tuple)t.get(0)).toString() + t.get(1).toString();
Long value = (Long)t.get(2);
Long refValue = hashes.get(key);
if(!value.equals(refValue))
{
numDiff++;
}
}
}
//Assert that 80% of the hashes are different between runs due to different seeds.
//this ensures that a random seed is actually being used.
System.out.println(1.0*numDiff / numHashes);
Assert.assertTrue(1.0*numDiff/numHashes > .8);
}
/**
define LSH datafu.pig.hash.lsh.L1PStableHash('3', '150', '1', '5', '0');
define METRIC datafu.pig.hash.lsh.metric.L1('3');
PTS = LOAD 'input' AS (pt:bag{t:tuple(idx:int, val:double)});
PTS_HASHED = foreach PTS generate pt as pt
, FLATTEN(LSH(pt));
PARTITIONS = group PTS_HASHED by (lsh_id, hash);
QUERIES = LOAD 'queries' as (pt:bag{t:tuple(idx:int, val:double)});
QUERIES_HASHED = foreach QUERIES generate pt as query_pt
, FLATTEN(LSH(pt))
;
QUERIES_W_PARTS = join QUERIES_HASHED by (lsh_id, hash), PARTITIONS by (group.$0, group.$1);
NEAR_NEIGHBORS = foreach QUERIES_W_PARTS generate query_pt as query_pt
, METRIC(query_pt, 1000, PTS_HASHED) as neighbor
;
describe NEAR_NEIGHBORS;
NEIGHBORS_PROJ = foreach NEAR_NEIGHBORS {
generate TOTUPLE(query_pt) as query_pt, neighbor.pt as matching_pts;
};
describe NEIGHBORS_PROJ;
NOT_NULL = filter NEIGHBORS_PROJ by SIZE(matching_pts) > 0;
NEIGHBORS_GRP = group NOT_NULL by query_pt;
describe NEIGHBORS_GRP;
NEIGHBOR_CNT = foreach NEIGHBORS_GRP{
MATCHING_PTS = foreach NOT_NULL generate matching_pts;
DIST_MATCHING_PTS = DISTINCT MATCHING_PTS;
generate group as query_pt, COUNT(NOT_NULL), DIST_MATCHING_PTS;
};
STORE NEIGHBOR_CNT INTO 'neighbors';
*/
@Multiline private String l1SparseTest;
@Test
public void testL1UDFSparse() throws Exception
{
setMemorySettings();
RandomGenerator rg = new JDKRandomGenerator();
rg.setSeed(0);
RandomData rd = new RandomDataImpl(rg);
int n = 1000;
List<RealVector> vectors = LSHTest.getVectors(rd, 1000, n);
PigTest test = createPigTestFromString(l1SparseTest);
writeLinesToFile("input", getSparseLines(vectors));
List<RealVector> queries = LSHTest.getVectors(rd, 1000, 10);
writeLinesToFile("queries", getSparseLines(queries));
test.runScript();
List<Tuple> neighbors = this.getLinesForAlias(test, "NEIGHBOR_CNT");
Assert.assertEquals( queries.size(), neighbors.size() );
for(long cnt : getCounts(neighbors))
{
Assert.assertTrue(cnt >= 3);
}
Distance d = new Distance()
{
@Override
public double distance(RealVector v1, RealVector v2) {
return L1.distance(v1, v2);
}
};
verifyPoints(neighbors, d, 1000);
}
/**
define LSH datafu.pig.hash.lsh.L1PStableHash('3', '150', '1', '5', '0');
define METRIC datafu.pig.hash.lsh.metric.L1('3');
PTS = LOAD 'input' AS (dim1:double, dim2:double, dim3:double);
PTS_HASHED = foreach PTS generate TOTUPLE(dim1, dim2, dim3) as pt
, FLATTEN(LSH(TOTUPLE(dim1, dim2, dim3)));
PARTITIONS = group PTS_HASHED by (lsh_id, hash);
QUERIES = LOAD 'queries' as (dim1:double, dim2:double, dim3:double);
QUERIES_HASHED = foreach QUERIES generate TOTUPLE(dim1, dim2, dim3) as query_pt
, FLATTEN(LSH(TOTUPLE(dim1, dim2, dim3)))
;
QUERIES_W_PARTS = join QUERIES_HASHED by (lsh_id, hash), PARTITIONS by (group.$0, group.$1);
NEAR_NEIGHBORS = foreach QUERIES_W_PARTS generate query_pt as query_pt
, METRIC(query_pt, 1000, PTS_HASHED) as neighbor
;
describe NEAR_NEIGHBORS;
NEIGHBORS_PROJ = foreach NEAR_NEIGHBORS {
generate query_pt as query_pt, neighbor.pt as matching_pts;
};
describe NEIGHBORS_PROJ;
NOT_NULL = filter NEIGHBORS_PROJ by SIZE(matching_pts) > 0;
NEIGHBORS_GRP = group NOT_NULL by query_pt;
describe NEIGHBORS_GRP;
NEIGHBOR_CNT = foreach NEIGHBORS_GRP{
MATCHING_PTS = foreach NOT_NULL generate FLATTEN(matching_pts);
DIST_MATCHING_PTS = DISTINCT MATCHING_PTS;
generate group as query_pt, COUNT(NOT_NULL), DIST_MATCHING_PTS;
};
STORE NEIGHBOR_CNT INTO 'neighbors';
*/
@Multiline private String l1Test;
@Test
public void testL1UDF() throws Exception
{
setMemorySettings();
RandomGenerator rg = new JDKRandomGenerator();
rg.setSeed(0);
RandomData rd = new RandomDataImpl(rg);
int n = 1000;
List<RealVector> vectors = LSHTest.getVectors(rd, 1000, n);
PigTest test = createPigTestFromString(l1Test);
writeLinesToFile("input", getLines(vectors));
List<RealVector> queries = LSHTest.getVectors(rd, 1000, 10);
writeLinesToFile("queries", getLines(queries));
test.runScript();
List<Tuple> neighbors = this.getLinesForAlias(test, "NEIGHBOR_CNT");
Assert.assertEquals( queries.size(), neighbors.size() );
for(long cnt : getCounts(neighbors))
{
Assert.assertTrue(cnt >= 3);
}
Distance d = new Distance()
{
@Override
public double distance(RealVector v1, RealVector v2) {
return L1.distance(v1, v2);
}
};
verifyPoints(neighbors, d, 1000);
}
/**
define LSH datafu.pig.hash.lsh.L2PStableHash('3', '200', '1', '5', '0');
define METRIC datafu.pig.hash.lsh.metric.L2('3');
PTS = LOAD 'input' AS (dim1:double, dim2:double, dim3:double);
PTS_HASHED = foreach PTS generate TOTUPLE(dim1, dim2, dim3) as pt
, FLATTEN(LSH(TOTUPLE(dim1, dim2, dim3)));
PARTITIONS = group PTS_HASHED by (lsh_id, hash);
QUERIES = LOAD 'queries' as (dim1:double, dim2:double, dim3:double);
QUERIES_HASHED = foreach QUERIES generate TOTUPLE(dim1, dim2, dim3) as query_pt
, FLATTEN(LSH(TOTUPLE(dim1, dim2, dim3)))
;
QUERIES_W_PARTS = join QUERIES_HASHED by (lsh_id, hash), PARTITIONS by (group.$0, group.$1);
NEAR_NEIGHBORS = foreach QUERIES_W_PARTS generate query_pt as query_pt
, METRIC(query_pt, 1000, PTS_HASHED) as neighbor
;
describe NEAR_NEIGHBORS;
NEIGHBORS_PROJ = foreach NEAR_NEIGHBORS {
generate query_pt as query_pt, neighbor.pt as matching_pts;
};
describe NEIGHBORS_PROJ;
NOT_NULL = filter NEIGHBORS_PROJ by SIZE(matching_pts) > 0;
NEIGHBORS_GRP = group NOT_NULL by query_pt;
describe NEIGHBORS_GRP;
NEIGHBOR_CNT = foreach NEIGHBORS_GRP{
MATCHING_PTS = foreach NOT_NULL generate FLATTEN(matching_pts);
DIST_MATCHING_PTS = DISTINCT MATCHING_PTS;
generate group as query_pt, COUNT(NOT_NULL), DIST_MATCHING_PTS;
};
STORE NEIGHBOR_CNT INTO 'neighbors';
*/
@Multiline private String l2Test;
@Test
public void testL2UDF() throws Exception
{
setMemorySettings();
RandomGenerator rg = new JDKRandomGenerator();
rg.setSeed(0);
RandomData rd = new RandomDataImpl(rg);
int n = 1000;
List<RealVector> vectors = LSHTest.getVectors(rd, 1000, n);
PigTest test = createPigTestFromString(l2Test);
writeLinesToFile("input", getLines(vectors));
List<RealVector> queries = LSHTest.getVectors(rd, 1000, 10);
writeLinesToFile("queries", getLines(queries));
test.runScript();
List<Tuple> neighbors = this.getLinesForAlias(test, "NEIGHBOR_CNT");
Assert.assertEquals( queries.size(), neighbors.size() );
for(long cnt : getCounts(neighbors))
{
Assert.assertTrue(cnt >= 3);
}
Distance d = new Distance()
{
@Override
public double distance(RealVector v1, RealVector v2) {
return L2.distance(v1, v2);
}
};
verifyPoints(neighbors, d, 1000);
}
/**
define LSH datafu.pig.hash.lsh.CosineDistanceHash('3', '1500', '5', '0');
define METRIC datafu.pig.hash.lsh.metric.Cosine('3');
PTS = LOAD 'input' AS (dim1:double, dim2:double, dim3:double);
PTS_HASHED = foreach PTS generate TOTUPLE(dim1, dim2, dim3) as pt
, FLATTEN(LSH(TOTUPLE(dim1, dim2, dim3)));
PARTITIONS = group PTS_HASHED by (lsh_id, hash);
QUERIES = LOAD 'queries' as (dim1:double, dim2:double, dim3:double);
QUERIES_HASHED = foreach QUERIES generate TOTUPLE(dim1, dim2, dim3) as query_pt
, FLATTEN(LSH(TOTUPLE(dim1, dim2, dim3)))
;
describe QUERIES_HASHED;
QUERIES_W_PARTS = join QUERIES_HASHED by (lsh_id, hash), PARTITIONS by (group.$0, group.$1);
NEAR_NEIGHBORS = foreach QUERIES_W_PARTS generate query_pt as query_pt
, METRIC(query_pt, .001, PTS_HASHED) as neighbor
;
describe NEAR_NEIGHBORS;
NEIGHBORS_PROJ = foreach NEAR_NEIGHBORS {
generate query_pt as query_pt, neighbor.pt as matching_pts;
};
describe NEIGHBORS_PROJ;
NOT_NULL = filter NEIGHBORS_PROJ by SIZE(matching_pts) > 0;
NEIGHBORS_GRP = group NOT_NULL by query_pt;
describe NEIGHBORS_GRP;
NEIGHBOR_CNT = foreach NEIGHBORS_GRP{
MATCHING_PTS = foreach NOT_NULL generate FLATTEN(matching_pts);
DIST_MATCHING_PTS = DISTINCT MATCHING_PTS;
generate group as query_pt, COUNT(NOT_NULL), DIST_MATCHING_PTS;
};
describe NEIGHBOR_CNT;
STORE NEIGHBOR_CNT INTO 'neighbors';
*/
@Multiline private String cosTest;
@Test
public void testCosineUDF() throws Exception
{
setMemorySettings();
RandomGenerator rg = new JDKRandomGenerator();
rg.setSeed(0);
RandomData rd = new RandomDataImpl(rg);
int n = 1000;
List<RealVector> vectors = LSHTest.getVectors(rd, 1000, n);
PigTest test = createPigTestFromString(cosTest);
writeLinesToFile("input", getLines(vectors));
List<RealVector> queries = LSHTest.getVectors(rd, 1000, 10);
writeLinesToFile("queries", getLines(queries));
test.runScript();
List<Tuple> neighbors = this.getLinesForAlias(test, "NEIGHBOR_CNT");
Assert.assertEquals( queries.size(), neighbors.size() );
for(long cnt : getCounts(neighbors))
{
Assert.assertTrue(cnt >= 2);
}
Distance d = new Distance()
{
@Override
public double distance(RealVector v1, RealVector v2) {
return Cosine.distance(v1, v2);
}
};
verifyPoints(neighbors, d, .001);
}
private static interface Distance
{
public double distance(RealVector v1, RealVector v2);
}
private void verifyPoints(List<Tuple> neighbors, Distance d, double threshold) throws PigException
{
for(Tuple t : neighbors)
{
RealVector queryPt = DataTypeUtil.INSTANCE.convert(t, 3);
DataBag bag = (DataBag) t.get(2);
for(Tuple neighbor : bag)
{
RealVector v = DataTypeUtil.INSTANCE.convert(neighbor, 3);
double distance = d.distance(queryPt, v);
Assert.assertTrue(distance < threshold);
}
}
}
private Iterable<Long> getCounts(List<Tuple> neighbors)
{
return Iterables.transform(neighbors, new Function<Tuple, Long>()
{
public Long apply(Tuple in)
{
try {
return (Long)in.get(1);
} catch (ExecException e) {
return -1L;
}
}
});
}
private String[] getSparseLines(List<RealVector> vectors)
{
String[] input = new String[vectors.size()];
int i = 0;
for(RealVector vec : vectors)
{
input[i++] = String.format("({(%d,%f),(%d,%f),(%d,%f)})", 0, vec.getEntry(0), 1, vec.getEntry(1), 2, vec.getEntry(2));
}
return input;
}
private String[] getLines(List<RealVector> vectors)
{
String[] input = new String[vectors.size()];
int i = 0;
for(RealVector vec : vectors)
{
input[i++] = String.format("%f\t%f\t%f", vec.getEntry(0), vec.getEntry(1), vec.getEntry(2));
}
return input;
}
}