properly deal with only v4 or only v6 start/stop/restart
authorJo-Philipp Wich <jow@openwrt.org>
Tue, 19 Feb 2013 00:22:52 +0000 (01:22 +0100)
committerJo-Philipp Wich <jow@openwrt.org>
Tue, 19 Feb 2013 15:38:14 +0000 (16:38 +0100)
defaults.c
forwards.c
ipsets.c
main.c
options.h
redirects.c
rules.c
utils.c
utils.h
zones.c

index 1a25eaf..556d3b9 100644 (file)
@@ -142,6 +142,8 @@ fw3_load_defaults(struct fw3_state *state, struct uci_package *p)
        defs->tcp_window_scaling   = true;
        defs->custom_chains        = true;
 
+       defs->has_flag = (1 << FW3_DEFAULT_IPV4_LOADED);
+
        uci_foreach_element(&p->sections, e)
        {
                s = uci_to_section(e);
@@ -162,11 +164,14 @@ fw3_load_defaults(struct fw3_state *state, struct uci_package *p)
                check_policy(e, &defs->policy_output, "output");
                check_policy(e, &defs->policy_forward, "forward");
 
+               if (!defs->disable_ipv6)
+                       setbit(defs->has_flag, FW3_DEFAULT_IPV6_LOADED);
+
                if (defs->custom_chains)
-                       defs->has_flag |= (1 << FW3_DEFAULT_CUSTOM_CHAINS);
+                       setbit(defs->has_flag, FW3_DEFAULT_CUSTOM_CHAINS);
 
                if (defs->syn_flood)
-                       defs->has_flag |= (1 << FW3_DEFAULT_SYN_FLOOD);
+                       setbit(defs->has_flag, FW3_DEFAULT_SYN_FLOOD);
        }
 }
 
index d3750b5..8f899bd 100644 (file)
@@ -76,7 +76,7 @@ fw3_load_forwards(struct fw3_state *state, struct uci_package *p)
 
                if (forward->_dest)
                {
-                       forward->_dest->has_dest_target |= (1 << FW3_TARGET_ACCEPT);
+                       setbit(forward->_dest->has_dest_target, FW3_TARGET_ACCEPT);
 
                        if (forward->_src &&
                            (forward->_src->conntrack || forward->_dest->conntrack))
index 215be73..f4253c7 100644 (file)
--- a/ipsets.c
+++ b/ipsets.c
@@ -355,17 +355,20 @@ fw3_destroy_ipsets(struct list_head *statefile)
 {
        struct fw3_statefile_entry *e;
 
-       info("Destroying ipsets ...");
-
-       list_for_each_entry(e, statefile, list)
+       if (statefile)
        {
-               if (e->type != FW3_TYPE_IPSET)
-                       continue;
+               info("Destroying ipsets ...");
 
-               info(" * %s", e->name);
+               list_for_each_entry(e, statefile, list)
+               {
+                       if (e->type != FW3_TYPE_IPSET)
+                               continue;
 
-               fw3_pr("flush %s\n", e->name);
-               fw3_pr("destroy %s\n", e->name);
+                       info(" * %s", e->name);
+
+                       fw3_pr("flush %s\n", e->name);
+                       fw3_pr("destroy %s\n", e->name);
+               }
        }
 }
 
diff --git a/main.c b/main.c
index fdb901a..a58f582 100644 (file)
--- a/main.c
+++ b/main.c
 
 
 static bool print_rules = false;
-static bool skip_family[FW3_FAMILY_V6 + 1] = { false };
+static enum fw3_family use_family = FW3_FAMILY_ANY;
+
+static const char *families[] = {
+       "(bug)",
+       "IPv4",
+       "IPv6",
+};
+
+static const char *tables[] = {
+       "filter",
+       "nat",
+       "mangle",
+       "raw",
+};
 
 
 static struct fw3_state *
@@ -69,12 +82,6 @@ build_state(void)
        fw3_load_redirects(state, p);
        fw3_load_forwards(state, p);
 
-       if (state->defaults.disable_ipv6 && !skip_family[FW3_FAMILY_V6])
-       {
-               warn("IPv6 rules globally disabled in configuration");
-               skip_family[FW3_FAMILY_V6] = true;
-       }
-
        return state;
 }
 
@@ -124,27 +131,76 @@ restore_pipe(enum fw3_family family, bool silent)
        return true;
 }
 
+#define family_flag(f) \
+       (f == FW3_FAMILY_V4 ? FW3_DEFAULT_IPV4_LOADED : FW3_DEFAULT_IPV6_LOADED)
+
+static bool
+family_running(struct list_head *statefile, enum fw3_family family)
+{
+       struct fw3_statefile_entry *e;
+
+       if (statefile)
+       {
+               list_for_each_entry(e, statefile, list)
+               {
+                       if (e->type != FW3_TYPE_DEFAULTS)
+                               continue;
+
+                       return hasbit(e->flags[0], family_flag(family));
+               }
+       }
+
+       return false;
+}
+
+static bool
+family_used(enum fw3_family family)
+{
+       return (use_family == FW3_FAMILY_ANY) || (use_family == family);
+}
+
+static bool
+family_loaded(struct fw3_state *state, enum fw3_family family)
+{
+       return hasbit(state->defaults.has_flag, family_flag(family));
+}
+
+static void
+family_set(struct fw3_state *state, enum fw3_family family, bool set)
+{
+       if (set)
+               setbit(state->defaults.has_flag, family_flag(family));
+       else
+               delbit(state->defaults.has_flag, family_flag(family));
+}
+
 static int
-stop(struct fw3_state *state, bool complete, bool ipsets)
+stop(struct fw3_state *state, bool complete, bool restart)
 {
+       int rv = 1;
        enum fw3_family family;
        enum fw3_table table;
 
-       struct list_head *statefile = fw3_read_state();
+       struct list_head *statefile = fw3_read_statefile();
 
-       const char *tables[] = {
-               "filter",
-               "nat",
-               "mangle",
-               "raw",
-       };
+       if (!complete && !statefile)
+       {
+               if (!restart)
+                       warn("The firewall appears to be stopped. "
+                                "Use the 'flush' command to forcefully purge all rules.");
+
+               return rv;
+       }
 
        for (family = FW3_FAMILY_V4; family <= FW3_FAMILY_V6; family++)
        {
-               if (skip_family[family] || !restore_pipe(family, true))
+               if (!complete && !family_running(statefile, family))
+                       continue;
+
+               if (!family_used(family) || !restore_pipe(family, true))
                        continue;
 
-               info("Removing IPv%d rules ...", family == FW3_FAMILY_V4 ? 4 : 6);
+               info("Removing %s rules ...", families[family]);
 
                for (table = FW3_TABLE_FILTER; table <= FW3_TABLE_RAW; table++)
                {
@@ -175,33 +231,41 @@ stop(struct fw3_state *state, bool complete, bool ipsets)
                }
 
                fw3_command_close();
+
+               if (!restart)
+                       family_set(state, family, false);
+
+               rv = 0;
        }
 
-       if (ipsets && fw3_command_pipe(false, "ipset", "-exist", "-"))
+       if (!restart &&
+           !family_loaded(state, FW3_FAMILY_V4) &&
+           !family_loaded(state, FW3_FAMILY_V6) &&
+           fw3_command_pipe(false, "ipset", "-exist", "-"))
        {
                fw3_destroy_ipsets(statefile);
                fw3_command_close();
        }
 
-       fw3_free_state(statefile);
+       fw3_free_statefile(statefile);
 
-       return 0;
+       if (!rv)
+               fw3_write_statefile(state);
+
+       return rv;
 }
 
 static int
-start(struct fw3_state *state)
+start(struct fw3_state *state, bool restart)
 {
+       int rv = 1;
        enum fw3_family family;
        enum fw3_table table;
 
-       const char *tables[] = {
-               "filter",
-               "nat",
-               "mangle",
-               "raw",
-       };
+       struct list_head *statefile = fw3_read_statefile();
 
-       if (!print_rules && fw3_command_pipe(false, "ipset", "-exist", "-"))
+       if (!print_rules && !restart &&
+           fw3_command_pipe(false, "ipset", "-exist", "-"))
        {
                fw3_create_ipsets(state);
                fw3_command_close();
@@ -209,10 +273,22 @@ start(struct fw3_state *state)
 
        for (family = FW3_FAMILY_V4; family <= FW3_FAMILY_V6; family++)
        {
-               if (skip_family[family] || !restore_pipe(family, false))
+               if (!family_used(family))
+                       continue;
+
+               if (!family_loaded(state, family) || !restore_pipe(family, false))
+                       continue;
+
+               if (!restart && family_running(statefile, family))
+               {
+                       warn("The %s firewall appears to be started already. "
+                            "If it is indeed empty, remove the %s file and retry.",
+                            families[family], FW3_STATEFILE);
+
                        continue;
+               }
 
-               info("Constructing IPv%d rules ...", family == FW3_FAMILY_V4 ? 4 : 6);
+               info("Constructing %s rules ...", families[family]);
 
                for (table = FW3_TABLE_FILTER; table <= FW3_TABLE_RAW; table++)
                {
@@ -234,9 +310,17 @@ start(struct fw3_state *state)
                }
 
                fw3_command_close();
+               family_set(state, family, true);
+
+               rv = 0;
        }
 
-       return 0;
+       fw3_free_statefile(statefile);
+
+       if (!rv)
+               fw3_write_statefile(state);
+
+       return rv;
 }
 
 static int
@@ -296,19 +380,18 @@ int main(int argc, char **argv)
 {
        int ch, rv = 1;
        struct fw3_state *state = NULL;
+       struct fw3_defaults *defs = NULL;
 
        while ((ch = getopt(argc, argv, "46qh")) != -1)
        {
                switch (ch)
                {
                case '4':
-                       skip_family[FW3_FAMILY_V4] = false;
-                       skip_family[FW3_FAMILY_V6] = true;
+                       use_family = FW3_FAMILY_V4;
                        break;
 
                case '6':
-                       skip_family[FW3_FAMILY_V4] = true;
-                       skip_family[FW3_FAMILY_V6] = false;
+                       use_family = FW3_FAMILY_V6;
                        break;
 
                case 'q':
@@ -325,6 +408,7 @@ int main(int argc, char **argv)
                error("Failed to connect to ubus");
 
        state = build_state();
+       defs = &state->defaults;
 
        if (!fw3_lock())
                goto out;
@@ -335,6 +419,9 @@ int main(int argc, char **argv)
                goto out;
        }
 
+       if (use_family == FW3_FAMILY_V6 && defs->disable_ipv6)
+               warn("IPv6 rules globally disabled in configuration");
+
        if (!strcmp(argv[optind], "print"))
        {
                freopen("/dev/null", "w", stderr);
@@ -342,56 +429,24 @@ int main(int argc, char **argv)
                state->disable_ipsets = true;
                print_rules = true;
 
-               if (!skip_family[FW3_FAMILY_V4] && !skip_family[FW3_FAMILY_V6])
-                       skip_family[FW3_FAMILY_V6] = true;
-
-               rv = start(state);
+               rv = start(state, false);
        }
        else if (!strcmp(argv[optind], "start"))
        {
-               if (fw3_has_state())
-               {
-                       warn("The firewall appears to be started already. "
-                                "If it is indeed empty, remove the %s file and retry.",
-                                FW3_STATEFILE);
-
-                       goto out;
-               }
-
-               rv = start(state);
-               fw3_write_state(state);
+               rv = start(state, false);
        }
        else if (!strcmp(argv[optind], "stop"))
        {
-               if (!fw3_has_state())
-               {
-                       warn("The firewall appears to be stopped. "
-                                "Use the 'flush' command to forcefully purge all rules.");
-
-                       goto out;
-               }
-
-               rv = stop(state, false, true);
-
-               fw3_remove_state();
+               rv = stop(state, false, false);
        }
        else if (!strcmp(argv[optind], "flush"))
        {
-               rv = stop(state, true, true);
-
-               if (fw3_has_state())
-                       fw3_remove_state();
+               rv = stop(state, true, false);
        }
        else if (!strcmp(argv[optind], "restart"))
        {
-               if (fw3_has_state())
-               {
-                       stop(state, false, false);
-                       fw3_remove_state();
-               }
-
-               rv = start(state);
-               fw3_write_state(state);
+               rv = stop(state, false, true);
+               rv = start(state, !rv);
        }
        else if (!strcmp(argv[optind], "network") && (optind + 1) < argc)
        {
index ee2c008..e4a507f 100644 (file)
--- a/options.h
+++ b/options.h
@@ -77,6 +77,8 @@ enum fw3_default
        FW3_DEFAULT_SYN_FLOOD     = 2,
        FW3_DEFAULT_MTU_FIX       = 3,
        FW3_DEFAULT_DROP_INVALID  = 4,
+       FW3_DEFAULT_IPV4_LOADED   = 5,
+       FW3_DEFAULT_IPV6_LOADED   = 6,
 };
 
 enum fw3_limit_unit
index 1fc81f0..f8eaed3 100644 (file)
@@ -133,16 +133,16 @@ fw3_load_redirects(struct fw3_state *state, struct uci_package *p)
                                warn_elem(e, "has no source specified");
                        else
                        {
-                               redir->_src->has_dest_target |= (1 << redir->target);
+                               setbit(redir->_src->has_dest_target, redir->target);
                                redir->_src->conntrack = true;
                                valid = true;
                        }
 
                        if (redir->reflection && redir->_dest && redir->_src->masq)
                        {
-                               redir->_dest->has_dest_target |= (1 << FW3_TARGET_ACCEPT);
-                               redir->_dest->has_dest_target |= (1 << FW3_TARGET_DNAT);
-                               redir->_dest->has_dest_target |= (1 << FW3_TARGET_SNAT);
+                               setbit(redir->_dest->has_dest_target, FW3_TARGET_ACCEPT);
+                               setbit(redir->_dest->has_dest_target, FW3_TARGET_DNAT);
+                               setbit(redir->_dest->has_dest_target, FW3_TARGET_SNAT);
                        }
                }
                else
@@ -155,7 +155,7 @@ fw3_load_redirects(struct fw3_state *state, struct uci_package *p)
                                warn_elem(e, "has no src_dip option specified");
                        else
                        {
-                               redir->_dest->has_dest_target |= (1 << redir->target);
+                               setbit(redir->_dest->has_dest_target, redir->target);
                                redir->_dest->conntrack = true;
                                valid = true;
                        }
diff --git a/rules.c b/rules.c
index 12c04c9..d29bba2 100644 (file)
--- a/rules.c
+++ b/rules.c
@@ -142,7 +142,7 @@ fw3_load_rules(struct fw3_state *state, struct uci_package *p)
                }
 
                if (rule->_dest)
-                       rule->_dest->has_dest_target |= (1 << rule->target);
+                       setbit(rule->_dest->has_dest_target, rule->target);
 
                list_add_tail(&rule->list, &state->rules);
                continue;
diff --git a/utils.c b/utils.c
index 4691fe1..c3dbb3d 100644 (file)
--- a/utils.c
+++ b/utils.c
@@ -332,15 +332,8 @@ fw3_unlock(void)
 }
 
 
-bool
-fw3_has_state(void)
-{
-       struct stat s;
-       return !stat(FW3_STATEFILE, &s);
-}
-
 struct list_head *
-fw3_read_state(void)
+fw3_read_statefile(void)
 {
        FILE *sf;
 
@@ -351,6 +344,11 @@ fw3_read_state(void)
        struct list_head *state;
        struct fw3_statefile_entry *entry;
 
+       sf = fopen(FW3_STATEFILE, "r");
+
+       if (!sf)
+               return NULL;
+
        state = malloc(sizeof(*state));
 
        if (!state)
@@ -358,16 +356,6 @@ fw3_read_state(void)
 
        INIT_LIST_HEAD(state);
 
-       sf = fopen(FW3_STATEFILE, "r");
-
-       if (!sf)
-       {
-               warn("Cannot open state %s: %s", FW3_STATEFILE, strerror(errno));
-               free(state);
-
-               return NULL;
-       }
-
        while (fgets(line, sizeof(line), sf))
        {
                entry = malloc(sizeof(*entry));
@@ -407,14 +395,7 @@ fw3_read_state(void)
 }
 
 void
-fw3_free_state(struct list_head *statefile)
-{
-       fw3_free_list(statefile);
-       free(statefile);
-}
-
-void
-fw3_write_state(void *state)
+fw3_write_statefile(void *state)
 {
        FILE *sf;
        struct fw3_state *s = state;
@@ -422,6 +403,17 @@ fw3_write_state(void *state)
        struct fw3_zone *z;
        struct fw3_ipset *i;
 
+       int mask = (1 << FW3_DEFAULT_IPV4_LOADED) | (1 << FW3_DEFAULT_IPV6_LOADED);
+
+       if (!(d->has_flag & mask))
+       {
+               if (unlink(FW3_STATEFILE))
+                       warn("Unable to remove state %s: %s",
+                            FW3_STATEFILE, strerror(errno));
+
+               return;
+       }
+
        sf = fopen(FW3_STATEFILE, "w");
 
        if (!sf)
@@ -450,8 +442,19 @@ fw3_write_state(void *state)
 }
 
 void
-fw3_remove_state(void)
+fw3_free_statefile(struct list_head *statefile)
 {
-       if (unlink(FW3_STATEFILE))
-               warn("Unable to remove state %s: %s", FW3_STATEFILE, strerror(errno));
+       struct fw3_statefile_entry *e, *tmp;
+
+       if (!statefile)
+               return;
+
+       list_for_each_entry_safe(e, tmp, statefile, list)
+       {
+               list_del(&e->list);
+               free(e->name);
+               free(e);
+       }
+
+       free(statefile);
 }
diff --git a/utils.h b/utils.h
index 2178b5a..590895a 100644 (file)
--- a/utils.h
+++ b/utils.h
@@ -40,6 +40,10 @@ void warn(const char *format, ...);
 void error(const char *format, ...);
 void info(const char *format, ...);
 
+#define setbit(field, flag) field |= (1 << (flag))
+#define delbit(field, flag) field &= ~(1 << (flag))
+#define hasbit(field, flag) (field & (1 << (flag)))
+
 #define fw3_foreach(p, h)                                                  \
        for (p = list_empty(h) ? NULL : list_first_entry(h, typeof(*p), list); \
          list_empty(h) ? (p == NULL) : (&p->list != (h));                  \
@@ -75,10 +79,6 @@ bool fw3_has_table(bool ipv6, const char *table);
 bool fw3_lock(void);
 void fw3_unlock(void);
 
-bool fw3_has_state(void);
-void fw3_write_state(void *state);
-void fw3_remove_state(void);
-
 
 enum fw3_statefile_type
 {
@@ -91,11 +91,12 @@ struct fw3_statefile_entry
 {
        struct list_head list;
        enum fw3_statefile_type type;
-       const char *name;
+       char *name;
        uint32_t flags[2];
 };
 
-struct list_head * fw3_read_state(void);
-void fw3_free_state(struct list_head *statefile);
+struct list_head * fw3_read_statefile(void);
+void fw3_write_statefile(void *state);
+void fw3_free_statefile(struct list_head *statefile);
 
 #endif
diff --git a/zones.c b/zones.c
index 3ab17bf..434d758 100644 (file)
--- a/zones.c
+++ b/zones.c
@@ -201,13 +201,13 @@ fw3_load_zones(struct fw3_state *state, struct uci_package *p)
 
                if (zone->masq)
                {
-                       zone->has_dest_target |= (1 << FW3_TARGET_SNAT);
+                       setbit(zone->has_dest_target, FW3_TARGET_SNAT);
                        zone->conntrack = true;
                }
 
-               zone->has_src_target  |= (1 << zone->policy_input);
-               zone->has_dest_target |= (1 << zone->policy_output);
-               zone->has_dest_target |= (1 << zone->policy_forward);
+               setbit(zone->has_src_target, zone->policy_input);
+               setbit(zone->has_dest_target, zone->policy_output);
+               setbit(zone->has_dest_target, zone->policy_forward);
 
                list_add_tail(&zone->list, &state->zones);
        }
@@ -224,7 +224,7 @@ print_zone_chain(enum fw3_table table, enum fw3_family family,
                return;
 
        if (!zone->conntrack && !disable_notrack)
-               zone->has_dest_target |= (1 << FW3_TARGET_NOTRACK);
+               setbit(zone->has_dest_target, FW3_TARGET_NOTRACK);
 
        s = print_chains(table, family, ":%s - [0:0]\n", zone->name,
                         zone->has_src_target, src_chains, ARRAY_SIZE(src_chains));