Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add java priority queue, set, deque, collection coders #5520

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ package com.spotify.scio.coders.instances
import java.io.{InputStream, OutputStream}
import java.math.{BigDecimal, BigInteger}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, LocalTime, Period}
import java.util.UUID
import com.spotify.scio.IsJavaBean
import com.spotify.scio.coders.{Coder, CoderGrammar}
import com.spotify.scio.schemas.Schema
import com.spotify.scio.transforms.BaseAsyncLookupDoFn
import com.spotify.scio.util.ScioUtil
import org.apache.beam.sdk.coders.{Coder => _, _}
import org.apache.beam.sdk.coders.Coder.NonDeterministicException
import org.apache.beam.sdk.coders.{Coder => BCoder, _}
import org.apache.beam.sdk.schemas.SchemaCoder
import org.apache.beam.sdk.values.TypeDescriptor
import org.apache.beam.sdk.{coders => bcoders}
Expand All @@ -42,15 +42,55 @@ private[coders] object VoidCoder extends AtomicCoder[Void] {
override def structuralValue(value: Void): AnyRef = AnyRef
}

final private[coders] class JArrayListCoder[T](bc: BCoder[T])
extends IterableLikeCoder[T, java.util.ArrayList[T]](bc, "ArrayList") {

override def decodeToIterable(decodedElements: java.util.List[T]): java.util.ArrayList[T] =
decodedElements match {
case al: java.util.ArrayList[T] => al
case _ => new java.util.ArrayList[T](decodedElements)
}

override def consistentWithEquals(): Boolean = getElemCoder.consistentWithEquals()

override def verifyDeterministic(): Unit =
BCoder.verifyDeterministic(
this,
"JArrayListCoder element coder must be deterministic",
getElemCoder
)
}

final private[coders] class JPriorityQueueCoder[T](
bc: BCoder[T],
ordering: Ordering[T] // use Ordering instead of Comparator for serialization
) extends IterableLikeCoder[T, java.util.PriorityQueue[T]](bc, "PriorityQueue") {

override def decodeToIterable(decodedElements: java.util.List[T]): java.util.PriorityQueue[T] = {
val pq = new java.util.PriorityQueue[T](ordering)
pq.addAll(decodedElements)
pq
}

override def encode(value: java.util.PriorityQueue[T], os: OutputStream): Unit =
super.encode(value, os)

override def verifyDeterministic(): Unit =
throw new NonDeterministicException(
this,
"Ordering of elements in a priority queue may be non-deterministic."
)
}

//
// Java Coders
//
trait JavaCoders extends CoderGrammar with JavaBeanCoders {
implicit lazy val voidCoder: Coder[Void] = beam[Void](VoidCoder)

implicit lazy val uuidCoder: Coder[UUID] =
implicit lazy val uuidCoder: Coder[java.util.UUID] =
xmap(Coder[(Long, Long)])(
{ case (msb, lsb) => new UUID(msb, lsb) },
{ case (msb, lsb) => new java.util.UUID(msb, lsb) },
uuid => (uuid.getMostSignificantBits, uuid.getLeastSignificantBits)
)

Expand All @@ -63,11 +103,26 @@ trait JavaCoders extends CoderGrammar with JavaBeanCoders {
implicit def jIterableCoder[T](implicit c: Coder[T]): Coder[java.lang.Iterable[T]] =
transform(c)(bc => beam(bcoders.IterableCoder.of(bc)))

implicit def jCollectionCoder[T](implicit c: Coder[T]): Coder[java.util.Collection[T]] =
transform(c)(bc => beam(bcoders.CollectionCoder.of(bc)))

implicit def jListCoder[T](implicit c: Coder[T]): Coder[java.util.List[T]] =
transform(c)(bc => beam(bcoders.ListCoder.of(bc)))

implicit def jArrayListCoder[T](implicit c: Coder[T]): Coder[java.util.ArrayList[T]] =
xmap(jListCoder[T])(new java.util.ArrayList(_), identity)
transform(c)(bc => beam(new JArrayListCoder[T](bc)))

implicit def jSetCoder[T](implicit c: Coder[T]): Coder[java.util.Set[T]] =
transform(c)(bc => beam(bcoders.SetCoder.of(bc)))

implicit def jDequeCoder[T](implicit c: Coder[T]): Coder[java.util.Deque[T]] =
transform(c)(bc => beam(bcoders.DequeCoder.of(bc)))

implicit def jPriorityQueueCoder[T](implicit
c: Coder[T],
ord: Ordering[T]
): Coder[java.util.PriorityQueue[T]] =
transform(c)(bc => beam(new JPriorityQueueCoder[T](bc, ord)))

implicit def jMapCoder[K, V](implicit ck: Coder[K], cv: Coder[V]): Coder[java.util.Map[K, V]] =
transform(ck)(bk => transform(cv)(bv => beam(bcoders.MapCoder.of(bk, bv))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,19 @@ private[coders] class MutableSetCoder[T](bc: BCoder[T]) extends SeqLikeCoder[m.S
}

private class SortedSetCoder[T: Ordering](bc: BCoder[T]) extends SeqLikeCoder[SortedSet, T](bc) {

override def encode(value: SortedSet[T], os: OutputStream): Unit =
super.encode(value, os)

override def decode(inStream: InputStream): SortedSet[T] =
decode(inStream, SortedSet.newBuilder[T])
}

private class MutablePriorityQueueCoder[T: Ordering](bc: BCoder[T])
private[coders] class MutablePriorityQueueCoder[T: Ordering](bc: BCoder[T])
extends SeqLikeCoder[m.PriorityQueue, T](bc) {
override def consistentWithEquals(): Boolean = false // PriorityQueue does not define equality
override def encode(value: m.PriorityQueue[T], os: OutputStream): Unit =
super.encode(value, os)
override def decode(inStream: InputStream): m.PriorityQueue[T] =
decode(inStream, m.PriorityQueue.newBuilder[T])
}
Expand Down Expand Up @@ -361,13 +367,8 @@ private class MutableMapCoder[K, V](kc: BCoder[K], vc: BCoder[V])
private[coders] class SortedMapCoder[K: Ordering, V](kc: BCoder[K], vc: BCoder[V])
extends MapLikeCoder[K, V, SortedMap](kc, vc) {

override def encode(value: SortedMap[K, V], os: OutputStream): Unit = {
require(
value.ordering == Ordering[K],
"SortedMap ordering does not match SortedMapCoder ordering"
)
override def encode(value: SortedMap[K, V], os: OutputStream): Unit =
super.encode(value, os)
}

override def decode(is: InputStream): SortedMap[K, V] =
decode(is, SortedMap.newBuilder[K, V])
Expand Down
138 changes: 108 additions & 30 deletions scio-core/src/test/scala/com/spotify/scio/coders/CoderTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ import org.apache.beam.sdk.options.{PipelineOptions, PipelineOptionsFactory}
import org.apache.beam.sdk.util.SerializableUtils
import org.apache.beam.sdk.extensions.protobuf.ByteStringCoder
import org.apache.beam.sdk.schemas.SchemaCoder
import org.apache.commons.collections.IteratorUtils
import org.apache.commons.io.output.NullOutputStream
import org.scalactic.Equality
import org.scalatest.Assertion
import org.scalatest.exceptions.TestFailedException
import org.scalatest.flatspec.AnyFlatSpec
Expand All @@ -43,9 +45,9 @@ import java.io.{ByteArrayInputStream, ObjectOutputStream, ObjectStreamClass}
import java.nio.charset.Charset
import java.time._
import java.util.UUID

import scala.collection.{mutable => mut}
import scala.collection.compat._
import scala.collection.compat.immutable.ArraySeq
import scala.collection.immutable.SortedMap
import scala.jdk.CollectionConverters._

Expand Down Expand Up @@ -154,6 +156,22 @@ final class CoderTest extends AnyFlatSpec with Matchers {
beOfType[CoderTransform[_, _]] and
materializeTo[ArrayCoder[_]] and
beFullyCompliantNotConsistentWithEquals()

{
implicit val pqOrd: Ordering[String] = Ordering.String.on(_.reverse)
val pq = new mut.PriorityQueue[String]()(pqOrd)
pq ++= s

implicit val pqEq: Equality[mut.PriorityQueue[String]] = {
case (a: mut.PriorityQueue[String], b: mut.PriorityQueue[_]) => a.toList == b.toList
case _ => false
}

pq coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[MutablePriorityQueueCoder[_]] and
beFullyCompliantNotConsistentWithEquals()
}
}

it should "support Scala enumerations" in {
Expand Down Expand Up @@ -321,29 +339,98 @@ final class CoderTest extends AnyFlatSpec with Matchers {
}

it should "support Java collections" in {
import java.util.{ArrayList => jArrayList, List => jList, Map => jMap}
val is = 1 to 10
val s: jList[String] = is.map(_.toString).asJava
val m: jMap[String, Int] = is
.map(v => v.toString -> v)
.toMap
.asJava
val arrayList = new jArrayList(s)
import java.lang.{Iterable => JIterable}
import java.util.{
ArrayList => JArrayList,
Collection => JCollection,
List => JList,
Set => JSet,
Map => JMap,
PriorityQueue => JPriorityQueue
}

s coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[beam.ListCoder[_]] and
beFullyCompliant()
val elems = (1 to 10).map(_.toString)

m coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[org.apache.beam.sdk.coders.MapCoder[_, _]] and
beFullyCompliantNonDeterministic()
{
val i: JIterable[String] = (elems: Iterable[String]).asJava
implicit val iEq: Equality[JIterable[String]] = {
case (xs: JIterable[String], ys: JIterable[String]) =>
IteratorUtils.toArray(xs.iterator()) sameElements IteratorUtils.toArray(ys.iterator())
case _ => false
}

arrayList coderShould roundtrip() and
beOfType[Transform[_, _]] and
materializeToTransformOf[beam.ListCoder[_]] and
beFullyCompliant()
i coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[beam.IterableCoder[_]] and
beNotConsistentWithEquals()
}

{
val c: JCollection[String] = elems.asJavaCollection
implicit val iEq: Equality[JCollection[String]] = {
case (xs: JCollection[String], ys: JCollection[String]) =>
IteratorUtils.toArray(xs.iterator()) sameElements IteratorUtils.toArray(ys.iterator())
case _ => false
}
c coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[beam.CollectionCoder[_]] and
beNotConsistentWithEquals()
}

{
val l: JList[String] = elems.asJava
l coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[beam.ListCoder[_]] and
beFullyCompliant()
}

{
val al: JArrayList[String] = new JArrayList(elems.asJava)
al coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[JArrayListCoder[_]] and
beFullyCompliant()
}

{
val s: JSet[String] = elems.toSet.asJava
s coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[beam.SetCoder[_]] and
structuralValueConsistentWithEquals() and
beSerializable()
}

{
val m: JMap[String, Int] = (1 to 10)
.map(v => v.toString -> v)
.toMap
.asJava
m coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[beam.MapCoder[_, _]] and
beFullyCompliantNonDeterministic()
}

{
implicit val pqOrd: Ordering[String] = Ordering.String.on(_.reverse)
val pq = new JPriorityQueue[String](pqOrd)
pq.addAll(elems.asJavaCollection)

implicit val pqEq: Equality[java.util.PriorityQueue[String]] = {
case (a: JPriorityQueue[String], b: JPriorityQueue[_]) =>
a.toArray sameElements b.toArray
case _ => false
}

pq coderShould roundtrip() and
beOfType[CoderTransform[_, _]] and
materializeTo[JPriorityQueueCoder[_]] and
beSerializable() and
structuralValueConsistentWithEquals()
}
}

it should "Derive serializable coders" in {
Expand Down Expand Up @@ -906,15 +993,6 @@ final class CoderTest extends AnyFlatSpec with Matchers {
)
}

it should "not accept SortedMap when ordering doesn't match with coder" in {
val sm = SortedMap(1 -> "1", 2 -> "2")(Ordering[Int].reverse)
// implicit SortedMapCoder will use implicit default Int ordering
val e = the[IllegalArgumentException] thrownBy {
sm coderShould roundtrip()
}
e.getMessage shouldBe "requirement failed: SortedMap ordering does not match SortedMapCoder ordering"
}

/*
* Case class nested inside another class. Do not move outside
* */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,26 +180,26 @@ object CoderAssertions {

/** Passes all checks on Beam coder */
def beFullyCompliant[T <: Object: ClassTag](): CoderAssertion[T] = ctx => {
structuralValueConsistentWithEquals()(ctx)
beSerializable()(ctx)
structuralValueConsistentWithEquals()(ctx)
beConsistentWithEquals()(ctx)
bytesCountTested[T]().apply(ctx)
beDeterministic()(ctx)
}

def beFullyCompliantNonDeterministic[T <: Object: ClassTag](): CoderAssertion[T] =
ctx => {
structuralValueConsistentWithEquals()(ctx)
beSerializable()(ctx)
structuralValueConsistentWithEquals()(ctx)
beConsistentWithEquals()(ctx)
bytesCountTested[T]().apply(ctx)
beNonDeterministic()(ctx)
}

def beFullyCompliantNotConsistentWithEquals[T <: Object: ClassTag](): CoderAssertion[T] =
ctx => {
structuralValueConsistentWithEquals()(ctx)
beSerializable()(ctx)
structuralValueConsistentWithEquals()(ctx)
beNotConsistentWithEquals()(ctx)
bytesCountTested[T]().apply(ctx)
beDeterministic()(ctx)
Expand Down