From 385f25a20c537acf97b3c5c008d1173a290a5ff0 Mon Sep 17 00:00:00 2001 From: Ken Schneider <103530259+ken-schneider@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:54:42 -0500 Subject: [PATCH] [NETPATH-371] Move common functions to separate package, create separate testutils package (#31819) --- pkg/networkpath/traceroute/common/common.go | 160 ++++++++++++ .../traceroute/common/common_test.go | 117 +++++++++ pkg/networkpath/traceroute/runner/runner.go | 3 +- pkg/networkpath/traceroute/tcp/tcpv4.go | 22 -- pkg/networkpath/traceroute/tcp/tcpv4_unix.go | 13 +- .../traceroute/tcp/tcpv4_windows.go | 13 +- pkg/networkpath/traceroute/tcp/utils.go | 125 +--------- pkg/networkpath/traceroute/tcp/utils_test.go | 235 +----------------- pkg/networkpath/traceroute/tcp/utils_unix.go | 11 +- .../traceroute/tcp/utils_unix_test.go | 10 +- .../traceroute/tcp/utils_windows.go | 9 +- .../traceroute/tcp/utils_windows_test.go | 5 +- pkg/networkpath/traceroute/testutils/doc.go | 7 + .../traceroute/testutils/testutils.go | 150 +++++++++++ 14 files changed, 481 insertions(+), 399 deletions(-) create mode 100644 pkg/networkpath/traceroute/common/common.go create mode 100644 pkg/networkpath/traceroute/common/common_test.go create mode 100644 pkg/networkpath/traceroute/testutils/doc.go create mode 100644 pkg/networkpath/traceroute/testutils/testutils.go diff --git a/pkg/networkpath/traceroute/common/common.go b/pkg/networkpath/traceroute/common/common.go new file mode 100644 index 0000000000000..c7622fa391891 --- /dev/null +++ b/pkg/networkpath/traceroute/common/common.go @@ -0,0 +1,160 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016-present Datadog, Inc. + +// Package common contains common functionality for both TCP and UDP +// traceroute implementations +package common + +import ( + "fmt" + "net" + "strconv" + "time" + + "github.com/DataDog/datadog-agent/pkg/util/log" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "golang.org/x/net/ipv4" +) + +const ( + // IPProtoICMP is the IP protocol number for ICMP + // we create our own constant here because there are + // different imports for the constant in different + // operating systems + IPProtoICMP = 1 +) + +type ( + // Results encapsulates a response from the + // traceroute + Results struct { + Source net.IP + SourcePort uint16 + Target net.IP + DstPort uint16 + Hops []*Hop + } + + // Hop encapsulates information about a single + // hop in a traceroute + Hop struct { + IP net.IP + Port uint16 + ICMPType layers.ICMPv4TypeCode + RTT time.Duration + IsDest bool + } + + // CanceledError is sent when a listener + // is canceled + CanceledError string + + // ICMPResponse encapsulates the data from + // an ICMP response packet needed for matching + ICMPResponse struct { + SrcIP net.IP + DstIP net.IP + TypeCode layers.ICMPv4TypeCode + InnerSrcIP net.IP + InnerDstIP net.IP + InnerSrcPort uint16 + InnerDstPort uint16 + InnerSeqNum uint32 + } +) + +func (c CanceledError) Error() string { + return string(c) +} + +// LocalAddrForHost takes in a destionation IP and port and returns the local +// address that should be used to connect to the destination +func LocalAddrForHost(destIP net.IP, destPort uint16) (*net.UDPAddr, error) { + // this is a quick way to get the local address for connecting to the host + // using UDP as the network type to avoid actually creating a connection to + // the host, just get the OS to give us a local IP and local ephemeral port + conn, err := net.Dial("udp4", net.JoinHostPort(destIP.String(), strconv.Itoa(int(destPort)))) + if err != nil { + return nil, err + } + defer conn.Close() + localAddr := conn.LocalAddr() + + localUDPAddr, ok := localAddr.(*net.UDPAddr) + if !ok { + return nil, fmt.Errorf("invalid address type for %s: want %T, got %T", localAddr, localUDPAddr, localAddr) + } + + return localUDPAddr, nil +} + +// ParseICMP takes in an IPv4 header and payload and tries to convert to an ICMP +// message, it returns all the fields from the packet we need to validate it's the response +// we're looking for +func ParseICMP(header *ipv4.Header, payload []byte) (*ICMPResponse, error) { + // in addition to parsing, it is probably not a bad idea to do some validation + // so we can ignore the ICMP packets we don't care about + icmpResponse := ICMPResponse{} + + if header.Protocol != IPProtoICMP || header.Version != 4 || + header.Src == nil || header.Dst == nil { + return nil, fmt.Errorf("invalid IP header for ICMP packet: %+v", header) + } + icmpResponse.SrcIP = header.Src + icmpResponse.DstIP = header.Dst + + var icmpv4Layer layers.ICMPv4 + decoded := []gopacket.LayerType{} + icmpParser := gopacket.NewDecodingLayerParser(layers.LayerTypeICMPv4, &icmpv4Layer) + icmpParser.IgnoreUnsupported = true // ignore unsupported layers, we will decode them in the next step + if err := icmpParser.DecodeLayers(payload, &decoded); err != nil { + return nil, fmt.Errorf("failed to decode ICMP packet: %w", err) + } + // since we ignore unsupported layers, we need to check if we actually decoded + // anything + if len(decoded) < 1 { + return nil, fmt.Errorf("failed to decode ICMP packet, no layers decoded") + } + icmpResponse.TypeCode = icmpv4Layer.TypeCode + + var icmpPayload []byte + if len(icmpv4Layer.Payload) < 40 { + log.Tracef("Payload length %d is less than 40, extending...\n", len(icmpv4Layer.Payload)) + icmpPayload = make([]byte, 40) + copy(icmpPayload, icmpv4Layer.Payload) + // we have to set this in order for the TCP + // parser to work + icmpPayload[32] = 5 << 4 // set data offset + } else { + icmpPayload = icmpv4Layer.Payload + } + + // a separate parser is needed to decode the inner IP and TCP headers because + // gopacket doesn't support this type of nesting in a single decoder + var innerIPLayer layers.IPv4 + var innerTCPLayer layers.TCP + innerIPParser := gopacket.NewDecodingLayerParser(layers.LayerTypeIPv4, &innerIPLayer, &innerTCPLayer) + if err := innerIPParser.DecodeLayers(icmpPayload, &decoded); err != nil { + return nil, fmt.Errorf("failed to decode inner ICMP payload: %w", err) + } + icmpResponse.InnerSrcIP = innerIPLayer.SrcIP + icmpResponse.InnerDstIP = innerIPLayer.DstIP + icmpResponse.InnerSrcPort = uint16(innerTCPLayer.SrcPort) + icmpResponse.InnerDstPort = uint16(innerTCPLayer.DstPort) + icmpResponse.InnerSeqNum = innerTCPLayer.Seq + + return &icmpResponse, nil +} + +// ICMPMatch checks if an ICMP response matches the expected response +// based on the local and remote IP, port, and sequence number +func ICMPMatch(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, seqNum uint32, response *ICMPResponse) bool { + return localIP.Equal(response.InnerSrcIP) && + remoteIP.Equal(response.InnerDstIP) && + localPort == response.InnerSrcPort && + remotePort == response.InnerDstPort && + seqNum == response.InnerSeqNum +} diff --git a/pkg/networkpath/traceroute/common/common_test.go b/pkg/networkpath/traceroute/common/common_test.go new file mode 100644 index 0000000000000..aa6f6cfdf24cb --- /dev/null +++ b/pkg/networkpath/traceroute/common/common_test.go @@ -0,0 +1,117 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016-present Datadog, Inc. + +//go:build test + +package common + +import ( + "net" + "testing" + + "github.com/DataDog/datadog-agent/pkg/networkpath/traceroute/testutils" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/ipv4" +) + +var ( + srcIP = net.ParseIP("1.2.3.4") + dstIP = net.ParseIP("5.6.7.8") + + innerSrcIP = net.ParseIP("10.0.0.1") + innerDstIP = net.ParseIP("192.168.1.1") +) + +func Test_parseICMP(t *testing.T) { + ipv4Header := testutils.CreateMockIPv4Header(srcIP, dstIP, 1) + icmpLayer := testutils.CreateMockICMPLayer(layers.ICMPv4CodeTTLExceeded) + innerIPv4Layer := testutils.CreateMockIPv4Layer(innerSrcIP, innerDstIP, layers.IPProtocolTCP) + innerTCPLayer := testutils.CreateMockTCPLayer(12345, 443, 28394, 12737, true, true, true) + + tt := []struct { + description string + inHeader *ipv4.Header + inPayload []byte + expected *ICMPResponse + errMsg string + }{ + { + description: "empty IPv4 layer should return an error", + inHeader: &ipv4.Header{}, + inPayload: []byte{}, + expected: nil, + errMsg: "invalid IP header for ICMP packet", + }, + { + description: "missing ICMP layer should return an error", + inHeader: ipv4Header, + inPayload: []byte{}, + expected: nil, + errMsg: "failed to decode ICMP packet", + }, + { + description: "missing inner layers should return an error", + inHeader: ipv4Header, + inPayload: testutils.CreateMockICMPPacket(nil, icmpLayer, nil, nil, false), + expected: nil, + errMsg: "failed to decode inner ICMP payload", + }, + { + description: "ICMP packet with partial TCP header should create icmpResponse", + inHeader: ipv4Header, + inPayload: testutils.CreateMockICMPPacket(nil, icmpLayer, innerIPv4Layer, innerTCPLayer, true), + expected: &ICMPResponse{ + SrcIP: srcIP, + DstIP: dstIP, + InnerSrcIP: innerSrcIP, + InnerDstIP: innerDstIP, + InnerSrcPort: 12345, + InnerDstPort: 443, + InnerSeqNum: 28394, + }, + errMsg: "", + }, + { + description: "full ICMP packet should create icmpResponse", + inHeader: ipv4Header, + inPayload: testutils.CreateMockICMPPacket(nil, icmpLayer, innerIPv4Layer, innerTCPLayer, true), + expected: &ICMPResponse{ + SrcIP: srcIP, + DstIP: dstIP, + InnerSrcIP: innerSrcIP, + InnerDstIP: innerDstIP, + InnerSrcPort: 12345, + InnerDstPort: 443, + InnerSeqNum: 28394, + }, + errMsg: "", + }, + } + + for _, test := range tt { + t.Run(test.description, func(t *testing.T) { + actual, err := ParseICMP(test.inHeader, test.inPayload) + if test.errMsg != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), test.errMsg) + assert.Nil(t, actual) + return + } + require.Nil(t, err) + require.NotNil(t, actual) + // assert.Equal doesn't handle net.IP well + assert.Equal(t, testutils.StructFieldCount(test.expected), testutils.StructFieldCount(actual)) + assert.Truef(t, test.expected.SrcIP.Equal(actual.SrcIP), "mismatch source IPs: expected %s, got %s", test.expected.SrcIP.String(), actual.SrcIP.String()) + assert.Truef(t, test.expected.DstIP.Equal(actual.DstIP), "mismatch dest IPs: expected %s, got %s", test.expected.DstIP.String(), actual.DstIP.String()) + assert.Truef(t, test.expected.InnerSrcIP.Equal(actual.InnerSrcIP), "mismatch inner source IPs: expected %s, got %s", test.expected.InnerSrcIP.String(), actual.InnerSrcIP.String()) + assert.Truef(t, test.expected.InnerDstIP.Equal(actual.InnerDstIP), "mismatch inner dest IPs: expected %s, got %s", test.expected.InnerDstIP.String(), actual.InnerDstIP.String()) + assert.Equal(t, test.expected.InnerSrcPort, actual.InnerSrcPort) + assert.Equal(t, test.expected.InnerDstPort, actual.InnerDstPort) + assert.Equal(t, test.expected.InnerSeqNum, actual.InnerSeqNum) + }) + } +} diff --git a/pkg/networkpath/traceroute/runner/runner.go b/pkg/networkpath/traceroute/runner/runner.go index 9087b66b4eb8a..b8240ee6db3cd 100644 --- a/pkg/networkpath/traceroute/runner/runner.go +++ b/pkg/networkpath/traceroute/runner/runner.go @@ -24,6 +24,7 @@ import ( "github.com/DataDog/datadog-agent/pkg/config/setup" "github.com/DataDog/datadog-agent/pkg/network" "github.com/DataDog/datadog-agent/pkg/networkpath/payload" + "github.com/DataDog/datadog-agent/pkg/networkpath/traceroute/common" "github.com/DataDog/datadog-agent/pkg/networkpath/traceroute/config" "github.com/DataDog/datadog-agent/pkg/networkpath/traceroute/tcp" "github.com/DataDog/datadog-agent/pkg/process/util" @@ -222,7 +223,7 @@ func (r *Runner) runTCP(cfg config.Config, hname string, target net.IP, maxTTL u return pathResult, nil } -func (r *Runner) processTCPResults(res *tcp.Results, hname string, destinationHost string, destinationPort uint16, destinationIP net.IP) (payload.NetworkPath, error) { +func (r *Runner) processTCPResults(res *common.Results, hname string, destinationHost string, destinationPort uint16, destinationIP net.IP) (payload.NetworkPath, error) { traceroutePath := payload.NetworkPath{ AgentVersion: version.AgentVersion, PathtraceID: payload.NewPathtraceID(), diff --git a/pkg/networkpath/traceroute/tcp/tcpv4.go b/pkg/networkpath/traceroute/tcp/tcpv4.go index 64484f9c0ad60..ca38edf21cf4c 100644 --- a/pkg/networkpath/traceroute/tcp/tcpv4.go +++ b/pkg/networkpath/traceroute/tcp/tcpv4.go @@ -9,8 +9,6 @@ package tcp import ( "net" "time" - - "github.com/google/gopacket/layers" ) type ( @@ -27,26 +25,6 @@ type ( Delay time.Duration // delay between sending packets (not applicable if we go the serial send/receive route) Timeout time.Duration // full timeout for all packets } - - // Results encapsulates a response from the TCP - // traceroute - Results struct { - Source net.IP - SourcePort uint16 - Target net.IP - DstPort uint16 - Hops []*Hop - } - - // Hop encapsulates information about a single - // hop in a TCP traceroute - Hop struct { - IP net.IP - Port uint16 - ICMPType layers.ICMPv4TypeCode - RTT time.Duration - IsDest bool - } ) // Close doesn't to anything yet, but we should diff --git a/pkg/networkpath/traceroute/tcp/tcpv4_unix.go b/pkg/networkpath/traceroute/tcp/tcpv4_unix.go index 32cf7e19ee11e..f5859d056e00e 100644 --- a/pkg/networkpath/traceroute/tcp/tcpv4_unix.go +++ b/pkg/networkpath/traceroute/tcp/tcpv4_unix.go @@ -16,15 +16,16 @@ import ( "golang.org/x/net/ipv4" + "github.com/DataDog/datadog-agent/pkg/networkpath/traceroute/common" "github.com/DataDog/datadog-agent/pkg/util/log" ) // TracerouteSequential runs a traceroute sequentially where a packet is // sent and we wait for a response before sending the next packet -func (t *TCPv4) TracerouteSequential() (*Results, error) { +func (t *TCPv4) TracerouteSequential() (*common.Results, error) { // Get local address for the interface that connects to this // host and store in in the probe - addr, err := localAddrForHost(t.Target, t.DestPort) + addr, err := common.LocalAddrForHost(t.Target, t.DestPort) if err != nil { return nil, fmt.Errorf("failed to get local address for target: %w", err) } @@ -71,7 +72,7 @@ func (t *TCPv4) TracerouteSequential() (*Results, error) { } // hops should be of length # of hops - hops := make([]*Hop, 0, t.MaxTTL-t.MinTTL) + hops := make([]*common.Hop, 0, t.MaxTTL-t.MinTTL) for i := int(t.MinTTL); i <= int(t.MaxTTL); i++ { seqNumber := rand.Uint32() @@ -88,7 +89,7 @@ func (t *TCPv4) TracerouteSequential() (*Results, error) { } } - return &Results{ + return &common.Results{ Source: t.srcIP, SourcePort: t.srcPort, Target: t.Target, @@ -97,7 +98,7 @@ func (t *TCPv4) TracerouteSequential() (*Results, error) { }, nil } -func (t *TCPv4) sendAndReceive(rawIcmpConn *ipv4.RawConn, rawTCPConn *ipv4.RawConn, ttl int, seqNum uint32, timeout time.Duration) (*Hop, error) { +func (t *TCPv4) sendAndReceive(rawIcmpConn *ipv4.RawConn, rawTCPConn *ipv4.RawConn, ttl int, seqNum uint32, timeout time.Duration) (*common.Hop, error) { tcpHeader, tcpPacket, err := createRawTCPSyn(t.srcIP, t.srcPort, t.Target, t.DestPort, seqNum, ttl) if err != nil { log.Errorf("failed to create TCP packet with TTL: %d, error: %s", ttl, err.Error()) @@ -122,7 +123,7 @@ func (t *TCPv4) sendAndReceive(rawIcmpConn *ipv4.RawConn, rawTCPConn *ipv4.RawCo rtt = end.Sub(start) } - return &Hop{ + return &common.Hop{ IP: hopIP, Port: hopPort, ICMPType: icmpType, diff --git a/pkg/networkpath/traceroute/tcp/tcpv4_windows.go b/pkg/networkpath/traceroute/tcp/tcpv4_windows.go index 3067695b0e559..e9eb74cae702f 100644 --- a/pkg/networkpath/traceroute/tcp/tcpv4_windows.go +++ b/pkg/networkpath/traceroute/tcp/tcpv4_windows.go @@ -14,6 +14,7 @@ import ( "golang.org/x/sys/windows" + "github.com/DataDog/datadog-agent/pkg/networkpath/traceroute/common" "github.com/DataDog/datadog-agent/pkg/util/log" ) @@ -67,14 +68,14 @@ func createRawSocket() (*winrawsocket, error) { // TracerouteSequential runs a traceroute sequentially where a packet is // sent and we wait for a response before sending the next packet -func (t *TCPv4) TracerouteSequential() (*Results, error) { +func (t *TCPv4) TracerouteSequential() (*common.Results, error) { log.Debugf("Running traceroute to %+v", t) // Get local address for the interface that connects to this // host and store in in the probe // // TODO: do this once for the probe and hang on to the // listener until we decide to close the probe - addr, err := localAddrForHost(t.Target, t.DestPort) + addr, err := common.LocalAddrForHost(t.Target, t.DestPort) if err != nil { return nil, fmt.Errorf("failed to get local address for target: %w", err) } @@ -87,7 +88,7 @@ func (t *TCPv4) TracerouteSequential() (*Results, error) { } defer rs.close() - hops := make([]*Hop, 0, int(t.MaxTTL-t.MinTTL)+1) + hops := make([]*common.Hop, 0, int(t.MaxTTL-t.MinTTL)+1) for i := int(t.MinTTL); i <= int(t.MaxTTL); i++ { seqNumber := rand.Uint32() @@ -104,7 +105,7 @@ func (t *TCPv4) TracerouteSequential() (*Results, error) { } } - return &Results{ + return &common.Results{ Source: t.srcIP, SourcePort: t.srcPort, Target: t.Target, @@ -113,7 +114,7 @@ func (t *TCPv4) TracerouteSequential() (*Results, error) { }, nil } -func (t *TCPv4) sendAndReceive(rs *winrawsocket, ttl int, seqNum uint32, timeout time.Duration) (*Hop, error) { +func (t *TCPv4) sendAndReceive(rs *winrawsocket, ttl int, seqNum uint32, timeout time.Duration) (*common.Hop, error) { _, buffer, _, err := createRawTCPSynBuffer(t.srcIP, t.srcPort, t.Target, t.DestPort, seqNum, ttl) if err != nil { log.Errorf("failed to create TCP packet with TTL: %d, error: %s", ttl, err.Error()) @@ -138,7 +139,7 @@ func (t *TCPv4) sendAndReceive(rs *winrawsocket, ttl int, seqNum uint32, timeout rtt = end.Sub(start) } - return &Hop{ + return &common.Hop{ IP: hopIP, Port: hopPort, ICMPType: icmpType, diff --git a/pkg/networkpath/traceroute/tcp/utils.go b/pkg/networkpath/traceroute/tcp/utils.go index ae7d507f23940..3bd9e437e90c7 100644 --- a/pkg/networkpath/traceroute/tcp/utils.go +++ b/pkg/networkpath/traceroute/tcp/utils.go @@ -8,46 +8,14 @@ package tcp import ( "fmt" "net" - "strconv" + "syscall" - "github.com/DataDog/datadog-agent/pkg/util/log" "github.com/google/gopacket" "github.com/google/gopacket/layers" "golang.org/x/net/ipv4" ) -const ( - // ACK is the acknowledge TCP flag - ACK = 1 << 4 - // RST is the reset TCP flag - RST = 1 << 2 - // SYN is the synchronization TCP flag - SYN = 1 << 1 - - // IPProtoICMP is the ICMP protocol number - IPProtoICMP = 1 - // IPProtoTCP is the TCP protocol number - IPProtoTCP = 6 -) - type ( - // canceledError is sent when a listener - // is canceled - canceledError string - - // icmpResponse encapsulates the data from - // an ICMP response packet needed for matching - icmpResponse struct { - SrcIP net.IP - DstIP net.IP - TypeCode layers.ICMPv4TypeCode - InnerSrcIP net.IP - InnerDstIP net.IP - InnerSrcPort uint16 - InnerDstPort uint16 - InnerSeqNum uint32 - } - // tcpResponse encapsulates the data from a // TCP response needed for matching tcpResponse struct { @@ -57,25 +25,6 @@ type ( } ) -func localAddrForHost(destIP net.IP, destPort uint16) (*net.UDPAddr, error) { - // this is a quick way to get the local address for connecting to the host - // using UDP as the network type to avoid actually creating a connection to - // the host, just get the OS to give us a local IP and local ephemeral port - conn, err := net.Dial("udp4", net.JoinHostPort(destIP.String(), strconv.Itoa(int(destPort)))) - if err != nil { - return nil, err - } - defer conn.Close() - localAddr := conn.LocalAddr() - - localUDPAddr, ok := localAddr.(*net.UDPAddr) - if !ok { - return nil, fmt.Errorf("invalid address type for %s: want %T, got %T", localAddr, localUDPAddr, localAddr) - } - - return localUDPAddr, nil -} - // reserveLocalPort reserves an ephemeral TCP port // and returns both the listener and port because the // listener should be held until the port is no longer @@ -145,64 +94,6 @@ func createRawTCPSynBuffer(sourceIP net.IP, sourcePort uint16, destIP net.IP, de return &ipHdr, packet, 20, nil } -// parseICMP takes in an IPv4 header and payload and tries to convert to an ICMP -// message, it returns all the fields from the packet we need to validate it's the response -// we're looking for -func parseICMP(header *ipv4.Header, payload []byte) (*icmpResponse, error) { - // in addition to parsing, it is probably not a bad idea to do some validation - // so we can ignore the ICMP packets we don't care about - icmpResponse := icmpResponse{} - - if header.Protocol != IPProtoICMP || header.Version != 4 || - header.Src == nil || header.Dst == nil { - return nil, fmt.Errorf("invalid IP header for ICMP packet: %+v", header) - } - icmpResponse.SrcIP = header.Src - icmpResponse.DstIP = header.Dst - - var icmpv4Layer layers.ICMPv4 - decoded := []gopacket.LayerType{} - icmpParser := gopacket.NewDecodingLayerParser(layers.LayerTypeICMPv4, &icmpv4Layer) - icmpParser.IgnoreUnsupported = true // ignore unsupported layers, we will decode them in the next step - if err := icmpParser.DecodeLayers(payload, &decoded); err != nil { - return nil, fmt.Errorf("failed to decode ICMP packet: %w", err) - } - // since we ignore unsupported layers, we need to check if we actually decoded - // anything - if len(decoded) < 1 { - return nil, fmt.Errorf("failed to decode ICMP packet, no layers decoded") - } - icmpResponse.TypeCode = icmpv4Layer.TypeCode - - var icmpPayload []byte - if len(icmpv4Layer.Payload) < 40 { - log.Tracef("Payload length %d is less than 40, extending...\n", len(icmpv4Layer.Payload)) - icmpPayload = make([]byte, 40) - copy(icmpPayload, icmpv4Layer.Payload) - // we have to set this in order for the TCP - // parser to work - icmpPayload[32] = 5 << 4 // set data offset - } else { - icmpPayload = icmpv4Layer.Payload - } - - // a separate parser is needed to decode the inner IP and TCP headers because - // gopacket doesn't support this type of nesting in a single decoder - var innerIPLayer layers.IPv4 - var innerTCPLayer layers.TCP - innerIPParser := gopacket.NewDecodingLayerParser(layers.LayerTypeIPv4, &innerIPLayer, &innerTCPLayer) - if err := innerIPParser.DecodeLayers(icmpPayload, &decoded); err != nil { - return nil, fmt.Errorf("failed to decode inner ICMP payload: %w", err) - } - icmpResponse.InnerSrcIP = innerIPLayer.SrcIP - icmpResponse.InnerDstIP = innerIPLayer.DstIP - icmpResponse.InnerSrcPort = uint16(innerTCPLayer.SrcPort) - icmpResponse.InnerDstPort = uint16(innerTCPLayer.DstPort) - icmpResponse.InnerSeqNum = innerTCPLayer.Seq - - return &icmpResponse, nil -} - type tcpParser struct { layer layers.TCP decoded []gopacket.LayerType @@ -216,7 +107,7 @@ func newTCPParser() *tcpParser { } func (tp *tcpParser) parseTCP(header *ipv4.Header, payload []byte) (*tcpResponse, error) { - if header.Protocol != IPProtoTCP || header.Version != 4 || + if header.Protocol != syscall.IPPROTO_TCP || header.Version != 4 || header.Src == nil || header.Dst == nil { return nil, fmt.Errorf("invalid IP header for TCP packet: %+v", header) } @@ -236,14 +127,6 @@ func (tp *tcpParser) parseTCP(header *ipv4.Header, payload []byte) (*tcpResponse return resp, nil } -func icmpMatch(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, seqNum uint32, response *icmpResponse) bool { - return localIP.Equal(response.InnerSrcIP) && - remoteIP.Equal(response.InnerDstIP) && - localPort == response.InnerSrcPort && - remotePort == response.InnerDstPort && - seqNum == response.InnerSeqNum -} - func tcpMatch(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint16, seqNum uint32, response *tcpResponse) bool { flagsCheck := (response.TCPResponse.SYN && response.TCPResponse.ACK) || response.TCPResponse.RST sourcePort := uint16(response.TCPResponse.SrcPort) @@ -256,7 +139,3 @@ func tcpMatch(localIP net.IP, localPort uint16, remoteIP net.IP, remotePort uint seqNum == response.TCPResponse.Ack-1 && flagsCheck } - -func (c canceledError) Error() string { - return string(c) -} diff --git a/pkg/networkpath/traceroute/tcp/utils_test.go b/pkg/networkpath/traceroute/tcp/utils_test.go index d79cda12ac5da..5b344df71ec16 100644 --- a/pkg/networkpath/traceroute/tcp/utils_test.go +++ b/pkg/networkpath/traceroute/tcp/utils_test.go @@ -10,15 +10,14 @@ package tcp import ( "fmt" "net" - "reflect" "runtime" "testing" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/ipv4" + + "github.com/DataDog/datadog-agent/pkg/networkpath/traceroute/testutils" ) var ( @@ -113,102 +112,12 @@ func Test_createRawTCPSynBuffer(t *testing.T) { assert.Equal(t, expectedPktBytes, pktBytes) } -func Test_parseICMP(t *testing.T) { - ipv4Header := createMockIPv4Header(srcIP, dstIP, 1) - icmpLayer := createMockICMPLayer(layers.ICMPv4CodeTTLExceeded) - innerIPv4Layer := createMockIPv4Layer(innerSrcIP, innerDstIP, layers.IPProtocolTCP) - innerTCPLayer := createMockTCPLayer(12345, 443, 28394, 12737, true, true, true) - - tt := []struct { - description string - inHeader *ipv4.Header - inPayload []byte - expected *icmpResponse - errMsg string - }{ - { - description: "empty IPv4 layer should return an error", - inHeader: &ipv4.Header{}, - inPayload: []byte{}, - expected: nil, - errMsg: "invalid IP header for ICMP packet", - }, - { - description: "missing ICMP layer should return an error", - inHeader: ipv4Header, - inPayload: []byte{}, - expected: nil, - errMsg: "failed to decode ICMP packet", - }, - { - description: "missing inner layers should return an error", - inHeader: ipv4Header, - inPayload: createMockICMPPacket(nil, icmpLayer, nil, nil, false), - expected: nil, - errMsg: "failed to decode inner ICMP payload", - }, - { - description: "ICMP packet with partial TCP header should create icmpResponse", - inHeader: ipv4Header, - inPayload: createMockICMPPacket(nil, icmpLayer, innerIPv4Layer, innerTCPLayer, true), - expected: &icmpResponse{ - SrcIP: srcIP, - DstIP: dstIP, - InnerSrcIP: innerSrcIP, - InnerDstIP: innerDstIP, - InnerSrcPort: 12345, - InnerDstPort: 443, - InnerSeqNum: 28394, - }, - errMsg: "", - }, - { - description: "full ICMP packet should create icmpResponse", - inHeader: ipv4Header, - inPayload: createMockICMPPacket(nil, icmpLayer, innerIPv4Layer, innerTCPLayer, true), - expected: &icmpResponse{ - SrcIP: srcIP, - DstIP: dstIP, - InnerSrcIP: innerSrcIP, - InnerDstIP: innerDstIP, - InnerSrcPort: 12345, - InnerDstPort: 443, - InnerSeqNum: 28394, - }, - errMsg: "", - }, - } - - for _, test := range tt { - t.Run(test.description, func(t *testing.T) { - actual, err := parseICMP(test.inHeader, test.inPayload) - if test.errMsg != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), test.errMsg) - assert.Nil(t, actual) - return - } - require.Nil(t, err) - require.NotNil(t, actual) - // assert.Equal doesn't handle net.IP well - assert.Equal(t, structFieldCount(test.expected), structFieldCount(actual)) - assert.Truef(t, test.expected.SrcIP.Equal(actual.SrcIP), "mismatch source IPs: expected %s, got %s", test.expected.SrcIP.String(), actual.SrcIP.String()) - assert.Truef(t, test.expected.DstIP.Equal(actual.DstIP), "mismatch dest IPs: expected %s, got %s", test.expected.DstIP.String(), actual.DstIP.String()) - assert.Truef(t, test.expected.InnerSrcIP.Equal(actual.InnerSrcIP), "mismatch inner source IPs: expected %s, got %s", test.expected.InnerSrcIP.String(), actual.InnerSrcIP.String()) - assert.Truef(t, test.expected.InnerDstIP.Equal(actual.InnerDstIP), "mismatch inner dest IPs: expected %s, got %s", test.expected.InnerDstIP.String(), actual.InnerDstIP.String()) - assert.Equal(t, test.expected.InnerSrcPort, actual.InnerSrcPort) - assert.Equal(t, test.expected.InnerDstPort, actual.InnerDstPort) - assert.Equal(t, test.expected.InnerSeqNum, actual.InnerSeqNum) - }) - } -} - func Test_parseTCP(t *testing.T) { - ipv4Header := createMockIPv4Header(srcIP, dstIP, 6) // 6 is TCP - tcpLayer := createMockTCPLayer(12345, 443, 28394, 12737, true, true, true) + ipv4Header := testutils.CreateMockIPv4Header(srcIP, dstIP, 6) // 6 is TCP + tcpLayer := testutils.CreateMockTCPLayer(12345, 443, 28394, 12737, true, true, true) // full packet - encodedTCPLayer, fullTCPPacket := createMockTCPPacket(ipv4Header, tcpLayer, false) + encodedTCPLayer, fullTCPPacket := testutils.CreateMockTCPPacket(ipv4Header, tcpLayer, false) tt := []struct { description string @@ -257,7 +166,7 @@ func Test_parseTCP(t *testing.T) { require.Nil(t, err) require.NotNil(t, actual) // assert.Equal doesn't handle net.IP well - assert.Equal(t, structFieldCount(test.expected), structFieldCount(actual)) + assert.Equal(t, testutils.StructFieldCount(test.expected), testutils.StructFieldCount(actual)) assert.Truef(t, test.expected.SrcIP.Equal(actual.SrcIP), "mismatch source IPs: expected %s, got %s", test.expected.SrcIP.String(), actual.SrcIP.String()) assert.Truef(t, test.expected.DstIP.Equal(actual.DstIP), "mismatch dest IPs: expected %s, got %s", test.expected.DstIP.String(), actual.DstIP.String()) assert.Equal(t, test.expected.TCPResponse, actual.TCPResponse) @@ -266,11 +175,11 @@ func Test_parseTCP(t *testing.T) { } func BenchmarkParseTCP(b *testing.B) { - ipv4Header := createMockIPv4Header(srcIP, dstIP, 6) // 6 is TCP - tcpLayer := createMockTCPLayer(12345, 443, 28394, 12737, true, true, true) + ipv4Header := testutils.CreateMockIPv4Header(srcIP, dstIP, 6) // 6 is TCP + tcpLayer := testutils.CreateMockTCPLayer(12345, 443, 28394, 12737, true, true, true) // full packet - _, fullTCPPacket := createMockTCPPacket(ipv4Header, tcpLayer, false) + _, fullTCPPacket := testutils.CreateMockTCPPacket(ipv4Header, tcpLayer, false) tp := newTCPParser() @@ -282,129 +191,3 @@ func BenchmarkParseTCP(b *testing.B) { } } } - -func createMockIPv4Header(srcIP, dstIP net.IP, protocol int) *ipv4.Header { - return &ipv4.Header{ - Version: 4, - Src: srcIP, - Dst: dstIP, - Protocol: protocol, - TTL: 64, - Len: 8, - } -} - -func createMockICMPPacket(ipLayer *layers.IPv4, icmpLayer *layers.ICMPv4, innerIP *layers.IPv4, innerTCP *layers.TCP, partialTCPHeader bool) []byte { - innerBuf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} - - innerLayers := make([]gopacket.SerializableLayer, 0, 2) - if innerIP != nil { - innerLayers = append(innerLayers, innerIP) - } - if innerTCP != nil { - innerLayers = append(innerLayers, innerTCP) - if innerIP != nil { - innerTCP.SetNetworkLayerForChecksum(innerIP) - } - } - - gopacket.SerializeLayers(innerBuf, opts, - innerLayers..., - ) - payload := innerBuf.Bytes() - - // if partialTCP is set, truncate - // the payload to include only the - // first 8 bytes of the TCP header - if partialTCPHeader { - payload = payload[:32] - } - - buf := gopacket.NewSerializeBuffer() - gopacket.SerializeLayers(buf, opts, - icmpLayer, - gopacket.Payload(payload), - ) - - icmpBytes := buf.Bytes() - if ipLayer == nil { - return icmpBytes - } - - buf = gopacket.NewSerializeBuffer() - gopacket.SerializeLayers(buf, opts, - ipLayer, - gopacket.Payload(icmpBytes), - ) - - return buf.Bytes() -} - -func createMockTCPPacket(ipHeader *ipv4.Header, tcpLayer *layers.TCP, includeHeader bool) (*layers.TCP, []byte) { - ipLayer := &layers.IPv4{ - Version: 4, - SrcIP: ipHeader.Src, - DstIP: ipHeader.Dst, - Protocol: layers.IPProtocol(ipHeader.Protocol), - TTL: 64, - Length: 8, - } - tcpLayer.SetNetworkLayerForChecksum(ipLayer) - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} - if includeHeader { - gopacket.SerializeLayers(buf, opts, - ipLayer, - tcpLayer, - ) - } else { - gopacket.SerializeLayers(buf, opts, - tcpLayer, - ) - } - - pkt := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeTCP, gopacket.Default) - - // return encoded TCP layer here - return pkt.Layer(layers.LayerTypeTCP).(*layers.TCP), buf.Bytes() -} - -func createMockIPv4Layer(srcIP, dstIP net.IP, protocol layers.IPProtocol) *layers.IPv4 { - return &layers.IPv4{ - SrcIP: srcIP, - DstIP: dstIP, - Version: 4, - Protocol: protocol, - } -} - -func createMockICMPLayer(typeCode layers.ICMPv4TypeCode) *layers.ICMPv4 { - return &layers.ICMPv4{ - TypeCode: typeCode, - } -} - -func createMockTCPLayer(srcPort uint16, dstPort uint16, seqNum uint32, ackNum uint32, syn bool, ack bool, rst bool) *layers.TCP { - return &layers.TCP{ - SrcPort: layers.TCPPort(srcPort), - DstPort: layers.TCPPort(dstPort), - Seq: seqNum, - Ack: ackNum, - SYN: syn, - ACK: ack, - RST: rst, - } -} - -func structFieldCount(v interface{}) int { - val := reflect.ValueOf(v) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - if val.Kind() != reflect.Struct { - return -1 - } - - return val.NumField() -} diff --git a/pkg/networkpath/traceroute/tcp/utils_unix.go b/pkg/networkpath/traceroute/tcp/utils_unix.go index 2a52e5f8bea88..4fd1b2e4b251d 100644 --- a/pkg/networkpath/traceroute/tcp/utils_unix.go +++ b/pkg/networkpath/traceroute/tcp/utils_unix.go @@ -14,6 +14,7 @@ import ( "sync" "time" + "github.com/DataDog/datadog-agent/pkg/networkpath/traceroute/common" "github.com/DataDog/datadog-agent/pkg/util/log" "github.com/google/gopacket/layers" "go.uber.org/multierr" @@ -68,8 +69,8 @@ func listenPackets(icmpConn rawConnWrapper, tcpConn rawConnWrapper, timeout time wg.Wait() if tcpErr != nil && icmpErr != nil { - _, tcpCanceled := tcpErr.(canceledError) - _, icmpCanceled := icmpErr.(canceledError) + _, tcpCanceled := tcpErr.(common.CanceledError) + _, icmpCanceled := icmpErr.(common.CanceledError) if icmpCanceled && tcpCanceled { log.Trace("timed out waiting for responses") return net.IP{}, 0, 0, time.Time{}, nil @@ -103,7 +104,7 @@ func handlePackets(ctx context.Context, conn rawConnWrapper, listener string, lo for { select { case <-ctx.Done(): - return net.IP{}, 0, 0, time.Time{}, canceledError("listener canceled") + return net.IP{}, 0, 0, time.Time{}, common.CanceledError("listener canceled") default: } now := time.Now() @@ -127,12 +128,12 @@ func handlePackets(ctx context.Context, conn rawConnWrapper, listener string, lo // TODO: remove listener constraint and parse all packets // in the same function return a succinct struct here if listener == "icmp" { - icmpResponse, err := parseICMP(header, packet) + icmpResponse, err := common.ParseICMP(header, packet) if err != nil { log.Tracef("failed to parse ICMP packet: %s", err) continue } - if icmpMatch(localIP, localPort, remoteIP, remotePort, seqNum, icmpResponse) { + if common.ICMPMatch(localIP, localPort, remoteIP, remotePort, seqNum, icmpResponse) { return icmpResponse.SrcIP, 0, icmpResponse.TypeCode, received, nil } } else if listener == "tcp" { diff --git a/pkg/networkpath/traceroute/tcp/utils_unix_test.go b/pkg/networkpath/traceroute/tcp/utils_unix_test.go index 731f5affe1380..db78310723d28 100644 --- a/pkg/networkpath/traceroute/tcp/utils_unix_test.go +++ b/pkg/networkpath/traceroute/tcp/utils_unix_test.go @@ -19,6 +19,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/ipv4" + + "github.com/DataDog/datadog-agent/pkg/networkpath/traceroute/testutils" ) type ( @@ -40,7 +42,7 @@ type ( ) func Test_handlePackets(t *testing.T) { - _, tcpBytes := createMockTCPPacket(createMockIPv4Header(dstIP, srcIP, 6), createMockTCPLayer(443, 12345, 28394, 28395, true, true, true), false) + _, tcpBytes := testutils.CreateMockTCPPacket(testutils.CreateMockIPv4Header(dstIP, srcIP, 6), testutils.CreateMockTCPLayer(443, 12345, 28394, 28395, true, true, true), false) tt := []struct { description string @@ -120,8 +122,8 @@ func Test_handlePackets(t *testing.T) { description: "successful ICMP parsing returns IP, port, and type code", ctxTimeout: 500 * time.Millisecond, conn: &mockRawConn{ - header: createMockIPv4Header(srcIP, dstIP, 1), - payload: createMockICMPPacket(nil, createMockICMPLayer(layers.ICMPv4CodeTTLExceeded), createMockIPv4Layer(innerSrcIP, innerDstIP, layers.IPProtocolTCP), createMockTCPLayer(12345, 443, 28394, 12737, true, true, true), false), + header: testutils.CreateMockIPv4Header(srcIP, dstIP, 1), + payload: testutils.CreateMockICMPPacket(nil, testutils.CreateMockICMPLayer(layers.ICMPv4CodeTTLExceeded), testutils.CreateMockIPv4Layer(innerSrcIP, innerDstIP, layers.IPProtocolTCP), testutils.CreateMockTCPLayer(12345, 443, 28394, 12737, true, true, true), false), }, localIP: innerSrcIP, localPort: 12345, @@ -137,7 +139,7 @@ func Test_handlePackets(t *testing.T) { description: "successful TCP parsing returns IP, port, and type code", ctxTimeout: 500 * time.Millisecond, conn: &mockRawConn{ - header: createMockIPv4Header(dstIP, srcIP, 6), + header: testutils.CreateMockIPv4Header(dstIP, srcIP, 6), payload: tcpBytes, }, localIP: srcIP, diff --git a/pkg/networkpath/traceroute/tcp/utils_windows.go b/pkg/networkpath/traceroute/tcp/utils_windows.go index 077495f43203e..483167109b6e4 100644 --- a/pkg/networkpath/traceroute/tcp/utils_windows.go +++ b/pkg/networkpath/traceroute/tcp/utils_windows.go @@ -16,6 +16,7 @@ import ( "golang.org/x/net/ipv4" "golang.org/x/sys/windows" + "github.com/DataDog/datadog-agent/pkg/networkpath/traceroute/common" "github.com/DataDog/datadog-agent/pkg/util/log" "github.com/google/gopacket/layers" ) @@ -49,7 +50,7 @@ func (w *winrawsocket) listenPackets(timeout time.Duration, localIP net.IP, loca wg.Wait() if icmpErr != nil { - _, icmpCanceled := icmpErr.(canceledError) + _, icmpCanceled := icmpErr.(common.CanceledError) if icmpCanceled { log.Trace("timed out waiting for responses") return net.IP{}, 0, 0, time.Time{}, nil @@ -74,7 +75,7 @@ func (w *winrawsocket) handlePackets(ctx context.Context, localIP net.IP, localP for { select { case <-ctx.Done(): - return net.IP{}, 0, 0, time.Time{}, canceledError("listener canceled") + return net.IP{}, 0, 0, time.Time{}, common.CanceledError("listener canceled") default: } @@ -110,12 +111,12 @@ func (w *winrawsocket) handlePackets(ctx context.Context, localIP net.IP, localP // TODO: remove listener constraint and parse all packets // in the same function return a succinct struct here if header.Protocol == windows.IPPROTO_ICMP { - icmpResponse, err := parseICMP(header, packet) + icmpResponse, err := common.ParseICMP(header, packet) if err != nil { log.Tracef("failed to parse ICMP packet: %s", err.Error()) continue } - if icmpMatch(localIP, localPort, remoteIP, remotePort, seqNum, icmpResponse) { + if common.ICMPMatch(localIP, localPort, remoteIP, remotePort, seqNum, icmpResponse) { return icmpResponse.SrcIP, 0, icmpResponse.TypeCode, received, nil } } else if header.Protocol == windows.IPPROTO_TCP { diff --git a/pkg/networkpath/traceroute/tcp/utils_windows_test.go b/pkg/networkpath/traceroute/tcp/utils_windows_test.go index 6e5b2a1c81ba4..6fbd7d0cc860b 100644 --- a/pkg/networkpath/traceroute/tcp/utils_windows_test.go +++ b/pkg/networkpath/traceroute/tcp/utils_windows_test.go @@ -18,6 +18,7 @@ import ( "golang.org/x/sys/windows" + "github.com/DataDog/datadog-agent/pkg/networkpath/traceroute/testutils" "github.com/google/gopacket/layers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -34,7 +35,7 @@ type ( ) func Test_handlePackets(t *testing.T) { - _, tcpBytes := createMockTCPPacket(createMockIPv4Header(dstIP, srcIP, 6), createMockTCPLayer(443, 12345, 28394, 28395, true, true, true), true) + _, tcpBytes := testutils.CreateMockTCPPacket(testutils.CreateMockIPv4Header(dstIP, srcIP, 6), testutils.CreateMockTCPLayer(443, 12345, 28394, 28395, true, true, true), true) tt := []struct { description string @@ -92,7 +93,7 @@ func Test_handlePackets(t *testing.T) { description: "successful ICMP parsing returns IP, port, and type code", ctxTimeout: 500 * time.Millisecond, conn: &mockRawConn{ - payload: createMockICMPPacket(createMockIPv4Layer(srcIP, dstIP, layers.IPProtocolICMPv4), createMockICMPLayer(layers.ICMPv4CodeTTLExceeded), createMockIPv4Layer(innerSrcIP, innerDstIP, layers.IPProtocolTCP), createMockTCPLayer(12345, 443, 28394, 12737, true, true, true), false), + payload: testutils.CreateMockICMPPacket(testutils.CreateMockIPv4Layer(srcIP, dstIP, layers.IPProtocolICMPv4), testutils.CreateMockICMPLayer(layers.ICMPv4CodeTTLExceeded), testutils.CreateMockIPv4Layer(innerSrcIP, innerDstIP, layers.IPProtocolTCP), testutils.CreateMockTCPLayer(12345, 443, 28394, 12737, true, true, true), false), }, localIP: innerSrcIP, localPort: 12345, diff --git a/pkg/networkpath/traceroute/testutils/doc.go b/pkg/networkpath/traceroute/testutils/doc.go new file mode 100644 index 0000000000000..2a31e324c585c --- /dev/null +++ b/pkg/networkpath/traceroute/testutils/doc.go @@ -0,0 +1,7 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016-present Datadog, Inc. + +// Package testutils contains utilities for testing traceroute code +package testutils diff --git a/pkg/networkpath/traceroute/testutils/testutils.go b/pkg/networkpath/traceroute/testutils/testutils.go new file mode 100644 index 0000000000000..e412d8971372b --- /dev/null +++ b/pkg/networkpath/traceroute/testutils/testutils.go @@ -0,0 +1,150 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2016-present Datadog, Inc. + +//go:build test + +package testutils + +import ( + "net" + "reflect" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "golang.org/x/net/ipv4" +) + +// CreateMockIPv4Header creates a mock IPv4 header for testing +func CreateMockIPv4Header(srcIP, dstIP net.IP, protocol int) *ipv4.Header { + return &ipv4.Header{ + Version: 4, + Src: srcIP, + Dst: dstIP, + Protocol: protocol, + TTL: 64, + Len: 8, + } +} + +// CreateMockICMPPacket creates a mock ICMP packet for testing +func CreateMockICMPPacket(ipLayer *layers.IPv4, icmpLayer *layers.ICMPv4, innerIP *layers.IPv4, innerTCP *layers.TCP, partialTCPHeader bool) []byte { + innerBuf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} + + innerLayers := make([]gopacket.SerializableLayer, 0, 2) + if innerIP != nil { + innerLayers = append(innerLayers, innerIP) + } + if innerTCP != nil { + innerLayers = append(innerLayers, innerTCP) + if innerIP != nil { + innerTCP.SetNetworkLayerForChecksum(innerIP) // nolint: errcheck + } + } + + gopacket.SerializeLayers(innerBuf, opts, // nolint: errcheck + innerLayers..., + ) + payload := innerBuf.Bytes() + + // if partialTCP is set, truncate + // the payload to include only the + // first 8 bytes of the TCP header + if partialTCPHeader { + payload = payload[:32] + } + + buf := gopacket.NewSerializeBuffer() + gopacket.SerializeLayers(buf, opts, // nolint: errcheck + icmpLayer, + gopacket.Payload(payload), + ) + + icmpBytes := buf.Bytes() + if ipLayer == nil { + return icmpBytes + } + + buf = gopacket.NewSerializeBuffer() + gopacket.SerializeLayers(buf, opts, // nolint: errcheck + ipLayer, + gopacket.Payload(icmpBytes), + ) + + return buf.Bytes() +} + +// CreateMockTCPPacket creates a mock TCP packet for testing +func CreateMockTCPPacket(ipHeader *ipv4.Header, tcpLayer *layers.TCP, includeHeader bool) (*layers.TCP, []byte) { + ipLayer := &layers.IPv4{ + Version: 4, + SrcIP: ipHeader.Src, + DstIP: ipHeader.Dst, + Protocol: layers.IPProtocol(ipHeader.Protocol), + TTL: 64, + Length: 8, + } + tcpLayer.SetNetworkLayerForChecksum(ipLayer) // nolint: errcheck + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} + if includeHeader { + gopacket.SerializeLayers(buf, opts, // nolint: errcheck + ipLayer, + tcpLayer, + ) + } else { + gopacket.SerializeLayers(buf, opts, // nolint: errcheck + tcpLayer, + ) + } + + pkt := gopacket.NewPacket(buf.Bytes(), layers.LayerTypeTCP, gopacket.Default) + + // return encoded TCP layer here + return pkt.Layer(layers.LayerTypeTCP).(*layers.TCP), buf.Bytes() +} + +// CreateMockIPv4Layer creates a mock IPv4 layer for testing +func CreateMockIPv4Layer(srcIP, dstIP net.IP, protocol layers.IPProtocol) *layers.IPv4 { + return &layers.IPv4{ + SrcIP: srcIP, + DstIP: dstIP, + Version: 4, + Protocol: protocol, + } +} + +// CreateMockICMPLayer creates a mock ICMP layer for testing +func CreateMockICMPLayer(typeCode layers.ICMPv4TypeCode) *layers.ICMPv4 { + return &layers.ICMPv4{ + TypeCode: typeCode, + } +} + +// CreateMockTCPLayer creates a mock TCP layer for testing +func CreateMockTCPLayer(srcPort uint16, dstPort uint16, seqNum uint32, ackNum uint32, syn bool, ack bool, rst bool) *layers.TCP { + return &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + Seq: seqNum, + Ack: ackNum, + SYN: syn, + ACK: ack, + RST: rst, + } +} + +// StructFieldCount returns the number of fields in a struct +func StructFieldCount(v interface{}) int { + val := reflect.ValueOf(v) + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + if val.Kind() != reflect.Struct { + return -1 + } + + return val.NumField() +}