Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add java priority queue, set, deque, collection coders #5520

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ package com.spotify.scio.coders.instances
import java.io.{InputStream, OutputStream}
import java.math.{BigDecimal, BigInteger}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, LocalTime, Period}
import java.util.UUID
import com.spotify.scio.IsJavaBean
import com.spotify.scio.coders.{Coder, CoderGrammar}
import com.spotify.scio.schemas.Schema
import com.spotify.scio.transforms.BaseAsyncLookupDoFn
import com.spotify.scio.util.ScioUtil
import org.apache.beam.sdk.coders.{Coder => _, _}
import org.apache.beam.sdk.coders.Coder.NonDeterministicException
import org.apache.beam.sdk.coders.{Coder => BCoder, _}
import org.apache.beam.sdk.schemas.SchemaCoder
import org.apache.beam.sdk.values.TypeDescriptor
import org.apache.beam.sdk.{coders => bcoders}
Expand All @@ -42,15 +42,60 @@ private[coders] object VoidCoder extends AtomicCoder[Void] {
override def structuralValue(value: Void): AnyRef = AnyRef
}

final private[coders] class JArrayListCoder[T](bc: BCoder[T])
extends IterableLikeCoder[T, java.util.ArrayList[T]](bc, "ArrayList") {

override def decodeToIterable(decodedElements: java.util.List[T]): java.util.ArrayList[T] =
decodedElements match {
case al: java.util.ArrayList[T] => al
case _ => new java.util.ArrayList[T](decodedElements)
}

override def consistentWithEquals(): Boolean = getElemCoder.consistentWithEquals()

override def verifyDeterministic(): Unit =
BCoder.verifyDeterministic(
this,
"JArrayListCoder element coder must be deterministic",
getElemCoder
)
}

final private[coders] class JPriorityQueueCoder[T](
bc: BCoder[T],
ordering: Ordering[T] // use Ordering instead of Comparator for serialization
) extends IterableLikeCoder[T, java.util.PriorityQueue[T]](bc, "PriorityQueue") {

override def decodeToIterable(decodedElements: java.util.List[T]): java.util.PriorityQueue[T] = {
val pq = new java.util.PriorityQueue[T](ordering)
pq.addAll(decodedElements)
pq
}

override def encode(value: java.util.PriorityQueue[T], os: OutputStream): Unit = {
require(
value.comparator() == ordering,
"PriorityQueue comparator does not match JPriorityQueueCoder comparator"
)
super.encode(value, os)
}

override def verifyDeterministic(): Unit =
throw new NonDeterministicException(
this,
"Ordering of elements in a priority queue may be non-deterministic."
)
}

//
// Java Coders
//
trait JavaCoders extends CoderGrammar with JavaBeanCoders {
implicit lazy val voidCoder: Coder[Void] = beam[Void](VoidCoder)

implicit lazy val uuidCoder: Coder[UUID] =
implicit lazy val uuidCoder: Coder[java.util.UUID] =
xmap(Coder[(Long, Long)])(
{ case (msb, lsb) => new UUID(msb, lsb) },
{ case (msb, lsb) => new java.util.UUID(msb, lsb) },
uuid => (uuid.getMostSignificantBits, uuid.getLeastSignificantBits)
)

Expand All @@ -63,11 +108,26 @@ trait JavaCoders extends CoderGrammar with JavaBeanCoders {
implicit def jIterableCoder[T](implicit c: Coder[T]): Coder[java.lang.Iterable[T]] =
transform(c)(bc => beam(bcoders.IterableCoder.of(bc)))

implicit def jCollectionCoder[T](implicit c: Coder[T]): Coder[java.util.Collection[T]] =
transform(c)(bc => beam(bcoders.CollectionCoder.of(bc)))

implicit def jListCoder[T](implicit c: Coder[T]): Coder[java.util.List[T]] =
transform(c)(bc => beam(bcoders.ListCoder.of(bc)))

implicit def jArrayListCoder[T](implicit c: Coder[T]): Coder[java.util.ArrayList[T]] =
xmap(jListCoder[T])(new java.util.ArrayList(_), identity)
transform(c)(bc => beam(new JArrayListCoder[T](bc)))

implicit def jSetCoder[T](implicit c: Coder[T]): Coder[java.util.Set[T]] =
transform(c)(bc => beam(bcoders.SetCoder.of(bc)))

implicit def jDequeCoder[T](implicit c: Coder[T]): Coder[java.util.Deque[T]] =
transform(c)(bc => beam(bcoders.DequeCoder.of(bc)))

implicit def jPriorityQueueCoder[T](implicit
c: Coder[T],
ord: Ordering[T]
): Coder[java.util.PriorityQueue[T]] =
transform(c)(bc => beam(new JPriorityQueueCoder[T](bc, ord)))

implicit def jMapCoder[K, V](implicit ck: Coder[K], cv: Coder[V]): Coder[java.util.Map[K, V]] =
transform(ck)(bk => transform(cv)(bv => beam(bcoders.MapCoder.of(bk, bv))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,29 @@ private[coders] class MutableSetCoder[T](bc: BCoder[T]) extends SeqLikeCoder[m.S
}

private class SortedSetCoder[T: Ordering](bc: BCoder[T]) extends SeqLikeCoder[SortedSet, T](bc) {

override def encode(value: SortedSet[T], os: OutputStream): Unit = {
require(
value.ordering == Ordering[T],
"SortedSet ordering does not match SortedSetCoder ordering"
)
super.encode(value, os)
}

override def decode(inStream: InputStream): SortedSet[T] =
decode(inStream, SortedSet.newBuilder[T])
}

private class MutablePriorityQueueCoder[T: Ordering](bc: BCoder[T])
private[coders] class MutablePriorityQueueCoder[T: Ordering](bc: BCoder[T])
extends SeqLikeCoder[m.PriorityQueue, T](bc) {
override def consistentWithEquals(): Boolean = false // PriorityQueue does not define equality
override def encode(value: m.PriorityQueue[T], os: OutputStream): Unit = {
require(
value.ord == Ordering[T],
"PriorityQueue ordering does not match MutablePriorityQueueCoder ordering"
)
super.encode(value, os)
}
override def decode(inStream: InputStream): m.PriorityQueue[T] =
decode(inStream, m.PriorityQueue.newBuilder[T])
}
Expand Down
131 changes: 110 additions & 21 deletions scio-core/src/test/scala/com/spotify/scio/coders/CoderTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ import org.apache.beam.sdk.options.{PipelineOptions, PipelineOptionsFactory}
import org.apache.beam.sdk.util.SerializableUtils
import org.apache.beam.sdk.extensions.protobuf.ByteStringCoder
import org.apache.beam.sdk.schemas.SchemaCoder
import org.apache.commons.collections.IteratorUtils
import org.apache.commons.io.output.NullOutputStream
import org.scalactic.Equality
import org.scalatest.Assertion
import org.scalatest.exceptions.TestFailedException
import org.scalatest.flatspec.AnyFlatSpec
Expand All @@ -43,9 +45,9 @@ import java.io.{ByteArrayInputStream, ObjectOutputStream, ObjectStreamClass}
import java.nio.charset.Charset
import java.time._
import java.util.UUID

import scala.collection.{mutable => mut}
import scala.collection.compat._
import scala.collection.compat.immutable.ArraySeq
import scala.collection.immutable.SortedMap
import scala.jdk.CollectionConverters._

Expand Down Expand Up @@ -154,6 +156,23 @@ final class CoderTest extends AnyFlatSpec with Matchers {
beOfType[CoderTransform[_, _]] and
materializeTo[ArrayCoder[_]] and
beFullyCompliantNotConsistentWithEquals()

{
// custom ordering must have stable equal after serialization
implicit val pqOrd: Ordering[String] = FlippedStringOrdering
Copy link
Contributor

@RustedBones RustedBones Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really happy about this

val pq = new mut.PriorityQueue[String]()(pqOrd)
pq ++= s

implicit val pqEq: Equality[mut.PriorityQueue[String]] = {
case (a: mut.PriorityQueue[String], b: mut.PriorityQueue[_]) => a.toList == b.toList
case _ => false
}

pq coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[MutablePriorityQueueCoder[_]] and
beFullyCompliantNotConsistentWithEquals()
}
}

it should "support Scala enumerations" in {
Expand Down Expand Up @@ -321,29 +340,99 @@ final class CoderTest extends AnyFlatSpec with Matchers {
}

it should "support Java collections" in {
import java.util.{ArrayList => jArrayList, List => jList, Map => jMap}
val is = 1 to 10
val s: jList[String] = is.map(_.toString).asJava
val m: jMap[String, Int] = is
.map(v => v.toString -> v)
.toMap
.asJava
val arrayList = new jArrayList(s)
import java.lang.{Iterable => JIterable}
import java.util.{
ArrayList => JArrayList,
Collection => JCollection,
List => JList,
Set => JSet,
Map => JMap,
PriorityQueue => JPriorityQueue
}

s coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[beam.ListCoder[_]] and
beFullyCompliant()
val elems = (1 to 10).map(_.toString)

m coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[org.apache.beam.sdk.coders.MapCoder[_, _]] and
beFullyCompliantNonDeterministic()
{
val i: JIterable[String] = (elems: Iterable[String]).asJava
implicit val iEq: Equality[JIterable[String]] = {
case (xs: JIterable[String], ys: JIterable[String]) =>
IteratorUtils.toArray(xs.iterator()) sameElements IteratorUtils.toArray(ys.iterator())
case _ => false
}

arrayList coderShould roundtrip() and
beOfType[Transform[_, _]] and
materializeToTransformOf[beam.ListCoder[_]] and
beFullyCompliant()
i coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[beam.IterableCoder[_]] and
beNotConsistentWithEquals()
}

{
val c: JCollection[String] = elems.asJavaCollection
implicit val iEq: Equality[JCollection[String]] = {
case (xs: JCollection[String], ys: JCollection[String]) =>
IteratorUtils.toArray(xs.iterator()) sameElements IteratorUtils.toArray(ys.iterator())
case _ => false
}
c coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[beam.CollectionCoder[_]] and
beNotConsistentWithEquals()
}

{
val l: JList[String] = elems.asJava
l coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[beam.ListCoder[_]] and
beFullyCompliant()
}

{
val al: JArrayList[String] = new JArrayList(elems.asJava)
al coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[JArrayListCoder[_]] and
beFullyCompliant()
}

{
val s: JSet[String] = elems.toSet.asJava
s coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[beam.SetCoder[_]] and
structuralValueConsistentWithEquals() and
beSerializable()
}

{
val m: JMap[String, Int] = (1 to 10)
.map(v => v.toString -> v)
.toMap
.asJava
m coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[beam.MapCoder[_, _]] and
beFullyCompliantNonDeterministic()
}

{
// custom ordering must have stable equal after serialization
implicit val pqOrd: Ordering[String] = FlippedStringOrdering
val pq = new JPriorityQueue[String](pqOrd)
pq.addAll(elems.asJavaCollection)

implicit val pqEq: Equality[java.util.PriorityQueue[String]] = {
case (a: JPriorityQueue[String], b: JPriorityQueue[_]) =>
a.toArray sameElements b.toArray
case _ => false
}

pq coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[JPriorityQueueCoder[_]] and
beSerializable() and
structuralValueConsistentWithEquals()
}
}

it should "Derive serializable coders" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ import scala.jdk.CollectionConverters._

object CoderTestUtils {

object FlippedStringOrdering extends Ordering[String] {
override def compare(x: String, y: String): Int = x.reverse.compareTo(y.reverse)
}

def testRoundTrip[T](coder: BCoder[T], value: T): Boolean =
testRoundTrip(coder, coder, value)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,26 +180,26 @@ object CoderAssertions {

/** Passes all checks on Beam coder */
def beFullyCompliant[T <: Object: ClassTag](): CoderAssertion[T] = ctx => {
structuralValueConsistentWithEquals()(ctx)
beSerializable()(ctx)
structuralValueConsistentWithEquals()(ctx)
beConsistentWithEquals()(ctx)
bytesCountTested[T]().apply(ctx)
beDeterministic()(ctx)
}

def beFullyCompliantNonDeterministic[T <: Object: ClassTag](): CoderAssertion[T] =
ctx => {
structuralValueConsistentWithEquals()(ctx)
beSerializable()(ctx)
structuralValueConsistentWithEquals()(ctx)
beConsistentWithEquals()(ctx)
bytesCountTested[T]().apply(ctx)
beNonDeterministic()(ctx)
}

def beFullyCompliantNotConsistentWithEquals[T <: Object: ClassTag](): CoderAssertion[T] =
ctx => {
structuralValueConsistentWithEquals()(ctx)
beSerializable()(ctx)
structuralValueConsistentWithEquals()(ctx)
beNotConsistentWithEquals()(ctx)
bytesCountTested[T]().apply(ctx)
beDeterministic()(ctx)
Expand Down
Loading