* luci/libs: fix off-by-one bug in luci.ip
[project/luci.git] / libs / core / luasrc / ip.lua
1 --[[
2
3 LuCI ip calculation libarary
4 (c) 2008 Jo-Philipp Wich <xm@leipzig.freifunk.net>
5 (c) 2008 Steven Barth <steven@midlink.org>
6
7 Licensed under the Apache License, Version 2.0 (the "License");
8 you may not use this file except in compliance with the License.
9 You may obtain a copy of the License at
10
11         http://www.apache.org/licenses/LICENSE-2.0
12
13 $Id$
14
15 ]]--
16
17 module( "luci.ip", package.seeall )
18
19 require("bit")
20 require("luci.util")
21
22 LITTLE_ENDIAN = not luci.util.bigendian()
23 BIG_ENDIAN    = not LITTLE_ENDIAN
24
25 FAMILY_INET4  = 0x04
26 FAMILY_INET6  = 0x06
27
28
29 local function __bless(x)
30         return setmetatable( x, {
31                 __index = luci.ip.cidr,
32                 __add   = luci.ip.cidr.add,
33                 __lt    = luci.ip.cidr.lower,
34                 __eq    = luci.ip.cidr.equal,
35                 __le    =
36                         function(...)
37                                 return luci.ip.cidr.equal(...) or luci.ip.cidr.lower(...)
38                         end
39         } )
40 end
41
42 local function __mask16(bits)
43         return bit.lshift(
44                 bit.rshift( 0xFFFF, 16 - bits % 16 ),
45                 16 - bits % 16
46         )
47 end
48
49 local function __length(family)
50         if family == FAMILY_INET4 then
51                 return 32
52         else
53                 return 128
54         end
55 end
56
57 -- htons(), htonl(), ntohs(), ntohl()
58
59 function htons(x)
60         if LITTLE_ENDIAN then
61                 return bit.bor(
62                         bit.rshift( x, 8 ),
63                         bit.band( bit.lshift( x, 8 ), 0xFF00 )
64                 )
65         else
66                 return x
67         end
68 end
69
70 function htonl(x)
71         if LITTLE_ENDIAN then
72                 return bit.bor(
73                         bit.lshift( htons( bit.band( x, 0xFFFF ) ), 16 ),
74                         htons( bit.rshift( x, 16 ) )
75                 )
76         else
77                 return x
78         end
79 end
80
81 ntohs = htons
82 ntohl = htonl
83
84
85 function IPv4(address, netmask)
86         address = address or "0.0.0.0/0"
87
88         local obj = __bless({ FAMILY_INET4 })
89
90         local data = {}
91         local prefix = address:match("/(.+)")
92
93         if netmask then
94                 prefix = obj:prefix(netmask)
95         elseif prefix then
96                 address = address:gsub("/.+","")
97                 prefix = tonumber(prefix)
98                 if not prefix or prefix < 0 or prefix > 32 then return nil end
99         else
100                 prefix = 32
101         end
102
103         local b1, b2, b3, b4 = address:match("^(%d+)%.(%d+)%.(%d+)%.(%d+)$")
104
105         b1 = tonumber(b1)
106         b2 = tonumber(b2)
107         b3 = tonumber(b3)
108         b4 = tonumber(b4)
109
110         if b1 and b1 <= 255 and
111            b2 and b2 <= 255 and
112            b3 and b3 <= 255 and
113            b4 and b4 <= 255 and
114            prefix
115         then
116                 table.insert(obj, { b1 * 256 + b2, b3 * 256 + b4 })
117                 table.insert(obj, prefix)
118                 return obj
119         end
120 end
121
122 function IPv6(address, netmask)
123         address = address or "::/0"
124
125         local obj = __bless({ FAMILY_INET6 })
126
127         local data = {}
128         local prefix = address:match("/(.+)")
129
130         if netmask then
131                 prefix = obj:prefix(netmask)
132         elseif prefix then
133                 address = address:gsub("/.+","")
134                 prefix = tonumber(prefix)
135                 if not prefix or prefix < 0 or prefix > 128 then return nil end
136         else
137                 prefix = 128
138         end
139
140         local borderl = address:sub(1, 1) == ":" and 2 or 1
141         local borderh, zeroh, chunk, block
142
143         if #address > 45 then return nil end
144
145         repeat
146                 borderh = address:find(":", borderl, true)
147                 if not borderh then break end
148
149                 block = tonumber(address:sub(borderl, borderh - 1), 16)
150                 if block and block <= 65535 then
151                         table.insert(data, block)
152                 else
153                         if zeroh or borderh - borderl > 1 then return nil end
154                         zeroh = #data + 1
155                 end
156
157                 borderl = borderh + 1
158         until #data == 7
159
160         chunk = address:sub(borderl)
161         if #chunk > 0 and #chunk <= 4 then
162                 block = tonumber(chunk, 16)
163                 if not block or block > 65535 then return nil end
164
165                 table.insert(data, block)
166         elseif #chunk > 4 then
167                 if #data == 7 or #chunk > 15 then return nil end
168                 borderl = 1
169                 for i=1, 4 do
170                         borderh = chunk:find(".", borderl, true)
171                         if not borderh and i < 4 then return nil end
172                         borderh = borderh and borderh - 1
173
174                         block = tonumber(chunk:sub(borderl, borderh))
175                         if not block or block > 255 then return nil end
176
177                         if i == 1 or i == 3 then
178                                 table.insert(data, block * 256)
179                         else
180                                 data[#data] = data[#data] + block
181                         end
182
183                         borderl = borderh and borderh + 2
184                 end
185         end
186
187         if zeroh then
188                 if #data == 8 then return nil end
189                 while #data < 8 do
190                         table.insert(data, zeroh, 0)
191                 end
192         end
193
194         if #data == 8 and prefix then
195                 table.insert(obj, data)
196                 table.insert(obj, prefix)
197                 return obj
198         end
199 end
200
201 function Hex( hex, prefix, family, swap )
202         family = ( family ~= nil ) and family or FAMILY_INET4
203         swap   = ( swap   == nil ) and true   or swap
204         prefix = prefix or __length(family)
205
206         local len  = __length(family)
207         local tmp  = ""
208         local data = { }
209
210         for i = 1, (len/4) - #hex do tmp = tmp .. '0' end
211
212         if swap and LITTLE_ENDIAN then
213                 for i = #hex, 1, -2 do tmp = tmp .. hex:sub( i - 1, i ) end
214         else
215                 tmp = tmp .. hex
216         end
217
218         hex = tmp
219
220         for i = 1, ( len / 4 ), 4 do
221                 local n = tonumber( hex:sub( i, i+3 ), 16 )
222                 if n then
223                         table.insert( data, n )
224                 else
225                         return nil
226                 end
227         end
228
229         return __bless({ family, data, len })
230 end
231
232
233 cidr = luci.util.class()
234
235 function cidr.is4( self )
236         return self[1] == FAMILY_INET4
237 end
238
239 function cidr.is6( self )
240         return self[1] == FAMILY_INET6
241 end
242
243 function cidr.string( self )
244         local str
245         if self:is4() then
246                 str = string.format(
247                         "%d.%d.%d.%d",
248                         bit.rshift(self[2][1], 8), bit.band(self[2][1], 0xFF),
249                         bit.rshift(self[2][2], 8), bit.band(self[2][2], 0xFF)
250                 )
251                 if self[3] < 32 then
252                         str = str .. "/" .. self[3]
253                 end
254         elseif self:is6() then
255                 str = string.format( "%X:%X:%X:%X:%X:%X:%X:%X", unpack(self[2]) )
256                 if self[3] < 128 then
257                         str = str .. "/" .. self[3]
258                 end
259         end
260         return str
261 end
262
263 function cidr.lower( self, addr )
264         assert( self[1] == addr[1], "Can't compare IPv4 and IPv6 addresses" )
265         for i = 1, #self[2] do
266                 if self[2][i] ~= addr[2][i] then
267                         return self[2][i] < addr[2][i]
268                 end
269         end
270         return false
271 end
272
273 function cidr.higher( self, addr )
274         assert( self[1] == addr[1], "Can't compare IPv4 and IPv6 addresses" )
275         for i = 1, #self[2] do
276                 if self[2][i] ~= addr[2][i] then
277                         return self[2][i] > addr[2][i]
278                 end
279         end
280         return false
281 end
282
283 function cidr.equal( self, addr )
284         assert( self[1] == addr[1], "Can't compare IPv4 and IPv6 addresses" )
285         for i = 1, #self[2] do
286                 if self[2][i] ~= addr[2][i] then
287                         return false
288                 end
289         end
290         return true
291 end
292
293 function cidr.prefix( self, mask )
294         local prefix = self[3]
295
296         if mask then
297                 prefix = 0
298                 local stop = false
299                 local obj = self:is4() and IPv4(mask) or IPv6(mask)
300
301                 if not obj then
302                         return nil
303                 end
304
305                 for i, block in ipairs(obj[2]) do
306                         local pos = bit.lshift(1, 15)
307                         for i=15, 0, -1 do
308                                 if bit.band(block, pos) == pos then
309                                         if not stop then
310                                                 prefix = prefix + 1
311                                         else
312                                                 return nil
313                                         end
314                                 else
315                                         stop = true
316                                 end
317                                 pos = bit.rshift(pos, 1)
318                         end
319                 end
320         end
321
322         return prefix
323 end
324
325 function cidr.network( self, bits )
326         local data = { }
327         bits = bits or self[3]
328
329         for i = 1, math.floor( bits / 16 ) do
330                 table.insert( data, self[2][i] )
331         end
332
333         if #data < #self[2] then
334                 table.insert( data, bit.band( self[2][1+#data], __mask16(bits) ) )
335
336                 for i = #data + 1, #self[2] do
337                         table.insert( data, 0 )
338                 end
339         end
340
341         return __bless({ self[1], data, __length(self[1]) })
342 end
343
344 function cidr.host( self )
345         return __bless({ self[1], data, __length(self[1]) })
346 end
347
348 function cidr.mask( self, bits )
349         local data = { }
350         bits = bits or self[3]
351
352         for i = 1, math.floor( bits / 16 ) do
353                 table.insert( data, 0xFFFF )
354         end
355
356         if #data < #self[2] then
357                 table.insert( data, __mask16(bits) )
358
359                 for i = #data + 1, #self[2] do
360                         table.insert( data, 0 )
361                 end
362         end
363
364         return __bless({ self[1], data, __length(self[1]) })
365 end
366
367 function cidr.contains( self, addr )
368         assert( self[1] == addr[1], "Can't compare IPv4 and IPv6 addresses" )
369
370         if self:prefix() <= addr:prefix() then
371                 return self:network() == addr:network(self:prefix())
372         end
373
374         return false
375 end
376
377 function cidr.add( self, amount )
378         local shorts = { bit.rshift(amount, 16), bit.band(amount, 0xFFFF) }
379         local data   = { unpack(self[2]) }
380
381         for pos = #data, 1, -1 do
382                 local add = ( #shorts > 0 ) and table.remove( shorts, #shorts ) or 0
383                 if ( data[pos] + add ) > 0xFFFF then
384                         data[pos] = ( data[pos] + add ) % 0xFFFF
385                         if pos > 2 then
386                                 data[pos-1] = data[pos-1] + ( add - data[pos] )
387                         end
388                 else
389                         data[pos] = data[pos] + add
390                 end
391         end
392
393         return __bless({ self[1], data, self[3] })
394 end