Skip to content

Commit 69f487d

Browse files
a-ilinstapelberg
authored andcommitted
set: Add set support for size specifier
Handle attribute NFTNL_SET_DESC_SIZE, as done in libnftnl: https://git.netfilter.org/libnftnl/tree/src/set.c#n424 Example: nft add set ip filter myset { type ipv4_addr\; size 65535\; flags dynamic\; }
1 parent b011eb1 commit 69f487d

File tree

3 files changed

+161
-2
lines changed

3 files changed

+161
-2
lines changed

nftables_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4103,6 +4103,46 @@ func TestSetElementsInterval(t *testing.T) {
41034103
}
41044104
}
41054105

4106+
func TestSetSizeConcat(t *testing.T) {
4107+
// Create a new network namespace to test these operations,
4108+
// and tear down the namespace at test completion.
4109+
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
4110+
defer nftest.CleanupSystemConn(t, newNS)
4111+
// Clear all rules at the beginning + end of the test.
4112+
c.FlushRuleset()
4113+
defer c.FlushRuleset()
4114+
4115+
filter := c.AddTable(&nftables.Table{
4116+
Family: nftables.TableFamilyIPv6,
4117+
Name: "filter",
4118+
})
4119+
4120+
set := &nftables.Set{
4121+
Name: "test-set",
4122+
Table: filter,
4123+
KeyType: nftables.MustConcatSetType(nftables.TypeIP6Addr, nftables.TypeInetService, nftables.TypeIP6Addr),
4124+
Dynamic: true,
4125+
Concatenation: true,
4126+
Size: 200,
4127+
}
4128+
4129+
if err := c.AddSet(set, nil); err != nil {
4130+
t.Errorf("c.AddSet(set) failed: %v", err)
4131+
}
4132+
4133+
if err := c.Flush(); err != nil {
4134+
t.Errorf("c.Flush() failed: %v", err)
4135+
}
4136+
4137+
sets, err := c.GetSets(filter)
4138+
if err != nil {
4139+
t.Errorf("c.GetSets() failed: %v", err)
4140+
}
4141+
if len(sets) != 1 {
4142+
t.Fatalf("len(sets) = %d, want 1", len(sets))
4143+
}
4144+
}
4145+
41064146
func TestCreateListFlowtable(t *testing.T) {
41074147
c, newNS := nftest.OpenSystemConn(t, *enableSysTests)
41084148
defer nftest.CleanupSystemConn(t, newNS)

set.go

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ type Set struct {
267267
// https://git.netfilter.org/nftables/tree/include/datatype.h?id=d486c9e626405e829221b82d7355558005b26d8a#n109
268268
KeyByteOrder binaryutil.ByteOrder
269269
Comment string
270+
// Indicates that the set has "size" specifier
271+
Size uint32
270272
}
271273

272274
// SetElement represents a data point within a set.
@@ -566,6 +568,21 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
566568
}
567569
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: numberOfElements})
568570
}
571+
572+
var descBytes []byte
573+
574+
if s.Size > 0 {
575+
// Marshal set size description
576+
descSizeBytes, err := netlink.MarshalAttributes([]netlink.Attribute{
577+
{Type: unix.NFTA_SET_DESC_SIZE, Data: binaryutil.BigEndian.PutUint32(s.Size)},
578+
})
579+
if err != nil {
580+
return fmt.Errorf("fail to marshal set size description: %w", err)
581+
}
582+
583+
descBytes = append(descBytes, descSizeBytes...)
584+
}
585+
569586
if s.Concatenation {
570587
// Length of concatenated types is a must, otherwise segfaults when executing nft list ruleset
571588
var concatDefinition []byte
@@ -592,8 +609,13 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error {
592609
if err != nil {
593610
return fmt.Errorf("fail to marshal concat definition %v", err)
594611
}
595-
// Marshal concat size description as set description
596-
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: concatBytes})
612+
613+
descBytes = append(descBytes, concatBytes...)
614+
}
615+
616+
if len(descBytes) > 0 {
617+
// Marshal set description
618+
tableInfo = append(tableInfo, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_SET_DESC, Data: descBytes})
597619
}
598620

599621
// https://git.netfilter.org/libnftnl/tree/include/udata.h#n17
@@ -776,6 +798,20 @@ func setsFromMsg(msg netlink.Message) (*Set, error) {
776798
data := ad.Bytes()
777799
value, ok := userdata.GetUint32(data, userdata.NFTNL_UDATA_SET_MERGE_ELEMENTS)
778800
set.AutoMerge = ok && value == 1
801+
case unix.NFTA_SET_DESC:
802+
nestedAD, err := netlink.NewAttributeDecoder(ad.Bytes())
803+
if err != nil {
804+
return nil, fmt.Errorf("nested NewAttributeDecoder() failed: %w", err)
805+
}
806+
for nestedAD.Next() {
807+
switch nestedAD.Type() {
808+
case unix.NFTA_SET_DESC_SIZE:
809+
set.Size = binary.BigEndian.Uint32(nestedAD.Bytes())
810+
}
811+
}
812+
if nestedAD.Err() != nil {
813+
return nil, fmt.Errorf("decoding set description: %w", nestedAD.Err())
814+
}
779815
}
780816
}
781817
return &set, nil

set_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
package nftables
22

33
import (
4+
"reflect"
45
"testing"
6+
"time"
7+
8+
"github.com/mdlayher/netlink"
59
)
610

711
// unknownNFTMagic is an nftMagic value that's unhandled by this
@@ -185,3 +189,82 @@ func TestConcatSetTypeElements(t *testing.T) {
185189
})
186190
}
187191
}
192+
193+
func TestMarshalSet(t *testing.T) {
194+
t.Parallel()
195+
196+
tbl := &Table{
197+
Name: "ipv4table",
198+
Family: TableFamilyIPv4,
199+
}
200+
201+
c, err := New(WithTestDial(
202+
func(req []netlink.Message) ([]netlink.Message, error) {
203+
return req, nil
204+
}))
205+
if err != nil {
206+
t.Fatal(err)
207+
}
208+
209+
c.AddTable(tbl)
210+
211+
// Ensure the table is added.
212+
const connMsgStart = 1
213+
if len(c.messages) != connMsgStart {
214+
t.Fatalf("AddSet() wrong start message count: %d, expected: %d", len(c.messages), connMsgStart)
215+
}
216+
217+
tests := []struct {
218+
name string
219+
set Set
220+
}{
221+
{
222+
name: "Set without flags",
223+
set: Set{
224+
Name: "test-set",
225+
ID: uint32(1),
226+
Table: tbl,
227+
KeyType: TypeIPAddr,
228+
},
229+
},
230+
{
231+
name: "Set with size, timeout, dynamic flag specified",
232+
set: Set{
233+
Name: "test-set",
234+
ID: uint32(2),
235+
HasTimeout: true,
236+
Dynamic: true,
237+
Size: 10,
238+
Table: tbl,
239+
KeyType: TypeIPAddr,
240+
Timeout: 30 * time.Second,
241+
},
242+
},
243+
}
244+
245+
for i, tt := range tests {
246+
t.Run(tt.name, func(t *testing.T) {
247+
if err := c.AddSet(&tt.set, nil); err != nil {
248+
t.Fatal(err)
249+
}
250+
251+
connMsgSetIdx := connMsgStart + i
252+
if len(c.messages) != connMsgSetIdx+1 {
253+
t.Fatalf("AddSet() wrong message count: %d, expected: %d", len(c.messages), connMsgSetIdx+1)
254+
}
255+
msg := c.messages[connMsgSetIdx]
256+
257+
nset, err := setsFromMsg(msg)
258+
if err != nil {
259+
t.Fatalf("setsFromMsg() error: %+v", err)
260+
}
261+
262+
// Table pointer is set after flush, which is not implemented in the test.
263+
tt.set.Table = nil
264+
265+
if !reflect.DeepEqual(&tt.set, nset) {
266+
t.Fatalf("original %+v and recovered %+v Set structs are different", tt.set, nset)
267+
}
268+
})
269+
}
270+
}

0 commit comments

Comments
 (0)