Fully moved to Ninject, fixed issues with medusa (un)loadabaility

This commit is contained in:
Kwoth
2023-03-10 01:11:43 +01:00
parent ff066b6473
commit e91646594f
10 changed files with 283 additions and 111 deletions

View File

@@ -8,7 +8,9 @@ using NadekoBot.Modules.Utility;
using NadekoBot.Services.Database.Models; using NadekoBot.Services.Database.Models;
using Ninject; using Ninject;
using Ninject.Extensions.Conventions; using Ninject.Extensions.Conventions;
using Ninject.Extensions.Conventions.Syntax;
using Ninject.Infrastructure.Language; using Ninject.Infrastructure.Language;
using Ninject.Planning;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Diagnostics; using System.Diagnostics;
using System.Reflection; using System.Reflection;
@@ -102,7 +104,14 @@ public sealed class Bot
AllGuildConfigs = uow.GuildConfigs.GetAllGuildConfigs(startingGuildIdList).ToImmutableArray(); AllGuildConfigs = uow.GuildConfigs.GetAllGuildConfigs(startingGuildIdList).ToImmutableArray();
} }
var kernel = new StandardKernel(); var kernel = new StandardKernel(new NinjectSettings()
{
ThrowOnGetServiceNotFound = true,
ActivationCacheDisabled = true,
});
kernel.Components.Remove<IPlanner, Planner>();
kernel.Components.Add<IPlanner, RemovablePlanner>();
kernel.Bind<IBotCredentials>().ToMethod(_ => _credsProvider.GetCreds()).InTransientScope(); kernel.Bind<IBotCredentials>().ToMethod(_ => _credsProvider.GetCreds()).InTransientScope();
@@ -122,7 +131,6 @@ public sealed class Bot
.AddCache(_creds) .AddCache(_creds)
.AddHttpClients(); .AddHttpClients();
if (Environment.GetEnvironmentVariable("NADEKOBOT_IS_COORDINATED") != "1") if (Environment.GetEnvironmentVariable("NADEKOBOT_IS_COORDINATED") != "1")
{ {
kernel.Bind<ICoordinator>().To<SingleProcessCoordinator>().InSingletonScope(); kernel.Bind<ICoordinator>().To<SingleProcessCoordinator>().InSingletonScope();
@@ -134,35 +142,25 @@ public sealed class Bot
kernel.Bind(scan => kernel.Bind(scan =>
{ {
var classes = scan.FromThisAssembly() scan.FromThisAssembly()
.SelectAllClasses() .SelectAllClasses()
.Where(c => (c.IsAssignableTo(typeof(INService)) .Where(c => (c.IsAssignableTo(typeof(INService))
|| c.IsAssignableTo(typeof(IExecOnMessage)) || c.IsAssignableTo(typeof(IExecOnMessage))
|| c.IsAssignableTo(typeof(IInputTransformer)) || c.IsAssignableTo(typeof(IInputTransformer))
|| c.IsAssignableTo(typeof(IExecPreCommand)) || c.IsAssignableTo(typeof(IExecPreCommand))
|| c.IsAssignableTo(typeof(IExecPostCommand)) || c.IsAssignableTo(typeof(IExecPostCommand))
|| c.IsAssignableTo(typeof(IExecNoCommand))) || c.IsAssignableTo(typeof(IExecNoCommand)))
&& !c.HasAttribute<DontAddToIocContainerAttribute>() && !c.HasAttribute<DontAddToIocContainerAttribute>()
#if GLOBAL_NADEKO #if GLOBAL_NADEK
&& !c.HasAttribute<NoPublicBotAttribute>() && !c.HasAttribute<NoPublicBotAttribute>()
#endif #endif
); )
classes .BindToSelfWithInterfaces()
.BindAllInterfaces()
.Configure(c => c.InSingletonScope()); .Configure(c => c.InSingletonScope());
classes.BindToSelf()
.Configure(c => c.InSingletonScope());
}); });
kernel.Bind<IServiceProvider>().ToConstant(kernel).InSingletonScope(); kernel.Bind<IServiceProvider>().ToConstant(kernel).InSingletonScope();
var services = kernel.GetServices(typeof(INService));
foreach (var s in services)
{
Console.WriteLine(s.GetType().FullName);
}
//initialize Services //initialize Services
Services = kernel; Services = kernel;
Services.GetRequiredService<IBehaviorHandler>().Initialize(); Services.GetRequiredService<IBehaviorHandler>().Initialize();
@@ -187,16 +185,7 @@ public sealed class Bot
private IEnumerable<object> LoadTypeReaders(Assembly assembly) private IEnumerable<object> LoadTypeReaders(Assembly assembly)
{ {
Type[] allTypes; var allTypes = assembly.GetTypes();
try
{
allTypes = assembly.GetTypes();
}
catch (ReflectionTypeLoadException ex)
{
Log.Warning(ex.LoaderExceptions[0], "Error getting types");
return Enumerable.Empty<object>();
}
var filteredTypes = allTypes.Where(x => x.IsSubclassOf(typeof(TypeReader)) var filteredTypes = allTypes.Where(x => x.IsSubclassOf(typeof(TypeReader))
&& x.BaseType?.GetGenericArguments().Length > 0 && x.BaseType?.GetGenericArguments().Length > 0
@@ -205,10 +194,12 @@ public sealed class Bot
var toReturn = new List<object>(); var toReturn = new List<object>();
foreach (var ft in filteredTypes) foreach (var ft in filteredTypes)
{ {
var x = (TypeReader)ActivatorUtilities.CreateInstance(Services, ft);
var baseType = ft.BaseType; var baseType = ft.BaseType;
if (baseType is null) if (baseType is null)
continue; continue;
var x = (TypeReader)ActivatorUtilities.CreateInstance(Services, ft);
var typeArgs = baseType.GetGenericArguments(); var typeArgs = baseType.GetGenericArguments();
_commandService.AddTypeReader(typeArgs[0], x); _commandService.AddTypeReader(typeArgs[0], x);
toReturn.Add(x); toReturn.Add(x);

View File

@@ -3,34 +3,33 @@ using System.Runtime.Loader;
namespace Nadeko.Medusa; namespace Nadeko.Medusa;
public sealed class MedusaAssemblyLoadContext : AssemblyLoadContext public class MedusaAssemblyLoadContext : AssemblyLoadContext
{ {
private readonly AssemblyDependencyResolver _depResolver; private readonly AssemblyDependencyResolver _resolver;
public MedusaAssemblyLoadContext(string pluginPath) : base(isCollectible: true) public MedusaAssemblyLoadContext(string folderPath) : base(isCollectible: true)
{ => _resolver = new(folderPath);
_depResolver = new(pluginPath);
} // public Assembly MainAssembly { get; private set; }
protected override Assembly? Load(AssemblyName assemblyName) protected override Assembly? Load(AssemblyName assemblyName)
{ {
var assemblyPath = _depResolver.ResolveAssemblyToPath(assemblyName); var assemblyPath = _resolver.ResolveAssemblyToPath(assemblyName);
if (assemblyPath != null) if (assemblyPath != null)
{ {
return LoadFromAssemblyPath(assemblyPath); Assembly assembly = LoadFromAssemblyPath(assemblyPath);
LoadDependencies(assembly);
return assembly;
} }
return null; return null;
} }
protected override IntPtr LoadUnmanagedDll(string unmanagedDllName) public void LoadDependencies(Assembly assembly)
{ {
var libraryPath = _depResolver.ResolveUnmanagedDllToPath(unmanagedDllName); foreach (var reference in assembly.GetReferencedAssemblies())
if (libraryPath != null)
{ {
return LoadUnmanagedDllFromPath(libraryPath); Load(reference);
} }
return IntPtr.Zero;
} }
} }

View File

@@ -1,50 +1,95 @@
using Ninject.Modules; using System.Reflection;
using Ninject.Extensions.Conventions; using Ninject;
using System.Reflection; using Ninject.Activation;
using Ninject.Activation.Caching;
using Ninject.Modules;
using Ninject.Planning;
using System.Text.Json;
namespace Nadeko.Medusa; public sealed class MedusaNinjectModule : NinjectModule
public sealed class MedusaIoCKernelModule : NinjectModule
{ {
private Assembly _a;
public override string Name { get; } public override string Name { get; }
private volatile bool _isLoaded = false;
private readonly Dictionary<Type, Type[]> _types;
public MedusaIoCKernelModule(string name, Assembly a) public MedusaNinjectModule(Assembly assembly, string name)
{ {
Name = name; Name = name;
_a = a; _types = assembly.GetExportedTypes()
.Where(t => t.IsClass)
.Where(t => t.GetCustomAttribute<svcAttribute>() is not null)
.ToDictionary(x => x,
type => type.GetInterfaces().ToArray());
} }
public override void Load() public override void Load()
{ {
// todo cehck for duplicate registrations with ninject.extensions.convention if (_isLoaded)
Kernel.Bind(conf => return;
foreach (var (type, data) in _types)
{ {
var transient = conf.From(_a) var attribute = type.GetCustomAttribute<svcAttribute>()!;
.SelectAllClasses() var scope = GetScope(attribute.Lifetime);
.WithAttribute<svcAttribute>(x => x.Lifetime == Lifetime.Transient);
transient.BindAllInterfaces().Configure(x => x.InTransientScope()); Bind(type)
transient.BindToSelf().Configure(x => x.InTransientScope()); .ToSelf()
.InScope(scope);
var singleton = conf.From(_a) foreach (var inter in data)
.SelectAllClasses() {
.WithAttribute<svcAttribute>(x => x.Lifetime == Lifetime.Singleton); Bind(inter)
.ToMethod(x => x.Kernel.Get(type))
.InScope(scope);
}
}
singleton.BindAllInterfaces().Configure(x => x.InSingletonScope()); _isLoaded = true;
singleton.BindToSelf().Configure(x => x.InSingletonScope());
});
} }
private Func<IContext, object?> GetScope(Lifetime lt)
=> _ => lt switch
{
Lifetime.Singleton => this,
Lifetime.Transient => null,
};
public override void Unload() public override void Unload()
{ {
// todo implement unload if (!_isLoaded)
// Kernel.Unbind(); return;
}
public override void Dispose(bool disposing) var planner = (RemovablePlanner)Kernel.Components.Get<IPlanner>();
{ var cache = Kernel.Components.Get<ICache>();
_a = null!; foreach (var binding in this.Bindings)
base.Dispose(disposing); {
Kernel.RemoveBinding(binding);
}
foreach (var type in _types.SelectMany(x => x.Value).Concat(_types.Keys))
{
var binds = Kernel.GetBindings(type);
if (!binds.Any())
{
Unbind(type);
planner.RemovePlan(type);
}
}
Bindings.Clear();
cache.Clear(this);
_types.Clear();
// in case the library uses System.Text.Json
var assembly = typeof(JsonSerializerOptions).Assembly;
var updateHandlerType = assembly.GetType("System.Text.Json.JsonSerializerOptionsUpdateHandler");
var clearCacheMethod = updateHandlerType?.GetMethod("ClearCache", BindingFlags.Static | BindingFlags.Public);
clearCacheMethod?.Invoke(null, new object?[] { null });
_isLoaded = false;
} }
} }

View File

@@ -237,8 +237,10 @@ public sealed class MedusaLoaderService : IMedusaLoaderService, IReadyExecutor,
SnekInfos: snekData.ToImmutableArray(), SnekInfos: snekData.ToImmutableArray(),
strings, strings,
typeReaders, typeReaders,
execs, execs)
kernelModule); {
KernelModule = kernelModule
};
_medusaConfig.AddLoadedMedusa(safeName); _medusaConfig.AddLoadedMedusa(safeName);
@@ -319,18 +321,28 @@ public sealed class MedusaLoaderService : IMedusaLoaderService, IReadyExecutor,
ctxWr = null; ctxWr = null;
snekData = null; snekData = null;
var path = $"{BASE_DIR}/{safeName}/{safeName}.dll"; var path = Path.GetFullPath($"{BASE_DIR}/{safeName}/{safeName}.dll");
strings = MedusaStrings.CreateDefault($"{BASE_DIR}/{safeName}"); var dir = Path.GetFullPath($"{BASE_DIR}/{safeName}");
if (!Directory.Exists(dir))
throw new DirectoryNotFoundException($"Medusa folder not found: {dir}");
if (!File.Exists(path))
throw new FileNotFoundException($"Medusa dll not found: {path}");
strings = MedusaStrings.CreateDefault(dir);
var ctx = new MedusaAssemblyLoadContext(Path.GetDirectoryName(path)!); var ctx = new MedusaAssemblyLoadContext(Path.GetDirectoryName(path)!);
var a = ctx.LoadFromAssemblyPath(Path.GetFullPath(path)); var a = ctx.LoadFromAssemblyPath(Path.GetFullPath(path));
ctx.LoadDependencies(a);
// load services // load services
ninjectModule = new MedusaIoCKernelModule(safeName, a); ninjectModule = new MedusaNinjectModule(a, safeName);
_kernel.Load(ninjectModule); _kernel.Load(ninjectModule);
var sis = LoadSneksFromAssembly(safeName, a); var sis = LoadSneksFromAssembly(safeName, a);
typeReaders = LoadTypeReadersFromAssembly(a, strings); typeReaders = LoadTypeReadersFromAssembly(a, strings);
// todo allow this
if (sis.Count == 0) if (sis.Count == 0)
{ {
_kernel.Unload(safeName); _kernel.Unload(safeName);
@@ -590,8 +602,14 @@ public sealed class MedusaLoaderService : IMedusaLoaderService, IReadyExecutor,
await DisposeSnekInstances(lsi); await DisposeSnekInstances(lsi);
var lc = lsi.LoadContext; var lc = lsi.LoadContext;
var km = lsi.KernelModule;
lsi.KernelModule = null!;
_kernel.Unload(km.Name);
if (km is IDisposable d)
d.Dispose();
// lsi.KernelModule = null!;
lsi = null; lsi = null;
_medusaConfig.RemoveLoadedMedusa(name); _medusaConfig.RemoveLoadedMedusa(name);
@@ -650,7 +668,9 @@ public sealed class MedusaLoaderService : IMedusaLoaderService, IReadyExecutor,
private void UnloadContext(WeakReference<MedusaAssemblyLoadContext> lsiLoadContext) private void UnloadContext(WeakReference<MedusaAssemblyLoadContext> lsiLoadContext)
{ {
if (lsiLoadContext.TryGetTarget(out var ctx)) if (lsiLoadContext.TryGetTarget(out var ctx))
{
ctx.Unload(); ctx.Unload();
}
} }
private void GcCleanup() private void GcCleanup()

View File

@@ -9,7 +9,8 @@ public sealed record ResolvedMedusa(
IImmutableList<SnekInfo> SnekInfos, IImmutableList<SnekInfo> SnekInfos,
IMedusaStrings Strings, IMedusaStrings Strings,
Dictionary<Type, TypeReader> TypeReaders, Dictionary<Type, TypeReader> TypeReaders,
IReadOnlyCollection<ICustomBehavior> Execs, IReadOnlyCollection<ICustomBehavior> Execs
INinjectModule KernelModule) )
{ {
public INinjectModule KernelModule { get; set; }
} }

View File

@@ -0,0 +1,122 @@
//-------------------------------------------------------------------------------
// <copyright file="Planner.cs" company="Ninject Project Contributors">
// Copyright (c) 2007-2009, Enkari, Ltd.
// Copyright (c) 2009-2011 Ninject Project Contributors
// Authors: Nate Kohari (nate@enkari.com)
// Remo Gloor (remo.gloor@gmail.com)
//
// Dual-licensed under the Apache License, Version 2.0, and the Microsoft Public License (Ms-PL).
// you may not use this file except in compliance with one of the Licenses.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// or
// http://www.microsoft.com/opensource/licenses.mspx
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// </copyright>
//-------------------------------------------------------------------------------
// ReSharper disable all
#pragma warning disable
namespace Ninject.Planning;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Ninject.Components;
using Ninject.Infrastructure.Language;
using Ninject.Planning.Strategies;
/// <summary>
/// Generates plans for how to activate instances.
/// </summary>
public class RemovablePlanner : NinjectComponent, IPlanner
{
private readonly ReaderWriterLock plannerLock = new ReaderWriterLock();
private readonly Dictionary<Type, IPlan> plans = new Dictionary<Type, IPlan>();
/// <summary>
/// Initializes a new instance of the <see cref="RemovablePlanner"/> class.
/// </summary>
/// <param name="strategies">The strategies to execute during planning.</param>
public RemovablePlanner(IEnumerable<IPlanningStrategy> strategies)
{
this.Strategies = strategies.ToList();
}
/// <summary>
/// Gets the strategies that contribute to the planning process.
/// </summary>
public IList<IPlanningStrategy> Strategies { get; private set; }
/// <summary>
/// Gets or creates an activation plan for the specified type.
/// </summary>
/// <param name="type">The type for which a plan should be created.</param>
/// <returns>The type's activation plan.</returns>
public IPlan GetPlan(Type type)
{
this.plannerLock.AcquireReaderLock(Timeout.Infinite);
try
{
IPlan plan;
return this.plans.TryGetValue(type, out plan) ? plan : this.CreateNewPlan(type);
}
finally
{
this.plannerLock.ReleaseReaderLock();
}
}
/// <summary>
/// Creates an empty plan for the specified type.
/// </summary>
/// <param name="type">The type for which a plan should be created.</param>
/// <returns>The created plan.</returns>
protected virtual IPlan CreateEmptyPlan(Type type)
{
return new Plan(type);
}
/// <summary>
/// Creates a new plan for the specified type.
/// This method requires an active reader lock!
/// </summary>
/// <param name="type">The type.</param>
/// <returns>The newly created plan.</returns>
private IPlan CreateNewPlan(Type type)
{
var lockCooki = this.plannerLock.UpgradeToWriterLock(Timeout.Infinite);
try
{
IPlan plan;
if (this.plans.TryGetValue(type, out plan))
{
return plan;
}
plan = this.CreateEmptyPlan(type);
this.plans.Add(type, plan);
this.Strategies.Map(s => s.Execute(plan));
return plan;
}
finally
{
this.plannerLock.DowngradeFromWriterLock(ref lockCooki);
}
}
public void RemovePlan(Type type)
{
plans.Remove(type);
plans.TrimExcess();
}
}

View File

@@ -482,10 +482,10 @@ public abstract class NadekoContext : DbContext
#endregion #endregion
} }
#if DEBUG // #if DEBUG
private static readonly ILoggerFactory _debugLoggerFactory = LoggerFactory.Create(x => x.AddConsole()); // private static readonly ILoggerFactory _debugLoggerFactory = LoggerFactory.Create(x => x.AddConsole());
//
protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) // protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
=> optionsBuilder.UseLoggerFactory(_debugLoggerFactory); // => optionsBuilder.UseLoggerFactory(_debugLoggerFactory);
#endif // #endif
} }

View File

@@ -1,7 +0,0 @@
<Project>
<ItemDefinitionGroup>
<ProjectReference>
<PrivateAssets>all</PrivateAssets>
</ProjectReference>
</ItemDefinitionGroup>
</Project>

View File

@@ -55,6 +55,7 @@ public class CommandHandler : INService, IReadyExecutor
_prefixes = bot.AllGuildConfigs.Where(x => x.Prefix is not null) _prefixes = bot.AllGuildConfigs.Where(x => x.Prefix is not null)
.ToDictionary(x => x.GuildId, x => x.Prefix) .ToDictionary(x => x.GuildId, x => x.Prefix)
.ToConcurrent(); .ToConcurrent();
} }
public async Task OnReadyAsync() public async Task OnReadyAsync()

View File

@@ -5,6 +5,7 @@ using NadekoBot.Modules.Music.Resolvers;
using NadekoBot.Modules.Music.Services; using NadekoBot.Modules.Music.Services;
using Ninject; using Ninject;
using Ninject.Extensions.Conventions; using Ninject.Extensions.Conventions;
using Ninject.Extensions.Conventions.Syntax;
using StackExchange.Redis; using StackExchange.Redis;
using System.Net; using System.Net;
using System.Reflection; using System.Reflection;
@@ -39,10 +40,7 @@ public static class ServiceCollectionExtensions
.SelectAllClasses() .SelectAllClasses()
.Where(f => f.IsAssignableToGenericType(typeof(ConfigServiceBase<>))); .Where(f => f.IsAssignableToGenericType(typeof(ConfigServiceBase<>)));
// todo check for duplicates configs.BindToSelfWithInterfaces()
configs.BindToSelf()
.Configure(c => c.InSingletonScope());
configs.BindAllInterfaces()
.Configure(c => c.InSingletonScope()); .Configure(c => c.InSingletonScope());
}); });
@@ -64,7 +62,7 @@ public static class ServiceCollectionExtensions
kernel.Bind<ILocalTrackResolver>().To<LocalTrackResolver>().InSingletonScope(); kernel.Bind<ILocalTrackResolver>().To<LocalTrackResolver>().InSingletonScope();
kernel.Bind<IRadioResolver>().To<RadioResolver>().InSingletonScope(); kernel.Bind<IRadioResolver>().To<RadioResolver>().InSingletonScope();
kernel.Bind<ITrackCacher>().To<TrackCacher>().InSingletonScope(); kernel.Bind<ITrackCacher>().To<TrackCacher>().InSingletonScope();
kernel.Bind<YtLoader>().ToSelf().InSingletonScope(); // kernel.Bind<YtLoader>().ToSelf().InSingletonScope();
return kernel; return kernel;
} }
@@ -77,8 +75,7 @@ public static class ServiceCollectionExtensions
.SelectAllClasses() .SelectAllClasses()
.Where(c => c.IsPublic && c.IsNested && baseType.IsAssignableFrom(baseType)); .Where(c => c.IsPublic && c.IsNested && baseType.IsAssignableFrom(baseType));
classes.BindAllInterfaces().Configure(x => x.InSingletonScope()); classes.BindToSelfWithInterfaces().Configure(x => x.InSingletonScope());
classes.BindToSelf().Configure(x => x.InSingletonScope());
}); });
return kernel; return kernel;
@@ -125,4 +122,7 @@ public static class ServiceCollectionExtensions
return kernel; return kernel;
} }
public static IConfigureSyntax BindToSelfWithInterfaces(this IJoinExcludeIncludeBindSyntax matcher)
=> matcher.BindSelection((type, types) => types.Append(type));
} }