/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.regression;

import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.math.MutableInnerProductModule;
import breeze.optimize.CachedDiffFunction;
import breeze.optimize.DiffFunction;
import breeze.optimize.FirstOrderMinimizer;
import breeze.optimize.LBFGS;
import breeze.optimize.StochasticDiffFunction;
import java.io.IOException;
import java.io.Serializable;
import org.apache.spark.SparkException;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.DoubleArrayParam;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.shared.HasAggregationDepth;
import org.apache.spark.ml.param.shared.HasFitIntercept;
import org.apache.spark.ml.param.shared.HasMaxIter;
import org.apache.spark.ml.param.shared.HasTol;
import org.apache.spark.ml.regression.AFTCostFun;
import org.apache.spark.ml.regression.AFTPoint;
import org.apache.spark.ml.regression.AFTSurvivalRegression$;
import org.apache.spark.ml.regression.AFTSurvivalRegressionModel;
import org.apache.spark.ml.regression.AFTSurvivalRegressionParams;
import org.apache.spark.ml.regression.Regressor;
import org.apache.spark.ml.stat.Summarizer$;
import org.apache.spark.ml.stat.SummarizerBuffer;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.Instrumentation$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Some;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.SeqLike;
import scala.collection.mutable.ArrayBuilder;
import scala.collection.mutable.ArrayBuilder$;
import scala.collection.mutable.ArrayOps;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0001\u0005Mg\u0001B\u000b\u0017\u0001\u0005B\u0001b\u0010\u0001\u0003\u0006\u0004%\t\u0005\u0011\u0005\t/\u0002\u0011\t\u0011)A\u0005\u0003\")\u0011\f\u0001C\u00015\")\u0011\f\u0001C\u0001=\")\u0001\r\u0001C\u0001C\")a\r\u0001C\u0001O\")\u0011\u000f\u0001C\u0001e\")Q\u000f\u0001C\u0001m\")A\u0010\u0001C\u0001{\"9\u0011q\u0001\u0001\u0005\u0002\u0005%\u0001bBA\b\u0001\u0011\u0005\u0011\u0011\u0003\u0005\t\u00037\u0001A\u0011\u0003\r\u0002\u001e!9\u0011\u0011\f\u0001\u0005R\u0005m\u0003bBA5\u0001\u0011\u0005\u00131\u000e\u0005\b\u0003\u007f\u0002A\u0011IAA\u000f\u001d\t9J\u0006E\u0001\u000333a!\u0006\f\t\u0002\u0005m\u0005BB-\u0012\t\u0003\ty\u000bC\u0004\u00022F!\t%a-\t\u0013\u0005m\u0016#!A\u0005\n\u0005u&!F!G)N+(O^5wC2\u0014Vm\u001a:fgNLwN\u001c\u0006\u0003/a\t!B]3he\u0016\u001c8/[8o\u0015\tI\"$\u0001\u0002nY*\u00111\u0004H\u0001\u0006gB\f'o\u001b\u0006\u0003;y\ta!\u00199bG\",'\"A\u0010\u0002\u0007=\u0014xm\u0001\u0001\u0014\u000b\u0001\u0011\u0003gM\u001d\u0011\u000b\r\"c\u0005L\u0017\u000e\u0003YI!!\n\f\u0003\u0013I+wM]3tg>\u0014\bCA\u0014+\u001b\u0005A#BA\u0015\u0019\u0003\u0019a\u0017N\\1mO&\u00111\u0006\u000b\u0002\u0007-\u0016\u001cGo\u001c:\u0011\u0005\r\u0002\u0001CA\u0012/\u0013\tycC\u0001\u000eB\rR\u001bVO\u001d<jm\u0006d'+Z4sKN\u001c\u0018n\u001c8N_\u0012,G\u000e\u0005\u0002$c%\u0011!G\u0006\u0002\u001c\u0003\u001a#6+\u001e:wSZ\fGNU3he\u0016\u001c8/[8o!\u0006\u0014\u0018-\\:\u0011\u0005Q:T\"A\u001b\u000b\u0005YB\u0012\u0001B;uS2L!\u0001O\u001b\u0003+\u0011+g-Y;miB\u000b'/Y7t/JLG/\u00192mKB\u0011!(P\u0007\u0002w)\u0011AHG\u0001\tS:$XM\u001d8bY&\u0011ah\u000f\u0002\b\u0019><w-\u001b8h\u0003\r)\u0018\u000eZ\u000b\u0002\u0003B\u0011!i\u0013\b\u0003\u0007&\u0003\"\u0001R$\u000e\u0003\u0015S!A\u0012\u0011\u0002\rq\u0012xn\u001c;?\u0015\u0005A\u0015!B:dC2\f\u0017B\u0001&H\u0003\u0019\u0001&/\u001a3fM&\u0011A*\u0014\u0002\u0007'R\u0014\u0018N\\4\u000b\u0005);\u0005fA\u0001P+B\u0011\u0001kU\u0007\u0002#*\u0011!KG\u0001\u000bC:tw\u000e^1uS>t\u0017B\u0001+R\u0005\u0015\u0019\u0016N\\2fC\u00051\u0016!B\u0019/m9\u0002\u0014\u0001B;jI\u0002B3AA(V\u0003\u0019a\u0014N\\5u}Q\u0011Af\u0017\u0005\u0006\u007f\r\u0001\r!\u0011\u0015\u00047>+\u0006fA\u0002P+R\tA\u0006K\u0002\u0005\u001fV\u000bAb]3u\u0007\u0016t7o\u001c:D_2$\"AY2\u000e\u0003\u0001AQ\u0001Z\u0003A\u0002\u0005\u000bQA^1mk\u0016D3!B(V\u0003a\u0019X\r^)vC:$\u0018\u000e\\3Qe>\u0014\u0017MY5mSRLWm\u001d\u000b\u0003E\"DQ\u0001\u001a\u0004A\u0002%\u00042A[6n\u001b\u00059\u0015B\u00017H\u0005\u0015\t%O]1z!\tQg.\u0003\u0002p\u000f\n1Ai\\;cY\u0016D3AB(V\u0003=\u0019X\r^)vC:$\u0018\u000e\\3t\u0007>dGC\u00012t\u0011\u0015!w\u00011\u0001BQ\r9q*V\u0001\u0010g\u0016$h)\u001b;J]R,'oY3qiR\u0011!m\u001e\u0005\u0006I\"\u0001\r\u0001\u001f\t\u0003UfL!A_$\u0003\u000f\t{w\u000e\\3b]\"\u001a\u0001bT+\u0002\u0015M,G/T1y\u0013R,'\u000f\u0006\u0002c}\")A-\u0003a\u0001\u007fB\u0019!.!\u0001\n\u0007\u0005\rqIA\u0002J]RD3!C(V\u0003\u0019\u0019X\r\u001e+pYR\u0019!-a\u0003\t\u000b\u0011T\u0001\u0019A7)\u0007)yU+A\ntKR\fum\u001a:fO\u0006$\u0018n\u001c8EKB$\b\u000eF\u0002c\u0003'AQ\u0001Z\u0006A\u0002}DCaC(\u0002\u0018\u0005\u0012\u0011\u0011D\u0001\u0006e9\nd\u0006M\u0001\u0011Kb$(/Y2u\u0003\u001a#\u0006k\\5oiN$B!a\b\u00022A1\u0011\u0011EA\u0014\u0003Wi!!a\t\u000b\u0007\u0005\u0015\"$A\u0002sI\u0012LA!!\u000b\u0002$\t\u0019!\u000b\u0012#\u0011\u0007\r\ni#C\u0002\u00020Y\u0011\u0001\"\u0011$U!>Lg\u000e\u001e\u0005\b\u0003ga\u0001\u0019AA\u001b\u0003\u001d!\u0017\r^1tKR\u0004D!a\u000e\u0002HA1\u0011\u0011HA \u0003\u0007j!!a\u000f\u000b\u0007\u0005u\"$A\u0002tc2LA!!\u0011\u0002<\t9A)\u0019;bg\u0016$\b\u0003BA#\u0003\u000fb\u0001\u0001\u0002\u0007\u0002J\u0005E\u0012\u0011!A\u0001\u0006\u0003\tYEA\u0002`IE\nB!!\u0014\u0002TA\u0019!.a\u0014\n\u0007\u0005EsIA\u0004O_RD\u0017N\\4\u0011\u0007)\f)&C\u0002\u0002X\u001d\u00131!\u00118z\u0003\u0015!(/Y5o)\ri\u0013Q\f\u0005\b\u0003gi\u0001\u0019AA0a\u0011\t\t'!\u001a\u0011\r\u0005e\u0012qHA2!\u0011\t)%!\u001a\u0005\u0019\u0005\u001d\u0014QLA\u0001\u0002\u0003\u0015\t!a\u0013\u0003\u0007}##'A\bue\u0006t7OZ8s[N\u001b\u0007.Z7b)\u0011\ti'!\u001f\u0011\t\u0005=\u0014QO\u0007\u0003\u0003cRA!a\u001d\u0002<\u0005)A/\u001f9fg&!\u0011qOA9\u0005)\u0019FO];diRK\b/\u001a\u0005\b\u0003wr\u0001\u0019AA7\u0003\u0019\u00198\r[3nC\"\u001aabT+\u0002\t\r|\u0007/\u001f\u000b\u0004Y\u0005\r\u0005bBAC\u001f\u0001\u0007\u0011qQ\u0001\u0006Kb$(/\u0019\t\u0005\u0003\u0013\u000by)\u0004\u0002\u0002\f*\u0019\u0011Q\u0012\r\u0002\u000bA\f'/Y7\n\t\u0005E\u00151\u0012\u0002\t!\u0006\u0014\u0018-\\'ba\"\u001aqbT+)\u0007\u0001yU+A\u000bB\rR\u001bVO\u001d<jm\u0006d'+Z4sKN\u001c\u0018n\u001c8\u0011\u0005\r\n2cB\t\u0002\u001e\u0006\r\u0016\u0011\u0016\t\u0004U\u0006}\u0015bAAQ\u000f\n1\u0011I\\=SK\u001a\u0004B\u0001NASY%\u0019\u0011qU\u001b\u0003+\u0011+g-Y;miB\u000b'/Y7t%\u0016\fG-\u00192mKB\u0019!.a+\n\u0007\u00055vI\u0001\u0007TKJL\u0017\r\\5{C\ndW\r\u0006\u0002\u0002\u001a\u0006!An\\1e)\ra\u0013Q\u0017\u0005\u0007\u0003o\u001b\u0002\u0019A!\u0002\tA\fG\u000f\u001b\u0015\u0004'=+\u0016a\u0003:fC\u0012\u0014Vm]8mm\u0016$\"!a0\u0011\t\u0005\u0005\u00171Z\u0007\u0003\u0003\u0007TA!!2\u0002H\u0006!A.\u00198h\u0015\t\tI-\u0001\u0003kCZ\f\u0017\u0002BAg\u0003\u0007\u0014aa\u00142kK\u000e$\bfA\tP+\"\u001a\u0001cT+")
public class AFTSurvivalRegression
extends Regressor<Vector, AFTSurvivalRegression, AFTSurvivalRegressionModel>
implements AFTSurvivalRegressionParams,
DefaultParamsWritable {
    private final String uid;
    private final Param<String> censorCol;
    private final DoubleArrayParam quantileProbabilities;
    private final Param<String> quantilesCol;
    private final IntParam aggregationDepth;
    private final BooleanParam fitIntercept;
    private final DoubleParam tol;
    private final IntParam maxIter;

    public static AFTSurvivalRegression load(String string) {
        return AFTSurvivalRegression$.MODULE$.load(string);
    }

    public static MLReader<AFTSurvivalRegression> read() {
        return AFTSurvivalRegression$.MODULE$.read();
    }

    @Override
    public MLWriter write() {
        return DefaultParamsWritable.write$(this);
    }

    @Override
    public void save(String path) throws IOException {
        MLWritable.save$(this, path);
    }

    @Override
    public String getCensorCol() {
        return AFTSurvivalRegressionParams.getCensorCol$(this);
    }

    @Override
    public double[] getQuantileProbabilities() {
        return AFTSurvivalRegressionParams.getQuantileProbabilities$(this);
    }

    @Override
    public String getQuantilesCol() {
        return AFTSurvivalRegressionParams.getQuantilesCol$(this);
    }

    @Override
    public boolean hasQuantilesCol() {
        return AFTSurvivalRegressionParams.hasQuantilesCol$(this);
    }

    @Override
    public StructType validateAndTransformSchema(StructType schema, boolean fitting) {
        return AFTSurvivalRegressionParams.validateAndTransformSchema$(this, schema, fitting);
    }

    @Override
    public final int getAggregationDepth() {
        return HasAggregationDepth.getAggregationDepth$(this);
    }

    @Override
    public final boolean getFitIntercept() {
        return HasFitIntercept.getFitIntercept$(this);
    }

    @Override
    public final double getTol() {
        return HasTol.getTol$(this);
    }

    @Override
    public final int getMaxIter() {
        return HasMaxIter.getMaxIter$(this);
    }

    @Override
    public final Param<String> censorCol() {
        return this.censorCol;
    }

    @Override
    public final DoubleArrayParam quantileProbabilities() {
        return this.quantileProbabilities;
    }

    @Override
    public final Param<String> quantilesCol() {
        return this.quantilesCol;
    }

    @Override
    public final void org$apache$spark$ml$regression$AFTSurvivalRegressionParams$_setter_$censorCol_$eq(Param<String> x$1) {
        this.censorCol = x$1;
    }

    @Override
    public final void org$apache$spark$ml$regression$AFTSurvivalRegressionParams$_setter_$quantileProbabilities_$eq(DoubleArrayParam x$1) {
        this.quantileProbabilities = x$1;
    }

    @Override
    public final void org$apache$spark$ml$regression$AFTSurvivalRegressionParams$_setter_$quantilesCol_$eq(Param<String> x$1) {
        this.quantilesCol = x$1;
    }

    @Override
    public final IntParam aggregationDepth() {
        return this.aggregationDepth;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasAggregationDepth$_setter_$aggregationDepth_$eq(IntParam x$1) {
        this.aggregationDepth = x$1;
    }

    @Override
    public final BooleanParam fitIntercept() {
        return this.fitIntercept;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasFitIntercept$_setter_$fitIntercept_$eq(BooleanParam x$1) {
        this.fitIntercept = x$1;
    }

    @Override
    public final DoubleParam tol() {
        return this.tol;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(DoubleParam x$1) {
        this.tol = x$1;
    }

    @Override
    public final IntParam maxIter() {
        return this.maxIter;
    }

    @Override
    public final void org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(IntParam x$1) {
        this.maxIter = x$1;
    }

    @Override
    public String uid() {
        return this.uid;
    }

    public AFTSurvivalRegression setCensorCol(String value) {
        return (AFTSurvivalRegression)this.set(this.censorCol(), value);
    }

    public AFTSurvivalRegression setQuantileProbabilities(double[] value) {
        return (AFTSurvivalRegression)this.set(this.quantileProbabilities(), value);
    }

    public AFTSurvivalRegression setQuantilesCol(String value) {
        return (AFTSurvivalRegression)this.set(this.quantilesCol(), value);
    }

    public AFTSurvivalRegression setFitIntercept(boolean value) {
        return (AFTSurvivalRegression)this.set(this.fitIntercept(), BoxesRunTime.boxToBoolean((boolean)value));
    }

    public AFTSurvivalRegression setMaxIter(int value) {
        return (AFTSurvivalRegression)this.set(this.maxIter(), BoxesRunTime.boxToInteger((int)value));
    }

    public AFTSurvivalRegression setTol(double value) {
        return (AFTSurvivalRegression)this.set(this.tol(), BoxesRunTime.boxToDouble((double)value));
    }

    public AFTSurvivalRegression setAggregationDepth(int value) {
        return (AFTSurvivalRegression)this.set(this.aggregationDepth(), BoxesRunTime.boxToInteger((int)value));
    }

    public RDD<AFTPoint> extractAFTPoints(Dataset<?> dataset) {
        return dataset.select((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Column[]{functions$.MODULE$.col(this.$(this.featuresCol())), functions$.MODULE$.col(this.$(this.labelCol())).cast((DataType)DoubleType$.MODULE$), functions$.MODULE$.col(this.$(this.censorCol())).cast((DataType)DoubleType$.MODULE$)})).rdd().map((Function1 & Serializable & scala.Serializable)x0$1 -> {
            double d;
            Vector vector;
            Object censor;
            block3: {
                Row row;
                block2: {
                    row = x0$1;
                    Some some = Row$.MODULE$.unapplySeq(row);
                    if (some.isEmpty() || some.get() == null || ((SeqLike)some.get()).lengthCompare(3) != 0) break block2;
                    Object features = ((SeqLike)some.get()).apply(0);
                    Object label = ((SeqLike)some.get()).apply(1);
                    censor = ((SeqLike)some.get()).apply(2);
                    if (!(features instanceof Vector)) break block2;
                    vector = (Vector)features;
                    if (!(label instanceof Double)) break block2;
                    d = BoxesRunTime.unboxToDouble((Object)label);
                    if (censor instanceof Double) break block3;
                }
                throw new MatchError((Object)row);
            }
            double d2 = BoxesRunTime.unboxToDouble((Object)censor);
            AFTPoint aFTPoint = new AFTPoint(vector, d, d2);
            return aFTPoint;
        }, ClassTag$.MODULE$.apply(AFTPoint.class));
    }

    @Override
    public AFTSurvivalRegressionModel train(Dataset<?> dataset) {
        return (AFTSurvivalRegressionModel)Instrumentation$.MODULE$.instrumented((Function1 & Serializable & scala.Serializable)instr -> {
            RDD<AFTPoint> instances = this.extractAFTPoints(dataset);
            StorageLevel storageLevel = dataset.storageLevel();
            StorageLevel storageLevel2 = StorageLevel$.MODULE$.NONE();
            boolean handlePersistence = !(storageLevel != null ? !storageLevel.equals(storageLevel2) : storageLevel2 != null);
            Object object = handlePersistence ? instances.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK()) : BoxedUnit.UNIT;
            SummarizerBuffer featuresSummarizer = (SummarizerBuffer)instances.treeAggregate((Object)Summarizer$.MODULE$.createSummarizerBuffer((Seq<String>)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"mean", "std", "count"})), (Function2 & Serializable & scala.Serializable)(c, v) -> c.add(v.features()), (Function2 & Serializable & scala.Serializable)(c1, c2) -> c1.merge((SummarizerBuffer)c2), BoxesRunTime.unboxToInt((Object)this.$(this.aggregationDepth())), ClassTag$.MODULE$.apply(SummarizerBuffer.class));
            double[] featuresStd = featuresSummarizer.std().toArray();
            int numFeatures = new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(featuresStd)).size();
            instr.logPipelineStage(this);
            instr.logDataset(dataset);
            instr.logParams(this, (Seq<Param<?>>)Predef$.MODULE$.wrapRefArray((Object[])new Param[]{this.labelCol(), this.featuresCol(), this.censorCol(), this.predictionCol(), this.quantilesCol(), this.fitIntercept(), this.maxIter(), this.tol(), this.aggregationDepth()}));
            instr.logNamedValue("quantileProbabilities.size", this.$(this.quantileProbabilities()).length);
            instr.logNumFeatures(numFeatures);
            instr.logNumExamples(featuresSummarizer.count());
            if (!BoxesRunTime.unboxToBoolean((Object)this.$(this.fitIntercept())) && RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), numFeatures).exists((Function1)(JFunction1.mcZI.sp & Serializable & scala.Serializable)i -> featuresStd[i] == 0.0 && featuresSummarizer.mean().apply(i) != 0.0)) {
                instr.logWarning((Function0<String>)(Function0 & Serializable & scala.Serializable)() -> "Fitting AFTSurvivalRegressionModel without intercept on dataset with constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is different from R survival::survreg.");
            }
            Broadcast bcFeaturesStd = instances.context().broadcast((Object)featuresStd, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE)));
            AFTCostFun costFun = new AFTCostFun(instances, BoxesRunTime.unboxToBoolean((Object)this.$(this.fitIntercept())), (Broadcast<double[]>)bcFeaturesStd, BoxesRunTime.unboxToInt((Object)this.$(this.aggregationDepth())));
            LBFGS optimizer = new LBFGS(BoxesRunTime.unboxToInt((Object)this.$(this.maxIter())), 10, BoxesRunTime.unboxToDouble((Object)this.$(this.tol())), (MutableInnerProductModule)DenseVector$.MODULE$.space_Double());
            Vector initialParameters = Vectors$.MODULE$.zeros(numFeatures + 2);
            Iterator states = optimizer.iterations((StochasticDiffFunction)new CachedDiffFunction((DiffFunction)costFun, DenseVector$.MODULE$.canCopyDenseVector(ClassTag$.MODULE$.Double())), (Object)initialParameters.asBreeze().toDenseVector$mcD$sp(ClassTag$.MODULE$.Double()));
            ArrayBuilder arrayBuilder = ArrayBuilder$.MODULE$.make(ClassTag$.MODULE$.Double());
            FirstOrderMinimizer.State state = null;
            while (states.hasNext()) {
                state = (FirstOrderMinimizer.State)states.next();
                arrayBuilder.$plus$eq((Object)BoxesRunTime.boxToDouble((double)state.adjustedValue()));
            }
            if (state == null) {
                String msg = new StringBuilder(8).append(optimizer.getClass().getName()).append(" failed.").toString();
                throw new SparkException(msg);
            }
            double[] parameters = (double[])((DenseVector)state.x()).toArray$mcD$sp(ClassTag$.MODULE$.Double()).clone();
            bcFeaturesStd.destroy();
            Object object2 = handlePersistence ? instances.unpersist(instances.unpersist$default$1()) : BoxedUnit.UNIT;
            double[] rawCoefficients = (double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(parameters)).slice(2, parameters.length);
            for (int i2 = 0; i2 < numFeatures; ++i2) {
                int n = i2;
                rawCoefficients[n] = rawCoefficients[n] * (featuresStd[i2] != 0.0 ? 1.0 / featuresStd[i2] : 0.0);
            }
            Vector coefficients = Vectors$.MODULE$.dense(rawCoefficients);
            double intercept = parameters[1];
            double scale = package$.MODULE$.exp(parameters[0]);
            return new AFTSurvivalRegressionModel(this.uid(), coefficients, intercept, scale);
        });
    }

    @Override
    public StructType transformSchema(StructType schema) {
        return this.validateAndTransformSchema(schema, true);
    }

    @Override
    public AFTSurvivalRegression copy(ParamMap extra) {
        return (AFTSurvivalRegression)this.defaultCopy(extra);
    }

    public AFTSurvivalRegression(String uid) {
        this.uid = uid;
        HasMaxIter.$init$(this);
        HasTol.$init$(this);
        HasFitIntercept.$init$(this);
        HasAggregationDepth.$init$(this);
        AFTSurvivalRegressionParams.$init$(this);
        MLWritable.$init$(this);
        DefaultParamsWritable.$init$(this);
        this.setDefault((Seq<ParamPair<?>>)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.fitIntercept().$minus$greater(BoxesRunTime.boxToBoolean((boolean)true))}));
        this.setDefault((Seq<ParamPair<?>>)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.maxIter().$minus$greater(BoxesRunTime.boxToInteger((int)100))}));
        this.setDefault((Seq<ParamPair<?>>)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.tol().$minus$greater(BoxesRunTime.boxToDouble((double)1.0E-6))}));
        this.setDefault((Seq<ParamPair<?>>)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.aggregationDepth().$minus$greater(BoxesRunTime.boxToInteger((int)2))}));
    }

    public AFTSurvivalRegression() {
        this(Identifiable$.MODULE$.randomUID("aftSurvReg"));
    }
}

