1 /** 2 * Serialization and deserialization of MQTT protocol messages 3 * 4 * Author: 5 * Tomáš Chaloupka <chalucha@gmail.com> 6 * 7 * License: 8 * Boost Software License 1.0 (BSL-1.0) 9 * 10 * Permission is hereby granted, free of charge, to any person or organization obtaining a copy 11 * of the software and accompanying documentation covered by this license (the "Software") to use, 12 * reproduce, display, distribute, execute, and transmit the Software, and to prepare derivative 13 * works of the Software, and to permit third-parties to whom the Software is furnished to do so, 14 * all subject to the following: 15 * 16 * The copyright notices in the Software and this entire statement, including the above license 17 * grant, this restriction and the following disclaimer, must be included in all copies of the Software, 18 * in whole or in part, and all derivative works of the Software, unless such copies or derivative works 19 * are solely in the form of machine-executable object code generated by a source language processor. 20 * 21 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 22 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 23 * PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR ANYONE 24 * DISTRIBUTING THE SOFTWARE BE LIABLE FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, 25 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 26 * OTHER DEALINGS IN THE SOFTWARE. 27 */ 28 module mqttd.serialization; 29 30 import std..string : format; 31 import std.range; 32 33 import mqttd.messages; 34 import mqttd.traits; 35 36 debug import std.stdio; 37 38 @safe: 39 40 auto serialize(R, T)(auto ref R output, ref T item) if (canSerializeTo!(R)) 41 { 42 auto ser = serializer(output); 43 ser.serialize(item); 44 return ser; 45 } 46 47 auto serializer(R)(auto ref R output) if (canSerializeTo!(R)) 48 { 49 return Serializer!R(output); 50 } 51 52 struct Serializer(R) if (canSerializeTo!(R)) 53 { 54 this(R output) 55 { 56 _output = output; 57 } 58 59 @safe 60 void put(in ubyte val) 61 { 62 _output.put(val); 63 } 64 65 @safe 66 void put(in ubyte[] val) 67 { 68 _output.put(val); 69 } 70 71 static if(__traits(hasMember, R, "data")) 72 { 73 @safe 74 @property auto data() nothrow 75 { 76 return _output.data(); 77 } 78 } 79 80 static if(__traits(hasMember, R, "clear")) 81 { 82 void clear() 83 { 84 _output.clear(); 85 } 86 } 87 88 /// Serialize given Mqtt packet 89 void serialize(T)(ref T item) if (isMqttPacket!T) 90 { 91 static assert(hasFixedHeader!T, format("'%s' packet has no required header field!", T.stringof)); 92 93 mixin processMembersTemplate!(uint, `res += f.itemLength;`) L; 94 mixin processMembersTemplate!(uint, `write(f);`) W; 95 96 //set remaining packet length by checking packet conditions 97 item.header.length = L.processMembers(item); 98 99 static if (__traits(hasMember, R, "reserve")) // we can reserve required size to serialize packet 100 { 101 _output.reserve(item.header.length + 4); // 4 = max header size 102 } 103 104 //check if is valid 105 try item.validate(); 106 catch (Exception ex) 107 throw new PacketFormatException(format("'%s' packet is not valid: %s", T.stringof, ex.msg), ex); 108 109 //write members to output writer 110 W.processMembers(item); 111 } 112 113 package ref Serializer write(T)(T val) if (canWrite!T) 114 { 115 import std.traits : isDynamicArray; 116 117 bool handled = true; 118 static if (is(T == FixedHeader)) // first to avoid implicit conversion to ubyte 119 { 120 put(val.flags); 121 122 int tmp = val.length; 123 do 124 { 125 byte digit = tmp % 128; 126 tmp /= 128; 127 if (tmp > 0) digit |= 0x80; 128 put(digit); 129 } while (tmp > 0); 130 } 131 else static if (is(T:ubyte)) 132 { 133 put(val); 134 } 135 else static if (is(T:ushort)) 136 { 137 put(cast(ubyte) (val >> 8)); 138 put(cast(ubyte) val); 139 } 140 else static if (is(T:string)) 141 { 142 import std..string : representation; 143 import std.exception : enforce; 144 145 enforce(val.length <= 0xFF, "String too long: ", val); 146 147 write((cast(ushort)val.length)); 148 put(val.representation); 149 } 150 else static if (isDynamicArray!T) 151 { 152 static if (is(ElementType!T == ubyte)) put(val); 153 else foreach(ret; val) write(ret); 154 } 155 else static if (is(T == struct)) //write struct members individually 156 { 157 foreach(memberName; __traits(allMembers, T)) 158 write(__traits(getMember, val, memberName)); 159 } 160 else 161 { 162 handled = false; 163 } 164 165 if (handled) return this; 166 assert(0, "Not implemented write for: " ~ T.stringof); 167 } 168 169 private: 170 R _output; 171 } 172 173 template deserialize(T) 174 { 175 auto deserialize(R)(auto ref R input) if (canDeserializeFrom!(R)) 176 { 177 return deserializer(input).deserialize!T(); 178 } 179 } 180 181 auto deserializer(R)(auto ref R input) if (canDeserializeFrom!(R)) 182 { 183 return Deserializer!R(input); 184 } 185 186 struct Deserializer(R) if (canDeserializeFrom!(R)) 187 { 188 this(R input) 189 { 190 _input = input; 191 } 192 193 @property ubyte front() 194 { 195 //debug writef("%.02x ", _input.front); 196 return cast(ubyte)_input.front; 197 } 198 199 @property bool empty() 200 { 201 //debug if (_input.empty) writeln(); 202 return _input.empty; 203 } 204 205 void popFront() 206 { 207 _input.popFront(); 208 if(_remainingLen > 0) _remainingLen--; //decrease remaining length set from fixed header 209 //debug writefln("Pop: %s", empty? "empty" : format("%.02x", front)); 210 } 211 212 T deserialize(T)() if (isMqttPacket!T) 213 { 214 import std.typetuple; 215 import std.exception : enforce; 216 217 static assert(hasFixedHeader!T, format("'%s' packet has no required header field!", T.stringof)); 218 219 mixin processMembersTemplate!(void, "item.tupleof[i] = read!(typeof(f))();"); 220 221 T res; 222 223 processMembers(res); 224 225 enforce(empty, "Some data are remaining after packet deserialization!"); 226 227 // validate initialized packet 228 try res.validate(); 229 catch (Exception ex) 230 throw new PacketFormatException(format("'%s' packet is not valid: %s", T.stringof, ex.msg), ex); 231 232 return res; 233 } 234 235 package T read(T)() if (canRead!T) 236 { 237 import std.traits : isDynamicArray; 238 239 auto handled = true; 240 T res; 241 242 static if (is(T == FixedHeader)) // first to avoid implicit conversion to ubyte 243 { 244 res.flags = read!ubyte(); 245 res.length = 0; 246 247 uint multiplier = 1; 248 ubyte digit; 249 do 250 { 251 digit = read!ubyte(); 252 res.length += ((digit & 127) * multiplier); 253 multiplier *= 128; 254 if (multiplier > 128*128*128) throw new PacketFormatException("Malformed remaining length"); 255 } while ((digit & 128) != 0); 256 257 //set remaining length for calculations 258 _remainingLen = res.length; 259 } 260 else static if (is(T:ubyte)) 261 { 262 res = cast(T)front; 263 popFront(); 264 } 265 else static if (is(T:ushort)) 266 { 267 res = cast(ushort) (read!ubyte() << 8); 268 res |= cast(ushort) read!ubyte(); 269 } 270 else static if (is(T:string)) 271 { 272 import std.array : array; 273 import std.algorithm : map; 274 275 auto length = read!ushort(); 276 static if(hasSlicing!R) 277 { 278 //writeln(cast(string)_input[0..length]); 279 res = (cast(char[])_input[0..length]).idup; 280 _remainingLen -= length; 281 _input = _input.length > length ? _input[length..$] : R.init; 282 } 283 else res = (&this).takeExactly(length).map!(a => cast(immutable char)a).array; 284 } 285 else static if (isDynamicArray!T) 286 { 287 res = T.init; 288 static if (is(ElementType!T == ubyte) && hasSlicing!R) //slice it 289 { 290 res = _input[0..$]; 291 _remainingLen -= res.length; 292 _input = R.init; 293 } 294 else 295 { 296 while(_remainingLen > 0) // read to end 297 { 298 res ~= read!(ElementType!T)(); 299 } 300 } 301 } 302 else static if (is(T == struct)) //read struct members individually 303 { 304 foreach(memberName; __traits(allMembers, T)) 305 __traits(getMember, res, memberName) = read!(typeof(__traits(getMember, res, memberName)))(); 306 } 307 else 308 { 309 handled = false; 310 } 311 312 if (handled) return res; 313 assert(0, "Not implemented read for: " ~ T.stringof); 314 } 315 316 private: 317 R _input; 318 uint _remainingLen; 319 } 320 321 /// Gets required buffer size to encode into 322 @safe @nogc 323 uint itemLength(T)(auto ref in T item) pure nothrow 324 { 325 import std.traits : isDynamicArray; 326 327 static if (is(T == FixedHeader)) return 0; 328 else static if (is(T:ubyte)) return 1; 329 else static if (is(T:ushort)) return 2; 330 else static if (is(T:string)) return cast(uint)(2 + item.length); 331 else static if (is(T == QoSLevel[])) return cast(uint)item.length; 332 else static if (is(T == Topic)) return 3u + cast(uint)item.filter.length; 333 else static if (isDynamicArray!T) 334 { 335 static if (is(ElementType!T == ubyte)) return cast(uint)item.length; 336 else 337 { 338 uint len; 339 foreach(ref e; item) len += e.itemLength(); 340 return len; 341 } 342 } 343 else assert(0, "Not implemented itemLength for " ~ T.stringof); 344 } 345 346 @safe 347 void validate(T)(auto ref in T packet) pure 348 { 349 import std..string : format; 350 import std.exception : enforce; 351 352 static if (__traits(hasMember, T, "header")) 353 { 354 import std.typecons : Nullable; 355 356 void checkHeader(ubyte value, ubyte mask = 0xFF, Nullable!uint length = Nullable!uint()) 357 { 358 enforce((mask & 0xF0) == 0x00 359 || (packet.header & 0xF0 & mask) == (value & 0xF0), 360 "Wrong packet type"); 361 362 enforce((mask & 0x0F) == 0x00 363 || (packet.header & 0x0F & mask) == (value & 0x0F), 364 "Wrong fixed header flags"); 365 366 enforce(length.isNull || packet.header.length == length.get, "Wrong fixed header length"); 367 } 368 } 369 370 static if (__traits(hasMember, T, "clientIdentifier")) 371 { 372 import std..string : representation; 373 374 if (packet.clientIdentifier.length == 0) 375 enforce(packet.flags.cleanSession, "If the Client supplies a zero-byte ClientId, the Client MUST also set CleanSession to 1"); 376 377 // note that some broker implementations MAY not support client identifiers with more than 23 encoded bytes - http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349242 378 } 379 380 static if (is(T == ConnectFlags)) 381 { 382 enforce(packet.will || (packet.willQoS == QoSLevel.QoS0 && !packet.willRetain), 383 "WillQoS and Will Retain MUST be 0 if Will flag is not set"); 384 enforce(packet.userName || !packet.password, "Password MUST be set to 0 if User flag is 0"); 385 } 386 else static if (is(T == Connect)) 387 { 388 checkHeader(0x10); 389 enforce(packet.header.length != 0, "Length must be set!"); 390 enforce(packet.protocolName == MQTT_PROTOCOL_NAME, 391 format("Wrong protocol name '%s', must be '%s'", packet.protocolName, MQTT_PROTOCOL_NAME)); 392 enforce(packet.protocolLevel == MQTT_PROTOCOL_LEVEL_3_1_1, 393 format("Unsuported protocol level '%d', must be '%d' (v3.1.1)", packet.protocolLevel, MQTT_PROTOCOL_LEVEL_3_1_1)); 394 packet.flags.validate(); 395 enforce(!packet.flags.userName || packet.userName.length > 0, "Username not set"); 396 enforce(packet.flags.userName || !packet.flags.password > 0, "Username not set, but password is"); 397 } 398 else static if (is(T == ConnAck)) 399 { 400 checkHeader(0x20, 0xFF, Nullable!uint(0x02)); 401 enforce(packet.flags <= 1, "Invalid Connect Acknowledge Flags"); 402 enforce(packet.returnCode <= 5, "Invalid return code"); 403 } 404 else static if(is(T == Subscribe)) 405 { 406 checkHeader(0x82, 0xFF); 407 enforce(packet.topics.length > 0, "At least one topic filter MUST be provided"); 408 enforce(packet.header.length >= 5, "Invalid length"); 409 } 410 } 411 412 mixin template processMembersTemplate(R, string fn) 413 { 414 private R processMembers(T)(ref T item) if (isMqttPacket!T) 415 { 416 enum hasReturn = !is(R == void); 417 418 static if (hasReturn) 419 { 420 R res; 421 } 422 423 import std.typetuple; 424 425 foreach(i, f; item.tupleof) 426 { 427 enum memberName = __traits(identifier, T.tupleof[i]); 428 static if (is(T == Connect)) //special case for Connect packet 429 { 430 static if (memberName == "willTopic" || memberName == "willMessage") 431 { 432 if (!item.flags.will) continue; 433 } 434 else static if (memberName == "userName") { if (!item.flags.userName) continue; } 435 else static if (memberName == "password") { if (!item.flags.password) continue; } 436 } 437 else static if(is(T == Publish)) //special case for Publish packet 438 { 439 static if (memberName == "packetId") if (item.header.qos == QoSLevel.QoS0) continue; 440 } 441 442 //debug writeln("processing ", memberName); 443 mixin(fn); 444 } 445 446 static if (hasReturn) 447 { 448 return res; 449 } 450 } 451 }