load running state after lock is acquired
[project/firewall3.git] / main.c
diff --git a/main.c b/main.c
index 17d71d1..241da62 100644 (file)
--- a/main.c
+++ b/main.c
@@ -45,12 +45,10 @@ build_state(bool runtime)
        struct uci_package *p = NULL;
        FILE *sf;
 
-       state = malloc(sizeof(*state));
-
+       state = calloc(1, sizeof(*state));
        if (!state)
                error("Out of memory");
 
-       memset(state, 0, sizeof(*state));
        state->uci = uci_alloc_context();
 
        if (!state->uci)
@@ -169,8 +167,6 @@ family_set(struct fw3_state *state, enum fw3_family family, bool set)
 static int
 stop(bool complete)
 {
-       FILE *ct;
-
        int rv = 1;
        enum fw3_family family;
        enum fw3_table table;
@@ -226,13 +222,8 @@ stop(bool complete)
        if (run_state)
                fw3_destroy_ipsets(run_state);
 
-       if (complete && (ct = fopen("/proc/net/nf_conntrack", "w")) != NULL)
-       {
-               info(" * Flushing conntrack table ...");
-
-               fwrite("f\n", 2, 1, ct);
-               fclose(ct);
-       }
+       if (complete)
+               fw3_flush_conntrack(NULL);
 
        if (!rv && run_state)
                fw3_write_statefile(run_state);
@@ -306,6 +297,7 @@ start(void)
 
        if (!rv)
        {
+               fw3_flush_conntrack(run_state);
                fw3_set_defaults(cfg_state);
 
                if (!print_family)
@@ -397,6 +389,8 @@ start:
 
        if (!rv)
        {
+               fw3_flush_conntrack(run_state);
+
                fw3_set_defaults(cfg_state);
                fw3_run_includes(cfg_state, true);
                fw3_hotplug_zones(cfg_state, true);
@@ -407,6 +401,35 @@ start:
 }
 
 static int
+gc(void)
+{
+       enum fw3_family family;
+       enum fw3_table table;
+       struct fw3_ipt_handle *handle;
+
+       for (family = FW3_FAMILY_V4; family <= FW3_FAMILY_V6; family++)
+       {
+               if (family == FW3_FAMILY_V6 && cfg_state->defaults.disable_ipv6)
+                       continue;
+
+               for (table = FW3_TABLE_FILTER; table <= FW3_TABLE_RAW; table++)
+               {
+                       if (!fw3_has_table(family == FW3_FAMILY_V6, fw3_flag_names[table]))
+                               continue;
+
+                       if (!(handle = fw3_ipt_open(family, table)))
+                               continue;
+
+                       fw3_ipt_gc(handle);
+                       fw3_ipt_commit(handle);
+                       fw3_ipt_close(handle);
+               }
+       }
+
+       return 0;
+}
+
+static int
 lookup_network(const char *net)
 {
        struct fw3_zone *z;
@@ -523,7 +546,6 @@ int main(int argc, char **argv)
        }
 
        build_state(false);
-       build_state(true);
        defs = &cfg_state->defaults;
 
        if (optind >= argc)
@@ -554,12 +576,18 @@ int main(int argc, char **argv)
                print_family = family;
                fw3_pr_debug = true;
 
-               rv = start();
+               if (fw3_lock())
+               {
+                       build_state(true);
+                       rv = start();
+                       fw3_unlock();
+               }
        }
        else if (!strcmp(argv[optind], "start"))
        {
                if (fw3_lock())
                {
+                       build_state(true);
                        rv = start();
                        fw3_unlock();
                }
@@ -568,6 +596,7 @@ int main(int argc, char **argv)
        {
                if (fw3_lock())
                {
+                       build_state(true);
                        rv = stop(false);
                        fw3_unlock();
                }
@@ -576,6 +605,7 @@ int main(int argc, char **argv)
        {
                if (fw3_lock())
                {
+                       build_state(true);
                        rv = stop(true);
                        fw3_unlock();
                }
@@ -584,6 +614,7 @@ int main(int argc, char **argv)
        {
                if (fw3_lock())
                {
+                       build_state(true);
                        stop(true);
                        rv = start();
                        fw3_unlock();
@@ -593,10 +624,19 @@ int main(int argc, char **argv)
        {
                if (fw3_lock())
                {
+                       build_state(true);
                        rv = reload();
                        fw3_unlock();
                }
        }
+       else if (!strcmp(argv[optind], "gc"))
+       {
+               if (fw3_lock())
+               {
+                       rv = gc();
+                       fw3_unlock();
+               }
+       }
        else if (!strcmp(argv[optind], "network") && (optind + 1) < argc)
        {
                rv = lookup_network(argv[optind + 1]);