package jsonrpc import ( "encoding/json" "errors" "fmt" "io" "log" "reflect" "sync" "sync/atomic" "time" ) type JSONRPCRouter struct { writer io.Writer handlers map[string]*RPCHandler nextId atomic.Int64 responseChannelsMutex sync.Mutex responseChannels map[int64]chan JSONRPCResponse } func NewJSONRPCRouter(writer io.Writer, handlers map[string]*RPCHandler) *JSONRPCRouter { return &JSONRPCRouter{ writer: writer, handlers: handlers, responseChannels: make(map[int64]chan JSONRPCResponse), } } func (s *JSONRPCRouter) Request(method string, params map[string]interface{}, result interface{}) *JSONRPCResponseError { id := s.nextId.Add(1) request := JSONRPCRequest{ JSONRPC: "2.0", Method: method, Params: params, ID: id, } requestBytes, err := json.Marshal(request) if err != nil { return &JSONRPCResponseError{ Code: -32700, Message: "Parse error", Data: err, } } // log.Printf("Sending RPC request: Method=%s, Params=%v, ID=%d", method, params, id) responseChan := make(chan JSONRPCResponse, 1) s.responseChannelsMutex.Lock() s.responseChannels[id] = responseChan s.responseChannelsMutex.Unlock() defer func() { s.responseChannelsMutex.Lock() delete(s.responseChannels, id) s.responseChannelsMutex.Unlock() }() _, err = s.writer.Write(requestBytes) if err != nil { return &JSONRPCResponseError{ Code: -32603, Message: "Internal error", Data: err, } } timeout := time.After(5 * time.Second) select { case response := <-responseChan: if response.Error != nil { return response.Error } rawResult, err := json.Marshal(response.Result) if err != nil { return &JSONRPCResponseError{ Code: -32603, Message: "Internal error", Data: err, } } if err := json.Unmarshal(rawResult, result); err != nil { return &JSONRPCResponseError{ Code: -32603, Message: "Internal error", Data: err, } } return nil case <-timeout: return &JSONRPCResponseError{ Code: -32603, Message: "Internal error", Data: "timeout waiting for response", } } } type JSONRPCMessage struct { Method *string `json:"method,omitempty"` ID *int64 `json:"id,omitempty"` } func (s *JSONRPCRouter) HandleMessage(data []byte) error { // Data will either be a JSONRPCRequest or JSONRPCResponse object // We need to determine which one it is var raw JSONRPCMessage err := json.Unmarshal(data, &raw) if err != nil { errorResponse := JSONRPCResponse{ JSONRPC: "2.0", Error: &JSONRPCResponseError{ Code: -32700, Message: "Parse error", }, ID: 0, } return s.writeResponse(errorResponse) } if raw.Method == nil && raw.ID != nil { var resp JSONRPCResponse if err := json.Unmarshal(data, &resp); err != nil { fmt.Println("error unmarshalling response", err) return err } s.responseChannelsMutex.Lock() responseChan, ok := s.responseChannels[*raw.ID] s.responseChannelsMutex.Unlock() if ok { responseChan <- resp } else { log.Println("No response channel found for ID", resp.ID) } return nil } var request JSONRPCRequest err = json.Unmarshal(data, &request) if err != nil { errorResponse := JSONRPCResponse{ JSONRPC: "2.0", Error: &JSONRPCResponseError{ Code: -32700, Message: "Parse error", }, ID: 0, } return s.writeResponse(errorResponse) } //log.Printf("Received RPC request: Method=%s, Params=%v, ID=%d", request.Method, request.Params, request.ID) handler, ok := s.handlers[request.Method] if !ok { errorResponse := JSONRPCResponse{ JSONRPC: "2.0", Error: &JSONRPCResponseError{ Code: -32601, Message: "Method not found", }, ID: request.ID, } return s.writeResponse(errorResponse) } result, err := callRPCHandler(handler, request.Params) if err != nil { errorResponse := JSONRPCResponse{ JSONRPC: "2.0", Error: &JSONRPCResponseError{ Code: -32603, Message: "Internal error", Data: err.Error(), }, ID: request.ID, } return s.writeResponse(errorResponse) } response := JSONRPCResponse{ JSONRPC: "2.0", Result: result, ID: request.ID, } return s.writeResponse(response) } func (s *JSONRPCRouter) writeResponse(response JSONRPCResponse) error { responseBytes, err := json.Marshal(response) if err != nil { return err } _, err = s.writer.Write(responseBytes) return err } func callRPCHandler(handler *RPCHandler, params map[string]interface{}) (interface{}, error) { handlerValue := reflect.ValueOf(handler.Func) handlerType := handlerValue.Type() if handlerType.Kind() != reflect.Func { return nil, errors.New("handler is not a function") } numParams := handlerType.NumIn() args := make([]reflect.Value, numParams) // Get the parameter names from the RPCHandler paramNames := handler.Params if len(paramNames) != numParams { return nil, errors.New("mismatch between handler parameters and defined parameter names") } for i := 0; i < numParams; i++ { paramType := handlerType.In(i) paramName := paramNames[i] paramValue, ok := params[paramName] if !ok { return nil, errors.New("missing parameter: " + paramName) } convertedValue := reflect.ValueOf(paramValue) if !convertedValue.Type().ConvertibleTo(paramType) { if paramType.Kind() == reflect.Slice && (convertedValue.Kind() == reflect.Slice || convertedValue.Kind() == reflect.Array) { newSlice := reflect.MakeSlice(paramType, convertedValue.Len(), convertedValue.Len()) for j := 0; j < convertedValue.Len(); j++ { elemValue := convertedValue.Index(j) if elemValue.Kind() == reflect.Interface { elemValue = elemValue.Elem() } if !elemValue.Type().ConvertibleTo(paramType.Elem()) { // Handle float64 to uint8 conversion if elemValue.Kind() == reflect.Float64 && paramType.Elem().Kind() == reflect.Uint8 { intValue := int(elemValue.Float()) if intValue < 0 || intValue > 255 { return nil, fmt.Errorf("value out of range for uint8: %v", intValue) } newSlice.Index(j).SetUint(uint64(intValue)) } else { fromType := elemValue.Type() toType := paramType.Elem() return nil, fmt.Errorf("invalid element type in slice for parameter %s: from %v to %v", paramName, fromType, toType) } } else { newSlice.Index(j).Set(elemValue.Convert(paramType.Elem())) } } args[i] = newSlice } else if paramType.Kind() == reflect.Struct && convertedValue.Kind() == reflect.Map { jsonData, err := json.Marshal(convertedValue.Interface()) if err != nil { return nil, fmt.Errorf("failed to marshal map to JSON: %v", err) } newStruct := reflect.New(paramType).Interface() if err := json.Unmarshal(jsonData, newStruct); err != nil { return nil, fmt.Errorf("failed to unmarshal JSON into struct: %v", err) } args[i] = reflect.ValueOf(newStruct).Elem() } else { return nil, fmt.Errorf("invalid parameter type for: %s", paramName) } } else { args[i] = convertedValue.Convert(paramType) } } results := handlerValue.Call(args) if len(results) == 0 { return nil, nil } if len(results) == 1 { if results[0].Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { if !results[0].IsNil() { return nil, results[0].Interface().(error) } return nil, nil } return results[0].Interface(), nil } if len(results) == 2 && results[1].Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) { if !results[1].IsNil() { return nil, results[1].Interface().(error) } return results[0].Interface(), nil } return nil, errors.New("unexpected return values from handler") }