/*
 * Copyright (C) 2020 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <dirent.h>
#include <sys/mount.h>
#include <sys/stat.h>
#include <sys/utsname.h>

#include <cstdlib>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>

#include <modprobe/modprobe.h>
#include <private/android_filesystem_config.h>

#include "virt_test.h"

static constexpr const char LOG_TAG[] = "guest: ";

static constexpr const char CMDLINE_TEST_NAME_PARAM[] = "virt_test_name";

#define MODULE_BASE_DIR "/lib/modules"

std::string GetModuleLoadList(bool recovery, const std::string& dir_path) {
    auto module_load_file = "modules.load";
    if (recovery) {
        struct stat fileStat;
        std::string recovery_load_path = dir_path + "/modules.load.recovery";
        if (!stat(recovery_load_path.c_str(), &fileStat)) {
            module_load_file = "modules.load.recovery";
        }
    }

    return module_load_file;
}

int LoadKernelModules(bool recovery, bool want_console) {
    struct utsname uts;
    if (uname(&uts)) {
        std::cerr << "Failed to get kernel version." << std::endl;
        return EXIT_FAILURE;
    }
    int major, minor;
    if (sscanf(uts.release, "%d.%d", &major, &minor) != 2) {
        std::cerr << "Failed to parse kernel version " << uts.release << std::endl;
        return EXIT_FAILURE;
    }

    std::unique_ptr<DIR, decltype(&closedir)> base_dir(opendir(MODULE_BASE_DIR), closedir);
    if (!base_dir) {
        std::cerr << "Unable to open /lib/modules, skipping module loading." << std::endl;
        return EXIT_FAILURE;
    }
    dirent* entry;
    std::vector<std::string> module_dirs;
    while ((entry = readdir(base_dir.get()))) {
        if (entry->d_type != DT_DIR) {
            continue;
        }
        int dir_major, dir_minor;
        if (sscanf(entry->d_name, "%d.%d", &dir_major, &dir_minor) != 2 || dir_major != major ||
            dir_minor != minor) {
            continue;
        }
        module_dirs.emplace_back(entry->d_name);
    }

    // Sort the directories so they are iterated over during module loading
    // in a consistent order. Alphabetical sorting is fine here because the
    // kernel version at the beginning of the directory name must match the
    // current kernel version, so the sort only applies to a label that
    // follows the kernel version, for example /lib/modules/5.4 vs.
    // /lib/modules/5.4-gki.
    std::sort(module_dirs.begin(), module_dirs.end());

    for (const auto& module_dir : module_dirs) {
        std::string dir_path = MODULE_BASE_DIR "/";
        dir_path.append(module_dir);
        Modprobe m({dir_path}, GetModuleLoadList(recovery, dir_path));
        bool retval = m.LoadListedModules(!want_console);
        int modules_loaded = m.GetModuleCount();
        if (modules_loaded > 0) {
            return retval ? EXIT_SUCCESS : EXIT_FAILURE;
        }
    }

    Modprobe m({MODULE_BASE_DIR}, GetModuleLoadList(recovery, MODULE_BASE_DIR));
    bool retval = m.LoadListedModules(!want_console);
    int modules_loaded = m.GetModuleCount();
    if (modules_loaded > 0) {
        return retval ? EXIT_SUCCESS : EXIT_FAILURE;
    }
    return EXIT_SUCCESS;
}

int ReadCommandLineTestName(std::string *out) {
    std::ifstream ifs("/proc/cmdline");
    std::string cmdline;
    std::string test_name;
    char delim = ' ';
    size_t last = 0;

    if (!ifs) {
        std::cerr << "ERROR: Could not read command line params" << std::endl;
        return EXIT_FAILURE;
    }

    if (ifs.eof()) {
        std::cerr << "ERROR: Command line params empty" << std::endl;
        return EXIT_FAILURE;
    }

    std::getline(ifs, cmdline, '\n');

    while (last < cmdline.length()) {
        ssize_t next = cmdline.find(delim, last);
        if (next < 0) {
            next = cmdline.length();
        }

        std::string param = cmdline.substr(last, next - last);
        last = next + 1;

        ssize_t equals = param.find('=');
        if (equals < 0) {
            continue;
        }

        std::string name = param.substr(0, equals);
        std::string val = param.substr(equals + 1);
        if (name == CMDLINE_TEST_NAME_PARAM) {
            out->assign(val);
            return EXIT_SUCCESS;
        }
    }

    std::cerr << "ERROR: Command line params do not contain test name" << std::endl;
    return EXIT_FAILURE;
}

#define CHECKCALL(x)                                                    \
    do {                                                                \
        if ((x) != 0) {                                                 \
            std::cerr << LOG_TAG << "ERROR: " << #x << std::endl;       \
            return EXIT_FAILURE;                                        \
        }                                                               \
    } while(0)

int main() {
    std::string test_name;

    CHECKCALL(clearenv());

    std::cerr << LOG_TAG << "Loading kernel modules..." << std::endl;
    CHECKCALL(LoadKernelModules(false, false));

    std::cerr << LOG_TAG << "Mounting filesystems..." << std::endl;
    CHECKCALL(mkdir("/proc", 0777));
    CHECKCALL(mount("proc", "/proc", "proc", 0, nullptr));

    std::cerr << LOG_TAG << "Reading command line params..." << std::endl;
    CHECKCALL(ReadCommandLineTestName(&test_name));

#define VIRT_TEST(FN)                                                       \
    if (test_name == #FN) {                                                 \
        std::cerr << LOG_TAG << "Invoking " << #FN << "..." << std::endl;   \
        CHECKCALL(FN());                                                    \
        return EXIT_FAILURE;                                                \
    }
#include "virt_test_list.h"
#undef VIRT_TEST

    std::cerr << LOG_TAG << "Exiting VM..." << std::endl;
    return EXIT_FAILURE;
}
