feat: increased security
Some checks failed
continuous-integration/drone/push Build is failing

This commit is contained in:
2026-02-21 21:05:47 +03:00
parent 04830681b9
commit 8befcc11c1
14 changed files with 233 additions and 50 deletions

View File

@@ -2,32 +2,27 @@ package db
import (
"context"
"log"
"fmt"
"time"
"example.com/m/internal/config"
"github.com/redis/go-redis/v9"
)
var Redis *redis.Client
func InitRedis(cfg *config.Config) (*redis.Client, error) {
opt, err := redis.ParseURL(cfg.RedisURL)
if err != nil {
log.Fatalf("impossible to parse redis url: %v", err)
return nil, err
return nil, fmt.Errorf("failed to parse redis url: %w", err)
}
Redis = redis.NewClient(opt)
client := redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := Redis.Ping(ctx).Err(); err != nil {
log.Fatalf("impossible to conenct to redis db: %v", err)
return nil, err
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to redis: %w", err)
}
return Redis, nil
return client, nil
}

View File

@@ -3,6 +3,7 @@ package handlers
import (
"errors"
"fmt"
"net"
"net/http"
"net/url"
"strings"
@@ -47,7 +48,7 @@ func (h *LinksHandler) CreateLink(c *gin.Context) {
return
}
address := fmt.Sprintf("http://%v:%s/r/%v", h.host, h.port, id)
address := fmt.Sprintf("https://%s/r/%s", h.host, id)
response := CreateLinkResponse{
Status: "success",
@@ -89,5 +90,44 @@ func NormalizeURL(raw string) (string, error) {
return "", errors.New("invalid host in URL")
}
host := u.Hostname()
if isPrivateHost(host) {
return "", errors.New("URLs pointing to private/internal addresses are not allowed")
}
return u.String(), nil
}
func isPrivateHost(host string) bool {
ip := net.ParseIP(host)
if ip == nil {
addrs, err := net.LookupHost(host)
if err != nil || len(addrs) == 0 {
return false
}
ip = net.ParseIP(addrs[0])
if ip == nil {
return false
}
}
privateRanges := []string{
"127.0.0.0/8",
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"169.254.0.0/16",
"0.0.0.0/8",
"::1/128",
"fc00::/7",
"fe80::/10",
}
for _, cidr := range privateRanges {
_, network, _ := net.ParseCIDR(cidr)
if network.Contains(ip) {
return true
}
}
return false
}

View File

@@ -2,10 +2,10 @@ package http
import (
"fmt"
"log"
"net/http"
"example.com/m/internal/handlers"
"example.com/m/internal/middleware"
"github.com/gin-gonic/gin"
)
@@ -20,7 +20,12 @@ type Server struct {
func NewServer(address string, lh *handlers.LinksHandler) *Server {
engine := gin.New()
engine.Use(gin.Logger())
engine.Use(gin.Recovery())
engine.Use(middleware.SecurityHeaders())
rl := middleware.NewRateLimiter(10, 20)
engine.Use(rl.Middleware())
return &Server{
address: address,
@@ -32,11 +37,11 @@ func NewServer(address string, lh *handlers.LinksHandler) *Server {
func (s *Server) Start() error {
s.routes()
fmt.Printf("starting server on %s\n", s.address)
if err := s.engine.Run(s.address); err != nil && err != http.ErrServerClosed {
log.Fatalf("error occured while starting the server: %v", err)
return err
return fmt.Errorf("error occurred while starting the server: %w", err)
}
fmt.Println("server started: &v", s.address)
return nil
}

View File

@@ -0,0 +1,74 @@
package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
type client struct {
limiter *rate.Limiter
lastSeen time.Time
}
type RateLimiter struct {
clients map[string]*client
mu sync.Mutex
rps rate.Limit
burst int
}
func NewRateLimiter(requestsPerSecond float64, burst int) *RateLimiter {
rl := &RateLimiter{
clients: make(map[string]*client),
rps: rate.Limit(requestsPerSecond),
burst: burst,
}
go rl.cleanup()
return rl
}
func (rl *RateLimiter) getClient(ip string) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
if c, exists := rl.clients[ip]; exists {
c.lastSeen = time.Now()
return c.limiter
}
limiter := rate.NewLimiter(rl.rps, rl.burst)
rl.clients[ip] = &client{limiter: limiter, lastSeen: time.Now()}
return limiter
}
func (rl *RateLimiter) cleanup() {
for {
time.Sleep(time.Minute)
rl.mu.Lock()
for ip, c := range rl.clients {
if time.Since(c.lastSeen) > 3*time.Minute {
delete(rl.clients, ip)
}
}
rl.mu.Unlock()
}
}
func (rl *RateLimiter) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()
limiter := rl.getClient(ip)
if !limiter.Allow() {
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
"error": "failed",
"message": "rate limit exceeded",
})
return
}
c.Next()
}
}

View File

@@ -0,0 +1,14 @@
package middleware
import "github.com/gin-gonic/gin"
func SecurityHeaders() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("X-XSS-Protection", "1; mode=block")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
c.Header("Content-Security-Policy", "default-src 'none'")
c.Next()
}
}

View File

@@ -1,6 +1,7 @@
package services
import (
"fmt"
"math/rand"
"example.com/m/internal/repositories"
@@ -12,17 +13,23 @@ type LinksService struct {
sqid *sqids.Sqids
}
func NewLinksService(repo *repositories.LinksRepository) *LinksService {
s, _ := sqids.New()
func NewLinksService(repo *repositories.LinksRepository) (*LinksService, error) {
s, err := sqids.New()
if err != nil {
return nil, fmt.Errorf("failed to initialize sqids: %w", err)
}
return &LinksService{
repo: repo,
sqid: s,
}
}, nil
}
func (s *LinksService) CreateLink(original string) (string, error) {
id, _ := s.sqid.Encode([]uint64{rand.Uint64()})
err := s.repo.CreateLink(id, original)
id, err := s.sqid.Encode([]uint64{rand.Uint64()})
if err != nil {
return "", fmt.Errorf("failed to encode link id: %w", err)
}
err = s.repo.CreateLink(id, original)
return id, err
}