Xylo-Imu Manual Mode

Xylo-Imu Manual Mode#

Manual mode is debug/test mode of Xylo-IMU. You can input spikes and trigger processing and read/write register or memory at anytime in this mode.

Typically all things are controlled manually in manual mode, here is an example that bases on packages :

- samna                 0.33.1
import samna

def initialize_board() :
    dk = samna.device.open_device("XyloImuTestBoard:0")
    buf = samna.graph.sink_from(dk.get_model_source_node())
    source = samna.graph.source_to(dk.get_model_sink_node())
    return dk, buf, source

dk, buf, source = initialize_board()

def build_event_type_filters(dk, graph):
    _, etf0, register_value_buf = graph.sequential([dk.get_model_source_node(), "XyloImuOutputEventTypeFilter", samna.graph.JitSink()])
    etf0.set_desired_type('xyloImu::event::RegisterValue')
    _, etf1, readout_buf = graph.sequential([dk.get_model_source_node(), "XyloImuOutputEventTypeFilter", samna.graph.JitSink()])
    etf1.set_desired_type('xyloImu::event::Readout')
    _, etf2, interrupt_buf = graph.sequential([dk.get_model_source_node(), "XyloImuOutputEventTypeFilter", samna.graph.JitSink()])
    etf2.set_desired_type('xyloImu::event::Interrupt')
    _, etf3, membrane_potential_buf = graph.sequential([dk.get_model_source_node(), "XyloImuOutputEventTypeFilter", samna.graph.JitSink()])
    etf3.set_desired_type('xyloImu::event::MembranePotential')
    _, etf4, synaptic_current_buf = graph.sequential([dk.get_model_source_node(), "XyloImuOutputEventTypeFilter", samna.graph.JitSink()])
    etf4.set_desired_type('xyloImu::event::SynapticCurrent')
    _, etf5, hidden_spike_buf = graph.sequential([dk.get_model_source_node(), "XyloImuOutputEventTypeFilter", samna.graph.JitSink()])
    etf5.set_desired_type('xyloImu::event::HiddenSpikeCount')

    return register_value_buf, readout_buf, interrupt_buf, membrane_potential_buf, synaptic_current_buf, hidden_spike_buf

graph = samna.graph.EventFilterGraph()  # Please mind that this `graph` object can't be released in python while receiving events, otherwise no event will be received.
register_value_buf, readout_buf, interrupt_buf, membrane_potential_buf, synaptic_current_buf, hidden_spike_buf = build_event_type_filters(dk, graph)
graph.start()       # Graph has to be started manually to work.

def read_register(address):
    buf.get_events()
    source.write([samna.xyloImu.event.ReadRegisterValue(address = address)])
    events = register_value_buf.get_n_events(1, 2000) # Try to get 1 event in 2 seconds.
    assert(len(events) == 1)
    return events[0].data

def trigger_processing():
    # Send a trigger and wait for processing done
    buf.get_events()
    source.write([samna.xyloImu.event.TriggerProcessing()])
    interrupt_events = interrupt_buf.get_n_events(1, 2000) # Try to get 1 event in 2 seconds.
    if not interrupt_events:
        # By default there is an interrupt after processing done.
        raise Exception("No interrupt occurs after processing done!")

def request_readout(hidden_count, output_count):
    buf.get_events()
    source.write([samna.xyloImu.event.TriggerReadout()])
    readouts = readout_buf.get_n_events(1, 2000)

    # Only two attributes of `Readout` event is available in manual mode: `timestep`, `output_v_mems`.
    # We have to read all other things manually in manual mode.
    assert(len(readouts) == 1)
    readout = readouts[0]

    # Read all membrane potentials
    for _ in range(2):      # Due to a bug on chip, you have to read memory twice to ensure it's correct.
        source.write([samna.xyloImu.event.ReadMembranePotential(neuron_id = i) for i in range(hidden_count + output_count)])
        membrane_potentials = membrane_potential_buf.get_n_events(hidden_count + output_count, 5000)
        assert(len(membrane_potentials) == hidden_count + output_count)
        readout.neuron_v_mems = [e.value for e in membrane_potentials]

    # Read all synaptic current
    for _ in range(2):      # Due to a bug on chip, you have to read memory twice to ensure it's correct.
        source.write([samna.xyloImu.event.ReadSynapticCurrent(neuron_id = i) for i in range(hidden_count + output_count)])
        synaptic_currents = synaptic_current_buf.get_n_events(hidden_count + output_count, 5000)
        assert(len(synaptic_currents) == hidden_count + output_count)
        readout.neuron_i_syns = [e.value for e in synaptic_currents]

    # Read all hidden spike count
    source.write([samna.xyloImu.event.ReadHiddenSpikeCount(neuron_id = i) for i in range(hidden_count)])
    hidden_spikes = hidden_spike_buf.get_n_events(hidden_count, 5000)
    assert(len(hidden_spikes) == hidden_count)
    readout.hidden_spikes = [e.count for e in hidden_spikes]

    # Read output spikes from register
    stat_reg_addr = 0x4B
    stat = read_register(stat_reg_addr)
    readout.output_spikes = [1 if stat & (1 << i) else 0 for i in range(output_count)]

    return readout

def apply_configuration():
    xylo_config = samna.xyloImu.configuration.XyloConfiguration()
    xylo_config.operation_mode = samna.xyloImu.OperationMode.Manual

    input_count = 3
    hidden_count = 5
    output_count = 2
    xylo_config.input.weights = [[1] * hidden_count] * input_count
    xylo_config.hidden.weights = [[1] * hidden_count] * hidden_count
    hidden_neurons = [samna.xyloImu.configuration.HiddenNeuron()] * hidden_count
    xylo_config.hidden.neurons = hidden_neurons
    output_neurons = [samna.xyloImu.configuration.OutputNeuron()] * output_count
    xylo_config.readout.neurons = output_neurons
    xylo_config.readout.weights = [[1] * output_count] * hidden_count

    dk.get_model().apply_configuration(xylo_config)
    return xylo_config, input_count, hidden_count, output_count

xylo_config, input_count, hidden_count, output_count = apply_configuration()

def send_spikes(neurons):
    events = []
    for n in neurons:
        ev = samna.xyloImu.event.Spike()
        ev.neuron_id = n
        events.append(ev)
    source.write(events)

def evolve(input_neurons):
    send_spikes(input_neurons)                                  # Input spikes to process
    trigger_processing()                                        # Process the input spikes, returns when processing done
    readout = request_readout(hidden_count, output_count)          # Read all state after processing to debug
    print("Readout after processing: ", readout)

# Request manually
readout = request_readout(hidden_count, output_count)
print("Initial readout: ", readout)

# Process spikes and read state after processing
evolve([0,1,2])             # timestep 0
evolve([1,1,2,2,2,0,0])     # timestep 1
evolve([])                  # timestep 2
evolve([2,2,1,1,0])         # timestep 3

graph.stop()