Allow Basic-Auth pass-through
[project/luci.git] / libs / lucid-http / luasrc / lucid / http / server.lua
1 --[[
2 LuCId HTTP-Slave
3 (c) 2009 Steven Barth <steven@midlink.org>
4
5 Licensed under the Apache License, Version 2.0 (the "License");
6 you may not use this file except in compliance with the License.
7 You may obtain a copy of the License at
8
9         http://www.apache.org/licenses/LICENSE-2.0
10
11 $Id$
12 ]]--
13
14 local ipairs, pairs = ipairs, pairs
15 local tostring, tonumber = tostring, tonumber
16 local pcall, assert, type = pcall, assert, type
17 local set_memory_limit = set_memory_limit
18
19 local os = require "os"
20 local nixio = require "nixio"
21 local util = require "luci.util"
22 local ltn12 = require "luci.ltn12"
23 local proto = require "luci.http.protocol"
24 local table = require "table"
25 local date = require "luci.http.protocol.date"
26
27 --- HTTP Daemon
28 -- @cstyle instance
29 module "luci.lucid.http.server"
30
31 VERSION = "1.0"
32
33 statusmsg = {
34         [200] = "OK",
35         [206] = "Partial Content",
36         [301] = "Moved Permanently",
37         [302] = "Found",
38         [304] = "Not Modified",
39         [400] = "Bad Request",
40         [401] = "Unauthorized",
41         [403] = "Forbidden",
42         [404] = "Not Found",
43         [405] = "Method Not Allowed",
44         [408] = "Request Time-out",
45         [411] = "Length Required",
46         [412] = "Precondition Failed",
47         [416] = "Requested range not satisfiable",
48         [500] = "Internal Server Error",
49         [503] = "Server Unavailable",
50 }
51
52 --- Create a new IO resource response.
53 -- @class function
54 -- @param fd File descriptor
55 -- @param len Length of data
56 -- @return IO resource
57 IOResource = util.class()
58
59 function IOResource.__init__(self, fd, len)
60         self.fd, self.len = fd, len
61 end
62
63
64 --- Create a server handler.
65 -- @class function
66 -- @param name Name
67 -- @return Handler
68 Handler = util.class()
69
70 function Handler.__init__(self, name)
71         self.name = name or tostring(self)
72 end
73
74 --- Create a failure reply.
75 -- @param code HTTP status code
76 -- @param msg Status message
77 -- @return status code, header table, response source
78 function Handler.failure(self, code, msg)       
79         return code, { ["Content-Type"] = "text/plain" }, ltn12.source.string(msg)
80 end
81
82 --- Add an access restriction.
83 -- @param restriction Restriction specification
84 function Handler.restrict(self, restriction)
85         if not self.restrictions then
86                 self.restrictions = {restriction}
87         else
88                 self.restrictions[#self.restrictions+1] = restriction
89         end
90 end
91
92 --- Enforce access restrictions.
93 -- @param request Request object
94 -- @return nil or HTTP statuscode, table of headers, response source
95 function Handler.checkrestricted(self, request)
96         if not self.restrictions then
97                 return
98         end
99
100         local localif, user, pass
101         
102         for _, r in ipairs(self.restrictions) do
103                 local stat = true
104                 if stat and r.interface then    -- Interface restriction
105                         if not localif then
106                                 for _, v in ipairs(request.server.interfaces) do
107                                         if v.addr == request.env.SERVER_ADDR then
108                                                 localif = v.name
109                                                 break
110                                         end
111                                 end
112                         end
113                         
114                         if r.interface ~= localif then
115                                 stat = false
116                         end
117                 end
118                 
119                 if stat and r.user then -- User restriction
120                         local rh, pwe
121                         if not user then
122                                 rh = (request.headers.Authorization or ""):match("Basic (.*)")
123                                 rh = rh and nixio.bin.b64decode(rh) or ""
124                                 user, pass = rh:match("(.*):(.*)")
125                                 pass = pass or ""
126                         end
127                         pwe = nixio.getsp and nixio.getsp(r.user) or nixio.getpw(r.user)
128                         local pwh = (user == r.user) and pwe and (pwe.pwdp or pwe.passwd)
129                         if not pwh or #pwh < 1 or nixio.crypt(pass, pwh) ~= pwh then
130                                 stat = false
131                         end
132                 end
133                 
134                 if stat then
135                         request.env.HTTP_AUTH_USER, request.env.HTTP_AUTH_PASS = user, pass
136                         return
137                 end
138         end
139         
140         return 401, {
141                 ["WWW-Authenticate"] = ('Basic realm=%q'):format(self.name),
142                 ["Content-Type"] = 'text/plain'
143         }, ltn12.source.string("Unauthorized")
144 end
145
146 --- Process a request.
147 -- @param request Request object
148 -- @param sourcein Request data source
149 -- @return HTTP statuscode, table of headers, response source
150 function Handler.process(self, request, sourcein)
151         local stat, code, hdr, sourceout
152         
153         local stat, code, msg = self:checkrestricted(request)
154         if stat then    -- Access Denied
155                 return stat, code, msg
156         end
157
158         -- Detect request Method
159         local hname = "handle_" .. request.env.REQUEST_METHOD
160         if self[hname] then
161                 -- Run the handler
162                 stat, code, hdr, sourceout = pcall(self[hname], self, request, sourcein)
163
164                 -- Check for any errors
165                 if not stat then
166                         return self:failure(500, code)
167                 end
168         else
169                 return self:failure(405, statusmsg[405])
170         end
171
172         return code, hdr, sourceout
173 end
174
175
176 --- Create a Virtual Host.
177 -- @class function
178 -- @return Virtual Host
179 VHost = util.class()
180
181 function VHost.__init__(self)
182         self.handlers = {}
183 end
184
185 --- Process a request and invoke the appropriate handler. 
186 -- @param request Request object
187 -- @param ... Additional parameters passed to the handler
188 -- @return HTTP statuscode, table of headers, response source 
189 function VHost.process(self, request, ...)
190         local handler
191         local hlen = -1
192         local uri = request.env.SCRIPT_NAME
193         local sc = ("/"):byte()
194
195         -- SCRIPT_NAME
196         request.env.SCRIPT_NAME = ""
197
198         -- Call URI part
199         request.env.PATH_INFO = uri
200         
201         for k, h in pairs(self.handlers) do
202                 if #k > hlen then
203                         if uri == k or (uri:sub(1, #k) == k and uri:byte(#k+1) == sc) then
204                                 handler = h
205                                 hlen = #k
206                                 request.env.SCRIPT_NAME = k
207                                 request.env.PATH_INFO   = uri:sub(#k+1)
208                         end
209                 end
210         end
211         
212         if handler then
213                 return handler:process(request, ...)
214         else
215                 return 404, nil, ltn12.source.string("No such handler")
216         end
217 end
218
219 --- Get a list of registered handlers.
220 -- @return Table of handlers
221 function VHost.get_handlers(self)
222         return self.handlers
223 end
224
225 --- Register handler with a given URI prefix.
226 -- @oaram match URI prefix
227 -- @param handler Handler object
228 function VHost.set_handler(self, match, handler)
229         self.handlers[match] = handler
230 end
231
232 -- Remap IPv6-IPv4-compatibility addresses back to IPv4 addresses.
233 local function remapipv6(adr)
234         local map = "::ffff:"
235         if adr:sub(1, #map) == map then
236                 return adr:sub(#map+1)
237         else
238                 return adr
239         end 
240 end
241
242 -- Create a source that decodes chunked-encoded data from a socket.
243 local function chunksource(sock, buffer)
244         buffer = buffer or ""
245         return function()
246                 local output
247                 local _, endp, count = buffer:find("^([0-9a-fA-F]+);?.-\r\n")
248                 while not count and #buffer <= 1024 do
249                         local newblock, code = sock:recv(1024 - #buffer)
250                         if not newblock then
251                                 return nil, code
252                         end
253                         buffer = buffer .. newblock  
254                         _, endp, count = buffer:find("^([0-9a-fA-F]+);?.-\r\n")
255                 end
256                 count = tonumber(count, 16)
257                 if not count then
258                         return nil, -1, "invalid encoding"
259                 elseif count == 0 then
260                         return nil
261                 elseif count + 2 <= #buffer - endp then
262                         output = buffer:sub(endp+1, endp+count)
263                         buffer = buffer:sub(endp+count+3)
264                         return output
265                 else
266                         output = buffer:sub(endp+1, endp+count)
267                         buffer = ""
268                         if count - #output > 0 then
269                                 local remain, code = sock:recvall(count-#output)
270                                 if not remain then
271                                         return nil, code
272                                 end
273                                 output = output .. remain
274                                 count, code = sock:recvall(2)
275                         else
276                                 count, code = sock:recvall(count+2-#buffer+endp)
277                         end
278                         if not count then
279                                 return nil, code
280                         end
281                         return output
282                 end
283         end
284 end
285
286 -- Create a sink that chunk-encodes data and writes it on a given socket.
287 local function chunksink(sock)
288         return function(chunk, err)
289                 if not chunk then
290                         return sock:writeall("0\r\n\r\n")
291                 else
292                         return sock:writeall(("%X\r\n%s\r\n"):format(#chunk, tostring(chunk)))
293                 end
294         end
295 end
296
297
298 --- Create a server object.
299 -- @class function
300 -- @return Server object
301 Server = util.class()
302
303 function Server.__init__(self)
304         self.vhosts = {}
305 end
306
307 --- Get a list of registered virtual hosts.
308 -- @return Table of virtual hosts
309 function Server.get_vhosts(self)
310         return self.vhosts
311 end
312
313 --- Register a virtual host with a given name.
314 -- @param name Hostname
315 -- @param vhost Virtual host object
316 function Server.set_vhost(self, name, vhost)
317         self.vhosts[name] = vhost
318 end
319
320 --- Send a fatal error message to given client and close the connection.
321 -- @param client Client socket
322 -- @param code HTTP status code
323 -- @param msg status message
324 function Server.error(self, client, code, msg)
325         hcode = tostring(code)
326         
327         client:writeall( "HTTP/1.0 " .. hcode .. " " ..
328          statusmsg[code] .. "\r\n" )
329         client:writeall( "Connection: close\r\n" )
330         client:writeall( "Content-Type: text/plain\r\n\r\n" )
331
332         if msg then
333                 client:writeall( "HTTP-Error " .. code .. ": " .. msg .. "\r\n" )
334         end
335         
336         client:close()
337 end
338
339 local hdr2env = {
340         ["Content-Length"] = "CONTENT_LENGTH",
341         ["Content-Type"] = "CONTENT_TYPE",
342         ["Content-type"] = "CONTENT_TYPE",
343         ["Accept"] = "HTTP_ACCEPT",
344         ["Accept-Charset"] = "HTTP_ACCEPT_CHARSET",
345         ["Accept-Encoding"] = "HTTP_ACCEPT_ENCODING",
346         ["Accept-Language"] = "HTTP_ACCEPT_LANGUAGE",
347         ["Connection"] = "HTTP_CONNECTION",
348         ["Cookie"] = "HTTP_COOKIE",
349         ["Host"] = "HTTP_HOST",
350         ["Referer"] = "HTTP_REFERER",
351         ["User-Agent"] = "HTTP_USER_AGENT"
352 }
353
354 --- Parse the request headers and prepare the environment.
355 -- @param source line-based input source
356 -- @return Request object
357 function Server.parse_headers(self, source)
358         local env = {}
359         local req = {env = env, headers = {}}
360         local line, err
361
362         repeat  -- Ignore empty lines
363                 line, err = source()
364                 if not line then
365                         return nil, err
366                 end
367         until #line > 0
368         
369         env.REQUEST_METHOD, env.REQUEST_URI, env.SERVER_PROTOCOL =
370                 line:match("^([A-Z]+) ([^ ]+) (HTTP/1%.[01])$")
371                 
372         if not env.REQUEST_METHOD then
373                 return nil, "invalid magic"
374         end
375         
376         local key, envkey, val
377         repeat
378                 line, err = source()
379                 if not line then
380                         return nil, err
381                 elseif #line > 0 then   
382                         key, val = line:match("^([%w-]+)%s?:%s?(.*)")
383                         if key then
384                                 req.headers[key] = val
385                                 envkey = hdr2env[key]
386                                 if envkey then
387                                         env[envkey] = val
388                                 end
389                         else
390                                 return nil, "invalid header line"
391                         end
392                 else
393                         break
394                 end
395         until false
396         
397         env.SCRIPT_NAME, env.QUERY_STRING = env.REQUEST_URI:match("([^?]*)%??(.*)")
398         return req
399 end
400
401 --- Handle a new client connection.
402 -- @param client client socket
403 -- @param env superserver environment
404 function Server.process(self, client, env)
405         local sourcein  = function() end
406         local sourcehdr = client:linesource()
407         local sinkout
408         local buffer
409         
410         local close = false
411         local stat, code, msg, message, err
412         
413         env.config.memlimit = tonumber(env.config.memlimit)
414         if env.config.memlimit and set_memory_limit then
415                 set_memory_limit(env.config.memlimit)
416         end
417
418         client:setsockopt("socket", "rcvtimeo", 5)
419         client:setsockopt("socket", "sndtimeo", 5)
420         
421         repeat
422                 -- parse headers
423                 message, err = self:parse_headers(sourcehdr)
424
425                 -- any other error
426                 if not message or err then
427                         if err == 11 then       -- EAGAIN
428                                 break
429                         else
430                                 return self:error(client, 400, err)
431                         end
432                 end
433
434                 -- Prepare sources and sinks
435                 buffer = sourcehdr(true)
436                 sinkout = client:sink()
437                 message.server = env
438                 
439                 if client:is_tls_socket() then
440                         message.env.HTTPS = "on"
441                 end
442                 
443                 -- Addresses
444                 message.env.REMOTE_ADDR = remapipv6(env.host)
445                 message.env.REMOTE_PORT = env.port
446                 
447                 local srvaddr, srvport = client:getsockname()
448                 message.env.SERVER_ADDR = remapipv6(srvaddr)
449                 message.env.SERVER_PORT = srvport
450                 
451                 -- keep-alive
452                 if message.env.SERVER_PROTOCOL == "HTTP/1.1" then
453                         close = (message.env.HTTP_CONNECTION == "close")
454                 else
455                         close = not message.env.HTTP_CONNECTION 
456                                 or message.env.HTTP_CONNECTION == "close"
457                 end
458
459                 -- Uncomment this to disable keep-alive
460                 close = close or env.config.nokeepalive
461         
462                 if message.env.REQUEST_METHOD == "GET"
463                 or message.env.REQUEST_METHOD == "HEAD" then
464                         -- Be happy
465                         
466                 elseif message.env.REQUEST_METHOD == "POST" then
467                         -- If we have a HTTP/1.1 client and an Expect: 100-continue header
468                         -- respond with HTTP 100 Continue message
469                         if message.env.SERVER_PROTOCOL == "HTTP/1.1" 
470                         and message.headers.Expect == '100-continue' then
471                                 client:writeall("HTTP/1.1 100 Continue\r\n\r\n")
472                         end
473                         
474                         if message.headers['Transfer-Encoding'] and
475                          message.headers['Transfer-Encoding'] ~= "identity" then
476                                 sourcein = chunksource(client, buffer)
477                                 buffer = nil
478                         elseif message.env.CONTENT_LENGTH then
479                                 local len = tonumber(message.env.CONTENT_LENGTH)
480                                 if #buffer >= len then
481                                         sourcein = ltn12.source.string(buffer:sub(1, len))
482                                         buffer = buffer:sub(len+1)
483                                 else
484                                         sourcein = ltn12.source.cat(
485                                                 ltn12.source.string(buffer),
486                                                 client:blocksource(nil, len - #buffer)
487                                         )
488                                 end
489                         else
490                                 return self:error(client, 411, statusmsg[411])
491                         end
492
493                         close = true
494                 else
495                         return self:error(client, 405, statusmsg[405])
496                 end
497
498
499                 local host = self.vhosts[message.env.HTTP_HOST] or self.vhosts[""]
500                 if not host then
501                         return self:error(client, 404, "No virtual host found")
502                 end
503                 
504                 local code, headers, sourceout = host:process(message, sourcein)
505                 headers = headers or {}
506                 
507                 -- Post process response
508                 if sourceout then
509                         if util.instanceof(sourceout, IOResource) then
510                                 if not headers["Content-Length"] then
511                                         headers["Content-Length"] = sourceout.len
512                                 end
513                         end
514                         if not headers["Content-Length"] and not close then
515                                 if message.env.SERVER_PROTOCOL == "HTTP/1.1" then
516                                         headers["Transfer-Encoding"] = "chunked"
517                                         sinkout = chunksink(client)
518                                 else
519                                         close = true
520                                 end
521                         end
522                 elseif message.env.REQUEST_METHOD ~= "HEAD" then
523                         headers["Content-Length"] = 0
524                 end
525                 
526                 if close then
527                         headers["Connection"] = "close"
528                 elseif message.env.SERVER_PROTOCOL == "HTTP/1.0" then
529                         headers["Connection"] = "Keep-Alive"
530                 end 
531
532                 headers["Date"] = date.to_http(os.time())
533                 local header = {
534                         message.env.SERVER_PROTOCOL .. " " .. tostring(code) .. " " 
535                                 .. statusmsg[code],
536                         "Server: LuCId-HTTPd/" .. VERSION
537                 }
538
539                 
540                 for k, v in pairs(headers) do
541                         if type(v) == "table" then
542                                 for _, h in ipairs(v) do
543                                         header[#header+1] = k .. ": " .. h
544                                 end
545                         else
546                                 header[#header+1] = k .. ": " .. v
547                         end
548                 end
549
550                 header[#header+1] = ""
551                 header[#header+1] = ""
552                 
553                 -- Output
554                 stat, code, msg = client:writeall(table.concat(header, "\r\n"))
555
556                 if sourceout and stat then
557                         if util.instanceof(sourceout, IOResource) then
558                                 if not headers["Transfer-Encoding"] then
559                                         stat, code, msg = sourceout.fd:copyz(client, sourceout.len)
560                                         sourceout = nil
561                                 else
562                                         sourceout = sourceout.fd:blocksource(nil, sourceout.len)
563                                 end
564                         end
565
566                         if sourceout then
567                                 stat, msg = ltn12.pump.all(sourceout, sinkout)
568                         end
569                 end
570
571
572                 -- Write errors
573                 if not stat then
574                         if msg then
575                                 nixio.syslog("err", "Error sending data to " .. env.host ..
576                                         ": " .. msg .. "\n")
577                         end
578                         break
579                 end
580                 
581                 if buffer then
582                         sourcehdr(buffer)
583                 end
584         until close
585         
586         client:shutdown()
587         client:close()
588 end