diff --git a/net/packet/af_packet.c b/net/packet/af_packet.c
index af399cac5205402480300ae7b7d6fbb2094095fb..242bce1cf0f3a66b1806469ab3e733363638cbf2 100644
--- a/net/packet/af_packet.c
+++ b/net/packet/af_packet.c
@@ -1741,6 +1741,20 @@ static void fanout_release(struct sock *sk)
 		kfree_rcu(po->rollover, rcu);
 }
 
+static bool packet_extra_vlan_len_allowed(const struct net_device *dev,
+					  struct sk_buff *skb)
+{
+	/* Earlier code assumed this would be a VLAN pkt, double-check
+	 * this now that we have the actual packet in hand. We can only
+	 * do this check on Ethernet devices.
+	 */
+	if (unlikely(dev->type != ARPHRD_ETHER))
+		return false;
+
+	skb_reset_mac_header(skb);
+	return likely(eth_hdr(skb)->h_proto == htons(ETH_P_8021Q));
+}
+
 static const struct proto_ops packet_ops;
 
 static const struct proto_ops packet_ops_spkt;
@@ -1902,18 +1916,10 @@ static int packet_sendmsg_spkt(struct socket *sock, struct msghdr *msg,
 		goto retry;
 	}
 
-	if (len > (dev->mtu + dev->hard_header_len + extra_len)) {
-		/* Earlier code assumed this would be a VLAN pkt,
-		 * double-check this now that we have the actual
-		 * packet in hand.
-		 */
-		struct ethhdr *ehdr;
-		skb_reset_mac_header(skb);
-		ehdr = eth_hdr(skb);
-		if (ehdr->h_proto != htons(ETH_P_8021Q)) {
-			err = -EMSGSIZE;
-			goto out_unlock;
-		}
+	if (len > (dev->mtu + dev->hard_header_len + extra_len) &&
+	    !packet_extra_vlan_len_allowed(dev, skb)) {
+		err = -EMSGSIZE;
+		goto out_unlock;
 	}
 
 	skb->protocol = proto;
@@ -2332,6 +2338,15 @@ static bool ll_header_truncated(const struct net_device *dev, int len)
 	return false;
 }
 
+static void tpacket_set_protocol(const struct net_device *dev,
+				 struct sk_buff *skb)
+{
+	if (dev->type == ARPHRD_ETHER) {
+		skb_reset_mac_header(skb);
+		skb->protocol = eth_hdr(skb)->h_proto;
+	}
+}
+
 static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
 		void *frame, struct net_device *dev, int size_max,
 		__be16 proto, unsigned char *addr, int hlen)
@@ -2368,8 +2383,6 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
 	skb_reserve(skb, hlen);
 	skb_reset_network_header(skb);
 
-	if (!packet_use_direct_xmit(po))
-		skb_probe_transport_header(skb, 0);
 	if (unlikely(po->tp_tx_has_off)) {
 		int off_min, off_max, off;
 		off_min = po->tp_hdrlen - sizeof(struct sockaddr_ll);
@@ -2415,6 +2428,8 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
 				dev->hard_header_len);
 		if (unlikely(err))
 			return err;
+		if (!skb->protocol)
+			tpacket_set_protocol(dev, skb);
 
 		data += dev->hard_header_len;
 		to_write -= dev->hard_header_len;
@@ -2449,6 +2464,8 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
 		len = ((to_write > len_max) ? len_max : to_write);
 	}
 
+	skb_probe_transport_header(skb, 0);
+
 	return tp_len;
 }
 
@@ -2493,12 +2510,13 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
 	if (unlikely(!(dev->flags & IFF_UP)))
 		goto out_put;
 
-	reserve = dev->hard_header_len + VLAN_HLEN;
+	if (po->sk.sk_socket->type == SOCK_RAW)
+		reserve = dev->hard_header_len;
 	size_max = po->tx_ring.frame_size
 		- (po->tp_hdrlen - sizeof(struct sockaddr_ll));
 
-	if (size_max > dev->mtu + reserve)
-		size_max = dev->mtu + reserve;
+	if (size_max > dev->mtu + reserve + VLAN_HLEN)
+		size_max = dev->mtu + reserve + VLAN_HLEN;
 
 	do {
 		ph = packet_current_frame(po, &po->tx_ring,
@@ -2525,18 +2543,10 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
 		tp_len = tpacket_fill_skb(po, skb, ph, dev, size_max, proto,
 					  addr, hlen);
 		if (likely(tp_len >= 0) &&
-		    tp_len > dev->mtu + dev->hard_header_len) {
-			struct ethhdr *ehdr;
-			/* Earlier code assumed this would be a VLAN pkt,
-			 * double-check this now that we have the actual
-			 * packet in hand.
-			 */
+		    tp_len > dev->mtu + reserve &&
+		    !packet_extra_vlan_len_allowed(dev, skb))
+			tp_len = -EMSGSIZE;
 
-			skb_reset_mac_header(skb);
-			ehdr = eth_hdr(skb);
-			if (ehdr->h_proto != htons(ETH_P_8021Q))
-				tp_len = -EMSGSIZE;
-		}
 		if (unlikely(tp_len < 0)) {
 			if (po->tp_loss) {
 				__packet_set_status(po, ph,
@@ -2765,18 +2775,10 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
 
 	sock_tx_timestamp(sk, &skb_shinfo(skb)->tx_flags);
 
-	if (!gso_type && (len > dev->mtu + reserve + extra_len)) {
-		/* Earlier code assumed this would be a VLAN pkt,
-		 * double-check this now that we have the actual
-		 * packet in hand.
-		 */
-		struct ethhdr *ehdr;
-		skb_reset_mac_header(skb);
-		ehdr = eth_hdr(skb);
-		if (ehdr->h_proto != htons(ETH_P_8021Q)) {
-			err = -EMSGSIZE;
-			goto out_free;
-		}
+	if (!gso_type && (len > dev->mtu + reserve + extra_len) &&
+	    !packet_extra_vlan_len_allowed(dev, skb)) {
+		err = -EMSGSIZE;
+		goto out_free;
 	}
 
 	skb->protocol = proto;
@@ -2807,8 +2809,8 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
 		len += vnet_hdr_len;
 	}
 
-	if (!packet_use_direct_xmit(po))
-		skb_probe_transport_header(skb, reserve);
+	skb_probe_transport_header(skb, reserve);
+
 	if (unlikely(extra_len == 4))
 		skb->no_fcs = 1;