blob: 2a4dcf3e242008e23631435c23fd0e8115f03763 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.functions
import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.core.fs.FileSystem.WriteMode
import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
import org.apache.flink.test.util.{TestBaseUtils, MultipleProgramsTestBase}
import org.junit.Assert.fail
import org.junit.{After, Before, Test, Rule}
import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.apache.flink.api.scala._
import org.apache.flink.api.common.InvalidProgramException
/* The test cases are originally from the Apache Spark project. Like the ClosureCleaner itself. */
@RunWith(classOf[Parameterized])
class ClosureCleanerITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) {
val _tempFolder = new TemporaryFolder()
var resultPath: String = _
var result: String = _
@Rule
def tempFolder = _tempFolder
@Before
def before(): Unit = {
resultPath = tempFolder.newFile().toURI.toString
}
@After
def after(): Unit = {
TestBaseUtils.compareResultsByLinesInMemory(result, resultPath)
}
@Test
def testObject(): Unit = {
TestObject.run(resultPath)
result = "30"
}
@Test
def testClass(): Unit = {
val obj = new TestClass
obj.run(resultPath)
result = "30"
}
@Test
def testClassWithoutDefaulConstructor(): Unit = {
val obj = new TestClassWithoutDefaultConstructor(5)
obj.run(resultPath)
result = "30"
}
@Test
def testClassWithoutFieldAccess(): Unit = {
val obj = new TestClassWithoutFieldAccess
obj.run(resultPath)
result = "30" // 6 + 7 + 8 + 9
}
@Test
def testObjectWithNesting(): Unit = {
TestObjectWithNesting.run(resultPath)
result = "27"
}
@Test
def testClassWithNesting(): Unit = {
val obj = new TestClassWithNesting(1)
obj.run(resultPath)
result = "27"
}
@Test
def testObjectWithBogusReturns(): Unit = {
TestObjectWithBogusReturns.run(resultPath)
result = "1"
}
@Test
def testObjectWithNestedReturns(): Unit = {
TestObjectWithNestedReturns.run(resultPath)
result = "1"
}
}
// A non-serializable class we create in closures to make sure that we aren't
// keeping references to unneeded variables from our outer closures.
class NonSerializable {}
object TestObject {
def run(resultPath: String): Unit = {
var nonSer = new NonSerializable
val x = 5
val env = ExecutionEnvironment.getExecutionEnvironment
val nums = env.fromElements(1, 2, 3, 4)
nums.map(_ + x).reduce(_ + _).writeAsText(resultPath, WriteMode.OVERWRITE)
env.execute()
}
}
class TestClass extends Serializable {
var x = 5
def getX = x
def run(resultPath: String): Unit = {
var nonSer = new NonSerializable
val env = ExecutionEnvironment.getExecutionEnvironment
val nums = env.fromElements(1, 2, 3, 4)
nums.map(_ + getX).reduce(_ + _).writeAsText(resultPath, WriteMode.OVERWRITE)
env.execute()
}
}
class TestClassWithoutDefaultConstructor(x: Int) extends Serializable {
def getX = x
def run(resultPath: String): Unit = {
var nonSer = new NonSerializable
val env = ExecutionEnvironment.getExecutionEnvironment
val nums = env.fromElements(1, 2, 3, 4)
nums.map(_ + getX).reduce(_ + _).writeAsText(resultPath, WriteMode.OVERWRITE)
env.execute()
}
}
// This class is not serializable, but we aren't using any of its fields in our
// closures, so they won't have a $outer pointing to it and should still work.
class TestClassWithoutFieldAccess {
var nonSer = new NonSerializable
def run(resultPath: String): Unit = {
var nonSer2 = new NonSerializable
val x = 5
val env = ExecutionEnvironment.getExecutionEnvironment
val nums = env.fromElements(1, 2, 3, 4)
nums.map(_ + x).reduce(_ + _).writeAsText(resultPath, WriteMode.OVERWRITE)
env.execute()
}
}
object TestObjectWithBogusReturns {
def run(resultPath: String): Unit = {
var nonSer2 = new NonSerializable
val x = 5
val env = ExecutionEnvironment.getExecutionEnvironment
val nums = env.fromElements(1)
// this return is invalid since it will transfer control outside the closure
try {
nums.map { x => return 1; x * 2}.print()
} catch {
case inv: ReturnStatementInClosureException => // all good
case _: Throwable => fail("Bogus return statement not detected.")
}
nums.writeAsText(resultPath, WriteMode.OVERWRITE)
env.execute()
}
}
object TestObjectWithNestedReturns {
def run(resultPath: String): Unit = {
var nonSer2 = new NonSerializable
val x = 5
val env = ExecutionEnvironment.getExecutionEnvironment
val nums = env.fromElements(1)
nums.map { x =>
// this return is fine since it will not transfer control outside the closure
def foo(): Int = { return 5; 1 }
foo()
}
nums.writeAsText(resultPath, WriteMode.OVERWRITE)
env.execute()
}
}
object TestObjectWithNesting {
def run(resultPath: String): Unit = {
var nonSer = new NonSerializable
var answer = 0
val env = ExecutionEnvironment.getExecutionEnvironment
val nums = env.fromElements(1, 1)
var y = 1
val result = nums.iterate(4) {
in =>
var nonSer2 = new NonSerializable
var x = y + 3
in.map(_ + x + y).reduce(_ + _).withBroadcastSet(nums, "nums")
}
result.writeAsText(resultPath, WriteMode.OVERWRITE)
env.execute()
}
}
class TestClassWithNesting(val y: Int) extends Serializable {
def getY = y
def run(resultPath: String): Unit = {
var nonSer = new NonSerializable
var answer = 0
val env = ExecutionEnvironment.getExecutionEnvironment
val nums = env.fromElements(1, 1)
val result = nums.iterate(4) {
in =>
var nonSer2 = new NonSerializable
var x = y + 3
in.map(_ + x + getY).reduce(_ + _).withBroadcastSet(nums, "nums")
}
result.writeAsText(resultPath, WriteMode.OVERWRITE)
env.execute()
}
}