1 /**
2  *
3  * /home/tomas/workspace/mqtt-d/source/mqttd/client.d
4  *
5  * Author:
6  * Tomáš Chaloupka <chalucha@gmail.com>
7  *
8  * Copyright (c) 2015 Tomáš Chaloupka
9  *
10  * Boost Software License 1.0 (BSL-1.0)
11  *
12  * Permission is hereby granted, free of charge, to any person or organization obtaining a copy
13  * of the software and accompanying documentation covered by this license (the "Software") to use,
14  * reproduce, display, distribute, execute, and transmit the Software, and to prepare derivative
15  * works of the Software, and to permit third-parties to whom the Software is furnished to do so,
16  * all subject to the following:
17  *
18  * The copyright notices in the Software and this entire statement, including the above license
19  * grant, this restriction and the following disclaimer, must be included in all copies of the Software,
20  * in whole or in part, and all derivative works of the Software, unless such copies or derivative works
21  * are solely in the form of machine-executable object code generated by a source language processor.
22  *
23  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
24  * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
25  * PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR ANYONE
26  * DISTRIBUTING THE SOFTWARE BE LIABLE FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT,
27  * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
28  * OTHER DEALINGS IN THE SOFTWARE.
29  */
30 module mqttd.client;
31 
32 debug import std.stdio;
33 
34 import mqttd.traits;
35 import mqttd.messages;
36 import mqttd.serialization;
37 
38 import vibe.core.log;
39 import vibe.core.net: TCPConnection;
40 import vibe.core.stream;
41 import vibe.core.task;
42 import vibe.core.concurrency;
43 import vibe.utils.array : FixedRingBuffer;
44 
45 import std.datetime;
46 import std.exception;
47 import std..string : format;
48 import std.traits;
49 
50 enum MQTT_BROKER_DEFAULT_PORT = 1883u;
51 enum MQTT_BROKER_DEFAULT_SSL_PORT = 8883u;
52 enum MQTT_SESSION_MAX_PACKETS = ushort.max;
53 enum MQTT_CLIENT_ID = "vibe-mqtt"; /// default client identifier
54 enum MQTT_RETRY_DELAY = 10_000u; /// delay for retry publish, subscribe and unsubscribe for QoS Level 1 or 2 [ms]
55 enum MQTT_RETRY_ATTEMPTS = 3u; /// max publish, subscribe and unsubscribe retry for QoS Level 1 or 2
56 
57 
58 alias SessionContainer = FixedRingBuffer!(PacketContext, MQTT_SESSION_MAX_PACKETS);
59 
60 /// MqttClient settings
61 struct Settings
62 {
63 	string host = "127.0.0.1"; /// message broker address
64 	ushort port = MQTT_BROKER_DEFAULT_PORT; /// message broker port
65 	string clientId = MQTT_CLIENT_ID; /// Client Id to identify within message broker (must be unique)
66 	string userName = null; /// optional user name to login with
67 	string password = null; /// user password
68 	int retryDelay = MQTT_RETRY_DELAY;
69 	int retryAttempts = MQTT_RETRY_ATTEMPTS; /// how many times will client try to resend QoS1 and QoS2 packets
70 }
71 
72 /// MQTT packet state
73 enum PacketState
74 {
75 	queuedQos0, /// QOS = 0, Message queued
76 	queuedQos1, /// QOS = 1, Message queued
77 	queuedQos2, /// QOS = 2, Message queued
78 	waitForPuback, /// QOS = 1, PUBLISH sent, wait for PUBACK
79 	waitForPubrec, /// QOS = 2, PUBLISH sent, wait for PUBREC
80 	waitForPubrel, /// QOS = 2, PUBREC sent, wait for PUBREL
81 	waitForPubcomp, /// QOS = 2, PUBREL sent, wait for PUBCOMP
82 	sendPubrec, /// QOS = 2, start first phase handshake send PUBREC
83 	sendPubrel, /// QOS = 2, start second phase handshake send PUBREL
84 	sendPubcomp, /// QOS = 2, end second phase handshake send PUBCOMP
85 	sendPuback, /// QOS = 1, PUBLISH received, send PUBACK
86 	waitForSuback, /// (QOS = 1), SUBSCRIBE sent, wait for SUBACK
87 	waitForUnsuback, /// (QOS = 1), UNSUBSCRIBE sent, wait for UNSUBACK
88 	any /// for search purposes
89 }
90 
91 /// Context for MQTT packet stored in Session
92 struct PacketContext
93 {
94 	PacketType packetType; /// MQTT packet content
95 	PacketState state; /// MQTT packet state
96 	public SysTime timestamp; /// Timestamp (for retry)
97 	public uint attempt; /// Attempt (for retry)
98 	/// MQTT packet id
99 	@property ushort packetId()
100 	{
101 		switch (packetType)
102 		{
103 			case PacketType.PUBLISH:
104 				return publish.packetId;
105 			case PacketType.SUBSCRIBE:
106 				return subscribe.packetId;
107 			case PacketType.UNSUBSCRIBE:
108 				return unsubscribe.packetId;
109 			default:
110 				assert(0, "Unsupported packet type");
111 		}
112 	}
113 
114 	/// Context can hold different packet types
115 	union
116 	{
117 		Publish publish; /// Publish packet
118 		Subscribe subscribe; /// Subscribe packet
119 		Unsubscribe unsubscribe; /// Unsubscribe packet
120 	}
121 
122 	ref PacketContext opAssign(ref PacketContext ctx)
123 	{
124 		this.packetType = ctx.packetType;
125 		this.state = ctx.state;
126 		this.timestamp = ctx.timestamp;
127 		this.attempt = ctx.attempt;
128 
129 		switch (packetType)
130 		{
131 			case PacketType.PUBLISH:
132 				this.publish = ctx.publish;
133 				break;
134 			case PacketType.SUBSCRIBE:
135 				this.subscribe = ctx.subscribe;
136 				break;
137 			case PacketType.UNSUBSCRIBE:
138 				this.unsubscribe = ctx.unsubscribe;
139 				break;
140 			default:
141 				assert(0, "Unsupported packet type");
142 		}
143 
144 		return this;
145 	}
146 
147 	void opAssign(Publish pub)
148 	{
149 		this.packetType = PacketType.PUBLISH;
150 		this.publish = pub;
151 	}
152 
153 	void opAssign(Subscribe sub)
154 	{
155 		this.packetType = PacketType.SUBSCRIBE;
156 		this.subscribe = sub;
157 	}
158 
159 	void opAssign(Unsubscribe unsub)
160 	{
161 		this.packetType = PacketType.UNSUBSCRIBE;
162 		this.unsubscribe = unsub;
163 	}
164 }
165 
166 /// MQTT session status holder
167 struct Session
168 {
169 	/// Adds packet to Session
170 	auto add(T)(auto ref T packet, PacketState state)
171 		if (is(T == Publish) || is(T == Subscribe) || is(T == Unsubscribe))
172 	{
173 		// assign packet id
174 		static if (is(T == Publish))
175 		{
176 			if (state != PacketState.queuedQos0) packet.packetId = nextPacketId();
177 		}
178 		else packet.packetId = nextPacketId();
179 
180 		auto ctx = PacketContext();
181 		ctx = packet;
182 		ctx.timestamp = Clock.currTime;
183 		ctx.state = state;
184 
185 		if (_packets.full())
186 			_packets.popBack(); // make place by oldest one removal
187 
188 		_packets.put(ctx);
189 
190 		return packet.packetId;
191 	}
192 
193 	/// Removes the stored PacketContext
194 	void removeAt(size_t idx)
195 	{
196 		_packets.removeAt(_packets[idx..idx+1]);
197 	}
198 
199 	/// Finds package context stored in session
200 	auto canFind(ushort packetId, out PacketContext ctx, out size_t idx, PacketState state = PacketState.any) @trusted
201 	{
202 		size_t i;
203 		foreach(ref c; _packets)
204 		{
205 			if(c.packetId == packetId && (state == PacketState.any || c.state == state))
206 			{
207 				ctx = c;
208 				idx = i;
209 				return true;
210 			}
211 			++i;
212 		}
213 
214 		return false;
215 	}
216 
217 	@property ref PacketContext front()
218 	{
219 		return _packets.front();
220 	}
221 
222 	void popFront()
223 	{
224 		_packets.popFront();
225 	}
226 
227 	@safe @nogc pure nothrow:
228 
229 	/// Clears cached messages
230 	void clear()
231 	{
232 		_packets.clear();
233 	}
234 
235 	/// Number of packets to process
236 	@property auto packetCount() const
237 	{
238 		return _packets.length;
239 	}
240 
241 	/// Gets next packet id
242 	@property auto nextPacketId()
243 	{
244 		//TODO: Is this ok or should we check with session packets?
245 		//packet id can't be 0!
246 		_packetId = cast(ushort)((_packetId % MQTT_SESSION_MAX_PACKETS) != 0 ? _packetId + 1 : 1);
247 		return _packetId;
248 	}
249 
250 private:
251 	/// Packets to handle
252 	SessionContainer _packets;
253 	ushort _packetId = 0u;
254 }
255 
256 /// MQTT Client implementation
257 class MqttClient
258 {
259 	import std.array : Appender;
260 
261 	this(Settings settings)
262 	{
263 		import std.socket : Socket;
264 
265 		_settings = settings;
266 		if (_settings.clientId.length == 0) // set clientId if not provided
267 			_settings.clientId = Socket.hostName;
268 
269 		_readBuffer.capacity = 4 * 1024;
270 	}
271 
272 	final
273 	{
274 		/// Connects to the specified broker and sends it the Connect packet
275 		void connect()
276 		in { assert(_con is null ? true : !_con.connected); }
277 		body
278 		{
279 			import vibe.core.net: connectTCP;
280 			import vibe.core.core: runTask;
281 
282 			//cleanup before reconnects
283 			_readBuffer.clear();
284 
285 			_con = connectTCP(_settings.host, _settings.port);
286 			_listener = runTask(&listener);
287 			_dispatcher = runTask(&dispatcher);
288 			onDisconnectCalled = false;
289 
290 			version(MqttDebug) logDebug("MQTT Broker Connecting");
291 
292 			auto con = Connect();
293 			con.clientIdentifier = _settings.clientId;
294 			con.flags.cleanSession = true;
295 			if (_settings.userName.length > 0)
296 			{
297 				con.flags.userName = true;
298 				con.userName = _settings.userName;
299 				if (_settings.password.length > 0)
300 				{
301 					con.flags.password = true;
302 					con.password = _settings.password;
303 				}
304 			}
305 
306 			send(con);
307 		}
308 
309 		/// Sends Disconnect packet to the broker and closes the underlying connection
310 		void disconnect()
311 		in { assert(!(_con is null)); }
312 		body
313 		{
314 			version(MqttDebug) logDebug("MQTT Disconnecting from Broker");
315 
316 			if (_con.connected)
317 			{
318 				send(Disconnect());
319 				_con.flush();
320 				_con.close();
321 
322 				if(Task.getThis !is _listener)
323 					_listener.join;
324 			}
325 		}
326 
327 		/**
328 		 * Return true, if client is in a connected state
329 		 */
330 		@property bool connected() const
331 		{
332 			return _con !is null && _con.connected;
333 		}
334 
335 		/**
336 		 * Publishes the message on the specified topic
337 		 *
338 		 * Params:
339 		 *     topic = Topic to send message to
340 		 *     payload = Content of the message
341 		 *     qos = Required QoSLevel to handle message (default is QoSLevel.AtMostOnce)
342 		 *     retain = If true, the server must store the message so that it can be delivered to future subscribers
343 		 *
344 		 */
345 		void publish(T)(in string topic, in T payload, QoSLevel qos = QoSLevel.QoS0, bool retain = false)
346 			if (isSomeString!T || (isArray!T && is(ForeachType!T : ubyte)))
347 		{
348 			auto pub = Publish();
349 			pub.header.qos = qos;
350 			pub.header.retain = retain;
351 			pub.topic = topic;
352 			pub.payload = cast(ubyte[]) payload;
353 
354 			_session.add(pub, qos == QoSLevel.QoS0 ?
355 				PacketState.queuedQos0 :
356 				(qos == QoSLevel.QoS1 ? PacketState.queuedQos1 : PacketState.queuedQos2));
357 
358 			_dispatcher.send(true);
359 		}
360 
361 		/**
362 		 * Subscribes to the specified topics
363 		 *
364 		 * Params:
365 		 *      topics = Array of topic filters to subscribe to
366 		 *      qos = This gives the maximum QoS level at which the Server can send Application Messages to the Client.
367 		 *
368 		 */
369 		void subscribe(const string[] topics, QoSLevel qos = QoSLevel.QoS0)
370 		{
371 			import std.algorithm : map;
372 			import std.array : array;
373 
374 			auto sub = Subscribe();
375 			sub.topics = topics.map!(a => Topic(a, qos)).array;
376 
377 			_session.add(sub, PacketState.waitForSuback);
378 			_dispatcher.send(true);
379 		}
380 	}
381 
382 	void onConnAck(ConnAck packet)
383 	{
384 		version(MqttDebug) logDebug("MQTT onConnAck - %s", packet);
385 
386 		if(packet.returnCode == ConnectReturnCode.ConnectionAccepted)
387 		{
388 			version(MqttDebug) logDebug("MQTT Connection accepted");
389 		}
390 		else throw new Exception(format("Connection refused: %s", packet.returnCode));
391 	}
392 
393 	void onPingResp(PingResp packet)
394 	{
395 		version(MqttDebug) logDebug("MQTT onPingResp - %s", packet);
396 	}
397 
398 	void onPubAck(PubAck packet)
399 	{
400 		version(MqttDebug) logDebug("MQTT onPubAck - %s", packet);
401 
402 		PacketContext ctx;
403 		size_t idx;
404 		if(_session.canFind(packet.packetId, ctx, idx, PacketState.waitForPuback)) // QoS 1
405 		{
406 			//treat the PUBLISH Packet as “unacknowledged” until corresponding PUBACK received
407 			_session.removeAt(idx);
408 			_dispatcher.send(true);
409 		}
410 	}
411 
412 	void onPubRec(PubRec packet)
413 	{
414 		version(MqttDebug) logDebug("MQTT onPubRec - %s", packet);
415 	}
416 
417 	void onPubRel(PubRel packet)
418 	{
419 		version(MqttDebug) logDebug("MQTT onPubRel - %s", packet);
420 	}
421 
422 	void onPubComp(PubComp packet)
423 	{
424 		version(MqttDebug) logDebug("MQTT onPubComp - %s", packet);
425 	}
426 
427 	void onPublish(Publish packet)
428 	{
429 		version(MqttDebug) logDebug("MQTT onPublish - %s", packet);
430 
431 		//MUST respond with a PUBACK Packet containing the Packet Identifier from the incoming PUBLISH Packet
432 		if (packet.header.qos == QoSLevel.QoS1)
433 		{
434 			auto ack = PubAck();
435 			ack.packetId = packet.packetId;
436 
437 			send(ack);
438 		}
439 	}
440 
441 	void onSubAck(SubAck packet)
442 	{
443 		version(MqttDebug) logDebug("MQTT onSubAck - %s", packet);
444 
445 		PacketContext ctx;
446 		size_t idx;
447 		if(_session.canFind(packet.packetId, ctx, idx, PacketState.waitForSuback))
448 		{
449 			_session.removeAt(idx);
450 			_dispatcher.send(true);
451 		}
452 	}
453 
454 	void onUnsubAck(UnsubAck packet)
455 	{
456 		version(MqttDebug) logDebug("MQTT onUnsubAck - %s", packet);
457 	}
458 
459 	void onDisconnect()
460 	{
461 		version(MqttDebug) logDebug("MQTT onDisconnect");
462 	}
463 
464 private:
465 	Settings _settings;
466 	TCPConnection _con;
467 	Session _session;
468 	Task _listener, _dispatcher;
469 	Serializer!(Appender!(ubyte[])) _sendBuffer;
470 	FixedRingBuffer!ubyte _readBuffer;
471 	ubyte[] _packetBuffer;
472 	bool onDisconnectCalled;
473 
474 final:
475 
476 	/// Processes data in read buffer. If whole packet is presented, it delegates it to handler
477 	void proccessData(in ubyte[] data)
478 	{
479 		import mqttd.serialization;
480 		import std.range;
481 
482 		version(MqttDebug) logDebug("MQTT IN: %(%.02x %)", data);
483 
484 		if (_readBuffer.freeSpace < data.length) // ensure all fits to the buffer
485 			_readBuffer.capacity = _readBuffer.capacity + data.length;
486 		_readBuffer.put(data);
487 
488 		if (_readBuffer.length > 0)
489 		{
490 			// try read packet header
491 			FixedHeader header = _readBuffer[0]; // type + flags
492 
493 			// try read remaining length
494 			uint pos;
495 			uint multiplier = 1;
496 			ubyte digit;
497 			do
498 			{
499 				if (++pos >= _readBuffer.length) return; // not enough data
500 				digit = _readBuffer[pos];
501 				header.length += ((digit & 127) * multiplier);
502 				multiplier *= 128;
503 				if (multiplier > 128*128*128) throw new PacketFormatException("Malformed remaining length");
504 			} while ((digit & 128) != 0);
505 
506 			if (_readBuffer.length < header.length + pos + 1) return; // not enough data
507 
508 			// we've got the whole packet to handle
509 			_packetBuffer.length = 1 + pos + header.length; // packet type byte + remaining size bytes + remaining size
510 			_readBuffer.read(_packetBuffer); // read whole packet from read buffer
511 
512 			with (PacketType)
513 			{
514 				switch (header.type)
515 				{
516 					case CONNACK:
517 						onConnAck(_packetBuffer.deserialize!ConnAck());
518 						break;
519 					case PINGRESP:
520 						onPingResp(_packetBuffer.deserialize!PingResp());
521 						break;
522 					case PUBACK:
523 						onPubAck(_packetBuffer.deserialize!PubAck());
524 						break;
525 					case PUBREC:
526 						onPubRec(_packetBuffer.deserialize!PubRec());
527 						break;
528 					case PUBREL:
529 						onPubRel(_packetBuffer.deserialize!PubRel());
530 						break;
531 					case PUBCOMP:
532 						onPubComp(_packetBuffer.deserialize!PubComp());
533 						break;
534 					case PUBLISH:
535 						onPublish(_packetBuffer.deserialize!Publish());
536 						break;
537 					case SUBACK:
538 						onSubAck(_packetBuffer.deserialize!SubAck());
539 						break;
540 					case UNSUBACK:
541 						onUnsubAck(_packetBuffer.deserialize!UnsubAck());
542 						break;
543 					default:
544 						throw new Exception(format("Unexpected packet type '%s'", header.type));
545 				}
546 			}
547 		}
548 	}
549 
550 	/// loop to receive packets
551 	void listener()
552 	in { assert(_con && _con.connected); }
553 	body
554 	{
555 		import vibe.core.log: logError;
556 
557 		version(MqttDebug) logDebug("MQTT Entering listening loop");
558 
559 		auto buffer = new ubyte[4096];
560 
561 		while (_con.connected && !_con.empty)
562 		{
563 			auto size = cast(size_t)_con.leastSize;
564 			if (size > 0)
565 			{
566 				if (size > buffer.length) size = buffer.length;
567 				_con.read(buffer[0..size]);
568 				proccessData(buffer[0..size]);
569 			}
570 		}
571 
572 		if (!_con.connected) callOnDisconnect();
573 
574 		version(MqttDebug) logDebug("MQTT Exiting listening loop");
575 	}
576 
577 	/// loop to dispatch in session stored packets
578 	void dispatcher()
579 	in { assert(_con && _con.connected); }
580 	body
581 	{
582 		import vibe.core.log: logError;
583 		import vibe.core.core : yield;
584 
585 		version(MqttDebug) logDebug("MQTT Entering dispatch loop");
586 
587 		bool exit = false;
588 		bool con = true;
589 		while (_con.connected && !exit)
590 		{
591 			// wait for info about change or timeout
592 			receiveTimeout(_settings.retryDelay.msecs,
593 				(bool msg) { exit = !msg; });
594 
595 			if (_session.packetCount > 0)
596 			{
597 				//version(MqttDebug) logDebug("MQTT Packets in session: %s", _session.packetCount);
598 				auto ctx = &_session.front();
599 				final switch (ctx.state)
600 				{
601 					case PacketState.queuedQos0: // just send it
602 						assert(ctx.packetType == PacketType.PUBLISH);
603 						send(ctx.publish);
604 						_session.popFront(); // remove it from session
605 						break;
606 					case PacketState.queuedQos1:
607 						//treat the Packet as “unacknowledged” until the corresponding PUBACK packet received
608 						assert(ctx.packetType == PacketType.PUBLISH);
609 						assert(ctx.publish.header.qos == QoSLevel.QoS1);
610 						send(ctx.publish);
611 						ctx.state = PacketState.waitForPuback; // change to next state
612 						break;
613 					case PacketState.queuedQos2:
614 						//TODO
615 						break;
616 					case PacketState.sendPuback:
617 						break;
618 					case PacketState.sendPubcomp:
619 						break;
620 					case PacketState.sendPubrec:
621 						break;
622 					case PacketState.sendPubrel:
623 						break;
624 					case PacketState.waitForPuback:
625 						break;
626 					case PacketState.waitForPubcomp:
627 						break;
628 					case PacketState.waitForPubrec:
629 						break;
630 					case PacketState.waitForPubrel:
631 						break;
632 					case PacketState.waitForSuback:
633 						assert(ctx.packetType == PacketType.SUBSCRIBE);
634 						send(ctx.subscribe);
635 						break;
636 					case PacketState.waitForUnsuback:
637 						break;
638 					case PacketState.any:
639 						assert(0, "Invalid state");
640 				}
641 			}
642 		}
643 
644 		if (!_con.connected) callOnDisconnect();
645 
646 		version(MqttDebug) logDebug("MQTT Exiting dispatch loop");
647 	}
648 
649 	void send(T)(auto ref T msg) if (isMqttPacket!T)
650 	{
651 		_sendBuffer.clear(); // clear to write new
652 		_sendBuffer.serialize(msg);
653 
654 		if (_con.connected)
655 		{
656 			version(MqttDebug) logDebug("MQTT OUT: %(%.02x %)", _sendBuffer.data);
657 			_con.write(_sendBuffer.data);
658 		}
659 	}
660 
661 	auto callOnDisconnect()
662 	{
663 		if (!onDisconnectCalled)
664 		{
665 			onDisconnectCalled = true;
666 			onDisconnect();
667 		}
668 	}
669 }
670 
671 unittest
672 {
673 	Session s;
674 
675 	auto pub = Publish();
676 	auto id = s.add(pub, PacketState.waitForPuback);
677 
678 	assert(s.packetCount == 1);
679 
680 	PacketContext ctx;
681 	size_t idx;
682 	assert(id != 0);
683 	assert(s.canFind(id, ctx, idx));
684 	assert(idx == 0);
685 	assert(s.packetCount == 1);
686 
687 	assert(ctx.packetType == PacketType.PUBLISH);
688 	assert(ctx.state == PacketState.waitForPuback);
689 	assert(ctx.attempt == 0);
690 	assert(ctx.publish != Publish.init);
691 	assert(ctx.timestamp != SysTime.init);
692 
693 	s.removeAt(idx);
694 	assert(s.packetCount == 0);
695 }