Demo Application#

This is a demo application for the Speck2b device for a smart door lock. It uses samna to take observations from the camera and to predict whether the door must be opened.

First set the variables that are used elsewhere in the code. Note that images must have absolute paths, as it will raise exceptions in C++ on mac when using relative path.

dvs_size = (128, 128)
feature_count = 2
feature_names = ["Closed", "Open"]
spike_collection_interval = 500
silent_threshold = 60
images = [os.path.abspath(os.path.join( "./smartdoor/icons/", name + ".png"))
          for name in ["closed", "open", "silent"]]

The visualizer is used to show the DVS output, spike counts and the predictions. Because the visualizer must run in a separate process and is blocking, the Visualizer class is used for abstraction. First it initializes samna to create the nodes for the interprocess communication. Then it starts the visualizer process in a different thread. Once the visualizer is running and connection is established, the three plots are added.

class Visualizer:
    def __init__(
        self, show_dvs_size, show_feature_count, show_feature_names, show_readout_images
    ):
        self.endpoint = None
        self.gui_process = None

        self.__open_visualizer(
            show_dvs_size, show_feature_count, show_feature_names, show_readout_images
        )

    def get_endpoint(self):
        if self.endpoint:
            return self.endpoint
        free_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        free_socket.bind(("0.0.0.0", 0))
        free_socket.listen(5)
        self.endpoint = f"tcp://0.0.0.0:{free_socket.getsockname()[1]}"
        free_socket.close()
        return self.endpoint

    def join(self):
        self.gui_process.join()

    def __open_visualizer(
        self, show_dvs_size, show_feature_count, show_feature_names, show_readout_images
    ):
        # Visualizer window size and layout
        window_width = 3 / 4
        window_height = 9 / 16
        dvs_layout = [0, 0, 0.5, 1]
        readout_layout = [0.5, 0, 1, 0.5]
        spike_count_layout = [0.5, 0.5, 1, 1]

        self.gui_process = Visualizer.start_visualizer_process(
            window_width, window_height, self.get_endpoint()
        )

        graph = samna.graph.EventFilterGraph()

        config_source, streamer = graph.sequential(
            [samna.BasicSourceNode_ui_event(), "VizEventStreamer"]
        )

        streamer.set_streamer_endpoint(self.get_endpoint())
        if streamer.wait_for_receiver_count() == 0:
            raise Exception(f"connecting to visualizer on {self.get_endpoint()} fails")

        graph.start()

        visualizer_config = samna.ui.VisualizerConfiguration(
            plots=[
                samna.ui.ActivityPlotConfiguration(
                    *show_dvs_size, "DVS Layer", dvs_layout
                ),
                samna.ui.ReadoutPlotConfiguration(
                    "Readout Layer", show_readout_images, readout_layout
                ),
                samna.ui.SpikeCountPlotConfiguration(
                    title="Spike Count",
                    channel_count=show_feature_count,
                    line_names=show_feature_names,
                    layout=spike_count_layout,
                    show_x_span=25,
                    label_interval=2.5,
                    max_y_rate=1.2,
                    show_point_circle=True,
                    default_y_max=10,
                ),
            ]
        )

        config_source.write([visualizer_config])

    @staticmethod
    def start_visualizer_process(window_width, window_height, receiver_endpoint):
        gui_process = Process(
            target=samnagui.run_visualizer,
            args=(receiver_endpoint, window_width, window_height),
        )
        gui_process.start()

        return gui_process

The readout function to be used is not built as a filter in to the samna library, a JIT filter is defined instead. The filter predicts a class if the spikes of that class surpass the previously defined silent_threshold.

def jit_readout_filter():
    num_classes = 2 # ["Closed", "Open"]
    silent_class = 2

    src = """
    template <typename Spike>
    class SimpleReadout : public iris::FilterInterface<std::shared_ptr<const std::vector<Spike>>, std::shared_ptr<const std::vector<ui::Event>>> {{
    public:
        SimpleReadout() = default;

        void apply() override {{
            std::vector<std::shared_ptr<std::vector<ui::Event>>> results;

            // The input type here is 'std::optional<std::shared_ptr<const std::vector<Spike>>>' as mentioned in the documentation.
            // The optional will return false and exit the loop when there is no more input waiting
            while (const auto& spikes = this->receiveInput()) {{
                results.emplace_back(predict(**spikes));
            }}

            if (results.empty()) {{
                return;
            }}

            this->forwardResultsInBulk(std::move(results));
        }}

    private:
        auto predict(const std::vector<Spike>& spikes) {{
            // get each-class spikes number
            auto spike_counts_each_class = get_spikes_of_each_class(spikes);
            // find the class with the maximum spike number
            std::vector<std::pair<int, int>> sorted_results{{spike_counts_each_class.begin(), spike_counts_each_class.end()}};
            std::partial_sort(sorted_results.begin(), sorted_results.begin() + 1, sorted_results.end(),
                                [](const auto& a, const auto& b) {{ return a.second > b.second; }});
            // return prediction
            const auto total_spike_is_large_enough = sorted_results[0].second > silent_threshold;
            auto prediction = total_spike_is_large_enough ? sorted_results[0].first : silent_class;
            return std::make_shared<std::vector<ui::Event>>(std::vector<ui::Event>{{ui::Readout(prediction)}});
        }}

        std::map<int, int> get_spikes_of_each_class(const std::vector<Spike>& spikes) {{
            std::map<int, int> spike_counts_each_class;
            for (int i = 0; i < num_classes; i++) {{
                spike_counts_each_class[i] = 0;
            }}
            // loop all the spikes
            for (const auto& spike : spikes) {{
                // chip may generate wrong class
                if (spike.feature < num_classes) {{
                    ++spike_counts_each_class.at(spike.feature);
                }}
            }}
            return spike_counts_each_class;
        }}

        static constexpr auto num_classes = {};
        static constexpr auto silent_class = {};
        static constexpr auto silent_threshold = {};
    }};
    """.format(silent_threshold, num_classes, silent_class)

    return samna.graph.JitFilter("SimpleReadout", src)

After defining the helper class and function above, everything can be connected together.

# Start the visualizer, this will be done in a different process
visualizer = Visualizer(dvs_size, feature_count, feature_names, images)

# Open a device using the device string
device = samna.device.open_device("Speck2bTestboard:0")

# First we define a filter graph to process the device output
graph = samna.graph.EventFilterGraph()

# Take the dvs output of the device and stream it directly to the visualizer
streamer = graph.sequential([device.get_model_source_node(), "Speck2bDvsToVizConverter", samna.graph.JitZMQStreamer()])[2]
streamer.set_streamer_endpoint(visualizer.get_endpoint())

# Collect the spikes from the device in a predefined interval for further processing
spike_collection_filter = graph.sequential([device.get_model_source_node(), samna.graph.JitSpikeCollection(samna.speck2b.event.Spike)])[1]
spike_collection_filter.set_interval_milli_sec(spike_collection_interval)

# Count the spikes in the last spike collection interval and stream them to the visualizer
spike_count_filter = graph.sequential([spike_collection_filter, samna.graph.JitSpikeCount(samna.ui.Event), streamer])[1]
spike_count_filter.set_feature_count(feature_count)

# Read out the label of the last spike collection interval and stream it to the visualizer
graph.sequential([spike_collection_filter, jit_readout_filter(), streamer])

# Start the graph to process events
graph.start()

# Configure the device with the ML model
config = get_config()
device.get_model().apply_configuration(config)

# Wait until visualizer window is destroyed
visualizer.join()