diff options
Diffstat (limited to 'merd.c')
-rw-r--r-- | merd.c | 472 |
1 files changed, 271 insertions, 201 deletions
@@ -28,10 +28,12 @@ #include <sys/epoll.h> #include <sys/socket.h> #include <sys/types.h> +#include <sys/ioctl.h> #include <sys/un.h> #include <ifaddrs.h> #include <linux/if_ether.h> #include <linux/if_packet.h> +#include <linux/ip.h> #include <linux/ipv6.h> #include <linux/tcp.h> #include <linux/udp.h> @@ -45,47 +47,15 @@ #include <string.h> #include <errno.h> #include <linux/ip.h> +#include <linux/netlink.h> +#include <linux/rtnetlink.h> #include "merd.h" +#include "arp.h" +#include "dhcp.h" +#include "util.h" #define EPOLL_EVENTS 10 -#define CT_SIZE 4096 - -/** - * struct ct4 - IPv4 connection tracking entry - * @p: IANA protocol number - * @sa: Source address (as seen from tap interface) - * @da: Destination address - * @sp: Source port, network order - * @dp: Destination port, network order - * @hd: Destination MAC address - * @hs: Source MAC address - * @fd: File descriptor for corresponding AF_INET socket - */ -struct ct4 { - uint8_t p; - uint32_t sa; - uint32_t da; - uint16_t sp; - uint16_t dp; - unsigned char hd[ETH_ALEN]; - unsigned char hs[ETH_ALEN]; - int fd; -}; - -/** - * struct ctx - Execution context - * @epollfd: file descriptor for epoll instance - * @ext_addr4: IPv4 address for external, routable interface - * @fd_unix: AF_UNIX socket for tap file descriptor - * @map4: Connection tracking table - */ -struct ctx { - int epollfd; - unsigned long ext_addr4; - int fd_unix; - struct ct4 map4[CT_SIZE]; -}; /** * sock_unix() - Create and bind AF_UNIX socket, add to epoll list @@ -94,13 +64,12 @@ struct ctx { */ static int sock_unix(void) { + int fd = socket(AF_UNIX, SOCK_STREAM, 0); struct sockaddr_un addr = { .sun_family = AF_UNIX, .sun_path = UNIX_SOCK_PATH, }; - int fd; - fd = socket(AF_UNIX, SOCK_STREAM, 0); if (fd < 0) { perror("UNIX socket"); exit(EXIT_FAILURE); @@ -115,23 +84,118 @@ static int sock_unix(void) } /** - * getaddrs_ext() - Fetch IP addresses of external routable interface + * struct nl_request - Netlink request filled and sent by get_routes() + * @nlh: Netlink message header + * @rtm: Routing Netlink message + */ +struct nl_request { + struct nlmsghdr nlh; + struct rtmsg rtm; +}; + +/** + * get_routes() - Get default route and fill in routable interface name + * @c: Execution context + */ +static void get_routes(struct ctx *c) +{ + struct nl_request req = { + .nlh.nlmsg_type = RTM_GETROUTE, + .nlh.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP | NLM_F_EXCL, + .nlh.nlmsg_len = sizeof(struct nl_request), + .nlh.nlmsg_seq = 1, + .rtm.rtm_family = AF_INET, + .rtm.rtm_table = RT_TABLE_MAIN, + .rtm.rtm_scope = RT_SCOPE_UNIVERSE, + .rtm.rtm_type = RTN_UNICAST, + }; + struct sockaddr_nl addr = { + .nl_family = AF_NETLINK, + }; + int s, n, na, found = 0; + struct nlmsghdr *nlh; + struct rtattr *rta; + struct rtmsg *rtm; + char buf[BUFSIZ]; + + s = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); + if (s < 0) { + perror("netlink socket"); + goto out; + } + + if (bind(s, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + perror("netlink bind"); + goto out; + } + + if (send(s, &req, sizeof(req), 0) < 0) { + perror("netlink send"); + goto out; + } + + n = recv(s, &buf, sizeof(buf), 0); + if (n < 0) { + perror("netlink recv"); + goto out; + } + + nlh = (struct nlmsghdr *)buf; + if (nlh->nlmsg_type == NLMSG_DONE) + goto out; + + for ( ; NLMSG_OK(nlh, n) && found < 2; NLMSG_NEXT(nlh, n)) { + rtm = (struct rtmsg *)NLMSG_DATA(nlh); + + if (rtm->rtm_dst_len) + continue; + + rta = (struct rtattr *)RTM_RTA(rtm); + na = RTM_PAYLOAD(nlh); + for ( ; RTA_OK(rta, na) && found < 2; rta = RTA_NEXT(rta, na)) { + if (rta->rta_type == RTA_GATEWAY) { + memcpy(&c->gw4, RTA_DATA(rta), sizeof(c->gw4)); + found++; + } + + if (rta->rta_type == RTA_OIF) { + if_indextoname(*(unsigned *)RTA_DATA(rta), + c->ifn); + found++; + } + } + } + +out: + close(s); + + if (found < 2) { + fprintf(stderr, "No routing information\n"); + exit(EXIT_FAILURE); + } +} + +/** + * get_addrs() - Fetch MAC, IP addresses, masks of external routable interface * @c: Execution context - * @ifn: Name of external interface */ -static void getaddrs_ext(struct ctx *c, const char *ifn) +static void get_addrs(struct ctx *c) { + struct ifreq ifr = { + .ifr_addr.sa_family = AF_INET, + }; struct ifaddrs *ifaddr, *ifa; + int s; if (getifaddrs(&ifaddr) == -1) { perror("getifaddrs"); - exit(EXIT_FAILURE); + goto out; } - for (ifa = ifaddr; ifa; ifa = ifa->ifa_next) { + for (ifa = ifaddr; ifa && !c->addr4; ifa = ifa->ifa_next) { struct sockaddr_in *in_addr; - if (strcmp(ifa->ifa_name, ifn)) + if (strcmp(ifa->ifa_name, c->ifn)) continue; if (!ifa->ifa_addr) @@ -141,13 +205,61 @@ static void getaddrs_ext(struct ctx *c, const char *ifn) continue; in_addr = (struct sockaddr_in *)ifa->ifa_addr; - c->ext_addr4 = in_addr->sin_addr.s_addr; - freeifaddrs(ifaddr); - return; + c->addr4 = in_addr->sin_addr.s_addr; + in_addr = (struct sockaddr_in *)ifa->ifa_netmask; + c->mask4 = in_addr->sin_addr.s_addr; } - fprintf(stderr, "Couldn't get IPv4 address for external interface\n"); freeifaddrs(ifaddr); + + s = socket(AF_INET, SOCK_DGRAM, 0); + if (s < 0) { + perror("socket SIOCGIFHWADDR"); + goto out; + } + + strncpy(ifr.ifr_name, c->ifn, IF_NAMESIZE); + if (ioctl(s, SIOCGIFHWADDR, &ifr) < 0) { + perror("SIOCGIFHWADDR"); + goto out; + } + + close(s); + memcpy(c->mac, ifr.ifr_hwaddr.sa_data, ETH_ALEN); + + return; +out: + fprintf(stderr, "Couldn't get addresses for routable interface\n"); + exit(EXIT_FAILURE); +} + +/** + * get_dns() - Get nameserver addresses from local /etc/resolv.conf + * @c: Execution context + */ +static void get_dns(struct ctx *c) +{ + char buf[BUFSIZ], *p, *nl; + int dns4 = 0; + FILE *r; + + r = fopen("/etc/resolv.conf", "r"); + while (fgets(buf, BUFSIZ, r) && !dns4) { + if (!strstr(buf, "nameserver ")) + continue; + p = strrchr(buf, ' '); + nl = strchr(buf, '\n'); + if (nl) + *nl = 0; + if (p && inet_pton(AF_INET, p + 1, &c->dns4)) + dns4 = 1; + } + + fclose(r); + if (dns4) + return; + + fprintf(stderr, "Couldn't get IPv4 nameserver address\n"); exit(EXIT_FAILURE); } @@ -164,7 +276,7 @@ static int sock4_l4(struct ctx *c, uint16_t proto, uint16_t port) struct sockaddr_in addr = { .sin_family = AF_INET, .sin_port = port, - .sin_addr = { .s_addr = c->ext_addr4 }, + .sin_addr = { .s_addr = c->addr4 }, }; struct epoll_event ev = { 0 }; int fd; @@ -176,7 +288,7 @@ static int sock4_l4(struct ctx *c, uint16_t proto, uint16_t port) } if (bind(fd, (const struct sockaddr *)&addr, sizeof(addr)) < 0) { - perror("bind"); + perror("L4 bind"); close(fd); return -1; } @@ -184,7 +296,7 @@ static int sock4_l4(struct ctx *c, uint16_t proto, uint16_t port) ev.events = EPOLLIN; ev.data.fd = fd; if (epoll_ctl(c->epollfd, EPOLL_CTL_ADD, fd, &ev) == -1) { - perror("epoll_ctl"); + perror("L4 epoll_ctl"); return -1; } @@ -192,61 +304,50 @@ static int sock4_l4(struct ctx *c, uint16_t proto, uint16_t port) } /** - * usage() - Print usage and exit - * @name: Executable name - */ -void usage(const char *name) -{ - fprintf(stderr, "Usage: %s IF_EXT\n", name); - - exit(EXIT_FAILURE); -} - -/** * lookup4() - Look up socket entry from tap-sourced packet, create if missing * @c: Execution context - * @in: Packet buffer, L2 headers + * @eh: Packet buffer, Ethernet header * * Return: -1 for unsupported or too many sockets, matching socket otherwise */ -static int lookup4(struct ctx *c, const char *in) +static int lookup4(struct ctx *c, const struct ethhdr *eh) { + struct iphdr *iph = (struct iphdr *)(eh + 1); + struct tcphdr *th = (struct tcphdr *)((char *)iph + iph->ihl * 4); char buf_s[BUFSIZ], buf_d[BUFSIZ]; struct ct4 *ct = c->map4; - struct tcphdr *th; - struct iphdr *iph; - struct ethhdr *eh; - int i; - - eh = (struct ethhdr *)in; - iph = (struct iphdr *)(in + ETH_HLEN); - th = (struct tcphdr *)(iph + 1); + int i, one_icmp_fd = 0; - switch (iph->protocol) { - case IPPROTO_ICMP: - case IPPROTO_TCP: - case IPPROTO_UDP: - break; - default: + if (iph->protocol != IPPROTO_ICMP && iph->protocol != IPPROTO_TCP && + iph->protocol != IPPROTO_UDP) return -1; - } for (i = 0; i < CT_SIZE; i++) { - if (ct[i].p == iph->protocol && - ct[i].sa == iph->saddr && ct[i].da == iph->daddr && - (ct[i].p == IPPROTO_ICMP || - (ct[i].sp == th->source && ct[i].dp == th->dest)) && + if (ct[i].p == iph->protocol && ct[i].sa == iph->saddr && + ((ct[i].p == IPPROTO_ICMP && ct[i].da == iph->daddr) + || ct[i].sp == th->source) && !memcmp(ct[i].hd, eh->h_dest, ETH_ALEN) && - !memcmp(ct[i].hs, eh->h_source, ETH_ALEN)) + !memcmp(ct[i].hs, eh->h_source, ETH_ALEN)) { + if (iph->protocol != IPPROTO_ICMP) { + ct[i].da = iph->daddr; + ct[i].dp = th->dest; + } return ct[i].fd; + } } - for (i = 0; i < CT_SIZE && ct[i].p; i++); + for (i = 0; i < CT_SIZE && ct[i].p; i++) { + if (iph->protocol == IPPROTO_ICMP) + one_icmp_fd = ct[i].fd; + } if (i == CT_SIZE) { fprintf(stderr, "\nToo many sockets, aborting "); } else { - ct[i].fd = sock4_l4(c, iph->protocol, th->source); + if (iph->protocol == IPPROTO_ICMP && one_icmp_fd) + ct[i].fd = one_icmp_fd; + else + ct[i].fd = sock4_l4(c, iph->protocol, th->source); fprintf(stderr, "\n(socket %i) New ", ct[i].fd); ct[i].p = iph->protocol; @@ -279,21 +380,19 @@ static int lookup4(struct ctx *c, const char *in) /** * lookup4_r4() - Reverse look up connection tracking entry from incoming packet * @ct: Connection tracking table - * @in: Packet buffer, L3 headers + * @fd: File descriptor that received the packet + * @iph: Packet buffer, IP header * * Return: matching entry if any, NULL otherwise */ -struct ct4 *lookup_r4(struct ct4 *ct, const char *in) +struct ct4 *lookup_r4(struct ct4 *ct, int fd, struct iphdr *iph) { - struct tcphdr *th; - struct iphdr *iph; + struct tcphdr *th = (struct tcphdr *)((char *)iph + iph->ihl * 4); int i; - iph = (struct iphdr *)in; - th = (struct tcphdr *)(iph + 1); - for (i = 0; i < CT_SIZE; i++) { - if (iph->protocol == ct[i].p && + if (ct[i].fd == fd && + iph->protocol == ct[i].p && iph->saddr == ct[i].da && (iph->protocol == IPPROTO_ICMP || (th->source == ct[i].dp && th->dest == ct[i].sp))) @@ -306,95 +405,49 @@ struct ct4 *lookup_r4(struct ct4 *ct, const char *in) /** * nat4_out() - Perform outgoing IPv4 address translation * @addr: Source address to be used - * @in: Packet buffer, L3 headers + * @iph: IP header */ -static void nat4_out(unsigned long addr, const char *in) +static void nat4_out(unsigned long addr, struct iphdr *iph) { - struct iphdr *iph = (struct iphdr *)in; - iph->saddr = addr; } /** * nat4_in() - Perform incoming IPv4 address translation * @addr: Original destination address to be used - * @in: Packet buffer, L3 headers + * @iph: IP header */ -static void nat_in(unsigned long addr, const char *in) +static void nat_in(unsigned long addr, struct iphdr *iph) { - struct iphdr *iph = (struct iphdr *)in; - iph->daddr = addr; } /** - * csum_fold() - Fold long sum for IP and TCP checksum - * @sum: Original long sum - * - * Return: 16-bit folded sum - */ -static uint16_t csum_fold(uint32_t sum) -{ - while (sum >> 16) - sum = (sum & 0xffff) + (sum >> 16); - - return sum; -} - -/** - * csum_ipv4() - Calculate IPv4 checksum - * @buf: Packet buffer, L3 headers - * @len: Total L3 packet length - * - * Return: 16-bit IPv4-style checksum - */ -static uint16_t csum_ip4(void *buf, size_t len) -{ - uint32_t sum = 0; - uint16_t *p = buf; - size_t len1 = len / 2; - size_t off; - - for (off = 0; off < len1; off++, p++) - sum += *p; - - if (len % 2) - sum += *p & 0xff; - - return ~csum_fold(sum); -} - -/** * csum_ipv4() - Calculate TCP checksum for IPv4 and set in place - * @in: Packet buffer, L3 headers + * @iph: Packet buffer, IP header */ -static void csum_tcp4(uint16_t *in) +static void csum_tcp4(struct iphdr *iph) { - struct iphdr *iph = (struct iphdr *)in; - struct tcphdr *th; - uint16_t tcp_len; + struct tcphdr *th = (struct tcphdr *)((char *)iph + iph->ihl * 4); + uint16_t tlen = ntohs(iph->tot_len) - iph->ihl * 4, *p = (uint16_t *)th; uint32_t sum = 0; - tcp_len = ntohs(iph->tot_len) - (iph->ihl << 2); - th = (struct tcphdr *)(iph + 1); - in = (uint16_t *)th; - sum += (iph->saddr >> 16) & 0xffff; sum += iph->saddr & 0xffff; sum += (iph->daddr >> 16) & 0xffff; sum += iph->daddr & 0xffff; sum += htons(IPPROTO_TCP); - sum += htons(tcp_len); + sum += htons(tlen); th->check = 0; - while (tcp_len > 1) { - sum += *in++; - tcp_len -= 2; + while (tlen > 1) { + sum += *p++; + tlen -= 2; } - if (tcp_len > 0) { - sum += *in & htons(0xff00); + if (tlen > 0) { + sum += *p & htons(0xff00); } th->check = (uint16_t)~csum_fold(sum); @@ -408,9 +461,10 @@ static void csum_tcp4(uint16_t *in) */ static void tap4_handler(struct ctx *c, int len, char *in) { - struct iphdr *iph = (struct iphdr *)(in + ETH_HLEN); - struct tcphdr *th = (struct tcphdr *)(iph + 1); - struct udphdr *uh = (struct udphdr *)(iph + 1); + struct ethhdr *eh = (struct ethhdr *)in; + struct iphdr *iph = (struct iphdr *)(eh + 1); + struct tcphdr *th = (struct tcphdr *)((char *)iph + iph->ihl * 4); + struct udphdr *uh = (struct udphdr *)th; struct sockaddr_in addr = { .sin_family = AF_INET, .sin_port = th->dest, @@ -419,7 +473,10 @@ static void tap4_handler(struct ctx *c, int len, char *in) char buf_s[BUFSIZ], buf_d[BUFSIZ]; int fd; - fd = lookup4(c, in); + if (arp(c, len, eh) || dhcp(c, len, eh)) + return; + + fd = lookup4(c, eh); if (fd == -1) return; @@ -438,65 +495,56 @@ static void tap4_handler(struct ctx *c, int len, char *in) fd); } - nat4_out(c->ext_addr4, in + ETH_HLEN); - - switch (iph->protocol) { - case IPPROTO_TCP: - csum_tcp4((uint16_t *)(in + ETH_HLEN)); - break; - case IPPROTO_UDP: + if (iph->protocol == IPPROTO_TCP) + csum_tcp4(iph); + else if (iph->protocol == IPPROTO_UDP) uh->check = 0; - break; - case IPPROTO_ICMP: - break; - default: + else if (iph->protocol != IPPROTO_ICMP) return; - } - if (sendto(fd, in + sizeof(struct ethhdr) + sizeof(struct iphdr), - len - sizeof(struct ethhdr) - 4 * iph->ihl, 0, + nat4_out(c->addr4, iph); + + if (sendto(fd, (void *)th, len - sizeof(*eh) - iph->ihl * 4, 0, (struct sockaddr *)&addr, sizeof(addr)) < 0) perror("sendto"); } /** - * tap4_handler() - Packet handler for external routable interface + * ext4_handler() - Packet handler for external routable interface * @c: Execution context + * @fd: File descriptor that received the packet * @len: Total L3 packet length * @in: Packet buffer, L3 headers */ -static void ext4_handler(struct ctx *c, int len, char *in) +static void ext4_handler(struct ctx *c, int fd, int len, char *in) { struct iphdr *iph = (struct iphdr *)in; - struct tcphdr *th = (struct tcphdr *)(iph + 1); - char buf_s[BUFSIZ], buf_d[BUFSIZ]; - struct ethhdr *eh; + struct tcphdr *th = (struct tcphdr *)((char *)iph + iph->ihl * 4); + struct udphdr *uh = (struct udphdr *)th; + char buf_s[BUFSIZ], buf_d[BUFSIZ], buf[ETH_MAX_MTU]; + struct ethhdr *eh = (struct ethhdr *)buf; struct ct4 *entry; - char buf[1 << 16]; - entry = lookup_r4(c->map4, in); + entry = lookup_r4(c->map4, fd, iph); if (!entry) return; - nat_in(entry->sa, in); + nat_in(entry->sa, iph); iph->check = 0; - iph->check = csum_ip4(iph, 4 * iph->ihl); + iph->check = csum_ip4(iph, iph->ihl * 4); if (iph->protocol == IPPROTO_TCP) - csum_tcp4((uint16_t *)in); - else if (iph->protocol == IPPROTO_UDP) { - struct udphdr *uh = (struct udphdr *)(iph + 1); + csum_tcp4(iph); + else if (iph->protocol == IPPROTO_UDP) uh->check = 0; - } - eh = (struct ethhdr *)buf; memcpy(eh->h_dest, entry->hs, ETH_ALEN); memcpy(eh->h_source, entry->hd, ETH_ALEN); eh->h_proto = ntohs(ETH_P_IP); - memcpy(buf + sizeof(struct ethhdr), in, len); + memcpy(eh + 1, in, len); if (iph->protocol == IPPROTO_ICMP) { fprintf(stderr, "icmp (socket %i) to tap: %s -> %s\n", @@ -513,11 +561,22 @@ static void ext4_handler(struct ctx *c, int len, char *in) ntohs(th->dest)); } - if (send(c->fd_unix, buf, len + sizeof(struct ethhdr), 0) < 0) + if (send(c->fd_unix, buf, len + sizeof(*eh), 0) < 0) perror("send"); } /** + * usage() - Print usage and exit + * @name: Executable name + */ +void usage(const char *name) +{ + fprintf(stderr, "Usage: %s\n", name); + + exit(EXIT_FAILURE); +} + +/** * main() - Entry point and main loop * @argc: Argument count * @argv: Interface names @@ -527,18 +586,30 @@ static void ext4_handler(struct ctx *c, int len, char *in) int main(int argc, char **argv) { struct epoll_event events[EPOLL_EVENTS]; + char buf4[4][sizeof("255.255.255.255")]; struct epoll_event ev = { 0 }; + char buf[ETH_MAX_MTU]; struct ctx c = { 0 }; - const char *if_ext; - char buf[1 << 16]; int nfds, i, len; int fd_unix; - if (argc != 2) + if (argc != 1) usage(argv[0]); - if_ext = argv[1]; - getaddrs_ext(&c, if_ext); + get_routes(&c); + get_addrs(&c); + get_dns(&c); + + fprintf(stderr, "ARP:\n"); + fprintf(stderr, "\taddress: %02x:%02x:%02x:%02x:%02x:%02x from %s\n", + c.mac[0], c.mac[1], c.mac[2], c.mac[3], c.mac[4], c.mac[5], + c.ifn); + fprintf(stderr, "DHCP:\n"); + fprintf(stderr, "\tassign: %s, mask: %s, router: %s, DNS: %s\n\n", + inet_ntop(AF_INET, &c.addr4, buf4[0], sizeof(buf4[0])), + inet_ntop(AF_INET, &c.mask4, buf4[1], sizeof(buf4[1])), + inet_ntop(AF_INET, &c.gw4, buf4[2], sizeof(buf4[2])), + inet_ntop(AF_INET, &c.dns4, buf4[3], sizeof(buf4[3]))); c.epollfd = epoll_create1(0); if (c.epollfd == -1) { @@ -549,18 +620,17 @@ int main(int argc, char **argv) fd_unix = sock_unix(); listen: listen(fd_unix, 1); - fprintf(stderr, - "You can now start qrap:\n\t" - "./qrap 42 kvm ... -net tap,fd=42 -net nic,model=virtio ...\n"); - c.fd_unix = accept(fd_unix, NULL, NULL); ev.events = EPOLLIN; ev.data.fd = c.fd_unix; epoll_ctl(c.epollfd, EPOLL_CTL_ADD, c.fd_unix, &ev); + fprintf(stderr, + "You can now start qrap:\n\t" + "./qrap 42 kvm ... -net tap,fd=42 -net nic,model=virtio\n\n"); loop: nfds = epoll_wait(c.epollfd, events, EPOLL_EVENTS, -1); - if (nfds == -1) { + if (nfds == -1 && errno != EINTR) { perror("epoll_wait"); exit(EXIT_FAILURE); } @@ -574,7 +644,7 @@ loop: goto listen; } - if (len == 0) + if (len == 0 || (len < 0 && errno == EINTR)) continue; if (len < 0) { @@ -586,7 +656,7 @@ loop: if (events[i].data.fd == c.fd_unix) tap4_handler(&c, len, buf); else - ext4_handler(&c, len, buf); + ext4_handler(&c, events[i].data.fd, len, buf); } goto loop; |