package smile.data.formula;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import smile.data.AbstractTuple;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.type.DataType;
import smile.data.type.DataTypes;
import smile.data.type.ObjectType;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile/data/formula/Formula.class */
public class Formula implements Serializable {
    private static final long serialVersionUID = 2;
    private Optional<Term> response;
    private HyperTerm[] predictors;
    private transient StructType schema;
    private transient StructType xschema;
    private transient Term[] x;
    private transient Term[] xy;

    public Formula(String str) {
        this(new Variable(str));
    }

    public Formula(Term term) {
        this.response = Optional.empty();
        this.response = Optional.of(term);
        this.predictors = new HyperTerm[]{new All()};
    }

    public Formula(HyperTerm[] hyperTermArr) {
        this.response = Optional.empty();
        this.predictors = hyperTermArr;
    }

    public Formula(String str, HyperTerm[] hyperTermArr) {
        this(new Variable(str), hyperTermArr);
    }

    public Formula(Term term, HyperTerm[] hyperTermArr) {
        this.response = Optional.empty();
        this.response = Optional.of(term);
        this.predictors = hyperTermArr;
    }

    public Formula predictors() {
        return rhs(this.predictors);
    }

    public Optional<Term> response() {
        return this.response;
    }

    public String toString() {
        String str = (String) this.response.map((v0) -> {
            return Objects.toString(v0);
        }).orElse("");
        String str2 = (String) Arrays.stream(this.predictors).map(hyperTerm -> {
            String obj = hyperTerm.toString();
            if (!obj.startsWith("- ")) {
                obj = "+ " + obj;
            }
            return obj;
        }).collect(Collectors.joining(" "));
        if (str2.startsWith("+ ")) {
            str2 = str2.substring(2);
        }
        return String.format("%s ~ %s", str, str2);
    }

    public static Formula lhs(String str) {
        return new Formula(str);
    }

    public static Formula lhs(Term term) {
        return new Formula(term);
    }

    public static Formula rhs(String... strArr) {
        return new Formula((HyperTerm[]) Arrays.stream(strArr).map(str -> {
            return new Variable(str);
        }).toArray(i -> {
            return new Term[i];
        }));
    }

    public static Formula rhs(HyperTerm... hyperTermArr) {
        return new Formula(hyperTermArr);
    }

    public static Formula of(String str, String... strArr) {
        return new Formula(str, (HyperTerm[]) Arrays.stream(strArr).map(str2 -> {
            return new Variable(str2);
        }).toArray(i -> {
            return new Term[i];
        }));
    }

    public static Formula of(String str, HyperTerm... hyperTermArr) {
        return new Formula(str, hyperTermArr);
    }

    public static Formula of(Term term, HyperTerm... hyperTermArr) {
        return new Formula(term, hyperTermArr);
    }

    public StructType schema() {
        return this.schema;
    }

    public StructType xschema() {
        return this.xschema;
    }

    public StructType bind(StructType structType) {
        return bind(structType, true);
    }

    private StructType bind(StructType structType, boolean z) {
        if (this.schema != null && !z) {
            return this.schema;
        }
        this.response.ifPresent(term -> {
            term.bind(structType);
        });
        Arrays.stream(this.predictors).forEach(hyperTerm -> {
            hyperTerm.bind(structType);
        });
        HashSet hashSet = new HashSet();
        this.response.ifPresent(term2 -> {
            hashSet.addAll(term2.variables());
        });
        Arrays.stream(this.predictors).filter(hyperTerm2 -> {
            return !(hyperTerm2 instanceof All);
        }).flatMap(hyperTerm3 -> {
            return hyperTerm3.terms().stream();
        }).filter(term3 -> {
            return term3 instanceof Variable;
        }).forEach(term4 -> {
            hashSet.add(term4.name());
        });
        ArrayList arrayList = new ArrayList();
        this.response.ifPresent(term5 -> {
            arrayList.add(term5);
        });
        arrayList.addAll((Collection) Arrays.stream(this.predictors).filter(hyperTerm4 -> {
            return !(hyperTerm4 instanceof Delete);
        }).flatMap(hyperTerm5 -> {
            return hyperTerm5 instanceof Delete ? Stream.empty() : hyperTerm5 instanceof All ? hyperTerm5.terms().stream().filter(term6 -> {
                return !hashSet.contains(term6.name());
            }) : hyperTerm5.terms().stream();
        }).collect(Collectors.toList()));
        arrayList.removeAll((List) Arrays.stream(this.predictors).filter(hyperTerm6 -> {
            return hyperTerm6 instanceof Delete;
        }).flatMap(hyperTerm7 -> {
            return hyperTerm7.terms().stream();
        }).collect(Collectors.toList()));
        this.xy = (Term[]) arrayList.toArray(new Term[arrayList.size()]);
        StructField[] structFieldArr = (StructField[]) arrayList.stream().map(term6 -> {
            return term6.field();
        }).toArray(i -> {
            return new StructField[i];
        });
        this.schema = DataTypes.struct(structFieldArr);
        if (this.response.isPresent()) {
            this.x = (Term[]) Arrays.copyOfRange(this.xy, 1, this.xy.length);
            this.xschema = DataTypes.struct((StructField[]) Arrays.copyOfRange(structFieldArr, 1, structFieldArr.length));
        } else {
            this.x = this.xy;
            this.xschema = this.schema;
        }
        return this.schema;
    }

    public Tuple apply(final Tuple tuple) {
        bind(tuple.schema(), false);
        return new AbstractTuple() { // from class: smile.data.formula.Formula.1
            @Override // smile.data.Tuple
            public StructType schema() {
                return Formula.this.schema;
            }

            @Override // smile.data.Tuple
            public Object get(int i) {
                return Formula.this.xy[i].apply(tuple);
            }

            @Override // smile.data.Tuple
            public int getInt(int i) {
                return Formula.this.xy[i].applyAsInt(tuple);
            }

            @Override // smile.data.Tuple
            public long getLong(int i) {
                return Formula.this.xy[i].applyAsLong(tuple);
            }

            @Override // smile.data.Tuple
            public float getFloat(int i) {
                return Formula.this.xy[i].applyAsFloat(tuple);
            }

            @Override // smile.data.Tuple
            public double getDouble(int i) {
                return Formula.this.xy[i].applyAsDouble(tuple);
            }

            @Override // smile.data.AbstractTuple
            public String toString() {
                return Formula.this.schema.toString(this);
            }
        };
    }

    public Tuple x(final Tuple tuple) {
        bind(tuple.schema(), false);
        return new AbstractTuple() { // from class: smile.data.formula.Formula.2
            @Override // smile.data.Tuple
            public StructType schema() {
                return Formula.this.xschema;
            }

            @Override // smile.data.Tuple
            public Object get(int i) {
                return Formula.this.x[i].apply(tuple);
            }

            @Override // smile.data.Tuple
            public int getInt(int i) {
                return Formula.this.x[i].applyAsInt(tuple);
            }

            @Override // smile.data.Tuple
            public long getLong(int i) {
                return Formula.this.x[i].applyAsLong(tuple);
            }

            @Override // smile.data.Tuple
            public float getFloat(int i) {
                return Formula.this.x[i].applyAsFloat(tuple);
            }

            @Override // smile.data.Tuple
            public double getDouble(int i) {
                return Formula.this.x[i].applyAsDouble(tuple);
            }

            @Override // smile.data.AbstractTuple
            public String toString() {
                return Formula.this.xschema.toString(this);
            }
        };
    }

    public double[] xarray(Tuple tuple) {
        return Arrays.stream(this.x).mapToDouble(term -> {
            return term.applyAsDouble(tuple);
        }).toArray();
    }

    public DataFrame apply(DataFrame dataFrame) {
        bind(dataFrame.schema(), true);
        return DataFrame.of((BaseVector[]) Arrays.stream(this.xy).map(term -> {
            return term.apply(dataFrame);
        }).toArray(i -> {
            return new BaseVector[i];
        }));
    }

    public DataFrame x(DataFrame dataFrame) {
        bind(dataFrame.schema(), true);
        return DataFrame.of((BaseVector[]) Arrays.stream(this.x).map(term -> {
            return term.apply(dataFrame);
        }).toArray(i -> {
            return new BaseVector[i];
        }));
    }

    public DenseMatrix matrix(DataFrame dataFrame) {
        return matrix(dataFrame, false);
    }

    public DenseMatrix matrix(DataFrame dataFrame, boolean z) {
        bind(dataFrame.schema(), true);
        int nrows = dataFrame.nrows();
        int length = this.x.length + (z ? 1 : 0);
        DenseMatrix of = Matrix.of(nrows, length, 0.0d);
        if (z) {
            for (int i = 0; i < nrows; i++) {
                of.set(i, length - 1, 1.0d);
            }
        }
        for (int i2 = 0; i2 < this.x.length; i2++) {
            BaseVector apply = this.x[i2].apply(dataFrame);
            DataType type = this.x[i2].type();
            switch (type.id()) {
                case Double:
                case Integer:
                case Float:
                case Long:
                case Boolean:
                case Byte:
                case Short:
                case Char:
                    for (int i3 = 0; i3 < nrows; i3++) {
                        of.set(i3, i2, apply.getDouble(i3));
                    }
                    break;
                case String:
                    for (int i4 = 0; i4 < nrows; i4++) {
                        String str = (String) apply.get(i4);
                        of.set(i4, i2, str == null ? Double.NaN : Double.valueOf(str).doubleValue());
                    }
                    break;
                case Object:
                    Class objectClass = ((ObjectType) type).getObjectClass();
                    if (objectClass != Boolean.class) {
                        if (!Number.class.isAssignableFrom(objectClass)) {
                            throw new UnsupportedOperationException(String.format("DataFrame.toMatrix() doesn't support type %s", type));
                        }
                        for (int i5 = 0; i5 < nrows; i5++) {
                            of.set(i5, i2, apply.getDouble(i5));
                        }
                        break;
                    } else {
                        for (int i6 = 0; i6 < nrows; i6++) {
                            Boolean bool = (Boolean) apply.get(i6);
                            if (bool != null) {
                                of.set(i6, i2, bool.booleanValue() ? 1.0d : 0.0d);
                            } else {
                                of.set(i6, i2, Double.NaN);
                            }
                        }
                        break;
                    }
                default:
                    throw new UnsupportedOperationException(String.format("DataFrame.toMatrix() doesn't support type %s", type));
            }
        }
        return of;
    }

    public BaseVector y(DataFrame dataFrame) {
        return (BaseVector) this.response.map(term -> {
            term.bind(dataFrame.schema());
            return term.apply(dataFrame);
        }).orElse(null);
    }

    public double y(Tuple tuple) {
        return ((Double) this.response.map(term -> {
            term.bind(tuple.schema());
            return Double.valueOf(term.applyAsDouble(tuple));
        }).orElse(Double.valueOf(0.0d))).doubleValue();
    }

    public int yint(Tuple tuple) {
        return ((Integer) this.response.map(term -> {
            term.bind(tuple.schema());
            return Integer.valueOf(term.applyAsInt(tuple));
        }).orElse(-1)).intValue();
    }
}
