minio/pkg/s3select/sql/evaluate.go
Aditya Manthramurthy 41a9d1d778 Fix S3Select SQL column reference handling (#11957)
This change fixes handling of these types of queries:

- Double quoted column names with special characters:
    SELECT "column.name" FROM s3object
- Double quoted column names with reserved keywords:
    SELECT "CAST" FROM s3object
- Table name as prefix for column names:
    SELECT S3Object."CAST" FROM s3object
2021-04-06 12:40:28 -07:00

491 lines
12 KiB
Go

/*
* MinIO Cloud Storage, (C) 2019 MinIO, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package sql
import (
"encoding/json"
"errors"
"fmt"
"math"
"github.com/bcicen/jstream"
"github.com/minio/simdjson-go"
)
var (
errInvalidASTNode = errors.New("invalid AST Node")
errExpectedBool = errors.New("expected bool")
errLikeNonStrArg = errors.New("LIKE clause requires string arguments")
errLikeInvalidEscape = errors.New("LIKE clause has invalid ESCAPE character")
errNotImplemented = errors.New("not implemented")
)
// AST Node Evaluation functions
//
// During evaluation, the query is known to be valid, as analysis is
// complete. The only errors possible are due to value type
// mismatches, etc.
//
// If an aggregation node is present as a descendant (when
// e.prop.isAggregation is true), we call evalNode on all child nodes,
// check for errors, but do not perform any combining of the results
// of child nodes. The final result row is returned after all rows are
// processed, and the `getAggregate` function is called.
func (e *AliasedExpression) evalNode(r Record, tableAlias string) (*Value, error) {
return e.Expression.evalNode(r, tableAlias)
}
func (e *Expression) evalNode(r Record, tableAlias string) (*Value, error) {
if len(e.And) == 1 {
// In this case, result is not required to be boolean
// type.
return e.And[0].evalNode(r, tableAlias)
}
// Compute OR of conditions
result := false
for _, ex := range e.And {
res, err := ex.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
b, ok := res.ToBool()
if !ok {
return nil, errExpectedBool
}
result = result || b
}
return FromBool(result), nil
}
func (e *AndCondition) evalNode(r Record, tableAlias string) (*Value, error) {
if len(e.Condition) == 1 {
// In this case, result does not have to be boolean
return e.Condition[0].evalNode(r, tableAlias)
}
// Compute AND of conditions
result := true
for _, ex := range e.Condition {
res, err := ex.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
b, ok := res.ToBool()
if !ok {
return nil, errExpectedBool
}
result = result && b
}
return FromBool(result), nil
}
func (e *Condition) evalNode(r Record, tableAlias string) (*Value, error) {
if e.Operand != nil {
// In this case, result does not have to be boolean
return e.Operand.evalNode(r, tableAlias)
}
// Compute NOT of condition
res, err := e.Not.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
b, ok := res.ToBool()
if !ok {
return nil, errExpectedBool
}
return FromBool(!b), nil
}
func (e *ConditionOperand) evalNode(r Record, tableAlias string) (*Value, error) {
opVal, opErr := e.Operand.evalNode(r, tableAlias)
if opErr != nil || e.ConditionRHS == nil {
return opVal, opErr
}
// Need to evaluate the ConditionRHS
switch {
case e.ConditionRHS.Compare != nil:
cmpRight, cmpRErr := e.ConditionRHS.Compare.Operand.evalNode(r, tableAlias)
if cmpRErr != nil {
return nil, cmpRErr
}
b, err := opVal.compareOp(e.ConditionRHS.Compare.Operator, cmpRight)
return FromBool(b), err
case e.ConditionRHS.Between != nil:
return e.ConditionRHS.Between.evalBetweenNode(r, opVal, tableAlias)
case e.ConditionRHS.Like != nil:
return e.ConditionRHS.Like.evalLikeNode(r, opVal, tableAlias)
case e.ConditionRHS.In != nil:
return e.ConditionRHS.In.evalInNode(r, opVal, tableAlias)
default:
return nil, errInvalidASTNode
}
}
func (e *Between) evalBetweenNode(r Record, arg *Value, tableAlias string) (*Value, error) {
stVal, stErr := e.Start.evalNode(r, tableAlias)
if stErr != nil {
return nil, stErr
}
endVal, endErr := e.End.evalNode(r, tableAlias)
if endErr != nil {
return nil, endErr
}
part1, err1 := stVal.compareOp(opLte, arg)
if err1 != nil {
return nil, err1
}
part2, err2 := arg.compareOp(opLte, endVal)
if err2 != nil {
return nil, err2
}
result := part1 && part2
if e.Not {
result = !result
}
return FromBool(result), nil
}
func (e *Like) evalLikeNode(r Record, arg *Value, tableAlias string) (*Value, error) {
inferTypeAsString(arg)
s, ok := arg.ToString()
if !ok {
err := errLikeNonStrArg
return nil, errLikeInvalidInputs(err)
}
pattern, err1 := e.Pattern.evalNode(r, tableAlias)
if err1 != nil {
return nil, err1
}
// Infer pattern as string (in case it is untyped)
inferTypeAsString(pattern)
patternStr, ok := pattern.ToString()
if !ok {
err := errLikeNonStrArg
return nil, errLikeInvalidInputs(err)
}
escape := runeZero
if e.EscapeChar != nil {
escapeVal, err2 := e.EscapeChar.evalNode(r, tableAlias)
if err2 != nil {
return nil, err2
}
inferTypeAsString(escapeVal)
escapeStr, ok := escapeVal.ToString()
if !ok {
err := errLikeNonStrArg
return nil, errLikeInvalidInputs(err)
}
if len([]rune(escapeStr)) > 1 {
err := errLikeInvalidEscape
return nil, errLikeInvalidInputs(err)
}
}
matchResult, err := evalSQLLike(s, patternStr, escape)
if err != nil {
return nil, err
}
if e.Not {
matchResult = !matchResult
}
return FromBool(matchResult), nil
}
func (e *ListExpr) evalNode(r Record, tableAlias string) (*Value, error) {
res := make([]Value, len(e.Elements))
if len(e.Elements) == 1 {
// If length 1, treat as single value.
return e.Elements[0].evalNode(r, tableAlias)
}
for i, elt := range e.Elements {
v, err := elt.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
res[i] = *v
}
return FromArray(res), nil
}
const floatCmpTolerance = 0.000001
func (e *In) evalInNode(r Record, lhs *Value, tableAlias string) (*Value, error) {
// Compare two values in terms of in-ness.
var cmp func(a, b Value) bool
cmp = func(a, b Value) bool {
// Convert if needed.
inferTypesForCmp(&a, &b)
if a.Equals(b) {
return true
}
// If elements, compare each.
aA, aOK := a.ToArray()
bA, bOK := b.ToArray()
if aOK && bOK {
if len(aA) != len(bA) {
return false
}
for i := range aA {
if !cmp(aA[i], bA[i]) {
return false
}
}
return true
}
// Try as numbers
aF, aOK := a.ToFloat()
bF, bOK := b.ToFloat()
diff := math.Abs(aF - bF)
return aOK && bOK && diff < floatCmpTolerance
}
var rhs Value
if elt := e.ListExpression; elt != nil {
eltVal, err := elt.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
rhs = *eltVal
}
// If RHS is array compare each element.
if arr, ok := rhs.ToArray(); ok {
for _, element := range arr {
// If we have an array we are on the wrong level.
if cmp(element, *lhs) {
return FromBool(true), nil
}
}
return FromBool(false), nil
}
return FromBool(cmp(rhs, *lhs)), nil
}
func (e *Operand) evalNode(r Record, tableAlias string) (*Value, error) {
lval, lerr := e.Left.evalNode(r, tableAlias)
if lerr != nil || len(e.Right) == 0 {
return lval, lerr
}
// Process remaining child nodes - result must be
// numeric. This AST node is for terms separated by + or -
// symbols.
for _, rightTerm := range e.Right {
op := rightTerm.Op
rval, rerr := rightTerm.Right.evalNode(r, tableAlias)
if rerr != nil {
return nil, rerr
}
err := lval.arithOp(op, rval)
if err != nil {
return nil, err
}
}
return lval, nil
}
func (e *MultOp) evalNode(r Record, tableAlias string) (*Value, error) {
lval, lerr := e.Left.evalNode(r, tableAlias)
if lerr != nil || len(e.Right) == 0 {
return lval, lerr
}
// Process other child nodes - result must be numeric. This
// AST node is for terms separated by *, / or % symbols.
for _, rightTerm := range e.Right {
op := rightTerm.Op
rval, rerr := rightTerm.Right.evalNode(r, tableAlias)
if rerr != nil {
return nil, rerr
}
err := lval.arithOp(op, rval)
if err != nil {
return nil, err
}
}
return lval, nil
}
func (e *UnaryTerm) evalNode(r Record, tableAlias string) (*Value, error) {
if e.Negated == nil {
return e.Primary.evalNode(r, tableAlias)
}
v, err := e.Negated.Term.evalNode(r, tableAlias)
if err != nil {
return nil, err
}
inferTypeForArithOp(v)
v.negate()
if v.isNumeric() {
return v, nil
}
return nil, errArithMismatchedTypes
}
func (e *JSONPath) evalNode(r Record, tableAlias string) (*Value, error) {
alias := tableAlias
if tableAlias == "" {
alias = baseTableName
}
pathExpr := e.StripTableAlias(alias)
_, rawVal := r.Raw()
switch rowVal := rawVal.(type) {
case jstream.KVS, simdjson.Object:
if len(pathExpr) == 0 {
pathExpr = []*JSONPathElement{{Key: &ObjectKey{ID: e.BaseKey}}}
}
result, _, err := jsonpathEval(pathExpr, rowVal)
if err != nil {
return nil, err
}
return jsonToValue(result)
default:
if pathExpr[len(pathExpr)-1].Key == nil {
return nil, errInvalidKeypath
}
return r.Get(pathExpr[len(pathExpr)-1].Key.keyString())
}
}
// jsonToValue will convert the json value to an internal value.
func jsonToValue(result interface{}) (*Value, error) {
switch rval := result.(type) {
case string:
return FromString(rval), nil
case float64:
return FromFloat(rval), nil
case int64:
return FromInt(rval), nil
case uint64:
if rval <= math.MaxInt64 {
return FromInt(int64(rval)), nil
}
return FromFloat(float64(rval)), nil
case bool:
return FromBool(rval), nil
case jstream.KVS:
bs, err := json.Marshal(result)
if err != nil {
return nil, err
}
return FromBytes(bs), nil
case []interface{}:
dst := make([]Value, len(rval))
for i := range rval {
v, err := jsonToValue(rval[i])
if err != nil {
return nil, err
}
dst[i] = *v
}
return FromArray(dst), nil
case simdjson.Object:
o := rval
elems, err := o.Parse(nil)
if err != nil {
return nil, err
}
bs, err := elems.MarshalJSON()
if err != nil {
return nil, err
}
return FromBytes(bs), nil
case []Value:
return FromArray(rval), nil
case nil:
return FromNull(), nil
}
return nil, fmt.Errorf("Unhandled value type: %T", result)
}
func (e *PrimaryTerm) evalNode(r Record, tableAlias string) (res *Value, err error) {
switch {
case e.Value != nil:
return e.Value.evalNode(r)
case e.JPathExpr != nil:
return e.JPathExpr.evalNode(r, tableAlias)
case e.ListExpr != nil:
return e.ListExpr.evalNode(r, tableAlias)
case e.SubExpression != nil:
return e.SubExpression.evalNode(r, tableAlias)
case e.FuncCall != nil:
return e.FuncCall.evalNode(r, tableAlias)
}
return nil, errInvalidASTNode
}
func (e *FuncExpr) evalNode(r Record, tableAlias string) (res *Value, err error) {
switch e.getFunctionName() {
case aggFnCount, aggFnAvg, aggFnMax, aggFnMin, aggFnSum:
return e.getAggregate()
default:
return e.evalSQLFnNode(r, tableAlias)
}
}
// evalNode on a literal value is independent of the node being an
// aggregation or a row function - it always returns a value.
func (e *LitValue) evalNode(_ Record) (res *Value, err error) {
switch {
case e.Int != nil:
if *e.Int < math.MaxInt64 && *e.Int > math.MinInt64 {
return FromInt(int64(*e.Int)), nil
}
return FromFloat(*e.Int), nil
case e.Float != nil:
return FromFloat(*e.Float), nil
case e.String != nil:
return FromString(string(*e.String)), nil
case e.Boolean != nil:
return FromBool(bool(*e.Boolean)), nil
}
return FromNull(), nil
}