blob: e2b9c3e1a5a28d174d5f7712a02db6755f4f69eb [file] [log] [blame] [edit]
// SPDX-License-Identifier: GPL-2.0
/*
* Copyright (C) 2023 Google, Inc.
*/
#include <linux/cred.h>
#include <linux/init.h>
#include <linux/kernel.h>
#include <linux/list.h>
#include <linux/module.h>
#include <linux/mutex.h>
#include <linux/poll.h>
#include <linux/proc_fs.h>
#include <linux/sched.h>
#include <linux/seq_file.h>
#include <linux/timekeeping.h>
#include <linux/tracepoint.h>
#include <linux/workqueue.h>
#include <trace/events/oom.h>
#define MEMHEALTH_DIRECTORY "memhealth"
#define OOM_VICTIM_LIST_ENTRY "oom_victim_list"
static wait_queue_head_t memhealth_wq;
static struct proc_dir_entry *proc_memhealth_dir;
/* List of oom victims */
static struct list_head oom_victim_list;
static size_t oom_victim_count;
static size_t oom_victim_removed_count;
/*
* Lock protecting oom_victim_list, oom_victim_count and
* oom_victim_removed_count
*/
static DEFINE_MUTEX(memhealth_mutex);
/* List of new oom victims not yet added into oom_victim_list */
static struct list_head new_oom_victims_list;
static size_t new_oom_victims_count;
/* Lock protecting new_oom_victims_list and new_oom_victims_count */
static DEFINE_SPINLOCK(memhealth_spin_lock);
struct oom_victim {
pid_t pid;
uid_t uid;
char process_name[TASK_COMM_LEN];
ktime_t timestamp;
short oom_score_adj;
struct list_head list;
};
#define OOM_VICTIM_LIST_MAX_SIZE ((PAGE_SIZE/sizeof(struct oom_victim)))
static void oom_list_move_victims(struct work_struct *work)
{
struct oom_victim *head;
size_t total_new_nodes;
mutex_lock(&memhealth_mutex);
spin_lock(&memhealth_spin_lock);
if (list_empty(&new_oom_victims_list)) {
spin_unlock(&memhealth_spin_lock);
mutex_unlock(&memhealth_mutex);
return;
}
total_new_nodes = new_oom_victims_count;
list_splice_tail_init(&new_oom_victims_list, &oom_victim_list);
new_oom_victims_count = 0;
spin_unlock(&memhealth_spin_lock);
oom_victim_count += total_new_nodes;
while (oom_victim_count - oom_victim_removed_count >=
OOM_VICTIM_LIST_MAX_SIZE) {
head = list_first_entry(&oom_victim_list, struct oom_victim, list);
list_del(&head->list);
kfree(head);
oom_victim_removed_count++;
}
mutex_unlock(&memhealth_mutex);
}
static DECLARE_WORK(memhealth_oom_work, oom_list_move_victims);
static int add_oom_victim_to_list(pid_t pid, ktime_t timestamp)
{
struct oom_victim *new_node;
struct task_struct *task;
const struct cred *cred;
struct pid *pid_struct;
int ret = -EINVAL;
/*
* Function could be called while a spinlock is being held, we want to
* prevent blocking while allocating for the new node
*/
new_node = kmalloc(sizeof(*new_node), GFP_ATOMIC);
if (!new_node) {
pr_err("memhealth failed to create new oom node for pid %d\n", pid);
ret = -ENOMEM;
goto err_create_oom_node;
}
pid_struct = find_get_pid(pid);
if (!pid_struct) {
pr_err("memhealth failed to find pid %d\n", pid);
goto err_get_pid;
}
task = get_pid_task(pid_struct, PIDTYPE_PID);
put_pid(pid_struct);
if (!task) {
pr_err("memhealth failed to find task with pid %d\n", pid);
goto err_get_task;
}
cred = get_task_cred(task);
if (!cred) {
pr_err("memhealth failed to find credentials\n");
goto err_get_cred;
}
if (!task->signal) {
pr_err("memhealth failed to find signal in task\n");
goto err_no_task_signal;
}
new_node->oom_score_adj = task->signal->oom_score_adj;
ret = strscpy_pad(new_node->process_name,
task->comm, TASK_COMM_LEN);
if (ret < 0) {
pr_err("memhealth failed to copy process name to new oom victim node\n");
goto err_write_process_name;
}
put_task_struct(task);
new_node->pid = pid;
new_node->timestamp = timestamp;
new_node->uid = cred->uid.val;
put_cred(cred);
spin_lock(&memhealth_spin_lock);
/*
* We add to `new_oom_victims_list` so the file read does not
* block the caller of mark victim
*/
list_add_tail(&new_node->list, &new_oom_victims_list);
new_oom_victims_count++;
spin_unlock(&memhealth_spin_lock);
schedule_work(&memhealth_oom_work);
return 0;
err_write_process_name:
err_no_task_signal:
put_cred(cred);
err_get_cred:
put_task_struct(task);
err_get_task:
err_get_pid:
kfree(new_node);
err_create_oom_node:
return ret;
}
static void mark_victim_probe(void *data, pid_t pid)
{
ktime_t timestamp;
timestamp = ktime_get();
if (add_oom_victim_to_list(pid, timestamp) < 0) {
pr_err("memhealth failed to add pid(%d) as new OOM killer victim\n", pid);
return;
}
wake_up_interruptible(&memhealth_wq);
}
static void *oom_victim_list_seq_start(struct seq_file *s, loff_t *pos)
{
mutex_lock(&memhealth_mutex);
return seq_list_start(&oom_victim_list, *pos);
}
static void *oom_victim_list_seq_next(struct seq_file *s, void *v, loff_t *pos)
{
return seq_list_next(v, &oom_victim_list, pos);
}
static void oom_victim_list_seq_stop(struct seq_file *s, void *v)
{
mutex_unlock(&memhealth_mutex);
}
static int oom_victim_list_seq_show(struct seq_file *s, void *v)
{
struct oom_victim *entry = list_entry(v, struct oom_victim, list);
seq_printf(s, "%llu %d %d %i %s\n",
ktime_to_ms(entry->timestamp), entry->pid,
entry->uid, entry->oom_score_adj, entry->process_name
);
return 0;
}
static const struct seq_operations oom_victim_list_seq_ops = {
.start = oom_victim_list_seq_start,
.next = oom_victim_list_seq_next,
.stop = oom_victim_list_seq_stop,
.show = oom_victim_list_seq_show,
};
static int oom_victim_list_seq_open(struct inode *inode, struct file *file)
{
struct seq_file *seq;
int seq_open_result;
if (!capable(CAP_SYS_PTRACE))
return -EPERM;
seq_open_result = seq_open(file, &oom_victim_list_seq_ops);
if (seq_open_result) {
pr_err("memhealth failed opening OOM sequential file\n");
return seq_open_result;
}
seq = file->private_data;
seq->poll_event = 0;
return seq_open_result;
}
static __poll_t oom_victim_list_poll(struct file *filp, poll_table *wait)
{
struct seq_file *seq = filp->private_data;
__poll_t mask = DEFAULT_POLLMASK;
poll_wait(filp, &memhealth_wq, wait);
mutex_lock(&memhealth_mutex);
if (seq->poll_event < oom_victim_count) {
seq->poll_event = oom_victim_count;
mask |= EPOLLPRI;
}
mutex_unlock(&memhealth_mutex);
return mask;
}
static const struct proc_ops oom_victims_list_proc_ops = {
.proc_read = seq_read,
.proc_lseek = seq_lseek,
.proc_release = seq_release,
.proc_open = oom_victim_list_seq_open,
.proc_poll = oom_victim_list_poll,
};
static int __init memhealthmod_start(void)
{
struct proc_dir_entry *entry;
int ret = -ENOMEM;
proc_memhealth_dir = proc_mkdir(MEMHEALTH_DIRECTORY, NULL);
if (!proc_memhealth_dir) {
pr_err("memhealth failed to create directory (%s)\n", MEMHEALTH_DIRECTORY);
goto err_create_memhealth_dir;
}
entry = proc_create(OOM_VICTIM_LIST_ENTRY, 0444,
proc_memhealth_dir, &oom_victims_list_proc_ops);
if (!entry) {
pr_err("memhealth failed to create proc entry: %s\n",
OOM_VICTIM_LIST_ENTRY);
goto err_create_oom_entry;
}
INIT_LIST_HEAD(&oom_victim_list);
INIT_LIST_HEAD(&new_oom_victims_list);
init_waitqueue_head(&memhealth_wq);
oom_victim_count = 0;
oom_victim_removed_count = 0;
new_oom_victims_count = 0;
ret = register_trace_mark_victim(mark_victim_probe, NULL);
if (ret) {
pr_err("memhealth failed to hook a probe to the mark_victim tracepoint\n");
goto err_register_victim;
}
return 0;
err_register_victim:
remove_proc_entry(OOM_VICTIM_LIST_ENTRY, proc_memhealth_dir);
err_create_oom_entry:
remove_proc_entry(MEMHEALTH_DIRECTORY, NULL);
err_create_memhealth_dir:
return ret;
}
static void __exit memhealthmod_end(void)
{
struct oom_victim *entry, *tmp;
if (unregister_trace_mark_victim(mark_victim_probe, NULL))
pr_warn("memhealth failed to unhook a probe from the mark_victim tracepoint\n");
list_for_each_entry_safe(entry, tmp, &oom_victim_list, list) {
list_del(&entry->list);
kfree(entry);
}
remove_proc_entry(OOM_VICTIM_LIST_ENTRY, proc_memhealth_dir);
remove_proc_entry(MEMHEALTH_DIRECTORY, NULL);
}
module_init(memhealthmod_start);
module_exit(memhealthmod_end);
MODULE_LICENSE("GPL");