Skip to content

Commit cf34aa2

Browse files
BioTurboNickAndy FerrisvtjnashSeelengrab
authored
Round-trip reinterpret of all isbits types (#47116)
Hiding padding bytes in the process, to avoid undefined behavior if those are observed. Co-authored-by: Andy Ferris <[email protected]> Co-authored-by: Jameson Nash <[email protected]> Co-authored-by: Sukera <[email protected]>
1 parent 663c58d commit cf34aa2

File tree

4 files changed

+161
-19
lines changed

4 files changed

+161
-19
lines changed

base/reinterpretarray.jl

Lines changed: 131 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -651,8 +651,8 @@ end
651651

652652
# Padding
653653
struct Padding
654-
offset::Int
655-
size::Int
654+
offset::Int # 0-indexed offset of the next valid byte; sizeof(T) indicates trailing padding
655+
size::Int # bytes of padding before a valid byte
656656
end
657657
function intersect(p1::Padding, p2::Padding)
658658
start = max(p1.offset, p2.offset)
@@ -696,20 +696,24 @@ function iterate(cp::CyclePadding, state::Tuple)
696696
end
697697

698698
"""
699-
Compute the location of padding in a type.
699+
Compute the location of padding in an isbits datatype. Recursive over the fields of that type.
700700
"""
701-
function padding(T)
702-
padding = Padding[]
703-
last_end::Int = 0
701+
@assume_effects :foldable function padding(T::DataType, baseoffset::Int = 0)
702+
pads = Padding[]
703+
last_end::Int = baseoffset
704704
for i = 1:fieldcount(T)
705-
offset = fieldoffset(T, i)
705+
offset = baseoffset + Int(fieldoffset(T, i))
706706
fT = fieldtype(T, i)
707+
append!(pads, padding(fT, offset))
707708
if offset != last_end
708-
push!(padding, Padding(offset, offset-last_end))
709+
push!(pads, Padding(offset, offset-last_end))
709710
end
710711
last_end = offset + sizeof(fT)
711712
end
712-
padding
713+
if 0 < last_end - baseoffset < sizeof(T)
714+
push!(pads, Padding(baseoffset + sizeof(T), sizeof(T) - last_end + baseoffset))
715+
end
716+
return Core.svec(pads...)
713717
end
714718

715719
function CyclePadding(T::DataType)
@@ -748,6 +752,124 @@ end
748752
return true
749753
end
750754

755+
@assume_effects :foldable function struct_subpadding(::Type{Out}, ::Type{In}) where {Out, In}
756+
padding(Out) == padding(In)
757+
end
758+
759+
@assume_effects :foldable function packedsize(::Type{T}) where T
760+
pads = padding(T)
761+
return sizeof(T) - sum((p.size for p pads), init = 0)
762+
end
763+
764+
@assume_effects :foldable ispacked(::Type{T}) where T = isempty(padding(T))
765+
766+
function _copytopacked!(ptr_out::Ptr{Out}, ptr_in::Ptr{In}) where {Out, In}
767+
writeoffset = 0
768+
for i 1:fieldcount(In)
769+
readoffset = fieldoffset(In, i)
770+
fT = fieldtype(In, i)
771+
if ispacked(fT)
772+
readsize = sizeof(fT)
773+
memcpy(ptr_out + writeoffset, ptr_in + readoffset, readsize)
774+
writeoffset += readsize
775+
else # nested padded type
776+
_copytopacked!(ptr_out + writeoffset, Ptr{fT}(ptr_in + readoffset))
777+
writeoffset += packedsize(fT)
778+
end
779+
end
780+
end
781+
782+
function _copyfrompacked!(ptr_out::Ptr{Out}, ptr_in::Ptr{In}) where {Out, In}
783+
readoffset = 0
784+
for i 1:fieldcount(Out)
785+
writeoffset = fieldoffset(Out, i)
786+
fT = fieldtype(Out, i)
787+
if ispacked(fT)
788+
writesize = sizeof(fT)
789+
memcpy(ptr_out + writeoffset, ptr_in + readoffset, writesize)
790+
readoffset += writesize
791+
else # nested padded type
792+
_copyfrompacked!(Ptr{fT}(ptr_out + writeoffset), ptr_in + readoffset)
793+
readoffset += packedsize(fT)
794+
end
795+
end
796+
end
797+
798+
"""
799+
reinterpret(::Type{Out}, x::In)
800+
801+
Reinterpret the valid non-padding bytes of an isbits value `x` as isbits type `Out`.
802+
803+
Both types must have the same amount of non-padding bytes. This operation is guaranteed
804+
to be reversible.
805+
806+
```jldoctest
807+
julia> reinterpret(NTuple{2, UInt8}, 0x1234)
808+
(0x34, 0x12)
809+
810+
julia> reinterpret(UInt16, (0x34, 0x12))
811+
0x1234
812+
813+
julia> reinterpret(Tuple{UInt16, UInt8}, (0x01, 0x0203))
814+
(0x0301, 0x02)
815+
```
816+
817+
!!! warning
818+
819+
Use caution if some combinations of bits in `Out` are not considered valid and would
820+
otherwise be prevented by the type's constructors and methods. Unexpected behavior
821+
may result without additional validation.
822+
"""
823+
@inline function reinterpret(::Type{Out}, x::In) where {Out, In}
824+
isbitstype(Out) || throw(ArgumentError("Target type for `reinterpret` must be isbits"))
825+
isbitstype(In) || throw(ArgumentError("Source type for `reinterpret` must be isbits"))
826+
if isprimitivetype(Out) && isprimitivetype(In)
827+
outsize = sizeof(Out)
828+
insize = sizeof(In)
829+
outsize == insize ||
830+
throw(ArgumentError("Sizes of types $Out and $In do not match; got $outsize \
831+
and $insize, respectively."))
832+
return bitcast(Out, x)
833+
end
834+
inpackedsize = packedsize(In)
835+
outpackedsize = packedsize(Out)
836+
inpackedsize == outpackedsize ||
837+
throw(ArgumentError("Packed sizes of types $Out and $In do not match; got $outpackedsize \
838+
and $inpackedsize, respectively."))
839+
in = Ref{In}(x)
840+
out = Ref{Out}()
841+
if struct_subpadding(Out, In)
842+
# if packed the same, just copy
843+
GC.@preserve in out begin
844+
ptr_in = unsafe_convert(Ptr{In}, in)
845+
ptr_out = unsafe_convert(Ptr{Out}, out)
846+
memcpy(ptr_out, ptr_in, sizeof(Out))
847+
end
848+
return out[]
849+
else
850+
# mismatched padding
851+
GC.@preserve in out begin
852+
ptr_in = unsafe_convert(Ptr{In}, in)
853+
ptr_out = unsafe_convert(Ptr{Out}, out)
854+
855+
if fieldcount(In) > 0 && ispacked(Out)
856+
_copytopacked!(ptr_out, ptr_in)
857+
elseif fieldcount(Out) > 0 && ispacked(In)
858+
_copyfrompacked!(ptr_out, ptr_in)
859+
else
860+
packed = Ref{NTuple{inpackedsize, UInt8}}()
861+
GC.@preserve packed begin
862+
ptr_packed = unsafe_convert(Ptr{NTuple{inpackedsize, UInt8}}, packed)
863+
_copytopacked!(ptr_packed, ptr_in)
864+
_copyfrompacked!(ptr_out, ptr_packed)
865+
end
866+
end
867+
end
868+
return out[]
869+
end
870+
end
871+
872+
751873
# Reductions with IndexSCartesian2
752874

753875
function _mapreduce(f::F, op::OP, style::IndexSCartesian2{K}, A::AbstractArrayOrBroadcasted) where {F,OP,K}

test/core.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1897,7 +1897,7 @@ function f4528(A, B)
18971897
end
18981898
end
18991899
@test f4528(false, Int32(12)) === nothing
1900-
@test_throws ErrorException f4528(true, Int32(12))
1900+
@test_throws ArgumentError f4528(true, Int32(12))
19011901

19021902
# issue #4518
19031903
f4518(x, y::Union{Int32,Int64}) = 0

test/numbers.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2216,13 +2216,11 @@ end
22162216
@test round(Int16, -32768.1) === Int16(-32768)
22172217
end
22182218
# issue #7508
2219-
@test_throws ErrorException reinterpret(Int, 0x01)
2219+
@test_throws ArgumentError reinterpret(Int, 0x01)
22202220

22212221
@testset "issue #12832" begin
2222-
@test_throws ErrorException reinterpret(Float64, Complex{Int64}(1))
2223-
@test_throws ErrorException reinterpret(Float64, ComplexF32(1))
2224-
@test_throws ErrorException reinterpret(ComplexF32, Float64(1))
2225-
@test_throws ErrorException reinterpret(Int32, false)
2222+
@test_throws ArgumentError reinterpret(Float64, Complex{Int64}(1))
2223+
@test_throws ArgumentError reinterpret(Int32, false)
22262224
end
22272225
# issue #41
22282226
ndigf(n) = Float64(log(Float32(n)))

test/reinterpretarray.jl

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -450,10 +450,10 @@ end
450450
SomeSingleton(x) = new()
451451
end
452452

453-
@test_throws ErrorException reinterpret(Int, nothing)
454-
@test_throws ErrorException reinterpret(Missing, 3)
455-
@test_throws ErrorException reinterpret(Missing, NotASingleton())
456-
@test_throws ErrorException reinterpret(NotASingleton, ())
453+
@test_throws ArgumentError reinterpret(Int, nothing)
454+
@test_throws ArgumentError reinterpret(Missing, 3)
455+
@test_throws ArgumentError reinterpret(Missing, NotASingleton())
456+
@test_throws ArgumentError reinterpret(NotASingleton, ())
457457

458458
@test_throws ArgumentError reinterpret(NotASingleton, fill(nothing, ()))
459459
@test_throws ArgumentError reinterpret(reshape, NotASingleton, fill(missing, 3))
@@ -513,3 +513,25 @@ end
513513
@test setindex!(x, SomeSingleton(:), 3, 5) == x2
514514
@test_throws MethodError x[2,4] = nothing
515515
end
516+
517+
# reinterpret of arbitrary bitstypes
518+
@testset "Reinterpret arbitrary bitstypes" begin
519+
struct Bytes15
520+
a::Int8
521+
b::Int16
522+
c::Int32
523+
d::Int64
524+
end
525+
526+
@test reinterpret(Float64, ComplexF32(1, 1)) === 0.007812501848093234
527+
@test reinterpret(ComplexF32, 0.007812501848093234) === ComplexF32(1, 1)
528+
@test reinterpret(Tuple{Float64, Float64}, ComplexF64(1, 1)) === (1.0, 1.0)
529+
@test reinterpret(ComplexF64, (1.0, 1.0)) === ComplexF64(1, 1)
530+
@test reinterpret(Tuple{Int8, Int16, Int32, Int64}, (Int64(1), Int32(2), Int16(3), Int8(4))) === (Int8(1), Int16(0), Int32(0), 288233674686595584)
531+
@test reinterpret(Tuple{Int8, Int16, Tuple{Int32, Int64}}, (Int64(1), Int32(2), Int16(3), Int8(4))) === (Int8(1), Int16(0), (Int32(0), 288233674686595584))
532+
@test reinterpret(Tuple{Int64, Int32, Int16, Int8}, (Int8(1), Int16(0), (Int32(0), 288233674686595584))) === (Int64(1), Int32(2), Int16(3), Int8(4))
533+
@test reinterpret(Tuple{Int8, Int16, Int32, Int64}, Bytes15(Int8(1), Int16(2), Int32(3), Int64(4))) === (Int8(1), Int16(2), Int32(3), Int64(4))
534+
@test reinterpret(Bytes15, (Int8(1), Int16(2), Int32(3), Int64(4))) == Bytes15(Int8(1), Int16(2), Int32(3), Int64(4))
535+
536+
@test_throws ArgumentError reinterpret(Tuple{Int32, Int64}, (Int16(1), Int64(4)))
537+
end

0 commit comments

Comments
 (0)