1 /**
2  * Serialization and deserialization of MQTT protocol messages
3  *
4  * Author:
5  * Tomáš Chaloupka <chalucha@gmail.com>
6  *
7  * License:
8  * Boost Software License 1.0 (BSL-1.0)
9  *
10  * Permission is hereby granted, free of charge, to any person or organization obtaining a copy
11  * of the software and accompanying documentation covered by this license (the "Software") to use,
12  * reproduce, display, distribute, execute, and transmit the Software, and to prepare derivative
13  * works of the Software, and to permit third-parties to whom the Software is furnished to do so,
14  * all subject to the following:
15  *
16  * The copyright notices in the Software and this entire statement, including the above license
17  * grant, this restriction and the following disclaimer, must be included in all copies of the Software,
18  * in whole or in part, and all derivative works of the Software, unless such copies or derivative works
19  * are solely in the form of machine-executable object code generated by a source language processor.
20  *
21  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
22  * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
23  * PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR ANYONE
24  * DISTRIBUTING THE SOFTWARE BE LIABLE FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT,
25  * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
26  * OTHER DEALINGS IN THE SOFTWARE.
27  */
28 module mqttd.serialization;
29 
30 import std..string : format;
31 import std.range;
32 
33 import mqttd.messages;
34 import mqttd.traits;
35 
36 debug import std.stdio;
37 
38 @safe:
39 
40 auto serialize(R, T)(auto ref R output, ref T item) if (canSerializeTo!(R))
41 {
42 	auto ser = serializer(output);
43 	ser.serialize(item);
44 	return ser;
45 }
46 
47 auto serializer(R)(auto ref R output) if (canSerializeTo!(R))
48 {
49 	return Serializer!R(output);
50 }
51 
52 struct Serializer(R) if (canSerializeTo!(R))
53 {
54 	this(R output)
55 	{
56 		_output = output;
57 	}
58 
59 	@safe
60 	void put(in ubyte val)
61 	{
62 		_output.put(val);
63 	}
64 
65 	@safe
66 	void put(in ubyte[] val)
67 	{
68 		_output.put(val);
69 	}
70 
71 	static if(__traits(hasMember, R, "data"))
72 	{
73 		@safe
74 		@property auto data() nothrow
75 		{
76 			return _output.data();
77 		}
78 	}
79 
80 	static if(__traits(hasMember, R, "clear"))
81 	{
82 		void clear()
83 		{
84 			_output.clear();
85 		}
86 	}
87 
88 	/// Serialize given Mqtt packet
89 	void serialize(T)(ref T item) if (isMqttPacket!T)
90 	{
91 		static assert(hasFixedHeader!T, format("'%s' packet has no required header field!", T.stringof));
92 
93 		mixin processMembersTemplate!(uint, `res += f.itemLength;`) L;
94 		mixin processMembersTemplate!(uint, `write(f);`) W;
95 
96 		//set remaining packet length by checking packet conditions
97 		item.header.length = L.processMembers(item);
98 
99 		static if (__traits(hasMember, R, "reserve")) // we can reserve required size to serialize packet
100 		{
101 			_output.reserve(item.header.length + 4); // 4 = max header size
102 		}
103 
104 		//check if is valid
105 		try item.validate();
106 		catch (Exception ex)
107 			throw new PacketFormatException(format("'%s' packet is not valid: %s", T.stringof, ex.msg), ex);
108 
109 		//write members to output writer
110 		W.processMembers(item);
111 	}
112 
113 	package ref Serializer write(T)(T val) if (canWrite!T)
114 	{
115 		import std.traits : isDynamicArray;
116 
117 		bool handled = true;
118 		static if (is(T == FixedHeader)) // first to avoid implicit conversion to ubyte
119 		{
120 			put(val.flags);
121 
122 			int tmp = val.length;
123 			do
124 			{
125 				byte digit = tmp % 128;
126 				tmp /= 128;
127 				if (tmp > 0) digit |= 0x80;
128 				put(digit);
129 			} while (tmp > 0);
130 		}
131 		else static if (is(T:ubyte))
132 		{
133 			put(val);
134 		}
135 		else static if (is(T:ushort))
136 		{
137 			put(cast(ubyte) (val >> 8));
138 			put(cast(ubyte) val);
139 		}
140 		else static if (is(T:string))
141 		{
142 			import std..string : representation;
143 			import std.exception : enforce;
144 
145 			enforce(val.length <= 0xFF, "String too long: ", val);
146 
147 			write((cast(ushort)val.length));
148 			put(val.representation);
149 		}
150 		else static if (isDynamicArray!T)
151 		{
152 			static if (is(ElementType!T == ubyte)) put(val);
153 			else foreach(ret; val) write(ret);
154 		}
155 		else static if (is(T == struct)) //write struct members individually
156 		{
157 			foreach(memberName; __traits(allMembers, T))
158 				write(__traits(getMember, val, memberName));
159 		}
160 		else
161 		{
162 			handled = false;
163 		}
164 
165 		if (handled) return this;
166 		assert(0, "Not implemented write for: " ~ T.stringof);
167 	}
168 
169 private:
170 	R _output;
171 }
172 
173 template deserialize(T)
174 {
175 	auto deserialize(R)(auto ref R input) if (canDeserializeFrom!(R))
176 	{
177 		return deserializer(input).deserialize!T();
178 	}
179 }
180 
181 auto deserializer(R)(auto ref R input) if (canDeserializeFrom!(R))
182 {
183 	return Deserializer!R(input);
184 }
185 
186 struct Deserializer(R) if (canDeserializeFrom!(R))
187 {
188 	this(R input)
189 	{
190 		_input = input;
191 	}
192 
193 	@property ubyte front()
194 	{
195 		//debug writef("%.02x ", _input.front);
196 		return cast(ubyte)_input.front;
197 	}
198 
199 	@property bool empty()
200 	{
201 		//debug if (_input.empty) writeln();
202 		return _input.empty;
203 	}
204 
205 	void popFront()
206 	{
207 		_input.popFront();
208 		if(_remainingLen > 0) _remainingLen--; //decrease remaining length set from fixed header
209 		//debug writefln("Pop: %s", empty? "empty" : format("%.02x", front));
210 	}
211 
212 	T deserialize(T)() if (isMqttPacket!T)
213 	{
214 		import std.typetuple;
215 		import std.exception : enforce;
216 
217 		static assert(hasFixedHeader!T, format("'%s' packet has no required header field!", T.stringof));
218 
219 		mixin processMembersTemplate!(void, "item.tupleof[i] = read!(typeof(f))();");
220 
221 		T res;
222 
223 		processMembers(res);
224 
225 		enforce(empty, "Some data are remaining after packet deserialization!");
226 
227 		// validate initialized packet
228 		try res.validate();
229 		catch (Exception ex)
230 			throw new PacketFormatException(format("'%s' packet is not valid: %s", T.stringof, ex.msg), ex);
231 
232 		return res;
233 	}
234 
235 	package T read(T)() if (canRead!T)
236 	{
237 		import std.traits : isDynamicArray;
238 
239 		auto handled = true;
240 		T res;
241 
242 		static if (is(T == FixedHeader)) // first to avoid implicit conversion to ubyte
243 		{
244 			res.flags = read!ubyte();
245 			res.length = 0;
246 
247 			uint multiplier = 1;
248 			ubyte digit;
249 			do
250 			{
251 				digit = read!ubyte();
252 				res.length += ((digit & 127) * multiplier);
253 				multiplier *= 128;
254 				if (multiplier > 128*128*128) throw new PacketFormatException("Malformed remaining length");
255 			} while ((digit & 128) != 0);
256 
257 			//set remaining length for calculations
258 			_remainingLen = res.length;
259 		}
260 		else static if (is(T:ubyte))
261 		{
262 			res = cast(T)front;
263 			popFront();
264 		}
265 		else static if (is(T:ushort))
266 		{
267 			res = cast(ushort) (read!ubyte() << 8);
268 			res |= cast(ushort) read!ubyte();
269 		}
270 		else static if (is(T:string))
271 		{
272 			import std.array : array;
273 			import std.algorithm : map;
274 
275 			auto length = read!ushort();
276 			static if(hasSlicing!R)
277 			{
278 				//writeln(cast(string)_input[0..length]);
279 				res = (cast(char[])_input[0..length]).idup;
280 				_remainingLen -= length;
281 				_input = _input.length > length ? _input[length..$] : R.init;
282 			}
283 			else res = (&this).takeExactly(length).map!(a => cast(immutable char)a).array;
284 		}
285 		else static if (isDynamicArray!T)
286 		{
287 			res = T.init;
288 			static if (is(ElementType!T == ubyte) && hasSlicing!R) //slice it
289 			{
290 				res = _input[0..$];
291 				_remainingLen -= res.length;
292 				_input = R.init;
293 			}
294 			else
295 			{
296 				while(_remainingLen > 0) // read to end
297 				{
298 					res ~= read!(ElementType!T)();
299 				}
300 			}
301 		}
302 		else static if (is(T == struct)) //read struct members individually
303 		{
304 			foreach(memberName; __traits(allMembers, T))
305 				__traits(getMember, res, memberName) = read!(typeof(__traits(getMember, res, memberName)))();
306 		}
307 		else
308 		{
309 			handled = false;
310 		}
311 
312 		if (handled) return res;
313 		assert(0, "Not implemented read for: " ~ T.stringof);
314 	}
315 
316 private:
317 	R _input;
318 	uint _remainingLen;
319 }
320 
321 /// Gets required buffer size to encode into
322 @safe @nogc
323 uint itemLength(T)(auto ref in T item) pure nothrow
324 {
325 	import std.traits : isDynamicArray;
326 
327 	static if (is(T == FixedHeader)) return 0;
328 	else static if (is(T:ubyte)) return 1;
329 	else static if (is(T:ushort)) return 2;
330 	else static if (is(T:string)) return cast(uint)(2 + item.length);
331 	else static if (is(T == QoSLevel[])) return cast(uint)item.length;
332 	else static if (is(T == Topic)) return 3u + cast(uint)item.filter.length;
333 	else static if (isDynamicArray!T)
334 	{
335 		static if (is(ElementType!T == ubyte)) return cast(uint)item.length;
336 		else
337 		{
338 			uint len;
339 			foreach(ref e; item) len += e.itemLength();
340 			return len;
341 		}
342 	}
343 	else assert(0, "Not implemented itemLength for " ~ T.stringof);
344 }
345 
346 @safe
347 void validate(T)(auto ref in T packet) pure
348 {
349 	import std..string : format;
350 	import std.exception : enforce;
351 
352 	static if (__traits(hasMember, T, "header"))
353 	{
354 		import std.typecons : Nullable;
355 
356 		void checkHeader(ubyte value, ubyte mask = 0xFF, Nullable!uint length = Nullable!uint())
357 		{
358 			enforce((mask & 0xF0) == 0x00
359 				|| (packet.header & 0xF0 & mask) == (value & 0xF0),
360 				"Wrong packet type");
361 
362 			enforce((mask & 0x0F) == 0x00
363 				|| (packet.header & 0x0F & mask) == (value & 0x0F),
364 				"Wrong fixed header flags");
365 
366 			enforce(length.isNull || packet.header.length == length.get, "Wrong fixed header length");
367 		}
368 	}
369 
370 	static if (__traits(hasMember, T, "clientIdentifier"))
371 	{
372 		import std..string : representation;
373 
374 		if (packet.clientIdentifier.length == 0)
375 			enforce(packet.flags.cleanSession, "If the Client supplies a zero-byte ClientId, the Client MUST also set CleanSession to 1");
376 
377 		// note that some broker implementations MAY not support client identifiers with more than 23 encoded bytes - http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349242
378 	}
379 
380 	static if (is(T == ConnectFlags))
381 	{
382 		enforce(packet.will || (packet.willQoS == QoSLevel.QoS0 && !packet.willRetain),
383 			"WillQoS and Will Retain MUST be 0 if Will flag is not set");
384 		enforce(packet.userName || !packet.password, "Password MUST be set to 0 if User flag is 0");
385 	}
386 	else static if (is(T == Connect))
387 	{
388 		checkHeader(0x10);
389 		enforce(packet.header.length != 0, "Length must be set!");
390 		enforce(packet.protocolName == MQTT_PROTOCOL_NAME,
391 			format("Wrong protocol name '%s', must be '%s'", packet.protocolName, MQTT_PROTOCOL_NAME));
392 		enforce(packet.protocolLevel == MQTT_PROTOCOL_LEVEL_3_1_1,
393 			format("Unsuported protocol level '%d', must be '%d' (v3.1.1)", packet.protocolLevel, MQTT_PROTOCOL_LEVEL_3_1_1));
394 		packet.flags.validate();
395 		enforce(!packet.flags.userName || packet.userName.length > 0, "Username not set");
396 		enforce(packet.flags.userName || !packet.flags.password > 0, "Username not set, but password is");
397 	}
398 	else static if (is(T == ConnAck))
399 	{
400 		checkHeader(0x20, 0xFF, Nullable!uint(0x02));
401 		enforce(packet.flags <= 1, "Invalid Connect Acknowledge Flags");
402 		enforce(packet.returnCode <= 5, "Invalid return code");
403 	}
404 	else static if(is(T == Subscribe))
405 	{
406 		checkHeader(0x82, 0xFF);
407 		enforce(packet.topics.length > 0, "At least one topic filter MUST be provided");
408 		enforce(packet.header.length >= 5, "Invalid length");
409 	}
410 }
411 
412 mixin template processMembersTemplate(R, string fn)
413 {
414 	private R processMembers(T)(ref T item) if (isMqttPacket!T)
415 	{
416 		enum hasReturn = !is(R == void);
417 
418 		static if (hasReturn)
419 		{
420 			R res;
421 		}
422 
423 		import std.typetuple;
424 
425 		foreach(i, f; item.tupleof)
426 		{
427 			enum memberName = __traits(identifier, T.tupleof[i]);
428 			static if (is(T == Connect)) //special case for Connect packet
429 			{
430 				static if (memberName == "willTopic" || memberName == "willMessage")
431 				{
432 					if (!item.flags.will) continue;
433 				}
434 				else static if (memberName == "userName") { if (!item.flags.userName) continue; }
435 				else static if (memberName == "password") { if (!item.flags.password) continue; }
436 			}
437 			else static if(is(T == Publish)) //special case for Publish packet
438 			{
439 				static if (memberName == "packetId") if (item.header.qos == QoSLevel.QoS0) continue;
440 			}
441 
442 			//debug writeln("processing ", memberName);
443 			mixin(fn);
444 		}
445 
446 		static if (hasReturn)
447 		{
448 			return res;
449 		}
450 	}
451 }