Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ import org.apache.spark.sql.catalyst.JavaTypeInference
import org.apache.spark.sql.catalyst.KotlinReflection
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.streaming.GroupState
import org.apache.spark.sql.streaming.GroupStateTimeout
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.*
import org.jetbrains.kotlinx.spark.extensions.KSparkExtensions
import scala.collection.Seq
Expand All @@ -42,6 +45,7 @@ import java.time.Instant
import java.time.LocalDate
import java.util.concurrent.ConcurrentHashMap
import kotlin.reflect.KClass
import kotlin.reflect.KProperty
import kotlin.reflect.KType
import kotlin.reflect.full.findAnnotation
import kotlin.reflect.full.isSubclassOf
Expand Down Expand Up @@ -159,6 +163,58 @@ inline fun <reified KEY, reified VALUE> KeyValueGroupedDataset<KEY, VALUE>.reduc
reduceGroups(ReduceFunction(func))
.map { t -> t._1 to t._2 }

inline fun <K, V, reified U> KeyValueGroupedDataset<K, V>.flatMapGroups(
noinline func: (key: K, values: Iterator<V>) -> Iterator<U>
): Dataset<U> = flatMapGroups(
FlatMapGroupsFunction(func),
encoder<U>()
)

fun <S> GroupState<S>.getOrNull(): S? = if (exists()) get() else null

operator fun <S> GroupState<S>.getValue(thisRef: Any?, property: KProperty<*>): S? = getOrNull()
operator fun <S> GroupState<S>.setValue(thisRef: Any?, property: KProperty<*>, value: S?): Unit = update(value)


inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.mapGroupsWithState(
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> U
): Dataset<U> = mapGroupsWithState(
MapGroupsWithStateFunction(func),
encoder<S>(),
encoder<U>()
)

inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.mapGroupsWithState(
timeoutConf: GroupStateTimeout,
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> U
): Dataset<U> = mapGroupsWithState(
MapGroupsWithStateFunction(func),
encoder<S>(),
encoder<U>(),
timeoutConf
)

inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.flatMapGroupsWithState(
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> Iterator<U>
): Dataset<U> = flatMapGroupsWithState(
FlatMapGroupsWithStateFunction(func),
outputMode,
encoder<S>(),
encoder<U>(),
timeoutConf
)

inline fun <K, V, U, reified R> KeyValueGroupedDataset<K, V>.cogroup(
other: KeyValueGroupedDataset<K, U>,
noinline func: (key: K, left: Iterator<V>, right: Iterator<U>) -> Iterator<R>
): Dataset<R> = cogroup(
other,
CoGroupFunction(func),
encoder<R>()
)

inline fun <T, reified R> Dataset<T>.downcast(): Dataset<R> = `as`(encoder<R>())
inline fun <reified R> Dataset<*>.`as`(): Dataset<R> = `as`(encoder<R>())
inline fun <reified R> Dataset<*>.to(): Dataset<R> = `as`(encoder<R>())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import ch.tutteli.atrium.api.fluent.en_GB.*
import ch.tutteli.atrium.domain.builders.migration.asExpect
import ch.tutteli.atrium.verbs.expect
import io.kotest.core.spec.style.ShouldSpec
import io.kotest.matchers.shouldBe
import org.apache.spark.sql.streaming.GroupState
import org.apache.spark.sql.streaming.GroupStateTimeout
import java.io.Serializable
import java.time.LocalDate

Expand Down Expand Up @@ -156,6 +159,76 @@ class ApiTest : ShouldSpec({

expect(result).asExpect().contains.inOrder.only.values(3, 5, 7, 9, 11)
}
should("perform operations on grouped datasets") {
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
.toDS()
.groupByKey { it.first }

val flatMapped = groupedDataset.flatMapGroups { key, values ->
val collected = values.asSequence().toList()

if (collected.size > 1) collected.iterator()
else emptyList<Pair<Int, String>>().iterator()
}

flatMapped.count() shouldBe 2

val mappedWithStateTimeoutConf = groupedDataset.mapGroupsWithState(GroupStateTimeout.NoTimeout()) { key, values, state: GroupState<Int> ->
var s by state
val collected = values.asSequence().toList()

s = key
s shouldBe key

s!! to collected.map { it.second }
}

mappedWithStateTimeoutConf.count() shouldBe 2

val mappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState<Int> ->
var s by state
val collected = values.asSequence().toList()

s = key
s shouldBe key

s!! to collected.map { it.second }
}

mappedWithState.count() shouldBe 2

val flatMappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState<Int> ->
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't these be moved to separate tests?

var s by state
val collected = values.asSequence().toList()

s = key
s shouldBe key

if (collected.size > 1) collected.iterator()
else emptyList<Pair<Int, String>>().iterator()
}

flatMappedWithState.count() shouldBe 2
}
should("be able to cogroup grouped datasets") {
val groupedDataset1 = listOf(1 to "a", 1 to "b", 2 to "c")
.toDS()
.groupByKey { it.first }

val groupedDataset2 = listOf(1 to "d", 5 to "e", 3 to "f")
.toDS()
.groupByKey { it.first }

val cogrouped = groupedDataset1.cogroup(groupedDataset2) { key, left, right ->
listOf(
key to (left.asSequence() + right.asSequence())
.map { it.second }
.toList()
).iterator()
}

cogrouped.count() shouldBe 4
}
}
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.*
import org.apache.spark.sql.Encoders.*
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.streaming.GroupState
import org.apache.spark.sql.streaming.GroupStateTimeout
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.*
import org.jetbrains.kotinx.spark.extensions.KSparkExtensions
import scala.reflect.ClassTag
Expand All @@ -38,6 +41,7 @@ import java.time.Instant
import java.time.LocalDate
import java.util.concurrent.ConcurrentHashMap
import kotlin.reflect.KClass
import kotlin.reflect.KProperty
import kotlin.reflect.KType
import kotlin.reflect.full.findAnnotation
import kotlin.reflect.full.isSubclassOf
Expand Down Expand Up @@ -149,6 +153,58 @@ inline fun <reified KEY, reified VALUE> KeyValueGroupedDataset<KEY, VALUE>.reduc
reduceGroups(ReduceFunction(func))
.map { t -> t._1 to t._2 }

inline fun <K, V, reified U> KeyValueGroupedDataset<K, V>.flatMapGroups(
noinline func: (key: K, values: Iterator<V>) -> Iterator<U>
): Dataset<U> = flatMapGroups(
FlatMapGroupsFunction(func),
encoder<U>()
)

fun <S> GroupState<S>.getOrNull(): S? = if (exists()) get() else null

operator fun <S> GroupState<S>.getValue(thisRef: Any?, property: KProperty<*>): S? = getOrNull()
operator fun <S> GroupState<S>.setValue(thisRef: Any?, property: KProperty<*>, value: S?): Unit = update(value)


inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.mapGroupsWithState(
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> U
): Dataset<U> = mapGroupsWithState(
MapGroupsWithStateFunction(func),
encoder<S>(),
encoder<U>()
)

inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.mapGroupsWithState(
timeoutConf: GroupStateTimeout,
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> U
): Dataset<U> = mapGroupsWithState(
MapGroupsWithStateFunction(func),
encoder<S>(),
encoder<U>(),
timeoutConf
)

inline fun <K, V, reified S, reified U> KeyValueGroupedDataset<K, V>.flatMapGroupsWithState(
outputMode: OutputMode,
timeoutConf: GroupStateTimeout,
noinline func: (key: K, values: Iterator<V>, state: GroupState<S>) -> Iterator<U>
): Dataset<U> = flatMapGroupsWithState(
FlatMapGroupsWithStateFunction(func),
outputMode,
encoder<S>(),
encoder<U>(),
timeoutConf
)

inline fun <K, V, U, reified R> KeyValueGroupedDataset<K, V>.cogroup(
other: KeyValueGroupedDataset<K, U>,
noinline func: (key: K, left: Iterator<V>, right: Iterator<U>) -> Iterator<R>
): Dataset<R> = cogroup(
other,
CoGroupFunction(func),
encoder<R>()
)

inline fun <T, reified R> Dataset<T>.downcast(): Dataset<R> = `as`(encoder<R>())
inline fun <reified R> Dataset<*>.`as`(): Dataset<R> = `as`(encoder<R>())
inline fun <reified R> Dataset<*>.to(): Dataset<R> = `as`(encoder<R>())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import ch.tutteli.atrium.api.fluent.en_GB.*
import ch.tutteli.atrium.domain.builders.migration.asExpect
import ch.tutteli.atrium.verbs.expect
import io.kotest.core.spec.style.ShouldSpec
import io.kotest.matchers.shouldBe
import org.apache.spark.sql.streaming.GroupState
import org.apache.spark.sql.streaming.GroupStateTimeout
import java.io.Serializable
import java.time.LocalDate

Expand Down Expand Up @@ -169,6 +172,76 @@ class ApiTest : ShouldSpec({

expect(result).asExpect().contains.inOrder.only.values(3, 5, 7, 9, 11)
}
should("perform operations on grouped datasets") {
val groupedDataset = listOf(1 to "a", 1 to "b", 2 to "c")
.toDS()
.groupByKey { it.first }

val flatMapped = groupedDataset.flatMapGroups { key, values ->
val collected = values.asSequence().toList()

if (collected.size > 1) collected.iterator()
else emptyList<Pair<Int, String>>().iterator()
}

flatMapped.count() shouldBe 2

val mappedWithStateTimeoutConf = groupedDataset.mapGroupsWithState(GroupStateTimeout.NoTimeout()) { key, values, state: GroupState<Int> ->
var s by state
val collected = values.asSequence().toList()

s = key
s shouldBe key

s!! to collected.map { it.second }
}

mappedWithStateTimeoutConf.count() shouldBe 2

val mappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState<Int> ->
var s by state
val collected = values.asSequence().toList()

s = key
s shouldBe key

s!! to collected.map { it.second }
}

mappedWithState.count() shouldBe 2

val flatMappedWithState = groupedDataset.mapGroupsWithState { key, values, state: GroupState<Int> ->
var s by state
val collected = values.asSequence().toList()

s = key
s shouldBe key

if (collected.size > 1) collected.iterator()
else emptyList<Pair<Int, String>>().iterator()
}

flatMappedWithState.count() shouldBe 2
}
should("be able to cogroup grouped datasets") {
val groupedDataset1 = listOf(1 to "a", 1 to "b", 2 to "c")
.toDS()
.groupByKey { it.first }

val groupedDataset2 = listOf(1 to "d", 5 to "e", 3 to "f")
.toDS()
.groupByKey { it.first }

val cogrouped = groupedDataset1.cogroup(groupedDataset2) { key, left, right ->
listOf(
key to (left.asSequence() + right.asSequence())
.map { it.second }
.toList()
).iterator()
}

cogrouped.count() shouldBe 4
}
}
}
})
Expand Down