1
2
3
4
5
6
7#define pr_fmt(fmt) "MPTCP: " fmt
8
9#include <linux/inet.h>
10#include <linux/kernel.h>
11#include <net/tcp.h>
12#include <net/netns/generic.h>
13#include <net/mptcp.h>
14#include <net/genetlink.h>
15#include <uapi/linux/mptcp.h>
16
17#include "protocol.h"
18#include "mib.h"
19
20
21static struct genl_family mptcp_genl_family;
22
23static int pm_nl_pernet_id;
24
25struct mptcp_pm_addr_entry {
26 struct list_head list;
27 struct mptcp_addr_info addr;
28 u8 flags;
29 int ifindex;
30 struct socket *lsk;
31};
32
33struct mptcp_pm_add_entry {
34 struct list_head list;
35 struct mptcp_addr_info addr;
36 struct timer_list add_timer;
37 struct mptcp_sock *sock;
38 u8 retrans_times;
39};
40
41#define MAX_ADDR_ID 255
42#define BITMAP_SZ DIV_ROUND_UP(MAX_ADDR_ID + 1, BITS_PER_LONG)
43
44struct pm_nl_pernet {
45
46 spinlock_t lock;
47 struct list_head local_addr_list;
48 unsigned int addrs;
49 unsigned int add_addr_signal_max;
50 unsigned int add_addr_accept_max;
51 unsigned int local_addr_max;
52 unsigned int subflows_max;
53 unsigned int next_id;
54 unsigned long id_bitmap[BITMAP_SZ];
55};
56
57#define MPTCP_PM_ADDR_MAX 8
58#define ADD_ADDR_RETRANS_MAX 3
59
60static bool addresses_equal(const struct mptcp_addr_info *a,
61 struct mptcp_addr_info *b, bool use_port)
62{
63 bool addr_equals = false;
64
65 if (a->family == b->family) {
66 if (a->family == AF_INET)
67 addr_equals = a->addr.s_addr == b->addr.s_addr;
68#if IS_ENABLED(CONFIG_MPTCP_IPV6)
69 else
70 addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6);
71 } else if (a->family == AF_INET) {
72 if (ipv6_addr_v4mapped(&b->addr6))
73 addr_equals = a->addr.s_addr == b->addr6.s6_addr32[3];
74 } else if (b->family == AF_INET) {
75 if (ipv6_addr_v4mapped(&a->addr6))
76 addr_equals = a->addr6.s6_addr32[3] == b->addr.s_addr;
77#endif
78 }
79
80 if (!addr_equals)
81 return false;
82 if (!use_port)
83 return true;
84
85 return a->port == b->port;
86}
87
88static bool address_zero(const struct mptcp_addr_info *addr)
89{
90 struct mptcp_addr_info zero;
91
92 memset(&zero, 0, sizeof(zero));
93 zero.family = addr->family;
94
95 return addresses_equal(addr, &zero, true);
96}
97
98static void local_address(const struct sock_common *skc,
99 struct mptcp_addr_info *addr)
100{
101 addr->family = skc->skc_family;
102 addr->port = htons(skc->skc_num);
103 if (addr->family == AF_INET)
104 addr->addr.s_addr = skc->skc_rcv_saddr;
105#if IS_ENABLED(CONFIG_MPTCP_IPV6)
106 else if (addr->family == AF_INET6)
107 addr->addr6 = skc->skc_v6_rcv_saddr;
108#endif
109}
110
111static void remote_address(const struct sock_common *skc,
112 struct mptcp_addr_info *addr)
113{
114 addr->family = skc->skc_family;
115 addr->port = skc->skc_dport;
116 if (addr->family == AF_INET)
117 addr->addr.s_addr = skc->skc_daddr;
118#if IS_ENABLED(CONFIG_MPTCP_IPV6)
119 else if (addr->family == AF_INET6)
120 addr->addr6 = skc->skc_v6_daddr;
121#endif
122}
123
124static bool lookup_subflow_by_saddr(const struct list_head *list,
125 struct mptcp_addr_info *saddr)
126{
127 struct mptcp_subflow_context *subflow;
128 struct mptcp_addr_info cur;
129 struct sock_common *skc;
130
131 list_for_each_entry(subflow, list, node) {
132 skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
133
134 local_address(skc, &cur);
135 if (addresses_equal(&cur, saddr, saddr->port))
136 return true;
137 }
138
139 return false;
140}
141
142static bool lookup_subflow_by_daddr(const struct list_head *list,
143 struct mptcp_addr_info *daddr)
144{
145 struct mptcp_subflow_context *subflow;
146 struct mptcp_addr_info cur;
147 struct sock_common *skc;
148
149 list_for_each_entry(subflow, list, node) {
150 skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
151
152 remote_address(skc, &cur);
153 if (addresses_equal(&cur, daddr, daddr->port))
154 return true;
155 }
156
157 return false;
158}
159
160static struct mptcp_pm_addr_entry *
161select_local_address(const struct pm_nl_pernet *pernet,
162 struct mptcp_sock *msk)
163{
164 struct mptcp_pm_addr_entry *entry, *ret = NULL;
165 struct sock *sk = (struct sock *)msk;
166
167 msk_owned_by_me(msk);
168
169 rcu_read_lock();
170 __mptcp_flush_join_list(msk);
171 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
172 if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW))
173 continue;
174
175 if (entry->addr.family != sk->sk_family) {
176#if IS_ENABLED(CONFIG_MPTCP_IPV6)
177 if ((entry->addr.family == AF_INET &&
178 !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) ||
179 (sk->sk_family == AF_INET &&
180 !ipv6_addr_v4mapped(&entry->addr.addr6)))
181#endif
182 continue;
183 }
184
185
186
187
188 if (!lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) {
189 ret = entry;
190 break;
191 }
192 }
193 rcu_read_unlock();
194 return ret;
195}
196
197static struct mptcp_pm_addr_entry *
198select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos)
199{
200 struct mptcp_pm_addr_entry *entry, *ret = NULL;
201 int i = 0;
202
203 rcu_read_lock();
204
205
206
207
208
209 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
210 if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL))
211 continue;
212 if (i++ == pos) {
213 ret = entry;
214 break;
215 }
216 }
217 rcu_read_unlock();
218 return ret;
219}
220
221unsigned int mptcp_pm_get_add_addr_signal_max(struct mptcp_sock *msk)
222{
223 struct pm_nl_pernet *pernet;
224
225 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
226 return READ_ONCE(pernet->add_addr_signal_max);
227}
228EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max);
229
230unsigned int mptcp_pm_get_add_addr_accept_max(struct mptcp_sock *msk)
231{
232 struct pm_nl_pernet *pernet;
233
234 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
235 return READ_ONCE(pernet->add_addr_accept_max);
236}
237EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max);
238
239unsigned int mptcp_pm_get_subflows_max(struct mptcp_sock *msk)
240{
241 struct pm_nl_pernet *pernet;
242
243 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
244 return READ_ONCE(pernet->subflows_max);
245}
246EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max);
247
248unsigned int mptcp_pm_get_local_addr_max(struct mptcp_sock *msk)
249{
250 struct pm_nl_pernet *pernet;
251
252 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
253 return READ_ONCE(pernet->local_addr_max);
254}
255EXPORT_SYMBOL_GPL(mptcp_pm_get_local_addr_max);
256
257static void check_work_pending(struct mptcp_sock *msk)
258{
259 if (msk->pm.add_addr_signaled == mptcp_pm_get_add_addr_signal_max(msk) &&
260 (msk->pm.local_addr_used == mptcp_pm_get_local_addr_max(msk) ||
261 msk->pm.subflows == mptcp_pm_get_subflows_max(msk)))
262 WRITE_ONCE(msk->pm.work_pending, false);
263}
264
265struct mptcp_pm_add_entry *
266mptcp_lookup_anno_list_by_saddr(struct mptcp_sock *msk,
267 struct mptcp_addr_info *addr)
268{
269 struct mptcp_pm_add_entry *entry;
270
271 lockdep_assert_held(&msk->pm.lock);
272
273 list_for_each_entry(entry, &msk->pm.anno_list, list) {
274 if (addresses_equal(&entry->addr, addr, true))
275 return entry;
276 }
277
278 return NULL;
279}
280
281bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk)
282{
283 struct mptcp_pm_add_entry *entry;
284 struct mptcp_addr_info saddr;
285 bool ret = false;
286
287 local_address((struct sock_common *)sk, &saddr);
288
289 spin_lock_bh(&msk->pm.lock);
290 list_for_each_entry(entry, &msk->pm.anno_list, list) {
291 if (addresses_equal(&entry->addr, &saddr, true)) {
292 ret = true;
293 goto out;
294 }
295 }
296
297out:
298 spin_unlock_bh(&msk->pm.lock);
299 return ret;
300}
301
302static void mptcp_pm_add_timer(struct timer_list *timer)
303{
304 struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer);
305 struct mptcp_sock *msk = entry->sock;
306 struct sock *sk = (struct sock *)msk;
307
308 pr_debug("msk=%p", msk);
309
310 if (!msk)
311 return;
312
313 if (inet_sk_state_load(sk) == TCP_CLOSE)
314 return;
315
316 if (!entry->addr.id)
317 return;
318
319 if (mptcp_pm_should_add_signal(msk)) {
320 sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX / 8);
321 goto out;
322 }
323
324 spin_lock_bh(&msk->pm.lock);
325
326 if (!mptcp_pm_should_add_signal(msk)) {
327 pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id);
328 mptcp_pm_announce_addr(msk, &entry->addr, false);
329 mptcp_pm_add_addr_send_ack(msk);
330 entry->retrans_times++;
331 }
332
333 if (entry->retrans_times < ADD_ADDR_RETRANS_MAX)
334 sk_reset_timer(sk, timer,
335 jiffies + mptcp_get_add_addr_timeout(sock_net(sk)));
336
337 spin_unlock_bh(&msk->pm.lock);
338
339 if (entry->retrans_times == ADD_ADDR_RETRANS_MAX)
340 mptcp_pm_subflow_established(msk);
341
342out:
343 __sock_put(sk);
344}
345
346struct mptcp_pm_add_entry *
347mptcp_pm_del_add_timer(struct mptcp_sock *msk,
348 struct mptcp_addr_info *addr, bool check_id)
349{
350 struct mptcp_pm_add_entry *entry;
351 struct sock *sk = (struct sock *)msk;
352
353 spin_lock_bh(&msk->pm.lock);
354 entry = mptcp_lookup_anno_list_by_saddr(msk, addr);
355 if (entry && (!check_id || entry->addr.id == addr->id))
356 entry->retrans_times = ADD_ADDR_RETRANS_MAX;
357 spin_unlock_bh(&msk->pm.lock);
358
359 if (entry && (!check_id || entry->addr.id == addr->id))
360 sk_stop_timer_sync(sk, &entry->add_timer);
361
362 return entry;
363}
364
365static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk,
366 struct mptcp_pm_addr_entry *entry)
367{
368 struct mptcp_pm_add_entry *add_entry = NULL;
369 struct sock *sk = (struct sock *)msk;
370 struct net *net = sock_net(sk);
371
372 lockdep_assert_held(&msk->pm.lock);
373
374 if (mptcp_lookup_anno_list_by_saddr(msk, &entry->addr))
375 return false;
376
377 add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC);
378 if (!add_entry)
379 return false;
380
381 list_add(&add_entry->list, &msk->pm.anno_list);
382
383 add_entry->addr = entry->addr;
384 add_entry->sock = msk;
385 add_entry->retrans_times = 0;
386
387 timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0);
388 sk_reset_timer(sk, &add_entry->add_timer,
389 jiffies + mptcp_get_add_addr_timeout(net));
390
391 return true;
392}
393
394void mptcp_pm_free_anno_list(struct mptcp_sock *msk)
395{
396 struct mptcp_pm_add_entry *entry, *tmp;
397 struct sock *sk = (struct sock *)msk;
398 LIST_HEAD(free_list);
399
400 pr_debug("msk=%p", msk);
401
402 spin_lock_bh(&msk->pm.lock);
403 list_splice_init(&msk->pm.anno_list, &free_list);
404 spin_unlock_bh(&msk->pm.lock);
405
406 list_for_each_entry_safe(entry, tmp, &free_list, list) {
407 sk_stop_timer_sync(sk, &entry->add_timer);
408 kfree(entry);
409 }
410}
411
412static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
413{
414 struct sock *sk = (struct sock *)msk;
415 struct mptcp_pm_addr_entry *local;
416 unsigned int add_addr_signal_max;
417 unsigned int local_addr_max;
418 struct pm_nl_pernet *pernet;
419 unsigned int subflows_max;
420
421 pernet = net_generic(sock_net(sk), pm_nl_pernet_id);
422
423 add_addr_signal_max = mptcp_pm_get_add_addr_signal_max(msk);
424 local_addr_max = mptcp_pm_get_local_addr_max(msk);
425 subflows_max = mptcp_pm_get_subflows_max(msk);
426
427 pr_debug("local %d:%d signal %d:%d subflows %d:%d\n",
428 msk->pm.local_addr_used, local_addr_max,
429 msk->pm.add_addr_signaled, add_addr_signal_max,
430 msk->pm.subflows, subflows_max);
431
432
433 if (msk->pm.add_addr_signaled < add_addr_signal_max) {
434 local = select_signal_address(pernet,
435 msk->pm.add_addr_signaled);
436
437 if (local) {
438 if (mptcp_pm_alloc_anno_list(msk, local)) {
439 msk->pm.add_addr_signaled++;
440 mptcp_pm_announce_addr(msk, &local->addr, false);
441 mptcp_pm_nl_addr_send_ack(msk);
442 }
443 } else {
444
445 msk->pm.local_addr_used = add_addr_signal_max;
446 }
447
448 check_work_pending(msk);
449 }
450
451
452 if (msk->pm.local_addr_used < local_addr_max &&
453 msk->pm.subflows < subflows_max &&
454 !READ_ONCE(msk->pm.remote_deny_join_id0)) {
455 local = select_local_address(pernet, msk);
456 if (local) {
457 struct mptcp_addr_info remote = { 0 };
458
459 msk->pm.local_addr_used++;
460 msk->pm.subflows++;
461 check_work_pending(msk);
462 remote_address((struct sock_common *)sk, &remote);
463 spin_unlock_bh(&msk->pm.lock);
464 __mptcp_subflow_connect(sk, &local->addr, &remote,
465 local->flags, local->ifindex);
466 spin_lock_bh(&msk->pm.lock);
467 return;
468 }
469
470
471 msk->pm.local_addr_used = local_addr_max;
472 check_work_pending(msk);
473 }
474}
475
476static void mptcp_pm_nl_fully_established(struct mptcp_sock *msk)
477{
478 mptcp_pm_create_subflow_or_signal_addr(msk);
479}
480
481static void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk)
482{
483 mptcp_pm_create_subflow_or_signal_addr(msk);
484}
485
486static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
487{
488 struct sock *sk = (struct sock *)msk;
489 unsigned int add_addr_accept_max;
490 struct mptcp_addr_info remote;
491 struct mptcp_addr_info local;
492 unsigned int subflows_max;
493
494 add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk);
495 subflows_max = mptcp_pm_get_subflows_max(msk);
496
497 pr_debug("accepted %d:%d remote family %d",
498 msk->pm.add_addr_accepted, add_addr_accept_max,
499 msk->pm.remote.family);
500
501 if (lookup_subflow_by_daddr(&msk->conn_list, &msk->pm.remote))
502 goto add_addr_echo;
503
504 msk->pm.add_addr_accepted++;
505 msk->pm.subflows++;
506 if (msk->pm.add_addr_accepted >= add_addr_accept_max ||
507 msk->pm.subflows >= subflows_max)
508 WRITE_ONCE(msk->pm.accept_addr, false);
509
510
511
512
513 remote = msk->pm.remote;
514 if (!remote.port)
515 remote.port = sk->sk_dport;
516 memset(&local, 0, sizeof(local));
517 local.family = remote.family;
518
519 spin_unlock_bh(&msk->pm.lock);
520 __mptcp_subflow_connect(sk, &local, &remote, 0, 0);
521 spin_lock_bh(&msk->pm.lock);
522
523add_addr_echo:
524 mptcp_pm_announce_addr(msk, &msk->pm.remote, true);
525 mptcp_pm_nl_addr_send_ack(msk);
526}
527
528void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk)
529{
530 struct mptcp_subflow_context *subflow;
531
532 msk_owned_by_me(msk);
533 lockdep_assert_held(&msk->pm.lock);
534
535 if (!mptcp_pm_should_add_signal(msk) &&
536 !mptcp_pm_should_rm_signal(msk))
537 return;
538
539 __mptcp_flush_join_list(msk);
540 subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node);
541 if (subflow) {
542 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
543 bool slow;
544
545 spin_unlock_bh(&msk->pm.lock);
546 pr_debug("send ack for %s%s%s",
547 mptcp_pm_should_add_signal(msk) ? "add_addr" : "rm_addr",
548 mptcp_pm_should_add_signal_ipv6(msk) ? " [ipv6]" : "",
549 mptcp_pm_should_add_signal_port(msk) ? " [port]" : "");
550
551 slow = lock_sock_fast(ssk);
552 tcp_send_ack(ssk);
553 unlock_sock_fast(ssk, slow);
554 spin_lock_bh(&msk->pm.lock);
555 }
556}
557
558int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
559 struct mptcp_addr_info *addr,
560 u8 bkup)
561{
562 struct mptcp_subflow_context *subflow;
563
564 pr_debug("bkup=%d", bkup);
565
566 mptcp_for_each_subflow(msk, subflow) {
567 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
568 struct sock *sk = (struct sock *)msk;
569 struct mptcp_addr_info local;
570 bool slow;
571
572 local_address((struct sock_common *)ssk, &local);
573 if (!addresses_equal(&local, addr, addr->port))
574 continue;
575
576 subflow->backup = bkup;
577 subflow->send_mp_prio = 1;
578 subflow->request_bkup = bkup;
579 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPPRIOTX);
580
581 spin_unlock_bh(&msk->pm.lock);
582 pr_debug("send ack for mp_prio");
583 slow = lock_sock_fast(ssk);
584 tcp_send_ack(ssk);
585 unlock_sock_fast(ssk, slow);
586 spin_lock_bh(&msk->pm.lock);
587
588 return 0;
589 }
590
591 return -EINVAL;
592}
593
594static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,
595 const struct mptcp_rm_list *rm_list,
596 enum linux_mptcp_mib_field rm_type)
597{
598 struct mptcp_subflow_context *subflow, *tmp;
599 struct sock *sk = (struct sock *)msk;
600 u8 i;
601
602 pr_debug("%s rm_list_nr %d",
603 rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow", rm_list->nr);
604
605 msk_owned_by_me(msk);
606
607 if (!rm_list->nr)
608 return;
609
610 if (list_empty(&msk->conn_list))
611 return;
612
613 for (i = 0; i < rm_list->nr; i++) {
614 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
615 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
616 int how = RCV_SHUTDOWN | SEND_SHUTDOWN;
617 u8 id = subflow->local_id;
618
619 if (rm_type == MPTCP_MIB_RMADDR)
620 id = subflow->remote_id;
621
622 if (rm_list->ids[i] != id)
623 continue;
624
625 pr_debug(" -> %s rm_list_ids[%d]=%u local_id=%u remote_id=%u",
626 rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow",
627 i, rm_list->ids[i], subflow->local_id, subflow->remote_id);
628 spin_unlock_bh(&msk->pm.lock);
629 mptcp_subflow_shutdown(sk, ssk, how);
630 mptcp_close_ssk(sk, ssk, subflow);
631 spin_lock_bh(&msk->pm.lock);
632
633 if (rm_type == MPTCP_MIB_RMADDR) {
634 msk->pm.add_addr_accepted--;
635 WRITE_ONCE(msk->pm.accept_addr, true);
636 } else if (rm_type == MPTCP_MIB_RMSUBFLOW) {
637 msk->pm.local_addr_used--;
638 }
639 msk->pm.subflows--;
640 __MPTCP_INC_STATS(sock_net(sk), rm_type);
641 }
642 }
643}
644
645static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
646{
647 mptcp_pm_nl_rm_addr_or_subflow(msk, &msk->pm.rm_list_rx, MPTCP_MIB_RMADDR);
648}
649
650void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk,
651 const struct mptcp_rm_list *rm_list)
652{
653 mptcp_pm_nl_rm_addr_or_subflow(msk, rm_list, MPTCP_MIB_RMSUBFLOW);
654}
655
656void mptcp_pm_nl_work(struct mptcp_sock *msk)
657{
658 struct mptcp_pm_data *pm = &msk->pm;
659
660 msk_owned_by_me(msk);
661
662 spin_lock_bh(&msk->pm.lock);
663
664 pr_debug("msk=%p status=%x", msk, pm->status);
665 if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) {
666 pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED);
667 mptcp_pm_nl_add_addr_received(msk);
668 }
669 if (pm->status & BIT(MPTCP_PM_ADD_ADDR_SEND_ACK)) {
670 pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_SEND_ACK);
671 mptcp_pm_nl_addr_send_ack(msk);
672 }
673 if (pm->status & BIT(MPTCP_PM_RM_ADDR_RECEIVED)) {
674 pm->status &= ~BIT(MPTCP_PM_RM_ADDR_RECEIVED);
675 mptcp_pm_nl_rm_addr_received(msk);
676 }
677 if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) {
678 pm->status &= ~BIT(MPTCP_PM_ESTABLISHED);
679 mptcp_pm_nl_fully_established(msk);
680 }
681 if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) {
682 pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED);
683 mptcp_pm_nl_subflow_established(msk);
684 }
685
686 spin_unlock_bh(&msk->pm.lock);
687}
688
689static bool address_use_port(struct mptcp_pm_addr_entry *entry)
690{
691 return (entry->flags &
692 (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) ==
693 MPTCP_PM_ADDR_FLAG_SIGNAL;
694}
695
696static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
697 struct mptcp_pm_addr_entry *entry)
698{
699 struct mptcp_pm_addr_entry *cur;
700 unsigned int addr_max;
701 int ret = -EINVAL;
702
703 spin_lock_bh(&pernet->lock);
704
705
706
707 if (pernet->next_id == MAX_ADDR_ID)
708 pernet->next_id = 1;
709 if (pernet->addrs >= MPTCP_PM_ADDR_MAX)
710 goto out;
711 if (test_bit(entry->addr.id, pernet->id_bitmap))
712 goto out;
713
714
715
716
717 list_for_each_entry(cur, &pernet->local_addr_list, list) {
718 if (addresses_equal(&cur->addr, &entry->addr,
719 address_use_port(entry) &&
720 address_use_port(cur)))
721 goto out;
722 }
723
724 if (!entry->addr.id) {
725find_next:
726 entry->addr.id = find_next_zero_bit(pernet->id_bitmap,
727 MAX_ADDR_ID + 1,
728 pernet->next_id);
729 if ((!entry->addr.id || entry->addr.id > MAX_ADDR_ID) &&
730 pernet->next_id != 1) {
731 pernet->next_id = 1;
732 goto find_next;
733 }
734 }
735
736 if (!entry->addr.id || entry->addr.id > MAX_ADDR_ID)
737 goto out;
738
739 __set_bit(entry->addr.id, pernet->id_bitmap);
740 if (entry->addr.id > pernet->next_id)
741 pernet->next_id = entry->addr.id;
742
743 if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
744 addr_max = pernet->add_addr_signal_max;
745 WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1);
746 }
747 if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
748 addr_max = pernet->local_addr_max;
749 WRITE_ONCE(pernet->local_addr_max, addr_max + 1);
750 }
751
752 pernet->addrs++;
753 list_add_tail_rcu(&entry->list, &pernet->local_addr_list);
754 ret = entry->addr.id;
755
756out:
757 spin_unlock_bh(&pernet->lock);
758 return ret;
759}
760
761static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
762 struct mptcp_pm_addr_entry *entry)
763{
764 struct sockaddr_storage addr;
765 struct mptcp_sock *msk;
766 struct socket *ssock;
767 int backlog = 1024;
768 int err;
769
770 err = sock_create_kern(sock_net(sk), entry->addr.family,
771 SOCK_STREAM, IPPROTO_MPTCP, &entry->lsk);
772 if (err)
773 return err;
774
775 msk = mptcp_sk(entry->lsk->sk);
776 if (!msk) {
777 err = -EINVAL;
778 goto out;
779 }
780
781 ssock = __mptcp_nmpc_socket(msk);
782 if (!ssock) {
783 err = -EINVAL;
784 goto out;
785 }
786
787 mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family);
788 err = kernel_bind(ssock, (struct sockaddr *)&addr,
789 sizeof(struct sockaddr_in));
790 if (err) {
791 pr_warn("kernel_bind error, err=%d", err);
792 goto out;
793 }
794
795 err = kernel_listen(ssock, backlog);
796 if (err) {
797 pr_warn("kernel_listen error, err=%d", err);
798 goto out;
799 }
800
801 return 0;
802
803out:
804 sock_release(entry->lsk);
805 return err;
806}
807
808int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
809{
810 struct mptcp_pm_addr_entry *entry;
811 struct mptcp_addr_info skc_local;
812 struct mptcp_addr_info msk_local;
813 struct pm_nl_pernet *pernet;
814 int ret = -1;
815
816 if (WARN_ON_ONCE(!msk))
817 return -1;
818
819
820
821
822 local_address((struct sock_common *)msk, &msk_local);
823 local_address((struct sock_common *)skc, &skc_local);
824 if (addresses_equal(&msk_local, &skc_local, false))
825 return 0;
826
827 if (address_zero(&skc_local))
828 return 0;
829
830 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
831
832 rcu_read_lock();
833 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
834 if (addresses_equal(&entry->addr, &skc_local, entry->addr.port)) {
835 ret = entry->addr.id;
836 break;
837 }
838 }
839 rcu_read_unlock();
840 if (ret >= 0)
841 return ret;
842
843
844 entry = kmalloc(sizeof(*entry), GFP_ATOMIC);
845 if (!entry)
846 return -ENOMEM;
847
848 entry->addr = skc_local;
849 entry->addr.id = 0;
850 entry->addr.port = 0;
851 entry->ifindex = 0;
852 entry->flags = 0;
853 entry->lsk = NULL;
854 ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
855 if (ret < 0)
856 kfree(entry);
857
858 return ret;
859}
860
861void mptcp_pm_nl_data_init(struct mptcp_sock *msk)
862{
863 struct mptcp_pm_data *pm = &msk->pm;
864 bool subflows;
865
866 subflows = !!mptcp_pm_get_subflows_max(msk);
867 WRITE_ONCE(pm->work_pending, (!!mptcp_pm_get_local_addr_max(msk) && subflows) ||
868 !!mptcp_pm_get_add_addr_signal_max(msk));
869 WRITE_ONCE(pm->accept_addr, !!mptcp_pm_get_add_addr_accept_max(msk) && subflows);
870 WRITE_ONCE(pm->accept_subflow, subflows);
871}
872
873#define MPTCP_PM_CMD_GRP_OFFSET 0
874#define MPTCP_PM_EV_GRP_OFFSET 1
875
876static const struct genl_multicast_group mptcp_pm_mcgrps[] = {
877 [MPTCP_PM_CMD_GRP_OFFSET] = { .name = MPTCP_PM_CMD_GRP_NAME, },
878 [MPTCP_PM_EV_GRP_OFFSET] = { .name = MPTCP_PM_EV_GRP_NAME,
879 .flags = GENL_UNS_ADMIN_PERM,
880 },
881};
882
883static const struct nla_policy
884mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = {
885 [MPTCP_PM_ADDR_ATTR_FAMILY] = { .type = NLA_U16, },
886 [MPTCP_PM_ADDR_ATTR_ID] = { .type = NLA_U8, },
887 [MPTCP_PM_ADDR_ATTR_ADDR4] = { .type = NLA_U32, },
888 [MPTCP_PM_ADDR_ATTR_ADDR6] =
889 NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)),
890 [MPTCP_PM_ADDR_ATTR_PORT] = { .type = NLA_U16 },
891 [MPTCP_PM_ADDR_ATTR_FLAGS] = { .type = NLA_U32 },
892 [MPTCP_PM_ADDR_ATTR_IF_IDX] = { .type = NLA_S32 },
893};
894
895static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = {
896 [MPTCP_PM_ATTR_ADDR] =
897 NLA_POLICY_NESTED(mptcp_pm_addr_policy),
898 [MPTCP_PM_ATTR_RCV_ADD_ADDRS] = { .type = NLA_U32, },
899 [MPTCP_PM_ATTR_SUBFLOWS] = { .type = NLA_U32, },
900};
901
902static int mptcp_pm_family_to_addr(int family)
903{
904#if IS_ENABLED(CONFIG_MPTCP_IPV6)
905 if (family == AF_INET6)
906 return MPTCP_PM_ADDR_ATTR_ADDR6;
907#endif
908 return MPTCP_PM_ADDR_ATTR_ADDR4;
909}
910
911static int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info,
912 bool require_family,
913 struct mptcp_pm_addr_entry *entry)
914{
915 struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1];
916 int err, addr_addr;
917
918 if (!attr) {
919 GENL_SET_ERR_MSG(info, "missing address info");
920 return -EINVAL;
921 }
922
923
924 err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr,
925 mptcp_pm_addr_policy, info->extack);
926 if (err)
927 return err;
928
929 memset(entry, 0, sizeof(*entry));
930 if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) {
931 if (!require_family)
932 goto skip_family;
933
934 NL_SET_ERR_MSG_ATTR(info->extack, attr,
935 "missing family");
936 return -EINVAL;
937 }
938
939 entry->addr.family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]);
940 if (entry->addr.family != AF_INET
941#if IS_ENABLED(CONFIG_MPTCP_IPV6)
942 && entry->addr.family != AF_INET6
943#endif
944 ) {
945 NL_SET_ERR_MSG_ATTR(info->extack, attr,
946 "unknown address family");
947 return -EINVAL;
948 }
949 addr_addr = mptcp_pm_family_to_addr(entry->addr.family);
950 if (!tb[addr_addr]) {
951 NL_SET_ERR_MSG_ATTR(info->extack, attr,
952 "missing address data");
953 return -EINVAL;
954 }
955
956#if IS_ENABLED(CONFIG_MPTCP_IPV6)
957 if (entry->addr.family == AF_INET6)
958 entry->addr.addr6 = nla_get_in6_addr(tb[addr_addr]);
959 else
960#endif
961 entry->addr.addr.s_addr = nla_get_in_addr(tb[addr_addr]);
962
963skip_family:
964 if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX]) {
965 u32 val = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]);
966
967 entry->ifindex = val;
968 }
969
970 if (tb[MPTCP_PM_ADDR_ATTR_ID])
971 entry->addr.id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]);
972
973 if (tb[MPTCP_PM_ADDR_ATTR_FLAGS])
974 entry->flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]);
975
976 if (tb[MPTCP_PM_ADDR_ATTR_PORT]) {
977 if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) {
978 NL_SET_ERR_MSG_ATTR(info->extack, attr,
979 "flags must have signal when using port");
980 return -EINVAL;
981 }
982 entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT]));
983 }
984
985 return 0;
986}
987
988static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info)
989{
990 return net_generic(genl_info_net(info), pm_nl_pernet_id);
991}
992
993static int mptcp_nl_add_subflow_or_signal_addr(struct net *net)
994{
995 struct mptcp_sock *msk;
996 long s_slot = 0, s_num = 0;
997
998 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
999 struct sock *sk = (struct sock *)msk;
1000
1001 if (!READ_ONCE(msk->fully_established))
1002 goto next;
1003
1004 lock_sock(sk);
1005 spin_lock_bh(&msk->pm.lock);
1006 mptcp_pm_create_subflow_or_signal_addr(msk);
1007 spin_unlock_bh(&msk->pm.lock);
1008 release_sock(sk);
1009
1010next:
1011 sock_put(sk);
1012 cond_resched();
1013 }
1014
1015 return 0;
1016}
1017
1018static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info)
1019{
1020 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1021 struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1022 struct mptcp_pm_addr_entry addr, *entry;
1023 int ret;
1024
1025 ret = mptcp_pm_parse_addr(attr, info, true, &addr);
1026 if (ret < 0)
1027 return ret;
1028
1029 entry = kmalloc(sizeof(*entry), GFP_KERNEL);
1030 if (!entry) {
1031 GENL_SET_ERR_MSG(info, "can't allocate addr");
1032 return -ENOMEM;
1033 }
1034
1035 *entry = addr;
1036 if (entry->addr.port) {
1037 ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry);
1038 if (ret) {
1039 GENL_SET_ERR_MSG(info, "create listen socket error");
1040 kfree(entry);
1041 return ret;
1042 }
1043 }
1044 ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
1045 if (ret < 0) {
1046 GENL_SET_ERR_MSG(info, "too many addresses or duplicate one");
1047 if (entry->lsk)
1048 sock_release(entry->lsk);
1049 kfree(entry);
1050 return ret;
1051 }
1052
1053 mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk));
1054
1055 return 0;
1056}
1057
1058static struct mptcp_pm_addr_entry *
1059__lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id)
1060{
1061 struct mptcp_pm_addr_entry *entry;
1062
1063 list_for_each_entry(entry, &pernet->local_addr_list, list) {
1064 if (entry->addr.id == id)
1065 return entry;
1066 }
1067 return NULL;
1068}
1069
1070static bool remove_anno_list_by_saddr(struct mptcp_sock *msk,
1071 struct mptcp_addr_info *addr)
1072{
1073 struct mptcp_pm_add_entry *entry;
1074
1075 entry = mptcp_pm_del_add_timer(msk, addr, false);
1076 if (entry) {
1077 list_del(&entry->list);
1078 kfree(entry);
1079 return true;
1080 }
1081
1082 return false;
1083}
1084
1085static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk,
1086 struct mptcp_addr_info *addr,
1087 bool force)
1088{
1089 struct mptcp_rm_list list = { .nr = 0 };
1090 bool ret;
1091
1092 list.ids[list.nr++] = addr->id;
1093
1094 ret = remove_anno_list_by_saddr(msk, addr);
1095 if (ret || force) {
1096 spin_lock_bh(&msk->pm.lock);
1097 mptcp_pm_remove_addr(msk, &list);
1098 spin_unlock_bh(&msk->pm.lock);
1099 }
1100 return ret;
1101}
1102
1103static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net,
1104 struct mptcp_addr_info *addr)
1105{
1106 struct mptcp_sock *msk;
1107 long s_slot = 0, s_num = 0;
1108 struct mptcp_rm_list list = { .nr = 0 };
1109
1110 pr_debug("remove_id=%d", addr->id);
1111
1112 list.ids[list.nr++] = addr->id;
1113
1114 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1115 struct sock *sk = (struct sock *)msk;
1116 bool remove_subflow;
1117
1118 if (list_empty(&msk->conn_list)) {
1119 mptcp_pm_remove_anno_addr(msk, addr, false);
1120 goto next;
1121 }
1122
1123 lock_sock(sk);
1124 remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr);
1125 mptcp_pm_remove_anno_addr(msk, addr, remove_subflow);
1126 if (remove_subflow)
1127 mptcp_pm_remove_subflow(msk, &list);
1128 release_sock(sk);
1129
1130next:
1131 sock_put(sk);
1132 cond_resched();
1133 }
1134
1135 return 0;
1136}
1137
1138
1139static void __mptcp_pm_release_addr_entry(struct mptcp_pm_addr_entry *entry)
1140{
1141 if (entry->lsk)
1142 sock_release(entry->lsk);
1143 kfree(entry);
1144}
1145
1146static int mptcp_nl_remove_id_zero_address(struct net *net,
1147 struct mptcp_addr_info *addr)
1148{
1149 struct mptcp_rm_list list = { .nr = 0 };
1150 long s_slot = 0, s_num = 0;
1151 struct mptcp_sock *msk;
1152
1153 list.ids[list.nr++] = 0;
1154
1155 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1156 struct sock *sk = (struct sock *)msk;
1157 struct mptcp_addr_info msk_local;
1158
1159 if (list_empty(&msk->conn_list))
1160 goto next;
1161
1162 local_address((struct sock_common *)msk, &msk_local);
1163 if (!addresses_equal(&msk_local, addr, addr->port))
1164 goto next;
1165
1166 lock_sock(sk);
1167 spin_lock_bh(&msk->pm.lock);
1168 mptcp_pm_remove_addr(msk, &list);
1169 mptcp_pm_nl_rm_subflow_received(msk, &list);
1170 spin_unlock_bh(&msk->pm.lock);
1171 release_sock(sk);
1172
1173next:
1174 sock_put(sk);
1175 cond_resched();
1176 }
1177
1178 return 0;
1179}
1180
1181static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
1182{
1183 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1184 struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1185 struct mptcp_pm_addr_entry addr, *entry;
1186 unsigned int addr_max;
1187 int ret;
1188
1189 ret = mptcp_pm_parse_addr(attr, info, false, &addr);
1190 if (ret < 0)
1191 return ret;
1192
1193
1194
1195
1196
1197
1198 if (addr.addr.id == 0)
1199 return mptcp_nl_remove_id_zero_address(sock_net(skb->sk), &addr.addr);
1200
1201 spin_lock_bh(&pernet->lock);
1202 entry = __lookup_addr_by_id(pernet, addr.addr.id);
1203 if (!entry) {
1204 GENL_SET_ERR_MSG(info, "address not found");
1205 spin_unlock_bh(&pernet->lock);
1206 return -EINVAL;
1207 }
1208 if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
1209 addr_max = pernet->add_addr_signal_max;
1210 WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1);
1211 }
1212 if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
1213 addr_max = pernet->local_addr_max;
1214 WRITE_ONCE(pernet->local_addr_max, addr_max - 1);
1215 }
1216
1217 pernet->addrs--;
1218 list_del_rcu(&entry->list);
1219 __clear_bit(entry->addr.id, pernet->id_bitmap);
1220 spin_unlock_bh(&pernet->lock);
1221
1222 mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr);
1223 synchronize_rcu();
1224 __mptcp_pm_release_addr_entry(entry);
1225
1226 return ret;
1227}
1228
1229static void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk,
1230 struct list_head *rm_list)
1231{
1232 struct mptcp_rm_list alist = { .nr = 0 }, slist = { .nr = 0 };
1233 struct mptcp_pm_addr_entry *entry;
1234
1235 list_for_each_entry(entry, rm_list, list) {
1236 if (lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) &&
1237 alist.nr < MPTCP_RM_IDS_MAX &&
1238 slist.nr < MPTCP_RM_IDS_MAX) {
1239 alist.ids[alist.nr++] = entry->addr.id;
1240 slist.ids[slist.nr++] = entry->addr.id;
1241 } else if (remove_anno_list_by_saddr(msk, &entry->addr) &&
1242 alist.nr < MPTCP_RM_IDS_MAX) {
1243 alist.ids[alist.nr++] = entry->addr.id;
1244 }
1245 }
1246
1247 if (alist.nr) {
1248 spin_lock_bh(&msk->pm.lock);
1249 mptcp_pm_remove_addr(msk, &alist);
1250 spin_unlock_bh(&msk->pm.lock);
1251 }
1252 if (slist.nr)
1253 mptcp_pm_remove_subflow(msk, &slist);
1254}
1255
1256static void mptcp_nl_remove_addrs_list(struct net *net,
1257 struct list_head *rm_list)
1258{
1259 long s_slot = 0, s_num = 0;
1260 struct mptcp_sock *msk;
1261
1262 if (list_empty(rm_list))
1263 return;
1264
1265 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1266 struct sock *sk = (struct sock *)msk;
1267
1268 lock_sock(sk);
1269 mptcp_pm_remove_addrs_and_subflows(msk, rm_list);
1270 release_sock(sk);
1271
1272 sock_put(sk);
1273 cond_resched();
1274 }
1275}
1276
1277
1278static void __flush_addrs(struct list_head *list)
1279{
1280 while (!list_empty(list)) {
1281 struct mptcp_pm_addr_entry *cur;
1282
1283 cur = list_entry(list->next,
1284 struct mptcp_pm_addr_entry, list);
1285 list_del_rcu(&cur->list);
1286 __mptcp_pm_release_addr_entry(cur);
1287 }
1288}
1289
1290static void __reset_counters(struct pm_nl_pernet *pernet)
1291{
1292 WRITE_ONCE(pernet->add_addr_signal_max, 0);
1293 WRITE_ONCE(pernet->add_addr_accept_max, 0);
1294 WRITE_ONCE(pernet->local_addr_max, 0);
1295 pernet->addrs = 0;
1296}
1297
1298static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info)
1299{
1300 struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1301 LIST_HEAD(free_list);
1302
1303 spin_lock_bh(&pernet->lock);
1304 list_splice_init(&pernet->local_addr_list, &free_list);
1305 __reset_counters(pernet);
1306 pernet->next_id = 1;
1307 bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1);
1308 spin_unlock_bh(&pernet->lock);
1309 mptcp_nl_remove_addrs_list(sock_net(skb->sk), &free_list);
1310 synchronize_rcu();
1311 __flush_addrs(&free_list);
1312 return 0;
1313}
1314
1315static int mptcp_nl_fill_addr(struct sk_buff *skb,
1316 struct mptcp_pm_addr_entry *entry)
1317{
1318 struct mptcp_addr_info *addr = &entry->addr;
1319 struct nlattr *attr;
1320
1321 attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR);
1322 if (!attr)
1323 return -EMSGSIZE;
1324
1325 if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family))
1326 goto nla_put_failure;
1327 if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_PORT, ntohs(addr->port)))
1328 goto nla_put_failure;
1329 if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id))
1330 goto nla_put_failure;
1331 if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->flags))
1332 goto nla_put_failure;
1333 if (entry->ifindex &&
1334 nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->ifindex))
1335 goto nla_put_failure;
1336
1337 if (addr->family == AF_INET &&
1338 nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4,
1339 addr->addr.s_addr))
1340 goto nla_put_failure;
1341#if IS_ENABLED(CONFIG_MPTCP_IPV6)
1342 else if (addr->family == AF_INET6 &&
1343 nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6))
1344 goto nla_put_failure;
1345#endif
1346 nla_nest_end(skb, attr);
1347 return 0;
1348
1349nla_put_failure:
1350 nla_nest_cancel(skb, attr);
1351 return -EMSGSIZE;
1352}
1353
1354static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info)
1355{
1356 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1357 struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1358 struct mptcp_pm_addr_entry addr, *entry;
1359 struct sk_buff *msg;
1360 void *reply;
1361 int ret;
1362
1363 ret = mptcp_pm_parse_addr(attr, info, false, &addr);
1364 if (ret < 0)
1365 return ret;
1366
1367 msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1368 if (!msg)
1369 return -ENOMEM;
1370
1371 reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
1372 info->genlhdr->cmd);
1373 if (!reply) {
1374 GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
1375 ret = -EMSGSIZE;
1376 goto fail;
1377 }
1378
1379 spin_lock_bh(&pernet->lock);
1380 entry = __lookup_addr_by_id(pernet, addr.addr.id);
1381 if (!entry) {
1382 GENL_SET_ERR_MSG(info, "address not found");
1383 ret = -EINVAL;
1384 goto unlock_fail;
1385 }
1386
1387 ret = mptcp_nl_fill_addr(msg, entry);
1388 if (ret)
1389 goto unlock_fail;
1390
1391 genlmsg_end(msg, reply);
1392 ret = genlmsg_reply(msg, info);
1393 spin_unlock_bh(&pernet->lock);
1394 return ret;
1395
1396unlock_fail:
1397 spin_unlock_bh(&pernet->lock);
1398
1399fail:
1400 nlmsg_free(msg);
1401 return ret;
1402}
1403
1404static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg,
1405 struct netlink_callback *cb)
1406{
1407 struct net *net = sock_net(msg->sk);
1408 struct mptcp_pm_addr_entry *entry;
1409 struct pm_nl_pernet *pernet;
1410 int id = cb->args[0];
1411 void *hdr;
1412 int i;
1413
1414 pernet = net_generic(net, pm_nl_pernet_id);
1415
1416 spin_lock_bh(&pernet->lock);
1417 for (i = id; i < MAX_ADDR_ID + 1; i++) {
1418 if (test_bit(i, pernet->id_bitmap)) {
1419 entry = __lookup_addr_by_id(pernet, i);
1420 if (!entry)
1421 break;
1422
1423 if (entry->addr.id <= id)
1424 continue;
1425
1426 hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid,
1427 cb->nlh->nlmsg_seq, &mptcp_genl_family,
1428 NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR);
1429 if (!hdr)
1430 break;
1431
1432 if (mptcp_nl_fill_addr(msg, entry) < 0) {
1433 genlmsg_cancel(msg, hdr);
1434 break;
1435 }
1436
1437 id = entry->addr.id;
1438 genlmsg_end(msg, hdr);
1439 }
1440 }
1441 spin_unlock_bh(&pernet->lock);
1442
1443 cb->args[0] = id;
1444 return msg->len;
1445}
1446
1447static int parse_limit(struct genl_info *info, int id, unsigned int *limit)
1448{
1449 struct nlattr *attr = info->attrs[id];
1450
1451 if (!attr)
1452 return 0;
1453
1454 *limit = nla_get_u32(attr);
1455 if (*limit > MPTCP_PM_ADDR_MAX) {
1456 GENL_SET_ERR_MSG(info, "limit greater than maximum");
1457 return -EINVAL;
1458 }
1459 return 0;
1460}
1461
1462static int
1463mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info)
1464{
1465 struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1466 unsigned int rcv_addrs, subflows;
1467 int ret;
1468
1469 spin_lock_bh(&pernet->lock);
1470 rcv_addrs = pernet->add_addr_accept_max;
1471 ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs);
1472 if (ret)
1473 goto unlock;
1474
1475 subflows = pernet->subflows_max;
1476 ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows);
1477 if (ret)
1478 goto unlock;
1479
1480 WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs);
1481 WRITE_ONCE(pernet->subflows_max, subflows);
1482
1483unlock:
1484 spin_unlock_bh(&pernet->lock);
1485 return ret;
1486}
1487
1488static int
1489mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info)
1490{
1491 struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1492 struct sk_buff *msg;
1493 void *reply;
1494
1495 msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1496 if (!msg)
1497 return -ENOMEM;
1498
1499 reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
1500 MPTCP_PM_CMD_GET_LIMITS);
1501 if (!reply)
1502 goto fail;
1503
1504 if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS,
1505 READ_ONCE(pernet->add_addr_accept_max)))
1506 goto fail;
1507
1508 if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS,
1509 READ_ONCE(pernet->subflows_max)))
1510 goto fail;
1511
1512 genlmsg_end(msg, reply);
1513 return genlmsg_reply(msg, info);
1514
1515fail:
1516 GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
1517 nlmsg_free(msg);
1518 return -EMSGSIZE;
1519}
1520
1521static int mptcp_nl_addr_backup(struct net *net,
1522 struct mptcp_addr_info *addr,
1523 u8 bkup)
1524{
1525 long s_slot = 0, s_num = 0;
1526 struct mptcp_sock *msk;
1527 int ret = -EINVAL;
1528
1529 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1530 struct sock *sk = (struct sock *)msk;
1531
1532 if (list_empty(&msk->conn_list))
1533 goto next;
1534
1535 lock_sock(sk);
1536 spin_lock_bh(&msk->pm.lock);
1537 ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, bkup);
1538 spin_unlock_bh(&msk->pm.lock);
1539 release_sock(sk);
1540
1541next:
1542 sock_put(sk);
1543 cond_resched();
1544 }
1545
1546 return ret;
1547}
1548
1549static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
1550{
1551 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1552 struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1553 struct mptcp_pm_addr_entry addr, *entry;
1554 struct net *net = sock_net(skb->sk);
1555 u8 bkup = 0;
1556 int ret;
1557
1558 ret = mptcp_pm_parse_addr(attr, info, true, &addr);
1559 if (ret < 0)
1560 return ret;
1561
1562 if (addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP)
1563 bkup = 1;
1564
1565 list_for_each_entry(entry, &pernet->local_addr_list, list) {
1566 if (addresses_equal(&entry->addr, &addr.addr, true)) {
1567 ret = mptcp_nl_addr_backup(net, &entry->addr, bkup);
1568 if (ret)
1569 return ret;
1570
1571 if (bkup)
1572 entry->flags |= MPTCP_PM_ADDR_FLAG_BACKUP;
1573 else
1574 entry->flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP;
1575 }
1576 }
1577
1578 return 0;
1579}
1580
1581static void mptcp_nl_mcast_send(struct net *net, struct sk_buff *nlskb, gfp_t gfp)
1582{
1583 genlmsg_multicast_netns(&mptcp_genl_family, net,
1584 nlskb, 0, MPTCP_PM_EV_GRP_OFFSET, gfp);
1585}
1586
1587static int mptcp_event_add_subflow(struct sk_buff *skb, const struct sock *ssk)
1588{
1589 const struct inet_sock *issk = inet_sk(ssk);
1590 const struct mptcp_subflow_context *sf;
1591
1592 if (nla_put_u16(skb, MPTCP_ATTR_FAMILY, ssk->sk_family))
1593 return -EMSGSIZE;
1594
1595 switch (ssk->sk_family) {
1596 case AF_INET:
1597 if (nla_put_in_addr(skb, MPTCP_ATTR_SADDR4, issk->inet_saddr))
1598 return -EMSGSIZE;
1599 if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, issk->inet_daddr))
1600 return -EMSGSIZE;
1601 break;
1602#if IS_ENABLED(CONFIG_MPTCP_IPV6)
1603 case AF_INET6: {
1604 const struct ipv6_pinfo *np = inet6_sk(ssk);
1605
1606 if (nla_put_in6_addr(skb, MPTCP_ATTR_SADDR6, &np->saddr))
1607 return -EMSGSIZE;
1608 if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &ssk->sk_v6_daddr))
1609 return -EMSGSIZE;
1610 break;
1611 }
1612#endif
1613 default:
1614 WARN_ON_ONCE(1);
1615 return -EMSGSIZE;
1616 }
1617
1618 if (nla_put_be16(skb, MPTCP_ATTR_SPORT, issk->inet_sport))
1619 return -EMSGSIZE;
1620 if (nla_put_be16(skb, MPTCP_ATTR_DPORT, issk->inet_dport))
1621 return -EMSGSIZE;
1622
1623 sf = mptcp_subflow_ctx(ssk);
1624 if (WARN_ON_ONCE(!sf))
1625 return -EINVAL;
1626
1627 if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, sf->local_id))
1628 return -EMSGSIZE;
1629
1630 if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, sf->remote_id))
1631 return -EMSGSIZE;
1632
1633 return 0;
1634}
1635
1636static int mptcp_event_put_token_and_ssk(struct sk_buff *skb,
1637 const struct mptcp_sock *msk,
1638 const struct sock *ssk)
1639{
1640 const struct sock *sk = (const struct sock *)msk;
1641 const struct mptcp_subflow_context *sf;
1642 u8 sk_err;
1643
1644 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
1645 return -EMSGSIZE;
1646
1647 if (mptcp_event_add_subflow(skb, ssk))
1648 return -EMSGSIZE;
1649
1650 sf = mptcp_subflow_ctx(ssk);
1651 if (WARN_ON_ONCE(!sf))
1652 return -EINVAL;
1653
1654 if (nla_put_u8(skb, MPTCP_ATTR_BACKUP, sf->backup))
1655 return -EMSGSIZE;
1656
1657 if (ssk->sk_bound_dev_if &&
1658 nla_put_s32(skb, MPTCP_ATTR_IF_IDX, ssk->sk_bound_dev_if))
1659 return -EMSGSIZE;
1660
1661 sk_err = ssk->sk_err;
1662 if (sk_err && sk->sk_state == TCP_ESTABLISHED &&
1663 nla_put_u8(skb, MPTCP_ATTR_ERROR, sk_err))
1664 return -EMSGSIZE;
1665
1666 return 0;
1667}
1668
1669static int mptcp_event_sub_established(struct sk_buff *skb,
1670 const struct mptcp_sock *msk,
1671 const struct sock *ssk)
1672{
1673 return mptcp_event_put_token_and_ssk(skb, msk, ssk);
1674}
1675
1676static int mptcp_event_sub_closed(struct sk_buff *skb,
1677 const struct mptcp_sock *msk,
1678 const struct sock *ssk)
1679{
1680 const struct mptcp_subflow_context *sf;
1681
1682 if (mptcp_event_put_token_and_ssk(skb, msk, ssk))
1683 return -EMSGSIZE;
1684
1685 sf = mptcp_subflow_ctx(ssk);
1686 if (!sf->reset_seen)
1687 return 0;
1688
1689 if (nla_put_u32(skb, MPTCP_ATTR_RESET_REASON, sf->reset_reason))
1690 return -EMSGSIZE;
1691
1692 if (nla_put_u32(skb, MPTCP_ATTR_RESET_FLAGS, sf->reset_transient))
1693 return -EMSGSIZE;
1694
1695 return 0;
1696}
1697
1698static int mptcp_event_created(struct sk_buff *skb,
1699 const struct mptcp_sock *msk,
1700 const struct sock *ssk)
1701{
1702 int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token);
1703
1704 if (err)
1705 return err;
1706
1707 return mptcp_event_add_subflow(skb, ssk);
1708}
1709
1710void mptcp_event_addr_removed(const struct mptcp_sock *msk, uint8_t id)
1711{
1712 struct net *net = sock_net((const struct sock *)msk);
1713 struct nlmsghdr *nlh;
1714 struct sk_buff *skb;
1715
1716 if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
1717 return;
1718
1719 skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
1720 if (!skb)
1721 return;
1722
1723 nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, MPTCP_EVENT_REMOVED);
1724 if (!nlh)
1725 goto nla_put_failure;
1726
1727 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
1728 goto nla_put_failure;
1729
1730 if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, id))
1731 goto nla_put_failure;
1732
1733 genlmsg_end(skb, nlh);
1734 mptcp_nl_mcast_send(net, skb, GFP_ATOMIC);
1735 return;
1736
1737nla_put_failure:
1738 kfree_skb(skb);
1739}
1740
1741void mptcp_event_addr_announced(const struct mptcp_sock *msk,
1742 const struct mptcp_addr_info *info)
1743{
1744 struct net *net = sock_net((const struct sock *)msk);
1745 struct nlmsghdr *nlh;
1746 struct sk_buff *skb;
1747
1748 if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
1749 return;
1750
1751 skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
1752 if (!skb)
1753 return;
1754
1755 nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0,
1756 MPTCP_EVENT_ANNOUNCED);
1757 if (!nlh)
1758 goto nla_put_failure;
1759
1760 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
1761 goto nla_put_failure;
1762
1763 if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, info->id))
1764 goto nla_put_failure;
1765
1766 if (nla_put_be16(skb, MPTCP_ATTR_DPORT, info->port))
1767 goto nla_put_failure;
1768
1769 switch (info->family) {
1770 case AF_INET:
1771 if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, info->addr.s_addr))
1772 goto nla_put_failure;
1773 break;
1774#if IS_ENABLED(CONFIG_MPTCP_IPV6)
1775 case AF_INET6:
1776 if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &info->addr6))
1777 goto nla_put_failure;
1778 break;
1779#endif
1780 default:
1781 WARN_ON_ONCE(1);
1782 goto nla_put_failure;
1783 }
1784
1785 genlmsg_end(skb, nlh);
1786 mptcp_nl_mcast_send(net, skb, GFP_ATOMIC);
1787 return;
1788
1789nla_put_failure:
1790 kfree_skb(skb);
1791}
1792
1793void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk,
1794 const struct sock *ssk, gfp_t gfp)
1795{
1796 struct net *net = sock_net((const struct sock *)msk);
1797 struct nlmsghdr *nlh;
1798 struct sk_buff *skb;
1799
1800 if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
1801 return;
1802
1803 skb = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp);
1804 if (!skb)
1805 return;
1806
1807 nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, type);
1808 if (!nlh)
1809 goto nla_put_failure;
1810
1811 switch (type) {
1812 case MPTCP_EVENT_UNSPEC:
1813 WARN_ON_ONCE(1);
1814 break;
1815 case MPTCP_EVENT_CREATED:
1816 case MPTCP_EVENT_ESTABLISHED:
1817 if (mptcp_event_created(skb, msk, ssk) < 0)
1818 goto nla_put_failure;
1819 break;
1820 case MPTCP_EVENT_CLOSED:
1821 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token) < 0)
1822 goto nla_put_failure;
1823 break;
1824 case MPTCP_EVENT_ANNOUNCED:
1825 case MPTCP_EVENT_REMOVED:
1826
1827 WARN_ON_ONCE(1);
1828 break;
1829 case MPTCP_EVENT_SUB_ESTABLISHED:
1830 case MPTCP_EVENT_SUB_PRIORITY:
1831 if (mptcp_event_sub_established(skb, msk, ssk) < 0)
1832 goto nla_put_failure;
1833 break;
1834 case MPTCP_EVENT_SUB_CLOSED:
1835 if (mptcp_event_sub_closed(skb, msk, ssk) < 0)
1836 goto nla_put_failure;
1837 break;
1838 }
1839
1840 genlmsg_end(skb, nlh);
1841 mptcp_nl_mcast_send(net, skb, gfp);
1842 return;
1843
1844nla_put_failure:
1845 kfree_skb(skb);
1846}
1847
1848static const struct genl_small_ops mptcp_pm_ops[] = {
1849 {
1850 .cmd = MPTCP_PM_CMD_ADD_ADDR,
1851 .doit = mptcp_nl_cmd_add_addr,
1852 .flags = GENL_ADMIN_PERM,
1853 },
1854 {
1855 .cmd = MPTCP_PM_CMD_DEL_ADDR,
1856 .doit = mptcp_nl_cmd_del_addr,
1857 .flags = GENL_ADMIN_PERM,
1858 },
1859 {
1860 .cmd = MPTCP_PM_CMD_FLUSH_ADDRS,
1861 .doit = mptcp_nl_cmd_flush_addrs,
1862 .flags = GENL_ADMIN_PERM,
1863 },
1864 {
1865 .cmd = MPTCP_PM_CMD_GET_ADDR,
1866 .doit = mptcp_nl_cmd_get_addr,
1867 .dumpit = mptcp_nl_cmd_dump_addrs,
1868 },
1869 {
1870 .cmd = MPTCP_PM_CMD_SET_LIMITS,
1871 .doit = mptcp_nl_cmd_set_limits,
1872 .flags = GENL_ADMIN_PERM,
1873 },
1874 {
1875 .cmd = MPTCP_PM_CMD_GET_LIMITS,
1876 .doit = mptcp_nl_cmd_get_limits,
1877 },
1878 {
1879 .cmd = MPTCP_PM_CMD_SET_FLAGS,
1880 .doit = mptcp_nl_cmd_set_flags,
1881 .flags = GENL_ADMIN_PERM,
1882 },
1883};
1884
1885static struct genl_family mptcp_genl_family __ro_after_init = {
1886 .name = MPTCP_PM_NAME,
1887 .version = MPTCP_PM_VER,
1888 .maxattr = MPTCP_PM_ATTR_MAX,
1889 .policy = mptcp_pm_policy,
1890 .netnsok = true,
1891 .module = THIS_MODULE,
1892 .small_ops = mptcp_pm_ops,
1893 .n_small_ops = ARRAY_SIZE(mptcp_pm_ops),
1894 .mcgrps = mptcp_pm_mcgrps,
1895 .n_mcgrps = ARRAY_SIZE(mptcp_pm_mcgrps),
1896};
1897
1898static int __net_init pm_nl_init_net(struct net *net)
1899{
1900 struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id);
1901
1902 INIT_LIST_HEAD_RCU(&pernet->local_addr_list);
1903 pernet->next_id = 1;
1904 spin_lock_init(&pernet->lock);
1905
1906
1907
1908
1909
1910 return 0;
1911}
1912
1913static void __net_exit pm_nl_exit_net(struct list_head *net_list)
1914{
1915 struct net *net;
1916
1917 list_for_each_entry(net, net_list, exit_list) {
1918 struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id);
1919
1920
1921
1922
1923
1924 __flush_addrs(&pernet->local_addr_list);
1925 }
1926}
1927
1928static struct pernet_operations mptcp_pm_pernet_ops = {
1929 .init = pm_nl_init_net,
1930 .exit_batch = pm_nl_exit_net,
1931 .id = &pm_nl_pernet_id,
1932 .size = sizeof(struct pm_nl_pernet),
1933};
1934
1935void __init mptcp_pm_nl_init(void)
1936{
1937 if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0)
1938 panic("Failed to register MPTCP PM pernet subsystem.\n");
1939
1940 if (genl_register_family(&mptcp_genl_family))
1941 panic("Failed to register MPTCP PM netlink family\n");
1942}
1943