adapter.go 7.68 KB
Newer Older
1 2 3
package pgadapter

import (
khoipham's avatar
khoipham committed
4
	"fmt"
5
	"strings"
khoipham's avatar
khoipham committed
6 7 8 9 10

	"github.com/casbin/casbin/v2/model"
	"github.com/casbin/casbin/v2/persist"
	"github.com/go-pg/pg/v9"
	"github.com/go-pg/pg/v9/orm"
11
	"github.com/mmcloughlin/meow"
12 13 14 15 16 17 18 19
)

const (
	tableExistsErrorCode = "ERROR #42P07"
)

// CasbinRule represents a rule in Casbin.
type CasbinRule struct {
khoipham's avatar
khoipham committed
20
	ID    string
21 22 23 24 25 26 27 28 29
	PType string
	V0    string
	V1    string
	V2    string
	V3    string
	V4    string
	V5    string
}

30 31 32 33 34
type Filter struct {
	P []string
	G []string
}

35 36
// Adapter represents the github.com/go-pg/pg adapter for policy storage.
type Adapter struct {
37 38
	db       *pg.DB
	filtered bool
39 40 41
}

// NewAdapter is the constructor for Adapter.
khoipham's avatar
khoipham committed
42 43 44 45
// arg should be a PostgreS URL string or of type *pg.Options
// The adapter will create a DB named "casbin" if it doesn't exist
func NewAdapter(arg interface{}) (*Adapter, error) {
	db, err := createCasbinDatabase(arg)
46
	if err != nil {
khoipham's avatar
khoipham committed
47
		return nil, fmt.Errorf("pgadapter.NewAdapter: %v", err)
48 49
	}

khoipham's avatar
khoipham committed
50
	a := &Adapter{db: db}
51

khoipham's avatar
khoipham committed
52 53
	if err := a.createTable(); err != nil {
		return nil, fmt.Errorf("pgadapter.NewAdapter: %v", err)
54 55
	}

khoipham's avatar
khoipham committed
56
	return a, nil
57 58
}

59 60 61 62 63 64 65 66 67 68
// NewAdapterByDB creates new Adapter by using existing DB connection
// creates table from CasbinRule struct if it doesn't exist
func NewAdapterByDB(db *pg.DB) (*Adapter, error) {
	a := &Adapter{db: db}
	if err := a.createTable(); err != nil {
		return nil, fmt.Errorf("pgadapter.NewAdapter: %v", err)
	}
	return a, nil
}

khoipham's avatar
khoipham committed
69 70
func createCasbinDatabase(arg interface{}) (*pg.DB, error) {
	var opts *pg.Options
71
	var err error
khoipham's avatar
khoipham committed
72 73 74 75 76 77 78 79 80 81
	if connURL, ok := arg.(string); ok {
		opts, err = pg.ParseURL(connURL)
		if err != nil {
			return nil, err
		}
	} else {
		opts, ok = arg.(*pg.Options)
		if !ok {
			return nil, fmt.Errorf("must pass in a PostgreS URL string or an instance of *pg.Options, received %T instead", arg)
		}
82 83
	}

khoipham's avatar
khoipham committed
84 85
	db := pg.Connect(opts)
	defer db.Close()
86

khoipham's avatar
khoipham committed
87 88 89 90 91
	_, err = db.Exec("CREATE DATABASE casbin")
	db.Close()

	opts.Database = "casbin"
	db = pg.Connect(opts)
92

khoipham's avatar
khoipham committed
93
	return db, nil
94 95
}

khoipham's avatar
khoipham committed
96 97 98 99
// Close close database connection
func (a *Adapter) Close() error {
	if a != nil && a.db != nil {
		return a.db.Close()
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
	}
	return nil
}

func (a *Adapter) createTable() error {
	err := a.db.CreateTable(&CasbinRule{}, &orm.CreateTableOptions{
		Temp: false,
	})
	if err != nil {
		errorCode := err.Error()[0:12]
		if errorCode != tableExistsErrorCode {
			return err
		}
	}
	return nil
}

117
func (r *CasbinRule) String() string {
118 119 120
	const prefixLine = ", "
	var sb strings.Builder

121 122
	sb.WriteString(r.PType)
	if len(r.V0) > 0 {
123
		sb.WriteString(prefixLine)
124
		sb.WriteString(r.V0)
125
	}
126
	if len(r.V1) > 0 {
127
		sb.WriteString(prefixLine)
128
		sb.WriteString(r.V1)
129
	}
130
	if len(r.V2) > 0 {
131
		sb.WriteString(prefixLine)
132
		sb.WriteString(r.V2)
133
	}
134
	if len(r.V3) > 0 {
135
		sb.WriteString(prefixLine)
136
		sb.WriteString(r.V3)
137
	}
138
	if len(r.V4) > 0 {
139
		sb.WriteString(prefixLine)
140
		sb.WriteString(r.V4)
141
	}
142
	if len(r.V5) > 0 {
143
		sb.WriteString(prefixLine)
144
		sb.WriteString(r.V5)
145 146
	}

147
	return sb.String()
148 149 150 151 152 153 154 155 156 157 158
}

// LoadPolicy loads policy from database.
func (a *Adapter) LoadPolicy(model model.Model) error {
	var lines []*CasbinRule

	if _, err := a.db.Query(&lines, `SELECT * FROM casbin_rules`); err != nil {
		return err
	}

	for _, line := range lines {
159
		persist.LoadPolicyLine(line.String(), model)
160 161
	}

162 163
	a.filtered = false

164 165 166
	return nil
}

khoipham's avatar
khoipham committed
167 168
func policyID(ptype string, rule []string) string {
	data := strings.Join(append([]string{ptype}, rule...), ",")
169
	sum := meow.Checksum(0, []byte(data))
khoipham's avatar
khoipham committed
170 171 172
	return fmt.Sprintf("%x", sum)
}

173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
func savePolicyLine(ptype string, rule []string) *CasbinRule {
	line := &CasbinRule{PType: ptype}

	l := len(rule)
	if l > 0 {
		line.V0 = rule[0]
	}
	if l > 1 {
		line.V1 = rule[1]
	}
	if l > 2 {
		line.V2 = rule[2]
	}
	if l > 3 {
		line.V3 = rule[3]
	}
	if l > 4 {
		line.V4 = rule[4]
	}
	if l > 5 {
		line.V5 = rule[5]
	}

khoipham's avatar
khoipham committed
196
	line.ID = policyID(ptype, rule)
khoipham's avatar
khoipham committed
197

198 199 200 201 202
	return line
}

// SavePolicy saves policy to database.
func (a *Adapter) SavePolicy(model model.Model) error {
203 204 205 206 207
	_, err := a.db.Model((*CasbinRule)(nil)).Where("id IS NOT NULL").Delete()
	if err != nil {
		return err
	}

208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
	var lines []*CasbinRule

	for ptype, ast := range model["p"] {
		for _, rule := range ast.Policy {
			line := savePolicyLine(ptype, rule)
			lines = append(lines, line)
		}
	}

	for ptype, ast := range model["g"] {
		for _, rule := range ast.Policy {
			line := savePolicyLine(ptype, rule)
			lines = append(lines, line)
		}
	}

224 225 226 227 228 229 230
	if len(lines) > 0 {
		_, err = a.db.Model(&lines).
			OnConflict("DO NOTHING").
			Insert()
		return err
	}
	return nil
231 232 233 234 235
}

// AddPolicy adds a policy rule to the storage.
func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
	line := savePolicyLine(ptype, rule)
khoipham's avatar
khoipham committed
236 237 238
	_, err := a.db.Model(line).
		OnConflict("DO NOTHING").
		Insert()
239 240 241 242 243 244 245 246 247 248 249 250
	return err
}

// RemovePolicy removes a policy rule from the storage.
func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
	line := savePolicyLine(ptype, rule)
	err := a.db.Delete(line)
	return err
}

// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
khoipham's avatar
khoipham committed
251
	query := a.db.Model((*CasbinRule)(nil)).Where("p_type = ?", ptype)
252 253

	idx := fieldIndex + len(fieldValues)
khoipham's avatar
khoipham committed
254 255
	if fieldIndex <= 0 && idx > 0 && fieldValues[0-fieldIndex] != "" {
		query = query.Where("v0 = ?", fieldValues[0-fieldIndex])
256
	}
khoipham's avatar
khoipham committed
257 258
	if fieldIndex <= 1 && idx > 1 && fieldValues[1-fieldIndex] != "" {
		query = query.Where("v1 = ?", fieldValues[1-fieldIndex])
259
	}
khoipham's avatar
khoipham committed
260 261
	if fieldIndex <= 2 && idx > 2 && fieldValues[2-fieldIndex] != "" {
		query = query.Where("v2 = ?", fieldValues[2-fieldIndex])
262
	}
khoipham's avatar
khoipham committed
263 264
	if fieldIndex <= 3 && idx > 3 && fieldValues[3-fieldIndex] != "" {
		query = query.Where("v3 = ?", fieldValues[3-fieldIndex])
265
	}
khoipham's avatar
khoipham committed
266 267
	if fieldIndex <= 4 && idx > 4 && fieldValues[4-fieldIndex] != "" {
		query = query.Where("v4 = ?", fieldValues[4-fieldIndex])
268
	}
khoipham's avatar
khoipham committed
269 270
	if fieldIndex <= 5 && idx > 5 && fieldValues[5-fieldIndex] != "" {
		query = query.Where("v5 = ?", fieldValues[5-fieldIndex])
271 272
	}

khoipham's avatar
khoipham committed
273
	_, err := query.Delete()
274 275
	return err
}
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359

func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) error {
	if filter == nil {
		return a.LoadPolicy(model)
	}

	filterValue, ok := filter.(*Filter)
	if !ok {
		return fmt.Errorf("invalid filter type")
	}
	err := a.loadFilteredPolicy(model, filterValue, persist.LoadPolicyLine)
	if err != nil {
		return err
	}
	a.filtered = true
	return nil
}

func buildQuery(query *orm.Query, values []string) (*orm.Query, error) {
	for ind, v := range values {
		if v == "" {
			continue
		}
		switch ind {
		case 0:
			query = query.Where("v0 = ?", v)
		case 1:
			query = query.Where("v1 = ?", v)
		case 2:
			query = query.Where("v2 = ?", v)
		case 3:
			query = query.Where("v3 = ?", v)
		case 4:
			query = query.Where("v4 = ?", v)
		case 5:
			query = query.Where("v5 = ?", v)
		default:
			return nil, fmt.Errorf("filter has more values than expected, should not exceed 6 values")
		}
	}
	return query, nil
}

func (a *Adapter) loadFilteredPolicy(model model.Model, filter *Filter, handler func(string, model.Model)) error {
	if filter.P != nil {
		lines := []*CasbinRule{}

		query := a.db.Model(&lines).Where("p_type = 'p'")
		query, err := buildQuery(query, filter.P)
		if err != nil {
			return err
		}
		err = query.Select()
		if err != nil {
			return err
		}

		for _, line := range lines {
			handler(line.String(), model)
		}
	}
	if filter.G != nil {
		lines := []*CasbinRule{}

		query := a.db.Model(&lines).Where("p_type = 'g'")
		query, err := buildQuery(query, filter.G)
		if err != nil {
			return err
		}
		err = query.Select()
		if err != nil {
			return err
		}

		for _, line := range lines {
			handler(line.String(), model)
		}
	}
	return nil
}

func (a *Adapter) IsFiltered() bool {
	return a.filtered
}