HTTPd: No keep-alive after POST request, because we do not check state
[project/luci.git] / libs / lucid-http / luasrc / lucid / http / server.lua
1 --[[
2 LuCId HTTP-Slave
3 (c) 2009 Steven Barth <steven@midlink.org>
4
5 Licensed under the Apache License, Version 2.0 (the "License");
6 you may not use this file except in compliance with the License.
7 You may obtain a copy of the License at
8
9         http://www.apache.org/licenses/LICENSE-2.0
10
11 $Id$
12 ]]--
13
14 local ipairs, pairs = ipairs, pairs
15 local tostring, tonumber = tostring, tonumber
16 local pcall, assert, type = pcall, assert, type
17
18 local os = require "os"
19 local nixio = require "nixio"
20 local util = require "luci.util"
21 local ltn12 = require "luci.ltn12"
22 local proto = require "luci.http.protocol"
23 local table = require "table"
24 local date = require "luci.http.protocol.date"
25
26 module "luci.lucid.http.server"
27
28 VERSION = "1.0"
29
30 statusmsg = {
31         [200] = "OK",
32         [206] = "Partial Content",
33         [301] = "Moved Permanently",
34         [302] = "Found",
35         [304] = "Not Modified",
36         [400] = "Bad Request",
37         [401] = "Unauthorized",
38         [403] = "Forbidden",
39         [404] = "Not Found",
40         [405] = "Method Not Allowed",
41         [408] = "Request Time-out",
42         [411] = "Length Required",
43         [412] = "Precondition Failed",
44         [416] = "Requested range not satisfiable",
45         [500] = "Internal Server Error",
46         [503] = "Server Unavailable",
47 }
48
49 -- File Resource
50 IOResource = util.class()
51
52 function IOResource.__init__(self, fd, len)
53         self.fd, self.len = fd, len
54 end
55
56
57 -- Server handler implementation
58 Handler = util.class()
59
60 function Handler.__init__(self, name)
61         self.name = name or tostring(self)
62 end
63
64 -- Creates a failure reply
65 function Handler.failure(self, code, msg)       
66         return code, { ["Content-Type"] = "text/plain" }, ltn12.source.string(msg)
67 end
68
69 -- Access Restrictions
70 function Handler.restrict(self, restriction)
71         if not self.restrictions then
72                 self.restrictions = {restriction}
73         else
74                 self.restrictions[#self.restrictions+1] = restriction
75         end
76 end
77
78 -- Check restrictions
79 function Handler.checkrestricted(self, request)
80         if not self.restrictions then
81                 return
82         end
83
84         local localif, user, pass
85         
86         for _, r in ipairs(self.restrictions) do
87                 local stat = true
88                 if stat and r.interface then    -- Interface restriction
89                         if not localif then
90                                 for _, v in ipairs(request.server.interfaces) do
91                                         if v.addr == request.env.SERVER_ADDR then
92                                                 localif = v.name
93                                                 break
94                                         end
95                                 end
96                         end
97                         
98                         if r.interface ~= localif then
99                                 stat = false
100                         end
101                 end
102                 
103                 if stat and r.user then -- User restriction
104                         local rh, pwe
105                         if not user then
106                                 rh = (request.headers.Authorization or ""):match("Basic (.*)")
107                                 rh = rh and nixio.bin.b64decode(rh) or ""
108                                 user, pass = rh:match("(.*):(.*)")
109                                 pass = pass or ""
110                         end
111                         pwe = nixio.getsp and nixio.getsp(r.user) or nixio.getpw(r.user)
112                         local pwh = (user == r.user) and pwe and (pwe.pwdp or pwe.passwd)
113                         if not pwh or #pwh < 1 or nixio.crypt(pass, pwh) ~= pwh then
114                                 stat = false
115                         end
116                 end
117                 
118                 if stat then
119                         return
120                 end
121         end
122         
123         return 401, {
124                 ["WWW-Authenticate"] = ('Basic realm=%q'):format(self.name),
125                 ["Content-Type"] = 'text/plain'
126         }, ltn12.source.string("Unauthorized")
127 end
128
129 -- Processes a request
130 function Handler.process(self, request, sourcein)
131         local stat, code, hdr, sourceout
132         
133         local stat, code, msg = self:checkrestricted(request)
134         if stat then    -- Access Denied
135                 return stat, code, msg
136         end
137
138         -- Detect request Method
139         local hname = "handle_" .. request.env.REQUEST_METHOD
140         if self[hname] then
141                 -- Run the handler
142                 stat, code, hdr, sourceout = pcall(self[hname], self, request, sourcein)
143
144                 -- Check for any errors
145                 if not stat then
146                         return self:failure(500, code)
147                 end
148         else
149                 return self:failure(405, statusmsg[405])
150         end
151
152         return code, hdr, sourceout
153 end
154
155
156 VHost = util.class()
157
158 function VHost.__init__(self)
159         self.handlers = {}
160 end
161
162 function VHost.process(self, request, ...)
163         local handler
164         local hlen = -1
165         local uri = request.env.SCRIPT_NAME
166         local sc = ("/"):byte()
167
168         -- SCRIPT_NAME
169         request.env.SCRIPT_NAME = ""
170
171         -- Call URI part
172         request.env.PATH_INFO = uri
173         
174         for k, h in pairs(self.handlers) do
175                 if #k > hlen then
176                         if uri == k or (uri:sub(1, #k) == k and uri:byte(#k+1) == sc) then
177                                 handler = h
178                                 hlen = #k
179                                 request.env.SCRIPT_NAME = k
180                                 request.env.PATH_INFO   = uri:sub(#k+1)
181                         end
182                 end
183         end
184         
185         if handler then
186                 return handler:process(request, ...)
187         else
188                 return 404, nil, ltn12.source.string("No such handler")
189         end
190 end
191
192 function VHost.get_handlers(self)
193         return self.handlers
194 end
195
196 function VHost.set_handler(self, match, handler)
197         self.handlers[match] = handler
198 end
199
200
201 local function remapipv6(adr)
202         local map = "::ffff:"
203         if adr:sub(1, #map) == map then
204                 return adr:sub(#map+1)
205         else
206                 return adr
207         end 
208 end
209
210 local function chunksource(sock, buffer)
211         buffer = buffer or ""
212         return function()
213                 local output
214                 local _, endp, count = buffer:find("^([0-9a-fA-F]+);?.-\r\n")
215                 while not count and #buffer <= 1024 do
216                         local newblock, code = sock:recv(1024 - #buffer)
217                         if not newblock then
218                                 return nil, code
219                         end
220                         buffer = buffer .. newblock  
221                         _, endp, count = buffer:find("^([0-9a-fA-F]+);?.-\r\n")
222                 end
223                 count = tonumber(count, 16)
224                 if not count then
225                         return nil, -1, "invalid encoding"
226                 elseif count == 0 then
227                         return nil
228                 elseif count + 2 <= #buffer - endp then
229                         output = buffer:sub(endp+1, endp+count)
230                         buffer = buffer:sub(endp+count+3)
231                         return output
232                 else
233                         output = buffer:sub(endp+1, endp+count)
234                         buffer = ""
235                         if count - #output > 0 then
236                                 local remain, code = sock:recvall(count-#output)
237                                 if not remain then
238                                         return nil, code
239                                 end
240                                 output = output .. remain
241                                 count, code = sock:recvall(2)
242                         else
243                                 count, code = sock:recvall(count+2-#buffer+endp)
244                         end
245                         if not count then
246                                 return nil, code
247                         end
248                         return output
249                 end
250         end
251 end
252
253 local function chunksink(sock)
254         return function(chunk, err)
255                 if not chunk then
256                         return sock:writeall("0\r\n\r\n")
257                 else
258                         return sock:writeall(("%X\r\n%s\r\n"):format(#chunk, chunk))
259                 end
260         end
261 end
262
263 Server = util.class()
264
265 function Server.__init__(self)
266         self.vhosts = {}
267 end
268
269 function Server.get_vhosts(self)
270         return self.vhosts
271 end
272
273 function Server.set_vhost(self, name, vhost)
274         self.vhosts[name] = vhost
275 end
276
277 function Server.error(self, client, code, msg)
278         hcode = tostring(code)
279         
280         client:writeall( "HTTP/1.0 " .. hcode .. " " ..
281          statusmsg[code] .. "\r\n" )
282         client:writeall( "Connection: close\r\n" )
283         client:writeall( "Content-Type: text/plain\r\n\r\n" )
284
285         if msg then
286                 client:writeall( "HTTP-Error " .. code .. ": " .. msg .. "\r\n" )
287         end
288         
289         client:close()
290 end
291
292 local hdr2env = {
293         ["Content-Length"] = "CONTENT_LENGTH",
294         ["Content-Type"] = "CONTENT_TYPE",
295         ["Content-type"] = "CONTENT_TYPE",
296         ["Accept"] = "HTTP_ACCEPT",
297         ["Accept-Charset"] = "HTTP_ACCEPT_CHARSET",
298         ["Accept-Encoding"] = "HTTP_ACCEPT_ENCODING",
299         ["Accept-Language"] = "HTTP_ACCEPT_LANGUAGE",
300         ["Connection"] = "HTTP_CONNECTION",
301         ["Cookie"] = "HTTP_COOKIE",
302         ["Host"] = "HTTP_HOST",
303         ["Referer"] = "HTTP_REFERER",
304         ["User-Agent"] = "HTTP_USER_AGENT"
305 }
306
307 function Server.parse_headers(self, source)
308         local env = {}
309         local req = {env = env, headers = {}}
310         local line, err
311
312         repeat  -- Ignore empty lines
313                 line, err = source()
314                 if not line then
315                         return nil, err
316                 end
317         until #line > 0
318         
319         env.REQUEST_METHOD, env.REQUEST_URI, env.SERVER_PROTOCOL =
320                 line:match("^([A-Z]+) ([^ ]+) (HTTP/1%.[01])$")
321                 
322         if not env.REQUEST_METHOD then
323                 return nil, "invalid magic"
324         end
325         
326         local key, envkey, val
327         repeat
328                 line, err = source()
329                 if not line then
330                         return nil, err
331                 elseif #line > 0 then   
332                         key, val = line:match("^([%w-]+)%s?:%s?(.*)")
333                         if key then
334                                 req.headers[key] = val
335                                 envkey = hdr2env[key]
336                                 if envkey then
337                                         env[envkey] = val
338                                 end
339                         else
340                                 return nil, "invalid header line"
341                         end
342                 else
343                         break
344                 end
345         until false
346         
347         env.SCRIPT_NAME, env.QUERY_STRING = env.REQUEST_URI:match("([^?]*)%??(.*)")
348         return req
349 end
350
351
352 function Server.process(self, client, env)
353         local sourcein  = function() end
354         local sourcehdr = client:linesource()
355         local sinkout
356         local buffer
357         
358         local close = false
359         local stat, code, msg, message, err
360         
361         client:setsockopt("socket", "rcvtimeo", 5)
362         client:setsockopt("socket", "sndtimeo", 5)
363         
364         repeat
365                 -- parse headers
366                 message, err = self:parse_headers(sourcehdr)
367
368                 -- any other error
369                 if not message or err then
370                         if err == 11 then       -- EAGAIN
371                                 break
372                         else
373                                 return self:error(client, 400, err)
374                         end
375                 end
376
377                 -- Prepare sources and sinks
378                 buffer = sourcehdr(true)
379                 sinkout = client:sink()
380                 message.server = env
381                 
382                 if client:is_tls_socket() then
383                         message.env.HTTPS = "on"
384                 end
385                 
386                 -- Addresses
387                 message.env.REMOTE_ADDR = remapipv6(env.host)
388                 message.env.REMOTE_PORT = env.port
389                 
390                 local srvaddr, srvport = client:getsockname()
391                 message.env.SERVER_ADDR = remapipv6(srvaddr)
392                 message.env.SERVER_PORT = srvport
393                 
394                 -- keep-alive
395                 if message.env.SERVER_PROTOCOL == "HTTP/1.1" then
396                         close = (message.env.HTTP_CONNECTION == "close")
397                 else
398                         close = not message.env.HTTP_CONNECTION 
399                                 or message.env.HTTP_CONNECTION == "close"
400                 end
401
402                 -- Uncomment this to disable keep-alive
403                 close = close or env.config.nokeepalive
404         
405                 if message.env.REQUEST_METHOD == "GET"
406                 or message.env.REQUEST_METHOD == "HEAD" then
407                         -- Be happy
408                         
409                 elseif message.env.REQUEST_METHOD == "POST" then
410                         -- If we have a HTTP/1.1 client and an Expect: 100-continue header
411                         -- respond with HTTP 100 Continue message
412                         if message.env.SERVER_PROTOCOL == "HTTP/1.1" 
413                         and message.headers.Expect == '100-continue' then
414                                 client:writeall("HTTP/1.1 100 Continue\r\n\r\n")
415                         end
416                         
417                         if message.headers['Transfer-Encoding'] and
418                          message.headers['Transfer-Encoding'] ~= "identity" then
419                                 sourcein = chunksource(client, buffer)
420                                 buffer = nil
421                         elseif message.env.CONTENT_LENGTH then
422                                 local len = tonumber(message.env.CONTENT_LENGTH)
423                                 if #buffer >= len then
424                                         sourcein = ltn12.source.string(buffer:sub(1, len))
425                                         buffer = buffer:sub(len+1)
426                                 else
427                                         sourcein = ltn12.source.cat(
428                                                 ltn12.source.string(buffer),
429                                                 client:blocksource(nil, len - #buffer)
430                                         )
431                                 end
432                         else
433                                 return self:error(client, 411, statusmsg[411])
434                         end
435
436                         close = true
437                 else
438                         return self:error(client, 405, statusmsg[405])
439                 end
440
441
442                 local host = self.vhosts[message.env.HTTP_HOST] or self.vhosts[""]
443                 if not host then
444                         return self:error(client, 404, "No virtual host found")
445                 end
446                 
447                 local code, headers, sourceout = host:process(message, sourcein)
448                 headers = headers or {}
449                 
450                 -- Post process response
451                 if sourceout then
452                         if util.instanceof(sourceout, IOResource) then
453                                 if not headers["Content-Length"] then
454                                         headers["Content-Length"] = sourceout.len
455                                 end
456                         end
457                         if not headers["Content-Length"] then
458                                 if message.env.SERVER_PROTOCOL == "HTTP/1.1" then
459                                         headers["Transfer-Encoding"] = "chunked"
460                                         sinkout = chunksink(client)
461                                 else
462                                         close = true
463                                 end
464                         end
465                 elseif message.env.REQUEST_METHOD ~= "HEAD" then
466                         headers["Content-Length"] = 0
467                 end
468                 
469                 if close then
470                         headers["Connection"] = "close"
471                 elseif message.env.SERVER_PROTOCOL == "HTTP/1.0" then
472                         headers["Connection"] = "Keep-Alive"
473                 end 
474
475                 headers["Date"] = date.to_http(os.time())
476                 local header = {
477                         message.env.SERVER_PROTOCOL .. " " .. tostring(code) .. " " 
478                                 .. statusmsg[code],
479                         "Server: LuCId-HTTPd/" .. VERSION
480                 }
481
482                 
483                 for k, v in pairs(headers) do
484                         if type(v) == "table" then
485                                 for _, h in ipairs(v) do
486                                         header[#header+1] = k .. ": " .. h
487                                 end
488                         else
489                                 header[#header+1] = k .. ": " .. v
490                         end
491                 end
492
493                 header[#header+1] = ""
494                 header[#header+1] = ""
495                 
496                 -- Output
497                 stat, code, msg = client:writeall(table.concat(header, "\r\n"))
498
499                 if sourceout and stat then
500                         if util.instanceof(sourceout, IOResource) then
501                                 stat, code, msg = sourceout.fd:copyz(client, sourceout.len)
502                         else
503                                 stat, msg = ltn12.pump.all(sourceout, sinkout)
504                         end
505                 end
506
507
508                 -- Write errors
509                 if not stat then
510                         if msg then
511                                 nixio.syslog("err", "Error sending data to " .. env.host ..
512                                         ": " .. msg .. "\n")
513                         end
514                         break
515                 end
516                 
517                 if buffer then
518                         sourcehdr(buffer)
519                 end
520         until close
521         
522         client:shutdown()
523         client:close()
524 end