Merge pull request #311 from myui/feature/v0.4.2-rc.2
Updated hivemall version to v0.4.2-rc.2
diff --git a/VERSION b/VERSION
index e7bbf8c..24db844 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-0.4.2-rc.1
+0.4.2-rc.2
diff --git a/core/pom.xml b/core/pom.xml
index fa58096..3d341d4 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -5,7 +5,7 @@
<parent>
<groupId>io.github.myui</groupId>
<artifactId>hivemall</artifactId>
- <version>0.4.2-rc.1</version>
+ <version>0.4.2-rc.2</version>
<relativePath>../pom.xml</relativePath>
</parent>
diff --git a/core/src/main/java/hivemall/HivemallConstants.java b/core/src/main/java/hivemall/HivemallConstants.java
index 096661b..505cd2e 100644
--- a/core/src/main/java/hivemall/HivemallConstants.java
+++ b/core/src/main/java/hivemall/HivemallConstants.java
@@ -20,9 +20,10 @@
public final class HivemallConstants {
- public static final String VERSION = "0.4.2-rc.1";
+ public static final String VERSION = "0.4.2-rc.2";
public static final String BIAS_CLAUSE = "0";
+ public static final int BIAS_CLAUSE_HASHVAL = 0;
public static final String CONFKEY_RAND_AMPLIFY_SEED = "hivemall.amplify.seed";
// org.apache.hadoop.hive.serde.Constants (hive 0.9)
diff --git a/core/src/main/java/hivemall/fm/FMArrayModel.java b/core/src/main/java/hivemall/fm/FMArrayModel.java
index d0aca30..30ee95e 100644
--- a/core/src/main/java/hivemall/fm/FMArrayModel.java
+++ b/core/src/main/java/hivemall/fm/FMArrayModel.java
@@ -36,14 +36,7 @@
super(params);
this._p = params.numFeatures;
this._w = new float[params.numFeatures + 1];
- this._V = new float[params.numFeatures][params.factors];
- }
-
- @Override
- protected void initLearningParams() {
- for (int i = 0; i < _p; i++) {
- _V[i] = initV();
- }
+ this._V = new float[params.numFeatures][];
}
@Override
@@ -92,29 +85,31 @@
}
@Override
- protected float[] getV(int i) {
+ protected float[] getV(int i, boolean init) {
if (i < 1 || i > _p) {
throw new IllegalArgumentException("Index i should be in range [1," + _p + "]: " + i);
}
- return _V[i - 1];
+ final int idx = i - 1;
+ float[] v = _V[idx];
+ if (v == null && init) {
+ v = initV();
+ _V[idx] = v;
+ }
+ return v;
}
@Override
public float getV(@Nonnull final Feature x, int f) {
final int i = x.getFeatureIndex();
- if (i < 1 || i > _p) {
- throw new IllegalArgumentException("Index i should be in range [1," + _p + "]: " + i);
- }
- return _V[i - 1][f];
+ float[] v = getV(i, true);
+ return v[f];
}
@Override
protected void setV(@Nonnull Feature x, int f, float nextVif) {
final int i = x.getFeatureIndex();
- if (i < 1 || i > _p) {
- throw new IllegalArgumentException("Index i should be in range [1," + _p + "]: " + i);
- }
- _V[i - 1][f] = nextVif;
+ float[] v = getV(i, true);
+ v[f] = nextVif;
}
@Override
diff --git a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
index 9f16407..e23b33f 100644
--- a/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
+++ b/core/src/main/java/hivemall/fm/FMIntFeatureMapModel.java
@@ -101,7 +101,7 @@
}
@Override
- protected float[] getV(int i) {
+ protected float[] getV(int i, boolean init) {
assert (i >= 1) : i;
return _V.get(i);
}
diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
index 9c93480..396328a 100644
--- a/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
+++ b/core/src/main/java/hivemall/fm/FactorizationMachineModel.java
@@ -65,12 +65,8 @@
this._lambdaW = params.lambdaW;
this._lambdaV = new float[params.factors];
Arrays.fill(_lambdaV, params.lambdaV);
-
- initLearningParams();
}
- protected void initLearningParams() {}
-
public abstract int getSize();
protected int getMinIndex() {
@@ -100,7 +96,7 @@
* @param i index value >= 1
*/
@Nullable
- protected float[] getV(int i) {
+ protected float[] getV(int i, boolean init) {
throw new UnsupportedOperationException();
}
diff --git a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
index 18d8c41..2388689 100644
--- a/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
+++ b/core/src/main/java/hivemall/fm/FactorizationMachineUDTF.java
@@ -100,7 +100,7 @@
// ----------------------------------------
- protected FactorizationMachineModel _model;
+ protected transient FactorizationMachineModel _model;
/**
* The number of training examples processed
@@ -197,11 +197,6 @@
return cl;
}
- @Nonnull
- protected FactorizationMachineModel getModel() {
- return _model;
- }
-
@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length != 2 && argOIs.length != 3) {
@@ -215,9 +210,9 @@
this._yOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]);
this._params = newHyperParameters();
- CommandLine cl = processOptions(argOIs);
+ processOptions(argOIs);
- this._model = initModel(cl, _params);
+ this._model = null;
this._t = 0L;
if (LOG.isInfoEnabled()) {
@@ -251,21 +246,28 @@
}
@Nonnull
- protected FactorizationMachineModel initModel(@Nullable CommandLine cl,
- @Nonnull FMHyperParameters params) throws UDFArgumentException {
+ protected FactorizationMachineModel initModel(@Nonnull FMHyperParameters params)
+ throws UDFArgumentException {
+ final FactorizationMachineModel model;
if (params.parseFeatureAsInt) {
if (params.numFeatures == -1) {
- return new FMIntFeatureMapModel(params);
+ model = new FMIntFeatureMapModel(params);
} else {
- return new FMArrayModel(params);
+ model = new FMArrayModel(params);
}
} else {
- return new FMStringFeatureMapModel(params);
+ model = new FMStringFeatureMapModel(params);
}
+ this._model = model;
+ return model;
}
@Override
public void process(Object[] args) throws HiveException {
+ if (_model == null) {
+ this._model = initModel(_params);
+ }
+
Feature[] x = parseFeatures(args[0]);
if (x == null) {
return;
@@ -463,7 +465,7 @@
forwardObjs[2] = Arrays.asList(f_Vi);
for (int i = model.getMinIndex(), maxIdx = model.getMaxIndex(); i <= maxIdx; i++) {
- final float[] vi = model.getV(i);
+ final float[] vi = model.getV(i, false);
if (vi == null) {
continue;
}
diff --git a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
index 619ca3d..d99bee9 100644
--- a/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
+++ b/core/src/main/java/hivemall/fm/FieldAwareFactorizationMachineUDTF.java
@@ -68,11 +68,11 @@
private int _numFields;
// ----------------------------------------
- private FFMStringFeatureMapModel _ffmModel;
+ private transient FFMStringFeatureMapModel _ffmModel;
- private IntArrayList _fieldList;
+ private transient IntArrayList _fieldList;
@Nullable
- private DoubleArray3D _sumVfX;
+ private transient DoubleArray3D _sumVfX;
public FieldAwareFactorizationMachineUDTF() {
super();
@@ -156,8 +156,8 @@
}
@Override
- protected FFMStringFeatureMapModel initModel(@Nullable CommandLine cl,
- @Nonnull FMHyperParameters params) throws UDFArgumentException {
+ protected FFMStringFeatureMapModel initModel(@Nonnull FMHyperParameters params)
+ throws UDFArgumentException {
FFMHyperParameters ffmParams = (FFMHyperParameters) params;
FFMStringFeatureMapModel model = new FFMStringFeatureMapModel(ffmParams);
diff --git a/core/src/main/java/hivemall/ftvec/hashing/FeatureHashingUDF.java b/core/src/main/java/hivemall/ftvec/hashing/FeatureHashingUDF.java
index 4bc2179..c5f9ce0 100644
--- a/core/src/main/java/hivemall/ftvec/hashing/FeatureHashingUDF.java
+++ b/core/src/main/java/hivemall/ftvec/hashing/FeatureHashingUDF.java
@@ -17,6 +17,7 @@
*/
package hivemall.ftvec.hashing;
+import hivemall.HivemallConstants;
import hivemall.UDFWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hashing.MurmurHash3;
@@ -146,18 +147,28 @@
}
@Nonnull
- private static String featureHashing(@Nonnull final String fv, final int numFeatures) {
+ static String featureHashing(@Nonnull final String fv, final int numFeatures) {
final int headPos = fv.indexOf(':');
if (headPos == -1) {
+ if (fv.equals(HivemallConstants.BIAS_CLAUSE)) {
+ return fv;
+ }
int h = mhash(fv, numFeatures);
return String.valueOf(h);
} else {
final int tailPos = fv.lastIndexOf(':');
if (headPos == tailPos) {
String f = fv.substring(0, headPos);
+ String tail = fv.substring(headPos);
+ if (f.equals(HivemallConstants.BIAS_CLAUSE)) {
+ String v = fv.substring(headPos + 1);
+ double d = Double.parseDouble(v);
+ if (d == 1.d) {
+ return fv;
+ }
+ }
int h = mhash(f, numFeatures);
- String v = fv.substring(headPos);
- return h + v;
+ return h + tail;
} else {
String field = fv.substring(0, headPos + 1);
String f = fv.substring(headPos + 1, tailPos);
@@ -168,7 +179,7 @@
}
}
- private static int mhash(@Nonnull final String word, final int numFeatures) {
+ static int mhash(@Nonnull final String word, final int numFeatures) {
int r = MurmurHash3.murmurhash3_x86_32(word, 0, word.length(), 0x9747b28c) % numFeatures;
if (r < 0) {
r += numFeatures;
diff --git a/core/src/main/java/hivemall/ftvec/hashing/MurmurHash3UDF.java b/core/src/main/java/hivemall/ftvec/hashing/MurmurHash3UDF.java
index ec25de3..7006186 100644
--- a/core/src/main/java/hivemall/ftvec/hashing/MurmurHash3UDF.java
+++ b/core/src/main/java/hivemall/ftvec/hashing/MurmurHash3UDF.java
@@ -22,6 +22,9 @@
import java.util.List;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
@@ -33,11 +36,13 @@
@UDFType(deterministic = true, stateful = false)
public final class MurmurHash3UDF extends UDF {
- public IntWritable evaluate(final String word) throws UDFArgumentException {
+ @Nullable
+ public IntWritable evaluate(@Nullable final String word) throws UDFArgumentException {
return evaluate(word, MurmurHash3.DEFAULT_NUM_FEATURES);
}
- public IntWritable evaluate(final String word, final int numFeatures)
+ @Nullable
+ public IntWritable evaluate(@Nullable final String word, final int numFeatures)
throws UDFArgumentException {
if (word == null) {
return null;
@@ -46,11 +51,13 @@
return new IntWritable(h);
}
- public IntWritable evaluate(final List<String> words) throws UDFArgumentException {
+ @Nullable
+ public IntWritable evaluate(@Nullable final List<String> words) throws UDFArgumentException {
return evaluate(words, MurmurHash3.DEFAULT_NUM_FEATURES);
}
- public IntWritable evaluate(final List<String> words, final int numFeatures)
+ @Nullable
+ public IntWritable evaluate(@Nullable final List<String> words, final int numFeatures)
throws UDFArgumentException {
if (words == null) {
return null;
@@ -70,11 +77,11 @@
return evaluate(s, numFeatures);
}
- public static int mhash(final String word) {
+ public static int mhash(@Nonnull final String word) {
return mhash(word, MurmurHash3.DEFAULT_NUM_FEATURES);
}
- public static int mhash(final String word, final int numFeatures) {
+ public static int mhash(@Nonnull final String word, final int numFeatures) {
int r = MurmurHash3.murmurhash3_x86_32(word, 0, word.length(), 0x9747b28c) % numFeatures;
if (r < 0) {
r += numFeatures;
diff --git a/core/src/main/java/hivemall/mf/MFPredictionUDF.java b/core/src/main/java/hivemall/mf/MFPredictionUDF.java
index 366fb00..ee6627c 100644
--- a/core/src/main/java/hivemall/mf/MFPredictionUDF.java
+++ b/core/src/main/java/hivemall/mf/MFPredictionUDF.java
@@ -20,10 +20,14 @@
import java.util.List;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.io.FloatWritable;
@Description(
@@ -32,51 +36,70 @@
@UDFType(deterministic = true, stateful = false)
public final class MFPredictionUDF extends UDF {
- public FloatWritable evaluate(List<Float> Pu, List<Float> Qi) throws HiveException {
- return evaluate(Pu, Qi, 0.d);
+ @Nonnull
+ public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
+ @Nullable List<FloatWritable> Qi) throws HiveException {
+ return evaluate(Pu, Qi, null);
}
- public FloatWritable evaluate(List<Float> Pu, List<Float> Qi, double mu) throws HiveException {
+ @Nonnull
+ public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
+ @Nullable List<FloatWritable> Qi, @Nullable DoubleWritable mu) throws HiveException {
+ final double muValue = (mu == null) ? 0.d : mu.get();
if (Pu == null || Qi == null) {
- return new FloatWritable((float) mu);
+ return new DoubleWritable(muValue);
}
final int PuSize = Pu.size();
final int QiSize = Qi.size();
// workaround for TD
if (PuSize == 0) {
- return new FloatWritable((float) mu);
+ return new DoubleWritable(muValue);
} else if (QiSize == 0) {
- return new FloatWritable((float) mu);
+ return new DoubleWritable(muValue);
}
if (QiSize != PuSize) {
throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize);
}
- float ret = (float) mu;
+ double ret = muValue;
for (int k = 0; k < PuSize; k++) {
- ret += Pu.get(k) * Qi.get(k);
+ FloatWritable Pu_k = Pu.get(k);
+ if (Pu_k == null) {
+ continue;
+ }
+ FloatWritable Qi_k = Qi.get(k);
+ if (Qi_k == null) {
+ continue;
+ }
+ ret += Pu_k.get() * Qi_k.get();
}
- return new FloatWritable(ret);
+ return new DoubleWritable(ret);
}
- public FloatWritable evaluate(List<Float> Pu, List<Float> Qi, double Bu, double Bi)
- throws HiveException {
- return evaluate(Pu, Qi, Bu, Bi, 0.d);
+ @Nonnull
+ public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
+ @Nullable List<FloatWritable> Qi, @Nullable DoubleWritable Bu,
+ @Nullable DoubleWritable Bi) throws HiveException {
+ return evaluate(Pu, Qi, Bu, Bi, null);
}
- public FloatWritable evaluate(List<Float> Pu, List<Float> Qi, double Bu, double Bi, double mu)
- throws HiveException {
+ @Nonnull
+ public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
+ @Nullable List<FloatWritable> Qi, @Nullable DoubleWritable Bu,
+ @Nullable DoubleWritable Bi, @Nullable DoubleWritable mu) throws HiveException {
+ final double muValue = (mu == null) ? 0.d : mu.get();
if (Pu == null && Qi == null) {
- return new FloatWritable((float) mu);
+ return new DoubleWritable(muValue);
}
+ final double BiValue = (Bi == null) ? 0.d : Bi.get();
+ final double BuValue = (Bu == null) ? 0.d : Bu.get();
if (Pu == null) {
- float ret = (float) (mu + Bi);
- return new FloatWritable(ret);
+ double ret = muValue + BiValue;
+ return new DoubleWritable(ret);
} else if (Qi == null) {
- float ret = (float) (mu + Bu);
- return new FloatWritable(ret);
+ return new DoubleWritable(muValue);
}
final int PuSize = Pu.size();
@@ -84,25 +107,33 @@
// workaround for TD
if (PuSize == 0) {
if (QiSize == 0) {
- return new FloatWritable((float) mu);
+ return new DoubleWritable(muValue);
} else {
- float ret = (float) (mu + Bi);
- return new FloatWritable(ret);
+ double ret = muValue + BiValue;
+ return new DoubleWritable(ret);
}
} else if (QiSize == 0) {
- float ret = (float) (mu + Bu);
- return new FloatWritable(ret);
+ double ret = muValue + BuValue;
+ return new DoubleWritable(ret);
}
if (QiSize != PuSize) {
throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize);
}
- float ret = (float) (mu + Bu + Bi);
+ double ret = muValue + BuValue + BiValue;
for (int k = 0; k < PuSize; k++) {
- ret += Pu.get(k) * Qi.get(k);
+ FloatWritable Pu_k = Pu.get(k);
+ if (Pu_k == null) {
+ continue;
+ }
+ FloatWritable Qi_k = Qi.get(k);
+ if (Qi_k == null) {
+ continue;
+ }
+ ret += Pu_k.get() * Qi_k.get();
}
- return new FloatWritable(ret);
+ return new DoubleWritable(ret);
}
}
diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
index 76b438c..3228517 100644
--- a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
+++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
@@ -50,8 +50,8 @@
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
@@ -66,7 +66,7 @@
public final class TreePredictUDF extends GenericUDF {
private boolean classification;
- private IntObjectInspector modelTypeOI;
+ private PrimitiveObjectInspector modelTypeOI;
private StringObjectInspector stringOI;
private ListObjectInspector featureListOI;
private PrimitiveObjectInspector featureElemOI;
@@ -94,7 +94,7 @@
throw new UDFArgumentException("_FUNC_ takes 4 or 5 arguments");
}
- this.modelTypeOI = HiveUtils.asIntOI(argOIs[1]);
+ this.modelTypeOI = HiveUtils.asIntegerOI(argOIs[1]);
this.stringOI = HiveUtils.asStringOI(argOIs[2]);
ListObjectInspector listOI = HiveUtils.asListOI(argOIs[3]);
this.featureListOI = listOI;
@@ -124,7 +124,7 @@
String modelId = arg0.toString();
Object arg1 = arguments[1].get();
- int modelTypeId = modelTypeOI.get(arg1);
+ int modelTypeId = PrimitiveObjectInspectorUtils.getInt(arg1, modelTypeOI);
ModelType modelType = ModelType.resolve(modelTypeId);
Object arg2 = arguments[2].get();
diff --git a/core/src/test/java/hivemall/fm/ArrayModelTest.java b/core/src/test/java/hivemall/fm/ArrayModelTest.java
index 4c4b32a..9706dbd 100644
--- a/core/src/test/java/hivemall/fm/ArrayModelTest.java
+++ b/core/src/test/java/hivemall/fm/ArrayModelTest.java
@@ -48,9 +48,9 @@
DoubleObjectInspector yOI = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
ObjectInspector paramOI = ObjectInspectorUtils.getConstantObjectInspector(
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
- "-adareg -factors 20 -classification -seed 31 -iters 100 -int_feature -p " + COL);
+ "-adareg -factors 20 -classification -seed 31 -iters 10 -int_feature -p " + COL);
udtf.initialize(new ObjectInspector[] {xOI, yOI, paramOI});
- FactorizationMachineModel model = udtf.getModel();
+ FactorizationMachineModel model = udtf.initModel(udtf._params);
Assert.assertTrue("Actual class: " + model.getClass().getName(),
model instanceof FMArrayModel);
@@ -110,8 +110,8 @@
accuracy = bingo / (float) total;
println("Accuracy = " + accuracy);
}
- udtf.runTrainingIteration(100);
- Assert.assertTrue(accuracy > 0.95f);
+ udtf.runTrainingIteration(10);
+ Assert.assertTrue("Expected accuracy greather than 0.95f: " + accuracy, accuracy > 0.95f);
}
@Test
@@ -125,7 +125,7 @@
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
"-factors 20 -seed 31 -eta 0.001 -lambda0 0.1 -sigma 0.1 -int_feature -p " + COL);
udtf.initialize(new ObjectInspector[] {xOI, yOI, paramOI});
- FactorizationMachineModel model = udtf.getModel();
+ FactorizationMachineModel model = udtf.initModel(udtf._params);
Assert.assertTrue("Actual class: " + model.getClass().getName(),
model instanceof FMArrayModel);
diff --git a/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java
index a418272..81c1858 100644
--- a/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java
+++ b/core/src/test/java/hivemall/fm/FactorizationMachineUDTFTest.java
@@ -31,7 +31,7 @@
"-factors 5 -min 1 -max 5 -iters 1 -init_v gaussian -eta0 0.01 -seed 31")};
udtf.initialize(argOIs);
- FactorizationMachineModel model = udtf.getModel();
+ FactorizationMachineModel model = udtf.initModel(udtf._params);
Assert.assertTrue("Actual class: " + model.getClass().getName(),
model instanceof FMStringFeatureMapModel);
diff --git a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
index 2760b30..cc1fcf3 100644
--- a/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
+++ b/core/src/test/java/hivemall/fm/FieldAwareFactorizationMachineUDTFTest.java
@@ -80,7 +80,7 @@
PrimitiveObjectInspectorFactory.javaStringObjectInspector, testOptions)};
udtf.initialize(argOIs);
- FieldAwareFactorizationMachineModel model = (FieldAwareFactorizationMachineModel) udtf.getModel();
+ FieldAwareFactorizationMachineModel model = udtf.initModel(udtf._params);
Assert.assertTrue("Actual class: " + model.getClass().getName(),
model instanceof FFMStringFeatureMapModel);
diff --git a/core/src/test/java/hivemall/fm/IntFeatureMapModelTest.java b/core/src/test/java/hivemall/fm/IntFeatureMapModelTest.java
index e9196b2..017c25b 100644
--- a/core/src/test/java/hivemall/fm/IntFeatureMapModelTest.java
+++ b/core/src/test/java/hivemall/fm/IntFeatureMapModelTest.java
@@ -50,7 +50,7 @@
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
"-adareg -int_feature -factors 20 -classification -seed 31 -iters 10");
udtf.initialize(new ObjectInspector[] {xOI, yOI, paramOI});
- FactorizationMachineModel model = udtf.getModel();
+ FactorizationMachineModel model = udtf.initModel(udtf._params);
Assert.assertTrue("Actual class: " + model.getClass().getName(),
model instanceof FMIntFeatureMapModel);
@@ -125,7 +125,7 @@
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
"-int_feature -factors 20 -seed 31 -eta 0.001 -lambda0 0.1 -sigma 0.1");
udtf.initialize(new ObjectInspector[] {xOI, yOI, paramOI});
- FactorizationMachineModel model = udtf.getModel();
+ FactorizationMachineModel model = udtf.initModel(udtf._params);
Assert.assertTrue("Actual class: " + model.getClass().getName(),
model instanceof FMIntFeatureMapModel);
diff --git a/core/src/test/java/hivemall/fm/StringFeatureMapModelTest.java b/core/src/test/java/hivemall/fm/StringFeatureMapModelTest.java
index 484c2a1..56fa137 100644
--- a/core/src/test/java/hivemall/fm/StringFeatureMapModelTest.java
+++ b/core/src/test/java/hivemall/fm/StringFeatureMapModelTest.java
@@ -50,7 +50,7 @@
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
"-adareg -init_v gaussian -factors 20 -classification -seed 31 -iters 10");
udtf.initialize(new ObjectInspector[] {xOI, yOI, paramOI});
- FactorizationMachineModel model = udtf.getModel();
+ FactorizationMachineModel model = udtf.initModel(udtf._params);
Assert.assertTrue("Actual class: " + model.getClass().getName(),
model instanceof FMStringFeatureMapModel);
@@ -125,7 +125,7 @@
PrimitiveObjectInspectorFactory.javaStringObjectInspector,
"-factors 20 -seed 31 -eta 0.001 -lambda0 0.1 -sigma 0.1");
udtf.initialize(new ObjectInspector[] {xOI, yOI, paramOI});
- FactorizationMachineModel model = udtf.getModel();
+ FactorizationMachineModel model = udtf.initModel(udtf._params);
Assert.assertTrue("Actual class: " + model.getClass().getName(),
model instanceof FMStringFeatureMapModel);
diff --git a/core/src/test/java/hivemall/ftvec/hashing/FeatureHashingUDFTest.java b/core/src/test/java/hivemall/ftvec/hashing/FeatureHashingUDFTest.java
new file mode 100644
index 0000000..70e0b9c
--- /dev/null
+++ b/core/src/test/java/hivemall/ftvec/hashing/FeatureHashingUDFTest.java
@@ -0,0 +1,44 @@
+/*
+ * Hivemall: Hive scalable Machine Learning Library
+ *
+ * Copyright (C) 2015 Makoto YUI
+ * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package hivemall.ftvec.hashing;
+
+import hivemall.utils.hashing.MurmurHash3;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class FeatureHashingUDFTest {
+
+ @Test
+ public void testBias() {
+ String expected = "0:1.0";
+ String actual = FeatureHashingUDF.featureHashing(expected, MurmurHash3.DEFAULT_NUM_FEATURES);
+ Assert.assertEquals(expected, actual);
+
+ expected = "0";
+ actual = FeatureHashingUDF.featureHashing(expected, MurmurHash3.DEFAULT_NUM_FEATURES);
+ Assert.assertEquals(expected, actual);
+
+ expected = "0:1.1";
+ actual = FeatureHashingUDF.featureHashing(expected, MurmurHash3.DEFAULT_NUM_FEATURES);
+ Assert.assertEquals(
+ FeatureHashingUDF.mhash("0", MurmurHash3.DEFAULT_NUM_FEATURES) + ":1.1", actual);
+ }
+
+}
diff --git a/mixserv/pom.xml b/mixserv/pom.xml
index e89bf38..83c5022 100644
--- a/mixserv/pom.xml
+++ b/mixserv/pom.xml
@@ -5,7 +5,7 @@
<parent>
<groupId>io.github.myui</groupId>
<artifactId>hivemall</artifactId>
- <version>0.4.2-rc.1</version>
+ <version>0.4.2-rc.2</version>
<relativePath>../pom.xml</relativePath>
</parent>
diff --git a/nlp/pom.xml b/nlp/pom.xml
index dd2414a..27622cd 100644
--- a/nlp/pom.xml
+++ b/nlp/pom.xml
@@ -5,7 +5,7 @@
<parent>
<groupId>io.github.myui</groupId>
<artifactId>hivemall</artifactId>
- <version>0.4.2-rc.1</version>
+ <version>0.4.2-rc.2</version>
<relativePath>../pom.xml</relativePath>
</parent>
diff --git a/pom.xml b/pom.xml
index d918cbd..3ad81a4 100644
--- a/pom.xml
+++ b/pom.xml
@@ -4,7 +4,7 @@
<groupId>io.github.myui</groupId>
<artifactId>hivemall</artifactId>
- <version>0.4.2-rc.1</version>
+ <version>0.4.2-rc.2</version>
<name>Hivemall</name>
<description>Scalable Machine Learning Library for Apache Hive</description>
diff --git a/spark/.classpath b/spark/.classpath
new file mode 100644
index 0000000..534b5e5
--- /dev/null
+++ b/spark/.classpath
@@ -0,0 +1,36 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<classpath>
+ <classpathentry kind="src" output="target/classes" path="src/main/java">
+ <attributes>
+ <attribute name="optional" value="true"/>
+ <attribute name="maven.pomderived" value="true"/>
+ </attributes>
+ </classpathentry>
+ <classpathentry excluding="**" kind="src" output="target/classes" path="src/main/resources">
+ <attributes>
+ <attribute name="maven.pomderived" value="true"/>
+ </attributes>
+ </classpathentry>
+ <classpathentry kind="src" output="target/test-classes" path="src/test/java">
+ <attributes>
+ <attribute name="optional" value="true"/>
+ <attribute name="maven.pomderived" value="true"/>
+ </attributes>
+ </classpathentry>
+ <classpathentry excluding="**" kind="src" output="target/test-classes" path="src/test/resources">
+ <attributes>
+ <attribute name="maven.pomderived" value="true"/>
+ </attributes>
+ </classpathentry>
+ <classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER/org.eclipse.jdt.internal.debug.ui.launcher.StandardVMType/JavaSE-1.6">
+ <attributes>
+ <attribute name="maven.pomderived" value="true"/>
+ </attributes>
+ </classpathentry>
+ <classpathentry kind="con" path="org.eclipse.m2e.MAVEN2_CLASSPATH_CONTAINER">
+ <attributes>
+ <attribute name="maven.pomderived" value="true"/>
+ </attributes>
+ </classpathentry>
+ <classpathentry kind="output" path="target/classes"/>
+</classpath>
diff --git a/spark/build.sbt b/spark/build.sbt
index 8038569..f5da043 100644
--- a/spark/build.sbt
+++ b/spark/build.sbt
@@ -31,8 +31,8 @@
libraryDependencies ++= Seq(
"org.apache.commons" % "commons-compress" % "1.8",
- "io.github.myui" % "hivemall-core" % "0.4.2-rc.1",
- "io.github.myui" % "hivemall-mixserv" % "0.4.2-rc.1",
+ "io.github.myui" % "hivemall-core" % "0.4.2-rc.2",
+ "io.github.myui" % "hivemall-mixserv" % "0.4.2-rc.2",
"org.scalatest" % "scalatest_2.11" % "2.2.4" % "provided",
"org.xerial" % "xerial-core" % "3.2.3" % "provided"
)
diff --git a/spark/pom.xml b/spark/pom.xml
index 50fe0ab..b025db9 100644
--- a/spark/pom.xml
+++ b/spark/pom.xml
@@ -5,12 +5,12 @@
<parent>
<groupId>io.github.myui</groupId>
<artifactId>hivemall</artifactId>
- <version>0.4.2-rc.1</version>
+ <version>0.4.2-rc.2</version>
<relativePath>../pom.xml</relativePath>
</parent>
<artifactId>hivemall-spark</artifactId>
- <name>Hivemall in Spark</name>
+ <name>Hivemall on Spark</name>
<packaging>jar</packaging>
@@ -194,7 +194,7 @@
<reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory>
<junitxml>.</junitxml>
<filereports>WDF TestSuite.txt</filereports>
- <argLine>-Xms256m -Xmx1024m -XX:MaxPermSize=1024m -XX:+CMSClassUnloadingEnabled</argLine>
+ <argLine>-Xms256m -Xmx1024m -XX:MaxPermSize=1024m -XX:+CMSClassUnloadingEnabled</argLine>
</configuration>
<executions>
<execution>
diff --git a/spark/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala b/spark/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala
index 5f24fef..64f0382 100644
--- a/spark/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/hive/HiveUdfSuite.scala
@@ -33,7 +33,7 @@
checkAnswer(
sql(s"SELECT DISTINCT hivemall_version()"),
- Row("0.4.2-rc.1")
+ Row("0.4.2-rc.2")
)
// sql("DROP TEMPORARY FUNCTION IF EXISTS hivemall_version")
diff --git a/spark/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 45a8ee3..ace8123 100644
--- a/spark/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -244,7 +244,7 @@
}
test("misc - hivemall_version") {
- assert(DummyInputData.select(hivemall_version()).collect.toSet === Set(Row("0.4.2-rc.1")))
+ assert(DummyInputData.select(hivemall_version()).collect.toSet === Set(Row("0.4.2-rc.2")))
/**
* TODO: Why a test below does fail?
*