From 81d344efb76b69c3f1a1a15c4dffda2b785e3dfa Mon Sep 17 00:00:00 2001
From: mihalicyn <alexander@mihalicyn.com>
Date: Thu, 21 Mar 2019 23:57:18 +0300
Subject: [PATCH] Several changes: * possible memory leaks fixed * code clean
 up * now driver supports running several lua scripts in threads

---
 README.md |  61 +++-------------
 luadrv.c  | 204 ++++++++++++++++++++++++++++++++++--------------------
 2 files changed, 138 insertions(+), 127 deletions(-)

diff --git a/README.md b/README.md
index 98b50b2..4513753 100644
--- a/README.md
+++ b/README.md
@@ -1,63 +1,18 @@
 # poc-driver - Linux kernel driver for lunatik
 
-## Compiling
+## Compiling and running
 To build module you need to:
-1. Clone sources for
-Driver
-https://github.com/luainkernel/poc-driver
-Lua (kernel port)
-https://github.com/luainkernel/lunatik
+1. Clone sources
+git clone --recursive https://github.com/luainkernel/poc-driver
 2. Assume that you have kernel tree sources in /usr/src/linux
-3. create symlinks to lunatik and poc-driver in /usr/src/linux/drivers with corresponding names
+3. Then compile:
 ```
-ln -s /where_you_put_lunatik_src /usr/src/linux/drivers/lunatik
-ln -s /where_you_put_poc-driver_src /usr/src/linux/drivers/poc-driver
+make
 ```
-4. edit drivers/Kconfig to add following:
+4. Run (as root):
 ```
-source drivers/lunatik/Kconfig
-```
-5. lunatik/Kconfig contents:
-```
-config LUNATIK
-    tristate "Lunatik"
-    
-config LUNATIK_POC
-    bool "Use poc driver"
-    depends on LUNATIK
-    default y
-```
-6. edit drivers/Makefile to add:
-```
-obj-$(CONFIG_LUNATIK) += lunatik/
-```
-7. lunatik/Makefile example code:
-```
-EXTRA_CFLAGS += -D_KERNEL
-# for poc-driver:
-EXTRA_CFLAGS += -I$(src)
-
-obj-$(CONFIG_LUNATIK) += lunatik.o
-
-lunatik-objs := lua/lapi.o lua/lcode.o lua/lctype.o lua/ldebug.o lua/ldo.o \
-         lua/ldump.o lua/lfunc.o lua/lgc.o lua/llex.o lua/lmem.o \
-	 lua/lobject.o lua/lopcodes.o lua/lparser.o lua/lstate.o \
-         lua/lstring.o lua/ltable.o lua/ltm.o \
-	 lua/lundump.o lua/lvm.o lua/lzio.o lua/lauxlib.o lua/lbaselib.o \
-         lua/lbitlib.o lua/lcorolib.o lua/ldblib.o lua/lstrlib.o \
-	 lua/ltablib.o lua/lutf8lib.o lua/loslib.o lua/lmathlib.o lua/linit.o
-
-lunatik-objs += arch/$(ARCH)/setjmp.o
-
-lunatik-${CONFIG_LUNATIK_POC} += ../poc-driver/luadrv.o
-```
-8. Then:
-```
-cd /usr/src/linux
-#compile
-make modules -j4 ARCH=x86_64
-#load
-modprobe -v lunatik
+insmod ./dependencies/lunatik/lunatik.ko
+insmod poc-driver.ko
 ```
 
 ## Usage
diff --git a/luadrv.c b/luadrv.c
index 8f46909..22757bc 100644
--- a/luadrv.c
+++ b/luadrv.c
@@ -6,6 +6,7 @@
 #include <linux/device.h>
 #include <linux/cdev.h>
 #include <linux/uaccess.h>
+#include <linux/kthread.h>
 
 #include <lua.h>
 #include <lualib.h>
@@ -13,22 +14,38 @@
 
 MODULE_LICENSE("Dual MIT/GPL");
 MODULE_AUTHOR("Pedro Tammela <pctammela@gmail.com>");
-MODULE_DESCRIPTION("sample driver for lunatik proof of concepts");
+MODULE_AUTHOR("lunatik team (https://github.com/luainkernel)");
+MODULE_DESCRIPTION("Basic kernel module that provides /dev/luadrv character device to load and execute arbitrary lua code");
 
 #define DEVICE_NAME "luadrv"
 #define CLASS_NAME "lua"
+#define NSTATES 4
+// currently supported only one device
 #define LUA_MAX_MINORS  1
 
 #define raise_err(msg) pr_warn("[lua] %s - %s\n", __func__, msg);
+#define print_info(msg) pr_warn("[lua] %s\n", msg);
 
-static DEFINE_MUTEX(mtx);
+typedef struct device_data {
+	dev_t dev;
+	struct device *luadev;
+	struct class *luaclass;
+	struct cdev luacdev;
+	struct mutex lock;
+} device_data;
 
-static lua_State *L;
-static bool hasreturn = 0; /* does the lua state have anything for us? */
-static dev_t dev;
-static struct device *luadev;
-static struct class *luaclass;
-static struct cdev luacdev;
+static device_data devs[LUA_MAX_MINORS];
+
+typedef struct lua_exec {
+	int id;
+	lua_State *L;
+	int stacktop;
+	char *script;
+	struct task_struct *kthread;
+	struct mutex lock;
+} lua_exec;
+
+static lua_exec lua_states[NSTATES];
 
 static int dev_open(struct inode*, struct file*);
 static int dev_release(struct inode*, struct file*);
@@ -45,66 +62,85 @@ static struct file_operations fops =
 
 static int __init luadrv_init(void)
 {
-	int ret;
+	int ret, i, j;
+	device_data *dev = &devs[0];
 
-	ret = alloc_chrdev_region(&dev, 0, LUA_MAX_MINORS, "lua");
+	ret = alloc_chrdev_region(&dev->dev, 0, LUA_MAX_MINORS, "lua");
 	if (ret) {
 		raise_err("alloc_chrdev_region failed");
 		goto error;
 	}
 
-	cdev_init(&luacdev, &fops);
-	ret = cdev_add(&luacdev, dev, LUA_MAX_MINORS);
+	cdev_init(&dev->luacdev, &fops);
+	ret = cdev_add(&dev->luacdev, dev->dev, LUA_MAX_MINORS);
 	if (ret) {
 		raise_err("cdev_add failed");
 		goto error_free_region;
 	}
 
-	luaclass = class_create(THIS_MODULE, CLASS_NAME);
-	if (IS_ERR(luaclass)) {
+	dev->luaclass = class_create(THIS_MODULE, CLASS_NAME);
+	if (IS_ERR(dev->luaclass)) {
 		raise_err("class_create failed");
-		ret = PTR_ERR(luaclass);
+		ret = PTR_ERR(dev->luaclass);
 		goto error_free_cdev;
 	}
 
-	luadev = device_create(luaclass, NULL, dev,
+	dev->luadev = device_create(dev->luaclass, NULL, dev->dev,
 			NULL, "%s", DEVICE_NAME);
-	if (IS_ERR(luadev)) {
+	if (IS_ERR(dev->luadev)) {
 		raise_err("device_create failed");
-		ret = PTR_ERR(luadev);
+		ret = PTR_ERR(dev->luadev);
 		goto error_free_class;
 	}
 
-	L = luaL_newstate();
-	if (L == NULL) {
-		raise_err("no memory");
-		ret = -ENOMEM;
-		goto error_free_device;
+	mutex_init(&dev->lock);
+
+	for (i = 0; i < NSTATES; i++) {
+		lua_states[i].id = i;
+		lua_states[i].L = luaL_newstate();
+
+		if (lua_states[i].L == NULL) {
+			raise_err("no memory");
+			ret = -ENOMEM;
+
+			for (j = 0; j < i; j++) {
+				lua_close(lua_states[j].L);
+			}
+
+			goto error_free_device;
+		}
+
+		luaL_openlibs(lua_states[i].L);
+		mutex_init(&lua_states[i].lock);
 	}
-	luaL_openlibs(L);
 
 	return 0;
 
 error_free_device:
-	device_destroy(luaclass, dev);
+	device_destroy(dev->luaclass, dev->dev);
 error_free_class:
-	class_destroy(luaclass);
+	class_destroy(dev->luaclass);
 error_free_cdev:
-	cdev_del(&luacdev);
+	cdev_del(&dev->luacdev);
 error_free_region:
-	unregister_chrdev_region(dev, LUA_MAX_MINORS);
+	unregister_chrdev_region(dev->dev, LUA_MAX_MINORS);
 error:
 	return ret;
 }
 
 static void __exit luadrv_exit(void)
 {
-	lua_close(L);
+	int i;
+	device_data *dev = &devs[0];
+
+	for (i = 0; i < NSTATES; i++) {
+		lua_close(lua_states[i].L);
+	}
 
-	device_destroy(luaclass, dev);
-	class_destroy(luaclass);
-	cdev_del(&luacdev);
-	unregister_chrdev_region(dev, LUA_MAX_MINORS);
+	device_destroy(dev->luaclass, dev->dev);
+	class_destroy(dev->luaclass);
+	cdev_del(&dev->luacdev);
+	unregister_chrdev_region(dev->dev, LUA_MAX_MINORS);
 }
 
 static int dev_open(struct inode *i, struct file *f)
@@ -114,75 +150,95 @@ static int dev_open(struct inode *i, struct file *f)
 
 static ssize_t dev_read(struct file *f, char *buf, size_t len, loff_t *off)
 {
-	const char *msg = "Nothing yet.\n";
-	int msglen;
-	int err;
-	mutex_lock(&mtx);
-	if (hasreturn) {
-		msg = lua_tostring(L, -1);
-		hasreturn = false;
-	}
-	if ((err = copy_to_user(buf, msg, len)) < 0) {
-		raise_err("copy to user failed");
-		mutex_unlock(&mtx);
-		return -ECANCELED;
-	}
-	mutex_unlock(&mtx);
-	msglen = strlen(msg);
-	return msglen < len ? msglen : len;
+	return 0;
 }
 
-static int flushL(void)
+static lua_State* flush(lua_State *L)
 {
 	lua_close(L);
 	L = luaL_newstate();
 	if (L == NULL) {
 		raise_err("flushL failed, giving up");
-		mutex_unlock(&mtx);
-		return 1;
+		return NULL;
 	}
 	luaL_openlibs(L);
 	raise_err("lua state flushed");
-	return 0;
+	return L;
+}
+
+static int thread_fn(void *arg)
+{
+	int ret = 0;
+	lua_exec *lua = arg;
+	set_current_state(TASK_INTERRUPTIBLE);
+
+	printk("[lua] running thread %d\n", lua->id);
+	if (luaL_dostring(lua->L, lua->script)) {
+		raise_err("script error, flushing the state\n");
+		printk("%s\n", lua_tostring(lua->L, -1));
+		lua->L = flush(lua->L);
+		ret = -ECANCELED;
+	} else if (lua_gettop(lua->L) > lua->stacktop) {
+		printk("[lua] thread %d result: %s\n", lua->id, lua_tostring(lua->L, -1));
+	}
+
+	kfree(lua->script);
+	mutex_unlock(&lua->lock);
+
+	printk("[lua] thread %d finished\n", lua->id);
+	return ret;
 }
 
 static ssize_t dev_write(struct file *f, const char *buf, size_t len,
 		loff_t* off)
 {
+	device_data *dev = &devs[0];
+	int ret, i;
 	char *script = NULL;
-	int idx = lua_gettop(L);
-	int err;
-	mutex_lock(&mtx);
-	script = kmalloc(len, GFP_KERNEL);
+
+	mutex_lock(&dev->lock);
+
+	script = kmalloc(len + 1, GFP_KERNEL);
 	if (script == NULL) {
 		raise_err("no memory");
-		return -ENOMEM;
+		ret = -ENOMEM;
+		goto return_unlock;
 	}
-	if ((err = copy_from_user(script, buf, len)) < 0) {
+
+	if (copy_from_user(script, buf, len) < 0) {
 		raise_err("copy from user failed");
-		mutex_unlock(&mtx);
-		return -ECANCELED;
+		ret = -ECANCELED;
+		goto return_free;
 	}
-	script[len - 1] = '\0';
-	if (luaL_dostring(L, script)) {
-		raise_err(lua_tostring(L, -1));
-		if (flushL()) {
-			return -ECANCELED;
+	script[len] = '\0';
+
+	for (i = 0; i < NSTATES; i++) {
+		if (lua_states[i].L != NULL && mutex_trylock(&lua_states[i].lock)) {
+			lua_states[i].stacktop = lua_gettop(lua_states[i].L);
+			lua_states[i].script = script;
+			lua_states[i].kthread = kthread_run(thread_fn, &lua_states[i], "lua kthread %d", lua_states[i].id);
+			if(IS_ERR(lua_states[i].kthread)) {
+				ret = PTR_ERR(lua_states[i].kthread);
+				goto return_free;
+			}
+
+			ret = len;
+			goto return_unlock;
 		}
-		mutex_unlock(&mtx);
-		return -ECANCELED;
 	}
+
+	raise_err("all lua states are busy");
+	ret = -EBUSY;
+
+return_free:
 	kfree(script);
-	hasreturn = lua_gettop(L) > idx ? true : false;
-	mutex_unlock(&mtx);
-	return len;
+return_unlock:
+	mutex_unlock(&dev->lock);
+	return ret;
 }
 
 static int dev_release(struct inode *i, struct file *f)
 {
-	mutex_lock(&mtx);
-	hasreturn = false;
-	mutex_unlock(&mtx);
 	return 0;
 }