Skip to content

二维数组合并

约 1355 字大约 5 分钟

数组合并Go

2022-09-19

前言

对接港交所(HKEX)时,HKEX提供了一组线路,一共三条:实时(UDP)、刷新(UDP)、重传(TCP)。 实时和刷新线路采用UDP协议,采用组播的方式传输,因此可以支持多个客户端接入。而UDP丢包的概率比较大,就需要发起重传,请求指定的消息包序号及消息包数量,HKEX返回这一段的消息包,但是消息包数量上限为 10000。当请求的消息包数量超过 10000 时,HKEX会通过刷新线路返回当前的消息快照。 但HKEX提供的线路只有这么一组,当有多个客户端需要接入HKEX的数据时,这一组线路就不够用了,再申请一组线路的成本又太大。因此就需要开发一个重传代理服务,代理和HKEX保持TCP连接,多客户端和代理保持TCP连接。 但客户端增多时,丢包的概率会成倍增加,重传的次数也会成倍增加,频繁的重传HKEX会警告,严重则会封号。。。那么每次HKEX的重传响应就必须缓存起来,客户端下次再重传时优先检查缓存,尽量减少向HKEX真实发起重传请求。

需求

TCP包内的消息序号一定是连续的,但缓存的多个 TCP 包不一定是连续的,如:[[3,4,5],[7,8],[12,13,14]]。若下一个 TCP 包中包含的消息序号为:[6],则缓存应该合并:[[3,4,5,6,7,8],[12,13,14]]。如果新的消息包交叉则取并集,若下一次 TCP 包中的消息序号为:[13,14,15,16],则缓存应该合并为:[[3,4,5,6,7,8],[12,13,14,15,16]]。

1. 实现

缓存代码cache.go

package cache

import (
	"fmt"
	"hkex/omdc/model"
	"hkex/omdc/packet"
	"math"
)

// 缓存容器
var cacheM = make(map[uint16][]*Message)

// 每个 channel 的 message 缓存上限
const sizeLimit = 10000

func Clear() {
	cacheM = make(map[uint16][]*Message)
}

// 缓存的消息结构
type Message struct {
	Start  uint32
	Count  int
	Models []model.Message
}

// 缓存的大小
func Size(msgs []*Message) int {
	l := 0
	for _, m := range msgs {
		l += m.Count
	}
	return l
}

type CheckResult struct {
	Begin   uint32 // 当需要发起重传请求时的起始序号
	End     uint32 // 当需要发起重传请求时的截止序号
	Request bool   // 是否需要发起请求
}

// 判断是否需要发起重传请求
func Check(channel uint16, begin, end uint32) CheckResult {
	msgs := cacheM[channel]
	r := CheckResult{
		Begin:   begin,
		End:     end,
		Request: true,
	}
	for _, msg := range msgs {
		msgEnd := int(msg.Start) + msg.Count - 1
		if end < msg.Start {
			return r
		} else if end >= msg.Start && int(end) <= msgEnd {
			if begin < msg.Start {
				r.End = msg.Start - 1
				return r
			} else {
				// 无需请求
				r.Request = false
				return r
			}
		} else if int(end) > msgEnd {
			if begin < msg.Start {
				return r
			} else if begin >= msg.Start && int(begin) <= msgEnd {
				r.Begin = uint32(msgEnd + 1)
				return r
			} else {
				continue
			}
		}
	}
	return r
}

// 抽取消息包封装返回。业务逻辑处理到这里是一定能够取到的
func Get(channel uint16, begin, end uint32) []*packet.Buffer {
	msgs := cacheM[channel]
	for _, msg := range msgs {
		msgEnd := int(msg.Start) + msg.Count - 1
		if msg.Start <= begin && msgEnd >= int(end) {
			// 偏移量
			offset := begin - msg.Start
			size := end - begin + 1
			messages := msg.Models[offset : offset+size]
			// packet.Buffer 中 num 的数据类型为 uint8,值范围 0 ~ 255,因此每个数据包最多有 255 个 message
			arr := make([]model.Message, 0, 255)
			count := math.Ceil(float64(size) / 255)
			pkts := make([]*packet.Buffer, 0, int(count))

			seq := begin
			for i := 1; i <= len(messages); i++ {
				arr = append(arr, messages[i-1])
				if i%255 == 0 {
					p := new(packet.Buffer)
					for _, m := range arr {
						_ = p.WriteMessage(m)
					}
					p.SeqNum = seq
					// 加入
					pkts = append(pkts, p)
					// 更新序号
					seq += uint32(i)
					// 重置
					arr = make([]model.Message, 0, 255)
				}
			}
			if len(arr) > 0 {
				p := new(packet.Buffer)
				for _, m := range arr {
					_ = p.WriteMessage(m)
				}
				p.SeqNum = seq
				pkts = append(pkts, p)
			}
			return pkts // 返回
		}
		continue
	}
	return nil
}

func Put(channel uint16, insert *Message) {
	if msgs, ok := cacheM[channel]; ok {
		// 合并
		merge := Merge(insert, msgs)
		// 限制大小
		merge = Truncate(merge)
		cacheM[channel] = merge
	} else {
		ms := make([]*Message, 0, 1)
		ms = append(ms, insert)
		cacheM[channel] = ms
	}
}

func Merge(insert *Message, msgs []*Message) []*Message {
	rets := make([]*Message, 0)
	// 找到间隙
	for i, msg := range msgs {
		insEnd := int(insert.Start) + insert.Count - 1 // end 序号
		msgEnd := int(msg.Start) + msg.Count - 1       // end 序号

		if insEnd < int(msg.Start)-1 { // 间隙
			rets = append(rets, insert)
			rets = append(rets, msgs[i:]...)
			return rets
		} else if insEnd == int(msg.Start)-1 { // 连续
			insert.Models = append(insert.Models, msg.Models...)
			insert.Count += msg.Count
			rets = append(rets, insert)
			rets = append(rets, msgs[i+1:]...)
			return rets
		} else if insEnd > int(msg.Start) && insEnd < msgEnd {
			if insert.Start < msg.Start { // 交叉
				skip := insEnd - int(msg.Start)
				insert.Models = append(insert.Models, msg.Models[skip+1:]...)
				insert.Count += msg.Count - skip - 1
				rets = append(rets, insert)
				rets = append(rets, msgs[i+1:]...)
				return rets
			} else {
				// message 忽略
				return msgs
			}
		} else if insEnd >= msgEnd {
			if insert.Start <= msg.Start {
				continue
			} else if insert.Start > msg.Start && int(insert.Start) <= msgEnd {
				skip := msgEnd - int(insert.Start)
				msg.Models = append(msg.Models, insert.Models[skip+1:]...)
				msg.Count += insert.Count - skip - 1
				insert = msg
			} else if int(insert.Start) == msgEnd+1 { // 连续
				msg.Models = append(msg.Models, insert.Models...)
				msg.Count += insert.Count
				insert = msg
			} else { // 间隙
				rets = append(rets, msg)
			}
		}
	}
	rets = append(rets, insert)
	return rets
}

// 淘汰序号最靠前的消息
func Truncate(msgs []*Message) []*Message {
	// 总大小
	size := Size(msgs)
	if size <= sizeLimit {
		return msgs
	}

	// 丢弃数量
	throw := size - sizeLimit
	for i, msg := range msgs {
		if msg.Count > throw {
			msg.Models = msg.Models[msg.Count-throw:]
			msg.Count -= throw
			msgs[i] = msg
			msgs = msgs[i:]
			break
		} else if msg.Count == throw {
			msgs = msgs[i+1:]
			break
		} else {
			throw -= msg.Count
		}
	}
	return msgs
}

// 测试用
func Console(channel uint16) {
	msgs := cacheM[channel]
	for _, msg := range msgs {
		fmt.Println("start = ", msg.Start, "end = ", int(msg.Start)+msg.Count-1, "count = ", msg.Count)
	}
}

测试cache_test.go

package cache

import (
	"hkex/omdc/model"
	"testing"
)

func TestMerge(t *testing.T) {
	var channel uint16 = 1
	m1 := &Message{
		Start:  2,
		Count:  2,
		Models: []model.Message{model.NewLogonResponse(), model.NewRetransmissionResponse()},
	}
	Put(channel, m1)

	m2 := &Message{
		Start:  5,
		Count:  3,
		Models: []model.Message{model.NewLogonResponse(), model.NewRetransmissionResponse(), model.NewLogon()},
	}
	Put(channel, m2)

	m3 := &Message{
		Start:  1,
		Count:  8,
		Models: []model.Message{model.NewLogonResponse(), model.NewRetransmissionResponse(), model.NewRetransmissionResponse(), model.NewLogon(), model.NewLogonResponse(), model.NewRetransmissionResponse(), model.NewRetransmissionResponse(), model.NewLogon()},
	}
	Put(channel, m3)

	Console(channel)
}

测试结果:

=== RUN   TestMerge
start =  1 end =  8 count =  8
--- PASS: TestMerge (0.00s)
PASS

附件

香港交易所市场平台-证券市场(OMD-C)