Update Linux to v5.4.2

Change-Id: Idf6911045d9d382da2cfe01b1edff026404ac8fd
diff --git a/drivers/infiniband/hw/mlx5/mr.c b/drivers/infiniband/hw/mlx5/mr.c
index 7df4a4f..7019c12 100644
--- a/drivers/infiniband/hw/mlx5/mr.c
+++ b/drivers/infiniband/hw/mlx5/mr.c
@@ -51,30 +51,19 @@
 static void dereg_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr);
 static int mr_cache_max_order(struct mlx5_ib_dev *dev);
 static int unreg_umr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr);
-static bool umr_can_modify_entity_size(struct mlx5_ib_dev *dev)
-{
-	return !MLX5_CAP_GEN(dev->mdev, umr_modify_entity_size_disabled);
-}
 
 static bool umr_can_use_indirect_mkey(struct mlx5_ib_dev *dev)
 {
 	return !MLX5_CAP_GEN(dev->mdev, umr_indirect_mkey_disabled);
 }
 
-static bool use_umr(struct mlx5_ib_dev *dev, int order)
-{
-	return order <= mr_cache_max_order(dev) &&
-		umr_can_modify_entity_size(dev);
-}
-
 static int destroy_mkey(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr)
 {
 	int err = mlx5_core_destroy_mkey(dev->mdev, &mr->mmkey);
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	/* Wait until all page fault handlers using the mr complete. */
-	synchronize_srcu(&dev->mr_srcu);
-#endif
+	if (IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING))
+		/* Wait until all page fault handlers using the mr complete. */
+		synchronize_srcu(&dev->mr_srcu);
 
 	return err;
 }
@@ -95,44 +84,17 @@
 		length + (start & (MLX5_ADAPTER_PAGE_SIZE - 1));
 }
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-static void update_odp_mr(struct mlx5_ib_mr *mr)
+static void reg_mr_callback(int status, struct mlx5_async_work *context)
 {
-	if (mr->umem->odp_data) {
-		/*
-		 * This barrier prevents the compiler from moving the
-		 * setting of umem->odp_data->private to point to our
-		 * MR, before reg_umr finished, to ensure that the MR
-		 * initialization have finished before starting to
-		 * handle invalidations.
-		 */
-		smp_wmb();
-		mr->umem->odp_data->private = mr;
-		/*
-		 * Make sure we will see the new
-		 * umem->odp_data->private value in the invalidation
-		 * routines, before we can get page faults on the
-		 * MR. Page faults can happen once we put the MR in
-		 * the tree, below this line. Without the barrier,
-		 * there can be a fault handling and an invalidation
-		 * before umem->odp_data->private == mr is visible to
-		 * the invalidation handler.
-		 */
-		smp_wmb();
-	}
-}
-#endif
-
-static void reg_mr_callback(int status, void *context)
-{
-	struct mlx5_ib_mr *mr = context;
+	struct mlx5_ib_mr *mr =
+		container_of(context, struct mlx5_ib_mr, cb_work);
 	struct mlx5_ib_dev *dev = mr->dev;
 	struct mlx5_mr_cache *cache = &dev->cache;
 	int c = order2idx(dev, mr->order);
 	struct mlx5_cache_ent *ent = &cache->ent[c];
 	u8 key;
 	unsigned long flags;
-	struct mlx5_mkey_table *table = &dev->mdev->priv.mkey_table;
+	struct xarray *mkeys = &dev->mdev->priv.mkey_table;
 	int err;
 
 	spin_lock_irqsave(&ent->lock, flags);
@@ -160,12 +122,12 @@
 	ent->size++;
 	spin_unlock_irqrestore(&ent->lock, flags);
 
-	write_lock_irqsave(&table->lock, flags);
-	err = radix_tree_insert(&table->tree, mlx5_base_mkey(mr->mmkey.key),
-				&mr->mmkey);
+	xa_lock_irqsave(mkeys, flags);
+	err = xa_err(__xa_store(mkeys, mlx5_base_mkey(mr->mmkey.key),
+				&mr->mmkey, GFP_ATOMIC));
+	xa_unlock_irqrestore(mkeys, flags);
 	if (err)
 		pr_err("Error inserting to mkey tree. 0x%x\n", -err);
-	write_unlock_irqrestore(&table->lock, flags);
 
 	if (!completion_done(&ent->compl))
 		complete(&ent->compl);
@@ -216,9 +178,9 @@
 		ent->pending++;
 		spin_unlock_irq(&ent->lock);
 		err = mlx5_core_create_mkey_cb(dev->mdev, &mr->mmkey,
-					       in, inlen,
+					       &dev->async_ctx, in, inlen,
 					       mr->out, sizeof(mr->out),
-					       reg_mr_callback, mr);
+					       reg_mr_callback, &mr->cb_work);
 		if (err) {
 			spin_lock_irq(&ent->lock);
 			ent->pending--;
@@ -256,9 +218,8 @@
 		mlx5_core_destroy_mkey(dev->mdev, &mr->mmkey);
 	}
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	synchronize_srcu(&dev->mr_srcu);
-#endif
+	if (IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING))
+		synchronize_srcu(&dev->mr_srcu);
 
 	list_for_each_entry_safe(mr, tmp_mr, &del_list, list) {
 		list_del(&mr->list);
@@ -548,14 +509,17 @@
 		return;
 
 	c = order2idx(dev, mr->order);
-	if (c < 0 || c >= MAX_MR_CACHE_ENTRIES) {
-		mlx5_ib_warn(dev, "order %d, cache index %d\n", mr->order, c);
+	WARN_ON(c < 0 || c >= MAX_MR_CACHE_ENTRIES);
+
+	if (unreg_umr(dev, mr)) {
+		mr->allocated_from_cache = false;
+		destroy_mkey(dev, mr);
+		ent = &cache->ent[c];
+		if (ent->cur < ent->limit)
+			queue_work(cache->wq, &ent->work);
 		return;
 	}
 
-	if (unreg_umr(dev, mr))
-		return;
-
 	ent = &cache->ent[c];
 	spin_lock_irq(&ent->lock);
 	list_add_tail(&mr->list, &ent->head);
@@ -603,59 +567,34 @@
 
 static void mlx5_mr_cache_debugfs_cleanup(struct mlx5_ib_dev *dev)
 {
-	if (!mlx5_debugfs_root || dev->rep)
+	if (!mlx5_debugfs_root || dev->is_rep)
 		return;
 
 	debugfs_remove_recursive(dev->cache.root);
 	dev->cache.root = NULL;
 }
 
-static int mlx5_mr_cache_debugfs_init(struct mlx5_ib_dev *dev)
+static void mlx5_mr_cache_debugfs_init(struct mlx5_ib_dev *dev)
 {
 	struct mlx5_mr_cache *cache = &dev->cache;
 	struct mlx5_cache_ent *ent;
+	struct dentry *dir;
 	int i;
 
-	if (!mlx5_debugfs_root || dev->rep)
-		return 0;
+	if (!mlx5_debugfs_root || dev->is_rep)
+		return;
 
 	cache->root = debugfs_create_dir("mr_cache", dev->mdev->priv.dbg_root);
-	if (!cache->root)
-		return -ENOMEM;
 
 	for (i = 0; i < MAX_MR_CACHE_ENTRIES; i++) {
 		ent = &cache->ent[i];
 		sprintf(ent->name, "%d", ent->order);
-		ent->dir = debugfs_create_dir(ent->name,  cache->root);
-		if (!ent->dir)
-			goto err;
-
-		ent->fsize = debugfs_create_file("size", 0600, ent->dir, ent,
-						 &size_fops);
-		if (!ent->fsize)
-			goto err;
-
-		ent->flimit = debugfs_create_file("limit", 0600, ent->dir, ent,
-						  &limit_fops);
-		if (!ent->flimit)
-			goto err;
-
-		ent->fcur = debugfs_create_u32("cur", 0400, ent->dir,
-					       &ent->cur);
-		if (!ent->fcur)
-			goto err;
-
-		ent->fmiss = debugfs_create_u32("miss", 0600, ent->dir,
-						&ent->miss);
-		if (!ent->fmiss)
-			goto err;
+		dir = debugfs_create_dir(ent->name, cache->root);
+		debugfs_create_file("size", 0600, dir, ent, &size_fops);
+		debugfs_create_file("limit", 0600, dir, ent, &limit_fops);
+		debugfs_create_u32("cur", 0400, dir, &ent->cur);
+		debugfs_create_u32("miss", 0600, dir, &ent->miss);
 	}
-
-	return 0;
-err:
-	mlx5_mr_cache_debugfs_cleanup(dev);
-
-	return -ENOMEM;
 }
 
 static void delay_time_func(struct timer_list *t)
@@ -669,7 +608,6 @@
 {
 	struct mlx5_mr_cache *cache = &dev->cache;
 	struct mlx5_cache_ent *ent;
-	int err;
 	int i;
 
 	mutex_init(&dev->slow_path_mutex);
@@ -679,6 +617,7 @@
 		return -ENOMEM;
 	}
 
+	mlx5_cmd_init_async_ctx(dev->mdev, &dev->async_ctx);
 	timer_setup(&dev->delay_timer, delay_time_func, 0);
 	for (i = 0; i < MAX_MR_CACHE_ENTRIES; i++) {
 		ent = &cache->ent[i];
@@ -705,7 +644,7 @@
 			   MLX5_IB_UMR_OCTOWORD;
 		ent->access_mode = MLX5_MKC_ACCESS_MODE_MTT;
 		if ((dev->mdev->profile->mask & MLX5_PROF_MASK_MR_CACHE) &&
-		    !dev->rep &&
+		    !dev->is_rep &&
 		    mlx5_core_is_pf(dev->mdev))
 			ent->limit = dev->mdev->profile->mr_cache[i].limit;
 		else
@@ -713,45 +652,11 @@
 		queue_work(cache->wq, &ent->work);
 	}
 
-	err = mlx5_mr_cache_debugfs_init(dev);
-	if (err)
-		mlx5_ib_warn(dev, "cache debugfs failure\n");
-
-	/*
-	 * We don't want to fail driver if debugfs failed to initialize,
-	 * so we are not forwarding error to the user.
-	 */
+	mlx5_mr_cache_debugfs_init(dev);
 
 	return 0;
 }
 
-static void wait_for_async_commands(struct mlx5_ib_dev *dev)
-{
-	struct mlx5_mr_cache *cache = &dev->cache;
-	struct mlx5_cache_ent *ent;
-	int total = 0;
-	int i;
-	int j;
-
-	for (i = 0; i < MAX_MR_CACHE_ENTRIES; i++) {
-		ent = &cache->ent[i];
-		for (j = 0 ; j < 1000; j++) {
-			if (!ent->pending)
-				break;
-			msleep(50);
-		}
-	}
-	for (i = 0; i < MAX_MR_CACHE_ENTRIES; i++) {
-		ent = &cache->ent[i];
-		total += ent->pending;
-	}
-
-	if (total)
-		mlx5_ib_warn(dev, "aborted while there are %d pending mr requests\n", total);
-	else
-		mlx5_ib_warn(dev, "done with all pending requests\n");
-}
-
 int mlx5_mr_cache_cleanup(struct mlx5_ib_dev *dev)
 {
 	int i;
@@ -763,12 +668,12 @@
 	flush_workqueue(dev->cache.wq);
 
 	mlx5_mr_cache_debugfs_cleanup(dev);
+	mlx5_cmd_cleanup_async_ctx(&dev->async_ctx);
 
 	for (i = 0; i < MAX_MR_CACHE_ENTRIES; i++)
 		clean_keys(dev, i);
 
 	destroy_workqueue(dev->cache.wq);
-	wait_for_async_commands(dev);
 	del_timer_sync(&dev->delay_timer);
 
 	return 0;
@@ -847,26 +752,43 @@
 	return MLX5_MAX_UMR_SHIFT;
 }
 
-static int mr_umem_get(struct ib_pd *pd, u64 start, u64 length,
-		       int access_flags, struct ib_umem **umem,
-		       int *npages, int *page_shift, int *ncont,
-		       int *order)
+static int mr_umem_get(struct mlx5_ib_dev *dev, struct ib_udata *udata,
+		       u64 start, u64 length, int access_flags,
+		       struct ib_umem **umem, int *npages, int *page_shift,
+		       int *ncont, int *order)
 {
-	struct mlx5_ib_dev *dev = to_mdev(pd->device);
 	struct ib_umem *u;
-	int err;
 
 	*umem = NULL;
 
-	u = ib_umem_get(pd->uobject->context, start, length, access_flags, 0);
-	err = PTR_ERR_OR_ZERO(u);
-	if (err) {
-		mlx5_ib_dbg(dev, "umem get failed (%d)\n", err);
-		return err;
+	if (access_flags & IB_ACCESS_ON_DEMAND) {
+		struct ib_umem_odp *odp;
+
+		odp = ib_umem_odp_get(udata, start, length, access_flags);
+		if (IS_ERR(odp)) {
+			mlx5_ib_dbg(dev, "umem get failed (%ld)\n",
+				    PTR_ERR(odp));
+			return PTR_ERR(odp);
+		}
+
+		u = &odp->umem;
+
+		*page_shift = odp->page_shift;
+		*ncont = ib_umem_odp_num_pages(odp);
+		*npages = *ncont << (*page_shift - PAGE_SHIFT);
+		if (order)
+			*order = ilog2(roundup_pow_of_two(*ncont));
+	} else {
+		u = ib_umem_get(udata, start, length, access_flags, 0);
+		if (IS_ERR(u)) {
+			mlx5_ib_dbg(dev, "umem get failed (%ld)\n", PTR_ERR(u));
+			return PTR_ERR(u);
+		}
+
+		mlx5_ib_cont_pages(u, start, MLX5_MKEY_PAGE_SHIFT_MASK, npages,
+				   page_shift, ncont, order);
 	}
 
-	mlx5_ib_cont_pages(u, start, MLX5_MKEY_PAGE_SHIFT_MASK, npages,
-			   page_shift, ncont, order);
 	if (!*npages) {
 		mlx5_ib_warn(dev, "avoid zero region\n");
 		ib_umem_release(u);
@@ -1211,7 +1133,7 @@
 	return ERR_PTR(err);
 }
 
-static void set_mr_fileds(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr,
+static void set_mr_fields(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr,
 			  int npages, u64 length, int access_flags)
 {
 	mr->npages = npages;
@@ -1222,8 +1144,8 @@
 	mr->access_flags = access_flags;
 }
 
-static struct ib_mr *mlx5_ib_get_memic_mr(struct ib_pd *pd, u64 memic_addr,
-					  u64 length, int acc)
+static struct ib_mr *mlx5_ib_get_dm_mr(struct ib_pd *pd, u64 start_addr,
+				       u64 length, int acc, int mode)
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
 	int inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
@@ -1245,9 +1167,8 @@
 
 	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
 
-	MLX5_SET(mkc, mkc, access_mode_1_0, MLX5_MKC_ACCESS_MODE_MEMIC & 0x3);
-	MLX5_SET(mkc, mkc, access_mode_4_2,
-		 (MLX5_MKC_ACCESS_MODE_MEMIC >> 2) & 0x7);
+	MLX5_SET(mkc, mkc, access_mode_1_0, mode & 0x3);
+	MLX5_SET(mkc, mkc, access_mode_4_2, (mode >> 2) & 0x7);
 	MLX5_SET(mkc, mkc, a, !!(acc & IB_ACCESS_REMOTE_ATOMIC));
 	MLX5_SET(mkc, mkc, rw, !!(acc & IB_ACCESS_REMOTE_WRITE));
 	MLX5_SET(mkc, mkc, rr, !!(acc & IB_ACCESS_REMOTE_READ));
@@ -1257,8 +1178,7 @@
 	MLX5_SET64(mkc, mkc, len, length);
 	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
 	MLX5_SET(mkc, mkc, qpn, 0xffffff);
-	MLX5_SET64(mkc, mkc, start_addr,
-		   memic_addr - pci_resource_start(dev->mdev->pdev, 0));
+	MLX5_SET64(mkc, mkc, start_addr, start_addr);
 
 	err = mlx5_core_create_mkey(mdev, &mr->mmkey, in, inlen);
 	if (err)
@@ -1267,7 +1187,7 @@
 	kfree(in);
 
 	mr->umem = NULL;
-	set_mr_fileds(dev, mr, 0, length, acc);
+	set_mr_fields(dev, mr, 0, length, acc);
 
 	return &mr->ibmr;
 
@@ -1280,20 +1200,51 @@
 	return ERR_PTR(err);
 }
 
+int mlx5_ib_advise_mr(struct ib_pd *pd,
+		      enum ib_uverbs_advise_mr_advice advice,
+		      u32 flags,
+		      struct ib_sge *sg_list,
+		      u32 num_sge,
+		      struct uverbs_attr_bundle *attrs)
+{
+	if (advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH &&
+	    advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_WRITE)
+		return -EOPNOTSUPP;
+
+	return mlx5_ib_advise_mr_prefetch(pd, advice, flags,
+					 sg_list, num_sge);
+}
+
 struct ib_mr *mlx5_ib_reg_dm_mr(struct ib_pd *pd, struct ib_dm *dm,
 				struct ib_dm_mr_attr *attr,
 				struct uverbs_attr_bundle *attrs)
 {
 	struct mlx5_ib_dm *mdm = to_mdm(dm);
-	u64 memic_addr;
+	struct mlx5_core_dev *dev = to_mdev(dm->device)->mdev;
+	u64 start_addr = mdm->dev_addr + attr->offset;
+	int mode;
 
-	if (attr->access_flags & ~MLX5_IB_DM_ALLOWED_ACCESS)
+	switch (mdm->type) {
+	case MLX5_IB_UAPI_DM_TYPE_MEMIC:
+		if (attr->access_flags & ~MLX5_IB_DM_MEMIC_ALLOWED_ACCESS)
+			return ERR_PTR(-EINVAL);
+
+		mode = MLX5_MKC_ACCESS_MODE_MEMIC;
+		start_addr -= pci_resource_start(dev->pdev, 0);
+		break;
+	case MLX5_IB_UAPI_DM_TYPE_STEERING_SW_ICM:
+	case MLX5_IB_UAPI_DM_TYPE_HEADER_MODIFY_SW_ICM:
+		if (attr->access_flags & ~MLX5_IB_DM_SW_ICM_ALLOWED_ACCESS)
+			return ERR_PTR(-EINVAL);
+
+		mode = MLX5_MKC_ACCESS_MODE_SW_ICM;
+		break;
+	default:
 		return ERR_PTR(-EINVAL);
+	}
 
-	memic_addr = mdm->dev_addr + attr->offset;
-
-	return mlx5_ib_get_memic_mr(pd, memic_addr, attr->length,
-				    attr->access_flags);
+	return mlx5_ib_get_dm_mr(pd, start_addr, attr->length,
+				 attr->access_flags, mode);
 }
 
 struct ib_mr *mlx5_ib_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
@@ -1302,7 +1253,7 @@
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
 	struct mlx5_ib_mr *mr = NULL;
-	bool populate_mtts = false;
+	bool use_umr;
 	struct ib_umem *umem;
 	int page_shift;
 	int npages;
@@ -1316,48 +1267,46 @@
 	mlx5_ib_dbg(dev, "start 0x%llx, virt_addr 0x%llx, length 0x%llx, access_flags 0x%x\n",
 		    start, virt_addr, length, access_flags);
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	if (!start && length == U64_MAX) {
+	if (IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING) && !start &&
+	    length == U64_MAX) {
 		if (!(access_flags & IB_ACCESS_ON_DEMAND) ||
 		    !(dev->odp_caps.general_caps & IB_ODP_SUPPORT_IMPLICIT))
 			return ERR_PTR(-EINVAL);
 
-		mr = mlx5_ib_alloc_implicit_mr(to_mpd(pd), access_flags);
+		mr = mlx5_ib_alloc_implicit_mr(to_mpd(pd), udata, access_flags);
 		if (IS_ERR(mr))
 			return ERR_CAST(mr);
 		return &mr->ibmr;
 	}
-#endif
 
-	err = mr_umem_get(pd, start, length, access_flags, &umem, &npages,
-			   &page_shift, &ncont, &order);
+	err = mr_umem_get(dev, udata, start, length, access_flags, &umem,
+			  &npages, &page_shift, &ncont, &order);
 
 	if (err < 0)
 		return ERR_PTR(err);
 
-	if (use_umr(dev, order)) {
+	use_umr = mlx5_ib_can_use_umr(dev, true);
+
+	if (order <= mr_cache_max_order(dev) && use_umr) {
 		mr = alloc_mr_from_cache(pd, umem, virt_addr, length, ncont,
 					 page_shift, order, access_flags);
 		if (PTR_ERR(mr) == -EAGAIN) {
 			mlx5_ib_dbg(dev, "cache empty for order %d\n", order);
 			mr = NULL;
 		}
-		populate_mtts = false;
 	} else if (!MLX5_CAP_GEN(dev->mdev, umr_extended_translation_offset)) {
 		if (access_flags & IB_ACCESS_ON_DEMAND) {
 			err = -EINVAL;
 			pr_err("Got MR registration for ODP MR > 512MB, not supported for Connect-IB\n");
 			goto error;
 		}
-		populate_mtts = true;
+		use_umr = false;
 	}
 
 	if (!mr) {
-		if (!umr_can_modify_entity_size(dev))
-			populate_mtts = true;
 		mutex_lock(&dev->slow_path_mutex);
 		mr = reg_create(NULL, pd, virt_addr, length, umem, ncont,
-				page_shift, access_flags, populate_mtts);
+				page_shift, access_flags, !use_umr);
 		mutex_unlock(&dev->slow_path_mutex);
 	}
 
@@ -1369,13 +1318,9 @@
 	mlx5_ib_dbg(dev, "mkey 0x%x\n", mr->mmkey.key);
 
 	mr->umem = umem;
-	set_mr_fileds(dev, mr, npages, length, access_flags);
+	set_mr_fields(dev, mr, npages, length, access_flags);
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	update_odp_mr(mr);
-#endif
-
-	if (!populate_mtts) {
+	if (use_umr) {
 		int update_xlt_flags = MLX5_IB_UPD_XLT_ENABLE;
 
 		if (access_flags & IB_ACCESS_ON_DEMAND)
@@ -1390,9 +1335,13 @@
 		}
 	}
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	mr->live = 1;
-#endif
+	if (is_odp_mr(mr)) {
+		to_ib_umem_odp(mr->umem)->private = mr;
+		atomic_set(&mr->num_pending_prefetch, 0);
+	}
+	if (IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING))
+		smp_store_release(&mr->live, 1);
+
 	return &mr->ibmr;
 error:
 	ib_umem_release(umem);
@@ -1408,9 +1357,11 @@
 		return 0;
 
 	umrwr.wr.send_flags = MLX5_IB_SEND_UMR_DISABLE_MR |
-			      MLX5_IB_SEND_UMR_FAIL_IF_FREE;
+			      MLX5_IB_SEND_UMR_UPDATE_PD_ACCESS;
 	umrwr.wr.opcode = MLX5_IB_WR_UMR;
+	umrwr.pd = dev->umrc.pd;
 	umrwr.mkey = mr->mmkey.key;
+	umrwr.ignore_free_state = 1;
 
 	return mlx5_ib_post_send_wait(dev, &umrwr);
 }
@@ -1464,6 +1415,9 @@
 	if (!mr->umem)
 		return -EINVAL;
 
+	if (is_odp_mr(mr))
+		return -EOPNOTSUPP;
+
 	if (flags & IB_MR_REREG_TRANS) {
 		addr = virt_addr;
 		len = length;
@@ -1480,13 +1434,15 @@
 		flags |= IB_MR_REREG_TRANS;
 		ib_umem_release(mr->umem);
 		mr->umem = NULL;
-		err = mr_umem_get(pd, addr, len, access_flags, &mr->umem,
-				  &npages, &page_shift, &ncont, &order);
+		err = mr_umem_get(dev, udata, addr, len, access_flags,
+				  &mr->umem, &npages, &page_shift, &ncont,
+				  &order);
 		if (err)
 			goto err;
 	}
 
-	if (flags & IB_MR_REREG_TRANS && !use_umr_mtt_update(mr, addr, len)) {
+	if (!mlx5_ib_can_use_umr(dev, true) ||
+	    (flags & IB_MR_REREG_TRANS && !use_umr_mtt_update(mr, addr, len))) {
 		/*
 		 * UMR can't be used - MKey needs to be replaced.
 		 */
@@ -1507,9 +1463,6 @@
 		}
 
 		mr->allocated_from_cache = 0;
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-		mr->live = 1;
-#endif
 	} else {
 		/*
 		 * Send a UMR WQE
@@ -1536,18 +1489,14 @@
 			goto err;
 	}
 
-	set_mr_fileds(dev, mr, npages, len, access_flags);
+	set_mr_fields(dev, mr, npages, len, access_flags);
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	update_odp_mr(mr);
-#endif
 	return 0;
 
 err:
-	if (mr->umem) {
-		ib_umem_release(mr->umem);
-		mr->umem = NULL;
-	}
+	ib_umem_release(mr->umem);
+	mr->umem = NULL;
+
 	clean_mr(dev, mr);
 	return err;
 }
@@ -1615,10 +1564,10 @@
 		mr->sig = NULL;
 	}
 
-	mlx5_free_priv_descs(mr);
-
-	if (!allocated_from_cache)
+	if (!allocated_from_cache) {
 		destroy_mkey(dev, mr);
+		mlx5_free_priv_descs(mr);
+	}
 }
 
 static void dereg_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr)
@@ -1626,16 +1575,27 @@
 	int npages = mr->npages;
 	struct ib_umem *umem = mr->umem;
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	if (umem && umem->odp_data) {
-		/* Prevent new page faults from succeeding */
-		mr->live = 0;
+	if (is_odp_mr(mr)) {
+		struct ib_umem_odp *umem_odp = to_ib_umem_odp(umem);
+
+		/* Prevent new page faults and
+		 * prefetch requests from succeeding
+		 */
+		WRITE_ONCE(mr->live, 0);
+
 		/* Wait for all running page-fault handlers to finish. */
 		synchronize_srcu(&dev->mr_srcu);
+
+		/* dequeue pending prefetch requests for the mr */
+		if (atomic_read(&mr->num_pending_prefetch))
+			flush_workqueue(system_unbound_wq);
+		WARN_ON(atomic_read(&mr->num_pending_prefetch));
+
 		/* Destroy all page mappings */
-		if (umem->odp_data->page_list)
-			mlx5_ib_invalidate_range(umem, ib_umem_start(umem),
-						 ib_umem_end(umem));
+		if (!umem_odp->is_implicit_odp)
+			mlx5_ib_invalidate_range(umem_odp,
+						 ib_umem_start(umem_odp),
+						 ib_umem_end(umem_odp));
 		else
 			mlx5_ib_free_implicit_mr(mr);
 		/*
@@ -1643,13 +1603,13 @@
 		 * so that there will not be any invalidations in
 		 * flight, looking at the *mr struct.
 		 */
-		ib_umem_release(umem);
+		ib_umem_odp_release(umem_odp);
 		atomic_sub(npages, &dev->mdev->priv.reg_pages);
 
 		/* Avoid double-freeing the umem. */
 		umem = NULL;
 	}
-#endif
+
 	clean_mr(dev, mr);
 
 	/*
@@ -1657,29 +1617,215 @@
 	 * remove the DMA mapping.
 	 */
 	mlx5_mr_cache_free(dev, mr);
-	if (umem) {
-		ib_umem_release(umem);
+	ib_umem_release(umem);
+	if (umem)
 		atomic_sub(npages, &dev->mdev->priv.reg_pages);
-	}
+
 	if (!mr->allocated_from_cache)
 		kfree(mr);
 }
 
-int mlx5_ib_dereg_mr(struct ib_mr *ibmr)
+int mlx5_ib_dereg_mr(struct ib_mr *ibmr, struct ib_udata *udata)
 {
-	dereg_mr(to_mdev(ibmr->device), to_mmr(ibmr));
+	struct mlx5_ib_mr *mmr = to_mmr(ibmr);
+
+	if (ibmr->type == IB_MR_TYPE_INTEGRITY) {
+		dereg_mr(to_mdev(mmr->mtt_mr->ibmr.device), mmr->mtt_mr);
+		dereg_mr(to_mdev(mmr->klm_mr->ibmr.device), mmr->klm_mr);
+	}
+
+	dereg_mr(to_mdev(ibmr->device), mmr);
+
 	return 0;
 }
 
-struct ib_mr *mlx5_ib_alloc_mr(struct ib_pd *pd,
-			       enum ib_mr_type mr_type,
-			       u32 max_num_sg)
+static void mlx5_set_umr_free_mkey(struct ib_pd *pd, u32 *in, int ndescs,
+				   int access_mode, int page_shift)
+{
+	void *mkc;
+
+	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
+
+	MLX5_SET(mkc, mkc, free, 1);
+	MLX5_SET(mkc, mkc, qpn, 0xffffff);
+	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
+	MLX5_SET(mkc, mkc, translations_octword_size, ndescs);
+	MLX5_SET(mkc, mkc, access_mode_1_0, access_mode & 0x3);
+	MLX5_SET(mkc, mkc, access_mode_4_2, (access_mode >> 2) & 0x7);
+	MLX5_SET(mkc, mkc, umr_en, 1);
+	MLX5_SET(mkc, mkc, log_page_size, page_shift);
+}
+
+static int _mlx5_alloc_mkey_descs(struct ib_pd *pd, struct mlx5_ib_mr *mr,
+				  int ndescs, int desc_size, int page_shift,
+				  int access_mode, u32 *in, int inlen)
+{
+	struct mlx5_ib_dev *dev = to_mdev(pd->device);
+	int err;
+
+	mr->access_mode = access_mode;
+	mr->desc_size = desc_size;
+	mr->max_descs = ndescs;
+
+	err = mlx5_alloc_priv_descs(pd->device, mr, ndescs, desc_size);
+	if (err)
+		return err;
+
+	mlx5_set_umr_free_mkey(pd, in, ndescs, access_mode, page_shift);
+
+	err = mlx5_core_create_mkey(dev->mdev, &mr->mmkey, in, inlen);
+	if (err)
+		goto err_free_descs;
+
+	mr->mmkey.type = MLX5_MKEY_MR;
+	mr->ibmr.lkey = mr->mmkey.key;
+	mr->ibmr.rkey = mr->mmkey.key;
+
+	return 0;
+
+err_free_descs:
+	mlx5_free_priv_descs(mr);
+	return err;
+}
+
+static struct mlx5_ib_mr *mlx5_ib_alloc_pi_mr(struct ib_pd *pd,
+				u32 max_num_sg, u32 max_num_meta_sg,
+				int desc_size, int access_mode)
+{
+	int inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
+	int ndescs = ALIGN(max_num_sg + max_num_meta_sg, 4);
+	int page_shift = 0;
+	struct mlx5_ib_mr *mr;
+	u32 *in;
+	int err;
+
+	mr = kzalloc(sizeof(*mr), GFP_KERNEL);
+	if (!mr)
+		return ERR_PTR(-ENOMEM);
+
+	mr->ibmr.pd = pd;
+	mr->ibmr.device = pd->device;
+
+	in = kzalloc(inlen, GFP_KERNEL);
+	if (!in) {
+		err = -ENOMEM;
+		goto err_free;
+	}
+
+	if (access_mode == MLX5_MKC_ACCESS_MODE_MTT)
+		page_shift = PAGE_SHIFT;
+
+	err = _mlx5_alloc_mkey_descs(pd, mr, ndescs, desc_size, page_shift,
+				     access_mode, in, inlen);
+	if (err)
+		goto err_free_in;
+
+	mr->umem = NULL;
+	kfree(in);
+
+	return mr;
+
+err_free_in:
+	kfree(in);
+err_free:
+	kfree(mr);
+	return ERR_PTR(err);
+}
+
+static int mlx5_alloc_mem_reg_descs(struct ib_pd *pd, struct mlx5_ib_mr *mr,
+				    int ndescs, u32 *in, int inlen)
+{
+	return _mlx5_alloc_mkey_descs(pd, mr, ndescs, sizeof(struct mlx5_mtt),
+				      PAGE_SHIFT, MLX5_MKC_ACCESS_MODE_MTT, in,
+				      inlen);
+}
+
+static int mlx5_alloc_sg_gaps_descs(struct ib_pd *pd, struct mlx5_ib_mr *mr,
+				    int ndescs, u32 *in, int inlen)
+{
+	return _mlx5_alloc_mkey_descs(pd, mr, ndescs, sizeof(struct mlx5_klm),
+				      0, MLX5_MKC_ACCESS_MODE_KLMS, in, inlen);
+}
+
+static int mlx5_alloc_integrity_descs(struct ib_pd *pd, struct mlx5_ib_mr *mr,
+				      int max_num_sg, int max_num_meta_sg,
+				      u32 *in, int inlen)
+{
+	struct mlx5_ib_dev *dev = to_mdev(pd->device);
+	u32 psv_index[2];
+	void *mkc;
+	int err;
+
+	mr->sig = kzalloc(sizeof(*mr->sig), GFP_KERNEL);
+	if (!mr->sig)
+		return -ENOMEM;
+
+	/* create mem & wire PSVs */
+	err = mlx5_core_create_psv(dev->mdev, to_mpd(pd)->pdn, 2, psv_index);
+	if (err)
+		goto err_free_sig;
+
+	mr->sig->psv_memory.psv_idx = psv_index[0];
+	mr->sig->psv_wire.psv_idx = psv_index[1];
+
+	mr->sig->sig_status_checked = true;
+	mr->sig->sig_err_exists = false;
+	/* Next UMR, Arm SIGERR */
+	++mr->sig->sigerr_count;
+	mr->klm_mr = mlx5_ib_alloc_pi_mr(pd, max_num_sg, max_num_meta_sg,
+					 sizeof(struct mlx5_klm),
+					 MLX5_MKC_ACCESS_MODE_KLMS);
+	if (IS_ERR(mr->klm_mr)) {
+		err = PTR_ERR(mr->klm_mr);
+		goto err_destroy_psv;
+	}
+	mr->mtt_mr = mlx5_ib_alloc_pi_mr(pd, max_num_sg, max_num_meta_sg,
+					 sizeof(struct mlx5_mtt),
+					 MLX5_MKC_ACCESS_MODE_MTT);
+	if (IS_ERR(mr->mtt_mr)) {
+		err = PTR_ERR(mr->mtt_mr);
+		goto err_free_klm_mr;
+	}
+
+	/* Set bsf descriptors for mkey */
+	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
+	MLX5_SET(mkc, mkc, bsf_en, 1);
+	MLX5_SET(mkc, mkc, bsf_octword_size, MLX5_MKEY_BSF_OCTO_SIZE);
+
+	err = _mlx5_alloc_mkey_descs(pd, mr, 4, sizeof(struct mlx5_klm), 0,
+				     MLX5_MKC_ACCESS_MODE_KLMS, in, inlen);
+	if (err)
+		goto err_free_mtt_mr;
+
+	return 0;
+
+err_free_mtt_mr:
+	dereg_mr(to_mdev(mr->mtt_mr->ibmr.device), mr->mtt_mr);
+	mr->mtt_mr = NULL;
+err_free_klm_mr:
+	dereg_mr(to_mdev(mr->klm_mr->ibmr.device), mr->klm_mr);
+	mr->klm_mr = NULL;
+err_destroy_psv:
+	if (mlx5_core_destroy_psv(dev->mdev, mr->sig->psv_memory.psv_idx))
+		mlx5_ib_warn(dev, "failed to destroy mem psv %d\n",
+			     mr->sig->psv_memory.psv_idx);
+	if (mlx5_core_destroy_psv(dev->mdev, mr->sig->psv_wire.psv_idx))
+		mlx5_ib_warn(dev, "failed to destroy wire psv %d\n",
+			     mr->sig->psv_wire.psv_idx);
+err_free_sig:
+	kfree(mr->sig);
+
+	return err;
+}
+
+static struct ib_mr *__mlx5_ib_alloc_mr(struct ib_pd *pd,
+					enum ib_mr_type mr_type, u32 max_num_sg,
+					u32 max_num_meta_sg)
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
 	int inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
 	int ndescs = ALIGN(max_num_sg, 4);
 	struct mlx5_ib_mr *mr;
-	void *mkc;
 	u32 *in;
 	int err;
 
@@ -1693,93 +1839,32 @@
 		goto err_free;
 	}
 
-	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
-	MLX5_SET(mkc, mkc, free, 1);
-	MLX5_SET(mkc, mkc, translations_octword_size, ndescs);
-	MLX5_SET(mkc, mkc, qpn, 0xffffff);
-	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
+	mr->ibmr.device = pd->device;
+	mr->umem = NULL;
 
-	if (mr_type == IB_MR_TYPE_MEM_REG) {
-		mr->access_mode = MLX5_MKC_ACCESS_MODE_MTT;
-		MLX5_SET(mkc, mkc, log_page_size, PAGE_SHIFT);
-		err = mlx5_alloc_priv_descs(pd->device, mr,
-					    ndescs, sizeof(struct mlx5_mtt));
-		if (err)
-			goto err_free_in;
-
-		mr->desc_size = sizeof(struct mlx5_mtt);
-		mr->max_descs = ndescs;
-	} else if (mr_type == IB_MR_TYPE_SG_GAPS) {
-		mr->access_mode = MLX5_MKC_ACCESS_MODE_KLMS;
-
-		err = mlx5_alloc_priv_descs(pd->device, mr,
-					    ndescs, sizeof(struct mlx5_klm));
-		if (err)
-			goto err_free_in;
-		mr->desc_size = sizeof(struct mlx5_klm);
-		mr->max_descs = ndescs;
-	} else if (mr_type == IB_MR_TYPE_SIGNATURE) {
-		u32 psv_index[2];
-
-		MLX5_SET(mkc, mkc, bsf_en, 1);
-		MLX5_SET(mkc, mkc, bsf_octword_size, MLX5_MKEY_BSF_OCTO_SIZE);
-		mr->sig = kzalloc(sizeof(*mr->sig), GFP_KERNEL);
-		if (!mr->sig) {
-			err = -ENOMEM;
-			goto err_free_in;
-		}
-
-		/* create mem & wire PSVs */
-		err = mlx5_core_create_psv(dev->mdev, to_mpd(pd)->pdn,
-					   2, psv_index);
-		if (err)
-			goto err_free_sig;
-
-		mr->access_mode = MLX5_MKC_ACCESS_MODE_KLMS;
-		mr->sig->psv_memory.psv_idx = psv_index[0];
-		mr->sig->psv_wire.psv_idx = psv_index[1];
-
-		mr->sig->sig_status_checked = true;
-		mr->sig->sig_err_exists = false;
-		/* Next UMR, Arm SIGERR */
-		++mr->sig->sigerr_count;
-	} else {
+	switch (mr_type) {
+	case IB_MR_TYPE_MEM_REG:
+		err = mlx5_alloc_mem_reg_descs(pd, mr, ndescs, in, inlen);
+		break;
+	case IB_MR_TYPE_SG_GAPS:
+		err = mlx5_alloc_sg_gaps_descs(pd, mr, ndescs, in, inlen);
+		break;
+	case IB_MR_TYPE_INTEGRITY:
+		err = mlx5_alloc_integrity_descs(pd, mr, max_num_sg,
+						 max_num_meta_sg, in, inlen);
+		break;
+	default:
 		mlx5_ib_warn(dev, "Invalid mr type %d\n", mr_type);
 		err = -EINVAL;
-		goto err_free_in;
 	}
 
-	MLX5_SET(mkc, mkc, access_mode_1_0, mr->access_mode & 0x3);
-	MLX5_SET(mkc, mkc, access_mode_4_2, (mr->access_mode >> 2) & 0x7);
-	MLX5_SET(mkc, mkc, umr_en, 1);
-
-	mr->ibmr.device = pd->device;
-	err = mlx5_core_create_mkey(dev->mdev, &mr->mmkey, in, inlen);
 	if (err)
-		goto err_destroy_psv;
+		goto err_free_in;
 
-	mr->mmkey.type = MLX5_MKEY_MR;
-	mr->ibmr.lkey = mr->mmkey.key;
-	mr->ibmr.rkey = mr->mmkey.key;
-	mr->umem = NULL;
 	kfree(in);
 
 	return &mr->ibmr;
 
-err_destroy_psv:
-	if (mr->sig) {
-		if (mlx5_core_destroy_psv(dev->mdev,
-					  mr->sig->psv_memory.psv_idx))
-			mlx5_ib_warn(dev, "failed to destroy mem psv %d\n",
-				     mr->sig->psv_memory.psv_idx);
-		if (mlx5_core_destroy_psv(dev->mdev,
-					  mr->sig->psv_wire.psv_idx))
-			mlx5_ib_warn(dev, "failed to destroy wire psv %d\n",
-				     mr->sig->psv_wire.psv_idx);
-	}
-	mlx5_free_priv_descs(mr);
-err_free_sig:
-	kfree(mr->sig);
 err_free_in:
 	kfree(in);
 err_free:
@@ -1787,6 +1872,19 @@
 	return ERR_PTR(err);
 }
 
+struct ib_mr *mlx5_ib_alloc_mr(struct ib_pd *pd, enum ib_mr_type mr_type,
+			       u32 max_num_sg, struct ib_udata *udata)
+{
+	return __mlx5_ib_alloc_mr(pd, mr_type, max_num_sg, 0);
+}
+
+struct ib_mr *mlx5_ib_alloc_mr_integrity(struct ib_pd *pd,
+					 u32 max_num_sg, u32 max_num_meta_sg)
+{
+	return __mlx5_ib_alloc_mr(pd, IB_MR_TYPE_INTEGRITY, max_num_sg,
+				  max_num_meta_sg);
+}
+
 struct ib_mw *mlx5_ib_alloc_mw(struct ib_pd *pd, enum ib_mw_type type,
 			       struct ib_udata *udata)
 {
@@ -1864,14 +1962,25 @@
 
 int mlx5_ib_dealloc_mw(struct ib_mw *mw)
 {
+	struct mlx5_ib_dev *dev = to_mdev(mw->device);
 	struct mlx5_ib_mw *mmw = to_mmw(mw);
 	int err;
 
-	err =  mlx5_core_destroy_mkey((to_mdev(mw->device))->mdev,
-				      &mmw->mmkey);
-	if (!err)
-		kfree(mmw);
-	return err;
+	if (IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING)) {
+		xa_erase_irq(&dev->mdev->priv.mkey_table,
+			     mlx5_base_mkey(mmw->mmkey.key));
+		/*
+		 * pagefault_single_data_segment() may be accessing mmw under
+		 * SRCU if the user bound an ODP MR to this MW.
+		 */
+		synchronize_srcu(&dev->mr_srcu);
+	}
+
+	err = mlx5_core_destroy_mkey(dev->mdev, &mmw->mmkey);
+	if (err)
+		return err;
+	kfree(mmw);
+	return 0;
 }
 
 int mlx5_ib_check_mr_status(struct ib_mr *ibmr, u32 check_mask,
@@ -1916,16 +2025,53 @@
 }
 
 static int
+mlx5_ib_map_pa_mr_sg_pi(struct ib_mr *ibmr, struct scatterlist *data_sg,
+			int data_sg_nents, unsigned int *data_sg_offset,
+			struct scatterlist *meta_sg, int meta_sg_nents,
+			unsigned int *meta_sg_offset)
+{
+	struct mlx5_ib_mr *mr = to_mmr(ibmr);
+	unsigned int sg_offset = 0;
+	int n = 0;
+
+	mr->meta_length = 0;
+	if (data_sg_nents == 1) {
+		n++;
+		mr->ndescs = 1;
+		if (data_sg_offset)
+			sg_offset = *data_sg_offset;
+		mr->data_length = sg_dma_len(data_sg) - sg_offset;
+		mr->data_iova = sg_dma_address(data_sg) + sg_offset;
+		if (meta_sg_nents == 1) {
+			n++;
+			mr->meta_ndescs = 1;
+			if (meta_sg_offset)
+				sg_offset = *meta_sg_offset;
+			else
+				sg_offset = 0;
+			mr->meta_length = sg_dma_len(meta_sg) - sg_offset;
+			mr->pi_iova = sg_dma_address(meta_sg) + sg_offset;
+		}
+		ibmr->length = mr->data_length + mr->meta_length;
+	}
+
+	return n;
+}
+
+static int
 mlx5_ib_sg_to_klms(struct mlx5_ib_mr *mr,
 		   struct scatterlist *sgl,
 		   unsigned short sg_nents,
-		   unsigned int *sg_offset_p)
+		   unsigned int *sg_offset_p,
+		   struct scatterlist *meta_sgl,
+		   unsigned short meta_sg_nents,
+		   unsigned int *meta_sg_offset_p)
 {
 	struct scatterlist *sg = sgl;
 	struct mlx5_klm *klms = mr->descs;
 	unsigned int sg_offset = sg_offset_p ? *sg_offset_p : 0;
 	u32 lkey = mr->ibmr.pd->local_dma_lkey;
-	int i;
+	int i, j = 0;
 
 	mr->ibmr.iova = sg_dma_address(sg) + sg_offset;
 	mr->ibmr.length = 0;
@@ -1940,12 +2086,36 @@
 
 		sg_offset = 0;
 	}
-	mr->ndescs = i;
 
 	if (sg_offset_p)
 		*sg_offset_p = sg_offset;
 
-	return i;
+	mr->ndescs = i;
+	mr->data_length = mr->ibmr.length;
+
+	if (meta_sg_nents) {
+		sg = meta_sgl;
+		sg_offset = meta_sg_offset_p ? *meta_sg_offset_p : 0;
+		for_each_sg(meta_sgl, sg, meta_sg_nents, j) {
+			if (unlikely(i + j >= mr->max_descs))
+				break;
+			klms[i + j].va = cpu_to_be64(sg_dma_address(sg) +
+						     sg_offset);
+			klms[i + j].bcount = cpu_to_be32(sg_dma_len(sg) -
+							 sg_offset);
+			klms[i + j].key = cpu_to_be32(lkey);
+			mr->ibmr.length += sg_dma_len(sg) - sg_offset;
+
+			sg_offset = 0;
+		}
+		if (meta_sg_offset_p)
+			*meta_sg_offset_p = sg_offset;
+
+		mr->meta_ndescs = j;
+		mr->meta_length = mr->ibmr.length - mr->data_length;
+	}
+
+	return i + j;
 }
 
 static int mlx5_set_page(struct ib_mr *ibmr, u64 addr)
@@ -1962,6 +2132,181 @@
 	return 0;
 }
 
+static int mlx5_set_page_pi(struct ib_mr *ibmr, u64 addr)
+{
+	struct mlx5_ib_mr *mr = to_mmr(ibmr);
+	__be64 *descs;
+
+	if (unlikely(mr->ndescs + mr->meta_ndescs == mr->max_descs))
+		return -ENOMEM;
+
+	descs = mr->descs;
+	descs[mr->ndescs + mr->meta_ndescs++] =
+		cpu_to_be64(addr | MLX5_EN_RD | MLX5_EN_WR);
+
+	return 0;
+}
+
+static int
+mlx5_ib_map_mtt_mr_sg_pi(struct ib_mr *ibmr, struct scatterlist *data_sg,
+			 int data_sg_nents, unsigned int *data_sg_offset,
+			 struct scatterlist *meta_sg, int meta_sg_nents,
+			 unsigned int *meta_sg_offset)
+{
+	struct mlx5_ib_mr *mr = to_mmr(ibmr);
+	struct mlx5_ib_mr *pi_mr = mr->mtt_mr;
+	int n;
+
+	pi_mr->ndescs = 0;
+	pi_mr->meta_ndescs = 0;
+	pi_mr->meta_length = 0;
+
+	ib_dma_sync_single_for_cpu(ibmr->device, pi_mr->desc_map,
+				   pi_mr->desc_size * pi_mr->max_descs,
+				   DMA_TO_DEVICE);
+
+	pi_mr->ibmr.page_size = ibmr->page_size;
+	n = ib_sg_to_pages(&pi_mr->ibmr, data_sg, data_sg_nents, data_sg_offset,
+			   mlx5_set_page);
+	if (n != data_sg_nents)
+		return n;
+
+	pi_mr->data_iova = pi_mr->ibmr.iova;
+	pi_mr->data_length = pi_mr->ibmr.length;
+	pi_mr->ibmr.length = pi_mr->data_length;
+	ibmr->length = pi_mr->data_length;
+
+	if (meta_sg_nents) {
+		u64 page_mask = ~((u64)ibmr->page_size - 1);
+		u64 iova = pi_mr->data_iova;
+
+		n += ib_sg_to_pages(&pi_mr->ibmr, meta_sg, meta_sg_nents,
+				    meta_sg_offset, mlx5_set_page_pi);
+
+		pi_mr->meta_length = pi_mr->ibmr.length;
+		/*
+		 * PI address for the HW is the offset of the metadata address
+		 * relative to the first data page address.
+		 * It equals to first data page address + size of data pages +
+		 * metadata offset at the first metadata page
+		 */
+		pi_mr->pi_iova = (iova & page_mask) +
+				 pi_mr->ndescs * ibmr->page_size +
+				 (pi_mr->ibmr.iova & ~page_mask);
+		/*
+		 * In order to use one MTT MR for data and metadata, we register
+		 * also the gaps between the end of the data and the start of
+		 * the metadata (the sig MR will verify that the HW will access
+		 * to right addresses). This mapping is safe because we use
+		 * internal mkey for the registration.
+		 */
+		pi_mr->ibmr.length = pi_mr->pi_iova + pi_mr->meta_length - iova;
+		pi_mr->ibmr.iova = iova;
+		ibmr->length += pi_mr->meta_length;
+	}
+
+	ib_dma_sync_single_for_device(ibmr->device, pi_mr->desc_map,
+				      pi_mr->desc_size * pi_mr->max_descs,
+				      DMA_TO_DEVICE);
+
+	return n;
+}
+
+static int
+mlx5_ib_map_klm_mr_sg_pi(struct ib_mr *ibmr, struct scatterlist *data_sg,
+			 int data_sg_nents, unsigned int *data_sg_offset,
+			 struct scatterlist *meta_sg, int meta_sg_nents,
+			 unsigned int *meta_sg_offset)
+{
+	struct mlx5_ib_mr *mr = to_mmr(ibmr);
+	struct mlx5_ib_mr *pi_mr = mr->klm_mr;
+	int n;
+
+	pi_mr->ndescs = 0;
+	pi_mr->meta_ndescs = 0;
+	pi_mr->meta_length = 0;
+
+	ib_dma_sync_single_for_cpu(ibmr->device, pi_mr->desc_map,
+				   pi_mr->desc_size * pi_mr->max_descs,
+				   DMA_TO_DEVICE);
+
+	n = mlx5_ib_sg_to_klms(pi_mr, data_sg, data_sg_nents, data_sg_offset,
+			       meta_sg, meta_sg_nents, meta_sg_offset);
+
+	ib_dma_sync_single_for_device(ibmr->device, pi_mr->desc_map,
+				      pi_mr->desc_size * pi_mr->max_descs,
+				      DMA_TO_DEVICE);
+
+	/* This is zero-based memory region */
+	pi_mr->data_iova = 0;
+	pi_mr->ibmr.iova = 0;
+	pi_mr->pi_iova = pi_mr->data_length;
+	ibmr->length = pi_mr->ibmr.length;
+
+	return n;
+}
+
+int mlx5_ib_map_mr_sg_pi(struct ib_mr *ibmr, struct scatterlist *data_sg,
+			 int data_sg_nents, unsigned int *data_sg_offset,
+			 struct scatterlist *meta_sg, int meta_sg_nents,
+			 unsigned int *meta_sg_offset)
+{
+	struct mlx5_ib_mr *mr = to_mmr(ibmr);
+	struct mlx5_ib_mr *pi_mr = NULL;
+	int n;
+
+	WARN_ON(ibmr->type != IB_MR_TYPE_INTEGRITY);
+
+	mr->ndescs = 0;
+	mr->data_length = 0;
+	mr->data_iova = 0;
+	mr->meta_ndescs = 0;
+	mr->pi_iova = 0;
+	/*
+	 * As a performance optimization, if possible, there is no need to
+	 * perform UMR operation to register the data/metadata buffers.
+	 * First try to map the sg lists to PA descriptors with local_dma_lkey.
+	 * Fallback to UMR only in case of a failure.
+	 */
+	n = mlx5_ib_map_pa_mr_sg_pi(ibmr, data_sg, data_sg_nents,
+				    data_sg_offset, meta_sg, meta_sg_nents,
+				    meta_sg_offset);
+	if (n == data_sg_nents + meta_sg_nents)
+		goto out;
+	/*
+	 * As a performance optimization, if possible, there is no need to map
+	 * the sg lists to KLM descriptors. First try to map the sg lists to MTT
+	 * descriptors and fallback to KLM only in case of a failure.
+	 * It's more efficient for the HW to work with MTT descriptors
+	 * (especially in high load).
+	 * Use KLM (indirect access) only if it's mandatory.
+	 */
+	pi_mr = mr->mtt_mr;
+	n = mlx5_ib_map_mtt_mr_sg_pi(ibmr, data_sg, data_sg_nents,
+				     data_sg_offset, meta_sg, meta_sg_nents,
+				     meta_sg_offset);
+	if (n == data_sg_nents + meta_sg_nents)
+		goto out;
+
+	pi_mr = mr->klm_mr;
+	n = mlx5_ib_map_klm_mr_sg_pi(ibmr, data_sg, data_sg_nents,
+				     data_sg_offset, meta_sg, meta_sg_nents,
+				     meta_sg_offset);
+	if (unlikely(n != data_sg_nents + meta_sg_nents))
+		return -ENOMEM;
+
+out:
+	/* This is zero-based memory region */
+	ibmr->iova = 0;
+	mr->pi_mr = pi_mr;
+	if (pi_mr)
+		ibmr->sig_attrs->meta_length = pi_mr->meta_length;
+	else
+		ibmr->sig_attrs->meta_length = mr->meta_length;
+
+	return 0;
+}
+
 int mlx5_ib_map_mr_sg(struct ib_mr *ibmr, struct scatterlist *sg, int sg_nents,
 		      unsigned int *sg_offset)
 {
@@ -1975,7 +2320,8 @@
 				   DMA_TO_DEVICE);
 
 	if (mr->access_mode == MLX5_MKC_ACCESS_MODE_KLMS)
-		n = mlx5_ib_sg_to_klms(mr, sg, sg_nents, sg_offset);
+		n = mlx5_ib_sg_to_klms(mr, sg, sg_nents, sg_offset, NULL, 0,
+				       NULL);
 	else
 		n = ib_sg_to_pages(ibmr, sg, sg_nents, sg_offset,
 				mlx5_set_page);