kvm/internal/jsonrpc/router.go

301 lines
7.5 KiB
Go

package jsonrpc
import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"reflect"
"sync"
"sync/atomic"
"time"
)
type JSONRPCRouter struct {
writer io.Writer
handlers map[string]*RPCHandler
nextId atomic.Int64
responseChannelsMutex sync.Mutex
responseChannels map[int64]chan JSONRPCResponse
}
func NewJSONRPCRouter(writer io.Writer, handlers map[string]*RPCHandler) *JSONRPCRouter {
return &JSONRPCRouter{
writer: writer,
handlers: handlers,
responseChannels: make(map[int64]chan JSONRPCResponse),
}
}
func (s *JSONRPCRouter) Request(method string, params map[string]interface{}, result interface{}) *JSONRPCResponseError {
id := s.nextId.Add(1)
request := JSONRPCRequest{
JSONRPC: "2.0",
Method: method,
Params: params,
ID: id,
}
requestBytes, err := json.Marshal(request)
if err != nil {
return &JSONRPCResponseError{
Code: -32700,
Message: "Parse error",
Data: err,
}
}
// log.Printf("Sending RPC request: Method=%s, Params=%v, ID=%d", method, params, id)
responseChan := make(chan JSONRPCResponse, 1)
s.responseChannelsMutex.Lock()
s.responseChannels[id] = responseChan
s.responseChannelsMutex.Unlock()
defer func() {
s.responseChannelsMutex.Lock()
delete(s.responseChannels, id)
s.responseChannelsMutex.Unlock()
}()
_, err = s.writer.Write(requestBytes)
if err != nil {
return &JSONRPCResponseError{
Code: -32603,
Message: "Internal error",
Data: err,
}
}
timeout := time.After(5 * time.Second)
select {
case response := <-responseChan:
if response.Error != nil {
return response.Error
}
rawResult, err := json.Marshal(response.Result)
if err != nil {
return &JSONRPCResponseError{
Code: -32603,
Message: "Internal error",
Data: err,
}
}
if err := json.Unmarshal(rawResult, result); err != nil {
return &JSONRPCResponseError{
Code: -32603,
Message: "Internal error",
Data: err,
}
}
return nil
case <-timeout:
return &JSONRPCResponseError{
Code: -32603,
Message: "Internal error",
Data: "timeout waiting for response",
}
}
}
type JSONRPCMessage struct {
Method *string `json:"method,omitempty"`
ID *int64 `json:"id,omitempty"`
}
func (s *JSONRPCRouter) HandleMessage(data []byte) error {
// Data will either be a JSONRPCRequest or JSONRPCResponse object
// We need to determine which one it is
var raw JSONRPCMessage
err := json.Unmarshal(data, &raw)
if err != nil {
errorResponse := JSONRPCResponse{
JSONRPC: "2.0",
Error: &JSONRPCResponseError{
Code: -32700,
Message: "Parse error",
},
ID: 0,
}
return s.writeResponse(errorResponse)
}
if raw.Method == nil && raw.ID != nil {
var resp JSONRPCResponse
if err := json.Unmarshal(data, &resp); err != nil {
fmt.Println("error unmarshalling response", err)
return err
}
s.responseChannelsMutex.Lock()
responseChan, ok := s.responseChannels[*raw.ID]
s.responseChannelsMutex.Unlock()
if ok {
responseChan <- resp
} else {
log.Println("No response channel found for ID", resp.ID)
}
return nil
}
var request JSONRPCRequest
err = json.Unmarshal(data, &request)
if err != nil {
errorResponse := JSONRPCResponse{
JSONRPC: "2.0",
Error: &JSONRPCResponseError{
Code: -32700,
Message: "Parse error",
},
ID: 0,
}
return s.writeResponse(errorResponse)
}
//log.Printf("Received RPC request: Method=%s, Params=%v, ID=%d", request.Method, request.Params, request.ID)
handler, ok := s.handlers[request.Method]
if !ok {
errorResponse := JSONRPCResponse{
JSONRPC: "2.0",
Error: &JSONRPCResponseError{
Code: -32601,
Message: "Method not found",
},
ID: request.ID,
}
return s.writeResponse(errorResponse)
}
result, err := callRPCHandler(handler, request.Params)
if err != nil {
errorResponse := JSONRPCResponse{
JSONRPC: "2.0",
Error: &JSONRPCResponseError{
Code: -32603,
Message: "Internal error",
Data: err.Error(),
},
ID: request.ID,
}
return s.writeResponse(errorResponse)
}
response := JSONRPCResponse{
JSONRPC: "2.0",
Result: result,
ID: request.ID,
}
return s.writeResponse(response)
}
func (s *JSONRPCRouter) writeResponse(response JSONRPCResponse) error {
responseBytes, err := json.Marshal(response)
if err != nil {
return err
}
_, err = s.writer.Write(responseBytes)
return err
}
func callRPCHandler(handler *RPCHandler, params map[string]interface{}) (interface{}, error) {
handlerValue := reflect.ValueOf(handler.Func)
handlerType := handlerValue.Type()
if handlerType.Kind() != reflect.Func {
return nil, errors.New("handler is not a function")
}
numParams := handlerType.NumIn()
args := make([]reflect.Value, numParams)
// Get the parameter names from the RPCHandler
paramNames := handler.Params
if len(paramNames) != numParams {
return nil, errors.New("mismatch between handler parameters and defined parameter names")
}
for i := 0; i < numParams; i++ {
paramType := handlerType.In(i)
paramName := paramNames[i]
paramValue, ok := params[paramName]
if !ok {
return nil, errors.New("missing parameter: " + paramName)
}
convertedValue := reflect.ValueOf(paramValue)
if !convertedValue.Type().ConvertibleTo(paramType) {
if paramType.Kind() == reflect.Slice && (convertedValue.Kind() == reflect.Slice || convertedValue.Kind() == reflect.Array) {
newSlice := reflect.MakeSlice(paramType, convertedValue.Len(), convertedValue.Len())
for j := 0; j < convertedValue.Len(); j++ {
elemValue := convertedValue.Index(j)
if elemValue.Kind() == reflect.Interface {
elemValue = elemValue.Elem()
}
if !elemValue.Type().ConvertibleTo(paramType.Elem()) {
// Handle float64 to uint8 conversion
if elemValue.Kind() == reflect.Float64 && paramType.Elem().Kind() == reflect.Uint8 {
intValue := int(elemValue.Float())
if intValue < 0 || intValue > 255 {
return nil, fmt.Errorf("value out of range for uint8: %v", intValue)
}
newSlice.Index(j).SetUint(uint64(intValue))
} else {
fromType := elemValue.Type()
toType := paramType.Elem()
return nil, fmt.Errorf("invalid element type in slice for parameter %s: from %v to %v", paramName, fromType, toType)
}
} else {
newSlice.Index(j).Set(elemValue.Convert(paramType.Elem()))
}
}
args[i] = newSlice
} else if paramType.Kind() == reflect.Struct && convertedValue.Kind() == reflect.Map {
jsonData, err := json.Marshal(convertedValue.Interface())
if err != nil {
return nil, fmt.Errorf("failed to marshal map to JSON: %v", err)
}
newStruct := reflect.New(paramType).Interface()
if err := json.Unmarshal(jsonData, newStruct); err != nil {
return nil, fmt.Errorf("failed to unmarshal JSON into struct: %v", err)
}
args[i] = reflect.ValueOf(newStruct).Elem()
} else {
return nil, fmt.Errorf("invalid parameter type for: %s", paramName)
}
} else {
args[i] = convertedValue.Convert(paramType)
}
}
results := handlerValue.Call(args)
if len(results) == 0 {
return nil, nil
}
if len(results) == 1 {
if results[0].Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) {
if !results[0].IsNil() {
return nil, results[0].Interface().(error)
}
return nil, nil
}
return results[0].Interface(), nil
}
if len(results) == 2 && results[1].Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) {
if !results[1].IsNil() {
return nil, results[1].Interface().(error)
}
return results[0].Interface(), nil
}
return nil, errors.New("unexpected return values from handler")
}