qbittorrent-multiplexer/handlers.go

418 lines
8.7 KiB
Go

package main
import (
"bytes"
"context"
"errors"
"io"
"log"
"net/http"
"net/url"
"slices"
"strings"
"github.com/W-Floyd/qbittorrent-docker-multiplexer/qbittorrent"
"github.com/W-Floyd/qbittorrent-docker-multiplexer/state"
"golang.org/x/sync/errgroup"
"github.com/Jeffail/gabs/v2"
)
func (c *Config) HandlerPassthrough(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.RequestURI(), "/api/") {
log.Println("Warning - Passthrough on API call: " + r.URL.RequestURI())
}
i := state.AppState.BalancerCount
resp, err := c.MakeRequest(r, &i)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
c.MakeResponse(resp, w)
state.AppState.BalancerCount += 1
state.AppState.BalancerCount %= state.AppState.NumberOfClients
}
func (c *Config) MakeRequest(r *http.Request, i *uint) (*http.Response, error) {
body, err := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewBuffer(body))
proxy, err := qbittorrent.GetProxy(&c.QBittorrent, i)
if err != nil {
return nil, err
}
req, err := http.NewRequest(
r.Method,
qbittorrent.URL(&c.QBittorrent, i, r.URL).String(),
bytes.NewBuffer(body),
)
if err != nil {
return nil, err
}
req.Header = r.Header.Clone()
err = PrepareRequest(req, c, i)
if err != nil {
return nil, err
}
oldDirector := proxy.Director
proxy.Director = func(r *http.Request) {
rewriteRequestURL(req, qbittorrent.URL(&c.QBittorrent, i, r.URL))
r.Header.Del("Referer")
}
// if strings.HasPrefix(r.URL.RequestURI(), "/api/v2/torrents/add") {
// log.Println(req)
// }
resp, err := proxy.Transport.RoundTrip(req)
proxy.Director = oldDirector
return resp, err
}
func (c Config) MakeResponse(resp *http.Response, w http.ResponseWriter) {
for header := range resp.Header {
w.Header().Add(header, resp.Header.Get(header))
}
io.Copy(w, resp.Body)
}
func (c *Config) HandlerHashFinder(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewBuffer(body))
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
}
hashes := []string{}
key := "hash"
if r.Form.Has("hash") {
hashes = append(hashes, r.Form.Get("hash"))
}
if r.Form.Has("hashes") {
key = "hashes"
hashes = append(hashes, strings.Split(r.Form.Get("hashes"), "|")...)
}
state.AppState.Locks.Torrents.Lock()
defer state.AppState.Locks.Torrents.Unlock()
resps := []*http.Response{}
for _, hash := range hashes {
form, err := url.ParseQuery(r.URL.RawQuery)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
}
form.Del("hash")
form.Del("hashes")
form.Set(key, hash)
r.URL.RawQuery = form.Encode()
instance, ok := state.AppState.Torrents[hash]
if ok {
resp, err := c.MakeRequest(r, &instance)
if err != nil || resp.StatusCode != http.StatusOK {
ok = false
} else {
resps = append(resps, resp)
}
r.Form.Del("")
} else if !ok {
log.Println("hash not known, looking up: " + hash)
for i := uint(0); i < state.AppState.NumberOfClients; i++ {
resp, err := c.MakeRequest(r, &i)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
} else if resp.StatusCode == http.StatusOK {
resps = append(resps, resp)
state.AppState.Torrents[hash] = i
}
}
}
}
if len(resps) == 1 {
c.MakeResponse(resps[0], w)
} else if len(resps) > 0 {
baseResp := resps[0]
baseBody := ""
for _, resp := range resps {
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
}
baseBody += string(body)
}
baseResp.Body = io.NopCloser(bytes.NewBuffer([]byte(baseBody)))
c.MakeResponse(baseResp, w)
} else {
http.Error(w, "no responses", http.StatusInternalServerError)
}
}
func (c *Config) Parallel(r *http.Request) (output []*struct {
response *gabs.Container
instance uint
}, err error) {
g, ctx := errgroup.WithContext(context.Background())
responses := make(chan struct {
response *http.Response
instance uint
})
body, err := io.ReadAll(r.Body)
if err != nil {
log.Println(err)
return
}
for i := uint(0); i < state.AppState.NumberOfClients; i++ {
g.Go(func() (err error) {
req, err := http.NewRequest(
r.Method,
qbittorrent.URL(&c.QBittorrent, &i, r.URL).String(),
bytes.NewBuffer(body),
)
if err != nil {
return err
}
resp, err := c.MakeRequest(req, &i)
if err != nil {
return err
} else if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return errors.New("Error" + string(body))
}
select {
case responses <- struct {
response *http.Response
instance uint
}{
response: resp,
instance: i,
}:
return nil
case <-ctx.Done():
return ctx.Err()
}
})
}
go func() {
g.Wait()
close(responses)
}()
for resp := range responses {
cont, err := gabs.ParseJSONBuffer(resp.response.Body)
if err != nil {
return nil, err
}
output = append(output, &struct {
response *gabs.Container
instance uint
}{
response: cont,
instance: resp.instance,
})
}
err = g.Wait()
if err != nil {
return
}
return output, nil
}
func (c *Config) HandlerFetchMergeArray(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/v2/torrents/info" {
state.AppState.Locks.Torrents.Lock()
defer state.AppState.Locks.Torrents.Unlock()
state.AppState.Torrents = map[string]uint{}
}
resps, err := c.Parallel(r)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
}
output := []*gabs.Container{}
for _, resp := range resps {
for _, child := range resp.response.Children() {
output = append(output, child)
if r.URL.Path == "/api/v2/torrents/info" {
hash := child.Path("hash").Data().(string)
state.AppState.Torrents[hash] = resp.instance
log.Println("Stored hash", hash)
}
}
}
if r.URL.Path == "/api/v2/torrents/info" {
slices.SortStableFunc(output, func(a *gabs.Container, b *gabs.Container) int {
return strings.Compare(a.Path("added_on").String(), b.Path("added_on").String())
})
}
w.Write(gabs.Wrap(output).BytesIndent("", " "))
}
func (c *Config) HandlerFetchMerge(w http.ResponseWriter, r *http.Request) {
resps, err := c.Parallel(r)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
}
output := &gabs.Container{}
for _, resp := range resps {
output.MergeFn(resp.response, func(dest, source interface{}) interface{} {
destArr, destIsArray := dest.([]interface{})
sourceArr, sourceIsArray := source.([]interface{})
if destIsArray {
if sourceIsArray {
return append(destArr, sourceArr...)
}
return append(destArr, source)
}
if sourceIsArray {
return append(append([]interface{}{}, dest), sourceArr...)
}
return source
})
}
w.Write(gabs.Wrap(output).BytesIndent("", " "))
}
// Prepepares and rewrites headers required to auth with qBittorrent
func PrepareRequest(r *http.Request, c *Config, i *uint) (err error) {
if r == nil {
return errors.New("empty request given")
}
if c == nil {
return errors.New("empty config given")
}
if i == nil {
return errors.New("empty instance given")
}
url := qbittorrent.URL(&c.QBittorrent, i, r.URL)
r.Host = url.String()
r.Header.Del("Referer")
r.Header.Del("Origin")
r.Header.Del("Accept-Encoding")
r.Header.Set("Cookie", state.AppState.Cookies[*i].Raw)
return
}
func rewriteRequestURL(req *http.Request, target *url.URL) {
targetQuery := target.RawQuery
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
}
func joinURLPath(a, b *url.URL) (path, rawpath string) {
if a.RawPath == "" && b.RawPath == "" {
return singleJoiningSlash(a.Path, b.Path), ""
}
// Same as singleJoiningSlash, but uses EscapedPath to determine
// whether a slash should be added
apath := a.EscapedPath()
bpath := b.EscapedPath()
aslash := strings.HasSuffix(apath, "/")
bslash := strings.HasPrefix(bpath, "/")
switch {
case aslash && bslash:
return a.Path + b.Path[1:], apath + bpath[1:]
case !aslash && !bslash:
return a.Path + "/" + b.Path, apath + "/" + bpath
}
return a.Path + b.Path, apath + bpath
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}