#nullable disable using LinqToDB; using LinqToDB.Data; using LinqToDB.EntityFrameworkCore; using NadekoBot.Common.ModuleBehaviors; using NadekoBot.Db.Models; using System.Collections.Frozen; namespace NadekoBot.Modules.Permissions.Services; public sealed class BlacklistService : IExecOnMessage, IReadyExecutor { public int Priority => int.MaxValue; private readonly DbService _db; private readonly IPubSub _pubSub; private readonly IBotCreds _creds; private readonly DiscordSocketClient _client; private FrozenSet blacklistedGuilds = new HashSet().ToFrozenSet(); private FrozenSet blacklistedUsers = new HashSet().ToFrozenSet(); private FrozenSet blacklistedChannels = new HashSet().ToFrozenSet(); private readonly TypedKey _blPubKey = new("blacklist.reload"); public BlacklistService( DbService db, IPubSub pubSub, IBotCreds creds, DiscordSocketClient client) { _db = db; _pubSub = pubSub; _creds = creds; _client = client; _pubSub.Sub(_blPubKey, async _ => await Reload(false)); } public async Task OnReadyAsync() { _client.JoinedGuild += async (g) => { if (blacklistedGuilds.Contains(g.Id)) { await g.LeaveAsync(); } }; await Reload(false); } private ValueTask OnReload(BlacklistEntry[] newBlacklist) { if (newBlacklist is null) return default; blacklistedGuilds = new HashSet(newBlacklist.Where(x => x.Type == BlacklistType.Server).Select(x => x.ItemId)) .ToFrozenSet(); blacklistedChannels = new HashSet(newBlacklist.Where(x => x.Type == BlacklistType.Channel).Select(x => x.ItemId)) .ToFrozenSet(); blacklistedUsers = new HashSet(newBlacklist.Where(x => x.Type == BlacklistType.User).Select(x => x.ItemId)) .ToFrozenSet(); return default; } public Task ExecOnMessageAsync(IGuild guild, IUserMessage usrMsg) { if (blacklistedGuilds.Contains(guild.Id)) { Log.Information("Blocked input from blacklisted guild: {GuildName} [{GuildId}]", guild.Name, guild.Id.ToString()); return Task.FromResult(true); } if (blacklistedChannels.Contains(usrMsg.Channel.Id)) { Log.Information("Blocked input from blacklisted channel: {ChannelName} [{ChannelId}]", usrMsg.Channel.Name, usrMsg.Channel.Id.ToString()); } if (blacklistedUsers.Contains(usrMsg.Author.Id)) { Log.Information("Blocked input from blacklisted user: {UserName} [{UserId}]", usrMsg.Author.ToString(), usrMsg.Author.Id.ToString()); return Task.FromResult(true); } return Task.FromResult(false); } public async Task> GetBlacklist(BlacklistType type) { await using var uow = _db.GetDbContext(); return await uow .GetTable() .Where(x => x.Type == type) .ToListAsync(); } public async Task Reload(bool publish = true) { var totalShards = _creds.TotalShards; await using var uow = _db.GetDbContext(); var items = uow.GetTable() .Where(x => x.Type != BlacklistType.Server || (x.Type == BlacklistType.Server && Linq2DbExpressions.GuildOnShard(x.ItemId, totalShards, _client.ShardId))) .ToArray(); if (publish) { await _pubSub.Pub(_blPubKey, true); } await OnReload(items); } public async Task Blacklist(BlacklistType type, ulong id) { if (_creds.OwnerIds.Contains(id)) return; await using var uow = _db.GetDbContext(); await uow .GetTable() .InsertAsync(() => new() { ItemId = id, Type = type, }); if (type == BlacklistType.User) { await uow.GetTable() .Where(x => x.UserId == id) .UpdateAsync(_ => new() { CurrencyAmount = 0 }); } await Reload(); } public async Task UnBlacklist(BlacklistType type, ulong id) { await using var uow = _db.GetDbContext(); await uow.GetTable() .Where(bi => bi.ItemId == id && bi.Type == type) .DeleteAsync(); await Reload(); } public async Task BlacklistUsers(IReadOnlyCollection toBlacklist) { await using var uow = _db.GetDbContext(); var bc = uow.GetTable(); await bc.BulkCopyAsync(toBlacklist.Select(uid => new BlacklistEntry { ItemId = uid, Type = BlacklistType.User })); var blList = toBlacklist.ToList(); await uow.GetTable() .Where(x => blList.Contains(x.UserId)) .UpdateAsync(_ => new() { CurrencyAmount = 0 }); await Reload(); } }