kmodloader: use avl tree and calloc_a
[project/ubox.git] / kmodloader.c
1 /*
2  * Copyright (C) 2013 Felix Fietkau <nbd@openwrt.org>
3  * Copyright (C) 2013 John Crispin <blogic@openwrt.org>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU Lesser General Public License version 2.1
7  * as published by the Free Software Foundation
8  *
9  * This program is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU General Public License for more details.
13  */
14
15 #define _GNU_SOURCE
16 #include <sys/syscall.h>
17 #include <sys/mman.h>
18 #include <sys/utsname.h>
19
20 #include <stdlib.h>
21 #include <unistd.h>
22 #include <sys/syscall.h>
23 #include <sys/types.h>
24 #include <values.h>
25 #include <errno.h>
26 #include <stdio.h>
27 #include <string.h>
28 #include <sys/stat.h>
29 #include <fcntl.h>
30 #include <syslog.h>
31 #include <glob.h>
32 #include <elf.h>
33
34 #include <libubox/avl.h>
35 #include <libubox/avl-cmp.h>
36 #include <libubox/utils.h>
37
38 #define DEF_MOD_PATH "/lib/modules/%s/"
39
40 enum {
41         SCANNED,
42         PROBE,
43         LOADED,
44         FAILED,
45 };
46
47 struct module {
48         struct avl_node avl;
49
50         char *name;
51         char *depends;
52
53         int size;
54         int usage;
55         int state;
56 };
57
58 static struct avl_tree modules;
59
60 static struct module *find_module(const char *name)
61 {
62         struct module *m;
63         return avl_find_element(&modules, name, m, avl);
64 }
65
66 static void free_modules(void)
67 {
68         struct module *m, *tmp;
69
70         avl_remove_all_elements(&modules, m, avl, tmp)
71                 free(m);
72 }
73
74 static char* get_module_path(char *name)
75 {
76         static char path[256];
77         struct utsname ver;
78         struct stat s;
79         char *t;
80
81         if (!stat(name, &s))
82                 return name;
83
84         uname(&ver);
85         snprintf(path, 256, DEF_MOD_PATH "%s.ko", ver.release, name);
86
87         if (!stat(path, &s))
88                 return path;
89
90         t = name;
91         while (t && *t) {
92                 if (*t == '_')
93                         *t = '-';
94                 t++;
95         }
96
97         snprintf(path, 256, DEF_MOD_PATH "%s.ko", ver.release, name);
98
99         if (!stat(path, &s))
100                 return path;
101
102         return NULL;
103 }
104
105 static char* get_module_name(char *path)
106 {
107         static char name[32];
108         char *t;
109
110         strncpy(name, basename(path), sizeof(name));
111
112         t = strstr(name, ".ko");
113         if (t)
114                 *t = '\0';
115         t = name;
116         while (t && *t) {
117                 if (*t == '-')
118                         *t = '_';
119                 t++;
120         }
121
122         return name;
123 }
124
125 #if __WORDSIZE == 64
126 static int elf_find_section(char *map, const char *section, unsigned int *offset, unsigned int *size)
127 {
128         const char *secnames;
129         Elf64_Ehdr *e;
130         Elf64_Shdr *sh;
131         int i;
132
133         e = (Elf64_Ehdr *) map;
134         sh = (Elf64_Shdr *) (map + e->e_shoff);
135
136         secnames = map + sh[e->e_shstrndx].sh_offset;
137         for (i = 0; i < e->e_shnum; i++) {
138                 if (!strcmp(section, secnames + sh[i].sh_name)) {
139                         *size = sh[i].sh_size;
140                         *offset = sh[i].sh_offset;
141                         return 0;
142                 }
143         }
144
145         return -1;
146 }
147 #else
148 static int elf_find_section(char *map, const char *section, unsigned int *offset, unsigned int *size)
149 {
150         const char *secnames;
151         Elf32_Ehdr *e;
152         Elf32_Shdr *sh;
153         int i;
154
155         e = (Elf32_Ehdr *) map;
156         sh = (Elf32_Shdr *) (map + e->e_shoff);
157
158         secnames = map + sh[e->e_shstrndx].sh_offset;
159         for (i = 0; i < e->e_shnum; i++) {
160                 if (!strcmp(section, secnames + sh[i].sh_name)) {
161                         *size = sh[i].sh_size;
162                         *offset = sh[i].sh_offset;
163                         return 0;
164                 }
165         }
166
167         return -1;
168 }
169 #endif
170
171 static struct module *
172 alloc_module(const char *name, const char *depends, int size)
173 {
174         struct module *m;
175         char *_name, *_dep;
176
177         m = calloc_a(sizeof(*m),
178                 &_name, strlen(name) + 1,
179                 &_dep, depends ? strlen(depends) + 1 : 0);
180         if (!m)
181                 return NULL;
182
183         m->avl.key = m->name = strcpy(_name, name);
184         if (depends)
185                 m->depends = strcpy(_dep, depends);
186
187         m->size = size;
188         avl_insert(&modules, &m->avl);
189         return m;
190 }
191
192 static int scan_loaded_modules(void)
193 {
194         FILE *fp = fopen("/proc/modules", "r");
195         char buf[256];
196
197         if (!fp) {
198                 fprintf(stderr, "failed to open /proc/modules\n");
199                 return -1;
200         }
201
202         while (fgets(buf, sizeof(buf), fp)) {
203                 struct module m;
204                 struct module *n;
205
206                 m.name = strtok(buf, " ");
207                 m.size = atoi(strtok(NULL, " "));
208                 m.usage = atoi(strtok(NULL, " "));
209                 m.depends = strtok(NULL, " ");
210
211                 if (!m.name || !m.depends)
212                         continue;
213
214                 n = alloc_module(m.name, m.depends, m.size);
215                 n->usage = m.usage;
216                 n->state = LOADED;
217         }
218
219         return 0;
220 }
221
222 static struct module* get_module_info(const char *module, const char *name)
223 {
224         int fd = open(module, O_RDONLY);
225         unsigned int offset, size;
226         char *map, *strings, *dep = NULL;
227         struct module *m;
228         struct stat s;
229
230         if (!fd) {
231                 fprintf(stderr, "failed to open %s\n", module);
232                 return NULL;
233         }
234
235         if (fstat(fd, &s) == -1) {
236                 fprintf(stderr, "failed to stat %s\n", module);
237                 return NULL;
238         }
239
240         map = mmap(NULL, s.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
241         if (map == MAP_FAILED) {
242                 fprintf(stderr, "failed to mmap %s\n", module);
243                 return NULL;
244         }
245
246         if (elf_find_section(map, ".modinfo", &offset, &size)) {
247                 fprintf(stderr, "failed to load the .modinfo section from %s\n", module);
248                 return NULL;
249         }
250
251         strings = map + offset;
252         while (strings && (strings < map + offset + size)) {
253                 char *sep;
254                 int len;
255
256                 while (!strings[0])
257                         strings++;
258                 sep = strstr(strings, "=");
259                 if (!sep)
260                         break;
261                 len = sep - strings;
262                 sep++;
263                 if (!strncmp(strings, "depends=", len + 1))
264                         dep = sep;
265                 strings = &sep[strlen(sep)];
266         }
267
268         m = alloc_module(name, dep, s.st_size);
269         if (!m)
270                 return NULL;
271
272         m->state = SCANNED;
273
274         return m;
275 }
276
277 static int scan_module_folder(char *dir)
278 {
279         int gl_flags = GLOB_NOESCAPE | GLOB_MARK;
280         int j;
281         glob_t gl;
282
283         if (glob(dir, gl_flags, NULL, &gl) < 0)
284                 return -1;
285
286         for (j = 0; j < gl.gl_pathc; j++) {
287                 char *name = get_module_name(gl.gl_pathv[j]);
288                 struct module *m;
289
290                 if (!name)
291                         continue;
292
293                 m = find_module(name);
294                 if (!m)
295                         get_module_info(gl.gl_pathv[j], name);
296         }
297
298         globfree(&gl);
299
300         return 0;
301 }
302
303 static int print_modinfo(char *module)
304 {
305         int fd = open(module, O_RDONLY);
306         unsigned int offset, size;
307         struct stat s;
308         char *map, *strings;
309
310         if (!fd) {
311                 fprintf(stderr, "failed to open %s\n", module);
312                 return -1;
313         }
314
315         if (fstat(fd, &s) == -1) {
316                 fprintf(stderr, "failed to stat %s\n", module);
317                 return -1;
318         }
319
320         map = mmap(NULL, s.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
321         if (map == MAP_FAILED) {
322                 fprintf(stderr, "failed to mmap %s\n", module);
323                 return -1;
324         }
325
326         if (elf_find_section(map, ".modinfo", &offset, &size)) {
327                 fprintf(stderr, "failed to load the .modinfo section from %s\n", module);
328                 return -1;
329         }
330
331         strings = map + offset;
332         printf("module:\t\t%s\n", module);
333         while (strings && (strings < map + offset + size)) {
334                 char *dup = NULL;
335                 char *sep;
336
337                 while (!strings[0])
338                         strings++;
339                 sep = strstr(strings, "=");
340                 if (!sep)
341                         break;
342                 dup = strndup(strings, sep - strings);
343                 sep++;
344                 if (strncmp(strings, "parm", 4)) {
345                         if (strlen(dup) < 7)
346                                 printf("%s:\t\t%s\n",  dup, sep);
347                         else
348                                 printf("%s:\t%s\n",  dup, sep);
349                 }
350                 strings = &sep[strlen(sep)];
351                 if (dup)
352                         free(dup);
353         }
354
355         return 0;
356 }
357
358 static int insert_module(char *path, const char *options)
359 {
360         void *data = 0;
361         struct stat s;
362         int fd, ret = -1;
363
364         if (stat(path, &s)) {
365                 fprintf(stderr, "missing module %s\n", path);
366                 return ret;
367         }
368
369         fd = open(path, O_RDONLY);
370         if (!fd) {
371                 fprintf(stderr, "cannot open %s\n", path);
372                 return ret;
373         }
374
375         data = malloc(s.st_size);
376         if (read(fd, data, s.st_size) == s.st_size) {
377                 ret = syscall(__NR_init_module, data, s.st_size, options);
378                 if (ret)
379                         fprintf(stderr, "failed to insert %s\n", path);
380         } else {
381                 fprintf(stderr, "failed to read full module %s\n", path);
382         }
383
384         close(fd);
385         free(data);
386
387         return ret;
388 }
389
390 static int deps_available(struct module *m)
391 {
392         char *deps = m->depends;
393         char *comma;
394
395         if (!strcmp(deps, "-"))
396                 return 0;
397         while (*deps && (NULL != ((comma = strstr(deps, ","))))) {
398                 *comma = '\0';
399
400                 m = find_module(deps);
401
402                 if (!m || (m->state != LOADED))
403                         return -1;
404
405                 deps = ++comma;
406         }
407
408         return 0;
409 }
410
411 static int load_depmod(void)
412 {
413         int loaded, todo;
414         struct module *m;
415
416         do {
417                 loaded = 0;
418                 todo = 0;
419                 avl_for_each_element(&modules, m, avl) {
420                         if ((m->state == PROBE) && (!deps_available(m))) {
421                                 if (!insert_module(get_module_path(m->name), "")) {
422                                         m->state = LOADED;
423                                         loaded++;
424                                         continue;
425                                 }
426                                 m->state = FAILED;
427                         } else if (m->state == PROBE) {
428                                 todo++;
429                         }
430                 }
431 //              printf("loaded %d modules this pass\n", loaded);
432         } while (loaded);
433
434 //      printf("missing todos %d\n", todo);
435
436         return -todo;
437 }
438
439 static int print_insmod_usage(void)
440 {
441         fprintf(stderr, "Usage:\n\tinsmod filename [args]\n");
442
443         return -1;
444 }
445
446 static int print_usage(char *arg)
447 {
448         fprintf(stderr, "Usage:\n\t%s module\n", arg);
449
450         return -1;
451 }
452
453 static int main_insmod(int argc, char **argv)
454 {
455         char options[256] = "";
456         char *name;
457         int i;
458
459         if (argc < 2)
460                 return print_insmod_usage();
461
462         name = get_module_name(argv[1]);
463         if (!name) {
464                 fprintf(stderr, "cannot find module - %s\n", argv[1]);
465                 return -1;
466         }
467
468         if (scan_loaded_modules())
469                 return -1;
470
471         if (find_module(name)) {
472                 fprintf(stderr, "module is already loaded - %s\n", name);
473                 return -1;
474
475         }
476
477         free_modules();
478
479         for (i = 2; i < argc; i++)
480                 if (snprintf(options, sizeof(options), "%s %s", options, argv[i]) >= sizeof(options)) {
481                         fprintf(stderr, "argument line too long - %s\n", options);
482                         return -1;
483                 }
484
485         return insert_module(get_module_path(name), options);
486 }
487
488 static int main_rmmod(int argc, char **argv)
489 {
490         struct module *m;
491         char *name;
492         int ret;
493
494         if (argc != 2)
495                 return print_usage("rmmod");
496
497         if (scan_loaded_modules())
498                 return -1;
499
500         name = get_module_name(argv[1]);
501         m = find_module(name);
502         if (!m) {
503                 fprintf(stderr, "module is not loaded\n");
504                 return -1;
505         }
506         free_modules();
507
508         ret = syscall(__NR_delete_module, name, 0);
509
510         if (ret)
511                 fprintf(stderr, "unloading the module failed\n");
512
513         return ret;
514 }
515
516 static int main_lsmod(int argc, char **argv)
517 {
518         struct module *m;
519
520         if (scan_loaded_modules())
521                 return -1;
522
523         avl_for_each_element(&modules, m, avl)
524                 if (m->state == LOADED)
525                         printf("%-20s%8d%3d %s\n",
526                                 m->name, m->size, m->usage,
527                                 (*m->depends == '-') ? ("") : (m->depends));
528
529         free_modules();
530
531         return 0;
532 }
533
534 static int main_modinfo(int argc, char **argv)
535 {
536         char *module;
537
538         if (argc != 2)
539                 return print_usage("modinfo");
540
541         module = get_module_path(argv[1]);
542         if (!module) {
543                 fprintf(stderr, "cannot find module - %s\n", argv[1]);
544                 return -1;
545         }
546
547         print_modinfo(module);
548
549         return 0;
550 }
551
552 static int main_depmod(int argc, char **argv)
553 {
554         struct utsname ver;
555         struct module *m;
556         char path[128];
557         char *name;
558
559         if (argc != 2)
560                 return print_usage("depmod");
561
562         if (scan_loaded_modules())
563                 return -1;
564
565         uname(&ver);
566         snprintf(path, sizeof(path), DEF_MOD_PATH "*.ko", ver.release);
567
568         scan_module_folder(path);
569
570         name = get_module_name(argv[1]);
571         m = find_module(name);
572         if (m && m->state == LOADED) {
573                 fprintf(stderr, "%s is already loaded\n", name);
574                 return -1;
575         } else if (!m) {
576                 fprintf(stderr, "failed to find a module named %s\n", name);
577         } else {
578                 m->state = PROBE;
579                 load_depmod();
580         }
581
582         free_modules();
583
584         return 0;
585 }
586
587 static int main_loader(int argc, char **argv)
588 {
589         int gl_flags = GLOB_NOESCAPE | GLOB_MARK;
590         char *dir = "/etc/modules.d/*";
591         glob_t gl;
592         char *path;
593
594         if (argc > 1)
595                 dir = argv[1];
596
597         path = malloc(strlen(dir) + 2);
598         strcpy(path, dir);
599         strcat(path, "*");
600
601         scan_loaded_modules();
602
603         syslog(0, "kmodloader: loading kernel modules from %s\n", path);
604
605         if (glob(path, gl_flags, NULL, &gl) >= 0) {
606                 int j;
607
608                 for (j = 0; j < gl.gl_pathc; j++) {
609                         FILE *fp = fopen(gl.gl_pathv[j], "r");
610
611                         if (!fp) {
612                                 fprintf(stderr, "failed to open %s\n", gl.gl_pathv[j]);
613                         } else {
614                                 char mod[256];
615
616                                 while (fgets(mod, sizeof(mod), fp)) {
617                                         char *nl = strchr(mod, '\n');
618                                         struct module *m;
619                                         char *opts;
620
621                                         if (nl)
622                                                 *nl = '\0';
623
624                                         opts = strchr(mod, ' ');
625                                         if (opts)
626                                                 *opts++ = '\0';
627
628                                         m = find_module(get_module_name(mod));
629                                         if (m)
630                                                 continue;
631                                         insert_module(get_module_path(mod), (opts) ? (opts) : (""));
632                                 }
633                                 fclose(fp);
634                         }
635                 }
636         }
637
638         globfree(&gl);
639         free(path);
640
641         return 0;
642 }
643
644 int main(int argc, char **argv)
645 {
646         char *exec = basename(*argv);
647
648         avl_init(&modules, avl_strcmp, false, NULL);
649         if (!strcmp(exec, "insmod"))
650                 return main_insmod(argc, argv);
651
652         if (!strcmp(exec, "rmmod"))
653                 return main_rmmod(argc, argv);
654
655         if (!strcmp(exec, "lsmod"))
656                 return main_lsmod(argc, argv);
657
658         if (!strcmp(exec, "modinfo"))
659                 return main_modinfo(argc, argv);
660
661         if (!strcmp(exec, "depmod"))
662                 return main_depmod(argc, argv);
663
664         return main_loader(argc, argv);
665 }