Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ public fun <T> DataFrame<T>.add(body: AddDsl<T>.() -> Unit): DataFrame<T> {
return dataFrameOf([email protected]() + dsl.columns).cast()
}

@Refine
@Interpretable("GroupByAdd")
public inline fun <reified R, T, G> GroupBy<T, G>.add(
name: String,
infer: Infer = Infer.Nulls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ class FunctionCallTransformer(
val groupMarker = rootMarkers[1]

val (keySchema, groupSchema) = if (groupBy != null) {
val keySchema = createPluginDataFrameSchema(groupBy.keys, groupBy.moveToTop)
val groupSchema = PluginDataFrameSchema(groupBy.df.columns())
val keySchema = groupBy.keys
val groupSchema = groupBy.groups
keySchema to groupSchema
} else {
PluginDataFrameSchema.EMPTY to PluginDataFrameSchema.EMPTY
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.jetbrains.kotlinx.dataframe.plugin.impl

import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter.*
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TypeApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.DataFrameCallableId
import kotlin.properties.PropertyDelegateProvider
Expand Down Expand Up @@ -35,3 +36,7 @@ internal fun <T> AbstractInterpreter<T>.ignore(
): ExpectedArgumentProvider<Nothing?> =
arg(name, lens = Interpreter.Id, defaultValue = Present(null))

internal fun <T> AbstractInterpreter<T>.groupBy(
name: ArgumentName? = null
): ExpectedArgumentProvider<GroupBy> = arg(name, lens = Interpreter.GroupBy)

Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ interface Interpreter<T> {

data object Schema : Lens

data object GroupBy : Lens

data object Id : Lens

// required to compute whether resulting schema should be inheritor of previous class or a new class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ data class PluginDataFrameSchema(
}
}

fun PluginDataFrameSchema.add(name: String, type: ConeKotlinType, context: KotlinTypeFacade): PluginDataFrameSchema {
return PluginDataFrameSchema(columns() + context.simpleColumnOf(name, type))
}

private fun List<SimpleCol>.asString(indent: String = ""): String {
return joinToString("\n") {
val col = when (it) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,22 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.Present
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.add
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.groupBy
import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf
import org.jetbrains.kotlinx.dataframe.plugin.impl.type

class GroupBy(val df: PluginDataFrameSchema, val keys: List<ColumnWithPathApproximation>, val moveToTop: Boolean)
class GroupBy(val keys: PluginDataFrameSchema, val groups: PluginDataFrameSchema)

class DataFrameGroupBy : AbstractInterpreter<GroupBy>() {
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
val Arguments.moveToTop: Boolean by arg(defaultValue = Present(true))
val Arguments.cols: ColumnsResolver by arg()

override fun Arguments.interpret(): GroupBy {
return GroupBy(receiver, cols.resolve(receiver), moveToTop)
return GroupBy(keys = createPluginDataFrameSchema(cols.resolve(receiver), moveToTop), groups = receiver)
}
}

Expand All @@ -52,7 +55,7 @@ class GroupByInto : AbstractInterpreter<Unit>() {
}

class Aggregate : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: GroupBy by arg()
val Arguments.receiver: GroupBy by groupBy()
val Arguments.body: FirAnonymousFunctionExpression by arg(lens = Interpreter.Id)
override fun Arguments.interpret(): PluginDataFrameSchema {
return aggregate(
Expand Down Expand Up @@ -87,7 +90,7 @@ fun KotlinTypeFacade.aggregate(
)
}

val cols = createPluginDataFrameSchema(groupBy.keys, groupBy.moveToTop).columns() + dsl.columns.map {
val cols = groupBy.keys.columns() + dsl.columns.map {
simpleColumnOf(it.name, it.type)
}
PluginDataFrameSchema(cols)
Expand Down Expand Up @@ -144,13 +147,23 @@ fun KotlinTypeFacade.createPluginDataFrameSchema(keys: List<ColumnWithPathApprox
}

class GroupByToDataFrame : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: GroupBy by arg()
val Arguments.receiver: GroupBy by groupBy()
val Arguments.groupedColumnName: String? by arg(defaultValue = Present(null))

override fun Arguments.interpret(): PluginDataFrameSchema {
val grouped = listOf(SimpleFrameColumn(groupedColumnName ?: "group", receiver.df.columns()))
val grouped = listOf(SimpleFrameColumn(groupedColumnName ?: "group", receiver.groups.columns()))
return PluginDataFrameSchema(
createPluginDataFrameSchema(receiver.keys, receiver.moveToTop).columns() + grouped
receiver.keys.columns() + grouped
)
}
}

class GroupByAdd : AbstractInterpreter<GroupBy>() {
val Arguments.receiver: GroupBy by groupBy()
val Arguments.name: String by arg()
val Arguments.type: TypeApproximation by type(name("expression"))

override fun Arguments.interpret(): GroupBy {
return GroupBy(receiver.keys, receiver.groups.add(name, type.type, context = this))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.jetbrains.kotlin.fir.references.resolved
import org.jetbrains.kotlin.fir.references.symbol
import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
import org.jetbrains.kotlin.fir.resolve.fqName
import org.jetbrains.kotlin.fir.resolve.fullyExpandedType
import org.jetbrains.kotlin.fir.scopes.collectAllProperties
import org.jetbrains.kotlin.fir.scopes.getProperties
import org.jetbrains.kotlin.fir.scopes.impl.declaredMemberScope
Expand Down Expand Up @@ -78,6 +79,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ColumnsResolver
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.SingleColumnApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TypeApproximation

Expand Down Expand Up @@ -277,6 +279,17 @@ fun <T> KotlinTypeFacade.interpret(
}
}

is Interpreter.GroupBy -> {
assert(expectedReturnType.toString() == GroupBy::class.qualifiedName!!) {
"'$name' should be ${GroupBy::class.qualifiedName!!}, but plugin expect $expectedReturnType"
}

val resolvedType = it.expression.resolvedType.fullyExpandedType(session)
val keys = pluginDataFrameSchema(resolvedType.typeArguments[0])
val groups = pluginDataFrameSchema(resolvedType.typeArguments[1])
Interpreter.Success(GroupBy(keys, groups))
}

is Interpreter.Id -> {
Interpreter.Success(it.expression)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FillNulls0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Flatten0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FlattenDefault
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.FrameCols0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupByAdd
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MapToFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Move0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.MoveAfter0
Expand Down Expand Up @@ -275,6 +276,7 @@ internal inline fun <reified T> String.load(): T {
"MoveToLeft1" -> MoveToLeft1()
"MoveToRight0" -> MoveToRight0()
"MoveAfter0" -> MoveAfter0()
"GroupByAdd" -> GroupByAdd()
else -> error("$this")
} as T
}
42 changes: 42 additions & 0 deletions plugins/kotlin-dataframe/testData/box/groupByAdd.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.api.groupBy
import org.jetbrains.kotlinx.dataframe.io.*

enum class State {
Idle,
Productive,
Maintenance,
}

class Event(val toolId: String, val state: State, val timestamp: Long)

fun box(): String {
val tool1 = "tool_1"
val tool2 = "tool_2"
val tool3 = "tool_3"

val events = listOf(
Event(tool1, State.Idle, 0),
Event(tool1, State.Productive, 5),
Event(tool2, State.Idle, 0),
Event(tool2, State.Maintenance, 10),
Event(tool2, State.Idle, 20),
Event(tool3, State.Idle, 0),
Event(tool3, State.Productive, 25),
).toDataFrame()

val lastTimestamp = events.maxOf { timestamp }
val groupBy = events
.groupBy { toolId }
.sortBy { timestamp }
.add("stateDuration") {
(next()?.timestamp ?: lastTimestamp) - timestamp
}.toDataFrame()

groupBy.group[0].stateDuration

groupBy.compareSchemas(strict = true)
return "OK"
}
14 changes: 14 additions & 0 deletions plugins/kotlin-dataframe/testData/box/groupBy_extractSchema.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import org.jetbrains.kotlinx.dataframe.*
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.api.*
import org.jetbrains.kotlinx.dataframe.io.*

fun box(): String {
val df = dataFrameOf("a", "b", "c")(1, 2, 3)

val groupBy = df.groupBy { a }

val df1 = groupBy.updateGroups { it.remove { a } }.toDataFrame()
df1.compileTimeSchema().print()
return "OK"
}
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,24 @@ public void testGroupBy() {
runTest("testData/box/groupBy.kt");
}

@Test
@TestMetadata("groupByAdd.kt")
public void testGroupByAdd() {
runTest("testData/box/groupByAdd.kt");
}

@Test
@TestMetadata("groupBy_DataRow.kt")
public void testGroupBy_DataRow() {
runTest("testData/box/groupBy_DataRow.kt");
}

@Test
@TestMetadata("groupBy_extractSchema.kt")
public void testGroupBy_extractSchema() {
runTest("testData/box/groupBy_extractSchema.kt");
}

@Test
@TestMetadata("groupBy_refine.kt")
public void testGroupBy_refine() {
Expand Down
Loading