package mapset import ( "encoding/json" "fmt" "strconv" "strings" ) // New 返回 [MapSet] func New[T comparable](keys ...T) MapSet[T] { s := make(MapSet[T], len(keys)) s.AddMany(keys...) return s } // MapSet 是集合的范型实现 type MapSet[T comparable] map[T]struct{} // Cardinality 返回集合的元素个数 func (s MapSet[T]) Cardinality() int { return len(s) } // AddOne 添加元素到集合 func (s MapSet[T]) AddOne(key T) { s[key] = struct{}{} } // AddOneOk 添加元素到集合并返回是否添加成功 func (s MapSet[T]) AddOneOk(key T) bool { size := len(s) s[key] = struct{}{} return len(s) > size } // AddMany 添加多个元素到集合 func (s MapSet[T]) AddMany(keys ...T) { for _, key := range keys { s[key] = struct{}{} } } // AddManyOk 添加多个元素到集合并返回添加成功的个数 func (s MapSet[T]) AddManyOk(keys ...T) int { size := len(s) for _, key := range keys { s[key] = struct{}{} } return len(s) - size } // DeleteOne 从集合中删除元素 func (s MapSet[T]) DeleteOne(key T) { delete(s, key) } // DeleteOneOk 从集合中删除元素并返回是否删除成功 func (s MapSet[T]) DeleteOneOk(key T) bool { size := len(s) delete(s, key) return size > len(s) } // DeleteMany 从集合删除多个元素 func (s MapSet[T]) DeleteMany(keys ...T) { for _, key := range keys { delete(s, key) } } // DeleteManyOk 从集合删除多个元素并返回成功删除的个数 func (s MapSet[T]) DeleteManyOk(keys ...T) int { size := len(s) for _, key := range keys { delete(s, key) } return size - len(s) } // PopOne 从集合中删除1个元素, 并把被删除的元素返回 func (s MapSet[T]) PopOne() (key T, ok bool) { for k := range s { key, ok = k, true return } return } // Clear 清空集合 func (s MapSet[T]) Clear() { for key := range s { delete(s, key) } } // Clone 返回集合的浅拷贝 func (s MapSet[T]) Clone() MapSet[T] { n := make(MapSet[T], len(s)) for key := range s { n[key] = struct{}{} } return n } // IsEmpty 判断集合是否为空 func (s MapSet[T]) IsEmpty() bool { return len(s) == 0 } // Equal 判断集合是否和另1个集合相等 func (s MapSet[T]) Equal(o MapSet[T]) bool { if len(s) != len(o) { return false } for key := range s { if _, exists := o[key]; !exists { return false } } return true } // ContainsOne 判断集合是否包含1个元素 func (s MapSet[T]) ContainsOne(key T) bool { _, exists := s[key] return exists } // ContainsAll 判断集合是否包含全部元素 // 参数为空时返回 true func (s MapSet[T]) ContainsAll(keys ...T) bool { for _, key := range keys { if _, exists := s[key]; !exists { return false } } return true } // ContainsAny 判断集合是否包含任意元素 // 参数为空时返回 false func (s MapSet[T]) ContainsAny(keys ...T) bool { for _, key := range keys { if _, exists := s[key]; exists { return true } } return false } // IsSubsetOf 判断集合是否是另1个集合的子集 func (s MapSet[T]) IsSubsetOf(o MapSet[T]) bool { if len(s) > len(o) { return false } for key := range s { if _, exists := o[key]; !exists { return false } } return true } // IsProperSubsetOf 判断集合是否是另1个集合的真子集 func (s MapSet[T]) IsProperSubsetOf(o MapSet[T]) bool { if len(s) >= len(o) { return false } for key := range s { if _, exists := o[key]; !exists { return false } } return true } // Union 返回集合和另1个集合的并集 func (s MapSet[T]) Union(o MapSet[T]) MapSet[T] { r := make(MapSet[T], maxInt(len(s), len(o))) for key := range s { r[key] = struct{}{} } for key := range o { r[key] = struct{}{} } return r } // Intersection 返回集合和另1个集合的交集 func (s MapSet[T]) Intersection(o MapSet[T]) MapSet[T] { a, b := s, o if len(a) > len(b) { a, b = b, a } r := make(MapSet[T], len(a)) for key := range a { if _, exists := b[key]; exists { r[key] = struct{}{} } } return r } // Difference 返回集合和另1个集合的差集 func (s MapSet[T]) Difference(o MapSet[T]) MapSet[T] { r := make(MapSet[T]) for key := range s { if _, exists := o[key]; !exists { r[key] = struct{}{} } } return r } // SymmetricDifference 返回集合和另1个集合相互差集的并集 func (s MapSet[T]) SymmetricDifference(o MapSet[T]) MapSet[T] { r := make(MapSet[T]) for key := range s { if _, exists := o[key]; !exists { r[key] = struct{}{} } } for key := range o { if _, exists := s[key]; !exists { r[key] = struct{}{} } } return r } // ToSlice 把集合转换成切片 func (s MapSet[T]) ToSlice() []T { keys := make([]T, 0, len(s)) for key := range s { keys = append(keys, key) } return keys } // ToAnySlice 把集合转换成 []any func (s MapSet[T]) ToAnySlice() []any { keys := make([]any, 0, len(s)) for key := range s { keys = append(keys, key) } return keys } // MarshalJSON 实现 [encoding/json.Marshaler] func (s MapSet[T]) MarshalJSON() ([]byte, error) { return json.Marshal(s.ToSlice()) } // UnmarshalJSON 实现 [encoding/json.Unmarshaler] func (s *MapSet[T]) UnmarshalJSON(b []byte) error { var keys []T if *s == nil { *s = New[T]() } if err := json.Unmarshal(b, &keys); err != nil { return err } s.AddMany(keys...) return nil } func (s MapSet[T]) String() string { size := s.Cardinality() if size > 64 { size = 64 } keys := make([]string, 0, size) for key := range s { switch any(key).(type) { case string: keys = append(keys, fmt.Sprintf("%q", any(key))) default: keys = append(keys, fmt.Sprintf("%v", any(key))) } size-- if size == 0 { break } } return "MapSet(" + strconv.Itoa(s.Cardinality()) + "){" + strings.Join(keys, ",") + "}" } func maxInt(a, b int) int { if a > b { return a } return b }