MetNet-3: A state-of-the-art neural weather model available in Google products

Posted by Samier Merchant, Google Research, and Nal Kalchbrenner, Google DeepMind Forecasting weather variables such as precipitation, temperature, and wind is key to numerous aspects of society, from daily planning and transportation to energy production. As we continue to see more extreme weather events such as floods, droughts, and heat waves, accurate forecasts can be essential to preparing for and mitigating their effects. The first 24 hours into the future are especially important as they are both highly predictable and actionable, which can help people make informed decisions in a timely manner and stay safe. Today we present a new weather model called MetNet-3, developed by Google Research and Google DeepMind. Building on the earlier MetNet and MetNet-2 models, MetNet-3 provides high resolution predictions up to 24 hours ahead for a larger set of core variables, including precipitation, surface temperature, wind speed and direction, and dew point. MetNet-3 creates a temporally smooth and highly granular forecast, with lead time intervals of 2 minutes and spatial resolutions of 1 to 4 kilometers. MetNet-3 achieves strong performance compared to traditional methods, outperforming the best single- and multi-member physics-based numerical weather prediction (NWP) models — such as High-Resolution Rapid Refresh (HRRR) and ensemble forecast suite (ENS) — for multiple regions up to 24 hours ahead. Finally, we’ve integrated MetNet-3’s capabilities across various Google products and technologies where weather is relevant. Currently available in the contiguous United States and parts of Europe with a focus on 12 hour precipitation forecasts, MetNet-3 is helping bring accurate and reliable weather information to people in multiple countries and languages.       MetNet-3 precipitation output summarized into actionable forecasts in Google Search on mobile. Densification of sparse observations Many recent machine learning weather models use the atmospheric state generated by traditional methods (e.g., data assimilation from NWPs) as the primary starting point to build forecasts. In contrast, a defining feature of the MetNet models has been to use direct observations of the atmosphere for training and evaluation. The advantage of direct observations is that they often have higher fidelity and resolution. However, direct observations come from a large variety of sensors at different altitudes, including weather stations at the surface level and satellites in orbit, and can be of varying degrees of sparsity. For example, precipitation estimates derived from radar such as NOAA’s Multi-Radar/Multi-Sensor System (MRMS) are relatively dense images, whereas weather stations located on the ground that provide measurements for variables such as temperature and wind are mere points spread over a region. In addition to the data sources used in previous MetNet models, MetNet-3 includes point measurements from weather stations as both inputs and targets with the goal of making a forecast at all locations. To this end, MetNet-3’s key innovation is a technique called densification, which merges the traditional two-step process of data assimilation and simulation found in physics-based models into a single pass through the neural network. The main components of densification are illustrated below. Although the densification technique applies to a specific stream of data individually, the resulting densified forecast benefits from all the other input streams that go into MetNet-3, including topographical, satellite, radar, and NWP analysis features. No NWP forecasts are included in MetNet-3’s default inputs. A) During training, a fraction of the weather stations are masked out from the input while kept in the target. B) To evaluate generalization to untrained locations, a set of weather stations represented by squares is never used for training and is only used for evaluation. C) Data from these held out weather stations with sparse coverage is included during evaluation to determine prediction quality in these areas. D) The final forecasts use the full set of training weather stations as input and produce fully dense forecasts aided by spatial parameter sharing. High resolution in space and time A central advantage of using direct observations is their high spatial and temporal resolution. For example, weather stations and ground radar stations provide measurements every few minutes at specific points and at 1 km resolutions, respectively; this is in stark contrast with the assimilation state from the state-of-the-art model ENS, which is generated every 6 hours at a resolution of 9 km with hour-by-hour forecasts. To handle such a high resolution, MetNet-3 preserves another of the defining features of this series of models, lead time conditioning. The lead time of the forecast in minutes is directly given as input to the neural network. This allows MetNet-3 to efficiently model the high temporal frequency of the observations for intervals as brief as 2 minutes. Densification combined with lead time conditioning and high resolution direct observations produces a fully dense 24 hour forecast with a temporal resolution of 2 minutes, while learning from just 1,000 points from the One Minute Observation (OMO) network of weather stations spread across the United States. MetNet-3 predicts a marginal multinomial probability distribution for each output variable and each location that provides rich information beyond just the mean. This allows us to compare the probabilistic outputs of MetNet-3 with the outputs of advanced probabilistic ensemble NWP models, including the ensemble forecast ENS from the European Centre for Medium-Range Weather Forecasts and the High Resolution Ensemble Forecast (HREF) from the National Oceanic and Atmospheric Administration of the US. Due to the probabilistic nature of the outputs of both models, we are able to compute scores such as the Continuous Ranked Probability Score (CRPS). The following graphics highlight densification results and illustrate that MetNet’s forecasts are not only of much higher resolution, but are also more accurate when evaluated at the overlapping lead times. Top: MetNet-3’s forecast of wind speed for each 2 minutes over the future 24 hours with a spatial resolution of 4km. Bottom: ENS’s hourly forecast with a spatial resolution of 18 km. The two distinct regimes in spatial structure are primarily driven by the presence of the Colorado mountain ranges. Darker corresponds to higher wind speed. More samples available here: 1, 2, 3, 4. Performance comparison between MetNet-3 and NWP baseline for wind speed based on CRPS (lower is better). In the hyperlocal setting, values of the test weather stations are given as input to the network during evaluation; the results improve further especially in the early lead times. In contrast to weather station variables, precipitation estimates are more dense as they come from ground radar. MetNet-3’s modeling of precipitation is similar to that of MetNet-1 and 2, but extends the high resolution precipitation forecasts with a 1km spatial granularity to the same 24 hours of lead time as the other variables, as shown in the animation below. MetNet-3’s performance on precipitation achieves a better CRPS value than ENS’s throughout the 24 hour range. Case study for Thu Jan 17 2019 00:00 UTC showing the probability of instantaneous precipitation rate being above 1 mm/h on CONUS. Darker corresponds to a higher probability value. The maps also show the prediction threshold when optimized towards Critical Success Index CSI (dark blue contours). This specific case study shows the formation of a new large precipitation pattern in the central US; it is not just forecasting of existing patterns. Top: ENS’s hourly forecast. Center: Ground truth, source NOAA’s MRMS. Bottom: Probability map as predicted by MetNet-3. Native resolution available here. Performance comparison between MetNet-3 and NWP baseline for instantaneous precipitation rate on CRPS (lower is better). Delivering realtime ML forecasts Training and evaluating a weather forecasting model like MetNet-3 on historical data is only a part of the process of delivering ML-powered forecasts to users. There are many considerations when developing a real-time ML system for weather forecasting, such as ingesting real-time input data from multiple distinct sources, running inference, implementing real-time validation of outputs, building insights from the rich output of the model that lead to an intuitive user experience, and serving the results at Google scale — all on a continuous cycle, refreshed every few minutes. We developed such a real-time system that is capable of producing a precipitation forecast every few minutes for the entire contiguous United States and for 27 countries in Europe for a lead time of up to 12 hours. Illustration of the process of generating precipitation forecasts using MetNet-3. The system's uniqueness stems from its use of near-continuous inference, which allows the model to constantly create full forecasts based on incoming data streams. This mode of inference is different from traditional inference systems, and is necessary due to the distinct characteristics of the incoming data. The model takes in various data sources as input, such as radar, satellite, and numerical weather prediction assimilations. Each of these inputs has a different refresh frequency and spatial and temporal resolution. Some data sources, such as weather observations and radar, have characteristics similar to a continuous stream of data, while others, such as NWP assimilations, are similar to batches of data. The system is able to align all of these data sources spatially and temporally, allowing the model to create an updated understanding of the next 12 hours of precipitation at a very high cadence. With the above process, the model is able to predict arbitrary discrete probability distributions. We developed novel techniques to transform this dense output space into user-friendly information that enables rich experiences throughout Google products and technologies. Weather features in Google products People around the world rely on Google every day to provide helpful, timely, and accurate information about the weather. This information is used for a variety of purposes, such as planning outdoor activities, packing for trips, and staying safe during severe weather events. The state-of-the-art accuracy, high temporal and spatial resolution, and probabilistic nature of MetNet-3 makes it possible to create unique hyperlocal weather insights. For the contiguous United States and Europe, MetNet-3 is operational and produces real-time 12 hour precipitation forecasts that are now served across Google products and technologies where weather is relevant, such as Search. The rich output from the model is synthesized into actionable information and instantly served to millions of users. For example, a user who searches for weather information for a precise location from their mobile device will receive highly localized precipitation forecast data, including timeline graphs with granular minute breakdowns depending on the product. MetNet-3 precipitation output in weather on the Google app on Android (left) and mobile web Search (right). Conclusion MetNet-3 is a new deep learning model for weather forecasting that outperforms state-of-the-art physics-based models for 24-hour forecasts of a core set of weather variables. It has the potential to create new possibilities for weather forecasting and to improve the safety and efficiency of many activities, such as transportation, agriculture, and energy production. MetNet-3 is operational and its forecasts are served across several Google products where weather is relevant. Acknowledgements Many people were involved in the development of this effort. We would like to especially thank those from Google DeepMind (Di Li, Jeremiah Harmsen, Lasse Espeholt, Marcin Andrychowicz, Zack Ontiveros), Google Research (Aaron Bell, Akib Uddin, Alex Merose, Carla Bromberg, Fred Zyda, Isalo Montacute, Jared Sisk, Jason Hickey, Luke Barrington, Mark Young, Maya Tohidi, Natalie Williams, Pramod Gupta, Shreya Agrawal, Thomas Turnbull, Tom Small, Tyler Russell), and Google Search (Agustin Pesciallo, Bill Myers, Danny Cheresnick, Lior Cohen, Maca Piombi, Maia Diamant, Max Kamenetsky, Maya Ekron, Mor Schlesinger, Neta Gefen-Doron, Nofar Peled Levi, Ofer Lehr, Or Hillel, Rotem Wertman, Vinay Ruelius Shah, Yechie Labai).

Share This Post

Forecasting weather variables such as precipitation, temperature, and wind is key to numerous aspects of society, from daily planning and transportation to energy production. As we continue to see more extreme weather events such as floods, droughts, and heat waves, accurate forecasts can be essential to preparing for and mitigating their effects. The first 24 hours into the future are especially important as they are both highly predictable and actionable, which can help people make informed decisions in a timely manner and stay safe.

Today we present a new weather model called MetNet-3, developed by Google Research and Google DeepMind. Building on the earlier MetNet and MetNet-2 models, MetNet-3 provides high resolution predictions up to 24 hours ahead for a larger set of core variables, including precipitation, surface temperature, wind speed and direction, and dew point. MetNet-3 creates a temporally smooth and highly granular forecast, with lead time intervals of 2 minutes and spatial resolutions of 1 to 4 kilometers. MetNet-3 achieves strong performance compared to traditional methods, outperforming the best single- and multi-member physics-based numerical weather prediction (NWP) models — such as High-Resolution Rapid Refresh (HRRR) and ensemble forecast suite (ENS) — for multiple regions up to 24 hours ahead.

Finally, we’ve integrated MetNet-3’s capabilities across various Google products and technologies where weather is relevant. Currently available in the contiguous United States and parts of Europe with a focus on 12 hour precipitation forecasts, MetNet-3 is helping bring accurate and reliable weather information to people in multiple countries and languages.

     
MetNet-3 precipitation output summarized into actionable forecasts in Google Search on mobile.

Densification of sparse observations

Many recent machine learning weather models use the atmospheric state generated by traditional methods (e.g., data assimilation from NWPs) as the primary starting point to build forecasts. In contrast, a defining feature of the MetNet models has been to use direct observations of the atmosphere for training and evaluation. The advantage of direct observations is that they often have higher fidelity and resolution. However, direct observations come from a large variety of sensors at different altitudes, including weather stations at the surface level and satellites in orbit, and can be of varying degrees of sparsity. For example, precipitation estimates derived from radar such as NOAA’s Multi-Radar/Multi-Sensor System (MRMS) are relatively dense images, whereas weather stations located on the ground that provide measurements for variables such as temperature and wind are mere points spread over a region.

In addition to the data sources used in previous MetNet models, MetNet-3 includes point measurements from weather stations as both inputs and targets with the goal of making a forecast at all locations. To this end, MetNet-3’s key innovation is a technique called densification, which merges the traditional two-step process of data assimilation and simulation found in physics-based models into a single pass through the neural network. The main components of densification are illustrated below. Although the densification technique applies to a specific stream of data individually, the resulting densified forecast benefits from all the other input streams that go into MetNet-3, including topographical, satellite, radar, and NWP analysis features. No NWP forecasts are included in MetNet-3’s default inputs.

A) During training, a fraction of the weather stations are masked out from the input while kept in the target. B) To evaluate generalization to untrained locations, a set of weather stations represented by squares is never used for training and is only used for evaluation. C) Data from these held out weather stations with sparse coverage is included during evaluation to determine prediction quality in these areas. D) The final forecasts use the full set of training weather stations as input and produce fully dense forecasts aided by spatial parameter sharing.

High resolution in space and time

A central advantage of using direct observations is their high spatial and temporal resolution. For example, weather stations and ground radar stations provide measurements every few minutes at specific points and at 1 km resolutions, respectively; this is in stark contrast with the assimilation state from the state-of-the-art model ENS, which is generated every 6 hours at a resolution of 9 km with hour-by-hour forecasts. To handle such a high resolution, MetNet-3 preserves another of the defining features of this series of models, lead time conditioning. The lead time of the forecast in minutes is directly given as input to the neural network. This allows MetNet-3 to efficiently model the high temporal frequency of the observations for intervals as brief as 2 minutes. Densification combined with lead time conditioning and high resolution direct observations produces a fully dense 24 hour forecast with a temporal resolution of 2 minutes, while learning from just 1,000 points from the One Minute Observation (OMO) network of weather stations spread across the United States.

MetNet-3 predicts a marginal multinomial probability distribution for each output variable and each location that provides rich information beyond just the mean. This allows us to compare the probabilistic outputs of MetNet-3 with the outputs of advanced probabilistic ensemble NWP models, including the ensemble forecast ENS from the European Centre for Medium-Range Weather Forecasts and the High Resolution Ensemble Forecast (HREF) from the National Oceanic and Atmospheric Administration of the US. Due to the probabilistic nature of the outputs of both models, we are able to compute scores such as the Continuous Ranked Probability Score (CRPS). The following graphics highlight densification results and illustrate that MetNet’s forecasts are not only of much higher resolution, but are also more accurate when evaluated at the overlapping lead times.

Top: MetNet-3’s forecast of wind speed for each 2 minutes over the future 24 hours with a spatial resolution of 4km. Bottom: ENS’s hourly forecast with a spatial resolution of 18 km.
The two distinct regimes in spatial structure are primarily driven by the presence of the Colorado mountain ranges. Darker corresponds to higher wind speed. More samples available here: 1, 2, 3, 4.
Performance comparison between MetNet-3 and NWP baseline for wind speed based on CRPS (lower is better). In the hyperlocal setting, values of the test weather stations are given as input to the network during evaluation; the results improve further especially in the early lead times.

In contrast to weather station variables, precipitation estimates are more dense as they come from ground radar. MetNet-3’s modeling of precipitation is similar to that of MetNet-1 and 2, but extends the high resolution precipitation forecasts with a 1km spatial granularity to the same 24 hours of lead time as the other variables, as shown in the animation below. MetNet-3’s performance on precipitation achieves a better CRPS value than ENS’s throughout the 24 hour range.

Case study for Thu Jan 17 2019 00:00 UTC showing the probability of instantaneous precipitation rate being above 1 mm/h on CONUS. Darker corresponds to a higher probability value. The maps also show the prediction threshold when optimized towards Critical Success Index CSI (dark blue contours). This specific case study shows the formation of a new large precipitation pattern in the central US; it is not just forecasting of existing patterns.
Top: ENS’s hourly forecast. Center: Ground truth, source NOAA’s MRMS. Bottom: Probability map as predicted by MetNet-3. Native resolution available here.
Performance comparison between MetNet-3 and NWP baseline for instantaneous precipitation rate on CRPS (lower is better).

Delivering realtime ML forecasts

Training and evaluating a weather forecasting model like MetNet-3 on historical data is only a part of the process of delivering ML-powered forecasts to users. There are many considerations when developing a real-time ML system for weather forecasting, such as ingesting real-time input data from multiple distinct sources, running inference, implementing real-time validation of outputs, building insights from the rich output of the model that lead to an intuitive user experience, and serving the results at Google scale — all on a continuous cycle, refreshed every few minutes.

We developed such a real-time system that is capable of producing a precipitation forecast every few minutes for the entire contiguous United States and for 27 countries in Europe for a lead time of up to 12 hours.

Illustration of the process of generating precipitation forecasts using MetNet-3.

The system’s uniqueness stems from its use of near-continuous inference, which allows the model to constantly create full forecasts based on incoming data streams. This mode of inference is different from traditional inference systems, and is necessary due to the distinct characteristics of the incoming data. The model takes in various data sources as input, such as radar, satellite, and numerical weather prediction assimilations. Each of these inputs has a different refresh frequency and spatial and temporal resolution. Some data sources, such as weather observations and radar, have characteristics similar to a continuous stream of data, while others, such as NWP assimilations, are similar to batches of data. The system is able to align all of these data sources spatially and temporally, allowing the model to create an updated understanding of the next 12 hours of precipitation at a very high cadence.

With the above process, the model is able to predict arbitrary discrete probability distributions. We developed novel techniques to transform this dense output space into user-friendly information that enables rich experiences throughout Google products and technologies.

Weather features in Google products

People around the world rely on Google every day to provide helpful, timely, and accurate information about the weather. This information is used for a variety of purposes, such as planning outdoor activities, packing for trips, and staying safe during severe weather events.

The state-of-the-art accuracy, high temporal and spatial resolution, and probabilistic nature of MetNet-3 makes it possible to create unique hyperlocal weather insights. For the contiguous United States and Europe, MetNet-3 is operational and produces real-time 12 hour precipitation forecasts that are now served across Google products and technologies where weather is relevant, such as Search. The rich output from the model is synthesized into actionable information and instantly served to millions of users.

For example, a user who searches for weather information for a precise location from their mobile device will receive highly localized precipitation forecast data, including timeline graphs with granular minute breakdowns depending on the product.

MetNet-3 precipitation output in weather on the Google app on Android (left) and mobile web Search (right).

Conclusion

MetNet-3 is a new deep learning model for weather forecasting that outperforms state-of-the-art physics-based models for 24-hour forecasts of a core set of weather variables. It has the potential to create new possibilities for weather forecasting and to improve the safety and efficiency of many activities, such as transportation, agriculture, and energy production. MetNet-3 is operational and its forecasts are served across several Google products where weather is relevant.

Acknowledgements

Many people were involved in the development of this effort. We would like to especially thank those from Google DeepMind (Di Li, Jeremiah Harmsen, Lasse Espeholt, Marcin Andrychowicz, Zack Ontiveros), Google Research (Aaron Bell, Akib Uddin, Alex Merose, Carla Bromberg, Fred Zyda, Isalo Montacute, Jared Sisk, Jason Hickey, Luke Barrington, Mark Young, Maya Tohidi, Natalie Williams, Pramod Gupta, Shreya Agrawal, Thomas Turnbull, Tom Small, Tyler Russell), and Google Search (Agustin Pesciallo, Bill Myers, Danny Cheresnick, Lior Cohen, Maca Piombi, Maia Diamant, Max Kamenetsky, Maya Ekron, Mor Schlesinger, Neta Gefen-Doron, Nofar Peled Levi, Ofer Lehr, Or Hillel, Rotem Wertman, Vinay Ruelius Shah, Yechie Labai).

Subscribe To Our Newsletter

Get updates and learn from the best

More To Explore

Uncategorized

Best of both worlds: Achieving scalability and quality in text clustering

Posted by Sara Ahmadian and Mehran Kazemi, Research Scientists, Google Research

Clustering is a fundamental, ubiquitous problem in data mining and unsupervised machine learning, where the goal is to group together similar items. The standard forms of clustering are metric clustering and graph clustering. In metric clustering, a given metric space defines distances between data points, which are grouped together based on their separation. In graph clustering, a given graph connects similar data points through edges, and the clustering process groups data points together based on the connections between them. Both clustering forms are particularly useful for large corpora where class labels can’t be defined. Examples of such corpora are the ever-growing digital text collections of various internet platforms, with applications including organizing and searching documents, identifying patterns in text, and recommending relevant documents to users (see more examples in the following posts: clustering related queries based on user intent and practical differentially private clustering).

The choice of text clustering method often presents a dilemma. One approach is to use embedding models, such as BERT or RoBERTa, to define a metric clustering problem. Another is to utilize cross-attention (CA) models, such as PaLM or GPT, to define a graph clustering problem. CA models can provide highly accurate similarity scores, but constructing the input graph may require a prohibitive quadratic number of inference calls to the model. On the other hand, a metric space can efficiently be defined by distances of embeddings produced by embedding models. However, these similarity distances are typically of substantial lower-quality compared to the similarity signals of CA models, and hence the produced clustering can be of much lower-quality.

An overview of the embedding-based and cross-attention–based similarity scoring functions and their scalability vs. quality dilemma.

Motivated by this, in “KwikBucks: Correlation Clustering with Cheap-Weak and Expensive-Strong Signals”, presented at ICLR 2023, we describe a novel clustering algorithm that effectively combines the scalability benefits from embedding models and the quality from CA models. This graph clustering algorithm has query access to both the CA model and the embedding model, however, we apply a budget on the number of queries made to the CA model. This algorithm uses the CA model to answer edge queries, and benefits from unlimited access to similarity scores from the embedding model. We describe how this proposed setting bridges algorithm design and practical considerations, and can be applied to other clustering problems with similar available scoring functions, such as clustering problems on images and media. We demonstrate how this algorithm yields high-quality clusters with almost a linear number of query calls to the CA model. We have also open-sourced the data used in our experiments.

The clustering algorithm

The KwikBucks algorithm is an extension of the well-known KwikCluster algorithm (Pivot algorithm). The high-level idea is to first select a set of documents (i.e., centers) with no similarity edge between them, and then form clusters around these centers. To obtain the quality from CA models and the runtime efficiency from embedding models, we introduce the novel combo similarity oracle mechanism. In this approach, we utilize the embedding model to guide the selection of queries to be sent to the CA model. When given a set of center documents and a target document, the combo similarity oracle mechanism outputs a center from the set that is similar to the target document, if present. The combo similarity oracle enables us to save on budget by limiting the number of query calls to the CA model when selecting centers and forming clusters. It does this by first ranking centers based on their embedding similarity to the target document, and then querying the CA model for the pair (i.e., target document and ranked center), as shown below.

A combo similarity oracle that for a set of documents and a target document, returns a similar document from the set, if present.

We then perform a post processing step to merge clusters if there is a strong connection between two of them, i.e., when the number of connecting edges is higher than the number of missing edges between two clusters. Additionally, we apply the following steps for further computational savings on queries made to the CA model, and to improve performance at runtime:

We leverage query-efficient correlation clustering to form a set of centers from a set of randomly selected documents instead of selecting these centers from all the documents (in the illustration below, the center nodes are red).

We apply the combo similarity oracle mechanism to perform the cluster assignment step in parallel for all non-center documents and leave documents with no similar center as singletons. In the illustration below, the assignments are depicted by blue arrows and initially two (non-center) nodes are left as singletons due to no assignment.

In the post-processing step, to ensure scalability, we use the embedding similarity scores to filter down the potential mergers (in the illustration below, the green dashed boundaries show these merged clusters).

Illustration of progress of the clustering algorithm on a given graph instance.

Results

We evaluate the novel clustering algorithm on various datasets with different properties using different embedding-based and cross-attention–based models. We compare the clustering algorithm’s performance with the two best performing baselines (see the paper for more details):

To evaluate the quality of clustering, we use precision and recall. Precision is used to calculate the percentage of similar pairs out of all co-clustered pairs and recall is the percentage of co-clustered similar pairs out of all similar pairs. To measure the quality of the obtained solutions from our experiments, we use the F1-score, which is the harmonic mean of the precision and recall, where 1.0 is the highest possible value that indicates perfect precision and recall, and 0 is the lowest possible value that indicates if either precision or recall are zero. The table below reports the F1-score for Kwikbucks and various baselines in the case that we allow only a linear number of queries to the CA model. We show that Kwikbucks offers a substantial boost in performance with a 45% relative improvement compared to the best baseline when averaging across all datasets.

The figure below compares the clustering algorithm’s performance with baselines using different query budgets. We observe that KwikBucks consistently outperforms other baselines at various budgets.

A comparison of KwikBucks with top-2 baselines when allowed different budgets for querying the cross-attention model.

Conclusion

Text clustering often presents a dilemma in the choice of similarity function: embedding models are scalable but lack quality, while cross-attention models offer quality but substantially hurt scalability. We present a clustering algorithm that offers the best of both worlds: the scalability of embedding models and the quality of cross-attention models. KwikBucks can also be applied to other clustering problems with multiple similarity oracles of varying accuracy levels. This is validated with an exhaustive set of experiments on various datasets with diverse properties. See the paper for more details.

Acknowledgements

This project was initiated during Sandeep Silwal’s summer internship at Google in 2022. We would like to express our gratitude to our co-authors, Andrew McCallum, Andrew Nystrom, Deepak Ramachandran, and Sandeep Silwal, for their valuable contributions to this work. We also thank Ravi Kumar and John Guilyard for assistance with this blog post.

Uncategorized

Zero-shot adaptive prompting of large language models

Posted by Xingchen Wan, Student Researcher, and Ruoxi Sun, Research Scientist, Cloud AI Team

Recent advances in large language models (LLMs) are very promising as reflected in their capability for general problem-solving in few-shot and zero-shot setups, even without explicit training on these tasks. This is impressive because in the few-shot setup, LLMs are presented with only a few question-answer demonstrations prior to being given a test question. Even more challenging is the zero-shot setup, where the LLM is directly prompted with the test question only.

Even though the few-shot setup has dramatically reduced the amount of data required to adapt a model for a specific use-case, there are still cases where generating sample prompts can be challenging. For example, handcrafting even a small number of demos for the broad range of tasks covered by general-purpose models can be difficult or, for unseen tasks, impossible. For example, for tasks like summarization of long articles or those that require domain knowledge (e.g., medical question answering), it can be challenging to generate sample answers. In such situations, models with high zero-shot performance are useful since no manual prompt generation is required. However, zero-shot performance is typically weaker as the LLM is not presented with guidance and thus is prone to spurious output.

In “Better Zero-shot Reasoning with Self-Adaptive Prompting”, published at ACL 2023, we propose Consistency-Based Self-Adaptive Prompting (COSP) to address this dilemma. COSP is a zero-shot automatic prompting method for reasoning problems that carefully selects and constructs pseudo-demonstrations for LLMs using only unlabeled samples (that are typically easy to obtain) and the models’ own predictions. With COSP, we largely close the performance gap between zero-shot and few-shot while retaining the desirable generality of zero-shot prompting. We follow this with “Universal Self-Adaptive Prompting“ (USP), accepted at EMNLP 2023, in which we extend the idea to a wide range of general natural language understanding (NLU) and natural language generation (NLG) tasks and demonstrate its effectiveness.

Prompting LLMs with their own outputs

Knowing that LLMs benefit from demonstrations and have at least some zero-shot abilities, we wondered whether the model’s zero-shot outputs could serve as demonstrations for the model to prompt itself. The challenge is that zero-shot solutions are imperfect, and we risk giving LLMs poor quality demonstrations, which could be worse than no demonstrations at all. Indeed, the figure below shows that adding a correct demonstration to a question can lead to a correct solution of the test question (Demo1 with question), whereas adding an incorrect demonstration (Demo 2 + questions, Demo 3 with questions) leads to incorrect answers. Therefore, we need to select reliable self-generated demonstrations.

Example inputs & outputs for reasoning tasks, which illustrates the need for carefully designed selection procedure for in-context demonstrations (MultiArith dataset & PaLM-62B model): (1) zero-shot chain-of-thought with no demo: correct logic but wrong answer; (2) correct demo (Demo1) and correct answer; (3) correct but repetitive demo (Demo2) leads to repetitive outputs; (4) erroneous demo (Demo3) leads to a wrong answer; but (5) combining Demo3 and Demo1 again leads to a correct answer.

COSP leverages a key observation of LLMs: that confident and consistent predictions are more likely correct. This observation, of course, depends on how good the uncertainty estimate of the LLM is. Luckily, in large models, previous works suggest that the uncertainty estimates are robust. Since measuring confidence requires only model predictions, not labels, we propose to use this as a zero-shot proxy of correctness. The high-confidence outputs and their inputs are then used as pseudo-demonstrations.

With this as our starting premise, we estimate the model’s confidence in its output based on its self-consistency and use this measure to select robust self-generated demonstrations. We ask LLMs the same question multiple times with zero-shot chain-of-thought (CoT) prompting. To guide the model to generate a range of possible rationales and final answers, we include randomness controlled by a “temperature” hyperparameter. In an extreme case, if the model is 100% certain, it should output identical final answers each time. We then compute the entropy of the answers to gauge the uncertainty — the answers that have high self-consistency and for which the LLM is more certain, are likely to be correct and will be selected.

Assuming that we are presented with a collection of unlabeled questions, the COSP method is:

Input each unlabeled question into an LLM, obtaining multiple rationales and answers by sampling the model multiple times. The most frequent answers are highlighted, followed by a score that measures consistency of answers across multiple sampled outputs (higher is better). In addition to favoring more consistent answers, we also penalize repetition within a response (i.e., with repeated words or phrases) and encourage diversity of selected demonstrations. We encode the preference towards consistent, un-repetitive and diverse outputs in the form of a scoring function that consists of a weighted sum of the three scores for selection of the self-generated pseudo-demonstrations.
We concatenate the pseudo-demonstrations into test questions, feed them to the LLM, and obtain a final predicted answer.

Illustration of COSP: In Stage 1 (left), we run zero-shot CoT multiple times to generate a pool of demonstrations (each consisting of the question, generated rationale and prediction) and assign a score. In Stage 2 (right), we augment the current test question with pseudo-demos (blue boxes) and query the LLM again. A majority vote over outputs from both stages forms the final prediction.

COSP focuses on question-answering tasks with CoT prompting for which it is easy to measure self-consistency since the questions have unique correct answers. But this can be difficult for other tasks, such as open-ended question-answering or generative tasks that don’t have unique answers (e.g., text summarization). To address this limitation, we introduce USP in which we generalize our approach to other general NLP tasks:

Classification (CLS): Problems where we can compute the probability of each class using the neural network output logits of each class. In this way, we can measure the uncertainty without multiple sampling by computing the entropy of the logit distribution.
Short-form generation (SFG): Problems like question answering where we can use the same procedure mentioned above for COSP, but, if necessary, without the rationale-generating step.
Long-form generation (LFG): Problems like summarization and translation, where the questions are often open-ended and the outputs are unlikely to be identical, even if the LLM is certain. In this case, we use an overlap metric in which we compute the average of the pairwise ROUGE score between the different outputs to the same query.

Illustration of USP in exemplary tasks (classification, QA and text summarization). Similar to COSP, the LLM first generates predictions on an unlabeled dataset whose outputs are scored with logit entropy, consistency or alignment, depending on the task type, and pseudo-demonstrations are selected from these input-output pairs. In Stage 2, the test instances are augmented with pseudo-demos for prediction.

We compute the relevant confidence scores depending on the type of task on the aforementioned set of unlabeled test samples. After scoring, similar to COSP, we pick the confident, diverse and less repetitive answers to form a model-generated pseudo-demonstration set. We finally query the LLM again in a few-shot format with these pseudo-demonstrations to obtain the final predictions on the entire test set.

Key Results

For COSP, we focus on a set of six arithmetic and commonsense reasoning problems, and we compare against 0-shot-CoT (i.e., “Let’s think step by step“ only). We use self-consistency in all baselines so that they use roughly the same amount of computational resources as COSP. Compared across three LLMs, we see that zero-shot COSP significantly outperforms the standard zero-shot baseline.

USP improves significantly on 0-shot performance. “CLS” is an average of 15 classification tasks; “SFG” is the average of five short-form generation tasks; “LFG” is the average of two summarization tasks. “SFG (BBH)” is an average of all BIG-Bench Hard tasks, where each question is in SFG format.

For USP, we expand our analysis to a much wider range of tasks, including more than 25 classifications, short-form generation, and long-form generation tasks. Using the state-of-the-art PaLM 2 models, we also test against the BIG-Bench Hard suite of tasks where LLMs have previously underperformed compared to people. We show that in all cases, USP again outperforms the baselines and is competitive to prompting with golden examples.

Accuracy on BIG-Bench Hard tasks with PaLM 2-M (each line represents a task of the suite). The gain/loss of USP (green stars) over standard 0-shot (green triangles) is shown in percentages. “Human” refers to average human performance; “AutoCoT” and “Random demo” are baselines we compared against in the paper; and “3-shot” is the few-shot performance for three handcrafted demos in CoT format.

We also analyze the working mechanism of USP by validating the key observation above on the relation between confidence and correctness, and we found that in an overwhelming majority of the cases, USP picks confident predictions that are more likely better in all task types considered, as shown in the figure below.

USP picks confident predictions that are more likely better. Ground-truth performance metrics against USP confidence scores in selected tasks in various task types (blue: CLS, orange: SFG, green: LFG) with PaLM-540B.
Conclusion

Zero-shot inference is a highly sought-after capability of modern LLMs, yet the success in which poses unique challenges. We propose COSP and USP, a family of versatile, zero-shot automatic prompting techniques applicable to a wide range of tasks. We show large improvement over the state-of-the-art baselines over numerous task and model combinations.

Acknowledgements

This work was conducted by Xingchen Wan, Ruoxi Sun, Hootan Nakhost, Hanjun Dai, Julian Martin Eisenschlos, Sercan Ö. Arık, and Tomas Pfister. We would like to thank Jinsung Yoon Xuezhi Wang for providing helpful reviews, and other colleagues at Google Cloud AI Research for their discussion and feedback.

Do You Want To Boost Your Business?

drop us a line and keep in touch