diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt index c583857c81..c3dbacaee5 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt @@ -49,23 +49,38 @@ internal fun KType.projectUpTo(superClass: KClass<*>): KType { return current.withNullability(isMarkedNullable) } -internal fun KType.replaceTypeParameters(): KType { - var replaced = false - val arguments = arguments.map { - val type = it.type - val newType = when { - type == null -> typeOf() - type.classifier is KTypeParameter -> { - replaced = true - (type.classifier as KTypeParameter).upperBounds.firstOrNull() ?: typeOf() - } +/** + * Changes generic type parameters to `Any?`, like `List -> List`. + * Works recursively as well. + */ +@PublishedApi +internal fun KType.eraseGenericTypeParameters(): KType { + fun KType.eraseRecursively(): Pair { + var replaced = false + val arguments = arguments.map { + val type = it.type + val (replacedDownwards, newType) = when { + type == null -> typeOf() + + type.classifier is KTypeParameter -> { + replaced = true + (type.classifier as KTypeParameter).upperBounds.firstOrNull() ?: typeOf() + } - else -> type + else -> type + }.eraseRecursively() + + if (replacedDownwards) replaced = true + + KTypeProjection.invariant(newType) } - KTypeProjection.invariant(newType) + return Pair( + first = replaced, + second = if (replaced) jvmErasure.createType(arguments, isMarkedNullable) else this, + ) } - return if (replaced) jvmErasure.createType(arguments, isMarkedNullable) - else this + + return eraseRecursively().second } internal fun inheritanceChain(subClass: KClass<*>, superClass: KClass<*>): List, KType>> { @@ -255,7 +270,7 @@ internal fun Iterable.commonTypeListifyValues(): KType { else -> { val kclass = commonParent(distinct.map { it.jvmErasure }) ?: return typeOf() - val projections = distinct.map { it.projectUpTo(kclass).replaceTypeParameters() } + val projections = distinct.map { it.projectUpTo(kclass).eraseGenericTypeParameters() } require(projections.all { it.jvmErasure == kclass }) val arguments = List(kclass.typeParameters.size) { i -> val projectionTypes = projections diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/Utils.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/Utils.kt index 209ab3aa47..c1e933a33c 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/Utils.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/Utils.kt @@ -146,7 +146,7 @@ internal fun Iterable.commonType(): KType { distinct.size == 1 -> distinct.single()!! else -> { val kclass = commonParent(distinct.map { it!!.jvmErasure }) ?: return typeOf() - val projections = distinct.map { it!!.projectUpTo(kclass).replaceTypeParameters() } + val projections = distinct.map { it!!.projectUpTo(kclass).eraseGenericTypeParameters() } require(projections.all { it.jvmErasure == kclass }) val arguments = List(kclass.typeParameters.size) { i -> val projectionTypes = projections diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/constructors.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/constructors.kt index 5f198810f3..6117845049 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/constructors.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/columns/constructors.kt @@ -32,6 +32,7 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnWithPath import org.jetbrains.kotlinx.dataframe.impl.DataFrameReceiver import org.jetbrains.kotlinx.dataframe.impl.DataRowImpl import org.jetbrains.kotlinx.dataframe.impl.asList +import org.jetbrains.kotlinx.dataframe.impl.eraseGenericTypeParameters import org.jetbrains.kotlinx.dataframe.impl.guessValueType import org.jetbrains.kotlinx.dataframe.index import org.jetbrains.kotlinx.dataframe.nrow @@ -58,9 +59,25 @@ internal fun ColumnsContainer.newColumn( ): DataColumn { val (nullable, values) = computeValues(this as DataFrame, expression) return when (infer) { - Infer.Nulls -> DataColumn.create(name, values, type.withNullability(nullable), Infer.None) - Infer.Type -> DataColumn.createWithTypeInference(name, values, nullable) - Infer.None -> DataColumn.create(name, values, type, Infer.None) + Infer.Nulls -> DataColumn.create( + name = name, + values = values, + type = type.withNullability(nullable).eraseGenericTypeParameters(), + infer = Infer.None, + ) + + Infer.Type -> DataColumn.createWithTypeInference( + name = name, + values = values, + nullable = nullable, + ) + + Infer.None -> DataColumn.create( + name = name, + values = values, + type = type.eraseGenericTypeParameters(), + infer = Infer.None, + ) } } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt index f3065d03ea..4f164215df 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/add.kt @@ -2,7 +2,9 @@ package org.jetbrains.kotlinx.dataframe.api import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe +import org.jetbrains.kotlinx.dataframe.AnyFrame import org.junit.Test +import kotlin.reflect.typeOf class AddTests { @@ -23,4 +25,12 @@ class AddTests { df.add("y") { next()?.newValue() ?: 1 } } } + + private fun AnyFrame.addValue(value: T) = add("value") { listOf(value) } + + @Test + fun `add with generic function`() { + val df = dataFrameOf("a")(1).addValue(2) + df["value"].type() shouldBe typeOf>() + } } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/JupyterCodegenTests.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/JupyterCodegenTests.kt index 1064c456ce..47ac43181f 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/JupyterCodegenTests.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/jupyter/JupyterCodegenTests.kt @@ -4,12 +4,31 @@ import io.kotest.assertions.throwables.shouldNotThrowAny import io.kotest.matchers.shouldBe import io.kotest.matchers.types.shouldBeInstanceOf import org.intellij.lang.annotations.Language +import org.jetbrains.kotlinx.dataframe.AnyFrame import org.jetbrains.kotlinx.dataframe.columns.ValueColumn +import org.jetbrains.kotlinx.dataframe.type import org.jetbrains.kotlinx.jupyter.api.MimeTypedResult import org.jetbrains.kotlinx.jupyter.testkit.JupyterReplTestCase import org.junit.Test +import kotlin.reflect.typeOf class JupyterCodegenTests : JupyterReplTestCase() { + + @Test + fun `codegen adding column with generic type function`() { + @Language("kts") + val res1 = exec( + """ + fun AnyFrame.addValue(value: T) = add("value") { listOf(value) } + val df = dataFrameOf("a")(1).addValue(2) + """.trimIndent() + ) + res1 shouldBe Unit + val res2 = execRaw("df") as AnyFrame + + res2["value"].type shouldBe typeOf>() + } + @Test fun `codegen for enumerated frames`() { @Language("kts") @@ -78,6 +97,7 @@ class JupyterCodegenTests : JupyterReplTestCase() { @Test fun `codegen for chars that is forbidden in JVM identifiers`() { val forbiddenChar = ";" + @Language("kts") val res1 = exec( """ @@ -96,6 +116,7 @@ class JupyterCodegenTests : JupyterReplTestCase() { @Test fun `codegen for chars that is forbidden in JVM identifiers 1`() { val forbiddenChar = "\\\\" + @Language("kts") val res1 = exec( """