Implemented dispatching tree modifiers
[project/luci.git] / libs / web / luasrc / dispatcher.lua
index b74c5bd..b48b584 100644 (file)
@@ -25,12 +25,13 @@ limitations under the License.
 ]]--
 
 --- LuCI web dispatcher.
-module("luci.dispatcher", package.seeall)
-require("luci.init")
-require("luci.http")
-require("luci.sys")
-require("luci.fs")
+local fs = require "luci.fs"
+local sys = require "luci.sys"
+local init = require "luci.init"
+local util = require "luci.util"
+local http = require "luci.http"
 
+module("luci.dispatcher", package.seeall)
 context = luci.util.threadlocal()
 
 authenticator = {}
@@ -78,20 +79,20 @@ function error500(message)
        return false
 end
 
-function authenticator.htmlauth(validator, default)
+function authenticator.htmlauth(validator, accs, default)
        local user = luci.http.formvalue("username")
        local pass = luci.http.formvalue("password")
-       
+
        if user and validator(user, pass) then
                return user
        end
-       
+
        require("luci.i18n")
        require("luci.template")
        context.path = {}
        luci.template.render("sysauth", {duser=default, fuser=user})
        return false
-       
+
 end
 
 --- Dispatch an HTTP request.
@@ -105,36 +106,49 @@ function httpdispatch(request)
                table.insert(context.request, node)
        end
 
-       dispatch(context.request)
+       local stat, err = util.copcall(dispatch, context.request)
+       if not stat then
+               luci.util.perror(err)
+               error500(err)
+       end
+
        luci.http.close()
+
+       --context._disable_memtrace()
 end
 
 --- Dispatches a LuCI virtual path.
 -- @param request      Virtual path
 function dispatch(request)
-       context.path = request
-       
-       require("luci.i18n")
-       luci.i18n.setlanguage(require("luci.config").main.lang)
-       
-       if not context.tree then
-               createtree()
+       --context._disable_memtrace = require "luci.debug".trap_memtrace()
+       local ctx = context
+       ctx.path = request
+
+       require "luci.i18n".setlanguage(require "luci.config".main.lang)
+
+       local c = ctx.tree
+       local stat
+       if not c then
+               c = createtree()
        end
-       
-       local c = context.tree
+
        local track = {}
        local args = {}
+       ctx.args = args
+       ctx.requestargs = ctx.requestargs or args
        local n
 
        for i, s in ipairs(request) do
                c = c.nodes[s]
                n = i
-               if not c or c.leaf then
+               if not c then
                        break
                end
 
-               for k, v in pairs(c) do
-                       track[k] = v
+               util.update(track, c)
+
+               if c.leaf then
+                       break
                end
        end
 
@@ -147,42 +161,63 @@ function dispatch(request)
        if track.i18n then
                require("luci.i18n").loadc(track.i18n)
        end
-       
+
        -- Init template engine
-       local tpl = require("luci.template")
-       local viewns = {}
-       tpl.context.viewns = viewns
-       viewns.write       = luci.http.write
-       viewns.translate   = function(...) return require("luci.i18n").translate(...) end
-       viewns.controller  = luci.http.getenv("SCRIPT_NAME")
-       viewns.media       = luci.config.main.mediaurlbase
-       viewns.resource    = luci.config.main.resourcebase
-       viewns.REQUEST_URI = luci.http.getenv("SCRIPT_NAME") .. (luci.http.getenv("PATH_INFO") or "")
-       
-       if track.dependent then
-               local stat, err = pcall(assert, not track.auto)
-               if not stat then
-                       error500(err)
-                       return
+       if (c and c.index) or not track.notemplate then
+               local tpl = require("luci.template")
+               local media = track.mediaurlbase or luci.config.main.mediaurlbase
+               if not pcall(tpl.Template, "themes/%s/header" % fs.basename(media)) then
+                       media = nil
+                       for name, theme in pairs(luci.config.themes) do
+                               if name:sub(1,1) ~= "." and pcall(tpl.Template,
+                                "themes/%s/header" % fs.basename(theme)) then
+                                       media = theme
+                               end
+                       end
+                       assert(media, "No valid theme found")
                end
+
+               local viewns = setmetatable({}, {__index=_G})
+               tpl.context.viewns = viewns
+               viewns.write       = luci.http.write
+               viewns.include     = function(name) tpl.Template(name):render(getfenv(2)) end
+               viewns.translate   = function(...) return require("luci.i18n").translate(...) end
+               viewns.striptags   = util.striptags
+               viewns.controller  = luci.http.getenv("SCRIPT_NAME")
+               viewns.media       = media
+               viewns.theme       = fs.basename(media)
+               viewns.resource    = luci.config.main.resourcebase
+               viewns.REQUEST_URI = (luci.http.getenv("SCRIPT_NAME") or "") .. (luci.http.getenv("PATH_INFO") or "")
        end
-       
+
+       track.dependent = (track.dependent ~= false)
+       assert(not track.dependent or not track.auto, "Access Violation")
+
        if track.sysauth then
-               require("luci.sauth")
-               local authen = authenticator[track.sysauth_authenticator]
+               local sauth = require "luci.sauth"
+
+               local authen = type(track.sysauth_authenticator) == "function"
+                and track.sysauth_authenticator
+                or authenticator[track.sysauth_authenticator]
+
                local def  = (type(track.sysauth) == "string") and track.sysauth
                local accs = def and {track.sysauth} or track.sysauth
-               local user = luci.sauth.read(luci.http.getcookie("sysauth"))
-               
-               if not luci.util.contains(accs, user) then
+               local sess = ctx.authsession or luci.http.getcookie("sysauth")
+               sess = sess and sess:match("^[A-F0-9]+$")
+               local user = sauth.read(sess)
+
+               if not util.contains(accs, user) then
                        if authen then
-                               local user = authen(luci.sys.user.checkpasswd, def)
-                               if not user or not luci.util.contains(accs, user) then
+                               local user, sess = authen(luci.sys.user.checkpasswd, accs, def)
+                               if not user or not util.contains(accs, user) then
                                        return
                                else
-                                       local sid = luci.sys.uniqueid(16)
+                                       local sid = sess or luci.sys.uniqueid(16)
                                        luci.http.header("Set-Cookie", "sysauth=" .. sid.."; path=/")
-                                       luci.sauth.write(sid, user)
+                                       if not sess then
+                                               sauth.write(sid, user)
+                                       end
+                                       ctx.authsession = sid
                                end
                        else
                                luci.http.status(403, "Forbidden")
@@ -199,17 +234,33 @@ function dispatch(request)
                luci.sys.process.setuser(track.setuser)
        end
 
-       if c and type(c.target) == "function" then
-               context.dispatched = c
-               stat, mod = luci.util.copcall(require, c.module)
-               if stat then
-                       luci.util.updfenv(c.target, mod)
-               end
-               
-               stat, err = luci.util.copcall(c.target, unpack(args))
-               if not stat then
-                       error500(err)
+       if c and (c.index or type(c.target) == "function") then
+               ctx.dispatched = c
+               ctx.requested = ctx.requested or ctx.dispatched
+       end
+
+       if c and c.index then
+               local tpl = require "luci.template"
+
+               if util.copcall(tpl.render, "indexer", {}) then
+                       return true
                end
+       end
+
+       if c and type(c.target) == "function" then
+               util.copcall(function()
+                       local oldenv = getfenv(c.target)
+                       local module = require(c.module)
+                       local env = setmetatable({}, {__index=
+
+                       function(tbl, key)
+                               return rawget(tbl, key) or module[key] or oldenv[key]
+                       end})
+
+                       setfenv(c.target, env)
+               end)
+
+               c.target(unpack(args))
        else
                error404()
        end
@@ -219,7 +270,7 @@ end
 function createindex()
        local path = luci.util.libpath() .. "/controller/"
        local suff = ".lua"
-       
+
        if luci.util.copcall(require, "luci.fastindex") then
                createindex_fastindex(path, suff)
        else
@@ -232,14 +283,14 @@ end
 -- @param suffix       Controller file suffix
 function createindex_fastindex(path, suffix)
        index = {}
-               
+
        if not fi then
                fi = luci.fastindex.new("index")
                fi.add(path .. "*" .. suffix)
                fi.add(path .. "*/*" .. suffix)
        end
        fi.scan()
-       
+
        for k, v in pairs(fi.indexes) do
                index[v[2]] = v[1]
        end
@@ -249,51 +300,49 @@ end
 -- @param path         Controller base directory
 -- @param suffix       Controller file suffix
 function createindex_plain(path, suffix)
-       index = {}
-
-       local cache = nil 
-       
-       local controllers = luci.util.combine(
+       local controllers = util.combine(
                luci.fs.glob(path .. "*" .. suffix) or {},
                luci.fs.glob(path .. "*/*" .. suffix) or {}
        )
-       
+
        if indexcache then
-               cache = luci.fs.mtime(indexcache)
-               
-               if not cache then
-                       luci.fs.mkdir(indexcache)
-                       luci.fs.chmod(indexcache, "a=,u=rwx")
-                       cache = luci.fs.mtime(indexcache)
+               local cachedate = fs.mtime(indexcache)
+               if cachedate then
+                       local realdate = 0
+                       for _, obj in ipairs(controllers) do
+                               local omtime = fs.mtime(path .. "/" .. obj)
+                               realdate = (omtime and omtime > realdate) and omtime or realdate
+                       end
+
+                       if cachedate > realdate then
+                               assert(
+                                       sys.process.info("uid") == fs.stat(indexcache, "uid")
+                                       and fs.stat(indexcache, "mode") == "rw-------",
+                                       "Fatal: Indexcache is not sane!"
+                               )
+
+                               index = loadfile(indexcache)()
+                               return index
+                       end
                end
        end
 
+       index = {}
+
        for i,c in ipairs(controllers) do
                local module = "luci.controller." .. c:sub(#path+1, #c-#suffix):gsub("/", ".")
-               local cachefile
-               local stime
-               local ctime
-               
-               if cache then
-                       cachefile = indexcache .. "/" .. module
-                       stime = luci.fs.mtime(c) or 0
-                       ctime = luci.fs.mtime(cachefile) or 0
-               end
-               
-               if not cache or stime > ctime then 
-                       stat, mod = luci.util.copcall(require, module)
-       
-                       if stat and mod and type(mod.index) == "function" then
-                               index[module] = mod.index
-                               
-                               if cache then
-                                       luci.fs.writefile(cachefile, luci.util.get_bytecode(mod.index))
-                               end
-                       end
-               else
-                       index[module] = loadfile(cachefile)
+               local mod = require(module)
+               local idx = mod.index
+
+               if type(idx) == "function" then
+                       index[module] = idx
                end
        end
+
+       if indexcache then
+               fs.writefile(indexcache, util.get_bytecode(index))
+               fs.chmod(indexcache, "a-rwx,u+rw")
+       end
 end
 
 --- Create the dispatching tree from the index.
@@ -302,31 +351,48 @@ function createtree()
        if not index then
                createindex()
        end
-       
-       context.tree = {nodes={}}
-       require("luci.i18n")
-               
+
+       local ctx  = context
+       local tree = {nodes={}}
+       local modi = {}
+
+       ctx.treecache = setmetatable({}, {__mode="v"})
+       ctx.tree = tree
+       ctx.modifiers = modi
+
        -- Load default translation
-       luci.i18n.loadc("default")
-       
-       local scope = luci.util.clone(_G)
-       for k,v in pairs(luci.dispatcher) do
-               if type(v) == "function" then
-                       scope[k] = v
-               end
-       end
+       require "luci.i18n".loadc("default")
+
+       local scope = setmetatable({}, {__index = luci.dispatcher})
 
        for k, v in pairs(index) do
                scope._NAME = k
                setfenv(v, scope)
+               v()
+       end
 
-               local stat, err = luci.util.copcall(v)
-               if not stat then
-                       error500("createtree failed: " .. k .. ": " .. err)
-                       luci.http.close()
-                       os.exit(1)
-               end
+       local function modisort(a,b)
+               return modi[a].order < modi[b].order
+       end
+
+       for _, v in util.spairs(modi, modisort) do
+               scope._NAME = v.module
+               setfenv(v.func, scope)
+               v.func()
        end
+
+       return tree
+end
+
+--- Register a tree modifier.
+-- @param      func    Modifier function
+-- @param      order   Modifier order value (optional)
+function modifier(func, order)
+       context.modifiers[#context.modifiers+1] = {
+               func = func,
+               order = order or 0,
+               module = getfenv(2)._NAME
+       }
 end
 
 --- Clone a node of the dispatching tree to another position.
@@ -339,33 +405,24 @@ function assign(path, clone, title, order)
        local obj  = node(unpack(path))
        obj.nodes  = nil
        obj.module = nil
-       
+
        obj.title = title
        obj.order = order
-       
-       local c = context.tree
-       for k, v in ipairs(clone) do
-               if not c.nodes[v] then
-                       c.nodes[v] = {nodes={}}
-               end
 
-               c = c.nodes[v]
-       end
-       
-       setmetatable(obj, {__index = c})
-       
+       setmetatable(obj, {__index = _create_node(clone)})
+
        return obj
 end
 
 --- Create a new dispatching node and define common parameters.
 -- @param      path    Virtual path
--- @param      target  Target function to call when dispatched. 
+-- @param      target  Target function to call when dispatched.
 -- @param      title   Destination node title
 -- @param      order   Destination node order value (optional)
 -- @return                     Dispatching tree node
 function entry(path, target, title, order)
        local c = node(unpack(path))
-       
+
        c.target = target
        c.title  = title
        c.order  = order
@@ -378,16 +435,7 @@ end
 -- @param      ...             Virtual path
 -- @return                     Dispatching tree node
 function node(...)
-       local c = context.tree
-       arg.n = nil
-
-       for k,v in ipairs(arg) do
-               if not c.nodes[v] then
-                       c.nodes[v] = {nodes={}, auto=true}
-               end
-
-               c = c.nodes[v]
-       end
+       local c = _create_node({...})
 
        c.module = getfenv(2)._NAME
        c.path = arg
@@ -396,13 +444,40 @@ function node(...)
        return c
 end
 
+function _create_node(path, cache)
+       if #path == 0 then
+               return context.tree
+       end
+
+       cache = cache or context.treecache
+       local name = table.concat(path, ".")
+       local c = cache[name]
+
+       if not c then
+               local last = table.remove(path)
+               c = _create_node(path, cache)
+
+               local new = {nodes={}, auto=true}
+               c.nodes[last] = new
+               cache[name] = new
+
+               return new
+       else
+               return c
+       end
+end
+
 -- Subdispatchers --
 
 --- Create a redirect to another dispatching node.
 -- @param      ...             Virtual path destination
 function alias(...)
-       local req = arg
-       return function()
+       local req = {...}
+       return function(...)
+               for _, r in ipairs({...}) do
+                       req[#req+1] = r
+               end
+
                dispatch(req)
        end
 end
@@ -411,85 +486,126 @@ end
 -- @param      n               Number of path values to replace
 -- @param      ...             Virtual path to replace removed path values with
 function rewrite(n, ...)
-       local req = arg
-       return function()
-               for i=1,n do 
-                       table.remove(context.path, 1)
+       local req = {...}
+       return function(...)
+               local dispatched = util.clone(context.dispatched)
+
+               for i=1,n do
+                       table.remove(dispatched, 1)
+               end
+
+               for i, r in ipairs(req) do
+                       table.insert(dispatched, i, r)
                end
-               
-               for i,r in ipairs(req) do
-                       table.insert(context.path, i, r)
+
+               for _, r in ipairs({...}) do
+                       dispatched[#dispatched+1] = r
                end
-               
-               dispatch()
+
+               dispatch(dispatched)
        end
 end
 
 --- Create a function-call dispatching target.
--- @param      name    Target function of local controller 
+-- @param      name    Target function of local controller
 -- @param      ...             Additional parameters passed to the function
 function call(name, ...)
        local argv = {...}
-       return function() return getfenv()[name](unpack(argv)) end
+       return function(...)
+               if #argv > 0 then 
+                       return getfenv()[name](unpack(argv), ...)
+               else
+                       return getfenv()[name](...)
+               end
+       end
 end
 
 --- Create a template render dispatching target.
 -- @param      name    Template to be rendered
 function template(name)
-       require("luci.template")
-       return function() luci.template.render(name) end
+       return function()
+               require("luci.template")
+               luci.template.render(name)
+       end
 end
 
 --- Create a CBI model dispatching target.
--- @param      model   CBI model tpo be rendered
-function cbi(model)
-       require("luci.cbi")
-       require("luci.template")
-
+-- @param      model   CBI model to be rendered
+function cbi(model, config)
+       config = config or {}
        return function(...)
-               local stat, maps = luci.util.copcall(luci.cbi.load, model, ...)
-               if not stat then
-                       error500(maps)
-                       return true
-               end
+               require("luci.cbi")
+               require("luci.template")
+               local http = require "luci.http"
+
+               maps = luci.cbi.load(model, ...)
+
+               local state = nil
 
                for i, res in ipairs(maps) do
-                       local stat, err = luci.util.copcall(res.parse, res)
-                       if not stat then
-                               error500(err)
-                               return true
+                       if config.autoapply then
+                               res.autoapply = config.autoapply
+                       end
+                       local cstate = res:parse()
+                       if not state or cstate < state then
+                               state = cstate
                        end
                end
 
-               luci.template.render("cbi/header")
+               if config.on_valid_to and state and state > 0 and state < 2 then
+                       luci.http.redirect(config.on_valid_to)
+                       return
+               end
+
+               if config.on_changed_to and state and state > 1 then
+                       luci.http.redirect(config.on_changed_to)
+                       return
+               end
+
+               if config.on_success_to and state and state > 0 then
+                       luci.http.redirect(config.on_success_to)
+                       return
+               end
+
+               if config.state_handler then
+                       if not config.state_handler(state, maps) then
+                               return
+                       end
+               end
+
+               local pageaction = true
+               http.header("X-CBI-State", state or 0)
+               luci.template.render("cbi/header", {state = state})
                for i, res in ipairs(maps) do
                        res:render()
+                       if res.pageaction == false then
+                               pageaction = false
+                       end
                end
-               luci.template.render("cbi/footer")
+               luci.template.render("cbi/footer", {pageaction=pageaction, state = state, autoapply = config.autoapply})
        end
 end
 
 --- Create a CBI form model dispatching target.
 -- @param      model   CBI form model tpo be rendered
 function form(model)
-       require("luci.cbi")
-       require("luci.template")
-
        return function(...)
-               local stat, maps = luci.util.copcall(luci.cbi.load, model, ...)
-               if not stat then
-                       error500(maps)
-                       return true
-               end
+               require("luci.cbi")
+               require("luci.template")
+               local http = require "luci.http"
+
+               maps = luci.cbi.load(model, ...)
+
+               local state = nil
 
                for i, res in ipairs(maps) do
-                       local stat, err = luci.util.copcall(res.parse, res)
-                       if not stat then
-                               error500(err)
-                               return true
+                       local cstate = res:parse()
+                       if not state or cstate < state then
+                               state = cstate
                        end
                end
 
+               http.header("X-CBI-State", state or 0)
                luci.template.render("header")
                for i, res in ipairs(maps) do
                        res:render()