diff --git a/drivers/gpu/drm/virtio/virtgpu_drv.h b/drivers/gpu/drm/virtio/virtgpu_drv.h
index c1824bdf2418218793b04c0e11ab0e7fefdbc3fa..7879ff58236f1472afb38d55de3ec1782e41fd7b 100644
--- a/drivers/gpu/drm/virtio/virtgpu_drv.h
+++ b/drivers/gpu/drm/virtio/virtgpu_drv.h
@@ -221,6 +221,7 @@ struct virtio_gpu_fpriv {
 /* virtio_ioctl.c */
 #define DRM_VIRTIO_NUM_IOCTLS 10
 extern struct drm_ioctl_desc virtio_gpu_ioctls[DRM_VIRTIO_NUM_IOCTLS];
+void virtio_gpu_create_context(struct drm_device *dev, struct drm_file *file);
 
 /* virtio_kms.c */
 int virtio_gpu_init(struct drm_device *dev);
diff --git a/drivers/gpu/drm/virtio/virtgpu_gem.c b/drivers/gpu/drm/virtio/virtgpu_gem.c
index 0d6152c99a27190f538be52eb777d7d4ce973cf8..f0d5a897467752aedbc1ba0ee8dbae6380da3bd2 100644
--- a/drivers/gpu/drm/virtio/virtgpu_gem.c
+++ b/drivers/gpu/drm/virtio/virtgpu_gem.c
@@ -39,6 +39,9 @@ int virtio_gpu_gem_create(struct drm_file *file,
 	int ret;
 	u32 handle;
 
+	if (vgdev->has_virgl_3d)
+		virtio_gpu_create_context(dev, file);
+
 	ret = virtio_gpu_object_create(vgdev, params, &obj, NULL);
 	if (ret < 0)
 		return ret;
diff --git a/drivers/gpu/drm/virtio/virtgpu_ioctl.c b/drivers/gpu/drm/virtio/virtgpu_ioctl.c
index 336cc9143205de921c4307f076903530abe58e97..70ba13e8e58a82032777ea03c9d3c72318bbaf40 100644
--- a/drivers/gpu/drm/virtio/virtgpu_ioctl.c
+++ b/drivers/gpu/drm/virtio/virtgpu_ioctl.c
@@ -33,8 +33,7 @@
 
 #include "virtgpu_drv.h"
 
-static void virtio_gpu_create_context(struct drm_device *dev,
-				      struct drm_file *file)
+void virtio_gpu_create_context(struct drm_device *dev, struct drm_file *file)
 {
 	struct virtio_gpu_device *vgdev = dev->dev_private;
 	struct virtio_gpu_fpriv *vfpriv = file->driver_priv;