More work on figuring out the DB stuff and converting EF code to linqtodb where it's easy

This commit is contained in:
Kwoth
2023-07-09 06:50:04 +00:00
parent be1d14d095
commit 842a8a2f71
17 changed files with 210 additions and 180 deletions

View File

@@ -6,73 +6,13 @@ using NadekoBot.Services.Database;
namespace NadekoBot.Services;
public class DbService
public abstract class DbService
{
private readonly IBotCredsProvider _creds;
/// <summary>
/// Call this to apply all migrations
/// </summary>
public abstract Task SetupAsync();
// these are props because creds can change at runtime
private string DbType => _creds.GetCreds().Db.Type.ToLowerInvariant().Trim();
private string ConnString => _creds.GetCreds().Db.ConnectionString;
public DbService(IBotCredsProvider creds)
{
LinqToDBForEFTools.Initialize();
Configuration.Linq.DisableQueryCache = true;
_creds = creds;
}
public async Task SetupAsync()
{
var dbType = DbType;
var connString = ConnString;
await using var context = CreateRawDbContext(dbType, connString);
// make sure sqlite db is in wal journal mode
if (context is SqliteContext)
{
await context.Database.ExecuteSqlRawAsync("PRAGMA journal_mode=WAL");
}
await context.Database.MigrateAsync();
}
private static NadekoContext CreateRawDbContext(string dbType, string connString)
{
switch (dbType)
{
case "postgresql":
case "postgres":
case "pgsql":
return new PostgreSqlContext(connString);
case "mysql":
return new MysqlContext(connString);
case "sqlite":
return new SqliteContext(connString);
default:
throw new NotSupportedException($"The database provide type of '{dbType}' is not supported.");
}
}
private NadekoContext GetDbContextInternal()
{
var dbType = DbType;
var connString = ConnString;
var context = CreateRawDbContext(dbType, connString);
if (context is SqliteContext)
{
var conn = context.Database.GetDbConnection();
conn.Open();
using var com = conn.CreateCommand();
com.CommandText = "PRAGMA synchronous=OFF";
com.ExecuteNonQuery();
}
return context;
}
public NadekoContext GetDbContext()
=> GetDbContextInternal();
public abstract DbContext CreateRawDbContext(string dbType, string connString);
public abstract DbContext GetDbContext();
}

View File

@@ -1,5 +1,7 @@
#nullable disable
using LinqToDB;
using LinqToDB.EntityFrameworkCore;
using NadekoBot.Db.Models;
using NadekoBot.Services.Currency;
namespace NadekoBot.Services;
@@ -52,7 +54,8 @@ public sealed class CurrencyService : ICurrencyService, INService
if (type == CurrencyType.Default)
{
await using var ctx = _db.GetDbContext();
await ctx.DiscordUser
await ctx
.GetTable<DiscordUser>()
.Where(x => userIds.Contains(x.UserId))
.UpdateAsync(du => new()
{

View File

@@ -1,5 +1,6 @@
using LinqToDB;
using LinqToDB.EntityFrameworkCore;
using NadekoBot.Db.Models;
using NadekoBot.Services.Database.Models;
namespace NadekoBot.Services.Currency;
@@ -19,11 +20,11 @@ public class DefaultWallet : IWallet
{
await using var ctx = _db.GetDbContext();
var userId = UserId;
return await ctx.DiscordUser
.ToLinqToDBTable()
.Where(x => x.UserId == userId)
.Select(x => x.CurrencyAmount)
.FirstOrDefaultAsync();
return await ctx
.GetTable<DiscordUser>()
.Where(x => x.UserId == userId)
.Select(x => x.CurrencyAmount)
.FirstOrDefaultAsync();
}
public async Task<bool> Take(long amount, TxData? txData)
@@ -34,12 +35,13 @@ public class DefaultWallet : IWallet
await using var ctx = _db.GetDbContext();
var userId = UserId;
var changed = await ctx.DiscordUser
.Where(x => x.UserId == userId && x.CurrencyAmount >= amount)
.UpdateAsync(x => new()
{
CurrencyAmount = x.CurrencyAmount - amount
});
var changed = await ctx
.GetTable<DiscordUser>()
.Where(x => x.UserId == userId && x.CurrencyAmount >= amount)
.UpdateAsync(x => new()
{
CurrencyAmount = x.CurrencyAmount - amount
});
if (changed == 0)
return false;
@@ -73,22 +75,23 @@ public class DefaultWallet : IWallet
await using (var tran = await ctx.Database.BeginTransactionAsync())
{
var changed = await ctx.DiscordUser
.Where(x => x.UserId == userId)
.UpdateAsync(x => new()
{
CurrencyAmount = x.CurrencyAmount + amount
});
var changed = await ctx
.GetTable<DiscordUser>()
.Where(x => x.UserId == userId)
.UpdateAsync(x => new()
{
CurrencyAmount = x.CurrencyAmount + amount
});
if (changed == 0)
{
await ctx.DiscordUser
.ToLinqToDBTable()
.Value(x => x.UserId, userId)
.Value(x => x.Username, "Unknown")
.Value(x => x.Discriminator, "????")
.Value(x => x.CurrencyAmount, amount)
.InsertAsync();
await ctx
.GetTable<DiscordUser>()
.Value(x => x.UserId, userId)
.Value(x => x.Username, "Unknown")
.Value(x => x.Discriminator, "????")
.Value(x => x.CurrencyAmount, amount)
.InsertAsync();
}
await tran.CommitAsync();

View File

@@ -1,7 +1,9 @@
#nullable disable
using Microsoft.EntityFrameworkCore;
using LinqToDB;
using LinqToDB.EntityFrameworkCore;
using NadekoBot.Common.ModuleBehaviors;
using NadekoBot.Db;
using NadekoBot.Db.Models;
using NadekoBot.Services.Database.Models;
namespace NadekoBot.Modules.Permissions.Services;
@@ -73,38 +75,36 @@ public sealed class BlacklistService : IExecOnMessage
public void Reload(bool publish = true)
{
using var uow = _db.GetDbContext();
var toPublish = uow.Blacklist.AsNoTracking().ToArray();
var toPublish = uow.GetTable<BlacklistEntry>().ToArray();
blacklist = toPublish;
if (publish)
_pubSub.Pub(_blPubKey, toPublish);
}
public void Blacklist(BlacklistType type, ulong id)
public async Task Blacklist(BlacklistType type, ulong id)
{
if (_creds.OwnerIds.Contains(id))
return;
using var uow = _db.GetDbContext();
var item = new BlacklistEntry
{
ItemId = id,
Type = type
};
uow.Blacklist.Add(item);
uow.SaveChanges();
await using var uow = _db.GetDbContext();
await uow
.GetTable<BlacklistEntry>()
.InsertAsync(() => new()
{
ItemId = id,
Type = type,
});
Reload();
}
public void UnBlacklist(BlacklistType type, ulong id)
public async Task UnBlacklist(BlacklistType type, ulong id)
{
using var uow = _db.GetDbContext();
var toRemove = uow.Blacklist.FirstOrDefault(bi => bi.ItemId == id && bi.Type == type);
if (toRemove is not null)
uow.Blacklist.Remove(toRemove);
uow.SaveChanges();
await using var uow = _db.GetDbContext();
await uow.GetTable<BlacklistEntry>()
.Where(bi => bi.ItemId == id && bi.Type == type)
.DeleteAsync();
Reload();
}
@@ -113,16 +113,21 @@ public sealed class BlacklistService : IExecOnMessage
{
using (var uow = _db.GetDbContext())
{
var bc = uow.Blacklist;
//blacklist the users
var bc = uow.Set<BlacklistEntry>();
bc.AddRange(toBlacklist.Select(x => new BlacklistEntry
{
ItemId = x,
Type = BlacklistType.User
}));
//clear their currencies
uow.DiscordUser.RemoveFromMany(toBlacklist);
// todo check if blacklist works and removes currency
uow.GetTable<DiscordUser>()
.UpdateAsync(x => toBlacklist.Contains(x.UserId),
_ => new()
{
CurrencyAmount = 0
});
uow.SaveChanges();
}

View File

@@ -18,10 +18,11 @@ public class DiscordPermOverrideService : INService, IExecPreCommand, IDiscordPe
_db = db;
_services = services;
using var uow = _db.GetDbContext();
_overrides = uow.DiscordPermOverrides.AsNoTracking()
.AsEnumerable()
.ToDictionary(o => (o.GuildId ?? 0, o.Command), o => o)
.ToConcurrent();
_overrides = uow.Set<DiscordPermOverride>()
.AsNoTracking()
.AsEnumerable()
.ToDictionary(o => (o.GuildId ?? 0, o.Command), o => o)
.ToConcurrent();
}
public bool TryGetOverrides(ulong guildId, string commandName, out Nadeko.Bot.Db.GuildPerm? perm)
@@ -52,18 +53,18 @@ public class DiscordPermOverrideService : INService, IExecPreCommand, IDiscordPe
commandName = commandName.ToLowerInvariant();
await using var uow = _db.GetDbContext();
var over = await uow.Set<DiscordPermOverride>()
.AsQueryable()
.FirstOrDefaultAsync(x => x.GuildId == guildId && commandName == x.Command);
.AsQueryable()
.FirstOrDefaultAsync(x => x.GuildId == guildId && commandName == x.Command);
if (over is null)
{
uow.Set<DiscordPermOverride>()
.Add(over = new()
{
Command = commandName,
Perm = (Nadeko.Bot.Db.GuildPerm)perm,
GuildId = guildId
});
.Add(over = new()
{
Command = commandName,
Perm = (Nadeko.Bot.Db.GuildPerm)perm,
GuildId = guildId
});
}
else
over.Perm = (Nadeko.Bot.Db.GuildPerm)perm;
@@ -77,10 +78,10 @@ public class DiscordPermOverrideService : INService, IExecPreCommand, IDiscordPe
{
await using var uow = _db.GetDbContext();
var overrides = await uow.Set<DiscordPermOverride>()
.AsQueryable()
.AsNoTracking()
.Where(x => x.GuildId == guildId)
.ToListAsync();
.AsQueryable()
.AsNoTracking()
.Where(x => x.GuildId == guildId)
.ToListAsync();
uow.RemoveRange(overrides);
await uow.SaveChangesAsync();
@@ -95,9 +96,9 @@ public class DiscordPermOverrideService : INService, IExecPreCommand, IDiscordPe
await using var uow = _db.GetDbContext();
var over = await uow.Set<DiscordPermOverride>()
.AsQueryable()
.AsNoTracking()
.FirstOrDefaultAsync(x => x.GuildId == guildId && x.Command == commandName);
.AsQueryable()
.AsNoTracking()
.FirstOrDefaultAsync(x => x.GuildId == guildId && x.Command == commandName);
if (over is null)
return;
@@ -112,10 +113,10 @@ public class DiscordPermOverrideService : INService, IExecPreCommand, IDiscordPe
{
await using var uow = _db.GetDbContext();
return await uow.Set<DiscordPermOverride>()
.AsQueryable()
.AsNoTracking()
.Where(x => x.GuildId == guildId)
.ToListAsync();
.AsQueryable()
.AsNoTracking()
.Where(x => x.GuildId == guildId)
.ToListAsync();
}
public async Task<bool> ExecPreCommandAsync(ICommandContext context, string moduleName, CommandInfo command)

View File

@@ -42,7 +42,7 @@ public static class DiscordUserExtensions
});
public static Task EnsureUserCreatedAsync(
this NadekoBaseContext ctx,
this DbContext ctx,
ulong userId)
=> ctx.GetTable<DiscordUser>()
.InsertOrUpdateAsync(
@@ -66,7 +66,7 @@ public static class DiscordUserExtensions
//temp is only used in updatecurrencystate, so that i don't overwrite real usernames/discrims with Unknown
public static DiscordUser GetOrCreateUser(
this NadekoBaseContext ctx,
this DbContext ctx,
ulong userId,
string username,
string discrim,

View File

@@ -65,7 +65,7 @@ public static class GuildConfigExtensions
/// <param name="includes">Use to manipulate the set however you want. Pass null to include everything</param>
/// <returns>Config for the guild</returns>
public static GuildConfig GuildConfigsForId(
this NadekoBaseContext ctx,
this DbContext ctx,
ulong guildId,
Func<DbSet<GuildConfig>, IQueryable<GuildConfig>> includes)
{
@@ -193,7 +193,7 @@ public static class GuildConfigExtensions
conf.CleverbotEnabled = cleverbotEnabled;
}
public static XpSettings XpSettingsFor(this NadekoBaseContext ctx, ulong guildId)
public static XpSettings XpSettingsFor(this DbContext ctx, ulong guildId)
{
var gc = ctx.GuildConfigsForId(guildId,
set => set.Include(x => x.XpSettings)

View File

@@ -9,9 +9,9 @@ namespace NadekoBot.Db;
public static class UserXpExtensions
{
public static UserXpStats GetOrCreateUserXpStats(this NadekoContext ctx, ulong guildId, ulong userId)
public static UserXpStats GetOrCreateUserXpStats(this DbContext ctx, ulong guildId, ulong userId)
{
var usr = ctx.UserXpStats.FirstOrDefault(x => x.UserId == userId && x.GuildId == guildId);
var usr = ctx.Set<UserXpStats>().FirstOrDefault(x => x.UserId == userId && x.GuildId == guildId);
if (usr is null)
{

View File

@@ -75,24 +75,24 @@ public static class WaifuExtensions
.Select(x => x.Waifu.UserId)
.FirstOrDefault();
public static async Task<WaifuInfoStats> GetWaifuInfoAsync(this NadekoContext ctx, ulong userId)
public static async Task<WaifuInfoStats> GetWaifuInfoAsync(this DbContext ctx, ulong userId)
{
await ctx.WaifuInfo
await ctx.Set<WaifuInfo>()
.ToLinqToDBTable()
.InsertOrUpdateAsync(() => new()
{
AffinityId = null,
ClaimerId = null,
Price = 1,
WaifuId = ctx.DiscordUser.Where(x => x.UserId == userId).Select(x => x.Id).First()
WaifuId = ctx.Set<DiscordUser>().Where(x => x.UserId == userId).Select(x => x.Id).First()
},
_ => new(),
() => new()
{
WaifuId = ctx.DiscordUser.Where(x => x.UserId == userId).Select(x => x.Id).First()
WaifuId = ctx.Set<DiscordUser>().Where(x => x.UserId == userId).Select(x => x.Id).First()
});
var toReturn = ctx.WaifuInfo.AsQueryable()
var toReturn = ctx.Set<WaifuInfo>().AsQueryable()
.Where(w => w.WaifuId
== ctx.Set<DiscordUser>()
.AsQueryable()
@@ -120,7 +120,7 @@ public static class WaifuExtensions
.Where(u => u.Id == w.AffinityId)
.Select(u => u.Username + "#" + u.Discriminator)
.FirstOrDefault(),
ClaimCount = ctx.WaifuInfo.AsQueryable().Count(x => x.ClaimerId == w.WaifuId),
ClaimCount = ctx.Set<WaifuInfo>().AsQueryable().Count(x => x.ClaimerId == w.WaifuId),
ClaimerName =
ctx.Set<DiscordUser>()
.AsQueryable()

View File

@@ -1,7 +0,0 @@
using Microsoft.EntityFrameworkCore;
namespace NadekoBot.Db;
public abstract class NadekoBaseContext : DbContext
{
}

View File

@@ -12,9 +12,9 @@ public static class PollExtensions
.Include(x => x.Votes)
.ToArray();
public static void RemovePoll(this NadekoContext ctx, int id)
public static void RemovePoll(this DbContext ctx, int id)
{
var p = ctx.Poll.Include(x => x.Answers).Include(x => x.Votes).FirstOrDefault(x => x.Id == id);
var p = ctx.Set<Poll>().Include(x => x.Answers).Include(x => x.Votes).FirstOrDefault(x => x.Id == id);
if (p is null)
return;
@@ -31,6 +31,6 @@ public static class PollExtensions
p.Answers.Clear();
}
ctx.Poll.Remove(p);
ctx.Set<Poll>().Remove(p);
}
}

View File

@@ -52,7 +52,7 @@ public class PollRunner
finally { _locker.Release(); }
await using var uow = _db.GetDbContext();
var trackedPoll = uow.Poll.FirstOrDefault(x => x.Id == Poll.Id);
var trackedPoll = uow.Set<Poll>().FirstOrDefault(x => x.Id == Poll.Id);
trackedPoll.Votes.Add(voteObj);
uow.SaveChanges();
return true;

View File

@@ -24,15 +24,15 @@ public class PollService : IExecOnMessage
_eb = eb;
using var uow = db.GetDbContext();
ActivePolls = uow.Poll.GetAllPolls()
.ToDictionary(x => x.GuildId,
x =>
{
var pr = new PollRunner(db, x);
pr.OnVoted += Pr_OnVoted;
return pr;
})
.ToConcurrent();
ActivePolls = uow.Set<Poll>().GetAllPolls()
.ToDictionary(x => x.GuildId,
x =>
{
var pr = new PollRunner(db, x);
pr.OnVoted += Pr_OnVoted;
return pr;
})
.ToConcurrent();
}
public Poll CreatePoll(ulong guildId, ulong channelId, string input)
@@ -44,10 +44,10 @@ public class PollService : IExecOnMessage
return null;
var col = new IndexedCollection<PollAnswer>(data.Skip(1)
.Select(x => new PollAnswer
{
Text = x
}));
.Select(x => new PollAnswer
{
Text = x
}));
return new()
{
@@ -66,7 +66,7 @@ public class PollService : IExecOnMessage
{
using (var uow = _db.GetDbContext())
{
uow.Poll.Add(p);
uow.Set<Poll>().Add(p);
uow.SaveChanges();
}
@@ -98,8 +98,13 @@ public class PollService : IExecOnMessage
var toDelete = await msg.Channel.SendConfirmAsync(_eb,
_strs.GetText(strs.poll_voted(Format.Bold(usr.ToString())), usr.GuildId));
toDelete.DeleteAfter(5);
try { await msg.DeleteAsync(); }
catch { }
try
{
await msg.DeleteAsync();
}
catch
{
}
}
public async Task<bool> ExecOnMessageAsync(IGuild guild, IUserMessage msg)

View File

@@ -124,9 +124,9 @@ public partial class Permissions
private async Task Blacklist(AddRemove action, ulong id, BlacklistType type)
{
if (action == AddRemove.Add)
_service.Blacklist(type, id);
await _service.Blacklist(type, id);
else
_service.UnBlacklist(type, id);
await _service.UnBlacklist(type, id);
if (action == AddRemove.Add)
{

View File

@@ -24,4 +24,8 @@
<ProjectReference Include="..\Nadeko.Bot.Generators.Cloneable\Nadeko.Bot.Generators.Cloneable.csproj" OutputItemType="Analyzer" />
</ItemGroup>
<ItemGroup>
<Folder Include="_Common\Db\" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,76 @@
using LinqToDB.Common;
using LinqToDB.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore;
namespace NadekoBot.Services.Database;
public sealed class NadekoDbService : DbService
{
private readonly IBotCredsProvider _creds;
// these are props because creds can change at runtime
private string DbType => _creds.GetCreds().Db.Type.ToLowerInvariant().Trim();
private string ConnString => _creds.GetCreds().Db.ConnectionString;
public NadekoDbService(IBotCredsProvider creds)
{
LinqToDBForEFTools.Initialize();
Configuration.Linq.DisableQueryCache = true;
_creds = creds;
}
public override async Task SetupAsync()
{
var dbType = DbType;
var connString = ConnString;
await using var context = CreateRawDbContext(dbType, connString);
// make sure sqlite db is in wal journal mode
if (context is SqliteContext)
{
await context.Database.ExecuteSqlRawAsync("PRAGMA journal_mode=WAL");
}
await context.Database.MigrateAsync();
}
public override NadekoContext CreateRawDbContext(string dbType, string connString)
{
switch (dbType)
{
case "postgresql":
case "postgres":
case "pgsql":
return new PostgreSqlContext(connString);
case "mysql":
return new MysqlContext(connString);
case "sqlite":
return new SqliteContext(connString);
default:
throw new NotSupportedException($"The database provide type of '{dbType}' is not supported.");
}
}
private NadekoContext GetDbContextInternal()
{
var dbType = DbType;
var connString = ConnString;
var context = CreateRawDbContext(dbType, connString);
if (context is SqliteContext)
{
var conn = context.Database.GetDbConnection();
conn.Open();
using var com = conn.CreateCommand();
com.CommandText = "PRAGMA synchronous=OFF";
com.ExecuteNonQuery();
}
return context;
}
public override NadekoContext GetDbContext()
=> GetDbContextInternal();
}