| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * https://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package tcp |
| |
| import ( |
| "bufio" |
| "context" |
| "fmt" |
| "net" |
| "sync" |
| "sync/atomic" |
| |
| "github.com/pkg/errors" |
| "github.com/rs/zerolog" |
| |
| "github.com/apache/plc4x/plc4go/spi/options" |
| "github.com/apache/plc4x/plc4go/spi/transports" |
| transportUtils "github.com/apache/plc4x/plc4go/spi/transports/utils" |
| "github.com/apache/plc4x/plc4go/spi/utils" |
| ) |
| |
| type TransportInstance struct { |
| transportUtils.DefaultBufferedTransportInstance |
| |
| RemoteAddress *net.TCPAddr |
| LocalAddress *net.TCPAddr |
| ConnectTimeout uint32 |
| |
| transport *Transport |
| |
| tcpConn net.Conn |
| reader *bufio.Reader |
| |
| connected atomic.Bool |
| stateChangeMutex sync.RWMutex |
| |
| log zerolog.Logger |
| } |
| |
| func NewTcpTransportInstance(remoteAddress *net.TCPAddr, connectTimeout uint32, transport *Transport, _options ...options.WithOption) *TransportInstance { |
| customLogger := options.ExtractCustomLoggerOrDefaultToGlobal(_options...) |
| transportInstance := &TransportInstance{ |
| RemoteAddress: remoteAddress, |
| ConnectTimeout: connectTimeout, |
| transport: transport, |
| |
| log: customLogger, |
| } |
| transportInstance.DefaultBufferedTransportInstance = transportUtils.NewDefaultBufferedTransportInstance(transportInstance, _options...) |
| return transportInstance |
| } |
| |
| func (m *TransportInstance) Connect() error { |
| return m.ConnectWithContext(context.Background()) |
| } |
| |
| func (m *TransportInstance) ConnectWithContext(ctx context.Context) error { |
| if m.connected.Load() { |
| return errors.New("already connected") |
| } |
| m.stateChangeMutex.Lock() |
| defer m.stateChangeMutex.Unlock() |
| if m.RemoteAddress == nil { |
| return errors.New("Required remote address missing") |
| } |
| var err error |
| var d net.Dialer |
| m.tcpConn, err = d.DialContext(ctx, "tcp", m.RemoteAddress.String()) |
| if err != nil { |
| return errors.Wrap(err, "error connecting to remote address") |
| } |
| |
| m.LocalAddress = m.tcpConn.LocalAddr().(*net.TCPAddr) |
| |
| m.reader = bufio.NewReaderSize(m.tcpConn, 100000) |
| |
| m.connected.Store(true) |
| return nil |
| } |
| |
| func (m *TransportInstance) Close() error { |
| defer utils.StopWarn(m.log)() |
| m.stateChangeMutex.Lock() |
| defer m.stateChangeMutex.Unlock() |
| if !m.connected.Load() { |
| return nil |
| } |
| if err := m.tcpConn.Close(); err != nil { |
| return errors.Wrap(err, "error closing connection") |
| } |
| m.connected.Store(false) |
| return nil |
| } |
| |
| func (m *TransportInstance) IsConnected() bool { |
| return m.connected.Load() |
| } |
| |
| func (m *TransportInstance) Write(data []byte) error { |
| if !m.connected.Load() { |
| return errors.New("error writing to transport. Not connected") |
| } |
| num, err := m.tcpConn.Write(data) |
| if err != nil { |
| return errors.Wrap(err, "error writing") |
| } |
| if num != len(data) { |
| return errors.New("error writing: not all bytes written") |
| } |
| return nil |
| } |
| |
| func (m *TransportInstance) GetReader() transports.ExtendedReader { |
| return m.reader |
| } |
| |
| func (m *TransportInstance) String() string { |
| localAddress := "" |
| if m.LocalAddress != nil { |
| localAddress = m.LocalAddress.String() + "->" |
| } |
| return fmt.Sprintf("tcp:%s%s", localAddress, m.RemoteAddress) |
| } |