blob: e432f092f2fb685911b87039797fb78e6c624559 [file] [log] [blame]
#include <linux/types.h>
#include <linux/skbuff.h>
#include <linux/socket.h>
#include <linux/sysctl.h>
#include <linux/net.h>
#include <linux/module.h>
#include <linux/if_arp.h>
#include <linux/ipv6.h>
#include <linux/mpls.h>
#include <net/ip.h>
#include <net/dst.h>
#include <net/sock.h>
#include <net/arp.h>
#include <net/ip_fib.h>
#include <net/netevent.h>
#include <net/netns/generic.h>
#include "internal.h"
#define LABEL_NOT_SPECIFIED (1<<20)
#define MAX_NEW_LABELS 2
/* This maximum ha length copied from the definition of struct neighbour */
#define MAX_VIA_ALEN (ALIGN(MAX_ADDR_LEN, sizeof(unsigned long)))
struct mpls_route { /* next hop label forwarding entry */
struct net_device *rt_dev;
struct rcu_head rt_rcu;
u32 rt_label[MAX_NEW_LABELS];
u8 rt_protocol; /* routing protocol that set this entry */
u8 rt_labels:2,
rt_via_alen:6;
unsigned short rt_via_family;
u8 rt_via[0];
};
static int zero = 0;
static int label_limit = (1 << 20) - 1;
static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
{
struct mpls_route *rt = NULL;
if (index < net->mpls.platform_labels) {
struct mpls_route __rcu **platform_label =
rcu_dereference(net->mpls.platform_label);
rt = rcu_dereference(platform_label[index]);
}
return rt;
}
static bool mpls_output_possible(const struct net_device *dev)
{
return dev && (dev->flags & IFF_UP) && netif_carrier_ok(dev);
}
static unsigned int mpls_rt_header_size(const struct mpls_route *rt)
{
/* The size of the layer 2.5 labels to be added for this route */
return rt->rt_labels * sizeof(struct mpls_shim_hdr);
}
static unsigned int mpls_dev_mtu(const struct net_device *dev)
{
/* The amount of data the layer 2 frame can hold */
return dev->mtu;
}
static bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
{
if (skb->len <= mtu)
return false;
if (skb_is_gso(skb) && skb_gso_network_seglen(skb) <= mtu)
return false;
return true;
}
static bool mpls_egress(struct mpls_route *rt, struct sk_buff *skb,
struct mpls_entry_decoded dec)
{
/* RFC4385 and RFC5586 encode other packets in mpls such that
* they don't conflict with the ip version number, making
* decoding by examining the ip version correct in everything
* except for the strangest cases.
*
* The strange cases if we choose to support them will require
* manual configuration.
*/
struct iphdr *hdr4 = ip_hdr(skb);
bool success = true;
if (hdr4->version == 4) {
skb->protocol = htons(ETH_P_IP);
csum_replace2(&hdr4->check,
htons(hdr4->ttl << 8),
htons(dec.ttl << 8));
hdr4->ttl = dec.ttl;
}
else if (hdr4->version == 6) {
struct ipv6hdr *hdr6 = ipv6_hdr(skb);
skb->protocol = htons(ETH_P_IPV6);
hdr6->hop_limit = dec.ttl;
}
else
/* version 0 and version 1 are used by pseudo wires */
success = false;
return success;
}
static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
struct packet_type *pt, struct net_device *orig_dev)
{
struct net *net = dev_net(dev);
struct mpls_shim_hdr *hdr;
struct mpls_route *rt;
struct mpls_entry_decoded dec;
struct net_device *out_dev;
unsigned int hh_len;
unsigned int new_header_size;
unsigned int mtu;
int err;
/* Careful this entire function runs inside of an rcu critical section */
if (skb->pkt_type != PACKET_HOST)
goto drop;
if ((skb = skb_share_check(skb, GFP_ATOMIC)) == NULL)
goto drop;
if (!pskb_may_pull(skb, sizeof(*hdr)))
goto drop;
/* Read and decode the label */
hdr = mpls_hdr(skb);
dec = mpls_entry_decode(hdr);
/* Pop the label */
skb_pull(skb, sizeof(*hdr));
skb_reset_network_header(skb);
skb_orphan(skb);
rt = mpls_route_input_rcu(net, dec.label);
if (!rt)
goto drop;
/* Find the output device */
out_dev = rt->rt_dev;
if (!mpls_output_possible(out_dev))
goto drop;
if (skb_warn_if_lro(skb))
goto drop;
skb_forward_csum(skb);
/* Verify ttl is valid */
if (dec.ttl <= 2)
goto drop;
dec.ttl -= 1;
/* Verify the destination can hold the packet */
new_header_size = mpls_rt_header_size(rt);
mtu = mpls_dev_mtu(out_dev);
if (mpls_pkt_too_big(skb, mtu - new_header_size))
goto drop;
hh_len = LL_RESERVED_SPACE(out_dev);
if (!out_dev->header_ops)
hh_len = 0;
/* Ensure there is enough space for the headers in the skb */
if (skb_cow(skb, hh_len + new_header_size))
goto drop;
skb->dev = out_dev;
skb->protocol = htons(ETH_P_MPLS_UC);
if (unlikely(!new_header_size && dec.bos)) {
/* Penultimate hop popping */
if (!mpls_egress(rt, skb, dec))
goto drop;
} else {
bool bos;
int i;
skb_push(skb, new_header_size);
skb_reset_network_header(skb);
/* Push the new labels */
hdr = mpls_hdr(skb);
bos = dec.bos;
for (i = rt->rt_labels - 1; i >= 0; i--) {
hdr[i] = mpls_entry_encode(rt->rt_label[i], dec.ttl, 0, bos);
bos = false;
}
}
err = neigh_xmit(rt->rt_via_family, out_dev, rt->rt_via, skb);
if (err)
net_dbg_ratelimited("%s: packet transmission failed: %d\n",
__func__, err);
return 0;
drop:
kfree_skb(skb);
return NET_RX_DROP;
}
static struct packet_type mpls_packet_type __read_mostly = {
.type = cpu_to_be16(ETH_P_MPLS_UC),
.func = mpls_forward,
};
struct mpls_route_config {
u32 rc_protocol;
u32 rc_ifindex;
u16 rc_via_family;
u16 rc_via_alen;
u8 rc_via[MAX_VIA_ALEN];
u32 rc_label;
u32 rc_output_labels;
u32 rc_output_label[MAX_NEW_LABELS];
u32 rc_nlflags;
struct nl_info rc_nlinfo;
};
static struct mpls_route *mpls_rt_alloc(size_t alen)
{
struct mpls_route *rt;
rt = kzalloc(GFP_KERNEL, sizeof(*rt) + alen);
if (rt)
rt->rt_via_alen = alen;
return rt;
}
static void mpls_rt_free(struct mpls_route *rt)
{
if (rt)
kfree_rcu(rt, rt_rcu);
}
static void mpls_route_update(struct net *net, unsigned index,
struct net_device *dev, struct mpls_route *new,
const struct nl_info *info)
{
struct mpls_route *rt, *old = NULL;
ASSERT_RTNL();
rt = net->mpls.platform_label[index];
if (!dev || (rt && (rt->rt_dev == dev))) {
rcu_assign_pointer(net->mpls.platform_label[index], new);
old = rt;
}
/* If we removed a route free it now */
mpls_rt_free(old);
}
static unsigned find_free_label(struct net *net)
{
unsigned index;
for (index = 16; index < net->mpls.platform_labels; index++) {
if (!net->mpls.platform_label[index])
return index;
}
return LABEL_NOT_SPECIFIED;
}
static int mpls_route_add(struct mpls_route_config *cfg)
{
struct net *net = cfg->rc_nlinfo.nl_net;
struct net_device *dev = NULL;
struct mpls_route *rt, *old;
unsigned index;
int i;
int err = -EINVAL;
index = cfg->rc_label;
/* If a label was not specified during insert pick one */
if ((index == LABEL_NOT_SPECIFIED) &&
(cfg->rc_nlflags & NLM_F_CREATE)) {
index = find_free_label(net);
}
/* The first 16 labels are reserved, and may not be set */
if (index < 16)
goto errout;
/* The full 20 bit range may not be supported. */
if (index >= net->mpls.platform_labels)
goto errout;
/* Ensure only a supported number of labels are present */
if (cfg->rc_output_labels > MAX_NEW_LABELS)
goto errout;
err = -ENODEV;
dev = dev_get_by_index(net, cfg->rc_ifindex);
if (!dev)
goto errout;
/* For now just support ethernet devices */
err = -EINVAL;
if ((dev->type != ARPHRD_ETHER) && (dev->type != ARPHRD_LOOPBACK))
goto errout;
err = -EINVAL;
if ((cfg->rc_via_family == AF_PACKET) &&
(dev->addr_len != cfg->rc_via_alen))
goto errout;
/* Append makes no sense with mpls */
err = -EINVAL;
if (cfg->rc_nlflags & NLM_F_APPEND)
goto errout;
err = -EEXIST;
old = net->mpls.platform_label[index];
if ((cfg->rc_nlflags & NLM_F_EXCL) && old)
goto errout;
err = -EEXIST;
if (!(cfg->rc_nlflags & NLM_F_REPLACE) && old)
goto errout;
err = -ENOENT;
if (!(cfg->rc_nlflags & NLM_F_CREATE) && !old)
goto errout;
err = -ENOMEM;
rt = mpls_rt_alloc(cfg->rc_via_alen);
if (!rt)
goto errout;
rt->rt_labels = cfg->rc_output_labels;
for (i = 0; i < rt->rt_labels; i++)
rt->rt_label[i] = cfg->rc_output_label[i];
rt->rt_protocol = cfg->rc_protocol;
rt->rt_dev = dev;
rt->rt_via_family = cfg->rc_via_family;
memcpy(rt->rt_via, cfg->rc_via, cfg->rc_via_alen);
mpls_route_update(net, index, NULL, rt, &cfg->rc_nlinfo);
dev_put(dev);
return 0;
errout:
if (dev)
dev_put(dev);
return err;
}
static int mpls_route_del(struct mpls_route_config *cfg)
{
struct net *net = cfg->rc_nlinfo.nl_net;
unsigned index;
int err = -EINVAL;
index = cfg->rc_label;
/* The first 16 labels are reserved, and may not be removed */
if (index < 16)
goto errout;
/* The full 20 bit range may not be supported */
if (index >= net->mpls.platform_labels)
goto errout;
mpls_route_update(net, index, NULL, NULL, &cfg->rc_nlinfo);
err = 0;
errout:
return err;
}
static void mpls_ifdown(struct net_device *dev)
{
struct net *net = dev_net(dev);
unsigned index;
for (index = 0; index < net->mpls.platform_labels; index++) {
struct mpls_route *rt = net->mpls.platform_label[index];
if (!rt)
continue;
if (rt->rt_dev != dev)
continue;
rt->rt_dev = NULL;
}
}
static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
void *ptr)
{
struct net_device *dev = netdev_notifier_info_to_dev(ptr);
switch(event) {
case NETDEV_UNREGISTER:
mpls_ifdown(dev);
break;
}
return NOTIFY_OK;
}
static struct notifier_block mpls_dev_notifier = {
.notifier_call = mpls_dev_notify,
};
static int resize_platform_label_table(struct net *net, size_t limit)
{
size_t size = sizeof(struct mpls_route *) * limit;
size_t old_limit;
size_t cp_size;
struct mpls_route __rcu **labels = NULL, **old;
struct mpls_route *rt0 = NULL, *rt2 = NULL;
unsigned index;
if (size) {
labels = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
if (!labels)
labels = vzalloc(size);
if (!labels)
goto nolabels;
}
/* In case the predefined labels need to be populated */
if (limit > LABEL_IPV4_EXPLICIT_NULL) {
struct net_device *lo = net->loopback_dev;
rt0 = mpls_rt_alloc(lo->addr_len);
if (!rt0)
goto nort0;
rt0->rt_dev = lo;
rt0->rt_protocol = RTPROT_KERNEL;
rt0->rt_via_family = AF_PACKET;
memcpy(rt0->rt_via, lo->dev_addr, lo->addr_len);
}
if (limit > LABEL_IPV6_EXPLICIT_NULL) {
struct net_device *lo = net->loopback_dev;
rt2 = mpls_rt_alloc(lo->addr_len);
if (!rt2)
goto nort2;
rt2->rt_dev = lo;
rt2->rt_protocol = RTPROT_KERNEL;
rt2->rt_via_family = AF_PACKET;
memcpy(rt2->rt_via, lo->dev_addr, lo->addr_len);
}
rtnl_lock();
/* Remember the original table */
old = net->mpls.platform_label;
old_limit = net->mpls.platform_labels;
/* Free any labels beyond the new table */
for (index = limit; index < old_limit; index++)
mpls_route_update(net, index, NULL, NULL, NULL);
/* Copy over the old labels */
cp_size = size;
if (old_limit < limit)
cp_size = old_limit * sizeof(struct mpls_route *);
memcpy(labels, old, cp_size);
/* If needed set the predefined labels */
if ((old_limit <= LABEL_IPV6_EXPLICIT_NULL) &&
(limit > LABEL_IPV6_EXPLICIT_NULL)) {
labels[LABEL_IPV6_EXPLICIT_NULL] = rt2;
rt2 = NULL;
}
if ((old_limit <= LABEL_IPV4_EXPLICIT_NULL) &&
(limit > LABEL_IPV4_EXPLICIT_NULL)) {
labels[LABEL_IPV4_EXPLICIT_NULL] = rt0;
rt0 = NULL;
}
/* Update the global pointers */
net->mpls.platform_labels = limit;
net->mpls.platform_label = labels;
rtnl_unlock();
mpls_rt_free(rt2);
mpls_rt_free(rt0);
if (old) {
synchronize_rcu();
kvfree(old);
}
return 0;
nort2:
mpls_rt_free(rt0);
nort0:
kvfree(labels);
nolabels:
return -ENOMEM;
}
static int mpls_platform_labels(struct ctl_table *table, int write,
void __user *buffer, size_t *lenp, loff_t *ppos)
{
struct net *net = table->data;
int platform_labels = net->mpls.platform_labels;
int ret;
struct ctl_table tmp = {
.procname = table->procname,
.data = &platform_labels,
.maxlen = sizeof(int),
.mode = table->mode,
.extra1 = &zero,
.extra2 = &label_limit,
};
ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos);
if (write && ret == 0)
ret = resize_platform_label_table(net, platform_labels);
return ret;
}
static struct ctl_table mpls_table[] = {
{
.procname = "platform_labels",
.data = NULL,
.maxlen = sizeof(int),
.mode = 0644,
.proc_handler = mpls_platform_labels,
},
{ }
};
static int mpls_net_init(struct net *net)
{
struct ctl_table *table;
net->mpls.platform_labels = 0;
net->mpls.platform_label = NULL;
table = kmemdup(mpls_table, sizeof(mpls_table), GFP_KERNEL);
if (table == NULL)
return -ENOMEM;
table[0].data = net;
net->mpls.ctl = register_net_sysctl(net, "net/mpls", table);
if (net->mpls.ctl == NULL)
return -ENOMEM;
return 0;
}
static void mpls_net_exit(struct net *net)
{
struct ctl_table *table;
unsigned int index;
table = net->mpls.ctl->ctl_table_arg;
unregister_net_sysctl_table(net->mpls.ctl);
kfree(table);
/* An rcu grace period haselapsed since there was a device in
* the network namespace (and thus the last in fqlight packet)
* left this network namespace. This is because
* unregister_netdevice_many and netdev_run_todo has completed
* for each network device that was in this network namespace.
*
* As such no additional rcu synchronization is necessary when
* freeing the platform_label table.
*/
rtnl_lock();
for (index = 0; index < net->mpls.platform_labels; index++) {
struct mpls_route *rt = net->mpls.platform_label[index];
rcu_assign_pointer(net->mpls.platform_label[index], NULL);
mpls_rt_free(rt);
}
rtnl_unlock();
kvfree(net->mpls.platform_label);
}
static struct pernet_operations mpls_net_ops = {
.init = mpls_net_init,
.exit = mpls_net_exit,
};
static int __init mpls_init(void)
{
int err;
BUILD_BUG_ON(sizeof(struct mpls_shim_hdr) != 4);
err = register_pernet_subsys(&mpls_net_ops);
if (err)
goto out;
err = register_netdevice_notifier(&mpls_dev_notifier);
if (err)
goto out_unregister_pernet;
dev_add_pack(&mpls_packet_type);
err = 0;
out:
return err;
out_unregister_pernet:
unregister_pernet_subsys(&mpls_net_ops);
goto out;
}
module_init(mpls_init);
static void __exit mpls_exit(void)
{
dev_remove_pack(&mpls_packet_type);
unregister_netdevice_notifier(&mpls_dev_notifier);
unregister_pernet_subsys(&mpls_net_ops);
}
module_exit(mpls_exit);
MODULE_DESCRIPTION("MultiProtocol Label Switching");
MODULE_LICENSE("GPL v2");
MODULE_ALIAS_NETPROTO(PF_MPLS);