36bbaaf47f6e25659f039301b09920da039b0584
[project/luci.git] / modules / luci-base / luasrc / util.lua
1 -- Copyright 2008 Steven Barth <steven@midlink.org>
2 -- Licensed to the public under the Apache License 2.0.
3
4 local io = require "io"
5 local math = require "math"
6 local table = require "table"
7 local debug = require "debug"
8 local ldebug = require "luci.debug"
9 local string = require "string"
10 local coroutine = require "coroutine"
11 local tparser = require "luci.template.parser"
12 local json = require "luci.jsonc"
13 local lhttp = require "lucihttp"
14
15 local _ubus = require "ubus"
16 local _ubus_connection = nil
17
18 local getmetatable, setmetatable = getmetatable, setmetatable
19 local rawget, rawset, unpack = rawget, rawset, unpack
20 local tostring, type, assert, error = tostring, type, assert, error
21 local ipairs, pairs, next, loadstring = ipairs, pairs, next, loadstring
22 local require, pcall, xpcall = require, pcall, xpcall
23 local collectgarbage, get_memory_limit = collectgarbage, get_memory_limit
24
25 module "luci.util"
26
27 --
28 -- Pythonic string formatting extension
29 --
30 getmetatable("").__mod = function(a, b)
31         local ok, res
32
33         if not b then
34                 return a
35         elseif type(b) == "table" then
36                 local k, _
37                 for k, _ in pairs(b) do if type(b[k]) == "userdata" then b[k] = tostring(b[k]) end end
38
39                 ok, res = pcall(a.format, a, unpack(b))
40                 if not ok then
41                         error(res, 2)
42                 end
43                 return res
44         else
45                 if type(b) == "userdata" then b = tostring(b) end
46
47                 ok, res = pcall(a.format, a, b)
48                 if not ok then
49                         error(res, 2)
50                 end
51                 return res
52         end
53 end
54
55
56 --
57 -- Class helper routines
58 --
59
60 -- Instantiates a class
61 local function _instantiate(class, ...)
62         local inst = setmetatable({}, {__index = class})
63
64         if inst.__init__ then
65                 inst:__init__(...)
66         end
67
68         return inst
69 end
70
71 -- The class object can be instantiated by calling itself.
72 -- Any class functions or shared parameters can be attached to this object.
73 -- Attaching a table to the class object makes this table shared between
74 -- all instances of this class. For object parameters use the __init__ function.
75 -- Classes can inherit member functions and values from a base class.
76 -- Class can be instantiated by calling them. All parameters will be passed
77 -- to the __init__ function of this class - if such a function exists.
78 -- The __init__ function must be used to set any object parameters that are not shared
79 -- with other objects of this class. Any return values will be ignored.
80 function class(base)
81         return setmetatable({}, {
82                 __call  = _instantiate,
83                 __index = base
84         })
85 end
86
87 function instanceof(object, class)
88         local meta = getmetatable(object)
89         while meta and meta.__index do
90                 if meta.__index == class then
91                         return true
92                 end
93                 meta = getmetatable(meta.__index)
94         end
95         return false
96 end
97
98
99 --
100 -- Scope manipulation routines
101 --
102
103 local tl_meta = {
104         __mode = "k",
105
106         __index = function(self, key)
107                 local t = rawget(self, coxpt[coroutine.running()]
108                  or coroutine.running() or 0)
109                 return t and t[key]
110         end,
111
112         __newindex = function(self, key, value)
113                 local c = coxpt[coroutine.running()] or coroutine.running() or 0
114                 local r = rawget(self, c)
115                 if not r then
116                         rawset(self, c, { [key] = value })
117                 else
118                         r[key] = value
119                 end
120         end
121 }
122
123 -- the current active coroutine. A thread local store is private a table object
124 -- whose values can't be accessed from outside of the running coroutine.
125 function threadlocal(tbl)
126         return setmetatable(tbl or {}, tl_meta)
127 end
128
129
130 --
131 -- Debugging routines
132 --
133
134 function perror(obj)
135         return io.stderr:write(tostring(obj) .. "\n")
136 end
137
138 function dumptable(t, maxdepth, i, seen)
139         i = i or 0
140         seen = seen or setmetatable({}, {__mode="k"})
141
142         for k,v in pairs(t) do
143                 perror(string.rep("\t", i) .. tostring(k) .. "\t" .. tostring(v))
144                 if type(v) == "table" and (not maxdepth or i < maxdepth) then
145                         if not seen[v] then
146                                 seen[v] = true
147                                 dumptable(v, maxdepth, i+1, seen)
148                         else
149                                 perror(string.rep("\t", i) .. "*** RECURSION ***")
150                         end
151                 end
152         end
153 end
154
155
156 --
157 -- String and data manipulation routines
158 --
159
160 function pcdata(value)
161         return value and tparser.pcdata(tostring(value))
162 end
163
164 function urlencode(value)
165         if value ~= nil then
166                 local str = tostring(value)
167                 return lhttp.urlencode(str, lhttp.ENCODE_IF_NEEDED + lhttp.ENCODE_FULL)
168                         or str
169         end
170         return nil
171 end
172
173 function urldecode(value, decode_plus)
174         if value ~= nil then
175                 local flag = decode_plus and lhttp.DECODE_PLUS or 0
176                 local str = tostring(value)
177                 return lhttp.urldecode(str, lhttp.DECODE_IF_NEEDED + flag)
178                         or str
179         end
180         return nil
181 end
182
183 function striptags(value)
184         return value and tparser.striptags(tostring(value))
185 end
186
187 function shellquote(value)
188         return string.format("'%s'", string.gsub(value or "", "'", "'\\''"))
189 end
190
191 -- for bash, ash and similar shells single-quoted strings are taken
192 -- literally except for single quotes (which terminate the string)
193 -- (and the exception noted below for dash (-) at the start of a
194 -- command line parameter).
195 function shellsqescape(value)
196    local res
197    res, _ = string.gsub(value, "'", "'\\''")
198    return res
199 end
200
201 -- bash, ash and other similar shells interpret a dash (-) at the start
202 -- of a command-line parameters as an option indicator regardless of
203 -- whether it is inside a single-quoted string.  It must be backlash
204 -- escaped to resolve this.  This requires in some funky special-case
205 -- handling.  It may actually be a property of the getopt function
206 -- rather than the shell proper.
207 function shellstartsqescape(value)
208    res, _ = string.gsub(value, "^\-", "\\-")
209    res, _ = string.gsub(res, "^-", "\-")
210    return shellsqescape(value)
211 end
212
213 -- containing the resulting substrings. The optional max parameter specifies
214 -- the number of bytes to process, regardless of the actual length of the given
215 -- string. The optional last parameter, regex, specifies whether the separator
216 -- sequence is interpreted as regular expression.
217 --                                      pattern as regular expression (optional, default is false)
218 function split(str, pat, max, regex)
219         pat = pat or "\n"
220         max = max or #str
221
222         local t = {}
223         local c = 1
224
225         if #str == 0 then
226                 return {""}
227         end
228
229         if #pat == 0 then
230                 return nil
231         end
232
233         if max == 0 then
234                 return str
235         end
236
237         repeat
238                 local s, e = str:find(pat, c, not regex)
239                 max = max - 1
240                 if s and max < 0 then
241                         t[#t+1] = str:sub(c)
242                 else
243                         t[#t+1] = str:sub(c, s and s - 1)
244                 end
245                 c = e and e + 1 or #str + 1
246         until not s or max < 0
247
248         return t
249 end
250
251 function trim(str)
252         return (str:gsub("^%s*(.-)%s*$", "%1"))
253 end
254
255 function cmatch(str, pat)
256         local count = 0
257         for _ in str:gmatch(pat) do count = count + 1 end
258         return count
259 end
260
261 -- one token per invocation, the tokens are separated by whitespace. If the
262 -- input value is a table, it is transformed into a string first. A nil value
263 -- will result in a valid interator which aborts with the first invocation.
264 function imatch(v)
265         if type(v) == "table" then
266                 local k = nil
267                 return function()
268                         k = next(v, k)
269                         return v[k]
270                 end
271
272         elseif type(v) == "number" or type(v) == "boolean" then
273                 local x = true
274                 return function()
275                         if x then
276                                 x = false
277                                 return tostring(v)
278                         end
279                 end
280
281         elseif type(v) == "userdata" or type(v) == "string" then
282                 return tostring(v):gmatch("%S+")
283         end
284
285         return function() end
286 end
287
288 -- value or 0 if the unit is unknown. Upper- or lower case is irrelevant.
289 -- Recognized units are:
290 --      o "y"   - one year   (60*60*24*366)
291 --  o "m"       - one month  (60*60*24*31)
292 --  o "w"       - one week   (60*60*24*7)
293 --  o "d"       - one day    (60*60*24)
294 --  o "h"       - one hour       (60*60)
295 --  o "min"     - one minute (60)
296 --  o "kb"  - one kilobyte (1024)
297 --  o "mb"      - one megabyte (1024*1024)
298 --  o "gb"      - one gigabyte (1024*1024*1024)
299 --  o "kib" - one si kilobyte (1000)
300 --  o "mib"     - one si megabyte (1000*1000)
301 --  o "gib"     - one si gigabyte (1000*1000*1000)
302 function parse_units(ustr)
303
304         local val = 0
305
306         -- unit map
307         local map = {
308                 -- date stuff
309                 y   = 60 * 60 * 24 * 366,
310                 m   = 60 * 60 * 24 * 31,
311                 w   = 60 * 60 * 24 * 7,
312                 d   = 60 * 60 * 24,
313                 h   = 60 * 60,
314                 min = 60,
315
316                 -- storage sizes
317                 kb  = 1024,
318                 mb  = 1024 * 1024,
319                 gb  = 1024 * 1024 * 1024,
320
321                 -- storage sizes (si)
322                 kib = 1000,
323                 mib = 1000 * 1000,
324                 gib = 1000 * 1000 * 1000
325         }
326
327         -- parse input string
328         for spec in ustr:lower():gmatch("[0-9%.]+[a-zA-Z]*") do
329
330                 local num = spec:gsub("[^0-9%.]+$","")
331                 local spn = spec:gsub("^[0-9%.]+", "")
332
333                 if map[spn] or map[spn:sub(1,1)] then
334                         val = val + num * ( map[spn] or map[spn:sub(1,1)] )
335                 else
336                         val = val + num
337                 end
338         end
339
340
341         return val
342 end
343
344 -- also register functions above in the central string class for convenience
345 string.pcdata      = pcdata
346 string.striptags   = striptags
347 string.split       = split
348 string.trim        = trim
349 string.cmatch      = cmatch
350 string.parse_units = parse_units
351
352
353 function append(src, ...)
354         for i, a in ipairs({...}) do
355                 if type(a) == "table" then
356                         for j, v in ipairs(a) do
357                                 src[#src+1] = v
358                         end
359                 else
360                         src[#src+1] = a
361                 end
362         end
363         return src
364 end
365
366 function combine(...)
367         return append({}, ...)
368 end
369
370 function contains(table, value)
371         for k, v in pairs(table) do
372                 if value == v then
373                         return k
374                 end
375         end
376         return false
377 end
378
379 -- Both table are - in fact - merged together.
380 function update(t, updates)
381         for k, v in pairs(updates) do
382                 t[k] = v
383         end
384 end
385
386 function keys(t)
387         local keys = { }
388         if t then
389                 for k, _ in kspairs(t) do
390                         keys[#keys+1] = k
391                 end
392         end
393         return keys
394 end
395
396 function clone(object, deep)
397         local copy = {}
398
399         for k, v in pairs(object) do
400                 if deep and type(v) == "table" then
401                         v = clone(v, deep)
402                 end
403                 copy[k] = v
404         end
405
406         return setmetatable(copy, getmetatable(object))
407 end
408
409
410 function dtable()
411         return setmetatable({}, { __index =
412                 function(tbl, key)
413                         return rawget(tbl, key)
414                          or rawget(rawset(tbl, key, dtable()), key)
415                 end
416         })
417 end
418
419
420 -- Serialize the contents of a table value.
421 function _serialize_table(t, seen)
422         assert(not seen[t], "Recursion detected.")
423         seen[t] = true
424
425         local data  = ""
426         local idata = ""
427         local ilen  = 0
428
429         for k, v in pairs(t) do
430                 if type(k) ~= "number" or k < 1 or math.floor(k) ~= k or ( k - #t ) > 3 then
431                         k = serialize_data(k, seen)
432                         v = serialize_data(v, seen)
433                         data = data .. ( #data > 0 and ", " or "" ) ..
434                                 '[' .. k .. '] = ' .. v
435                 elseif k > ilen then
436                         ilen = k
437                 end
438         end
439
440         for i = 1, ilen do
441                 local v = serialize_data(t[i], seen)
442                 idata = idata .. ( #idata > 0 and ", " or "" ) .. v
443         end
444
445         return idata .. ( #data > 0 and #idata > 0 and ", " or "" ) .. data
446 end
447
448 -- with loadstring().
449 function serialize_data(val, seen)
450         seen = seen or setmetatable({}, {__mode="k"})
451
452         if val == nil then
453                 return "nil"
454         elseif type(val) == "number" then
455                 return val
456         elseif type(val) == "string" then
457                 return "%q" % val
458         elseif type(val) == "boolean" then
459                 return val and "true" or "false"
460         elseif type(val) == "function" then
461                 return "loadstring(%q)" % get_bytecode(val)
462         elseif type(val) == "table" then
463                 return "{ " .. _serialize_table(val, seen) .. " }"
464         else
465                 return '"[unhandled data type:' .. type(val) .. ']"'
466         end
467 end
468
469 function restore_data(str)
470         return loadstring("return " .. str)()
471 end
472
473
474 --
475 -- Byte code manipulation routines
476 --
477
478 -- will be stripped before it is returned.
479 function get_bytecode(val)
480         local code
481
482         if type(val) == "function" then
483                 code = string.dump(val)
484         else
485                 code = string.dump( loadstring( "return " .. serialize_data(val) ) )
486         end
487
488         return code -- and strip_bytecode(code)
489 end
490
491 -- numbers and debugging numbers will be discarded. Original version by
492 -- Peter Cawley (http://lua-users.org/lists/lua-l/2008-02/msg01158.html)
493 function strip_bytecode(code)
494         local version, format, endian, int, size, ins, num, lnum = code:byte(5, 12)
495         local subint
496         if endian == 1 then
497                 subint = function(code, i, l)
498                         local val = 0
499                         for n = l, 1, -1 do
500                                 val = val * 256 + code:byte(i + n - 1)
501                         end
502                         return val, i + l
503                 end
504         else
505                 subint = function(code, i, l)
506                         local val = 0
507                         for n = 1, l, 1 do
508                                 val = val * 256 + code:byte(i + n - 1)
509                         end
510                         return val, i + l
511                 end
512         end
513
514         local function strip_function(code)
515                 local count, offset = subint(code, 1, size)
516                 local stripped = { string.rep("\0", size) }
517                 local dirty = offset + count
518                 offset = offset + count + int * 2 + 4
519                 offset = offset + int + subint(code, offset, int) * ins
520                 count, offset = subint(code, offset, int)
521                 for n = 1, count do
522                         local t
523                         t, offset = subint(code, offset, 1)
524                         if t == 1 then
525                                 offset = offset + 1
526                         elseif t == 4 then
527                                 offset = offset + size + subint(code, offset, size)
528                         elseif t == 3 then
529                                 offset = offset + num
530                         elseif t == 254 or t == 9 then
531                                 offset = offset + lnum
532                         end
533                 end
534                 count, offset = subint(code, offset, int)
535                 stripped[#stripped+1] = code:sub(dirty, offset - 1)
536                 for n = 1, count do
537                         local proto, off = strip_function(code:sub(offset, -1))
538                         stripped[#stripped+1] = proto
539                         offset = offset + off - 1
540                 end
541                 offset = offset + subint(code, offset, int) * int + int
542                 count, offset = subint(code, offset, int)
543                 for n = 1, count do
544                         offset = offset + subint(code, offset, size) + size + int * 2
545                 end
546                 count, offset = subint(code, offset, int)
547                 for n = 1, count do
548                         offset = offset + subint(code, offset, size) + size
549                 end
550                 stripped[#stripped+1] = string.rep("\0", int * 3)
551                 return table.concat(stripped), offset
552         end
553
554         return code:sub(1,12) .. strip_function(code:sub(13,-1))
555 end
556
557
558 --
559 -- Sorting iterator functions
560 --
561
562 function _sortiter( t, f )
563         local keys = { }
564
565         local k, v
566         for k, v in pairs(t) do
567                 keys[#keys+1] = k
568         end
569
570         local _pos = 0
571
572         table.sort( keys, f )
573
574         return function()
575                 _pos = _pos + 1
576                 if _pos <= #keys then
577                         return keys[_pos], t[keys[_pos]], _pos
578                 end
579         end
580 end
581
582 -- the provided callback function.
583 function spairs(t,f)
584         return _sortiter( t, f )
585 end
586
587 -- The table pairs are sorted by key.
588 function kspairs(t)
589         return _sortiter( t )
590 end
591
592 -- The table pairs are sorted by value.
593 function vspairs(t)
594         return _sortiter( t, function (a,b) return t[a] < t[b] end )
595 end
596
597
598 --
599 -- System utility functions
600 --
601
602 function bigendian()
603         return string.byte(string.dump(function() end), 7) == 0
604 end
605
606 function exec(command)
607         local pp   = io.popen(command)
608         local data = pp:read("*a")
609         pp:close()
610
611         return data
612 end
613
614 function execi(command)
615         local pp = io.popen(command)
616
617         return pp and function()
618                 local line = pp:read()
619
620                 if not line then
621                         pp:close()
622                 end
623
624                 return line
625         end
626 end
627
628 -- Deprecated
629 function execl(command)
630         local pp   = io.popen(command)
631         local line = ""
632         local data = {}
633
634         while true do
635                 line = pp:read()
636                 if (line == nil) then break end
637                 data[#data+1] = line
638         end
639         pp:close()
640
641         return data
642 end
643
644 function ubus(object, method, data)
645         if not _ubus_connection then
646                 _ubus_connection = _ubus.connect()
647                 assert(_ubus_connection, "Unable to establish ubus connection")
648         end
649
650         if object and method then
651                 if type(data) ~= "table" then
652                         data = { }
653                 end
654                 return _ubus_connection:call(object, method, data)
655         elseif object then
656                 return _ubus_connection:signatures(object)
657         else
658                 return _ubus_connection:objects()
659         end
660 end
661
662 function serialize_json(x, cb)
663         local js = json.stringify(x)
664         if type(cb) == "function" then
665                 cb(js)
666         else
667                 return js
668         end
669 end
670
671
672 function libpath()
673         return require "nixio.fs".dirname(ldebug.__file__)
674 end
675
676 function checklib(fullpathexe, wantedlib)
677         local fs = require "nixio.fs"
678         local haveldd = fs.access('/usr/bin/ldd')
679         local haveexe = fs.access(fullpathexe)
680         if not haveldd or not haveexe then
681                 return false
682         end
683         local libs = exec(string.format("/usr/bin/ldd %s", shellquote(fullpathexe)))
684         if not libs then
685                 return false
686         end
687         for k, v in ipairs(split(libs)) do
688                 if v:find(wantedlib) then
689                         return true
690                 end
691         end
692         return false
693 end
694
695 --
696 -- Coroutine safe xpcall and pcall versions modified for Luci
697 -- original version:
698 -- coxpcall 1.13 - Copyright 2005 - Kepler Project (www.keplerproject.org)
699 --
700 -- Copyright © 2005 Kepler Project.
701 -- Permission is hereby granted, free of charge, to any person obtaining a
702 -- copy of this software and associated documentation files (the "Software"),
703 -- to deal in the Software without restriction, including without limitation
704 -- the rights to use, copy, modify, merge, publish, distribute, sublicense,
705 -- and/or sell copies of the Software, and to permit persons to whom the
706 -- Software is furnished to do so, subject to the following conditions:
707 --
708 -- The above copyright notice and this permission notice shall be
709 -- included in all copies or substantial portions of the Software.
710 --
711 -- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
712 -- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
713 -- OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
714 -- IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
715 -- DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
716 -- TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
717 -- OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
718
719 local performResume, handleReturnValue
720 local oldpcall, oldxpcall = pcall, xpcall
721 coxpt = {}
722 setmetatable(coxpt, {__mode = "kv"})
723
724 -- Identity function for copcall
725 local function copcall_id(trace, ...)
726   return ...
727 end
728
729 --                              values of either the function or the error handler
730 function coxpcall(f, err, ...)
731         local res, co = oldpcall(coroutine.create, f)
732         if not res then
733                 local params = {...}
734                 local newf = function() return f(unpack(params)) end
735                 co = coroutine.create(newf)
736         end
737         local c = coroutine.running()
738         coxpt[co] = coxpt[c] or c or 0
739
740         return performResume(err, co, ...)
741 end
742
743 --                              values of the function or the error object
744 function copcall(f, ...)
745         return coxpcall(f, copcall_id, ...)
746 end
747
748 -- Handle return value of protected call
749 function handleReturnValue(err, co, status, ...)
750         if not status then
751                 return false, err(debug.traceback(co, (...)), ...)
752         end
753
754         if coroutine.status(co) ~= 'suspended' then
755                 return true, ...
756         end
757
758         return performResume(err, co, coroutine.yield(...))
759 end
760
761 -- Resume execution of protected function call
762 function performResume(err, co, ...)
763         return handleReturnValue(err, co, coroutine.resume(co, ...))
764 end