This commit is contained in:
@@ -2,32 +2,27 @@ package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"example.com/m/internal/config"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var Redis *redis.Client
|
||||
|
||||
func InitRedis(cfg *config.Config) (*redis.Client, error) {
|
||||
opt, err := redis.ParseURL(cfg.RedisURL)
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("impossible to parse redis url: %v", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to parse redis url: %w", err)
|
||||
}
|
||||
|
||||
Redis = redis.NewClient(opt)
|
||||
client := redis.NewClient(opt)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := Redis.Ping(ctx).Err(); err != nil {
|
||||
log.Fatalf("impossible to conenct to redis db: %v", err)
|
||||
return nil, err
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to redis: %w", err)
|
||||
}
|
||||
|
||||
return Redis, nil
|
||||
return client, nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package handlers
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
@@ -47,7 +48,7 @@ func (h *LinksHandler) CreateLink(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
address := fmt.Sprintf("http://%v:%s/r/%v", h.host, h.port, id)
|
||||
address := fmt.Sprintf("https://%s/r/%s", h.host, id)
|
||||
|
||||
response := CreateLinkResponse{
|
||||
Status: "success",
|
||||
@@ -89,5 +90,44 @@ func NormalizeURL(raw string) (string, error) {
|
||||
return "", errors.New("invalid host in URL")
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
if isPrivateHost(host) {
|
||||
return "", errors.New("URLs pointing to private/internal addresses are not allowed")
|
||||
}
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func isPrivateHost(host string) bool {
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
addrs, err := net.LookupHost(host)
|
||||
if err != nil || len(addrs) == 0 {
|
||||
return false
|
||||
}
|
||||
ip = net.ParseIP(addrs[0])
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
privateRanges := []string{
|
||||
"127.0.0.0/8",
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"169.254.0.0/16",
|
||||
"0.0.0.0/8",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
"fe80::/10",
|
||||
}
|
||||
|
||||
for _, cidr := range privateRanges {
|
||||
_, network, _ := net.ParseCIDR(cidr)
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -2,10 +2,10 @@ package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"example.com/m/internal/handlers"
|
||||
"example.com/m/internal/middleware"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -20,7 +20,12 @@ type Server struct {
|
||||
func NewServer(address string, lh *handlers.LinksHandler) *Server {
|
||||
engine := gin.New()
|
||||
|
||||
engine.Use(gin.Logger())
|
||||
engine.Use(gin.Recovery())
|
||||
engine.Use(middleware.SecurityHeaders())
|
||||
|
||||
rl := middleware.NewRateLimiter(10, 20)
|
||||
engine.Use(rl.Middleware())
|
||||
|
||||
return &Server{
|
||||
address: address,
|
||||
@@ -32,11 +37,11 @@ func NewServer(address string, lh *handlers.LinksHandler) *Server {
|
||||
func (s *Server) Start() error {
|
||||
s.routes()
|
||||
|
||||
fmt.Printf("starting server on %s\n", s.address)
|
||||
|
||||
if err := s.engine.Run(s.address); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("error occured while starting the server: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("error occurred while starting the server: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("server started: &v", s.address)
|
||||
return nil
|
||||
}
|
||||
|
||||
74
internal/middleware/ratelimit.go
Normal file
74
internal/middleware/ratelimit.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
limiter *rate.Limiter
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
type RateLimiter struct {
|
||||
clients map[string]*client
|
||||
mu sync.Mutex
|
||||
rps rate.Limit
|
||||
burst int
|
||||
}
|
||||
|
||||
func NewRateLimiter(requestsPerSecond float64, burst int) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
clients: make(map[string]*client),
|
||||
rps: rate.Limit(requestsPerSecond),
|
||||
burst: burst,
|
||||
}
|
||||
go rl.cleanup()
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) getClient(ip string) *rate.Limiter {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
if c, exists := rl.clients[ip]; exists {
|
||||
c.lastSeen = time.Now()
|
||||
return c.limiter
|
||||
}
|
||||
|
||||
limiter := rate.NewLimiter(rl.rps, rl.burst)
|
||||
rl.clients[ip] = &client{limiter: limiter, lastSeen: time.Now()}
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
for {
|
||||
time.Sleep(time.Minute)
|
||||
rl.mu.Lock()
|
||||
for ip, c := range rl.clients {
|
||||
if time.Since(c.lastSeen) > 3*time.Minute {
|
||||
delete(rl.clients, ip)
|
||||
}
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Middleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
limiter := rl.getClient(ip)
|
||||
if !limiter.Allow() {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "failed",
|
||||
"message": "rate limit exceeded",
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
14
internal/middleware/security.go
Normal file
14
internal/middleware/security.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package middleware
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
func SecurityHeaders() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("X-XSS-Protection", "1; mode=block")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
c.Header("Content-Security-Policy", "default-src 'none'")
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
|
||||
"example.com/m/internal/repositories"
|
||||
@@ -12,17 +13,23 @@ type LinksService struct {
|
||||
sqid *sqids.Sqids
|
||||
}
|
||||
|
||||
func NewLinksService(repo *repositories.LinksRepository) *LinksService {
|
||||
s, _ := sqids.New()
|
||||
func NewLinksService(repo *repositories.LinksRepository) (*LinksService, error) {
|
||||
s, err := sqids.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize sqids: %w", err)
|
||||
}
|
||||
return &LinksService{
|
||||
repo: repo,
|
||||
sqid: s,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *LinksService) CreateLink(original string) (string, error) {
|
||||
id, _ := s.sqid.Encode([]uint64{rand.Uint64()})
|
||||
err := s.repo.CreateLink(id, original)
|
||||
id, err := s.sqid.Encode([]uint64{rand.Uint64()})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to encode link id: %w", err)
|
||||
}
|
||||
err = s.repo.CreateLink(id, original)
|
||||
return id, err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user