Selaa lähdekoodia

Fix concurrency problems

Billy Barrow 4 vuotta sitten
vanhempi
sitoutus
791a42372c

+ 2 - 1
src/lib/Networks/PeerInfo.vala

@@ -1,12 +1,13 @@
 using GLib;
 using Gee;
+using LibPeer.Util;
 
 namespace LibPeer.Networks
 {
     
     public abstract class PeerInfo {
 
-        private static HashMap<Bytes, Type> info_types = new HashMap<Bytes, Type>((a) => a.hash(), (a, b) => a.compare(b) == 0);
+        private static ConcurrentHashMap<Bytes, Type> info_types = new ConcurrentHashMap<Bytes, Type>((a) => a.hash(), (a, b) => a.compare(b) == 0);
         
         protected abstract void build(uint8 data_length, InputStream stream) throws IOError, Error;
         

+ 2 - 1
src/lib/Networks/Simulation/Conduit.vala

@@ -1,4 +1,5 @@
 using LibPeer.Networks;
+using LibPeer.Util;
 
 using Gee;
 
@@ -6,7 +7,7 @@ namespace LibPeer.Networks.Simulation {
 
     public class Conduit {
 
-        private HashMap<Bytes, NetSim> interfaces = new HashMap<Bytes, NetSim>((a) => a.hash(), (a, b) => a.compare(b) == 0);
+        private ConcurrentHashMap<Bytes, NetSim> interfaces = new ConcurrentHashMap<Bytes, NetSim>((a) => a.hash(), (a, b) => a.compare(b) == 0);
 
         private int count = 0;
 

+ 1 - 1
src/lib/Networks/Simulation/NetSim.vala

@@ -91,7 +91,7 @@ namespace LibPeer.Networks.Simulation {
             // Create the packet
             var packet = new Packet(peer_info, data);
 
-            print(@"NET: $(data.get(0)) $(data.get(1)) $(data.get(2))\n");
+            //  print(@"NET: $(origin.get(0)) $(origin.get(1)) $(origin.get(2)) to $(identifier.get(0)) $(identifier.get(1)) $(identifier.get(2))\n");
 
             // Add packet to queue
             packet_queue.push(new QueueCommand<Packet>.with_payload(packet));

+ 1 - 5
src/lib/Protocols/MX2/Frame.vala

@@ -46,13 +46,11 @@ namespace LibPeer.Protocols.Mx2 {
             // Encrypt the signed payload
             uint8[] encrypted_signed_payload = Asymmetric.Sealing.seal(signed_payload, destination.public_key);
 
-            print(@"TX_PAYLOAD: $(new ByteComposer().add_byte_array(payload).to_string())\n");
-
             // Write the signed and encrypted payload
             stream.write(encrypted_signed_payload);
         }
 
-        public Frame.from_stream(InputStream stream, HashMap<InstanceReference, Instance> instances) throws IOError, Error{
+        public Frame.from_stream(InputStream stream, ConcurrentHashMap<InstanceReference, Instance> instances) throws IOError, Error{
             // Read the magic number
             uint8[] magic = new uint8[3];
             stream.read(magic);
@@ -93,8 +91,6 @@ namespace LibPeer.Protocols.Mx2 {
 
             // Verify the signature and get plaintext message
             uint8[]? payload = Asymmetric.Signing.verify(signed_payload, origin.verification_key);
-                                
-            print(@"RX_PAYLOAD: $(new ByteComposer().add_byte_array(payload).to_string())\n");
 
             if (payload == null) {
                 throw new IOError.FAILED("Payload signature is invalid");

+ 1 - 1
src/lib/Protocols/MX2/InstanceAccessInfo.vala

@@ -2,7 +2,7 @@ using LibPeer.Networks;
 
 namespace LibPeer.Protocols.Mx2 {
 
-    internal struct InstanceAccessInfo {
+    internal class InstanceAccessInfo {
 
         public Network network;
 

+ 7 - 6
src/lib/Protocols/MX2/InstanceReference.vala

@@ -3,8 +3,8 @@ namespace LibPeer.Protocols.Mx2 {
 
     public class InstanceReference {
 
-        public uint8[] verification_key { get; protected set; }
-        public uint8[] public_key { get; protected set; }
+        public uint8[] verification_key { get; private set; }
+        public uint8[] public_key { get; private set; }
 
         public InstanceReference(uint8[] verification_key, uint8[] public_key) 
         requires (verification_key.length == 32)
@@ -16,7 +16,7 @@ namespace LibPeer.Protocols.Mx2 {
 
         public InstanceReference.from_stream(InputStream stream) throws IOError {
             verification_key = new uint8[32];
-            stream.read(verification_key);
+            stream.read(_verification_key);
 
             public_key = new uint8[32];
             stream.read(public_key);
@@ -28,9 +28,10 @@ namespace LibPeer.Protocols.Mx2 {
         }
 
         private Bytes combined_bytes () {
-            uint8[] combined = new uint8[64];
-            MemoryOutputStream stream = new MemoryOutputStream(combined);
-            serialise(stream);
+            uint8[] combined = new Util.ByteComposer()
+                .add_byte_array(verification_key)
+                .add_byte_array(public_key)
+                .to_byte_array();
             return new Bytes(combined);
         }
 

+ 9 - 9
src/lib/Protocols/MX2/Muxer.vala

@@ -13,15 +13,17 @@ namespace LibPeer.Protocols.Mx2 {
 
         private const int FALLBACK_PING_VALUE = 120000;
         
-        private HashMap<Bytes, HashSet<Network>> networks = new HashMap<Bytes, HashSet<Network>>((a) => a.hash(), (a, b) => a.compare(b) == 0);
+        private ConcurrentHashMap<Bytes, HashSet<Network>> networks = new ConcurrentHashMap<Bytes, HashSet<Network>>((a) => a.hash(), (a, b) => a.compare(b) == 0);
 
-        private HashMap<InstanceReference, Instance> instances = new HashMap<InstanceReference, Instance>((a) => a.hash(), (a, b) => a.compare(b) == 0);
+        private ConcurrentHashMap<InstanceReference, Instance> instances = new ConcurrentHashMap<InstanceReference, Instance>((a) => a.hash(), (a, b) => a.compare(b) == 0);
 
-        private HashMap<InstanceReference, InstanceAccessInfo?> remote_instance_mapping = new HashMap<InstanceReference, InstanceAccessInfo>((a) => a.hash(), (a, b) => a.compare(b) == 0);
+        private ConcurrentHashMap<InstanceReference, InstanceAccessInfo> remote_instance_mapping = new ConcurrentHashMap<InstanceReference, InstanceAccessInfo>((a) => a.hash(), (a, b) => a.compare(b) == 0);
+
+        private ConcurrentHashMap<Bytes, Inquiry> inquiries = new ConcurrentHashMap<Bytes, Inquiry>((a) => a.hash(), (a, b) => a.compare(b) == 0);
+
+        private ConcurrentHashMap<InstanceReference, int> pings = new ConcurrentHashMap<InstanceReference, int>((a) => a.hash(), (a, b) => a.compare(b) == 0);
 
-        private HashMap<Bytes, Inquiry> inquiries = new HashMap<Bytes, Inquiry>((a) => a.hash(), (a, b) => a.compare(b) == 0);
 
-        private HashMap<InstanceReference, int> pings = new HashMap<InstanceReference, int>((a) => a.hash(), (a, b) => a.compare(b) == 0);
 
         
         public void register_network(Network network) {
@@ -188,12 +190,10 @@ namespace LibPeer.Protocols.Mx2 {
                 .add_byte_array(frame.payload[17:frame.payload.length])
                 .to_string();
 
-            print(@"NAMESPACE: $(application_namespace)\n");
-
             // Does the application namespace match the instance's
             if (instance.application_namespace == application_namespace) {
                 // Yes, save this instance's information locally for use later
-                remote_instance_mapping.set(frame.origin, InstanceAccessInfo() { 
+                remote_instance_mapping.set(frame.origin, new InstanceAccessInfo() { 
                     network = receiption.network,
                     peer_info = receiption.peer_info,
                     path_info = frame.via.return_path
@@ -216,7 +216,7 @@ namespace LibPeer.Protocols.Mx2 {
             // Have we received one from this instance before?
             if (!remote_instance_mapping.has_key(frame.origin)) {
                 // No, this is the first (and therefore least latent) method of reaching this instance
-                remote_instance_mapping.set(frame.origin, InstanceAccessInfo() { 
+                remote_instance_mapping.set(frame.origin, new InstanceAccessInfo() { 
                     network = receiption.network,
                     peer_info = receiption.peer_info,
                     path_info = frame.via.return_path

+ 125 - 0
src/lib/Util/ConcurrentHashMap.vala

@@ -0,0 +1,125 @@
+using Gee;
+
+namespace LibPeer.Util {
+
+    public class ConcurrentHashMap<K, V> : AbstractMap<K,V> {
+
+        private HashMap<K, V> _map;
+
+        private HashDataFunc<K>? key_hash_func = null;
+        private EqualDataFunc<K>? key_equal_func = null;
+        private EqualDataFunc<V>? value_equal_func = null;
+
+        public ConcurrentHashMap (owned HashDataFunc<K>? key_hash_func = null, owned EqualDataFunc<K>? key_equal_func = null, owned EqualDataFunc<V>? value_equal_func = null) {
+            _map = new HashMap<K, V> (key_hash_func, key_equal_func, value_equal_func);
+            this.key_hash_func = key_hash_func;
+            this.key_equal_func = key_equal_func;
+            this.value_equal_func = value_equal_func;
+        }
+
+        private HashMap<K, V> copy() {
+            lock(_map) {
+                HashMap<K, V> copy = new HashMap<K, V> (key_hash_func, key_equal_func, value_equal_func);
+                copy.set_all(_map);
+                return copy;
+            }
+        }
+
+        public override void clear () {
+            lock(_map) {
+                clear();
+            }
+        }
+		/**
+		 * {@inheritDoc}
+		 */
+		public override new V @get (K key) {
+            lock(_map) {
+                return _map.get (key);
+            }
+        }
+		/**
+		 * {@inheritDoc}
+		 */
+		public override bool has (K key, V value) {
+            lock(_map) {
+                return _map.has (key, value);
+            }
+        }
+		/**
+		 * {@inheritDoc}
+		 */
+		public override bool has_key (K key)  {
+            lock(_map) {
+                return _map.has_key (key);
+            }
+        }
+		/**
+		 * {@inheritDoc}
+		 */
+		public override Gee.MapIterator<K,V> map_iterator () {
+            lock(_map) {
+                return copy().map_iterator();
+            }
+        }
+
+		/**
+		 * {@inheritDoc}
+		 */
+		public override new void @set (K key, V value)  {
+            lock(_map) {
+                _map.set (key, value);
+            }
+        }
+		/**
+		 * {@inheritDoc}
+		 */
+		public override bool unset (K key, out V value = null)  {
+            lock(_map) {
+                return _map.unset (key, out value);
+            }
+        }
+		/**
+		 * {@inheritDoc}
+		 */
+		public override Gee.Set<Gee.Map.Entry<K,V>> entries { owned get {
+            lock(_map) {
+                return copy().entries;
+            }
+        } }
+		/**
+		 * {@inheritDoc}
+		 */
+		public override Gee.Set<K> keys { owned get {
+            lock(_map) {
+                return copy().keys;
+            }
+        } }
+		/**
+		 * {@inheritDoc}
+		 */
+		public override bool read_only { get {
+            lock(_map) {
+                return _map.read_only;
+            }
+        } }
+		/**
+		 * {@inheritDoc}
+		 */
+		public override int size { get {
+            lock(_map) {
+                return _map.size;
+            }
+        } }
+		/**
+		 * {@inheritDoc}
+		 */
+		public override Gee.Collection<V> values { owned get {
+            lock(_map) {
+                return copy().values;
+            }
+        } }
+
+    }
+
+}

+ 1 - 0
src/lib/meson.build

@@ -35,6 +35,7 @@ sources += files('Protocols/MX2/PathStrategy.vala')
 sources += files('Util/ByteComposer.vala')
 sources += files('Util/QueueCommand.vala')
 sources += files('Util/ThreadTimer.vala')
+sources += files('Util/ConcurrentHashMap.vala')
 
 libpeer = library('peer', sources, dependencies: dependencies)
 libpeer_dep = declare_dependency(link_with: libpeer, include_directories: include_directories('.'))

+ 3 - 3
src/lib/vapi/libsodium.vapi

@@ -49,7 +49,7 @@
      private void key_gen([CCode (array_length = false)]uint8[] key);
  
      public uint8[] generate_key() {
-       uint8[KEY_BYTES] key = new uint8[KEY_BYTES];
+       uint8[] key = new uint8[KEY_BYTES];
        key_gen(key);
        return key;
      }
@@ -68,7 +68,7 @@
      {
        // Initialise array for ciphertext
        size_t ciphertext_size = MAC_BYTES + message.length;
-       uint8[ciphertext_size] ciphertext = new uint8[ciphertext_size];
+       uint8[] ciphertext = new uint8[ciphertext_size];
  
        // Encrypt
        secretbox(ciphertext, message, nonce, key);
@@ -92,7 +92,7 @@
      {
        // Initialise array for message
        size_t message_size = ciphertext.length - MAC_BYTES;
-       uint8[message_size] message = new uint8[message_size];
+       uint8[] message = new uint8[message_size];
  
        // Decrypt
        int status = secretbox_open(message, ciphertext, nonce, key);

+ 3 - 2
src/toys/exponential_pinger/Main.vala

@@ -6,11 +6,12 @@ namespace ExponentialPinger {
 
         public static int main(string[] args) {
             print("Exponential Pinger\n");
+            int count = int.parse(args[1]);
 
             Conduit conduit = new Conduit();
 
-            Pinger[] pingas = new Pinger[10];
-            for (int i = 0; i < 10; i++){
+            Pinger[] pingas = new Pinger[count];
+            for (int i = 0; i < count; i++){
                 pingas[i] = new Pinger(conduit);
             }
 

+ 9 - 15
src/toys/exponential_pinger/Pinger.vala

@@ -11,7 +11,7 @@ namespace ExponentialPinger {
         private Muxer muxer = new Muxer();
         private Network network;
         private Instance instance;
-        private HashSet<InstanceReference> peers = new HashSet<InstanceReference>((m) => m.hash(), (a, b) => a.compare(b) == 0);
+        private ConcurrentSet<InstanceReference> peers = new ConcurrentSet<InstanceReference>((a, b) => a.compare(b));
 
         public Pinger(Conduit conduit) throws Error, IOError {
             network = conduit.get_interface();
@@ -28,28 +28,22 @@ namespace ExponentialPinger {
         }
 
         private void rx_advertisement(Advertisement adv) throws Error, IOError {
-            lock (peers) {
-                if(!peers.contains(adv.instance_reference)) {
-                    var peer_info = new GLib.List<PeerInfo>();
-                    peer_info.append(adv.peer_info);
-                    muxer.inquire(instance, adv.instance_reference, peer_info);
-                }
+            if(!peers.contains(adv.instance_reference)) {
+                var peer_info = new GLib.List<PeerInfo>();
+                peer_info.append(adv.peer_info);
+                muxer.inquire(instance, adv.instance_reference, peer_info);
             }
         }
 
         private void rx_greeting(InstanceReference origin) throws Error, IOError {
-            lock (peers) {
-                peers.add(origin);
-            }
+            peers.add(origin);
             muxer.send(instance, origin, "Hello World!".data);
         }
 
         private void rx_data(Packet packet) throws Error, IOError {
-            lock (peers) {
-                peers.add(packet.origin);
-                network.advertise(instance.reference);
-                print(@"RX DATA, I have $(peers.size) peers\n");
-            }
+            peers.add(packet.origin);
+            network.advertise(instance.reference);
+            print(@"RX DATA, I have $(peers.size) peers\n");
 
             uint8[] data = new uint8[13];
             packet.stream.read(data);