diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sort.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sort.kt index 37a009dfc5..cdafcc1dab 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sort.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sort.kt @@ -33,6 +33,15 @@ public interface SortDsl : ColumnsSelectionDsl { public fun KProperty.nullsLast(flag: Boolean = true): ColumnSet = toColumnAccessor().nullsLast(flag) } +/** + * [SortColumnsSelector] is used to express or select multiple columns to sort by, represented by [ColumnSet]``, + * using the context of [SortDsl]`` as `this` and `it`. + * + * So: + * ```kotlin + * SortDsl.(it: SortDsl) -> ColumnSet + * ``` + */ public typealias SortColumnsSelector = Selector, ColumnSet> // region DataColumn diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/DataFrameReceiver.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/DataFrameReceiver.kt index cc8ab39487..3345f87fe9 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/DataFrameReceiver.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/DataFrameReceiver.kt @@ -36,12 +36,17 @@ internal open class DataFrameReceiver( private val unresolvedColumnsPolicy: UnresolvedColumnsPolicy ) : DataFrameReceiverBase(source.unbox()), SingleColumn> { - private fun DataColumn?.check(path: ColumnPath): DataColumn? = + private fun DataColumn?.check(path: ColumnPath): DataColumn = when (this) { null -> when (unresolvedColumnsPolicy) { - UnresolvedColumnsPolicy.Create, UnresolvedColumnsPolicy.Skip -> MissingColumnGroup(path, this@DataFrameReceiver).asDataColumn().cast() + UnresolvedColumnsPolicy.Create, UnresolvedColumnsPolicy.Skip -> MissingColumnGroup( + path, + this@DataFrameReceiver + ).asDataColumn().cast() + UnresolvedColumnsPolicy.Fail -> error("Column $path not found") } + is MissingDataColumn -> this is ColumnGroup<*> -> ColumnGroupWithParent(null, this).asDataColumn().cast() else -> this diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/sort.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/sort.kt index f5d9b75ca7..c0a7ae71db 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/sort.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/api/sort.kt @@ -16,22 +16,29 @@ import org.jetbrains.kotlinx.dataframe.columns.UnresolvedColumnsPolicy import org.jetbrains.kotlinx.dataframe.columns.ValueColumn import org.jetbrains.kotlinx.dataframe.impl.columns.addPath import org.jetbrains.kotlinx.dataframe.impl.columns.assertIsComparable +import org.jetbrains.kotlinx.dataframe.impl.columns.missing.MissingColumnGroup import org.jetbrains.kotlinx.dataframe.impl.columns.resolve import org.jetbrains.kotlinx.dataframe.impl.columns.toColumns import org.jetbrains.kotlinx.dataframe.kind import org.jetbrains.kotlinx.dataframe.nrow -internal fun GroupBy.sortByImpl(columns: SortColumnsSelector): GroupBy { - return toDataFrame() +@Suppress("UNCHECKED_CAST", "RemoveExplicitTypeArguments") +internal fun GroupBy.sortByImpl(columns: SortColumnsSelector): GroupBy = + toDataFrame() + + // sort the individual groups by the columns specified .update { groups } .with { it.sortByImpl(UnresolvedColumnsPolicy.Skip, columns) } + + // sort the groups by the columns specified (must be either be the keys column or "groups") + // will do nothing if the columns specified are not the keys column or "groups" .sortByImpl(UnresolvedColumnsPolicy.Skip, columns as SortColumnsSelector) - .asGroupBy { it.getFrameColumn(groups.name()).castFrameColumn() } -} + + .asGroupBy { it.getFrameColumn(groups.name()).castFrameColumn() } internal fun DataFrame.sortByImpl( unresolvedColumnsPolicy: UnresolvedColumnsPolicy = UnresolvedColumnsPolicy.Fail, - columns: SortColumnsSelector + columns: SortColumnsSelector, ): DataFrame { val sortColumns = getSortColumns(columns, unresolvedColumnsPolicy) if (sortColumns.isEmpty()) return this @@ -61,9 +68,10 @@ internal fun AnyCol.createComparator(nullsLast: Boolean): java.util.Comparator DataFrame.getSortColumns( columns: SortColumnsSelector, - unresolvedColumnsPolicy: UnresolvedColumnsPolicy -): List> { - return columns.toColumns().resolve(this, unresolvedColumnsPolicy) + unresolvedColumnsPolicy: UnresolvedColumnsPolicy, +): List> = + columns.toColumns().resolve(this, unresolvedColumnsPolicy) + .filterNot { it.data is MissingColumnGroup<*> } // can appear using [DataColumn?.check] with UnresolvedColumnsPolicy.Skip .map { when (val col = it.data) { is SortColumnDescriptor<*> -> col @@ -71,7 +79,6 @@ internal fun DataFrame.getSortColumns( else -> throw IllegalStateException("Can not use ${col.kind} as sort column") } } -} internal enum class SortFlag { Reversed, NullsLast } @@ -86,12 +93,14 @@ internal fun ColumnWithPath.addFlag(flag: SortFlag): ColumnWithPath { SortFlag.NullsLast -> SortColumnDescriptor(col.column, col.direction, true) } } + is ValueColumn -> { when (flag) { SortFlag.Reversed -> SortColumnDescriptor(col, SortDirection.Desc) SortFlag.NullsLast -> SortColumnDescriptor(col, SortDirection.Asc, true) } } + else -> throw IllegalArgumentException("Can not apply sort flag to column kind ${col.kind}") }.addPath(path) } @@ -103,7 +112,7 @@ internal class ColumnsWithSortFlag(val column: ColumnSet, val flag: SortFl internal class SortColumnDescriptor( val column: ValueColumn, val direction: SortDirection = SortDirection.Asc, - val nullsLast: Boolean = false + val nullsLast: Boolean = false, ) : ValueColumn by column internal enum class SortDirection { Asc, Desc } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/sortGroupedDataframe.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/sortGroupedDataframe.kt new file mode 100644 index 0000000000..76d7b07edb --- /dev/null +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/sortGroupedDataframe.kt @@ -0,0 +1,72 @@ +package org.jetbrains.kotlinx.dataframe.api + +import io.kotest.matchers.shouldBe +import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.alsoDebug +import org.jetbrains.kotlinx.dataframe.io.read +import org.junit.Test + +class SortGroupedDataframeTests { + + @Test + fun `Sorted grouped iris dataset`() { + val irisData = DataFrame.read("src/test/resources/irisDataset.csv") + irisData.alsoDebug() + + irisData.groupBy("variety").let { + it.sortBy("petal.length").toString() shouldBe + it.sortBy { it["petal.length"] }.toString() + } + } + + enum class State { + Idle, Productive, Maintenance + } + + @Test + fun test4() { + class Event(val toolId: String, val state: State, val timestamp: Long) + + val tool1 = "tool_1" + val tool2 = "tool_2" + val tool3 = "tool_3" + + val events = listOf( + Event(tool1, State.Idle, 0), + Event(tool1, State.Productive, 5), + Event(tool2, State.Idle, 0), + Event(tool2, State.Maintenance, 10), + Event(tool2, State.Idle, 20), + Event(tool3, State.Idle, 0), + Event(tool3, State.Productive, 25), + ).toDataFrame() + + val lastTimestamp = events.maxOf { getValue("timestamp") } + val groupBy = events + .groupBy("toolId") + .sortBy("timestamp") + .add("stateDuration") { + (next()?.getValue("timestamp") ?: lastTimestamp) - getValue("timestamp") + } + + groupBy.toDataFrame().alsoDebug() + groupBy.schema().print() + groupBy.keys.print() + groupBy.keys[0].print() + + val df1 = groupBy.updateGroups { + val missingValues = State.values().asList().toDataFrame { + "state" from { it } + } + + val df = it + .fullJoin(missingValues, "state") + .fillNulls("stateDuration") + .with { 100L } + + df.groupBy("state").sumFor("stateDuration") + } + + df1.toDataFrame().alsoDebug().isNotEmpty() shouldBe true + } +} diff --git a/core/src/test/resources/irisDataset.csv b/core/src/test/resources/irisDataset.csv new file mode 100644 index 0000000000..bf14e161bf --- /dev/null +++ b/core/src/test/resources/irisDataset.csv @@ -0,0 +1,151 @@ +"sepal.length","sepal.width","petal.length","petal.width","variety" +5.1,3.5,1.4,.2,"Setosa" +4.9,3,1.4,.2,"Setosa" +4.7,3.2,1.3,.2,"Setosa" +4.6,3.1,1.5,.2,"Setosa" +5,3.6,1.4,.2,"Setosa" +5.4,3.9,1.7,.4,"Setosa" +4.6,3.4,1.4,.3,"Setosa" +5,3.4,1.5,.2,"Setosa" +4.4,2.9,1.4,.2,"Setosa" +4.9,3.1,1.5,.1,"Setosa" +5.4,3.7,1.5,.2,"Setosa" +4.8,3.4,1.6,.2,"Setosa" +4.8,3,1.4,.1,"Setosa" +4.3,3,1.1,.1,"Setosa" +5.8,4,1.2,.2,"Setosa" +5.7,4.4,1.5,.4,"Setosa" +5.4,3.9,1.3,.4,"Setosa" +5.1,3.5,1.4,.3,"Setosa" +5.7,3.8,1.7,.3,"Setosa" +5.1,3.8,1.5,.3,"Setosa" +5.4,3.4,1.7,.2,"Setosa" +5.1,3.7,1.5,.4,"Setosa" +4.6,3.6,1,.2,"Setosa" +5.1,3.3,1.7,.5,"Setosa" +4.8,3.4,1.9,.2,"Setosa" +5,3,1.6,.2,"Setosa" +5,3.4,1.6,.4,"Setosa" +5.2,3.5,1.5,.2,"Setosa" +5.2,3.4,1.4,.2,"Setosa" +4.7,3.2,1.6,.2,"Setosa" +4.8,3.1,1.6,.2,"Setosa" +5.4,3.4,1.5,.4,"Setosa" +5.2,4.1,1.5,.1,"Setosa" +5.5,4.2,1.4,.2,"Setosa" +4.9,3.1,1.5,.2,"Setosa" +5,3.2,1.2,.2,"Setosa" +5.5,3.5,1.3,.2,"Setosa" +4.9,3.6,1.4,.1,"Setosa" +4.4,3,1.3,.2,"Setosa" +5.1,3.4,1.5,.2,"Setosa" +5,3.5,1.3,.3,"Setosa" +4.5,2.3,1.3,.3,"Setosa" +4.4,3.2,1.3,.2,"Setosa" +5,3.5,1.6,.6,"Setosa" +5.1,3.8,1.9,.4,"Setosa" +4.8,3,1.4,.3,"Setosa" +5.1,3.8,1.6,.2,"Setosa" +4.6,3.2,1.4,.2,"Setosa" +5.3,3.7,1.5,.2,"Setosa" +5,3.3,1.4,.2,"Setosa" +7,3.2,4.7,1.4,"Versicolor" +6.4,3.2,4.5,1.5,"Versicolor" +6.9,3.1,4.9,1.5,"Versicolor" +5.5,2.3,4,1.3,"Versicolor" +6.5,2.8,4.6,1.5,"Versicolor" +5.7,2.8,4.5,1.3,"Versicolor" +6.3,3.3,4.7,1.6,"Versicolor" +4.9,2.4,3.3,1,"Versicolor" +6.6,2.9,4.6,1.3,"Versicolor" +5.2,2.7,3.9,1.4,"Versicolor" +5,2,3.5,1,"Versicolor" +5.9,3,4.2,1.5,"Versicolor" +6,2.2,4,1,"Versicolor" +6.1,2.9,4.7,1.4,"Versicolor" +5.6,2.9,3.6,1.3,"Versicolor" +6.7,3.1,4.4,1.4,"Versicolor" +5.6,3,4.5,1.5,"Versicolor" +5.8,2.7,4.1,1,"Versicolor" +6.2,2.2,4.5,1.5,"Versicolor" +5.6,2.5,3.9,1.1,"Versicolor" +5.9,3.2,4.8,1.8,"Versicolor" +6.1,2.8,4,1.3,"Versicolor" +6.3,2.5,4.9,1.5,"Versicolor" +6.1,2.8,4.7,1.2,"Versicolor" +6.4,2.9,4.3,1.3,"Versicolor" +6.6,3,4.4,1.4,"Versicolor" +6.8,2.8,4.8,1.4,"Versicolor" +6.7,3,5,1.7,"Versicolor" +6,2.9,4.5,1.5,"Versicolor" +5.7,2.6,3.5,1,"Versicolor" +5.5,2.4,3.8,1.1,"Versicolor" +5.5,2.4,3.7,1,"Versicolor" +5.8,2.7,3.9,1.2,"Versicolor" +6,2.7,5.1,1.6,"Versicolor" +5.4,3,4.5,1.5,"Versicolor" +6,3.4,4.5,1.6,"Versicolor" +6.7,3.1,4.7,1.5,"Versicolor" +6.3,2.3,4.4,1.3,"Versicolor" +5.6,3,4.1,1.3,"Versicolor" +5.5,2.5,4,1.3,"Versicolor" +5.5,2.6,4.4,1.2,"Versicolor" +6.1,3,4.6,1.4,"Versicolor" +5.8,2.6,4,1.2,"Versicolor" +5,2.3,3.3,1,"Versicolor" +5.6,2.7,4.2,1.3,"Versicolor" +5.7,3,4.2,1.2,"Versicolor" +5.7,2.9,4.2,1.3,"Versicolor" +6.2,2.9,4.3,1.3,"Versicolor" +5.1,2.5,3,1.1,"Versicolor" +5.7,2.8,4.1,1.3,"Versicolor" +6.3,3.3,6,2.5,"Virginica" +5.8,2.7,5.1,1.9,"Virginica" +7.1,3,5.9,2.1,"Virginica" +6.3,2.9,5.6,1.8,"Virginica" +6.5,3,5.8,2.2,"Virginica" +7.6,3,6.6,2.1,"Virginica" +4.9,2.5,4.5,1.7,"Virginica" +7.3,2.9,6.3,1.8,"Virginica" +6.7,2.5,5.8,1.8,"Virginica" +7.2,3.6,6.1,2.5,"Virginica" +6.5,3.2,5.1,2,"Virginica" +6.4,2.7,5.3,1.9,"Virginica" +6.8,3,5.5,2.1,"Virginica" +5.7,2.5,5,2,"Virginica" +5.8,2.8,5.1,2.4,"Virginica" +6.4,3.2,5.3,2.3,"Virginica" +6.5,3,5.5,1.8,"Virginica" +7.7,3.8,6.7,2.2,"Virginica" +7.7,2.6,6.9,2.3,"Virginica" +6,2.2,5,1.5,"Virginica" +6.9,3.2,5.7,2.3,"Virginica" +5.6,2.8,4.9,2,"Virginica" +7.7,2.8,6.7,2,"Virginica" +6.3,2.7,4.9,1.8,"Virginica" +6.7,3.3,5.7,2.1,"Virginica" +7.2,3.2,6,1.8,"Virginica" +6.2,2.8,4.8,1.8,"Virginica" +6.1,3,4.9,1.8,"Virginica" +6.4,2.8,5.6,2.1,"Virginica" +7.2,3,5.8,1.6,"Virginica" +7.4,2.8,6.1,1.9,"Virginica" +7.9,3.8,6.4,2,"Virginica" +6.4,2.8,5.6,2.2,"Virginica" +6.3,2.8,5.1,1.5,"Virginica" +6.1,2.6,5.6,1.4,"Virginica" +7.7,3,6.1,2.3,"Virginica" +6.3,3.4,5.6,2.4,"Virginica" +6.4,3.1,5.5,1.8,"Virginica" +6,3,4.8,1.8,"Virginica" +6.9,3.1,5.4,2.1,"Virginica" +6.7,3.1,5.6,2.4,"Virginica" +6.9,3.1,5.1,2.3,"Virginica" +5.8,2.7,5.1,1.9,"Virginica" +6.8,3.2,5.9,2.3,"Virginica" +6.7,3.3,5.7,2.5,"Virginica" +6.7,3,5.2,2.3,"Virginica" +6.3,2.5,5,1.9,"Virginica" +6.5,3,5.2,2,"Virginica" +6.2,3.4,5.4,2.3,"Virginica" +5.9,3,5.1,1.8,"Virginica"