Compare commits

...

16 Commits

Author SHA1 Message Date
Andreas Auernhammer 404d2ebe3f
set SSE headers in put-part response (#12008)
This commit fixes a bug in the put-part
implementation. The SSE headers should be
set as specified by AWS - See:
https://docs.aws.amazon.com/AmazonS3/latest/API/API_UploadPart.html

Now, the MinIO server should set SSE-C headers,
like `x-amz-server-side-encryption-customer-algorithm`.

Fixes #11991
2021-04-07 14:50:28 -07:00
Minio Trusted 46964eb764 Update yaml files to latest version RELEASE.2021-04-06T23-11-00Z 2021-04-06 23:35:33 +00:00
Poorna Krishnamoorthy bfab990c33 Improve error message from SetRemoteTargetHandler (#11909) 2021-04-06 12:42:30 -07:00
Harshavardhana 94018588fe unmarshal both LegalHold and ObjectLockLegalHold XML types (#11921)
Because of silly AWS S3 behavior we to handle both types.

fixes #11920
2021-04-06 12:41:56 -07:00
Anis Elleuch 8b76ba8d5d crawling: Apply lifecycle then decide healing action (#11563)
It is inefficient to decide to heal an object before checking its
lifecycle for expiration or transition. This commit will just reverse
the order of action: evaluate lifecycle and heal only if asked and
lifecycle resulted a NoneAction.
2021-04-06 12:41:51 -07:00
Harshavardhana 7eb7f65e48 add policy conditions support for signatureVersion and authType (#11947)
https://docs.aws.amazon.com/AmazonS3/latest/API/bucket-policy-s3-sigv4-conditions.html

fixes #11944
2021-04-06 12:41:31 -07:00
Harshavardhana c608c0688a fix: properly close leaking bandwidth monitor channel (#11967)
This PR fixes

- close leaking bandwidth report channel leakage
- remove the closer requirement for bandwidth monitor
  instead if Read() fails remember the error and return
  error for all subsequent reads.
- use locking for usage-cache.bin updates, with inline
  data we cannot afford to have concurrent writes to
  usage-cache.bin corrupting xl.meta
2021-04-06 12:40:42 -07:00
Aditya Manthramurthy 41a9d1d778 Fix S3Select SQL column reference handling (#11957)
This change fixes handling of these types of queries:

- Double quoted column names with special characters:
    SELECT "column.name" FROM s3object
- Double quoted column names with reserved keywords:
    SELECT "CAST" FROM s3object
- Table name as prefix for column names:
    SELECT S3Object."CAST" FROM s3object
2021-04-06 12:40:28 -07:00
Klaus Post e21e80841e Fix data race when connecting disks (#11983)
Multiple disks from the same set would be writing concurrently.

```
WARNING: DATA RACE
Write at 0x00c002100ce0 by goroutine 166:
  github.com/minio/minio/cmd.(*erasureSets).connectDisks.func1()
      d:/minio/minio/cmd/erasure-sets.go:254 +0x82f

Previous write at 0x00c002100ce0 by goroutine 129:
  github.com/minio/minio/cmd.(*erasureSets).connectDisks.func1()
      d:/minio/minio/cmd/erasure-sets.go:254 +0x82f

Goroutine 166 (running) created at:
  github.com/minio/minio/cmd.(*erasureSets).connectDisks()
      d:/minio/minio/cmd/erasure-sets.go:210 +0x324
  github.com/minio/minio/cmd.(*erasureSets).monitorAndConnectEndpoints()
      d:/minio/minio/cmd/erasure-sets.go:288 +0x244

Goroutine 129 (finished) created at:
  github.com/minio/minio/cmd.(*erasureSets).connectDisks()
      d:/minio/minio/cmd/erasure-sets.go:210 +0x324
  github.com/minio/minio/cmd.(*erasureSets).monitorAndConnectEndpoints()
      d:/minio/minio/cmd/erasure-sets.go:288 +0x244
```
2021-04-06 12:39:59 -07:00
Klaus Post 98c792bbeb Fix disk info race (#11984)
Protect updated members in xlStorage.

```
WARNING: DATA RACE
Write at 0x00c004b4ee78 by goroutine 1491:
  github.com/minio/minio/cmd.(*xlStorage).GetDiskID()
      d:/minio/minio/cmd/xl-storage.go:590 +0x1078
  github.com/minio/minio/cmd.(*xlStorageDiskIDCheck).checkDiskStale()
      d:/minio/minio/cmd/xl-storage-disk-id-check.go:195 +0x84
  github.com/minio/minio/cmd.(*xlStorageDiskIDCheck).StatVol()
      d:/minio/minio/cmd/xl-storage-disk-id-check.go:284 +0x16a
  github.com/minio/minio/cmd.erasureObjects.getBucketInfo.func1()
      d:/minio/minio/cmd/erasure-bucket.go:100 +0x1a5
  github.com/minio/minio/pkg/sync/errgroup.(*Group).Go.func1()
      d:/minio/minio/pkg/sync/errgroup/errgroup.go:122 +0xd7

Previous read at 0x00c004b4ee78 by goroutine 1087:
  github.com/minio/minio/cmd.(*xlStorage).CheckFile.func1()
      d:/minio/minio/cmd/xl-storage.go:1699 +0x384
  github.com/minio/minio/cmd.(*xlStorage).CheckFile()
      d:/minio/minio/cmd/xl-storage.go:1726 +0x13c
  github.com/minio/minio/cmd.(*xlStorageDiskIDCheck).CheckFile()
      d:/minio/minio/cmd/xl-storage-disk-id-check.go:446 +0x23b
  github.com/minio/minio/cmd.erasureObjects.parentDirIsObject.func1()
      d:/minio/minio/cmd/erasure-common.go:173 +0x194
  github.com/minio/minio/pkg/sync/errgroup.(*Group).Go.func1()
      d:/minio/minio/pkg/sync/errgroup/errgroup.go:122 +0xd7
```
2021-04-06 12:39:57 -07:00
Klaus Post f687ba53bc Fix Access Key requests (#11979)
Fix accessing claims when auth error is unchecked.

Only replaced when unchecked and when clearly without side effects.

Fixes #11959
2021-04-06 11:03:55 -07:00
Harshavardhana e3da59c923 fix possible crash in bucket bandwidth monitor (#11986) 2021-04-06 11:03:41 -07:00
Harshavardhana 781b9b051c fix: service accounts policy enforcement regression (#11910)
service accounts were not inheriting parent policies
anymore due to refactors in the PolicyDBGet() from
the latest release, fix this behavior properly.
2021-04-06 08:58:05 -07:00
Harshavardhana 438becfde8 fix: delete/delete marker replication versions consistent (#11932)
replication didn't work as expected when deletion of
delete markers was requested in DeleteMultipleObjects
API, this is due to incorrect lookup elements being
used to look for delete markers.
2021-04-06 08:57:36 -07:00
Harshavardhana 16ef338649 fix: notify parent user in notification events (#11934)
fixes #11885
2021-04-06 08:55:37 -07:00
Harshavardhana 3242847ec0 avoid network read errors crashing CreateFile call (#11939)
Thanks to @dvaldivia for reproducing this
2021-04-06 08:55:30 -07:00
37 changed files with 602 additions and 274 deletions

View File

@ -172,7 +172,12 @@ func (a adminAPIHandlers) SetRemoteTargetHandler(w http.ResponseWriter, r *http.
}
if err = globalBucketTargetSys.SetTarget(ctx, bucket, &target, update); err != nil {
writeErrorResponseJSON(ctx, w, toAPIError(ctx, err), r.URL)
switch err.(type) {
case BucketRemoteConnectionErr:
writeErrorResponseJSON(ctx, w, errorCodes.ToAPIErrWithErr(ErrReplicationRemoteConnectionError, err), r.URL)
default:
writeErrorResponseJSON(ctx, w, toAPIError(ctx, err), r.URL)
}
return
}
targets, err := globalBucketTargetSys.ListBucketTargets(ctx, bucket)

View File

@ -24,6 +24,7 @@ import (
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"net/url"
"os"
@ -1470,30 +1471,33 @@ func (a adminAPIHandlers) BandwidthMonitorHandler(w http.ResponseWriter, r *http
return
}
rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
setEventStreamHeaders(w)
reportCh := make(chan bandwidth.Report, 1)
reportCh := make(chan bandwidth.Report)
keepAliveTicker := time.NewTicker(500 * time.Millisecond)
defer keepAliveTicker.Stop()
bucketsRequestedString := r.URL.Query().Get("buckets")
bucketsRequested := strings.Split(bucketsRequestedString, ",")
go func() {
defer close(reportCh)
for {
reportCh <- globalNotificationSys.GetBandwidthReports(ctx, bucketsRequested...)
select {
case <-ctx.Done():
return
default:
time.Sleep(2 * time.Second)
case reportCh <- globalNotificationSys.GetBandwidthReports(ctx, bucketsRequested...):
time.Sleep(time.Duration(rnd.Float64() * float64(2*time.Second)))
}
}
}()
for {
select {
case report := <-reportCh:
enc := json.NewEncoder(w)
err := enc.Encode(report)
if err != nil {
writeErrorResponseJSON(ctx, w, errorCodes.ToAPIErr(ErrInternalError), r.URL)
case report, ok := <-reportCh:
if !ok {
return
}
if err := json.NewEncoder(w).Encode(report); err != nil {
writeErrorResponseJSON(ctx, w, toAPIError(ctx, err), r.URL)
return
}
w.(http.Flusher).Flush()

View File

@ -496,7 +496,7 @@ func (api objectAPIHandlers) DeleteMultipleObjectsHandler(w http.ResponseWriter,
object.PurgeTransitioned = goi.TransitionStatus
}
if replicateDeletes {
delMarker, replicate, repsync := checkReplicateDelete(ctx, bucket, ObjectToDelete{
replicate, repsync := checkReplicateDelete(ctx, bucket, ObjectToDelete{
ObjectName: object.ObjectName,
VersionID: object.VersionID,
}, goi, gerr)
@ -511,9 +511,6 @@ func (api objectAPIHandlers) DeleteMultipleObjectsHandler(w http.ResponseWriter,
}
if object.VersionID != "" {
object.VersionPurgeStatus = Pending
if delMarker {
object.DeleteMarkerVersionID = object.VersionID
}
} else {
object.DeleteMarkerReplicationStatus = string(replication.Pending)
}
@ -557,13 +554,18 @@ func (api objectAPIHandlers) DeleteMultipleObjectsHandler(w http.ResponseWriter,
})
deletedObjects := make([]DeletedObject, len(deleteObjects.Objects))
for i := range errs {
dindex := objectsToDelete[ObjectToDelete{
// DeleteMarkerVersionID is not used specifically to avoid
// lookup errors, since DeleteMarkerVersionID is only
// created during DeleteMarker creation when client didn't
// specify a versionID.
objToDel := ObjectToDelete{
ObjectName: dObjects[i].ObjectName,
VersionID: dObjects[i].VersionID,
VersionPurgeStatus: dObjects[i].VersionPurgeStatus,
DeleteMarkerReplicationStatus: dObjects[i].DeleteMarkerReplicationStatus,
PurgeTransitioned: dObjects[i].PurgeTransitioned,
}]
}
dindex := objectsToDelete[objToDel]
if errs[i] == nil || isErrObjectNotFound(errs[i]) || isErrVersionNotFound(errs[i]) {
if replicateDeletes {
dObjects[i].DeleteMarkerReplicationStatus = deleteList[i].DeleteMarkerReplicationStatus
@ -619,12 +621,12 @@ func (api objectAPIHandlers) DeleteMultipleObjectsHandler(w http.ResponseWriter,
eventName := event.ObjectRemovedDelete
objInfo := ObjectInfo{
Name: dobj.ObjectName,
VersionID: dobj.VersionID,
Name: dobj.ObjectName,
VersionID: dobj.VersionID,
DeleteMarker: dobj.DeleteMarker,
}
if dobj.DeleteMarker {
objInfo.DeleteMarker = dobj.DeleteMarker
if objInfo.DeleteMarker {
objInfo.VersionID = dobj.DeleteMarkerVersionID
eventName = event.ObjectRemovedDeleteMarkerCreated
}

View File

@ -83,17 +83,38 @@ func getConditionValues(r *http.Request, lc string, username string, claims map[
}
}
authType := getRequestAuthType(r)
var signatureVersion string
switch authType {
case authTypeSignedV2, authTypePresignedV2:
signatureVersion = signV2Algorithm
case authTypeSigned, authTypePresigned, authTypeStreamingSigned, authTypePostPolicy:
signatureVersion = signV4Algorithm
}
var authtype string
switch authType {
case authTypePresignedV2, authTypePresigned:
authtype = "REST-QUERY-STRING"
case authTypeSignedV2, authTypeSigned, authTypeStreamingSigned:
authtype = "REST-HEADER"
case authTypePostPolicy:
authtype = "POST"
}
args := map[string][]string{
"CurrentTime": {currTime.Format(time.RFC3339)},
"EpochTime": {strconv.FormatInt(currTime.Unix(), 10)},
"SecureTransport": {strconv.FormatBool(r.TLS != nil)},
"SourceIp": {handlers.GetSourceIP(r)},
"UserAgent": {r.UserAgent()},
"Referer": {r.Referer()},
"principaltype": {principalType},
"userid": {username},
"username": {username},
"versionid": {vid},
"CurrentTime": {currTime.Format(time.RFC3339)},
"EpochTime": {strconv.FormatInt(currTime.Unix(), 10)},
"SecureTransport": {strconv.FormatBool(r.TLS != nil)},
"SourceIp": {handlers.GetSourceIP(r)},
"UserAgent": {r.UserAgent()},
"Referer": {r.Referer()},
"principaltype": {principalType},
"userid": {username},
"username": {username},
"versionid": {vid},
"signatureversion": {signatureVersion},
"authType": {authtype},
}
if lc != "" {

View File

@ -175,10 +175,10 @@ func isStandardHeader(matchHeaderKey string) bool {
}
// returns whether object version is a deletemarker and if object qualifies for replication
func checkReplicateDelete(ctx context.Context, bucket string, dobj ObjectToDelete, oi ObjectInfo, gerr error) (dm, replicate, sync bool) {
func checkReplicateDelete(ctx context.Context, bucket string, dobj ObjectToDelete, oi ObjectInfo, gerr error) (replicate, sync bool) {
rcfg, err := getReplicationConfig(ctx, bucket)
if err != nil || rcfg == nil {
return false, false, sync
return false, sync
}
opts := replication.ObjectOpts{
Name: dobj.ObjectName,
@ -198,19 +198,19 @@ func checkReplicateDelete(ctx context.Context, bucket string, dobj ObjectToDelet
validReplStatus = true
}
if oi.DeleteMarker && (validReplStatus || replicate) {
return oi.DeleteMarker, true, sync
return true, sync
}
// can be the case that other cluster is down and duplicate `mc rm --vid`
// is issued - this still needs to be replicated back to the other target
return oi.DeleteMarker, oi.VersionPurgeStatus == Pending || oi.VersionPurgeStatus == Failed, sync
return oi.VersionPurgeStatus == Pending || oi.VersionPurgeStatus == Failed, sync
}
tgt := globalBucketTargetSys.GetRemoteTargetClient(ctx, rcfg.RoleArn)
// the target online status should not be used here while deciding
// whether to replicate deletes as the target could be temporarily down
if tgt == nil {
return oi.DeleteMarker, false, false
return false, false
}
return oi.DeleteMarker, replicate, tgt.replicateSync
return replicate, tgt.replicateSync
}
// replicate deletes to the designated replication target if replication configuration
@ -697,19 +697,25 @@ func replicateObject(ctx context.Context, objInfo ObjectInfo, objectAPI ObjectLa
if totalNodesCount == 0 {
totalNodesCount = 1 // For standalone erasure coding
}
b := target.BandwidthLimit / int64(totalNodesCount)
var headerSize int
for k, v := range putOpts.Header() {
headerSize += len(k) + len(v)
}
// r takes over closing gr.
r := bandwidth.NewMonitoredReader(ctx, globalBucketMonitor, objInfo.Bucket, objInfo.Name, gr, headerSize, b, target.BandwidthLimit)
opts := &bandwidth.MonitorReaderOptions{
Bucket: objInfo.Bucket,
Object: objInfo.Name,
HeaderSize: headerSize,
BandwidthBytesPerSec: target.BandwidthLimit / int64(totalNodesCount),
ClusterBandwidth: target.BandwidthLimit,
}
r := bandwidth.NewMonitoredReader(ctx, globalBucketMonitor, gr, opts)
if _, err = c.PutObject(ctx, dest.Bucket, object, r, size, "", "", putOpts); err != nil {
replicationStatus = replication.Failed
logger.LogIf(ctx, fmt.Errorf("Unable to replicate for object %s/%s(%s): %w", bucket, objInfo.Name, objInfo.VersionID, err))
}
defer r.Close()
}
objInfo.UserDefined[xhttp.AmzBucketReplicationStatus] = replicationStatus.String()

View File

@ -100,7 +100,7 @@ func (sys *BucketTargetSys) SetTarget(ctx context.Context, bucket string, tgt *m
if minio.ToErrorResponse(err).Code == "NoSuchBucket" {
return BucketRemoteTargetNotFound{Bucket: tgt.TargetBucket}
}
return BucketRemoteConnectionErr{Bucket: tgt.TargetBucket}
return BucketRemoteConnectionErr{Bucket: tgt.TargetBucket, Err: err}
}
if tgt.Type == madmin.ReplicationService {
if !globalIsErasure {
@ -111,7 +111,7 @@ func (sys *BucketTargetSys) SetTarget(ctx context.Context, bucket string, tgt *m
}
vcfg, err := clnt.GetBucketVersioning(ctx, tgt.TargetBucket)
if err != nil {
return BucketRemoteConnectionErr{Bucket: tgt.TargetBucket}
return BucketRemoteConnectionErr{Bucket: tgt.TargetBucket, Err: err}
}
if vcfg.Status != string(versioning.Enabled) {
return BucketRemoteTargetNotVersioned{Bucket: tgt.TargetBucket}
@ -124,7 +124,7 @@ func (sys *BucketTargetSys) SetTarget(ctx context.Context, bucket string, tgt *m
if minio.ToErrorResponse(err).Code == "NoSuchBucket" {
return BucketRemoteTargetNotFound{Bucket: tgt.TargetBucket}
}
return BucketRemoteConnectionErr{Bucket: tgt.TargetBucket}
return BucketRemoteConnectionErr{Bucket: tgt.TargetBucket, Err: err}
}
if vcfg.Status != string(versioning.Enabled) {
return BucketRemoteTargetNotVersioned{Bucket: tgt.TargetBucket}

View File

@ -797,42 +797,39 @@ type actionMeta struct {
var applyActionsLogPrefix = color.Green("applyActions:")
// applyActions will apply lifecycle checks on to a scanned item.
// The resulting size on disk will always be returned.
// The metadata will be compared to consensus on the object layer before any changes are applied.
// If no metadata is supplied, -1 is returned if no action is taken.
func (i *scannerItem) applyActions(ctx context.Context, o ObjectLayer, meta actionMeta) (size int64) {
func (i *scannerItem) applyHealing(ctx context.Context, o ObjectLayer, meta actionMeta) (size int64) {
if i.debug {
if meta.oi.VersionID != "" {
console.Debugf(applyActionsLogPrefix+" heal checking: %v/%v v(%s)\n", i.bucket, i.objectPath(), meta.oi.VersionID)
} else {
console.Debugf(applyActionsLogPrefix+" heal checking: %v/%v\n", i.bucket, i.objectPath())
}
}
healOpts := madmin.HealOpts{Remove: healDeleteDangling}
if meta.bitRotScan {
healOpts.ScanMode = madmin.HealDeepScan
}
res, err := o.HealObject(ctx, i.bucket, i.objectPath(), meta.oi.VersionID, healOpts)
if isErrObjectNotFound(err) || isErrVersionNotFound(err) {
return 0
}
if err != nil && !errors.Is(err, NotImplemented{}) {
logger.LogIf(ctx, err)
return 0
}
return res.ObjectSize
}
func (i *scannerItem) applyLifecycle(ctx context.Context, o ObjectLayer, meta actionMeta) (applied bool, size int64) {
size, err := meta.oi.GetActualSize()
if i.debug {
logger.LogIf(ctx, err)
}
if i.heal {
if i.debug {
if meta.oi.VersionID != "" {
console.Debugf(applyActionsLogPrefix+" heal checking: %v/%v v(%s)\n", i.bucket, i.objectPath(), meta.oi.VersionID)
} else {
console.Debugf(applyActionsLogPrefix+" heal checking: %v/%v\n", i.bucket, i.objectPath())
}
}
healOpts := madmin.HealOpts{Remove: healDeleteDangling}
if meta.bitRotScan {
healOpts.ScanMode = madmin.HealDeepScan
}
res, err := o.HealObject(ctx, i.bucket, i.objectPath(), meta.oi.VersionID, healOpts)
if isErrObjectNotFound(err) || isErrVersionNotFound(err) {
return 0
}
if err != nil && !errors.Is(err, NotImplemented{}) {
logger.LogIf(ctx, err)
return 0
}
size = res.ObjectSize
}
if i.lifeCycle == nil {
if i.debug {
console.Debugf(applyActionsLogPrefix+" no lifecycle rules to apply: %q\n", i.objectPath())
}
return size
return false, size
}
versionID := meta.oi.VersionID
@ -866,7 +863,7 @@ func (i *scannerItem) applyActions(ctx context.Context, o ObjectLayer, meta acti
if i.debug {
console.Debugf(applyActionsLogPrefix+" object not expirable: %q\n", i.objectPath())
}
return size
return false, size
}
obj, err := o.GetObjectInfo(ctx, i.bucket, i.objectPath(), ObjectOptions{
@ -878,19 +875,18 @@ func (i *scannerItem) applyActions(ctx context.Context, o ObjectLayer, meta acti
if !obj.DeleteMarker { // if this is not a delete marker log and return
// Do nothing - heal in the future.
logger.LogIf(ctx, err)
return size
return false, size
}
case ObjectNotFound, VersionNotFound:
// object not found or version not found return 0
return 0
return false, 0
default:
// All other errors proceed.
logger.LogIf(ctx, err)
return size
return false, size
}
}
var applied bool
action = evalActionFromLifecycle(ctx, *i.lifeCycle, obj, i.debug)
if action != lifecycle.NoneAction {
applied = applyLifecycleAction(ctx, action, o, obj)
@ -899,9 +895,26 @@ func (i *scannerItem) applyActions(ctx context.Context, o ObjectLayer, meta acti
if applied {
switch action {
case lifecycle.TransitionAction, lifecycle.TransitionVersionAction:
default: // for all lifecycle actions that remove data
return 0
return true, size
}
// For all other lifecycle actions that remove data
return true, 0
}
return false, size
}
// applyActions will apply lifecycle checks on to a scanned item.
// The resulting size on disk will always be returned.
// The metadata will be compared to consensus on the object layer before any changes are applied.
// If no metadata is supplied, -1 is returned if no action is taken.
func (i *scannerItem) applyActions(ctx context.Context, o ObjectLayer, meta actionMeta) int64 {
applied, size := i.applyLifecycle(ctx, o, meta)
// For instance, an applied lifecycle means we remove/transitioned an object
// from the current deployment, which means we don't have to call healing
// routine even if we are asked to do via heal flag.
if !applied && i.heal {
size = i.applyHealing(ctx, o, meta)
}
return size
}

View File

@ -522,7 +522,7 @@ func (d *dataUsageCache) save(ctx context.Context, store objectIO, name string)
dataUsageBucket,
name,
NewPutObjReader(r),
ObjectOptions{NoLock: true})
ObjectOptions{})
if isErrBucketNotFound(err) {
return nil
}

View File

@ -250,8 +250,8 @@ func (s *erasureSets) connectDisks() {
}
disk.SetDiskLoc(s.poolIndex, setIndex, diskIndex)
s.endpointStrings[setIndex*s.setDriveCount+diskIndex] = disk.String()
s.erasureDisksMu.Unlock()
setsJustConnected[setIndex] = true
s.erasureDisksMu.Unlock()
}(endpoint)
}

View File

@ -233,10 +233,15 @@ func extractReqParams(r *http.Request) map[string]string {
region := globalServerRegion
cred := getReqAccessCred(r, region)
principalID := cred.AccessKey
if cred.ParentUser != "" {
principalID = cred.ParentUser
}
// Success.
m := map[string]string{
"region": region,
"accessKey": cred.AccessKey,
"principalId": principalID,
"sourceIPAddress": handlers.GetSourceIP(r),
// Add more fields here.
}

View File

@ -1704,7 +1704,7 @@ func (sys *IAMSys) PolicyDBGet(name string, isGroup bool, groups ...string) ([]s
// information in IAM (i.e sys.iam*Map) - this info is stored only in the STS
// generated credentials. Thus we skip looking up group memberships, user map,
// and group map and check the appropriate policy maps directly.
func (sys *IAMSys) policyDBGet(name string, isGroup bool) ([]string, error) {
func (sys *IAMSys) policyDBGet(name string, isGroup bool) (policies []string, err error) {
if isGroup {
if sys.usersSysType == MinIOUsersSysType {
g, ok := sys.iamGroupsMap[name]
@ -1719,8 +1719,7 @@ func (sys *IAMSys) policyDBGet(name string, isGroup bool) ([]string, error) {
}
}
mp := sys.iamGroupPolicyMap[name]
return mp.toSlice(), nil
return sys.iamGroupPolicyMap[name].toSlice(), nil
}
var u auth.Credentials
@ -1738,8 +1737,6 @@ func (sys *IAMSys) policyDBGet(name string, isGroup bool) ([]string, error) {
}
}
var policies []string
mp, ok := sys.iamUserPolicyMap[name]
if !ok {
if u.ParentUser != "" {
@ -1757,8 +1754,7 @@ func (sys *IAMSys) policyDBGet(name string, isGroup bool) ([]string, error) {
continue
}
p := sys.iamGroupPolicyMap[group]
policies = append(policies, p.toSlice()...)
policies = append(policies, sys.iamGroupPolicyMap[group].toSlice()...)
}
return policies, nil
@ -1788,8 +1784,9 @@ func (sys *IAMSys) IsAllowedServiceAccount(args iampolicy.Args, parent string) b
}
// Check policy for this service account.
svcPolicies, err := sys.PolicyDBGet(args.AccountName, false)
svcPolicies, err := sys.PolicyDBGet(parent, false, args.Groups...)
if err != nil {
logger.LogIf(GlobalContext, err)
return false
}
@ -2072,7 +2069,7 @@ func (sys *IAMSys) IsAllowed(args iampolicy.Args) bool {
}
// Continue with the assumption of a regular user
policies, err := sys.PolicyDBGet(args.AccountName, false)
policies, err := sys.PolicyDBGet(args.AccountName, false, args.Groups...)
if err != nil {
return false
}

View File

@ -81,6 +81,15 @@ type MapClaims struct {
jwtgo.MapClaims
}
// GetAccessKey will return the access key.
// If nil an empty string will be returned.
func (c *MapClaims) GetAccessKey() string {
if c == nil {
return ""
}
return c.AccessKey
}
// NewStandardClaims - initializes standard claims
func NewStandardClaims() *StandardClaims {
return &StandardClaims{}

View File

@ -1368,7 +1368,7 @@ func (args eventArgs) ToEvent(escape bool) event.Event {
AwsRegion: args.ReqParams["region"],
EventTime: eventTime.Format(event.AMZTimeFormat),
EventName: args.EventName,
UserIdentity: event.Identity{PrincipalID: args.ReqParams["accessKey"]},
UserIdentity: event.Identity{PrincipalID: args.ReqParams["principalId"]},
RequestParameters: args.ReqParams,
ResponseElements: respElements,
S3: event.Metadata{
@ -1376,7 +1376,7 @@ func (args eventArgs) ToEvent(escape bool) event.Event {
ConfigurationID: "Config",
Bucket: event.Bucket{
Name: args.BucketName,
OwnerIdentity: event.Identity{PrincipalID: args.ReqParams["accessKey"]},
OwnerIdentity: event.Identity{PrincipalID: args.ReqParams["principalId"]},
ARN: policy.ResourceARNPrefix + args.BucketName,
},
Object: event.Object{

View File

@ -426,7 +426,7 @@ func (e BucketRemoteTargetNotFound) Error() string {
type BucketRemoteConnectionErr GenericError
func (e BucketRemoteConnectionErr) Error() string {
return "Remote service endpoint or target bucket not available: " + e.Bucket
return fmt.Sprintf("Remote service endpoint or target bucket not available: %s \n\t%s", e.Bucket, e.Err.Error())
}
// BucketRemoteAlreadyExists remote already exists for this target type.

View File

@ -2371,8 +2371,20 @@ func (api objectAPIHandlers) PutObjectPartHandler(w http.ResponseWriter, r *http
}
etag := partInfo.ETag
if isEncrypted {
etag = tryDecryptETag(objectEncryptionKey[:], partInfo.ETag, crypto.SSEC.IsRequested(r.Header))
switch kind, encrypted := crypto.IsEncrypted(mi.UserDefined); {
case encrypted:
switch kind {
case crypto.S3:
w.Header().Set(xhttp.AmzServerSideEncryption, xhttp.AmzEncryptionAES)
etag = tryDecryptETag(objectEncryptionKey[:], etag, false)
case crypto.SSEC:
w.Header().Set(xhttp.AmzServerSideEncryptionCustomerAlgorithm, r.Header.Get(xhttp.AmzServerSideEncryptionCustomerAlgorithm))
w.Header().Set(xhttp.AmzServerSideEncryptionCustomerKeyMD5, r.Header.Get(xhttp.AmzServerSideEncryptionCustomerKeyMD5))
if len(etag) >= 32 && strings.Count(etag, "-") != 1 {
etag = etag[len(etag)-32:]
}
}
}
// We must not use the http.Header().Set method here because some (broken)
@ -2817,7 +2829,8 @@ func (api objectAPIHandlers) DeleteObjectHandler(w http.ResponseWriter, r *http.
VersionID: opts.VersionID,
})
}
_, replicateDel, replicateSync := checkReplicateDelete(ctx, bucket, ObjectToDelete{ObjectName: object, VersionID: opts.VersionID}, goi, gerr)
replicateDel, replicateSync := checkReplicateDelete(ctx, bucket, ObjectToDelete{ObjectName: object, VersionID: opts.VersionID}, goi, gerr)
if replicateDel {
if opts.VersionID != "" {
opts.VersionPurgeStatus = Pending
@ -2825,6 +2838,7 @@ func (api objectAPIHandlers) DeleteObjectHandler(w http.ResponseWriter, r *http.
opts.DeleteMarkerReplicationStatus = string(replication.Pending)
}
}
vID := opts.VersionID
if r.Header.Get(xhttp.AmzBucketReplicationStatus) == replication.Replica.String() {
// check if replica has permission to be deleted.

View File

@ -340,9 +340,8 @@ func (client *storageRESTClient) CreateFile(ctx context.Context, volume, path st
if err != nil {
return err
}
waitReader, err := waitForHTTPResponse(respBody)
defer http.DrainBody(ioutil.NopCloser(waitReader))
defer respBody.Close()
_, err = waitForHTTPResponse(respBody)
defer http.DrainBody(respBody)
return err
}

View File

@ -226,7 +226,7 @@ func (web *webAPIHandlers) MakeBucket(r *http.Request, args *MakeBucketArgs, rep
reply.UIVersion = Version
reqParams := extractReqParams(r)
reqParams["accessKey"] = claims.AccessKey
reqParams["accessKey"] = claims.GetAccessKey()
sendEvent(eventArgs{
EventName: event.BucketCreated,
@ -723,7 +723,7 @@ func (web *webAPIHandlers) RemoveObject(r *http.Request, args *RemoveObjectArgs,
)
reqParams := extractReqParams(r)
reqParams["accessKey"] = claims.AccessKey
reqParams["accessKey"] = claims.GetAccessKey()
sourceIP := handlers.GetSourceIP(r)
next:
@ -767,7 +767,7 @@ next:
}
if hasReplicationRules(ctx, args.BucketName, []ObjectToDelete{{ObjectName: objectName}}) || hasLifecycleConfig {
goi, gerr = getObjectInfoFn(ctx, args.BucketName, objectName, opts)
if _, replicateDel, replicateSync = checkReplicateDelete(ctx, args.BucketName, ObjectToDelete{
if replicateDel, replicateSync = checkReplicateDelete(ctx, args.BucketName, ObjectToDelete{
ObjectName: objectName,
VersionID: goi.VersionID,
}, goi, gerr); replicateDel {
@ -903,7 +903,7 @@ next:
}
}
}
_, replicateDel, _ := checkReplicateDelete(ctx, args.BucketName, ObjectToDelete{ObjectName: obj.Name, VersionID: obj.VersionID}, obj, nil)
replicateDel, _ := checkReplicateDelete(ctx, args.BucketName, ObjectToDelete{ObjectName: obj.Name, VersionID: obj.VersionID}, obj, nil)
// since versioned delete is not available on web browser, yet - this is a simple DeleteMarker replication
objToDel := ObjectToDelete{ObjectName: obj.Name}
if replicateDel {
@ -1340,7 +1340,7 @@ func (web *webAPIHandlers) Upload(w http.ResponseWriter, r *http.Request) {
}
reqParams := extractReqParams(r)
reqParams["accessKey"] = claims.AccessKey
reqParams["accessKey"] = claims.GetAccessKey()
// Notify object created event.
sendEvent(eventArgs{
@ -1529,7 +1529,7 @@ func (web *webAPIHandlers) Download(w http.ResponseWriter, r *http.Request) {
}
reqParams := extractReqParams(r)
reqParams["accessKey"] = claims.AccessKey
reqParams["accessKey"] = claims.GetAccessKey()
// Notify object accessed via a GET request.
sendEvent(eventArgs{
@ -1684,7 +1684,7 @@ func (web *webAPIHandlers) DownloadZip(w http.ResponseWriter, r *http.Request) {
defer archive.Close()
reqParams := extractReqParams(r)
reqParams["accessKey"] = claims.AccessKey
reqParams["accessKey"] = claims.GetAccessKey()
respElements := extractRespElements(w)
for i, object := range args.Objects {

View File

@ -347,6 +347,8 @@ func (s *xlStorage) IsLocal() bool {
// Retrieve location indexes.
func (s *xlStorage) GetDiskLoc() (poolIdx, setIdx, diskIdx int) {
s.RLock()
defer s.RUnlock()
// If unset, see if we can locate it.
if s.poolIndex < 0 || s.setIndex < 0 || s.diskIndex < 0 {
return getXLDiskLoc(s.diskID)
@ -1615,6 +1617,9 @@ func (s *xlStorage) CheckFile(ctx context.Context, volume string, path string) e
if err != nil {
return err
}
s.RLock()
formatLegacy := s.formatLegacy
s.RUnlock()
var checkFile func(p string) error
checkFile = func(p string) error {
@ -1626,10 +1631,10 @@ func (s *xlStorage) CheckFile(ctx context.Context, volume string, path string) e
if err := checkPathLength(filePath); err != nil {
return err
}
st, _ := Lstat(filePath)
if st == nil {
if !s.formatLegacy {
if !formatLegacy {
return errPathNotFound
}
@ -1880,10 +1885,13 @@ func (s *xlStorage) RenameData(ctx context.Context, srcVolume, srcPath, dataDir,
legacyPreserved = true
}
} else {
s.RLock()
formatLegacy := s.formatLegacy
s.RUnlock()
// It is possible that some drives may not have `xl.meta` file
// in such scenarios verify if atleast `part.1` files exist
// to verify for legacy version.
if s.formatLegacy {
if formatLegacy {
// We only need this code if we are moving
// from `xl.json` to `xl.meta`, we can avoid
// one extra readdir operation here for all

View File

@ -5,7 +5,7 @@ version: '3.7'
# it through port 9000.
services:
minio1:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z
image: minio/minio:RELEASE.2021-04-06T23-11-00Z
volumes:
- data1-1:/data1
- data1-2:/data2
@ -22,7 +22,7 @@ services:
retries: 3
minio2:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z
image: minio/minio:RELEASE.2021-04-06T23-11-00Z
volumes:
- data2-1:/data1
- data2-2:/data2
@ -39,7 +39,7 @@ services:
retries: 3
minio3:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z
image: minio/minio:RELEASE.2021-04-06T23-11-00Z
volumes:
- data3-1:/data1
- data3-2:/data2
@ -56,7 +56,7 @@ services:
retries: 3
minio4:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z
image: minio/minio:RELEASE.2021-04-06T23-11-00Z
volumes:
- data4-1:/data1
- data4-2:/data2

View File

@ -2,7 +2,7 @@ version: '3.7'
services:
minio1:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z
image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio1
volumes:
- minio1-data:/export
@ -29,7 +29,7 @@ services:
retries: 3
minio2:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z
image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio2
volumes:
- minio2-data:/export
@ -56,7 +56,7 @@ services:
retries: 3
minio3:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z
image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio3
volumes:
- minio3-data:/export
@ -83,7 +83,7 @@ services:
retries: 3
minio4:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z
image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio4
volumes:
- minio4-data:/export

View File

@ -2,7 +2,7 @@ version: '3.7'
services:
minio1:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z
image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio1
volumes:
- minio1-data:/export
@ -33,7 +33,7 @@ services:
retries: 3
minio2:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z
image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio2
volumes:
- minio2-data:/export
@ -64,7 +64,7 @@ services:
retries: 3
minio3:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z
image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio3
volumes:
- minio3-data:/export
@ -95,7 +95,7 @@ services:
retries: 3
minio4:
image: minio/minio:RELEASE.2021-03-17T02-33-02Z
image: minio/minio:RELEASE.2021-04-06T23-11-00Z
hostname: minio4
volumes:
- minio4-data:/export

1
go.mod
View File

@ -77,6 +77,7 @@ require (
github.com/tidwall/gjson v1.6.8
github.com/tidwall/sjson v1.0.4
github.com/tinylib/msgp v1.1.3
github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31 // indirect
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a
github.com/willf/bitset v1.1.11 // indirect
github.com/willf/bloom v2.0.3+incompatible

2
go.sum
View File

@ -596,6 +596,8 @@ github.com/tinylib/msgp v1.1.3 h1:3giwAkmtaEDLSV0MdO1lDLuPgklgPzmk8H9+So2BVfA=
github.com/tinylib/msgp v1.1.3/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8 h1:ndzgwNDnKIqyCvHTXaCqh9KlOWKvBry6nuXMJmonVsE=
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31 h1:OXcKh35JaYsGMRzpvFkLv/MEyPuL49CThT1pZ8aSml4=
github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31/go.mod h1:onvgF043R+lC5RZ8IT9rBXDaEDnpnw/Cl+HFiw+v/7Q=
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=
github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=

View File

@ -123,8 +123,12 @@ func (m *Monitor) getReport(selectBucket SelectionFunction) *bandwidth.Report {
if !selectBucket(bucket) {
continue
}
bucketThrottle, ok := m.bucketThrottle[bucket]
if !ok {
continue
}
report.BucketStats[bucket] = bandwidth.Details{
LimitInBytesPerSecond: m.bucketThrottle[bucket].clusterBandwidth,
LimitInBytesPerSecond: bucketThrottle.clusterBandwidth,
CurrentBandwidthInBytesPerSecond: bucketMeasurement.getExpMovingAvgBytesPerSecond(),
}
}

View File

@ -25,62 +25,61 @@ import (
// MonitoredReader monitors the bandwidth
type MonitoredReader struct {
bucket string // Token to track bucket
opts *MonitorReaderOptions
bucketMeasurement *bucketMeasurement // bucket measurement object
object string // Token to track object
reader io.ReadCloser // Reader to wrap
reader io.Reader // Reader to wrap
lastStop time.Time // Last timestamp for a measurement
headerSize int // Size of the header not captured by reader
throttle *throttle // throttle the rate at which replication occur
monitor *Monitor // Monitor reference
closed bool // Reader is closed
lastErr error // last error reported, if this non-nil all reads will fail.
}
// NewMonitoredReader returns a io.ReadCloser that reports bandwidth details.
// The supplied reader will be closed.
func NewMonitoredReader(ctx context.Context, monitor *Monitor, bucket string, object string, reader io.ReadCloser, headerSize int, bandwidthBytesPerSecond int64, clusterBandwidth int64) *MonitoredReader {
// MonitorReaderOptions provides configurable options for monitor reader implementation.
type MonitorReaderOptions struct {
Bucket string
Object string
HeaderSize int
BandwidthBytesPerSec int64
ClusterBandwidth int64
}
// NewMonitoredReader returns a io.Reader that reports bandwidth details.
func NewMonitoredReader(ctx context.Context, monitor *Monitor, reader io.Reader, opts *MonitorReaderOptions) *MonitoredReader {
timeNow := time.Now()
b := monitor.track(bucket, object, timeNow)
b := monitor.track(opts.Bucket, opts.Object, timeNow)
return &MonitoredReader{
bucket: bucket,
object: object,
opts: opts,
bucketMeasurement: b,
reader: reader,
lastStop: timeNow,
headerSize: headerSize,
throttle: monitor.throttleBandwidth(ctx, bucket, bandwidthBytesPerSecond, clusterBandwidth),
throttle: monitor.throttleBandwidth(ctx, opts.Bucket, opts.BandwidthBytesPerSec, opts.ClusterBandwidth),
monitor: monitor,
}
}
// Read wraps the read reader
func (m *MonitoredReader) Read(p []byte) (n int, err error) {
if m.closed {
err = io.ErrClosedPipe
if m.lastErr != nil {
err = m.lastErr
return
}
p = p[:m.throttle.GetLimitForBytes(int64(len(p)))]
n, err = m.reader.Read(p)
stop := time.Now()
update := uint64(n + m.headerSize)
update := uint64(n + m.opts.HeaderSize)
m.bucketMeasurement.incrementBytes(update)
m.lastStop = stop
unused := len(p) - (n + m.headerSize)
m.headerSize = 0 // Set to 0 post first read
unused := len(p) - (n + m.opts.HeaderSize)
m.opts.HeaderSize = 0 // Set to 0 post first read
if unused > 0 {
m.throttle.ReleaseUnusedBandwidth(int64(unused))
}
if err != nil {
m.lastErr = err
}
return
}
// Close stops tracking the io
func (m *MonitoredReader) Close() error {
if m.closed {
return nil
}
m.closed = true
return m.reader.Close()
}

View File

@ -18,6 +18,7 @@ package lifecycle
import (
"encoding/xml"
"fmt"
"io"
"strings"
"time"
@ -71,7 +72,8 @@ func (lc *Lifecycle) UnmarshalXML(d *xml.Decoder, start xml.StartElement) (err e
switch start.Name.Local {
case "LifecycleConfiguration", "BucketLifecycleConfiguration":
default:
return errUnknownXMLTag
return xml.UnmarshalError(fmt.Sprintf("expected element type <LifecycleConfiguration>/<BucketLifecycleConfiguration> but have <%s>",
start.Name.Local))
}
for {
// Read tokens from the XML document in a stream.
@ -93,7 +95,7 @@ func (lc *Lifecycle) UnmarshalXML(d *xml.Decoder, start xml.StartElement) (err e
}
lc.Rules = append(lc.Rules, r)
default:
return errUnknownXMLTag
return xml.UnmarshalError(fmt.Sprintf("expected element type <Rule> but have <%s>", se.Name.Local))
}
}
}

View File

@ -489,6 +489,41 @@ type ObjectLegalHold struct {
Status LegalHoldStatus `xml:"Status,omitempty"`
}
// UnmarshalXML - decodes XML data.
func (l *ObjectLegalHold) UnmarshalXML(d *xml.Decoder, start xml.StartElement) (err error) {
switch start.Name.Local {
case "LegalHold", "ObjectLockLegalHold":
default:
return xml.UnmarshalError(fmt.Sprintf("expected element type <LegalHold>/<ObjectLockLegalHold> but have <%s>",
start.Name.Local))
}
for {
// Read tokens from the XML document in a stream.
t, err := d.Token()
if err != nil {
if err == io.EOF {
break
}
return err
}
switch se := t.(type) {
case xml.StartElement:
switch se.Name.Local {
case "Status":
var st LegalHoldStatus
if err = d.DecodeElement(&st, &se); err != nil {
return err
}
l.Status = st
default:
return xml.UnmarshalError(fmt.Sprintf("expected element type <Status> but have <%s>", se.Name.Local))
}
}
}
return nil
}
// IsEmpty returns true if struct is empty
func (l *ObjectLegalHold) IsEmpty() bool {
return !l.Status.Valid()

View File

@ -18,6 +18,7 @@ package lock
import (
"encoding/xml"
"errors"
"fmt"
"net/http"
"reflect"
@ -467,6 +468,23 @@ func TestParseObjectLegalHold(t *testing.T) {
expectedErr: nil,
expectErr: false,
},
{
value: `<?xml version="1.0" encoding="UTF-8"?><ObjectLockLegalHold xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><Status>ON</Status></ObjectLockLegalHold>`,
expectedErr: nil,
expectErr: false,
},
// invalid Status key
{
value: `<?xml version="1.0" encoding="UTF-8"?><ObjectLockLegalHold xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><MyStatus>ON</MyStatus></ObjectLockLegalHold>`,
expectedErr: errors.New("expected element type <Status> but have <MyStatus>"),
expectErr: true,
},
// invalid XML attr
{
value: `<?xml version="1.0" encoding="UTF-8"?><UnknownLegalHold xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><Status>ON</Status></UnknownLegalHold>`,
expectedErr: errors.New("expected element type <LegalHold>/<ObjectLockLegalHold> but have <UnknownLegalHold>"),
expectErr: true,
},
{
value: `<?xml version="1.0" encoding="UTF-8"?><LegalHold xmlns="http://s3.amazonaws.com/doc/2006-03-01/"><Status>On</Status></LegalHold>`,
expectedErr: ErrMalformedXML,

View File

@ -110,10 +110,18 @@ const (
// AWSUsername - user friendly name, in MinIO this value is same as your user Access Key.
AWSUsername Key = "aws:username"
// S3SignatureVersion - identifies the version of AWS Signature that you want to support for authenticated requests.
S3SignatureVersion = "s3:signatureversion"
// S3AuthType - optionally use this condition key to restrict incoming requests to use a specific authentication method.
S3AuthType = "s3:authType"
)
// AllSupportedKeys - is list of all all supported keys.
var AllSupportedKeys = append([]Key{
S3SignatureVersion,
S3AuthType,
S3XAmzCopySource,
S3XAmzServerSideEncryption,
S3XAmzServerSideEncryptionCustomerAlgorithm,
@ -144,6 +152,8 @@ var AllSupportedKeys = append([]Key{
// CommonKeys - is list of all common condition keys.
var CommonKeys = append([]Key{
S3SignatureVersion,
S3AuthType,
S3XAmzContentSha256,
S3LocationConstraint,
AWSReferer,

View File

@ -739,6 +739,152 @@ func TestCSVQueries2(t *testing.T) {
}
}
func TestCSVQueries3(t *testing.T) {
input := `na.me,qty,CAST
apple,1,true
mango,3,false
`
var testTable = []struct {
name string
query string
requestXML []byte // override request XML
wantResult string
}{
{
name: "Select a column containing dot",
query: `select "na.me" from S3Object s`,
wantResult: `apple
mango`,
},
{
name: "Select column containing dot with table name prefix",
query: `select count(S3Object."na.me") from S3Object`,
wantResult: `2`,
},
{
name: "Select column containing dot with table alias prefix",
query: `select s."na.me" from S3Object as s`,
wantResult: `apple
mango`,
},
{
name: "Select column simplest",
query: `select qty from S3Object`,
wantResult: `1
3`,
},
{
name: "Select column with table name prefix",
query: `select S3Object.qty from S3Object`,
wantResult: `1
3`,
},
{
name: "Select column without table alias",
query: `select qty from S3Object s`,
wantResult: `1
3`,
},
{
name: "Select column with table alias",
query: `select s.qty from S3Object s`,
wantResult: `1
3`,
},
{
name: "Select reserved word column",
query: `select "CAST" from s3object`,
wantResult: `true
false`,
},
{
name: "Select reserved word column with table alias",
query: `select S3Object."CAST" from s3object`,
wantResult: `true
false`,
},
{
name: "Select reserved word column with unused table alias",
query: `select "CAST" from s3object s`,
wantResult: `true
false`,
},
{
name: "Select reserved word column with table alias",
query: `select s."CAST" from s3object s`,
wantResult: `true
false`,
},
{
name: "Select reserved word column with table alias",
query: `select NOT CAST(s."CAST" AS Bool) from s3object s`,
wantResult: `false
true`,
},
}
defRequest := `<?xml version="1.0" encoding="UTF-8"?>
<SelectObjectContentRequest>
<Expression>%s</Expression>
<ExpressionType>SQL</ExpressionType>
<InputSerialization>
<CompressionType>NONE</CompressionType>
<CSV>
<FileHeaderInfo>USE</FileHeaderInfo>
<QuoteCharacter>"</QuoteCharacter>
</CSV>
</InputSerialization>
<OutputSerialization>
<CSV/>
</OutputSerialization>
<RequestProgress>
<Enabled>FALSE</Enabled>
</RequestProgress>
</SelectObjectContentRequest>`
for _, testCase := range testTable {
t.Run(testCase.name, func(t *testing.T) {
testReq := testCase.requestXML
if len(testReq) == 0 {
testReq = []byte(fmt.Sprintf(defRequest, testCase.query))
}
s3Select, err := NewS3Select(bytes.NewReader(testReq))
if err != nil {
t.Fatal(err)
}
if err = s3Select.Open(func(offset, length int64) (io.ReadCloser, error) {
return ioutil.NopCloser(bytes.NewBufferString(input)), nil
}); err != nil {
t.Fatal(err)
}
w := &testResponseWriter{}
s3Select.Evaluate(w)
s3Select.Close()
resp := http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(bytes.NewReader(w.response)),
ContentLength: int64(len(w.response)),
}
res, err := minio.NewSelectResults(&resp, "testbucket")
if err != nil {
t.Error(err)
return
}
got, err := ioutil.ReadAll(res)
if err != nil {
t.Error(err)
return
}
gotS := strings.TrimSpace(string(got))
if gotS != testCase.wantResult {
t.Errorf("received response does not match with expected reply.\nQuery: %s\n=====\ngot: %s\n=====\nwant: %s\n=====\n", testCase.query, gotS, testCase.wantResult)
}
})
}
}
func TestCSVInput(t *testing.T) {
var testTable = []struct {
requestXML []byte

View File

@ -63,7 +63,7 @@ func newAggVal(fn FuncName) *aggVal {
// current row and stores the result.
//
// On success, it returns (nil, nil).
func (e *FuncExpr) evalAggregationNode(r Record) error {
func (e *FuncExpr) evalAggregationNode(r Record, tableAlias string) error {
// It is assumed that this function is called only when
// `e` is an aggregation function.
@ -77,13 +77,13 @@ func (e *FuncExpr) evalAggregationNode(r Record) error {
return nil
}
val, err = e.Count.ExprArg.evalNode(r)
val, err = e.Count.ExprArg.evalNode(r, tableAlias)
if err != nil {
return err
}
} else {
// Evaluate the (only) argument
val, err = e.SFunc.ArgsList[0].evalNode(r)
val, err = e.SFunc.ArgsList[0].evalNode(r, tableAlias)
if err != nil {
return err
}
@ -149,13 +149,13 @@ func (e *FuncExpr) evalAggregationNode(r Record) error {
return err
}
func (e *AliasedExpression) aggregateRow(r Record) error {
return e.Expression.aggregateRow(r)
func (e *AliasedExpression) aggregateRow(r Record, tableAlias string) error {
return e.Expression.aggregateRow(r, tableAlias)
}
func (e *Expression) aggregateRow(r Record) error {
func (e *Expression) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.And {
err := ex.aggregateRow(r)
err := ex.aggregateRow(r, tableAlias)
if err != nil {
return err
}
@ -163,9 +163,9 @@ func (e *Expression) aggregateRow(r Record) error {
return nil
}
func (e *ListExpr) aggregateRow(r Record) error {
func (e *ListExpr) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.Elements {
err := ex.aggregateRow(r)
err := ex.aggregateRow(r, tableAlias)
if err != nil {
return err
}
@ -173,9 +173,9 @@ func (e *ListExpr) aggregateRow(r Record) error {
return nil
}
func (e *AndCondition) aggregateRow(r Record) error {
func (e *AndCondition) aggregateRow(r Record, tableAlias string) error {
for _, ex := range e.Condition {
err := ex.aggregateRow(r)
err := ex.aggregateRow(r, tableAlias)
if err != nil {
return err
}
@ -183,15 +183,15 @@ func (e *AndCondition) aggregateRow(r Record) error {
return nil
}
func (e *Condition) aggregateRow(r Record) error {
func (e *Condition) aggregateRow(r Record, tableAlias string) error {
if e.Operand != nil {
return e.Operand.aggregateRow(r)
return e.Operand.aggregateRow(r, tableAlias)
}
return e.Not.aggregateRow(r)
return e.Not.aggregateRow(r, tableAlias)
}
func (e *ConditionOperand) aggregateRow(r Record) error {
err := e.Operand.aggregateRow(r)
func (e *ConditionOperand) aggregateRow(r Record, tableAlias string) error {
err := e.Operand.aggregateRow(r, tableAlias)
if err != nil {
return err
}
@ -202,38 +202,38 @@ func (e *ConditionOperand) aggregateRow(r Record) error {
switch {
case e.ConditionRHS.Compare != nil:
return e.ConditionRHS.Compare.Operand.aggregateRow(r)
return e.ConditionRHS.Compare.Operand.aggregateRow(r, tableAlias)
case e.ConditionRHS.Between != nil:
err = e.ConditionRHS.Between.Start.aggregateRow(r)
err = e.ConditionRHS.Between.Start.aggregateRow(r, tableAlias)
if err != nil {
return err
}
return e.ConditionRHS.Between.End.aggregateRow(r)
return e.ConditionRHS.Between.End.aggregateRow(r, tableAlias)
case e.ConditionRHS.In != nil:
elt := e.ConditionRHS.In.ListExpression
err = elt.aggregateRow(r)
err = elt.aggregateRow(r, tableAlias)
if err != nil {
return err
}
return nil
case e.ConditionRHS.Like != nil:
err = e.ConditionRHS.Like.Pattern.aggregateRow(r)
err = e.ConditionRHS.Like.Pattern.aggregateRow(r, tableAlias)
if err != nil {
return err
}
return e.ConditionRHS.Like.EscapeChar.aggregateRow(r)
return e.ConditionRHS.Like.EscapeChar.aggregateRow(r, tableAlias)
default:
return errInvalidASTNode
}
}
func (e *Operand) aggregateRow(r Record) error {
err := e.Left.aggregateRow(r)
func (e *Operand) aggregateRow(r Record, tableAlias string) error {
err := e.Left.aggregateRow(r, tableAlias)
if err != nil {
return err
}
for _, rt := range e.Right {
err = rt.Right.aggregateRow(r)
err = rt.Right.aggregateRow(r, tableAlias)
if err != nil {
return err
}
@ -241,13 +241,13 @@ func (e *Operand) aggregateRow(r Record) error {
return nil
}
func (e *MultOp) aggregateRow(r Record) error {
err := e.Left.aggregateRow(r)
func (e *MultOp) aggregateRow(r Record, tableAlias string) error {
err := e.Left.aggregateRow(r, tableAlias)
if err != nil {
return err
}
for _, rt := range e.Right {
err = rt.Right.aggregateRow(r)
err = rt.Right.aggregateRow(r, tableAlias)
if err != nil {
return err
}
@ -255,29 +255,29 @@ func (e *MultOp) aggregateRow(r Record) error {
return nil
}
func (e *UnaryTerm) aggregateRow(r Record) error {
func (e *UnaryTerm) aggregateRow(r Record, tableAlias string) error {
if e.Negated != nil {
return e.Negated.Term.aggregateRow(r)
return e.Negated.Term.aggregateRow(r, tableAlias)
}
return e.Primary.aggregateRow(r)
return e.Primary.aggregateRow(r, tableAlias)
}
func (e *PrimaryTerm) aggregateRow(r Record) error {
func (e *PrimaryTerm) aggregateRow(r Record, tableAlias string) error {
switch {
case e.ListExpr != nil:
return e.ListExpr.aggregateRow(r)
return e.ListExpr.aggregateRow(r, tableAlias)
case e.SubExpression != nil:
return e.SubExpression.aggregateRow(r)
return e.SubExpression.aggregateRow(r, tableAlias)
case e.FuncCall != nil:
return e.FuncCall.aggregateRow(r)
return e.FuncCall.aggregateRow(r, tableAlias)
}
return nil
}
func (e *FuncExpr) aggregateRow(r Record) error {
func (e *FuncExpr) aggregateRow(r Record, tableAlias string) error {
switch e.getFunctionName() {
case aggFnAvg, aggFnSum, aggFnMax, aggFnMin, aggFnCount:
return e.evalAggregationNode(r)
return e.evalAggregationNode(r, tableAlias)
default:
// TODO: traverse arguments and call aggregateRow on
// them if they could be an ancestor of an

View File

@ -19,6 +19,7 @@ package sql
import (
"errors"
"fmt"
"strings"
)
// Query analysis - The query is analyzed to determine if it involves
@ -177,7 +178,7 @@ func (e *PrimaryTerm) analyze(s *Select) (result qProp) {
case e.JPathExpr != nil:
// Check if the path expression is valid
if len(e.JPathExpr.PathExpr) > 0 {
if e.JPathExpr.BaseKey.String() != s.From.As {
if e.JPathExpr.BaseKey.String() != s.From.As && strings.ToLower(e.JPathExpr.BaseKey.String()) != baseTableName {
result = qProp{err: errInvalidKeypath}
return
}

View File

@ -21,7 +21,6 @@ import (
"errors"
"fmt"
"math"
"strings"
"github.com/bcicen/jstream"
"github.com/minio/simdjson-go"
@ -47,21 +46,21 @@ var (
// of child nodes. The final result row is returned after all rows are
// processed, and the `getAggregate` function is called.
func (e *AliasedExpression) evalNode(r Record) (*Value, error) {
return e.Expression.evalNode(r)
func (e *AliasedExpression) evalNode(r Record, tableAlias string) (*Value, error) {
return e.Expression.evalNode(r, tableAlias)
}
func (e *Expression) evalNode(r Record) (*Value, error) {
func (e *Expression) evalNode(r Record, tableAlias string) (*Value, error) {
if len(e.And) == 1 {
// In this case, result is not required to be boolean
// type.
return e.And[0].evalNode(r)
return e.And[0].evalNode(r, tableAlias)
}
// Compute OR of conditions
result := false
for _, ex := range e.And {
res, err := ex.evalNode(r)
res, err := ex.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@ -74,16 +73,16 @@ func (e *Expression) evalNode(r Record) (*Value, error) {
return FromBool(result), nil
}
func (e *AndCondition) evalNode(r Record) (*Value, error) {
func (e *AndCondition) evalNode(r Record, tableAlias string) (*Value, error) {
if len(e.Condition) == 1 {
// In this case, result does not have to be boolean
return e.Condition[0].evalNode(r)
return e.Condition[0].evalNode(r, tableAlias)
}
// Compute AND of conditions
result := true
for _, ex := range e.Condition {
res, err := ex.evalNode(r)
res, err := ex.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@ -96,14 +95,14 @@ func (e *AndCondition) evalNode(r Record) (*Value, error) {
return FromBool(result), nil
}
func (e *Condition) evalNode(r Record) (*Value, error) {
func (e *Condition) evalNode(r Record, tableAlias string) (*Value, error) {
if e.Operand != nil {
// In this case, result does not have to be boolean
return e.Operand.evalNode(r)
return e.Operand.evalNode(r, tableAlias)
}
// Compute NOT of condition
res, err := e.Not.evalNode(r)
res, err := e.Not.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@ -114,8 +113,8 @@ func (e *Condition) evalNode(r Record) (*Value, error) {
return FromBool(!b), nil
}
func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
opVal, opErr := e.Operand.evalNode(r)
func (e *ConditionOperand) evalNode(r Record, tableAlias string) (*Value, error) {
opVal, opErr := e.Operand.evalNode(r, tableAlias)
if opErr != nil || e.ConditionRHS == nil {
return opVal, opErr
}
@ -123,7 +122,7 @@ func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
// Need to evaluate the ConditionRHS
switch {
case e.ConditionRHS.Compare != nil:
cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r)
cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r, tableAlias)
if cmpRErr != nil {
return nil, cmpRErr
}
@ -132,26 +131,26 @@ func (e *ConditionOperand) evalNode(r Record) (*Value, error) {
return FromBool(b), err
case e.ConditionRHS.Between != nil:
return e.ConditionRHS.Between.evalBetweenNode(r, opVal)
return e.ConditionRHS.Between.evalBetweenNode(r, opVal, tableAlias)
case e.ConditionRHS.Like != nil:
return e.ConditionRHS.Like.evalLikeNode(r, opVal)
return e.ConditionRHS.Like.evalLikeNode(r, opVal, tableAlias)
case e.ConditionRHS.In != nil:
return e.ConditionRHS.In.evalInNode(r, opVal)
return e.ConditionRHS.In.evalInNode(r, opVal, tableAlias)
default:
return nil, errInvalidASTNode
}
}
func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) {
stVal, stErr := e.Start.evalNode(r)
func (e *Between) evalBetweenNode(r Record, arg *Value, tableAlias string) (*Value, error) {
stVal, stErr := e.Start.evalNode(r, tableAlias)
if stErr != nil {
return nil, stErr
}
endVal, endErr := e.End.evalNode(r)
endVal, endErr := e.End.evalNode(r, tableAlias)
if endErr != nil {
return nil, endErr
}
@ -174,7 +173,7 @@ func (e *Between) evalBetweenNode(r Record, arg *Value) (*Value, error) {
return FromBool(result), nil
}
func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
func (e *Like) evalLikeNode(r Record, arg *Value, tableAlias string) (*Value, error) {
inferTypeAsString(arg)
s, ok := arg.ToString()
@ -183,7 +182,7 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
return nil, errLikeInvalidInputs(err)
}
pattern, err1 := e.Pattern.evalNode(r)
pattern, err1 := e.Pattern.evalNode(r, tableAlias)
if err1 != nil {
return nil, err1
}
@ -199,7 +198,7 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
escape := runeZero
if e.EscapeChar != nil {
escapeVal, err2 := e.EscapeChar.evalNode(r)
escapeVal, err2 := e.EscapeChar.evalNode(r, tableAlias)
if err2 != nil {
return nil, err2
}
@ -230,14 +229,14 @@ func (e *Like) evalLikeNode(r Record, arg *Value) (*Value, error) {
return FromBool(matchResult), nil
}
func (e *ListExpr) evalNode(r Record) (*Value, error) {
func (e *ListExpr) evalNode(r Record, tableAlias string) (*Value, error) {
res := make([]Value, len(e.Elements))
if len(e.Elements) == 1 {
// If length 1, treat as single value.
return e.Elements[0].evalNode(r)
return e.Elements[0].evalNode(r, tableAlias)
}
for i, elt := range e.Elements {
v, err := elt.evalNode(r)
v, err := elt.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@ -248,7 +247,7 @@ func (e *ListExpr) evalNode(r Record) (*Value, error) {
const floatCmpTolerance = 0.000001
func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) {
func (e *In) evalInNode(r Record, lhs *Value, tableAlias string) (*Value, error) {
// Compare two values in terms of in-ness.
var cmp func(a, b Value) bool
cmp = func(a, b Value) bool {
@ -283,7 +282,7 @@ func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) {
var rhs Value
if elt := e.ListExpression; elt != nil {
eltVal, err := elt.evalNode(r)
eltVal, err := elt.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@ -304,8 +303,8 @@ func (e *In) evalInNode(r Record, lhs *Value) (*Value, error) {
return FromBool(cmp(rhs, *lhs)), nil
}
func (e *Operand) evalNode(r Record) (*Value, error) {
lval, lerr := e.Left.evalNode(r)
func (e *Operand) evalNode(r Record, tableAlias string) (*Value, error) {
lval, lerr := e.Left.evalNode(r, tableAlias)
if lerr != nil || len(e.Right) == 0 {
return lval, lerr
}
@ -315,7 +314,7 @@ func (e *Operand) evalNode(r Record) (*Value, error) {
// symbols.
for _, rightTerm := range e.Right {
op := rightTerm.Op
rval, rerr := rightTerm.Right.evalNode(r)
rval, rerr := rightTerm.Right.evalNode(r, tableAlias)
if rerr != nil {
return nil, rerr
}
@ -327,8 +326,8 @@ func (e *Operand) evalNode(r Record) (*Value, error) {
return lval, nil
}
func (e *MultOp) evalNode(r Record) (*Value, error) {
lval, lerr := e.Left.evalNode(r)
func (e *MultOp) evalNode(r Record, tableAlias string) (*Value, error) {
lval, lerr := e.Left.evalNode(r, tableAlias)
if lerr != nil || len(e.Right) == 0 {
return lval, lerr
}
@ -337,7 +336,7 @@ func (e *MultOp) evalNode(r Record) (*Value, error) {
// AST node is for terms separated by *, / or % symbols.
for _, rightTerm := range e.Right {
op := rightTerm.Op
rval, rerr := rightTerm.Right.evalNode(r)
rval, rerr := rightTerm.Right.evalNode(r, tableAlias)
if rerr != nil {
return nil, rerr
}
@ -350,12 +349,12 @@ func (e *MultOp) evalNode(r Record) (*Value, error) {
return lval, nil
}
func (e *UnaryTerm) evalNode(r Record) (*Value, error) {
func (e *UnaryTerm) evalNode(r Record, tableAlias string) (*Value, error) {
if e.Negated == nil {
return e.Primary.evalNode(r)
return e.Primary.evalNode(r, tableAlias)
}
v, err := e.Negated.Term.evalNode(r)
v, err := e.Negated.Term.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@ -368,19 +367,15 @@ func (e *UnaryTerm) evalNode(r Record) (*Value, error) {
return nil, errArithMismatchedTypes
}
func (e *JSONPath) evalNode(r Record) (*Value, error) {
// Strip the table name from the keypath.
keypath := e.String()
if strings.Contains(keypath, ".") {
ps := strings.SplitN(keypath, ".", 2)
if len(ps) == 2 {
keypath = ps[1]
}
func (e *JSONPath) evalNode(r Record, tableAlias string) (*Value, error) {
alias := tableAlias
if tableAlias == "" {
alias = baseTableName
}
pathExpr := e.StripTableAlias(alias)
_, rawVal := r.Raw()
switch rowVal := rawVal.(type) {
case jstream.KVS, simdjson.Object:
pathExpr := e.PathExpr
if len(pathExpr) == 0 {
pathExpr = []*JSONPathElement{{Key: &ObjectKey{ID: e.BaseKey}}}
}
@ -392,7 +387,10 @@ func (e *JSONPath) evalNode(r Record) (*Value, error) {
return jsonToValue(result)
default:
return r.Get(keypath)
if pathExpr[len(pathExpr)-1].Key == nil {
return nil, errInvalidKeypath
}
return r.Get(pathExpr[len(pathExpr)-1].Key.keyString())
}
}
@ -447,28 +445,28 @@ func jsonToValue(result interface{}) (*Value, error) {
return nil, fmt.Errorf("Unhandled value type: %T", result)
}
func (e *PrimaryTerm) evalNode(r Record) (res *Value, err error) {
func (e *PrimaryTerm) evalNode(r Record, tableAlias string) (res *Value, err error) {
switch {
case e.Value != nil:
return e.Value.evalNode(r)
case e.JPathExpr != nil:
return e.JPathExpr.evalNode(r)
return e.JPathExpr.evalNode(r, tableAlias)
case e.ListExpr != nil:
return e.ListExpr.evalNode(r)
return e.ListExpr.evalNode(r, tableAlias)
case e.SubExpression != nil:
return e.SubExpression.evalNode(r)
return e.SubExpression.evalNode(r, tableAlias)
case e.FuncCall != nil:
return e.FuncCall.evalNode(r)
return e.FuncCall.evalNode(r, tableAlias)
}
return nil, errInvalidASTNode
}
func (e *FuncExpr) evalNode(r Record) (res *Value, err error) {
func (e *FuncExpr) evalNode(r Record, tableAlias string) (res *Value, err error) {
switch e.getFunctionName() {
case aggFnCount, aggFnAvg, aggFnMax, aggFnMin, aggFnSum:
return e.getAggregate()
default:
return e.evalSQLFnNode(r)
return e.evalSQLFnNode(r, tableAlias)
}
}

View File

@ -84,35 +84,35 @@ func (e *FuncExpr) getFunctionName() FuncName {
// evalSQLFnNode assumes that the FuncExpr is not an aggregation
// function.
func (e *FuncExpr) evalSQLFnNode(r Record) (res *Value, err error) {
func (e *FuncExpr) evalSQLFnNode(r Record, tableAlias string) (res *Value, err error) {
// Handle functions that have phrase arguments
switch e.getFunctionName() {
case sqlFnCast:
expr := e.Cast.Expr
res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType))
res, err = expr.castTo(r, strings.ToUpper(e.Cast.CastType), tableAlias)
return
case sqlFnSubstring:
return handleSQLSubstring(r, e.Substring)
return handleSQLSubstring(r, e.Substring, tableAlias)
case sqlFnExtract:
return handleSQLExtract(r, e.Extract)
return handleSQLExtract(r, e.Extract, tableAlias)
case sqlFnTrim:
return handleSQLTrim(r, e.Trim)
return handleSQLTrim(r, e.Trim, tableAlias)
case sqlFnDateAdd:
return handleDateAdd(r, e.DateAdd)
return handleDateAdd(r, e.DateAdd, tableAlias)
case sqlFnDateDiff:
return handleDateDiff(r, e.DateDiff)
return handleDateDiff(r, e.DateDiff, tableAlias)
}
// For all simple argument functions, we evaluate the arguments here
argVals := make([]*Value, len(e.SFunc.ArgsList))
for i, arg := range e.SFunc.ArgsList {
argVals[i], err = arg.evalNode(r)
argVals[i], err = arg.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@ -219,8 +219,8 @@ func upperCase(v *Value) (*Value, error) {
return FromString(strings.ToUpper(s)), nil
}
func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) {
q, err := d.Quantity.evalNode(r)
func handleDateAdd(r Record, d *DateAddFunc, tableAlias string) (*Value, error) {
q, err := d.Quantity.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@ -230,7 +230,7 @@ func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) {
return nil, fmt.Errorf("QUANTITY must be a numeric argument to %s()", sqlFnDateAdd)
}
ts, err := d.Timestamp.evalNode(r)
ts, err := d.Timestamp.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@ -245,8 +245,8 @@ func handleDateAdd(r Record, d *DateAddFunc) (*Value, error) {
return dateAdd(strings.ToUpper(d.DatePart), qty, t)
}
func handleDateDiff(r Record, d *DateDiffFunc) (*Value, error) {
tval1, err := d.Timestamp1.evalNode(r)
func handleDateDiff(r Record, d *DateDiffFunc, tableAlias string) (*Value, error) {
tval1, err := d.Timestamp1.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@ -258,7 +258,7 @@ func handleDateDiff(r Record, d *DateDiffFunc) (*Value, error) {
return nil, fmt.Errorf("%s() expects two timestamp arguments", sqlFnDateDiff)
}
tval2, err := d.Timestamp2.evalNode(r)
tval2, err := d.Timestamp2.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@ -277,12 +277,12 @@ func handleUTCNow() (*Value, error) {
return FromTimestamp(time.Now().UTC()), nil
}
func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
func handleSQLSubstring(r Record, e *SubstringFunc, tableAlias string) (val *Value, err error) {
// Both forms `SUBSTRING('abc' FROM 2 FOR 1)` and
// SUBSTRING('abc', 2, 1) are supported.
// Evaluate the string argument
v1, err := e.Expr.evalNode(r)
v1, err := e.Expr.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@ -301,7 +301,7 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
}
// Evaluate the FROM argument
v2, err := arg2.evalNode(r)
v2, err := arg2.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@ -315,7 +315,7 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
length := -1
// Evaluate the optional FOR argument
if arg3 != nil {
v3, err := arg3.evalNode(r)
v3, err := arg3.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
@ -336,11 +336,11 @@ func handleSQLSubstring(r Record, e *SubstringFunc) (val *Value, err error) {
return FromString(res), err
}
func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) {
func handleSQLTrim(r Record, e *TrimFunc, tableAlias string) (res *Value, err error) {
chars := ""
ok := false
if e.TrimChars != nil {
charsV, cerr := e.TrimChars.evalNode(r)
charsV, cerr := e.TrimChars.evalNode(r, tableAlias)
if cerr != nil {
return nil, cerr
}
@ -351,7 +351,7 @@ func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) {
}
}
fromV, ferr := e.TrimFrom.evalNode(r)
fromV, ferr := e.TrimFrom.evalNode(r, tableAlias)
if ferr != nil {
return nil, ferr
}
@ -368,8 +368,8 @@ func handleSQLTrim(r Record, e *TrimFunc) (res *Value, err error) {
return FromString(result), nil
}
func handleSQLExtract(r Record, e *ExtractFunc) (res *Value, err error) {
timeVal, verr := e.From.evalNode(r)
func handleSQLExtract(r Record, e *ExtractFunc, tableAlias string) (res *Value, err error) {
timeVal, verr := e.From.evalNode(r, tableAlias)
if verr != nil {
return nil, verr
}
@ -406,8 +406,8 @@ const (
castTimestamp = "TIMESTAMP"
)
func (e *Expression) castTo(r Record, castType string) (res *Value, err error) {
v, err := e.evalNode(r)
func (e *Expression) castTo(r Record, castType string, tableAlias string) (res *Value, err error) {
v, err := e.evalNode(r, tableAlias)
if err != nil {
return nil, err
}

View File

@ -119,7 +119,9 @@ type JSONPath struct {
PathExpr []*JSONPathElement `parser:"(@@)*"`
// Cached values:
pathString string
pathString string
strippedTableAlias string
strippedPathExpr []*JSONPathElement
}
// AliasedExpression is an expression that can be optionally named

View File

@ -46,6 +46,9 @@ type SelectStatement struct {
// Count of rows that have been output.
outputCount int64
// Table alias
tableAlias string
}
// ParseSelectStatement - parses a select query from the given string
@ -107,6 +110,9 @@ func ParseSelectStatement(s string) (stmt SelectStatement, err error) {
if err != nil {
err = errQueryAnalysisFailure(err)
}
// Set table alias
stmt.tableAlias = selectAST.From.As
return
}
@ -226,7 +232,7 @@ func (e *SelectStatement) IsAggregated() bool {
// records have been processed. Applies only to aggregation queries.
func (e *SelectStatement) AggregateResult(output Record) error {
for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(nil)
v, err := expr.evalNode(nil, e.tableAlias)
if err != nil {
return err
}
@ -246,7 +252,7 @@ func (e *SelectStatement) isPassingWhereClause(input Record) (bool, error) {
if e.selectAST.Where == nil {
return true, nil
}
value, err := e.selectAST.Where.evalNode(input)
value, err := e.selectAST.Where.evalNode(input, e.tableAlias)
if err != nil {
return false, err
}
@ -272,7 +278,7 @@ func (e *SelectStatement) AggregateRow(input Record) error {
}
for _, expr := range e.selectAST.Expression.Expressions {
err := expr.aggregateRow(input)
err := expr.aggregateRow(input, e.tableAlias)
if err != nil {
return err
}
@ -302,7 +308,7 @@ func (e *SelectStatement) Eval(input, output Record) (Record, error) {
}
for i, expr := range e.selectAST.Expression.Expressions {
v, err := expr.evalNode(input)
v, err := expr.evalNode(input, e.tableAlias)
if err != nil {
return nil, err
}

View File

@ -36,6 +36,27 @@ func (e *JSONPath) String() string {
return e.pathString
}
// StripTableAlias removes a table alias from the path. The result is also
// cached for repeated lookups during SQL query evaluation.
func (e *JSONPath) StripTableAlias(tableAlias string) []*JSONPathElement {
if e.strippedTableAlias == tableAlias {
return e.strippedPathExpr
}
hasTableAlias := e.BaseKey.String() == tableAlias || strings.ToLower(e.BaseKey.String()) == baseTableName
var pathExpr []*JSONPathElement
if hasTableAlias {
pathExpr = e.PathExpr
} else {
pathExpr = make([]*JSONPathElement, len(e.PathExpr)+1)
pathExpr[0] = &JSONPathElement{Key: &ObjectKey{ID: e.BaseKey}}
copy(pathExpr[1:], e.PathExpr)
}
e.strippedTableAlias = tableAlias
e.strippedPathExpr = pathExpr
return e.strippedPathExpr
}
func (e *JSONPathElement) String() string {
switch {
case e.Key != nil: