[FLINK-36653] Fix OnlineLogisticRegressionModel updating logic
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java
index 36d1555..c38efd5 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java
@@ -145,6 +145,10 @@
LogisticRegressionModelData modelData = streamRecord.getValue();
coefficient = modelData.coefficient;
modelDataVersion = modelData.modelVersion;
+ servable =
+ new LogisticRegressionModelServable(
+ new LogisticRegressionModelData(coefficient, modelDataVersion));
+ ParamUtils.updateExistingParams(servable, params);
for (Row dataPoint : bufferedPointsState.get()) {
processElement(new StreamRecord<>(dataPoint));
}
@@ -160,7 +164,7 @@
if (servable == null) {
servable =
new LogisticRegressionModelServable(
- new LogisticRegressionModelData(coefficient, 0L));
+ new LogisticRegressionModelData(coefficient, modelDataVersion));
ParamUtils.updateExistingParams(servable, params);
}
Vector features = (Vector) dataPoint.getField(servable.getFeaturesCol());
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java
index cac9473..a0bb5e8 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java
@@ -61,9 +61,14 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+import org.junit.runners.Parameterized.Parameters;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
@@ -72,6 +77,7 @@
import static org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel.MODEL_DATA_VERSION_GAUGE_KEY;
/** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */
+@RunWith(Parameterized.class)
public class OnlineLogisticRegressionTest extends TestLogger {
@Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
@@ -142,10 +148,21 @@
Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.)
};
- private static final int defaultParallelism = 4;
+ @Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {{1}, {4}});
+ }
+
+ @Parameter public int defaultParallelism;
private static final int numTaskManagers = 2;
private static final int numSlotsPerTaskManager = 2;
-
+ private static final Configuration config =
+ new Configuration() {
+ {
+ set(RestOptions.BIND_PORT, "18081-19091");
+ set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ }
+ };
private long currentModelDataVersion;
private InMemorySourceFunction<Row> trainDenseSource;
@@ -170,9 +187,6 @@
@BeforeClass
public static void beforeClass() throws Exception {
- Configuration config = new Configuration();
- config.set(RestOptions.BIND_PORT, "18081-19091");
- config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
reporter = InMemoryReporter.create();
reporter.addToConfiguration(config);
@@ -184,17 +198,17 @@
.setNumSlotsPerTaskManager(numSlotsPerTaskManager)
.build());
miniCluster.start();
+ }
+ @Before
+ public void before() throws Exception {
env = StreamExecutionEnvironment.getExecutionEnvironment(config);
env.getConfig().enableObjectReuse();
env.setParallelism(defaultParallelism);
env.enableCheckpointing(100);
env.setRestartStrategy(RestartStrategies.noRestart());
tEnv = StreamTableEnvironment.create(env);
- }
- @Before
- public void before() throws Exception {
currentModelDataVersion = 0;
trainDenseSource = new InMemorySourceFunction<>();
@@ -562,6 +576,10 @@
@Test
public void testBatchSizeLessThanParallelism() {
+ if (defaultParallelism < 2) {
+ return;
+ }
+
try {
new OnlineLogisticRegression()
.setInitialModelData(initDenseModel)