Skip to content

Commit

Permalink
Additional MQTT options & support for external brokers/TLS (#1005)
Browse files Browse the repository at this point in the history
* Additional MQTT options & support for external brokers/TLS

* Run fix

* Remove comment

* Fewer options, some with defaults

* Test fixes

* Re-add DiscoveryPrefix
  • Loading branch information
scottt732 authored Jan 5, 2024
1 parent 9819e15 commit 27f3599
Show file tree
Hide file tree
Showing 12 changed files with 309 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License. See https://go.microsoft.com/fwlink/?linkid=2090316 for license information.
#-------------------------------------------------------------------------------------------------------------

FROM mcr.microsoft.com/dotnet/sdk:7.0.100-bullseye-slim-amd64
FROM mcr.microsoft.com/dotnet/sdk:8.0-bookworm-slim-amd64

RUN apt-get update && apt-get install -y ssh
# This Dockerfile adds a non-root 'vscode' user with sudo access. However, for Linux,
Expand Down
26 changes: 21 additions & 5 deletions Docker/rootfs/etc/services.d/netdaemon_addon/run
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,37 @@
# Optional MQTT configuration
declare Mqtt__Host
declare Mqtt__Port
declare Mqtt__Username
declare Mqtt__UserName
declare Mqtt__Password
declare Mqtt__DiscoveryPrefix
declare Mqtt__UseTls
declare Mqtt__AllowUntrustedCertificates

# Set configuration values to environment variables
NetDaemon__ApplicationAssembly=$(bashio::config 'app_assembly')
Logging__Loglevel__Default=$(bashio::config 'log_level')
NetDaemon__ApplicationConfigurationFolder=$(bashio::config 'app_config_folder')

if bashio::services.available "mqtt"; then
if bashio::config.has_value "mqtt_host"; then
Mqtt__Host=$(bashio::config 'mqtt_host')
Mqtt__Port=$(bashio::config 'mqtt_port' '1883')
Mqtt__UserName=$(bashio::config 'mqtt_username' '')
Mqtt__Password=$(bashio::config 'mqtt_password' '')
elif bashio::services.available "mqtt"; then
Mqtt__Host=$(bashio::services 'mqtt' 'host')
Mqtt__Port=$(bashio::services 'mqtt' 'port')
Mqtt__Username=$(bashio::services 'mqtt' 'username')
Mqtt__UserName=$(bashio::services 'mqtt' 'username')
Mqtt__Password=$(bashio::services 'mqtt' 'password')
else
bashio::log.warning \
"No MQTT add-on installed, MQTT features will not work."
"No MQTT add-on installed and manual broker configuration was not specified. MQTT features will not work."
fi

# Optional settings - Toggle "Show unused optional configuration options" to reveal
Mqtt__DiscoveryPrefix=$(bashio::config 'mqtt_discovery_prefix' 'homeassistant')
Mqtt__UseTls=$(bashio::config 'mqtt_use_tls' 'false')
Mqtt__AllowUntrustedCertificates=$(bashio::config 'mqtt_allow_untrusted_certificates' 'false')

export \
HomeAssistant__Host="supervisor" \
HomeAssistant__WebsocketPath="core/websocket" \
Expand All @@ -29,8 +42,11 @@ export \
HomeAssistant__Token="${SUPERVISOR_TOKEN}" \
Mqtt__Host \
Mqtt__Port \
Mqtt__Username \
Mqtt__UserName \
Mqtt__Password \
Mqtt__DiscoveryPrefix \
Mqtt__UseTls \
Mqtt__AllowUntrustedCertificates \
NetDaemon__ApplicationAssembly \
Logging__Loglevel__Default \
NetDaemon__ApplicationConfigurationFolder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,36 @@ public async Task CanGetClient()

var mqttClient = new Mock<IManagedMqttClient>();
var mqttFactory = new MqttFactoryWrapper(mqttClient.Object);
var mqttClientOptionsFactory = new Mock<IMqttClientOptionsFactory>();
var mqttConfigurationOptions = new Mock<IOptions<MqttConfiguration>>();

var conn = new AssuredMqttConnection(logger.Object, mqttFactory, GetMockOptions());
ConfigureMockOptions(mqttConfigurationOptions);

mqttClientOptionsFactory.Setup(f => f.CreateClientOptions(It.Is<MqttConfiguration>(o => o.Host == "localhost" && o.UserName == "id")))
.Returns(new ManagedMqttClientOptions())
.Verifiable(Times.Once);

var conn = new AssuredMqttConnection(logger.Object, mqttClientOptionsFactory.Object, mqttFactory, mqttConfigurationOptions.Object);
var returnedClient = await conn.GetClientAsync();

returnedClient.Should().Be(mqttClient.Object);

mqttClientOptionsFactory.VerifyAll();
mqttConfigurationOptions.VerifyAll();
}

private static IOptions<MqttConfiguration> GetMockOptions()
private static void ConfigureMockOptions(Mock<IOptions<MqttConfiguration>> mockOptions, Action<MqttConfiguration>? configuration = null)
{
var options = new Mock<IOptions<MqttConfiguration>>();
var mqttConfiguration = new MqttConfiguration
{
Host = "localhost",
UserName = "id"
};

options.Setup(o => o.Value)
.Returns(() => new MqttConfiguration
{
Host = "localhost", UserName = "id"
});
configuration?.Invoke(mqttConfiguration);

return options.Object;
mockOptions.SetupGet(o => o.Value)
.Returns(() => mqttConfiguration)
.Verifiable(Times.Once);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
using MQTTnet.Client;
using NetDaemon.Extensions.MqttEntityManager;

namespace NetDaemon.HassClient.Tests.ExtensionsTest.MqttEntityManagerTests;

public class MqttClientOptionsFactoryTests
{
private MqttClientOptionsFactory MqttClientOptionsFactory { get; } = new();

[Fact]
public void CreatesDefaultConfiguration()
{
// This is the bare minimum necessary to establish a connection to an MQTT broker that doesn't use TLS
// or require authentication. The default port is 1883 and a TCP connection is used.
var mqttConfiguration = new MqttConfiguration
{
Host = "broker",
};

var mqttClientOptions = MqttClientOptionsFactory.CreateClientOptions(mqttConfiguration);

mqttClientOptions.Should().NotBeNull();

mqttClientOptions.ClientOptions.ChannelOptions.Should().NotBeNull();
mqttClientOptions.ClientOptions.ChannelOptions.Should().BeOfType<MqttClientTcpOptions>();

var mqttClientChannelOptions = (MqttClientTcpOptions)mqttClientOptions.ClientOptions.ChannelOptions;
mqttClientChannelOptions.Server.Should().Be("broker");
mqttClientChannelOptions.Port.Should().Be(1883);

mqttClientOptions.ClientOptions.Credentials.Should().BeNull();

mqttClientOptions.ClientOptions.ChannelOptions.TlsOptions.UseTls.Should().BeFalse();
mqttClientOptions.ClientOptions.ChannelOptions.TlsOptions.AllowUntrustedCertificates.Should().BeFalse();
}

[Fact]
public void CreatesDefaultConfigurationWithTls()
{
var mqttConfiguration = new MqttConfiguration
{
Host = "broker",
UseTls = true
};

var mqttClientOptions = MqttClientOptionsFactory.CreateClientOptions(mqttConfiguration);

mqttClientOptions.Should().NotBeNull();

mqttClientOptions.ClientOptions.ChannelOptions.Should().NotBeNull();
mqttClientOptions.ClientOptions.ChannelOptions.Should().BeOfType<MqttClientTcpOptions>();

var mqttClientChannelOptions = (MqttClientTcpOptions)mqttClientOptions.ClientOptions.ChannelOptions;
mqttClientChannelOptions.Server.Should().Be("broker");
mqttClientChannelOptions.Port.Should().Be(1883);

mqttClientOptions.ClientOptions.Credentials.Should().BeNull();

mqttClientOptions.ClientOptions.ChannelOptions.TlsOptions.UseTls.Should().BeTrue();

// This would only get set to true if it and UseTls are both true
mqttClientOptions.ClientOptions.ChannelOptions.TlsOptions.AllowUntrustedCertificates.Should().BeFalse();
}

[Fact]
public void IgnoresTlsCustomizationIfTlsIsntEnabled()
{
var mqttConfiguration = new MqttConfiguration
{
Host = "broker",
UseTls = false,
AllowUntrustedCertificates = true
};

var mqttClientOptions = MqttClientOptionsFactory.CreateClientOptions(mqttConfiguration);

mqttClientOptions.Should().NotBeNull();

mqttClientOptions.ClientOptions.ChannelOptions.Should().NotBeNull();
mqttClientOptions.ClientOptions.ChannelOptions.Should().BeOfType<MqttClientTcpOptions>();

var mqttClientChannelOptions = (MqttClientTcpOptions)mqttClientOptions.ClientOptions.ChannelOptions;
mqttClientChannelOptions.Server.Should().Be("broker");
mqttClientChannelOptions.Port.Should().Be(1883);

mqttClientOptions.ClientOptions.Credentials.Should().BeNull();

mqttClientOptions.ClientOptions.ChannelOptions.TlsOptions.UseTls.Should().BeFalse();

// This would only get set to true if it and UseTls are both true
mqttClientOptions.ClientOptions.ChannelOptions.TlsOptions.AllowUntrustedCertificates.Should().BeFalse();
}

[Fact]
public void CreatesFullyCustomizedConfiguration()
{
var mqttConfiguration = new MqttConfiguration
{
Host = "broker",
Port = 1234,
UserName = "testuser",
Password = "testpassword",
UseTls = true,
AllowUntrustedCertificates = true
};

var mqttClientOptions = MqttClientOptionsFactory.CreateClientOptions(mqttConfiguration);

mqttClientOptions.Should().NotBeNull();

mqttClientOptions.ClientOptions.ChannelOptions.Should().NotBeNull();
mqttClientOptions.ClientOptions.ChannelOptions.Should().BeOfType<MqttClientTcpOptions>();

var mqttClientChannelOptions = (MqttClientTcpOptions)mqttClientOptions.ClientOptions.ChannelOptions;
mqttClientChannelOptions.Server.Should().Be("broker");
mqttClientChannelOptions.Port.Should().Be(1234);

mqttClientOptions.ClientOptions.Credentials.Should().NotBeNull();
mqttClientOptions.ClientOptions.Credentials.Should().BeOfType<MqttClientCredentials>();

mqttClientOptions.ClientOptions.Credentials.GetUserName(mqttClientOptions.ClientOptions).Should().Be("testuser");
mqttClientOptions.ClientOptions.Credentials.GetPassword(mqttClientOptions.ClientOptions).Should().BeEquivalentTo(Encoding.UTF8.GetBytes("testpassword"));

mqttClientOptions.ClientOptions.ChannelOptions.TlsOptions.UseTls.Should().BeTrue();
mqttClientOptions.ClientOptions.ChannelOptions.TlsOptions.AllowUntrustedCertificates.Should().BeTrue();
}

[Fact]
void ThrowsArgumentNullExceptionIfMqttConfigIsNull()
{
Assert.Throws<ArgumentNullException>(() => MqttClientOptionsFactory.CreateClientOptions(null!));
}

[Theory]
[InlineData(null)]
[InlineData("")]
void ThrowsArgumentExceptionIfMqttConfigHasNullOrEmptyHost(string? host)
{
var mqttConfiguration = new MqttConfiguration
{
Host = host!,
};

Assert.Throws<ArgumentException>(() => MqttClientOptionsFactory.CreateClientOptions(mqttConfiguration));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public async Task CreateWithNoOptionsSetsBaseConfig()
var entityManager = new MqttEntityManager(mqttSetup.MessageSender, null!, GetOptions());

await entityManager.CreateAsync("domain.sensor");
var payload = PayloadToDictionary(mqttSetup.LastPublishedMessage.PayloadSegment.Array ?? Array.Empty<byte>() );
var payload = PayloadToDictionary(mqttSetup.LastPublishedMessage.PayloadSegment.Array ?? Array.Empty<byte>());

payload?.Count.Should().Be(6);
payload?["name"].ToString().Should().Be("sensor");
Expand Down Expand Up @@ -269,7 +269,9 @@ private static IOptions<MqttConfiguration> GetOptions()
options.Setup(o => o.Value)
.Returns(() => new MqttConfiguration
{
Host = "localhost", UserName = "id", DiscoveryPrefix = "homeassistant"
Host = "localhost",
UserName = "id",
DiscoveryPrefix = "homeassistant"
});

return options.Object;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ internal sealed class MockMqttMessageSenderSetup
{
public AssuredMqttConnection Connection { get; private set; } = null!;
public Mock<IManagedMqttClient> MqttClient { get; private set; } = null!;
public Mock<IMqttClientOptionsFactory> MqttClientOptionsFactory { get; private set; } = null!;

public MqttFactoryWrapper MqttFactory { get; private set; } = null!;
public MessageSender MessageSender { get; private set; } = null!;
public MqttApplicationMessage LastPublishedMessage { get; set; } = null!;
Expand Down Expand Up @@ -41,18 +43,29 @@ public void SetupMessageReceiver()
/// <returns></returns>
private void SetupMockMqtt()
{
var mqttConfiguration = new MqttConfiguration
{
Host = "localhost",
UserName = "id"
};

var options = new Mock<IOptions<MqttConfiguration>>();

options.Setup(o => o.Value)
.Returns(() => new MqttConfiguration
{
Host = "localhost", UserName = "id"
});
.Returns(() => mqttConfiguration);

MqttClient = new Mock<IManagedMqttClient>();
MqttClientOptionsFactory = new Mock<IMqttClientOptionsFactory>();
MqttFactory = new MqttFactoryWrapper(MqttClient.Object);

Connection = new AssuredMqttConnection(new Mock<ILogger<AssuredMqttConnection>>().Object, MqttFactory,
MqttClientOptionsFactory
.Setup(o => o.CreateClientOptions(mqttConfiguration))
.Returns(new ManagedMqttClientOptions());

Connection = new AssuredMqttConnection(
new Mock<ILogger<AssuredMqttConnection>>().Object,
MqttClientOptionsFactory.Object,
MqttFactory,
options.Object);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,26 @@ namespace NetDaemon.Extensions.MqttEntityManager;
internal class AssuredMqttConnection : IAssuredMqttConnection, IDisposable
{
private readonly ILogger<AssuredMqttConnection> _logger;
private readonly IMqttClientOptionsFactory _mqttClientOptionsFactory;
private readonly Task _connectionTask;
private IManagedMqttClient? _mqttClient;
private bool _disposed;

/// <summary>
/// Wrapper to assure an MQTT connection
/// Initializes a new instance of the <see cref="AssuredMqttConnection"/> class.
/// </summary>
/// <param name="logger"></param>
/// <param name="mqttFactory"></param>
/// <param name="mqttConfig"></param>
/// <param name="logger">The logger.</param>
/// <param name="mqttClientOptionsFactory">The MQTT client options factory.</param>
/// <param name="mqttFactory">The MQTT factory wrapper.</param>
/// <param name="mqttConfig">The MQTT configuration.</param>
public AssuredMqttConnection(
ILogger<AssuredMqttConnection> logger,
IMqttClientOptionsFactory mqttClientOptionsFactory,
IMqttFactoryWrapper mqttFactory,
IOptions<MqttConfiguration> mqttConfig)
{
_logger = logger;

_mqttClientOptionsFactory = mqttClientOptionsFactory;
_logger.LogTrace("MQTT initiating connection");
_connectionTask = Task.Run(() => ConnectAsync(mqttConfig.Value, mqttFactory));
}
Expand All @@ -53,12 +56,7 @@ private async Task ConnectAsync(MqttConfiguration mqttConfig, IMqttFactoryWrappe
_logger.LogTrace("Connecting to MQTT broker at {Host}:{Port}/{UserName}",
mqttConfig.Host, mqttConfig.Port, mqttConfig.UserName);

var clientOptions = new ManagedMqttClientOptionsBuilder()
.WithAutoReconnectDelay(TimeSpan.FromSeconds(5))
.WithClientOptions(new MqttClientOptionsBuilder()
.WithTcpServer(mqttConfig.Host, mqttConfig.Port)
.WithCredentials(mqttConfig.UserName, mqttConfig.Password))
.Build();
var clientOptions = _mqttClientOptionsFactory.CreateClientOptions(mqttConfig);

_mqttClient = mqttFactory.CreateManagedMqttClient();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public static IHostBuilder UseNetDaemonMqttEntityManagement(this IHostBuilder ho
return hostBuilder.ConfigureServices((context, services) =>
{
services.AddSingleton<IMqttFactory, MqttFactoryFactory>();
services.AddSingleton<IMqttClientOptionsFactory, MqttClientOptionsFactory>();
services.AddSingleton<IMqttFactoryWrapper, MqttFactoryWrapper>();
services.AddSingleton<IMqttEntityManager, MqttEntityManager>();
services.AddSingleton<IAssuredMqttConnection, AssuredMqttConnection>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using MQTTnet.Extensions.ManagedClient;

namespace NetDaemon.Extensions.MqttEntityManager;

/// <summary>
/// Represents a factory for creating MQTT client options.
/// </summary>
public interface IMqttClientOptionsFactory
{
/// <summary>
/// Creates the client options for MQTT connection from the supplied configuration.
/// /// </summary>
/// <param name="mqttConfig">The MQTT configuration.</param>
/// <returns>The managed MQTT client options.</returns>
ManagedMqttClientOptions CreateClientOptions(MqttConfiguration mqttConfig);
}
Loading

0 comments on commit 27f3599

Please sign in to comment.