Speck 2f proximity detection demo

Speck 2f proximity detection demo#

An example that uses Speck2fDevKit/Speck2fModuleDevkit to detect how people are passing in front of the camera.

This example bases on packages :

- samna             0.30.9
- sinabs            1.2.2
- sinabs_dynapcnn   1.0.10
- torch             1.12.1

Files needed (all stuffs should be put into the same path of the Python script):

Structure:

Speck 2f proximity detection
from multiprocessing import Process

import samna
import samnagui
from gen_config import gen_config


def open_speck2f():
    devices = [
        device
        for device in samna.device.get_unopened_devices()
        if device.device_type_name.startswith("Speck2f")
    ]
    assert devices, "There is no avaliable speck2f board."
    return samna.device.open_device(devices[0])


def open_visualizer(window_width, window_height, receiver_endpoint):
    # start visualizer in a isolated process which is required on mac, intead of a sub process.
    gui_process = Process(
        target=samnagui.run_visualizer,
        args=(receiver_endpoint, window_width, window_height),
    )
    gui_process.start()

    return gui_process


def post_processing_filter(jit_config):
    jit_src = """class GestureReadoutState : public iris::FilterInterface<std::shared_ptr<const std::vector<ui::Event>>, std::shared_ptr<const std::vector<ui::Event>>> {{
public:
    void apply() override
    {{
        while (const auto maybeEventsPtr = this->receiveInput()) {{
            if ((**maybeEventsPtr).empty()) {{
                continue;
            }}

            auto result = std::make_shared<std::vector<ui::Event>>();
            std::optional<uint32_t> new_predication = std::nullopt;

            const ui::SpikeCount* max_spike = nullptr;
            for (const auto& event : **maybeEventsPtr) {{
                if (const auto spike = std::get_if<ui::SpikeCount>(&event)) {{
                    if (!max_spike || max_spike->count < spike->count) {{
                        max_spike = spike;
                    }}
                }}
            }}

            if (!max_spike) {{
                continue;
            }}

            auto feature = max_spike->count > SILENT_THRESH ? max_spike->feature : -1;

            checked_append(feature);

            if (full()) {{
                if (is(-1, 3)) {{
                    new_predication = {left};
                    reset(feature_buffer[1]);
                }}
                else if (is(-1, 4)) {{
                    new_predication = {right};
                    reset(feature_buffer[1]);
                }}
                else if (is_combine({{0, 3, 4}}, {{-1}})) {{
                    new_predication = {silent};
                    reset(-1);
                }}
                else if (is(3, 4) || is(4, 3)) {{
                    reset();
                }}
                else if (is_combine({{3, 4}}, {{2}})) {{
                    reset(2);
                }}
                else if (is_combine({{3, 4}}, {{std::nullopt}}) ||
                        is_combine({{std::nullopt}}, {{3, 4}})) {{
                    reset();
                }}
                else if (is_combine({{-1}}, {{3, 4}}) ||
                        is_combine({{3, 4}}, {{-1}})) {{
                    reset();
                }}
                else if (is(-1, 0)) {{
                    new_predication = {near3};
                }}
                else if (is(0, -1)) {{
                    new_predication = {near2};
                    reset(1);
                }}
                else if (is(1, 2)) {{
                    new_predication = {near1};
                    reset(2);
                }}
                else if (is(2, 1)) {{
                    new_predication = {away2};
                    reset(1);
                }}
                else if (is(1, 0)) {{
                    new_predication = {away3};
                    reset(0);
                }}
            }}
            else if (is(-1)) {{
                new_predication = {silent};
            }}

            if (std::chrono::steady_clock::now() - std::chrono::milliseconds(500) > lastReadoutTime) {{
                result->emplace_back(ui::Readout{{{silent}}});
            }}
            else
            {{
                if (new_predication) {{
                    predication = new_predication;
                    lastReadoutTime = std::chrono::steady_clock::now();
                }}

                if (predication) {{
                    result->emplace_back(ui::Readout{{*predication}});
                }}
            }}

            this->forwardResult(std::move(result));
        }}
    }}

private:
    static constexpr auto SILENT_THRESH = 30;
    std::array<int, 2> feature_buffer = {{}};
    size_t feature_buffer_size = 0;
    std::optional<uint32_t> predication = {silent};
    std::chrono::steady_clock::time_point lastReadoutTime = {{}};

    auto begin()
    {{
        return feature_buffer.begin();
    }}
    auto end()
    {{
        return feature_buffer.begin() + feature_buffer_size;
    }}

    void append(int val)
    {{
        feature_buffer[feature_buffer_size++] = val;
    }}

    bool full()
    {{
        return feature_buffer_size == feature_buffer.size();
    }}

    void checked_append(int val)
    {{
        if (!full() && std::find(begin(), end(), val) == end()) {{
            append(val);
        }}
    }}

    template<typename... T>
    void reset(T&&... values)
    {{
        feature_buffer_size = 0;
        feature_buffer = {{values...}};
    }}

    template<typename... T>
    bool is(T&&... values)
    {{
        auto v = std::initializer_list<std::optional<int>>{{values...}};
        return std::equal(begin(), end(), v.begin(), [](auto&& lhs, auto&& rhs) {{
            return lhs == rhs.value_or(lhs);
        }});
    }}

    bool is_combine(std::vector<std::optional<int>> values1, std::vector<std::optional<int>> values2)
    {{
        for (auto value1 : values1) {{
            for (auto value2 : values2) {{
                if (is(value1, value2)) {{
                    return true;
                }}
            }}
        }}
        return false;
    }}
}};""".format(
        **jit_config
    )
    return samna.graph.JitFilter("GestureReadoutState", jit_src)


def samna_initialization(devkit):
    streamer_endpoint = "tcp://0.0.0.0:40000"

    gui_process = open_visualizer(0.75, 0.75, streamer_endpoint)

    image_names = [
        "silent",
        "near1",
        "near2",
        "near3",
        "away2",
        "away3",
        "left",
        "right",
    ]

    visualizer_config = samna.ui.VisualizerConfiguration(
        # add plots to gui
        plots=[
            # add dvs plot
            samna.ui.ActivityPlotConfiguration(128, 128, "DVS Layer", [0, 0, 0.5, 0.8]),
            # add imnage plot
            samna.ui.ReadoutPlotConfiguration(
                "Readout Layer",
                [f"./readout_icons/{name}.png" for name in image_names],
                [0.5, 0, 1, 0.8],
            ),
            # add power measurement plot
            samna.ui.PowerMeasurementPlotConfiguration(
                title="Power Consumption",
                channel_count=5,
                line_names=["io", "ram", "logic", "vddd", "vdda"],
                layout=[0, 0.8, 1, 1],
                show_x_span=10,
                label_interval=2,
                max_y_rate=1.5,
                show_point_circle=False,
                default_y_max=1,
                y_label_name="power (mW)",
            ),
        ]
    )

    graph = samna.graph.EventFilterGraph()

    # init the graph
    _, _, streamer = graph.sequential(
        [devkit.get_model_source_node(), "Speck2fDvsToVizConverter", "VizEventStreamer"]
    )

    jit_config = {index: name for name, index in enumerate(image_names)}
    _, spike_collection_filter, spike_count_filter, _, _ = graph.sequential(
        [
            devkit.get_model_source_node(),
            "Speck2fSpikeCollectionNode",
            "Speck2fSpikeCountNode",
            post_processing_filter(jit_config),
            streamer,
        ]
    )

    # divide according to this time period in milliseconds
    spike_collection_filter.set_interval_milli_sec(100)
    spike_count_filter.set_feature_count(6)

    power = devkit.get_power_monitor()
    power.start_auto_power_measurement(20)
    graph.sequential([power.get_source_node(), "MeasurementToVizConverter", streamer])

    config_source, _ = graph.sequential([samna.BasicSourceNode_ui_event(), streamer])

    streamer.set_streamer_endpoint(streamer_endpoint)
    if streamer.wait_for_receiver_count() == 0:
        raise Exception(f'connecting to visualizer on {streamer_endpoint} fails')

    graph.start()

    config_source.write([visualizer_config])

    return graph, gui_process


def run_demo_main():
    last_layer_id, config = gen_config()

    # enable the last layer monitor
    config.dvs_layer.monitor_enable = True
    config.cnn_layers[last_layer_id].monitor_enable = True

    # open device
    devkit = open_speck2f()

    # start running on hardware
    devkit.get_model().apply_configuration(config)

    # set timestamp
    stopWatch = devkit.get_stop_watch()
    stopWatch.set_enable_value(True)

    # set io of the devkit
    dk_io = devkit.get_io_module()

    dk_io.set_slow_clk_rate(10)  # Hz
    dk_io.set_slow_clk(True)

    graph, gui_process = samna_initialization(devkit)

    gui_process.join()
    graph.stop()


if __name__ == "__main__":
    run_demo_main()