Update Linux to v5.10.109

Sourced from [1]

[1] https://cdn.kernel.org/pub/linux/kernel/v5.x/linux-5.10.109.tar.xz

Change-Id: I19bca9fc6762d4e63bcf3e4cba88bbe560d9c76c
Signed-off-by: Olivier Deprez <olivier.deprez@arm.com>
diff --git a/drivers/infiniband/hw/mlx5/mr.c b/drivers/infiniband/hw/mlx5/mr.c
index 99d563d..1934669 100644
--- a/drivers/infiniband/hw/mlx5/mr.c
+++ b/drivers/infiniband/hw/mlx5/mr.c
@@ -47,10 +47,69 @@
 
 #define MLX5_UMR_ALIGN 2048
 
+static void
+create_mkey_callback(int status, struct mlx5_async_work *context);
+
+static void set_mkc_access_pd_addr_fields(void *mkc, int acc, u64 start_addr,
+					  struct ib_pd *pd)
+{
+	struct mlx5_ib_dev *dev = to_mdev(pd->device);
+
+	MLX5_SET(mkc, mkc, a, !!(acc & IB_ACCESS_REMOTE_ATOMIC));
+	MLX5_SET(mkc, mkc, rw, !!(acc & IB_ACCESS_REMOTE_WRITE));
+	MLX5_SET(mkc, mkc, rr, !!(acc & IB_ACCESS_REMOTE_READ));
+	MLX5_SET(mkc, mkc, lw, !!(acc & IB_ACCESS_LOCAL_WRITE));
+	MLX5_SET(mkc, mkc, lr, 1);
+
+	if (MLX5_CAP_GEN(dev->mdev, relaxed_ordering_write))
+		MLX5_SET(mkc, mkc, relaxed_ordering_write,
+			 !!(acc & IB_ACCESS_RELAXED_ORDERING));
+	if (MLX5_CAP_GEN(dev->mdev, relaxed_ordering_read))
+		MLX5_SET(mkc, mkc, relaxed_ordering_read,
+			 !!(acc & IB_ACCESS_RELAXED_ORDERING));
+
+	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
+	MLX5_SET(mkc, mkc, qpn, 0xffffff);
+	MLX5_SET64(mkc, mkc, start_addr, start_addr);
+}
+
+static void
+assign_mkey_variant(struct mlx5_ib_dev *dev, struct mlx5_core_mkey *mkey,
+		    u32 *in)
+{
+	u8 key = atomic_inc_return(&dev->mkey_var);
+	void *mkc;
+
+	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
+	MLX5_SET(mkc, mkc, mkey_7_0, key);
+	mkey->key = key;
+}
+
+static int
+mlx5_ib_create_mkey(struct mlx5_ib_dev *dev, struct mlx5_core_mkey *mkey,
+		    u32 *in, int inlen)
+{
+	assign_mkey_variant(dev, mkey, in);
+	return mlx5_core_create_mkey(dev->mdev, mkey, in, inlen);
+}
+
+static int
+mlx5_ib_create_mkey_cb(struct mlx5_ib_dev *dev,
+		       struct mlx5_core_mkey *mkey,
+		       struct mlx5_async_ctx *async_ctx,
+		       u32 *in, int inlen, u32 *out, int outlen,
+		       struct mlx5_async_work *context)
+{
+	MLX5_SET(create_mkey_in, in, opcode, MLX5_CMD_OP_CREATE_MKEY);
+	assign_mkey_variant(dev, mkey, in);
+	return mlx5_cmd_exec_cb(async_ctx, in, inlen, out, outlen,
+				create_mkey_callback, context);
+}
+
 static void clean_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr);
 static void dereg_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr);
 static int mr_cache_max_order(struct mlx5_ib_dev *dev);
-static int unreg_umr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr);
+static void queue_adjust_cache_locked(struct mlx5_cache_ent *ent);
 
 static bool umr_can_use_indirect_mkey(struct mlx5_ib_dev *dev)
 {
@@ -59,85 +118,79 @@
 
 static int destroy_mkey(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr)
 {
-	int err = mlx5_core_destroy_mkey(dev->mdev, &mr->mmkey);
+	WARN_ON(xa_load(&dev->odp_mkeys, mlx5_base_mkey(mr->mmkey.key)));
 
-	if (IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING))
-		/* Wait until all page fault handlers using the mr complete. */
-		synchronize_srcu(&dev->mr_srcu);
-
-	return err;
+	return mlx5_core_destroy_mkey(dev->mdev, &mr->mmkey);
 }
 
-static int order2idx(struct mlx5_ib_dev *dev, int order)
-{
-	struct mlx5_mr_cache *cache = &dev->cache;
-
-	if (order < cache->ent[0].order)
-		return 0;
-	else
-		return order - cache->ent[0].order;
-}
-
-static bool use_umr_mtt_update(struct mlx5_ib_mr *mr, u64 start, u64 length)
+static inline bool mlx5_ib_pas_fits_in_mr(struct mlx5_ib_mr *mr, u64 start,
+					  u64 length)
 {
 	return ((u64)1 << mr->order) * MLX5_ADAPTER_PAGE_SIZE >=
 		length + (start & (MLX5_ADAPTER_PAGE_SIZE - 1));
 }
 
-static void reg_mr_callback(int status, struct mlx5_async_work *context)
+static void create_mkey_callback(int status, struct mlx5_async_work *context)
 {
 	struct mlx5_ib_mr *mr =
 		container_of(context, struct mlx5_ib_mr, cb_work);
 	struct mlx5_ib_dev *dev = mr->dev;
-	struct mlx5_mr_cache *cache = &dev->cache;
-	int c = order2idx(dev, mr->order);
-	struct mlx5_cache_ent *ent = &cache->ent[c];
-	u8 key;
+	struct mlx5_cache_ent *ent = mr->cache_ent;
 	unsigned long flags;
-	struct xarray *mkeys = &dev->mdev->priv.mkey_table;
-	int err;
 
-	spin_lock_irqsave(&ent->lock, flags);
-	ent->pending--;
-	spin_unlock_irqrestore(&ent->lock, flags);
 	if (status) {
 		mlx5_ib_warn(dev, "async reg mr failed. status %d\n", status);
 		kfree(mr);
-		dev->fill_delay = 1;
+		spin_lock_irqsave(&ent->lock, flags);
+		ent->pending--;
+		WRITE_ONCE(dev->fill_delay, 1);
+		spin_unlock_irqrestore(&ent->lock, flags);
 		mod_timer(&dev->delay_timer, jiffies + HZ);
 		return;
 	}
 
 	mr->mmkey.type = MLX5_MKEY_MR;
-	spin_lock_irqsave(&dev->mdev->priv.mkey_lock, flags);
-	key = dev->mdev->priv.mkey_key++;
-	spin_unlock_irqrestore(&dev->mdev->priv.mkey_lock, flags);
-	mr->mmkey.key = mlx5_idx_to_mkey(MLX5_GET(create_mkey_out, mr->out, mkey_index)) | key;
+	mr->mmkey.key |= mlx5_idx_to_mkey(
+		MLX5_GET(create_mkey_out, mr->out, mkey_index));
 
-	cache->last_add = jiffies;
+	WRITE_ONCE(dev->cache.last_add, jiffies);
 
 	spin_lock_irqsave(&ent->lock, flags);
 	list_add_tail(&mr->list, &ent->head);
-	ent->cur++;
-	ent->size++;
+	ent->available_mrs++;
+	ent->total_mrs++;
+	/* If we are doing fill_to_high_water then keep going. */
+	queue_adjust_cache_locked(ent);
+	ent->pending--;
 	spin_unlock_irqrestore(&ent->lock, flags);
-
-	xa_lock_irqsave(mkeys, flags);
-	err = xa_err(__xa_store(mkeys, mlx5_base_mkey(mr->mmkey.key),
-				&mr->mmkey, GFP_ATOMIC));
-	xa_unlock_irqrestore(mkeys, flags);
-	if (err)
-		pr_err("Error inserting to mkey tree. 0x%x\n", -err);
-
-	if (!completion_done(&ent->compl))
-		complete(&ent->compl);
 }
 
-static int add_keys(struct mlx5_ib_dev *dev, int c, int num)
+static struct mlx5_ib_mr *alloc_cache_mr(struct mlx5_cache_ent *ent, void *mkc)
 {
-	struct mlx5_mr_cache *cache = &dev->cache;
-	struct mlx5_cache_ent *ent = &cache->ent[c];
-	int inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
+	struct mlx5_ib_mr *mr;
+
+	mr = kzalloc(sizeof(*mr), GFP_KERNEL);
+	if (!mr)
+		return NULL;
+	mr->order = ent->order;
+	mr->cache_ent = ent;
+	mr->dev = ent->dev;
+
+	set_mkc_access_pd_addr_fields(mkc, 0, 0, ent->dev->umrc.pd);
+	MLX5_SET(mkc, mkc, free, 1);
+	MLX5_SET(mkc, mkc, umr_en, 1);
+	MLX5_SET(mkc, mkc, access_mode_1_0, ent->access_mode & 0x3);
+	MLX5_SET(mkc, mkc, access_mode_4_2, (ent->access_mode >> 2) & 0x7);
+
+	MLX5_SET(mkc, mkc, translations_octword_size, ent->xlt);
+	MLX5_SET(mkc, mkc, log_page_size, ent->page);
+	return mr;
+}
+
+/* Asynchronously schedule new MRs to be populated in the cache. */
+static int add_keys(struct mlx5_cache_ent *ent, unsigned int num)
+{
+	size_t inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
 	struct mlx5_ib_mr *mr;
 	void *mkc;
 	u32 *in;
@@ -150,42 +203,29 @@
 
 	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
 	for (i = 0; i < num; i++) {
-		if (ent->pending >= MAX_PENDING_REG_MR) {
-			err = -EAGAIN;
-			break;
-		}
-
-		mr = kzalloc(sizeof(*mr), GFP_KERNEL);
+		mr = alloc_cache_mr(ent, mkc);
 		if (!mr) {
 			err = -ENOMEM;
 			break;
 		}
-		mr->order = ent->order;
-		mr->allocated_from_cache = 1;
-		mr->dev = dev;
-
-		MLX5_SET(mkc, mkc, free, 1);
-		MLX5_SET(mkc, mkc, umr_en, 1);
-		MLX5_SET(mkc, mkc, access_mode_1_0, ent->access_mode & 0x3);
-		MLX5_SET(mkc, mkc, access_mode_4_2,
-			 (ent->access_mode >> 2) & 0x7);
-
-		MLX5_SET(mkc, mkc, qpn, 0xffffff);
-		MLX5_SET(mkc, mkc, translations_octword_size, ent->xlt);
-		MLX5_SET(mkc, mkc, log_page_size, ent->page);
-
 		spin_lock_irq(&ent->lock);
+		if (ent->pending >= MAX_PENDING_REG_MR) {
+			err = -EAGAIN;
+			spin_unlock_irq(&ent->lock);
+			kfree(mr);
+			break;
+		}
 		ent->pending++;
 		spin_unlock_irq(&ent->lock);
-		err = mlx5_core_create_mkey_cb(dev->mdev, &mr->mmkey,
-					       &dev->async_ctx, in, inlen,
-					       mr->out, sizeof(mr->out),
-					       reg_mr_callback, &mr->cb_work);
+		err = mlx5_ib_create_mkey_cb(ent->dev, &mr->mmkey,
+					     &ent->dev->async_ctx, in, inlen,
+					     mr->out, sizeof(mr->out),
+					     &mr->cb_work);
 		if (err) {
 			spin_lock_irq(&ent->lock);
 			ent->pending--;
 			spin_unlock_irq(&ent->lock);
-			mlx5_ib_warn(dev, "create mkey failed %d\n", err);
+			mlx5_ib_warn(ent->dev, "create mkey failed %d\n", err);
 			kfree(mr);
 			break;
 		}
@@ -195,35 +235,89 @@
 	return err;
 }
 
-static void remove_keys(struct mlx5_ib_dev *dev, int c, int num)
+/* Synchronously create a MR in the cache */
+static struct mlx5_ib_mr *create_cache_mr(struct mlx5_cache_ent *ent)
 {
-	struct mlx5_mr_cache *cache = &dev->cache;
-	struct mlx5_cache_ent *ent = &cache->ent[c];
-	struct mlx5_ib_mr *tmp_mr;
+	size_t inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
 	struct mlx5_ib_mr *mr;
-	LIST_HEAD(del_list);
-	int i;
+	void *mkc;
+	u32 *in;
+	int err;
 
-	for (i = 0; i < num; i++) {
-		spin_lock_irq(&ent->lock);
-		if (list_empty(&ent->head)) {
-			spin_unlock_irq(&ent->lock);
-			break;
-		}
-		mr = list_first_entry(&ent->head, struct mlx5_ib_mr, list);
-		list_move(&mr->list, &del_list);
-		ent->cur--;
-		ent->size--;
-		spin_unlock_irq(&ent->lock);
-		mlx5_core_destroy_mkey(dev->mdev, &mr->mmkey);
+	in = kzalloc(inlen, GFP_KERNEL);
+	if (!in)
+		return ERR_PTR(-ENOMEM);
+	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
+
+	mr = alloc_cache_mr(ent, mkc);
+	if (!mr) {
+		err = -ENOMEM;
+		goto free_in;
 	}
 
-	if (IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING))
-		synchronize_srcu(&dev->mr_srcu);
+	err = mlx5_core_create_mkey(ent->dev->mdev, &mr->mmkey, in, inlen);
+	if (err)
+		goto free_mr;
 
-	list_for_each_entry_safe(mr, tmp_mr, &del_list, list) {
-		list_del(&mr->list);
-		kfree(mr);
+	mr->mmkey.type = MLX5_MKEY_MR;
+	WRITE_ONCE(ent->dev->cache.last_add, jiffies);
+	spin_lock_irq(&ent->lock);
+	ent->total_mrs++;
+	spin_unlock_irq(&ent->lock);
+	kfree(in);
+	return mr;
+free_mr:
+	kfree(mr);
+free_in:
+	kfree(in);
+	return ERR_PTR(err);
+}
+
+static void remove_cache_mr_locked(struct mlx5_cache_ent *ent)
+{
+	struct mlx5_ib_mr *mr;
+
+	lockdep_assert_held(&ent->lock);
+	if (list_empty(&ent->head))
+		return;
+	mr = list_first_entry(&ent->head, struct mlx5_ib_mr, list);
+	list_del(&mr->list);
+	ent->available_mrs--;
+	ent->total_mrs--;
+	spin_unlock_irq(&ent->lock);
+	mlx5_core_destroy_mkey(ent->dev->mdev, &mr->mmkey);
+	kfree(mr);
+	spin_lock_irq(&ent->lock);
+}
+
+static int resize_available_mrs(struct mlx5_cache_ent *ent, unsigned int target,
+				bool limit_fill)
+{
+	int err;
+
+	lockdep_assert_held(&ent->lock);
+
+	while (true) {
+		if (limit_fill)
+			target = ent->limit * 2;
+		if (target == ent->available_mrs + ent->pending)
+			return 0;
+		if (target > ent->available_mrs + ent->pending) {
+			u32 todo = target - (ent->available_mrs + ent->pending);
+
+			spin_unlock_irq(&ent->lock);
+			err = add_keys(ent, todo);
+			if (err == -EAGAIN)
+				usleep_range(3000, 5000);
+			spin_lock_irq(&ent->lock);
+			if (err) {
+				if (err != -EAGAIN)
+					return err;
+			} else
+				return 0;
+		} else {
+			remove_cache_mr_locked(ent);
+		}
 	}
 }
 
@@ -231,37 +325,38 @@
 			  size_t count, loff_t *pos)
 {
 	struct mlx5_cache_ent *ent = filp->private_data;
-	struct mlx5_ib_dev *dev = ent->dev;
-	char lbuf[20] = {0};
-	u32 var;
+	u32 target;
 	int err;
-	int c;
 
-	count = min(count, sizeof(lbuf) - 1);
-	if (copy_from_user(lbuf, buf, count))
-		return -EFAULT;
+	err = kstrtou32_from_user(buf, count, 0, &target);
+	if (err)
+		return err;
 
-	c = order2idx(dev, ent->order);
-
-	if (sscanf(lbuf, "%u", &var) != 1)
-		return -EINVAL;
-
-	if (var < ent->limit)
-		return -EINVAL;
-
-	if (var > ent->size) {
-		do {
-			err = add_keys(dev, c, var - ent->size);
-			if (err && err != -EAGAIN)
-				return err;
-
-			usleep_range(3000, 5000);
-		} while (err);
-	} else if (var < ent->size) {
-		remove_keys(dev, c, ent->size - var);
+	/*
+	 * Target is the new value of total_mrs the user requests, however we
+	 * cannot free MRs that are in use. Compute the target value for
+	 * available_mrs.
+	 */
+	spin_lock_irq(&ent->lock);
+	if (target < ent->total_mrs - ent->available_mrs) {
+		err = -EINVAL;
+		goto err_unlock;
 	}
+	target = target - (ent->total_mrs - ent->available_mrs);
+	if (target < ent->limit || target > ent->limit*2) {
+		err = -EINVAL;
+		goto err_unlock;
+	}
+	err = resize_available_mrs(ent, target, false);
+	if (err)
+		goto err_unlock;
+	spin_unlock_irq(&ent->lock);
 
 	return count;
+
+err_unlock:
+	spin_unlock_irq(&ent->lock);
+	return err;
 }
 
 static ssize_t size_read(struct file *filp, char __user *buf, size_t count,
@@ -271,7 +366,7 @@
 	char lbuf[20];
 	int err;
 
-	err = snprintf(lbuf, sizeof(lbuf), "%d\n", ent->size);
+	err = snprintf(lbuf, sizeof(lbuf), "%d\n", ent->total_mrs);
 	if (err < 0)
 		return err;
 
@@ -289,32 +384,23 @@
 			   size_t count, loff_t *pos)
 {
 	struct mlx5_cache_ent *ent = filp->private_data;
-	struct mlx5_ib_dev *dev = ent->dev;
-	char lbuf[20] = {0};
 	u32 var;
 	int err;
-	int c;
 
-	count = min(count, sizeof(lbuf) - 1);
-	if (copy_from_user(lbuf, buf, count))
-		return -EFAULT;
+	err = kstrtou32_from_user(buf, count, 0, &var);
+	if (err)
+		return err;
 
-	c = order2idx(dev, ent->order);
-
-	if (sscanf(lbuf, "%u", &var) != 1)
-		return -EINVAL;
-
-	if (var > ent->size)
-		return -EINVAL;
-
+	/*
+	 * Upon set we immediately fill the cache to high water mark implied by
+	 * the limit.
+	 */
+	spin_lock_irq(&ent->lock);
 	ent->limit = var;
-
-	if (ent->cur < ent->limit) {
-		err = add_keys(dev, c, 2 * ent->limit - ent->cur);
-		if (err)
-			return err;
-	}
-
+	err = resize_available_mrs(ent, 0, true);
+	spin_unlock_irq(&ent->lock);
+	if (err)
+		return err;
 	return count;
 }
 
@@ -339,68 +425,119 @@
 	.read	= limit_read,
 };
 
-static int someone_adding(struct mlx5_mr_cache *cache)
+static bool someone_adding(struct mlx5_mr_cache *cache)
 {
-	int i;
+	unsigned int i;
 
 	for (i = 0; i < MAX_MR_CACHE_ENTRIES; i++) {
-		if (cache->ent[i].cur < cache->ent[i].limit)
-			return 1;
-	}
+		struct mlx5_cache_ent *ent = &cache->ent[i];
+		bool ret;
 
-	return 0;
+		spin_lock_irq(&ent->lock);
+		ret = ent->available_mrs < ent->limit;
+		spin_unlock_irq(&ent->lock);
+		if (ret)
+			return true;
+	}
+	return false;
+}
+
+/*
+ * Check if the bucket is outside the high/low water mark and schedule an async
+ * update. The cache refill has hysteresis, once the low water mark is hit it is
+ * refilled up to the high mark.
+ */
+static void queue_adjust_cache_locked(struct mlx5_cache_ent *ent)
+{
+	lockdep_assert_held(&ent->lock);
+
+	if (ent->disabled || READ_ONCE(ent->dev->fill_delay))
+		return;
+	if (ent->available_mrs < ent->limit) {
+		ent->fill_to_high_water = true;
+		queue_work(ent->dev->cache.wq, &ent->work);
+	} else if (ent->fill_to_high_water &&
+		   ent->available_mrs + ent->pending < 2 * ent->limit) {
+		/*
+		 * Once we start populating due to hitting a low water mark
+		 * continue until we pass the high water mark.
+		 */
+		queue_work(ent->dev->cache.wq, &ent->work);
+	} else if (ent->available_mrs == 2 * ent->limit) {
+		ent->fill_to_high_water = false;
+	} else if (ent->available_mrs > 2 * ent->limit) {
+		/* Queue deletion of excess entries */
+		ent->fill_to_high_water = false;
+		if (ent->pending)
+			queue_delayed_work(ent->dev->cache.wq, &ent->dwork,
+					   msecs_to_jiffies(1000));
+		else
+			queue_work(ent->dev->cache.wq, &ent->work);
+	}
 }
 
 static void __cache_work_func(struct mlx5_cache_ent *ent)
 {
 	struct mlx5_ib_dev *dev = ent->dev;
 	struct mlx5_mr_cache *cache = &dev->cache;
-	int i = order2idx(dev, ent->order);
 	int err;
 
-	if (cache->stopped)
-		return;
+	spin_lock_irq(&ent->lock);
+	if (ent->disabled)
+		goto out;
 
-	ent = &dev->cache.ent[i];
-	if (ent->cur < 2 * ent->limit && !dev->fill_delay) {
-		err = add_keys(dev, i, 1);
-		if (ent->cur < 2 * ent->limit) {
-			if (err == -EAGAIN) {
-				mlx5_ib_dbg(dev, "returned eagain, order %d\n",
-					    i + 2);
-				queue_delayed_work(cache->wq, &ent->dwork,
-						   msecs_to_jiffies(3));
-			} else if (err) {
-				mlx5_ib_warn(dev, "command failed order %d, err %d\n",
-					     i + 2, err);
+	if (ent->fill_to_high_water &&
+	    ent->available_mrs + ent->pending < 2 * ent->limit &&
+	    !READ_ONCE(dev->fill_delay)) {
+		spin_unlock_irq(&ent->lock);
+		err = add_keys(ent, 1);
+		spin_lock_irq(&ent->lock);
+		if (ent->disabled)
+			goto out;
+		if (err) {
+			/*
+			 * EAGAIN only happens if pending is positive, so we
+			 * will be rescheduled from reg_mr_callback(). The only
+			 * failure path here is ENOMEM.
+			 */
+			if (err != -EAGAIN) {
+				mlx5_ib_warn(
+					dev,
+					"command failed order %d, err %d\n",
+					ent->order, err);
 				queue_delayed_work(cache->wq, &ent->dwork,
 						   msecs_to_jiffies(1000));
-			} else {
-				queue_work(cache->wq, &ent->work);
 			}
 		}
-	} else if (ent->cur > 2 * ent->limit) {
+	} else if (ent->available_mrs > 2 * ent->limit) {
+		bool need_delay;
+
 		/*
-		 * The remove_keys() logic is performed as garbage collection
-		 * task. Such task is intended to be run when no other active
-		 * processes are running.
+		 * The remove_cache_mr() logic is performed as garbage
+		 * collection task. Such task is intended to be run when no
+		 * other active processes are running.
 		 *
 		 * The need_resched() will return TRUE if there are user tasks
 		 * to be activated in near future.
 		 *
-		 * In such case, we don't execute remove_keys() and postpone
-		 * the garbage collection work to try to run in next cycle,
-		 * in order to free CPU resources to other tasks.
+		 * In such case, we don't execute remove_cache_mr() and postpone
+		 * the garbage collection work to try to run in next cycle, in
+		 * order to free CPU resources to other tasks.
 		 */
-		if (!need_resched() && !someone_adding(cache) &&
-		    time_after(jiffies, cache->last_add + 300 * HZ)) {
-			remove_keys(dev, i, 1);
-			if (ent->cur > ent->limit)
-				queue_work(cache->wq, &ent->work);
-		} else {
+		spin_unlock_irq(&ent->lock);
+		need_delay = need_resched() || someone_adding(cache) ||
+			     !time_after(jiffies,
+					 READ_ONCE(cache->last_add) + 300 * HZ);
+		spin_lock_irq(&ent->lock);
+		if (ent->disabled)
+			goto out;
+		if (need_delay)
 			queue_delayed_work(cache->wq, &ent->dwork, 300 * HZ);
-		}
+		remove_cache_mr_locked(ent);
+		queue_adjust_cache_locked(ent);
 	}
+out:
+	spin_unlock_irq(&ent->lock);
 }
 
 static void delayed_cache_work_func(struct work_struct *work)
@@ -419,117 +556,101 @@
 	__cache_work_func(ent);
 }
 
-struct mlx5_ib_mr *mlx5_mr_cache_alloc(struct mlx5_ib_dev *dev, int entry)
+/* Allocate a special entry from the cache */
+struct mlx5_ib_mr *mlx5_mr_cache_alloc(struct mlx5_ib_dev *dev,
+				       unsigned int entry, int access_flags)
 {
 	struct mlx5_mr_cache *cache = &dev->cache;
 	struct mlx5_cache_ent *ent;
 	struct mlx5_ib_mr *mr;
-	int err;
 
-	if (entry < 0 || entry >= MAX_MR_CACHE_ENTRIES) {
-		mlx5_ib_err(dev, "cache entry %d is out of range\n", entry);
+	if (WARN_ON(entry <= MR_CACHE_LAST_STD_ENTRY ||
+		    entry >= ARRAY_SIZE(cache->ent)))
 		return ERR_PTR(-EINVAL);
-	}
+
+	/* Matches access in alloc_cache_mr() */
+	if (!mlx5_ib_can_reconfig_with_umr(dev, 0, access_flags))
+		return ERR_PTR(-EOPNOTSUPP);
 
 	ent = &cache->ent[entry];
-	while (1) {
-		spin_lock_irq(&ent->lock);
-		if (list_empty(&ent->head)) {
-			spin_unlock_irq(&ent->lock);
-
-			err = add_keys(dev, entry, 1);
-			if (err && err != -EAGAIN)
-				return ERR_PTR(err);
-
-			wait_for_completion(&ent->compl);
-		} else {
-			mr = list_first_entry(&ent->head, struct mlx5_ib_mr,
-					      list);
-			list_del(&mr->list);
-			ent->cur--;
-			spin_unlock_irq(&ent->lock);
-			if (ent->cur < ent->limit)
-				queue_work(cache->wq, &ent->work);
+	spin_lock_irq(&ent->lock);
+	if (list_empty(&ent->head)) {
+		spin_unlock_irq(&ent->lock);
+		mr = create_cache_mr(ent);
+		if (IS_ERR(mr))
 			return mr;
-		}
+	} else {
+		mr = list_first_entry(&ent->head, struct mlx5_ib_mr, list);
+		list_del(&mr->list);
+		ent->available_mrs--;
+		queue_adjust_cache_locked(ent);
+		spin_unlock_irq(&ent->lock);
 	}
+	mr->access_flags = access_flags;
+	return mr;
 }
 
-static struct mlx5_ib_mr *alloc_cached_mr(struct mlx5_ib_dev *dev, int order)
+/* Return a MR already available in the cache */
+static struct mlx5_ib_mr *get_cache_mr(struct mlx5_cache_ent *req_ent)
 {
-	struct mlx5_mr_cache *cache = &dev->cache;
+	struct mlx5_ib_dev *dev = req_ent->dev;
 	struct mlx5_ib_mr *mr = NULL;
-	struct mlx5_cache_ent *ent;
-	int last_umr_cache_entry;
-	int c;
-	int i;
+	struct mlx5_cache_ent *ent = req_ent;
 
-	c = order2idx(dev, order);
-	last_umr_cache_entry = order2idx(dev, mr_cache_max_order(dev));
-	if (c < 0 || c > last_umr_cache_entry) {
-		mlx5_ib_warn(dev, "order %d, cache index %d\n", order, c);
-		return NULL;
-	}
-
-	for (i = c; i <= last_umr_cache_entry; i++) {
-		ent = &cache->ent[i];
-
-		mlx5_ib_dbg(dev, "order %d, cache index %d\n", ent->order, i);
+	/* Try larger MR pools from the cache to satisfy the allocation */
+	for (; ent != &dev->cache.ent[MR_CACHE_LAST_STD_ENTRY + 1]; ent++) {
+		mlx5_ib_dbg(dev, "order %u, cache index %zu\n", ent->order,
+			    ent - dev->cache.ent);
 
 		spin_lock_irq(&ent->lock);
 		if (!list_empty(&ent->head)) {
 			mr = list_first_entry(&ent->head, struct mlx5_ib_mr,
 					      list);
 			list_del(&mr->list);
-			ent->cur--;
+			ent->available_mrs--;
+			queue_adjust_cache_locked(ent);
 			spin_unlock_irq(&ent->lock);
-			if (ent->cur < ent->limit)
-				queue_work(cache->wq, &ent->work);
 			break;
 		}
+		queue_adjust_cache_locked(ent);
 		spin_unlock_irq(&ent->lock);
-
-		queue_work(cache->wq, &ent->work);
 	}
 
 	if (!mr)
-		cache->ent[c].miss++;
+		req_ent->miss++;
 
 	return mr;
 }
 
+static void detach_mr_from_cache(struct mlx5_ib_mr *mr)
+{
+	struct mlx5_cache_ent *ent = mr->cache_ent;
+
+	mr->cache_ent = NULL;
+	spin_lock_irq(&ent->lock);
+	ent->total_mrs--;
+	spin_unlock_irq(&ent->lock);
+}
+
 void mlx5_mr_cache_free(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr)
 {
-	struct mlx5_mr_cache *cache = &dev->cache;
-	struct mlx5_cache_ent *ent;
-	int shrink = 0;
-	int c;
+	struct mlx5_cache_ent *ent = mr->cache_ent;
 
-	if (!mr->allocated_from_cache)
+	if (!ent)
 		return;
 
-	c = order2idx(dev, mr->order);
-	WARN_ON(c < 0 || c >= MAX_MR_CACHE_ENTRIES);
-
-	if (unreg_umr(dev, mr)) {
-		mr->allocated_from_cache = false;
+	if (mlx5_mr_cache_invalidate(mr)) {
+		detach_mr_from_cache(mr);
 		destroy_mkey(dev, mr);
-		ent = &cache->ent[c];
-		if (ent->cur < ent->limit)
-			queue_work(cache->wq, &ent->work);
+		kfree(mr);
 		return;
 	}
 
-	ent = &cache->ent[c];
 	spin_lock_irq(&ent->lock);
 	list_add_tail(&mr->list, &ent->head);
-	ent->cur++;
-	if (ent->cur > 2 * ent->limit)
-		shrink = 1;
+	ent->available_mrs++;
+	queue_adjust_cache_locked(ent);
 	spin_unlock_irq(&ent->lock);
-
-	if (shrink)
-		queue_work(cache->wq, &ent->work);
 }
 
 static void clean_keys(struct mlx5_ib_dev *dev, int c)
@@ -549,16 +670,12 @@
 		}
 		mr = list_first_entry(&ent->head, struct mlx5_ib_mr, list);
 		list_move(&mr->list, &del_list);
-		ent->cur--;
-		ent->size--;
+		ent->available_mrs--;
+		ent->total_mrs--;
 		spin_unlock_irq(&ent->lock);
 		mlx5_core_destroy_mkey(dev->mdev, &mr->mmkey);
 	}
 
-#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
-	synchronize_srcu(&dev->mr_srcu);
-#endif
-
 	list_for_each_entry_safe(mr, tmp_mr, &del_list, list) {
 		list_del(&mr->list);
 		kfree(mr);
@@ -592,7 +709,7 @@
 		dir = debugfs_create_dir(ent->name, cache->root);
 		debugfs_create_file("size", 0600, dir, ent, &size_fops);
 		debugfs_create_file("limit", 0600, dir, ent, &limit_fops);
-		debugfs_create_u32("cur", 0400, dir, &ent->cur);
+		debugfs_create_u32("cur", 0400, dir, &ent->available_mrs);
 		debugfs_create_u32("miss", 0600, dir, &ent->miss);
 	}
 }
@@ -601,7 +718,7 @@
 {
 	struct mlx5_ib_dev *dev = from_timer(dev, t, delay_timer);
 
-	dev->fill_delay = 0;
+	WRITE_ONCE(dev->fill_delay, 0);
 }
 
 int mlx5_mr_cache_init(struct mlx5_ib_dev *dev)
@@ -627,7 +744,6 @@
 		ent->dev = dev;
 		ent->limit = 0;
 
-		init_completion(&ent->compl);
 		INIT_WORK(&ent->work, cache_work_func);
 		INIT_DELAYED_WORK(&ent->dwork, delayed_cache_work_func);
 
@@ -644,12 +760,14 @@
 			   MLX5_IB_UMR_OCTOWORD;
 		ent->access_mode = MLX5_MKC_ACCESS_MODE_MTT;
 		if ((dev->mdev->profile->mask & MLX5_PROF_MASK_MR_CACHE) &&
-		    !dev->is_rep &&
-		    mlx5_core_is_pf(dev->mdev))
+		    !dev->is_rep && mlx5_core_is_pf(dev->mdev) &&
+		    mlx5_ib_can_load_pas_with_umr(dev, 0))
 			ent->limit = dev->mdev->profile->mr_cache[i].limit;
 		else
 			ent->limit = 0;
-		queue_work(cache->wq, &ent->work);
+		spin_lock_irq(&ent->lock);
+		queue_adjust_cache_locked(ent);
+		spin_unlock_irq(&ent->lock);
 	}
 
 	mlx5_mr_cache_debugfs_init(dev);
@@ -659,13 +777,20 @@
 
 int mlx5_mr_cache_cleanup(struct mlx5_ib_dev *dev)
 {
-	int i;
+	unsigned int i;
 
 	if (!dev->cache.wq)
 		return 0;
 
-	dev->cache.stopped = 1;
-	flush_workqueue(dev->cache.wq);
+	for (i = 0; i < MAX_MR_CACHE_ENTRIES; i++) {
+		struct mlx5_cache_ent *ent = &dev->cache.ent[i];
+
+		spin_lock_irq(&ent->lock);
+		ent->disabled = true;
+		spin_unlock_irq(&ent->lock);
+		cancel_work_sync(&ent->work);
+		cancel_delayed_work_sync(&ent->dwork);
+	}
 
 	mlx5_mr_cache_debugfs_cleanup(dev);
 	mlx5_cmd_cleanup_async_ctx(&dev->async_ctx);
@@ -683,7 +808,6 @@
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
 	int inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
-	struct mlx5_core_dev *mdev = dev->mdev;
 	struct mlx5_ib_mr *mr;
 	void *mkc;
 	u32 *in;
@@ -702,18 +826,10 @@
 	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
 
 	MLX5_SET(mkc, mkc, access_mode_1_0, MLX5_MKC_ACCESS_MODE_PA);
-	MLX5_SET(mkc, mkc, a, !!(acc & IB_ACCESS_REMOTE_ATOMIC));
-	MLX5_SET(mkc, mkc, rw, !!(acc & IB_ACCESS_REMOTE_WRITE));
-	MLX5_SET(mkc, mkc, rr, !!(acc & IB_ACCESS_REMOTE_READ));
-	MLX5_SET(mkc, mkc, lw, !!(acc & IB_ACCESS_LOCAL_WRITE));
-	MLX5_SET(mkc, mkc, lr, 1);
-
 	MLX5_SET(mkc, mkc, length64, 1);
-	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
-	MLX5_SET(mkc, mkc, qpn, 0xffffff);
-	MLX5_SET64(mkc, mkc, start_addr, 0);
+	set_mkc_access_pd_addr_fields(mkc, acc, 0, pd);
 
-	err = mlx5_core_create_mkey(mdev, &mr->mmkey, in, inlen);
+	err = mlx5_ib_create_mkey(dev, &mr->mmkey, in, inlen);
 	if (err)
 		goto err_in;
 
@@ -752,10 +868,9 @@
 	return MLX5_MAX_UMR_SHIFT;
 }
 
-static int mr_umem_get(struct mlx5_ib_dev *dev, struct ib_udata *udata,
-		       u64 start, u64 length, int access_flags,
-		       struct ib_umem **umem, int *npages, int *page_shift,
-		       int *ncont, int *order)
+static int mr_umem_get(struct mlx5_ib_dev *dev, u64 start, u64 length,
+		       int access_flags, struct ib_umem **umem, int *npages,
+		       int *page_shift, int *ncont, int *order)
 {
 	struct ib_umem *u;
 
@@ -764,7 +879,8 @@
 	if (access_flags & IB_ACCESS_ON_DEMAND) {
 		struct ib_umem_odp *odp;
 
-		odp = ib_umem_odp_get(udata, start, length, access_flags);
+		odp = ib_umem_odp_get(&dev->ib_dev, start, length, access_flags,
+				      &mlx5_mn_ops);
 		if (IS_ERR(odp)) {
 			mlx5_ib_dbg(dev, "umem get failed (%ld)\n",
 				    PTR_ERR(odp));
@@ -779,7 +895,7 @@
 		if (order)
 			*order = ilog2(roundup_pow_of_two(*ncont));
 	} else {
-		u = ib_umem_get(udata, start, length, access_flags, 0);
+		u = ib_umem_get(&dev->ib_dev, start, length, access_flags);
 		if (IS_ERR(u)) {
 			mlx5_ib_dbg(dev, "umem get failed (%ld)\n", PTR_ERR(u));
 			return PTR_ERR(u);
@@ -846,31 +962,42 @@
 	return err;
 }
 
-static struct mlx5_ib_mr *alloc_mr_from_cache(
-				  struct ib_pd *pd, struct ib_umem *umem,
-				  u64 virt_addr, u64 len, int npages,
-				  int page_shift, int order, int access_flags)
+static struct mlx5_cache_ent *mr_cache_ent_from_order(struct mlx5_ib_dev *dev,
+						      unsigned int order)
+{
+	struct mlx5_mr_cache *cache = &dev->cache;
+
+	if (order < cache->ent[0].order)
+		return &cache->ent[0];
+	order = order - cache->ent[0].order;
+	if (order > MR_CACHE_LAST_STD_ENTRY)
+		return NULL;
+	return &cache->ent[order];
+}
+
+static struct mlx5_ib_mr *
+alloc_mr_from_cache(struct ib_pd *pd, struct ib_umem *umem, u64 virt_addr,
+		    u64 len, int npages, int page_shift, unsigned int order,
+		    int access_flags)
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
+	struct mlx5_cache_ent *ent = mr_cache_ent_from_order(dev, order);
 	struct mlx5_ib_mr *mr;
-	int err = 0;
-	int i;
 
-	for (i = 0; i < 1; i++) {
-		mr = alloc_cached_mr(dev, order);
-		if (mr)
-			break;
+	if (!ent)
+		return ERR_PTR(-E2BIG);
 
-		err = add_keys(dev, order2idx(dev, order), 1);
-		if (err && err != -EAGAIN) {
-			mlx5_ib_warn(dev, "add_keys failed, err %d\n", err);
-			break;
-		}
+	/* Matches access in alloc_cache_mr() */
+	if (!mlx5_ib_can_reconfig_with_umr(dev, 0, access_flags))
+		return ERR_PTR(-EOPNOTSUPP);
+
+	mr = get_cache_mr(ent);
+	if (!mr) {
+		mr = create_cache_mr(ent);
+		if (IS_ERR(mr))
+			return mr;
 	}
 
-	if (!mr)
-		return ERR_PTR(-EAGAIN);
-
 	mr->ibmr.pd = pd;
 	mr->umem = umem;
 	mr->access_flags = access_flags;
@@ -882,36 +1009,6 @@
 	return mr;
 }
 
-static inline int populate_xlt(struct mlx5_ib_mr *mr, int idx, int npages,
-			       void *xlt, int page_shift, size_t size,
-			       int flags)
-{
-	struct mlx5_ib_dev *dev = mr->dev;
-	struct ib_umem *umem = mr->umem;
-
-	if (flags & MLX5_IB_UPD_XLT_INDIRECT) {
-		if (!umr_can_use_indirect_mkey(dev))
-			return -EPERM;
-		mlx5_odp_populate_klm(xlt, idx, npages, mr, flags);
-		return npages;
-	}
-
-	npages = min_t(size_t, npages, ib_umem_num_pages(umem) - idx);
-
-	if (!(flags & MLX5_IB_UPD_XLT_ZAP)) {
-		__mlx5_ib_populate_pas(dev, umem, page_shift,
-				       idx, npages, xlt,
-				       MLX5_IB_MTT_PRESENT);
-		/* Clear padding after the pages
-		 * brought from the umem.
-		 */
-		memset(xlt + (npages * sizeof(struct mlx5_mtt)), 0,
-		       size - npages * sizeof(struct mlx5_mtt));
-	}
-
-	return npages;
-}
-
 #define MLX5_MAX_UMR_CHUNK ((1 << (MLX5_MAX_UMR_SHIFT + 4)) - \
 			    MLX5_UMR_MTT_ALIGNMENT)
 #define MLX5_SPARE_UMR_CHUNK 0x10000
@@ -935,6 +1032,7 @@
 	size_t pages_mapped = 0;
 	size_t pages_to_map = 0;
 	size_t pages_iter = 0;
+	size_t size_to_map = 0;
 	gfp_t gfp;
 	bool use_emergency_page = false;
 
@@ -981,6 +1079,15 @@
 		goto free_xlt;
 	}
 
+	if (mr->umem->is_odp) {
+		if (!(flags & MLX5_IB_UPD_XLT_INDIRECT)) {
+			struct ib_umem_odp *odp = to_ib_umem_odp(mr->umem);
+			size_t max_pages = ib_umem_odp_num_pages(odp) - idx;
+
+			pages_to_map = min_t(size_t, pages_to_map, max_pages);
+		}
+	}
+
 	sg.addr = dma;
 	sg.lkey = dev->umrc.pd->local_dma_lkey;
 
@@ -1003,14 +1110,22 @@
 	     pages_mapped < pages_to_map && !err;
 	     pages_mapped += pages_iter, idx += pages_iter) {
 		npages = min_t(int, pages_iter, pages_to_map - pages_mapped);
+		size_to_map = npages * desc_size;
 		dma_sync_single_for_cpu(ddev, dma, size, DMA_TO_DEVICE);
-		npages = populate_xlt(mr, idx, npages, xlt,
-				      page_shift, size, flags);
-
+		if (mr->umem->is_odp) {
+			mlx5_odp_populate_xlt(xlt, idx, npages, mr, flags);
+		} else {
+			__mlx5_ib_populate_pas(dev, mr->umem, page_shift, idx,
+					       npages, xlt,
+					       MLX5_IB_MTT_PRESENT);
+			/* Clear padding after the pages
+			 * brought from the umem.
+			 */
+			memset(xlt + size_to_map, 0, size - size_to_map);
+		}
 		dma_sync_single_for_device(ddev, dma, size, DMA_TO_DEVICE);
 
-		sg.length = ALIGN(npages * desc_size,
-				  MLX5_UMR_MTT_ALIGNMENT);
+		sg.length = ALIGN(size_to_map, MLX5_UMR_MTT_ALIGNMENT);
 
 		if (pages_mapped + pages_iter >= pages_to_map) {
 			if (flags & MLX5_IB_UPD_XLT_ENABLE)
@@ -1078,38 +1193,37 @@
 		goto err_1;
 	}
 	pas = (__be64 *)MLX5_ADDR_OF(create_mkey_in, in, klm_pas_mtt);
-	if (populate && !(access_flags & IB_ACCESS_ON_DEMAND))
+	if (populate) {
+		if (WARN_ON(access_flags & IB_ACCESS_ON_DEMAND)) {
+			err = -EINVAL;
+			goto err_2;
+		}
 		mlx5_ib_populate_pas(dev, umem, page_shift, pas,
 				     pg_cap ? MLX5_IB_MTT_PRESENT : 0);
+	}
 
 	/* The pg_access bit allows setting the access flags
 	 * in the page list submitted with the command. */
 	MLX5_SET(create_mkey_in, in, pg_access, !!(pg_cap));
 
 	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
+	set_mkc_access_pd_addr_fields(mkc, access_flags, virt_addr,
+				      populate ? pd : dev->umrc.pd);
 	MLX5_SET(mkc, mkc, free, !populate);
 	MLX5_SET(mkc, mkc, access_mode_1_0, MLX5_MKC_ACCESS_MODE_MTT);
-	MLX5_SET(mkc, mkc, a, !!(access_flags & IB_ACCESS_REMOTE_ATOMIC));
-	MLX5_SET(mkc, mkc, rw, !!(access_flags & IB_ACCESS_REMOTE_WRITE));
-	MLX5_SET(mkc, mkc, rr, !!(access_flags & IB_ACCESS_REMOTE_READ));
-	MLX5_SET(mkc, mkc, lw, !!(access_flags & IB_ACCESS_LOCAL_WRITE));
-	MLX5_SET(mkc, mkc, lr, 1);
 	MLX5_SET(mkc, mkc, umr_en, 1);
 
-	MLX5_SET64(mkc, mkc, start_addr, virt_addr);
 	MLX5_SET64(mkc, mkc, len, length);
-	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
 	MLX5_SET(mkc, mkc, bsf_octword_size, 0);
 	MLX5_SET(mkc, mkc, translations_octword_size,
 		 get_octo_len(virt_addr, length, page_shift));
 	MLX5_SET(mkc, mkc, log_page_size, page_shift);
-	MLX5_SET(mkc, mkc, qpn, 0xffffff);
 	if (populate) {
 		MLX5_SET(create_mkey_in, in, translations_octword_actual_size,
 			 get_octo_len(virt_addr, length, page_shift));
 	}
 
-	err = mlx5_core_create_mkey(dev->mdev, &mr->mmkey, in, inlen);
+	err = mlx5_ib_create_mkey(dev, &mr->mmkey, in, inlen);
 	if (err) {
 		mlx5_ib_warn(dev, "create mkey failed\n");
 		goto err_2;
@@ -1134,10 +1248,8 @@
 }
 
 static void set_mr_fields(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr,
-			  int npages, u64 length, int access_flags)
+			  u64 length, int access_flags)
 {
-	mr->npages = npages;
-	atomic_add(npages, &dev->mdev->priv.reg_pages);
 	mr->ibmr.lkey = mr->mmkey.key;
 	mr->ibmr.rkey = mr->mmkey.key;
 	mr->ibmr.length = length;
@@ -1149,7 +1261,6 @@
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
 	int inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
-	struct mlx5_core_dev *mdev = dev->mdev;
 	struct mlx5_ib_mr *mr;
 	void *mkc;
 	u32 *in;
@@ -1169,25 +1280,16 @@
 
 	MLX5_SET(mkc, mkc, access_mode_1_0, mode & 0x3);
 	MLX5_SET(mkc, mkc, access_mode_4_2, (mode >> 2) & 0x7);
-	MLX5_SET(mkc, mkc, a, !!(acc & IB_ACCESS_REMOTE_ATOMIC));
-	MLX5_SET(mkc, mkc, rw, !!(acc & IB_ACCESS_REMOTE_WRITE));
-	MLX5_SET(mkc, mkc, rr, !!(acc & IB_ACCESS_REMOTE_READ));
-	MLX5_SET(mkc, mkc, lw, !!(acc & IB_ACCESS_LOCAL_WRITE));
-	MLX5_SET(mkc, mkc, lr, 1);
-
 	MLX5_SET64(mkc, mkc, len, length);
-	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
-	MLX5_SET(mkc, mkc, qpn, 0xffffff);
-	MLX5_SET64(mkc, mkc, start_addr, start_addr);
+	set_mkc_access_pd_addr_fields(mkc, acc, start_addr, pd);
 
-	err = mlx5_core_create_mkey(mdev, &mr->mmkey, in, inlen);
+	err = mlx5_ib_create_mkey(dev, &mr->mmkey, in, inlen);
 	if (err)
 		goto err_in;
 
 	kfree(in);
 
-	mr->umem = NULL;
-	set_mr_fields(dev, mr, 0, length, acc);
+	set_mr_fields(dev, mr, length, acc);
 
 	return &mr->ibmr;
 
@@ -1208,7 +1310,8 @@
 		      struct uverbs_attr_bundle *attrs)
 {
 	if (advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH &&
-	    advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_WRITE)
+	    advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_WRITE &&
+	    advice != IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_NO_FAULT)
 		return -EOPNOTSUPP;
 
 	return mlx5_ib_advise_mr_prefetch(pd, advice, flags,
@@ -1253,7 +1356,7 @@
 {
 	struct mlx5_ib_dev *dev = to_mdev(pd->device);
 	struct mlx5_ib_mr *mr = NULL;
-	bool use_umr;
+	bool xlt_with_umr;
 	struct ib_umem *umem;
 	int page_shift;
 	int npages;
@@ -1267,8 +1370,15 @@
 	mlx5_ib_dbg(dev, "start 0x%llx, virt_addr 0x%llx, length 0x%llx, access_flags 0x%x\n",
 		    start, virt_addr, length, access_flags);
 
+	xlt_with_umr = mlx5_ib_can_load_pas_with_umr(dev, length);
+	/* ODP requires xlt update via umr to work. */
+	if (!xlt_with_umr && (access_flags & IB_ACCESS_ON_DEMAND))
+		return ERR_PTR(-EINVAL);
+
 	if (IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING) && !start &&
 	    length == U64_MAX) {
+		if (virt_addr != start)
+			return ERR_PTR(-EINVAL);
 		if (!(access_flags & IB_ACCESS_ON_DEMAND) ||
 		    !(dev->odp_caps.general_caps & IB_ODP_SUPPORT_IMPLICIT))
 			return ERR_PTR(-EINVAL);
@@ -1279,34 +1389,23 @@
 		return &mr->ibmr;
 	}
 
-	err = mr_umem_get(dev, udata, start, length, access_flags, &umem,
+	err = mr_umem_get(dev, start, length, access_flags, &umem,
 			  &npages, &page_shift, &ncont, &order);
 
 	if (err < 0)
 		return ERR_PTR(err);
 
-	use_umr = mlx5_ib_can_use_umr(dev, true);
-
-	if (order <= mr_cache_max_order(dev) && use_umr) {
+	if (xlt_with_umr) {
 		mr = alloc_mr_from_cache(pd, umem, virt_addr, length, ncont,
 					 page_shift, order, access_flags);
-		if (PTR_ERR(mr) == -EAGAIN) {
-			mlx5_ib_dbg(dev, "cache empty for order %d\n", order);
+		if (IS_ERR(mr))
 			mr = NULL;
-		}
-	} else if (!MLX5_CAP_GEN(dev->mdev, umr_extended_translation_offset)) {
-		if (access_flags & IB_ACCESS_ON_DEMAND) {
-			err = -EINVAL;
-			pr_err("Got MR registration for ODP MR > 512MB, not supported for Connect-IB\n");
-			goto error;
-		}
-		use_umr = false;
 	}
 
 	if (!mr) {
 		mutex_lock(&dev->slow_path_mutex);
 		mr = reg_create(NULL, pd, virt_addr, length, umem, ncont,
-				page_shift, access_flags, !use_umr);
+				page_shift, access_flags, !xlt_with_umr);
 		mutex_unlock(&dev->slow_path_mutex);
 	}
 
@@ -1318,17 +1417,20 @@
 	mlx5_ib_dbg(dev, "mkey 0x%x\n", mr->mmkey.key);
 
 	mr->umem = umem;
-	set_mr_fields(dev, mr, npages, length, access_flags);
+	mr->npages = npages;
+	atomic_add(mr->npages, &dev->mdev->priv.reg_pages);
+	set_mr_fields(dev, mr, length, access_flags);
 
-	if (use_umr) {
+	if (xlt_with_umr && !(access_flags & IB_ACCESS_ON_DEMAND)) {
+		/*
+		 * If the MR was created with reg_create then it will be
+		 * configured properly but left disabled. It is safe to go ahead
+		 * and configure it again via UMR while enabling it.
+		 */
 		int update_xlt_flags = MLX5_IB_UPD_XLT_ENABLE;
 
-		if (access_flags & IB_ACCESS_ON_DEMAND)
-			update_xlt_flags |= MLX5_IB_UPD_XLT_ZAP;
-
 		err = mlx5_ib_update_xlt(mr, 0, ncont, page_shift,
 					 update_xlt_flags);
-
 		if (err) {
 			dereg_mr(dev, mr);
 			return ERR_PTR(err);
@@ -1337,10 +1439,22 @@
 
 	if (is_odp_mr(mr)) {
 		to_ib_umem_odp(mr->umem)->private = mr;
-		atomic_set(&mr->num_pending_prefetch, 0);
+		init_waitqueue_head(&mr->q_deferred_work);
+		atomic_set(&mr->num_deferred_work, 0);
+		err = xa_err(xa_store(&dev->odp_mkeys,
+				      mlx5_base_mkey(mr->mmkey.key), &mr->mmkey,
+				      GFP_KERNEL));
+		if (err) {
+			dereg_mr(dev, mr);
+			return ERR_PTR(err);
+		}
+
+		err = mlx5_ib_init_odp_mr(mr, xlt_with_umr);
+		if (err) {
+			dereg_mr(dev, mr);
+			return ERR_PTR(err);
+		}
 	}
-	if (IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING))
-		smp_store_release(&mr->live, 1);
 
 	return &mr->ibmr;
 error:
@@ -1348,22 +1462,29 @@
 	return ERR_PTR(err);
 }
 
-static int unreg_umr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr)
+/**
+ * mlx5_mr_cache_invalidate - Fence all DMA on the MR
+ * @mr: The MR to fence
+ *
+ * Upon return the NIC will not be doing any DMA to the pages under the MR,
+ * and any DMA inprogress will be completed. Failure of this function
+ * indicates the HW has failed catastrophically.
+ */
+int mlx5_mr_cache_invalidate(struct mlx5_ib_mr *mr)
 {
-	struct mlx5_core_dev *mdev = dev->mdev;
 	struct mlx5_umr_wr umrwr = {};
 
-	if (mdev->state == MLX5_DEVICE_STATE_INTERNAL_ERROR)
+	if (mr->dev->mdev->state == MLX5_DEVICE_STATE_INTERNAL_ERROR)
 		return 0;
 
 	umrwr.wr.send_flags = MLX5_IB_SEND_UMR_DISABLE_MR |
 			      MLX5_IB_SEND_UMR_UPDATE_PD_ACCESS;
 	umrwr.wr.opcode = MLX5_IB_WR_UMR;
-	umrwr.pd = dev->umrc.pd;
+	umrwr.pd = mr->dev->umrc.pd;
 	umrwr.mkey = mr->mmkey.key;
 	umrwr.ignore_free_state = 1;
 
-	return mlx5_ib_post_send_wait(dev, &umrwr);
+	return mlx5_ib_post_send_wait(mr->dev, &umrwr);
 }
 
 static int rereg_umr(struct ib_pd *pd, struct mlx5_ib_mr *mr,
@@ -1410,8 +1531,6 @@
 	mlx5_ib_dbg(dev, "start 0x%llx, virt_addr 0x%llx, length 0x%llx, access_flags 0x%x\n",
 		    start, virt_addr, length, access_flags);
 
-	atomic_sub(mr->npages, &dev->mdev->priv.reg_pages);
-
 	if (!mr->umem)
 		return -EINVAL;
 
@@ -1432,24 +1551,30 @@
 		 * used.
 		 */
 		flags |= IB_MR_REREG_TRANS;
+		atomic_sub(mr->npages, &dev->mdev->priv.reg_pages);
+		mr->npages = 0;
 		ib_umem_release(mr->umem);
 		mr->umem = NULL;
-		err = mr_umem_get(dev, udata, addr, len, access_flags,
-				  &mr->umem, &npages, &page_shift, &ncont,
-				  &order);
+
+		err = mr_umem_get(dev, addr, len, access_flags, &mr->umem,
+				  &npages, &page_shift, &ncont, &order);
 		if (err)
 			goto err;
+		mr->npages = ncont;
+		atomic_add(mr->npages, &dev->mdev->priv.reg_pages);
 	}
 
-	if (!mlx5_ib_can_use_umr(dev, true) ||
-	    (flags & IB_MR_REREG_TRANS && !use_umr_mtt_update(mr, addr, len))) {
+	if (!mlx5_ib_can_reconfig_with_umr(dev, mr->access_flags,
+					   access_flags) ||
+	    !mlx5_ib_can_load_pas_with_umr(dev, len) ||
+	    (flags & IB_MR_REREG_TRANS &&
+	     !mlx5_ib_pas_fits_in_mr(mr, addr, len))) {
 		/*
 		 * UMR can't be used - MKey needs to be replaced.
 		 */
-		if (mr->allocated_from_cache)
-			err = unreg_umr(dev, mr);
-		else
-			err = destroy_mkey(dev, mr);
+		if (mr->cache_ent)
+			detach_mr_from_cache(mr);
+		err = destroy_mkey(dev, mr);
 		if (err)
 			goto err;
 
@@ -1461,8 +1586,6 @@
 			mr = to_mmr(ib_mr);
 			goto err;
 		}
-
-		mr->allocated_from_cache = 0;
 	} else {
 		/*
 		 * Send a UMR WQE
@@ -1489,7 +1612,7 @@
 			goto err;
 	}
 
-	set_mr_fields(dev, mr, npages, len, access_flags);
+	set_mr_fields(dev, mr, len, access_flags);
 
 	return 0;
 
@@ -1549,8 +1672,6 @@
 
 static void clean_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr)
 {
-	int allocated_from_cache = mr->allocated_from_cache;
-
 	if (mr->sig) {
 		if (mlx5_core_destroy_psv(dev->mdev,
 					  mr->sig->psv_memory.psv_idx))
@@ -1560,11 +1681,12 @@
 					  mr->sig->psv_wire.psv_idx))
 			mlx5_ib_warn(dev, "failed to destroy wire psv %d\n",
 				     mr->sig->psv_wire.psv_idx);
+		xa_erase(&dev->sig_mrs, mlx5_base_mkey(mr->mmkey.key));
 		kfree(mr->sig);
 		mr->sig = NULL;
 	}
 
-	if (!allocated_from_cache) {
+	if (!mr->cache_ent) {
 		destroy_mkey(dev, mr);
 		mlx5_free_priv_descs(mr);
 	}
@@ -1575,54 +1697,20 @@
 	int npages = mr->npages;
 	struct ib_umem *umem = mr->umem;
 
-	if (is_odp_mr(mr)) {
-		struct ib_umem_odp *umem_odp = to_ib_umem_odp(umem);
+	/* Stop all DMA */
+	if (is_odp_mr(mr))
+		mlx5_ib_fence_odp_mr(mr);
+	else
+		clean_mr(dev, mr);
 
-		/* Prevent new page faults and
-		 * prefetch requests from succeeding
-		 */
-		WRITE_ONCE(mr->live, 0);
-
-		/* Wait for all running page-fault handlers to finish. */
-		synchronize_srcu(&dev->mr_srcu);
-
-		/* dequeue pending prefetch requests for the mr */
-		if (atomic_read(&mr->num_pending_prefetch))
-			flush_workqueue(system_unbound_wq);
-		WARN_ON(atomic_read(&mr->num_pending_prefetch));
-
-		/* Destroy all page mappings */
-		if (!umem_odp->is_implicit_odp)
-			mlx5_ib_invalidate_range(umem_odp,
-						 ib_umem_start(umem_odp),
-						 ib_umem_end(umem_odp));
-		else
-			mlx5_ib_free_implicit_mr(mr);
-		/*
-		 * We kill the umem before the MR for ODP,
-		 * so that there will not be any invalidations in
-		 * flight, looking at the *mr struct.
-		 */
-		ib_umem_odp_release(umem_odp);
-		atomic_sub(npages, &dev->mdev->priv.reg_pages);
-
-		/* Avoid double-freeing the umem. */
-		umem = NULL;
-	}
-
-	clean_mr(dev, mr);
-
-	/*
-	 * We should unregister the DMA address from the HCA before
-	 * remove the DMA mapping.
-	 */
-	mlx5_mr_cache_free(dev, mr);
-	ib_umem_release(umem);
-	if (umem)
-		atomic_sub(npages, &dev->mdev->priv.reg_pages);
-
-	if (!mr->allocated_from_cache)
+	if (mr->cache_ent)
+		mlx5_mr_cache_free(dev, mr);
+	else
 		kfree(mr);
+
+	ib_umem_release(umem);
+	atomic_sub(npages, &dev->mdev->priv.reg_pages);
+
 }
 
 int mlx5_ib_dereg_mr(struct ib_mr *ibmr, struct ib_udata *udata)
@@ -1634,6 +1722,11 @@
 		dereg_mr(to_mdev(mmr->klm_mr->ibmr.device), mmr->klm_mr);
 	}
 
+	if (is_odp_mr(mmr) && to_ib_umem_odp(mmr->umem)->is_implicit_odp) {
+		mlx5_ib_free_implicit_mr(mmr);
+		return 0;
+	}
+
 	dereg_mr(to_mdev(ibmr->device), mmr);
 
 	return 0;
@@ -1646,9 +1739,9 @@
 
 	mkc = MLX5_ADDR_OF(create_mkey_in, in, memory_key_mkey_entry);
 
+	/* This is only used from the kernel, so setting the PD is OK. */
+	set_mkc_access_pd_addr_fields(mkc, 0, 0, pd);
 	MLX5_SET(mkc, mkc, free, 1);
-	MLX5_SET(mkc, mkc, qpn, 0xffffff);
-	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
 	MLX5_SET(mkc, mkc, translations_octword_size, ndescs);
 	MLX5_SET(mkc, mkc, access_mode_1_0, access_mode & 0x3);
 	MLX5_SET(mkc, mkc, access_mode_4_2, (access_mode >> 2) & 0x7);
@@ -1673,7 +1766,7 @@
 
 	mlx5_set_umr_free_mkey(pd, in, ndescs, access_mode, page_shift);
 
-	err = mlx5_core_create_mkey(dev->mdev, &mr->mmkey, in, inlen);
+	err = mlx5_ib_create_mkey(dev, &mr->mmkey, in, inlen);
 	if (err)
 		goto err_free_descs;
 
@@ -1797,8 +1890,15 @@
 	if (err)
 		goto err_free_mtt_mr;
 
+	err = xa_err(xa_store(&dev->sig_mrs, mlx5_base_mkey(mr->mmkey.key),
+			      mr->sig, GFP_KERNEL));
+	if (err)
+		goto err_free_descs;
 	return 0;
 
+err_free_descs:
+	destroy_mkey(dev, mr);
+	mlx5_free_priv_descs(mr);
 err_free_mtt_mr:
 	dereg_mr(to_mdev(mr->mtt_mr->ibmr.device), mr->mtt_mr);
 	mr->mtt_mr = NULL;
@@ -1873,7 +1973,7 @@
 }
 
 struct ib_mr *mlx5_ib_alloc_mr(struct ib_pd *pd, enum ib_mr_type mr_type,
-			       u32 max_num_sg, struct ib_udata *udata)
+			       u32 max_num_sg)
 {
 	return __mlx5_ib_alloc_mr(pd, mr_type, max_num_sg, 0);
 }
@@ -1885,12 +1985,11 @@
 				  max_num_meta_sg);
 }
 
-struct ib_mw *mlx5_ib_alloc_mw(struct ib_pd *pd, enum ib_mw_type type,
-			       struct ib_udata *udata)
+int mlx5_ib_alloc_mw(struct ib_mw *ibmw, struct ib_udata *udata)
 {
-	struct mlx5_ib_dev *dev = to_mdev(pd->device);
+	struct mlx5_ib_dev *dev = to_mdev(ibmw->device);
 	int inlen = MLX5_ST_SZ_BYTES(create_mkey_in);
-	struct mlx5_ib_mw *mw = NULL;
+	struct mlx5_ib_mw *mw = to_mmw(ibmw);
 	u32 *in = NULL;
 	void *mkc;
 	int ndescs;
@@ -1903,21 +2002,20 @@
 
 	err = ib_copy_from_udata(&req, udata, min(udata->inlen, sizeof(req)));
 	if (err)
-		return ERR_PTR(err);
+		return err;
 
 	if (req.comp_mask || req.reserved1 || req.reserved2)
-		return ERR_PTR(-EOPNOTSUPP);
+		return -EOPNOTSUPP;
 
 	if (udata->inlen > sizeof(req) &&
 	    !ib_is_udata_cleared(udata, sizeof(req),
 				 udata->inlen - sizeof(req)))
-		return ERR_PTR(-EOPNOTSUPP);
+		return -EOPNOTSUPP;
 
 	ndescs = req.num_klms ? roundup(req.num_klms, 4) : roundup(1, 4);
 
-	mw = kzalloc(sizeof(*mw), GFP_KERNEL);
 	in = kzalloc(inlen, GFP_KERNEL);
-	if (!mw || !in) {
+	if (!in) {
 		err = -ENOMEM;
 		goto free;
 	}
@@ -1926,61 +2024,62 @@
 
 	MLX5_SET(mkc, mkc, free, 1);
 	MLX5_SET(mkc, mkc, translations_octword_size, ndescs);
-	MLX5_SET(mkc, mkc, pd, to_mpd(pd)->pdn);
+	MLX5_SET(mkc, mkc, pd, to_mpd(ibmw->pd)->pdn);
 	MLX5_SET(mkc, mkc, umr_en, 1);
 	MLX5_SET(mkc, mkc, lr, 1);
 	MLX5_SET(mkc, mkc, access_mode_1_0, MLX5_MKC_ACCESS_MODE_KLMS);
-	MLX5_SET(mkc, mkc, en_rinval, !!((type == IB_MW_TYPE_2)));
+	MLX5_SET(mkc, mkc, en_rinval, !!((ibmw->type == IB_MW_TYPE_2)));
 	MLX5_SET(mkc, mkc, qpn, 0xffffff);
 
-	err = mlx5_core_create_mkey(dev->mdev, &mw->mmkey, in, inlen);
+	err = mlx5_ib_create_mkey(dev, &mw->mmkey, in, inlen);
 	if (err)
 		goto free;
 
 	mw->mmkey.type = MLX5_MKEY_MW;
-	mw->ibmw.rkey = mw->mmkey.key;
+	ibmw->rkey = mw->mmkey.key;
 	mw->ndescs = ndescs;
 
-	resp.response_length = min(offsetof(typeof(resp), response_length) +
-				   sizeof(resp.response_length), udata->outlen);
+	resp.response_length =
+		min(offsetofend(typeof(resp), response_length), udata->outlen);
 	if (resp.response_length) {
 		err = ib_copy_to_udata(udata, &resp, resp.response_length);
-		if (err) {
-			mlx5_core_destroy_mkey(dev->mdev, &mw->mmkey);
-			goto free;
-		}
+		if (err)
+			goto free_mkey;
+	}
+
+	if (IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING)) {
+		err = xa_err(xa_store(&dev->odp_mkeys,
+				      mlx5_base_mkey(mw->mmkey.key), &mw->mmkey,
+				      GFP_KERNEL));
+		if (err)
+			goto free_mkey;
 	}
 
 	kfree(in);
-	return &mw->ibmw;
+	return 0;
 
+free_mkey:
+	mlx5_core_destroy_mkey(dev->mdev, &mw->mmkey);
 free:
-	kfree(mw);
 	kfree(in);
-	return ERR_PTR(err);
+	return err;
 }
 
 int mlx5_ib_dealloc_mw(struct ib_mw *mw)
 {
 	struct mlx5_ib_dev *dev = to_mdev(mw->device);
 	struct mlx5_ib_mw *mmw = to_mmw(mw);
-	int err;
 
 	if (IS_ENABLED(CONFIG_INFINIBAND_ON_DEMAND_PAGING)) {
-		xa_erase_irq(&dev->mdev->priv.mkey_table,
-			     mlx5_base_mkey(mmw->mmkey.key));
+		xa_erase(&dev->odp_mkeys, mlx5_base_mkey(mmw->mmkey.key));
 		/*
 		 * pagefault_single_data_segment() may be accessing mmw under
 		 * SRCU if the user bound an ODP MR to this MW.
 		 */
-		synchronize_srcu(&dev->mr_srcu);
+		synchronize_srcu(&dev->odp_srcu);
 	}
 
-	err = mlx5_core_destroy_mkey(dev->mdev, &mmw->mmkey);
-	if (err)
-		return err;
-	kfree(mmw);
-	return 0;
+	return mlx5_core_destroy_mkey(dev->mdev, &mmw->mmkey);
 }
 
 int mlx5_ib_check_mr_status(struct ib_mr *ibmr, u32 check_mask,