aboutgitcodebugslistschat
diff options
context:
space:
mode:
-rw-r--r--tcp.c107
-rw-r--r--tcp_conn.h2
-rw-r--r--util.h28
3 files changed, 81 insertions, 56 deletions
diff --git a/tcp.c b/tcp.c
index 108006e..4e3a134 100644
--- a/tcp.c
+++ b/tcp.c
@@ -573,22 +573,12 @@ static unsigned int tcp6_l2_flags_buf_used;
#define CONN(idx) (&(FLOW(idx)->tcp))
-/** conn_at_idx() - Find a connection by index, if present
- * @idx: Index of connection to lookup
- *
- * Return: pointer to connection, or NULL if @idx is out of bounds
- */
-static inline struct tcp_tap_conn *conn_at_idx(unsigned idx)
-{
- if (idx >= FLOW_MAX)
- return NULL;
- ASSERT(CONN(idx)->f.type == FLOW_TCP);
- return CONN(idx);
-}
-
/* Table for lookup from remote address, local port, remote port */
static struct tcp_tap_conn *tc_hash[TCP_HASH_TABLE_SIZE];
+static_assert(ARRAY_SIZE(tc_hash) >= FLOW_MAX,
+ "Safe linear probing requires hash table larger than connection table");
+
/* Pools for pre-opened sockets (in init) */
int init_sock_pool4 [TCP_SOCK_POOL_SIZE];
int init_sock_pool6 [TCP_SOCK_POOL_SIZE];
@@ -1198,20 +1188,36 @@ static unsigned int tcp_conn_hash(const struct ctx *c,
}
/**
+ * tcp_hash_probe() - Find hash bucket for a connection
+ * @c: Execution context
+ * @conn: Connection to find bucket for
+ *
+ * Return: If @conn is in the table, its current bucket, otherwise a suitable
+ * free bucket for it.
+ */
+static inline unsigned tcp_hash_probe(const struct ctx *c,
+ const struct tcp_tap_conn *conn)
+{
+ unsigned b = tcp_conn_hash(c, conn);
+
+ /* Linear probing */
+ while (tc_hash[b] && tc_hash[b] != conn)
+ b = mod_sub(b, 1, TCP_HASH_TABLE_SIZE);
+
+ return b;
+}
+
+/**
* tcp_hash_insert() - Insert connection into hash table, chain link
* @c: Execution context
* @conn: Connection pointer
*/
static void tcp_hash_insert(const struct ctx *c, struct tcp_tap_conn *conn)
{
- int b;
+ unsigned b = tcp_hash_probe(c, conn);
- b = tcp_hash(c, &conn->faddr, conn->eport, conn->fport);
- conn->next_index = tc_hash[b] ? FLOW_IDX(tc_hash[b]) : -1U;
tc_hash[b] = conn;
-
- flow_dbg(conn, "hash table insert: sock %i, bucket: %i, next: %p",
- conn->sock, b, (void *)conn_at_idx(conn->next_index));
+ flow_dbg(conn, "hash table insert: sock %i, bucket: %u", conn->sock, b);
}
/**
@@ -1222,23 +1228,27 @@ static void tcp_hash_insert(const struct ctx *c, struct tcp_tap_conn *conn)
static void tcp_hash_remove(const struct ctx *c,
const struct tcp_tap_conn *conn)
{
- struct tcp_tap_conn *entry, *prev = NULL;
- int b = tcp_conn_hash(c, conn);
+ unsigned b = tcp_hash_probe(c, conn), s;
- for (entry = tc_hash[b]; entry;
- prev = entry, entry = conn_at_idx(entry->next_index)) {
- if (entry == conn) {
- if (prev)
- prev->next_index = conn->next_index;
- else
- tc_hash[b] = conn_at_idx(conn->next_index);
- break;
+ if (!tc_hash[b])
+ return; /* Redundant remove */
+
+ flow_dbg(conn, "hash table remove: sock %i, bucket: %u", conn->sock, b);
+
+ /* Scan the remainder of the cluster */
+ for (s = mod_sub(b, 1, TCP_HASH_TABLE_SIZE); tc_hash[s];
+ s = mod_sub(s, 1, TCP_HASH_TABLE_SIZE)) {
+ unsigned h = tcp_conn_hash(c, tc_hash[s]);
+
+ if (!mod_between(h, s, b, TCP_HASH_TABLE_SIZE)) {
+ /* tc_hash[s] can live in tc_hash[b]'s slot */
+ debug("hash table remove: shuffle %u -> %u", s, b);
+ tc_hash[b] = tc_hash[s];
+ b = s;
}
}
- flow_dbg(conn, "hash table remove: sock %i, bucket: %i, new: %p",
- conn->sock, b,
- (void *)(prev ? conn_at_idx(prev->next_index) : tc_hash[b]));
+ tc_hash[b] = NULL;
}
/**
@@ -1251,24 +1261,15 @@ void tcp_tap_conn_update(const struct ctx *c, struct tcp_tap_conn *old,
struct tcp_tap_conn *new)
{
- struct tcp_tap_conn *entry, *prev = NULL;
- int b = tcp_conn_hash(c, old);
+ unsigned b = tcp_hash_probe(c, old);
- for (entry = tc_hash[b]; entry;
- prev = entry, entry = conn_at_idx(entry->next_index)) {
- if (entry == old) {
- if (prev)
- prev->next_index = FLOW_IDX(new);
- else
- tc_hash[b] = new;
- break;
- }
- }
+ if (!tc_hash[b])
+ return; /* Not in hash table, nothing to update */
+
+ tc_hash[b] = new;
debug("TCP: hash table update: old index %u, new index %u, sock %i, "
- "bucket: %i, old: %p, new: %p",
- FLOW_IDX(old), FLOW_IDX(new), new->sock, b,
- (void *)old, (void *)new);
+ "bucket: %u", FLOW_IDX(old), FLOW_IDX(new), new->sock, b);
tcp_epoll_ctl(c, new);
}
@@ -1288,17 +1289,15 @@ static struct tcp_tap_conn *tcp_hash_lookup(const struct ctx *c,
in_port_t eport, in_port_t fport)
{
union inany_addr aany;
- struct tcp_tap_conn *conn;
- int b;
+ unsigned b;
inany_from_af(&aany, af, faddr);
+
b = tcp_hash(c, &aany, eport, fport);
- for (conn = tc_hash[b]; conn; conn = conn_at_idx(conn->next_index)) {
- if (tcp_hash_match(conn, &aany, eport, fport))
- return conn;
- }
+ while (tc_hash[b] && !tcp_hash_match(tc_hash[b], &aany, eport, fport))
+ b = mod_sub(b, 1, TCP_HASH_TABLE_SIZE);
- return NULL;
+ return tc_hash[b];
}
/**
diff --git a/tcp_conn.h b/tcp_conn.h
index 3900305..e3400bb 100644
--- a/tcp_conn.h
+++ b/tcp_conn.h
@@ -13,7 +13,6 @@
* struct tcp_tap_conn - Descriptor for a TCP connection (not spliced)
* @f: Generic flow information
* @in_epoll: Is the connection in the epoll set?
- * @next_index: Connection index of next item in hash chain, -1 for none
* @tap_mss: MSS advertised by tap/guest, rounded to 2 ^ TCP_MSS_BITS
* @sock: Socket descriptor number
* @events: Connection events, implying connection states
@@ -40,7 +39,6 @@ struct tcp_tap_conn {
struct flow_common f;
bool in_epoll :1;
- unsigned next_index :FLOW_INDEX_BITS + 2;
#define TCP_RETRANS_BITS 3
unsigned int retrans :TCP_RETRANS_BITS;
diff --git a/util.h b/util.h
index 53bb54b..9446ea7 100644
--- a/util.h
+++ b/util.h
@@ -227,6 +227,34 @@ int __daemon(int pidfile_fd, int devnull_fd);
int fls(unsigned long x);
int write_file(const char *path, const char *buf);
+/**
+ * mod_sub() - Modular arithmetic subtraction
+ * @a: Minued, unsigned value < @m
+ * @b: Subtrahend, unsigned value < @m
+ * @m: Modulus, must be less than (UINT_MAX / 2)
+ *
+ * Returns (@a - @b) mod @m, correctly handling unsigned underflows.
+ */
+static inline unsigned mod_sub(unsigned a, unsigned b, unsigned m)
+{
+ if (a < b)
+ a += m;
+ return a - b;
+}
+
+/**
+ * mod_between() - Determine if a value is in a cyclic range
+ * @x, @i, @j: Unsigned values < @m
+ * @m: Modulus
+ *
+ * Returns true iff @x is in the cyclic range of values from @i..@j (mod @m),
+ * inclusive of @i, exclusive of @j.
+ */
+static inline bool mod_between(unsigned x, unsigned i, unsigned j, unsigned m)
+{
+ return mod_sub(x, i, m) < mod_sub(j, i, m);
+}
+
/*
* Workarounds for https://github.com/llvm/llvm-project/issues/58992
*