linux/drivers/vhost/net.c
<<
>>
Prefs
   1/* Copyright (C) 2009 Red Hat, Inc.
   2 * Author: Michael S. Tsirkin <mst@redhat.com>
   3 *
   4 * This work is licensed under the terms of the GNU GPL, version 2.
   5 *
   6 * virtio-net server in host kernel.
   7 */
   8
   9#include <linux/compat.h>
  10#include <linux/eventfd.h>
  11#include <linux/vhost.h>
  12#include <linux/virtio_net.h>
  13#include <linux/miscdevice.h>
  14#include <linux/module.h>
  15#include <linux/moduleparam.h>
  16#include <linux/mutex.h>
  17#include <linux/workqueue.h>
  18#include <linux/rcupdate.h>
  19#include <linux/file.h>
  20#include <linux/slab.h>
  21
  22#include <linux/net.h>
  23#include <linux/if_packet.h>
  24#include <linux/if_arp.h>
  25#include <linux/if_tun.h>
  26#include <linux/if_macvlan.h>
  27
  28#include <net/sock.h>
  29
  30#include "vhost.h"
  31
  32static int experimental_zcopytx;
  33module_param(experimental_zcopytx, int, 0444);
  34MODULE_PARM_DESC(experimental_zcopytx, "Enable Experimental Zero Copy TX");
  35
  36/* Max number of bytes transferred before requeueing the job.
  37 * Using this limit prevents one virtqueue from starving others. */
  38#define VHOST_NET_WEIGHT 0x80000
  39
  40/* MAX number of TX used buffers for outstanding zerocopy */
  41#define VHOST_MAX_PEND 128
  42#define VHOST_GOODCOPY_LEN 256
  43
  44enum {
  45        VHOST_NET_VQ_RX = 0,
  46        VHOST_NET_VQ_TX = 1,
  47        VHOST_NET_VQ_MAX = 2,
  48};
  49
  50enum vhost_net_poll_state {
  51        VHOST_NET_POLL_DISABLED = 0,
  52        VHOST_NET_POLL_STARTED = 1,
  53        VHOST_NET_POLL_STOPPED = 2,
  54};
  55
  56struct vhost_net {
  57        struct vhost_dev dev;
  58        struct vhost_virtqueue vqs[VHOST_NET_VQ_MAX];
  59        struct vhost_poll poll[VHOST_NET_VQ_MAX];
  60        /* Tells us whether we are polling a socket for TX.
  61         * We only do this when socket buffer fills up.
  62         * Protected by tx vq lock. */
  63        enum vhost_net_poll_state tx_poll_state;
  64};
  65
  66static bool vhost_sock_zcopy(struct socket *sock)
  67{
  68        return unlikely(experimental_zcopytx) &&
  69                sock_flag(sock->sk, SOCK_ZEROCOPY);
  70}
  71
  72/* Pop first len bytes from iovec. Return number of segments used. */
  73static int move_iovec_hdr(struct iovec *from, struct iovec *to,
  74                          size_t len, int iov_count)
  75{
  76        int seg = 0;
  77        size_t size;
  78
  79        while (len && seg < iov_count) {
  80                size = min(from->iov_len, len);
  81                to->iov_base = from->iov_base;
  82                to->iov_len = size;
  83                from->iov_len -= size;
  84                from->iov_base += size;
  85                len -= size;
  86                ++from;
  87                ++to;
  88                ++seg;
  89        }
  90        return seg;
  91}
  92/* Copy iovec entries for len bytes from iovec. */
  93static void copy_iovec_hdr(const struct iovec *from, struct iovec *to,
  94                           size_t len, int iovcount)
  95{
  96        int seg = 0;
  97        size_t size;
  98
  99        while (len && seg < iovcount) {
 100                size = min(from->iov_len, len);
 101                to->iov_base = from->iov_base;
 102                to->iov_len = size;
 103                len -= size;
 104                ++from;
 105                ++to;
 106                ++seg;
 107        }
 108}
 109
 110/* Caller must have TX VQ lock */
 111static void tx_poll_stop(struct vhost_net *net)
 112{
 113        if (likely(net->tx_poll_state != VHOST_NET_POLL_STARTED))
 114                return;
 115        vhost_poll_stop(net->poll + VHOST_NET_VQ_TX);
 116        net->tx_poll_state = VHOST_NET_POLL_STOPPED;
 117}
 118
 119/* Caller must have TX VQ lock */
 120static void tx_poll_start(struct vhost_net *net, struct socket *sock)
 121{
 122        if (unlikely(net->tx_poll_state != VHOST_NET_POLL_STOPPED))
 123                return;
 124        vhost_poll_start(net->poll + VHOST_NET_VQ_TX, sock->file);
 125        net->tx_poll_state = VHOST_NET_POLL_STARTED;
 126}
 127
 128/* Expects to be always run from workqueue - which acts as
 129 * read-size critical section for our kind of RCU. */
 130static void handle_tx(struct vhost_net *net)
 131{
 132        struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX];
 133        unsigned out, in, s;
 134        int head;
 135        struct msghdr msg = {
 136                .msg_name = NULL,
 137                .msg_namelen = 0,
 138                .msg_control = NULL,
 139                .msg_controllen = 0,
 140                .msg_iov = vq->iov,
 141                .msg_flags = MSG_DONTWAIT,
 142        };
 143        size_t len, total_len = 0;
 144        int err, wmem;
 145        size_t hdr_size;
 146        struct socket *sock;
 147        struct vhost_ubuf_ref *uninitialized_var(ubufs);
 148        bool zcopy;
 149
 150        /* TODO: check that we are running from vhost_worker? */
 151        sock = rcu_dereference_check(vq->private_data, 1);
 152        if (!sock)
 153                return;
 154
 155        wmem = atomic_read(&sock->sk->sk_wmem_alloc);
 156        if (wmem >= sock->sk->sk_sndbuf) {
 157                mutex_lock(&vq->mutex);
 158                tx_poll_start(net, sock);
 159                mutex_unlock(&vq->mutex);
 160                return;
 161        }
 162
 163        mutex_lock(&vq->mutex);
 164        vhost_disable_notify(&net->dev, vq);
 165
 166        if (wmem < sock->sk->sk_sndbuf / 2)
 167                tx_poll_stop(net);
 168        hdr_size = vq->vhost_hlen;
 169        zcopy = vhost_sock_zcopy(sock);
 170
 171        for (;;) {
 172                /* Release DMAs done buffers first */
 173                if (zcopy)
 174                        vhost_zerocopy_signal_used(vq);
 175
 176                head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
 177                                         ARRAY_SIZE(vq->iov),
 178                                         &out, &in,
 179                                         NULL, NULL);
 180                /* On error, stop handling until the next kick. */
 181                if (unlikely(head < 0))
 182                        break;
 183                /* Nothing new?  Wait for eventfd to tell us they refilled. */
 184                if (head == vq->num) {
 185                        int num_pends;
 186
 187                        wmem = atomic_read(&sock->sk->sk_wmem_alloc);
 188                        if (wmem >= sock->sk->sk_sndbuf * 3 / 4) {
 189                                tx_poll_start(net, sock);
 190                                set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
 191                                break;
 192                        }
 193                        /* If more outstanding DMAs, queue the work.
 194                         * Handle upend_idx wrap around
 195                         */
 196                        num_pends = likely(vq->upend_idx >= vq->done_idx) ?
 197                                    (vq->upend_idx - vq->done_idx) :
 198                                    (vq->upend_idx + UIO_MAXIOV - vq->done_idx);
 199                        if (unlikely(num_pends > VHOST_MAX_PEND)) {
 200                                tx_poll_start(net, sock);
 201                                set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
 202                                break;
 203                        }
 204                        if (unlikely(vhost_enable_notify(&net->dev, vq))) {
 205                                vhost_disable_notify(&net->dev, vq);
 206                                continue;
 207                        }
 208                        break;
 209                }
 210                if (in) {
 211                        vq_err(vq, "Unexpected descriptor format for TX: "
 212                               "out %d, int %d\n", out, in);
 213                        break;
 214                }
 215                /* Skip header. TODO: support TSO. */
 216                s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, out);
 217                msg.msg_iovlen = out;
 218                len = iov_length(vq->iov, out);
 219                /* Sanity check */
 220                if (!len) {
 221                        vq_err(vq, "Unexpected header len for TX: "
 222                               "%zd expected %zd\n",
 223                               iov_length(vq->hdr, s), hdr_size);
 224                        break;
 225                }
 226                /* use msg_control to pass vhost zerocopy ubuf info to skb */
 227                if (zcopy) {
 228                        vq->heads[vq->upend_idx].id = head;
 229                        if (len < VHOST_GOODCOPY_LEN) {
 230                                /* copy don't need to wait for DMA done */
 231                                vq->heads[vq->upend_idx].len =
 232                                                        VHOST_DMA_DONE_LEN;
 233                                msg.msg_control = NULL;
 234                                msg.msg_controllen = 0;
 235                                ubufs = NULL;
 236                        } else {
 237                                struct ubuf_info *ubuf = &vq->ubuf_info[head];
 238
 239                                vq->heads[vq->upend_idx].len = len;
 240                                ubuf->callback = vhost_zerocopy_callback;
 241                                ubuf->arg = vq->ubufs;
 242                                ubuf->desc = vq->upend_idx;
 243                                msg.msg_control = ubuf;
 244                                msg.msg_controllen = sizeof(ubuf);
 245                                ubufs = vq->ubufs;
 246                                kref_get(&ubufs->kref);
 247                        }
 248                        vq->upend_idx = (vq->upend_idx + 1) % UIO_MAXIOV;
 249                }
 250                /* TODO: Check specific error and bomb out unless ENOBUFS? */
 251                err = sock->ops->sendmsg(NULL, sock, &msg, len);
 252                if (unlikely(err < 0)) {
 253                        if (zcopy) {
 254                                if (ubufs)
 255                                        vhost_ubuf_put(ubufs);
 256                                vq->upend_idx = ((unsigned)vq->upend_idx - 1) %
 257                                        UIO_MAXIOV;
 258                        }
 259                        vhost_discard_vq_desc(vq, 1);
 260                        tx_poll_start(net, sock);
 261                        break;
 262                }
 263                if (err != len)
 264                        pr_debug("Truncated TX packet: "
 265                                 " len %d != %zd\n", err, len);
 266                if (!zcopy)
 267                        vhost_add_used_and_signal(&net->dev, vq, head, 0);
 268                total_len += len;
 269                if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
 270                        vhost_poll_queue(&vq->poll);
 271                        break;
 272                }
 273        }
 274
 275        mutex_unlock(&vq->mutex);
 276}
 277
 278static int peek_head_len(struct sock *sk)
 279{
 280        struct sk_buff *head;
 281        int len = 0;
 282        unsigned long flags;
 283
 284        spin_lock_irqsave(&sk->sk_receive_queue.lock, flags);
 285        head = skb_peek(&sk->sk_receive_queue);
 286        if (likely(head))
 287                len = head->len;
 288        spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags);
 289        return len;
 290}
 291
 292/* This is a multi-buffer version of vhost_get_desc, that works if
 293 *      vq has read descriptors only.
 294 * @vq          - the relevant virtqueue
 295 * @datalen     - data length we'll be reading
 296 * @iovcount    - returned count of io vectors we fill
 297 * @log         - vhost log
 298 * @log_num     - log offset
 299 * @quota       - headcount quota, 1 for big buffer
 300 *      returns number of buffer heads allocated, negative on error
 301 */
 302static int get_rx_bufs(struct vhost_virtqueue *vq,
 303                       struct vring_used_elem *heads,
 304                       int datalen,
 305                       unsigned *iovcount,
 306                       struct vhost_log *log,
 307                       unsigned *log_num,
 308                       unsigned int quota)
 309{
 310        unsigned int out, in;
 311        int seg = 0;
 312        int headcount = 0;
 313        unsigned d;
 314        int r, nlogs = 0;
 315
 316        while (datalen > 0 && headcount < quota) {
 317                if (unlikely(seg >= UIO_MAXIOV)) {
 318                        r = -ENOBUFS;
 319                        goto err;
 320                }
 321                d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
 322                                      ARRAY_SIZE(vq->iov) - seg, &out,
 323                                      &in, log, log_num);
 324                if (d == vq->num) {
 325                        r = 0;
 326                        goto err;
 327                }
 328                if (unlikely(out || in <= 0)) {
 329                        vq_err(vq, "unexpected descriptor format for RX: "
 330                                "out %d, in %d\n", out, in);
 331                        r = -EINVAL;
 332                        goto err;
 333                }
 334                if (unlikely(log)) {
 335                        nlogs += *log_num;
 336                        log += *log_num;
 337                }
 338                heads[headcount].id = d;
 339                heads[headcount].len = iov_length(vq->iov + seg, in);
 340                datalen -= heads[headcount].len;
 341                ++headcount;
 342                seg += in;
 343        }
 344        heads[headcount - 1].len += datalen;
 345        *iovcount = seg;
 346        if (unlikely(log))
 347                *log_num = nlogs;
 348        return headcount;
 349err:
 350        vhost_discard_vq_desc(vq, headcount);
 351        return r;
 352}
 353
 354/* Expects to be always run from workqueue - which acts as
 355 * read-size critical section for our kind of RCU. */
 356static void handle_rx(struct vhost_net *net)
 357{
 358        struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
 359        unsigned uninitialized_var(in), log;
 360        struct vhost_log *vq_log;
 361        struct msghdr msg = {
 362                .msg_name = NULL,
 363                .msg_namelen = 0,
 364                .msg_control = NULL, /* FIXME: get and handle RX aux data. */
 365                .msg_controllen = 0,
 366                .msg_iov = vq->iov,
 367                .msg_flags = MSG_DONTWAIT,
 368        };
 369        struct virtio_net_hdr_mrg_rxbuf hdr = {
 370                .hdr.flags = 0,
 371                .hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE
 372        };
 373        size_t total_len = 0;
 374        int err, headcount, mergeable;
 375        size_t vhost_hlen, sock_hlen;
 376        size_t vhost_len, sock_len;
 377        /* TODO: check that we are running from vhost_worker? */
 378        struct socket *sock = rcu_dereference_check(vq->private_data, 1);
 379
 380        if (!sock)
 381                return;
 382
 383        mutex_lock(&vq->mutex);
 384        vhost_disable_notify(&net->dev, vq);
 385        vhost_hlen = vq->vhost_hlen;
 386        sock_hlen = vq->sock_hlen;
 387
 388        vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
 389                vq->log : NULL;
 390        mergeable = vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF);
 391
 392        while ((sock_len = peek_head_len(sock->sk))) {
 393                sock_len += sock_hlen;
 394                vhost_len = sock_len + vhost_hlen;
 395                headcount = get_rx_bufs(vq, vq->heads, vhost_len,
 396                                        &in, vq_log, &log,
 397                                        likely(mergeable) ? UIO_MAXIOV : 1);
 398                /* On error, stop handling until the next kick. */
 399                if (unlikely(headcount < 0))
 400                        break;
 401                /* OK, now we need to know about added descriptors. */
 402                if (!headcount) {
 403                        if (unlikely(vhost_enable_notify(&net->dev, vq))) {
 404                                /* They have slipped one in as we were
 405                                 * doing that: check again. */
 406                                vhost_disable_notify(&net->dev, vq);
 407                                continue;
 408                        }
 409                        /* Nothing new?  Wait for eventfd to tell us
 410                         * they refilled. */
 411                        break;
 412                }
 413                /* We don't need to be notified again. */
 414                if (unlikely((vhost_hlen)))
 415                        /* Skip header. TODO: support TSO. */
 416                        move_iovec_hdr(vq->iov, vq->hdr, vhost_hlen, in);
 417                else
 418                        /* Copy the header for use in VIRTIO_NET_F_MRG_RXBUF:
 419                         * needed because recvmsg can modify msg_iov. */
 420                        copy_iovec_hdr(vq->iov, vq->hdr, sock_hlen, in);
 421                msg.msg_iovlen = in;
 422                err = sock->ops->recvmsg(NULL, sock, &msg,
 423                                         sock_len, MSG_DONTWAIT | MSG_TRUNC);
 424                /* Userspace might have consumed the packet meanwhile:
 425                 * it's not supposed to do this usually, but might be hard
 426                 * to prevent. Discard data we got (if any) and keep going. */
 427                if (unlikely(err != sock_len)) {
 428                        pr_debug("Discarded rx packet: "
 429                                 " len %d, expected %zd\n", err, sock_len);
 430                        vhost_discard_vq_desc(vq, headcount);
 431                        continue;
 432                }
 433                if (unlikely(vhost_hlen) &&
 434                    memcpy_toiovecend(vq->hdr, (unsigned char *)&hdr, 0,
 435                                      vhost_hlen)) {
 436                        vq_err(vq, "Unable to write vnet_hdr at addr %p\n",
 437                               vq->iov->iov_base);
 438                        break;
 439                }
 440                /* TODO: Should check and handle checksum. */
 441                if (likely(mergeable) &&
 442                    memcpy_toiovecend(vq->hdr, (unsigned char *)&headcount,
 443                                      offsetof(typeof(hdr), num_buffers),
 444                                      sizeof hdr.num_buffers)) {
 445                        vq_err(vq, "Failed num_buffers write");
 446                        vhost_discard_vq_desc(vq, headcount);
 447                        break;
 448                }
 449                vhost_add_used_and_signal_n(&net->dev, vq, vq->heads,
 450                                            headcount);
 451                if (unlikely(vq_log))
 452                        vhost_log_write(vq, vq_log, log, vhost_len);
 453                total_len += vhost_len;
 454                if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
 455                        vhost_poll_queue(&vq->poll);
 456                        break;
 457                }
 458        }
 459
 460        mutex_unlock(&vq->mutex);
 461}
 462
 463static void handle_tx_kick(struct vhost_work *work)
 464{
 465        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 466                                                  poll.work);
 467        struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);
 468
 469        handle_tx(net);
 470}
 471
 472static void handle_rx_kick(struct vhost_work *work)
 473{
 474        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 475                                                  poll.work);
 476        struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);
 477
 478        handle_rx(net);
 479}
 480
 481static void handle_tx_net(struct vhost_work *work)
 482{
 483        struct vhost_net *net = container_of(work, struct vhost_net,
 484                                             poll[VHOST_NET_VQ_TX].work);
 485        handle_tx(net);
 486}
 487
 488static void handle_rx_net(struct vhost_work *work)
 489{
 490        struct vhost_net *net = container_of(work, struct vhost_net,
 491                                             poll[VHOST_NET_VQ_RX].work);
 492        handle_rx(net);
 493}
 494
 495static int vhost_net_open(struct inode *inode, struct file *f)
 496{
 497        struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL);
 498        struct vhost_dev *dev;
 499        int r;
 500
 501        if (!n)
 502                return -ENOMEM;
 503
 504        dev = &n->dev;
 505        n->vqs[VHOST_NET_VQ_TX].handle_kick = handle_tx_kick;
 506        n->vqs[VHOST_NET_VQ_RX].handle_kick = handle_rx_kick;
 507        r = vhost_dev_init(dev, n->vqs, VHOST_NET_VQ_MAX);
 508        if (r < 0) {
 509                kfree(n);
 510                return r;
 511        }
 512
 513        vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev);
 514        vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev);
 515        n->tx_poll_state = VHOST_NET_POLL_DISABLED;
 516
 517        f->private_data = n;
 518
 519        return 0;
 520}
 521
 522static void vhost_net_disable_vq(struct vhost_net *n,
 523                                 struct vhost_virtqueue *vq)
 524{
 525        if (!vq->private_data)
 526                return;
 527        if (vq == n->vqs + VHOST_NET_VQ_TX) {
 528                tx_poll_stop(n);
 529                n->tx_poll_state = VHOST_NET_POLL_DISABLED;
 530        } else
 531                vhost_poll_stop(n->poll + VHOST_NET_VQ_RX);
 532}
 533
 534static void vhost_net_enable_vq(struct vhost_net *n,
 535                                struct vhost_virtqueue *vq)
 536{
 537        struct socket *sock;
 538
 539        sock = rcu_dereference_protected(vq->private_data,
 540                                         lockdep_is_held(&vq->mutex));
 541        if (!sock)
 542                return;
 543        if (vq == n->vqs + VHOST_NET_VQ_TX) {
 544                n->tx_poll_state = VHOST_NET_POLL_STOPPED;
 545                tx_poll_start(n, sock);
 546        } else
 547                vhost_poll_start(n->poll + VHOST_NET_VQ_RX, sock->file);
 548}
 549
 550static struct socket *vhost_net_stop_vq(struct vhost_net *n,
 551                                        struct vhost_virtqueue *vq)
 552{
 553        struct socket *sock;
 554
 555        mutex_lock(&vq->mutex);
 556        sock = rcu_dereference_protected(vq->private_data,
 557                                         lockdep_is_held(&vq->mutex));
 558        vhost_net_disable_vq(n, vq);
 559        rcu_assign_pointer(vq->private_data, NULL);
 560        mutex_unlock(&vq->mutex);
 561        return sock;
 562}
 563
 564static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
 565                           struct socket **rx_sock)
 566{
 567        *tx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_TX);
 568        *rx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_RX);
 569}
 570
 571static void vhost_net_flush_vq(struct vhost_net *n, int index)
 572{
 573        vhost_poll_flush(n->poll + index);
 574        vhost_poll_flush(&n->dev.vqs[index].poll);
 575}
 576
 577static void vhost_net_flush(struct vhost_net *n)
 578{
 579        vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
 580        vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
 581}
 582
 583static int vhost_net_release(struct inode *inode, struct file *f)
 584{
 585        struct vhost_net *n = f->private_data;
 586        struct socket *tx_sock;
 587        struct socket *rx_sock;
 588
 589        vhost_net_stop(n, &tx_sock, &rx_sock);
 590        vhost_net_flush(n);
 591        vhost_dev_cleanup(&n->dev);
 592        if (tx_sock)
 593                fput(tx_sock->file);
 594        if (rx_sock)
 595                fput(rx_sock->file);
 596        /* We do an extra flush before freeing memory,
 597         * since jobs can re-queue themselves. */
 598        vhost_net_flush(n);
 599        kfree(n);
 600        return 0;
 601}
 602
 603static struct socket *get_raw_socket(int fd)
 604{
 605        struct {
 606                struct sockaddr_ll sa;
 607                char  buf[MAX_ADDR_LEN];
 608        } uaddr;
 609        int uaddr_len = sizeof uaddr, r;
 610        struct socket *sock = sockfd_lookup(fd, &r);
 611
 612        if (!sock)
 613                return ERR_PTR(-ENOTSOCK);
 614
 615        /* Parameter checking */
 616        if (sock->sk->sk_type != SOCK_RAW) {
 617                r = -ESOCKTNOSUPPORT;
 618                goto err;
 619        }
 620
 621        r = sock->ops->getname(sock, (struct sockaddr *)&uaddr.sa,
 622                               &uaddr_len, 0);
 623        if (r)
 624                goto err;
 625
 626        if (uaddr.sa.sll_family != AF_PACKET) {
 627                r = -EPFNOSUPPORT;
 628                goto err;
 629        }
 630        return sock;
 631err:
 632        fput(sock->file);
 633        return ERR_PTR(r);
 634}
 635
 636static struct socket *get_tap_socket(int fd)
 637{
 638        struct file *file = fget(fd);
 639        struct socket *sock;
 640
 641        if (!file)
 642                return ERR_PTR(-EBADF);
 643        sock = tun_get_socket(file);
 644        if (!IS_ERR(sock))
 645                return sock;
 646        sock = macvtap_get_socket(file);
 647        if (IS_ERR(sock))
 648                fput(file);
 649        return sock;
 650}
 651
 652static struct socket *get_socket(int fd)
 653{
 654        struct socket *sock;
 655
 656        /* special case to disable backend */
 657        if (fd == -1)
 658                return NULL;
 659        sock = get_raw_socket(fd);
 660        if (!IS_ERR(sock))
 661                return sock;
 662        sock = get_tap_socket(fd);
 663        if (!IS_ERR(sock))
 664                return sock;
 665        return ERR_PTR(-ENOTSOCK);
 666}
 667
 668static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
 669{
 670        struct socket *sock, *oldsock;
 671        struct vhost_virtqueue *vq;
 672        struct vhost_ubuf_ref *ubufs, *oldubufs = NULL;
 673        int r;
 674
 675        mutex_lock(&n->dev.mutex);
 676        r = vhost_dev_check_owner(&n->dev);
 677        if (r)
 678                goto err;
 679
 680        if (index >= VHOST_NET_VQ_MAX) {
 681                r = -ENOBUFS;
 682                goto err;
 683        }
 684        vq = n->vqs + index;
 685        mutex_lock(&vq->mutex);
 686
 687        /* Verify that ring has been setup correctly. */
 688        if (!vhost_vq_access_ok(vq)) {
 689                r = -EFAULT;
 690                goto err_vq;
 691        }
 692        sock = get_socket(fd);
 693        if (IS_ERR(sock)) {
 694                r = PTR_ERR(sock);
 695                goto err_vq;
 696        }
 697
 698        /* start polling new socket */
 699        oldsock = rcu_dereference_protected(vq->private_data,
 700                                            lockdep_is_held(&vq->mutex));
 701        if (sock != oldsock) {
 702                ubufs = vhost_ubuf_alloc(vq, sock && vhost_sock_zcopy(sock));
 703                if (IS_ERR(ubufs)) {
 704                        r = PTR_ERR(ubufs);
 705                        goto err_ubufs;
 706                }
 707                oldubufs = vq->ubufs;
 708                vq->ubufs = ubufs;
 709                vhost_net_disable_vq(n, vq);
 710                rcu_assign_pointer(vq->private_data, sock);
 711                vhost_net_enable_vq(n, vq);
 712
 713                r = vhost_init_used(vq);
 714                if (r)
 715                        goto err_vq;
 716        }
 717
 718        mutex_unlock(&vq->mutex);
 719
 720        if (oldubufs) {
 721                vhost_ubuf_put_and_wait(oldubufs);
 722                mutex_lock(&vq->mutex);
 723                vhost_zerocopy_signal_used(vq);
 724                mutex_unlock(&vq->mutex);
 725        }
 726
 727        if (oldsock) {
 728                vhost_net_flush_vq(n, index);
 729                fput(oldsock->file);
 730        }
 731
 732        mutex_unlock(&n->dev.mutex);
 733        return 0;
 734
 735err_ubufs:
 736        fput(sock->file);
 737err_vq:
 738        mutex_unlock(&vq->mutex);
 739err:
 740        mutex_unlock(&n->dev.mutex);
 741        return r;
 742}
 743
 744static long vhost_net_reset_owner(struct vhost_net *n)
 745{
 746        struct socket *tx_sock = NULL;
 747        struct socket *rx_sock = NULL;
 748        long err;
 749
 750        mutex_lock(&n->dev.mutex);
 751        err = vhost_dev_check_owner(&n->dev);
 752        if (err)
 753                goto done;
 754        vhost_net_stop(n, &tx_sock, &rx_sock);
 755        vhost_net_flush(n);
 756        err = vhost_dev_reset_owner(&n->dev);
 757done:
 758        mutex_unlock(&n->dev.mutex);
 759        if (tx_sock)
 760                fput(tx_sock->file);
 761        if (rx_sock)
 762                fput(rx_sock->file);
 763        return err;
 764}
 765
 766static int vhost_net_set_features(struct vhost_net *n, u64 features)
 767{
 768        size_t vhost_hlen, sock_hlen, hdr_len;
 769        int i;
 770
 771        hdr_len = (features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ?
 772                        sizeof(struct virtio_net_hdr_mrg_rxbuf) :
 773                        sizeof(struct virtio_net_hdr);
 774        if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) {
 775                /* vhost provides vnet_hdr */
 776                vhost_hlen = hdr_len;
 777                sock_hlen = 0;
 778        } else {
 779                /* socket provides vnet_hdr */
 780                vhost_hlen = 0;
 781                sock_hlen = hdr_len;
 782        }
 783        mutex_lock(&n->dev.mutex);
 784        if ((features & (1 << VHOST_F_LOG_ALL)) &&
 785            !vhost_log_access_ok(&n->dev)) {
 786                mutex_unlock(&n->dev.mutex);
 787                return -EFAULT;
 788        }
 789        n->dev.acked_features = features;
 790        smp_wmb();
 791        for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
 792                mutex_lock(&n->vqs[i].mutex);
 793                n->vqs[i].vhost_hlen = vhost_hlen;
 794                n->vqs[i].sock_hlen = sock_hlen;
 795                mutex_unlock(&n->vqs[i].mutex);
 796        }
 797        vhost_net_flush(n);
 798        mutex_unlock(&n->dev.mutex);
 799        return 0;
 800}
 801
 802static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
 803                            unsigned long arg)
 804{
 805        struct vhost_net *n = f->private_data;
 806        void __user *argp = (void __user *)arg;
 807        u64 __user *featurep = argp;
 808        struct vhost_vring_file backend;
 809        u64 features;
 810        int r;
 811
 812        switch (ioctl) {
 813        case VHOST_NET_SET_BACKEND:
 814                if (copy_from_user(&backend, argp, sizeof backend))
 815                        return -EFAULT;
 816                return vhost_net_set_backend(n, backend.index, backend.fd);
 817        case VHOST_GET_FEATURES:
 818                features = VHOST_FEATURES;
 819                if (copy_to_user(featurep, &features, sizeof features))
 820                        return -EFAULT;
 821                return 0;
 822        case VHOST_SET_FEATURES:
 823                if (copy_from_user(&features, featurep, sizeof features))
 824                        return -EFAULT;
 825                if (features & ~VHOST_FEATURES)
 826                        return -EOPNOTSUPP;
 827                return vhost_net_set_features(n, features);
 828        case VHOST_RESET_OWNER:
 829                return vhost_net_reset_owner(n);
 830        default:
 831                mutex_lock(&n->dev.mutex);
 832                r = vhost_dev_ioctl(&n->dev, ioctl, arg);
 833                vhost_net_flush(n);
 834                mutex_unlock(&n->dev.mutex);
 835                return r;
 836        }
 837}
 838
 839#ifdef CONFIG_COMPAT
 840static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl,
 841                                   unsigned long arg)
 842{
 843        return vhost_net_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
 844}
 845#endif
 846
 847static const struct file_operations vhost_net_fops = {
 848        .owner          = THIS_MODULE,
 849        .release        = vhost_net_release,
 850        .unlocked_ioctl = vhost_net_ioctl,
 851#ifdef CONFIG_COMPAT
 852        .compat_ioctl   = vhost_net_compat_ioctl,
 853#endif
 854        .open           = vhost_net_open,
 855        .llseek         = noop_llseek,
 856};
 857
 858static struct miscdevice vhost_net_misc = {
 859        MISC_DYNAMIC_MINOR,
 860        "vhost-net",
 861        &vhost_net_fops,
 862};
 863
 864static int vhost_net_init(void)
 865{
 866        if (experimental_zcopytx)
 867                vhost_enable_zcopy(VHOST_NET_VQ_TX);
 868        return misc_register(&vhost_net_misc);
 869}
 870module_init(vhost_net_init);
 871
 872static void vhost_net_exit(void)
 873{
 874        misc_deregister(&vhost_net_misc);
 875}
 876module_exit(vhost_net_exit);
 877
 878MODULE_VERSION("0.0.1");
 879MODULE_LICENSE("GPL v2");
 880MODULE_AUTHOR("Michael S. Tsirkin");
 881MODULE_DESCRIPTION("Host kernel accelerator for virtio net");
 882
lxr.linux.no kindly hosted by Redpill Linpro AS, provider of Linux consulting and operations services since 1995.