kmodloader: fix elf header parsing on 64bit machines
[project/ubox.git] / kmodloader.c
1 /*
2  * Copyright (C) 2013 Felix Fietkau <nbd@openwrt.org>
3  * Copyright (C) 2013 John Crispin <blogic@openwrt.org>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU Lesser General Public License version 2.1
7  * as published by the Free Software Foundation
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  */
14
15 #define _GNU_SOURCE
16 #include <sys/syscall.h>
17 #include <sys/mman.h>
18 #include <sys/utsname.h>
19
20 #include <stdlib.h>
21 #include <unistd.h>
22 #include <sys/syscall.h>
23 #include <sys/types.h>
24 #include <values.h>
25 #include <errno.h>
26 #include <stdio.h>
27 #include <string.h>
28 #include <sys/stat.h>
29 #include <fcntl.h>
30 #include <syslog.h>
31 #include <libgen.h>
32 #include <glob.h>
33 #include <elf.h>
34
35 #include <libubox/avl.h>
36 #include <libubox/avl-cmp.h>
37 #include <libubox/utils.h>
38
39 #define DEF_MOD_PATH "/lib/modules/%s/"
40
41 #define LOG(fmt, ...) do { \
42         syslog(LOG_INFO, fmt, ## __VA_ARGS__); \
43         printf("kmod: "fmt, ## __VA_ARGS__); \
44         } while (0)
45
46
47 enum {
48         SCANNED,
49         PROBE,
50         LOADED,
51 };
52
53 struct module {
54         struct avl_node avl;
55
56         char *name;
57         char *depends;
58
59         int size;
60         int usage;
61         int state;
62         int error;
63 };
64
65 static struct avl_tree modules;
66 static char *prefix = "";
67
68 static struct module *find_module(const char *name)
69 {
70         struct module *m;
71         return avl_find_element(&modules, name, m, avl);
72 }
73
74 static void free_modules(void)
75 {
76         struct module *m, *tmp;
77
78         avl_remove_all_elements(&modules, m, avl, tmp)
79                 free(m);
80 }
81
82 static char* get_module_path(char *name)
83 {
84         static char path[256];
85         struct utsname ver;
86         struct stat s;
87
88         if (!stat(name, &s))
89                 return name;
90
91         uname(&ver);
92         snprintf(path, 256, "%s" DEF_MOD_PATH "%s.ko", prefix, ver.release, name);
93
94         if (!stat(path, &s))
95                 return path;
96
97         return NULL;
98 }
99
100 static char* get_module_name(char *path)
101 {
102         static char name[32];
103         char *t;
104
105         strncpy(name, basename(path), sizeof(name));
106
107         t = strstr(name, ".ko");
108         if (t)
109                 *t = '\0';
110
111         return name;
112 }
113
114 static int elf64_find_section(char *map, const char *section, unsigned int *offset, unsigned int *size)
115 {
116         const char *secnames;
117         Elf64_Ehdr *e;
118         Elf64_Shdr *sh;
119         int i;
120
121         e = (Elf64_Ehdr *) map;
122         sh = (Elf64_Shdr *) (map + e->e_shoff);
123
124         secnames = map + sh[e->e_shstrndx].sh_offset;
125         for (i = 0; i < e->e_shnum; i++) {
126                 if (!strcmp(section, secnames + sh[i].sh_name)) {
127                         *size = sh[i].sh_size;
128                         *offset = sh[i].sh_offset;
129                         return 0;
130                 }
131         }
132
133         return -1;
134 }
135
136 static int elf32_find_section(char *map, const char *section, unsigned int *offset, unsigned int *size)
137 {
138         const char *secnames;
139         Elf32_Ehdr *e;
140         Elf32_Shdr *sh;
141         int i;
142
143         e = (Elf32_Ehdr *) map;
144         sh = (Elf32_Shdr *) (map + e->e_shoff);
145
146         secnames = map + sh[e->e_shstrndx].sh_offset;
147         for (i = 0; i < e->e_shnum; i++) {
148                 if (!strcmp(section, secnames + sh[i].sh_name)) {
149                         *size = sh[i].sh_size;
150                         *offset = sh[i].sh_offset;
151                         return 0;
152                 }
153         }
154
155         return -1;
156 }
157
158 static int elf_find_section(char *map, const char *section, unsigned int *offset, unsigned int *size)
159 {
160         int clazz = map[EI_CLASS];
161
162         if (clazz == ELFCLASS32)
163                 return elf32_find_section(map, section, offset, size);
164         else if (clazz == ELFCLASS64)
165                 return elf64_find_section(map, section, offset, size);
166
167         LOG("unknown elf format %d\n", clazz);
168
169         return -1;
170 }
171
172 static struct module *
173 alloc_module(const char *name, const char *depends, int size)
174 {
175         struct module *m;
176         char *_name, *_dep;
177
178         m = calloc_a(sizeof(*m),
179                 &_name, strlen(name) + 1,
180                 &_dep, depends ? strlen(depends) + 2 : 0);
181         if (!m)
182                 return NULL;
183
184         m->avl.key = m->name = strcpy(_name, name);
185
186         if (depends) {
187                 m->depends = strcpy(_dep, depends);
188                 while (*_dep) {
189                         if (*_dep == ',')
190                                 *_dep = '\0';
191                         _dep++;
192                 }
193         }
194
195         m->size = size;
196         avl_insert(&modules, &m->avl);
197
198         return m;
199 }
200
201 static int scan_loaded_modules(void)
202 {
203         size_t buf_len = 0;
204         char *buf = NULL;
205         FILE *fp;
206
207         fp = fopen("/proc/modules", "r");
208         if (!fp) {
209                 LOG("failed to open /proc/modules\n");
210                 return -1;
211         }
212
213         while (getline(&buf, &buf_len, fp) > 0) {
214                 struct module m;
215                 struct module *n;
216
217                 m.name = strtok(buf, " ");
218                 m.size = atoi(strtok(NULL, " "));
219                 m.usage = atoi(strtok(NULL, " "));
220                 m.depends = strtok(NULL, " ");
221
222                 if (!m.name || !m.depends)
223                         continue;
224
225                 n = alloc_module(m.name, m.depends, m.size);
226                 n->usage = m.usage;
227                 n->state = LOADED;
228         }
229         free(buf);
230         fclose(fp);
231
232         return 0;
233 }
234
235 static struct module* get_module_info(const char *module, const char *name)
236 {
237         int fd = open(module, O_RDONLY);
238         unsigned int offset, size;
239         char *map, *strings, *dep = NULL;
240         struct module *m;
241         struct stat s;
242
243         if (!fd) {
244                 LOG("failed to open %s\n", module);
245                 return NULL;
246         }
247
248         if (fstat(fd, &s) == -1) {
249                 LOG("failed to stat %s\n", module);
250                 return NULL;
251         }
252
253         map = mmap(NULL, s.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
254         if (map == MAP_FAILED) {
255                 LOG("failed to mmap %s\n", module);
256                 return NULL;
257         }
258
259         if (elf_find_section(map, ".modinfo", &offset, &size)) {
260                 LOG("failed to load the .modinfo section from %s\n", module);
261                 return NULL;
262         }
263
264         strings = map + offset;
265         while (strings && (strings < map + offset + size)) {
266                 char *sep;
267                 int len;
268
269                 while (!strings[0])
270                         strings++;
271                 sep = strstr(strings, "=");
272                 if (!sep)
273                         break;
274                 len = sep - strings;
275                 sep++;
276                 if (!strncmp(strings, "depends=", len + 1))
277                         dep = sep;
278                 strings = &sep[strlen(sep)];
279         }
280
281         m = alloc_module(name, dep, s.st_size);
282         if (!m)
283                 return NULL;
284
285         m->state = SCANNED;
286
287         return m;
288 }
289
290 static int scan_module_folder(void)
291 {
292         int gl_flags = GLOB_NOESCAPE | GLOB_MARK;
293         struct utsname ver;
294         char *path;
295         glob_t gl;
296         int j;
297
298         uname(&ver);
299         path = alloca(sizeof(DEF_MOD_PATH "*.ko") + strlen(prefix) + strlen(ver.release) + 1);
300         sprintf(path, "%s" DEF_MOD_PATH "*.ko", prefix, ver.release);
301
302         if (glob(path, gl_flags, NULL, &gl) < 0)
303                 return -1;
304
305         for (j = 0; j < gl.gl_pathc; j++) {
306                 char *name = get_module_name(gl.gl_pathv[j]);
307                 struct module *m;
308
309                 if (!name)
310                         continue;
311
312                 m = find_module(name);
313                 if (!m)
314                         get_module_info(gl.gl_pathv[j], name);
315         }
316
317         globfree(&gl);
318
319         return 0;
320 }
321
322 static int print_modinfo(char *module)
323 {
324         int fd = open(module, O_RDONLY);
325         unsigned int offset, size;
326         struct stat s;
327         char *map, *strings;
328
329         if (!fd) {
330                 LOG("failed to open %s\n", module);
331                 return -1;
332         }
333
334         if (fstat(fd, &s) == -1) {
335                 LOG("failed to stat %s\n", module);
336                 return -1;
337         }
338
339         map = mmap(NULL, s.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
340         if (map == MAP_FAILED) {
341                 LOG("failed to mmap %s\n", module);
342                 return -1;
343         }
344
345         if (elf_find_section(map, ".modinfo", &offset, &size)) {
346                 LOG("failed to load the .modinfo section from %s\n", module);
347                 return -1;
348         }
349
350         strings = map + offset;
351         printf("module:\t\t%s\n", module);
352         while (strings && (strings < map + offset + size)) {
353                 char *dup = NULL;
354                 char *sep;
355
356                 while (!strings[0])
357                         strings++;
358                 sep = strstr(strings, "=");
359                 if (!sep)
360                         break;
361                 dup = strndup(strings, sep - strings);
362                 sep++;
363                 if (strncmp(strings, "parm", 4)) {
364                         if (strlen(dup) < 7)
365                                 printf("%s:\t\t%s\n",  dup, sep);
366                         else
367                                 printf("%s:\t%s\n",  dup, sep);
368                 }
369                 strings = &sep[strlen(sep)];
370                 if (dup)
371                         free(dup);
372         }
373
374         return 0;
375 }
376
377 static int deps_available(struct module *m, int verbose)
378 {
379         char *dep;
380         int err = 0;
381
382         if (!strcmp(m->depends, "-") || !strcmp(m->depends, ""))
383                 return 0;
384
385         dep = m->depends;
386
387         while (*dep) {
388                 m = find_module(dep);
389
390                 if (verbose && !m)
391                         LOG("missing dependency %s\n", dep);
392                 if (verbose && m && (m->state != LOADED))
393                         LOG("dependency not loaded %s\n", dep);
394                 if (!m || (m->state != LOADED))
395                         err++;
396                 dep += strlen(dep) + 1;
397         }
398
399         return err;
400 }
401
402 static int insert_module(char *path, const char *options)
403 {
404         void *data = 0;
405         struct stat s;
406         int fd, ret = -1;
407
408         if (stat(path, &s)) {
409                 LOG("missing module %s\n", path);
410                 return ret;
411         }
412
413         fd = open(path, O_RDONLY);
414         if (!fd) {
415                 LOG("cannot open %s\n", path);
416                 return ret;
417         }
418
419         data = malloc(s.st_size);
420         if (read(fd, data, s.st_size) == s.st_size)
421                 ret = syscall(__NR_init_module, data, s.st_size, options);
422         else
423                 LOG("failed to read full module %s\n", path);
424
425         close(fd);
426         free(data);
427
428         return ret;
429 }
430
431 static void load_moddeps(struct module *_m)
432 {
433         char *dep;
434         struct module *m;
435
436         if (!strcmp(_m->depends, "-") || !strcmp(_m->depends, ""))
437                 return;
438
439         dep = _m->depends;
440
441         while (*dep) {
442                 m = find_module(dep);
443
444                 if (!m)
445                         LOG("failed to find dependency %s\n", dep);
446                 if (m && (m->state != LOADED)) {
447                         m->state = PROBE;
448                         load_moddeps(m);
449                 }
450
451                 dep = dep + strlen(dep) + 1;
452         }
453 }
454
455 static int iterations = 0;
456 static int load_modprobe(void)
457 {
458         int loaded, todo;
459         struct module *m;
460
461         avl_for_each_element(&modules, m, avl)
462                 if (m->state == PROBE)
463                         load_moddeps(m);
464
465         do {
466                 loaded = 0;
467                 todo = 0;
468                 avl_for_each_element(&modules, m, avl) {
469                         if ((m->state == PROBE) && (!deps_available(m, 0))) {
470                                 if (!insert_module(get_module_path(m->name), "")) {
471                                         m->state = LOADED;
472                                         m->error = 0;
473                                         loaded++;
474                                         continue;
475                                 }
476                                 m->error = 1;
477                         }
478
479                         if ((m->state == PROBE) || m->error)
480                                 todo++;
481                 }
482                 iterations++;
483         } while (loaded);
484
485         return todo;
486 }
487
488 static int print_insmod_usage(void)
489 {
490         LOG("Usage:\n\tinsmod filename [args]\n");
491
492         return -1;
493 }
494
495 static int print_usage(char *arg)
496 {
497         LOG("Usage:\n\t%s module\n", arg);
498
499         return -1;
500 }
501
502 static int main_insmod(int argc, char **argv)
503 {
504         char *name, *cur, *options;
505         int i, ret, len;
506
507         if (argc < 2)
508                 return print_insmod_usage();
509
510         name = get_module_name(argv[1]);
511         if (!name) {
512                 LOG("cannot find module - %s\n", argv[1]);
513                 return -1;
514         }
515
516         if (scan_loaded_modules())
517                 return -1;
518
519         if (find_module(name)) {
520                 LOG("module is already loaded - %s\n", name);
521                 return -1;
522
523         }
524
525         free_modules();
526
527         for (len = 0, i = 2; i < argc; i++)
528                 len += strlen(argv[i]) + 1;
529
530         options = malloc(len);
531         options[0] = 0;
532         cur = options;
533         for (i = 2; i < argc; i++) {
534                 if (options[0]) {
535                         *cur = ' ';
536                         cur++;
537                 }
538                 cur += sprintf(cur, "%s", argv[i]);
539         }
540
541         if (!get_module_path(name)) {
542                 fprintf(stderr, "Failed to find %s. Maybe it is a built in module ?\n", name);
543                 return -1;
544         }
545
546         ret = insert_module(get_module_path(name), options);
547         free(options);
548
549         if (ret)
550                 LOG("failed to insert %s\n", get_module_path(name));
551
552         return ret;
553 }
554
555 static int main_rmmod(int argc, char **argv)
556 {
557         struct module *m;
558         char *name;
559         int ret;
560
561         if (argc != 2)
562                 return print_usage("rmmod");
563
564         if (scan_loaded_modules())
565                 return -1;
566
567         name = get_module_name(argv[1]);
568         m = find_module(name);
569         if (!m) {
570                 LOG("module is not loaded\n");
571                 return -1;
572         }
573         ret = syscall(__NR_delete_module, m->name, 0);
574
575         if (ret)
576                 LOG("unloading the module failed\n");
577
578         free_modules();
579
580         return ret;
581 }
582
583 static int main_lsmod(int argc, char **argv)
584 {
585         struct module *m;
586
587         if (scan_loaded_modules())
588                 return -1;
589
590         avl_for_each_element(&modules, m, avl)
591                 if (m->state == LOADED)
592                         printf("%-20s%8d%3d %s\n",
593                                 m->name, m->size, m->usage,
594                                 (*m->depends == '-') ? ("") : (m->depends));
595
596         free_modules();
597
598         return 0;
599 }
600
601 static int main_modinfo(int argc, char **argv)
602 {
603         struct module *m;
604         char *name;
605
606         if (argc != 2)
607                 return print_usage("modinfo");
608
609         if (scan_module_folder())
610                 return -1;
611
612         name = get_module_name(argv[1]);
613         m = find_module(name);
614         if (!m) {
615                 LOG("cannot find module - %s\n", argv[1]);
616                 return -1;
617         }
618
619         name = get_module_path(m->name);
620         if (!name) {
621                 LOG("cannot find path of module - %s\n", m->name);
622                 return -1;
623         }
624
625         print_modinfo(name);
626
627         return 0;
628 }
629
630 static int main_modprobe(int argc, char **argv)
631 {
632         struct module *m;
633         char *name;
634
635         if (argc != 2)
636                 return print_usage("modprobe");
637
638         if (scan_loaded_modules())
639                 return -1;
640
641         if (scan_module_folder())
642                 return -1;
643
644         name = get_module_name(argv[1]);
645         m = find_module(name);
646         if (m && m->state == LOADED) {
647                 LOG("%s is already loaded\n", name);
648                 return -1;
649         } else if (!m) {
650                 LOG("failed to find a module named %s\n", name);
651         } else {
652                 int fail;
653
654                 m->state = PROBE;
655
656                 fail = load_modprobe();
657
658                 if (fail) {
659                         LOG("%d module%s could not be probed\n",
660                                         fail, (fail == 1) ? ("") : ("s"));
661
662                         avl_for_each_element(&modules, m, avl)
663                                 if ((m->state == PROBE) || m->error)
664                                         LOG("- %s\n", m->name);
665                 }
666         }
667
668         free_modules();
669
670         return 0;
671 }
672
673 static int main_loader(int argc, char **argv)
674 {
675         int gl_flags = GLOB_NOESCAPE | GLOB_MARK;
676         char *dir = "/etc/modules.d/*";
677         struct module *m;
678         glob_t gl;
679         char *path;
680         int fail, j;
681
682         if (argc > 1)
683                 dir = argv[1];
684
685         if (argc > 2)
686                 prefix = argv[2];
687
688         path = malloc(strlen(dir) + 2);
689         strcpy(path, dir);
690         strcat(path, "*");
691
692         if (scan_loaded_modules())
693                 return -1;
694
695         if (scan_module_folder())
696                 return -1;
697
698         syslog(0, "kmodloader: loading kernel modules from %s\n", path);
699
700         if (glob(path, gl_flags, NULL, &gl) < 0)
701                 goto out;
702
703         for (j = 0; j < gl.gl_pathc; j++) {
704                 FILE *fp = fopen(gl.gl_pathv[j], "r");
705                 size_t mod_len = 0;
706                 char *mod = NULL;
707
708                 if (!fp) {
709                         LOG("failed to open %s\n", gl.gl_pathv[j]);
710                         continue;
711                 }
712
713                 while (getline(&mod, &mod_len, fp) > 0) {
714                         char *nl = strchr(mod, '\n');
715                         struct module *m;
716                         char *opts;
717
718                         if (nl)
719                                 *nl = '\0';
720
721                         opts = strchr(mod, ' ');
722                         if (opts)
723                                 *opts++ = '\0';
724
725                         m = find_module(get_module_name(mod));
726                         if (!m || (m->state == LOADED))
727                                 continue;
728
729                         m->state = PROBE;
730                         if (basename(gl.gl_pathv[j])[0] - '0' <= 9)
731                                 load_modprobe();
732
733                 }
734                 free(mod);
735                 fclose(fp);
736         }
737
738         fail = load_modprobe();
739         LOG("ran %d iterations\n", iterations);
740
741         if (fail) {
742                 LOG("%d module%s could not be probed\n",
743                                 fail, (fail == 1) ? ("") : ("s"));
744
745                 avl_for_each_element(&modules, m, avl)
746                         if ((m->state == PROBE) || (m->error))
747                                 LOG("- %s - %d\n", m->name, deps_available(m, 1));
748         }
749
750 out:
751         globfree(&gl);
752         free(path);
753
754         return 0;
755 }
756
757 static int avl_modcmp(const void *k1, const void *k2, void *ptr)
758 {
759         const char *s1 = k1;
760         const char *s2 = k2;
761
762         while (*s1 && ((*s1 == *s2) ||
763                        ((*s1 == '_') && (*s2 == '-')) ||
764                        ((*s1 == '-') && (*s2 == '_'))))
765         {
766                 s1++;
767                 s2++;
768         }
769
770         return *(const unsigned char *)s1 - *(const unsigned char *)s2;
771 }
772
773 int main(int argc, char **argv)
774 {
775         char *exec = basename(*argv);
776
777         avl_init(&modules, avl_modcmp, false, NULL);
778         if (!strcmp(exec, "insmod"))
779                 return main_insmod(argc, argv);
780
781         if (!strcmp(exec, "rmmod"))
782                 return main_rmmod(argc, argv);
783
784         if (!strcmp(exec, "lsmod"))
785                 return main_lsmod(argc, argv);
786
787         if (!strcmp(exec, "modinfo"))
788                 return main_modinfo(argc, argv);
789
790         if (!strcmp(exec, "modprobe"))
791                 return main_modprobe(argc, argv);
792
793         return main_loader(argc, argv);
794 }