Skip to content

Commit

Permalink
[SPARK-45618][CORE] Remove BaseErrorHandler
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This patch removes a workaround trait `BaseErrorHandler` which was added long time ago (SPARK-25535) for [CRYPTO-141](https://issues.apache.org/jira/browse/CRYPTO-141) which was fixed 5 years ago.

### Why are the changes needed?

Removing unnecessary workaround code.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing tests.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#43468 from viirya/remove_workaround.

Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
viirya authored and dongjoon-hyun committed Oct 24, 2023
1 parent 2daa66e commit 4385273
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 162 deletions.
137 changes: 9 additions & 128 deletions core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.security

import java.io.{Closeable, InputStream, IOException, OutputStream}
import java.io.{InputStream, OutputStream}
import java.nio.ByteBuffer
import java.nio.channels.{ReadableByteChannel, WritableByteChannel}
import java.util.Properties
Expand Down Expand Up @@ -55,10 +55,8 @@ private[spark] object CryptoStreamUtils extends Logging {
val params = new CryptoParams(key, sparkConf)
val iv = createInitializationVector(params.conf)
os.write(iv)
new ErrorHandlingOutputStream(
new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec,
new IvParameterSpec(iv)),
os)
new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec,
new IvParameterSpec(iv))
}

/**
Expand All @@ -73,10 +71,8 @@ private[spark] object CryptoStreamUtils extends Logging {
val helper = new CryptoHelperChannel(channel)

helper.write(ByteBuffer.wrap(iv))
new ErrorHandlingWritableChannel(
new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec,
new IvParameterSpec(iv)),
helper)
new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec,
new IvParameterSpec(iv))
}

/**
Expand All @@ -89,10 +85,8 @@ private[spark] object CryptoStreamUtils extends Logging {
val iv = new Array[Byte](IV_LENGTH_IN_BYTES)
ByteStreams.readFully(is, iv)
val params = new CryptoParams(key, sparkConf)
new ErrorHandlingInputStream(
new CryptoInputStream(params.transformation, params.conf, is, params.keySpec,
new IvParameterSpec(iv)),
is)
new CryptoInputStream(params.transformation, params.conf, is, params.keySpec,
new IvParameterSpec(iv))
}

/**
Expand All @@ -107,10 +101,8 @@ private[spark] object CryptoStreamUtils extends Logging {
JavaUtils.readFully(channel, buf)

val params = new CryptoParams(key, sparkConf)
new ErrorHandlingReadableChannel(
new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec,
new IvParameterSpec(iv)),
channel)
new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec,
new IvParameterSpec(iv))
}

def toCryptoConf(conf: SparkConf): Properties = {
Expand Down Expand Up @@ -166,117 +158,6 @@ private[spark] object CryptoStreamUtils extends Logging {

}

/**
* SPARK-25535. The commons-crypto library will throw InternalError if something goes
* wrong, and leave bad state behind in the Java wrappers, so it's not safe to use them
* afterwards. This wrapper detects that situation and avoids further calls into the
* commons-crypto code, while still allowing the underlying streams to be closed.
*
* This should be removed once CRYPTO-141 is fixed (and Spark upgrades its commons-crypto
* dependency).
*/
trait BaseErrorHandler extends Closeable {

private var closed = false

/** The encrypted stream that may get into an unhealthy state. */
protected def cipherStream: Closeable

/**
* The underlying stream that is being wrapped by the encrypted stream, so that it can be
* closed even if there's an error in the crypto layer.
*/
protected def original: Closeable

protected def safeCall[T](fn: => T): T = {
if (closed) {
throw new IOException("Cipher stream is closed.")
}
try {
fn
} catch {
case ie: InternalError =>
closed = true
original.close()
throw ie
}
}

override def close(): Unit = {
if (!closed) {
cipherStream.close()
}
}

}

// Visible for testing.
class ErrorHandlingReadableChannel(
protected val cipherStream: ReadableByteChannel,
protected val original: ReadableByteChannel)
extends ReadableByteChannel with BaseErrorHandler {

override def read(src: ByteBuffer): Int = safeCall {
cipherStream.read(src)
}

override def isOpen(): Boolean = cipherStream.isOpen()

}

private class ErrorHandlingInputStream(
protected val cipherStream: InputStream,
protected val original: InputStream)
extends InputStream with BaseErrorHandler {

override def read(b: Array[Byte]): Int = safeCall {
cipherStream.read(b)
}

override def read(b: Array[Byte], off: Int, len: Int): Int = safeCall {
cipherStream.read(b, off, len)
}

override def read(): Int = safeCall {
cipherStream.read()
}
}

private class ErrorHandlingWritableChannel(
protected val cipherStream: WritableByteChannel,
protected val original: WritableByteChannel)
extends WritableByteChannel with BaseErrorHandler {

override def write(src: ByteBuffer): Int = safeCall {
cipherStream.write(src)
}

override def isOpen(): Boolean = cipherStream.isOpen()

}

private class ErrorHandlingOutputStream(
protected val cipherStream: OutputStream,
protected val original: OutputStream)
extends OutputStream with BaseErrorHandler {

override def flush(): Unit = safeCall {
cipherStream.flush()
}

override def write(b: Array[Byte]): Unit = safeCall {
cipherStream.write(b)
}

override def write(b: Array[Byte], off: Int, len: Int): Unit = safeCall {
cipherStream.write(b, off, len)
}

override def write(b: Int): Unit = safeCall {
cipherStream.write(b)
}
}

private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) {

val keySpec = new SecretKeySpec(key, "AES")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
package org.apache.spark.security

import java.io._
import java.nio.ByteBuffer
import java.nio.channels.{Channels, ReadableByteChannel}
import java.nio.channels.Channels
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.Files
import java.util.{Arrays, Random, UUID}

import com.google.common.io.ByteStreams
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._

import org.apache.spark._
import org.apache.spark.internal.config._
Expand Down Expand Up @@ -167,36 +164,6 @@ class CryptoStreamUtilsSuite extends SparkFunSuite {
}
}

test("error handling wrapper") {
val wrapped = mock(classOf[ReadableByteChannel])
val decrypted = mock(classOf[ReadableByteChannel])
val errorHandler = new CryptoStreamUtils.ErrorHandlingReadableChannel(decrypted, wrapped)

when(decrypted.read(any(classOf[ByteBuffer])))
.thenThrow(new IOException())
.thenThrow(new InternalError())
.thenReturn(1)

val out = ByteBuffer.allocate(1)
intercept[IOException] {
errorHandler.read(out)
}
intercept[InternalError] {
errorHandler.read(out)
}

val e = intercept[IOException] {
errorHandler.read(out)
}
assert(e.getMessage().contains("is closed"))
errorHandler.close()

verify(decrypted, times(2)).read(any(classOf[ByteBuffer]))
verify(wrapped, never()).read(any(classOf[ByteBuffer]))
verify(decrypted, never()).close()
verify(wrapped, times(1)).close()
}

private def createConf(extra: (String, String)*): SparkConf = {
val conf = new SparkConf()
extra.foreach { case (k, v) => conf.set(k, v) }
Expand Down

0 comments on commit 4385273

Please sign in to comment.