iov_iter: Turn iov_iter_fault_in_readable into fault_in_iov_iter_readable
Turn iov_iter_fault_in_readable into a function that returns the number
of bytes not faulted in, similar to copy_to_user, instead of returning a
non-zero value when any of the requested pages couldn't be faulted in.
This supports the existing users that require all pages to be faulted in
as well as new users that are happy if any pages can be faulted in.
Rename iov_iter_fault_in_readable to fault_in_iov_iter_readable to make
sure this change doesn't silently break things.
Signed-off-by: Andreas Gruenbacher <agruenba@redhat.com>
diff --git a/lib/iov_iter.c b/lib/iov_iter.c
index c88908f..ce3d4f6 100644
--- a/lib/iov_iter.c
+++ b/lib/iov_iter.c
@@ -430,33 +430,42 @@ static size_t copy_page_to_iter_pipe(struct page *page, size_t offset, size_t by
}
/*
- * Fault in one or more iovecs of the given iov_iter, to a maximum length of
- * bytes. For each iovec, fault in each page that constitutes the iovec.
+ * fault_in_iov_iter_readable - fault in iov iterator for reading
+ * @i: iterator
+ * @size: maximum length
*
- * Return 0 on success, or non-zero if the memory could not be accessed (i.e.
- * because it is an invalid address).
+ * Fault in one or more iovecs of the given iov_iter, to a maximum length of
+ * @size. For each iovec, fault in each page that constitutes the iovec.
+ *
+ * Returns the number of bytes not faulted in (like copy_to_user() and
+ * copy_from_user()).
+ *
+ * Always returns 0 for non-userspace iterators.
*/
-int iov_iter_fault_in_readable(const struct iov_iter *i, size_t bytes)
+size_t fault_in_iov_iter_readable(const struct iov_iter *i, size_t size)
{
if (iter_is_iovec(i)) {
+ size_t count = min(size, iov_iter_count(i));
const struct iovec *p;
size_t skip;
- if (bytes > i->count)
- bytes = i->count;
- for (p = i->iov, skip = i->iov_offset; bytes; p++, skip = 0) {
- size_t len = min(bytes, p->iov_len - skip);
+ size -= count;
+ for (p = i->iov, skip = i->iov_offset; count; p++, skip = 0) {
+ size_t len = min(count, p->iov_len - skip);
+ size_t ret;
if (unlikely(!len))
continue;
- if (fault_in_readable(p->iov_base + skip, len))
- return -EFAULT;
- bytes -= len;
+ ret = fault_in_readable(p->iov_base + skip, len);
+ count -= len - ret;
+ if (ret)
+ break;
}
+ return count + size;
}
return 0;
}
-EXPORT_SYMBOL(iov_iter_fault_in_readable);
+EXPORT_SYMBOL(fault_in_iov_iter_readable);
void iov_iter_init(struct iov_iter *i, unsigned int direction,
const struct iovec *iov, unsigned long nr_segs,