/*
 *****************************************************************************
 * Copyright (C) 2017, Cisco Systems
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 *****************************************************************************
 *
 *  File:    utils.h
 *  Author:  Koushik Chakravarty <kouchakr@cisco.com>
 *
 ****************************************************************************
 *
 *  This file contains the utility apis
 *
 **************************************************************************
 */

#ifndef _UTILS_H_
#define _UTILS_H_

#include "defines.h"
#ifdef NVM_BPF_USERSPACE
#include <string.h>
#include <sys/time.h>
#include <unistd.h>
#include <unordered_map>

struct task_struct {
	pid_t tgid;
	char comm[16];
	uint64_t start_time;
};
#else
#include <linux/module.h>
#include <linux/version.h>
#include <linux/hashtable.h>
#endif

extern const char *default_name;

/*
 * Task methods
 */
struct task_struct;

struct task_struct *get_curr_task(void);
pid_t get_pid_of_task(struct task_struct *task);
struct task_struct *get_task_from_pid(pid_t uPid);
struct task_struct *get_parent(struct task_struct *task);
#ifndef NVM_BPF_USERSPACE
uint16_t get_exepath_from_curr_task(struct task_struct *task, char *path_buffer,
				    uint16_t buffer_size);
#endif
uint16_t get_exepath_from_task(struct task_struct *task, char *path_buffer,
			       uint16_t buffer_size);
void get_process_creation_time(struct task_struct *task, uint64_t *time);
void unref_task(struct task_struct *task);

/* Api to get the current time */
uint32_t get_unix_systime(void);
/* Get the task name with the path */
uint16_t get_taskname(struct task_struct *task, char *name_buffer,
		      uint16_t buffer_size);

/* socket send method */
struct socket;
error_code socket_sendto(struct socket *local, struct sockaddr_in *dest,
			   const uint8_t *pBuffer, size_t buff_len);

/*
 * Hash list related macros
 */

#ifdef NVM_BPF_USERSPACE
// Define a simple hash list node
struct hash_list {
    void *data;
    uint32_t key;
    std::unordered_multimap<uint32_t, hash_list*>* parent_map; // Store reference to containing map
    
    // Constructor
    hash_list(void* d, uint32_t k, std::unordered_multimap<uint32_t, hash_list*>* map = nullptr) 
        : data(d), key(k), parent_map(map) {}
    // Destructor
    ~hash_list() {}
};

#define HLIST_ITER 

// Iterate through specific key match with safe deletion
#define for_each_hlist_match(table, node_ptr, member, key_expr) \
    for (auto&& _range = (table).equal_range(key_expr); \
         _range.first != _range.second && \
         (node_ptr = _range.first->second) != nullptr; \
         ++_range.first)

// Iterate through all entries safely - simpler implementation
#define for_each_hlist(table, bkt, tmp, node_ptr, member) \
    for (auto _it = table.begin(), _next = _it; \
         _it != table.end() && (_next = std::next(_it), true) && \
         (node_ptr = _it->second) != nullptr; \
         _it = _next)

// Add entry to hashtable
#define hlist_add(hashtable, entry, entry_key) \
    do { \
        hash_list* node = new(std::nothrow) hash_list(entry, entry_key, &hashtable); \
        if (node != nullptr) { \
            hashtable.emplace(entry_key, node); \
        } \
    } while(0)


// Delete entry from hashtable
#define hlist_del(node) \
    do { \
        if (node != nullptr && node->parent_map != nullptr) { \
            auto _range = node->parent_map->equal_range(node->key); \
            for (auto _it = _range.first; _it != _range.second; ) { \
                if (_it->second == node) { \
                    _it = node->parent_map->erase(_it); \
                    delete node; \
                    break; /* Only remove first matching entry */ \
                } else { \
                    ++_it; \
                } \
            } \
            node = nullptr; \
        } \
    } while(0)

// Initialize hashtable
#define hash_init(table) \
    do { \
        for (auto& entry : table) { \
            delete entry.second; \
        } \
        table.clear(); \
    } while(0)

// Declare hashtable
#define DECLARE_HASHTABLE(name, bits) std::unordered_multimap<uint32_t, hash_list*> name = std::unordered_multimap<uint32_t, hash_list*>(1 << (bits))

#else
/*
*   \brief wrapper to kernel hashtable
*/

/* custom hash list node */
struct hash_list {
	struct hlist_node next;
	void *data;
	uint32_t key;
};

#define IS_IP_VERSION4(x) (4 == x)
#define LOOPBACK_ADDR_V4_BE  (0x0100007fU)

#if (LINUX_VERSION_CODE <= KERNEL_VERSION(3, 8, 0))

#define HLIST_ITER struct hlist_node *temp

#define for_each_hlist(table, bkt, tmp, node, member) \
	hash_for_each_safe(table, bkt, temp, tmp, node, member)

#define for_each_hlist_match(table, node, member, key) \
	hash_for_each_possible(table, node, temp, member, key)

#else				/*LINUX VERSION */

#define HLIST_ITER

#define for_each_hlist(table, bkt, tmp, node, member) \
	hash_for_each_safe(table, bkt, tmp, node, member)

#define for_each_hlist_match(table, node, member, key) \
	hash_for_each_possible(table, node, member, key)

#endif				/*LINUX VERSION */

#define hlist_add(hashtable, entry, entry_key) \
	{ \
	struct hash_list *new = KMALLOC(sizeof(struct hash_list)); \
	if (NULL == new) \
		return; \
	memset(new, 0, sizeof(struct hash_list)); \
	new->data = entry; \
	new->key = entry_key; \
	hash_add(hashtable, &new->next, new->key); \
	}

#define hlist_del(a) \
	{ \
	hash_del(&a->next); \
	KFREE(a); \
	a = NULL; \
	}
#endif

#endif
