From 367bbe1f3e82d48b1d5eb558346d9801f839b802 Mon Sep 17 00:00:00 2001 From: tmartins Date: Mon, 29 Aug 2022 17:05:43 -0300 Subject: [PATCH] allow id_field to be customizable when feeding a data frame --- vespa/application.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vespa/application.py b/vespa/application.py index 889e0572..5dc0c501 100644 --- a/vespa/application.py +++ b/vespa/application.py @@ -63,25 +63,26 @@ def parse_labeled_data(df): return labeled_data -def parse_feed_df(df: DataFrame, include_id): +def parse_feed_df(df: DataFrame, include_id, id_field="id"): """ Convert a df into batch format for feeding :param df: DataFrame with the following required columns ["id"]. Additional columns are assumed to be fields. :param include_id: Include id on the fields to be fed. + :param id_field: Name of the column containing the id field. :return: List of Dict containing 'id' and 'fields'. """ - required_columns = ["id"] + required_columns = [id_field] assert all( [x in list(df.columns) for x in required_columns] ), "DataFrame needs at least the following columns: {}".format(required_columns) records = df.to_dict(orient="records") batch = [ { - "id": record["id"], + "id": record[id_field], "fields": record if include_id - else {k: v for k, v in record.items() if k not in ["id"]}, + else {k: v for k, v in record.items() if k not in [id_field]}, } for record in records ] @@ -543,16 +544,17 @@ def feed_batch( ) return batch_http_responses - def feed_df(self, df: DataFrame, include_id: bool = True, **kwargs): + def feed_df(self, df: DataFrame, include_id: bool = True, id_field="id", **kwargs): """ Feed data contained in a DataFrame. :param df: A DataFrame containing a required 'id' column and the remaining fields to be fed. :param include_id: Include id on the fields to be fed. Default to True. + :param id_field: Name of the column containing the id field. :param kwargs: Additional parameters are passed to :func:`feed_batch`. :return: List of HTTP POST responses """ - batch = parse_feed_df(df=df, include_id=include_id) + batch = parse_feed_df(df=df, include_id=include_id, id_field=id_field) return self.feed_batch(batch=batch, **kwargs) def delete_data(