ar71xx: check squashfs signature in TP-Link mtd parser
[openwrt.git] / target / linux / ar71xx / files / drivers / mtd / tplinkpart.c
1 /*
2  * Copyright (C) 2011 Gabor Juhos <juhosg@openwrt.org>
3  *
4  * This program is free software; you can redistribute it and/or modify it
5  * under the terms of the GNU General Public License version 2 as published
6  * by the Free Software Foundation.
7  *
8  */
9
10 #include <linux/kernel.h>
11 #include <linux/slab.h>
12 #include <linux/vmalloc.h>
13 #include <linux/magic.h>
14
15 #include <linux/mtd/mtd.h>
16 #include <linux/mtd/partitions.h>
17
18 #define TPLINK_NUM_PARTS        5
19 #define TPLINK_HEADER_V1        0x01000000
20 #define MD5SUM_LEN              16
21
22 #define TPLINK_ART_LEN          0x10000
23 #define TPLINK_KERNEL_OFFS      0x20000
24
25 struct tplink_fw_header {
26         uint32_t        version;        /* header version */
27         char            vendor_name[24];
28         char            fw_version[36];
29         uint32_t        hw_id;          /* hardware id */
30         uint32_t        hw_rev;         /* hardware revision */
31         uint32_t        unk1;
32         uint8_t         md5sum1[MD5SUM_LEN];
33         uint32_t        unk2;
34         uint8_t         md5sum2[MD5SUM_LEN];
35         uint32_t        unk3;
36         uint32_t        kernel_la;      /* kernel load address */
37         uint32_t        kernel_ep;      /* kernel entry point */
38         uint32_t        fw_length;      /* total length of the firmware */
39         uint32_t        kernel_ofs;     /* kernel data offset */
40         uint32_t        kernel_len;     /* kernel data length */
41         uint32_t        rootfs_ofs;     /* rootfs data offset */
42         uint32_t        rootfs_len;     /* rootfs data length */
43         uint32_t        boot_ofs;       /* bootloader data offset */
44         uint32_t        boot_len;       /* bootloader data length */
45         uint8_t         pad[360];
46 } __attribute__ ((packed));
47
48 static struct tplink_fw_header *
49 tplink_read_header(struct mtd_info *mtd, size_t offset)
50 {
51         struct tplink_fw_header *header;
52         size_t header_len;
53         size_t retlen;
54         int ret;
55         u32 t;
56
57         header = vmalloc(sizeof(*header));
58         if (!header)
59                 goto err;
60
61         header_len = sizeof(struct tplink_fw_header);
62         ret = mtd->read(mtd, offset, header_len, &retlen,
63                         (unsigned char *) header);
64         if (ret)
65                 goto err_free_header;
66
67         if (retlen != header_len)
68                 goto err_free_header;
69
70         /* sanity checks */
71         t = be32_to_cpu(header->version);
72         if (t != TPLINK_HEADER_V1)
73                 goto err_free_header;
74
75         t = be32_to_cpu(header->kernel_ofs);
76         if (t != header_len)
77                 goto err_free_header;
78
79         return header;
80
81 err_free_header:
82         vfree(header);
83 err:
84         return NULL;
85 }
86
87 static int tplink_check_squashfs_magic(struct mtd_info *mtd, size_t offset)
88 {
89         u32 magic;
90         size_t retlen;
91         int ret;
92
93         ret = mtd->read(mtd, offset, sizeof(magic), &retlen,
94                         (unsigned char *) &magic);
95         if (ret)
96                 return ret;
97
98         if (retlen != sizeof(magic))
99                 return -EIO;
100
101         if (le32_to_cpu(magic) != SQUASHFS_MAGIC)
102                 return -EINVAL;
103
104         return 0;
105 }
106
107 static int tplink_parse_partitions(struct mtd_info *master,
108                                    struct mtd_partition **pparts,
109                                    unsigned long origin)
110 {
111         struct mtd_partition *parts;
112         struct tplink_fw_header *header;
113         int nr_parts;
114         size_t offset;
115         size_t art_offset;
116         size_t rootfs_offset;
117         size_t squashfs_offset;
118         int ret;
119
120         nr_parts = TPLINK_NUM_PARTS;
121         parts = kzalloc(nr_parts * sizeof(struct mtd_partition), GFP_KERNEL);
122         if (!parts) {
123                 ret = -ENOMEM;
124                 goto err;
125         }
126
127         offset = TPLINK_KERNEL_OFFS;
128
129         header = tplink_read_header(master, offset);
130         if (!header) {
131                 pr_notice("%s: no TP-Link header found\n", master->name);
132                 ret = -ENODEV;
133                 goto err_free_parts;
134         }
135
136         squashfs_offset = offset + sizeof(struct tplink_fw_header) +
137                           be32_to_cpu(header->kernel_len);
138
139         ret = tplink_check_squashfs_magic(master, squashfs_offset);
140         if (ret == 0)
141                 rootfs_offset = squashfs_offset;
142         else
143                 rootfs_offset = offset + be32_to_cpu(header->rootfs_ofs);
144
145         art_offset = master->size - TPLINK_ART_LEN;
146
147         parts[0].name = "u-boot";
148         parts[0].offset = 0;
149         parts[0].size = offset;
150         parts[0].mask_flags = MTD_WRITEABLE;
151
152         parts[1].name = "kernel";
153         parts[1].offset = offset;
154         parts[1].size = rootfs_offset - offset;
155
156         parts[2].name = "rootfs";
157         parts[2].offset = rootfs_offset;
158         parts[2].size = art_offset - rootfs_offset;
159
160         parts[3].name = "art";
161         parts[3].offset = art_offset;
162         parts[3].size = TPLINK_ART_LEN;
163         parts[3].mask_flags = MTD_WRITEABLE;
164
165         parts[4].name = "firmware";
166         parts[4].offset = offset;
167         parts[4].size = art_offset - offset;
168
169         vfree(header);
170
171         *pparts = parts;
172         return nr_parts;
173
174 err_free_parts:
175         kfree(parts);
176 err:
177         *pparts = NULL;
178         return ret;
179 }
180
181 static struct mtd_part_parser tplink_parser = {
182         .owner          = THIS_MODULE,
183         .parse_fn       = tplink_parse_partitions,
184         .name           = "tp-link",
185 };
186
187 static int __init tplink_parser_init(void)
188 {
189         return register_mtd_parser(&tplink_parser);
190 }
191
192 module_init(tplink_parser_init);
193
194 MODULE_LICENSE("GPL v2");
195 MODULE_AUTHOR("Gabor Juhos <juhosg@openwrt.org>");