linux/drivers/vhost/vsock.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * vhost transport for vsock
   4 *
   5 * Copyright (C) 2013-2015 Red Hat, Inc.
   6 * Author: Asias He <asias@redhat.com>
   7 *         Stefan Hajnoczi <stefanha@redhat.com>
   8 */
   9#include <linux/miscdevice.h>
  10#include <linux/atomic.h>
  11#include <linux/module.h>
  12#include <linux/mutex.h>
  13#include <linux/vmalloc.h>
  14#include <net/sock.h>
  15#include <linux/virtio_vsock.h>
  16#include <linux/vhost.h>
  17#include <linux/hashtable.h>
  18
  19#include <net/af_vsock.h>
  20#include "vhost.h"
  21
  22#define VHOST_VSOCK_DEFAULT_HOST_CID    2
  23/* Max number of bytes transferred before requeueing the job.
  24 * Using this limit prevents one virtqueue from starving others. */
  25#define VHOST_VSOCK_WEIGHT 0x80000
  26/* Max number of packets transferred before requeueing the job.
  27 * Using this limit prevents one virtqueue from starving others with
  28 * small pkts.
  29 */
  30#define VHOST_VSOCK_PKT_WEIGHT 256
  31
  32enum {
  33        VHOST_VSOCK_FEATURES = VHOST_FEATURES |
  34                               (1ULL << VIRTIO_F_ACCESS_PLATFORM)
  35};
  36
  37enum {
  38        VHOST_VSOCK_BACKEND_FEATURES = (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2)
  39};
  40
  41/* Used to track all the vhost_vsock instances on the system. */
  42static DEFINE_MUTEX(vhost_vsock_mutex);
  43static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
  44
  45struct vhost_vsock {
  46        struct vhost_dev dev;
  47        struct vhost_virtqueue vqs[2];
  48
  49        /* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */
  50        struct hlist_node hash;
  51
  52        struct vhost_work send_pkt_work;
  53        spinlock_t send_pkt_list_lock;
  54        struct list_head send_pkt_list; /* host->guest pending packets */
  55
  56        atomic_t queued_replies;
  57
  58        u32 guest_cid;
  59};
  60
  61static u32 vhost_transport_get_local_cid(void)
  62{
  63        return VHOST_VSOCK_DEFAULT_HOST_CID;
  64}
  65
  66/* Callers that dereference the return value must hold vhost_vsock_mutex or the
  67 * RCU read lock.
  68 */
  69static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
  70{
  71        struct vhost_vsock *vsock;
  72
  73        hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) {
  74                u32 other_cid = vsock->guest_cid;
  75
  76                /* Skip instances that have no CID yet */
  77                if (other_cid == 0)
  78                        continue;
  79
  80                if (other_cid == guest_cid)
  81                        return vsock;
  82
  83        }
  84
  85        return NULL;
  86}
  87
  88static void
  89vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
  90                            struct vhost_virtqueue *vq)
  91{
  92        struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
  93        int pkts = 0, total_len = 0;
  94        bool added = false;
  95        bool restart_tx = false;
  96
  97        mutex_lock(&vq->mutex);
  98
  99        if (!vhost_vq_get_backend(vq))
 100                goto out;
 101
 102        if (!vq_meta_prefetch(vq))
 103                goto out;
 104
 105        /* Avoid further vmexits, we're already processing the virtqueue */
 106        vhost_disable_notify(&vsock->dev, vq);
 107
 108        do {
 109                struct virtio_vsock_pkt *pkt;
 110                struct iov_iter iov_iter;
 111                unsigned out, in;
 112                size_t nbytes;
 113                size_t iov_len, payload_len;
 114                int head;
 115
 116                spin_lock_bh(&vsock->send_pkt_list_lock);
 117                if (list_empty(&vsock->send_pkt_list)) {
 118                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 119                        vhost_enable_notify(&vsock->dev, vq);
 120                        break;
 121                }
 122
 123                pkt = list_first_entry(&vsock->send_pkt_list,
 124                                       struct virtio_vsock_pkt, list);
 125                list_del_init(&pkt->list);
 126                spin_unlock_bh(&vsock->send_pkt_list_lock);
 127
 128                head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
 129                                         &out, &in, NULL, NULL);
 130                if (head < 0) {
 131                        spin_lock_bh(&vsock->send_pkt_list_lock);
 132                        list_add(&pkt->list, &vsock->send_pkt_list);
 133                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 134                        break;
 135                }
 136
 137                if (head == vq->num) {
 138                        spin_lock_bh(&vsock->send_pkt_list_lock);
 139                        list_add(&pkt->list, &vsock->send_pkt_list);
 140                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 141
 142                        /* We cannot finish yet if more buffers snuck in while
 143                         * re-enabling notify.
 144                         */
 145                        if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
 146                                vhost_disable_notify(&vsock->dev, vq);
 147                                continue;
 148                        }
 149                        break;
 150                }
 151
 152                if (out) {
 153                        virtio_transport_free_pkt(pkt);
 154                        vq_err(vq, "Expected 0 output buffers, got %u\n", out);
 155                        break;
 156                }
 157
 158                iov_len = iov_length(&vq->iov[out], in);
 159                if (iov_len < sizeof(pkt->hdr)) {
 160                        virtio_transport_free_pkt(pkt);
 161                        vq_err(vq, "Buffer len [%zu] too small\n", iov_len);
 162                        break;
 163                }
 164
 165                iov_iter_init(&iov_iter, READ, &vq->iov[out], in, iov_len);
 166                payload_len = pkt->len - pkt->off;
 167
 168                /* If the packet is greater than the space available in the
 169                 * buffer, we split it using multiple buffers.
 170                 */
 171                if (payload_len > iov_len - sizeof(pkt->hdr))
 172                        payload_len = iov_len - sizeof(pkt->hdr);
 173
 174                /* Set the correct length in the header */
 175                pkt->hdr.len = cpu_to_le32(payload_len);
 176
 177                nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
 178                if (nbytes != sizeof(pkt->hdr)) {
 179                        virtio_transport_free_pkt(pkt);
 180                        vq_err(vq, "Faulted on copying pkt hdr\n");
 181                        break;
 182                }
 183
 184                nbytes = copy_to_iter(pkt->buf + pkt->off, payload_len,
 185                                      &iov_iter);
 186                if (nbytes != payload_len) {
 187                        virtio_transport_free_pkt(pkt);
 188                        vq_err(vq, "Faulted on copying pkt buf\n");
 189                        break;
 190                }
 191
 192                /* Deliver to monitoring devices all packets that we
 193                 * will transmit.
 194                 */
 195                virtio_transport_deliver_tap_pkt(pkt);
 196
 197                vhost_add_used(vq, head, sizeof(pkt->hdr) + payload_len);
 198                added = true;
 199
 200                pkt->off += payload_len;
 201                total_len += payload_len;
 202
 203                /* If we didn't send all the payload we can requeue the packet
 204                 * to send it with the next available buffer.
 205                 */
 206                if (pkt->off < pkt->len) {
 207                        /* We are queueing the same virtio_vsock_pkt to handle
 208                         * the remaining bytes, and we want to deliver it
 209                         * to monitoring devices in the next iteration.
 210                         */
 211                        pkt->tap_delivered = false;
 212
 213                        spin_lock_bh(&vsock->send_pkt_list_lock);
 214                        list_add(&pkt->list, &vsock->send_pkt_list);
 215                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 216                } else {
 217                        if (pkt->reply) {
 218                                int val;
 219
 220                                val = atomic_dec_return(&vsock->queued_replies);
 221
 222                                /* Do we have resources to resume tx
 223                                 * processing?
 224                                 */
 225                                if (val + 1 == tx_vq->num)
 226                                        restart_tx = true;
 227                        }
 228
 229                        virtio_transport_free_pkt(pkt);
 230                }
 231        } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
 232        if (added)
 233                vhost_signal(&vsock->dev, vq);
 234
 235out:
 236        mutex_unlock(&vq->mutex);
 237
 238        if (restart_tx)
 239                vhost_poll_queue(&tx_vq->poll);
 240}
 241
 242static void vhost_transport_send_pkt_work(struct vhost_work *work)
 243{
 244        struct vhost_virtqueue *vq;
 245        struct vhost_vsock *vsock;
 246
 247        vsock = container_of(work, struct vhost_vsock, send_pkt_work);
 248        vq = &vsock->vqs[VSOCK_VQ_RX];
 249
 250        vhost_transport_do_send_pkt(vsock, vq);
 251}
 252
 253static int
 254vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
 255{
 256        struct vhost_vsock *vsock;
 257        int len = pkt->len;
 258
 259        rcu_read_lock();
 260
 261        /* Find the vhost_vsock according to guest context id  */
 262        vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
 263        if (!vsock) {
 264                rcu_read_unlock();
 265                virtio_transport_free_pkt(pkt);
 266                return -ENODEV;
 267        }
 268
 269        if (pkt->reply)
 270                atomic_inc(&vsock->queued_replies);
 271
 272        spin_lock_bh(&vsock->send_pkt_list_lock);
 273        list_add_tail(&pkt->list, &vsock->send_pkt_list);
 274        spin_unlock_bh(&vsock->send_pkt_list_lock);
 275
 276        vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
 277
 278        rcu_read_unlock();
 279        return len;
 280}
 281
 282static int
 283vhost_transport_cancel_pkt(struct vsock_sock *vsk)
 284{
 285        struct vhost_vsock *vsock;
 286        struct virtio_vsock_pkt *pkt, *n;
 287        int cnt = 0;
 288        int ret = -ENODEV;
 289        LIST_HEAD(freeme);
 290
 291        rcu_read_lock();
 292
 293        /* Find the vhost_vsock according to guest context id  */
 294        vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
 295        if (!vsock)
 296                goto out;
 297
 298        spin_lock_bh(&vsock->send_pkt_list_lock);
 299        list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
 300                if (pkt->vsk != vsk)
 301                        continue;
 302                list_move(&pkt->list, &freeme);
 303        }
 304        spin_unlock_bh(&vsock->send_pkt_list_lock);
 305
 306        list_for_each_entry_safe(pkt, n, &freeme, list) {
 307                if (pkt->reply)
 308                        cnt++;
 309                list_del(&pkt->list);
 310                virtio_transport_free_pkt(pkt);
 311        }
 312
 313        if (cnt) {
 314                struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
 315                int new_cnt;
 316
 317                new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
 318                if (new_cnt + cnt >= tx_vq->num && new_cnt < tx_vq->num)
 319                        vhost_poll_queue(&tx_vq->poll);
 320        }
 321
 322        ret = 0;
 323out:
 324        rcu_read_unlock();
 325        return ret;
 326}
 327
 328static struct virtio_vsock_pkt *
 329vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
 330                      unsigned int out, unsigned int in)
 331{
 332        struct virtio_vsock_pkt *pkt;
 333        struct iov_iter iov_iter;
 334        size_t nbytes;
 335        size_t len;
 336
 337        if (in != 0) {
 338                vq_err(vq, "Expected 0 input buffers, got %u\n", in);
 339                return NULL;
 340        }
 341
 342        pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
 343        if (!pkt)
 344                return NULL;
 345
 346        len = iov_length(vq->iov, out);
 347        iov_iter_init(&iov_iter, WRITE, vq->iov, out, len);
 348
 349        nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
 350        if (nbytes != sizeof(pkt->hdr)) {
 351                vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
 352                       sizeof(pkt->hdr), nbytes);
 353                kfree(pkt);
 354                return NULL;
 355        }
 356
 357        if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM)
 358                pkt->len = le32_to_cpu(pkt->hdr.len);
 359
 360        /* No payload */
 361        if (!pkt->len)
 362                return pkt;
 363
 364        /* The pkt is too big */
 365        if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
 366                kfree(pkt);
 367                return NULL;
 368        }
 369
 370        pkt->buf = kmalloc(pkt->len, GFP_KERNEL);
 371        if (!pkt->buf) {
 372                kfree(pkt);
 373                return NULL;
 374        }
 375
 376        pkt->buf_len = pkt->len;
 377
 378        nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
 379        if (nbytes != pkt->len) {
 380                vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
 381                       pkt->len, nbytes);
 382                virtio_transport_free_pkt(pkt);
 383                return NULL;
 384        }
 385
 386        return pkt;
 387}
 388
 389/* Is there space left for replies to rx packets? */
 390static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
 391{
 392        struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX];
 393        int val;
 394
 395        smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
 396        val = atomic_read(&vsock->queued_replies);
 397
 398        return val < vq->num;
 399}
 400
 401static struct virtio_transport vhost_transport = {
 402        .transport = {
 403                .module                   = THIS_MODULE,
 404
 405                .get_local_cid            = vhost_transport_get_local_cid,
 406
 407                .init                     = virtio_transport_do_socket_init,
 408                .destruct                 = virtio_transport_destruct,
 409                .release                  = virtio_transport_release,
 410                .connect                  = virtio_transport_connect,
 411                .shutdown                 = virtio_transport_shutdown,
 412                .cancel_pkt               = vhost_transport_cancel_pkt,
 413
 414                .dgram_enqueue            = virtio_transport_dgram_enqueue,
 415                .dgram_dequeue            = virtio_transport_dgram_dequeue,
 416                .dgram_bind               = virtio_transport_dgram_bind,
 417                .dgram_allow              = virtio_transport_dgram_allow,
 418
 419                .stream_enqueue           = virtio_transport_stream_enqueue,
 420                .stream_dequeue           = virtio_transport_stream_dequeue,
 421                .stream_has_data          = virtio_transport_stream_has_data,
 422                .stream_has_space         = virtio_transport_stream_has_space,
 423                .stream_rcvhiwat          = virtio_transport_stream_rcvhiwat,
 424                .stream_is_active         = virtio_transport_stream_is_active,
 425                .stream_allow             = virtio_transport_stream_allow,
 426
 427                .notify_poll_in           = virtio_transport_notify_poll_in,
 428                .notify_poll_out          = virtio_transport_notify_poll_out,
 429                .notify_recv_init         = virtio_transport_notify_recv_init,
 430                .notify_recv_pre_block    = virtio_transport_notify_recv_pre_block,
 431                .notify_recv_pre_dequeue  = virtio_transport_notify_recv_pre_dequeue,
 432                .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
 433                .notify_send_init         = virtio_transport_notify_send_init,
 434                .notify_send_pre_block    = virtio_transport_notify_send_pre_block,
 435                .notify_send_pre_enqueue  = virtio_transport_notify_send_pre_enqueue,
 436                .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
 437                .notify_buffer_size       = virtio_transport_notify_buffer_size,
 438
 439        },
 440
 441        .send_pkt = vhost_transport_send_pkt,
 442};
 443
 444static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
 445{
 446        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 447                                                  poll.work);
 448        struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
 449                                                 dev);
 450        struct virtio_vsock_pkt *pkt;
 451        int head, pkts = 0, total_len = 0;
 452        unsigned int out, in;
 453        bool added = false;
 454
 455        mutex_lock(&vq->mutex);
 456
 457        if (!vhost_vq_get_backend(vq))
 458                goto out;
 459
 460        if (!vq_meta_prefetch(vq))
 461                goto out;
 462
 463        vhost_disable_notify(&vsock->dev, vq);
 464        do {
 465                u32 len;
 466
 467                if (!vhost_vsock_more_replies(vsock)) {
 468                        /* Stop tx until the device processes already
 469                         * pending replies.  Leave tx virtqueue
 470                         * callbacks disabled.
 471                         */
 472                        goto no_more_replies;
 473                }
 474
 475                head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
 476                                         &out, &in, NULL, NULL);
 477                if (head < 0)
 478                        break;
 479
 480                if (head == vq->num) {
 481                        if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
 482                                vhost_disable_notify(&vsock->dev, vq);
 483                                continue;
 484                        }
 485                        break;
 486                }
 487
 488                pkt = vhost_vsock_alloc_pkt(vq, out, in);
 489                if (!pkt) {
 490                        vq_err(vq, "Faulted on pkt\n");
 491                        continue;
 492                }
 493
 494                len = pkt->len;
 495
 496                /* Deliver to monitoring devices all received packets */
 497                virtio_transport_deliver_tap_pkt(pkt);
 498
 499                /* Only accept correctly addressed packets */
 500                if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid &&
 501                    le64_to_cpu(pkt->hdr.dst_cid) ==
 502                    vhost_transport_get_local_cid())
 503                        virtio_transport_recv_pkt(&vhost_transport, pkt);
 504                else
 505                        virtio_transport_free_pkt(pkt);
 506
 507                len += sizeof(pkt->hdr);
 508                vhost_add_used(vq, head, len);
 509                total_len += len;
 510                added = true;
 511        } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
 512
 513no_more_replies:
 514        if (added)
 515                vhost_signal(&vsock->dev, vq);
 516
 517out:
 518        mutex_unlock(&vq->mutex);
 519}
 520
 521static void vhost_vsock_handle_rx_kick(struct vhost_work *work)
 522{
 523        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 524                                                poll.work);
 525        struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
 526                                                 dev);
 527
 528        vhost_transport_do_send_pkt(vsock, vq);
 529}
 530
 531static int vhost_vsock_start(struct vhost_vsock *vsock)
 532{
 533        struct vhost_virtqueue *vq;
 534        size_t i;
 535        int ret;
 536
 537        mutex_lock(&vsock->dev.mutex);
 538
 539        ret = vhost_dev_check_owner(&vsock->dev);
 540        if (ret)
 541                goto err;
 542
 543        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 544                vq = &vsock->vqs[i];
 545
 546                mutex_lock(&vq->mutex);
 547
 548                if (!vhost_vq_access_ok(vq)) {
 549                        ret = -EFAULT;
 550                        goto err_vq;
 551                }
 552
 553                if (!vhost_vq_get_backend(vq)) {
 554                        vhost_vq_set_backend(vq, vsock);
 555                        ret = vhost_vq_init_access(vq);
 556                        if (ret)
 557                                goto err_vq;
 558                }
 559
 560                mutex_unlock(&vq->mutex);
 561        }
 562
 563        /* Some packets may have been queued before the device was started,
 564         * let's kick the send worker to send them.
 565         */
 566        vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
 567
 568        mutex_unlock(&vsock->dev.mutex);
 569        return 0;
 570
 571err_vq:
 572        vhost_vq_set_backend(vq, NULL);
 573        mutex_unlock(&vq->mutex);
 574
 575        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 576                vq = &vsock->vqs[i];
 577
 578                mutex_lock(&vq->mutex);
 579                vhost_vq_set_backend(vq, NULL);
 580                mutex_unlock(&vq->mutex);
 581        }
 582err:
 583        mutex_unlock(&vsock->dev.mutex);
 584        return ret;
 585}
 586
 587static int vhost_vsock_stop(struct vhost_vsock *vsock)
 588{
 589        size_t i;
 590        int ret;
 591
 592        mutex_lock(&vsock->dev.mutex);
 593
 594        ret = vhost_dev_check_owner(&vsock->dev);
 595        if (ret)
 596                goto err;
 597
 598        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 599                struct vhost_virtqueue *vq = &vsock->vqs[i];
 600
 601                mutex_lock(&vq->mutex);
 602                vhost_vq_set_backend(vq, NULL);
 603                mutex_unlock(&vq->mutex);
 604        }
 605
 606err:
 607        mutex_unlock(&vsock->dev.mutex);
 608        return ret;
 609}
 610
 611static void vhost_vsock_free(struct vhost_vsock *vsock)
 612{
 613        kvfree(vsock);
 614}
 615
 616static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
 617{
 618        struct vhost_virtqueue **vqs;
 619        struct vhost_vsock *vsock;
 620        int ret;
 621
 622        /* This struct is large and allocation could fail, fall back to vmalloc
 623         * if there is no other way.
 624         */
 625        vsock = kvmalloc(sizeof(*vsock), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
 626        if (!vsock)
 627                return -ENOMEM;
 628
 629        vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL);
 630        if (!vqs) {
 631                ret = -ENOMEM;
 632                goto out;
 633        }
 634
 635        vsock->guest_cid = 0; /* no CID assigned yet */
 636
 637        atomic_set(&vsock->queued_replies, 0);
 638
 639        vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX];
 640        vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX];
 641        vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
 642        vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
 643
 644        vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs),
 645                       UIO_MAXIOV, VHOST_VSOCK_PKT_WEIGHT,
 646                       VHOST_VSOCK_WEIGHT, true, NULL);
 647
 648        file->private_data = vsock;
 649        spin_lock_init(&vsock->send_pkt_list_lock);
 650        INIT_LIST_HEAD(&vsock->send_pkt_list);
 651        vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
 652        return 0;
 653
 654out:
 655        vhost_vsock_free(vsock);
 656        return ret;
 657}
 658
 659static void vhost_vsock_flush(struct vhost_vsock *vsock)
 660{
 661        int i;
 662
 663        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++)
 664                if (vsock->vqs[i].handle_kick)
 665                        vhost_poll_flush(&vsock->vqs[i].poll);
 666        vhost_work_flush(&vsock->dev, &vsock->send_pkt_work);
 667}
 668
 669static void vhost_vsock_reset_orphans(struct sock *sk)
 670{
 671        struct vsock_sock *vsk = vsock_sk(sk);
 672
 673        /* vmci_transport.c doesn't take sk_lock here either.  At least we're
 674         * under vsock_table_lock so the sock cannot disappear while we're
 675         * executing.
 676         */
 677
 678        /* If the peer is still valid, no need to reset connection */
 679        if (vhost_vsock_get(vsk->remote_addr.svm_cid))
 680                return;
 681
 682        /* If the close timeout is pending, let it expire.  This avoids races
 683         * with the timeout callback.
 684         */
 685        if (vsk->close_work_scheduled)
 686                return;
 687
 688        sock_set_flag(sk, SOCK_DONE);
 689        vsk->peer_shutdown = SHUTDOWN_MASK;
 690        sk->sk_state = SS_UNCONNECTED;
 691        sk->sk_err = ECONNRESET;
 692        sk->sk_error_report(sk);
 693}
 694
 695static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
 696{
 697        struct vhost_vsock *vsock = file->private_data;
 698
 699        mutex_lock(&vhost_vsock_mutex);
 700        if (vsock->guest_cid)
 701                hash_del_rcu(&vsock->hash);
 702        mutex_unlock(&vhost_vsock_mutex);
 703
 704        /* Wait for other CPUs to finish using vsock */
 705        synchronize_rcu();
 706
 707        /* Iterating over all connections for all CIDs to find orphans is
 708         * inefficient.  Room for improvement here. */
 709        vsock_for_each_connected_socket(vhost_vsock_reset_orphans);
 710
 711        vhost_vsock_stop(vsock);
 712        vhost_vsock_flush(vsock);
 713        vhost_dev_stop(&vsock->dev);
 714
 715        spin_lock_bh(&vsock->send_pkt_list_lock);
 716        while (!list_empty(&vsock->send_pkt_list)) {
 717                struct virtio_vsock_pkt *pkt;
 718
 719                pkt = list_first_entry(&vsock->send_pkt_list,
 720                                struct virtio_vsock_pkt, list);
 721                list_del_init(&pkt->list);
 722                virtio_transport_free_pkt(pkt);
 723        }
 724        spin_unlock_bh(&vsock->send_pkt_list_lock);
 725
 726        vhost_dev_cleanup(&vsock->dev);
 727        kfree(vsock->dev.vqs);
 728        vhost_vsock_free(vsock);
 729        return 0;
 730}
 731
 732static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
 733{
 734        struct vhost_vsock *other;
 735
 736        /* Refuse reserved CIDs */
 737        if (guest_cid <= VMADDR_CID_HOST ||
 738            guest_cid == U32_MAX)
 739                return -EINVAL;
 740
 741        /* 64-bit CIDs are not yet supported */
 742        if (guest_cid > U32_MAX)
 743                return -EINVAL;
 744
 745        /* Refuse if CID is assigned to the guest->host transport (i.e. nested
 746         * VM), to make the loopback work.
 747         */
 748        if (vsock_find_cid(guest_cid))
 749                return -EADDRINUSE;
 750
 751        /* Refuse if CID is already in use */
 752        mutex_lock(&vhost_vsock_mutex);
 753        other = vhost_vsock_get(guest_cid);
 754        if (other && other != vsock) {
 755                mutex_unlock(&vhost_vsock_mutex);
 756                return -EADDRINUSE;
 757        }
 758
 759        if (vsock->guest_cid)
 760                hash_del_rcu(&vsock->hash);
 761
 762        vsock->guest_cid = guest_cid;
 763        hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid);
 764        mutex_unlock(&vhost_vsock_mutex);
 765
 766        return 0;
 767}
 768
 769static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
 770{
 771        struct vhost_virtqueue *vq;
 772        int i;
 773
 774        if (features & ~VHOST_VSOCK_FEATURES)
 775                return -EOPNOTSUPP;
 776
 777        mutex_lock(&vsock->dev.mutex);
 778        if ((features & (1 << VHOST_F_LOG_ALL)) &&
 779            !vhost_log_access_ok(&vsock->dev)) {
 780                goto err;
 781        }
 782
 783        if ((features & (1ULL << VIRTIO_F_ACCESS_PLATFORM))) {
 784                if (vhost_init_device_iotlb(&vsock->dev, true))
 785                        goto err;
 786        }
 787
 788        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 789                vq = &vsock->vqs[i];
 790                mutex_lock(&vq->mutex);
 791                vq->acked_features = features;
 792                mutex_unlock(&vq->mutex);
 793        }
 794        mutex_unlock(&vsock->dev.mutex);
 795        return 0;
 796
 797err:
 798        mutex_unlock(&vsock->dev.mutex);
 799        return -EFAULT;
 800}
 801
 802static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl,
 803                                  unsigned long arg)
 804{
 805        struct vhost_vsock *vsock = f->private_data;
 806        void __user *argp = (void __user *)arg;
 807        u64 guest_cid;
 808        u64 features;
 809        int start;
 810        int r;
 811
 812        switch (ioctl) {
 813        case VHOST_VSOCK_SET_GUEST_CID:
 814                if (copy_from_user(&guest_cid, argp, sizeof(guest_cid)))
 815                        return -EFAULT;
 816                return vhost_vsock_set_cid(vsock, guest_cid);
 817        case VHOST_VSOCK_SET_RUNNING:
 818                if (copy_from_user(&start, argp, sizeof(start)))
 819                        return -EFAULT;
 820                if (start)
 821                        return vhost_vsock_start(vsock);
 822                else
 823                        return vhost_vsock_stop(vsock);
 824        case VHOST_GET_FEATURES:
 825                features = VHOST_VSOCK_FEATURES;
 826                if (copy_to_user(argp, &features, sizeof(features)))
 827                        return -EFAULT;
 828                return 0;
 829        case VHOST_SET_FEATURES:
 830                if (copy_from_user(&features, argp, sizeof(features)))
 831                        return -EFAULT;
 832                return vhost_vsock_set_features(vsock, features);
 833        case VHOST_GET_BACKEND_FEATURES:
 834                features = VHOST_VSOCK_BACKEND_FEATURES;
 835                if (copy_to_user(argp, &features, sizeof(features)))
 836                        return -EFAULT;
 837                return 0;
 838        case VHOST_SET_BACKEND_FEATURES:
 839                if (copy_from_user(&features, argp, sizeof(features)))
 840                        return -EFAULT;
 841                if (features & ~VHOST_VSOCK_BACKEND_FEATURES)
 842                        return -EOPNOTSUPP;
 843                vhost_set_backend_features(&vsock->dev, features);
 844                return 0;
 845        default:
 846                mutex_lock(&vsock->dev.mutex);
 847                r = vhost_dev_ioctl(&vsock->dev, ioctl, argp);
 848                if (r == -ENOIOCTLCMD)
 849                        r = vhost_vring_ioctl(&vsock->dev, ioctl, argp);
 850                else
 851                        vhost_vsock_flush(vsock);
 852                mutex_unlock(&vsock->dev.mutex);
 853                return r;
 854        }
 855}
 856
 857static ssize_t vhost_vsock_chr_read_iter(struct kiocb *iocb, struct iov_iter *to)
 858{
 859        struct file *file = iocb->ki_filp;
 860        struct vhost_vsock *vsock = file->private_data;
 861        struct vhost_dev *dev = &vsock->dev;
 862        int noblock = file->f_flags & O_NONBLOCK;
 863
 864        return vhost_chr_read_iter(dev, to, noblock);
 865}
 866
 867static ssize_t vhost_vsock_chr_write_iter(struct kiocb *iocb,
 868                                        struct iov_iter *from)
 869{
 870        struct file *file = iocb->ki_filp;
 871        struct vhost_vsock *vsock = file->private_data;
 872        struct vhost_dev *dev = &vsock->dev;
 873
 874        return vhost_chr_write_iter(dev, from);
 875}
 876
 877static __poll_t vhost_vsock_chr_poll(struct file *file, poll_table *wait)
 878{
 879        struct vhost_vsock *vsock = file->private_data;
 880        struct vhost_dev *dev = &vsock->dev;
 881
 882        return vhost_chr_poll(file, dev, wait);
 883}
 884
 885static const struct file_operations vhost_vsock_fops = {
 886        .owner          = THIS_MODULE,
 887        .open           = vhost_vsock_dev_open,
 888        .release        = vhost_vsock_dev_release,
 889        .llseek         = noop_llseek,
 890        .unlocked_ioctl = vhost_vsock_dev_ioctl,
 891        .compat_ioctl   = compat_ptr_ioctl,
 892        .read_iter      = vhost_vsock_chr_read_iter,
 893        .write_iter     = vhost_vsock_chr_write_iter,
 894        .poll           = vhost_vsock_chr_poll,
 895};
 896
 897static struct miscdevice vhost_vsock_misc = {
 898        .minor = VHOST_VSOCK_MINOR,
 899        .name = "vhost-vsock",
 900        .fops = &vhost_vsock_fops,
 901};
 902
 903static int __init vhost_vsock_init(void)
 904{
 905        int ret;
 906
 907        ret = vsock_core_register(&vhost_transport.transport,
 908                                  VSOCK_TRANSPORT_F_H2G);
 909        if (ret < 0)
 910                return ret;
 911        return misc_register(&vhost_vsock_misc);
 912};
 913
 914static void __exit vhost_vsock_exit(void)
 915{
 916        misc_deregister(&vhost_vsock_misc);
 917        vsock_core_unregister(&vhost_transport.transport);
 918};
 919
 920module_init(vhost_vsock_init);
 921module_exit(vhost_vsock_exit);
 922MODULE_LICENSE("GPL v2");
 923MODULE_AUTHOR("Asias He");
 924MODULE_DESCRIPTION("vhost transport for vsock ");
 925MODULE_ALIAS_MISCDEV(VHOST_VSOCK_MINOR);
 926MODULE_ALIAS("devname:vhost-vsock");
 927