aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Documentation/netlink/specs/psp.yaml187
-rw-r--r--Documentation/networking/index.rst1
-rw-r--r--Documentation/networking/psp.rst183
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/Kconfig11
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/Makefile2
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en.h6
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en/fs.h2
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en/params.c4
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en_accel/en_accel.h50
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec_rxtx.h2
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp.c952
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp.h61
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp_rxtx.c200
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp_rxtx.h121
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en_main.c9
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en_rx.c49
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/en_tx.c10
-rw-r--r--drivers/net/ethernet/mellanox/mlx5/core/lib/crypto.h1
-rw-r--r--include/linux/netdevice.h4
-rw-r--r--include/linux/skbuff.h3
-rw-r--r--include/net/dropreason-core.h6
-rw-r--r--include/net/inet_timewait_sock.h8
-rw-r--r--include/net/psp.h12
-rw-r--r--include/net/psp/functions.h210
-rw-r--r--include/net/psp/types.h184
-rw-r--r--include/net/sock.h26
-rw-r--r--include/uapi/linux/psp.h66
-rw-r--r--net/Kconfig1
-rw-r--r--net/Makefile1
-rw-r--r--net/core/dev.c32
-rw-r--r--net/core/gro.c2
-rw-r--r--net/core/skbuff.c4
-rw-r--r--net/ipv4/af_inet.c2
-rw-r--r--net/ipv4/inet_timewait_sock.c5
-rw-r--r--net/ipv4/ip_output.c5
-rw-r--r--net/ipv4/tcp.c2
-rw-r--r--net/ipv4/tcp_ipv4.c18
-rw-r--r--net/ipv4/tcp_minisocks.c20
-rw-r--r--net/ipv4/tcp_output.c17
-rw-r--r--net/ipv6/ipv6_sockglue.c6
-rw-r--r--net/ipv6/tcp_ipv6.c17
-rw-r--r--net/psp/Kconfig15
-rw-r--r--net/psp/Makefile5
-rw-r--r--net/psp/psp-nl-gen.c119
-rw-r--r--net/psp/psp-nl-gen.h39
-rw-r--r--net/psp/psp.h54
-rw-r--r--net/psp/psp_main.c321
-rw-r--r--net/psp/psp_nl.c505
-rw-r--r--net/psp/psp_sock.c295
-rw-r--r--tools/net/ynl/Makefile.deps1
50 files changed, 3801 insertions, 55 deletions
diff --git a/Documentation/netlink/specs/psp.yaml b/Documentation/netlink/specs/psp.yaml
new file mode 100644
index 000000000000..944429e5c9a8
--- /dev/null
+++ b/Documentation/netlink/specs/psp.yaml
@@ -0,0 +1,187 @@
+# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
+---
+name: psp
+
+doc:
+ PSP Security Protocol Generic Netlink family.
+
+definitions:
+ -
+ type: enum
+ name: version
+ entries: [hdr0-aes-gcm-128, hdr0-aes-gcm-256,
+ hdr0-aes-gmac-128, hdr0-aes-gmac-256]
+
+attribute-sets:
+ -
+ name: dev
+ attributes:
+ -
+ name: id
+ doc: PSP device ID.
+ type: u32
+ checks:
+ min: 1
+ -
+ name: ifindex
+ doc: ifindex of the main netdevice linked to the PSP device.
+ type: u32
+ -
+ name: psp-versions-cap
+ doc: Bitmask of PSP versions supported by the device.
+ type: u32
+ enum: version
+ enum-as-flags: true
+ -
+ name: psp-versions-ena
+ doc: Bitmask of currently enabled (accepted on Rx) PSP versions.
+ type: u32
+ enum: version
+ enum-as-flags: true
+ -
+ name: assoc
+ attributes:
+ -
+ name: dev-id
+ doc: PSP device ID.
+ type: u32
+ checks:
+ min: 1
+ -
+ name: version
+ doc: |
+ PSP versions (AEAD and protocol version) used by this association,
+ dictates the size of the key.
+ type: u32
+ enum: version
+ -
+ name: rx-key
+ type: nest
+ nested-attributes: keys
+ -
+ name: tx-key
+ type: nest
+ nested-attributes: keys
+ -
+ name: sock-fd
+ doc: Sockets which should be bound to the association immediately.
+ type: u32
+ -
+ name: keys
+ attributes:
+ -
+ name: key
+ type: binary
+ -
+ name: spi
+ doc: Security Parameters Index (SPI) of the association.
+ type: u32
+
+operations:
+ list:
+ -
+ name: dev-get
+ doc: Get / dump information about PSP capable devices on the system.
+ attribute-set: dev
+ do:
+ request:
+ attributes:
+ - id
+ reply: &dev-all
+ attributes:
+ - id
+ - ifindex
+ - psp-versions-cap
+ - psp-versions-ena
+ pre: psp-device-get-locked
+ post: psp-device-unlock
+ dump:
+ reply: *dev-all
+ -
+ name: dev-add-ntf
+ doc: Notification about device appearing.
+ notify: dev-get
+ mcgrp: mgmt
+ -
+ name: dev-del-ntf
+ doc: Notification about device disappearing.
+ notify: dev-get
+ mcgrp: mgmt
+ -
+ name: dev-set
+ doc: Set the configuration of a PSP device.
+ attribute-set: dev
+ do:
+ request:
+ attributes:
+ - id
+ - psp-versions-ena
+ reply:
+ attributes: []
+ pre: psp-device-get-locked
+ post: psp-device-unlock
+ -
+ name: dev-change-ntf
+ doc: Notification about device configuration being changed.
+ notify: dev-get
+ mcgrp: mgmt
+
+ -
+ name: key-rotate
+ doc: Rotate the device key.
+ attribute-set: dev
+ do:
+ request:
+ attributes:
+ - id
+ reply:
+ attributes:
+ - id
+ pre: psp-device-get-locked
+ post: psp-device-unlock
+ -
+ name: key-rotate-ntf
+ doc: Notification about device key getting rotated.
+ notify: key-rotate
+ mcgrp: use
+
+ -
+ name: rx-assoc
+ doc: Allocate a new Rx key + SPI pair, associate it with a socket.
+ attribute-set: assoc
+ do:
+ request:
+ attributes:
+ - dev-id
+ - version
+ - sock-fd
+ reply:
+ attributes:
+ - dev-id
+ - rx-key
+ pre: psp-assoc-device-get-locked
+ post: psp-device-unlock
+ -
+ name: tx-assoc
+ doc: Add a PSP Tx association.
+ attribute-set: assoc
+ do:
+ request:
+ attributes:
+ - dev-id
+ - version
+ - tx-key
+ - sock-fd
+ reply:
+ attributes: []
+ pre: psp-assoc-device-get-locked
+ post: psp-device-unlock
+
+mcast-groups:
+ list:
+ -
+ name: mgmt
+ -
+ name: use
+
+...
diff --git a/Documentation/networking/index.rst b/Documentation/networking/index.rst
index b7a4969e9bc9..c775cababc8c 100644
--- a/Documentation/networking/index.rst
+++ b/Documentation/networking/index.rst
@@ -101,6 +101,7 @@ Contents:
ppp_generic
proc_net_tcp
pse-pd/index
+ psp
radiotap-headers
rds
regulatory
diff --git a/Documentation/networking/psp.rst b/Documentation/networking/psp.rst
new file mode 100644
index 000000000000..4ac09e64e95a
--- /dev/null
+++ b/Documentation/networking/psp.rst
@@ -0,0 +1,183 @@
+.. SPDX-License-Identifier: GPL-2.0-only
+
+=====================
+PSP Security Protocol
+=====================
+
+Protocol
+========
+
+PSP Security Protocol (PSP) was defined at Google and published in:
+
+https://raw.githubusercontent.com/google/psp/main/doc/PSP_Arch_Spec.pdf
+
+This section briefly covers protocol aspects crucial for understanding
+the kernel API. Refer to the protocol specification for further details.
+
+Note that the kernel implementation and documentation uses the term
+"device key" in place of "master key", it is both less confusing
+to an average developer and is less likely to run afoul any naming
+guidelines.
+
+Derived Rx keys
+---------------
+
+PSP borrows some terms and mechanisms from IPsec. PSP was designed
+with HW offloads in mind. The key feature of PSP is that Rx keys for every
+connection do not have to be stored by the receiver but can be derived
+from device key and information present in packet headers.
+This makes it possible to implement receivers which require a constant
+amount of memory regardless of the number of connections (``O(1)`` scaling).
+
+Tx keys have to be stored like with any other protocol, but Tx is much
+less latency sensitive than Rx, and delays in fetching keys from slow
+memory is less likely to cause packet drops. Preferably, the Tx keys
+should be provided with the packet (e.g. as part of the descriptors).
+
+Key rotation
+------------
+
+The device key known only to the receiver is fundamental to the design.
+Per specification this state cannot be directly accessible (it must be
+impossible to read it out of the hardware of the receiver NIC).
+Moreover, it has to be "rotated" periodically (usually daily). Rotation
+means that new device key gets generated (by a random number generator
+of the device), and used for all new connections. To avoid disrupting
+old connections the old device key remains in the NIC. A phase bit
+carried in the packet headers indicates which generation of device key
+the packet has been encrypted with.
+
+User facing API
+===============
+
+PSP is designed primarily for hardware offloads. There is currently
+no software fallback for systems which do not have PSP capable NICs.
+There is also no standard (or otherwise defined) way of establishing
+a PSP-secured connection or exchanging the symmetric keys.
+
+The expectation is that higher layer protocols will take care of
+protocol and key negotiation. For example one may use TLS key exchange,
+announce the PSP capability, and switch to PSP if both endpoints
+are PSP-capable.
+
+All configuration of PSP is performed via the PSP netlink family.
+
+Device discovery
+----------------
+
+The PSP netlink family defines operations to retrieve information
+about the PSP devices available on the system, configure them and
+access PSP related statistics.
+
+Securing a connection
+---------------------
+
+PSP encryption is currently only supported for TCP connections.
+Rx and Tx keys are allocated separately. First the ``rx-assoc``
+Netlink command needs to be issued, specifying a target TCP socket.
+Kernel will allocate a new PSP Rx key from the NIC and associate it
+with given socket. At this stage socket will accept both PSP-secured
+and plain text TCP packets.
+
+Tx keys are installed using the ``tx-assoc`` Netlink command.
+Once the Tx keys are installed, all data read from the socket will
+be PSP-secured. In other words act of installing Tx keys has a secondary
+effect on the Rx direction.
+
+There is an intermediate period after ``tx-assoc`` successfully
+returns and before the TCP socket encounters it's first PSP
+authenticated packet, where the TCP stack will allow certain nondata
+packets, i.e. ACKs, FINs, and RSTs, to enter TCP receive processing
+even if not PSP authenticated. During the ``tx-assoc`` call, the TCP
+socket's ``rcv_nxt`` field is recorded. At this point, ACKs and RSTs
+will be accepted with any sequence number, while FINs will only be
+accepted at the latched value of ``rcv_nxt``. Once the TCP stack
+encounters the first TCP packet containing PSP authenticated data, the
+other end of the connection must have executed the ``tx-assoc``
+command, so any TCP packet, including those without data, will be
+dropped before receive processing if it is not successfully
+authenticated. This is summarized in the table below. The
+aforementioned state of rejecting all non-PSP packets is labeled "PSP
+Full".
+
++----------------+------------+------------+-------------+-------------+
+| Event | Normal TCP | Rx PSP | Tx PSP | PSP Full |
++================+============+============+=============+=============+
+| Rx plain | accept | accept | drop | drop |
+| (data) | | | | |
++----------------+------------+------------+-------------+-------------+
+| Rx plain | accept | accept | accept | drop |
+| (ACK|FIN|RST) | | | | |
++----------------+------------+------------+-------------+-------------+
+| Rx PSP (good) | drop | accept | accept | accept |
++----------------+------------+------------+-------------+-------------+
+| Rx PSP (bad | drop | drop | drop | drop |
+| crypt, !=SPI) | | | | |
++----------------+------------+------------+-------------+-------------+
+| Tx | plain text | plain text | encrypted | encrypted |
+| | | | (excl. rtx) | (excl. rtx) |
++----------------+------------+------------+-------------+-------------+
+
+To ensure that any data read from the socket after the ``tx-assoc``
+call returns success has been authenticated, the kernel will scan the
+receive and ofo queues of the socket at ``tx-assoc`` time. If any
+enqueued packet was received in clear text, the Tx association will
+fail, and the application should retry installing the Tx key after
+draining the socket (this should not be necessary if both endpoints
+are well behaved).
+
+Because TCP sequence numbers are not integrity protected prior to
+upgrading to PSP, it is possible that a MITM could offset sequence
+numbers in a way that deletes a prefix of the PSP protected part of
+the TCP stream. If userspace cares to mitigate this type of attack, a
+special "start of PSP" message should be exchanged after ``tx-assoc``.
+
+Rotation notifications
+----------------------
+
+The rotations of device key happen asynchronously and are usually
+performed by management daemons, not under application control.
+The PSP netlink family will generate a notification whenever keys
+are rotated. The applications are expected to re-establish connections
+before keys are rotated again.
+
+Kernel implementation
+=====================
+
+Driver notes
+------------
+
+Drivers are expected to start with no PSP enabled (``psp-versions-ena``
+in ``dev-get`` set to ``0``) whenever possible. The user space should
+not depend on this behavior, as future extension may necessitate creation
+of devices with PSP already enabled, nonetheless drivers should not enable
+PSP by default. Enabling PSP should be the responsibility of the system
+component which also takes care of key rotation.
+
+Note that ``psp-versions-ena`` is expected to be used only for enabling
+receive processing. The device is not expected to reject transmit requests
+after ``psp-versions-ena`` has been disabled. User may also disable
+``psp-versions-ena`` while there are active associations, which will
+break all PSP Rx processing.
+
+Drivers are expected to ensure that a device key is usable and secure
+upon init, without explicit key rotation by the user space. It must be
+possible to allocate working keys, and that no duplicate keys must be
+generated. If the device allows the host to request the key for an
+arbitrary SPI - driver should discard both device keys (rotate the
+device key twice), to avoid potentially using a SPI+key which previous
+OS instance already had access to.
+
+Drivers must use ``psp_skb_get_assoc_rcu()`` to check if PSP Tx offload
+was requested for given skb. On Rx drivers should allocate and populate
+the ``SKB_EXT_PSP`` skb extension, and set the skb->decrypted bit to 1.
+
+Kernel implementation notes
+---------------------------
+
+PSP implementation follows the TLS offload more closely than the IPsec
+offload, with per-socket state, and the use of skb->decrypted to prevent
+clear text leaks.
+
+PSP device is separate from netdev, to make it possible to "delegate"
+PSP offload capabilities to software devices (e.g. ``veth``).
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/Kconfig b/drivers/net/ethernet/mellanox/mlx5/core/Kconfig
index 8ef2ac2060ba..3c3e84100d5a 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/Kconfig
+++ b/drivers/net/ethernet/mellanox/mlx5/core/Kconfig
@@ -207,3 +207,14 @@ config MLX5_DPLL
help
DPLL support in Mellanox Technologies ConnectX NICs.
+config MLX5_EN_PSP
+ bool "Mellanox Technologies support for PSP cryptography-offload acceleration"
+ depends on INET_PSP
+ depends on MLX5_CORE_EN
+ default y
+ help
+ mlx5 device offload support for Google PSP Security Protocol offload.
+ Adds support for PSP encryption offload and for SPI and key generation
+ interfaces to PSP Stack which supports PSP crypto offload.
+
+ If unsure, say Y.
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/Makefile b/drivers/net/ethernet/mellanox/mlx5/core/Makefile
index d77696f46eb5..8ffa286a18f5 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/Makefile
+++ b/drivers/net/ethernet/mellanox/mlx5/core/Makefile
@@ -112,6 +112,8 @@ mlx5_core-$(CONFIG_MLX5_EN_TLS) += en_accel/ktls_stats.o \
en_accel/fs_tcp.o en_accel/ktls.o en_accel/ktls_txrx.o \
en_accel/ktls_tx.o en_accel/ktls_rx.o
+mlx5_core-$(CONFIG_MLX5_EN_PSP) += en_accel/psp.o en_accel/psp_rxtx.o
+
#
# SW Steering
#
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en.h b/drivers/net/ethernet/mellanox/mlx5/core/en.h
index f1aa2b2ce10b..4ffbc263d60f 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en.h
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en.h
@@ -47,6 +47,7 @@
#include <linux/rhashtable.h>
#include <net/udp_tunnel.h>
#include <net/switchdev.h>
+#include <net/psp/types.h>
#include <net/xdp.h>
#include <linux/dim.h>
#include <linux/bits.h>
@@ -68,7 +69,7 @@ struct page_pool;
#define MLX5E_METADATA_ETHER_TYPE (0x8CE4)
#define MLX5E_METADATA_ETHER_LEN 8
-#define MLX5E_ETH_HARD_MTU (ETH_HLEN + VLAN_HLEN + ETH_FCS_LEN)
+#define MLX5E_ETH_HARD_MTU (ETH_HLEN + PSP_ENCAP_HLEN + PSP_TRL_SIZE + VLAN_HLEN + ETH_FCS_LEN)
#define MLX5E_HW2SW_MTU(params, hwmtu) ((hwmtu) - ((params)->hard_mtu))
#define MLX5E_SW2HW_MTU(params, swmtu) ((swmtu) + ((params)->hard_mtu))
@@ -938,6 +939,9 @@ struct mlx5e_priv {
#ifdef CONFIG_MLX5_EN_IPSEC
struct mlx5e_ipsec *ipsec;
#endif
+#ifdef CONFIG_MLX5_EN_PSP
+ struct mlx5e_psp *psp;
+#endif
#ifdef CONFIG_MLX5_EN_TLS
struct mlx5e_tls *tls;
#endif
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en/fs.h b/drivers/net/ethernet/mellanox/mlx5/core/en/fs.h
index 9560fcba643f..85a53e8bcbc7 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en/fs.h
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en/fs.h
@@ -88,7 +88,7 @@ enum {
#ifdef CONFIG_MLX5_EN_ARFS
MLX5E_ARFS_FT_LEVEL = MLX5E_INNER_TTC_FT_LEVEL + 1,
#endif
-#ifdef CONFIG_MLX5_EN_IPSEC
+#if defined(CONFIG_MLX5_EN_IPSEC) || defined(CONFIG_MLX5_EN_PSP)
MLX5E_ACCEL_FS_ESP_FT_LEVEL = MLX5E_INNER_TTC_FT_LEVEL + 1,
MLX5E_ACCEL_FS_ESP_FT_ERR_LEVEL,
MLX5E_ACCEL_FS_POL_FT_LEVEL,
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en/params.c b/drivers/net/ethernet/mellanox/mlx5/core/en/params.c
index 596440c8c364..3692298e10f2 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en/params.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en/params.c
@@ -6,6 +6,7 @@
#include "en/port.h"
#include "en_accel/en_accel.h"
#include "en_accel/ipsec.h"
+#include "en_accel/psp.h"
#include <linux/dim.h>
#include <net/page_pool/types.h>
#include <net/xdp_sock_drv.h>
@@ -1004,7 +1005,8 @@ void mlx5e_build_sq_param(struct mlx5_core_dev *mdev,
bool allow_swp;
allow_swp = mlx5_geneve_tx_allowed(mdev) ||
- (mlx5_ipsec_device_caps(mdev) & MLX5_IPSEC_CAP_CRYPTO);
+ (mlx5_ipsec_device_caps(mdev) & MLX5_IPSEC_CAP_CRYPTO) ||
+ mlx5_is_psp_device(mdev);
mlx5e_build_sq_param_common(mdev, param);
MLX5_SET(wq, wq, log_wq_sz, params->log_sq_size);
MLX5_SET(sqc, sqc, allow_swp, allow_swp);
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/en_accel.h b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/en_accel.h
index 33e32584b07f..8bef99e8367e 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/en_accel.h
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/en_accel.h
@@ -42,6 +42,8 @@
#include <en_accel/macsec.h>
#include "en.h"
#include "en/txrx.h"
+#include "en_accel/psp.h"
+#include "en_accel/psp_rxtx.h"
#if IS_ENABLED(CONFIG_GENEVE)
#include <net/geneve.h>
@@ -119,6 +121,9 @@ struct mlx5e_accel_tx_state {
#ifdef CONFIG_MLX5_EN_IPSEC
struct mlx5e_accel_tx_ipsec_state ipsec;
#endif
+#ifdef CONFIG_MLX5_EN_PSP
+ struct mlx5e_accel_tx_psp_state psp_st;
+#endif
};
static inline bool mlx5e_accel_tx_begin(struct net_device *dev,
@@ -137,6 +142,13 @@ static inline bool mlx5e_accel_tx_begin(struct net_device *dev,
return false;
#endif
+#ifdef CONFIG_MLX5_EN_PSP
+ if (mlx5e_psp_is_offload(skb, dev)) {
+ if (unlikely(!mlx5e_psp_handle_tx_skb(dev, skb, &state->psp_st)))
+ return false;
+ }
+#endif
+
#ifdef CONFIG_MLX5_EN_IPSEC
if (test_bit(MLX5E_SQ_STATE_IPSEC, &sq->state) && xfrm_offload(skb)) {
if (unlikely(!mlx5e_ipsec_handle_tx_skb(dev, skb, &state->ipsec)))
@@ -157,8 +169,14 @@ static inline bool mlx5e_accel_tx_begin(struct net_device *dev,
}
static inline unsigned int mlx5e_accel_tx_ids_len(struct mlx5e_txqsq *sq,
+ struct sk_buff *skb,
struct mlx5e_accel_tx_state *state)
{
+#ifdef CONFIG_MLX5_EN_PSP
+ if (mlx5e_psp_is_offload_state(&state->psp_st))
+ return mlx5e_psp_tx_ids_len(&state->psp_st);
+#endif
+
#ifdef CONFIG_MLX5_EN_IPSEC
if (test_bit(MLX5E_SQ_STATE_IPSEC, &sq->state))
return mlx5e_ipsec_tx_ids_len(&state->ipsec);
@@ -172,8 +190,14 @@ static inline unsigned int mlx5e_accel_tx_ids_len(struct mlx5e_txqsq *sq,
static inline void mlx5e_accel_tx_eseg(struct mlx5e_priv *priv,
struct sk_buff *skb,
+ struct mlx5e_accel_tx_state *accel,
struct mlx5_wqe_eth_seg *eseg, u16 ihs)
{
+#ifdef CONFIG_MLX5_EN_PSP
+ if (mlx5e_psp_is_offload_state(&accel->psp_st))
+ mlx5e_psp_tx_build_eseg(priv, skb, &accel->psp_st, eseg);
+#endif
+
#ifdef CONFIG_MLX5_EN_IPSEC
if (xfrm_offload(skb))
mlx5e_ipsec_tx_build_eseg(priv, skb, eseg);
@@ -199,6 +223,11 @@ static inline void mlx5e_accel_tx_finish(struct mlx5e_txqsq *sq,
mlx5e_ktls_handle_tx_wqe(&wqe->ctrl, &state->tls);
#endif
+#ifdef CONFIG_MLX5_EN_PSP
+ if (mlx5e_psp_is_offload_state(&state->psp_st))
+ mlx5e_psp_handle_tx_wqe(wqe, &state->psp_st, inlseg);
+#endif
+
#ifdef CONFIG_MLX5_EN_IPSEC
if (test_bit(MLX5E_SQ_STATE_IPSEC, &sq->state) &&
state->ipsec.xo && state->ipsec.tailen)
@@ -208,21 +237,40 @@ static inline void mlx5e_accel_tx_finish(struct mlx5e_txqsq *sq,
static inline int mlx5e_accel_init_rx(struct mlx5e_priv *priv)
{
- return mlx5e_ktls_init_rx(priv);
+ int err;
+
+ err = mlx5_accel_psp_fs_init_rx_tables(priv);
+ if (err)
+ goto out;
+
+ err = mlx5e_ktls_init_rx(priv);
+ if (err)
+ mlx5_accel_psp_fs_cleanup_rx_tables(priv);
+
+out:
+ return err;
}
static inline void mlx5e_accel_cleanup_rx(struct mlx5e_priv *priv)
{
mlx5e_ktls_cleanup_rx(priv);
+ mlx5_accel_psp_fs_cleanup_rx_tables(priv);
}
static inline int mlx5e_accel_init_tx(struct mlx5e_priv *priv)
{
+ int err;
+
+ err = mlx5_accel_psp_fs_init_tx_tables(priv);
+ if (err)
+ return err;
+
return mlx5e_ktls_init_tx(priv);
}
static inline void mlx5e_accel_cleanup_tx(struct mlx5e_priv *priv)
{
mlx5e_ktls_cleanup_tx(priv);
+ mlx5_accel_psp_fs_cleanup_tx_tables(priv);
}
#endif /* __MLX5E_EN_ACCEL_H__ */
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec_rxtx.h b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec_rxtx.h
index 3cc640669247..45b0d19e735c 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec_rxtx.h
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec_rxtx.h
@@ -40,7 +40,7 @@
#include "en/txrx.h"
/* Bit31: IPsec marker, Bit30: reserved, Bit29-24: IPsec syndrome, Bit23-0: IPsec obj id */
-#define MLX5_IPSEC_METADATA_MARKER(metadata) (((metadata) >> 31) & 0x1)
+#define MLX5_IPSEC_METADATA_MARKER(metadata) ((((metadata) >> 30) & 0x3) == 0x2)
#define MLX5_IPSEC_METADATA_SYNDROM(metadata) (((metadata) >> 24) & GENMASK(5, 0))
#define MLX5_IPSEC_METADATA_HANDLE(metadata) ((metadata) & GENMASK(23, 0))
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp.c b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp.c
new file mode 100644
index 000000000000..b4cb131c5f81
--- /dev/null
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp.c
@@ -0,0 +1,952 @@
+// SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
+/* Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. */
+#include <linux/mlx5/device.h>
+#include <net/psp.h>
+#include <linux/psp.h>
+#include "mlx5_core.h"
+#include "psp.h"
+#include "lib/crypto.h"
+#include "en_accel/psp.h"
+#include "fs_core.h"
+
+enum accel_fs_psp_type {
+ ACCEL_FS_PSP4,
+ ACCEL_FS_PSP6,
+ ACCEL_FS_PSP_NUM_TYPES,
+};
+
+enum accel_psp_syndrome {
+ PSP_OK = 0,
+ PSP_ICV_FAIL,
+ PSP_BAD_TRAILER,
+};
+
+struct mlx5e_psp_tx {
+ struct mlx5_flow_namespace *ns;
+ struct mlx5_flow_table *ft;
+ struct mlx5_flow_group *fg;
+ struct mlx5_flow_handle *rule;
+ struct mutex mutex; /* Protect PSP TX steering */
+ u32 refcnt;
+};
+
+struct mlx5e_psp_rx_err {
+ struct mlx5_flow_table *ft;
+ struct mlx5_flow_handle *rule;
+ struct mlx5_flow_handle *drop_rule;
+ struct mlx5_modify_hdr *copy_modify_hdr;
+};
+
+struct mlx5e_accel_fs_psp_prot {
+ struct mlx5_flow_table *ft;
+ struct mlx5_flow_group *miss_group;
+ struct mlx5_flow_handle *miss_rule;
+ struct mlx5_flow_destination default_dest;
+ struct mlx5e_psp_rx_err rx_err;
+ u32 refcnt;
+ struct mutex prot_mutex; /* protect ESP4/ESP6 protocol */
+ struct mlx5_flow_handle *def_rule;
+};
+
+struct mlx5e_accel_fs_psp {
+ struct mlx5e_accel_fs_psp_prot fs_prot[ACCEL_FS_PSP_NUM_TYPES];
+};
+
+struct mlx5e_psp_fs {
+ struct mlx5_core_dev *mdev;
+ struct mlx5e_psp_tx *tx_fs;
+ /* Rx manage */
+ struct mlx5e_flow_steering *fs;
+ struct mlx5e_accel_fs_psp *rx_fs;
+};
+
+/* PSP RX flow steering */
+static enum mlx5_traffic_types fs_psp2tt(enum accel_fs_psp_type i)
+{
+ if (i == ACCEL_FS_PSP4)
+ return MLX5_TT_IPV4_UDP;
+
+ return MLX5_TT_IPV6_UDP;
+}
+
+static void accel_psp_fs_rx_err_del_rules(struct mlx5e_psp_fs *fs,
+ struct mlx5e_psp_rx_err *rx_err)
+{
+ if (rx_err->drop_rule) {
+ mlx5_del_flow_rules(rx_err->drop_rule);
+ rx_err->drop_rule = NULL;
+ }
+
+ if (rx_err->rule) {
+ mlx5_del_flow_rules(rx_err->rule);
+ rx_err->rule = NULL;
+ }
+
+ if (rx_err->copy_modify_hdr) {
+ mlx5_modify_header_dealloc(fs->mdev, rx_err->copy_modify_hdr);
+ rx_err->copy_modify_hdr = NULL;
+ }
+}
+
+static void accel_psp_fs_rx_err_destroy_ft(struct mlx5e_psp_fs *fs,
+ struct mlx5e_psp_rx_err *rx_err)
+{
+ accel_psp_fs_rx_err_del_rules(fs, rx_err);
+
+ if (rx_err->ft) {
+ mlx5_destroy_flow_table(rx_err->ft);
+ rx_err->ft = NULL;
+ }
+}
+
+static void accel_psp_setup_syndrome_match(struct mlx5_flow_spec *spec,
+ enum accel_psp_syndrome syndrome)
+{
+ void *misc_params_2;
+
+ spec->match_criteria_enable |= MLX5_MATCH_MISC_PARAMETERS_2;
+ misc_params_2 = MLX5_ADDR_OF(fte_match_param, spec->match_criteria, misc_parameters_2);
+ MLX5_SET_TO_ONES(fte_match_set_misc2, misc_params_2, psp_syndrome);
+ misc_params_2 = MLX5_ADDR_OF(fte_match_param, spec->match_value, misc_parameters_2);
+ MLX5_SET(fte_match_set_misc2, misc_params_2, psp_syndrome, syndrome);
+}
+
+static int accel_psp_fs_rx_err_add_rule(struct mlx5e_psp_fs *fs,
+ struct mlx5e_accel_fs_psp_prot *fs_prot,
+ struct mlx5e_psp_rx_err *rx_err)
+{
+ u8 action[MLX5_UN_SZ_BYTES(set_add_copy_action_in_auto)] = {};
+ struct mlx5_core_dev *mdev = fs->mdev;
+ struct mlx5_flow_act flow_act = {};
+ struct mlx5_modify_hdr *modify_hdr;
+ struct mlx5_flow_handle *fte;
+ struct mlx5_flow_spec *spec;
+ int err = 0;
+
+ spec = kzalloc(sizeof(*spec), GFP_KERNEL);
+ if (!spec)
+ return -ENOMEM;
+
+ /* Action to copy 7 bit psp_syndrome to regB[23:29] */
+ MLX5_SET(copy_action_in, action, action_type, MLX5_ACTION_TYPE_COPY);
+ MLX5_SET(copy_action_in, action, src_field, MLX5_ACTION_IN_FIELD_PSP_SYNDROME);
+ MLX5_SET(copy_action_in, action, src_offset, 0);
+ MLX5_SET(copy_action_in, action, length, 7);
+ MLX5_SET(copy_action_in, action, dst_field, MLX5_ACTION_IN_FIELD_METADATA_REG_B);
+ MLX5_SET(copy_action_in, action, dst_offset, 23);
+
+ modify_hdr = mlx5_modify_header_alloc(mdev, MLX5_FLOW_NAMESPACE_KERNEL,
+ 1, action);
+ if (IS_ERR(modify_hdr)) {
+ err = PTR_ERR(modify_hdr);
+ mlx5_core_err(mdev,
+ "fail to alloc psp copy modify_header_id err=%d\n", err);
+ goto out_spec;
+ }
+
+ accel_psp_setup_syndrome_match(spec, PSP_OK);
+ /* create fte */
+ flow_act.action = MLX5_FLOW_CONTEXT_ACTION_MOD_HDR |
+ MLX5_FLOW_CONTEXT_ACTION_FWD_DEST;
+ flow_act.modify_hdr = modify_hdr;
+ fte = mlx5_add_flow_rules(rx_err->ft, spec, &flow_act,
+ &fs_prot->default_dest, 1);
+ if (IS_ERR(fte)) {
+ err = PTR_ERR(fte);
+ mlx5_core_err(mdev, "fail to add psp rx err copy rule err=%d\n", err);
+ goto out;
+ }
+ rx_err->rule = fte;
+
+ /* add default drop rule */
+ memset(spec, 0, sizeof(*spec));
+ memset(&flow_act, 0, sizeof(flow_act));
+ /* create fte */
+ flow_act.action = MLX5_FLOW_CONTEXT_ACTION_DROP;
+ fte = mlx5_add_flow_rules(rx_err->ft, spec, &flow_act, NULL, 0);
+ if (IS_ERR(fte)) {
+ err = PTR_ERR(fte);
+ mlx5_core_err(mdev, "fail to add psp rx err drop rule err=%d\n", err);
+ goto out_drop_rule;
+ }
+ rx_err->drop_rule = fte;
+ rx_err->copy_modify_hdr = modify_hdr;
+
+ goto out_spec;
+
+out_drop_rule:
+ mlx5_del_flow_rules(rx_err->rule);
+ rx_err->rule = NULL;
+out:
+ mlx5_modify_header_dealloc(mdev, modify_hdr);
+out_spec:
+ kfree(spec);
+ return err;
+}
+
+static int accel_psp_fs_rx_err_create_ft(struct mlx5e_psp_fs *fs,
+ struct mlx5e_accel_fs_psp_prot *fs_prot,
+ struct mlx5e_psp_rx_err *rx_err)
+{
+ struct mlx5_flow_namespace *ns = mlx5e_fs_get_ns(fs->fs, false);
+ struct mlx5_flow_table_attr ft_attr = {};
+ struct mlx5_flow_table *ft;
+ int err;
+
+ ft_attr.max_fte = 2;
+ ft_attr.autogroup.max_num_groups = 2;
+ ft_attr.level = MLX5E_ACCEL_FS_ESP_FT_ERR_LEVEL; // MLX5E_ACCEL_FS_TCP_FT_LEVEL
+ ft_attr.prio = MLX5E_NIC_PRIO;
+ ft = mlx5_create_auto_grouped_flow_table(ns, &ft_attr);
+ if (IS_ERR(ft)) {
+ err = PTR_ERR(ft);
+ mlx5_core_err(fs->mdev, "fail to create psp rx inline ft err=%d\n", err);
+ return err;
+ }
+
+ rx_err->ft = ft;
+ err = accel_psp_fs_rx_err_add_rule(fs, fs_prot, rx_err);
+ if (err)
+ goto out_err;
+
+ return 0;
+
+out_err:
+ mlx5_destroy_flow_table(ft);
+ rx_err->ft = NULL;
+ return err;
+}
+
+static void accel_psp_fs_rx_fs_destroy(struct mlx5e_accel_fs_psp_prot *fs_prot)
+{
+ if (fs_prot->def_rule) {
+ mlx5_del_flow_rules(fs_prot->def_rule);
+ fs_prot->def_rule = NULL;
+ }
+
+ if (fs_prot->miss_rule) {
+ mlx5_del_flow_rules(fs_prot->miss_rule);
+ fs_prot->miss_rule = NULL;
+ }
+
+ if (fs_prot->miss_group) {
+ mlx5_destroy_flow_group(fs_prot->miss_group);
+ fs_prot->miss_group = NULL;
+ }
+
+ if (fs_prot->ft) {
+ mlx5_destroy_flow_table(fs_prot->ft);
+ fs_prot->ft = NULL;
+ }
+}
+
+static void setup_fte_udp_psp(struct mlx5_flow_spec *spec, u16 udp_port)
+{
+ spec->match_criteria_enable |= MLX5_MATCH_OUTER_HEADERS;
+ MLX5_SET(fte_match_set_lyr_2_4, spec->match_criteria, udp_dport, 0xffff);
+ MLX5_SET(fte_match_set_lyr_2_4, spec->match_value, udp_dport, udp_port);
+ MLX5_SET_TO_ONES(fte_match_set_lyr_2_4, spec->match_criteria, ip_protocol);
+ MLX5_SET(fte_match_set_lyr_2_4, spec->match_value, ip_protocol, IPPROTO_UDP);
+}
+
+static int accel_psp_fs_rx_create_ft(struct mlx5e_psp_fs *fs,
+ struct mlx5e_accel_fs_psp_prot *fs_prot)
+{
+ struct mlx5_flow_namespace *ns = mlx5e_fs_get_ns(fs->fs, false);
+ u8 action[MLX5_UN_SZ_BYTES(set_add_copy_action_in_auto)] = {};
+ int inlen = MLX5_ST_SZ_BYTES(create_flow_group_in);
+ struct mlx5_modify_hdr *modify_hdr = NULL;
+ struct mlx5_flow_table_attr ft_attr = {};
+ struct mlx5_flow_destination dest = {};
+ struct mlx5_core_dev *mdev = fs->mdev;
+ struct mlx5_flow_group *miss_group;
+ MLX5_DECLARE_FLOW_ACT(flow_act);
+ struct mlx5_flow_handle *rule;
+ struct mlx5_flow_spec *spec;
+ struct mlx5_flow_table *ft;
+ u32 *flow_group_in;
+ int err = 0;
+
+ flow_group_in = kvzalloc(inlen, GFP_KERNEL);
+ spec = kvzalloc(sizeof(*spec), GFP_KERNEL);
+ if (!flow_group_in || !spec) {
+ err = -ENOMEM;
+ goto out;
+ }
+
+ /* Create FT */
+ ft_attr.max_fte = 2;
+ ft_attr.level = MLX5E_ACCEL_FS_ESP_FT_LEVEL;
+ ft_attr.prio = MLX5E_NIC_PRIO;
+ ft_attr.autogroup.num_reserved_entries = 1;
+ ft_attr.autogroup.max_num_groups = 1;
+ ft = mlx5_create_auto_grouped_flow_table(ns, &ft_attr);
+ if (IS_ERR(ft)) {
+ err = PTR_ERR(ft);
+ mlx5_core_err(mdev, "fail to create psp rx ft err=%d\n", err);
+ goto out_err;
+ }
+ fs_prot->ft = ft;
+
+ /* Create miss_group */
+ MLX5_SET(create_flow_group_in, flow_group_in, start_flow_index, ft->max_fte - 1);
+ MLX5_SET(create_flow_group_in, flow_group_in, end_flow_index, ft->max_fte - 1);
+ miss_group = mlx5_create_flow_group(ft, flow_group_in);
+ if (IS_ERR(miss_group)) {
+ err = PTR_ERR(miss_group);
+ mlx5_core_err(mdev, "fail to create psp rx miss_group err=%d\n", err);
+ goto out_err;
+ }
+ fs_prot->miss_group = miss_group;
+
+ /* Create miss rule */
+ rule = mlx5_add_flow_rules(ft, spec, &flow_act, &fs_prot->default_dest, 1);
+ if (IS_ERR(rule)) {
+ err = PTR_ERR(rule);
+ mlx5_core_err(mdev, "fail to create psp rx miss_rule err=%d\n", err);
+ goto out_err;
+ }
+ fs_prot->miss_rule = rule;
+
+ /* Add default Rx psp rule */
+ setup_fte_udp_psp(spec, PSP_DEFAULT_UDP_PORT);
+ flow_act.crypto.type = MLX5_FLOW_CONTEXT_ENCRYPT_DECRYPT_TYPE_PSP;
+ /* Set bit[31, 30] PSP marker */
+ /* Set bit[29-23] psp_syndrome is set in error FT */
+#define MLX5E_PSP_MARKER_BIT (BIT(30) | BIT(31))
+ MLX5_SET(set_action_in, action, action_type, MLX5_ACTION_TYPE_SET);
+ MLX5_SET(set_action_in, action, field, MLX5_ACTION_IN_FIELD_METADATA_REG_B);
+ MLX5_SET(set_action_in, action, data, MLX5E_PSP_MARKER_BIT);
+ MLX5_SET(set_action_in, action, offset, 0);
+ MLX5_SET(set_action_in, action, length, 32);
+
+ modify_hdr = mlx5_modify_header_alloc(mdev, MLX5_FLOW_NAMESPACE_KERNEL, 1, action);
+ if (IS_ERR(modify_hdr)) {
+ err = PTR_ERR(modify_hdr);
+ mlx5_core_err(mdev, "fail to alloc psp set modify_header_id err=%d\n", err);
+ modify_hdr = NULL;
+ goto out_err;
+ }
+
+ flow_act.action = MLX5_FLOW_CONTEXT_ACTION_FWD_DEST |
+ MLX5_FLOW_CONTEXT_ACTION_CRYPTO_DECRYPT |
+ MLX5_FLOW_CONTEXT_ACTION_MOD_HDR;
+ flow_act.modify_hdr = modify_hdr;
+ dest.type = MLX5_FLOW_DESTINATION_TYPE_FLOW_TABLE;
+ dest.ft = fs_prot->rx_err.ft;
+ rule = mlx5_add_flow_rules(fs_prot->ft, spec, &flow_act, &dest, 1);
+ if (IS_ERR(rule)) {
+ err = PTR_ERR(rule);
+ mlx5_core_err(mdev,
+ "fail to add psp rule Rx decryption, err=%d, flow_act.action = %#04X\n",
+ err, flow_act.action);
+ goto out_err;
+ }
+
+ fs_prot->def_rule = rule;
+ goto out;
+
+out_err:
+ accel_psp_fs_rx_fs_destroy(fs_prot);
+out:
+ kvfree(flow_group_in);
+ kvfree(spec);
+ return err;
+}
+
+static int accel_psp_fs_rx_destroy(struct mlx5e_psp_fs *fs, enum accel_fs_psp_type type)
+{
+ struct mlx5e_accel_fs_psp_prot *fs_prot;
+ struct mlx5e_accel_fs_psp *accel_psp;
+
+ accel_psp = fs->rx_fs;
+
+ /* The netdev unreg already happened, so all offloaded rule are already removed */
+ fs_prot = &accel_psp->fs_prot[type];
+
+ accel_psp_fs_rx_fs_destroy(fs_prot);
+
+ accel_psp_fs_rx_err_destroy_ft(fs, &fs_prot->rx_err);
+
+ return 0;
+}
+
+static int accel_psp_fs_rx_create(struct mlx5e_psp_fs *fs, enum accel_fs_psp_type type)
+{
+ struct mlx5_ttc_table *ttc = mlx5e_fs_get_ttc(fs->fs, false);
+ struct mlx5e_accel_fs_psp_prot *fs_prot;
+ struct mlx5e_accel_fs_psp *accel_psp;
+ int err;
+
+ accel_psp = fs->rx_fs;
+ fs_prot = &accel_psp->fs_prot[type];
+
+ fs_prot->default_dest = mlx5_ttc_get_default_dest(ttc, fs_psp2tt(type));
+
+ err = accel_psp_fs_rx_err_create_ft(fs, fs_prot, &fs_prot->rx_err);
+ if (err)
+ return err;
+
+ err = accel_psp_fs_rx_create_ft(fs, fs_prot);
+ if (err)
+ accel_psp_fs_rx_err_destroy_ft(fs, &fs_prot->rx_err);
+
+ return err;
+}
+
+static int accel_psp_fs_rx_ft_get(struct mlx5e_psp_fs *fs, enum accel_fs_psp_type type)
+{
+ struct mlx5e_accel_fs_psp_prot *fs_prot;
+ struct mlx5_flow_destination dest = {};
+ struct mlx5e_accel_fs_psp *accel_psp;
+ struct mlx5_ttc_table *ttc;
+ int err = 0;
+
+ if (!fs || !fs->rx_fs)
+ return -EINVAL;
+
+ ttc = mlx5e_fs_get_ttc(fs->fs, false);
+ accel_psp = fs->rx_fs;
+ fs_prot = &accel_psp->fs_prot[type];
+ mutex_lock(&fs_prot->prot_mutex);
+ if (fs_prot->refcnt++)
+ goto out;
+
+ /* create FT */
+ err = accel_psp_fs_rx_create(fs, type);
+ if (err) {
+ fs_prot->refcnt--;
+ goto out;
+ }
+
+ /* connect */
+ dest.type = MLX5_FLOW_DESTINATION_TYPE_FLOW_TABLE;
+ dest.ft = fs_prot->ft;
+ mlx5_ttc_fwd_dest(ttc, fs_psp2tt(type), &dest);
+
+out:
+ mutex_unlock(&fs_prot->prot_mutex);
+ return err;
+}
+
+static void accel_psp_fs_rx_ft_put(struct mlx5e_psp_fs *fs, enum accel_fs_psp_type type)
+{
+ struct mlx5_ttc_table *ttc = mlx5e_fs_get_ttc(fs->fs, false);
+ struct mlx5e_accel_fs_psp_prot *fs_prot;
+ struct mlx5e_accel_fs_psp *accel_psp;
+
+ accel_psp = fs->rx_fs;
+ fs_prot = &accel_psp->fs_prot[type];
+ mutex_lock(&fs_prot->prot_mutex);
+ if (--fs_prot->refcnt)
+ goto out;
+
+ /* disconnect */
+ mlx5_ttc_fwd_default_dest(ttc, fs_psp2tt(type));
+
+ /* remove FT */
+ accel_psp_fs_rx_destroy(fs, type);
+
+out:
+ mutex_unlock(&fs_prot->prot_mutex);
+}
+
+static void accel_psp_fs_cleanup_rx(struct mlx5e_psp_fs *fs)
+{
+ struct mlx5e_accel_fs_psp_prot *fs_prot;
+ struct mlx5e_accel_fs_psp *accel_psp;
+ enum accel_fs_psp_type i;
+
+ if (!fs->rx_fs)
+ return;
+
+ accel_psp = fs->rx_fs;
+ for (i = 0; i < ACCEL_FS_PSP_NUM_TYPES; i++) {
+ fs_prot = &accel_psp->fs_prot[i];
+ mutex_destroy(&fs_prot->prot_mutex);
+ WARN_ON(fs_prot->refcnt);
+ }
+ kfree(fs->rx_fs);
+ fs->rx_fs = NULL;
+}
+
+static int accel_psp_fs_init_rx(struct mlx5e_psp_fs *fs)
+{
+ struct mlx5e_accel_fs_psp_prot *fs_prot;
+ struct mlx5e_accel_fs_psp *accel_psp;
+ enum accel_fs_psp_type i;
+
+ accel_psp = kzalloc(sizeof(*accel_psp), GFP_KERNEL);
+ if (!accel_psp)
+ return -ENOMEM;
+
+ for (i = 0; i < ACCEL_FS_PSP_NUM_TYPES; i++) {
+ fs_prot = &accel_psp->fs_prot[i];
+ mutex_init(&fs_prot->prot_mutex);
+ }
+
+ fs->rx_fs = accel_psp;
+
+ return 0;
+}
+
+void mlx5_accel_psp_fs_cleanup_rx_tables(struct mlx5e_priv *priv)
+{
+ int i;
+
+ if (!priv->psp)
+ return;
+
+ for (i = 0; i < ACCEL_FS_PSP_NUM_TYPES; i++)
+ accel_psp_fs_rx_ft_put(priv->psp->fs, i);
+}
+
+int mlx5_accel_psp_fs_init_rx_tables(struct mlx5e_priv *priv)
+{
+ struct mlx5e_psp_fs *fs;
+ int err, i;
+
+ if (!priv->psp)
+ return 0;
+
+ fs = priv->psp->fs;
+ for (i = 0; i < ACCEL_FS_PSP_NUM_TYPES; i++) {
+ err = accel_psp_fs_rx_ft_get(fs, i);
+ if (err)
+ goto out_err;
+ }
+
+ return 0;
+
+out_err:
+ i--;
+ while (i >= 0) {
+ accel_psp_fs_rx_ft_put(fs, i);
+ --i;
+ }
+
+ return err;
+}
+
+static int accel_psp_fs_tx_create_ft_table(struct mlx5e_psp_fs *fs)
+{
+ int inlen = MLX5_ST_SZ_BYTES(create_flow_group_in);
+ struct mlx5_flow_table_attr ft_attr = {};
+ struct mlx5_core_dev *mdev = fs->mdev;
+ struct mlx5_flow_act flow_act = {};
+ u32 *in, *mc, *outer_headers_c;
+ struct mlx5_flow_handle *rule;
+ struct mlx5_flow_spec *spec;
+ struct mlx5e_psp_tx *tx_fs;
+ struct mlx5_flow_table *ft;
+ struct mlx5_flow_group *fg;
+ int err = 0;
+
+ spec = kvzalloc(sizeof(*spec), GFP_KERNEL);
+ in = kvzalloc(inlen, GFP_KERNEL);
+ if (!spec || !in) {
+ err = -ENOMEM;
+ goto out;
+ }
+
+ ft_attr.max_fte = 1;
+#define MLX5E_PSP_PRIO 0
+ ft_attr.prio = MLX5E_PSP_PRIO;
+#define MLX5E_PSP_LEVEL 0
+ ft_attr.level = MLX5E_PSP_LEVEL;
+ ft_attr.autogroup.max_num_groups = 1;
+
+ tx_fs = fs->tx_fs;
+ ft = mlx5_create_flow_table(tx_fs->ns, &ft_attr);
+ if (IS_ERR(ft)) {
+ err = PTR_ERR(ft);
+ mlx5_core_err(mdev, "PSP: fail to add psp tx flow table, err = %d\n", err);
+ goto out;
+ }
+
+ mc = MLX5_ADDR_OF(create_flow_group_in, in, match_criteria);
+ outer_headers_c = MLX5_ADDR_OF(fte_match_param, mc, outer_headers);
+ MLX5_SET_TO_ONES(fte_match_set_lyr_2_4, outer_headers_c, ip_protocol);
+ MLX5_SET_TO_ONES(fte_match_set_lyr_2_4, outer_headers_c, udp_dport);
+ MLX5_SET_CFG(in, match_criteria_enable, MLX5_MATCH_OUTER_HEADERS);
+ fg = mlx5_create_flow_group(ft, in);
+ if (IS_ERR(fg)) {
+ err = PTR_ERR(fg);
+ mlx5_core_err(mdev, "PSP: fail to add psp tx flow group, err = %d\n", err);
+ goto err_create_fg;
+ }
+
+ setup_fte_udp_psp(spec, PSP_DEFAULT_UDP_PORT);
+ flow_act.crypto.type = MLX5_FLOW_CONTEXT_ENCRYPT_DECRYPT_TYPE_PSP;
+ flow_act.flags |= FLOW_ACT_NO_APPEND;
+ flow_act.action = MLX5_FLOW_CONTEXT_ACTION_ALLOW |
+ MLX5_FLOW_CONTEXT_ACTION_CRYPTO_ENCRYPT;
+ rule = mlx5_add_flow_rules(ft, spec, &flow_act, NULL, 0);
+ if (IS_ERR(rule)) {
+ err = PTR_ERR(rule);
+ mlx5_core_err(mdev, "PSP: fail to add psp tx flow rule, err = %d\n", err);
+ goto err_add_flow_rule;
+ }
+
+ tx_fs->ft = ft;
+ tx_fs->fg = fg;
+ tx_fs->rule = rule;
+ goto out;
+
+err_add_flow_rule:
+ mlx5_destroy_flow_group(fg);
+err_create_fg:
+ mlx5_destroy_flow_table(ft);
+out:
+ kvfree(in);
+ kvfree(spec);
+ return err;
+}
+
+static void accel_psp_fs_tx_destroy(struct mlx5e_psp_tx *tx_fs)
+{
+ if (!tx_fs->ft)
+ return;
+
+ mlx5_del_flow_rules(tx_fs->rule);
+ mlx5_destroy_flow_group(tx_fs->fg);
+ mlx5_destroy_flow_table(tx_fs->ft);
+}
+
+static int accel_psp_fs_tx_ft_get(struct mlx5e_psp_fs *fs)
+{
+ struct mlx5e_psp_tx *tx_fs = fs->tx_fs;
+ int err = 0;
+
+ mutex_lock(&tx_fs->mutex);
+ if (tx_fs->refcnt++)
+ goto out;
+
+ err = accel_psp_fs_tx_create_ft_table(fs);
+ if (err)
+ tx_fs->refcnt--;
+out:
+ mutex_unlock(&tx_fs->mutex);
+ return err;
+}
+
+static void accel_psp_fs_tx_ft_put(struct mlx5e_psp_fs *fs)
+{
+ struct mlx5e_psp_tx *tx_fs = fs->tx_fs;
+
+ mutex_lock(&tx_fs->mutex);
+ if (--tx_fs->refcnt)
+ goto out;
+
+ accel_psp_fs_tx_destroy(tx_fs);
+out:
+ mutex_unlock(&tx_fs->mutex);
+}
+
+static void accel_psp_fs_cleanup_tx(struct mlx5e_psp_fs *fs)
+{
+ struct mlx5e_psp_tx *tx_fs = fs->tx_fs;
+
+ if (!tx_fs)
+ return;
+
+ mutex_destroy(&tx_fs->mutex);
+ WARN_ON(tx_fs->refcnt);
+ kfree(tx_fs);
+ fs->tx_fs = NULL;
+}
+
+static int accel_psp_fs_init_tx(struct mlx5e_psp_fs *fs)
+{
+ struct mlx5_flow_namespace *ns;
+ struct mlx5e_psp_tx *tx_fs;
+
+ ns = mlx5_get_flow_namespace(fs->mdev, MLX5_FLOW_NAMESPACE_EGRESS_IPSEC);
+ if (!ns)
+ return -EOPNOTSUPP;
+
+ tx_fs = kzalloc(sizeof(*tx_fs), GFP_KERNEL);
+ if (!tx_fs)
+ return -ENOMEM;
+
+ mutex_init(&tx_fs->mutex);
+ tx_fs->ns = ns;
+ fs->tx_fs = tx_fs;
+ return 0;
+}
+
+void mlx5_accel_psp_fs_cleanup_tx_tables(struct mlx5e_priv *priv)
+{
+ if (!priv->psp)
+ return;
+
+ accel_psp_fs_tx_ft_put(priv->psp->fs);
+}
+
+int mlx5_accel_psp_fs_init_tx_tables(struct mlx5e_priv *priv)
+{
+ if (!priv->psp)
+ return 0;
+
+ return accel_psp_fs_tx_ft_get(priv->psp->fs);
+}
+
+static void mlx5e_accel_psp_fs_cleanup(struct mlx5e_psp_fs *fs)
+{
+ accel_psp_fs_cleanup_rx(fs);
+ accel_psp_fs_cleanup_tx(fs);
+ kfree(fs);
+}
+
+static struct mlx5e_psp_fs *mlx5e_accel_psp_fs_init(struct mlx5e_priv *priv)
+{
+ struct mlx5e_psp_fs *fs;
+ int err = 0;
+
+ fs = kzalloc(sizeof(*fs), GFP_KERNEL);
+ if (!fs)
+ return ERR_PTR(-ENOMEM);
+
+ fs->mdev = priv->mdev;
+ err = accel_psp_fs_init_tx(fs);
+ if (err)
+ goto err_tx;
+
+ fs->fs = priv->fs;
+ err = accel_psp_fs_init_rx(fs);
+ if (err)
+ goto err_rx;
+
+ return fs;
+
+err_rx:
+ accel_psp_fs_cleanup_tx(fs);
+err_tx:
+ kfree(fs);
+ return ERR_PTR(err);
+}
+
+static int
+mlx5e_psp_set_config(struct psp_dev *psd, struct psp_dev_config *conf,
+ struct netlink_ext_ack *extack)
+{
+ return 0; /* TODO: this should actually do things to the device */
+}
+
+static int
+mlx5e_psp_generate_key_spi(struct mlx5_core_dev *mdev,
+ enum mlx5_psp_gen_spi_in_key_size keysz,
+ unsigned int keysz_bytes,
+ struct psp_key_parsed *key)
+{
+ u32 out[MLX5_ST_SZ_DW(psp_gen_spi_out) + MLX5_ST_SZ_DW(key_spi)] = {};
+ u32 in[MLX5_ST_SZ_DW(psp_gen_spi_in)] = {};
+ void *outkey;
+ int err;
+
+ WARN_ON_ONCE(keysz_bytes > PSP_MAX_KEY);
+
+ MLX5_SET(psp_gen_spi_in, in, opcode, MLX5_CMD_OP_PSP_GEN_SPI);
+ MLX5_SET(psp_gen_spi_in, in, key_size, keysz);
+ MLX5_SET(psp_gen_spi_in, in, num_of_spi, 1);
+ err = mlx5_cmd_exec(mdev, in, sizeof(in), out, sizeof(out));
+ if (err)
+ return err;
+
+ outkey = MLX5_ADDR_OF(psp_gen_spi_out, out, key_spi);
+ key->spi = cpu_to_be32(MLX5_GET(key_spi, outkey, spi));
+ memcpy(key->key, MLX5_ADDR_OF(key_spi, outkey, key) + 32 - keysz_bytes,
+ keysz_bytes);
+
+ return 0;
+}
+
+static int
+mlx5e_psp_rx_spi_alloc(struct psp_dev *psd, u32 version,
+ struct psp_key_parsed *assoc,
+ struct netlink_ext_ack *extack)
+{
+ struct mlx5e_priv *priv = netdev_priv(psd->main_netdev);
+ enum mlx5_psp_gen_spi_in_key_size keysz;
+ u8 keysz_bytes;
+
+ switch (version) {
+ case PSP_VERSION_HDR0_AES_GCM_128:
+ keysz = MLX5_PSP_GEN_SPI_IN_KEY_SIZE_128;
+ keysz_bytes = 16;
+ break;
+ case PSP_VERSION_HDR0_AES_GCM_256:
+ keysz = MLX5_PSP_GEN_SPI_IN_KEY_SIZE_256;
+ keysz_bytes = 32;
+ break;
+ default:
+ return -EINVAL;
+ }
+
+ return mlx5e_psp_generate_key_spi(priv->mdev, keysz, keysz_bytes, assoc);
+}
+
+struct psp_key {
+ u32 id;
+};
+
+static int mlx5e_psp_assoc_add(struct psp_dev *psd, struct psp_assoc *pas,
+ struct netlink_ext_ack *extack)
+{
+ struct mlx5e_priv *priv = netdev_priv(psd->main_netdev);
+ struct mlx5_core_dev *mdev = priv->mdev;
+ struct psp_key_parsed *tx = &pas->tx;
+ struct mlx5e_psp *psp = priv->psp;
+ struct psp_key *nkey;
+ int err;
+
+ mdev = priv->mdev;
+ nkey = (struct psp_key *)pas->drv_data;
+
+ err = mlx5_create_encryption_key(mdev, tx->key,
+ psp_key_size(pas->version),
+ MLX5_ACCEL_OBJ_PSP_KEY,
+ &nkey->id);
+ if (err) {
+ mlx5_core_err(mdev, "Failed to create encryption key (err = %d)\n", err);
+ return err;
+ }
+
+ atomic_inc(&psp->tx_key_cnt);
+ return 0;
+}
+
+static void mlx5e_psp_assoc_del(struct psp_dev *psd, struct psp_assoc *pas)
+{
+ struct mlx5e_priv *priv = netdev_priv(psd->main_netdev);
+ struct mlx5e_psp *psp = priv->psp;
+ struct psp_key *nkey;
+
+ nkey = (struct psp_key *)pas->drv_data;
+ mlx5_destroy_encryption_key(priv->mdev, nkey->id);
+ atomic_dec(&psp->tx_key_cnt);
+}
+
+static int mlx5e_psp_rotate_key(struct mlx5_core_dev *mdev)
+{
+ u32 in[MLX5_ST_SZ_DW(psp_rotate_key_in)] = {};
+ u32 out[MLX5_ST_SZ_DW(psp_rotate_key_out)];
+
+ MLX5_SET(psp_rotate_key_in, in, opcode,
+ MLX5_CMD_OP_PSP_ROTATE_KEY);
+
+ return mlx5_cmd_exec(mdev, in, sizeof(in), out, sizeof(out));
+}
+
+static int
+mlx5e_psp_key_rotate(struct psp_dev *psd, struct netlink_ext_ack *exack)
+{
+ struct mlx5e_priv *priv = netdev_priv(psd->main_netdev);
+
+ /* no support for protecting against external rotations */
+ psd->generation = 0;
+
+ return mlx5e_psp_rotate_key(priv->mdev);
+}
+
+static struct psp_dev_ops mlx5_psp_ops = {
+ .set_config = mlx5e_psp_set_config,
+ .rx_spi_alloc = mlx5e_psp_rx_spi_alloc,
+ .tx_key_add = mlx5e_psp_assoc_add,
+ .tx_key_del = mlx5e_psp_assoc_del,
+ .key_rotate = mlx5e_psp_key_rotate,
+};
+
+void mlx5e_psp_unregister(struct mlx5e_priv *priv)
+{
+ if (!priv->psp || !priv->psp->psp)
+ return;
+
+ psp_dev_unregister(priv->psp->psp);
+}
+
+void mlx5e_psp_register(struct mlx5e_priv *priv)
+{
+ /* FW Caps missing */
+ if (!priv->psp)
+ return;
+
+ priv->psp->caps.assoc_drv_spc = sizeof(u32);
+ priv->psp->caps.versions = 1 << PSP_VERSION_HDR0_AES_GCM_128;
+ if (MLX5_CAP_PSP(priv->mdev, psp_crypto_esp_aes_gcm_256_encrypt) &&
+ MLX5_CAP_PSP(priv->mdev, psp_crypto_esp_aes_gcm_256_decrypt))
+ priv->psp->caps.versions |= 1 << PSP_VERSION_HDR0_AES_GCM_256;
+
+ priv->psp->psp = psp_dev_create(priv->netdev, &mlx5_psp_ops,
+ &priv->psp->caps, NULL);
+ if (IS_ERR(priv->psp->psp))
+ mlx5_core_err(priv->mdev, "PSP failed to register due to %pe\n",
+ priv->psp->psp);
+}
+
+int mlx5e_psp_init(struct mlx5e_priv *priv)
+{
+ struct mlx5_core_dev *mdev = priv->mdev;
+ struct mlx5e_psp_fs *fs;
+ struct mlx5e_psp *psp;
+ int err;
+
+ if (!mlx5_is_psp_device(mdev)) {
+ mlx5_core_dbg(mdev, "PSP offload not supported\n");
+ return -EOPNOTSUPP;
+ }
+
+ if (!MLX5_CAP_ETH(mdev, swp)) {
+ mlx5_core_dbg(mdev, "SWP not supported\n");
+ return -EOPNOTSUPP;
+ }
+
+ if (!MLX5_CAP_ETH(mdev, swp_csum)) {
+ mlx5_core_dbg(mdev, "SWP checksum not supported\n");
+ return -EOPNOTSUPP;
+ }
+
+ if (!MLX5_CAP_ETH(mdev, swp_csum_l4_partial)) {
+ mlx5_core_dbg(mdev, "SWP L4 partial checksum not supported\n");
+ return -EOPNOTSUPP;
+ }
+
+ if (!MLX5_CAP_ETH(mdev, swp_lso)) {
+ mlx5_core_dbg(mdev, "PSP LSO not supported\n");
+ return -EOPNOTSUPP;
+ }
+
+ psp = kzalloc(sizeof(*psp), GFP_KERNEL);
+ if (!psp)
+ return -ENOMEM;
+
+ priv->psp = psp;
+ fs = mlx5e_accel_psp_fs_init(priv);
+ if (IS_ERR(fs)) {
+ err = PTR_ERR(fs);
+ goto out_err;
+ }
+
+ psp->fs = fs;
+
+ mlx5_core_dbg(priv->mdev, "PSP attached to netdevice\n");
+ return 0;
+
+out_err:
+ priv->psp = NULL;
+ kfree(psp);
+ return err;
+}
+
+void mlx5e_psp_cleanup(struct mlx5e_priv *priv)
+{
+ struct mlx5e_psp *psp = priv->psp;
+
+ if (!psp)
+ return;
+
+ WARN_ON(atomic_read(&psp->tx_key_cnt));
+ mlx5e_accel_psp_fs_cleanup(psp->fs);
+ priv->psp = NULL;
+ kfree(psp);
+}
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp.h b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp.h
new file mode 100644
index 000000000000..42bb671fb2cb
--- /dev/null
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp.h
@@ -0,0 +1,61 @@
+/* SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB */
+/* Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. */
+
+#ifndef __MLX5E_ACCEL_PSP_H__
+#define __MLX5E_ACCEL_PSP_H__
+#if IS_ENABLED(CONFIG_MLX5_EN_PSP)
+#include <net/psp/types.h>
+#include "en.h"
+
+struct mlx5e_psp {
+ struct psp_dev *psp;
+ struct psp_dev_caps caps;
+ struct mlx5e_psp_fs *fs;
+ atomic_t tx_key_cnt;
+};
+
+static inline bool mlx5_is_psp_device(struct mlx5_core_dev *mdev)
+{
+ if (!MLX5_CAP_GEN(mdev, psp))
+ return false;
+
+ if (!MLX5_CAP_PSP(mdev, psp_crypto_offload) ||
+ !MLX5_CAP_PSP(mdev, psp_crypto_esp_aes_gcm_128_encrypt) ||
+ !MLX5_CAP_PSP(mdev, psp_crypto_esp_aes_gcm_128_decrypt))
+ return false;
+
+ return true;
+}
+
+int mlx5_accel_psp_fs_init_rx_tables(struct mlx5e_priv *priv);
+void mlx5_accel_psp_fs_cleanup_rx_tables(struct mlx5e_priv *priv);
+int mlx5_accel_psp_fs_init_tx_tables(struct mlx5e_priv *priv);
+void mlx5_accel_psp_fs_cleanup_tx_tables(struct mlx5e_priv *priv);
+void mlx5e_psp_register(struct mlx5e_priv *priv);
+void mlx5e_psp_unregister(struct mlx5e_priv *priv);
+int mlx5e_psp_init(struct mlx5e_priv *priv);
+void mlx5e_psp_cleanup(struct mlx5e_priv *priv);
+#else
+static inline int mlx5_accel_psp_fs_init_rx_tables(struct mlx5e_priv *priv)
+{
+ return 0;
+}
+
+static inline void mlx5_accel_psp_fs_cleanup_rx_tables(struct mlx5e_priv *priv) { }
+static inline int mlx5_accel_psp_fs_init_tx_tables(struct mlx5e_priv *priv)
+{
+ return 0;
+}
+
+static inline void mlx5_accel_psp_fs_cleanup_tx_tables(struct mlx5e_priv *priv) { }
+static inline bool mlx5_is_psp_device(struct mlx5_core_dev *mdev)
+{
+ return false;
+}
+
+static inline void mlx5e_psp_register(struct mlx5e_priv *priv) { }
+static inline void mlx5e_psp_unregister(struct mlx5e_priv *priv) { }
+static inline int mlx5e_psp_init(struct mlx5e_priv *priv) { return 0; }
+static inline void mlx5e_psp_cleanup(struct mlx5e_priv *priv) { }
+#endif /* CONFIG_MLX5_EN_PSP */
+#endif /* __MLX5E_ACCEL_PSP_H__ */
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp_rxtx.c b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp_rxtx.c
new file mode 100644
index 000000000000..828bff1137af
--- /dev/null
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp_rxtx.c
@@ -0,0 +1,200 @@
+// SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
+/* Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. */
+
+#include <linux/skbuff.h>
+#include <linux/ip.h>
+#include <linux/udp.h>
+#include <net/protocol.h>
+#include <net/udp.h>
+#include <net/ip6_checksum.h>
+#include <net/psp/types.h>
+
+#include "en.h"
+#include "psp.h"
+#include "en_accel/psp_rxtx.h"
+#include "en_accel/psp.h"
+
+enum {
+ MLX5E_PSP_OFFLOAD_RX_SYNDROME_DECRYPTED,
+ MLX5E_PSP_OFFLOAD_RX_SYNDROME_AUTH_FAILED,
+ MLX5E_PSP_OFFLOAD_RX_SYNDROME_BAD_TRAILER,
+};
+
+static void mlx5e_psp_set_swp(struct sk_buff *skb,
+ struct mlx5e_accel_tx_psp_state *psp_st,
+ struct mlx5_wqe_eth_seg *eseg)
+{
+ /* Tunnel Mode:
+ * SWP: OutL3 InL3 InL4
+ * Pkt: MAC IP ESP IP L4
+ *
+ * Transport Mode:
+ * SWP: OutL3 OutL4
+ * Pkt: MAC IP ESP L4
+ *
+ * Tunnel(VXLAN TCP/UDP) over Transport Mode
+ * SWP: OutL3 InL3 InL4
+ * Pkt: MAC IP ESP UDP VXLAN IP L4
+ */
+ u8 inner_ipproto = 0;
+ struct ethhdr *eth;
+
+ /* Shared settings */
+ eseg->swp_outer_l3_offset = skb_network_offset(skb) / 2;
+ if (skb->protocol == htons(ETH_P_IPV6))
+ eseg->swp_flags |= MLX5_ETH_WQE_SWP_OUTER_L3_IPV6;
+
+ if (skb->inner_protocol_type == ENCAP_TYPE_IPPROTO) {
+ inner_ipproto = skb->inner_ipproto;
+ /* Set SWP additional flags for packet of type IP|UDP|PSP|[ TCP | UDP ] */
+ switch (inner_ipproto) {
+ case IPPROTO_UDP:
+ eseg->swp_flags |= MLX5_ETH_WQE_SWP_INNER_L4_UDP;
+ fallthrough;
+ case IPPROTO_TCP:
+ eseg->swp_inner_l4_offset = skb_inner_transport_offset(skb) / 2;
+ break;
+ default:
+ break;
+ }
+ } else {
+ /* IP in IP tunneling like vxlan*/
+ if (skb->inner_protocol_type != ENCAP_TYPE_ETHER)
+ return;
+
+ eth = (struct ethhdr *)skb_inner_mac_header(skb);
+ switch (ntohs(eth->h_proto)) {
+ case ETH_P_IP:
+ inner_ipproto = ((struct iphdr *)((char *)skb->data +
+ skb_inner_network_offset(skb)))->protocol;
+ break;
+ case ETH_P_IPV6:
+ inner_ipproto = ((struct ipv6hdr *)((char *)skb->data +
+ skb_inner_network_offset(skb)))->nexthdr;
+ break;
+ default:
+ break;
+ }
+
+ /* Tunnel(VXLAN TCP/UDP) over Transport Mode PSP i.e. PSP payload is vxlan tunnel */
+ switch (inner_ipproto) {
+ case IPPROTO_UDP:
+ eseg->swp_flags |= MLX5_ETH_WQE_SWP_INNER_L4_UDP;
+ fallthrough;
+ case IPPROTO_TCP:
+ eseg->swp_inner_l3_offset = skb_inner_network_offset(skb) / 2;
+ eseg->swp_inner_l4_offset =
+ (skb->csum_start + skb->head - skb->data) / 2;
+ if (skb->protocol == htons(ETH_P_IPV6))
+ eseg->swp_flags |= MLX5_ETH_WQE_SWP_INNER_L3_IPV6;
+ break;
+ default:
+ break;
+ }
+
+ psp_st->inner_ipproto = inner_ipproto;
+ }
+}
+
+static bool mlx5e_psp_set_state(struct mlx5e_priv *priv,
+ struct sk_buff *skb,
+ struct mlx5e_accel_tx_psp_state *psp_st)
+{
+ struct psp_assoc *pas;
+ bool ret = false;
+
+ rcu_read_lock();
+ pas = psp_skb_get_assoc_rcu(skb);
+ if (!pas)
+ goto out;
+
+ ret = true;
+ psp_st->tailen = PSP_TRL_SIZE;
+ psp_st->spi = pas->tx.spi;
+ psp_st->ver = pas->version;
+ psp_st->keyid = *(u32 *)pas->drv_data;
+
+out:
+ rcu_read_unlock();
+ return ret;
+}
+
+bool mlx5e_psp_offload_handle_rx_skb(struct net_device *netdev, struct sk_buff *skb,
+ struct mlx5_cqe64 *cqe)
+{
+ u32 psp_meta_data = be32_to_cpu(cqe->ft_metadata);
+ struct mlx5e_priv *priv = netdev_priv(netdev);
+ u16 dev_id = priv->psp->psp->id;
+ bool strip_icv = true;
+ u8 generation = 0;
+
+ /* TBD: report errors as SW counters to ethtool, any further handling ? */
+ if (MLX5_PSP_METADATA_SYNDROME(psp_meta_data) != MLX5E_PSP_OFFLOAD_RX_SYNDROME_DECRYPTED)
+ goto drop;
+
+ if (psp_dev_rcv(skb, dev_id, generation, strip_icv))
+ goto drop;
+
+ skb->decrypted = 1;
+ return false;
+
+drop:
+ kfree_skb(skb);
+ return true;
+}
+
+void mlx5e_psp_tx_build_eseg(struct mlx5e_priv *priv, struct sk_buff *skb,
+ struct mlx5e_accel_tx_psp_state *psp_st,
+ struct mlx5_wqe_eth_seg *eseg)
+{
+ if (!mlx5_is_psp_device(priv->mdev))
+ return;
+
+ if (unlikely(skb->protocol != htons(ETH_P_IP) &&
+ skb->protocol != htons(ETH_P_IPV6)))
+ return;
+
+ mlx5e_psp_set_swp(skb, psp_st, eseg);
+ /* Special WA for PSP LSO in ConnectX7 */
+ eseg->swp_outer_l3_offset = 0;
+ eseg->swp_inner_l3_offset = 0;
+
+ eseg->flow_table_metadata |= cpu_to_be32(psp_st->keyid);
+ eseg->trailer |= cpu_to_be32(MLX5_ETH_WQE_INSERT_TRAILER) |
+ cpu_to_be32(MLX5_ETH_WQE_TRAILER_HDR_OUTER_L4_ASSOC);
+}
+
+void mlx5e_psp_handle_tx_wqe(struct mlx5e_tx_wqe *wqe,
+ struct mlx5e_accel_tx_psp_state *psp_st,
+ struct mlx5_wqe_inline_seg *inlseg)
+{
+ inlseg->byte_count = cpu_to_be32(psp_st->tailen | MLX5_INLINE_SEG);
+}
+
+bool mlx5e_psp_handle_tx_skb(struct net_device *netdev,
+ struct sk_buff *skb,
+ struct mlx5e_accel_tx_psp_state *psp_st)
+{
+ struct mlx5e_priv *priv = netdev_priv(netdev);
+ struct net *net = sock_net(skb->sk);
+ const struct ipv6hdr *ip6;
+ struct tcphdr *th;
+
+ if (!mlx5e_psp_set_state(priv, skb, psp_st))
+ return true;
+
+ /* psp_encap of the packet */
+ if (!psp_dev_encapsulate(net, skb, psp_st->spi, psp_st->ver, 0)) {
+ kfree_skb_reason(skb, SKB_DROP_REASON_PSP_OUTPUT);
+ return false;
+ }
+ if (skb_is_gso(skb)) {
+ ip6 = ipv6_hdr(skb);
+ th = inner_tcp_hdr(skb);
+
+ th->check = ~tcp_v6_check(skb_shinfo(skb)->gso_size + inner_tcp_hdrlen(skb), &ip6->saddr,
+ &ip6->daddr, 0);
+ }
+
+ return true;
+}
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp_rxtx.h b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp_rxtx.h
new file mode 100644
index 000000000000..70289c921bd6
--- /dev/null
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/psp_rxtx.h
@@ -0,0 +1,121 @@
+/* SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB */
+/* Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. */
+
+#ifndef __MLX5E_PSP_RXTX_H__
+#define __MLX5E_PSP_RXTX_H__
+
+#include <linux/skbuff.h>
+#include <net/xfrm.h>
+#include <net/psp.h>
+#include "en.h"
+#include "en/txrx.h"
+
+/* Bit30: PSP marker, Bit29-23: PSP syndrome, Bit22-0: PSP obj id */
+#define MLX5_PSP_METADATA_MARKER(metadata) ((((metadata) >> 30) & 0x3) == 0x3)
+#define MLX5_PSP_METADATA_SYNDROME(metadata) (((metadata) >> 23) & GENMASK(6, 0))
+#define MLX5_PSP_METADATA_HANDLE(metadata) ((metadata) & GENMASK(22, 0))
+
+struct mlx5e_accel_tx_psp_state {
+ u32 tailen;
+ u32 keyid;
+ __be32 spi;
+ u8 inner_ipproto;
+ u8 ver;
+};
+
+#ifdef CONFIG_MLX5_EN_PSP
+static inline bool mlx5e_psp_is_offload_state(struct mlx5e_accel_tx_psp_state *psp_state)
+{
+ return (psp_state->tailen != 0);
+}
+
+static inline bool mlx5e_psp_is_offload(struct sk_buff *skb, struct net_device *netdev)
+{
+ bool ret;
+
+ rcu_read_lock();
+ ret = !!psp_skb_get_assoc_rcu(skb);
+ rcu_read_unlock();
+ return ret;
+}
+
+bool mlx5e_psp_handle_tx_skb(struct net_device *netdev,
+ struct sk_buff *skb,
+ struct mlx5e_accel_tx_psp_state *psp_st);
+
+void mlx5e_psp_tx_build_eseg(struct mlx5e_priv *priv, struct sk_buff *skb,
+ struct mlx5e_accel_tx_psp_state *psp_st,
+ struct mlx5_wqe_eth_seg *eseg);
+
+void mlx5e_psp_handle_tx_wqe(struct mlx5e_tx_wqe *wqe,
+ struct mlx5e_accel_tx_psp_state *psp_st,
+ struct mlx5_wqe_inline_seg *inlseg);
+
+static inline bool mlx5e_psp_txwqe_build_eseg_csum(struct mlx5e_txqsq *sq, struct sk_buff *skb,
+ struct mlx5e_accel_tx_psp_state *psp_st,
+ struct mlx5_wqe_eth_seg *eseg)
+{
+ u8 inner_ipproto;
+
+ if (!mlx5e_psp_is_offload_state(psp_st))
+ return false;
+
+ inner_ipproto = psp_st->inner_ipproto;
+ eseg->cs_flags = MLX5_ETH_WQE_L3_CSUM;
+ if (inner_ipproto) {
+ eseg->cs_flags |= MLX5_ETH_WQE_L3_INNER_CSUM;
+ if (inner_ipproto == IPPROTO_TCP || inner_ipproto == IPPROTO_UDP)
+ eseg->cs_flags |= MLX5_ETH_WQE_L4_INNER_CSUM;
+ if (likely(skb->ip_summed == CHECKSUM_PARTIAL))
+ sq->stats->csum_partial_inner++;
+ } else if (likely(skb->ip_summed == CHECKSUM_PARTIAL)) {
+ eseg->cs_flags |= MLX5_ETH_WQE_L4_INNER_CSUM;
+ sq->stats->csum_partial_inner++;
+ }
+
+ return true;
+}
+
+static inline unsigned int mlx5e_psp_tx_ids_len(struct mlx5e_accel_tx_psp_state *psp_st)
+{
+ return psp_st->tailen;
+}
+
+static inline bool mlx5e_psp_is_rx_flow(struct mlx5_cqe64 *cqe)
+{
+ return MLX5_PSP_METADATA_MARKER(be32_to_cpu(cqe->ft_metadata));
+}
+
+bool mlx5e_psp_offload_handle_rx_skb(struct net_device *netdev, struct sk_buff *skb,
+ struct mlx5_cqe64 *cqe);
+#else
+static inline bool mlx5e_psp_is_offload_state(struct mlx5e_accel_tx_psp_state *psp_state)
+{
+ return false;
+}
+
+static inline bool mlx5e_psp_is_offload(struct sk_buff *skb, struct net_device *netdev)
+{
+ return false;
+}
+
+static inline bool mlx5e_psp_txwqe_build_eseg_csum(struct mlx5e_txqsq *sq, struct sk_buff *skb,
+ struct mlx5e_accel_tx_psp_state *psp_st,
+ struct mlx5_wqe_eth_seg *eseg)
+{
+ return false;
+}
+
+static inline bool mlx5e_psp_is_rx_flow(struct mlx5_cqe64 *cqe)
+{
+ return false;
+}
+
+static inline bool mlx5e_psp_offload_handle_rx_skb(struct net_device *netdev,
+ struct sk_buff *skb,
+ struct mlx5_cqe64 *cqe)
+{
+ return false;
+}
+#endif /* CONFIG_MLX5_EN_PSP */
+#endif /* __MLX5E_PSP_RXTX_H__ */
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_main.c b/drivers/net/ethernet/mellanox/mlx5/core/en_main.c
index 6aec5edc204e..9d502d40d864 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_main.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_main.c
@@ -53,6 +53,7 @@
#include "en_tc.h"
#include "en_rep.h"
#include "en_accel/ipsec.h"
+#include "en_accel/psp.h"
#include "en_accel/macsec.h"
#include "en_accel/en_accel.h"
#include "en_accel/ktls.h"
@@ -5919,6 +5920,10 @@ static int mlx5e_nic_init(struct mlx5_core_dev *mdev,
}
priv->fs = fs;
+ err = mlx5e_psp_init(priv);
+ if (err)
+ mlx5_core_err(mdev, "PSP initialization failed, %d\n", err);
+
err = mlx5e_ktls_init(priv);
if (err)
mlx5_core_err(mdev, "TLS initialization failed, %d\n", err);
@@ -5931,6 +5936,7 @@ static int mlx5e_nic_init(struct mlx5_core_dev *mdev,
if (take_rtnl)
rtnl_lock();
+ mlx5e_psp_register(priv);
/* update XDP supported features */
mlx5e_set_xdp_feature(netdev);
@@ -5943,7 +5949,9 @@ static int mlx5e_nic_init(struct mlx5_core_dev *mdev,
static void mlx5e_nic_cleanup(struct mlx5e_priv *priv)
{
mlx5e_health_destroy_reporters(priv);
+ mlx5e_psp_unregister(priv);
mlx5e_ktls_cleanup(priv);
+ mlx5e_psp_cleanup(priv);
mlx5e_fs_cleanup(priv->fs);
debugfs_remove_recursive(priv->dfs_root);
priv->fs = NULL;
@@ -6791,6 +6799,7 @@ static void _mlx5e_remove(struct auxiliary_device *adev)
* is already unregistered before changing to NIC profile.
*/
if (priv->netdev->reg_state == NETREG_REGISTERED) {
+ mlx5e_psp_unregister(priv);
unregister_netdev(priv->netdev);
_mlx5e_suspend(adev, false);
} else {
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_rx.c b/drivers/net/ethernet/mellanox/mlx5/core/en_rx.c
index 2925ece136c4..4ed43ee9aa35 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_rx.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_rx.c
@@ -51,6 +51,7 @@
#include "ipoib/ipoib.h"
#include "en_accel/ipsec.h"
#include "en_accel/macsec.h"
+#include "en_accel/psp_rxtx.h"
#include "en_accel/ipsec_rxtx.h"
#include "en_accel/ktls_txrx.h"
#include "en/xdp.h"
@@ -1521,6 +1522,11 @@ static inline void mlx5e_handle_csum(struct net_device *netdev,
skb->ip_summed = CHECKSUM_COMPLETE;
skb->csum = csum_unfold((__force __sum16)cqe->check_sum);
+ if (unlikely(mlx5e_psp_is_rx_flow(cqe))) {
+ /* TBD: PSP csum complete corrections for now chose csum_unnecessary path */
+ goto csum_unnecessary;
+ }
+
if (test_bit(MLX5E_RQ_STATE_CSUM_FULL, &rq->state))
return; /* CQE csum covers all received bytes */
@@ -1549,7 +1555,7 @@ csum_none:
#define MLX5E_CE_BIT_MASK 0x80
-static inline void mlx5e_build_rx_skb(struct mlx5_cqe64 *cqe,
+static inline bool mlx5e_build_rx_skb(struct mlx5_cqe64 *cqe,
u32 cqe_bcnt,
struct mlx5e_rq *rq,
struct sk_buff *skb)
@@ -1563,6 +1569,11 @@ static inline void mlx5e_build_rx_skb(struct mlx5_cqe64 *cqe,
if (unlikely(get_cqe_tls_offload(cqe)))
mlx5e_ktls_handle_rx_skb(rq, skb, cqe, &cqe_bcnt);
+ if (unlikely(mlx5e_psp_is_rx_flow(cqe))) {
+ if (mlx5e_psp_offload_handle_rx_skb(netdev, skb, cqe))
+ return true;
+ }
+
if (unlikely(mlx5_ipsec_is_rx_flow(cqe)))
mlx5e_ipsec_offload_handle_rx_skb(netdev, skb,
be32_to_cpu(cqe->ft_metadata));
@@ -1608,9 +1619,11 @@ static inline void mlx5e_build_rx_skb(struct mlx5_cqe64 *cqe,
if (unlikely(mlx5e_skb_is_multicast(skb)))
stats->mcast_packets++;
+
+ return false;
}
-static void mlx5e_shampo_complete_rx_cqe(struct mlx5e_rq *rq,
+static bool mlx5e_shampo_complete_rx_cqe(struct mlx5e_rq *rq,
struct mlx5_cqe64 *cqe,
u32 cqe_bcnt,
struct sk_buff *skb)
@@ -1620,16 +1633,20 @@ static void mlx5e_shampo_complete_rx_cqe(struct mlx5e_rq *rq,
stats->packets++;
stats->bytes += cqe_bcnt;
if (NAPI_GRO_CB(skb)->count != 1)
- return;
- mlx5e_build_rx_skb(cqe, cqe_bcnt, rq, skb);
+ return false;
+
+ if (mlx5e_build_rx_skb(cqe, cqe_bcnt, rq, skb))
+ return true;
+
skb_reset_network_header(skb);
if (!skb_flow_dissect_flow_keys(skb, &rq->hw_gro_data->fk, 0)) {
napi_gro_receive(rq->cq.napi, skb);
rq->hw_gro_data->skb = NULL;
}
+ return false;
}
-static inline void mlx5e_complete_rx_cqe(struct mlx5e_rq *rq,
+static inline bool mlx5e_complete_rx_cqe(struct mlx5e_rq *rq,
struct mlx5_cqe64 *cqe,
u32 cqe_bcnt,
struct sk_buff *skb)
@@ -1638,7 +1655,7 @@ static inline void mlx5e_complete_rx_cqe(struct mlx5e_rq *rq,
stats->packets++;
stats->bytes += cqe_bcnt;
- mlx5e_build_rx_skb(cqe, cqe_bcnt, rq, skb);
+ return mlx5e_build_rx_skb(cqe, cqe_bcnt, rq, skb);
}
static inline
@@ -1854,7 +1871,8 @@ static void mlx5e_handle_rx_cqe(struct mlx5e_rq *rq, struct mlx5_cqe64 *cqe)
goto wq_cyc_pop;
}
- mlx5e_complete_rx_cqe(rq, cqe, cqe_bcnt, skb);
+ if (mlx5e_complete_rx_cqe(rq, cqe, cqe_bcnt, skb))
+ goto wq_cyc_pop;
if (mlx5e_cqe_regb_chain(cqe))
if (!mlx5e_tc_update_skb_nic(cqe, skb)) {
@@ -1901,7 +1919,8 @@ static void mlx5e_handle_rx_cqe_rep(struct mlx5e_rq *rq, struct mlx5_cqe64 *cqe)
goto wq_cyc_pop;
}
- mlx5e_complete_rx_cqe(rq, cqe, cqe_bcnt, skb);
+ if (mlx5e_complete_rx_cqe(rq, cqe, cqe_bcnt, skb))
+ goto wq_cyc_pop;
if (rep->vlan && skb_vlan_tag_present(skb))
skb_vlan_pop(skb);
@@ -1950,7 +1969,8 @@ static void mlx5e_handle_rx_cqe_mpwrq_rep(struct mlx5e_rq *rq, struct mlx5_cqe64
if (!skb)
goto mpwrq_cqe_out;
- mlx5e_complete_rx_cqe(rq, cqe, cqe_bcnt, skb);
+ if (mlx5e_complete_rx_cqe(rq, cqe, cqe_bcnt, skb))
+ goto mpwrq_cqe_out;
mlx5e_rep_tc_receive(cqe, rq, skb);
@@ -2387,7 +2407,10 @@ static void mlx5e_handle_rx_cqe_mpwrq_shampo(struct mlx5e_rq *rq, struct mlx5_cq
stats->hds_nosplit_bytes += data_bcnt;
}
- mlx5e_shampo_complete_rx_cqe(rq, cqe, cqe_bcnt, *skb);
+ if (mlx5e_shampo_complete_rx_cqe(rq, cqe, cqe_bcnt, *skb)) {
+ *skb = NULL;
+ goto free_hd_entry;
+ }
if (flush && rq->hw_gro_data->skb)
mlx5e_shampo_flush_skb(rq, cqe, match);
free_hd_entry:
@@ -2445,7 +2468,8 @@ static void mlx5e_handle_rx_cqe_mpwrq(struct mlx5e_rq *rq, struct mlx5_cqe64 *cq
if (!skb)
goto mpwrq_cqe_out;
- mlx5e_complete_rx_cqe(rq, cqe, cqe_bcnt, skb);
+ if (mlx5e_complete_rx_cqe(rq, cqe, cqe_bcnt, skb))
+ goto mpwrq_cqe_out;
if (mlx5e_cqe_regb_chain(cqe))
if (!mlx5e_tc_update_skb_nic(cqe, skb)) {
@@ -2778,7 +2802,8 @@ static void mlx5e_trap_handle_rx_cqe(struct mlx5e_rq *rq, struct mlx5_cqe64 *cqe
if (!skb)
goto wq_cyc_pop;
- mlx5e_complete_rx_cqe(rq, cqe, cqe_bcnt, skb);
+ if (mlx5e_complete_rx_cqe(rq, cqe, cqe_bcnt, skb))
+ goto wq_cyc_pop;
skb_push(skb, ETH_HLEN);
mlx5_devlink_trap_report(rq->mdev, trap_id, skb,
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_tx.c b/drivers/net/ethernet/mellanox/mlx5/core/en_tx.c
index 319061d31602..10d50ca637f1 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_tx.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_tx.c
@@ -39,6 +39,7 @@
#include "ipoib/ipoib.h"
#include "en_accel/en_accel.h"
#include "en_accel/ipsec_rxtx.h"
+#include "en_accel/psp_rxtx.h"
#include "en_accel/macsec.h"
#include "en/ptp.h"
#include <net/ipv6.h>
@@ -120,6 +121,11 @@ mlx5e_txwqe_build_eseg_csum(struct mlx5e_txqsq *sq, struct sk_buff *skb,
struct mlx5e_accel_tx_state *accel,
struct mlx5_wqe_eth_seg *eseg)
{
+#ifdef CONFIG_MLX5_EN_PSP
+ if (unlikely(mlx5e_psp_txwqe_build_eseg_csum(sq, skb, &accel->psp_st, eseg)))
+ return;
+#endif
+
if (unlikely(mlx5e_ipsec_txwqe_build_eseg_csum(sq, skb, eseg)))
return;
@@ -297,7 +303,7 @@ static void mlx5e_sq_xmit_prepare(struct mlx5e_txqsq *sq, struct sk_buff *skb,
stats->packets++;
}
- attr->insz = mlx5e_accel_tx_ids_len(sq, accel);
+ attr->insz = mlx5e_accel_tx_ids_len(sq, skb, accel);
stats->bytes += attr->num_bytes;
}
@@ -661,7 +667,7 @@ static void mlx5e_txwqe_build_eseg(struct mlx5e_priv *priv, struct mlx5e_txqsq *
struct sk_buff *skb, struct mlx5e_accel_tx_state *accel,
struct mlx5_wqe_eth_seg *eseg, u16 ihs)
{
- mlx5e_accel_tx_eseg(priv, skb, eseg, ihs);
+ mlx5e_accel_tx_eseg(priv, skb, accel, eseg, ihs);
mlx5e_txwqe_build_eseg_csum(sq, skb, accel, eseg);
if (unlikely(sq->ptpsq))
mlx5e_cqe_ts_id_eseg(sq->ptpsq, skb, eseg);
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/lib/crypto.h b/drivers/net/ethernet/mellanox/mlx5/core/lib/crypto.h
index c819c047bb9c..4821163a547f 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/lib/crypto.h
+++ b/drivers/net/ethernet/mellanox/mlx5/core/lib/crypto.h
@@ -8,6 +8,7 @@ enum {
MLX5_ACCEL_OBJ_TLS_KEY = MLX5_GENERAL_OBJECT_TYPE_ENCRYPTION_KEY_PURPOSE_TLS,
MLX5_ACCEL_OBJ_IPSEC_KEY = MLX5_GENERAL_OBJECT_TYPE_ENCRYPTION_KEY_PURPOSE_IPSEC,
MLX5_ACCEL_OBJ_MACSEC_KEY = MLX5_GENERAL_OBJECT_TYPE_ENCRYPTION_KEY_PURPOSE_MACSEC,
+ MLX5_ACCEL_OBJ_PSP_KEY = MLX5_GENERAL_OBJECT_TYPE_ENCRYPTION_KEY_PURPOSE_PSP,
MLX5_ACCEL_OBJ_TYPE_KEY_NUM,
};
diff --git a/include/linux/netdevice.h b/include/linux/netdevice.h
index f5a840c07cf1..1c54d44805fa 100644
--- a/include/linux/netdevice.h
+++ b/include/linux/netdevice.h
@@ -1906,6 +1906,7 @@ enum netdev_reg_state {
* device struct
* @mpls_ptr: mpls_dev struct pointer
* @mctp_ptr: MCTP specific data
+ * @psp_dev: PSP crypto device registered for this netdev
*
* @dev_addr: Hw address (before bcast,
* because most packets are unicast)
@@ -2310,6 +2311,9 @@ struct net_device {
#if IS_ENABLED(CONFIG_MCTP)
struct mctp_dev __rcu *mctp_ptr;
#endif
+#if IS_ENABLED(CONFIG_INET_PSP)
+ struct psp_dev __rcu *psp_dev;
+#endif
/*
* Cache lines mostly used on receive path (including eth_type_trans())
diff --git a/include/linux/skbuff.h b/include/linux/skbuff.h
index 62e7addccdf6..78ecfa7d00d0 100644
--- a/include/linux/skbuff.h
+++ b/include/linux/skbuff.h
@@ -4902,6 +4902,9 @@ enum skb_ext_id {
#if IS_ENABLED(CONFIG_MCTP_FLOWS)
SKB_EXT_MCTP,
#endif
+#if IS_ENABLED(CONFIG_INET_PSP)
+ SKB_EXT_PSP,
+#endif
SKB_EXT_NUM, /* must be last */
};
diff --git a/include/net/dropreason-core.h b/include/net/dropreason-core.h
index d8ff24a33459..58d91ccc56e0 100644
--- a/include/net/dropreason-core.h
+++ b/include/net/dropreason-core.h
@@ -127,6 +127,8 @@
FN(CANXL_RX_INVALID_FRAME) \
FN(PFMEMALLOC) \
FN(DUALPI2_STEP_DROP) \
+ FN(PSP_INPUT) \
+ FN(PSP_OUTPUT) \
FNe(MAX)
/**
@@ -610,6 +612,10 @@ enum skb_drop_reason {
* threshold of DualPI2 qdisc.
*/
SKB_DROP_REASON_DUALPI2_STEP_DROP,
+ /** @SKB_DROP_REASON_PSP_INPUT: PSP input checks failed */
+ SKB_DROP_REASON_PSP_INPUT,
+ /** @SKB_DROP_REASON_PSP_OUTPUT: PSP output checks failed */
+ SKB_DROP_REASON_PSP_OUTPUT,
/**
* @SKB_DROP_REASON_MAX: the maximum of core drop reasons, which
* shouldn't be used as a real 'reason' - only for tracing code gen
diff --git a/include/net/inet_timewait_sock.h b/include/net/inet_timewait_sock.h
index 67a313575780..3a31c74c9e15 100644
--- a/include/net/inet_timewait_sock.h
+++ b/include/net/inet_timewait_sock.h
@@ -81,6 +81,14 @@ struct inet_timewait_sock {
struct timer_list tw_timer;
struct inet_bind_bucket *tw_tb;
struct inet_bind2_bucket *tw_tb2;
+#if IS_ENABLED(CONFIG_INET_PSP)
+ struct psp_assoc __rcu *psp_assoc;
+#endif
+#ifdef CONFIG_SOCK_VALIDATE_XMIT
+ struct sk_buff* (*tw_validate_xmit_skb)(struct sock *sk,
+ struct net_device *dev,
+ struct sk_buff *skb);
+#endif
};
#define tw_tclass tw_tos
diff --git a/include/net/psp.h b/include/net/psp.h
new file mode 100644
index 000000000000..33bb4d1dc46e
--- /dev/null
+++ b/include/net/psp.h
@@ -0,0 +1,12 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+
+#ifndef __NET_PSP_ALL_H
+#define __NET_PSP_ALL_H
+
+#include <uapi/linux/psp.h>
+#include <net/psp/functions.h>
+#include <net/psp/types.h>
+
+/* Do not add any code here. Put it in the sub-headers instead. */
+
+#endif /* __NET_PSP_ALL_H */
diff --git a/include/net/psp/functions.h b/include/net/psp/functions.h
new file mode 100644
index 000000000000..91ba06733321
--- /dev/null
+++ b/include/net/psp/functions.h
@@ -0,0 +1,210 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+
+#ifndef __NET_PSP_HELPERS_H
+#define __NET_PSP_HELPERS_H
+
+#include <linux/skbuff.h>
+#include <linux/rcupdate.h>
+#include <linux/udp.h>
+#include <net/sock.h>
+#include <net/tcp.h>
+#include <net/psp/types.h>
+
+struct inet_timewait_sock;
+
+/* Driver-facing API */
+struct psp_dev *
+psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops,
+ struct psp_dev_caps *psd_caps, void *priv_ptr);
+void psp_dev_unregister(struct psp_dev *psd);
+bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
+ u8 ver, __be16 sport);
+int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv);
+
+/* Kernel-facing API */
+void psp_assoc_put(struct psp_assoc *pas);
+
+static inline void *psp_assoc_drv_data(struct psp_assoc *pas)
+{
+ return pas->drv_data;
+}
+
+#if IS_ENABLED(CONFIG_INET_PSP)
+unsigned int psp_key_size(u32 version);
+void psp_sk_assoc_free(struct sock *sk);
+void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk);
+void psp_twsk_assoc_free(struct inet_timewait_sock *tw);
+void psp_reply_set_decrypted(struct sk_buff *skb);
+
+static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
+{
+ return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk));
+}
+
+static inline void
+psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb)
+{
+ struct psp_assoc *pas;
+
+ pas = psp_sk_assoc(sk);
+ if (pas && pas->tx.spi)
+ skb->decrypted = 1;
+}
+
+static inline unsigned long
+__psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
+ unsigned long diffs)
+{
+ struct psp_skb_ext *a, *b;
+
+ a = skb_ext_find(one, SKB_EXT_PSP);
+ b = skb_ext_find(two, SKB_EXT_PSP);
+
+ diffs |= (!!a) ^ (!!b);
+ if (!diffs && unlikely(a))
+ diffs |= memcmp(a, b, sizeof(*a));
+ return diffs;
+}
+
+static inline bool
+psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas)
+{
+ bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN);
+ u32 end_seq = TCP_SKB_CB(skb)->end_seq;
+ u32 seq = TCP_SKB_CB(skb)->seq;
+ bool pure_fin;
+
+ pure_fin = fin && end_seq - seq == 1;
+
+ return seq == end_seq || (pure_fin && seq == pas->upgrade_seq);
+}
+
+static inline bool
+psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas)
+{
+ return pse && pas->rx.spi == pse->spi &&
+ pas->generation == pse->generation &&
+ pas->version == pse->version &&
+ pas->dev_id == pse->dev_id;
+}
+
+static inline enum skb_drop_reason
+__psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas)
+{
+ struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP);
+
+ if (!pas)
+ return pse ? SKB_DROP_REASON_PSP_INPUT : 0;
+
+ if (likely(psp_pse_matches_pas(pse, pas))) {
+ if (unlikely(!pas->peer_tx))
+ pas->peer_tx = 1;
+
+ return 0;
+ }
+
+ if (!pse) {
+ if (!pas->tx.spi ||
+ (!pas->peer_tx && psp_is_allowed_nondata(skb, pas)))
+ return 0;
+ }
+
+ return SKB_DROP_REASON_PSP_INPUT;
+}
+
+static inline enum skb_drop_reason
+psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
+{
+ return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk));
+}
+
+static inline enum skb_drop_reason
+psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
+{
+ return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc));
+}
+
+static inline struct psp_assoc *psp_sk_get_assoc_rcu(struct sock *sk)
+{
+ struct inet_timewait_sock *tw;
+ struct psp_assoc *pas;
+ int state;
+
+ state = 1 << READ_ONCE(sk->sk_state);
+ if (!sk_is_inet(sk) || state & TCPF_NEW_SYN_RECV)
+ return NULL;
+
+ tw = inet_twsk(sk);
+ pas = state & TCPF_TIME_WAIT ? rcu_dereference(tw->psp_assoc) :
+ rcu_dereference(sk->psp_assoc);
+ return pas;
+}
+
+static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
+{
+ if (!skb->decrypted || !skb->sk)
+ return NULL;
+
+ return psp_sk_get_assoc_rcu(skb->sk);
+}
+
+static inline unsigned int psp_sk_overhead(const struct sock *sk)
+{
+ int psp_encap = sizeof(struct udphdr) + PSP_HDR_SIZE + PSP_TRL_SIZE;
+ bool has_psp = rcu_access_pointer(sk->psp_assoc);
+
+ return has_psp ? psp_encap : 0;
+}
+#else
+static inline void psp_sk_assoc_free(struct sock *sk) { }
+static inline void
+psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { }
+static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { }
+static inline void
+psp_reply_set_decrypted(struct sk_buff *skb) { }
+
+static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk)
+{
+ return NULL;
+}
+
+static inline void
+psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { }
+
+static inline unsigned long
+__psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two,
+ unsigned long diffs)
+{
+ return diffs;
+}
+
+static inline enum skb_drop_reason
+psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb)
+{
+ return 0;
+}
+
+static inline enum skb_drop_reason
+psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb)
+{
+ return 0;
+}
+
+static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb)
+{
+ return NULL;
+}
+
+static inline unsigned int psp_sk_overhead(const struct sock *sk)
+{
+ return 0;
+}
+#endif
+
+static inline unsigned long
+psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two)
+{
+ return __psp_skb_coalesce_diff(one, two, 0);
+}
+
+#endif /* __NET_PSP_HELPERS_H */
diff --git a/include/net/psp/types.h b/include/net/psp/types.h
new file mode 100644
index 000000000000..d9688e66cf09
--- /dev/null
+++ b/include/net/psp/types.h
@@ -0,0 +1,184 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+
+#ifndef __NET_PSP_H
+#define __NET_PSP_H
+
+#include <linux/mutex.h>
+#include <linux/refcount.h>
+
+struct netlink_ext_ack;
+
+#define PSP_DEFAULT_UDP_PORT 1000
+
+struct psphdr {
+ u8 nexthdr;
+ u8 hdrlen;
+ u8 crypt_offset;
+ u8 verfl;
+ __be32 spi;
+ __be64 iv;
+ __be64 vc[]; /* optional */
+};
+
+#define PSP_ENCAP_HLEN (sizeof(struct udphdr) + sizeof(struct psphdr))
+
+#define PSP_SPI_KEY_ID GENMASK(30, 0)
+#define PSP_SPI_KEY_PHASE BIT(31)
+
+#define PSPHDR_CRYPT_OFFSET GENMASK(5, 0)
+
+#define PSPHDR_VERFL_SAMPLE BIT(7)
+#define PSPHDR_VERFL_DROP BIT(6)
+#define PSPHDR_VERFL_VERSION GENMASK(5, 2)
+#define PSPHDR_VERFL_VIRT BIT(1)
+#define PSPHDR_VERFL_ONE BIT(0)
+
+#define PSP_HDRLEN_NOOPT ((sizeof(struct psphdr) - 8) / 8)
+
+/**
+ * struct psp_dev_config - PSP device configuration
+ * @versions: PSP versions enabled on the device
+ */
+struct psp_dev_config {
+ u32 versions;
+};
+
+/**
+ * struct psp_dev - PSP device struct
+ * @main_netdev: original netdevice of this PSP device
+ * @ops: driver callbacks
+ * @caps: device capabilities
+ * @drv_priv: driver priv pointer
+ * @lock: instance lock, protects all fields
+ * @refcnt: reference count for the instance
+ * @id: instance id
+ * @generation: current generation of the device key
+ * @config: current device configuration
+ * @active_assocs: list of registered associations
+ * @prev_assocs: associations which use old (but still usable)
+ * device key
+ * @stale_assocs: associations which use a rotated out key
+ *
+ * @rcu: RCU head for freeing the structure
+ */
+struct psp_dev {
+ struct net_device *main_netdev;
+
+ struct psp_dev_ops *ops;
+ struct psp_dev_caps *caps;
+ void *drv_priv;
+
+ struct mutex lock;
+ refcount_t refcnt;
+
+ u32 id;
+
+ u8 generation;
+
+ struct psp_dev_config config;
+
+ struct list_head active_assocs;
+ struct list_head prev_assocs;
+ struct list_head stale_assocs;
+
+ struct rcu_head rcu;
+};
+
+#define PSP_GEN_VALID_MASK 0x7f
+
+/**
+ * struct psp_dev_caps - PSP device capabilities
+ */
+struct psp_dev_caps {
+ /**
+ * @versions: mask of supported PSP versions
+ * Set this field to 0 to indicate PSP is not supported at all.
+ */
+ u32 versions;
+
+ /**
+ * @assoc_drv_spc: size of driver-specific state in Tx assoc
+ * Determines the size of struct psp_assoc::drv_spc
+ */
+ u32 assoc_drv_spc;
+};
+
+#define PSP_MAX_KEY 32
+
+#define PSP_HDR_SIZE 16 /* We don't support optional fields, yet */
+#define PSP_TRL_SIZE 16 /* AES-GCM/GMAC trailer size */
+
+struct psp_skb_ext {
+ __be32 spi;
+ u16 dev_id;
+ u8 generation;
+ u8 version;
+};
+
+struct psp_key_parsed {
+ __be32 spi;
+ u8 key[PSP_MAX_KEY];
+};
+
+struct psp_assoc {
+ struct psp_dev *psd;
+
+ u16 dev_id;
+ u8 generation;
+ u8 version;
+ u8 peer_tx;
+
+ u32 upgrade_seq;
+
+ struct psp_key_parsed tx;
+ struct psp_key_parsed rx;
+
+ refcount_t refcnt;
+ struct rcu_head rcu;
+ struct work_struct work;
+ struct list_head assocs_list;
+
+ u8 drv_data[] __aligned(8);
+};
+
+/**
+ * struct psp_dev_ops - netdev driver facing PSP callbacks
+ */
+struct psp_dev_ops {
+ /**
+ * @set_config: set configuration of a PSP device
+ * Driver can inspect @psd->config for the previous configuration.
+ * Core will update @psd->config with @config on success.
+ */
+ int (*set_config)(struct psp_dev *psd, struct psp_dev_config *conf,
+ struct netlink_ext_ack *extack);
+
+ /**
+ * @key_rotate: rotate the device key
+ */
+ int (*key_rotate)(struct psp_dev *psd, struct netlink_ext_ack *extack);
+
+ /**
+ * @rx_spi_alloc: allocate an Rx SPI+key pair
+ * Allocate an Rx SPI and resulting derived key.
+ * This key should remain valid until key rotation.
+ */
+ int (*rx_spi_alloc)(struct psp_dev *psd, u32 version,
+ struct psp_key_parsed *assoc,
+ struct netlink_ext_ack *extack);
+
+ /**
+ * @tx_key_add: add a Tx key to the device
+ * Install an association in the device. Core will allocate space
+ * for the driver to use at drv_data.
+ */
+ int (*tx_key_add)(struct psp_dev *psd, struct psp_assoc *pas,
+ struct netlink_ext_ack *extack);
+ /**
+ * @tx_key_del: remove a Tx key from the device
+ * Remove an association from the device.
+ */
+ void (*tx_key_del)(struct psp_dev *psd, struct psp_assoc *pas);
+};
+
+#endif /* __NET_PSP_H */
diff --git a/include/net/sock.h b/include/net/sock.h
index 82bcdb7d7e67..e58c6120271f 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -249,6 +249,7 @@ struct sk_filter;
* @sk_dst_cache: destination cache
* @sk_dst_pending_confirm: need to confirm neighbour
* @sk_policy: flow policy
+ * @psp_assoc: PSP association, if socket is PSP-secured
* @sk_receive_queue: incoming packets
* @sk_wmem_alloc: transmit queue bytes committed
* @sk_tsq_flags: TCP Small Queues flags
@@ -451,6 +452,9 @@ struct sock {
#ifdef CONFIG_XFRM
struct xfrm_policy __rcu *sk_policy[2];
#endif
+#if IS_ENABLED(CONFIG_INET_PSP)
+ struct psp_assoc __rcu *psp_assoc;
+#endif
__cacheline_group_end(sock_read_rxtx);
__cacheline_group_begin(sock_write_rxtx);
@@ -2956,28 +2960,6 @@ sk_requests_wifi_status(struct sock *sk)
return sk && sk_fullsock(sk) && sock_flag(sk, SOCK_WIFI_STATUS);
}
-/* Checks if this SKB belongs to an HW offloaded socket
- * and whether any SW fallbacks are required based on dev.
- * Check decrypted mark in case skb_orphan() cleared socket.
- */
-static inline struct sk_buff *sk_validate_xmit_skb(struct sk_buff *skb,
- struct net_device *dev)
-{
-#ifdef CONFIG_SOCK_VALIDATE_XMIT
- struct sock *sk = skb->sk;
-
- if (sk && sk_fullsock(sk) && sk->sk_validate_xmit_skb) {
- skb = sk->sk_validate_xmit_skb(sk, dev, skb);
- } else if (unlikely(skb_is_decrypted(skb))) {
- pr_warn_ratelimited("unencrypted skb with no associated socket - dropping\n");
- kfree_skb(skb);
- skb = NULL;
- }
-#endif
-
- return skb;
-}
-
/* This helper checks if a socket is a LISTEN or NEW_SYN_RECV
* SYNACK messages can be attached to either ones (depending on SYNCOOKIE)
*/
diff --git a/include/uapi/linux/psp.h b/include/uapi/linux/psp.h
new file mode 100644
index 000000000000..607c42c39ba5
--- /dev/null
+++ b/include/uapi/linux/psp.h
@@ -0,0 +1,66 @@
+/* SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) */
+/* Do not edit directly, auto-generated from: */
+/* Documentation/netlink/specs/psp.yaml */
+/* YNL-GEN uapi header */
+
+#ifndef _UAPI_LINUX_PSP_H
+#define _UAPI_LINUX_PSP_H
+
+#define PSP_FAMILY_NAME "psp"
+#define PSP_FAMILY_VERSION 1
+
+enum psp_version {
+ PSP_VERSION_HDR0_AES_GCM_128,
+ PSP_VERSION_HDR0_AES_GCM_256,
+ PSP_VERSION_HDR0_AES_GMAC_128,
+ PSP_VERSION_HDR0_AES_GMAC_256,
+};
+
+enum {
+ PSP_A_DEV_ID = 1,
+ PSP_A_DEV_IFINDEX,
+ PSP_A_DEV_PSP_VERSIONS_CAP,
+ PSP_A_DEV_PSP_VERSIONS_ENA,
+
+ __PSP_A_DEV_MAX,
+ PSP_A_DEV_MAX = (__PSP_A_DEV_MAX - 1)
+};
+
+enum {
+ PSP_A_ASSOC_DEV_ID = 1,
+ PSP_A_ASSOC_VERSION,
+ PSP_A_ASSOC_RX_KEY,
+ PSP_A_ASSOC_TX_KEY,
+ PSP_A_ASSOC_SOCK_FD,
+
+ __PSP_A_ASSOC_MAX,
+ PSP_A_ASSOC_MAX = (__PSP_A_ASSOC_MAX - 1)
+};
+
+enum {
+ PSP_A_KEYS_KEY = 1,
+ PSP_A_KEYS_SPI,
+
+ __PSP_A_KEYS_MAX,
+ PSP_A_KEYS_MAX = (__PSP_A_KEYS_MAX - 1)
+};
+
+enum {
+ PSP_CMD_DEV_GET = 1,
+ PSP_CMD_DEV_ADD_NTF,
+ PSP_CMD_DEV_DEL_NTF,
+ PSP_CMD_DEV_SET,
+ PSP_CMD_DEV_CHANGE_NTF,
+ PSP_CMD_KEY_ROTATE,
+ PSP_CMD_KEY_ROTATE_NTF,
+ PSP_CMD_RX_ASSOC,
+ PSP_CMD_TX_ASSOC,
+
+ __PSP_CMD_MAX,
+ PSP_CMD_MAX = (__PSP_CMD_MAX - 1)
+};
+
+#define PSP_MCGRP_MGMT "mgmt"
+#define PSP_MCGRP_USE "use"
+
+#endif /* _UAPI_LINUX_PSP_H */
diff --git a/net/Kconfig b/net/Kconfig
index d5865cf19799..4b563aea4c23 100644
--- a/net/Kconfig
+++ b/net/Kconfig
@@ -82,6 +82,7 @@ config NET_CRC32C
menu "Networking options"
source "net/packet/Kconfig"
+source "net/psp/Kconfig"
source "net/unix/Kconfig"
source "net/tls/Kconfig"
source "net/xfrm/Kconfig"
diff --git a/net/Makefile b/net/Makefile
index aac960c41db6..90e3d72bf58b 100644
--- a/net/Makefile
+++ b/net/Makefile
@@ -18,6 +18,7 @@ obj-$(CONFIG_INET) += ipv4/
obj-$(CONFIG_TLS) += tls/
obj-$(CONFIG_XFRM) += xfrm/
obj-$(CONFIG_UNIX) += unix/
+obj-$(CONFIG_INET_PSP) += psp/
obj-y += ipv6/
obj-$(CONFIG_PACKET) += packet/
obj-$(CONFIG_NET_KEY) += key/
diff --git a/net/core/dev.c b/net/core/dev.c
index 2522d9d8f0e4..5e22d062bac5 100644
--- a/net/core/dev.c
+++ b/net/core/dev.c
@@ -3907,6 +3907,38 @@ sw_checksum:
}
EXPORT_SYMBOL(skb_csum_hwoffload_help);
+/* Checks if this SKB belongs to an HW offloaded socket
+ * and whether any SW fallbacks are required based on dev.
+ * Check decrypted mark in case skb_orphan() cleared socket.
+ */
+static struct sk_buff *sk_validate_xmit_skb(struct sk_buff *skb,
+ struct net_device *dev)
+{
+#ifdef CONFIG_SOCK_VALIDATE_XMIT
+ struct sk_buff *(*sk_validate)(struct sock *sk, struct net_device *dev,
+ struct sk_buff *skb);
+ struct sock *sk = skb->sk;
+
+ sk_validate = NULL;
+ if (sk) {
+ if (sk_fullsock(sk))
+ sk_validate = sk->sk_validate_xmit_skb;
+ else if (sk_is_inet(sk) && sk->sk_state == TCP_TIME_WAIT)
+ sk_validate = inet_twsk(sk)->tw_validate_xmit_skb;
+ }
+
+ if (sk_validate) {
+ skb = sk_validate(sk, dev, skb);
+ } else if (unlikely(skb_is_decrypted(skb))) {
+ pr_warn_ratelimited("unencrypted skb with no associated socket - dropping\n");
+ kfree_skb(skb);
+ skb = NULL;
+ }
+#endif
+
+ return skb;
+}
+
static struct sk_buff *validate_xmit_unreadable_skb(struct sk_buff *skb,
struct net_device *dev)
{
diff --git a/net/core/gro.c b/net/core/gro.c
index b350e5b69549..5ba4504cfd28 100644
--- a/net/core/gro.c
+++ b/net/core/gro.c
@@ -1,4 +1,5 @@
// SPDX-License-Identifier: GPL-2.0-or-later
+#include <net/psp.h>
#include <net/gro.h>
#include <net/dst_metadata.h>
#include <net/busy_poll.h>
@@ -376,6 +377,7 @@ static void gro_list_prepare(const struct list_head *head,
diffs |= skb_get_nfct(p) ^ skb_get_nfct(skb);
diffs |= gro_list_prepare_tc_ext(skb, p, diffs);
+ diffs |= __psp_skb_coalesce_diff(skb, p, diffs);
}
NAPI_GRO_CB(p)->same_flow = !diffs;
diff --git a/net/core/skbuff.c b/net/core/skbuff.c
index 23b776cd9879..d331e607edfb 100644
--- a/net/core/skbuff.c
+++ b/net/core/skbuff.c
@@ -79,6 +79,7 @@
#include <net/mptcp.h>
#include <net/mctp.h>
#include <net/page_pool/helpers.h>
+#include <net/psp/types.h>
#include <net/dropreason.h>
#include <linux/uaccess.h>
@@ -5062,6 +5063,9 @@ static const u8 skb_ext_type_len[] = {
#if IS_ENABLED(CONFIG_MCTP_FLOWS)
[SKB_EXT_MCTP] = SKB_EXT_CHUNKSIZEOF(struct mctp_flow),
#endif
+#if IS_ENABLED(CONFIG_INET_PSP)
+ [SKB_EXT_PSP] = SKB_EXT_CHUNKSIZEOF(struct psp_skb_ext),
+#endif
};
static __always_inline unsigned int skb_ext_total_length(void)
diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
index 76e38092cd8a..e298dacb4a06 100644
--- a/net/ipv4/af_inet.c
+++ b/net/ipv4/af_inet.c
@@ -102,6 +102,7 @@
#include <net/gro.h>
#include <net/gso.h>
#include <net/tcp.h>
+#include <net/psp.h>
#include <net/udp.h>
#include <net/udplite.h>
#include <net/ping.h>
@@ -158,6 +159,7 @@ void inet_sock_destruct(struct sock *sk)
kfree(rcu_dereference_protected(inet->inet_opt, 1));
dst_release(rcu_dereference_protected(sk->sk_dst_cache, 1));
dst_release(rcu_dereference_protected(sk->sk_rx_dst, 1));
+ psp_sk_assoc_free(sk);
}
EXPORT_SYMBOL(inet_sock_destruct);
diff --git a/net/ipv4/inet_timewait_sock.c b/net/ipv4/inet_timewait_sock.c
index 5b5426b8ee92..2ca2912f61f4 100644
--- a/net/ipv4/inet_timewait_sock.c
+++ b/net/ipv4/inet_timewait_sock.c
@@ -16,6 +16,7 @@
#include <net/inet_timewait_sock.h>
#include <net/ip.h>
#include <net/tcp.h>
+#include <net/psp.h>
/**
* inet_twsk_bind_unhash - unhash a timewait socket from bind hash
@@ -211,6 +212,9 @@ struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk,
atomic64_set(&tw->tw_cookie, atomic64_read(&sk->sk_cookie));
twsk_net_set(tw, sock_net(sk));
timer_setup(&tw->tw_timer, tw_timer_handler, 0);
+#ifdef CONFIG_SOCK_VALIDATE_XMIT
+ tw->tw_validate_xmit_skb = NULL;
+#endif
/*
* Because we use RCU lookups, we should not set tw_refcnt
* to a non null value before everything is setup for this
@@ -219,6 +223,7 @@ struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk,
refcount_set(&tw->tw_refcnt, 0);
__module_get(tw->tw_prot->owner);
+ psp_twsk_init(tw, sk);
}
return tw;
diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c
index 2b96651d719b..5ca97ede979c 100644
--- a/net/ipv4/ip_output.c
+++ b/net/ipv4/ip_output.c
@@ -84,6 +84,7 @@
#include <linux/netfilter_bridge.h>
#include <linux/netlink.h>
#include <linux/tcp.h>
+#include <net/psp.h>
static int
ip_fragment(struct net *net, struct sock *sk, struct sk_buff *skb,
@@ -1665,8 +1666,10 @@ void ip_send_unicast_reply(struct sock *sk, const struct sock *orig_sk,
arg->csumoffset) = csum_fold(csum_add(nskb->csum,
arg->csum));
nskb->ip_summed = CHECKSUM_NONE;
- if (orig_sk)
+ if (orig_sk) {
skb_set_owner_edemux(nskb, (struct sock *)orig_sk);
+ psp_reply_set_decrypted(nskb);
+ }
if (transmit_time)
nskb->tstamp_type = SKB_CLOCK_MONOTONIC;
if (txhash)
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index 5b5c655ded1d..d6d0d970e014 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -277,6 +277,7 @@
#include <net/proto_memory.h>
#include <net/xfrm.h>
#include <net/ip.h>
+#include <net/psp.h>
#include <net/sock.h>
#include <net/rstreason.h>
@@ -705,6 +706,7 @@ void tcp_skb_entail(struct sock *sk, struct sk_buff *skb)
tcb->seq = tcb->end_seq = tp->write_seq;
tcb->tcp_flags = TCPHDR_ACK;
__skb_header_release(skb);
+ psp_enqueue_set_decrypted(sk, skb);
tcp_add_write_queue_tail(sk, skb);
sk_wmem_queued_add(sk, skb->truesize);
sk_mem_charge(sk, skb->truesize);
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index 6a63be1f6461..b1fcf3e4e1ce 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -75,6 +75,7 @@
#include <net/secure_seq.h>
#include <net/busy_poll.h>
#include <net/rstreason.h>
+#include <net/psp.h>
#include <linux/inet.h>
#include <linux/ipv6.h>
@@ -293,9 +294,9 @@ int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
inet->inet_dport = usin->sin_port;
sk_daddr_set(sk, daddr);
- inet_csk(sk)->icsk_ext_hdr_len = 0;
+ inet_csk(sk)->icsk_ext_hdr_len = psp_sk_overhead(sk);
if (inet_opt)
- inet_csk(sk)->icsk_ext_hdr_len = inet_opt->opt.optlen;
+ inet_csk(sk)->icsk_ext_hdr_len += inet_opt->opt.optlen;
tp->rx_opt.mss_clamp = TCP_MSS_DEFAULT;
@@ -1907,6 +1908,10 @@ int tcp_v4_do_rcv(struct sock *sk, struct sk_buff *skb)
enum skb_drop_reason reason;
struct sock *rsk;
+ reason = psp_sk_rx_policy_check(sk, skb);
+ if (reason)
+ goto err_discard;
+
if (sk->sk_state == TCP_ESTABLISHED) { /* Fast path */
struct dst_entry *dst;
@@ -1968,6 +1973,7 @@ csum_err:
reason = SKB_DROP_REASON_TCP_CSUM;
trace_tcp_bad_csum(skb);
TCP_INC_STATS(sock_net(sk), TCP_MIB_CSUMERRORS);
+err_discard:
TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS);
goto discard;
}
@@ -2069,7 +2075,9 @@ bool tcp_add_backlog(struct sock *sk, struct sk_buff *skb,
(TCPHDR_ECE | TCPHDR_CWR | TCPHDR_AE)) ||
!tcp_skb_can_collapse_rx(tail, skb) ||
thtail->doff != th->doff ||
- memcmp(thtail + 1, th + 1, hdrlen - sizeof(*th)))
+ memcmp(thtail + 1, th + 1, hdrlen - sizeof(*th)) ||
+ /* prior to PSP Rx policy check, retain exact PSP metadata */
+ psp_skb_coalesce_diff(tail, skb))
goto no_coalesce;
__skb_pull(skb, hdrlen);
@@ -2437,6 +2445,10 @@ do_time_wait:
__this_cpu_write(tcp_tw_isn, isn);
goto process;
}
+
+ drop_reason = psp_twsk_rx_policy_check(inet_twsk(sk), skb);
+ if (drop_reason)
+ break;
}
/* to ACK */
fallthrough;
diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c
index 327095ef95ef..2ec8c6f1cdcc 100644
--- a/net/ipv4/tcp_minisocks.c
+++ b/net/ipv4/tcp_minisocks.c
@@ -24,6 +24,7 @@
#include <net/xfrm.h>
#include <net/busy_poll.h>
#include <net/rstreason.h>
+#include <net/psp.h>
static bool tcp_in_window(u32 seq, u32 end_seq, u32 s_win, u32 e_win)
{
@@ -104,9 +105,16 @@ tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb,
struct tcp_timewait_sock *tcptw = tcp_twsk((struct sock *)tw);
u32 rcv_nxt = READ_ONCE(tcptw->tw_rcv_nxt);
struct tcp_options_received tmp_opt;
+ enum skb_drop_reason psp_drop;
bool paws_reject = false;
int ts_recent_stamp;
+ /* Instead of dropping immediately, wait to see what value is
+ * returned. We will accept a non psp-encapsulated syn in the
+ * case where TCP_TW_SYN is returned.
+ */
+ psp_drop = psp_twsk_rx_policy_check(tw, skb);
+
tmp_opt.saw_tstamp = 0;
ts_recent_stamp = READ_ONCE(tcptw->tw_ts_recent_stamp);
if (th->doff > (sizeof(*th) >> 2) && ts_recent_stamp) {
@@ -124,6 +132,9 @@ tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb,
if (READ_ONCE(tw->tw_substate) == TCP_FIN_WAIT2) {
/* Just repeat all the checks of tcp_rcv_state_process() */
+ if (psp_drop)
+ goto out_put;
+
/* Out of window, send ACK */
if (paws_reject ||
!tcp_in_window(TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq,
@@ -194,6 +205,9 @@ tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb,
(TCP_SKB_CB(skb)->seq == TCP_SKB_CB(skb)->end_seq || th->rst))) {
/* In window segment, it may be only reset or bare ack. */
+ if (psp_drop)
+ goto out_put;
+
if (th->rst) {
/* This is TIME_WAIT assassination, in two flavors.
* Oh well... nobody has a sufficient solution to this
@@ -247,6 +261,9 @@ kill:
return TCP_TW_SYN;
}
+ if (psp_drop)
+ goto out_put;
+
if (paws_reject) {
*drop_reason = SKB_DROP_REASON_TCP_RFC7323_TW_PAWS;
__NET_INC_STATS(twsk_net(tw), LINUX_MIB_PAWS_TW_REJECTED);
@@ -265,6 +282,8 @@ kill:
return tcp_timewait_check_oow_rate_limit(
tw, skb, LINUX_MIB_TCPACKSKIPPEDTIMEWAIT);
}
+
+out_put:
inet_twsk_put(tw);
return TCP_TW_SUCCESS;
}
@@ -392,6 +411,7 @@ void tcp_twsk_destructor(struct sock *sk)
}
#endif
tcp_ao_destroy_sock(sk, true);
+ psp_twsk_assoc_free(inet_twsk(sk));
}
void tcp_twsk_purge(struct list_head *net_exit_list)
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index 388c45859469..223d7feeb19d 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -41,6 +41,7 @@
#include <net/tcp_ecn.h>
#include <net/mptcp.h>
#include <net/proto_memory.h>
+#include <net/psp.h>
#include <linux/compiler.h>
#include <linux/gfp.h>
@@ -358,13 +359,15 @@ static void tcp_ecn_send(struct sock *sk, struct sk_buff *skb,
/* Constructs common control bits of non-data skb. If SYN/FIN is present,
* auto increment end seqno.
*/
-static void tcp_init_nondata_skb(struct sk_buff *skb, u32 seq, u16 flags)
+static void tcp_init_nondata_skb(struct sk_buff *skb, struct sock *sk,
+ u32 seq, u16 flags)
{
skb->ip_summed = CHECKSUM_PARTIAL;
TCP_SKB_CB(skb)->tcp_flags = flags;
tcp_skb_pcount_set(skb, 1);
+ psp_enqueue_set_decrypted(sk, skb);
TCP_SKB_CB(skb)->seq = seq;
if (flags & (TCPHDR_SYN | TCPHDR_FIN))
@@ -1656,6 +1659,7 @@ static void tcp_queue_skb(struct sock *sk, struct sk_buff *skb)
/* Advance write_seq and place onto the write_queue. */
WRITE_ONCE(tp->write_seq, TCP_SKB_CB(skb)->end_seq);
__skb_header_release(skb);
+ psp_enqueue_set_decrypted(sk, skb);
tcp_add_write_queue_tail(sk, skb);
sk_wmem_queued_add(sk, skb->truesize);
sk_mem_charge(sk, skb->truesize);
@@ -3778,7 +3782,7 @@ void tcp_send_fin(struct sock *sk)
skb_reserve(skb, MAX_TCP_HEADER);
sk_forced_mem_schedule(sk, skb->truesize);
/* FIN eats a sequence byte, write_seq advanced by tcp_queue_skb(). */
- tcp_init_nondata_skb(skb, tp->write_seq,
+ tcp_init_nondata_skb(skb, sk, tp->write_seq,
TCPHDR_ACK | TCPHDR_FIN);
tcp_queue_skb(sk, skb);
}
@@ -3806,7 +3810,7 @@ void tcp_send_active_reset(struct sock *sk, gfp_t priority,
/* Reserve space for headers and prepare control bits. */
skb_reserve(skb, MAX_TCP_HEADER);
- tcp_init_nondata_skb(skb, tcp_acceptable_seq(sk),
+ tcp_init_nondata_skb(skb, sk, tcp_acceptable_seq(sk),
TCPHDR_ACK | TCPHDR_RST);
tcp_mstamp_refresh(tcp_sk(sk));
/* Send it off. */
@@ -4303,7 +4307,7 @@ int tcp_connect(struct sock *sk)
/* SYN eats a sequence byte, write_seq updated by
* tcp_connect_queue_skb().
*/
- tcp_init_nondata_skb(buff, tp->write_seq, TCPHDR_SYN);
+ tcp_init_nondata_skb(buff, sk, tp->write_seq, TCPHDR_SYN);
tcp_mstamp_refresh(tp);
tp->retrans_stamp = tcp_time_stamp_ts(tp);
tcp_connect_queue_skb(sk, buff);
@@ -4428,7 +4432,8 @@ void __tcp_send_ack(struct sock *sk, u32 rcv_nxt, u16 flags)
/* Reserve space for headers and prepare control bits. */
skb_reserve(buff, MAX_TCP_HEADER);
- tcp_init_nondata_skb(buff, tcp_acceptable_seq(sk), TCPHDR_ACK | flags);
+ tcp_init_nondata_skb(buff, sk,
+ tcp_acceptable_seq(sk), TCPHDR_ACK | flags);
/* We do not want pure acks influencing TCP Small Queues or fq/pacing
* too much.
@@ -4474,7 +4479,7 @@ static int tcp_xmit_probe_skb(struct sock *sk, int urgent, int mib)
* end to send an ack. Don't queue or clone SKB, just
* send it.
*/
- tcp_init_nondata_skb(skb, tp->snd_una - !urgent, TCPHDR_ACK);
+ tcp_init_nondata_skb(skb, sk, tp->snd_una - !urgent, TCPHDR_ACK);
NET_INC_STATS(sock_net(sk), mib);
return tcp_transmit_skb(sk, skb, 0, (__force gfp_t)0);
}
diff --git a/net/ipv6/ipv6_sockglue.c b/net/ipv6/ipv6_sockglue.c
index e66ec623972e..a61e742794f9 100644
--- a/net/ipv6/ipv6_sockglue.c
+++ b/net/ipv6/ipv6_sockglue.c
@@ -49,6 +49,7 @@
#include <net/xfrm.h>
#include <net/compat.h>
#include <net/seg6.h>
+#include <net/psp.h>
#include <linux/uaccess.h>
@@ -107,7 +108,10 @@ struct ipv6_txoptions *ipv6_update_options(struct sock *sk,
!((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) &&
inet_sk(sk)->inet_daddr != LOOPBACK4_IPV6) {
struct inet_connection_sock *icsk = inet_csk(sk);
- icsk->icsk_ext_hdr_len = opt->opt_flen + opt->opt_nflen;
+
+ icsk->icsk_ext_hdr_len =
+ psp_sk_overhead(sk) +
+ opt->opt_flen + opt->opt_nflen;
icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie);
}
}
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index c7271f6359db..d1e5b2a186fb 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -62,6 +62,7 @@
#include <net/hotdata.h>
#include <net/busy_poll.h>
#include <net/rstreason.h>
+#include <net/psp.h>
#include <linux/proc_fs.h>
#include <linux/seq_file.h>
@@ -301,10 +302,10 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
sk->sk_gso_type = SKB_GSO_TCPV6;
ip6_dst_store(sk, dst, false, false);
- icsk->icsk_ext_hdr_len = 0;
+ icsk->icsk_ext_hdr_len = psp_sk_overhead(sk);
if (opt)
- icsk->icsk_ext_hdr_len = opt->opt_flen +
- opt->opt_nflen;
+ icsk->icsk_ext_hdr_len += opt->opt_flen +
+ opt->opt_nflen;
tp->rx_opt.mss_clamp = IPV6_MIN_MTU - sizeof(struct tcphdr) - sizeof(struct ipv6hdr);
@@ -973,6 +974,7 @@ static void tcp_v6_send_response(const struct sock *sk, struct sk_buff *skb, u32
if (sk) {
/* unconstify the socket only to attach it to buff with care. */
skb_set_owner_edemux(buff, (struct sock *)sk);
+ psp_reply_set_decrypted(buff);
if (sk->sk_state == TCP_TIME_WAIT)
mark = inet_twsk(sk)->tw_mark;
@@ -1605,6 +1607,10 @@ int tcp_v6_do_rcv(struct sock *sk, struct sk_buff *skb)
if (skb->protocol == htons(ETH_P_IP))
return tcp_v4_do_rcv(sk, skb);
+ reason = psp_sk_rx_policy_check(sk, skb);
+ if (reason)
+ goto err_discard;
+
/*
* socket locking is here for SMP purposes as backlog rcv
* is currently called with bh processing disabled.
@@ -1684,6 +1690,7 @@ csum_err:
reason = SKB_DROP_REASON_TCP_CSUM;
trace_tcp_bad_csum(skb);
TCP_INC_STATS(sock_net(sk), TCP_MIB_CSUMERRORS);
+err_discard:
TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS);
goto discard;
@@ -1988,6 +1995,10 @@ do_time_wait:
__this_cpu_write(tcp_tw_isn, isn);
goto process;
}
+
+ drop_reason = psp_twsk_rx_policy_check(inet_twsk(sk), skb);
+ if (drop_reason)
+ break;
}
/* to ACK */
fallthrough;
diff --git a/net/psp/Kconfig b/net/psp/Kconfig
new file mode 100644
index 000000000000..a7d24691a7e1
--- /dev/null
+++ b/net/psp/Kconfig
@@ -0,0 +1,15 @@
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# PSP configuration
+#
+config INET_PSP
+ bool "PSP Security Protocol support"
+ depends on INET
+ select SKB_DECRYPTED
+ select SOCK_VALIDATE_XMIT
+ help
+ Enable kernel support for the PSP protocol.
+ For more information see:
+ https://raw.githubusercontent.com/google/psp/main/doc/PSP_Arch_Spec.pdf
+
+ If unsure, say N.
diff --git a/net/psp/Makefile b/net/psp/Makefile
new file mode 100644
index 000000000000..eb5ff3c5bfb2
--- /dev/null
+++ b/net/psp/Makefile
@@ -0,0 +1,5 @@
+# SPDX-License-Identifier: GPL-2.0-only
+
+obj-$(CONFIG_INET_PSP) += psp.o
+
+psp-y := psp_main.o psp_nl.o psp_sock.o psp-nl-gen.o
diff --git a/net/psp/psp-nl-gen.c b/net/psp/psp-nl-gen.c
new file mode 100644
index 000000000000..9fdd6f831803
--- /dev/null
+++ b/net/psp/psp-nl-gen.c
@@ -0,0 +1,119 @@
+// SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
+/* Do not edit directly, auto-generated from: */
+/* Documentation/netlink/specs/psp.yaml */
+/* YNL-GEN kernel source */
+
+#include <net/netlink.h>
+#include <net/genetlink.h>
+
+#include "psp-nl-gen.h"
+
+#include <uapi/linux/psp.h>
+
+/* Common nested types */
+const struct nla_policy psp_keys_nl_policy[PSP_A_KEYS_SPI + 1] = {
+ [PSP_A_KEYS_KEY] = { .type = NLA_BINARY, },
+ [PSP_A_KEYS_SPI] = { .type = NLA_U32, },
+};
+
+/* PSP_CMD_DEV_GET - do */
+static const struct nla_policy psp_dev_get_nl_policy[PSP_A_DEV_ID + 1] = {
+ [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1),
+};
+
+/* PSP_CMD_DEV_SET - do */
+static const struct nla_policy psp_dev_set_nl_policy[PSP_A_DEV_PSP_VERSIONS_ENA + 1] = {
+ [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1),
+ [PSP_A_DEV_PSP_VERSIONS_ENA] = NLA_POLICY_MASK(NLA_U32, 0xf),
+};
+
+/* PSP_CMD_KEY_ROTATE - do */
+static const struct nla_policy psp_key_rotate_nl_policy[PSP_A_DEV_ID + 1] = {
+ [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1),
+};
+
+/* PSP_CMD_RX_ASSOC - do */
+static const struct nla_policy psp_rx_assoc_nl_policy[PSP_A_ASSOC_SOCK_FD + 1] = {
+ [PSP_A_ASSOC_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1),
+ [PSP_A_ASSOC_VERSION] = NLA_POLICY_MAX(NLA_U32, 3),
+ [PSP_A_ASSOC_SOCK_FD] = { .type = NLA_U32, },
+};
+
+/* PSP_CMD_TX_ASSOC - do */
+static const struct nla_policy psp_tx_assoc_nl_policy[PSP_A_ASSOC_SOCK_FD + 1] = {
+ [PSP_A_ASSOC_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1),
+ [PSP_A_ASSOC_VERSION] = NLA_POLICY_MAX(NLA_U32, 3),
+ [PSP_A_ASSOC_TX_KEY] = NLA_POLICY_NESTED(psp_keys_nl_policy),
+ [PSP_A_ASSOC_SOCK_FD] = { .type = NLA_U32, },
+};
+
+/* Ops table for psp */
+static const struct genl_split_ops psp_nl_ops[] = {
+ {
+ .cmd = PSP_CMD_DEV_GET,
+ .pre_doit = psp_device_get_locked,
+ .doit = psp_nl_dev_get_doit,
+ .post_doit = psp_device_unlock,
+ .policy = psp_dev_get_nl_policy,
+ .maxattr = PSP_A_DEV_ID,
+ .flags = GENL_CMD_CAP_DO,
+ },
+ {
+ .cmd = PSP_CMD_DEV_GET,
+ .dumpit = psp_nl_dev_get_dumpit,
+ .flags = GENL_CMD_CAP_DUMP,
+ },
+ {
+ .cmd = PSP_CMD_DEV_SET,
+ .pre_doit = psp_device_get_locked,
+ .doit = psp_nl_dev_set_doit,
+ .post_doit = psp_device_unlock,
+ .policy = psp_dev_set_nl_policy,
+ .maxattr = PSP_A_DEV_PSP_VERSIONS_ENA,
+ .flags = GENL_CMD_CAP_DO,
+ },
+ {
+ .cmd = PSP_CMD_KEY_ROTATE,
+ .pre_doit = psp_device_get_locked,
+ .doit = psp_nl_key_rotate_doit,
+ .post_doit = psp_device_unlock,
+ .policy = psp_key_rotate_nl_policy,
+ .maxattr = PSP_A_DEV_ID,
+ .flags = GENL_CMD_CAP_DO,
+ },
+ {
+ .cmd = PSP_CMD_RX_ASSOC,
+ .pre_doit = psp_assoc_device_get_locked,
+ .doit = psp_nl_rx_assoc_doit,
+ .post_doit = psp_device_unlock,
+ .policy = psp_rx_assoc_nl_policy,
+ .maxattr = PSP_A_ASSOC_SOCK_FD,
+ .flags = GENL_CMD_CAP_DO,
+ },
+ {
+ .cmd = PSP_CMD_TX_ASSOC,
+ .pre_doit = psp_assoc_device_get_locked,
+ .doit = psp_nl_tx_assoc_doit,
+ .post_doit = psp_device_unlock,
+ .policy = psp_tx_assoc_nl_policy,
+ .maxattr = PSP_A_ASSOC_SOCK_FD,
+ .flags = GENL_CMD_CAP_DO,
+ },
+};
+
+static const struct genl_multicast_group psp_nl_mcgrps[] = {
+ [PSP_NLGRP_MGMT] = { "mgmt", },
+ [PSP_NLGRP_USE] = { "use", },
+};
+
+struct genl_family psp_nl_family __ro_after_init = {
+ .name = PSP_FAMILY_NAME,
+ .version = PSP_FAMILY_VERSION,
+ .netnsok = true,
+ .parallel_ops = true,
+ .module = THIS_MODULE,
+ .split_ops = psp_nl_ops,
+ .n_split_ops = ARRAY_SIZE(psp_nl_ops),
+ .mcgrps = psp_nl_mcgrps,
+ .n_mcgrps = ARRAY_SIZE(psp_nl_mcgrps),
+};
diff --git a/net/psp/psp-nl-gen.h b/net/psp/psp-nl-gen.h
new file mode 100644
index 000000000000..25268ed11fb5
--- /dev/null
+++ b/net/psp/psp-nl-gen.h
@@ -0,0 +1,39 @@
+/* SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause) */
+/* Do not edit directly, auto-generated from: */
+/* Documentation/netlink/specs/psp.yaml */
+/* YNL-GEN kernel header */
+
+#ifndef _LINUX_PSP_GEN_H
+#define _LINUX_PSP_GEN_H
+
+#include <net/netlink.h>
+#include <net/genetlink.h>
+
+#include <uapi/linux/psp.h>
+
+/* Common nested types */
+extern const struct nla_policy psp_keys_nl_policy[PSP_A_KEYS_SPI + 1];
+
+int psp_device_get_locked(const struct genl_split_ops *ops,
+ struct sk_buff *skb, struct genl_info *info);
+int psp_assoc_device_get_locked(const struct genl_split_ops *ops,
+ struct sk_buff *skb, struct genl_info *info);
+void
+psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb,
+ struct genl_info *info);
+
+int psp_nl_dev_get_doit(struct sk_buff *skb, struct genl_info *info);
+int psp_nl_dev_get_dumpit(struct sk_buff *skb, struct netlink_callback *cb);
+int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info);
+int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info);
+int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info);
+int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info);
+
+enum {
+ PSP_NLGRP_MGMT,
+ PSP_NLGRP_USE,
+};
+
+extern struct genl_family psp_nl_family;
+
+#endif /* _LINUX_PSP_GEN_H */
diff --git a/net/psp/psp.h b/net/psp/psp.h
new file mode 100644
index 000000000000..0f34e1a23fdd
--- /dev/null
+++ b/net/psp/psp.h
@@ -0,0 +1,54 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+
+#ifndef __PSP_PSP_H
+#define __PSP_PSP_H
+
+#include <linux/list.h>
+#include <linux/lockdep.h>
+#include <linux/mutex.h>
+#include <net/netns/generic.h>
+#include <net/psp.h>
+#include <net/sock.h>
+
+extern struct xarray psp_devs;
+extern struct mutex psp_devs_lock;
+
+void psp_dev_destroy(struct psp_dev *psd);
+int psp_dev_check_access(struct psp_dev *psd, struct net *net);
+
+void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd);
+
+struct psp_assoc *psp_assoc_create(struct psp_dev *psd);
+struct psp_dev *psp_dev_get_for_sock(struct sock *sk);
+void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas);
+int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas,
+ struct psp_key_parsed *key,
+ struct netlink_ext_ack *extack);
+int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd,
+ u32 version, struct psp_key_parsed *key,
+ struct netlink_ext_ack *extack);
+void psp_assocs_key_rotated(struct psp_dev *psd);
+
+static inline void psp_dev_get(struct psp_dev *psd)
+{
+ refcount_inc(&psd->refcnt);
+}
+
+static inline bool psp_dev_tryget(struct psp_dev *psd)
+{
+ return refcount_inc_not_zero(&psd->refcnt);
+}
+
+static inline void psp_dev_put(struct psp_dev *psd)
+{
+ if (refcount_dec_and_test(&psd->refcnt))
+ psp_dev_destroy(psd);
+}
+
+static inline bool psp_dev_is_registered(struct psp_dev *psd)
+{
+ lockdep_assert_held(&psd->lock);
+ return !!psd->ops;
+}
+
+#endif /* __PSP_PSP_H */
diff --git a/net/psp/psp_main.c b/net/psp/psp_main.c
new file mode 100644
index 000000000000..b4b756f87382
--- /dev/null
+++ b/net/psp/psp_main.c
@@ -0,0 +1,321 @@
+// SPDX-License-Identifier: GPL-2.0-only
+
+#include <linux/bitfield.h>
+#include <linux/list.h>
+#include <linux/netdevice.h>
+#include <linux/xarray.h>
+#include <net/net_namespace.h>
+#include <net/psp.h>
+#include <net/udp.h>
+
+#include "psp.h"
+#include "psp-nl-gen.h"
+
+DEFINE_XARRAY_ALLOC1(psp_devs);
+struct mutex psp_devs_lock;
+
+/**
+ * DOC: PSP locking
+ *
+ * psp_devs_lock protects the psp_devs xarray.
+ * Ordering is take the psp_devs_lock and then the instance lock.
+ * Each instance is protected by RCU, and has a refcount.
+ * When driver unregisters the instance gets flushed, but struct sticks around.
+ */
+
+/**
+ * psp_dev_check_access() - check if user in a given net ns can access PSP dev
+ * @psd: PSP device structure user is trying to access
+ * @net: net namespace user is in
+ *
+ * Return: 0 if PSP device should be visible in @net, errno otherwise.
+ */
+int psp_dev_check_access(struct psp_dev *psd, struct net *net)
+{
+ if (dev_net(psd->main_netdev) == net)
+ return 0;
+ return -ENOENT;
+}
+
+/**
+ * psp_dev_create() - create and register PSP device
+ * @netdev: main netdevice
+ * @psd_ops: driver callbacks
+ * @psd_caps: device capabilities
+ * @priv_ptr: back-pointer to driver private data
+ *
+ * Return: pointer to allocated PSP device, or ERR_PTR.
+ */
+struct psp_dev *
+psp_dev_create(struct net_device *netdev,
+ struct psp_dev_ops *psd_ops, struct psp_dev_caps *psd_caps,
+ void *priv_ptr)
+{
+ struct psp_dev *psd;
+ static u32 last_id;
+ int err;
+
+ if (WARN_ON(!psd_caps->versions ||
+ !psd_ops->set_config ||
+ !psd_ops->key_rotate ||
+ !psd_ops->rx_spi_alloc ||
+ !psd_ops->tx_key_add ||
+ !psd_ops->tx_key_del))
+ return ERR_PTR(-EINVAL);
+
+ psd = kzalloc(sizeof(*psd), GFP_KERNEL);
+ if (!psd)
+ return ERR_PTR(-ENOMEM);
+
+ psd->main_netdev = netdev;
+ psd->ops = psd_ops;
+ psd->caps = psd_caps;
+ psd->drv_priv = priv_ptr;
+
+ mutex_init(&psd->lock);
+ INIT_LIST_HEAD(&psd->active_assocs);
+ INIT_LIST_HEAD(&psd->prev_assocs);
+ INIT_LIST_HEAD(&psd->stale_assocs);
+ refcount_set(&psd->refcnt, 1);
+
+ mutex_lock(&psp_devs_lock);
+ err = xa_alloc_cyclic(&psp_devs, &psd->id, psd, xa_limit_16b,
+ &last_id, GFP_KERNEL);
+ if (err) {
+ mutex_unlock(&psp_devs_lock);
+ kfree(psd);
+ return ERR_PTR(err);
+ }
+ mutex_lock(&psd->lock);
+ mutex_unlock(&psp_devs_lock);
+
+ psp_nl_notify_dev(psd, PSP_CMD_DEV_ADD_NTF);
+
+ rcu_assign_pointer(netdev->psp_dev, psd);
+
+ mutex_unlock(&psd->lock);
+
+ return psd;
+}
+EXPORT_SYMBOL(psp_dev_create);
+
+void psp_dev_destroy(struct psp_dev *psd)
+{
+ mutex_lock(&psp_devs_lock);
+ xa_erase(&psp_devs, psd->id);
+ mutex_unlock(&psp_devs_lock);
+
+ mutex_destroy(&psd->lock);
+ kfree_rcu(psd, rcu);
+}
+
+/**
+ * psp_dev_unregister() - unregister PSP device
+ * @psd: PSP device structure
+ */
+void psp_dev_unregister(struct psp_dev *psd)
+{
+ struct psp_assoc *pas, *next;
+
+ mutex_lock(&psp_devs_lock);
+ mutex_lock(&psd->lock);
+
+ psp_nl_notify_dev(psd, PSP_CMD_DEV_DEL_NTF);
+
+ /* Wait until psp_dev_destroy() to call xa_erase() to prevent a
+ * different psd from being added to the xarray with this id, while
+ * there are still references to this psd being held.
+ */
+ xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL);
+ mutex_unlock(&psp_devs_lock);
+
+ list_splice_init(&psd->active_assocs, &psd->prev_assocs);
+ list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
+ list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list)
+ psp_dev_tx_key_del(psd, pas);
+
+ rcu_assign_pointer(psd->main_netdev->psp_dev, NULL);
+
+ psd->ops = NULL;
+ psd->drv_priv = NULL;
+
+ mutex_unlock(&psd->lock);
+
+ psp_dev_put(psd);
+}
+EXPORT_SYMBOL(psp_dev_unregister);
+
+unsigned int psp_key_size(u32 version)
+{
+ switch (version) {
+ case PSP_VERSION_HDR0_AES_GCM_128:
+ case PSP_VERSION_HDR0_AES_GMAC_128:
+ return 16;
+ case PSP_VERSION_HDR0_AES_GCM_256:
+ case PSP_VERSION_HDR0_AES_GMAC_256:
+ return 32;
+ default:
+ return 0;
+ }
+}
+EXPORT_SYMBOL(psp_key_size);
+
+static void psp_write_headers(struct net *net, struct sk_buff *skb, __be32 spi,
+ u8 ver, unsigned int udp_len, __be16 sport)
+{
+ struct udphdr *uh = udp_hdr(skb);
+ struct psphdr *psph = (struct psphdr *)(uh + 1);
+
+ uh->dest = htons(PSP_DEFAULT_UDP_PORT);
+ uh->source = udp_flow_src_port(net, skb, 0, 0, false);
+ uh->check = 0;
+ uh->len = htons(udp_len);
+
+ psph->nexthdr = IPPROTO_TCP;
+ psph->hdrlen = PSP_HDRLEN_NOOPT;
+ psph->crypt_offset = 0;
+ psph->verfl = FIELD_PREP(PSPHDR_VERFL_VERSION, ver) |
+ FIELD_PREP(PSPHDR_VERFL_ONE, 1);
+ psph->spi = spi;
+ memset(&psph->iv, 0, sizeof(psph->iv));
+}
+
+/* Encapsulate a TCP packet with PSP by adding the UDP+PSP headers and filling
+ * them in.
+ */
+bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi,
+ u8 ver, __be16 sport)
+{
+ u32 network_len = skb_network_header_len(skb);
+ u32 ethr_len = skb_mac_header_len(skb);
+ u32 bufflen = ethr_len + network_len;
+
+ if (skb_cow_head(skb, PSP_ENCAP_HLEN))
+ return false;
+
+ skb_push(skb, PSP_ENCAP_HLEN);
+ skb->mac_header -= PSP_ENCAP_HLEN;
+ skb->network_header -= PSP_ENCAP_HLEN;
+ skb->transport_header -= PSP_ENCAP_HLEN;
+ memmove(skb->data, skb->data + PSP_ENCAP_HLEN, bufflen);
+
+ if (skb->protocol == htons(ETH_P_IP)) {
+ ip_hdr(skb)->protocol = IPPROTO_UDP;
+ be16_add_cpu(&ip_hdr(skb)->tot_len, PSP_ENCAP_HLEN);
+ ip_hdr(skb)->check = 0;
+ ip_hdr(skb)->check =
+ ip_fast_csum((u8 *)ip_hdr(skb), ip_hdr(skb)->ihl);
+ } else if (skb->protocol == htons(ETH_P_IPV6)) {
+ ipv6_hdr(skb)->nexthdr = IPPROTO_UDP;
+ be16_add_cpu(&ipv6_hdr(skb)->payload_len, PSP_ENCAP_HLEN);
+ } else {
+ return false;
+ }
+
+ skb_set_inner_ipproto(skb, IPPROTO_TCP);
+ skb_set_inner_transport_header(skb, skb_transport_offset(skb) +
+ PSP_ENCAP_HLEN);
+ skb->encapsulation = 1;
+ psp_write_headers(net, skb, spi, ver,
+ skb->len - skb_transport_offset(skb), sport);
+
+ return true;
+}
+EXPORT_SYMBOL(psp_dev_encapsulate);
+
+/* Receive handler for PSP packets.
+ *
+ * Presently it accepts only already-authenticated packets and does not
+ * support optional fields, such as virtualization cookies. The caller should
+ * ensure that skb->data is pointing to the mac header, and that skb->mac_len
+ * is set.
+ */
+int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv)
+{
+ int l2_hlen = 0, l3_hlen, encap;
+ struct psp_skb_ext *pse;
+ struct psphdr *psph;
+ struct ethhdr *eth;
+ struct udphdr *uh;
+ __be16 proto;
+ bool is_udp;
+
+ eth = (struct ethhdr *)skb->data;
+ proto = __vlan_get_protocol(skb, eth->h_proto, &l2_hlen);
+ if (proto == htons(ETH_P_IP))
+ l3_hlen = sizeof(struct iphdr);
+ else if (proto == htons(ETH_P_IPV6))
+ l3_hlen = sizeof(struct ipv6hdr);
+ else
+ return -EINVAL;
+
+ if (unlikely(!pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN)))
+ return -EINVAL;
+
+ if (proto == htons(ETH_P_IP)) {
+ struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
+
+ is_udp = iph->protocol == IPPROTO_UDP;
+ l3_hlen = iph->ihl * 4;
+ if (l3_hlen != sizeof(struct iphdr) &&
+ !pskb_may_pull(skb, l2_hlen + l3_hlen + PSP_ENCAP_HLEN))
+ return -EINVAL;
+ } else {
+ struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
+
+ is_udp = ipv6h->nexthdr == IPPROTO_UDP;
+ }
+
+ if (unlikely(!is_udp))
+ return -EINVAL;
+
+ uh = (struct udphdr *)(skb->data + l2_hlen + l3_hlen);
+ if (unlikely(uh->dest != htons(PSP_DEFAULT_UDP_PORT)))
+ return -EINVAL;
+
+ pse = skb_ext_add(skb, SKB_EXT_PSP);
+ if (!pse)
+ return -EINVAL;
+
+ psph = (struct psphdr *)(skb->data + l2_hlen + l3_hlen +
+ sizeof(struct udphdr));
+ pse->spi = psph->spi;
+ pse->dev_id = dev_id;
+ pse->generation = generation;
+ pse->version = FIELD_GET(PSPHDR_VERFL_VERSION, psph->verfl);
+
+ encap = PSP_ENCAP_HLEN;
+ encap += strip_icv ? PSP_TRL_SIZE : 0;
+
+ if (proto == htons(ETH_P_IP)) {
+ struct iphdr *iph = (struct iphdr *)(skb->data + l2_hlen);
+
+ iph->protocol = psph->nexthdr;
+ iph->tot_len = htons(ntohs(iph->tot_len) - encap);
+ iph->check = 0;
+ iph->check = ip_fast_csum((u8 *)iph, iph->ihl);
+ } else {
+ struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + l2_hlen);
+
+ ipv6h->nexthdr = psph->nexthdr;
+ ipv6h->payload_len = htons(ntohs(ipv6h->payload_len) - encap);
+ }
+
+ memmove(skb->data + PSP_ENCAP_HLEN, skb->data, l2_hlen + l3_hlen);
+ skb_pull(skb, PSP_ENCAP_HLEN);
+
+ if (strip_icv)
+ pskb_trim(skb, skb->len - PSP_TRL_SIZE);
+
+ return 0;
+}
+EXPORT_SYMBOL(psp_dev_rcv);
+
+static int __init psp_init(void)
+{
+ mutex_init(&psp_devs_lock);
+
+ return genl_register_family(&psp_nl_family);
+}
+
+subsys_initcall(psp_init);
diff --git a/net/psp/psp_nl.c b/net/psp/psp_nl.c
new file mode 100644
index 000000000000..8aaca62744c3
--- /dev/null
+++ b/net/psp/psp_nl.c
@@ -0,0 +1,505 @@
+// SPDX-License-Identifier: GPL-2.0-only
+
+#include <linux/skbuff.h>
+#include <linux/xarray.h>
+#include <net/genetlink.h>
+#include <net/psp.h>
+#include <net/sock.h>
+
+#include "psp-nl-gen.h"
+#include "psp.h"
+
+/* Netlink helpers */
+
+static struct sk_buff *psp_nl_reply_new(struct genl_info *info)
+{
+ struct sk_buff *rsp;
+ void *hdr;
+
+ rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
+ if (!rsp)
+ return NULL;
+
+ hdr = genlmsg_iput(rsp, info);
+ if (!hdr) {
+ nlmsg_free(rsp);
+ return NULL;
+ }
+
+ return rsp;
+}
+
+static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info)
+{
+ /* Note that this *only* works with a single message per skb! */
+ nlmsg_end(rsp, (struct nlmsghdr *)rsp->data);
+
+ return genlmsg_reply(rsp, info);
+}
+
+/* Device stuff */
+
+static struct psp_dev *
+psp_device_get_and_lock(struct net *net, struct nlattr *dev_id)
+{
+ struct psp_dev *psd;
+ int err;
+
+ mutex_lock(&psp_devs_lock);
+ psd = xa_load(&psp_devs, nla_get_u32(dev_id));
+ if (!psd) {
+ mutex_unlock(&psp_devs_lock);
+ return ERR_PTR(-ENODEV);
+ }
+
+ mutex_lock(&psd->lock);
+ mutex_unlock(&psp_devs_lock);
+
+ err = psp_dev_check_access(psd, net);
+ if (err) {
+ mutex_unlock(&psd->lock);
+ return ERR_PTR(err);
+ }
+
+ return psd;
+}
+
+int psp_device_get_locked(const struct genl_split_ops *ops,
+ struct sk_buff *skb, struct genl_info *info)
+{
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID))
+ return -EINVAL;
+
+ info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info),
+ info->attrs[PSP_A_DEV_ID]);
+ return PTR_ERR_OR_ZERO(info->user_ptr[0]);
+}
+
+void
+psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb,
+ struct genl_info *info)
+{
+ struct socket *socket = info->user_ptr[1];
+ struct psp_dev *psd = info->user_ptr[0];
+
+ mutex_unlock(&psd->lock);
+ if (socket)
+ sockfd_put(socket);
+}
+
+static int
+psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp,
+ const struct genl_info *info)
+{
+ void *hdr;
+
+ hdr = genlmsg_iput(rsp, info);
+ if (!hdr)
+ return -EMSGSIZE;
+
+ if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
+ nla_put_u32(rsp, PSP_A_DEV_IFINDEX, psd->main_netdev->ifindex) ||
+ nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_CAP, psd->caps->versions) ||
+ nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions))
+ goto err_cancel_msg;
+
+ genlmsg_end(rsp, hdr);
+ return 0;
+
+err_cancel_msg:
+ genlmsg_cancel(rsp, hdr);
+ return -EMSGSIZE;
+}
+
+void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd)
+{
+ struct genl_info info;
+ struct sk_buff *ntf;
+
+ if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev),
+ PSP_NLGRP_MGMT))
+ return;
+
+ ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
+ if (!ntf)
+ return;
+
+ genl_info_init_ntf(&info, &psp_nl_family, cmd);
+ if (psp_nl_dev_fill(psd, ntf, &info)) {
+ nlmsg_free(ntf);
+ return;
+ }
+
+ genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
+ 0, PSP_NLGRP_MGMT, GFP_KERNEL);
+}
+
+int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info)
+{
+ struct psp_dev *psd = info->user_ptr[0];
+ struct sk_buff *rsp;
+ int err;
+
+ rsp = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL);
+ if (!rsp)
+ return -ENOMEM;
+
+ err = psp_nl_dev_fill(psd, rsp, info);
+ if (err)
+ goto err_free_msg;
+
+ return genlmsg_reply(rsp, info);
+
+err_free_msg:
+ nlmsg_free(rsp);
+ return err;
+}
+
+static int
+psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb,
+ struct psp_dev *psd)
+{
+ if (psp_dev_check_access(psd, sock_net(rsp->sk)))
+ return 0;
+
+ return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb));
+}
+
+int psp_nl_dev_get_dumpit(struct sk_buff *rsp, struct netlink_callback *cb)
+{
+ struct psp_dev *psd;
+ int err = 0;
+
+ mutex_lock(&psp_devs_lock);
+ xa_for_each_start(&psp_devs, cb->args[0], psd, cb->args[0]) {
+ mutex_lock(&psd->lock);
+ err = psp_nl_dev_get_dumpit_one(rsp, cb, psd);
+ mutex_unlock(&psd->lock);
+ if (err)
+ break;
+ }
+ mutex_unlock(&psp_devs_lock);
+
+ return err;
+}
+
+int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info)
+{
+ struct psp_dev *psd = info->user_ptr[0];
+ struct psp_dev_config new_config;
+ struct sk_buff *rsp;
+ int err;
+
+ memcpy(&new_config, &psd->config, sizeof(new_config));
+
+ if (info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]) {
+ new_config.versions =
+ nla_get_u32(info->attrs[PSP_A_DEV_PSP_VERSIONS_ENA]);
+ if (new_config.versions & ~psd->caps->versions) {
+ NL_SET_ERR_MSG(info->extack, "Requested PSP versions not supported by the device");
+ return -EINVAL;
+ }
+ } else {
+ NL_SET_ERR_MSG(info->extack, "No settings present");
+ return -EINVAL;
+ }
+
+ rsp = psp_nl_reply_new(info);
+ if (!rsp)
+ return -ENOMEM;
+
+ if (memcmp(&new_config, &psd->config, sizeof(new_config))) {
+ err = psd->ops->set_config(psd, &new_config, info->extack);
+ if (err)
+ goto err_free_rsp;
+
+ memcpy(&psd->config, &new_config, sizeof(new_config));
+ }
+
+ psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF);
+
+ return psp_nl_reply_send(rsp, info);
+
+err_free_rsp:
+ nlmsg_free(rsp);
+ return err;
+}
+
+int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info)
+{
+ struct psp_dev *psd = info->user_ptr[0];
+ struct genl_info ntf_info;
+ struct sk_buff *ntf, *rsp;
+ u8 prev_gen;
+ int err;
+
+ rsp = psp_nl_reply_new(info);
+ if (!rsp)
+ return -ENOMEM;
+
+ genl_info_init_ntf(&ntf_info, &psp_nl_family, PSP_CMD_KEY_ROTATE_NTF);
+ ntf = psp_nl_reply_new(&ntf_info);
+ if (!ntf) {
+ err = -ENOMEM;
+ goto err_free_rsp;
+ }
+
+ if (nla_put_u32(rsp, PSP_A_DEV_ID, psd->id) ||
+ nla_put_u32(ntf, PSP_A_DEV_ID, psd->id)) {
+ err = -EMSGSIZE;
+ goto err_free_ntf;
+ }
+
+ /* suggest the next gen number, driver can override */
+ prev_gen = psd->generation;
+ psd->generation = (prev_gen + 1) & PSP_GEN_VALID_MASK;
+
+ err = psd->ops->key_rotate(psd, info->extack);
+ if (err)
+ goto err_free_ntf;
+
+ WARN_ON_ONCE((psd->generation && psd->generation == prev_gen) ||
+ psd->generation & ~PSP_GEN_VALID_MASK);
+
+ psp_assocs_key_rotated(psd);
+
+ nlmsg_end(ntf, (struct nlmsghdr *)ntf->data);
+ genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf,
+ 0, PSP_NLGRP_USE, GFP_KERNEL);
+ return psp_nl_reply_send(rsp, info);
+
+err_free_ntf:
+ nlmsg_free(ntf);
+err_free_rsp:
+ nlmsg_free(rsp);
+ return err;
+}
+
+/* Key etc. */
+
+int psp_assoc_device_get_locked(const struct genl_split_ops *ops,
+ struct sk_buff *skb, struct genl_info *info)
+{
+ struct socket *socket;
+ struct psp_dev *psd;
+ struct nlattr *id;
+ int fd, err;
+
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD))
+ return -EINVAL;
+
+ fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]);
+ socket = sockfd_lookup(fd, &err);
+ if (!socket)
+ return err;
+
+ if (!sk_is_tcp(socket->sk)) {
+ NL_SET_ERR_MSG_ATTR(info->extack,
+ info->attrs[PSP_A_ASSOC_SOCK_FD],
+ "Unsupported socket family and type");
+ err = -EOPNOTSUPP;
+ goto err_sock_put;
+ }
+
+ psd = psp_dev_get_for_sock(socket->sk);
+ if (psd) {
+ err = psp_dev_check_access(psd, genl_info_net(info));
+ if (err) {
+ psp_dev_put(psd);
+ psd = NULL;
+ }
+ }
+
+ if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) {
+ err = -EINVAL;
+ goto err_sock_put;
+ }
+
+ id = info->attrs[PSP_A_ASSOC_DEV_ID];
+ if (psd) {
+ mutex_lock(&psd->lock);
+ if (id && psd->id != nla_get_u32(id)) {
+ mutex_unlock(&psd->lock);
+ NL_SET_ERR_MSG_ATTR(info->extack, id,
+ "Device id vs socket mismatch");
+ err = -EINVAL;
+ goto err_psd_put;
+ }
+
+ psp_dev_put(psd);
+ } else {
+ psd = psp_device_get_and_lock(genl_info_net(info), id);
+ if (IS_ERR(psd)) {
+ err = PTR_ERR(psd);
+ goto err_sock_put;
+ }
+ }
+
+ info->user_ptr[0] = psd;
+ info->user_ptr[1] = socket;
+
+ return 0;
+
+err_psd_put:
+ psp_dev_put(psd);
+err_sock_put:
+ sockfd_put(socket);
+ return err;
+}
+
+static int
+psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key,
+ unsigned int key_sz)
+{
+ struct nlattr *nest = info->attrs[attr];
+ struct nlattr *tb[PSP_A_KEYS_SPI + 1];
+ u32 spi;
+ int err;
+
+ err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest,
+ psp_keys_nl_policy, info->extack);
+ if (err)
+ return err;
+
+ if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) ||
+ NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI))
+ return -EINVAL;
+
+ if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) {
+ NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
+ "incorrect key length");
+ return -EINVAL;
+ }
+
+ spi = nla_get_u32(tb[PSP_A_KEYS_SPI]);
+ if (!(spi & PSP_SPI_KEY_ID)) {
+ NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY],
+ "invalid SPI: lower 31b must be non-zero");
+ return -EINVAL;
+ }
+
+ key->spi = cpu_to_be32(spi);
+ memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz);
+
+ return 0;
+}
+
+static int
+psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version,
+ struct psp_key_parsed *key)
+{
+ int key_sz = psp_key_size(version);
+ void *nest;
+
+ nest = nla_nest_start(skb, attr);
+
+ if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) ||
+ nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) {
+ nla_nest_cancel(skb, nest);
+ return -EMSGSIZE;
+ }
+
+ nla_nest_end(skb, nest);
+
+ return 0;
+}
+
+int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
+{
+ struct socket *socket = info->user_ptr[1];
+ struct psp_dev *psd = info->user_ptr[0];
+ struct psp_key_parsed key;
+ struct psp_assoc *pas;
+ struct sk_buff *rsp;
+ u32 version;
+ int err;
+
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION))
+ return -EINVAL;
+
+ version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
+ if (!(psd->caps->versions & (1 << version))) {
+ NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
+ return -EOPNOTSUPP;
+ }
+
+ rsp = psp_nl_reply_new(info);
+ if (!rsp)
+ return -ENOMEM;
+
+ pas = psp_assoc_create(psd);
+ if (!pas) {
+ err = -ENOMEM;
+ goto err_free_rsp;
+ }
+ pas->version = version;
+
+ err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack);
+ if (err)
+ goto err_free_pas;
+
+ if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) ||
+ psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) {
+ err = -EMSGSIZE;
+ goto err_free_pas;
+ }
+
+ err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack);
+ if (err) {
+ NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]);
+ goto err_free_pas;
+ }
+ psp_assoc_put(pas);
+
+ return psp_nl_reply_send(rsp, info);
+
+err_free_pas:
+ psp_assoc_put(pas);
+err_free_rsp:
+ nlmsg_free(rsp);
+ return err;
+}
+
+int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info)
+{
+ struct socket *socket = info->user_ptr[1];
+ struct psp_dev *psd = info->user_ptr[0];
+ struct psp_key_parsed key;
+ struct sk_buff *rsp;
+ unsigned int key_sz;
+ u32 version;
+ int err;
+
+ if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) ||
+ GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY))
+ return -EINVAL;
+
+ version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]);
+ if (!(psd->caps->versions & (1 << version))) {
+ NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]);
+ return -EOPNOTSUPP;
+ }
+
+ key_sz = psp_key_size(version);
+ if (!key_sz)
+ return -EINVAL;
+
+ err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz);
+ if (err < 0)
+ return err;
+
+ rsp = psp_nl_reply_new(info);
+ if (!rsp)
+ return -ENOMEM;
+
+ err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key,
+ info->extack);
+ if (err)
+ goto err_free_msg;
+
+ return psp_nl_reply_send(rsp, info);
+
+err_free_msg:
+ nlmsg_free(rsp);
+ return err;
+}
diff --git a/net/psp/psp_sock.c b/net/psp/psp_sock.c
new file mode 100644
index 000000000000..afa966c6b69d
--- /dev/null
+++ b/net/psp/psp_sock.c
@@ -0,0 +1,295 @@
+// SPDX-License-Identifier: GPL-2.0-only
+
+#include <linux/file.h>
+#include <linux/net.h>
+#include <linux/rcupdate.h>
+#include <linux/tcp.h>
+
+#include <net/ip.h>
+#include <net/psp.h>
+#include "psp.h"
+
+struct psp_dev *psp_dev_get_for_sock(struct sock *sk)
+{
+ struct dst_entry *dst;
+ struct psp_dev *psd;
+
+ dst = sk_dst_get(sk);
+ if (!dst)
+ return NULL;
+
+ rcu_read_lock();
+ psd = rcu_dereference(dst->dev->psp_dev);
+ if (psd && !psp_dev_tryget(psd))
+ psd = NULL;
+ rcu_read_unlock();
+
+ dst_release(dst);
+
+ return psd;
+}
+
+static struct sk_buff *
+psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb)
+{
+ struct psp_assoc *pas;
+ bool good;
+
+ rcu_read_lock();
+ pas = psp_skb_get_assoc_rcu(skb);
+ good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd;
+ rcu_read_unlock();
+ if (!good) {
+ kfree_skb_reason(skb, SKB_DROP_REASON_PSP_OUTPUT);
+ return NULL;
+ }
+
+ return skb;
+}
+
+struct psp_assoc *psp_assoc_create(struct psp_dev *psd)
+{
+ struct psp_assoc *pas;
+
+ lockdep_assert_held(&psd->lock);
+
+ pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc),
+ GFP_KERNEL_ACCOUNT);
+ if (!pas)
+ return NULL;
+
+ pas->psd = psd;
+ pas->dev_id = psd->id;
+ pas->generation = psd->generation;
+ psp_dev_get(psd);
+ refcount_set(&pas->refcnt, 1);
+
+ list_add_tail(&pas->assocs_list, &psd->active_assocs);
+
+ return pas;
+}
+
+static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas)
+{
+ struct psp_dev *psd = pas->psd;
+ size_t sz;
+
+ lockdep_assert_held(&psd->lock);
+
+ sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc);
+ return kmemdup(pas, sz, GFP_KERNEL);
+}
+
+static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas,
+ struct netlink_ext_ack *extack)
+{
+ return psd->ops->tx_key_add(psd, pas, extack);
+}
+
+void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas)
+{
+ if (pas->tx.spi)
+ psd->ops->tx_key_del(psd, pas);
+ list_del(&pas->assocs_list);
+}
+
+static void psp_assoc_free(struct work_struct *work)
+{
+ struct psp_assoc *pas = container_of(work, struct psp_assoc, work);
+ struct psp_dev *psd = pas->psd;
+
+ mutex_lock(&psd->lock);
+ if (psd->ops)
+ psp_dev_tx_key_del(psd, pas);
+ mutex_unlock(&psd->lock);
+ psp_dev_put(psd);
+ kfree(pas);
+}
+
+static void psp_assoc_free_queue(struct rcu_head *head)
+{
+ struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu);
+
+ INIT_WORK(&pas->work, psp_assoc_free);
+ schedule_work(&pas->work);
+}
+
+/**
+ * psp_assoc_put() - release a reference on a PSP association
+ * @pas: association to release
+ */
+void psp_assoc_put(struct psp_assoc *pas)
+{
+ if (pas && refcount_dec_and_test(&pas->refcnt))
+ call_rcu(&pas->rcu, psp_assoc_free_queue);
+}
+
+void psp_sk_assoc_free(struct sock *sk)
+{
+ struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1);
+
+ rcu_assign_pointer(sk->psp_assoc, NULL);
+ psp_assoc_put(pas);
+}
+
+int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas,
+ struct psp_key_parsed *key,
+ struct netlink_ext_ack *extack)
+{
+ int err;
+
+ memcpy(&pas->rx, key, sizeof(*key));
+
+ lock_sock(sk);
+
+ if (psp_sk_assoc(sk)) {
+ NL_SET_ERR_MSG(extack, "Socket already has PSP state");
+ err = -EBUSY;
+ goto exit_unlock;
+ }
+
+ refcount_inc(&pas->refcnt);
+ rcu_assign_pointer(sk->psp_assoc, pas);
+ err = 0;
+
+exit_unlock:
+ release_sock(sk);
+
+ return err;
+}
+
+static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas)
+{
+ struct psp_skb_ext *pse;
+ struct sk_buff *skb;
+
+ skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) {
+ pse = skb_ext_find(skb, SKB_EXT_PSP);
+ if (!psp_pse_matches_pas(pse, pas))
+ return -EBUSY;
+ }
+
+ skb_queue_walk(&sk->sk_receive_queue, skb) {
+ pse = skb_ext_find(skb, SKB_EXT_PSP);
+ if (!psp_pse_matches_pas(pse, pas))
+ return -EBUSY;
+ }
+ return 0;
+}
+
+int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd,
+ u32 version, struct psp_key_parsed *key,
+ struct netlink_ext_ack *extack)
+{
+ struct inet_connection_sock *icsk;
+ struct psp_assoc *pas, *dummy;
+ int err;
+
+ lock_sock(sk);
+
+ pas = psp_sk_assoc(sk);
+ if (!pas) {
+ NL_SET_ERR_MSG(extack, "Socket has no Rx key");
+ err = -EINVAL;
+ goto exit_unlock;
+ }
+ if (pas->psd != psd) {
+ NL_SET_ERR_MSG(extack, "Rx key from different device");
+ err = -EINVAL;
+ goto exit_unlock;
+ }
+ if (pas->version != version) {
+ NL_SET_ERR_MSG(extack,
+ "PSP version mismatch with existing state");
+ err = -EINVAL;
+ goto exit_unlock;
+ }
+ if (pas->tx.spi) {
+ NL_SET_ERR_MSG(extack, "Tx key already set");
+ err = -EBUSY;
+ goto exit_unlock;
+ }
+
+ err = psp_sock_recv_queue_check(sk, pas);
+ if (err) {
+ NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue");
+ goto exit_unlock;
+ }
+
+ /* Pass a fake association to drivers to make sure they don't
+ * try to store pointers to it. For re-keying we'll need to
+ * re-allocate the assoc structures.
+ */
+ dummy = psp_assoc_dummy(pas);
+ if (!dummy) {
+ err = -ENOMEM;
+ goto exit_unlock;
+ }
+
+ memcpy(&dummy->tx, key, sizeof(*key));
+ err = psp_dev_tx_key_add(psd, dummy, extack);
+ if (err)
+ goto exit_free_dummy;
+
+ memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc);
+ memcpy(&pas->tx, key, sizeof(*key));
+
+ WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit);
+ tcp_write_collapse_fence(sk);
+ pas->upgrade_seq = tcp_sk(sk)->rcv_nxt;
+
+ icsk = inet_csk(sk);
+ icsk->icsk_ext_hdr_len += psp_sk_overhead(sk);
+ icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie);
+
+exit_free_dummy:
+ kfree(dummy);
+exit_unlock:
+ release_sock(sk);
+ return err;
+}
+
+void psp_assocs_key_rotated(struct psp_dev *psd)
+{
+ struct psp_assoc *pas, *next;
+
+ /* Mark the stale associations as invalid, they will no longer
+ * be able to Rx any traffic.
+ */
+ list_for_each_entry_safe(pas, next, &psd->prev_assocs, assocs_list)
+ pas->generation |= ~PSP_GEN_VALID_MASK;
+ list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
+ list_splice_init(&psd->active_assocs, &psd->prev_assocs);
+
+ /* TODO: we should inform the sockets that got shut down */
+}
+
+void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk)
+{
+ struct psp_assoc *pas = psp_sk_assoc(sk);
+
+ if (pas)
+ refcount_inc(&pas->refcnt);
+ rcu_assign_pointer(tw->psp_assoc, pas);
+ tw->tw_validate_xmit_skb = psp_validate_xmit;
+}
+
+void psp_twsk_assoc_free(struct inet_timewait_sock *tw)
+{
+ struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1);
+
+ rcu_assign_pointer(tw->psp_assoc, NULL);
+ psp_assoc_put(pas);
+}
+
+void psp_reply_set_decrypted(struct sk_buff *skb)
+{
+ struct psp_assoc *pas;
+
+ rcu_read_lock();
+ pas = psp_sk_get_assoc_rcu(skb->sk);
+ if (pas && pas->tx.spi)
+ skb->decrypted = 1;
+ rcu_read_unlock();
+}
+EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted);
diff --git a/tools/net/ynl/Makefile.deps b/tools/net/ynl/Makefile.deps
index 90686e241157..865fd2e8519e 100644
--- a/tools/net/ynl/Makefile.deps
+++ b/tools/net/ynl/Makefile.deps
@@ -31,6 +31,7 @@ CFLAGS_ovpn:=$(call get_hdr_inc,_LINUX_OVPN_H,ovpn.h)
CFLAGS_ovs_datapath:=$(call get_hdr_inc,__LINUX_OPENVSWITCH_H,openvswitch.h)
CFLAGS_ovs_flow:=$(call get_hdr_inc,__LINUX_OPENVSWITCH_H,openvswitch.h)
CFLAGS_ovs_vport:=$(call get_hdr_inc,__LINUX_OPENVSWITCH_H,openvswitch.h)
+CFLAGS_psp:=$(call get_hdr_inc,_LINUX_PSP_H,psp.h)
CFLAGS_rt-addr:=$(call get_hdr_inc,__LINUX_RTNETLINK_H,rtnetlink.h) \
$(call get_hdr_inc,__LINUX_IF_ADDR_H,if_addr.h)
CFLAGS_rt-link:=$(call get_hdr_inc,__LINUX_RTNETLINK_H,rtnetlink.h) \