diff --git a/drivers/iommu/io-pgtable-arm.c b/drivers/iommu/io-pgtable-arm.c
index 72dcdd468cf30d6ec32eaf0795463549076309b0..f7828a7aad410d4406de3d85a97cd419bc29bd0e 100644
--- a/drivers/iommu/io-pgtable-arm.c
+++ b/drivers/iommu/io-pgtable-arm.c
@@ -188,20 +188,28 @@ static dma_addr_t __arm_lpae_dma_addr(void *pages)
 }
 
 static void *__arm_lpae_alloc_pages(size_t size, gfp_t gfp,
-				    struct io_pgtable_cfg *cfg)
+				    struct io_pgtable_cfg *cfg,
+				    void *cookie)
 {
 	struct device *dev = cfg->iommu_dev;
 	int order = get_order(size);
-	struct page *p;
 	dma_addr_t dma;
 	void *pages;
 
 	VM_BUG_ON((gfp & __GFP_HIGHMEM));
-	p = alloc_pages_node(dev_to_node(dev), gfp | __GFP_ZERO, order);
-	if (!p)
+
+	if (cfg->alloc) {
+		pages = cfg->alloc(cookie, size, gfp);
+	} else {
+		struct page *p;
+
+		p = alloc_pages_node(dev_to_node(dev), gfp | __GFP_ZERO, order);
+		pages = p ? page_address(p) : NULL;
+	}
+
+	if (!pages)
 		return NULL;
 
-	pages = page_address(p);
 	if (!cfg->coherent_walk) {
 		dma = dma_map_single(dev, pages, size, DMA_TO_DEVICE);
 		if (dma_mapping_error(dev, dma))
@@ -220,18 +228,28 @@ static void *__arm_lpae_alloc_pages(size_t size, gfp_t gfp,
 out_unmap:
 	dev_err(dev, "Cannot accommodate DMA translation for IOMMU page tables\n");
 	dma_unmap_single(dev, dma, size, DMA_TO_DEVICE);
+
 out_free:
-	__free_pages(p, order);
+	if (cfg->free)
+		cfg->free(cookie, pages, size);
+	else
+		free_pages((unsigned long)pages, order);
+
 	return NULL;
 }
 
 static void __arm_lpae_free_pages(void *pages, size_t size,
-				  struct io_pgtable_cfg *cfg)
+				  struct io_pgtable_cfg *cfg,
+				  void *cookie)
 {
 	if (!cfg->coherent_walk)
 		dma_unmap_single(cfg->iommu_dev, __arm_lpae_dma_addr(pages),
 				 size, DMA_TO_DEVICE);
-	free_pages((unsigned long)pages, get_order(size));
+
+	if (cfg->free)
+		cfg->free(cookie, pages, size);
+	else
+		free_pages((unsigned long)pages, get_order(size));
 }
 
 static void __arm_lpae_sync_pte(arm_lpae_iopte *ptep, int num_entries,
@@ -373,13 +391,13 @@ static int __arm_lpae_map(struct arm_lpae_io_pgtable *data, unsigned long iova,
 	/* Grab a pointer to the next level */
 	pte = READ_ONCE(*ptep);
 	if (!pte) {
-		cptep = __arm_lpae_alloc_pages(tblsz, gfp, cfg);
+		cptep = __arm_lpae_alloc_pages(tblsz, gfp, cfg, data->iop.cookie);
 		if (!cptep)
 			return -ENOMEM;
 
 		pte = arm_lpae_install_table(cptep, ptep, 0, data);
 		if (pte)
-			__arm_lpae_free_pages(cptep, tblsz, cfg);
+			__arm_lpae_free_pages(cptep, tblsz, cfg, data->iop.cookie);
 	} else if (!cfg->coherent_walk && !(pte & ARM_LPAE_PTE_SW_SYNC)) {
 		__arm_lpae_sync_pte(ptep, 1, cfg);
 	}
@@ -524,7 +542,7 @@ static void __arm_lpae_free_pgtable(struct arm_lpae_io_pgtable *data, int lvl,
 		__arm_lpae_free_pgtable(data, lvl + 1, iopte_deref(pte, data));
 	}
 
-	__arm_lpae_free_pages(start, table_size, &data->iop.cfg);
+	__arm_lpae_free_pages(start, table_size, &data->iop.cfg, data->iop.cookie);
 }
 
 static void arm_lpae_free_pgtable(struct io_pgtable *iop)
@@ -552,7 +570,7 @@ static size_t arm_lpae_split_blk_unmap(struct arm_lpae_io_pgtable *data,
 	if (WARN_ON(lvl == ARM_LPAE_MAX_LEVELS))
 		return 0;
 
-	tablep = __arm_lpae_alloc_pages(tablesz, GFP_ATOMIC, cfg);
+	tablep = __arm_lpae_alloc_pages(tablesz, GFP_ATOMIC, cfg, data->iop.cookie);
 	if (!tablep)
 		return 0; /* Bytes unmapped */
 
@@ -575,7 +593,7 @@ static size_t arm_lpae_split_blk_unmap(struct arm_lpae_io_pgtable *data,
 
 	pte = arm_lpae_install_table(tablep, ptep, blk_pte, data);
 	if (pte != blk_pte) {
-		__arm_lpae_free_pages(tablep, tablesz, cfg);
+		__arm_lpae_free_pages(tablep, tablesz, cfg, data->iop.cookie);
 		/*
 		 * We may race against someone unmapping another part of this
 		 * block, but anything else is invalid. We can't misinterpret
@@ -882,7 +900,7 @@ arm_64_lpae_alloc_pgtable_s1(struct io_pgtable_cfg *cfg, void *cookie)
 
 	/* Looking good; allocate a pgd */
 	data->pgd = __arm_lpae_alloc_pages(ARM_LPAE_PGD_SIZE(data),
-					   GFP_KERNEL, cfg);
+					   GFP_KERNEL, cfg, cookie);
 	if (!data->pgd)
 		goto out_free_data;
 
@@ -984,7 +1002,7 @@ arm_64_lpae_alloc_pgtable_s2(struct io_pgtable_cfg *cfg, void *cookie)
 
 	/* Allocate pgd pages */
 	data->pgd = __arm_lpae_alloc_pages(ARM_LPAE_PGD_SIZE(data),
-					   GFP_KERNEL, cfg);
+					   GFP_KERNEL, cfg, cookie);
 	if (!data->pgd)
 		goto out_free_data;
 
@@ -1059,7 +1077,7 @@ arm_mali_lpae_alloc_pgtable(struct io_pgtable_cfg *cfg, void *cookie)
 		 << ARM_LPAE_MAIR_ATTR_SHIFT(ARM_LPAE_MAIR_ATTR_IDX_DEV));
 
 	data->pgd = __arm_lpae_alloc_pages(ARM_LPAE_PGD_SIZE(data), GFP_KERNEL,
-					   cfg);
+					   cfg, cookie);
 	if (!data->pgd)
 		goto out_free_data;
 
@@ -1080,26 +1098,31 @@ arm_mali_lpae_alloc_pgtable(struct io_pgtable_cfg *cfg, void *cookie)
 }
 
 struct io_pgtable_init_fns io_pgtable_arm_64_lpae_s1_init_fns = {
+	.caps	= IO_PGTABLE_CAP_CUSTOM_ALLOCATOR,
 	.alloc	= arm_64_lpae_alloc_pgtable_s1,
 	.free	= arm_lpae_free_pgtable,
 };
 
 struct io_pgtable_init_fns io_pgtable_arm_64_lpae_s2_init_fns = {
+	.caps	= IO_PGTABLE_CAP_CUSTOM_ALLOCATOR,
 	.alloc	= arm_64_lpae_alloc_pgtable_s2,
 	.free	= arm_lpae_free_pgtable,
 };
 
 struct io_pgtable_init_fns io_pgtable_arm_32_lpae_s1_init_fns = {
+	.caps	= IO_PGTABLE_CAP_CUSTOM_ALLOCATOR,
 	.alloc	= arm_32_lpae_alloc_pgtable_s1,
 	.free	= arm_lpae_free_pgtable,
 };
 
 struct io_pgtable_init_fns io_pgtable_arm_32_lpae_s2_init_fns = {
+	.caps	= IO_PGTABLE_CAP_CUSTOM_ALLOCATOR,
 	.alloc	= arm_32_lpae_alloc_pgtable_s2,
 	.free	= arm_lpae_free_pgtable,
 };
 
 struct io_pgtable_init_fns io_pgtable_arm_mali_lpae_init_fns = {
+	.caps	= IO_PGTABLE_CAP_CUSTOM_ALLOCATOR,
 	.alloc	= arm_mali_lpae_alloc_pgtable,
 	.free	= arm_lpae_free_pgtable,
 };