diff --git a/commons-benchmark/jvm/src/main/scala/com/avsystem/commons/ser/GenCodecBenchmarks.scala b/commons-benchmark/jvm/src/main/scala/com/avsystem/commons/ser/GenCodecBenchmarks.scala index e01f17402..b6067969f 100644 --- a/commons-benchmark/jvm/src/main/scala/com/avsystem/commons/ser/GenCodecBenchmarks.scala +++ b/commons-benchmark/jvm/src/main/scala/com/avsystem/commons/ser/GenCodecBenchmarks.scala @@ -39,5 +39,7 @@ object DummyInput extends Input { def readList() = ignored def readBoolean() = ignored def readDouble() = ignored + def readBigInt() = ignored + def readBigDecimal() = ignored def skip() = () } diff --git a/commons-benchmark/src/main/scala/com/avsystem/commons/ser/CirceJsonInputOutput.scala b/commons-benchmark/src/main/scala/com/avsystem/commons/ser/CirceJsonInputOutput.scala index dc1617c9c..ca104edee 100644 --- a/commons-benchmark/src/main/scala/com/avsystem/commons/ser/CirceJsonInputOutput.scala +++ b/commons-benchmark/src/main/scala/com/avsystem/commons/ser/CirceJsonInputOutput.scala @@ -22,7 +22,9 @@ class CirceJsonOutput(consumer: Json => Any) extends Output { def writeInt(int: Int): Unit = consumer(Json.fromInt(int)) def writeLong(long: Long): Unit = consumer(Json.fromLong(long)) def writeDouble(double: Double): Unit = consumer(Json.fromDoubleOrString(double)) - def writeBinary(binary: Array[Byte]): Unit = ??? + def writeBigInt(bigInt: BigInt): Unit = consumer(Json.fromBigInt(bigInt)) + def writeBigDecimal(bigDecimal: BigDecimal): Unit = consumer(Json.fromBigDecimal(bigDecimal)) + def writeBinary(binary: Array[Byte]): Unit = consumer(Json.fromValues(binary.map(Json.fromInt(_)))) def writeList(): ListOutput = new CirceJsonListOutput(consumer) def writeObject(): ObjectOutput = new CirceJsonObjectOutput(consumer) override def writeFloat(float: Float): Unit = consumer(Json.fromFloatOrString(float)) @@ -67,7 +69,10 @@ class CirceJsonInput(json: Json) extends Input { def readInt(): Int = asNumber.toInt.getOrElse(failNot("int")) def readLong(): Long = asNumber.toLong.getOrElse(failNot("long")) def readDouble(): Double = asNumber.toDouble - def readBinary(): Array[Byte] = ??? + def readBigInt(): BigInt = asNumber.toBigInt.getOrElse(failNot("bigInteger")) + def readBigDecimal(): BigDecimal = asNumber.toBigDecimal.getOrElse(failNot("bigDecimal")) + def readBinary(): Array[Byte] = json.asArray.getOrElse(failNot("array")).iterator + .map(_.asNumber.flatMap(_.toByte).getOrElse(failNot("byte"))).toArray def readList(): ListInput = new CirceJsonListInput(json.asArray.getOrElse(failNot("array"))) def readObject(): ObjectInput = new CirceJsonObjectInput(json.asObject.getOrElse(failNot("object"))) def skip(): Unit = () diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala index d658b317b..8257ca0ab 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/GenCodec.scala @@ -287,8 +287,8 @@ object GenCodec extends RecursiveAutoCodecs with TupleGenCodecs { implicit lazy val LongCodec: GenCodec[Long] = create(_.readLong(), _ writeLong _) implicit lazy val FloatCodec: GenCodec[Float] = create(_.readFloat(), _ writeFloat _) implicit lazy val DoubleCodec: GenCodec[Double] = create(_.readDouble(), _ writeDouble _) - implicit lazy val BigIntCodec: GenCodec[BigInt] = createNullable(i => BigInt(i.readString()), (o, v) => o.writeString(v.toString)) - implicit lazy val BigDecimalCodec: GenCodec[BigDecimal] = createNullable(i => BigDecimal(i.readString()), (o, v) => o.writeString(v.toString)) + implicit lazy val BigIntCodec: GenCodec[BigInt] = createNullable(_.readBigInt(), _ writeBigInt _) + implicit lazy val BigDecimalCodec: GenCodec[BigDecimal] = createNullable(_.readBigDecimal(), _ writeBigDecimal _) implicit lazy val JBooleanCodec: GenCodec[JBoolean] = createNullable(_.readBoolean(), _ writeBoolean _) implicit lazy val JCharacterCodec: GenCodec[JCharacter] = createNullable(_.readChar(), _ writeChar _) @@ -298,8 +298,10 @@ object GenCodec extends RecursiveAutoCodecs with TupleGenCodecs { implicit lazy val JLongCodec: GenCodec[JLong] = createNullable(_.readLong(), _ writeLong _) implicit lazy val JFloatCodec: GenCodec[JFloat] = createNullable(_.readFloat(), _ writeFloat _) implicit lazy val JDoubleCodec: GenCodec[JDouble] = createNullable(_.readDouble(), _ writeDouble _) - implicit lazy val JBigIntegerCodec: GenCodec[JBigInteger] = createNullable(i => new JBigInteger(i.readString()), (o, v) => o.writeString(v.toString)) - implicit lazy val JBigDecimalCodec: GenCodec[JBigDecimal] = createNullable(i => new JBigDecimal(i.readString()), (o, v) => o.writeString(v.toString)) + implicit lazy val JBigIntegerCodec: GenCodec[JBigInteger] = + createNullable(_.readBigInt().bigInteger, (o, v) => o.writeBigInt(BigInt(v))) + implicit lazy val JBigDecimalCodec: GenCodec[JBigDecimal] = + createNullable(_.readBigDecimal().bigDecimal, (o, v) => o.writeBigDecimal(BigDecimal(v))) implicit lazy val JDateCodec: GenCodec[JDate] = createNullable(i => new JDate(i.readTimestamp()), (o, d) => o.writeTimestamp(d.getTime)) implicit lazy val StringCodec: GenCodec[String] = createNullable(_.readString(), _ writeString _) diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/InputOutput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/InputOutput.scala index 559da7085..5c1d83b1a 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/InputOutput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/InputOutput.scala @@ -22,6 +22,8 @@ trait Output extends Any { def writeTimestamp(millis: Long): Unit = writeLong(millis) def writeFloat(float: Float): Unit = writeDouble(float) def writeDouble(double: Double): Unit + def writeBigInt(bigInt: BigInt): Unit + def writeBigDecimal(bigDecimal: BigDecimal): Unit def writeBinary(binary: Array[Byte]): Unit def writeList(): ListOutput def writeObject(): ObjectOutput @@ -139,6 +141,8 @@ trait Input extends Any { def readTimestamp(): Long = readLong() def readFloat(): Float = readDouble().toFloat def readDouble(): Double + def readBigInt(): BigInt + def readBigDecimal(): BigDecimal def readBinary(): Array[Byte] def readList(): ListInput def readObject(): ObjectInput diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/SimpleValueInputOutput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/SimpleValueInputOutput.scala index ac7571160..3b0d48f84 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/SimpleValueInputOutput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/SimpleValueInputOutput.scala @@ -28,6 +28,8 @@ object SimpleValueOutput { * - `Int` * - `Long` * - `Double` + * - `BigInt` + * - `BigDecimal` * - `Boolean` * - `String` * - `Array[Byte]` @@ -49,27 +51,27 @@ class SimpleValueOutput( def this(consumer: Any => Unit) = this(consumer, new MHashMap[String, Any], new ListBuffer[Any]) - def writeBinary(binary: Array[Byte]) = consumer(binary) - def writeString(str: String) = consumer(str) - def writeDouble(double: Double) = consumer(double) - def writeInt(int: Int) = consumer(int) - - def writeList() = new ListOutput { + def writeNull(): Unit = consumer(null) + def writeBoolean(boolean: Boolean): Unit = consumer(boolean) + def writeString(str: String): Unit = consumer(str) + def writeInt(int: Int): Unit = consumer(int) + def writeLong(long: Long): Unit = consumer(long) + def writeDouble(double: Double): Unit = consumer(double) + def writeBigInt(bigInt: BigInt): Unit = consumer(bigInt) + def writeBigDecimal(bigDecimal: BigDecimal): Unit = consumer(bigDecimal) + def writeBinary(binary: Array[Byte]): Unit = consumer(binary) + + def writeList(): ListOutput = new ListOutput { private val buffer = newListRepr def writeElement() = new SimpleValueOutput(buffer += _, newObjectRepr, newListRepr) - def finish() = consumer(buffer.result()) + def finish(): Unit = consumer(buffer.result()) } - def writeBoolean(boolean: Boolean) = consumer(boolean) - - def writeObject() = new ObjectOutput { + def writeObject(): ObjectOutput = new ObjectOutput { private val result = newObjectRepr def writeField(key: String) = new SimpleValueOutput(v => result += ((key, v)), newObjectRepr, newListRepr) - def finish() = consumer(result) + def finish(): Unit = consumer(result) } - - def writeLong(long: Long) = consumer(long) - def writeNull() = consumer(null) } object SimpleValueInput { @@ -91,41 +93,42 @@ class SimpleValueInput(value: Any) extends Input { case _ => throw new ReadFailure(s"Expected ${classTag[B].runtimeClass} but got ${value.getClass}") } - def inputType = value match { + def inputType: InputType = value match { case null => InputType.Null case _: BSeq[Any] => InputType.List case _: BMap[_, Any] => InputType.Object case _ => InputType.Simple } - def readBinary() = doRead[Array[Byte]] - def readLong() = doReadUnboxed[Long, JLong] - def readNull() = if (value == null) null else throw new ReadFailure("not null") - def readObject() = + def readNull(): Null = if (value == null) null else throw new ReadFailure("not null") + def readBoolean(): Boolean = doReadUnboxed[Boolean, JBoolean] + def readString(): String = doRead[String] + def readInt(): Int = doReadUnboxed[Int, JInteger] + def readLong(): Long = doReadUnboxed[Long, JLong] + def readDouble(): Double = doReadUnboxed[Double, JDouble] + def readBigInt(): BigInt = doRead[JBigInteger] + def readBigDecimal(): BigDecimal = doRead[JBigDecimal] + def readBinary(): Array[Byte] = doRead[Array[Byte]] + + def readObject(): ObjectInput = new ObjectInput { private val map = doRead[BMap[String, Any]] private val it = map.iterator.map { case (k, v) => new SimpleValueFieldInput(k, v) } - def nextField() = it.next() - override def peekField(name: String) = map.getOpt(name).map(new SimpleValueFieldInput(name, _)) - def hasNext = it.hasNext + def nextField(): SimpleValueFieldInput = it.next() + override def peekField(name: String): Opt[SimpleValueFieldInput] = map.getOpt(name).map(new SimpleValueFieldInput(name, _)) + def hasNext: Boolean = it.hasNext } - def readInt() = doReadUnboxed[Int, JInteger] - def readString() = doRead[String] - - def readList() = + def readList(): ListInput = new ListInput { private val it = doRead[BSeq[Any]].iterator.map(new SimpleValueInput(_)) - def nextElement() = it.next() - def hasNext = it.hasNext + def nextElement(): SimpleValueInput = it.next() + def hasNext: Boolean = it.hasNext } - def readBoolean() = doReadUnboxed[Boolean, JBoolean] - def readDouble() = doReadUnboxed[Double, JDouble] - - def skip() = () + def skip(): Unit = () } class SimpleValueFieldInput(val fieldName: String, value: Any) diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala index 7cab613f9..86a2f5519 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/StreamInputOutput.scala @@ -29,6 +29,8 @@ private object FormatConstants { final val ObjectStartMarker: Byte = 11 final val ListEndMarker: Byte = 12 final val ObjectEndMarker: Byte = 13 + final val BitIntMarker: Byte = 14 + final val BigDecimalMarker: Byte = 15 } import com.avsystem.commons.serialization.FormatConstants._ @@ -47,53 +49,68 @@ class StreamInput(is: DataInputStream) extends Input { InputType.Simple } - def readNull(): Null = if (markerByte == NullMarker) - null - else - throw new ReadFailure(s"Expected null, but $markerByte found") - - def readString(): String = if (markerByte == StringMarker) - is.readUTF() - else - throw new ReadFailure(s"Expected string, but $markerByte found") - - def readBoolean(): Boolean = if (markerByte == BooleanMarker) - is.readBoolean() - else - throw new ReadFailure(s"Expected boolean, but $markerByte found") - - def readInt(): Int = if (markerByte == IntMarker) - is.readInt() - else - throw new ReadFailure(s"Expected int, but $markerByte found") - - def readLong(): Long = if (markerByte == LongMarker) - is.readLong() - else - throw new ReadFailure(s"Expected long, but $markerByte found") - - def readDouble(): Double = if (markerByte == DoubleMarker) - is.readDouble() - else - throw new ReadFailure(s"Expected double, but $markerByte found") - - def readBinary(): Array[Byte] = if (markerByte == ByteArrayMarker) { - val binary = Array.ofDim[Byte](is.readInt()) - is.readFully(binary) - binary - } else { - throw new ReadFailure(s"Expected binary array, but $markerByte found") - } + def readNull(): Null = + if (markerByte == NullMarker) null + else throw new ReadFailure(s"Expected null but $markerByte found") + + def readString(): String = + if (markerByte == StringMarker) is.readUTF() + else throw new ReadFailure(s"Expected string but $markerByte found") + + def readBoolean(): Boolean = + if (markerByte == BooleanMarker) is.readBoolean() + else throw new ReadFailure(s"Expected boolean but $markerByte found") + + def readInt(): Int = + if (markerByte == IntMarker) is.readInt() + else throw new ReadFailure(s"Expected int but $markerByte found") + + def readLong(): Long = + if (markerByte == LongMarker) is.readLong() + else throw new ReadFailure(s"Expected long but $markerByte found") + + def readDouble(): Double = + if (markerByte == DoubleMarker) is.readDouble() + else throw new ReadFailure(s"Expected double but $markerByte found") + + def readBigInt(): BigInt = + if (markerByte == BitIntMarker) { + val len = is.readInt() + val bytes = new Array[Byte](len) + is.read(bytes) + BigInt(bytes) + } else { + throw new ReadFailure(s"Expected big integer but $markerByte found") + } - def readList(): ListInput = if (markerByte == ListStartMarker) - new StreamListInput(is) - else - throw new ReadFailure(s"Expected list, but $markerByte found") + def readBigDecimal(): BigDecimal = + if (markerByte == BigDecimalMarker) { + val len = is.readInt() + val bytes = new Array[Byte](len) + is.read(bytes) + val unscaled = BigInt(bytes) + val scale = is.readInt() + BigDecimal(unscaled, scale) + } else { + throw new ReadFailure(s"Expected big decimal but $markerByte found") + } - def readObject(): ObjectInput = if (markerByte == ObjectStartMarker) - new StreamObjectInput(is) - else - throw new ReadFailure(s"Expected object, but $markerByte found") + def readBinary(): Array[Byte] = + if (markerByte == ByteArrayMarker) { + val binary = Array.ofDim[Byte](is.readInt()) + is.readFully(binary) + binary + } else { + throw new ReadFailure(s"Expected binary array but $markerByte found") + } + + def readList(): ListInput = + if (markerByte == ListStartMarker) new StreamListInput(is) + else throw new ReadFailure(s"Expected list but $markerByte found") + + def readObject(): ObjectInput = + if (markerByte == ObjectStartMarker) new StreamObjectInput(is) + else throw new ReadFailure(s"Expected object but $markerByte found") def skip(): Unit = markerByte match { case NullMarker => @@ -121,6 +138,10 @@ class StreamInput(is: DataInputStream) extends Input { new StreamListInput(is).skipRemaining() case ObjectStartMarker => new StreamObjectInput(is).skipRemaining() + case BitIntMarker => + is.skipBytes(is.readInt()) + case BigDecimalMarker => + is.skipBytes(is.readInt() + Integer.BYTES) case unexpected => throw new ReadFailure(s"Unexpected marker byte: $unexpected") } @@ -182,14 +203,16 @@ private object StreamObjectInput { case class EmptyFieldInput(name: String) extends FieldInput { private def nope: Nothing = throw new ReadFailure(s"Something went horribly wrong ($name)") - def fieldName: String = nope def inputType: InputType = nope + def fieldName: String = nope def readNull(): Null = nope def readString(): String = nope def readBoolean(): Boolean = nope def readInt(): Int = nope def readLong(): Long = nope def readDouble(): Double = nope + def readBigInt(): BigInt = nope + def readBigDecimal(): BigDecimal = nope def readBinary(): Array[Byte] = nope def readList(): ListInput = nope def readObject(): ObjectInput = nope @@ -232,6 +255,21 @@ class StreamOutput(os: DataOutputStream) extends Output { os.writeDouble(double) } + def writeBigInt(bigInt: BigInt): Unit = { + os.writeByte(BitIntMarker) + val bytes = bigInt.toByteArray + os.writeInt(bytes.length) + os.write(bytes) + } + + def writeBigDecimal(bigDecimal: BigDecimal): Unit = { + os.writeByte(BigDecimalMarker) + val bytes = bigDecimal.bigDecimal.unscaledValue.toByteArray + os.writeInt(bytes.length) + os.write(bytes) + os.writeInt(bigDecimal.scale) + } + def writeBinary(binary: Array[Byte]): Unit = { os.writeByte(ByteArrayMarker) os.writeInt(binary.length) diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringInput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringInput.scala index f065c0584..f6f813c6e 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringInput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringInput.scala @@ -33,7 +33,8 @@ class JsonStringInput(reader: JsonReader, callback: AfterElement = AfterElementN case _ => afterElement() } - private def expectedError(tpe: JsonType) = throw new ReadFailure(s"Expected $tpe but got ${reader.jsonType}: ${reader.currentValue}") + private def expectedError(tpe: JsonType) = + throw new ReadFailure(s"Expected $tpe but got ${reader.jsonType}: ${reader.currentValue}") private def checkedValue[T](jsonType: JsonType): T = { if (reader.jsonType != jsonType) expectedError(jsonType) @@ -64,6 +65,8 @@ class JsonStringInput(reader: JsonReader, callback: AfterElement = AfterElementN def readInt(): Int = matchNumericString(_.toInt) def readLong(): Long = matchNumericString(_.toLong) def readDouble(): Double = matchNumericString(_.toDouble) + def readBigInt(): BigInt = matchNumericString(BigInt(_)) + def readBigDecimal(): BigDecimal = matchNumericString(BigDecimal(_)) def readBinary(): Array[Byte] = { val hex = checkedValue[String](JsonType.string) val result = new Array[Byte](hex.length / 2) @@ -213,7 +216,7 @@ final class JsonReader(val json: String) { private def readHex(): Int = fromHex(read()) - private def parseNumber(): Any = { + private def parseNumber(): String = { val start = i if (isNext('-')) { diff --git a/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringOutput.scala b/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringOutput.scala index 4213c1a2c..32c2470eb 100644 --- a/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringOutput.scala +++ b/commons-core/src/main/scala/com/avsystem/commons/serialization/json/JsonStringOutput.scala @@ -50,6 +50,8 @@ final class JsonStringOutput(builder: JStringBuilder) extends BaseJsonOutput wit writeString(double.toString) else builder.append(double.toString) } + def writeBigInt(bigInt: BigInt): Unit = builder.append(bigInt.toString) + def writeBigDecimal(bigDecimal: BigDecimal): Unit = builder.append(bigDecimal.toString) def writeBinary(binary: Array[Byte]): Unit = { builder.append('"') var i = 0 diff --git a/commons-core/src/test/scala/com/avsystem/commons/serialization/StreamInputOutputTest.scala b/commons-core/src/test/scala/com/avsystem/commons/serialization/StreamInputOutputTest.scala index 63a8f4502..4432838f1 100644 --- a/commons-core/src/test/scala/com/avsystem/commons/serialization/StreamInputOutputTest.scala +++ b/commons-core/src/test/scala/com/avsystem/commons/serialization/StreamInputOutputTest.scala @@ -22,9 +22,11 @@ case class FieldTypes( i: Int, j: Long, k: Double, - l: Array[Byte], - m: Obj, - n: List[List[Obj]] + l: BigInt, + m: BigDecimal, + n: Array[Byte], + o: Obj, + p: List[List[Obj]] ) class StreamInputOutputTest extends FunSuite { @@ -42,6 +44,8 @@ class StreamInputOutputTest extends FunSuite { -5, -6, -7.3, + BigInt("5345224654563123434325343"), + BigDecimal(BigInt("2356342454564522135435"), 150), Array[Byte](1, 2, 4, 2), Obj(10, "x"), List( @@ -90,8 +94,8 @@ class StreamInputOutputTest extends FunSuite { test("encode and decode all field types in a complicated structure") { val encoded = encDec(fieldTypesInstance) - assert(fieldTypesInstance.l sameElements encoded.l) - assert(fieldTypesInstance == encoded.copy(l = fieldTypesInstance.l)) + assert(fieldTypesInstance.n sameElements encoded.n) + assert(fieldTypesInstance == encoded.copy(n = fieldTypesInstance.n)) } test("raw API usage") { diff --git a/commons-core/src/test/scala/com/avsystem/commons/serialization/json/JsonStringInputOutputTest.scala b/commons-core/src/test/scala/com/avsystem/commons/serialization/json/JsonStringInputOutputTest.scala index de1f499bf..fca6b1ff2 100644 --- a/commons-core/src/test/scala/com/avsystem/commons/serialization/json/JsonStringInputOutputTest.scala +++ b/commons-core/src/test/scala/com/avsystem/commons/serialization/json/JsonStringInputOutputTest.scala @@ -199,6 +199,8 @@ class JsonStringInputOutputTest extends FunSuite with SerializationTestUtils wit deserialized.i2.long shouldBe item.i2.long deserialized.i2.float shouldBe item.i2.float deserialized.i2.double shouldBe item.i2.double + deserialized.i2.bigInt shouldBe item.i2.bigInt + deserialized.i2.bigDecimal shouldBe item.i2.bigDecimal deserialized.i2.binary shouldBe item.i2.binary deserialized.i2.list shouldBe item.i2.list deserialized.i2.set shouldBe item.i2.set @@ -207,7 +209,6 @@ class JsonStringInputOutputTest extends FunSuite with SerializationTestUtils wit } } - test("serialize and deserialize huge case classes") { implicit val arbTree: Arbitrary[DeepNestedTestCC] = Arbitrary { diff --git a/commons-core/src/test/scala/com/avsystem/commons/serialization/json/SerializationTestUtils.scala b/commons-core/src/test/scala/com/avsystem/commons/serialization/json/SerializationTestUtils.scala index 9ba346748..00a65f2f8 100644 --- a/commons-core/src/test/scala/com/avsystem/commons/serialization/json/SerializationTestUtils.scala +++ b/commons-core/src/test/scala/com/avsystem/commons/serialization/json/SerializationTestUtils.scala @@ -26,8 +26,8 @@ trait SerializationTestUtils { case class CompleteItem( unit: Unit, string: String, char: Char, boolean: Boolean, byte: Byte, short: Short, int: Int, - long: Long, float: Float, double: Double, binary: Array[Byte], list: List[String], - set: Set[String], obj: TestCC, map: Map[String, Int] + long: Long, float: Float, double: Double, bigInt: BigInt, bigDecimal: BigDecimal, + binary: Array[Byte], list: List[String], set: Set[String], obj: TestCC, map: Map[String, Int] ) object CompleteItem extends HasGenCodec[CompleteItem] { implicit val arb: Arbitrary[CompleteItem] = Arbitrary(for { @@ -41,11 +41,13 @@ trait SerializationTestUtils { l <- arbitrary[Long] f <- arbitrary[Float] d <- arbitrary[Double] + bi <- arbitrary[BigInt] + bd <- arbitrary[BigDecimal] binary <- arbitrary[Array[Byte]] list <- arbitrary[List[String]] set <- arbitrary[Set[String]] obj <- arbitrary[TestCC] map <- arbitrary[Map[String, Int]] - } yield CompleteItem(u, str, c, bool, b, s, i, l, f, d, binary, list, set, obj, map)) + } yield CompleteItem(u, str, c, bool, b, s, i, l, f, d, bi, bd, binary, list, set, obj, map)) } } diff --git a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala index dd8f34e8e..ed9c41768 100644 --- a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala +++ b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonInputOutput.scala @@ -1,6 +1,8 @@ package com.avsystem.commons package mongo +import java.nio.ByteBuffer + import com.avsystem.commons.serialization.{Input, Output} import org.bson.types.ObjectId @@ -8,6 +10,24 @@ trait BsonInput extends Any with Input { def readObjectId(): ObjectId } +object BsonInput { + def bigDecimalFromBytes(bytes: Array[Byte]): BigDecimal = { + val buf = ByteBuffer.wrap(bytes) + val unscaledBytes = new Array[Byte](bytes.length - Integer.BYTES) + buf.get(unscaledBytes) + val unscaled = BigInt(unscaledBytes) + val scale = buf.getInt + BigDecimal(unscaled, scale) + } +} + trait BsonOutput extends Any with Output { def writeObjectId(objectId: ObjectId): Unit } + +object BsonOutput { + def bigDecimalBytes(bigDecimal: BigDecimal): Array[Byte] = { + val unscaledBytes = bigDecimal.bigDecimal.unscaledValue.toByteArray + ByteBuffer.allocate(unscaledBytes.length + Integer.BYTES).put(unscaledBytes).putInt(bigDecimal.scale).array + } +} diff --git a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonReaderInput.scala b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonReaderInput.scala index 4c36887e6..77b256014 100644 --- a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonReaderInput.scala +++ b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonReaderInput.scala @@ -24,6 +24,8 @@ class BsonReaderInput(br: BsonReader) extends BsonInput { override def readLong(): Long = br.readInt64() override def readTimestamp(): Long = br.readDateTime() override def readDouble(): Double = br.readDouble() + override def readBigInt(): BigInt = BigInt(br.readBinaryData().getData) + override def readBigDecimal(): BigDecimal = BsonInput.bigDecimalFromBytes(br.readBinaryData().getData) override def readBinary(): Array[Byte] = br.readBinaryData().getData override def readList(): BsonReaderListInput = { br.readStartArray() diff --git a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonValueOutput.scala b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonValueOutput.scala index c3ed12341..c9f59531f 100644 --- a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonValueOutput.scala +++ b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonValueOutput.scala @@ -33,6 +33,8 @@ final class BsonValueOutput(receiver: BsonValue => Unit = _ => ()) extends BsonO override def writeLong(long: Long): Unit = setValue(new BsonInt64(long)) override def writeTimestamp(millis: Long): Unit = setValue(new BsonDateTime(millis)) override def writeDouble(double: Double): Unit = setValue(new BsonDouble(double)) + override def writeBigInt(bigInt: BigInt): Unit = setValue(new BsonBinary(bigInt.toByteArray)) + override def writeBigDecimal(bigDecimal: BigDecimal): Unit = setValue(new BsonBinary(BsonOutput.bigDecimalBytes(bigDecimal))) override def writeBinary(binary: Array[Byte]): Unit = setValue(new BsonBinary(binary)) override def writeList(): ListOutput = new BsonValueListOutput(setValue) override def writeObject(): ObjectOutput = new BsonValueObjectOutput(setValue) diff --git a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonWriterOutput.scala b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonWriterOutput.scala index 12e373d57..5490c9e05 100644 --- a/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonWriterOutput.scala +++ b/commons-mongo/src/main/scala/com/avsystem/commons/mongo/BsonWriterOutput.scala @@ -13,6 +13,10 @@ final class BsonWriterOutput(bw: BsonWriter) extends BsonOutput { override def writeLong(long: Long): Unit = bw.writeInt64(long) override def writeTimestamp(millis: Long): Unit = bw.writeDateTime(millis) override def writeDouble(double: Double): Unit = bw.writeDouble(double) + override def writeBigInt(bigInt: BigInt): Unit = + bw.writeBinaryData(new BsonBinary(bigInt.toByteArray)) + override def writeBigDecimal(bigDecimal: BigDecimal): Unit = + bw.writeBinaryData(new BsonBinary(BsonOutput.bigDecimalBytes(bigDecimal))) override def writeBinary(binary: Array[Byte]): Unit = bw.writeBinaryData(new BsonBinary(binary)) override def writeList(): BsonWriterListOutput = { bw.writeStartArray() @@ -33,6 +37,10 @@ final class BsonWriterNamedOutput(escapedName: String, bw: BsonWriter) extends B override def writeLong(long: Long): Unit = bw.writeInt64(escapedName, long) override def writeTimestamp(millis: Long): Unit = bw.writeDateTime(escapedName, millis) override def writeDouble(double: Double): Unit = bw.writeDouble(escapedName, double) + override def writeBigInt(bigInt: BigInt): Unit = + bw.writeBinaryData(escapedName, new BsonBinary(bigInt.toByteArray)) + override def writeBigDecimal(bigDecimal: BigDecimal): Unit = + bw.writeBinaryData(escapedName, new BsonBinary(BsonOutput.bigDecimalBytes(bigDecimal))) override def writeBinary(binary: Array[Byte]): Unit = bw.writeBinaryData(escapedName, new BsonBinary(binary)) override def writeList(): BsonWriterListOutput = { bw.writeStartArray(escapedName) diff --git a/commons-mongo/src/test/scala/com/avsystem/commons/mongo/BigDecimalEncodingTest.scala b/commons-mongo/src/test/scala/com/avsystem/commons/mongo/BigDecimalEncodingTest.scala new file mode 100644 index 000000000..a77014a1e --- /dev/null +++ b/commons-mongo/src/test/scala/com/avsystem/commons/mongo/BigDecimalEncodingTest.scala @@ -0,0 +1,13 @@ +package com.avsystem.commons +package mongo + +import org.scalatest.FunSuite +import org.scalatest.prop.PropertyChecks + +class BigDecimalEncodingTest extends FunSuite with PropertyChecks { + test("BigDecimal BSON encoding") { + forAll { value: BigDecimal => + assert(value == BsonInput.bigDecimalFromBytes(BsonOutput.bigDecimalBytes(value))) + } + } +}