Skip to content

Commit e852f4c

Browse files
committed
tuntest: split out testing package
This code is useful to other packages writing tests. Signed-off-by: David Crawshaw <[email protected]>
1 parent d127a16 commit e852f4c

File tree

2 files changed

+155
-141
lines changed

2 files changed

+155
-141
lines changed

device/device_test.go

Lines changed: 5 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,12 @@ package device
88
import (
99
"bufio"
1010
"bytes"
11-
"encoding/binary"
12-
"io"
1311
"net"
14-
"os"
1512
"strings"
1613
"testing"
1714
"time"
1815

19-
"golang.zx2c4.com/wireguard/tun"
16+
"golang.zx2c4.com/wireguard/tun/tuntest"
2017
)
2118

2219
func TestTwoDevicePing(t *testing.T) {
@@ -29,7 +26,7 @@ protocol_version=1
2926
replace_allowed_ips=true
3027
allowed_ip=1.0.0.2/32
3128
endpoint=127.0.0.1:53512`
32-
tun1 := NewChannelTUN()
29+
tun1 := tuntest.NewChannelTUN()
3330
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
3431
dev1.Up()
3532
defer dev1.Close()
@@ -45,7 +42,7 @@ protocol_version=1
4542
replace_allowed_ips=true
4643
allowed_ip=1.0.0.1/32
4744
endpoint=127.0.0.1:53511`
48-
tun2 := NewChannelTUN()
45+
tun2 := tuntest.NewChannelTUN()
4946
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
5047
dev2.Up()
5148
defer dev2.Close()
@@ -54,7 +51,7 @@ endpoint=127.0.0.1:53511`
5451
}
5552

5653
t.Run("ping 1.0.0.1", func(t *testing.T) {
57-
msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
54+
msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
5855
tun2.Outbound <- msg2to1
5956
select {
6057
case msgRecv := <-tun1.Inbound:
@@ -67,7 +64,7 @@ endpoint=127.0.0.1:53511`
6764
})
6865

6966
t.Run("ping 1.0.0.2", func(t *testing.T) {
70-
msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
67+
msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
7168
tun1.Outbound <- msg1to2
7269
select {
7370
case msgRecv := <-tun2.Inbound:
@@ -80,139 +77,6 @@ endpoint=127.0.0.1:53511`
8077
})
8178
}
8279

83-
func ping(dst, src net.IP) []byte {
84-
localPort := uint16(1337)
85-
seq := uint16(0)
86-
87-
payload := make([]byte, 4)
88-
binary.BigEndian.PutUint16(payload[0:], localPort)
89-
binary.BigEndian.PutUint16(payload[2:], seq)
90-
91-
return genICMPv4(payload, dst, src)
92-
}
93-
94-
// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
95-
func checksum(buf []byte, initial uint16) uint16 {
96-
v := uint32(initial)
97-
for i := 0; i < len(buf)-1; i += 2 {
98-
v += uint32(binary.BigEndian.Uint16(buf[i:]))
99-
}
100-
if len(buf)%2 == 1 {
101-
v += uint32(buf[len(buf)-1]) << 8
102-
}
103-
for v > 0xffff {
104-
v = (v >> 16) + (v & 0xffff)
105-
}
106-
return ^uint16(v)
107-
}
108-
109-
func genICMPv4(payload []byte, dst, src net.IP) []byte {
110-
const (
111-
icmpv4ProtocolNumber = 1
112-
icmpv4Echo = 8
113-
icmpv4ChecksumOffset = 2
114-
icmpv4Size = 8
115-
ipv4Size = 20
116-
ipv4TotalLenOffset = 2
117-
ipv4ChecksumOffset = 10
118-
ttl = 65
119-
)
120-
121-
hdr := make([]byte, ipv4Size+icmpv4Size)
122-
123-
ip := hdr[0:ipv4Size]
124-
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
125-
126-
// https://tools.ietf.org/html/rfc792
127-
icmpv4[0] = icmpv4Echo // type
128-
icmpv4[1] = 0 // code
129-
chksum := ^checksum(icmpv4, checksum(payload, 0))
130-
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
131-
132-
// https://tools.ietf.org/html/rfc760 section 3.1
133-
length := uint16(len(hdr) + len(payload))
134-
ip[0] = (4 << 4) | (ipv4Size / 4)
135-
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
136-
ip[8] = ttl
137-
ip[9] = icmpv4ProtocolNumber
138-
copy(ip[12:], src.To4())
139-
copy(ip[16:], dst.To4())
140-
chksum = ^checksum(ip[:], 0)
141-
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
142-
143-
var v []byte
144-
v = append(v, hdr...)
145-
v = append(v, payload...)
146-
return []byte(v)
147-
}
148-
149-
// TODO(crawshaw): find a reusable home for this. package devicetest?
150-
type ChannelTUN struct {
151-
Inbound chan []byte // incoming packets, closed on TUN close
152-
Outbound chan []byte // outbound packets, blocks forever on TUN close
153-
154-
closed chan struct{}
155-
events chan tun.Event
156-
tun chTun
157-
}
158-
159-
func NewChannelTUN() *ChannelTUN {
160-
c := &ChannelTUN{
161-
Inbound: make(chan []byte),
162-
Outbound: make(chan []byte),
163-
closed: make(chan struct{}),
164-
events: make(chan tun.Event, 1),
165-
}
166-
c.tun.c = c
167-
c.events <- tun.EventUp
168-
return c
169-
}
170-
171-
func (c *ChannelTUN) TUN() tun.Device {
172-
return &c.tun
173-
}
174-
175-
type chTun struct {
176-
c *ChannelTUN
177-
}
178-
179-
func (t *chTun) File() *os.File { return nil }
180-
181-
func (t *chTun) Read(data []byte, offset int) (int, error) {
182-
select {
183-
case <-t.c.closed:
184-
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
185-
case msg := <-t.c.Outbound:
186-
return copy(data[offset:], msg), nil
187-
}
188-
}
189-
190-
// Write is called by the wireguard device to deliver a packet for routing.
191-
func (t *chTun) Write(data []byte, offset int) (int, error) {
192-
if offset == -1 {
193-
close(t.c.closed)
194-
close(t.c.events)
195-
return 0, io.EOF
196-
}
197-
msg := make([]byte, len(data)-offset)
198-
copy(msg, data[offset:])
199-
select {
200-
case <-t.c.closed:
201-
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
202-
case t.c.Inbound <- msg:
203-
return len(data) - offset, nil
204-
}
205-
}
206-
207-
func (t *chTun) Flush() error { return nil }
208-
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
209-
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
210-
func (t *chTun) Events() chan tun.Event { return t.c.events }
211-
func (t *chTun) Close() error {
212-
t.Write(nil, -1)
213-
return nil
214-
}
215-
21680
func assertNil(t *testing.T, err error) {
21781
if err != nil {
21882
t.Fatal(err)

tun/tuntest/tuntest.go

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/* SPDX-License-Identifier: MIT
2+
*
3+
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
4+
*/
5+
6+
package tuntest
7+
8+
import (
9+
"encoding/binary"
10+
"io"
11+
"net"
12+
"os"
13+
14+
"golang.zx2c4.com/wireguard/tun"
15+
)
16+
17+
func Ping(dst, src net.IP) []byte {
18+
localPort := uint16(1337)
19+
seq := uint16(0)
20+
21+
payload := make([]byte, 4)
22+
binary.BigEndian.PutUint16(payload[0:], localPort)
23+
binary.BigEndian.PutUint16(payload[2:], seq)
24+
25+
return genICMPv4(payload, dst, src)
26+
}
27+
28+
// Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
29+
func checksum(buf []byte, initial uint16) uint16 {
30+
v := uint32(initial)
31+
for i := 0; i < len(buf)-1; i += 2 {
32+
v += uint32(binary.BigEndian.Uint16(buf[i:]))
33+
}
34+
if len(buf)%2 == 1 {
35+
v += uint32(buf[len(buf)-1]) << 8
36+
}
37+
for v > 0xffff {
38+
v = (v >> 16) + (v & 0xffff)
39+
}
40+
return ^uint16(v)
41+
}
42+
43+
func genICMPv4(payload []byte, dst, src net.IP) []byte {
44+
const (
45+
icmpv4ProtocolNumber = 1
46+
icmpv4Echo = 8
47+
icmpv4ChecksumOffset = 2
48+
icmpv4Size = 8
49+
ipv4Size = 20
50+
ipv4TotalLenOffset = 2
51+
ipv4ChecksumOffset = 10
52+
ttl = 65
53+
)
54+
55+
hdr := make([]byte, ipv4Size+icmpv4Size)
56+
57+
ip := hdr[0:ipv4Size]
58+
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
59+
60+
// https://tools.ietf.org/html/rfc792
61+
icmpv4[0] = icmpv4Echo // type
62+
icmpv4[1] = 0 // code
63+
chksum := ^checksum(icmpv4, checksum(payload, 0))
64+
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
65+
66+
// https://tools.ietf.org/html/rfc760 section 3.1
67+
length := uint16(len(hdr) + len(payload))
68+
ip[0] = (4 << 4) | (ipv4Size / 4)
69+
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
70+
ip[8] = ttl
71+
ip[9] = icmpv4ProtocolNumber
72+
copy(ip[12:], src.To4())
73+
copy(ip[16:], dst.To4())
74+
chksum = ^checksum(ip[:], 0)
75+
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
76+
77+
var v []byte
78+
v = append(v, hdr...)
79+
v = append(v, payload...)
80+
return []byte(v)
81+
}
82+
83+
// TODO(crawshaw): find a reusable home for this. package devicetest?
84+
type ChannelTUN struct {
85+
Inbound chan []byte // incoming packets, closed on TUN close
86+
Outbound chan []byte // outbound packets, blocks forever on TUN close
87+
88+
closed chan struct{}
89+
events chan tun.Event
90+
tun chTun
91+
}
92+
93+
func NewChannelTUN() *ChannelTUN {
94+
c := &ChannelTUN{
95+
Inbound: make(chan []byte),
96+
Outbound: make(chan []byte),
97+
closed: make(chan struct{}),
98+
events: make(chan tun.Event, 1),
99+
}
100+
c.tun.c = c
101+
c.events <- tun.EventUp
102+
return c
103+
}
104+
105+
func (c *ChannelTUN) TUN() tun.Device {
106+
return &c.tun
107+
}
108+
109+
type chTun struct {
110+
c *ChannelTUN
111+
}
112+
113+
func (t *chTun) File() *os.File { return nil }
114+
115+
func (t *chTun) Read(data []byte, offset int) (int, error) {
116+
select {
117+
case <-t.c.closed:
118+
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
119+
case msg := <-t.c.Outbound:
120+
return copy(data[offset:], msg), nil
121+
}
122+
}
123+
124+
// Write is called by the wireguard device to deliver a packet for routing.
125+
func (t *chTun) Write(data []byte, offset int) (int, error) {
126+
if offset == -1 {
127+
close(t.c.closed)
128+
close(t.c.events)
129+
return 0, io.EOF
130+
}
131+
msg := make([]byte, len(data)-offset)
132+
copy(msg, data[offset:])
133+
select {
134+
case <-t.c.closed:
135+
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
136+
case t.c.Inbound <- msg:
137+
return len(data) - offset, nil
138+
}
139+
}
140+
141+
const DefaultMTU = 1420
142+
143+
func (t *chTun) Flush() error { return nil }
144+
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
145+
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
146+
func (t *chTun) Events() chan tun.Event { return t.c.events }
147+
func (t *chTun) Close() error {
148+
t.Write(nil, -1)
149+
return nil
150+
}

0 commit comments

Comments
 (0)