diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_main.c b/drivers/net/ethernet/mellanox/mlx5/core/en_main.c
index 772bfdbdeb9c601f93354a654f02163c6db3fc17..06a592fb62bf201f0df592a11f6cb16c1c4a0019 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_main.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_main.c
@@ -63,6 +63,7 @@
 #include "en/xsk/rx.h"
 #include "en/xsk/tx.h"
 #include "en/hv_vhca_stats.h"
+#include "lib/mlx5.h"
 
 
 bool mlx5e_check_fragmented_striding_rq_cap(struct mlx5_core_dev *mdev)
@@ -5427,6 +5428,7 @@ static void *mlx5e_add(struct mlx5_core_dev *mdev)
 		return NULL;
 	}
 
+	dev_net_set(netdev, mlx5_core_net(mdev));
 	priv = netdev_priv(netdev);
 
 	err = mlx5e_attach(mdev, priv);
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_rep.c b/drivers/net/ethernet/mellanox/mlx5/core/en_rep.c
index cd9bb7c7b3413651a9835925d69305299d007dcf..c7f98f1fd9b12e11fb0557f4d92a0e7307e790e8 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_rep.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_rep.c
@@ -47,6 +47,7 @@
 #include "en/tc_tun.h"
 #include "fs_core.h"
 #include "lib/port_tun.h"
+#include "lib/mlx5.h"
 #define CREATE_TRACE_POINTS
 #include "diag/en_rep_tracepoint.h"
 
@@ -1877,6 +1878,7 @@ mlx5e_vport_rep_load(struct mlx5_core_dev *dev, struct mlx5_eswitch_rep *rep)
 		return -EINVAL;
 	}
 
+	dev_net_set(netdev, mlx5_core_net(dev));
 	rpriv->netdev = netdev;
 	rep->rep_data[REP_ETH].priv = rpriv;
 	INIT_LIST_HEAD(&rpriv->vport_sqs_list);
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/lib/mlx5.h b/drivers/net/ethernet/mellanox/mlx5/core/lib/mlx5.h
index b99d469e4e6457e0dda72f42d164e3c0e829c38e..249539247e2e76841232355121fd4b742256a135 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/lib/mlx5.h
+++ b/drivers/net/ethernet/mellanox/mlx5/core/lib/mlx5.h
@@ -84,4 +84,9 @@ int mlx5_create_encryption_key(struct mlx5_core_dev *mdev,
 			       void *key, u32 sz_bytes, u32 *p_key_id);
 void mlx5_destroy_encryption_key(struct mlx5_core_dev *mdev, u32 key_id);
 
+static inline struct net *mlx5_core_net(struct mlx5_core_dev *dev)
+{
+	return devlink_net(priv_to_devlink(dev));
+}
+
 #endif