Git Product home page Git Product logo

Comments (5)

wuyu8512 avatar wuyu8512 commented on May 1, 2024 5

I have a new version and it works well

using AspNetCoreRateLimit;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Options;
using System;
using System.Threading.Tasks;

namespace Api.Internal
{
    public class SignalRLimitFilter : IHubFilter
    {
        private readonly IRateLimitProcessor _processor;

        public SignalRLimitFilter(
            IOptions<IpRateLimitOptions> options, IProcessingStrategy processing, IRateLimitCounterStore counterStore,
            IRateLimitConfiguration rateLimitConfiguration, IIpPolicyStore policyStore)
        {
            _processor = new IpRateLimitProcessor(options?.Value, counterStore, policyStore, rateLimitConfiguration, processing);
        }

        public async ValueTask<object> InvokeMethodAsync(
            HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
        {
            var httpContext = invocationContext.Context.GetHttpContext();
            var ip = httpContext.Connection.RemoteIpAddress.ToString();
            var client = new ClientRequestIdentity
            {
                ClientIp = ip,
                Path = invocationContext.HubMethodName,
                HttpVerb = "ws",
                ClientId = invocationContext.Context.UserIdentifier
            };
            foreach (var rule in await _processor.GetMatchingRulesAsync(client))
            {
                var counter = await _processor.ProcessRequestAsync(client, rule);
                Console.WriteLine("time: {0}, count: {1}", counter.Timestamp, counter.Count);
                if (counter.Count > rule.Limit)
                {
                    var retry = counter.Timestamp.RetryAfterFrom(rule);
                    throw new HubException($"call limit {retry}");
                }
            }

            Console.WriteLine($"Calling hub method '{invocationContext.HubMethodName}'");
            return await next(invocationContext);
        }

        // Optional method
        public Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
        {
            return next(context);
        }

        // Optional method
        public Task OnDisconnectedAsync(
            HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
        {
            return next(context, exception);
        }
    }
}

and in Startup.cs

// 加载限速配置
services.Configure<IpRateLimitOptions>(Configuration.GetSection("IpRateLimiting"));
services.Configure<IpRateLimitPolicies>(Configuration.GetSection("IpRateLimitPolicies"));
// 存储在DistributedCache中
services.AddSingleton<IIpPolicyStore, DistributedCacheIpPolicyStore>();
services.AddSingleton<IRateLimitCounterStore, DistributedCacheRateLimitCounterStore>();
// AspNetCoreRateLimit 4+ 新增
services.AddDistributedRateLimiting<AsyncKeyLockProcessingStrategy>();
// ASP.NET Core 3.0+ 必须有下面2段
services.AddSingleton<IHttpContextAccessor, HttpContextAccessor>();
services.AddSingleton<IRateLimitConfiguration, RateLimitConfiguration>();
// SignalR
services.AddSignalR(options => options.AddFilter<SignalRLimitFilter>()).AddMessagePackProtocol();

from aspnetcoreratelimit.

cristipufu avatar cristipufu commented on May 1, 2024 2

so this middleware will not work with SignalR? I.e. SignalR will be broken?

SignalR will work, you just can't throttle the ws:// requests.

from aspnetcoreratelimit.

replaysMike avatar replaysMike commented on May 1, 2024

I failed to RTFM and didn't have the configuration wired correctly. Now that it seems to be configured correctly SignalR works fine but all of my tests have failed to throttle. Will keep digging, maybe this is an error on my part.

from aspnetcoreratelimit.

replaysMike avatar replaysMike commented on May 1, 2024

So the ASP.Net middleware doesn't pass websocket messages through. It seems that only WebSocket connections (ASP.Net Websockets not SignalR) are being passed, and even so it might only be the connection phase.

I was able to integrate AspNetCoreRateLimit manually (much less ideal, but it works) by duplicating the code in the IpRateLimitMiddleware seen below, and calling the method from each endpoint in the Hub. SignalR doesn't seem to have any facility for intercepting messages through middleware or delegates, which is why this approach is much less ideal.

public enum ThrottleOption {
  /// <summary>
  /// Default behavior, returns a BadRequestResponse when throttling occurs
  /// </summary>
  None,
  /// <summary>
  /// Throw an exception when throttling occurs
  /// </summary>
  ThrowOnError
}

/// <summary>
/// SignalR rate throttling provider
/// </summary>
public class SignalRRateThrottleProvider {
  private IpRateLimitOptions _options;
  private IIpAddressParser _ipParser;
  private IpRateLimitProcessor _processor;
  private ILogger<SignalRRateThrottleProvider> _logger;

  /// <summary>
  /// Create a SignalR rate throttling provider
  /// </summary>
  /// <param name="logger"></param>
  /// <param name="rateLimit"></param>
  /// <param name="options"></param>
  /// <param name="ipParser"></param>
  public SignalRRateThrottleProvider (ILogger<SignalRRateThrottleProvider> logger, IpRateLimitProcessor rateLimit, IOptions<IpRateLimitOptions> options, IIpAddressParser ipParser) {
    _logger = logger;
    _options = options != null ? options.Value : null;
    _ipParser = ipParser;
    _processor = rateLimit;
  }

  /// <summary>
  /// Perform rate throttling on a SignalR request
  /// </summary>
  /// <param name="callerContext"></param>
  /// <returns></returns>
  public async Task<IActionResult> ThrottleAsync (HubCallerContext callerContext) {
    return await ThrottleAsync (callerContext, ThrottleOption.None);
  }

  /// <summary>
  /// Perform rate throttling on a SignalR request
  /// </summary>
  /// <param name="callerContext"></param>
  /// <param name="options">The throttle options to use</param>
  /// <returns></returns>
  public async Task<IActionResult> ThrottleAsync (HubCallerContext callerContext, ThrottleOption options) {
    if (callerContext != null) {
      var context = callerContext.GetHttpContext ();
      var identity = SetIdentity (context);
      if (_processor.IsWhitelisted (identity)) {
        // allow through
        return null;
      }

      var rules = _processor.GetMatchingRules (identity);
      return await ProcessRulesAsync (rules, identity, callerContext.GetHttpContext (), options);
    }

    // no context available, allow through
    return null;
  }

  private async Task<IActionResult> ProcessRulesAsync (List<RateLimitRule> rules, ClientRequestIdentity identity, HttpContext httpContext, ThrottleOption options) {
    foreach (var rule in rules) {
      if (rule.Limit > 0) {
        // increment counter
        var counter = _processor.ProcessRequest (identity, rule);

        // check if key expired
        if (counter.Timestamp + rule.PeriodTimespan.Value < DateTime.UtcNow) {
          continue;
        }

        // check if limit is reached
        if (counter.TotalRequests > rule.Limit) {
          //compute retry after value
          var retryAfter = _processor.RetryAfterFrom (counter.Timestamp, rule);

          // log blocked request
          LogBlockedRequest (httpContext, identity, counter, rule);

          // break execution
          return await ReturnQuotaExceededResponseAsync (httpContext, rule, retryAfter, options);
        }
      }
      // if limit is zero or less, block the request.
      else {
        // process request count
        var counter = _processor.ProcessRequest (identity, rule);

        // log blocked request
        LogBlockedRequest (httpContext, identity, counter, rule);

        // break execution (Int32 max used to represent infinity)
        return await ReturnQuotaExceededResponseAsync (httpContext, rule, Int32.MaxValue.ToString (System.Globalization.CultureInfo.InvariantCulture), options);
      }
    }

    return null;
  }

  private ClientRequestIdentity SetIdentity (HttpContext httpContext) {
    var clientId = "anon";
    if (httpContext.Request.Headers.Keys.Contains (_options.ClientIdHeader, StringComparer.CurrentCultureIgnoreCase)) {
      clientId = httpContext.Request.Headers[_options.ClientIdHeader].First ();
    }

    var clientIp = string.Empty;
    try {
      var ip = _ipParser.GetClientIp (httpContext);
      if (ip == null) {
        throw new Exception ("IpRateLimitMiddleware can't parse caller IP");
      }

      clientIp = ip.ToString ();
    } catch (Exception ex) {
      throw new Exception ("IpRateLimitMiddleware can't parse caller IP", ex);
    }

    return new ClientRequestIdentity {
      ClientIp = clientIp,
        Path = httpContext.Request.Path.ToString ().ToLowerInvariant (),
        HttpVerb = httpContext.Request.Method.ToLowerInvariant (),
        ClientId = clientId
    };
  }

  private async Task<IActionResult> ReturnQuotaExceededResponseAsync (HttpContext httpContext, RateLimitRule rule, string retryAfter, ThrottleOption options) {
    var message = string.IsNullOrEmpty (_options.QuotaExceededMessage) ? $"API calls quota exceeded! Maximum admitted {rule.Limit} per {rule.Period}." : _options.QuotaExceededMessage;
    var ex = new ThrottledException (message);
    if (options.HasFlag (ThrottleOption.ThrowOnError))
      throw ex;

    return new ApiBadRequestResponse (ex);
  }

  private void LogBlockedRequest (HttpContext httpContext, ClientRequestIdentity identity, RateLimitCounter counter, RateLimitRule rule) {
    _logger.LogInformation ($"Request {identity.HttpVerb}:{identity.Path} from IP {identity.ClientIp} has been blocked, quota {rule.Limit}/{rule.Period} exceeded by {counter.TotalRequests}. Blocked by rule {rule.Endpoint}, TraceIdentifier {httpContext.TraceIdentifier}.");
  }
}

and used as follows in the SignalR Hub - MyHub.cs

public async Task<IActionResult> TestEndpointAsync(MyRequest request)
{
    var response = await _rateThrottleProvider.ThrottleAsync(Context);
    if (response != null)
        return response;
    return await MyService.ProcessRequest(Context, request);
}

from aspnetcoreratelimit.

dzmitry-lahoda avatar dzmitry-lahoda commented on May 1, 2024

so this middleware will not work with SignalR? I.e. SignalR will be broken?

from aspnetcoreratelimit.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.