diff --git a/lib/PicoMQTT/src/PicoMQTT.h b/lib/PicoMQTT/src/PicoMQTT.h new file mode 100644 index 00000000..7fa7f667 --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +#if defined(ESP8266) + +#if (ARDUINO_ESP8266_MAJOR != 3) || (ARDUINO_ESP8266_MINOR < 1) +#error PicoMQTT requires ESP8266 board core version >= 3.1 +#endif + +#elif defined(ESP32) + +//#if ESP_ARDUINO_VERSION < ESP_ARDUINO_VERSION_VAL(2, 0, 7) +//#error PicoMQTT requires ESP32 board core version >= 2.0.7 +//#endif + +#endif + +#include "PicoMQTT/client.h" +#include "PicoMQTT/server.h" diff --git a/lib/PicoMQTT/src/PicoMQTT/autoid.h b/lib/PicoMQTT/src/PicoMQTT/autoid.h new file mode 100644 index 00000000..a3b17643 --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/autoid.h @@ -0,0 +1,21 @@ +#pragma once + +namespace PicoMQTT { + +class AutoId { + public: + typedef unsigned int Id; + + AutoId(): id(generate_id()) {} + AutoId(const AutoId &) = default; + + const Id id; + + private: + static Id generate_id() { + static Id next_id = 1; + return next_id++; + } +}; + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/client.cpp b/lib/PicoMQTT/src/PicoMQTT/client.cpp new file mode 100644 index 00000000..ee242826 --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/client.cpp @@ -0,0 +1,288 @@ +#include "client.h" +#include "debug.h" + +namespace PicoMQTT { + +BasicClient::BasicClient(unsigned long keep_alive_seconds, unsigned long socket_timeout_seconds) + : Connection(keep_alive_seconds, socket_timeout_seconds) { + TRACE_FUNCTION +} + +BasicClient::BasicClient(const ::WiFiClient & client, unsigned long keep_alive_seconds, + unsigned long socket_timeout_seconds) + : Connection(client, keep_alive_seconds, socket_timeout_seconds) { + TRACE_FUNCTION +} + +bool BasicClient::connect( + const char * host, + uint16_t port, + const char * id, + const char * user, + const char * pass, + const char * will_topic, + const char * will_message, + const size_t will_message_length, + uint8_t will_qos, + bool will_retain, + const bool clean_session, + ConnectReturnCode * connect_return_code) { + TRACE_FUNCTION + + if (connect_return_code) { + *connect_return_code = CRC_UNDEFINED; + } + + client.stop(); + + if (!client.connect(host, port)) { + return false; + } + + message_id_generator.reset(); + + const bool will = will_topic && will_message; + + const uint8_t connect_flags = + (user ? 1 : 0) << 7 + | (user && pass ? 1 : 0) << 6 + | (will && will_retain ? 1 : 0) << 5 + | (will && will_qos ? 1 : 0) << 3 + | (will ? 1 : 0) << 2 + | (clean_session ? 1 : 0) << 1; + + const size_t client_id_length = strlen(id); + const size_t will_topic_length = (will && will_topic) ? strlen(will_topic) : 0; + const size_t user_length = user ? strlen(user) : 0; + const size_t pass_length = pass ? strlen(pass) : 0; + + const size_t total_size = 6 // protocol name + + 1 // protocol level + + 1 // connect flags + + 2 // keep-alive + + client_id_length + 2 + + (will ? will_topic_length + 2 : 0) + + (will ? will_message_length + 2 : 0) + + (user ? user_length + 2 : 0) + + (user && pass ? pass_length + 2 : 0); + + auto packet = build_packet(Packet::CONNECT, 0, total_size); + packet.write_string("MQTT", 4); + packet.write_u8(4); + packet.write_u8(connect_flags); + packet.write_u16(keep_alive_millis / 1000); + packet.write_string(id, client_id_length); + + if (will) { + packet.write_string(will_topic, will_topic_length); + packet.write_string(will_message, will_message_length); + } + + if (user) { + packet.write_string(user, user_length); + if (pass) { + packet.write_string(pass, pass_length); + } + } + + if (!packet.send()) { + return false; + } + + wait_for_reply(Packet::CONNACK, [this, connect_return_code](IncomingPacket & packet) { + TRACE_FUNCTION + if (packet.size != 2) { + on_protocol_violation(); + return; + } + + /* const uint8_t connect_ack_flags = */ packet.read_u8(); + const uint8_t crc = packet.read_u8(); + + if (connect_return_code) { + *connect_return_code = (ConnectReturnCode) crc; + } + + if (crc != 0) { + // connection refused + client.stop(); + } + }); + + return client.connected(); +} + +void BasicClient::loop() { + TRACE_FUNCTION + + if (client.connected() && get_millis_since_last_write() >= keep_alive_millis) { + // ping time! + build_packet(Packet::PINGREQ).send(); + wait_for_reply(Packet::PINGRESP, [](IncomingPacket &) {}); + } + + Connection::loop(); +} + +Publisher::Publish BasicClient::begin_publish(const char * topic, const size_t payload_size, + uint8_t qos, bool retain, uint16_t message_id) { + TRACE_FUNCTION + return Publish( + *this, + client.status() ? client : PrintMux(), + topic, payload_size, + (qos >= 1) ? 1 : 0, + retain, + message_id, // dup if message_id is non-zero + message_id ? message_id : message_id_generator.generate() // generate only if message_id == 0 + ); +} + +bool BasicClient::on_publish_complete(const Publish & publish) { + TRACE_FUNCTION + if (publish.qos == 0) { + return true; + } + + bool confirmed = false; + wait_for_reply(Packet::PUBACK, [&publish, &confirmed](IncomingPacket & puback) { + confirmed |= (puback.read_u16() == publish.message_id); + }); + + return confirmed; +} + +bool BasicClient::subscribe(const String & topic, uint8_t qos, uint8_t * qos_granted) { + TRACE_FUNCTION + if (qos > 1) { + return false; + } + + const size_t topic_size = topic.length(); + const uint16_t message_id = message_id_generator.generate(); + + auto packet = build_packet(Packet::SUBSCRIBE, 0b0010, 2 + 2 + topic_size + 1); + packet.write_u16(message_id); + packet.write_string(topic.c_str(), topic_size); + packet.write_u8(qos); + packet.send(); + + uint8_t code = 0x80; + + wait_for_reply(Packet::SUBACK, [this, message_id, &code](IncomingPacket & packet) { + if (packet.read_u16() != message_id) { + on_protocol_violation(); + } else { + code = packet.read_u8(); + } + }); + + if (code == 0x80) { + return false; + } + + if (qos_granted) { + *qos_granted = code; + } + + return client.connected(); +} + +bool BasicClient::unsubscribe(const String & topic) { + TRACE_FUNCTION + + const size_t topic_size = topic.length(); + const uint16_t message_id = message_id_generator.generate(); + + auto packet = build_packet(Packet::UNSUBSCRIBE, 0b0010, 2 + 2 + topic_size); + packet.write_u16(message_id); + packet.write_string(topic.c_str(), topic_size); + packet.send(); + + wait_for_reply(Packet::UNSUBACK, [this, message_id](IncomingPacket & packet) { + if (packet.read_u16() != message_id) { + on_protocol_violation(); + } + }); + + return client.connected(); +} + +Client::Client(const char * host, uint16_t port, const char * id, const char * user, const char * password, + unsigned long reconnect_interval_millis) + : host(host), port(port), client_id(id), username(user), password(password), + will({"", "", 0, false}), +reconnect_interval_millis(reconnect_interval_millis), +last_reconnect_attempt(millis() - reconnect_interval_millis) { + TRACE_FUNCTION +} + +Client::SubscriptionId Client::subscribe(const String & topic_filter, MessageCallback callback) { + TRACE_FUNCTION + BasicClient::subscribe(topic_filter); + return SubscribedMessageListener::subscribe(topic_filter, callback); +} + +void Client::unsubscribe(const String & topic_filter) { + TRACE_FUNCTION + BasicClient::unsubscribe(topic_filter); + SubscribedMessageListener::unsubscribe(topic_filter); +} + +void Client::on_message(const char * topic, IncomingPacket & packet) { + SubscribedMessageListener::fire_message_callbacks(topic, packet); +} + +void Client::loop() { + TRACE_FUNCTION + if (!client.connected()) { + if (host.isEmpty() || !port) { + return; + } + + if (millis() - last_reconnect_attempt < reconnect_interval_millis) { + return; + } + + const bool connection_established = connect(host.c_str(), port, + client_id.isEmpty() ? "" : client_id.c_str(), + username.isEmpty() ? nullptr : username.c_str(), + password.isEmpty() ? nullptr : password.c_str(), + will.topic.isEmpty() ? nullptr : will.topic.c_str(), + will.payload.isEmpty() ? nullptr : will.payload.c_str(), + will.payload.isEmpty() ? 0 : will.payload.length(), + will.qos, will.retain); + + last_reconnect_attempt = millis(); + + if (!connection_established) { + return; + } + + for (const auto & kv : subscriptions) { + BasicClient::subscribe(kv.first.c_str()); + } + + on_connect(); + } + + BasicClient::loop(); +} + +void Client::on_connect() { + TRACE_FUNCTION + BasicClient::on_connect(); + if (connected_callback) { + connected_callback(); + } +} + +void Client::on_disconnect() { + TRACE_FUNCTION + BasicClient::on_disconnect(); + if (disconnected_callback) { + connected_callback(); + } +} + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/client.h b/lib/PicoMQTT/src/PicoMQTT/client.h new file mode 100644 index 00000000..4d0ab542 --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/client.h @@ -0,0 +1,82 @@ +#pragma once + +#include + +#include "connection.h" +#include "incoming_packet.h" +#include "outgoing_packet.h" +#include "pico_interface.h" +#include "publisher.h" +#include "subscriber.h" + +namespace PicoMQTT { + +class BasicClient: public PicoMQTTInterface, public Connection, public Publisher { + public: + BasicClient(unsigned long keep_alive_seconds = 60, unsigned long socket_timeout_seconds = 10); + + BasicClient(const ::WiFiClient & client, unsigned long keep_alive_seconds = 60, + unsigned long socket_timeout_seconds = 10); + + bool connect( + const char * host, uint16_t port = 1883, + const char * id = "", const char * user = nullptr, const char * pass = nullptr, + const char * will_topic = nullptr, const char * will_message = nullptr, + const size_t will_message_length = 0, uint8_t willQos = 0, bool willRetain = false, + const bool cleanSession = true, + ConnectReturnCode * connect_return_code = nullptr); + + using Publisher::begin_publish; + virtual Publish begin_publish(const char * topic, const size_t payload_size, + uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) override; + + bool subscribe(const String & topic, uint8_t qos = 0, uint8_t * qos_granted = nullptr); + bool unsubscribe(const String & topic); + + void loop() override; + + virtual void on_connect() {} + + private: + virtual bool on_publish_complete(const Publish & publish) override; +}; + +class Client: public BasicClient, public SubscribedMessageListener { + public: + Client(const char * host = nullptr, uint16_t port = 1883, const char * id = nullptr, const char * user = nullptr, + const char * password = nullptr, unsigned long reconnect_interval_millis = 5 * 1000); + + using SubscribedMessageListener::subscribe; + virtual SubscriptionId subscribe(const String & topic_filter, MessageCallback callback) override; + virtual void unsubscribe(const String & topic_filter) override; + + virtual void loop() override; + + String host; + uint16_t port; + + String client_id; + String username; + String password; + + struct { + String topic; + String payload; + uint8_t qos; + bool retain; + } will; + + unsigned long reconnect_interval_millis; + + std::function connected_callback; + std::function disconnected_callback; + + virtual void on_connect() override; + virtual void on_disconnect() override; + + protected: + unsigned long last_reconnect_attempt; + virtual void on_message(const char * topic, IncomingPacket & packet) override; +}; + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/client_wrapper.cpp b/lib/PicoMQTT/src/PicoMQTT/client_wrapper.cpp new file mode 100644 index 00000000..c2c37b1a --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/client_wrapper.cpp @@ -0,0 +1,119 @@ +#include "Arduino.h" + +#include "client_wrapper.h" + +#include "debug.h" + +namespace PicoMQTT { + +ClientWrapper::ClientWrapper(unsigned long socket_timeout_seconds): socket_timeout_millis( + socket_timeout_seconds * 1000) { + TRACE_FUNCTION +} + +ClientWrapper::ClientWrapper(const ::WiFiClient & client, unsigned long socket_timeout_seconds): + WiFiClient(client), socket_timeout_millis(socket_timeout_seconds * 1000) { + TRACE_FUNCTION +} + +// reads +int ClientWrapper::available_wait(unsigned long timeout) { + TRACE_FUNCTION + const unsigned long start_millis = millis(); + while (true) { + const int ret = available(); + if (ret > 0) { + return ret; + } + if (!status()) { + // A disconnected client might still have unread data waiting in buffers. Don't move this check earlier. + return 0; + } + const unsigned long elapsed = millis() - start_millis; + if (elapsed > timeout) { + return 0; + } + yield(); + } +} + +int ClientWrapper::read(uint8_t * buf, size_t size) { + TRACE_FUNCTION + const unsigned long start_millis = millis(); + size_t ret = 0; + + while (ret < size) { + const unsigned long now_millis = millis(); + const unsigned long elapsed_millis = now_millis - start_millis; + + if (elapsed_millis > socket_timeout_millis) { + // timeout + abort(); + break; + } + + const unsigned long remaining_millis = socket_timeout_millis - elapsed_millis; + + const int available_size = available_wait(remaining_millis); + if (available_size <= 0) { + // timeout + abort(); + break; + } + + const int chunk_size = size - ret < (size_t) available_size ? size - ret : (size_t) available_size; + + const int bytes_read = WiFiClient::read(buf + ret, chunk_size); + if (bytes_read <= 0) { + // connection error + abort(); + break; + } + + ret += bytes_read; + } + + return ret; +} + +int ClientWrapper::read() { + TRACE_FUNCTION + if (!available_wait(socket_timeout_millis)) { + return -1; + } + return WiFiClient::read(); +} + +int ClientWrapper::peek() { + TRACE_FUNCTION + if (!available_wait(socket_timeout_millis)) { + return -1; + } + return WiFiClient::peek(); +} + +// writes +size_t ClientWrapper::write(const uint8_t * buffer, size_t size) { + TRACE_FUNCTION + size_t ret = 0; + + while (status() && ret < size) { + const int bytes_written = WiFiClient::write(buffer + ret, size - ret); + if (bytes_written <= 0) { + // connection error + abort(); + return 0; + } + + ret += bytes_written; + } + + return ret; +} + +size_t ClientWrapper::write(uint8_t value) { + TRACE_FUNCTION + return write(&value, 1); +} + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/client_wrapper.h b/lib/PicoMQTT/src/PicoMQTT/client_wrapper.h new file mode 100644 index 00000000..148efec1 --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/client_wrapper.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +namespace PicoMQTT { + +class ClientWrapper: public ::WiFiClient { + public: + ClientWrapper(unsigned long socket_timeout_seconds); + ClientWrapper(const WiFiClient & client, unsigned long socket_timeout_seconds); + ClientWrapper(const ClientWrapper &) = default; + + virtual int peek() override; + virtual int read() override; + virtual int read(uint8_t * buf, size_t size) override; + virtual size_t write(const uint8_t * buffer, size_t size) override; + virtual size_t write(uint8_t value) override final; + +#ifdef ESP32 + // these methods are only available in WiFiClient on ESP8266 + uint8_t status() { return connected(); } + void abort() { stop(); } +#endif + + const unsigned long socket_timeout_millis; + + protected: + int available_wait(unsigned long timeout); +}; + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/config.h b/lib/PicoMQTT/src/PicoMQTT/config.h new file mode 100644 index 00000000..79a6384a --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/config.h @@ -0,0 +1,29 @@ +#pragma once + +#ifndef PICOMQTT_MAX_TOPIC_SIZE +#define PICOMQTT_MAX_TOPIC_SIZE 256 +#endif + +#ifndef PICOMQTT_MAX_MESSAGE_SIZE +#define PICOMQTT_MAX_MESSAGE_SIZE 1024 +#endif + +#ifndef PICOMQTT_MAX_CLIENT_ID_SIZE +/* + * The MQTT standard requires brokers to accept client ids that are + * 1-23 chars long, but allows longer client IDs to be accepted too. + */ +#define PICOMQTT_MAX_CLIENT_ID_SIZE 64 +#endif + +#ifndef PICOMQTT_MAX_USERPASS_SIZE +#define PICOMQTT_MAX_USERPASS_SIZE 256 +#endif + +#ifndef PICOMQTT_OUTGOING_BUFFER_SIZE +#define PICOMQTT_OUTGOING_BUFFER_SIZE 128 +#endif + +// #define PICOMQTT_DEBUG + +// #define PICOMQTT_DEBUG_TRACE_FUNCTIONS diff --git a/lib/PicoMQTT/src/PicoMQTT/connection.cpp b/lib/PicoMQTT/src/PicoMQTT/connection.cpp new file mode 100644 index 00000000..54164933 --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/connection.cpp @@ -0,0 +1,176 @@ +#include "config.h" +#include "connection.h" +#include "debug.h" + +namespace PicoMQTT { + +Connection::Connection(unsigned long keep_alive_seconds, unsigned long socket_timeout_seconds) : + client(socket_timeout_seconds), + keep_alive_millis(keep_alive_seconds * 1000), + last_read(millis()), last_write(millis()) { + TRACE_FUNCTION +} + +Connection::Connection(const ::WiFiClient & client, unsigned long keep_alive_seconds, + unsigned long socket_timeout_seconds) : + client(client, socket_timeout_seconds), + keep_alive_millis(keep_alive_seconds * 1000), + last_read(millis()), last_write(millis()) { + TRACE_FUNCTION +} + +OutgoingPacket Connection::build_packet(Packet::Type type, uint8_t flags, size_t length) { + TRACE_FUNCTION + last_write = millis(); + auto ret = OutgoingPacket(client, type, flags, length); + ret.write_header(); + return ret; +} + +void Connection::on_timeout() { + TRACE_FUNCTION + client.abort(); + on_disconnect(); +} + +void Connection::on_protocol_violation() { + TRACE_FUNCTION + on_disconnect(); +} + +void Connection::on_disconnect() { + TRACE_FUNCTION + client.stop(); +} + +void Connection::disconnect() { + TRACE_FUNCTION + build_packet(Packet::DISCONNECT).send(); + client.stop(); +} + +bool Connection::connected() { + TRACE_FUNCTION + return client.connected(); +} + +void Connection::wait_for_reply(Packet::Type type, std::function handler) { + TRACE_FUNCTION + + const unsigned long start = millis(); + + while (client.connected() && (millis() - start < client.socket_timeout_millis)) { + + IncomingPacket packet(client); + if (!packet) { + break; + } + + last_read = millis(); + + if (packet.get_type() == type) { + handler(packet); + return; + } + + handle_packet(packet); + + } + + if (client.connected()) { + on_timeout(); + } +} + +void Connection::send_ack(Packet::Type ack_type, uint16_t msg_id) { + TRACE_FUNCTION + auto ack = build_packet(ack_type, 0, 2); + ack.write_u16(msg_id); + ack.send(); +} + +void Connection::handle_packet(IncomingPacket & packet) { + TRACE_FUNCTION + + switch (packet.get_type()) { + case Packet::PUBLISH: { + const uint16_t topic_size = packet.read_u16(); + + // const bool dup = (packet.get_flags() >> 3) & 0b1; + const uint8_t qos = (packet.get_flags() >> 1) & 0b11; + // const bool retain = packet.get_flags() & 0b1; + + uint16_t msg_id = 0; + + if (topic_size > PICOMQTT_MAX_TOPIC_SIZE) { + packet.ignore(topic_size); + on_topic_too_long(packet); + if (qos) { + msg_id = packet.read_u16(); + } + } else { + char topic[topic_size + 1]; + if (!packet.read_string(topic, topic_size)) { + // connection error + return; + } + if (qos) { + msg_id = packet.read_u16(); + } + on_message(topic, packet); + } + + if (msg_id) { + send_ack(qos == 1 ? Packet::PUBACK : Packet::PUBREC, msg_id); + } + + break; + }; + + case Packet::PUBREC: + send_ack(Packet::PUBREL, packet.read_u16()); + break; + + case Packet::PUBREL: + send_ack(Packet::PUBCOMP, packet.read_u16()); + break; + + case Packet::PUBCOMP: + // ignore + break; + + case Packet::DISCONNECT: + on_disconnect(); + break; + + default: + on_protocol_violation(); + break; + } +} + +unsigned long Connection::get_millis_since_last_read() const { + TRACE_FUNCTION + return millis() - last_read; +} + +unsigned long Connection::get_millis_since_last_write() const { + TRACE_FUNCTION + return millis() - last_write; +} + +void Connection::loop() { + TRACE_FUNCTION + + // only handle 10 packets max in one go to not starve other connections + for (unsigned int i = 0; (i < 10) && client.available(); ++i) { + IncomingPacket packet(client); + if (!packet.is_valid()) { + return; + } + last_read = millis(); + handle_packet(packet); + } +} + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/connection.h b/lib/PicoMQTT/src/PicoMQTT/connection.h new file mode 100644 index 00000000..ec8c6925 --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/connection.h @@ -0,0 +1,80 @@ +#pragma once + +#include + +#include + +#include "client_wrapper.h" +#include "incoming_packet.h" +#include "outgoing_packet.h" + +namespace PicoMQTT { + +enum ConnectReturnCode : uint8_t { + CRC_ACCEPTED = 0, + CRC_UNACCEPTABLE_PROTOCOL_VERSION = 1, + CRC_IDENTIFIER_REJECTED = 2, + CRC_SERVER_UNAVAILABLE = 3, + CRC_BAD_USERNAME_OR_PASSWORD = 4, + CRC_NOT_AUTHORIZED = 5, + + // internal + CRC_UNDEFINED = 255, +}; + +class Connection { + public: + Connection(unsigned long keep_alive_seconds = 0, unsigned long socket_timeout_seconds = 15); + Connection(const ::WiFiClient & client, unsigned long keep_alive_seconds = 0, + unsigned long socket_timeout_seconds = 15); + Connection(const Connection &) = default; + + virtual ~Connection() {} + + bool connected(); + void disconnect(); + + virtual void loop(); + + protected: + class MessageIdGenerator { + public: + MessageIdGenerator(): value(0) {} + uint16_t generate() { + if (++value == 0) { value = 1; } + return value; + } + + void reset() { value = 0; } + + protected: + uint16_t value; + } message_id_generator; + + OutgoingPacket build_packet(Packet::Type type, uint8_t flags = 0, size_t length = 0); + + void wait_for_reply(Packet::Type type, std::function handler); + + virtual void on_topic_too_long(const IncomingPacket & packet) {} + virtual void on_message(const char * topic, IncomingPacket & packet) {} + + virtual void on_timeout(); + virtual void on_protocol_violation(); + virtual void on_disconnect(); + + ClientWrapper client; + uint16_t keep_alive_millis; + + virtual void handle_packet(IncomingPacket & packet); + + protected: + unsigned long get_millis_since_last_read() const; + unsigned long get_millis_since_last_write() const; + + private: + unsigned long last_read; + unsigned long last_write; + void send_ack(Packet::Type ack_type, uint16_t msg_id); +}; + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/debug.h b/lib/PicoMQTT/src/PicoMQTT/debug.h new file mode 100644 index 00000000..01e39a95 --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/debug.h @@ -0,0 +1,50 @@ +#pragma once + +#include "config.h" + +#ifdef PICOMQTT_DEBUG_TRACE_FUNCTIONS + +#include + +namespace PicoMQTT { + +class FunctionTracer { + public: + FunctionTracer(const char * function_name) : function_name(function_name) { + indent(1); + Serial.print(F("CALL ")); + Serial.println(function_name); + } + + ~FunctionTracer() { + indent(-1); + Serial.print(F("RETURN ")); + Serial.println(function_name); + } + + const char * const function_name; + + protected: + void indent(int delta) { + static int depth = 0; + if (delta < 0) { + depth += delta; + } + for (int i = 0; i < depth; ++i) { + Serial.print(" "); + } + if (delta > 0) { + depth += delta; + } + } +}; + +} + +#define TRACE_FUNCTION FunctionTracer _function_tracer(__PRETTY_FUNCTION__); + +#else + +#define TRACE_FUNCTION + +#endif diff --git a/lib/PicoMQTT/src/PicoMQTT/incoming_packet.cpp b/lib/PicoMQTT/src/PicoMQTT/incoming_packet.cpp new file mode 100644 index 00000000..1b7e60df --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/incoming_packet.cpp @@ -0,0 +1,171 @@ +#include "incoming_packet.h" +#include "debug.h" + +namespace PicoMQTT { + +IncomingPacket::IncomingPacket(Client & client) + : Packet(read_header(client)), client(client) { + TRACE_FUNCTION +} + +IncomingPacket::IncomingPacket(IncomingPacket && other) + : Packet(other), client(other.client) { + TRACE_FUNCTION + other.pos = size; +} + +IncomingPacket::~IncomingPacket() { + TRACE_FUNCTION +#ifdef PICOMQTT_DEBUG + if (pos != size) { + Serial.print(F("IncomingPacket read incorrect number of bytes: ")); + Serial.print(pos); + Serial.print(F("/")); + Serial.println(size); + } +#endif + // read and ignore remaining data + while (get_remaining_size() && (read() >= 0)); +} + +// disabled functions +int IncomingPacket::connect(IPAddress ip, uint16_t port) { + TRACE_FUNCTION; + return 0; +} + +int IncomingPacket::connect(const char * host, uint16_t port) { + TRACE_FUNCTION; + return 0; +} + +size_t IncomingPacket::write(const uint8_t * buffer, size_t size) { + TRACE_FUNCTION + return 0; +} + +size_t IncomingPacket::write(uint8_t value) { + TRACE_FUNCTION + return 0; +} + +void IncomingPacket::flush() { + TRACE_FUNCTION +} + +void IncomingPacket::stop() { + TRACE_FUNCTION +} + + +// extended functions +int IncomingPacket::available() { + TRACE_FUNCTION; + return get_remaining_size(); +} + +int IncomingPacket::peek() { + TRACE_FUNCTION + if (!get_remaining_size()) { +#if PICOMQTT_DEBUG + Serial.println(F("Attempt to peek beyond end of IncomingPacket.")); +#endif + return -1; + } + return client.peek(); +} + +int IncomingPacket::read() { + TRACE_FUNCTION + if (!get_remaining_size()) { +#if PICOMQTT_DEBUG + Serial.println(F("Attempt to read beyond end of IncomingPacket.")); +#endif + return -1; + } + const int ret = client.read(); + if (ret >= 0) { + ++pos; + } + return ret; +} + +int IncomingPacket::read(uint8_t * buf, size_t size) { + TRACE_FUNCTION + const size_t remaining = get_remaining_size(); + const size_t read_size = remaining < size ? remaining : size; +#if PICOMQTT_DEBUG + if (size > remaining) { + Serial.println(F("Attempt to read buf beyond end of IncomingPacket.")); + } +#endif + const int ret = client.read(buf, read_size); + if (ret > 0) { + pos += ret; + } + return ret; +} + +IncomingPacket::operator bool() { + TRACE_FUNCTION + return is_valid() && bool(client); +} + +uint8_t IncomingPacket::connected() { + TRACE_FUNCTION + return is_valid() && client.connected(); +} + +// extra functions +uint8_t IncomingPacket::read_u8() { + TRACE_FUNCTION; + return get_remaining_size() ? read() : 0; +} + +uint16_t IncomingPacket::read_u16() { + TRACE_FUNCTION; + return ((uint16_t) read_u8()) << 8 | ((uint16_t) read_u8()); +} + +bool IncomingPacket::read_string(char * buffer, size_t len) { + if (read((uint8_t *) buffer, len) != (int) len) { + return false; + } + buffer[len] = '\0'; + return true; +} + +void IncomingPacket::ignore(size_t len) { + while (len--) { + read(); + } +} + +Packet IncomingPacket::read_header(Client & client) { + TRACE_FUNCTION + const int head = client.read(); + if (head <= 0) { + return Packet(); + } + + uint32_t size = 0; + for (size_t length_size = 0; ; ++length_size) { + if (length_size >= 5) { + return Packet(); + } + const int digit = client.read(); + if (digit < 0) { + return Packet(); + } + + size |= (digit & 0x7f) << (7 * length_size); + + if (!(digit & 0x80)) { + break; + } + } + + return Packet(head, size); +} + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/incoming_packet.h b/lib/PicoMQTT/src/PicoMQTT/incoming_packet.h new file mode 100644 index 00000000..b89527fa --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/incoming_packet.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +#include "packet.h" + +namespace PicoMQTT { + +class IncomingPacket: public Packet, public Client { + public: + IncomingPacket(Client & client); + IncomingPacket(IncomingPacket &&); + + IncomingPacket(const IncomingPacket &) = delete; + const IncomingPacket & operator=(const IncomingPacket &) = delete; + + ~IncomingPacket(); + + virtual int available() override; + virtual int connect(IPAddress ip, uint16_t port) override; + virtual int connect(const char * host, uint16_t port) override; + virtual int peek() override; + virtual int read() override; + virtual int read(uint8_t * buf, size_t size) override; + // This operator is not marked explicit in the Client base class. Still, we're marking it explicit here + // to block implicit conversions to integer types. + virtual explicit operator bool() override; + virtual size_t write(const uint8_t * buffer, size_t size) override; + virtual size_t write(uint8_t value) override final; + virtual uint8_t connected() override; + virtual void flush() override; + virtual void stop() override; + + uint8_t read_u8(); + uint16_t read_u16(); + bool read_string(char * buffer, size_t len); + void ignore(size_t len); + + protected: + static Packet read_header(Client & client); + + Client & client; +}; + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/outgoing_packet.cpp b/lib/PicoMQTT/src/PicoMQTT/outgoing_packet.cpp new file mode 100644 index 00000000..333c3522 --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/outgoing_packet.cpp @@ -0,0 +1,224 @@ +#include +#include + +#include "debug.h" +#include "outgoing_packet.h" + +namespace PicoMQTT { + +OutgoingPacket::OutgoingPacket(Print & print, Packet::Type type, uint8_t flags, size_t payload_size) + : Packet(type, flags, payload_size), print(print), +#ifndef PICOMQTT_UNBUFFERED + buffer_position(0), +#endif + state(State::ok) { + TRACE_FUNCTION +} + +OutgoingPacket::OutgoingPacket(OutgoingPacket && other) + : OutgoingPacket(other) { + TRACE_FUNCTION + other.state = State::dead; +} + +OutgoingPacket::~OutgoingPacket() { + TRACE_FUNCTION +#ifdef PICOMQTT_DEBUG +#ifndef PICOMQTT_UNBUFFERED + if (buffer_position) { + Serial.printf("OutgoingPacket has unsent data in the buffer (pos=%u)\n", buffer_position); + } +#endif + switch (state) { + case State::ok: + Serial.println(F("Unsent OutgoingPacket")); + break; + case State::sent: + if (pos != size) { + Serial.print(F("OutgoingPacket sent incorrect number of bytes: ")); + Serial.print(pos); + Serial.print(F("/")); + Serial.println(size); + } + break; + default: + break; + } +#endif +} + +size_t OutgoingPacket::write_from_client(::Client & client, size_t length) { + TRACE_FUNCTION + size_t written = 0; +#ifndef PICOMQTT_UNBUFFERED + while (written < length) { + const size_t remaining = length - written; + const size_t remaining_buffer_space = PICOMQTT_OUTGOING_BUFFER_SIZE - buffer_position; + const size_t chunk_size = remaining < remaining_buffer_space ? remaining : remaining_buffer_space; + + const int read_size = client.read(buffer + buffer_position, chunk_size); + if (read_size <= 0) { + break; + } + + buffer_position += (size_t) read_size; + written += (size_t) read_size; + + if (buffer_position >= PICOMQTT_OUTGOING_BUFFER_SIZE) { + flush(); + } + } +#else + uint8_t buffer[128] __attribute__((aligned(4))); + while (written < length) { + const size_t remain = length - written; + const size_t chunk_size = sizeof(buffer) < remain ? sizeof(buffer) : remain; + const int read_size = client.read(buffer, chunk_size); + if (read_size <= 0) { + break; + } + const size_t write_size = print.write(buffer, read_size); + written += write_size; + if (!write_size) { + break; + } + } +#endif + pos += written; + return written; +} + +size_t OutgoingPacket::write_zero(size_t length) { + TRACE_FUNCTION + for (size_t written = 0; written < length; ++written) { + write_u8('0'); + } + return length; +} + +#ifndef PICOMQTT_UNBUFFERED +size_t OutgoingPacket::write(const void * data, size_t remaining, void * (*memcpy_fn)(void *, const void *, size_t n)) { + TRACE_FUNCTION + + const char * src = (const char *) data; + + while (remaining) { + const size_t remaining_buffer_space = PICOMQTT_OUTGOING_BUFFER_SIZE - buffer_position; + const size_t chunk_size = remaining < remaining_buffer_space ? remaining : remaining_buffer_space; + + memcpy_fn(buffer + buffer_position, src, chunk_size); + + buffer_position += chunk_size; + src += chunk_size; + remaining -= chunk_size; + + if (buffer_position >= PICOMQTT_OUTGOING_BUFFER_SIZE) { + flush(); + } + } + + const size_t written = src - (const char *) data; + pos += written; + return written; +} +#endif + +size_t OutgoingPacket::write(const uint8_t * data, size_t length) { + TRACE_FUNCTION +#ifndef PICOMQTT_UNBUFFERED + return write(data, length, memcpy); +#else + const size_t written = print.write(data, length); + pos += written; + return written; +#endif +} + +size_t OutgoingPacket::write_P(PGM_P data, size_t length) { + TRACE_FUNCTION +#ifndef PICOMQTT_UNBUFFERED + return write(data, length, memcpy_P); +#else + // here we will need a buffer + uint8_t buffer[128] __attribute__((aligned(4))); + size_t written = 0; + while (written < length) { + const size_t remain = length - written; + const size_t chunk_size = sizeof(buffer) < remain ? sizeof(buffer) : remain; + memcpy_P(buffer, data, chunk_size); + const size_t write_size = print.write(buffer, chunk_size); + written += write_size; + data += write_size; + if (!write_size) { + break; + } + } + pos += written; + return written; +#endif +} + +size_t OutgoingPacket::write_u8(uint8_t c) { + TRACE_FUNCTION + return write(&c, 1); +} + +size_t OutgoingPacket::write_u16(uint16_t value) { + TRACE_FUNCTION + return write_u8(value >> 8) + write_u8(value & 0xff); +} + +size_t OutgoingPacket::write_string(const char * string, uint16_t size) { + TRACE_FUNCTION + return write_u16(size) + write((const uint8_t *) string, size); +} + +size_t OutgoingPacket::write_packet_length(size_t length) { + TRACE_FUNCTION + size_t ret = 0; + do { + const uint8_t digit = length & 127; // digit := length % 128 + length >>= 7; // length := length / 128 + ret += write_u8(digit | (length ? 0x80 : 0)); + } while (length); + return ret; +} + +size_t OutgoingPacket::write_header() { + TRACE_FUNCTION + const size_t ret = write_u8(head) + write_packet_length(size); + // we've just written the header, payload starts now + pos = 0; + return ret; +} + +void OutgoingPacket::flush() { + TRACE_FUNCTION +#ifndef PICOMQTT_UNBUFFERED + print.write(buffer, buffer_position); + buffer_position = 0; +#endif +} + +bool OutgoingPacket::send() { + TRACE_FUNCTION + const size_t remaining_size = get_remaining_size(); + if (remaining_size) { +#ifdef PICOMQTT_DEBUG + Serial.printf("OutgoingPacket sent called on incomplete payload (%u / %u), filling with zeros.\n", pos, size); +#endif + write_zero(remaining_size); + } + flush(); + switch (state) { + case State::ok: + // print.flush(); + state = State::sent; + case State::sent: + return true; + default: + return false; + } +} + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/outgoing_packet.h b/lib/PicoMQTT/src/PicoMQTT/outgoing_packet.h new file mode 100644 index 00000000..f3776d5f --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/outgoing_packet.h @@ -0,0 +1,64 @@ +#pragma once + +// #define MQTT_OUTGOING_PACKET_DEBUG + +#include + +#include "config.h" +#include "packet.h" + +class Print; +class Client; + +#if PICOMQTT_OUTGOING_BUFFER_SIZE == 0 +#define PICOMQTT_UNBUFFERED +#endif + +namespace PicoMQTT { + +class OutgoingPacket: public Packet, public Print { + public: + OutgoingPacket(Print & print, Type type, uint8_t flags, size_t payload_size); + virtual ~OutgoingPacket(); + + const OutgoingPacket & operator=(const OutgoingPacket &) = delete; + OutgoingPacket(OutgoingPacket && other); + + virtual size_t write(const uint8_t * data, size_t length) override; + virtual size_t write(uint8_t value) override final { return write(&value, 1); } + + size_t write_P(PGM_P data, size_t length); + size_t write_u8(uint8_t value); + size_t write_u16(uint16_t value); + size_t write_string(const char * string, uint16_t size); + size_t write_header(); + + size_t write_from_client(::Client & c, size_t length); + size_t write_zero(size_t count); + + virtual void flush() override; + virtual bool send(); + + protected: + OutgoingPacket(const OutgoingPacket &) = default; + + size_t write(const void * data, size_t length, void * (*memcpy_fn)(void *, const void *, size_t n)); + size_t write_packet_length(size_t length); + + Print & print; + +#ifndef PICOMQTT_UNBUFFERED + uint8_t buffer[PICOMQTT_OUTGOING_BUFFER_SIZE] __attribute__((aligned(4))); + + size_t buffer_position; +#endif + + enum class State { + ok, + sent, + error, + dead, + } state; +}; + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/packet.h b/lib/PicoMQTT/src/PicoMQTT/packet.h new file mode 100644 index 00000000..5efd6a7e --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/packet.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +namespace PicoMQTT { + +class Packet { + public: + enum Type : uint8_t { + ERROR = 0, + CONNECT = 1 << 4, // Client request to connect to Server + CONNACK = 2 << 4, // Connect Acknowledgment + PUBLISH = 3 << 4, // Publish message + PUBACK = 4 << 4, // Publish Acknowledgment + PUBREC = 5 << 4, // Publish Received (assured delivery part 1) + PUBREL = 6 << 4, // Publish Release (assured delivery part 2) + PUBCOMP = 7 << 4, // Publish Complete (assured delivery part 3) + SUBSCRIBE = 8 << 4, // Client Subscribe request + SUBACK = 9 << 4, // Subscribe Acknowledgment + UNSUBSCRIBE = 10 << 4, // Client Unsubscribe request + UNSUBACK = 11 << 4, // Unsubscribe Acknowledgment + PINGREQ = 12 << 4, // PING Request + PINGRESP = 13 << 4, // PING Response + DISCONNECT = 14 << 4, // Client is Disconnecting + }; + + Packet(uint8_t head, size_t size) + : head(head), size(size), pos(0) {} + + Packet(Type type = ERROR, const uint8_t flags = 0, size_t size = 0) + : Packet((uint8_t) type | (flags & 0xf), size) { + } + + virtual ~Packet() {} + + Type get_type() const { return Type(head & 0xf0); } + uint8_t get_flags() const { return head & 0x0f; } + + bool is_valid() { return get_type() != ERROR; } + size_t get_remaining_size() const { return pos < size ? size - pos : 0; } + + const uint8_t head; + const size_t size; + + protected: + size_t pos; +}; + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/pico_interface.h b/lib/PicoMQTT/src/PicoMQTT/pico_interface.h new file mode 100644 index 00000000..2f823b6f --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/pico_interface.h @@ -0,0 +1,13 @@ +#pragma once + +namespace PicoMQTT { + +class PicoMQTTInterface { + public: + virtual ~PicoMQTTInterface() {} + virtual void begin() {} + virtual void stop() {} + virtual void loop() {} +}; + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/print_mux.cpp b/lib/PicoMQTT/src/PicoMQTT/print_mux.cpp new file mode 100644 index 00000000..bbe970f6 --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/print_mux.cpp @@ -0,0 +1,29 @@ +#include "print_mux.h" +#include "debug.h" + +namespace PicoMQTT { + +size_t PrintMux::write(uint8_t value) { + TRACE_FUNCTION + for (auto print_ptr : prints) { + print_ptr->write(value); + } + return 1; +} + +size_t PrintMux::write(const uint8_t * buffer, size_t size) { + TRACE_FUNCTION + for (auto print_ptr : prints) { + print_ptr->write(buffer, size); + } + return size; +} + +void PrintMux::flush() { + TRACE_FUNCTION + for (auto print_ptr : prints) { + print_ptr->flush(); + } +} + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/print_mux.h b/lib/PicoMQTT/src/PicoMQTT/print_mux.h new file mode 100644 index 00000000..b1dc688d --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/print_mux.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include + +namespace PicoMQTT { + +class PrintMux: public ::Print { + public: + PrintMux() {} + + PrintMux(Print & print) : prints({&print}) {} + + void add(Print & print) { + prints.push_back(&print); + } + + virtual size_t write(uint8_t) override; + virtual size_t write(const uint8_t * buffer, size_t size) override; + virtual void flush(); + + size_t size() const { return prints.size(); } + + protected: + std::vector prints; +}; + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/publisher.cpp b/lib/PicoMQTT/src/PicoMQTT/publisher.cpp new file mode 100644 index 00000000..506afd70 --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/publisher.cpp @@ -0,0 +1,56 @@ +#include "publisher.h" +#include "debug.h" + +namespace PicoMQTT { + +Publisher::Publish::Publish(Publisher & publisher, const PrintMux & print, + uint8_t flags, size_t total_size, + const char * topic, size_t topic_size, + uint16_t message_id) + : + OutgoingPacket(this->print, Packet::PUBLISH, flags, total_size), + qos((flags >> 1) & 0b11), + message_id(message_id), + print(print), + publisher(publisher) { + TRACE_FUNCTION + + OutgoingPacket::write_header(); + write_string(topic, topic_size); + if (qos) { + write_u16(message_id); + } +} + +Publisher::Publish::Publish(Publisher & publisher, const PrintMux & print, + const char * topic, size_t topic_size, size_t payload_size, + uint8_t qos, bool retain, bool dup, uint16_t message_id) + : Publish( + publisher, print, + (dup ? 0b1000 : 0) | ((qos & 0b11) << 1) | (retain ? 1 : 0), // flags + 2 + topic_size + (qos ? 2 : 0) + payload_size, // total size + topic, topic_size, // topic + message_id) { + TRACE_FUNCTION +} + +Publisher::Publish::Publish(Publisher & publisher, const PrintMux & print, + const char * topic, size_t payload_size, + uint8_t qos, bool retain, bool dup, uint16_t message_id) + : Publish( + publisher, print, + topic, strlen(topic), payload_size, + qos, retain, dup, message_id) { + TRACE_FUNCTION +} + +Publisher::Publish::~Publish() { + TRACE_FUNCTION +} + +bool Publisher::Publish::send() { + TRACE_FUNCTION + return OutgoingPacket::send() && publisher.on_publish_complete(*this); +} + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/publisher.h b/lib/PicoMQTT/src/PicoMQTT/publisher.h new file mode 100644 index 00000000..56e3a790 --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/publisher.h @@ -0,0 +1,90 @@ +#pragma once + +#include + +#include + +#include "debug.h" +#include "outgoing_packet.h" +#include "print_mux.h" + +namespace PicoMQTT { + +class Publisher { + public: + class Publish: public OutgoingPacket { + private: + Publish(Publisher & publisher, const PrintMux & print, + uint8_t flags, size_t total_size, + const char * topic, size_t topic_size, + uint16_t message_id); + + public: + Publish(Publisher & publisher, const PrintMux & print, + const char * topic, size_t topic_size, size_t payload_size, + uint8_t qos = 0, bool retain = false, bool dup = false, uint16_t message_id = 0); + + Publish(Publisher & publisher, const PrintMux & print, + const char * topic, size_t payload_size, + uint8_t qos = 0, bool retain = false, bool dup = false, uint16_t message_id = 0); + + ~Publish(); + + virtual bool send() override; + + const uint8_t qos; + const uint16_t message_id; + PrintMux print; + Publisher & publisher; + }; + + virtual Publish begin_publish(const char * topic, const size_t payload_size, + uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) = 0; + + Publish begin_publish(const String & topic, const size_t payload_size, + uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) { + return begin_publish(topic.c_str(), payload_size, qos, retain, message_id); + } + + template + bool publish(TopicStringType topic, const void * payload, const size_t payload_size, + uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) { + TRACE_FUNCTION + auto packet = begin_publish(get_c_str(topic), payload_size, qos, retain, message_id); + packet.write((const uint8_t *) payload, payload_size); + return packet.send(); + } + + template + bool publish_P(TopicStringType topic, PGM_P payload, const size_t payload_size, + uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) { + TRACE_FUNCTION; + auto packet = begin_publish(get_c_str(topic), payload_size, qos, retain, message_id); + packet.write_P(payload, payload_size); + return packet.send(); + } + + template + bool publish(TopicStringType topic, PayloadStringType payload, + uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) { + return publish(topic, (const void *) get_c_str(payload), get_c_str_len(payload), + qos, retain, message_id); + } + + template + bool publish_P(TopicStringType topic, PGM_P payload, + uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) { + return publish_P(topic, payload, strlen_P(payload), + qos, retain, message_id); + } + + protected: + virtual bool on_publish_complete(const Publish & publish) { return true; } + + static const char * get_c_str(const char * string) { return string; } + static const char * get_c_str(const String & string) { return string.c_str(); } + static size_t get_c_str_len(const char * string) { return strlen(string); } + static size_t get_c_str_len(const String & string) { return string.length(); } +}; + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/server.cpp b/lib/PicoMQTT/src/PicoMQTT/server.cpp new file mode 100644 index 00000000..c3c82f64 --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/server.cpp @@ -0,0 +1,365 @@ +#include "config.h" +#include "debug.h" +#include "server.h" + +namespace PicoMQTT { + +BasicServer::Client::Client(const BasicServer::Client & other) + : Connection(other.client, 0), + server(other.server), + client_id(other.client_id) { + TRACE_FUNCTION +} + +BasicServer::Client::Client(BasicServer & server, const WiFiClient & client) + : Connection(client, 0, server.socket_timeout_seconds), server(server), client_id("") { + TRACE_FUNCTION + wait_for_reply(Packet::CONNECT, [this](IncomingPacket & packet) { + TRACE_FUNCTION + + auto connack = [this](ConnectReturnCode crc) { + TRACE_FUNCTION + auto connack = build_packet(Packet::CONNACK, 0, 2); + connack.write_u8(0); /* session present always set to zero */ + connack.write_u8(crc); + connack.send(); + if (crc != CRC_ACCEPTED) { + this->client.stop(); + } + }; + + { + // MQTT protocol identifier + char buf[4]; + + if (packet.read_u16() != 4) { + on_protocol_violation(); + return; + } + + packet.read((uint8_t *) buf, 4); + + if (memcmp(buf, "MQTT", 4) != 0) { + on_protocol_violation(); + return; + } + } + + const uint8_t protocol_level = packet.read_u8(); + if (protocol_level != 4) { + on_protocol_violation(); + return; + } + + const uint8_t connect_flags = packet.read_u8(); + const bool has_user = connect_flags & (1 << 7); + const bool has_pass = connect_flags & (1 << 6); + const bool will_retain = connect_flags & (1 << 5); + const uint8_t will_qos = (connect_flags >> 3) & 0b11; + const bool has_will = connect_flags & (1 << 2); + /* const bool clean_session = connect_flags & (1 << 1); */ + + if ((has_pass && !has_user) + || (will_qos > 2) + || (!has_will && ((will_qos > 0) || will_retain))) { + on_protocol_violation(); + return; + } + + const uint16_t keep_alive_seconds = packet.read_u16(); + keep_alive_millis = keep_alive_seconds ? (keep_alive_seconds + this->server.keep_alive_tolerance_seconds) * 1000 : 0; + + { + const size_t client_id_size = packet.read_u16(); + if (client_id_size > PICOMQTT_MAX_CLIENT_ID_SIZE) { + connack(CRC_IDENTIFIER_REJECTED); + return; + } + + char client_id_buffer[client_id_size + 1]; + packet.read_string(client_id_buffer, client_id_size); + client_id = client_id_buffer; + } + + if (client_id.isEmpty()) { + client_id = String((unsigned int)(this), HEX); + } + + if (has_will) { + packet.ignore(packet.read_u16()); // will topic + packet.ignore(packet.read_u16()); // will payload + } + + // read username + const size_t user_size = has_user ? packet.read_u16() : 0; + if (user_size > PICOMQTT_MAX_USERPASS_SIZE) { + connack(CRC_BAD_USERNAME_OR_PASSWORD); + return; + } + char user[user_size + 1]; + if (user_size && !packet.read_string(user, user_size)) { + on_timeout(); + return; + } + + // read password + const size_t pass_size = has_pass ? packet.read_u16() : 0; + if (pass_size > PICOMQTT_MAX_USERPASS_SIZE) { + connack(CRC_BAD_USERNAME_OR_PASSWORD); + return; + } + char pass[pass_size + 1]; + if (pass_size && !packet.read_string(pass, pass_size)) { + on_timeout(); + return; + } + + const auto connect_return_code = this->server.auth( + client_id.c_str(), + has_user ? user : nullptr, has_pass ? pass : nullptr); + + connack(connect_return_code); + }); +} + +void BasicServer::Client::on_message(const char * topic, IncomingPacket & packet) { + TRACE_FUNCTION + + const size_t payload_size = packet.get_remaining_size(); + auto publish = server.begin_publish(topic, payload_size); + + // Always notify the server about the message + { + IncomingPublish incoming_publish(packet, publish); + server.on_message(topic, incoming_publish); + } + + publish.send(); +} + +void BasicServer::Client::on_subscribe(IncomingPacket & subscribe) { + TRACE_FUNCTION + const uint16_t message_id = subscribe.read_u16(); + + if ((subscribe.get_flags() != 0b0010) || !message_id) { + on_protocol_violation(); + return; + } + + std::list suback_codes; + + while (subscribe.get_remaining_size()) { + const size_t topic_size = subscribe.read_u16(); + if (topic_size > PICOMQTT_MAX_TOPIC_SIZE) { + subscribe.ignore(topic_size); + subscribe.read_u8(); + suback_codes.push_back(0x80); + } else { + char topic[topic_size + 1]; + if (!subscribe.read_string(topic, topic_size)) { + // connection error + return; + } + uint8_t qos = subscribe.read_u8(); + if (qos > 2) { + on_protocol_violation(); + return; + } + this->subscribe(topic); + server.on_subscribe(client_id.c_str(), topic); + suback_codes.push_back(0); + } + } + + auto suback = build_packet(Packet::SUBACK, 0, 2 + suback_codes.size()); + suback.write_u16(message_id); + for (uint8_t code : suback_codes) { + suback.write_u8(code); + } + suback.send(); +} + +void BasicServer::Client::on_unsubscribe(IncomingPacket & unsubscribe) { + TRACE_FUNCTION + const uint16_t message_id = unsubscribe.read_u16(); + + if ((unsubscribe.get_flags() != 0b0010) || !message_id) { + on_protocol_violation(); + return; + } + + while (unsubscribe.get_remaining_size()) { + const size_t topic_size = unsubscribe.read_u16(); + if (topic_size > PICOMQTT_MAX_TOPIC_SIZE) { + unsubscribe.ignore(topic_size); + } else { + char topic[topic_size + 1]; + if (!unsubscribe.read_string(topic, topic_size)) { + // connection error + return; + } + server.on_unsubscribe(client_id.c_str(), topic); + this->unsubscribe(topic); + } + } + + auto unsuback = build_packet(Packet::UNSUBACK, 0, 2); + unsuback.write_u16(message_id); + unsuback.send(); +} + +const char * BasicServer::Client::get_subscription_pattern(BasicServer::Client::SubscriptionId id) const { + for (const auto & pattern : subscriptions) + if (pattern.id == id) { + return pattern.c_str(); + } + return nullptr; +} + +Server::SubscriptionId BasicServer::Client::get_subscription(const char * topic) const { + TRACE_FUNCTION + for (const auto & pattern : subscriptions) + if (topic_matches(pattern.c_str(), topic)) { + return pattern.id; + } + return 0; +} + +BasicServer::Client::SubscriptionId BasicServer::Client::subscribe(const String & topic_filter) { + TRACE_FUNCTION + const Subscription subscription(topic_filter.c_str()); + subscriptions.insert(subscription); + return subscription.id; +} + +void BasicServer::Client::unsubscribe(const String & topic_filter) { + TRACE_FUNCTION + subscriptions.erase(topic_filter.c_str()); +} + +void BasicServer::Client::handle_packet(IncomingPacket & packet) { + TRACE_FUNCTION + + switch (packet.get_type()) { + case Packet::PINGREQ: + build_packet(Packet::PINGRESP).send(); + return; + + case Packet::SUBSCRIBE: + on_subscribe(packet); + return; + + case Packet::UNSUBSCRIBE: + on_unsubscribe(packet); + return; + + default: + Connection::handle_packet(packet); + return; + } +} + +void BasicServer::Client::loop() { + TRACE_FUNCTION + if (keep_alive_millis && (get_millis_since_last_read() > keep_alive_millis)) { + // ping timeout + on_timeout(); + return; + } + + Connection::loop(); +} + +BasicServer::IncomingPublish::IncomingPublish(IncomingPacket & packet, Publish & publish) + : IncomingPacket(std::move(packet)), publish(publish) { + TRACE_FUNCTION +} + +BasicServer::IncomingPublish::~IncomingPublish() { + TRACE_FUNCTION + pos += publish.write_from_client(client, get_remaining_size()); +} + +int BasicServer::IncomingPublish::read(uint8_t * buf, size_t size) { + TRACE_FUNCTION + const int ret = IncomingPacket::read(buf, size); + if (ret > 0) { + publish.write(buf, ret); + } + return ret; +} + +int BasicServer::IncomingPublish::read() { + TRACE_FUNCTION + const int ret = IncomingPacket::read(); + if (ret >= 0) { + publish.write(ret); + } + return ret; +} + +BasicServer::BasicServer(uint16_t port, unsigned long keep_alive_tolerance_seconds, + unsigned long socket_timeout_seconds) + : server(port), keep_alive_tolerance_seconds(keep_alive_tolerance_seconds), + socket_timeout_seconds(socket_timeout_seconds) { + TRACE_FUNCTION +} + +void BasicServer::begin() { + TRACE_FUNCTION + server.begin(); +} + +void BasicServer::stop() { + TRACE_FUNCTION + server.stop(); + clients.clear(); +} + +void BasicServer::loop() { + TRACE_FUNCTION + + while (server.hasClient()) { + auto client = Client(*this, server.accept()); + clients.push_back(client); + on_connected(client.get_client_id()); + } + + for (auto it = clients.begin(); it != clients.end();) { + it->loop(); + + if (!it->connected()) { + on_disconnected(it->get_client_id()); + clients.erase(it++); + } else { + ++it; + } + } +} + +PrintMux BasicServer::get_subscribed(const char * topic) { + TRACE_FUNCTION + PrintMux ret; + for (auto & client : clients) { + if (client.get_subscription(topic)) { + ret.add(client.get_print()); + } + } + return ret; +} + +Publisher::Publish BasicServer::begin_publish(const char * topic, const size_t payload_size, + uint8_t, bool, uint16_t) { + TRACE_FUNCTION + return Publish(*this, get_subscribed(topic), topic, payload_size); +} + +void BasicServer::on_message(const char * topic, IncomingPacket & packet) { +} + +void Server::on_message(const char * topic, IncomingPacket & packet) { + TRACE_FUNCTION + fire_message_callbacks(topic, packet); +} + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/server.h b/lib/PicoMQTT/src/PicoMQTT/server.h new file mode 100644 index 00000000..35516e5e --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/server.h @@ -0,0 +1,105 @@ +#pragma once + +#include +#include + +#include + +#if defined(ESP32) +#include +#elif defined(ESP8266) +#include +#else +#error "This board is not supported." +#endif + +#include "debug.h" +#include "incoming_packet.h" +#include "connection.h" +#include "publisher.h" +#include "subscriber.h" +#include "pico_interface.h" + +namespace PicoMQTT { + +class BasicServer: public PicoMQTTInterface, public Publisher { + public: + class Client: public Connection, public Subscriber { + public: + Client(BasicServer & server, const WiFiClient & client); + Client(const Client &); + + void on_message(const char * topic, IncomingPacket & packet) override; + + Print & get_print() { return client; } + const char * get_client_id() const { return client_id.c_str(); } + + virtual void loop() override; + + virtual const char * get_subscription_pattern(SubscriptionId id) const override; + virtual SubscriptionId get_subscription(const char * topic) const override; + virtual SubscriptionId subscribe(const String & topic_filter) override; + virtual void unsubscribe(const String & topic_filter) override; + + protected: + BasicServer & server; + String client_id; + std::set subscriptions; + + virtual void on_subscribe(IncomingPacket & packet); + virtual void on_unsubscribe(IncomingPacket & packet); + + virtual void handle_packet(IncomingPacket & packet) override; + }; + + class IncomingPublish: public IncomingPacket { + public: + IncomingPublish(IncomingPacket & packet, Publish & publish); + IncomingPublish(const IncomingPublish &) = delete; + ~IncomingPublish(); + + virtual int read(uint8_t * buf, size_t size) override; + virtual int read() override; + + protected: + Publish & publish; + }; + + BasicServer(uint16_t port = 1883, unsigned long keep_alive_tolerance_seconds = 10, + unsigned long socket_timeout_seconds = 5); + + void begin() override; + void stop() override; + void loop() override; + + using Publisher::begin_publish; + virtual Publish begin_publish(const char * topic, const size_t payload_size, + uint8_t qos = 0, bool retain = false, uint16_t message_id = 0) override; + + protected: + virtual void on_message(const char * topic, IncomingPacket & packet); + virtual ConnectReturnCode auth(const char * client_id, const char * username, const char * password) { return CRC_ACCEPTED; } + + virtual void on_connected(const char * client_id) {} + virtual void on_disconnected(const char * client_id) {} + + virtual void on_subscribe(const char * client_id, const char * topic) {} + virtual void on_unsubscribe(const char * client_id, const char * topic) {} + + virtual PrintMux get_subscribed(const char * topic); + + WiFiServer server; + std::list clients; + + const unsigned long keep_alive_tolerance_seconds; + const unsigned long socket_timeout_seconds; + +}; + +class Server: public BasicServer, public SubscribedMessageListener { + public: + using BasicServer::BasicServer; + virtual void on_message(const char * topic, IncomingPacket & packet) override; +}; + +} diff --git a/lib/PicoMQTT/src/PicoMQTT/subscriber.cpp b/lib/PicoMQTT/src/PicoMQTT/subscriber.cpp new file mode 100644 index 00000000..971195ea --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/subscriber.cpp @@ -0,0 +1,174 @@ +#include "subscriber.h" +#include "incoming_packet.h" +#include "debug.h" + +namespace PicoMQTT { + +String Subscriber::get_topic_element(const char * topic, size_t index) { + + while (index && topic[0]) { + if (topic++[0] == '/') { + --index; + } + } + + if (!topic[0]) { + return ""; + } + + const char * end = topic; + while (*end && *end != '/') { + ++end; + } + + String ret; + ret.concat(topic, end - topic); + return ret; +} + +String Subscriber::get_topic_element(const String & topic, size_t index) { + TRACE_FUNCTION + return get_topic_element(topic.c_str(), index); +} + +bool Subscriber::topic_matches(const char * p, const char * t) { + TRACE_FUNCTION + // TODO: Special handling of the $ prefix + while (true) { + switch (*p) { + case '\0': + // end of pattern reached + // TODO: check for '/#' suffix + return (*t == '\0'); + + case '#': + // multilevel wildcard + if (*t == '\0') { + return false; + } + return true; + + case '+': + // single level wildcard + while (*t && *t != '/') { + ++t; + } + ++p; + break; + + default: + // regular match + if (*p != *t) { + if (*t == '\0') + { + if (*p == '/') + { + ++p; + if (*p == '#') + { + ++p; + if (*p == '\0') + return true; + } + } + } + return false; + } + ++p; + ++t; + } + } +} + +const char * SubscribedMessageListener::get_subscription_pattern(SubscriptionId id) const { + TRACE_FUNCTION + for (const auto & kv : subscriptions) { + if (kv.first.id == id) { + return kv.first.c_str(); + } + } + return nullptr; +} + +Subscriber::SubscriptionId SubscribedMessageListener::get_subscription(const char * topic) const { + TRACE_FUNCTION + for (const auto & kv : subscriptions) { + if (topic_matches(kv.first.c_str(), topic)) { + return kv.first.id; + } + } + return 0; +} + +Subscriber::SubscriptionId SubscribedMessageListener::subscribe(const String & topic_filter) { + TRACE_FUNCTION + return subscribe(topic_filter, [this](const char * topic, IncomingPacket & packet) { on_extra_message(topic, packet); }); +} + +Subscriber::SubscriptionId SubscribedMessageListener::subscribe(const String & topic_filter, MessageCallback callback) { + TRACE_FUNCTION + unsubscribe(topic_filter); + auto pair = subscriptions.emplace(std::make_pair(Subscription(topic_filter), callback)); + return pair.first->first.id; +} + +void SubscribedMessageListener::unsubscribe(const String & topic_filter) { + TRACE_FUNCTION + subscriptions.erase(topic_filter); +} + +void SubscribedMessageListener::fire_message_callbacks(const char * topic, IncomingPacket & packet) { + TRACE_FUNCTION + for (const auto & kv : subscriptions) { + if (topic_matches(kv.first.c_str(), topic)) { + kv.second((char *) topic, packet); + return; + } + } + on_extra_message(topic, packet); +} + +Subscriber::SubscriptionId SubscribedMessageListener::subscribe(const String & topic_filter, + std::function callback, size_t max_size) { + TRACE_FUNCTION + return subscribe(topic_filter, [this, callback, max_size](char * topic, IncomingPacket & packet) { + const size_t payload_size = packet.get_remaining_size(); + if (payload_size >= max_size) { + on_message_too_big(topic, packet); + return; + } + char payload[payload_size + 1]; + if (packet.read((uint8_t *) payload, payload_size) != (int) payload_size) { + // connection error, ignore + return; + } + payload[payload_size] = '\0'; + callback(topic, payload, payload_size); + }); +} + +Subscriber::SubscriptionId SubscribedMessageListener::subscribe(const String & topic_filter, + std::function callback, size_t max_size) { + TRACE_FUNCTION + return subscribe(topic_filter, [callback](char * topic, void * payload, size_t payload_size) { + callback(topic, (char *) payload); + }); +} + +Subscriber::SubscriptionId SubscribedMessageListener::subscribe(const String & topic_filter, + std::function callback, size_t max_size) { + TRACE_FUNCTION + return subscribe(topic_filter, [callback](char * topic, void * payload, size_t payload_size) { + callback((char *) payload); + }); +} + +Subscriber::SubscriptionId SubscribedMessageListener::subscribe(const String & topic_filter, + std::function callback, size_t max_size) { + TRACE_FUNCTION + return subscribe(topic_filter, [callback](char * topic, void * payload, size_t payload_size) { + callback(payload, payload_size); + }); +} + +}; diff --git a/lib/PicoMQTT/src/PicoMQTT/subscriber.h b/lib/PicoMQTT/src/PicoMQTT/subscriber.h new file mode 100644 index 00000000..1fe5ace5 --- /dev/null +++ b/lib/PicoMQTT/src/PicoMQTT/subscriber.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include + +#include + +#include "autoid.h" +#include "config.h" + +namespace PicoMQTT { + +class IncomingPacket; + +class Subscriber { + public: + typedef AutoId::Id SubscriptionId; + + static bool topic_matches(const char * topic_filter, const char * topic); + static String get_topic_element(const char * topic, size_t index); + static String get_topic_element(const String & topic, size_t index); + + virtual const char * get_subscription_pattern(SubscriptionId id) const = 0; + virtual SubscriptionId get_subscription(const char * topic) const = 0; + + virtual SubscriptionId subscribe(const String & topic_filter) = 0; + + virtual void unsubscribe(const String & topic_filter) = 0; + void unsubscribe(SubscriptionId id) { unsubscribe(get_subscription_pattern(id)); } + + protected: + class Subscription: public String, public AutoId { + public: + using String::String; + Subscription(const String & str): Subscription(str.c_str()) {} + }; + +}; + +class SubscribedMessageListener: public Subscriber { + public: + // NOTE: None of the callback functions use const arguments for wider compatibility. It's still OK (and + // recommended) to use callbacks which take const arguments. Similarly with Strings. + typedef std::function MessageCallback; + + virtual const char * get_subscription_pattern(SubscriptionId id) const override; + virtual SubscriptionId get_subscription(const char * topic) const override; + + virtual SubscriptionId subscribe(const String & topic_filter) override; + virtual SubscriptionId subscribe(const String & topic_filter, MessageCallback callback); + + SubscriptionId subscribe(const String & topic_filter, std::function callback, + size_t max_size = PICOMQTT_MAX_MESSAGE_SIZE); + + SubscriptionId subscribe(const String & topic_filter, std::function callback, + size_t max_size = PICOMQTT_MAX_MESSAGE_SIZE); + SubscriptionId subscribe(const String & topic_filter, std::function callback, + size_t max_size = PICOMQTT_MAX_MESSAGE_SIZE); + SubscriptionId subscribe(const String & topic_filter, std::function callback, + size_t max_size = PICOMQTT_MAX_MESSAGE_SIZE); + + virtual void unsubscribe(const String & topic_filter) override; + + virtual void on_extra_message(const char * topic, IncomingPacket & packet) {} + virtual void on_message_too_big(const char * topic, IncomingPacket & packet) {} + + protected: + void fire_message_callbacks(const char * topic, IncomingPacket & packet); + + std::map subscriptions; +}; + +} diff --git a/src/modules/exec/BrokerMQTT/modinfo.json b/src/modules/exec/BrokerMQTT/modinfo.json index 7cad1e16..3a5725f4 100644 --- a/src/modules/exec/BrokerMQTT/modinfo.json +++ b/src/modules/exec/BrokerMQTT/modinfo.json @@ -36,7 +36,6 @@ "defActive": false, "usedLibs": { "esp32_4mb3f": [], - "esp32*": [], - "esp82*": [] + "esp32*": [] } } \ No newline at end of file