Files
qwsdcvghyu89 580ceb8c3c Add FollowRedirects option to downloader
Introduces a FollowRedirects property to UnitDownloaderOptions and its builder, allowing control over HTTP redirect behavior. Updates UnitDownloader to use this option, following redirects when enabled and reporting progress accordingly.
2025-11-16 01:11:22 +11:00

201 lines
8.6 KiB
C#

using System.Diagnostics.CodeAnalysis;
using System.Text;
using Beam.Abstractions;
using Beam.Models;
using HtmlAgilityPack;
using File = System.IO.File;
namespace Beam.Downloaders {
/// <summary>
/// A download managing class that manages a singular download with failure-detection and exponential-backoff retries. This class is safe to instantiate per request.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="web"></param>
/// <param name="transformer"></param>
/// <param name="failurePredicate"></param>
public class UnitDownloader<OutType>(UnitDownloaderOptions<OutType> options) : IUnitDownloader<OutType> {
public UnitDownloaderOptions<OutType> Options { get; } = options;
public HttpClient Client => Options.Client;
public virtual AsyncTransformer<ByteDocument, OutType> Transformer => Options.AsyncTransformer;
public virtual AsyncDownloadFailurePredicate<ByteDocument>?[]? FailurePredicates =>
Options?.FailurePredicateOptions?.AsyncDownloadFailurePredicates;
public int LinksPerDownload { get; } = 1;
protected virtual async Task DownloadToStream(string url, int bufferSize, Stream destinationStream, IProgress<IDownloadReport> progress,
CancellationToken ct) {
if (options.FollowRedirects) {
var response = await Client.GetAsync(url, ct); // automatically follows redirects
await response.Content.CopyToAsync(destinationStream, ct);
progress?.Report(new DownloadReport() {
BytesDownloaded = destinationStream.Length,
BytesRemaining = 0
});
return;
}
var stream = await Client.GetStreamAsync(url, ct); // does not follow redirects
byte[] buffer = new byte[bufferSize];
int inBuffer = 0;
long downloaded = 0;
long? remaining() {
try {
return stream.Length - downloaded;
}
catch {
return null;
}
}
while ((inBuffer = stream.Read(buffer)) > 0) {
downloaded += inBuffer;
await destinationStream.WriteAsync(buffer.AsMemory(0, inBuffer), ct);
progress?.Report(new DownloadReport() {
BytesDownloaded = inBuffer,
BytesRemaining = remaining()
});
ct.ThrowIfCancellationRequested();
}
}
protected virtual async Task DownloadToFile(string url, int bufferSize, string path,
IProgress<IDownloadReport> progress, CancellationToken ct) {
if (!Directory.Exists(Path.GetDirectoryName(path)))
throw new InvalidOperationException(
string.Format(Exceptions.Exceptions.unit_download_directory_nonexistant, path));
await using var file = File.OpenWrite(path);
await DownloadToStream(url, bufferSize, file, progress, ct);
}
protected virtual async Task<ByteDocument> DownloadToMemory(string url, int bufferSize,
IProgress<IDownloadReport> progress, CancellationToken ct) {
await using var ms = new MemoryStream();
await DownloadToStream(url, bufferSize, ms, progress, ct);
if (!ms.TryGetBuffer(out var bytes))
throw new Exception(Exceptions.Exceptions.unit_download_invalid_memory_stream);
return new ByteDocument(url, bytes);
}
protected virtual async Task<bool> IsFailure(ByteDocument doc, CancellationToken ct) {
if (FailurePredicates is null)
return false;
if (!(Options?.FailurePredicateOptions?.ProcessInParallel ?? false))
foreach (var pred in FailurePredicates) {
if (pred is null)
continue;
if (await pred(doc))
return true;
}
else {
var failed = false;
await Parallel.ForEachAsync(FailurePredicates, new ParallelOptions() {
MaxDegreeOfParallelism = Options?.FailurePredicateOptions?.ParallelThreads ?? 4,
CancellationToken = ct
},
async (predicate, token) => {
if (token.IsCancellationRequested)
return;
if (failed)
return;
if (predicate == null)
return;
if (await predicate(doc))
Interlocked.CompareExchange(ref failed, true, false);
}
);
return failed;
}
return false;
}
protected virtual async Task<ByteDocument> _Download(string link, IProgress<IDownloadReport> progress, CancellationToken ct) {
if (Options.DownloadFolder is not null) {
var path = Path.Combine(Options.DownloadFolder, options.GetFileNameForDownload(link, []));
await DownloadToFile(link, Options.BufferSize, path, progress, ct);
return new ByteDocument(link, Encoding.UTF8.GetBytes(path));
}
else {
return await DownloadToMemory(link, Options.BufferSize, progress, ct);
}
}
protected virtual async Task<(bool, OutType?)> Transform(ByteDocument download, CancellationToken ct) {
try {
if (FailurePredicates is null || !(await IsFailure(download, ct)))
return (true, await Transformer(download));
else
return (false, default);
} catch(Exception) {
return (false, default);
}
}
public async Task<(bool, OutType?)> TryDownload(IOrdered<string>[] link, CancellationToken ct, int maximumRetryCount = 7, IProgress<IDownloadReport>? downProgress = null, IProgress<IRetryReport>? tryProgress = null) {
if (link.Length == 0)
return (false, default);
downProgress ??= new Progress<IDownloadReport>();
if (ShouldSkip(link[0].Data, out var defaultType))
return (true, defaultType);
OutType? ot = default;
int tryCount = 0;
while (tryCount < maximumRetryCount) {
ct.ThrowIfCancellationRequested();
var rt = await _Download(link[0].Data, downProgress, ct);
(var success, ot) = await Transform(rt, ct);
if (success && ot != null)
return (true, ot);
++tryCount;
tryProgress?.Report(new RetryReport(tryCount, link[0].Data));
await Task.Delay((int)Math.Pow(2, tryCount) * 1000);
}
return (false, ot);
}
private bool ShouldSkip(string link, [NotNullWhen(true)] out OutType? outType) {
outType = default;
if (Options.SkipPredicateOptions?.SkipPredicates is null)
return false;
if (!Options.SkipPredicateOptions.ProcessInParallel)
foreach (var pred in Options.SkipPredicateOptions.SkipPredicates) {
if (pred is null)
continue;
if (pred(link, out outType))
return true;
}
else {
var shouldSkip = false;
OutType? _outType = default;
Parallel.ForEach(Options.SkipPredicateOptions.SkipPredicates, new ParallelOptions() {
MaxDegreeOfParallelism = Options?.FailurePredicateOptions?.ParallelThreads ?? 4
},
(predicate, parallelLoopState) => {
if (parallelLoopState.ShouldExitCurrentIteration)
return;
if (predicate == null)
return;
if (predicate(link, out var _innerLoopOutType)) {
Interlocked.CompareExchange(ref shouldSkip, true, false);
Interlocked.CompareExchange(ref _outType, _innerLoopOutType, default);
parallelLoopState.Break();
}
}
);
outType = _outType;
return shouldSkip;
}
return false;
}
}
}